1 Training
Virgil edited this page 2026-03-11 12:02:35 +00:00

Training Interface

Module: forge.lthn.ai/core/go-inference

The training interface extends TextModel with LoRA fine-tuning capabilities. Backend-specific training operations (optimisers, gradient computation, tensor creation) are provided by the backend package directly (e.g. go-mlx for Metal).

TrainableModel

type TrainableModel interface {
    TextModel
    ApplyLoRA(cfg LoRAConfig) Adapter
    Encode(text string) []int32
    Decode(ids []int32) string
    NumLayers() int
}

Use type assertion to check if a loaded model supports training:

tm, ok := model.(inference.TrainableModel)

Or use the convenience function:

tm, err := inference.LoadTrainable("/path/to/model/")

LoRAConfig

type LoRAConfig struct {
    Rank       int      // Decomposition rank (default 8)
    Alpha      float32  // Scaling factor (default 16)
    TargetKeys []string // Projection layer suffixes (default: q_proj, v_proj)
    BFloat16   bool     // Mixed precision adapter weights
}

DefaultLoRAConfig() returns {Rank: 8, Alpha: 16, TargetKeys: ["q_proj", "v_proj"]}.

Adapter

type Adapter interface {
    TotalParams() int
    Save(path string) error
}

The concrete type is backend-specific (e.g. *metal.LoRAAdapter for go-mlx).

Backend Registration

Backends register via init() with build tags:

// go-mlx: //go:build darwin && arm64
func init() { inference.Register(metal.NewBackend()) }

// go-rocm: //go:build linux && amd64
func init() { inference.Register(rocm.NewBackend()) }