go-i18n/calibrate_test.go
Snider 3b7ef9d26a feat(calibrate): 1B vs 27B domain calibration tool
CalibrateDomains() accepts two inference.TextModel instances and a corpus
of CalibrationSamples, classifies all with both models, and computes
agreement rate, per-domain distribution, confusion pairs, and accuracy
vs ground truth.

- calibrate.go: CalibrateDomains + classifyAll batch helper
- calibrate_test.go: 7 mock tests (agreement, disagreement, mixed,
  no ground truth, empty, batch boundary, results slice)
- integration/calibrate_test.go: 500-sample corpus (220 ground-truth
  + 280 unlabelled) for real 1B vs 27B model comparison
- TODO.md: Phase 2a calibration task marked complete

Co-Authored-By: Virgil <virgil@lethean.io>
2026-02-20 13:51:11 +00:00

277 lines
8.3 KiB
Go

package i18n
import (
"context"
"testing"
"forge.lthn.ai/core/go-inference"
)
func TestCalibrateDomains_FullAgreement(t *testing.T) {
// Both models return the same domain for all samples.
model := &mockModel{
classifyFunc: func(_ context.Context, prompts []string, _ ...inference.GenerateOption) ([]inference.ClassifyResult, error) {
results := make([]inference.ClassifyResult, len(prompts))
for i := range prompts {
results[i] = inference.ClassifyResult{Token: inference.Token{Text: "technical"}}
}
return results, nil
},
}
samples := []CalibrationSample{
{Text: "Delete the file", TrueDomain: "technical"},
{Text: "Build the project", TrueDomain: "technical"},
{Text: "Run the tests", TrueDomain: "technical"},
}
stats, err := CalibrateDomains(context.Background(), model, model, samples)
if err != nil {
t.Fatalf("CalibrateDomains: %v", err)
}
if stats.Total != 3 {
t.Errorf("Total = %d, want 3", stats.Total)
}
if stats.Agreed != 3 {
t.Errorf("Agreed = %d, want 3", stats.Agreed)
}
if stats.AgreementRate != 1.0 {
t.Errorf("AgreementRate = %f, want 1.0", stats.AgreementRate)
}
if stats.AccuracyA != 1.0 {
t.Errorf("AccuracyA = %f, want 1.0", stats.AccuracyA)
}
if stats.AccuracyB != 1.0 {
t.Errorf("AccuracyB = %f, want 1.0", stats.AccuracyB)
}
if len(stats.ConfusionPairs) != 0 {
t.Errorf("ConfusionPairs = %v, want empty", stats.ConfusionPairs)
}
}
func TestCalibrateDomains_Disagreement(t *testing.T) {
// Model A always says "technical", model B always says "creative".
modelA := &mockModel{
classifyFunc: func(_ context.Context, prompts []string, _ ...inference.GenerateOption) ([]inference.ClassifyResult, error) {
results := make([]inference.ClassifyResult, len(prompts))
for i := range prompts {
results[i] = inference.ClassifyResult{Token: inference.Token{Text: "technical"}}
}
return results, nil
},
}
modelB := &mockModel{
classifyFunc: func(_ context.Context, prompts []string, _ ...inference.GenerateOption) ([]inference.ClassifyResult, error) {
results := make([]inference.ClassifyResult, len(prompts))
for i := range prompts {
results[i] = inference.ClassifyResult{Token: inference.Token{Text: "creative"}}
}
return results, nil
},
}
samples := []CalibrationSample{
{Text: "She wrote a poem", TrueDomain: "creative"},
{Text: "He painted the sky", TrueDomain: "creative"},
}
stats, err := CalibrateDomains(context.Background(), modelA, modelB, samples)
if err != nil {
t.Fatalf("CalibrateDomains: %v", err)
}
if stats.Agreed != 0 {
t.Errorf("Agreed = %d, want 0", stats.Agreed)
}
if stats.AgreementRate != 0 {
t.Errorf("AgreementRate = %f, want 0", stats.AgreementRate)
}
if stats.CorrectA != 0 {
t.Errorf("CorrectA = %d, want 0 (A said technical, truth is creative)", stats.CorrectA)
}
if stats.CorrectB != 2 {
t.Errorf("CorrectB = %d, want 2", stats.CorrectB)
}
if stats.ConfusionPairs["technical->creative"] != 2 {
t.Errorf("ConfusionPairs[technical->creative] = %d, want 2", stats.ConfusionPairs["technical->creative"])
}
}
func TestCalibrateDomains_MixedAgreement(t *testing.T) {
// Model A and B agree on first sample, disagree on second.
callCount := 0
modelA := &mockModel{
classifyFunc: func(_ context.Context, prompts []string, _ ...inference.GenerateOption) ([]inference.ClassifyResult, error) {
results := make([]inference.ClassifyResult, len(prompts))
for i := range prompts {
results[i] = inference.ClassifyResult{Token: inference.Token{Text: "ethical"}}
}
return results, nil
},
}
modelB := &mockModel{
classifyFunc: func(_ context.Context, prompts []string, _ ...inference.GenerateOption) ([]inference.ClassifyResult, error) {
callCount++
results := make([]inference.ClassifyResult, len(prompts))
for i, p := range prompts {
if i == 0 && callCount == 1 {
// First batch: agree on first item
results[i] = inference.ClassifyResult{Token: inference.Token{Text: "ethical"}}
} else {
_ = p
results[i] = inference.ClassifyResult{Token: inference.Token{Text: "technical"}}
}
}
return results, nil
},
}
samples := []CalibrationSample{
{Text: "We should act fairly"},
{Text: "Delete the config"},
}
stats, err := CalibrateDomains(context.Background(), modelA, modelB, samples, WithBatchSize(16))
if err != nil {
t.Fatalf("CalibrateDomains: %v", err)
}
if stats.Total != 2 {
t.Errorf("Total = %d, want 2", stats.Total)
}
if stats.Agreed != 1 {
t.Errorf("Agreed = %d, want 1", stats.Agreed)
}
if got := stats.AgreementRate; got != 0.5 {
t.Errorf("AgreementRate = %f, want 0.5", got)
}
}
func TestCalibrateDomains_NoGroundTruth(t *testing.T) {
// Samples without TrueDomain: accuracy should be 0, agreement still measured.
model := &mockModel{
classifyFunc: func(_ context.Context, prompts []string, _ ...inference.GenerateOption) ([]inference.ClassifyResult, error) {
results := make([]inference.ClassifyResult, len(prompts))
for i := range prompts {
results[i] = inference.ClassifyResult{Token: inference.Token{Text: "casual"}}
}
return results, nil
},
}
samples := []CalibrationSample{
{Text: "Went to the store"},
{Text: "Had coffee this morning"},
}
stats, err := CalibrateDomains(context.Background(), model, model, samples)
if err != nil {
t.Fatalf("CalibrateDomains: %v", err)
}
if stats.WithTruth != 0 {
t.Errorf("WithTruth = %d, want 0", stats.WithTruth)
}
if stats.AccuracyA != 0 {
t.Errorf("AccuracyA = %f, want 0 (no ground truth)", stats.AccuracyA)
}
if stats.Agreed != 2 {
t.Errorf("Agreed = %d, want 2", stats.Agreed)
}
}
func TestCalibrateDomains_EmptySamples(t *testing.T) {
model := &mockModel{
classifyFunc: func(_ context.Context, _ []string, _ ...inference.GenerateOption) ([]inference.ClassifyResult, error) {
return nil, nil
},
}
_, err := CalibrateDomains(context.Background(), model, model, nil)
if err == nil {
t.Error("expected error for empty samples, got nil")
}
}
func TestCalibrateDomains_BatchBoundary(t *testing.T) {
// 7 samples with batch size 3: tests partial last batch.
model := &mockModel{
classifyFunc: func(_ context.Context, prompts []string, _ ...inference.GenerateOption) ([]inference.ClassifyResult, error) {
results := make([]inference.ClassifyResult, len(prompts))
for i := range prompts {
results[i] = inference.ClassifyResult{Token: inference.Token{Text: "technical"}}
}
return results, nil
},
}
samples := make([]CalibrationSample, 7)
for i := range samples {
samples[i] = CalibrationSample{Text: "Build the project"}
}
stats, err := CalibrateDomains(context.Background(), model, model, samples, WithBatchSize(3))
if err != nil {
t.Fatalf("CalibrateDomains: %v", err)
}
if stats.Total != 7 {
t.Errorf("Total = %d, want 7", stats.Total)
}
if stats.Agreed != 7 {
t.Errorf("Agreed = %d, want 7", stats.Agreed)
}
}
func TestCalibrateDomains_ResultsSlice(t *testing.T) {
// Verify individual results are populated correctly.
modelA := &mockModel{
classifyFunc: func(_ context.Context, prompts []string, _ ...inference.GenerateOption) ([]inference.ClassifyResult, error) {
results := make([]inference.ClassifyResult, len(prompts))
for i := range prompts {
results[i] = inference.ClassifyResult{Token: inference.Token{Text: "ethical"}}
}
return results, nil
},
}
modelB := &mockModel{
classifyFunc: func(_ context.Context, prompts []string, _ ...inference.GenerateOption) ([]inference.ClassifyResult, error) {
results := make([]inference.ClassifyResult, len(prompts))
for i := range prompts {
results[i] = inference.ClassifyResult{Token: inference.Token{Text: "casual"}}
}
return results, nil
},
}
samples := []CalibrationSample{
{Text: "Be fair to everyone", TrueDomain: "ethical"},
}
stats, err := CalibrateDomains(context.Background(), modelA, modelB, samples)
if err != nil {
t.Fatalf("CalibrateDomains: %v", err)
}
if len(stats.Results) != 1 {
t.Fatalf("Results len = %d, want 1", len(stats.Results))
}
r := stats.Results[0]
if r.Text != "Be fair to everyone" {
t.Errorf("Text = %q", r.Text)
}
if r.TrueDomain != "ethical" {
t.Errorf("TrueDomain = %q", r.TrueDomain)
}
if r.DomainA != "ethical" {
t.Errorf("DomainA = %q, want ethical", r.DomainA)
}
if r.DomainB != "casual" {
t.Errorf("DomainB = %q, want casual", r.DomainB)
}
if r.Agree {
t.Error("Agree = true, want false")
}
}