feat: add classify types and token-to-domain mapper
Co-Authored-By: Virgil <virgil@lethean.io> Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
c05e3fc283
commit
a5f3eb4777
3 changed files with 115 additions and 0 deletions
79
classify.go
Normal file
79
classify.go
Normal file
|
|
@ -0,0 +1,79 @@
|
|||
package i18n
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"forge.lthn.ai/core/go-inference"
|
||||
)
|
||||
|
||||
// ClassifyStats reports metrics from a ClassifyCorpus run.
|
||||
type ClassifyStats struct {
|
||||
Total int
|
||||
Skipped int // malformed or missing prompt field
|
||||
ByDomain map[string]int // domain_1b label -> count
|
||||
Duration time.Duration
|
||||
PromptsPerSec float64
|
||||
}
|
||||
|
||||
// ClassifyOption configures ClassifyCorpus behaviour.
|
||||
type ClassifyOption func(*classifyConfig)
|
||||
|
||||
type classifyConfig struct {
|
||||
batchSize int
|
||||
promptField string
|
||||
promptTemplate string
|
||||
}
|
||||
|
||||
func defaultClassifyConfig() classifyConfig {
|
||||
return classifyConfig{
|
||||
batchSize: 8,
|
||||
promptField: "prompt",
|
||||
promptTemplate: "Classify this text into exactly one category: technical, creative, ethical, casual.\n\nText: %s\n\nCategory:",
|
||||
}
|
||||
}
|
||||
|
||||
// WithBatchSize sets the number of prompts per Classify call. Default 8.
|
||||
func WithBatchSize(n int) ClassifyOption {
|
||||
return func(c *classifyConfig) { c.batchSize = n }
|
||||
}
|
||||
|
||||
// WithPromptField sets which JSONL field contains the text to classify. Default "prompt".
|
||||
func WithPromptField(field string) ClassifyOption {
|
||||
return func(c *classifyConfig) { c.promptField = field }
|
||||
}
|
||||
|
||||
// WithPromptTemplate sets the classification prompt. Use %s for the text placeholder.
|
||||
func WithPromptTemplate(tmpl string) ClassifyOption {
|
||||
return func(c *classifyConfig) { c.promptTemplate = tmpl }
|
||||
}
|
||||
|
||||
// mapTokenToDomain maps a model output token to a 4-way domain label.
|
||||
func mapTokenToDomain(token string) string {
|
||||
if len(token) == 0 {
|
||||
return "unknown"
|
||||
}
|
||||
lower := strings.ToLower(token)
|
||||
switch {
|
||||
case strings.HasPrefix(lower, "tech"):
|
||||
return "technical"
|
||||
case strings.HasPrefix(lower, "cre"):
|
||||
return "creative"
|
||||
case strings.HasPrefix(lower, "eth"):
|
||||
return "ethical"
|
||||
case strings.HasPrefix(lower, "cas"):
|
||||
return "casual"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure imports are used (ClassifyCorpus will be added in Task 3).
|
||||
var (
|
||||
_ = (*inference.Token)(nil)
|
||||
_ context.Context
|
||||
_ io.Reader
|
||||
_ time.Duration
|
||||
)
|
||||
34
classify_test.go
Normal file
34
classify_test.go
Normal file
|
|
@ -0,0 +1,34 @@
|
|||
package i18n
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestMapTokenToDomain(t *testing.T) {
|
||||
tests := []struct {
|
||||
token string
|
||||
want string
|
||||
}{
|
||||
{"technical", "technical"},
|
||||
{"Technical", "technical"},
|
||||
{"tech", "technical"},
|
||||
{"creative", "creative"},
|
||||
{"Creative", "creative"},
|
||||
{"cre", "creative"},
|
||||
{"ethical", "ethical"},
|
||||
{"Ethical", "ethical"},
|
||||
{"eth", "ethical"},
|
||||
{"casual", "casual"},
|
||||
{"Casual", "casual"},
|
||||
{"cas", "casual"},
|
||||
{"unknown", "unknown"},
|
||||
{"", "unknown"},
|
||||
{"foo", "unknown"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.token, func(t *testing.T) {
|
||||
got := mapTokenToDomain(tt.token)
|
||||
if got != tt.want {
|
||||
t.Errorf("mapTokenToDomain(%q) = %q, want %q", tt.token, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
2
go.mod
2
go.mod
|
|
@ -4,4 +4,6 @@ go 1.25.5
|
|||
|
||||
require golang.org/x/text v0.33.0
|
||||
|
||||
require forge.lthn.ai/core/go-inference v0.0.0-20260219234405-c61ec9f5c724 // indirect
|
||||
|
||||
replace forge.lthn.ai/core/go-inference => ../go-inference
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue