feat: implement ClassifyCorpus with streaming batch classification

Co-Authored-By: Virgil <virgil@lethean.io>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Snider 2026-02-20 00:06:44 +00:00
parent a5f3eb4777
commit 94fd2f463f
2 changed files with 243 additions and 8 deletions

View file

@ -1,7 +1,10 @@
package i18n
import (
"bufio"
"context"
"encoding/json"
"fmt"
"io"
"strings"
"time"
@ -70,10 +73,95 @@ func mapTokenToDomain(token string) string {
}
}
// Ensure imports are used (ClassifyCorpus will be added in Task 3).
var (
_ = (*inference.Token)(nil)
_ context.Context
_ io.Reader
_ time.Duration
)
// ClassifyCorpus reads JSONL from input, batch-classifies each entry through
// model, and writes JSONL with domain_1b field added to output.
func ClassifyCorpus(ctx context.Context, model inference.TextModel,
input io.Reader, output io.Writer, opts ...ClassifyOption) (*ClassifyStats, error) {
cfg := defaultClassifyConfig()
for _, o := range opts {
o(&cfg)
}
stats := &ClassifyStats{ByDomain: make(map[string]int)}
start := time.Now()
scanner := bufio.NewScanner(input)
scanner.Buffer(make([]byte, 1024*1024), 1024*1024)
type pending struct {
record map[string]any
prompt string
}
var batch []pending
flush := func() error {
if len(batch) == 0 {
return nil
}
prompts := make([]string, len(batch))
for i, p := range batch {
prompts[i] = fmt.Sprintf(cfg.promptTemplate, p.prompt)
}
results, err := model.Classify(ctx, prompts, inference.WithMaxTokens(1))
if err != nil {
return fmt.Errorf("classify batch: %w", err)
}
for i, r := range results {
domain := mapTokenToDomain(r.Token.Text)
batch[i].record["domain_1b"] = domain
stats.ByDomain[domain]++
stats.Total++
line, err := json.Marshal(batch[i].record)
if err != nil {
return fmt.Errorf("marshal output: %w", err)
}
if _, err := fmt.Fprintf(output, "%s\n", line); err != nil {
return fmt.Errorf("write output: %w", err)
}
}
batch = batch[:0]
return nil
}
for scanner.Scan() {
var record map[string]any
if err := json.Unmarshal(scanner.Bytes(), &record); err != nil {
stats.Skipped++
continue
}
promptVal, ok := record[cfg.promptField]
if !ok {
stats.Skipped++
continue
}
prompt, ok := promptVal.(string)
if !ok || prompt == "" {
stats.Skipped++
continue
}
batch = append(batch, pending{record: record, prompt: prompt})
if len(batch) >= cfg.batchSize {
if err := flush(); err != nil {
return stats, err
}
}
}
if err := scanner.Err(); err != nil {
return stats, fmt.Errorf("read input: %w", err)
}
if err := flush(); err != nil {
return stats, err
}
stats.Duration = time.Since(start)
if stats.Duration > 0 {
stats.PromptsPerSec = float64(stats.Total) / stats.Duration.Seconds()
}
return stats, nil
}

View file

