- Remove go-mlx from go.mod (breaks non-darwin builds) - Fix go-inference pseudo-version for CI compatibility - Fix mapTokenToDomain prefix collision (castle, credential) - Add testing.Short() skip to slow classification benchmarks - Add 80% accuracy threshold to integration test Integration test moved to integration/ sub-module with its own go.mod to cleanly isolate go-mlx dependency from the main module. Co-Authored-By: Virgil <virgil@lethean.io>
186 lines
5.7 KiB
Go
186 lines
5.7 KiB
Go
package i18n
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"iter"
|
|
"strings"
|
|
"testing"
|
|
|
|
"forge.lthn.ai/core/go-inference"
|
|
)
|
|
|
|
func TestMapTokenToDomain(t *testing.T) {
|
|
tests := []struct {
|
|
token string
|
|
want string
|
|
}{
|
|
{"technical", "technical"},
|
|
{"Technical", "technical"},
|
|
{"tech", "technical"},
|
|
{"creative", "creative"},
|
|
{"Creative", "creative"},
|
|
{"cre", "creative"},
|
|
{"ethical", "ethical"},
|
|
{"Ethical", "ethical"},
|
|
{"eth", "ethical"},
|
|
{"casual", "casual"},
|
|
{"Casual", "casual"},
|
|
{"cas", "casual"},
|
|
{"unknown", "unknown"},
|
|
{"", "unknown"},
|
|
{"foo", "unknown"},
|
|
// Verify prefix collision fix: these must NOT match any domain
|
|
{"castle", "unknown"},
|
|
{"cascade", "unknown"},
|
|
{"credential", "unknown"},
|
|
{"creature", "unknown"},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.token, func(t *testing.T) {
|
|
got := mapTokenToDomain(tt.token)
|
|
if got != tt.want {
|
|
t.Errorf("mapTokenToDomain(%q) = %q, want %q", tt.token, got, tt.want)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// mockModel satisfies inference.TextModel for testing.
|
|
type mockModel struct {
|
|
classifyFunc func(ctx context.Context, prompts []string, opts ...inference.GenerateOption) ([]inference.ClassifyResult, error)
|
|
}
|
|
|
|
func (m *mockModel) Generate(_ context.Context, _ string, _ ...inference.GenerateOption) iter.Seq[inference.Token] {
|
|
return func(yield func(inference.Token) bool) {}
|
|
}
|
|
|
|
func (m *mockModel) Chat(_ context.Context, _ []inference.Message, _ ...inference.GenerateOption) iter.Seq[inference.Token] {
|
|
return func(yield func(inference.Token) bool) {}
|
|
}
|
|
|
|
func (m *mockModel) Classify(ctx context.Context, prompts []string, opts ...inference.GenerateOption) ([]inference.ClassifyResult, error) {
|
|
return m.classifyFunc(ctx, prompts, opts...)
|
|
}
|
|
|
|
func (m *mockModel) BatchGenerate(_ context.Context, _ []string, _ ...inference.GenerateOption) ([]inference.BatchResult, error) {
|
|
return nil, nil
|
|
}
|
|
|
|
func (m *mockModel) ModelType() string { return "mock" }
|
|
func (m *mockModel) Info() inference.ModelInfo { return inference.ModelInfo{} }
|
|
func (m *mockModel) Metrics() inference.GenerateMetrics { return inference.GenerateMetrics{} }
|
|
func (m *mockModel) Err() error { return nil }
|
|
func (m *mockModel) Close() error { return nil }
|
|
|
|
func TestClassifyCorpus_Basic(t *testing.T) {
|
|
model := &mockModel{
|
|
classifyFunc: func(_ context.Context, prompts []string, _ ...inference.GenerateOption) ([]inference.ClassifyResult, error) {
|
|
results := make([]inference.ClassifyResult, len(prompts))
|
|
for i := range prompts {
|
|
results[i] = inference.ClassifyResult{Token: inference.Token{Text: "technical"}}
|
|
}
|
|
return results, nil
|
|
},
|
|
}
|
|
|
|
input := strings.NewReader(
|
|
`{"seed_id":"1","domain":"general","prompt":"Delete the file"}` + "\n" +
|
|
`{"seed_id":"2","domain":"science","prompt":"Explain gravity"}` + "\n",
|
|
)
|
|
var output bytes.Buffer
|
|
|
|
stats, err := ClassifyCorpus(context.Background(), model, input, &output, WithBatchSize(16))
|
|
if err != nil {
|
|
t.Fatalf("ClassifyCorpus returned error: %v", err)
|
|
}
|
|
if stats.Total != 2 {
|
|
t.Errorf("Total = %d, want 2", stats.Total)
|
|
}
|
|
if stats.Skipped != 0 {
|
|
t.Errorf("Skipped = %d, want 0", stats.Skipped)
|
|
}
|
|
|
|
lines := strings.Split(strings.TrimSpace(output.String()), "\n")
|
|
if len(lines) != 2 {
|
|
t.Fatalf("output lines = %d, want 2", len(lines))
|
|
}
|
|
|
|
for i, line := range lines {
|
|
var record map[string]any
|
|
if err := json.Unmarshal([]byte(line), &record); err != nil {
|
|
t.Fatalf("line %d: unmarshal: %v", i, err)
|
|
}
|
|
if record["domain_1b"] != "technical" {
|
|
t.Errorf("line %d: domain_1b = %v, want %q", i, record["domain_1b"], "technical")
|
|
}
|
|
// original domain field must be preserved
|
|
if _, ok := record["domain"]; !ok {
|
|
t.Errorf("line %d: original domain field missing", i)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestClassifyCorpus_SkipsMalformed(t *testing.T) {
|
|
model := &mockModel{
|
|
classifyFunc: func(_ context.Context, prompts []string, _ ...inference.GenerateOption) ([]inference.ClassifyResult, error) {
|
|
results := make([]inference.ClassifyResult, len(prompts))
|
|
for i := range prompts {
|
|
results[i] = inference.ClassifyResult{Token: inference.Token{Text: "technical"}}
|
|
}
|
|
return results, nil
|
|
},
|
|
}
|
|
|
|
input := strings.NewReader(
|
|
"not valid json\n" +
|
|
`{"seed_id":"1","domain":"general","prompt":"Hello world"}` + "\n" +
|
|
`{"seed_id":"2","domain":"general"}` + "\n",
|
|
)
|
|
var output bytes.Buffer
|
|
|
|
stats, err := ClassifyCorpus(context.Background(), model, input, &output)
|
|
if err != nil {
|
|
t.Fatalf("ClassifyCorpus returned error: %v", err)
|
|
}
|
|
if stats.Total != 1 {
|
|
t.Errorf("Total = %d, want 1", stats.Total)
|
|
}
|
|
if stats.Skipped != 2 {
|
|
t.Errorf("Skipped = %d, want 2", stats.Skipped)
|
|
}
|
|
}
|
|
|
|
func TestClassifyCorpus_DomainMapping(t *testing.T) {
|
|
model := &mockModel{
|
|
classifyFunc: func(_ context.Context, prompts []string, _ ...inference.GenerateOption) ([]inference.ClassifyResult, error) {
|
|
results := make([]inference.ClassifyResult, len(prompts))
|
|
for i, p := range prompts {
|
|
if strings.Contains(p, "Delete") {
|
|
results[i] = inference.ClassifyResult{Token: inference.Token{Text: "technical"}}
|
|
} else {
|
|
results[i] = inference.ClassifyResult{Token: inference.Token{Text: "ethical"}}
|
|
}
|
|
}
|
|
return results, nil
|
|
},
|
|
}
|
|
|
|
input := strings.NewReader(
|
|
`{"prompt":"Delete the file now"}` + "\n" +
|
|
`{"prompt":"Is it right to lie?"}` + "\n",
|
|
)
|
|
var output bytes.Buffer
|
|
|
|
stats, err := ClassifyCorpus(context.Background(), model, input, &output, WithBatchSize(16))
|
|
if err != nil {
|
|
t.Fatalf("ClassifyCorpus returned error: %v", err)
|
|
}
|
|
if stats.ByDomain["technical"] != 1 {
|
|
t.Errorf("ByDomain[technical] = %d, want 1", stats.ByDomain["technical"])
|
|
}
|
|
if stats.ByDomain["ethical"] != 1 {
|
|
t.Errorf("ByDomain[ethical] = %d, want 1", stats.ByDomain["ethical"])
|
|
}
|
|
}
|