feat(metal): add Qwen2 model support (DeepSeek R1 validated)
Qwen2 and Qwen3 share the same architecture — Qwen3 adds Q/K RMS normalization which Qwen2 lacks. The loader auto-detects the variant from weight presence and reports the correct ModelType(). - Add "qwen2" to architecture dispatch in model.go - Make Q/K norm optional in attention forward (nil-safe check) - Store detected model type on Qwen3Model struct - Add "qwen2" to chat template routing - DeepSeek R1 7B (4-bit): 27 tok/s on M3 Ultra - 2 new tests: inference + chat Co-Authored-By: Virgil <virgil@lethean.io> Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
a2493e0242
commit
535b04d5d6
5 changed files with 115 additions and 12 deletions
3
TODO.md
3
TODO.md
|
|
@ -15,7 +15,8 @@ Dispatched from core/go orchestration. Pick up tasks in order.
|
|||
|
||||
- [x] **Gemma3-1B inference validation** — ✅ End-to-end inference works. 4-bit quantised Gemma3-1B loads and generates coherently at **46 tok/s** on M3 Ultra. Fixed: `model_type: "gemma3_text"` not matched in architecture dispatch, GPT-2 BPE false detection on 262K SentencePiece vocab (checked `Ġthe` instead of bare `Ġ`). 3 new tests: inference (greedy, timing), chat template, context cancellation.
|
||||
- [x] **Model loading robustness** — ✅ 24 new tests in model_test.go covering: missing/invalid config.json, unsupported architecture, `gemma3_text` dispatch, missing tokenizer, missing safetensors (was a nil-pointer panic — fixed with early error return in both LoadGemma3 and LoadQwen3), config parsing defaults/quantization/nested text_config, `isLayerSliding`, `resolveWeight` with prefix fallback.
|
||||
- [ ] **Add Llama model support** — Only Gemma3 and Qwen3 exist. Llama architecture would cover Meta's model family (Llama 3, CodeLlama).
|
||||
- [x] **Add Qwen2 model support** — ✅ Qwen2 architecture (used by DeepSeek R1) now supported. Shares Qwen3 loader with optional Q/K RMS normalization (Qwen3 has it, Qwen2 does not). Auto-detected from weight presence. DeepSeek R1 7B: **27 tok/s** on M3 Ultra. 2 new tests.
|
||||
- [ ] **Add Llama model support** — Llama architecture would cover Meta's model family (Llama 3, CodeLlama). Needs safetensors model on disk for validation (e.g. `mlx-community/Meta-Llama-3.1-8B-Instruct-4bit`).
|
||||
|
||||
## Phase 3: Training Pipeline
|
||||
|
||||
|
|
|
|||
|
|
@ -198,7 +198,7 @@ func (m *Model) formatChat(messages []ChatMessage) string {
|
|||
switch m.modelType {
|
||||
case "gemma3":
|
||||
return formatGemmaChat(messages)
|
||||
case "qwen3":
|
||||
case "qwen2", "qwen3":
|
||||
return formatQwenChat(messages)
|
||||
default:
|
||||
var s string
|
||||
|
|
|
|||
|
|
@ -63,7 +63,7 @@ func loadModel(modelPath string) (InternalModel, error) {
|
|||
}
|
||||
|
||||
switch probe.ModelType {
|
||||
case "qwen3":
|
||||
case "qwen3", "qwen2":
|
||||
return LoadQwen3(modelPath)
|
||||
case "gemma3", "gemma3_text", "gemma2":
|
||||
return LoadGemma3(modelPath)
|
||||
|
|
|
|||
|
|
@ -28,15 +28,17 @@ type Qwen3Config struct {
|
|||
Scale float32 `json:"-"` // 1/sqrt(head_dim)
|
||||
}
|
||||
|
||||
// Qwen3Model is the Qwen 3 text model.
|
||||
// Qwen3Model is the Qwen 2/3 text model.
|
||||
// Qwen 2 and 3 share the same architecture; Qwen 3 adds Q/K RMS normalization.
|
||||
type Qwen3Model struct {
|
||||
EmbedTokens *Embedding
|
||||
Layers []*Qwen3DecoderLayer
|
||||
Norm *RMSNormModule
|
||||
Output *Linear
|
||||
|
||||
Tok *Tokenizer
|
||||
Cfg *Qwen3Config
|
||||
Tok *Tokenizer
|
||||
Cfg *Qwen3Config
|
||||
modelType string // "qwen2" or "qwen3"
|
||||
}
|
||||
|
||||
// Qwen3DecoderLayer is a single transformer block.
|
||||
|
|
@ -100,7 +102,7 @@ func parseQwen3Config(data []byte) (*Qwen3Config, error) {
|
|||
return &cfg, nil
|
||||
}
|
||||
|
||||
// LoadQwen3 loads a Qwen 3 model from a safetensors directory.
|
||||
// LoadQwen3 loads a Qwen 2 or 3 model from a safetensors directory.
|
||||
func LoadQwen3(modelPath string) (*Qwen3Model, error) {
|
||||
data, err := os.ReadFile(filepath.Join(modelPath, "config.json"))
|
||||
if err != nil {
|
||||
|
|
@ -159,12 +161,20 @@ func LoadQwen3(modelPath string) (*Qwen3Model, error) {
|
|||
embed.Bits = q.Bits
|
||||
}
|
||||
|
||||
// Detect Qwen 2 vs 3 by presence of Q/K normalization weights.
|
||||
hasQKNorm := w("model.layers.0.self_attn.q_norm.weight") != nil
|
||||
detectedType := "qwen3"
|
||||
if !hasQKNorm {
|
||||
detectedType = "qwen2"
|
||||
}
|
||||
|
||||
m := &Qwen3Model{
|
||||
EmbedTokens: embed,
|
||||
Layers: make([]*Qwen3DecoderLayer, cfg.NumHiddenLayers),
|
||||
Norm: &RMSNormModule{Weight: w("model.norm.weight")},
|
||||
Tok: tok,
|
||||
Cfg: cfg,
|
||||
modelType: detectedType,
|
||||
}
|
||||
|
||||
for i := int32(0); i < cfg.NumHiddenLayers; i++ {
|
||||
|
|
@ -260,9 +270,13 @@ func (a *Qwen3Attention) forward(x *Array, c Cache, B, L int32, cfg *Qwen3Config
|
|||
v = AsStrided(v, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim},
|
||||
[]int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0)
|
||||
|
||||
// Q/K RMS normalization (Qwen 3 has this)
|
||||
q = a.QNorm.Forward(q, cfg.RMSNormEps)
|
||||
k = a.KNorm.Forward(k, cfg.RMSNormEps)
|
||||
// Q/K RMS normalization (Qwen 3 has this; Qwen 2 does not)
|
||||
if a.QNorm != nil && a.QNorm.Weight != nil {
|
||||
q = a.QNorm.Forward(q, cfg.RMSNormEps)
|
||||
}
|
||||
if a.KNorm != nil && a.KNorm.Weight != nil {
|
||||
k = a.KNorm.Forward(k, cfg.RMSNormEps)
|
||||
}
|
||||
|
||||
// RoPE — single theta for all layers (no sliding window)
|
||||
q = RoPE(q, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, c.Offset())
|
||||
|
|
@ -305,8 +319,8 @@ func (m *Qwen3Model) NumLayers() int { return len(m.Layers) }
|
|||
// Tokenizer returns the model's tokenizer.
|
||||
func (m *Qwen3Model) Tokenizer() *Tokenizer { return m.Tok }
|
||||
|
||||
// ModelType returns the architecture identifier.
|
||||
func (m *Qwen3Model) ModelType() string { return "qwen3" }
|
||||
// ModelType returns the architecture identifier ("qwen2" or "qwen3").
|
||||
func (m *Qwen3Model) ModelType() string { return m.modelType }
|
||||
|
||||
// ApplyLoRA wraps target projection layers with LoRA adapters.
|
||||
func (m *Qwen3Model) ApplyLoRA(cfg LoRAConfig) *LoRAAdapter {
|
||||
|
|
|
|||
88
mlx_test.go
88
mlx_test.go
|
|
@ -294,3 +294,91 @@ func TestGemma3_1B_ContextCancel(t *testing.T) {
|
|||
}
|
||||
t.Logf("Stopped after %d tokens", count)
|
||||
}
|
||||
|
||||
// --- Qwen2 (DeepSeek R1 7B) tests ---
|
||||
|
||||
func qwen2ModelPath(t *testing.T) string {
|
||||
t.Helper()
|
||||
paths := []string{
|
||||
"/Volumes/Data/lem/LEK-DeepSeek-R1-7B",
|
||||
}
|
||||
for _, p := range paths {
|
||||
if _, err := os.Stat(p); err == nil {
|
||||
return p
|
||||
}
|
||||
}
|
||||
t.Skip("no Qwen2/DeepSeek model available")
|
||||
return ""
|
||||
}
|
||||
|
||||
// TestQwen2_Inference validates Qwen2 arch (DeepSeek R1 7B) end-to-end.
|
||||
func TestQwen2_Inference(t *testing.T) {
|
||||
modelPath := qwen2ModelPath(t)
|
||||
|
||||
loadStart := time.Now()
|
||||
m, err := inference.LoadModel(modelPath)
|
||||
loadDur := time.Since(loadStart)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadModel: %v", err)
|
||||
}
|
||||
defer m.Close()
|
||||
t.Logf("Model loaded in %s", loadDur)
|
||||
|
||||
if m.ModelType() != "qwen2" {
|
||||
t.Errorf("ModelType() = %q, want %q", m.ModelType(), "qwen2")
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
genStart := time.Now()
|
||||
var tokens []inference.Token
|
||||
var output strings.Builder
|
||||
for tok := range m.Generate(ctx, "What is 2+2?", inference.WithMaxTokens(32)) {
|
||||
tokens = append(tokens, tok)
|
||||
output.WriteString(tok.Text)
|
||||
}
|
||||
genDur := time.Since(genStart)
|
||||
|
||||
if err := m.Err(); err != nil {
|
||||
t.Fatalf("Generate error: %v", err)
|
||||
}
|
||||
|
||||
nTokens := len(tokens)
|
||||
if nTokens == 0 {
|
||||
t.Fatal("Generate produced no tokens")
|
||||
}
|
||||
|
||||
tps := float64(nTokens) / genDur.Seconds()
|
||||
t.Logf("Generated %d tokens in %s (%.1f tok/s)", nTokens, genDur, tps)
|
||||
t.Logf("Output: %s", output.String())
|
||||
for i, tok := range tokens {
|
||||
t.Logf(" [%d] id=%d %q", i, tok.ID, tok.Text)
|
||||
}
|
||||
}
|
||||
|
||||
// TestQwen2_Chat validates chat template for Qwen2 models.
|
||||
func TestQwen2_Chat(t *testing.T) {
|
||||
modelPath := qwen2ModelPath(t)
|
||||
|
||||
m, err := inference.LoadModel(modelPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadModel: %v", err)
|
||||
}
|
||||
defer m.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
var output strings.Builder
|
||||
var count int
|
||||
for tok := range m.Chat(ctx, []inference.Message{
|
||||
{Role: "user", Content: "Reply with exactly one word: the capital of France."},
|
||||
}, inference.WithMaxTokens(32)) {
|
||||
output.WriteString(tok.Text)
|
||||
count++
|
||||
}
|
||||
if err := m.Err(); err != nil {
|
||||
t.Fatalf("Chat error: %v", err)
|
||||
}
|
||||
if count == 0 {
|
||||
t.Fatal("Chat produced no tokens")
|
||||
}
|
||||
t.Logf("Chat output (%d tokens): %s", count, output.String())
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue