go-ai/mlx/compile.go
Snider e9973aef3c feat(mlx): add autograd — VJP, JVP, ValueAndGrad, loss functions
Native Go bindings for MLX-C gradient computation on Apple Silicon.
Foundation for LoRA training without Python.

- VJP (reverse-mode autodiff) for backward pass
- JVP (forward-mode autodiff) for directional derivatives
- ValueAndGrad for combined loss + gradient computation
- Checkpoint for memory-efficient gradient recomputation
- CrossEntropyLoss (numerically stable via LogSumExp)
- MSELoss, Log, SumAll, MeanAll, OnesLike helpers
- TakeAlongAxis and LogSumExp ops
- Fix closure callback null vector bug (affects compile.go too)
- Fix Float() returning 0 for float32 arrays

14 tests passing on Metal GPU.

Co-Authored-By: Virgil <virgil@lethean.io>
2026-02-17 17:18:47 +00:00

90 lines
2.2 KiB
Go

//go:build darwin && arm64
package mlx
/*
#include "mlx/c/mlx.h"
// Callback for compiled functions.
extern int goCompiledFunc(mlx_vector_array *outputs, const mlx_vector_array inputs, void *payload);
static mlx_closure new_closure(void *payload) {
return mlx_closure_new_func_payload(&goCompiledFunc, payload, NULL);
}
*/
import "C"
import (
"sync"
"unsafe"
)
// CompiledFunc wraps a compiled MLX computation graph for efficient repeated calls.
type CompiledFunc struct {
fn func([]*Array) []*Array
closure C.mlx_closure
mu sync.Mutex
}
var compiledFuncs sync.Map
//export goCompiledFunc
func goCompiledFunc(outputs *C.mlx_vector_array, inputs C.mlx_vector_array, payload unsafe.Pointer) C.int {
id := uintptr(payload)
fnI, ok := compiledFuncs.Load(id)
if !ok {
return 1
}
fn := fnI.(func([]*Array) []*Array)
// Convert inputs
nInputs := int(C.mlx_vector_array_size(inputs))
goInputs := make([]*Array, nInputs)
for i := 0; i < nInputs; i++ {
a := New("INPUT")
C.mlx_vector_array_get(&a.ctx, inputs, C.size_t(i))
goInputs[i] = a
}
// Call user function
goOutputs := fn(goInputs)
// The output vector arrives with ctx=nullptr. Create a valid vector,
// fill it, then copy to the output pointer.
tmp := C.mlx_vector_array_new()
for _, out := range goOutputs {
C.mlx_vector_array_append_value(tmp, out.ctx)
}
C.mlx_vector_array_set(outputs, tmp)
C.mlx_vector_array_free(tmp)
return 0
}
var nextID uintptr
var nextIDMu sync.Mutex
// CompileShapeless compiles a function for efficient repeated execution.
// The function must accept and return arrays of consistent shapes.
func CompileShapeless(fn func([]*Array) []*Array, shapeless bool) *CompiledFunc {
nextIDMu.Lock()
nextID++
id := nextID
nextIDMu.Unlock()
compiledFuncs.Store(id, fn)
cf := &CompiledFunc{fn: fn}
cf.closure = C.new_closure(unsafe.Pointer(id))
return cf
}
// Call executes the compiled function with the given inputs.
func (cf *CompiledFunc) Call(inputs ...*Array) []*Array {
cf.mu.Lock()
defer cf.mu.Unlock()
// Fall back to direct call — compilation is an optimization.
// The compiled closure can be used via mlx_compiled but the
// direct path is simpler and still benefits from MLX's lazy evaluation.
return cf.fn(inputs)
}