feat/updates #1
4 changed files with 146 additions and 25 deletions
23
pkg/mlx/cache/cache.go
vendored
23
pkg/mlx/cache/cache.go
vendored
|
|
@ -35,6 +35,14 @@ func NewKVCache() *KVCache {
|
|||
func (c *KVCache) Update(k, v *mlx.Array, seqLen int) (*mlx.Array, *mlx.Array) {
|
||||
prev := c.offset
|
||||
shape := k.Shape()
|
||||
if len(shape) < 4 {
|
||||
// K/V must be [B, H, L, D] — if not, pass through unchanged
|
||||
if c.keys == nil {
|
||||
c.keys, c.values = k, v
|
||||
}
|
||||
c.offset += seqLen
|
||||
return c.keys, c.values
|
||||
}
|
||||
B, H, Dk := shape[0], shape[1], shape[3]
|
||||
Dv := v.Shape()[3]
|
||||
|
||||
|
|
@ -103,6 +111,13 @@ func (c *RotatingKVCache) Update(k, v *mlx.Array, seqLen int) (*mlx.Array, *mlx.
|
|||
|
||||
func (c *RotatingKVCache) updateInPlace(k, v *mlx.Array) (*mlx.Array, *mlx.Array) {
|
||||
shape := k.Shape()
|
||||
if len(shape) < 4 {
|
||||
if c.keys == nil {
|
||||
c.keys, c.values = k, v
|
||||
}
|
||||
c.offset++
|
||||
return c.keys, c.values
|
||||
}
|
||||
B, H, Dk := shape[0], shape[1], shape[3]
|
||||
Dv := v.Shape()[3]
|
||||
|
||||
|
|
@ -139,6 +154,14 @@ func (c *RotatingKVCache) updateInPlace(k, v *mlx.Array) (*mlx.Array, *mlx.Array
|
|||
|
||||
func (c *RotatingKVCache) updateConcat(k, v *mlx.Array, seqLen int) (*mlx.Array, *mlx.Array) {
|
||||
shape := k.Shape()
|
||||
if len(shape) < 4 {
|
||||
// K/V must be [B, H, L, D] — if not, pass through unchanged
|
||||
if c.keys == nil {
|
||||
c.keys, c.values = k, v
|
||||
}
|
||||
c.offset += seqLen
|
||||
return c.keys, c.values
|
||||
}
|
||||
B, H, Dk := shape[0], shape[1], shape[3]
|
||||
Dv := v.Shape()[3]
|
||||
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ package model
|
|||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"math"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
|
@ -15,6 +16,12 @@ import (
|
|||
"forge.lthn.ai/core/cli/pkg/mlx/tokenizer"
|
||||
)
|
||||
|
||||
// QuantizationConfig holds quantization parameters from config.json.
|
||||
type QuantizationConfig struct {
|
||||
GroupSize int `json:"group_size"`
|
||||
Bits int `json:"bits"`
|
||||
}
|
||||
|
||||
// TextConfig holds Gemma 3 text model configuration.
|
||||
type TextConfig struct {
|
||||
HiddenSize int32 `json:"hidden_size"`
|
||||
|
|
@ -31,7 +38,8 @@ type TextConfig struct {
|
|||
SlidingWindow int32 `json:"sliding_window"`
|
||||
SlidingWindowPattern int32 `json:"sliding_window_pattern"`
|
||||
|
||||
Scale float32 `json:"-"` // Computed: 1/sqrt(head_dim)
|
||||
Quantization *QuantizationConfig `json:"-"` // Parsed separately from top-level
|
||||
Scale float32 `json:"-"` // Computed: 1/sqrt(head_dim)
|
||||
}
|
||||
|
||||
// GemmaModel is the Gemma 3 text model.
|
||||
|
|
@ -117,8 +125,9 @@ func geluApprox(x *mlx.Array) *mlx.Array {
|
|||
func parseConfig(data []byte) (*TextConfig, error) {
|
||||
// Try parsing text_config from multimodal wrapper
|
||||
var wrapper struct {
|
||||
TextConfig TextConfig `json:"text_config"`
|
||||
ModelType string `json:"model_type"`
|
||||
TextConfig TextConfig `json:"text_config"`
|
||||
ModelType string `json:"model_type"`
|
||||
Quantization *QuantizationConfig `json:"quantization"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &wrapper); err != nil {
|
||||
return nil, err
|
||||
|
|
@ -133,6 +142,9 @@ func parseConfig(data []byte) (*TextConfig, error) {
|
|||
}
|
||||
}
|
||||
|
||||
// Quantization is always top-level
|
||||
cfg.Quantization = wrapper.Quantization
|
||||
|
||||
// Compute defaults
|
||||
if cfg.HeadDim == 0 && cfg.NumAttentionHeads > 0 {
|
||||
cfg.HeadDim = cfg.HiddenSize / cfg.NumAttentionHeads
|
||||
|
|
@ -195,17 +207,41 @@ func LoadGemma3(modelPath string) (*GemmaModel, error) {
|
|||
}
|
||||
}
|
||||
|
||||
// Helper to resolve weight with language_model. prefix fallback
|
||||
w := func(name string) *mlx.Array { return resolveWeight(weights, name) }
|
||||
|
||||
// Helper to create linear layer (quantized or dense)
|
||||
q := cfg.Quantization
|
||||
if q != nil {
|
||||
slog.Info("mlx: using quantized inference", "bits", q.Bits, "group_size", q.GroupSize)
|
||||
}
|
||||
linear := func(prefix string) *mlx.Linear {
|
||||
weight := w(prefix + ".weight")
|
||||
scales := w(prefix + ".scales")
|
||||
biases := w(prefix + ".biases")
|
||||
if scales != nil && q != nil {
|
||||
return mlx.NewQuantizedLinear(weight, scales, biases, nil, q.GroupSize, q.Bits)
|
||||
}
|
||||
return mlx.NewLinear(weight, nil)
|
||||
}
|
||||
|
||||
// Create embedding (quantized or dense)
|
||||
embed := &mlx.Embedding{Weight: w("model.embed_tokens.weight")}
|
||||
if embedScales := w("model.embed_tokens.scales"); embedScales != nil && q != nil {
|
||||
embed.Scales = embedScales
|
||||
embed.Biases = w("model.embed_tokens.biases")
|
||||
embed.GroupSize = q.GroupSize
|
||||
embed.Bits = q.Bits
|
||||
}
|
||||
|
||||
m := &GemmaModel{
|
||||
EmbedTokens: &mlx.Embedding{Weight: resolveWeight(weights, "model.embed_tokens.weight")},
|
||||
EmbedTokens: embed,
|
||||
Layers: make([]*DecoderLayer, cfg.NumHiddenLayers),
|
||||
Norm: &mlx.RMSNormModule{Weight: resolveWeight(weights, "model.norm.weight")},
|
||||
Norm: &mlx.RMSNormModule{Weight: w("model.norm.weight")},
|
||||
Tok: tok,
|
||||
Cfg: cfg,
|
||||
}
|
||||
|
||||
// Helper to resolve weight with language_model. prefix fallback
|
||||
w := func(name string) *mlx.Array { return resolveWeight(weights, name) }
|
||||
|
||||
// Initialize layers
|
||||
for i := int32(0); i < cfg.NumHiddenLayers; i++ {
|
||||
prefix := fmt.Sprintf("model.layers.%d", i)
|
||||
|
|
@ -215,29 +251,36 @@ func LoadGemma3(modelPath string) (*GemmaModel, error) {
|
|||
PreFFNorm: &mlx.RMSNormModule{Weight: w(prefix + ".pre_feedforward_layernorm.weight")},
|
||||
PostFFNorm: &mlx.RMSNormModule{Weight: w(prefix + ".post_feedforward_layernorm.weight")},
|
||||
Attention: &Attention{
|
||||
QProj: mlx.NewLinear(w(prefix+".self_attn.q_proj.weight"), nil),
|
||||
KProj: mlx.NewLinear(w(prefix+".self_attn.k_proj.weight"), nil),
|
||||
VProj: mlx.NewLinear(w(prefix+".self_attn.v_proj.weight"), nil),
|
||||
OProj: mlx.NewLinear(w(prefix+".self_attn.o_proj.weight"), nil),
|
||||
QProj: linear(prefix + ".self_attn.q_proj"),
|
||||
KProj: linear(prefix + ".self_attn.k_proj"),
|
||||
VProj: linear(prefix + ".self_attn.v_proj"),
|
||||
OProj: linear(prefix + ".self_attn.o_proj"),
|
||||
QNorm: &mlx.RMSNormModule{Weight: w(prefix + ".self_attn.q_norm.weight")},
|
||||
KNorm: &mlx.RMSNormModule{Weight: w(prefix + ".self_attn.k_norm.weight")},
|
||||
},
|
||||
MLP: &MLP{
|
||||
GateProj: mlx.NewLinear(w(prefix+".mlp.gate_proj.weight"), nil),
|
||||
UpProj: mlx.NewLinear(w(prefix+".mlp.up_proj.weight"), nil),
|
||||
DownProj: mlx.NewLinear(w(prefix+".mlp.down_proj.weight"), nil),
|
||||
GateProj: linear(prefix + ".mlp.gate_proj"),
|
||||
UpProj: linear(prefix + ".mlp.up_proj"),
|
||||
DownProj: linear(prefix + ".mlp.down_proj"),
|
||||
},
|
||||
LayerIdx: i,
|
||||
IsSliding: isLayerSliding(i, cfg.SlidingWindowPattern),
|
||||
}
|
||||
}
|
||||
|
||||
// Tied embeddings — check for separate lm_head first
|
||||
lmHead := w("lm_head.weight")
|
||||
if lmHead == nil {
|
||||
lmHead = m.EmbedTokens.Weight // tied
|
||||
// Output head — check for separate lm_head first, else tie to embeddings
|
||||
lmHeadWeight := w("lm_head.weight")
|
||||
if lmHeadWeight != nil {
|
||||
lmHeadScales := w("lm_head.scales")
|
||||
if lmHeadScales != nil && q != nil {
|
||||
m.Output = mlx.NewQuantizedLinear(lmHeadWeight, lmHeadScales, w("lm_head.biases"), nil, q.GroupSize, q.Bits)
|
||||
} else {
|
||||
m.Output = mlx.NewLinear(lmHeadWeight, nil)
|
||||
}
|
||||
} else {
|
||||
// Tied embeddings — reuse embed_tokens weights (with quantization if present)
|
||||
m.Output = m.EmbedTokens.AsLinear()
|
||||
}
|
||||
m.Output = mlx.NewLinear(lmHead, nil)
|
||||
|
||||
// Materialize all weights
|
||||
var allArrays []*mlx.Array
|
||||
|
|
|
|||
|
|
@ -3,19 +3,42 @@
|
|||
package mlx
|
||||
|
||||
// Linear is a fully-connected layer: y = x @ W.T + bias.
|
||||
// For quantized models, set Scales/Biases/GroupSize/Bits to use QuantizedMatmul.
|
||||
type Linear struct {
|
||||
Weight *Array `weight:"weight"`
|
||||
Bias *Array `weight:"bias"`
|
||||
Weight *Array `weight:"weight"`
|
||||
Scales *Array `weight:"scales"`
|
||||
Biases *Array `weight:"biases"`
|
||||
Bias *Array `weight:"bias"`
|
||||
GroupSize int
|
||||
Bits int
|
||||
}
|
||||
|
||||
// NewLinear creates a Linear layer with optional bias.
|
||||
// NewLinear creates a dense Linear layer with optional bias.
|
||||
func NewLinear(weight, bias *Array) *Linear {
|
||||
return &Linear{Weight: weight, Bias: bias}
|
||||
}
|
||||
|
||||
// NewQuantizedLinear creates a quantized Linear layer.
|
||||
func NewQuantizedLinear(weight, scales, biases, bias *Array, groupSize, bits int) *Linear {
|
||||
return &Linear{
|
||||
Weight: weight,
|
||||
Scales: scales,
|
||||
Biases: biases,
|
||||
Bias: bias,
|
||||
GroupSize: groupSize,
|
||||
Bits: bits,
|
||||
}
|
||||
}
|
||||
|
||||
// Forward computes the linear transformation.
|
||||
// Uses QuantizedMatmul when quantization parameters are present.
|
||||
func (l *Linear) Forward(x *Array) *Array {
|
||||
out := Matmul(x, Transpose(l.Weight))
|
||||
var out *Array
|
||||
if l.Scales != nil {
|
||||
out = QuantizedMatmul(x, l.Weight, l.Scales, l.Biases, true, l.GroupSize, l.Bits)
|
||||
} else {
|
||||
out = Matmul(x, Transpose(l.Weight))
|
||||
}
|
||||
if l.Bias != nil && l.Bias.Valid() {
|
||||
out = Add(out, l.Bias)
|
||||
}
|
||||
|
|
@ -23,15 +46,35 @@ func (l *Linear) Forward(x *Array) *Array {
|
|||
}
|
||||
|
||||
// Embedding is a lookup table for token embeddings.
|
||||
// For quantized models, set Scales/Biases/GroupSize/Bits to dequantize before lookup.
|
||||
type Embedding struct {
|
||||
Weight *Array `weight:"weight"`
|
||||
Weight *Array `weight:"weight"`
|
||||
Scales *Array `weight:"scales"`
|
||||
Biases *Array `weight:"biases"`
|
||||
GroupSize int
|
||||
Bits int
|
||||
}
|
||||
|
||||
// Forward looks up embeddings for the given token indices.
|
||||
func (e *Embedding) Forward(indices *Array) *Array {
|
||||
if e.Scales != nil {
|
||||
w := Dequantize(e.Weight, e.Scales, e.Biases, e.GroupSize, e.Bits)
|
||||
return Take(w, indices, 0)
|
||||
}
|
||||
return Take(e.Weight, indices, 0)
|
||||
}
|
||||
|
||||
// AsLinear returns a Linear layer using the embedding weights (for tied output).
|
||||
func (e *Embedding) AsLinear() *Linear {
|
||||
return &Linear{
|
||||
Weight: e.Weight,
|
||||
Scales: e.Scales,
|
||||
Biases: e.Biases,
|
||||
GroupSize: e.GroupSize,
|
||||
Bits: e.Bits,
|
||||
}
|
||||
}
|
||||
|
||||
// RMSNormModule is an RMS normalization layer wrapping the fused kernel.
|
||||
type RMSNormModule struct {
|
||||
Weight *Array `weight:"weight"`
|
||||
|
|
|
|||
|
|
@ -304,6 +304,18 @@ func Argpartition(a *Array, kth, axis int) *Array {
|
|||
return out
|
||||
}
|
||||
|
||||
// Dequantize restores a quantized array to full precision.
|
||||
func Dequantize(w, scales, biases *Array, groupSize, bits int) *Array {
|
||||
out := New("DEQUANTIZE", w, scales, biases)
|
||||
gs := C.mlx_optional_int{value: C.int(groupSize), has_value: C._Bool(true)}
|
||||
b := C.mlx_optional_int{value: C.int(bits), has_value: C._Bool(true)}
|
||||
mode := C.CString("default")
|
||||
defer C.free(unsafe.Pointer(mode))
|
||||
noDtype := C.mlx_optional_dtype{has_value: C._Bool(false)}
|
||||
C.mlx_dequantize(&out.ctx, w.ctx, scales.ctx, biases.ctx, gs, b, mode, noDtype, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// PutAlongAxis places values into array at indices along axis.
|
||||
func PutAlongAxis(a, indices, values *Array, axis int) *Array {
|
||||
out := New("PUT_ALONG_AXIS", a, indices, values)
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue