feat(mlx): add autograd — VJP, JVP, ValueAndGrad, loss functions

Native Go bindings for MLX-C gradient computation on Apple Silicon.
Foundation for LoRA training without Python.

- VJP (reverse-mode autodiff) for backward pass
- JVP (forward-mode autodiff) for directional derivatives
- ValueAndGrad for combined loss + gradient computation
- Checkpoint for memory-efficient gradient recomputation
- CrossEntropyLoss (numerically stable via LogSumExp)
- MSELoss, Log, SumAll, MeanAll, OnesLike helpers
- TakeAlongAxis and LogSumExp ops
- Fix closure callback null vector bug (affects compile.go too)
- Fix Float() returning 0 for float32 arrays

14 tests passing on Metal GPU.

Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
Snider 2026-02-17 17:18:47 +00:00
parent 92c6282d50
commit e9973aef3c
5 changed files with 702 additions and 5 deletions

View file

@ -204,10 +204,18 @@ func (t Array) Int() int {
}
// Float extracts a scalar float64 value.
// Handles both float32 and float64 array dtypes.
func (t Array) Float() float64 {
var item C.double
C.mlx_array_item_float64(&item, t.ctx)
return float64(item)
switch t.Dtype() {
case DTypeFloat32:
var item C.float
C.mlx_array_item_float32(&item, t.ctx)
return float64(item)
default:
var item C.double
C.mlx_array_item_float64(&item, t.ctx)
return float64(item)
}
}
// Ints extracts all elements as int slice (from int32 data).

View file

@ -49,10 +49,14 @@ func goCompiledFunc(outputs *C.mlx_vector_array, inputs C.mlx_vector_array, payl
// Call user function
goOutputs := fn(goInputs)
// Set outputs
// The output vector arrives with ctx=nullptr. Create a valid vector,
// fill it, then copy to the output pointer.
tmp := C.mlx_vector_array_new()
for _, out := range goOutputs {
C.mlx_vector_array_append_value(*outputs, out.ctx)
C.mlx_vector_array_append_value(tmp, out.ctx)
}
C.mlx_vector_array_set(outputs, tmp)
C.mlx_vector_array_free(tmp)
return 0
}

337
mlx/grad.go Normal file
View file

