Replace fmt, errors, strings, encoding/json with Core primitives across 20 files. Keep strings.Fields/CutPrefix. No translation files modified. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
152 lines
4.5 KiB
Go
152 lines
4.5 KiB
Go
package i18n
|
|
|
|
import (
|
|
"context"
|
|
"time"
|
|
|
|
"dappco.re/go/core"
|
|
log "dappco.re/go/core/log"
|
|
"forge.lthn.ai/core/go-inference"
|
|
)
|
|
|
|
// CalibrationSample is a single text entry for model comparison.
|
|
type CalibrationSample struct {
|
|
Text string
|
|
TrueDomain string // optional ground truth label (empty if unknown)
|
|
}
|
|
|
|
// CalibrationResult holds per-sample classification from two models.
|
|
type CalibrationResult struct {
|
|
Text string `json:"text"`
|
|
TrueDomain string `json:"true_domain,omitempty"`
|
|
DomainA string `json:"domain_a"`
|
|
DomainB string `json:"domain_b"`
|
|
Agree bool `json:"agree"`
|
|
}
|
|
|
|
// CalibrationStats holds aggregate metrics from CalibrateDomains.
|
|
type CalibrationStats struct {
|
|
Total int `json:"total"`
|
|
Agreed int `json:"agreed"`
|
|
AgreementRate float64 `json:"agreement_rate"`
|
|
ByDomainA map[string]int `json:"by_domain_a"`
|
|
ByDomainB map[string]int `json:"by_domain_b"`
|
|
ConfusionPairs map[string]int `json:"confusion_pairs"` // "technical->creative": count
|
|
AccuracyA float64 `json:"accuracy_a"` // vs ground truth (0 if none)
|
|
AccuracyB float64 `json:"accuracy_b"` // vs ground truth (0 if none)
|
|
CorrectA int `json:"correct_a"`
|
|
CorrectB int `json:"correct_b"`
|
|
WithTruth int `json:"with_truth"` // samples that had ground truth
|
|
DurationA time.Duration `json:"duration_a"`
|
|
DurationB time.Duration `json:"duration_b"`
|
|
Results []CalibrationResult `json:"results"`
|
|
}
|
|
|
|
// CalibrateDomains classifies all samples with both models and computes agreement.
|
|
// Model A is typically the smaller/faster model (1B), model B the larger reference (27B).
|
|
// Samples with non-empty TrueDomain also contribute to accuracy metrics.
|
|
func CalibrateDomains(ctx context.Context, modelA, modelB inference.TextModel,
|
|
samples []CalibrationSample, opts ...ClassifyOption) (*CalibrationStats, error) {
|
|
|
|
if len(samples) == 0 {
|
|
return nil, log.E("CalibrateDomains", "empty sample set", nil)
|
|
}
|
|
|
|
cfg := defaultClassifyConfig()
|
|
for _, o := range opts {
|
|
o(&cfg)
|
|
}
|
|
|
|
stats := &CalibrationStats{
|
|
ByDomainA: make(map[string]int),
|
|
ByDomainB: make(map[string]int),
|
|
ConfusionPairs: make(map[string]int),
|
|
}
|
|
|
|
// Build classification prompts from sample texts.
|
|
prompts := make([]string, len(samples))
|
|
for i, s := range samples {
|
|
prompts[i] = core.Sprintf(cfg.promptTemplate, s.Text)
|
|
}
|
|
|
|
// Classify with model A.
|
|
domainsA, durA, err := classifyAll(ctx, modelA, prompts, cfg.batchSize)
|
|
if err != nil {
|
|
return nil, log.E("CalibrateDomains", "classify with model A", err)
|
|
}
|
|
stats.DurationA = durA
|
|
|
|
// Classify with model B.
|
|
domainsB, durB, err := classifyAll(ctx, modelB, prompts, cfg.batchSize)
|
|
if err != nil {
|
|
return nil, log.E("CalibrateDomains", "classify with model B", err)
|
|
}
|
|
stats.DurationB = durB
|
|
|
|
// Compare results.
|
|
stats.Total = len(samples)
|
|
stats.Results = make([]CalibrationResult, len(samples))
|
|
|
|
for i, s := range samples {
|
|
a, b := domainsA[i], domainsB[i]
|
|
agree := a == b
|
|
if agree {
|
|
stats.Agreed++
|
|
} else {
|
|
key := core.Sprintf("%s->%s", a, b)
|
|
stats.ConfusionPairs[key]++
|
|
}
|
|
stats.ByDomainA[a]++
|
|
stats.ByDomainB[b]++
|
|
|
|
if s.TrueDomain != "" {
|
|
stats.WithTruth++
|
|
if a == s.TrueDomain {
|
|
stats.CorrectA++
|
|
}
|
|
if b == s.TrueDomain {
|
|
stats.CorrectB++
|
|
}
|
|
}
|
|
|
|
stats.Results[i] = CalibrationResult{
|
|
Text: s.Text,
|
|
TrueDomain: s.TrueDomain,
|
|
DomainA: a,
|
|
DomainB: b,
|
|
Agree: agree,
|
|
}
|
|
}
|
|
|
|
if stats.Total > 0 {
|
|
stats.AgreementRate = float64(stats.Agreed) / float64(stats.Total)
|
|
}
|
|
if stats.WithTruth > 0 {
|
|
stats.AccuracyA = float64(stats.CorrectA) / float64(stats.WithTruth)
|
|
stats.AccuracyB = float64(stats.CorrectB) / float64(stats.WithTruth)
|
|
}
|
|
|
|
return stats, nil
|
|
}
|
|
|
|
// classifyAll runs batch classification over all prompts, returning domain labels.
|
|
func classifyAll(ctx context.Context, model inference.TextModel, prompts []string, batchSize int) ([]string, time.Duration, error) {
|
|
start := time.Now()
|
|
domains := make([]string, len(prompts))
|
|
|
|
for i := 0; i < len(prompts); i += batchSize {
|
|
end := min(i+batchSize, len(prompts))
|
|
batch := prompts[i:end]
|
|
|
|
results, err := model.Classify(ctx, batch, inference.WithMaxTokens(1))
|
|
if err != nil {
|
|
return nil, 0, log.E("classifyAll", core.Sprintf("classify batch [%d:%d]", i, end), err)
|
|
}
|
|
|
|
for j, r := range results {
|
|
domains[i+j] = mapTokenToDomain(r.Token.Text)
|
|
}
|
|
}
|
|
|
|
return domains, time.Since(start), nil
|
|
}
|