fix: remove Go-side array ref tracking, rely on MLX-C refcounting
The Go wrapper was tracking inter-array references via desc.inputs, creating chains that kept all intermediate arrays alive across requests. After 3-4 requests, Metal memory grew to 170GB+ and macOS killed the process. Fix: remove desc.inputs/numRefs entirely. MLX-C has its own internal reference counting — when Go GC finalizes an Array wrapper, it calls mlx_array_free which decrements the C-side refcount. If the C-side count reaches 0, Metal memory is freed. Go GC + MLX-C refcounting together handle all lifecycle management correctly. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
1d4ec55d05
commit
c5689c3e83
2 changed files with 24 additions and 54 deletions
|
|
@ -16,34 +16,20 @@ import (
|
|||
"unsafe"
|
||||
)
|
||||
|
||||
type tensorDesc struct {
|
||||
name string
|
||||
inputs []*Array
|
||||
numRefs int
|
||||
}
|
||||
|
||||
// Array wraps an mlx_array handle with reference-counted memory management.
|
||||
// Array wraps an mlx_array handle.
|
||||
// Memory management relies on Go GC finalizers to call mlx_array_free,
|
||||
// which decrements MLX-C's internal reference count. MLX-C handles all
|
||||
// cross-array references internally — the Go wrapper does not track them.
|
||||
type Array struct {
|
||||
ctx C.mlx_array
|
||||
desc tensorDesc
|
||||
name string // debug label
|
||||
}
|
||||
|
||||
// New creates a named Array tracking its input dependencies for cleanup.
|
||||
// A runtime finalizer is set so Go GC can release the C handle when
|
||||
// the Array becomes unreachable — critical because Go GC cannot see
|
||||
// Metal/C memory pressure.
|
||||
// New creates a named Array and registers a GC finalizer.
|
||||
// The inputs parameter is accepted for API compatibility but not stored —
|
||||
// MLX-C tracks inter-array references via its own refcounting.
|
||||
func New(name string, inputs ...*Array) *Array {
|
||||
t := &Array{
|
||||
desc: tensorDesc{
|
||||
name: name,
|
||||
inputs: inputs,
|
||||
},
|
||||
}
|
||||
for _, input := range inputs {
|
||||
if input != nil {
|
||||
input.desc.numRefs++
|
||||
}
|
||||
}
|
||||
t := &Array{name: name}
|
||||
runtime.SetFinalizer(t, finalizeArray)
|
||||
return t
|
||||
}
|
||||
|
|
@ -63,7 +49,7 @@ type scalarTypes interface {
|
|||
// FromValue creates a scalar Array from a Go value.
|
||||
func FromValue[T scalarTypes](t T) *Array {
|
||||
Init()
|
||||
tt := New("") // finalizer set by New
|
||||
tt := New("")
|
||||
switch v := any(t).(type) {
|
||||
case bool:
|
||||
tt.ctx = C.mlx_array_new_bool(C.bool(v))
|
||||
|
|
@ -152,17 +138,14 @@ func Zeros(shape []int32, dtype DType) *Array {
|
|||
return tt
|
||||
}
|
||||
|
||||
// Set replaces this array's value with another, updating ref tracking.
|
||||
// Set replaces this array's C handle with another's.
|
||||
func (t *Array) Set(other *Array) {
|
||||
Free(t.desc.inputs...)
|
||||
other.desc.numRefs++
|
||||
t.desc.inputs = []*Array{other}
|
||||
C.mlx_array_set(&t.ctx, other.ctx)
|
||||
}
|
||||
|
||||
// Clone creates a copy of this array sharing the same data.
|
||||
// Clone creates a new Go wrapper sharing the same C handle (increments C refcount).
|
||||
func (t *Array) Clone() *Array {
|
||||
tt := New(t.desc.name, t.desc.inputs...)
|
||||
tt := New(t.name)
|
||||
C.mlx_array_set(&tt.ctx, t.ctx)
|
||||
return tt
|
||||
}
|
||||
|
|
@ -254,33 +237,17 @@ func (t Array) Floats() []float32 {
|
|||
return floats
|
||||
}
|
||||
|
||||
// Free releases arrays using reference-counted cleanup.
|
||||
// Arrays with remaining references are not freed.
|
||||
// Free explicitly releases C array handles. Does not cascade — MLX-C's
|
||||
// internal refcounting handles dependent arrays automatically.
|
||||
func Free(s ...*Array) int {
|
||||
var n int
|
||||
free := make([]*Array, 0, 64)
|
||||
|
||||
fn := func(t *Array) {
|
||||
for _, t := range s {
|
||||
if t != nil && t.Valid() {
|
||||
t.desc.numRefs--
|
||||
if t.desc.numRefs <= 0 {
|
||||
free = append(free, t.desc.inputs...)
|
||||
n += t.NumBytes()
|
||||
C.mlx_array_free(t.ctx)
|
||||
t.ctx.ctx = nil
|
||||
runtime.SetFinalizer(t, nil) // cancel finalizer
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, t := range s {
|
||||
fn(t)
|
||||
}
|
||||
|
||||
for len(free) > 0 {
|
||||
tail := free[len(free)-1]
|
||||
free = free[:len(free)-1]
|
||||
fn(tail)
|
||||
}
|
||||
|
||||
return n
|
||||
}
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ import "C"
|
|||
|
||||
import (
|
||||
"iter"
|
||||
"runtime"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
|
|
@ -43,7 +44,9 @@ func LoadSafetensors(path string) iter.Seq2[string, *Array] {
|
|||
}
|
||||
|
||||
name := C.GoString(key)
|
||||
if !yield(name, &Array{ctx: value, desc: tensorDesc{name: name, numRefs: 1000}}) {
|
||||
arr := &Array{ctx: value, name: name}
|
||||
runtime.SetFinalizer(arr, finalizeArray)
|
||||
if !yield(name, arr) {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue