feat(distill): migrate from go-inference to go-ml Backend

Replace inference.LoadModel() with ml.NewMLXBackend() which wraps
the same Metal model with memory management (SetCacheLimit,
SetMemoryLimit). Replace raw iter.Seq token loop with backend.Chat()
returning Result{Text, Metrics}. Add runtime.GC() between probes
to prevent incremental memory leak.

Reference: go-ml/cmd/cmd_ab.go memory management pattern.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Snider 2026-02-22 18:02:16 +00:00
parent 8408cc0bab
commit 55519b24aa

View file

@ -8,11 +8,13 @@ import (
"log"
"os"
"path/filepath"
"runtime"
"strings"
"time"
"forge.lthn.ai/core/go-i18n/reversal"
"forge.lthn.ai/core/go-inference"
ml "forge.lthn.ai/core/go-ml"
"forge.lthn.ai/core/go-mlx"
)
// DistillProbe is a single input prompt for distillation.
@ -146,16 +148,34 @@ func RunDistill(args []string) {
return
}
// Load model via go-inference.
// Set Metal memory limits before loading model.
if cacheLimitGB > 0 {
mlx.SetCacheLimit(uint64(cacheLimitGB) * 1024 * 1024 * 1024)
log.Printf("metal cache limit: %dGB", cacheLimitGB)
}
if memLimitGB > 0 {
mlx.SetMemoryLimit(uint64(memLimitGB) * 1024 * 1024 * 1024)
log.Printf("metal memory limit: %dGB", memLimitGB)
}
// Load model via go-ml Backend (wraps go-inference with memory management).
log.Printf("loading model: %s", modelCfg.Paths.Base)
model, err := inference.LoadModel(modelCfg.Paths.Base)
backend, err := ml.NewMLXBackend(modelCfg.Paths.Base)
if err != nil {
log.Fatalf("load model: %v", err)
}
defer model.Close()
defer backend.Close()
info := model.Info()
log.Printf("model loaded: %s %d-layer", info.Architecture, info.NumLayers)
log.Printf("model loaded via %s backend", backend.Name())
// Build generation options from merged config.
genOpts := ml.GenOpts{
MaxTokens: genCfg.MaxTokens,
Temperature: genCfg.Temperature,
TopP: genCfg.TopP,
TopK: genCfg.TopK,
RepeatPenalty: genCfg.RepeatPenalty,
}
// Initialise grammar scorer.
tok := reversal.NewTokeniser()
@ -188,27 +208,18 @@ func RunDistill(args []string) {
// Inference uses bare probe — the model generates from its weights.
// Sandwich wrapping is only for the training output format.
messages := []inference.Message{
messages := []ml.Message{
{Role: "user", Content: probe.Prompt},
}
// Generate via native Metal inference.
// Generate via go-ml Backend (memory-managed Metal inference).
start := time.Now()
var sb strings.Builder
for token := range model.Chat(ctx, messages,
inference.WithMaxTokens(genCfg.MaxTokens),
inference.WithTemperature(float32(genCfg.Temperature)),
inference.WithTopP(float32(genCfg.TopP)),
inference.WithTopK(genCfg.TopK),
inference.WithRepeatPenalty(float32(genCfg.RepeatPenalty)),
) {
sb.WriteString(token.Text)
}
if err := model.Err(); err != nil {
result, err := backend.Chat(ctx, messages, genOpts)
if err != nil {
fmt.Fprintf(os.Stderr, " → ERROR: %v\n", err)
continue
}
response := sb.String()
response := result.Text
elapsed := time.Since(start)
// Quick reject: empty/degenerate.
@ -222,11 +233,14 @@ func RunDistill(args []string) {
delta := ComputeDelta(tok, probe.Prompt, response,
aiCfg.Scorer.SycophancyEcho, aiCfg.Scorer.SycophancyUplift)
met := model.Metrics()
tokPerSec := 0.0
if result.Metrics != nil {
tokPerSec = result.Metrics.DecodeTokensPerSec
}
fmt.Fprintf(os.Stderr, " → %d chars, g=%.1f up=%+.1f echo=%.2f enr=%+.1f, %.1fs (%.0f tok/s)\n",
len(response), grammar.Composite,
delta.Uplift, delta.Echo, delta.Enrichment,
elapsed.Seconds(), met.DecodeTokensPerSec)
elapsed.Seconds(), tokPerSec)
candidate := &distillCandidate{
Response: response,
@ -266,12 +280,15 @@ func RunDistill(args []string) {
fmt.Fprintf(os.Stderr, " ✗ SKIP %s (best g=%.1f < %.1f)\n",
probe.ID, score, *minScore)
}
// Release GPU memory between probes to prevent incremental leak.
runtime.GC()
}
duration := time.Since(totalStart)
fmt.Fprintf(os.Stderr, "\n=== Distillation Complete ===\n")
fmt.Fprintf(os.Stderr, "Model: %s (%s)\n", modelCfg.Name, info.Architecture)
fmt.Fprintf(os.Stderr, "Model: %s (%s)\n", modelCfg.Name, backend.Name())
fmt.Fprintf(os.Stderr, "Probes: %d\n", len(probes))
fmt.Fprintf(os.Stderr, "Runs: %d per probe (%d total generations)\n", *runs, len(probes)**runs)
fmt.Fprintf(os.Stderr, "Scorer: go-i18n/reversal grammar v3, gate >= %.1f\n", *minScore)