feat(mlx): add LoRA adapter layers and AdamW optimizer
LoRA: low-rank adaptation with trainable A/B matrices, Kaiming normal init, safetensors save/load. AdamW: decoupled weight decay optimizer with positional moment tracking for gradient-replaced params. 14 tests passing including end-to-end LoRA+AdamW training loop. Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
parent
e9973aef3c
commit
0eaf3d5a17
4 changed files with 808 additions and 0 deletions
211
mlx/lora.go
Normal file
211
mlx/lora.go
Normal file
|
|
@ -0,0 +1,211 @@
|
||||||
|
//go:build darwin && arm64
|
||||||
|
|
||||||
|
package mlx
|
||||||
|
|
||||||
|
/*
|
||||||
|
#include <stdlib.h>
|
||||||
|
#include "mlx/c/mlx.h"
|
||||||
|
*/
|
||||||
|
import "C"
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"math"
|
||||||
|
"unsafe"
|
||||||
|
)
|
||||||
|
|
||||||
|
// LoRALinear wraps a frozen Linear layer with low-rank trainable adapters.
|
||||||
|
//
|
||||||
|
// Forward: base(x) + scale * (x @ A^T) @ B^T
|
||||||
|
//
|
||||||
|
// A is [rank, in_features] — initialised with Kaiming normal
|
||||||
|
// B is [out_features, rank] — initialised to zero
|
||||||
|
// Scale = alpha / rank
|
||||||
|
//
|
||||||
|
// Only A and B are trainable. The base weights stay frozen.
|
||||||
|
type LoRALinear struct {
|
||||||
|
Base *Linear // Frozen base weights (may be quantized)
|
||||||
|
A *Array // [rank, in_features] — trainable
|
||||||
|
B *Array // [out_features, rank] — trainable
|
||||||
|
Scale float32 // alpha / rank
|
||||||
|
Rank int
|
||||||
|
Alpha float32
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewLoRALinear wraps an existing Linear layer with LoRA adapters.
|
||||||
|
// rank: decomposition rank (typically 4, 8, or 16)
|
||||||
|
// alpha: scaling factor (typically 2*rank)
|
||||||
|
func NewLoRALinear(base *Linear, rank int, alpha float32) *LoRALinear {
|
||||||
|
// Determine dimensions from the base weight.
|
||||||
|
// Weight shape is [out_features, in_features] for standard linear,
|
||||||
|
// or quantized shape which we handle via the base layer.
|
||||||
|
var inFeatures, outFeatures int32
|
||||||
|
|
||||||
|
if base.Scales != nil {
|
||||||
|
// Quantized: weight is packed. Compute dims from scales.
|
||||||
|
// scales shape is [out_features, in_features / group_size]
|
||||||
|
scaleShape := base.Scales.Shape()
|
||||||
|
outFeatures = scaleShape[0]
|
||||||
|
inFeatures = scaleShape[1] * int32(base.GroupSize)
|
||||||
|
} else {
|
||||||
|
wShape := base.Weight.Shape()
|
||||||
|
outFeatures = wShape[0]
|
||||||
|
inFeatures = wShape[1]
|
||||||
|
}
|
||||||
|
|
||||||
|
// A: Kaiming normal initialisation — N(0, 1/sqrt(in_features))
|
||||||
|
stddev := float32(1.0 / math.Sqrt(float64(inFeatures)))
|
||||||
|
a := RandomNormal(0, stddev, []int32{int32(rank), inFeatures}, DTypeFloat32)
|
||||||
|
|
||||||
|
// B: zero initialisation — LoRA starts as identity (no change to base)
|
||||||
|
b := Zeros([]int32{outFeatures, int32(rank)}, DTypeFloat32)
|
||||||
|
|
||||||
|
Materialize(a, b)
|
||||||
|
|
||||||
|
return &LoRALinear{
|
||||||
|
Base: base,
|
||||||
|
A: a,
|
||||||
|
B: b,
|
||||||
|
Scale: alpha / float32(rank),
|
||||||
|
Rank: rank,
|
||||||
|
Alpha: alpha,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Forward computes: base(x) + scale * (x @ A^T) @ B^T
|
||||||
|
func (l *LoRALinear) Forward(x *Array) *Array {
|
||||||
|
baseOut := l.Base.Forward(x)
|
||||||
|
|
||||||
|
// LoRA path: x @ A^T gives [B, L, rank], then @ B^T gives [B, L, out]
|
||||||
|
loraOut := Matmul(x, Transpose(l.A))
|
||||||
|
loraOut = Matmul(loraOut, Transpose(l.B))
|
||||||
|
loraOut = MulScalar(loraOut, l.Scale)
|
||||||
|
|
||||||
|
return Add(baseOut, loraOut)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TrainableParams returns the LoRA A and B arrays for gradient computation.
|
||||||
|
func (l *LoRALinear) TrainableParams() []*Array {
|
||||||
|
return []*Array{l.A, l.B}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetParams updates the LoRA A and B arrays (used by optimiser after gradient step).
|
||||||
|
func (l *LoRALinear) SetParams(a, b *Array) {
|
||||||
|
l.A = a
|
||||||
|
l.B = b
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParamCount returns the number of trainable parameters.
|
||||||
|
func (l *LoRALinear) ParamCount() int {
|
||||||
|
aShape := l.A.Shape()
|
||||||
|
bShape := l.B.Shape()
|
||||||
|
return int(aShape[0]*aShape[1] + bShape[0]*bShape[1])
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- LoRA Application to Models ---
|
||||||
|
|
||||||
|
// LoRAConfig specifies which layers to apply LoRA to and with what parameters.
|
||||||
|
type LoRAConfig struct {
|
||||||
|
Rank int // Decomposition rank (default 8)
|
||||||
|
Alpha float32 // Scaling factor (default 16)
|
||||||
|
TargetKeys []string // Weight name suffixes to target (default: q_proj, v_proj)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultLoRAConfig returns the standard LoRA configuration for LLM fine-tuning.
|
||||||
|
func DefaultLoRAConfig() LoRAConfig {
|
||||||
|
return LoRAConfig{
|
||||||
|
Rank: 8,
|
||||||
|
Alpha: 16,
|
||||||
|
TargetKeys: []string{"q_proj", "v_proj"},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoRAAdapter holds all LoRA layers applied to a model.
|
||||||
|
type LoRAAdapter struct {
|
||||||
|
Layers map[string]*LoRALinear // keyed by weight path prefix
|
||||||
|
Config LoRAConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
// TotalParams returns the total number of trainable parameters across all LoRA layers.
|
||||||
|
func (a *LoRAAdapter) TotalParams() int {
|
||||||
|
total := 0
|
||||||
|
for _, l := range a.Layers {
|
||||||
|
total += l.ParamCount()
|
||||||
|
}
|
||||||
|
return total
|
||||||
|
}
|
||||||
|
|
||||||
|
// AllTrainableParams returns all trainable arrays (A and B from every layer),
|
||||||
|
// in a deterministic order by layer name.
|
||||||
|
func (a *LoRAAdapter) AllTrainableParams() []*Array {
|
||||||
|
var params []*Array
|
||||||
|
for _, l := range a.Layers {
|
||||||
|
params = append(params, l.TrainableParams()...)
|
||||||
|
}
|
||||||
|
return params
|
||||||
|
}
|
||||||
|
|
||||||
|
// Save writes the LoRA adapter weights to a safetensors file.
|
||||||
|
// Only saves the A and B matrices — not the frozen base weights.
|
||||||
|
func (a *LoRAAdapter) Save(path string) error {
|
||||||
|
weights := make(map[string]*Array)
|
||||||
|
for name, l := range a.Layers {
|
||||||
|
weights[name+".lora_a"] = l.A
|
||||||
|
weights[name+".lora_b"] = l.B
|
||||||
|
}
|
||||||
|
return SaveSafetensors(path, weights)
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Random Normal ---
|
||||||
|
|
||||||
|
// RandomNormal generates normal (Gaussian) random values with given mean and stddev.
|
||||||
|
func RandomNormal(mean, stddev float32, shape []int32, dtype DType) *Array {
|
||||||
|
Init()
|
||||||
|
out := New("RANDOM_NORMAL")
|
||||||
|
cShape := make([]C.int, len(shape))
|
||||||
|
for i, s := range shape {
|
||||||
|
cShape[i] = C.int(s)
|
||||||
|
}
|
||||||
|
key := C.mlx_array_new()
|
||||||
|
defer C.mlx_array_free(key)
|
||||||
|
C.mlx_random_normal(
|
||||||
|
&out.ctx,
|
||||||
|
&cShape[0], C.size_t(len(cShape)),
|
||||||
|
C.mlx_dtype(dtype),
|
||||||
|
C.float(mean), C.float(stddev),
|
||||||
|
key, // null key = use default RNG
|
||||||
|
DefaultStream().ctx,
|
||||||
|
)
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- SaveSafetensors ---
|
||||||
|
|
||||||
|
// SaveSafetensors saves a map of named arrays to a .safetensors file.
|
||||||
|
func SaveSafetensors(path string, weights map[string]*Array) error {
|
||||||
|
Init()
|
||||||
|
|
||||||
|
// Build the map
|
||||||
|
cMap := C.mlx_map_string_to_array_new()
|
||||||
|
defer C.mlx_map_string_to_array_free(cMap)
|
||||||
|
|
||||||
|
for name, arr := range weights {
|
||||||
|
cName := C.CString(name)
|
||||||
|
C.mlx_map_string_to_array_insert(cMap, cName, arr.ctx)
|
||||||
|
C.free(unsafe.Pointer(cName))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Empty metadata
|
||||||
|
cMeta := C.mlx_map_string_to_string_new()
|
||||||
|
defer C.mlx_map_string_to_string_free(cMeta)
|
||||||
|
|
||||||
|
cPath := C.CString(path)
|
||||||
|
defer C.free(unsafe.Pointer(cPath))
|
||||||
|
|
||||||
|
rc := C.mlx_save_safetensors(cPath, cMap, cMeta)
|
||||||
|
if rc != 0 {
|
||||||
|
checkError()
|
||||||
|
return fmt.Errorf("mlx: save safetensors failed: %s", path)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
316
mlx/lora_test.go
Normal file
316
mlx/lora_test.go
Normal file
|
|
@ -0,0 +1,316 @@
|
||||||
|
//go:build darwin && arm64
|
||||||
|
|
||||||
|
package mlx
|
||||||
|
|
||||||
|
import (
|
||||||
|
"math"
|
||||||
|
"os"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewLoRALinear(t *testing.T) {
|
||||||
|
// Create a simple base linear layer: [4, 8] weight
|
||||||
|
w := RandomNormal(0, 0.01, []int32{4, 8}, DTypeFloat32)
|
||||||
|
Materialize(w)
|
||||||
|
base := NewLinear(w, nil)
|
||||||
|
|
||||||
|
lora := NewLoRALinear(base, 4, 8.0) // rank=4, alpha=8
|
||||||
|
|
||||||
|
// Check dimensions
|
||||||
|
aShape := lora.A.Shape()
|
||||||
|
bShape := lora.B.Shape()
|
||||||
|
|
||||||
|
if aShape[0] != 4 || aShape[1] != 8 {
|
||||||
|
t.Errorf("A shape = %v, want [4, 8]", aShape)
|
||||||
|
}
|
||||||
|
if bShape[0] != 4 || bShape[1] != 4 {
|
||||||
|
t.Errorf("B shape = %v, want [4, 4]", bShape)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Scale should be alpha/rank = 8/4 = 2
|
||||||
|
if math.Abs(float64(lora.Scale)-2.0) > 1e-5 {
|
||||||
|
t.Errorf("Scale = %f, want 2.0", lora.Scale)
|
||||||
|
}
|
||||||
|
|
||||||
|
// B should be all zeros (LoRA starts as identity)
|
||||||
|
Materialize(lora.B)
|
||||||
|
bFloats := lora.B.Floats()
|
||||||
|
for i, v := range bFloats {
|
||||||
|
if v != 0 {
|
||||||
|
t.Errorf("B[%d] = %f, want 0", i, v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoRALinear_ForwardMatchesBase(t *testing.T) {
|
||||||
|
// With B=0, LoRA forward should equal base forward
|
||||||
|
w := RandomNormal(0, 0.1, []int32{4, 8}, DTypeFloat32)
|
||||||
|
Materialize(w)
|
||||||
|
base := NewLinear(w, nil)
|
||||||
|
|
||||||
|
lora := NewLoRALinear(base, 4, 8.0)
|
||||||
|
|
||||||
|
// Random input [1, 3, 8]
|
||||||
|
x := RandomNormal(0, 1, []int32{1, 3, 8}, DTypeFloat32)
|
||||||
|
Materialize(x)
|
||||||
|
|
||||||
|
baseOut := base.Forward(x)
|
||||||
|
loraOut := lora.Forward(x)
|
||||||
|
Materialize(baseOut, loraOut)
|
||||||
|
|
||||||
|
// Should be identical since B is zero
|
||||||
|
baseFloats := baseOut.Floats()
|
||||||
|
loraFloats := loraOut.Floats()
|
||||||
|
|
||||||
|
if len(baseFloats) != len(loraFloats) {
|
||||||
|
t.Fatalf("output sizes differ: base=%d, lora=%d", len(baseFloats), len(loraFloats))
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := range baseFloats {
|
||||||
|
diff := math.Abs(float64(baseFloats[i] - loraFloats[i]))
|
||||||
|
if diff > 1e-4 {
|
||||||
|
t.Errorf("output[%d] differs: base=%f, lora=%f", i, baseFloats[i], loraFloats[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoRALinear_ForwardWithAdapter(t *testing.T) {
|
||||||
|
// Set A and B to known values and verify output changes
|
||||||
|
w := Zeros([]int32{4, 8}, DTypeFloat32)
|
||||||
|
Materialize(w)
|
||||||
|
base := NewLinear(w, nil)
|
||||||
|
|
||||||
|
lora := NewLoRALinear(base, 2, 4.0) // rank=2, alpha=4, scale=2
|
||||||
|
|
||||||
|
// Set A to identity-like: [[1,0,0,...], [0,1,0,...]]
|
||||||
|
a := Zeros([]int32{2, 8}, DTypeFloat32)
|
||||||
|
// Set B to ones: [[1,1], [1,1], [1,1], [1,1]]
|
||||||
|
b := FromValues([]float32{
|
||||||
|
1, 1,
|
||||||
|
1, 1,
|
||||||
|
1, 1,
|
||||||
|
1, 1,
|
||||||
|
}, 4, 2)
|
||||||
|
Materialize(a, b)
|
||||||
|
lora.A = a
|
||||||
|
lora.B = b
|
||||||
|
|
||||||
|
// With base=0, A=0, output should also be 0 (scale * x@0@B^T = 0)
|
||||||
|
x := FromValues([]float32{1, 2, 3, 4, 5, 6, 7, 8}, 1, 1, 8)
|
||||||
|
result := lora.Forward(x)
|
||||||
|
Materialize(result)
|
||||||
|
|
||||||
|
// base(x) = 0 (zero weights), lora = scale * (x @ A^T) @ B^T
|
||||||
|
// A is zeros, so x @ A^T = [0, 0], then @ B^T = [0,0,0,0]
|
||||||
|
for _, v := range result.Floats() {
|
||||||
|
if v != 0 {
|
||||||
|
t.Errorf("expected 0 with zero A, got %f", v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoRALinear_ParamCount(t *testing.T) {
|
||||||
|
w := RandomNormal(0, 0.01, []int32{64, 128}, DTypeFloat32)
|
||||||
|
Materialize(w)
|
||||||
|
base := NewLinear(w, nil)
|
||||||
|
|
||||||
|
lora := NewLoRALinear(base, 8, 16.0) // rank=8
|
||||||
|
// A: [8, 128] = 1024, B: [64, 8] = 512, total = 1536
|
||||||
|
expected := 8*128 + 64*8
|
||||||
|
if lora.ParamCount() != expected {
|
||||||
|
t.Errorf("ParamCount = %d, want %d", lora.ParamCount(), expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoRALinear_TrainableParams(t *testing.T) {
|
||||||
|
w := RandomNormal(0, 0.01, []int32{4, 8}, DTypeFloat32)
|
||||||
|
Materialize(w)
|
||||||
|
base := NewLinear(w, nil)
|
||||||
|
|
||||||
|
lora := NewLoRALinear(base, 4, 8.0)
|
||||||
|
params := lora.TrainableParams()
|
||||||
|
|
||||||
|
if len(params) != 2 {
|
||||||
|
t.Fatalf("TrainableParams returned %d arrays, want 2", len(params))
|
||||||
|
}
|
||||||
|
|
||||||
|
// First is A, second is B
|
||||||
|
if params[0].Shape()[0] != 4 || params[0].Shape()[1] != 8 {
|
||||||
|
t.Errorf("param[0] (A) shape = %v, want [4, 8]", params[0].Shape())
|
||||||
|
}
|
||||||
|
if params[1].Shape()[0] != 4 || params[1].Shape()[1] != 4 {
|
||||||
|
t.Errorf("param[1] (B) shape = %v, want [4, 4]", params[1].Shape())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoRALinear_GradientFlows(t *testing.T) {
|
||||||
|
// Verify that gradients flow through the LoRA path
|
||||||
|
w := RandomNormal(0, 0.1, []int32{4, 8}, DTypeFloat32)
|
||||||
|
Materialize(w)
|
||||||
|
base := NewLinear(w, nil)
|
||||||
|
|
||||||
|
lora := NewLoRALinear(base, 4, 8.0)
|
||||||
|
x := RandomNormal(0, 1, []int32{1, 2, 8}, DTypeFloat32)
|
||||||
|
Materialize(x)
|
||||||
|
|
||||||
|
// Loss function: sum of LoRA output (differentiating w.r.t. A and B)
|
||||||
|
lossFn := func(inputs []*Array) []*Array {
|
||||||
|
lora.A = inputs[0]
|
||||||
|
lora.B = inputs[1]
|
||||||
|
out := lora.Forward(x)
|
||||||
|
return []*Array{SumAll(out)}
|
||||||
|
}
|
||||||
|
|
||||||
|
grad := ValueAndGrad(lossFn, 0, 1) // grad w.r.t. A and B
|
||||||
|
defer grad.Free()
|
||||||
|
|
||||||
|
values, grads, err := grad.Apply(lora.A, lora.B)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ValueAndGrad failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
Materialize(append(values, grads...)...)
|
||||||
|
|
||||||
|
// Loss should be a scalar
|
||||||
|
loss := values[0].Float()
|
||||||
|
t.Logf("loss = %f", loss)
|
||||||
|
|
||||||
|
// Gradients should be non-zero (A has random init, B is zero but gets grad)
|
||||||
|
gradA := grads[0]
|
||||||
|
gradB := grads[1]
|
||||||
|
|
||||||
|
aGradFloats := gradA.Floats()
|
||||||
|
bGradFloats := gradB.Floats()
|
||||||
|
|
||||||
|
hasNonZeroA := false
|
||||||
|
for _, v := range aGradFloats {
|
||||||
|
if v != 0 {
|
||||||
|
hasNonZeroA = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
hasNonZeroB := false
|
||||||
|
for _, v := range bGradFloats {
|
||||||
|
if v != 0 {
|
||||||
|
hasNonZeroB = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// A gradient might be zero if B is zero (since dL/dA depends on B)
|
||||||
|
// But B gradient should be non-zero since A is random
|
||||||
|
if !hasNonZeroB {
|
||||||
|
t.Error("gradient for B is all zeros — gradients not flowing")
|
||||||
|
}
|
||||||
|
t.Logf("gradA has non-zero: %v, gradB has non-zero: %v", hasNonZeroA, hasNonZeroB)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRandomNormal(t *testing.T) {
|
||||||
|
arr := RandomNormal(0, 1, []int32{100}, DTypeFloat32)
|
||||||
|
Materialize(arr)
|
||||||
|
|
||||||
|
floats := arr.Floats()
|
||||||
|
if len(floats) != 100 {
|
||||||
|
t.Fatalf("RandomNormal returned %d elements, want 100", len(floats))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check rough statistics: mean should be near 0, values should have spread
|
||||||
|
var sum float64
|
||||||
|
for _, f := range floats {
|
||||||
|
sum += float64(f)
|
||||||
|
}
|
||||||
|
mean := sum / 100
|
||||||
|
if math.Abs(mean) > 0.5 { // generous tolerance for 100 samples
|
||||||
|
t.Errorf("mean = %f, expected near 0", mean)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSaveSafetensors(t *testing.T) {
|
||||||
|
a := FromValues([]float32{1, 2, 3, 4}, 2, 2)
|
||||||
|
b := FromValues([]float32{5, 6, 7, 8, 9, 10}, 3, 2)
|
||||||
|
Materialize(a, b)
|
||||||
|
|
||||||
|
path := t.TempDir() + "/test.safetensors"
|
||||||
|
err := SaveSafetensors(path, map[string]*Array{
|
||||||
|
"layer.lora_a": a,
|
||||||
|
"layer.lora_b": b,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("SaveSafetensors failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify file exists
|
||||||
|
info, err := os.Stat(path)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("saved file not found: %v", err)
|
||||||
|
}
|
||||||
|
if info.Size() == 0 {
|
||||||
|
t.Error("saved file is empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load it back
|
||||||
|
loaded := LoadAllSafetensors(path)
|
||||||
|
Materialize(loaded["layer.lora_a"], loaded["layer.lora_b"])
|
||||||
|
|
||||||
|
aLoaded := loaded["layer.lora_a"].Floats()
|
||||||
|
bLoaded := loaded["layer.lora_b"].Floats()
|
||||||
|
|
||||||
|
expectedA := []float32{1, 2, 3, 4}
|
||||||
|
expectedB := []float32{5, 6, 7, 8, 9, 10}
|
||||||
|
|
||||||
|
for i, v := range expectedA {
|
||||||
|
if aLoaded[i] != v {
|
||||||
|
t.Errorf("loaded A[%d] = %f, want %f", i, aLoaded[i], v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for i, v := range expectedB {
|
||||||
|
if bLoaded[i] != v {
|
||||||
|
t.Errorf("loaded B[%d] = %f, want %f", i, bLoaded[i], v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoRAAdapter_Save(t *testing.T) {
|
||||||
|
w := RandomNormal(0, 0.01, []int32{4, 8}, DTypeFloat32)
|
||||||
|
Materialize(w)
|
||||||
|
base := NewLinear(w, nil)
|
||||||
|
|
||||||
|
adapter := &LoRAAdapter{
|
||||||
|
Layers: map[string]*LoRALinear{
|
||||||
|
"model.layers.0.self_attn.q_proj": NewLoRALinear(base, 4, 8.0),
|
||||||
|
},
|
||||||
|
Config: DefaultLoRAConfig(),
|
||||||
|
}
|
||||||
|
|
||||||
|
path := t.TempDir() + "/adapter.safetensors"
|
||||||
|
err := adapter.Save(path)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Adapter.Save failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load and verify
|
||||||
|
loaded := LoadAllSafetensors(path)
|
||||||
|
aKey := "model.layers.0.self_attn.q_proj.lora_a"
|
||||||
|
bKey := "model.layers.0.self_attn.q_proj.lora_b"
|
||||||
|
|
||||||
|
if _, ok := loaded[aKey]; !ok {
|
||||||
|
t.Errorf("missing key %s in saved adapter", aKey)
|
||||||
|
}
|
||||||
|
if _, ok := loaded[bKey]; !ok {
|
||||||
|
t.Errorf("missing key %s in saved adapter", bKey)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDefaultLoRAConfig(t *testing.T) {
|
||||||
|
cfg := DefaultLoRAConfig()
|
||||||
|
if cfg.Rank != 8 {
|
||||||
|
t.Errorf("Rank = %d, want 8", cfg.Rank)
|
||||||
|
}
|
||||||
|
if cfg.Alpha != 16 {
|
||||||
|
t.Errorf("Alpha = %f, want 16", cfg.Alpha)
|
||||||
|
}
|
||||||
|
if len(cfg.TargetKeys) != 2 {
|
||||||
|
t.Errorf("TargetKeys = %v, want [q_proj, v_proj]", cfg.TargetKeys)
|
||||||
|
}
|
||||||
|
}
|
||||||
106
mlx/optim.go
Normal file
106
mlx/optim.go
Normal file
|
|
@ -0,0 +1,106 @@
|
||||||
|
//go:build darwin && arm64
|
||||||
|
|
||||||
|
package mlx
|
||||||
|
|
||||||
|
import "math"
|
||||||
|
|
||||||
|
// AdamW implements the AdamW optimiser (Adam with decoupled weight decay).
|
||||||
|
//
|
||||||
|
// Update rule per parameter:
|
||||||
|
//
|
||||||
|
// m = beta1 * m + (1 - beta1) * grad
|
||||||
|
// v = beta2 * v + (1 - beta2) * grad^2
|
||||||
|
// m_hat = m / (1 - beta1^t)
|
||||||
|
// v_hat = v / (1 - beta2^t)
|
||||||
|
// param = param * (1 - lr * weight_decay) - lr * m_hat / (sqrt(v_hat) + eps)
|
||||||
|
type AdamW struct {
|
||||||
|
LR float64 // Learning rate (default 1e-5)
|
||||||
|
Beta1 float64 // First moment decay (default 0.9)
|
||||||
|
Beta2 float64 // Second moment decay (default 0.999)
|
||||||
|
Eps float64 // Numerical stability (default 1e-8)
|
||||||
|
WeightDecay float64 // Decoupled weight decay (default 0.01)
|
||||||
|
|
||||||
|
step int // Number of updates performed
|
||||||
|
m []*Array // First moment estimates (positional, parallel to params)
|
||||||
|
v []*Array // Second moment estimates (positional, parallel to params)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewAdamW creates an AdamW optimiser with default hyperparameters.
|
||||||
|
func NewAdamW(lr float64) *AdamW {
|
||||||
|
return &AdamW{
|
||||||
|
LR: lr,
|
||||||
|
Beta1: 0.9,
|
||||||
|
Beta2: 0.999,
|
||||||
|
Eps: 1e-8,
|
||||||
|
WeightDecay: 0.01,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Step performs one optimisation step: updates params using gradients.
|
||||||
|
// params and grads must be parallel slices of the same length.
|
||||||
|
// Returns the updated parameter arrays (params are replaced in-place).
|
||||||
|
func (o *AdamW) Step(params []*Array, grads []*Array) []*Array {
|
||||||
|
o.step++
|
||||||
|
|
||||||
|
// Bias correction factors
|
||||||
|
bc1 := 1.0 - math.Pow(o.Beta1, float64(o.step))
|
||||||
|
bc2 := 1.0 - math.Pow(o.Beta2, float64(o.step))
|
||||||
|
|
||||||
|
updated := make([]*Array, len(params))
|
||||||
|
|
||||||
|
// Grow moment slices if needed (first call or param count increased)
|
||||||
|
for len(o.m) < len(params) {
|
||||||
|
o.m = append(o.m, nil)
|
||||||
|
o.v = append(o.v, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, param := range params {
|
||||||
|
grad := grads[i]
|
||||||
|
|
||||||
|
// Initialise moments on first use
|
||||||
|
if o.m[i] == nil {
|
||||||
|
shape := param.Shape()
|
||||||
|
o.m[i] = Zeros(shape, param.Dtype())
|
||||||
|
o.v[i] = Zeros(shape, param.Dtype())
|
||||||
|
}
|
||||||
|
|
||||||
|
// m = beta1 * m + (1 - beta1) * grad
|
||||||
|
m := Add(
|
||||||
|
MulScalar(o.m[i], float32(o.Beta1)),
|
||||||
|
MulScalar(grad, float32(1.0-o.Beta1)),
|
||||||
|
)
|
||||||
|
|
||||||
|
// v = beta2 * v + (1 - beta2) * grad^2
|
||||||
|
v := Add(
|
||||||
|
MulScalar(o.v[i], float32(o.Beta2)),
|
||||||
|
MulScalar(Square(grad), float32(1.0-o.Beta2)),
|
||||||
|
)
|
||||||
|
|
||||||
|
// Bias-corrected estimates
|
||||||
|
mHat := MulScalar(m, float32(1.0/bc1))
|
||||||
|
vHat := MulScalar(v, float32(1.0/bc2))
|
||||||
|
|
||||||
|
// Weight decay: param = param * (1 - lr * weight_decay)
|
||||||
|
decayed := MulScalar(param, float32(1.0-o.LR*o.WeightDecay))
|
||||||
|
|
||||||
|
// Update: param = decayed - lr * m_hat / (sqrt(v_hat) + eps)
|
||||||
|
denom := AddScalar(Sqrt(vHat), float32(o.Eps))
|
||||||
|
step := MulScalar(Divide(mHat, denom), float32(o.LR))
|
||||||
|
newParam := Subtract(decayed, step)
|
||||||
|
|
||||||
|
// Store updated moments
|
||||||
|
o.m[i] = m
|
||||||
|
o.v[i] = v
|
||||||
|
|
||||||
|
updated[i] = newParam
|
||||||
|
}
|
||||||
|
|
||||||
|
return updated
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reset clears the optimiser state (moments and step counter).
|
||||||
|
func (o *AdamW) Reset() {
|
||||||
|
o.step = 0
|
||||||
|
o.m = nil
|
||||||
|
o.v = nil
|
||||||
|
}
|
||||||
175
mlx/optim_test.go
Normal file
175
mlx/optim_test.go
Normal file
|
|
@ -0,0 +1,175 @@
|
||||||
|
//go:build darwin && arm64
|
||||||
|
|
||||||
|
package mlx
|
||||||
|
|
||||||
|
import (
|
||||||
|
"math"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestAdamW_BasicStep(t *testing.T) {
|
||||||
|
// Simple test: minimise f(x) = x^2, starting at x=10
|
||||||
|
x := FromValue(float32(10.0))
|
||||||
|
Materialize(x)
|
||||||
|
|
||||||
|
opt := NewAdamW(0.1)
|
||||||
|
|
||||||
|
for i := 0; i < 300; i++ {
|
||||||
|
// Gradient of x^2 is 2x
|
||||||
|
lossFn := func(inputs []*Array) []*Array {
|
||||||
|
p := inputs[0]
|
||||||
|
return []*Array{Mul(p, p)}
|
||||||
|
}
|
||||||
|
|
||||||
|
grad := ValueAndGrad(lossFn)
|
||||||
|
_, grads, err := grad.Apply(x)
|
||||||
|
grad.Free()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("step %d: grad failed: %v", i, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
updated := opt.Step([]*Array{x}, grads)
|
||||||
|
x = updated[0]
|
||||||
|
Materialize(x)
|
||||||
|
}
|
||||||
|
|
||||||
|
final := x.Float()
|
||||||
|
if math.Abs(final) > 0.5 {
|
||||||
|
t.Errorf("after 300 steps, x = %f, want near 0", final)
|
||||||
|
}
|
||||||
|
t.Logf("final x = %f (started at 10.0)", final)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAdamW_MultiParam(t *testing.T) {
|
||||||
|
// Minimise f(x, y) = x^2 + y^2
|
||||||
|
x := FromValue(float32(5.0))
|
||||||
|
y := FromValue(float32(-3.0))
|
||||||
|
Materialize(x, y)
|
||||||
|
|
||||||
|
opt := NewAdamW(0.1)
|
||||||
|
|
||||||
|
for i := 0; i < 100; i++ {
|
||||||
|
lossFn := func(inputs []*Array) []*Array {
|
||||||
|
return []*Array{Add(Mul(inputs[0], inputs[0]), Mul(inputs[1], inputs[1]))}
|
||||||
|
}
|
||||||
|
|
||||||
|
grad := ValueAndGrad(lossFn, 0, 1)
|
||||||
|
_, grads, err := grad.Apply(x, y)
|
||||||
|
grad.Free()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("step %d failed: %v", i, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
updated := opt.Step([]*Array{x, y}, grads)
|
||||||
|
x = updated[0]
|
||||||
|
y = updated[1]
|
||||||
|
Materialize(x, y)
|
||||||
|
}
|
||||||
|
|
||||||
|
xFinal := x.Float()
|
||||||
|
yFinal := y.Float()
|
||||||
|
if math.Abs(xFinal) > 0.1 || math.Abs(yFinal) > 0.1 {
|
||||||
|
t.Errorf("x=%f, y=%f, want both near 0", xFinal, yFinal)
|
||||||
|
}
|
||||||
|
t.Logf("final x=%f, y=%f", xFinal, yFinal)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAdamW_WeightDecay(t *testing.T) {
|
||||||
|
// With large weight decay and zero gradient, param should decay toward 0
|
||||||
|
x := FromValue(float32(10.0))
|
||||||
|
Materialize(x)
|
||||||
|
|
||||||
|
opt := NewAdamW(0.01)
|
||||||
|
opt.WeightDecay = 0.5 // aggressive decay
|
||||||
|
|
||||||
|
zeroGrad := FromValue(float32(0.0))
|
||||||
|
Materialize(zeroGrad)
|
||||||
|
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
updated := opt.Step([]*Array{x}, []*Array{zeroGrad})
|
||||||
|
x = updated[0]
|
||||||
|
Materialize(x)
|
||||||
|
}
|
||||||
|
|
||||||
|
final := x.Float()
|
||||||
|
if final >= 10.0 {
|
||||||
|
t.Errorf("x = %f, should have decayed from 10.0", final)
|
||||||
|
}
|
||||||
|
if final <= 0 {
|
||||||
|
t.Errorf("x = %f, decayed too much", final)
|
||||||
|
}
|
||||||
|
t.Logf("after 10 steps with weight_decay=0.5: x = %f (started at 10.0)", final)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAdamW_Reset(t *testing.T) {
|
||||||
|
opt := NewAdamW(0.01)
|
||||||
|
|
||||||
|
x := FromValue(float32(5.0))
|
||||||
|
grad := FromValue(float32(1.0))
|
||||||
|
Materialize(x, grad)
|
||||||
|
|
||||||
|
opt.Step([]*Array{x}, []*Array{grad})
|
||||||
|
if opt.step != 1 {
|
||||||
|
t.Errorf("step = %d, want 1", opt.step)
|
||||||
|
}
|
||||||
|
|
||||||
|
opt.Reset()
|
||||||
|
if opt.step != 0 {
|
||||||
|
t.Errorf("after reset, step = %d, want 0", opt.step)
|
||||||
|
}
|
||||||
|
if opt.m != nil {
|
||||||
|
t.Error("after reset, moments should be nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAdamW_WithLoRA(t *testing.T) {
|
||||||
|
// End-to-end: create LoRA layer, compute gradients, update with AdamW
|
||||||
|
w := RandomNormal(0, 0.1, []int32{4, 8}, DTypeFloat32)
|
||||||
|
Materialize(w)
|
||||||
|
base := NewLinear(w, nil)
|
||||||
|
|
||||||
|
lora := NewLoRALinear(base, 4, 8.0)
|
||||||
|
opt := NewAdamW(0.001)
|
||||||
|
|
||||||
|
x := RandomNormal(0, 1, []int32{1, 2, 8}, DTypeFloat32)
|
||||||
|
target := RandomNormal(0, 1, []int32{1, 2, 4}, DTypeFloat32)
|
||||||
|
Materialize(x, target)
|
||||||
|
|
||||||
|
var initialLoss, finalLoss float64
|
||||||
|
|
||||||
|
for step := 0; step < 50; step++ {
|
||||||
|
lossFn := func(inputs []*Array) []*Array {
|
||||||
|
lora.A = inputs[0]
|
||||||
|
lora.B = inputs[1]
|
||||||
|
pred := lora.Forward(x)
|
||||||
|
return []*Array{MSELoss(pred, target)}
|
||||||
|
}
|
||||||
|
|
||||||
|
grad := ValueAndGrad(lossFn, 0, 1)
|
||||||
|
values, grads, err := grad.Apply(lora.A, lora.B)
|
||||||
|
grad.Free()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("step %d failed: %v", step, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
Materialize(append(values, grads...)...)
|
||||||
|
|
||||||
|
loss := values[0].Float()
|
||||||
|
if step == 0 {
|
||||||
|
initialLoss = loss
|
||||||
|
}
|
||||||
|
if step == 49 {
|
||||||
|
finalLoss = loss
|
||||||
|
}
|
||||||
|
|
||||||
|
updated := opt.Step([]*Array{lora.A, lora.B}, grads)
|
||||||
|
lora.A = updated[0]
|
||||||
|
lora.B = updated[1]
|
||||||
|
Materialize(lora.A, lora.B)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Logf("loss: %.6f -> %.6f", initialLoss, finalLoss)
|
||||||
|
if finalLoss >= initialLoss {
|
||||||
|
t.Errorf("loss did not decrease: %f -> %f", initialLoss, finalLoss)
|
||||||
|
}
|
||||||
|
}
|
||||||
Loading…
Add table
Reference in a new issue