feat: handle nested text_config and language_model weight prefix
Supports both multimodal (Gemma3ForConditionalGeneration) and text-only configs. Resolves weights with language_model. prefix fallback. Computes head_dim from hidden_size when missing. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
97ffe91cde
commit
f4303fada2
1 changed files with 74 additions and 27 deletions
|
|
@ -113,19 +113,30 @@ func geluApprox(x *mlx.Array) *mlx.Array {
|
|||
return mlx.Mul(mlx.MulScalar(x, 0.5), onePlusT)
|
||||
}
|
||||
|
||||
// LoadGemma3 loads a Gemma 3 text model from a directory.
|
||||
func LoadGemma3(modelPath string) (*GemmaModel, error) {
|
||||
data, err := os.ReadFile(filepath.Join(modelPath, "config.json"))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("gemma3: load config: %w", err)
|
||||
// parseConfig handles both flat and nested (text_config) Gemma 3 configs.
|
||||
func parseConfig(data []byte) (*TextConfig, error) {
|
||||
// Try parsing text_config from multimodal wrapper
|
||||
var wrapper struct {
|
||||
TextConfig TextConfig `json:"text_config"`
|
||||
ModelType string `json:"model_type"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &wrapper); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var cfg TextConfig
|
||||
cfg := wrapper.TextConfig
|
||||
|
||||
// If text_config was empty, try top-level
|
||||
if cfg.NumHiddenLayers == 0 {
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
return nil, fmt.Errorf("gemma3: parse config: %w", err)
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// Defaults
|
||||
// Compute defaults
|
||||
if cfg.HeadDim == 0 && cfg.NumAttentionHeads > 0 {
|
||||
cfg.HeadDim = cfg.HiddenSize / cfg.NumAttentionHeads
|
||||
}
|
||||
cfg.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim)))
|
||||
if cfg.RopeTheta == 0 {
|
||||
cfg.RopeTheta = 1000000
|
||||
|
|
@ -139,6 +150,35 @@ func LoadGemma3(modelPath string) (*GemmaModel, error) {
|
|||
if cfg.SlidingWindowPattern == 0 {
|
||||
cfg.SlidingWindowPattern = 6
|
||||
}
|
||||
if cfg.VocabSize == 0 {
|
||||
cfg.VocabSize = 262208 // Gemma 3 default
|
||||
}
|
||||
|
||||
return &cfg, nil
|
||||
}
|
||||
|
||||
// resolveWeight looks up a weight with optional "language_model." prefix.
|
||||
func resolveWeight(weights map[string]*mlx.Array, name string) *mlx.Array {
|
||||
if w, ok := weights[name]; ok {
|
||||
return w
|
||||
}
|
||||
if w, ok := weights["language_model."+name]; ok {
|
||||
return w
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoadGemma3 loads a Gemma 3 text model from a directory.
|
||||
func LoadGemma3(modelPath string) (*GemmaModel, error) {
|
||||
data, err := os.ReadFile(filepath.Join(modelPath, "config.json"))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("gemma3: load config: %w", err)
|
||||
}
|
||||
|
||||
cfg, err := parseConfig(data)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("gemma3: parse config: %w", err)
|
||||
}
|
||||
|
||||
// Load tokenizer
|
||||
tok, err := tokenizer.Load(filepath.Join(modelPath, "tokenizer.json"))
|
||||
|
|
@ -156,41 +196,48 @@ func LoadGemma3(modelPath string) (*GemmaModel, error) {
|
|||
}
|
||||
|
||||
m := &GemmaModel{
|
||||
EmbedTokens: &mlx.Embedding{Weight: weights["model.embed_tokens.weight"]},
|
||||
EmbedTokens: &mlx.Embedding{Weight: resolveWeight(weights, "model.embed_tokens.weight")},
|
||||
Layers: make([]*DecoderLayer, cfg.NumHiddenLayers),
|
||||
Norm: &mlx.RMSNormModule{Weight: weights["model.norm.weight"]},
|
||||
Norm: &mlx.RMSNormModule{Weight: resolveWeight(weights, "model.norm.weight")},
|
||||
Tok: tok,
|
||||
Cfg: &cfg,
|
||||
Cfg: cfg,
|
||||
}
|
||||
|
||||
// Helper to resolve weight with language_model. prefix fallback
|
||||
w := func(name string) *mlx.Array { return resolveWeight(weights, name) }
|
||||
|
||||
// Initialize layers
|
||||
for i := int32(0); i < cfg.NumHiddenLayers; i++ {
|
||||
prefix := fmt.Sprintf("model.layers.%d", i)
|
||||
m.Layers[i] = &DecoderLayer{
|
||||
InputNorm: &mlx.RMSNormModule{Weight: weights[prefix+".input_layernorm.weight"]},
|
||||
PostAttnNorm: &mlx.RMSNormModule{Weight: weights[prefix+".post_attention_layernorm.weight"]},
|
||||
PreFFNorm: &mlx.RMSNormModule{Weight: weights[prefix+".pre_feedforward_layernorm.weight"]},
|
||||
PostFFNorm: &mlx.RMSNormModule{Weight: weights[prefix+".post_feedforward_layernorm.weight"]},
|
||||
InputNorm: &mlx.RMSNormModule{Weight: w(prefix + ".input_layernorm.weight")},
|
||||
PostAttnNorm: &mlx.RMSNormModule{Weight: w(prefix + ".post_attention_layernorm.weight")},
|
||||
PreFFNorm: &mlx.RMSNormModule{Weight: w(prefix + ".pre_feedforward_layernorm.weight")},
|
||||
PostFFNorm: &mlx.RMSNormModule{Weight: w(prefix + ".post_feedforward_layernorm.weight")},
|
||||
Attention: &Attention{
|
||||
QProj: mlx.NewLinear(weights[prefix+".self_attn.q_proj.weight"], nil),
|
||||
KProj: mlx.NewLinear(weights[prefix+".self_attn.k_proj.weight"], nil),
|
||||
VProj: mlx.NewLinear(weights[prefix+".self_attn.v_proj.weight"], nil),
|
||||
OProj: mlx.NewLinear(weights[prefix+".self_attn.o_proj.weight"], nil),
|
||||
QNorm: &mlx.RMSNormModule{Weight: weights[prefix+".self_attn.q_norm.weight"]},
|
||||
KNorm: &mlx.RMSNormModule{Weight: weights[prefix+".self_attn.k_norm.weight"]},
|
||||
QProj: mlx.NewLinear(w(prefix+".self_attn.q_proj.weight"), nil),
|
||||
KProj: mlx.NewLinear(w(prefix+".self_attn.k_proj.weight"), nil),
|
||||
VProj: mlx.NewLinear(w(prefix+".self_attn.v_proj.weight"), nil),
|
||||
OProj: mlx.NewLinear(w(prefix+".self_attn.o_proj.weight"), nil),
|
||||
QNorm: &mlx.RMSNormModule{Weight: w(prefix + ".self_attn.q_norm.weight")},
|
||||
KNorm: &mlx.RMSNormModule{Weight: w(prefix + ".self_attn.k_norm.weight")},
|
||||
},
|
||||
MLP: &MLP{
|
||||
GateProj: mlx.NewLinear(weights[prefix+".mlp.gate_proj.weight"], nil),
|
||||
UpProj: mlx.NewLinear(weights[prefix+".mlp.up_proj.weight"], nil),
|
||||
DownProj: mlx.NewLinear(weights[prefix+".mlp.down_proj.weight"], nil),
|
||||
GateProj: mlx.NewLinear(w(prefix+".mlp.gate_proj.weight"), nil),
|
||||
UpProj: mlx.NewLinear(w(prefix+".mlp.up_proj.weight"), nil),
|
||||
DownProj: mlx.NewLinear(w(prefix+".mlp.down_proj.weight"), nil),
|
||||
},
|
||||
LayerIdx: i,
|
||||
IsSliding: isLayerSliding(i, cfg.SlidingWindowPattern),
|
||||
}
|
||||
}
|
||||
|
||||
// Tied embeddings
|
||||
m.Output = mlx.NewLinear(m.EmbedTokens.Weight, nil)
|
||||
// Tied embeddings — check for separate lm_head first
|
||||
lmHead := w("lm_head.weight")
|
||||
if lmHead == nil {
|
||||
lmHead = m.EmbedTokens.Weight // tied
|
||||
}
|
||||
m.Output = mlx.NewLinear(lmHead, nil)
|
||||
|
||||
// Materialize all weights
|
||||
var allArrays []*mlx.Array
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue