go-mlx/internal
Snider fa08ed1e2a test(metal): validate gradient checkpointing with real model (Phase 3)
Checkpoint() wraps forward pass to recompute activations during
backward, trading compute for memory. Verified with Gemma3-1B LoRA
training: produces correct gradients (loss 7.15→7.08, matches
non-checkpointed initial loss). Unit test confirms gradient
correctness on simple function (sum(x^2), grad=[2,4,6]).

Co-Authored-By: Virgil <virgil@lethean.io>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-19 23:11:15 +00:00
..
metal test(metal): validate gradient checkpointing with real model (Phase 3) 2026-02-19 23:11:15 +00:00