fix(classify): fail on batch result mismatch
Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
parent
bd5e299575
commit
cc1dd6b898
2 changed files with 39 additions and 4 deletions
|
|
@ -111,6 +111,13 @@ func ClassifyCorpus(ctx context.Context, model inference.TextModel,
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return log.E("ClassifyCorpus", "classify batch", err)
|
return log.E("ClassifyCorpus", "classify batch", err)
|
||||||
}
|
}
|
||||||
|
if len(results) != len(batch) {
|
||||||
|
return log.E(
|
||||||
|
"ClassifyCorpus",
|
||||||
|
core.Sprintf("classify batch returned %d results for %d prompts", len(results), len(batch)),
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
}
|
||||||
for i, r := range results {
|
for i, r := range results {
|
||||||
domain := mapTokenToDomain(r.Token.Text)
|
domain := mapTokenToDomain(r.Token.Text)
|
||||||
batch[i].record["domain_1b"] = domain
|
batch[i].record["domain_1b"] = domain
|
||||||
|
|
|
||||||
|
|
@ -67,11 +67,11 @@ func (m *mockModel) BatchGenerate(_ context.Context, _ []string, _ ...inference.
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockModel) ModelType() string { return "mock" }
|
func (m *mockModel) ModelType() string { return "mock" }
|
||||||
func (m *mockModel) Info() inference.ModelInfo { return inference.ModelInfo{} }
|
func (m *mockModel) Info() inference.ModelInfo { return inference.ModelInfo{} }
|
||||||
func (m *mockModel) Metrics() inference.GenerateMetrics { return inference.GenerateMetrics{} }
|
func (m *mockModel) Metrics() inference.GenerateMetrics { return inference.GenerateMetrics{} }
|
||||||
func (m *mockModel) Err() error { return nil }
|
func (m *mockModel) Err() error { return nil }
|
||||||
func (m *mockModel) Close() error { return nil }
|
func (m *mockModel) Close() error { return nil }
|
||||||
|
|
||||||
func TestClassifyCorpus_Basic(t *testing.T) {
|
func TestClassifyCorpus_Basic(t *testing.T) {
|
||||||
model := &mockModel{
|
model := &mockModel{
|
||||||
|
|
@ -183,3 +183,31 @@ func TestClassifyCorpus_DomainMapping(t *testing.T) {
|
||||||
t.Errorf("ByDomain[ethical] = %d, want 1", stats.ByDomain["ethical"])
|
t.Errorf("ByDomain[ethical] = %d, want 1", stats.ByDomain["ethical"])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestClassifyCorpus_ResultCountMismatch(t *testing.T) {
|
||||||
|
model := &mockModel{
|
||||||
|
classifyFunc: func(_ context.Context, prompts []string, _ ...inference.GenerateOption) ([]inference.ClassifyResult, error) {
|
||||||
|
if len(prompts) == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return []inference.ClassifyResult{{Token: inference.Token{Text: "technical"}}}, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
input := core.NewReader(
|
||||||
|
`{"prompt":"Delete the file now"}` + "\n" +
|
||||||
|
`{"prompt":"Create the repo"}` + "\n",
|
||||||
|
)
|
||||||
|
|
||||||
|
var output bytes.Buffer
|
||||||
|
stats, err := ClassifyCorpus(context.Background(), model, input, &output, WithBatchSize(16))
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("ClassifyCorpus returned nil error, want mismatch failure")
|
||||||
|
}
|
||||||
|
if stats.Total != 0 {
|
||||||
|
t.Errorf("Total = %d, want 0", stats.Total)
|
||||||
|
}
|
||||||
|
if output.Len() != 0 {
|
||||||
|
t.Errorf("output len = %d, want 0", output.Len())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue