feat: handle nested text_config and language_model weight prefix

Supports both multimodal (Gemma3ForConditionalGeneration) and
text-only configs. Resolves weights with language_model. prefix
fallback. Computes head_dim from hidden_size when missing.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Claude 2026-02-16 02:00:40 +00:00 committed by Snider
parent 97ffe91cde
commit f4303fada2

View file

@ -113,19 +113,30 @@ func geluApprox(x *mlx.Array) *mlx.Array {
return mlx.Mul(mlx.MulScalar(x, 0.5), onePlusT) return mlx.Mul(mlx.MulScalar(x, 0.5), onePlusT)
} }
// LoadGemma3 loads a Gemma 3 text model from a directory. // parseConfig handles both flat and nested (text_config) Gemma 3 configs.
func LoadGemma3(modelPath string) (*GemmaModel, error) { func parseConfig(data []byte) (*TextConfig, error) {
data, err := os.ReadFile(filepath.Join(modelPath, "config.json")) // Try parsing text_config from multimodal wrapper
if err != nil { var wrapper struct {
return nil, fmt.Errorf("gemma3: load config: %w", err) TextConfig TextConfig `json:"text_config"`
ModelType string `json:"model_type"`
}
if err := json.Unmarshal(data, &wrapper); err != nil {
return nil, err
} }
var cfg TextConfig cfg := wrapper.TextConfig
if err := json.Unmarshal(data, &cfg); err != nil {
return nil, fmt.Errorf("gemma3: parse config: %w", err) // If text_config was empty, try top-level
if cfg.NumHiddenLayers == 0 {
if err := json.Unmarshal(data, &cfg); err != nil {
return nil, err
}
} }
// Defaults // Compute defaults
if cfg.HeadDim == 0 && cfg.NumAttentionHeads > 0 {
cfg.HeadDim = cfg.HiddenSize / cfg.NumAttentionHeads
}
cfg.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim))) cfg.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim)))
if cfg.RopeTheta == 0 { if cfg.RopeTheta == 0 {
cfg.RopeTheta = 1000000 cfg.RopeTheta = 1000000
@ -139,6 +150,35 @@ func LoadGemma3(modelPath string) (*GemmaModel, error) {
if cfg.SlidingWindowPattern == 0 { if cfg.SlidingWindowPattern == 0 {
cfg.SlidingWindowPattern = 6 cfg.SlidingWindowPattern = 6
} }
if cfg.VocabSize == 0 {
cfg.VocabSize = 262208 // Gemma 3 default
}
return &cfg, nil
}
// resolveWeight looks up a weight with optional "language_model." prefix.
func resolveWeight(weights map[string]*mlx.Array, name string) *mlx.Array {
if w, ok := weights[name]; ok {
return w
}
if w, ok := weights["language_model."+name]; ok {
return w
}
return nil
}
// LoadGemma3 loads a Gemma 3 text model from a directory.
func LoadGemma3(modelPath string) (*GemmaModel, error) {
data, err := os.ReadFile(filepath.Join(modelPath, "config.json"))
if err != nil {
return nil, fmt.Errorf("gemma3: load config: %w", err)
}
cfg, err := parseConfig(data)
if err != nil {
return nil, fmt.Errorf("gemma3: parse config: %w", err)
}
// Load tokenizer // Load tokenizer
tok, err := tokenizer.Load(filepath.Join(modelPath, "tokenizer.json")) tok, err := tokenizer.Load(filepath.Join(modelPath, "tokenizer.json"))
@ -156,41 +196,48 @@ func LoadGemma3(modelPath string) (*GemmaModel, error) {
} }
m := &GemmaModel{ m := &GemmaModel{
EmbedTokens: &mlx.Embedding{Weight: weights["model.embed_tokens.weight"]}, EmbedTokens: &mlx.Embedding{Weight: resolveWeight(weights, "model.embed_tokens.weight")},
Layers: make([]*DecoderLayer, cfg.NumHiddenLayers), Layers: make([]*DecoderLayer, cfg.NumHiddenLayers),
Norm: &mlx.RMSNormModule{Weight: weights["model.norm.weight"]}, Norm: &mlx.RMSNormModule{Weight: resolveWeight(weights, "model.norm.weight")},
Tok: tok, Tok: tok,
Cfg: &cfg, Cfg: cfg,
} }
// Helper to resolve weight with language_model. prefix fallback
w := func(name string) *mlx.Array { return resolveWeight(weights, name) }
// Initialize layers // Initialize layers
for i := int32(0); i < cfg.NumHiddenLayers; i++ { for i := int32(0); i < cfg.NumHiddenLayers; i++ {
prefix := fmt.Sprintf("model.layers.%d", i) prefix := fmt.Sprintf("model.layers.%d", i)
m.Layers[i] = &DecoderLayer{ m.Layers[i] = &DecoderLayer{
InputNorm: &mlx.RMSNormModule{Weight: weights[prefix+".input_layernorm.weight"]}, InputNorm: &mlx.RMSNormModule{Weight: w(prefix + ".input_layernorm.weight")},
PostAttnNorm: &mlx.RMSNormModule{Weight: weights[prefix+".post_attention_layernorm.weight"]}, PostAttnNorm: &mlx.RMSNormModule{Weight: w(prefix + ".post_attention_layernorm.weight")},
PreFFNorm: &mlx.RMSNormModule{Weight: weights[prefix+".pre_feedforward_layernorm.weight"]}, PreFFNorm: &mlx.RMSNormModule{Weight: w(prefix + ".pre_feedforward_layernorm.weight")},
PostFFNorm: &mlx.RMSNormModule{Weight: weights[prefix+".post_feedforward_layernorm.weight"]}, PostFFNorm: &mlx.RMSNormModule{Weight: w(prefix + ".post_feedforward_layernorm.weight")},
Attention: &Attention{ Attention: &Attention{
QProj: mlx.NewLinear(weights[prefix+".self_attn.q_proj.weight"], nil), QProj: mlx.NewLinear(w(prefix+".self_attn.q_proj.weight"), nil),
KProj: mlx.NewLinear(weights[prefix+".self_attn.k_proj.weight"], nil), KProj: mlx.NewLinear(w(prefix+".self_attn.k_proj.weight"), nil),
VProj: mlx.NewLinear(weights[prefix+".self_attn.v_proj.weight"], nil), VProj: mlx.NewLinear(w(prefix+".self_attn.v_proj.weight"), nil),
OProj: mlx.NewLinear(weights[prefix+".self_attn.o_proj.weight"], nil), OProj: mlx.NewLinear(w(prefix+".self_attn.o_proj.weight"), nil),
QNorm: &mlx.RMSNormModule{Weight: weights[prefix+".self_attn.q_norm.weight"]}, QNorm: &mlx.RMSNormModule{Weight: w(prefix + ".self_attn.q_norm.weight")},
KNorm: &mlx.RMSNormModule{Weight: weights[prefix+".self_attn.k_norm.weight"]}, KNorm: &mlx.RMSNormModule{Weight: w(prefix + ".self_attn.k_norm.weight")},
}, },
MLP: &MLP{ MLP: &MLP{
GateProj: mlx.NewLinear(weights[prefix+".mlp.gate_proj.weight"], nil), GateProj: mlx.NewLinear(w(prefix+".mlp.gate_proj.weight"), nil),
UpProj: mlx.NewLinear(weights[prefix+".mlp.up_proj.weight"], nil), UpProj: mlx.NewLinear(w(prefix+".mlp.up_proj.weight"), nil),
DownProj: mlx.NewLinear(weights[prefix+".mlp.down_proj.weight"], nil), DownProj: mlx.NewLinear(w(prefix+".mlp.down_proj.weight"), nil),
}, },
LayerIdx: i, LayerIdx: i,
IsSliding: isLayerSliding(i, cfg.SlidingWindowPattern), IsSliding: isLayerSliding(i, cfg.SlidingWindowPattern),
} }
} }
// Tied embeddings // Tied embeddings — check for separate lm_head first
m.Output = mlx.NewLinear(m.EmbedTokens.Weight, nil) lmHead := w("lm_head.weight")
if lmHead == nil {
lmHead = m.EmbedTokens.Weight // tied
}
m.Output = mlx.NewLinear(lmHead, nil)
// Materialize all weights // Materialize all weights
var allArrays []*mlx.Array var allArrays []*mlx.Array