diff --git a/classify.go b/classify.go index 1db1be6..681a169 100644 --- a/classify.go +++ b/classify.go @@ -1,7 +1,10 @@ package i18n import ( + "bufio" "context" + "encoding/json" + "fmt" "io" "strings" "time" @@ -70,10 +73,95 @@ func mapTokenToDomain(token string) string { } } -// Ensure imports are used (ClassifyCorpus will be added in Task 3). -var ( - _ = (*inference.Token)(nil) - _ context.Context - _ io.Reader - _ time.Duration -) +// ClassifyCorpus reads JSONL from input, batch-classifies each entry through +// model, and writes JSONL with domain_1b field added to output. +func ClassifyCorpus(ctx context.Context, model inference.TextModel, + input io.Reader, output io.Writer, opts ...ClassifyOption) (*ClassifyStats, error) { + + cfg := defaultClassifyConfig() + for _, o := range opts { + o(&cfg) + } + + stats := &ClassifyStats{ByDomain: make(map[string]int)} + start := time.Now() + + scanner := bufio.NewScanner(input) + scanner.Buffer(make([]byte, 1024*1024), 1024*1024) + + type pending struct { + record map[string]any + prompt string + } + + var batch []pending + + flush := func() error { + if len(batch) == 0 { + return nil + } + prompts := make([]string, len(batch)) + for i, p := range batch { + prompts[i] = fmt.Sprintf(cfg.promptTemplate, p.prompt) + } + results, err := model.Classify(ctx, prompts, inference.WithMaxTokens(1)) + if err != nil { + return fmt.Errorf("classify batch: %w", err) + } + for i, r := range results { + domain := mapTokenToDomain(r.Token.Text) + batch[i].record["domain_1b"] = domain + stats.ByDomain[domain]++ + stats.Total++ + + line, err := json.Marshal(batch[i].record) + if err != nil { + return fmt.Errorf("marshal output: %w", err) + } + if _, err := fmt.Fprintf(output, "%s\n", line); err != nil { + return fmt.Errorf("write output: %w", err) + } + } + batch = batch[:0] + return nil + } + + for scanner.Scan() { + var record map[string]any + if err := json.Unmarshal(scanner.Bytes(), &record); err != nil { + stats.Skipped++ + continue + } + promptVal, ok := record[cfg.promptField] + if !ok { + stats.Skipped++ + continue + } + prompt, ok := promptVal.(string) + if !ok || prompt == "" { + stats.Skipped++ + continue + } + + batch = append(batch, pending{record: record, prompt: prompt}) + if len(batch) >= cfg.batchSize { + if err := flush(); err != nil { + return stats, err + } + } + } + + if err := scanner.Err(); err != nil { + return stats, fmt.Errorf("read input: %w", err) + } + if err := flush(); err != nil { + return stats, err + } + + stats.Duration = time.Since(start) + if stats.Duration > 0 { + stats.PromptsPerSec = float64(stats.Total) / stats.Duration.Seconds() + } + + return stats, nil +} diff --git a/classify_test.go b/classify_test.go index 315442c..d197bdd 100644 --- a/classify_test.go +++ b/classify_test.go @@ -1,6 +1,15 @@ package i18n -import "testing" +import ( + "bytes" + "context" + "encoding/json" + "iter" + "strings" + "testing" + + "forge.lthn.ai/core/go-inference" +) func TestMapTokenToDomain(t *testing.T) { tests := []struct { @@ -32,3 +41,141 @@ func TestMapTokenToDomain(t *testing.T) { }) } } + +// mockModel satisfies inference.TextModel for testing. +type mockModel struct { + classifyFunc func(ctx context.Context, prompts []string, opts ...inference.GenerateOption) ([]inference.ClassifyResult, error) +} + +func (m *mockModel) Generate(_ context.Context, _ string, _ ...inference.GenerateOption) iter.Seq[inference.Token] { + return func(yield func(inference.Token) bool) {} +} + +func (m *mockModel) Chat(_ context.Context, _ []inference.Message, _ ...inference.GenerateOption) iter.Seq[inference.Token] { + return func(yield func(inference.Token) bool) {} +} + +func (m *mockModel) Classify(ctx context.Context, prompts []string, opts ...inference.GenerateOption) ([]inference.ClassifyResult, error) { + return m.classifyFunc(ctx, prompts, opts...) +} + +func (m *mockModel) BatchGenerate(_ context.Context, _ []string, _ ...inference.GenerateOption) ([]inference.BatchResult, error) { + return nil, nil +} + +func (m *mockModel) ModelType() string { return "mock" } +func (m *mockModel) Info() inference.ModelInfo { return inference.ModelInfo{} } +func (m *mockModel) Metrics() inference.GenerateMetrics { return inference.GenerateMetrics{} } +func (m *mockModel) Err() error { return nil } +func (m *mockModel) Close() error { return nil } + +func TestClassifyCorpus_Basic(t *testing.T) { + model := &mockModel{ + classifyFunc: func(_ context.Context, prompts []string, _ ...inference.GenerateOption) ([]inference.ClassifyResult, error) { + results := make([]inference.ClassifyResult, len(prompts)) + for i := range prompts { + results[i] = inference.ClassifyResult{Token: inference.Token{Text: "technical"}} + } + return results, nil + }, + } + + input := strings.NewReader( + `{"seed_id":"1","domain":"general","prompt":"Delete the file"}` + "\n" + + `{"seed_id":"2","domain":"science","prompt":"Explain gravity"}` + "\n", + ) + var output bytes.Buffer + + stats, err := ClassifyCorpus(context.Background(), model, input, &output, WithBatchSize(16)) + if err != nil { + t.Fatalf("ClassifyCorpus returned error: %v", err) + } + if stats.Total != 2 { + t.Errorf("Total = %d, want 2", stats.Total) + } + if stats.Skipped != 0 { + t.Errorf("Skipped = %d, want 0", stats.Skipped) + } + + lines := strings.Split(strings.TrimSpace(output.String()), "\n") + if len(lines) != 2 { + t.Fatalf("output lines = %d, want 2", len(lines)) + } + + for i, line := range lines { + var record map[string]any + if err := json.Unmarshal([]byte(line), &record); err != nil { + t.Fatalf("line %d: unmarshal: %v", i, err) + } + if record["domain_1b"] != "technical" { + t.Errorf("line %d: domain_1b = %v, want %q", i, record["domain_1b"], "technical") + } + // original domain field must be preserved + if _, ok := record["domain"]; !ok { + t.Errorf("line %d: original domain field missing", i) + } + } +} + +func TestClassifyCorpus_SkipsMalformed(t *testing.T) { + model := &mockModel{ + classifyFunc: func(_ context.Context, prompts []string, _ ...inference.GenerateOption) ([]inference.ClassifyResult, error) { + results := make([]inference.ClassifyResult, len(prompts)) + for i := range prompts { + results[i] = inference.ClassifyResult{Token: inference.Token{Text: "technical"}} + } + return results, nil + }, + } + + input := strings.NewReader( + "not valid json\n" + + `{"seed_id":"1","domain":"general","prompt":"Hello world"}` + "\n" + + `{"seed_id":"2","domain":"general"}` + "\n", + ) + var output bytes.Buffer + + stats, err := ClassifyCorpus(context.Background(), model, input, &output) + if err != nil { + t.Fatalf("ClassifyCorpus returned error: %v", err) + } + if stats.Total != 1 { + t.Errorf("Total = %d, want 1", stats.Total) + } + if stats.Skipped != 2 { + t.Errorf("Skipped = %d, want 2", stats.Skipped) + } +} + +func TestClassifyCorpus_DomainMapping(t *testing.T) { + model := &mockModel{ + classifyFunc: func(_ context.Context, prompts []string, _ ...inference.GenerateOption) ([]inference.ClassifyResult, error) { + results := make([]inference.ClassifyResult, len(prompts)) + for i, p := range prompts { + if strings.Contains(p, "Delete") { + results[i] = inference.ClassifyResult{Token: inference.Token{Text: "technical"}} + } else { + results[i] = inference.ClassifyResult{Token: inference.Token{Text: "ethical"}} + } + } + return results, nil + }, + } + + input := strings.NewReader( + `{"prompt":"Delete the file now"}` + "\n" + + `{"prompt":"Is it right to lie?"}` + "\n", + ) + var output bytes.Buffer + + stats, err := ClassifyCorpus(context.Background(), model, input, &output, WithBatchSize(16)) + if err != nil { + t.Fatalf("ClassifyCorpus returned error: %v", err) + } + if stats.ByDomain["technical"] != 1 { + t.Errorf("ByDomain[technical] = %d, want 1", stats.ByDomain["technical"]) + } + if stats.ByDomain["ethical"] != 1 { + t.Errorf("ByDomain[ethical] = %d, want 1", stats.ByDomain["ethical"]) + } +}