From e9973aef3cb7a622d7567491e48cee8ccaefe075 Mon Sep 17 00:00:00 2001 From: Snider Date: Tue, 17 Feb 2026 17:18:47 +0000 Subject: [PATCH] =?UTF-8?q?feat(mlx):=20add=20autograd=20=E2=80=94=20VJP,?= =?UTF-8?q?=20JVP,=20ValueAndGrad,=20loss=20functions?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Native Go bindings for MLX-C gradient computation on Apple Silicon. Foundation for LoRA training without Python. - VJP (reverse-mode autodiff) for backward pass - JVP (forward-mode autodiff) for directional derivatives - ValueAndGrad for combined loss + gradient computation - Checkpoint for memory-efficient gradient recomputation - CrossEntropyLoss (numerically stable via LogSumExp) - MSELoss, Log, SumAll, MeanAll, OnesLike helpers - TakeAlongAxis and LogSumExp ops - Fix closure callback null vector bug (affects compile.go too) - Fix Float() returning 0 for float32 arrays 14 tests passing on Metal GPU. Co-Authored-By: Virgil --- mlx/array.go | 14 +- mlx/compile.go | 8 +- mlx/grad.go | 337 +++++++++++++++++++++++++++++++++++++++++++++++ mlx/grad_test.go | 332 ++++++++++++++++++++++++++++++++++++++++++++++ mlx/ops.go | 16 +++ 5 files changed, 702 insertions(+), 5 deletions(-) create mode 100644 mlx/grad.go create mode 100644 mlx/grad_test.go diff --git a/mlx/array.go b/mlx/array.go index ee4f4a8..0968bda 100644 --- a/mlx/array.go +++ b/mlx/array.go @@ -204,10 +204,18 @@ func (t Array) Int() int { } // Float extracts a scalar float64 value. +// Handles both float32 and float64 array dtypes. func (t Array) Float() float64 { - var item C.double - C.mlx_array_item_float64(&item, t.ctx) - return float64(item) + switch t.Dtype() { + case DTypeFloat32: + var item C.float + C.mlx_array_item_float32(&item, t.ctx) + return float64(item) + default: + var item C.double + C.mlx_array_item_float64(&item, t.ctx) + return float64(item) + } } // Ints extracts all elements as int slice (from int32 data). diff --git a/mlx/compile.go b/mlx/compile.go index f62426f..8440f4b 100644 --- a/mlx/compile.go +++ b/mlx/compile.go @@ -49,10 +49,14 @@ func goCompiledFunc(outputs *C.mlx_vector_array, inputs C.mlx_vector_array, payl // Call user function goOutputs := fn(goInputs) - // Set outputs + // The output vector arrives with ctx=nullptr. Create a valid vector, + // fill it, then copy to the output pointer. + tmp := C.mlx_vector_array_new() for _, out := range goOutputs { - C.mlx_vector_array_append_value(*outputs, out.ctx) + C.mlx_vector_array_append_value(tmp, out.ctx) } + C.mlx_vector_array_set(outputs, tmp) + C.mlx_vector_array_free(tmp) return 0 } diff --git a/mlx/grad.go b/mlx/grad.go new file mode 100644 index 0000000..f46ee7b --- /dev/null +++ b/mlx/grad.go @@ -0,0 +1,337 @@ +//go:build darwin && arm64 + +package mlx + +/* +#include "mlx/c/mlx.h" + +// Callback for gradient closures — same signature as goCompiledFunc. +extern int goGradFunc(mlx_vector_array *outputs, const mlx_vector_array inputs, void *payload); + +static mlx_closure new_grad_closure(void *payload) { + return mlx_closure_new_func_payload(&goGradFunc, payload, NULL); +} +*/ +import "C" + +import ( + "fmt" + "sync" + "sync/atomic" + "unsafe" +) + +// --- Closure registry (separate from compile.go's registry) --- + +var ( + gradFuncs sync.Map + gradNextID atomic.Uintptr +) + +//export goGradFunc +func goGradFunc(outputs *C.mlx_vector_array, inputs C.mlx_vector_array, payload unsafe.Pointer) C.int { + id := uintptr(payload) + fnI, ok := gradFuncs.Load(id) + if !ok { + return 1 + } + fn := fnI.(func([]*Array) []*Array) + + nInputs := int(C.mlx_vector_array_size(inputs)) + goInputs := make([]*Array, nInputs) + for i := 0; i < nInputs; i++ { + a := New("GRAD_INPUT") + C.mlx_vector_array_get(&a.ctx, inputs, C.size_t(i)) + goInputs[i] = a + } + + goOutputs := fn(goInputs) + + // The output vector arrives with ctx=nullptr (from internal mlx_vector_array_new_). + // Create a valid vector, fill it, then copy to the output pointer. + tmp := C.mlx_vector_array_new() + for _, out := range goOutputs { + C.mlx_vector_array_append_value(tmp, out.ctx) + } + C.mlx_vector_array_set(outputs, tmp) + C.mlx_vector_array_free(tmp) + return 0 +} + +// newClosure registers a Go function as an MLX closure for autograd. +func newClosure(fn func([]*Array) []*Array) C.mlx_closure { + id := gradNextID.Add(1) + gradFuncs.Store(id, fn) + return C.new_grad_closure(unsafe.Pointer(id)) +} + +// --- VJP (Vector-Jacobian Product) — Reverse Mode --- + +// VJP computes the vector-Jacobian product (reverse-mode autodiff). +// Given a function fn, input primals, and output cotangents (upstream gradients), +// returns (outputs, gradients) where gradients are w.r.t. the primals. +// +// This is the fundamental backward pass operation. +func VJP(fn func([]*Array) []*Array, primals []*Array, cotangents []*Array) (outputs []*Array, vjps []*Array, err error) { + Init() + + closure := newClosure(fn) + defer C.mlx_closure_free(closure) + + // Pack primals into vector + primalsVec := C.mlx_vector_array_new() + defer C.mlx_vector_array_free(primalsVec) + for _, p := range primals { + C.mlx_vector_array_append_value(primalsVec, p.ctx) + } + + // Pack cotangents into vector + cotangentsVec := C.mlx_vector_array_new() + defer C.mlx_vector_array_free(cotangentsVec) + for _, c := range cotangents { + C.mlx_vector_array_append_value(cotangentsVec, c.ctx) + } + + // Call mlx_vjp + var outVec, vjpVec C.mlx_vector_array + outVec = C.mlx_vector_array_new() + vjpVec = C.mlx_vector_array_new() + defer C.mlx_vector_array_free(outVec) + defer C.mlx_vector_array_free(vjpVec) + + rc := C.mlx_vjp(&outVec, &vjpVec, closure, primalsVec, cotangentsVec) + if rc != 0 { + checkError() + return nil, nil, fmt.Errorf("mlx: vjp failed") + } + + outputs = vectorToArrays(outVec) + vjps = vectorToArrays(vjpVec) + return outputs, vjps, nil +} + +// --- JVP (Jacobian-Vector Product) — Forward Mode --- + +// JVP computes the Jacobian-vector product (forward-mode autodiff). +// Given a function fn, input primals, and input tangents (perturbation directions), +// returns (outputs, output_tangents). +// +// Useful for directional derivatives and Hessian-vector products. +func JVP(fn func([]*Array) []*Array, primals []*Array, tangents []*Array) (outputs []*Array, jvps []*Array, err error) { + Init() + + closure := newClosure(fn) + defer C.mlx_closure_free(closure) + + primalsVec := C.mlx_vector_array_new() + defer C.mlx_vector_array_free(primalsVec) + for _, p := range primals { + C.mlx_vector_array_append_value(primalsVec, p.ctx) + } + + tangentsVec := C.mlx_vector_array_new() + defer C.mlx_vector_array_free(tangentsVec) + for _, t := range tangents { + C.mlx_vector_array_append_value(tangentsVec, t.ctx) + } + + var outVec, jvpVec C.mlx_vector_array + outVec = C.mlx_vector_array_new() + jvpVec = C.mlx_vector_array_new() + defer C.mlx_vector_array_free(outVec) + defer C.mlx_vector_array_free(jvpVec) + + rc := C.mlx_jvp(&outVec, &jvpVec, closure, primalsVec, tangentsVec) + if rc != 0 { + checkError() + return nil, nil, fmt.Errorf("mlx: jvp failed") + } + + outputs = vectorToArrays(outVec) + jvps = vectorToArrays(jvpVec) + return outputs, jvps, nil +} + +// --- ValueAndGrad — Combined Forward + Backward --- + +// GradFn is a function that computes both the loss value and gradients +// with respect to specified arguments. Call it with inputs to get +// (values, gradients). +type GradFn struct { + cls C.mlx_closure_value_and_grad +} + +// ValueAndGrad creates a GradFn that computes both the function value and +// gradients with respect to the arguments at the given indices. +// +// The returned GradFn can be called repeatedly with different inputs. +// This is the primary API for training loops: +// +// lossFn := func(params []*Array) []*Array { return []*Array{loss} } +// grad := mlx.ValueAndGrad(lossFn, 0) // differentiate w.r.t. first arg +// values, grads := grad.Apply(params...) +func ValueAndGrad(fn func([]*Array) []*Array, argnums ...int) *GradFn { + Init() + + closure := newClosure(fn) + defer C.mlx_closure_free(closure) + + // Default: differentiate w.r.t. first argument + if len(argnums) == 0 { + argnums = []int{0} + } + + cArgs := make([]C.int, len(argnums)) + for i, a := range argnums { + cArgs[i] = C.int(a) + } + + g := &GradFn{} + C.mlx_value_and_grad(&g.cls, closure, &cArgs[0], C.size_t(len(cArgs))) + return g +} + +// Apply calls the gradient function with the given inputs. +// Returns (values, gradients) where values are the function outputs +// and gradients are w.r.t. the arguments specified in ValueAndGrad. +func (g *GradFn) Apply(inputs ...*Array) (values []*Array, grads []*Array, err error) { + inputVec := C.mlx_vector_array_new() + defer C.mlx_vector_array_free(inputVec) + for _, in := range inputs { + C.mlx_vector_array_append_value(inputVec, in.ctx) + } + + var valVec, gradVec C.mlx_vector_array + valVec = C.mlx_vector_array_new() + gradVec = C.mlx_vector_array_new() + defer C.mlx_vector_array_free(valVec) + defer C.mlx_vector_array_free(gradVec) + + rc := C.mlx_closure_value_and_grad_apply(&valVec, &gradVec, g.cls, inputVec) + if rc != 0 { + checkError() + return nil, nil, fmt.Errorf("mlx: value_and_grad apply failed") + } + + values = vectorToArrays(valVec) + grads = vectorToArrays(gradVec) + return values, grads, nil +} + +// Free releases the underlying C closure. +func (g *GradFn) Free() { + if g.cls.ctx != nil { + C.mlx_closure_value_and_grad_free(g.cls) + g.cls.ctx = nil + } +} + +// --- Checkpoint — Memory-Efficient Gradient Recomputation --- + +// Checkpoint wraps a function so that during backward pass, intermediate +// activations are recomputed rather than stored. Trades compute for memory. +// +// Use this for memory-constrained training (large models on limited VRAM). +func Checkpoint(fn func([]*Array) []*Array) func([]*Array) []*Array { + Init() + + closure := newClosure(fn) + defer C.mlx_closure_free(closure) + + var checkpointed C.mlx_closure + C.mlx_checkpoint(&checkpointed, closure) + + // Register the checkpointed closure so it stays alive + id := gradNextID.Add(1) + cpClosure := checkpointed // capture for the returned function + + // Return a Go function that calls through the checkpointed closure + return func(inputs []*Array) []*Array { + _ = id // prevent GC of closure reference + + inputVec := C.mlx_vector_array_new() + defer C.mlx_vector_array_free(inputVec) + for _, in := range inputs { + C.mlx_vector_array_append_value(inputVec, in.ctx) + } + + outVec := C.mlx_vector_array_new() + defer C.mlx_vector_array_free(outVec) + + C.mlx_closure_apply(&outVec, cpClosure, inputVec) + return vectorToArrays(outVec) + } +} + +// --- Loss Functions --- + +// CrossEntropyLoss computes cross-entropy loss between logits and integer targets. +// logits: [..., V] (raw model output, pre-softmax, last dim = vocab) +// targets: [...] (integer token IDs, same shape as logits minus last dim) +// Returns scalar loss averaged over all positions. +// +// Numerically stable: loss_i = logsumexp(logits_i) - logits_i[target_i] +func CrossEntropyLoss(logits, targets *Array) *Array { + Init() + // LogSumExp along last axis (vocab dimension) for numerical stability + lse := LogSumExp(logits, -1, false) + + // Gather the logit at the target index: logits[..., target] + // targets needs to be [..., 1] for take_along_axis + tgtExpanded := ExpandDims(targets, -1) + gathered := TakeAlongAxis(logits, tgtExpanded, -1) + gathered = Squeeze(gathered, -1) // back to [...] shape + + // Per-position loss: logsumexp - logit_at_target + perPos := Subtract(lse, gathered) + + // Mean over all positions + return MeanAll(perPos) +} + +// MSELoss computes the mean squared error loss: mean((predictions - targets)^2). +func MSELoss(predictions, targets *Array) *Array { + diff := Subtract(predictions, targets) + sq := Square(diff) + return MeanAll(sq) +} + +// --- Helpers --- + +// vectorToArrays extracts all arrays from an mlx_vector_array. +func vectorToArrays(vec C.mlx_vector_array) []*Array { + n := int(C.mlx_vector_array_size(vec)) + out := make([]*Array, n) + for i := 0; i < n; i++ { + a := New("VEC_OUT") + C.mlx_vector_array_get(&a.ctx, vec, C.size_t(i)) + out[i] = a + } + return out +} + +// Log returns element-wise natural logarithm. +func Log(a *Array) *Array { + out := New("LOG", a) + C.mlx_log(&out.ctx, a.ctx, DefaultStream().ctx) + return out +} + +// SumAll reduces by summation over all elements, returning a scalar. +func SumAll(a *Array) *Array { + flat := Reshape(a, -1) + return Sum(flat, 0, false) +} + +// MeanAll reduces by averaging over all elements, returning a scalar. +func MeanAll(a *Array) *Array { + flat := Reshape(a, -1) + return Mean(flat, 0, false) +} + +// OnesLike creates an array of ones with the same shape and type as the input. +func OnesLike(a *Array) *Array { + out := New("ONES_LIKE", a) + C.mlx_ones_like(&out.ctx, a.ctx, DefaultStream().ctx) + return out +} diff --git a/mlx/grad_test.go b/mlx/grad_test.go new file mode 100644 index 0000000..0a00751 --- /dev/null +++ b/mlx/grad_test.go @@ -0,0 +1,332 @@ +//go:build darwin && arm64 + +package mlx + +import ( + "math" + "testing" +) + +func TestVJP_SimpleSquare(t *testing.T) { + // f(x) = x^2, df/dx = 2x + // At x=3: f(3)=9, df/dx=6 + fn := func(inputs []*Array) []*Array { + x := inputs[0] + return []*Array{Mul(x, x)} + } + + x := FromValue(float32(3.0)) + cotangent := FromValue(float32(1.0)) // upstream grad = 1 + + outputs, grads, err := VJP(fn, []*Array{x}, []*Array{cotangent}) + if err != nil { + t.Fatalf("VJP failed: %v", err) + } + + Materialize(outputs[0], grads[0]) + + got := outputs[0].Float() + if math.Abs(got-9.0) > 1e-5 { + t.Errorf("output = %f, want 9.0", got) + } + + grad := grads[0].Float() + if math.Abs(grad-6.0) > 1e-5 { + t.Errorf("grad = %f, want 6.0", grad) + } +} + +func TestVJP_Addition(t *testing.T) { + // f(x, y) = x + y, df/dx = 1, df/dy = 1 + fn := func(inputs []*Array) []*Array { + return []*Array{Add(inputs[0], inputs[1])} + } + + x := FromValue(float32(2.0)) + y := FromValue(float32(5.0)) + cotangent := FromValue(float32(1.0)) + + _, grads, err := VJP(fn, []*Array{x, y}, []*Array{cotangent}) + if err != nil { + t.Fatalf("VJP failed: %v", err) + } + + Materialize(grads...) + + if math.Abs(grads[0].Float()-1.0) > 1e-5 { + t.Errorf("dx = %f, want 1.0", grads[0].Float()) + } + if math.Abs(grads[1].Float()-1.0) > 1e-5 { + t.Errorf("dy = %f, want 1.0", grads[1].Float()) + } +} + +func TestVJP_MatmulGrad(t *testing.T) { + // f(W) = sum(W @ x) — gradient of sum(matmul) w.r.t. W + // For W=[2,2], x=[2,1]: dL/dW = ones @ x^T + x := FromValues([]float32{1.0, 2.0}, 2, 1) + w := FromValues([]float32{1.0, 0.0, 0.0, 1.0}, 2, 2) // identity + + fn := func(inputs []*Array) []*Array { + result := Matmul(inputs[0], x) + return []*Array{SumAll(result)} + } + + cotangent := FromValue(float32(1.0)) + + outputs, grads, err := VJP(fn, []*Array{w}, []*Array{cotangent}) + if err != nil { + t.Fatalf("VJP failed: %v", err) + } + + Materialize(outputs[0], grads[0]) + + // W @ x with W=I, x=[1,2]^T gives [1,2]^T, sum=3 + got := outputs[0].Float() + if math.Abs(got-3.0) > 1e-5 { + t.Errorf("output = %f, want 3.0", got) + } + + // Gradient of sum(W@x) w.r.t. W is outer product: ones @ x^T + // = [[1,2],[1,2]] + gradFloats := grads[0].Floats() + expected := []float32{1.0, 2.0, 1.0, 2.0} + for i, exp := range expected { + if math.Abs(float64(gradFloats[i]-exp)) > 1e-5 { + t.Errorf("grad[%d] = %f, want %f", i, gradFloats[i], exp) + } + } +} + +func TestJVP_SimpleSquare(t *testing.T) { + // f(x) = x^2, JVP with tangent v: df = 2x * v + // At x=3, v=1: df = 6 + fn := func(inputs []*Array) []*Array { + x := inputs[0] + return []*Array{Mul(x, x)} + } + + x := FromValue(float32(3.0)) + tangent := FromValue(float32(1.0)) + + outputs, jvps, err := JVP(fn, []*Array{x}, []*Array{tangent}) + if err != nil { + t.Fatalf("JVP failed: %v", err) + } + + Materialize(outputs[0], jvps[0]) + + got := outputs[0].Float() + if math.Abs(got-9.0) > 1e-5 { + t.Errorf("output = %f, want 9.0", got) + } + + jvp := jvps[0].Float() + if math.Abs(jvp-6.0) > 1e-5 { + t.Errorf("jvp = %f, want 6.0", jvp) + } +} + +func TestValueAndGrad_Quadratic(t *testing.T) { + // f(x) = x^2 + 2x + 1 = (x+1)^2 + // f'(x) = 2x + 2 + // At x=3: f(3) = 16, f'(3) = 8 + fn := func(inputs []*Array) []*Array { + x := inputs[0] + x2 := Mul(x, x) + two_x := MulScalar(x, 2.0) + one := FromValue(float32(1.0)) + return []*Array{Add(Add(x2, two_x), one)} + } + + grad := ValueAndGrad(fn, 0) + defer grad.Free() + + x := FromValue(float32(3.0)) + values, grads, err := grad.Apply(x) + if err != nil { + t.Fatalf("ValueAndGrad failed: %v", err) + } + + Materialize(values[0], grads[0]) + + val := values[0].Float() + if math.Abs(val-16.0) > 1e-5 { + t.Errorf("value = %f, want 16.0", val) + } + + g := grads[0].Float() + if math.Abs(g-8.0) > 1e-5 { + t.Errorf("grad = %f, want 8.0", g) + } +} + +func TestValueAndGrad_MultiArg(t *testing.T) { + // f(x, y) = x*y, df/dx = y, df/dy = x + // At x=3, y=4: f=12, dx=4, dy=3 + fn := func(inputs []*Array) []*Array { + return []*Array{Mul(inputs[0], inputs[1])} + } + + // Differentiate w.r.t. both arguments + grad := ValueAndGrad(fn, 0, 1) + defer grad.Free() + + x := FromValue(float32(3.0)) + y := FromValue(float32(4.0)) + values, grads, err := grad.Apply(x, y) + if err != nil { + t.Fatalf("ValueAndGrad failed: %v", err) + } + + Materialize(values[0], grads[0], grads[1]) + + val := values[0].Float() + if math.Abs(val-12.0) > 1e-5 { + t.Errorf("value = %f, want 12.0", val) + } + + dx := grads[0].Float() + if math.Abs(dx-4.0) > 1e-5 { + t.Errorf("dx = %f, want 4.0 (y)", dx) + } + + dy := grads[1].Float() + if math.Abs(dy-3.0) > 1e-5 { + t.Errorf("dy = %f, want 3.0 (x)", dy) + } +} + +func TestValueAndGrad_Reusable(t *testing.T) { + // Verify GradFn can be called multiple times + fn := func(inputs []*Array) []*Array { + x := inputs[0] + return []*Array{Mul(x, x)} // x^2, grad = 2x + } + + grad := ValueAndGrad(fn) + defer grad.Free() + + for _, tc := range []struct { + x float32 + want float64 // expected gradient + }{ + {2.0, 4.0}, + {5.0, 10.0}, + {-3.0, -6.0}, + {0.0, 0.0}, + } { + x := FromValue(tc.x) + _, grads, err := grad.Apply(x) + if err != nil { + t.Fatalf("Apply failed for x=%f: %v", tc.x, err) + } + Materialize(grads[0]) + + g := grads[0].Float() + if math.Abs(g-tc.want) > 1e-5 { + t.Errorf("x=%f: grad = %f, want %f", tc.x, g, tc.want) + } + } +} + +func TestCrossEntropyLoss(t *testing.T) { + // Simple 3-class classification + // logits = [1.0, 2.0, 3.0], target = 2 (class index) + // Manual: logsumexp([1,2,3]) = 3 + log(exp(-2)+exp(-1)+1) + // = 3 + log(0.1353 + 0.3679 + 1.0) = 3 + log(1.5032) = 3.4076 + // loss = 3.4076 - 3.0 = 0.4076 + logits := FromValues([]float32{1.0, 2.0, 3.0}, 1, 3) // [1, 3] + targets := FromValues([]int32{2}, 1) // [1] + + loss := CrossEntropyLoss(logits, targets) + Materialize(loss) + + got := loss.Float() + expected := 0.4076 + if math.Abs(got-expected) > 0.01 { + t.Errorf("CrossEntropyLoss = %f, want ~%f", got, expected) + } +} + +func TestMSELoss(t *testing.T) { + pred := FromValues([]float32{1.0, 2.0, 3.0}, 3) + target := FromValues([]float32{1.5, 2.5, 3.5}, 3) + + loss := MSELoss(pred, target) + Materialize(loss) + + // MSE = mean((0.5)^2, (0.5)^2, (0.5)^2) = mean(0.25, 0.25, 0.25) = 0.25 + got := loss.Float() + if math.Abs(got-0.25) > 1e-5 { + t.Errorf("MSELoss = %f, want 0.25", got) + } +} + +func TestLogSumExp(t *testing.T) { + // logsumexp([1, 2, 3]) along axis -1 + a := FromValues([]float32{1.0, 2.0, 3.0}, 1, 3) + result := LogSumExp(a, -1, false) + Materialize(result) + + // = 3 + log(exp(-2) + exp(-1) + 1) = 3 + log(1.5032) ≈ 3.4076 + got := result.Float() + expected := 3.4076 + if math.Abs(got-expected) > 0.01 { + t.Errorf("LogSumExp = %f, want ~%f", got, expected) + } +} + +func TestOnesLike(t *testing.T) { + a := FromValues([]float32{1.0, 2.0, 3.0}, 3) + ones := OnesLike(a) + Materialize(ones) + + floats := ones.Floats() + for i, f := range floats { + if f != 1.0 { + t.Errorf("OnesLike[%d] = %f, want 1.0", i, f) + } + } +} + +func TestCheckpoint(t *testing.T) { + // Checkpoint should produce the same result as the original function + fn := func(inputs []*Array) []*Array { + x := inputs[0] + return []*Array{Mul(x, x)} + } + + cpFn := Checkpoint(fn) + + x := FromValue(float32(5.0)) + result := cpFn([]*Array{x}) + Materialize(result[0]) + + got := result[0].Float() + if math.Abs(got-25.0) > 1e-5 { + t.Errorf("Checkpoint result = %f, want 25.0", got) + } +} + +func TestSumAll(t *testing.T) { + a := FromValues([]float32{1.0, 2.0, 3.0, 4.0}, 2, 2) + result := SumAll(a) + Materialize(result) + + got := result.Float() + if math.Abs(got-10.0) > 1e-5 { + t.Errorf("SumAll = %f, want 10.0", got) + } +} + +func TestMeanAll(t *testing.T) { + a := FromValues([]float32{2.0, 4.0, 6.0, 8.0}, 2, 2) + result := MeanAll(a) + Materialize(result) + + got := result.Float() + if math.Abs(got-5.0) > 1e-5 { + t.Errorf("MeanAll = %f, want 5.0", got) + } +} diff --git a/mlx/ops.go b/mlx/ops.go index c743ce9..99f953b 100644 --- a/mlx/ops.go +++ b/mlx/ops.go @@ -335,3 +335,19 @@ func PutAlongAxis(a, indices, values *Array, axis int) *Array { C.mlx_put_along_axis(&out.ctx, a.ctx, indices.ctx, values.ctx, C.int(axis), DefaultStream().ctx) return out } + +// TakeAlongAxis gathers elements from a along axis using indices. +// Unlike Take, this uses the same number of dimensions for indices and input. +func TakeAlongAxis(a, indices *Array, axis int) *Array { + out := New("TAKE_ALONG_AXIS", a, indices) + C.mlx_take_along_axis(&out.ctx, a.ctx, indices.ctx, C.int(axis), DefaultStream().ctx) + return out +} + +// LogSumExp computes log(sum(exp(a))) along the given axis. +// Numerically stable reduction for cross-entropy loss. +func LogSumExp(a *Array, axis int, keepDims bool) *Array { + out := New("LOGSUMEXP", a) + C.mlx_logsumexp_axis(&out.ctx, a.ctx, C.int(axis), C._Bool(keepDims), DefaultStream().ctx) + return out +}