From c5689c3e8309bb52003653beb22ddf3727bf1aaa Mon Sep 17 00:00:00 2001 From: Claude Date: Mon, 16 Feb 2026 02:49:35 +0000 Subject: [PATCH] fix: remove Go-side array ref tracking, rely on MLX-C refcounting MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The Go wrapper was tracking inter-array references via desc.inputs, creating chains that kept all intermediate arrays alive across requests. After 3-4 requests, Metal memory grew to 170GB+ and macOS killed the process. Fix: remove desc.inputs/numRefs entirely. MLX-C has its own internal reference counting — when Go GC finalizes an Array wrapper, it calls mlx_array_free which decrements the C-side refcount. If the C-side count reaches 0, Metal memory is freed. Go GC + MLX-C refcounting together handle all lifecycle management correctly. Co-Authored-By: Claude Opus 4.6 --- pkg/mlx/array.go | 73 +++++++++++++----------------------------------- pkg/mlx/io.go | 5 +++- 2 files changed, 24 insertions(+), 54 deletions(-) diff --git a/pkg/mlx/array.go b/pkg/mlx/array.go index 2ea18a7b..6d36df23 100644 --- a/pkg/mlx/array.go +++ b/pkg/mlx/array.go @@ -16,34 +16,20 @@ import ( "unsafe" ) -type tensorDesc struct { - name string - inputs []*Array - numRefs int -} - -// Array wraps an mlx_array handle with reference-counted memory management. +// Array wraps an mlx_array handle. +// Memory management relies on Go GC finalizers to call mlx_array_free, +// which decrements MLX-C's internal reference count. MLX-C handles all +// cross-array references internally — the Go wrapper does not track them. type Array struct { ctx C.mlx_array - desc tensorDesc + name string // debug label } -// New creates a named Array tracking its input dependencies for cleanup. -// A runtime finalizer is set so Go GC can release the C handle when -// the Array becomes unreachable — critical because Go GC cannot see -// Metal/C memory pressure. +// New creates a named Array and registers a GC finalizer. +// The inputs parameter is accepted for API compatibility but not stored — +// MLX-C tracks inter-array references via its own refcounting. func New(name string, inputs ...*Array) *Array { - t := &Array{ - desc: tensorDesc{ - name: name, - inputs: inputs, - }, - } - for _, input := range inputs { - if input != nil { - input.desc.numRefs++ - } - } + t := &Array{name: name} runtime.SetFinalizer(t, finalizeArray) return t } @@ -63,7 +49,7 @@ type scalarTypes interface { // FromValue creates a scalar Array from a Go value. func FromValue[T scalarTypes](t T) *Array { Init() - tt := New("") // finalizer set by New + tt := New("") switch v := any(t).(type) { case bool: tt.ctx = C.mlx_array_new_bool(C.bool(v)) @@ -152,17 +138,14 @@ func Zeros(shape []int32, dtype DType) *Array { return tt } -// Set replaces this array's value with another, updating ref tracking. +// Set replaces this array's C handle with another's. func (t *Array) Set(other *Array) { - Free(t.desc.inputs...) - other.desc.numRefs++ - t.desc.inputs = []*Array{other} C.mlx_array_set(&t.ctx, other.ctx) } -// Clone creates a copy of this array sharing the same data. +// Clone creates a new Go wrapper sharing the same C handle (increments C refcount). func (t *Array) Clone() *Array { - tt := New(t.desc.name, t.desc.inputs...) + tt := New(t.name) C.mlx_array_set(&tt.ctx, t.ctx) return tt } @@ -254,33 +237,17 @@ func (t Array) Floats() []float32 { return floats } -// Free releases arrays using reference-counted cleanup. -// Arrays with remaining references are not freed. +// Free explicitly releases C array handles. Does not cascade — MLX-C's +// internal refcounting handles dependent arrays automatically. func Free(s ...*Array) int { var n int - free := make([]*Array, 0, 64) - - fn := func(t *Array) { + for _, t := range s { if t != nil && t.Valid() { - t.desc.numRefs-- - if t.desc.numRefs <= 0 { - free = append(free, t.desc.inputs...) - n += t.NumBytes() - C.mlx_array_free(t.ctx) - t.ctx.ctx = nil - } + n += t.NumBytes() + C.mlx_array_free(t.ctx) + t.ctx.ctx = nil + runtime.SetFinalizer(t, nil) // cancel finalizer } } - - for _, t := range s { - fn(t) - } - - for len(free) > 0 { - tail := free[len(free)-1] - free = free[:len(free)-1] - fn(tail) - } - return n } diff --git a/pkg/mlx/io.go b/pkg/mlx/io.go index e4aa363c..c7247b2f 100644 --- a/pkg/mlx/io.go +++ b/pkg/mlx/io.go @@ -10,6 +10,7 @@ import "C" import ( "iter" + "runtime" "unsafe" ) @@ -43,7 +44,9 @@ func LoadSafetensors(path string) iter.Seq2[string, *Array] { } name := C.GoString(key) - if !yield(name, &Array{ctx: value, desc: tensorDesc{name: name, numRefs: 1000}}) { + arr := &Array{ctx: value, name: name} + runtime.SetFinalizer(arr, finalizeArray) + if !yield(name, arr) { break } }