fix: address Virgil review — 5 fixes for classify pipeline

- Remove go-mlx from go.mod (breaks non-darwin builds)
- Fix go-inference pseudo-version for CI compatibility
- Fix mapTokenToDomain prefix collision (castle, credential)
- Add testing.Short() skip to slow classification benchmarks
- Add 80% accuracy threshold to integration test

Integration test moved to integration/ sub-module with its own go.mod
to cleanly isolate go-mlx dependency from the main module.

Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
Snider 2026-02-20 00:44:35 +00:00
parent c23a271716
commit ff376830c0
8 changed files with 114 additions and 62 deletions

10
TODO.md
View file

@ -59,15 +59,15 @@ models, _ := inference.Discover("/Volumes/Data/lem/")
**Do these first, in order, before picking up the next Phase 2a task.**
- [ ] **Fix go.mod: remove go-mlx from module require** — go-mlx is darwin/arm64 CGO. Having it in go.mod makes go-i18n uncompilable on Linux and breaks any downstream consumer that vendors this module. go-i18n's founding principle is "only dependency: golang.org/x/text" plus go-inference. The integration test imports go-mlx behind a `//go:build integration` tag, which is fine — but the go.mod `require` is unconditional. Fix: remove the `require` and `replace` lines for go-mlx from go.mod. For local integration test builds, use a go.work file instead. Verify `go mod tidy` succeeds without go-mlx.
- [x] **Fix go.mod: remove go-mlx from module require** — Removed go-mlx `require` and `replace` from go.mod. Moved integration test to `integration/` sub-module with its own go.mod that depends on go-mlx. Main module now compiles cleanly on all platforms. `go mod tidy` no longer pulls go-mlx.
- [ ] **Fix go.mod: go-inference pseudo-version**`v0.0.0` without a timestamp is not a valid Go module version. Run `go get forge.lthn.ai/core/go-inference@latest` with the workspace active to compute a proper pseudo-version (like `v0.0.0-20260219...`), or add a proper `replace` and let `go mod tidy` resolve it. This prevents `go mod verify` failures in CI.
- [x] **Fix go.mod: go-inference pseudo-version**`go mod tidy` resolved to the standard replaced-module pseudo-version `v0.0.0-00010101000000-000000000000`. CI-safe.
- [ ] **Fix mapTokenToDomain prefix collision**`strings.HasPrefix(lower, "cas")` matches "castle", "cascade", etc. — not just "casual". Same risk with "cre" matching "credential". Use exact match with a short-prefix fallback only for known token fragments: `lower == "casual" || lower == "cas"`. Add a comment explaining why prefix matching exists (BPE token fragmentation can produce partial words). Add test cases for "castle" and "credential" to verify they return "unknown".
- [x] **Fix mapTokenToDomain prefix collision** — Replaced `strings.HasPrefix` with exact match + known BPE fragment fallback. Added test cases for "castle", "cascade", "credential", "creature" — all return "unknown".
- [ ] **Fix classify_bench_test.go naming** — File contains 6 `Test*` functions that run on every `go test ./...`. The O(n^2) `TestClassification_LeaveOneOut` (220 sentences, ~48K similarity computations) is slow for a normal test run. Either: (a) rename to `classify_test.go` and add `testing.Short()` skip to slow tests, or (b) add a `//go:build !short` build tag. Option (a) preferred.
- [x] **Fix classify_bench_test.go naming** — Added `testing.Short()` skip to `TestClassification_DomainSeparation` and `TestClassification_LeaveOneOut` (the two O(n^2) tests). Verified with `go test -short -v`.
- [ ] **Add accuracy assertion to integration test** — Currently only checks `Total == 50` and `Skipped == 0`. If the model returns all "unknown", the test still passes. Add a minimum threshold: at least 80% of the 50 technical prompts should classify as "technical". The FINDINGS data shows 100% accuracy on controlled technical input, so 80% is a conservative floor.
- [x] **Add accuracy assertion to integration test** — Integration test now asserts at least 80% (40/50) of technical prompts classified as "technical". Logs full domain breakdown and misclassified entries on failure. Test moved to `integration/` sub-module.
### Remaining Phase 2a Tasks

