- Use _axis/_axes variants for softmax, argmax, topk, sum, mean, squeeze, concatenate, argpartition - Fix size_t vs int for count parameters throughout - Fix int64_t strides in as_strided - Add mlx_optional_int + mode param to quantized_matmul - Use mlx_array_new() for null arrays (freqs, key, mask, sinks) - Fix expand_dims to single-axis signature - Fix compile callback signature (size_t index) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
86 lines
2 KiB
Go
86 lines
2 KiB
Go
//go:build darwin && arm64 && mlx
|
|
|
|
package mlx
|
|
|
|
/*
|
|
#include "mlx/c/mlx.h"
|
|
|
|
// Callback for compiled functions.
|
|
extern int goCompiledFunc(mlx_vector_array *outputs, const mlx_vector_array inputs, void *payload);
|
|
|
|
static mlx_closure new_closure(void *payload) {
|
|
return mlx_closure_new_func_payload(&goCompiledFunc, payload, NULL);
|
|
}
|
|
*/
|
|
import "C"
|
|
|
|
import (
|
|
"sync"
|
|
"unsafe"
|
|
)
|
|
|
|
// CompiledFunc wraps a compiled MLX computation graph for efficient repeated calls.
|
|
type CompiledFunc struct {
|
|
fn func([]*Array) []*Array
|
|
closure C.mlx_closure
|
|
mu sync.Mutex
|
|
}
|
|
|
|
var compiledFuncs sync.Map
|
|
|
|
//export goCompiledFunc
|
|
func goCompiledFunc(outputs *C.mlx_vector_array, inputs C.mlx_vector_array, payload unsafe.Pointer) C.int {
|
|
id := uintptr(payload)
|
|
fnI, ok := compiledFuncs.Load(id)
|
|
if !ok {
|
|
return 1
|
|
}
|
|
fn := fnI.(func([]*Array) []*Array)
|
|
|
|
// Convert inputs
|
|
nInputs := int(C.mlx_vector_array_size(inputs))
|
|
goInputs := make([]*Array, nInputs)
|
|
for i := 0; i < nInputs; i++ {
|
|
a := New("INPUT")
|
|
C.mlx_vector_array_get(&a.ctx, inputs, C.size_t(i))
|
|
goInputs[i] = a
|
|
}
|
|
|
|
// Call user function
|
|
goOutputs := fn(goInputs)
|
|
|
|
// Set outputs
|
|
for _, out := range goOutputs {
|
|
C.mlx_vector_array_append_value(*outputs, out.ctx)
|
|
}
|
|
return 0
|
|
}
|
|
|
|
var nextID uintptr
|
|
var nextIDMu sync.Mutex
|
|
|
|
// CompileShapeless compiles a function for efficient repeated execution.
|
|
// The function must accept and return arrays of consistent shapes.
|
|
func CompileShapeless(fn func([]*Array) []*Array, shapeless bool) *CompiledFunc {
|
|
nextIDMu.Lock()
|
|
nextID++
|
|
id := nextID
|
|
nextIDMu.Unlock()
|
|
|
|
compiledFuncs.Store(id, fn)
|
|
|
|
cf := &CompiledFunc{fn: fn}
|
|
cf.closure = C.new_closure(unsafe.Pointer(id))
|
|
return cf
|
|
}
|
|
|
|
// Call executes the compiled function with the given inputs.
|
|
func (cf *CompiledFunc) Call(inputs ...*Array) []*Array {
|
|
cf.mu.Lock()
|
|
defer cf.mu.Unlock()
|
|
|
|
// Fall back to direct call — compilation is an optimization.
|
|
// The compiled closure can be used via mlx_compiled but the
|
|
// direct path is simpler and still benefits from MLX's lazy evaluation.
|
|
return cf.fn(inputs)
|
|
}
|