diff --git a/internal/metal/backend.go b/internal/metal/backend.go index 79d7d97..fd57875 100644 --- a/internal/metal/backend.go +++ b/internal/metal/backend.go @@ -6,7 +6,8 @@ import "fmt" // LoadConfig holds configuration applied during model loading. type LoadConfig struct { - ContextLen int // Context window size (0 = model default, unbounded KV cache) + ContextLen int // Context window size (0 = model default, unbounded KV cache) + AdapterPath string // Path to LoRA adapter directory (empty = no adapter) } // LoadAndInit initialises Metal and loads a model from the given path. @@ -22,8 +23,15 @@ func LoadAndInit(path string, cfg ...LoadConfig) (*Model, error) { tokenizer: im.Tokenizer(), modelType: im.ModelType(), } - if len(cfg) > 0 && cfg[0].ContextLen > 0 { - m.contextLen = cfg[0].ContextLen + if len(cfg) > 0 { + if cfg[0].ContextLen > 0 { + m.contextLen = cfg[0].ContextLen + } + if cfg[0].AdapterPath != "" { + if err := applyLoadedLoRA(im, cfg[0].AdapterPath); err != nil { + return nil, fmt.Errorf("metal: load adapter: %w", err) + } + } } return m, nil } diff --git a/internal/metal/batch.go b/internal/metal/batch.go index f25058b..3c5c04e 100644 --- a/internal/metal/batch.go +++ b/internal/metal/batch.go @@ -6,6 +6,7 @@ import ( "context" "fmt" "math" + "slices" "sort" "time" ) @@ -86,7 +87,7 @@ func (m *Model) Classify(ctx context.Context, prompts []string, cfg GenerateConf // Gather logits at each prompt's last real token position and sample. sortedResults := make([]ClassifyResult, N) - for si := int32(0); si < N; si++ { + for si := range N { lastPos := sortedLengths[si] - 1 // Extract [1, vocab] at position lastPos for this batch element. @@ -120,11 +121,11 @@ func (m *Model) Classify(ctx context.Context, prompts []string, cfg GenerateConf totalDur := time.Since(totalStart) m.lastMetrics = Metrics{ - PromptTokens: totalPromptTokens, - GeneratedTokens: int(N), // One token sampled per prompt - PrefillDuration: totalDur, - TotalDuration: totalDur, - PeakMemoryBytes: GetPeakMemory(), + PromptTokens: totalPromptTokens, + GeneratedTokens: int(N), // One token sampled per prompt + PrefillDuration: totalDur, + TotalDuration: totalDur, + PeakMemoryBytes: GetPeakMemory(), ActiveMemoryBytes: GetActiveMemory(), } if totalDur > 0 { @@ -227,7 +228,7 @@ func (m *Model) BatchGenerate(ctx context.Context, prompts []string, cfg Generat nextIDs := make([]int32, N) allFinished := true - for si := int32(0); si < N; si++ { + for si := range N { if states[si].finished { nextIDs[si] = 0 // pad continue @@ -259,11 +260,8 @@ func (m *Model) BatchGenerate(ctx context.Context, prompts []string, cfg Generat states[si].finished = true continue } - for _, stop := range cfg.StopTokens { - if id == stop { - states[si].finished = true - break - } + if slices.Contains(cfg.StopTokens, id) { + states[si].finished = true } if !states[si].finished { text := m.tokenizer.DecodeToken(id) @@ -299,12 +297,12 @@ func (m *Model) BatchGenerate(ctx context.Context, prompts []string, cfg Generat totalDur := time.Since(totalStart) decodeDur := totalDur - prefillDur m.lastMetrics = Metrics{ - PromptTokens: totalPromptTokens, - GeneratedTokens: totalGenerated, - PrefillDuration: prefillDur, - DecodeDuration: decodeDur, - TotalDuration: totalDur, - PeakMemoryBytes: GetPeakMemory(), + PromptTokens: totalPromptTokens, + GeneratedTokens: totalGenerated, + PrefillDuration: prefillDur, + DecodeDuration: decodeDur, + TotalDuration: totalDur, + PeakMemoryBytes: GetPeakMemory(), ActiveMemoryBytes: GetActiveMemory(), } if prefillDur > 0 { @@ -324,11 +322,11 @@ func buildBatchMask(N, L int32, promptLens []int32) *Array { negInf := float32(math.Inf(-1)) data := make([]float32, int(N)*int(L)*int(L)) - for b := int32(0); b < N; b++ { + for b := range N { pLen := promptLens[b] base := int(b) * int(L) * int(L) - for i := int32(0); i < L; i++ { - for j := int32(0); j < L; j++ { + for i := range L { + for j := range L { if j <= i && j < pLen { data[base+int(i)*int(L)+int(j)] = 0 } else { diff --git a/internal/metal/gemma3.go b/internal/metal/gemma3.go index 407892c..f37b14e 100644 --- a/internal/metal/gemma3.go +++ b/internal/metal/gemma3.go @@ -6,6 +6,7 @@ import ( "encoding/json" "fmt" "log/slog" + "maps" "math" "os" "path/filepath" @@ -182,9 +183,7 @@ func LoadGemma3(modelPath string) (*GemmaModel, error) { return nil, fmt.Errorf("gemma3: no .safetensors files found in %s", modelPath) } for _, path := range matches { - for name, arr := range LoadSafetensors(path) { - weights[name] = arr - } + maps.Insert(weights, LoadSafetensors(path)) if err := lastError(); err != nil { return nil, fmt.Errorf("gemma3: load weights %s: %w", filepath.Base(path), err) } diff --git a/internal/metal/generate.go b/internal/metal/generate.go index 0203391..8ec2322 100644 --- a/internal/metal/generate.go +++ b/internal/metal/generate.go @@ -6,6 +6,8 @@ import ( "context" "fmt" "iter" + "slices" + "strings" "time" ) @@ -46,11 +48,11 @@ type Metrics struct { // Model wraps a loaded transformer model for text generation. type Model struct { - model InternalModel - tokenizer *Tokenizer - modelType string - contextLen int // 0 = unbounded (model default) - lastErr error + model InternalModel + tokenizer *Tokenizer + modelType string + contextLen int // 0 = unbounded (model default) + lastErr error lastMetrics Metrics } @@ -145,12 +147,12 @@ func (m *Model) Generate(ctx context.Context, prompt string, cfg GenerateConfig) decodeDur := time.Since(totalStart) - prefillDur totalDur := time.Since(totalStart) m.lastMetrics = Metrics{ - PromptTokens: promptLen, - GeneratedTokens: genCount, - PrefillDuration: prefillDur, - DecodeDuration: decodeDur, - TotalDuration: totalDur, - PeakMemoryBytes: GetPeakMemory(), + PromptTokens: promptLen, + GeneratedTokens: genCount, + PrefillDuration: prefillDur, + DecodeDuration: decodeDur, + TotalDuration: totalDur, + PeakMemoryBytes: GetPeakMemory(), ActiveMemoryBytes: GetActiveMemory(), } if prefillDur > 0 { @@ -205,10 +207,8 @@ func (m *Model) Generate(ctx context.Context, prompt string, cfg GenerateConfig) if id == m.tokenizer.EOSToken() { return } - for _, stop := range cfg.StopTokens { - if id == stop { - return - } + if slices.Contains(cfg.StopTokens, id) { + return } genCount++ @@ -286,44 +286,45 @@ func (m *Model) formatChat(messages []ChatMessage) string { case "llama": return formatLlamaChat(messages) default: - var s string + var s strings.Builder for _, msg := range messages { - s += msg.Content + "\n" + s.WriteString(msg.Content + "\n") } - return s + return s.String() } } func formatGemmaChat(messages []ChatMessage) string { - var s string + var s strings.Builder for _, msg := range messages { switch msg.Role { case "system": - s += "user\n" + msg.Content + "\n" + s.WriteString("user\n" + msg.Content + "\n") case "user": - s += "user\n" + msg.Content + "\n" + s.WriteString("user\n" + msg.Content + "\n") case "assistant": - s += "model\n" + msg.Content + "\n" + s.WriteString("model\n" + msg.Content + "\n") } } - s += "model\n" - return s + s.WriteString("model\n") + return s.String() } func formatQwenChat(messages []ChatMessage) string { - var s string + var s strings.Builder for _, msg := range messages { - s += "<|im_start|>" + msg.Role + "\n" + msg.Content + "<|im_end|>\n" + s.WriteString("<|im_start|>" + msg.Role + "\n" + msg.Content + "<|im_end|>\n") } - s += "<|im_start|>assistant\n" - return s + s.WriteString("<|im_start|>assistant\n") + return s.String() } func formatLlamaChat(messages []ChatMessage) string { - s := "<|begin_of_text|>" + var s strings.Builder + s.WriteString("<|begin_of_text|>") for _, msg := range messages { - s += "<|start_header_id|>" + msg.Role + "<|end_header_id|>\n\n" + msg.Content + "<|eot_id|>" + s.WriteString("<|start_header_id|>" + msg.Role + "<|end_header_id|>\n\n" + msg.Content + "<|eot_id|>") } - s += "<|start_header_id|>assistant<|end_header_id|>\n\n" - return s + s.WriteString("<|start_header_id|>assistant<|end_header_id|>\n\n") + return s.String() } diff --git a/internal/metal/lora.go b/internal/metal/lora.go index f42abcd..8d080b2 100644 --- a/internal/metal/lora.go +++ b/internal/metal/lora.go @@ -9,9 +9,14 @@ package metal import "C" import ( + "encoding/json" "fmt" + "log/slog" "math" + "os" + "path/filepath" "sort" + "strings" "unsafe" ) @@ -212,6 +217,238 @@ func RandomNormal(mean, stddev float32, shape []int32, dtype DType) *Array { return out } +// --- Adapter Loading (Inference) --- + +// adapterConfig holds the metadata from adapter_config.json produced by mlx-lm training. +type adapterConfig struct { + Rank int `json:"rank"` + Alpha float32 `json:"alpha"` + NumLayers int `json:"num_layers"` + TargetKeys []string `json:"lora_layers"` // e.g. ["self_attn.q_proj", "self_attn.v_proj"] +} + +// parseAdapterConfig reads and parses an adapter_config.json file. +func parseAdapterConfig(path string) (*adapterConfig, error) { + data, err := os.ReadFile(path) + if err != nil { + return nil, fmt.Errorf("read adapter_config.json: %w", err) + } + var cfg adapterConfig + if err := json.Unmarshal(data, &cfg); err != nil { + return nil, fmt.Errorf("parse adapter_config.json: %w", err) + } + // Apply defaults matching mlx-lm conventions. + if cfg.Rank == 0 { + cfg.Rank = 8 + } + if cfg.Alpha == 0 { + cfg.Alpha = float32(cfg.Rank) * 2 // mlx-lm default: alpha = 2 * rank + } + return &cfg, nil +} + +// loadAdapterWeights loads all safetensors files from an adapter directory into a flat weight map. +func loadAdapterWeights(dir string) (map[string]*Array, error) { + matches, err := filepath.Glob(filepath.Join(dir, "*.safetensors")) + if err != nil { + return nil, fmt.Errorf("glob adapter safetensors: %w", err) + } + if len(matches) == 0 { + return nil, fmt.Errorf("no .safetensors files found in %s", dir) + } + + weights := make(map[string]*Array) + for _, path := range matches { + for name, arr := range LoadSafetensors(path) { + weights[name] = arr + } + if err := lastError(); err != nil { + return nil, fmt.Errorf("load adapter weights %s: %w", filepath.Base(path), err) + } + } + return weights, nil +} + +// resolveLinear returns the *Linear for a given projection path within a model. +// projPath is e.g. "self_attn.q_proj" and the function resolves layer index + field. +func resolveLinear(model InternalModel, layerIdx int, projPath string) *Linear { + switch m := model.(type) { + case *Qwen3Model: + if layerIdx >= len(m.Layers) { + return nil + } + layer := m.Layers[layerIdx] + switch projPath { + case "self_attn.q_proj": + return layer.Attention.QProj + case "self_attn.k_proj": + return layer.Attention.KProj + case "self_attn.v_proj": + return layer.Attention.VProj + case "self_attn.o_proj": + return layer.Attention.OProj + } + case *GemmaModel: + if layerIdx >= len(m.Layers) { + return nil + } + layer := m.Layers[layerIdx] + switch projPath { + case "self_attn.q_proj": + return layer.Attention.QProj + case "self_attn.k_proj": + return layer.Attention.KProj + case "self_attn.v_proj": + return layer.Attention.VProj + case "self_attn.o_proj": + return layer.Attention.OProj + } + } + return nil +} + +// parseLoRAWeightName extracts the layer index, projection path, and A/B suffix +// from an adapter weight name. Returns (-1, "", "") if the name is not a recognised +// LoRA weight. +// +// Examples: +// +// "layers.0.self_attn.q_proj.lora_a" → (0, "self_attn.q_proj", "lora_a") +// "model.layers.12.self_attn.v_proj.lora_b" → (12, "self_attn.v_proj", "lora_b") +func parseLoRAWeightName(name string) (layerIdx int, projPath, suffix string) { + // Strip optional "model." prefix. + name = strings.TrimPrefix(name, "model.") + + // Must start with "layers.{N}." + if !strings.HasPrefix(name, "layers.") { + return -1, "", "" + } + + // Must end with ".lora_a" or ".lora_b". + if strings.HasSuffix(name, ".lora_a") { + suffix = "lora_a" + } else if strings.HasSuffix(name, ".lora_b") { + suffix = "lora_b" + } else { + return -1, "", "" + } + + // Remove "layers." prefix and ".lora_X" suffix. + inner := name[len("layers."):] + inner = inner[:len(inner)-len("."+suffix)] + + // Split off the layer index. + dotIdx := strings.Index(inner, ".") + if dotIdx < 0 { + return -1, "", "" + } + idxStr := inner[:dotIdx] + projPath = inner[dotIdx+1:] + + var idx int + if _, err := fmt.Sscanf(idxStr, "%d", &idx); err != nil { + return -1, "", "" + } + + return idx, projPath, suffix +} + +// applyLoadedLoRA loads a trained LoRA adapter from disk and injects it into the model +// for inference. The adapter weights are frozen (no gradients needed). +func applyLoadedLoRA(model InternalModel, adapterDir string) error { + // Step 1: Read adapter configuration. + cfg, err := parseAdapterConfig(filepath.Join(adapterDir, "adapter_config.json")) + if err != nil { + return err + } + + // Step 2: Load adapter safetensors. + weights, err := loadAdapterWeights(adapterDir) + if err != nil { + return err + } + + // Materialise all adapter weights onto Metal. + var allArrays []*Array + for _, a := range weights { + allArrays = append(allArrays, a) + } + Materialize(allArrays...) + + // Step 3: Group weights by (layerIdx, projPath) and inject LoRA. + type loraKey struct { + layerIdx int + projPath string + } + type loraPair struct { + a *Array + b *Array + } + pairs := make(map[loraKey]*loraPair) + + for name, arr := range weights { + layerIdx, projPath, suffix := parseLoRAWeightName(name) + if layerIdx < 0 { + slog.Warn("adapter: skipping unrecognised weight", "name", name) + continue + } + key := loraKey{layerIdx, projPath} + pair, ok := pairs[key] + if !ok { + pair = &loraPair{} + pairs[key] = pair + } + switch suffix { + case "lora_a": + pair.a = arr + case "lora_b": + pair.b = arr + } + } + + scale := cfg.Alpha / float32(cfg.Rank) + injected := 0 + + for key, pair := range pairs { + if pair.a == nil || pair.b == nil { + slog.Warn("adapter: incomplete LoRA pair, skipping", + "layer", key.layerIdx, "proj", key.projPath) + continue + } + + linear := resolveLinear(model, key.layerIdx, key.projPath) + if linear == nil { + slog.Warn("adapter: target layer not found, skipping", + "layer", key.layerIdx, "proj", key.projPath) + continue + } + + lora := &LoRALinear{ + Base: linear, + A: pair.a, + B: pair.b, + Scale: scale, + Rank: cfg.Rank, + Alpha: cfg.Alpha, + } + linear.LoRA = lora + injected++ + } + + if injected == 0 { + return fmt.Errorf("no LoRA layers injected from %s", adapterDir) + } + + slog.Info("adapter loaded", + "path", adapterDir, + "rank", cfg.Rank, + "alpha", cfg.Alpha, + "scale", scale, + "layers_injected", injected, + ) + return nil +} + // --- SaveSafetensors --- // SaveSafetensors saves a map of named arrays to a .safetensors file. diff --git a/internal/metal/lora_test.go b/internal/metal/lora_test.go index 64828ac..e5a70b2 100644 --- a/internal/metal/lora_test.go +++ b/internal/metal/lora_test.go @@ -5,6 +5,7 @@ package metal import ( "math" "os" + "path/filepath" "testing" ) @@ -320,3 +321,430 @@ func TestDefaultLoRAConfig(t *testing.T) { t.Errorf("TargetKeys = %v, want [q_proj, v_proj]", cfg.TargetKeys) } } + +// --- parseLoRAWeightName --- + +func TestParseLoRAWeightName_Good(t *testing.T) { + tests := []struct { + name string + input string + wantIdx int + wantProj string + wantSuf string + }{ + { + "standard_lora_a", + "layers.0.self_attn.q_proj.lora_a", + 0, "self_attn.q_proj", "lora_a", + }, + { + "standard_lora_b", + "layers.5.self_attn.v_proj.lora_b", + 5, "self_attn.v_proj", "lora_b", + }, + { + "with_model_prefix", + "model.layers.12.self_attn.q_proj.lora_a", + 12, "self_attn.q_proj", "lora_a", + }, + { + "k_proj", + "layers.3.self_attn.k_proj.lora_b", + 3, "self_attn.k_proj", "lora_b", + }, + { + "o_proj", + "layers.7.self_attn.o_proj.lora_a", + 7, "self_attn.o_proj", "lora_a", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + idx, proj, suf := parseLoRAWeightName(tt.input) + if idx != tt.wantIdx { + t.Errorf("layerIdx = %d, want %d", idx, tt.wantIdx) + } + if proj != tt.wantProj { + t.Errorf("projPath = %q, want %q", proj, tt.wantProj) + } + if suf != tt.wantSuf { + t.Errorf("suffix = %q, want %q", suf, tt.wantSuf) + } + }) + } +} + +func TestParseLoRAWeightName_Bad(t *testing.T) { + tests := []struct { + name string + input string + }{ + {"no_lora_suffix", "layers.0.self_attn.q_proj.weight"}, + {"no_layers_prefix", "self_attn.q_proj.lora_a"}, + {"empty", ""}, + {"just_layers", "layers."}, + {"no_dot_after_idx", "layers.0lora_a"}, + {"non_numeric_idx", "layers.abc.self_attn.q_proj.lora_a"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + idx, _, _ := parseLoRAWeightName(tt.input) + if idx != -1 { + t.Errorf("expected -1 for %q, got %d", tt.input, idx) + } + }) + } +} + +// --- parseAdapterConfig --- + +func TestParseAdapterConfig_Good(t *testing.T) { + dir := t.TempDir() + cfg := `{ + "rank": 16, + "alpha": 32.0, + "num_layers": 4, + "lora_layers": ["self_attn.q_proj", "self_attn.v_proj"] + }` + os.WriteFile(filepath.Join(dir, "adapter_config.json"), []byte(cfg), 0644) + + parsed, err := parseAdapterConfig(filepath.Join(dir, "adapter_config.json")) + if err != nil { + t.Fatalf("parseAdapterConfig: %v", err) + } + if parsed.Rank != 16 { + t.Errorf("Rank = %d, want 16", parsed.Rank) + } + if parsed.Alpha != 32.0 { + t.Errorf("Alpha = %f, want 32.0", parsed.Alpha) + } + if parsed.NumLayers != 4 { + t.Errorf("NumLayers = %d, want 4", parsed.NumLayers) + } + if len(parsed.TargetKeys) != 2 { + t.Errorf("TargetKeys = %v, want 2 entries", parsed.TargetKeys) + } +} + +func TestParseAdapterConfig_Good_Defaults(t *testing.T) { + dir := t.TempDir() + // Minimal config — rank and alpha should get defaults. + cfg := `{}` + os.WriteFile(filepath.Join(dir, "adapter_config.json"), []byte(cfg), 0644) + + parsed, err := parseAdapterConfig(filepath.Join(dir, "adapter_config.json")) + if err != nil { + t.Fatalf("parseAdapterConfig: %v", err) + } + if parsed.Rank != 8 { + t.Errorf("default Rank = %d, want 8", parsed.Rank) + } + if parsed.Alpha != 16.0 { + t.Errorf("default Alpha = %f, want 16.0 (2 * rank)", parsed.Alpha) + } +} + +func TestParseAdapterConfig_Bad_MissingFile(t *testing.T) { + _, err := parseAdapterConfig("/nonexistent/adapter_config.json") + if err == nil { + t.Fatal("expected error for missing file") + } +} + +func TestParseAdapterConfig_Bad_InvalidJSON(t *testing.T) { + dir := t.TempDir() + os.WriteFile(filepath.Join(dir, "adapter_config.json"), []byte("{broken"), 0644) + + _, err := parseAdapterConfig(filepath.Join(dir, "adapter_config.json")) + if err == nil { + t.Fatal("expected error for invalid JSON") + } +} + +// --- loadAdapterWeights --- + +func TestLoadAdapterWeights_Bad_NoFiles(t *testing.T) { + dir := t.TempDir() + _, err := loadAdapterWeights(dir) + if err == nil { + t.Fatal("expected error for directory with no safetensors files") + } +} + +func TestLoadAdapterWeights_Good(t *testing.T) { + dir := t.TempDir() + + // Save a small adapter file. + a := FromValues([]float32{1, 2, 3, 4}, 2, 2) + b := FromValues([]float32{5, 6, 7, 8}, 2, 2) + Materialize(a, b) + + err := SaveSafetensors(filepath.Join(dir, "adapters.safetensors"), map[string]*Array{ + "layers.0.self_attn.q_proj.lora_a": a, + "layers.0.self_attn.q_proj.lora_b": b, + }) + if err != nil { + t.Fatalf("SaveSafetensors: %v", err) + } + + weights, err := loadAdapterWeights(dir) + if err != nil { + t.Fatalf("loadAdapterWeights: %v", err) + } + if len(weights) != 2 { + t.Errorf("loaded %d weights, want 2", len(weights)) + } + if _, ok := weights["layers.0.self_attn.q_proj.lora_a"]; !ok { + t.Error("missing lora_a weight") + } + if _, ok := weights["layers.0.self_attn.q_proj.lora_b"]; !ok { + t.Error("missing lora_b weight") + } +} + +// --- applyLoadedLoRA integration --- + +func TestApplyLoadedLoRA_Good_SaveAndReload(t *testing.T) { + // Create a simple base Linear layer and save LoRA weights for it, + // then load them back with applyLoadedLoRA. + + // Create a small "model" with 1 layer and known dimensions. + w := RandomNormal(0, 0.01, []int32{4, 8}, DTypeFloat32) + Materialize(w) + linear := NewLinear(w, nil) + + // Train a LoRA on this linear, then save. + lora := NewLoRALinear(linear, 4, 8.0) + // Set A and B to non-zero values so we can verify they load correctly. + newA := FromValues([]float32{ + 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, + 0.9, 1.0, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, + 1.7, 1.8, 1.9, 2.0, 2.1, 2.2, 2.3, 2.4, + 2.5, 2.6, 2.7, 2.8, 2.9, 3.0, 3.1, 3.2, + }, 4, 8) // [rank=4, in=8] + newB := FromValues([]float32{ + 0.1, 0.2, 0.3, 0.4, + 0.5, 0.6, 0.7, 0.8, + 0.9, 1.0, 1.1, 1.2, + 1.3, 1.4, 1.5, 1.6, + }, 4, 4) // [out=4, rank=4] + Materialize(newA, newB) + lora.A = newA + lora.B = newB + + // Save the adapter weights. + adapterDir := t.TempDir() + err := SaveSafetensors(filepath.Join(adapterDir, "adapters.safetensors"), map[string]*Array{ + "layers.0.self_attn.q_proj.lora_a": lora.A, + "layers.0.self_attn.q_proj.lora_b": lora.B, + }) + if err != nil { + t.Fatalf("SaveSafetensors: %v", err) + } + + // Write adapter_config.json. + configJSON := `{"rank": 4, "alpha": 8.0, "num_layers": 1, "lora_layers": ["self_attn.q_proj"]}` + os.WriteFile(filepath.Join(adapterDir, "adapter_config.json"), []byte(configJSON), 0644) + + // Now create a fresh linear with the same base weights (no LoRA). + linear2 := NewLinear(w, nil) + if linear2.LoRA != nil { + t.Fatal("fresh linear should not have LoRA") + } + + // Build a minimal model for resolveLinear to work. + qwen := &Qwen3Model{ + Layers: []*Qwen3DecoderLayer{ + { + Attention: &Qwen3Attention{ + QProj: linear2, + KProj: NewLinear(RandomNormal(0, 0.01, []int32{4, 8}, DTypeFloat32), nil), + VProj: NewLinear(RandomNormal(0, 0.01, []int32{4, 8}, DTypeFloat32), nil), + OProj: NewLinear(RandomNormal(0, 0.01, []int32{4, 8}, DTypeFloat32), nil), + }, + }, + }, + } + + // Apply the loaded adapter. + err = applyLoadedLoRA(qwen, adapterDir) + if err != nil { + t.Fatalf("applyLoadedLoRA: %v", err) + } + + // Verify LoRA was injected. + if linear2.LoRA == nil { + t.Fatal("LoRA should have been injected into q_proj") + } + + // Verify rank and scale. + if linear2.LoRA.Rank != 4 { + t.Errorf("Rank = %d, want 4", linear2.LoRA.Rank) + } + expectedScale := float32(8.0) / float32(4) // alpha / rank = 2.0 + if math.Abs(float64(linear2.LoRA.Scale-expectedScale)) > 1e-5 { + t.Errorf("Scale = %f, want %f", linear2.LoRA.Scale, expectedScale) + } + + // Verify the loaded A weights match what we saved. + Materialize(linear2.LoRA.A, linear2.LoRA.B) + loadedA := linear2.LoRA.A.Floats() + origA := newA.Floats() + if len(loadedA) != len(origA) { + t.Fatalf("A size mismatch: %d vs %d", len(loadedA), len(origA)) + } + for i := range origA { + if math.Abs(float64(loadedA[i]-origA[i])) > 1e-5 { + t.Errorf("A[%d] = %f, want %f", i, loadedA[i], origA[i]) + break + } + } + + // Verify the loaded B weights match. + loadedB := linear2.LoRA.B.Floats() + origB := newB.Floats() + if len(loadedB) != len(origB) { + t.Fatalf("B size mismatch: %d vs %d", len(loadedB), len(origB)) + } + for i := range origB { + if math.Abs(float64(loadedB[i]-origB[i])) > 1e-5 { + t.Errorf("B[%d] = %f, want %f", i, loadedB[i], origB[i]) + break + } + } +} + +func TestApplyLoadedLoRA_Bad_MissingConfig(t *testing.T) { + dir := t.TempDir() + // Write safetensors but no config. + a := FromValues([]float32{1, 2, 3, 4}, 2, 2) + Materialize(a) + SaveSafetensors(filepath.Join(dir, "adapters.safetensors"), map[string]*Array{"x": a}) + + qwen := &Qwen3Model{Layers: []*Qwen3DecoderLayer{}} + err := applyLoadedLoRA(qwen, dir) + if err == nil { + t.Fatal("expected error for missing adapter_config.json") + } +} + +func TestApplyLoadedLoRA_Bad_MissingSafetensors(t *testing.T) { + dir := t.TempDir() + // Write config but no safetensors. + os.WriteFile(filepath.Join(dir, "adapter_config.json"), []byte(`{"rank": 8}`), 0644) + + qwen := &Qwen3Model{Layers: []*Qwen3DecoderLayer{}} + err := applyLoadedLoRA(qwen, dir) + if err == nil { + t.Fatal("expected error for missing safetensors") + } +} + +func TestApplyLoadedLoRA_Bad_NoMatchingLayers(t *testing.T) { + dir := t.TempDir() + os.WriteFile(filepath.Join(dir, "adapter_config.json"), []byte(`{"rank": 4, "alpha": 8.0}`), 0644) + + // Save weights that reference layer 99 (which won't exist). + a := FromValues([]float32{1, 2, 3, 4}, 2, 2) + b := FromValues([]float32{5, 6, 7, 8}, 2, 2) + Materialize(a, b) + SaveSafetensors(filepath.Join(dir, "adapters.safetensors"), map[string]*Array{ + "layers.99.self_attn.q_proj.lora_a": a, + "layers.99.self_attn.q_proj.lora_b": b, + }) + + qwen := &Qwen3Model{ + Layers: []*Qwen3DecoderLayer{ + { + Attention: &Qwen3Attention{ + QProj: NewLinear(RandomNormal(0, 0.01, []int32{4, 8}, DTypeFloat32), nil), + }, + }, + }, + } + err := applyLoadedLoRA(qwen, dir) + if err == nil { + t.Fatal("expected error when no layers are injected") + } +} + +// TestApplyLoadedLoRA_Good_ForwardProducesOutput validates that a model with a +// loaded LoRA adapter produces different output than the base model alone. +func TestApplyLoadedLoRA_Good_ForwardProducesOutput(t *testing.T) { + // Create base linear [4, 8]. + w := RandomNormal(0, 0.1, []int32{4, 8}, DTypeFloat32) + Materialize(w) + linear := NewLinear(w, nil) + + // Compute base output. + x := RandomNormal(0, 1, []int32{1, 2, 8}, DTypeFloat32) + Materialize(x) + baseOut := linear.Forward(x) + Materialize(baseOut) + baseFloats := baseOut.Floats() + + // Create and save non-trivial adapter weights. + rank := 4 + loraA := RandomNormal(0, 0.1, []int32{int32(rank), 8}, DTypeFloat32) + loraB := RandomNormal(0, 0.1, []int32{4, int32(rank)}, DTypeFloat32) + Materialize(loraA, loraB) + + adapterDir := t.TempDir() + SaveSafetensors(filepath.Join(adapterDir, "adapters.safetensors"), map[string]*Array{ + "layers.0.self_attn.q_proj.lora_a": loraA, + "layers.0.self_attn.q_proj.lora_b": loraB, + }) + os.WriteFile(filepath.Join(adapterDir, "adapter_config.json"), + []byte(`{"rank": 4, "alpha": 8.0}`), 0644) + + // Build a model and apply adapter. + qwen := &Qwen3Model{ + Layers: []*Qwen3DecoderLayer{ + { + Attention: &Qwen3Attention{ + QProj: linear, + KProj: NewLinear(RandomNormal(0, 0.01, []int32{4, 8}, DTypeFloat32), nil), + VProj: NewLinear(RandomNormal(0, 0.01, []int32{4, 8}, DTypeFloat32), nil), + OProj: NewLinear(RandomNormal(0, 0.01, []int32{4, 8}, DTypeFloat32), nil), + }, + }, + }, + } + + err := applyLoadedLoRA(qwen, adapterDir) + if err != nil { + t.Fatalf("applyLoadedLoRA: %v", err) + } + + // Now forward should go through LoRA path. + loraOut := linear.Forward(x) + Materialize(loraOut) + loraFloats := loraOut.Floats() + + // Outputs should differ since B is non-zero. + allSame := true + for i := range baseFloats { + if math.Abs(float64(baseFloats[i]-loraFloats[i])) > 1e-6 { + allSame = false + break + } + } + if allSame { + t.Error("expected LoRA output to differ from base output with non-zero B weights") + } +} + +// --- LoadAndInit with adapter --- + +func TestLoadAndInit_Bad_AdapterMissing(t *testing.T) { + dir := t.TempDir() + writeMinimalConfig(t, dir, "qwen3") + writeMinimalTokenizer(t, dir) + + // Create a minimal safetensors file so model loading proceeds. + // The adapter path doesn't exist, so it should fail at the adapter step. + _, err := LoadAndInit(dir, LoadConfig{AdapterPath: "/nonexistent/adapter"}) + if err == nil { + t.Fatal("expected error for missing adapter") + } +} diff --git a/internal/metal/optim_test.go b/internal/metal/optim_test.go index 1cd41bd..6c754e4 100644 --- a/internal/metal/optim_test.go +++ b/internal/metal/optim_test.go @@ -14,7 +14,7 @@ func TestAdamW_BasicStep(t *testing.T) { opt := NewAdamW(0.1) - for i := 0; i < 300; i++ { + for i := range 300 { // Gradient of x^2 is 2x lossFn := func(inputs []*Array) []*Array { p := inputs[0] @@ -48,7 +48,7 @@ func TestAdamW_MultiParam(t *testing.T) { opt := NewAdamW(0.1) - for i := 0; i < 100; i++ { + for i := range 100 { lossFn := func(inputs []*Array) []*Array { return []*Array{Add(Mul(inputs[0], inputs[0]), Mul(inputs[1], inputs[1]))} } @@ -85,7 +85,7 @@ func TestAdamW_WeightDecay(t *testing.T) { zeroGrad := FromValue(float32(0.0)) Materialize(zeroGrad) - for i := 0; i < 10; i++ { + for range 10 { updated := opt.Step([]*Array{x}, []*Array{zeroGrad}) x = updated[0] Materialize(x) @@ -137,7 +137,7 @@ func TestAdamW_WithLoRA(t *testing.T) { var initialLoss, finalLoss float64 - for step := 0; step < 50; step++ { + for step := range 50 { lossFn := func(inputs []*Array) []*Array { lora.A = inputs[0] lora.B = inputs[1] diff --git a/internal/metal/qwen3.go b/internal/metal/qwen3.go index ff0b233..f8daef9 100644 --- a/internal/metal/qwen3.go +++ b/internal/metal/qwen3.go @@ -6,6 +6,7 @@ import ( "encoding/json" "fmt" "log/slog" + "maps" "math" "os" "path/filepath" @@ -44,10 +45,10 @@ type Qwen3Model struct { // Qwen3DecoderLayer is a single transformer block. // Qwen 3 uses standard pre-norm residual: norm→attn→add, norm→mlp→add. type Qwen3DecoderLayer struct { - InputNorm *RMSNormModule // Pre-attention norm + InputNorm *RMSNormModule // Pre-attention norm PostAttnNorm *RMSNormModule // Pre-MLP norm (confusingly named post_attention_layernorm) - Attention *Qwen3Attention - MLP *Qwen3MLP + Attention *Qwen3Attention + MLP *Qwen3MLP } // Qwen3Attention implements Qwen 3 GQA with Q/K RMS normalization. @@ -136,9 +137,7 @@ func LoadQwen3(modelPath string) (*Qwen3Model, error) { return nil, fmt.Errorf("qwen3: no .safetensors files found in %s", modelPath) } for _, path := range matches { - for name, arr := range LoadSafetensors(path) { - weights[name] = arr - } + maps.Insert(weights, LoadSafetensors(path)) if err := lastError(); err != nil { return nil, fmt.Errorf("qwen3: load weights %s: %w", filepath.Base(path), err) } diff --git a/internal/metal/tokenizer.go b/internal/metal/tokenizer.go index 3110aec..6f2c1f4 100644 --- a/internal/metal/tokenizer.go +++ b/internal/metal/tokenizer.go @@ -173,7 +173,7 @@ func buildGPT2ByteMaps() (decoder map[rune]byte, encoder map[byte]rune) { // Non-self-mapping: control chars, space, DEL, and gaps n := 0 - for b := 0; b < 256; b++ { + for b := range 256 { if _, ok := encoder[byte(b)]; !ok { r := rune(256 + n) encoder[byte(b)] = r diff --git a/internal/metal/tokenizer_test.go b/internal/metal/tokenizer_test.go index b95bf83..6574b23 100644 --- a/internal/metal/tokenizer_test.go +++ b/internal/metal/tokenizer_test.go @@ -239,7 +239,7 @@ func TestBuildGPT2ByteMaps(t *testing.T) { } // Round-trip: every byte should survive encode → decode - for b := 0; b < 256; b++ { + for b := range 256 { r := encoder[byte(b)] got := decoder[r] if got != byte(b) { diff --git a/internal/metal/train_test.go b/internal/metal/train_test.go index 864d745..2640ecd 100644 --- a/internal/metal/train_test.go +++ b/internal/metal/train_test.go @@ -75,7 +75,7 @@ func TestLoRA_EndToEnd(t *testing.T) { var initialLoss, finalLoss float64 const numSteps = 5 - for step := 0; step < numSteps; step++ { + for step := range numSteps { // Fresh caches each step (stateful — can't reuse across gradient calls). caches := gemma.NewCache() @@ -224,7 +224,7 @@ func TestLoRA_GradientCheckpointing(t *testing.T) { var initialLoss, finalLoss float64 const numSteps = 3 - for step := 0; step < numSteps; step++ { + for step := range numSteps { caches := gemma.NewCache() // Wrap the model forward pass in Checkpoint to recompute activations @@ -326,7 +326,7 @@ func TestLoRA_MixedPrecision(t *testing.T) { var initialLoss, finalLoss float64 const numSteps = 5 - for step := 0; step < numSteps; step++ { + for step := range numSteps { caches := gemma.NewCache() lossFn := func(inputs []*Array) []*Array { diff --git a/register_metal.go b/register_metal.go index 4f90432..b742bb2 100644 --- a/register_metal.go +++ b/register_metal.go @@ -63,7 +63,8 @@ func (b *metalBackend) LoadModel(path string, opts ...inference.LoadOption) (inf slog.Warn("mlx: GPULayers=0 ignored — Metal always uses full GPU offload") } m, err := metal.LoadAndInit(path, metal.LoadConfig{ - ContextLen: cfg.ContextLen, + ContextLen: cfg.ContextLen, + AdapterPath: cfg.AdapterPath, }) if err != nil { return nil, err