feat/updates #1

Merged
Snider merged 51 commits from feat/updates into dev 2026-02-16 05:54:07 +00:00
Showing only changes of commit 2a67653bf7 - Show all commits

View file

@ -113,19 +113,30 @@ func geluApprox(x *mlx.Array) *mlx.Array {
return mlx.Mul(mlx.MulScalar(x, 0.5), onePlusT)
}
// LoadGemma3 loads a Gemma 3 text model from a directory.
func LoadGemma3(modelPath string) (*GemmaModel, error) {
data, err := os.ReadFile(filepath.Join(modelPath, "config.json"))
if err != nil {
return nil, fmt.Errorf("gemma3: load config: %w", err)
// parseConfig handles both flat and nested (text_config) Gemma 3 configs.
func parseConfig(data []byte) (*TextConfig, error) {
// Try parsing text_config from multimodal wrapper
var wrapper struct {
TextConfig TextConfig `json:"text_config"`
ModelType string `json:"model_type"`
}
if err := json.Unmarshal(data, &wrapper); err != nil {
return nil, err
}
var cfg TextConfig
if err := json.Unmarshal(data, &cfg); err != nil {
return nil, fmt.Errorf("gemma3: parse config: %w", err)
cfg := wrapper.TextConfig
// If text_config was empty, try top-level
if cfg.NumHiddenLayers == 0 {
if err := json.Unmarshal(data, &cfg); err != nil {
return nil, err
}
}
// Defaults
// Compute defaults
if cfg.HeadDim == 0 && cfg.NumAttentionHeads > 0 {
cfg.HeadDim = cfg.HiddenSize / cfg.NumAttentionHeads
}
cfg.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim)))
if cfg.RopeTheta == 0 {
cfg.RopeTheta = 1000000
@ -139,6 +150,35 @@ func LoadGemma3(modelPath string) (*GemmaModel, error) {
if cfg.SlidingWindowPattern == 0 {
cfg.SlidingWindowPattern = 6
}
if cfg.VocabSize == 0 {
cfg.VocabSize = 262208 // Gemma 3 default
}
return &cfg, nil
}
// resolveWeight looks up a weight with optional "language_model." prefix.
func resolveWeight(weights map[string]*mlx.Array, name string) *mlx.Array {
if w, ok := weights[name]; ok {
return w
}
if w, ok := weights["language_model."+name]; ok {
return w
}
return nil
}
// LoadGemma3 loads a Gemma 3 text model from a directory.
func LoadGemma3(modelPath string) (*GemmaModel, error) {
data, err := os.ReadFile(filepath.Join(modelPath, "config.json"))
if err != nil {
return nil, fmt.Errorf("gemma3: load config: %w", err)
}
cfg, err := parseConfig(data)
if err != nil {
return nil, fmt.Errorf("gemma3: parse config: %w", err)
}
// Load tokenizer
tok, err := tokenizer.Load(filepath.Join(modelPath, "tokenizer.json"))
@ -156,41 +196,48 @@ func LoadGemma3(modelPath string) (*GemmaModel, error) {
}
m := &GemmaModel{
EmbedTokens: &mlx.Embedding{Weight: weights["model.embed_tokens.weight"]},
EmbedTokens: &mlx.Embedding{Weight: resolveWeight(weights, "model.embed_tokens.weight")},
Layers: make([]*DecoderLayer, cfg.NumHiddenLayers),
Norm: &mlx.RMSNormModule{Weight: weights["model.norm.weight"]},
Norm: &mlx.RMSNormModule{Weight: resolveWeight(weights, "model.norm.weight")},
Tok: tok,
Cfg: &cfg,
Cfg: cfg,
}
// Helper to resolve weight with language_model. prefix fallback
w := func(name string) *mlx.Array { return resolveWeight(weights, name) }
// Initialize layers
for i := int32(0); i < cfg.NumHiddenLayers; i++ {
prefix := fmt.Sprintf("model.layers.%d", i)
m.Layers[i] = &DecoderLayer{
InputNorm: &mlx.RMSNormModule{Weight: weights[prefix+".input_layernorm.weight"]},
PostAttnNorm: &mlx.RMSNormModule{Weight: weights[prefix+".post_attention_layernorm.weight"]},
PreFFNorm: &mlx.RMSNormModule{Weight: weights[prefix+".pre_feedforward_layernorm.weight"]},
PostFFNorm: &mlx.RMSNormModule{Weight: weights[prefix+".post_feedforward_layernorm.weight"]},
InputNorm: &mlx.RMSNormModule{Weight: w(prefix + ".input_layernorm.weight")},
PostAttnNorm: &mlx.RMSNormModule{Weight: w(prefix + ".post_attention_layernorm.weight")},
PreFFNorm: &mlx.RMSNormModule{Weight: w(prefix + ".pre_feedforward_layernorm.weight")},
PostFFNorm: &mlx.RMSNormModule{Weight: w(prefix + ".post_feedforward_layernorm.weight")},
Attention: &Attention{
QProj: mlx.NewLinear(weights[prefix+".self_attn.q_proj.weight"], nil),
KProj: mlx.NewLinear(weights[prefix+".self_attn.k_proj.weight"], nil),
VProj: mlx.NewLinear(weights[prefix+".self_attn.v_proj.weight"], nil),
OProj: mlx.NewLinear(weights[prefix+".self_attn.o_proj.weight"], nil),
QNorm: &mlx.RMSNormModule{Weight: weights[prefix+".self_attn.q_norm.weight"]},
KNorm: &mlx.RMSNormModule{Weight: weights[prefix+".self_attn.k_norm.weight"]},
QProj: mlx.NewLinear(w(prefix+".self_attn.q_proj.weight"), nil),
KProj: mlx.NewLinear(w(prefix+".self_attn.k_proj.weight"), nil),
VProj: mlx.NewLinear(w(prefix+".self_attn.v_proj.weight"), nil),
OProj: mlx.NewLinear(w(prefix+".self_attn.o_proj.weight"), nil),
QNorm: &mlx.RMSNormModule{Weight: w(prefix + ".self_attn.q_norm.weight")},
KNorm: &mlx.RMSNormModule{Weight: w(prefix + ".self_attn.k_norm.weight")},
},
MLP: &MLP{
GateProj: mlx.NewLinear(weights[prefix+".mlp.gate_proj.weight"], nil),
UpProj: mlx.NewLinear(weights[prefix+".mlp.up_proj.weight"], nil),
DownProj: mlx.NewLinear(weights[prefix+".mlp.down_proj.weight"], nil),
GateProj: mlx.NewLinear(w(prefix+".mlp.gate_proj.weight"), nil),
UpProj: mlx.NewLinear(w(prefix+".mlp.up_proj.weight"), nil),
DownProj: mlx.NewLinear(w(prefix+".mlp.down_proj.weight"), nil),
},
LayerIdx: i,
IsSliding: isLayerSliding(i, cfg.SlidingWindowPattern),
}
}
// Tied embeddings
m.Output = mlx.NewLinear(m.EmbedTokens.Weight, nil)
// Tied embeddings — check for separate lm_head first
lmHead := w("lm_head.weight")
if lmHead == nil {
lmHead = m.EmbedTokens.Weight // tied
}
m.Output = mlx.NewLinear(lmHead, nil)
// Materialize all weights
var allArrays []*mlx.Array