go-ml/agent_execute.go
Snider 12f3a1c79d refactor: extract hardcoded values into package constants
Move magic numbers and strings into named constants in agent_config.go:
- EpochBase (was 1739577600 in 5 locations)
- 5 InfluxDB measurement names (capability_score, probe_score, etc.)
- 2 DuckDB table names (checkpoint_scores, probe_results)
- Probe defaults (temperature, max tokens, response truncation)
- InfluxBufferFile, LogSeparatorWidth, InterCheckpointDelay

Replace hardcoded probe counts ("23", "6") with len(CapabilityProbes),
len(ContentProbes). 7 files modified, no functional changes.

Co-Authored-By: Virgil <virgil@lethean.io>
2026-02-20 03:18:19 +00:00

218 lines
6 KiB
Go

package ml
import (
"context"
"fmt"
"log"
"os"
"regexp"
"sort"
"strings"
"time"
)
// RunAgentLoop is the main scoring agent loop.
func RunAgentLoop(cfg *AgentConfig) {
log.Println(strings.Repeat("=", LogSeparatorWidth))
log.Println("ROCm Scoring Agent — Go Edition")
log.Printf("M3: %s@%s", cfg.M3User, cfg.M3Host)
log.Printf("Inference API: %s", cfg.APIURL)
log.Printf("Judge API: %s (%s)", cfg.JudgeURL, cfg.JudgeModel)
log.Printf("InfluxDB: %s/%s", cfg.InfluxURL, cfg.InfluxDB)
if cfg.DBPath != "" {
log.Printf("DuckDB: %s", cfg.DBPath)
}
log.Printf("Poll interval: %ds", cfg.PollInterval)
log.Println(strings.Repeat("=", LogSeparatorWidth))
influx := NewInfluxClient(cfg.InfluxURL, cfg.InfluxDB)
os.MkdirAll(cfg.WorkDir, 0755)
for {
ReplayInfluxBuffer(cfg.WorkDir, influx)
log.Println("Discovering checkpoints on M3...")
checkpoints, err := DiscoverCheckpoints(cfg)
if err != nil {
log.Printf("Discovery failed: %v", err)
if cfg.OneShot {
return
}
time.Sleep(time.Duration(cfg.PollInterval) * time.Second)
continue
}
log.Printf("Found %d total checkpoints", len(checkpoints))
var unscored []Checkpoint
if cfg.Force {
unscored = checkpoints
log.Printf("Force mode: scoring all %d checkpoints", len(unscored))
} else {
scored, err := GetScoredLabels(influx)
if err != nil {
log.Printf("InfluxDB query failed: %v", err)
}
log.Printf("Already scored: %d (run_id, label) pairs", len(scored))
unscored = FindUnscored(checkpoints, scored)
log.Printf("Unscored: %d checkpoints", len(unscored))
}
if len(unscored) == 0 {
log.Printf("Nothing to score. Sleeping %ds...", cfg.PollInterval)
if cfg.OneShot {
return
}
time.Sleep(time.Duration(cfg.PollInterval) * time.Second)
continue
}
targets := unscored
if !cfg.Force {
targets = unscored[:1]
}
for i, target := range targets {
log.Printf("Grabbed: %s (%s) [%d/%d]", target.Label, target.Dirname, i+1, len(targets))
if cfg.DryRun {
log.Printf("[DRY RUN] Would process: %s/%s", target.Dirname, target.Filename)
continue
}
if err := ProcessOne(cfg, influx, target); err != nil {
log.Printf("Error processing %s: %v", target.Label, err)
}
time.Sleep(InterCheckpointDelay)
}
if cfg.DryRun || cfg.OneShot {
return
}
}
}
// DiscoverCheckpoints lists all adapter directories and checkpoint files on M3 via SSH.
func DiscoverCheckpoints(cfg *AgentConfig) ([]Checkpoint, error) {
pattern := "adapters-*"
if cfg.Filter != "" {
pattern = "adapters-" + cfg.Filter + "*"
}
t := cfg.transport()
ctx := context.Background()
out, err := t.Run(ctx, fmt.Sprintf("ls -d %s/%s 2>/dev/null", cfg.M3AdapterBase, pattern))
if err != nil {
return nil, fmt.Errorf("list adapter dirs: %w", err)
}
var checkpoints []Checkpoint
iterRe := regexp.MustCompile(`(\d+)`)
var adapterDirs []string
for _, dirpath := range strings.Split(strings.TrimSpace(out), "\n") {
if dirpath == "" {
continue
}
subOut, subErr := t.Run(ctx, fmt.Sprintf("ls -d %s/gemma-3-* 2>/dev/null", dirpath))
if subErr == nil && strings.TrimSpace(subOut) != "" {
for _, sub := range strings.Split(strings.TrimSpace(subOut), "\n") {
if sub != "" {
adapterDirs = append(adapterDirs, sub)
}
}
} else {
adapterDirs = append(adapterDirs, dirpath)
}
}
for _, dirpath := range adapterDirs {
dirname := strings.TrimPrefix(dirpath, cfg.M3AdapterBase+"/")
filesOut, err := t.Run(ctx, fmt.Sprintf("ls %s/*_adapters.safetensors 2>/dev/null", dirpath))
if err != nil {
continue
}
for _, fp := range strings.Split(strings.TrimSpace(filesOut), "\n") {
if fp == "" {
continue
}
filename := fileBase(fp)
match := iterRe.FindStringSubmatch(filename)
if len(match) < 2 {
continue
}
iteration := 0
fmt.Sscanf(match[1], "%d", &iteration)
modelTag, labelPrefix, stem := AdapterMeta(dirname)
label := fmt.Sprintf("%s @%s", labelPrefix, match[1])
runID := fmt.Sprintf("%s-capability-auto", stem)
checkpoints = append(checkpoints, Checkpoint{
RemoteDir: dirpath,
Filename: filename,
Dirname: dirname,
Iteration: iteration,
ModelTag: modelTag,
Label: label,
RunID: runID,
})
}
}
return checkpoints, nil
}
// GetScoredLabels returns all (run_id, label) pairs already scored in InfluxDB.
func GetScoredLabels(influx *InfluxClient) (map[[2]string]bool, error) {
rows, err := influx.QuerySQL("SELECT DISTINCT run_id, label FROM " + MeasurementCapabilityScore)
if err != nil {
return nil, err
}
scored := make(map[[2]string]bool)
for _, row := range rows {
runID, _ := row["run_id"].(string)
label, _ := row["label"].(string)
if runID != "" && label != "" {
scored[[2]string{runID, label}] = true
}
}
return scored, nil
}
// FindUnscored filters checkpoints to only unscored ones, sorted by (dirname, iteration).
func FindUnscored(checkpoints []Checkpoint, scored map[[2]string]bool) []Checkpoint {
var unscored []Checkpoint
for _, c := range checkpoints {
if !scored[[2]string{c.RunID, c.Label}] {
unscored = append(unscored, c)
}
}
sort.Slice(unscored, func(i, j int) bool {
if unscored[i].Dirname != unscored[j].Dirname {
return unscored[i].Dirname < unscored[j].Dirname
}
return unscored[i].Iteration < unscored[j].Iteration
})
return unscored
}
// isMLXNative returns true if this model can be served directly on M3 via
// mlx_lm.server with --adapter, avoiding the MLX→PEFT conversion step.
func isMLXNative(modelTag string) bool {
return strings.HasPrefix(modelTag, "gemma-3-") || strings.HasPrefix(modelTag, "gpt-oss")
}
// ProcessOne fetches, converts, scores, and pushes one checkpoint.
func ProcessOne(cfg *AgentConfig, influx *InfluxClient, cp Checkpoint) error {
log.Println(strings.Repeat("=", LogSeparatorWidth))
log.Printf("Processing: %s / %s [%s]", cp.Dirname, cp.Filename, cp.ModelTag)
log.Println(strings.Repeat("=", LogSeparatorWidth))
if isMLXNative(cp.ModelTag) {
return processMLXNative(cfg, influx, cp)
}
return processWithConversion(cfg, influx, cp)
}