@ -1,6 +1,15 @@
package i18n
import "testing"
import (
"bytes"
"context"
"encoding/json"
"iter"
"strings"
"testing"
"forge.lthn.ai/core/go-inference"
)
func TestMapTokenToDomain(t *testing.T) {
tests := []struct {
@ -32,3 +41,141 @@ func TestMapTokenToDomain(t *testing.T) {
})
}
}
// mockModel satisfies inference.TextModel for testing.
type mockModel struct {
classifyFunc func(ctx context.Context, prompts []string, opts ...inference.GenerateOption) ([]inference.ClassifyResult, error)
}
func (m *mockModel) Generate(_ context.Context, _ string, _ ...inference.GenerateOption) iter.Seq[inference.Token] {
return func(yield func(inference.Token) bool) {}
}
func (m *mockModel) Chat(_ context.Context, _ []inference.Message, _ ...inference.GenerateOption) iter.Seq[inference.Token] {
return func(yield func(inference.Token) bool) {}
}
func (m *mockModel) Classify(ctx context.Context, prompts []string, opts ...inference.GenerateOption) ([]inference.ClassifyResult, error) {
return m.classifyFunc(ctx, prompts, opts...)
}
func (m *mockModel) BatchGenerate(_ context.Context, _ []string, _ ...inference.GenerateOption) ([]inference.BatchResult, error) {
return nil, nil
}
func (m *mockModel) ModelType() string { return "mock" }
func (m *mockModel) Info() inference.ModelInfo { return inference.ModelInfo{} }
func (m *mockModel) Metrics() inference.GenerateMetrics { return inference.GenerateMetrics{} }
func (m *mockModel) Err() error { return nil }
func (m *mockModel) Close() error { return nil }
func TestClassifyCorpus_Basic(t *testing.T) {
model := &mockModel{
classifyFunc: func(_ context.Context, prompts []string, _ ...inference.GenerateOption) ([]inference.ClassifyResult, error) {
results := make([]inference.ClassifyResult, len(prompts))
for i := range prompts {
results[i] = inference.ClassifyResult{Token: inference.Token{Text: "technical"}}
}
return results, nil
},
}
input := strings.NewReader(
`{"seed_id":"1","domain":"general","prompt":"Delete the file"}` + "\n" +
`{"seed_id":"2","domain":"science","prompt":"Explain gravity"}` + "\n",
)
var output bytes.Buffer
stats, err := ClassifyCorpus(context.Background(), model, input, &output, WithBatchSize(16))
if err != nil {
t.Fatalf("ClassifyCorpus returned error: %v", err)
}
if stats.Total != 2 {
t.Errorf("Total = %d, want 2", stats.Total)
}
if stats.Skipped != 0 {
t.Errorf("Skipped = %d, want 0", stats.Skipped)
}
lines := strings.Split(strings.TrimSpace(output.String()), "\n")
if len(lines) != 2 {
t.Fatalf("output lines = %d, want 2", len(lines))
}
for i, line := range lines {
var record map[string]any
if err := json.Unmarshal([]byte(line), &record); err != nil {
t.Fatalf("line %d: unmarshal: %v", i, err)
}
if record["domain_1b"] != "technical" {
t.Errorf("line %d: domain_1b = %v, want %q", i, record["domain_1b"], "technical")
}
// original domain field must be preserved
if _, ok := record["domain"]; !ok {
t.Errorf("line %d: original domain field missing", i)
}
}
}
func TestClassifyCorpus_SkipsMalformed(t *testing.T) {
model := &mockModel{
classifyFunc: func(_ context.Context, prompts []string, _ ...inference.GenerateOption) ([]inference.ClassifyResult, error) {
results := make([]inference.ClassifyResult, len(prompts))
for i := range prompts {
results[i] = inference.ClassifyResult{Token: inference.Token{Text: "technical"}}
}
return results, nil
},
}
input := strings.NewReader(
"not valid json\n" +
`{"seed_id":"1","domain":"general","prompt":"Hello world"}` + "\n" +
`{"seed_id":"2","domain":"general"}` + "\n",
)
var output bytes.Buffer
stats, err := ClassifyCorpus(context.Background(), model, input, &output)
if err != nil {
t.Fatalf("ClassifyCorpus returned error: %v", err)
}
if stats.Total != 1 {
t.Errorf("Total = %d, want 1", stats.Total)
}
if stats.Skipped != 2 {
t.Errorf("Skipped = %d, want 2", stats.Skipped)
}
}
func TestClassifyCorpus_DomainMapping(t *testing.T) {
model := &mockModel{
classifyFunc: func(_ context.Context, prompts []string, _ ...inference.GenerateOption) ([]inference.ClassifyResult, error) {
results := make([]inference.ClassifyResult, len(prompts))
for i, p := range prompts {
if strings.Contains(p, "Delete") {
results[i] = inference.ClassifyResult{Token: inference.Token{Text: "technical"}}
} else {
results[i] = inference.ClassifyResult{Token: inference.Token{Text: "ethical"}}
}
}
return results, nil
},
}
input := strings.NewReader(
`{"prompt":"Delete the file now"}` + "\n" +
`{"prompt":"Is it right to lie?"}` + "\n",
)
var output bytes.Buffer
stats, err := ClassifyCorpus(context.Background(), model, input, &output, WithBatchSize(16))
if err != nil {
t.Fatalf("ClassifyCorpus returned error: %v", err)
}
if stats.ByDomain["technical"] != 1 {
t.Errorf("ByDomain[technical] = %d, want 1", stats.ByDomain["technical"])
}
if stats.ByDomain["ethical"] != 1 {
t.Errorf("ByDomain[ethical] = %d, want 1", stats.ByDomain["ethical"])
}
}