From 55519b24aa2464f0afbb1805696985ced394136d Mon Sep 17 00:00:00 2001 From: Snider Date: Sun, 22 Feb 2026 18:02:16 +0000 Subject: [PATCH] feat(distill): migrate from go-inference to go-ml Backend Replace inference.LoadModel() with ml.NewMLXBackend() which wraps the same Metal model with memory management (SetCacheLimit, SetMemoryLimit). Replace raw iter.Seq token loop with backend.Chat() returning Result{Text, Metrics}. Add runtime.GC() between probes to prevent incremental memory leak. Reference: go-ml/cmd/cmd_ab.go memory management pattern. Co-Authored-By: Claude Opus 4.6 --- pkg/lem/distill.go | 63 +++++++++++++++++++++++++++++----------------- 1 file changed, 40 insertions(+), 23 deletions(-) diff --git a/pkg/lem/distill.go b/pkg/lem/distill.go index 04bd711..b791f11 100644 --- a/pkg/lem/distill.go +++ b/pkg/lem/distill.go @@ -8,11 +8,13 @@ import ( "log" "os" "path/filepath" + "runtime" "strings" "time" "forge.lthn.ai/core/go-i18n/reversal" - "forge.lthn.ai/core/go-inference" + ml "forge.lthn.ai/core/go-ml" + "forge.lthn.ai/core/go-mlx" ) // DistillProbe is a single input prompt for distillation. @@ -146,16 +148,34 @@ func RunDistill(args []string) { return } - // Load model via go-inference. + // Set Metal memory limits before loading model. + if cacheLimitGB > 0 { + mlx.SetCacheLimit(uint64(cacheLimitGB) * 1024 * 1024 * 1024) + log.Printf("metal cache limit: %dGB", cacheLimitGB) + } + if memLimitGB > 0 { + mlx.SetMemoryLimit(uint64(memLimitGB) * 1024 * 1024 * 1024) + log.Printf("metal memory limit: %dGB", memLimitGB) + } + + // Load model via go-ml Backend (wraps go-inference with memory management). log.Printf("loading model: %s", modelCfg.Paths.Base) - model, err := inference.LoadModel(modelCfg.Paths.Base) + backend, err := ml.NewMLXBackend(modelCfg.Paths.Base) if err != nil { log.Fatalf("load model: %v", err) } - defer model.Close() + defer backend.Close() - info := model.Info() - log.Printf("model loaded: %s %d-layer", info.Architecture, info.NumLayers) + log.Printf("model loaded via %s backend", backend.Name()) + + // Build generation options from merged config. + genOpts := ml.GenOpts{ + MaxTokens: genCfg.MaxTokens, + Temperature: genCfg.Temperature, + TopP: genCfg.TopP, + TopK: genCfg.TopK, + RepeatPenalty: genCfg.RepeatPenalty, + } // Initialise grammar scorer. tok := reversal.NewTokeniser() @@ -188,27 +208,18 @@ func RunDistill(args []string) { // Inference uses bare probe — the model generates from its weights. // Sandwich wrapping is only for the training output format. - messages := []inference.Message{ + messages := []ml.Message{ {Role: "user", Content: probe.Prompt}, } - // Generate via native Metal inference. + // Generate via go-ml Backend (memory-managed Metal inference). start := time.Now() - var sb strings.Builder - for token := range model.Chat(ctx, messages, - inference.WithMaxTokens(genCfg.MaxTokens), - inference.WithTemperature(float32(genCfg.Temperature)), - inference.WithTopP(float32(genCfg.TopP)), - inference.WithTopK(genCfg.TopK), - inference.WithRepeatPenalty(float32(genCfg.RepeatPenalty)), - ) { - sb.WriteString(token.Text) - } - if err := model.Err(); err != nil { + result, err := backend.Chat(ctx, messages, genOpts) + if err != nil { fmt.Fprintf(os.Stderr, " → ERROR: %v\n", err) continue } - response := sb.String() + response := result.Text elapsed := time.Since(start) // Quick reject: empty/degenerate. @@ -222,11 +233,14 @@ func RunDistill(args []string) { delta := ComputeDelta(tok, probe.Prompt, response, aiCfg.Scorer.SycophancyEcho, aiCfg.Scorer.SycophancyUplift) - met := model.Metrics() + tokPerSec := 0.0 + if result.Metrics != nil { + tokPerSec = result.Metrics.DecodeTokensPerSec + } fmt.Fprintf(os.Stderr, " → %d chars, g=%.1f up=%+.1f echo=%.2f enr=%+.1f, %.1fs (%.0f tok/s)\n", len(response), grammar.Composite, delta.Uplift, delta.Echo, delta.Enrichment, - elapsed.Seconds(), met.DecodeTokensPerSec) + elapsed.Seconds(), tokPerSec) candidate := &distillCandidate{ Response: response, @@ -266,12 +280,15 @@ func RunDistill(args []string) { fmt.Fprintf(os.Stderr, " ✗ SKIP %s (best g=%.1f < %.1f)\n", probe.ID, score, *minScore) } + + // Release GPU memory between probes to prevent incremental leak. + runtime.GC() } duration := time.Since(totalStart) fmt.Fprintf(os.Stderr, "\n=== Distillation Complete ===\n") - fmt.Fprintf(os.Stderr, "Model: %s (%s)\n", modelCfg.Name, info.Architecture) + fmt.Fprintf(os.Stderr, "Model: %s (%s)\n", modelCfg.Name, backend.Name()) fmt.Fprintf(os.Stderr, "Probes: %d\n", len(probes)) fmt.Fprintf(os.Stderr, "Runs: %d per probe (%d total generations)\n", *runs, len(probes)**runs) fmt.Fprintf(os.Stderr, "Scorer: go-i18n/reversal grammar v3, gate >= %.1f\n", *minScore)