feat(mlx): LoRA injection into models + masked cross-entropy loss
Add LoRA field to Linear for transparent adapter injection via model's Forward() path. ApplyLoRA() on Qwen3/Gemma3 wraps target projections. Deterministic param ordering for adapter save/load consistency. MaskedCrossEntropyLoss for training on assistant tokens only. Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
parent
0eaf3d5a17
commit
2d870385f9
6 changed files with 128 additions and 5 deletions
16
mlx/grad.go
16
mlx/grad.go
|
|
@ -289,6 +289,22 @@ func CrossEntropyLoss(logits, targets *Array) *Array {
|
|||
return MeanAll(perPos)
|
||||
}
|
||||
|
||||
// MaskedCrossEntropyLoss computes cross-entropy loss only on masked positions.
|
||||
// logits: [B, L, V], targets: [B, L], mask: [B, L] (1.0 = compute loss, 0.0 = ignore)
|
||||
// Returns scalar loss averaged over masked positions only.
|
||||
func MaskedCrossEntropyLoss(logits, targets, mask *Array) *Array {
|
||||
Init()
|
||||
lse := LogSumExp(logits, -1, false)
|
||||
tgtExpanded := ExpandDims(targets, -1)
|
||||
gathered := TakeAlongAxis(logits, tgtExpanded, -1)
|
||||
gathered = Squeeze(gathered, -1)
|
||||
perPos := Subtract(lse, gathered) // [B, L]
|
||||
|
||||
// Apply mask and average over masked positions
|
||||
masked := Mul(perPos, mask)
|
||||
return Divide(SumAll(masked), SumAll(mask))
|
||||
}
|
||||
|
||||
// MSELoss computes the mean squared error loss: mean((predictions - targets)^2).
|
||||
func MSELoss(predictions, targets *Array) *Array {
|
||||
diff := Subtract(predictions, targets)
|
||||
|
|
|
|||
36
mlx/lora.go
36
mlx/lora.go
|
|
@ -11,6 +11,7 @@ import "C"
|
|||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
"sort"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
|
|
@ -73,8 +74,10 @@ func NewLoRALinear(base *Linear, rank int, alpha float32) *LoRALinear {
|
|||
}
|
||||
|
||||
// Forward computes: base(x) + scale * (x @ A^T) @ B^T
|
||||
// Calls baseForward on the underlying Linear to avoid infinite recursion
|
||||
// when the Linear's Forward method dispatches through LoRA.
|
||||
func (l *LoRALinear) Forward(x *Array) *Array {
|
||||
baseOut := l.Base.Forward(x)
|
||||
baseOut := l.Base.baseForward(x)
|
||||
|
||||
// LoRA path: x @ A^T gives [B, L, rank], then @ B^T gives [B, L, out]
|
||||
loraOut := Matmul(x, Transpose(l.A))
|
||||
|
|
@ -135,16 +138,39 @@ func (a *LoRAAdapter) TotalParams() int {
|
|||
return total
|
||||
}
|
||||
|
||||
// SortedNames returns layer names in deterministic sorted order.
|
||||
func (a *LoRAAdapter) SortedNames() []string {
|
||||
names := make([]string, 0, len(a.Layers))
|
||||
for name := range a.Layers {
|
||||
names = append(names, name)
|
||||
}
|
||||
sort.Strings(names)
|
||||
return names
|
||||
}
|
||||
|
||||
// AllTrainableParams returns all trainable arrays (A and B from every layer),
|
||||
// in a deterministic order by layer name.
|
||||
// in a deterministic order sorted by layer name.
|
||||
func (a *LoRAAdapter) AllTrainableParams() []*Array {
|
||||
var params []*Array
|
||||
for _, l := range a.Layers {
|
||||
params = append(params, l.TrainableParams()...)
|
||||
names := a.SortedNames()
|
||||
params := make([]*Array, 0, len(names)*2)
|
||||
for _, name := range names {
|
||||
l := a.Layers[name]
|
||||
params = append(params, l.A, l.B)
|
||||
}
|
||||
return params
|
||||
}
|
||||
|
||||
// SetAllParams updates all LoRA A and B arrays from a flat slice,
|
||||
// in the same deterministic order as AllTrainableParams.
|
||||
func (a *LoRAAdapter) SetAllParams(params []*Array) {
|
||||
names := a.SortedNames()
|
||||
for i, name := range names {
|
||||
l := a.Layers[name]
|
||||
l.A = params[i*2]
|
||||
l.B = params[i*2+1]
|
||||
}
|
||||
}
|
||||
|
||||
// Save writes the LoRA adapter weights to a safetensors file.
|
||||
// Only saves the A and B matrices — not the frozen base weights.
|
||||
func (a *LoRAAdapter) Save(path string) error {
|
||||
|
|
|
|||
|
|
@ -414,3 +414,35 @@ func (m *GemmaModel) Tokenizer() *tokenizer.Tokenizer { return m.Tok }
|
|||
|
||||
// ModelType returns the architecture identifier.
|
||||
func (m *GemmaModel) ModelType() string { return "gemma3" }
|
||||
|
||||
// ApplyLoRA wraps target projection layers with LoRA adapters.
|
||||
func (m *GemmaModel) ApplyLoRA(cfg mlx.LoRAConfig) *mlx.LoRAAdapter {
|
||||
adapter := &mlx.LoRAAdapter{
|
||||
Layers: make(map[string]*mlx.LoRALinear),
|
||||
Config: cfg,
|
||||
}
|
||||
|
||||
for i, layer := range m.Layers {
|
||||
prefix := fmt.Sprintf("model.layers.%d.self_attn", i)
|
||||
for _, target := range cfg.TargetKeys {
|
||||
var proj *mlx.Linear
|
||||
switch target {
|
||||
case "q_proj":
|
||||
proj = layer.Attention.QProj
|
||||
case "k_proj":
|
||||
proj = layer.Attention.KProj
|
||||
case "v_proj":
|
||||
proj = layer.Attention.VProj
|
||||
case "o_proj":
|
||||
proj = layer.Attention.OProj
|
||||
}
|
||||
if proj != nil {
|
||||
lora := mlx.NewLoRALinear(proj, cfg.Rank, cfg.Alpha)
|
||||
proj.LoRA = lora
|
||||
adapter.Layers[prefix+"."+target] = lora
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return adapter
|
||||
}
|
||||
|
|
|
|||
|
|
@ -30,6 +30,10 @@ type Model interface {
|
|||
|
||||
// ModelType returns the architecture identifier (e.g. "gemma3", "qwen3").
|
||||
ModelType() string
|
||||
|
||||
// ApplyLoRA wraps target projection layers with LoRA adapters for training.
|
||||
// Returns the adapter which holds references to all LoRA layers.
|
||||
ApplyLoRA(cfg mlx.LoRAConfig) *mlx.LoRAAdapter
|
||||
}
|
||||
|
||||
// QuantizationConfig holds quantization parameters from config.json.
|
||||
|
|
|
|||
|
|
@ -303,3 +303,35 @@ func (m *Qwen3Model) Tokenizer() *tokenizer.Tokenizer { return m.Tok }
|
|||
|
||||
// ModelType returns the architecture identifier.
|
||||
func (m *Qwen3Model) ModelType() string { return "qwen3" }
|
||||
|
||||
// ApplyLoRA wraps target projection layers with LoRA adapters.
|
||||
func (m *Qwen3Model) ApplyLoRA(cfg mlx.LoRAConfig) *mlx.LoRAAdapter {
|
||||
adapter := &mlx.LoRAAdapter{
|
||||
Layers: make(map[string]*mlx.LoRALinear),
|
||||
Config: cfg,
|
||||
}
|
||||
|
||||
for i, layer := range m.Layers {
|
||||
prefix := fmt.Sprintf("model.layers.%d.self_attn", i)
|
||||
for _, target := range cfg.TargetKeys {
|
||||
var proj *mlx.Linear
|
||||
switch target {
|
||||
case "q_proj":
|
||||
proj = layer.Attention.QProj
|
||||
case "k_proj":
|
||||
proj = layer.Attention.KProj
|
||||
case "v_proj":
|
||||
proj = layer.Attention.VProj
|
||||
case "o_proj":
|
||||
proj = layer.Attention.OProj
|
||||
}
|
||||
if proj != nil {
|
||||
lora := mlx.NewLoRALinear(proj, cfg.Rank, cfg.Alpha)
|
||||
proj.LoRA = lora
|
||||
adapter.Layers[prefix+"."+target] = lora
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return adapter
|
||||
}
|
||||
|
|
|
|||
13
mlx/nn.go
13
mlx/nn.go
|
|
@ -4,6 +4,7 @@ package mlx
|
|||
|
||||
// Linear is a fully-connected layer: y = x @ W.T + bias.
|
||||
// For quantized models, set Scales/Biases/GroupSize/Bits to use QuantizedMatmul.
|
||||
// Set LoRA to inject a low-rank adapter (training only).
|
||||
type Linear struct {
|
||||
Weight *Array `weight:"weight"`
|
||||
Scales *Array `weight:"scales"`
|
||||
|
|
@ -11,6 +12,8 @@ type Linear struct {
|
|||
Bias *Array `weight:"bias"`
|
||||
GroupSize int
|
||||
Bits int
|
||||
|
||||
LoRA *LoRALinear // Optional LoRA adapter — if set, Forward routes through it
|
||||
}
|
||||
|
||||
// NewLinear creates a dense Linear layer with optional bias.
|
||||
|
|
@ -31,8 +34,18 @@ func NewQuantizedLinear(weight, scales, biases, bias *Array, groupSize, bits int
|
|||
}
|
||||
|
||||
// Forward computes the linear transformation.
|
||||
// If a LoRA adapter is attached, routes through it instead (base + low-rank delta).
|
||||
// Uses QuantizedMatmul when quantization parameters are present.
|
||||
func (l *Linear) Forward(x *Array) *Array {
|
||||
if l.LoRA != nil {
|
||||
return l.LoRA.Forward(x)
|
||||
}
|
||||
return l.baseForward(x)
|
||||
}
|
||||
|
||||
// baseForward is the raw linear transformation without LoRA.
|
||||
// Used internally by LoRALinear to avoid infinite recursion.
|
||||
func (l *Linear) baseForward(x *Array) *Array {
|
||||
var out *Array
|
||||
if l.Scales != nil {
|
||||
out = QuantizedMatmul(x, l.Weight, l.Scales, l.Biases, true, l.GroupSize, l.Bits)
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue