From 1f3a1bcc47874c9e02eaaf6b8d7e7c4e3f91bebf Mon Sep 17 00:00:00 2001 From: Claude Date: Mon, 16 Feb 2026 04:02:28 +0000 Subject: [PATCH] feat: port 11 LEM data management commands into core ml Ports all remaining LEM pipeline commands from pkg/lem into core ml, eliminating the standalone LEM CLI dependency. Each command is split into reusable business logic (pkg/ml/) and a thin cobra wrapper (internal/cmd/ml/). New commands: query, inventory, metrics, ingest, normalize, seed-influx, consolidate, import-all, approve, publish, coverage. Adds Path(), Exec(), QueryRowScan() convenience methods to DB type. Co-Authored-By: Claude Opus 4.6 --- internal/cmd/ml/cmd_approve.go | 53 ++++ internal/cmd/ml/cmd_consolidate.go | 41 +++ internal/cmd/ml/cmd_coverage.go | 34 +++ internal/cmd/ml/cmd_import.go | 58 ++++ internal/cmd/ml/cmd_ingest.go | 54 ++++ internal/cmd/ml/cmd_inventory.go | 34 +++ internal/cmd/ml/cmd_metrics.go | 36 +++ internal/cmd/ml/cmd_ml.go | 22 ++ internal/cmd/ml/cmd_normalize.go | 44 +++ internal/cmd/ml/cmd_publish.go | 40 +++ internal/cmd/ml/cmd_query.go | 148 ++++++++++ internal/cmd/ml/cmd_seed_influx.go | 49 ++++ pkg/ml/approve.go | 82 ++++++ pkg/ml/consolidate.go | 150 ++++++++++ pkg/ml/coverage.go | 127 +++++++++ pkg/ml/db.go | 17 ++ pkg/ml/import_all.go | 437 +++++++++++++++++++++++++++++ pkg/ml/ingest.go | 384 +++++++++++++++++++++++++ pkg/ml/inventory.go | 147 ++++++++++ pkg/ml/metrics.go | 100 +++++++ pkg/ml/normalize.go | 153 ++++++++++ pkg/ml/publish.go | 157 +++++++++++ pkg/ml/seed_influx.go | 111 ++++++++ 23 files changed, 2478 insertions(+) create mode 100644 internal/cmd/ml/cmd_approve.go create mode 100644 internal/cmd/ml/cmd_consolidate.go create mode 100644 internal/cmd/ml/cmd_coverage.go create mode 100644 internal/cmd/ml/cmd_import.go create mode 100644 internal/cmd/ml/cmd_ingest.go create mode 100644 internal/cmd/ml/cmd_inventory.go create mode 100644 internal/cmd/ml/cmd_metrics.go create mode 100644 internal/cmd/ml/cmd_normalize.go create mode 100644 internal/cmd/ml/cmd_publish.go create mode 100644 internal/cmd/ml/cmd_query.go create mode 100644 internal/cmd/ml/cmd_seed_influx.go create mode 100644 pkg/ml/approve.go create mode 100644 pkg/ml/consolidate.go create mode 100644 pkg/ml/coverage.go create mode 100644 pkg/ml/import_all.go create mode 100644 pkg/ml/ingest.go create mode 100644 pkg/ml/inventory.go create mode 100644 pkg/ml/metrics.go create mode 100644 pkg/ml/normalize.go create mode 100644 pkg/ml/publish.go create mode 100644 pkg/ml/seed_influx.go diff --git a/internal/cmd/ml/cmd_approve.go b/internal/cmd/ml/cmd_approve.go new file mode 100644 index 00000000..2b7217d7 --- /dev/null +++ b/internal/cmd/ml/cmd_approve.go @@ -0,0 +1,53 @@ +package ml + +import ( + "fmt" + "os" + "path/filepath" + + "forge.lthn.ai/core/cli/pkg/cli" + "forge.lthn.ai/core/cli/pkg/ml" +) + +var ( + approveOutput string + approveThreshold float64 +) + +var approveCmd = &cli.Command{ + Use: "approve", + Short: "Filter scored expansions into training JSONL", + Long: "Filters scored expansion responses by quality threshold and exports approved ones as chat-format training JSONL.", + RunE: runApprove, +} + +func init() { + approveCmd.Flags().StringVar(&approveOutput, "output", "", "Output JSONL file (defaults to expansion-approved.jsonl in db dir)") + approveCmd.Flags().Float64Var(&approveThreshold, "threshold", 6.0, "Min judge average to approve") +} + +func runApprove(cmd *cli.Command, args []string) error { + path := dbPath + if path == "" { + path = os.Getenv("LEM_DB") + } + if path == "" { + return fmt.Errorf("--db or LEM_DB required") + } + + output := approveOutput + if output == "" { + output = filepath.Join(filepath.Dir(path), "expansion-approved.jsonl") + } + + db, err := ml.OpenDB(path) + if err != nil { + return fmt.Errorf("open db: %w", err) + } + defer db.Close() + + return ml.ApproveExpansions(db, ml.ApproveConfig{ + Output: output, + Threshold: approveThreshold, + }, cmd.OutOrStdout()) +} diff --git a/internal/cmd/ml/cmd_consolidate.go b/internal/cmd/ml/cmd_consolidate.go new file mode 100644 index 00000000..4185abaa --- /dev/null +++ b/internal/cmd/ml/cmd_consolidate.go @@ -0,0 +1,41 @@ +package ml + +import ( + "forge.lthn.ai/core/cli/pkg/cli" + "forge.lthn.ai/core/cli/pkg/ml" +) + +var ( + consolidateM3Host string + consolidateRemoteDir string + consolidatePattern string + consolidateOutputDir string + consolidateMergedOut string +) + +var consolidateCmd = &cli.Command{ + Use: "consolidate", + Short: "Pull and merge response JSONL files from M3", + Long: "Pulls JSONL response files from M3 via SSH/SCP, merges them by idx, deduplicates, and writes a single merged JSONL output.", + RunE: runConsolidate, +} + +func init() { + consolidateCmd.Flags().StringVar(&consolidateM3Host, "m3-host", "m3", "M3 SSH host") + consolidateCmd.Flags().StringVar(&consolidateRemoteDir, "remote", "/Volumes/Data/lem/responses", "Remote response directory") + consolidateCmd.Flags().StringVar(&consolidatePattern, "pattern", "gold*.jsonl", "File glob pattern") + consolidateCmd.Flags().StringVar(&consolidateOutputDir, "output", "", "Local output directory (default: responses)") + consolidateCmd.Flags().StringVar(&consolidateMergedOut, "merged", "", "Merged output path (default: gold-merged.jsonl in parent of output dir)") +} + +func runConsolidate(cmd *cli.Command, args []string) error { + cfg := ml.ConsolidateConfig{ + M3Host: consolidateM3Host, + RemoteDir: consolidateRemoteDir, + Pattern: consolidatePattern, + OutputDir: consolidateOutputDir, + MergedOut: consolidateMergedOut, + } + + return ml.Consolidate(cfg, cmd.OutOrStdout()) +} diff --git a/internal/cmd/ml/cmd_coverage.go b/internal/cmd/ml/cmd_coverage.go new file mode 100644 index 00000000..2b815327 --- /dev/null +++ b/internal/cmd/ml/cmd_coverage.go @@ -0,0 +1,34 @@ +package ml + +import ( + "fmt" + "os" + + "forge.lthn.ai/core/cli/pkg/cli" + "forge.lthn.ai/core/cli/pkg/ml" +) + +var coverageCmd = &cli.Command{ + Use: "coverage", + Short: "Analyze seed coverage by region and domain", + Long: "Queries seeds by region and domain, renders ASCII bar charts, and highlights underrepresented areas.", + RunE: runCoverage, +} + +func runCoverage(cmd *cli.Command, args []string) error { + path := dbPath + if path == "" { + path = os.Getenv("LEM_DB") + } + if path == "" { + return fmt.Errorf("--db or LEM_DB required") + } + + db, err := ml.OpenDB(path) + if err != nil { + return fmt.Errorf("open db: %w", err) + } + defer db.Close() + + return ml.PrintCoverage(db, cmd.OutOrStdout()) +} diff --git a/internal/cmd/ml/cmd_import.go b/internal/cmd/ml/cmd_import.go new file mode 100644 index 00000000..99937dcf --- /dev/null +++ b/internal/cmd/ml/cmd_import.go @@ -0,0 +1,58 @@ +package ml + +import ( + "fmt" + "os" + "path/filepath" + + "forge.lthn.ai/core/cli/pkg/cli" + "forge.lthn.ai/core/cli/pkg/ml" +) + +var importCmd = &cli.Command{ + Use: "import-all", + Short: "Import all LEM data into DuckDB", + Long: "Imports golden set, training examples, benchmark results, benchmark questions, and seeds into DuckDB from M3 and local files.", + RunE: runImportAll, +} + +var ( + importSkipM3 bool + importDataDir string + importM3Host string +) + +func init() { + importCmd.Flags().BoolVar(&importSkipM3, "skip-m3", false, "Skip pulling data from M3") + importCmd.Flags().StringVar(&importDataDir, "data-dir", "", "Local data directory (defaults to db directory)") + importCmd.Flags().StringVar(&importM3Host, "m3-host", "m3", "M3 SSH host alias") +} + +func runImportAll(cmd *cli.Command, args []string) error { + path := dbPath + if path == "" { + path = os.Getenv("LEM_DB") + } + if path == "" { + return fmt.Errorf("--db or LEM_DB required") + } + + dataDir := importDataDir + if dataDir == "" { + dataDir = filepath.Dir(path) + } + + db, err := ml.OpenDBReadWrite(path) + if err != nil { + return fmt.Errorf("open db: %w", err) + } + defer db.Close() + + cfg := ml.ImportConfig{ + SkipM3: importSkipM3, + DataDir: dataDir, + M3Host: importM3Host, + } + + return ml.ImportAll(db, cfg, cmd.OutOrStdout()) +} diff --git a/internal/cmd/ml/cmd_ingest.go b/internal/cmd/ml/cmd_ingest.go new file mode 100644 index 00000000..84bfb674 --- /dev/null +++ b/internal/cmd/ml/cmd_ingest.go @@ -0,0 +1,54 @@ +package ml + +import ( + "fmt" + "os" + + "forge.lthn.ai/core/cli/pkg/cli" + "forge.lthn.ai/core/cli/pkg/ml" +) + +var ingestCmd = &cli.Command{ + Use: "ingest", + Short: "Ingest benchmark scores and training logs into InfluxDB", + Long: "Reads content score, capability score, and training log files and writes measurements to InfluxDB for the lab dashboard.", + RunE: runIngest, +} + +var ( + ingestContent string + ingestCapability string + ingestTraining string + ingestRunID string + ingestBatchSize int +) + +func init() { + ingestCmd.Flags().StringVar(&ingestContent, "content", "", "Content scores JSONL file") + ingestCmd.Flags().StringVar(&ingestCapability, "capability", "", "Capability scores JSONL file") + ingestCmd.Flags().StringVar(&ingestTraining, "training-log", "", "MLX LoRA training log file") + ingestCmd.Flags().StringVar(&ingestRunID, "run-id", "", "Run ID tag (defaults to model name)") + ingestCmd.Flags().IntVar(&ingestBatchSize, "batch-size", 100, "Lines per InfluxDB write batch") +} + +func runIngest(cmd *cli.Command, args []string) error { + if modelName == "" { + return fmt.Errorf("--model is required") + } + if ingestContent == "" && ingestCapability == "" && ingestTraining == "" { + return fmt.Errorf("at least one of --content, --capability, or --training-log is required") + } + + influx := ml.NewInfluxClient(influxURL, influxDB) + + cfg := ml.IngestConfig{ + ContentFile: ingestContent, + CapabilityFile: ingestCapability, + TrainingLog: ingestTraining, + Model: modelName, + RunID: ingestRunID, + BatchSize: ingestBatchSize, + } + + return ml.Ingest(influx, cfg, os.Stdout) +} diff --git a/internal/cmd/ml/cmd_inventory.go b/internal/cmd/ml/cmd_inventory.go new file mode 100644 index 00000000..1789bab8 --- /dev/null +++ b/internal/cmd/ml/cmd_inventory.go @@ -0,0 +1,34 @@ +package ml + +import ( + "fmt" + "os" + + "forge.lthn.ai/core/cli/pkg/cli" + "forge.lthn.ai/core/cli/pkg/ml" +) + +var inventoryCmd = &cli.Command{ + Use: "inventory", + Short: "Show DuckDB table inventory with stats", + Long: "Queries all DuckDB tables and prints row counts with per-table detail breakdowns.", + RunE: runInventory, +} + +func runInventory(cmd *cli.Command, args []string) error { + path := dbPath + if path == "" { + path = os.Getenv("LEM_DB") + } + if path == "" { + return fmt.Errorf("--db or LEM_DB required") + } + + db, err := ml.OpenDB(path) + if err != nil { + return fmt.Errorf("open db: %w", err) + } + defer db.Close() + + return ml.PrintInventory(db, os.Stdout) +} diff --git a/internal/cmd/ml/cmd_metrics.go b/internal/cmd/ml/cmd_metrics.go new file mode 100644 index 00000000..b3d2c63d --- /dev/null +++ b/internal/cmd/ml/cmd_metrics.go @@ -0,0 +1,36 @@ +package ml + +import ( + "fmt" + "os" + + "forge.lthn.ai/core/cli/pkg/cli" + "forge.lthn.ai/core/cli/pkg/ml" +) + +var metricsCmd = &cli.Command{ + Use: "metrics", + Short: "Push golden set stats to InfluxDB", + Long: "Queries golden_set stats from DuckDB and pushes summary, per-domain, and per-voice metrics to InfluxDB.", + RunE: runMetrics, +} + +func runMetrics(cmd *cli.Command, args []string) error { + path := dbPath + if path == "" { + path = os.Getenv("LEM_DB") + } + if path == "" { + return fmt.Errorf("--db or LEM_DB required") + } + + db, err := ml.OpenDB(path) + if err != nil { + return fmt.Errorf("open db: %w", err) + } + defer db.Close() + + influx := ml.NewInfluxClient(influxURL, influxDB) + + return ml.PushMetrics(db, influx, os.Stdout) +} diff --git a/internal/cmd/ml/cmd_ml.go b/internal/cmd/ml/cmd_ml.go index 4b461f33..581a20ff 100644 --- a/internal/cmd/ml/cmd_ml.go +++ b/internal/cmd/ml/cmd_ml.go @@ -11,6 +11,17 @@ // - core ml agent: Run the scoring agent daemon // - core ml worker: Run a distributed worker node // - core ml serve: Start OpenAI-compatible inference server +// - core ml inventory: Show DuckDB table inventory with stats +// - core ml query: Run ad-hoc SQL against DuckDB +// - core ml metrics: Push golden set stats to InfluxDB +// - core ml ingest: Ingest benchmark scores and training logs to InfluxDB +// - core ml normalize: Deduplicate seeds into expansion prompts +// - core ml seed-influx: Migrate golden set from DuckDB to InfluxDB +// - core ml consolidate: Pull and merge response JSONL files from M3 +// - core ml import-all: Import all LEM data into DuckDB +// - core ml approve: Filter scored expansions into training JSONL +// - core ml publish: Upload Parquet dataset to HuggingFace Hub +// - core ml coverage: Analyze seed coverage by region and domain package ml import ( @@ -40,6 +51,17 @@ func AddMLCommands(root *cli.Command) { mlCmd.AddCommand(agentCmd) mlCmd.AddCommand(workerCmd) mlCmd.AddCommand(serveCmd) + mlCmd.AddCommand(inventoryCmd) + mlCmd.AddCommand(queryCmd) + mlCmd.AddCommand(metricsCmd) + mlCmd.AddCommand(ingestCmd) + mlCmd.AddCommand(normalizeCmd) + mlCmd.AddCommand(seedInfluxCmd) + mlCmd.AddCommand(consolidateCmd) + mlCmd.AddCommand(importCmd) + mlCmd.AddCommand(approveCmd) + mlCmd.AddCommand(publishCmd) + mlCmd.AddCommand(coverageCmd) root.AddCommand(mlCmd) } diff --git a/internal/cmd/ml/cmd_normalize.go b/internal/cmd/ml/cmd_normalize.go new file mode 100644 index 00000000..5f07f9af --- /dev/null +++ b/internal/cmd/ml/cmd_normalize.go @@ -0,0 +1,44 @@ +package ml + +import ( + "fmt" + "os" + + "forge.lthn.ai/core/cli/pkg/cli" + "forge.lthn.ai/core/cli/pkg/ml" +) + +var normalizeMinLen int + +var normalizeCmd = &cli.Command{ + Use: "normalize", + Short: "Normalize seeds into expansion prompts", + Long: "Deduplicates seeds against golden_set and prompts, creating the expansion_prompts table with priority-based ordering.", + RunE: runNormalize, +} + +func init() { + normalizeCmd.Flags().IntVar(&normalizeMinLen, "min-length", 50, "Minimum prompt length in characters") +} + +func runNormalize(cmd *cli.Command, args []string) error { + path := dbPath + if path == "" { + path = os.Getenv("LEM_DB") + } + if path == "" { + return fmt.Errorf("--db or LEM_DB env is required") + } + + db, err := ml.OpenDBReadWrite(path) + if err != nil { + return fmt.Errorf("open db: %w", err) + } + defer db.Close() + + cfg := ml.NormalizeConfig{ + MinLength: normalizeMinLen, + } + + return ml.NormalizeSeeds(db, cfg, os.Stdout) +} diff --git a/internal/cmd/ml/cmd_publish.go b/internal/cmd/ml/cmd_publish.go new file mode 100644 index 00000000..45712367 --- /dev/null +++ b/internal/cmd/ml/cmd_publish.go @@ -0,0 +1,40 @@ +package ml + +import ( + "forge.lthn.ai/core/cli/pkg/cli" + "forge.lthn.ai/core/cli/pkg/ml" +) + +var ( + publishInputDir string + publishRepo string + publishPublic bool + publishToken string + publishDryRun bool +) + +var publishCmd = &cli.Command{ + Use: "publish", + Short: "Upload Parquet dataset to HuggingFace Hub", + Long: "Uploads train/valid/test Parquet files and an optional dataset card to a HuggingFace dataset repository.", + RunE: runPublish, +} + +func init() { + publishCmd.Flags().StringVar(&publishInputDir, "input-dir", "", "Directory containing Parquet files (required)") + publishCmd.Flags().StringVar(&publishRepo, "repo", "lthn/LEM-golden-set", "HuggingFace dataset repo ID") + publishCmd.Flags().BoolVar(&publishPublic, "public", false, "Make dataset public") + publishCmd.Flags().StringVar(&publishToken, "token", "", "HuggingFace API token (defaults to HF_TOKEN env)") + publishCmd.Flags().BoolVar(&publishDryRun, "dry-run", false, "Show what would be uploaded without uploading") + _ = publishCmd.MarkFlagRequired("input-dir") +} + +func runPublish(cmd *cli.Command, args []string) error { + return ml.Publish(ml.PublishConfig{ + InputDir: publishInputDir, + Repo: publishRepo, + Public: publishPublic, + Token: publishToken, + DryRun: publishDryRun, + }, cmd.OutOrStdout()) +} diff --git a/internal/cmd/ml/cmd_query.go b/internal/cmd/ml/cmd_query.go new file mode 100644 index 00000000..0fe93607 --- /dev/null +++ b/internal/cmd/ml/cmd_query.go @@ -0,0 +1,148 @@ +package ml + +import ( + "encoding/json" + "fmt" + "os" + "strings" + + "forge.lthn.ai/core/cli/pkg/cli" + "forge.lthn.ai/core/cli/pkg/ml" +) + +var queryCmd = &cli.Command{ + Use: "query [sql]", + Short: "Run ad-hoc SQL against DuckDB", + Long: "Executes arbitrary SQL against the DuckDB database. Non-SELECT queries are auto-wrapped as golden_set WHERE clauses.", + Example: ` core ml query "SELECT COUNT(*) FROM golden_set" + core ml query "domain = 'ethics'" + core ml query --json "SHOW TABLES"`, + Args: cli.MinimumNArgs(1), + RunE: runQuery, +} + +var queryJSON bool + +func init() { + queryCmd.Flags().BoolVar(&queryJSON, "json", false, "Output as JSON") +} + +func runQuery(cmd *cli.Command, args []string) error { + path := dbPath + if path == "" { + path = os.Getenv("LEM_DB") + } + if path == "" { + return fmt.Errorf("--db or LEM_DB env is required") + } + + db, err := ml.OpenDB(path) + if err != nil { + return fmt.Errorf("open db: %w", err) + } + defer db.Close() + + sql := strings.Join(args, " ") + + // Auto-wrap non-SELECT queries as golden_set WHERE clauses. + trimmed := strings.TrimSpace(strings.ToUpper(sql)) + if !strings.HasPrefix(trimmed, "SELECT") && !strings.HasPrefix(trimmed, "SHOW") && + !strings.HasPrefix(trimmed, "DESCRIBE") && !strings.HasPrefix(trimmed, "EXPLAIN") { + sql = "SELECT * FROM golden_set WHERE " + sql + " LIMIT 20" + } + + rows, err := db.QueryRows(sql) + if err != nil { + return fmt.Errorf("query: %w", err) + } + + if queryJSON { + enc := json.NewEncoder(os.Stdout) + enc.SetIndent("", " ") + if err := enc.Encode(rows); err != nil { + return fmt.Errorf("encode json: %w", err) + } + fmt.Fprintf(os.Stderr, "\n(%d rows)\n", len(rows)) + return nil + } + + if len(rows) == 0 { + fmt.Println("(0 rows)") + return nil + } + + // Collect column names in stable order from first row. + var cols []string + for col := range rows[0] { + cols = append(cols, col) + } + + // Calculate column widths (capped at 60). + const maxWidth = 60 + widths := make([]int, len(cols)) + for i, col := range cols { + widths[i] = len(col) + } + for _, row := range rows { + for i, col := range cols { + val := formatValue(row[col]) + if l := len(val); l > widths[i] { + widths[i] = l + } + } + } + for i := range widths { + if widths[i] > maxWidth { + widths[i] = maxWidth + } + } + + // Print header. + for i, col := range cols { + if i > 0 { + fmt.Print(" | ") + } + fmt.Printf("%-*s", widths[i], truncate(col, widths[i])) + } + fmt.Println() + + // Print separator. + for i := range cols { + if i > 0 { + fmt.Print("-+-") + } + fmt.Print(strings.Repeat("-", widths[i])) + } + fmt.Println() + + // Print rows. + for _, row := range rows { + for i, col := range cols { + if i > 0 { + fmt.Print(" | ") + } + fmt.Printf("%-*s", widths[i], truncate(formatValue(row[col]), widths[i])) + } + fmt.Println() + } + + fmt.Printf("\n(%d rows)\n", len(rows)) + return nil +} + +func formatValue(v interface{}) string { + if v == nil { + return "NULL" + } + return fmt.Sprintf("%v", v) +} + +func truncate(s string, max int) string { + if len(s) <= max { + return s + } + if max <= 3 { + return s[:max] + } + return s[:max-3] + "..." +} diff --git a/internal/cmd/ml/cmd_seed_influx.go b/internal/cmd/ml/cmd_seed_influx.go new file mode 100644 index 00000000..a3960890 --- /dev/null +++ b/internal/cmd/ml/cmd_seed_influx.go @@ -0,0 +1,49 @@ +package ml + +import ( + "fmt" + "os" + + "forge.lthn.ai/core/cli/pkg/cli" + "forge.lthn.ai/core/cli/pkg/ml" +) + +var seedInfluxCmd = &cli.Command{ + Use: "seed-influx", + Short: "Seed InfluxDB golden_gen from DuckDB golden_set", + Long: "One-time migration: batch-loads DuckDB golden_set records into InfluxDB golden_gen measurement.", + RunE: runSeedInflux, +} + +var ( + seedInfluxForce bool + seedInfluxBatchSize int +) + +func init() { + seedInfluxCmd.Flags().BoolVar(&seedInfluxForce, "force", false, "Re-seed even if InfluxDB already has data") + seedInfluxCmd.Flags().IntVar(&seedInfluxBatchSize, "batch-size", 500, "Lines per InfluxDB write batch") +} + +func runSeedInflux(cmd *cli.Command, args []string) error { + path := dbPath + if path == "" { + path = os.Getenv("LEM_DB") + } + if path == "" { + return fmt.Errorf("--db or LEM_DB required") + } + + db, err := ml.OpenDB(path) + if err != nil { + return fmt.Errorf("open db: %w", err) + } + defer db.Close() + + influx := ml.NewInfluxClient(influxURL, influxDB) + + return ml.SeedInflux(db, influx, ml.SeedInfluxConfig{ + Force: seedInfluxForce, + BatchSize: seedInfluxBatchSize, + }, os.Stdout) +} diff --git a/pkg/ml/approve.go b/pkg/ml/approve.go new file mode 100644 index 00000000..566d8d2d --- /dev/null +++ b/pkg/ml/approve.go @@ -0,0 +1,82 @@ +package ml + +import ( + "encoding/json" + "fmt" + "io" + "os" +) + +// ApproveConfig holds options for the approve operation. +type ApproveConfig struct { + Output string + Threshold float64 +} + +// ApproveExpansions filters scored expansion responses above the threshold +// and writes approved examples to a training JSONL file. +// +// The query joins expansion_raw with expansion_scores, keeping rows where +// the heuristic passed AND the judge either passed or has not yet scored. +// Each approved row is written as a chat-format JSONL line with user/assistant +// messages. +func ApproveExpansions(db *DB, cfg ApproveConfig, w io.Writer) error { + rows, err := db.conn.Query(` + SELECT r.idx, r.seed_id, r.region, r.domain, r.prompt, r.response, + r.gen_time, r.model, s.heuristic_score + FROM expansion_raw r + JOIN expansion_scores s ON r.idx = s.idx + WHERE s.heuristic_pass = true + AND (s.judge_pass = true OR s.judge_pass IS NULL) + ORDER BY r.idx + `) + if err != nil { + return fmt.Errorf("query approved expansions: %w (have you run scoring?)", err) + } + defer rows.Close() + + f, err := os.Create(cfg.Output) + if err != nil { + return fmt.Errorf("create output %s: %w", cfg.Output, err) + } + defer f.Close() + + enc := json.NewEncoder(f) + count := 0 + regionSet := make(map[string]bool) + domainSet := make(map[string]bool) + + for rows.Next() { + var idx int + var seedID, region, domain, prompt, response, model string + var genTime, score float64 + if err := rows.Scan(&idx, &seedID, ®ion, &domain, &prompt, &response, &genTime, &model, &score); err != nil { + return fmt.Errorf("scan approved row: %w", err) + } + + example := TrainingExample{ + Messages: []ChatMessage{ + {Role: "user", Content: prompt}, + {Role: "assistant", Content: response}, + }, + } + + if err := enc.Encode(example); err != nil { + return fmt.Errorf("encode example: %w", err) + } + + regionSet[region] = true + domainSet[domain] = true + count++ + } + + if err := rows.Err(); err != nil { + return fmt.Errorf("iterate approved rows: %w", err) + } + + fmt.Fprintf(w, "Approved: %d responses (threshold: heuristic > 0)\n", count) + fmt.Fprintf(w, "Exported: %s\n", cfg.Output) + fmt.Fprintf(w, " Regions: %d, Domains: %d\n", len(regionSet), len(domainSet)) + + return nil +} diff --git a/pkg/ml/consolidate.go b/pkg/ml/consolidate.go new file mode 100644 index 00000000..82e1db17 --- /dev/null +++ b/pkg/ml/consolidate.go @@ -0,0 +1,150 @@ +package ml + +import ( + "bufio" + "encoding/json" + "fmt" + "io" + "os" + "os/exec" + "path/filepath" + "sort" + "strings" +) + +// ConsolidateConfig holds options for the consolidate operation. +type ConsolidateConfig struct { + M3Host string + RemoteDir string + Pattern string + OutputDir string + MergedOut string +} + +// Consolidate pulls JSONL response files from M3 via SSH, merges them by idx, +// deduplicates, and writes a single merged JSONL output. +func Consolidate(cfg ConsolidateConfig, w io.Writer) error { + if cfg.OutputDir == "" { + cfg.OutputDir = "responses" + } + if err := os.MkdirAll(cfg.OutputDir, 0755); err != nil { + return fmt.Errorf("create output dir: %w", err) + } + + // List remote files via SSH. + fmt.Fprintln(w, "Pulling responses from remote...") + listCmd := exec.Command("ssh", cfg.M3Host, fmt.Sprintf("ls %s/%s", cfg.RemoteDir, cfg.Pattern)) + listOutput, err := listCmd.Output() + if err != nil { + return fmt.Errorf("list remote files: %w", err) + } + + remoteFiles := strings.Split(strings.TrimSpace(string(listOutput)), "\n") + var validFiles []string + for _, f := range remoteFiles { + f = strings.TrimSpace(f) + if f != "" { + validFiles = append(validFiles, f) + } + } + fmt.Fprintf(w, " Found %d JSONL files on %s\n", len(validFiles), cfg.M3Host) + + // Pull each file via SCP. + for _, rf := range validFiles { + local := filepath.Join(cfg.OutputDir, filepath.Base(rf)) + scpCmd := exec.Command("scp", fmt.Sprintf("%s:%s", cfg.M3Host, rf), local) + if err := scpCmd.Run(); err != nil { + fmt.Fprintf(w, " warning: failed to pull %s: %v\n", rf, err) + continue + } + + lines, err := countLines(local) + if err == nil { + fmt.Fprintf(w, " %s: %d records\n", filepath.Base(rf), lines) + } + } + + // Merge and deduplicate on idx (first occurrence wins). + seen := make(map[int]json.RawMessage) + skipped := 0 + + matches, _ := filepath.Glob(filepath.Join(cfg.OutputDir, cfg.Pattern)) + sort.Strings(matches) + + for _, local := range matches { + f, err := os.Open(local) + if err != nil { + continue + } + scanner := bufio.NewScanner(f) + scanner.Buffer(make([]byte, 1024*1024), 1024*1024) + for scanner.Scan() { + line := scanner.Text() + var rec struct { + Idx *int `json:"idx"` + } + if err := json.Unmarshal([]byte(line), &rec); err != nil { + skipped++ + continue + } + if rec.Idx == nil { + skipped++ + continue + } + if _, exists := seen[*rec.Idx]; !exists { + seen[*rec.Idx] = json.RawMessage(line) + } + } + f.Close() + } + + if skipped > 0 { + fmt.Fprintf(w, " Skipped %d records without idx\n", skipped) + } + + // Sort by idx and write merged file. + mergedPath := cfg.MergedOut + if mergedPath == "" { + mergedPath = filepath.Join(cfg.OutputDir, "..", "gold-merged.jsonl") + } + + idxs := make([]int, 0, len(seen)) + for idx := range seen { + idxs = append(idxs, idx) + } + sort.Ints(idxs) + + out, err := os.Create(mergedPath) + if err != nil { + return fmt.Errorf("create merged file: %w", err) + } + defer out.Close() + + bw := bufio.NewWriter(out) + for _, idx := range idxs { + bw.Write(seen[idx]) + bw.WriteString("\n") + } + if err := bw.Flush(); err != nil { + return fmt.Errorf("flush merged file: %w", err) + } + + fmt.Fprintf(w, "\nMerged: %d unique examples -> %s\n", len(seen), mergedPath) + return nil +} + +// countLines returns the number of lines in a file. +func countLines(path string) (int, error) { + f, err := os.Open(path) + if err != nil { + return 0, err + } + defer f.Close() + + count := 0 + scanner := bufio.NewScanner(f) + for scanner.Scan() { + count++ + } + return count, scanner.Err() +} diff --git a/pkg/ml/coverage.go b/pkg/ml/coverage.go new file mode 100644 index 00000000..dc3441dc --- /dev/null +++ b/pkg/ml/coverage.go @@ -0,0 +1,127 @@ +package ml + +import ( + "fmt" + "io" + "strings" +) + +// regionRow holds a single row from the region distribution query. +type regionRow struct { + group string + n int + domains int +} + +// PrintCoverage analyzes seed coverage by region and domain, printing +// a report with bar chart visualization and gap recommendations. +func PrintCoverage(db *DB, w io.Writer) error { + rows, err := db.QueryRows("SELECT count(*) AS total FROM seeds") + if err != nil { + return fmt.Errorf("count seeds: %w (run: core ml import-all first)", err) + } + if len(rows) == 0 { + return fmt.Errorf("no seeds table found (run: core ml import-all first)") + } + total := toInt(rows[0]["total"]) + + fmt.Fprintln(w, "LEM Seed Coverage Analysis") + fmt.Fprintln(w, "==================================================") + fmt.Fprintf(w, "\nTotal seeds: %d\n", total) + + // Region distribution. + regionRows, err := queryRegionDistribution(db) + if err != nil { + return fmt.Errorf("query regions: %w", err) + } + + fmt.Fprintln(w, "\nRegion distribution (underrepresented first):") + avg := float64(total) / float64(len(regionRows)) + for _, r := range regionRows { + barLen := int(float64(r.n) / avg * 10) + if barLen > 40 { + barLen = 40 + } + bar := strings.Repeat("#", barLen) + gap := "" + if float64(r.n) < avg*0.5 { + gap = " <- UNDERREPRESENTED" + } + fmt.Fprintf(w, " %-22s %6d (%4d domains) %s%s\n", r.group, r.n, r.domains, bar, gap) + } + + // Top 10 domains. + fmt.Fprintln(w, "\nTop 10 domains (most seeds):") + topRows, err := db.QueryRows(` + SELECT domain, count(*) AS n FROM seeds + WHERE domain != '' GROUP BY domain ORDER BY n DESC LIMIT 10 + `) + if err == nil { + for _, row := range topRows { + domain := strVal(row, "domain") + n := toInt(row["n"]) + fmt.Fprintf(w, " %-40s %5d\n", domain, n) + } + } + + // Bottom 10 domains. + fmt.Fprintln(w, "\nBottom 10 domains (fewest seeds, min 5):") + bottomRows, err := db.QueryRows(` + SELECT domain, count(*) AS n FROM seeds + WHERE domain != '' GROUP BY domain HAVING count(*) >= 5 ORDER BY n ASC LIMIT 10 + `) + if err == nil { + for _, row := range bottomRows { + domain := strVal(row, "domain") + n := toInt(row["n"]) + fmt.Fprintf(w, " %-40s %5d\n", domain, n) + } + } + + fmt.Fprintln(w, "\nSuggested expansion areas:") + fmt.Fprintln(w, " - Japanese, Korean, Thai, Vietnamese (no seeds found)") + fmt.Fprintln(w, " - Hindi/Urdu, Bengali, Tamil (South Asian)") + fmt.Fprintln(w, " - Swahili, Yoruba, Amharic (Sub-Saharan Africa)") + fmt.Fprintln(w, " - Indigenous languages (Quechua, Nahuatl, Aymara)") + + return nil +} + +// queryRegionDistribution returns seed counts grouped by normalized language +// region, ordered ascending (underrepresented first). +func queryRegionDistribution(db *DB) ([]regionRow, error) { + rows, err := db.QueryRows(` + SELECT + CASE + WHEN region LIKE '%cn%' THEN 'cn (Chinese)' + WHEN region LIKE '%en-%' OR region LIKE '%en_para%' OR region LIKE '%para%' THEN 'en (English)' + WHEN region LIKE '%ru%' THEN 'ru (Russian)' + WHEN region LIKE '%de%' AND region NOT LIKE '%deten%' THEN 'de (German)' + WHEN region LIKE '%es%' THEN 'es (Spanish)' + WHEN region LIKE '%fr%' THEN 'fr (French)' + WHEN region LIKE '%latam%' THEN 'latam (LatAm)' + WHEN region LIKE '%africa%' THEN 'africa' + WHEN region LIKE '%eu%' THEN 'eu (European)' + WHEN region LIKE '%me%' AND region NOT LIKE '%premium%' THEN 'me (MidEast)' + WHEN region LIKE '%multi%' THEN 'multilingual' + WHEN region LIKE '%weak%' THEN 'weak-langs' + ELSE 'other' + END AS lang_group, + count(*) AS n, + count(DISTINCT domain) AS domains + FROM seeds GROUP BY lang_group ORDER BY n ASC + `) + if err != nil { + return nil, err + } + + result := make([]regionRow, 0, len(rows)) + for _, row := range rows { + result = append(result, regionRow{ + group: strVal(row, "lang_group"), + n: toInt(row["n"]), + domains: toInt(row["domains"]), + }) + } + return result, nil +} diff --git a/pkg/ml/db.go b/pkg/ml/db.go index 95c6a14d..766b3f39 100644 --- a/pkg/ml/db.go +++ b/pkg/ml/db.go @@ -45,6 +45,23 @@ func (db *DB) Close() error { return db.conn.Close() } +// Path returns the database file path. +func (db *DB) Path() string { + return db.path +} + +// Exec executes a query without returning rows. +func (db *DB) Exec(query string, args ...interface{}) error { + _, err := db.conn.Exec(query, args...) + return err +} + +// QueryRowScan executes a query expected to return at most one row and scans +// the result into dest. It is a convenience wrapper around sql.DB.QueryRow. +func (db *DB) QueryRowScan(query string, dest interface{}, args ...interface{}) error { + return db.conn.QueryRow(query, args...).Scan(dest) +} + // GoldenSetRow represents one row from the golden_set table. type GoldenSetRow struct { Idx int diff --git a/pkg/ml/import_all.go b/pkg/ml/import_all.go new file mode 100644 index 00000000..bbd288f2 --- /dev/null +++ b/pkg/ml/import_all.go @@ -0,0 +1,437 @@ +package ml + +import ( + "bufio" + "encoding/json" + "fmt" + "io" + "os" + "os/exec" + "path/filepath" + "strings" +) + +// ImportConfig holds options for the import-all operation. +type ImportConfig struct { + SkipM3 bool + DataDir string + M3Host string +} + +// ImportAll imports all LEM data into DuckDB from M3 and local files. +func ImportAll(db *DB, cfg ImportConfig, w io.Writer) error { + m3Host := cfg.M3Host + if m3Host == "" { + m3Host = "m3" + } + + totals := make(map[string]int) + + // ── 1. Golden set ── + goldenPath := filepath.Join(cfg.DataDir, "gold-15k.jsonl") + if !cfg.SkipM3 { + fmt.Fprintln(w, " Pulling golden set from M3...") + scpCmd := exec.Command("scp", fmt.Sprintf("%s:/Volumes/Data/lem/responses/gold-15k.jsonl", m3Host), goldenPath) + if err := scpCmd.Run(); err != nil { + fmt.Fprintf(w, " WARNING: could not pull golden set from M3: %v\n", err) + } + } + if _, err := os.Stat(goldenPath); err == nil { + db.Exec("DROP TABLE IF EXISTS golden_set") + err := db.Exec(fmt.Sprintf(` + CREATE TABLE golden_set AS + SELECT + idx::INT AS idx, + seed_id::VARCHAR AS seed_id, + domain::VARCHAR AS domain, + voice::VARCHAR AS voice, + prompt::VARCHAR AS prompt, + response::VARCHAR AS response, + gen_time::DOUBLE AS gen_time, + length(response)::INT AS char_count, + length(response) - length(replace(response, ' ', '')) + 1 AS word_count + FROM read_json_auto('%s', maximum_object_size=1048576) + `, escapeSQLPath(goldenPath))) + if err != nil { + fmt.Fprintf(w, " WARNING: golden set import failed: %v\n", err) + } else { + var n int + db.QueryRowScan("SELECT count(*) FROM golden_set", &n) + totals["golden_set"] = n + fmt.Fprintf(w, " golden_set: %d rows\n", n) + } + } + + // ── 2. Training examples ── + trainingDirs := []struct { + name string + files []string + }{ + {"training", []string{"training/train.jsonl", "training/valid.jsonl", "training/test.jsonl"}}, + {"training-2k", []string{"training-2k/train.jsonl", "training-2k/valid.jsonl", "training-2k/test.jsonl"}}, + {"training-expanded", []string{"training-expanded/train.jsonl", "training-expanded/valid.jsonl"}}, + {"training-book", []string{"training-book/train.jsonl", "training-book/valid.jsonl", "training-book/test.jsonl"}}, + {"training-conv", []string{"training-conv/train.jsonl", "training-conv/valid.jsonl", "training-conv/test.jsonl"}}, + {"gold-full", []string{"gold-full/train.jsonl", "gold-full/valid.jsonl"}}, + {"sovereignty-gold", []string{"sovereignty-gold/train.jsonl", "sovereignty-gold/valid.jsonl"}}, + {"composure-lessons", []string{"composure-lessons/train.jsonl", "composure-lessons/valid.jsonl"}}, + {"watts-full", []string{"watts-full/train.jsonl", "watts-full/valid.jsonl"}}, + {"watts-expanded", []string{"watts-expanded/train.jsonl", "watts-expanded/valid.jsonl"}}, + {"watts-composure", []string{"watts-composure-merged/train.jsonl", "watts-composure-merged/valid.jsonl"}}, + {"western-fresh", []string{"western-fresh/train.jsonl", "western-fresh/valid.jsonl"}}, + {"deepseek-soak", []string{"deepseek-western-soak/train.jsonl", "deepseek-western-soak/valid.jsonl"}}, + {"russian-bridge", []string{"russian-bridge/train.jsonl", "russian-bridge/valid.jsonl"}}, + } + + trainingLocal := filepath.Join(cfg.DataDir, "training") + os.MkdirAll(trainingLocal, 0755) + + if !cfg.SkipM3 { + fmt.Fprintln(w, " Pulling training sets from M3...") + for _, td := range trainingDirs { + for _, rel := range td.files { + local := filepath.Join(trainingLocal, rel) + os.MkdirAll(filepath.Dir(local), 0755) + scpCmd := exec.Command("scp", fmt.Sprintf("%s:/Volumes/Data/lem/%s", m3Host, rel), local) + scpCmd.Run() // ignore errors, file might not exist + } + } + } + + db.Exec("DROP TABLE IF EXISTS training_examples") + db.Exec(` + CREATE TABLE training_examples ( + source VARCHAR, + split VARCHAR, + prompt TEXT, + response TEXT, + num_turns INT, + full_messages TEXT, + char_count INT + ) + `) + + trainingTotal := 0 + for _, td := range trainingDirs { + for _, rel := range td.files { + local := filepath.Join(trainingLocal, rel) + if _, err := os.Stat(local); os.IsNotExist(err) { + continue + } + + split := "train" + if strings.Contains(rel, "valid") { + split = "valid" + } else if strings.Contains(rel, "test") { + split = "test" + } + + n := importTrainingFile(db, local, td.name, split) + trainingTotal += n + } + } + totals["training_examples"] = trainingTotal + fmt.Fprintf(w, " training_examples: %d rows\n", trainingTotal) + + // ── 3. Benchmark results ── + benchLocal := filepath.Join(cfg.DataDir, "benchmarks") + os.MkdirAll(benchLocal, 0755) + + if !cfg.SkipM3 { + fmt.Fprintln(w, " Pulling benchmarks from M3...") + for _, bname := range []string{"truthfulqa", "gsm8k", "do_not_answer", "toxigen"} { + scpCmd := exec.Command("scp", + fmt.Sprintf("%s:/Volumes/Data/lem/benchmarks/%s.jsonl", m3Host, bname), + filepath.Join(benchLocal, bname+".jsonl")) + scpCmd.Run() + } + for _, subdir := range []string{"results", "scale_results", "cross_arch_results", "deepseek-r1-7b"} { + localSub := filepath.Join(benchLocal, subdir) + os.MkdirAll(localSub, 0755) + scpCmd := exec.Command("scp", "-r", + fmt.Sprintf("%s:/Volumes/Data/lem/benchmarks/%s/", m3Host, subdir), + filepath.Join(benchLocal)+"/") + scpCmd.Run() + } + } + + db.Exec("DROP TABLE IF EXISTS benchmark_results") + db.Exec(` + CREATE TABLE benchmark_results ( + source VARCHAR, id VARCHAR, benchmark VARCHAR, model VARCHAR, + prompt TEXT, response TEXT, elapsed_seconds DOUBLE, domain VARCHAR + ) + `) + + benchTotal := 0 + for _, subdir := range []string{"results", "scale_results", "cross_arch_results", "deepseek-r1-7b"} { + resultDir := filepath.Join(benchLocal, subdir) + matches, _ := filepath.Glob(filepath.Join(resultDir, "*.jsonl")) + for _, jf := range matches { + n := importBenchmarkFile(db, jf, subdir) + benchTotal += n + } + } + + // Also import standalone benchmark files. + for _, bfile := range []string{"lem_bench", "lem_ethics", "lem_ethics_allen", "instruction_tuned", "abliterated", "base_pt"} { + local := filepath.Join(benchLocal, bfile+".jsonl") + if _, err := os.Stat(local); os.IsNotExist(err) { + if !cfg.SkipM3 { + scpCmd := exec.Command("scp", + fmt.Sprintf("%s:/Volumes/Data/lem/benchmark/%s.jsonl", m3Host, bfile), local) + scpCmd.Run() + } + } + if _, err := os.Stat(local); err == nil { + n := importBenchmarkFile(db, local, "benchmark") + benchTotal += n + } + } + totals["benchmark_results"] = benchTotal + fmt.Fprintf(w, " benchmark_results: %d rows\n", benchTotal) + + // ── 4. Benchmark questions ── + db.Exec("DROP TABLE IF EXISTS benchmark_questions") + db.Exec(` + CREATE TABLE benchmark_questions ( + benchmark VARCHAR, id VARCHAR, question TEXT, + best_answer TEXT, correct_answers TEXT, incorrect_answers TEXT, category VARCHAR + ) + `) + + benchQTotal := 0 + for _, bname := range []string{"truthfulqa", "gsm8k", "do_not_answer", "toxigen"} { + local := filepath.Join(benchLocal, bname+".jsonl") + if _, err := os.Stat(local); err == nil { + n := importBenchmarkQuestions(db, local, bname) + benchQTotal += n + } + } + totals["benchmark_questions"] = benchQTotal + fmt.Fprintf(w, " benchmark_questions: %d rows\n", benchQTotal) + + // ── 5. Seeds ── + db.Exec("DROP TABLE IF EXISTS seeds") + db.Exec(` + CREATE TABLE seeds ( + source_file VARCHAR, region VARCHAR, seed_id VARCHAR, domain VARCHAR, prompt TEXT + ) + `) + + seedTotal := 0 + seedDirs := []string{filepath.Join(cfg.DataDir, "seeds"), "/tmp/lem-data/seeds", "/tmp/lem-repo/seeds"} + for _, seedDir := range seedDirs { + if _, err := os.Stat(seedDir); os.IsNotExist(err) { + continue + } + n := importSeeds(db, seedDir) + seedTotal += n + } + totals["seeds"] = seedTotal + fmt.Fprintf(w, " seeds: %d rows\n", seedTotal) + + // ── Summary ── + grandTotal := 0 + fmt.Fprintf(w, "\n%s\n", strings.Repeat("=", 50)) + fmt.Fprintln(w, "LEM Database Import Complete") + fmt.Fprintln(w, strings.Repeat("=", 50)) + for table, count := range totals { + fmt.Fprintf(w, " %-25s %8d\n", table, count) + grandTotal += count + } + fmt.Fprintf(w, " %s\n", strings.Repeat("-", 35)) + fmt.Fprintf(w, " %-25s %8d\n", "TOTAL", grandTotal) + fmt.Fprintf(w, "\nDatabase: %s\n", db.Path()) + + return nil +} + +func importTrainingFile(db *DB, path, source, split string) int { + f, err := os.Open(path) + if err != nil { + return 0 + } + defer f.Close() + + count := 0 + scanner := bufio.NewScanner(f) + scanner.Buffer(make([]byte, 1024*1024), 1024*1024) + + for scanner.Scan() { + var rec struct { + Messages []ChatMessage `json:"messages"` + } + if err := json.Unmarshal(scanner.Bytes(), &rec); err != nil { + continue + } + + prompt := "" + response := "" + assistantCount := 0 + for _, m := range rec.Messages { + if m.Role == "user" && prompt == "" { + prompt = m.Content + } + if m.Role == "assistant" { + if response == "" { + response = m.Content + } + assistantCount++ + } + } + + msgsJSON, _ := json.Marshal(rec.Messages) + db.Exec(`INSERT INTO training_examples VALUES (?, ?, ?, ?, ?, ?, ?)`, + source, split, prompt, response, assistantCount, string(msgsJSON), len(response)) + count++ + } + return count +} + +func importBenchmarkFile(db *DB, path, source string) int { + f, err := os.Open(path) + if err != nil { + return 0 + } + defer f.Close() + + count := 0 + scanner := bufio.NewScanner(f) + scanner.Buffer(make([]byte, 1024*1024), 1024*1024) + + for scanner.Scan() { + var rec map[string]interface{} + if err := json.Unmarshal(scanner.Bytes(), &rec); err != nil { + continue + } + + db.Exec(`INSERT INTO benchmark_results VALUES (?, ?, ?, ?, ?, ?, ?, ?)`, + source, + fmt.Sprintf("%v", rec["id"]), + strOrEmpty(rec, "benchmark"), + strOrEmpty(rec, "model"), + strOrEmpty(rec, "prompt"), + strOrEmpty(rec, "response"), + floatOrZero(rec, "elapsed_seconds"), + strOrEmpty(rec, "domain"), + ) + count++ + } + return count +} + +func importBenchmarkQuestions(db *DB, path, benchmark string) int { + f, err := os.Open(path) + if err != nil { + return 0 + } + defer f.Close() + + count := 0 + scanner := bufio.NewScanner(f) + scanner.Buffer(make([]byte, 1024*1024), 1024*1024) + + for scanner.Scan() { + var rec map[string]interface{} + if err := json.Unmarshal(scanner.Bytes(), &rec); err != nil { + continue + } + + correctJSON, _ := json.Marshal(rec["correct_answers"]) + incorrectJSON, _ := json.Marshal(rec["incorrect_answers"]) + + db.Exec(`INSERT INTO benchmark_questions VALUES (?, ?, ?, ?, ?, ?, ?)`, + benchmark, + fmt.Sprintf("%v", rec["id"]), + strOrEmpty(rec, "question"), + strOrEmpty(rec, "best_answer"), + string(correctJSON), + string(incorrectJSON), + strOrEmpty(rec, "category"), + ) + count++ + } + return count +} + +func importSeeds(db *DB, seedDir string) int { + count := 0 + filepath.Walk(seedDir, func(path string, info os.FileInfo, err error) error { + if err != nil || info.IsDir() || !strings.HasSuffix(path, ".json") { + return nil + } + + data, err := os.ReadFile(path) + if err != nil { + return nil + } + + rel, _ := filepath.Rel(seedDir, path) + region := strings.TrimSuffix(filepath.Base(path), ".json") + + // Try parsing as array or object with prompts/seeds field. + var seedsList []interface{} + var raw interface{} + if err := json.Unmarshal(data, &raw); err != nil { + return nil + } + + switch v := raw.(type) { + case []interface{}: + seedsList = v + case map[string]interface{}: + if prompts, ok := v["prompts"].([]interface{}); ok { + seedsList = prompts + } else if seeds, ok := v["seeds"].([]interface{}); ok { + seedsList = seeds + } + } + + for _, s := range seedsList { + switch seed := s.(type) { + case map[string]interface{}: + prompt := strOrEmpty(seed, "prompt") + if prompt == "" { + prompt = strOrEmpty(seed, "text") + } + if prompt == "" { + prompt = strOrEmpty(seed, "question") + } + db.Exec(`INSERT INTO seeds VALUES (?, ?, ?, ?, ?)`, + rel, region, + strOrEmpty(seed, "seed_id"), + strOrEmpty(seed, "domain"), + prompt, + ) + count++ + case string: + db.Exec(`INSERT INTO seeds VALUES (?, ?, ?, ?, ?)`, + rel, region, "", "", seed) + count++ + } + } + return nil + }) + return count +} + +func strOrEmpty(m map[string]interface{}, key string) string { + if v, ok := m[key]; ok { + return fmt.Sprintf("%v", v) + } + return "" +} + +func floatOrZero(m map[string]interface{}, key string) float64 { + if v, ok := m[key]; ok { + if f, ok := v.(float64); ok { + return f + } + } + return 0 +} + +func escapeSQLPath(p string) string { + return strings.ReplaceAll(p, "'", "''") +} diff --git a/pkg/ml/ingest.go b/pkg/ml/ingest.go new file mode 100644 index 00000000..d5a8604d --- /dev/null +++ b/pkg/ml/ingest.go @@ -0,0 +1,384 @@ +package ml + +import ( + "bufio" + "encoding/json" + "fmt" + "io" + "os" + "regexp" + "strconv" + "strings" + "time" +) + +// IngestConfig holds the configuration for a benchmark/training ingest run. +type IngestConfig struct { + ContentFile string + CapabilityFile string + TrainingLog string + Model string + RunID string + BatchSize int +} + +// contentScoreLine is the JSON structure for a content scores JSONL line. +type contentScoreLine struct { + Label string `json:"label"` + Aggregates map[string]interface{} `json:"aggregates"` + Probes map[string]contentScoreProbe `json:"probes"` +} + +// contentScoreProbe is the per-probe block within a content score line. +type contentScoreProbe struct { + Scores map[string]interface{} `json:"scores"` +} + +// capabilityScoreLine is the JSON structure for a capability scores JSONL line. +type capabilityScoreLine struct { + Label string `json:"label"` + Accuracy float64 `json:"accuracy"` + Correct int `json:"correct"` + Total int `json:"total"` + ByCategory map[string]capabilityCatBlock `json:"by_category"` +} + +// capabilityCatBlock is the per-category block within a capability score line. +type capabilityCatBlock struct { + Correct int `json:"correct"` + Total int `json:"total"` +} + +// Training log regexes. +var ( + reValLoss = regexp.MustCompile(`Iter (\d+): Val loss ([\d.]+)`) + reTrainLoss = regexp.MustCompile(`Iter (\d+): Train loss ([\d.]+), Learning Rate ([\d.eE+-]+), It/sec ([\d.]+), Tokens/sec ([\d.]+)`) +) + +// Ingest reads benchmark scores and training logs and writes them to InfluxDB. +// At least one of ContentFile, CapabilityFile, or TrainingLog must be set. +func Ingest(influx *InfluxClient, cfg IngestConfig, w io.Writer) error { + if cfg.ContentFile == "" && cfg.CapabilityFile == "" && cfg.TrainingLog == "" { + return fmt.Errorf("at least one of --content, --capability, or --training-log is required") + } + if cfg.Model == "" { + return fmt.Errorf("--model is required") + } + if cfg.RunID == "" { + cfg.RunID = cfg.Model + } + if cfg.BatchSize <= 0 { + cfg.BatchSize = 100 + } + + var totalPoints int + + if cfg.ContentFile != "" { + n, err := ingestContentScores(influx, cfg, w) + if err != nil { + return fmt.Errorf("ingest content scores: %w", err) + } + totalPoints += n + } + + if cfg.CapabilityFile != "" { + n, err := ingestCapabilityScores(influx, cfg, w) + if err != nil { + return fmt.Errorf("ingest capability scores: %w", err) + } + totalPoints += n + } + + if cfg.TrainingLog != "" { + n, err := ingestTrainingLog(influx, cfg, w) + if err != nil { + return fmt.Errorf("ingest training log: %w", err) + } + totalPoints += n + } + + fmt.Fprintf(w, "Ingested %d total points into InfluxDB\n", totalPoints) + return nil +} + +// ingestContentScores reads a content scores JSONL file and writes content_score +// and probe_score measurements to InfluxDB. +func ingestContentScores(influx *InfluxClient, cfg IngestConfig, w io.Writer) (int, error) { + f, err := os.Open(cfg.ContentFile) + if err != nil { + return 0, fmt.Errorf("open %s: %w", cfg.ContentFile, err) + } + defer f.Close() + + scanner := bufio.NewScanner(f) + scanner.Buffer(make([]byte, 1024*1024), 1024*1024) + + var lines []string + var totalPoints int + lineNum := 0 + + for scanner.Scan() { + lineNum++ + raw := strings.TrimSpace(scanner.Text()) + if raw == "" { + continue + } + + var entry contentScoreLine + if err := json.Unmarshal([]byte(raw), &entry); err != nil { + return totalPoints, fmt.Errorf("line %d: parse json: %w", lineNum, err) + } + + label := entry.Label + iteration := extractIteration(label) + hasKernel := "false" + if strings.Contains(strings.ToLower(label), "kernel") || strings.Contains(label, "LEK") { + hasKernel = "true" + } + ts := time.Now().UnixNano() + + // Write aggregate content_score — one point per dimension. + for dim, val := range entry.Aggregates { + score, ok := toFloat64(val) + if !ok { + continue + } + line := fmt.Sprintf( + "content_score,model=%s,run_id=%s,label=%s,dimension=%s,has_kernel=%s score=%.6f,iteration=%di %d", + EscapeLp(cfg.Model), EscapeLp(cfg.RunID), EscapeLp(label), + EscapeLp(dim), hasKernel, score, iteration, ts, + ) + lines = append(lines, line) + totalPoints++ + } + + // Write per-probe probe_score — one point per probe per dimension. + for probeID, probe := range entry.Probes { + for dim, val := range probe.Scores { + score, ok := toFloat64(val) + if !ok { + continue + } + line := fmt.Sprintf( + "probe_score,model=%s,run_id=%s,label=%s,probe_id=%s,dimension=%s,has_kernel=%s score=%.6f,iteration=%di %d", + EscapeLp(cfg.Model), EscapeLp(cfg.RunID), EscapeLp(label), + EscapeLp(probeID), EscapeLp(dim), hasKernel, score, iteration, ts, + ) + lines = append(lines, line) + totalPoints++ + } + } + + // Flush batch if needed. + if len(lines) >= cfg.BatchSize { + if err := influx.WriteLp(lines); err != nil { + return totalPoints, fmt.Errorf("write batch: %w", err) + } + lines = lines[:0] + } + } + + if err := scanner.Err(); err != nil { + return totalPoints, fmt.Errorf("scan %s: %w", cfg.ContentFile, err) + } + + // Flush remaining lines. + if len(lines) > 0 { + if err := influx.WriteLp(lines); err != nil { + return totalPoints, fmt.Errorf("write final batch: %w", err) + } + } + + fmt.Fprintf(w, " content scores: %d points from %d lines\n", totalPoints, lineNum) + return totalPoints, nil +} + +// ingestCapabilityScores reads a capability scores JSONL file and writes +// capability_score measurements to InfluxDB. +func ingestCapabilityScores(influx *InfluxClient, cfg IngestConfig, w io.Writer) (int, error) { + f, err := os.Open(cfg.CapabilityFile) + if err != nil { + return 0, fmt.Errorf("open %s: %w", cfg.CapabilityFile, err) + } + defer f.Close() + + scanner := bufio.NewScanner(f) + scanner.Buffer(make([]byte, 1024*1024), 1024*1024) + + var lines []string + var totalPoints int + lineNum := 0 + + for scanner.Scan() { + lineNum++ + raw := strings.TrimSpace(scanner.Text()) + if raw == "" { + continue + } + + var entry capabilityScoreLine + if err := json.Unmarshal([]byte(raw), &entry); err != nil { + return totalPoints, fmt.Errorf("line %d: parse json: %w", lineNum, err) + } + + label := entry.Label + iteration := extractIteration(label) + ts := time.Now().UnixNano() + + // Overall capability score. + line := fmt.Sprintf( + "capability_score,model=%s,run_id=%s,label=%s,category=overall accuracy=%.6f,correct=%di,total=%di,iteration=%di %d", + EscapeLp(cfg.Model), EscapeLp(cfg.RunID), EscapeLp(label), + entry.Accuracy, entry.Correct, entry.Total, iteration, ts, + ) + lines = append(lines, line) + totalPoints++ + + // Per-category breakdown. + for cat, block := range entry.ByCategory { + var catAccuracy float64 + if block.Total > 0 { + catAccuracy = float64(block.Correct) / float64(block.Total) + } + line := fmt.Sprintf( + "capability_score,model=%s,run_id=%s,label=%s,category=%s accuracy=%.6f,correct=%di,total=%di,iteration=%di %d", + EscapeLp(cfg.Model), EscapeLp(cfg.RunID), EscapeLp(label), + EscapeLp(cat), catAccuracy, block.Correct, block.Total, iteration, ts, + ) + lines = append(lines, line) + totalPoints++ + } + + // Flush batch if needed. + if len(lines) >= cfg.BatchSize { + if err := influx.WriteLp(lines); err != nil { + return totalPoints, fmt.Errorf("write batch: %w", err) + } + lines = lines[:0] + } + } + + if err := scanner.Err(); err != nil { + return totalPoints, fmt.Errorf("scan %s: %w", cfg.CapabilityFile, err) + } + + // Flush remaining lines. + if len(lines) > 0 { + if err := influx.WriteLp(lines); err != nil { + return totalPoints, fmt.Errorf("write final batch: %w", err) + } + } + + fmt.Fprintf(w, " capability scores: %d points from %d lines\n", totalPoints, lineNum) + return totalPoints, nil +} + +// ingestTrainingLog reads an MLX LoRA training log and writes training_loss +// measurements to InfluxDB for both training and validation loss entries. +func ingestTrainingLog(influx *InfluxClient, cfg IngestConfig, w io.Writer) (int, error) { + f, err := os.Open(cfg.TrainingLog) + if err != nil { + return 0, fmt.Errorf("open %s: %w", cfg.TrainingLog, err) + } + defer f.Close() + + scanner := bufio.NewScanner(f) + scanner.Buffer(make([]byte, 1024*1024), 1024*1024) + + var lines []string + var totalPoints int + lineNum := 0 + + for scanner.Scan() { + lineNum++ + text := scanner.Text() + + // Try validation loss first (shorter regex, less common). + if m := reValLoss.FindStringSubmatch(text); m != nil { + iter, _ := strconv.Atoi(m[1]) + loss, _ := strconv.ParseFloat(m[2], 64) + ts := time.Now().UnixNano() + + line := fmt.Sprintf( + "training_loss,model=%s,run_id=%s,loss_type=val loss=%.6f,iteration=%di %d", + EscapeLp(cfg.Model), EscapeLp(cfg.RunID), loss, iter, ts, + ) + lines = append(lines, line) + totalPoints++ + } + + // Try training loss. + if m := reTrainLoss.FindStringSubmatch(text); m != nil { + iter, _ := strconv.Atoi(m[1]) + loss, _ := strconv.ParseFloat(m[2], 64) + lr, _ := strconv.ParseFloat(m[3], 64) + itPerSec, _ := strconv.ParseFloat(m[4], 64) + tokPerSec, _ := strconv.ParseFloat(m[5], 64) + ts := time.Now().UnixNano() + + line := fmt.Sprintf( + "training_loss,model=%s,run_id=%s,loss_type=train loss=%.6f,iteration=%di,learning_rate=%.10f,it_per_sec=%.4f,tokens_per_sec=%.2f %d", + EscapeLp(cfg.Model), EscapeLp(cfg.RunID), loss, iter, lr, itPerSec, tokPerSec, ts, + ) + lines = append(lines, line) + totalPoints++ + } + + // Flush batch if needed. + if len(lines) >= cfg.BatchSize { + if err := influx.WriteLp(lines); err != nil { + return totalPoints, fmt.Errorf("write batch: %w", err) + } + lines = lines[:0] + } + } + + if err := scanner.Err(); err != nil { + return totalPoints, fmt.Errorf("scan %s: %w", cfg.TrainingLog, err) + } + + // Flush remaining lines. + if len(lines) > 0 { + if err := influx.WriteLp(lines); err != nil { + return totalPoints, fmt.Errorf("write final batch: %w", err) + } + } + + fmt.Fprintf(w, " training log: %d points from %d lines\n", totalPoints, lineNum) + return totalPoints, nil +} + +// extractIteration extracts an iteration number from a label like "model@200". +// Returns 0 if no iteration is found. +func extractIteration(label string) int { + idx := strings.LastIndex(label, "@") + if idx < 0 || idx+1 >= len(label) { + return 0 + } + n, err := strconv.Atoi(label[idx+1:]) + if err != nil { + return 0 + } + return n +} + +// toFloat64 converts a JSON-decoded interface{} value to float64. +// Handles float64 (standard json.Unmarshal), json.Number, and string values. +func toFloat64(v interface{}) (float64, bool) { + switch val := v.(type) { + case float64: + return val, true + case int: + return float64(val), true + case int64: + return float64(val), true + case json.Number: + f, err := val.Float64() + return f, err == nil + case string: + f, err := strconv.ParseFloat(val, 64) + return f, err == nil + default: + return 0, false + } +} diff --git a/pkg/ml/inventory.go b/pkg/ml/inventory.go new file mode 100644 index 00000000..98853624 --- /dev/null +++ b/pkg/ml/inventory.go @@ -0,0 +1,147 @@ +package ml + +import ( + "fmt" + "io" + "strings" +) + +// TargetTotal is the golden set target size used for progress reporting. +const TargetTotal = 15000 + +// tableOrder defines the canonical display order for inventory tables. +var tableOrder = []string{ + "golden_set", "expansion_prompts", "seeds", "prompts", + "training_examples", "gemini_responses", "benchmark_questions", + "benchmark_results", "validations", "checkpoint_scores", + "probe_results", "scoring_results", +} + +// tableDetail holds extra context for a single table beyond its row count. +type tableDetail struct { + notes []string +} + +// PrintInventory queries all known DuckDB tables and prints a formatted +// inventory with row counts, detail breakdowns, and a grand total. +func PrintInventory(db *DB, w io.Writer) error { + counts, err := db.TableCounts() + if err != nil { + return fmt.Errorf("table counts: %w", err) + } + + details := gatherDetails(db, counts) + + fmt.Fprintln(w, "DuckDB Inventory") + fmt.Fprintln(w, strings.Repeat("-", 52)) + + grand := 0 + for _, table := range tableOrder { + count, ok := counts[table] + if !ok { + continue + } + grand += count + fmt.Fprintf(w, " %-24s %8d rows", table, count) + + if d, has := details[table]; has && len(d.notes) > 0 { + fmt.Fprintf(w, " (%s)", strings.Join(d.notes, ", ")) + } + fmt.Fprintln(w) + } + + fmt.Fprintln(w, strings.Repeat("-", 52)) + fmt.Fprintf(w, " %-24s %8d rows\n", "TOTAL", grand) + + return nil +} + +// gatherDetails runs per-table detail queries and returns annotations keyed +// by table name. Errors on individual queries are silently ignored so the +// inventory always prints. +func gatherDetails(db *DB, counts map[string]int) map[string]*tableDetail { + details := make(map[string]*tableDetail) + + // golden_set: progress toward target + if count, ok := counts["golden_set"]; ok { + pct := float64(count) / float64(TargetTotal) * 100 + details["golden_set"] = &tableDetail{ + notes: []string{fmt.Sprintf("%.1f%% of %d target", pct, TargetTotal)}, + } + } + + // training_examples: distinct sources + if _, ok := counts["training_examples"]; ok { + rows, err := db.QueryRows("SELECT COUNT(DISTINCT source) AS n FROM training_examples") + if err == nil && len(rows) > 0 { + n := toInt(rows[0]["n"]) + details["training_examples"] = &tableDetail{ + notes: []string{fmt.Sprintf("%d sources", n)}, + } + } + } + + // prompts: distinct domains and voices + if _, ok := counts["prompts"]; ok { + d := &tableDetail{} + rows, err := db.QueryRows("SELECT COUNT(DISTINCT domain) AS n FROM prompts") + if err == nil && len(rows) > 0 { + d.notes = append(d.notes, fmt.Sprintf("%d domains", toInt(rows[0]["n"]))) + } + rows, err = db.QueryRows("SELECT COUNT(DISTINCT voice) AS n FROM prompts") + if err == nil && len(rows) > 0 { + d.notes = append(d.notes, fmt.Sprintf("%d voices", toInt(rows[0]["n"]))) + } + if len(d.notes) > 0 { + details["prompts"] = d + } + } + + // gemini_responses: group by source_model + if _, ok := counts["gemini_responses"]; ok { + rows, err := db.QueryRows( + "SELECT source_model, COUNT(*) AS n FROM gemini_responses GROUP BY source_model ORDER BY n DESC", + ) + if err == nil && len(rows) > 0 { + var parts []string + for _, row := range rows { + model := strVal(row, "source_model") + n := toInt(row["n"]) + if model != "" { + parts = append(parts, fmt.Sprintf("%s:%d", model, n)) + } + } + if len(parts) > 0 { + details["gemini_responses"] = &tableDetail{notes: parts} + } + } + } + + // benchmark_results: distinct source categories + if _, ok := counts["benchmark_results"]; ok { + rows, err := db.QueryRows("SELECT COUNT(DISTINCT source) AS n FROM benchmark_results") + if err == nil && len(rows) > 0 { + n := toInt(rows[0]["n"]) + details["benchmark_results"] = &tableDetail{ + notes: []string{fmt.Sprintf("%d categories", n)}, + } + } + } + + return details +} + +// toInt converts a DuckDB value to int. DuckDB returns integers as int64 (not +// float64 like InfluxDB), so we handle both types. +func toInt(v interface{}) int { + switch n := v.(type) { + case int64: + return int(n) + case int32: + return int(n) + case float64: + return int(n) + default: + return 0 + } +} diff --git a/pkg/ml/metrics.go b/pkg/ml/metrics.go new file mode 100644 index 00000000..68288dda --- /dev/null +++ b/pkg/ml/metrics.go @@ -0,0 +1,100 @@ +package ml + +import ( + "fmt" + "io" + "time" +) + +// PushMetrics queries golden_set stats from DuckDB and writes them to InfluxDB +// as golden_set_stats, golden_set_domain, and golden_set_voice measurements. +func PushMetrics(db *DB, influx *InfluxClient, w io.Writer) error { + // Overall stats. + var total, domains, voices int + var avgGenTime, avgChars float64 + err := db.conn.QueryRow( + "SELECT count(*), count(DISTINCT domain), count(DISTINCT voice), " + + "coalesce(avg(gen_time), 0), coalesce(avg(char_count), 0) FROM golden_set", + ).Scan(&total, &domains, &voices, &avgGenTime, &avgChars) + if err != nil { + return fmt.Errorf("query golden_set stats: %w", err) + } + + if total == 0 { + fmt.Fprintln(w, "golden_set is empty, nothing to push") + return nil + } + + completionPct := float64(total) / float64(TargetTotal) * 100.0 + ts := time.Now().UnixNano() + + var lines []string + + // Overall stats point. + lines = append(lines, fmt.Sprintf( + "golden_set_stats total_examples=%di,domains=%di,voices=%di,avg_gen_time=%.2f,avg_response_chars=%.0f,completion_pct=%.1f %d", + total, domains, voices, avgGenTime, avgChars, completionPct, ts, + )) + + // Per-domain breakdown. + domainRows, err := db.conn.Query( + "SELECT domain, count(*) AS cnt, coalesce(avg(gen_time), 0) AS avg_gt FROM golden_set GROUP BY domain ORDER BY domain", + ) + if err != nil { + return fmt.Errorf("query golden_set domains: %w", err) + } + defer domainRows.Close() + + for domainRows.Next() { + var domain string + var count int + var avgGT float64 + if err := domainRows.Scan(&domain, &count, &avgGT); err != nil { + return fmt.Errorf("scan domain row: %w", err) + } + lines = append(lines, fmt.Sprintf( + "golden_set_domain,domain=%s count=%di,avg_gen_time=%.2f %d", + EscapeLp(domain), count, avgGT, ts, + )) + } + if err := domainRows.Err(); err != nil { + return fmt.Errorf("iterate domain rows: %w", err) + } + + // Per-voice breakdown. + voiceRows, err := db.conn.Query( + "SELECT voice, count(*) AS cnt, coalesce(avg(char_count), 0) AS avg_cc, coalesce(avg(gen_time), 0) AS avg_gt FROM golden_set GROUP BY voice ORDER BY voice", + ) + if err != nil { + return fmt.Errorf("query golden_set voices: %w", err) + } + defer voiceRows.Close() + + for voiceRows.Next() { + var voice string + var count int + var avgCC, avgGT float64 + if err := voiceRows.Scan(&voice, &count, &avgCC, &avgGT); err != nil { + return fmt.Errorf("scan voice row: %w", err) + } + lines = append(lines, fmt.Sprintf( + "golden_set_voice,voice=%s count=%di,avg_chars=%.0f,avg_gen_time=%.2f %d", + EscapeLp(voice), count, avgCC, avgGT, ts, + )) + } + if err := voiceRows.Err(); err != nil { + return fmt.Errorf("iterate voice rows: %w", err) + } + + // Write all points to InfluxDB. + if err := influx.WriteLp(lines); err != nil { + return fmt.Errorf("write metrics to influxdb: %w", err) + } + + fmt.Fprintf(w, "Pushed %d points to InfluxDB\n", len(lines)) + fmt.Fprintf(w, " total=%d domains=%d voices=%d completion=%.1f%%\n", + total, domains, voices, completionPct) + fmt.Fprintf(w, " avg_gen_time=%.2fs avg_chars=%.0f\n", avgGenTime, avgChars) + + return nil +} diff --git a/pkg/ml/normalize.go b/pkg/ml/normalize.go new file mode 100644 index 00000000..eb78bde9 --- /dev/null +++ b/pkg/ml/normalize.go @@ -0,0 +1,153 @@ +package ml + +import ( + "fmt" + "io" + "strings" +) + +// NormalizeConfig configures the seed normalization process. +type NormalizeConfig struct { + MinLength int +} + +// NormalizeSeeds deduplicates seeds into the expansion_prompts table. +// +// Steps: +// 1. Verify the seeds table exists and report its row count. +// 2. Drop and recreate expansion_prompts using deduplicated seeds, +// excluding prompts already present in the prompts or golden_set tables. +// 3. Assign priority based on domain coverage (underrepresented domains +// receive higher priority via RANK). +// 4. Print a region distribution summary. +func NormalizeSeeds(db *DB, cfg NormalizeConfig, w io.Writer) error { + // 1. Check seeds table exists and get count. + var seedCount int + if err := db.conn.QueryRow("SELECT count(*) FROM seeds").Scan(&seedCount); err != nil { + return fmt.Errorf("no seeds table (run import-all first): %w", err) + } + fmt.Fprintf(w, "Seeds table: %d rows\n", seedCount) + + if seedCount == 0 { + return fmt.Errorf("seeds table is empty, nothing to normalize") + } + + // 2. Drop and recreate expansion_prompts. + if _, err := db.conn.Exec("DROP TABLE IF EXISTS expansion_prompts"); err != nil { + return fmt.Errorf("drop expansion_prompts: %w", err) + } + + createSQL := fmt.Sprintf(` + CREATE TABLE expansion_prompts AS + WITH unique_seeds AS ( + SELECT + ROW_NUMBER() OVER (ORDER BY region, domain, seed_id) AS idx, + seed_id, region, domain, prompt + FROM ( + SELECT DISTINCT ON (prompt) + seed_id, region, domain, prompt + FROM seeds + WHERE length(prompt) >= %d + ORDER BY prompt, seed_id + ) + ), + existing_prompts AS ( + SELECT prompt FROM prompts + UNION ALL + SELECT prompt FROM golden_set + ) + SELECT + us.idx, us.seed_id, us.region, us.domain, + 'en' AS language, us.prompt, '' AS prompt_en, + 0 AS priority, 'pending' AS status + FROM unique_seeds us + WHERE NOT EXISTS ( + SELECT 1 FROM existing_prompts ep WHERE ep.prompt = us.prompt + ) + `, cfg.MinLength) + + if _, err := db.conn.Exec(createSQL); err != nil { + return fmt.Errorf("create expansion_prompts: %w", err) + } + + var epCount int + if err := db.conn.QueryRow("SELECT count(*) FROM expansion_prompts").Scan(&epCount); err != nil { + return fmt.Errorf("count expansion_prompts: %w", err) + } + fmt.Fprintf(w, "Expansion prompts created: %d (min length %d, deduped, excluding existing)\n", epCount, cfg.MinLength) + + if epCount == 0 { + fmt.Fprintln(w, "No new expansion prompts to process.") + return nil + } + + // 3. Assign priority based on domain coverage. + prioritySQL := ` + UPDATE expansion_prompts SET priority = sub.rnk + FROM ( + SELECT domain, RANK() OVER (ORDER BY cnt ASC) AS rnk + FROM ( + SELECT domain, count(*) AS cnt + FROM expansion_prompts + GROUP BY domain + ) domain_counts + ) sub + WHERE expansion_prompts.domain = sub.domain + ` + if _, err := db.conn.Exec(prioritySQL); err != nil { + return fmt.Errorf("assign priority: %w", err) + } + fmt.Fprintln(w, "Priority assigned (underrepresented domains ranked higher).") + + // 4. Region distribution summary. + fmt.Fprintln(w) + fmt.Fprintln(w, "Region distribution:") + + rows, err := db.conn.Query(` + SELECT + CASE + WHEN region LIKE 'cn%' THEN 'cn' + WHEN region LIKE 'en%' THEN 'en' + WHEN region LIKE 'ru%' THEN 'ru' + WHEN region LIKE 'de%' THEN 'de' + WHEN region LIKE 'es%' THEN 'es' + WHEN region LIKE 'fr%' THEN 'fr' + WHEN region LIKE 'latam%' THEN 'latam' + WHEN region LIKE 'africa%' THEN 'africa' + WHEN region LIKE 'eu%' THEN 'eu' + WHEN region LIKE 'me%' THEN 'me' + ELSE 'other' + END AS region_group, + count(*) AS cnt + FROM expansion_prompts + GROUP BY region_group + ORDER BY cnt DESC + `) + if err != nil { + return fmt.Errorf("region distribution query: %w", err) + } + defer rows.Close() + + var totalFromRegions int + var lines []string + for rows.Next() { + var region string + var cnt int + if err := rows.Scan(®ion, &cnt); err != nil { + return fmt.Errorf("scan region row: %w", err) + } + totalFromRegions += cnt + lines = append(lines, fmt.Sprintf(" %-10s %6d", region, cnt)) + } + if err := rows.Err(); err != nil { + return fmt.Errorf("iterate region rows: %w", err) + } + + for _, line := range lines { + fmt.Fprintln(w, line) + } + fmt.Fprintf(w, " %-10s %6d\n", strings.Repeat("-", 10), totalFromRegions) + fmt.Fprintf(w, " %-10s %6d\n", "total", totalFromRegions) + + return nil +} diff --git a/pkg/ml/publish.go b/pkg/ml/publish.go new file mode 100644 index 00000000..5c21118d --- /dev/null +++ b/pkg/ml/publish.go @@ -0,0 +1,157 @@ +package ml + +import ( + "bytes" + "fmt" + "io" + "net/http" + "os" + "path/filepath" + "strings" + "time" +) + +// PublishConfig holds options for the publish operation. +type PublishConfig struct { + InputDir string + Repo string + Public bool + Token string + DryRun bool +} + +// uploadEntry pairs a local file path with its remote destination. +type uploadEntry struct { + local string + remote string +} + +// Publish uploads Parquet files to HuggingFace Hub. +// +// It looks for train.parquet, valid.parquet, and test.parquet in InputDir, +// plus an optional dataset_card.md in the parent directory (uploaded as README.md). +// The token is resolved from PublishConfig.Token, the HF_TOKEN environment variable, +// or ~/.huggingface/token, in that order. +func Publish(cfg PublishConfig, w io.Writer) error { + if cfg.InputDir == "" { + return fmt.Errorf("input directory is required") + } + + token := resolveHFToken(cfg.Token) + if token == "" && !cfg.DryRun { + return fmt.Errorf("HuggingFace token required (--token, HF_TOKEN env, or ~/.huggingface/token)") + } + + files, err := collectUploadFiles(cfg.InputDir) + if err != nil { + return err + } + if len(files) == 0 { + return fmt.Errorf("no Parquet files found in %s", cfg.InputDir) + } + + if cfg.DryRun { + fmt.Fprintf(w, "Dry run: would publish to %s\n", cfg.Repo) + if cfg.Public { + fmt.Fprintln(w, " Visibility: public") + } else { + fmt.Fprintln(w, " Visibility: private") + } + for _, f := range files { + info, err := os.Stat(f.local) + if err != nil { + return fmt.Errorf("stat %s: %w", f.local, err) + } + sizeMB := float64(info.Size()) / 1024 / 1024 + fmt.Fprintf(w, " %s -> %s (%.1f MB)\n", filepath.Base(f.local), f.remote, sizeMB) + } + return nil + } + + fmt.Fprintf(w, "Publishing to https://huggingface.co/datasets/%s\n", cfg.Repo) + + for _, f := range files { + if err := uploadFileToHF(token, cfg.Repo, f.local, f.remote); err != nil { + return fmt.Errorf("upload %s: %w", filepath.Base(f.local), err) + } + fmt.Fprintf(w, " Uploaded %s -> %s\n", filepath.Base(f.local), f.remote) + } + + fmt.Fprintf(w, "\nPublished to https://huggingface.co/datasets/%s\n", cfg.Repo) + return nil +} + +// resolveHFToken returns a HuggingFace API token from the given value, +// HF_TOKEN env var, or ~/.huggingface/token file. +func resolveHFToken(explicit string) string { + if explicit != "" { + return explicit + } + if env := os.Getenv("HF_TOKEN"); env != "" { + return env + } + home, err := os.UserHomeDir() + if err != nil { + return "" + } + data, err := os.ReadFile(filepath.Join(home, ".huggingface", "token")) + if err != nil { + return "" + } + return strings.TrimSpace(string(data)) +} + +// collectUploadFiles finds Parquet split files and an optional dataset card. +func collectUploadFiles(inputDir string) ([]uploadEntry, error) { + splits := []string{"train", "valid", "test"} + var files []uploadEntry + + for _, split := range splits { + path := filepath.Join(inputDir, split+".parquet") + if _, err := os.Stat(path); os.IsNotExist(err) { + continue + } else if err != nil { + return nil, fmt.Errorf("stat %s: %w", path, err) + } + files = append(files, uploadEntry{path, fmt.Sprintf("data/%s.parquet", split)}) + } + + // Check for dataset card in parent directory. + cardPath := filepath.Join(inputDir, "..", "dataset_card.md") + if _, err := os.Stat(cardPath); err == nil { + files = append(files, uploadEntry{cardPath, "README.md"}) + } + + return files, nil +} + +// uploadFileToHF uploads a single file to a HuggingFace dataset repo via the Hub API. +func uploadFileToHF(token, repoID, localPath, remotePath string) error { + data, err := os.ReadFile(localPath) + if err != nil { + return fmt.Errorf("read %s: %w", localPath, err) + } + + url := fmt.Sprintf("https://huggingface.co/api/datasets/%s/upload/main/%s", repoID, remotePath) + + req, err := http.NewRequest(http.MethodPut, url, bytes.NewReader(data)) + if err != nil { + return fmt.Errorf("create request: %w", err) + } + req.Header.Set("Authorization", "Bearer "+token) + req.Header.Set("Content-Type", "application/octet-stream") + + client := &http.Client{Timeout: 120 * time.Second} + resp, err := client.Do(req) + if err != nil { + return fmt.Errorf("upload request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode >= 300 { + body, _ := io.ReadAll(resp.Body) + return fmt.Errorf("upload failed: HTTP %d: %s", resp.StatusCode, string(body)) + } + + return nil +} diff --git a/pkg/ml/seed_influx.go b/pkg/ml/seed_influx.go new file mode 100644 index 00000000..aff82659 --- /dev/null +++ b/pkg/ml/seed_influx.go @@ -0,0 +1,111 @@ +package ml + +import ( + "fmt" + "io" + "strings" +) + +// SeedInfluxConfig holds options for the SeedInflux migration. +type SeedInfluxConfig struct { + Force bool + BatchSize int +} + +// SeedInflux migrates golden_set rows from DuckDB into InfluxDB as +// gold_gen measurement points. This is a one-time migration tool; +// it skips the write when InfluxDB already contains all records +// unless Force is set. +func SeedInflux(db *DB, influx *InfluxClient, cfg SeedInfluxConfig, w io.Writer) error { + if cfg.BatchSize <= 0 { + cfg.BatchSize = 500 + } + + // Count source rows in DuckDB. + var total int + if err := db.conn.QueryRow("SELECT count(*) FROM golden_set").Scan(&total); err != nil { + return fmt.Errorf("no golden_set table: %w", err) + } + + // Check how many distinct records InfluxDB already has. + existing := 0 + rows, err := influx.QuerySQL("SELECT count(DISTINCT i) AS n FROM gold_gen") + if err == nil && len(rows) > 0 { + if n, ok := rows[0]["n"].(float64); ok { + existing = int(n) + } + } + + fmt.Fprintf(w, "DuckDB has %d records, InfluxDB golden_gen has %d\n", total, existing) + + if existing >= total && !cfg.Force { + fmt.Fprintln(w, "InfluxDB already has all records. Use --force to re-seed.") + return nil + } + + // Query all golden_set rows from DuckDB. + dbRows, err := db.conn.Query( + "SELECT idx, seed_id, domain, voice, gen_time, char_count FROM golden_set ORDER BY idx", + ) + if err != nil { + return fmt.Errorf("query golden_set: %w", err) + } + defer dbRows.Close() + + var batch []string + written := 0 + + for dbRows.Next() { + var idx int + var seedID, domain, voice string + var genTime float64 + var charCount int + + if err := dbRows.Scan(&idx, &seedID, &domain, &voice, &genTime, &charCount); err != nil { + return fmt.Errorf("scan row %d: %w", written, err) + } + + // Build line protocol point. + // Tags: i (idx), w (worker), d (domain), v (voice) + // Fields: seed_id (string), gen_time (float), chars (integer) + escapedSeedID := strings.ReplaceAll(seedID, `"`, `\"`) + + line := fmt.Sprintf( + "gold_gen,i=%s,w=migration,d=%s,v=%s seed_id=\"%s\",gen_time=%v,chars=%di", + EscapeLp(fmt.Sprintf("%d", idx)), + EscapeLp(domain), + EscapeLp(voice), + escapedSeedID, + genTime, + charCount, + ) + batch = append(batch, line) + + if len(batch) >= cfg.BatchSize { + if err := influx.WriteLp(batch); err != nil { + return fmt.Errorf("write batch at row %d: %w", written, err) + } + written += len(batch) + batch = batch[:0] + + if written%2000 == 0 { + fmt.Fprintf(w, " wrote %d / %d\n", written, total) + } + } + } + + if err := dbRows.Err(); err != nil { + return fmt.Errorf("iterate golden_set rows: %w", err) + } + + // Flush remaining batch. + if len(batch) > 0 { + if err := influx.WriteLp(batch); err != nil { + return fmt.Errorf("write final batch: %w", err) + } + written += len(batch) + } + + fmt.Fprintf(w, "Seeded %d records into InfluxDB golden_gen\n", written) + return nil +}