1 Training
Virgil edited this page 2026-02-19 17:58:33 +00:00

Training

LoRA Fine-Tuning

Low-Rank Adaptation for efficient fine-tuning without modifying base model weights.

// Apply LoRA to a model's attention layers
adapter := mlx.NewLoRA(model, mlx.LoRAConfig{
    Rank:    8,
    Alpha:   16.0,
    Dropout: 0.1,
})

// Train
optimizer := mlx.NewAdamW(adapter.Parameters(), mlx.AdamWConfig{
    LearningRate: 1e-4,
    WeightDecay:  0.01,
})

Saving/Loading Adapters

// Save trained adapter
adapter.Save("/path/to/adapter.safetensors")

// Load and apply to base model
adapter, _ := mlx.LoadLoRA("/path/to/adapter.safetensors")
adapter.Apply(model)

Gradient Computation

VJP (Vector-Jacobian Product) for computing gradients:

// Define loss function
loss := func(params []*mlx.Array) *mlx.Array {
    // Forward pass + loss computation
    return crossEntropyLoss
}

// Compute gradients
grads := mlx.VJP(loss, params)

AdamW Optimiser

opt := mlx.NewAdamW(params, mlx.AdamWConfig{
    LearningRate: 1e-4,
    Beta1:        0.9,
    Beta2:        0.999,
    Epsilon:      1e-8,
    WeightDecay:  0.01,
})

// Training loop
for step := range epochs {
    grads := mlx.VJP(lossFn, params)
    opt.Step(grads)
}

Mixed Precision

MLX natively supports multiple dtypes. Models typically use their native precision:

  • Float16 for inference
  • BFloat16 for training
  • Float32 for loss computation

Casting: array.AsType(mlx.BFloat16)