diff --git a/pkg/lem/setup.go b/pkg/lem/setup.go new file mode 100644 index 0000000..6f3cd25 --- /dev/null +++ b/pkg/lem/setup.go @@ -0,0 +1,332 @@ +// setup.go — `lem setup --data` hydrates cold JSONL.zst into warm DuckDB + hot InfluxDB. +package lem + +import ( + "bufio" + "encoding/json" + "fmt" + "log" + "os" + "path/filepath" + "strings" +) + +// SetupOpts holds configuration for the setup --data command. +type SetupOpts struct { + Root string // LEM repo root (defaults to ".") + DB string // DuckDB path (defaults to LEM_DB env, then ./lem.duckdb) + Influx string // InfluxDB URL (empty = default) + InfluxDB string // InfluxDB database name (empty = default) + SkipInflux bool // Skip InfluxDB backfill +} + +// RunSetup hydrates cold storage (JSONL.zst) into warm (DuckDB) and hot (InfluxDB). +func RunSetup(cfg SetupOpts) error { + root := cfg.Root + if root == "" { + root = "." + } + + dbPath := cfg.DB + if dbPath == "" { + dbPath = os.Getenv("LEM_DB") + } + if dbPath == "" { + dbPath = filepath.Join(root, "lem.duckdb") + } + + log.Printf("hydrating cold storage → %s", dbPath) + + db, err := OpenDBReadWrite(dbPath) + if err != nil { + return fmt.Errorf("open db: %w", err) + } + defer db.Close() + + // Create tables if they don't exist. + if err := ensureSetupTables(db); err != nil { + return fmt.Errorf("create tables: %w", err) + } + + trainingDir := filepath.Join(root, "training") + totals := make(map[string]int) + + // -- 1. Training examples from JSONL.zst -- + fmt.Println(" Hydrating training examples...") + trainingTotal := 0 + walkZstFiles(trainingDir, func(zstPath string) error { + rel, _ := filepath.Rel(trainingDir, zstPath) + // Skip seeds (they're structured differently) + if strings.HasPrefix(rel, "seeds/") || strings.HasPrefix(rel, "seeds"+string(filepath.Separator)) { + return nil + } + + source := strings.TrimSuffix(rel, ".jsonl.zst") + split := "train" + if strings.Contains(rel, "valid") { + split = "valid" + } else if strings.Contains(rel, "test") { + split = "test" + } + + n, err := hydrateTrainingFromZst(db, zstPath, source, split) + if err != nil { + log.Printf(" WARNING: %s: %v", rel, err) + return nil + } + if n > 0 { + trainingTotal += n + log.Printf(" %s: %d examples", rel, n) + } + return nil + }) + totals["training_examples"] = trainingTotal + + // -- 2. Seeds from JSONL.zst -- + fmt.Println(" Hydrating seeds...") + seedsDir := filepath.Join(trainingDir, "seeds") + seedTotal := 0 + if _, err := os.Stat(seedsDir); err == nil { + seedTotal = importSeeds(db, seedsDir) + } + // Also try .zst seed files + if _, err := os.Stat(seedsDir); err == nil { + walkZstFiles(seedsDir, func(zstPath string) error { + tmpFile, err := decompressToTemp(zstPath) + if err != nil { + log.Printf(" WARNING: decompress %s: %v", zstPath, err) + return nil + } + defer os.Remove(tmpFile) + + // Seeds are JSONL format — import as seeds + n := importSeedsFromJSONL(db, tmpFile, zstPath) + seedTotal += n + return nil + }) + } + totals["seeds"] = seedTotal + + // -- 3. Probes from JSON files -- + fmt.Println(" Hydrating probes...") + probeTotal := 0 + probesCfg, cfgErr := LoadProbesConfig(root) + if cfgErr == nil { + for setName, set := range probesCfg.Sets { + phase := 0 + if set.Phase != nil { + phase = *set.Phase + } + for _, f := range set.Files { + probePath := filepath.Join(root, "training", "lem", f) + n, err := hydrateProbes(db, probePath, phase) + if err != nil { + log.Printf(" WARNING: probes %s/%s: %v", setName, f, err) + continue + } + probeTotal += n + } + } + } + totals["probes"] = probeTotal + + // -- 4. InfluxDB backfill -- + if !cfg.SkipInflux { + fmt.Println(" Backfilling InfluxDB...") + influx := NewInfluxClient(cfg.Influx, cfg.InfluxDB) + backfilled := backfillInflux(db, influx) + fmt.Printf(" influx: %d points written\n", backfilled) + } + + // -- Summary -- + fmt.Printf("\n%s\n", strings.Repeat("=", 50)) + fmt.Println("LEM Data Hydration Complete") + fmt.Println(strings.Repeat("=", 50)) + grandTotal := 0 + for table, count := range totals { + fmt.Printf(" %-25s %8d\n", table, count) + grandTotal += count + } + fmt.Printf(" %s\n", strings.Repeat("-", 35)) + fmt.Printf(" %-25s %8d\n", "TOTAL", grandTotal) + fmt.Printf("\nDatabase: %s\n", dbPath) + return nil +} + +// ensureSetupTables creates tables if they don't exist. +func ensureSetupTables(db *DB) error { + tables := []string{ + `CREATE TABLE IF NOT EXISTS training_examples ( + source VARCHAR, split VARCHAR, prompt TEXT, response TEXT, + num_turns INT, full_messages TEXT, char_count INT + )`, + `CREATE TABLE IF NOT EXISTS seeds ( + source_file VARCHAR, region VARCHAR, seed_id VARCHAR, domain VARCHAR, prompt TEXT + )`, + `CREATE TABLE IF NOT EXISTS probes ( + id TEXT, domain TEXT, prompt TEXT, phase INTEGER, source_file TEXT + )`, + `CREATE TABLE IF NOT EXISTS distill_results ( + probe_id TEXT, model TEXT, phase INTEGER, + prompt TEXT, response TEXT, source_file TEXT, + imported_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + )`, + } + for _, ddl := range tables { + if _, err := db.conn.Exec(ddl); err != nil { + return fmt.Errorf("exec DDL: %w", err) + } + } + return nil +} + +// hydrateTrainingFromZst decompresses a .jsonl.zst file and imports chat-format +// training examples into DuckDB. Returns the number of rows imported. +func hydrateTrainingFromZst(db *DB, zstPath, source, split string) (int, error) { + tmpFile, err := decompressToTemp(zstPath) + if err != nil { + return 0, err + } + defer os.Remove(tmpFile) + + // Ensure table exists before importing + ensureSetupTables(db) + + return importTrainingFile(db, tmpFile, source, split), nil +} + +// decompressToTemp decompresses a .zst file to a temporary file, returning the path. +func decompressToTemp(zstPath string) (string, error) { + tmp, err := os.CreateTemp("", "lem-hydrate-*.jsonl") + if err != nil { + return "", fmt.Errorf("create temp: %w", err) + } + tmpPath := tmp.Name() + tmp.Close() + + if err := decompressZstd(zstPath, tmpPath); err != nil { + os.Remove(tmpPath) + return "", err + } + return tmpPath, nil +} + +// hydrateProbes reads a JSON probe file and inserts into the probes table. +func hydrateProbes(db *DB, path string, phase int) (int, error) { + data, err := os.ReadFile(path) + if err != nil { + return 0, fmt.Errorf("read %s: %w", path, err) + } + + // Ensure table exists + ensureSetupTables(db) + + var probes []struct { + ID string `json:"id"` + Domain string `json:"domain"` + Prompt string `json:"prompt"` + } + if err := json.Unmarshal(data, &probes); err != nil { + return 0, fmt.Errorf("parse %s: %w", filepath.Base(path), err) + } + + rel := filepath.Base(path) + count := 0 + for _, p := range probes { + if _, err := db.conn.Exec( + "INSERT INTO probes VALUES (?, ?, ?, ?, ?)", + p.ID, p.Domain, p.Prompt, phase, rel, + ); err != nil { + log.Printf(" WARNING: insert probe %s: %v", p.ID, err) + continue + } + count++ + } + return count, nil +} + +// importSeedsFromJSONL imports seed data from a decompressed JSONL file. +func importSeedsFromJSONL(db *DB, path, sourcePath string) int { + f, err := os.Open(path) + if err != nil { + return 0 + } + defer f.Close() + + rel := filepath.Base(sourcePath) + count := 0 + scanner := bufio.NewScanner(f) + scanner.Buffer(make([]byte, 1024*1024), 1024*1024) + + for scanner.Scan() { + var rec map[string]any + if err := json.Unmarshal(scanner.Bytes(), &rec); err != nil { + continue + } + prompt := strOrEmpty(rec, "prompt") + if prompt == "" { + prompt = strOrEmpty(rec, "text") + } + if prompt == "" { + continue + } + if _, err := db.conn.Exec("INSERT INTO seeds VALUES (?, ?, ?, ?, ?)", + rel, strOrEmpty(rec, "region"), + strOrEmpty(rec, "seed_id"), strOrEmpty(rec, "domain"), + prompt, + ); err != nil { + continue + } + count++ + } + return count +} + +// backfillInflux writes aggregate stats from DuckDB to InfluxDB. +// Best-effort — returns count of points written. +func backfillInflux(db *DB, influx *InfluxClient) int { + count := 0 + + // Training examples summary + rows, err := db.QueryRows("SELECT source, split, COUNT(*) AS n, SUM(char_count) AS chars FROM training_examples GROUP BY source, split") + if err == nil { + var lines []string + for _, row := range rows { + source := fmt.Sprintf("%v", row["source"]) + split := fmt.Sprintf("%v", row["split"]) + n := row["n"] + chars := row["chars"] + line := fmt.Sprintf("training_hydration,source=%s,split=%s examples=%vi,chars=%vi", + escapeLp(source), escapeLp(split), n, chars) + lines = append(lines, line) + } + if len(lines) > 0 { + if err := influx.WriteLp(lines); err != nil { + log.Printf(" influx training stats: %v", err) + } else { + count += len(lines) + } + } + } + + // Probes summary + probeRows, err := db.QueryRows("SELECT phase, COUNT(*) AS n FROM probes GROUP BY phase ORDER BY phase") + if err == nil { + var lines []string + for _, row := range probeRows { + phase := fmt.Sprintf("%v", row["phase"]) + n := row["n"] + line := fmt.Sprintf("probe_hydration,phase=%s count=%vi", phase, n) + lines = append(lines, line) + } + if len(lines) > 0 { + if err := influx.WriteLp(lines); err != nil { + log.Printf(" influx probe stats: %v", err) + } else { + count += len(lines) + } + } + } + + return count +} diff --git a/pkg/lem/setup_test.go b/pkg/lem/setup_test.go new file mode 100644 index 0000000..69b2e44 --- /dev/null +++ b/pkg/lem/setup_test.go @@ -0,0 +1,71 @@ +package lem + +import ( + "os" + "path/filepath" + "testing" +) + +func TestHydrateTrainingExamples(t *testing.T) { + dir := t.TempDir() + + // Create a .jsonl.zst file with two training examples + jsonl := `{"messages":[{"role":"user","content":"What is kindness?"},{"role":"assistant","content":"Kindness is..."}]} +{"messages":[{"role":"user","content":"Explain gravity"},{"role":"assistant","content":"Gravity is..."}]} +` + src := filepath.Join(dir, "train.jsonl") + os.WriteFile(src, []byte(jsonl), 0644) + zst := src + ".zst" + compressFileZstd(src, zst) + os.Remove(src) + + // Hydrate into DuckDB + dbPath := filepath.Join(dir, "test.duckdb") + db, err := OpenDBReadWrite(dbPath) + if err != nil { + t.Fatal(err) + } + defer db.Close() + + n, err := hydrateTrainingFromZst(db, zst, "test-source", "train") + if err != nil { + t.Fatal(err) + } + if n != 2 { + t.Fatalf("expected 2 rows, got %d", n) + } + + // Verify data + rows, err := db.QueryRows("SELECT COUNT(*) AS n FROM training_examples") + if err != nil { + t.Fatal(err) + } + count := rows[0]["n"].(int64) + if count != 2 { + t.Fatalf("expected 2 rows in DB, got %d", count) + } +} + +func TestHydrateProbes(t *testing.T) { + dir := t.TempDir() + + // Create a probe JSON file + probeJSON := `[{"id":"P01","domain":"Identity","prompt":"Design auth..."},{"id":"P02","domain":"Network","prompt":"Build mesh..."}]` + probeFile := filepath.Join(dir, "core.json") + os.WriteFile(probeFile, []byte(probeJSON), 0644) + + dbPath := filepath.Join(dir, "test.duckdb") + db, err := OpenDBReadWrite(dbPath) + if err != nil { + t.Fatal(err) + } + defer db.Close() + + n, err := hydrateProbes(db, probeFile, 0) + if err != nil { + t.Fatal(err) + } + if n != 2 { + t.Fatalf("expected 2 probes, got %d", n) + } +}