From 2d870385f95aaa88197a7e7ca9c607c3d9436a1e Mon Sep 17 00:00:00 2001 From: Snider Date: Tue, 17 Feb 2026 17:37:44 +0000 Subject: [PATCH] feat(mlx): LoRA injection into models + masked cross-entropy loss Add LoRA field to Linear for transparent adapter injection via model's Forward() path. ApplyLoRA() on Qwen3/Gemma3 wraps target projections. Deterministic param ordering for adapter save/load consistency. MaskedCrossEntropyLoss for training on assistant tokens only. Co-Authored-By: Virgil --- mlx/grad.go | 16 ++++++++++++++++ mlx/lora.go | 36 +++++++++++++++++++++++++++++++----- mlx/model/gemma3.go | 32 ++++++++++++++++++++++++++++++++ mlx/model/model.go | 4 ++++ mlx/model/qwen3.go | 32 ++++++++++++++++++++++++++++++++ mlx/nn.go | 13 +++++++++++++ 6 files changed, 128 insertions(+), 5 deletions(-) diff --git a/mlx/grad.go b/mlx/grad.go index f46ee7b..12ca786 100644 --- a/mlx/grad.go +++ b/mlx/grad.go @@ -289,6 +289,22 @@ func CrossEntropyLoss(logits, targets *Array) *Array { return MeanAll(perPos) } +// MaskedCrossEntropyLoss computes cross-entropy loss only on masked positions. +// logits: [B, L, V], targets: [B, L], mask: [B, L] (1.0 = compute loss, 0.0 = ignore) +// Returns scalar loss averaged over masked positions only. +func MaskedCrossEntropyLoss(logits, targets, mask *Array) *Array { + Init() + lse := LogSumExp(logits, -1, false) + tgtExpanded := ExpandDims(targets, -1) + gathered := TakeAlongAxis(logits, tgtExpanded, -1) + gathered = Squeeze(gathered, -1) + perPos := Subtract(lse, gathered) // [B, L] + + // Apply mask and average over masked positions + masked := Mul(perPos, mask) + return Divide(SumAll(masked), SumAll(mask)) +} + // MSELoss computes the mean squared error loss: mean((predictions - targets)^2). func MSELoss(predictions, targets *Array) *Array { diff := Subtract(predictions, targets) diff --git a/mlx/lora.go b/mlx/lora.go index 31621e1..ac7107e 100644 --- a/mlx/lora.go +++ b/mlx/lora.go @@ -11,6 +11,7 @@ import "C" import ( "fmt" "math" + "sort" "unsafe" ) @@ -73,8 +74,10 @@ func NewLoRALinear(base *Linear, rank int, alpha float32) *LoRALinear { } // Forward computes: base(x) + scale * (x @ A^T) @ B^T +// Calls baseForward on the underlying Linear to avoid infinite recursion +// when the Linear's Forward method dispatches through LoRA. func (l *LoRALinear) Forward(x *Array) *Array { - baseOut := l.Base.Forward(x) + baseOut := l.Base.baseForward(x) // LoRA path: x @ A^T gives [B, L, rank], then @ B^T gives [B, L, out] loraOut := Matmul(x, Transpose(l.A)) @@ -135,16 +138,39 @@ func (a *LoRAAdapter) TotalParams() int { return total } +// SortedNames returns layer names in deterministic sorted order. +func (a *LoRAAdapter) SortedNames() []string { + names := make([]string, 0, len(a.Layers)) + for name := range a.Layers { + names = append(names, name) + } + sort.Strings(names) + return names +} + // AllTrainableParams returns all trainable arrays (A and B from every layer), -// in a deterministic order by layer name. +// in a deterministic order sorted by layer name. func (a *LoRAAdapter) AllTrainableParams() []*Array { - var params []*Array - for _, l := range a.Layers { - params = append(params, l.TrainableParams()...) + names := a.SortedNames() + params := make([]*Array, 0, len(names)*2) + for _, name := range names { + l := a.Layers[name] + params = append(params, l.A, l.B) } return params } +// SetAllParams updates all LoRA A and B arrays from a flat slice, +// in the same deterministic order as AllTrainableParams. +func (a *LoRAAdapter) SetAllParams(params []*Array) { + names := a.SortedNames() + for i, name := range names { + l := a.Layers[name] + l.A = params[i*2] + l.B = params[i*2+1] + } +} + // Save writes the LoRA adapter weights to a safetensors file. // Only saves the A and B matrices — not the frozen base weights. func (a *LoRAAdapter) Save(path string) error { diff --git a/mlx/model/gemma3.go b/mlx/model/gemma3.go index 11d3ae1..849df4c 100644 --- a/mlx/model/gemma3.go +++ b/mlx/model/gemma3.go @@ -414,3 +414,35 @@ func (m *GemmaModel) Tokenizer() *tokenizer.Tokenizer { return m.Tok } // ModelType returns the architecture identifier. func (m *GemmaModel) ModelType() string { return "gemma3" } + +// ApplyLoRA wraps target projection layers with LoRA adapters. +func (m *GemmaModel) ApplyLoRA(cfg mlx.LoRAConfig) *mlx.LoRAAdapter { + adapter := &mlx.LoRAAdapter{ + Layers: make(map[string]*mlx.LoRALinear), + Config: cfg, + } + + for i, layer := range m.Layers { + prefix := fmt.Sprintf("model.layers.%d.self_attn", i) + for _, target := range cfg.TargetKeys { + var proj *mlx.Linear + switch target { + case "q_proj": + proj = layer.Attention.QProj + case "k_proj": + proj = layer.Attention.KProj + case "v_proj": + proj = layer.Attention.VProj + case "o_proj": + proj = layer.Attention.OProj + } + if proj != nil { + lora := mlx.NewLoRALinear(proj, cfg.Rank, cfg.Alpha) + proj.LoRA = lora + adapter.Layers[prefix+"."+target] = lora + } + } + } + + return adapter +} diff --git a/mlx/model/model.go b/mlx/model/model.go index 5e5481d..faf1faa 100644 --- a/mlx/model/model.go +++ b/mlx/model/model.go @@ -30,6 +30,10 @@ type Model interface { // ModelType returns the architecture identifier (e.g. "gemma3", "qwen3"). ModelType() string + + // ApplyLoRA wraps target projection layers with LoRA adapters for training. + // Returns the adapter which holds references to all LoRA layers. + ApplyLoRA(cfg mlx.LoRAConfig) *mlx.LoRAAdapter } // QuantizationConfig holds quantization parameters from config.json. diff --git a/mlx/model/qwen3.go b/mlx/model/qwen3.go index 3c15578..469f9f7 100644 --- a/mlx/model/qwen3.go +++ b/mlx/model/qwen3.go @@ -303,3 +303,35 @@ func (m *Qwen3Model) Tokenizer() *tokenizer.Tokenizer { return m.Tok } // ModelType returns the architecture identifier. func (m *Qwen3Model) ModelType() string { return "qwen3" } + +// ApplyLoRA wraps target projection layers with LoRA adapters. +func (m *Qwen3Model) ApplyLoRA(cfg mlx.LoRAConfig) *mlx.LoRAAdapter { + adapter := &mlx.LoRAAdapter{ + Layers: make(map[string]*mlx.LoRALinear), + Config: cfg, + } + + for i, layer := range m.Layers { + prefix := fmt.Sprintf("model.layers.%d.self_attn", i) + for _, target := range cfg.TargetKeys { + var proj *mlx.Linear + switch target { + case "q_proj": + proj = layer.Attention.QProj + case "k_proj": + proj = layer.Attention.KProj + case "v_proj": + proj = layer.Attention.VProj + case "o_proj": + proj = layer.Attention.OProj + } + if proj != nil { + lora := mlx.NewLoRALinear(proj, cfg.Rank, cfg.Alpha) + proj.LoRA = lora + adapter.Layers[prefix+"."+target] = lora + } + } + } + + return adapter +} diff --git a/mlx/nn.go b/mlx/nn.go index ccd0c9e..4de9db9 100644 --- a/mlx/nn.go +++ b/mlx/nn.go @@ -4,6 +4,7 @@ package mlx // Linear is a fully-connected layer: y = x @ W.T + bias. // For quantized models, set Scales/Biases/GroupSize/Bits to use QuantizedMatmul. +// Set LoRA to inject a low-rank adapter (training only). type Linear struct { Weight *Array `weight:"weight"` Scales *Array `weight:"scales"` @@ -11,6 +12,8 @@ type Linear struct { Bias *Array `weight:"bias"` GroupSize int Bits int + + LoRA *LoRALinear // Optional LoRA adapter — if set, Forward routes through it } // NewLinear creates a dense Linear layer with optional bias. @@ -31,8 +34,18 @@ func NewQuantizedLinear(weight, scales, biases, bias *Array, groupSize, bits int } // Forward computes the linear transformation. +// If a LoRA adapter is attached, routes through it instead (base + low-rank delta). // Uses QuantizedMatmul when quantization parameters are present. func (l *Linear) Forward(x *Array) *Array { + if l.LoRA != nil { + return l.LoRA.Forward(x) + } + return l.baseForward(x) +} + +// baseForward is the raw linear transformation without LoRA. +// Used internally by LoRALinear to avoid infinite recursion. +func (l *Linear) baseForward(x *Array) *Array { var out *Array if l.Scales != nil { out = QuantizedMatmul(x, l.Weight, l.Scales, l.Biases, true, l.GroupSize, l.Bits)