From 8cdafc8d666a647b16c695195688bd09ab0eb1ae Mon Sep 17 00:00:00 2001 From: Claude Date: Mon, 16 Feb 2026 02:42:28 +0000 Subject: [PATCH] fix: add GC-based memory management for MLX array handles Go GC cannot see Metal/C memory pressure, so intermediate arrays from each forward pass accumulated without bound, causing OOM kills after 3-4 requests. Fix: runtime.SetFinalizer on every Array releases C handles when GC collects them, and runtime.GC() is forced every 4 tokens during generation. Also adds SetMemoryLimit(24GB) as a hard Metal ceiling. Co-Authored-By: Claude Opus 4.6 --- pkg/ml/backend_mlx.go | 25 +++++++++++++++++++------ pkg/mlx/array.go | 15 ++++++++++++++- 2 files changed, 33 insertions(+), 7 deletions(-) diff --git a/pkg/ml/backend_mlx.go b/pkg/ml/backend_mlx.go index de8d5c2..f4af0d1 100644 --- a/pkg/ml/backend_mlx.go +++ b/pkg/ml/backend_mlx.go @@ -6,6 +6,7 @@ import ( "context" "fmt" "log/slog" + "runtime" "sync" "forge.lthn.ai/core/cli/pkg/mlx" @@ -37,8 +38,10 @@ func NewMLXBackend(modelPath string) (*MLXBackend, error) { return nil, fmt.Errorf("mlx: load model: %w", err) } - // Set Metal cache limit to prevent unbounded memory growth - mlx.SetCacheLimit(16 * 1024 * 1024 * 1024) // 16 GB + // Cap Metal memory: cache limit for allocator reuse, memory limit as hard ceiling. + // This prevents runaway memory growth from killing the system. + mlx.SetCacheLimit(16 * 1024 * 1024 * 1024) // 16 GB allocator cache + mlx.SetMemoryLimit(24 * 1024 * 1024 * 1024) // 24 GB hard cap slog.Info("mlx: model loaded", "layers", m.NumLayers(), @@ -80,11 +83,13 @@ func (b *MLXBackend) Generate(ctx context.Context, prompt string, opts GenOpts) maxTokens = 2048 } - // Generation loop + // Generation loop — force Go GC every 4 tokens so finalizers release + // intermediate C array handles that Go GC cannot see as memory pressure. var output []int32 for i := 0; i < maxTokens; i++ { select { case <-ctx.Done(): + runtime.GC() mlx.ClearCache() return b.tok.Decode(output), ctx.Err() default: @@ -102,12 +107,15 @@ func (b *MLXBackend) Generate(ctx context.Context, prompt string, opts GenOpts) output = append(output, nextToken) input = mlx.FromValues([]int32{nextToken}, 1, 1) - // Periodically release Metal allocator cache to prevent memory growth - if i%8 == 7 { + // Force GC to collect intermediate arrays + release Metal allocator cache + if i%4 == 3 { + runtime.GC() mlx.ClearCache() } } + // Full cleanup between requests + runtime.GC() mlx.ClearCache() return b.tok.Decode(output), nil } @@ -167,6 +175,7 @@ func (b *MLXBackend) Chat(ctx context.Context, messages []Message, opts GenOpts) for i := 0; i < maxTokens; i++ { select { case <-ctx.Done(): + runtime.GC() mlx.ClearCache() return b.tok.Decode(output), ctx.Err() default: @@ -184,11 +193,15 @@ func (b *MLXBackend) Chat(ctx context.Context, messages []Message, opts GenOpts) output = append(output, nextToken) input = mlx.FromValues([]int32{nextToken}, 1, 1) - if i%8 == 7 { + // Force GC to collect intermediate arrays + release Metal allocator cache + if i%4 == 3 { + runtime.GC() mlx.ClearCache() } } + // Full cleanup between requests + runtime.GC() mlx.ClearCache() return b.tok.Decode(output), nil } diff --git a/pkg/mlx/array.go b/pkg/mlx/array.go index 091dab8..2ea18a7 100644 --- a/pkg/mlx/array.go +++ b/pkg/mlx/array.go @@ -11,6 +11,7 @@ import "C" import ( "encoding/binary" "reflect" + "runtime" "strings" "unsafe" ) @@ -28,6 +29,9 @@ type Array struct { } // New creates a named Array tracking its input dependencies for cleanup. +// A runtime finalizer is set so Go GC can release the C handle when +// the Array becomes unreachable — critical because Go GC cannot see +// Metal/C memory pressure. func New(name string, inputs ...*Array) *Array { t := &Array{ desc: tensorDesc{ @@ -40,9 +44,18 @@ func New(name string, inputs ...*Array) *Array { input.desc.numRefs++ } } + runtime.SetFinalizer(t, finalizeArray) return t } +// finalizeArray is called by Go GC to release the underlying C array handle. +func finalizeArray(t *Array) { + if t != nil && t.ctx.ctx != nil { + C.mlx_array_free(t.ctx) + t.ctx.ctx = nil + } +} + type scalarTypes interface { ~bool | ~int | ~float32 | ~float64 | ~complex64 } @@ -50,7 +63,7 @@ type scalarTypes interface { // FromValue creates a scalar Array from a Go value. func FromValue[T scalarTypes](t T) *Array { Init() - tt := New("") + tt := New("") // finalizer set by New switch v := any(t).(type) { case bool: tt.ctx = C.mlx_array_new_bool(C.bool(v))