feat(metal): add mixed precision training via LoRAConfig.DType (Phase 3)
LoRA A/B matrices can now be created in BFloat16 or Float16 for mixed precision training. DType field added to LoRAConfig, passed through ApplyLoRA and NewLoRALinear. MLX auto-promotes for cross-dtype ops. BFloat16 validated: loss 7.15→6.29, matches Float32 accuracy with half param memory. Co-Authored-By: Virgil <virgil@lethean.io> Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
fa08ed1e2a
commit
e3fbc221ce
5 changed files with 108 additions and 6 deletions
2
TODO.md
2
TODO.md
|
|
@ -22,7 +22,7 @@ Dispatched from core/go orchestration. Pick up tasks in order.
|
|||
|
||||
- [x] **LoRA fine-tuning end-to-end** — ✅ Full pipeline validated: load Gemma3-1B → apply LoRA (rank=8, q_proj+v_proj, 745K params) → 5 training steps with cross-entropy loss (7.15→6.31) → save adapter (2.9MB safetensors) → reload and verify weights match. Uses ValueAndGrad + AdamW. 1 new test in train_test.go.
|
||||
- [x] **Gradient checkpointing** — ✅ `Checkpoint()` validates with real model training. Wraps forward pass to recompute activations during backward. Verified: produces correct gradients (loss 7.15→7.08 in 3 steps, matching non-checkpointed initial loss). 2 new tests: unit (grad_test.go) + model (train_test.go).
|
||||
- [ ] **Mixed precision training** — MLX supports BFloat16/Float16. Add dtype selection for training (currently inference uses model's native dtype).
|
||||
- [x] **Mixed precision training** — ✅ `LoRAConfig.DType` selects training dtype for A/B matrices. BFloat16 validated: loss 7.15→6.29 in 5 steps, matches Float32 accuracy with half param memory. MLX auto-promotes for cross-dtype ops. 1 new test in train_test.go.
|
||||
|
||||
## Phase 4: Backend Abstraction — ✅ COMPLETE (19 Feb 2026)
|
||||
|
||||
|
|
|
|||
|
|
@ -440,7 +440,7 @@ func (m *GemmaModel) ApplyLoRA(cfg LoRAConfig) *LoRAAdapter {
|
|||
proj = layer.Attention.OProj
|
||||
}
|
||||
if proj != nil {
|
||||
lora := NewLoRALinear(proj, cfg.Rank, cfg.Alpha)
|
||||
lora := NewLoRALinear(proj, cfg.Rank, cfg.Alpha, cfg.DType)
|
||||
proj.LoRA = lora
|
||||
adapter.Layers[prefix+"."+target] = lora
|
||||
}
|
||||
|
|
|
|||
|
|
@ -36,7 +36,12 @@ type LoRALinear struct {
|
|||
// NewLoRALinear wraps an existing Linear layer with LoRA adapters.
|
||||
// rank: decomposition rank (typically 4, 8, or 16)
|
||||
// alpha: scaling factor (typically 2*rank)
|
||||
func NewLoRALinear(base *Linear, rank int, alpha float32) *LoRALinear {
|
||||
// dtype: optional training dtype (pass 0 or DTypeFloat32 for default)
|
||||
func NewLoRALinear(base *Linear, rank int, alpha float32, dtype ...DType) *LoRALinear {
|
||||
dt := DTypeFloat32
|
||||
if len(dtype) > 0 && dtype[0] != 0 {
|
||||
dt = dtype[0]
|
||||
}
|
||||
// Determine dimensions from the base weight.
|
||||
// Weight shape is [out_features, in_features] for standard linear,
|
||||
// or quantized shape which we handle via the base layer.
|
||||
|
|
@ -56,10 +61,10 @@ func NewLoRALinear(base *Linear, rank int, alpha float32) *LoRALinear {
|
|||
|
||||
// A: Kaiming normal initialisation — N(0, 1/sqrt(in_features))
|
||||
stddev := float32(1.0 / math.Sqrt(float64(inFeatures)))
|
||||
a := RandomNormal(0, stddev, []int32{int32(rank), inFeatures}, DTypeFloat32)
|
||||
a := RandomNormal(0, stddev, []int32{int32(rank), inFeatures}, dt)
|
||||
|
||||
// B: zero initialisation — LoRA starts as identity (no change to base)
|
||||
b := Zeros([]int32{outFeatures, int32(rank)}, DTypeFloat32)
|
||||
b := Zeros([]int32{outFeatures, int32(rank)}, dt)
|
||||
|
||||
Materialize(a, b)
|
||||
|
||||
|
|
@ -112,6 +117,7 @@ type LoRAConfig struct {
|
|||
Rank int // Decomposition rank (default 8)
|
||||
Alpha float32 // Scaling factor (default 16)
|
||||
TargetKeys []string // Weight name suffixes to target (default: q_proj, v_proj)
|
||||
DType DType // Training dtype for A/B (default Float32; use BFloat16 for mixed precision)
|
||||
}
|
||||
|
||||
// DefaultLoRAConfig returns the standard LoRA configuration for LLM fine-tuning.
|
||||
|
|
@ -120,6 +126,7 @@ func DefaultLoRAConfig() LoRAConfig {
|
|||
Rank: 8,
|
||||
Alpha: 16,
|
||||
TargetKeys: []string{"q_proj", "v_proj"},
|
||||
DType: DTypeFloat32,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -360,7 +360,7 @@ func (m *Qwen3Model) ApplyLoRA(cfg LoRAConfig) *LoRAAdapter {
|
|||
proj = layer.Attention.OProj
|
||||
}
|
||||
if proj != nil {
|
||||
lora := NewLoRALinear(proj, cfg.Rank, cfg.Alpha)
|
||||
lora := NewLoRALinear(proj, cfg.Rank, cfg.Alpha, cfg.DType)
|
||||
proj.LoRA = lora
|
||||
adapter.Layers[prefix+"."+target] = lora
|
||||
}
|
||||
|
|
|
|||
|
|
@ -276,3 +276,98 @@ func TestLoRA_GradientCheckpointing(t *testing.T) {
|
|||
|
||||
ClearCache()
|
||||
}
|
||||
|
||||
// TestLoRA_MixedPrecision validates training with BFloat16 LoRA parameters.
|
||||
// The base model stays in its native dtype; LoRA A/B are BFloat16.
|
||||
// MLX auto-promotes for cross-dtype operations.
|
||||
func TestLoRA_MixedPrecision(t *testing.T) {
|
||||
modelPath := gemma3Path(t)
|
||||
|
||||
model, err := loadModel(modelPath)
|
||||
if err != nil {
|
||||
t.Fatalf("loadModel: %v", err)
|
||||
}
|
||||
|
||||
gemma := model.(*GemmaModel)
|
||||
tok := gemma.Tokenizer()
|
||||
|
||||
// Apply LoRA with BFloat16 parameters.
|
||||
cfg := DefaultLoRAConfig()
|
||||
cfg.DType = DTypeBFloat16
|
||||
adapter := gemma.ApplyLoRA(cfg)
|
||||
|
||||
// Verify A/B are actually BFloat16.
|
||||
for name, layer := range adapter.Layers {
|
||||
if layer.A.Dtype() != DTypeBFloat16 {
|
||||
t.Errorf("%s: A dtype = %v, want bfloat16", name, layer.A.Dtype())
|
||||
}
|
||||
if layer.B.Dtype() != DTypeBFloat16 {
|
||||
t.Errorf("%s: B dtype = %v, want bfloat16", name, layer.B.Dtype())
|
||||
}
|
||||
break // just check first layer
|
||||
}
|
||||
|
||||
t.Logf("LoRA BFloat16: %d trainable params (half memory vs Float32)",
|
||||
adapter.TotalParams())
|
||||
|
||||
inputIDs := tok.Encode("The capital of France is Paris")
|
||||
seqLen := len(inputIDs) - 1
|
||||
inputTokens := FromValues(inputIDs[:seqLen], 1, seqLen)
|
||||
targetTokens := FromValues(inputIDs[1:], 1, seqLen)
|
||||
Materialize(inputTokens, targetTokens)
|
||||
|
||||
params := adapter.AllTrainableParams()
|
||||
argnums := make([]int, len(params))
|
||||
for i := range argnums {
|
||||
argnums[i] = i
|
||||
}
|
||||
|
||||
opt := NewAdamW(1e-4)
|
||||
var initialLoss, finalLoss float64
|
||||
const numSteps = 5
|
||||
|
||||
for step := 0; step < numSteps; step++ {
|
||||
caches := gemma.NewCache()
|
||||
|
||||
lossFn := func(inputs []*Array) []*Array {
|
||||
adapter.SetAllParams(inputs)
|
||||
logits := gemma.Forward(inputTokens, caches)
|
||||
loss := CrossEntropyLoss(logits, targetTokens)
|
||||
return []*Array{loss}
|
||||
}
|
||||
|
||||
grad := ValueAndGrad(lossFn, argnums...)
|
||||
values, grads, err := grad.Apply(params...)
|
||||
grad.Free()
|
||||
if err != nil {
|
||||
t.Fatalf("step %d: ValueAndGrad failed: %v", step, err)
|
||||
}
|
||||
|
||||
Materialize(append(values, grads...)...)
|
||||
|
||||
loss := values[0].Float()
|
||||
t.Logf("step %d: loss = %.4f (bf16)", step, loss)
|
||||
|
||||
if step == 0 {
|
||||
initialLoss = loss
|
||||
if math.IsNaN(loss) || math.IsInf(loss, 0) {
|
||||
t.Fatalf("initial loss is %f — bf16 may have caused NaN", loss)
|
||||
}
|
||||
}
|
||||
finalLoss = loss
|
||||
|
||||
updated := opt.Step(params, grads)
|
||||
for i := range updated {
|
||||
Materialize(updated[i])
|
||||
}
|
||||
params = updated
|
||||
adapter.SetAllParams(params)
|
||||
}
|
||||
|
||||
t.Logf("bf16 loss: %.4f → %.4f", initialLoss, finalLoss)
|
||||
if finalLoss >= initialLoss {
|
||||
t.Errorf("loss did not decrease with bf16: %.4f → %.4f", initialLoss, finalLoss)
|
||||
}
|
||||
|
||||
ClearCache()
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue