go-ai/mlx/optim_test.go
Snider 0eaf3d5a17 feat(mlx): add LoRA adapter layers and AdamW optimizer
LoRA: low-rank adaptation with trainable A/B matrices, Kaiming normal
init, safetensors save/load. AdamW: decoupled weight decay optimizer
with positional moment tracking for gradient-replaced params.

14 tests passing including end-to-end LoRA+AdamW training loop.

Co-Authored-By: Virgil <virgil@lethean.io>
2026-02-17 17:25:42 +00:00

175 lines
3.8 KiB
Go

//go:build darwin && arm64
package mlx
import (
"math"
"testing"
)
func TestAdamW_BasicStep(t *testing.T) {
// Simple test: minimise f(x) = x^2, starting at x=10
x := FromValue(float32(10.0))
Materialize(x)
opt := NewAdamW(0.1)
for i := 0; i < 300; i++ {
// Gradient of x^2 is 2x
lossFn := func(inputs []*Array) []*Array {
p := inputs[0]
return []*Array{Mul(p, p)}
}
grad := ValueAndGrad(lossFn)
_, grads, err := grad.Apply(x)
grad.Free()
if err != nil {
t.Fatalf("step %d: grad failed: %v", i, err)
}
updated := opt.Step([]*Array{x}, grads)
x = updated[0]
Materialize(x)
}
final := x.Float()
if math.Abs(final) > 0.5 {
t.Errorf("after 300 steps, x = %f, want near 0", final)
}
t.Logf("final x = %f (started at 10.0)", final)
}
func TestAdamW_MultiParam(t *testing.T) {
// Minimise f(x, y) = x^2 + y^2
x := FromValue(float32(5.0))
y := FromValue(float32(-3.0))
Materialize(x, y)
opt := NewAdamW(0.1)
for i := 0; i < 100; i++ {
lossFn := func(inputs []*Array) []*Array {
return []*Array{Add(Mul(inputs[0], inputs[0]), Mul(inputs[1], inputs[1]))}
}
grad := ValueAndGrad(lossFn, 0, 1)
_, grads, err := grad.Apply(x, y)
grad.Free()
if err != nil {
t.Fatalf("step %d failed: %v", i, err)
}
updated := opt.Step([]*Array{x, y}, grads)
x = updated[0]
y = updated[1]
Materialize(x, y)
}
xFinal := x.Float()
yFinal := y.Float()
if math.Abs(xFinal) > 0.1 || math.Abs(yFinal) > 0.1 {
t.Errorf("x=%f, y=%f, want both near 0", xFinal, yFinal)
}
t.Logf("final x=%f, y=%f", xFinal, yFinal)
}
func TestAdamW_WeightDecay(t *testing.T) {
// With large weight decay and zero gradient, param should decay toward 0
x := FromValue(float32(10.0))
Materialize(x)
opt := NewAdamW(0.01)
opt.WeightDecay = 0.5 // aggressive decay
zeroGrad := FromValue(float32(0.0))
Materialize(zeroGrad)
for i := 0; i < 10; i++ {
updated := opt.Step([]*Array{x}, []*Array{zeroGrad})
x = updated[0]
Materialize(x)
}
final := x.Float()
if final >= 10.0 {
t.Errorf("x = %f, should have decayed from 10.0", final)
}
if final <= 0 {
t.Errorf("x = %f, decayed too much", final)
}
t.Logf("after 10 steps with weight_decay=0.5: x = %f (started at 10.0)", final)
}
func TestAdamW_Reset(t *testing.T) {
opt := NewAdamW(0.01)
x := FromValue(float32(5.0))
grad := FromValue(float32(1.0))
Materialize(x, grad)
opt.Step([]*Array{x}, []*Array{grad})
if opt.step != 1 {
t.Errorf("step = %d, want 1", opt.step)
}
opt.Reset()
if opt.step != 0 {
t.Errorf("after reset, step = %d, want 0", opt.step)
}
if opt.m != nil {
t.Error("after reset, moments should be nil")
}
}
func TestAdamW_WithLoRA(t *testing.T) {
// End-to-end: create LoRA layer, compute gradients, update with AdamW
w := RandomNormal(0, 0.1, []int32{4, 8}, DTypeFloat32)
Materialize(w)
base := NewLinear(w, nil)
lora := NewLoRALinear(base, 4, 8.0)
opt := NewAdamW(0.001)
x := RandomNormal(0, 1, []int32{1, 2, 8}, DTypeFloat32)
target := RandomNormal(0, 1, []int32{1, 2, 4}, DTypeFloat32)
Materialize(x, target)
var initialLoss, finalLoss float64
for step := 0; step < 50; step++ {
lossFn := func(inputs []*Array) []*Array {
lora.A = inputs[0]
lora.B = inputs[1]
pred := lora.Forward(x)
return []*Array{MSELoss(pred, target)}
}
grad := ValueAndGrad(lossFn, 0, 1)
values, grads, err := grad.Apply(lora.A, lora.B)
grad.Free()
if err != nil {
t.Fatalf("step %d failed: %v", step, err)
}
Materialize(append(values, grads...)...)
loss := values[0].Float()
if step == 0 {
initialLoss = loss
}
if step == 49 {
finalLoss = loss
}
updated := opt.Step([]*Array{lora.A, lora.B}, grads)
lora.A = updated[0]
lora.B = updated[1]
Materialize(lora.A, lora.B)
}
t.Logf("loss: %.6f -> %.6f", initialLoss, finalLoss)
if finalLoss >= initialLoss {
t.Errorf("loss did not decrease: %f -> %f", initialLoss, finalLoss)
}
}