feat(distill): migrate from go-inference to go-ml Backend
Replace inference.LoadModel() with ml.NewMLXBackend() which wraps
the same Metal model with memory management (SetCacheLimit,
SetMemoryLimit). Replace raw iter.Seq token loop with backend.Chat()
returning Result{Text, Metrics}. Add runtime.GC() between probes
to prevent incremental memory leak.
Reference: go-ml/cmd/cmd_ab.go memory management pattern.
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
8408cc0bab
commit
55519b24aa
1 changed files with 40 additions and 23 deletions
|
|
@ -8,11 +8,13 @@ import (
|
|||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"forge.lthn.ai/core/go-i18n/reversal"
|
||||
"forge.lthn.ai/core/go-inference"
|
||||
ml "forge.lthn.ai/core/go-ml"
|
||||
"forge.lthn.ai/core/go-mlx"
|
||||
)
|
||||
|
||||
// DistillProbe is a single input prompt for distillation.
|
||||
|
|
@ -146,16 +148,34 @@ func RunDistill(args []string) {
|
|||
return
|
||||
}
|
||||
|
||||
// Load model via go-inference.
|
||||
// Set Metal memory limits before loading model.
|
||||
if cacheLimitGB > 0 {
|
||||
mlx.SetCacheLimit(uint64(cacheLimitGB) * 1024 * 1024 * 1024)
|
||||
log.Printf("metal cache limit: %dGB", cacheLimitGB)
|
||||
}
|
||||
if memLimitGB > 0 {
|
||||
mlx.SetMemoryLimit(uint64(memLimitGB) * 1024 * 1024 * 1024)
|
||||
log.Printf("metal memory limit: %dGB", memLimitGB)
|
||||
}
|
||||
|
||||
// Load model via go-ml Backend (wraps go-inference with memory management).
|
||||
log.Printf("loading model: %s", modelCfg.Paths.Base)
|
||||
model, err := inference.LoadModel(modelCfg.Paths.Base)
|
||||
backend, err := ml.NewMLXBackend(modelCfg.Paths.Base)
|
||||
if err != nil {
|
||||
log.Fatalf("load model: %v", err)
|
||||
}
|
||||
defer model.Close()
|
||||
defer backend.Close()
|
||||
|
||||
info := model.Info()
|
||||
log.Printf("model loaded: %s %d-layer", info.Architecture, info.NumLayers)
|
||||
log.Printf("model loaded via %s backend", backend.Name())
|
||||
|
||||
// Build generation options from merged config.
|
||||
genOpts := ml.GenOpts{
|
||||
MaxTokens: genCfg.MaxTokens,
|
||||
Temperature: genCfg.Temperature,
|
||||
TopP: genCfg.TopP,
|
||||
TopK: genCfg.TopK,
|
||||
RepeatPenalty: genCfg.RepeatPenalty,
|
||||
}
|
||||
|
||||
// Initialise grammar scorer.
|
||||
tok := reversal.NewTokeniser()
|
||||
|
|
@ -188,27 +208,18 @@ func RunDistill(args []string) {
|
|||
|
||||
// Inference uses bare probe — the model generates from its weights.
|
||||
// Sandwich wrapping is only for the training output format.
|
||||
messages := []inference.Message{
|
||||
messages := []ml.Message{
|
||||
{Role: "user", Content: probe.Prompt},
|
||||
}
|
||||
|
||||
// Generate via native Metal inference.
|
||||
// Generate via go-ml Backend (memory-managed Metal inference).
|
||||
start := time.Now()
|
||||
var sb strings.Builder
|
||||
for token := range model.Chat(ctx, messages,
|
||||
inference.WithMaxTokens(genCfg.MaxTokens),
|
||||
inference.WithTemperature(float32(genCfg.Temperature)),
|
||||
inference.WithTopP(float32(genCfg.TopP)),
|
||||
inference.WithTopK(genCfg.TopK),
|
||||
inference.WithRepeatPenalty(float32(genCfg.RepeatPenalty)),
|
||||
) {
|
||||
sb.WriteString(token.Text)
|
||||
}
|
||||
if err := model.Err(); err != nil {
|
||||
result, err := backend.Chat(ctx, messages, genOpts)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, " → ERROR: %v\n", err)
|
||||
continue
|
||||
}
|
||||
response := sb.String()
|
||||
response := result.Text
|
||||
elapsed := time.Since(start)
|
||||
|
||||
// Quick reject: empty/degenerate.
|
||||
|
|
@ -222,11 +233,14 @@ func RunDistill(args []string) {
|
|||
delta := ComputeDelta(tok, probe.Prompt, response,
|
||||
aiCfg.Scorer.SycophancyEcho, aiCfg.Scorer.SycophancyUplift)
|
||||
|
||||
met := model.Metrics()
|
||||
tokPerSec := 0.0
|
||||
if result.Metrics != nil {
|
||||
tokPerSec = result.Metrics.DecodeTokensPerSec
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, " → %d chars, g=%.1f up=%+.1f echo=%.2f enr=%+.1f, %.1fs (%.0f tok/s)\n",
|
||||
len(response), grammar.Composite,
|
||||
delta.Uplift, delta.Echo, delta.Enrichment,
|
||||
elapsed.Seconds(), met.DecodeTokensPerSec)
|
||||
elapsed.Seconds(), tokPerSec)
|
||||
|
||||
candidate := &distillCandidate{
|
||||
Response: response,
|
||||
|
|
@ -266,12 +280,15 @@ func RunDistill(args []string) {
|
|||
fmt.Fprintf(os.Stderr, " ✗ SKIP %s (best g=%.1f < %.1f)\n",
|
||||
probe.ID, score, *minScore)
|
||||
}
|
||||
|
||||
// Release GPU memory between probes to prevent incremental leak.
|
||||
runtime.GC()
|
||||
}
|
||||
|
||||
duration := time.Since(totalStart)
|
||||
|
||||
fmt.Fprintf(os.Stderr, "\n=== Distillation Complete ===\n")
|
||||
fmt.Fprintf(os.Stderr, "Model: %s (%s)\n", modelCfg.Name, info.Architecture)
|
||||
fmt.Fprintf(os.Stderr, "Model: %s (%s)\n", modelCfg.Name, backend.Name())
|
||||
fmt.Fprintf(os.Stderr, "Probes: %d\n", len(probes))
|
||||
fmt.Fprintf(os.Stderr, "Runs: %d per probe (%d total generations)\n", *runs, len(probes)**runs)
|
||||
fmt.Fprintf(os.Stderr, "Scorer: go-i18n/reversal grammar v3, gate >= %.1f\n", *minScore)
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue