feat(metal): add Qwen2 model support (DeepSeek R1 validated)

Qwen2 and Qwen3 share the same architecture — Qwen3 adds Q/K RMS
normalization which Qwen2 lacks. The loader auto-detects the variant
from weight presence and reports the correct ModelType().

- Add "qwen2" to architecture dispatch in model.go
- Make Q/K norm optional in attention forward (nil-safe check)
- Store detected model type on Qwen3Model struct
- Add "qwen2" to chat template routing
- DeepSeek R1 7B (4-bit): 27 tok/s on M3 Ultra
- 2 new tests: inference + chat

Co-Authored-By: Virgil <virgil@lethean.io>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Snider 2026-02-19 21:55:56 +00:00
parent a2493e0242
commit 535b04d5d6
5 changed files with 115 additions and 12 deletions

View file

@ -15,7 +15,8 @@ Dispatched from core/go orchestration. Pick up tasks in order.
- [x] **Gemma3-1B inference validation** — ✅ End-to-end inference works. 4-bit quantised Gemma3-1B loads and generates coherently at **46 tok/s** on M3 Ultra. Fixed: `model_type: "gemma3_text"` not matched in architecture dispatch, GPT-2 BPE false detection on 262K SentencePiece vocab (checked `Ġthe` instead of bare `Ġ`). 3 new tests: inference (greedy, timing), chat template, context cancellation.
- [x] **Model loading robustness** — ✅ 24 new tests in model_test.go covering: missing/invalid config.json, unsupported architecture, `gemma3_text` dispatch, missing tokenizer, missing safetensors (was a nil-pointer panic — fixed with early error return in both LoadGemma3 and LoadQwen3), config parsing defaults/quantization/nested text_config, `isLayerSliding`, `resolveWeight` with prefix fallback.
- [ ] **Add Llama model support** — Only Gemma3 and Qwen3 exist. Llama architecture would cover Meta's model family (Llama 3, CodeLlama).
- [x] **Add Qwen2 model support** — ✅ Qwen2 architecture (used by DeepSeek R1) now supported. Shares Qwen3 loader with optional Q/K RMS normalization (Qwen3 has it, Qwen2 does not). Auto-detected from weight presence. DeepSeek R1 7B: **27 tok/s** on M3 Ultra. 2 new tests.
- [ ] **Add Llama model support** — Llama architecture would cover Meta's model family (Llama 3, CodeLlama). Needs safetensors model on disk for validation (e.g. `mlx-community/Meta-Llama-3.1-8B-Instruct-4bit`).
## Phase 3: Training Pipeline

View file

