fix: add GC-based memory management for MLX array handles

Go GC cannot see Metal/C memory pressure, so intermediate arrays from
each forward pass accumulated without bound, causing OOM kills after
3-4 requests. Fix: runtime.SetFinalizer on every Array releases C
handles when GC collects them, and runtime.GC() is forced every 4
tokens during generation. Also adds SetMemoryLimit(24GB) as a hard
Metal ceiling.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Claude 2026-02-16 02:42:28 +00:00 committed by Snider
parent e6ada25bd8
commit 1d4ec55d05
2 changed files with 33 additions and 7 deletions

View file

@ -6,6 +6,7 @@ import (
"context" "context"
"fmt" "fmt"
"log/slog" "log/slog"
"runtime"
"sync" "sync"
"forge.lthn.ai/core/cli/pkg/mlx" "forge.lthn.ai/core/cli/pkg/mlx"
@ -37,8 +38,10 @@ func NewMLXBackend(modelPath string) (*MLXBackend, error) {
return nil, fmt.Errorf("mlx: load model: %w", err) return nil, fmt.Errorf("mlx: load model: %w", err)
} }
// Set Metal cache limit to prevent unbounded memory growth // Cap Metal memory: cache limit for allocator reuse, memory limit as hard ceiling.
mlx.SetCacheLimit(16 * 1024 * 1024 * 1024) // 16 GB // This prevents runaway memory growth from killing the system.
mlx.SetCacheLimit(16 * 1024 * 1024 * 1024) // 16 GB allocator cache
mlx.SetMemoryLimit(24 * 1024 * 1024 * 1024) // 24 GB hard cap
slog.Info("mlx: model loaded", slog.Info("mlx: model loaded",
"layers", m.NumLayers(), "layers", m.NumLayers(),
@ -80,11 +83,13 @@ func (b *MLXBackend) Generate(ctx context.Context, prompt string, opts GenOpts)
maxTokens = 2048 maxTokens = 2048
} }
// Generation loop // Generation loop — force Go GC every 4 tokens so finalizers release
// intermediate C array handles that Go GC cannot see as memory pressure.
var output []int32 var output []int32
for i := 0; i < maxTokens; i++ { for i := 0; i < maxTokens; i++ {
select { select {
case <-ctx.Done(): case <-ctx.Done():
runtime.GC()
mlx.ClearCache() mlx.ClearCache()
return b.tok.Decode(output), ctx.Err() return b.tok.Decode(output), ctx.Err()
default: default:
@ -102,12 +107,15 @@ func (b *MLXBackend) Generate(ctx context.Context, prompt string, opts GenOpts)
output = append(output, nextToken) output = append(output, nextToken)
input = mlx.FromValues([]int32{nextToken}, 1, 1) input = mlx.FromValues([]int32{nextToken}, 1, 1)
// Periodically release Metal allocator cache to prevent memory growth // Force GC to collect intermediate arrays + release Metal allocator cache
if i%8 == 7 { if i%4 == 3 {
runtime.GC()
mlx.ClearCache() mlx.ClearCache()
} }
} }
// Full cleanup between requests
runtime.GC()
mlx.ClearCache() mlx.ClearCache()
return b.tok.Decode(output), nil return b.tok.Decode(output), nil
} }
@ -167,6 +175,7 @@ func (b *MLXBackend) Chat(ctx context.Context, messages []Message, opts GenOpts)
for i := 0; i < maxTokens; i++ { for i := 0; i < maxTokens; i++ {
select { select {
case <-ctx.Done(): case <-ctx.Done():
runtime.GC()
mlx.ClearCache() mlx.ClearCache()
return b.tok.Decode(output), ctx.Err() return b.tok.Decode(output), ctx.Err()
default: default:
@ -184,11 +193,15 @@ func (b *MLXBackend) Chat(ctx context.Context, messages []Message, opts GenOpts)
output = append(output, nextToken) output = append(output, nextToken)
input = mlx.FromValues([]int32{nextToken}, 1, 1) input = mlx.FromValues([]int32{nextToken}, 1, 1)
if i%8 == 7 { // Force GC to collect intermediate arrays + release Metal allocator cache
if i%4 == 3 {
runtime.GC()
mlx.ClearCache() mlx.ClearCache()
} }
} }
// Full cleanup between requests
runtime.GC()
mlx.ClearCache() mlx.ClearCache()
return b.tok.Decode(output), nil return b.tok.Decode(output), nil
} }

View file

@ -11,6 +11,7 @@ import "C"
import ( import (
"encoding/binary" "encoding/binary"
"reflect" "reflect"
"runtime"
"strings" "strings"
"unsafe" "unsafe"
) )
@ -28,6 +29,9 @@ type Array struct {
} }
// New creates a named Array tracking its input dependencies for cleanup. // New creates a named Array tracking its input dependencies for cleanup.
// A runtime finalizer is set so Go GC can release the C handle when
// the Array becomes unreachable — critical because Go GC cannot see
// Metal/C memory pressure.
func New(name string, inputs ...*Array) *Array { func New(name string, inputs ...*Array) *Array {
t := &Array{ t := &Array{
desc: tensorDesc{ desc: tensorDesc{
@ -40,9 +44,18 @@ func New(name string, inputs ...*Array) *Array {
input.desc.numRefs++ input.desc.numRefs++
} }
} }
runtime.SetFinalizer(t, finalizeArray)
return t return t
} }
// finalizeArray is called by Go GC to release the underlying C array handle.
func finalizeArray(t *Array) {
if t != nil && t.ctx.ctx != nil {
C.mlx_array_free(t.ctx)
t.ctx.ctx = nil
}
}
type scalarTypes interface { type scalarTypes interface {
~bool | ~int | ~float32 | ~float64 | ~complex64 ~bool | ~int | ~float32 | ~float64 | ~complex64
} }
@ -50,7 +63,7 @@ type scalarTypes interface {
// FromValue creates a scalar Array from a Go value. // FromValue creates a scalar Array from a Go value.
func FromValue[T scalarTypes](t T) *Array { func FromValue[T scalarTypes](t T) *Array {
Init() Init()
tt := New("") tt := New("") // finalizer set by New
switch v := any(t).(type) { switch v := any(t).(type) {
case bool: case bool:
tt.ctx = C.mlx_array_new_bool(C.bool(v)) tt.ctx = C.mlx_array_new_bool(C.bool(v))