fix: memory, error handling, and signal improvements across pkg/lem

- Stream parquet export rows instead of unbounded memory allocation
- Replace QueryGoldenSet/QueryExpansionPrompts with iter.Seq2 iterators
- Remove legacy runtime.GC() calls from distill (go-mlx handles cleanup)
- Replace log.Fatalf with error return in tier_score.go
- Add SIGINT/SIGTERM signal handling to agent and worker daemon loops
- Add error checks for unchecked db.conn.Exec in import.go and tier_score.go
- Update tests for iterator-based database methods

Co-Authored-By: Gemini <noreply@google.com>
Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
Snider 2026-02-23 04:46:51 +00:00
parent de3d6a70f1
commit 3606ff994b
11 changed files with 287 additions and 199 deletions

View file

@ -1,15 +1,18 @@
package lem
import (
"context"
"encoding/json"
"fmt"
"log"
"os"
"os/exec"
"os/signal"
"path/filepath"
"regexp"
"sort"
"strings"
"syscall"
"time"
)
@ -72,11 +75,14 @@ type bufferEntry struct {
// runs 23 capability probes via an OpenAI-compatible API, and
// pushes results to InfluxDB.
func RunAgent(cfg AgentOpts) error {
runAgentLoop(&cfg)
ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
defer stop()
runAgentLoop(ctx, &cfg)
return nil
}
func runAgentLoop(cfg *AgentOpts) {
func runAgentLoop(ctx context.Context, cfg *AgentOpts) {
log.Println(strings.Repeat("=", 60))
log.Println("ROCm Scoring Agent — Go Edition")
log.Printf("M3: %s@%s", cfg.M3User, cfg.M3Host)
@ -88,6 +94,9 @@ func runAgentLoop(cfg *AgentOpts) {
influx := NewInfluxClient(cfg.InfluxURL, cfg.InfluxDB)
os.MkdirAll(cfg.WorkDir, 0755)
ticker := time.NewTicker(time.Duration(cfg.PollInterval) * time.Second)
defer ticker.Stop()
for {
// Replay any buffered results.
replayInfluxBuffer(cfg.WorkDir, influx)
@ -97,51 +106,51 @@ func runAgentLoop(cfg *AgentOpts) {
checkpoints, err := discoverCheckpoints(cfg)
if err != nil {
log.Printf("Discovery failed: %v", err)
sleepOrExit(cfg)
continue
}
log.Printf("Found %d total checkpoints", len(checkpoints))
} else {
log.Printf("Found %d total checkpoints", len(checkpoints))
// Check what is already scored.
scored, err := getScoredLabels(influx)
if err != nil {
log.Printf("InfluxDB query failed: %v", err)
}
log.Printf("Already scored: %d (run_id, label) pairs", len(scored))
// Find unscored work.
unscored := findUnscored(checkpoints, scored)
log.Printf("Unscored: %d checkpoints", len(unscored))
if len(unscored) == 0 {
log.Printf("Nothing to score. Sleeping %ds...", cfg.PollInterval)
if cfg.OneShot {
return
// Check what is already scored.
scored, err := getScoredLabels(influx)
if err != nil {
log.Printf("InfluxDB query failed: %v", err)
}
time.Sleep(time.Duration(cfg.PollInterval) * time.Second)
continue
}
log.Printf("Already scored: %d (run_id, label) pairs", len(scored))
target := unscored[0]
log.Printf("Grabbed: %s (%s)", target.Label, target.Dirname)
// Find unscored work.
unscored := findUnscored(checkpoints, scored)
log.Printf("Unscored: %d checkpoints", len(unscored))
if cfg.DryRun {
log.Printf("[DRY RUN] Would process: %s/%s", target.Dirname, target.Filename)
for _, u := range unscored[1:] {
log.Printf("[DRY RUN] Queued: %s/%s", u.Dirname, u.Filename)
if len(unscored) > 0 {
target := unscored[0]
log.Printf("Grabbed: %s (%s)", target.Label, target.Dirname)
if cfg.DryRun {
log.Printf("[DRY RUN] Would process: %s/%s", target.Dirname, target.Filename)
for _, u := range unscored[1:] {
log.Printf("[DRY RUN] Queued: %s/%s", u.Dirname, u.Filename)
}
return
}
if err := processOne(cfg, influx, target); err != nil {
log.Printf("Error processing %s: %v", target.Label, err)
}
} else {
log.Printf("Nothing to score.")
}
return
}
if err := processOne(cfg, influx, target); err != nil {
log.Printf("Error processing %s: %v", target.Label, err)
}
if cfg.OneShot {
return
}
time.Sleep(5 * time.Second)
log.Printf("Sleeping %ds...", cfg.PollInterval)
select {
case <-ctx.Done():
log.Println("Agent shutting down...")
return
case <-ticker.C:
}
}
}

View file

@ -80,11 +80,10 @@ func RunConv(cfg ConvOpts) error {
}
defer db.Close()
rows, err := db.QueryGoldenSet(cfg.MinChars)
if err != nil {
return fmt.Errorf("query golden_set: %w", err)
}
for _, r := range rows {
for r, err := range db.GoldenSet(cfg.MinChars) {
if err != nil {
return fmt.Errorf("golden set error: %w", err)
}
goldenResponses = append(goldenResponses, Response{
ID: r.SeedID,
Domain: r.Domain,

View file

@ -3,6 +3,7 @@ package lem
import (
"database/sql"
"fmt"
"iter"
_ "github.com/marcboeker/go-duckdb"
)
@ -71,28 +72,37 @@ type ExpansionPromptRow struct {
Status string
}
// QueryGoldenSet returns all golden set rows with responses >= minChars.
func (db *DB) QueryGoldenSet(minChars int) ([]GoldenSetRow, error) {
rows, err := db.conn.Query(
"SELECT idx, seed_id, domain, voice, prompt, response, gen_time, char_count "+
"FROM golden_set WHERE char_count >= ? ORDER BY idx",
minChars,
)
if err != nil {
return nil, fmt.Errorf("query golden_set: %w", err)
}
defer rows.Close()
var result []GoldenSetRow
for rows.Next() {
var r GoldenSetRow
if err := rows.Scan(&r.Idx, &r.SeedID, &r.Domain, &r.Voice,
&r.Prompt, &r.Response, &r.GenTime, &r.CharCount); err != nil {
return nil, fmt.Errorf("scan golden_set row: %w", err)
// GoldenSet returns an iterator over golden set rows with responses >= minChars.
func (db *DB) GoldenSet(minChars int) iter.Seq2[GoldenSetRow, error] {
return func(yield func(GoldenSetRow, error) bool) {
rows, err := db.conn.Query(
"SELECT idx, seed_id, domain, voice, prompt, response, gen_time, char_count "+
"FROM golden_set WHERE char_count >= ? ORDER BY idx",
minChars,
)
if err != nil {
yield(GoldenSetRow{}, fmt.Errorf("query golden_set: %w", err))
return
}
defer rows.Close()
for rows.Next() {
var r GoldenSetRow
if err := rows.Scan(&r.Idx, &r.SeedID, &r.Domain, &r.Voice,
&r.Prompt, &r.Response, &r.GenTime, &r.CharCount); err != nil {
if !yield(GoldenSetRow{}, fmt.Errorf("scan golden_set row: %w", err)) {
return
}
continue
}
if !yield(r, nil) {
return
}
}
if err := rows.Err(); err != nil {
yield(GoldenSetRow{}, fmt.Errorf("rows error: %w", err))
}
result = append(result, r)
}
return result, rows.Err()
}
// CountGoldenSet returns the total count of golden set rows.
@ -105,39 +115,47 @@ func (db *DB) CountGoldenSet() (int, error) {
return count, nil
}
// QueryExpansionPrompts returns expansion prompts filtered by status.
// If status is empty, returns all prompts.
func (db *DB) QueryExpansionPrompts(status string, limit int) ([]ExpansionPromptRow, error) {
query := "SELECT idx, seed_id, region, domain, language, prompt, prompt_en, priority, status " +
"FROM expansion_prompts"
var args []any
// ExpansionPrompts returns an iterator over expansion prompts filtered by status.
func (db *DB) ExpansionPrompts(status string, limit int) iter.Seq2[ExpansionPromptRow, error] {
return func(yield func(ExpansionPromptRow, error) bool) {
query := "SELECT idx, seed_id, region, domain, language, prompt, prompt_en, priority, status " +
"FROM expansion_prompts"
var args []any
if status != "" {
query += " WHERE status = ?"
args = append(args, status)
}
query += " ORDER BY priority, idx"
if limit > 0 {
query += fmt.Sprintf(" LIMIT %d", limit)
}
rows, err := db.conn.Query(query, args...)
if err != nil {
return nil, fmt.Errorf("query expansion_prompts: %w", err)
}
defer rows.Close()
var result []ExpansionPromptRow
for rows.Next() {
var r ExpansionPromptRow
if err := rows.Scan(&r.Idx, &r.SeedID, &r.Region, &r.Domain,
&r.Language, &r.Prompt, &r.PromptEn, &r.Priority, &r.Status); err != nil {
return nil, fmt.Errorf("scan expansion_prompt row: %w", err)
if status != "" {
query += " WHERE status = ?"
args = append(args, status)
}
query += " ORDER BY priority, idx"
if limit > 0 {
query += fmt.Sprintf(" LIMIT %d", limit)
}
rows, err := db.conn.Query(query, args...)
if err != nil {
yield(ExpansionPromptRow{}, fmt.Errorf("query expansion_prompts: %w", err))
return
}
defer rows.Close()
for rows.Next() {
var r ExpansionPromptRow
if err := rows.Scan(&r.Idx, &r.SeedID, &r.Region, &r.Domain,
&r.Language, &r.Prompt, &r.PromptEn, &r.Priority, &r.Status); err != nil {
if !yield(ExpansionPromptRow{}, fmt.Errorf("scan expansion_prompt row: %w", err)) {
return
}
continue
}
if !yield(r, nil) {
return
}
}
if err := rows.Err(); err != nil {
yield(ExpansionPromptRow{}, fmt.Errorf("rows error: %w", err))
}
result = append(result, r)
}
return result, rows.Err()
}
// CountExpansionPrompts returns counts by status.

View file

@ -82,9 +82,12 @@ func TestQueryGoldenSet(t *testing.T) {
}
// Query with minChars=50 should return 2 (skip the short one).
rows, err := db.QueryGoldenSet(50)
if err != nil {
t.Fatalf("query: %v", err)
var rows []GoldenSetRow
for r, err := range db.GoldenSet(50) {
if err != nil {
t.Fatalf("iter error: %v", err)
}
rows = append(rows, r)
}
if len(rows) != 2 {
t.Fatalf("got %d rows, want 2", len(rows))
@ -97,16 +100,19 @@ func TestQueryGoldenSet(t *testing.T) {
}
}
func TestQueryGoldenSetEmpty(t *testing.T) {
func TestGoldenSetEmpty(t *testing.T) {
db := createTestDB(t)
defer db.Close()
rows, err := db.QueryGoldenSet(0)
if err != nil {
t.Fatalf("query: %v", err)
count := 0
for _, err := range db.GoldenSet(0) {
if err != nil {
t.Fatalf("iter error: %v", err)
}
count++
}
if len(rows) != 0 {
t.Fatalf("got %d rows, want 0", len(rows))
if count != 0 {
t.Fatalf("got %d rows, want 0", count)
}
}
@ -131,7 +137,7 @@ func TestCountGoldenSet(t *testing.T) {
}
}
func TestQueryExpansionPrompts(t *testing.T) {
func TestExpansionPrompts(t *testing.T) {
db := createTestDB(t)
defer db.Close()
@ -145,9 +151,12 @@ func TestQueryExpansionPrompts(t *testing.T) {
}
// Query pending only.
rows, err := db.QueryExpansionPrompts("pending", 0)
if err != nil {
t.Fatalf("query pending: %v", err)
var rows []ExpansionPromptRow
for r, err := range db.ExpansionPrompts("pending", 0) {
if err != nil {
t.Fatalf("iter error: %v", err)
}
rows = append(rows, r)
}
if len(rows) != 2 {
t.Fatalf("got %d rows, want 2", len(rows))
@ -158,18 +167,24 @@ func TestQueryExpansionPrompts(t *testing.T) {
}
// Query all.
all, err := db.QueryExpansionPrompts("", 0)
if err != nil {
t.Fatalf("query all: %v", err)
var all []ExpansionPromptRow
for r, err := range db.ExpansionPrompts("", 0) {
if err != nil {
t.Fatalf("iter error: %v", err)
}
all = append(all, r)
}
if len(all) != 3 {
t.Fatalf("got %d rows, want 3", len(all))
}
// Query with limit.
limited, err := db.QueryExpansionPrompts("pending", 1)
if err != nil {
t.Fatalf("query limited: %v", err)
var limited []ExpansionPromptRow
for r, err := range db.ExpansionPrompts("pending", 1) {
if err != nil {
t.Fatalf("iter error: %v", err)
}
limited = append(limited, r)
}
if len(limited) != 1 {
t.Fatalf("got %d rows, want 1", len(limited))
@ -217,15 +232,18 @@ func TestUpdateExpansionStatus(t *testing.T) {
t.Fatalf("update: %v", err)
}
rows, err := db.QueryExpansionPrompts("completed", 0)
if err != nil {
t.Fatalf("query: %v", err)
count := 0
for r, err := range db.ExpansionPrompts("completed", 0) {
if err != nil {
t.Fatalf("iter error: %v", err)
}
if r.Status != "completed" {
t.Errorf("status = %q, want completed", r.Status)
}
count++
}
if len(rows) != 1 {
t.Fatalf("got %d rows, want 1", len(rows))
}
if rows[0].Status != "completed" {
t.Errorf("status = %q, want completed", rows[0].Status)
if count != 1 {
t.Fatalf("got %d rows, want 1", count)
}
}

View file

@ -7,7 +7,6 @@ import (
"log"
"os"
"path/filepath"
"runtime"
"strings"
"time"
@ -293,7 +292,6 @@ func RunDistill(cfg DistillOpts) error {
skipped++
fmt.Fprintf(os.Stderr, " ✗ SKIP %s (BO composite %d < %d)\n",
probe.ID, boScore, aiCfg.Scorer.AttentionMinScore)
runtime.GC()
continue
}
}
@ -306,8 +304,6 @@ func RunDistill(cfg DistillOpts) error {
if dedupIdx != nil && dedupIdx.IsDuplicate(bestFeatures, 0.02) {
deduped++
fmt.Fprintf(os.Stderr, " ~ DEDUP %s (grammar profile too similar to existing)\n", probe.ID)
// Release GPU memory between probes to prevent incremental leak.
runtime.GC()
continue
}
@ -343,9 +339,6 @@ func RunDistill(cfg DistillOpts) error {
fmt.Fprintf(os.Stderr, " ✗ SKIP %s (best g=%.1f < %.1f)\n",
probe.ID, score, minScore)
}
// Release GPU memory between probes to prevent incremental leak.
runtime.GC()
}
duration := time.Since(totalStart)

View file

@ -71,13 +71,10 @@ func RunExpand(cfg ExpandOpts) error {
}
defer duckDB.Close()
rows, err := duckDB.QueryExpansionPrompts("pending", cfg.Limit)
if err != nil {
return fmt.Errorf("query expansion_prompts: %v", err)
}
log.Printf("loaded %d pending prompts from %s", len(rows), cfg.DB)
for _, r := range rows {
for r, err := range duckDB.ExpansionPrompts("pending", cfg.Limit) {
if err != nil {
return fmt.Errorf("expansion prompts error: %v", err)
}
prompt := r.Prompt
if prompt == "" && r.PromptEn != "" {
prompt = r.PromptEn // Use English translation if primary is empty.
@ -88,6 +85,7 @@ func RunExpand(cfg ExpandOpts) error {
Prompt: prompt,
})
}
log.Printf("loaded %d pending prompts from %s", len(promptList), cfg.DB)
} else {
var err error
promptList, err = ReadResponses(cfg.Prompts)

View file

@ -62,14 +62,10 @@ func RunExport(cfg ExportOpts) error {
}
defer db.Close()
rows, err := db.QueryGoldenSet(cfg.MinChars)
if err != nil {
return fmt.Errorf("query golden_set: %w", err)
}
log.Printf("loaded %d golden set rows from %s (min_chars=%d)", len(rows), cfg.DBPath, cfg.MinChars)
// Convert GoldenSetRow → Response for the shared pipeline.
for _, r := range rows {
for r, err := range db.GoldenSet(cfg.MinChars) {
if err != nil {
return fmt.Errorf("golden set error: %w", err)
}
responses = append(responses, Response{
ID: r.SeedID,
Domain: r.Domain,
@ -78,6 +74,7 @@ func RunExport(cfg ExportOpts) error {
Model: r.Voice, // voice maps to the "model" slot for tracking
})
}
log.Printf("loaded %d golden set rows from %s (min_chars=%d)", len(responses), cfg.DBPath, cfg.MinChars)
} else {
// Fallback: read from JSONL file.
var err error

View file

@ -53,7 +53,9 @@ func RunImport(cfg ImportOpts) error {
}
}
if _, err := os.Stat(goldenPath); err == nil {
db.conn.Exec("DROP TABLE IF EXISTS golden_set")
if _, err := db.conn.Exec("DROP TABLE IF EXISTS golden_set"); err != nil {
log.Printf(" WARNING: drop golden_set failed: %v", err)
}
_, err := db.conn.Exec(fmt.Sprintf(`
CREATE TABLE golden_set AS
SELECT
@ -114,8 +116,10 @@ func RunImport(cfg ImportOpts) error {
}
}
db.conn.Exec("DROP TABLE IF EXISTS training_examples")
db.conn.Exec(`
if _, err := db.conn.Exec("DROP TABLE IF EXISTS training_examples"); err != nil {
log.Printf(" WARNING: drop training_examples failed: %v", err)
}
if _, err := db.conn.Exec(`
CREATE TABLE training_examples (
source VARCHAR,
split VARCHAR,
@ -125,7 +129,9 @@ func RunImport(cfg ImportOpts) error {
full_messages TEXT,
char_count INT
)
`)
`); err != nil {
log.Printf(" WARNING: create training_examples failed: %v", err)
}
trainingTotal := 0
for _, td := range trainingDirs {
@ -171,13 +177,17 @@ func RunImport(cfg ImportOpts) error {
}
}
db.conn.Exec("DROP TABLE IF EXISTS benchmark_results")
db.conn.Exec(`
if _, err := db.conn.Exec("DROP TABLE IF EXISTS benchmark_results"); err != nil {
log.Printf(" WARNING: drop benchmark_results failed: %v", err)
}
if _, err := db.conn.Exec(`
CREATE TABLE benchmark_results (
source VARCHAR, id VARCHAR, benchmark VARCHAR, model VARCHAR,
prompt TEXT, response TEXT, elapsed_seconds DOUBLE, domain VARCHAR
)
`)
`); err != nil {
log.Printf(" WARNING: create benchmark_results failed: %v", err)
}
benchTotal := 0
for _, subdir := range []string{"results", "scale_results", "cross_arch_results", "deepseek-r1-7b"} {
@ -208,13 +218,17 @@ func RunImport(cfg ImportOpts) error {
fmt.Printf(" benchmark_results: %d rows\n", benchTotal)
// ── 4. Benchmark questions ──
db.conn.Exec("DROP TABLE IF EXISTS benchmark_questions")
db.conn.Exec(`
if _, err := db.conn.Exec("DROP TABLE IF EXISTS benchmark_questions"); err != nil {
log.Printf(" WARNING: drop benchmark_questions failed: %v", err)
}
if _, err := db.conn.Exec(`
CREATE TABLE benchmark_questions (
benchmark VARCHAR, id VARCHAR, question TEXT,
best_answer TEXT, correct_answers TEXT, incorrect_answers TEXT, category VARCHAR
)
`)
`); err != nil {
log.Printf(" WARNING: create benchmark_questions failed: %v", err)
}
benchQTotal := 0
for _, bname := range []string{"truthfulqa", "gsm8k", "do_not_answer", "toxigen"} {
@ -228,12 +242,16 @@ func RunImport(cfg ImportOpts) error {
fmt.Printf(" benchmark_questions: %d rows\n", benchQTotal)
// ── 5. Seeds ──
db.conn.Exec("DROP TABLE IF EXISTS seeds")
db.conn.Exec(`
if _, err := db.conn.Exec("DROP TABLE IF EXISTS seeds"); err != nil {
log.Printf(" WARNING: drop seeds failed: %v", err)
}
if _, err := db.conn.Exec(`
CREATE TABLE seeds (
source_file VARCHAR, region VARCHAR, seed_id VARCHAR, domain VARCHAR, prompt TEXT
)
`)
`); err != nil {
log.Printf(" WARNING: create seeds failed: %v", err)
}
seedTotal := 0
seedDirs := []string{filepath.Join(dataDir, "seeds"), "/tmp/lem-data/seeds", "/tmp/lem-repo/seeds"}
@ -297,8 +315,10 @@ func importTrainingFile(db *DB, path, source, split string) int {
}
msgsJSON, _ := json.Marshal(rec.Messages)
db.conn.Exec(`INSERT INTO training_examples VALUES (?, ?, ?, ?, ?, ?, ?)`,
source, split, prompt, response, assistantCount, string(msgsJSON), len(response))
if _, err := db.conn.Exec(`INSERT INTO training_examples VALUES (?, ?, ?, ?, ?, ?, ?)`,
source, split, prompt, response, assistantCount, string(msgsJSON), len(response)); err != nil {
log.Printf(" WARNING: insert training example failed: %v", err)
}
count++
}
return count
@ -321,7 +341,7 @@ func importBenchmarkFile(db *DB, path, source string) int {
continue
}
db.conn.Exec(`INSERT INTO benchmark_results VALUES (?, ?, ?, ?, ?, ?, ?, ?)`,
if _, err := db.conn.Exec(`INSERT INTO benchmark_results VALUES (?, ?, ?, ?, ?, ?, ?, ?)`,
source,
fmt.Sprintf("%v", rec["id"]),
strOrEmpty(rec, "benchmark"),
@ -330,7 +350,9 @@ func importBenchmarkFile(db *DB, path, source string) int {
strOrEmpty(rec, "response"),
floatOrZero(rec, "elapsed_seconds"),
strOrEmpty(rec, "domain"),
)
); err != nil {
log.Printf(" WARNING: insert benchmark result failed: %v", err)
}
count++
}
return count
@ -356,7 +378,7 @@ func importBenchmarkQuestions(db *DB, path, benchmark string) int {
correctJSON, _ := json.Marshal(rec["correct_answers"])
incorrectJSON, _ := json.Marshal(rec["incorrect_answers"])
db.conn.Exec(`INSERT INTO benchmark_questions VALUES (?, ?, ?, ?, ?, ?, ?)`,
if _, err := db.conn.Exec(`INSERT INTO benchmark_questions VALUES (?, ?, ?, ?, ?, ?, ?)`,
benchmark,
fmt.Sprintf("%v", rec["id"]),
strOrEmpty(rec, "question"),
@ -364,7 +386,9 @@ func importBenchmarkQuestions(db *DB, path, benchmark string) int {
string(correctJSON),
string(incorrectJSON),
strOrEmpty(rec, "category"),
)
); err != nil {
log.Printf(" WARNING: insert benchmark question failed: %v", err)
}
count++
}
return count
@ -413,16 +437,20 @@ func importSeeds(db *DB, seedDir string) int {
if prompt == "" {
prompt = strOrEmpty(seed, "question")
}
db.conn.Exec(`INSERT INTO seeds VALUES (?, ?, ?, ?, ?)`,
if _, err := db.conn.Exec(`INSERT INTO seeds VALUES (?, ?, ?, ?, ?)`,
rel, region,
strOrEmpty(seed, "seed_id"),
strOrEmpty(seed, "domain"),
prompt,
)
); err != nil {
log.Printf(" WARNING: insert seed failed: %v", err)
}
count++
case string:
db.conn.Exec(`INSERT INTO seeds VALUES (?, ?, ?, ?, ?)`,
rel, region, "", "", seed)
if _, err := db.conn.Exec(`INSERT INTO seeds VALUES (?, ?, ?, ?, ?)`,
rel, region, "", "", seed); err != nil {
log.Printf(" WARNING: insert string seed failed: %v", err)
}
count++
}
}

View file

@ -71,7 +71,19 @@ func exportSplitParquet(jsonlPath, outputDir, split string) (int, error) {
}
defer f.Close()
var rows []ParquetRow
outPath := filepath.Join(outputDir, split+".parquet")
out, err := os.Create(outPath)
if err != nil {
return 0, fmt.Errorf("create %s: %w", outPath, err)
}
defer out.Close()
writer := parquet.NewGenericWriter[ParquetRow](out,
parquet.Compression(&parquet.Snappy),
)
defer writer.Close()
count := 0
scanner := bufio.NewScanner(f)
scanner.Buffer(make([]byte, 1024*1024), 1024*1024)
@ -107,51 +119,39 @@ func exportSplitParquet(jsonlPath, outputDir, split string) (int, error) {
}
msgsJSON, _ := json.Marshal(data.Messages)
rows = append(rows, ParquetRow{
row := ParquetRow{
Prompt: prompt,
Response: response,
System: system,
Messages: string(msgsJSON),
})
}
if _, err := writer.Write([]ParquetRow{row}); err != nil {
return count, fmt.Errorf("write parquet row: %w", err)
}
count++
}
if err := scanner.Err(); err != nil {
return 0, fmt.Errorf("scan %s: %w", jsonlPath, err)
return count, fmt.Errorf("scan %s: %w", jsonlPath, err)
}
if len(rows) == 0 {
if count == 0 {
fmt.Printf(" Skip: %s — no data\n", split)
return 0, nil
}
outPath := filepath.Join(outputDir, split+".parquet")
out, err := os.Create(outPath)
if err != nil {
return 0, fmt.Errorf("create %s: %w", outPath, err)
}
writer := parquet.NewGenericWriter[ParquetRow](out,
parquet.Compression(&parquet.Snappy),
)
if _, err := writer.Write(rows); err != nil {
out.Close()
return 0, fmt.Errorf("write parquet rows: %w", err)
}
if err := writer.Close(); err != nil {
out.Close()
return 0, fmt.Errorf("close parquet writer: %w", err)
return count, fmt.Errorf("close parquet writer: %w", err)
}
if err := out.Close(); err != nil {
return 0, fmt.Errorf("close file: %w", err)
return count, fmt.Errorf("close file: %w", err)
}
info, _ := os.Stat(outPath)
sizeMB := float64(info.Size()) / 1024 / 1024
fmt.Printf(" %s.parquet: %d rows (%.1f MB)\n", split, len(rows), sizeMB)
fmt.Printf(" %s.parquet: %d rows (%.1f MB)\n", split, count, sizeMB)
return len(rows), nil
return count, nil
}

View file

@ -2,7 +2,6 @@ package lem
import (
"fmt"
"log"
"os"
"strings"
)
@ -35,7 +34,7 @@ func RunTierScore(cfg TierScoreOpts) error {
defer db.Close()
// Ensure expansion_scores table exists.
db.conn.Exec(`
if _, err := db.conn.Exec(`
CREATE TABLE IF NOT EXISTS expansion_scores (
idx INT,
heuristic_score DOUBLE,
@ -49,10 +48,14 @@ func RunTierScore(cfg TierScoreOpts) error {
judge_model VARCHAR,
scored_at TIMESTAMP
)
`)
`); err != nil {
return fmt.Errorf("create expansion_scores: %w", err)
}
if cfg.Tier >= 1 {
runHeuristicTier(db, cfg.Limit)
if err := runHeuristicTier(db, cfg.Limit); err != nil {
return fmt.Errorf("tier 1: %w", err)
}
}
if cfg.Tier >= 2 {
@ -67,7 +70,7 @@ func RunTierScore(cfg TierScoreOpts) error {
return nil
}
func runHeuristicTier(db *DB, limit int) {
func runHeuristicTier(db *DB, limit int) error {
// Find unscored responses.
query := `
SELECT r.idx, r.response FROM expansion_raw r
@ -81,7 +84,7 @@ func runHeuristicTier(db *DB, limit int) {
rows, err := db.conn.Query(query)
if err != nil {
log.Fatalf("query unscored: %v", err)
return fmt.Errorf("query unscored: %w", err)
}
defer rows.Close()
@ -99,7 +102,7 @@ func runHeuristicTier(db *DB, limit int) {
if len(unscored) == 0 {
fmt.Println("Tier 1 (heuristic): all responses already scored")
return
return nil
}
fmt.Printf("Tier 1 (heuristic): scoring %d responses...\n", len(unscored))
@ -112,16 +115,19 @@ func runHeuristicTier(db *DB, limit int) {
passed++
}
db.conn.Exec(`
if _, err := db.conn.Exec(`
INSERT INTO expansion_scores (idx, heuristic_score, heuristic_pass, scored_at)
VALUES (?, ?, ?, current_timestamp)
`, r.idx, score, isPass)
`, r.idx, score, isPass); err != nil {
fmt.Printf(" WARNING: insert score failed for idx %d: %v\n", r.idx, err)
}
}
fmt.Printf(" Scored: %d, Passed: %d, Failed: %d\n", len(unscored), passed, len(unscored)-passed)
if len(unscored) > 0 {
fmt.Printf(" Pass rate: %.1f%%\n", float64(passed)/float64(len(unscored))*100)
}
return nil
}
// heuristicExpansionScore applies fast heuristic scoring to an expansion response.

View file

@ -2,14 +2,17 @@ package lem
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"os"
"os/signal"
"path/filepath"
"runtime"
"syscall"
"time"
)
@ -151,6 +154,12 @@ func RunWorker(opts WorkerOpts) error {
}
log.Println("Registered with LEM API")
ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
defer stop()
ticker := time.NewTicker(cfg.pollInterval)
defer ticker.Stop()
// Main loop.
for {
processed := workerPoll(&cfg)
@ -162,7 +171,20 @@ func RunWorker(opts WorkerOpts) error {
if processed == 0 {
log.Printf("No tasks available, sleeping %v", cfg.pollInterval)
time.Sleep(cfg.pollInterval)
select {
case <-ctx.Done():
log.Println("Worker shutting down...")
return nil
case <-ticker.C:
}
} else {
// Check context between tasks if needed.
select {
case <-ctx.Done():
log.Println("Worker shutting down...")
return nil
default:
}
}
// Heartbeat.