@ -198,7 +198,7 @@ func (m *Model) formatChat(messages []ChatMessage) string {
switch m.modelType {
case "gemma3":
return formatGemmaChat(messages)
case "qwen3":
case "qwen2", "qwen3":
return formatQwenChat(messages)
default:
var s string

View file

@ -63,7 +63,7 @@ func loadModel(modelPath string) (InternalModel, error) {
}
switch probe.ModelType {
case "qwen3":
case "qwen3", "qwen2":
return LoadQwen3(modelPath)
case "gemma3", "gemma3_text", "gemma2":
return LoadGemma3(modelPath)

View file

@ -28,15 +28,17 @@ type Qwen3Config struct {
Scale float32 `json:"-"` // 1/sqrt(head_dim)
}
// Qwen3Model is the Qwen 3 text model.
// Qwen3Model is the Qwen 2/3 text model.
// Qwen 2 and 3 share the same architecture; Qwen 3 adds Q/K RMS normalization.
type Qwen3Model struct {
EmbedTokens *Embedding
Layers []*Qwen3DecoderLayer
Norm *RMSNormModule
Output *Linear
Tok *Tokenizer
Cfg *Qwen3Config
Tok *Tokenizer
Cfg *Qwen3Config
modelType string // "qwen2" or "qwen3"
}
// Qwen3DecoderLayer is a single transformer block.
@ -100,7 +102,7 @@ func parseQwen3Config(data []byte) (*Qwen3Config, error) {
return &cfg, nil
}
// LoadQwen3 loads a Qwen 3 model from a safetensors directory.
// LoadQwen3 loads a Qwen 2 or 3 model from a safetensors directory.
func LoadQwen3(modelPath string) (*Qwen3Model, error) {
data, err := os.ReadFile(filepath.Join(modelPath, "config.json"))
if err != nil {
@ -159,12 +161,20 @@ func LoadQwen3(modelPath string) (*Qwen3Model, error) {
embed.Bits = q.Bits
}
// Detect Qwen 2 vs 3 by presence of Q/K normalization weights.
hasQKNorm := w("model.layers.0.self_attn.q_norm.weight") != nil
detectedType := "qwen3"
if !hasQKNorm {
detectedType = "qwen2"
}
m := &Qwen3Model{
EmbedTokens: embed,
Layers: make([]*Qwen3DecoderLayer, cfg.NumHiddenLayers),
Norm: &RMSNormModule{Weight: w("model.norm.weight")},
Tok: tok,
Cfg: cfg,
modelType: detectedType,
}
for i := int32(0); i < cfg.NumHiddenLayers; i++ {
@ -260,9 +270,13 @@ func (a *Qwen3Attention) forward(x *Array, c Cache, B, L int32, cfg *Qwen3Config
v = AsStrided(v, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim},
[]int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0)
// Q/K RMS normalization (Qwen 3 has this)
q = a.QNorm.Forward(q, cfg.RMSNormEps)
k = a.KNorm.Forward(k, cfg.RMSNormEps)
// Q/K RMS normalization (Qwen 3 has this; Qwen 2 does not)
if a.QNorm != nil && a.QNorm.Weight != nil {
q = a.QNorm.Forward(q, cfg.RMSNormEps)
}
if a.KNorm != nil && a.KNorm.Weight != nil {
k = a.KNorm.Forward(k, cfg.RMSNormEps)
}
// RoPE — single theta for all layers (no sliding window)
q = RoPE(q, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, c.Offset())
@ -305,8 +319,8 @@ func (m *Qwen3Model) NumLayers() int { return len(m.Layers) }
// Tokenizer returns the model's tokenizer.
func (m *Qwen3Model) Tokenizer() *Tokenizer { return m.Tok }
// ModelType returns the architecture identifier.
func (m *Qwen3Model) ModelType() string { return "qwen3" }
// ModelType returns the architecture identifier ("qwen2" or "qwen3").
func (m *Qwen3Model) ModelType() string { return m.modelType }
// ApplyLoRA wraps target projection layers with LoRA adapters.
func (m *Qwen3Model) ApplyLoRA(cfg LoRAConfig) *LoRAAdapter {

View file

@ -294,3 +294,91 @@ func TestGemma3_1B_ContextCancel(t *testing.T) {
}
t.Logf("Stopped after %d tokens", count)
}
// --- Qwen2 (DeepSeek R1 7B) tests ---
func qwen2ModelPath(t *testing.T) string {
t.Helper()
paths := []string{
"/Volumes/Data/lem/LEK-DeepSeek-R1-7B",
}
for _, p := range paths {
if _, err := os.Stat(p); err == nil {
return p
}
}
t.Skip("no Qwen2/DeepSeek model available")
return ""
}
// TestQwen2_Inference validates Qwen2 arch (DeepSeek R1 7B) end-to-end.
func TestQwen2_Inference(t *testing.T) {
modelPath := qwen2ModelPath(t)
loadStart := time.Now()
m, err := inference.LoadModel(modelPath)
loadDur := time.Since(loadStart)
if err != nil {
t.Fatalf("LoadModel: %v", err)
}
defer m.Close()
t.Logf("Model loaded in %s", loadDur)
if m.ModelType() != "qwen2" {
t.Errorf("ModelType() = %q, want %q", m.ModelType(), "qwen2")
}
ctx := context.Background()
genStart := time.Now()
var tokens []inference.Token
var output strings.Builder
for tok := range m.Generate(ctx, "What is 2+2?", inference.WithMaxTokens(32)) {
tokens = append(tokens, tok)
output.WriteString(tok.Text)
}
genDur := time.Since(genStart)
if err := m.Err(); err != nil {
t.Fatalf("Generate error: %v", err)
}
nTokens := len(tokens)
if nTokens == 0 {
t.Fatal("Generate produced no tokens")
}
tps := float64(nTokens) / genDur.Seconds()
t.Logf("Generated %d tokens in %s (%.1f tok/s)", nTokens, genDur, tps)
t.Logf("Output: %s", output.String())
for i, tok := range tokens {
t.Logf(" [%d] id=%d %q", i, tok.ID, tok.Text)
}
}
// TestQwen2_Chat validates chat template for Qwen2 models.
func TestQwen2_Chat(t *testing.T) {
modelPath := qwen2ModelPath(t)
m, err := inference.LoadModel(modelPath)
if err != nil {
t.Fatalf("LoadModel: %v", err)
}
defer m.Close()
ctx := context.Background()
var output strings.Builder
var count int
for tok := range m.Chat(ctx, []inference.Message{
{Role: "user", Content: "Reply with exactly one word: the capital of France."},
}, inference.WithMaxTokens(32)) {
output.WriteString(tok.Text)
count++
}
if err := m.Err(); err != nil {
t.Fatalf("Chat error: %v", err)
}
if count == 0 {
t.Fatal("Chat produced no tokens")
}
t.Logf("Chat output (%d tokens): %s", count, output.String())
}