feat(mlx): LoRA injection into models + masked cross-entropy loss

Add LoRA field to Linear for transparent adapter injection via model's
Forward() path. ApplyLoRA() on Qwen3/Gemma3 wraps target projections.
Deterministic param ordering for adapter save/load consistency.
MaskedCrossEntropyLoss for training on assistant tokens only.

Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
Snider 2026-02-17 17:37:44 +00:00
parent 0eaf3d5a17
commit 2d870385f9
6 changed files with 128 additions and 5 deletions

View file

@ -289,6 +289,22 @@ func CrossEntropyLoss(logits, targets *Array) *Array {
return MeanAll(perPos)
}
// MaskedCrossEntropyLoss computes cross-entropy loss only on masked positions.
// logits: [B, L, V], targets: [B, L], mask: [B, L] (1.0 = compute loss, 0.0 = ignore)
// Returns scalar loss averaged over masked positions only.
func MaskedCrossEntropyLoss(logits, targets, mask *Array) *Array {
Init()
lse := LogSumExp(logits, -1, false)
tgtExpanded := ExpandDims(targets, -1)
gathered := TakeAlongAxis(logits, tgtExpanded, -1)
gathered = Squeeze(gathered, -1)
perPos := Subtract(lse, gathered) // [B, L]
// Apply mask and average over masked positions
masked := Mul(perPos, mask)
return Divide(SumAll(masked), SumAll(mask))
}
// MSELoss computes the mean squared error loss: mean((predictions - targets)^2).
func MSELoss(predictions, targets *Array) *Array {
diff := Subtract(predictions, targets)

View file

@ -11,6 +11,7 @@ import "C"
import (
"fmt"
"math"
"sort"
"unsafe"
)
@ -73,8 +74,10 @@ func NewLoRALinear(base *Linear, rank int, alpha float32) *LoRALinear {
}
// Forward computes: base(x) + scale * (x @ A^T) @ B^T
// Calls baseForward on the underlying Linear to avoid infinite recursion
// when the Linear's Forward method dispatches through LoRA.
func (l *LoRALinear) Forward(x *Array) *Array {
baseOut := l.Base.Forward(x)
baseOut := l.Base.baseForward(x)
// LoRA path: x @ A^T gives [B, L, rank], then @ B^T gives [B, L, out]
loraOut := Matmul(x, Transpose(l.A))
@ -135,16 +138,39 @@ func (a *LoRAAdapter) TotalParams() int {
return total
}
// SortedNames returns layer names in deterministic sorted order.
func (a *LoRAAdapter) SortedNames() []string {
names := make([]string, 0, len(a.Layers))
for name := range a.Layers {
names = append(names, name)
}
sort.Strings(names)
return names
}
// AllTrainableParams returns all trainable arrays (A and B from every layer),
// in a deterministic order by layer name.
// in a deterministic order sorted by layer name.
func (a *LoRAAdapter) AllTrainableParams() []*Array {
var params []*Array
for _, l := range a.Layers {
params = append(params, l.TrainableParams()...)
names := a.SortedNames()
params := make([]*Array, 0, len(names)*2)
for _, name := range names {
l := a.Layers[name]
params = append(params, l.A, l.B)
}
return params
}
// SetAllParams updates all LoRA A and B arrays from a flat slice,
// in the same deterministic order as AllTrainableParams.
func (a *LoRAAdapter) SetAllParams(params []*Array) {
names := a.SortedNames()
for i, name := range names {
l := a.Layers[name]
l.A = params[i*2]
l.B = params[i*2+1]
}
}
// Save writes the LoRA adapter weights to a safetensors file.
// Only saves the A and B matrices — not the frozen base weights.
func (a *LoRAAdapter) Save(path string) error {

View file

@ -414,3 +414,35 @@ func (m *GemmaModel) Tokenizer() *tokenizer.Tokenizer { return m.Tok }
// ModelType returns the architecture identifier.
func (m *GemmaModel) ModelType() string { return "gemma3" }
// ApplyLoRA wraps target projection layers with LoRA adapters.
func (m *GemmaModel) ApplyLoRA(cfg mlx.LoRAConfig) *mlx.LoRAAdapter {
adapter := &mlx.LoRAAdapter{
Layers: make(map[string]*mlx.LoRALinear),
Config: cfg,
}
for i, layer := range m.Layers {
prefix := fmt.Sprintf("model.layers.%d.self_attn", i)
for _, target := range cfg.TargetKeys {
var proj *mlx.Linear
switch target {
case "q_proj":
proj = layer.Attention.QProj
case "k_proj":
proj = layer.Attention.KProj
case "v_proj":
proj = layer.Attention.VProj
case "o_proj":
proj = layer.Attention.OProj
}
if proj != nil {
lora := mlx.NewLoRALinear(proj, cfg.Rank, cfg.Alpha)
proj.LoRA = lora
adapter.Layers[prefix+"."+target] = lora
}
}
}
return adapter
}

View file

@ -30,6 +30,10 @@ type Model interface {
// ModelType returns the architecture identifier (e.g. "gemma3", "qwen3").
ModelType() string
// ApplyLoRA wraps target projection layers with LoRA adapters for training.
// Returns the adapter which holds references to all LoRA layers.
ApplyLoRA(cfg mlx.LoRAConfig) *mlx.LoRAAdapter
}
// QuantizationConfig holds quantization parameters from config.json.

View file

@ -303,3 +303,35 @@ func (m *Qwen3Model) Tokenizer() *tokenizer.Tokenizer { return m.Tok }
// ModelType returns the architecture identifier.
func (m *Qwen3Model) ModelType() string { return "qwen3" }
// ApplyLoRA wraps target projection layers with LoRA adapters.
func (m *Qwen3Model) ApplyLoRA(cfg mlx.LoRAConfig) *mlx.LoRAAdapter {
adapter := &mlx.LoRAAdapter{
Layers: make(map[string]*mlx.LoRALinear),
Config: cfg,
}
for i, layer := range m.Layers {
prefix := fmt.Sprintf("model.layers.%d.self_attn", i)
for _, target := range cfg.TargetKeys {
var proj *mlx.Linear
switch target {
case "q_proj":
proj = layer.Attention.QProj
case "k_proj":
proj = layer.Attention.KProj
case "v_proj":
proj = layer.Attention.VProj
case "o_proj":
proj = layer.Attention.OProj
}
if proj != nil {
lora := mlx.NewLoRALinear(proj, cfg.Rank, cfg.Alpha)
proj.LoRA = lora
adapter.Layers[prefix+"."+target] = lora
}
}
}
return adapter
}

View file

@ -4,6 +4,7 @@ package mlx
// Linear is a fully-connected layer: y = x @ W.T + bias.
// For quantized models, set Scales/Biases/GroupSize/Bits to use QuantizedMatmul.
// Set LoRA to inject a low-rank adapter (training only).
type Linear struct {
Weight *Array `weight:"weight"`
Scales *Array `weight:"scales"`
@ -11,6 +12,8 @@ type Linear struct {
Bias *Array `weight:"bias"`
GroupSize int
Bits int
LoRA *LoRALinear // Optional LoRA adapter — if set, Forward routes through it
}
// NewLinear creates a dense Linear layer with optional bias.
@ -31,8 +34,18 @@ func NewQuantizedLinear(weight, scales, biases, bias *Array, groupSize, bits int
}
// Forward computes the linear transformation.
// If a LoRA adapter is attached, routes through it instead (base + low-rank delta).
// Uses QuantizedMatmul when quantization parameters are present.
func (l *Linear) Forward(x *Array) *Array {
if l.LoRA != nil {
return l.LoRA.Forward(x)
}
return l.baseForward(x)
}
// baseForward is the raw linear transformation without LoRA.
// Used internally by LoRALinear to avoid infinite recursion.
func (l *Linear) baseForward(x *Array) *Array {
var out *Array
if l.Scales != nil {
out = QuantizedMatmul(x, l.Weight, l.Scales, l.Biases, true, l.GroupSize, l.Bits)