Replace passthrough() + stdlib flag.FlagSet anti-pattern with proper cobra integration. Every Run* function now takes a typed *Opts struct and returns error. Flags registered via cli.StringFlag/IntFlag/etc. Commands participate in Core lifecycle with full cobra flag parsing. - 6 command groups: gen, score, data, export, infra, mon - 25 commands converted, 0 passthrough() calls remain - Delete passthrough() helper from lem.go - Update export_test.go to use ExportOpts struct Co-Authored-By: Virgil <virgil@lethean.io>
110 lines
6.3 KiB
Go
110 lines
6.3 KiB
Go
package lemcmd
|
|
|
|
import (
|
|
"fmt"
|
|
|
|
"forge.lthn.ai/core/cli/pkg/cli"
|
|
"forge.lthn.ai/lthn/lem/pkg/lem"
|
|
)
|
|
|
|
func addScoreCommands(root *cli.Command) {
|
|
scoreGroup := cli.NewGroup("score", "Scoring commands", "Score responses, probe models, compare results.")
|
|
|
|
// run — score existing response files.
|
|
var scoreCfg lem.ScoreOpts
|
|
scoreCmd := cli.NewCommand("run", "Score existing response files", "",
|
|
func(cmd *cli.Command, args []string) error {
|
|
return lem.RunScore(scoreCfg)
|
|
},
|
|
)
|
|
cli.StringFlag(scoreCmd, &scoreCfg.Input, "input", "i", "", "Input JSONL response file (required)")
|
|
cli.StringFlag(scoreCmd, &scoreCfg.Suites, "suites", "", "all", "Comma-separated suites or 'all'")
|
|
cli.StringFlag(scoreCmd, &scoreCfg.JudgeModel, "judge-model", "", "mlx-community/gemma-3-27b-it-qat-4bit", "Judge model name")
|
|
cli.StringFlag(scoreCmd, &scoreCfg.JudgeURL, "judge-url", "", "http://10.69.69.108:8090", "Judge API URL")
|
|
cli.IntFlag(scoreCmd, &scoreCfg.Concurrency, "concurrency", "c", 4, "Max concurrent judge calls")
|
|
cli.StringFlag(scoreCmd, &scoreCfg.Output, "output", "o", "scores.json", "Output score file path")
|
|
cli.BoolFlag(scoreCmd, &scoreCfg.Resume, "resume", "", false, "Resume from existing output, skipping scored IDs")
|
|
scoreGroup.AddCommand(scoreCmd)
|
|
|
|
// probe — generate responses and score them.
|
|
var probeCfg lem.ProbeOpts
|
|
probeCmd := cli.NewCommand("probe", "Generate responses and score them", "",
|
|
func(cmd *cli.Command, args []string) error {
|
|
return lem.RunProbe(probeCfg)
|
|
},
|
|
)
|
|
cli.StringFlag(probeCmd, &probeCfg.Model, "model", "m", "", "Target model name (required)")
|
|
cli.StringFlag(probeCmd, &probeCfg.TargetURL, "target-url", "", "", "Target model API URL (defaults to judge-url)")
|
|
cli.StringFlag(probeCmd, &probeCfg.ProbesFile, "probes", "", "", "Custom probes JSONL file (uses built-in content probes if not specified)")
|
|
cli.StringFlag(probeCmd, &probeCfg.Suites, "suites", "", "all", "Comma-separated suites or 'all'")
|
|
cli.StringFlag(probeCmd, &probeCfg.JudgeModel, "judge-model", "", "mlx-community/gemma-3-27b-it-qat-4bit", "Judge model name")
|
|
cli.StringFlag(probeCmd, &probeCfg.JudgeURL, "judge-url", "", "http://10.69.69.108:8090", "Judge API URL")
|
|
cli.IntFlag(probeCmd, &probeCfg.Concurrency, "concurrency", "c", 4, "Max concurrent judge calls")
|
|
cli.StringFlag(probeCmd, &probeCfg.Output, "output", "o", "scores.json", "Output score file path")
|
|
scoreGroup.AddCommand(probeCmd)
|
|
|
|
// compare has a different signature — it takes two named args, not []string.
|
|
var compareOld, compareNew string
|
|
compareCmd := cli.NewCommand("compare", "Compare two score files", "",
|
|
func(cmd *cli.Command, args []string) error {
|
|
if compareOld == "" || compareNew == "" {
|
|
return fmt.Errorf("--old and --new are required")
|
|
}
|
|
return lem.RunCompare(compareOld, compareNew)
|
|
},
|
|
)
|
|
cli.StringFlag(compareCmd, &compareOld, "old", "", "", "Old score file (required)")
|
|
cli.StringFlag(compareCmd, &compareNew, "new", "", "", "New score file (required)")
|
|
scoreGroup.AddCommand(compareCmd)
|
|
|
|
// attention — Q/K Bone Orientation analysis.
|
|
var attCfg lem.AttentionOpts
|
|
attCmd := cli.NewCommand("attention", "Q/K Bone Orientation analysis for a prompt", "",
|
|
func(cmd *cli.Command, args []string) error {
|
|
return lem.RunAttention(attCfg)
|
|
},
|
|
)
|
|
cli.StringFlag(attCmd, &attCfg.Model, "model", "m", "gemma3/1b", "Model config path (relative to .core/ai/models/)")
|
|
cli.StringFlag(attCmd, &attCfg.Prompt, "prompt", "p", "", "Prompt text to analyse")
|
|
cli.BoolFlag(attCmd, &attCfg.JSON, "json", "j", false, "Output as JSON")
|
|
cli.IntFlag(attCmd, &attCfg.CacheLimit, "cache-limit", "", 0, "Metal cache limit in GB (0 = use ai.yaml default)")
|
|
cli.IntFlag(attCmd, &attCfg.MemLimit, "mem-limit", "", 0, "Metal memory limit in GB (0 = use ai.yaml default)")
|
|
cli.StringFlag(attCmd, &attCfg.Root, "root", "", ".", "Project root (for .core/ai/ config)")
|
|
scoreGroup.AddCommand(attCmd)
|
|
|
|
// tier — score expansion responses with heuristic/judge tiers.
|
|
var tierCfg lem.TierScoreOpts
|
|
tierCmd := cli.NewCommand("tier", "Score expansion responses (heuristic/judge tiers)", "",
|
|
func(cmd *cli.Command, args []string) error {
|
|
return lem.RunTierScore(tierCfg)
|
|
},
|
|
)
|
|
cli.StringFlag(tierCmd, &tierCfg.DBPath, "db", "", "", "DuckDB database path (defaults to LEM_DB env)")
|
|
cli.IntFlag(tierCmd, &tierCfg.Tier, "tier", "t", 1, "Scoring tier: 1=heuristic, 2=LEM judge, 3=external")
|
|
cli.IntFlag(tierCmd, &tierCfg.Limit, "limit", "l", 0, "Max items to score (0=all)")
|
|
scoreGroup.AddCommand(tierCmd)
|
|
|
|
// agent — ROCm scoring daemon.
|
|
var agentCfg lem.AgentOpts
|
|
agentCmd := cli.NewCommand("agent", "ROCm scoring daemon (polls M3, scores checkpoints)", "",
|
|
func(cmd *cli.Command, args []string) error {
|
|
return lem.RunAgent(agentCfg)
|
|
},
|
|
)
|
|
cli.StringFlag(agentCmd, &agentCfg.M3Host, "m3-host", "", envOr("M3_HOST", "10.69.69.108"), "M3 host address")
|
|
cli.StringFlag(agentCmd, &agentCfg.M3User, "m3-user", "", envOr("M3_USER", "claude"), "M3 SSH user")
|
|
cli.StringFlag(agentCmd, &agentCfg.M3SSHKey, "m3-ssh-key", "", envOr("M3_SSH_KEY", expandHome("~/.ssh/id_ed25519")), "SSH key for M3")
|
|
cli.StringFlag(agentCmd, &agentCfg.M3AdapterBase, "m3-adapter-base", "", envOr("M3_ADAPTER_BASE", "/Volumes/Data/lem"), "Adapter base dir on M3")
|
|
cli.StringFlag(agentCmd, &agentCfg.InfluxURL, "influx", "", envOr("INFLUX_URL", "http://10.69.69.165:8181"), "InfluxDB URL")
|
|
cli.StringFlag(agentCmd, &agentCfg.InfluxDB, "influx-db", "", envOr("INFLUX_DB", "training"), "InfluxDB database")
|
|
cli.StringFlag(agentCmd, &agentCfg.APIURL, "api-url", "", envOr("LEM_API_URL", "http://localhost:8080"), "OpenAI-compatible inference API URL")
|
|
cli.StringFlag(agentCmd, &agentCfg.Model, "model", "m", envOr("LEM_MODEL", ""), "Model name for API (overrides auto-detect)")
|
|
cli.StringFlag(agentCmd, &agentCfg.BaseModel, "base-model", "", envOr("BASE_MODEL", "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B"), "HuggingFace base model ID")
|
|
cli.IntFlag(agentCmd, &agentCfg.PollInterval, "poll", "", intEnvOr("POLL_INTERVAL", 300), "Poll interval in seconds")
|
|
cli.StringFlag(agentCmd, &agentCfg.WorkDir, "work-dir", "", envOr("WORK_DIR", "/tmp/scoring-agent"), "Working directory for adapters")
|
|
cli.BoolFlag(agentCmd, &agentCfg.OneShot, "one-shot", "", false, "Process one checkpoint and exit")
|
|
cli.BoolFlag(agentCmd, &agentCfg.DryRun, "dry-run", "", false, "Discover and plan but don't execute")
|
|
scoreGroup.AddCommand(agentCmd)
|
|
|
|
root.AddCommand(scoreGroup)
|
|
}
|