fix: remove Go-side array ref tracking, rely on MLX-C refcounting

The Go wrapper was tracking inter-array references via desc.inputs,
creating chains that kept all intermediate arrays alive across requests.
After 3-4 requests, Metal memory grew to 170GB+ and macOS killed the
process.

Fix: remove desc.inputs/numRefs entirely. MLX-C has its own internal
reference counting — when Go GC finalizes an Array wrapper, it calls
mlx_array_free which decrements the C-side refcount. If the C-side
count reaches 0, Metal memory is freed. Go GC + MLX-C refcounting
together handle all lifecycle management correctly.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Claude 2026-02-16 02:49:35 +00:00 committed by Snider
parent 1d4ec55d05
commit c5689c3e83
2 changed files with 24 additions and 54 deletions

View file

@ -16,34 +16,20 @@ import (
"unsafe"
)
type tensorDesc struct {
name string
inputs []*Array
numRefs int
}
// Array wraps an mlx_array handle with reference-counted memory management.
// Array wraps an mlx_array handle.
// Memory management relies on Go GC finalizers to call mlx_array_free,
// which decrements MLX-C's internal reference count. MLX-C handles all
// cross-array references internally — the Go wrapper does not track them.
type Array struct {
ctx C.mlx_array
desc tensorDesc
name string // debug label
}
// New creates a named Array tracking its input dependencies for cleanup.
// A runtime finalizer is set so Go GC can release the C handle when
// the Array becomes unreachable — critical because Go GC cannot see
// Metal/C memory pressure.
// New creates a named Array and registers a GC finalizer.
// The inputs parameter is accepted for API compatibility but not stored —
// MLX-C tracks inter-array references via its own refcounting.
func New(name string, inputs ...*Array) *Array {
t := &Array{
desc: tensorDesc{
name: name,
inputs: inputs,
},
}
for _, input := range inputs {
if input != nil {
input.desc.numRefs++
}
}
t := &Array{name: name}
runtime.SetFinalizer(t, finalizeArray)
return t
}
@ -63,7 +49,7 @@ type scalarTypes interface {
// FromValue creates a scalar Array from a Go value.
func FromValue[T scalarTypes](t T) *Array {
Init()
tt := New("") // finalizer set by New
tt := New("")
switch v := any(t).(type) {
case bool:
tt.ctx = C.mlx_array_new_bool(C.bool(v))
@ -152,17 +138,14 @@ func Zeros(shape []int32, dtype DType) *Array {
return tt
}
// Set replaces this array's value with another, updating ref tracking.
// Set replaces this array's C handle with another's.
func (t *Array) Set(other *Array) {
Free(t.desc.inputs...)
other.desc.numRefs++
t.desc.inputs = []*Array{other}
C.mlx_array_set(&t.ctx, other.ctx)
}
// Clone creates a copy of this array sharing the same data.
// Clone creates a new Go wrapper sharing the same C handle (increments C refcount).
func (t *Array) Clone() *Array {
tt := New(t.desc.name, t.desc.inputs...)
tt := New(t.name)
C.mlx_array_set(&tt.ctx, t.ctx)
return tt
}
@ -254,33 +237,17 @@ func (t Array) Floats() []float32 {
return floats
}
// Free releases arrays using reference-counted cleanup.
// Arrays with remaining references are not freed.
// Free explicitly releases C array handles. Does not cascade — MLX-C's
// internal refcounting handles dependent arrays automatically.
func Free(s ...*Array) int {
var n int
free := make([]*Array, 0, 64)
fn := func(t *Array) {
for _, t := range s {
if t != nil && t.Valid() {
t.desc.numRefs--
if t.desc.numRefs <= 0 {
free = append(free, t.desc.inputs...)
n += t.NumBytes()
C.mlx_array_free(t.ctx)
t.ctx.ctx = nil
runtime.SetFinalizer(t, nil) // cancel finalizer
}
}
}
for _, t := range s {
fn(t)
}
for len(free) > 0 {
tail := free[len(free)-1]
free = free[:len(free)-1]
fn(tail)
}
return n
}

View file

@ -10,6 +10,7 @@ import "C"
import (
"iter"
"runtime"
"unsafe"
)
@ -43,7 +44,9 @@ func LoadSafetensors(path string) iter.Seq2[string, *Array] {
}
name := C.GoString(key)
if !yield(name, &Array{ctx: value, desc: tensorDesc{name: name, numRefs: 1000}}) {
arr := &Array{ctx: value, name: name}
runtime.SetFinalizer(arr, finalizeArray)
if !yield(name, arr) {
break
}
}