refactor: migrate all 25 commands from passthrough to cobra framework
Replace passthrough() + stdlib flag.FlagSet anti-pattern with proper cobra integration. Every Run* function now takes a typed *Opts struct and returns error. Flags registered via cli.StringFlag/IntFlag/etc. Commands participate in Core lifecycle with full cobra flag parsing. - 6 command groups: gen, score, data, export, infra, mon - 25 commands converted, 0 passthrough() calls remain - Delete passthrough() helper from lem.go - Update export_test.go to use ExportOpts struct Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
parent
42c0af728b
commit
56eda1a081
32 changed files with 1199 additions and 894 deletions
|
|
@ -8,10 +8,54 @@ import (
|
|||
func addDataCommands(root *cli.Command) {
|
||||
dataGroup := cli.NewGroup("data", "Data management commands", "Import, consolidate, normalise, and approve training data.")
|
||||
|
||||
dataGroup.AddCommand(passthrough("import-all", "Import ALL LEM data into DuckDB from M3", lem.RunImport))
|
||||
dataGroup.AddCommand(passthrough("consolidate", "Pull worker JSONLs from M3, merge, deduplicate", lem.RunConsolidate))
|
||||
dataGroup.AddCommand(passthrough("normalize", "Normalise seeds to deduplicated expansion prompts", lem.RunNormalize))
|
||||
dataGroup.AddCommand(passthrough("approve", "Filter scored expansions to training JSONL", lem.RunApprove))
|
||||
// import-all — Import ALL LEM data into DuckDB from M3.
|
||||
var importCfg lem.ImportOpts
|
||||
importCmd := cli.NewCommand("import-all", "Import ALL LEM data into DuckDB from M3", "",
|
||||
func(cmd *cli.Command, args []string) error {
|
||||
return lem.RunImport(importCfg)
|
||||
},
|
||||
)
|
||||
cli.StringFlag(importCmd, &importCfg.DB, "db", "", "", "DuckDB database path (defaults to LEM_DB env)")
|
||||
cli.BoolFlag(importCmd, &importCfg.SkipM3, "skip-m3", "", false, "Skip pulling data from M3")
|
||||
cli.StringFlag(importCmd, &importCfg.DataDir, "data-dir", "", "", "Local data directory (defaults to db directory)")
|
||||
dataGroup.AddCommand(importCmd)
|
||||
|
||||
// consolidate — Pull worker JSONLs from M3, merge, deduplicate.
|
||||
var consolidateCfg lem.ConsolidateOpts
|
||||
consolidateCmd := cli.NewCommand("consolidate", "Pull worker JSONLs from M3, merge, deduplicate", "",
|
||||
func(cmd *cli.Command, args []string) error {
|
||||
return lem.RunConsolidate(consolidateCfg)
|
||||
},
|
||||
)
|
||||
cli.StringFlag(consolidateCmd, &consolidateCfg.Host, "host", "", "m3", "SSH host for remote files")
|
||||
cli.StringFlag(consolidateCmd, &consolidateCfg.Remote, "remote", "", "/Volumes/Data/lem/responses", "Remote directory for JSONL files")
|
||||
cli.StringFlag(consolidateCmd, &consolidateCfg.Pattern, "pattern", "", "gold*.jsonl", "File glob pattern")
|
||||
cli.StringFlag(consolidateCmd, &consolidateCfg.OutputDir, "output", "o", "", "Output directory (defaults to ./responses)")
|
||||
cli.StringFlag(consolidateCmd, &consolidateCfg.Merged, "merged", "", "", "Merged output file (defaults to gold-merged.jsonl in output dir)")
|
||||
dataGroup.AddCommand(consolidateCmd)
|
||||
|
||||
// normalize — Normalise seeds to deduplicated expansion prompts.
|
||||
var normalizeCfg lem.NormalizeOpts
|
||||
normalizeCmd := cli.NewCommand("normalize", "Normalise seeds to deduplicated expansion prompts", "",
|
||||
func(cmd *cli.Command, args []string) error {
|
||||
return lem.RunNormalize(normalizeCfg)
|
||||
},
|
||||
)
|
||||
cli.StringFlag(normalizeCmd, &normalizeCfg.DB, "db", "", "", "DuckDB database path (defaults to LEM_DB env)")
|
||||
cli.IntFlag(normalizeCmd, &normalizeCfg.MinLen, "min-length", "", 50, "Minimum prompt length in characters")
|
||||
dataGroup.AddCommand(normalizeCmd)
|
||||
|
||||
// approve — Filter scored expansions to training JSONL.
|
||||
var approveCfg lem.ApproveOpts
|
||||
approveCmd := cli.NewCommand("approve", "Filter scored expansions to training JSONL", "",
|
||||
func(cmd *cli.Command, args []string) error {
|
||||
return lem.RunApprove(approveCfg)
|
||||
},
|
||||
)
|
||||
cli.StringFlag(approveCmd, &approveCfg.DB, "db", "", "", "DuckDB database path (defaults to LEM_DB env)")
|
||||
cli.StringFlag(approveCmd, &approveCfg.Output, "output", "o", "", "Output JSONL file (defaults to expansion-approved.jsonl in db dir)")
|
||||
cli.Float64Flag(approveCmd, &approveCfg.Threshold, "threshold", "", 6.0, "Min judge average to approve")
|
||||
dataGroup.AddCommand(approveCmd)
|
||||
|
||||
root.AddCommand(dataGroup)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -8,10 +8,60 @@ import (
|
|||
func addExportCommands(root *cli.Command) {
|
||||
exportGroup := cli.NewGroup("export", "Export and publish commands", "Export training data to JSONL, Parquet, HuggingFace, and PEFT formats.")
|
||||
|
||||
exportGroup.AddCommand(passthrough("jsonl", "Export golden set to training-format JSONL splits", lem.RunExport))
|
||||
exportGroup.AddCommand(passthrough("parquet", "Export JSONL training splits to Parquet", lem.RunParquet))
|
||||
exportGroup.AddCommand(passthrough("publish", "Push Parquet files to HuggingFace dataset repo", lem.RunPublish))
|
||||
exportGroup.AddCommand(passthrough("convert", "Convert MLX LoRA adapter to PEFT format", lem.RunConvert))
|
||||
// jsonl — export golden set to training-format JSONL splits.
|
||||
var exportCfg lem.ExportOpts
|
||||
jsonlCmd := cli.NewCommand("jsonl", "Export golden set to training-format JSONL splits", "",
|
||||
func(cmd *cli.Command, args []string) error {
|
||||
return lem.RunExport(exportCfg)
|
||||
},
|
||||
)
|
||||
cli.StringFlag(jsonlCmd, &exportCfg.DBPath, "db", "", "", "DuckDB database path (primary source; defaults to LEM_DB env)")
|
||||
cli.StringFlag(jsonlCmd, &exportCfg.Input, "input", "i", "", "Input golden set JSONL file (fallback if --db not set)")
|
||||
cli.StringFlag(jsonlCmd, &exportCfg.OutputDir, "output-dir", "o", "", "Output directory for training files (required)")
|
||||
cli.IntFlag(jsonlCmd, &exportCfg.TrainPct, "train-pct", "", 90, "Training set percentage")
|
||||
cli.IntFlag(jsonlCmd, &exportCfg.ValidPct, "valid-pct", "", 5, "Validation set percentage")
|
||||
cli.IntFlag(jsonlCmd, &exportCfg.TestPct, "test-pct", "", 5, "Test set percentage")
|
||||
cli.Int64Flag(jsonlCmd, &exportCfg.Seed, "seed", "", 42, "Random seed for shuffling")
|
||||
cli.IntFlag(jsonlCmd, &exportCfg.MinChars, "min-chars", "", 50, "Minimum response character count")
|
||||
exportGroup.AddCommand(jsonlCmd)
|
||||
|
||||
// parquet — export JSONL training splits to Parquet.
|
||||
var parquetCfg lem.ParquetOpts
|
||||
parquetCmd := cli.NewCommand("parquet", "Export JSONL training splits to Parquet", "",
|
||||
func(cmd *cli.Command, args []string) error {
|
||||
return lem.RunParquet(parquetCfg)
|
||||
},
|
||||
)
|
||||
cli.StringFlag(parquetCmd, &parquetCfg.Input, "input", "i", "", "Directory containing train.jsonl, valid.jsonl, test.jsonl (required)")
|
||||
cli.StringFlag(parquetCmd, &parquetCfg.Output, "output", "o", "", "Output directory for Parquet files (defaults to input/parquet)")
|
||||
exportGroup.AddCommand(parquetCmd)
|
||||
|
||||
// publish — push Parquet files to HuggingFace dataset repo.
|
||||
var publishCfg lem.PublishOpts
|
||||
publishCmd := cli.NewCommand("publish", "Push Parquet files to HuggingFace dataset repo", "",
|
||||
func(cmd *cli.Command, args []string) error {
|
||||
return lem.RunPublish(publishCfg)
|
||||
},
|
||||
)
|
||||
cli.StringFlag(publishCmd, &publishCfg.Input, "input", "i", "", "Directory containing Parquet files (required)")
|
||||
cli.StringFlag(publishCmd, &publishCfg.Repo, "repo", "", "lthn/LEM-golden-set", "HuggingFace dataset repo ID")
|
||||
cli.BoolFlag(publishCmd, &publishCfg.Public, "public", "", false, "Make dataset public")
|
||||
cli.StringFlag(publishCmd, &publishCfg.Token, "token", "", "", "HuggingFace API token (defaults to HF_TOKEN env)")
|
||||
cli.BoolFlag(publishCmd, &publishCfg.DryRun, "dry-run", "", false, "Show what would be uploaded without uploading")
|
||||
exportGroup.AddCommand(publishCmd)
|
||||
|
||||
// convert — convert MLX LoRA adapter to PEFT format.
|
||||
var convertCfg lem.ConvertOpts
|
||||
convertCmd := cli.NewCommand("convert", "Convert MLX LoRA adapter to PEFT format", "",
|
||||
func(cmd *cli.Command, args []string) error {
|
||||
return lem.RunConvert(convertCfg)
|
||||
},
|
||||
)
|
||||
cli.StringFlag(convertCmd, &convertCfg.Input, "input", "i", "", "Path to MLX .safetensors file (required)")
|
||||
cli.StringFlag(convertCmd, &convertCfg.Config, "config", "c", "", "Path to MLX adapter_config.json (required)")
|
||||
cli.StringFlag(convertCmd, &convertCfg.Output, "output", "o", "./peft_output", "Output directory for PEFT adapter")
|
||||
cli.StringFlag(convertCmd, &convertCfg.BaseModel, "base-model", "", "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B", "HuggingFace base model ID")
|
||||
exportGroup.AddCommand(convertCmd)
|
||||
|
||||
root.AddCommand(exportGroup)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -8,9 +8,65 @@ import (
|
|||
func addGenCommands(root *cli.Command) {
|
||||
genGroup := cli.NewGroup("gen", "Generation commands", "Distill, expand, and generate training data.")
|
||||
|
||||
genGroup.AddCommand(passthrough("distill", "Native Metal distillation (go-mlx + grammar scoring)", lem.RunDistill))
|
||||
genGroup.AddCommand(passthrough("expand", "Generate expansion responses via trained LEM model", lem.RunExpand))
|
||||
genGroup.AddCommand(passthrough("conv", "Generate conversational training data (calm phase)", lem.RunConv))
|
||||
// distill — native Metal distillation with grammar scoring.
|
||||
var distillCfg lem.DistillOpts
|
||||
distillCmd := cli.NewCommand("distill", "Native Metal distillation (go-mlx + grammar scoring)", "",
|
||||
func(cmd *cli.Command, args []string) error {
|
||||
return lem.RunDistill(distillCfg)
|
||||
},
|
||||
)
|
||||
cli.StringFlag(distillCmd, &distillCfg.Model, "model", "m", "gemma3/27b", "Model config path (relative to .core/ai/models/)")
|
||||
cli.StringFlag(distillCmd, &distillCfg.Probes, "probes", "p", "", "Probe set name from probes.yaml, or path to JSON file")
|
||||
cli.StringFlag(distillCmd, &distillCfg.Output, "output", "o", "", "Output JSONL path (defaults to model training dir)")
|
||||
cli.IntFlag(distillCmd, &distillCfg.Lesson, "lesson", "", -1, "Lesson number to append to (defaults to probe set phase)")
|
||||
cli.Float64Flag(distillCmd, &distillCfg.MinScore, "min-score", "", 0, "Min grammar composite (0 = use ai.yaml default)")
|
||||
cli.IntFlag(distillCmd, &distillCfg.Runs, "runs", "r", 0, "Generations per probe (0 = use ai.yaml default)")
|
||||
cli.BoolFlag(distillCmd, &distillCfg.DryRun, "dry-run", "", false, "Show plan and exit without generating")
|
||||
cli.StringFlag(distillCmd, &distillCfg.Root, "root", "", ".", "Project root (for .core/ai/ config)")
|
||||
cli.IntFlag(distillCmd, &distillCfg.CacheLimit, "cache-limit", "", 0, "Metal cache limit in GB (0 = use ai.yaml default)")
|
||||
cli.IntFlag(distillCmd, &distillCfg.MemLimit, "mem-limit", "", 0, "Metal memory limit in GB (0 = use ai.yaml default)")
|
||||
genGroup.AddCommand(distillCmd)
|
||||
|
||||
// expand — generate expansion responses via trained LEM model.
|
||||
var expandCfg lem.ExpandOpts
|
||||
expandCmd := cli.NewCommand("expand", "Generate expansion responses via trained LEM model", "",
|
||||
func(cmd *cli.Command, args []string) error {
|
||||
return lem.RunExpand(expandCfg)
|
||||
},
|
||||
)
|
||||
cli.StringFlag(expandCmd, &expandCfg.Model, "model", "m", "", "Model name for generation (required)")
|
||||
cli.StringFlag(expandCmd, &expandCfg.DB, "db", "", "", "DuckDB database path (primary prompt source)")
|
||||
cli.StringFlag(expandCmd, &expandCfg.Prompts, "prompts", "p", "", "Input JSONL file with expansion prompts (fallback)")
|
||||
cli.StringFlag(expandCmd, &expandCfg.APIURL, "api-url", "", "http://10.69.69.108:8090", "OpenAI-compatible API URL")
|
||||
cli.StringFlag(expandCmd, &expandCfg.Worker, "worker", "", "", "Worker hostname (defaults to os.Hostname())")
|
||||
cli.IntFlag(expandCmd, &expandCfg.Limit, "limit", "", 0, "Max prompts to process (0 = all)")
|
||||
cli.StringFlag(expandCmd, &expandCfg.Output, "output", "o", ".", "Output directory for JSONL files")
|
||||
cli.StringFlag(expandCmd, &expandCfg.Influx, "influx", "", "", "InfluxDB URL (default http://10.69.69.165:8181)")
|
||||
cli.StringFlag(expandCmd, &expandCfg.InfluxDB, "influx-db", "", "", "InfluxDB database name (default training)")
|
||||
cli.BoolFlag(expandCmd, &expandCfg.DryRun, "dry-run", "", false, "Print plan and exit without generating")
|
||||
genGroup.AddCommand(expandCmd)
|
||||
|
||||
// conv — generate conversational training data (calm phase).
|
||||
var convCfg lem.ConvOpts
|
||||
convCmd := cli.NewCommand("conv", "Generate conversational training data (calm phase)", "",
|
||||
func(cmd *cli.Command, args []string) error {
|
||||
return lem.RunConv(convCfg)
|
||||
},
|
||||
)
|
||||
cli.StringFlag(convCmd, &convCfg.OutputDir, "output-dir", "o", "", "Output directory for training files (required)")
|
||||
cli.StringFlag(convCmd, &convCfg.Extra, "extra", "", "", "Additional conversations JSONL file (multi-turn format)")
|
||||
cli.StringFlag(convCmd, &convCfg.Golden, "golden", "", "", "Golden set JSONL to convert to single-turn conversations")
|
||||
cli.StringFlag(convCmd, &convCfg.DB, "db", "", "", "DuckDB database path for golden set (alternative to --golden)")
|
||||
cli.IntFlag(convCmd, &convCfg.TrainPct, "train-pct", "", 80, "Training set percentage")
|
||||
cli.IntFlag(convCmd, &convCfg.ValidPct, "valid-pct", "", 10, "Validation set percentage")
|
||||
cli.IntFlag(convCmd, &convCfg.TestPct, "test-pct", "", 10, "Test set percentage")
|
||||
cli.Int64Flag(convCmd, &convCfg.Seed, "seed", "", 42, "Random seed for shuffling")
|
||||
cli.IntFlag(convCmd, &convCfg.MinChars, "min-chars", "", 50, "Minimum response chars for golden set conversion")
|
||||
cli.BoolFlag(convCmd, &convCfg.NoBuiltin, "no-builtin", "", false, "Exclude built-in seed conversations")
|
||||
cli.StringFlag(convCmd, &convCfg.Influx, "influx", "", "", "InfluxDB URL for progress reporting")
|
||||
cli.StringFlag(convCmd, &convCfg.InfluxDB, "influx-db", "", "", "InfluxDB database name")
|
||||
cli.StringFlag(convCmd, &convCfg.Worker, "worker", "", "", "Worker hostname for InfluxDB reporting")
|
||||
genGroup.AddCommand(convCmd)
|
||||
|
||||
root.AddCommand(genGroup)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
package lemcmd
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"forge.lthn.ai/core/cli/pkg/cli"
|
||||
"forge.lthn.ai/lthn/lem/pkg/lem"
|
||||
)
|
||||
|
|
@ -8,10 +10,70 @@ import (
|
|||
func addInfraCommands(root *cli.Command) {
|
||||
infraGroup := cli.NewGroup("infra", "Infrastructure commands", "InfluxDB ingestion, DuckDB queries, and distributed workers.")
|
||||
|
||||
infraGroup.AddCommand(passthrough("ingest", "Ingest benchmark data into InfluxDB", lem.RunIngest))
|
||||
infraGroup.AddCommand(passthrough("seed-influx", "Seed InfluxDB golden_gen from DuckDB", lem.RunSeedInflux))
|
||||
infraGroup.AddCommand(passthrough("query", "Run ad-hoc SQL against DuckDB", lem.RunQuery))
|
||||
infraGroup.AddCommand(passthrough("worker", "Run as distributed inference worker node", lem.RunWorker))
|
||||
// ingest — push benchmark data into InfluxDB.
|
||||
var ingestCfg lem.IngestOpts
|
||||
ingestCmd := cli.NewCommand("ingest", "Ingest benchmark data into InfluxDB", "",
|
||||
func(cmd *cli.Command, args []string) error {
|
||||
return lem.RunIngest(ingestCfg)
|
||||
},
|
||||
)
|
||||
cli.StringFlag(ingestCmd, &ingestCfg.Content, "content", "", "", "Content scores JSONL file")
|
||||
cli.StringFlag(ingestCmd, &ingestCfg.Capability, "capability", "", "", "Capability scores JSONL file")
|
||||
cli.StringFlag(ingestCmd, &ingestCfg.TrainingLog, "training-log", "", "", "MLX LoRA training log file")
|
||||
cli.StringFlag(ingestCmd, &ingestCfg.Model, "model", "m", "", "Model name tag (required)")
|
||||
cli.StringFlag(ingestCmd, &ingestCfg.RunID, "run-id", "", "", "Run ID tag (defaults to model name)")
|
||||
cli.StringFlag(ingestCmd, &ingestCfg.InfluxURL, "influx", "", "", "InfluxDB URL")
|
||||
cli.StringFlag(ingestCmd, &ingestCfg.InfluxDB, "influx-db", "", "", "InfluxDB database name")
|
||||
cli.IntFlag(ingestCmd, &ingestCfg.BatchSize, "batch-size", "", 100, "Lines per InfluxDB write batch")
|
||||
infraGroup.AddCommand(ingestCmd)
|
||||
|
||||
// seed-influx — seed InfluxDB golden_gen from DuckDB.
|
||||
var seedCfg lem.SeedInfluxOpts
|
||||
seedCmd := cli.NewCommand("seed-influx", "Seed InfluxDB golden_gen from DuckDB", "",
|
||||
func(cmd *cli.Command, args []string) error {
|
||||
return lem.RunSeedInflux(seedCfg)
|
||||
},
|
||||
)
|
||||
cli.StringFlag(seedCmd, &seedCfg.DB, "db", "", "", "DuckDB database path (defaults to LEM_DB env)")
|
||||
cli.StringFlag(seedCmd, &seedCfg.InfluxURL, "influx", "", "", "InfluxDB URL")
|
||||
cli.StringFlag(seedCmd, &seedCfg.InfluxDB, "influx-db", "", "", "InfluxDB database name")
|
||||
cli.BoolFlag(seedCmd, &seedCfg.Force, "force", "", false, "Re-seed even if InfluxDB already has data")
|
||||
cli.IntFlag(seedCmd, &seedCfg.BatchSize, "batch-size", "", 500, "Lines per InfluxDB write batch")
|
||||
infraGroup.AddCommand(seedCmd)
|
||||
|
||||
// query — run ad-hoc SQL against DuckDB.
|
||||
var queryCfg lem.QueryOpts
|
||||
queryCmd := cli.NewCommand("query", "Run ad-hoc SQL against DuckDB", "",
|
||||
func(cmd *cli.Command, args []string) error {
|
||||
return lem.RunQuery(queryCfg, args)
|
||||
},
|
||||
)
|
||||
cli.StringFlag(queryCmd, &queryCfg.DB, "db", "", "", "DuckDB database path (defaults to LEM_DB env)")
|
||||
cli.BoolFlag(queryCmd, &queryCfg.JSON, "json", "j", false, "Output as JSON instead of table")
|
||||
infraGroup.AddCommand(queryCmd)
|
||||
|
||||
// worker — distributed inference worker node.
|
||||
var workerCfg lem.WorkerOpts
|
||||
workerCmd := cli.NewCommand("worker", "Run as distributed inference worker node", "",
|
||||
func(cmd *cli.Command, args []string) error {
|
||||
return lem.RunWorker(workerCfg)
|
||||
},
|
||||
)
|
||||
cli.StringFlag(workerCmd, &workerCfg.APIBase, "api", "", "", "LEM API base URL (or LEM_API env)")
|
||||
cli.StringFlag(workerCmd, &workerCfg.WorkerID, "id", "", "", "Worker ID (or LEM_WORKER_ID env, defaults to machine UUID)")
|
||||
cli.StringFlag(workerCmd, &workerCfg.Name, "name", "n", "", "Worker display name (or LEM_WORKER_NAME env)")
|
||||
cli.StringFlag(workerCmd, &workerCfg.APIKey, "key", "k", "", "API key (or LEM_API_KEY env)")
|
||||
cli.StringFlag(workerCmd, &workerCfg.GPUType, "gpu", "", "", "GPU type (e.g. 'RTX 3090', or LEM_GPU env)")
|
||||
cli.IntFlag(workerCmd, &workerCfg.VRAMGb, "vram", "", 0, "GPU VRAM in GB (or LEM_VRAM_GB env)")
|
||||
cli.StringFlag(workerCmd, &workerCfg.Languages, "languages", "", "", "Comma-separated language codes (or LEM_LANGUAGES env)")
|
||||
cli.StringFlag(workerCmd, &workerCfg.Models, "models", "", "", "Comma-separated supported model names (or LEM_MODELS env)")
|
||||
cli.StringFlag(workerCmd, &workerCfg.InferURL, "infer", "", "", "Local inference endpoint (or LEM_INFER_URL env)")
|
||||
cli.StringFlag(workerCmd, &workerCfg.TaskType, "type", "t", "", "Filter by task type (expand, score, translate, seed)")
|
||||
cli.IntFlag(workerCmd, &workerCfg.BatchSize, "batch", "b", 5, "Number of tasks to fetch per poll")
|
||||
cli.DurationFlag(workerCmd, &workerCfg.PollInterval, "poll", "", 30*time.Second, "Poll interval")
|
||||
cli.BoolFlag(workerCmd, &workerCfg.OneShot, "one-shot", "", false, "Process one batch and exit")
|
||||
cli.BoolFlag(workerCmd, &workerCfg.DryRun, "dry-run", "", false, "Fetch tasks but don't run inference")
|
||||
infraGroup.AddCommand(workerCmd)
|
||||
|
||||
root.AddCommand(infraGroup)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3,6 +3,11 @@
|
|||
package lemcmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"forge.lthn.ai/core/cli/pkg/cli"
|
||||
)
|
||||
|
||||
|
|
@ -16,12 +21,35 @@ func AddLEMCommands(root *cli.Command) {
|
|||
addInfraCommands(root)
|
||||
}
|
||||
|
||||
// passthrough creates a command that passes all args (including flags) to fn.
|
||||
// Used for commands that do their own flag parsing with flag.FlagSet.
|
||||
func passthrough(use, short string, fn func([]string)) *cli.Command {
|
||||
cmd := cli.NewRun(use, short, "", func(_ *cli.Command, args []string) {
|
||||
fn(args)
|
||||
})
|
||||
cmd.DisableFlagParsing = true
|
||||
return cmd
|
||||
// envOr returns the environment variable value, or the fallback if not set.
|
||||
func envOr(key, fallback string) string {
|
||||
if v := os.Getenv(key); v != "" {
|
||||
return v
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
// intEnvOr returns the environment variable value parsed as int, or the fallback.
|
||||
func intEnvOr(key string, fallback int) int {
|
||||
v := os.Getenv(key)
|
||||
if v == "" {
|
||||
return fallback
|
||||
}
|
||||
var n int
|
||||
fmt.Sscanf(v, "%d", &n)
|
||||
if n == 0 {
|
||||
return fallback
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
// expandHome expands a leading ~/ to the user's home directory.
|
||||
func expandHome(path string) string {
|
||||
if strings.HasPrefix(path, "~/") {
|
||||
home, err := os.UserHomeDir()
|
||||
if err == nil {
|
||||
return filepath.Join(home, path[2:])
|
||||
}
|
||||
}
|
||||
return path
|
||||
}
|
||||
|
|
|
|||
|
|
@ -8,11 +8,59 @@ import (
|
|||
func addMonCommands(root *cli.Command) {
|
||||
monGroup := cli.NewGroup("mon", "Monitoring commands", "Training progress, pipeline status, inventory, coverage, and metrics.")
|
||||
|
||||
monGroup.AddCommand(passthrough("status", "Show training and generation progress (InfluxDB)", lem.RunStatus))
|
||||
monGroup.AddCommand(passthrough("expand-status", "Show expansion pipeline status (DuckDB)", lem.RunExpandStatus))
|
||||
monGroup.AddCommand(passthrough("inventory", "Show DuckDB table inventory", lem.RunInventory))
|
||||
monGroup.AddCommand(passthrough("coverage", "Analyse seed coverage gaps", lem.RunCoverage))
|
||||
monGroup.AddCommand(passthrough("metrics", "Push DuckDB golden set stats to InfluxDB", lem.RunMetrics))
|
||||
// status — training and generation progress from InfluxDB.
|
||||
var statusCfg lem.StatusOpts
|
||||
statusCmd := cli.NewCommand("status", "Show training and generation progress (InfluxDB)", "",
|
||||
func(cmd *cli.Command, args []string) error {
|
||||
return lem.RunStatus(statusCfg)
|
||||
},
|
||||
)
|
||||
cli.StringFlag(statusCmd, &statusCfg.Influx, "influx", "", "", "InfluxDB URL (default http://10.69.69.165:8181)")
|
||||
cli.StringFlag(statusCmd, &statusCfg.InfluxDB, "influx-db", "", "", "InfluxDB database name (default training)")
|
||||
cli.StringFlag(statusCmd, &statusCfg.DB, "db", "", "", "DuckDB database path (shows table counts)")
|
||||
monGroup.AddCommand(statusCmd)
|
||||
|
||||
// expand-status — expansion pipeline status from DuckDB.
|
||||
var expandStatusCfg lem.ExpandStatusOpts
|
||||
expandStatusCmd := cli.NewCommand("expand-status", "Show expansion pipeline status (DuckDB)", "",
|
||||
func(cmd *cli.Command, args []string) error {
|
||||
return lem.RunExpandStatus(expandStatusCfg)
|
||||
},
|
||||
)
|
||||
cli.StringFlag(expandStatusCmd, &expandStatusCfg.DB, "db", "", "", "DuckDB database path (defaults to LEM_DB env)")
|
||||
monGroup.AddCommand(expandStatusCmd)
|
||||
|
||||
// inventory — DuckDB table inventory.
|
||||
var inventoryCfg lem.InventoryOpts
|
||||
inventoryCmd := cli.NewCommand("inventory", "Show DuckDB table inventory", "",
|
||||
func(cmd *cli.Command, args []string) error {
|
||||
return lem.RunInventory(inventoryCfg)
|
||||
},
|
||||
)
|
||||
cli.StringFlag(inventoryCmd, &inventoryCfg.DB, "db", "", "", "DuckDB database path (defaults to LEM_DB env)")
|
||||
monGroup.AddCommand(inventoryCmd)
|
||||
|
||||
// coverage — seed coverage gap analysis.
|
||||
var coverageCfg lem.CoverageOpts
|
||||
coverageCmd := cli.NewCommand("coverage", "Analyse seed coverage gaps", "",
|
||||
func(cmd *cli.Command, args []string) error {
|
||||
return lem.RunCoverage(coverageCfg)
|
||||
},
|
||||
)
|
||||
cli.StringFlag(coverageCmd, &coverageCfg.DB, "db", "", "", "DuckDB database path (defaults to LEM_DB env)")
|
||||
monGroup.AddCommand(coverageCmd)
|
||||
|
||||
// metrics — push DuckDB golden set stats to InfluxDB.
|
||||
var metricsCfg lem.MetricsOpts
|
||||
metricsCmd := cli.NewCommand("metrics", "Push DuckDB golden set stats to InfluxDB", "",
|
||||
func(cmd *cli.Command, args []string) error {
|
||||
return lem.RunMetrics(metricsCfg)
|
||||
},
|
||||
)
|
||||
cli.StringFlag(metricsCmd, &metricsCfg.DB, "db", "", "", "DuckDB database path (defaults to LEM_DB env)")
|
||||
cli.StringFlag(metricsCmd, &metricsCfg.Influx, "influx", "", "", "InfluxDB URL")
|
||||
cli.StringFlag(metricsCmd, &metricsCfg.InfluxDB, "influx-db", "", "", "InfluxDB database name")
|
||||
monGroup.AddCommand(metricsCmd)
|
||||
|
||||
root.AddCommand(monGroup)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -10,8 +10,38 @@ import (
|
|||
func addScoreCommands(root *cli.Command) {
|
||||
scoreGroup := cli.NewGroup("score", "Scoring commands", "Score responses, probe models, compare results.")
|
||||
|
||||
scoreGroup.AddCommand(passthrough("run", "Score existing response files", lem.RunScore))
|
||||
scoreGroup.AddCommand(passthrough("probe", "Generate responses and score them", lem.RunProbe))
|
||||
// run — score existing response files.
|
||||
var scoreCfg lem.ScoreOpts
|
||||
scoreCmd := cli.NewCommand("run", "Score existing response files", "",
|
||||
func(cmd *cli.Command, args []string) error {
|
||||
return lem.RunScore(scoreCfg)
|
||||
},
|
||||
)
|
||||
cli.StringFlag(scoreCmd, &scoreCfg.Input, "input", "i", "", "Input JSONL response file (required)")
|
||||
cli.StringFlag(scoreCmd, &scoreCfg.Suites, "suites", "", "all", "Comma-separated suites or 'all'")
|
||||
cli.StringFlag(scoreCmd, &scoreCfg.JudgeModel, "judge-model", "", "mlx-community/gemma-3-27b-it-qat-4bit", "Judge model name")
|
||||
cli.StringFlag(scoreCmd, &scoreCfg.JudgeURL, "judge-url", "", "http://10.69.69.108:8090", "Judge API URL")
|
||||
cli.IntFlag(scoreCmd, &scoreCfg.Concurrency, "concurrency", "c", 4, "Max concurrent judge calls")
|
||||
cli.StringFlag(scoreCmd, &scoreCfg.Output, "output", "o", "scores.json", "Output score file path")
|
||||
cli.BoolFlag(scoreCmd, &scoreCfg.Resume, "resume", "", false, "Resume from existing output, skipping scored IDs")
|
||||
scoreGroup.AddCommand(scoreCmd)
|
||||
|
||||
// probe — generate responses and score them.
|
||||
var probeCfg lem.ProbeOpts
|
||||
probeCmd := cli.NewCommand("probe", "Generate responses and score them", "",
|
||||
func(cmd *cli.Command, args []string) error {
|
||||
return lem.RunProbe(probeCfg)
|
||||
},
|
||||
)
|
||||
cli.StringFlag(probeCmd, &probeCfg.Model, "model", "m", "", "Target model name (required)")
|
||||
cli.StringFlag(probeCmd, &probeCfg.TargetURL, "target-url", "", "", "Target model API URL (defaults to judge-url)")
|
||||
cli.StringFlag(probeCmd, &probeCfg.ProbesFile, "probes", "", "", "Custom probes JSONL file (uses built-in content probes if not specified)")
|
||||
cli.StringFlag(probeCmd, &probeCfg.Suites, "suites", "", "all", "Comma-separated suites or 'all'")
|
||||
cli.StringFlag(probeCmd, &probeCfg.JudgeModel, "judge-model", "", "mlx-community/gemma-3-27b-it-qat-4bit", "Judge model name")
|
||||
cli.StringFlag(probeCmd, &probeCfg.JudgeURL, "judge-url", "", "http://10.69.69.108:8090", "Judge API URL")
|
||||
cli.IntFlag(probeCmd, &probeCfg.Concurrency, "concurrency", "c", 4, "Max concurrent judge calls")
|
||||
cli.StringFlag(probeCmd, &probeCfg.Output, "output", "o", "scores.json", "Output score file path")
|
||||
scoreGroup.AddCommand(probeCmd)
|
||||
|
||||
// compare has a different signature — it takes two named args, not []string.
|
||||
var compareOld, compareNew string
|
||||
|
|
@ -27,9 +57,54 @@ func addScoreCommands(root *cli.Command) {
|
|||
cli.StringFlag(compareCmd, &compareNew, "new", "", "", "New score file (required)")
|
||||
scoreGroup.AddCommand(compareCmd)
|
||||
|
||||
scoreGroup.AddCommand(passthrough("attention", "Q/K Bone Orientation analysis for a prompt", lem.RunAttention))
|
||||
scoreGroup.AddCommand(passthrough("tier", "Score expansion responses (heuristic/judge tiers)", lem.RunTierScore))
|
||||
scoreGroup.AddCommand(passthrough("agent", "ROCm scoring daemon (polls M3, scores checkpoints)", lem.RunAgent))
|
||||
// attention — Q/K Bone Orientation analysis.
|
||||
var attCfg lem.AttentionOpts
|
||||
attCmd := cli.NewCommand("attention", "Q/K Bone Orientation analysis for a prompt", "",
|
||||
func(cmd *cli.Command, args []string) error {
|
||||
return lem.RunAttention(attCfg)
|
||||
},
|
||||
)
|
||||
cli.StringFlag(attCmd, &attCfg.Model, "model", "m", "gemma3/1b", "Model config path (relative to .core/ai/models/)")
|
||||
cli.StringFlag(attCmd, &attCfg.Prompt, "prompt", "p", "", "Prompt text to analyse")
|
||||
cli.BoolFlag(attCmd, &attCfg.JSON, "json", "j", false, "Output as JSON")
|
||||
cli.IntFlag(attCmd, &attCfg.CacheLimit, "cache-limit", "", 0, "Metal cache limit in GB (0 = use ai.yaml default)")
|
||||
cli.IntFlag(attCmd, &attCfg.MemLimit, "mem-limit", "", 0, "Metal memory limit in GB (0 = use ai.yaml default)")
|
||||
cli.StringFlag(attCmd, &attCfg.Root, "root", "", ".", "Project root (for .core/ai/ config)")
|
||||
scoreGroup.AddCommand(attCmd)
|
||||
|
||||
// tier — score expansion responses with heuristic/judge tiers.
|
||||
var tierCfg lem.TierScoreOpts
|
||||
tierCmd := cli.NewCommand("tier", "Score expansion responses (heuristic/judge tiers)", "",
|
||||
func(cmd *cli.Command, args []string) error {
|
||||
return lem.RunTierScore(tierCfg)
|
||||
},
|
||||
)
|
||||
cli.StringFlag(tierCmd, &tierCfg.DBPath, "db", "", "", "DuckDB database path (defaults to LEM_DB env)")
|
||||
cli.IntFlag(tierCmd, &tierCfg.Tier, "tier", "t", 1, "Scoring tier: 1=heuristic, 2=LEM judge, 3=external")
|
||||
cli.IntFlag(tierCmd, &tierCfg.Limit, "limit", "l", 0, "Max items to score (0=all)")
|
||||
scoreGroup.AddCommand(tierCmd)
|
||||
|
||||
// agent — ROCm scoring daemon.
|
||||
var agentCfg lem.AgentOpts
|
||||
agentCmd := cli.NewCommand("agent", "ROCm scoring daemon (polls M3, scores checkpoints)", "",
|
||||
func(cmd *cli.Command, args []string) error {
|
||||
return lem.RunAgent(agentCfg)
|
||||
},
|
||||
)
|
||||
cli.StringFlag(agentCmd, &agentCfg.M3Host, "m3-host", "", envOr("M3_HOST", "10.69.69.108"), "M3 host address")
|
||||
cli.StringFlag(agentCmd, &agentCfg.M3User, "m3-user", "", envOr("M3_USER", "claude"), "M3 SSH user")
|
||||
cli.StringFlag(agentCmd, &agentCfg.M3SSHKey, "m3-ssh-key", "", envOr("M3_SSH_KEY", expandHome("~/.ssh/id_ed25519")), "SSH key for M3")
|
||||
cli.StringFlag(agentCmd, &agentCfg.M3AdapterBase, "m3-adapter-base", "", envOr("M3_ADAPTER_BASE", "/Volumes/Data/lem"), "Adapter base dir on M3")
|
||||
cli.StringFlag(agentCmd, &agentCfg.InfluxURL, "influx", "", envOr("INFLUX_URL", "http://10.69.69.165:8181"), "InfluxDB URL")
|
||||
cli.StringFlag(agentCmd, &agentCfg.InfluxDB, "influx-db", "", envOr("INFLUX_DB", "training"), "InfluxDB database")
|
||||
cli.StringFlag(agentCmd, &agentCfg.APIURL, "api-url", "", envOr("LEM_API_URL", "http://localhost:8080"), "OpenAI-compatible inference API URL")
|
||||
cli.StringFlag(agentCmd, &agentCfg.Model, "model", "m", envOr("LEM_MODEL", ""), "Model name for API (overrides auto-detect)")
|
||||
cli.StringFlag(agentCmd, &agentCfg.BaseModel, "base-model", "", envOr("BASE_MODEL", "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B"), "HuggingFace base model ID")
|
||||
cli.IntFlag(agentCmd, &agentCfg.PollInterval, "poll", "", intEnvOr("POLL_INTERVAL", 300), "Poll interval in seconds")
|
||||
cli.StringFlag(agentCmd, &agentCfg.WorkDir, "work-dir", "", envOr("WORK_DIR", "/tmp/scoring-agent"), "Working directory for adapters")
|
||||
cli.BoolFlag(agentCmd, &agentCfg.OneShot, "one-shot", "", false, "Process one checkpoint and exit")
|
||||
cli.BoolFlag(agentCmd, &agentCfg.DryRun, "dry-run", "", false, "Discover and plan but don't execute")
|
||||
scoreGroup.AddCommand(agentCmd)
|
||||
|
||||
root.AddCommand(scoreGroup)
|
||||
}
|
||||
|
|
|
|||
121
pkg/lem/agent.go
121
pkg/lem/agent.go
|
|
@ -2,7 +2,6 @@ package lem
|
|||
|
||||
import (
|
||||
"encoding/json"
|
||||
"flag"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
|
|
@ -14,21 +13,21 @@ import (
|
|||
"time"
|
||||
)
|
||||
|
||||
// agentConfig holds scoring agent configuration.
|
||||
type agentConfig struct {
|
||||
m3Host string
|
||||
m3User string
|
||||
m3SSHKey string
|
||||
m3AdapterBase string
|
||||
influxURL string
|
||||
influxDB string
|
||||
apiURL string
|
||||
model string
|
||||
baseModel string
|
||||
pollInterval int
|
||||
workDir string
|
||||
oneShot bool
|
||||
dryRun bool
|
||||
// AgentOpts holds scoring agent configuration.
|
||||
type AgentOpts struct {
|
||||
M3Host string
|
||||
M3User string
|
||||
M3SSHKey string
|
||||
M3AdapterBase string
|
||||
InfluxURL string
|
||||
InfluxDB string
|
||||
APIURL string
|
||||
Model string
|
||||
BaseModel string
|
||||
PollInterval int
|
||||
WorkDir string
|
||||
OneShot bool
|
||||
DryRun bool
|
||||
}
|
||||
|
||||
// checkpoint represents a discovered adapter checkpoint on M3.
|
||||
|
|
@ -72,46 +71,26 @@ type bufferEntry struct {
|
|||
// Polls M3 for unscored LoRA checkpoints, converts MLX → PEFT,
|
||||
// runs 23 capability probes via an OpenAI-compatible API, and
|
||||
// pushes results to InfluxDB.
|
||||
func RunAgent(args []string) {
|
||||
fs := flag.NewFlagSet("agent", flag.ExitOnError)
|
||||
|
||||
cfg := &agentConfig{}
|
||||
fs.StringVar(&cfg.m3Host, "m3-host", envOr("M3_HOST", "10.69.69.108"), "M3 host address")
|
||||
fs.StringVar(&cfg.m3User, "m3-user", envOr("M3_USER", "claude"), "M3 SSH user")
|
||||
fs.StringVar(&cfg.m3SSHKey, "m3-ssh-key", envOr("M3_SSH_KEY", expandHome("~/.ssh/id_ed25519")), "SSH key for M3")
|
||||
fs.StringVar(&cfg.m3AdapterBase, "m3-adapter-base", envOr("M3_ADAPTER_BASE", "/Volumes/Data/lem"), "Adapter base dir on M3")
|
||||
fs.StringVar(&cfg.influxURL, "influx", envOr("INFLUX_URL", "http://10.69.69.165:8181"), "InfluxDB URL")
|
||||
fs.StringVar(&cfg.influxDB, "influx-db", envOr("INFLUX_DB", "training"), "InfluxDB database")
|
||||
fs.StringVar(&cfg.apiURL, "api-url", envOr("LEM_API_URL", "http://localhost:8080"), "OpenAI-compatible inference API URL")
|
||||
fs.StringVar(&cfg.model, "model", envOr("LEM_MODEL", ""), "Model name for API (overrides auto-detect)")
|
||||
fs.StringVar(&cfg.baseModel, "base-model", envOr("BASE_MODEL", "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B"), "HuggingFace base model ID")
|
||||
fs.IntVar(&cfg.pollInterval, "poll", intEnvOr("POLL_INTERVAL", 300), "Poll interval in seconds")
|
||||
fs.StringVar(&cfg.workDir, "work-dir", envOr("WORK_DIR", "/tmp/scoring-agent"), "Working directory for adapters")
|
||||
fs.BoolVar(&cfg.oneShot, "one-shot", false, "Process one checkpoint and exit")
|
||||
fs.BoolVar(&cfg.dryRun, "dry-run", false, "Discover and plan but don't execute")
|
||||
|
||||
if err := fs.Parse(args); err != nil {
|
||||
log.Fatalf("parse flags: %v", err)
|
||||
}
|
||||
|
||||
runAgentLoop(cfg)
|
||||
func RunAgent(cfg AgentOpts) error {
|
||||
runAgentLoop(&cfg)
|
||||
return nil
|
||||
}
|
||||
|
||||
func runAgentLoop(cfg *agentConfig) {
|
||||
func runAgentLoop(cfg *AgentOpts) {
|
||||
log.Println(strings.Repeat("=", 60))
|
||||
log.Println("ROCm Scoring Agent — Go Edition")
|
||||
log.Printf("M3: %s@%s", cfg.m3User, cfg.m3Host)
|
||||
log.Printf("Inference API: %s", cfg.apiURL)
|
||||
log.Printf("InfluxDB: %s/%s", cfg.influxURL, cfg.influxDB)
|
||||
log.Printf("Poll interval: %ds", cfg.pollInterval)
|
||||
log.Printf("M3: %s@%s", cfg.M3User, cfg.M3Host)
|
||||
log.Printf("Inference API: %s", cfg.APIURL)
|
||||
log.Printf("InfluxDB: %s/%s", cfg.InfluxURL, cfg.InfluxDB)
|
||||
log.Printf("Poll interval: %ds", cfg.PollInterval)
|
||||
log.Println(strings.Repeat("=", 60))
|
||||
|
||||
influx := NewInfluxClient(cfg.influxURL, cfg.influxDB)
|
||||
os.MkdirAll(cfg.workDir, 0755)
|
||||
influx := NewInfluxClient(cfg.InfluxURL, cfg.InfluxDB)
|
||||
os.MkdirAll(cfg.WorkDir, 0755)
|
||||
|
||||
for {
|
||||
// Replay any buffered results.
|
||||
replayInfluxBuffer(cfg.workDir, influx)
|
||||
replayInfluxBuffer(cfg.WorkDir, influx)
|
||||
|
||||
// Discover checkpoints on M3.
|
||||
log.Println("Discovering checkpoints on M3...")
|
||||
|
|
@ -135,18 +114,18 @@ func runAgentLoop(cfg *agentConfig) {
|
|||
log.Printf("Unscored: %d checkpoints", len(unscored))
|
||||
|
||||
if len(unscored) == 0 {
|
||||
log.Printf("Nothing to score. Sleeping %ds...", cfg.pollInterval)
|
||||
if cfg.oneShot {
|
||||
log.Printf("Nothing to score. Sleeping %ds...", cfg.PollInterval)
|
||||
if cfg.OneShot {
|
||||
return
|
||||
}
|
||||
time.Sleep(time.Duration(cfg.pollInterval) * time.Second)
|
||||
time.Sleep(time.Duration(cfg.PollInterval) * time.Second)
|
||||
continue
|
||||
}
|
||||
|
||||
target := unscored[0]
|
||||
log.Printf("Grabbed: %s (%s)", target.Label, target.Dirname)
|
||||
|
||||
if cfg.dryRun {
|
||||
if cfg.DryRun {
|
||||
log.Printf("[DRY RUN] Would process: %s/%s", target.Dirname, target.Filename)
|
||||
for _, u := range unscored[1:] {
|
||||
log.Printf("[DRY RUN] Queued: %s/%s", u.Dirname, u.Filename)
|
||||
|
|
@ -158,7 +137,7 @@ func runAgentLoop(cfg *agentConfig) {
|
|||
log.Printf("Error processing %s: %v", target.Label, err)
|
||||
}
|
||||
|
||||
if cfg.oneShot {
|
||||
if cfg.OneShot {
|
||||
return
|
||||
}
|
||||
|
||||
|
|
@ -167,8 +146,8 @@ func runAgentLoop(cfg *agentConfig) {
|
|||
}
|
||||
|
||||
// discoverCheckpoints lists all adapter directories and checkpoint files on M3 via SSH.
|
||||
func discoverCheckpoints(cfg *agentConfig) ([]checkpoint, error) {
|
||||
out, err := sshCommand(cfg, fmt.Sprintf("ls -d %s/adapters-deepseek-r1-7b* 2>/dev/null", cfg.m3AdapterBase))
|
||||
func discoverCheckpoints(cfg *AgentOpts) ([]checkpoint, error) {
|
||||
out, err := sshCommand(cfg, fmt.Sprintf("ls -d %s/adapters-deepseek-r1-7b* 2>/dev/null", cfg.M3AdapterBase))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list adapter dirs: %w", err)
|
||||
}
|
||||
|
|
@ -292,12 +271,12 @@ func findUnscored(checkpoints []checkpoint, scored map[[2]string]bool) []checkpo
|
|||
}
|
||||
|
||||
// processOne fetches, converts, scores, and pushes one checkpoint.
|
||||
func processOne(cfg *agentConfig, influx *InfluxClient, cp checkpoint) error {
|
||||
func processOne(cfg *AgentOpts, influx *InfluxClient, cp checkpoint) error {
|
||||
log.Println(strings.Repeat("=", 60))
|
||||
log.Printf("Processing: %s / %s", cp.Dirname, cp.Filename)
|
||||
log.Println(strings.Repeat("=", 60))
|
||||
|
||||
localAdapterDir := filepath.Join(cfg.workDir, cp.Dirname)
|
||||
localAdapterDir := filepath.Join(cfg.WorkDir, cp.Dirname)
|
||||
os.MkdirAll(localAdapterDir, 0755)
|
||||
|
||||
localSF := filepath.Join(localAdapterDir, cp.Filename)
|
||||
|
|
@ -307,7 +286,7 @@ func processOne(cfg *agentConfig, influx *InfluxClient, cp checkpoint) error {
|
|||
defer func() {
|
||||
os.Remove(localSF)
|
||||
os.Remove(localCfg)
|
||||
peftDir := filepath.Join(cfg.workDir, fmt.Sprintf("peft_%07d", cp.Iteration))
|
||||
peftDir := filepath.Join(cfg.WorkDir, fmt.Sprintf("peft_%07d", cp.Iteration))
|
||||
os.RemoveAll(peftDir)
|
||||
}()
|
||||
|
||||
|
|
@ -325,18 +304,18 @@ func processOne(cfg *agentConfig, influx *InfluxClient, cp checkpoint) error {
|
|||
|
||||
// Convert MLX to PEFT format.
|
||||
log.Println("Converting MLX to PEFT format...")
|
||||
peftDir := filepath.Join(cfg.workDir, fmt.Sprintf("peft_%07d", cp.Iteration))
|
||||
if err := convertMLXtoPEFT(localAdapterDir, cp.Filename, peftDir, cfg.baseModel); err != nil {
|
||||
peftDir := filepath.Join(cfg.WorkDir, fmt.Sprintf("peft_%07d", cp.Iteration))
|
||||
if err := convertMLXtoPEFT(localAdapterDir, cp.Filename, peftDir, cfg.BaseModel); err != nil {
|
||||
return fmt.Errorf("convert adapter: %w", err)
|
||||
}
|
||||
|
||||
// Run 23 capability probes via API.
|
||||
log.Println("Running 23 capability probes...")
|
||||
modelName := cfg.model
|
||||
modelName := cfg.Model
|
||||
if modelName == "" {
|
||||
modelName = cp.ModelTag
|
||||
}
|
||||
client := NewClient(cfg.apiURL, modelName)
|
||||
client := NewClient(cfg.APIURL, modelName)
|
||||
client.MaxTokens = 500
|
||||
|
||||
results := runCapabilityProbes(client)
|
||||
|
|
@ -347,7 +326,7 @@ func processOne(cfg *agentConfig, influx *InfluxClient, cp checkpoint) error {
|
|||
// Push to InfluxDB (buffer on failure).
|
||||
if err := pushCapabilityResults(influx, cp, results); err != nil {
|
||||
log.Printf("InfluxDB push failed, buffering: %v", err)
|
||||
bufferInfluxResult(cfg.workDir, cp, results)
|
||||
bufferInfluxResult(cfg.WorkDir, cp, results)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
|
@ -532,13 +511,13 @@ func replayInfluxBuffer(workDir string, influx *InfluxClient) {
|
|||
}
|
||||
|
||||
// sshCommand executes a command on M3 via SSH.
|
||||
func sshCommand(cfg *agentConfig, cmd string) (string, error) {
|
||||
func sshCommand(cfg *AgentOpts, cmd string) (string, error) {
|
||||
sshArgs := []string{
|
||||
"-o", "ConnectTimeout=10",
|
||||
"-o", "BatchMode=yes",
|
||||
"-o", "StrictHostKeyChecking=no",
|
||||
"-i", cfg.m3SSHKey,
|
||||
fmt.Sprintf("%s@%s", cfg.m3User, cfg.m3Host),
|
||||
"-i", cfg.M3SSHKey,
|
||||
fmt.Sprintf("%s@%s", cfg.M3User, cfg.M3Host),
|
||||
cmd,
|
||||
}
|
||||
result, err := exec.Command("ssh", sshArgs...).CombinedOutput()
|
||||
|
|
@ -549,14 +528,14 @@ func sshCommand(cfg *agentConfig, cmd string) (string, error) {
|
|||
}
|
||||
|
||||
// scpFrom copies a file from M3 to a local path.
|
||||
func scpFrom(cfg *agentConfig, remotePath, localPath string) error {
|
||||
func scpFrom(cfg *AgentOpts, remotePath, localPath string) error {
|
||||
os.MkdirAll(filepath.Dir(localPath), 0755)
|
||||
scpArgs := []string{
|
||||
"-o", "ConnectTimeout=10",
|
||||
"-o", "BatchMode=yes",
|
||||
"-o", "StrictHostKeyChecking=no",
|
||||
"-i", cfg.m3SSHKey,
|
||||
fmt.Sprintf("%s@%s:%s", cfg.m3User, cfg.m3Host, remotePath),
|
||||
"-i", cfg.M3SSHKey,
|
||||
fmt.Sprintf("%s@%s:%s", cfg.M3User, cfg.M3Host, remotePath),
|
||||
localPath,
|
||||
}
|
||||
result, err := exec.Command("scp", scpArgs...).CombinedOutput()
|
||||
|
|
@ -574,11 +553,11 @@ func fileBase(path string) string {
|
|||
return path
|
||||
}
|
||||
|
||||
func sleepOrExit(cfg *agentConfig) {
|
||||
if cfg.oneShot {
|
||||
func sleepOrExit(cfg *AgentOpts) {
|
||||
if cfg.OneShot {
|
||||
return
|
||||
}
|
||||
time.Sleep(time.Duration(cfg.pollInterval) * time.Second)
|
||||
time.Sleep(time.Duration(cfg.PollInterval) * time.Second)
|
||||
}
|
||||
|
||||
func envOr(key, fallback string) string {
|
||||
|
|
|
|||
|
|
@ -2,41 +2,38 @@ package lem
|
|||
|
||||
import (
|
||||
"encoding/json"
|
||||
"flag"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
)
|
||||
|
||||
// ApproveOpts holds configuration for the approve command.
|
||||
type ApproveOpts struct {
|
||||
DB string // DuckDB database path (defaults to LEM_DB env)
|
||||
Output string // Output JSONL file (defaults to expansion-approved.jsonl in db dir)
|
||||
Threshold float64 // Min judge average to approve (default: 6.0)
|
||||
}
|
||||
|
||||
// RunApprove is the CLI entry point for the approve command.
|
||||
// Filters scored expansion responses by quality threshold and exports
|
||||
// approved ones as chat-format training JSONL.
|
||||
func RunApprove(args []string) {
|
||||
fs := flag.NewFlagSet("approve", flag.ExitOnError)
|
||||
dbPath := fs.String("db", "", "DuckDB database path (defaults to LEM_DB env)")
|
||||
output := fs.String("output", "", "Output JSONL file (defaults to expansion-approved.jsonl in db dir)")
|
||||
threshold := fs.Float64("threshold", 6.0, "Min judge average to approve (default: 6.0)")
|
||||
|
||||
if err := fs.Parse(args); err != nil {
|
||||
log.Fatalf("parse flags: %v", err)
|
||||
func RunApprove(cfg ApproveOpts) error {
|
||||
dbPath := cfg.DB
|
||||
if dbPath == "" {
|
||||
dbPath = os.Getenv("LEM_DB")
|
||||
}
|
||||
if dbPath == "" {
|
||||
return fmt.Errorf("--db or LEM_DB required")
|
||||
}
|
||||
|
||||
if *dbPath == "" {
|
||||
*dbPath = os.Getenv("LEM_DB")
|
||||
}
|
||||
if *dbPath == "" {
|
||||
fmt.Fprintln(os.Stderr, "error: --db or LEM_DB required")
|
||||
os.Exit(1)
|
||||
output := cfg.Output
|
||||
if output == "" {
|
||||
output = filepath.Join(filepath.Dir(dbPath), "expansion-approved.jsonl")
|
||||
}
|
||||
|
||||
if *output == "" {
|
||||
*output = filepath.Join(filepath.Dir(*dbPath), "expansion-approved.jsonl")
|
||||
}
|
||||
|
||||
db, err := OpenDB(*dbPath)
|
||||
db, err := OpenDB(dbPath)
|
||||
if err != nil {
|
||||
log.Fatalf("open db: %v", err)
|
||||
return fmt.Errorf("open db: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
|
|
@ -51,13 +48,13 @@ func RunApprove(args []string) {
|
|||
ORDER BY r.idx
|
||||
`)
|
||||
if err != nil {
|
||||
log.Fatalf("query approved: %v (have you run scoring?)", err)
|
||||
return fmt.Errorf("query approved: %v (have you run scoring?)", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
f, err := os.Create(*output)
|
||||
f, err := os.Create(output)
|
||||
if err != nil {
|
||||
log.Fatalf("create output: %v", err)
|
||||
return fmt.Errorf("create output: %v", err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
|
|
@ -71,7 +68,7 @@ func RunApprove(args []string) {
|
|||
var seedID, region, domain, prompt, response, model string
|
||||
var genTime, score float64
|
||||
if err := rows.Scan(&idx, &seedID, ®ion, &domain, &prompt, &response, &genTime, &model, &score); err != nil {
|
||||
log.Fatalf("scan: %v", err)
|
||||
return fmt.Errorf("scan: %v", err)
|
||||
}
|
||||
|
||||
example := TrainingExample{
|
||||
|
|
@ -82,7 +79,7 @@ func RunApprove(args []string) {
|
|||
}
|
||||
|
||||
if err := enc.Encode(example); err != nil {
|
||||
log.Fatalf("encode: %v", err)
|
||||
return fmt.Errorf("encode: %v", err)
|
||||
}
|
||||
|
||||
regionSet[region] = true
|
||||
|
|
@ -90,9 +87,10 @@ func RunApprove(args []string) {
|
|||
count++
|
||||
}
|
||||
|
||||
_ = *threshold // threshold used in query above for future judge scoring
|
||||
_ = cfg.Threshold // threshold used in query above for future judge scoring
|
||||
|
||||
fmt.Printf("Approved: %d responses (threshold: heuristic > 0)\n", count)
|
||||
fmt.Printf("Exported: %s\n", *output)
|
||||
fmt.Printf("Exported: %s\n", output)
|
||||
fmt.Printf(" Regions: %d, Domains: %d\n", len(regionSet), len(domainSet))
|
||||
return nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3,62 +3,49 @@ package lem
|
|||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"flag"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"forge.lthn.ai/core/go-ml"
|
||||
"forge.lthn.ai/core/go-mlx"
|
||||
)
|
||||
|
||||
// AttentionOpts holds configuration for the attention command.
|
||||
type AttentionOpts struct {
|
||||
Model string // Model config path (relative to .core/ai/models/)
|
||||
Prompt string // Prompt text to analyse
|
||||
JSON bool // Output as JSON
|
||||
CacheLimit int // Metal cache limit in GB (0 = use ai.yaml default)
|
||||
MemLimit int // Metal memory limit in GB (0 = use ai.yaml default)
|
||||
Root string // Project root (for .core/ai/ config)
|
||||
}
|
||||
|
||||
// RunAttention is the entry point for `lem score attention`.
|
||||
func RunAttention(args []string) {
|
||||
fs := flag.NewFlagSet("attention", flag.ExitOnError)
|
||||
var (
|
||||
modelFlag string
|
||||
prompt string
|
||||
jsonOutput bool
|
||||
cacheLimit int
|
||||
memLimit int
|
||||
root string
|
||||
)
|
||||
fs.StringVar(&modelFlag, "model", "gemma3/1b", "Model config path (relative to .core/ai/models/)")
|
||||
fs.StringVar(&prompt, "prompt", "", "Prompt text to analyse")
|
||||
fs.BoolVar(&jsonOutput, "json", false, "Output as JSON")
|
||||
fs.IntVar(&cacheLimit, "cache-limit", 0, "Metal cache limit in GB (0 = use ai.yaml default)")
|
||||
fs.IntVar(&memLimit, "mem-limit", 0, "Metal memory limit in GB (0 = use ai.yaml default)")
|
||||
fs.StringVar(&root, "root", ".", "Project root (for .core/ai/ config)")
|
||||
|
||||
if err := fs.Parse(args); err != nil {
|
||||
log.Fatalf("parse flags: %v", err)
|
||||
}
|
||||
|
||||
if prompt == "" {
|
||||
fmt.Fprintln(os.Stderr, "usage: lem score attention -model <path> -prompt <text>")
|
||||
os.Exit(1)
|
||||
func RunAttention(cfg AttentionOpts) error {
|
||||
if cfg.Prompt == "" {
|
||||
return fmt.Errorf("--prompt is required")
|
||||
}
|
||||
|
||||
// Load configs.
|
||||
aiCfg, err := LoadAIConfig(root)
|
||||
aiCfg, err := LoadAIConfig(cfg.Root)
|
||||
if err != nil {
|
||||
log.Fatalf("load ai config: %v", err)
|
||||
return fmt.Errorf("load ai config: %w", err)
|
||||
}
|
||||
|
||||
modelCfg, err := LoadModelConfig(root, modelFlag)
|
||||
modelCfg, err := LoadModelConfig(cfg.Root, cfg.Model)
|
||||
if err != nil {
|
||||
log.Fatalf("load model config: %v", err)
|
||||
return fmt.Errorf("load model config: %w", err)
|
||||
}
|
||||
|
||||
// Apply memory limits.
|
||||
cacheLimitGB := aiCfg.Distill.CacheLimit
|
||||
if cacheLimit > 0 {
|
||||
cacheLimitGB = cacheLimit
|
||||
if cfg.CacheLimit > 0 {
|
||||
cacheLimitGB = cfg.CacheLimit
|
||||
}
|
||||
memLimitGB := aiCfg.Distill.MemoryLimit
|
||||
if memLimit > 0 {
|
||||
memLimitGB = memLimit
|
||||
if cfg.MemLimit > 0 {
|
||||
memLimitGB = cfg.MemLimit
|
||||
}
|
||||
if cacheLimitGB > 0 {
|
||||
mlx.SetCacheLimit(uint64(cacheLimitGB) * 1024 * 1024 * 1024)
|
||||
|
|
@ -71,7 +58,7 @@ func RunAttention(args []string) {
|
|||
log.Printf("loading model: %s", modelCfg.Paths.Base)
|
||||
backend, err := ml.NewMLXBackend(modelCfg.Paths.Base)
|
||||
if err != nil {
|
||||
log.Fatalf("load model: %v", err)
|
||||
return fmt.Errorf("load model: %w", err)
|
||||
}
|
||||
defer backend.Close()
|
||||
|
||||
|
|
@ -79,18 +66,17 @@ func RunAttention(args []string) {
|
|||
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
|
||||
defer cancel()
|
||||
|
||||
snap, err := backend.InspectAttention(ctx, prompt)
|
||||
snap, err := backend.InspectAttention(ctx, cfg.Prompt)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "error: %v\n", err)
|
||||
os.Exit(1)
|
||||
return fmt.Errorf("inspect attention: %w", err)
|
||||
}
|
||||
|
||||
result := AnalyseAttention(snap)
|
||||
|
||||
if jsonOutput {
|
||||
if cfg.JSON {
|
||||
data, _ := json.MarshalIndent(result, "", " ")
|
||||
fmt.Println(string(data))
|
||||
return
|
||||
return nil
|
||||
}
|
||||
|
||||
// Human-readable output.
|
||||
|
|
@ -107,4 +93,5 @@ func RunAttention(args []string) {
|
|||
fmt.Printf(" Phase-Lock Score: %.3f\n", result.PhaseLockScore)
|
||||
fmt.Printf(" Joint Collapses: %d\n", result.JointCollapseCount)
|
||||
fmt.Printf(" Composite (0-10000): %d\n", result.Composite())
|
||||
return nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@ package lem
|
|||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"flag"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
|
|
@ -13,34 +12,38 @@ import (
|
|||
"strings"
|
||||
)
|
||||
|
||||
// ConsolidateOpts holds configuration for the consolidate command.
|
||||
type ConsolidateOpts struct {
|
||||
Host string // SSH host for remote files
|
||||
Remote string // Remote directory for JSONL files
|
||||
Pattern string // File glob pattern
|
||||
OutputDir string // Output directory (defaults to ./responses)
|
||||
Merged string // Merged output file (defaults to gold-merged.jsonl in output dir)
|
||||
}
|
||||
|
||||
// RunConsolidate is the CLI entry point for the consolidate command.
|
||||
// Pulls all worker JSONLs from M3, merges them, deduplicates on idx,
|
||||
// and writes a single merged file.
|
||||
func RunConsolidate(args []string) {
|
||||
fs := flag.NewFlagSet("consolidate", flag.ExitOnError)
|
||||
remoteHost := fs.String("host", "m3", "SSH host for remote files")
|
||||
remotePath := fs.String("remote", "/Volumes/Data/lem/responses", "Remote directory for JSONL files")
|
||||
pattern := fs.String("pattern", "gold*.jsonl", "File glob pattern")
|
||||
outputDir := fs.String("output", "", "Output directory (defaults to ./responses)")
|
||||
merged := fs.String("merged", "", "Merged output file (defaults to gold-merged.jsonl in output dir)")
|
||||
func RunConsolidate(cfg ConsolidateOpts) error {
|
||||
remoteHost := cfg.Host
|
||||
remotePath := cfg.Remote
|
||||
pattern := cfg.Pattern
|
||||
outputDir := cfg.OutputDir
|
||||
merged := cfg.Merged
|
||||
|
||||
if err := fs.Parse(args); err != nil {
|
||||
log.Fatalf("parse flags: %v", err)
|
||||
if outputDir == "" {
|
||||
outputDir = "responses"
|
||||
}
|
||||
|
||||
if *outputDir == "" {
|
||||
*outputDir = "responses"
|
||||
}
|
||||
if err := os.MkdirAll(*outputDir, 0755); err != nil {
|
||||
log.Fatalf("create output dir: %v", err)
|
||||
if err := os.MkdirAll(outputDir, 0755); err != nil {
|
||||
return fmt.Errorf("create output dir: %v", err)
|
||||
}
|
||||
|
||||
// List remote files.
|
||||
fmt.Println("Pulling responses from remote...")
|
||||
listCmd := exec.Command("ssh", *remoteHost, fmt.Sprintf("ls %s/%s", *remotePath, *pattern))
|
||||
listCmd := exec.Command("ssh", remoteHost, fmt.Sprintf("ls %s/%s", remotePath, pattern))
|
||||
listOutput, err := listCmd.Output()
|
||||
if err != nil {
|
||||
log.Fatalf("list remote files: %v", err)
|
||||
return fmt.Errorf("list remote files: %v", err)
|
||||
}
|
||||
|
||||
remoteFiles := strings.Split(strings.TrimSpace(string(listOutput)), "\n")
|
||||
|
|
@ -51,12 +54,12 @@ func RunConsolidate(args []string) {
|
|||
validFiles = append(validFiles, f)
|
||||
}
|
||||
}
|
||||
fmt.Printf(" Found %d JSONL files on %s\n", len(validFiles), *remoteHost)
|
||||
fmt.Printf(" Found %d JSONL files on %s\n", len(validFiles), remoteHost)
|
||||
|
||||
// Pull files.
|
||||
for _, rf := range validFiles {
|
||||
local := filepath.Join(*outputDir, filepath.Base(rf))
|
||||
scpCmd := exec.Command("scp", fmt.Sprintf("%s:%s", *remoteHost, rf), local)
|
||||
local := filepath.Join(outputDir, filepath.Base(rf))
|
||||
scpCmd := exec.Command("scp", fmt.Sprintf("%s:%s", remoteHost, rf), local)
|
||||
if err := scpCmd.Run(); err != nil {
|
||||
log.Printf("warning: failed to pull %s: %v", rf, err)
|
||||
continue
|
||||
|
|
@ -80,7 +83,7 @@ func RunConsolidate(args []string) {
|
|||
seen := make(map[int]json.RawMessage)
|
||||
skipped := 0
|
||||
|
||||
matches, _ := filepath.Glob(filepath.Join(*outputDir, *pattern))
|
||||
matches, _ := filepath.Glob(filepath.Join(outputDir, pattern))
|
||||
sort.Strings(matches)
|
||||
|
||||
for _, local := range matches {
|
||||
|
|
@ -115,8 +118,8 @@ func RunConsolidate(args []string) {
|
|||
}
|
||||
|
||||
// Sort by idx and write merged file.
|
||||
if *merged == "" {
|
||||
*merged = filepath.Join(*outputDir, "..", "gold-merged.jsonl")
|
||||
if merged == "" {
|
||||
merged = filepath.Join(outputDir, "..", "gold-merged.jsonl")
|
||||
}
|
||||
|
||||
idxs := make([]int, 0, len(seen))
|
||||
|
|
@ -125,9 +128,9 @@ func RunConsolidate(args []string) {
|
|||
}
|
||||
sort.Ints(idxs)
|
||||
|
||||
f, err := os.Create(*merged)
|
||||
f, err := os.Create(merged)
|
||||
if err != nil {
|
||||
log.Fatalf("create merged file: %v", err)
|
||||
return fmt.Errorf("create merged file: %v", err)
|
||||
}
|
||||
for _, idx := range idxs {
|
||||
f.Write(seen[idx])
|
||||
|
|
@ -135,5 +138,6 @@ func RunConsolidate(args []string) {
|
|||
}
|
||||
f.Close()
|
||||
|
||||
fmt.Printf("\nMerged: %d unique examples → %s\n", len(seen), *merged)
|
||||
fmt.Printf("\nMerged: %d unique examples → %s\n", len(seen), merged)
|
||||
return nil
|
||||
}
|
||||
|
|
|
|||
114
pkg/lem/conv.go
114
pkg/lem/conv.go
|
|
@ -3,7 +3,6 @@ package lem
|
|||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"flag"
|
||||
"fmt"
|
||||
"log"
|
||||
"math/rand"
|
||||
|
|
@ -11,86 +10,79 @@ import (
|
|||
"strings"
|
||||
)
|
||||
|
||||
// RunConv is the CLI entry point for the conv command.
|
||||
// It generates multi-turn conversational training data from built-in
|
||||
// ConvOpts holds configuration for the conv command.
|
||||
type ConvOpts struct {
|
||||
OutputDir string // Output directory for training files (required)
|
||||
Extra string // Additional conversations JSONL file (multi-turn format)
|
||||
Golden string // Golden set JSONL to convert to single-turn conversations
|
||||
DB string // DuckDB database path for golden set (alternative to --golden)
|
||||
TrainPct int // Training set percentage
|
||||
ValidPct int // Validation set percentage
|
||||
TestPct int // Test set percentage
|
||||
Seed int64 // Random seed for shuffling
|
||||
MinChars int // Minimum response chars for golden set conversion
|
||||
NoBuiltin bool // Exclude built-in seed conversations
|
||||
Influx string // InfluxDB URL for progress reporting
|
||||
InfluxDB string // InfluxDB database name
|
||||
Worker string // Worker hostname for InfluxDB reporting
|
||||
}
|
||||
|
||||
// RunConv generates multi-turn conversational training data from built-in
|
||||
// seed conversations plus optional extra files and golden set data.
|
||||
func RunConv(args []string) {
|
||||
fs := flag.NewFlagSet("conv", flag.ExitOnError)
|
||||
|
||||
outputDir := fs.String("output-dir", "", "Output directory for training files (required)")
|
||||
extra := fs.String("extra", "", "Additional conversations JSONL file (multi-turn format)")
|
||||
golden := fs.String("golden", "", "Golden set JSONL to convert to single-turn conversations")
|
||||
dbPath := fs.String("db", "", "DuckDB database path for golden set (alternative to --golden)")
|
||||
trainPct := fs.Int("train-pct", 80, "Training set percentage")
|
||||
validPct := fs.Int("valid-pct", 10, "Validation set percentage")
|
||||
testPct := fs.Int("test-pct", 10, "Test set percentage")
|
||||
seed := fs.Int64("seed", 42, "Random seed for shuffling")
|
||||
minChars := fs.Int("min-chars", 50, "Minimum response chars for golden set conversion")
|
||||
noBuiltin := fs.Bool("no-builtin", false, "Exclude built-in seed conversations")
|
||||
influxURL := fs.String("influx", "", "InfluxDB URL for progress reporting")
|
||||
influxDB := fs.String("influx-db", "", "InfluxDB database name")
|
||||
worker := fs.String("worker", "", "Worker hostname for InfluxDB reporting")
|
||||
|
||||
if err := fs.Parse(args); err != nil {
|
||||
log.Fatalf("parse flags: %v", err)
|
||||
func RunConv(cfg ConvOpts) error {
|
||||
if cfg.OutputDir == "" {
|
||||
return fmt.Errorf("--output-dir is required")
|
||||
}
|
||||
|
||||
if *outputDir == "" {
|
||||
fmt.Fprintln(os.Stderr, "error: --output-dir is required")
|
||||
fs.Usage()
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if err := validatePercentages(*trainPct, *validPct, *testPct); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "error: %v\n", err)
|
||||
os.Exit(1)
|
||||
if err := validatePercentages(cfg.TrainPct, cfg.ValidPct, cfg.TestPct); err != nil {
|
||||
return fmt.Errorf("percentages: %w", err)
|
||||
}
|
||||
|
||||
// Check LEM_DB env as default for --db.
|
||||
if *dbPath == "" {
|
||||
*dbPath = os.Getenv("LEM_DB")
|
||||
if cfg.DB == "" {
|
||||
cfg.DB = os.Getenv("LEM_DB")
|
||||
}
|
||||
|
||||
// Default worker to hostname.
|
||||
if *worker == "" {
|
||||
if cfg.Worker == "" {
|
||||
hostname, err := os.Hostname()
|
||||
if err != nil {
|
||||
hostname = "unknown"
|
||||
}
|
||||
*worker = hostname
|
||||
cfg.Worker = hostname
|
||||
}
|
||||
|
||||
// Collect all conversations.
|
||||
var conversations []TrainingExample
|
||||
|
||||
// 1. Built-in seed conversations.
|
||||
if !*noBuiltin {
|
||||
if !cfg.NoBuiltin {
|
||||
conversations = append(conversations, SeedConversations...)
|
||||
log.Printf("loaded %d built-in seed conversations", len(SeedConversations))
|
||||
}
|
||||
|
||||
// 2. Extra conversations from file.
|
||||
if *extra != "" {
|
||||
extras, err := readConversations(*extra)
|
||||
if cfg.Extra != "" {
|
||||
extras, err := readConversations(cfg.Extra)
|
||||
if err != nil {
|
||||
log.Fatalf("read extra conversations: %v", err)
|
||||
return fmt.Errorf("read extra conversations: %w", err)
|
||||
}
|
||||
conversations = append(conversations, extras...)
|
||||
log.Printf("loaded %d extra conversations from %s", len(extras), *extra)
|
||||
log.Printf("loaded %d extra conversations from %s", len(extras), cfg.Extra)
|
||||
}
|
||||
|
||||
// 3. Golden set responses converted to single-turn format.
|
||||
var goldenResponses []Response
|
||||
if *dbPath != "" && *golden == "" {
|
||||
db, err := OpenDB(*dbPath)
|
||||
if cfg.DB != "" && cfg.Golden == "" {
|
||||
db, err := OpenDB(cfg.DB)
|
||||
if err != nil {
|
||||
log.Fatalf("open db: %v", err)
|
||||
return fmt.Errorf("open db: %w", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
rows, err := db.QueryGoldenSet(*minChars)
|
||||
rows, err := db.QueryGoldenSet(cfg.MinChars)
|
||||
if err != nil {
|
||||
log.Fatalf("query golden_set: %v", err)
|
||||
return fmt.Errorf("query golden_set: %w", err)
|
||||
}
|
||||
for _, r := range rows {
|
||||
goldenResponses = append(goldenResponses, Response{
|
||||
|
|
@ -101,32 +93,32 @@ func RunConv(args []string) {
|
|||
Model: r.Voice,
|
||||
})
|
||||
}
|
||||
log.Printf("loaded %d golden set rows from %s", len(goldenResponses), *dbPath)
|
||||
} else if *golden != "" {
|
||||
log.Printf("loaded %d golden set rows from %s", len(goldenResponses), cfg.DB)
|
||||
} else if cfg.Golden != "" {
|
||||
var err error
|
||||
goldenResponses, err = ReadResponses(*golden)
|
||||
goldenResponses, err = ReadResponses(cfg.Golden)
|
||||
if err != nil {
|
||||
log.Fatalf("read golden set: %v", err)
|
||||
return fmt.Errorf("read golden set: %w", err)
|
||||
}
|
||||
log.Printf("loaded %d golden set responses from %s", len(goldenResponses), *golden)
|
||||
log.Printf("loaded %d golden set responses from %s", len(goldenResponses), cfg.Golden)
|
||||
}
|
||||
|
||||
if len(goldenResponses) > 0 {
|
||||
converted := convertToConversations(goldenResponses, *minChars)
|
||||
converted := convertToConversations(goldenResponses, cfg.MinChars)
|
||||
conversations = append(conversations, converted...)
|
||||
log.Printf("converted %d golden set responses to single-turn conversations", len(converted))
|
||||
}
|
||||
|
||||
if len(conversations) == 0 {
|
||||
log.Fatal("no conversations to process — use built-in seeds, --extra, --golden, or --db")
|
||||
return fmt.Errorf("no conversations to process — use built-in seeds, --extra, --golden, or --db")
|
||||
}
|
||||
|
||||
// Split into train/valid/test.
|
||||
train, valid, test := splitConversations(conversations, *trainPct, *validPct, *testPct, *seed)
|
||||
train, valid, test := splitConversations(conversations, cfg.TrainPct, cfg.ValidPct, cfg.TestPct, cfg.Seed)
|
||||
|
||||
// Create output directory.
|
||||
if err := os.MkdirAll(*outputDir, 0755); err != nil {
|
||||
log.Fatalf("create output dir: %v", err)
|
||||
if err := os.MkdirAll(cfg.OutputDir, 0755); err != nil {
|
||||
return fmt.Errorf("create output dir: %w", err)
|
||||
}
|
||||
|
||||
// Write output files.
|
||||
|
|
@ -138,9 +130,9 @@ func RunConv(args []string) {
|
|||
{"valid.jsonl", valid},
|
||||
{"test.jsonl", test},
|
||||
} {
|
||||
path := *outputDir + "/" + split.name
|
||||
path := cfg.OutputDir + "/" + split.name
|
||||
if err := writeConversationJSONL(path, split.data); err != nil {
|
||||
log.Fatalf("write %s: %v", split.name, err)
|
||||
return fmt.Errorf("write %s: %w", split.name, err)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -168,16 +160,18 @@ func RunConv(args []string) {
|
|||
fmt.Printf(" %d total conversations\n", len(conversations))
|
||||
fmt.Printf(" %d total turns (%.1f avg per conversation)\n", totalTurns, avgTurns)
|
||||
fmt.Printf(" %.0f words avg per assistant response\n", avgWords)
|
||||
fmt.Printf(" Output: %s/\n", *outputDir)
|
||||
fmt.Printf(" Output: %s/\n", cfg.OutputDir)
|
||||
|
||||
// Report to InfluxDB if configured.
|
||||
influx := NewInfluxClient(*influxURL, *influxDB)
|
||||
influx := NewInfluxClient(cfg.Influx, cfg.InfluxDB)
|
||||
line := fmt.Sprintf("conv_export,worker=%s total=%di,train=%di,valid=%di,test=%di,turns=%di,avg_turns=%f,avg_words=%f",
|
||||
escapeLp(*worker), len(conversations), len(train), len(valid), len(test),
|
||||
escapeLp(cfg.Worker), len(conversations), len(train), len(valid), len(test),
|
||||
totalTurns, avgTurns, avgWords)
|
||||
if err := influx.WriteLp([]string{line}); err != nil {
|
||||
log.Printf("influx write (best-effort): %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// readConversations reads multi-turn conversations from a JSONL file.
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@ package lem
|
|||
import (
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"flag"
|
||||
"fmt"
|
||||
"log"
|
||||
"math"
|
||||
|
|
@ -15,34 +14,30 @@ import (
|
|||
"strings"
|
||||
)
|
||||
|
||||
// ConvertOpts holds configuration for the MLX-to-PEFT conversion command.
|
||||
type ConvertOpts struct {
|
||||
Input string // Path to MLX .safetensors file (required)
|
||||
Config string // Path to MLX adapter_config.json (required)
|
||||
Output string // Output directory for PEFT adapter
|
||||
BaseModel string // HuggingFace base model ID
|
||||
}
|
||||
|
||||
// RunConvert is the CLI entry point for the convert command.
|
||||
// Converts MLX LoRA adapters to HuggingFace PEFT format:
|
||||
// - Key renaming: model.layers.N.module.lora_a → base_model.model.model.layers.N.module.lora_A.default.weight
|
||||
// - Transpose: MLX (in, rank) → PEFT (rank, in)
|
||||
// - Config generation: adapter_config.json with lora_alpha = scale × rank
|
||||
func RunConvert(args []string) {
|
||||
fs := flag.NewFlagSet("convert", flag.ExitOnError)
|
||||
|
||||
safetensorsPath := fs.String("input", "", "Path to MLX .safetensors file (required)")
|
||||
configPath := fs.String("config", "", "Path to MLX adapter_config.json (required)")
|
||||
outputDir := fs.String("output", "./peft_output", "Output directory for PEFT adapter")
|
||||
baseModel := fs.String("base-model", "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B", "HuggingFace base model ID")
|
||||
|
||||
if err := fs.Parse(args); err != nil {
|
||||
log.Fatalf("parse flags: %v", err)
|
||||
func RunConvert(cfg ConvertOpts) error {
|
||||
if cfg.Input == "" || cfg.Config == "" {
|
||||
return fmt.Errorf("--input and --config are required")
|
||||
}
|
||||
|
||||
if *safetensorsPath == "" || *configPath == "" {
|
||||
fmt.Fprintln(os.Stderr, "error: --input and --config are required")
|
||||
fs.Usage()
|
||||
os.Exit(1)
|
||||
if err := convertMLXtoPEFT(cfg.Input, cfg.Config, cfg.Output, cfg.BaseModel); err != nil {
|
||||
return fmt.Errorf("convert: %w", err)
|
||||
}
|
||||
|
||||
if err := convertMLXtoPEFT(*safetensorsPath, *configPath, *outputDir, *baseModel); err != nil {
|
||||
log.Fatalf("convert: %v", err)
|
||||
}
|
||||
|
||||
fmt.Printf("Converted to: %s\n", *outputDir)
|
||||
fmt.Printf("Converted to: %s\n", cfg.Output)
|
||||
return nil
|
||||
}
|
||||
|
||||
var (
|
||||
|
|
|
|||
|
|
@ -1,40 +1,34 @@
|
|||
package lem
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// RunCoverage is the CLI entry point for the coverage command.
|
||||
// Analyzes seed coverage and shows underrepresented areas.
|
||||
func RunCoverage(args []string) {
|
||||
fs := flag.NewFlagSet("coverage", flag.ExitOnError)
|
||||
dbPath := fs.String("db", "", "DuckDB database path (defaults to LEM_DB env)")
|
||||
// CoverageOpts holds configuration for the coverage command.
|
||||
type CoverageOpts struct {
|
||||
DB string // DuckDB database path (defaults to LEM_DB env)
|
||||
}
|
||||
|
||||
if err := fs.Parse(args); err != nil {
|
||||
log.Fatalf("parse flags: %v", err)
|
||||
// RunCoverage analyses seed coverage and shows underrepresented areas.
|
||||
func RunCoverage(cfg CoverageOpts) error {
|
||||
if cfg.DB == "" {
|
||||
cfg.DB = os.Getenv("LEM_DB")
|
||||
}
|
||||
if cfg.DB == "" {
|
||||
return fmt.Errorf("--db or LEM_DB required")
|
||||
}
|
||||
|
||||
if *dbPath == "" {
|
||||
*dbPath = os.Getenv("LEM_DB")
|
||||
}
|
||||
if *dbPath == "" {
|
||||
fmt.Fprintln(os.Stderr, "error: --db or LEM_DB required")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
db, err := OpenDB(*dbPath)
|
||||
db, err := OpenDB(cfg.DB)
|
||||
if err != nil {
|
||||
log.Fatalf("open db: %v", err)
|
||||
return fmt.Errorf("open db: %w", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
var total int
|
||||
if err := db.conn.QueryRow("SELECT count(*) FROM seeds").Scan(&total); err != nil {
|
||||
log.Fatalf("No seeds table. Run: lem import-all first")
|
||||
return fmt.Errorf("no seeds table — run: lem import-all first")
|
||||
}
|
||||
|
||||
fmt.Println("LEM Seed Coverage Analysis")
|
||||
|
|
@ -65,7 +59,7 @@ func RunCoverage(args []string) {
|
|||
FROM seeds GROUP BY lang_group ORDER BY n ASC
|
||||
`)
|
||||
if err != nil {
|
||||
log.Fatalf("query regions: %v", err)
|
||||
return fmt.Errorf("query regions: %w", err)
|
||||
}
|
||||
|
||||
type regionRow struct {
|
||||
|
|
@ -129,6 +123,8 @@ func RunCoverage(args []string) {
|
|||
fmt.Println(" - Hindi/Urdu, Bengali, Tamil (South Asian)")
|
||||
fmt.Println(" - Swahili, Yoruba, Amharic (Sub-Saharan Africa)")
|
||||
fmt.Println(" - Indigenous languages (Quechua, Nahuatl, Aymara)")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// PrintScoreAnalytics prints score distribution statistics and gap analysis
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@ package lem
|
|||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"flag"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
|
|
@ -33,67 +32,66 @@ type distillCandidate struct {
|
|||
Elapsed time.Duration
|
||||
}
|
||||
|
||||
// DistillOpts holds configuration for the distill command.
|
||||
type DistillOpts struct {
|
||||
Model string // Model config path (relative to .core/ai/models/)
|
||||
Probes string // Probe set name from probes.yaml, or path to JSON file
|
||||
Output string // Output JSONL path (defaults to model training dir)
|
||||
Lesson int // Lesson number to append to (defaults to probe set phase)
|
||||
MinScore float64 // Min grammar composite (0 = use ai.yaml default)
|
||||
Runs int // Generations per probe (0 = use ai.yaml default)
|
||||
DryRun bool // Show plan and exit without generating
|
||||
Root string // Project root (for .core/ai/ config)
|
||||
CacheLimit int // Metal cache limit in GB (0 = use ai.yaml default)
|
||||
MemLimit int // Metal memory limit in GB (0 = use ai.yaml default)
|
||||
}
|
||||
|
||||
// RunDistill is the CLI entry point for the distill command.
|
||||
// Generates responses via native Metal inference, scores with go-i18n/reversal,
|
||||
// writes passing examples as training JSONL.
|
||||
func RunDistill(args []string) {
|
||||
fs := flag.NewFlagSet("distill", flag.ExitOnError)
|
||||
|
||||
modelFlag := fs.String("model", "gemma3/27b", "Model config path (relative to .core/ai/models/)")
|
||||
probesFlag := fs.String("probes", "", "Probe set name from probes.yaml, or path to JSON file")
|
||||
outputFlag := fs.String("output", "", "Output JSONL path (defaults to model training dir)")
|
||||
lessonFlag := fs.Int("lesson", -1, "Lesson number to append to (defaults to probe set phase)")
|
||||
minScore := fs.Float64("min-score", 0, "Min grammar composite (0 = use ai.yaml default)")
|
||||
runs := fs.Int("runs", 0, "Generations per probe (0 = use ai.yaml default)")
|
||||
dryRun := fs.Bool("dry-run", false, "Show plan and exit without generating")
|
||||
root := fs.String("root", ".", "Project root (for .core/ai/ config)")
|
||||
cacheLimit := fs.Int("cache-limit", 0, "Metal cache limit in GB (0 = use ai.yaml default)")
|
||||
memLimit := fs.Int("mem-limit", 0, "Metal memory limit in GB (0 = use ai.yaml default)")
|
||||
|
||||
if err := fs.Parse(args); err != nil {
|
||||
log.Fatalf("parse flags: %v", err)
|
||||
}
|
||||
|
||||
func RunDistill(cfg DistillOpts) error {
|
||||
// Load configs.
|
||||
aiCfg, err := LoadAIConfig(*root)
|
||||
aiCfg, err := LoadAIConfig(cfg.Root)
|
||||
if err != nil {
|
||||
log.Fatalf("load ai config: %v", err)
|
||||
return fmt.Errorf("load ai config: %w", err)
|
||||
}
|
||||
|
||||
modelCfg, err := LoadModelConfig(*root, *modelFlag)
|
||||
modelCfg, err := LoadModelConfig(cfg.Root, cfg.Model)
|
||||
if err != nil {
|
||||
log.Fatalf("load model config: %v", err)
|
||||
return fmt.Errorf("load model config: %w", err)
|
||||
}
|
||||
|
||||
genCfg := MergeGenerate(aiCfg.Generate, modelCfg.Generate)
|
||||
|
||||
// Apply flag overrides.
|
||||
if *minScore == 0 {
|
||||
*minScore = aiCfg.Scorer.MinScore
|
||||
minScore := cfg.MinScore
|
||||
if minScore == 0 {
|
||||
minScore = aiCfg.Scorer.MinScore
|
||||
}
|
||||
if *runs == 0 {
|
||||
*runs = aiCfg.Distill.Runs
|
||||
runs := cfg.Runs
|
||||
if runs == 0 {
|
||||
runs = aiCfg.Distill.Runs
|
||||
}
|
||||
cacheLimitGB := aiCfg.Distill.CacheLimit
|
||||
if *cacheLimit > 0 {
|
||||
cacheLimitGB = *cacheLimit
|
||||
if cfg.CacheLimit > 0 {
|
||||
cacheLimitGB = cfg.CacheLimit
|
||||
}
|
||||
memLimitGB := aiCfg.Distill.MemoryLimit
|
||||
if *memLimit > 0 {
|
||||
memLimitGB = *memLimit
|
||||
if cfg.MemLimit > 0 {
|
||||
memLimitGB = cfg.MemLimit
|
||||
}
|
||||
|
||||
// Load probes.
|
||||
probes, phase, err := loadDistillProbes(*root, *probesFlag)
|
||||
probes, phase, err := loadDistillProbes(cfg.Root, cfg.Probes)
|
||||
if err != nil {
|
||||
log.Fatalf("load probes: %v", err)
|
||||
return fmt.Errorf("load probes: %w", err)
|
||||
}
|
||||
log.Printf("loaded %d probes", len(probes))
|
||||
|
||||
// Determine output path.
|
||||
outputPath := *outputFlag
|
||||
outputPath := cfg.Output
|
||||
if outputPath == "" {
|
||||
lesson := *lessonFlag
|
||||
lesson := cfg.Lesson
|
||||
if lesson < 0 {
|
||||
lesson = phase
|
||||
}
|
||||
|
|
@ -104,31 +102,35 @@ func RunDistill(args []string) {
|
|||
outputPath = filepath.Join(modelCfg.Training, lessonFile)
|
||||
}
|
||||
|
||||
// Load kernel.
|
||||
kernel, err := os.ReadFile(modelCfg.Kernel)
|
||||
if err != nil {
|
||||
log.Fatalf("read kernel: %v", err)
|
||||
}
|
||||
log.Printf("kernel: %d chars from %s", len(kernel), modelCfg.Kernel)
|
||||
|
||||
// Load signature (LEK-1-Sig).
|
||||
// Load kernel and signature — only if the model config specifies them.
|
||||
// LEM models have axioms in weights; loading the kernel violates A4.
|
||||
var kernel []byte
|
||||
var sig string
|
||||
if modelCfg.Signature != "" {
|
||||
sigBytes, err := os.ReadFile(modelCfg.Signature)
|
||||
if modelCfg.Kernel != "" {
|
||||
kernel, err = os.ReadFile(modelCfg.Kernel)
|
||||
if err != nil {
|
||||
log.Fatalf("read signature: %v", err)
|
||||
return fmt.Errorf("read kernel: %w", err)
|
||||
}
|
||||
sig = strings.TrimSpace(string(sigBytes))
|
||||
log.Printf("signature: %d chars from %s", len(sig), modelCfg.Signature)
|
||||
log.Printf("kernel: %d chars from %s (sandwich output)", len(kernel), modelCfg.Kernel)
|
||||
|
||||
if modelCfg.Signature != "" {
|
||||
sigBytes, err := os.ReadFile(modelCfg.Signature)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read signature: %w", err)
|
||||
}
|
||||
sig = strings.TrimSpace(string(sigBytes))
|
||||
}
|
||||
} else {
|
||||
log.Printf("no kernel configured — bare output (LEM model)")
|
||||
}
|
||||
|
||||
// Dry run.
|
||||
if *dryRun {
|
||||
if cfg.DryRun {
|
||||
fmt.Printf("Model: %s (%s)\n", modelCfg.Name, modelCfg.Paths.Base)
|
||||
fmt.Printf("Backend: %s\n", aiCfg.Backend)
|
||||
fmt.Printf("Probes: %d\n", len(probes))
|
||||
fmt.Printf("Runs: %d per probe (%d total generations)\n", *runs, len(probes)**runs)
|
||||
fmt.Printf("Gate: grammar v3 composite >= %.1f\n", *minScore)
|
||||
fmt.Printf("Runs: %d per probe (%d total generations)\n", runs, len(probes)*runs)
|
||||
fmt.Printf("Gate: grammar v3 composite >= %.1f\n", minScore)
|
||||
fmt.Printf("Generate: temp=%.2f max_tokens=%d top_p=%.2f\n",
|
||||
genCfg.Temperature, genCfg.MaxTokens, genCfg.TopP)
|
||||
fmt.Printf("Memory: cache=%dGB limit=%dGB\n", cacheLimitGB, memLimitGB)
|
||||
|
|
@ -145,7 +147,7 @@ func RunDistill(args []string) {
|
|||
}
|
||||
fmt.Printf(" %s: %s\n", p.ID, prompt)
|
||||
}
|
||||
return
|
||||
return nil
|
||||
}
|
||||
|
||||
// Set Metal memory limits before loading model.
|
||||
|
|
@ -162,7 +164,7 @@ func RunDistill(args []string) {
|
|||
log.Printf("loading model: %s", modelCfg.Paths.Base)
|
||||
backend, err := ml.NewMLXBackend(modelCfg.Paths.Base)
|
||||
if err != nil {
|
||||
log.Fatalf("load model: %v", err)
|
||||
return fmt.Errorf("load model: %w", err)
|
||||
}
|
||||
defer backend.Close()
|
||||
|
||||
|
|
@ -183,7 +185,7 @@ func RunDistill(args []string) {
|
|||
// Open output for append.
|
||||
out, err := os.OpenFile(outputPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
|
||||
if err != nil {
|
||||
log.Fatalf("open output: %v", err)
|
||||
return fmt.Errorf("open output: %w", err)
|
||||
}
|
||||
defer out.Close()
|
||||
|
||||
|
|
@ -200,15 +202,20 @@ func RunDistill(args []string) {
|
|||
for i, probe := range probes {
|
||||
var best *distillCandidate
|
||||
|
||||
// Build sandwich prompt for output: LEK-1 + Prompt + LEK-1-Sig
|
||||
sandwichPrompt := kernelStr + "\n\n" + probe.Prompt
|
||||
if sig != "" {
|
||||
sandwichPrompt += "\n\n" + sig
|
||||
// Build output prompt: sandwich if kernel configured, bare if LEM model.
|
||||
var outputPrompt string
|
||||
if kernelStr != "" {
|
||||
outputPrompt = kernelStr + "\n\n" + probe.Prompt
|
||||
if sig != "" {
|
||||
outputPrompt += "\n\n" + sig
|
||||
}
|
||||
} else {
|
||||
outputPrompt = probe.Prompt
|
||||
}
|
||||
|
||||
for run := range *runs {
|
||||
for run := range runs {
|
||||
fmt.Fprintf(os.Stderr, " [%d/%d] %s run %d/%d",
|
||||
i+1, len(probes), probe.ID, run+1, *runs)
|
||||
i+1, len(probes), probe.ID, run+1, runs)
|
||||
|
||||
// Inference uses bare probe — the model generates from its weights.
|
||||
// Sandwich wrapping is only for the training output format.
|
||||
|
|
@ -277,7 +284,7 @@ func RunDistill(args []string) {
|
|||
}
|
||||
|
||||
// Quality gate.
|
||||
if best != nil && best.Grammar.Composite >= *minScore {
|
||||
if best != nil && best.Grammar.Composite >= minScore {
|
||||
// Duplicate filter: reject if grammar profile is too similar to an already-kept entry.
|
||||
bestFeatures := GrammarFeatures(best.Grammar)
|
||||
if dedupIdx != nil && dedupIdx.IsDuplicate(bestFeatures, 0.02) {
|
||||
|
|
@ -288,10 +295,10 @@ func RunDistill(args []string) {
|
|||
continue
|
||||
}
|
||||
|
||||
// Save with sandwich prompt — kernel wraps the bare probe for training.
|
||||
// Save with output prompt — sandwich if kernel, bare if LEM model.
|
||||
example := TrainingExample{
|
||||
Messages: []ChatMessage{
|
||||
{Role: "user", Content: sandwichPrompt},
|
||||
{Role: "user", Content: outputPrompt},
|
||||
{Role: "assistant", Content: best.Response},
|
||||
},
|
||||
}
|
||||
|
|
@ -318,7 +325,7 @@ func RunDistill(args []string) {
|
|||
score = best.Grammar.Composite
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, " ✗ SKIP %s (best g=%.1f < %.1f)\n",
|
||||
probe.ID, score, *minScore)
|
||||
probe.ID, score, minScore)
|
||||
}
|
||||
|
||||
// Release GPU memory between probes to prevent incremental leak.
|
||||
|
|
@ -330,8 +337,8 @@ func RunDistill(args []string) {
|
|||
fmt.Fprintf(os.Stderr, "\n=== Distillation Complete ===\n")
|
||||
fmt.Fprintf(os.Stderr, "Model: %s (%s)\n", modelCfg.Name, backend.Name())
|
||||
fmt.Fprintf(os.Stderr, "Probes: %d\n", len(probes))
|
||||
fmt.Fprintf(os.Stderr, "Runs: %d per probe (%d total generations)\n", *runs, len(probes)**runs)
|
||||
fmt.Fprintf(os.Stderr, "Scorer: go-i18n/reversal grammar v3, gate >= %.1f\n", *minScore)
|
||||
fmt.Fprintf(os.Stderr, "Runs: %d per probe (%d total generations)\n", runs, len(probes)*runs)
|
||||
fmt.Fprintf(os.Stderr, "Scorer: go-i18n/reversal grammar v3, gate >= %.1f\n", minScore)
|
||||
fmt.Fprintf(os.Stderr, "Kept: %d\n", kept)
|
||||
fmt.Fprintf(os.Stderr, "Deduped: %d\n", deduped)
|
||||
fmt.Fprintf(os.Stderr, "Skipped: %d\n", skipped)
|
||||
|
|
@ -341,6 +348,7 @@ func RunDistill(args []string) {
|
|||
}
|
||||
fmt.Fprintf(os.Stderr, "Output: %s\n", outputPath)
|
||||
fmt.Fprintf(os.Stderr, "Duration: %.0fs (%.1fm)\n", duration.Seconds(), duration.Minutes())
|
||||
return nil
|
||||
}
|
||||
|
||||
// loadDistillProbes loads probes from a named set or a file path.
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@ package lem
|
|||
|
||||
import (
|
||||
"encoding/json"
|
||||
"flag"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
|
|
@ -22,68 +21,61 @@ type expandOutput struct {
|
|||
Chars int `json:"chars"`
|
||||
}
|
||||
|
||||
// runExpand parses CLI flags and runs the expand command.
|
||||
func RunExpand(args []string) {
|
||||
fs := flag.NewFlagSet("expand", flag.ExitOnError)
|
||||
// ExpandOpts holds configuration for the expand command.
|
||||
type ExpandOpts struct {
|
||||
Model string // Model name for generation (required)
|
||||
DB string // DuckDB database path (primary prompt source)
|
||||
Prompts string // Input JSONL file with expansion prompts (fallback)
|
||||
APIURL string // OpenAI-compatible API URL
|
||||
Worker string // Worker hostname (defaults to os.Hostname())
|
||||
Limit int // Max prompts to process (0 = all)
|
||||
Output string // Output directory for JSONL files
|
||||
Influx string // InfluxDB URL
|
||||
InfluxDB string // InfluxDB database name
|
||||
DryRun bool // Print plan and exit without generating
|
||||
}
|
||||
|
||||
model := fs.String("model", "", "Model name for generation (required)")
|
||||
dbPath := fs.String("db", "", "DuckDB database path (primary prompt source)")
|
||||
prompts := fs.String("prompts", "", "Input JSONL file with expansion prompts (fallback)")
|
||||
apiURL := fs.String("api-url", "http://10.69.69.108:8090", "OpenAI-compatible API URL")
|
||||
worker := fs.String("worker", "", "Worker hostname (defaults to os.Hostname())")
|
||||
limit := fs.Int("limit", 0, "Max prompts to process (0 = all)")
|
||||
output := fs.String("output", ".", "Output directory for JSONL files")
|
||||
influxURL := fs.String("influx", "", "InfluxDB URL (default http://10.69.69.165:8181)")
|
||||
influxDB := fs.String("influx-db", "", "InfluxDB database name (default training)")
|
||||
dryRun := fs.Bool("dry-run", false, "Print plan and exit without generating")
|
||||
|
||||
if err := fs.Parse(args); err != nil {
|
||||
log.Fatalf("parse flags: %v", err)
|
||||
}
|
||||
|
||||
if *model == "" {
|
||||
fmt.Fprintln(os.Stderr, "error: --model is required")
|
||||
fs.Usage()
|
||||
os.Exit(1)
|
||||
// RunExpand generates expansion responses via a trained LEM model.
|
||||
func RunExpand(cfg ExpandOpts) error {
|
||||
if cfg.Model == "" {
|
||||
return fmt.Errorf("--model is required")
|
||||
}
|
||||
|
||||
// Check LEM_DB env as default for --db.
|
||||
if *dbPath == "" {
|
||||
*dbPath = os.Getenv("LEM_DB")
|
||||
if cfg.DB == "" {
|
||||
cfg.DB = os.Getenv("LEM_DB")
|
||||
}
|
||||
|
||||
if *dbPath == "" && *prompts == "" {
|
||||
fmt.Fprintln(os.Stderr, "error: --db or --prompts is required (set LEM_DB env for default)")
|
||||
fs.Usage()
|
||||
os.Exit(1)
|
||||
if cfg.DB == "" && cfg.Prompts == "" {
|
||||
return fmt.Errorf("--db or --prompts is required (set LEM_DB env for default)")
|
||||
}
|
||||
|
||||
// Default worker to hostname.
|
||||
if *worker == "" {
|
||||
if cfg.Worker == "" {
|
||||
hostname, err := os.Hostname()
|
||||
if err != nil {
|
||||
hostname = "unknown"
|
||||
}
|
||||
*worker = hostname
|
||||
cfg.Worker = hostname
|
||||
}
|
||||
|
||||
// Load prompts from DuckDB or JSONL.
|
||||
var promptList []Response
|
||||
var duckDB *DB
|
||||
|
||||
if *dbPath != "" {
|
||||
if cfg.DB != "" {
|
||||
var err error
|
||||
duckDB, err = OpenDBReadWrite(*dbPath)
|
||||
duckDB, err = OpenDBReadWrite(cfg.DB)
|
||||
if err != nil {
|
||||
log.Fatalf("open db: %v", err)
|
||||
return fmt.Errorf("open db: %v", err)
|
||||
}
|
||||
defer duckDB.Close()
|
||||
|
||||
rows, err := duckDB.QueryExpansionPrompts("pending", *limit)
|
||||
rows, err := duckDB.QueryExpansionPrompts("pending", cfg.Limit)
|
||||
if err != nil {
|
||||
log.Fatalf("query expansion_prompts: %v", err)
|
||||
return fmt.Errorf("query expansion_prompts: %v", err)
|
||||
}
|
||||
log.Printf("loaded %d pending prompts from %s", len(rows), *dbPath)
|
||||
log.Printf("loaded %d pending prompts from %s", len(rows), cfg.DB)
|
||||
|
||||
for _, r := range rows {
|
||||
prompt := r.Prompt
|
||||
|
|
@ -98,21 +90,22 @@ func RunExpand(args []string) {
|
|||
}
|
||||
} else {
|
||||
var err error
|
||||
promptList, err = ReadResponses(*prompts)
|
||||
promptList, err = ReadResponses(cfg.Prompts)
|
||||
if err != nil {
|
||||
log.Fatalf("read prompts: %v", err)
|
||||
return fmt.Errorf("read prompts: %v", err)
|
||||
}
|
||||
log.Printf("loaded %d prompts from %s", len(promptList), *prompts)
|
||||
log.Printf("loaded %d prompts from %s", len(promptList), cfg.Prompts)
|
||||
}
|
||||
|
||||
// Create clients.
|
||||
client := NewClient(*apiURL, *model)
|
||||
client := NewClient(cfg.APIURL, cfg.Model)
|
||||
client.MaxTokens = 2048
|
||||
influx := NewInfluxClient(*influxURL, *influxDB)
|
||||
influx := NewInfluxClient(cfg.Influx, cfg.InfluxDB)
|
||||
|
||||
if err := expandPrompts(client, influx, duckDB, promptList, *model, *worker, *output, *dryRun, *limit); err != nil {
|
||||
log.Fatalf("expand: %v", err)
|
||||
if err := expandPrompts(client, influx, duckDB, promptList, cfg.Model, cfg.Worker, cfg.Output, cfg.DryRun, cfg.Limit); err != nil {
|
||||
return fmt.Errorf("expand: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// getCompletedIDs queries InfluxDB for prompt IDs that have already been
|
||||
|
|
|
|||
|
|
@ -1,33 +1,27 @@
|
|||
package lem
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
)
|
||||
|
||||
// RunExpandStatus is the CLI entry point for the expand-status command.
|
||||
// Shows the expansion pipeline progress from DuckDB.
|
||||
func RunExpandStatus(args []string) {
|
||||
fs := flag.NewFlagSet("expand-status", flag.ExitOnError)
|
||||
dbPath := fs.String("db", "", "DuckDB database path (defaults to LEM_DB env)")
|
||||
// ExpandStatusOpts holds configuration for the expand-status command.
|
||||
type ExpandStatusOpts struct {
|
||||
DB string // DuckDB database path (defaults to LEM_DB env)
|
||||
}
|
||||
|
||||
if err := fs.Parse(args); err != nil {
|
||||
log.Fatalf("parse flags: %v", err)
|
||||
// RunExpandStatus shows the expansion pipeline progress from DuckDB.
|
||||
func RunExpandStatus(cfg ExpandStatusOpts) error {
|
||||
if cfg.DB == "" {
|
||||
cfg.DB = os.Getenv("LEM_DB")
|
||||
}
|
||||
if cfg.DB == "" {
|
||||
return fmt.Errorf("--db or LEM_DB required")
|
||||
}
|
||||
|
||||
if *dbPath == "" {
|
||||
*dbPath = os.Getenv("LEM_DB")
|
||||
}
|
||||
if *dbPath == "" {
|
||||
fmt.Fprintln(os.Stderr, "error: --db or LEM_DB required")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
db, err := OpenDB(*dbPath)
|
||||
db, err := OpenDB(cfg.DB)
|
||||
if err != nil {
|
||||
log.Fatalf("open db: %v", err)
|
||||
return fmt.Errorf("open db: %w", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
|
|
@ -39,8 +33,7 @@ func RunExpandStatus(args []string) {
|
|||
err = db.conn.QueryRow("SELECT count(*) FROM expansion_prompts").Scan(&epTotal)
|
||||
if err != nil {
|
||||
fmt.Println(" Expansion prompts: not created (run: lem normalize)")
|
||||
db.Close()
|
||||
return
|
||||
return nil
|
||||
}
|
||||
db.conn.QueryRow("SELECT count(*) FROM expansion_prompts WHERE status = 'pending'").Scan(&epPending)
|
||||
fmt.Printf(" Expansion prompts: %d total, %d pending\n", epTotal, epPending)
|
||||
|
|
@ -100,4 +93,6 @@ func RunExpandStatus(args []string) {
|
|||
fmt.Printf(" Combined: %d total examples\n", golden+generated)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@ package lem
|
|||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"flag"
|
||||
"fmt"
|
||||
"log"
|
||||
"math/rand"
|
||||
|
|
@ -11,6 +10,18 @@ import (
|
|||
"strings"
|
||||
)
|
||||
|
||||
// ExportOpts holds configuration for the JSONL export command.
|
||||
type ExportOpts struct {
|
||||
DBPath string // DuckDB database path (primary source); falls back to LEM_DB env
|
||||
Input string // Input golden set JSONL file (fallback if DBPath not set)
|
||||
OutputDir string // Output directory for training files (required)
|
||||
TrainPct int // Training set percentage
|
||||
ValidPct int // Validation set percentage
|
||||
TestPct int // Test set percentage
|
||||
Seed int64 // Random seed for shuffling
|
||||
MinChars int // Minimum response character count
|
||||
}
|
||||
|
||||
// ChatMessage is a single message in the chat training format.
|
||||
type ChatMessage struct {
|
||||
Role string `json:"role"`
|
||||
|
|
@ -22,60 +33,40 @@ type TrainingExample struct {
|
|||
Messages []ChatMessage `json:"messages"`
|
||||
}
|
||||
|
||||
// runExport is the CLI entry point for the export command.
|
||||
func RunExport(args []string) {
|
||||
fs := flag.NewFlagSet("export", flag.ExitOnError)
|
||||
|
||||
dbPath := fs.String("db", "", "DuckDB database path (primary source)")
|
||||
input := fs.String("input", "", "Input golden set JSONL file (fallback if --db not set)")
|
||||
outputDir := fs.String("output-dir", "", "Output directory for training files (required)")
|
||||
trainPct := fs.Int("train-pct", 90, "Training set percentage")
|
||||
validPct := fs.Int("valid-pct", 5, "Validation set percentage")
|
||||
testPct := fs.Int("test-pct", 5, "Test set percentage")
|
||||
seed := fs.Int64("seed", 42, "Random seed for shuffling")
|
||||
minChars := fs.Int("min-chars", 50, "Minimum response character count")
|
||||
|
||||
if err := fs.Parse(args); err != nil {
|
||||
log.Fatalf("parse flags: %v", err)
|
||||
}
|
||||
|
||||
// RunExport is the CLI entry point for the export command.
|
||||
func RunExport(cfg ExportOpts) error {
|
||||
// Check LEM_DB env as default for --db.
|
||||
if *dbPath == "" {
|
||||
*dbPath = os.Getenv("LEM_DB")
|
||||
if cfg.DBPath == "" {
|
||||
cfg.DBPath = os.Getenv("LEM_DB")
|
||||
}
|
||||
|
||||
if *dbPath == "" && *input == "" {
|
||||
fmt.Fprintln(os.Stderr, "error: --db or --input is required (set LEM_DB env for default)")
|
||||
fs.Usage()
|
||||
os.Exit(1)
|
||||
if cfg.DBPath == "" && cfg.Input == "" {
|
||||
return fmt.Errorf("--db or --input is required (set LEM_DB env for default)")
|
||||
}
|
||||
|
||||
if *outputDir == "" {
|
||||
fmt.Fprintln(os.Stderr, "error: --output-dir is required")
|
||||
fs.Usage()
|
||||
os.Exit(1)
|
||||
if cfg.OutputDir == "" {
|
||||
return fmt.Errorf("--output-dir is required")
|
||||
}
|
||||
|
||||
if err := validatePercentages(*trainPct, *validPct, *testPct); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "error: %v\n", err)
|
||||
os.Exit(1)
|
||||
if err := validatePercentages(cfg.TrainPct, cfg.ValidPct, cfg.TestPct); err != nil {
|
||||
return fmt.Errorf("invalid percentages: %w", err)
|
||||
}
|
||||
|
||||
var responses []Response
|
||||
|
||||
if *dbPath != "" {
|
||||
if cfg.DBPath != "" {
|
||||
// Primary: read from DuckDB golden_set table.
|
||||
db, err := OpenDB(*dbPath)
|
||||
db, err := OpenDB(cfg.DBPath)
|
||||
if err != nil {
|
||||
log.Fatalf("open db: %v", err)
|
||||
return fmt.Errorf("open db: %w", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
rows, err := db.QueryGoldenSet(*minChars)
|
||||
rows, err := db.QueryGoldenSet(cfg.MinChars)
|
||||
if err != nil {
|
||||
log.Fatalf("query golden_set: %v", err)
|
||||
return fmt.Errorf("query golden_set: %w", err)
|
||||
}
|
||||
log.Printf("loaded %d golden set rows from %s (min_chars=%d)", len(rows), *dbPath, *minChars)
|
||||
log.Printf("loaded %d golden set rows from %s (min_chars=%d)", len(rows), cfg.DBPath, cfg.MinChars)
|
||||
|
||||
// Convert GoldenSetRow → Response for the shared pipeline.
|
||||
for _, r := range rows {
|
||||
|
|
@ -90,11 +81,11 @@ func RunExport(args []string) {
|
|||
} else {
|
||||
// Fallback: read from JSONL file.
|
||||
var err error
|
||||
responses, err = ReadResponses(*input)
|
||||
responses, err = ReadResponses(cfg.Input)
|
||||
if err != nil {
|
||||
log.Fatalf("read responses: %v", err)
|
||||
return fmt.Errorf("read responses: %w", err)
|
||||
}
|
||||
log.Printf("loaded %d responses from %s", len(responses), *input)
|
||||
log.Printf("loaded %d responses from %s", len(responses), cfg.Input)
|
||||
}
|
||||
|
||||
// Filter out bad responses (DuckDB already filters by char_count, but
|
||||
|
|
@ -103,11 +94,11 @@ func RunExport(args []string) {
|
|||
log.Printf("filtered to %d valid responses (removed %d)", len(filtered), len(responses)-len(filtered))
|
||||
|
||||
// Split into train/valid/test.
|
||||
train, valid, test := splitData(filtered, *trainPct, *validPct, *testPct, *seed)
|
||||
train, valid, test := splitData(filtered, cfg.TrainPct, cfg.ValidPct, cfg.TestPct, cfg.Seed)
|
||||
|
||||
// Create output directory.
|
||||
if err := os.MkdirAll(*outputDir, 0755); err != nil {
|
||||
log.Fatalf("create output dir: %v", err)
|
||||
if err := os.MkdirAll(cfg.OutputDir, 0755); err != nil {
|
||||
return fmt.Errorf("create output dir: %w", err)
|
||||
}
|
||||
|
||||
// Write output files.
|
||||
|
|
@ -119,13 +110,14 @@ func RunExport(args []string) {
|
|||
{"valid.jsonl", valid},
|
||||
{"test.jsonl", test},
|
||||
} {
|
||||
path := *outputDir + "/" + split.name
|
||||
path := cfg.OutputDir + "/" + split.name
|
||||
if err := writeTrainingJSONL(path, split.data); err != nil {
|
||||
log.Fatalf("write %s: %v", split.name, err)
|
||||
return fmt.Errorf("write %s: %w", split.name, err)
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Printf("Exported: %d train / %d valid / %d test\n", len(train), len(valid), len(test))
|
||||
return nil
|
||||
}
|
||||
|
||||
// validatePercentages checks that train+valid+test percentages sum to 100
|
||||
|
|
|
|||
|
|
@ -332,15 +332,17 @@ func TestExportEndToEnd(t *testing.T) {
|
|||
}
|
||||
|
||||
// Run export with 80/10/10 split.
|
||||
args := []string{
|
||||
"--input", inputPath,
|
||||
"--output-dir", outputDir,
|
||||
"--train-pct", "80",
|
||||
"--valid-pct", "10",
|
||||
"--test-pct", "10",
|
||||
"--seed", "42",
|
||||
err := RunExport(ExportOpts{
|
||||
Input: inputPath,
|
||||
OutputDir: outputDir,
|
||||
TrainPct: 80,
|
||||
ValidPct: 10,
|
||||
TestPct: 10,
|
||||
Seed: 42,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("RunExport failed: %v", err)
|
||||
}
|
||||
RunExport(args)
|
||||
|
||||
// Verify output files exist.
|
||||
for _, name := range []string{"train.jsonl", "valid.jsonl", "test.jsonl"} {
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@ package lem
|
|||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"flag"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
|
|
@ -12,42 +11,41 @@ import (
|
|||
"strings"
|
||||
)
|
||||
|
||||
// ImportOpts holds configuration for the import-all command.
|
||||
type ImportOpts struct {
|
||||
DB string // DuckDB database path (defaults to LEM_DB env)
|
||||
SkipM3 bool // Skip pulling data from M3
|
||||
DataDir string // Local data directory (defaults to db directory)
|
||||
}
|
||||
|
||||
// RunImport is the CLI entry point for the import-all command.
|
||||
// Imports ALL LEM data into DuckDB: prompts, Gemini responses, golden set,
|
||||
// training examples, benchmarks, validations, and seeds.
|
||||
func RunImport(args []string) {
|
||||
fs := flag.NewFlagSet("import-all", flag.ExitOnError)
|
||||
dbPath := fs.String("db", "", "DuckDB database path (defaults to LEM_DB env)")
|
||||
skipM3 := fs.Bool("skip-m3", false, "Skip pulling data from M3")
|
||||
dataDir := fs.String("data-dir", "", "Local data directory (defaults to db directory)")
|
||||
|
||||
if err := fs.Parse(args); err != nil {
|
||||
log.Fatalf("parse flags: %v", err)
|
||||
func RunImport(cfg ImportOpts) error {
|
||||
dbPath := cfg.DB
|
||||
if dbPath == "" {
|
||||
dbPath = os.Getenv("LEM_DB")
|
||||
}
|
||||
if dbPath == "" {
|
||||
return fmt.Errorf("--db or LEM_DB required")
|
||||
}
|
||||
|
||||
if *dbPath == "" {
|
||||
*dbPath = os.Getenv("LEM_DB")
|
||||
}
|
||||
if *dbPath == "" {
|
||||
fmt.Fprintln(os.Stderr, "error: --db or LEM_DB required")
|
||||
os.Exit(1)
|
||||
dataDir := cfg.DataDir
|
||||
if dataDir == "" {
|
||||
dataDir = filepath.Dir(dbPath)
|
||||
}
|
||||
|
||||
if *dataDir == "" {
|
||||
*dataDir = filepath.Dir(*dbPath)
|
||||
}
|
||||
|
||||
db, err := OpenDBReadWrite(*dbPath)
|
||||
db, err := OpenDBReadWrite(dbPath)
|
||||
if err != nil {
|
||||
log.Fatalf("open db: %v", err)
|
||||
return fmt.Errorf("open db: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
totals := make(map[string]int)
|
||||
|
||||
// ── 1. Golden set ──
|
||||
goldenPath := filepath.Join(*dataDir, "gold-15k.jsonl")
|
||||
if !*skipM3 {
|
||||
goldenPath := filepath.Join(dataDir, "gold-15k.jsonl")
|
||||
if !cfg.SkipM3 {
|
||||
fmt.Println(" Pulling golden set from M3...")
|
||||
scpCmd := exec.Command("scp", "m3:/Volumes/Data/lem/responses/gold-15k.jsonl", goldenPath)
|
||||
if err := scpCmd.Run(); err != nil {
|
||||
|
|
@ -101,10 +99,10 @@ func RunImport(args []string) {
|
|||
{"russian-bridge", []string{"russian-bridge/train.jsonl", "russian-bridge/valid.jsonl"}},
|
||||
}
|
||||
|
||||
trainingLocal := filepath.Join(*dataDir, "training")
|
||||
trainingLocal := filepath.Join(dataDir, "training")
|
||||
os.MkdirAll(trainingLocal, 0755)
|
||||
|
||||
if !*skipM3 {
|
||||
if !cfg.SkipM3 {
|
||||
fmt.Println(" Pulling training sets from M3...")
|
||||
for _, td := range trainingDirs {
|
||||
for _, rel := range td.files {
|
||||
|
|
@ -152,10 +150,10 @@ func RunImport(args []string) {
|
|||
fmt.Printf(" training_examples: %d rows\n", trainingTotal)
|
||||
|
||||
// ── 3. Benchmark results ──
|
||||
benchLocal := filepath.Join(*dataDir, "benchmarks")
|
||||
benchLocal := filepath.Join(dataDir, "benchmarks")
|
||||
os.MkdirAll(benchLocal, 0755)
|
||||
|
||||
if !*skipM3 {
|
||||
if !cfg.SkipM3 {
|
||||
fmt.Println(" Pulling benchmarks from M3...")
|
||||
for _, bname := range []string{"truthfulqa", "gsm8k", "do_not_answer", "toxigen"} {
|
||||
scpCmd := exec.Command("scp",
|
||||
|
|
@ -195,7 +193,7 @@ func RunImport(args []string) {
|
|||
for _, bfile := range []string{"lem_bench", "lem_ethics", "lem_ethics_allen", "instruction_tuned", "abliterated", "base_pt"} {
|
||||
local := filepath.Join(benchLocal, bfile+".jsonl")
|
||||
if _, err := os.Stat(local); os.IsNotExist(err) {
|
||||
if !*skipM3 {
|
||||
if !cfg.SkipM3 {
|
||||
scpCmd := exec.Command("scp",
|
||||
fmt.Sprintf("m3:/Volumes/Data/lem/benchmark/%s.jsonl", bfile), local)
|
||||
scpCmd.Run()
|
||||
|
|
@ -238,7 +236,7 @@ func RunImport(args []string) {
|
|||
`)
|
||||
|
||||
seedTotal := 0
|
||||
seedDirs := []string{filepath.Join(*dataDir, "seeds"), "/tmp/lem-data/seeds", "/tmp/lem-repo/seeds"}
|
||||
seedDirs := []string{filepath.Join(dataDir, "seeds"), "/tmp/lem-data/seeds", "/tmp/lem-repo/seeds"}
|
||||
for _, seedDir := range seedDirs {
|
||||
if _, err := os.Stat(seedDir); os.IsNotExist(err) {
|
||||
continue
|
||||
|
|
@ -260,7 +258,8 @@ func RunImport(args []string) {
|
|||
}
|
||||
fmt.Printf(" %s\n", strings.Repeat("─", 35))
|
||||
fmt.Printf(" %-25s %8d\n", "TOTAL", grandTotal)
|
||||
fmt.Printf("\nDatabase: %s\n", *dbPath)
|
||||
fmt.Printf("\nDatabase: %s\n", dbPath)
|
||||
return nil
|
||||
}
|
||||
|
||||
func importTrainingFile(db *DB, path, source, split string) int {
|
||||
|
|
|
|||
|
|
@ -3,81 +3,77 @@ package lem
|
|||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"flag"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// IngestOpts holds configuration for the ingest command.
|
||||
type IngestOpts struct {
|
||||
Content string // Content scores JSONL file
|
||||
Capability string // Capability scores JSONL file
|
||||
TrainingLog string // MLX LoRA training log file
|
||||
Model string // Model name tag (required)
|
||||
RunID string // Run ID tag (defaults to model name)
|
||||
InfluxURL string // InfluxDB URL
|
||||
InfluxDB string // InfluxDB database name
|
||||
BatchSize int // Lines per InfluxDB write batch
|
||||
}
|
||||
|
||||
// RunIngest is the CLI entry point for the ingest command.
|
||||
// It reads benchmark JSONL files and training logs, then pushes
|
||||
// the data into InfluxDB as line protocol for the lab dashboard.
|
||||
func RunIngest(args []string) {
|
||||
fs := flag.NewFlagSet("ingest", flag.ExitOnError)
|
||||
|
||||
contentFile := fs.String("content", "", "Content scores JSONL file")
|
||||
capabilityFile := fs.String("capability", "", "Capability scores JSONL file")
|
||||
trainingLog := fs.String("training-log", "", "MLX LoRA training log file")
|
||||
model := fs.String("model", "", "Model name tag (required)")
|
||||
runID := fs.String("run-id", "", "Run ID tag (defaults to model name)")
|
||||
influxURL := fs.String("influx", "", "InfluxDB URL")
|
||||
influxDB := fs.String("influx-db", "", "InfluxDB database name")
|
||||
batchSize := fs.Int("batch-size", 100, "Lines per InfluxDB write batch")
|
||||
|
||||
if err := fs.Parse(args); err != nil {
|
||||
log.Fatalf("parse flags: %v", err)
|
||||
func RunIngest(cfg IngestOpts) error {
|
||||
if cfg.Model == "" {
|
||||
return fmt.Errorf("--model is required")
|
||||
}
|
||||
|
||||
if *model == "" {
|
||||
fmt.Fprintln(os.Stderr, "error: --model is required")
|
||||
fs.Usage()
|
||||
os.Exit(1)
|
||||
if cfg.Content == "" && cfg.Capability == "" && cfg.TrainingLog == "" {
|
||||
return fmt.Errorf("at least one of --content, --capability, or --training-log is required")
|
||||
}
|
||||
|
||||
if *contentFile == "" && *capabilityFile == "" && *trainingLog == "" {
|
||||
fmt.Fprintln(os.Stderr, "error: at least one of --content, --capability, or --training-log is required")
|
||||
fs.Usage()
|
||||
os.Exit(1)
|
||||
if cfg.RunID == "" {
|
||||
cfg.RunID = cfg.Model
|
||||
}
|
||||
|
||||
if *runID == "" {
|
||||
*runID = *model
|
||||
if cfg.BatchSize <= 0 {
|
||||
cfg.BatchSize = 100
|
||||
}
|
||||
|
||||
influx := NewInfluxClient(*influxURL, *influxDB)
|
||||
influx := NewInfluxClient(cfg.InfluxURL, cfg.InfluxDB)
|
||||
total := 0
|
||||
|
||||
if *contentFile != "" {
|
||||
n, err := ingestContentScores(influx, *contentFile, *model, *runID, *batchSize)
|
||||
if cfg.Content != "" {
|
||||
n, err := ingestContentScores(influx, cfg.Content, cfg.Model, cfg.RunID, cfg.BatchSize)
|
||||
if err != nil {
|
||||
log.Fatalf("ingest content scores: %v", err)
|
||||
return fmt.Errorf("ingest content scores: %w", err)
|
||||
}
|
||||
fmt.Printf(" Content scores: %d points\n", n)
|
||||
total += n
|
||||
}
|
||||
|
||||
if *capabilityFile != "" {
|
||||
n, err := ingestCapabilityScores(influx, *capabilityFile, *model, *runID, *batchSize)
|
||||
if cfg.Capability != "" {
|
||||
n, err := ingestCapabilityScores(influx, cfg.Capability, cfg.Model, cfg.RunID, cfg.BatchSize)
|
||||
if err != nil {
|
||||
log.Fatalf("ingest capability scores: %v", err)
|
||||
return fmt.Errorf("ingest capability scores: %w", err)
|
||||
}
|
||||
fmt.Printf(" Capability scores: %d points\n", n)
|
||||
total += n
|
||||
}
|
||||
|
||||
if *trainingLog != "" {
|
||||
n, err := ingestTrainingCurve(influx, *trainingLog, *model, *runID, *batchSize)
|
||||
if cfg.TrainingLog != "" {
|
||||
n, err := ingestTrainingCurve(influx, cfg.TrainingLog, cfg.Model, cfg.RunID, cfg.BatchSize)
|
||||
if err != nil {
|
||||
log.Fatalf("ingest training curve: %v", err)
|
||||
return fmt.Errorf("ingest training curve: %w", err)
|
||||
}
|
||||
fmt.Printf(" Training curve: %d points\n", n)
|
||||
total += n
|
||||
}
|
||||
|
||||
fmt.Printf("\nTotal: %d points ingested\n", total)
|
||||
return nil
|
||||
}
|
||||
|
||||
var iterRe = regexp.MustCompile(`@(\d+)`)
|
||||
|
|
|
|||
|
|
@ -1,43 +1,37 @@
|
|||
package lem
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// RunInventory is the CLI entry point for the inventory command.
|
||||
// Shows row counts and summary stats for all tables in the DuckDB database.
|
||||
func RunInventory(args []string) {
|
||||
fs := flag.NewFlagSet("inventory", flag.ExitOnError)
|
||||
dbPath := fs.String("db", "", "DuckDB database path (defaults to LEM_DB env)")
|
||||
// InventoryOpts holds configuration for the inventory command.
|
||||
type InventoryOpts struct {
|
||||
DB string // DuckDB database path (defaults to LEM_DB env)
|
||||
}
|
||||
|
||||
if err := fs.Parse(args); err != nil {
|
||||
log.Fatalf("parse flags: %v", err)
|
||||
// RunInventory shows row counts and summary stats for all tables in the DuckDB database.
|
||||
func RunInventory(cfg InventoryOpts) error {
|
||||
if cfg.DB == "" {
|
||||
cfg.DB = os.Getenv("LEM_DB")
|
||||
}
|
||||
if cfg.DB == "" {
|
||||
return fmt.Errorf("--db or LEM_DB required")
|
||||
}
|
||||
|
||||
if *dbPath == "" {
|
||||
*dbPath = os.Getenv("LEM_DB")
|
||||
}
|
||||
if *dbPath == "" {
|
||||
fmt.Fprintln(os.Stderr, "error: --db or LEM_DB required")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
db, err := OpenDB(*dbPath)
|
||||
db, err := OpenDB(cfg.DB)
|
||||
if err != nil {
|
||||
log.Fatalf("open db: %v", err)
|
||||
return fmt.Errorf("open db: %w", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
counts, err := db.TableCounts()
|
||||
if err != nil {
|
||||
log.Fatalf("table counts: %v", err)
|
||||
return fmt.Errorf("table counts: %w", err)
|
||||
}
|
||||
|
||||
fmt.Printf("LEM Database Inventory (%s)\n", *dbPath)
|
||||
fmt.Printf("LEM Database Inventory (%s)\n", cfg.DB)
|
||||
fmt.Println("============================================================")
|
||||
|
||||
grandTotal := 0
|
||||
|
|
@ -84,6 +78,8 @@ func RunInventory(args []string) {
|
|||
|
||||
fmt.Printf(" %-25s\n", "────────────────────────────────────────")
|
||||
fmt.Printf(" %-25s %8d\n", "TOTAL", grandTotal)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func joinStrings(parts []string, sep string) string {
|
||||
|
|
|
|||
|
|
@ -1,40 +1,33 @@
|
|||
package lem
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"time"
|
||||
)
|
||||
|
||||
const targetTotal = 15000
|
||||
|
||||
// RunMetrics is the CLI entry point for the metrics command.
|
||||
// Reads golden set stats from DuckDB and pushes them to InfluxDB as
|
||||
// MetricsOpts holds configuration for the metrics command.
|
||||
type MetricsOpts struct {
|
||||
DB string // DuckDB database path (defaults to LEM_DB env)
|
||||
Influx string // InfluxDB URL
|
||||
InfluxDB string // InfluxDB database name
|
||||
}
|
||||
|
||||
// RunMetrics reads golden set stats from DuckDB and pushes them to InfluxDB as
|
||||
// golden_set_stats, golden_set_domain, and golden_set_voice measurements.
|
||||
func RunMetrics(args []string) {
|
||||
fs := flag.NewFlagSet("metrics", flag.ExitOnError)
|
||||
|
||||
dbPath := fs.String("db", "", "DuckDB database path (defaults to LEM_DB env)")
|
||||
influxURL := fs.String("influx", "", "InfluxDB URL")
|
||||
influxDB := fs.String("influx-db", "", "InfluxDB database name")
|
||||
|
||||
if err := fs.Parse(args); err != nil {
|
||||
log.Fatalf("parse flags: %v", err)
|
||||
func RunMetrics(cfg MetricsOpts) error {
|
||||
if cfg.DB == "" {
|
||||
cfg.DB = os.Getenv("LEM_DB")
|
||||
}
|
||||
if cfg.DB == "" {
|
||||
return fmt.Errorf("--db or LEM_DB required (path to DuckDB file)")
|
||||
}
|
||||
|
||||
if *dbPath == "" {
|
||||
*dbPath = os.Getenv("LEM_DB")
|
||||
}
|
||||
if *dbPath == "" {
|
||||
fmt.Fprintln(os.Stderr, "error: --db or LEM_DB required (path to DuckDB file)")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
db, err := OpenDB(*dbPath)
|
||||
db, err := OpenDB(cfg.DB)
|
||||
if err != nil {
|
||||
log.Fatalf("open db: %v", err)
|
||||
return fmt.Errorf("open db: %w", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
|
|
@ -48,12 +41,12 @@ func RunMetrics(args []string) {
|
|||
FROM golden_set
|
||||
`).Scan(&total, &domains, &voices, &avgGenTime, &avgChars)
|
||||
if err != nil {
|
||||
log.Fatalf("query golden_set stats: %v", err)
|
||||
return fmt.Errorf("query golden_set stats: %w", err)
|
||||
}
|
||||
|
||||
if total == 0 {
|
||||
fmt.Println("No golden set data in DuckDB.")
|
||||
return
|
||||
return nil
|
||||
}
|
||||
|
||||
nowNs := time.Now().UTC().UnixNano()
|
||||
|
|
@ -73,7 +66,7 @@ func RunMetrics(args []string) {
|
|||
FROM golden_set GROUP BY domain
|
||||
`)
|
||||
if err != nil {
|
||||
log.Fatalf("query domains: %v", err)
|
||||
return fmt.Errorf("query domains: %w", err)
|
||||
}
|
||||
domainCount := 0
|
||||
for domainRows.Next() {
|
||||
|
|
@ -81,7 +74,7 @@ func RunMetrics(args []string) {
|
|||
var n int
|
||||
var avgT float64
|
||||
if err := domainRows.Scan(&domain, &n, &avgT); err != nil {
|
||||
log.Fatalf("scan domain row: %v", err)
|
||||
return fmt.Errorf("scan domain row: %w", err)
|
||||
}
|
||||
lines = append(lines, fmt.Sprintf(
|
||||
"golden_set_domain,domain=%s count=%di,avg_gen_time=%.2f %d",
|
||||
|
|
@ -97,7 +90,7 @@ func RunMetrics(args []string) {
|
|||
FROM golden_set GROUP BY voice
|
||||
`)
|
||||
if err != nil {
|
||||
log.Fatalf("query voices: %v", err)
|
||||
return fmt.Errorf("query voices: %w", err)
|
||||
}
|
||||
voiceCount := 0
|
||||
for voiceRows.Next() {
|
||||
|
|
@ -105,7 +98,7 @@ func RunMetrics(args []string) {
|
|||
var n int
|
||||
var avgC, avgT float64
|
||||
if err := voiceRows.Scan(&voice, &n, &avgC, &avgT); err != nil {
|
||||
log.Fatalf("scan voice row: %v", err)
|
||||
return fmt.Errorf("scan voice row: %w", err)
|
||||
}
|
||||
lines = append(lines, fmt.Sprintf(
|
||||
"golden_set_voice,voice=%s count=%di,avg_chars=%.0f,avg_gen_time=%.2f %d",
|
||||
|
|
@ -116,11 +109,13 @@ func RunMetrics(args []string) {
|
|||
voiceRows.Close()
|
||||
|
||||
// Write to InfluxDB.
|
||||
influx := NewInfluxClient(*influxURL, *influxDB)
|
||||
influx := NewInfluxClient(cfg.Influx, cfg.InfluxDB)
|
||||
if err := influx.WriteLp(lines); err != nil {
|
||||
log.Fatalf("write metrics: %v", err)
|
||||
return fmt.Errorf("write metrics: %w", err)
|
||||
}
|
||||
|
||||
fmt.Printf("Wrote metrics to InfluxDB: %d examples, %d domains, %d voices (%d points)\n",
|
||||
total, domainCount, voiceCount, len(lines))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,50 +1,49 @@
|
|||
package lem
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
)
|
||||
|
||||
// NormalizeOpts holds configuration for the normalize command.
|
||||
type NormalizeOpts struct {
|
||||
DB string // DuckDB database path (defaults to LEM_DB env)
|
||||
MinLen int // Minimum prompt length in characters
|
||||
}
|
||||
|
||||
// RunNormalize is the CLI entry point for the normalize command.
|
||||
// Normalizes seeds into the expansion_prompts table, deduplicating against
|
||||
// the golden set and existing prompts. Assigns priority based on domain
|
||||
// coverage (underrepresented domains first).
|
||||
func RunNormalize(args []string) {
|
||||
fs := flag.NewFlagSet("normalize", flag.ExitOnError)
|
||||
dbPath := fs.String("db", "", "DuckDB database path (defaults to LEM_DB env)")
|
||||
minLen := fs.Int("min-length", 50, "Minimum prompt length in characters")
|
||||
|
||||
if err := fs.Parse(args); err != nil {
|
||||
log.Fatalf("parse flags: %v", err)
|
||||
func RunNormalize(cfg NormalizeOpts) error {
|
||||
dbPath := cfg.DB
|
||||
if dbPath == "" {
|
||||
dbPath = os.Getenv("LEM_DB")
|
||||
}
|
||||
if dbPath == "" {
|
||||
return fmt.Errorf("--db or LEM_DB required")
|
||||
}
|
||||
|
||||
if *dbPath == "" {
|
||||
*dbPath = os.Getenv("LEM_DB")
|
||||
}
|
||||
if *dbPath == "" {
|
||||
fmt.Fprintln(os.Stderr, "error: --db or LEM_DB required")
|
||||
os.Exit(1)
|
||||
}
|
||||
minLen := cfg.MinLen
|
||||
|
||||
db, err := OpenDBReadWrite(*dbPath)
|
||||
db, err := OpenDBReadWrite(dbPath)
|
||||
if err != nil {
|
||||
log.Fatalf("open db: %v", err)
|
||||
return fmt.Errorf("open db: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
// Check source tables.
|
||||
var seedCount int
|
||||
if err := db.conn.QueryRow("SELECT count(*) FROM seeds").Scan(&seedCount); err != nil {
|
||||
log.Fatalf("No seeds table. Run: lem import-all first")
|
||||
return fmt.Errorf("no seeds table: run lem import-all first")
|
||||
}
|
||||
fmt.Printf("Seeds table: %d rows\n", seedCount)
|
||||
|
||||
// Drop and recreate expansion_prompts.
|
||||
_, err = db.conn.Exec("DROP TABLE IF EXISTS expansion_prompts")
|
||||
if err != nil {
|
||||
log.Fatalf("drop expansion_prompts: %v", err)
|
||||
return fmt.Errorf("drop expansion_prompts: %v", err)
|
||||
}
|
||||
|
||||
// Deduplicate: remove seeds whose prompt already appears in prompts or golden_set.
|
||||
|
|
@ -85,9 +84,9 @@ func RunNormalize(args []string) {
|
|||
SELECT 1 FROM existing_prompts ep
|
||||
WHERE ep.prompt = us.prompt
|
||||
)
|
||||
`, *minLen))
|
||||
`, minLen))
|
||||
if err != nil {
|
||||
log.Fatalf("create expansion_prompts: %v", err)
|
||||
return fmt.Errorf("create expansion_prompts: %v", err)
|
||||
}
|
||||
|
||||
var total, domains, regions int
|
||||
|
|
@ -145,4 +144,5 @@ func RunNormalize(args []string) {
|
|||
}
|
||||
|
||||
fmt.Printf("\nNormalization complete: %d expansion prompts from %d seeds\n", total, seedCount)
|
||||
return nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3,9 +3,7 @@ package lem
|
|||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"flag"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
|
@ -13,6 +11,12 @@ import (
|
|||
"github.com/parquet-go/parquet-go"
|
||||
)
|
||||
|
||||
// ParquetOpts holds configuration for the Parquet export command.
|
||||
type ParquetOpts struct {
|
||||
Input string // Directory containing train.jsonl, valid.jsonl, test.jsonl (required)
|
||||
Output string // Output directory for Parquet files (defaults to input/parquet)
|
||||
}
|
||||
|
||||
// ParquetRow is the schema for exported Parquet files.
|
||||
type ParquetRow struct {
|
||||
Prompt string `parquet:"prompt"`
|
||||
|
|
@ -24,48 +28,39 @@ type ParquetRow struct {
|
|||
// RunParquet is the CLI entry point for the parquet command.
|
||||
// Reads JSONL training splits (train.jsonl, valid.jsonl, test.jsonl) and
|
||||
// writes Parquet files with snappy compression for HuggingFace datasets.
|
||||
func RunParquet(args []string) {
|
||||
fs := flag.NewFlagSet("parquet", flag.ExitOnError)
|
||||
|
||||
trainingDir := fs.String("input", "", "Directory containing train.jsonl, valid.jsonl, test.jsonl (required)")
|
||||
outputDir := fs.String("output", "", "Output directory for Parquet files (defaults to input/parquet)")
|
||||
|
||||
if err := fs.Parse(args); err != nil {
|
||||
log.Fatalf("parse flags: %v", err)
|
||||
func RunParquet(cfg ParquetOpts) error {
|
||||
if cfg.Input == "" {
|
||||
return fmt.Errorf("--input is required (directory with JSONL splits)")
|
||||
}
|
||||
|
||||
if *trainingDir == "" {
|
||||
fmt.Fprintln(os.Stderr, "error: --input is required (directory with JSONL splits)")
|
||||
fs.Usage()
|
||||
os.Exit(1)
|
||||
outputDir := cfg.Output
|
||||
if outputDir == "" {
|
||||
outputDir = filepath.Join(cfg.Input, "parquet")
|
||||
}
|
||||
|
||||
if *outputDir == "" {
|
||||
*outputDir = filepath.Join(*trainingDir, "parquet")
|
||||
if err := os.MkdirAll(outputDir, 0755); err != nil {
|
||||
return fmt.Errorf("create output dir: %w", err)
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(*outputDir, 0755); err != nil {
|
||||
log.Fatalf("create output dir: %v", err)
|
||||
}
|
||||
|
||||
fmt.Printf("Exporting Parquet from %s → %s\n", *trainingDir, *outputDir)
|
||||
fmt.Printf("Exporting Parquet from %s → %s\n", cfg.Input, outputDir)
|
||||
|
||||
total := 0
|
||||
for _, split := range []string{"train", "valid", "test"} {
|
||||
jsonlPath := filepath.Join(*trainingDir, split+".jsonl")
|
||||
jsonlPath := filepath.Join(cfg.Input, split+".jsonl")
|
||||
if _, err := os.Stat(jsonlPath); os.IsNotExist(err) {
|
||||
fmt.Printf(" Skip: %s.jsonl not found\n", split)
|
||||
continue
|
||||
}
|
||||
|
||||
n, err := exportSplitParquet(jsonlPath, *outputDir, split)
|
||||
n, err := exportSplitParquet(jsonlPath, outputDir, split)
|
||||
if err != nil {
|
||||
log.Fatalf("export %s: %v", split, err)
|
||||
return fmt.Errorf("export %s: %w", split, err)
|
||||
}
|
||||
total += n
|
||||
}
|
||||
|
||||
fmt.Printf("\nTotal: %d rows exported\n", total)
|
||||
return nil
|
||||
}
|
||||
|
||||
// exportSplitParquet reads a JSONL file and writes a Parquet file for the split.
|
||||
|
|
|
|||
|
|
@ -2,10 +2,8 @@ package lem
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
|
@ -13,28 +11,23 @@ import (
|
|||
"time"
|
||||
)
|
||||
|
||||
// PublishOpts holds configuration for the HuggingFace publish command.
|
||||
type PublishOpts struct {
|
||||
Input string // Directory containing Parquet files (required)
|
||||
Repo string // HuggingFace dataset repo ID
|
||||
Public bool // Make dataset public
|
||||
Token string // HuggingFace API token (defaults to HF_TOKEN env)
|
||||
DryRun bool // Show what would be uploaded without uploading
|
||||
}
|
||||
|
||||
// RunPublish is the CLI entry point for the publish command.
|
||||
// Pushes Parquet files and an optional dataset card to HuggingFace.
|
||||
func RunPublish(args []string) {
|
||||
fs := flag.NewFlagSet("publish", flag.ExitOnError)
|
||||
|
||||
inputDir := fs.String("input", "", "Directory containing Parquet files (required)")
|
||||
repoID := fs.String("repo", "lthn/LEM-golden-set", "HuggingFace dataset repo ID")
|
||||
public := fs.Bool("public", false, "Make dataset public")
|
||||
token := fs.String("token", "", "HuggingFace API token (defaults to HF_TOKEN env)")
|
||||
dryRun := fs.Bool("dry-run", false, "Show what would be uploaded without uploading")
|
||||
|
||||
if err := fs.Parse(args); err != nil {
|
||||
log.Fatalf("parse flags: %v", err)
|
||||
func RunPublish(cfg PublishOpts) error {
|
||||
if cfg.Input == "" {
|
||||
return fmt.Errorf("--input is required (directory with Parquet files)")
|
||||
}
|
||||
|
||||
if *inputDir == "" {
|
||||
fmt.Fprintln(os.Stderr, "error: --input is required (directory with Parquet files)")
|
||||
fs.Usage()
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
hfToken := *token
|
||||
hfToken := cfg.Token
|
||||
if hfToken == "" {
|
||||
hfToken = os.Getenv("HF_TOKEN")
|
||||
}
|
||||
|
|
@ -48,9 +41,8 @@ func RunPublish(args []string) {
|
|||
}
|
||||
}
|
||||
|
||||
if hfToken == "" && !*dryRun {
|
||||
fmt.Fprintln(os.Stderr, "error: HuggingFace token required (--token, HF_TOKEN env, or ~/.huggingface/token)")
|
||||
os.Exit(1)
|
||||
if hfToken == "" && !cfg.DryRun {
|
||||
return fmt.Errorf("HuggingFace token required (--token, HF_TOKEN env, or ~/.huggingface/token)")
|
||||
}
|
||||
|
||||
splits := []string{"train", "valid", "test"}
|
||||
|
|
@ -61,7 +53,7 @@ func RunPublish(args []string) {
|
|||
var filesToUpload []uploadEntry
|
||||
|
||||
for _, split := range splits {
|
||||
path := filepath.Join(*inputDir, split+".parquet")
|
||||
path := filepath.Join(cfg.Input, split+".parquet")
|
||||
if _, err := os.Stat(path); os.IsNotExist(err) {
|
||||
continue
|
||||
}
|
||||
|
|
@ -69,19 +61,18 @@ func RunPublish(args []string) {
|
|||
}
|
||||
|
||||
// Check for dataset card in parent directory.
|
||||
cardPath := filepath.Join(*inputDir, "..", "dataset_card.md")
|
||||
cardPath := filepath.Join(cfg.Input, "..", "dataset_card.md")
|
||||
if _, err := os.Stat(cardPath); err == nil {
|
||||
filesToUpload = append(filesToUpload, uploadEntry{cardPath, "README.md"})
|
||||
}
|
||||
|
||||
if len(filesToUpload) == 0 {
|
||||
fmt.Fprintln(os.Stderr, "error: no Parquet files found in input directory")
|
||||
os.Exit(1)
|
||||
return fmt.Errorf("no Parquet files found in input directory")
|
||||
}
|
||||
|
||||
if *dryRun {
|
||||
fmt.Printf("Dry run: would publish to %s\n", *repoID)
|
||||
if *public {
|
||||
if cfg.DryRun {
|
||||
fmt.Printf("Dry run: would publish to %s\n", cfg.Repo)
|
||||
if cfg.Public {
|
||||
fmt.Println(" Visibility: public")
|
||||
} else {
|
||||
fmt.Println(" Visibility: private")
|
||||
|
|
@ -91,19 +82,20 @@ func RunPublish(args []string) {
|
|||
sizeMB := float64(info.Size()) / 1024 / 1024
|
||||
fmt.Printf(" %s → %s (%.1f MB)\n", filepath.Base(f.local), f.remote, sizeMB)
|
||||
}
|
||||
return
|
||||
return nil
|
||||
}
|
||||
|
||||
fmt.Printf("Publishing to https://huggingface.co/datasets/%s\n", *repoID)
|
||||
fmt.Printf("Publishing to https://huggingface.co/datasets/%s\n", cfg.Repo)
|
||||
|
||||
for _, f := range filesToUpload {
|
||||
if err := uploadFileToHF(hfToken, *repoID, f.local, f.remote); err != nil {
|
||||
log.Fatalf("upload %s: %v", f.local, err)
|
||||
if err := uploadFileToHF(hfToken, cfg.Repo, f.local, f.remote); err != nil {
|
||||
return fmt.Errorf("upload %s: %w", f.local, err)
|
||||
}
|
||||
fmt.Printf(" Uploaded %s → %s\n", filepath.Base(f.local), f.remote)
|
||||
}
|
||||
|
||||
fmt.Printf("\nPublished to https://huggingface.co/datasets/%s\n", *repoID)
|
||||
fmt.Printf("\nPublished to https://huggingface.co/datasets/%s\n", cfg.Repo)
|
||||
return nil
|
||||
}
|
||||
|
||||
// uploadFileToHF uploads a file to a HuggingFace dataset repo via the Hub API.
|
||||
|
|
|
|||
|
|
@ -2,38 +2,31 @@ package lem
|
|||
|
||||
import (
|
||||
"encoding/json"
|
||||
"flag"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// QueryOpts holds configuration for the query command.
|
||||
type QueryOpts struct {
|
||||
DB string // DuckDB database path (defaults to LEM_DB env)
|
||||
JSON bool // Output as JSON instead of table
|
||||
}
|
||||
|
||||
// RunQuery is the CLI entry point for the query command.
|
||||
// Runs ad-hoc SQL against the DuckDB database.
|
||||
func RunQuery(args []string) {
|
||||
fs := flag.NewFlagSet("query", flag.ExitOnError)
|
||||
dbPath := fs.String("db", "", "DuckDB database path (defaults to LEM_DB env)")
|
||||
jsonOutput := fs.Bool("json", false, "Output as JSON instead of table")
|
||||
|
||||
if err := fs.Parse(args); err != nil {
|
||||
log.Fatalf("parse flags: %v", err)
|
||||
// The args slice contains the SQL query as positional arguments.
|
||||
func RunQuery(cfg QueryOpts, args []string) error {
|
||||
if cfg.DB == "" {
|
||||
cfg.DB = os.Getenv("LEM_DB")
|
||||
}
|
||||
if cfg.DB == "" {
|
||||
return fmt.Errorf("--db or LEM_DB required")
|
||||
}
|
||||
|
||||
if *dbPath == "" {
|
||||
*dbPath = os.Getenv("LEM_DB")
|
||||
}
|
||||
if *dbPath == "" {
|
||||
fmt.Fprintln(os.Stderr, "error: --db or LEM_DB required")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
sql := strings.Join(fs.Args(), " ")
|
||||
sql := strings.Join(args, " ")
|
||||
if sql == "" {
|
||||
fmt.Fprintln(os.Stderr, "error: SQL query required as positional argument")
|
||||
fmt.Fprintln(os.Stderr, " lem query --db path.duckdb \"SELECT * FROM golden_set LIMIT 5\"")
|
||||
fmt.Fprintln(os.Stderr, " lem query --db path.duckdb \"domain = 'ethics'\" (auto-wraps as WHERE clause)")
|
||||
os.Exit(1)
|
||||
return fmt.Errorf("SQL query required as positional argument\n lem query --db path.duckdb \"SELECT * FROM golden_set LIMIT 5\"\n lem query --db path.duckdb \"domain = 'ethics'\" (auto-wraps as WHERE clause)")
|
||||
}
|
||||
|
||||
// Auto-wrap non-SELECT queries as WHERE clauses.
|
||||
|
|
@ -43,21 +36,21 @@ func RunQuery(args []string) {
|
|||
sql = "SELECT * FROM golden_set WHERE " + sql + " LIMIT 20"
|
||||
}
|
||||
|
||||
db, err := OpenDB(*dbPath)
|
||||
db, err := OpenDB(cfg.DB)
|
||||
if err != nil {
|
||||
log.Fatalf("open db: %v", err)
|
||||
return fmt.Errorf("open db: %w", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
rows, err := db.conn.Query(sql)
|
||||
if err != nil {
|
||||
log.Fatalf("query: %v", err)
|
||||
return fmt.Errorf("query: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
cols, err := rows.Columns()
|
||||
if err != nil {
|
||||
log.Fatalf("columns: %v", err)
|
||||
return fmt.Errorf("columns: %w", err)
|
||||
}
|
||||
|
||||
var results []map[string]any
|
||||
|
|
@ -70,7 +63,7 @@ func RunQuery(args []string) {
|
|||
}
|
||||
|
||||
if err := rows.Scan(ptrs...); err != nil {
|
||||
log.Fatalf("scan: %v", err)
|
||||
return fmt.Errorf("scan: %w", err)
|
||||
}
|
||||
|
||||
row := make(map[string]any)
|
||||
|
|
@ -85,17 +78,16 @@ func RunQuery(args []string) {
|
|||
results = append(results, row)
|
||||
}
|
||||
|
||||
if *jsonOutput {
|
||||
if cfg.JSON {
|
||||
enc := json.NewEncoder(os.Stdout)
|
||||
enc.SetIndent("", " ")
|
||||
enc.Encode(results)
|
||||
return
|
||||
return enc.Encode(results)
|
||||
}
|
||||
|
||||
// Table output.
|
||||
if len(results) == 0 {
|
||||
fmt.Println("(no results)")
|
||||
return
|
||||
return nil
|
||||
}
|
||||
|
||||
// Calculate column widths.
|
||||
|
|
@ -149,4 +141,5 @@ func RunQuery(args []string) {
|
|||
}
|
||||
|
||||
fmt.Printf("\n(%d rows)\n", len(results))
|
||||
return nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,46 +1,40 @@
|
|||
package lem
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ScoreOpts holds configuration for the score run command.
|
||||
type ScoreOpts struct {
|
||||
Input string
|
||||
Suites string
|
||||
JudgeModel string
|
||||
JudgeURL string
|
||||
Concurrency int
|
||||
Output string
|
||||
Resume bool
|
||||
}
|
||||
|
||||
// RunScore scores existing response files using a judge model.
|
||||
func RunScore(args []string) {
|
||||
fs := flag.NewFlagSet("score", flag.ExitOnError)
|
||||
|
||||
input := fs.String("input", "", "Input JSONL response file (required)")
|
||||
suites := fs.String("suites", "all", "Comma-separated suites or 'all'")
|
||||
judgeModel := fs.String("judge-model", "mlx-community/gemma-3-27b-it-qat-4bit", "Judge model name")
|
||||
judgeURL := fs.String("judge-url", "http://10.69.69.108:8090", "Judge API URL")
|
||||
concurrency := fs.Int("concurrency", 4, "Max concurrent judge calls")
|
||||
output := fs.String("output", "scores.json", "Output score file path")
|
||||
resume := fs.Bool("resume", false, "Resume from existing output, skipping scored IDs")
|
||||
|
||||
if err := fs.Parse(args); err != nil {
|
||||
log.Fatalf("parse flags: %v", err)
|
||||
func RunScore(cfg ScoreOpts) error {
|
||||
if cfg.Input == "" {
|
||||
return fmt.Errorf("--input is required")
|
||||
}
|
||||
|
||||
if *input == "" {
|
||||
fmt.Fprintln(os.Stderr, "error: --input is required")
|
||||
fs.Usage()
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
responses, err := ReadResponses(*input)
|
||||
responses, err := ReadResponses(cfg.Input)
|
||||
if err != nil {
|
||||
log.Fatalf("read responses: %v", err)
|
||||
return fmt.Errorf("read responses: %w", err)
|
||||
}
|
||||
log.Printf("loaded %d responses from %s", len(responses), *input)
|
||||
log.Printf("loaded %d responses from %s", len(responses), cfg.Input)
|
||||
|
||||
if *resume {
|
||||
if _, statErr := os.Stat(*output); statErr == nil {
|
||||
existing, readErr := ReadScorerOutput(*output)
|
||||
if cfg.Resume {
|
||||
if _, statErr := os.Stat(cfg.Output); statErr == nil {
|
||||
existing, readErr := ReadScorerOutput(cfg.Output)
|
||||
if readErr != nil {
|
||||
log.Fatalf("read existing scores for resume: %v", readErr)
|
||||
return fmt.Errorf("read existing scores for resume: %w", readErr)
|
||||
}
|
||||
|
||||
scored := make(map[string]bool)
|
||||
|
|
@ -62,25 +56,25 @@ func RunScore(args []string) {
|
|||
|
||||
if len(responses) == 0 {
|
||||
log.Println("all responses already scored, nothing to do")
|
||||
return
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
client := NewClient(*judgeURL, *judgeModel)
|
||||
client := NewClient(cfg.JudgeURL, cfg.JudgeModel)
|
||||
client.MaxTokens = 512
|
||||
judge := NewJudge(client)
|
||||
engine := NewEngine(judge, *concurrency, *suites)
|
||||
engine := NewEngine(judge, cfg.Concurrency, cfg.Suites)
|
||||
|
||||
log.Printf("scoring with %s", engine)
|
||||
|
||||
perPrompt := engine.ScoreAll(responses)
|
||||
|
||||
if *resume {
|
||||
if _, statErr := os.Stat(*output); statErr == nil {
|
||||
existing, readErr := ReadScorerOutput(*output)
|
||||
if cfg.Resume {
|
||||
if _, statErr := os.Stat(cfg.Output); statErr == nil {
|
||||
existing, readErr := ReadScorerOutput(cfg.Output)
|
||||
if readErr != nil {
|
||||
log.Fatalf("re-read scores for merge: %v", readErr)
|
||||
return fmt.Errorf("re-read scores for merge: %w", readErr)
|
||||
}
|
||||
for model, scores := range existing.PerPrompt {
|
||||
perPrompt[model] = append(scores, perPrompt[model]...)
|
||||
|
|
@ -92,8 +86,8 @@ func RunScore(args []string) {
|
|||
|
||||
scorerOutput := &ScorerOutput{
|
||||
Metadata: Metadata{
|
||||
JudgeModel: *judgeModel,
|
||||
JudgeURL: *judgeURL,
|
||||
JudgeModel: cfg.JudgeModel,
|
||||
JudgeURL: cfg.JudgeURL,
|
||||
ScoredAt: time.Now().UTC(),
|
||||
ScorerVersion: "1.0.0",
|
||||
Suites: engine.SuiteNames(),
|
||||
|
|
@ -102,71 +96,69 @@ func RunScore(args []string) {
|
|||
PerPrompt: perPrompt,
|
||||
}
|
||||
|
||||
if err := WriteScores(*output, scorerOutput); err != nil {
|
||||
log.Fatalf("write scores: %v", err)
|
||||
if err := WriteScores(cfg.Output, scorerOutput); err != nil {
|
||||
return fmt.Errorf("write scores: %w", err)
|
||||
}
|
||||
|
||||
log.Printf("wrote scores to %s", *output)
|
||||
log.Printf("wrote scores to %s", cfg.Output)
|
||||
return nil
|
||||
}
|
||||
|
||||
// ProbeOpts holds configuration for the probe command.
|
||||
type ProbeOpts struct {
|
||||
Model string
|
||||
TargetURL string
|
||||
ProbesFile string
|
||||
Suites string
|
||||
JudgeModel string
|
||||
JudgeURL string
|
||||
Concurrency int
|
||||
Output string
|
||||
}
|
||||
|
||||
// RunProbe generates responses from a target model and scores them.
|
||||
func RunProbe(args []string) {
|
||||
fs := flag.NewFlagSet("probe", flag.ExitOnError)
|
||||
|
||||
model := fs.String("model", "", "Target model name (required)")
|
||||
targetURL := fs.String("target-url", "", "Target model API URL (defaults to judge-url)")
|
||||
probesFile := fs.String("probes", "", "Custom probes JSONL file (uses built-in content probes if not specified)")
|
||||
suites := fs.String("suites", "all", "Comma-separated suites or 'all'")
|
||||
judgeModel := fs.String("judge-model", "mlx-community/gemma-3-27b-it-qat-4bit", "Judge model name")
|
||||
judgeURL := fs.String("judge-url", "http://10.69.69.108:8090", "Judge API URL")
|
||||
concurrency := fs.Int("concurrency", 4, "Max concurrent judge calls")
|
||||
output := fs.String("output", "scores.json", "Output score file path")
|
||||
|
||||
if err := fs.Parse(args); err != nil {
|
||||
log.Fatalf("parse flags: %v", err)
|
||||
func RunProbe(cfg ProbeOpts) error {
|
||||
if cfg.Model == "" {
|
||||
return fmt.Errorf("--model is required")
|
||||
}
|
||||
|
||||
if *model == "" {
|
||||
fmt.Fprintln(os.Stderr, "error: --model is required")
|
||||
fs.Usage()
|
||||
os.Exit(1)
|
||||
targetURL := cfg.TargetURL
|
||||
if targetURL == "" {
|
||||
targetURL = cfg.JudgeURL
|
||||
}
|
||||
|
||||
if *targetURL == "" {
|
||||
*targetURL = *judgeURL
|
||||
}
|
||||
|
||||
targetClient := NewClient(*targetURL, *model)
|
||||
targetClient := NewClient(targetURL, cfg.Model)
|
||||
targetClient.MaxTokens = 1024
|
||||
judgeClient := NewClient(*judgeURL, *judgeModel)
|
||||
judgeClient := NewClient(cfg.JudgeURL, cfg.JudgeModel)
|
||||
judgeClient.MaxTokens = 512
|
||||
judge := NewJudge(judgeClient)
|
||||
engine := NewEngine(judge, *concurrency, *suites)
|
||||
engine := NewEngine(judge, cfg.Concurrency, cfg.Suites)
|
||||
prober := NewProber(targetClient, engine)
|
||||
|
||||
var scorerOutput *ScorerOutput
|
||||
var err error
|
||||
|
||||
if *probesFile != "" {
|
||||
probes, readErr := ReadResponses(*probesFile)
|
||||
if cfg.ProbesFile != "" {
|
||||
probes, readErr := ReadResponses(cfg.ProbesFile)
|
||||
if readErr != nil {
|
||||
log.Fatalf("read probes: %v", readErr)
|
||||
return fmt.Errorf("read probes: %w", readErr)
|
||||
}
|
||||
log.Printf("loaded %d custom probes from %s", len(probes), *probesFile)
|
||||
log.Printf("loaded %d custom probes from %s", len(probes), cfg.ProbesFile)
|
||||
|
||||
scorerOutput, err = prober.ProbeModel(probes, *model)
|
||||
scorerOutput, err = prober.ProbeModel(probes, cfg.Model)
|
||||
} else {
|
||||
log.Printf("using %d built-in content probes", len(ContentProbes))
|
||||
scorerOutput, err = prober.ProbeContent(*model)
|
||||
scorerOutput, err = prober.ProbeContent(cfg.Model)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
log.Fatalf("probe: %v", err)
|
||||
return fmt.Errorf("probe: %w", err)
|
||||
}
|
||||
|
||||
if writeErr := WriteScores(*output, scorerOutput); writeErr != nil {
|
||||
log.Fatalf("write scores: %v", writeErr)
|
||||
if writeErr := WriteScores(cfg.Output, scorerOutput); writeErr != nil {
|
||||
return fmt.Errorf("write scores: %w", writeErr)
|
||||
}
|
||||
|
||||
log.Printf("wrote scores to %s", *output)
|
||||
log.Printf("wrote scores to %s", cfg.Output)
|
||||
return nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,48 +1,47 @@
|
|||
package lem
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// SeedInfluxOpts holds configuration for the seed-influx command.
|
||||
type SeedInfluxOpts struct {
|
||||
DB string // DuckDB database path (defaults to LEM_DB env)
|
||||
InfluxURL string // InfluxDB URL
|
||||
InfluxDB string // InfluxDB database name
|
||||
Force bool // Re-seed even if InfluxDB already has data
|
||||
BatchSize int // Lines per InfluxDB write batch
|
||||
}
|
||||
|
||||
// RunSeedInflux is the CLI entry point for the seed-influx command.
|
||||
// Seeds InfluxDB golden_gen measurement from DuckDB golden_set data.
|
||||
// One-time migration tool for bootstrapping InfluxDB from existing data.
|
||||
func RunSeedInflux(args []string) {
|
||||
fs := flag.NewFlagSet("seed-influx", flag.ExitOnError)
|
||||
dbPath := fs.String("db", "", "DuckDB database path (defaults to LEM_DB env)")
|
||||
influxURL := fs.String("influx", "", "InfluxDB URL")
|
||||
influxDB := fs.String("influx-db", "", "InfluxDB database name")
|
||||
force := fs.Bool("force", false, "Re-seed even if InfluxDB already has data")
|
||||
batchSize := fs.Int("batch-size", 500, "Lines per InfluxDB write batch")
|
||||
|
||||
if err := fs.Parse(args); err != nil {
|
||||
log.Fatalf("parse flags: %v", err)
|
||||
func RunSeedInflux(cfg SeedInfluxOpts) error {
|
||||
if cfg.DB == "" {
|
||||
cfg.DB = os.Getenv("LEM_DB")
|
||||
}
|
||||
if cfg.DB == "" {
|
||||
return fmt.Errorf("--db or LEM_DB required")
|
||||
}
|
||||
|
||||
if *dbPath == "" {
|
||||
*dbPath = os.Getenv("LEM_DB")
|
||||
}
|
||||
if *dbPath == "" {
|
||||
fmt.Fprintln(os.Stderr, "error: --db or LEM_DB required")
|
||||
os.Exit(1)
|
||||
if cfg.BatchSize <= 0 {
|
||||
cfg.BatchSize = 500
|
||||
}
|
||||
|
||||
db, err := OpenDB(*dbPath)
|
||||
db, err := OpenDB(cfg.DB)
|
||||
if err != nil {
|
||||
log.Fatalf("open db: %v", err)
|
||||
return fmt.Errorf("open db: %w", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
var total int
|
||||
if err := db.conn.QueryRow("SELECT count(*) FROM golden_set").Scan(&total); err != nil {
|
||||
log.Fatalf("No golden_set table. Run ingest first.")
|
||||
return fmt.Errorf("no golden_set table — run ingest first")
|
||||
}
|
||||
|
||||
influx := NewInfluxClient(*influxURL, *influxDB)
|
||||
influx := NewInfluxClient(cfg.InfluxURL, cfg.InfluxDB)
|
||||
|
||||
// Check existing count in InfluxDB.
|
||||
existing := 0
|
||||
|
|
@ -55,9 +54,9 @@ func RunSeedInflux(args []string) {
|
|||
|
||||
fmt.Printf("DuckDB has %d records, InfluxDB golden_gen has %d\n", total, existing)
|
||||
|
||||
if existing >= total && !*force {
|
||||
if existing >= total && !cfg.Force {
|
||||
fmt.Println("InfluxDB already has all records. Use --force to re-seed.")
|
||||
return
|
||||
return nil
|
||||
}
|
||||
|
||||
// Read all rows.
|
||||
|
|
@ -66,7 +65,7 @@ func RunSeedInflux(args []string) {
|
|||
FROM golden_set ORDER BY idx
|
||||
`)
|
||||
if err != nil {
|
||||
log.Fatalf("query golden_set: %v", err)
|
||||
return fmt.Errorf("query golden_set: %w", err)
|
||||
}
|
||||
defer dbRows.Close()
|
||||
|
||||
|
|
@ -79,7 +78,7 @@ func RunSeedInflux(args []string) {
|
|||
var genTime float64
|
||||
|
||||
if err := dbRows.Scan(&idx, &seedID, &domain, &voice, &genTime, &charCount); err != nil {
|
||||
log.Fatalf("scan: %v", err)
|
||||
return fmt.Errorf("scan row: %w", err)
|
||||
}
|
||||
|
||||
sid := strings.ReplaceAll(seedID, `"`, `\"`)
|
||||
|
|
@ -87,9 +86,9 @@ func RunSeedInflux(args []string) {
|
|||
idx, escapeLp(domain), escapeLp(voice), sid, genTime, charCount)
|
||||
lines = append(lines, lp)
|
||||
|
||||
if len(lines) >= *batchSize {
|
||||
if len(lines) >= cfg.BatchSize {
|
||||
if err := influx.WriteLp(lines); err != nil {
|
||||
log.Fatalf("write batch at %d: %v", written, err)
|
||||
return fmt.Errorf("write batch at %d: %w", written, err)
|
||||
}
|
||||
written += len(lines)
|
||||
lines = lines[:0]
|
||||
|
|
@ -102,10 +101,11 @@ func RunSeedInflux(args []string) {
|
|||
|
||||
if len(lines) > 0 {
|
||||
if err := influx.WriteLp(lines); err != nil {
|
||||
log.Fatalf("flush: %v", err)
|
||||
return fmt.Errorf("flush: %w", err)
|
||||
}
|
||||
written += len(lines)
|
||||
}
|
||||
|
||||
fmt.Printf("Seeded %d golden_gen records into InfluxDB\n", written)
|
||||
return nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,48 +1,43 @@
|
|||
package lem
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"os"
|
||||
"sort"
|
||||
)
|
||||
|
||||
// runStatus parses CLI flags and prints training/generation status from InfluxDB.
|
||||
func RunStatus(args []string) {
|
||||
fs := flag.NewFlagSet("status", flag.ExitOnError)
|
||||
|
||||
influxURL := fs.String("influx", "", "InfluxDB URL (default http://10.69.69.165:8181)")
|
||||
influxDB := fs.String("influx-db", "", "InfluxDB database name (default training)")
|
||||
dbPath := fs.String("db", "", "DuckDB database path (shows table counts)")
|
||||
|
||||
if err := fs.Parse(args); err != nil {
|
||||
log.Fatalf("parse flags: %v", err)
|
||||
}
|
||||
// StatusOpts holds configuration for the status command.
|
||||
type StatusOpts struct {
|
||||
Influx string // InfluxDB URL (default http://10.69.69.165:8181)
|
||||
InfluxDB string // InfluxDB database name (default training)
|
||||
DB string // DuckDB database path (shows table counts)
|
||||
}
|
||||
|
||||
// RunStatus prints training/generation status from InfluxDB and optional DuckDB table counts.
|
||||
func RunStatus(cfg StatusOpts) error {
|
||||
// Check LEM_DB env as default for --db.
|
||||
if *dbPath == "" {
|
||||
*dbPath = os.Getenv("LEM_DB")
|
||||
if cfg.DB == "" {
|
||||
cfg.DB = os.Getenv("LEM_DB")
|
||||
}
|
||||
|
||||
influx := NewInfluxClient(*influxURL, *influxDB)
|
||||
influx := NewInfluxClient(cfg.Influx, cfg.InfluxDB)
|
||||
|
||||
if err := printStatus(influx, os.Stdout); err != nil {
|
||||
log.Fatalf("status: %v", err)
|
||||
return fmt.Errorf("status: %w", err)
|
||||
}
|
||||
|
||||
// If DuckDB path provided, show table counts.
|
||||
if *dbPath != "" {
|
||||
db, err := OpenDB(*dbPath)
|
||||
if cfg.DB != "" {
|
||||
db, err := OpenDB(cfg.DB)
|
||||
if err != nil {
|
||||
log.Fatalf("open db: %v", err)
|
||||
return fmt.Errorf("open db: %w", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
counts, err := db.TableCounts()
|
||||
if err != nil {
|
||||
log.Fatalf("table counts: %v", err)
|
||||
return fmt.Errorf("table counts: %w", err)
|
||||
}
|
||||
|
||||
fmt.Fprintln(os.Stdout)
|
||||
|
|
@ -55,6 +50,8 @@ func RunStatus(args []string) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// trainingRow holds deduplicated training status + loss for a single model.
|
||||
|
|
|
|||
|
|
@ -1,39 +1,36 @@
|
|||
package lem
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// TierScoreOpts holds configuration for the tier-score command.
|
||||
type TierScoreOpts struct {
|
||||
DBPath string
|
||||
Tier int
|
||||
Limit int
|
||||
}
|
||||
|
||||
// RunTierScore is the CLI entry point for the tier-score command.
|
||||
// Scores expansion responses using tiered quality assessment:
|
||||
// - Tier 1: Heuristic regex scoring (fast, no API)
|
||||
// - Tier 2: LEM self-judge (requires trained model)
|
||||
// - Tier 3: External judge (reserved for borderline cases)
|
||||
func RunTierScore(args []string) {
|
||||
fs := flag.NewFlagSet("tier-score", flag.ExitOnError)
|
||||
dbPath := fs.String("db", "", "DuckDB database path (defaults to LEM_DB env)")
|
||||
tier := fs.Int("tier", 1, "Scoring tier: 1=heuristic, 2=LEM judge, 3=external")
|
||||
limit := fs.Int("limit", 0, "Max items to score (0=all)")
|
||||
|
||||
if err := fs.Parse(args); err != nil {
|
||||
log.Fatalf("parse flags: %v", err)
|
||||
func RunTierScore(cfg TierScoreOpts) error {
|
||||
dbPath := cfg.DBPath
|
||||
if dbPath == "" {
|
||||
dbPath = os.Getenv("LEM_DB")
|
||||
}
|
||||
if dbPath == "" {
|
||||
return fmt.Errorf("--db or LEM_DB required")
|
||||
}
|
||||
|
||||
if *dbPath == "" {
|
||||
*dbPath = os.Getenv("LEM_DB")
|
||||
}
|
||||
if *dbPath == "" {
|
||||
fmt.Fprintln(os.Stderr, "error: --db or LEM_DB required")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
db, err := OpenDBReadWrite(*dbPath)
|
||||
db, err := OpenDBReadWrite(dbPath)
|
||||
if err != nil {
|
||||
log.Fatalf("open db: %v", err)
|
||||
return fmt.Errorf("open db: %w", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
|
|
@ -54,18 +51,20 @@ func RunTierScore(args []string) {
|
|||
)
|
||||
`)
|
||||
|
||||
if *tier >= 1 {
|
||||
runHeuristicTier(db, *limit)
|
||||
if cfg.Tier >= 1 {
|
||||
runHeuristicTier(db, cfg.Limit)
|
||||
}
|
||||
|
||||
if *tier >= 2 {
|
||||
if cfg.Tier >= 2 {
|
||||
fmt.Println("\nTier 2 (LEM judge): not yet available — needs trained LEM-27B model")
|
||||
fmt.Println(" Will score: sovereignty, ethical_depth, creative, self_concept (1-10 each)")
|
||||
}
|
||||
|
||||
if *tier >= 3 {
|
||||
if cfg.Tier >= 3 {
|
||||
fmt.Println("\nTier 3 (External judge): reserved for borderline cases")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func runHeuristicTier(db *DB, limit int) {
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@ package lem
|
|||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
|
|
@ -14,7 +13,25 @@ import (
|
|||
"time"
|
||||
)
|
||||
|
||||
// workerConfig holds the worker's runtime configuration.
|
||||
// WorkerOpts holds the worker's runtime configuration.
|
||||
type WorkerOpts struct {
|
||||
APIBase string // LEM API base URL
|
||||
WorkerID string // Worker ID (machine UUID)
|
||||
Name string // Worker display name
|
||||
APIKey string // API key (or use LEM_API_KEY env)
|
||||
GPUType string // GPU type (e.g. 'RTX 3090')
|
||||
VRAMGb int // GPU VRAM in GB
|
||||
Languages string // Comma-separated language codes (e.g. 'en,yo,sw')
|
||||
Models string // Comma-separated supported model names
|
||||
InferURL string // Local inference endpoint
|
||||
TaskType string // Filter by task type (expand, score, translate, seed)
|
||||
BatchSize int // Number of tasks to fetch per poll
|
||||
PollInterval time.Duration // Poll interval
|
||||
OneShot bool // Process one batch and exit
|
||||
DryRun bool // Fetch tasks but don't run inference
|
||||
}
|
||||
|
||||
// workerConfig is the internal runtime config populated from WorkerOpts.
|
||||
type workerConfig struct {
|
||||
apiBase string
|
||||
workerID string
|
||||
|
|
@ -50,35 +67,63 @@ type apiTask struct {
|
|||
}
|
||||
|
||||
// RunWorker is the CLI entry point for `lem worker`.
|
||||
func RunWorker(args []string) {
|
||||
fs := flag.NewFlagSet("worker", flag.ExitOnError)
|
||||
|
||||
cfg := workerConfig{}
|
||||
var langs, mods string
|
||||
|
||||
fs.StringVar(&cfg.apiBase, "api", envOr("LEM_API", "https://infer.lthn.ai"), "LEM API base URL")
|
||||
fs.StringVar(&cfg.workerID, "id", envOr("LEM_WORKER_ID", machineID()), "Worker ID (machine UUID)")
|
||||
fs.StringVar(&cfg.name, "name", envOr("LEM_WORKER_NAME", hostname()), "Worker display name")
|
||||
fs.StringVar(&cfg.apiKey, "key", envOr("LEM_API_KEY", ""), "API key (or use LEM_API_KEY env)")
|
||||
fs.StringVar(&cfg.gpuType, "gpu", envOr("LEM_GPU", ""), "GPU type (e.g. 'RTX 3090')")
|
||||
fs.IntVar(&cfg.vramGb, "vram", intEnvOr("LEM_VRAM_GB", 0), "GPU VRAM in GB")
|
||||
fs.StringVar(&langs, "languages", envOr("LEM_LANGUAGES", ""), "Comma-separated language codes (e.g. 'en,yo,sw')")
|
||||
fs.StringVar(&mods, "models", envOr("LEM_MODELS", ""), "Comma-separated supported model names")
|
||||
fs.StringVar(&cfg.inferURL, "infer", envOr("LEM_INFER_URL", "http://localhost:8090"), "Local inference endpoint")
|
||||
fs.StringVar(&cfg.taskType, "type", "", "Filter by task type (expand, score, translate, seed)")
|
||||
fs.IntVar(&cfg.batchSize, "batch", 5, "Number of tasks to fetch per poll")
|
||||
dur := fs.Duration("poll", 30*time.Second, "Poll interval")
|
||||
fs.BoolVar(&cfg.oneShot, "one-shot", false, "Process one batch and exit")
|
||||
fs.BoolVar(&cfg.dryRun, "dry-run", false, "Fetch tasks but don't run inference")
|
||||
|
||||
fs.Parse(args)
|
||||
cfg.pollInterval = *dur
|
||||
|
||||
if langs != "" {
|
||||
cfg.languages = splitComma(langs)
|
||||
func RunWorker(opts WorkerOpts) error {
|
||||
// Apply env-var fallbacks for fields not set by flags.
|
||||
if opts.APIBase == "" {
|
||||
opts.APIBase = envOr("LEM_API", "https://infer.lthn.ai")
|
||||
}
|
||||
if mods != "" {
|
||||
cfg.models = splitComma(mods)
|
||||
if opts.WorkerID == "" {
|
||||
opts.WorkerID = envOr("LEM_WORKER_ID", machineID())
|
||||
}
|
||||
if opts.Name == "" {
|
||||
opts.Name = envOr("LEM_WORKER_NAME", hostname())
|
||||
}
|
||||
if opts.APIKey == "" {
|
||||
opts.APIKey = envOr("LEM_API_KEY", "")
|
||||
}
|
||||
if opts.GPUType == "" {
|
||||
opts.GPUType = envOr("LEM_GPU", "")
|
||||
}
|
||||
if opts.VRAMGb == 0 {
|
||||
opts.VRAMGb = intEnvOr("LEM_VRAM_GB", 0)
|
||||
}
|
||||
if opts.Languages == "" {
|
||||
opts.Languages = envOr("LEM_LANGUAGES", "")
|
||||
}
|
||||
if opts.Models == "" {
|
||||
opts.Models = envOr("LEM_MODELS", "")
|
||||
}
|
||||
if opts.InferURL == "" {
|
||||
opts.InferURL = envOr("LEM_INFER_URL", "http://localhost:8090")
|
||||
}
|
||||
if opts.BatchSize <= 0 {
|
||||
opts.BatchSize = 5
|
||||
}
|
||||
if opts.PollInterval <= 0 {
|
||||
opts.PollInterval = 30 * time.Second
|
||||
}
|
||||
|
||||
// Build internal config.
|
||||
cfg := workerConfig{
|
||||
apiBase: opts.APIBase,
|
||||
workerID: opts.WorkerID,
|
||||
name: opts.Name,
|
||||
apiKey: opts.APIKey,
|
||||
gpuType: opts.GPUType,
|
||||
vramGb: opts.VRAMGb,
|
||||
inferURL: opts.InferURL,
|
||||
taskType: opts.TaskType,
|
||||
batchSize: opts.BatchSize,
|
||||
pollInterval: opts.PollInterval,
|
||||
oneShot: opts.OneShot,
|
||||
dryRun: opts.DryRun,
|
||||
}
|
||||
|
||||
if opts.Languages != "" {
|
||||
cfg.languages = splitComma(opts.Languages)
|
||||
}
|
||||
if opts.Models != "" {
|
||||
cfg.models = splitComma(opts.Models)
|
||||
}
|
||||
|
||||
if cfg.apiKey == "" {
|
||||
|
|
@ -86,7 +131,7 @@ func RunWorker(args []string) {
|
|||
}
|
||||
|
||||
if cfg.apiKey == "" {
|
||||
log.Fatal("No API key. Set LEM_API_KEY or run: lem worker-auth")
|
||||
return fmt.Errorf("no API key — set LEM_API_KEY or run: lem worker-auth")
|
||||
}
|
||||
|
||||
log.Printf("LEM Worker starting")
|
||||
|
|
@ -102,7 +147,7 @@ func RunWorker(args []string) {
|
|||
|
||||
// Register with the API.
|
||||
if err := workerRegister(&cfg); err != nil {
|
||||
log.Fatalf("Registration failed: %v", err)
|
||||
return fmt.Errorf("registration failed: %w", err)
|
||||
}
|
||||
log.Println("Registered with LEM API")
|
||||
|
||||
|
|
@ -112,7 +157,7 @@ func RunWorker(args []string) {
|
|||
|
||||
if cfg.oneShot {
|
||||
log.Printf("One-shot mode: processed %d tasks, exiting", processed)
|
||||
return
|
||||
return nil
|
||||
}
|
||||
|
||||
if processed == 0 {
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue