feat: port 11 LEM data management commands into core ml
Ports all remaining LEM pipeline commands from pkg/lem into core ml, eliminating the standalone LEM CLI dependency. Each command is split into reusable business logic (pkg/ml/) and a thin cobra wrapper (internal/cmd/ml/). New commands: query, inventory, metrics, ingest, normalize, seed-influx, consolidate, import-all, approve, publish, coverage. Adds Path(), Exec(), QueryRowScan() convenience methods to DB type. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
d2bb19c6bd
commit
a290ab31e9
23 changed files with 2478 additions and 0 deletions
53
internal/cmd/ml/cmd_approve.go
Normal file
53
internal/cmd/ml/cmd_approve.go
Normal file
|
|
@ -0,0 +1,53 @@
|
|||
package ml
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"forge.lthn.ai/core/cli/pkg/cli"
|
||||
"forge.lthn.ai/core/cli/pkg/ml"
|
||||
)
|
||||
|
||||
var (
|
||||
approveOutput string
|
||||
approveThreshold float64
|
||||
)
|
||||
|
||||
var approveCmd = &cli.Command{
|
||||
Use: "approve",
|
||||
Short: "Filter scored expansions into training JSONL",
|
||||
Long: "Filters scored expansion responses by quality threshold and exports approved ones as chat-format training JSONL.",
|
||||
RunE: runApprove,
|
||||
}
|
||||
|
||||
func init() {
|
||||
approveCmd.Flags().StringVar(&approveOutput, "output", "", "Output JSONL file (defaults to expansion-approved.jsonl in db dir)")
|
||||
approveCmd.Flags().Float64Var(&approveThreshold, "threshold", 6.0, "Min judge average to approve")
|
||||
}
|
||||
|
||||
func runApprove(cmd *cli.Command, args []string) error {
|
||||
path := dbPath
|
||||
if path == "" {
|
||||
path = os.Getenv("LEM_DB")
|
||||
}
|
||||
if path == "" {
|
||||
return fmt.Errorf("--db or LEM_DB required")
|
||||
}
|
||||
|
||||
output := approveOutput
|
||||
if output == "" {
|
||||
output = filepath.Join(filepath.Dir(path), "expansion-approved.jsonl")
|
||||
}
|
||||
|
||||
db, err := ml.OpenDB(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("open db: %w", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
return ml.ApproveExpansions(db, ml.ApproveConfig{
|
||||
Output: output,
|
||||
Threshold: approveThreshold,
|
||||
}, cmd.OutOrStdout())
|
||||
}
|
||||
41
internal/cmd/ml/cmd_consolidate.go
Normal file
41
internal/cmd/ml/cmd_consolidate.go
Normal file
|
|
@ -0,0 +1,41 @@
|
|||
package ml
|
||||
|
||||
import (
|
||||
"forge.lthn.ai/core/cli/pkg/cli"
|
||||
"forge.lthn.ai/core/cli/pkg/ml"
|
||||
)
|
||||
|
||||
var (
|
||||
consolidateM3Host string
|
||||
consolidateRemoteDir string
|
||||
consolidatePattern string
|
||||
consolidateOutputDir string
|
||||
consolidateMergedOut string
|
||||
)
|
||||
|
||||
var consolidateCmd = &cli.Command{
|
||||
Use: "consolidate",
|
||||
Short: "Pull and merge response JSONL files from M3",
|
||||
Long: "Pulls JSONL response files from M3 via SSH/SCP, merges them by idx, deduplicates, and writes a single merged JSONL output.",
|
||||
RunE: runConsolidate,
|
||||
}
|
||||
|
||||
func init() {
|
||||
consolidateCmd.Flags().StringVar(&consolidateM3Host, "m3-host", "m3", "M3 SSH host")
|
||||
consolidateCmd.Flags().StringVar(&consolidateRemoteDir, "remote", "/Volumes/Data/lem/responses", "Remote response directory")
|
||||
consolidateCmd.Flags().StringVar(&consolidatePattern, "pattern", "gold*.jsonl", "File glob pattern")
|
||||
consolidateCmd.Flags().StringVar(&consolidateOutputDir, "output", "", "Local output directory (default: responses)")
|
||||
consolidateCmd.Flags().StringVar(&consolidateMergedOut, "merged", "", "Merged output path (default: gold-merged.jsonl in parent of output dir)")
|
||||
}
|
||||
|
||||
func runConsolidate(cmd *cli.Command, args []string) error {
|
||||
cfg := ml.ConsolidateConfig{
|
||||
M3Host: consolidateM3Host,
|
||||
RemoteDir: consolidateRemoteDir,
|
||||
Pattern: consolidatePattern,
|
||||
OutputDir: consolidateOutputDir,
|
||||
MergedOut: consolidateMergedOut,
|
||||
}
|
||||
|
||||
return ml.Consolidate(cfg, cmd.OutOrStdout())
|
||||
}
|
||||
34
internal/cmd/ml/cmd_coverage.go
Normal file
34
internal/cmd/ml/cmd_coverage.go
Normal file
|
|
@ -0,0 +1,34 @@
|
|||
package ml
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"forge.lthn.ai/core/cli/pkg/cli"
|
||||
"forge.lthn.ai/core/cli/pkg/ml"
|
||||
)
|
||||
|
||||
var coverageCmd = &cli.Command{
|
||||
Use: "coverage",
|
||||
Short: "Analyze seed coverage by region and domain",
|
||||
Long: "Queries seeds by region and domain, renders ASCII bar charts, and highlights underrepresented areas.",
|
||||
RunE: runCoverage,
|
||||
}
|
||||
|
||||
func runCoverage(cmd *cli.Command, args []string) error {
|
||||
path := dbPath
|
||||
if path == "" {
|
||||
path = os.Getenv("LEM_DB")
|
||||
}
|
||||
if path == "" {
|
||||
return fmt.Errorf("--db or LEM_DB required")
|
||||
}
|
||||
|
||||
db, err := ml.OpenDB(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("open db: %w", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
return ml.PrintCoverage(db, cmd.OutOrStdout())
|
||||
}
|
||||
58
internal/cmd/ml/cmd_import.go
Normal file
58
internal/cmd/ml/cmd_import.go
Normal file
|
|
@ -0,0 +1,58 @@
|
|||
package ml
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"forge.lthn.ai/core/cli/pkg/cli"
|
||||
"forge.lthn.ai/core/cli/pkg/ml"
|
||||
)
|
||||
|
||||
var importCmd = &cli.Command{
|
||||
Use: "import-all",
|
||||
Short: "Import all LEM data into DuckDB",
|
||||
Long: "Imports golden set, training examples, benchmark results, benchmark questions, and seeds into DuckDB from M3 and local files.",
|
||||
RunE: runImportAll,
|
||||
}
|
||||
|
||||
var (
|
||||
importSkipM3 bool
|
||||
importDataDir string
|
||||
importM3Host string
|
||||
)
|
||||
|
||||
func init() {
|
||||
importCmd.Flags().BoolVar(&importSkipM3, "skip-m3", false, "Skip pulling data from M3")
|
||||
importCmd.Flags().StringVar(&importDataDir, "data-dir", "", "Local data directory (defaults to db directory)")
|
||||
importCmd.Flags().StringVar(&importM3Host, "m3-host", "m3", "M3 SSH host alias")
|
||||
}
|
||||
|
||||
func runImportAll(cmd *cli.Command, args []string) error {
|
||||
path := dbPath
|
||||
if path == "" {
|
||||
path = os.Getenv("LEM_DB")
|
||||
}
|
||||
if path == "" {
|
||||
return fmt.Errorf("--db or LEM_DB required")
|
||||
}
|
||||
|
||||
dataDir := importDataDir
|
||||
if dataDir == "" {
|
||||
dataDir = filepath.Dir(path)
|
||||
}
|
||||
|
||||
db, err := ml.OpenDBReadWrite(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("open db: %w", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
cfg := ml.ImportConfig{
|
||||
SkipM3: importSkipM3,
|
||||
DataDir: dataDir,
|
||||
M3Host: importM3Host,
|
||||
}
|
||||
|
||||
return ml.ImportAll(db, cfg, cmd.OutOrStdout())
|
||||
}
|
||||
54
internal/cmd/ml/cmd_ingest.go
Normal file
54
internal/cmd/ml/cmd_ingest.go
Normal file
|
|
@ -0,0 +1,54 @@
|
|||
package ml
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"forge.lthn.ai/core/cli/pkg/cli"
|
||||
"forge.lthn.ai/core/cli/pkg/ml"
|
||||
)
|
||||
|
||||
var ingestCmd = &cli.Command{
|
||||
Use: "ingest",
|
||||
Short: "Ingest benchmark scores and training logs into InfluxDB",
|
||||
Long: "Reads content score, capability score, and training log files and writes measurements to InfluxDB for the lab dashboard.",
|
||||
RunE: runIngest,
|
||||
}
|
||||
|
||||
var (
|
||||
ingestContent string
|
||||
ingestCapability string
|
||||
ingestTraining string
|
||||
ingestRunID string
|
||||
ingestBatchSize int
|
||||
)
|
||||
|
||||
func init() {
|
||||
ingestCmd.Flags().StringVar(&ingestContent, "content", "", "Content scores JSONL file")
|
||||
ingestCmd.Flags().StringVar(&ingestCapability, "capability", "", "Capability scores JSONL file")
|
||||
ingestCmd.Flags().StringVar(&ingestTraining, "training-log", "", "MLX LoRA training log file")
|
||||
ingestCmd.Flags().StringVar(&ingestRunID, "run-id", "", "Run ID tag (defaults to model name)")
|
||||
ingestCmd.Flags().IntVar(&ingestBatchSize, "batch-size", 100, "Lines per InfluxDB write batch")
|
||||
}
|
||||
|
||||
func runIngest(cmd *cli.Command, args []string) error {
|
||||
if modelName == "" {
|
||||
return fmt.Errorf("--model is required")
|
||||
}
|
||||
if ingestContent == "" && ingestCapability == "" && ingestTraining == "" {
|
||||
return fmt.Errorf("at least one of --content, --capability, or --training-log is required")
|
||||
}
|
||||
|
||||
influx := ml.NewInfluxClient(influxURL, influxDB)
|
||||
|
||||
cfg := ml.IngestConfig{
|
||||
ContentFile: ingestContent,
|
||||
CapabilityFile: ingestCapability,
|
||||
TrainingLog: ingestTraining,
|
||||
Model: modelName,
|
||||
RunID: ingestRunID,
|
||||
BatchSize: ingestBatchSize,
|
||||
}
|
||||
|
||||
return ml.Ingest(influx, cfg, os.Stdout)
|
||||
}
|
||||
34
internal/cmd/ml/cmd_inventory.go
Normal file
34
internal/cmd/ml/cmd_inventory.go
Normal file
|
|
@ -0,0 +1,34 @@
|
|||
package ml
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"forge.lthn.ai/core/cli/pkg/cli"
|
||||
"forge.lthn.ai/core/cli/pkg/ml"
|
||||
)
|
||||
|
||||
var inventoryCmd = &cli.Command{
|
||||
Use: "inventory",
|
||||
Short: "Show DuckDB table inventory with stats",
|
||||
Long: "Queries all DuckDB tables and prints row counts with per-table detail breakdowns.",
|
||||
RunE: runInventory,
|
||||
}
|
||||
|
||||
func runInventory(cmd *cli.Command, args []string) error {
|
||||
path := dbPath
|
||||
if path == "" {
|
||||
path = os.Getenv("LEM_DB")
|
||||
}
|
||||
if path == "" {
|
||||
return fmt.Errorf("--db or LEM_DB required")
|
||||
}
|
||||
|
||||
db, err := ml.OpenDB(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("open db: %w", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
return ml.PrintInventory(db, os.Stdout)
|
||||
}
|
||||
36
internal/cmd/ml/cmd_metrics.go
Normal file
36
internal/cmd/ml/cmd_metrics.go
Normal file
|
|
@ -0,0 +1,36 @@
|
|||
package ml
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"forge.lthn.ai/core/cli/pkg/cli"
|
||||
"forge.lthn.ai/core/cli/pkg/ml"
|
||||
)
|
||||
|
||||
var metricsCmd = &cli.Command{
|
||||
Use: "metrics",
|
||||
Short: "Push golden set stats to InfluxDB",
|
||||
Long: "Queries golden_set stats from DuckDB and pushes summary, per-domain, and per-voice metrics to InfluxDB.",
|
||||
RunE: runMetrics,
|
||||
}
|
||||
|
||||
func runMetrics(cmd *cli.Command, args []string) error {
|
||||
path := dbPath
|
||||
if path == "" {
|
||||
path = os.Getenv("LEM_DB")
|
||||
}
|
||||
if path == "" {
|
||||
return fmt.Errorf("--db or LEM_DB required")
|
||||
}
|
||||
|
||||
db, err := ml.OpenDB(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("open db: %w", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
influx := ml.NewInfluxClient(influxURL, influxDB)
|
||||
|
||||
return ml.PushMetrics(db, influx, os.Stdout)
|
||||
}
|
||||
|
|
@ -11,6 +11,17 @@
|
|||
// - core ml agent: Run the scoring agent daemon
|
||||
// - core ml worker: Run a distributed worker node
|
||||
// - core ml serve: Start OpenAI-compatible inference server
|
||||
// - core ml inventory: Show DuckDB table inventory with stats
|
||||
// - core ml query: Run ad-hoc SQL against DuckDB
|
||||
// - core ml metrics: Push golden set stats to InfluxDB
|
||||
// - core ml ingest: Ingest benchmark scores and training logs to InfluxDB
|
||||
// - core ml normalize: Deduplicate seeds into expansion prompts
|
||||
// - core ml seed-influx: Migrate golden set from DuckDB to InfluxDB
|
||||
// - core ml consolidate: Pull and merge response JSONL files from M3
|
||||
// - core ml import-all: Import all LEM data into DuckDB
|
||||
// - core ml approve: Filter scored expansions into training JSONL
|
||||
// - core ml publish: Upload Parquet dataset to HuggingFace Hub
|
||||
// - core ml coverage: Analyze seed coverage by region and domain
|
||||
package ml
|
||||
|
||||
import (
|
||||
|
|
@ -40,6 +51,17 @@ func AddMLCommands(root *cli.Command) {
|
|||
mlCmd.AddCommand(agentCmd)
|
||||
mlCmd.AddCommand(workerCmd)
|
||||
mlCmd.AddCommand(serveCmd)
|
||||
mlCmd.AddCommand(inventoryCmd)
|
||||
mlCmd.AddCommand(queryCmd)
|
||||
mlCmd.AddCommand(metricsCmd)
|
||||
mlCmd.AddCommand(ingestCmd)
|
||||
mlCmd.AddCommand(normalizeCmd)
|
||||
mlCmd.AddCommand(seedInfluxCmd)
|
||||
mlCmd.AddCommand(consolidateCmd)
|
||||
mlCmd.AddCommand(importCmd)
|
||||
mlCmd.AddCommand(approveCmd)
|
||||
mlCmd.AddCommand(publishCmd)
|
||||
mlCmd.AddCommand(coverageCmd)
|
||||
root.AddCommand(mlCmd)
|
||||
}
|
||||
|
||||
|
|
|
|||
44
internal/cmd/ml/cmd_normalize.go
Normal file
44
internal/cmd/ml/cmd_normalize.go
Normal file
|
|
@ -0,0 +1,44 @@
|
|||
package ml
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"forge.lthn.ai/core/cli/pkg/cli"
|
||||
"forge.lthn.ai/core/cli/pkg/ml"
|
||||
)
|
||||
|
||||
var normalizeMinLen int
|
||||
|
||||
var normalizeCmd = &cli.Command{
|
||||
Use: "normalize",
|
||||
Short: "Normalize seeds into expansion prompts",
|
||||
Long: "Deduplicates seeds against golden_set and prompts, creating the expansion_prompts table with priority-based ordering.",
|
||||
RunE: runNormalize,
|
||||
}
|
||||
|
||||
func init() {
|
||||
normalizeCmd.Flags().IntVar(&normalizeMinLen, "min-length", 50, "Minimum prompt length in characters")
|
||||
}
|
||||
|
||||
func runNormalize(cmd *cli.Command, args []string) error {
|
||||
path := dbPath
|
||||
if path == "" {
|
||||
path = os.Getenv("LEM_DB")
|
||||
}
|
||||
if path == "" {
|
||||
return fmt.Errorf("--db or LEM_DB env is required")
|
||||
}
|
||||
|
||||
db, err := ml.OpenDBReadWrite(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("open db: %w", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
cfg := ml.NormalizeConfig{
|
||||
MinLength: normalizeMinLen,
|
||||
}
|
||||
|
||||
return ml.NormalizeSeeds(db, cfg, os.Stdout)
|
||||
}
|
||||
40
internal/cmd/ml/cmd_publish.go
Normal file
40
internal/cmd/ml/cmd_publish.go
Normal file
|
|
@ -0,0 +1,40 @@
|
|||
package ml
|
||||
|
||||
import (
|
||||
"forge.lthn.ai/core/cli/pkg/cli"
|
||||
"forge.lthn.ai/core/cli/pkg/ml"
|
||||
)
|
||||
|
||||
var (
|
||||
publishInputDir string
|
||||
publishRepo string
|
||||
publishPublic bool
|
||||
publishToken string
|
||||
publishDryRun bool
|
||||
)
|
||||
|
||||
var publishCmd = &cli.Command{
|
||||
Use: "publish",
|
||||
Short: "Upload Parquet dataset to HuggingFace Hub",
|
||||
Long: "Uploads train/valid/test Parquet files and an optional dataset card to a HuggingFace dataset repository.",
|
||||
RunE: runPublish,
|
||||
}
|
||||
|
||||
func init() {
|
||||
publishCmd.Flags().StringVar(&publishInputDir, "input-dir", "", "Directory containing Parquet files (required)")
|
||||
publishCmd.Flags().StringVar(&publishRepo, "repo", "lthn/LEM-golden-set", "HuggingFace dataset repo ID")
|
||||
publishCmd.Flags().BoolVar(&publishPublic, "public", false, "Make dataset public")
|
||||
publishCmd.Flags().StringVar(&publishToken, "token", "", "HuggingFace API token (defaults to HF_TOKEN env)")
|
||||
publishCmd.Flags().BoolVar(&publishDryRun, "dry-run", false, "Show what would be uploaded without uploading")
|
||||
_ = publishCmd.MarkFlagRequired("input-dir")
|
||||
}
|
||||
|
||||
func runPublish(cmd *cli.Command, args []string) error {
|
||||
return ml.Publish(ml.PublishConfig{
|
||||
InputDir: publishInputDir,
|
||||
Repo: publishRepo,
|
||||
Public: publishPublic,
|
||||
Token: publishToken,
|
||||
DryRun: publishDryRun,
|
||||
}, cmd.OutOrStdout())
|
||||
}
|
||||
148
internal/cmd/ml/cmd_query.go
Normal file
148
internal/cmd/ml/cmd_query.go
Normal file
|
|
@ -0,0 +1,148 @@
|
|||
package ml
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"forge.lthn.ai/core/cli/pkg/cli"
|
||||
"forge.lthn.ai/core/cli/pkg/ml"
|
||||
)
|
||||
|
||||
var queryCmd = &cli.Command{
|
||||
Use: "query [sql]",
|
||||
Short: "Run ad-hoc SQL against DuckDB",
|
||||
Long: "Executes arbitrary SQL against the DuckDB database. Non-SELECT queries are auto-wrapped as golden_set WHERE clauses.",
|
||||
Example: ` core ml query "SELECT COUNT(*) FROM golden_set"
|
||||
core ml query "domain = 'ethics'"
|
||||
core ml query --json "SHOW TABLES"`,
|
||||
Args: cli.MinimumNArgs(1),
|
||||
RunE: runQuery,
|
||||
}
|
||||
|
||||
var queryJSON bool
|
||||
|
||||
func init() {
|
||||
queryCmd.Flags().BoolVar(&queryJSON, "json", false, "Output as JSON")
|
||||
}
|
||||
|
||||
func runQuery(cmd *cli.Command, args []string) error {
|
||||
path := dbPath
|
||||
if path == "" {
|
||||
path = os.Getenv("LEM_DB")
|
||||
}
|
||||
if path == "" {
|
||||
return fmt.Errorf("--db or LEM_DB env is required")
|
||||
}
|
||||
|
||||
db, err := ml.OpenDB(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("open db: %w", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
sql := strings.Join(args, " ")
|
||||
|
||||
// Auto-wrap non-SELECT queries as golden_set WHERE clauses.
|
||||
trimmed := strings.TrimSpace(strings.ToUpper(sql))
|
||||
if !strings.HasPrefix(trimmed, "SELECT") && !strings.HasPrefix(trimmed, "SHOW") &&
|
||||
!strings.HasPrefix(trimmed, "DESCRIBE") && !strings.HasPrefix(trimmed, "EXPLAIN") {
|
||||
sql = "SELECT * FROM golden_set WHERE " + sql + " LIMIT 20"
|
||||
}
|
||||
|
||||
rows, err := db.QueryRows(sql)
|
||||
if err != nil {
|
||||
return fmt.Errorf("query: %w", err)
|
||||
}
|
||||
|
||||
if queryJSON {
|
||||
enc := json.NewEncoder(os.Stdout)
|
||||
enc.SetIndent("", " ")
|
||||
if err := enc.Encode(rows); err != nil {
|
||||
return fmt.Errorf("encode json: %w", err)
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "\n(%d rows)\n", len(rows))
|
||||
return nil
|
||||
}
|
||||
|
||||
if len(rows) == 0 {
|
||||
fmt.Println("(0 rows)")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Collect column names in stable order from first row.
|
||||
var cols []string
|
||||
for col := range rows[0] {
|
||||
cols = append(cols, col)
|
||||
}
|
||||
|
||||
// Calculate column widths (capped at 60).
|
||||
const maxWidth = 60
|
||||
widths := make([]int, len(cols))
|
||||
for i, col := range cols {
|
||||
widths[i] = len(col)
|
||||
}
|
||||
for _, row := range rows {
|
||||
for i, col := range cols {
|
||||
val := formatValue(row[col])
|
||||
if l := len(val); l > widths[i] {
|
||||
widths[i] = l
|
||||
}
|
||||
}
|
||||
}
|
||||
for i := range widths {
|
||||
if widths[i] > maxWidth {
|
||||
widths[i] = maxWidth
|
||||
}
|
||||
}
|
||||
|
||||
// Print header.
|
||||
for i, col := range cols {
|
||||
if i > 0 {
|
||||
fmt.Print(" | ")
|
||||
}
|
||||
fmt.Printf("%-*s", widths[i], truncate(col, widths[i]))
|
||||
}
|
||||
fmt.Println()
|
||||
|
||||
// Print separator.
|
||||
for i := range cols {
|
||||
if i > 0 {
|
||||
fmt.Print("-+-")
|
||||
}
|
||||
fmt.Print(strings.Repeat("-", widths[i]))
|
||||
}
|
||||
fmt.Println()
|
||||
|
||||
// Print rows.
|
||||
for _, row := range rows {
|
||||
for i, col := range cols {
|
||||
if i > 0 {
|
||||
fmt.Print(" | ")
|
||||
}
|
||||
fmt.Printf("%-*s", widths[i], truncate(formatValue(row[col]), widths[i]))
|
||||
}
|
||||
fmt.Println()
|
||||
}
|
||||
|
||||
fmt.Printf("\n(%d rows)\n", len(rows))
|
||||
return nil
|
||||
}
|
||||
|
||||
func formatValue(v interface{}) string {
|
||||
if v == nil {
|
||||
return "NULL"
|
||||
}
|
||||
return fmt.Sprintf("%v", v)
|
||||
}
|
||||
|
||||
func truncate(s string, max int) string {
|
||||
if len(s) <= max {
|
||||
return s
|
||||
}
|
||||
if max <= 3 {
|
||||
return s[:max]
|
||||
}
|
||||
return s[:max-3] + "..."
|
||||
}
|
||||
49
internal/cmd/ml/cmd_seed_influx.go
Normal file
49
internal/cmd/ml/cmd_seed_influx.go
Normal file
|
|
@ -0,0 +1,49 @@
|
|||
package ml
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"forge.lthn.ai/core/cli/pkg/cli"
|
||||
"forge.lthn.ai/core/cli/pkg/ml"
|
||||
)
|
||||
|
||||
var seedInfluxCmd = &cli.Command{
|
||||
Use: "seed-influx",
|
||||
Short: "Seed InfluxDB golden_gen from DuckDB golden_set",
|
||||
Long: "One-time migration: batch-loads DuckDB golden_set records into InfluxDB golden_gen measurement.",
|
||||
RunE: runSeedInflux,
|
||||
}
|
||||
|
||||
var (
|
||||
seedInfluxForce bool
|
||||
seedInfluxBatchSize int
|
||||
)
|
||||
|
||||
func init() {
|
||||
seedInfluxCmd.Flags().BoolVar(&seedInfluxForce, "force", false, "Re-seed even if InfluxDB already has data")
|
||||
seedInfluxCmd.Flags().IntVar(&seedInfluxBatchSize, "batch-size", 500, "Lines per InfluxDB write batch")
|
||||
}
|
||||
|
||||
func runSeedInflux(cmd *cli.Command, args []string) error {
|
||||
path := dbPath
|
||||
if path == "" {
|
||||
path = os.Getenv("LEM_DB")
|
||||
}
|
||||
if path == "" {
|
||||
return fmt.Errorf("--db or LEM_DB required")
|
||||
}
|
||||
|
||||
db, err := ml.OpenDB(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("open db: %w", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
influx := ml.NewInfluxClient(influxURL, influxDB)
|
||||
|
||||
return ml.SeedInflux(db, influx, ml.SeedInfluxConfig{
|
||||
Force: seedInfluxForce,
|
||||
BatchSize: seedInfluxBatchSize,
|
||||
}, os.Stdout)
|
||||
}
|
||||
82
pkg/ml/approve.go
Normal file
82
pkg/ml/approve.go
Normal file
|
|
@ -0,0 +1,82 @@
|
|||
package ml
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
)
|
||||
|
||||
// ApproveConfig holds options for the approve operation.
|
||||
type ApproveConfig struct {
|
||||
Output string
|
||||
Threshold float64
|
||||
}
|
||||
|
||||
// ApproveExpansions filters scored expansion responses above the threshold
|
||||
// and writes approved examples to a training JSONL file.
|
||||
//
|
||||
// The query joins expansion_raw with expansion_scores, keeping rows where
|
||||
// the heuristic passed AND the judge either passed or has not yet scored.
|
||||
// Each approved row is written as a chat-format JSONL line with user/assistant
|
||||
// messages.
|
||||
func ApproveExpansions(db *DB, cfg ApproveConfig, w io.Writer) error {
|
||||
rows, err := db.conn.Query(`
|
||||
SELECT r.idx, r.seed_id, r.region, r.domain, r.prompt, r.response,
|
||||
r.gen_time, r.model, s.heuristic_score
|
||||
FROM expansion_raw r
|
||||
JOIN expansion_scores s ON r.idx = s.idx
|
||||
WHERE s.heuristic_pass = true
|
||||
AND (s.judge_pass = true OR s.judge_pass IS NULL)
|
||||
ORDER BY r.idx
|
||||
`)
|
||||
if err != nil {
|
||||
return fmt.Errorf("query approved expansions: %w (have you run scoring?)", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
f, err := os.Create(cfg.Output)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create output %s: %w", cfg.Output, err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
enc := json.NewEncoder(f)
|
||||
count := 0
|
||||
regionSet := make(map[string]bool)
|
||||
domainSet := make(map[string]bool)
|
||||
|
||||
for rows.Next() {
|
||||
var idx int
|
||||
var seedID, region, domain, prompt, response, model string
|
||||
var genTime, score float64
|
||||
if err := rows.Scan(&idx, &seedID, ®ion, &domain, &prompt, &response, &genTime, &model, &score); err != nil {
|
||||
return fmt.Errorf("scan approved row: %w", err)
|
||||
}
|
||||
|
||||
example := TrainingExample{
|
||||
Messages: []ChatMessage{
|
||||
{Role: "user", Content: prompt},
|
||||
{Role: "assistant", Content: response},
|
||||
},
|
||||
}
|
||||
|
||||
if err := enc.Encode(example); err != nil {
|
||||
return fmt.Errorf("encode example: %w", err)
|
||||
}
|
||||
|
||||
regionSet[region] = true
|
||||
domainSet[domain] = true
|
||||
count++
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return fmt.Errorf("iterate approved rows: %w", err)
|
||||
}
|
||||
|
||||
fmt.Fprintf(w, "Approved: %d responses (threshold: heuristic > 0)\n", count)
|
||||
fmt.Fprintf(w, "Exported: %s\n", cfg.Output)
|
||||
fmt.Fprintf(w, " Regions: %d, Domains: %d\n", len(regionSet), len(domainSet))
|
||||
|
||||
return nil
|
||||
}
|
||||
150
pkg/ml/consolidate.go
Normal file
150
pkg/ml/consolidate.go
Normal file
|
|
@ -0,0 +1,150 @@
|
|||
package ml
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ConsolidateConfig holds options for the consolidate operation.
|
||||
type ConsolidateConfig struct {
|
||||
M3Host string
|
||||
RemoteDir string
|
||||
Pattern string
|
||||
OutputDir string
|
||||
MergedOut string
|
||||
}
|
||||
|
||||
// Consolidate pulls JSONL response files from M3 via SSH, merges them by idx,
|
||||
// deduplicates, and writes a single merged JSONL output.
|
||||
func Consolidate(cfg ConsolidateConfig, w io.Writer) error {
|
||||
if cfg.OutputDir == "" {
|
||||
cfg.OutputDir = "responses"
|
||||
}
|
||||
if err := os.MkdirAll(cfg.OutputDir, 0755); err != nil {
|
||||
return fmt.Errorf("create output dir: %w", err)
|
||||
}
|
||||
|
||||
// List remote files via SSH.
|
||||
fmt.Fprintln(w, "Pulling responses from remote...")
|
||||
listCmd := exec.Command("ssh", cfg.M3Host, fmt.Sprintf("ls %s/%s", cfg.RemoteDir, cfg.Pattern))
|
||||
listOutput, err := listCmd.Output()
|
||||
if err != nil {
|
||||
return fmt.Errorf("list remote files: %w", err)
|
||||
}
|
||||
|
||||
remoteFiles := strings.Split(strings.TrimSpace(string(listOutput)), "\n")
|
||||
var validFiles []string
|
||||
for _, f := range remoteFiles {
|
||||
f = strings.TrimSpace(f)
|
||||
if f != "" {
|
||||
validFiles = append(validFiles, f)
|
||||
}
|
||||
}
|
||||
fmt.Fprintf(w, " Found %d JSONL files on %s\n", len(validFiles), cfg.M3Host)
|
||||
|
||||
// Pull each file via SCP.
|
||||
for _, rf := range validFiles {
|
||||
local := filepath.Join(cfg.OutputDir, filepath.Base(rf))
|
||||
scpCmd := exec.Command("scp", fmt.Sprintf("%s:%s", cfg.M3Host, rf), local)
|
||||
if err := scpCmd.Run(); err != nil {
|
||||
fmt.Fprintf(w, " warning: failed to pull %s: %v\n", rf, err)
|
||||
continue
|
||||
}
|
||||
|
||||
lines, err := countLines(local)
|
||||
if err == nil {
|
||||
fmt.Fprintf(w, " %s: %d records\n", filepath.Base(rf), lines)
|
||||
}
|
||||
}
|
||||
|
||||
// Merge and deduplicate on idx (first occurrence wins).
|
||||
seen := make(map[int]json.RawMessage)
|
||||
skipped := 0
|
||||
|
||||
matches, _ := filepath.Glob(filepath.Join(cfg.OutputDir, cfg.Pattern))
|
||||
sort.Strings(matches)
|
||||
|
||||
for _, local := range matches {
|
||||
f, err := os.Open(local)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
scanner := bufio.NewScanner(f)
|
||||
scanner.Buffer(make([]byte, 1024*1024), 1024*1024)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
var rec struct {
|
||||
Idx *int `json:"idx"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(line), &rec); err != nil {
|
||||
skipped++
|
||||
continue
|
||||
}
|
||||
if rec.Idx == nil {
|
||||
skipped++
|
||||
continue
|
||||
}
|
||||
if _, exists := seen[*rec.Idx]; !exists {
|
||||
seen[*rec.Idx] = json.RawMessage(line)
|
||||
}
|
||||
}
|
||||
f.Close()
|
||||
}
|
||||
|
||||
if skipped > 0 {
|
||||
fmt.Fprintf(w, " Skipped %d records without idx\n", skipped)
|
||||
}
|
||||
|
||||
// Sort by idx and write merged file.
|
||||
mergedPath := cfg.MergedOut
|
||||
if mergedPath == "" {
|
||||
mergedPath = filepath.Join(cfg.OutputDir, "..", "gold-merged.jsonl")
|
||||
}
|
||||
|
||||
idxs := make([]int, 0, len(seen))
|
||||
for idx := range seen {
|
||||
idxs = append(idxs, idx)
|
||||
}
|
||||
sort.Ints(idxs)
|
||||
|
||||
out, err := os.Create(mergedPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create merged file: %w", err)
|
||||
}
|
||||
defer out.Close()
|
||||
|
||||
bw := bufio.NewWriter(out)
|
||||
for _, idx := range idxs {
|
||||
bw.Write(seen[idx])
|
||||
bw.WriteString("\n")
|
||||
}
|
||||
if err := bw.Flush(); err != nil {
|
||||
return fmt.Errorf("flush merged file: %w", err)
|
||||
}
|
||||
|
||||
fmt.Fprintf(w, "\nMerged: %d unique examples -> %s\n", len(seen), mergedPath)
|
||||
return nil
|
||||
}
|
||||
|
||||
// countLines returns the number of lines in a file.
|
||||
func countLines(path string) (int, error) {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
count := 0
|
||||
scanner := bufio.NewScanner(f)
|
||||
for scanner.Scan() {
|
||||
count++
|
||||
}
|
||||
return count, scanner.Err()
|
||||
}
|
||||
127
pkg/ml/coverage.go
Normal file
127
pkg/ml/coverage.go
Normal file
|
|
@ -0,0 +1,127 @@
|
|||
package ml
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// regionRow holds a single row from the region distribution query.
|
||||
type regionRow struct {
|
||||
group string
|
||||
n int
|
||||
domains int
|
||||
}
|
||||
|
||||
// PrintCoverage analyzes seed coverage by region and domain, printing
|
||||
// a report with bar chart visualization and gap recommendations.
|
||||
func PrintCoverage(db *DB, w io.Writer) error {
|
||||
rows, err := db.QueryRows("SELECT count(*) AS total FROM seeds")
|
||||
if err != nil {
|
||||
return fmt.Errorf("count seeds: %w (run: core ml import-all first)", err)
|
||||
}
|
||||
if len(rows) == 0 {
|
||||
return fmt.Errorf("no seeds table found (run: core ml import-all first)")
|
||||
}
|
||||
total := toInt(rows[0]["total"])
|
||||
|
||||
fmt.Fprintln(w, "LEM Seed Coverage Analysis")
|
||||
fmt.Fprintln(w, "==================================================")
|
||||
fmt.Fprintf(w, "\nTotal seeds: %d\n", total)
|
||||
|
||||
// Region distribution.
|
||||
regionRows, err := queryRegionDistribution(db)
|
||||
if err != nil {
|
||||
return fmt.Errorf("query regions: %w", err)
|
||||
}
|
||||
|
||||
fmt.Fprintln(w, "\nRegion distribution (underrepresented first):")
|
||||
avg := float64(total) / float64(len(regionRows))
|
||||
for _, r := range regionRows {
|
||||
barLen := int(float64(r.n) / avg * 10)
|
||||
if barLen > 40 {
|
||||
barLen = 40
|
||||
}
|
||||
bar := strings.Repeat("#", barLen)
|
||||
gap := ""
|
||||
if float64(r.n) < avg*0.5 {
|
||||
gap = " <- UNDERREPRESENTED"
|
||||
}
|
||||
fmt.Fprintf(w, " %-22s %6d (%4d domains) %s%s\n", r.group, r.n, r.domains, bar, gap)
|
||||
}
|
||||
|
||||
// Top 10 domains.
|
||||
fmt.Fprintln(w, "\nTop 10 domains (most seeds):")
|
||||
topRows, err := db.QueryRows(`
|
||||
SELECT domain, count(*) AS n FROM seeds
|
||||
WHERE domain != '' GROUP BY domain ORDER BY n DESC LIMIT 10
|
||||
`)
|
||||
if err == nil {
|
||||
for _, row := range topRows {
|
||||
domain := strVal(row, "domain")
|
||||
n := toInt(row["n"])
|
||||
fmt.Fprintf(w, " %-40s %5d\n", domain, n)
|
||||
}
|
||||
}
|
||||
|
||||
// Bottom 10 domains.
|
||||
fmt.Fprintln(w, "\nBottom 10 domains (fewest seeds, min 5):")
|
||||
bottomRows, err := db.QueryRows(`
|
||||
SELECT domain, count(*) AS n FROM seeds
|
||||
WHERE domain != '' GROUP BY domain HAVING count(*) >= 5 ORDER BY n ASC LIMIT 10
|
||||
`)
|
||||
if err == nil {
|
||||
for _, row := range bottomRows {
|
||||
domain := strVal(row, "domain")
|
||||
n := toInt(row["n"])
|
||||
fmt.Fprintf(w, " %-40s %5d\n", domain, n)
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Fprintln(w, "\nSuggested expansion areas:")
|
||||
fmt.Fprintln(w, " - Japanese, Korean, Thai, Vietnamese (no seeds found)")
|
||||
fmt.Fprintln(w, " - Hindi/Urdu, Bengali, Tamil (South Asian)")
|
||||
fmt.Fprintln(w, " - Swahili, Yoruba, Amharic (Sub-Saharan Africa)")
|
||||
fmt.Fprintln(w, " - Indigenous languages (Quechua, Nahuatl, Aymara)")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// queryRegionDistribution returns seed counts grouped by normalized language
|
||||
// region, ordered ascending (underrepresented first).
|
||||
func queryRegionDistribution(db *DB) ([]regionRow, error) {
|
||||
rows, err := db.QueryRows(`
|
||||
SELECT
|
||||
CASE
|
||||
WHEN region LIKE '%cn%' THEN 'cn (Chinese)'
|
||||
WHEN region LIKE '%en-%' OR region LIKE '%en_para%' OR region LIKE '%para%' THEN 'en (English)'
|
||||
WHEN region LIKE '%ru%' THEN 'ru (Russian)'
|
||||
WHEN region LIKE '%de%' AND region NOT LIKE '%deten%' THEN 'de (German)'
|
||||
WHEN region LIKE '%es%' THEN 'es (Spanish)'
|
||||
WHEN region LIKE '%fr%' THEN 'fr (French)'
|
||||
WHEN region LIKE '%latam%' THEN 'latam (LatAm)'
|
||||
WHEN region LIKE '%africa%' THEN 'africa'
|
||||
WHEN region LIKE '%eu%' THEN 'eu (European)'
|
||||
WHEN region LIKE '%me%' AND region NOT LIKE '%premium%' THEN 'me (MidEast)'
|
||||
WHEN region LIKE '%multi%' THEN 'multilingual'
|
||||
WHEN region LIKE '%weak%' THEN 'weak-langs'
|
||||
ELSE 'other'
|
||||
END AS lang_group,
|
||||
count(*) AS n,
|
||||
count(DISTINCT domain) AS domains
|
||||
FROM seeds GROUP BY lang_group ORDER BY n ASC
|
||||
`)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := make([]regionRow, 0, len(rows))
|
||||
for _, row := range rows {
|
||||
result = append(result, regionRow{
|
||||
group: strVal(row, "lang_group"),
|
||||
n: toInt(row["n"]),
|
||||
domains: toInt(row["domains"]),
|
||||
})
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
17
pkg/ml/db.go
17
pkg/ml/db.go
|
|
@ -45,6 +45,23 @@ func (db *DB) Close() error {
|
|||
return db.conn.Close()
|
||||
}
|
||||
|
||||
// Path returns the database file path.
|
||||
func (db *DB) Path() string {
|
||||
return db.path
|
||||
}
|
||||
|
||||
// Exec executes a query without returning rows.
|
||||
func (db *DB) Exec(query string, args ...interface{}) error {
|
||||
_, err := db.conn.Exec(query, args...)
|
||||
return err
|
||||
}
|
||||
|
||||
// QueryRowScan executes a query expected to return at most one row and scans
|
||||
// the result into dest. It is a convenience wrapper around sql.DB.QueryRow.
|
||||
func (db *DB) QueryRowScan(query string, dest interface{}, args ...interface{}) error {
|
||||
return db.conn.QueryRow(query, args...).Scan(dest)
|
||||
}
|
||||
|
||||
// GoldenSetRow represents one row from the golden_set table.
|
||||
type GoldenSetRow struct {
|
||||
Idx int
|
||||
|
|
|
|||
437
pkg/ml/import_all.go
Normal file
437
pkg/ml/import_all.go
Normal file
|
|
@ -0,0 +1,437 @@
|
|||
package ml
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ImportConfig holds options for the import-all operation.
|
||||
type ImportConfig struct {
|
||||
SkipM3 bool
|
||||
DataDir string
|
||||
M3Host string
|
||||
}
|
||||
|
||||
// ImportAll imports all LEM data into DuckDB from M3 and local files.
|
||||
func ImportAll(db *DB, cfg ImportConfig, w io.Writer) error {
|
||||
m3Host := cfg.M3Host
|
||||
if m3Host == "" {
|
||||
m3Host = "m3"
|
||||
}
|
||||
|
||||
totals := make(map[string]int)
|
||||
|
||||
// ── 1. Golden set ──
|
||||
goldenPath := filepath.Join(cfg.DataDir, "gold-15k.jsonl")
|
||||
if !cfg.SkipM3 {
|
||||
fmt.Fprintln(w, " Pulling golden set from M3...")
|
||||
scpCmd := exec.Command("scp", fmt.Sprintf("%s:/Volumes/Data/lem/responses/gold-15k.jsonl", m3Host), goldenPath)
|
||||
if err := scpCmd.Run(); err != nil {
|
||||
fmt.Fprintf(w, " WARNING: could not pull golden set from M3: %v\n", err)
|
||||
}
|
||||
}
|
||||
if _, err := os.Stat(goldenPath); err == nil {
|
||||
db.Exec("DROP TABLE IF EXISTS golden_set")
|
||||
err := db.Exec(fmt.Sprintf(`
|
||||
CREATE TABLE golden_set AS
|
||||
SELECT
|
||||
idx::INT AS idx,
|
||||
seed_id::VARCHAR AS seed_id,
|
||||
domain::VARCHAR AS domain,
|
||||
voice::VARCHAR AS voice,
|
||||
prompt::VARCHAR AS prompt,
|
||||
response::VARCHAR AS response,
|
||||
gen_time::DOUBLE AS gen_time,
|
||||
length(response)::INT AS char_count,
|
||||
length(response) - length(replace(response, ' ', '')) + 1 AS word_count
|
||||
FROM read_json_auto('%s', maximum_object_size=1048576)
|
||||
`, escapeSQLPath(goldenPath)))
|
||||
if err != nil {
|
||||
fmt.Fprintf(w, " WARNING: golden set import failed: %v\n", err)
|
||||
} else {
|
||||
var n int
|
||||
db.QueryRowScan("SELECT count(*) FROM golden_set", &n)
|
||||
totals["golden_set"] = n
|
||||
fmt.Fprintf(w, " golden_set: %d rows\n", n)
|
||||
}
|
||||
}
|
||||
|
||||
// ── 2. Training examples ──
|
||||
trainingDirs := []struct {
|
||||
name string
|
||||
files []string
|
||||
}{
|
||||
{"training", []string{"training/train.jsonl", "training/valid.jsonl", "training/test.jsonl"}},
|
||||
{"training-2k", []string{"training-2k/train.jsonl", "training-2k/valid.jsonl", "training-2k/test.jsonl"}},
|
||||
{"training-expanded", []string{"training-expanded/train.jsonl", "training-expanded/valid.jsonl"}},
|
||||
{"training-book", []string{"training-book/train.jsonl", "training-book/valid.jsonl", "training-book/test.jsonl"}},
|
||||
{"training-conv", []string{"training-conv/train.jsonl", "training-conv/valid.jsonl", "training-conv/test.jsonl"}},
|
||||
{"gold-full", []string{"gold-full/train.jsonl", "gold-full/valid.jsonl"}},
|
||||
{"sovereignty-gold", []string{"sovereignty-gold/train.jsonl", "sovereignty-gold/valid.jsonl"}},
|
||||
{"composure-lessons", []string{"composure-lessons/train.jsonl", "composure-lessons/valid.jsonl"}},
|
||||
{"watts-full", []string{"watts-full/train.jsonl", "watts-full/valid.jsonl"}},
|
||||
{"watts-expanded", []string{"watts-expanded/train.jsonl", "watts-expanded/valid.jsonl"}},
|
||||
{"watts-composure", []string{"watts-composure-merged/train.jsonl", "watts-composure-merged/valid.jsonl"}},
|
||||
{"western-fresh", []string{"western-fresh/train.jsonl", "western-fresh/valid.jsonl"}},
|
||||
{"deepseek-soak", []string{"deepseek-western-soak/train.jsonl", "deepseek-western-soak/valid.jsonl"}},
|
||||
{"russian-bridge", []string{"russian-bridge/train.jsonl", "russian-bridge/valid.jsonl"}},
|
||||
}
|
||||
|
||||
trainingLocal := filepath.Join(cfg.DataDir, "training")
|
||||
os.MkdirAll(trainingLocal, 0755)
|
||||
|
||||
if !cfg.SkipM3 {
|
||||
fmt.Fprintln(w, " Pulling training sets from M3...")
|
||||
for _, td := range trainingDirs {
|
||||
for _, rel := range td.files {
|
||||
local := filepath.Join(trainingLocal, rel)
|
||||
os.MkdirAll(filepath.Dir(local), 0755)
|
||||
scpCmd := exec.Command("scp", fmt.Sprintf("%s:/Volumes/Data/lem/%s", m3Host, rel), local)
|
||||
scpCmd.Run() // ignore errors, file might not exist
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
db.Exec("DROP TABLE IF EXISTS training_examples")
|
||||
db.Exec(`
|
||||
CREATE TABLE training_examples (
|
||||
source VARCHAR,
|
||||
split VARCHAR,
|
||||
prompt TEXT,
|
||||
response TEXT,
|
||||
num_turns INT,
|
||||
full_messages TEXT,
|
||||
char_count INT
|
||||
)
|
||||
`)
|
||||
|
||||
trainingTotal := 0
|
||||
for _, td := range trainingDirs {
|
||||
for _, rel := range td.files {
|
||||
local := filepath.Join(trainingLocal, rel)
|
||||
if _, err := os.Stat(local); os.IsNotExist(err) {
|
||||
continue
|
||||
}
|
||||
|
||||
split := "train"
|
||||
if strings.Contains(rel, "valid") {
|
||||
split = "valid"
|
||||
} else if strings.Contains(rel, "test") {
|
||||
split = "test"
|
||||
}
|
||||
|
||||
n := importTrainingFile(db, local, td.name, split)
|
||||
trainingTotal += n
|
||||
}
|
||||
}
|
||||
totals["training_examples"] = trainingTotal
|
||||
fmt.Fprintf(w, " training_examples: %d rows\n", trainingTotal)
|
||||
|
||||
// ── 3. Benchmark results ──
|
||||
benchLocal := filepath.Join(cfg.DataDir, "benchmarks")
|
||||
os.MkdirAll(benchLocal, 0755)
|
||||
|
||||
if !cfg.SkipM3 {
|
||||
fmt.Fprintln(w, " Pulling benchmarks from M3...")
|
||||
for _, bname := range []string{"truthfulqa", "gsm8k", "do_not_answer", "toxigen"} {
|
||||
scpCmd := exec.Command("scp",
|
||||
fmt.Sprintf("%s:/Volumes/Data/lem/benchmarks/%s.jsonl", m3Host, bname),
|
||||
filepath.Join(benchLocal, bname+".jsonl"))
|
||||
scpCmd.Run()
|
||||
}
|
||||
for _, subdir := range []string{"results", "scale_results", "cross_arch_results", "deepseek-r1-7b"} {
|
||||
localSub := filepath.Join(benchLocal, subdir)
|
||||
os.MkdirAll(localSub, 0755)
|
||||
scpCmd := exec.Command("scp", "-r",
|
||||
fmt.Sprintf("%s:/Volumes/Data/lem/benchmarks/%s/", m3Host, subdir),
|
||||
filepath.Join(benchLocal)+"/")
|
||||
scpCmd.Run()
|
||||
}
|
||||
}
|
||||
|
||||
db.Exec("DROP TABLE IF EXISTS benchmark_results")
|
||||
db.Exec(`
|
||||
CREATE TABLE benchmark_results (
|
||||
source VARCHAR, id VARCHAR, benchmark VARCHAR, model VARCHAR,
|
||||
prompt TEXT, response TEXT, elapsed_seconds DOUBLE, domain VARCHAR
|
||||
)
|
||||
`)
|
||||
|
||||
benchTotal := 0
|
||||
for _, subdir := range []string{"results", "scale_results", "cross_arch_results", "deepseek-r1-7b"} {
|
||||
resultDir := filepath.Join(benchLocal, subdir)
|
||||
matches, _ := filepath.Glob(filepath.Join(resultDir, "*.jsonl"))
|
||||
for _, jf := range matches {
|
||||
n := importBenchmarkFile(db, jf, subdir)
|
||||
benchTotal += n
|
||||
}
|
||||
}
|
||||
|
||||
// Also import standalone benchmark files.
|
||||
for _, bfile := range []string{"lem_bench", "lem_ethics", "lem_ethics_allen", "instruction_tuned", "abliterated", "base_pt"} {
|
||||
local := filepath.Join(benchLocal, bfile+".jsonl")
|
||||
if _, err := os.Stat(local); os.IsNotExist(err) {
|
||||
if !cfg.SkipM3 {
|
||||
scpCmd := exec.Command("scp",
|
||||
fmt.Sprintf("%s:/Volumes/Data/lem/benchmark/%s.jsonl", m3Host, bfile), local)
|
||||
scpCmd.Run()
|
||||
}
|
||||
}
|
||||
if _, err := os.Stat(local); err == nil {
|
||||
n := importBenchmarkFile(db, local, "benchmark")
|
||||
benchTotal += n
|
||||
}
|
||||
}
|
||||
totals["benchmark_results"] = benchTotal
|
||||
fmt.Fprintf(w, " benchmark_results: %d rows\n", benchTotal)
|
||||
|
||||
// ── 4. Benchmark questions ──
|
||||
db.Exec("DROP TABLE IF EXISTS benchmark_questions")
|
||||
db.Exec(`
|
||||
CREATE TABLE benchmark_questions (
|
||||
benchmark VARCHAR, id VARCHAR, question TEXT,
|
||||
best_answer TEXT, correct_answers TEXT, incorrect_answers TEXT, category VARCHAR
|
||||
)
|
||||
`)
|
||||
|
||||
benchQTotal := 0
|
||||
for _, bname := range []string{"truthfulqa", "gsm8k", "do_not_answer", "toxigen"} {
|
||||
local := filepath.Join(benchLocal, bname+".jsonl")
|
||||
if _, err := os.Stat(local); err == nil {
|
||||
n := importBenchmarkQuestions(db, local, bname)
|
||||
benchQTotal += n
|
||||
}
|
||||
}
|
||||
totals["benchmark_questions"] = benchQTotal
|
||||
fmt.Fprintf(w, " benchmark_questions: %d rows\n", benchQTotal)
|
||||
|
||||
// ── 5. Seeds ──
|
||||
db.Exec("DROP TABLE IF EXISTS seeds")
|
||||
db.Exec(`
|
||||
CREATE TABLE seeds (
|
||||
source_file VARCHAR, region VARCHAR, seed_id VARCHAR, domain VARCHAR, prompt TEXT
|
||||
)
|
||||
`)
|
||||
|
||||
seedTotal := 0
|
||||
seedDirs := []string{filepath.Join(cfg.DataDir, "seeds"), "/tmp/lem-data/seeds", "/tmp/lem-repo/seeds"}
|
||||
for _, seedDir := range seedDirs {
|
||||
if _, err := os.Stat(seedDir); os.IsNotExist(err) {
|
||||
continue
|
||||
}
|
||||
n := importSeeds(db, seedDir)
|
||||
seedTotal += n
|
||||
}
|
||||
totals["seeds"] = seedTotal
|
||||
fmt.Fprintf(w, " seeds: %d rows\n", seedTotal)
|
||||
|
||||
// ── Summary ──
|
||||
grandTotal := 0
|
||||
fmt.Fprintf(w, "\n%s\n", strings.Repeat("=", 50))
|
||||
fmt.Fprintln(w, "LEM Database Import Complete")
|
||||
fmt.Fprintln(w, strings.Repeat("=", 50))
|
||||
for table, count := range totals {
|
||||
fmt.Fprintf(w, " %-25s %8d\n", table, count)
|
||||
grandTotal += count
|
||||
}
|
||||
fmt.Fprintf(w, " %s\n", strings.Repeat("-", 35))
|
||||
fmt.Fprintf(w, " %-25s %8d\n", "TOTAL", grandTotal)
|
||||
fmt.Fprintf(w, "\nDatabase: %s\n", db.Path())
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func importTrainingFile(db *DB, path, source, split string) int {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
count := 0
|
||||
scanner := bufio.NewScanner(f)
|
||||
scanner.Buffer(make([]byte, 1024*1024), 1024*1024)
|
||||
|
||||
for scanner.Scan() {
|
||||
var rec struct {
|
||||
Messages []ChatMessage `json:"messages"`
|
||||
}
|
||||
if err := json.Unmarshal(scanner.Bytes(), &rec); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
prompt := ""
|
||||
response := ""
|
||||
assistantCount := 0
|
||||
for _, m := range rec.Messages {
|
||||
if m.Role == "user" && prompt == "" {
|
||||
prompt = m.Content
|
||||
}
|
||||
if m.Role == "assistant" {
|
||||
if response == "" {
|
||||
response = m.Content
|
||||
}
|
||||
assistantCount++
|
||||
}
|
||||
}
|
||||
|
||||
msgsJSON, _ := json.Marshal(rec.Messages)
|
||||
db.Exec(`INSERT INTO training_examples VALUES (?, ?, ?, ?, ?, ?, ?)`,
|
||||
source, split, prompt, response, assistantCount, string(msgsJSON), len(response))
|
||||
count++
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
func importBenchmarkFile(db *DB, path, source string) int {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
count := 0
|
||||
scanner := bufio.NewScanner(f)
|
||||
scanner.Buffer(make([]byte, 1024*1024), 1024*1024)
|
||||
|
||||
for scanner.Scan() {
|
||||
var rec map[string]interface{}
|
||||
if err := json.Unmarshal(scanner.Bytes(), &rec); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
db.Exec(`INSERT INTO benchmark_results VALUES (?, ?, ?, ?, ?, ?, ?, ?)`,
|
||||
source,
|
||||
fmt.Sprintf("%v", rec["id"]),
|
||||
strOrEmpty(rec, "benchmark"),
|
||||
strOrEmpty(rec, "model"),
|
||||
strOrEmpty(rec, "prompt"),
|
||||
strOrEmpty(rec, "response"),
|
||||
floatOrZero(rec, "elapsed_seconds"),
|
||||
strOrEmpty(rec, "domain"),
|
||||
)
|
||||
count++
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
func importBenchmarkQuestions(db *DB, path, benchmark string) int {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
count := 0
|
||||
scanner := bufio.NewScanner(f)
|
||||
scanner.Buffer(make([]byte, 1024*1024), 1024*1024)
|
||||
|
||||
for scanner.Scan() {
|
||||
var rec map[string]interface{}
|
||||
if err := json.Unmarshal(scanner.Bytes(), &rec); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
correctJSON, _ := json.Marshal(rec["correct_answers"])
|
||||
incorrectJSON, _ := json.Marshal(rec["incorrect_answers"])
|
||||
|
||||
db.Exec(`INSERT INTO benchmark_questions VALUES (?, ?, ?, ?, ?, ?, ?)`,
|
||||
benchmark,
|
||||
fmt.Sprintf("%v", rec["id"]),
|
||||
strOrEmpty(rec, "question"),
|
||||
strOrEmpty(rec, "best_answer"),
|
||||
string(correctJSON),
|
||||
string(incorrectJSON),
|
||||
strOrEmpty(rec, "category"),
|
||||
)
|
||||
count++
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
func importSeeds(db *DB, seedDir string) int {
|
||||
count := 0
|
||||
filepath.Walk(seedDir, func(path string, info os.FileInfo, err error) error {
|
||||
if err != nil || info.IsDir() || !strings.HasSuffix(path, ".json") {
|
||||
return nil
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
rel, _ := filepath.Rel(seedDir, path)
|
||||
region := strings.TrimSuffix(filepath.Base(path), ".json")
|
||||
|
||||
// Try parsing as array or object with prompts/seeds field.
|
||||
var seedsList []interface{}
|
||||
var raw interface{}
|
||||
if err := json.Unmarshal(data, &raw); err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch v := raw.(type) {
|
||||
case []interface{}:
|
||||
seedsList = v
|
||||
case map[string]interface{}:
|
||||
if prompts, ok := v["prompts"].([]interface{}); ok {
|
||||
seedsList = prompts
|
||||
} else if seeds, ok := v["seeds"].([]interface{}); ok {
|
||||
seedsList = seeds
|
||||
}
|
||||
}
|
||||
|
||||
for _, s := range seedsList {
|
||||
switch seed := s.(type) {
|
||||
case map[string]interface{}:
|
||||
prompt := strOrEmpty(seed, "prompt")
|
||||
if prompt == "" {
|
||||
prompt = strOrEmpty(seed, "text")
|
||||
}
|
||||
if prompt == "" {
|
||||
prompt = strOrEmpty(seed, "question")
|
||||
}
|
||||
db.Exec(`INSERT INTO seeds VALUES (?, ?, ?, ?, ?)`,
|
||||
rel, region,
|
||||
strOrEmpty(seed, "seed_id"),
|
||||
strOrEmpty(seed, "domain"),
|
||||
prompt,
|
||||
)
|
||||
count++
|
||||
case string:
|
||||
db.Exec(`INSERT INTO seeds VALUES (?, ?, ?, ?, ?)`,
|
||||
rel, region, "", "", seed)
|
||||
count++
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
return count
|
||||
}
|
||||
|
||||
func strOrEmpty(m map[string]interface{}, key string) string {
|
||||
if v, ok := m[key]; ok {
|
||||
return fmt.Sprintf("%v", v)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func floatOrZero(m map[string]interface{}, key string) float64 {
|
||||
if v, ok := m[key]; ok {
|
||||
if f, ok := v.(float64); ok {
|
||||
return f
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func escapeSQLPath(p string) string {
|
||||
return strings.ReplaceAll(p, "'", "''")
|
||||
}
|
||||
384
pkg/ml/ingest.go
Normal file
384
pkg/ml/ingest.go
Normal file
|
|
@ -0,0 +1,384 @@
|
|||
package ml
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// IngestConfig holds the configuration for a benchmark/training ingest run.
|
||||
type IngestConfig struct {
|
||||
ContentFile string
|
||||
CapabilityFile string
|
||||
TrainingLog string
|
||||
Model string
|
||||
RunID string
|
||||
BatchSize int
|
||||
}
|
||||
|
||||
// contentScoreLine is the JSON structure for a content scores JSONL line.
|
||||
type contentScoreLine struct {
|
||||
Label string `json:"label"`
|
||||
Aggregates map[string]interface{} `json:"aggregates"`
|
||||
Probes map[string]contentScoreProbe `json:"probes"`
|
||||
}
|
||||
|
||||
// contentScoreProbe is the per-probe block within a content score line.
|
||||
type contentScoreProbe struct {
|
||||
Scores map[string]interface{} `json:"scores"`
|
||||
}
|
||||
|
||||
// capabilityScoreLine is the JSON structure for a capability scores JSONL line.
|
||||
type capabilityScoreLine struct {
|
||||
Label string `json:"label"`
|
||||
Accuracy float64 `json:"accuracy"`
|
||||
Correct int `json:"correct"`
|
||||
Total int `json:"total"`
|
||||
ByCategory map[string]capabilityCatBlock `json:"by_category"`
|
||||
}
|
||||
|
||||
// capabilityCatBlock is the per-category block within a capability score line.
|
||||
type capabilityCatBlock struct {
|
||||
Correct int `json:"correct"`
|
||||
Total int `json:"total"`
|
||||
}
|
||||
|
||||
// Training log regexes.
|
||||
var (
|
||||
reValLoss = regexp.MustCompile(`Iter (\d+): Val loss ([\d.]+)`)
|
||||
reTrainLoss = regexp.MustCompile(`Iter (\d+): Train loss ([\d.]+), Learning Rate ([\d.eE+-]+), It/sec ([\d.]+), Tokens/sec ([\d.]+)`)
|
||||
)
|
||||
|
||||
// Ingest reads benchmark scores and training logs and writes them to InfluxDB.
|
||||
// At least one of ContentFile, CapabilityFile, or TrainingLog must be set.
|
||||
func Ingest(influx *InfluxClient, cfg IngestConfig, w io.Writer) error {
|
||||
if cfg.ContentFile == "" && cfg.CapabilityFile == "" && cfg.TrainingLog == "" {
|
||||
return fmt.Errorf("at least one of --content, --capability, or --training-log is required")
|
||||
}
|
||||
if cfg.Model == "" {
|
||||
return fmt.Errorf("--model is required")
|
||||
}
|
||||
if cfg.RunID == "" {
|
||||
cfg.RunID = cfg.Model
|
||||
}
|
||||
if cfg.BatchSize <= 0 {
|
||||
cfg.BatchSize = 100
|
||||
}
|
||||
|
||||
var totalPoints int
|
||||
|
||||
if cfg.ContentFile != "" {
|
||||
n, err := ingestContentScores(influx, cfg, w)
|
||||
if err != nil {
|
||||
return fmt.Errorf("ingest content scores: %w", err)
|
||||
}
|
||||
totalPoints += n
|
||||
}
|
||||
|
||||
if cfg.CapabilityFile != "" {
|
||||
n, err := ingestCapabilityScores(influx, cfg, w)
|
||||
if err != nil {
|
||||
return fmt.Errorf("ingest capability scores: %w", err)
|
||||
}
|
||||
totalPoints += n
|
||||
}
|
||||
|
||||
if cfg.TrainingLog != "" {
|
||||
n, err := ingestTrainingLog(influx, cfg, w)
|
||||
if err != nil {
|
||||
return fmt.Errorf("ingest training log: %w", err)
|
||||
}
|
||||
totalPoints += n
|
||||
}
|
||||
|
||||
fmt.Fprintf(w, "Ingested %d total points into InfluxDB\n", totalPoints)
|
||||
return nil
|
||||
}
|
||||
|
||||
// ingestContentScores reads a content scores JSONL file and writes content_score
|
||||
// and probe_score measurements to InfluxDB.
|
||||
func ingestContentScores(influx *InfluxClient, cfg IngestConfig, w io.Writer) (int, error) {
|
||||
f, err := os.Open(cfg.ContentFile)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("open %s: %w", cfg.ContentFile, err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
scanner := bufio.NewScanner(f)
|
||||
scanner.Buffer(make([]byte, 1024*1024), 1024*1024)
|
||||
|
||||
var lines []string
|
||||
var totalPoints int
|
||||
lineNum := 0
|
||||
|
||||
for scanner.Scan() {
|
||||
lineNum++
|
||||
raw := strings.TrimSpace(scanner.Text())
|
||||
if raw == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
var entry contentScoreLine
|
||||
if err := json.Unmarshal([]byte(raw), &entry); err != nil {
|
||||
return totalPoints, fmt.Errorf("line %d: parse json: %w", lineNum, err)
|
||||
}
|
||||
|
||||
label := entry.Label
|
||||
iteration := extractIteration(label)
|
||||
hasKernel := "false"
|
||||
if strings.Contains(strings.ToLower(label), "kernel") || strings.Contains(label, "LEK") {
|
||||
hasKernel = "true"
|
||||
}
|
||||
ts := time.Now().UnixNano()
|
||||
|
||||
// Write aggregate content_score — one point per dimension.
|
||||
for dim, val := range entry.Aggregates {
|
||||
score, ok := toFloat64(val)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
line := fmt.Sprintf(
|
||||
"content_score,model=%s,run_id=%s,label=%s,dimension=%s,has_kernel=%s score=%.6f,iteration=%di %d",
|
||||
EscapeLp(cfg.Model), EscapeLp(cfg.RunID), EscapeLp(label),
|
||||
EscapeLp(dim), hasKernel, score, iteration, ts,
|
||||
)
|
||||
lines = append(lines, line)
|
||||
totalPoints++
|
||||
}
|
||||
|
||||
// Write per-probe probe_score — one point per probe per dimension.
|
||||
for probeID, probe := range entry.Probes {
|
||||
for dim, val := range probe.Scores {
|
||||
score, ok := toFloat64(val)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
line := fmt.Sprintf(
|
||||
"probe_score,model=%s,run_id=%s,label=%s,probe_id=%s,dimension=%s,has_kernel=%s score=%.6f,iteration=%di %d",
|
||||
EscapeLp(cfg.Model), EscapeLp(cfg.RunID), EscapeLp(label),
|
||||
EscapeLp(probeID), EscapeLp(dim), hasKernel, score, iteration, ts,
|
||||
)
|
||||
lines = append(lines, line)
|
||||
totalPoints++
|
||||
}
|
||||
}
|
||||
|
||||
// Flush batch if needed.
|
||||
if len(lines) >= cfg.BatchSize {
|
||||
if err := influx.WriteLp(lines); err != nil {
|
||||
return totalPoints, fmt.Errorf("write batch: %w", err)
|
||||
}
|
||||
lines = lines[:0]
|
||||
}
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
return totalPoints, fmt.Errorf("scan %s: %w", cfg.ContentFile, err)
|
||||
}
|
||||
|
||||
// Flush remaining lines.
|
||||
if len(lines) > 0 {
|
||||
if err := influx.WriteLp(lines); err != nil {
|
||||
return totalPoints, fmt.Errorf("write final batch: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Fprintf(w, " content scores: %d points from %d lines\n", totalPoints, lineNum)
|
||||
return totalPoints, nil
|
||||
}
|
||||
|
||||
// ingestCapabilityScores reads a capability scores JSONL file and writes
|
||||
// capability_score measurements to InfluxDB.
|
||||
func ingestCapabilityScores(influx *InfluxClient, cfg IngestConfig, w io.Writer) (int, error) {
|
||||
f, err := os.Open(cfg.CapabilityFile)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("open %s: %w", cfg.CapabilityFile, err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
scanner := bufio.NewScanner(f)
|
||||
scanner.Buffer(make([]byte, 1024*1024), 1024*1024)
|
||||
|
||||
var lines []string
|
||||
var totalPoints int
|
||||
lineNum := 0
|
||||
|
||||
for scanner.Scan() {
|
||||
lineNum++
|
||||
raw := strings.TrimSpace(scanner.Text())
|
||||
if raw == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
var entry capabilityScoreLine
|
||||
if err := json.Unmarshal([]byte(raw), &entry); err != nil {
|
||||
return totalPoints, fmt.Errorf("line %d: parse json: %w", lineNum, err)
|
||||
}
|
||||
|
||||
label := entry.Label
|
||||
iteration := extractIteration(label)
|
||||
ts := time.Now().UnixNano()
|
||||
|
||||
// Overall capability score.
|
||||
line := fmt.Sprintf(
|
||||
"capability_score,model=%s,run_id=%s,label=%s,category=overall accuracy=%.6f,correct=%di,total=%di,iteration=%di %d",
|
||||
EscapeLp(cfg.Model), EscapeLp(cfg.RunID), EscapeLp(label),
|
||||
entry.Accuracy, entry.Correct, entry.Total, iteration, ts,
|
||||
)
|
||||
lines = append(lines, line)
|
||||
totalPoints++
|
||||
|
||||
// Per-category breakdown.
|
||||
for cat, block := range entry.ByCategory {
|
||||
var catAccuracy float64
|
||||
if block.Total > 0 {
|
||||
catAccuracy = float64(block.Correct) / float64(block.Total)
|
||||
}
|
||||
line := fmt.Sprintf(
|
||||
"capability_score,model=%s,run_id=%s,label=%s,category=%s accuracy=%.6f,correct=%di,total=%di,iteration=%di %d",
|
||||
EscapeLp(cfg.Model), EscapeLp(cfg.RunID), EscapeLp(label),
|
||||
EscapeLp(cat), catAccuracy, block.Correct, block.Total, iteration, ts,
|
||||
)
|
||||
lines = append(lines, line)
|
||||
totalPoints++
|
||||
}
|
||||
|
||||
// Flush batch if needed.
|
||||
if len(lines) >= cfg.BatchSize {
|
||||
if err := influx.WriteLp(lines); err != nil {
|
||||
return totalPoints, fmt.Errorf("write batch: %w", err)
|
||||
}
|
||||
lines = lines[:0]
|
||||
}
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
return totalPoints, fmt.Errorf("scan %s: %w", cfg.CapabilityFile, err)
|
||||
}
|
||||
|
||||
// Flush remaining lines.
|
||||
if len(lines) > 0 {
|
||||
if err := influx.WriteLp(lines); err != nil {
|
||||
return totalPoints, fmt.Errorf("write final batch: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Fprintf(w, " capability scores: %d points from %d lines\n", totalPoints, lineNum)
|
||||
return totalPoints, nil
|
||||
}
|
||||
|
||||
// ingestTrainingLog reads an MLX LoRA training log and writes training_loss
|
||||
// measurements to InfluxDB for both training and validation loss entries.
|
||||
func ingestTrainingLog(influx *InfluxClient, cfg IngestConfig, w io.Writer) (int, error) {
|
||||
f, err := os.Open(cfg.TrainingLog)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("open %s: %w", cfg.TrainingLog, err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
scanner := bufio.NewScanner(f)
|
||||
scanner.Buffer(make([]byte, 1024*1024), 1024*1024)
|
||||
|
||||
var lines []string
|
||||
var totalPoints int
|
||||
lineNum := 0
|
||||
|
||||
for scanner.Scan() {
|
||||
lineNum++
|
||||
text := scanner.Text()
|
||||
|
||||
// Try validation loss first (shorter regex, less common).
|
||||
if m := reValLoss.FindStringSubmatch(text); m != nil {
|
||||
iter, _ := strconv.Atoi(m[1])
|
||||
loss, _ := strconv.ParseFloat(m[2], 64)
|
||||
ts := time.Now().UnixNano()
|
||||
|
||||
line := fmt.Sprintf(
|
||||
"training_loss,model=%s,run_id=%s,loss_type=val loss=%.6f,iteration=%di %d",
|
||||
EscapeLp(cfg.Model), EscapeLp(cfg.RunID), loss, iter, ts,
|
||||
)
|
||||
lines = append(lines, line)
|
||||
totalPoints++
|
||||
}
|
||||
|
||||
// Try training loss.
|
||||
if m := reTrainLoss.FindStringSubmatch(text); m != nil {
|
||||
iter, _ := strconv.Atoi(m[1])
|
||||
loss, _ := strconv.ParseFloat(m[2], 64)
|
||||
lr, _ := strconv.ParseFloat(m[3], 64)
|
||||
itPerSec, _ := strconv.ParseFloat(m[4], 64)
|
||||
tokPerSec, _ := strconv.ParseFloat(m[5], 64)
|
||||
ts := time.Now().UnixNano()
|
||||
|
||||
line := fmt.Sprintf(
|
||||
"training_loss,model=%s,run_id=%s,loss_type=train loss=%.6f,iteration=%di,learning_rate=%.10f,it_per_sec=%.4f,tokens_per_sec=%.2f %d",
|
||||
EscapeLp(cfg.Model), EscapeLp(cfg.RunID), loss, iter, lr, itPerSec, tokPerSec, ts,
|
||||
)
|
||||
lines = append(lines, line)
|
||||
totalPoints++
|
||||
}
|
||||
|
||||
// Flush batch if needed.
|
||||
if len(lines) >= cfg.BatchSize {
|
||||
if err := influx.WriteLp(lines); err != nil {
|
||||
return totalPoints, fmt.Errorf("write batch: %w", err)
|
||||
}
|
||||
lines = lines[:0]
|
||||
}
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
return totalPoints, fmt.Errorf("scan %s: %w", cfg.TrainingLog, err)
|
||||
}
|
||||
|
||||
// Flush remaining lines.
|
||||
if len(lines) > 0 {
|
||||
if err := influx.WriteLp(lines); err != nil {
|
||||
return totalPoints, fmt.Errorf("write final batch: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Fprintf(w, " training log: %d points from %d lines\n", totalPoints, lineNum)
|
||||
return totalPoints, nil
|
||||
}
|
||||
|
||||
// extractIteration extracts an iteration number from a label like "model@200".
|
||||
// Returns 0 if no iteration is found.
|
||||
func extractIteration(label string) int {
|
||||
idx := strings.LastIndex(label, "@")
|
||||
if idx < 0 || idx+1 >= len(label) {
|
||||
return 0
|
||||
}
|
||||
n, err := strconv.Atoi(label[idx+1:])
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
// toFloat64 converts a JSON-decoded interface{} value to float64.
|
||||
// Handles float64 (standard json.Unmarshal), json.Number, and string values.
|
||||
func toFloat64(v interface{}) (float64, bool) {
|
||||
switch val := v.(type) {
|
||||
case float64:
|
||||
return val, true
|
||||
case int:
|
||||
return float64(val), true
|
||||
case int64:
|
||||
return float64(val), true
|
||||
case json.Number:
|
||||
f, err := val.Float64()
|
||||
return f, err == nil
|
||||
case string:
|
||||
f, err := strconv.ParseFloat(val, 64)
|
||||
return f, err == nil
|
||||
default:
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
147
pkg/ml/inventory.go
Normal file
147
pkg/ml/inventory.go
Normal file
|
|
@ -0,0 +1,147 @@
|
|||
package ml
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// TargetTotal is the golden set target size used for progress reporting.
|
||||
const TargetTotal = 15000
|
||||
|
||||
// tableOrder defines the canonical display order for inventory tables.
|
||||
var tableOrder = []string{
|
||||
"golden_set", "expansion_prompts", "seeds", "prompts",
|
||||
"training_examples", "gemini_responses", "benchmark_questions",
|
||||
"benchmark_results", "validations", "checkpoint_scores",
|
||||
"probe_results", "scoring_results",
|
||||
}
|
||||
|
||||
// tableDetail holds extra context for a single table beyond its row count.
|
||||
type tableDetail struct {
|
||||
notes []string
|
||||
}
|
||||
|
||||
// PrintInventory queries all known DuckDB tables and prints a formatted
|
||||
// inventory with row counts, detail breakdowns, and a grand total.
|
||||
func PrintInventory(db *DB, w io.Writer) error {
|
||||
counts, err := db.TableCounts()
|
||||
if err != nil {
|
||||
return fmt.Errorf("table counts: %w", err)
|
||||
}
|
||||
|
||||
details := gatherDetails(db, counts)
|
||||
|
||||
fmt.Fprintln(w, "DuckDB Inventory")
|
||||
fmt.Fprintln(w, strings.Repeat("-", 52))
|
||||
|
||||
grand := 0
|
||||
for _, table := range tableOrder {
|
||||
count, ok := counts[table]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
grand += count
|
||||
fmt.Fprintf(w, " %-24s %8d rows", table, count)
|
||||
|
||||
if d, has := details[table]; has && len(d.notes) > 0 {
|
||||
fmt.Fprintf(w, " (%s)", strings.Join(d.notes, ", "))
|
||||
}
|
||||
fmt.Fprintln(w)
|
||||
}
|
||||
|
||||
fmt.Fprintln(w, strings.Repeat("-", 52))
|
||||
fmt.Fprintf(w, " %-24s %8d rows\n", "TOTAL", grand)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// gatherDetails runs per-table detail queries and returns annotations keyed
|
||||
// by table name. Errors on individual queries are silently ignored so the
|
||||
// inventory always prints.
|
||||
func gatherDetails(db *DB, counts map[string]int) map[string]*tableDetail {
|
||||
details := make(map[string]*tableDetail)
|
||||
|
||||
// golden_set: progress toward target
|
||||
if count, ok := counts["golden_set"]; ok {
|
||||
pct := float64(count) / float64(TargetTotal) * 100
|
||||
details["golden_set"] = &tableDetail{
|
||||
notes: []string{fmt.Sprintf("%.1f%% of %d target", pct, TargetTotal)},
|
||||
}
|
||||
}
|
||||
|
||||
// training_examples: distinct sources
|
||||
if _, ok := counts["training_examples"]; ok {
|
||||
rows, err := db.QueryRows("SELECT COUNT(DISTINCT source) AS n FROM training_examples")
|
||||
if err == nil && len(rows) > 0 {
|
||||
n := toInt(rows[0]["n"])
|
||||
details["training_examples"] = &tableDetail{
|
||||
notes: []string{fmt.Sprintf("%d sources", n)},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// prompts: distinct domains and voices
|
||||
if _, ok := counts["prompts"]; ok {
|
||||
d := &tableDetail{}
|
||||
rows, err := db.QueryRows("SELECT COUNT(DISTINCT domain) AS n FROM prompts")
|
||||
if err == nil && len(rows) > 0 {
|
||||
d.notes = append(d.notes, fmt.Sprintf("%d domains", toInt(rows[0]["n"])))
|
||||
}
|
||||
rows, err = db.QueryRows("SELECT COUNT(DISTINCT voice) AS n FROM prompts")
|
||||
if err == nil && len(rows) > 0 {
|
||||
d.notes = append(d.notes, fmt.Sprintf("%d voices", toInt(rows[0]["n"])))
|
||||
}
|
||||
if len(d.notes) > 0 {
|
||||
details["prompts"] = d
|
||||
}
|
||||
}
|
||||
|
||||
// gemini_responses: group by source_model
|
||||
if _, ok := counts["gemini_responses"]; ok {
|
||||
rows, err := db.QueryRows(
|
||||
"SELECT source_model, COUNT(*) AS n FROM gemini_responses GROUP BY source_model ORDER BY n DESC",
|
||||
)
|
||||
if err == nil && len(rows) > 0 {
|
||||
var parts []string
|
||||
for _, row := range rows {
|
||||
model := strVal(row, "source_model")
|
||||
n := toInt(row["n"])
|
||||
if model != "" {
|
||||
parts = append(parts, fmt.Sprintf("%s:%d", model, n))
|
||||
}
|
||||
}
|
||||
if len(parts) > 0 {
|
||||
details["gemini_responses"] = &tableDetail{notes: parts}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// benchmark_results: distinct source categories
|
||||
if _, ok := counts["benchmark_results"]; ok {
|
||||
rows, err := db.QueryRows("SELECT COUNT(DISTINCT source) AS n FROM benchmark_results")
|
||||
if err == nil && len(rows) > 0 {
|
||||
n := toInt(rows[0]["n"])
|
||||
details["benchmark_results"] = &tableDetail{
|
||||
notes: []string{fmt.Sprintf("%d categories", n)},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return details
|
||||
}
|
||||
|
||||
// toInt converts a DuckDB value to int. DuckDB returns integers as int64 (not
|
||||
// float64 like InfluxDB), so we handle both types.
|
||||
func toInt(v interface{}) int {
|
||||
switch n := v.(type) {
|
||||
case int64:
|
||||
return int(n)
|
||||
case int32:
|
||||
return int(n)
|
||||
case float64:
|
||||
return int(n)
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
100
pkg/ml/metrics.go
Normal file
100
pkg/ml/metrics.go
Normal file
|
|
@ -0,0 +1,100 @@
|
|||
package ml
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"time"
|
||||
)
|
||||
|
||||
// PushMetrics queries golden_set stats from DuckDB and writes them to InfluxDB
|
||||
// as golden_set_stats, golden_set_domain, and golden_set_voice measurements.
|
||||
func PushMetrics(db *DB, influx *InfluxClient, w io.Writer) error {
|
||||
// Overall stats.
|
||||
var total, domains, voices int
|
||||
var avgGenTime, avgChars float64
|
||||
err := db.conn.QueryRow(
|
||||
"SELECT count(*), count(DISTINCT domain), count(DISTINCT voice), " +
|
||||
"coalesce(avg(gen_time), 0), coalesce(avg(char_count), 0) FROM golden_set",
|
||||
).Scan(&total, &domains, &voices, &avgGenTime, &avgChars)
|
||||
if err != nil {
|
||||
return fmt.Errorf("query golden_set stats: %w", err)
|
||||
}
|
||||
|
||||
if total == 0 {
|
||||
fmt.Fprintln(w, "golden_set is empty, nothing to push")
|
||||
return nil
|
||||
}
|
||||
|
||||
completionPct := float64(total) / float64(TargetTotal) * 100.0
|
||||
ts := time.Now().UnixNano()
|
||||
|
||||
var lines []string
|
||||
|
||||
// Overall stats point.
|
||||
lines = append(lines, fmt.Sprintf(
|
||||
"golden_set_stats total_examples=%di,domains=%di,voices=%di,avg_gen_time=%.2f,avg_response_chars=%.0f,completion_pct=%.1f %d",
|
||||
total, domains, voices, avgGenTime, avgChars, completionPct, ts,
|
||||
))
|
||||
|
||||
// Per-domain breakdown.
|
||||
domainRows, err := db.conn.Query(
|
||||
"SELECT domain, count(*) AS cnt, coalesce(avg(gen_time), 0) AS avg_gt FROM golden_set GROUP BY domain ORDER BY domain",
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("query golden_set domains: %w", err)
|
||||
}
|
||||
defer domainRows.Close()
|
||||
|
||||
for domainRows.Next() {
|
||||
var domain string
|
||||
var count int
|
||||
var avgGT float64
|
||||
if err := domainRows.Scan(&domain, &count, &avgGT); err != nil {
|
||||
return fmt.Errorf("scan domain row: %w", err)
|
||||
}
|
||||
lines = append(lines, fmt.Sprintf(
|
||||
"golden_set_domain,domain=%s count=%di,avg_gen_time=%.2f %d",
|
||||
EscapeLp(domain), count, avgGT, ts,
|
||||
))
|
||||
}
|
||||
if err := domainRows.Err(); err != nil {
|
||||
return fmt.Errorf("iterate domain rows: %w", err)
|
||||
}
|
||||
|
||||
// Per-voice breakdown.
|
||||
voiceRows, err := db.conn.Query(
|
||||
"SELECT voice, count(*) AS cnt, coalesce(avg(char_count), 0) AS avg_cc, coalesce(avg(gen_time), 0) AS avg_gt FROM golden_set GROUP BY voice ORDER BY voice",
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("query golden_set voices: %w", err)
|
||||
}
|
||||
defer voiceRows.Close()
|
||||
|
||||
for voiceRows.Next() {
|
||||
var voice string
|
||||
var count int
|
||||
var avgCC, avgGT float64
|
||||
if err := voiceRows.Scan(&voice, &count, &avgCC, &avgGT); err != nil {
|
||||
return fmt.Errorf("scan voice row: %w", err)
|
||||
}
|
||||
lines = append(lines, fmt.Sprintf(
|
||||
"golden_set_voice,voice=%s count=%di,avg_chars=%.0f,avg_gen_time=%.2f %d",
|
||||
EscapeLp(voice), count, avgCC, avgGT, ts,
|
||||
))
|
||||
}
|
||||
if err := voiceRows.Err(); err != nil {
|
||||
return fmt.Errorf("iterate voice rows: %w", err)
|
||||
}
|
||||
|
||||
// Write all points to InfluxDB.
|
||||
if err := influx.WriteLp(lines); err != nil {
|
||||
return fmt.Errorf("write metrics to influxdb: %w", err)
|
||||
}
|
||||
|
||||
fmt.Fprintf(w, "Pushed %d points to InfluxDB\n", len(lines))
|
||||
fmt.Fprintf(w, " total=%d domains=%d voices=%d completion=%.1f%%\n",
|
||||
total, domains, voices, completionPct)
|
||||
fmt.Fprintf(w, " avg_gen_time=%.2fs avg_chars=%.0f\n", avgGenTime, avgChars)
|
||||
|
||||
return nil
|
||||
}
|
||||
153
pkg/ml/normalize.go
Normal file
153
pkg/ml/normalize.go
Normal file
|
|
@ -0,0 +1,153 @@
|
|||
package ml
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// NormalizeConfig configures the seed normalization process.
|
||||
type NormalizeConfig struct {
|
||||
MinLength int
|
||||
}
|
||||
|
||||
// NormalizeSeeds deduplicates seeds into the expansion_prompts table.
|
||||
//
|
||||
// Steps:
|
||||
// 1. Verify the seeds table exists and report its row count.
|
||||
// 2. Drop and recreate expansion_prompts using deduplicated seeds,
|
||||
// excluding prompts already present in the prompts or golden_set tables.
|
||||
// 3. Assign priority based on domain coverage (underrepresented domains
|
||||
// receive higher priority via RANK).
|
||||
// 4. Print a region distribution summary.
|
||||
func NormalizeSeeds(db *DB, cfg NormalizeConfig, w io.Writer) error {
|
||||
// 1. Check seeds table exists and get count.
|
||||
var seedCount int
|
||||
if err := db.conn.QueryRow("SELECT count(*) FROM seeds").Scan(&seedCount); err != nil {
|
||||
return fmt.Errorf("no seeds table (run import-all first): %w", err)
|
||||
}
|
||||
fmt.Fprintf(w, "Seeds table: %d rows\n", seedCount)
|
||||
|
||||
if seedCount == 0 {
|
||||
return fmt.Errorf("seeds table is empty, nothing to normalize")
|
||||
}
|
||||
|
||||
// 2. Drop and recreate expansion_prompts.
|
||||
if _, err := db.conn.Exec("DROP TABLE IF EXISTS expansion_prompts"); err != nil {
|
||||
return fmt.Errorf("drop expansion_prompts: %w", err)
|
||||
}
|
||||
|
||||
createSQL := fmt.Sprintf(`
|
||||
CREATE TABLE expansion_prompts AS
|
||||
WITH unique_seeds AS (
|
||||
SELECT
|
||||
ROW_NUMBER() OVER (ORDER BY region, domain, seed_id) AS idx,
|
||||
seed_id, region, domain, prompt
|
||||
FROM (
|
||||
SELECT DISTINCT ON (prompt)
|
||||
seed_id, region, domain, prompt
|
||||
FROM seeds
|
||||
WHERE length(prompt) >= %d
|
||||
ORDER BY prompt, seed_id
|
||||
)
|
||||
),
|
||||
existing_prompts AS (
|
||||
SELECT prompt FROM prompts
|
||||
UNION ALL
|
||||
SELECT prompt FROM golden_set
|
||||
)
|
||||
SELECT
|
||||
us.idx, us.seed_id, us.region, us.domain,
|
||||
'en' AS language, us.prompt, '' AS prompt_en,
|
||||
0 AS priority, 'pending' AS status
|
||||
FROM unique_seeds us
|
||||
WHERE NOT EXISTS (
|
||||
SELECT 1 FROM existing_prompts ep WHERE ep.prompt = us.prompt
|
||||
)
|
||||
`, cfg.MinLength)
|
||||
|
||||
if _, err := db.conn.Exec(createSQL); err != nil {
|
||||
return fmt.Errorf("create expansion_prompts: %w", err)
|
||||
}
|
||||
|
||||
var epCount int
|
||||
if err := db.conn.QueryRow("SELECT count(*) FROM expansion_prompts").Scan(&epCount); err != nil {
|
||||
return fmt.Errorf("count expansion_prompts: %w", err)
|
||||
}
|
||||
fmt.Fprintf(w, "Expansion prompts created: %d (min length %d, deduped, excluding existing)\n", epCount, cfg.MinLength)
|
||||
|
||||
if epCount == 0 {
|
||||
fmt.Fprintln(w, "No new expansion prompts to process.")
|
||||
return nil
|
||||
}
|
||||
|
||||
// 3. Assign priority based on domain coverage.
|
||||
prioritySQL := `
|
||||
UPDATE expansion_prompts SET priority = sub.rnk
|
||||
FROM (
|
||||
SELECT domain, RANK() OVER (ORDER BY cnt ASC) AS rnk
|
||||
FROM (
|
||||
SELECT domain, count(*) AS cnt
|
||||
FROM expansion_prompts
|
||||
GROUP BY domain
|
||||
) domain_counts
|
||||
) sub
|
||||
WHERE expansion_prompts.domain = sub.domain
|
||||
`
|
||||
if _, err := db.conn.Exec(prioritySQL); err != nil {
|
||||
return fmt.Errorf("assign priority: %w", err)
|
||||
}
|
||||
fmt.Fprintln(w, "Priority assigned (underrepresented domains ranked higher).")
|
||||
|
||||
// 4. Region distribution summary.
|
||||
fmt.Fprintln(w)
|
||||
fmt.Fprintln(w, "Region distribution:")
|
||||
|
||||
rows, err := db.conn.Query(`
|
||||
SELECT
|
||||
CASE
|
||||
WHEN region LIKE 'cn%' THEN 'cn'
|
||||
WHEN region LIKE 'en%' THEN 'en'
|
||||
WHEN region LIKE 'ru%' THEN 'ru'
|
||||
WHEN region LIKE 'de%' THEN 'de'
|
||||
WHEN region LIKE 'es%' THEN 'es'
|
||||
WHEN region LIKE 'fr%' THEN 'fr'
|
||||
WHEN region LIKE 'latam%' THEN 'latam'
|
||||
WHEN region LIKE 'africa%' THEN 'africa'
|
||||
WHEN region LIKE 'eu%' THEN 'eu'
|
||||
WHEN region LIKE 'me%' THEN 'me'
|
||||
ELSE 'other'
|
||||
END AS region_group,
|
||||
count(*) AS cnt
|
||||
FROM expansion_prompts
|
||||
GROUP BY region_group
|
||||
ORDER BY cnt DESC
|
||||
`)
|
||||
if err != nil {
|
||||
return fmt.Errorf("region distribution query: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var totalFromRegions int
|
||||
var lines []string
|
||||
for rows.Next() {
|
||||
var region string
|
||||
var cnt int
|
||||
if err := rows.Scan(®ion, &cnt); err != nil {
|
||||
return fmt.Errorf("scan region row: %w", err)
|
||||
}
|
||||
totalFromRegions += cnt
|
||||
lines = append(lines, fmt.Sprintf(" %-10s %6d", region, cnt))
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return fmt.Errorf("iterate region rows: %w", err)
|
||||
}
|
||||
|
||||
for _, line := range lines {
|
||||
fmt.Fprintln(w, line)
|
||||
}
|
||||
fmt.Fprintf(w, " %-10s %6d\n", strings.Repeat("-", 10), totalFromRegions)
|
||||
fmt.Fprintf(w, " %-10s %6d\n", "total", totalFromRegions)
|
||||
|
||||
return nil
|
||||
}
|
||||
157
pkg/ml/publish.go
Normal file
157
pkg/ml/publish.go
Normal file
|
|
@ -0,0 +1,157 @@
|
|||
package ml
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// PublishConfig holds options for the publish operation.
|
||||
type PublishConfig struct {
|
||||
InputDir string
|
||||
Repo string
|
||||
Public bool
|
||||
Token string
|
||||
DryRun bool
|
||||
}
|
||||
|
||||
// uploadEntry pairs a local file path with its remote destination.
|
||||
type uploadEntry struct {
|
||||
local string
|
||||
remote string
|
||||
}
|
||||
|
||||
// Publish uploads Parquet files to HuggingFace Hub.
|
||||
//
|
||||
// It looks for train.parquet, valid.parquet, and test.parquet in InputDir,
|
||||
// plus an optional dataset_card.md in the parent directory (uploaded as README.md).
|
||||
// The token is resolved from PublishConfig.Token, the HF_TOKEN environment variable,
|
||||
// or ~/.huggingface/token, in that order.
|
||||
func Publish(cfg PublishConfig, w io.Writer) error {
|
||||
if cfg.InputDir == "" {
|
||||
return fmt.Errorf("input directory is required")
|
||||
}
|
||||
|
||||
token := resolveHFToken(cfg.Token)
|
||||
if token == "" && !cfg.DryRun {
|
||||
return fmt.Errorf("HuggingFace token required (--token, HF_TOKEN env, or ~/.huggingface/token)")
|
||||
}
|
||||
|
||||
files, err := collectUploadFiles(cfg.InputDir)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(files) == 0 {
|
||||
return fmt.Errorf("no Parquet files found in %s", cfg.InputDir)
|
||||
}
|
||||
|
||||
if cfg.DryRun {
|
||||
fmt.Fprintf(w, "Dry run: would publish to %s\n", cfg.Repo)
|
||||
if cfg.Public {
|
||||
fmt.Fprintln(w, " Visibility: public")
|
||||
} else {
|
||||
fmt.Fprintln(w, " Visibility: private")
|
||||
}
|
||||
for _, f := range files {
|
||||
info, err := os.Stat(f.local)
|
||||
if err != nil {
|
||||
return fmt.Errorf("stat %s: %w", f.local, err)
|
||||
}
|
||||
sizeMB := float64(info.Size()) / 1024 / 1024
|
||||
fmt.Fprintf(w, " %s -> %s (%.1f MB)\n", filepath.Base(f.local), f.remote, sizeMB)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
fmt.Fprintf(w, "Publishing to https://huggingface.co/datasets/%s\n", cfg.Repo)
|
||||
|
||||
for _, f := range files {
|
||||
if err := uploadFileToHF(token, cfg.Repo, f.local, f.remote); err != nil {
|
||||
return fmt.Errorf("upload %s: %w", filepath.Base(f.local), err)
|
||||
}
|
||||
fmt.Fprintf(w, " Uploaded %s -> %s\n", filepath.Base(f.local), f.remote)
|
||||
}
|
||||
|
||||
fmt.Fprintf(w, "\nPublished to https://huggingface.co/datasets/%s\n", cfg.Repo)
|
||||
return nil
|
||||
}
|
||||
|
||||
// resolveHFToken returns a HuggingFace API token from the given value,
|
||||
// HF_TOKEN env var, or ~/.huggingface/token file.
|
||||
func resolveHFToken(explicit string) string {
|
||||
if explicit != "" {
|
||||
return explicit
|
||||
}
|
||||
if env := os.Getenv("HF_TOKEN"); env != "" {
|
||||
return env
|
||||
}
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
data, err := os.ReadFile(filepath.Join(home, ".huggingface", "token"))
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSpace(string(data))
|
||||
}
|
||||
|
||||
// collectUploadFiles finds Parquet split files and an optional dataset card.
|
||||
func collectUploadFiles(inputDir string) ([]uploadEntry, error) {
|
||||
splits := []string{"train", "valid", "test"}
|
||||
var files []uploadEntry
|
||||
|
||||
for _, split := range splits {
|
||||
path := filepath.Join(inputDir, split+".parquet")
|
||||
if _, err := os.Stat(path); os.IsNotExist(err) {
|
||||
continue
|
||||
} else if err != nil {
|
||||
return nil, fmt.Errorf("stat %s: %w", path, err)
|
||||
}
|
||||
files = append(files, uploadEntry{path, fmt.Sprintf("data/%s.parquet", split)})
|
||||
}
|
||||
|
||||
// Check for dataset card in parent directory.
|
||||
cardPath := filepath.Join(inputDir, "..", "dataset_card.md")
|
||||
if _, err := os.Stat(cardPath); err == nil {
|
||||
files = append(files, uploadEntry{cardPath, "README.md"})
|
||||
}
|
||||
|
||||
return files, nil
|
||||
}
|
||||
|
||||
// uploadFileToHF uploads a single file to a HuggingFace dataset repo via the Hub API.
|
||||
func uploadFileToHF(token, repoID, localPath, remotePath string) error {
|
||||
data, err := os.ReadFile(localPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read %s: %w", localPath, err)
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("https://huggingface.co/api/datasets/%s/upload/main/%s", repoID, remotePath)
|
||||
|
||||
req, err := http.NewRequest(http.MethodPut, url, bytes.NewReader(data))
|
||||
if err != nil {
|
||||
return fmt.Errorf("create request: %w", err)
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
req.Header.Set("Content-Type", "application/octet-stream")
|
||||
|
||||
client := &http.Client{Timeout: 120 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("upload request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode >= 300 {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return fmt.Errorf("upload failed: HTTP %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
111
pkg/ml/seed_influx.go
Normal file
111
pkg/ml/seed_influx.go
Normal file
|
|
@ -0,0 +1,111 @@
|
|||
package ml
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// SeedInfluxConfig holds options for the SeedInflux migration.
|
||||
type SeedInfluxConfig struct {
|
||||
Force bool
|
||||
BatchSize int
|
||||
}
|
||||
|
||||
// SeedInflux migrates golden_set rows from DuckDB into InfluxDB as
|
||||
// gold_gen measurement points. This is a one-time migration tool;
|
||||
// it skips the write when InfluxDB already contains all records
|
||||
// unless Force is set.
|
||||
func SeedInflux(db *DB, influx *InfluxClient, cfg SeedInfluxConfig, w io.Writer) error {
|
||||
if cfg.BatchSize <= 0 {
|
||||
cfg.BatchSize = 500
|
||||
}
|
||||
|
||||
// Count source rows in DuckDB.
|
||||
var total int
|
||||
if err := db.conn.QueryRow("SELECT count(*) FROM golden_set").Scan(&total); err != nil {
|
||||
return fmt.Errorf("no golden_set table: %w", err)
|
||||
}
|
||||
|
||||
// Check how many distinct records InfluxDB already has.
|
||||
existing := 0
|
||||
rows, err := influx.QuerySQL("SELECT count(DISTINCT i) AS n FROM gold_gen")
|
||||
if err == nil && len(rows) > 0 {
|
||||
if n, ok := rows[0]["n"].(float64); ok {
|
||||
existing = int(n)
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Fprintf(w, "DuckDB has %d records, InfluxDB golden_gen has %d\n", total, existing)
|
||||
|
||||
if existing >= total && !cfg.Force {
|
||||
fmt.Fprintln(w, "InfluxDB already has all records. Use --force to re-seed.")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Query all golden_set rows from DuckDB.
|
||||
dbRows, err := db.conn.Query(
|
||||
"SELECT idx, seed_id, domain, voice, gen_time, char_count FROM golden_set ORDER BY idx",
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("query golden_set: %w", err)
|
||||
}
|
||||
defer dbRows.Close()
|
||||
|
||||
var batch []string
|
||||
written := 0
|
||||
|
||||
for dbRows.Next() {
|
||||
var idx int
|
||||
var seedID, domain, voice string
|
||||
var genTime float64
|
||||
var charCount int
|
||||
|
||||
if err := dbRows.Scan(&idx, &seedID, &domain, &voice, &genTime, &charCount); err != nil {
|
||||
return fmt.Errorf("scan row %d: %w", written, err)
|
||||
}
|
||||
|
||||
// Build line protocol point.
|
||||
// Tags: i (idx), w (worker), d (domain), v (voice)
|
||||
// Fields: seed_id (string), gen_time (float), chars (integer)
|
||||
escapedSeedID := strings.ReplaceAll(seedID, `"`, `\"`)
|
||||
|
||||
line := fmt.Sprintf(
|
||||
"gold_gen,i=%s,w=migration,d=%s,v=%s seed_id=\"%s\",gen_time=%v,chars=%di",
|
||||
EscapeLp(fmt.Sprintf("%d", idx)),
|
||||
EscapeLp(domain),
|
||||
EscapeLp(voice),
|
||||
escapedSeedID,
|
||||
genTime,
|
||||
charCount,
|
||||
)
|
||||
batch = append(batch, line)
|
||||
|
||||
if len(batch) >= cfg.BatchSize {
|
||||
if err := influx.WriteLp(batch); err != nil {
|
||||
return fmt.Errorf("write batch at row %d: %w", written, err)
|
||||
}
|
||||
written += len(batch)
|
||||
batch = batch[:0]
|
||||
|
||||
if written%2000 == 0 {
|
||||
fmt.Fprintf(w, " wrote %d / %d\n", written, total)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := dbRows.Err(); err != nil {
|
||||
return fmt.Errorf("iterate golden_set rows: %w", err)
|
||||
}
|
||||
|
||||
// Flush remaining batch.
|
||||
if len(batch) > 0 {
|
||||
if err := influx.WriteLp(batch); err != nil {
|
||||
return fmt.Errorf("write final batch: %w", err)
|
||||
}
|
||||
written += len(batch)
|
||||
}
|
||||
|
||||
fmt.Fprintf(w, "Seeded %d records into InfluxDB golden_gen\n", written)
|
||||
return nil
|
||||
}
|
||||
Loading…
Add table
Reference in a new issue