feat(mlx): add autograd — VJP, JVP, ValueAndGrad, loss functions
Native Go bindings for MLX-C gradient computation on Apple Silicon. Foundation for LoRA training without Python. - VJP (reverse-mode autodiff) for backward pass - JVP (forward-mode autodiff) for directional derivatives - ValueAndGrad for combined loss + gradient computation - Checkpoint for memory-efficient gradient recomputation - CrossEntropyLoss (numerically stable via LogSumExp) - MSELoss, Log, SumAll, MeanAll, OnesLike helpers - TakeAlongAxis and LogSumExp ops - Fix closure callback null vector bug (affects compile.go too) - Fix Float() returning 0 for float32 arrays 14 tests passing on Metal GPU. Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
parent
92c6282d50
commit
e9973aef3c
5 changed files with 702 additions and 5 deletions
14
mlx/array.go
14
mlx/array.go
|
|
@ -204,10 +204,18 @@ func (t Array) Int() int {
|
|||
}
|
||||
|
||||
// Float extracts a scalar float64 value.
|
||||
// Handles both float32 and float64 array dtypes.
|
||||
func (t Array) Float() float64 {
|
||||
var item C.double
|
||||
C.mlx_array_item_float64(&item, t.ctx)
|
||||
return float64(item)
|
||||
switch t.Dtype() {
|
||||
case DTypeFloat32:
|
||||
var item C.float
|
||||
C.mlx_array_item_float32(&item, t.ctx)
|
||||
return float64(item)
|
||||
default:
|
||||
var item C.double
|
||||
C.mlx_array_item_float64(&item, t.ctx)
|
||||
return float64(item)
|
||||
}
|
||||
}
|
||||
|
||||
// Ints extracts all elements as int slice (from int32 data).
|
||||
|
|
|
|||
|
|
@ -49,10 +49,14 @@ func goCompiledFunc(outputs *C.mlx_vector_array, inputs C.mlx_vector_array, payl
|
|||
// Call user function
|
||||
goOutputs := fn(goInputs)
|
||||
|
||||
// Set outputs
|
||||
// The output vector arrives with ctx=nullptr. Create a valid vector,
|
||||
// fill it, then copy to the output pointer.
|
||||
tmp := C.mlx_vector_array_new()
|
||||
for _, out := range goOutputs {
|
||||
C.mlx_vector_array_append_value(*outputs, out.ctx)
|
||||
C.mlx_vector_array_append_value(tmp, out.ctx)
|
||||
}
|
||||
C.mlx_vector_array_set(outputs, tmp)
|
||||
C.mlx_vector_array_free(tmp)
|
||||
return 0
|
||||
}
|
||||
|
||||
|
|
|
|||
337
mlx/grad.go
Normal file
337
mlx/grad.go
Normal file
|
|
@ -0,0 +1,337 @@
|
|||
//go:build darwin && arm64
|
||||
|
||||
package mlx
|
||||
|
||||
/*
|
||||
#include "mlx/c/mlx.h"
|
||||
|
||||
// Callback for gradient closures — same signature as goCompiledFunc.
|
||||
extern int goGradFunc(mlx_vector_array *outputs, const mlx_vector_array inputs, void *payload);
|
||||
|
||||
static mlx_closure new_grad_closure(void *payload) {
|
||||
return mlx_closure_new_func_payload(&goGradFunc, payload, NULL);
|
||||
}
|
||||
*/
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// --- Closure registry (separate from compile.go's registry) ---
|
||||
|
||||
var (
|
||||
gradFuncs sync.Map
|
||||
gradNextID atomic.Uintptr
|
||||
)
|
||||
|
||||
//export goGradFunc
|
||||
func goGradFunc(outputs *C.mlx_vector_array, inputs C.mlx_vector_array, payload unsafe.Pointer) C.int {
|
||||
id := uintptr(payload)
|
||||
fnI, ok := gradFuncs.Load(id)
|
||||
if !ok {
|
||||
return 1
|
||||
}
|
||||
fn := fnI.(func([]*Array) []*Array)
|
||||
|
||||
nInputs := int(C.mlx_vector_array_size(inputs))
|
||||
goInputs := make([]*Array, nInputs)
|
||||
for i := 0; i < nInputs; i++ {
|
||||
a := New("GRAD_INPUT")
|
||||
C.mlx_vector_array_get(&a.ctx, inputs, C.size_t(i))
|
||||
goInputs[i] = a
|
||||
}
|
||||
|
||||
goOutputs := fn(goInputs)
|
||||
|
||||
// The output vector arrives with ctx=nullptr (from internal mlx_vector_array_new_).
|
||||
// Create a valid vector, fill it, then copy to the output pointer.
|
||||
tmp := C.mlx_vector_array_new()
|
||||
for _, out := range goOutputs {
|
||||
C.mlx_vector_array_append_value(tmp, out.ctx)
|
||||
}
|
||||
C.mlx_vector_array_set(outputs, tmp)
|
||||
C.mlx_vector_array_free(tmp)
|
||||
return 0
|
||||
}
|
||||
|
||||
// newClosure registers a Go function as an MLX closure for autograd.
|
||||
func newClosure(fn func([]*Array) []*Array) C.mlx_closure {
|
||||
id := gradNextID.Add(1)
|
||||
gradFuncs.Store(id, fn)
|
||||
return C.new_grad_closure(unsafe.Pointer(id))
|
||||
}
|
||||
|
||||
// --- VJP (Vector-Jacobian Product) — Reverse Mode ---
|
||||
|
||||
// VJP computes the vector-Jacobian product (reverse-mode autodiff).
|
||||
// Given a function fn, input primals, and output cotangents (upstream gradients),
|
||||
// returns (outputs, gradients) where gradients are w.r.t. the primals.
|
||||
//
|
||||
// This is the fundamental backward pass operation.
|
||||
func VJP(fn func([]*Array) []*Array, primals []*Array, cotangents []*Array) (outputs []*Array, vjps []*Array, err error) {
|
||||
Init()
|
||||
|
||||
closure := newClosure(fn)
|
||||
defer C.mlx_closure_free(closure)
|
||||
|
||||
// Pack primals into vector
|
||||
primalsVec := C.mlx_vector_array_new()
|
||||
defer C.mlx_vector_array_free(primalsVec)
|
||||
for _, p := range primals {
|
||||
C.mlx_vector_array_append_value(primalsVec, p.ctx)
|
||||
}
|
||||
|
||||
// Pack cotangents into vector
|
||||
cotangentsVec := C.mlx_vector_array_new()
|
||||
defer C.mlx_vector_array_free(cotangentsVec)
|
||||
for _, c := range cotangents {
|
||||
C.mlx_vector_array_append_value(cotangentsVec, c.ctx)
|
||||
}
|
||||
|
||||
// Call mlx_vjp
|
||||
var outVec, vjpVec C.mlx_vector_array
|
||||
outVec = C.mlx_vector_array_new()
|
||||
vjpVec = C.mlx_vector_array_new()
|
||||
defer C.mlx_vector_array_free(outVec)
|
||||
defer C.mlx_vector_array_free(vjpVec)
|
||||
|
||||
rc := C.mlx_vjp(&outVec, &vjpVec, closure, primalsVec, cotangentsVec)
|
||||
if rc != 0 {
|
||||
checkError()
|
||||
return nil, nil, fmt.Errorf("mlx: vjp failed")
|
||||
}
|
||||
|
||||
outputs = vectorToArrays(outVec)
|
||||
vjps = vectorToArrays(vjpVec)
|
||||
return outputs, vjps, nil
|
||||
}
|
||||
|
||||
// --- JVP (Jacobian-Vector Product) — Forward Mode ---
|
||||
|
||||
// JVP computes the Jacobian-vector product (forward-mode autodiff).
|
||||
// Given a function fn, input primals, and input tangents (perturbation directions),
|
||||
// returns (outputs, output_tangents).
|
||||
//
|
||||
// Useful for directional derivatives and Hessian-vector products.
|
||||
func JVP(fn func([]*Array) []*Array, primals []*Array, tangents []*Array) (outputs []*Array, jvps []*Array, err error) {
|
||||
Init()
|
||||
|
||||
closure := newClosure(fn)
|
||||
defer C.mlx_closure_free(closure)
|
||||
|
||||
primalsVec := C.mlx_vector_array_new()
|
||||
defer C.mlx_vector_array_free(primalsVec)
|
||||
for _, p := range primals {
|
||||
C.mlx_vector_array_append_value(primalsVec, p.ctx)
|
||||
}
|
||||
|
||||
tangentsVec := C.mlx_vector_array_new()
|
||||
defer C.mlx_vector_array_free(tangentsVec)
|
||||
for _, t := range tangents {
|
||||
C.mlx_vector_array_append_value(tangentsVec, t.ctx)
|
||||
}
|
||||
|
||||
var outVec, jvpVec C.mlx_vector_array
|
||||
outVec = C.mlx_vector_array_new()
|
||||
jvpVec = C.mlx_vector_array_new()
|
||||
defer C.mlx_vector_array_free(outVec)
|
||||
defer C.mlx_vector_array_free(jvpVec)
|
||||
|
||||
rc := C.mlx_jvp(&outVec, &jvpVec, closure, primalsVec, tangentsVec)
|
||||
if rc != 0 {
|
||||
checkError()
|
||||
return nil, nil, fmt.Errorf("mlx: jvp failed")
|
||||
}
|
||||
|
||||
outputs = vectorToArrays(outVec)
|
||||
jvps = vectorToArrays(jvpVec)
|
||||
return outputs, jvps, nil
|
||||
}
|
||||
|
||||
// --- ValueAndGrad — Combined Forward + Backward ---
|
||||
|
||||
// GradFn is a function that computes both the loss value and gradients
|
||||
// with respect to specified arguments. Call it with inputs to get
|
||||
// (values, gradients).
|
||||
type GradFn struct {
|
||||
cls C.mlx_closure_value_and_grad
|
||||
}
|
||||
|
||||
// ValueAndGrad creates a GradFn that computes both the function value and
|
||||
// gradients with respect to the arguments at the given indices.
|
||||
//
|
||||
// The returned GradFn can be called repeatedly with different inputs.
|
||||
// This is the primary API for training loops:
|
||||
//
|
||||
// lossFn := func(params []*Array) []*Array { return []*Array{loss} }
|
||||
// grad := mlx.ValueAndGrad(lossFn, 0) // differentiate w.r.t. first arg
|
||||
// values, grads := grad.Apply(params...)
|
||||
func ValueAndGrad(fn func([]*Array) []*Array, argnums ...int) *GradFn {
|
||||
Init()
|
||||
|
||||
closure := newClosure(fn)
|
||||
defer C.mlx_closure_free(closure)
|
||||
|
||||
// Default: differentiate w.r.t. first argument
|
||||
if len(argnums) == 0 {
|
||||
argnums = []int{0}
|
||||
}
|
||||
|
||||
cArgs := make([]C.int, len(argnums))
|
||||
for i, a := range argnums {
|
||||
cArgs[i] = C.int(a)
|
||||
}
|
||||
|
||||
g := &GradFn{}
|
||||
C.mlx_value_and_grad(&g.cls, closure, &cArgs[0], C.size_t(len(cArgs)))
|
||||
return g
|
||||
}
|
||||
|
||||
// Apply calls the gradient function with the given inputs.
|
||||
// Returns (values, gradients) where values are the function outputs
|
||||
// and gradients are w.r.t. the arguments specified in ValueAndGrad.
|
||||
func (g *GradFn) Apply(inputs ...*Array) (values []*Array, grads []*Array, err error) {
|
||||
inputVec := C.mlx_vector_array_new()
|
||||
defer C.mlx_vector_array_free(inputVec)
|
||||
for _, in := range inputs {
|
||||
C.mlx_vector_array_append_value(inputVec, in.ctx)
|
||||
}
|
||||
|
||||
var valVec, gradVec C.mlx_vector_array
|
||||
valVec = C.mlx_vector_array_new()
|
||||
gradVec = C.mlx_vector_array_new()
|
||||
defer C.mlx_vector_array_free(valVec)
|
||||
defer C.mlx_vector_array_free(gradVec)
|
||||
|
||||
rc := C.mlx_closure_value_and_grad_apply(&valVec, &gradVec, g.cls, inputVec)
|
||||
if rc != 0 {
|
||||
checkError()
|
||||
return nil, nil, fmt.Errorf("mlx: value_and_grad apply failed")
|
||||
}
|
||||
|
||||
values = vectorToArrays(valVec)
|
||||
grads = vectorToArrays(gradVec)
|
||||
return values, grads, nil
|
||||
}
|
||||
|
||||
// Free releases the underlying C closure.
|
||||
func (g *GradFn) Free() {
|
||||
if g.cls.ctx != nil {
|
||||
C.mlx_closure_value_and_grad_free(g.cls)
|
||||
g.cls.ctx = nil
|
||||
}
|
||||
}
|
||||
|
||||
// --- Checkpoint — Memory-Efficient Gradient Recomputation ---
|
||||
|
||||
// Checkpoint wraps a function so that during backward pass, intermediate
|
||||
// activations are recomputed rather than stored. Trades compute for memory.
|
||||
//
|
||||
// Use this for memory-constrained training (large models on limited VRAM).
|
||||
func Checkpoint(fn func([]*Array) []*Array) func([]*Array) []*Array {
|
||||
Init()
|
||||
|
||||
closure := newClosure(fn)
|
||||
defer C.mlx_closure_free(closure)
|
||||
|
||||
var checkpointed C.mlx_closure
|
||||
C.mlx_checkpoint(&checkpointed, closure)
|
||||
|
||||
// Register the checkpointed closure so it stays alive
|
||||
id := gradNextID.Add(1)
|
||||
cpClosure := checkpointed // capture for the returned function
|
||||
|
||||
// Return a Go function that calls through the checkpointed closure
|
||||
return func(inputs []*Array) []*Array {
|
||||
_ = id // prevent GC of closure reference
|
||||
|
||||
inputVec := C.mlx_vector_array_new()
|
||||
defer C.mlx_vector_array_free(inputVec)
|
||||
for _, in := range inputs {
|
||||
C.mlx_vector_array_append_value(inputVec, in.ctx)
|
||||
}
|
||||
|
||||
outVec := C.mlx_vector_array_new()
|
||||
defer C.mlx_vector_array_free(outVec)
|
||||
|
||||
C.mlx_closure_apply(&outVec, cpClosure, inputVec)
|
||||
return vectorToArrays(outVec)
|
||||
}
|
||||
}
|
||||
|
||||
// --- Loss Functions ---
|
||||
|
||||
// CrossEntropyLoss computes cross-entropy loss between logits and integer targets.
|
||||
// logits: [..., V] (raw model output, pre-softmax, last dim = vocab)
|
||||
// targets: [...] (integer token IDs, same shape as logits minus last dim)
|
||||
// Returns scalar loss averaged over all positions.
|
||||
//
|
||||
// Numerically stable: loss_i = logsumexp(logits_i) - logits_i[target_i]
|
||||
func CrossEntropyLoss(logits, targets *Array) *Array {
|
||||
Init()
|
||||
// LogSumExp along last axis (vocab dimension) for numerical stability
|
||||
lse := LogSumExp(logits, -1, false)
|
||||
|
||||
// Gather the logit at the target index: logits[..., target]
|
||||
// targets needs to be [..., 1] for take_along_axis
|
||||
tgtExpanded := ExpandDims(targets, -1)
|
||||
gathered := TakeAlongAxis(logits, tgtExpanded, -1)
|
||||
gathered = Squeeze(gathered, -1) // back to [...] shape
|
||||
|
||||
// Per-position loss: logsumexp - logit_at_target
|
||||
perPos := Subtract(lse, gathered)
|
||||
|
||||
// Mean over all positions
|
||||
return MeanAll(perPos)
|
||||
}
|
||||
|
||||
// MSELoss computes the mean squared error loss: mean((predictions - targets)^2).
|
||||
func MSELoss(predictions, targets *Array) *Array {
|
||||
diff := Subtract(predictions, targets)
|
||||
sq := Square(diff)
|
||||
return MeanAll(sq)
|
||||
}
|
||||
|
||||
// --- Helpers ---
|
||||
|
||||
// vectorToArrays extracts all arrays from an mlx_vector_array.
|
||||
func vectorToArrays(vec C.mlx_vector_array) []*Array {
|
||||
n := int(C.mlx_vector_array_size(vec))
|
||||
out := make([]*Array, n)
|
||||
for i := 0; i < n; i++ {
|
||||
a := New("VEC_OUT")
|
||||
C.mlx_vector_array_get(&a.ctx, vec, C.size_t(i))
|
||||
out[i] = a
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// Log returns element-wise natural logarithm.
|
||||
func Log(a *Array) *Array {
|
||||
out := New("LOG", a)
|
||||
C.mlx_log(&out.ctx, a.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// SumAll reduces by summation over all elements, returning a scalar.
|
||||
func SumAll(a *Array) *Array {
|
||||
flat := Reshape(a, -1)
|
||||
return Sum(flat, 0, false)
|
||||
}
|
||||
|
||||
// MeanAll reduces by averaging over all elements, returning a scalar.
|
||||
func MeanAll(a *Array) *Array {
|
||||
flat := Reshape(a, -1)
|
||||
return Mean(flat, 0, false)
|
||||
}
|
||||
|
||||
// OnesLike creates an array of ones with the same shape and type as the input.
|
||||
func OnesLike(a *Array) *Array {
|
||||
out := New("ONES_LIKE", a)
|
||||
C.mlx_ones_like(&out.ctx, a.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
332
mlx/grad_test.go
Normal file
332
mlx/grad_test.go
Normal file
|
|
@ -0,0 +1,332 @@
|
|||
//go:build darwin && arm64
|
||||
|
||||
package mlx
|
||||
|
||||
import (
|
||||
"math"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestVJP_SimpleSquare(t *testing.T) {
|
||||
// f(x) = x^2, df/dx = 2x
|
||||
// At x=3: f(3)=9, df/dx=6
|
||||
fn := func(inputs []*Array) []*Array {
|
||||
x := inputs[0]
|
||||
return []*Array{Mul(x, x)}
|
||||
}
|
||||
|
||||
x := FromValue(float32(3.0))
|
||||
cotangent := FromValue(float32(1.0)) // upstream grad = 1
|
||||
|
||||
outputs, grads, err := VJP(fn, []*Array{x}, []*Array{cotangent})
|
||||
if err != nil {
|
||||
t.Fatalf("VJP failed: %v", err)
|
||||
}
|
||||
|
||||
Materialize(outputs[0], grads[0])
|
||||
|
||||
got := outputs[0].Float()
|
||||
if math.Abs(got-9.0) > 1e-5 {
|
||||
t.Errorf("output = %f, want 9.0", got)
|
||||
}
|
||||
|
||||
grad := grads[0].Float()
|
||||
if math.Abs(grad-6.0) > 1e-5 {
|
||||
t.Errorf("grad = %f, want 6.0", grad)
|
||||
}
|
||||
}
|
||||
|
||||
func TestVJP_Addition(t *testing.T) {
|
||||
// f(x, y) = x + y, df/dx = 1, df/dy = 1
|
||||
fn := func(inputs []*Array) []*Array {
|
||||
return []*Array{Add(inputs[0], inputs[1])}
|
||||
}
|
||||
|
||||
x := FromValue(float32(2.0))
|
||||
y := FromValue(float32(5.0))
|
||||
cotangent := FromValue(float32(1.0))
|
||||
|
||||
_, grads, err := VJP(fn, []*Array{x, y}, []*Array{cotangent})
|
||||
if err != nil {
|
||||
t.Fatalf("VJP failed: %v", err)
|
||||
}
|
||||
|
||||
Materialize(grads...)
|
||||
|
||||
if math.Abs(grads[0].Float()-1.0) > 1e-5 {
|
||||
t.Errorf("dx = %f, want 1.0", grads[0].Float())
|
||||
}
|
||||
if math.Abs(grads[1].Float()-1.0) > 1e-5 {
|
||||
t.Errorf("dy = %f, want 1.0", grads[1].Float())
|
||||
}
|
||||
}
|
||||
|
||||
func TestVJP_MatmulGrad(t *testing.T) {
|
||||
// f(W) = sum(W @ x) — gradient of sum(matmul) w.r.t. W
|
||||
// For W=[2,2], x=[2,1]: dL/dW = ones @ x^T
|
||||
x := FromValues([]float32{1.0, 2.0}, 2, 1)
|
||||
w := FromValues([]float32{1.0, 0.0, 0.0, 1.0}, 2, 2) // identity
|
||||
|
||||
fn := func(inputs []*Array) []*Array {
|
||||
result := Matmul(inputs[0], x)
|
||||
return []*Array{SumAll(result)}
|
||||
}
|
||||
|
||||
cotangent := FromValue(float32(1.0))
|
||||
|
||||
outputs, grads, err := VJP(fn, []*Array{w}, []*Array{cotangent})
|
||||
if err != nil {
|
||||
t.Fatalf("VJP failed: %v", err)
|
||||
}
|
||||
|
||||
Materialize(outputs[0], grads[0])
|
||||
|
||||
// W @ x with W=I, x=[1,2]^T gives [1,2]^T, sum=3
|
||||
got := outputs[0].Float()
|
||||
if math.Abs(got-3.0) > 1e-5 {
|
||||
t.Errorf("output = %f, want 3.0", got)
|
||||
}
|
||||
|
||||
// Gradient of sum(W@x) w.r.t. W is outer product: ones @ x^T
|
||||
// = [[1,2],[1,2]]
|
||||
gradFloats := grads[0].Floats()
|
||||
expected := []float32{1.0, 2.0, 1.0, 2.0}
|
||||
for i, exp := range expected {
|
||||
if math.Abs(float64(gradFloats[i]-exp)) > 1e-5 {
|
||||
t.Errorf("grad[%d] = %f, want %f", i, gradFloats[i], exp)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestJVP_SimpleSquare(t *testing.T) {
|
||||
// f(x) = x^2, JVP with tangent v: df = 2x * v
|
||||
// At x=3, v=1: df = 6
|
||||
fn := func(inputs []*Array) []*Array {
|
||||
x := inputs[0]
|
||||
return []*Array{Mul(x, x)}
|
||||
}
|
||||
|
||||
x := FromValue(float32(3.0))
|
||||
tangent := FromValue(float32(1.0))
|
||||
|
||||
outputs, jvps, err := JVP(fn, []*Array{x}, []*Array{tangent})
|
||||
if err != nil {
|
||||
t.Fatalf("JVP failed: %v", err)
|
||||
}
|
||||
|
||||
Materialize(outputs[0], jvps[0])
|
||||
|
||||
got := outputs[0].Float()
|
||||
if math.Abs(got-9.0) > 1e-5 {
|
||||
t.Errorf("output = %f, want 9.0", got)
|
||||
}
|
||||
|
||||
jvp := jvps[0].Float()
|
||||
if math.Abs(jvp-6.0) > 1e-5 {
|
||||
t.Errorf("jvp = %f, want 6.0", jvp)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValueAndGrad_Quadratic(t *testing.T) {
|
||||
// f(x) = x^2 + 2x + 1 = (x+1)^2
|
||||
// f'(x) = 2x + 2
|
||||
// At x=3: f(3) = 16, f'(3) = 8
|
||||
fn := func(inputs []*Array) []*Array {
|
||||
x := inputs[0]
|
||||
x2 := Mul(x, x)
|
||||
two_x := MulScalar(x, 2.0)
|
||||
one := FromValue(float32(1.0))
|
||||
return []*Array{Add(Add(x2, two_x), one)}
|
||||
}
|
||||
|
||||
grad := ValueAndGrad(fn, 0)
|
||||
defer grad.Free()
|
||||
|
||||
x := FromValue(float32(3.0))
|
||||
values, grads, err := grad.Apply(x)
|
||||
if err != nil {
|
||||
t.Fatalf("ValueAndGrad failed: %v", err)
|
||||
}
|
||||
|
||||
Materialize(values[0], grads[0])
|
||||
|
||||
val := values[0].Float()
|
||||
if math.Abs(val-16.0) > 1e-5 {
|
||||
t.Errorf("value = %f, want 16.0", val)
|
||||
}
|
||||
|
||||
g := grads[0].Float()
|
||||
if math.Abs(g-8.0) > 1e-5 {
|
||||
t.Errorf("grad = %f, want 8.0", g)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValueAndGrad_MultiArg(t *testing.T) {
|
||||
// f(x, y) = x*y, df/dx = y, df/dy = x
|
||||
// At x=3, y=4: f=12, dx=4, dy=3
|
||||
fn := func(inputs []*Array) []*Array {
|
||||
return []*Array{Mul(inputs[0], inputs[1])}
|
||||
}
|
||||
|
||||
// Differentiate w.r.t. both arguments
|
||||
grad := ValueAndGrad(fn, 0, 1)
|
||||
defer grad.Free()
|
||||
|
||||
x := FromValue(float32(3.0))
|
||||
y := FromValue(float32(4.0))
|
||||
values, grads, err := grad.Apply(x, y)
|
||||
if err != nil {
|
||||
t.Fatalf("ValueAndGrad failed: %v", err)
|
||||
}
|
||||
|
||||
Materialize(values[0], grads[0], grads[1])
|
||||
|
||||
val := values[0].Float()
|
||||
if math.Abs(val-12.0) > 1e-5 {
|
||||
t.Errorf("value = %f, want 12.0", val)
|
||||
}
|
||||
|
||||
dx := grads[0].Float()
|
||||
if math.Abs(dx-4.0) > 1e-5 {
|
||||
t.Errorf("dx = %f, want 4.0 (y)", dx)
|
||||
}
|
||||
|
||||
dy := grads[1].Float()
|
||||
if math.Abs(dy-3.0) > 1e-5 {
|
||||
t.Errorf("dy = %f, want 3.0 (x)", dy)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValueAndGrad_Reusable(t *testing.T) {
|
||||
// Verify GradFn can be called multiple times
|
||||
fn := func(inputs []*Array) []*Array {
|
||||
x := inputs[0]
|
||||
return []*Array{Mul(x, x)} // x^2, grad = 2x
|
||||
}
|
||||
|
||||
grad := ValueAndGrad(fn)
|
||||
defer grad.Free()
|
||||
|
||||
for _, tc := range []struct {
|
||||
x float32
|
||||
want float64 // expected gradient
|
||||
}{
|
||||
{2.0, 4.0},
|
||||
{5.0, 10.0},
|
||||
{-3.0, -6.0},
|
||||
{0.0, 0.0},
|
||||
} {
|
||||
x := FromValue(tc.x)
|
||||
_, grads, err := grad.Apply(x)
|
||||
if err != nil {
|
||||
t.Fatalf("Apply failed for x=%f: %v", tc.x, err)
|
||||
}
|
||||
Materialize(grads[0])
|
||||
|
||||
g := grads[0].Float()
|
||||
if math.Abs(g-tc.want) > 1e-5 {
|
||||
t.Errorf("x=%f: grad = %f, want %f", tc.x, g, tc.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCrossEntropyLoss(t *testing.T) {
|
||||
// Simple 3-class classification
|
||||
// logits = [1.0, 2.0, 3.0], target = 2 (class index)
|
||||
// Manual: logsumexp([1,2,3]) = 3 + log(exp(-2)+exp(-1)+1)
|
||||
// = 3 + log(0.1353 + 0.3679 + 1.0) = 3 + log(1.5032) = 3.4076
|
||||
// loss = 3.4076 - 3.0 = 0.4076
|
||||
logits := FromValues([]float32{1.0, 2.0, 3.0}, 1, 3) // [1, 3]
|
||||
targets := FromValues([]int32{2}, 1) // [1]
|
||||
|
||||
loss := CrossEntropyLoss(logits, targets)
|
||||
Materialize(loss)
|
||||
|
||||
got := loss.Float()
|
||||
expected := 0.4076
|
||||
if math.Abs(got-expected) > 0.01 {
|
||||
t.Errorf("CrossEntropyLoss = %f, want ~%f", got, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMSELoss(t *testing.T) {
|
||||
pred := FromValues([]float32{1.0, 2.0, 3.0}, 3)
|
||||
target := FromValues([]float32{1.5, 2.5, 3.5}, 3)
|
||||
|
||||
loss := MSELoss(pred, target)
|
||||
Materialize(loss)
|
||||
|
||||
// MSE = mean((0.5)^2, (0.5)^2, (0.5)^2) = mean(0.25, 0.25, 0.25) = 0.25
|
||||
got := loss.Float()
|
||||
if math.Abs(got-0.25) > 1e-5 {
|
||||
t.Errorf("MSELoss = %f, want 0.25", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogSumExp(t *testing.T) {
|
||||
// logsumexp([1, 2, 3]) along axis -1
|
||||
a := FromValues([]float32{1.0, 2.0, 3.0}, 1, 3)
|
||||
result := LogSumExp(a, -1, false)
|
||||
Materialize(result)
|
||||
|
||||
// = 3 + log(exp(-2) + exp(-1) + 1) = 3 + log(1.5032) ≈ 3.4076
|
||||
got := result.Float()
|
||||
expected := 3.4076
|
||||
if math.Abs(got-expected) > 0.01 {
|
||||
t.Errorf("LogSumExp = %f, want ~%f", got, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOnesLike(t *testing.T) {
|
||||
a := FromValues([]float32{1.0, 2.0, 3.0}, 3)
|
||||
ones := OnesLike(a)
|
||||
Materialize(ones)
|
||||
|
||||
floats := ones.Floats()
|
||||
for i, f := range floats {
|
||||
if f != 1.0 {
|
||||
t.Errorf("OnesLike[%d] = %f, want 1.0", i, f)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckpoint(t *testing.T) {
|
||||
// Checkpoint should produce the same result as the original function
|
||||
fn := func(inputs []*Array) []*Array {
|
||||
x := inputs[0]
|
||||
return []*Array{Mul(x, x)}
|
||||
}
|
||||
|
||||
cpFn := Checkpoint(fn)
|
||||
|
||||
x := FromValue(float32(5.0))
|
||||
result := cpFn([]*Array{x})
|
||||
Materialize(result[0])
|
||||
|
||||
got := result[0].Float()
|
||||
if math.Abs(got-25.0) > 1e-5 {
|
||||
t.Errorf("Checkpoint result = %f, want 25.0", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSumAll(t *testing.T) {
|
||||
a := FromValues([]float32{1.0, 2.0, 3.0, 4.0}, 2, 2)
|
||||
result := SumAll(a)
|
||||
Materialize(result)
|
||||
|
||||
got := result.Float()
|
||||
if math.Abs(got-10.0) > 1e-5 {
|
||||
t.Errorf("SumAll = %f, want 10.0", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMeanAll(t *testing.T) {
|
||||
a := FromValues([]float32{2.0, 4.0, 6.0, 8.0}, 2, 2)
|
||||
result := MeanAll(a)
|
||||
Materialize(result)
|
||||
|
||||
got := result.Float()
|
||||
if math.Abs(got-5.0) > 1e-5 {
|
||||
t.Errorf("MeanAll = %f, want 5.0", got)
|
||||
}
|
||||
}
|
||||
16
mlx/ops.go
16
mlx/ops.go
|
|
@ -335,3 +335,19 @@ func PutAlongAxis(a, indices, values *Array, axis int) *Array {
|
|||
C.mlx_put_along_axis(&out.ctx, a.ctx, indices.ctx, values.ctx, C.int(axis), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// TakeAlongAxis gathers elements from a along axis using indices.
|
||||
// Unlike Take, this uses the same number of dimensions for indices and input.
|
||||
func TakeAlongAxis(a, indices *Array, axis int) *Array {
|
||||
out := New("TAKE_ALONG_AXIS", a, indices)
|
||||
C.mlx_take_along_axis(&out.ctx, a.ctx, indices.ctx, C.int(axis), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// LogSumExp computes log(sum(exp(a))) along the given axis.
|
||||
// Numerically stable reduction for cross-entropy loss.
|
||||
func LogSumExp(a *Array, axis int, keepDims bool) *Array {
|
||||
out := New("LOGSUMEXP", a)
|
||||
C.mlx_logsumexp_axis(&out.ctx, a.ctx, C.int(axis), C._Bool(keepDims), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue