feat/ml-integration #2

Merged
Snider merged 81 commits from feat/ml-integration into dev 2026-02-16 06:19:10 +00:00
2 changed files with 24 additions and 54 deletions
Showing only changes of commit 298c8d9cd8 - Show all commits

View file

@ -16,34 +16,20 @@ import (
"unsafe"
)
type tensorDesc struct {
name string
inputs []*Array
numRefs int
}
// Array wraps an mlx_array handle with reference-counted memory management.
// Array wraps an mlx_array handle.
// Memory management relies on Go GC finalizers to call mlx_array_free,
// which decrements MLX-C's internal reference count. MLX-C handles all
// cross-array references internally — the Go wrapper does not track them.
type Array struct {
ctx C.mlx_array
desc tensorDesc
name string // debug label
}
// New creates a named Array tracking its input dependencies for cleanup.
// A runtime finalizer is set so Go GC can release the C handle when
// the Array becomes unreachable — critical because Go GC cannot see
// Metal/C memory pressure.
// New creates a named Array and registers a GC finalizer.
// The inputs parameter is accepted for API compatibility but not stored —
// MLX-C tracks inter-array references via its own refcounting.
func New(name string, inputs ...*Array) *Array {
t := &Array{
desc: tensorDesc{
name: name,
inputs: inputs,
},
}
for _, input := range inputs {
if input != nil {
input.desc.numRefs++
}
}
t := &Array{name: name}
runtime.SetFinalizer(t, finalizeArray)
return t
}
@ -63,7 +49,7 @@ type scalarTypes interface {
// FromValue creates a scalar Array from a Go value.
func FromValue[T scalarTypes](t T) *Array {
Init()
tt := New("") // finalizer set by New
tt := New("")
switch v := any(t).(type) {
case bool:
tt.ctx = C.mlx_array_new_bool(C.bool(v))
@ -152,17 +138,14 @@ func Zeros(shape []int32, dtype DType) *Array {
return tt
}
// Set replaces this array's value with another, updating ref tracking.
// Set replaces this array's C handle with another's.
func (t *Array) Set(other *Array) {
Free(t.desc.inputs...)
other.desc.numRefs++
t.desc.inputs = []*Array{other}
C.mlx_array_set(&t.ctx, other.ctx)
}
// Clone creates a copy of this array sharing the same data.
// Clone creates a new Go wrapper sharing the same C handle (increments C refcount).
func (t *Array) Clone() *Array {
tt := New(t.desc.name, t.desc.inputs...)
tt := New(t.name)
C.mlx_array_set(&tt.ctx, t.ctx)
return tt
}
@ -254,33 +237,17 @@ func (t Array) Floats() []float32 {
return floats
}
// Free releases arrays using reference-counted cleanup.
// Arrays with remaining references are not freed.
// Free explicitly releases C array handles. Does not cascade — MLX-C's
// internal refcounting handles dependent arrays automatically.
func Free(s ...*Array) int {
var n int
free := make([]*Array, 0, 64)
fn := func(t *Array) {
for _, t := range s {
if t != nil && t.Valid() {
t.desc.numRefs--
if t.desc.numRefs <= 0 {
free = append(free, t.desc.inputs...)
n += t.NumBytes()
C.mlx_array_free(t.ctx)
t.ctx.ctx = nil
}
n += t.NumBytes()
C.mlx_array_free(t.ctx)
t.ctx.ctx = nil
runtime.SetFinalizer(t, nil) // cancel finalizer
}
}
for _, t := range s {
fn(t)
}
for len(free) > 0 {
tail := free[len(free)-1]
free = free[:len(free)-1]
fn(tail)
}
return n
}

View file

@ -10,6 +10,7 @@ import "C"
import (
"iter"
"runtime"
"unsafe"
)
@ -43,7 +44,9 @@ func LoadSafetensors(path string) iter.Seq2[string, *Array] {
}
name := C.GoString(key)
if !yield(name, &Array{ctx: value, desc: tensorDesc{name: name, numRefs: 1000}}) {
arr := &Array{ctx: value, name: name}
runtime.SetFinalizer(arr, finalizeArray)
if !yield(name, arr) {
break
}
}