fix: use affine quantization mode and infer head_dim from weights
This commit is contained in:
parent
f416f7e529
commit
b1f76ce7db
2 changed files with 19 additions and 6 deletions
|
|
@ -148,11 +148,10 @@ func parseConfig(data []byte) (*TextConfig, error) {
|
||||||
// Quantization is always top-level
|
// Quantization is always top-level
|
||||||
cfg.Quantization = wrapper.Quantization
|
cfg.Quantization = wrapper.Quantization
|
||||||
|
|
||||||
// Compute defaults
|
// Compute scale (head_dim may be inferred later from weights if not in config)
|
||||||
if cfg.HeadDim == 0 && cfg.NumAttentionHeads > 0 {
|
if cfg.HeadDim > 0 {
|
||||||
cfg.HeadDim = cfg.HiddenSize / cfg.NumAttentionHeads
|
|
||||||
}
|
|
||||||
cfg.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim)))
|
cfg.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim)))
|
||||||
|
}
|
||||||
if cfg.RopeTheta == 0 {
|
if cfg.RopeTheta == 0 {
|
||||||
cfg.RopeTheta = 1000000
|
cfg.RopeTheta = 1000000
|
||||||
}
|
}
|
||||||
|
|
@ -213,6 +212,20 @@ func LoadGemma3(modelPath string) (*GemmaModel, error) {
|
||||||
// Helper to resolve weight with language_model. prefix fallback
|
// Helper to resolve weight with language_model. prefix fallback
|
||||||
w := func(name string) *mlx.Array { return resolveWeight(weights, name) }
|
w := func(name string) *mlx.Array { return resolveWeight(weights, name) }
|
||||||
|
|
||||||
|
// Infer head_dim from q_proj weight shape when not in config.
|
||||||
|
// Gemma 3 uses head_dim=256 which differs from hidden_size/num_heads.
|
||||||
|
if cfg.HeadDim == 0 {
|
||||||
|
qWeight := w("model.layers.0.self_attn.q_proj.weight")
|
||||||
|
if qWeight != nil {
|
||||||
|
qShape := qWeight.Shape()
|
||||||
|
if len(qShape) > 0 {
|
||||||
|
cfg.HeadDim = qShape[0] / cfg.NumAttentionHeads
|
||||||
|
cfg.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim)))
|
||||||
|
slog.Info("mlx: inferred head_dim from q_proj weight", "head_dim", cfg.HeadDim)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Helper to create linear layer (quantized or dense)
|
// Helper to create linear layer (quantized or dense)
|
||||||
q := cfg.Quantization
|
q := cfg.Quantization
|
||||||
if q != nil {
|
if q != nil {
|
||||||
|
|
|
||||||
|
|
@ -138,7 +138,7 @@ func QuantizedMatmul(x, w, scales, biases *Array, transpose bool, groupSize, bit
|
||||||
out := New("QMATMUL", x, w, scales, biases)
|
out := New("QMATMUL", x, w, scales, biases)
|
||||||
gs := C.mlx_optional_int{value: C.int(groupSize), has_value: C._Bool(true)}
|
gs := C.mlx_optional_int{value: C.int(groupSize), has_value: C._Bool(true)}
|
||||||
b := C.mlx_optional_int{value: C.int(bits), has_value: C._Bool(true)}
|
b := C.mlx_optional_int{value: C.int(bits), has_value: C._Bool(true)}
|
||||||
mode := C.CString("default")
|
mode := C.CString("affine")
|
||||||
defer C.free(unsafe.Pointer(mode))
|
defer C.free(unsafe.Pointer(mode))
|
||||||
C.mlx_quantized_matmul(
|
C.mlx_quantized_matmul(
|
||||||
&out.ctx, x.ctx, w.ctx, scales.ctx, biases.ctx,
|
&out.ctx, x.ctx, w.ctx, scales.ctx, biases.ctx,
|
||||||
|
|
@ -309,7 +309,7 @@ func Dequantize(w, scales, biases *Array, groupSize, bits int) *Array {
|
||||||
out := New("DEQUANTIZE", w, scales, biases)
|
out := New("DEQUANTIZE", w, scales, biases)
|
||||||
gs := C.mlx_optional_int{value: C.int(groupSize), has_value: C._Bool(true)}
|
gs := C.mlx_optional_int{value: C.int(groupSize), has_value: C._Bool(true)}
|
||||||
b := C.mlx_optional_int{value: C.int(bits), has_value: C._Bool(true)}
|
b := C.mlx_optional_int{value: C.int(bits), has_value: C._Bool(true)}
|
||||||
mode := C.CString("default")
|
mode := C.CString("affine")
|
||||||
defer C.free(unsafe.Pointer(mode))
|
defer C.free(unsafe.Pointer(mode))
|
||||||
noDtype := C.mlx_optional_dtype{has_value: C._Bool(false)}
|
noDtype := C.mlx_optional_dtype{has_value: C._Bool(false)}
|
||||||
C.mlx_dequantize(&out.ctx, w.ctx, scales.ctx, biases.ctx, gs, b, mode, noDtype, DefaultStream().ctx)
|
C.mlx_dequantize(&out.ctx, w.ctx, scales.ctx, biases.ctx, gs, b, mode, noDtype, DefaultStream().ctx)
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue