diff --git a/main.go b/main.go index 89d281f..c57e0ba 100644 --- a/main.go +++ b/main.go @@ -17,6 +17,7 @@ Scoring: probe Generate responses and score them compare Compare two score files tier-score Score expansion responses (heuristic/judge tiers) + agent ROCm scoring daemon (polls M3, scores checkpoints) Generation: expand Generate expansion responses via trained LEM model @@ -98,6 +99,8 @@ func main() { lem.RunSeedInflux(os.Args[2:]) case "query": lem.RunQuery(os.Args[2:]) + case "agent": + lem.RunAgent(os.Args[2:]) default: fmt.Fprintf(os.Stderr, "unknown command: %s\n\n%s", os.Args[1], usage) os.Exit(1) diff --git a/pkg/lem/agent.go b/pkg/lem/agent.go new file mode 100644 index 0000000..23fd5d7 --- /dev/null +++ b/pkg/lem/agent.go @@ -0,0 +1,612 @@ +package lem + +import ( + "encoding/json" + "flag" + "fmt" + "log" + "os" + "os/exec" + "path/filepath" + "regexp" + "sort" + "strings" + "time" +) + +// agentConfig holds scoring agent configuration. +type agentConfig struct { + m3Host string + m3User string + m3SSHKey string + m3AdapterBase string + influxURL string + influxDB string + apiURL string + model string + baseModel string + pollInterval int + workDir string + oneShot bool + dryRun bool +} + +// checkpoint represents a discovered adapter checkpoint on M3. +type checkpoint struct { + RemoteDir string + Filename string + Dirname string + Iteration int + ModelTag string + Label string + RunID string +} + +// probeResult holds the result of running all probes against a checkpoint. +type probeResult struct { + Accuracy float64 `json:"accuracy"` + Correct int `json:"correct"` + Total int `json:"total"` + ByCategory map[string]categoryResult `json:"by_category"` + Probes map[string]singleProbeResult `json:"probes"` +} + +type categoryResult struct { + Correct int `json:"correct"` + Total int `json:"total"` +} + +type singleProbeResult struct { + Passed bool `json:"passed"` + Response string `json:"response"` +} + +// bufferEntry is a JSONL-buffered result for when InfluxDB is down. +type bufferEntry struct { + Checkpoint checkpoint `json:"checkpoint"` + Results probeResult `json:"results"` + Timestamp string `json:"timestamp"` +} + +// RunAgent is the CLI entry point for the agent command. +// Polls M3 for unscored LoRA checkpoints, converts MLX → PEFT, +// runs 23 capability probes via an OpenAI-compatible API, and +// pushes results to InfluxDB. +func RunAgent(args []string) { + fs := flag.NewFlagSet("agent", flag.ExitOnError) + + cfg := &agentConfig{} + fs.StringVar(&cfg.m3Host, "m3-host", envOr("M3_HOST", "10.69.69.108"), "M3 host address") + fs.StringVar(&cfg.m3User, "m3-user", envOr("M3_USER", "claude"), "M3 SSH user") + fs.StringVar(&cfg.m3SSHKey, "m3-ssh-key", envOr("M3_SSH_KEY", expandHome("~/.ssh/id_ed25519")), "SSH key for M3") + fs.StringVar(&cfg.m3AdapterBase, "m3-adapter-base", envOr("M3_ADAPTER_BASE", "/Volumes/Data/lem"), "Adapter base dir on M3") + fs.StringVar(&cfg.influxURL, "influx", envOr("INFLUX_URL", "http://10.69.69.165:8181"), "InfluxDB URL") + fs.StringVar(&cfg.influxDB, "influx-db", envOr("INFLUX_DB", "training"), "InfluxDB database") + fs.StringVar(&cfg.apiURL, "api-url", envOr("LEM_API_URL", "http://localhost:8080"), "OpenAI-compatible inference API URL") + fs.StringVar(&cfg.model, "model", envOr("LEM_MODEL", ""), "Model name for API (overrides auto-detect)") + fs.StringVar(&cfg.baseModel, "base-model", envOr("BASE_MODEL", "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B"), "HuggingFace base model ID") + fs.IntVar(&cfg.pollInterval, "poll", intEnvOr("POLL_INTERVAL", 300), "Poll interval in seconds") + fs.StringVar(&cfg.workDir, "work-dir", envOr("WORK_DIR", "/tmp/scoring-agent"), "Working directory for adapters") + fs.BoolVar(&cfg.oneShot, "one-shot", false, "Process one checkpoint and exit") + fs.BoolVar(&cfg.dryRun, "dry-run", false, "Discover and plan but don't execute") + + if err := fs.Parse(args); err != nil { + log.Fatalf("parse flags: %v", err) + } + + runAgentLoop(cfg) +} + +func runAgentLoop(cfg *agentConfig) { + log.Println(strings.Repeat("=", 60)) + log.Println("ROCm Scoring Agent — Go Edition") + log.Printf("M3: %s@%s", cfg.m3User, cfg.m3Host) + log.Printf("Inference API: %s", cfg.apiURL) + log.Printf("InfluxDB: %s/%s", cfg.influxURL, cfg.influxDB) + log.Printf("Poll interval: %ds", cfg.pollInterval) + log.Println(strings.Repeat("=", 60)) + + influx := NewInfluxClient(cfg.influxURL, cfg.influxDB) + os.MkdirAll(cfg.workDir, 0755) + + for { + // Replay any buffered results. + replayInfluxBuffer(cfg.workDir, influx) + + // Discover checkpoints on M3. + log.Println("Discovering checkpoints on M3...") + checkpoints, err := discoverCheckpoints(cfg) + if err != nil { + log.Printf("Discovery failed: %v", err) + sleepOrExit(cfg) + continue + } + log.Printf("Found %d total checkpoints", len(checkpoints)) + + // Check what is already scored. + scored, err := getScoredLabels(influx) + if err != nil { + log.Printf("InfluxDB query failed: %v", err) + } + log.Printf("Already scored: %d (run_id, label) pairs", len(scored)) + + // Find unscored work. + unscored := findUnscored(checkpoints, scored) + log.Printf("Unscored: %d checkpoints", len(unscored)) + + if len(unscored) == 0 { + log.Printf("Nothing to score. Sleeping %ds...", cfg.pollInterval) + if cfg.oneShot { + return + } + time.Sleep(time.Duration(cfg.pollInterval) * time.Second) + continue + } + + target := unscored[0] + log.Printf("Grabbed: %s (%s)", target.Label, target.Dirname) + + if cfg.dryRun { + log.Printf("[DRY RUN] Would process: %s/%s", target.Dirname, target.Filename) + for _, u := range unscored[1:] { + log.Printf("[DRY RUN] Queued: %s/%s", u.Dirname, u.Filename) + } + return + } + + if err := processOne(cfg, influx, target); err != nil { + log.Printf("Error processing %s: %v", target.Label, err) + } + + if cfg.oneShot { + return + } + + time.Sleep(5 * time.Second) + } +} + +// discoverCheckpoints lists all adapter directories and checkpoint files on M3 via SSH. +func discoverCheckpoints(cfg *agentConfig) ([]checkpoint, error) { + out, err := sshCommand(cfg, fmt.Sprintf("ls -d %s/adapters-deepseek-r1-7b* 2>/dev/null", cfg.m3AdapterBase)) + if err != nil { + return nil, fmt.Errorf("list adapter dirs: %w", err) + } + + var checkpoints []checkpoint + iterRe := regexp.MustCompile(`(\d+)`) + + for _, dirpath := range strings.Split(strings.TrimSpace(out), "\n") { + if dirpath == "" { + continue + } + dirname := filepath.Base(dirpath) + + // List checkpoint safetensors files. + filesOut, err := sshCommand(cfg, fmt.Sprintf("ls %s/*_adapters.safetensors 2>/dev/null", dirpath)) + if err != nil { + continue + } + + for _, filepath := range strings.Split(strings.TrimSpace(filesOut), "\n") { + if filepath == "" { + continue + } + filename := fileBase(filepath) + + match := iterRe.FindStringSubmatch(filename) + if len(match) < 2 { + continue + } + iteration := 0 + fmt.Sscanf(match[1], "%d", &iteration) + + modelTag, labelPrefix, stem := adapterMeta(dirname) + label := fmt.Sprintf("%s @%s", labelPrefix, match[1]) + runID := fmt.Sprintf("%s-capability-auto", stem) + + checkpoints = append(checkpoints, checkpoint{ + RemoteDir: dirpath, + Filename: filename, + Dirname: dirname, + Iteration: iteration, + ModelTag: modelTag, + Label: label, + RunID: runID, + }) + } + } + + return checkpoints, nil +} + +// adapterMeta maps an adapter directory name to (model_tag, label_prefix, run_id_stem). +func adapterMeta(dirname string) (string, string, string) { + name := strings.TrimPrefix(dirname, "adapters-deepseek-r1-7b") + name = strings.TrimLeft(name, "-") + if name == "" { + name = "base" + } + + shortNames := map[string]string{ + "sovereignty": "R1-sov", + "russian": "R1-rus", + "composure": "R1-comp", + "sandwich": "R1-sand", + "sandwich-watts": "R1-sw", + "western": "R1-west", + "western-fresh": "R1-wf", + "base": "R1-base", + } + + short, ok := shortNames[name] + if !ok { + if len(name) > 4 { + short = "R1-" + name[:4] + } else { + short = "R1-" + name + } + } + + stem := "r1-" + name + if name == "base" { + stem = "r1-base" + } + + return "deepseek-r1-7b", short, stem +} + +// getScoredLabels returns all (run_id, label) pairs already scored in InfluxDB. +func getScoredLabels(influx *InfluxClient) (map[[2]string]bool, error) { + rows, err := influx.QuerySQL("SELECT DISTINCT run_id, label FROM capability_score") + if err != nil { + return nil, err + } + + scored := make(map[[2]string]bool) + for _, row := range rows { + runID, _ := row["run_id"].(string) + label, _ := row["label"].(string) + if runID != "" && label != "" { + scored[[2]string{runID, label}] = true + } + } + return scored, nil +} + +// findUnscored filters checkpoints to only unscored ones, sorted by (dirname, iteration). +func findUnscored(checkpoints []checkpoint, scored map[[2]string]bool) []checkpoint { + var unscored []checkpoint + for _, c := range checkpoints { + if !scored[[2]string{c.RunID, c.Label}] { + unscored = append(unscored, c) + } + } + sort.Slice(unscored, func(i, j int) bool { + if unscored[i].Dirname != unscored[j].Dirname { + return unscored[i].Dirname < unscored[j].Dirname + } + return unscored[i].Iteration < unscored[j].Iteration + }) + return unscored +} + +// processOne fetches, converts, scores, and pushes one checkpoint. +func processOne(cfg *agentConfig, influx *InfluxClient, cp checkpoint) error { + log.Println(strings.Repeat("=", 60)) + log.Printf("Processing: %s / %s", cp.Dirname, cp.Filename) + log.Println(strings.Repeat("=", 60)) + + localAdapterDir := filepath.Join(cfg.workDir, cp.Dirname) + os.MkdirAll(localAdapterDir, 0755) + + localSF := filepath.Join(localAdapterDir, cp.Filename) + localCfg := filepath.Join(localAdapterDir, "adapter_config.json") + + // Cleanup on exit. + defer func() { + os.Remove(localSF) + os.Remove(localCfg) + peftDir := filepath.Join(cfg.workDir, fmt.Sprintf("peft_%07d", cp.Iteration)) + os.RemoveAll(peftDir) + }() + + // Fetch adapter + config from M3. + log.Println("Fetching adapter from M3...") + remoteSF := fmt.Sprintf("%s/%s", cp.RemoteDir, cp.Filename) + remoteCfg := fmt.Sprintf("%s/adapter_config.json", cp.RemoteDir) + + if err := scpFrom(cfg, remoteSF, localSF); err != nil { + return fmt.Errorf("scp safetensors: %w", err) + } + if err := scpFrom(cfg, remoteCfg, localCfg); err != nil { + return fmt.Errorf("scp config: %w", err) + } + + // Convert MLX to PEFT format. + log.Println("Converting MLX to PEFT format...") + peftDir := filepath.Join(cfg.workDir, fmt.Sprintf("peft_%07d", cp.Iteration)) + if err := convertMLXtoPEFT(localAdapterDir, cp.Filename, peftDir, cfg.baseModel); err != nil { + return fmt.Errorf("convert adapter: %w", err) + } + + // Run 23 capability probes via API. + log.Println("Running 23 capability probes...") + modelName := cfg.model + if modelName == "" { + modelName = cp.ModelTag + } + client := NewClient(cfg.apiURL, modelName) + client.MaxTokens = 500 + + results := runCapabilityProbes(client) + + log.Printf("Result: %s -- %.1f%% (%d/%d)", + cp.Label, results.Accuracy, results.Correct, results.Total) + + // Push to InfluxDB (buffer on failure). + if err := pushCapabilityResults(influx, cp, results); err != nil { + log.Printf("InfluxDB push failed, buffering: %v", err) + bufferInfluxResult(cfg.workDir, cp, results) + } + + return nil +} + +// runCapabilityProbes runs all 23 probes against the inference API. +func runCapabilityProbes(client *Client) probeResult { + results := probeResult{ + ByCategory: make(map[string]categoryResult), + Probes: make(map[string]singleProbeResult), + } + + correct := 0 + total := 0 + + for _, probe := range CapabilityProbes { + response, err := client.ChatWithTemp(probe.Prompt, 0.1) + if err != nil { + log.Printf(" [%s] ERROR: %v", probe.ID, err) + results.Probes[probe.ID] = singleProbeResult{Passed: false, Response: err.Error()} + total++ + cat := results.ByCategory[probe.Category] + cat.Total++ + results.ByCategory[probe.Category] = cat + continue + } + + // Strip blocks from DeepSeek R1 responses. + clean := StripThinkBlocks(response) + + passed := probe.Check(clean) + total++ + if passed { + correct++ + } + + cat := results.ByCategory[probe.Category] + cat.Total++ + if passed { + cat.Correct++ + } + results.ByCategory[probe.Category] = cat + + // Truncate response for storage. + stored := clean + if len(stored) > 300 { + stored = stored[:300] + } + results.Probes[probe.ID] = singleProbeResult{Passed: passed, Response: stored} + + status := "FAIL" + if passed { + status = "PASS" + } + log.Printf(" [%s] %s (expected: %s)", probe.ID, status, probe.Answer) + } + + if total > 0 { + results.Accuracy = float64(correct) / float64(total) * 100 + } + results.Correct = correct + results.Total = total + + return results +} + +// pushCapabilityResults writes scoring results to InfluxDB as line protocol. +func pushCapabilityResults(influx *InfluxClient, cp checkpoint, results probeResult) error { + // Base timestamp: 2026-02-15T00:00:00Z = 1739577600 + const baseTS int64 = 1739577600 + + var lines []string + + // Overall score. + ts := (baseTS + int64(cp.Iteration)*1000 + 0) * 1_000_000_000 + lines = append(lines, fmt.Sprintf( + "capability_score,model=%s,run_id=%s,label=%s,category=overall accuracy=%.1f,correct=%di,total=%di,iteration=%di %d", + escapeLp(cp.ModelTag), escapeLp(cp.RunID), escapeLp(cp.Label), + results.Accuracy, results.Correct, results.Total, cp.Iteration, ts, + )) + + // Per-category scores (sorted for deterministic output). + cats := make([]string, 0, len(results.ByCategory)) + for cat := range results.ByCategory { + cats = append(cats, cat) + } + sort.Strings(cats) + + for i, cat := range cats { + data := results.ByCategory[cat] + catAcc := 0.0 + if data.Total > 0 { + catAcc = float64(data.Correct) / float64(data.Total) * 100 + } + ts := (baseTS + int64(cp.Iteration)*1000 + int64(i+1)) * 1_000_000_000 + lines = append(lines, fmt.Sprintf( + "capability_score,model=%s,run_id=%s,label=%s,category=%s accuracy=%.1f,correct=%di,total=%di,iteration=%di %d", + escapeLp(cp.ModelTag), escapeLp(cp.RunID), escapeLp(cp.Label), escapeLp(cat), + catAcc, data.Correct, data.Total, cp.Iteration, ts, + )) + } + + // Per-probe results (sorted). + probeIDs := make([]string, 0, len(results.Probes)) + for id := range results.Probes { + probeIDs = append(probeIDs, id) + } + sort.Strings(probeIDs) + + for j, probeID := range probeIDs { + probeRes := results.Probes[probeID] + passedInt := 0 + if probeRes.Passed { + passedInt = 1 + } + ts := (baseTS + int64(cp.Iteration)*1000 + int64(j+100)) * 1_000_000_000 + lines = append(lines, fmt.Sprintf( + "probe_score,model=%s,run_id=%s,label=%s,probe_id=%s passed=%di,iteration=%di %d", + escapeLp(cp.ModelTag), escapeLp(cp.RunID), escapeLp(cp.Label), escapeLp(probeID), + passedInt, cp.Iteration, ts, + )) + } + + if err := influx.WriteLp(lines); err != nil { + return err + } + log.Printf("Pushed %d points to InfluxDB for %s", len(lines), cp.Label) + return nil +} + +// bufferInfluxResult saves results to a local JSONL file when InfluxDB is down. +func bufferInfluxResult(workDir string, cp checkpoint, results probeResult) { + bufPath := filepath.Join(workDir, "influx_buffer.jsonl") + f, err := os.OpenFile(bufPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) + if err != nil { + log.Printf("Cannot open buffer file: %v", err) + return + } + defer f.Close() + + entry := bufferEntry{ + Checkpoint: cp, + Results: results, + Timestamp: time.Now().UTC().Format(time.RFC3339), + } + data, _ := json.Marshal(entry) + f.Write(append(data, '\n')) + log.Printf("Buffered results to %s", bufPath) +} + +// replayInfluxBuffer retries pushing buffered results to InfluxDB. +func replayInfluxBuffer(workDir string, influx *InfluxClient) { + bufPath := filepath.Join(workDir, "influx_buffer.jsonl") + data, err := os.ReadFile(bufPath) + if err != nil { + return // No buffer file. + } + + var remaining []string + for _, line := range strings.Split(strings.TrimSpace(string(data)), "\n") { + if line == "" { + continue + } + var entry bufferEntry + if err := json.Unmarshal([]byte(line), &entry); err != nil { + remaining = append(remaining, line) + continue + } + if err := pushCapabilityResults(influx, entry.Checkpoint, entry.Results); err != nil { + remaining = append(remaining, line) + } else { + log.Printf("Replayed buffered result: %s", entry.Checkpoint.Label) + } + } + + if len(remaining) > 0 { + os.WriteFile(bufPath, []byte(strings.Join(remaining, "\n")+"\n"), 0644) + } else { + os.Remove(bufPath) + log.Println("Buffer fully replayed and cleared") + } +} + +// sshCommand executes a command on M3 via SSH. +func sshCommand(cfg *agentConfig, cmd string) (string, error) { + sshArgs := []string{ + "-o", "ConnectTimeout=10", + "-o", "BatchMode=yes", + "-o", "StrictHostKeyChecking=no", + "-i", cfg.m3SSHKey, + fmt.Sprintf("%s@%s", cfg.m3User, cfg.m3Host), + cmd, + } + result, err := exec.Command("ssh", sshArgs...).CombinedOutput() + if err != nil { + return "", fmt.Errorf("ssh %q: %w: %s", cmd, err, strings.TrimSpace(string(result))) + } + return string(result), nil +} + +// scpFrom copies a file from M3 to a local path. +func scpFrom(cfg *agentConfig, remotePath, localPath string) error { + os.MkdirAll(filepath.Dir(localPath), 0755) + scpArgs := []string{ + "-o", "ConnectTimeout=10", + "-o", "BatchMode=yes", + "-o", "StrictHostKeyChecking=no", + "-i", cfg.m3SSHKey, + fmt.Sprintf("%s@%s:%s", cfg.m3User, cfg.m3Host, remotePath), + localPath, + } + result, err := exec.Command("scp", scpArgs...).CombinedOutput() + if err != nil { + return fmt.Errorf("scp %s: %w: %s", remotePath, err, strings.TrimSpace(string(result))) + } + return nil +} + +// fileBase returns the last component of a path (works for both / and \). +func fileBase(path string) string { + if i := strings.LastIndexAny(path, "/\\"); i >= 0 { + return path[i+1:] + } + return path +} + +func sleepOrExit(cfg *agentConfig) { + if cfg.oneShot { + return + } + time.Sleep(time.Duration(cfg.pollInterval) * time.Second) +} + +func envOr(key, fallback string) string { + if v := os.Getenv(key); v != "" { + return v + } + return fallback +} + +func intEnvOr(key string, fallback int) int { + v := os.Getenv(key) + if v == "" { + return fallback + } + var n int + fmt.Sscanf(v, "%d", &n) + if n == 0 { + return fallback + } + return n +} + +func expandHome(path string) string { + if strings.HasPrefix(path, "~/") { + home, err := os.UserHomeDir() + if err == nil { + return filepath.Join(home, path[2:]) + } + } + return path +} diff --git a/pkg/lem/agent_test.go b/pkg/lem/agent_test.go new file mode 100644 index 0000000..24942d5 --- /dev/null +++ b/pkg/lem/agent_test.go @@ -0,0 +1,314 @@ +package lem + +import ( + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "strings" + "testing" +) + +func TestAdapterMeta(t *testing.T) { + tests := []struct { + dirname string + wantModel, wantShort string + wantStem string + }{ + {"adapters-deepseek-r1-7b-sovereignty", "deepseek-r1-7b", "R1-sov", "r1-sovereignty"}, + {"adapters-deepseek-r1-7b-russian", "deepseek-r1-7b", "R1-rus", "r1-russian"}, + {"adapters-deepseek-r1-7b-composure", "deepseek-r1-7b", "R1-comp", "r1-composure"}, + {"adapters-deepseek-r1-7b-sandwich", "deepseek-r1-7b", "R1-sand", "r1-sandwich"}, + {"adapters-deepseek-r1-7b-sandwich-watts", "deepseek-r1-7b", "R1-sw", "r1-sandwich-watts"}, + {"adapters-deepseek-r1-7b-western", "deepseek-r1-7b", "R1-west", "r1-western"}, + {"adapters-deepseek-r1-7b-western-fresh", "deepseek-r1-7b", "R1-wf", "r1-western-fresh"}, + {"adapters-deepseek-r1-7b", "deepseek-r1-7b", "R1-base", "r1-base"}, + {"adapters-deepseek-r1-7b-custom", "deepseek-r1-7b", "R1-cust", "r1-custom"}, + } + + for _, tt := range tests { + model, short, stem := adapterMeta(tt.dirname) + if model != tt.wantModel || short != tt.wantShort || stem != tt.wantStem { + t.Errorf("adapterMeta(%q) = (%q, %q, %q), want (%q, %q, %q)", + tt.dirname, model, short, stem, tt.wantModel, tt.wantShort, tt.wantStem) + } + } +} + +func TestFindUnscored(t *testing.T) { + checkpoints := []checkpoint{ + {RunID: "r1-sov-capability-auto", Label: "R1-sov @100", Dirname: "a", Iteration: 100}, + {RunID: "r1-sov-capability-auto", Label: "R1-sov @200", Dirname: "a", Iteration: 200}, + {RunID: "r1-sov-capability-auto", Label: "R1-sov @300", Dirname: "a", Iteration: 300}, + } + + scored := map[[2]string]bool{ + {"r1-sov-capability-auto", "R1-sov @100"}: true, + {"r1-sov-capability-auto", "R1-sov @200"}: true, + } + + unscored := findUnscored(checkpoints, scored) + if len(unscored) != 1 { + t.Fatalf("expected 1 unscored, got %d", len(unscored)) + } + if unscored[0].Label != "R1-sov @300" { + t.Errorf("expected R1-sov @300, got %s", unscored[0].Label) + } +} + +func TestFindUnscoredSorting(t *testing.T) { + checkpoints := []checkpoint{ + {RunID: "r1-a", Label: "a @300", Dirname: "a", Iteration: 300}, + {RunID: "r1-b", Label: "b @100", Dirname: "b", Iteration: 100}, + {RunID: "r1-a", Label: "a @100", Dirname: "a", Iteration: 100}, + } + + scored := make(map[[2]string]bool) + unscored := findUnscored(checkpoints, scored) + + if len(unscored) != 3 { + t.Fatalf("expected 3 unscored, got %d", len(unscored)) + } + // Should be sorted by dirname then iteration. + if unscored[0].Label != "a @100" { + t.Errorf("first should be a @100, got %s", unscored[0].Label) + } + if unscored[1].Label != "a @300" { + t.Errorf("second should be a @300, got %s", unscored[1].Label) + } + if unscored[2].Label != "b @100" { + t.Errorf("third should be b @100, got %s", unscored[2].Label) + } +} + +func TestRunCapabilityProbes(t *testing.T) { + // Mock an OpenAI-compatible API that returns correct answers. + answers := map[string]string{ + "What is 347": "The answer is 10063.", + "A store sells": "You get $28.75 in change.", + "Solve for x": "x = -12", + "If f(x)": "f(4) = 21", + "A bag has": "The probability is 1/2 or 0.5", + "A circle has": "The area is 153.94 cm²", + "next number": "The next number is 162.", + "laptop costs": "The final price is $612.", + "All cats": "Yes, a cat needs water.", + "If it rains": "No, we cannot conclude that.", + "room of 30": "The minimum is 3 people sharing a birth month.", + "farmer needs": "Take the chicken first.", + "class of 40": "5 students play neither.", + "Book is to": "eating", + "car won't start": "The starter motor is faulty.", + "facing north": "You are facing south.", + "Event A": "Event C happened in 1991.", + "APPLE = 50": "CAT = 24", + "Python code": "[2, 3]", + "def f(n)": "The output is 8.", + "code has a bug": "ZeroDivisionError when empty list.", + "train travels": "It takes 3 hours.", + "twice as many": "There are 7 children.", + } + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var req ChatRequest + json.NewDecoder(r.Body).Decode(&req) + + prompt := "" + for _, m := range req.Messages { + if m.Role == "user" { + prompt = m.Content + break + } + } + + response := "I don't know." + for prefix, ans := range answers { + if strings.Contains(prompt, prefix) { + response = ans + break + } + } + + json.NewEncoder(w).Encode(ChatResponse{ + Choices: []Choice{{Message: Message{Role: "assistant", Content: response}}}, + }) + })) + defer server.Close() + + client := NewClient(server.URL, "test-model") + client.MaxTokens = 500 + + results := runCapabilityProbes(client) + + if results.Total != 23 { + t.Errorf("expected 23 total probes, got %d", results.Total) + } + if results.Correct != 23 { + t.Errorf("expected 23 correct, got %d (accuracy: %.1f%%)", results.Correct, results.Accuracy) + } + if results.Accuracy != 100.0 { + t.Errorf("expected 100%% accuracy, got %.1f%%", results.Accuracy) + } +} + +func TestPushCapabilityResults(t *testing.T) { + var writtenLines []string + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/api/v3/write_lp" { + body := make([]byte, r.ContentLength) + r.Body.Read(body) + writtenLines = strings.Split(strings.TrimSpace(string(body)), "\n") + w.WriteHeader(http.StatusNoContent) + } + })) + defer server.Close() + + influx := &InfluxClient{url: server.URL, db: "test", token: "t"} + + cp := checkpoint{ + ModelTag: "deepseek-r1-7b", + RunID: "r1-sov-capability-auto", + Label: "R1-sov @100", + Iteration: 100, + } + + results := probeResult{ + Accuracy: 87.0, + Correct: 20, + Total: 23, + ByCategory: map[string]categoryResult{ + "arithmetic": {Correct: 2, Total: 2}, + "code": {Correct: 2, Total: 3}, + }, + Probes: map[string]singleProbeResult{ + "math_01": {Passed: true, Response: "10063"}, + "math_02": {Passed: true, Response: "28.75"}, + "code_03": {Passed: false, Response: "I'm not sure."}, + }, + } + + err := pushCapabilityResults(influx, cp, results) + if err != nil { + t.Fatalf("push failed: %v", err) + } + + // 1 overall + 2 categories + 3 probes = 6 lines. + if len(writtenLines) != 6 { + t.Errorf("expected 6 lines, got %d", len(writtenLines)) + for i, l := range writtenLines { + t.Logf(" line %d: %s", i, l) + } + } + + // Check overall line. + if !strings.HasPrefix(writtenLines[0], "capability_score,") { + t.Errorf("first line should be capability_score, got: %s", writtenLines[0]) + } + if !strings.Contains(writtenLines[0], "category=overall") { + t.Errorf("first line should have category=overall, got: %s", writtenLines[0]) + } + if !strings.Contains(writtenLines[0], "accuracy=87.0") { + t.Errorf("first line should have accuracy=87.0, got: %s", writtenLines[0]) + } +} + +func TestBufferAndReplay(t *testing.T) { + tmpDir := t.TempDir() + + cp := checkpoint{ + ModelTag: "test-model", + RunID: "test-run", + Label: "test @100", + Iteration: 100, + } + results := probeResult{ + Accuracy: 50.0, + Correct: 1, + Total: 2, + ByCategory: map[string]categoryResult{ + "arithmetic": {Correct: 1, Total: 2}, + }, + Probes: map[string]singleProbeResult{ + "math_01": {Passed: true, Response: "10063"}, + "math_02": {Passed: false, Response: "wrong"}, + }, + } + + // Buffer a result. + bufferInfluxResult(tmpDir, cp, results) + + // Verify buffer file exists. + bufPath := filepath.Join(tmpDir, "influx_buffer.jsonl") + data, err := os.ReadFile(bufPath) + if err != nil { + t.Fatalf("buffer file not created: %v", err) + } + if !strings.Contains(string(data), "test-run") { + t.Errorf("buffer should contain run_id, got: %s", string(data)) + } + + // Parse it. + var entry bufferEntry + if err := json.Unmarshal(data, &entry); err != nil { + t.Fatalf("parse buffer entry: %v", err) + } + if entry.Checkpoint.RunID != "test-run" { + t.Errorf("expected run_id=test-run, got %s", entry.Checkpoint.RunID) + } + + // Replay to a working InfluxDB. + replayCount := 0 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/api/v3/write_lp" { + replayCount++ + w.WriteHeader(http.StatusNoContent) + } + })) + defer server.Close() + + influx := &InfluxClient{url: server.URL, db: "test", token: "t"} + replayInfluxBuffer(tmpDir, influx) + + if replayCount == 0 { + t.Error("expected replay to push to InfluxDB") + } + + // Buffer should be cleared. + if _, err := os.Stat(bufPath); !os.IsNotExist(err) { + t.Error("buffer file should be removed after successful replay") + } +} + +func TestEnvOr(t *testing.T) { + // Test with env var set. + key := fmt.Sprintf("TEST_ENV_%d", os.Getpid()) + os.Setenv(key, "value") + defer os.Unsetenv(key) + + if got := envOr(key, "fallback"); got != "value" { + t.Errorf("envOr(%s) = %q, want %q", key, got, "value") + } + + if got := envOr("NONEXISTENT_"+key, "fallback"); got != "fallback" { + t.Errorf("envOr(nonexistent) = %q, want %q", got, "fallback") + } +} + +func TestFileBase(t *testing.T) { + tests := []struct { + input, want string + }{ + {"/foo/bar/baz.txt", "baz.txt"}, + {"baz.txt", "baz.txt"}, + {"/a/b/c", "c"}, + {"", ""}, + } + for _, tt := range tests { + if got := fileBase(tt.input); got != tt.want { + t.Errorf("fileBase(%q) = %q, want %q", tt.input, got, tt.want) + } + } +} diff --git a/pkg/lem/probes.go b/pkg/lem/probes.go new file mode 100644 index 0000000..91f4ab3 --- /dev/null +++ b/pkg/lem/probes.go @@ -0,0 +1,273 @@ +package lem + +import ( + "regexp" + "strings" +) + +// Probe defines a binary pass/fail capability check. +// Each probe sends a prompt to the model and evaluates the response +// with a Go function — no judge model needed. +type Probe struct { + ID string + Category string + Prompt string + Answer string + Check func(response string) bool +} + +// CapabilityProbes contains all 23 binary capability probes. +// Categories: arithmetic, algebra, probability, geometry, sequences, +// percentages, deduction, puzzles, sets, analogy, causal, spatial, +// temporal, pattern, code, word. +var CapabilityProbes = []Probe{ + // === MATH (8) === + { + ID: "math_01", + Category: "arithmetic", + Prompt: "What is 347 × 29? Show your work and give the final answer.", + Answer: "10063", + Check: func(r string) bool { + clean := strings.ReplaceAll(strings.ReplaceAll(r, ",", ""), " ", "") + return strings.Contains(clean, "10063") + }, + }, + { + ID: "math_02", + Category: "arithmetic", + Prompt: "A store sells apples for $1.25 each. If I buy 17 apples and pay with a $50 bill, how much change do I get?", + Answer: "28.75", + Check: func(r string) bool { + return strings.Contains(r, "28.75") || strings.Contains(r, "$28.75") + }, + }, + { + ID: "math_03", + Category: "algebra", + Prompt: "Solve for x: 3x + 7 = 2x - 5. What is x?", + Answer: "-12", + Check: func(r string) bool { + return regexp.MustCompile(`x\s*=\s*-\s*12|=\s*-12|-12`).MatchString(r) + }, + }, + { + ID: "math_04", + Category: "algebra", + Prompt: "If f(x) = 2x² - 3x + 1, what is f(4)?", + Answer: "21", + Check: func(r string) bool { + return regexp.MustCompile(`\b21\b`).MatchString(r) + }, + }, + { + ID: "math_05", + Category: "probability", + Prompt: "A bag has 3 red balls, 5 blue balls, and 2 green balls. What is the probability of drawing a blue ball? Express as a fraction and decimal.", + Answer: "1/2 or 0.5", + Check: func(r string) bool { + return strings.Contains(r, "1/2") || strings.Contains(r, "0.5") || + strings.Contains(r, "50%") || strings.Contains(r, "5/10") + }, + }, + { + ID: "math_06", + Category: "geometry", + Prompt: "A circle has a radius of 7cm. What is its area? Use pi = 3.14159.", + Answer: "153.94", + Check: func(r string) bool { + return regexp.MustCompile(`15[34]\.9|153\.9[0-9]|154\.0|49\s*[πpi]`).MatchString(r) + }, + }, + { + ID: "math_07", + Category: "sequences", + Prompt: "What is the next number in this sequence: 2, 6, 18, 54, ...?", + Answer: "162", + Check: func(r string) bool { + return strings.Contains(r, "162") + }, + }, + { + ID: "math_08", + Category: "percentages", + Prompt: "A laptop costs $800. It's on sale for 15% off. Then you have a coupon for 10% off the sale price. What is the final price?", + Answer: "612", + Check: func(r string) bool { + return regexp.MustCompile(`\$?612`).MatchString(r) + }, + }, + // === LOGIC (5) === + { + ID: "logic_01", + Category: "deduction", + Prompt: "All cats are animals. All animals need water. Does a cat need water? Explain your reasoning.", + Answer: "Yes", + Check: func(r string) bool { + return regexp.MustCompile(`(?i)\byes\b`).MatchString(r) + }, + }, + { + ID: "logic_02", + Category: "deduction", + Prompt: "If it rains, the ground gets wet. The ground is wet. Can we conclude it rained? Why or why not?", + Answer: "No - affirming the consequent fallacy", + Check: func(r string) bool { + lower := strings.ToLower(r) + return regexp.MustCompile(`\bno\b|\bcannot\b|\bcan't\b|not necessarily|fallac|other reason|doesn't mean`).MatchString(lower) + }, + }, + { + ID: "logic_03", + Category: "deduction", + Prompt: "In a room of 30 people, what is the minimum number of people that must share a birth month?", + Answer: "3", + Check: func(r string) bool { + lower := strings.ToLower(r) + has3 := regexp.MustCompile(`\b3\b|three`).MatchString(lower) + // Avoid matching "30" in the first 50 chars (restating the problem) + prefix := lower + if len(prefix) > 50 { + prefix = prefix[:50] + } + has30 := regexp.MustCompile(`\b30\b`).MatchString(prefix) + return has3 && !has30 + }, + }, + { + ID: "logic_04", + Category: "puzzles", + Prompt: "A farmer needs to cross a river with a fox, a chicken, and a bag of grain. The boat only holds the farmer and one item. If left alone, the fox eats the chicken, and the chicken eats the grain. What is the first thing the farmer should take across?", + Answer: "The chicken", + Check: func(r string) bool { + return regexp.MustCompile(`(?i)chicken|hen`).MatchString(r) + }, + }, + { + ID: "logic_05", + Category: "sets", + Prompt: "In a class of 40 students, 25 play football, 20 play basketball, and 10 play both. How many play neither?", + Answer: "5", + Check: func(r string) bool { + return regexp.MustCompile(`(?i)\b5\b|five`).MatchString(r) + }, + }, + // === REASONING (5) === + { + ID: "reason_01", + Category: "analogy", + Prompt: "Complete the analogy: Book is to reading as fork is to ___", + Answer: "eating", + Check: func(r string) bool { + return regexp.MustCompile(`(?i)eating|food|dining`).MatchString(r) + }, + }, + { + ID: "reason_02", + Category: "causal", + Prompt: "A car won't start. The battery is new. The fuel tank is full. The starter motor clicks but the engine doesn't turn. What is the most likely problem?", + Answer: "Starter motor / solenoid", + Check: func(r string) bool { + return regexp.MustCompile(`(?i)starter|solenoid|connection|terminal|corros|ground|wire`).MatchString(r) + }, + }, + { + ID: "reason_03", + Category: "spatial", + Prompt: "You're facing north. You turn right 90 degrees, then turn right 90 degrees again. What direction are you facing?", + Answer: "South", + Check: func(r string) bool { + return regexp.MustCompile(`(?i)\bsouth\b`).MatchString(r) + }, + }, + { + ID: "reason_04", + Category: "temporal", + Prompt: "Event A happened in 1995. Event B happened 12 years before Event A. Event C happened 8 years after Event B. In what year did Event C happen?", + Answer: "1991", + Check: func(r string) bool { + return strings.Contains(r, "1991") + }, + }, + { + ID: "reason_05", + Category: "pattern", + Prompt: "If APPLE = 50 (A=1, P=16, P=16, L=12, E=5), what does CAT equal using the same system?", + Answer: "24", + Check: func(r string) bool { + return regexp.MustCompile(`\b24\b`).MatchString(r) + }, + }, + // === CODE (3) === + { + ID: "code_01", + Category: "code", + Prompt: "What does this Python code print?\nx = [1, 2, 3, 4, 5]\nprint(x[1:3])", + Answer: "[2, 3]", + Check: func(r string) bool { + return strings.Contains(r, "[2, 3]") || strings.Contains(r, "[2,3]") + }, + }, + { + ID: "code_02", + Category: "code", + Prompt: "What is the output?\ndef f(n):\n if n <= 1: return n\n return f(n-1) + f(n-2)\nprint(f(6))", + Answer: "8", + Check: func(r string) bool { + return regexp.MustCompile(`\b8\b`).MatchString(r) + }, + }, + { + ID: "code_03", + Category: "code", + Prompt: "This code has a bug. What is it?\ndef average(numbers):\n total = 0\n for n in numbers:\n total += n\n return total / len(numbers)\nprint(average([]))", + Answer: "Division by zero", + Check: func(r string) bool { + return regexp.MustCompile(`(?i)divis.*zero|zero.*divis|empty|len.*0|ZeroDivision`).MatchString(r) + }, + }, + // === WORD PROBLEMS (2) === + { + ID: "word_01", + Category: "word", + Prompt: "A train travels at 60 km/h. Another train travels at 80 km/h in the same direction from the same station, leaving 1 hour later. How long after the second train departs will it catch the first?", + Answer: "3 hours", + Check: func(r string) bool { + return regexp.MustCompile(`(?i)\b3\b.*hour|three.*hour`).MatchString(r) + }, + }, + { + ID: "word_02", + Category: "word", + Prompt: "I have twice as many sisters as brothers. My sister has as many brothers as sisters. How many children are in my family? (I am male.)", + Answer: "7", + Check: func(r string) bool { + return regexp.MustCompile(`(?i)\b7\b|seven`).MatchString(r) + }, + }, +} + +// ProbeCategories returns sorted unique categories from CapabilityProbes. +func ProbeCategories() []string { + seen := make(map[string]bool) + var cats []string + for _, p := range CapabilityProbes { + if !seen[p.Category] { + seen[p.Category] = true + cats = append(cats, p.Category) + } + } + return cats +} + +// StripThinkBlocks removes ... blocks from DeepSeek R1 responses. +func StripThinkBlocks(s string) string { + re := regexp.MustCompile(`(?s).*?`) + clean := strings.TrimSpace(re.ReplaceAllString(s, "")) + if clean == "" && len(s) > 500 { + return s[:500] + } + if clean == "" { + return s + } + return clean +} diff --git a/pkg/lem/probes_test.go b/pkg/lem/probes_test.go new file mode 100644 index 0000000..b1b3844 --- /dev/null +++ b/pkg/lem/probes_test.go @@ -0,0 +1,140 @@ +package lem + +import ( + "testing" +) + +func TestProbeCount(t *testing.T) { + if got := len(CapabilityProbes); got != 23 { + t.Errorf("expected 23 probes, got %d", got) + } +} + +func TestProbeCategories(t *testing.T) { + cats := ProbeCategories() + if len(cats) == 0 { + t.Fatal("no categories") + } + // Should have at least these categories. + want := map[string]bool{ + "arithmetic": true, "algebra": true, "deduction": true, + "code": true, "word": true, + } + catSet := make(map[string]bool) + for _, c := range cats { + catSet[c] = true + } + for w := range want { + if !catSet[w] { + t.Errorf("missing category %q", w) + } + } +} + +func TestProbeChecks(t *testing.T) { + // Verify each probe's check function works with its expected answer. + tests := []struct { + id string + response string + want bool + }{ + // Math. + {"math_01", "The answer is 10063.", true}, + {"math_01", "The answer is 10064.", false}, + {"math_02", "You'd get $28.75 in change.", true}, + {"math_02", "You'd get $29.75 in change.", false}, + {"math_03", "x = -12", true}, + {"math_03", "x = 12", false}, + {"math_04", "f(4) = 21", true}, + {"math_04", "f(4) = 22", false}, + {"math_05", "The probability is 1/2 or 0.5", true}, + {"math_05", "The probability is 1/3", false}, + {"math_06", "The area is 153.94 cm²", true}, + {"math_06", "The area is 100 cm²", false}, + {"math_07", "The next number is 162.", true}, + {"math_07", "The next number is 163.", false}, + {"math_08", "The final price is $612.", true}, + {"math_08", "The final price is $600.", false}, + // Logic. + {"logic_01", "Yes, a cat needs water.", true}, + {"logic_01", "Maybe.", false}, + {"logic_02", "No, we cannot conclude that. It's the fallacy of affirming the consequent.", true}, + {"logic_02", "Yes, it rained.", false}, + {"logic_03", "The minimum is 3 people.", true}, + {"logic_03", "The minimum is 2 people.", false}, + {"logic_04", "Take the chicken first.", true}, + {"logic_04", "Take the fox first.", false}, + {"logic_05", "5 students play neither.", true}, + {"logic_05", "10 students play neither.", false}, + // Reasoning. + {"reason_01", "eating", true}, + {"reason_01", "building", false}, + {"reason_02", "The starter motor is likely faulty.", true}, + {"reason_02", "The tires are flat.", false}, + {"reason_03", "You are facing south.", true}, + {"reason_03", "You are facing north.", false}, + {"reason_04", "Event C happened in 1991.", true}, + {"reason_04", "Event C happened in 1990.", false}, + {"reason_05", "CAT = 24", true}, + {"reason_05", "CAT = 25", false}, + // Code. + {"code_01", "[2, 3]", true}, + {"code_01", "[1, 2, 3]", false}, + {"code_02", "The output is 8.", true}, + {"code_02", "The output is 7.", false}, + {"code_03", "Division by zero when the list is empty.", true}, + {"code_03", "There is no bug.", false}, + // Word. + {"word_01", "It takes 3 hours.", true}, + {"word_01", "It takes 4 hours.", false}, + {"word_02", "There are 7 children.", true}, + {"word_02", "There are 6 children.", false}, + } + + probeMap := make(map[string]Probe) + for _, p := range CapabilityProbes { + probeMap[p.ID] = p + } + + for _, tt := range tests { + probe, ok := probeMap[tt.id] + if !ok { + t.Errorf("probe %s not found", tt.id) + continue + } + got := probe.Check(tt.response) + if got != tt.want { + t.Errorf("probe %s: Check(%q) = %v, want %v", tt.id, tt.response, got, tt.want) + } + } +} + +func TestStripThinkBlocks(t *testing.T) { + tests := []struct { + input string + want string + }{ + { + "Let me think about this...The answer is 42.", + "The answer is 42.", + }, + { + "No think blocks here.", + "No think blocks here.", + }, + { + "First\nblockHello second world", + "Hello world", + }, + { + "", "", + }, + } + + for _, tt := range tests { + got := StripThinkBlocks(tt.input) + if got != tt.want { + t.Errorf("StripThinkBlocks(%q) = %q, want %q", tt.input, got, tt.want) + } + } +}