fix(classify): fail on batch result mismatch
Some checks are pending
Security Scan / security (push) Waiting to run
Test / test (push) Waiting to run

Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
Virgil 2026-04-03 07:31:14 +00:00
parent bd5e299575
commit cc1dd6b898
2 changed files with 39 additions and 4 deletions

View file

@ -111,6 +111,13 @@ func ClassifyCorpus(ctx context.Context, model inference.TextModel,
if err != nil { if err != nil {
return log.E("ClassifyCorpus", "classify batch", err) return log.E("ClassifyCorpus", "classify batch", err)
} }
if len(results) != len(batch) {
return log.E(
"ClassifyCorpus",
core.Sprintf("classify batch returned %d results for %d prompts", len(results), len(batch)),
nil,
)
}
for i, r := range results { for i, r := range results {
domain := mapTokenToDomain(r.Token.Text) domain := mapTokenToDomain(r.Token.Text)
batch[i].record["domain_1b"] = domain batch[i].record["domain_1b"] = domain

View file

@ -67,11 +67,11 @@ func (m *mockModel) BatchGenerate(_ context.Context, _ []string, _ ...inference.
return nil, nil return nil, nil
} }
func (m *mockModel) ModelType() string { return "mock" } func (m *mockModel) ModelType() string { return "mock" }
func (m *mockModel) Info() inference.ModelInfo { return inference.ModelInfo{} } func (m *mockModel) Info() inference.ModelInfo { return inference.ModelInfo{} }
func (m *mockModel) Metrics() inference.GenerateMetrics { return inference.GenerateMetrics{} } func (m *mockModel) Metrics() inference.GenerateMetrics { return inference.GenerateMetrics{} }
func (m *mockModel) Err() error { return nil } func (m *mockModel) Err() error { return nil }
func (m *mockModel) Close() error { return nil } func (m *mockModel) Close() error { return nil }
func TestClassifyCorpus_Basic(t *testing.T) { func TestClassifyCorpus_Basic(t *testing.T) {
model := &mockModel{ model := &mockModel{
@ -183,3 +183,31 @@ func TestClassifyCorpus_DomainMapping(t *testing.T) {
t.Errorf("ByDomain[ethical] = %d, want 1", stats.ByDomain["ethical"]) t.Errorf("ByDomain[ethical] = %d, want 1", stats.ByDomain["ethical"])
} }
} }
func TestClassifyCorpus_ResultCountMismatch(t *testing.T) {
model := &mockModel{
classifyFunc: func(_ context.Context, prompts []string, _ ...inference.GenerateOption) ([]inference.ClassifyResult, error) {
if len(prompts) == 0 {
return nil, nil
}
return []inference.ClassifyResult{{Token: inference.Token{Text: "technical"}}}, nil
},
}
input := core.NewReader(
`{"prompt":"Delete the file now"}` + "\n" +
`{"prompt":"Create the repo"}` + "\n",
)
var output bytes.Buffer
stats, err := ClassifyCorpus(context.Background(), model, input, &output, WithBatchSize(16))
if err == nil {
t.Fatal("ClassifyCorpus returned nil error, want mismatch failure")
}
if stats.Total != 0 {
t.Errorf("Total = %d, want 0", stats.Total)
}
if output.Len() != 0 {
t.Errorf("output len = %d, want 0", output.Len())
}
}