RunSetup decompresses .jsonl.zst training data into DuckDB tables (training_examples, seeds, probes, distill_results) and optionally backfills InfluxDB with aggregate stats. Co-Authored-By: Virgil <virgil@lethean.io>
332 lines
8.8 KiB
Go
332 lines
8.8 KiB
Go
// setup.go — `lem setup --data` hydrates cold JSONL.zst into warm DuckDB + hot InfluxDB.
|
|
package lem
|
|
|
|
import (
|
|
"bufio"
|
|
"encoding/json"
|
|
"fmt"
|
|
"log"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
)
|
|
|
|
// SetupOpts holds configuration for the setup --data command.
|
|
type SetupOpts struct {
|
|
Root string // LEM repo root (defaults to ".")
|
|
DB string // DuckDB path (defaults to LEM_DB env, then ./lem.duckdb)
|
|
Influx string // InfluxDB URL (empty = default)
|
|
InfluxDB string // InfluxDB database name (empty = default)
|
|
SkipInflux bool // Skip InfluxDB backfill
|
|
}
|
|
|
|
// RunSetup hydrates cold storage (JSONL.zst) into warm (DuckDB) and hot (InfluxDB).
|
|
func RunSetup(cfg SetupOpts) error {
|
|
root := cfg.Root
|
|
if root == "" {
|
|
root = "."
|
|
}
|
|
|
|
dbPath := cfg.DB
|
|
if dbPath == "" {
|
|
dbPath = os.Getenv("LEM_DB")
|
|
}
|
|
if dbPath == "" {
|
|
dbPath = filepath.Join(root, "lem.duckdb")
|
|
}
|
|
|
|
log.Printf("hydrating cold storage → %s", dbPath)
|
|
|
|
db, err := OpenDBReadWrite(dbPath)
|
|
if err != nil {
|
|
return fmt.Errorf("open db: %w", err)
|
|
}
|
|
defer db.Close()
|
|
|
|
// Create tables if they don't exist.
|
|
if err := ensureSetupTables(db); err != nil {
|
|
return fmt.Errorf("create tables: %w", err)
|
|
}
|
|
|
|
trainingDir := filepath.Join(root, "training")
|
|
totals := make(map[string]int)
|
|
|
|
// -- 1. Training examples from JSONL.zst --
|
|
fmt.Println(" Hydrating training examples...")
|
|
trainingTotal := 0
|
|
walkZstFiles(trainingDir, func(zstPath string) error {
|
|
rel, _ := filepath.Rel(trainingDir, zstPath)
|
|
// Skip seeds (they're structured differently)
|
|
if strings.HasPrefix(rel, "seeds/") || strings.HasPrefix(rel, "seeds"+string(filepath.Separator)) {
|
|
return nil
|
|
}
|
|
|
|
source := strings.TrimSuffix(rel, ".jsonl.zst")
|
|
split := "train"
|
|
if strings.Contains(rel, "valid") {
|
|
split = "valid"
|
|
} else if strings.Contains(rel, "test") {
|
|
split = "test"
|
|
}
|
|
|
|
n, err := hydrateTrainingFromZst(db, zstPath, source, split)
|
|
if err != nil {
|
|
log.Printf(" WARNING: %s: %v", rel, err)
|
|
return nil
|
|
}
|
|
if n > 0 {
|
|
trainingTotal += n
|
|
log.Printf(" %s: %d examples", rel, n)
|
|
}
|
|
return nil
|
|
})
|
|
totals["training_examples"] = trainingTotal
|
|
|
|
// -- 2. Seeds from JSONL.zst --
|
|
fmt.Println(" Hydrating seeds...")
|
|
seedsDir := filepath.Join(trainingDir, "seeds")
|
|
seedTotal := 0
|
|
if _, err := os.Stat(seedsDir); err == nil {
|
|
seedTotal = importSeeds(db, seedsDir)
|
|
}
|
|
// Also try .zst seed files
|
|
if _, err := os.Stat(seedsDir); err == nil {
|
|
walkZstFiles(seedsDir, func(zstPath string) error {
|
|
tmpFile, err := decompressToTemp(zstPath)
|
|
if err != nil {
|
|
log.Printf(" WARNING: decompress %s: %v", zstPath, err)
|
|
return nil
|
|
}
|
|
defer os.Remove(tmpFile)
|
|
|
|
// Seeds are JSONL format — import as seeds
|
|
n := importSeedsFromJSONL(db, tmpFile, zstPath)
|
|
seedTotal += n
|
|
return nil
|
|
})
|
|
}
|
|
totals["seeds"] = seedTotal
|
|
|
|
// -- 3. Probes from JSON files --
|
|
fmt.Println(" Hydrating probes...")
|
|
probeTotal := 0
|
|
probesCfg, cfgErr := LoadProbesConfig(root)
|
|
if cfgErr == nil {
|
|
for setName, set := range probesCfg.Sets {
|
|
phase := 0
|
|
if set.Phase != nil {
|
|
phase = *set.Phase
|
|
}
|
|
for _, f := range set.Files {
|
|
probePath := filepath.Join(root, "training", "lem", f)
|
|
n, err := hydrateProbes(db, probePath, phase)
|
|
if err != nil {
|
|
log.Printf(" WARNING: probes %s/%s: %v", setName, f, err)
|
|
continue
|
|
}
|
|
probeTotal += n
|
|
}
|
|
}
|
|
}
|
|
totals["probes"] = probeTotal
|
|
|
|
// -- 4. InfluxDB backfill --
|
|
if !cfg.SkipInflux {
|
|
fmt.Println(" Backfilling InfluxDB...")
|
|
influx := NewInfluxClient(cfg.Influx, cfg.InfluxDB)
|
|
backfilled := backfillInflux(db, influx)
|
|
fmt.Printf(" influx: %d points written\n", backfilled)
|
|
}
|
|
|
|
// -- Summary --
|
|
fmt.Printf("\n%s\n", strings.Repeat("=", 50))
|
|
fmt.Println("LEM Data Hydration Complete")
|
|
fmt.Println(strings.Repeat("=", 50))
|
|
grandTotal := 0
|
|
for table, count := range totals {
|
|
fmt.Printf(" %-25s %8d\n", table, count)
|
|
grandTotal += count
|
|
}
|
|
fmt.Printf(" %s\n", strings.Repeat("-", 35))
|
|
fmt.Printf(" %-25s %8d\n", "TOTAL", grandTotal)
|
|
fmt.Printf("\nDatabase: %s\n", dbPath)
|
|
return nil
|
|
}
|
|
|
|
// ensureSetupTables creates tables if they don't exist.
|
|
func ensureSetupTables(db *DB) error {
|
|
tables := []string{
|
|
`CREATE TABLE IF NOT EXISTS training_examples (
|
|
source VARCHAR, split VARCHAR, prompt TEXT, response TEXT,
|
|
num_turns INT, full_messages TEXT, char_count INT
|
|
)`,
|
|
`CREATE TABLE IF NOT EXISTS seeds (
|
|
source_file VARCHAR, region VARCHAR, seed_id VARCHAR, domain VARCHAR, prompt TEXT
|
|
)`,
|
|
`CREATE TABLE IF NOT EXISTS probes (
|
|
id TEXT, domain TEXT, prompt TEXT, phase INTEGER, source_file TEXT
|
|
)`,
|
|
`CREATE TABLE IF NOT EXISTS distill_results (
|
|
probe_id TEXT, model TEXT, phase INTEGER,
|
|
prompt TEXT, response TEXT, source_file TEXT,
|
|
imported_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
|
)`,
|
|
}
|
|
for _, ddl := range tables {
|
|
if _, err := db.conn.Exec(ddl); err != nil {
|
|
return fmt.Errorf("exec DDL: %w", err)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// hydrateTrainingFromZst decompresses a .jsonl.zst file and imports chat-format
|
|
// training examples into DuckDB. Returns the number of rows imported.
|
|
func hydrateTrainingFromZst(db *DB, zstPath, source, split string) (int, error) {
|
|
tmpFile, err := decompressToTemp(zstPath)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
defer os.Remove(tmpFile)
|
|
|
|
// Ensure table exists before importing
|
|
ensureSetupTables(db)
|
|
|
|
return importTrainingFile(db, tmpFile, source, split), nil
|
|
}
|
|
|
|
// decompressToTemp decompresses a .zst file to a temporary file, returning the path.
|
|
func decompressToTemp(zstPath string) (string, error) {
|
|
tmp, err := os.CreateTemp("", "lem-hydrate-*.jsonl")
|
|
if err != nil {
|
|
return "", fmt.Errorf("create temp: %w", err)
|
|
}
|
|
tmpPath := tmp.Name()
|
|
tmp.Close()
|
|
|
|
if err := decompressZstd(zstPath, tmpPath); err != nil {
|
|
os.Remove(tmpPath)
|
|
return "", err
|
|
}
|
|
return tmpPath, nil
|
|
}
|
|
|
|
// hydrateProbes reads a JSON probe file and inserts into the probes table.
|
|
func hydrateProbes(db *DB, path string, phase int) (int, error) {
|
|
data, err := os.ReadFile(path)
|
|
if err != nil {
|
|
return 0, fmt.Errorf("read %s: %w", path, err)
|
|
}
|
|
|
|
// Ensure table exists
|
|
ensureSetupTables(db)
|
|
|
|
var probes []struct {
|
|
ID string `json:"id"`
|
|
Domain string `json:"domain"`
|
|
Prompt string `json:"prompt"`
|
|
}
|
|
if err := json.Unmarshal(data, &probes); err != nil {
|
|
return 0, fmt.Errorf("parse %s: %w", filepath.Base(path), err)
|
|
}
|
|
|
|
rel := filepath.Base(path)
|
|
count := 0
|
|
for _, p := range probes {
|
|
if _, err := db.conn.Exec(
|
|
"INSERT INTO probes VALUES (?, ?, ?, ?, ?)",
|
|
p.ID, p.Domain, p.Prompt, phase, rel,
|
|
); err != nil {
|
|
log.Printf(" WARNING: insert probe %s: %v", p.ID, err)
|
|
continue
|
|
}
|
|
count++
|
|
}
|
|
return count, nil
|
|
}
|
|
|
|
// importSeedsFromJSONL imports seed data from a decompressed JSONL file.
|
|
func importSeedsFromJSONL(db *DB, path, sourcePath string) int {
|
|
f, err := os.Open(path)
|
|
if err != nil {
|
|
return 0
|
|
}
|
|
defer f.Close()
|
|
|
|
rel := filepath.Base(sourcePath)
|
|
count := 0
|
|
scanner := bufio.NewScanner(f)
|
|
scanner.Buffer(make([]byte, 1024*1024), 1024*1024)
|
|
|
|
for scanner.Scan() {
|
|
var rec map[string]any
|
|
if err := json.Unmarshal(scanner.Bytes(), &rec); err != nil {
|
|
continue
|
|
}
|
|
prompt := strOrEmpty(rec, "prompt")
|
|
if prompt == "" {
|
|
prompt = strOrEmpty(rec, "text")
|
|
}
|
|
if prompt == "" {
|
|
continue
|
|
}
|
|
if _, err := db.conn.Exec("INSERT INTO seeds VALUES (?, ?, ?, ?, ?)",
|
|
rel, strOrEmpty(rec, "region"),
|
|
strOrEmpty(rec, "seed_id"), strOrEmpty(rec, "domain"),
|
|
prompt,
|
|
); err != nil {
|
|
continue
|
|
}
|
|
count++
|
|
}
|
|
return count
|
|
}
|
|
|
|
// backfillInflux writes aggregate stats from DuckDB to InfluxDB.
|
|
// Best-effort — returns count of points written.
|
|
func backfillInflux(db *DB, influx *InfluxClient) int {
|
|
count := 0
|
|
|
|
// Training examples summary
|
|
rows, err := db.QueryRows("SELECT source, split, COUNT(*) AS n, SUM(char_count) AS chars FROM training_examples GROUP BY source, split")
|
|
if err == nil {
|
|
var lines []string
|
|
for _, row := range rows {
|
|
source := fmt.Sprintf("%v", row["source"])
|
|
split := fmt.Sprintf("%v", row["split"])
|
|
n := row["n"]
|
|
chars := row["chars"]
|
|
line := fmt.Sprintf("training_hydration,source=%s,split=%s examples=%vi,chars=%vi",
|
|
escapeLp(source), escapeLp(split), n, chars)
|
|
lines = append(lines, line)
|
|
}
|
|
if len(lines) > 0 {
|
|
if err := influx.WriteLp(lines); err != nil {
|
|
log.Printf(" influx training stats: %v", err)
|
|
} else {
|
|
count += len(lines)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Probes summary
|
|
probeRows, err := db.QueryRows("SELECT phase, COUNT(*) AS n FROM probes GROUP BY phase ORDER BY phase")
|
|
if err == nil {
|
|
var lines []string
|
|
for _, row := range probeRows {
|
|
phase := fmt.Sprintf("%v", row["phase"])
|
|
n := row["n"]
|
|
line := fmt.Sprintf("probe_hydration,phase=%s count=%vi", phase, n)
|
|
lines = append(lines, line)
|
|
}
|
|
if len(lines) > 0 {
|
|
if err := influx.WriteLp(lines); err != nil {
|
|
log.Printf(" influx probe stats: %v", err)
|
|
} else {
|
|
count += len(lines)
|
|
}
|
|
}
|
|
}
|
|
|
|
return count
|
|
}
|