From a0a0118155d8b18a20a95baec98ce2477c8f2835 Mon Sep 17 00:00:00 2001 From: Snider Date: Sun, 22 Feb 2026 18:53:15 +0000 Subject: [PATCH] refactor: move runScore and runProbe to pkg/lem All 28 commands now accessible as exported lem.Run* functions. Prerequisite for CLI framework migration. Co-Authored-By: Claude Opus 4.6 --- main.go | 199 ++++--------------------------------------- pkg/lem/score_cmd.go | 169 ++++++++++++++++++++++++++++++++++++ 2 files changed, 185 insertions(+), 183 deletions(-) create mode 100644 pkg/lem/score_cmd.go diff --git a/main.go b/main.go index cc52cfa..ce778f2 100644 --- a/main.go +++ b/main.go @@ -5,7 +5,6 @@ import ( "fmt" "log" "os" - "time" "forge.lthn.ai/lthn/lem/pkg/lem" ) @@ -62,11 +61,24 @@ func main() { case "distill": lem.RunDistill(os.Args[2:]) case "score": - runScore(os.Args[2:]) + lem.RunScore(os.Args[2:]) case "probe": - runProbe(os.Args[2:]) + lem.RunProbe(os.Args[2:]) case "compare": - runCompare(os.Args[2:]) + fs := flag.NewFlagSet("compare", flag.ExitOnError) + oldFile := fs.String("old", "", "Old score file (required)") + newFile := fs.String("new", "", "New score file (required)") + if err := fs.Parse(os.Args[2:]); err != nil { + log.Fatalf("parse flags: %v", err) + } + if *oldFile == "" || *newFile == "" { + fmt.Fprintln(os.Stderr, "error: --old and --new are required") + fs.Usage() + os.Exit(1) + } + if err := lem.RunCompare(*oldFile, *newFile); err != nil { + log.Fatalf("compare: %v", err) + } case "status": lem.RunStatus(os.Args[2:]) case "expand": @@ -114,182 +126,3 @@ func main() { os.Exit(1) } } - -func runScore(args []string) { - fs := flag.NewFlagSet("score", flag.ExitOnError) - - input := fs.String("input", "", "Input JSONL response file (required)") - suites := fs.String("suites", "all", "Comma-separated suites or 'all'") - judgeModel := fs.String("judge-model", "mlx-community/gemma-3-27b-it-qat-4bit", "Judge model name") - judgeURL := fs.String("judge-url", "http://10.69.69.108:8090", "Judge API URL") - concurrency := fs.Int("concurrency", 4, "Max concurrent judge calls") - output := fs.String("output", "scores.json", "Output score file path") - resume := fs.Bool("resume", false, "Resume from existing output, skipping scored IDs") - - if err := fs.Parse(args); err != nil { - log.Fatalf("parse flags: %v", err) - } - - if *input == "" { - fmt.Fprintln(os.Stderr, "error: --input is required") - fs.Usage() - os.Exit(1) - } - - responses, err := lem.ReadResponses(*input) - if err != nil { - log.Fatalf("read responses: %v", err) - } - log.Printf("loaded %d responses from %s", len(responses), *input) - - if *resume { - if _, statErr := os.Stat(*output); statErr == nil { - existing, readErr := lem.ReadScorerOutput(*output) - if readErr != nil { - log.Fatalf("read existing scores for resume: %v", readErr) - } - - scored := make(map[string]bool) - for _, scores := range existing.PerPrompt { - for _, ps := range scores { - scored[ps.ID] = true - } - } - - var filtered []lem.Response - for _, r := range responses { - if !scored[r.ID] { - filtered = append(filtered, r) - } - } - log.Printf("resume: skipping %d already-scored, %d remaining", - len(responses)-len(filtered), len(filtered)) - responses = filtered - - if len(responses) == 0 { - log.Println("all responses already scored, nothing to do") - return - } - } - } - - client := lem.NewClient(*judgeURL, *judgeModel) - client.MaxTokens = 512 - judge := lem.NewJudge(client) - engine := lem.NewEngine(judge, *concurrency, *suites) - - log.Printf("scoring with %s", engine) - - perPrompt := engine.ScoreAll(responses) - - if *resume { - if _, statErr := os.Stat(*output); statErr == nil { - existing, _ := lem.ReadScorerOutput(*output) - for model, scores := range existing.PerPrompt { - perPrompt[model] = append(scores, perPrompt[model]...) - } - } - } - - averages := lem.ComputeAverages(perPrompt) - - scorerOutput := &lem.ScorerOutput{ - Metadata: lem.Metadata{ - JudgeModel: *judgeModel, - JudgeURL: *judgeURL, - ScoredAt: time.Now().UTC(), - ScorerVersion: "1.0.0", - Suites: engine.SuiteNames(), - }, - ModelAverages: averages, - PerPrompt: perPrompt, - } - - if err := lem.WriteScores(*output, scorerOutput); err != nil { - log.Fatalf("write scores: %v", err) - } - - log.Printf("wrote scores to %s", *output) -} - -func runProbe(args []string) { - fs := flag.NewFlagSet("probe", flag.ExitOnError) - - model := fs.String("model", "", "Target model name (required)") - targetURL := fs.String("target-url", "", "Target model API URL (defaults to judge-url)") - probesFile := fs.String("probes", "", "Custom probes JSONL file (uses built-in content probes if not specified)") - suites := fs.String("suites", "all", "Comma-separated suites or 'all'") - judgeModel := fs.String("judge-model", "mlx-community/gemma-3-27b-it-qat-4bit", "Judge model name") - judgeURL := fs.String("judge-url", "http://10.69.69.108:8090", "Judge API URL") - concurrency := fs.Int("concurrency", 4, "Max concurrent judge calls") - output := fs.String("output", "scores.json", "Output score file path") - - if err := fs.Parse(args); err != nil { - log.Fatalf("parse flags: %v", err) - } - - if *model == "" { - fmt.Fprintln(os.Stderr, "error: --model is required") - fs.Usage() - os.Exit(1) - } - - if *targetURL == "" { - *targetURL = *judgeURL - } - - targetClient := lem.NewClient(*targetURL, *model) - targetClient.MaxTokens = 1024 - judgeClient := lem.NewClient(*judgeURL, *judgeModel) - judgeClient.MaxTokens = 512 - judge := lem.NewJudge(judgeClient) - engine := lem.NewEngine(judge, *concurrency, *suites) - prober := lem.NewProber(targetClient, engine) - - var scorerOutput *lem.ScorerOutput - var err error - - if *probesFile != "" { - probes, readErr := lem.ReadResponses(*probesFile) - if readErr != nil { - log.Fatalf("read probes: %v", readErr) - } - log.Printf("loaded %d custom probes from %s", len(probes), *probesFile) - - scorerOutput, err = prober.ProbeModel(probes, *model) - } else { - log.Printf("using %d built-in content probes", len(lem.ContentProbes)) - scorerOutput, err = prober.ProbeContent(*model) - } - - if err != nil { - log.Fatalf("probe: %v", err) - } - - if writeErr := lem.WriteScores(*output, scorerOutput); writeErr != nil { - log.Fatalf("write scores: %v", writeErr) - } - - log.Printf("wrote scores to %s", *output) -} - -func runCompare(args []string) { - fs := flag.NewFlagSet("compare", flag.ExitOnError) - - oldFile := fs.String("old", "", "Old score file (required)") - newFile := fs.String("new", "", "New score file (required)") - - if err := fs.Parse(args); err != nil { - log.Fatalf("parse flags: %v", err) - } - - if *oldFile == "" || *newFile == "" { - fmt.Fprintln(os.Stderr, "error: --old and --new are required") - fs.Usage() - os.Exit(1) - } - - if err := lem.RunCompare(*oldFile, *newFile); err != nil { - log.Fatalf("compare: %v", err) - } -} diff --git a/pkg/lem/score_cmd.go b/pkg/lem/score_cmd.go new file mode 100644 index 0000000..9e2fb5e --- /dev/null +++ b/pkg/lem/score_cmd.go @@ -0,0 +1,169 @@ +package lem + +import ( + "flag" + "fmt" + "log" + "os" + "time" +) + +// RunScore scores existing response files using a judge model. +func RunScore(args []string) { + fs := flag.NewFlagSet("score", flag.ExitOnError) + + input := fs.String("input", "", "Input JSONL response file (required)") + suites := fs.String("suites", "all", "Comma-separated suites or 'all'") + judgeModel := fs.String("judge-model", "mlx-community/gemma-3-27b-it-qat-4bit", "Judge model name") + judgeURL := fs.String("judge-url", "http://10.69.69.108:8090", "Judge API URL") + concurrency := fs.Int("concurrency", 4, "Max concurrent judge calls") + output := fs.String("output", "scores.json", "Output score file path") + resume := fs.Bool("resume", false, "Resume from existing output, skipping scored IDs") + + if err := fs.Parse(args); err != nil { + log.Fatalf("parse flags: %v", err) + } + + if *input == "" { + fmt.Fprintln(os.Stderr, "error: --input is required") + fs.Usage() + os.Exit(1) + } + + responses, err := ReadResponses(*input) + if err != nil { + log.Fatalf("read responses: %v", err) + } + log.Printf("loaded %d responses from %s", len(responses), *input) + + if *resume { + if _, statErr := os.Stat(*output); statErr == nil { + existing, readErr := ReadScorerOutput(*output) + if readErr != nil { + log.Fatalf("read existing scores for resume: %v", readErr) + } + + scored := make(map[string]bool) + for _, scores := range existing.PerPrompt { + for _, ps := range scores { + scored[ps.ID] = true + } + } + + var filtered []Response + for _, r := range responses { + if !scored[r.ID] { + filtered = append(filtered, r) + } + } + log.Printf("resume: skipping %d already-scored, %d remaining", + len(responses)-len(filtered), len(filtered)) + responses = filtered + + if len(responses) == 0 { + log.Println("all responses already scored, nothing to do") + return + } + } + } + + client := NewClient(*judgeURL, *judgeModel) + client.MaxTokens = 512 + judge := NewJudge(client) + engine := NewEngine(judge, *concurrency, *suites) + + log.Printf("scoring with %s", engine) + + perPrompt := engine.ScoreAll(responses) + + if *resume { + if _, statErr := os.Stat(*output); statErr == nil { + existing, _ := ReadScorerOutput(*output) + for model, scores := range existing.PerPrompt { + perPrompt[model] = append(scores, perPrompt[model]...) + } + } + } + + averages := ComputeAverages(perPrompt) + + scorerOutput := &ScorerOutput{ + Metadata: Metadata{ + JudgeModel: *judgeModel, + JudgeURL: *judgeURL, + ScoredAt: time.Now().UTC(), + ScorerVersion: "1.0.0", + Suites: engine.SuiteNames(), + }, + ModelAverages: averages, + PerPrompt: perPrompt, + } + + if err := WriteScores(*output, scorerOutput); err != nil { + log.Fatalf("write scores: %v", err) + } + + log.Printf("wrote scores to %s", *output) +} + +// RunProbe generates responses from a target model and scores them. +func RunProbe(args []string) { + fs := flag.NewFlagSet("probe", flag.ExitOnError) + + model := fs.String("model", "", "Target model name (required)") + targetURL := fs.String("target-url", "", "Target model API URL (defaults to judge-url)") + probesFile := fs.String("probes", "", "Custom probes JSONL file (uses built-in content probes if not specified)") + suites := fs.String("suites", "all", "Comma-separated suites or 'all'") + judgeModel := fs.String("judge-model", "mlx-community/gemma-3-27b-it-qat-4bit", "Judge model name") + judgeURL := fs.String("judge-url", "http://10.69.69.108:8090", "Judge API URL") + concurrency := fs.Int("concurrency", 4, "Max concurrent judge calls") + output := fs.String("output", "scores.json", "Output score file path") + + if err := fs.Parse(args); err != nil { + log.Fatalf("parse flags: %v", err) + } + + if *model == "" { + fmt.Fprintln(os.Stderr, "error: --model is required") + fs.Usage() + os.Exit(1) + } + + if *targetURL == "" { + *targetURL = *judgeURL + } + + targetClient := NewClient(*targetURL, *model) + targetClient.MaxTokens = 1024 + judgeClient := NewClient(*judgeURL, *judgeModel) + judgeClient.MaxTokens = 512 + judge := NewJudge(judgeClient) + engine := NewEngine(judge, *concurrency, *suites) + prober := NewProber(targetClient, engine) + + var scorerOutput *ScorerOutput + var err error + + if *probesFile != "" { + probes, readErr := ReadResponses(*probesFile) + if readErr != nil { + log.Fatalf("read probes: %v", readErr) + } + log.Printf("loaded %d custom probes from %s", len(probes), *probesFile) + + scorerOutput, err = prober.ProbeModel(probes, *model) + } else { + log.Printf("using %d built-in content probes", len(ContentProbes)) + scorerOutput, err = prober.ProbeContent(*model) + } + + if err != nil { + log.Fatalf("probe: %v", err) + } + + if writeErr := WriteScores(*output, scorerOutput); writeErr != nil { + log.Fatalf("write scores: %v", writeErr) + } + + log.Printf("wrote scores to %s", *output) +}