go-i18n/classify_test.go
Snider ff376830c0 fix: address Virgil review — 5 fixes for classify pipeline
- Remove go-mlx from go.mod (breaks non-darwin builds)
- Fix go-inference pseudo-version for CI compatibility
- Fix mapTokenToDomain prefix collision (castle, credential)
- Add testing.Short() skip to slow classification benchmarks
- Add 80% accuracy threshold to integration test

Integration test moved to integration/ sub-module with its own go.mod
to cleanly isolate go-mlx dependency from the main module.

Co-Authored-By: Virgil <virgil@lethean.io>
2026-02-20 00:44:35 +00:00

186 lines
5.7 KiB
Go

package i18n
import (
"bytes"
"context"
"encoding/json"
"iter"
"strings"
"testing"
"forge.lthn.ai/core/go-inference"
)
func TestMapTokenToDomain(t *testing.T) {
tests := []struct {
token string
want string
}{
{"technical", "technical"},
{"Technical", "technical"},
{"tech", "technical"},
{"creative", "creative"},
{"Creative", "creative"},
{"cre", "creative"},
{"ethical", "ethical"},
{"Ethical", "ethical"},
{"eth", "ethical"},
{"casual", "casual"},
{"Casual", "casual"},
{"cas", "casual"},
{"unknown", "unknown"},
{"", "unknown"},
{"foo", "unknown"},
// Verify prefix collision fix: these must NOT match any domain
{"castle", "unknown"},
{"cascade", "unknown"},
{"credential", "unknown"},
{"creature", "unknown"},
}
for _, tt := range tests {
t.Run(tt.token, func(t *testing.T) {
got := mapTokenToDomain(tt.token)
if got != tt.want {
t.Errorf("mapTokenToDomain(%q) = %q, want %q", tt.token, got, tt.want)
}
})
}
}
// mockModel satisfies inference.TextModel for testing.
type mockModel struct {
classifyFunc func(ctx context.Context, prompts []string, opts ...inference.GenerateOption) ([]inference.ClassifyResult, error)
}
func (m *mockModel) Generate(_ context.Context, _ string, _ ...inference.GenerateOption) iter.Seq[inference.Token] {
return func(yield func(inference.Token) bool) {}
}
func (m *mockModel) Chat(_ context.Context, _ []inference.Message, _ ...inference.GenerateOption) iter.Seq[inference.Token] {
return func(yield func(inference.Token) bool) {}
}
func (m *mockModel) Classify(ctx context.Context, prompts []string, opts ...inference.GenerateOption) ([]inference.ClassifyResult, error) {
return m.classifyFunc(ctx, prompts, opts...)
}
func (m *mockModel) BatchGenerate(_ context.Context, _ []string, _ ...inference.GenerateOption) ([]inference.BatchResult, error) {
return nil, nil
}
func (m *mockModel) ModelType() string { return "mock" }
func (m *mockModel) Info() inference.ModelInfo { return inference.ModelInfo{} }
func (m *mockModel) Metrics() inference.GenerateMetrics { return inference.GenerateMetrics{} }
func (m *mockModel) Err() error { return nil }
func (m *mockModel) Close() error { return nil }
func TestClassifyCorpus_Basic(t *testing.T) {
model := &mockModel{
classifyFunc: func(_ context.Context, prompts []string, _ ...inference.GenerateOption) ([]inference.ClassifyResult, error) {
results := make([]inference.ClassifyResult, len(prompts))
for i := range prompts {
results[i] = inference.ClassifyResult{Token: inference.Token{Text: "technical"}}
}
return results, nil
},
}
input := strings.NewReader(
`{"seed_id":"1","domain":"general","prompt":"Delete the file"}` + "\n" +
`{"seed_id":"2","domain":"science","prompt":"Explain gravity"}` + "\n",
)
var output bytes.Buffer
stats, err := ClassifyCorpus(context.Background(), model, input, &output, WithBatchSize(16))
if err != nil {
t.Fatalf("ClassifyCorpus returned error: %v", err)
}
if stats.Total != 2 {
t.Errorf("Total = %d, want 2", stats.Total)
}
if stats.Skipped != 0 {
t.Errorf("Skipped = %d, want 0", stats.Skipped)
}
lines := strings.Split(strings.TrimSpace(output.String()), "\n")
if len(lines) != 2 {
t.Fatalf("output lines = %d, want 2", len(lines))
}
for i, line := range lines {
var record map[string]any
if err := json.Unmarshal([]byte(line), &record); err != nil {
t.Fatalf("line %d: unmarshal: %v", i, err)
}
if record["domain_1b"] != "technical" {
t.Errorf("line %d: domain_1b = %v, want %q", i, record["domain_1b"], "technical")
}
// original domain field must be preserved
if _, ok := record["domain"]; !ok {
t.Errorf("line %d: original domain field missing", i)
}
}
}
func TestClassifyCorpus_SkipsMalformed(t *testing.T) {
model := &mockModel{
classifyFunc: func(_ context.Context, prompts []string, _ ...inference.GenerateOption) ([]inference.ClassifyResult, error) {
results := make([]inference.ClassifyResult, len(prompts))
for i := range prompts {
results[i] = inference.ClassifyResult{Token: inference.Token{Text: "technical"}}
}
return results, nil
},
}
input := strings.NewReader(
"not valid json\n" +
`{"seed_id":"1","domain":"general","prompt":"Hello world"}` + "\n" +
`{"seed_id":"2","domain":"general"}` + "\n",
)
var output bytes.Buffer
stats, err := ClassifyCorpus(context.Background(), model, input, &output)
if err != nil {
t.Fatalf("ClassifyCorpus returned error: %v", err)
}
if stats.Total != 1 {
t.Errorf("Total = %d, want 1", stats.Total)
}
if stats.Skipped != 2 {
t.Errorf("Skipped = %d, want 2", stats.Skipped)
}
}
func TestClassifyCorpus_DomainMapping(t *testing.T) {
model := &mockModel{
classifyFunc: func(_ context.Context, prompts []string, _ ...inference.GenerateOption) ([]inference.ClassifyResult, error) {
results := make([]inference.ClassifyResult, len(prompts))
for i, p := range prompts {
if strings.Contains(p, "Delete") {
results[i] = inference.ClassifyResult{Token: inference.Token{Text: "technical"}}
} else {
results[i] = inference.ClassifyResult{Token: inference.Token{Text: "ethical"}}
}
}
return results, nil
},
}
input := strings.NewReader(
`{"prompt":"Delete the file now"}` + "\n" +
`{"prompt":"Is it right to lie?"}` + "\n",
)
var output bytes.Buffer
stats, err := ClassifyCorpus(context.Background(), model, input, &output, WithBatchSize(16))
if err != nil {
t.Fatalf("ClassifyCorpus returned error: %v", err)
}
if stats.ByDomain["technical"] != 1 {
t.Errorf("ByDomain[technical] = %d, want 1", stats.ByDomain["technical"])
}
if stats.ByDomain["ethical"] != 1 {
t.Errorf("ByDomain[ethical] = %d, want 1", stats.ByDomain["ethical"])
}
}