From 3606ff994b70a80cd0e6fda2f1c6dc2006afdea3 Mon Sep 17 00:00:00 2001 From: Snider Date: Mon, 23 Feb 2026 04:46:51 +0000 Subject: [PATCH] fix: memory, error handling, and signal improvements across pkg/lem - Stream parquet export rows instead of unbounded memory allocation - Replace QueryGoldenSet/QueryExpansionPrompts with iter.Seq2 iterators - Remove legacy runtime.GC() calls from distill (go-mlx handles cleanup) - Replace log.Fatalf with error return in tier_score.go - Add SIGINT/SIGTERM signal handling to agent and worker daemon loops - Add error checks for unchecked db.conn.Exec in import.go and tier_score.go - Update tests for iterator-based database methods Co-Authored-By: Gemini Co-Authored-By: Virgil --- pkg/lem/agent.go | 81 ++++++++++++++++------------- pkg/lem/conv.go | 9 ++-- pkg/lem/db.go | 118 ++++++++++++++++++++++++------------------ pkg/lem/db_test.go | 72 ++++++++++++++++---------- pkg/lem/distill.go | 7 --- pkg/lem/expand.go | 12 ++--- pkg/lem/export.go | 13 ++--- pkg/lem/import.go | 74 ++++++++++++++++++-------- pkg/lem/parquet.go | 52 +++++++++---------- pkg/lem/tier_score.go | 24 +++++---- pkg/lem/worker.go | 24 ++++++++- 11 files changed, 287 insertions(+), 199 deletions(-) diff --git a/pkg/lem/agent.go b/pkg/lem/agent.go index 9403c4a..7872f4d 100644 --- a/pkg/lem/agent.go +++ b/pkg/lem/agent.go @@ -1,15 +1,18 @@ package lem import ( + "context" "encoding/json" "fmt" "log" "os" "os/exec" + "os/signal" "path/filepath" "regexp" "sort" "strings" + "syscall" "time" ) @@ -72,11 +75,14 @@ type bufferEntry struct { // runs 23 capability probes via an OpenAI-compatible API, and // pushes results to InfluxDB. func RunAgent(cfg AgentOpts) error { - runAgentLoop(&cfg) + ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) + defer stop() + + runAgentLoop(ctx, &cfg) return nil } -func runAgentLoop(cfg *AgentOpts) { +func runAgentLoop(ctx context.Context, cfg *AgentOpts) { log.Println(strings.Repeat("=", 60)) log.Println("ROCm Scoring Agent — Go Edition") log.Printf("M3: %s@%s", cfg.M3User, cfg.M3Host) @@ -88,6 +94,9 @@ func runAgentLoop(cfg *AgentOpts) { influx := NewInfluxClient(cfg.InfluxURL, cfg.InfluxDB) os.MkdirAll(cfg.WorkDir, 0755) + ticker := time.NewTicker(time.Duration(cfg.PollInterval) * time.Second) + defer ticker.Stop() + for { // Replay any buffered results. replayInfluxBuffer(cfg.WorkDir, influx) @@ -97,51 +106,51 @@ func runAgentLoop(cfg *AgentOpts) { checkpoints, err := discoverCheckpoints(cfg) if err != nil { log.Printf("Discovery failed: %v", err) - sleepOrExit(cfg) - continue - } - log.Printf("Found %d total checkpoints", len(checkpoints)) + } else { + log.Printf("Found %d total checkpoints", len(checkpoints)) - // Check what is already scored. - scored, err := getScoredLabels(influx) - if err != nil { - log.Printf("InfluxDB query failed: %v", err) - } - log.Printf("Already scored: %d (run_id, label) pairs", len(scored)) - - // Find unscored work. - unscored := findUnscored(checkpoints, scored) - log.Printf("Unscored: %d checkpoints", len(unscored)) - - if len(unscored) == 0 { - log.Printf("Nothing to score. Sleeping %ds...", cfg.PollInterval) - if cfg.OneShot { - return + // Check what is already scored. + scored, err := getScoredLabels(influx) + if err != nil { + log.Printf("InfluxDB query failed: %v", err) } - time.Sleep(time.Duration(cfg.PollInterval) * time.Second) - continue - } + log.Printf("Already scored: %d (run_id, label) pairs", len(scored)) - target := unscored[0] - log.Printf("Grabbed: %s (%s)", target.Label, target.Dirname) + // Find unscored work. + unscored := findUnscored(checkpoints, scored) + log.Printf("Unscored: %d checkpoints", len(unscored)) - if cfg.DryRun { - log.Printf("[DRY RUN] Would process: %s/%s", target.Dirname, target.Filename) - for _, u := range unscored[1:] { - log.Printf("[DRY RUN] Queued: %s/%s", u.Dirname, u.Filename) + if len(unscored) > 0 { + target := unscored[0] + log.Printf("Grabbed: %s (%s)", target.Label, target.Dirname) + + if cfg.DryRun { + log.Printf("[DRY RUN] Would process: %s/%s", target.Dirname, target.Filename) + for _, u := range unscored[1:] { + log.Printf("[DRY RUN] Queued: %s/%s", u.Dirname, u.Filename) + } + return + } + + if err := processOne(cfg, influx, target); err != nil { + log.Printf("Error processing %s: %v", target.Label, err) + } + } else { + log.Printf("Nothing to score.") } - return - } - - if err := processOne(cfg, influx, target); err != nil { - log.Printf("Error processing %s: %v", target.Label, err) } if cfg.OneShot { return } - time.Sleep(5 * time.Second) + log.Printf("Sleeping %ds...", cfg.PollInterval) + select { + case <-ctx.Done(): + log.Println("Agent shutting down...") + return + case <-ticker.C: + } } } diff --git a/pkg/lem/conv.go b/pkg/lem/conv.go index 57c3e5d..f713af6 100644 --- a/pkg/lem/conv.go +++ b/pkg/lem/conv.go @@ -80,11 +80,10 @@ func RunConv(cfg ConvOpts) error { } defer db.Close() - rows, err := db.QueryGoldenSet(cfg.MinChars) - if err != nil { - return fmt.Errorf("query golden_set: %w", err) - } - for _, r := range rows { + for r, err := range db.GoldenSet(cfg.MinChars) { + if err != nil { + return fmt.Errorf("golden set error: %w", err) + } goldenResponses = append(goldenResponses, Response{ ID: r.SeedID, Domain: r.Domain, diff --git a/pkg/lem/db.go b/pkg/lem/db.go index fbc5b9d..35d35de 100644 --- a/pkg/lem/db.go +++ b/pkg/lem/db.go @@ -3,6 +3,7 @@ package lem import ( "database/sql" "fmt" + "iter" _ "github.com/marcboeker/go-duckdb" ) @@ -71,28 +72,37 @@ type ExpansionPromptRow struct { Status string } -// QueryGoldenSet returns all golden set rows with responses >= minChars. -func (db *DB) QueryGoldenSet(minChars int) ([]GoldenSetRow, error) { - rows, err := db.conn.Query( - "SELECT idx, seed_id, domain, voice, prompt, response, gen_time, char_count "+ - "FROM golden_set WHERE char_count >= ? ORDER BY idx", - minChars, - ) - if err != nil { - return nil, fmt.Errorf("query golden_set: %w", err) - } - defer rows.Close() - - var result []GoldenSetRow - for rows.Next() { - var r GoldenSetRow - if err := rows.Scan(&r.Idx, &r.SeedID, &r.Domain, &r.Voice, - &r.Prompt, &r.Response, &r.GenTime, &r.CharCount); err != nil { - return nil, fmt.Errorf("scan golden_set row: %w", err) +// GoldenSet returns an iterator over golden set rows with responses >= minChars. +func (db *DB) GoldenSet(minChars int) iter.Seq2[GoldenSetRow, error] { + return func(yield func(GoldenSetRow, error) bool) { + rows, err := db.conn.Query( + "SELECT idx, seed_id, domain, voice, prompt, response, gen_time, char_count "+ + "FROM golden_set WHERE char_count >= ? ORDER BY idx", + minChars, + ) + if err != nil { + yield(GoldenSetRow{}, fmt.Errorf("query golden_set: %w", err)) + return + } + defer rows.Close() + + for rows.Next() { + var r GoldenSetRow + if err := rows.Scan(&r.Idx, &r.SeedID, &r.Domain, &r.Voice, + &r.Prompt, &r.Response, &r.GenTime, &r.CharCount); err != nil { + if !yield(GoldenSetRow{}, fmt.Errorf("scan golden_set row: %w", err)) { + return + } + continue + } + if !yield(r, nil) { + return + } + } + if err := rows.Err(); err != nil { + yield(GoldenSetRow{}, fmt.Errorf("rows error: %w", err)) } - result = append(result, r) } - return result, rows.Err() } // CountGoldenSet returns the total count of golden set rows. @@ -105,39 +115,47 @@ func (db *DB) CountGoldenSet() (int, error) { return count, nil } -// QueryExpansionPrompts returns expansion prompts filtered by status. -// If status is empty, returns all prompts. -func (db *DB) QueryExpansionPrompts(status string, limit int) ([]ExpansionPromptRow, error) { - query := "SELECT idx, seed_id, region, domain, language, prompt, prompt_en, priority, status " + - "FROM expansion_prompts" - var args []any +// ExpansionPrompts returns an iterator over expansion prompts filtered by status. +func (db *DB) ExpansionPrompts(status string, limit int) iter.Seq2[ExpansionPromptRow, error] { + return func(yield func(ExpansionPromptRow, error) bool) { + query := "SELECT idx, seed_id, region, domain, language, prompt, prompt_en, priority, status " + + "FROM expansion_prompts" + var args []any - if status != "" { - query += " WHERE status = ?" - args = append(args, status) - } - query += " ORDER BY priority, idx" - - if limit > 0 { - query += fmt.Sprintf(" LIMIT %d", limit) - } - - rows, err := db.conn.Query(query, args...) - if err != nil { - return nil, fmt.Errorf("query expansion_prompts: %w", err) - } - defer rows.Close() - - var result []ExpansionPromptRow - for rows.Next() { - var r ExpansionPromptRow - if err := rows.Scan(&r.Idx, &r.SeedID, &r.Region, &r.Domain, - &r.Language, &r.Prompt, &r.PromptEn, &r.Priority, &r.Status); err != nil { - return nil, fmt.Errorf("scan expansion_prompt row: %w", err) + if status != "" { + query += " WHERE status = ?" + args = append(args, status) + } + query += " ORDER BY priority, idx" + + if limit > 0 { + query += fmt.Sprintf(" LIMIT %d", limit) + } + + rows, err := db.conn.Query(query, args...) + if err != nil { + yield(ExpansionPromptRow{}, fmt.Errorf("query expansion_prompts: %w", err)) + return + } + defer rows.Close() + + for rows.Next() { + var r ExpansionPromptRow + if err := rows.Scan(&r.Idx, &r.SeedID, &r.Region, &r.Domain, + &r.Language, &r.Prompt, &r.PromptEn, &r.Priority, &r.Status); err != nil { + if !yield(ExpansionPromptRow{}, fmt.Errorf("scan expansion_prompt row: %w", err)) { + return + } + continue + } + if !yield(r, nil) { + return + } + } + if err := rows.Err(); err != nil { + yield(ExpansionPromptRow{}, fmt.Errorf("rows error: %w", err)) } - result = append(result, r) } - return result, rows.Err() } // CountExpansionPrompts returns counts by status. diff --git a/pkg/lem/db_test.go b/pkg/lem/db_test.go index 7456e3f..ffc4dc9 100644 --- a/pkg/lem/db_test.go +++ b/pkg/lem/db_test.go @@ -82,9 +82,12 @@ func TestQueryGoldenSet(t *testing.T) { } // Query with minChars=50 should return 2 (skip the short one). - rows, err := db.QueryGoldenSet(50) - if err != nil { - t.Fatalf("query: %v", err) + var rows []GoldenSetRow + for r, err := range db.GoldenSet(50) { + if err != nil { + t.Fatalf("iter error: %v", err) + } + rows = append(rows, r) } if len(rows) != 2 { t.Fatalf("got %d rows, want 2", len(rows)) @@ -97,16 +100,19 @@ func TestQueryGoldenSet(t *testing.T) { } } -func TestQueryGoldenSetEmpty(t *testing.T) { +func TestGoldenSetEmpty(t *testing.T) { db := createTestDB(t) defer db.Close() - rows, err := db.QueryGoldenSet(0) - if err != nil { - t.Fatalf("query: %v", err) + count := 0 + for _, err := range db.GoldenSet(0) { + if err != nil { + t.Fatalf("iter error: %v", err) + } + count++ } - if len(rows) != 0 { - t.Fatalf("got %d rows, want 0", len(rows)) + if count != 0 { + t.Fatalf("got %d rows, want 0", count) } } @@ -131,7 +137,7 @@ func TestCountGoldenSet(t *testing.T) { } } -func TestQueryExpansionPrompts(t *testing.T) { +func TestExpansionPrompts(t *testing.T) { db := createTestDB(t) defer db.Close() @@ -145,9 +151,12 @@ func TestQueryExpansionPrompts(t *testing.T) { } // Query pending only. - rows, err := db.QueryExpansionPrompts("pending", 0) - if err != nil { - t.Fatalf("query pending: %v", err) + var rows []ExpansionPromptRow + for r, err := range db.ExpansionPrompts("pending", 0) { + if err != nil { + t.Fatalf("iter error: %v", err) + } + rows = append(rows, r) } if len(rows) != 2 { t.Fatalf("got %d rows, want 2", len(rows)) @@ -158,18 +167,24 @@ func TestQueryExpansionPrompts(t *testing.T) { } // Query all. - all, err := db.QueryExpansionPrompts("", 0) - if err != nil { - t.Fatalf("query all: %v", err) + var all []ExpansionPromptRow + for r, err := range db.ExpansionPrompts("", 0) { + if err != nil { + t.Fatalf("iter error: %v", err) + } + all = append(all, r) } if len(all) != 3 { t.Fatalf("got %d rows, want 3", len(all)) } // Query with limit. - limited, err := db.QueryExpansionPrompts("pending", 1) - if err != nil { - t.Fatalf("query limited: %v", err) + var limited []ExpansionPromptRow + for r, err := range db.ExpansionPrompts("pending", 1) { + if err != nil { + t.Fatalf("iter error: %v", err) + } + limited = append(limited, r) } if len(limited) != 1 { t.Fatalf("got %d rows, want 1", len(limited)) @@ -217,15 +232,18 @@ func TestUpdateExpansionStatus(t *testing.T) { t.Fatalf("update: %v", err) } - rows, err := db.QueryExpansionPrompts("completed", 0) - if err != nil { - t.Fatalf("query: %v", err) + count := 0 + for r, err := range db.ExpansionPrompts("completed", 0) { + if err != nil { + t.Fatalf("iter error: %v", err) + } + if r.Status != "completed" { + t.Errorf("status = %q, want completed", r.Status) + } + count++ } - if len(rows) != 1 { - t.Fatalf("got %d rows, want 1", len(rows)) - } - if rows[0].Status != "completed" { - t.Errorf("status = %q, want completed", rows[0].Status) + if count != 1 { + t.Fatalf("got %d rows, want 1", count) } } diff --git a/pkg/lem/distill.go b/pkg/lem/distill.go index 7592838..b2b9621 100644 --- a/pkg/lem/distill.go +++ b/pkg/lem/distill.go @@ -7,7 +7,6 @@ import ( "log" "os" "path/filepath" - "runtime" "strings" "time" @@ -293,7 +292,6 @@ func RunDistill(cfg DistillOpts) error { skipped++ fmt.Fprintf(os.Stderr, " ✗ SKIP %s (BO composite %d < %d)\n", probe.ID, boScore, aiCfg.Scorer.AttentionMinScore) - runtime.GC() continue } } @@ -306,8 +304,6 @@ func RunDistill(cfg DistillOpts) error { if dedupIdx != nil && dedupIdx.IsDuplicate(bestFeatures, 0.02) { deduped++ fmt.Fprintf(os.Stderr, " ~ DEDUP %s (grammar profile too similar to existing)\n", probe.ID) - // Release GPU memory between probes to prevent incremental leak. - runtime.GC() continue } @@ -343,9 +339,6 @@ func RunDistill(cfg DistillOpts) error { fmt.Fprintf(os.Stderr, " ✗ SKIP %s (best g=%.1f < %.1f)\n", probe.ID, score, minScore) } - - // Release GPU memory between probes to prevent incremental leak. - runtime.GC() } duration := time.Since(totalStart) diff --git a/pkg/lem/expand.go b/pkg/lem/expand.go index 258cf71..8721836 100644 --- a/pkg/lem/expand.go +++ b/pkg/lem/expand.go @@ -71,13 +71,10 @@ func RunExpand(cfg ExpandOpts) error { } defer duckDB.Close() - rows, err := duckDB.QueryExpansionPrompts("pending", cfg.Limit) - if err != nil { - return fmt.Errorf("query expansion_prompts: %v", err) - } - log.Printf("loaded %d pending prompts from %s", len(rows), cfg.DB) - - for _, r := range rows { + for r, err := range duckDB.ExpansionPrompts("pending", cfg.Limit) { + if err != nil { + return fmt.Errorf("expansion prompts error: %v", err) + } prompt := r.Prompt if prompt == "" && r.PromptEn != "" { prompt = r.PromptEn // Use English translation if primary is empty. @@ -88,6 +85,7 @@ func RunExpand(cfg ExpandOpts) error { Prompt: prompt, }) } + log.Printf("loaded %d pending prompts from %s", len(promptList), cfg.DB) } else { var err error promptList, err = ReadResponses(cfg.Prompts) diff --git a/pkg/lem/export.go b/pkg/lem/export.go index 0ecb46e..da563c3 100644 --- a/pkg/lem/export.go +++ b/pkg/lem/export.go @@ -62,14 +62,10 @@ func RunExport(cfg ExportOpts) error { } defer db.Close() - rows, err := db.QueryGoldenSet(cfg.MinChars) - if err != nil { - return fmt.Errorf("query golden_set: %w", err) - } - log.Printf("loaded %d golden set rows from %s (min_chars=%d)", len(rows), cfg.DBPath, cfg.MinChars) - - // Convert GoldenSetRow → Response for the shared pipeline. - for _, r := range rows { + for r, err := range db.GoldenSet(cfg.MinChars) { + if err != nil { + return fmt.Errorf("golden set error: %w", err) + } responses = append(responses, Response{ ID: r.SeedID, Domain: r.Domain, @@ -78,6 +74,7 @@ func RunExport(cfg ExportOpts) error { Model: r.Voice, // voice maps to the "model" slot for tracking }) } + log.Printf("loaded %d golden set rows from %s (min_chars=%d)", len(responses), cfg.DBPath, cfg.MinChars) } else { // Fallback: read from JSONL file. var err error diff --git a/pkg/lem/import.go b/pkg/lem/import.go index c2dce23..837442a 100644 --- a/pkg/lem/import.go +++ b/pkg/lem/import.go @@ -53,7 +53,9 @@ func RunImport(cfg ImportOpts) error { } } if _, err := os.Stat(goldenPath); err == nil { - db.conn.Exec("DROP TABLE IF EXISTS golden_set") + if _, err := db.conn.Exec("DROP TABLE IF EXISTS golden_set"); err != nil { + log.Printf(" WARNING: drop golden_set failed: %v", err) + } _, err := db.conn.Exec(fmt.Sprintf(` CREATE TABLE golden_set AS SELECT @@ -114,8 +116,10 @@ func RunImport(cfg ImportOpts) error { } } - db.conn.Exec("DROP TABLE IF EXISTS training_examples") - db.conn.Exec(` + if _, err := db.conn.Exec("DROP TABLE IF EXISTS training_examples"); err != nil { + log.Printf(" WARNING: drop training_examples failed: %v", err) + } + if _, err := db.conn.Exec(` CREATE TABLE training_examples ( source VARCHAR, split VARCHAR, @@ -125,7 +129,9 @@ func RunImport(cfg ImportOpts) error { full_messages TEXT, char_count INT ) - `) + `); err != nil { + log.Printf(" WARNING: create training_examples failed: %v", err) + } trainingTotal := 0 for _, td := range trainingDirs { @@ -171,13 +177,17 @@ func RunImport(cfg ImportOpts) error { } } - db.conn.Exec("DROP TABLE IF EXISTS benchmark_results") - db.conn.Exec(` + if _, err := db.conn.Exec("DROP TABLE IF EXISTS benchmark_results"); err != nil { + log.Printf(" WARNING: drop benchmark_results failed: %v", err) + } + if _, err := db.conn.Exec(` CREATE TABLE benchmark_results ( source VARCHAR, id VARCHAR, benchmark VARCHAR, model VARCHAR, prompt TEXT, response TEXT, elapsed_seconds DOUBLE, domain VARCHAR ) - `) + `); err != nil { + log.Printf(" WARNING: create benchmark_results failed: %v", err) + } benchTotal := 0 for _, subdir := range []string{"results", "scale_results", "cross_arch_results", "deepseek-r1-7b"} { @@ -208,13 +218,17 @@ func RunImport(cfg ImportOpts) error { fmt.Printf(" benchmark_results: %d rows\n", benchTotal) // ── 4. Benchmark questions ── - db.conn.Exec("DROP TABLE IF EXISTS benchmark_questions") - db.conn.Exec(` + if _, err := db.conn.Exec("DROP TABLE IF EXISTS benchmark_questions"); err != nil { + log.Printf(" WARNING: drop benchmark_questions failed: %v", err) + } + if _, err := db.conn.Exec(` CREATE TABLE benchmark_questions ( benchmark VARCHAR, id VARCHAR, question TEXT, best_answer TEXT, correct_answers TEXT, incorrect_answers TEXT, category VARCHAR ) - `) + `); err != nil { + log.Printf(" WARNING: create benchmark_questions failed: %v", err) + } benchQTotal := 0 for _, bname := range []string{"truthfulqa", "gsm8k", "do_not_answer", "toxigen"} { @@ -228,12 +242,16 @@ func RunImport(cfg ImportOpts) error { fmt.Printf(" benchmark_questions: %d rows\n", benchQTotal) // ── 5. Seeds ── - db.conn.Exec("DROP TABLE IF EXISTS seeds") - db.conn.Exec(` + if _, err := db.conn.Exec("DROP TABLE IF EXISTS seeds"); err != nil { + log.Printf(" WARNING: drop seeds failed: %v", err) + } + if _, err := db.conn.Exec(` CREATE TABLE seeds ( source_file VARCHAR, region VARCHAR, seed_id VARCHAR, domain VARCHAR, prompt TEXT ) - `) + `); err != nil { + log.Printf(" WARNING: create seeds failed: %v", err) + } seedTotal := 0 seedDirs := []string{filepath.Join(dataDir, "seeds"), "/tmp/lem-data/seeds", "/tmp/lem-repo/seeds"} @@ -297,8 +315,10 @@ func importTrainingFile(db *DB, path, source, split string) int { } msgsJSON, _ := json.Marshal(rec.Messages) - db.conn.Exec(`INSERT INTO training_examples VALUES (?, ?, ?, ?, ?, ?, ?)`, - source, split, prompt, response, assistantCount, string(msgsJSON), len(response)) + if _, err := db.conn.Exec(`INSERT INTO training_examples VALUES (?, ?, ?, ?, ?, ?, ?)`, + source, split, prompt, response, assistantCount, string(msgsJSON), len(response)); err != nil { + log.Printf(" WARNING: insert training example failed: %v", err) + } count++ } return count @@ -321,7 +341,7 @@ func importBenchmarkFile(db *DB, path, source string) int { continue } - db.conn.Exec(`INSERT INTO benchmark_results VALUES (?, ?, ?, ?, ?, ?, ?, ?)`, + if _, err := db.conn.Exec(`INSERT INTO benchmark_results VALUES (?, ?, ?, ?, ?, ?, ?, ?)`, source, fmt.Sprintf("%v", rec["id"]), strOrEmpty(rec, "benchmark"), @@ -330,7 +350,9 @@ func importBenchmarkFile(db *DB, path, source string) int { strOrEmpty(rec, "response"), floatOrZero(rec, "elapsed_seconds"), strOrEmpty(rec, "domain"), - ) + ); err != nil { + log.Printf(" WARNING: insert benchmark result failed: %v", err) + } count++ } return count @@ -356,7 +378,7 @@ func importBenchmarkQuestions(db *DB, path, benchmark string) int { correctJSON, _ := json.Marshal(rec["correct_answers"]) incorrectJSON, _ := json.Marshal(rec["incorrect_answers"]) - db.conn.Exec(`INSERT INTO benchmark_questions VALUES (?, ?, ?, ?, ?, ?, ?)`, + if _, err := db.conn.Exec(`INSERT INTO benchmark_questions VALUES (?, ?, ?, ?, ?, ?, ?)`, benchmark, fmt.Sprintf("%v", rec["id"]), strOrEmpty(rec, "question"), @@ -364,7 +386,9 @@ func importBenchmarkQuestions(db *DB, path, benchmark string) int { string(correctJSON), string(incorrectJSON), strOrEmpty(rec, "category"), - ) + ); err != nil { + log.Printf(" WARNING: insert benchmark question failed: %v", err) + } count++ } return count @@ -413,16 +437,20 @@ func importSeeds(db *DB, seedDir string) int { if prompt == "" { prompt = strOrEmpty(seed, "question") } - db.conn.Exec(`INSERT INTO seeds VALUES (?, ?, ?, ?, ?)`, + if _, err := db.conn.Exec(`INSERT INTO seeds VALUES (?, ?, ?, ?, ?)`, rel, region, strOrEmpty(seed, "seed_id"), strOrEmpty(seed, "domain"), prompt, - ) + ); err != nil { + log.Printf(" WARNING: insert seed failed: %v", err) + } count++ case string: - db.conn.Exec(`INSERT INTO seeds VALUES (?, ?, ?, ?, ?)`, - rel, region, "", "", seed) + if _, err := db.conn.Exec(`INSERT INTO seeds VALUES (?, ?, ?, ?, ?)`, + rel, region, "", "", seed); err != nil { + log.Printf(" WARNING: insert string seed failed: %v", err) + } count++ } } diff --git a/pkg/lem/parquet.go b/pkg/lem/parquet.go index e80c60d..91af161 100644 --- a/pkg/lem/parquet.go +++ b/pkg/lem/parquet.go @@ -71,7 +71,19 @@ func exportSplitParquet(jsonlPath, outputDir, split string) (int, error) { } defer f.Close() - var rows []ParquetRow + outPath := filepath.Join(outputDir, split+".parquet") + out, err := os.Create(outPath) + if err != nil { + return 0, fmt.Errorf("create %s: %w", outPath, err) + } + defer out.Close() + + writer := parquet.NewGenericWriter[ParquetRow](out, + parquet.Compression(&parquet.Snappy), + ) + defer writer.Close() + + count := 0 scanner := bufio.NewScanner(f) scanner.Buffer(make([]byte, 1024*1024), 1024*1024) @@ -107,51 +119,39 @@ func exportSplitParquet(jsonlPath, outputDir, split string) (int, error) { } msgsJSON, _ := json.Marshal(data.Messages) - rows = append(rows, ParquetRow{ + row := ParquetRow{ Prompt: prompt, Response: response, System: system, Messages: string(msgsJSON), - }) + } + + if _, err := writer.Write([]ParquetRow{row}); err != nil { + return count, fmt.Errorf("write parquet row: %w", err) + } + count++ } if err := scanner.Err(); err != nil { - return 0, fmt.Errorf("scan %s: %w", jsonlPath, err) + return count, fmt.Errorf("scan %s: %w", jsonlPath, err) } - if len(rows) == 0 { + if count == 0 { fmt.Printf(" Skip: %s — no data\n", split) return 0, nil } - outPath := filepath.Join(outputDir, split+".parquet") - - out, err := os.Create(outPath) - if err != nil { - return 0, fmt.Errorf("create %s: %w", outPath, err) - } - - writer := parquet.NewGenericWriter[ParquetRow](out, - parquet.Compression(&parquet.Snappy), - ) - - if _, err := writer.Write(rows); err != nil { - out.Close() - return 0, fmt.Errorf("write parquet rows: %w", err) - } - if err := writer.Close(); err != nil { - out.Close() - return 0, fmt.Errorf("close parquet writer: %w", err) + return count, fmt.Errorf("close parquet writer: %w", err) } if err := out.Close(); err != nil { - return 0, fmt.Errorf("close file: %w", err) + return count, fmt.Errorf("close file: %w", err) } info, _ := os.Stat(outPath) sizeMB := float64(info.Size()) / 1024 / 1024 - fmt.Printf(" %s.parquet: %d rows (%.1f MB)\n", split, len(rows), sizeMB) + fmt.Printf(" %s.parquet: %d rows (%.1f MB)\n", split, count, sizeMB) - return len(rows), nil + return count, nil } diff --git a/pkg/lem/tier_score.go b/pkg/lem/tier_score.go index f0ca09d..e6e0e26 100644 --- a/pkg/lem/tier_score.go +++ b/pkg/lem/tier_score.go @@ -2,7 +2,6 @@ package lem import ( "fmt" - "log" "os" "strings" ) @@ -35,7 +34,7 @@ func RunTierScore(cfg TierScoreOpts) error { defer db.Close() // Ensure expansion_scores table exists. - db.conn.Exec(` + if _, err := db.conn.Exec(` CREATE TABLE IF NOT EXISTS expansion_scores ( idx INT, heuristic_score DOUBLE, @@ -49,10 +48,14 @@ func RunTierScore(cfg TierScoreOpts) error { judge_model VARCHAR, scored_at TIMESTAMP ) - `) + `); err != nil { + return fmt.Errorf("create expansion_scores: %w", err) + } if cfg.Tier >= 1 { - runHeuristicTier(db, cfg.Limit) + if err := runHeuristicTier(db, cfg.Limit); err != nil { + return fmt.Errorf("tier 1: %w", err) + } } if cfg.Tier >= 2 { @@ -67,7 +70,7 @@ func RunTierScore(cfg TierScoreOpts) error { return nil } -func runHeuristicTier(db *DB, limit int) { +func runHeuristicTier(db *DB, limit int) error { // Find unscored responses. query := ` SELECT r.idx, r.response FROM expansion_raw r @@ -81,7 +84,7 @@ func runHeuristicTier(db *DB, limit int) { rows, err := db.conn.Query(query) if err != nil { - log.Fatalf("query unscored: %v", err) + return fmt.Errorf("query unscored: %w", err) } defer rows.Close() @@ -99,7 +102,7 @@ func runHeuristicTier(db *DB, limit int) { if len(unscored) == 0 { fmt.Println("Tier 1 (heuristic): all responses already scored") - return + return nil } fmt.Printf("Tier 1 (heuristic): scoring %d responses...\n", len(unscored)) @@ -112,16 +115,19 @@ func runHeuristicTier(db *DB, limit int) { passed++ } - db.conn.Exec(` + if _, err := db.conn.Exec(` INSERT INTO expansion_scores (idx, heuristic_score, heuristic_pass, scored_at) VALUES (?, ?, ?, current_timestamp) - `, r.idx, score, isPass) + `, r.idx, score, isPass); err != nil { + fmt.Printf(" WARNING: insert score failed for idx %d: %v\n", r.idx, err) + } } fmt.Printf(" Scored: %d, Passed: %d, Failed: %d\n", len(unscored), passed, len(unscored)-passed) if len(unscored) > 0 { fmt.Printf(" Pass rate: %.1f%%\n", float64(passed)/float64(len(unscored))*100) } + return nil } // heuristicExpansionScore applies fast heuristic scoring to an expansion response. diff --git a/pkg/lem/worker.go b/pkg/lem/worker.go index 3580fb3..a08ff46 100644 --- a/pkg/lem/worker.go +++ b/pkg/lem/worker.go @@ -2,14 +2,17 @@ package lem import ( "bytes" + "context" "encoding/json" "fmt" "io" "log" "net/http" "os" + "os/signal" "path/filepath" "runtime" + "syscall" "time" ) @@ -151,6 +154,12 @@ func RunWorker(opts WorkerOpts) error { } log.Println("Registered with LEM API") + ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) + defer stop() + + ticker := time.NewTicker(cfg.pollInterval) + defer ticker.Stop() + // Main loop. for { processed := workerPoll(&cfg) @@ -162,7 +171,20 @@ func RunWorker(opts WorkerOpts) error { if processed == 0 { log.Printf("No tasks available, sleeping %v", cfg.pollInterval) - time.Sleep(cfg.pollInterval) + select { + case <-ctx.Done(): + log.Println("Worker shutting down...") + return nil + case <-ticker.C: + } + } else { + // Check context between tasks if needed. + select { + case <-ctx.Done(): + log.Println("Worker shutting down...") + return nil + default: + } } // Heartbeat.