From c8e66918c3fcc4fb1c2eeb0c993a1a6e096c4082 Mon Sep 17 00:00:00 2001 From: Claude Date: Mon, 16 Feb 2026 02:12:31 +0000 Subject: [PATCH] feat: support quantized inference (4-bit) for Gemma 3 - Add QuantizedLinear with QuantizedMatmul for packed uint32 weights - Add quantized Embedding with Dequantize before lookup - Parse quantization config (group_size, bits) from config.json - Detect .scales/.biases weight tensors and auto-select quantized path - Add Dequantize op wrapping mlx_dequantize - Add safety guard to KVCache.Update for malformed shapes - Handle tied embeddings with quantization (AsLinear helper) Co-Authored-By: Claude Opus 4.6 --- pkg/mlx/cache/cache.go | 23 ++++++++++++ pkg/mlx/model/gemma3.go | 83 +++++++++++++++++++++++++++++++---------- pkg/mlx/nn.go | 53 +++++++++++++++++++++++--- pkg/mlx/ops.go | 12 ++++++ 4 files changed, 146 insertions(+), 25 deletions(-) diff --git a/pkg/mlx/cache/cache.go b/pkg/mlx/cache/cache.go index c3e8f92..6c82785 100644 --- a/pkg/mlx/cache/cache.go +++ b/pkg/mlx/cache/cache.go @@ -35,6 +35,14 @@ func NewKVCache() *KVCache { func (c *KVCache) Update(k, v *mlx.Array, seqLen int) (*mlx.Array, *mlx.Array) { prev := c.offset shape := k.Shape() + if len(shape) < 4 { + // K/V must be [B, H, L, D] — if not, pass through unchanged + if c.keys == nil { + c.keys, c.values = k, v + } + c.offset += seqLen + return c.keys, c.values + } B, H, Dk := shape[0], shape[1], shape[3] Dv := v.Shape()[3] @@ -103,6 +111,13 @@ func (c *RotatingKVCache) Update(k, v *mlx.Array, seqLen int) (*mlx.Array, *mlx. func (c *RotatingKVCache) updateInPlace(k, v *mlx.Array) (*mlx.Array, *mlx.Array) { shape := k.Shape() + if len(shape) < 4 { + if c.keys == nil { + c.keys, c.values = k, v + } + c.offset++ + return c.keys, c.values + } B, H, Dk := shape[0], shape[1], shape[3] Dv := v.Shape()[3] @@ -139,6 +154,14 @@ func (c *RotatingKVCache) updateInPlace(k, v *mlx.Array) (*mlx.Array, *mlx.Array func (c *RotatingKVCache) updateConcat(k, v *mlx.Array, seqLen int) (*mlx.Array, *mlx.Array) { shape := k.Shape() + if len(shape) < 4 { + // K/V must be [B, H, L, D] — if not, pass through unchanged + if c.keys == nil { + c.keys, c.values = k, v + } + c.offset += seqLen + return c.keys, c.values + } B, H, Dk := shape[0], shape[1], shape[3] Dv := v.Shape()[3] diff --git a/pkg/mlx/model/gemma3.go b/pkg/mlx/model/gemma3.go index 3fc87f6..4892218 100644 --- a/pkg/mlx/model/gemma3.go +++ b/pkg/mlx/model/gemma3.go @@ -6,6 +6,7 @@ package model import ( "encoding/json" "fmt" + "log/slog" "math" "os" "path/filepath" @@ -15,6 +16,12 @@ import ( "forge.lthn.ai/core/cli/pkg/mlx/tokenizer" ) +// QuantizationConfig holds quantization parameters from config.json. +type QuantizationConfig struct { + GroupSize int `json:"group_size"` + Bits int `json:"bits"` +} + // TextConfig holds Gemma 3 text model configuration. type TextConfig struct { HiddenSize int32 `json:"hidden_size"` @@ -31,7 +38,8 @@ type TextConfig struct { SlidingWindow int32 `json:"sliding_window"` SlidingWindowPattern int32 `json:"sliding_window_pattern"` - Scale float32 `json:"-"` // Computed: 1/sqrt(head_dim) + Quantization *QuantizationConfig `json:"-"` // Parsed separately from top-level + Scale float32 `json:"-"` // Computed: 1/sqrt(head_dim) } // GemmaModel is the Gemma 3 text model. @@ -117,8 +125,9 @@ func geluApprox(x *mlx.Array) *mlx.Array { func parseConfig(data []byte) (*TextConfig, error) { // Try parsing text_config from multimodal wrapper var wrapper struct { - TextConfig TextConfig `json:"text_config"` - ModelType string `json:"model_type"` + TextConfig TextConfig `json:"text_config"` + ModelType string `json:"model_type"` + Quantization *QuantizationConfig `json:"quantization"` } if err := json.Unmarshal(data, &wrapper); err != nil { return nil, err @@ -133,6 +142,9 @@ func parseConfig(data []byte) (*TextConfig, error) { } } + // Quantization is always top-level + cfg.Quantization = wrapper.Quantization + // Compute defaults if cfg.HeadDim == 0 && cfg.NumAttentionHeads > 0 { cfg.HeadDim = cfg.HiddenSize / cfg.NumAttentionHeads @@ -195,17 +207,41 @@ func LoadGemma3(modelPath string) (*GemmaModel, error) { } } + // Helper to resolve weight with language_model. prefix fallback + w := func(name string) *mlx.Array { return resolveWeight(weights, name) } + + // Helper to create linear layer (quantized or dense) + q := cfg.Quantization + if q != nil { + slog.Info("mlx: using quantized inference", "bits", q.Bits, "group_size", q.GroupSize) + } + linear := func(prefix string) *mlx.Linear { + weight := w(prefix + ".weight") + scales := w(prefix + ".scales") + biases := w(prefix + ".biases") + if scales != nil && q != nil { + return mlx.NewQuantizedLinear(weight, scales, biases, nil, q.GroupSize, q.Bits) + } + return mlx.NewLinear(weight, nil) + } + + // Create embedding (quantized or dense) + embed := &mlx.Embedding{Weight: w("model.embed_tokens.weight")} + if embedScales := w("model.embed_tokens.scales"); embedScales != nil && q != nil { + embed.Scales = embedScales + embed.Biases = w("model.embed_tokens.biases") + embed.GroupSize = q.GroupSize + embed.Bits = q.Bits + } + m := &GemmaModel{ - EmbedTokens: &mlx.Embedding{Weight: resolveWeight(weights, "model.embed_tokens.weight")}, + EmbedTokens: embed, Layers: make([]*DecoderLayer, cfg.NumHiddenLayers), - Norm: &mlx.RMSNormModule{Weight: resolveWeight(weights, "model.norm.weight")}, + Norm: &mlx.RMSNormModule{Weight: w("model.norm.weight")}, Tok: tok, Cfg: cfg, } - // Helper to resolve weight with language_model. prefix fallback - w := func(name string) *mlx.Array { return resolveWeight(weights, name) } - // Initialize layers for i := int32(0); i < cfg.NumHiddenLayers; i++ { prefix := fmt.Sprintf("model.layers.%d", i) @@ -215,29 +251,36 @@ func LoadGemma3(modelPath string) (*GemmaModel, error) { PreFFNorm: &mlx.RMSNormModule{Weight: w(prefix + ".pre_feedforward_layernorm.weight")}, PostFFNorm: &mlx.RMSNormModule{Weight: w(prefix + ".post_feedforward_layernorm.weight")}, Attention: &Attention{ - QProj: mlx.NewLinear(w(prefix+".self_attn.q_proj.weight"), nil), - KProj: mlx.NewLinear(w(prefix+".self_attn.k_proj.weight"), nil), - VProj: mlx.NewLinear(w(prefix+".self_attn.v_proj.weight"), nil), - OProj: mlx.NewLinear(w(prefix+".self_attn.o_proj.weight"), nil), + QProj: linear(prefix + ".self_attn.q_proj"), + KProj: linear(prefix + ".self_attn.k_proj"), + VProj: linear(prefix + ".self_attn.v_proj"), + OProj: linear(prefix + ".self_attn.o_proj"), QNorm: &mlx.RMSNormModule{Weight: w(prefix + ".self_attn.q_norm.weight")}, KNorm: &mlx.RMSNormModule{Weight: w(prefix + ".self_attn.k_norm.weight")}, }, MLP: &MLP{ - GateProj: mlx.NewLinear(w(prefix+".mlp.gate_proj.weight"), nil), - UpProj: mlx.NewLinear(w(prefix+".mlp.up_proj.weight"), nil), - DownProj: mlx.NewLinear(w(prefix+".mlp.down_proj.weight"), nil), + GateProj: linear(prefix + ".mlp.gate_proj"), + UpProj: linear(prefix + ".mlp.up_proj"), + DownProj: linear(prefix + ".mlp.down_proj"), }, LayerIdx: i, IsSliding: isLayerSliding(i, cfg.SlidingWindowPattern), } } - // Tied embeddings — check for separate lm_head first - lmHead := w("lm_head.weight") - if lmHead == nil { - lmHead = m.EmbedTokens.Weight // tied + // Output head — check for separate lm_head first, else tie to embeddings + lmHeadWeight := w("lm_head.weight") + if lmHeadWeight != nil { + lmHeadScales := w("lm_head.scales") + if lmHeadScales != nil && q != nil { + m.Output = mlx.NewQuantizedLinear(lmHeadWeight, lmHeadScales, w("lm_head.biases"), nil, q.GroupSize, q.Bits) + } else { + m.Output = mlx.NewLinear(lmHeadWeight, nil) + } + } else { + // Tied embeddings — reuse embed_tokens weights (with quantization if present) + m.Output = m.EmbedTokens.AsLinear() } - m.Output = mlx.NewLinear(lmHead, nil) // Materialize all weights var allArrays []*mlx.Array diff --git a/pkg/mlx/nn.go b/pkg/mlx/nn.go index e1dcb4d..f06aada 100644 --- a/pkg/mlx/nn.go +++ b/pkg/mlx/nn.go @@ -3,19 +3,42 @@ package mlx // Linear is a fully-connected layer: y = x @ W.T + bias. +// For quantized models, set Scales/Biases/GroupSize/Bits to use QuantizedMatmul. type Linear struct { - Weight *Array `weight:"weight"` - Bias *Array `weight:"bias"` + Weight *Array `weight:"weight"` + Scales *Array `weight:"scales"` + Biases *Array `weight:"biases"` + Bias *Array `weight:"bias"` + GroupSize int + Bits int } -// NewLinear creates a Linear layer with optional bias. +// NewLinear creates a dense Linear layer with optional bias. func NewLinear(weight, bias *Array) *Linear { return &Linear{Weight: weight, Bias: bias} } +// NewQuantizedLinear creates a quantized Linear layer. +func NewQuantizedLinear(weight, scales, biases, bias *Array, groupSize, bits int) *Linear { + return &Linear{ + Weight: weight, + Scales: scales, + Biases: biases, + Bias: bias, + GroupSize: groupSize, + Bits: bits, + } +} + // Forward computes the linear transformation. +// Uses QuantizedMatmul when quantization parameters are present. func (l *Linear) Forward(x *Array) *Array { - out := Matmul(x, Transpose(l.Weight)) + var out *Array + if l.Scales != nil { + out = QuantizedMatmul(x, l.Weight, l.Scales, l.Biases, true, l.GroupSize, l.Bits) + } else { + out = Matmul(x, Transpose(l.Weight)) + } if l.Bias != nil && l.Bias.Valid() { out = Add(out, l.Bias) } @@ -23,15 +46,35 @@ func (l *Linear) Forward(x *Array) *Array { } // Embedding is a lookup table for token embeddings. +// For quantized models, set Scales/Biases/GroupSize/Bits to dequantize before lookup. type Embedding struct { - Weight *Array `weight:"weight"` + Weight *Array `weight:"weight"` + Scales *Array `weight:"scales"` + Biases *Array `weight:"biases"` + GroupSize int + Bits int } // Forward looks up embeddings for the given token indices. func (e *Embedding) Forward(indices *Array) *Array { + if e.Scales != nil { + w := Dequantize(e.Weight, e.Scales, e.Biases, e.GroupSize, e.Bits) + return Take(w, indices, 0) + } return Take(e.Weight, indices, 0) } +// AsLinear returns a Linear layer using the embedding weights (for tied output). +func (e *Embedding) AsLinear() *Linear { + return &Linear{ + Weight: e.Weight, + Scales: e.Scales, + Biases: e.Biases, + GroupSize: e.GroupSize, + Bits: e.Bits, + } +} + // RMSNormModule is an RMS normalization layer wrapping the fused kernel. type RMSNormModule struct { Weight *Array `weight:"weight"` diff --git a/pkg/mlx/ops.go b/pkg/mlx/ops.go index c9ba959..70f7efd 100644 --- a/pkg/mlx/ops.go +++ b/pkg/mlx/ops.go @@ -304,6 +304,18 @@ func Argpartition(a *Array, kth, axis int) *Array { return out } +// Dequantize restores a quantized array to full precision. +func Dequantize(w, scales, biases *Array, groupSize, bits int) *Array { + out := New("DEQUANTIZE", w, scales, biases) + gs := C.mlx_optional_int{value: C.int(groupSize), has_value: C._Bool(true)} + b := C.mlx_optional_int{value: C.int(bits), has_value: C._Bool(true)} + mode := C.CString("default") + defer C.free(unsafe.Pointer(mode)) + noDtype := C.mlx_optional_dtype{has_value: C._Bool(false)} + C.mlx_dequantize(&out.ctx, w.ctx, scales.ctx, biases.ctx, gs, b, mode, noDtype, DefaultStream().ctx) + return out +} + // PutAlongAxis places values into array at indices along axis. func PutAlongAxis(a, indices, values *Array, axis int) *Array { out := New("PUT_ALONG_AXIS", a, indices, values)