LoRA: low-rank adaptation with trainable A/B matrices, Kaiming normal init, safetensors save/load. AdamW: decoupled weight decay optimizer with positional moment tracking for gradient-replaced params. 14 tests passing including end-to-end LoRA+AdamW training loop. Co-Authored-By: Virgil <virgil@lethean.io>
175 lines
3.8 KiB
Go
175 lines
3.8 KiB
Go
//go:build darwin && arm64
|
|
|
|
package mlx
|
|
|
|
import (
|
|
"math"
|
|
"testing"
|
|
)
|
|
|
|
func TestAdamW_BasicStep(t *testing.T) {
|
|
// Simple test: minimise f(x) = x^2, starting at x=10
|
|
x := FromValue(float32(10.0))
|
|
Materialize(x)
|
|
|
|
opt := NewAdamW(0.1)
|
|
|
|
for i := 0; i < 300; i++ {
|
|
// Gradient of x^2 is 2x
|
|
lossFn := func(inputs []*Array) []*Array {
|
|
p := inputs[0]
|
|
return []*Array{Mul(p, p)}
|
|
}
|
|
|
|
grad := ValueAndGrad(lossFn)
|
|
_, grads, err := grad.Apply(x)
|
|
grad.Free()
|
|
if err != nil {
|
|
t.Fatalf("step %d: grad failed: %v", i, err)
|
|
}
|
|
|
|
updated := opt.Step([]*Array{x}, grads)
|
|
x = updated[0]
|
|
Materialize(x)
|
|
}
|
|
|
|
final := x.Float()
|
|
if math.Abs(final) > 0.5 {
|
|
t.Errorf("after 300 steps, x = %f, want near 0", final)
|
|
}
|
|
t.Logf("final x = %f (started at 10.0)", final)
|
|
}
|
|
|
|
func TestAdamW_MultiParam(t *testing.T) {
|
|
// Minimise f(x, y) = x^2 + y^2
|
|
x := FromValue(float32(5.0))
|
|
y := FromValue(float32(-3.0))
|
|
Materialize(x, y)
|
|
|
|
opt := NewAdamW(0.1)
|
|
|
|
for i := 0; i < 100; i++ {
|
|
lossFn := func(inputs []*Array) []*Array {
|
|
return []*Array{Add(Mul(inputs[0], inputs[0]), Mul(inputs[1], inputs[1]))}
|
|
}
|
|
|
|
grad := ValueAndGrad(lossFn, 0, 1)
|
|
_, grads, err := grad.Apply(x, y)
|
|
grad.Free()
|
|
if err != nil {
|
|
t.Fatalf("step %d failed: %v", i, err)
|
|
}
|
|
|
|
updated := opt.Step([]*Array{x, y}, grads)
|
|
x = updated[0]
|
|
y = updated[1]
|
|
Materialize(x, y)
|
|
}
|
|
|
|
xFinal := x.Float()
|
|
yFinal := y.Float()
|
|
if math.Abs(xFinal) > 0.1 || math.Abs(yFinal) > 0.1 {
|
|
t.Errorf("x=%f, y=%f, want both near 0", xFinal, yFinal)
|
|
}
|
|
t.Logf("final x=%f, y=%f", xFinal, yFinal)
|
|
}
|
|
|
|
func TestAdamW_WeightDecay(t *testing.T) {
|
|
// With large weight decay and zero gradient, param should decay toward 0
|
|
x := FromValue(float32(10.0))
|
|
Materialize(x)
|
|
|
|
opt := NewAdamW(0.01)
|
|
opt.WeightDecay = 0.5 // aggressive decay
|
|
|
|
zeroGrad := FromValue(float32(0.0))
|
|
Materialize(zeroGrad)
|
|
|
|
for i := 0; i < 10; i++ {
|
|
updated := opt.Step([]*Array{x}, []*Array{zeroGrad})
|
|
x = updated[0]
|
|
Materialize(x)
|
|
}
|
|
|
|
final := x.Float()
|
|
if final >= 10.0 {
|
|
t.Errorf("x = %f, should have decayed from 10.0", final)
|
|
}
|
|
if final <= 0 {
|
|
t.Errorf("x = %f, decayed too much", final)
|
|
}
|
|
t.Logf("after 10 steps with weight_decay=0.5: x = %f (started at 10.0)", final)
|
|
}
|
|
|
|
func TestAdamW_Reset(t *testing.T) {
|
|
opt := NewAdamW(0.01)
|
|
|
|
x := FromValue(float32(5.0))
|
|
grad := FromValue(float32(1.0))
|
|
Materialize(x, grad)
|
|
|
|
opt.Step([]*Array{x}, []*Array{grad})
|
|
if opt.step != 1 {
|
|
t.Errorf("step = %d, want 1", opt.step)
|
|
}
|
|
|
|
opt.Reset()
|
|
if opt.step != 0 {
|
|
t.Errorf("after reset, step = %d, want 0", opt.step)
|
|
}
|
|
if opt.m != nil {
|
|
t.Error("after reset, moments should be nil")
|
|
}
|
|
}
|
|
|
|
func TestAdamW_WithLoRA(t *testing.T) {
|
|
// End-to-end: create LoRA layer, compute gradients, update with AdamW
|
|
w := RandomNormal(0, 0.1, []int32{4, 8}, DTypeFloat32)
|
|
Materialize(w)
|
|
base := NewLinear(w, nil)
|
|
|
|
lora := NewLoRALinear(base, 4, 8.0)
|
|
opt := NewAdamW(0.001)
|
|
|
|
x := RandomNormal(0, 1, []int32{1, 2, 8}, DTypeFloat32)
|
|
target := RandomNormal(0, 1, []int32{1, 2, 4}, DTypeFloat32)
|
|
Materialize(x, target)
|
|
|
|
var initialLoss, finalLoss float64
|
|
|
|
for step := 0; step < 50; step++ {
|
|
lossFn := func(inputs []*Array) []*Array {
|
|
lora.A = inputs[0]
|
|
lora.B = inputs[1]
|
|
pred := lora.Forward(x)
|
|
return []*Array{MSELoss(pred, target)}
|
|
}
|
|
|
|
grad := ValueAndGrad(lossFn, 0, 1)
|
|
values, grads, err := grad.Apply(lora.A, lora.B)
|
|
grad.Free()
|
|
if err != nil {
|
|
t.Fatalf("step %d failed: %v", step, err)
|
|
}
|
|
|
|
Materialize(append(values, grads...)...)
|
|
|
|
loss := values[0].Float()
|
|
if step == 0 {
|
|
initialLoss = loss
|
|
}
|
|
if step == 49 {
|
|
finalLoss = loss
|
|
}
|
|
|
|
updated := opt.Step([]*Array{lora.A, lora.B}, grads)
|
|
lora.A = updated[0]
|
|
lora.B = updated[1]
|
|
Materialize(lora.A, lora.B)
|
|
}
|
|
|
|
t.Logf("loss: %.6f -> %.6f", initialLoss, finalLoss)
|
|
if finalLoss >= initialLoss {
|
|
t.Errorf("loss did not decrease: %f -> %f", initialLoss, finalLoss)
|
|
}
|
|
}
|