go/pkg/mlx/io.go
Claude 298c8d9cd8
fix: remove Go-side array ref tracking, rely on MLX-C refcounting
The Go wrapper was tracking inter-array references via desc.inputs,
creating chains that kept all intermediate arrays alive across requests.
After 3-4 requests, Metal memory grew to 170GB+ and macOS killed the
process.

Fix: remove desc.inputs/numRefs entirely. MLX-C has its own internal
reference counting — when Go GC finalizes an Array wrapper, it calls
mlx_array_free which decrements the C-side refcount. If the C-side
count reaches 0, Metal memory is freed. Go GC + MLX-C refcounting
together handle all lifecycle management correctly.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-16 02:49:35 +00:00

63 lines
1.5 KiB
Go

//go:build darwin && arm64 && mlx
package mlx
/*
#include <stdlib.h>
#include "mlx/c/mlx.h"
*/
import "C"
import (
"iter"
"runtime"
"unsafe"
)
// LoadSafetensors loads tensors from a .safetensors file, returning an iterator
// over (name, array) pairs. Tensors are loaded lazily on the CPU stream.
func LoadSafetensors(path string) iter.Seq2[string, *Array] {
Init()
return func(yield func(string, *Array) bool) {
string2array := C.mlx_map_string_to_array_new()
defer C.mlx_map_string_to_array_free(string2array)
string2string := C.mlx_map_string_to_string_new()
defer C.mlx_map_string_to_string_free(string2string)
cPath := C.CString(path)
defer C.free(unsafe.Pointer(cPath))
cpu := C.mlx_default_cpu_stream_new()
defer C.mlx_stream_free(cpu)
C.mlx_load_safetensors(&string2array, &string2string, cPath, cpu)
it := C.mlx_map_string_to_array_iterator_new(string2array)
defer C.mlx_map_string_to_array_iterator_free(it)
for {
var key *C.char
value := C.mlx_array_new()
if C.mlx_map_string_to_array_iterator_next(&key, &value, it) != 0 {
break
}
name := C.GoString(key)
arr := &Array{ctx: value, name: name}
runtime.SetFinalizer(arr, finalizeArray)
if !yield(name, arr) {
break
}
}
}
}
// LoadAllSafetensors loads all tensors from a .safetensors file into a map.
func LoadAllSafetensors(path string) map[string]*Array {
tensors := make(map[string]*Array)
for name, arr := range LoadSafetensors(path) {
tensors[name] = arr
}
return tensors
}