From 09da05d799922723a4630f83525ae24d7fa0fee0 Mon Sep 17 00:00:00 2001 From: Claude Date: Mon, 16 Feb 2026 02:19:33 +0000 Subject: [PATCH] fix: use affine quantization mode and infer head_dim from weights --- pkg/mlx/model/gemma3.go | 21 +++++++++++++++++---- pkg/mlx/ops.go | 4 ++-- 2 files changed, 19 insertions(+), 6 deletions(-) diff --git a/pkg/mlx/model/gemma3.go b/pkg/mlx/model/gemma3.go index 0bb1e4e..c9ff7c8 100644 --- a/pkg/mlx/model/gemma3.go +++ b/pkg/mlx/model/gemma3.go @@ -148,11 +148,10 @@ func parseConfig(data []byte) (*TextConfig, error) { // Quantization is always top-level cfg.Quantization = wrapper.Quantization - // Compute defaults - if cfg.HeadDim == 0 && cfg.NumAttentionHeads > 0 { - cfg.HeadDim = cfg.HiddenSize / cfg.NumAttentionHeads + // Compute scale (head_dim may be inferred later from weights if not in config) + if cfg.HeadDim > 0 { + cfg.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim))) } - cfg.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim))) if cfg.RopeTheta == 0 { cfg.RopeTheta = 1000000 } @@ -213,6 +212,20 @@ func LoadGemma3(modelPath string) (*GemmaModel, error) { // Helper to resolve weight with language_model. prefix fallback w := func(name string) *mlx.Array { return resolveWeight(weights, name) } + // Infer head_dim from q_proj weight shape when not in config. + // Gemma 3 uses head_dim=256 which differs from hidden_size/num_heads. + if cfg.HeadDim == 0 { + qWeight := w("model.layers.0.self_attn.q_proj.weight") + if qWeight != nil { + qShape := qWeight.Shape() + if len(qShape) > 0 { + cfg.HeadDim = qShape[0] / cfg.NumAttentionHeads + cfg.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim))) + slog.Info("mlx: inferred head_dim from q_proj weight", "head_dim", cfg.HeadDim) + } + } + } + // Helper to create linear layer (quantized or dense) q := cfg.Quantization if q != nil { diff --git a/pkg/mlx/ops.go b/pkg/mlx/ops.go index 70f7efd..7c388f9 100644 --- a/pkg/mlx/ops.go +++ b/pkg/mlx/ops.go @@ -138,7 +138,7 @@ func QuantizedMatmul(x, w, scales, biases *Array, transpose bool, groupSize, bit out := New("QMATMUL", x, w, scales, biases) gs := C.mlx_optional_int{value: C.int(groupSize), has_value: C._Bool(true)} b := C.mlx_optional_int{value: C.int(bits), has_value: C._Bool(true)} - mode := C.CString("default") + mode := C.CString("affine") defer C.free(unsafe.Pointer(mode)) C.mlx_quantized_matmul( &out.ctx, x.ctx, w.ctx, scales.ctx, biases.ctx, @@ -309,7 +309,7 @@ func Dequantize(w, scales, biases *Array, groupSize, bits int) *Array { out := New("DEQUANTIZE", w, scales, biases) gs := C.mlx_optional_int{value: C.int(groupSize), has_value: C._Bool(true)} b := C.mlx_optional_int{value: C.int(bits), has_value: C._Bool(true)} - mode := C.CString("default") + mode := C.CString("affine") defer C.free(unsafe.Pointer(mode)) noDtype := C.mlx_optional_dtype{has_value: C._Bool(false)} C.mlx_dequantize(&out.ctx, w.ctx, scales.ctx, biases.ctx, gs, b, mode, noDtype, DefaultStream().ctx)