fix: add GC-based memory management for MLX array handles

Go GC cannot see Metal/C memory pressure, so intermediate arrays from
each forward pass accumulated without bound, causing OOM kills after
3-4 requests. Fix: runtime.SetFinalizer on every Array releases C
handles when GC collects them, and runtime.GC() is forced every 4
tokens during generation. Also adds SetMemoryLimit(24GB) as a hard
Metal ceiling.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Claude 2026-02-16 02:42:28 +00:00
parent 478bbdd44c
commit a27a31faad
No known key found for this signature in database
GPG key ID: AF404715446AEB41
2 changed files with 33 additions and 7 deletions

View file

@ -6,6 +6,7 @@ import (
"context"
"fmt"
"log/slog"
"runtime"
"sync"
"forge.lthn.ai/core/cli/pkg/mlx"
@ -37,8 +38,10 @@ func NewMLXBackend(modelPath string) (*MLXBackend, error) {
return nil, fmt.Errorf("mlx: load model: %w", err)
}
// Set Metal cache limit to prevent unbounded memory growth
mlx.SetCacheLimit(16 * 1024 * 1024 * 1024) // 16 GB
// Cap Metal memory: cache limit for allocator reuse, memory limit as hard ceiling.
// This prevents runaway memory growth from killing the system.
mlx.SetCacheLimit(16 * 1024 * 1024 * 1024) // 16 GB allocator cache
mlx.SetMemoryLimit(24 * 1024 * 1024 * 1024) // 24 GB hard cap
slog.Info("mlx: model loaded",
"layers", m.NumLayers(),
@ -80,11 +83,13 @@ func (b *MLXBackend) Generate(ctx context.Context, prompt string, opts GenOpts)
maxTokens = 2048
}
// Generation loop
// Generation loop — force Go GC every 4 tokens so finalizers release
// intermediate C array handles that Go GC cannot see as memory pressure.
var output []int32
for i := 0; i < maxTokens; i++ {
select {
case <-ctx.Done():
runtime.GC()
mlx.ClearCache()
return b.tok.Decode(output), ctx.Err()
default:
@ -102,12 +107,15 @@ func (b *MLXBackend) Generate(ctx context.Context, prompt string, opts GenOpts)
output = append(output, nextToken)
input = mlx.FromValues([]int32{nextToken}, 1, 1)
// Periodically release Metal allocator cache to prevent memory growth
if i%8 == 7 {
// Force GC to collect intermediate arrays + release Metal allocator cache
if i%4 == 3 {
runtime.GC()
mlx.ClearCache()
}
}
// Full cleanup between requests
runtime.GC()
mlx.ClearCache()
return b.tok.Decode(output), nil
}
@ -167,6 +175,7 @@ func (b *MLXBackend) Chat(ctx context.Context, messages []Message, opts GenOpts)
for i := 0; i < maxTokens; i++ {
select {
case <-ctx.Done():
runtime.GC()
mlx.ClearCache()
return b.tok.Decode(output), ctx.Err()
default:
@ -184,11 +193,15 @@ func (b *MLXBackend) Chat(ctx context.Context, messages []Message, opts GenOpts)
output = append(output, nextToken)
input = mlx.FromValues([]int32{nextToken}, 1, 1)
if i%8 == 7 {
// Force GC to collect intermediate arrays + release Metal allocator cache
if i%4 == 3 {
runtime.GC()
mlx.ClearCache()
}
}
// Full cleanup between requests
runtime.GC()
mlx.ClearCache()
return b.tok.Decode(output), nil
}

View file

@ -11,6 +11,7 @@ import "C"
import (
"encoding/binary"
"reflect"
"runtime"
"strings"
"unsafe"
)
@ -28,6 +29,9 @@ type Array struct {
}
// New creates a named Array tracking its input dependencies for cleanup.
// A runtime finalizer is set so Go GC can release the C handle when
// the Array becomes unreachable — critical because Go GC cannot see
// Metal/C memory pressure.
func New(name string, inputs ...*Array) *Array {
t := &Array{
desc: tensorDesc{
@ -40,9 +44,18 @@ func New(name string, inputs ...*Array) *Array {
input.desc.numRefs++
}
}
runtime.SetFinalizer(t, finalizeArray)
return t
}
// finalizeArray is called by Go GC to release the underlying C array handle.
func finalizeArray(t *Array) {
if t != nil && t.ctx.ctx != nil {
C.mlx_array_free(t.ctx)
t.ctx.ctx = nil
}
}
type scalarTypes interface {
~bool | ~int | ~float32 | ~float64 | ~complex64
}
@ -50,7 +63,7 @@ type scalarTypes interface {
// FromValue creates a scalar Array from a Go value.
func FromValue[T scalarTypes](t T) *Array {
Init()
tt := New("")
tt := New("") // finalizer set by New
switch v := any(t).(type) {
case bool:
tt.ctx = C.mlx_array_new_bool(C.bool(v))