fix: remove Go-side array ref tracking, rely on MLX-C refcounting

The Go wrapper was tracking inter-array references via desc.inputs,
creating chains that kept all intermediate arrays alive across requests.
After 3-4 requests, Metal memory grew to 170GB+ and macOS killed the
process.

Fix: remove desc.inputs/numRefs entirely. MLX-C has its own internal
reference counting — when Go GC finalizes an Array wrapper, it calls
mlx_array_free which decrements the C-side refcount. If the C-side
count reaches 0, Metal memory is freed. Go GC + MLX-C refcounting
together handle all lifecycle management correctly.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Claude 2026-02-16 02:49:35 +00:00 committed by Snider
parent 8cdafc8d66
commit 6b603ee20b
2 changed files with 24 additions and 54 deletions

View file

@ -16,34 +16,20 @@ import (
"unsafe" "unsafe"
) )
type tensorDesc struct { // Array wraps an mlx_array handle.
name string // Memory management relies on Go GC finalizers to call mlx_array_free,
inputs []*Array // which decrements MLX-C's internal reference count. MLX-C handles all
numRefs int // cross-array references internally — the Go wrapper does not track them.
}
// Array wraps an mlx_array handle with reference-counted memory management.
type Array struct { type Array struct {
ctx C.mlx_array ctx C.mlx_array
desc tensorDesc name string // debug label
} }
// New creates a named Array tracking its input dependencies for cleanup. // New creates a named Array and registers a GC finalizer.
// A runtime finalizer is set so Go GC can release the C handle when // The inputs parameter is accepted for API compatibility but not stored —
// the Array becomes unreachable — critical because Go GC cannot see // MLX-C tracks inter-array references via its own refcounting.
// Metal/C memory pressure.
func New(name string, inputs ...*Array) *Array { func New(name string, inputs ...*Array) *Array {
t := &Array{ t := &Array{name: name}
desc: tensorDesc{
name: name,
inputs: inputs,
},
}
for _, input := range inputs {
if input != nil {
input.desc.numRefs++
}
}
runtime.SetFinalizer(t, finalizeArray) runtime.SetFinalizer(t, finalizeArray)
return t return t
} }
@ -63,7 +49,7 @@ type scalarTypes interface {
// FromValue creates a scalar Array from a Go value. // FromValue creates a scalar Array from a Go value.
func FromValue[T scalarTypes](t T) *Array { func FromValue[T scalarTypes](t T) *Array {
Init() Init()
tt := New("") // finalizer set by New tt := New("")
switch v := any(t).(type) { switch v := any(t).(type) {
case bool: case bool:
tt.ctx = C.mlx_array_new_bool(C.bool(v)) tt.ctx = C.mlx_array_new_bool(C.bool(v))
@ -152,17 +138,14 @@ func Zeros(shape []int32, dtype DType) *Array {
return tt return tt
} }
// Set replaces this array's value with another, updating ref tracking. // Set replaces this array's C handle with another's.
func (t *Array) Set(other *Array) { func (t *Array) Set(other *Array) {
Free(t.desc.inputs...)
other.desc.numRefs++
t.desc.inputs = []*Array{other}
C.mlx_array_set(&t.ctx, other.ctx) C.mlx_array_set(&t.ctx, other.ctx)
} }
// Clone creates a copy of this array sharing the same data. // Clone creates a new Go wrapper sharing the same C handle (increments C refcount).
func (t *Array) Clone() *Array { func (t *Array) Clone() *Array {
tt := New(t.desc.name, t.desc.inputs...) tt := New(t.name)
C.mlx_array_set(&tt.ctx, t.ctx) C.mlx_array_set(&tt.ctx, t.ctx)
return tt return tt
} }
@ -254,33 +237,17 @@ func (t Array) Floats() []float32 {
return floats return floats
} }
// Free releases arrays using reference-counted cleanup. // Free explicitly releases C array handles. Does not cascade — MLX-C's
// Arrays with remaining references are not freed. // internal refcounting handles dependent arrays automatically.
func Free(s ...*Array) int { func Free(s ...*Array) int {
var n int var n int
free := make([]*Array, 0, 64) for _, t := range s {
fn := func(t *Array) {
if t != nil && t.Valid() { if t != nil && t.Valid() {
t.desc.numRefs--
if t.desc.numRefs <= 0 {
free = append(free, t.desc.inputs...)
n += t.NumBytes() n += t.NumBytes()
C.mlx_array_free(t.ctx) C.mlx_array_free(t.ctx)
t.ctx.ctx = nil t.ctx.ctx = nil
runtime.SetFinalizer(t, nil) // cancel finalizer
} }
} }
}
for _, t := range s {
fn(t)
}
for len(free) > 0 {
tail := free[len(free)-1]
free = free[:len(free)-1]
fn(tail)
}
return n return n
} }

View file

@ -10,6 +10,7 @@ import "C"
import ( import (
"iter" "iter"
"runtime"
"unsafe" "unsafe"
) )
@ -43,7 +44,9 @@ func LoadSafetensors(path string) iter.Seq2[string, *Array] {
} }
name := C.GoString(key) name := C.GoString(key)
if !yield(name, &Array{ctx: value, desc: tensorDesc{name: name, numRefs: 1000}}) { arr := &Array{ctx: value, name: name}
runtime.SetFinalizer(arr, finalizeArray)
if !yield(name, arr) {
break break
} }
} }