feat: add scoring agent + 23 capability probes (replaces scoring_agent.py)

Go scoring daemon that polls M3 for unscored LoRA checkpoints,
converts MLX→PEFT, runs 23 binary capability probes via OpenAI-
compatible API, and pushes results to InfluxDB. Zero Python deps.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Claude 2026-02-15 17:22:40 +00:00
parent 91ee389377
commit 9fac5749c2
No known key found for this signature in database
GPG key ID: AF404715446AEB41
5 changed files with 1342 additions and 0 deletions

View file

@ -17,6 +17,7 @@ Scoring:
probe Generate responses and score them
compare Compare two score files
tier-score Score expansion responses (heuristic/judge tiers)
agent ROCm scoring daemon (polls M3, scores checkpoints)
Generation:
expand Generate expansion responses via trained LEM model
@ -98,6 +99,8 @@ func main() {
lem.RunSeedInflux(os.Args[2:])
case "query":
lem.RunQuery(os.Args[2:])
case "agent":
lem.RunAgent(os.Args[2:])
default:
fmt.Fprintf(os.Stderr, "unknown command: %s\n\n%s", os.Args[1], usage)
os.Exit(1)

612
pkg/lem/agent.go Normal file
View file

@ -0,0 +1,612 @@
package lem
import (
"encoding/json"
"flag"
"fmt"
"log"
"os"
"os/exec"
"path/filepath"
"regexp"
"sort"
"strings"
"time"
)
// agentConfig holds scoring agent configuration.
type agentConfig struct {
m3Host string
m3User string
m3SSHKey string
m3AdapterBase string
influxURL string
influxDB string
apiURL string
model string
baseModel string
pollInterval int
workDir string
oneShot bool
dryRun bool
}
// checkpoint represents a discovered adapter checkpoint on M3.
type checkpoint struct {
RemoteDir string
Filename string
Dirname string
Iteration int
ModelTag string
Label string
RunID string
}
// probeResult holds the result of running all probes against a checkpoint.
type probeResult struct {
Accuracy float64 `json:"accuracy"`
Correct int `json:"correct"`
Total int `json:"total"`
ByCategory map[string]categoryResult `json:"by_category"`
Probes map[string]singleProbeResult `json:"probes"`
}
type categoryResult struct {
Correct int `json:"correct"`
Total int `json:"total"`
}
type singleProbeResult struct {
Passed bool `json:"passed"`
Response string `json:"response"`
}
// bufferEntry is a JSONL-buffered result for when InfluxDB is down.
type bufferEntry struct {
Checkpoint checkpoint `json:"checkpoint"`
Results probeResult `json:"results"`
Timestamp string `json:"timestamp"`
}
// RunAgent is the CLI entry point for the agent command.
// Polls M3 for unscored LoRA checkpoints, converts MLX → PEFT,
// runs 23 capability probes via an OpenAI-compatible API, and
// pushes results to InfluxDB.
func RunAgent(args []string) {
fs := flag.NewFlagSet("agent", flag.ExitOnError)
cfg := &agentConfig{}
fs.StringVar(&cfg.m3Host, "m3-host", envOr("M3_HOST", "10.69.69.108"), "M3 host address")
fs.StringVar(&cfg.m3User, "m3-user", envOr("M3_USER", "claude"), "M3 SSH user")
fs.StringVar(&cfg.m3SSHKey, "m3-ssh-key", envOr("M3_SSH_KEY", expandHome("~/.ssh/id_ed25519")), "SSH key for M3")
fs.StringVar(&cfg.m3AdapterBase, "m3-adapter-base", envOr("M3_ADAPTER_BASE", "/Volumes/Data/lem"), "Adapter base dir on M3")
fs.StringVar(&cfg.influxURL, "influx", envOr("INFLUX_URL", "http://10.69.69.165:8181"), "InfluxDB URL")
fs.StringVar(&cfg.influxDB, "influx-db", envOr("INFLUX_DB", "training"), "InfluxDB database")
fs.StringVar(&cfg.apiURL, "api-url", envOr("LEM_API_URL", "http://localhost:8080"), "OpenAI-compatible inference API URL")
fs.StringVar(&cfg.model, "model", envOr("LEM_MODEL", ""), "Model name for API (overrides auto-detect)")
fs.StringVar(&cfg.baseModel, "base-model", envOr("BASE_MODEL", "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B"), "HuggingFace base model ID")
fs.IntVar(&cfg.pollInterval, "poll", intEnvOr("POLL_INTERVAL", 300), "Poll interval in seconds")
fs.StringVar(&cfg.workDir, "work-dir", envOr("WORK_DIR", "/tmp/scoring-agent"), "Working directory for adapters")
fs.BoolVar(&cfg.oneShot, "one-shot", false, "Process one checkpoint and exit")
fs.BoolVar(&cfg.dryRun, "dry-run", false, "Discover and plan but don't execute")
if err := fs.Parse(args); err != nil {
log.Fatalf("parse flags: %v", err)
}
runAgentLoop(cfg)
}
func runAgentLoop(cfg *agentConfig) {
log.Println(strings.Repeat("=", 60))
log.Println("ROCm Scoring Agent — Go Edition")
log.Printf("M3: %s@%s", cfg.m3User, cfg.m3Host)
log.Printf("Inference API: %s", cfg.apiURL)
log.Printf("InfluxDB: %s/%s", cfg.influxURL, cfg.influxDB)
log.Printf("Poll interval: %ds", cfg.pollInterval)
log.Println(strings.Repeat("=", 60))
influx := NewInfluxClient(cfg.influxURL, cfg.influxDB)
os.MkdirAll(cfg.workDir, 0755)
for {
// Replay any buffered results.
replayInfluxBuffer(cfg.workDir, influx)
// Discover checkpoints on M3.
log.Println("Discovering checkpoints on M3...")
checkpoints, err := discoverCheckpoints(cfg)
if err != nil {
log.Printf("Discovery failed: %v", err)
sleepOrExit(cfg)
continue
}
log.Printf("Found %d total checkpoints", len(checkpoints))
// Check what is already scored.
scored, err := getScoredLabels(influx)
if err != nil {
log.Printf("InfluxDB query failed: %v", err)
}
log.Printf("Already scored: %d (run_id, label) pairs", len(scored))
// Find unscored work.
unscored := findUnscored(checkpoints, scored)
log.Printf("Unscored: %d checkpoints", len(unscored))
if len(unscored) == 0 {
log.Printf("Nothing to score. Sleeping %ds...", cfg.pollInterval)
if cfg.oneShot {
return
}
time.Sleep(time.Duration(cfg.pollInterval) * time.Second)
continue
}
target := unscored[0]
log.Printf("Grabbed: %s (%s)", target.Label, target.Dirname)
if cfg.dryRun {
log.Printf("[DRY RUN] Would process: %s/%s", target.Dirname, target.Filename)
for _, u := range unscored[1:] {
log.Printf("[DRY RUN] Queued: %s/%s", u.Dirname, u.Filename)
}
return
}
if err := processOne(cfg, influx, target); err != nil {
log.Printf("Error processing %s: %v", target.Label, err)
}
if cfg.oneShot {
return
}
time.Sleep(5 * time.Second)
}
}
// discoverCheckpoints lists all adapter directories and checkpoint files on M3 via SSH.
func discoverCheckpoints(cfg *agentConfig) ([]checkpoint, error) {
out, err := sshCommand(cfg, fmt.Sprintf("ls -d %s/adapters-deepseek-r1-7b* 2>/dev/null", cfg.m3AdapterBase))
if err != nil {
return nil, fmt.Errorf("list adapter dirs: %w", err)
}
var checkpoints []checkpoint
iterRe := regexp.MustCompile(`(\d+)`)
for _, dirpath := range strings.Split(strings.TrimSpace(out), "\n") {
if dirpath == "" {
continue
}
dirname := filepath.Base(dirpath)
// List checkpoint safetensors files.
filesOut, err := sshCommand(cfg, fmt.Sprintf("ls %s/*_adapters.safetensors 2>/dev/null", dirpath))
if err != nil {
continue
}
for _, filepath := range strings.Split(strings.TrimSpace(filesOut), "\n") {
if filepath == "" {
continue
}
filename := fileBase(filepath)
match := iterRe.FindStringSubmatch(filename)
if len(match) < 2 {
continue
}
iteration := 0
fmt.Sscanf(match[1], "%d", &iteration)
modelTag, labelPrefix, stem := adapterMeta(dirname)
label := fmt.Sprintf("%s @%s", labelPrefix, match[1])
runID := fmt.Sprintf("%s-capability-auto", stem)
checkpoints = append(checkpoints, checkpoint{
RemoteDir: dirpath,
Filename: filename,
Dirname: dirname,
Iteration: iteration,
ModelTag: modelTag,
Label: label,
RunID: runID,
})
}
}
return checkpoints, nil
}
// adapterMeta maps an adapter directory name to (model_tag, label_prefix, run_id_stem).
func adapterMeta(dirname string) (string, string, string) {
name := strings.TrimPrefix(dirname, "adapters-deepseek-r1-7b")
name = strings.TrimLeft(name, "-")
if name == "" {
name = "base"
}
shortNames := map[string]string{
"sovereignty": "R1-sov",
"russian": "R1-rus",
"composure": "R1-comp",
"sandwich": "R1-sand",
"sandwich-watts": "R1-sw",
"western": "R1-west",
"western-fresh": "R1-wf",
"base": "R1-base",
}
short, ok := shortNames[name]
if !ok {
if len(name) > 4 {
short = "R1-" + name[:4]
} else {
short = "R1-" + name
}
}
stem := "r1-" + name
if name == "base" {
stem = "r1-base"
}
return "deepseek-r1-7b", short, stem
}
// getScoredLabels returns all (run_id, label) pairs already scored in InfluxDB.
func getScoredLabels(influx *InfluxClient) (map[[2]string]bool, error) {
rows, err := influx.QuerySQL("SELECT DISTINCT run_id, label FROM capability_score")
if err != nil {
return nil, err
}
scored := make(map[[2]string]bool)
for _, row := range rows {
runID, _ := row["run_id"].(string)
label, _ := row["label"].(string)
if runID != "" && label != "" {
scored[[2]string{runID, label}] = true
}
}
return scored, nil
}
// findUnscored filters checkpoints to only unscored ones, sorted by (dirname, iteration).
func findUnscored(checkpoints []checkpoint, scored map[[2]string]bool) []checkpoint {
var unscored []checkpoint
for _, c := range checkpoints {
if !scored[[2]string{c.RunID, c.Label}] {
unscored = append(unscored, c)
}
}
sort.Slice(unscored, func(i, j int) bool {
if unscored[i].Dirname != unscored[j].Dirname {
return unscored[i].Dirname < unscored[j].Dirname
}
return unscored[i].Iteration < unscored[j].Iteration
})
return unscored
}
// processOne fetches, converts, scores, and pushes one checkpoint.
func processOne(cfg *agentConfig, influx *InfluxClient, cp checkpoint) error {
log.Println(strings.Repeat("=", 60))
log.Printf("Processing: %s / %s", cp.Dirname, cp.Filename)
log.Println(strings.Repeat("=", 60))
localAdapterDir := filepath.Join(cfg.workDir, cp.Dirname)
os.MkdirAll(localAdapterDir, 0755)
localSF := filepath.Join(localAdapterDir, cp.Filename)
localCfg := filepath.Join(localAdapterDir, "adapter_config.json")
// Cleanup on exit.
defer func() {
os.Remove(localSF)
os.Remove(localCfg)
peftDir := filepath.Join(cfg.workDir, fmt.Sprintf("peft_%07d", cp.Iteration))
os.RemoveAll(peftDir)
}()
// Fetch adapter + config from M3.
log.Println("Fetching adapter from M3...")
remoteSF := fmt.Sprintf("%s/%s", cp.RemoteDir, cp.Filename)
remoteCfg := fmt.Sprintf("%s/adapter_config.json", cp.RemoteDir)
if err := scpFrom(cfg, remoteSF, localSF); err != nil {
return fmt.Errorf("scp safetensors: %w", err)
}
if err := scpFrom(cfg, remoteCfg, localCfg); err != nil {
return fmt.Errorf("scp config: %w", err)
}
// Convert MLX to PEFT format.
log.Println("Converting MLX to PEFT format...")
peftDir := filepath.Join(cfg.workDir, fmt.Sprintf("peft_%07d", cp.Iteration))
if err := convertMLXtoPEFT(localAdapterDir, cp.Filename, peftDir, cfg.baseModel); err != nil {
return fmt.Errorf("convert adapter: %w", err)
}
// Run 23 capability probes via API.
log.Println("Running 23 capability probes...")
modelName := cfg.model
if modelName == "" {
modelName = cp.ModelTag
}
client := NewClient(cfg.apiURL, modelName)
client.MaxTokens = 500
results := runCapabilityProbes(client)
log.Printf("Result: %s -- %.1f%% (%d/%d)",
cp.Label, results.Accuracy, results.Correct, results.Total)
// Push to InfluxDB (buffer on failure).
if err := pushCapabilityResults(influx, cp, results); err != nil {
log.Printf("InfluxDB push failed, buffering: %v", err)
bufferInfluxResult(cfg.workDir, cp, results)
}
return nil
}
// runCapabilityProbes runs all 23 probes against the inference API.
func runCapabilityProbes(client *Client) probeResult {
results := probeResult{
ByCategory: make(map[string]categoryResult),
Probes: make(map[string]singleProbeResult),
}
correct := 0
total := 0
for _, probe := range CapabilityProbes {
response, err := client.ChatWithTemp(probe.Prompt, 0.1)
if err != nil {
log.Printf(" [%s] ERROR: %v", probe.ID, err)
results.Probes[probe.ID] = singleProbeResult{Passed: false, Response: err.Error()}
total++
cat := results.ByCategory[probe.Category]
cat.Total++
results.ByCategory[probe.Category] = cat
continue
}
// Strip <think> blocks from DeepSeek R1 responses.
clean := StripThinkBlocks(response)
passed := probe.Check(clean)
total++
if passed {
correct++
}
cat := results.ByCategory[probe.Category]
cat.Total++
if passed {
cat.Correct++
}
results.ByCategory[probe.Category] = cat
// Truncate response for storage.
stored := clean
if len(stored) > 300 {
stored = stored[:300]
}
results.Probes[probe.ID] = singleProbeResult{Passed: passed, Response: stored}
status := "FAIL"
if passed {
status = "PASS"
}
log.Printf(" [%s] %s (expected: %s)", probe.ID, status, probe.Answer)
}
if total > 0 {
results.Accuracy = float64(correct) / float64(total) * 100
}
results.Correct = correct
results.Total = total
return results
}
// pushCapabilityResults writes scoring results to InfluxDB as line protocol.
func pushCapabilityResults(influx *InfluxClient, cp checkpoint, results probeResult) error {
// Base timestamp: 2026-02-15T00:00:00Z = 1739577600
const baseTS int64 = 1739577600
var lines []string
// Overall score.
ts := (baseTS + int64(cp.Iteration)*1000 + 0) * 1_000_000_000
lines = append(lines, fmt.Sprintf(
"capability_score,model=%s,run_id=%s,label=%s,category=overall accuracy=%.1f,correct=%di,total=%di,iteration=%di %d",
escapeLp(cp.ModelTag), escapeLp(cp.RunID), escapeLp(cp.Label),
results.Accuracy, results.Correct, results.Total, cp.Iteration, ts,
))
// Per-category scores (sorted for deterministic output).
cats := make([]string, 0, len(results.ByCategory))
for cat := range results.ByCategory {
cats = append(cats, cat)
}
sort.Strings(cats)
for i, cat := range cats {
data := results.ByCategory[cat]
catAcc := 0.0
if data.Total > 0 {
catAcc = float64(data.Correct) / float64(data.Total) * 100
}
ts := (baseTS + int64(cp.Iteration)*1000 + int64(i+1)) * 1_000_000_000
lines = append(lines, fmt.Sprintf(
"capability_score,model=%s,run_id=%s,label=%s,category=%s accuracy=%.1f,correct=%di,total=%di,iteration=%di %d",
escapeLp(cp.ModelTag), escapeLp(cp.RunID), escapeLp(cp.Label), escapeLp(cat),
catAcc, data.Correct, data.Total, cp.Iteration, ts,
))
}
// Per-probe results (sorted).
probeIDs := make([]string, 0, len(results.Probes))
for id := range results.Probes {
probeIDs = append(probeIDs, id)
}
sort.Strings(probeIDs)
for j, probeID := range probeIDs {
probeRes := results.Probes[probeID]
passedInt := 0
if probeRes.Passed {
passedInt = 1
}
ts := (baseTS + int64(cp.Iteration)*1000 + int64(j+100)) * 1_000_000_000
lines = append(lines, fmt.Sprintf(
"probe_score,model=%s,run_id=%s,label=%s,probe_id=%s passed=%di,iteration=%di %d",
escapeLp(cp.ModelTag), escapeLp(cp.RunID), escapeLp(cp.Label), escapeLp(probeID),
passedInt, cp.Iteration, ts,
))
}
if err := influx.WriteLp(lines); err != nil {
return err
}
log.Printf("Pushed %d points to InfluxDB for %s", len(lines), cp.Label)
return nil
}
// bufferInfluxResult saves results to a local JSONL file when InfluxDB is down.
func bufferInfluxResult(workDir string, cp checkpoint, results probeResult) {
bufPath := filepath.Join(workDir, "influx_buffer.jsonl")
f, err := os.OpenFile(bufPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
if err != nil {
log.Printf("Cannot open buffer file: %v", err)
return
}
defer f.Close()
entry := bufferEntry{
Checkpoint: cp,
Results: results,
Timestamp: time.Now().UTC().Format(time.RFC3339),
}
data, _ := json.Marshal(entry)
f.Write(append(data, '\n'))
log.Printf("Buffered results to %s", bufPath)
}
// replayInfluxBuffer retries pushing buffered results to InfluxDB.
func replayInfluxBuffer(workDir string, influx *InfluxClient) {
bufPath := filepath.Join(workDir, "influx_buffer.jsonl")
data, err := os.ReadFile(bufPath)
if err != nil {
return // No buffer file.
}
var remaining []string
for _, line := range strings.Split(strings.TrimSpace(string(data)), "\n") {
if line == "" {
continue
}
var entry bufferEntry
if err := json.Unmarshal([]byte(line), &entry); err != nil {
remaining = append(remaining, line)
continue
}
if err := pushCapabilityResults(influx, entry.Checkpoint, entry.Results); err != nil {
remaining = append(remaining, line)
} else {
log.Printf("Replayed buffered result: %s", entry.Checkpoint.Label)
}
}
if len(remaining) > 0 {
os.WriteFile(bufPath, []byte(strings.Join(remaining, "\n")+"\n"), 0644)
} else {
os.Remove(bufPath)
log.Println("Buffer fully replayed and cleared")
}
}
// sshCommand executes a command on M3 via SSH.
func sshCommand(cfg *agentConfig, cmd string) (string, error) {
sshArgs := []string{
"-o", "ConnectTimeout=10",
"-o", "BatchMode=yes",
"-o", "StrictHostKeyChecking=no",
"-i", cfg.m3SSHKey,
fmt.Sprintf("%s@%s", cfg.m3User, cfg.m3Host),
cmd,
}
result, err := exec.Command("ssh", sshArgs...).CombinedOutput()
if err != nil {
return "", fmt.Errorf("ssh %q: %w: %s", cmd, err, strings.TrimSpace(string(result)))
}
return string(result), nil
}
// scpFrom copies a file from M3 to a local path.
func scpFrom(cfg *agentConfig, remotePath, localPath string) error {
os.MkdirAll(filepath.Dir(localPath), 0755)
scpArgs := []string{
"-o", "ConnectTimeout=10",
"-o", "BatchMode=yes",
"-o", "StrictHostKeyChecking=no",
"-i", cfg.m3SSHKey,
fmt.Sprintf("%s@%s:%s", cfg.m3User, cfg.m3Host, remotePath),
localPath,
}
result, err := exec.Command("scp", scpArgs...).CombinedOutput()
if err != nil {
return fmt.Errorf("scp %s: %w: %s", remotePath, err, strings.TrimSpace(string(result)))
}
return nil
}
// fileBase returns the last component of a path (works for both / and \).
func fileBase(path string) string {
if i := strings.LastIndexAny(path, "/\\"); i >= 0 {
return path[i+1:]
}
return path
}
func sleepOrExit(cfg *agentConfig) {
if cfg.oneShot {
return
}
time.Sleep(time.Duration(cfg.pollInterval) * time.Second)
}
func envOr(key, fallback string) string {
if v := os.Getenv(key); v != "" {
return v
}
return fallback
}
func intEnvOr(key string, fallback int) int {
v := os.Getenv(key)
if v == "" {
return fallback
}
var n int
fmt.Sscanf(v, "%d", &n)
if n == 0 {
return fallback
}
return n
}
func expandHome(path string) string {
if strings.HasPrefix(path, "~/") {
home, err := os.UserHomeDir()
if err == nil {
return filepath.Join(home, path[2:])
}
}
return path
}

314
pkg/lem/agent_test.go Normal file
View file

@ -0,0 +1,314 @@
package lem
import (
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strings"
"testing"
)
func TestAdapterMeta(t *testing.T) {
tests := []struct {
dirname string
wantModel, wantShort string
wantStem string
}{
{"adapters-deepseek-r1-7b-sovereignty", "deepseek-r1-7b", "R1-sov", "r1-sovereignty"},
{"adapters-deepseek-r1-7b-russian", "deepseek-r1-7b", "R1-rus", "r1-russian"},
{"adapters-deepseek-r1-7b-composure", "deepseek-r1-7b", "R1-comp", "r1-composure"},
{"adapters-deepseek-r1-7b-sandwich", "deepseek-r1-7b", "R1-sand", "r1-sandwich"},
{"adapters-deepseek-r1-7b-sandwich-watts", "deepseek-r1-7b", "R1-sw", "r1-sandwich-watts"},
{"adapters-deepseek-r1-7b-western", "deepseek-r1-7b", "R1-west", "r1-western"},
{"adapters-deepseek-r1-7b-western-fresh", "deepseek-r1-7b", "R1-wf", "r1-western-fresh"},
{"adapters-deepseek-r1-7b", "deepseek-r1-7b", "R1-base", "r1-base"},
{"adapters-deepseek-r1-7b-custom", "deepseek-r1-7b", "R1-cust", "r1-custom"},
}
for _, tt := range tests {
model, short, stem := adapterMeta(tt.dirname)
if model != tt.wantModel || short != tt.wantShort || stem != tt.wantStem {
t.Errorf("adapterMeta(%q) = (%q, %q, %q), want (%q, %q, %q)",
tt.dirname, model, short, stem, tt.wantModel, tt.wantShort, tt.wantStem)
}
}
}
func TestFindUnscored(t *testing.T) {
checkpoints := []checkpoint{
{RunID: "r1-sov-capability-auto", Label: "R1-sov @100", Dirname: "a", Iteration: 100},
{RunID: "r1-sov-capability-auto", Label: "R1-sov @200", Dirname: "a", Iteration: 200},
{RunID: "r1-sov-capability-auto", Label: "R1-sov @300", Dirname: "a", Iteration: 300},
}
scored := map[[2]string]bool{
{"r1-sov-capability-auto", "R1-sov @100"}: true,
{"r1-sov-capability-auto", "R1-sov @200"}: true,
}
unscored := findUnscored(checkpoints, scored)
if len(unscored) != 1 {
t.Fatalf("expected 1 unscored, got %d", len(unscored))
}
if unscored[0].Label != "R1-sov @300" {
t.Errorf("expected R1-sov @300, got %s", unscored[0].Label)
}
}
func TestFindUnscoredSorting(t *testing.T) {
checkpoints := []checkpoint{
{RunID: "r1-a", Label: "a @300", Dirname: "a", Iteration: 300},
{RunID: "r1-b", Label: "b @100", Dirname: "b", Iteration: 100},
{RunID: "r1-a", Label: "a @100", Dirname: "a", Iteration: 100},
}
scored := make(map[[2]string]bool)
unscored := findUnscored(checkpoints, scored)
if len(unscored) != 3 {
t.Fatalf("expected 3 unscored, got %d", len(unscored))
}
// Should be sorted by dirname then iteration.
if unscored[0].Label != "a @100" {
t.Errorf("first should be a @100, got %s", unscored[0].Label)
}
if unscored[1].Label != "a @300" {
t.Errorf("second should be a @300, got %s", unscored[1].Label)
}
if unscored[2].Label != "b @100" {
t.Errorf("third should be b @100, got %s", unscored[2].Label)
}
}
func TestRunCapabilityProbes(t *testing.T) {
// Mock an OpenAI-compatible API that returns correct answers.
answers := map[string]string{
"What is 347": "The answer is 10063.",
"A store sells": "You get $28.75 in change.",
"Solve for x": "x = -12",
"If f(x)": "f(4) = 21",
"A bag has": "The probability is 1/2 or 0.5",
"A circle has": "The area is 153.94 cm²",
"next number": "The next number is 162.",
"laptop costs": "The final price is $612.",
"All cats": "Yes, a cat needs water.",
"If it rains": "No, we cannot conclude that.",
"room of 30": "The minimum is 3 people sharing a birth month.",
"farmer needs": "Take the chicken first.",
"class of 40": "5 students play neither.",
"Book is to": "eating",
"car won't start": "The starter motor is faulty.",
"facing north": "You are facing south.",
"Event A": "Event C happened in 1991.",
"APPLE = 50": "CAT = 24",
"Python code": "[2, 3]",
"def f(n)": "The output is 8.",
"code has a bug": "ZeroDivisionError when empty list.",
"train travels": "It takes 3 hours.",
"twice as many": "There are 7 children.",
}
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var req ChatRequest
json.NewDecoder(r.Body).Decode(&req)
prompt := ""
for _, m := range req.Messages {
if m.Role == "user" {
prompt = m.Content
break
}
}
response := "I don't know."
for prefix, ans := range answers {
if strings.Contains(prompt, prefix) {
response = ans
break
}
}
json.NewEncoder(w).Encode(ChatResponse{
Choices: []Choice{{Message: Message{Role: "assistant", Content: response}}},
})
}))
defer server.Close()
client := NewClient(server.URL, "test-model")
client.MaxTokens = 500
results := runCapabilityProbes(client)
if results.Total != 23 {
t.Errorf("expected 23 total probes, got %d", results.Total)
}
if results.Correct != 23 {
t.Errorf("expected 23 correct, got %d (accuracy: %.1f%%)", results.Correct, results.Accuracy)
}
if results.Accuracy != 100.0 {
t.Errorf("expected 100%% accuracy, got %.1f%%", results.Accuracy)
}
}
func TestPushCapabilityResults(t *testing.T) {
var writtenLines []string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/api/v3/write_lp" {
body := make([]byte, r.ContentLength)
r.Body.Read(body)
writtenLines = strings.Split(strings.TrimSpace(string(body)), "\n")
w.WriteHeader(http.StatusNoContent)
}
}))
defer server.Close()
influx := &InfluxClient{url: server.URL, db: "test", token: "t"}
cp := checkpoint{
ModelTag: "deepseek-r1-7b",
RunID: "r1-sov-capability-auto",
Label: "R1-sov @100",
Iteration: 100,
}
results := probeResult{
Accuracy: 87.0,
Correct: 20,
Total: 23,
ByCategory: map[string]categoryResult{
"arithmetic": {Correct: 2, Total: 2},
"code": {Correct: 2, Total: 3},
},
Probes: map[string]singleProbeResult{
"math_01": {Passed: true, Response: "10063"},
"math_02": {Passed: true, Response: "28.75"},
"code_03": {Passed: false, Response: "I'm not sure."},
},
}
err := pushCapabilityResults(influx, cp, results)
if err != nil {
t.Fatalf("push failed: %v", err)
}
// 1 overall + 2 categories + 3 probes = 6 lines.
if len(writtenLines) != 6 {
t.Errorf("expected 6 lines, got %d", len(writtenLines))
for i, l := range writtenLines {
t.Logf(" line %d: %s", i, l)
}
}
// Check overall line.
if !strings.HasPrefix(writtenLines[0], "capability_score,") {
t.Errorf("first line should be capability_score, got: %s", writtenLines[0])
}
if !strings.Contains(writtenLines[0], "category=overall") {
t.Errorf("first line should have category=overall, got: %s", writtenLines[0])
}
if !strings.Contains(writtenLines[0], "accuracy=87.0") {
t.Errorf("first line should have accuracy=87.0, got: %s", writtenLines[0])
}
}
func TestBufferAndReplay(t *testing.T) {
tmpDir := t.TempDir()
cp := checkpoint{
ModelTag: "test-model",
RunID: "test-run",
Label: "test @100",
Iteration: 100,
}
results := probeResult{
Accuracy: 50.0,
Correct: 1,
Total: 2,
ByCategory: map[string]categoryResult{
"arithmetic": {Correct: 1, Total: 2},
},
Probes: map[string]singleProbeResult{
"math_01": {Passed: true, Response: "10063"},
"math_02": {Passed: false, Response: "wrong"},
},
}
// Buffer a result.
bufferInfluxResult(tmpDir, cp, results)
// Verify buffer file exists.
bufPath := filepath.Join(tmpDir, "influx_buffer.jsonl")
data, err := os.ReadFile(bufPath)
if err != nil {
t.Fatalf("buffer file not created: %v", err)
}
if !strings.Contains(string(data), "test-run") {
t.Errorf("buffer should contain run_id, got: %s", string(data))
}
// Parse it.
var entry bufferEntry
if err := json.Unmarshal(data, &entry); err != nil {
t.Fatalf("parse buffer entry: %v", err)
}
if entry.Checkpoint.RunID != "test-run" {
t.Errorf("expected run_id=test-run, got %s", entry.Checkpoint.RunID)
}
// Replay to a working InfluxDB.
replayCount := 0
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/api/v3/write_lp" {
replayCount++
w.WriteHeader(http.StatusNoContent)
}
}))
defer server.Close()
influx := &InfluxClient{url: server.URL, db: "test", token: "t"}
replayInfluxBuffer(tmpDir, influx)
if replayCount == 0 {
t.Error("expected replay to push to InfluxDB")
}
// Buffer should be cleared.
if _, err := os.Stat(bufPath); !os.IsNotExist(err) {
t.Error("buffer file should be removed after successful replay")
}
}
func TestEnvOr(t *testing.T) {
// Test with env var set.
key := fmt.Sprintf("TEST_ENV_%d", os.Getpid())
os.Setenv(key, "value")
defer os.Unsetenv(key)
if got := envOr(key, "fallback"); got != "value" {
t.Errorf("envOr(%s) = %q, want %q", key, got, "value")
}
if got := envOr("NONEXISTENT_"+key, "fallback"); got != "fallback" {
t.Errorf("envOr(nonexistent) = %q, want %q", got, "fallback")
}
}
func TestFileBase(t *testing.T) {
tests := []struct {
input, want string
}{
{"/foo/bar/baz.txt", "baz.txt"},
{"baz.txt", "baz.txt"},
{"/a/b/c", "c"},
{"", ""},
}
for _, tt := range tests {
if got := fileBase(tt.input); got != tt.want {
t.Errorf("fileBase(%q) = %q, want %q", tt.input, got, tt.want)
}
}
}

