1
Training
Virgil edited this page 2026-02-19 17:58:33 +00:00
Training
LoRA Fine-Tuning
Low-Rank Adaptation for efficient fine-tuning without modifying base model weights.
// Apply LoRA to a model's attention layers
adapter := mlx.NewLoRA(model, mlx.LoRAConfig{
Rank: 8,
Alpha: 16.0,
Dropout: 0.1,
})
// Train
optimizer := mlx.NewAdamW(adapter.Parameters(), mlx.AdamWConfig{
LearningRate: 1e-4,
WeightDecay: 0.01,
})
Saving/Loading Adapters
// Save trained adapter
adapter.Save("/path/to/adapter.safetensors")
// Load and apply to base model
adapter, _ := mlx.LoadLoRA("/path/to/adapter.safetensors")
adapter.Apply(model)
Gradient Computation
VJP (Vector-Jacobian Product) for computing gradients:
// Define loss function
loss := func(params []*mlx.Array) *mlx.Array {
// Forward pass + loss computation
return crossEntropyLoss
}
// Compute gradients
grads := mlx.VJP(loss, params)
AdamW Optimiser
opt := mlx.NewAdamW(params, mlx.AdamWConfig{
LearningRate: 1e-4,
Beta1: 0.9,
Beta2: 0.999,
Epsilon: 1e-8,
WeightDecay: 0.01,
})
// Training loop
for step := range epochs {
grads := mlx.VJP(lossFn, params)
opt.Step(grads)
}
Mixed Precision
MLX natively supports multiple dtypes. Models typically use their native precision:
- Float16 for inference
- BFloat16 for training
- Float32 for loss computation
Casting: array.AsType(mlx.BFloat16)