go-ai/mlx/grad.go
Snider 2d870385f9 feat(mlx): LoRA injection into models + masked cross-entropy loss
Add LoRA field to Linear for transparent adapter injection via model's
Forward() path. ApplyLoRA() on Qwen3/Gemma3 wraps target projections.
Deterministic param ordering for adapter save/load consistency.
MaskedCrossEntropyLoss for training on assistant tokens only.

Co-Authored-By: Virgil <virgil@lethean.io>
2026-02-17 17:37:44 +00:00

353 lines
10 KiB
Go

//go:build darwin && arm64
package mlx
/*
#include "mlx/c/mlx.h"
// Callback for gradient closures — same signature as goCompiledFunc.
extern int goGradFunc(mlx_vector_array *outputs, const mlx_vector_array inputs, void *payload);
static mlx_closure new_grad_closure(void *payload) {
return mlx_closure_new_func_payload(&goGradFunc, payload, NULL);
}
*/
import "C"
import (
"fmt"
"sync"
"sync/atomic"
"unsafe"
)
// --- Closure registry (separate from compile.go's registry) ---
var (
gradFuncs sync.Map
gradNextID atomic.Uintptr
)
//export goGradFunc
func goGradFunc(outputs *C.mlx_vector_array, inputs C.mlx_vector_array, payload unsafe.Pointer) C.int {
id := uintptr(payload)
fnI, ok := gradFuncs.Load(id)
if !ok {
return 1
}
fn := fnI.(func([]*Array) []*Array)
nInputs := int(C.mlx_vector_array_size(inputs))
goInputs := make([]*Array, nInputs)
for i := 0; i < nInputs; i++ {
a := New("GRAD_INPUT")
C.mlx_vector_array_get(&a.ctx, inputs, C.size_t(i))
goInputs[i] = a
}
goOutputs := fn(goInputs)
// The output vector arrives with ctx=nullptr (from internal mlx_vector_array_new_).
// Create a valid vector, fill it, then copy to the output pointer.
tmp := C.mlx_vector_array_new()
for _, out := range goOutputs {
C.mlx_vector_array_append_value(tmp, out.ctx)
}
C.mlx_vector_array_set(outputs, tmp)
C.mlx_vector_array_free(tmp)
return 0
}
// newClosure registers a Go function as an MLX closure for autograd.
func newClosure(fn func([]*Array) []*Array) C.mlx_closure {
id := gradNextID.Add(1)
gradFuncs.Store(id, fn)
return C.new_grad_closure(unsafe.Pointer(id))
}
// --- VJP (Vector-Jacobian Product) — Reverse Mode ---
// VJP computes the vector-Jacobian product (reverse-mode autodiff).
// Given a function fn, input primals, and output cotangents (upstream gradients),
// returns (outputs, gradients) where gradients are w.r.t. the primals.
//
// This is the fundamental backward pass operation.
func VJP(fn func([]*Array) []*Array, primals []*Array, cotangents []*Array) (outputs []*Array, vjps []*Array, err error) {
Init()
closure := newClosure(fn)
defer C.mlx_closure_free(closure)
// Pack primals into vector
primalsVec := C.mlx_vector_array_new()
defer C.mlx_vector_array_free(primalsVec)
for _, p := range primals {
C.mlx_vector_array_append_value(primalsVec, p.ctx)
}
// Pack cotangents into vector
cotangentsVec := C.mlx_vector_array_new()
defer C.mlx_vector_array_free(cotangentsVec)
for _, c := range cotangents {
C.mlx_vector_array_append_value(cotangentsVec, c.ctx)
}
// Call mlx_vjp
var outVec, vjpVec C.mlx_vector_array
outVec = C.mlx_vector_array_new()
vjpVec = C.mlx_vector_array_new()
defer C.mlx_vector_array_free(outVec)
defer C.mlx_vector_array_free(vjpVec)
rc := C.mlx_vjp(&outVec, &vjpVec, closure, primalsVec, cotangentsVec)
if rc != 0 {
checkError()
return nil, nil, fmt.Errorf("mlx: vjp failed")
}
outputs = vectorToArrays(outVec)
vjps = vectorToArrays(vjpVec)
return outputs, vjps, nil
}
// --- JVP (Jacobian-Vector Product) — Forward Mode ---
// JVP computes the Jacobian-vector product (forward-mode autodiff).
// Given a function fn, input primals, and input tangents (perturbation directions),
// returns (outputs, output_tangents).
//
// Useful for directional derivatives and Hessian-vector products.
func JVP(fn func([]*Array) []*Array, primals []*Array, tangents []*Array) (outputs []*Array, jvps []*Array, err error) {
Init()
closure := newClosure(fn)
defer C.mlx_closure_free(closure)
primalsVec := C.mlx_vector_array_new()
defer C.mlx_vector_array_free(primalsVec)
for _, p := range primals {
C.mlx_vector_array_append_value(primalsVec, p.ctx)
}
tangentsVec := C.mlx_vector_array_new()
defer C.mlx_vector_array_free(tangentsVec)
for _, t := range tangents {
C.mlx_vector_array_append_value(tangentsVec, t.ctx)
}
var outVec, jvpVec C.mlx_vector_array
outVec = C.mlx_vector_array_new()
jvpVec = C.mlx_vector_array_new()
defer C.mlx_vector_array_free(outVec)
defer C.mlx_vector_array_free(jvpVec)
rc := C.mlx_jvp(&outVec, &jvpVec, closure, primalsVec, tangentsVec)
if rc != 0 {
checkError()
return nil, nil, fmt.Errorf("mlx: jvp failed")
}
outputs = vectorToArrays(outVec)
jvps = vectorToArrays(jvpVec)
return outputs, jvps, nil
}
// --- ValueAndGrad — Combined Forward + Backward ---
// GradFn is a function that computes both the loss value and gradients
// with respect to specified arguments. Call it with inputs to get
// (values, gradients).
type GradFn struct {
cls C.mlx_closure_value_and_grad
}
// ValueAndGrad creates a GradFn that computes both the function value and
// gradients with respect to the arguments at the given indices.
//
// The returned GradFn can be called repeatedly with different inputs.
// This is the primary API for training loops:
//
// lossFn := func(params []*Array) []*Array { return []*Array{loss} }
// grad := mlx.ValueAndGrad(lossFn, 0) // differentiate w.r.t. first arg
// values, grads := grad.Apply(params...)
func ValueAndGrad(fn func([]*Array) []*Array, argnums ...int) *GradFn {
Init()
closure := newClosure(fn)
defer C.mlx_closure_free(closure)
// Default: differentiate w.r.t. first argument
if len(argnums) == 0 {
argnums = []int{0}
}
cArgs := make([]C.int, len(argnums))
for i, a := range argnums {
cArgs[i] = C.int(a)
}
g := &GradFn{}
C.mlx_value_and_grad(&g.cls, closure, &cArgs[0], C.size_t(len(cArgs)))
return g
}
// Apply calls the gradient function with the given inputs.
// Returns (values, gradients) where values are the function outputs
// and gradients are w.r.t. the arguments specified in ValueAndGrad.
func (g *GradFn) Apply(inputs ...*Array) (values []*Array, grads []*Array, err error) {
inputVec := C.mlx_vector_array_new()
defer C.mlx_vector_array_free(inputVec)
for _, in := range inputs {
C.mlx_vector_array_append_value(inputVec, in.ctx)
}
var valVec, gradVec C.mlx_vector_array
valVec = C.mlx_vector_array_new()
gradVec = C.mlx_vector_array_new()
defer C.mlx_vector_array_free(valVec)
defer C.mlx_vector_array_free(gradVec)
rc := C.mlx_closure_value_and_grad_apply(&valVec, &gradVec, g.cls, inputVec)
if rc != 0 {
checkError()
return nil, nil, fmt.Errorf("mlx: value_and_grad apply failed")
}
values = vectorToArrays(valVec)
grads = vectorToArrays(gradVec)
return values, grads, nil
}
// Free releases the underlying C closure.
func (g *GradFn) Free() {
if g.cls.ctx != nil {
C.mlx_closure_value_and_grad_free(g.cls)
g.cls.ctx = nil
}
}
// --- Checkpoint — Memory-Efficient Gradient Recomputation ---
// Checkpoint wraps a function so that during backward pass, intermediate
// activations are recomputed rather than stored. Trades compute for memory.
//
// Use this for memory-constrained training (large models on limited VRAM).
func Checkpoint(fn func([]*Array) []*Array) func([]*Array) []*Array {
Init()
closure := newClosure(fn)
defer C.mlx_closure_free(closure)
var checkpointed C.mlx_closure
C.mlx_checkpoint(&checkpointed, closure)
// Register the checkpointed closure so it stays alive
id := gradNextID.Add(1)
cpClosure := checkpointed // capture for the returned function
// Return a Go function that calls through the checkpointed closure
return func(inputs []*Array) []*Array {
_ = id // prevent GC of closure reference
inputVec := C.mlx_vector_array_new()
defer C.mlx_vector_array_free(inputVec)
for _, in := range inputs {
C.mlx_vector_array_append_value(inputVec, in.ctx)
}
outVec := C.mlx_vector_array_new()
defer C.mlx_vector_array_free(outVec)
C.mlx_closure_apply(&outVec, cpClosure, inputVec)
return vectorToArrays(outVec)
}
}
// --- Loss Functions ---
// CrossEntropyLoss computes cross-entropy loss between logits and integer targets.
// logits: [..., V] (raw model output, pre-softmax, last dim = vocab)
// targets: [...] (integer token IDs, same shape as logits minus last dim)
// Returns scalar loss averaged over all positions.
//
// Numerically stable: loss_i = logsumexp(logits_i) - logits_i[target_i]
func CrossEntropyLoss(logits, targets *Array) *Array {
Init()
// LogSumExp along last axis (vocab dimension) for numerical stability
lse := LogSumExp(logits, -1, false)
// Gather the logit at the target index: logits[..., target]
// targets needs to be [..., 1] for take_along_axis
tgtExpanded := ExpandDims(targets, -1)
gathered := TakeAlongAxis(logits, tgtExpanded, -1)
gathered = Squeeze(gathered, -1) // back to [...] shape
// Per-position loss: logsumexp - logit_at_target
perPos := Subtract(lse, gathered)
// Mean over all positions
return MeanAll(perPos)
}
// MaskedCrossEntropyLoss computes cross-entropy loss only on masked positions.
// logits: [B, L, V], targets: [B, L], mask: [B, L] (1.0 = compute loss, 0.0 = ignore)
// Returns scalar loss averaged over masked positions only.
func MaskedCrossEntropyLoss(logits, targets, mask *Array) *Array {
Init()
lse := LogSumExp(logits, -1, false)
tgtExpanded := ExpandDims(targets, -1)
gathered := TakeAlongAxis(logits, tgtExpanded, -1)
gathered = Squeeze(gathered, -1)
perPos := Subtract(lse, gathered) // [B, L]
// Apply mask and average over masked positions
masked := Mul(perPos, mask)
return Divide(SumAll(masked), SumAll(mask))
}
// MSELoss computes the mean squared error loss: mean((predictions - targets)^2).
func MSELoss(predictions, targets *Array) *Array {
diff := Subtract(predictions, targets)
sq := Square(diff)
return MeanAll(sq)
}
// --- Helpers ---
// vectorToArrays extracts all arrays from an mlx_vector_array.
func vectorToArrays(vec C.mlx_vector_array) []*Array {
n := int(C.mlx_vector_array_size(vec))
out := make([]*Array, n)
for i := 0; i < n; i++ {
a := New("VEC_OUT")
C.mlx_vector_array_get(&a.ctx, vec, C.size_t(i))
out[i] = a
}
return out
}
// Log returns element-wise natural logarithm.
func Log(a *Array) *Array {
out := New("LOG", a)
C.mlx_log(&out.ctx, a.ctx, DefaultStream().ctx)
return out
}
// SumAll reduces by summation over all elements, returning a scalar.
func SumAll(a *Array) *Array {
flat := Reshape(a, -1)
return Sum(flat, 0, false)
}
// MeanAll reduces by averaging over all elements, returning a scalar.
func MeanAll(a *Array) *Array {
flat := Reshape(a, -1)
return Mean(flat, 0, false)
}
// OnesLike creates an array of ones with the same shape and type as the input.
func OnesLike(a *Array) *Array {
out := New("ONES_LIKE", a)
C.mlx_ones_like(&out.ctx, a.ctx, DefaultStream().ctx)
return out
}