feat(metal): add mixed precision training via LoRAConfig.DType (Phase 3)

LoRA A/B matrices can now be created in BFloat16 or Float16 for mixed
precision training. DType field added to LoRAConfig, passed through
ApplyLoRA and NewLoRALinear. MLX auto-promotes for cross-dtype ops.
BFloat16 validated: loss 7.15→6.29, matches Float32 accuracy with
half param memory.

Co-Authored-By: Virgil <virgil@lethean.io>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Snider 2026-02-19 23:13:49 +00:00
parent fa08ed1e2a
commit e3fbc221ce
5 changed files with 108 additions and 6 deletions

View file

@ -22,7 +22,7 @@ Dispatched from core/go orchestration. Pick up tasks in order.
- [x] **LoRA fine-tuning end-to-end** — ✅ Full pipeline validated: load Gemma3-1B → apply LoRA (rank=8, q_proj+v_proj, 745K params) → 5 training steps with cross-entropy loss (7.15→6.31) → save adapter (2.9MB safetensors) → reload and verify weights match. Uses ValueAndGrad + AdamW. 1 new test in train_test.go.
- [x] **Gradient checkpointing** — ✅ `Checkpoint()` validates with real model training. Wraps forward pass to recompute activations during backward. Verified: produces correct gradients (loss 7.15→7.08 in 3 steps, matching non-checkpointed initial loss). 2 new tests: unit (grad_test.go) + model (train_test.go).
- [ ] **Mixed precision training** — MLX supports BFloat16/Float16. Add dtype selection for training (currently inference uses model's native dtype).
- [x] **Mixed precision training** — ✅ `LoRAConfig.DType` selects training dtype for A/B matrices. BFloat16 validated: loss 7.15→6.29 in 5 steps, matches Float32 accuracy with half param memory. MLX auto-promotes for cross-dtype ops. 1 new test in train_test.go.
## Phase 4: Backend Abstraction — ✅ COMPLETE (19 Feb 2026)

View file

@ -440,7 +440,7 @@ func (m *GemmaModel) ApplyLoRA(cfg LoRAConfig) *LoRAAdapter {
proj = layer.Attention.OProj
}
if proj != nil {
lora := NewLoRALinear(proj, cfg.Rank, cfg.Alpha)
lora := NewLoRALinear(proj, cfg.Rank, cfg.Alpha, cfg.DType)
proj.LoRA = lora
adapter.Layers[prefix+"."+target] = lora
}

View file

@ -36,7 +36,12 @@ type LoRALinear struct {
// NewLoRALinear wraps an existing Linear layer with LoRA adapters.
// rank: decomposition rank (typically 4, 8, or 16)
// alpha: scaling factor (typically 2*rank)
func NewLoRALinear(base *Linear, rank int, alpha float32) *LoRALinear {
// dtype: optional training dtype (pass 0 or DTypeFloat32 for default)
func NewLoRALinear(base *Linear, rank int, alpha float32, dtype ...DType) *LoRALinear {
dt := DTypeFloat32
if len(dtype) > 0 && dtype[0] != 0 {
dt = dtype[0]
}
// Determine dimensions from the base weight.
// Weight shape is [out_features, in_features] for standard linear,
// or quantized shape which we handle via the base layer.
@ -56,10 +61,10 @@ func NewLoRALinear(base *Linear, rank int, alpha float32) *LoRALinear {
// A: Kaiming normal initialisation — N(0, 1/sqrt(in_features))
stddev := float32(1.0 / math.Sqrt(float64(inFeatures)))
a := RandomNormal(0, stddev, []int32{int32(rank), inFeatures}, DTypeFloat32)
a := RandomNormal(0, stddev, []int32{int32(rank), inFeatures}, dt)
// B: zero initialisation — LoRA starts as identity (no change to base)
b := Zeros([]int32{outFeatures, int32(rank)}, DTypeFloat32)
b := Zeros([]int32{outFeatures, int32(rank)}, dt)
Materialize(a, b)
@ -112,6 +117,7 @@ type LoRAConfig struct {
Rank int // Decomposition rank (default 8)
Alpha float32 // Scaling factor (default 16)
TargetKeys []string // Weight name suffixes to target (default: q_proj, v_proj)
DType DType // Training dtype for A/B (default Float32; use BFloat16 for mixed precision)
}
// DefaultLoRAConfig returns the standard LoRA configuration for LLM fine-tuning.
@ -120,6 +126,7 @@ func DefaultLoRAConfig() LoRAConfig {
Rank: 8,
Alpha: 16,
TargetKeys: []string{"q_proj", "v_proj"},
DType: DTypeFloat32,
}
}

View file

@ -360,7 +360,7 @@ func (m *Qwen3Model) ApplyLoRA(cfg LoRAConfig) *LoRAAdapter {
proj = layer.Attention.OProj
}
if proj != nil {
lora := NewLoRALinear(proj, cfg.Rank, cfg.Alpha)
lora := NewLoRALinear(proj, cfg.Rank, cfg.Alpha, cfg.DType)
proj.LoRA = lora
adapter.Layers[prefix+"."+target] = lora
}

View file

@ -276,3 +276,98 @@ func TestLoRA_GradientCheckpointing(t *testing.T) {
ClearCache()
}
// TestLoRA_MixedPrecision validates training with BFloat16 LoRA parameters.
// The base model stays in its native dtype; LoRA A/B are BFloat16.
// MLX auto-promotes for cross-dtype operations.
func TestLoRA_MixedPrecision(t *testing.T) {
modelPath := gemma3Path(t)
model, err := loadModel(modelPath)
if err != nil {
t.Fatalf("loadModel: %v", err)
}
gemma := model.(*GemmaModel)
tok := gemma.Tokenizer()
// Apply LoRA with BFloat16 parameters.
cfg := DefaultLoRAConfig()
cfg.DType = DTypeBFloat16
adapter := gemma.ApplyLoRA(cfg)
// Verify A/B are actually BFloat16.
for name, layer := range adapter.Layers {
if layer.A.Dtype() != DTypeBFloat16 {
t.Errorf("%s: A dtype = %v, want bfloat16", name, layer.A.Dtype())
}
if layer.B.Dtype() != DTypeBFloat16 {
t.Errorf("%s: B dtype = %v, want bfloat16", name, layer.B.Dtype())
}
break // just check first layer
}
t.Logf("LoRA BFloat16: %d trainable params (half memory vs Float32)",
adapter.TotalParams())
inputIDs := tok.Encode("The capital of France is Paris")
seqLen := len(inputIDs) - 1
inputTokens := FromValues(inputIDs[:seqLen], 1, seqLen)
targetTokens := FromValues(inputIDs[1:], 1, seqLen)
Materialize(inputTokens, targetTokens)
params := adapter.AllTrainableParams()
argnums := make([]int, len(params))
for i := range argnums {
argnums[i] = i
}
opt := NewAdamW(1e-4)
var initialLoss, finalLoss float64
const numSteps = 5
for step := 0; step < numSteps; step++ {
caches := gemma.NewCache()
lossFn := func(inputs []*Array) []*Array {
adapter.SetAllParams(inputs)
logits := gemma.Forward(inputTokens, caches)
loss := CrossEntropyLoss(logits, targetTokens)
return []*Array{loss}
}
grad := ValueAndGrad(lossFn, argnums...)
values, grads, err := grad.Apply(params...)
grad.Free()
if err != nil {
t.Fatalf("step %d: ValueAndGrad failed: %v", step, err)
}
Materialize(append(values, grads...)...)
loss := values[0].Float()
t.Logf("step %d: loss = %.4f (bf16)", step, loss)
if step == 0 {
initialLoss = loss
if math.IsNaN(loss) || math.IsInf(loss, 0) {
t.Fatalf("initial loss is %f — bf16 may have caused NaN", loss)
}
}
finalLoss = loss
updated := opt.Step(params, grads)
for i := range updated {
Materialize(updated[i])
}
params = updated
adapter.SetAllParams(params)
}
t.Logf("bf16 loss: %.4f → %.4f", initialLoss, finalLoss)
if finalLoss >= initialLoss {
t.Errorf("loss did not decrease with bf16: %.4f → %.4f", initialLoss, finalLoss)
}
ClearCache()
}