go-mlx/model/qwen3.go
Snider cae7ef05e8 feat: extract go-mlx from go-ai as standalone Metal inference package
Split mlx/ directory from forge.lthn.ai/core/go-ai into its own module.
Rewrites import paths, adds CLAUDE.md/TODO.md/FINDINGS.md for dedicated
Claude sessions. Zero external Go deps — pure CGO + mlx-c v0.4.1.

Co-Authored-By: Virgil <virgil@lethean.io>
2026-02-19 17:57:37 +00:00

337 lines
10 KiB
Go

//go:build darwin && arm64
package model
import (
"encoding/json"
"fmt"
"log/slog"
"math"
"os"
"path/filepath"
"forge.lthn.ai/core/go-mlx"
"forge.lthn.ai/core/go-mlx/cache"
"forge.lthn.ai/core/go-mlx/tokenizer"
)
// Qwen3Config holds Qwen 3 model configuration.
type Qwen3Config struct {
HiddenSize int32 `json:"hidden_size"`
NumHiddenLayers int32 `json:"num_hidden_layers"`
IntermediateSize int32 `json:"intermediate_size"`
NumAttentionHeads int32 `json:"num_attention_heads"`
NumKeyValueHeads int32 `json:"num_key_value_heads"`
HeadDim int32 `json:"head_dim"`
VocabSize int32 `json:"vocab_size"`
RMSNormEps float32 `json:"rms_norm_eps"`
RopeTheta float32 `json:"rope_theta"`
MaxPositionEmbeddings int32 `json:"max_position_embeddings"`
Quantization *QuantizationConfig `json:"-"`
Scale float32 `json:"-"` // 1/sqrt(head_dim)
}
// Qwen3Model is the Qwen 3 text model.
type Qwen3Model struct {
EmbedTokens *mlx.Embedding
Layers []*Qwen3DecoderLayer
Norm *mlx.RMSNormModule
Output *mlx.Linear
Tok *tokenizer.Tokenizer
Cfg *Qwen3Config
}
// Qwen3DecoderLayer is a single transformer block.
// Qwen 3 uses standard pre-norm residual: norm→attn→add, norm→mlp→add.
type Qwen3DecoderLayer struct {
InputNorm *mlx.RMSNormModule // Pre-attention norm
PostAttnNorm *mlx.RMSNormModule // Pre-MLP norm (confusingly named post_attention_layernorm)
Attention *Qwen3Attention
MLP *Qwen3MLP
}
// Qwen3Attention implements Qwen 3 GQA with Q/K RMS normalization.
type Qwen3Attention struct {
QProj *mlx.Linear
KProj *mlx.Linear
VProj *mlx.Linear
OProj *mlx.Linear
QNorm *mlx.RMSNormModule
KNorm *mlx.RMSNormModule
}
// Qwen3MLP is the SwiGLU feed-forward network: down(silu(gate(x)) * up(x)).
type Qwen3MLP struct {
GateProj *mlx.Linear
UpProj *mlx.Linear
DownProj *mlx.Linear
}
func parseQwen3Config(data []byte) (*Qwen3Config, error) {
var cfg Qwen3Config
if err := json.Unmarshal(data, &cfg); err != nil {
return nil, err
}
// Top-level quantization
var wrapper struct {
Quantization *QuantizationConfig `json:"quantization"`
}
json.Unmarshal(data, &wrapper)
cfg.Quantization = wrapper.Quantization
// Compute scale
if cfg.HeadDim == 0 {
cfg.HeadDim = cfg.HiddenSize / cfg.NumAttentionHeads
}
cfg.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim)))
// Defaults
if cfg.RopeTheta == 0 {
cfg.RopeTheta = 1000000
}
if cfg.RMSNormEps == 0 {
cfg.RMSNormEps = 1e-6
}
if cfg.VocabSize == 0 {
cfg.VocabSize = 151936
}
return &cfg, nil
}
// LoadQwen3 loads a Qwen 3 model from a safetensors directory.
func LoadQwen3(modelPath string) (*Qwen3Model, error) {
data, err := os.ReadFile(filepath.Join(modelPath, "config.json"))
if err != nil {
return nil, fmt.Errorf("qwen3: load config: %w", err)
}
cfg, err := parseQwen3Config(data)
if err != nil {
return nil, fmt.Errorf("qwen3: parse config: %w", err)
}
tok, err := tokenizer.Load(filepath.Join(modelPath, "tokenizer.json"))
if err != nil {
return nil, fmt.Errorf("qwen3: load tokenizer: %w", err)
}
// Load weights from all safetensors files
weights := make(map[string]*mlx.Array)
matches, _ := filepath.Glob(filepath.Join(modelPath, "*.safetensors"))
for _, path := range matches {
for name, arr := range mlx.LoadSafetensors(path) {
weights[name] = arr
}
}
w := func(name string) *mlx.Array { return resolveWeight(weights, name) }
// Quantization setup
q := cfg.Quantization
if q != nil {
slog.Info("qwen3: using quantized inference", "bits", q.Bits, "group_size", q.GroupSize)
}
linear := func(prefix string) *mlx.Linear {
weight := w(prefix + ".weight")
scales := w(prefix + ".scales")
biases := w(prefix + ".biases")
bias := w(prefix + ".bias")
if scales != nil && q != nil {
return mlx.NewQuantizedLinear(weight, scales, biases, bias, q.GroupSize, q.Bits)
}
return mlx.NewLinear(weight, bias)
}
// Embedding
embed := &mlx.Embedding{Weight: w("model.embed_tokens.weight")}
if embedScales := w("model.embed_tokens.scales"); embedScales != nil && q != nil {
embed.Scales = embedScales
embed.Biases = w("model.embed_tokens.biases")
embed.GroupSize = q.GroupSize
embed.Bits = q.Bits
}
m := &Qwen3Model{
EmbedTokens: embed,
Layers: make([]*Qwen3DecoderLayer, cfg.NumHiddenLayers),
Norm: &mlx.RMSNormModule{Weight: w("model.norm.weight")},
Tok: tok,
Cfg: cfg,
}
for i := int32(0); i < cfg.NumHiddenLayers; i++ {
p := fmt.Sprintf("model.layers.%d", i)
m.Layers[i] = &Qwen3DecoderLayer{
InputNorm: &mlx.RMSNormModule{Weight: w(p + ".input_layernorm.weight")},
PostAttnNorm: &mlx.RMSNormModule{Weight: w(p + ".post_attention_layernorm.weight")},
Attention: &Qwen3Attention{
QProj: linear(p + ".self_attn.q_proj"),
KProj: linear(p + ".self_attn.k_proj"),
VProj: linear(p + ".self_attn.v_proj"),
OProj: linear(p + ".self_attn.o_proj"),
QNorm: &mlx.RMSNormModule{Weight: w(p + ".self_attn.q_norm.weight")},
KNorm: &mlx.RMSNormModule{Weight: w(p + ".self_attn.k_norm.weight")},
},
MLP: &Qwen3MLP{
GateProj: linear(p + ".mlp.gate_proj"),
UpProj: linear(p + ".mlp.up_proj"),
DownProj: linear(p + ".mlp.down_proj"),
},
}
}
// Output head — Qwen 3 has tie_word_embeddings=false, so lm_head is separate
lmHeadWeight := w("lm_head.weight")
if lmHeadWeight != nil {
lmHeadScales := w("lm_head.scales")
if lmHeadScales != nil && q != nil {
m.Output = mlx.NewQuantizedLinear(lmHeadWeight, lmHeadScales, w("lm_head.biases"), nil, q.GroupSize, q.Bits)
} else {
m.Output = mlx.NewLinear(lmHeadWeight, nil)
}
} else {
m.Output = m.EmbedTokens.AsLinear()
}
// Materialise all weights onto Metal
var allArrays []*mlx.Array
for _, a := range weights {
allArrays = append(allArrays, a)
}
mlx.Materialize(allArrays...)
slog.Info("qwen3: model loaded",
"layers", cfg.NumHiddenLayers,
"hidden", cfg.HiddenSize,
"heads", cfg.NumAttentionHeads,
"kv_heads", cfg.NumKeyValueHeads,
"head_dim", cfg.HeadDim,
"vocab", cfg.VocabSize,
)
return m, nil
}
// Forward runs the Qwen 3 forward pass.
// Unlike Gemma, Qwen does NOT scale embeddings by sqrt(hidden_size).
func (m *Qwen3Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array {
shape := tokens.Shape()
B, L := shape[0], shape[1]
h := m.EmbedTokens.Forward(tokens)
for i, layer := range m.Layers {
h = layer.forward(h, caches[i], B, L, m.Cfg)
}
return m.Output.Forward(m.Norm.Forward(h, m.Cfg.RMSNormEps))
}
func (l *Qwen3DecoderLayer) forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Qwen3Config) *mlx.Array {
// Pre-attention norm → attention → residual add
normed := l.InputNorm.Forward(x, cfg.RMSNormEps)
attnOut := l.Attention.forward(normed, c, B, L, cfg)
h := mlx.Add(x, attnOut)
// Pre-MLP norm → MLP → residual add
normed = l.PostAttnNorm.Forward(h, cfg.RMSNormEps)
mlpOut := l.MLP.forward(normed)
return mlx.Add(h, mlpOut)
}
func (a *Qwen3Attention) forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Qwen3Config) *mlx.Array {
q := a.QProj.Forward(x)
k := a.KProj.Forward(x)
v := a.VProj.Forward(x)
// Reshape to [B, num_heads, L, head_dim]
q = mlx.AsStrided(q, []int32{B, cfg.NumAttentionHeads, L, cfg.HeadDim},
[]int64{int64(L * cfg.NumAttentionHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumAttentionHeads * cfg.HeadDim), 1}, 0)
k = mlx.AsStrided(k, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim},
[]int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0)
v = mlx.AsStrided(v, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim},
[]int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0)
// Q/K RMS normalization (Qwen 3 has this)
q = a.QNorm.Forward(q, cfg.RMSNormEps)
k = a.KNorm.Forward(k, cfg.RMSNormEps)
// RoPE — single theta for all layers (no sliding window)
q = mlx.RoPE(q, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, c.Offset())
k = mlx.RoPE(k, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, c.Offset())
// Update KV cache
k, v = c.Update(k, v, int(L))
// GQA: repeat K/V heads to match Q heads
repeatFactor := cfg.NumAttentionHeads / cfg.NumKeyValueHeads
if repeatFactor > 1 {
k = mlx.RepeatKV(k, repeatFactor)
v = mlx.RepeatKV(v, repeatFactor)
}
// Scaled dot-product attention
out := mlx.ScaledDotProductAttention(q, k, v, cfg.Scale, L > 1)
out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.HeadDim)
return a.OProj.Forward(out)
}
// forward computes SwiGLU: down(silu(gate(x)) * up(x)).
func (m *Qwen3MLP) forward(x *mlx.Array) *mlx.Array {
gate := mlx.SiLU(m.GateProj.Forward(x))
return m.DownProj.Forward(mlx.Mul(gate, m.UpProj.Forward(x)))
}
// NewCache creates per-layer KV caches. Qwen 3 uses global attention only.
func (m *Qwen3Model) NewCache() []cache.Cache {
caches := make([]cache.Cache, len(m.Layers))
for i := range caches {
caches[i] = cache.NewKVCache()
}
return caches
}
// NumLayers returns the number of transformer layers.
func (m *Qwen3Model) NumLayers() int { return len(m.Layers) }
// Tokenizer returns the model's tokenizer.
func (m *Qwen3Model) Tokenizer() *tokenizer.Tokenizer { return m.Tok }
// ModelType returns the architecture identifier.
func (m *Qwen3Model) ModelType() string { return "qwen3" }
// ApplyLoRA wraps target projection layers with LoRA adapters.
func (m *Qwen3Model) ApplyLoRA(cfg mlx.LoRAConfig) *mlx.LoRAAdapter {
adapter := &mlx.LoRAAdapter{
Layers: make(map[string]*mlx.LoRALinear),
Config: cfg,
}
for i, layer := range m.Layers {
prefix := fmt.Sprintf("model.layers.%d.self_attn", i)
for _, target := range cfg.TargetKeys {
var proj *mlx.Linear
switch target {
case "q_proj":
proj = layer.Attention.QProj
case "k_proj":
proj = layer.Attention.KProj
case "v_proj":
proj = layer.Attention.VProj
case "o_proj":
proj = layer.Attention.OProj
}
if proj != nil {
lora := mlx.NewLoRALinear(proj, cfg.Rank, cfg.Alpha)
proj.LoRA = lora
adapter.Layers[prefix+"."+target] = lora
}
}
}
return adapter
}