273
pkg/lem/probes.go Normal file
View file

@ -0,0 +1,273 @@
package lem
import (
"regexp"
"strings"
)
// Probe defines a binary pass/fail capability check.
// Each probe sends a prompt to the model and evaluates the response
// with a Go function — no judge model needed.
type Probe struct {
ID string
Category string
Prompt string
Answer string
Check func(response string) bool
}
// CapabilityProbes contains all 23 binary capability probes.
// Categories: arithmetic, algebra, probability, geometry, sequences,
// percentages, deduction, puzzles, sets, analogy, causal, spatial,
// temporal, pattern, code, word.
var CapabilityProbes = []Probe{
// === MATH (8) ===
{
ID: "math_01",
Category: "arithmetic",
Prompt: "What is 347 × 29? Show your work and give the final answer.",
Answer: "10063",
Check: func(r string) bool {
clean := strings.ReplaceAll(strings.ReplaceAll(r, ",", ""), " ", "")
return strings.Contains(clean, "10063")
},
},
{
ID: "math_02",
Category: "arithmetic",
Prompt: "A store sells apples for $1.25 each. If I buy 17 apples and pay with a $50 bill, how much change do I get?",
Answer: "28.75",
Check: func(r string) bool {
return strings.Contains(r, "28.75") || strings.Contains(r, "$28.75")
},
},
{
ID: "math_03",
Category: "algebra",
Prompt: "Solve for x: 3x + 7 = 2x - 5. What is x?",
Answer: "-12",
Check: func(r string) bool {
return regexp.MustCompile(`x\s*=\s*-\s*12|=\s*-12|-12`).MatchString(r)
},
},
{
ID: "math_04",
Category: "algebra",
Prompt: "If f(x) = 2x² - 3x + 1, what is f(4)?",
Answer: "21",
Check: func(r string) bool {
return regexp.MustCompile(`\b21\b`).MatchString(r)
},
},
{
ID: "math_05",
Category: "probability",
Prompt: "A bag has 3 red balls, 5 blue balls, and 2 green balls. What is the probability of drawing a blue ball? Express as a fraction and decimal.",
Answer: "1/2 or 0.5",
Check: func(r string) bool {
return strings.Contains(r, "1/2") || strings.Contains(r, "0.5") ||
strings.Contains(r, "50%") || strings.Contains(r, "5/10")
},
},
{
ID: "math_06",
Category: "geometry",
Prompt: "A circle has a radius of 7cm. What is its area? Use pi = 3.14159.",
Answer: "153.94",
Check: func(r string) bool {
return regexp.MustCompile(`15[34]\.9|153\.9[0-9]|154\.0|49\s*[πpi]`).MatchString(r)
},
},
{
ID: "math_07",
Category: "sequences",
Prompt: "What is the next number in this sequence: 2, 6, 18, 54, ...?",
Answer: "162",
Check: func(r string) bool {
return strings.Contains(r, "162")
},
},
{
ID: "math_08",
Category: "percentages",
Prompt: "A laptop costs $800. It's on sale for 15% off. Then you have a coupon for 10% off the sale price. What is the final price?",
Answer: "612",
Check: func(r string) bool {
return regexp.MustCompile(`\$?612`).MatchString(r)
},
},
// === LOGIC (5) ===
{
ID: "logic_01",
Category: "deduction",
Prompt: "All cats are animals. All animals need water. Does a cat need water? Explain your reasoning.",
Answer: "Yes",
Check: func(r string) bool {
return regexp.MustCompile(`(?i)\byes\b`).MatchString(r)
},
},
{
ID: "logic_02",
Category: "deduction",
Prompt: "If it rains, the ground gets wet. The ground is wet. Can we conclude it rained? Why or why not?",
Answer: "No - affirming the consequent fallacy",
Check: func(r string) bool {
lower := strings.ToLower(r)
return regexp.MustCompile(`\bno\b|\bcannot\b|\bcan't\b|not necessarily|fallac|other reason|doesn't mean`).MatchString(lower)
},
},
{
ID: "logic_03",
Category: "deduction",
Prompt: "In a room of 30 people, what is the minimum number of people that must share a birth month?",
Answer: "3",
Check: func(r string) bool {
lower := strings.ToLower(r)
has3 := regexp.MustCompile(`\b3\b|three`).MatchString(lower)
// Avoid matching "30" in the first 50 chars (restating the problem)
prefix := lower
if len(prefix) > 50 {
prefix = prefix[:50]
}
has30 := regexp.MustCompile(`\b30\b`).MatchString(prefix)
return has3 && !has30
},
},
{
ID: "logic_04",
Category: "puzzles",
Prompt: "A farmer needs to cross a river with a fox, a chicken, and a bag of grain. The boat only holds the farmer and one item. If left alone, the fox eats the chicken, and the chicken eats the grain. What is the first thing the farmer should take across?",
Answer: "The chicken",
Check: func(r string) bool {
return regexp.MustCompile(`(?i)chicken|hen`).MatchString(r)
},
},
{
ID: "logic_05",
Category: "sets",
Prompt: "In a class of 40 students, 25 play football, 20 play basketball, and 10 play both. How many play neither?",
Answer: "5",
Check: func(r string) bool {
return regexp.MustCompile(`(?i)\b5\b|five`).MatchString(r)
},
},
// === REASONING (5) ===
{
ID: "reason_01",
Category: "analogy",
Prompt: "Complete the analogy: Book is to reading as fork is to ___",
Answer: "eating",
Check: func(r string) bool {
return regexp.MustCompile(`(?i)eating|food|dining`).MatchString(r)
},
},
{
ID: "reason_02",
Category: "causal",
Prompt: "A car won't start. The battery is new. The fuel tank is full. The starter motor clicks but the engine doesn't turn. What is the most likely problem?",
Answer: "Starter motor / solenoid",
Check: func(r string) bool {
return regexp.MustCompile(`(?i)starter|solenoid|connection|terminal|corros|ground|wire`).MatchString(r)
},
},
{
ID: "reason_03",
Category: "spatial",
Prompt: "You're facing north. You turn right 90 degrees, then turn right 90 degrees again. What direction are you facing?",
Answer: "South",
Check: func(r string) bool {
return regexp.MustCompile(`(?i)\bsouth\b`).MatchString(r)
},
},
{
ID: "reason_04",
Category: "temporal",
Prompt: "Event A happened in 1995. Event B happened 12 years before Event A. Event C happened 8 years after Event B. In what year did Event C happen?",
Answer: "1991",
Check: func(r string) bool {
return strings.Contains(r, "1991")
},
},
{
ID: "reason_05",
Category: "pattern",
Prompt: "If APPLE = 50 (A=1, P=16, P=16, L=12, E=5), what does CAT equal using the same system?",
Answer: "24",
Check: func(r string) bool {
return regexp.MustCompile(`\b24\b`).MatchString(r)
},
},
// === CODE (3) ===
{
ID: "code_01",
Category: "code",
Prompt: "What does this Python code print?\nx = [1, 2, 3, 4, 5]\nprint(x[1:3])",
Answer: "[2, 3]",
Check: func(r string) bool {
return strings.Contains(r, "[2, 3]") || strings.Contains(r, "[2,3]")
},
},
{
ID: "code_02",
Category: "code",
Prompt: "What is the output?\ndef f(n):\n if n <= 1: return n\n return f(n-1) + f(n-2)\nprint(f(6))",
Answer: "8",
Check: func(r string) bool {
return regexp.MustCompile(`\b8\b`).MatchString(r)
},
},
{
ID: "code_03",
Category: "code",
Prompt: "This code has a bug. What is it?\ndef average(numbers):\n total = 0\n for n in numbers:\n total += n\n return total / len(numbers)\nprint(average([]))",
Answer: "Division by zero",
Check: func(r string) bool {
return regexp.MustCompile(`(?i)divis.*zero|zero.*divis|empty|len.*0|ZeroDivision`).MatchString(r)
},
},
// === WORD PROBLEMS (2) ===
{
ID: "word_01",
Category: "word",
Prompt: "A train travels at 60 km/h. Another train travels at 80 km/h in the same direction from the same station, leaving 1 hour later. How long after the second train departs will it catch the first?",
Answer: "3 hours",
Check: func(r string) bool {
return regexp.MustCompile(`(?i)\b3\b.*hour|three.*hour`).MatchString(r)
},
},
{
ID: "word_02",
Category: "word",
Prompt: "I have twice as many sisters as brothers. My sister has as many brothers as sisters. How many children are in my family? (I am male.)",
Answer: "7",
Check: func(r string) bool {
return regexp.MustCompile(`(?i)\b7\b|seven`).MatchString(r)
},
},
}
// ProbeCategories returns sorted unique categories from CapabilityProbes.
func ProbeCategories() []string {
seen := make(map[string]bool)
var cats []string
for _, p := range CapabilityProbes {
if !seen[p.Category] {
seen[p.Category] = true
cats = append(cats, p.Category)
}
}
return cats
}
// StripThinkBlocks removes <think>...</think> blocks from DeepSeek R1 responses.
func StripThinkBlocks(s string) string {
re := regexp.MustCompile(`(?s)<think>.*?</think>`)
clean := strings.TrimSpace(re.ReplaceAllString(s, ""))
if clean == "" && len(s) > 500 {
return s[:500]
}
if clean == "" {
return s
}
return clean
}

