fix: remove Go-side array ref tracking, rely on MLX-C refcounting
The Go wrapper was tracking inter-array references via desc.inputs, creating chains that kept all intermediate arrays alive across requests. After 3-4 requests, Metal memory grew to 170GB+ and macOS killed the process. Fix: remove desc.inputs/numRefs entirely. MLX-C has its own internal reference counting — when Go GC finalizes an Array wrapper, it calls mlx_array_free which decrements the C-side refcount. If the C-side count reaches 0, Metal memory is freed. Go GC + MLX-C refcounting together handle all lifecycle management correctly. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
8cdafc8d66
commit
6b603ee20b
2 changed files with 24 additions and 54 deletions
|
|
@ -16,34 +16,20 @@ import (
|
||||||
"unsafe"
|
"unsafe"
|
||||||
)
|
)
|
||||||
|
|
||||||
type tensorDesc struct {
|
// Array wraps an mlx_array handle.
|
||||||
name string
|
// Memory management relies on Go GC finalizers to call mlx_array_free,
|
||||||
inputs []*Array
|
// which decrements MLX-C's internal reference count. MLX-C handles all
|
||||||
numRefs int
|
// cross-array references internally — the Go wrapper does not track them.
|
||||||
}
|
|
||||||
|
|
||||||
// Array wraps an mlx_array handle with reference-counted memory management.
|
|
||||||
type Array struct {
|
type Array struct {
|
||||||
ctx C.mlx_array
|
ctx C.mlx_array
|
||||||
desc tensorDesc
|
name string // debug label
|
||||||
}
|
}
|
||||||
|
|
||||||
// New creates a named Array tracking its input dependencies for cleanup.
|
// New creates a named Array and registers a GC finalizer.
|
||||||
// A runtime finalizer is set so Go GC can release the C handle when
|
// The inputs parameter is accepted for API compatibility but not stored —
|
||||||
// the Array becomes unreachable — critical because Go GC cannot see
|
// MLX-C tracks inter-array references via its own refcounting.
|
||||||
// Metal/C memory pressure.
|
|
||||||
func New(name string, inputs ...*Array) *Array {
|
func New(name string, inputs ...*Array) *Array {
|
||||||
t := &Array{
|
t := &Array{name: name}
|
||||||
desc: tensorDesc{
|
|
||||||
name: name,
|
|
||||||
inputs: inputs,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for _, input := range inputs {
|
|
||||||
if input != nil {
|
|
||||||
input.desc.numRefs++
|
|
||||||
}
|
|
||||||
}
|
|
||||||
runtime.SetFinalizer(t, finalizeArray)
|
runtime.SetFinalizer(t, finalizeArray)
|
||||||
return t
|
return t
|
||||||
}
|
}
|
||||||
|
|
@ -63,7 +49,7 @@ type scalarTypes interface {
|
||||||
// FromValue creates a scalar Array from a Go value.
|
// FromValue creates a scalar Array from a Go value.
|
||||||
func FromValue[T scalarTypes](t T) *Array {
|
func FromValue[T scalarTypes](t T) *Array {
|
||||||
Init()
|
Init()
|
||||||
tt := New("") // finalizer set by New
|
tt := New("")
|
||||||
switch v := any(t).(type) {
|
switch v := any(t).(type) {
|
||||||
case bool:
|
case bool:
|
||||||
tt.ctx = C.mlx_array_new_bool(C.bool(v))
|
tt.ctx = C.mlx_array_new_bool(C.bool(v))
|
||||||
|
|
@ -152,17 +138,14 @@ func Zeros(shape []int32, dtype DType) *Array {
|
||||||
return tt
|
return tt
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set replaces this array's value with another, updating ref tracking.
|
// Set replaces this array's C handle with another's.
|
||||||
func (t *Array) Set(other *Array) {
|
func (t *Array) Set(other *Array) {
|
||||||
Free(t.desc.inputs...)
|
|
||||||
other.desc.numRefs++
|
|
||||||
t.desc.inputs = []*Array{other}
|
|
||||||
C.mlx_array_set(&t.ctx, other.ctx)
|
C.mlx_array_set(&t.ctx, other.ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Clone creates a copy of this array sharing the same data.
|
// Clone creates a new Go wrapper sharing the same C handle (increments C refcount).
|
||||||
func (t *Array) Clone() *Array {
|
func (t *Array) Clone() *Array {
|
||||||
tt := New(t.desc.name, t.desc.inputs...)
|
tt := New(t.name)
|
||||||
C.mlx_array_set(&tt.ctx, t.ctx)
|
C.mlx_array_set(&tt.ctx, t.ctx)
|
||||||
return tt
|
return tt
|
||||||
}
|
}
|
||||||
|
|
@ -254,33 +237,17 @@ func (t Array) Floats() []float32 {
|
||||||
return floats
|
return floats
|
||||||
}
|
}
|
||||||
|
|
||||||
// Free releases arrays using reference-counted cleanup.
|
// Free explicitly releases C array handles. Does not cascade — MLX-C's
|
||||||
// Arrays with remaining references are not freed.
|
// internal refcounting handles dependent arrays automatically.
|
||||||
func Free(s ...*Array) int {
|
func Free(s ...*Array) int {
|
||||||
var n int
|
var n int
|
||||||
free := make([]*Array, 0, 64)
|
for _, t := range s {
|
||||||
|
|
||||||
fn := func(t *Array) {
|
|
||||||
if t != nil && t.Valid() {
|
if t != nil && t.Valid() {
|
||||||
t.desc.numRefs--
|
|
||||||
if t.desc.numRefs <= 0 {
|
|
||||||
free = append(free, t.desc.inputs...)
|
|
||||||
n += t.NumBytes()
|
n += t.NumBytes()
|
||||||
C.mlx_array_free(t.ctx)
|
C.mlx_array_free(t.ctx)
|
||||||
t.ctx.ctx = nil
|
t.ctx.ctx = nil
|
||||||
|
runtime.SetFinalizer(t, nil) // cancel finalizer
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
for _, t := range s {
|
|
||||||
fn(t)
|
|
||||||
}
|
|
||||||
|
|
||||||
for len(free) > 0 {
|
|
||||||
tail := free[len(free)-1]
|
|
||||||
free = free[:len(free)-1]
|
|
||||||
fn(tail)
|
|
||||||
}
|
|
||||||
|
|
||||||
return n
|
return n
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -10,6 +10,7 @@ import "C"
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"iter"
|
"iter"
|
||||||
|
"runtime"
|
||||||
"unsafe"
|
"unsafe"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -43,7 +44,9 @@ func LoadSafetensors(path string) iter.Seq2[string, *Array] {
|
||||||
}
|
}
|
||||||
|
|
||||||
name := C.GoString(key)
|
name := C.GoString(key)
|
||||||
if !yield(name, &Array{ctx: value, desc: tensorDesc{name: name, numRefs: 1000}}) {
|
arr := &Array{ctx: value, name: name}
|
||||||
|
runtime.SetFinalizer(arr, finalizeArray)
|
||||||
|
if !yield(name, arr) {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue