test(metal): validate gradient checkpointing with real model (Phase 3)

Checkpoint() wraps forward pass to recompute activations during
backward, trading compute for memory. Verified with Gemma3-1B LoRA
training: produces correct gradients (loss 7.15→7.08, matches
non-checkpointed initial loss). Unit test confirms gradient
correctness on simple function (sum(x^2), grad=[2,4,6]).

Co-Authored-By: Virgil <virgil@lethean.io>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Snider 2026-02-19 23:11:15 +00:00
parent fb0692baf3
commit fa08ed1e2a
3 changed files with 125 additions and 1 deletions

View file

@ -21,7 +21,7 @@ Dispatched from core/go orchestration. Pick up tasks in order.
## Phase 3: Training Pipeline
- [x] **LoRA fine-tuning end-to-end** — ✅ Full pipeline validated: load Gemma3-1B → apply LoRA (rank=8, q_proj+v_proj, 745K params) → 5 training steps with cross-entropy loss (7.15→6.31) → save adapter (2.9MB safetensors) → reload and verify weights match. Uses ValueAndGrad + AdamW. 1 new test in train_test.go.
- [ ] **Gradient checkpointing**`grad.go` has VJP but large models will OOM without checkpointing. Add selective recomputation.
- [x] **Gradient checkpointing** — ✅ `Checkpoint()` validates with real model training. Wraps forward pass to recompute activations during backward. Verified: produces correct gradients (loss 7.15→7.08 in 3 steps, matching non-checkpointed initial loss). 2 new tests: unit (grad_test.go) + model (train_test.go).
- [ ] **Mixed precision training** — MLX supports BFloat16/Float16. Add dtype selection for training (currently inference uses model's native dtype).
## Phase 4: Backend Abstraction — ✅ COMPLETE (19 Feb 2026)

View file

@ -309,6 +309,45 @@ func TestCheckpoint(t *testing.T) {
}
}
func TestCheckpoint_GradientFlows(t *testing.T) {
// Checkpoint should produce correct gradients (same as non-checkpointed).
// f(x) = sum(x^2), df/dx = 2x. At x=[1,2,3]: grad=[2,4,6].
fn := func(inputs []*Array) []*Array {
x := inputs[0]
return []*Array{SumAll(Mul(x, x))}
}
cpFn := Checkpoint(fn)
x := FromValues([]float32{1.0, 2.0, 3.0}, 3)
// Gradient through checkpointed function.
grad := ValueAndGrad(func(inputs []*Array) []*Array {
return cpFn(inputs)
})
defer grad.Free()
values, grads, err := grad.Apply(x)
if err != nil {
t.Fatalf("ValueAndGrad through Checkpoint: %v", err)
}
Materialize(values[0], grads[0])
// Value: 1+4+9 = 14
val := values[0].Float()
if math.Abs(val-14.0) > 1e-4 {
t.Errorf("value = %f, want 14.0", val)
}
// Gradients: [2, 4, 6]
gFloats := grads[0].Floats()
expected := []float32{2.0, 4.0, 6.0}
for i, exp := range expected {
if math.Abs(float64(gFloats[i]-exp)) > 1e-4 {
t.Errorf("grad[%d] = %f, want %f", i, gFloats[i], exp)
}
}
}
func TestSumAll(t *testing.T) {
a := FromValues([]float32{1.0, 2.0, 3.0, 4.0}, 2, 2)
result := SumAll(a)

View file

@ -191,3 +191,88 @@ func TestLoRA_EndToEnd(t *testing.T) {
ClearCache()
}
// TestLoRA_GradientCheckpointing validates that wrapping the forward pass in
// Checkpoint produces correct gradients (same loss decrease as non-checkpointed).
func TestLoRA_GradientCheckpointing(t *testing.T) {
modelPath := gemma3Path(t)
model, err := loadModel(modelPath)
if err != nil {
t.Fatalf("loadModel: %v", err)
}
gemma := model.(*GemmaModel)
tok := gemma.Tokenizer()
adapter := gemma.ApplyLoRA(DefaultLoRAConfig())
t.Logf("LoRA: %d trainable params", adapter.TotalParams())
inputIDs := tok.Encode("The capital of France is Paris")
seqLen := len(inputIDs) - 1
inputTokens := FromValues(inputIDs[:seqLen], 1, seqLen)
targetTokens := FromValues(inputIDs[1:], 1, seqLen)
Materialize(inputTokens, targetTokens)
params := adapter.AllTrainableParams()
argnums := make([]int, len(params))
for i := range argnums {
argnums[i] = i
}
opt := NewAdamW(1e-4)
var initialLoss, finalLoss float64
const numSteps = 3
for step := 0; step < numSteps; step++ {
caches := gemma.NewCache()
// Wrap the model forward pass in Checkpoint to recompute activations
// during backward instead of storing them.
checkpointedForward := Checkpoint(func(inputs []*Array) []*Array {
adapter.SetAllParams(inputs)
logits := gemma.Forward(inputTokens, caches)
return []*Array{logits}
})
lossFn := func(inputs []*Array) []*Array {
logits := checkpointedForward(inputs)[0]
loss := CrossEntropyLoss(logits, targetTokens)
return []*Array{loss}
}
grad := ValueAndGrad(lossFn, argnums...)
values, grads, err := grad.Apply(params...)
grad.Free()
if err != nil {
t.Fatalf("step %d: ValueAndGrad failed: %v", step, err)
}
Materialize(append(values, grads...)...)
loss := values[0].Float()
t.Logf("step %d: loss = %.4f (checkpointed)", step, loss)
if step == 0 {
initialLoss = loss
if math.IsNaN(loss) || math.IsInf(loss, 0) {
t.Fatalf("initial loss is %f", loss)
}
}
finalLoss = loss
updated := opt.Step(params, grads)
for i := range updated {
Materialize(updated[i])
}
params = updated
adapter.SetAllParams(params)
}
t.Logf("checkpointed loss: %.4f → %.4f", initialLoss, finalLoss)
if finalLoss >= initialLoss {
t.Errorf("loss did not decrease with checkpointing: %.4f → %.4f", initialLoss, finalLoss)
}
ClearCache()
}