From 59812e585770ab55fd9d877bc7a531fe941fef34 Mon Sep 17 00:00:00 2001 From: Snider Date: Tue, 17 Feb 2026 17:37:54 +0000 Subject: [PATCH] feat(ml): add core ml train command for LoRA fine-tuning MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Native MLX LoRA training on Apple Silicon — no Python required. Reads chat-format JSONL, applies LoRA to target projections, trains with AdamW + masked cross-entropy loss on assistant tokens. Usage: core ml train --model-path /path/to/model --data training.jsonl Co-Authored-By: Virgil --- cmd/ml/cmd_train.go | 358 +++++++++++++++++++++++++++++++++++++++ cmd/ml/cmd_train_init.go | 7 + 2 files changed, 365 insertions(+) create mode 100644 cmd/ml/cmd_train.go create mode 100644 cmd/ml/cmd_train_init.go diff --git a/cmd/ml/cmd_train.go b/cmd/ml/cmd_train.go new file mode 100644 index 00000000..0123bf71 --- /dev/null +++ b/cmd/ml/cmd_train.go @@ -0,0 +1,358 @@ +//go:build darwin && arm64 + +package ml + +import ( + "bufio" + "encoding/json" + "fmt" + "log/slog" + "os" + "runtime" + "strings" + "time" + + "forge.lthn.ai/core/go-ai/ml" + "forge.lthn.ai/core/go-ai/mlx" + "forge.lthn.ai/core/go-ai/mlx/model" + "forge.lthn.ai/core/go-ai/mlx/tokenizer" + "forge.lthn.ai/core/go/pkg/cli" +) + +var trainCmd = &cli.Command{ + Use: "train", + Short: "LoRA fine-tune a model on JSONL training data", + Long: `Fine-tunes a local MLX model using LoRA (Low-Rank Adaptation). + +Reads chat-format JSONL training data and trains LoRA adapter weights +using AdamW optimiser with cross-entropy loss on assistant tokens only. + +Training data format (one JSON object per line): + {"messages": [{"role": "system", "content": "..."}, {"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]}`, + RunE: runTrain, +} + +var ( + trainModelPath string + trainData string + trainOutput string + trainRank int + trainAlpha float64 + trainLR float64 + trainEpochs int + trainMaxSeqLen int + trainTargets string + trainMemoryLimit int +) + +func init() { + trainCmd.Flags().StringVar(&trainModelPath, "model-path", "", "Path to model directory (required)") + trainCmd.Flags().StringVar(&trainData, "data", "", "Training JSONL file (required)") + trainCmd.Flags().StringVar(&trainOutput, "output", "adapters.safetensors", "Output adapter file") + trainCmd.Flags().IntVar(&trainRank, "rank", 8, "LoRA decomposition rank") + trainCmd.Flags().Float64Var(&trainAlpha, "alpha", 16, "LoRA scaling factor") + trainCmd.Flags().Float64Var(&trainLR, "lr", 1e-4, "Learning rate") + trainCmd.Flags().IntVar(&trainEpochs, "epochs", 1, "Number of training epochs") + trainCmd.Flags().IntVar(&trainMaxSeqLen, "max-seq-len", 512, "Maximum sequence length (tokens)") + trainCmd.Flags().StringVar(&trainTargets, "targets", "q_proj,v_proj", "Comma-separated projection targets for LoRA") + trainCmd.Flags().IntVar(&trainMemoryLimit, "memory-limit", 24, "Metal memory limit in GB") + trainCmd.MarkFlagRequired("model-path") + trainCmd.MarkFlagRequired("data") +} + +// trainSample holds a tokenised training example. +type trainSample struct { + Tokens []int32 // Full token sequence + Mask []int32 // 1 for assistant tokens, 0 for prompt tokens +} + +func runTrain(cmd *cli.Command, args []string) error { + start := time.Now() + + // --- Load model --- + slog.Info("loading model", "path", trainModelPath) + m, err := model.LoadModel(trainModelPath) + if err != nil { + return fmt.Errorf("load model: %w", err) + } + + mlx.SetCacheLimit(uint64(trainMemoryLimit) * 1024 * 1024 * 1024) + mlx.SetMemoryLimit(uint64(trainMemoryLimit) * 1024 * 1024 * 1024) + + tok := m.Tokenizer() + slog.Info("model loaded", + "type", m.ModelType(), + "layers", m.NumLayers(), + ) + + // --- Apply LoRA --- + targets := strings.Split(trainTargets, ",") + cfg := mlx.LoRAConfig{ + Rank: trainRank, + Alpha: float32(trainAlpha), + TargetKeys: targets, + } + + adapter := m.ApplyLoRA(cfg) + slog.Info("LoRA applied", + "rank", cfg.Rank, + "alpha", cfg.Alpha, + "targets", targets, + "trainable_params", adapter.TotalParams(), + "layers", len(adapter.Layers), + ) + + // --- Load training data --- + samples, err := loadTrainingSamples(trainData, tok, m.ModelType(), trainMaxSeqLen) + if err != nil { + return fmt.Errorf("load training data: %w", err) + } + slog.Info("training data loaded", "samples", len(samples)) + + if len(samples) == 0 { + return fmt.Errorf("no training samples loaded") + } + + // --- Training loop --- + params := adapter.AllTrainableParams() + opt := mlx.NewAdamW(trainLR) + + // Build argument indices for ValueAndGrad (all params) + argIndices := make([]int, len(params)) + for i := range argIndices { + argIndices[i] = i + } + + var totalLoss float64 + var totalSteps int + + for epoch := 0; epoch < trainEpochs; epoch++ { + var epochLoss float64 + epochStart := time.Now() + + for si, sample := range samples { + // Build token tensors: input = tokens[:-1], target = tokens[1:] + seqLen := len(sample.Tokens) + if seqLen < 2 { + continue + } + + inputTokens := sample.Tokens[:seqLen-1] + targetTokens := sample.Tokens[1:] + maskTokens := sample.Mask[1:] // mask aligned with targets + + inputArr := mlx.FromValues(inputTokens, 1, len(inputTokens)) + targetArr := mlx.FromValues(targetTokens, 1, len(targetTokens)) + + // Build float32 mask + maskF32 := make([]float32, len(maskTokens)) + for i, m := range maskTokens { + maskF32[i] = float32(m) + } + maskArr := mlx.FromValues(maskF32, 1, len(maskF32)) + mlx.Materialize(inputArr, targetArr, maskArr) + + // Loss function closure — takes LoRA params as inputs + lossFn := func(inputs []*mlx.Array) []*mlx.Array { + // Set LoRA params from inputs + adapter.SetAllParams(inputs) + + // Forward pass with fresh caches (no KV caching for training) + caches := m.NewCache() + logits := m.Forward(inputArr, caches) + + // Cast targets to int32 for take_along_axis + loss := mlx.MaskedCrossEntropyLoss(logits, targetArr, maskArr) + return []*mlx.Array{loss} + } + + // Compute value and gradients + grad := mlx.ValueAndGrad(lossFn, argIndices...) + values, grads, err := grad.Apply(params...) + grad.Free() + if err != nil { + return fmt.Errorf("epoch %d sample %d: gradient failed: %w", epoch, si, err) + } + + mlx.Materialize(append(values, grads...)...) + + loss := values[0].Float() + epochLoss += loss + totalSteps++ + + // Update parameters + params = opt.Step(params, grads) + adapter.SetAllParams(params) + mlx.Materialize(params...) + + // Periodic cleanup + if totalSteps%4 == 0 { + runtime.GC() + mlx.ClearCache() + } + + // Log progress + if (si+1)%10 == 0 || si == len(samples)-1 { + avgLoss := epochLoss / float64(si+1) + slog.Info("training", + "epoch", epoch+1, + "step", fmt.Sprintf("%d/%d", si+1, len(samples)), + "loss", fmt.Sprintf("%.4f", loss), + "avg_loss", fmt.Sprintf("%.4f", avgLoss), + ) + } + } + + totalLoss = epochLoss / float64(len(samples)) + elapsed := time.Since(epochStart) + slog.Info("epoch complete", + "epoch", epoch+1, + "avg_loss", fmt.Sprintf("%.4f", totalLoss), + "duration", elapsed.Round(time.Second), + "samples_per_sec", fmt.Sprintf("%.1f", float64(len(samples))/elapsed.Seconds()), + ) + } + + // --- Save adapter --- + if err := adapter.Save(trainOutput); err != nil { + return fmt.Errorf("save adapter: %w", err) + } + + slog.Info("training complete", + "output", trainOutput, + "total_steps", totalSteps, + "final_loss", fmt.Sprintf("%.4f", totalLoss), + "duration", time.Since(start).Round(time.Second), + "trainable_params", adapter.TotalParams(), + ) + + return nil +} + +// loadTrainingSamples reads JSONL and tokenises each conversation. +func loadTrainingSamples(path string, tok *tokenizer.Tokenizer, modelType string, maxSeqLen int) ([]trainSample, error) { + f, err := os.Open(path) + if err != nil { + return nil, err + } + defer f.Close() + + var samples []trainSample + scanner := bufio.NewScanner(f) + scanner.Buffer(make([]byte, 1<<20), 1<<20) // 1MB line buffer + + lineNum := 0 + for scanner.Scan() { + lineNum++ + line := strings.TrimSpace(scanner.Text()) + if line == "" || strings.HasPrefix(line, "#") { + continue + } + + var entry struct { + Messages []ml.Message `json:"messages"` + } + if err := json.Unmarshal([]byte(line), &entry); err != nil { + slog.Warn("skipping invalid line", "line", lineNum, "error", err) + continue + } + + if len(entry.Messages) == 0 { + continue + } + + sample := tokeniseConversation(entry.Messages, tok, modelType, maxSeqLen) + if sample != nil { + samples = append(samples, *sample) + } + } + + return samples, scanner.Err() +} + +// tokeniseConversation formats and tokenises a conversation, creating a mask +// that is 1 for assistant tokens and 0 for system/user tokens. +func tokeniseConversation(messages []ml.Message, tok *tokenizer.Tokenizer, modelType string, maxSeqLen int) *trainSample { + // Strategy: tokenise the full conversation, then tokenise just the prefix + // (non-assistant parts) to determine the mask boundary. + + // Build full conversation text + fullText := formatConversation(messages, modelType, true) + fullTokens := tok.Encode(fullText) + + if len(fullTokens) < 2 { + return nil + } + + // Truncate to max sequence length + if len(fullTokens) > maxSeqLen { + fullTokens = fullTokens[:maxSeqLen] + } + + // Build mask: tokenise prefix (everything up to last assistant response) + // then mark remaining tokens as assistant (mask=1) + prefixText := formatConversation(messages, modelType, false) + prefixTokens := tok.Encode(prefixText) + + mask := make([]int32, len(fullTokens)) + for i := range mask { + if i >= len(prefixTokens) { + mask[i] = 1 // assistant token + } + } + + return &trainSample{ + Tokens: fullTokens, + Mask: mask, + } +} + +// formatConversation formats messages using the model's chat template. +// If includeAssistant is false, only formats up to the last assistant turn header. +func formatConversation(messages []ml.Message, modelType string, includeAssistant bool) string { + switch modelType { + case "qwen3": + return formatQwen3Train(messages, includeAssistant) + default: + return formatGemmaTrain(messages, includeAssistant) + } +} + +func formatQwen3Train(messages []ml.Message, includeAssistant bool) string { + var sb strings.Builder + for _, msg := range messages { + if msg.Role == "assistant" && !includeAssistant { + // Write the assistant header but not the content + sb.WriteString("<|im_start|>assistant\n") + return sb.String() + } + switch msg.Role { + case "system": + sb.WriteString(fmt.Sprintf("<|im_start|>system\n%s<|im_end|>\n", msg.Content)) + case "user": + sb.WriteString(fmt.Sprintf("<|im_start|>user\n%s<|im_end|>\n", msg.Content)) + case "assistant": + sb.WriteString(fmt.Sprintf("<|im_start|>assistant\n%s<|im_end|>\n", msg.Content)) + } + } + return sb.String() +} + +func formatGemmaTrain(messages []ml.Message, includeAssistant bool) string { + var sb strings.Builder + for _, msg := range messages { + if msg.Role == "assistant" && !includeAssistant { + sb.WriteString("model\n") + return sb.String() + } + switch msg.Role { + case "user": + sb.WriteString(fmt.Sprintf("user\n%s\n", msg.Content)) + case "assistant": + sb.WriteString(fmt.Sprintf("model\n%s\n", msg.Content)) + case "system": + sb.WriteString(fmt.Sprintf("user\n[System: %s]\n", msg.Content)) + } + } + return sb.String() +} diff --git a/cmd/ml/cmd_train_init.go b/cmd/ml/cmd_train_init.go new file mode 100644 index 00000000..263966d9 --- /dev/null +++ b/cmd/ml/cmd_train_init.go @@ -0,0 +1,7 @@ +//go:build darwin && arm64 + +package ml + +func init() { + mlCmd.AddCommand(trainCmd) +}