test(metal): add LoRA end-to-end training pipeline test (Phase 3)
Validates full pipeline: load Gemma3-1B → apply LoRA (rank=8, 745K params across 52 layers) → train 5 steps with cross-entropy loss (7.15→6.31) → save adapter to safetensors → reload and verify all weights match. Uses ValueAndGrad for autograd + AdamW optimiser. Co-Authored-By: Virgil <virgil@lethean.io> Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
19c4823b04
commit
fb0692baf3
2 changed files with 194 additions and 1 deletions
2
TODO.md
2
TODO.md
|
|
@ -20,7 +20,7 @@ Dispatched from core/go orchestration. Pick up tasks in order.
|
|||
|
||||
## Phase 3: Training Pipeline
|
||||
|
||||
- [ ] **LoRA fine-tuning end-to-end** — `lora.go` has the adapter but no integration test showing: load base model → apply LoRA → train on small dataset → save adapter → reload. Critical for LEM Lab.
|
||||
- [x] **LoRA fine-tuning end-to-end** — ✅ Full pipeline validated: load Gemma3-1B → apply LoRA (rank=8, q_proj+v_proj, 745K params) → 5 training steps with cross-entropy loss (7.15→6.31) → save adapter (2.9MB safetensors) → reload and verify weights match. Uses ValueAndGrad + AdamW. 1 new test in train_test.go.
|
||||
- [ ] **Gradient checkpointing** — `grad.go` has VJP but large models will OOM without checkpointing. Add selective recomputation.
|
||||
- [ ] **Mixed precision training** — MLX supports BFloat16/Float16. Add dtype selection for training (currently inference uses model's native dtype).
|
||||
|
||||
|
|
|
|||
193
internal/metal/train_test.go
Normal file
193
internal/metal/train_test.go
Normal file
|
|
@ -0,0 +1,193 @@
|
|||
//go:build darwin && arm64
|
||||
|
||||
package metal
|
||||
|
||||
import (
|
||||
"math"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// gemma3Path returns the path to a Gemma3-1B model, or skips the test.
|
||||
func gemma3Path(t *testing.T) string {
|
||||
t.Helper()
|
||||
paths := []string{
|
||||
"/Volumes/Data/lem/gemma-3-1b-it-base",
|
||||
"/Volumes/Data/lem/safetensors/gemma-3/",
|
||||
}
|
||||
for _, p := range paths {
|
||||
if _, err := os.Stat(p); err == nil {
|
||||
return p
|
||||
}
|
||||
}
|
||||
t.Skip("no Gemma3 model available")
|
||||
return ""
|
||||
}
|
||||
|
||||
// TestLoRA_EndToEnd validates the full LoRA training pipeline:
|
||||
// load base model → apply LoRA → train on small data → save adapter → reload.
|
||||
func TestLoRA_EndToEnd(t *testing.T) {
|
||||
modelPath := gemma3Path(t)
|
||||
|
||||
// Step 1: Load base model.
|
||||
model, err := loadModel(modelPath)
|
||||
if err != nil {
|
||||
t.Fatalf("loadModel: %v", err)
|
||||
}
|
||||
|
||||
gemma := model.(*GemmaModel)
|
||||
tok := gemma.Tokenizer()
|
||||
|
||||
// Step 2: Apply LoRA to Q and V projections.
|
||||
cfg := DefaultLoRAConfig() // rank=8, alpha=16, targets=[q_proj, v_proj]
|
||||
adapter := gemma.ApplyLoRA(cfg)
|
||||
|
||||
numParams := adapter.TotalParams()
|
||||
t.Logf("LoRA applied: %d trainable parameters across %d layers",
|
||||
numParams, len(adapter.Layers))
|
||||
|
||||
if numParams == 0 {
|
||||
t.Fatal("no trainable parameters")
|
||||
}
|
||||
|
||||
// Step 3: Create a tiny training example.
|
||||
// Encode a short sequence, use shifted tokens as targets.
|
||||
inputIDs := tok.Encode("The capital of France is Paris")
|
||||
if len(inputIDs) < 3 {
|
||||
t.Fatalf("encoded too few tokens: %d", len(inputIDs))
|
||||
}
|
||||
t.Logf("Training tokens: %v (len=%d)", inputIDs, len(inputIDs))
|
||||
|
||||
seqLen := len(inputIDs) - 1 // input is all but last, target is all but first
|
||||
inputTokens := FromValues(inputIDs[:seqLen], 1, seqLen)
|
||||
targetTokens := FromValues(inputIDs[1:], 1, seqLen)
|
||||
Materialize(inputTokens, targetTokens)
|
||||
|
||||
// Step 4: Run a few training steps.
|
||||
params := adapter.AllTrainableParams()
|
||||
argnums := make([]int, len(params))
|
||||
for i := range argnums {
|
||||
argnums[i] = i
|
||||
}
|
||||
|
||||
opt := NewAdamW(1e-4)
|
||||
var initialLoss, finalLoss float64
|
||||
const numSteps = 5
|
||||
|
||||
for step := 0; step < numSteps; step++ {
|
||||
// Fresh caches each step (stateful — can't reuse across gradient calls).
|
||||
caches := gemma.NewCache()
|
||||
|
||||
lossFn := func(inputs []*Array) []*Array {
|
||||
adapter.SetAllParams(inputs)
|
||||
logits := gemma.Forward(inputTokens, caches)
|
||||
// logits is [1, seqLen, vocab] — compute cross-entropy against targets
|
||||
loss := CrossEntropyLoss(logits, targetTokens)
|
||||
return []*Array{loss}
|
||||
}
|
||||
|
||||
grad := ValueAndGrad(lossFn, argnums...)
|
||||
values, grads, err := grad.Apply(params...)
|
||||
grad.Free()
|
||||
if err != nil {
|
||||
t.Fatalf("step %d: ValueAndGrad failed: %v", step, err)
|
||||
}
|
||||
|
||||
Materialize(append(values, grads...)...)
|
||||
|
||||
loss := values[0].Float()
|
||||
t.Logf("step %d: loss = %.4f", step, loss)
|
||||
|
||||
if step == 0 {
|
||||
initialLoss = loss
|
||||
if math.IsNaN(loss) || math.IsInf(loss, 0) {
|
||||
t.Fatalf("initial loss is %f", loss)
|
||||
}
|
||||
}
|
||||
finalLoss = loss
|
||||
|
||||
// Update params.
|
||||
updated := opt.Step(params, grads)
|
||||
for i := range updated {
|
||||
Materialize(updated[i])
|
||||
}
|
||||
params = updated
|
||||
adapter.SetAllParams(params)
|
||||
}
|
||||
|
||||
// Verify loss decreased.
|
||||
t.Logf("loss: %.4f → %.4f", initialLoss, finalLoss)
|
||||
if finalLoss >= initialLoss {
|
||||
t.Errorf("loss did not decrease: %.4f → %.4f", initialLoss, finalLoss)
|
||||
}
|
||||
|
||||
// Step 5: Save adapter.
|
||||
savePath := filepath.Join(t.TempDir(), "adapter.safetensors")
|
||||
if err := adapter.Save(savePath); err != nil {
|
||||
t.Fatalf("adapter.Save: %v", err)
|
||||
}
|
||||
|
||||
info, err := os.Stat(savePath)
|
||||
if err != nil {
|
||||
t.Fatalf("saved adapter not found: %v", err)
|
||||
}
|
||||
t.Logf("adapter saved: %s (%d bytes)", savePath, info.Size())
|
||||
|
||||
// Step 6: Reload and verify weights match.
|
||||
loaded, err := LoadAllSafetensors(savePath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadAllSafetensors: %v", err)
|
||||
}
|
||||
|
||||
for _, name := range adapter.SortedNames() {
|
||||
layer := adapter.Layers[name]
|
||||
aKey := name + ".lora_a"
|
||||
bKey := name + ".lora_b"
|
||||
|
||||
loadedA, ok := loaded[aKey]
|
||||
if !ok {
|
||||
t.Errorf("missing %s in saved adapter", aKey)
|
||||
continue
|
||||
}
|
||||
loadedB, ok := loaded[bKey]
|
||||
if !ok {
|
||||
t.Errorf("missing %s in saved adapter", bKey)
|
||||
continue
|
||||
}
|
||||
|
||||
Materialize(loadedA, loadedB)
|
||||
|
||||
// Compare A weights.
|
||||
origA := layer.A.Floats()
|
||||
reloadA := loadedA.Floats()
|
||||
if len(origA) != len(reloadA) {
|
||||
t.Errorf("%s: A size mismatch: %d vs %d", name, len(origA), len(reloadA))
|
||||
continue
|
||||
}
|
||||
for i := range origA {
|
||||
if math.Abs(float64(origA[i]-reloadA[i])) > 1e-6 {
|
||||
t.Errorf("%s: A[%d] = %f, reloaded = %f", name, i, origA[i], reloadA[i])
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Compare B weights.
|
||||
origB := layer.B.Floats()
|
||||
reloadB := loadedB.Floats()
|
||||
if len(origB) != len(reloadB) {
|
||||
t.Errorf("%s: B size mismatch: %d vs %d", name, len(origB), len(reloadB))
|
||||
continue
|
||||
}
|
||||
for i := range origB {
|
||||
if math.Abs(float64(origB[i]-reloadB[i])) > 1e-6 {
|
||||
t.Errorf("%s: B[%d] = %f, reloaded = %f", name, i, origB[i], reloadB[i])
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
t.Logf("all %d adapter layers verified after reload", len(adapter.Layers))
|
||||
|
||||
ClearCache()
|
||||
}
|
||||
Loading…
Add table
Reference in a new issue