fix: address Virgil review — 5 fixes for classify pipeline
- Remove go-mlx from go.mod (breaks non-darwin builds) - Fix go-inference pseudo-version for CI compatibility - Fix mapTokenToDomain prefix collision (castle, credential) - Add testing.Short() skip to slow classification benchmarks - Add 80% accuracy threshold to integration test Integration test moved to integration/ sub-module with its own go.mod to cleanly isolate go-mlx dependency from the main module. Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
parent
c23a271716
commit
ff376830c0
8 changed files with 114 additions and 62 deletions
10
TODO.md
10
TODO.md
|
|
@ -59,15 +59,15 @@ models, _ := inference.Discover("/Volumes/Data/lem/")
|
|||
|
||||
**Do these first, in order, before picking up the next Phase 2a task.**
|
||||
|
||||
- [ ] **Fix go.mod: remove go-mlx from module require** — go-mlx is darwin/arm64 CGO. Having it in go.mod makes go-i18n uncompilable on Linux and breaks any downstream consumer that vendors this module. go-i18n's founding principle is "only dependency: golang.org/x/text" plus go-inference. The integration test imports go-mlx behind a `//go:build integration` tag, which is fine — but the go.mod `require` is unconditional. Fix: remove the `require` and `replace` lines for go-mlx from go.mod. For local integration test builds, use a go.work file instead. Verify `go mod tidy` succeeds without go-mlx.
|
||||
- [x] **Fix go.mod: remove go-mlx from module require** — Removed go-mlx `require` and `replace` from go.mod. Moved integration test to `integration/` sub-module with its own go.mod that depends on go-mlx. Main module now compiles cleanly on all platforms. `go mod tidy` no longer pulls go-mlx.
|
||||
|
||||
- [ ] **Fix go.mod: go-inference pseudo-version** — `v0.0.0` without a timestamp is not a valid Go module version. Run `go get forge.lthn.ai/core/go-inference@latest` with the workspace active to compute a proper pseudo-version (like `v0.0.0-20260219...`), or add a proper `replace` and let `go mod tidy` resolve it. This prevents `go mod verify` failures in CI.
|
||||
- [x] **Fix go.mod: go-inference pseudo-version** — `go mod tidy` resolved to the standard replaced-module pseudo-version `v0.0.0-00010101000000-000000000000`. CI-safe.
|
||||
|
||||
- [ ] **Fix mapTokenToDomain prefix collision** — `strings.HasPrefix(lower, "cas")` matches "castle", "cascade", etc. — not just "casual". Same risk with "cre" matching "credential". Use exact match with a short-prefix fallback only for known token fragments: `lower == "casual" || lower == "cas"`. Add a comment explaining why prefix matching exists (BPE token fragmentation can produce partial words). Add test cases for "castle" and "credential" to verify they return "unknown".
|
||||
- [x] **Fix mapTokenToDomain prefix collision** — Replaced `strings.HasPrefix` with exact match + known BPE fragment fallback. Added test cases for "castle", "cascade", "credential", "creature" — all return "unknown".
|
||||
|
||||
- [ ] **Fix classify_bench_test.go naming** — File contains 6 `Test*` functions that run on every `go test ./...`. The O(n^2) `TestClassification_LeaveOneOut` (220 sentences, ~48K similarity computations) is slow for a normal test run. Either: (a) rename to `classify_test.go` and add `testing.Short()` skip to slow tests, or (b) add a `//go:build !short` build tag. Option (a) preferred.
|
||||
- [x] **Fix classify_bench_test.go naming** — Added `testing.Short()` skip to `TestClassification_DomainSeparation` and `TestClassification_LeaveOneOut` (the two O(n^2) tests). Verified with `go test -short -v`.
|
||||
|
||||
- [ ] **Add accuracy assertion to integration test** — Currently only checks `Total == 50` and `Skipped == 0`. If the model returns all "unknown", the test still passes. Add a minimum threshold: at least 80% of the 50 technical prompts should classify as "technical". The FINDINGS data shows 100% accuracy on controlled technical input, so 80% is a conservative floor.
|
||||
- [x] **Add accuracy assertion to integration test** — Integration test now asserts at least 80% (40/50) of technical prompts classified as "technical". Logs full domain breakdown and misclassified entries on failure. Test moved to `integration/` sub-module.
|
||||
|
||||
### Remaining Phase 2a Tasks
|
||||
|
||||
|
|
|
|||
12
classify.go
12
classify.go
|
|
@ -54,19 +54,23 @@ func WithPromptTemplate(tmpl string) ClassifyOption {
|
|||
}
|
||||
|
||||
// mapTokenToDomain maps a model output token to a 4-way domain label.
|
||||
// Prefix matching exists because BPE tokenisation can fragment words into
|
||||
// partial tokens (e.g. "cas" from "casual", "cre" from "creative"). We
|
||||
// only match the known short fragments that actually appear in BPE output,
|
||||
// NOT arbitrary prefixes like "cas" which would collide with "castle" etc.
|
||||
func mapTokenToDomain(token string) string {
|
||||
if len(token) == 0 {
|
||||
return "unknown"
|
||||
}
|
||||
lower := strings.ToLower(token)
|
||||
switch {
|
||||
case strings.HasPrefix(lower, "tech"):
|
||||
case lower == "technical" || lower == "tech":
|
||||
return "technical"
|
||||
case strings.HasPrefix(lower, "cre"):
|
||||
case lower == "creative" || lower == "cre":
|
||||
return "creative"
|
||||
case strings.HasPrefix(lower, "eth"):
|
||||
case lower == "ethical" || lower == "eth":
|
||||
return "ethical"
|
||||
case strings.HasPrefix(lower, "cas"):
|
||||
case lower == "casual" || lower == "cas":
|
||||
return "casual"
|
||||
default:
|
||||
return "unknown"
|
||||
|
|
|
|||
|
|
@ -1,48 +0,0 @@
|
|||
//go:build integration
|
||||
|
||||
package i18n
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"forge.lthn.ai/core/go-inference"
|
||||
_ "forge.lthn.ai/core/go-mlx" // registers Metal backend
|
||||
)
|
||||
|
||||
func TestClassifyCorpus_Integration(t *testing.T) {
|
||||
model, err := inference.LoadModel("/Volumes/Data/lem/LEM-Gemma3-1B-layered-v2")
|
||||
if err != nil {
|
||||
t.Skipf("model not available: %v", err)
|
||||
}
|
||||
defer model.Close()
|
||||
|
||||
// Build 50 technical prompts for throughput measurement
|
||||
var lines []string
|
||||
for i := 0; i < 50; i++ {
|
||||
lines = append(lines, fmt.Sprintf(`{"id":%d,"prompt":"Delete the configuration file and rebuild the project"}`, i))
|
||||
}
|
||||
input := strings.NewReader(strings.Join(lines, "\n") + "\n")
|
||||
|
||||
var output bytes.Buffer
|
||||
start := time.Now()
|
||||
stats, err := ClassifyCorpus(context.Background(), model, input, &output, WithBatchSize(8))
|
||||
if err != nil {
|
||||
t.Fatalf("ClassifyCorpus: %v", err)
|
||||
}
|
||||
elapsed := time.Since(start)
|
||||
|
||||
t.Logf("Classified %d prompts in %v (%.1f prompts/sec)", stats.Total, elapsed, stats.PromptsPerSec)
|
||||
t.Logf("By domain: %v", stats.ByDomain)
|
||||
|
||||
if stats.Total != 50 {
|
||||
t.Errorf("Total = %d, want 50", stats.Total)
|
||||
}
|
||||
if stats.Skipped != 0 {
|
||||
t.Errorf("Skipped = %d, want 0", stats.Skipped)
|
||||
}
|
||||
}
|
||||
|
|
@ -31,6 +31,11 @@ func TestMapTokenToDomain(t *testing.T) {
|
|||
{"unknown", "unknown"},
|
||||
{"", "unknown"},
|
||||
{"foo", "unknown"},
|
||||
// Verify prefix collision fix: these must NOT match any domain
|
||||
{"castle", "unknown"},
|
||||
{"cascade", "unknown"},
|
||||
{"credential", "unknown"},
|
||||
{"creature", "unknown"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.token, func(t *testing.T) {
|
||||
|
|
|
|||
6
go.mod
6
go.mod
|
|
@ -4,10 +4,6 @@ go 1.25.5
|
|||
|
||||
require golang.org/x/text v0.33.0
|
||||
|
||||
require forge.lthn.ai/core/go-inference v0.0.0
|
||||
|
||||
require forge.lthn.ai/core/go-mlx v0.0.0-20260219234407-d1fb26d51e62
|
||||
require forge.lthn.ai/core/go-inference v0.0.0-00010101000000-000000000000
|
||||
|
||||
replace forge.lthn.ai/core/go-inference => ../go-inference
|
||||
|
||||
replace forge.lthn.ai/core/go-mlx => ../go-mlx
|
||||
|
|
|
|||
74
integration/classify_test.go
Normal file
74
integration/classify_test.go
Normal file
|
|
@ -0,0 +1,74 @@
|
|||
package integration
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
i18n "forge.lthn.ai/core/go-i18n"
|
||||
"forge.lthn.ai/core/go-inference"
|
||||
_ "forge.lthn.ai/core/go-mlx" // registers Metal backend
|
||||
)
|
||||
|
||||
func TestClassifyCorpus_Integration(t *testing.T) {
|
||||
model, err := inference.LoadModel("/Volumes/Data/lem/LEM-Gemma3-1B-layered-v2")
|
||||
if err != nil {
|
||||
t.Skipf("model not available: %v", err)
|
||||
}
|
||||
defer model.Close()
|
||||
|
||||
// Build 50 technical prompts for throughput measurement
|
||||
var lines []string
|
||||
for i := 0; i < 50; i++ {
|
||||
lines = append(lines, fmt.Sprintf(`{"id":%d,"prompt":"Delete the configuration file and rebuild the project"}`, i))
|
||||
}
|
||||
input := strings.NewReader(strings.Join(lines, "\n") + "\n")
|
||||
|
||||
var output bytes.Buffer
|
||||
start := time.Now()
|
||||
stats, err := i18n.ClassifyCorpus(context.Background(), model, input, &output, i18n.WithBatchSize(8))
|
||||
if err != nil {
|
||||
t.Fatalf("ClassifyCorpus: %v", err)
|
||||
}
|
||||
elapsed := time.Since(start)
|
||||
|
||||
t.Logf("Classified %d prompts in %v (%.1f prompts/sec)", stats.Total, elapsed, stats.PromptsPerSec)
|
||||
t.Logf("By domain: %v", stats.ByDomain)
|
||||
|
||||
if stats.Total != 50 {
|
||||
t.Errorf("Total = %d, want 50", stats.Total)
|
||||
}
|
||||
if stats.Skipped != 0 {
|
||||
t.Errorf("Skipped = %d, want 0", stats.Skipped)
|
||||
}
|
||||
|
||||
// Fix 5: Assert minimum accuracy — at least 80% of the 50 technical prompts
|
||||
// should be classified as "technical". The FINDINGS data shows 100% accuracy
|
||||
// on controlled technical input, so 80% is a conservative floor.
|
||||
technicalCount := stats.ByDomain["technical"]
|
||||
minRequired := 40 // 80% of 50
|
||||
if technicalCount < minRequired {
|
||||
// Log full domain breakdown for debugging
|
||||
for domain, count := range stats.ByDomain {
|
||||
t.Logf(" domain %q: %d (%.0f%%)", domain, count, float64(count)/float64(stats.Total)*100)
|
||||
}
|
||||
|
||||
// Also inspect the output JSONL for misclassified entries
|
||||
outLines := strings.Split(strings.TrimSpace(output.String()), "\n")
|
||||
for _, line := range outLines {
|
||||
var record map[string]any
|
||||
if err := json.Unmarshal([]byte(line), &record); err == nil {
|
||||
if record["domain_1b"] != "technical" {
|
||||
t.Logf(" misclassified: id=%v domain_1b=%v", record["id"], record["domain_1b"])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
t.Errorf("accuracy too low: %d/%d (%.0f%%) classified as technical, want >= %d (80%%)",
|
||||
technicalCount, stats.Total, float64(technicalCount)/float64(stats.Total)*100, minRequired)
|
||||
}
|
||||
}
|
||||
15
integration/go.mod
Normal file
15
integration/go.mod
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
module forge.lthn.ai/core/go-i18n/integration
|
||||
|
||||
go 1.25.5
|
||||
|
||||
require (
|
||||
forge.lthn.ai/core/go-i18n v0.0.0-00010101000000-000000000000
|
||||
forge.lthn.ai/core/go-inference v0.0.0-00010101000000-000000000000
|
||||
forge.lthn.ai/core/go-mlx v0.0.0-00010101000000-000000000000
|
||||
)
|
||||
|
||||
replace (
|
||||
forge.lthn.ai/core/go-i18n => ../
|
||||
forge.lthn.ai/core/go-inference => ../../go-inference
|
||||
forge.lthn.ai/core/go-mlx => ../../go-mlx
|
||||
)
|
||||
|
|
@ -325,6 +325,9 @@ func TestClassification_CorpusSize(t *testing.T) {
|
|||
// exceeds cross-domain similarity. This is the basic requirement for domain
|
||||
// classification to work.
|
||||
func TestClassification_DomainSeparation(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping slow domain separation test in short mode")
|
||||
}
|
||||
setup(t)
|
||||
tok := NewTokeniser()
|
||||
imprints := imprintCorpus(tok)
|
||||
|
|
@ -374,6 +377,9 @@ func TestClassification_DomainSeparation(t *testing.T) {
|
|||
// TestClassification_LeaveOneOut measures per-domain and overall accuracy
|
||||
// using leave-one-out nearest-centroid classification.
|
||||
func TestClassification_LeaveOneOut(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping slow classification benchmark in short mode")
|
||||
}
|
||||
setup(t)
|
||||
tok := NewTokeniser()
|
||||
imprints := imprintCorpus(tok)
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue