From 2a67653bf72dcfdb1836d920ce6d4f46aee64aa0 Mon Sep 17 00:00:00 2001 From: Claude Date: Mon, 16 Feb 2026 02:00:40 +0000 Subject: [PATCH] feat: handle nested text_config and language_model weight prefix Supports both multimodal (Gemma3ForConditionalGeneration) and text-only configs. Resolves weights with language_model. prefix fallback. Computes head_dim from hidden_size when missing. Co-Authored-By: Claude Opus 4.6 --- pkg/mlx/model/gemma3.go | 101 +++++++++++++++++++++++++++++----------- 1 file changed, 74 insertions(+), 27 deletions(-) diff --git a/pkg/mlx/model/gemma3.go b/pkg/mlx/model/gemma3.go index 6ea5da51..3fc87f6f 100644 --- a/pkg/mlx/model/gemma3.go +++ b/pkg/mlx/model/gemma3.go @@ -113,19 +113,30 @@ func geluApprox(x *mlx.Array) *mlx.Array { return mlx.Mul(mlx.MulScalar(x, 0.5), onePlusT) } -// LoadGemma3 loads a Gemma 3 text model from a directory. -func LoadGemma3(modelPath string) (*GemmaModel, error) { - data, err := os.ReadFile(filepath.Join(modelPath, "config.json")) - if err != nil { - return nil, fmt.Errorf("gemma3: load config: %w", err) +// parseConfig handles both flat and nested (text_config) Gemma 3 configs. +func parseConfig(data []byte) (*TextConfig, error) { + // Try parsing text_config from multimodal wrapper + var wrapper struct { + TextConfig TextConfig `json:"text_config"` + ModelType string `json:"model_type"` + } + if err := json.Unmarshal(data, &wrapper); err != nil { + return nil, err } - var cfg TextConfig - if err := json.Unmarshal(data, &cfg); err != nil { - return nil, fmt.Errorf("gemma3: parse config: %w", err) + cfg := wrapper.TextConfig + + // If text_config was empty, try top-level + if cfg.NumHiddenLayers == 0 { + if err := json.Unmarshal(data, &cfg); err != nil { + return nil, err + } } - // Defaults + // Compute defaults + if cfg.HeadDim == 0 && cfg.NumAttentionHeads > 0 { + cfg.HeadDim = cfg.HiddenSize / cfg.NumAttentionHeads + } cfg.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim))) if cfg.RopeTheta == 0 { cfg.RopeTheta = 1000000 @@ -139,6 +150,35 @@ func LoadGemma3(modelPath string) (*GemmaModel, error) { if cfg.SlidingWindowPattern == 0 { cfg.SlidingWindowPattern = 6 } + if cfg.VocabSize == 0 { + cfg.VocabSize = 262208 // Gemma 3 default + } + + return &cfg, nil +} + +// resolveWeight looks up a weight with optional "language_model." prefix. +func resolveWeight(weights map[string]*mlx.Array, name string) *mlx.Array { + if w, ok := weights[name]; ok { + return w + } + if w, ok := weights["language_model."+name]; ok { + return w + } + return nil +} + +// LoadGemma3 loads a Gemma 3 text model from a directory. +func LoadGemma3(modelPath string) (*GemmaModel, error) { + data, err := os.ReadFile(filepath.Join(modelPath, "config.json")) + if err != nil { + return nil, fmt.Errorf("gemma3: load config: %w", err) + } + + cfg, err := parseConfig(data) + if err != nil { + return nil, fmt.Errorf("gemma3: parse config: %w", err) + } // Load tokenizer tok, err := tokenizer.Load(filepath.Join(modelPath, "tokenizer.json")) @@ -156,41 +196,48 @@ func LoadGemma3(modelPath string) (*GemmaModel, error) { } m := &GemmaModel{ - EmbedTokens: &mlx.Embedding{Weight: weights["model.embed_tokens.weight"]}, + EmbedTokens: &mlx.Embedding{Weight: resolveWeight(weights, "model.embed_tokens.weight")}, Layers: make([]*DecoderLayer, cfg.NumHiddenLayers), - Norm: &mlx.RMSNormModule{Weight: weights["model.norm.weight"]}, + Norm: &mlx.RMSNormModule{Weight: resolveWeight(weights, "model.norm.weight")}, Tok: tok, - Cfg: &cfg, + Cfg: cfg, } + // Helper to resolve weight with language_model. prefix fallback + w := func(name string) *mlx.Array { return resolveWeight(weights, name) } + // Initialize layers for i := int32(0); i < cfg.NumHiddenLayers; i++ { prefix := fmt.Sprintf("model.layers.%d", i) m.Layers[i] = &DecoderLayer{ - InputNorm: &mlx.RMSNormModule{Weight: weights[prefix+".input_layernorm.weight"]}, - PostAttnNorm: &mlx.RMSNormModule{Weight: weights[prefix+".post_attention_layernorm.weight"]}, - PreFFNorm: &mlx.RMSNormModule{Weight: weights[prefix+".pre_feedforward_layernorm.weight"]}, - PostFFNorm: &mlx.RMSNormModule{Weight: weights[prefix+".post_feedforward_layernorm.weight"]}, + InputNorm: &mlx.RMSNormModule{Weight: w(prefix + ".input_layernorm.weight")}, + PostAttnNorm: &mlx.RMSNormModule{Weight: w(prefix + ".post_attention_layernorm.weight")}, + PreFFNorm: &mlx.RMSNormModule{Weight: w(prefix + ".pre_feedforward_layernorm.weight")}, + PostFFNorm: &mlx.RMSNormModule{Weight: w(prefix + ".post_feedforward_layernorm.weight")}, Attention: &Attention{ - QProj: mlx.NewLinear(weights[prefix+".self_attn.q_proj.weight"], nil), - KProj: mlx.NewLinear(weights[prefix+".self_attn.k_proj.weight"], nil), - VProj: mlx.NewLinear(weights[prefix+".self_attn.v_proj.weight"], nil), - OProj: mlx.NewLinear(weights[prefix+".self_attn.o_proj.weight"], nil), - QNorm: &mlx.RMSNormModule{Weight: weights[prefix+".self_attn.q_norm.weight"]}, - KNorm: &mlx.RMSNormModule{Weight: weights[prefix+".self_attn.k_norm.weight"]}, + QProj: mlx.NewLinear(w(prefix+".self_attn.q_proj.weight"), nil), + KProj: mlx.NewLinear(w(prefix+".self_attn.k_proj.weight"), nil), + VProj: mlx.NewLinear(w(prefix+".self_attn.v_proj.weight"), nil), + OProj: mlx.NewLinear(w(prefix+".self_attn.o_proj.weight"), nil), + QNorm: &mlx.RMSNormModule{Weight: w(prefix + ".self_attn.q_norm.weight")}, + KNorm: &mlx.RMSNormModule{Weight: w(prefix + ".self_attn.k_norm.weight")}, }, MLP: &MLP{ - GateProj: mlx.NewLinear(weights[prefix+".mlp.gate_proj.weight"], nil), - UpProj: mlx.NewLinear(weights[prefix+".mlp.up_proj.weight"], nil), - DownProj: mlx.NewLinear(weights[prefix+".mlp.down_proj.weight"], nil), + GateProj: mlx.NewLinear(w(prefix+".mlp.gate_proj.weight"), nil), + UpProj: mlx.NewLinear(w(prefix+".mlp.up_proj.weight"), nil), + DownProj: mlx.NewLinear(w(prefix+".mlp.down_proj.weight"), nil), }, LayerIdx: i, IsSliding: isLayerSliding(i, cfg.SlidingWindowPattern), } } - // Tied embeddings - m.Output = mlx.NewLinear(m.EmbedTokens.Weight, nil) + // Tied embeddings — check for separate lm_head first + lmHead := w("lm_head.weight") + if lmHead == nil { + lmHead = m.EmbedTokens.Weight // tied + } + m.Output = mlx.NewLinear(lmHead, nil) // Materialize all weights var allArrays []*mlx.Array