cli/pkg/mlx/compile.go
Claude a0435a84ea fix: correct 20 mlx-c API mismatches for v0.4.1
- Use _axis/_axes variants for softmax, argmax, topk, sum, mean, squeeze,
  concatenate, argpartition
- Fix size_t vs int for count parameters throughout
- Fix int64_t strides in as_strided
- Add mlx_optional_int + mode param to quantized_matmul
- Use mlx_array_new() for null arrays (freqs, key, mask, sinks)
- Fix expand_dims to single-axis signature
- Fix compile callback signature (size_t index)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-16 05:53:52 +00:00

86 lines
2 KiB
Go

//go:build darwin && arm64 && mlx
package mlx
/*
#include "mlx/c/mlx.h"
// Callback for compiled functions.
extern int goCompiledFunc(mlx_vector_array *outputs, const mlx_vector_array inputs, void *payload);
static mlx_closure new_closure(void *payload) {
return mlx_closure_new_func_payload(&goCompiledFunc, payload, NULL);
}
*/
import "C"
import (
"sync"
"unsafe"
)
// CompiledFunc wraps a compiled MLX computation graph for efficient repeated calls.
type CompiledFunc struct {
fn func([]*Array) []*Array
closure C.mlx_closure
mu sync.Mutex
}
var compiledFuncs sync.Map
//export goCompiledFunc
func goCompiledFunc(outputs *C.mlx_vector_array, inputs C.mlx_vector_array, payload unsafe.Pointer) C.int {
id := uintptr(payload)
fnI, ok := compiledFuncs.Load(id)
if !ok {
return 1
}
fn := fnI.(func([]*Array) []*Array)
// Convert inputs
nInputs := int(C.mlx_vector_array_size(inputs))
goInputs := make([]*Array, nInputs)
for i := 0; i < nInputs; i++ {
a := New("INPUT")
C.mlx_vector_array_get(&a.ctx, inputs, C.size_t(i))
goInputs[i] = a
}
// Call user function
goOutputs := fn(goInputs)
// Set outputs
for _, out := range goOutputs {
C.mlx_vector_array_append_value(*outputs, out.ctx)
}
return 0
}
var nextID uintptr
var nextIDMu sync.Mutex
// CompileShapeless compiles a function for efficient repeated execution.
// The function must accept and return arrays of consistent shapes.
func CompileShapeless(fn func([]*Array) []*Array, shapeless bool) *CompiledFunc {
nextIDMu.Lock()
nextID++
id := nextID
nextIDMu.Unlock()
compiledFuncs.Store(id, fn)
cf := &CompiledFunc{fn: fn}
cf.closure = C.new_closure(unsafe.Pointer(id))
return cf
}
// Call executes the compiled function with the given inputs.
func (cf *CompiledFunc) Call(inputs ...*Array) []*Array {
cf.mu.Lock()
defer cf.mu.Unlock()
// Fall back to direct call — compilation is an optimization.
// The compiled closure can be used via mlx_compiled but the
// direct path is simpler and still benefits from MLX's lazy evaluation.
return cf.fn(inputs)
}