go-ai/mlx
Snider e9973aef3c feat(mlx): add autograd — VJP, JVP, ValueAndGrad, loss functions
Native Go bindings for MLX-C gradient computation on Apple Silicon.
Foundation for LoRA training without Python.

- VJP (reverse-mode autodiff) for backward pass
- JVP (forward-mode autodiff) for directional derivatives
- ValueAndGrad for combined loss + gradient computation
- Checkpoint for memory-efficient gradient recomputation
- CrossEntropyLoss (numerically stable via LogSumExp)
- MSELoss, Log, SumAll, MeanAll, OnesLike helpers
- TakeAlongAxis and LogSumExp ops
- Fix closure callback null vector bug (affects compile.go too)
- Fix Float() returning 0 for float32 arrays

14 tests passing on Metal GPU.

Co-Authored-By: Virgil <virgil@lethean.io>
2026-02-17 17:18:47 +00:00
..
cache refactor(mlx): drop mlx build tag, auto-enable on darwin/arm64 2026-02-17 16:57:41 +00:00
model refactor(mlx): drop mlx build tag, auto-enable on darwin/arm64 2026-02-17 16:57:41 +00:00
sample refactor(mlx): drop mlx build tag, auto-enable on darwin/arm64 2026-02-17 16:57:41 +00:00
tokenizer refactor(mlx): drop mlx build tag, auto-enable on darwin/arm64 2026-02-17 16:57:41 +00:00
array.go feat(mlx): add autograd — VJP, JVP, ValueAndGrad, loss functions 2026-02-17 17:18:47 +00:00
CMakeLists.txt feat: extract AI/ML packages from core/go 2026-02-16 15:25:55 +00:00
compile.go feat(mlx): add autograd — VJP, JVP, ValueAndGrad, loss functions 2026-02-17 17:18:47 +00:00
dtype.go refactor(mlx): drop mlx build tag, auto-enable on darwin/arm64 2026-02-17 16:57:41 +00:00
fast.go refactor(mlx): drop mlx build tag, auto-enable on darwin/arm64 2026-02-17 16:57:41 +00:00
grad.go feat(mlx): add autograd — VJP, JVP, ValueAndGrad, loss functions 2026-02-17 17:18:47 +00:00
grad_test.go feat(mlx): add autograd — VJP, JVP, ValueAndGrad, loss functions 2026-02-17 17:18:47 +00:00
io.go refactor(mlx): drop mlx build tag, auto-enable on darwin/arm64 2026-02-17 16:57:41 +00:00
mlx.go refactor(mlx): drop mlx build tag, auto-enable on darwin/arm64 2026-02-17 16:57:41 +00:00
mlx_stub.go refactor(mlx): drop mlx build tag, auto-enable on darwin/arm64 2026-02-17 16:57:41 +00:00
nn.go refactor(mlx): drop mlx build tag, auto-enable on darwin/arm64 2026-02-17 16:57:41 +00:00
ops.go feat(mlx): add autograd — VJP, JVP, ValueAndGrad, loss functions 2026-02-17 17:18:47 +00:00
random.go refactor(mlx): drop mlx build tag, auto-enable on darwin/arm64 2026-02-17 16:57:41 +00:00
slice.go refactor(mlx): drop mlx build tag, auto-enable on darwin/arm64 2026-02-17 16:57:41 +00:00
stream.go refactor(mlx): drop mlx build tag, auto-enable on darwin/arm64 2026-02-17 16:57:41 +00:00