feat: implement ClassifyCorpus with streaming batch classification
Co-Authored-By: Virgil <virgil@lethean.io> Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
a5f3eb4777
commit
94fd2f463f
2 changed files with 243 additions and 8 deletions
102
classify.go
102
classify.go
|
|
@ -1,7 +1,10 @@
|
|||
package i18n
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
"time"
|
||||
|
|
@ -70,10 +73,95 @@ func mapTokenToDomain(token string) string {
|
|||
}
|
||||
}
|
||||
|
||||
// Ensure imports are used (ClassifyCorpus will be added in Task 3).
|
||||
var (
|
||||
_ = (*inference.Token)(nil)
|
||||
_ context.Context
|
||||
_ io.Reader
|
||||
_ time.Duration
|
||||
)
|
||||
// ClassifyCorpus reads JSONL from input, batch-classifies each entry through
|
||||
// model, and writes JSONL with domain_1b field added to output.
|
||||
func ClassifyCorpus(ctx context.Context, model inference.TextModel,
|
||||
input io.Reader, output io.Writer, opts ...ClassifyOption) (*ClassifyStats, error) {
|
||||
|
||||
cfg := defaultClassifyConfig()
|
||||
for _, o := range opts {
|
||||
o(&cfg)
|
||||
}
|
||||
|
||||
stats := &ClassifyStats{ByDomain: make(map[string]int)}
|
||||
start := time.Now()
|
||||
|
||||
scanner := bufio.NewScanner(input)
|
||||
scanner.Buffer(make([]byte, 1024*1024), 1024*1024)
|
||||
|
||||
type pending struct {
|
||||
record map[string]any
|
||||
prompt string
|
||||
}
|
||||
|
||||
var batch []pending
|
||||
|
||||
flush := func() error {
|
||||
if len(batch) == 0 {
|
||||
return nil
|
||||
}
|
||||
prompts := make([]string, len(batch))
|
||||
for i, p := range batch {
|
||||
prompts[i] = fmt.Sprintf(cfg.promptTemplate, p.prompt)
|
||||
}
|
||||
results, err := model.Classify(ctx, prompts, inference.WithMaxTokens(1))
|
||||
if err != nil {
|
||||
return fmt.Errorf("classify batch: %w", err)
|
||||
}
|
||||
for i, r := range results {
|
||||
domain := mapTokenToDomain(r.Token.Text)
|
||||
batch[i].record["domain_1b"] = domain
|
||||
stats.ByDomain[domain]++
|
||||
stats.Total++
|
||||
|
||||
line, err := json.Marshal(batch[i].record)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal output: %w", err)
|
||||
}
|
||||
if _, err := fmt.Fprintf(output, "%s\n", line); err != nil {
|
||||
return fmt.Errorf("write output: %w", err)
|
||||
}
|
||||
}
|
||||
batch = batch[:0]
|
||||
return nil
|
||||
}
|
||||
|
||||
for scanner.Scan() {
|
||||
var record map[string]any
|
||||
if err := json.Unmarshal(scanner.Bytes(), &record); err != nil {
|
||||
stats.Skipped++
|
||||
continue
|
||||
}
|
||||
promptVal, ok := record[cfg.promptField]
|
||||
if !ok {
|
||||
stats.Skipped++
|
||||
continue
|
||||
}
|
||||
prompt, ok := promptVal.(string)
|
||||
if !ok || prompt == "" {
|
||||
stats.Skipped++
|
||||
continue
|
||||
}
|
||||
|
||||
batch = append(batch, pending{record: record, prompt: prompt})
|
||||
if len(batch) >= cfg.batchSize {
|
||||
if err := flush(); err != nil {
|
||||
return stats, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
return stats, fmt.Errorf("read input: %w", err)
|
||||
}
|
||||
if err := flush(); err != nil {
|
||||
return stats, err
|
||||
}
|
||||
|
||||
stats.Duration = time.Since(start)
|
||||
if stats.Duration > 0 {
|
||||
stats.PromptsPerSec = float64(stats.Total) / stats.Duration.Seconds()
|
||||
}
|
||||
|
||||
return stats, nil
|
||||
}
|
||||
|
|
|
|||
149
classify_test.go
149
classify_test.go
|
|
@ -1,6 +1,15 @@
|
|||
package i18n
|
||||
|
||||
import "testing"
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"iter"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"forge.lthn.ai/core/go-inference"
|
||||
)
|
||||
|
||||
func TestMapTokenToDomain(t *testing.T) {
|
||||
tests := []struct {
|
||||
|
|
@ -32,3 +41,141 @@ func TestMapTokenToDomain(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
// mockModel satisfies inference.TextModel for testing.
|
||||
type mockModel struct {
|
||||
classifyFunc func(ctx context.Context, prompts []string, opts ...inference.GenerateOption) ([]inference.ClassifyResult, error)
|
||||
}
|
||||
|
||||
func (m *mockModel) Generate(_ context.Context, _ string, _ ...inference.GenerateOption) iter.Seq[inference.Token] {
|
||||
return func(yield func(inference.Token) bool) {}
|
||||
}
|
||||
|
||||
func (m *mockModel) Chat(_ context.Context, _ []inference.Message, _ ...inference.GenerateOption) iter.Seq[inference.Token] {
|
||||
return func(yield func(inference.Token) bool) {}
|
||||
}
|
||||
|
||||
func (m *mockModel) Classify(ctx context.Context, prompts []string, opts ...inference.GenerateOption) ([]inference.ClassifyResult, error) {
|
||||
return m.classifyFunc(ctx, prompts, opts...)
|
||||
}
|
||||
|
||||
func (m *mockModel) BatchGenerate(_ context.Context, _ []string, _ ...inference.GenerateOption) ([]inference.BatchResult, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockModel) ModelType() string { return "mock" }
|
||||
func (m *mockModel) Info() inference.ModelInfo { return inference.ModelInfo{} }
|
||||
func (m *mockModel) Metrics() inference.GenerateMetrics { return inference.GenerateMetrics{} }
|
||||
func (m *mockModel) Err() error { return nil }
|
||||
func (m *mockModel) Close() error { return nil }
|
||||
|
||||
func TestClassifyCorpus_Basic(t *testing.T) {
|
||||
model := &mockModel{
|
||||
classifyFunc: func(_ context.Context, prompts []string, _ ...inference.GenerateOption) ([]inference.ClassifyResult, error) {
|
||||
results := make([]inference.ClassifyResult, len(prompts))
|
||||
for i := range prompts {
|
||||
results[i] = inference.ClassifyResult{Token: inference.Token{Text: "technical"}}
|
||||
}
|
||||
return results, nil
|
||||
},
|
||||
}
|
||||
|
||||
input := strings.NewReader(
|
||||
`{"seed_id":"1","domain":"general","prompt":"Delete the file"}` + "\n" +
|
||||
`{"seed_id":"2","domain":"science","prompt":"Explain gravity"}` + "\n",
|
||||
)
|
||||
var output bytes.Buffer
|
||||
|
||||
stats, err := ClassifyCorpus(context.Background(), model, input, &output, WithBatchSize(16))
|
||||
if err != nil {
|
||||
t.Fatalf("ClassifyCorpus returned error: %v", err)
|
||||
}
|
||||
if stats.Total != 2 {
|
||||
t.Errorf("Total = %d, want 2", stats.Total)
|
||||
}
|
||||
if stats.Skipped != 0 {
|
||||
t.Errorf("Skipped = %d, want 0", stats.Skipped)
|
||||
}
|
||||
|
||||
lines := strings.Split(strings.TrimSpace(output.String()), "\n")
|
||||
if len(lines) != 2 {
|
||||
t.Fatalf("output lines = %d, want 2", len(lines))
|
||||
}
|
||||
|
||||
for i, line := range lines {
|
||||
var record map[string]any
|
||||
if err := json.Unmarshal([]byte(line), &record); err != nil {
|
||||
t.Fatalf("line %d: unmarshal: %v", i, err)
|
||||
}
|
||||
if record["domain_1b"] != "technical" {
|
||||
t.Errorf("line %d: domain_1b = %v, want %q", i, record["domain_1b"], "technical")
|
||||
}
|
||||
// original domain field must be preserved
|
||||
if _, ok := record["domain"]; !ok {
|
||||
t.Errorf("line %d: original domain field missing", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestClassifyCorpus_SkipsMalformed(t *testing.T) {
|
||||
model := &mockModel{
|
||||
classifyFunc: func(_ context.Context, prompts []string, _ ...inference.GenerateOption) ([]inference.ClassifyResult, error) {
|
||||
results := make([]inference.ClassifyResult, len(prompts))
|
||||
for i := range prompts {
|
||||
results[i] = inference.ClassifyResult{Token: inference.Token{Text: "technical"}}
|
||||
}
|
||||
return results, nil
|
||||
},
|
||||
}
|
||||
|
||||
input := strings.NewReader(
|
||||
"not valid json\n" +
|
||||
`{"seed_id":"1","domain":"general","prompt":"Hello world"}` + "\n" +
|
||||
`{"seed_id":"2","domain":"general"}` + "\n",
|
||||
)
|
||||
var output bytes.Buffer
|
||||
|
||||
stats, err := ClassifyCorpus(context.Background(), model, input, &output)
|
||||
if err != nil {
|
||||
t.Fatalf("ClassifyCorpus returned error: %v", err)
|
||||
}
|
||||
if stats.Total != 1 {
|
||||
t.Errorf("Total = %d, want 1", stats.Total)
|
||||
}
|
||||
if stats.Skipped != 2 {
|
||||
t.Errorf("Skipped = %d, want 2", stats.Skipped)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClassifyCorpus_DomainMapping(t *testing.T) {
|
||||
model := &mockModel{
|
||||
classifyFunc: func(_ context.Context, prompts []string, _ ...inference.GenerateOption) ([]inference.ClassifyResult, error) {
|
||||
results := make([]inference.ClassifyResult, len(prompts))
|
||||
for i, p := range prompts {
|
||||
if strings.Contains(p, "Delete") {
|
||||
results[i] = inference.ClassifyResult{Token: inference.Token{Text: "technical"}}
|
||||
} else {
|
||||
results[i] = inference.ClassifyResult{Token: inference.Token{Text: "ethical"}}
|
||||
}
|
||||
}
|
||||
return results, nil
|
||||
},
|
||||
}
|
||||
|
||||
input := strings.NewReader(
|
||||
`{"prompt":"Delete the file now"}` + "\n" +
|
||||
`{"prompt":"Is it right to lie?"}` + "\n",
|
||||
)
|
||||
var output bytes.Buffer
|
||||
|
||||
stats, err := ClassifyCorpus(context.Background(), model, input, &output, WithBatchSize(16))
|
||||
if err != nil {
|
||||
t.Fatalf("ClassifyCorpus returned error: %v", err)
|
||||
}
|
||||
if stats.ByDomain["technical"] != 1 {
|
||||
t.Errorf("ByDomain[technical] = %d, want 1", stats.ByDomain["technical"])
|
||||
}
|
||||
if stats.ByDomain["ethical"] != 1 {
|
||||
t.Errorf("ByDomain[ethical] = %d, want 1", stats.ByDomain["ethical"])
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue