feat: add Metal memory budget monitoring after each request

Tracks model size at load time and checks Metal active memory after
each generation. If usage exceeds 3× model size, forces double GC
and cache clear as a safety net.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Claude 2026-02-16 02:52:26 +00:00 committed by Snider
parent c5689c3e83
commit 045f8fc110

View file

@ -18,11 +18,12 @@ import (
// MLXBackend implements Backend for native Metal inference via mlx-c. // MLXBackend implements Backend for native Metal inference via mlx-c.
type MLXBackend struct { type MLXBackend struct {
model *model.GemmaModel model *model.GemmaModel
tok *tokenizer.Tokenizer tok *tokenizer.Tokenizer
caches []cache.Cache caches []cache.Cache
sampler sample.Sampler sampler sample.Sampler
mu sync.Mutex mu sync.Mutex
modelBytes uint64 // model size at load time, for memory budget
} }
// NewMLXBackend loads a model from a safetensors directory and creates // NewMLXBackend loads a model from a safetensors directory and creates
@ -43,16 +44,18 @@ func NewMLXBackend(modelPath string) (*MLXBackend, error) {
mlx.SetCacheLimit(16 * 1024 * 1024 * 1024) // 16 GB allocator cache mlx.SetCacheLimit(16 * 1024 * 1024 * 1024) // 16 GB allocator cache
mlx.SetMemoryLimit(24 * 1024 * 1024 * 1024) // 24 GB hard cap mlx.SetMemoryLimit(24 * 1024 * 1024 * 1024) // 24 GB hard cap
modelMB := mlx.GetActiveMemory() / 1024 / 1024
slog.Info("mlx: model loaded", slog.Info("mlx: model loaded",
"layers", m.NumLayers(), "layers", m.NumLayers(),
"memory_mb", mlx.GetActiveMemory()/1024/1024, "memory_mb", modelMB,
) )
return &MLXBackend{ return &MLXBackend{
model: m, model: m,
tok: m.Tokenizer(), tok: m.Tokenizer(),
caches: m.NewCache(), caches: m.NewCache(),
sampler: sample.New(0.1, 0, 0, 0), // default low temp sampler: sample.New(0.1, 0, 0, 0), // default low temp
modelBytes: mlx.GetActiveMemory(),
}, nil }, nil
} }
@ -114,9 +117,10 @@ func (b *MLXBackend) Generate(ctx context.Context, prompt string, opts GenOpts)
} }
} }
// Full cleanup between requests // Cleanup between requests
runtime.GC() runtime.GC()
mlx.ClearCache() mlx.ClearCache()
b.checkMemory()
return b.tok.Decode(output), nil return b.tok.Decode(output), nil
} }
@ -200,12 +204,29 @@ func (b *MLXBackend) Chat(ctx context.Context, messages []Message, opts GenOpts)
} }
} }
// Full cleanup between requests // Cleanup between requests
runtime.GC() runtime.GC()
mlx.ClearCache() mlx.ClearCache()
b.checkMemory()
return b.tok.Decode(output), nil return b.tok.Decode(output), nil
} }
// checkMemory logs Metal memory usage and forces cleanup if it exceeds budget.
func (b *MLXBackend) checkMemory() {
active := mlx.GetActiveMemory()
budget := b.modelBytes * 3 // 3× model size = danger zone
if active > budget {
slog.Warn("mlx: memory over budget, forcing cleanup",
"active_mb", active/1024/1024,
"model_mb", b.modelBytes/1024/1024,
"peak_mb", mlx.GetPeakMemory()/1024/1024,
)
runtime.GC()
runtime.GC() // double GC to run finalizers
mlx.ClearCache()
}
}
// Name returns the backend identifier. // Name returns the backend identifier.
func (b *MLXBackend) Name() string { return "mlx" } func (b *MLXBackend) Name() string { return "mlx" }