diff --git a/TODO.md b/TODO.md index 6f99708..d56fb60 100644 --- a/TODO.md +++ b/TODO.md @@ -22,7 +22,7 @@ Dispatched from core/go orchestration. Pick up tasks in order. - [x] **LoRA fine-tuning end-to-end** — ✅ Full pipeline validated: load Gemma3-1B → apply LoRA (rank=8, q_proj+v_proj, 745K params) → 5 training steps with cross-entropy loss (7.15→6.31) → save adapter (2.9MB safetensors) → reload and verify weights match. Uses ValueAndGrad + AdamW. 1 new test in train_test.go. - [x] **Gradient checkpointing** — ✅ `Checkpoint()` validates with real model training. Wraps forward pass to recompute activations during backward. Verified: produces correct gradients (loss 7.15→7.08 in 3 steps, matching non-checkpointed initial loss). 2 new tests: unit (grad_test.go) + model (train_test.go). -- [ ] **Mixed precision training** — MLX supports BFloat16/Float16. Add dtype selection for training (currently inference uses model's native dtype). +- [x] **Mixed precision training** — ✅ `LoRAConfig.DType` selects training dtype for A/B matrices. BFloat16 validated: loss 7.15→6.29 in 5 steps, matches Float32 accuracy with half param memory. MLX auto-promotes for cross-dtype ops. 1 new test in train_test.go. ## Phase 4: Backend Abstraction — ✅ COMPLETE (19 Feb 2026) diff --git a/internal/metal/gemma3.go b/internal/metal/gemma3.go index 0c9c7a3..4542200 100644 --- a/internal/metal/gemma3.go +++ b/internal/metal/gemma3.go @@ -440,7 +440,7 @@ func (m *GemmaModel) ApplyLoRA(cfg LoRAConfig) *LoRAAdapter { proj = layer.Attention.OProj } if proj != nil { - lora := NewLoRALinear(proj, cfg.Rank, cfg.Alpha) + lora := NewLoRALinear(proj, cfg.Rank, cfg.Alpha, cfg.DType) proj.LoRA = lora adapter.Layers[prefix+"."+target] = lora } diff --git a/internal/metal/lora.go b/internal/metal/lora.go index 73d4d33..f42abcd 100644 --- a/internal/metal/lora.go +++ b/internal/metal/lora.go @@ -36,7 +36,12 @@ type LoRALinear struct { // NewLoRALinear wraps an existing Linear layer with LoRA adapters. // rank: decomposition rank (typically 4, 8, or 16) // alpha: scaling factor (typically 2*rank) -func NewLoRALinear(base *Linear, rank int, alpha float32) *LoRALinear { +// dtype: optional training dtype (pass 0 or DTypeFloat32 for default) +func NewLoRALinear(base *Linear, rank int, alpha float32, dtype ...DType) *LoRALinear { + dt := DTypeFloat32 + if len(dtype) > 0 && dtype[0] != 0 { + dt = dtype[0] + } // Determine dimensions from the base weight. // Weight shape is [out_features, in_features] for standard linear, // or quantized shape which we handle via the base layer. @@ -56,10 +61,10 @@ func NewLoRALinear(base *Linear, rank int, alpha float32) *LoRALinear { // A: Kaiming normal initialisation — N(0, 1/sqrt(in_features)) stddev := float32(1.0 / math.Sqrt(float64(inFeatures))) - a := RandomNormal(0, stddev, []int32{int32(rank), inFeatures}, DTypeFloat32) + a := RandomNormal(0, stddev, []int32{int32(rank), inFeatures}, dt) // B: zero initialisation — LoRA starts as identity (no change to base) - b := Zeros([]int32{outFeatures, int32(rank)}, DTypeFloat32) + b := Zeros([]int32{outFeatures, int32(rank)}, dt) Materialize(a, b) @@ -112,6 +117,7 @@ type LoRAConfig struct { Rank int // Decomposition rank (default 8) Alpha float32 // Scaling factor (default 16) TargetKeys []string // Weight name suffixes to target (default: q_proj, v_proj) + DType DType // Training dtype for A/B (default Float32; use BFloat16 for mixed precision) } // DefaultLoRAConfig returns the standard LoRA configuration for LLM fine-tuning. @@ -120,6 +126,7 @@ func DefaultLoRAConfig() LoRAConfig { Rank: 8, Alpha: 16, TargetKeys: []string{"q_proj", "v_proj"}, + DType: DTypeFloat32, } } diff --git a/internal/metal/qwen3.go b/internal/metal/qwen3.go index 7923320..5b4e7a3 100644 --- a/internal/metal/qwen3.go +++ b/internal/metal/qwen3.go @@ -360,7 +360,7 @@ func (m *Qwen3Model) ApplyLoRA(cfg LoRAConfig) *LoRAAdapter { proj = layer.Attention.OProj } if proj != nil { - lora := NewLoRALinear(proj, cfg.Rank, cfg.Alpha) + lora := NewLoRALinear(proj, cfg.Rank, cfg.Alpha, cfg.DType) proj.LoRA = lora adapter.Layers[prefix+"."+target] = lora } diff --git a/internal/metal/train_test.go b/internal/metal/train_test.go index a525b1f..864d745 100644 --- a/internal/metal/train_test.go +++ b/internal/metal/train_test.go @@ -276,3 +276,98 @@ func TestLoRA_GradientCheckpointing(t *testing.T) { ClearCache() } + +// TestLoRA_MixedPrecision validates training with BFloat16 LoRA parameters. +// The base model stays in its native dtype; LoRA A/B are BFloat16. +// MLX auto-promotes for cross-dtype operations. +func TestLoRA_MixedPrecision(t *testing.T) { + modelPath := gemma3Path(t) + + model, err := loadModel(modelPath) + if err != nil { + t.Fatalf("loadModel: %v", err) + } + + gemma := model.(*GemmaModel) + tok := gemma.Tokenizer() + + // Apply LoRA with BFloat16 parameters. + cfg := DefaultLoRAConfig() + cfg.DType = DTypeBFloat16 + adapter := gemma.ApplyLoRA(cfg) + + // Verify A/B are actually BFloat16. + for name, layer := range adapter.Layers { + if layer.A.Dtype() != DTypeBFloat16 { + t.Errorf("%s: A dtype = %v, want bfloat16", name, layer.A.Dtype()) + } + if layer.B.Dtype() != DTypeBFloat16 { + t.Errorf("%s: B dtype = %v, want bfloat16", name, layer.B.Dtype()) + } + break // just check first layer + } + + t.Logf("LoRA BFloat16: %d trainable params (half memory vs Float32)", + adapter.TotalParams()) + + inputIDs := tok.Encode("The capital of France is Paris") + seqLen := len(inputIDs) - 1 + inputTokens := FromValues(inputIDs[:seqLen], 1, seqLen) + targetTokens := FromValues(inputIDs[1:], 1, seqLen) + Materialize(inputTokens, targetTokens) + + params := adapter.AllTrainableParams() + argnums := make([]int, len(params)) + for i := range argnums { + argnums[i] = i + } + + opt := NewAdamW(1e-4) + var initialLoss, finalLoss float64 + const numSteps = 5 + + for step := 0; step < numSteps; step++ { + caches := gemma.NewCache() + + lossFn := func(inputs []*Array) []*Array { + adapter.SetAllParams(inputs) + logits := gemma.Forward(inputTokens, caches) + loss := CrossEntropyLoss(logits, targetTokens) + return []*Array{loss} + } + + grad := ValueAndGrad(lossFn, argnums...) + values, grads, err := grad.Apply(params...) + grad.Free() + if err != nil { + t.Fatalf("step %d: ValueAndGrad failed: %v", step, err) + } + + Materialize(append(values, grads...)...) + + loss := values[0].Float() + t.Logf("step %d: loss = %.4f (bf16)", step, loss) + + if step == 0 { + initialLoss = loss + if math.IsNaN(loss) || math.IsInf(loss, 0) { + t.Fatalf("initial loss is %f — bf16 may have caused NaN", loss) + } + } + finalLoss = loss + + updated := opt.Step(params, grads) + for i := range updated { + Materialize(updated[i]) + } + params = updated + adapter.SetAllParams(params) + } + + t.Logf("bf16 loss: %.4f → %.4f", initialLoss, finalLoss) + if finalLoss >= initialLoss { + t.Errorf("loss did not decrease with bf16: %.4f → %.4f", initialLoss, finalLoss) + } + + ClearCache() +}