From fb0692baf3c1879fe07eaeac26fe2c1368838d08 Mon Sep 17 00:00:00 2001 From: Snider Date: Thu, 19 Feb 2026 23:09:16 +0000 Subject: [PATCH] test(metal): add LoRA end-to-end training pipeline test (Phase 3) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Validates full pipeline: load Gemma3-1B → apply LoRA (rank=8, 745K params across 52 layers) → train 5 steps with cross-entropy loss (7.15→6.31) → save adapter to safetensors → reload and verify all weights match. Uses ValueAndGrad for autograd + AdamW optimiser. Co-Authored-By: Virgil Co-Authored-By: Claude Opus 4.6 --- TODO.md | 2 +- internal/metal/train_test.go | 193 +++++++++++++++++++++++++++++++++++ 2 files changed, 194 insertions(+), 1 deletion(-) create mode 100644 internal/metal/train_test.go diff --git a/TODO.md b/TODO.md index 0dea240..9af9d85 100644 --- a/TODO.md +++ b/TODO.md @@ -20,7 +20,7 @@ Dispatched from core/go orchestration. Pick up tasks in order. ## Phase 3: Training Pipeline -- [ ] **LoRA fine-tuning end-to-end** — `lora.go` has the adapter but no integration test showing: load base model → apply LoRA → train on small dataset → save adapter → reload. Critical for LEM Lab. +- [x] **LoRA fine-tuning end-to-end** — ✅ Full pipeline validated: load Gemma3-1B → apply LoRA (rank=8, q_proj+v_proj, 745K params) → 5 training steps with cross-entropy loss (7.15→6.31) → save adapter (2.9MB safetensors) → reload and verify weights match. Uses ValueAndGrad + AdamW. 1 new test in train_test.go. - [ ] **Gradient checkpointing** — `grad.go` has VJP but large models will OOM without checkpointing. Add selective recomputation. - [ ] **Mixed precision training** — MLX supports BFloat16/Float16. Add dtype selection for training (currently inference uses model's native dtype). diff --git a/internal/metal/train_test.go b/internal/metal/train_test.go new file mode 100644 index 0000000..9d26cf9 --- /dev/null +++ b/internal/metal/train_test.go @@ -0,0 +1,193 @@ +//go:build darwin && arm64 + +package metal + +import ( + "math" + "os" + "path/filepath" + "testing" +) + +// gemma3Path returns the path to a Gemma3-1B model, or skips the test. +func gemma3Path(t *testing.T) string { + t.Helper() + paths := []string{ + "/Volumes/Data/lem/gemma-3-1b-it-base", + "/Volumes/Data/lem/safetensors/gemma-3/", + } + for _, p := range paths { + if _, err := os.Stat(p); err == nil { + return p + } + } + t.Skip("no Gemma3 model available") + return "" +} + +// TestLoRA_EndToEnd validates the full LoRA training pipeline: +// load base model → apply LoRA → train on small data → save adapter → reload. +func TestLoRA_EndToEnd(t *testing.T) { + modelPath := gemma3Path(t) + + // Step 1: Load base model. + model, err := loadModel(modelPath) + if err != nil { + t.Fatalf("loadModel: %v", err) + } + + gemma := model.(*GemmaModel) + tok := gemma.Tokenizer() + + // Step 2: Apply LoRA to Q and V projections. + cfg := DefaultLoRAConfig() // rank=8, alpha=16, targets=[q_proj, v_proj] + adapter := gemma.ApplyLoRA(cfg) + + numParams := adapter.TotalParams() + t.Logf("LoRA applied: %d trainable parameters across %d layers", + numParams, len(adapter.Layers)) + + if numParams == 0 { + t.Fatal("no trainable parameters") + } + + // Step 3: Create a tiny training example. + // Encode a short sequence, use shifted tokens as targets. + inputIDs := tok.Encode("The capital of France is Paris") + if len(inputIDs) < 3 { + t.Fatalf("encoded too few tokens: %d", len(inputIDs)) + } + t.Logf("Training tokens: %v (len=%d)", inputIDs, len(inputIDs)) + + seqLen := len(inputIDs) - 1 // input is all but last, target is all but first + inputTokens := FromValues(inputIDs[:seqLen], 1, seqLen) + targetTokens := FromValues(inputIDs[1:], 1, seqLen) + Materialize(inputTokens, targetTokens) + + // Step 4: Run a few training steps. + params := adapter.AllTrainableParams() + argnums := make([]int, len(params)) + for i := range argnums { + argnums[i] = i + } + + opt := NewAdamW(1e-4) + var initialLoss, finalLoss float64 + const numSteps = 5 + + for step := 0; step < numSteps; step++ { + // Fresh caches each step (stateful — can't reuse across gradient calls). + caches := gemma.NewCache() + + lossFn := func(inputs []*Array) []*Array { + adapter.SetAllParams(inputs) + logits := gemma.Forward(inputTokens, caches) + // logits is [1, seqLen, vocab] — compute cross-entropy against targets + loss := CrossEntropyLoss(logits, targetTokens) + return []*Array{loss} + } + + grad := ValueAndGrad(lossFn, argnums...) + values, grads, err := grad.Apply(params...) + grad.Free() + if err != nil { + t.Fatalf("step %d: ValueAndGrad failed: %v", step, err) + } + + Materialize(append(values, grads...)...) + + loss := values[0].Float() + t.Logf("step %d: loss = %.4f", step, loss) + + if step == 0 { + initialLoss = loss + if math.IsNaN(loss) || math.IsInf(loss, 0) { + t.Fatalf("initial loss is %f", loss) + } + } + finalLoss = loss + + // Update params. + updated := opt.Step(params, grads) + for i := range updated { + Materialize(updated[i]) + } + params = updated + adapter.SetAllParams(params) + } + + // Verify loss decreased. + t.Logf("loss: %.4f → %.4f", initialLoss, finalLoss) + if finalLoss >= initialLoss { + t.Errorf("loss did not decrease: %.4f → %.4f", initialLoss, finalLoss) + } + + // Step 5: Save adapter. + savePath := filepath.Join(t.TempDir(), "adapter.safetensors") + if err := adapter.Save(savePath); err != nil { + t.Fatalf("adapter.Save: %v", err) + } + + info, err := os.Stat(savePath) + if err != nil { + t.Fatalf("saved adapter not found: %v", err) + } + t.Logf("adapter saved: %s (%d bytes)", savePath, info.Size()) + + // Step 6: Reload and verify weights match. + loaded, err := LoadAllSafetensors(savePath) + if err != nil { + t.Fatalf("LoadAllSafetensors: %v", err) + } + + for _, name := range adapter.SortedNames() { + layer := adapter.Layers[name] + aKey := name + ".lora_a" + bKey := name + ".lora_b" + + loadedA, ok := loaded[aKey] + if !ok { + t.Errorf("missing %s in saved adapter", aKey) + continue + } + loadedB, ok := loaded[bKey] + if !ok { + t.Errorf("missing %s in saved adapter", bKey) + continue + } + + Materialize(loadedA, loadedB) + + // Compare A weights. + origA := layer.A.Floats() + reloadA := loadedA.Floats() + if len(origA) != len(reloadA) { + t.Errorf("%s: A size mismatch: %d vs %d", name, len(origA), len(reloadA)) + continue + } + for i := range origA { + if math.Abs(float64(origA[i]-reloadA[i])) > 1e-6 { + t.Errorf("%s: A[%d] = %f, reloaded = %f", name, i, origA[i], reloadA[i]) + break + } + } + + // Compare B weights. + origB := layer.B.Floats() + reloadB := loadedB.Floats() + if len(origB) != len(reloadB) { + t.Errorf("%s: B size mismatch: %d vs %d", name, len(origB), len(reloadB)) + continue + } + for i := range origB { + if math.Abs(float64(origB[i]-reloadB[i])) > 1e-6 { + t.Errorf("%s: B[%d] = %f, reloaded = %f", name, i, origB[i], reloadB[i]) + break + } + } + } + + t.Logf("all %d adapter layers verified after reload", len(adapter.Layers)) + + ClearCache() +}