fix(classify): fail on batch result mismatch
Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
parent
bd5e299575
commit
cc1dd6b898
2 changed files with 39 additions and 4 deletions
|
|
@ -111,6 +111,13 @@ func ClassifyCorpus(ctx context.Context, model inference.TextModel,
|
|||
if err != nil {
|
||||
return log.E("ClassifyCorpus", "classify batch", err)
|
||||
}
|
||||
if len(results) != len(batch) {
|
||||
return log.E(
|
||||
"ClassifyCorpus",
|
||||
core.Sprintf("classify batch returned %d results for %d prompts", len(results), len(batch)),
|
||||
nil,
|
||||
)
|
||||
}
|
||||
for i, r := range results {
|
||||
domain := mapTokenToDomain(r.Token.Text)
|
||||
batch[i].record["domain_1b"] = domain
|
||||
|
|
|
|||
|
|
@ -67,11 +67,11 @@ func (m *mockModel) BatchGenerate(_ context.Context, _ []string, _ ...inference.
|
|||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockModel) ModelType() string { return "mock" }
|
||||
func (m *mockModel) Info() inference.ModelInfo { return inference.ModelInfo{} }
|
||||
func (m *mockModel) ModelType() string { return "mock" }
|
||||
func (m *mockModel) Info() inference.ModelInfo { return inference.ModelInfo{} }
|
||||
func (m *mockModel) Metrics() inference.GenerateMetrics { return inference.GenerateMetrics{} }
|
||||
func (m *mockModel) Err() error { return nil }
|
||||
func (m *mockModel) Close() error { return nil }
|
||||
func (m *mockModel) Err() error { return nil }
|
||||
func (m *mockModel) Close() error { return nil }
|
||||
|
||||
func TestClassifyCorpus_Basic(t *testing.T) {
|
||||
model := &mockModel{
|
||||
|
|
@ -183,3 +183,31 @@ func TestClassifyCorpus_DomainMapping(t *testing.T) {
|
|||
t.Errorf("ByDomain[ethical] = %d, want 1", stats.ByDomain["ethical"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestClassifyCorpus_ResultCountMismatch(t *testing.T) {
|
||||
model := &mockModel{
|
||||
classifyFunc: func(_ context.Context, prompts []string, _ ...inference.GenerateOption) ([]inference.ClassifyResult, error) {
|
||||
if len(prompts) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
return []inference.ClassifyResult{{Token: inference.Token{Text: "technical"}}}, nil
|
||||
},
|
||||
}
|
||||
|
||||
input := core.NewReader(
|
||||
`{"prompt":"Delete the file now"}` + "\n" +
|
||||
`{"prompt":"Create the repo"}` + "\n",
|
||||
)
|
||||
|
||||
var output bytes.Buffer
|
||||
stats, err := ClassifyCorpus(context.Background(), model, input, &output, WithBatchSize(16))
|
||||
if err == nil {
|
||||
t.Fatal("ClassifyCorpus returned nil error, want mismatch failure")
|
||||
}
|
||||
if stats.Total != 0 {
|
||||
t.Errorf("Total = %d, want 0", stats.Total)
|
||||
}
|
||||
if output.Len() != 0 {
|
||||
t.Errorf("output len = %d, want 0", output.Len())
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue