feat: support quantized inference (4-bit) for Gemma 3

- Add QuantizedLinear with QuantizedMatmul for packed uint32 weights
- Add quantized Embedding with Dequantize before lookup
- Parse quantization config (group_size, bits) from config.json
- Detect .scales/.biases weight tensors and auto-select quantized path
- Add Dequantize op wrapping mlx_dequantize
- Add safety guard to KVCache.Update for malformed shapes
- Handle tied embeddings with quantization (AsLinear helper)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Claude 2026-02-16 02:12:31 +00:00 committed by Snider
parent 7fc1571f93
commit d92d097a7f
4 changed files with 146 additions and 25 deletions

View file

@ -35,6 +35,14 @@ func NewKVCache() *KVCache {
func (c *KVCache) Update(k, v *mlx.Array, seqLen int) (*mlx.Array, *mlx.Array) {
prev := c.offset
shape := k.Shape()
if len(shape) < 4 {
// K/V must be [B, H, L, D] — if not, pass through unchanged
if c.keys == nil {
c.keys, c.values = k, v
}
c.offset += seqLen
return c.keys, c.values
}
B, H, Dk := shape[0], shape[1], shape[3]
Dv := v.Shape()[3]
@ -103,6 +111,13 @@ func (c *RotatingKVCache) Update(k, v *mlx.Array, seqLen int) (*mlx.Array, *mlx.
func (c *RotatingKVCache) updateInPlace(k, v *mlx.Array) (*mlx.Array, *mlx.Array) {
shape := k.Shape()
if len(shape) < 4 {
if c.keys == nil {
c.keys, c.values = k, v
}
c.offset++
return c.keys, c.values
}
B, H, Dk := shape[0], shape[1], shape[3]
Dv := v.Shape()[3]
@ -139,6 +154,14 @@ func (c *RotatingKVCache) updateInPlace(k, v *mlx.Array) (*mlx.Array, *mlx.Array
func (c *RotatingKVCache) updateConcat(k, v *mlx.Array, seqLen int) (*mlx.Array, *mlx.Array) {
shape := k.Shape()
if len(shape) < 4 {
// K/V must be [B, H, L, D] — if not, pass through unchanged
if c.keys == nil {
c.keys, c.values = k, v
}
c.offset += seqLen
return c.keys, c.values
}
B, H, Dk := shape[0], shape[1], shape[3]
Dv := v.Shape()[3]

View file

@ -6,6 +6,7 @@ package model
import (
"encoding/json"
"fmt"
"log/slog"
"math"
"os"
"path/filepath"
@ -15,6 +16,12 @@ import (
"forge.lthn.ai/core/cli/pkg/mlx/tokenizer"
)
// QuantizationConfig holds quantization parameters from config.json.
type QuantizationConfig struct {
GroupSize int `json:"group_size"`
Bits int `json:"bits"`
}
// TextConfig holds Gemma 3 text model configuration.
type TextConfig struct {
HiddenSize int32 `json:"hidden_size"`
@ -31,6 +38,7 @@ type TextConfig struct {
SlidingWindow int32 `json:"sliding_window"`
SlidingWindowPattern int32 `json:"sliding_window_pattern"`
Quantization *QuantizationConfig `json:"-"` // Parsed separately from top-level
Scale float32 `json:"-"` // Computed: 1/sqrt(head_dim)
}
@ -119,6 +127,7 @@ func parseConfig(data []byte) (*TextConfig, error) {
var wrapper struct {
TextConfig TextConfig `json:"text_config"`
ModelType string `json:"model_type"`
Quantization *QuantizationConfig `json:"quantization"`
}
if err := json.Unmarshal(data, &wrapper); err != nil {
return nil, err
@ -133,6 +142,9 @@ func parseConfig(data []byte) (*TextConfig, error) {
}
}
// Quantization is always top-level
cfg.Quantization = wrapper.Quantization
// Compute defaults
if cfg.HeadDim == 0 && cfg.NumAttentionHeads > 0 {
cfg.HeadDim = cfg.HiddenSize / cfg.NumAttentionHeads
@ -195,17 +207,41 @@ func LoadGemma3(modelPath string) (*GemmaModel, error) {
}
}
// Helper to resolve weight with language_model. prefix fallback
w := func(name string) *mlx.Array { return resolveWeight(weights, name) }
// Helper to create linear layer (quantized or dense)
q := cfg.Quantization
if q != nil {
slog.Info("mlx: using quantized inference", "bits", q.Bits, "group_size", q.GroupSize)
}
linear := func(prefix string) *mlx.Linear {
weight := w(prefix + ".weight")
scales := w(prefix + ".scales")
biases := w(prefix + ".biases")
if scales != nil && q != nil {
return mlx.NewQuantizedLinear(weight, scales, biases, nil, q.GroupSize, q.Bits)
}
return mlx.NewLinear(weight, nil)
}
// Create embedding (quantized or dense)
embed := &mlx.Embedding{Weight: w("model.embed_tokens.weight")}
if embedScales := w("model.embed_tokens.scales"); embedScales != nil && q != nil {
embed.Scales = embedScales
embed.Biases = w("model.embed_tokens.biases")
embed.GroupSize = q.GroupSize
embed.Bits = q.Bits
}
m := &GemmaModel{
EmbedTokens: &mlx.Embedding{Weight: resolveWeight(weights, "model.embed_tokens.weight")},
EmbedTokens: embed,
Layers: make([]*DecoderLayer, cfg.NumHiddenLayers),
Norm: &mlx.RMSNormModule{Weight: resolveWeight(weights, "model.norm.weight")},
Norm: &mlx.RMSNormModule{Weight: w("model.norm.weight")},
Tok: tok,
Cfg: cfg,
}
// Helper to resolve weight with language_model. prefix fallback
w := func(name string) *mlx.Array { return resolveWeight(weights, name) }
// Initialize layers
for i := int32(0); i < cfg.NumHiddenLayers; i++ {
prefix := fmt.Sprintf("model.layers.%d", i)
@ -215,29 +251,36 @@ func LoadGemma3(modelPath string) (*GemmaModel, error) {
PreFFNorm: &mlx.RMSNormModule{Weight: w(prefix + ".pre_feedforward_layernorm.weight")},
PostFFNorm: &mlx.RMSNormModule{Weight: w(prefix + ".post_feedforward_layernorm.weight")},
Attention: &Attention{
QProj: mlx.NewLinear(w(prefix+".self_attn.q_proj.weight"), nil),
KProj: mlx.NewLinear(w(prefix+".self_attn.k_proj.weight"), nil),
VProj: mlx.NewLinear(w(prefix+".self_attn.v_proj.weight"), nil),
OProj: mlx.NewLinear(w(prefix+".self_attn.o_proj.weight"), nil),
QProj: linear(prefix + ".self_attn.q_proj"),
KProj: linear(prefix + ".self_attn.k_proj"),
VProj: linear(prefix + ".self_attn.v_proj"),
OProj: linear(prefix + ".self_attn.o_proj"),
QNorm: &mlx.RMSNormModule{Weight: w(prefix + ".self_attn.q_norm.weight")},
KNorm: &mlx.RMSNormModule{Weight: w(prefix + ".self_attn.k_norm.weight")},
},
MLP: &MLP{
GateProj: mlx.NewLinear(w(prefix+".mlp.gate_proj.weight"), nil),
UpProj: mlx.NewLinear(w(prefix+".mlp.up_proj.weight"), nil),
DownProj: mlx.NewLinear(w(prefix+".mlp.down_proj.weight"), nil),
GateProj: linear(prefix + ".mlp.gate_proj"),
UpProj: linear(prefix + ".mlp.up_proj"),
DownProj: linear(prefix + ".mlp.down_proj"),
},
LayerIdx: i,
IsSliding: isLayerSliding(i, cfg.SlidingWindowPattern),
}
}
// Tied embeddings — check for separate lm_head first
lmHead := w("lm_head.weight")
if lmHead == nil {
lmHead = m.EmbedTokens.Weight // tied
// Output head — check for separate lm_head first, else tie to embeddings
lmHeadWeight := w("lm_head.weight")
if lmHeadWeight != nil {
lmHeadScales := w("lm_head.scales")
if lmHeadScales != nil && q != nil {
m.Output = mlx.NewQuantizedLinear(lmHeadWeight, lmHeadScales, w("lm_head.biases"), nil, q.GroupSize, q.Bits)
} else {
m.Output = mlx.NewLinear(lmHeadWeight, nil)
}
} else {
// Tied embeddings — reuse embed_tokens weights (with quantization if present)
m.Output = m.EmbedTokens.AsLinear()
}
m.Output = mlx.NewLinear(lmHead, nil)
// Materialize all weights
var allArrays []*mlx.Array

View file

@ -3,19 +3,42 @@
package mlx
// Linear is a fully-connected layer: y = x @ W.T + bias.
// For quantized models, set Scales/Biases/GroupSize/Bits to use QuantizedMatmul.
type Linear struct {
Weight *Array `weight:"weight"`
Scales *Array `weight:"scales"`
Biases *Array `weight:"biases"`
Bias *Array `weight:"bias"`
GroupSize int
Bits int
}
// NewLinear creates a Linear layer with optional bias.
// NewLinear creates a dense Linear layer with optional bias.
func NewLinear(weight, bias *Array) *Linear {
return &Linear{Weight: weight, Bias: bias}
}
// NewQuantizedLinear creates a quantized Linear layer.
func NewQuantizedLinear(weight, scales, biases, bias *Array, groupSize, bits int) *Linear {
return &Linear{
Weight: weight,
Scales: scales,
Biases: biases,
Bias: bias,
GroupSize: groupSize,
Bits: bits,
}
}
// Forward computes the linear transformation.
// Uses QuantizedMatmul when quantization parameters are present.
func (l *Linear) Forward(x *Array) *Array {
out := Matmul(x, Transpose(l.Weight))
var out *Array
if l.Scales != nil {
out = QuantizedMatmul(x, l.Weight, l.Scales, l.Biases, true, l.GroupSize, l.Bits)
} else {
out = Matmul(x, Transpose(l.Weight))
}
if l.Bias != nil && l.Bias.Valid() {
out = Add(out, l.Bias)
}
@ -23,15 +46,35 @@ func (l *Linear) Forward(x *Array) *Array {
}
// Embedding is a lookup table for token embeddings.
// For quantized models, set Scales/Biases/GroupSize/Bits to dequantize before lookup.
type Embedding struct {
Weight *Array `weight:"weight"`
Scales *Array `weight:"scales"`
Biases *Array `weight:"biases"`
GroupSize int
Bits int
}
// Forward looks up embeddings for the given token indices.
func (e *Embedding) Forward(indices *Array) *Array {
if e.Scales != nil {
w := Dequantize(e.Weight, e.Scales, e.Biases, e.GroupSize, e.Bits)
return Take(w, indices, 0)
}
return Take(e.Weight, indices, 0)
}
// AsLinear returns a Linear layer using the embedding weights (for tied output).
func (e *Embedding) AsLinear() *Linear {
return &Linear{
Weight: e.Weight,
Scales: e.Scales,
Biases: e.Biases,
GroupSize: e.GroupSize,
Bits: e.Bits,
}
}
// RMSNormModule is an RMS normalization layer wrapping the fused kernel.
type RMSNormModule struct {
Weight *Array `weight:"weight"`

View file

@ -304,6 +304,18 @@ func Argpartition(a *Array, kth, axis int) *Array {
return out
}
// Dequantize restores a quantized array to full precision.
func Dequantize(w, scales, biases *Array, groupSize, bits int) *Array {
out := New("DEQUANTIZE", w, scales, biases)
gs := C.mlx_optional_int{value: C.int(groupSize), has_value: C._Bool(true)}
b := C.mlx_optional_int{value: C.int(bits), has_value: C._Bool(true)}
mode := C.CString("default")
defer C.free(unsafe.Pointer(mode))
noDtype := C.mlx_optional_dtype{has_value: C._Bool(false)}
C.mlx_dequantize(&out.ctx, w.ctx, scales.ctx, biases.ctx, gs, b, mode, noDtype, DefaultStream().ctx)
return out
}
// PutAlongAxis places values into array at indices along axis.
func PutAlongAxis(a, indices, values *Array, axis int) *Array {
out := New("PUT_ALONG_AXIS", a, indices, values)