diff --git a/pkg/ml/backend_mlx.go b/pkg/ml/backend_mlx.go index f4af0d1..7ef9f64 100644 --- a/pkg/ml/backend_mlx.go +++ b/pkg/ml/backend_mlx.go @@ -18,11 +18,12 @@ import ( // MLXBackend implements Backend for native Metal inference via mlx-c. type MLXBackend struct { - model *model.GemmaModel - tok *tokenizer.Tokenizer - caches []cache.Cache - sampler sample.Sampler - mu sync.Mutex + model *model.GemmaModel + tok *tokenizer.Tokenizer + caches []cache.Cache + sampler sample.Sampler + mu sync.Mutex + modelBytes uint64 // model size at load time, for memory budget } // NewMLXBackend loads a model from a safetensors directory and creates @@ -43,16 +44,18 @@ func NewMLXBackend(modelPath string) (*MLXBackend, error) { mlx.SetCacheLimit(16 * 1024 * 1024 * 1024) // 16 GB allocator cache mlx.SetMemoryLimit(24 * 1024 * 1024 * 1024) // 24 GB hard cap + modelMB := mlx.GetActiveMemory() / 1024 / 1024 slog.Info("mlx: model loaded", "layers", m.NumLayers(), - "memory_mb", mlx.GetActiveMemory()/1024/1024, + "memory_mb", modelMB, ) return &MLXBackend{ - model: m, - tok: m.Tokenizer(), - caches: m.NewCache(), - sampler: sample.New(0.1, 0, 0, 0), // default low temp + model: m, + tok: m.Tokenizer(), + caches: m.NewCache(), + sampler: sample.New(0.1, 0, 0, 0), // default low temp + modelBytes: mlx.GetActiveMemory(), }, nil } @@ -114,9 +117,10 @@ func (b *MLXBackend) Generate(ctx context.Context, prompt string, opts GenOpts) } } - // Full cleanup between requests + // Cleanup between requests runtime.GC() mlx.ClearCache() + b.checkMemory() return b.tok.Decode(output), nil } @@ -200,12 +204,29 @@ func (b *MLXBackend) Chat(ctx context.Context, messages []Message, opts GenOpts) } } - // Full cleanup between requests + // Cleanup between requests runtime.GC() mlx.ClearCache() + b.checkMemory() return b.tok.Decode(output), nil } +// checkMemory logs Metal memory usage and forces cleanup if it exceeds budget. +func (b *MLXBackend) checkMemory() { + active := mlx.GetActiveMemory() + budget := b.modelBytes * 3 // 3× model size = danger zone + if active > budget { + slog.Warn("mlx: memory over budget, forcing cleanup", + "active_mb", active/1024/1024, + "model_mb", b.modelBytes/1024/1024, + "peak_mb", mlx.GetPeakMemory()/1024/1024, + ) + runtime.GC() + runtime.GC() // double GC to run finalizers + mlx.ClearCache() + } +} + // Name returns the backend identifier. func (b *MLXBackend) Name() string { return "mlx" }