Checkpoint() wraps forward pass to recompute activations during backward, trading compute for memory. Verified with Gemma3-1B LoRA training: produces correct gradients (loss 7.15→7.08, matches non-checkpointed initial loss). Unit test confirms gradient correctness on simple function (sum(x^2), grad=[2,4,6]). Co-Authored-By: Virgil <virgil@lethean.io> Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> |
||
|---|---|---|
| .. | ||
| metal | ||