diff --git a/pkg/mlx/array.go b/pkg/mlx/array.go index 2ea18a7b..6d36df23 100644 --- a/pkg/mlx/array.go +++ b/pkg/mlx/array.go @@ -16,34 +16,20 @@ import ( "unsafe" ) -type tensorDesc struct { - name string - inputs []*Array - numRefs int -} - -// Array wraps an mlx_array handle with reference-counted memory management. +// Array wraps an mlx_array handle. +// Memory management relies on Go GC finalizers to call mlx_array_free, +// which decrements MLX-C's internal reference count. MLX-C handles all +// cross-array references internally — the Go wrapper does not track them. type Array struct { ctx C.mlx_array - desc tensorDesc + name string // debug label } -// New creates a named Array tracking its input dependencies for cleanup. -// A runtime finalizer is set so Go GC can release the C handle when -// the Array becomes unreachable — critical because Go GC cannot see -// Metal/C memory pressure. +// New creates a named Array and registers a GC finalizer. +// The inputs parameter is accepted for API compatibility but not stored — +// MLX-C tracks inter-array references via its own refcounting. func New(name string, inputs ...*Array) *Array { - t := &Array{ - desc: tensorDesc{ - name: name, - inputs: inputs, - }, - } - for _, input := range inputs { - if input != nil { - input.desc.numRefs++ - } - } + t := &Array{name: name} runtime.SetFinalizer(t, finalizeArray) return t } @@ -63,7 +49,7 @@ type scalarTypes interface { // FromValue creates a scalar Array from a Go value. func FromValue[T scalarTypes](t T) *Array { Init() - tt := New("") // finalizer set by New + tt := New("") switch v := any(t).(type) { case bool: tt.ctx = C.mlx_array_new_bool(C.bool(v)) @@ -152,17 +138,14 @@ func Zeros(shape []int32, dtype DType) *Array { return tt } -// Set replaces this array's value with another, updating ref tracking. +// Set replaces this array's C handle with another's. func (t *Array) Set(other *Array) { - Free(t.desc.inputs...) - other.desc.numRefs++ - t.desc.inputs = []*Array{other} C.mlx_array_set(&t.ctx, other.ctx) } -// Clone creates a copy of this array sharing the same data. +// Clone creates a new Go wrapper sharing the same C handle (increments C refcount). func (t *Array) Clone() *Array { - tt := New(t.desc.name, t.desc.inputs...) + tt := New(t.name) C.mlx_array_set(&tt.ctx, t.ctx) return tt } @@ -254,33 +237,17 @@ func (t Array) Floats() []float32 { return floats } -// Free releases arrays using reference-counted cleanup. -// Arrays with remaining references are not freed. +// Free explicitly releases C array handles. Does not cascade — MLX-C's +// internal refcounting handles dependent arrays automatically. func Free(s ...*Array) int { var n int - free := make([]*Array, 0, 64) - - fn := func(t *Array) { + for _, t := range s { if t != nil && t.Valid() { - t.desc.numRefs-- - if t.desc.numRefs <= 0 { - free = append(free, t.desc.inputs...) - n += t.NumBytes() - C.mlx_array_free(t.ctx) - t.ctx.ctx = nil - } + n += t.NumBytes() + C.mlx_array_free(t.ctx) + t.ctx.ctx = nil + runtime.SetFinalizer(t, nil) // cancel finalizer } } - - for _, t := range s { - fn(t) - } - - for len(free) > 0 { - tail := free[len(free)-1] - free = free[:len(free)-1] - fn(tail) - } - return n } diff --git a/pkg/mlx/io.go b/pkg/mlx/io.go index e4aa363c..c7247b2f 100644 --- a/pkg/mlx/io.go +++ b/pkg/mlx/io.go @@ -10,6 +10,7 @@ import "C" import ( "iter" + "runtime" "unsafe" ) @@ -43,7 +44,9 @@ func LoadSafetensors(path string) iter.Seq2[string, *Array] { } name := C.GoString(key) - if !yield(name, &Array{ctx: value, desc: tensorDesc{name: name, numRefs: 1000}}) { + arr := &Array{ctx: value, name: name} + runtime.SetFinalizer(arr, finalizeArray) + if !yield(name, arr) { break } }