go-ml/agent_eval.go

399 lines
12 KiB
Go
Raw Permalink Normal View History

package ml
import (
"bufio"
"context"
"encoding/json"
"fmt"
"io"
"log"
"os"
"path/filepath"
"strings"
)
// ProbeResult holds the result of running all probes against a checkpoint.
type ProbeResult struct {
Accuracy float64 `json:"accuracy"`
Correct int `json:"correct"`
Total int `json:"total"`
ByCategory map[string]CategoryResult `json:"by_category"`
Probes map[string]SingleProbeResult `json:"probes"`
}
// CategoryResult holds pass/fail counts for a probe category.
type CategoryResult struct {
Correct int `json:"correct"`
Total int `json:"total"`
}
// SingleProbeResult holds the result of a single probe.
type SingleProbeResult struct {
Passed bool `json:"passed"`
Response string `json:"response"`
}
// ProbeCallback is called after each probe completes for real-time streaming.
type ProbeCallback func(probeID, category string, passed bool, response string, correct, total int)
// CapResponseEntry holds a capability probe response with its metadata for judge scoring.
type CapResponseEntry struct {
ProbeID string
Category string
Prompt string
Answer string
Response string
Passed bool
}
// ContentResponse holds a content probe response for later judging.
type ContentResponse struct {
Probe ContentProbe
Response string
}
// probeRunnerResponse is the JSON response from the Python probe runner.
type probeRunnerResponse struct {
Response string `json:"response"`
Error string `json:"error"`
Elapsed float64 `json:"elapsed"`
}
// processMLXNative scores a checkpoint using Ollama on M3.
func processMLXNative(cfg *AgentConfig, influx *InfluxClient, cp Checkpoint) error {
ollamaBase, ok := OllamaBaseModelMap[cp.ModelTag]
if !ok {
return fmt.Errorf("unknown Ollama model for tag %s", cp.ModelTag)
}
hfBase := HFBaseModelMap[cp.ModelTag]
if hfBase == "" {
hfBase = ollamaBase
}
tempModel := fmt.Sprintf("lem-%s-%d", cp.ModelTag, cp.Iteration)
localAdapterDir := filepath.Join(cfg.WorkDir, "adapter-"+cp.Dirname)
peftDir := filepath.Join(cfg.WorkDir, "peft-"+cp.Dirname)
os.MkdirAll(localAdapterDir, 0755)
defer func() {
os.RemoveAll(localAdapterDir)
os.RemoveAll(peftDir)
OllamaDeleteModel(cfg.JudgeURL, tempModel)
}()
log.Printf("Fetching adapter from M3 (%s)...", cp.Filename)
remoteSF := fmt.Sprintf("%s/%s", cp.RemoteDir, cp.Filename)
remoteCfg := fmt.Sprintf("%s/adapter_config.json", cp.RemoteDir)
localSF := filepath.Join(localAdapterDir, cp.Filename)
localCfg := filepath.Join(localAdapterDir, "adapter_config.json")
ctx := context.Background()
t := cfg.transport()
if err := t.CopyFrom(ctx, remoteSF, localSF); err != nil {
return fmt.Errorf("scp safetensors: %w", err)
}
if err := t.CopyFrom(ctx, remoteCfg, localCfg); err != nil {
return fmt.Errorf("scp config: %w", err)
}
log.Println("Converting MLX → PEFT format...")
if err := ConvertMLXtoPEFT(localSF, localCfg, peftDir, hfBase); err != nil {
return fmt.Errorf("convert adapter: %w", err)
}
log.Printf("Creating Ollama model %s (base: %s)...", tempModel, ollamaBase)
if err := OllamaCreateModel(cfg.JudgeURL, tempModel, ollamaBase, peftDir); err != nil {
return fmt.Errorf("ollama create: %w", err)
}
log.Printf("Ollama model %s ready", tempModel)
probeBackend := NewHTTPBackend(cfg.JudgeURL, tempModel)
results, fullResponses := RunCapabilityProbesFull(ctx, probeBackend, func(probeID, category string, passed bool, response string, correct, total int) {
passedInt := 0
if passed {
passedInt = 1
}
ts := (EpochBase + int64(cp.Iteration)*1000 + int64(total+100)) * 1_000_000_000
line := fmt.Sprintf(
MeasurementProbeScore+",model=%s,run_id=%s,label=%s,probe_id=%s passed=%di,iteration=%di %d",
EscapeLp(cp.ModelTag), EscapeLp(cp.RunID), EscapeLp(cp.Label), EscapeLp(probeID),
passedInt, cp.Iteration, ts,
)
if err := influx.WriteLp([]string{line}); err != nil {
log.Printf(" [%s] InfluxDB stream failed: %v", probeID, err)
}
})
log.Printf("Capability: %s -- %.1f%% (%d/%d)",
cp.Label, results.Accuracy, results.Correct, results.Total)
if err := PushCapabilitySummary(influx, cp, results); err != nil {
log.Printf("InfluxDB summary push failed, buffering: %v", err)
BufferInfluxResult(cfg.WorkDir, cp, results)
}
PushCapabilityResultsDB(cfg.DBPath, cp, results)
judgeBackend := NewHTTPBackend(cfg.JudgeURL, cfg.JudgeModel)
judge := NewJudge(judgeBackend)
log.Printf("Judging %d capability responses (0-10 quality scoring)...", len(fullResponses))
ScoreCapabilityAndPush(ctx, judge, influx, cp, fullResponses)
log.Printf("Running %d content probes (0-10 judge scoring)...", len(ContentProbes))
contentResponses := RunContentProbesViaAPI(ctx, probeBackend)
if len(contentResponses) > 0 {
contentRunID := strings.Replace(cp.RunID, "-capability-", "-content-", 1)
ScoreContentAndPush(ctx, judge, influx, cp, contentRunID, contentResponses)
}
return nil
}
// processWithConversion fetches adapter locally, converts MLX→PEFT, and scores.
func processWithConversion(cfg *AgentConfig, influx *InfluxClient, cp Checkpoint) error {
localAdapterDir := filepath.Join(cfg.WorkDir, cp.Dirname)
os.MkdirAll(localAdapterDir, 0755)
localSF := filepath.Join(localAdapterDir, cp.Filename)
localCfg := filepath.Join(localAdapterDir, "adapter_config.json")
defer func() {
os.Remove(localSF)
os.Remove(localCfg)
peftDir := filepath.Join(cfg.WorkDir, fmt.Sprintf("peft_%07d", cp.Iteration))
os.RemoveAll(peftDir)
}()
log.Println("Fetching adapter from M3...")
remoteSF := fmt.Sprintf("%s/%s", cp.RemoteDir, cp.Filename)
remoteCfg := fmt.Sprintf("%s/adapter_config.json", cp.RemoteDir)
ctx := context.Background()
t := cfg.transport()
if err := t.CopyFrom(ctx, remoteSF, localSF); err != nil {
return fmt.Errorf("scp safetensors: %w", err)
}
if err := t.CopyFrom(ctx, remoteCfg, localCfg); err != nil {
return fmt.Errorf("scp config: %w", err)
}
log.Println("Converting MLX to PEFT format...")
peftDir := filepath.Join(cfg.WorkDir, fmt.Sprintf("peft_%07d", cp.Iteration))
if err := ConvertMLXtoPEFT(localSF, localCfg, peftDir, cfg.BaseModel); err != nil {
return fmt.Errorf("convert adapter: %w", err)
}
log.Printf("Running %d capability probes...", len(CapabilityProbes))
modelName := cfg.Model
if modelName == "" {
modelName = cp.ModelTag
}
backend := NewHTTPBackend(cfg.APIURL, modelName)
results := RunCapabilityProbes(ctx, backend)
log.Printf("Result: %s -- %.1f%% (%d/%d)",
cp.Label, results.Accuracy, results.Correct, results.Total)
if err := PushCapabilityResults(influx, cp, results); err != nil {
log.Printf("InfluxDB push failed, buffering: %v", err)
BufferInfluxResult(cfg.WorkDir, cp, results)
}
PushCapabilityResultsDB(cfg.DBPath, cp, results)
return nil
}
// RunCapabilityProbes runs all capability probes against a backend.
func RunCapabilityProbes(ctx context.Context, backend Backend) ProbeResult {
results := ProbeResult{
ByCategory: make(map[string]CategoryResult),
Probes: make(map[string]SingleProbeResult),
}
correct := 0
total := 0
for _, probe := range CapabilityProbes {
res, err := backend.Generate(ctx, probe.Prompt, GenOpts{Temperature: CapabilityTemperature, MaxTokens: CapabilityMaxTokens})
if err != nil {
log.Printf(" [%s] ERROR: %v", probe.ID, err)
results.Probes[probe.ID] = SingleProbeResult{Passed: false, Response: err.Error()}
total++
cat := results.ByCategory[probe.Category]
cat.Total++
results.ByCategory[probe.Category] = cat
continue
}
clean := StripThinkBlocks(res.Text)
passed := probe.Check(clean)
total++
if passed {
correct++
}
cat := results.ByCategory[probe.Category]
cat.Total++
if passed {
cat.Correct++
}
results.ByCategory[probe.Category] = cat
stored := clean
if len(stored) > MaxStoredResponseLen {
stored = stored[:MaxStoredResponseLen]
}
results.Probes[probe.ID] = SingleProbeResult{Passed: passed, Response: stored}
status := "FAIL"
if passed {
status = "PASS"
}
log.Printf(" [%s] %s (expected: %s)", probe.ID, status, probe.Answer)
}
if total > 0 {
results.Accuracy = float64(correct) / float64(total) * 100
}
results.Correct = correct
results.Total = total
return results
}
// RunCapabilityProbesFull runs all probes via a backend and returns both
// aggregate results and full responses for judge scoring.
func RunCapabilityProbesFull(ctx context.Context, backend Backend, onProbe ProbeCallback) (ProbeResult, []CapResponseEntry) {
results := ProbeResult{
ByCategory: make(map[string]CategoryResult),
Probes: make(map[string]SingleProbeResult),
}
var fullResponses []CapResponseEntry
correct := 0
total := 0
for _, probe := range CapabilityProbes {
res, err := backend.Generate(ctx, probe.Prompt, GenOpts{Temperature: CapabilityTemperature, MaxTokens: CapabilityMaxTokens})
response := res.Text
if err != nil {
log.Printf(" [%s] ERROR: %v", probe.ID, err)
response = fmt.Sprintf("ERROR: %v", err)
}
clean := StripThinkBlocks(response)
passed := probe.Check(clean)
total++
if passed {
correct++
}
cat := results.ByCategory[probe.Category]
cat.Total++
if passed {
cat.Correct++
}
results.ByCategory[probe.Category] = cat
stored := clean
if len(stored) > MaxStoredResponseLen {
stored = stored[:MaxStoredResponseLen]
}
results.Probes[probe.ID] = SingleProbeResult{Passed: passed, Response: stored}
fullResponses = append(fullResponses, CapResponseEntry{
ProbeID: probe.ID,
Category: probe.Category,
Prompt: probe.Prompt,
Answer: probe.Answer,
Response: clean,
Passed: passed,
})
status := "FAIL"
if passed {
status = "PASS"
}
log.Printf(" [%s] %s (expected: %s)", probe.ID, status, probe.Answer)
if onProbe != nil {
onProbe(probe.ID, probe.Category, passed, stored, correct, total)
}
}
if total > 0 {
results.Accuracy = float64(correct) / float64(total) * 100
}
results.Correct = correct
results.Total = total
return results, fullResponses
}
// RunContentProbesViaAPI runs content probes via a backend.
func RunContentProbesViaAPI(ctx context.Context, backend Backend) []ContentResponse {
var responses []ContentResponse
for _, probe := range ContentProbes {
res, err := backend.Generate(ctx, probe.Prompt, GenOpts{Temperature: ContentTemperature, MaxTokens: ContentMaxTokens})
if err != nil {
log.Printf(" [content:%s] ERROR: %v", probe.ID, err)
continue
}
reply := StripThinkBlocks(res.Text)
log.Printf(" [content:%s] got %d chars", probe.ID, len(reply))
responses = append(responses, ContentResponse{
Probe: probe,
Response: reply,
})
}
return responses
}
// RunContentProbesViaRunner sends content probes through an SSH probe runner.
func RunContentProbesViaRunner(stdin io.WriteCloser, scanner *bufio.Scanner) []ContentResponse {
var responses []ContentResponse
for _, probe := range ContentProbes {
req := map[string]any{
"prompt": probe.Prompt,
"max_tokens": ContentMaxTokens,
"temp": ContentTemperature,
}
reqJSON, _ := json.Marshal(req)
fmt.Fprintf(stdin, "%s\n", reqJSON)
var response string
if scanner.Scan() {
var resp probeRunnerResponse
if err := json.Unmarshal(scanner.Bytes(), &resp); err != nil {
log.Printf(" [content:%s] parse error: %v", probe.ID, err)
continue
} else if resp.Error != "" {
log.Printf(" [content:%s] ERROR: %s", probe.ID, resp.Error)
continue
} else {
response = resp.Response
}
} else {
log.Printf(" [content:%s] no response from runner", probe.ID)
continue
}
response = StripThinkBlocks(response)
log.Printf(" [content:%s] got %d chars", probe.ID, len(response))
responses = append(responses, ContentResponse{
Probe: probe,
Response: response,
})
}
return responses
}