diff --git a/cmd/lemcmd/data.go b/cmd/lemcmd/data.go index 04192a9..cef0326 100644 --- a/cmd/lemcmd/data.go +++ b/cmd/lemcmd/data.go @@ -8,10 +8,54 @@ import ( func addDataCommands(root *cli.Command) { dataGroup := cli.NewGroup("data", "Data management commands", "Import, consolidate, normalise, and approve training data.") - dataGroup.AddCommand(passthrough("import-all", "Import ALL LEM data into DuckDB from M3", lem.RunImport)) - dataGroup.AddCommand(passthrough("consolidate", "Pull worker JSONLs from M3, merge, deduplicate", lem.RunConsolidate)) - dataGroup.AddCommand(passthrough("normalize", "Normalise seeds to deduplicated expansion prompts", lem.RunNormalize)) - dataGroup.AddCommand(passthrough("approve", "Filter scored expansions to training JSONL", lem.RunApprove)) + // import-all — Import ALL LEM data into DuckDB from M3. + var importCfg lem.ImportOpts + importCmd := cli.NewCommand("import-all", "Import ALL LEM data into DuckDB from M3", "", + func(cmd *cli.Command, args []string) error { + return lem.RunImport(importCfg) + }, + ) + cli.StringFlag(importCmd, &importCfg.DB, "db", "", "", "DuckDB database path (defaults to LEM_DB env)") + cli.BoolFlag(importCmd, &importCfg.SkipM3, "skip-m3", "", false, "Skip pulling data from M3") + cli.StringFlag(importCmd, &importCfg.DataDir, "data-dir", "", "", "Local data directory (defaults to db directory)") + dataGroup.AddCommand(importCmd) + + // consolidate — Pull worker JSONLs from M3, merge, deduplicate. + var consolidateCfg lem.ConsolidateOpts + consolidateCmd := cli.NewCommand("consolidate", "Pull worker JSONLs from M3, merge, deduplicate", "", + func(cmd *cli.Command, args []string) error { + return lem.RunConsolidate(consolidateCfg) + }, + ) + cli.StringFlag(consolidateCmd, &consolidateCfg.Host, "host", "", "m3", "SSH host for remote files") + cli.StringFlag(consolidateCmd, &consolidateCfg.Remote, "remote", "", "/Volumes/Data/lem/responses", "Remote directory for JSONL files") + cli.StringFlag(consolidateCmd, &consolidateCfg.Pattern, "pattern", "", "gold*.jsonl", "File glob pattern") + cli.StringFlag(consolidateCmd, &consolidateCfg.OutputDir, "output", "o", "", "Output directory (defaults to ./responses)") + cli.StringFlag(consolidateCmd, &consolidateCfg.Merged, "merged", "", "", "Merged output file (defaults to gold-merged.jsonl in output dir)") + dataGroup.AddCommand(consolidateCmd) + + // normalize — Normalise seeds to deduplicated expansion prompts. + var normalizeCfg lem.NormalizeOpts + normalizeCmd := cli.NewCommand("normalize", "Normalise seeds to deduplicated expansion prompts", "", + func(cmd *cli.Command, args []string) error { + return lem.RunNormalize(normalizeCfg) + }, + ) + cli.StringFlag(normalizeCmd, &normalizeCfg.DB, "db", "", "", "DuckDB database path (defaults to LEM_DB env)") + cli.IntFlag(normalizeCmd, &normalizeCfg.MinLen, "min-length", "", 50, "Minimum prompt length in characters") + dataGroup.AddCommand(normalizeCmd) + + // approve — Filter scored expansions to training JSONL. + var approveCfg lem.ApproveOpts + approveCmd := cli.NewCommand("approve", "Filter scored expansions to training JSONL", "", + func(cmd *cli.Command, args []string) error { + return lem.RunApprove(approveCfg) + }, + ) + cli.StringFlag(approveCmd, &approveCfg.DB, "db", "", "", "DuckDB database path (defaults to LEM_DB env)") + cli.StringFlag(approveCmd, &approveCfg.Output, "output", "o", "", "Output JSONL file (defaults to expansion-approved.jsonl in db dir)") + cli.Float64Flag(approveCmd, &approveCfg.Threshold, "threshold", "", 6.0, "Min judge average to approve") + dataGroup.AddCommand(approveCmd) root.AddCommand(dataGroup) } diff --git a/cmd/lemcmd/export.go b/cmd/lemcmd/export.go index b1f1fba..56f538f 100644 --- a/cmd/lemcmd/export.go +++ b/cmd/lemcmd/export.go @@ -8,10 +8,60 @@ import ( func addExportCommands(root *cli.Command) { exportGroup := cli.NewGroup("export", "Export and publish commands", "Export training data to JSONL, Parquet, HuggingFace, and PEFT formats.") - exportGroup.AddCommand(passthrough("jsonl", "Export golden set to training-format JSONL splits", lem.RunExport)) - exportGroup.AddCommand(passthrough("parquet", "Export JSONL training splits to Parquet", lem.RunParquet)) - exportGroup.AddCommand(passthrough("publish", "Push Parquet files to HuggingFace dataset repo", lem.RunPublish)) - exportGroup.AddCommand(passthrough("convert", "Convert MLX LoRA adapter to PEFT format", lem.RunConvert)) + // jsonl — export golden set to training-format JSONL splits. + var exportCfg lem.ExportOpts + jsonlCmd := cli.NewCommand("jsonl", "Export golden set to training-format JSONL splits", "", + func(cmd *cli.Command, args []string) error { + return lem.RunExport(exportCfg) + }, + ) + cli.StringFlag(jsonlCmd, &exportCfg.DBPath, "db", "", "", "DuckDB database path (primary source; defaults to LEM_DB env)") + cli.StringFlag(jsonlCmd, &exportCfg.Input, "input", "i", "", "Input golden set JSONL file (fallback if --db not set)") + cli.StringFlag(jsonlCmd, &exportCfg.OutputDir, "output-dir", "o", "", "Output directory for training files (required)") + cli.IntFlag(jsonlCmd, &exportCfg.TrainPct, "train-pct", "", 90, "Training set percentage") + cli.IntFlag(jsonlCmd, &exportCfg.ValidPct, "valid-pct", "", 5, "Validation set percentage") + cli.IntFlag(jsonlCmd, &exportCfg.TestPct, "test-pct", "", 5, "Test set percentage") + cli.Int64Flag(jsonlCmd, &exportCfg.Seed, "seed", "", 42, "Random seed for shuffling") + cli.IntFlag(jsonlCmd, &exportCfg.MinChars, "min-chars", "", 50, "Minimum response character count") + exportGroup.AddCommand(jsonlCmd) + + // parquet — export JSONL training splits to Parquet. + var parquetCfg lem.ParquetOpts + parquetCmd := cli.NewCommand("parquet", "Export JSONL training splits to Parquet", "", + func(cmd *cli.Command, args []string) error { + return lem.RunParquet(parquetCfg) + }, + ) + cli.StringFlag(parquetCmd, &parquetCfg.Input, "input", "i", "", "Directory containing train.jsonl, valid.jsonl, test.jsonl (required)") + cli.StringFlag(parquetCmd, &parquetCfg.Output, "output", "o", "", "Output directory for Parquet files (defaults to input/parquet)") + exportGroup.AddCommand(parquetCmd) + + // publish — push Parquet files to HuggingFace dataset repo. + var publishCfg lem.PublishOpts + publishCmd := cli.NewCommand("publish", "Push Parquet files to HuggingFace dataset repo", "", + func(cmd *cli.Command, args []string) error { + return lem.RunPublish(publishCfg) + }, + ) + cli.StringFlag(publishCmd, &publishCfg.Input, "input", "i", "", "Directory containing Parquet files (required)") + cli.StringFlag(publishCmd, &publishCfg.Repo, "repo", "", "lthn/LEM-golden-set", "HuggingFace dataset repo ID") + cli.BoolFlag(publishCmd, &publishCfg.Public, "public", "", false, "Make dataset public") + cli.StringFlag(publishCmd, &publishCfg.Token, "token", "", "", "HuggingFace API token (defaults to HF_TOKEN env)") + cli.BoolFlag(publishCmd, &publishCfg.DryRun, "dry-run", "", false, "Show what would be uploaded without uploading") + exportGroup.AddCommand(publishCmd) + + // convert — convert MLX LoRA adapter to PEFT format. + var convertCfg lem.ConvertOpts + convertCmd := cli.NewCommand("convert", "Convert MLX LoRA adapter to PEFT format", "", + func(cmd *cli.Command, args []string) error { + return lem.RunConvert(convertCfg) + }, + ) + cli.StringFlag(convertCmd, &convertCfg.Input, "input", "i", "", "Path to MLX .safetensors file (required)") + cli.StringFlag(convertCmd, &convertCfg.Config, "config", "c", "", "Path to MLX adapter_config.json (required)") + cli.StringFlag(convertCmd, &convertCfg.Output, "output", "o", "./peft_output", "Output directory for PEFT adapter") + cli.StringFlag(convertCmd, &convertCfg.BaseModel, "base-model", "", "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B", "HuggingFace base model ID") + exportGroup.AddCommand(convertCmd) root.AddCommand(exportGroup) } diff --git a/cmd/lemcmd/gen.go b/cmd/lemcmd/gen.go index 87584ae..da48638 100644 --- a/cmd/lemcmd/gen.go +++ b/cmd/lemcmd/gen.go @@ -8,9 +8,65 @@ import ( func addGenCommands(root *cli.Command) { genGroup := cli.NewGroup("gen", "Generation commands", "Distill, expand, and generate training data.") - genGroup.AddCommand(passthrough("distill", "Native Metal distillation (go-mlx + grammar scoring)", lem.RunDistill)) - genGroup.AddCommand(passthrough("expand", "Generate expansion responses via trained LEM model", lem.RunExpand)) - genGroup.AddCommand(passthrough("conv", "Generate conversational training data (calm phase)", lem.RunConv)) + // distill — native Metal distillation with grammar scoring. + var distillCfg lem.DistillOpts + distillCmd := cli.NewCommand("distill", "Native Metal distillation (go-mlx + grammar scoring)", "", + func(cmd *cli.Command, args []string) error { + return lem.RunDistill(distillCfg) + }, + ) + cli.StringFlag(distillCmd, &distillCfg.Model, "model", "m", "gemma3/27b", "Model config path (relative to .core/ai/models/)") + cli.StringFlag(distillCmd, &distillCfg.Probes, "probes", "p", "", "Probe set name from probes.yaml, or path to JSON file") + cli.StringFlag(distillCmd, &distillCfg.Output, "output", "o", "", "Output JSONL path (defaults to model training dir)") + cli.IntFlag(distillCmd, &distillCfg.Lesson, "lesson", "", -1, "Lesson number to append to (defaults to probe set phase)") + cli.Float64Flag(distillCmd, &distillCfg.MinScore, "min-score", "", 0, "Min grammar composite (0 = use ai.yaml default)") + cli.IntFlag(distillCmd, &distillCfg.Runs, "runs", "r", 0, "Generations per probe (0 = use ai.yaml default)") + cli.BoolFlag(distillCmd, &distillCfg.DryRun, "dry-run", "", false, "Show plan and exit without generating") + cli.StringFlag(distillCmd, &distillCfg.Root, "root", "", ".", "Project root (for .core/ai/ config)") + cli.IntFlag(distillCmd, &distillCfg.CacheLimit, "cache-limit", "", 0, "Metal cache limit in GB (0 = use ai.yaml default)") + cli.IntFlag(distillCmd, &distillCfg.MemLimit, "mem-limit", "", 0, "Metal memory limit in GB (0 = use ai.yaml default)") + genGroup.AddCommand(distillCmd) + + // expand — generate expansion responses via trained LEM model. + var expandCfg lem.ExpandOpts + expandCmd := cli.NewCommand("expand", "Generate expansion responses via trained LEM model", "", + func(cmd *cli.Command, args []string) error { + return lem.RunExpand(expandCfg) + }, + ) + cli.StringFlag(expandCmd, &expandCfg.Model, "model", "m", "", "Model name for generation (required)") + cli.StringFlag(expandCmd, &expandCfg.DB, "db", "", "", "DuckDB database path (primary prompt source)") + cli.StringFlag(expandCmd, &expandCfg.Prompts, "prompts", "p", "", "Input JSONL file with expansion prompts (fallback)") + cli.StringFlag(expandCmd, &expandCfg.APIURL, "api-url", "", "http://10.69.69.108:8090", "OpenAI-compatible API URL") + cli.StringFlag(expandCmd, &expandCfg.Worker, "worker", "", "", "Worker hostname (defaults to os.Hostname())") + cli.IntFlag(expandCmd, &expandCfg.Limit, "limit", "", 0, "Max prompts to process (0 = all)") + cli.StringFlag(expandCmd, &expandCfg.Output, "output", "o", ".", "Output directory for JSONL files") + cli.StringFlag(expandCmd, &expandCfg.Influx, "influx", "", "", "InfluxDB URL (default http://10.69.69.165:8181)") + cli.StringFlag(expandCmd, &expandCfg.InfluxDB, "influx-db", "", "", "InfluxDB database name (default training)") + cli.BoolFlag(expandCmd, &expandCfg.DryRun, "dry-run", "", false, "Print plan and exit without generating") + genGroup.AddCommand(expandCmd) + + // conv — generate conversational training data (calm phase). + var convCfg lem.ConvOpts + convCmd := cli.NewCommand("conv", "Generate conversational training data (calm phase)", "", + func(cmd *cli.Command, args []string) error { + return lem.RunConv(convCfg) + }, + ) + cli.StringFlag(convCmd, &convCfg.OutputDir, "output-dir", "o", "", "Output directory for training files (required)") + cli.StringFlag(convCmd, &convCfg.Extra, "extra", "", "", "Additional conversations JSONL file (multi-turn format)") + cli.StringFlag(convCmd, &convCfg.Golden, "golden", "", "", "Golden set JSONL to convert to single-turn conversations") + cli.StringFlag(convCmd, &convCfg.DB, "db", "", "", "DuckDB database path for golden set (alternative to --golden)") + cli.IntFlag(convCmd, &convCfg.TrainPct, "train-pct", "", 80, "Training set percentage") + cli.IntFlag(convCmd, &convCfg.ValidPct, "valid-pct", "", 10, "Validation set percentage") + cli.IntFlag(convCmd, &convCfg.TestPct, "test-pct", "", 10, "Test set percentage") + cli.Int64Flag(convCmd, &convCfg.Seed, "seed", "", 42, "Random seed for shuffling") + cli.IntFlag(convCmd, &convCfg.MinChars, "min-chars", "", 50, "Minimum response chars for golden set conversion") + cli.BoolFlag(convCmd, &convCfg.NoBuiltin, "no-builtin", "", false, "Exclude built-in seed conversations") + cli.StringFlag(convCmd, &convCfg.Influx, "influx", "", "", "InfluxDB URL for progress reporting") + cli.StringFlag(convCmd, &convCfg.InfluxDB, "influx-db", "", "", "InfluxDB database name") + cli.StringFlag(convCmd, &convCfg.Worker, "worker", "", "", "Worker hostname for InfluxDB reporting") + genGroup.AddCommand(convCmd) root.AddCommand(genGroup) } diff --git a/cmd/lemcmd/infra.go b/cmd/lemcmd/infra.go index d55e5a6..2a4681e 100644 --- a/cmd/lemcmd/infra.go +++ b/cmd/lemcmd/infra.go @@ -1,6 +1,8 @@ package lemcmd import ( + "time" + "forge.lthn.ai/core/cli/pkg/cli" "forge.lthn.ai/lthn/lem/pkg/lem" ) @@ -8,10 +10,70 @@ import ( func addInfraCommands(root *cli.Command) { infraGroup := cli.NewGroup("infra", "Infrastructure commands", "InfluxDB ingestion, DuckDB queries, and distributed workers.") - infraGroup.AddCommand(passthrough("ingest", "Ingest benchmark data into InfluxDB", lem.RunIngest)) - infraGroup.AddCommand(passthrough("seed-influx", "Seed InfluxDB golden_gen from DuckDB", lem.RunSeedInflux)) - infraGroup.AddCommand(passthrough("query", "Run ad-hoc SQL against DuckDB", lem.RunQuery)) - infraGroup.AddCommand(passthrough("worker", "Run as distributed inference worker node", lem.RunWorker)) + // ingest — push benchmark data into InfluxDB. + var ingestCfg lem.IngestOpts + ingestCmd := cli.NewCommand("ingest", "Ingest benchmark data into InfluxDB", "", + func(cmd *cli.Command, args []string) error { + return lem.RunIngest(ingestCfg) + }, + ) + cli.StringFlag(ingestCmd, &ingestCfg.Content, "content", "", "", "Content scores JSONL file") + cli.StringFlag(ingestCmd, &ingestCfg.Capability, "capability", "", "", "Capability scores JSONL file") + cli.StringFlag(ingestCmd, &ingestCfg.TrainingLog, "training-log", "", "", "MLX LoRA training log file") + cli.StringFlag(ingestCmd, &ingestCfg.Model, "model", "m", "", "Model name tag (required)") + cli.StringFlag(ingestCmd, &ingestCfg.RunID, "run-id", "", "", "Run ID tag (defaults to model name)") + cli.StringFlag(ingestCmd, &ingestCfg.InfluxURL, "influx", "", "", "InfluxDB URL") + cli.StringFlag(ingestCmd, &ingestCfg.InfluxDB, "influx-db", "", "", "InfluxDB database name") + cli.IntFlag(ingestCmd, &ingestCfg.BatchSize, "batch-size", "", 100, "Lines per InfluxDB write batch") + infraGroup.AddCommand(ingestCmd) + + // seed-influx — seed InfluxDB golden_gen from DuckDB. + var seedCfg lem.SeedInfluxOpts + seedCmd := cli.NewCommand("seed-influx", "Seed InfluxDB golden_gen from DuckDB", "", + func(cmd *cli.Command, args []string) error { + return lem.RunSeedInflux(seedCfg) + }, + ) + cli.StringFlag(seedCmd, &seedCfg.DB, "db", "", "", "DuckDB database path (defaults to LEM_DB env)") + cli.StringFlag(seedCmd, &seedCfg.InfluxURL, "influx", "", "", "InfluxDB URL") + cli.StringFlag(seedCmd, &seedCfg.InfluxDB, "influx-db", "", "", "InfluxDB database name") + cli.BoolFlag(seedCmd, &seedCfg.Force, "force", "", false, "Re-seed even if InfluxDB already has data") + cli.IntFlag(seedCmd, &seedCfg.BatchSize, "batch-size", "", 500, "Lines per InfluxDB write batch") + infraGroup.AddCommand(seedCmd) + + // query — run ad-hoc SQL against DuckDB. + var queryCfg lem.QueryOpts + queryCmd := cli.NewCommand("query", "Run ad-hoc SQL against DuckDB", "", + func(cmd *cli.Command, args []string) error { + return lem.RunQuery(queryCfg, args) + }, + ) + cli.StringFlag(queryCmd, &queryCfg.DB, "db", "", "", "DuckDB database path (defaults to LEM_DB env)") + cli.BoolFlag(queryCmd, &queryCfg.JSON, "json", "j", false, "Output as JSON instead of table") + infraGroup.AddCommand(queryCmd) + + // worker — distributed inference worker node. + var workerCfg lem.WorkerOpts + workerCmd := cli.NewCommand("worker", "Run as distributed inference worker node", "", + func(cmd *cli.Command, args []string) error { + return lem.RunWorker(workerCfg) + }, + ) + cli.StringFlag(workerCmd, &workerCfg.APIBase, "api", "", "", "LEM API base URL (or LEM_API env)") + cli.StringFlag(workerCmd, &workerCfg.WorkerID, "id", "", "", "Worker ID (or LEM_WORKER_ID env, defaults to machine UUID)") + cli.StringFlag(workerCmd, &workerCfg.Name, "name", "n", "", "Worker display name (or LEM_WORKER_NAME env)") + cli.StringFlag(workerCmd, &workerCfg.APIKey, "key", "k", "", "API key (or LEM_API_KEY env)") + cli.StringFlag(workerCmd, &workerCfg.GPUType, "gpu", "", "", "GPU type (e.g. 'RTX 3090', or LEM_GPU env)") + cli.IntFlag(workerCmd, &workerCfg.VRAMGb, "vram", "", 0, "GPU VRAM in GB (or LEM_VRAM_GB env)") + cli.StringFlag(workerCmd, &workerCfg.Languages, "languages", "", "", "Comma-separated language codes (or LEM_LANGUAGES env)") + cli.StringFlag(workerCmd, &workerCfg.Models, "models", "", "", "Comma-separated supported model names (or LEM_MODELS env)") + cli.StringFlag(workerCmd, &workerCfg.InferURL, "infer", "", "", "Local inference endpoint (or LEM_INFER_URL env)") + cli.StringFlag(workerCmd, &workerCfg.TaskType, "type", "t", "", "Filter by task type (expand, score, translate, seed)") + cli.IntFlag(workerCmd, &workerCfg.BatchSize, "batch", "b", 5, "Number of tasks to fetch per poll") + cli.DurationFlag(workerCmd, &workerCfg.PollInterval, "poll", "", 30*time.Second, "Poll interval") + cli.BoolFlag(workerCmd, &workerCfg.OneShot, "one-shot", "", false, "Process one batch and exit") + cli.BoolFlag(workerCmd, &workerCfg.DryRun, "dry-run", "", false, "Fetch tasks but don't run inference") + infraGroup.AddCommand(workerCmd) root.AddCommand(infraGroup) } diff --git a/cmd/lemcmd/lem.go b/cmd/lemcmd/lem.go index b52aa6f..74886f0 100644 --- a/cmd/lemcmd/lem.go +++ b/cmd/lemcmd/lem.go @@ -3,6 +3,11 @@ package lemcmd import ( + "fmt" + "os" + "path/filepath" + "strings" + "forge.lthn.ai/core/cli/pkg/cli" ) @@ -16,12 +21,35 @@ func AddLEMCommands(root *cli.Command) { addInfraCommands(root) } -// passthrough creates a command that passes all args (including flags) to fn. -// Used for commands that do their own flag parsing with flag.FlagSet. -func passthrough(use, short string, fn func([]string)) *cli.Command { - cmd := cli.NewRun(use, short, "", func(_ *cli.Command, args []string) { - fn(args) - }) - cmd.DisableFlagParsing = true - return cmd +// envOr returns the environment variable value, or the fallback if not set. +func envOr(key, fallback string) string { + if v := os.Getenv(key); v != "" { + return v + } + return fallback +} + +// intEnvOr returns the environment variable value parsed as int, or the fallback. +func intEnvOr(key string, fallback int) int { + v := os.Getenv(key) + if v == "" { + return fallback + } + var n int + fmt.Sscanf(v, "%d", &n) + if n == 0 { + return fallback + } + return n +} + +// expandHome expands a leading ~/ to the user's home directory. +func expandHome(path string) string { + if strings.HasPrefix(path, "~/") { + home, err := os.UserHomeDir() + if err == nil { + return filepath.Join(home, path[2:]) + } + } + return path } diff --git a/cmd/lemcmd/mon.go b/cmd/lemcmd/mon.go index 829430e..eb6ff81 100644 --- a/cmd/lemcmd/mon.go +++ b/cmd/lemcmd/mon.go @@ -8,11 +8,59 @@ import ( func addMonCommands(root *cli.Command) { monGroup := cli.NewGroup("mon", "Monitoring commands", "Training progress, pipeline status, inventory, coverage, and metrics.") - monGroup.AddCommand(passthrough("status", "Show training and generation progress (InfluxDB)", lem.RunStatus)) - monGroup.AddCommand(passthrough("expand-status", "Show expansion pipeline status (DuckDB)", lem.RunExpandStatus)) - monGroup.AddCommand(passthrough("inventory", "Show DuckDB table inventory", lem.RunInventory)) - monGroup.AddCommand(passthrough("coverage", "Analyse seed coverage gaps", lem.RunCoverage)) - monGroup.AddCommand(passthrough("metrics", "Push DuckDB golden set stats to InfluxDB", lem.RunMetrics)) + // status — training and generation progress from InfluxDB. + var statusCfg lem.StatusOpts + statusCmd := cli.NewCommand("status", "Show training and generation progress (InfluxDB)", "", + func(cmd *cli.Command, args []string) error { + return lem.RunStatus(statusCfg) + }, + ) + cli.StringFlag(statusCmd, &statusCfg.Influx, "influx", "", "", "InfluxDB URL (default http://10.69.69.165:8181)") + cli.StringFlag(statusCmd, &statusCfg.InfluxDB, "influx-db", "", "", "InfluxDB database name (default training)") + cli.StringFlag(statusCmd, &statusCfg.DB, "db", "", "", "DuckDB database path (shows table counts)") + monGroup.AddCommand(statusCmd) + + // expand-status — expansion pipeline status from DuckDB. + var expandStatusCfg lem.ExpandStatusOpts + expandStatusCmd := cli.NewCommand("expand-status", "Show expansion pipeline status (DuckDB)", "", + func(cmd *cli.Command, args []string) error { + return lem.RunExpandStatus(expandStatusCfg) + }, + ) + cli.StringFlag(expandStatusCmd, &expandStatusCfg.DB, "db", "", "", "DuckDB database path (defaults to LEM_DB env)") + monGroup.AddCommand(expandStatusCmd) + + // inventory — DuckDB table inventory. + var inventoryCfg lem.InventoryOpts + inventoryCmd := cli.NewCommand("inventory", "Show DuckDB table inventory", "", + func(cmd *cli.Command, args []string) error { + return lem.RunInventory(inventoryCfg) + }, + ) + cli.StringFlag(inventoryCmd, &inventoryCfg.DB, "db", "", "", "DuckDB database path (defaults to LEM_DB env)") + monGroup.AddCommand(inventoryCmd) + + // coverage — seed coverage gap analysis. + var coverageCfg lem.CoverageOpts + coverageCmd := cli.NewCommand("coverage", "Analyse seed coverage gaps", "", + func(cmd *cli.Command, args []string) error { + return lem.RunCoverage(coverageCfg) + }, + ) + cli.StringFlag(coverageCmd, &coverageCfg.DB, "db", "", "", "DuckDB database path (defaults to LEM_DB env)") + monGroup.AddCommand(coverageCmd) + + // metrics — push DuckDB golden set stats to InfluxDB. + var metricsCfg lem.MetricsOpts + metricsCmd := cli.NewCommand("metrics", "Push DuckDB golden set stats to InfluxDB", "", + func(cmd *cli.Command, args []string) error { + return lem.RunMetrics(metricsCfg) + }, + ) + cli.StringFlag(metricsCmd, &metricsCfg.DB, "db", "", "", "DuckDB database path (defaults to LEM_DB env)") + cli.StringFlag(metricsCmd, &metricsCfg.Influx, "influx", "", "", "InfluxDB URL") + cli.StringFlag(metricsCmd, &metricsCfg.InfluxDB, "influx-db", "", "", "InfluxDB database name") + monGroup.AddCommand(metricsCmd) root.AddCommand(monGroup) } diff --git a/cmd/lemcmd/score.go b/cmd/lemcmd/score.go index cf25317..cb24ce6 100644 --- a/cmd/lemcmd/score.go +++ b/cmd/lemcmd/score.go @@ -10,8 +10,38 @@ import ( func addScoreCommands(root *cli.Command) { scoreGroup := cli.NewGroup("score", "Scoring commands", "Score responses, probe models, compare results.") - scoreGroup.AddCommand(passthrough("run", "Score existing response files", lem.RunScore)) - scoreGroup.AddCommand(passthrough("probe", "Generate responses and score them", lem.RunProbe)) + // run — score existing response files. + var scoreCfg lem.ScoreOpts + scoreCmd := cli.NewCommand("run", "Score existing response files", "", + func(cmd *cli.Command, args []string) error { + return lem.RunScore(scoreCfg) + }, + ) + cli.StringFlag(scoreCmd, &scoreCfg.Input, "input", "i", "", "Input JSONL response file (required)") + cli.StringFlag(scoreCmd, &scoreCfg.Suites, "suites", "", "all", "Comma-separated suites or 'all'") + cli.StringFlag(scoreCmd, &scoreCfg.JudgeModel, "judge-model", "", "mlx-community/gemma-3-27b-it-qat-4bit", "Judge model name") + cli.StringFlag(scoreCmd, &scoreCfg.JudgeURL, "judge-url", "", "http://10.69.69.108:8090", "Judge API URL") + cli.IntFlag(scoreCmd, &scoreCfg.Concurrency, "concurrency", "c", 4, "Max concurrent judge calls") + cli.StringFlag(scoreCmd, &scoreCfg.Output, "output", "o", "scores.json", "Output score file path") + cli.BoolFlag(scoreCmd, &scoreCfg.Resume, "resume", "", false, "Resume from existing output, skipping scored IDs") + scoreGroup.AddCommand(scoreCmd) + + // probe — generate responses and score them. + var probeCfg lem.ProbeOpts + probeCmd := cli.NewCommand("probe", "Generate responses and score them", "", + func(cmd *cli.Command, args []string) error { + return lem.RunProbe(probeCfg) + }, + ) + cli.StringFlag(probeCmd, &probeCfg.Model, "model", "m", "", "Target model name (required)") + cli.StringFlag(probeCmd, &probeCfg.TargetURL, "target-url", "", "", "Target model API URL (defaults to judge-url)") + cli.StringFlag(probeCmd, &probeCfg.ProbesFile, "probes", "", "", "Custom probes JSONL file (uses built-in content probes if not specified)") + cli.StringFlag(probeCmd, &probeCfg.Suites, "suites", "", "all", "Comma-separated suites or 'all'") + cli.StringFlag(probeCmd, &probeCfg.JudgeModel, "judge-model", "", "mlx-community/gemma-3-27b-it-qat-4bit", "Judge model name") + cli.StringFlag(probeCmd, &probeCfg.JudgeURL, "judge-url", "", "http://10.69.69.108:8090", "Judge API URL") + cli.IntFlag(probeCmd, &probeCfg.Concurrency, "concurrency", "c", 4, "Max concurrent judge calls") + cli.StringFlag(probeCmd, &probeCfg.Output, "output", "o", "scores.json", "Output score file path") + scoreGroup.AddCommand(probeCmd) // compare has a different signature — it takes two named args, not []string. var compareOld, compareNew string @@ -27,9 +57,54 @@ func addScoreCommands(root *cli.Command) { cli.StringFlag(compareCmd, &compareNew, "new", "", "", "New score file (required)") scoreGroup.AddCommand(compareCmd) - scoreGroup.AddCommand(passthrough("attention", "Q/K Bone Orientation analysis for a prompt", lem.RunAttention)) - scoreGroup.AddCommand(passthrough("tier", "Score expansion responses (heuristic/judge tiers)", lem.RunTierScore)) - scoreGroup.AddCommand(passthrough("agent", "ROCm scoring daemon (polls M3, scores checkpoints)", lem.RunAgent)) + // attention — Q/K Bone Orientation analysis. + var attCfg lem.AttentionOpts + attCmd := cli.NewCommand("attention", "Q/K Bone Orientation analysis for a prompt", "", + func(cmd *cli.Command, args []string) error { + return lem.RunAttention(attCfg) + }, + ) + cli.StringFlag(attCmd, &attCfg.Model, "model", "m", "gemma3/1b", "Model config path (relative to .core/ai/models/)") + cli.StringFlag(attCmd, &attCfg.Prompt, "prompt", "p", "", "Prompt text to analyse") + cli.BoolFlag(attCmd, &attCfg.JSON, "json", "j", false, "Output as JSON") + cli.IntFlag(attCmd, &attCfg.CacheLimit, "cache-limit", "", 0, "Metal cache limit in GB (0 = use ai.yaml default)") + cli.IntFlag(attCmd, &attCfg.MemLimit, "mem-limit", "", 0, "Metal memory limit in GB (0 = use ai.yaml default)") + cli.StringFlag(attCmd, &attCfg.Root, "root", "", ".", "Project root (for .core/ai/ config)") + scoreGroup.AddCommand(attCmd) + + // tier — score expansion responses with heuristic/judge tiers. + var tierCfg lem.TierScoreOpts + tierCmd := cli.NewCommand("tier", "Score expansion responses (heuristic/judge tiers)", "", + func(cmd *cli.Command, args []string) error { + return lem.RunTierScore(tierCfg) + }, + ) + cli.StringFlag(tierCmd, &tierCfg.DBPath, "db", "", "", "DuckDB database path (defaults to LEM_DB env)") + cli.IntFlag(tierCmd, &tierCfg.Tier, "tier", "t", 1, "Scoring tier: 1=heuristic, 2=LEM judge, 3=external") + cli.IntFlag(tierCmd, &tierCfg.Limit, "limit", "l", 0, "Max items to score (0=all)") + scoreGroup.AddCommand(tierCmd) + + // agent — ROCm scoring daemon. + var agentCfg lem.AgentOpts + agentCmd := cli.NewCommand("agent", "ROCm scoring daemon (polls M3, scores checkpoints)", "", + func(cmd *cli.Command, args []string) error { + return lem.RunAgent(agentCfg) + }, + ) + cli.StringFlag(agentCmd, &agentCfg.M3Host, "m3-host", "", envOr("M3_HOST", "10.69.69.108"), "M3 host address") + cli.StringFlag(agentCmd, &agentCfg.M3User, "m3-user", "", envOr("M3_USER", "claude"), "M3 SSH user") + cli.StringFlag(agentCmd, &agentCfg.M3SSHKey, "m3-ssh-key", "", envOr("M3_SSH_KEY", expandHome("~/.ssh/id_ed25519")), "SSH key for M3") + cli.StringFlag(agentCmd, &agentCfg.M3AdapterBase, "m3-adapter-base", "", envOr("M3_ADAPTER_BASE", "/Volumes/Data/lem"), "Adapter base dir on M3") + cli.StringFlag(agentCmd, &agentCfg.InfluxURL, "influx", "", envOr("INFLUX_URL", "http://10.69.69.165:8181"), "InfluxDB URL") + cli.StringFlag(agentCmd, &agentCfg.InfluxDB, "influx-db", "", envOr("INFLUX_DB", "training"), "InfluxDB database") + cli.StringFlag(agentCmd, &agentCfg.APIURL, "api-url", "", envOr("LEM_API_URL", "http://localhost:8080"), "OpenAI-compatible inference API URL") + cli.StringFlag(agentCmd, &agentCfg.Model, "model", "m", envOr("LEM_MODEL", ""), "Model name for API (overrides auto-detect)") + cli.StringFlag(agentCmd, &agentCfg.BaseModel, "base-model", "", envOr("BASE_MODEL", "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B"), "HuggingFace base model ID") + cli.IntFlag(agentCmd, &agentCfg.PollInterval, "poll", "", intEnvOr("POLL_INTERVAL", 300), "Poll interval in seconds") + cli.StringFlag(agentCmd, &agentCfg.WorkDir, "work-dir", "", envOr("WORK_DIR", "/tmp/scoring-agent"), "Working directory for adapters") + cli.BoolFlag(agentCmd, &agentCfg.OneShot, "one-shot", "", false, "Process one checkpoint and exit") + cli.BoolFlag(agentCmd, &agentCfg.DryRun, "dry-run", "", false, "Discover and plan but don't execute") + scoreGroup.AddCommand(agentCmd) root.AddCommand(scoreGroup) } diff --git a/pkg/lem/agent.go b/pkg/lem/agent.go index 1db07a4..9403c4a 100644 --- a/pkg/lem/agent.go +++ b/pkg/lem/agent.go @@ -2,7 +2,6 @@ package lem import ( "encoding/json" - "flag" "fmt" "log" "os" @@ -14,21 +13,21 @@ import ( "time" ) -// agentConfig holds scoring agent configuration. -type agentConfig struct { - m3Host string - m3User string - m3SSHKey string - m3AdapterBase string - influxURL string - influxDB string - apiURL string - model string - baseModel string - pollInterval int - workDir string - oneShot bool - dryRun bool +// AgentOpts holds scoring agent configuration. +type AgentOpts struct { + M3Host string + M3User string + M3SSHKey string + M3AdapterBase string + InfluxURL string + InfluxDB string + APIURL string + Model string + BaseModel string + PollInterval int + WorkDir string + OneShot bool + DryRun bool } // checkpoint represents a discovered adapter checkpoint on M3. @@ -72,46 +71,26 @@ type bufferEntry struct { // Polls M3 for unscored LoRA checkpoints, converts MLX → PEFT, // runs 23 capability probes via an OpenAI-compatible API, and // pushes results to InfluxDB. -func RunAgent(args []string) { - fs := flag.NewFlagSet("agent", flag.ExitOnError) - - cfg := &agentConfig{} - fs.StringVar(&cfg.m3Host, "m3-host", envOr("M3_HOST", "10.69.69.108"), "M3 host address") - fs.StringVar(&cfg.m3User, "m3-user", envOr("M3_USER", "claude"), "M3 SSH user") - fs.StringVar(&cfg.m3SSHKey, "m3-ssh-key", envOr("M3_SSH_KEY", expandHome("~/.ssh/id_ed25519")), "SSH key for M3") - fs.StringVar(&cfg.m3AdapterBase, "m3-adapter-base", envOr("M3_ADAPTER_BASE", "/Volumes/Data/lem"), "Adapter base dir on M3") - fs.StringVar(&cfg.influxURL, "influx", envOr("INFLUX_URL", "http://10.69.69.165:8181"), "InfluxDB URL") - fs.StringVar(&cfg.influxDB, "influx-db", envOr("INFLUX_DB", "training"), "InfluxDB database") - fs.StringVar(&cfg.apiURL, "api-url", envOr("LEM_API_URL", "http://localhost:8080"), "OpenAI-compatible inference API URL") - fs.StringVar(&cfg.model, "model", envOr("LEM_MODEL", ""), "Model name for API (overrides auto-detect)") - fs.StringVar(&cfg.baseModel, "base-model", envOr("BASE_MODEL", "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B"), "HuggingFace base model ID") - fs.IntVar(&cfg.pollInterval, "poll", intEnvOr("POLL_INTERVAL", 300), "Poll interval in seconds") - fs.StringVar(&cfg.workDir, "work-dir", envOr("WORK_DIR", "/tmp/scoring-agent"), "Working directory for adapters") - fs.BoolVar(&cfg.oneShot, "one-shot", false, "Process one checkpoint and exit") - fs.BoolVar(&cfg.dryRun, "dry-run", false, "Discover and plan but don't execute") - - if err := fs.Parse(args); err != nil { - log.Fatalf("parse flags: %v", err) - } - - runAgentLoop(cfg) +func RunAgent(cfg AgentOpts) error { + runAgentLoop(&cfg) + return nil } -func runAgentLoop(cfg *agentConfig) { +func runAgentLoop(cfg *AgentOpts) { log.Println(strings.Repeat("=", 60)) log.Println("ROCm Scoring Agent — Go Edition") - log.Printf("M3: %s@%s", cfg.m3User, cfg.m3Host) - log.Printf("Inference API: %s", cfg.apiURL) - log.Printf("InfluxDB: %s/%s", cfg.influxURL, cfg.influxDB) - log.Printf("Poll interval: %ds", cfg.pollInterval) + log.Printf("M3: %s@%s", cfg.M3User, cfg.M3Host) + log.Printf("Inference API: %s", cfg.APIURL) + log.Printf("InfluxDB: %s/%s", cfg.InfluxURL, cfg.InfluxDB) + log.Printf("Poll interval: %ds", cfg.PollInterval) log.Println(strings.Repeat("=", 60)) - influx := NewInfluxClient(cfg.influxURL, cfg.influxDB) - os.MkdirAll(cfg.workDir, 0755) + influx := NewInfluxClient(cfg.InfluxURL, cfg.InfluxDB) + os.MkdirAll(cfg.WorkDir, 0755) for { // Replay any buffered results. - replayInfluxBuffer(cfg.workDir, influx) + replayInfluxBuffer(cfg.WorkDir, influx) // Discover checkpoints on M3. log.Println("Discovering checkpoints on M3...") @@ -135,18 +114,18 @@ func runAgentLoop(cfg *agentConfig) { log.Printf("Unscored: %d checkpoints", len(unscored)) if len(unscored) == 0 { - log.Printf("Nothing to score. Sleeping %ds...", cfg.pollInterval) - if cfg.oneShot { + log.Printf("Nothing to score. Sleeping %ds...", cfg.PollInterval) + if cfg.OneShot { return } - time.Sleep(time.Duration(cfg.pollInterval) * time.Second) + time.Sleep(time.Duration(cfg.PollInterval) * time.Second) continue } target := unscored[0] log.Printf("Grabbed: %s (%s)", target.Label, target.Dirname) - if cfg.dryRun { + if cfg.DryRun { log.Printf("[DRY RUN] Would process: %s/%s", target.Dirname, target.Filename) for _, u := range unscored[1:] { log.Printf("[DRY RUN] Queued: %s/%s", u.Dirname, u.Filename) @@ -158,7 +137,7 @@ func runAgentLoop(cfg *agentConfig) { log.Printf("Error processing %s: %v", target.Label, err) } - if cfg.oneShot { + if cfg.OneShot { return } @@ -167,8 +146,8 @@ func runAgentLoop(cfg *agentConfig) { } // discoverCheckpoints lists all adapter directories and checkpoint files on M3 via SSH. -func discoverCheckpoints(cfg *agentConfig) ([]checkpoint, error) { - out, err := sshCommand(cfg, fmt.Sprintf("ls -d %s/adapters-deepseek-r1-7b* 2>/dev/null", cfg.m3AdapterBase)) +func discoverCheckpoints(cfg *AgentOpts) ([]checkpoint, error) { + out, err := sshCommand(cfg, fmt.Sprintf("ls -d %s/adapters-deepseek-r1-7b* 2>/dev/null", cfg.M3AdapterBase)) if err != nil { return nil, fmt.Errorf("list adapter dirs: %w", err) } @@ -292,12 +271,12 @@ func findUnscored(checkpoints []checkpoint, scored map[[2]string]bool) []checkpo } // processOne fetches, converts, scores, and pushes one checkpoint. -func processOne(cfg *agentConfig, influx *InfluxClient, cp checkpoint) error { +func processOne(cfg *AgentOpts, influx *InfluxClient, cp checkpoint) error { log.Println(strings.Repeat("=", 60)) log.Printf("Processing: %s / %s", cp.Dirname, cp.Filename) log.Println(strings.Repeat("=", 60)) - localAdapterDir := filepath.Join(cfg.workDir, cp.Dirname) + localAdapterDir := filepath.Join(cfg.WorkDir, cp.Dirname) os.MkdirAll(localAdapterDir, 0755) localSF := filepath.Join(localAdapterDir, cp.Filename) @@ -307,7 +286,7 @@ func processOne(cfg *agentConfig, influx *InfluxClient, cp checkpoint) error { defer func() { os.Remove(localSF) os.Remove(localCfg) - peftDir := filepath.Join(cfg.workDir, fmt.Sprintf("peft_%07d", cp.Iteration)) + peftDir := filepath.Join(cfg.WorkDir, fmt.Sprintf("peft_%07d", cp.Iteration)) os.RemoveAll(peftDir) }() @@ -325,18 +304,18 @@ func processOne(cfg *agentConfig, influx *InfluxClient, cp checkpoint) error { // Convert MLX to PEFT format. log.Println("Converting MLX to PEFT format...") - peftDir := filepath.Join(cfg.workDir, fmt.Sprintf("peft_%07d", cp.Iteration)) - if err := convertMLXtoPEFT(localAdapterDir, cp.Filename, peftDir, cfg.baseModel); err != nil { + peftDir := filepath.Join(cfg.WorkDir, fmt.Sprintf("peft_%07d", cp.Iteration)) + if err := convertMLXtoPEFT(localAdapterDir, cp.Filename, peftDir, cfg.BaseModel); err != nil { return fmt.Errorf("convert adapter: %w", err) } // Run 23 capability probes via API. log.Println("Running 23 capability probes...") - modelName := cfg.model + modelName := cfg.Model if modelName == "" { modelName = cp.ModelTag } - client := NewClient(cfg.apiURL, modelName) + client := NewClient(cfg.APIURL, modelName) client.MaxTokens = 500 results := runCapabilityProbes(client) @@ -347,7 +326,7 @@ func processOne(cfg *agentConfig, influx *InfluxClient, cp checkpoint) error { // Push to InfluxDB (buffer on failure). if err := pushCapabilityResults(influx, cp, results); err != nil { log.Printf("InfluxDB push failed, buffering: %v", err) - bufferInfluxResult(cfg.workDir, cp, results) + bufferInfluxResult(cfg.WorkDir, cp, results) } return nil @@ -532,13 +511,13 @@ func replayInfluxBuffer(workDir string, influx *InfluxClient) { } // sshCommand executes a command on M3 via SSH. -func sshCommand(cfg *agentConfig, cmd string) (string, error) { +func sshCommand(cfg *AgentOpts, cmd string) (string, error) { sshArgs := []string{ "-o", "ConnectTimeout=10", "-o", "BatchMode=yes", "-o", "StrictHostKeyChecking=no", - "-i", cfg.m3SSHKey, - fmt.Sprintf("%s@%s", cfg.m3User, cfg.m3Host), + "-i", cfg.M3SSHKey, + fmt.Sprintf("%s@%s", cfg.M3User, cfg.M3Host), cmd, } result, err := exec.Command("ssh", sshArgs...).CombinedOutput() @@ -549,14 +528,14 @@ func sshCommand(cfg *agentConfig, cmd string) (string, error) { } // scpFrom copies a file from M3 to a local path. -func scpFrom(cfg *agentConfig, remotePath, localPath string) error { +func scpFrom(cfg *AgentOpts, remotePath, localPath string) error { os.MkdirAll(filepath.Dir(localPath), 0755) scpArgs := []string{ "-o", "ConnectTimeout=10", "-o", "BatchMode=yes", "-o", "StrictHostKeyChecking=no", - "-i", cfg.m3SSHKey, - fmt.Sprintf("%s@%s:%s", cfg.m3User, cfg.m3Host, remotePath), + "-i", cfg.M3SSHKey, + fmt.Sprintf("%s@%s:%s", cfg.M3User, cfg.M3Host, remotePath), localPath, } result, err := exec.Command("scp", scpArgs...).CombinedOutput() @@ -574,11 +553,11 @@ func fileBase(path string) string { return path } -func sleepOrExit(cfg *agentConfig) { - if cfg.oneShot { +func sleepOrExit(cfg *AgentOpts) { + if cfg.OneShot { return } - time.Sleep(time.Duration(cfg.pollInterval) * time.Second) + time.Sleep(time.Duration(cfg.PollInterval) * time.Second) } func envOr(key, fallback string) string { diff --git a/pkg/lem/approve.go b/pkg/lem/approve.go index b3513e0..9725839 100644 --- a/pkg/lem/approve.go +++ b/pkg/lem/approve.go @@ -2,41 +2,38 @@ package lem import ( "encoding/json" - "flag" "fmt" - "log" "os" "path/filepath" ) +// ApproveOpts holds configuration for the approve command. +type ApproveOpts struct { + DB string // DuckDB database path (defaults to LEM_DB env) + Output string // Output JSONL file (defaults to expansion-approved.jsonl in db dir) + Threshold float64 // Min judge average to approve (default: 6.0) +} + // RunApprove is the CLI entry point for the approve command. // Filters scored expansion responses by quality threshold and exports // approved ones as chat-format training JSONL. -func RunApprove(args []string) { - fs := flag.NewFlagSet("approve", flag.ExitOnError) - dbPath := fs.String("db", "", "DuckDB database path (defaults to LEM_DB env)") - output := fs.String("output", "", "Output JSONL file (defaults to expansion-approved.jsonl in db dir)") - threshold := fs.Float64("threshold", 6.0, "Min judge average to approve (default: 6.0)") - - if err := fs.Parse(args); err != nil { - log.Fatalf("parse flags: %v", err) +func RunApprove(cfg ApproveOpts) error { + dbPath := cfg.DB + if dbPath == "" { + dbPath = os.Getenv("LEM_DB") + } + if dbPath == "" { + return fmt.Errorf("--db or LEM_DB required") } - if *dbPath == "" { - *dbPath = os.Getenv("LEM_DB") - } - if *dbPath == "" { - fmt.Fprintln(os.Stderr, "error: --db or LEM_DB required") - os.Exit(1) + output := cfg.Output + if output == "" { + output = filepath.Join(filepath.Dir(dbPath), "expansion-approved.jsonl") } - if *output == "" { - *output = filepath.Join(filepath.Dir(*dbPath), "expansion-approved.jsonl") - } - - db, err := OpenDB(*dbPath) + db, err := OpenDB(dbPath) if err != nil { - log.Fatalf("open db: %v", err) + return fmt.Errorf("open db: %v", err) } defer db.Close() @@ -51,13 +48,13 @@ func RunApprove(args []string) { ORDER BY r.idx `) if err != nil { - log.Fatalf("query approved: %v (have you run scoring?)", err) + return fmt.Errorf("query approved: %v (have you run scoring?)", err) } defer rows.Close() - f, err := os.Create(*output) + f, err := os.Create(output) if err != nil { - log.Fatalf("create output: %v", err) + return fmt.Errorf("create output: %v", err) } defer f.Close() @@ -71,7 +68,7 @@ func RunApprove(args []string) { var seedID, region, domain, prompt, response, model string var genTime, score float64 if err := rows.Scan(&idx, &seedID, ®ion, &domain, &prompt, &response, &genTime, &model, &score); err != nil { - log.Fatalf("scan: %v", err) + return fmt.Errorf("scan: %v", err) } example := TrainingExample{ @@ -82,7 +79,7 @@ func RunApprove(args []string) { } if err := enc.Encode(example); err != nil { - log.Fatalf("encode: %v", err) + return fmt.Errorf("encode: %v", err) } regionSet[region] = true @@ -90,9 +87,10 @@ func RunApprove(args []string) { count++ } - _ = *threshold // threshold used in query above for future judge scoring + _ = cfg.Threshold // threshold used in query above for future judge scoring fmt.Printf("Approved: %d responses (threshold: heuristic > 0)\n", count) - fmt.Printf("Exported: %s\n", *output) + fmt.Printf("Exported: %s\n", output) fmt.Printf(" Regions: %d, Domains: %d\n", len(regionSet), len(domainSet)) + return nil } diff --git a/pkg/lem/cmd_attention.go b/pkg/lem/cmd_attention.go index e63227d..25bc19c 100644 --- a/pkg/lem/cmd_attention.go +++ b/pkg/lem/cmd_attention.go @@ -3,62 +3,49 @@ package lem import ( "context" "encoding/json" - "flag" "fmt" "log" - "os" "time" "forge.lthn.ai/core/go-ml" "forge.lthn.ai/core/go-mlx" ) +// AttentionOpts holds configuration for the attention command. +type AttentionOpts struct { + Model string // Model config path (relative to .core/ai/models/) + Prompt string // Prompt text to analyse + JSON bool // Output as JSON + CacheLimit int // Metal cache limit in GB (0 = use ai.yaml default) + MemLimit int // Metal memory limit in GB (0 = use ai.yaml default) + Root string // Project root (for .core/ai/ config) +} + // RunAttention is the entry point for `lem score attention`. -func RunAttention(args []string) { - fs := flag.NewFlagSet("attention", flag.ExitOnError) - var ( - modelFlag string - prompt string - jsonOutput bool - cacheLimit int - memLimit int - root string - ) - fs.StringVar(&modelFlag, "model", "gemma3/1b", "Model config path (relative to .core/ai/models/)") - fs.StringVar(&prompt, "prompt", "", "Prompt text to analyse") - fs.BoolVar(&jsonOutput, "json", false, "Output as JSON") - fs.IntVar(&cacheLimit, "cache-limit", 0, "Metal cache limit in GB (0 = use ai.yaml default)") - fs.IntVar(&memLimit, "mem-limit", 0, "Metal memory limit in GB (0 = use ai.yaml default)") - fs.StringVar(&root, "root", ".", "Project root (for .core/ai/ config)") - - if err := fs.Parse(args); err != nil { - log.Fatalf("parse flags: %v", err) - } - - if prompt == "" { - fmt.Fprintln(os.Stderr, "usage: lem score attention -model -prompt ") - os.Exit(1) +func RunAttention(cfg AttentionOpts) error { + if cfg.Prompt == "" { + return fmt.Errorf("--prompt is required") } // Load configs. - aiCfg, err := LoadAIConfig(root) + aiCfg, err := LoadAIConfig(cfg.Root) if err != nil { - log.Fatalf("load ai config: %v", err) + return fmt.Errorf("load ai config: %w", err) } - modelCfg, err := LoadModelConfig(root, modelFlag) + modelCfg, err := LoadModelConfig(cfg.Root, cfg.Model) if err != nil { - log.Fatalf("load model config: %v", err) + return fmt.Errorf("load model config: %w", err) } // Apply memory limits. cacheLimitGB := aiCfg.Distill.CacheLimit - if cacheLimit > 0 { - cacheLimitGB = cacheLimit + if cfg.CacheLimit > 0 { + cacheLimitGB = cfg.CacheLimit } memLimitGB := aiCfg.Distill.MemoryLimit - if memLimit > 0 { - memLimitGB = memLimit + if cfg.MemLimit > 0 { + memLimitGB = cfg.MemLimit } if cacheLimitGB > 0 { mlx.SetCacheLimit(uint64(cacheLimitGB) * 1024 * 1024 * 1024) @@ -71,7 +58,7 @@ func RunAttention(args []string) { log.Printf("loading model: %s", modelCfg.Paths.Base) backend, err := ml.NewMLXBackend(modelCfg.Paths.Base) if err != nil { - log.Fatalf("load model: %v", err) + return fmt.Errorf("load model: %w", err) } defer backend.Close() @@ -79,18 +66,17 @@ func RunAttention(args []string) { ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) defer cancel() - snap, err := backend.InspectAttention(ctx, prompt) + snap, err := backend.InspectAttention(ctx, cfg.Prompt) if err != nil { - fmt.Fprintf(os.Stderr, "error: %v\n", err) - os.Exit(1) + return fmt.Errorf("inspect attention: %w", err) } result := AnalyseAttention(snap) - if jsonOutput { + if cfg.JSON { data, _ := json.MarshalIndent(result, "", " ") fmt.Println(string(data)) - return + return nil } // Human-readable output. @@ -107,4 +93,5 @@ func RunAttention(args []string) { fmt.Printf(" Phase-Lock Score: %.3f\n", result.PhaseLockScore) fmt.Printf(" Joint Collapses: %d\n", result.JointCollapseCount) fmt.Printf(" Composite (0-10000): %d\n", result.Composite()) + return nil } diff --git a/pkg/lem/consolidate.go b/pkg/lem/consolidate.go index 76cf31e..ecdafab 100644 --- a/pkg/lem/consolidate.go +++ b/pkg/lem/consolidate.go @@ -3,7 +3,6 @@ package lem import ( "bufio" "encoding/json" - "flag" "fmt" "log" "os" @@ -13,34 +12,38 @@ import ( "strings" ) +// ConsolidateOpts holds configuration for the consolidate command. +type ConsolidateOpts struct { + Host string // SSH host for remote files + Remote string // Remote directory for JSONL files + Pattern string // File glob pattern + OutputDir string // Output directory (defaults to ./responses) + Merged string // Merged output file (defaults to gold-merged.jsonl in output dir) +} + // RunConsolidate is the CLI entry point for the consolidate command. // Pulls all worker JSONLs from M3, merges them, deduplicates on idx, // and writes a single merged file. -func RunConsolidate(args []string) { - fs := flag.NewFlagSet("consolidate", flag.ExitOnError) - remoteHost := fs.String("host", "m3", "SSH host for remote files") - remotePath := fs.String("remote", "/Volumes/Data/lem/responses", "Remote directory for JSONL files") - pattern := fs.String("pattern", "gold*.jsonl", "File glob pattern") - outputDir := fs.String("output", "", "Output directory (defaults to ./responses)") - merged := fs.String("merged", "", "Merged output file (defaults to gold-merged.jsonl in output dir)") +func RunConsolidate(cfg ConsolidateOpts) error { + remoteHost := cfg.Host + remotePath := cfg.Remote + pattern := cfg.Pattern + outputDir := cfg.OutputDir + merged := cfg.Merged - if err := fs.Parse(args); err != nil { - log.Fatalf("parse flags: %v", err) + if outputDir == "" { + outputDir = "responses" } - - if *outputDir == "" { - *outputDir = "responses" - } - if err := os.MkdirAll(*outputDir, 0755); err != nil { - log.Fatalf("create output dir: %v", err) + if err := os.MkdirAll(outputDir, 0755); err != nil { + return fmt.Errorf("create output dir: %v", err) } // List remote files. fmt.Println("Pulling responses from remote...") - listCmd := exec.Command("ssh", *remoteHost, fmt.Sprintf("ls %s/%s", *remotePath, *pattern)) + listCmd := exec.Command("ssh", remoteHost, fmt.Sprintf("ls %s/%s", remotePath, pattern)) listOutput, err := listCmd.Output() if err != nil { - log.Fatalf("list remote files: %v", err) + return fmt.Errorf("list remote files: %v", err) } remoteFiles := strings.Split(strings.TrimSpace(string(listOutput)), "\n") @@ -51,12 +54,12 @@ func RunConsolidate(args []string) { validFiles = append(validFiles, f) } } - fmt.Printf(" Found %d JSONL files on %s\n", len(validFiles), *remoteHost) + fmt.Printf(" Found %d JSONL files on %s\n", len(validFiles), remoteHost) // Pull files. for _, rf := range validFiles { - local := filepath.Join(*outputDir, filepath.Base(rf)) - scpCmd := exec.Command("scp", fmt.Sprintf("%s:%s", *remoteHost, rf), local) + local := filepath.Join(outputDir, filepath.Base(rf)) + scpCmd := exec.Command("scp", fmt.Sprintf("%s:%s", remoteHost, rf), local) if err := scpCmd.Run(); err != nil { log.Printf("warning: failed to pull %s: %v", rf, err) continue @@ -80,7 +83,7 @@ func RunConsolidate(args []string) { seen := make(map[int]json.RawMessage) skipped := 0 - matches, _ := filepath.Glob(filepath.Join(*outputDir, *pattern)) + matches, _ := filepath.Glob(filepath.Join(outputDir, pattern)) sort.Strings(matches) for _, local := range matches { @@ -115,8 +118,8 @@ func RunConsolidate(args []string) { } // Sort by idx and write merged file. - if *merged == "" { - *merged = filepath.Join(*outputDir, "..", "gold-merged.jsonl") + if merged == "" { + merged = filepath.Join(outputDir, "..", "gold-merged.jsonl") } idxs := make([]int, 0, len(seen)) @@ -125,9 +128,9 @@ func RunConsolidate(args []string) { } sort.Ints(idxs) - f, err := os.Create(*merged) + f, err := os.Create(merged) if err != nil { - log.Fatalf("create merged file: %v", err) + return fmt.Errorf("create merged file: %v", err) } for _, idx := range idxs { f.Write(seen[idx]) @@ -135,5 +138,6 @@ func RunConsolidate(args []string) { } f.Close() - fmt.Printf("\nMerged: %d unique examples → %s\n", len(seen), *merged) + fmt.Printf("\nMerged: %d unique examples → %s\n", len(seen), merged) + return nil } diff --git a/pkg/lem/conv.go b/pkg/lem/conv.go index 9d20c1a..57c3e5d 100644 --- a/pkg/lem/conv.go +++ b/pkg/lem/conv.go @@ -3,7 +3,6 @@ package lem import ( "bufio" "encoding/json" - "flag" "fmt" "log" "math/rand" @@ -11,86 +10,79 @@ import ( "strings" ) -// RunConv is the CLI entry point for the conv command. -// It generates multi-turn conversational training data from built-in +// ConvOpts holds configuration for the conv command. +type ConvOpts struct { + OutputDir string // Output directory for training files (required) + Extra string // Additional conversations JSONL file (multi-turn format) + Golden string // Golden set JSONL to convert to single-turn conversations + DB string // DuckDB database path for golden set (alternative to --golden) + TrainPct int // Training set percentage + ValidPct int // Validation set percentage + TestPct int // Test set percentage + Seed int64 // Random seed for shuffling + MinChars int // Minimum response chars for golden set conversion + NoBuiltin bool // Exclude built-in seed conversations + Influx string // InfluxDB URL for progress reporting + InfluxDB string // InfluxDB database name + Worker string // Worker hostname for InfluxDB reporting +} + +// RunConv generates multi-turn conversational training data from built-in // seed conversations plus optional extra files and golden set data. -func RunConv(args []string) { - fs := flag.NewFlagSet("conv", flag.ExitOnError) - - outputDir := fs.String("output-dir", "", "Output directory for training files (required)") - extra := fs.String("extra", "", "Additional conversations JSONL file (multi-turn format)") - golden := fs.String("golden", "", "Golden set JSONL to convert to single-turn conversations") - dbPath := fs.String("db", "", "DuckDB database path for golden set (alternative to --golden)") - trainPct := fs.Int("train-pct", 80, "Training set percentage") - validPct := fs.Int("valid-pct", 10, "Validation set percentage") - testPct := fs.Int("test-pct", 10, "Test set percentage") - seed := fs.Int64("seed", 42, "Random seed for shuffling") - minChars := fs.Int("min-chars", 50, "Minimum response chars for golden set conversion") - noBuiltin := fs.Bool("no-builtin", false, "Exclude built-in seed conversations") - influxURL := fs.String("influx", "", "InfluxDB URL for progress reporting") - influxDB := fs.String("influx-db", "", "InfluxDB database name") - worker := fs.String("worker", "", "Worker hostname for InfluxDB reporting") - - if err := fs.Parse(args); err != nil { - log.Fatalf("parse flags: %v", err) +func RunConv(cfg ConvOpts) error { + if cfg.OutputDir == "" { + return fmt.Errorf("--output-dir is required") } - if *outputDir == "" { - fmt.Fprintln(os.Stderr, "error: --output-dir is required") - fs.Usage() - os.Exit(1) - } - - if err := validatePercentages(*trainPct, *validPct, *testPct); err != nil { - fmt.Fprintf(os.Stderr, "error: %v\n", err) - os.Exit(1) + if err := validatePercentages(cfg.TrainPct, cfg.ValidPct, cfg.TestPct); err != nil { + return fmt.Errorf("percentages: %w", err) } // Check LEM_DB env as default for --db. - if *dbPath == "" { - *dbPath = os.Getenv("LEM_DB") + if cfg.DB == "" { + cfg.DB = os.Getenv("LEM_DB") } // Default worker to hostname. - if *worker == "" { + if cfg.Worker == "" { hostname, err := os.Hostname() if err != nil { hostname = "unknown" } - *worker = hostname + cfg.Worker = hostname } // Collect all conversations. var conversations []TrainingExample // 1. Built-in seed conversations. - if !*noBuiltin { + if !cfg.NoBuiltin { conversations = append(conversations, SeedConversations...) log.Printf("loaded %d built-in seed conversations", len(SeedConversations)) } // 2. Extra conversations from file. - if *extra != "" { - extras, err := readConversations(*extra) + if cfg.Extra != "" { + extras, err := readConversations(cfg.Extra) if err != nil { - log.Fatalf("read extra conversations: %v", err) + return fmt.Errorf("read extra conversations: %w", err) } conversations = append(conversations, extras...) - log.Printf("loaded %d extra conversations from %s", len(extras), *extra) + log.Printf("loaded %d extra conversations from %s", len(extras), cfg.Extra) } // 3. Golden set responses converted to single-turn format. var goldenResponses []Response - if *dbPath != "" && *golden == "" { - db, err := OpenDB(*dbPath) + if cfg.DB != "" && cfg.Golden == "" { + db, err := OpenDB(cfg.DB) if err != nil { - log.Fatalf("open db: %v", err) + return fmt.Errorf("open db: %w", err) } defer db.Close() - rows, err := db.QueryGoldenSet(*minChars) + rows, err := db.QueryGoldenSet(cfg.MinChars) if err != nil { - log.Fatalf("query golden_set: %v", err) + return fmt.Errorf("query golden_set: %w", err) } for _, r := range rows { goldenResponses = append(goldenResponses, Response{ @@ -101,32 +93,32 @@ func RunConv(args []string) { Model: r.Voice, }) } - log.Printf("loaded %d golden set rows from %s", len(goldenResponses), *dbPath) - } else if *golden != "" { + log.Printf("loaded %d golden set rows from %s", len(goldenResponses), cfg.DB) + } else if cfg.Golden != "" { var err error - goldenResponses, err = ReadResponses(*golden) + goldenResponses, err = ReadResponses(cfg.Golden) if err != nil { - log.Fatalf("read golden set: %v", err) + return fmt.Errorf("read golden set: %w", err) } - log.Printf("loaded %d golden set responses from %s", len(goldenResponses), *golden) + log.Printf("loaded %d golden set responses from %s", len(goldenResponses), cfg.Golden) } if len(goldenResponses) > 0 { - converted := convertToConversations(goldenResponses, *minChars) + converted := convertToConversations(goldenResponses, cfg.MinChars) conversations = append(conversations, converted...) log.Printf("converted %d golden set responses to single-turn conversations", len(converted)) } if len(conversations) == 0 { - log.Fatal("no conversations to process — use built-in seeds, --extra, --golden, or --db") + return fmt.Errorf("no conversations to process — use built-in seeds, --extra, --golden, or --db") } // Split into train/valid/test. - train, valid, test := splitConversations(conversations, *trainPct, *validPct, *testPct, *seed) + train, valid, test := splitConversations(conversations, cfg.TrainPct, cfg.ValidPct, cfg.TestPct, cfg.Seed) // Create output directory. - if err := os.MkdirAll(*outputDir, 0755); err != nil { - log.Fatalf("create output dir: %v", err) + if err := os.MkdirAll(cfg.OutputDir, 0755); err != nil { + return fmt.Errorf("create output dir: %w", err) } // Write output files. @@ -138,9 +130,9 @@ func RunConv(args []string) { {"valid.jsonl", valid}, {"test.jsonl", test}, } { - path := *outputDir + "/" + split.name + path := cfg.OutputDir + "/" + split.name if err := writeConversationJSONL(path, split.data); err != nil { - log.Fatalf("write %s: %v", split.name, err) + return fmt.Errorf("write %s: %w", split.name, err) } } @@ -168,16 +160,18 @@ func RunConv(args []string) { fmt.Printf(" %d total conversations\n", len(conversations)) fmt.Printf(" %d total turns (%.1f avg per conversation)\n", totalTurns, avgTurns) fmt.Printf(" %.0f words avg per assistant response\n", avgWords) - fmt.Printf(" Output: %s/\n", *outputDir) + fmt.Printf(" Output: %s/\n", cfg.OutputDir) // Report to InfluxDB if configured. - influx := NewInfluxClient(*influxURL, *influxDB) + influx := NewInfluxClient(cfg.Influx, cfg.InfluxDB) line := fmt.Sprintf("conv_export,worker=%s total=%di,train=%di,valid=%di,test=%di,turns=%di,avg_turns=%f,avg_words=%f", - escapeLp(*worker), len(conversations), len(train), len(valid), len(test), + escapeLp(cfg.Worker), len(conversations), len(train), len(valid), len(test), totalTurns, avgTurns, avgWords) if err := influx.WriteLp([]string{line}); err != nil { log.Printf("influx write (best-effort): %v", err) } + + return nil } // readConversations reads multi-turn conversations from a JSONL file. diff --git a/pkg/lem/convert.go b/pkg/lem/convert.go index de6a9c6..ed1e894 100644 --- a/pkg/lem/convert.go +++ b/pkg/lem/convert.go @@ -3,7 +3,6 @@ package lem import ( "encoding/binary" "encoding/json" - "flag" "fmt" "log" "math" @@ -15,34 +14,30 @@ import ( "strings" ) +// ConvertOpts holds configuration for the MLX-to-PEFT conversion command. +type ConvertOpts struct { + Input string // Path to MLX .safetensors file (required) + Config string // Path to MLX adapter_config.json (required) + Output string // Output directory for PEFT adapter + BaseModel string // HuggingFace base model ID +} + // RunConvert is the CLI entry point for the convert command. // Converts MLX LoRA adapters to HuggingFace PEFT format: // - Key renaming: model.layers.N.module.lora_a → base_model.model.model.layers.N.module.lora_A.default.weight // - Transpose: MLX (in, rank) → PEFT (rank, in) // - Config generation: adapter_config.json with lora_alpha = scale × rank -func RunConvert(args []string) { - fs := flag.NewFlagSet("convert", flag.ExitOnError) - - safetensorsPath := fs.String("input", "", "Path to MLX .safetensors file (required)") - configPath := fs.String("config", "", "Path to MLX adapter_config.json (required)") - outputDir := fs.String("output", "./peft_output", "Output directory for PEFT adapter") - baseModel := fs.String("base-model", "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B", "HuggingFace base model ID") - - if err := fs.Parse(args); err != nil { - log.Fatalf("parse flags: %v", err) +func RunConvert(cfg ConvertOpts) error { + if cfg.Input == "" || cfg.Config == "" { + return fmt.Errorf("--input and --config are required") } - if *safetensorsPath == "" || *configPath == "" { - fmt.Fprintln(os.Stderr, "error: --input and --config are required") - fs.Usage() - os.Exit(1) + if err := convertMLXtoPEFT(cfg.Input, cfg.Config, cfg.Output, cfg.BaseModel); err != nil { + return fmt.Errorf("convert: %w", err) } - if err := convertMLXtoPEFT(*safetensorsPath, *configPath, *outputDir, *baseModel); err != nil { - log.Fatalf("convert: %v", err) - } - - fmt.Printf("Converted to: %s\n", *outputDir) + fmt.Printf("Converted to: %s\n", cfg.Output) + return nil } var ( diff --git a/pkg/lem/coverage.go b/pkg/lem/coverage.go index 8d723ef..6861ab1 100644 --- a/pkg/lem/coverage.go +++ b/pkg/lem/coverage.go @@ -1,40 +1,34 @@ package lem import ( - "flag" "fmt" - "log" "os" "strings" ) -// RunCoverage is the CLI entry point for the coverage command. -// Analyzes seed coverage and shows underrepresented areas. -func RunCoverage(args []string) { - fs := flag.NewFlagSet("coverage", flag.ExitOnError) - dbPath := fs.String("db", "", "DuckDB database path (defaults to LEM_DB env)") +// CoverageOpts holds configuration for the coverage command. +type CoverageOpts struct { + DB string // DuckDB database path (defaults to LEM_DB env) +} - if err := fs.Parse(args); err != nil { - log.Fatalf("parse flags: %v", err) +// RunCoverage analyses seed coverage and shows underrepresented areas. +func RunCoverage(cfg CoverageOpts) error { + if cfg.DB == "" { + cfg.DB = os.Getenv("LEM_DB") + } + if cfg.DB == "" { + return fmt.Errorf("--db or LEM_DB required") } - if *dbPath == "" { - *dbPath = os.Getenv("LEM_DB") - } - if *dbPath == "" { - fmt.Fprintln(os.Stderr, "error: --db or LEM_DB required") - os.Exit(1) - } - - db, err := OpenDB(*dbPath) + db, err := OpenDB(cfg.DB) if err != nil { - log.Fatalf("open db: %v", err) + return fmt.Errorf("open db: %w", err) } defer db.Close() var total int if err := db.conn.QueryRow("SELECT count(*) FROM seeds").Scan(&total); err != nil { - log.Fatalf("No seeds table. Run: lem import-all first") + return fmt.Errorf("no seeds table — run: lem import-all first") } fmt.Println("LEM Seed Coverage Analysis") @@ -65,7 +59,7 @@ func RunCoverage(args []string) { FROM seeds GROUP BY lang_group ORDER BY n ASC `) if err != nil { - log.Fatalf("query regions: %v", err) + return fmt.Errorf("query regions: %w", err) } type regionRow struct { @@ -129,6 +123,8 @@ func RunCoverage(args []string) { fmt.Println(" - Hindi/Urdu, Bengali, Tamil (South Asian)") fmt.Println(" - Swahili, Yoruba, Amharic (Sub-Saharan Africa)") fmt.Println(" - Indigenous languages (Quechua, Nahuatl, Aymara)") + + return nil } // PrintScoreAnalytics prints score distribution statistics and gap analysis diff --git a/pkg/lem/distill.go b/pkg/lem/distill.go index 41da31f..4e2b839 100644 --- a/pkg/lem/distill.go +++ b/pkg/lem/distill.go @@ -3,7 +3,6 @@ package lem import ( "context" "encoding/json" - "flag" "fmt" "log" "os" @@ -33,67 +32,66 @@ type distillCandidate struct { Elapsed time.Duration } +// DistillOpts holds configuration for the distill command. +type DistillOpts struct { + Model string // Model config path (relative to .core/ai/models/) + Probes string // Probe set name from probes.yaml, or path to JSON file + Output string // Output JSONL path (defaults to model training dir) + Lesson int // Lesson number to append to (defaults to probe set phase) + MinScore float64 // Min grammar composite (0 = use ai.yaml default) + Runs int // Generations per probe (0 = use ai.yaml default) + DryRun bool // Show plan and exit without generating + Root string // Project root (for .core/ai/ config) + CacheLimit int // Metal cache limit in GB (0 = use ai.yaml default) + MemLimit int // Metal memory limit in GB (0 = use ai.yaml default) +} + // RunDistill is the CLI entry point for the distill command. // Generates responses via native Metal inference, scores with go-i18n/reversal, // writes passing examples as training JSONL. -func RunDistill(args []string) { - fs := flag.NewFlagSet("distill", flag.ExitOnError) - - modelFlag := fs.String("model", "gemma3/27b", "Model config path (relative to .core/ai/models/)") - probesFlag := fs.String("probes", "", "Probe set name from probes.yaml, or path to JSON file") - outputFlag := fs.String("output", "", "Output JSONL path (defaults to model training dir)") - lessonFlag := fs.Int("lesson", -1, "Lesson number to append to (defaults to probe set phase)") - minScore := fs.Float64("min-score", 0, "Min grammar composite (0 = use ai.yaml default)") - runs := fs.Int("runs", 0, "Generations per probe (0 = use ai.yaml default)") - dryRun := fs.Bool("dry-run", false, "Show plan and exit without generating") - root := fs.String("root", ".", "Project root (for .core/ai/ config)") - cacheLimit := fs.Int("cache-limit", 0, "Metal cache limit in GB (0 = use ai.yaml default)") - memLimit := fs.Int("mem-limit", 0, "Metal memory limit in GB (0 = use ai.yaml default)") - - if err := fs.Parse(args); err != nil { - log.Fatalf("parse flags: %v", err) - } - +func RunDistill(cfg DistillOpts) error { // Load configs. - aiCfg, err := LoadAIConfig(*root) + aiCfg, err := LoadAIConfig(cfg.Root) if err != nil { - log.Fatalf("load ai config: %v", err) + return fmt.Errorf("load ai config: %w", err) } - modelCfg, err := LoadModelConfig(*root, *modelFlag) + modelCfg, err := LoadModelConfig(cfg.Root, cfg.Model) if err != nil { - log.Fatalf("load model config: %v", err) + return fmt.Errorf("load model config: %w", err) } genCfg := MergeGenerate(aiCfg.Generate, modelCfg.Generate) // Apply flag overrides. - if *minScore == 0 { - *minScore = aiCfg.Scorer.MinScore + minScore := cfg.MinScore + if minScore == 0 { + minScore = aiCfg.Scorer.MinScore } - if *runs == 0 { - *runs = aiCfg.Distill.Runs + runs := cfg.Runs + if runs == 0 { + runs = aiCfg.Distill.Runs } cacheLimitGB := aiCfg.Distill.CacheLimit - if *cacheLimit > 0 { - cacheLimitGB = *cacheLimit + if cfg.CacheLimit > 0 { + cacheLimitGB = cfg.CacheLimit } memLimitGB := aiCfg.Distill.MemoryLimit - if *memLimit > 0 { - memLimitGB = *memLimit + if cfg.MemLimit > 0 { + memLimitGB = cfg.MemLimit } // Load probes. - probes, phase, err := loadDistillProbes(*root, *probesFlag) + probes, phase, err := loadDistillProbes(cfg.Root, cfg.Probes) if err != nil { - log.Fatalf("load probes: %v", err) + return fmt.Errorf("load probes: %w", err) } log.Printf("loaded %d probes", len(probes)) // Determine output path. - outputPath := *outputFlag + outputPath := cfg.Output if outputPath == "" { - lesson := *lessonFlag + lesson := cfg.Lesson if lesson < 0 { lesson = phase } @@ -104,31 +102,35 @@ func RunDistill(args []string) { outputPath = filepath.Join(modelCfg.Training, lessonFile) } - // Load kernel. - kernel, err := os.ReadFile(modelCfg.Kernel) - if err != nil { - log.Fatalf("read kernel: %v", err) - } - log.Printf("kernel: %d chars from %s", len(kernel), modelCfg.Kernel) - - // Load signature (LEK-1-Sig). + // Load kernel and signature — only if the model config specifies them. + // LEM models have axioms in weights; loading the kernel violates A4. + var kernel []byte var sig string - if modelCfg.Signature != "" { - sigBytes, err := os.ReadFile(modelCfg.Signature) + if modelCfg.Kernel != "" { + kernel, err = os.ReadFile(modelCfg.Kernel) if err != nil { - log.Fatalf("read signature: %v", err) + return fmt.Errorf("read kernel: %w", err) } - sig = strings.TrimSpace(string(sigBytes)) - log.Printf("signature: %d chars from %s", len(sig), modelCfg.Signature) + log.Printf("kernel: %d chars from %s (sandwich output)", len(kernel), modelCfg.Kernel) + + if modelCfg.Signature != "" { + sigBytes, err := os.ReadFile(modelCfg.Signature) + if err != nil { + return fmt.Errorf("read signature: %w", err) + } + sig = strings.TrimSpace(string(sigBytes)) + } + } else { + log.Printf("no kernel configured — bare output (LEM model)") } // Dry run. - if *dryRun { + if cfg.DryRun { fmt.Printf("Model: %s (%s)\n", modelCfg.Name, modelCfg.Paths.Base) fmt.Printf("Backend: %s\n", aiCfg.Backend) fmt.Printf("Probes: %d\n", len(probes)) - fmt.Printf("Runs: %d per probe (%d total generations)\n", *runs, len(probes)**runs) - fmt.Printf("Gate: grammar v3 composite >= %.1f\n", *minScore) + fmt.Printf("Runs: %d per probe (%d total generations)\n", runs, len(probes)*runs) + fmt.Printf("Gate: grammar v3 composite >= %.1f\n", minScore) fmt.Printf("Generate: temp=%.2f max_tokens=%d top_p=%.2f\n", genCfg.Temperature, genCfg.MaxTokens, genCfg.TopP) fmt.Printf("Memory: cache=%dGB limit=%dGB\n", cacheLimitGB, memLimitGB) @@ -145,7 +147,7 @@ func RunDistill(args []string) { } fmt.Printf(" %s: %s\n", p.ID, prompt) } - return + return nil } // Set Metal memory limits before loading model. @@ -162,7 +164,7 @@ func RunDistill(args []string) { log.Printf("loading model: %s", modelCfg.Paths.Base) backend, err := ml.NewMLXBackend(modelCfg.Paths.Base) if err != nil { - log.Fatalf("load model: %v", err) + return fmt.Errorf("load model: %w", err) } defer backend.Close() @@ -183,7 +185,7 @@ func RunDistill(args []string) { // Open output for append. out, err := os.OpenFile(outputPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) if err != nil { - log.Fatalf("open output: %v", err) + return fmt.Errorf("open output: %w", err) } defer out.Close() @@ -200,15 +202,20 @@ func RunDistill(args []string) { for i, probe := range probes { var best *distillCandidate - // Build sandwich prompt for output: LEK-1 + Prompt + LEK-1-Sig - sandwichPrompt := kernelStr + "\n\n" + probe.Prompt - if sig != "" { - sandwichPrompt += "\n\n" + sig + // Build output prompt: sandwich if kernel configured, bare if LEM model. + var outputPrompt string + if kernelStr != "" { + outputPrompt = kernelStr + "\n\n" + probe.Prompt + if sig != "" { + outputPrompt += "\n\n" + sig + } + } else { + outputPrompt = probe.Prompt } - for run := range *runs { + for run := range runs { fmt.Fprintf(os.Stderr, " [%d/%d] %s run %d/%d", - i+1, len(probes), probe.ID, run+1, *runs) + i+1, len(probes), probe.ID, run+1, runs) // Inference uses bare probe — the model generates from its weights. // Sandwich wrapping is only for the training output format. @@ -277,7 +284,7 @@ func RunDistill(args []string) { } // Quality gate. - if best != nil && best.Grammar.Composite >= *minScore { + if best != nil && best.Grammar.Composite >= minScore { // Duplicate filter: reject if grammar profile is too similar to an already-kept entry. bestFeatures := GrammarFeatures(best.Grammar) if dedupIdx != nil && dedupIdx.IsDuplicate(bestFeatures, 0.02) { @@ -288,10 +295,10 @@ func RunDistill(args []string) { continue } - // Save with sandwich prompt — kernel wraps the bare probe for training. + // Save with output prompt — sandwich if kernel, bare if LEM model. example := TrainingExample{ Messages: []ChatMessage{ - {Role: "user", Content: sandwichPrompt}, + {Role: "user", Content: outputPrompt}, {Role: "assistant", Content: best.Response}, }, } @@ -318,7 +325,7 @@ func RunDistill(args []string) { score = best.Grammar.Composite } fmt.Fprintf(os.Stderr, " ✗ SKIP %s (best g=%.1f < %.1f)\n", - probe.ID, score, *minScore) + probe.ID, score, minScore) } // Release GPU memory between probes to prevent incremental leak. @@ -330,8 +337,8 @@ func RunDistill(args []string) { fmt.Fprintf(os.Stderr, "\n=== Distillation Complete ===\n") fmt.Fprintf(os.Stderr, "Model: %s (%s)\n", modelCfg.Name, backend.Name()) fmt.Fprintf(os.Stderr, "Probes: %d\n", len(probes)) - fmt.Fprintf(os.Stderr, "Runs: %d per probe (%d total generations)\n", *runs, len(probes)**runs) - fmt.Fprintf(os.Stderr, "Scorer: go-i18n/reversal grammar v3, gate >= %.1f\n", *minScore) + fmt.Fprintf(os.Stderr, "Runs: %d per probe (%d total generations)\n", runs, len(probes)*runs) + fmt.Fprintf(os.Stderr, "Scorer: go-i18n/reversal grammar v3, gate >= %.1f\n", minScore) fmt.Fprintf(os.Stderr, "Kept: %d\n", kept) fmt.Fprintf(os.Stderr, "Deduped: %d\n", deduped) fmt.Fprintf(os.Stderr, "Skipped: %d\n", skipped) @@ -341,6 +348,7 @@ func RunDistill(args []string) { } fmt.Fprintf(os.Stderr, "Output: %s\n", outputPath) fmt.Fprintf(os.Stderr, "Duration: %.0fs (%.1fm)\n", duration.Seconds(), duration.Minutes()) + return nil } // loadDistillProbes loads probes from a named set or a file path. diff --git a/pkg/lem/expand.go b/pkg/lem/expand.go index c915637..258cf71 100644 --- a/pkg/lem/expand.go +++ b/pkg/lem/expand.go @@ -2,7 +2,6 @@ package lem import ( "encoding/json" - "flag" "fmt" "log" "os" @@ -22,68 +21,61 @@ type expandOutput struct { Chars int `json:"chars"` } -// runExpand parses CLI flags and runs the expand command. -func RunExpand(args []string) { - fs := flag.NewFlagSet("expand", flag.ExitOnError) +// ExpandOpts holds configuration for the expand command. +type ExpandOpts struct { + Model string // Model name for generation (required) + DB string // DuckDB database path (primary prompt source) + Prompts string // Input JSONL file with expansion prompts (fallback) + APIURL string // OpenAI-compatible API URL + Worker string // Worker hostname (defaults to os.Hostname()) + Limit int // Max prompts to process (0 = all) + Output string // Output directory for JSONL files + Influx string // InfluxDB URL + InfluxDB string // InfluxDB database name + DryRun bool // Print plan and exit without generating +} - model := fs.String("model", "", "Model name for generation (required)") - dbPath := fs.String("db", "", "DuckDB database path (primary prompt source)") - prompts := fs.String("prompts", "", "Input JSONL file with expansion prompts (fallback)") - apiURL := fs.String("api-url", "http://10.69.69.108:8090", "OpenAI-compatible API URL") - worker := fs.String("worker", "", "Worker hostname (defaults to os.Hostname())") - limit := fs.Int("limit", 0, "Max prompts to process (0 = all)") - output := fs.String("output", ".", "Output directory for JSONL files") - influxURL := fs.String("influx", "", "InfluxDB URL (default http://10.69.69.165:8181)") - influxDB := fs.String("influx-db", "", "InfluxDB database name (default training)") - dryRun := fs.Bool("dry-run", false, "Print plan and exit without generating") - - if err := fs.Parse(args); err != nil { - log.Fatalf("parse flags: %v", err) - } - - if *model == "" { - fmt.Fprintln(os.Stderr, "error: --model is required") - fs.Usage() - os.Exit(1) +// RunExpand generates expansion responses via a trained LEM model. +func RunExpand(cfg ExpandOpts) error { + if cfg.Model == "" { + return fmt.Errorf("--model is required") } // Check LEM_DB env as default for --db. - if *dbPath == "" { - *dbPath = os.Getenv("LEM_DB") + if cfg.DB == "" { + cfg.DB = os.Getenv("LEM_DB") } - if *dbPath == "" && *prompts == "" { - fmt.Fprintln(os.Stderr, "error: --db or --prompts is required (set LEM_DB env for default)") - fs.Usage() - os.Exit(1) + if cfg.DB == "" && cfg.Prompts == "" { + return fmt.Errorf("--db or --prompts is required (set LEM_DB env for default)") } // Default worker to hostname. - if *worker == "" { + if cfg.Worker == "" { hostname, err := os.Hostname() if err != nil { hostname = "unknown" } - *worker = hostname + cfg.Worker = hostname } // Load prompts from DuckDB or JSONL. var promptList []Response var duckDB *DB - if *dbPath != "" { + if cfg.DB != "" { var err error - duckDB, err = OpenDBReadWrite(*dbPath) + duckDB, err = OpenDBReadWrite(cfg.DB) if err != nil { - log.Fatalf("open db: %v", err) + return fmt.Errorf("open db: %v", err) } defer duckDB.Close() - rows, err := duckDB.QueryExpansionPrompts("pending", *limit) + rows, err := duckDB.QueryExpansionPrompts("pending", cfg.Limit) if err != nil { - log.Fatalf("query expansion_prompts: %v", err) + return fmt.Errorf("query expansion_prompts: %v", err) } - log.Printf("loaded %d pending prompts from %s", len(rows), *dbPath) + log.Printf("loaded %d pending prompts from %s", len(rows), cfg.DB) for _, r := range rows { prompt := r.Prompt @@ -98,21 +90,22 @@ func RunExpand(args []string) { } } else { var err error - promptList, err = ReadResponses(*prompts) + promptList, err = ReadResponses(cfg.Prompts) if err != nil { - log.Fatalf("read prompts: %v", err) + return fmt.Errorf("read prompts: %v", err) } - log.Printf("loaded %d prompts from %s", len(promptList), *prompts) + log.Printf("loaded %d prompts from %s", len(promptList), cfg.Prompts) } // Create clients. - client := NewClient(*apiURL, *model) + client := NewClient(cfg.APIURL, cfg.Model) client.MaxTokens = 2048 - influx := NewInfluxClient(*influxURL, *influxDB) + influx := NewInfluxClient(cfg.Influx, cfg.InfluxDB) - if err := expandPrompts(client, influx, duckDB, promptList, *model, *worker, *output, *dryRun, *limit); err != nil { - log.Fatalf("expand: %v", err) + if err := expandPrompts(client, influx, duckDB, promptList, cfg.Model, cfg.Worker, cfg.Output, cfg.DryRun, cfg.Limit); err != nil { + return fmt.Errorf("expand: %v", err) } + return nil } // getCompletedIDs queries InfluxDB for prompt IDs that have already been diff --git a/pkg/lem/expand_status.go b/pkg/lem/expand_status.go index 62be3e9..2a715ee 100644 --- a/pkg/lem/expand_status.go +++ b/pkg/lem/expand_status.go @@ -1,33 +1,27 @@ package lem import ( - "flag" "fmt" - "log" "os" ) -// RunExpandStatus is the CLI entry point for the expand-status command. -// Shows the expansion pipeline progress from DuckDB. -func RunExpandStatus(args []string) { - fs := flag.NewFlagSet("expand-status", flag.ExitOnError) - dbPath := fs.String("db", "", "DuckDB database path (defaults to LEM_DB env)") +// ExpandStatusOpts holds configuration for the expand-status command. +type ExpandStatusOpts struct { + DB string // DuckDB database path (defaults to LEM_DB env) +} - if err := fs.Parse(args); err != nil { - log.Fatalf("parse flags: %v", err) +// RunExpandStatus shows the expansion pipeline progress from DuckDB. +func RunExpandStatus(cfg ExpandStatusOpts) error { + if cfg.DB == "" { + cfg.DB = os.Getenv("LEM_DB") + } + if cfg.DB == "" { + return fmt.Errorf("--db or LEM_DB required") } - if *dbPath == "" { - *dbPath = os.Getenv("LEM_DB") - } - if *dbPath == "" { - fmt.Fprintln(os.Stderr, "error: --db or LEM_DB required") - os.Exit(1) - } - - db, err := OpenDB(*dbPath) + db, err := OpenDB(cfg.DB) if err != nil { - log.Fatalf("open db: %v", err) + return fmt.Errorf("open db: %w", err) } defer db.Close() @@ -39,8 +33,7 @@ func RunExpandStatus(args []string) { err = db.conn.QueryRow("SELECT count(*) FROM expansion_prompts").Scan(&epTotal) if err != nil { fmt.Println(" Expansion prompts: not created (run: lem normalize)") - db.Close() - return + return nil } db.conn.QueryRow("SELECT count(*) FROM expansion_prompts WHERE status = 'pending'").Scan(&epPending) fmt.Printf(" Expansion prompts: %d total, %d pending\n", epTotal, epPending) @@ -100,4 +93,6 @@ func RunExpandStatus(args []string) { fmt.Printf(" Combined: %d total examples\n", golden+generated) } } + + return nil } diff --git a/pkg/lem/export.go b/pkg/lem/export.go index 3ad1ab3..0ecb46e 100644 --- a/pkg/lem/export.go +++ b/pkg/lem/export.go @@ -3,7 +3,6 @@ package lem import ( "bufio" "encoding/json" - "flag" "fmt" "log" "math/rand" @@ -11,6 +10,18 @@ import ( "strings" ) +// ExportOpts holds configuration for the JSONL export command. +type ExportOpts struct { + DBPath string // DuckDB database path (primary source); falls back to LEM_DB env + Input string // Input golden set JSONL file (fallback if DBPath not set) + OutputDir string // Output directory for training files (required) + TrainPct int // Training set percentage + ValidPct int // Validation set percentage + TestPct int // Test set percentage + Seed int64 // Random seed for shuffling + MinChars int // Minimum response character count +} + // ChatMessage is a single message in the chat training format. type ChatMessage struct { Role string `json:"role"` @@ -22,60 +33,40 @@ type TrainingExample struct { Messages []ChatMessage `json:"messages"` } -// runExport is the CLI entry point for the export command. -func RunExport(args []string) { - fs := flag.NewFlagSet("export", flag.ExitOnError) - - dbPath := fs.String("db", "", "DuckDB database path (primary source)") - input := fs.String("input", "", "Input golden set JSONL file (fallback if --db not set)") - outputDir := fs.String("output-dir", "", "Output directory for training files (required)") - trainPct := fs.Int("train-pct", 90, "Training set percentage") - validPct := fs.Int("valid-pct", 5, "Validation set percentage") - testPct := fs.Int("test-pct", 5, "Test set percentage") - seed := fs.Int64("seed", 42, "Random seed for shuffling") - minChars := fs.Int("min-chars", 50, "Minimum response character count") - - if err := fs.Parse(args); err != nil { - log.Fatalf("parse flags: %v", err) - } - +// RunExport is the CLI entry point for the export command. +func RunExport(cfg ExportOpts) error { // Check LEM_DB env as default for --db. - if *dbPath == "" { - *dbPath = os.Getenv("LEM_DB") + if cfg.DBPath == "" { + cfg.DBPath = os.Getenv("LEM_DB") } - if *dbPath == "" && *input == "" { - fmt.Fprintln(os.Stderr, "error: --db or --input is required (set LEM_DB env for default)") - fs.Usage() - os.Exit(1) + if cfg.DBPath == "" && cfg.Input == "" { + return fmt.Errorf("--db or --input is required (set LEM_DB env for default)") } - if *outputDir == "" { - fmt.Fprintln(os.Stderr, "error: --output-dir is required") - fs.Usage() - os.Exit(1) + if cfg.OutputDir == "" { + return fmt.Errorf("--output-dir is required") } - if err := validatePercentages(*trainPct, *validPct, *testPct); err != nil { - fmt.Fprintf(os.Stderr, "error: %v\n", err) - os.Exit(1) + if err := validatePercentages(cfg.TrainPct, cfg.ValidPct, cfg.TestPct); err != nil { + return fmt.Errorf("invalid percentages: %w", err) } var responses []Response - if *dbPath != "" { + if cfg.DBPath != "" { // Primary: read from DuckDB golden_set table. - db, err := OpenDB(*dbPath) + db, err := OpenDB(cfg.DBPath) if err != nil { - log.Fatalf("open db: %v", err) + return fmt.Errorf("open db: %w", err) } defer db.Close() - rows, err := db.QueryGoldenSet(*minChars) + rows, err := db.QueryGoldenSet(cfg.MinChars) if err != nil { - log.Fatalf("query golden_set: %v", err) + return fmt.Errorf("query golden_set: %w", err) } - log.Printf("loaded %d golden set rows from %s (min_chars=%d)", len(rows), *dbPath, *minChars) + log.Printf("loaded %d golden set rows from %s (min_chars=%d)", len(rows), cfg.DBPath, cfg.MinChars) // Convert GoldenSetRow → Response for the shared pipeline. for _, r := range rows { @@ -90,11 +81,11 @@ func RunExport(args []string) { } else { // Fallback: read from JSONL file. var err error - responses, err = ReadResponses(*input) + responses, err = ReadResponses(cfg.Input) if err != nil { - log.Fatalf("read responses: %v", err) + return fmt.Errorf("read responses: %w", err) } - log.Printf("loaded %d responses from %s", len(responses), *input) + log.Printf("loaded %d responses from %s", len(responses), cfg.Input) } // Filter out bad responses (DuckDB already filters by char_count, but @@ -103,11 +94,11 @@ func RunExport(args []string) { log.Printf("filtered to %d valid responses (removed %d)", len(filtered), len(responses)-len(filtered)) // Split into train/valid/test. - train, valid, test := splitData(filtered, *trainPct, *validPct, *testPct, *seed) + train, valid, test := splitData(filtered, cfg.TrainPct, cfg.ValidPct, cfg.TestPct, cfg.Seed) // Create output directory. - if err := os.MkdirAll(*outputDir, 0755); err != nil { - log.Fatalf("create output dir: %v", err) + if err := os.MkdirAll(cfg.OutputDir, 0755); err != nil { + return fmt.Errorf("create output dir: %w", err) } // Write output files. @@ -119,13 +110,14 @@ func RunExport(args []string) { {"valid.jsonl", valid}, {"test.jsonl", test}, } { - path := *outputDir + "/" + split.name + path := cfg.OutputDir + "/" + split.name if err := writeTrainingJSONL(path, split.data); err != nil { - log.Fatalf("write %s: %v", split.name, err) + return fmt.Errorf("write %s: %w", split.name, err) } } fmt.Printf("Exported: %d train / %d valid / %d test\n", len(train), len(valid), len(test)) + return nil } // validatePercentages checks that train+valid+test percentages sum to 100 diff --git a/pkg/lem/export_test.go b/pkg/lem/export_test.go index faecc77..9980eb5 100644 --- a/pkg/lem/export_test.go +++ b/pkg/lem/export_test.go @@ -332,15 +332,17 @@ func TestExportEndToEnd(t *testing.T) { } // Run export with 80/10/10 split. - args := []string{ - "--input", inputPath, - "--output-dir", outputDir, - "--train-pct", "80", - "--valid-pct", "10", - "--test-pct", "10", - "--seed", "42", + err := RunExport(ExportOpts{ + Input: inputPath, + OutputDir: outputDir, + TrainPct: 80, + ValidPct: 10, + TestPct: 10, + Seed: 42, + }) + if err != nil { + t.Fatalf("RunExport failed: %v", err) } - RunExport(args) // Verify output files exist. for _, name := range []string{"train.jsonl", "valid.jsonl", "test.jsonl"} { diff --git a/pkg/lem/import.go b/pkg/lem/import.go index 540470f..c2dce23 100644 --- a/pkg/lem/import.go +++ b/pkg/lem/import.go @@ -3,7 +3,6 @@ package lem import ( "bufio" "encoding/json" - "flag" "fmt" "log" "os" @@ -12,42 +11,41 @@ import ( "strings" ) +// ImportOpts holds configuration for the import-all command. +type ImportOpts struct { + DB string // DuckDB database path (defaults to LEM_DB env) + SkipM3 bool // Skip pulling data from M3 + DataDir string // Local data directory (defaults to db directory) +} + // RunImport is the CLI entry point for the import-all command. // Imports ALL LEM data into DuckDB: prompts, Gemini responses, golden set, // training examples, benchmarks, validations, and seeds. -func RunImport(args []string) { - fs := flag.NewFlagSet("import-all", flag.ExitOnError) - dbPath := fs.String("db", "", "DuckDB database path (defaults to LEM_DB env)") - skipM3 := fs.Bool("skip-m3", false, "Skip pulling data from M3") - dataDir := fs.String("data-dir", "", "Local data directory (defaults to db directory)") - - if err := fs.Parse(args); err != nil { - log.Fatalf("parse flags: %v", err) +func RunImport(cfg ImportOpts) error { + dbPath := cfg.DB + if dbPath == "" { + dbPath = os.Getenv("LEM_DB") + } + if dbPath == "" { + return fmt.Errorf("--db or LEM_DB required") } - if *dbPath == "" { - *dbPath = os.Getenv("LEM_DB") - } - if *dbPath == "" { - fmt.Fprintln(os.Stderr, "error: --db or LEM_DB required") - os.Exit(1) + dataDir := cfg.DataDir + if dataDir == "" { + dataDir = filepath.Dir(dbPath) } - if *dataDir == "" { - *dataDir = filepath.Dir(*dbPath) - } - - db, err := OpenDBReadWrite(*dbPath) + db, err := OpenDBReadWrite(dbPath) if err != nil { - log.Fatalf("open db: %v", err) + return fmt.Errorf("open db: %v", err) } defer db.Close() totals := make(map[string]int) // ── 1. Golden set ── - goldenPath := filepath.Join(*dataDir, "gold-15k.jsonl") - if !*skipM3 { + goldenPath := filepath.Join(dataDir, "gold-15k.jsonl") + if !cfg.SkipM3 { fmt.Println(" Pulling golden set from M3...") scpCmd := exec.Command("scp", "m3:/Volumes/Data/lem/responses/gold-15k.jsonl", goldenPath) if err := scpCmd.Run(); err != nil { @@ -101,10 +99,10 @@ func RunImport(args []string) { {"russian-bridge", []string{"russian-bridge/train.jsonl", "russian-bridge/valid.jsonl"}}, } - trainingLocal := filepath.Join(*dataDir, "training") + trainingLocal := filepath.Join(dataDir, "training") os.MkdirAll(trainingLocal, 0755) - if !*skipM3 { + if !cfg.SkipM3 { fmt.Println(" Pulling training sets from M3...") for _, td := range trainingDirs { for _, rel := range td.files { @@ -152,10 +150,10 @@ func RunImport(args []string) { fmt.Printf(" training_examples: %d rows\n", trainingTotal) // ── 3. Benchmark results ── - benchLocal := filepath.Join(*dataDir, "benchmarks") + benchLocal := filepath.Join(dataDir, "benchmarks") os.MkdirAll(benchLocal, 0755) - if !*skipM3 { + if !cfg.SkipM3 { fmt.Println(" Pulling benchmarks from M3...") for _, bname := range []string{"truthfulqa", "gsm8k", "do_not_answer", "toxigen"} { scpCmd := exec.Command("scp", @@ -195,7 +193,7 @@ func RunImport(args []string) { for _, bfile := range []string{"lem_bench", "lem_ethics", "lem_ethics_allen", "instruction_tuned", "abliterated", "base_pt"} { local := filepath.Join(benchLocal, bfile+".jsonl") if _, err := os.Stat(local); os.IsNotExist(err) { - if !*skipM3 { + if !cfg.SkipM3 { scpCmd := exec.Command("scp", fmt.Sprintf("m3:/Volumes/Data/lem/benchmark/%s.jsonl", bfile), local) scpCmd.Run() @@ -238,7 +236,7 @@ func RunImport(args []string) { `) seedTotal := 0 - seedDirs := []string{filepath.Join(*dataDir, "seeds"), "/tmp/lem-data/seeds", "/tmp/lem-repo/seeds"} + seedDirs := []string{filepath.Join(dataDir, "seeds"), "/tmp/lem-data/seeds", "/tmp/lem-repo/seeds"} for _, seedDir := range seedDirs { if _, err := os.Stat(seedDir); os.IsNotExist(err) { continue @@ -260,7 +258,8 @@ func RunImport(args []string) { } fmt.Printf(" %s\n", strings.Repeat("─", 35)) fmt.Printf(" %-25s %8d\n", "TOTAL", grandTotal) - fmt.Printf("\nDatabase: %s\n", *dbPath) + fmt.Printf("\nDatabase: %s\n", dbPath) + return nil } func importTrainingFile(db *DB, path, source, split string) int { diff --git a/pkg/lem/ingest.go b/pkg/lem/ingest.go index b978eae..0ce83e1 100644 --- a/pkg/lem/ingest.go +++ b/pkg/lem/ingest.go @@ -3,81 +3,77 @@ package lem import ( "bufio" "encoding/json" - "flag" "fmt" - "log" "os" "regexp" "strconv" "strings" ) +// IngestOpts holds configuration for the ingest command. +type IngestOpts struct { + Content string // Content scores JSONL file + Capability string // Capability scores JSONL file + TrainingLog string // MLX LoRA training log file + Model string // Model name tag (required) + RunID string // Run ID tag (defaults to model name) + InfluxURL string // InfluxDB URL + InfluxDB string // InfluxDB database name + BatchSize int // Lines per InfluxDB write batch +} + // RunIngest is the CLI entry point for the ingest command. // It reads benchmark JSONL files and training logs, then pushes // the data into InfluxDB as line protocol for the lab dashboard. -func RunIngest(args []string) { - fs := flag.NewFlagSet("ingest", flag.ExitOnError) - - contentFile := fs.String("content", "", "Content scores JSONL file") - capabilityFile := fs.String("capability", "", "Capability scores JSONL file") - trainingLog := fs.String("training-log", "", "MLX LoRA training log file") - model := fs.String("model", "", "Model name tag (required)") - runID := fs.String("run-id", "", "Run ID tag (defaults to model name)") - influxURL := fs.String("influx", "", "InfluxDB URL") - influxDB := fs.String("influx-db", "", "InfluxDB database name") - batchSize := fs.Int("batch-size", 100, "Lines per InfluxDB write batch") - - if err := fs.Parse(args); err != nil { - log.Fatalf("parse flags: %v", err) +func RunIngest(cfg IngestOpts) error { + if cfg.Model == "" { + return fmt.Errorf("--model is required") } - if *model == "" { - fmt.Fprintln(os.Stderr, "error: --model is required") - fs.Usage() - os.Exit(1) + if cfg.Content == "" && cfg.Capability == "" && cfg.TrainingLog == "" { + return fmt.Errorf("at least one of --content, --capability, or --training-log is required") } - if *contentFile == "" && *capabilityFile == "" && *trainingLog == "" { - fmt.Fprintln(os.Stderr, "error: at least one of --content, --capability, or --training-log is required") - fs.Usage() - os.Exit(1) + if cfg.RunID == "" { + cfg.RunID = cfg.Model } - if *runID == "" { - *runID = *model + if cfg.BatchSize <= 0 { + cfg.BatchSize = 100 } - influx := NewInfluxClient(*influxURL, *influxDB) + influx := NewInfluxClient(cfg.InfluxURL, cfg.InfluxDB) total := 0 - if *contentFile != "" { - n, err := ingestContentScores(influx, *contentFile, *model, *runID, *batchSize) + if cfg.Content != "" { + n, err := ingestContentScores(influx, cfg.Content, cfg.Model, cfg.RunID, cfg.BatchSize) if err != nil { - log.Fatalf("ingest content scores: %v", err) + return fmt.Errorf("ingest content scores: %w", err) } fmt.Printf(" Content scores: %d points\n", n) total += n } - if *capabilityFile != "" { - n, err := ingestCapabilityScores(influx, *capabilityFile, *model, *runID, *batchSize) + if cfg.Capability != "" { + n, err := ingestCapabilityScores(influx, cfg.Capability, cfg.Model, cfg.RunID, cfg.BatchSize) if err != nil { - log.Fatalf("ingest capability scores: %v", err) + return fmt.Errorf("ingest capability scores: %w", err) } fmt.Printf(" Capability scores: %d points\n", n) total += n } - if *trainingLog != "" { - n, err := ingestTrainingCurve(influx, *trainingLog, *model, *runID, *batchSize) + if cfg.TrainingLog != "" { + n, err := ingestTrainingCurve(influx, cfg.TrainingLog, cfg.Model, cfg.RunID, cfg.BatchSize) if err != nil { - log.Fatalf("ingest training curve: %v", err) + return fmt.Errorf("ingest training curve: %w", err) } fmt.Printf(" Training curve: %d points\n", n) total += n } fmt.Printf("\nTotal: %d points ingested\n", total) + return nil } var iterRe = regexp.MustCompile(`@(\d+)`) diff --git a/pkg/lem/inventory.go b/pkg/lem/inventory.go index 25c816c..ce9c659 100644 --- a/pkg/lem/inventory.go +++ b/pkg/lem/inventory.go @@ -1,43 +1,37 @@ package lem import ( - "flag" "fmt" - "log" "os" "strings" ) -// RunInventory is the CLI entry point for the inventory command. -// Shows row counts and summary stats for all tables in the DuckDB database. -func RunInventory(args []string) { - fs := flag.NewFlagSet("inventory", flag.ExitOnError) - dbPath := fs.String("db", "", "DuckDB database path (defaults to LEM_DB env)") +// InventoryOpts holds configuration for the inventory command. +type InventoryOpts struct { + DB string // DuckDB database path (defaults to LEM_DB env) +} - if err := fs.Parse(args); err != nil { - log.Fatalf("parse flags: %v", err) +// RunInventory shows row counts and summary stats for all tables in the DuckDB database. +func RunInventory(cfg InventoryOpts) error { + if cfg.DB == "" { + cfg.DB = os.Getenv("LEM_DB") + } + if cfg.DB == "" { + return fmt.Errorf("--db or LEM_DB required") } - if *dbPath == "" { - *dbPath = os.Getenv("LEM_DB") - } - if *dbPath == "" { - fmt.Fprintln(os.Stderr, "error: --db or LEM_DB required") - os.Exit(1) - } - - db, err := OpenDB(*dbPath) + db, err := OpenDB(cfg.DB) if err != nil { - log.Fatalf("open db: %v", err) + return fmt.Errorf("open db: %w", err) } defer db.Close() counts, err := db.TableCounts() if err != nil { - log.Fatalf("table counts: %v", err) + return fmt.Errorf("table counts: %w", err) } - fmt.Printf("LEM Database Inventory (%s)\n", *dbPath) + fmt.Printf("LEM Database Inventory (%s)\n", cfg.DB) fmt.Println("============================================================") grandTotal := 0 @@ -84,6 +78,8 @@ func RunInventory(args []string) { fmt.Printf(" %-25s\n", "────────────────────────────────────────") fmt.Printf(" %-25s %8d\n", "TOTAL", grandTotal) + + return nil } func joinStrings(parts []string, sep string) string { diff --git a/pkg/lem/metrics.go b/pkg/lem/metrics.go index 5b701cc..1c55e5d 100644 --- a/pkg/lem/metrics.go +++ b/pkg/lem/metrics.go @@ -1,40 +1,33 @@ package lem import ( - "flag" "fmt" - "log" "os" "time" ) const targetTotal = 15000 -// RunMetrics is the CLI entry point for the metrics command. -// Reads golden set stats from DuckDB and pushes them to InfluxDB as +// MetricsOpts holds configuration for the metrics command. +type MetricsOpts struct { + DB string // DuckDB database path (defaults to LEM_DB env) + Influx string // InfluxDB URL + InfluxDB string // InfluxDB database name +} + +// RunMetrics reads golden set stats from DuckDB and pushes them to InfluxDB as // golden_set_stats, golden_set_domain, and golden_set_voice measurements. -func RunMetrics(args []string) { - fs := flag.NewFlagSet("metrics", flag.ExitOnError) - - dbPath := fs.String("db", "", "DuckDB database path (defaults to LEM_DB env)") - influxURL := fs.String("influx", "", "InfluxDB URL") - influxDB := fs.String("influx-db", "", "InfluxDB database name") - - if err := fs.Parse(args); err != nil { - log.Fatalf("parse flags: %v", err) +func RunMetrics(cfg MetricsOpts) error { + if cfg.DB == "" { + cfg.DB = os.Getenv("LEM_DB") + } + if cfg.DB == "" { + return fmt.Errorf("--db or LEM_DB required (path to DuckDB file)") } - if *dbPath == "" { - *dbPath = os.Getenv("LEM_DB") - } - if *dbPath == "" { - fmt.Fprintln(os.Stderr, "error: --db or LEM_DB required (path to DuckDB file)") - os.Exit(1) - } - - db, err := OpenDB(*dbPath) + db, err := OpenDB(cfg.DB) if err != nil { - log.Fatalf("open db: %v", err) + return fmt.Errorf("open db: %w", err) } defer db.Close() @@ -48,12 +41,12 @@ func RunMetrics(args []string) { FROM golden_set `).Scan(&total, &domains, &voices, &avgGenTime, &avgChars) if err != nil { - log.Fatalf("query golden_set stats: %v", err) + return fmt.Errorf("query golden_set stats: %w", err) } if total == 0 { fmt.Println("No golden set data in DuckDB.") - return + return nil } nowNs := time.Now().UTC().UnixNano() @@ -73,7 +66,7 @@ func RunMetrics(args []string) { FROM golden_set GROUP BY domain `) if err != nil { - log.Fatalf("query domains: %v", err) + return fmt.Errorf("query domains: %w", err) } domainCount := 0 for domainRows.Next() { @@ -81,7 +74,7 @@ func RunMetrics(args []string) { var n int var avgT float64 if err := domainRows.Scan(&domain, &n, &avgT); err != nil { - log.Fatalf("scan domain row: %v", err) + return fmt.Errorf("scan domain row: %w", err) } lines = append(lines, fmt.Sprintf( "golden_set_domain,domain=%s count=%di,avg_gen_time=%.2f %d", @@ -97,7 +90,7 @@ func RunMetrics(args []string) { FROM golden_set GROUP BY voice `) if err != nil { - log.Fatalf("query voices: %v", err) + return fmt.Errorf("query voices: %w", err) } voiceCount := 0 for voiceRows.Next() { @@ -105,7 +98,7 @@ func RunMetrics(args []string) { var n int var avgC, avgT float64 if err := voiceRows.Scan(&voice, &n, &avgC, &avgT); err != nil { - log.Fatalf("scan voice row: %v", err) + return fmt.Errorf("scan voice row: %w", err) } lines = append(lines, fmt.Sprintf( "golden_set_voice,voice=%s count=%di,avg_chars=%.0f,avg_gen_time=%.2f %d", @@ -116,11 +109,13 @@ func RunMetrics(args []string) { voiceRows.Close() // Write to InfluxDB. - influx := NewInfluxClient(*influxURL, *influxDB) + influx := NewInfluxClient(cfg.Influx, cfg.InfluxDB) if err := influx.WriteLp(lines); err != nil { - log.Fatalf("write metrics: %v", err) + return fmt.Errorf("write metrics: %w", err) } fmt.Printf("Wrote metrics to InfluxDB: %d examples, %d domains, %d voices (%d points)\n", total, domainCount, voiceCount, len(lines)) + + return nil } diff --git a/pkg/lem/normalize.go b/pkg/lem/normalize.go index bc8ac75..e5d01b5 100644 --- a/pkg/lem/normalize.go +++ b/pkg/lem/normalize.go @@ -1,50 +1,49 @@ package lem import ( - "flag" "fmt" "log" "os" ) +// NormalizeOpts holds configuration for the normalize command. +type NormalizeOpts struct { + DB string // DuckDB database path (defaults to LEM_DB env) + MinLen int // Minimum prompt length in characters +} + // RunNormalize is the CLI entry point for the normalize command. // Normalizes seeds into the expansion_prompts table, deduplicating against // the golden set and existing prompts. Assigns priority based on domain // coverage (underrepresented domains first). -func RunNormalize(args []string) { - fs := flag.NewFlagSet("normalize", flag.ExitOnError) - dbPath := fs.String("db", "", "DuckDB database path (defaults to LEM_DB env)") - minLen := fs.Int("min-length", 50, "Minimum prompt length in characters") - - if err := fs.Parse(args); err != nil { - log.Fatalf("parse flags: %v", err) +func RunNormalize(cfg NormalizeOpts) error { + dbPath := cfg.DB + if dbPath == "" { + dbPath = os.Getenv("LEM_DB") + } + if dbPath == "" { + return fmt.Errorf("--db or LEM_DB required") } - if *dbPath == "" { - *dbPath = os.Getenv("LEM_DB") - } - if *dbPath == "" { - fmt.Fprintln(os.Stderr, "error: --db or LEM_DB required") - os.Exit(1) - } + minLen := cfg.MinLen - db, err := OpenDBReadWrite(*dbPath) + db, err := OpenDBReadWrite(dbPath) if err != nil { - log.Fatalf("open db: %v", err) + return fmt.Errorf("open db: %v", err) } defer db.Close() // Check source tables. var seedCount int if err := db.conn.QueryRow("SELECT count(*) FROM seeds").Scan(&seedCount); err != nil { - log.Fatalf("No seeds table. Run: lem import-all first") + return fmt.Errorf("no seeds table: run lem import-all first") } fmt.Printf("Seeds table: %d rows\n", seedCount) // Drop and recreate expansion_prompts. _, err = db.conn.Exec("DROP TABLE IF EXISTS expansion_prompts") if err != nil { - log.Fatalf("drop expansion_prompts: %v", err) + return fmt.Errorf("drop expansion_prompts: %v", err) } // Deduplicate: remove seeds whose prompt already appears in prompts or golden_set. @@ -85,9 +84,9 @@ func RunNormalize(args []string) { SELECT 1 FROM existing_prompts ep WHERE ep.prompt = us.prompt ) - `, *minLen)) + `, minLen)) if err != nil { - log.Fatalf("create expansion_prompts: %v", err) + return fmt.Errorf("create expansion_prompts: %v", err) } var total, domains, regions int @@ -145,4 +144,5 @@ func RunNormalize(args []string) { } fmt.Printf("\nNormalization complete: %d expansion prompts from %d seeds\n", total, seedCount) + return nil } diff --git a/pkg/lem/parquet.go b/pkg/lem/parquet.go index 0d3e136..e80c60d 100644 --- a/pkg/lem/parquet.go +++ b/pkg/lem/parquet.go @@ -3,9 +3,7 @@ package lem import ( "bufio" "encoding/json" - "flag" "fmt" - "log" "os" "path/filepath" "strings" @@ -13,6 +11,12 @@ import ( "github.com/parquet-go/parquet-go" ) +// ParquetOpts holds configuration for the Parquet export command. +type ParquetOpts struct { + Input string // Directory containing train.jsonl, valid.jsonl, test.jsonl (required) + Output string // Output directory for Parquet files (defaults to input/parquet) +} + // ParquetRow is the schema for exported Parquet files. type ParquetRow struct { Prompt string `parquet:"prompt"` @@ -24,48 +28,39 @@ type ParquetRow struct { // RunParquet is the CLI entry point for the parquet command. // Reads JSONL training splits (train.jsonl, valid.jsonl, test.jsonl) and // writes Parquet files with snappy compression for HuggingFace datasets. -func RunParquet(args []string) { - fs := flag.NewFlagSet("parquet", flag.ExitOnError) - - trainingDir := fs.String("input", "", "Directory containing train.jsonl, valid.jsonl, test.jsonl (required)") - outputDir := fs.String("output", "", "Output directory for Parquet files (defaults to input/parquet)") - - if err := fs.Parse(args); err != nil { - log.Fatalf("parse flags: %v", err) +func RunParquet(cfg ParquetOpts) error { + if cfg.Input == "" { + return fmt.Errorf("--input is required (directory with JSONL splits)") } - if *trainingDir == "" { - fmt.Fprintln(os.Stderr, "error: --input is required (directory with JSONL splits)") - fs.Usage() - os.Exit(1) + outputDir := cfg.Output + if outputDir == "" { + outputDir = filepath.Join(cfg.Input, "parquet") } - if *outputDir == "" { - *outputDir = filepath.Join(*trainingDir, "parquet") + if err := os.MkdirAll(outputDir, 0755); err != nil { + return fmt.Errorf("create output dir: %w", err) } - if err := os.MkdirAll(*outputDir, 0755); err != nil { - log.Fatalf("create output dir: %v", err) - } - - fmt.Printf("Exporting Parquet from %s → %s\n", *trainingDir, *outputDir) + fmt.Printf("Exporting Parquet from %s → %s\n", cfg.Input, outputDir) total := 0 for _, split := range []string{"train", "valid", "test"} { - jsonlPath := filepath.Join(*trainingDir, split+".jsonl") + jsonlPath := filepath.Join(cfg.Input, split+".jsonl") if _, err := os.Stat(jsonlPath); os.IsNotExist(err) { fmt.Printf(" Skip: %s.jsonl not found\n", split) continue } - n, err := exportSplitParquet(jsonlPath, *outputDir, split) + n, err := exportSplitParquet(jsonlPath, outputDir, split) if err != nil { - log.Fatalf("export %s: %v", split, err) + return fmt.Errorf("export %s: %w", split, err) } total += n } fmt.Printf("\nTotal: %d rows exported\n", total) + return nil } // exportSplitParquet reads a JSONL file and writes a Parquet file for the split. diff --git a/pkg/lem/publish.go b/pkg/lem/publish.go index 08170b6..6e09b94 100644 --- a/pkg/lem/publish.go +++ b/pkg/lem/publish.go @@ -2,10 +2,8 @@ package lem import ( "bytes" - "flag" "fmt" "io" - "log" "net/http" "os" "path/filepath" @@ -13,28 +11,23 @@ import ( "time" ) +// PublishOpts holds configuration for the HuggingFace publish command. +type PublishOpts struct { + Input string // Directory containing Parquet files (required) + Repo string // HuggingFace dataset repo ID + Public bool // Make dataset public + Token string // HuggingFace API token (defaults to HF_TOKEN env) + DryRun bool // Show what would be uploaded without uploading +} + // RunPublish is the CLI entry point for the publish command. // Pushes Parquet files and an optional dataset card to HuggingFace. -func RunPublish(args []string) { - fs := flag.NewFlagSet("publish", flag.ExitOnError) - - inputDir := fs.String("input", "", "Directory containing Parquet files (required)") - repoID := fs.String("repo", "lthn/LEM-golden-set", "HuggingFace dataset repo ID") - public := fs.Bool("public", false, "Make dataset public") - token := fs.String("token", "", "HuggingFace API token (defaults to HF_TOKEN env)") - dryRun := fs.Bool("dry-run", false, "Show what would be uploaded without uploading") - - if err := fs.Parse(args); err != nil { - log.Fatalf("parse flags: %v", err) +func RunPublish(cfg PublishOpts) error { + if cfg.Input == "" { + return fmt.Errorf("--input is required (directory with Parquet files)") } - if *inputDir == "" { - fmt.Fprintln(os.Stderr, "error: --input is required (directory with Parquet files)") - fs.Usage() - os.Exit(1) - } - - hfToken := *token + hfToken := cfg.Token if hfToken == "" { hfToken = os.Getenv("HF_TOKEN") } @@ -48,9 +41,8 @@ func RunPublish(args []string) { } } - if hfToken == "" && !*dryRun { - fmt.Fprintln(os.Stderr, "error: HuggingFace token required (--token, HF_TOKEN env, or ~/.huggingface/token)") - os.Exit(1) + if hfToken == "" && !cfg.DryRun { + return fmt.Errorf("HuggingFace token required (--token, HF_TOKEN env, or ~/.huggingface/token)") } splits := []string{"train", "valid", "test"} @@ -61,7 +53,7 @@ func RunPublish(args []string) { var filesToUpload []uploadEntry for _, split := range splits { - path := filepath.Join(*inputDir, split+".parquet") + path := filepath.Join(cfg.Input, split+".parquet") if _, err := os.Stat(path); os.IsNotExist(err) { continue } @@ -69,19 +61,18 @@ func RunPublish(args []string) { } // Check for dataset card in parent directory. - cardPath := filepath.Join(*inputDir, "..", "dataset_card.md") + cardPath := filepath.Join(cfg.Input, "..", "dataset_card.md") if _, err := os.Stat(cardPath); err == nil { filesToUpload = append(filesToUpload, uploadEntry{cardPath, "README.md"}) } if len(filesToUpload) == 0 { - fmt.Fprintln(os.Stderr, "error: no Parquet files found in input directory") - os.Exit(1) + return fmt.Errorf("no Parquet files found in input directory") } - if *dryRun { - fmt.Printf("Dry run: would publish to %s\n", *repoID) - if *public { + if cfg.DryRun { + fmt.Printf("Dry run: would publish to %s\n", cfg.Repo) + if cfg.Public { fmt.Println(" Visibility: public") } else { fmt.Println(" Visibility: private") @@ -91,19 +82,20 @@ func RunPublish(args []string) { sizeMB := float64(info.Size()) / 1024 / 1024 fmt.Printf(" %s → %s (%.1f MB)\n", filepath.Base(f.local), f.remote, sizeMB) } - return + return nil } - fmt.Printf("Publishing to https://huggingface.co/datasets/%s\n", *repoID) + fmt.Printf("Publishing to https://huggingface.co/datasets/%s\n", cfg.Repo) for _, f := range filesToUpload { - if err := uploadFileToHF(hfToken, *repoID, f.local, f.remote); err != nil { - log.Fatalf("upload %s: %v", f.local, err) + if err := uploadFileToHF(hfToken, cfg.Repo, f.local, f.remote); err != nil { + return fmt.Errorf("upload %s: %w", f.local, err) } fmt.Printf(" Uploaded %s → %s\n", filepath.Base(f.local), f.remote) } - fmt.Printf("\nPublished to https://huggingface.co/datasets/%s\n", *repoID) + fmt.Printf("\nPublished to https://huggingface.co/datasets/%s\n", cfg.Repo) + return nil } // uploadFileToHF uploads a file to a HuggingFace dataset repo via the Hub API. diff --git a/pkg/lem/query.go b/pkg/lem/query.go index 35aa987..6a52cd2 100644 --- a/pkg/lem/query.go +++ b/pkg/lem/query.go @@ -2,38 +2,31 @@ package lem import ( "encoding/json" - "flag" "fmt" - "log" "os" "strings" ) +// QueryOpts holds configuration for the query command. +type QueryOpts struct { + DB string // DuckDB database path (defaults to LEM_DB env) + JSON bool // Output as JSON instead of table +} + // RunQuery is the CLI entry point for the query command. // Runs ad-hoc SQL against the DuckDB database. -func RunQuery(args []string) { - fs := flag.NewFlagSet("query", flag.ExitOnError) - dbPath := fs.String("db", "", "DuckDB database path (defaults to LEM_DB env)") - jsonOutput := fs.Bool("json", false, "Output as JSON instead of table") - - if err := fs.Parse(args); err != nil { - log.Fatalf("parse flags: %v", err) +// The args slice contains the SQL query as positional arguments. +func RunQuery(cfg QueryOpts, args []string) error { + if cfg.DB == "" { + cfg.DB = os.Getenv("LEM_DB") + } + if cfg.DB == "" { + return fmt.Errorf("--db or LEM_DB required") } - if *dbPath == "" { - *dbPath = os.Getenv("LEM_DB") - } - if *dbPath == "" { - fmt.Fprintln(os.Stderr, "error: --db or LEM_DB required") - os.Exit(1) - } - - sql := strings.Join(fs.Args(), " ") + sql := strings.Join(args, " ") if sql == "" { - fmt.Fprintln(os.Stderr, "error: SQL query required as positional argument") - fmt.Fprintln(os.Stderr, " lem query --db path.duckdb \"SELECT * FROM golden_set LIMIT 5\"") - fmt.Fprintln(os.Stderr, " lem query --db path.duckdb \"domain = 'ethics'\" (auto-wraps as WHERE clause)") - os.Exit(1) + return fmt.Errorf("SQL query required as positional argument\n lem query --db path.duckdb \"SELECT * FROM golden_set LIMIT 5\"\n lem query --db path.duckdb \"domain = 'ethics'\" (auto-wraps as WHERE clause)") } // Auto-wrap non-SELECT queries as WHERE clauses. @@ -43,21 +36,21 @@ func RunQuery(args []string) { sql = "SELECT * FROM golden_set WHERE " + sql + " LIMIT 20" } - db, err := OpenDB(*dbPath) + db, err := OpenDB(cfg.DB) if err != nil { - log.Fatalf("open db: %v", err) + return fmt.Errorf("open db: %w", err) } defer db.Close() rows, err := db.conn.Query(sql) if err != nil { - log.Fatalf("query: %v", err) + return fmt.Errorf("query: %w", err) } defer rows.Close() cols, err := rows.Columns() if err != nil { - log.Fatalf("columns: %v", err) + return fmt.Errorf("columns: %w", err) } var results []map[string]any @@ -70,7 +63,7 @@ func RunQuery(args []string) { } if err := rows.Scan(ptrs...); err != nil { - log.Fatalf("scan: %v", err) + return fmt.Errorf("scan: %w", err) } row := make(map[string]any) @@ -85,17 +78,16 @@ func RunQuery(args []string) { results = append(results, row) } - if *jsonOutput { + if cfg.JSON { enc := json.NewEncoder(os.Stdout) enc.SetIndent("", " ") - enc.Encode(results) - return + return enc.Encode(results) } // Table output. if len(results) == 0 { fmt.Println("(no results)") - return + return nil } // Calculate column widths. @@ -149,4 +141,5 @@ func RunQuery(args []string) { } fmt.Printf("\n(%d rows)\n", len(results)) + return nil } diff --git a/pkg/lem/score_cmd.go b/pkg/lem/score_cmd.go index b87587a..ef92f97 100644 --- a/pkg/lem/score_cmd.go +++ b/pkg/lem/score_cmd.go @@ -1,46 +1,40 @@ package lem import ( - "flag" "fmt" "log" "os" "time" ) +// ScoreOpts holds configuration for the score run command. +type ScoreOpts struct { + Input string + Suites string + JudgeModel string + JudgeURL string + Concurrency int + Output string + Resume bool +} + // RunScore scores existing response files using a judge model. -func RunScore(args []string) { - fs := flag.NewFlagSet("score", flag.ExitOnError) - - input := fs.String("input", "", "Input JSONL response file (required)") - suites := fs.String("suites", "all", "Comma-separated suites or 'all'") - judgeModel := fs.String("judge-model", "mlx-community/gemma-3-27b-it-qat-4bit", "Judge model name") - judgeURL := fs.String("judge-url", "http://10.69.69.108:8090", "Judge API URL") - concurrency := fs.Int("concurrency", 4, "Max concurrent judge calls") - output := fs.String("output", "scores.json", "Output score file path") - resume := fs.Bool("resume", false, "Resume from existing output, skipping scored IDs") - - if err := fs.Parse(args); err != nil { - log.Fatalf("parse flags: %v", err) +func RunScore(cfg ScoreOpts) error { + if cfg.Input == "" { + return fmt.Errorf("--input is required") } - if *input == "" { - fmt.Fprintln(os.Stderr, "error: --input is required") - fs.Usage() - os.Exit(1) - } - - responses, err := ReadResponses(*input) + responses, err := ReadResponses(cfg.Input) if err != nil { - log.Fatalf("read responses: %v", err) + return fmt.Errorf("read responses: %w", err) } - log.Printf("loaded %d responses from %s", len(responses), *input) + log.Printf("loaded %d responses from %s", len(responses), cfg.Input) - if *resume { - if _, statErr := os.Stat(*output); statErr == nil { - existing, readErr := ReadScorerOutput(*output) + if cfg.Resume { + if _, statErr := os.Stat(cfg.Output); statErr == nil { + existing, readErr := ReadScorerOutput(cfg.Output) if readErr != nil { - log.Fatalf("read existing scores for resume: %v", readErr) + return fmt.Errorf("read existing scores for resume: %w", readErr) } scored := make(map[string]bool) @@ -62,25 +56,25 @@ func RunScore(args []string) { if len(responses) == 0 { log.Println("all responses already scored, nothing to do") - return + return nil } } } - client := NewClient(*judgeURL, *judgeModel) + client := NewClient(cfg.JudgeURL, cfg.JudgeModel) client.MaxTokens = 512 judge := NewJudge(client) - engine := NewEngine(judge, *concurrency, *suites) + engine := NewEngine(judge, cfg.Concurrency, cfg.Suites) log.Printf("scoring with %s", engine) perPrompt := engine.ScoreAll(responses) - if *resume { - if _, statErr := os.Stat(*output); statErr == nil { - existing, readErr := ReadScorerOutput(*output) + if cfg.Resume { + if _, statErr := os.Stat(cfg.Output); statErr == nil { + existing, readErr := ReadScorerOutput(cfg.Output) if readErr != nil { - log.Fatalf("re-read scores for merge: %v", readErr) + return fmt.Errorf("re-read scores for merge: %w", readErr) } for model, scores := range existing.PerPrompt { perPrompt[model] = append(scores, perPrompt[model]...) @@ -92,8 +86,8 @@ func RunScore(args []string) { scorerOutput := &ScorerOutput{ Metadata: Metadata{ - JudgeModel: *judgeModel, - JudgeURL: *judgeURL, + JudgeModel: cfg.JudgeModel, + JudgeURL: cfg.JudgeURL, ScoredAt: time.Now().UTC(), ScorerVersion: "1.0.0", Suites: engine.SuiteNames(), @@ -102,71 +96,69 @@ func RunScore(args []string) { PerPrompt: perPrompt, } - if err := WriteScores(*output, scorerOutput); err != nil { - log.Fatalf("write scores: %v", err) + if err := WriteScores(cfg.Output, scorerOutput); err != nil { + return fmt.Errorf("write scores: %w", err) } - log.Printf("wrote scores to %s", *output) + log.Printf("wrote scores to %s", cfg.Output) + return nil +} + +// ProbeOpts holds configuration for the probe command. +type ProbeOpts struct { + Model string + TargetURL string + ProbesFile string + Suites string + JudgeModel string + JudgeURL string + Concurrency int + Output string } // RunProbe generates responses from a target model and scores them. -func RunProbe(args []string) { - fs := flag.NewFlagSet("probe", flag.ExitOnError) - - model := fs.String("model", "", "Target model name (required)") - targetURL := fs.String("target-url", "", "Target model API URL (defaults to judge-url)") - probesFile := fs.String("probes", "", "Custom probes JSONL file (uses built-in content probes if not specified)") - suites := fs.String("suites", "all", "Comma-separated suites or 'all'") - judgeModel := fs.String("judge-model", "mlx-community/gemma-3-27b-it-qat-4bit", "Judge model name") - judgeURL := fs.String("judge-url", "http://10.69.69.108:8090", "Judge API URL") - concurrency := fs.Int("concurrency", 4, "Max concurrent judge calls") - output := fs.String("output", "scores.json", "Output score file path") - - if err := fs.Parse(args); err != nil { - log.Fatalf("parse flags: %v", err) +func RunProbe(cfg ProbeOpts) error { + if cfg.Model == "" { + return fmt.Errorf("--model is required") } - if *model == "" { - fmt.Fprintln(os.Stderr, "error: --model is required") - fs.Usage() - os.Exit(1) + targetURL := cfg.TargetURL + if targetURL == "" { + targetURL = cfg.JudgeURL } - if *targetURL == "" { - *targetURL = *judgeURL - } - - targetClient := NewClient(*targetURL, *model) + targetClient := NewClient(targetURL, cfg.Model) targetClient.MaxTokens = 1024 - judgeClient := NewClient(*judgeURL, *judgeModel) + judgeClient := NewClient(cfg.JudgeURL, cfg.JudgeModel) judgeClient.MaxTokens = 512 judge := NewJudge(judgeClient) - engine := NewEngine(judge, *concurrency, *suites) + engine := NewEngine(judge, cfg.Concurrency, cfg.Suites) prober := NewProber(targetClient, engine) var scorerOutput *ScorerOutput var err error - if *probesFile != "" { - probes, readErr := ReadResponses(*probesFile) + if cfg.ProbesFile != "" { + probes, readErr := ReadResponses(cfg.ProbesFile) if readErr != nil { - log.Fatalf("read probes: %v", readErr) + return fmt.Errorf("read probes: %w", readErr) } - log.Printf("loaded %d custom probes from %s", len(probes), *probesFile) + log.Printf("loaded %d custom probes from %s", len(probes), cfg.ProbesFile) - scorerOutput, err = prober.ProbeModel(probes, *model) + scorerOutput, err = prober.ProbeModel(probes, cfg.Model) } else { log.Printf("using %d built-in content probes", len(ContentProbes)) - scorerOutput, err = prober.ProbeContent(*model) + scorerOutput, err = prober.ProbeContent(cfg.Model) } if err != nil { - log.Fatalf("probe: %v", err) + return fmt.Errorf("probe: %w", err) } - if writeErr := WriteScores(*output, scorerOutput); writeErr != nil { - log.Fatalf("write scores: %v", writeErr) + if writeErr := WriteScores(cfg.Output, scorerOutput); writeErr != nil { + return fmt.Errorf("write scores: %w", writeErr) } - log.Printf("wrote scores to %s", *output) + log.Printf("wrote scores to %s", cfg.Output) + return nil } diff --git a/pkg/lem/seed_influx.go b/pkg/lem/seed_influx.go index 648d120..08a2b65 100644 --- a/pkg/lem/seed_influx.go +++ b/pkg/lem/seed_influx.go @@ -1,48 +1,47 @@ package lem import ( - "flag" "fmt" - "log" "os" "strings" ) +// SeedInfluxOpts holds configuration for the seed-influx command. +type SeedInfluxOpts struct { + DB string // DuckDB database path (defaults to LEM_DB env) + InfluxURL string // InfluxDB URL + InfluxDB string // InfluxDB database name + Force bool // Re-seed even if InfluxDB already has data + BatchSize int // Lines per InfluxDB write batch +} + // RunSeedInflux is the CLI entry point for the seed-influx command. // Seeds InfluxDB golden_gen measurement from DuckDB golden_set data. // One-time migration tool for bootstrapping InfluxDB from existing data. -func RunSeedInflux(args []string) { - fs := flag.NewFlagSet("seed-influx", flag.ExitOnError) - dbPath := fs.String("db", "", "DuckDB database path (defaults to LEM_DB env)") - influxURL := fs.String("influx", "", "InfluxDB URL") - influxDB := fs.String("influx-db", "", "InfluxDB database name") - force := fs.Bool("force", false, "Re-seed even if InfluxDB already has data") - batchSize := fs.Int("batch-size", 500, "Lines per InfluxDB write batch") - - if err := fs.Parse(args); err != nil { - log.Fatalf("parse flags: %v", err) +func RunSeedInflux(cfg SeedInfluxOpts) error { + if cfg.DB == "" { + cfg.DB = os.Getenv("LEM_DB") + } + if cfg.DB == "" { + return fmt.Errorf("--db or LEM_DB required") } - if *dbPath == "" { - *dbPath = os.Getenv("LEM_DB") - } - if *dbPath == "" { - fmt.Fprintln(os.Stderr, "error: --db or LEM_DB required") - os.Exit(1) + if cfg.BatchSize <= 0 { + cfg.BatchSize = 500 } - db, err := OpenDB(*dbPath) + db, err := OpenDB(cfg.DB) if err != nil { - log.Fatalf("open db: %v", err) + return fmt.Errorf("open db: %w", err) } defer db.Close() var total int if err := db.conn.QueryRow("SELECT count(*) FROM golden_set").Scan(&total); err != nil { - log.Fatalf("No golden_set table. Run ingest first.") + return fmt.Errorf("no golden_set table — run ingest first") } - influx := NewInfluxClient(*influxURL, *influxDB) + influx := NewInfluxClient(cfg.InfluxURL, cfg.InfluxDB) // Check existing count in InfluxDB. existing := 0 @@ -55,9 +54,9 @@ func RunSeedInflux(args []string) { fmt.Printf("DuckDB has %d records, InfluxDB golden_gen has %d\n", total, existing) - if existing >= total && !*force { + if existing >= total && !cfg.Force { fmt.Println("InfluxDB already has all records. Use --force to re-seed.") - return + return nil } // Read all rows. @@ -66,7 +65,7 @@ func RunSeedInflux(args []string) { FROM golden_set ORDER BY idx `) if err != nil { - log.Fatalf("query golden_set: %v", err) + return fmt.Errorf("query golden_set: %w", err) } defer dbRows.Close() @@ -79,7 +78,7 @@ func RunSeedInflux(args []string) { var genTime float64 if err := dbRows.Scan(&idx, &seedID, &domain, &voice, &genTime, &charCount); err != nil { - log.Fatalf("scan: %v", err) + return fmt.Errorf("scan row: %w", err) } sid := strings.ReplaceAll(seedID, `"`, `\"`) @@ -87,9 +86,9 @@ func RunSeedInflux(args []string) { idx, escapeLp(domain), escapeLp(voice), sid, genTime, charCount) lines = append(lines, lp) - if len(lines) >= *batchSize { + if len(lines) >= cfg.BatchSize { if err := influx.WriteLp(lines); err != nil { - log.Fatalf("write batch at %d: %v", written, err) + return fmt.Errorf("write batch at %d: %w", written, err) } written += len(lines) lines = lines[:0] @@ -102,10 +101,11 @@ func RunSeedInflux(args []string) { if len(lines) > 0 { if err := influx.WriteLp(lines); err != nil { - log.Fatalf("flush: %v", err) + return fmt.Errorf("flush: %w", err) } written += len(lines) } fmt.Printf("Seeded %d golden_gen records into InfluxDB\n", written) + return nil } diff --git a/pkg/lem/status.go b/pkg/lem/status.go index 0583733..b3ce9b9 100644 --- a/pkg/lem/status.go +++ b/pkg/lem/status.go @@ -1,48 +1,43 @@ package lem import ( - "flag" "fmt" "io" - "log" "os" "sort" ) -// runStatus parses CLI flags and prints training/generation status from InfluxDB. -func RunStatus(args []string) { - fs := flag.NewFlagSet("status", flag.ExitOnError) - - influxURL := fs.String("influx", "", "InfluxDB URL (default http://10.69.69.165:8181)") - influxDB := fs.String("influx-db", "", "InfluxDB database name (default training)") - dbPath := fs.String("db", "", "DuckDB database path (shows table counts)") - - if err := fs.Parse(args); err != nil { - log.Fatalf("parse flags: %v", err) - } +// StatusOpts holds configuration for the status command. +type StatusOpts struct { + Influx string // InfluxDB URL (default http://10.69.69.165:8181) + InfluxDB string // InfluxDB database name (default training) + DB string // DuckDB database path (shows table counts) +} +// RunStatus prints training/generation status from InfluxDB and optional DuckDB table counts. +func RunStatus(cfg StatusOpts) error { // Check LEM_DB env as default for --db. - if *dbPath == "" { - *dbPath = os.Getenv("LEM_DB") + if cfg.DB == "" { + cfg.DB = os.Getenv("LEM_DB") } - influx := NewInfluxClient(*influxURL, *influxDB) + influx := NewInfluxClient(cfg.Influx, cfg.InfluxDB) if err := printStatus(influx, os.Stdout); err != nil { - log.Fatalf("status: %v", err) + return fmt.Errorf("status: %w", err) } // If DuckDB path provided, show table counts. - if *dbPath != "" { - db, err := OpenDB(*dbPath) + if cfg.DB != "" { + db, err := OpenDB(cfg.DB) if err != nil { - log.Fatalf("open db: %v", err) + return fmt.Errorf("open db: %w", err) } defer db.Close() counts, err := db.TableCounts() if err != nil { - log.Fatalf("table counts: %v", err) + return fmt.Errorf("table counts: %w", err) } fmt.Fprintln(os.Stdout) @@ -55,6 +50,8 @@ func RunStatus(args []string) { } } } + + return nil } // trainingRow holds deduplicated training status + loss for a single model. diff --git a/pkg/lem/tier_score.go b/pkg/lem/tier_score.go index 59e077d..f0ca09d 100644 --- a/pkg/lem/tier_score.go +++ b/pkg/lem/tier_score.go @@ -1,39 +1,36 @@ package lem import ( - "flag" "fmt" "log" "os" "strings" ) +// TierScoreOpts holds configuration for the tier-score command. +type TierScoreOpts struct { + DBPath string + Tier int + Limit int +} + // RunTierScore is the CLI entry point for the tier-score command. // Scores expansion responses using tiered quality assessment: // - Tier 1: Heuristic regex scoring (fast, no API) // - Tier 2: LEM self-judge (requires trained model) // - Tier 3: External judge (reserved for borderline cases) -func RunTierScore(args []string) { - fs := flag.NewFlagSet("tier-score", flag.ExitOnError) - dbPath := fs.String("db", "", "DuckDB database path (defaults to LEM_DB env)") - tier := fs.Int("tier", 1, "Scoring tier: 1=heuristic, 2=LEM judge, 3=external") - limit := fs.Int("limit", 0, "Max items to score (0=all)") - - if err := fs.Parse(args); err != nil { - log.Fatalf("parse flags: %v", err) +func RunTierScore(cfg TierScoreOpts) error { + dbPath := cfg.DBPath + if dbPath == "" { + dbPath = os.Getenv("LEM_DB") + } + if dbPath == "" { + return fmt.Errorf("--db or LEM_DB required") } - if *dbPath == "" { - *dbPath = os.Getenv("LEM_DB") - } - if *dbPath == "" { - fmt.Fprintln(os.Stderr, "error: --db or LEM_DB required") - os.Exit(1) - } - - db, err := OpenDBReadWrite(*dbPath) + db, err := OpenDBReadWrite(dbPath) if err != nil { - log.Fatalf("open db: %v", err) + return fmt.Errorf("open db: %w", err) } defer db.Close() @@ -54,18 +51,20 @@ func RunTierScore(args []string) { ) `) - if *tier >= 1 { - runHeuristicTier(db, *limit) + if cfg.Tier >= 1 { + runHeuristicTier(db, cfg.Limit) } - if *tier >= 2 { + if cfg.Tier >= 2 { fmt.Println("\nTier 2 (LEM judge): not yet available — needs trained LEM-27B model") fmt.Println(" Will score: sovereignty, ethical_depth, creative, self_concept (1-10 each)") } - if *tier >= 3 { + if cfg.Tier >= 3 { fmt.Println("\nTier 3 (External judge): reserved for borderline cases") } + + return nil } func runHeuristicTier(db *DB, limit int) { diff --git a/pkg/lem/worker.go b/pkg/lem/worker.go index af10554..3580fb3 100644 --- a/pkg/lem/worker.go +++ b/pkg/lem/worker.go @@ -3,7 +3,6 @@ package lem import ( "bytes" "encoding/json" - "flag" "fmt" "io" "log" @@ -14,7 +13,25 @@ import ( "time" ) -// workerConfig holds the worker's runtime configuration. +// WorkerOpts holds the worker's runtime configuration. +type WorkerOpts struct { + APIBase string // LEM API base URL + WorkerID string // Worker ID (machine UUID) + Name string // Worker display name + APIKey string // API key (or use LEM_API_KEY env) + GPUType string // GPU type (e.g. 'RTX 3090') + VRAMGb int // GPU VRAM in GB + Languages string // Comma-separated language codes (e.g. 'en,yo,sw') + Models string // Comma-separated supported model names + InferURL string // Local inference endpoint + TaskType string // Filter by task type (expand, score, translate, seed) + BatchSize int // Number of tasks to fetch per poll + PollInterval time.Duration // Poll interval + OneShot bool // Process one batch and exit + DryRun bool // Fetch tasks but don't run inference +} + +// workerConfig is the internal runtime config populated from WorkerOpts. type workerConfig struct { apiBase string workerID string @@ -50,35 +67,63 @@ type apiTask struct { } // RunWorker is the CLI entry point for `lem worker`. -func RunWorker(args []string) { - fs := flag.NewFlagSet("worker", flag.ExitOnError) - - cfg := workerConfig{} - var langs, mods string - - fs.StringVar(&cfg.apiBase, "api", envOr("LEM_API", "https://infer.lthn.ai"), "LEM API base URL") - fs.StringVar(&cfg.workerID, "id", envOr("LEM_WORKER_ID", machineID()), "Worker ID (machine UUID)") - fs.StringVar(&cfg.name, "name", envOr("LEM_WORKER_NAME", hostname()), "Worker display name") - fs.StringVar(&cfg.apiKey, "key", envOr("LEM_API_KEY", ""), "API key (or use LEM_API_KEY env)") - fs.StringVar(&cfg.gpuType, "gpu", envOr("LEM_GPU", ""), "GPU type (e.g. 'RTX 3090')") - fs.IntVar(&cfg.vramGb, "vram", intEnvOr("LEM_VRAM_GB", 0), "GPU VRAM in GB") - fs.StringVar(&langs, "languages", envOr("LEM_LANGUAGES", ""), "Comma-separated language codes (e.g. 'en,yo,sw')") - fs.StringVar(&mods, "models", envOr("LEM_MODELS", ""), "Comma-separated supported model names") - fs.StringVar(&cfg.inferURL, "infer", envOr("LEM_INFER_URL", "http://localhost:8090"), "Local inference endpoint") - fs.StringVar(&cfg.taskType, "type", "", "Filter by task type (expand, score, translate, seed)") - fs.IntVar(&cfg.batchSize, "batch", 5, "Number of tasks to fetch per poll") - dur := fs.Duration("poll", 30*time.Second, "Poll interval") - fs.BoolVar(&cfg.oneShot, "one-shot", false, "Process one batch and exit") - fs.BoolVar(&cfg.dryRun, "dry-run", false, "Fetch tasks but don't run inference") - - fs.Parse(args) - cfg.pollInterval = *dur - - if langs != "" { - cfg.languages = splitComma(langs) +func RunWorker(opts WorkerOpts) error { + // Apply env-var fallbacks for fields not set by flags. + if opts.APIBase == "" { + opts.APIBase = envOr("LEM_API", "https://infer.lthn.ai") } - if mods != "" { - cfg.models = splitComma(mods) + if opts.WorkerID == "" { + opts.WorkerID = envOr("LEM_WORKER_ID", machineID()) + } + if opts.Name == "" { + opts.Name = envOr("LEM_WORKER_NAME", hostname()) + } + if opts.APIKey == "" { + opts.APIKey = envOr("LEM_API_KEY", "") + } + if opts.GPUType == "" { + opts.GPUType = envOr("LEM_GPU", "") + } + if opts.VRAMGb == 0 { + opts.VRAMGb = intEnvOr("LEM_VRAM_GB", 0) + } + if opts.Languages == "" { + opts.Languages = envOr("LEM_LANGUAGES", "") + } + if opts.Models == "" { + opts.Models = envOr("LEM_MODELS", "") + } + if opts.InferURL == "" { + opts.InferURL = envOr("LEM_INFER_URL", "http://localhost:8090") + } + if opts.BatchSize <= 0 { + opts.BatchSize = 5 + } + if opts.PollInterval <= 0 { + opts.PollInterval = 30 * time.Second + } + + // Build internal config. + cfg := workerConfig{ + apiBase: opts.APIBase, + workerID: opts.WorkerID, + name: opts.Name, + apiKey: opts.APIKey, + gpuType: opts.GPUType, + vramGb: opts.VRAMGb, + inferURL: opts.InferURL, + taskType: opts.TaskType, + batchSize: opts.BatchSize, + pollInterval: opts.PollInterval, + oneShot: opts.OneShot, + dryRun: opts.DryRun, + } + + if opts.Languages != "" { + cfg.languages = splitComma(opts.Languages) + } + if opts.Models != "" { + cfg.models = splitComma(opts.Models) } if cfg.apiKey == "" { @@ -86,7 +131,7 @@ func RunWorker(args []string) { } if cfg.apiKey == "" { - log.Fatal("No API key. Set LEM_API_KEY or run: lem worker-auth") + return fmt.Errorf("no API key — set LEM_API_KEY or run: lem worker-auth") } log.Printf("LEM Worker starting") @@ -102,7 +147,7 @@ func RunWorker(args []string) { // Register with the API. if err := workerRegister(&cfg); err != nil { - log.Fatalf("Registration failed: %v", err) + return fmt.Errorf("registration failed: %w", err) } log.Println("Registered with LEM API") @@ -112,7 +157,7 @@ func RunWorker(args []string) { if cfg.oneShot { log.Printf("One-shot mode: processed %d tasks, exiting", processed) - return + return nil } if processed == 0 {