140
pkg/lem/probes_test.go Normal file
View file

@ -0,0 +1,140 @@
package lem
import (
"testing"
)
func TestProbeCount(t *testing.T) {
if got := len(CapabilityProbes); got != 23 {
t.Errorf("expected 23 probes, got %d", got)
}
}
func TestProbeCategories(t *testing.T) {
cats := ProbeCategories()
if len(cats) == 0 {
t.Fatal("no categories")
}
// Should have at least these categories.
want := map[string]bool{
"arithmetic": true, "algebra": true, "deduction": true,
"code": true, "word": true,
}
catSet := make(map[string]bool)
for _, c := range cats {
catSet[c] = true
}
for w := range want {
if !catSet[w] {
t.Errorf("missing category %q", w)
}
}
}
func TestProbeChecks(t *testing.T) {
// Verify each probe's check function works with its expected answer.
tests := []struct {
id string
response string
want bool
}{
// Math.
{"math_01", "The answer is 10063.", true},
{"math_01", "The answer is 10064.", false},
{"math_02", "You'd get $28.75 in change.", true},
{"math_02", "You'd get $29.75 in change.", false},
{"math_03", "x = -12", true},
{"math_03", "x = 12", false},
{"math_04", "f(4) = 21", true},
{"math_04", "f(4) = 22", false},
{"math_05", "The probability is 1/2 or 0.5", true},
{"math_05", "The probability is 1/3", false},
{"math_06", "The area is 153.94 cm²", true},
{"math_06", "The area is 100 cm²", false},
{"math_07", "The next number is 162.", true},
{"math_07", "The next number is 163.", false},
{"math_08", "The final price is $612.", true},
{"math_08", "The final price is $600.", false},
// Logic.
{"logic_01", "Yes, a cat needs water.", true},
{"logic_01", "Maybe.", false},
{"logic_02", "No, we cannot conclude that. It's the fallacy of affirming the consequent.", true},
{"logic_02", "Yes, it rained.", false},
{"logic_03", "The minimum is 3 people.", true},
{"logic_03", "The minimum is 2 people.", false},
{"logic_04", "Take the chicken first.", true},
{"logic_04", "Take the fox first.", false},
{"logic_05", "5 students play neither.", true},
{"logic_05", "10 students play neither.", false},
// Reasoning.
{"reason_01", "eating", true},
{"reason_01", "building", false},
{"reason_02", "The starter motor is likely faulty.", true},
{"reason_02", "The tires are flat.", false},
{"reason_03", "You are facing south.", true},
{"reason_03", "You are facing north.", false},
{"reason_04", "Event C happened in 1991.", true},
{"reason_04", "Event C happened in 1990.", false},
{"reason_05", "CAT = 24", true},
{"reason_05", "CAT = 25", false},
// Code.
{"code_01", "[2, 3]", true},
{"code_01", "[1, 2, 3]", false},
{"code_02", "The output is 8.", true},
{"code_02", "The output is 7.", false},
{"code_03", "Division by zero when the list is empty.", true},
{"code_03", "There is no bug.", false},
// Word.
{"word_01", "It takes 3 hours.", true},
{"word_01", "It takes 4 hours.", false},
{"word_02", "There are 7 children.", true},
{"word_02", "There are 6 children.", false},
}
probeMap := make(map[string]Probe)
for _, p := range CapabilityProbes {
probeMap[p.ID] = p
}
for _, tt := range tests {
probe, ok := probeMap[tt.id]
if !ok {
t.Errorf("probe %s not found", tt.id)
continue
}
got := probe.Check(tt.response)
if got != tt.want {
t.Errorf("probe %s: Check(%q) = %v, want %v", tt.id, tt.response, got, tt.want)
}
}
}
func TestStripThinkBlocks(t *testing.T) {
tests := []struct {
input string
want string
}{
{
"<think>Let me think about this...</think>The answer is 42.",
"The answer is 42.",
},
{
"No think blocks here.",
"No think blocks here.",
},
{
"<think>First\nblock</think>Hello <think>second</think> world",
"Hello world",
},
{
"", "",
},
}
for _, tt := range tests {
got := StripThinkBlocks(tt.input)
if got != tt.want {
t.Errorf("StripThinkBlocks(%q) = %q, want %q", tt.input, got, tt.want)
}
}
}