go-i18n/calibrate.go

153 lines
4.5 KiB
Go
Raw Permalink Normal View History

package i18n
import (
"context"
"time"
"dappco.re/go/core"
log "dappco.re/go/core/log"
"forge.lthn.ai/core/go-inference"
)
// CalibrationSample is a single text entry for model comparison.
type CalibrationSample struct {
Text string
TrueDomain string // optional ground truth label (empty if unknown)
}
// CalibrationResult holds per-sample classification from two models.
type CalibrationResult struct {
Text string `json:"text"`
TrueDomain string `json:"true_domain,omitempty"`
DomainA string `json:"domain_a"`
DomainB string `json:"domain_b"`
Agree bool `json:"agree"`
}
// CalibrationStats holds aggregate metrics from CalibrateDomains.
type CalibrationStats struct {
Total int `json:"total"`
Agreed int `json:"agreed"`
AgreementRate float64 `json:"agreement_rate"`
ByDomainA map[string]int `json:"by_domain_a"`
ByDomainB map[string]int `json:"by_domain_b"`
ConfusionPairs map[string]int `json:"confusion_pairs"` // "technical->creative": count
AccuracyA float64 `json:"accuracy_a"` // vs ground truth (0 if none)
AccuracyB float64 `json:"accuracy_b"` // vs ground truth (0 if none)
CorrectA int `json:"correct_a"`
CorrectB int `json:"correct_b"`
WithTruth int `json:"with_truth"` // samples that had ground truth
DurationA time.Duration `json:"duration_a"`
DurationB time.Duration `json:"duration_b"`
Results []CalibrationResult `json:"results"`
}
// CalibrateDomains classifies all samples with both models and computes agreement.
// Model A is typically the smaller/faster model (1B), model B the larger reference (27B).
// Samples with non-empty TrueDomain also contribute to accuracy metrics.
func CalibrateDomains(ctx context.Context, modelA, modelB inference.TextModel,
samples []CalibrationSample, opts ...ClassifyOption) (*CalibrationStats, error) {
if len(samples) == 0 {
return nil, log.E("CalibrateDomains", "empty sample set", nil)
}
cfg := defaultClassifyConfig()
for _, o := range opts {
o(&cfg)
}
stats := &CalibrationStats{
ByDomainA: make(map[string]int),
ByDomainB: make(map[string]int),
ConfusionPairs: make(map[string]int),
}
// Build classification prompts from sample texts.
prompts := make([]string, len(samples))
for i, s := range samples {
prompts[i] = core.Sprintf(cfg.promptTemplate, s.Text)
}
// Classify with model A.
domainsA, durA, err := classifyAll(ctx, modelA, prompts, cfg.batchSize)
if err != nil {
return nil, log.E("CalibrateDomains", "classify with model A", err)
}
stats.DurationA = durA
// Classify with model B.
domainsB, durB, err := classifyAll(ctx, modelB, prompts, cfg.batchSize)
if err != nil {
return nil, log.E("CalibrateDomains", "classify with model B", err)
}
stats.DurationB = durB
// Compare results.
stats.Total = len(samples)
stats.Results = make([]CalibrationResult, len(samples))
for i, s := range samples {
a, b := domainsA[i], domainsB[i]
agree := a == b
if agree {
stats.Agreed++
} else {
key := core.Sprintf("%s->%s", a, b)
stats.ConfusionPairs[key]++
}
stats.ByDomainA[a]++
stats.ByDomainB[b]++
if s.TrueDomain != "" {
stats.WithTruth++
if a == s.TrueDomain {
stats.CorrectA++
}
if b == s.TrueDomain {
stats.CorrectB++
}
}
stats.Results[i] = CalibrationResult{
Text: s.Text,
TrueDomain: s.TrueDomain,
DomainA: a,
DomainB: b,
Agree: agree,
}
}
if stats.Total > 0 {
stats.AgreementRate = float64(stats.Agreed) / float64(stats.Total)
}
if stats.WithTruth > 0 {
stats.AccuracyA = float64(stats.CorrectA) / float64(stats.WithTruth)
stats.AccuracyB = float64(stats.CorrectB) / float64(stats.WithTruth)
}
return stats, nil
}
// classifyAll runs batch classification over all prompts, returning domain labels.
func classifyAll(ctx context.Context, model inference.TextModel, prompts []string, batchSize int) ([]string, time.Duration, error) {
start := time.Now()
domains := make([]string, len(prompts))
for i := 0; i < len(prompts); i += batchSize {
end := min(i+batchSize, len(prompts))
batch := prompts[i:end]
results, err := model.Classify(ctx, batch, inference.WithMaxTokens(1))
if err != nil {
return nil, 0, log.E("classifyAll", core.Sprintf("classify batch [%d:%d]", i, end), err)
}
for j, r := range results {
domains[i+j] = mapTokenToDomain(r.Token.Text)
}
}
return domains, time.Since(start), nil
}