From 535b04d5d6462212620a14089ebc6b5944c25334 Mon Sep 17 00:00:00 2001 From: Snider Date: Thu, 19 Feb 2026 21:55:56 +0000 Subject: [PATCH] feat(metal): add Qwen2 model support (DeepSeek R1 validated) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Qwen2 and Qwen3 share the same architecture — Qwen3 adds Q/K RMS normalization which Qwen2 lacks. The loader auto-detects the variant from weight presence and reports the correct ModelType(). - Add "qwen2" to architecture dispatch in model.go - Make Q/K norm optional in attention forward (nil-safe check) - Store detected model type on Qwen3Model struct - Add "qwen2" to chat template routing - DeepSeek R1 7B (4-bit): 27 tok/s on M3 Ultra - 2 new tests: inference + chat Co-Authored-By: Virgil Co-Authored-By: Claude Opus 4.6 --- TODO.md | 3 +- internal/metal/generate.go | 2 +- internal/metal/model.go | 2 +- internal/metal/qwen3.go | 32 ++++++++++---- mlx_test.go | 88 ++++++++++++++++++++++++++++++++++++++ 5 files changed, 115 insertions(+), 12 deletions(-) diff --git a/TODO.md b/TODO.md index 0a45484..318f467 100644 --- a/TODO.md +++ b/TODO.md @@ -15,7 +15,8 @@ Dispatched from core/go orchestration. Pick up tasks in order. - [x] **Gemma3-1B inference validation** — ✅ End-to-end inference works. 4-bit quantised Gemma3-1B loads and generates coherently at **46 tok/s** on M3 Ultra. Fixed: `model_type: "gemma3_text"` not matched in architecture dispatch, GPT-2 BPE false detection on 262K SentencePiece vocab (checked `Ġthe` instead of bare `Ġ`). 3 new tests: inference (greedy, timing), chat template, context cancellation. - [x] **Model loading robustness** — ✅ 24 new tests in model_test.go covering: missing/invalid config.json, unsupported architecture, `gemma3_text` dispatch, missing tokenizer, missing safetensors (was a nil-pointer panic — fixed with early error return in both LoadGemma3 and LoadQwen3), config parsing defaults/quantization/nested text_config, `isLayerSliding`, `resolveWeight` with prefix fallback. -- [ ] **Add Llama model support** — Only Gemma3 and Qwen3 exist. Llama architecture would cover Meta's model family (Llama 3, CodeLlama). +- [x] **Add Qwen2 model support** — ✅ Qwen2 architecture (used by DeepSeek R1) now supported. Shares Qwen3 loader with optional Q/K RMS normalization (Qwen3 has it, Qwen2 does not). Auto-detected from weight presence. DeepSeek R1 7B: **27 tok/s** on M3 Ultra. 2 new tests. +- [ ] **Add Llama model support** — Llama architecture would cover Meta's model family (Llama 3, CodeLlama). Needs safetensors model on disk for validation (e.g. `mlx-community/Meta-Llama-3.1-8B-Instruct-4bit`). ## Phase 3: Training Pipeline diff --git a/internal/metal/generate.go b/internal/metal/generate.go index aa6de02..4070bc6 100644 --- a/internal/metal/generate.go +++ b/internal/metal/generate.go @@ -198,7 +198,7 @@ func (m *Model) formatChat(messages []ChatMessage) string { switch m.modelType { case "gemma3": return formatGemmaChat(messages) - case "qwen3": + case "qwen2", "qwen3": return formatQwenChat(messages) default: var s string diff --git a/internal/metal/model.go b/internal/metal/model.go index aadf67e..6399054 100644 --- a/internal/metal/model.go +++ b/internal/metal/model.go @@ -63,7 +63,7 @@ func loadModel(modelPath string) (InternalModel, error) { } switch probe.ModelType { - case "qwen3": + case "qwen3", "qwen2": return LoadQwen3(modelPath) case "gemma3", "gemma3_text", "gemma2": return LoadGemma3(modelPath) diff --git a/internal/metal/qwen3.go b/internal/metal/qwen3.go index b70fe1b..2a91c47 100644 --- a/internal/metal/qwen3.go +++ b/internal/metal/qwen3.go @@ -28,15 +28,17 @@ type Qwen3Config struct { Scale float32 `json:"-"` // 1/sqrt(head_dim) } -// Qwen3Model is the Qwen 3 text model. +// Qwen3Model is the Qwen 2/3 text model. +// Qwen 2 and 3 share the same architecture; Qwen 3 adds Q/K RMS normalization. type Qwen3Model struct { EmbedTokens *Embedding Layers []*Qwen3DecoderLayer Norm *RMSNormModule Output *Linear - Tok *Tokenizer - Cfg *Qwen3Config + Tok *Tokenizer + Cfg *Qwen3Config + modelType string // "qwen2" or "qwen3" } // Qwen3DecoderLayer is a single transformer block. @@ -100,7 +102,7 @@ func parseQwen3Config(data []byte) (*Qwen3Config, error) { return &cfg, nil } -// LoadQwen3 loads a Qwen 3 model from a safetensors directory. +// LoadQwen3 loads a Qwen 2 or 3 model from a safetensors directory. func LoadQwen3(modelPath string) (*Qwen3Model, error) { data, err := os.ReadFile(filepath.Join(modelPath, "config.json")) if err != nil { @@ -159,12 +161,20 @@ func LoadQwen3(modelPath string) (*Qwen3Model, error) { embed.Bits = q.Bits } + // Detect Qwen 2 vs 3 by presence of Q/K normalization weights. + hasQKNorm := w("model.layers.0.self_attn.q_norm.weight") != nil + detectedType := "qwen3" + if !hasQKNorm { + detectedType = "qwen2" + } + m := &Qwen3Model{ EmbedTokens: embed, Layers: make([]*Qwen3DecoderLayer, cfg.NumHiddenLayers), Norm: &RMSNormModule{Weight: w("model.norm.weight")}, Tok: tok, Cfg: cfg, + modelType: detectedType, } for i := int32(0); i < cfg.NumHiddenLayers; i++ { @@ -260,9 +270,13 @@ func (a *Qwen3Attention) forward(x *Array, c Cache, B, L int32, cfg *Qwen3Config v = AsStrided(v, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim}, []int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0) - // Q/K RMS normalization (Qwen 3 has this) - q = a.QNorm.Forward(q, cfg.RMSNormEps) - k = a.KNorm.Forward(k, cfg.RMSNormEps) + // Q/K RMS normalization (Qwen 3 has this; Qwen 2 does not) + if a.QNorm != nil && a.QNorm.Weight != nil { + q = a.QNorm.Forward(q, cfg.RMSNormEps) + } + if a.KNorm != nil && a.KNorm.Weight != nil { + k = a.KNorm.Forward(k, cfg.RMSNormEps) + } // RoPE — single theta for all layers (no sliding window) q = RoPE(q, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, c.Offset()) @@ -305,8 +319,8 @@ func (m *Qwen3Model) NumLayers() int { return len(m.Layers) } // Tokenizer returns the model's tokenizer. func (m *Qwen3Model) Tokenizer() *Tokenizer { return m.Tok } -// ModelType returns the architecture identifier. -func (m *Qwen3Model) ModelType() string { return "qwen3" } +// ModelType returns the architecture identifier ("qwen2" or "qwen3"). +func (m *Qwen3Model) ModelType() string { return m.modelType } // ApplyLoRA wraps target projection layers with LoRA adapters. func (m *Qwen3Model) ApplyLoRA(cfg LoRAConfig) *LoRAAdapter { diff --git a/mlx_test.go b/mlx_test.go index 2562cbb..61744c5 100644 --- a/mlx_test.go +++ b/mlx_test.go @@ -294,3 +294,91 @@ func TestGemma3_1B_ContextCancel(t *testing.T) { } t.Logf("Stopped after %d tokens", count) } + +// --- Qwen2 (DeepSeek R1 7B) tests --- + +func qwen2ModelPath(t *testing.T) string { + t.Helper() + paths := []string{ + "/Volumes/Data/lem/LEK-DeepSeek-R1-7B", + } + for _, p := range paths { + if _, err := os.Stat(p); err == nil { + return p + } + } + t.Skip("no Qwen2/DeepSeek model available") + return "" +} + +// TestQwen2_Inference validates Qwen2 arch (DeepSeek R1 7B) end-to-end. +func TestQwen2_Inference(t *testing.T) { + modelPath := qwen2ModelPath(t) + + loadStart := time.Now() + m, err := inference.LoadModel(modelPath) + loadDur := time.Since(loadStart) + if err != nil { + t.Fatalf("LoadModel: %v", err) + } + defer m.Close() + t.Logf("Model loaded in %s", loadDur) + + if m.ModelType() != "qwen2" { + t.Errorf("ModelType() = %q, want %q", m.ModelType(), "qwen2") + } + + ctx := context.Background() + genStart := time.Now() + var tokens []inference.Token + var output strings.Builder + for tok := range m.Generate(ctx, "What is 2+2?", inference.WithMaxTokens(32)) { + tokens = append(tokens, tok) + output.WriteString(tok.Text) + } + genDur := time.Since(genStart) + + if err := m.Err(); err != nil { + t.Fatalf("Generate error: %v", err) + } + + nTokens := len(tokens) + if nTokens == 0 { + t.Fatal("Generate produced no tokens") + } + + tps := float64(nTokens) / genDur.Seconds() + t.Logf("Generated %d tokens in %s (%.1f tok/s)", nTokens, genDur, tps) + t.Logf("Output: %s", output.String()) + for i, tok := range tokens { + t.Logf(" [%d] id=%d %q", i, tok.ID, tok.Text) + } +} + +// TestQwen2_Chat validates chat template for Qwen2 models. +func TestQwen2_Chat(t *testing.T) { + modelPath := qwen2ModelPath(t) + + m, err := inference.LoadModel(modelPath) + if err != nil { + t.Fatalf("LoadModel: %v", err) + } + defer m.Close() + + ctx := context.Background() + var output strings.Builder + var count int + for tok := range m.Chat(ctx, []inference.Message{ + {Role: "user", Content: "Reply with exactly one word: the capital of France."}, + }, inference.WithMaxTokens(32)) { + output.WriteString(tok.Text) + count++ + } + if err := m.Err(); err != nil { + t.Fatalf("Chat error: %v", err) + } + if count == 0 { + t.Fatal("Chat produced no tokens") + } + t.Logf("Chat output (%d tokens): %s", count, output.String()) +}