From c03c755104a31acde618f33d376666b17bde65db Mon Sep 17 00:00:00 2001 From: Snider Date: Tue, 17 Feb 2026 17:55:10 +0000 Subject: [PATCH] feat(ml): add benchmark command for baseline vs trained model comparison Runs the same prompts through baseline and fine-tuned models, scores both with the heuristic scorer, and outputs a comparison report with LEK score deltas and improvement/regression counts. Uses built-in content probes by default, or custom prompts file. Co-Authored-By: Virgil --- cmd/ml/cmd_benchmark.go | 301 +++++++++++++++++++++++++++++++++++ cmd/ml/cmd_benchmark_init.go | 7 + 2 files changed, 308 insertions(+) create mode 100644 cmd/ml/cmd_benchmark.go create mode 100644 cmd/ml/cmd_benchmark_init.go diff --git a/cmd/ml/cmd_benchmark.go b/cmd/ml/cmd_benchmark.go new file mode 100644 index 00000000..773f53ad --- /dev/null +++ b/cmd/ml/cmd_benchmark.go @@ -0,0 +1,301 @@ +//go:build darwin && arm64 + +package ml + +import ( + "context" + "encoding/json" + "fmt" + "log/slog" + "os" + "runtime" + "sort" + "time" + + "forge.lthn.ai/core/go-ai/ml" + "forge.lthn.ai/core/go/pkg/cli" +) + +var benchmarkCmd = &cli.Command{ + Use: "benchmark", + Short: "Compare baseline vs fine-tuned model on ethics probes", + Long: `Runs the same prompts through a baseline model and a fine-tuned model, +scores both using the heuristic scorer, and outputs a comparison. + +Uses the built-in LEK content probes by default. Optionally takes a +custom prompts JSONL file (same format as 'core ml score --input'). + +The fine-tuned model can be the same model directory with a LoRA adapter +loaded, or a separately merged model.`, + RunE: runBenchmark, +} + +var ( + benchmarkBaseline string + benchmarkTrained string + benchmarkPrompts string + benchmarkOutput string + benchmarkMaxTokens int + benchmarkTemp float64 + benchmarkMemLimit int +) + +func init() { + benchmarkCmd.Flags().StringVar(&benchmarkBaseline, "baseline", "", "Path to baseline model directory (required)") + benchmarkCmd.Flags().StringVar(&benchmarkTrained, "trained", "", "Path to fine-tuned model directory (required)") + benchmarkCmd.Flags().StringVar(&benchmarkPrompts, "prompts", "", "Custom prompts file (JSONL with 'prompt' field, or seeds JSON)") + benchmarkCmd.Flags().StringVar(&benchmarkOutput, "output", "benchmark.json", "Output comparison JSON file") + benchmarkCmd.Flags().IntVar(&benchmarkMaxTokens, "max-tokens", 1024, "Max tokens per response") + benchmarkCmd.Flags().Float64Var(&benchmarkTemp, "temperature", 0.4, "Sampling temperature") + benchmarkCmd.Flags().IntVar(&benchmarkMemLimit, "memory-limit", 24, "Metal memory limit in GB") + benchmarkCmd.MarkFlagRequired("baseline") + benchmarkCmd.MarkFlagRequired("trained") +} + +// benchmarkResult holds the comparison for a single prompt. +type benchmarkResult struct { + ID string `json:"id"` + Prompt string `json:"prompt"` + BaselineResponse string `json:"baseline_response"` + TrainedResponse string `json:"trained_response"` + BaselineLEK float64 `json:"baseline_lek_score"` + TrainedLEK float64 `json:"trained_lek_score"` + Delta float64 `json:"delta"` + + BaselineHeuristic *ml.HeuristicScores `json:"baseline_heuristic"` + TrainedHeuristic *ml.HeuristicScores `json:"trained_heuristic"` +} + +// benchmarkSummary holds aggregate comparison metrics. +type benchmarkSummary struct { + BaselineModel string `json:"baseline_model"` + TrainedModel string `json:"trained_model"` + TotalPrompts int `json:"total_prompts"` + AvgBaselineLEK float64 `json:"avg_baseline_lek"` + AvgTrainedLEK float64 `json:"avg_trained_lek"` + AvgDelta float64 `json:"avg_delta"` + Improved int `json:"improved"` + Regressed int `json:"regressed"` + Unchanged int `json:"unchanged"` + Duration string `json:"duration"` + Results []benchmarkResult `json:"results"` +} + +func runBenchmark(cmd *cli.Command, args []string) error { + start := time.Now() + + // Load prompts — either custom file or built-in probes + prompts, err := loadBenchmarkPrompts() + if err != nil { + return err + } + + slog.Info("benchmark: loaded prompts", "count", len(prompts)) + + opts := ml.GenOpts{ + Temperature: benchmarkTemp, + MaxTokens: benchmarkMaxTokens, + } + + // Generate baseline responses + slog.Info("benchmark: loading baseline model", "path", benchmarkBaseline) + baselineBackend, err := ml.NewMLXBackend(benchmarkBaseline) + if err != nil { + return fmt.Errorf("load baseline: %w", err) + } + + baselineResponses := make(map[string]string) + for i, p := range prompts { + slog.Info("benchmark: baseline", + "prompt", fmt.Sprintf("%d/%d", i+1, len(prompts)), + "id", p.id, + ) + resp, err := baselineBackend.Generate(context.Background(), p.prompt, opts) + if err != nil { + slog.Error("benchmark: baseline failed", "id", p.id, "error", err) + continue + } + baselineResponses[p.id] = resp + + if (i+1)%4 == 0 { + runtime.GC() + } + } + + // Force cleanup before loading second model + baselineBackend = nil + runtime.GC() + runtime.GC() + + // Generate trained responses + slog.Info("benchmark: loading trained model", "path", benchmarkTrained) + trainedBackend, err := ml.NewMLXBackend(benchmarkTrained) + if err != nil { + return fmt.Errorf("load trained: %w", err) + } + + trainedResponses := make(map[string]string) + for i, p := range prompts { + slog.Info("benchmark: trained", + "prompt", fmt.Sprintf("%d/%d", i+1, len(prompts)), + "id", p.id, + ) + resp, err := trainedBackend.Generate(context.Background(), p.prompt, opts) + if err != nil { + slog.Error("benchmark: trained failed", "id", p.id, "error", err) + continue + } + trainedResponses[p.id] = resp + + if (i+1)%4 == 0 { + runtime.GC() + } + } + + trainedBackend = nil + runtime.GC() + + // Score both sets + var results []benchmarkResult + var totalBaseline, totalTrained float64 + improved, regressed, unchanged := 0, 0, 0 + + for _, p := range prompts { + baseResp := baselineResponses[p.id] + trainResp := trainedResponses[p.id] + + if baseResp == "" || trainResp == "" { + continue + } + + baseH := ml.ScoreHeuristic(baseResp) + trainH := ml.ScoreHeuristic(trainResp) + delta := trainH.LEKScore - baseH.LEKScore + + totalBaseline += baseH.LEKScore + totalTrained += trainH.LEKScore + + if delta > 0.5 { + improved++ + } else if delta < -0.5 { + regressed++ + } else { + unchanged++ + } + + results = append(results, benchmarkResult{ + ID: p.id, + Prompt: p.prompt, + BaselineResponse: baseResp, + TrainedResponse: trainResp, + BaselineLEK: baseH.LEKScore, + TrainedLEK: trainH.LEKScore, + Delta: delta, + BaselineHeuristic: baseH, + TrainedHeuristic: trainH, + }) + } + + n := float64(len(results)) + if n == 0 { + return fmt.Errorf("no results to compare") + } + + summary := benchmarkSummary{ + BaselineModel: benchmarkBaseline, + TrainedModel: benchmarkTrained, + TotalPrompts: len(results), + AvgBaselineLEK: totalBaseline / n, + AvgTrainedLEK: totalTrained / n, + AvgDelta: (totalTrained - totalBaseline) / n, + Improved: improved, + Regressed: regressed, + Unchanged: unchanged, + Duration: time.Since(start).Round(time.Second).String(), + Results: results, + } + + // Write output + data, err := json.MarshalIndent(summary, "", " ") + if err != nil { + return fmt.Errorf("marshal output: %w", err) + } + if err := os.WriteFile(benchmarkOutput, data, 0644); err != nil { + return fmt.Errorf("write output: %w", err) + } + + // Print summary + fmt.Println() + fmt.Println("=== Benchmark Results ===") + fmt.Printf("Baseline: %s\n", benchmarkBaseline) + fmt.Printf("Trained: %s\n", benchmarkTrained) + fmt.Printf("Prompts: %d\n", len(results)) + fmt.Println() + fmt.Printf("Avg LEK (baseline): %+.2f\n", summary.AvgBaselineLEK) + fmt.Printf("Avg LEK (trained): %+.2f\n", summary.AvgTrainedLEK) + fmt.Printf("Avg Delta: %+.2f\n", summary.AvgDelta) + fmt.Println() + fmt.Printf("Improved: %d (%.0f%%)\n", improved, float64(improved)/n*100) + fmt.Printf("Regressed: %d (%.0f%%)\n", regressed, float64(regressed)/n*100) + fmt.Printf("Unchanged: %d (%.0f%%)\n", unchanged, float64(unchanged)/n*100) + fmt.Printf("Duration: %s\n", summary.Duration) + fmt.Printf("Output: %s\n", benchmarkOutput) + + return nil +} + +type benchPrompt struct { + id string + prompt string +} + +func loadBenchmarkPrompts() ([]benchPrompt, error) { + if benchmarkPrompts == "" { + // Use built-in content probes + probes := ml.ContentProbes + prompts := make([]benchPrompt, len(probes)) + for i, p := range probes { + prompts[i] = benchPrompt{id: p.ID, prompt: p.Prompt} + } + return prompts, nil + } + + // Try seeds JSON format first (array of {id, prompt, ...}) + data, err := os.ReadFile(benchmarkPrompts) + if err != nil { + return nil, fmt.Errorf("read prompts: %w", err) + } + + var seeds []seedPrompt + if json.Unmarshal(data, &seeds) == nil && len(seeds) > 0 { + prompts := make([]benchPrompt, len(seeds)) + for i, s := range seeds { + prompts[i] = benchPrompt{id: s.ID, prompt: s.Prompt} + } + return prompts, nil + } + + // Try JSONL responses format + responses, err := ml.ReadResponses(benchmarkPrompts) + if err != nil { + return nil, fmt.Errorf("parse prompts: %w", err) + } + + // Deduplicate by prompt + seen := make(map[string]bool) + var prompts []benchPrompt + for _, r := range responses { + if seen[r.Prompt] { + continue + } + seen[r.Prompt] = true + id := r.ID + if id == "" { + id = fmt.Sprintf("P%03d", len(prompts)+1) + } + prompts = append(prompts, benchPrompt{id: id, prompt: r.Prompt}) + } + + sort.Slice(prompts, func(i, j int) bool { return prompts[i].id < prompts[j].id }) + return prompts, nil +} diff --git a/cmd/ml/cmd_benchmark_init.go b/cmd/ml/cmd_benchmark_init.go new file mode 100644 index 00000000..ff4508ea --- /dev/null +++ b/cmd/ml/cmd_benchmark_init.go @@ -0,0 +1,7 @@ +//go:build darwin && arm64 + +package ml + +func init() { + mlCmd.AddCommand(benchmarkCmd) +}