go-i18n/classify.go
Snider ff376830c0 fix: address Virgil review — 5 fixes for classify pipeline
- Remove go-mlx from go.mod (breaks non-darwin builds)
- Fix go-inference pseudo-version for CI compatibility
- Fix mapTokenToDomain prefix collision (castle, credential)
- Add testing.Short() skip to slow classification benchmarks
- Add 80% accuracy threshold to integration test

Integration test moved to integration/ sub-module with its own go.mod
to cleanly isolate go-mlx dependency from the main module.

Co-Authored-By: Virgil <virgil@lethean.io>
2026-02-20 00:44:35 +00:00

171 lines
4.4 KiB
Go

package i18n
import (
"bufio"
"context"
"encoding/json"
"fmt"
"io"
"strings"
"time"
"forge.lthn.ai/core/go-inference"
)
// ClassifyStats reports metrics from a ClassifyCorpus run.
type ClassifyStats struct {
Total int
Skipped int // malformed or missing prompt field
ByDomain map[string]int // domain_1b label -> count
Duration time.Duration
PromptsPerSec float64
}
// ClassifyOption configures ClassifyCorpus behaviour.
type ClassifyOption func(*classifyConfig)
type classifyConfig struct {
batchSize int
promptField string
promptTemplate string
}
func defaultClassifyConfig() classifyConfig {
return classifyConfig{
batchSize: 8,
promptField: "prompt",
promptTemplate: "Classify this text into exactly one category: technical, creative, ethical, casual.\n\nText: %s\n\nCategory:",
}
}
// WithBatchSize sets the number of prompts per Classify call. Default 8.
func WithBatchSize(n int) ClassifyOption {
return func(c *classifyConfig) { c.batchSize = n }
}
// WithPromptField sets which JSONL field contains the text to classify. Default "prompt".
func WithPromptField(field string) ClassifyOption {
return func(c *classifyConfig) { c.promptField = field }
}
// WithPromptTemplate sets the classification prompt. Use %s for the text placeholder.
func WithPromptTemplate(tmpl string) ClassifyOption {
return func(c *classifyConfig) { c.promptTemplate = tmpl }
}
// mapTokenToDomain maps a model output token to a 4-way domain label.
// Prefix matching exists because BPE tokenisation can fragment words into
// partial tokens (e.g. "cas" from "casual", "cre" from "creative"). We
// only match the known short fragments that actually appear in BPE output,
// NOT arbitrary prefixes like "cas" which would collide with "castle" etc.
func mapTokenToDomain(token string) string {
if len(token) == 0 {
return "unknown"
}
lower := strings.ToLower(token)
switch {
case lower == "technical" || lower == "tech":
return "technical"
case lower == "creative" || lower == "cre":
return "creative"
case lower == "ethical" || lower == "eth":
return "ethical"
case lower == "casual" || lower == "cas":
return "casual"
default:
return "unknown"
}
}
// ClassifyCorpus reads JSONL from input, batch-classifies each entry through
// model, and writes JSONL with domain_1b field added to output.
func ClassifyCorpus(ctx context.Context, model inference.TextModel,
input io.Reader, output io.Writer, opts ...ClassifyOption) (*ClassifyStats, error) {
cfg := defaultClassifyConfig()
for _, o := range opts {
o(&cfg)
}
stats := &ClassifyStats{ByDomain: make(map[string]int)}
start := time.Now()
scanner := bufio.NewScanner(input)
scanner.Buffer(make([]byte, 1024*1024), 1024*1024)
type pending struct {
record map[string]any
prompt string
}
var batch []pending
flush := func() error {
if len(batch) == 0 {
return nil
}
prompts := make([]string, len(batch))
for i, p := range batch {
prompts[i] = fmt.Sprintf(cfg.promptTemplate, p.prompt)
}
results, err := model.Classify(ctx, prompts, inference.WithMaxTokens(1))
if err != nil {
return fmt.Errorf("classify batch: %w", err)
}
for i, r := range results {
domain := mapTokenToDomain(r.Token.Text)
batch[i].record["domain_1b"] = domain
stats.ByDomain[domain]++
stats.Total++
line, err := json.Marshal(batch[i].record)
if err != nil {
return fmt.Errorf("marshal output: %w", err)
}
if _, err := fmt.Fprintf(output, "%s\n", line); err != nil {
return fmt.Errorf("write output: %w", err)
}
}
batch = batch[:0]
return nil
}
for scanner.Scan() {
var record map[string]any
if err := json.Unmarshal(scanner.Bytes(), &record); err != nil {
stats.Skipped++
continue
}
promptVal, ok := record[cfg.promptField]
if !ok {
stats.Skipped++
continue
}
prompt, ok := promptVal.(string)
if !ok || prompt == "" {
stats.Skipped++
continue
}
batch = append(batch, pending{record: record, prompt: prompt})
if len(batch) >= cfg.batchSize {
if err := flush(); err != nil {
return stats, err
}
}
}
if err := scanner.Err(); err != nil {
return stats, fmt.Errorf("read input: %w", err)
}
if err := flush(); err != nil {
return stats, err
}
stats.Duration = time.Since(start)
if stats.Duration > 0 {
stats.PromptsPerSec = float64(stats.Total) / stats.Duration.Seconds()
}
return stats, nil
}