From e6ada25bd8b4977059e6a1ab572c10e333b7486a Mon Sep 17 00:00:00 2001 From: Claude Date: Mon, 16 Feb 2026 02:27:10 +0000 Subject: [PATCH] fix: add Metal cache management to prevent memory growth - Add ClearCache() wrapping mlx_clear_cache - Clear Metal allocator cache every 8 tokens during generation - Set 16GB cache limit on backend init - Prevents GPU memory from growing unbounded during inference Co-Authored-By: Claude Opus 4.6 --- pkg/ml/backend_mlx.go | 17 ++++++++++++++++- pkg/mlx/stream.go | 5 +++++ 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/pkg/ml/backend_mlx.go b/pkg/ml/backend_mlx.go index f26c89c8..de8d5c2c 100644 --- a/pkg/ml/backend_mlx.go +++ b/pkg/ml/backend_mlx.go @@ -37,6 +37,9 @@ func NewMLXBackend(modelPath string) (*MLXBackend, error) { return nil, fmt.Errorf("mlx: load model: %w", err) } + // Set Metal cache limit to prevent unbounded memory growth + mlx.SetCacheLimit(16 * 1024 * 1024 * 1024) // 16 GB + slog.Info("mlx: model loaded", "layers", m.NumLayers(), "memory_mb", mlx.GetActiveMemory()/1024/1024, @@ -82,12 +85,12 @@ func (b *MLXBackend) Generate(ctx context.Context, prompt string, opts GenOpts) for i := 0; i < maxTokens; i++ { select { case <-ctx.Done(): + mlx.ClearCache() return b.tok.Decode(output), ctx.Err() default: } logits := b.model.Forward(input, b.caches) - // Take last position: [B, L, V] → [B, V] logits = lastPosition(logits) next := sampler.Sample(logits) mlx.Materialize(next) @@ -98,8 +101,14 @@ func (b *MLXBackend) Generate(ctx context.Context, prompt string, opts GenOpts) } output = append(output, nextToken) input = mlx.FromValues([]int32{nextToken}, 1, 1) + + // Periodically release Metal allocator cache to prevent memory growth + if i%8 == 7 { + mlx.ClearCache() + } } + mlx.ClearCache() return b.tok.Decode(output), nil } @@ -158,6 +167,7 @@ func (b *MLXBackend) Chat(ctx context.Context, messages []Message, opts GenOpts) for i := 0; i < maxTokens; i++ { select { case <-ctx.Done(): + mlx.ClearCache() return b.tok.Decode(output), ctx.Err() default: } @@ -173,8 +183,13 @@ func (b *MLXBackend) Chat(ctx context.Context, messages []Message, opts GenOpts) } output = append(output, nextToken) input = mlx.FromValues([]int32{nextToken}, 1, 1) + + if i%8 == 7 { + mlx.ClearCache() + } } + mlx.ClearCache() return b.tok.Decode(output), nil } diff --git a/pkg/mlx/stream.go b/pkg/mlx/stream.go index 40a80f82..261ea936 100644 --- a/pkg/mlx/stream.go +++ b/pkg/mlx/stream.go @@ -72,3 +72,8 @@ func GetPeakMemory() uint64 { C.mlx_get_peak_memory(&mem) return uint64(mem) } + +// ClearCache releases Metal memory held in the MLX allocator cache. +func ClearCache() { + C.mlx_clear_cache() +}