From fa08ed1e2ab5103477e79d82876ce8aacdaa678f Mon Sep 17 00:00:00 2001 From: Snider Date: Thu, 19 Feb 2026 23:11:15 +0000 Subject: [PATCH] test(metal): validate gradient checkpointing with real model (Phase 3) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Checkpoint() wraps forward pass to recompute activations during backward, trading compute for memory. Verified with Gemma3-1B LoRA training: produces correct gradients (loss 7.15→7.08, matches non-checkpointed initial loss). Unit test confirms gradient correctness on simple function (sum(x^2), grad=[2,4,6]). Co-Authored-By: Virgil Co-Authored-By: Claude Opus 4.6 --- TODO.md | 2 +- internal/metal/grad_test.go | 39 +++++++++++++++++ internal/metal/train_test.go | 85 ++++++++++++++++++++++++++++++++++++ 3 files changed, 125 insertions(+), 1 deletion(-) diff --git a/TODO.md b/TODO.md index 9af9d85..6f99708 100644 --- a/TODO.md +++ b/TODO.md @@ -21,7 +21,7 @@ Dispatched from core/go orchestration. Pick up tasks in order. ## Phase 3: Training Pipeline - [x] **LoRA fine-tuning end-to-end** — ✅ Full pipeline validated: load Gemma3-1B → apply LoRA (rank=8, q_proj+v_proj, 745K params) → 5 training steps with cross-entropy loss (7.15→6.31) → save adapter (2.9MB safetensors) → reload and verify weights match. Uses ValueAndGrad + AdamW. 1 new test in train_test.go. -- [ ] **Gradient checkpointing** — `grad.go` has VJP but large models will OOM without checkpointing. Add selective recomputation. +- [x] **Gradient checkpointing** — ✅ `Checkpoint()` validates with real model training. Wraps forward pass to recompute activations during backward. Verified: produces correct gradients (loss 7.15→7.08 in 3 steps, matching non-checkpointed initial loss). 2 new tests: unit (grad_test.go) + model (train_test.go). - [ ] **Mixed precision training** — MLX supports BFloat16/Float16. Add dtype selection for training (currently inference uses model's native dtype). ## Phase 4: Backend Abstraction — ✅ COMPLETE (19 Feb 2026) diff --git a/internal/metal/grad_test.go b/internal/metal/grad_test.go index 59f43ed..e7c045e 100644 --- a/internal/metal/grad_test.go +++ b/internal/metal/grad_test.go @@ -309,6 +309,45 @@ func TestCheckpoint(t *testing.T) { } } +func TestCheckpoint_GradientFlows(t *testing.T) { + // Checkpoint should produce correct gradients (same as non-checkpointed). + // f(x) = sum(x^2), df/dx = 2x. At x=[1,2,3]: grad=[2,4,6]. + fn := func(inputs []*Array) []*Array { + x := inputs[0] + return []*Array{SumAll(Mul(x, x))} + } + cpFn := Checkpoint(fn) + + x := FromValues([]float32{1.0, 2.0, 3.0}, 3) + + // Gradient through checkpointed function. + grad := ValueAndGrad(func(inputs []*Array) []*Array { + return cpFn(inputs) + }) + defer grad.Free() + + values, grads, err := grad.Apply(x) + if err != nil { + t.Fatalf("ValueAndGrad through Checkpoint: %v", err) + } + Materialize(values[0], grads[0]) + + // Value: 1+4+9 = 14 + val := values[0].Float() + if math.Abs(val-14.0) > 1e-4 { + t.Errorf("value = %f, want 14.0", val) + } + + // Gradients: [2, 4, 6] + gFloats := grads[0].Floats() + expected := []float32{2.0, 4.0, 6.0} + for i, exp := range expected { + if math.Abs(float64(gFloats[i]-exp)) > 1e-4 { + t.Errorf("grad[%d] = %f, want %f", i, gFloats[i], exp) + } + } +} + func TestSumAll(t *testing.T) { a := FromValues([]float32{1.0, 2.0, 3.0, 4.0}, 2, 2) result := SumAll(a) diff --git a/internal/metal/train_test.go b/internal/metal/train_test.go index 9d26cf9..a525b1f 100644 --- a/internal/metal/train_test.go +++ b/internal/metal/train_test.go @@ -191,3 +191,88 @@ func TestLoRA_EndToEnd(t *testing.T) { ClearCache() } + +// TestLoRA_GradientCheckpointing validates that wrapping the forward pass in +// Checkpoint produces correct gradients (same loss decrease as non-checkpointed). +func TestLoRA_GradientCheckpointing(t *testing.T) { + modelPath := gemma3Path(t) + + model, err := loadModel(modelPath) + if err != nil { + t.Fatalf("loadModel: %v", err) + } + + gemma := model.(*GemmaModel) + tok := gemma.Tokenizer() + + adapter := gemma.ApplyLoRA(DefaultLoRAConfig()) + t.Logf("LoRA: %d trainable params", adapter.TotalParams()) + + inputIDs := tok.Encode("The capital of France is Paris") + seqLen := len(inputIDs) - 1 + inputTokens := FromValues(inputIDs[:seqLen], 1, seqLen) + targetTokens := FromValues(inputIDs[1:], 1, seqLen) + Materialize(inputTokens, targetTokens) + + params := adapter.AllTrainableParams() + argnums := make([]int, len(params)) + for i := range argnums { + argnums[i] = i + } + + opt := NewAdamW(1e-4) + var initialLoss, finalLoss float64 + const numSteps = 3 + + for step := 0; step < numSteps; step++ { + caches := gemma.NewCache() + + // Wrap the model forward pass in Checkpoint to recompute activations + // during backward instead of storing them. + checkpointedForward := Checkpoint(func(inputs []*Array) []*Array { + adapter.SetAllParams(inputs) + logits := gemma.Forward(inputTokens, caches) + return []*Array{logits} + }) + + lossFn := func(inputs []*Array) []*Array { + logits := checkpointedForward(inputs)[0] + loss := CrossEntropyLoss(logits, targetTokens) + return []*Array{loss} + } + + grad := ValueAndGrad(lossFn, argnums...) + values, grads, err := grad.Apply(params...) + grad.Free() + if err != nil { + t.Fatalf("step %d: ValueAndGrad failed: %v", step, err) + } + + Materialize(append(values, grads...)...) + + loss := values[0].Float() + t.Logf("step %d: loss = %.4f (checkpointed)", step, loss) + + if step == 0 { + initialLoss = loss + if math.IsNaN(loss) || math.IsInf(loss, 0) { + t.Fatalf("initial loss is %f", loss) + } + } + finalLoss = loss + + updated := opt.Step(params, grads) + for i := range updated { + Materialize(updated[i]) + } + params = updated + adapter.SetAllParams(params) + } + + t.Logf("checkpointed loss: %.4f → %.4f", initialLoss, finalLoss) + if finalLoss >= initialLoss { + t.Errorf("loss did not decrease with checkpointing: %.4f → %.4f", initialLoss, finalLoss) + } + + ClearCache() +}