go-ai/mlx/lora.go
Snider 2d870385f9 feat(mlx): LoRA injection into models + masked cross-entropy loss
Add LoRA field to Linear for transparent adapter injection via model's
Forward() path. ApplyLoRA() on Qwen3/Gemma3 wraps target projections.
Deterministic param ordering for adapter save/load consistency.
MaskedCrossEntropyLoss for training on assistant tokens only.

Co-Authored-By: Virgil <virgil@lethean.io>
2026-02-17 17:37:44 +00:00

237 lines
6.4 KiB
Go

//go:build darwin && arm64
package mlx
/*
#include <stdlib.h>
#include "mlx/c/mlx.h"
*/
import "C"
import (
"fmt"
"math"
"sort"
"unsafe"
)
// LoRALinear wraps a frozen Linear layer with low-rank trainable adapters.
//
// Forward: base(x) + scale * (x @ A^T) @ B^T
//
// A is [rank, in_features] — initialised with Kaiming normal
// B is [out_features, rank] — initialised to zero
// Scale = alpha / rank
//
// Only A and B are trainable. The base weights stay frozen.
type LoRALinear struct {
Base *Linear // Frozen base weights (may be quantized)
A *Array // [rank, in_features] — trainable
B *Array // [out_features, rank] — trainable
Scale float32 // alpha / rank
Rank int
Alpha float32
}
// NewLoRALinear wraps an existing Linear layer with LoRA adapters.
// rank: decomposition rank (typically 4, 8, or 16)
// alpha: scaling factor (typically 2*rank)
func NewLoRALinear(base *Linear, rank int, alpha float32) *LoRALinear {
// Determine dimensions from the base weight.
// Weight shape is [out_features, in_features] for standard linear,
// or quantized shape which we handle via the base layer.
var inFeatures, outFeatures int32
if base.Scales != nil {
// Quantized: weight is packed. Compute dims from scales.
// scales shape is [out_features, in_features / group_size]
scaleShape := base.Scales.Shape()
outFeatures = scaleShape[0]
inFeatures = scaleShape[1] * int32(base.GroupSize)
} else {
wShape := base.Weight.Shape()
outFeatures = wShape[0]
inFeatures = wShape[1]
}
// A: Kaiming normal initialisation — N(0, 1/sqrt(in_features))
stddev := float32(1.0 / math.Sqrt(float64(inFeatures)))
a := RandomNormal(0, stddev, []int32{int32(rank), inFeatures}, DTypeFloat32)
// B: zero initialisation — LoRA starts as identity (no change to base)
b := Zeros([]int32{outFeatures, int32(rank)}, DTypeFloat32)
Materialize(a, b)
return &LoRALinear{
Base: base,
A: a,
B: b,
Scale: alpha / float32(rank),
Rank: rank,
Alpha: alpha,
}
}
// Forward computes: base(x) + scale * (x @ A^T) @ B^T
// Calls baseForward on the underlying Linear to avoid infinite recursion
// when the Linear's Forward method dispatches through LoRA.
func (l *LoRALinear) Forward(x *Array) *Array {
baseOut := l.Base.baseForward(x)
// LoRA path: x @ A^T gives [B, L, rank], then @ B^T gives [B, L, out]
loraOut := Matmul(x, Transpose(l.A))
loraOut = Matmul(loraOut, Transpose(l.B))
loraOut = MulScalar(loraOut, l.Scale)
return Add(baseOut, loraOut)
}
// TrainableParams returns the LoRA A and B arrays for gradient computation.
func (l *LoRALinear) TrainableParams() []*Array {
return []*Array{l.A, l.B}
}
// SetParams updates the LoRA A and B arrays (used by optimiser after gradient step).
func (l *LoRALinear) SetParams(a, b *Array) {
l.A = a
l.B = b
}
// ParamCount returns the number of trainable parameters.
func (l *LoRALinear) ParamCount() int {
aShape := l.A.Shape()
bShape := l.B.Shape()
return int(aShape[0]*aShape[1] + bShape[0]*bShape[1])
}
// --- LoRA Application to Models ---
// LoRAConfig specifies which layers to apply LoRA to and with what parameters.
type LoRAConfig struct {
Rank int // Decomposition rank (default 8)
Alpha float32 // Scaling factor (default 16)
TargetKeys []string // Weight name suffixes to target (default: q_proj, v_proj)
}
// DefaultLoRAConfig returns the standard LoRA configuration for LLM fine-tuning.
func DefaultLoRAConfig() LoRAConfig {
return LoRAConfig{
Rank: 8,
Alpha: 16,
TargetKeys: []string{"q_proj", "v_proj"},
}
}
// LoRAAdapter holds all LoRA layers applied to a model.
type LoRAAdapter struct {
Layers map[string]*LoRALinear // keyed by weight path prefix
Config LoRAConfig
}
// TotalParams returns the total number of trainable parameters across all LoRA layers.
func (a *LoRAAdapter) TotalParams() int {
total := 0
for _, l := range a.Layers {
total += l.ParamCount()
}
return total
}
// SortedNames returns layer names in deterministic sorted order.
func (a *LoRAAdapter) SortedNames() []string {
names := make([]string, 0, len(a.Layers))
for name := range a.Layers {
names = append(names, name)
}
sort.Strings(names)
return names
}
// AllTrainableParams returns all trainable arrays (A and B from every layer),
// in a deterministic order sorted by layer name.
func (a *LoRAAdapter) AllTrainableParams() []*Array {
names := a.SortedNames()
params := make([]*Array, 0, len(names)*2)
for _, name := range names {
l := a.Layers[name]
params = append(params, l.A, l.B)
}
return params
}
// SetAllParams updates all LoRA A and B arrays from a flat slice,
// in the same deterministic order as AllTrainableParams.
func (a *LoRAAdapter) SetAllParams(params []*Array) {
names := a.SortedNames()
for i, name := range names {
l := a.Layers[name]
l.A = params[i*2]
l.B = params[i*2+1]
}
}
// Save writes the LoRA adapter weights to a safetensors file.
// Only saves the A and B matrices — not the frozen base weights.
func (a *LoRAAdapter) Save(path string) error {
weights := make(map[string]*Array)
for name, l := range a.Layers {
weights[name+".lora_a"] = l.A
weights[name+".lora_b"] = l.B
}
return SaveSafetensors(path, weights)
}
// --- Random Normal ---
// RandomNormal generates normal (Gaussian) random values with given mean and stddev.
func RandomNormal(mean, stddev float32, shape []int32, dtype DType) *Array {
Init()
out := New("RANDOM_NORMAL")
cShape := make([]C.int, len(shape))
for i, s := range shape {
cShape[i] = C.int(s)
}
key := C.mlx_array_new()
defer C.mlx_array_free(key)
C.mlx_random_normal(
&out.ctx,
&cShape[0], C.size_t(len(cShape)),
C.mlx_dtype(dtype),
C.float(mean), C.float(stddev),
key, // null key = use default RNG
DefaultStream().ctx,
)
return out
}
// --- SaveSafetensors ---
// SaveSafetensors saves a map of named arrays to a .safetensors file.
func SaveSafetensors(path string, weights map[string]*Array) error {
Init()
// Build the map
cMap := C.mlx_map_string_to_array_new()
defer C.mlx_map_string_to_array_free(cMap)
for name, arr := range weights {
cName := C.CString(name)
C.mlx_map_string_to_array_insert(cMap, cName, arr.ctx)
C.free(unsafe.Pointer(cName))
}
// Empty metadata
cMeta := C.mlx_map_string_to_string_new()
defer C.mlx_map_string_to_string_free(cMeta)
cPath := C.CString(path)
defer C.free(unsafe.Pointer(cPath))
rc := C.mlx_save_safetensors(cPath, cMap, cMeta)
if rc != 0 {
checkError()
return fmt.Errorf("mlx: save safetensors failed: %s", path)
}
return nil
}