- Stream parquet export rows instead of unbounded memory allocation - Replace QueryGoldenSet/QueryExpansionPrompts with iter.Seq2 iterators - Remove legacy runtime.GC() calls from distill (go-mlx handles cleanup) - Replace log.Fatalf with error return in tier_score.go - Add SIGINT/SIGTERM signal handling to agent and worker daemon loops - Add error checks for unchecked db.conn.Exec in import.go and tier_score.go - Update tests for iterator-based database methods Co-Authored-By: Gemini <noreply@google.com> Co-Authored-By: Virgil <virgil@lethean.io>
292 lines
8.5 KiB
Go
292 lines
8.5 KiB
Go
package lem
|
|
|
|
import (
|
|
"bufio"
|
|
"encoding/json"
|
|
"fmt"
|
|
"log"
|
|
"math/rand"
|
|
"os"
|
|
"strings"
|
|
)
|
|
|
|
// ConvOpts holds configuration for the conv command.
|
|
type ConvOpts struct {
|
|
OutputDir string // Output directory for training files (required)
|
|
Extra string // Additional conversations JSONL file (multi-turn format)
|
|
Golden string // Golden set JSONL to convert to single-turn conversations
|
|
DB string // DuckDB database path for golden set (alternative to --golden)
|
|
TrainPct int // Training set percentage
|
|
ValidPct int // Validation set percentage
|
|
TestPct int // Test set percentage
|
|
Seed int64 // Random seed for shuffling
|
|
MinChars int // Minimum response chars for golden set conversion
|
|
NoBuiltin bool // Exclude built-in seed conversations
|
|
Influx string // InfluxDB URL for progress reporting
|
|
InfluxDB string // InfluxDB database name
|
|
Worker string // Worker hostname for InfluxDB reporting
|
|
}
|
|
|
|
// RunConv generates multi-turn conversational training data from built-in
|
|
// seed conversations plus optional extra files and golden set data.
|
|
func RunConv(cfg ConvOpts) error {
|
|
if cfg.OutputDir == "" {
|
|
return fmt.Errorf("--output-dir is required")
|
|
}
|
|
|
|
if err := validatePercentages(cfg.TrainPct, cfg.ValidPct, cfg.TestPct); err != nil {
|
|
return fmt.Errorf("percentages: %w", err)
|
|
}
|
|
|
|
// Check LEM_DB env as default for --db.
|
|
if cfg.DB == "" {
|
|
cfg.DB = os.Getenv("LEM_DB")
|
|
}
|
|
|
|
// Default worker to hostname.
|
|
if cfg.Worker == "" {
|
|
hostname, err := os.Hostname()
|
|
if err != nil {
|
|
hostname = "unknown"
|
|
}
|
|
cfg.Worker = hostname
|
|
}
|
|
|
|
// Collect all conversations.
|
|
var conversations []TrainingExample
|
|
|
|
// 1. Built-in seed conversations.
|
|
if !cfg.NoBuiltin {
|
|
conversations = append(conversations, SeedConversations...)
|
|
log.Printf("loaded %d built-in seed conversations", len(SeedConversations))
|
|
}
|
|
|
|
// 2. Extra conversations from file.
|
|
if cfg.Extra != "" {
|
|
extras, err := readConversations(cfg.Extra)
|
|
if err != nil {
|
|
return fmt.Errorf("read extra conversations: %w", err)
|
|
}
|
|
conversations = append(conversations, extras...)
|
|
log.Printf("loaded %d extra conversations from %s", len(extras), cfg.Extra)
|
|
}
|
|
|
|
// 3. Golden set responses converted to single-turn format.
|
|
var goldenResponses []Response
|
|
if cfg.DB != "" && cfg.Golden == "" {
|
|
db, err := OpenDB(cfg.DB)
|
|
if err != nil {
|
|
return fmt.Errorf("open db: %w", err)
|
|
}
|
|
defer db.Close()
|
|
|
|
for r, err := range db.GoldenSet(cfg.MinChars) {
|
|
if err != nil {
|
|
return fmt.Errorf("golden set error: %w", err)
|
|
}
|
|
goldenResponses = append(goldenResponses, Response{
|
|
ID: r.SeedID,
|
|
Domain: r.Domain,
|
|
Prompt: r.Prompt,
|
|
Response: r.Response,
|
|
Model: r.Voice,
|
|
})
|
|
}
|
|
log.Printf("loaded %d golden set rows from %s", len(goldenResponses), cfg.DB)
|
|
} else if cfg.Golden != "" {
|
|
var err error
|
|
goldenResponses, err = ReadResponses(cfg.Golden)
|
|
if err != nil {
|
|
return fmt.Errorf("read golden set: %w", err)
|
|
}
|
|
log.Printf("loaded %d golden set responses from %s", len(goldenResponses), cfg.Golden)
|
|
}
|
|
|
|
if len(goldenResponses) > 0 {
|
|
converted := convertToConversations(goldenResponses, cfg.MinChars)
|
|
conversations = append(conversations, converted...)
|
|
log.Printf("converted %d golden set responses to single-turn conversations", len(converted))
|
|
}
|
|
|
|
if len(conversations) == 0 {
|
|
return fmt.Errorf("no conversations to process — use built-in seeds, --extra, --golden, or --db")
|
|
}
|
|
|
|
// Split into train/valid/test.
|
|
train, valid, test := splitConversations(conversations, cfg.TrainPct, cfg.ValidPct, cfg.TestPct, cfg.Seed)
|
|
|
|
// Create output directory.
|
|
if err := os.MkdirAll(cfg.OutputDir, 0755); err != nil {
|
|
return fmt.Errorf("create output dir: %w", err)
|
|
}
|
|
|
|
// Write output files.
|
|
for _, split := range []struct {
|
|
name string
|
|
data []TrainingExample
|
|
}{
|
|
{"train.jsonl", train},
|
|
{"valid.jsonl", valid},
|
|
{"test.jsonl", test},
|
|
} {
|
|
path := cfg.OutputDir + "/" + split.name
|
|
if err := writeConversationJSONL(path, split.data); err != nil {
|
|
return fmt.Errorf("write %s: %w", split.name, err)
|
|
}
|
|
}
|
|
|
|
// Stats.
|
|
totalTurns := 0
|
|
totalAssistantWords := 0
|
|
assistantMsgCount := 0
|
|
for _, c := range conversations {
|
|
totalTurns += len(c.Messages)
|
|
for _, m := range c.Messages {
|
|
if m.Role == "assistant" {
|
|
totalAssistantWords += len(strings.Fields(m.Content))
|
|
assistantMsgCount++
|
|
}
|
|
}
|
|
}
|
|
avgTurns := float64(totalTurns) / float64(len(conversations))
|
|
avgWords := 0.0
|
|
if assistantMsgCount > 0 {
|
|
avgWords = float64(totalAssistantWords) / float64(assistantMsgCount)
|
|
}
|
|
|
|
fmt.Printf("Conversational training data generated:\n")
|
|
fmt.Printf(" %d train / %d valid / %d test\n", len(train), len(valid), len(test))
|
|
fmt.Printf(" %d total conversations\n", len(conversations))
|
|
fmt.Printf(" %d total turns (%.1f avg per conversation)\n", totalTurns, avgTurns)
|
|
fmt.Printf(" %.0f words avg per assistant response\n", avgWords)
|
|
fmt.Printf(" Output: %s/\n", cfg.OutputDir)
|
|
|
|
// Report to InfluxDB if configured.
|
|
influx := NewInfluxClient(cfg.Influx, cfg.InfluxDB)
|
|
line := fmt.Sprintf("conv_export,worker=%s total=%di,train=%di,valid=%di,test=%di,turns=%di,avg_turns=%f,avg_words=%f",
|
|
escapeLp(cfg.Worker), len(conversations), len(train), len(valid), len(test),
|
|
totalTurns, avgTurns, avgWords)
|
|
if err := influx.WriteLp([]string{line}); err != nil {
|
|
log.Printf("influx write (best-effort): %v", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// readConversations reads multi-turn conversations from a JSONL file.
|
|
// Each line must be a TrainingExample with a messages array.
|
|
func readConversations(path string) ([]TrainingExample, error) {
|
|
f, err := os.Open(path)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("open %s: %w", path, err)
|
|
}
|
|
defer f.Close()
|
|
|
|
var conversations []TrainingExample
|
|
scanner := bufio.NewScanner(f)
|
|
scanner.Buffer(make([]byte, 1024*1024), 1024*1024)
|
|
|
|
lineNum := 0
|
|
for scanner.Scan() {
|
|
lineNum++
|
|
line := strings.TrimSpace(scanner.Text())
|
|
if line == "" {
|
|
continue
|
|
}
|
|
|
|
var te TrainingExample
|
|
if err := json.Unmarshal([]byte(line), &te); err != nil {
|
|
return nil, fmt.Errorf("line %d: %w", lineNum, err)
|
|
}
|
|
if len(te.Messages) >= 2 {
|
|
conversations = append(conversations, te)
|
|
}
|
|
}
|
|
|
|
if err := scanner.Err(); err != nil {
|
|
return nil, fmt.Errorf("scan %s: %w", path, err)
|
|
}
|
|
|
|
return conversations, nil
|
|
}
|
|
|
|
// convertToConversations converts golden set prompt/response pairs into
|
|
// single-turn TrainingExample conversations (user → assistant).
|
|
func convertToConversations(responses []Response, minChars int) []TrainingExample {
|
|
var conversations []TrainingExample
|
|
for _, r := range responses {
|
|
if r.Response == "" || len(r.Response) < minChars {
|
|
continue
|
|
}
|
|
if strings.HasPrefix(r.Response, "ERROR:") {
|
|
continue
|
|
}
|
|
conversations = append(conversations, TrainingExample{
|
|
Messages: []ChatMessage{
|
|
{Role: "user", Content: r.Prompt},
|
|
{Role: "assistant", Content: r.Response},
|
|
},
|
|
})
|
|
}
|
|
return conversations
|
|
}
|
|
|
|
// splitConversations shuffles conversations with a deterministic seed and
|
|
// splits them into train, valid, and test sets by percentage.
|
|
func splitConversations(conversations []TrainingExample, trainPct, validPct, testPct int, seed int64) (train, valid, test []TrainingExample) {
|
|
shuffled := make([]TrainingExample, len(conversations))
|
|
copy(shuffled, conversations)
|
|
|
|
rng := rand.New(rand.NewSource(seed))
|
|
rng.Shuffle(len(shuffled), func(i, j int) {
|
|
shuffled[i], shuffled[j] = shuffled[j], shuffled[i]
|
|
})
|
|
|
|
n := len(shuffled)
|
|
trainN := n * trainPct / 100
|
|
validN := n * validPct / 100
|
|
_ = testPct
|
|
|
|
train = shuffled[:trainN]
|
|
valid = shuffled[trainN : trainN+validN]
|
|
test = shuffled[trainN+validN:]
|
|
|
|
// Ensure at least 1 in each split when we have enough data.
|
|
if len(valid) == 0 && len(train) > 1 {
|
|
valid = train[len(train)-1:]
|
|
train = train[:len(train)-1]
|
|
}
|
|
if len(test) == 0 && len(train) > 1 {
|
|
test = train[len(train)-1:]
|
|
train = train[:len(train)-1]
|
|
}
|
|
|
|
return train, valid, test
|
|
}
|
|
|
|
// writeConversationJSONL writes TrainingExample conversations to a JSONL file.
|
|
func writeConversationJSONL(path string, conversations []TrainingExample) error {
|
|
f, err := os.Create(path)
|
|
if err != nil {
|
|
return fmt.Errorf("create %s: %w", path, err)
|
|
}
|
|
defer f.Close()
|
|
|
|
w := bufio.NewWriter(f)
|
|
defer w.Flush()
|
|
|
|
for _, c := range conversations {
|
|
data, err := json.Marshal(c)
|
|
if err != nil {
|
|
return fmt.Errorf("marshal conversation: %w", err)
|
|
}
|
|
if _, err := w.Write(data); err != nil {
|
|
return fmt.Errorf("write line: %w", err)
|
|
}
|
|
if _, err := w.WriteString("\n"); err != nil {
|
|
return fmt.Errorf("write newline: %w", err)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|