feat: support quantized inference (4-bit) for Gemma 3

- Add QuantizedLinear with QuantizedMatmul for packed uint32 weights
- Add quantized Embedding with Dequantize before lookup
- Parse quantization config (group_size, bits) from config.json
- Detect .scales/.biases weight tensors and auto-select quantized path
- Add Dequantize op wrapping mlx_dequantize
- Add safety guard to KVCache.Update for malformed shapes
- Handle tied embeddings with quantization (AsLinear helper)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Claude 2026-02-16 02:12:31 +00:00 committed by Snider
parent 7fc1571f93
commit d92d097a7f
4 changed files with 146 additions and 25 deletions

View file

@ -35,6 +35,14 @@ func NewKVCache() *KVCache {
func (c *KVCache) Update(k, v *mlx.Array, seqLen int) (*mlx.Array, *mlx.Array) { func (c *KVCache) Update(k, v *mlx.Array, seqLen int) (*mlx.Array, *mlx.Array) {
prev := c.offset prev := c.offset
shape := k.Shape() shape := k.Shape()
if len(shape) < 4 {
// K/V must be [B, H, L, D] — if not, pass through unchanged
if c.keys == nil {
c.keys, c.values = k, v
}
c.offset += seqLen
return c.keys, c.values
}
B, H, Dk := shape[0], shape[1], shape[3] B, H, Dk := shape[0], shape[1], shape[3]
Dv := v.Shape()[3] Dv := v.Shape()[3]
@ -103,6 +111,13 @@ func (c *RotatingKVCache) Update(k, v *mlx.Array, seqLen int) (*mlx.Array, *mlx.
func (c *RotatingKVCache) updateInPlace(k, v *mlx.Array) (*mlx.Array, *mlx.Array) { func (c *RotatingKVCache) updateInPlace(k, v *mlx.Array) (*mlx.Array, *mlx.Array) {
shape := k.Shape() shape := k.Shape()
if len(shape) < 4 {
if c.keys == nil {
c.keys, c.values = k, v
}
c.offset++
return c.keys, c.values
}
B, H, Dk := shape[0], shape[1], shape[3] B, H, Dk := shape[0], shape[1], shape[3]
Dv := v.Shape()[3] Dv := v.Shape()[3]
@ -139,6 +154,14 @@ func (c *RotatingKVCache) updateInPlace(k, v *mlx.Array) (*mlx.Array, *mlx.Array
func (c *RotatingKVCache) updateConcat(k, v *mlx.Array, seqLen int) (*mlx.Array, *mlx.Array) { func (c *RotatingKVCache) updateConcat(k, v *mlx.Array, seqLen int) (*mlx.Array, *mlx.Array) {
shape := k.Shape() shape := k.Shape()
if len(shape) < 4 {
// K/V must be [B, H, L, D] — if not, pass through unchanged
if c.keys == nil {
c.keys, c.values = k, v
}
c.offset += seqLen
return c.keys, c.values
}
B, H, Dk := shape[0], shape[1], shape[3] B, H, Dk := shape[0], shape[1], shape[3]
Dv := v.Shape()[3] Dv := v.Shape()[3]

View file

@ -6,6 +6,7 @@ package model
import ( import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"log/slog"
"math" "math"
"os" "os"
"path/filepath" "path/filepath"
@ -15,6 +16,12 @@ import (
"forge.lthn.ai/core/cli/pkg/mlx/tokenizer" "forge.lthn.ai/core/cli/pkg/mlx/tokenizer"
) )
// QuantizationConfig holds quantization parameters from config.json.
type QuantizationConfig struct {
GroupSize int `json:"group_size"`
Bits int `json:"bits"`
}
// TextConfig holds Gemma 3 text model configuration. // TextConfig holds Gemma 3 text model configuration.
type TextConfig struct { type TextConfig struct {
HiddenSize int32 `json:"hidden_size"` HiddenSize int32 `json:"hidden_size"`
@ -31,7 +38,8 @@ type TextConfig struct {
SlidingWindow int32 `json:"sliding_window"` SlidingWindow int32 `json:"sliding_window"`
SlidingWindowPattern int32 `json:"sliding_window_pattern"` SlidingWindowPattern int32 `json:"sliding_window_pattern"`
Scale float32 `json:"-"` // Computed: 1/sqrt(head_dim) Quantization *QuantizationConfig `json:"-"` // Parsed separately from top-level
Scale float32 `json:"-"` // Computed: 1/sqrt(head_dim)
} }
// GemmaModel is the Gemma 3 text model. // GemmaModel is the Gemma 3 text model.
@ -117,8 +125,9 @@ func geluApprox(x *mlx.Array) *mlx.Array {
func parseConfig(data []byte) (*TextConfig, error) { func parseConfig(data []byte) (*TextConfig, error) {
// Try parsing text_config from multimodal wrapper // Try parsing text_config from multimodal wrapper
var wrapper struct { var wrapper struct {
TextConfig TextConfig `json:"text_config"` TextConfig TextConfig `json:"text_config"`
ModelType string `json:"model_type"` ModelType string `json:"model_type"`
Quantization *QuantizationConfig `json:"quantization"`
} }
if err := json.Unmarshal(data, &wrapper); err != nil { if err := json.Unmarshal(data, &wrapper); err != nil {
return nil, err return nil, err
@ -133,6 +142,9 @@ func parseConfig(data []byte) (*TextConfig, error) {
} }
} }
// Quantization is always top-level
cfg.Quantization = wrapper.Quantization
// Compute defaults // Compute defaults
if cfg.HeadDim == 0 && cfg.NumAttentionHeads > 0 { if cfg.HeadDim == 0 && cfg.NumAttentionHeads > 0 {
cfg.HeadDim = cfg.HiddenSize / cfg.NumAttentionHeads cfg.HeadDim = cfg.HiddenSize / cfg.NumAttentionHeads
@ -195,17 +207,41 @@ func LoadGemma3(modelPath string) (*GemmaModel, error) {
} }
} }
// Helper to resolve weight with language_model. prefix fallback
w := func(name string) *mlx.Array { return resolveWeight(weights, name) }
// Helper to create linear layer (quantized or dense)
q := cfg.Quantization
if q != nil {
slog.Info("mlx: using quantized inference", "bits", q.Bits, "group_size", q.GroupSize)
}
linear := func(prefix string) *mlx.Linear {
weight := w(prefix + ".weight")
scales := w(prefix + ".scales")
biases := w(prefix + ".biases")
if scales != nil && q != nil {
return mlx.NewQuantizedLinear(weight, scales, biases, nil, q.GroupSize, q.Bits)
}
return mlx.NewLinear(weight, nil)
}
// Create embedding (quantized or dense)
embed := &mlx.Embedding{Weight: w("model.embed_tokens.weight")}
if embedScales := w("model.embed_tokens.scales"); embedScales != nil && q != nil {
embed.Scales = embedScales
embed.Biases = w("model.embed_tokens.biases")
embed.GroupSize = q.GroupSize
embed.Bits = q.Bits
}
m := &GemmaModel{ m := &GemmaModel{
EmbedTokens: &mlx.Embedding{Weight: resolveWeight(weights, "model.embed_tokens.weight")}, EmbedTokens: embed,
Layers: make([]*DecoderLayer, cfg.NumHiddenLayers), Layers: make([]*DecoderLayer, cfg.NumHiddenLayers),
Norm: &mlx.RMSNormModule{Weight: resolveWeight(weights, "model.norm.weight")}, Norm: &mlx.RMSNormModule{Weight: w("model.norm.weight")},
Tok: tok, Tok: tok,
Cfg: cfg, Cfg: cfg,
} }
// Helper to resolve weight with language_model. prefix fallback
w := func(name string) *mlx.Array { return resolveWeight(weights, name) }
// Initialize layers // Initialize layers
for i := int32(0); i < cfg.NumHiddenLayers; i++ { for i := int32(0); i < cfg.NumHiddenLayers; i++ {
prefix := fmt.Sprintf("model.layers.%d", i) prefix := fmt.Sprintf("model.layers.%d", i)
@ -215,29 +251,36 @@ func LoadGemma3(modelPath string) (*GemmaModel, error) {
PreFFNorm: &mlx.RMSNormModule{Weight: w(prefix + ".pre_feedforward_layernorm.weight")}, PreFFNorm: &mlx.RMSNormModule{Weight: w(prefix + ".pre_feedforward_layernorm.weight")},
PostFFNorm: &mlx.RMSNormModule{Weight: w(prefix + ".post_feedforward_layernorm.weight")}, PostFFNorm: &mlx.RMSNormModule{Weight: w(prefix + ".post_feedforward_layernorm.weight")},
Attention: &Attention{ Attention: &Attention{
QProj: mlx.NewLinear(w(prefix+".self_attn.q_proj.weight"), nil), QProj: linear(prefix + ".self_attn.q_proj"),
KProj: mlx.NewLinear(w(prefix+".self_attn.k_proj.weight"), nil), KProj: linear(prefix + ".self_attn.k_proj"),
VProj: mlx.NewLinear(w(prefix+".self_attn.v_proj.weight"), nil), VProj: linear(prefix + ".self_attn.v_proj"),
OProj: mlx.NewLinear(w(prefix+".self_attn.o_proj.weight"), nil), OProj: linear(prefix + ".self_attn.o_proj"),
QNorm: &mlx.RMSNormModule{Weight: w(prefix + ".self_attn.q_norm.weight")}, QNorm: &mlx.RMSNormModule{Weight: w(prefix + ".self_attn.q_norm.weight")},
KNorm: &mlx.RMSNormModule{Weight: w(prefix + ".self_attn.k_norm.weight")}, KNorm: &mlx.RMSNormModule{Weight: w(prefix + ".self_attn.k_norm.weight")},
}, },
MLP: &MLP{ MLP: &MLP{
GateProj: mlx.NewLinear(w(prefix+".mlp.gate_proj.weight"), nil), GateProj: linear(prefix + ".mlp.gate_proj"),
UpProj: mlx.NewLinear(w(prefix+".mlp.up_proj.weight"), nil), UpProj: linear(prefix + ".mlp.up_proj"),
DownProj: mlx.NewLinear(w(prefix+".mlp.down_proj.weight"), nil), DownProj: linear(prefix + ".mlp.down_proj"),
}, },
LayerIdx: i, LayerIdx: i,
IsSliding: isLayerSliding(i, cfg.SlidingWindowPattern), IsSliding: isLayerSliding(i, cfg.SlidingWindowPattern),
} }
} }
// Tied embeddings — check for separate lm_head first // Output head — check for separate lm_head first, else tie to embeddings
lmHead := w("lm_head.weight") lmHeadWeight := w("lm_head.weight")
if lmHead == nil { if lmHeadWeight != nil {
lmHead = m.EmbedTokens.Weight // tied lmHeadScales := w("lm_head.scales")
if lmHeadScales != nil && q != nil {
m.Output = mlx.NewQuantizedLinear(lmHeadWeight, lmHeadScales, w("lm_head.biases"), nil, q.GroupSize, q.Bits)
} else {
m.Output = mlx.NewLinear(lmHeadWeight, nil)
}
} else {
// Tied embeddings — reuse embed_tokens weights (with quantization if present)
m.Output = m.EmbedTokens.AsLinear()
} }
m.Output = mlx.NewLinear(lmHead, nil)
// Materialize all weights // Materialize all weights
var allArrays []*mlx.Array var allArrays []*mlx.Array

View file

@ -3,19 +3,42 @@
package mlx package mlx
// Linear is a fully-connected layer: y = x @ W.T + bias. // Linear is a fully-connected layer: y = x @ W.T + bias.
// For quantized models, set Scales/Biases/GroupSize/Bits to use QuantizedMatmul.
type Linear struct { type Linear struct {
Weight *Array `weight:"weight"` Weight *Array `weight:"weight"`
Bias *Array `weight:"bias"` Scales *Array `weight:"scales"`
Biases *Array `weight:"biases"`
Bias *Array `weight:"bias"`
GroupSize int
Bits int
} }
// NewLinear creates a Linear layer with optional bias. // NewLinear creates a dense Linear layer with optional bias.
func NewLinear(weight, bias *Array) *Linear { func NewLinear(weight, bias *Array) *Linear {
return &Linear{Weight: weight, Bias: bias} return &Linear{Weight: weight, Bias: bias}
} }
// NewQuantizedLinear creates a quantized Linear layer.
func NewQuantizedLinear(weight, scales, biases, bias *Array, groupSize, bits int) *Linear {
return &Linear{
Weight: weight,
Scales: scales,
Biases: biases,
Bias: bias,
GroupSize: groupSize,
Bits: bits,
}
}
// Forward computes the linear transformation. // Forward computes the linear transformation.
// Uses QuantizedMatmul when quantization parameters are present.
func (l *Linear) Forward(x *Array) *Array { func (l *Linear) Forward(x *Array) *Array {
out := Matmul(x, Transpose(l.Weight)) var out *Array
if l.Scales != nil {
out = QuantizedMatmul(x, l.Weight, l.Scales, l.Biases, true, l.GroupSize, l.Bits)
} else {
out = Matmul(x, Transpose(l.Weight))
}
if l.Bias != nil && l.Bias.Valid() { if l.Bias != nil && l.Bias.Valid() {
out = Add(out, l.Bias) out = Add(out, l.Bias)
} }
@ -23,15 +46,35 @@ func (l *Linear) Forward(x *Array) *Array {
} }
// Embedding is a lookup table for token embeddings. // Embedding is a lookup table for token embeddings.
// For quantized models, set Scales/Biases/GroupSize/Bits to dequantize before lookup.
type Embedding struct { type Embedding struct {
Weight *Array `weight:"weight"` Weight *Array `weight:"weight"`
Scales *Array `weight:"scales"`
Biases *Array `weight:"biases"`
GroupSize int
Bits int
} }
// Forward looks up embeddings for the given token indices. // Forward looks up embeddings for the given token indices.
func (e *Embedding) Forward(indices *Array) *Array { func (e *Embedding) Forward(indices *Array) *Array {
if e.Scales != nil {
w := Dequantize(e.Weight, e.Scales, e.Biases, e.GroupSize, e.Bits)
return Take(w, indices, 0)
}
return Take(e.Weight, indices, 0) return Take(e.Weight, indices, 0)
} }
// AsLinear returns a Linear layer using the embedding weights (for tied output).
func (e *Embedding) AsLinear() *Linear {
return &Linear{
Weight: e.Weight,
Scales: e.Scales,
Biases: e.Biases,
GroupSize: e.GroupSize,
Bits: e.Bits,
}
}
// RMSNormModule is an RMS normalization layer wrapping the fused kernel. // RMSNormModule is an RMS normalization layer wrapping the fused kernel.
type RMSNormModule struct { type RMSNormModule struct {
Weight *Array `weight:"weight"` Weight *Array `weight:"weight"`

View file

@ -304,6 +304,18 @@ func Argpartition(a *Array, kth, axis int) *Array {
return out return out
} }
// Dequantize restores a quantized array to full precision.
func Dequantize(w, scales, biases *Array, groupSize, bits int) *Array {
out := New("DEQUANTIZE", w, scales, biases)
gs := C.mlx_optional_int{value: C.int(groupSize), has_value: C._Bool(true)}
b := C.mlx_optional_int{value: C.int(bits), has_value: C._Bool(true)}
mode := C.CString("default")
defer C.free(unsafe.Pointer(mode))
noDtype := C.mlx_optional_dtype{has_value: C._Bool(false)}
C.mlx_dequantize(&out.ctx, w.ctx, scales.ctx, biases.ctx, gs, b, mode, noDtype, DefaultStream().ctx)
return out
}
// PutAlongAxis places values into array at indices along axis. // PutAlongAxis places values into array at indices along axis.
func PutAlongAxis(a, indices, values *Array, axis int) *Array { func PutAlongAxis(a, indices, values *Array, axis int) *Array {
out := New("PUT_ALONG_AXIS", a, indices, values) out := New("PUT_ALONG_AXIS", a, indices, values)