feat(ml): add core ml train command for LoRA fine-tuning
Native MLX LoRA training on Apple Silicon — no Python required. Reads chat-format JSONL, applies LoRA to target projections, trains with AdamW + masked cross-entropy loss on assistant tokens. Usage: core ml train --model-path /path/to/model --data training.jsonl Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
parent
54b98e2d15
commit
59812e5857
2 changed files with 365 additions and 0 deletions
358
cmd/ml/cmd_train.go
Normal file
358
cmd/ml/cmd_train.go
Normal file
|
|
@ -0,0 +1,358 @@
|
|||
//go:build darwin && arm64
|
||||
|
||||
package ml
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"forge.lthn.ai/core/go-ai/ml"
|
||||
"forge.lthn.ai/core/go-ai/mlx"
|
||||
"forge.lthn.ai/core/go-ai/mlx/model"
|
||||
"forge.lthn.ai/core/go-ai/mlx/tokenizer"
|
||||
"forge.lthn.ai/core/go/pkg/cli"
|
||||
)
|
||||
|
||||
var trainCmd = &cli.Command{
|
||||
Use: "train",
|
||||
Short: "LoRA fine-tune a model on JSONL training data",
|
||||
Long: `Fine-tunes a local MLX model using LoRA (Low-Rank Adaptation).
|
||||
|
||||
Reads chat-format JSONL training data and trains LoRA adapter weights
|
||||
using AdamW optimiser with cross-entropy loss on assistant tokens only.
|
||||
|
||||
Training data format (one JSON object per line):
|
||||
{"messages": [{"role": "system", "content": "..."}, {"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]}`,
|
||||
RunE: runTrain,
|
||||
}
|
||||
|
||||
var (
|
||||
trainModelPath string
|
||||
trainData string
|
||||
trainOutput string
|
||||
trainRank int
|
||||
trainAlpha float64
|
||||
trainLR float64
|
||||
trainEpochs int
|
||||
trainMaxSeqLen int
|
||||
trainTargets string
|
||||
trainMemoryLimit int
|
||||
)
|
||||
|
||||
func init() {
|
||||
trainCmd.Flags().StringVar(&trainModelPath, "model-path", "", "Path to model directory (required)")
|
||||
trainCmd.Flags().StringVar(&trainData, "data", "", "Training JSONL file (required)")
|
||||
trainCmd.Flags().StringVar(&trainOutput, "output", "adapters.safetensors", "Output adapter file")
|
||||
trainCmd.Flags().IntVar(&trainRank, "rank", 8, "LoRA decomposition rank")
|
||||
trainCmd.Flags().Float64Var(&trainAlpha, "alpha", 16, "LoRA scaling factor")
|
||||
trainCmd.Flags().Float64Var(&trainLR, "lr", 1e-4, "Learning rate")
|
||||
trainCmd.Flags().IntVar(&trainEpochs, "epochs", 1, "Number of training epochs")
|
||||
trainCmd.Flags().IntVar(&trainMaxSeqLen, "max-seq-len", 512, "Maximum sequence length (tokens)")
|
||||
trainCmd.Flags().StringVar(&trainTargets, "targets", "q_proj,v_proj", "Comma-separated projection targets for LoRA")
|
||||
trainCmd.Flags().IntVar(&trainMemoryLimit, "memory-limit", 24, "Metal memory limit in GB")
|
||||
trainCmd.MarkFlagRequired("model-path")
|
||||
trainCmd.MarkFlagRequired("data")
|
||||
}
|
||||
|
||||
// trainSample holds a tokenised training example.
|
||||
type trainSample struct {
|
||||
Tokens []int32 // Full token sequence
|
||||
Mask []int32 // 1 for assistant tokens, 0 for prompt tokens
|
||||
}
|
||||
|
||||
func runTrain(cmd *cli.Command, args []string) error {
|
||||
start := time.Now()
|
||||
|
||||
// --- Load model ---
|
||||
slog.Info("loading model", "path", trainModelPath)
|
||||
m, err := model.LoadModel(trainModelPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("load model: %w", err)
|
||||
}
|
||||
|
||||
mlx.SetCacheLimit(uint64(trainMemoryLimit) * 1024 * 1024 * 1024)
|
||||
mlx.SetMemoryLimit(uint64(trainMemoryLimit) * 1024 * 1024 * 1024)
|
||||
|
||||
tok := m.Tokenizer()
|
||||
slog.Info("model loaded",
|
||||
"type", m.ModelType(),
|
||||
"layers", m.NumLayers(),
|
||||
)
|
||||
|
||||
// --- Apply LoRA ---
|
||||
targets := strings.Split(trainTargets, ",")
|
||||
cfg := mlx.LoRAConfig{
|
||||
Rank: trainRank,
|
||||
Alpha: float32(trainAlpha),
|
||||
TargetKeys: targets,
|
||||
}
|
||||
|
||||
adapter := m.ApplyLoRA(cfg)
|
||||
slog.Info("LoRA applied",
|
||||
"rank", cfg.Rank,
|
||||
"alpha", cfg.Alpha,
|
||||
"targets", targets,
|
||||
"trainable_params", adapter.TotalParams(),
|
||||
"layers", len(adapter.Layers),
|
||||
)
|
||||
|
||||
// --- Load training data ---
|
||||
samples, err := loadTrainingSamples(trainData, tok, m.ModelType(), trainMaxSeqLen)
|
||||
if err != nil {
|
||||
return fmt.Errorf("load training data: %w", err)
|
||||
}
|
||||
slog.Info("training data loaded", "samples", len(samples))
|
||||
|
||||
if len(samples) == 0 {
|
||||
return fmt.Errorf("no training samples loaded")
|
||||
}
|
||||
|
||||
// --- Training loop ---
|
||||
params := adapter.AllTrainableParams()
|
||||
opt := mlx.NewAdamW(trainLR)
|
||||
|
||||
// Build argument indices for ValueAndGrad (all params)
|
||||
argIndices := make([]int, len(params))
|
||||
for i := range argIndices {
|
||||
argIndices[i] = i
|
||||
}
|
||||
|
||||
var totalLoss float64
|
||||
var totalSteps int
|
||||
|
||||
for epoch := 0; epoch < trainEpochs; epoch++ {
|
||||
var epochLoss float64
|
||||
epochStart := time.Now()
|
||||
|
||||
for si, sample := range samples {
|
||||
// Build token tensors: input = tokens[:-1], target = tokens[1:]
|
||||
seqLen := len(sample.Tokens)
|
||||
if seqLen < 2 {
|
||||
continue
|
||||
}
|
||||
|
||||
inputTokens := sample.Tokens[:seqLen-1]
|
||||
targetTokens := sample.Tokens[1:]
|
||||
maskTokens := sample.Mask[1:] // mask aligned with targets
|
||||
|
||||
inputArr := mlx.FromValues(inputTokens, 1, len(inputTokens))
|
||||
targetArr := mlx.FromValues(targetTokens, 1, len(targetTokens))
|
||||
|
||||
// Build float32 mask
|
||||
maskF32 := make([]float32, len(maskTokens))
|
||||
for i, m := range maskTokens {
|
||||
maskF32[i] = float32(m)
|
||||
}
|
||||
maskArr := mlx.FromValues(maskF32, 1, len(maskF32))
|
||||
mlx.Materialize(inputArr, targetArr, maskArr)
|
||||
|
||||
// Loss function closure — takes LoRA params as inputs
|
||||
lossFn := func(inputs []*mlx.Array) []*mlx.Array {
|
||||
// Set LoRA params from inputs
|
||||
adapter.SetAllParams(inputs)
|
||||
|
||||
// Forward pass with fresh caches (no KV caching for training)
|
||||
caches := m.NewCache()
|
||||
logits := m.Forward(inputArr, caches)
|
||||
|
||||
// Cast targets to int32 for take_along_axis
|
||||
loss := mlx.MaskedCrossEntropyLoss(logits, targetArr, maskArr)
|
||||
return []*mlx.Array{loss}
|
||||
}
|
||||
|
||||
// Compute value and gradients
|
||||
grad := mlx.ValueAndGrad(lossFn, argIndices...)
|
||||
values, grads, err := grad.Apply(params...)
|
||||
grad.Free()
|
||||
if err != nil {
|
||||
return fmt.Errorf("epoch %d sample %d: gradient failed: %w", epoch, si, err)
|
||||
}
|
||||
|
||||
mlx.Materialize(append(values, grads...)...)
|
||||
|
||||
loss := values[0].Float()
|
||||
epochLoss += loss
|
||||
totalSteps++
|
||||
|
||||
// Update parameters
|
||||
params = opt.Step(params, grads)
|
||||
adapter.SetAllParams(params)
|
||||
mlx.Materialize(params...)
|
||||
|
||||
// Periodic cleanup
|
||||
if totalSteps%4 == 0 {
|
||||
runtime.GC()
|
||||
mlx.ClearCache()
|
||||
}
|
||||
|
||||
// Log progress
|
||||
if (si+1)%10 == 0 || si == len(samples)-1 {
|
||||
avgLoss := epochLoss / float64(si+1)
|
||||
slog.Info("training",
|
||||
"epoch", epoch+1,
|
||||
"step", fmt.Sprintf("%d/%d", si+1, len(samples)),
|
||||
"loss", fmt.Sprintf("%.4f", loss),
|
||||
"avg_loss", fmt.Sprintf("%.4f", avgLoss),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
totalLoss = epochLoss / float64(len(samples))
|
||||
elapsed := time.Since(epochStart)
|
||||
slog.Info("epoch complete",
|
||||
"epoch", epoch+1,
|
||||
"avg_loss", fmt.Sprintf("%.4f", totalLoss),
|
||||
"duration", elapsed.Round(time.Second),
|
||||
"samples_per_sec", fmt.Sprintf("%.1f", float64(len(samples))/elapsed.Seconds()),
|
||||
)
|
||||
}
|
||||
|
||||
// --- Save adapter ---
|
||||
if err := adapter.Save(trainOutput); err != nil {
|
||||
return fmt.Errorf("save adapter: %w", err)
|
||||
}
|
||||
|
||||
slog.Info("training complete",
|
||||
"output", trainOutput,
|
||||
"total_steps", totalSteps,
|
||||
"final_loss", fmt.Sprintf("%.4f", totalLoss),
|
||||
"duration", time.Since(start).Round(time.Second),
|
||||
"trainable_params", adapter.TotalParams(),
|
||||
)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// loadTrainingSamples reads JSONL and tokenises each conversation.
|
||||
func loadTrainingSamples(path string, tok *tokenizer.Tokenizer, modelType string, maxSeqLen int) ([]trainSample, error) {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
var samples []trainSample
|
||||
scanner := bufio.NewScanner(f)
|
||||
scanner.Buffer(make([]byte, 1<<20), 1<<20) // 1MB line buffer
|
||||
|
||||
lineNum := 0
|
||||
for scanner.Scan() {
|
||||
lineNum++
|
||||
line := strings.TrimSpace(scanner.Text())
|
||||
if line == "" || strings.HasPrefix(line, "#") {
|
||||
continue
|
||||
}
|
||||
|
||||
var entry struct {
|
||||
Messages []ml.Message `json:"messages"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(line), &entry); err != nil {
|
||||
slog.Warn("skipping invalid line", "line", lineNum, "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
if len(entry.Messages) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
sample := tokeniseConversation(entry.Messages, tok, modelType, maxSeqLen)
|
||||
if sample != nil {
|
||||
samples = append(samples, *sample)
|
||||
}
|
||||
}
|
||||
|
||||
return samples, scanner.Err()
|
||||
}
|
||||
|
||||
// tokeniseConversation formats and tokenises a conversation, creating a mask
|
||||
// that is 1 for assistant tokens and 0 for system/user tokens.
|
||||
func tokeniseConversation(messages []ml.Message, tok *tokenizer.Tokenizer, modelType string, maxSeqLen int) *trainSample {
|
||||
// Strategy: tokenise the full conversation, then tokenise just the prefix
|
||||
// (non-assistant parts) to determine the mask boundary.
|
||||
|
||||
// Build full conversation text
|
||||
fullText := formatConversation(messages, modelType, true)
|
||||
fullTokens := tok.Encode(fullText)
|
||||
|
||||
if len(fullTokens) < 2 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Truncate to max sequence length
|
||||
if len(fullTokens) > maxSeqLen {
|
||||
fullTokens = fullTokens[:maxSeqLen]
|
||||
}
|
||||
|
||||
// Build mask: tokenise prefix (everything up to last assistant response)
|
||||
// then mark remaining tokens as assistant (mask=1)
|
||||
prefixText := formatConversation(messages, modelType, false)
|
||||
prefixTokens := tok.Encode(prefixText)
|
||||
|
||||
mask := make([]int32, len(fullTokens))
|
||||
for i := range mask {
|
||||
if i >= len(prefixTokens) {
|
||||
mask[i] = 1 // assistant token
|
||||
}
|
||||
}
|
||||
|
||||
return &trainSample{
|
||||
Tokens: fullTokens,
|
||||
Mask: mask,
|
||||
}
|
||||
}
|
||||
|
||||
// formatConversation formats messages using the model's chat template.
|
||||
// If includeAssistant is false, only formats up to the last assistant turn header.
|
||||
func formatConversation(messages []ml.Message, modelType string, includeAssistant bool) string {
|
||||
switch modelType {
|
||||
case "qwen3":
|
||||
return formatQwen3Train(messages, includeAssistant)
|
||||
default:
|
||||
return formatGemmaTrain(messages, includeAssistant)
|
||||
}
|
||||
}
|
||||
|
||||
func formatQwen3Train(messages []ml.Message, includeAssistant bool) string {
|
||||
var sb strings.Builder
|
||||
for _, msg := range messages {
|
||||
if msg.Role == "assistant" && !includeAssistant {
|
||||
// Write the assistant header but not the content
|
||||
sb.WriteString("<|im_start|>assistant\n")
|
||||
return sb.String()
|
||||
}
|
||||
switch msg.Role {
|
||||
case "system":
|
||||
sb.WriteString(fmt.Sprintf("<|im_start|>system\n%s<|im_end|>\n", msg.Content))
|
||||
case "user":
|
||||
sb.WriteString(fmt.Sprintf("<|im_start|>user\n%s<|im_end|>\n", msg.Content))
|
||||
case "assistant":
|
||||
sb.WriteString(fmt.Sprintf("<|im_start|>assistant\n%s<|im_end|>\n", msg.Content))
|
||||
}
|
||||
}
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
func formatGemmaTrain(messages []ml.Message, includeAssistant bool) string {
|
||||
var sb strings.Builder
|
||||
for _, msg := range messages {
|
||||
if msg.Role == "assistant" && !includeAssistant {
|
||||
sb.WriteString("<start_of_turn>model\n")
|
||||
return sb.String()
|
||||
}
|
||||
switch msg.Role {
|
||||
case "user":
|
||||
sb.WriteString(fmt.Sprintf("<start_of_turn>user\n%s<end_of_turn>\n", msg.Content))
|
||||
case "assistant":
|
||||
sb.WriteString(fmt.Sprintf("<start_of_turn>model\n%s<end_of_turn>\n", msg.Content))
|
||||
case "system":
|
||||
sb.WriteString(fmt.Sprintf("<start_of_turn>user\n[System: %s]<end_of_turn>\n", msg.Content))
|
||||
}
|
||||
}
|
||||
return sb.String()
|
||||
}
|
||||
7
cmd/ml/cmd_train_init.go
Normal file
7
cmd/ml/cmd_train_init.go
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
//go:build darwin && arm64
|
||||
|
||||
package ml
|
||||
|
||||
func init() {
|
||||
mlCmd.AddCommand(trainCmd)
|
||||
}
|
||||
Loading…
Add table
Reference in a new issue