@ -0,0 +1,337 @@
//go:build darwin && arm64
package mlx
/*
#include "mlx/c/mlx.h"
// Callback for gradient closures — same signature as goCompiledFunc.
extern int goGradFunc(mlx_vector_array *outputs, const mlx_vector_array inputs, void *payload);
static mlx_closure new_grad_closure(void *payload) {
return mlx_closure_new_func_payload(&goGradFunc, payload, NULL);
}
*/
import "C"
import (
"fmt"
"sync"
"sync/atomic"
"unsafe"
)
// --- Closure registry (separate from compile.go's registry) ---
var (
gradFuncs sync.Map
gradNextID atomic.Uintptr
)
//export goGradFunc
func goGradFunc(outputs *C.mlx_vector_array, inputs C.mlx_vector_array, payload unsafe.Pointer) C.int {
id := uintptr(payload)
fnI, ok := gradFuncs.Load(id)
if !ok {
return 1
}
fn := fnI.(func([]*Array) []*Array)
nInputs := int(C.mlx_vector_array_size(inputs))
goInputs := make([]*Array, nInputs)
for i := 0; i < nInputs; i++ {
a := New("GRAD_INPUT")
C.mlx_vector_array_get(&a.ctx, inputs, C.size_t(i))
goInputs[i] = a
}
goOutputs := fn(goInputs)
// The output vector arrives with ctx=nullptr (from internal mlx_vector_array_new_).
// Create a valid vector, fill it, then copy to the output pointer.
tmp := C.mlx_vector_array_new()
for _, out := range goOutputs {
C.mlx_vector_array_append_value(tmp, out.ctx)
}
C.mlx_vector_array_set(outputs, tmp)
C.mlx_vector_array_free(tmp)
return 0
}
// newClosure registers a Go function as an MLX closure for autograd.
func newClosure(fn func([]*Array) []*Array) C.mlx_closure {
id := gradNextID.Add(1)
gradFuncs.Store(id, fn)
return C.new_grad_closure(unsafe.Pointer(id))
}
// --- VJP (Vector-Jacobian Product) — Reverse Mode ---
// VJP computes the vector-Jacobian product (reverse-mode autodiff).
// Given a function fn, input primals, and output cotangents (upstream gradients),
// returns (outputs, gradients) where gradients are w.r.t. the primals.
//
// This is the fundamental backward pass operation.
func VJP(fn func([]*Array) []*Array, primals []*Array, cotangents []*Array) (outputs []*Array, vjps []*Array, err error) {
Init()
closure := newClosure(fn)
defer C.mlx_closure_free(closure)
// Pack primals into vector
primalsVec := C.mlx_vector_array_new()
defer C.mlx_vector_array_free(primalsVec)
for _, p := range primals {
C.mlx_vector_array_append_value(primalsVec, p.ctx)
}
// Pack cotangents into vector
cotangentsVec := C.mlx_vector_array_new()
defer C.mlx_vector_array_free(cotangentsVec)
for _, c := range cotangents {
C.mlx_vector_array_append_value(cotangentsVec, c.ctx)
}
// Call mlx_vjp
var outVec, vjpVec C.mlx_vector_array
outVec = C.mlx_vector_array_new()
vjpVec = C.mlx_vector_array_new()
defer C.mlx_vector_array_free(outVec)
defer C.mlx_vector_array_free(vjpVec)
rc := C.mlx_vjp(&outVec, &vjpVec, closure, primalsVec, cotangentsVec)
if rc != 0 {
checkError()
return nil, nil, fmt.Errorf("mlx: vjp failed")
}
outputs = vectorToArrays(outVec)
vjps = vectorToArrays(vjpVec)
return outputs, vjps, nil
}
// --- JVP (Jacobian-Vector Product) — Forward Mode ---
// JVP computes the Jacobian-vector product (forward-mode autodiff).
// Given a function fn, input primals, and input tangents (perturbation directions),
// returns (outputs, output_tangents).
//
// Useful for directional derivatives and Hessian-vector products.
func JVP(fn func([]*Array) []*Array, primals []*Array, tangents []*Array) (outputs []*Array, jvps []*Array, err error) {
Init()
closure := newClosure(fn)
defer C.mlx_closure_free(closure)
primalsVec := C.mlx_vector_array_new()
defer C.mlx_vector_array_free(primalsVec)
for _, p := range primals {
C.mlx_vector_array_append_value(primalsVec, p.ctx)
}
tangentsVec := C.mlx_vector_array_new()
defer C.mlx_vector_array_free(tangentsVec)
for _, t := range tangents {
C.mlx_vector_array_append_value(tangentsVec, t.ctx)
}
var outVec, jvpVec C.mlx_vector_array
outVec = C.mlx_vector_array_new()
jvpVec = C.mlx_vector_array_new()
defer C.mlx_vector_array_free(outVec)
defer C.mlx_vector_array_free(jvpVec)
rc := C.mlx_jvp(&outVec, &jvpVec, closure, primalsVec, tangentsVec)
if rc != 0 {
checkError()
return nil, nil, fmt.Errorf("mlx: jvp failed")
}
outputs = vectorToArrays(outVec)
jvps = vectorToArrays(jvpVec)
return outputs, jvps, nil
}
// --- ValueAndGrad — Combined Forward + Backward ---
// GradFn is a function that computes both the loss value and gradients
// with respect to specified arguments. Call it with inputs to get
// (values, gradients).
type GradFn struct {
cls C.mlx_closure_value_and_grad
}
// ValueAndGrad creates a GradFn that computes both the function value and
// gradients with respect to the arguments at the given indices.
//
// The returned GradFn can be called repeatedly with different inputs.
// This is the primary API for training loops:
//
// lossFn := func(params []*Array) []*Array { return []*Array{loss} }
// grad := mlx.ValueAndGrad(lossFn, 0) // differentiate w.r.t. first arg
// values, grads := grad.Apply(params...)
func ValueAndGrad(fn func([]*Array) []*Array, argnums ...int) *GradFn {
Init()
closure := newClosure(fn)
defer C.mlx_closure_free(closure)
// Default: differentiate w.r.t. first argument
if len(argnums) == 0 {
argnums = []int{0}
}
cArgs := make([]C.int, len(argnums))
for i, a := range argnums {
cArgs[i] = C.int(a)
}
g := &GradFn{}
C.mlx_value_and_grad(&g.cls, closure, &cArgs[0], C.size_t(len(cArgs)))
return g
}
// Apply calls the gradient function with the given inputs.
// Returns (values, gradients) where values are the function outputs
// and gradients are w.r.t. the arguments specified in ValueAndGrad.
func (g *GradFn) Apply(inputs ...*Array) (values []*Array, grads []*Array, err error) {
inputVec := C.mlx_vector_array_new()
defer C.mlx_vector_array_free(inputVec)
for _, in := range inputs {
C.mlx_vector_array_append_value(inputVec, in.ctx)
}
var valVec, gradVec C.mlx_vector_array
valVec = C.mlx_vector_array_new()
gradVec = C.mlx_vector_array_new()
defer C.mlx_vector_array_free(valVec)
defer C.mlx_vector_array_free(gradVec)
rc := C.mlx_closure_value_and_grad_apply(&valVec, &gradVec, g.cls, inputVec)
if rc != 0 {
checkError()
return nil, nil, fmt.Errorf("mlx: value_and_grad apply failed")
}
values = vectorToArrays(valVec)
grads = vectorToArrays(gradVec)
return values, grads, nil
}
// Free releases the underlying C closure.
func (g *GradFn) Free() {
if g.cls.ctx != nil {
C.mlx_closure_value_and_grad_free(g.cls)
g.cls.ctx = nil
}
}
// --- Checkpoint — Memory-Efficient Gradient Recomputation ---
// Checkpoint wraps a function so that during backward pass, intermediate
// activations are recomputed rather than stored. Trades compute for memory.
//
// Use this for memory-constrained training (large models on limited VRAM).
func Checkpoint(fn func([]*Array) []*Array) func([]*Array) []*Array {
Init()
closure := newClosure(fn)
defer C.mlx_closure_free(closure)
var checkpointed C.mlx_closure
C.mlx_checkpoint(&checkpointed, closure)
// Register the checkpointed closure so it stays alive
id := gradNextID.Add(1)
cpClosure := checkpointed // capture for the returned function
// Return a Go function that calls through the checkpointed closure
return func(inputs []*Array) []*Array {
_ = id // prevent GC of closure reference
inputVec := C.mlx_vector_array_new()
defer C.mlx_vector_array_free(inputVec)
for _, in := range inputs {
C.mlx_vector_array_append_value(inputVec, in.ctx)
}
outVec := C.mlx_vector_array_new()
defer C.mlx_vector_array_free(outVec)
C.mlx_closure_apply(&outVec, cpClosure, inputVec)
return vectorToArrays(outVec)
}
}
// --- Loss Functions ---
// CrossEntropyLoss computes cross-entropy loss between logits and integer targets.
// logits: [..., V] (raw model output, pre-softmax, last dim = vocab)
// targets: [...] (integer token IDs, same shape as logits minus last dim)
// Returns scalar loss averaged over all positions.
//
// Numerically stable: loss_i = logsumexp(logits_i) - logits_i[target_i]
func CrossEntropyLoss(logits, targets *Array) *Array {
Init()
// LogSumExp along last axis (vocab dimension) for numerical stability
lse := LogSumExp(logits, -1, false)
// Gather the logit at the target index: logits[..., target]
// targets needs to be [..., 1] for take_along_axis
tgtExpanded := ExpandDims(targets, -1)
gathered := TakeAlongAxis(logits, tgtExpanded, -1)
gathered = Squeeze(gathered, -1) // back to [...] shape
// Per-position loss: logsumexp - logit_at_target
perPos := Subtract(lse, gathered)
// Mean over all positions
return MeanAll(perPos)
}
// MSELoss computes the mean squared error loss: mean((predictions - targets)^2).
func MSELoss(predictions, targets *Array) *Array {
diff := Subtract(predictions, targets)
sq := Square(diff)
return MeanAll(sq)
}
// --- Helpers ---
// vectorToArrays extracts all arrays from an mlx_vector_array.
func vectorToArrays(vec C.mlx_vector_array) []*Array {
n := int(C.mlx_vector_array_size(vec))
out := make([]*Array, n)
for i := 0; i < n; i++ {
a := New("VEC_OUT")
C.mlx_vector_array_get(&a.ctx, vec, C.size_t(i))
out[i] = a
}
return out
}
// Log returns element-wise natural logarithm.
func Log(a *Array) *Array {
out := New("LOG", a)
C.mlx_log(&out.ctx, a.ctx, DefaultStream().ctx)
return out
}
// SumAll reduces by summation over all elements, returning a scalar.
func SumAll(a *Array) *Array {
flat := Reshape(a, -1)
return Sum(flat, 0, false)
}
// MeanAll reduces by averaging over all elements, returning a scalar.
func MeanAll(a *Array) *Array {
flat := Reshape(a, -1)
return Mean(flat, 0, false)
}
// OnesLike creates an array of ones with the same shape and type as the input.
func OnesLike(a *Array) *Array {
out := New("ONES_LIKE", a)
C.mlx_ones_like(&out.ctx, a.ctx, DefaultStream().ctx)
return out
}

