refactor: move runScore and runProbe to pkg/lem
All 28 commands now accessible as exported lem.Run* functions. Prerequisite for CLI framework migration. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
131d1694b2
commit
a0a0118155
2 changed files with 185 additions and 183 deletions
199
main.go
199
main.go
|
|
@ -5,7 +5,6 @@ import (
|
|||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"forge.lthn.ai/lthn/lem/pkg/lem"
|
||||
)
|
||||
|
|
@ -62,11 +61,24 @@ func main() {
|
|||
case "distill":
|
||||
lem.RunDistill(os.Args[2:])
|
||||
case "score":
|
||||
runScore(os.Args[2:])
|
||||
lem.RunScore(os.Args[2:])
|
||||
case "probe":
|
||||
runProbe(os.Args[2:])
|
||||
lem.RunProbe(os.Args[2:])
|
||||
case "compare":
|
||||
runCompare(os.Args[2:])
|
||||
fs := flag.NewFlagSet("compare", flag.ExitOnError)
|
||||
oldFile := fs.String("old", "", "Old score file (required)")
|
||||
newFile := fs.String("new", "", "New score file (required)")
|
||||
if err := fs.Parse(os.Args[2:]); err != nil {
|
||||
log.Fatalf("parse flags: %v", err)
|
||||
}
|
||||
if *oldFile == "" || *newFile == "" {
|
||||
fmt.Fprintln(os.Stderr, "error: --old and --new are required")
|
||||
fs.Usage()
|
||||
os.Exit(1)
|
||||
}
|
||||
if err := lem.RunCompare(*oldFile, *newFile); err != nil {
|
||||
log.Fatalf("compare: %v", err)
|
||||
}
|
||||
case "status":
|
||||
lem.RunStatus(os.Args[2:])
|
||||
case "expand":
|
||||
|
|
@ -114,182 +126,3 @@ func main() {
|
|||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
func runScore(args []string) {
|
||||
fs := flag.NewFlagSet("score", flag.ExitOnError)
|
||||
|
||||
input := fs.String("input", "", "Input JSONL response file (required)")
|
||||
suites := fs.String("suites", "all", "Comma-separated suites or 'all'")
|
||||
judgeModel := fs.String("judge-model", "mlx-community/gemma-3-27b-it-qat-4bit", "Judge model name")
|
||||
judgeURL := fs.String("judge-url", "http://10.69.69.108:8090", "Judge API URL")
|
||||
concurrency := fs.Int("concurrency", 4, "Max concurrent judge calls")
|
||||
output := fs.String("output", "scores.json", "Output score file path")
|
||||
resume := fs.Bool("resume", false, "Resume from existing output, skipping scored IDs")
|
||||
|
||||
if err := fs.Parse(args); err != nil {
|
||||
log.Fatalf("parse flags: %v", err)
|
||||
}
|
||||
|
||||
if *input == "" {
|
||||
fmt.Fprintln(os.Stderr, "error: --input is required")
|
||||
fs.Usage()
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
responses, err := lem.ReadResponses(*input)
|
||||
if err != nil {
|
||||
log.Fatalf("read responses: %v", err)
|
||||
}
|
||||
log.Printf("loaded %d responses from %s", len(responses), *input)
|
||||
|
||||
if *resume {
|
||||
if _, statErr := os.Stat(*output); statErr == nil {
|
||||
existing, readErr := lem.ReadScorerOutput(*output)
|
||||
if readErr != nil {
|
||||
log.Fatalf("read existing scores for resume: %v", readErr)
|
||||
}
|
||||
|
||||
scored := make(map[string]bool)
|
||||
for _, scores := range existing.PerPrompt {
|
||||
for _, ps := range scores {
|
||||
scored[ps.ID] = true
|
||||
}
|
||||
}
|
||||
|
||||
var filtered []lem.Response
|
||||
for _, r := range responses {
|
||||
if !scored[r.ID] {
|
||||
filtered = append(filtered, r)
|
||||
}
|
||||
}
|
||||
log.Printf("resume: skipping %d already-scored, %d remaining",
|
||||
len(responses)-len(filtered), len(filtered))
|
||||
responses = filtered
|
||||
|
||||
if len(responses) == 0 {
|
||||
log.Println("all responses already scored, nothing to do")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
client := lem.NewClient(*judgeURL, *judgeModel)
|
||||
client.MaxTokens = 512
|
||||
judge := lem.NewJudge(client)
|
||||
engine := lem.NewEngine(judge, *concurrency, *suites)
|
||||
|
||||
log.Printf("scoring with %s", engine)
|
||||
|
||||
perPrompt := engine.ScoreAll(responses)
|
||||
|
||||
if *resume {
|
||||
if _, statErr := os.Stat(*output); statErr == nil {
|
||||
existing, _ := lem.ReadScorerOutput(*output)
|
||||
for model, scores := range existing.PerPrompt {
|
||||
perPrompt[model] = append(scores, perPrompt[model]...)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
averages := lem.ComputeAverages(perPrompt)
|
||||
|
||||
scorerOutput := &lem.ScorerOutput{
|
||||
Metadata: lem.Metadata{
|
||||
JudgeModel: *judgeModel,
|
||||
JudgeURL: *judgeURL,
|
||||
ScoredAt: time.Now().UTC(),
|
||||
ScorerVersion: "1.0.0",
|
||||
Suites: engine.SuiteNames(),
|
||||
},
|
||||
ModelAverages: averages,
|
||||
PerPrompt: perPrompt,
|
||||
}
|
||||
|
||||
if err := lem.WriteScores(*output, scorerOutput); err != nil {
|
||||
log.Fatalf("write scores: %v", err)
|
||||
}
|
||||
|
||||
log.Printf("wrote scores to %s", *output)
|
||||
}
|
||||
|
||||
func runProbe(args []string) {
|
||||
fs := flag.NewFlagSet("probe", flag.ExitOnError)
|
||||
|
||||
model := fs.String("model", "", "Target model name (required)")
|
||||
targetURL := fs.String("target-url", "", "Target model API URL (defaults to judge-url)")
|
||||
probesFile := fs.String("probes", "", "Custom probes JSONL file (uses built-in content probes if not specified)")
|
||||
suites := fs.String("suites", "all", "Comma-separated suites or 'all'")
|
||||
judgeModel := fs.String("judge-model", "mlx-community/gemma-3-27b-it-qat-4bit", "Judge model name")
|
||||
judgeURL := fs.String("judge-url", "http://10.69.69.108:8090", "Judge API URL")
|
||||
concurrency := fs.Int("concurrency", 4, "Max concurrent judge calls")
|
||||
output := fs.String("output", "scores.json", "Output score file path")
|
||||
|
||||
if err := fs.Parse(args); err != nil {
|
||||
log.Fatalf("parse flags: %v", err)
|
||||
}
|
||||
|
||||
if *model == "" {
|
||||
fmt.Fprintln(os.Stderr, "error: --model is required")
|
||||
fs.Usage()
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if *targetURL == "" {
|
||||
*targetURL = *judgeURL
|
||||
}
|
||||
|
||||
targetClient := lem.NewClient(*targetURL, *model)
|
||||
targetClient.MaxTokens = 1024
|
||||
judgeClient := lem.NewClient(*judgeURL, *judgeModel)
|
||||
judgeClient.MaxTokens = 512
|
||||
judge := lem.NewJudge(judgeClient)
|
||||
engine := lem.NewEngine(judge, *concurrency, *suites)
|
||||
prober := lem.NewProber(targetClient, engine)
|
||||
|
||||
var scorerOutput *lem.ScorerOutput
|
||||
var err error
|
||||
|
||||
if *probesFile != "" {
|
||||
probes, readErr := lem.ReadResponses(*probesFile)
|
||||
if readErr != nil {
|
||||
log.Fatalf("read probes: %v", readErr)
|
||||
}
|
||||
log.Printf("loaded %d custom probes from %s", len(probes), *probesFile)
|
||||
|
||||
scorerOutput, err = prober.ProbeModel(probes, *model)
|
||||
} else {
|
||||
log.Printf("using %d built-in content probes", len(lem.ContentProbes))
|
||||
scorerOutput, err = prober.ProbeContent(*model)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
log.Fatalf("probe: %v", err)
|
||||
}
|
||||
|
||||
if writeErr := lem.WriteScores(*output, scorerOutput); writeErr != nil {
|
||||
log.Fatalf("write scores: %v", writeErr)
|
||||
}
|
||||
|
||||
log.Printf("wrote scores to %s", *output)
|
||||
}
|
||||
|
||||
func runCompare(args []string) {
|
||||
fs := flag.NewFlagSet("compare", flag.ExitOnError)
|
||||
|
||||
oldFile := fs.String("old", "", "Old score file (required)")
|
||||
newFile := fs.String("new", "", "New score file (required)")
|
||||
|
||||
if err := fs.Parse(args); err != nil {
|
||||
log.Fatalf("parse flags: %v", err)
|
||||
}
|
||||
|
||||
if *oldFile == "" || *newFile == "" {
|
||||
fmt.Fprintln(os.Stderr, "error: --old and --new are required")
|
||||
fs.Usage()
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if err := lem.RunCompare(*oldFile, *newFile); err != nil {
|
||||
log.Fatalf("compare: %v", err)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
169
pkg/lem/score_cmd.go
Normal file
169
pkg/lem/score_cmd.go
Normal file
|
|
@ -0,0 +1,169 @@
|
|||
package lem
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"time"
|
||||
)
|
||||
|
||||
// RunScore scores existing response files using a judge model.
|
||||
func RunScore(args []string) {
|
||||
fs := flag.NewFlagSet("score", flag.ExitOnError)
|
||||
|
||||
input := fs.String("input", "", "Input JSONL response file (required)")
|
||||
suites := fs.String("suites", "all", "Comma-separated suites or 'all'")
|
||||
judgeModel := fs.String("judge-model", "mlx-community/gemma-3-27b-it-qat-4bit", "Judge model name")
|
||||
judgeURL := fs.String("judge-url", "http://10.69.69.108:8090", "Judge API URL")
|
||||
concurrency := fs.Int("concurrency", 4, "Max concurrent judge calls")
|
||||
output := fs.String("output", "scores.json", "Output score file path")
|
||||
resume := fs.Bool("resume", false, "Resume from existing output, skipping scored IDs")
|
||||
|
||||
if err := fs.Parse(args); err != nil {
|
||||
log.Fatalf("parse flags: %v", err)
|
||||
}
|
||||
|
||||
if *input == "" {
|
||||
fmt.Fprintln(os.Stderr, "error: --input is required")
|
||||
fs.Usage()
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
responses, err := ReadResponses(*input)
|
||||
if err != nil {
|
||||
log.Fatalf("read responses: %v", err)
|
||||
}
|
||||
log.Printf("loaded %d responses from %s", len(responses), *input)
|
||||
|
||||
if *resume {
|
||||
if _, statErr := os.Stat(*output); statErr == nil {
|
||||
existing, readErr := ReadScorerOutput(*output)
|
||||
if readErr != nil {
|
||||
log.Fatalf("read existing scores for resume: %v", readErr)
|
||||
}
|
||||
|
||||
scored := make(map[string]bool)
|
||||
for _, scores := range existing.PerPrompt {
|
||||
for _, ps := range scores {
|
||||
scored[ps.ID] = true
|
||||
}
|
||||
}
|
||||
|
||||
var filtered []Response
|
||||
for _, r := range responses {
|
||||
if !scored[r.ID] {
|
||||
filtered = append(filtered, r)
|
||||
}
|
||||
}
|
||||
log.Printf("resume: skipping %d already-scored, %d remaining",
|
||||
len(responses)-len(filtered), len(filtered))
|
||||
responses = filtered
|
||||
|
||||
if len(responses) == 0 {
|
||||
log.Println("all responses already scored, nothing to do")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
client := NewClient(*judgeURL, *judgeModel)
|
||||
client.MaxTokens = 512
|
||||
judge := NewJudge(client)
|
||||
engine := NewEngine(judge, *concurrency, *suites)
|
||||
|
||||
log.Printf("scoring with %s", engine)
|
||||
|
||||
perPrompt := engine.ScoreAll(responses)
|
||||
|
||||
if *resume {
|
||||
if _, statErr := os.Stat(*output); statErr == nil {
|
||||
existing, _ := ReadScorerOutput(*output)
|
||||
for model, scores := range existing.PerPrompt {
|
||||
perPrompt[model] = append(scores, perPrompt[model]...)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
averages := ComputeAverages(perPrompt)
|
||||
|
||||
scorerOutput := &ScorerOutput{
|
||||
Metadata: Metadata{
|
||||
JudgeModel: *judgeModel,
|
||||
JudgeURL: *judgeURL,
|
||||
ScoredAt: time.Now().UTC(),
|
||||
ScorerVersion: "1.0.0",
|
||||
Suites: engine.SuiteNames(),
|
||||
},
|
||||
ModelAverages: averages,
|
||||
PerPrompt: perPrompt,
|
||||
}
|
||||
|
||||
if err := WriteScores(*output, scorerOutput); err != nil {
|
||||
log.Fatalf("write scores: %v", err)
|
||||
}
|
||||
|
||||
log.Printf("wrote scores to %s", *output)
|
||||
}
|
||||
|
||||
// RunProbe generates responses from a target model and scores them.
|
||||
func RunProbe(args []string) {
|
||||
fs := flag.NewFlagSet("probe", flag.ExitOnError)
|
||||
|
||||
model := fs.String("model", "", "Target model name (required)")
|
||||
targetURL := fs.String("target-url", "", "Target model API URL (defaults to judge-url)")
|
||||
probesFile := fs.String("probes", "", "Custom probes JSONL file (uses built-in content probes if not specified)")
|
||||
suites := fs.String("suites", "all", "Comma-separated suites or 'all'")
|
||||
judgeModel := fs.String("judge-model", "mlx-community/gemma-3-27b-it-qat-4bit", "Judge model name")
|
||||
judgeURL := fs.String("judge-url", "http://10.69.69.108:8090", "Judge API URL")
|
||||
concurrency := fs.Int("concurrency", 4, "Max concurrent judge calls")
|
||||
output := fs.String("output", "scores.json", "Output score file path")
|
||||
|
||||
if err := fs.Parse(args); err != nil {
|
||||
log.Fatalf("parse flags: %v", err)
|
||||
}
|
||||
|
||||
if *model == "" {
|
||||
fmt.Fprintln(os.Stderr, "error: --model is required")
|
||||
fs.Usage()
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if *targetURL == "" {
|
||||
*targetURL = *judgeURL
|
||||
}
|
||||
|
||||
targetClient := NewClient(*targetURL, *model)
|
||||
targetClient.MaxTokens = 1024
|
||||
judgeClient := NewClient(*judgeURL, *judgeModel)
|
||||
judgeClient.MaxTokens = 512
|
||||
judge := NewJudge(judgeClient)
|
||||
engine := NewEngine(judge, *concurrency, *suites)
|
||||
prober := NewProber(targetClient, engine)
|
||||
|
||||
var scorerOutput *ScorerOutput
|
||||
var err error
|
||||
|
||||
if *probesFile != "" {
|
||||
probes, readErr := ReadResponses(*probesFile)
|
||||
if readErr != nil {
|
||||
log.Fatalf("read probes: %v", readErr)
|
||||
}
|
||||
log.Printf("loaded %d custom probes from %s", len(probes), *probesFile)
|
||||
|
||||
scorerOutput, err = prober.ProbeModel(probes, *model)
|
||||
} else {
|
||||
log.Printf("using %d built-in content probes", len(ContentProbes))
|
||||
scorerOutput, err = prober.ProbeContent(*model)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
log.Fatalf("probe: %v", err)
|
||||
}
|
||||
|
||||
if writeErr := WriteScores(*output, scorerOutput); writeErr != nil {
|
||||
log.Fatalf("write scores: %v", writeErr)
|
||||
}
|
||||
|
||||
log.Printf("wrote scores to %s", *output)
|
||||
}
|
||||
Loading…
Add table
Reference in a new issue