go-ml/agent_execute.go
Snider 5340a975a7
Some checks failed
Security Scan / security (push) Successful in 9s
Test / test (push) Failing after 41s
refactor: modernise to Go 1.26 — iterators, slices, maps
- Add DiscoverCheckpointsIter, FindUnscoredIter (iter.Seq2/iter.Seq)
- Add BackendsIter on Service
- Replace sort.Slice with slices.SortFunc, sort.Strings with slices.Sort
- Replace manual map-key-collect-sort with slices.Sorted(maps.Keys(...))
- Refactor ProbeCategories to use slices.Sorted with inline iterator
- Existing slice-returning methods preserved, backed by new iterators

Co-Authored-By: Gemini <noreply@google.com>
Co-Authored-By: Virgil <virgil@lethean.io>
2026-02-23 06:11:55 +00:00

244 lines
6.7 KiB
Go

package ml
import (
"context"
"fmt"
"iter"
"log"
"os"
"regexp"
"slices"
"strings"
"time"
)
// RunAgentLoop is the main scoring agent loop.
func RunAgentLoop(cfg *AgentConfig) {
log.Println(strings.Repeat("=", LogSeparatorWidth))
log.Println("ROCm Scoring Agent — Go Edition")
log.Printf("M3: %s@%s", cfg.M3User, cfg.M3Host)
log.Printf("Inference API: %s", cfg.APIURL)
log.Printf("Judge API: %s (%s)", cfg.JudgeURL, cfg.JudgeModel)
log.Printf("InfluxDB: %s/%s", cfg.InfluxURL, cfg.InfluxDB)
if cfg.DBPath != "" {
log.Printf("DuckDB: %s", cfg.DBPath)
}
log.Printf("Poll interval: %ds", cfg.PollInterval)
log.Println(strings.Repeat("=", LogSeparatorWidth))
influx := NewInfluxClient(cfg.InfluxURL, cfg.InfluxDB)
os.MkdirAll(cfg.WorkDir, 0755)
for {
ReplayInfluxBuffer(cfg.WorkDir, influx)
log.Println("Discovering checkpoints on M3...")
checkpoints, err := DiscoverCheckpoints(cfg)
if err != nil {
log.Printf("Discovery failed: %v", err)
if cfg.OneShot {
return
}
time.Sleep(time.Duration(cfg.PollInterval) * time.Second)
continue
}
log.Printf("Found %d total checkpoints", len(checkpoints))
var unscored []Checkpoint
if cfg.Force {
unscored = checkpoints
log.Printf("Force mode: scoring all %d checkpoints", len(unscored))
} else {
scored, err := GetScoredLabels(influx)
if err != nil {
log.Printf("InfluxDB query failed: %v", err)
}
log.Printf("Already scored: %d (run_id, label) pairs", len(scored))
unscored = FindUnscored(checkpoints, scored)
log.Printf("Unscored: %d checkpoints", len(unscored))
}
if len(unscored) == 0 {
log.Printf("Nothing to score. Sleeping %ds...", cfg.PollInterval)
if cfg.OneShot {
return
}
time.Sleep(time.Duration(cfg.PollInterval) * time.Second)
continue
}
targets := unscored
if !cfg.Force {
targets = unscored[:1]
}
for i, target := range targets {
log.Printf("Grabbed: %s (%s) [%d/%d]", target.Label, target.Dirname, i+1, len(targets))
if cfg.DryRun {
log.Printf("[DRY RUN] Would process: %s/%s", target.Dirname, target.Filename)
continue
}
if err := ProcessOne(cfg, influx, target); err != nil {
log.Printf("Error processing %s: %v", target.Label, err)
}
time.Sleep(InterCheckpointDelay)
}
if cfg.DryRun || cfg.OneShot {
return
}
}
}
// DiscoverCheckpoints lists all adapter directories and checkpoint files on M3 via SSH.
func DiscoverCheckpoints(cfg *AgentConfig) ([]Checkpoint, error) {
var checkpoints []Checkpoint
for cp, err := range DiscoverCheckpointsIter(cfg) {
if err != nil {
return nil, err
}
checkpoints = append(checkpoints, cp)
}
return checkpoints, nil
}
// DiscoverCheckpointsIter returns an iterator over discovered adapter checkpoints.
func DiscoverCheckpointsIter(cfg *AgentConfig) iter.Seq2[Checkpoint, error] {
return func(yield func(Checkpoint, error) bool) {
pattern := "adapters-*"
if cfg.Filter != "" {
pattern = "adapters-" + cfg.Filter + "*"
}
t := cfg.transport()
ctx := context.Background()
out, err := t.Run(ctx, fmt.Sprintf("ls -d %s/%s 2>/dev/null", cfg.M3AdapterBase, pattern))
if err != nil {
yield(Checkpoint{}, fmt.Errorf("list adapter dirs: %w", err))
return
}
iterRe := regexp.MustCompile(`(\d+)`)
var adapterDirs []string
for dirpath := range strings.SplitSeq(strings.TrimSpace(out), "\n") {
if dirpath == "" {
continue
}
subOut, subErr := t.Run(ctx, fmt.Sprintf("ls -d %s/gemma-3-* 2>/dev/null", dirpath))
if subErr == nil && strings.TrimSpace(subOut) != "" {
for sub := range strings.SplitSeq(strings.TrimSpace(subOut), "\n") {
if sub != "" {
adapterDirs = append(adapterDirs, sub)
}
}
} else {
adapterDirs = append(adapterDirs, dirpath)
}
}
for _, dirpath := range adapterDirs {
dirname := strings.TrimPrefix(dirpath, cfg.M3AdapterBase+"/")
filesOut, err := t.Run(ctx, fmt.Sprintf("ls %s/*_adapters.safetensors 2>/dev/null", dirpath))
if err != nil {
continue
}
for fp := range strings.SplitSeq(strings.TrimSpace(filesOut), "\n") {
if fp == "" {
continue
}
filename := fileBase(fp)
match := iterRe.FindStringSubmatch(filename)
if len(match) < 2 {
continue
}
iteration := 0
fmt.Sscanf(match[1], "%d", &iteration)
modelTag, labelPrefix, stem := AdapterMeta(dirname)
label := fmt.Sprintf("%s @%s", labelPrefix, match[1])
runID := fmt.Sprintf("%s-capability-auto", stem)
if !yield(Checkpoint{
RemoteDir: dirpath,
Filename: filename,
Dirname: dirname,
Iteration: iteration,
ModelTag: modelTag,
Label: label,
RunID: runID,
}, nil) {
return
}
}
}
}
}
// GetScoredLabels returns all (run_id, label) pairs already scored in InfluxDB.
func GetScoredLabels(influx *InfluxClient) (map[[2]string]bool, error) {
rows, err := influx.QuerySQL("SELECT DISTINCT run_id, label FROM " + MeasurementCapabilityScore)
if err != nil {
return nil, err
}
scored := make(map[[2]string]bool)
for _, row := range rows {
runID, _ := row["run_id"].(string)
label, _ := row["label"].(string)
if runID != "" && label != "" {
scored[[2]string{runID, label}] = true
}
}
return scored, nil
}
// FindUnscored filters checkpoints to only unscored ones, sorted by (dirname, iteration).
func FindUnscored(checkpoints []Checkpoint, scored map[[2]string]bool) []Checkpoint {
var unscored []Checkpoint
for c := range FindUnscoredIter(checkpoints, scored) {
unscored = append(unscored, c)
}
slices.SortFunc(unscored, func(a, b Checkpoint) int {
if a.Dirname != b.Dirname {
return strings.Compare(a.Dirname, b.Dirname)
}
return a.Iteration - b.Iteration
})
return unscored
}
// FindUnscoredIter returns an iterator over checkpoints that have not yet been scored.
func FindUnscoredIter(checkpoints []Checkpoint, scored map[[2]string]bool) iter.Seq[Checkpoint] {
return func(yield func(Checkpoint) bool) {
for _, c := range checkpoints {
if !scored[[2]string{c.RunID, c.Label}] {
if !yield(c) {
return
}
}
}
}
}
// isMLXNative returns true if this model can be served directly on M3 via
// mlx_lm.server with --adapter, avoiding the MLX→PEFT conversion step.
func isMLXNative(modelTag string) bool {
return strings.HasPrefix(modelTag, "gemma-3-") || strings.HasPrefix(modelTag, "gpt-oss")
}
// ProcessOne fetches, converts, scores, and pushes one checkpoint.
func ProcessOne(cfg *AgentConfig, influx *InfluxClient, cp Checkpoint) error {
log.Println(strings.Repeat("=", LogSeparatorWidth))
log.Printf("Processing: %s / %s [%s]", cp.Dirname, cp.Filename, cp.ModelTag)
log.Println(strings.Repeat("=", LogSeparatorWidth))
if isMLXNative(cp.ModelTag) {
return processMLXNative(cfg, influx, cp)
}
return processWithConversion(cfg, influx, cp)
}