feat: support quantized inference (4-bit) for Gemma 3
- Add QuantizedLinear with QuantizedMatmul for packed uint32 weights - Add quantized Embedding with Dequantize before lookup - Parse quantization config (group_size, bits) from config.json - Detect .scales/.biases weight tensors and auto-select quantized path - Add Dequantize op wrapping mlx_dequantize - Add safety guard to KVCache.Update for malformed shapes - Handle tied embeddings with quantization (AsLinear helper) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
7fc1571f93
commit
d92d097a7f
4 changed files with 146 additions and 25 deletions
23
pkg/mlx/cache/cache.go
vendored
23
pkg/mlx/cache/cache.go
vendored
|
|
@ -35,6 +35,14 @@ func NewKVCache() *KVCache {
|
||||||
func (c *KVCache) Update(k, v *mlx.Array, seqLen int) (*mlx.Array, *mlx.Array) {
|
func (c *KVCache) Update(k, v *mlx.Array, seqLen int) (*mlx.Array, *mlx.Array) {
|
||||||
prev := c.offset
|
prev := c.offset
|
||||||
shape := k.Shape()
|
shape := k.Shape()
|
||||||
|
if len(shape) < 4 {
|
||||||
|
// K/V must be [B, H, L, D] — if not, pass through unchanged
|
||||||
|
if c.keys == nil {
|
||||||
|
c.keys, c.values = k, v
|
||||||
|
}
|
||||||
|
c.offset += seqLen
|
||||||
|
return c.keys, c.values
|
||||||
|
}
|
||||||
B, H, Dk := shape[0], shape[1], shape[3]
|
B, H, Dk := shape[0], shape[1], shape[3]
|
||||||
Dv := v.Shape()[3]
|
Dv := v.Shape()[3]
|
||||||
|
|
||||||
|
|
@ -103,6 +111,13 @@ func (c *RotatingKVCache) Update(k, v *mlx.Array, seqLen int) (*mlx.Array, *mlx.
|
||||||
|
|
||||||
func (c *RotatingKVCache) updateInPlace(k, v *mlx.Array) (*mlx.Array, *mlx.Array) {
|
func (c *RotatingKVCache) updateInPlace(k, v *mlx.Array) (*mlx.Array, *mlx.Array) {
|
||||||
shape := k.Shape()
|
shape := k.Shape()
|
||||||
|
if len(shape) < 4 {
|
||||||
|
if c.keys == nil {
|
||||||
|
c.keys, c.values = k, v
|
||||||
|
}
|
||||||
|
c.offset++
|
||||||
|
return c.keys, c.values
|
||||||
|
}
|
||||||
B, H, Dk := shape[0], shape[1], shape[3]
|
B, H, Dk := shape[0], shape[1], shape[3]
|
||||||
Dv := v.Shape()[3]
|
Dv := v.Shape()[3]
|
||||||
|
|
||||||
|
|
@ -139,6 +154,14 @@ func (c *RotatingKVCache) updateInPlace(k, v *mlx.Array) (*mlx.Array, *mlx.Array
|
||||||
|
|
||||||
func (c *RotatingKVCache) updateConcat(k, v *mlx.Array, seqLen int) (*mlx.Array, *mlx.Array) {
|
func (c *RotatingKVCache) updateConcat(k, v *mlx.Array, seqLen int) (*mlx.Array, *mlx.Array) {
|
||||||
shape := k.Shape()
|
shape := k.Shape()
|
||||||
|
if len(shape) < 4 {
|
||||||
|
// K/V must be [B, H, L, D] — if not, pass through unchanged
|
||||||
|
if c.keys == nil {
|
||||||
|
c.keys, c.values = k, v
|
||||||
|
}
|
||||||
|
c.offset += seqLen
|
||||||
|
return c.keys, c.values
|
||||||
|
}
|
||||||
B, H, Dk := shape[0], shape[1], shape[3]
|
B, H, Dk := shape[0], shape[1], shape[3]
|
||||||
Dv := v.Shape()[3]
|
Dv := v.Shape()[3]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@ package model
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
"math"
|
"math"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
|
@ -15,6 +16,12 @@ import (
|
||||||
"forge.lthn.ai/core/cli/pkg/mlx/tokenizer"
|
"forge.lthn.ai/core/cli/pkg/mlx/tokenizer"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// QuantizationConfig holds quantization parameters from config.json.
|
||||||
|
type QuantizationConfig struct {
|
||||||
|
GroupSize int `json:"group_size"`
|
||||||
|
Bits int `json:"bits"`
|
||||||
|
}
|
||||||
|
|
||||||
// TextConfig holds Gemma 3 text model configuration.
|
// TextConfig holds Gemma 3 text model configuration.
|
||||||
type TextConfig struct {
|
type TextConfig struct {
|
||||||
HiddenSize int32 `json:"hidden_size"`
|
HiddenSize int32 `json:"hidden_size"`
|
||||||
|
|
@ -31,6 +38,7 @@ type TextConfig struct {
|
||||||
SlidingWindow int32 `json:"sliding_window"`
|
SlidingWindow int32 `json:"sliding_window"`
|
||||||
SlidingWindowPattern int32 `json:"sliding_window_pattern"`
|
SlidingWindowPattern int32 `json:"sliding_window_pattern"`
|
||||||
|
|
||||||
|
Quantization *QuantizationConfig `json:"-"` // Parsed separately from top-level
|
||||||
Scale float32 `json:"-"` // Computed: 1/sqrt(head_dim)
|
Scale float32 `json:"-"` // Computed: 1/sqrt(head_dim)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -119,6 +127,7 @@ func parseConfig(data []byte) (*TextConfig, error) {
|
||||||
var wrapper struct {
|
var wrapper struct {
|
||||||
TextConfig TextConfig `json:"text_config"`
|
TextConfig TextConfig `json:"text_config"`
|
||||||
ModelType string `json:"model_type"`
|
ModelType string `json:"model_type"`
|
||||||
|
Quantization *QuantizationConfig `json:"quantization"`
|
||||||
}
|
}
|
||||||
if err := json.Unmarshal(data, &wrapper); err != nil {
|
if err := json.Unmarshal(data, &wrapper); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|
@ -133,6 +142,9 @@ func parseConfig(data []byte) (*TextConfig, error) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Quantization is always top-level
|
||||||
|
cfg.Quantization = wrapper.Quantization
|
||||||
|
|
||||||
// Compute defaults
|
// Compute defaults
|
||||||
if cfg.HeadDim == 0 && cfg.NumAttentionHeads > 0 {
|
if cfg.HeadDim == 0 && cfg.NumAttentionHeads > 0 {
|
||||||
cfg.HeadDim = cfg.HiddenSize / cfg.NumAttentionHeads
|
cfg.HeadDim = cfg.HiddenSize / cfg.NumAttentionHeads
|
||||||
|
|
@ -195,17 +207,41 @@ func LoadGemma3(modelPath string) (*GemmaModel, error) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Helper to resolve weight with language_model. prefix fallback
|
||||||
|
w := func(name string) *mlx.Array { return resolveWeight(weights, name) }
|
||||||
|
|
||||||
|
// Helper to create linear layer (quantized or dense)
|
||||||
|
q := cfg.Quantization
|
||||||
|
if q != nil {
|
||||||
|
slog.Info("mlx: using quantized inference", "bits", q.Bits, "group_size", q.GroupSize)
|
||||||
|
}
|
||||||
|
linear := func(prefix string) *mlx.Linear {
|
||||||
|
weight := w(prefix + ".weight")
|
||||||
|
scales := w(prefix + ".scales")
|
||||||
|
biases := w(prefix + ".biases")
|
||||||
|
if scales != nil && q != nil {
|
||||||
|
return mlx.NewQuantizedLinear(weight, scales, biases, nil, q.GroupSize, q.Bits)
|
||||||
|
}
|
||||||
|
return mlx.NewLinear(weight, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create embedding (quantized or dense)
|
||||||
|
embed := &mlx.Embedding{Weight: w("model.embed_tokens.weight")}
|
||||||
|
if embedScales := w("model.embed_tokens.scales"); embedScales != nil && q != nil {
|
||||||
|
embed.Scales = embedScales
|
||||||
|
embed.Biases = w("model.embed_tokens.biases")
|
||||||
|
embed.GroupSize = q.GroupSize
|
||||||
|
embed.Bits = q.Bits
|
||||||
|
}
|
||||||
|
|
||||||
m := &GemmaModel{
|
m := &GemmaModel{
|
||||||
EmbedTokens: &mlx.Embedding{Weight: resolveWeight(weights, "model.embed_tokens.weight")},
|
EmbedTokens: embed,
|
||||||
Layers: make([]*DecoderLayer, cfg.NumHiddenLayers),
|
Layers: make([]*DecoderLayer, cfg.NumHiddenLayers),
|
||||||
Norm: &mlx.RMSNormModule{Weight: resolveWeight(weights, "model.norm.weight")},
|
Norm: &mlx.RMSNormModule{Weight: w("model.norm.weight")},
|
||||||
Tok: tok,
|
Tok: tok,
|
||||||
Cfg: cfg,
|
Cfg: cfg,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Helper to resolve weight with language_model. prefix fallback
|
|
||||||
w := func(name string) *mlx.Array { return resolveWeight(weights, name) }
|
|
||||||
|
|
||||||
// Initialize layers
|
// Initialize layers
|
||||||
for i := int32(0); i < cfg.NumHiddenLayers; i++ {
|
for i := int32(0); i < cfg.NumHiddenLayers; i++ {
|
||||||
prefix := fmt.Sprintf("model.layers.%d", i)
|
prefix := fmt.Sprintf("model.layers.%d", i)
|
||||||
|
|
@ -215,29 +251,36 @@ func LoadGemma3(modelPath string) (*GemmaModel, error) {
|
||||||
PreFFNorm: &mlx.RMSNormModule{Weight: w(prefix + ".pre_feedforward_layernorm.weight")},
|
PreFFNorm: &mlx.RMSNormModule{Weight: w(prefix + ".pre_feedforward_layernorm.weight")},
|
||||||
PostFFNorm: &mlx.RMSNormModule{Weight: w(prefix + ".post_feedforward_layernorm.weight")},
|
PostFFNorm: &mlx.RMSNormModule{Weight: w(prefix + ".post_feedforward_layernorm.weight")},
|
||||||
Attention: &Attention{
|
Attention: &Attention{
|
||||||
QProj: mlx.NewLinear(w(prefix+".self_attn.q_proj.weight"), nil),
|
QProj: linear(prefix + ".self_attn.q_proj"),
|
||||||
KProj: mlx.NewLinear(w(prefix+".self_attn.k_proj.weight"), nil),
|
KProj: linear(prefix + ".self_attn.k_proj"),
|
||||||
VProj: mlx.NewLinear(w(prefix+".self_attn.v_proj.weight"), nil),
|
VProj: linear(prefix + ".self_attn.v_proj"),
|
||||||
OProj: mlx.NewLinear(w(prefix+".self_attn.o_proj.weight"), nil),
|
OProj: linear(prefix + ".self_attn.o_proj"),
|
||||||
QNorm: &mlx.RMSNormModule{Weight: w(prefix + ".self_attn.q_norm.weight")},
|
QNorm: &mlx.RMSNormModule{Weight: w(prefix + ".self_attn.q_norm.weight")},
|
||||||
KNorm: &mlx.RMSNormModule{Weight: w(prefix + ".self_attn.k_norm.weight")},
|
KNorm: &mlx.RMSNormModule{Weight: w(prefix + ".self_attn.k_norm.weight")},
|
||||||
},
|
},
|
||||||
MLP: &MLP{
|
MLP: &MLP{
|
||||||
GateProj: mlx.NewLinear(w(prefix+".mlp.gate_proj.weight"), nil),
|
GateProj: linear(prefix + ".mlp.gate_proj"),
|
||||||
UpProj: mlx.NewLinear(w(prefix+".mlp.up_proj.weight"), nil),
|
UpProj: linear(prefix + ".mlp.up_proj"),
|
||||||
DownProj: mlx.NewLinear(w(prefix+".mlp.down_proj.weight"), nil),
|
DownProj: linear(prefix + ".mlp.down_proj"),
|
||||||
},
|
},
|
||||||
LayerIdx: i,
|
LayerIdx: i,
|
||||||
IsSliding: isLayerSliding(i, cfg.SlidingWindowPattern),
|
IsSliding: isLayerSliding(i, cfg.SlidingWindowPattern),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Tied embeddings — check for separate lm_head first
|
// Output head — check for separate lm_head first, else tie to embeddings
|
||||||
lmHead := w("lm_head.weight")
|
lmHeadWeight := w("lm_head.weight")
|
||||||
if lmHead == nil {
|
if lmHeadWeight != nil {
|
||||||
lmHead = m.EmbedTokens.Weight // tied
|
lmHeadScales := w("lm_head.scales")
|
||||||
|
if lmHeadScales != nil && q != nil {
|
||||||
|
m.Output = mlx.NewQuantizedLinear(lmHeadWeight, lmHeadScales, w("lm_head.biases"), nil, q.GroupSize, q.Bits)
|
||||||
|
} else {
|
||||||
|
m.Output = mlx.NewLinear(lmHeadWeight, nil)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Tied embeddings — reuse embed_tokens weights (with quantization if present)
|
||||||
|
m.Output = m.EmbedTokens.AsLinear()
|
||||||
}
|
}
|
||||||
m.Output = mlx.NewLinear(lmHead, nil)
|
|
||||||
|
|
||||||
// Materialize all weights
|
// Materialize all weights
|
||||||
var allArrays []*mlx.Array
|
var allArrays []*mlx.Array
|
||||||
|
|
|
||||||
|
|
@ -3,19 +3,42 @@
|
||||||
package mlx
|
package mlx
|
||||||
|
|
||||||
// Linear is a fully-connected layer: y = x @ W.T + bias.
|
// Linear is a fully-connected layer: y = x @ W.T + bias.
|
||||||
|
// For quantized models, set Scales/Biases/GroupSize/Bits to use QuantizedMatmul.
|
||||||
type Linear struct {
|
type Linear struct {
|
||||||
Weight *Array `weight:"weight"`
|
Weight *Array `weight:"weight"`
|
||||||
|
Scales *Array `weight:"scales"`
|
||||||
|
Biases *Array `weight:"biases"`
|
||||||
Bias *Array `weight:"bias"`
|
Bias *Array `weight:"bias"`
|
||||||
|
GroupSize int
|
||||||
|
Bits int
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewLinear creates a Linear layer with optional bias.
|
// NewLinear creates a dense Linear layer with optional bias.
|
||||||
func NewLinear(weight, bias *Array) *Linear {
|
func NewLinear(weight, bias *Array) *Linear {
|
||||||
return &Linear{Weight: weight, Bias: bias}
|
return &Linear{Weight: weight, Bias: bias}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NewQuantizedLinear creates a quantized Linear layer.
|
||||||
|
func NewQuantizedLinear(weight, scales, biases, bias *Array, groupSize, bits int) *Linear {
|
||||||
|
return &Linear{
|
||||||
|
Weight: weight,
|
||||||
|
Scales: scales,
|
||||||
|
Biases: biases,
|
||||||
|
Bias: bias,
|
||||||
|
GroupSize: groupSize,
|
||||||
|
Bits: bits,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Forward computes the linear transformation.
|
// Forward computes the linear transformation.
|
||||||
|
// Uses QuantizedMatmul when quantization parameters are present.
|
||||||
func (l *Linear) Forward(x *Array) *Array {
|
func (l *Linear) Forward(x *Array) *Array {
|
||||||
out := Matmul(x, Transpose(l.Weight))
|
var out *Array
|
||||||
|
if l.Scales != nil {
|
||||||
|
out = QuantizedMatmul(x, l.Weight, l.Scales, l.Biases, true, l.GroupSize, l.Bits)
|
||||||
|
} else {
|
||||||
|
out = Matmul(x, Transpose(l.Weight))
|
||||||
|
}
|
||||||
if l.Bias != nil && l.Bias.Valid() {
|
if l.Bias != nil && l.Bias.Valid() {
|
||||||
out = Add(out, l.Bias)
|
out = Add(out, l.Bias)
|
||||||
}
|
}
|
||||||
|
|
@ -23,15 +46,35 @@ func (l *Linear) Forward(x *Array) *Array {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Embedding is a lookup table for token embeddings.
|
// Embedding is a lookup table for token embeddings.
|
||||||
|
// For quantized models, set Scales/Biases/GroupSize/Bits to dequantize before lookup.
|
||||||
type Embedding struct {
|
type Embedding struct {
|
||||||
Weight *Array `weight:"weight"`
|
Weight *Array `weight:"weight"`
|
||||||
|
Scales *Array `weight:"scales"`
|
||||||
|
Biases *Array `weight:"biases"`
|
||||||
|
GroupSize int
|
||||||
|
Bits int
|
||||||
}
|
}
|
||||||
|
|
||||||
// Forward looks up embeddings for the given token indices.
|
// Forward looks up embeddings for the given token indices.
|
||||||
func (e *Embedding) Forward(indices *Array) *Array {
|
func (e *Embedding) Forward(indices *Array) *Array {
|
||||||
|
if e.Scales != nil {
|
||||||
|
w := Dequantize(e.Weight, e.Scales, e.Biases, e.GroupSize, e.Bits)
|
||||||
|
return Take(w, indices, 0)
|
||||||
|
}
|
||||||
return Take(e.Weight, indices, 0)
|
return Take(e.Weight, indices, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AsLinear returns a Linear layer using the embedding weights (for tied output).
|
||||||
|
func (e *Embedding) AsLinear() *Linear {
|
||||||
|
return &Linear{
|
||||||
|
Weight: e.Weight,
|
||||||
|
Scales: e.Scales,
|
||||||
|
Biases: e.Biases,
|
||||||
|
GroupSize: e.GroupSize,
|
||||||
|
Bits: e.Bits,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// RMSNormModule is an RMS normalization layer wrapping the fused kernel.
|
// RMSNormModule is an RMS normalization layer wrapping the fused kernel.
|
||||||
type RMSNormModule struct {
|
type RMSNormModule struct {
|
||||||
Weight *Array `weight:"weight"`
|
Weight *Array `weight:"weight"`
|
||||||
|
|
|
||||||
|
|
@ -304,6 +304,18 @@ func Argpartition(a *Array, kth, axis int) *Array {
|
||||||
return out
|
return out
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Dequantize restores a quantized array to full precision.
|
||||||
|
func Dequantize(w, scales, biases *Array, groupSize, bits int) *Array {
|
||||||
|
out := New("DEQUANTIZE", w, scales, biases)
|
||||||
|
gs := C.mlx_optional_int{value: C.int(groupSize), has_value: C._Bool(true)}
|
||||||
|
b := C.mlx_optional_int{value: C.int(bits), has_value: C._Bool(true)}
|
||||||
|
mode := C.CString("default")
|
||||||
|
defer C.free(unsafe.Pointer(mode))
|
||||||
|
noDtype := C.mlx_optional_dtype{has_value: C._Bool(false)}
|
||||||
|
C.mlx_dequantize(&out.ctx, w.ctx, scales.ctx, biases.ctx, gs, b, mode, noDtype, DefaultStream().ctx)
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
// PutAlongAxis places values into array at indices along axis.
|
// PutAlongAxis places values into array at indices along axis.
|
||||||
func PutAlongAxis(a, indices, values *Array, axis int) *Array {
|
func PutAlongAxis(a, indices, values *Array, axis int) *Array {
|
||||||
out := New("PUT_ALONG_AXIS", a, indices, values)
|
out := New("PUT_ALONG_AXIS", a, indices, values)
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue