LEM/pkg/lem/setup.go
Snider 1269e70853 feat: add data hydration engine (cold JSONL.zst -> warm DuckDB -> hot InfluxDB)
RunSetup decompresses .jsonl.zst training data into DuckDB tables
(training_examples, seeds, probes, distill_results) and optionally
backfills InfluxDB with aggregate stats.

Co-Authored-By: Virgil <virgil@lethean.io>
2026-02-28 12:20:00 +00:00

332 lines
8.8 KiB
Go

// setup.go — `lem setup --data` hydrates cold JSONL.zst into warm DuckDB + hot InfluxDB.
package lem
import (
"bufio"
"encoding/json"
"fmt"
"log"
"os"
"path/filepath"
"strings"
)
// SetupOpts holds configuration for the setup --data command.
type SetupOpts struct {
Root string // LEM repo root (defaults to ".")
DB string // DuckDB path (defaults to LEM_DB env, then ./lem.duckdb)
Influx string // InfluxDB URL (empty = default)
InfluxDB string // InfluxDB database name (empty = default)
SkipInflux bool // Skip InfluxDB backfill
}
// RunSetup hydrates cold storage (JSONL.zst) into warm (DuckDB) and hot (InfluxDB).
func RunSetup(cfg SetupOpts) error {
root := cfg.Root
if root == "" {
root = "."
}
dbPath := cfg.DB
if dbPath == "" {
dbPath = os.Getenv("LEM_DB")
}
if dbPath == "" {
dbPath = filepath.Join(root, "lem.duckdb")
}
log.Printf("hydrating cold storage → %s", dbPath)
db, err := OpenDBReadWrite(dbPath)
if err != nil {
return fmt.Errorf("open db: %w", err)
}
defer db.Close()
// Create tables if they don't exist.
if err := ensureSetupTables(db); err != nil {
return fmt.Errorf("create tables: %w", err)
}
trainingDir := filepath.Join(root, "training")
totals := make(map[string]int)
// -- 1. Training examples from JSONL.zst --
fmt.Println(" Hydrating training examples...")
trainingTotal := 0
walkZstFiles(trainingDir, func(zstPath string) error {
rel, _ := filepath.Rel(trainingDir, zstPath)
// Skip seeds (they're structured differently)
if strings.HasPrefix(rel, "seeds/") || strings.HasPrefix(rel, "seeds"+string(filepath.Separator)) {
return nil
}
source := strings.TrimSuffix(rel, ".jsonl.zst")
split := "train"
if strings.Contains(rel, "valid") {
split = "valid"
} else if strings.Contains(rel, "test") {
split = "test"
}
n, err := hydrateTrainingFromZst(db, zstPath, source, split)
if err != nil {
log.Printf(" WARNING: %s: %v", rel, err)
return nil
}
if n > 0 {
trainingTotal += n
log.Printf(" %s: %d examples", rel, n)
}
return nil
})
totals["training_examples"] = trainingTotal
// -- 2. Seeds from JSONL.zst --
fmt.Println(" Hydrating seeds...")
seedsDir := filepath.Join(trainingDir, "seeds")
seedTotal := 0
if _, err := os.Stat(seedsDir); err == nil {
seedTotal = importSeeds(db, seedsDir)
}
// Also try .zst seed files
if _, err := os.Stat(seedsDir); err == nil {
walkZstFiles(seedsDir, func(zstPath string) error {
tmpFile, err := decompressToTemp(zstPath)
if err != nil {
log.Printf(" WARNING: decompress %s: %v", zstPath, err)
return nil
}
defer os.Remove(tmpFile)
// Seeds are JSONL format — import as seeds
n := importSeedsFromJSONL(db, tmpFile, zstPath)
seedTotal += n
return nil
})
}
totals["seeds"] = seedTotal
// -- 3. Probes from JSON files --
fmt.Println(" Hydrating probes...")
probeTotal := 0
probesCfg, cfgErr := LoadProbesConfig(root)
if cfgErr == nil {
for setName, set := range probesCfg.Sets {
phase := 0
if set.Phase != nil {
phase = *set.Phase
}
for _, f := range set.Files {
probePath := filepath.Join(root, "training", "lem", f)
n, err := hydrateProbes(db, probePath, phase)
if err != nil {
log.Printf(" WARNING: probes %s/%s: %v", setName, f, err)
continue
}
probeTotal += n
}
}
}
totals["probes"] = probeTotal
// -- 4. InfluxDB backfill --
if !cfg.SkipInflux {
fmt.Println(" Backfilling InfluxDB...")
influx := NewInfluxClient(cfg.Influx, cfg.InfluxDB)
backfilled := backfillInflux(db, influx)
fmt.Printf(" influx: %d points written\n", backfilled)
}
// -- Summary --
fmt.Printf("\n%s\n", strings.Repeat("=", 50))
fmt.Println("LEM Data Hydration Complete")
fmt.Println(strings.Repeat("=", 50))
grandTotal := 0
for table, count := range totals {
fmt.Printf(" %-25s %8d\n", table, count)
grandTotal += count
}
fmt.Printf(" %s\n", strings.Repeat("-", 35))
fmt.Printf(" %-25s %8d\n", "TOTAL", grandTotal)
fmt.Printf("\nDatabase: %s\n", dbPath)
return nil
}
// ensureSetupTables creates tables if they don't exist.
func ensureSetupTables(db *DB) error {
tables := []string{
`CREATE TABLE IF NOT EXISTS training_examples (
source VARCHAR, split VARCHAR, prompt TEXT, response TEXT,
num_turns INT, full_messages TEXT, char_count INT
)`,
`CREATE TABLE IF NOT EXISTS seeds (
source_file VARCHAR, region VARCHAR, seed_id VARCHAR, domain VARCHAR, prompt TEXT
)`,
`CREATE TABLE IF NOT EXISTS probes (
id TEXT, domain TEXT, prompt TEXT, phase INTEGER, source_file TEXT
)`,
`CREATE TABLE IF NOT EXISTS distill_results (
probe_id TEXT, model TEXT, phase INTEGER,
prompt TEXT, response TEXT, source_file TEXT,
imported_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)`,
}
for _, ddl := range tables {
if _, err := db.conn.Exec(ddl); err != nil {
return fmt.Errorf("exec DDL: %w", err)
}
}
return nil
}
// hydrateTrainingFromZst decompresses a .jsonl.zst file and imports chat-format
// training examples into DuckDB. Returns the number of rows imported.
func hydrateTrainingFromZst(db *DB, zstPath, source, split string) (int, error) {
tmpFile, err := decompressToTemp(zstPath)
if err != nil {
return 0, err
}
defer os.Remove(tmpFile)
// Ensure table exists before importing
ensureSetupTables(db)
return importTrainingFile(db, tmpFile, source, split), nil
}
// decompressToTemp decompresses a .zst file to a temporary file, returning the path.
func decompressToTemp(zstPath string) (string, error) {
tmp, err := os.CreateTemp("", "lem-hydrate-*.jsonl")
if err != nil {
return "", fmt.Errorf("create temp: %w", err)
}
tmpPath := tmp.Name()
tmp.Close()
if err := decompressZstd(zstPath, tmpPath); err != nil {
os.Remove(tmpPath)
return "", err
}
return tmpPath, nil
}
// hydrateProbes reads a JSON probe file and inserts into the probes table.
func hydrateProbes(db *DB, path string, phase int) (int, error) {
data, err := os.ReadFile(path)
if err != nil {
return 0, fmt.Errorf("read %s: %w", path, err)
}
// Ensure table exists
ensureSetupTables(db)
var probes []struct {
ID string `json:"id"`
Domain string `json:"domain"`
Prompt string `json:"prompt"`
}
if err := json.Unmarshal(data, &probes); err != nil {
return 0, fmt.Errorf("parse %s: %w", filepath.Base(path), err)
}
rel := filepath.Base(path)
count := 0
for _, p := range probes {
if _, err := db.conn.Exec(
"INSERT INTO probes VALUES (?, ?, ?, ?, ?)",
p.ID, p.Domain, p.Prompt, phase, rel,
); err != nil {
log.Printf(" WARNING: insert probe %s: %v", p.ID, err)
continue
}
count++
}
return count, nil
}
// importSeedsFromJSONL imports seed data from a decompressed JSONL file.
func importSeedsFromJSONL(db *DB, path, sourcePath string) int {
f, err := os.Open(path)
if err != nil {
return 0
}
defer f.Close()
rel := filepath.Base(sourcePath)
count := 0
scanner := bufio.NewScanner(f)
scanner.Buffer(make([]byte, 1024*1024), 1024*1024)
for scanner.Scan() {
var rec map[string]any
if err := json.Unmarshal(scanner.Bytes(), &rec); err != nil {
continue
}
prompt := strOrEmpty(rec, "prompt")
if prompt == "" {
prompt = strOrEmpty(rec, "text")
}
if prompt == "" {
continue
}
if _, err := db.conn.Exec("INSERT INTO seeds VALUES (?, ?, ?, ?, ?)",
rel, strOrEmpty(rec, "region"),
strOrEmpty(rec, "seed_id"), strOrEmpty(rec, "domain"),
prompt,
); err != nil {
continue
}
count++
}
return count
}
// backfillInflux writes aggregate stats from DuckDB to InfluxDB.
// Best-effort — returns count of points written.
func backfillInflux(db *DB, influx *InfluxClient) int {
count := 0
// Training examples summary
rows, err := db.QueryRows("SELECT source, split, COUNT(*) AS n, SUM(char_count) AS chars FROM training_examples GROUP BY source, split")
if err == nil {
var lines []string
for _, row := range rows {
source := fmt.Sprintf("%v", row["source"])
split := fmt.Sprintf("%v", row["split"])
n := row["n"]
chars := row["chars"]
line := fmt.Sprintf("training_hydration,source=%s,split=%s examples=%vi,chars=%vi",
escapeLp(source), escapeLp(split), n, chars)
lines = append(lines, line)
}
if len(lines) > 0 {
if err := influx.WriteLp(lines); err != nil {
log.Printf(" influx training stats: %v", err)
} else {
count += len(lines)
}
}
}
// Probes summary
probeRows, err := db.QueryRows("SELECT phase, COUNT(*) AS n FROM probes GROUP BY phase ORDER BY phase")
if err == nil {
var lines []string
for _, row := range probeRows {
phase := fmt.Sprintf("%v", row["phase"])
n := row["n"]
line := fmt.Sprintf("probe_hydration,phase=%s count=%vi", phase, n)
lines = append(lines, line)
}
if len(lines) > 0 {
if err := influx.WriteLp(lines); err != nil {
log.Printf(" influx probe stats: %v", err)
} else {
count += len(lines)
}
}
}
return count
}