332
mlx/grad_test.go Normal file
View file

@ -0,0 +1,332 @@
//go:build darwin && arm64
package mlx
import (
"math"
"testing"
)
func TestVJP_SimpleSquare(t *testing.T) {
// f(x) = x^2, df/dx = 2x
// At x=3: f(3)=9, df/dx=6
fn := func(inputs []*Array) []*Array {
x := inputs[0]
return []*Array{Mul(x, x)}
}
x := FromValue(float32(3.0))
cotangent := FromValue(float32(1.0)) // upstream grad = 1
outputs, grads, err := VJP(fn, []*Array{x}, []*Array{cotangent})
if err != nil {
t.Fatalf("VJP failed: %v", err)
}
Materialize(outputs[0], grads[0])
got := outputs[0].Float()
if math.Abs(got-9.0) > 1e-5 {
t.Errorf("output = %f, want 9.0", got)
}
grad := grads[0].Float()
if math.Abs(grad-6.0) > 1e-5 {
t.Errorf("grad = %f, want 6.0", grad)
}
}
func TestVJP_Addition(t *testing.T) {
// f(x, y) = x + y, df/dx = 1, df/dy = 1
fn := func(inputs []*Array) []*Array {
return []*Array{Add(inputs[0], inputs[1])}
}
x := FromValue(float32(2.0))
y := FromValue(float32(5.0))
cotangent := FromValue(float32(1.0))
_, grads, err := VJP(fn, []*Array{x, y}, []*Array{cotangent})
if err != nil {
t.Fatalf("VJP failed: %v", err)
}
Materialize(grads...)
if math.Abs(grads[0].Float()-1.0) > 1e-5 {
t.Errorf("dx = %f, want 1.0", grads[0].Float())
}
if math.Abs(grads[1].Float()-1.0) > 1e-5 {
t.Errorf("dy = %f, want 1.0", grads[1].Float())
}
}
func TestVJP_MatmulGrad(t *testing.T) {
// f(W) = sum(W @ x) — gradient of sum(matmul) w.r.t. W
// For W=[2,2], x=[2,1]: dL/dW = ones @ x^T
x := FromValues([]float32{1.0, 2.0}, 2, 1)
w := FromValues([]float32{1.0, 0.0, 0.0, 1.0}, 2, 2) // identity
fn := func(inputs []*Array) []*Array {
result := Matmul(inputs[0], x)
return []*Array{SumAll(result)}
}
cotangent := FromValue(float32(1.0))
outputs, grads, err := VJP(fn, []*Array{w}, []*Array{cotangent})
if err != nil {
t.Fatalf("VJP failed: %v", err)
}
Materialize(outputs[0], grads[0])
// W @ x with W=I, x=[1,2]^T gives [1,2]^T, sum=3
got := outputs[0].Float()
if math.Abs(got-3.0) > 1e-5 {
t.Errorf("output = %f, want 3.0", got)
}
// Gradient of sum(W@x) w.r.t. W is outer product: ones @ x^T
// = [[1,2],[1,2]]
gradFloats := grads[0].Floats()
expected := []float32{1.0, 2.0, 1.0, 2.0}
for i, exp := range expected {
if math.Abs(float64(gradFloats[i]-exp)) > 1e-5 {
t.Errorf("grad[%d] = %f, want %f", i, gradFloats[i], exp)
}
}
}
func TestJVP_SimpleSquare(t *testing.T) {
// f(x) = x^2, JVP with tangent v: df = 2x * v
// At x=3, v=1: df = 6
fn := func(inputs []*Array) []*Array {
x := inputs[0]
return []*Array{Mul(x, x)}
}
x := FromValue(float32(3.0))
tangent := FromValue(float32(1.0))
outputs, jvps, err := JVP(fn, []*Array{x}, []*Array{tangent})
if err != nil {
t.Fatalf("JVP failed: %v", err)
}
Materialize(outputs[0], jvps[0])
got := outputs[0].Float()
if math.Abs(got-9.0) > 1e-5 {
t.Errorf("output = %f, want 9.0", got)
}
jvp := jvps[0].Float()
if math.Abs(jvp-6.0) > 1e-5 {
t.Errorf("jvp = %f, want 6.0", jvp)
}
}
func TestValueAndGrad_Quadratic(t *testing.T) {
// f(x) = x^2 + 2x + 1 = (x+1)^2
// f'(x) = 2x + 2
// At x=3: f(3) = 16, f'(3) = 8
fn := func(inputs []*Array) []*Array {
x := inputs[0]
x2 := Mul(x, x)
two_x := MulScalar(x, 2.0)
one := FromValue(float32(1.0))
return []*Array{Add(Add(x2, two_x), one)}
}
grad := ValueAndGrad(fn, 0)
defer grad.Free()
x := FromValue(float32(3.0))
values, grads, err := grad.Apply(x)
if err != nil {
t.Fatalf("ValueAndGrad failed: %v", err)
}
Materialize(values[0], grads[0])
val := values[0].Float()
if math.Abs(val-16.0) > 1e-5 {
t.Errorf("value = %f, want 16.0", val)
}
g := grads[0].Float()
if math.Abs(g-8.0) > 1e-5 {
t.Errorf("grad = %f, want 8.0", g)
}
}
func TestValueAndGrad_MultiArg(t *testing.T) {
// f(x, y) = x*y, df/dx = y, df/dy = x
// At x=3, y=4: f=12, dx=4, dy=3
fn := func(inputs []*Array) []*Array {
return []*Array{Mul(inputs[0], inputs[1])}
}
// Differentiate w.r.t. both arguments
grad := ValueAndGrad(fn, 0, 1)
defer grad.Free()
x := FromValue(float32(3.0))
y := FromValue(float32(4.0))
values, grads, err := grad.Apply(x, y)
if err != nil {
t.Fatalf("ValueAndGrad failed: %v", err)
}
Materialize(values[0], grads[0], grads[1])
val := values[0].Float()
if math.Abs(val-12.0) > 1e-5 {
t.Errorf("value = %f, want 12.0", val)
}
dx := grads[0].Float()
if math.Abs(dx-4.0) > 1e-5 {
t.Errorf("dx = %f, want 4.0 (y)", dx)
}
dy := grads[1].Float()
if math.Abs(dy-3.0) > 1e-5 {
t.Errorf("dy = %f, want 3.0 (x)", dy)
}
}
func TestValueAndGrad_Reusable(t *testing.T) {
// Verify GradFn can be called multiple times
fn := func(inputs []*Array) []*Array {
x := inputs[0]
return []*Array{Mul(x, x)} // x^2, grad = 2x
}
grad := ValueAndGrad(fn)
defer grad.Free()
for _, tc := range []struct {
x float32
want float64 // expected gradient
}{
{2.0, 4.0},
{5.0, 10.0},
{-3.0, -6.0},
{0.0, 0.0},
} {
x := FromValue(tc.x)
_, grads, err := grad.Apply(x)
if err != nil {
t.Fatalf("Apply failed for x=%f: %v", tc.x, err)
}
Materialize(grads[0])
g := grads[0].Float()
if math.Abs(g-tc.want) > 1e-5 {
t.Errorf("x=%f: grad = %f, want %f", tc.x, g, tc.want)
}
}
}
func TestCrossEntropyLoss(t *testing.T) {
// Simple 3-class classification
// logits = [1.0, 2.0, 3.0], target = 2 (class index)
// Manual: logsumexp([1,2,3]) = 3 + log(exp(-2)+exp(-1)+1)
// = 3 + log(0.1353 + 0.3679 + 1.0) = 3 + log(1.5032) = 3.4076
// loss = 3.4076 - 3.0 = 0.4076
logits := FromValues([]float32{1.0, 2.0, 3.0}, 1, 3) // [1, 3]
targets := FromValues([]int32{2}, 1) // [1]
loss := CrossEntropyLoss(logits, targets)
Materialize(loss)
got := loss.Float()
expected := 0.4076
if math.Abs(got-expected) > 0.01 {
t.Errorf("CrossEntropyLoss = %f, want ~%f", got, expected)
}
}
func TestMSELoss(t *testing.T) {
pred := FromValues([]float32{1.0, 2.0, 3.0}, 3)
target := FromValues([]float32{1.5, 2.5, 3.5}, 3)
loss := MSELoss(pred, target)
Materialize(loss)
// MSE = mean((0.5)^2, (0.5)^2, (0.5)^2) = mean(0.25, 0.25, 0.25) = 0.25
got := loss.Float()
if math.Abs(got-0.25) > 1e-5 {
t.Errorf("MSELoss = %f, want 0.25", got)
}
}
func TestLogSumExp(t *testing.T) {
// logsumexp([1, 2, 3]) along axis -1
a := FromValues([]float32{1.0, 2.0, 3.0}, 1, 3)
result := LogSumExp(a, -1, false)
Materialize(result)
// = 3 + log(exp(-2) + exp(-1) + 1) = 3 + log(1.5032) ≈ 3.4076
got := result.Float()
expected := 3.4076
if math.Abs(got-expected) > 0.01 {
t.Errorf("LogSumExp = %f, want ~%f", got, expected)
}
}
func TestOnesLike(t *testing.T) {
a := FromValues([]float32{1.0, 2.0, 3.0}, 3)
ones := OnesLike(a)
Materialize(ones)
floats := ones.Floats()
for i, f := range floats {
if f != 1.0 {
t.Errorf("OnesLike[%d] = %f, want 1.0", i, f)
}
}
}
func TestCheckpoint(t *testing.T) {
// Checkpoint should produce the same result as the original function
fn := func(inputs []*Array) []*Array {
x := inputs[0]
return []*Array{Mul(x, x)}
}
cpFn := Checkpoint(fn)
x := FromValue(float32(5.0))
result := cpFn([]*Array{x})
Materialize(result[0])
got := result[0].Float()
if math.Abs(got-25.0) > 1e-5 {
t.Errorf("Checkpoint result = %f, want 25.0", got)
}
}
func TestSumAll(t *testing.T) {
a := FromValues([]float32{1.0, 2.0, 3.0, 4.0}, 2, 2)
result := SumAll(a)
Materialize(result)
got := result.Float()
if math.Abs(got-10.0) > 1e-5 {
t.Errorf("SumAll = %f, want 10.0", got)
}
}
func TestMeanAll(t *testing.T) {
a := FromValues([]float32{2.0, 4.0, 6.0, 8.0}, 2, 2)
result := MeanAll(a)
Materialize(result)
got := result.Float()
if math.Abs(got-5.0) > 1e-5 {
t.Errorf("MeanAll = %f, want 5.0", got)
}
}

View file

@ -335,3 +335,19 @@ func PutAlongAxis(a, indices, values *Array, axis int) *Array {
C.mlx_put_along_axis(&out.ctx, a.ctx, indices.ctx, values.ctx, C.int(axis), DefaultStream().ctx)
return out
}
// TakeAlongAxis gathers elements from a along axis using indices.
// Unlike Take, this uses the same number of dimensions for indices and input.
func TakeAlongAxis(a, indices *Array, axis int) *Array {
out := New("TAKE_ALONG_AXIS", a, indices)
C.mlx_take_along_axis(&out.ctx, a.ctx, indices.ctx, C.int(axis), DefaultStream().ctx)
return out
}
// LogSumExp computes log(sum(exp(a))) along the given axis.
// Numerically stable reduction for cross-entropy loss.
func LogSumExp(a *Array, axis int, keepDims bool) *Array {
out := New("LOGSUMEXP", a)
C.mlx_logsumexp_axis(&out.ctx, a.ctx, C.int(axis), C._Bool(keepDims), DefaultStream().ctx)
return out
}