fix: use affine quantization mode and infer head_dim from weights

This commit is contained in:
Claude 2026-02-16 02:19:33 +00:00
parent ef22946f35
commit 70c32135d0
No known key found for this signature in database
GPG key ID: AF404715446AEB41
2 changed files with 19 additions and 6 deletions

View file

@ -148,11 +148,10 @@ func parseConfig(data []byte) (*TextConfig, error) {
// Quantization is always top-level
cfg.Quantization = wrapper.Quantization
// Compute defaults
if cfg.HeadDim == 0 && cfg.NumAttentionHeads > 0 {
cfg.HeadDim = cfg.HiddenSize / cfg.NumAttentionHeads
// Compute scale (head_dim may be inferred later from weights if not in config)
if cfg.HeadDim > 0 {
cfg.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim)))
}
cfg.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim)))
if cfg.RopeTheta == 0 {
cfg.RopeTheta = 1000000
}
@ -213,6 +212,20 @@ func LoadGemma3(modelPath string) (*GemmaModel, error) {
// Helper to resolve weight with language_model. prefix fallback
w := func(name string) *mlx.Array { return resolveWeight(weights, name) }
// Infer head_dim from q_proj weight shape when not in config.
// Gemma 3 uses head_dim=256 which differs from hidden_size/num_heads.
if cfg.HeadDim == 0 {
qWeight := w("model.layers.0.self_attn.q_proj.weight")
if qWeight != nil {
qShape := qWeight.Shape()
if len(qShape) > 0 {
cfg.HeadDim = qShape[0] / cfg.NumAttentionHeads
cfg.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim)))
slog.Info("mlx: inferred head_dim from q_proj weight", "head_dim", cfg.HeadDim)
}
}
}
// Helper to create linear layer (quantized or dense)
q := cfg.Quantization
if q != nil {

View file

@ -138,7 +138,7 @@ func QuantizedMatmul(x, w, scales, biases *Array, transpose bool, groupSize, bit
out := New("QMATMUL", x, w, scales, biases)
gs := C.mlx_optional_int{value: C.int(groupSize), has_value: C._Bool(true)}
b := C.mlx_optional_int{value: C.int(bits), has_value: C._Bool(true)}
mode := C.CString("default")
mode := C.CString("affine")
defer C.free(unsafe.Pointer(mode))
C.mlx_quantized_matmul(
&out.ctx, x.ctx, w.ctx, scales.ctx, biases.ctx,
@ -309,7 +309,7 @@ func Dequantize(w, scales, biases *Array, groupSize, bits int) *Array {
out := New("DEQUANTIZE", w, scales, biases)
gs := C.mlx_optional_int{value: C.int(groupSize), has_value: C._Bool(true)}
b := C.mlx_optional_int{value: C.int(bits), has_value: C._Bool(true)}
mode := C.CString("default")
mode := C.CString("affine")
defer C.free(unsafe.Pointer(mode))
noDtype := C.mlx_optional_dtype{has_value: C._Bool(false)}
C.mlx_dequantize(&out.ctx, w.ctx, scales.ctx, biases.ctx, gs, b, mode, noDtype, DefaultStream().ctx)