View file

@ -54,19 +54,23 @@ func WithPromptTemplate(tmpl string) ClassifyOption {
}
// mapTokenToDomain maps a model output token to a 4-way domain label.
// Prefix matching exists because BPE tokenisation can fragment words into
// partial tokens (e.g. "cas" from "casual", "cre" from "creative"). We
// only match the known short fragments that actually appear in BPE output,
// NOT arbitrary prefixes like "cas" which would collide with "castle" etc.
func mapTokenToDomain(token string) string {
if len(token) == 0 {
return "unknown"
}
lower := strings.ToLower(token)
switch {
case strings.HasPrefix(lower, "tech"):
case lower == "technical" || lower == "tech":
return "technical"
case strings.HasPrefix(lower, "cre"):
case lower == "creative" || lower == "cre":
return "creative"
case strings.HasPrefix(lower, "eth"):
case lower == "ethical" || lower == "eth":
return "ethical"
case strings.HasPrefix(lower, "cas"):
case lower == "casual" || lower == "cas":
return "casual"
default:
return "unknown"

View file

@ -1,48 +0,0 @@
//go:build integration
package i18n
import (
"bytes"
"context"
"fmt"
"strings"
"testing"
"time"
"forge.lthn.ai/core/go-inference"
_ "forge.lthn.ai/core/go-mlx" // registers Metal backend
)
func TestClassifyCorpus_Integration(t *testing.T) {
model, err := inference.LoadModel("/Volumes/Data/lem/LEM-Gemma3-1B-layered-v2")
if err != nil {
t.Skipf("model not available: %v", err)
}
defer model.Close()
// Build 50 technical prompts for throughput measurement
var lines []string
for i := 0; i < 50; i++ {
lines = append(lines, fmt.Sprintf(`{"id":%d,"prompt":"Delete the configuration file and rebuild the project"}`, i))
}
input := strings.NewReader(strings.Join(lines, "\n") + "\n")
var output bytes.Buffer
start := time.Now()
stats, err := ClassifyCorpus(context.Background(), model, input, &output, WithBatchSize(8))
if err != nil {
t.Fatalf("ClassifyCorpus: %v", err)
}
elapsed := time.Since(start)
t.Logf("Classified %d prompts in %v (%.1f prompts/sec)", stats.Total, elapsed, stats.PromptsPerSec)
t.Logf("By domain: %v", stats.ByDomain)
if stats.Total != 50 {
t.Errorf("Total = %d, want 50", stats.Total)
}
if stats.Skipped != 0 {
t.Errorf("Skipped = %d, want 0", stats.Skipped)
}
}

View file

@ -31,6 +31,11 @@ func TestMapTokenToDomain(t *testing.T) {
{"unknown", "unknown"},
{"", "unknown"},
{"foo", "unknown"},
// Verify prefix collision fix: these must NOT match any domain
{"castle", "unknown"},
{"cascade", "unknown"},
{"credential", "unknown"},
{"creature", "unknown"},
}
for _, tt := range tests {
t.Run(tt.token, func(t *testing.T) {

6
go.mod
View file

@ -4,10 +4,6 @@ go 1.25.5
require golang.org/x/text v0.33.0
require forge.lthn.ai/core/go-inference v0.0.0
require forge.lthn.ai/core/go-mlx v0.0.0-20260219234407-d1fb26d51e62
require forge.lthn.ai/core/go-inference v0.0.0-00010101000000-000000000000
replace forge.lthn.ai/core/go-inference => ../go-inference
replace forge.lthn.ai/core/go-mlx => ../go-mlx

View file

@ -0,0 +1,74 @@
package integration
import (
"bytes"
"context"
"encoding/json"
"fmt"
"strings"
"testing"
"time"
i18n "forge.lthn.ai/core/go-i18n"
"forge.lthn.ai/core/go-inference"
_ "forge.lthn.ai/core/go-mlx" // registers Metal backend
)
func TestClassifyCorpus_Integration(t *testing.T) {
model, err := inference.LoadModel("/Volumes/Data/lem/LEM-Gemma3-1B-layered-v2")
if err != nil {
t.Skipf("model not available: %v", err)
}
defer model.Close()
// Build 50 technical prompts for throughput measurement
var lines []string
for i := 0; i < 50; i++ {
lines = append(lines, fmt.Sprintf(`{"id":%d,"prompt":"Delete the configuration file and rebuild the project"}`, i))
}
input := strings.NewReader(strings.Join(lines, "\n") + "\n")
var output bytes.Buffer
start := time.Now()
stats, err := i18n.ClassifyCorpus(context.Background(), model, input, &output, i18n.WithBatchSize(8))
if err != nil {
t.Fatalf("ClassifyCorpus: %v", err)
}
elapsed := time.Since(start)
t.Logf("Classified %d prompts in %v (%.1f prompts/sec)", stats.Total, elapsed, stats.PromptsPerSec)
t.Logf("By domain: %v", stats.ByDomain)
if stats.Total != 50 {
t.Errorf("Total = %d, want 50", stats.Total)
}
if stats.Skipped != 0 {
t.Errorf("Skipped = %d, want 0", stats.Skipped)
}
// Fix 5: Assert minimum accuracy — at least 80% of the 50 technical prompts
// should be classified as "technical". The FINDINGS data shows 100% accuracy
// on controlled technical input, so 80% is a conservative floor.
technicalCount := stats.ByDomain["technical"]
minRequired := 40 // 80% of 50
if technicalCount < minRequired {
// Log full domain breakdown for debugging
for domain, count := range stats.ByDomain {
t.Logf(" domain %q: %d (%.0f%%)", domain, count, float64(count)/float64(stats.Total)*100)
}
// Also inspect the output JSONL for misclassified entries
outLines := strings.Split(strings.TrimSpace(output.String()), "\n")
for _, line := range outLines {
var record map[string]any
if err := json.Unmarshal([]byte(line), &record); err == nil {
if record["domain_1b"] != "technical" {
t.Logf(" misclassified: id=%v domain_1b=%v", record["id"], record["domain_1b"])
}
}
}
t.Errorf("accuracy too low: %d/%d (%.0f%%) classified as technical, want >= %d (80%%)",
technicalCount, stats.Total, float64(technicalCount)/float64(stats.Total)*100, minRequired)
}
}

15
integration/go.mod Normal file
View file

@ -0,0 +1,15 @@
module forge.lthn.ai/core/go-i18n/integration
go 1.25.5
require (
forge.lthn.ai/core/go-i18n v0.0.0-00010101000000-000000000000
forge.lthn.ai/core/go-inference v0.0.0-00010101000000-000000000000
forge.lthn.ai/core/go-mlx v0.0.0-00010101000000-000000000000
)
replace (
forge.lthn.ai/core/go-i18n => ../
forge.lthn.ai/core/go-inference => ../../go-inference
forge.lthn.ai/core/go-mlx => ../../go-mlx
)

View file

@ -325,6 +325,9 @@ func TestClassification_CorpusSize(t *testing.T) {
// exceeds cross-domain similarity. This is the basic requirement for domain
// classification to work.
func TestClassification_DomainSeparation(t *testing.T) {
if testing.Short() {
t.Skip("skipping slow domain separation test in short mode")
}
setup(t)
tok := NewTokeniser()
imprints := imprintCorpus(tok)
@ -374,6 +377,9 @@ func TestClassification_DomainSeparation(t *testing.T) {
// TestClassification_LeaveOneOut measures per-domain and overall accuracy
// using leave-one-out nearest-centroid classification.
func TestClassification_LeaveOneOut(t *testing.T) {
if testing.Short() {
t.Skip("skipping slow classification benchmark in short mode")
}
setup(t)
tok := NewTokeniser()
imprints := imprintCorpus(tok)