72 lines
4.7 KiB
Go
72 lines
4.7 KiB
Go
package lemcmd
|
|
|
|
import (
|
|
"forge.lthn.ai/core/cli/pkg/cli"
|
|
"forge.lthn.ai/lthn/lem/pkg/lem"
|
|
)
|
|
|
|
func addGenCommands(root *cli.Command) {
|
|
genGroup := cli.NewGroup("gen", "Generation commands", "Distill, expand, and generate training data.")
|
|
|
|
// distill — native Metal distillation with grammar scoring.
|
|
var distillCfg lem.DistillOpts
|
|
distillCmd := cli.NewCommand("distill", "Native Metal distillation (go-mlx + grammar scoring)", "",
|
|
func(cmd *cli.Command, args []string) error {
|
|
return lem.RunDistill(distillCfg)
|
|
},
|
|
)
|
|
cli.StringFlag(distillCmd, &distillCfg.Model, "model", "m", "", "Model config path (relative to .core/ai/models/)")
|
|
cli.StringFlag(distillCmd, &distillCfg.Probes, "probes", "p", "", "Probe set name from probes.yaml")
|
|
cli.StringFlag(distillCmd, &distillCfg.Output, "output", "o", "", "Output JSONL path (defaults to model training dir)")
|
|
cli.IntFlag(distillCmd, &distillCfg.Lesson, "lesson", "", -1, "Lesson number to append to (defaults to probe set phase)")
|
|
cli.Float64Flag(distillCmd, &distillCfg.MinScore, "min-score", "", 0, "Min grammar composite (0 = use ai.yaml default)")
|
|
cli.IntFlag(distillCmd, &distillCfg.Runs, "runs", "r", 0, "Generations per probe (0 = use ai.yaml default)")
|
|
cli.BoolFlag(distillCmd, &distillCfg.DryRun, "dry-run", "", false, "Show plan and exit without generating")
|
|
cli.StringFlag(distillCmd, &distillCfg.Root, "root", "", ".", "Project root (for .core/ai/ config)")
|
|
cli.IntFlag(distillCmd, &distillCfg.CacheLimit, "cache-limit", "", 0, "Metal cache limit in GB (0 = use ai.yaml default)")
|
|
cli.IntFlag(distillCmd, &distillCfg.MemLimit, "mem-limit", "", 0, "Metal memory limit in GB (0 = use ai.yaml default)")
|
|
genGroup.AddCommand(distillCmd)
|
|
|
|
// expand — generate expansion responses via trained LEM model.
|
|
var expandCfg lem.ExpandOpts
|
|
expandCmd := cli.NewCommand("expand", "Generate expansion responses via trained LEM model", "",
|
|
func(cmd *cli.Command, args []string) error {
|
|
return lem.RunExpand(expandCfg)
|
|
},
|
|
)
|
|
cli.StringFlag(expandCmd, &expandCfg.Model, "model", "m", "", "Model name for generation (required)")
|
|
cli.StringFlag(expandCmd, &expandCfg.DB, "db", "", "", "DuckDB database path (primary prompt source)")
|
|
cli.StringFlag(expandCmd, &expandCfg.Prompts, "prompts", "p", "", "Input JSONL file with expansion prompts (fallback)")
|
|
cli.StringFlag(expandCmd, &expandCfg.APIURL, "api-url", "", "http://10.69.69.108:8090", "OpenAI-compatible API URL")
|
|
cli.StringFlag(expandCmd, &expandCfg.Worker, "worker", "", "", "Worker hostname (defaults to os.Hostname())")
|
|
cli.IntFlag(expandCmd, &expandCfg.Limit, "limit", "", 0, "Max prompts to process (0 = all)")
|
|
cli.StringFlag(expandCmd, &expandCfg.Output, "output", "o", ".", "Output directory for JSONL files")
|
|
cli.StringFlag(expandCmd, &expandCfg.Influx, "influx", "", "", "InfluxDB URL (default http://10.69.69.165:8181)")
|
|
cli.StringFlag(expandCmd, &expandCfg.InfluxDB, "influx-db", "", "", "InfluxDB database name (default training)")
|
|
cli.BoolFlag(expandCmd, &expandCfg.DryRun, "dry-run", "", false, "Print plan and exit without generating")
|
|
genGroup.AddCommand(expandCmd)
|
|
|
|
// conv — generate conversational training data (calm phase).
|
|
var convCfg lem.ConvOpts
|
|
convCmd := cli.NewCommand("conv", "Generate conversational training data (calm phase)", "",
|
|
func(cmd *cli.Command, args []string) error {
|
|
return lem.RunConv(convCfg)
|
|
},
|
|
)
|
|
cli.StringFlag(convCmd, &convCfg.OutputDir, "output-dir", "o", "", "Output directory for training files (required)")
|
|
cli.StringFlag(convCmd, &convCfg.Extra, "extra", "", "", "Additional conversations JSONL file (multi-turn format)")
|
|
cli.StringFlag(convCmd, &convCfg.Golden, "golden", "", "", "Golden set JSONL to convert to single-turn conversations")
|
|
cli.StringFlag(convCmd, &convCfg.DB, "db", "", "", "DuckDB database path for golden set (alternative to --golden)")
|
|
cli.IntFlag(convCmd, &convCfg.TrainPct, "train-pct", "", 80, "Training set percentage")
|
|
cli.IntFlag(convCmd, &convCfg.ValidPct, "valid-pct", "", 10, "Validation set percentage")
|
|
cli.IntFlag(convCmd, &convCfg.TestPct, "test-pct", "", 10, "Test set percentage")
|
|
cli.Int64Flag(convCmd, &convCfg.Seed, "seed", "", 42, "Random seed for shuffling")
|
|
cli.IntFlag(convCmd, &convCfg.MinChars, "min-chars", "", 50, "Minimum response chars for golden set conversion")
|
|
cli.BoolFlag(convCmd, &convCfg.NoBuiltin, "no-builtin", "", false, "Exclude built-in seed conversations")
|
|
cli.StringFlag(convCmd, &convCfg.Influx, "influx", "", "", "InfluxDB URL for progress reporting")
|
|
cli.StringFlag(convCmd, &convCfg.InfluxDB, "influx-db", "", "", "InfluxDB database name")
|
|
cli.StringFlag(convCmd, &convCfg.Worker, "worker", "", "", "Worker hostname for InfluxDB reporting")
|
|
genGroup.AddCommand(convCmd)
|
|
|
|
root.AddCommand(genGroup)
|
|
}
|