refactor(metal): flatten model, tokenizer, sample, cache into internal/metal

Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
Snider 2026-02-19 19:51:14 +00:00
parent a669d1d9c1
commit 08976aa504
9 changed files with 223 additions and 777 deletions

200
cache/cache_test.go vendored
View file

@ -1,200 +0,0 @@
//go:build darwin && arm64
package cache
import (
"testing"
"forge.lthn.ai/core/go-mlx"
)
// makeKV creates a small K/V pair with shape [B=1, H=2, L=seqLen, D=4].
func makeKV(seqLen int) (*mlx.Array, *mlx.Array) {
size := 1 * 2 * seqLen * 4
data := make([]float32, size)
for i := range data {
data[i] = float32(i) * 0.1
}
k := mlx.FromValues(data, 1, 2, seqLen, 4)
v := mlx.FromValues(data, 1, 2, seqLen, 4)
return k, v
}
// --- KVCache ---
func TestKVCache_New(t *testing.T) {
c := NewKVCache()
if c.Offset() != 0 {
t.Errorf("offset = %d, want 0", c.Offset())
}
if c.Len() != 0 {
t.Errorf("len = %d, want 0", c.Len())
}
if c.State() != nil {
t.Error("state should be nil for empty cache")
}
}
func TestKVCache_SingleUpdate(t *testing.T) {
c := NewKVCache()
k, v := makeKV(3) // 3 tokens
outK, outV := c.Update(k, v, 3)
mlx.Materialize(outK, outV)
if c.Offset() != 3 {
t.Errorf("offset = %d, want 3", c.Offset())
}
if c.Len() != 3 {
t.Errorf("len = %d, want 3", c.Len())
}
// Output K should have shape [1, 2, 3, 4]
shape := outK.Shape()
if shape[0] != 1 || shape[1] != 2 || shape[2] != 3 || shape[3] != 4 {
t.Errorf("outK shape = %v, want [1 2 3 4]", shape)
}
}
func TestKVCache_MultipleUpdates(t *testing.T) {
c := NewKVCache()
// Prompt: 5 tokens
k1, v1 := makeKV(5)
outK, outV := c.Update(k1, v1, 5)
mlx.Materialize(outK, outV)
if c.Offset() != 5 {
t.Errorf("offset = %d, want 5", c.Offset())
}
// Generate: 1 token at a time
k2, v2 := makeKV(1)
outK, outV = c.Update(k2, v2, 1)
mlx.Materialize(outK, outV)
if c.Offset() != 6 {
t.Errorf("offset = %d, want 6", c.Offset())
}
shape := outK.Shape()
if shape[2] != 6 {
t.Errorf("outK L dim = %d, want 6", shape[2])
}
}
func TestKVCache_Reset(t *testing.T) {
c := NewKVCache()
k, v := makeKV(3)
c.Update(k, v, 3)
c.Reset()
if c.Offset() != 0 {
t.Errorf("offset after reset = %d, want 0", c.Offset())
}
if c.State() != nil {
t.Error("state should be nil after reset")
}
}
func TestKVCache_State(t *testing.T) {
c := NewKVCache()
k, v := makeKV(2)
c.Update(k, v, 2)
state := c.State()
if len(state) != 2 {
t.Fatalf("state length = %d, want 2", len(state))
}
// state[0] = keys, state[1] = values
if state[0] == nil || state[1] == nil {
t.Error("state arrays should not be nil")
}
}
// --- RotatingKVCache ---
func TestRotatingKVCache_New(t *testing.T) {
c := NewRotatingKVCache(16)
if c.Offset() != 0 {
t.Errorf("offset = %d, want 0", c.Offset())
}
if c.Len() != 0 {
t.Errorf("len = %d, want 0", c.Len())
}
}
func TestRotatingKVCache_SingleToken(t *testing.T) {
c := NewRotatingKVCache(8)
k, v := makeKV(1)
outK, outV := c.Update(k, v, 1)
mlx.Materialize(outK, outV)
if c.Offset() != 1 {
t.Errorf("offset = %d, want 1", c.Offset())
}
if c.Len() != 1 {
t.Errorf("len = %d, want 1", c.Len())
}
}
func TestRotatingKVCache_MultiTokenPrompt(t *testing.T) {
c := NewRotatingKVCache(16)
k, v := makeKV(5)
outK, outV := c.Update(k, v, 5)
mlx.Materialize(outK, outV)
if c.Offset() != 5 {
t.Errorf("offset = %d, want 5", c.Offset())
}
if c.Len() != 5 {
t.Errorf("len = %d, want 5", c.Len())
}
}
func TestRotatingKVCache_Bounded(t *testing.T) {
c := NewRotatingKVCache(4)
// Fill with 4-token prompt (at max)
k, v := makeKV(4)
outK, outV := c.Update(k, v, 4)
mlx.Materialize(outK, outV)
if c.Len() != 4 {
t.Errorf("len = %d, want 4 (at max)", c.Len())
}
// Add one more token — should trim to maxSize
k2, v2 := makeKV(1)
outK, outV = c.Update(k2, v2, 1)
mlx.Materialize(outK, outV)
if c.Offset() != 5 {
t.Errorf("offset = %d, want 5", c.Offset())
}
// Len should be bounded by maxSize
if c.Len() != 4 {
t.Errorf("len = %d, want 4 (bounded)", c.Len())
}
}
func TestRotatingKVCache_Reset(t *testing.T) {
c := NewRotatingKVCache(8)
k, v := makeKV(3)
c.Update(k, v, 3)
c.Reset()
if c.Offset() != 0 {
t.Errorf("offset after reset = %d, want 0", c.Offset())
}
if c.Len() != 0 {
t.Errorf("len after reset = %d, want 0", c.Len())
}
if c.State() != nil {
t.Error("state should be nil after reset")
}
}

View file

@ -1,20 +1,17 @@
//go:build darwin && arm64
// Package cache provides KV cache implementations for transformer inference.
package cache
import "forge.lthn.ai/core/go-mlx"
package metal
// Cache manages key-value pairs for transformer attention layers.
type Cache interface {
// Update adds new key/value tensors and returns the full cached K/V.
Update(k, v *mlx.Array, seqLen int) (*mlx.Array, *mlx.Array)
Update(k, v *Array, seqLen int) (*Array, *Array)
// Offset returns the total number of tokens processed.
Offset() int
// Len returns the number of cached tokens (may differ from Offset for rotating caches).
Len() int
// State returns the cached K/V arrays, or nil if empty.
State() []*mlx.Array
State() []*Array
// Reset clears the cache for a new generation session.
Reset()
}
@ -22,7 +19,7 @@ type Cache interface {
// KVCache implements an unbounded cache that grows as needed.
// Pre-allocates in chunks of `step` tokens to reduce allocations.
type KVCache struct {
keys, values *mlx.Array
keys, values *Array
offset int
step int
}
@ -32,7 +29,7 @@ func NewKVCache() *KVCache {
return &KVCache{step: 256}
}
func (c *KVCache) Update(k, v *mlx.Array, seqLen int) (*mlx.Array, *mlx.Array) {
func (c *KVCache) Update(k, v *Array, seqLen int) (*Array, *Array) {
prev := c.offset
shape := k.Shape()
if len(shape) < 4 {
@ -49,34 +46,34 @@ func (c *KVCache) Update(k, v *mlx.Array, seqLen int) (*mlx.Array, *mlx.Array) {
// Grow buffer if needed.
if c.keys == nil || (prev+seqLen) > int(c.keys.Shape()[2]) {
nSteps := (c.step + seqLen - 1) / c.step
newK := mlx.Zeros([]int32{B, H, int32(nSteps * c.step), Dk}, k.Dtype())
newV := mlx.Zeros([]int32{B, H, int32(nSteps * c.step), Dv}, v.Dtype())
newK := Zeros([]int32{B, H, int32(nSteps * c.step), Dk}, k.Dtype())
newV := Zeros([]int32{B, H, int32(nSteps * c.step), Dv}, v.Dtype())
if c.keys != nil {
if prev%c.step != 0 {
c.keys = mlx.Slice(c.keys, []int32{0, 0, 0, 0}, []int32{B, H, int32(prev), Dk})
c.values = mlx.Slice(c.values, []int32{0, 0, 0, 0}, []int32{B, H, int32(prev), Dv})
c.keys = Slice(c.keys, []int32{0, 0, 0, 0}, []int32{B, H, int32(prev), Dk})
c.values = Slice(c.values, []int32{0, 0, 0, 0}, []int32{B, H, int32(prev), Dv})
}
c.keys = mlx.Concatenate([]*mlx.Array{c.keys, newK}, 2)
c.values = mlx.Concatenate([]*mlx.Array{c.values, newV}, 2)
c.keys = Concatenate([]*Array{c.keys, newK}, 2)
c.values = Concatenate([]*Array{c.values, newV}, 2)
} else {
c.keys, c.values = newK, newV
}
}
c.offset += seqLen
c.keys = mlx.SliceUpdateInplace(c.keys, k, []int32{0, 0, int32(prev), 0}, []int32{B, H, int32(c.offset), Dk})
c.values = mlx.SliceUpdateInplace(c.values, v, []int32{0, 0, int32(prev), 0}, []int32{B, H, int32(c.offset), Dv})
c.keys = SliceUpdateInplace(c.keys, k, []int32{0, 0, int32(prev), 0}, []int32{B, H, int32(c.offset), Dk})
c.values = SliceUpdateInplace(c.values, v, []int32{0, 0, int32(prev), 0}, []int32{B, H, int32(c.offset), Dv})
return mlx.Slice(c.keys, []int32{0, 0, 0, 0}, []int32{B, H, int32(c.offset), Dk}),
mlx.Slice(c.values, []int32{0, 0, 0, 0}, []int32{B, H, int32(c.offset), Dv})
return Slice(c.keys, []int32{0, 0, 0, 0}, []int32{B, H, int32(c.offset), Dk}),
Slice(c.values, []int32{0, 0, 0, 0}, []int32{B, H, int32(c.offset), Dv})
}
func (c *KVCache) State() []*mlx.Array {
func (c *KVCache) State() []*Array {
if c.keys == nil {
return nil
}
return []*mlx.Array{c.keys, c.values}
return []*Array{c.keys, c.values}
}
func (c *KVCache) Offset() int { return c.offset }
@ -90,7 +87,7 @@ func (c *KVCache) Reset() {
// RotatingKVCache implements a bounded sliding window cache.
type RotatingKVCache struct {
keys, values *mlx.Array
keys, values *Array
offset int
maxSize int
step int
@ -102,14 +99,14 @@ func NewRotatingKVCache(maxSize int) *RotatingKVCache {
return &RotatingKVCache{maxSize: maxSize, step: 256}
}
func (c *RotatingKVCache) Update(k, v *mlx.Array, seqLen int) (*mlx.Array, *mlx.Array) {
func (c *RotatingKVCache) Update(k, v *Array, seqLen int) (*Array, *Array) {
if seqLen > 1 {
return c.updateConcat(k, v, seqLen)
}
return c.updateInPlace(k, v)
}
func (c *RotatingKVCache) updateInPlace(k, v *mlx.Array) (*mlx.Array, *mlx.Array) {
func (c *RotatingKVCache) updateInPlace(k, v *Array) (*Array, *Array) {
shape := k.Shape()
if len(shape) < 4 {
if c.keys == nil {
@ -127,11 +124,11 @@ func (c *RotatingKVCache) updateInPlace(k, v *mlx.Array) (*mlx.Array, *mlx.Array
cap = int(c.keys.Shape()[2])
}
newSize := min(c.step, c.maxSize-cap)
newK := mlx.Zeros([]int32{B, H, int32(newSize), Dk}, k.Dtype())
newV := mlx.Zeros([]int32{B, H, int32(newSize), Dv}, v.Dtype())
newK := Zeros([]int32{B, H, int32(newSize), Dk}, k.Dtype())
newV := Zeros([]int32{B, H, int32(newSize), Dv}, v.Dtype())
if c.keys != nil {
c.keys = mlx.Concatenate([]*mlx.Array{c.keys, newK}, 2)
c.values = mlx.Concatenate([]*mlx.Array{c.values, newV}, 2)
c.keys = Concatenate([]*Array{c.keys, newK}, 2)
c.values = Concatenate([]*Array{c.values, newV}, 2)
} else {
c.keys, c.values = newK, newV
}
@ -141,18 +138,18 @@ func (c *RotatingKVCache) updateInPlace(k, v *mlx.Array) (*mlx.Array, *mlx.Array
c.idx = 0
}
c.keys = mlx.SliceUpdateInplace(c.keys, k, []int32{0, 0, int32(c.idx), 0}, []int32{B, H, int32(c.idx + 1), Dk})
c.values = mlx.SliceUpdateInplace(c.values, v, []int32{0, 0, int32(c.idx), 0}, []int32{B, H, int32(c.idx + 1), Dv})
c.keys = SliceUpdateInplace(c.keys, k, []int32{0, 0, int32(c.idx), 0}, []int32{B, H, int32(c.idx + 1), Dk})
c.values = SliceUpdateInplace(c.values, v, []int32{0, 0, int32(c.idx), 0}, []int32{B, H, int32(c.idx + 1), Dv})
c.offset++
c.idx++
validLen := int32(min(c.offset, c.maxSize))
return mlx.Slice(c.keys, []int32{0, 0, 0, 0}, []int32{B, H, validLen, Dk}),
mlx.Slice(c.values, []int32{0, 0, 0, 0}, []int32{B, H, validLen, Dv})
return Slice(c.keys, []int32{0, 0, 0, 0}, []int32{B, H, validLen, Dk}),
Slice(c.values, []int32{0, 0, 0, 0}, []int32{B, H, validLen, Dv})
}
func (c *RotatingKVCache) updateConcat(k, v *mlx.Array, seqLen int) (*mlx.Array, *mlx.Array) {
func (c *RotatingKVCache) updateConcat(k, v *Array, seqLen int) (*Array, *Array) {
shape := k.Shape()
if len(shape) < 4 {
// K/V must be [B, H, L, D] — if not, pass through unchanged
@ -168,26 +165,26 @@ func (c *RotatingKVCache) updateConcat(k, v *mlx.Array, seqLen int) (*mlx.Array,
if c.keys == nil {
c.keys, c.values = k, v
} else {
c.keys = mlx.Concatenate([]*mlx.Array{c.keys, k}, 2)
c.values = mlx.Concatenate([]*mlx.Array{c.values, v}, 2)
c.keys = Concatenate([]*Array{c.keys, k}, 2)
c.values = Concatenate([]*Array{c.values, v}, 2)
}
c.offset += seqLen
cap := int(c.keys.Shape()[2])
if trim := cap - c.maxSize; trim > 0 {
c.keys = mlx.Slice(c.keys, []int32{0, 0, int32(trim), 0}, []int32{B, H, int32(cap), Dk})
c.values = mlx.Slice(c.values, []int32{0, 0, int32(trim), 0}, []int32{B, H, int32(cap), Dv})
c.keys = Slice(c.keys, []int32{0, 0, int32(trim), 0}, []int32{B, H, int32(cap), Dk})
c.values = Slice(c.values, []int32{0, 0, int32(trim), 0}, []int32{B, H, int32(cap), Dv})
}
c.idx = int(c.keys.Shape()[2])
return c.keys, c.values
}
func (c *RotatingKVCache) State() []*mlx.Array {
func (c *RotatingKVCache) State() []*Array {
if c.keys == nil {
return nil
}
return []*mlx.Array{c.keys, c.values}
return []*Array{c.keys, c.values}
}
func (c *RotatingKVCache) Offset() int { return c.offset }

View file

@ -1,7 +1,6 @@
//go:build darwin && arm64
// Package model provides transformer model architectures for MLX inference.
package model
package metal
import (
"encoding/json"
@ -10,10 +9,6 @@ import (
"math"
"os"
"path/filepath"
"forge.lthn.ai/core/go-mlx"
"forge.lthn.ai/core/go-mlx/cache"
"forge.lthn.ai/core/go-mlx/tokenizer"
)
// TextConfig holds Gemma 3 text model configuration.
@ -38,32 +33,32 @@ type TextConfig struct {
// GemmaModel is the Gemma 3 text model.
type GemmaModel struct {
EmbedTokens *mlx.Embedding
EmbedTokens *Embedding
Layers []*DecoderLayer
Norm *mlx.RMSNormModule
Output *mlx.Linear // Tied to EmbedTokens
Norm *RMSNormModule
Output *Linear // Tied to EmbedTokens
// Precomputed (1 + weight) for Gemma-style RMSNorm
NormScaled *mlx.Array
NormScaled *Array
Tok *tokenizer.Tokenizer
Tok *Tokenizer
Cfg *TextConfig
}
// DecoderLayer is a single transformer block.
type DecoderLayer struct {
InputNorm *mlx.RMSNormModule
InputNorm *RMSNormModule
Attention *Attention
PostAttnNorm *mlx.RMSNormModule
PreFFNorm *mlx.RMSNormModule
PostAttnNorm *RMSNormModule
PreFFNorm *RMSNormModule
MLP *MLP
PostFFNorm *mlx.RMSNormModule
PostFFNorm *RMSNormModule
// Precomputed scaled weights
InputNormScaled *mlx.Array
PostAttnNormScaled *mlx.Array
PreFFNormScaled *mlx.Array
PostFFNormScaled *mlx.Array
InputNormScaled *Array
PostAttnNormScaled *Array
PreFFNormScaled *Array
PostFFNormScaled *Array
IsSliding bool
LayerIdx int32
@ -71,31 +66,31 @@ type DecoderLayer struct {
// Attention implements Gemma 3 attention with Q/K normalization.
type Attention struct {
QProj *mlx.Linear
KProj *mlx.Linear
VProj *mlx.Linear
OProj *mlx.Linear
QNorm *mlx.RMSNormModule
KNorm *mlx.RMSNormModule
QProj *Linear
KProj *Linear
VProj *Linear
OProj *Linear
QNorm *RMSNormModule
KNorm *RMSNormModule
QNormScaled *mlx.Array
KNormScaled *mlx.Array
QNormScaled *Array
KNormScaled *Array
}
// MLP is the feed-forward network.
type MLP struct {
GateProj *mlx.Linear
UpProj *mlx.Linear
DownProj *mlx.Linear
GateProj *Linear
UpProj *Linear
DownProj *Linear
}
// compiledGELU is a singleton for the compiled GELU function.
var compiledGELU *mlx.CompiledFunc
var compiledGELU *CompiledFunc
func getCompiledGELU() *mlx.CompiledFunc {
func getCompiledGELU() *CompiledFunc {
if compiledGELU == nil {
compiledGELU = mlx.CompileShapeless(func(inputs []*mlx.Array) []*mlx.Array {
return []*mlx.Array{geluApprox(inputs[0])}
compiledGELU = CompileShapeless(func(inputs []*Array) []*Array {
return []*Array{geluApprox(inputs[0])}
}, true)
}
return compiledGELU
@ -103,16 +98,16 @@ func getCompiledGELU() *mlx.CompiledFunc {
// geluApprox computes GELU using the tanh approximation:
// 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))
func geluApprox(x *mlx.Array) *mlx.Array {
func geluApprox(x *Array) *Array {
const sqrt2OverPi = 0.7978845608028654
const coeff = 0.044715
x3 := mlx.Mul(mlx.Mul(x, x), x)
inner := mlx.Add(x, mlx.MulScalar(x3, coeff))
scaled := mlx.MulScalar(inner, sqrt2OverPi)
t := mlx.Tanh(scaled)
onePlusT := mlx.AddScalar(t, 1.0)
return mlx.Mul(mlx.MulScalar(x, 0.5), onePlusT)
x3 := Mul(Mul(x, x), x)
inner := Add(x, MulScalar(x3, coeff))
scaled := MulScalar(inner, sqrt2OverPi)
t := Tanh(scaled)
onePlusT := AddScalar(t, 1.0)
return Mul(MulScalar(x, 0.5), onePlusT)
}
// parseConfig handles both flat and nested (text_config) Gemma 3 configs.
@ -175,22 +170,22 @@ func LoadGemma3(modelPath string) (*GemmaModel, error) {
}
// Load tokenizer
tok, err := tokenizer.Load(filepath.Join(modelPath, "tokenizer.json"))
tok, err := LoadTokenizer(filepath.Join(modelPath, "tokenizer.json"))
if err != nil {
return nil, fmt.Errorf("gemma3: load tokenizer: %w", err)
}
// Load weights from all safetensors files
weights := make(map[string]*mlx.Array)
weights := make(map[string]*Array)
matches, _ := filepath.Glob(filepath.Join(modelPath, "*.safetensors"))
for _, path := range matches {
for name, arr := range mlx.LoadSafetensors(path) {
for name, arr := range LoadSafetensors(path) {
weights[name] = arr
}
}
// Helper to resolve weight with language_model. prefix fallback
w := func(name string) *mlx.Array { return resolveWeight(weights, name) }
w := func(name string) *Array { return resolveWeight(weights, name) }
// Infer head_dim from q_proj weight shape when not in config.
// Gemma 3 uses head_dim=256 which differs from hidden_size/num_heads.
@ -211,18 +206,18 @@ func LoadGemma3(modelPath string) (*GemmaModel, error) {
if q != nil {
slog.Info("mlx: using quantized inference", "bits", q.Bits, "group_size", q.GroupSize)
}
linear := func(prefix string) *mlx.Linear {
linear := func(prefix string) *Linear {
weight := w(prefix + ".weight")
scales := w(prefix + ".scales")
biases := w(prefix + ".biases")
if scales != nil && q != nil {
return mlx.NewQuantizedLinear(weight, scales, biases, nil, q.GroupSize, q.Bits)
return NewQuantizedLinear(weight, scales, biases, nil, q.GroupSize, q.Bits)
}
return mlx.NewLinear(weight, nil)
return NewLinear(weight, nil)
}
// Create embedding (quantized or dense)
embed := &mlx.Embedding{Weight: w("model.embed_tokens.weight")}
embed := &Embedding{Weight: w("model.embed_tokens.weight")}
if embedScales := w("model.embed_tokens.scales"); embedScales != nil && q != nil {
embed.Scales = embedScales
embed.Biases = w("model.embed_tokens.biases")
@ -233,7 +228,7 @@ func LoadGemma3(modelPath string) (*GemmaModel, error) {
m := &GemmaModel{
EmbedTokens: embed,
Layers: make([]*DecoderLayer, cfg.NumHiddenLayers),
Norm: &mlx.RMSNormModule{Weight: w("model.norm.weight")},
Norm: &RMSNormModule{Weight: w("model.norm.weight")},
Tok: tok,
Cfg: cfg,
}
@ -242,17 +237,17 @@ func LoadGemma3(modelPath string) (*GemmaModel, error) {
for i := int32(0); i < cfg.NumHiddenLayers; i++ {
prefix := fmt.Sprintf("model.layers.%d", i)
m.Layers[i] = &DecoderLayer{
InputNorm: &mlx.RMSNormModule{Weight: w(prefix + ".input_layernorm.weight")},
PostAttnNorm: &mlx.RMSNormModule{Weight: w(prefix + ".post_attention_layernorm.weight")},
PreFFNorm: &mlx.RMSNormModule{Weight: w(prefix + ".pre_feedforward_layernorm.weight")},
PostFFNorm: &mlx.RMSNormModule{Weight: w(prefix + ".post_feedforward_layernorm.weight")},
InputNorm: &RMSNormModule{Weight: w(prefix + ".input_layernorm.weight")},
PostAttnNorm: &RMSNormModule{Weight: w(prefix + ".post_attention_layernorm.weight")},
PreFFNorm: &RMSNormModule{Weight: w(prefix + ".pre_feedforward_layernorm.weight")},
PostFFNorm: &RMSNormModule{Weight: w(prefix + ".post_feedforward_layernorm.weight")},
Attention: &Attention{
QProj: linear(prefix + ".self_attn.q_proj"),
KProj: linear(prefix + ".self_attn.k_proj"),
VProj: linear(prefix + ".self_attn.v_proj"),
OProj: linear(prefix + ".self_attn.o_proj"),
QNorm: &mlx.RMSNormModule{Weight: w(prefix + ".self_attn.q_norm.weight")},
KNorm: &mlx.RMSNormModule{Weight: w(prefix + ".self_attn.k_norm.weight")},
QNorm: &RMSNormModule{Weight: w(prefix + ".self_attn.q_norm.weight")},
KNorm: &RMSNormModule{Weight: w(prefix + ".self_attn.k_norm.weight")},
},
MLP: &MLP{
GateProj: linear(prefix + ".mlp.gate_proj"),
@ -269,9 +264,9 @@ func LoadGemma3(modelPath string) (*GemmaModel, error) {
if lmHeadWeight != nil {
lmHeadScales := w("lm_head.scales")
if lmHeadScales != nil && q != nil {
m.Output = mlx.NewQuantizedLinear(lmHeadWeight, lmHeadScales, w("lm_head.biases"), nil, q.GroupSize, q.Bits)
m.Output = NewQuantizedLinear(lmHeadWeight, lmHeadScales, w("lm_head.biases"), nil, q.GroupSize, q.Bits)
} else {
m.Output = mlx.NewLinear(lmHeadWeight, nil)
m.Output = NewLinear(lmHeadWeight, nil)
}
} else {
// Tied embeddings — reuse embed_tokens weights (with quantization if present)
@ -279,11 +274,11 @@ func LoadGemma3(modelPath string) (*GemmaModel, error) {
}
// Materialize all weights
var allArrays []*mlx.Array
var allArrays []*Array
for _, a := range weights {
allArrays = append(allArrays, a)
}
mlx.Materialize(allArrays...)
Materialize(allArrays...)
// Precompute (1 + weight) for Gemma-style RMSNorm
precomputeScaledWeights(m)
@ -292,25 +287,25 @@ func LoadGemma3(modelPath string) (*GemmaModel, error) {
}
func precomputeScaledWeights(m *GemmaModel) {
m.NormScaled = mlx.AddScalar(m.Norm.Weight, 1.0)
m.NormScaled = AddScalar(m.Norm.Weight, 1.0)
for _, layer := range m.Layers {
layer.InputNormScaled = mlx.AddScalar(layer.InputNorm.Weight, 1.0)
layer.PostAttnNormScaled = mlx.AddScalar(layer.PostAttnNorm.Weight, 1.0)
layer.PreFFNormScaled = mlx.AddScalar(layer.PreFFNorm.Weight, 1.0)
layer.PostFFNormScaled = mlx.AddScalar(layer.PostFFNorm.Weight, 1.0)
layer.Attention.QNormScaled = mlx.AddScalar(layer.Attention.QNorm.Weight, 1.0)
layer.Attention.KNormScaled = mlx.AddScalar(layer.Attention.KNorm.Weight, 1.0)
layer.InputNormScaled = AddScalar(layer.InputNorm.Weight, 1.0)
layer.PostAttnNormScaled = AddScalar(layer.PostAttnNorm.Weight, 1.0)
layer.PreFFNormScaled = AddScalar(layer.PreFFNorm.Weight, 1.0)
layer.PostFFNormScaled = AddScalar(layer.PostFFNorm.Weight, 1.0)
layer.Attention.QNormScaled = AddScalar(layer.Attention.QNorm.Weight, 1.0)
layer.Attention.KNormScaled = AddScalar(layer.Attention.KNorm.Weight, 1.0)
}
var scaled []*mlx.Array
var scaled []*Array
scaled = append(scaled, m.NormScaled)
for _, layer := range m.Layers {
scaled = append(scaled, layer.InputNormScaled, layer.PostAttnNormScaled,
layer.PreFFNormScaled, layer.PostFFNormScaled,
layer.Attention.QNormScaled, layer.Attention.KNormScaled)
}
mlx.Materialize(scaled...)
Materialize(scaled...)
}
func isLayerSliding(layerIdx, pattern int32) bool {
@ -321,56 +316,56 @@ func isLayerSliding(layerIdx, pattern int32) bool {
}
// Forward runs the text model forward pass.
func (m *GemmaModel) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array {
func (m *GemmaModel) Forward(tokens *Array, caches []Cache) *Array {
shape := tokens.Shape()
B, L := shape[0], shape[1]
h := m.EmbedTokens.Forward(tokens)
h = mlx.MulScalar(h, float32(math.Sqrt(float64(m.Cfg.HiddenSize))))
h = MulScalar(h, float32(math.Sqrt(float64(m.Cfg.HiddenSize))))
for i, layer := range m.Layers {
h = layer.forward(h, caches[i], B, L, m.Cfg)
}
return m.Output.Forward(mlx.RMSNorm(h, m.NormScaled, m.Cfg.RMSNormEps))
return m.Output.Forward(RMSNorm(h, m.NormScaled, m.Cfg.RMSNormEps))
}
func (l *DecoderLayer) forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *TextConfig) *mlx.Array {
normed := mlx.RMSNorm(x, l.InputNormScaled, cfg.RMSNormEps)
func (l *DecoderLayer) forward(x *Array, c Cache, B, L int32, cfg *TextConfig) *Array {
normed := RMSNorm(x, l.InputNormScaled, cfg.RMSNormEps)
attnOut := l.Attention.forward(normed, c, B, L, l.IsSliding, cfg)
attnOut = mlx.RMSNorm(attnOut, l.PostAttnNormScaled, cfg.RMSNormEps)
h := mlx.Add(x, attnOut)
attnOut = RMSNorm(attnOut, l.PostAttnNormScaled, cfg.RMSNormEps)
h := Add(x, attnOut)
normed = mlx.RMSNorm(h, l.PreFFNormScaled, cfg.RMSNormEps)
normed = RMSNorm(h, l.PreFFNormScaled, cfg.RMSNormEps)
mlpOut := l.MLP.forward(normed)
mlpOut = mlx.RMSNorm(mlpOut, l.PostFFNormScaled, cfg.RMSNormEps)
return mlx.Add(h, mlpOut)
mlpOut = RMSNorm(mlpOut, l.PostFFNormScaled, cfg.RMSNormEps)
return Add(h, mlpOut)
}
func (a *Attention) forward(x *mlx.Array, c cache.Cache, B, L int32, isSliding bool, cfg *TextConfig) *mlx.Array {
func (a *Attention) forward(x *Array, c Cache, B, L int32, isSliding bool, cfg *TextConfig) *Array {
q := a.QProj.Forward(x)
k := a.KProj.Forward(x)
v := a.VProj.Forward(x)
// Reshape to [B, num_heads, L, head_dim]
q = mlx.AsStrided(q, []int32{B, cfg.NumAttentionHeads, L, cfg.HeadDim},
q = AsStrided(q, []int32{B, cfg.NumAttentionHeads, L, cfg.HeadDim},
[]int64{int64(L * cfg.NumAttentionHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumAttentionHeads * cfg.HeadDim), 1}, 0)
k = mlx.AsStrided(k, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim},
k = AsStrided(k, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim},
[]int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0)
v = mlx.AsStrided(v, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim},
v = AsStrided(v, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim},
[]int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0)
// Q/K normalization
q = mlx.RMSNorm(q, a.QNormScaled, cfg.RMSNormEps)
k = mlx.RMSNorm(k, a.KNormScaled, cfg.RMSNormEps)
q = RMSNorm(q, a.QNormScaled, cfg.RMSNormEps)
k = RMSNorm(k, a.KNormScaled, cfg.RMSNormEps)
// RoPE with appropriate theta
ropeTheta := cfg.RopeTheta
if isSliding {
ropeTheta = cfg.RopeLocalBaseFreq
}
q = mlx.RoPE(q, int(cfg.HeadDim), false, ropeTheta, 1.0, c.Offset())
k = mlx.RoPE(k, int(cfg.HeadDim), false, ropeTheta, 1.0, c.Offset())
q = RoPE(q, int(cfg.HeadDim), false, ropeTheta, 1.0, c.Offset())
k = RoPE(k, int(cfg.HeadDim), false, ropeTheta, 1.0, c.Offset())
// Update cache
k, v = c.Update(k, v, int(L))
@ -378,29 +373,29 @@ func (a *Attention) forward(x *mlx.Array, c cache.Cache, B, L int32, isSliding b
// GQA: repeat K/V heads
repeatFactor := cfg.NumAttentionHeads / cfg.NumKeyValueHeads
if repeatFactor > 1 {
k = mlx.RepeatKV(k, repeatFactor)
v = mlx.RepeatKV(v, repeatFactor)
k = RepeatKV(k, repeatFactor)
v = RepeatKV(v, repeatFactor)
}
// Scaled dot-product attention
out := mlx.ScaledDotProductAttention(q, k, v, cfg.Scale, L > 1)
out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.HeadDim)
out := ScaledDotProductAttention(q, k, v, cfg.Scale, L > 1)
out = Reshape(Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.HeadDim)
return a.OProj.Forward(out)
}
func (m *MLP) forward(x *mlx.Array) *mlx.Array {
func (m *MLP) forward(x *Array) *Array {
gate := getCompiledGELU().Call(m.GateProj.Forward(x))[0]
return m.DownProj.Forward(mlx.Mul(gate, m.UpProj.Forward(x)))
return m.DownProj.Forward(Mul(gate, m.UpProj.Forward(x)))
}
// NewCache creates per-layer caches for generation.
func (m *GemmaModel) NewCache() []cache.Cache {
caches := make([]cache.Cache, len(m.Layers))
func (m *GemmaModel) NewCache() []Cache {
caches := make([]Cache, len(m.Layers))
for i := range caches {
if m.Layers[i].IsSliding {
caches[i] = cache.NewRotatingKVCache(int(m.Cfg.SlidingWindow))
caches[i] = NewRotatingKVCache(int(m.Cfg.SlidingWindow))
} else {
caches[i] = cache.NewKVCache()
caches[i] = NewKVCache()
}
}
return caches
@ -410,22 +405,22 @@ func (m *GemmaModel) NewCache() []cache.Cache {
func (m *GemmaModel) NumLayers() int { return len(m.Layers) }
// Tokenizer returns the model's tokenizer.
func (m *GemmaModel) Tokenizer() *tokenizer.Tokenizer { return m.Tok }
func (m *GemmaModel) Tokenizer() *Tokenizer { return m.Tok }
// ModelType returns the architecture identifier.
func (m *GemmaModel) ModelType() string { return "gemma3" }
// ApplyLoRA wraps target projection layers with LoRA adapters.
func (m *GemmaModel) ApplyLoRA(cfg mlx.LoRAConfig) *mlx.LoRAAdapter {
adapter := &mlx.LoRAAdapter{
Layers: make(map[string]*mlx.LoRALinear),
func (m *GemmaModel) ApplyLoRA(cfg LoRAConfig) *LoRAAdapter {
adapter := &LoRAAdapter{
Layers: make(map[string]*LoRALinear),
Config: cfg,
}
for i, layer := range m.Layers {
prefix := fmt.Sprintf("model.layers.%d.self_attn", i)
for _, target := range cfg.TargetKeys {
var proj *mlx.Linear
var proj *Linear
switch target {
case "q_proj":
proj = layer.Attention.QProj
@ -437,7 +432,7 @@ func (m *GemmaModel) ApplyLoRA(cfg mlx.LoRAConfig) *mlx.LoRAAdapter {
proj = layer.Attention.OProj
}
if proj != nil {
lora := mlx.NewLoRALinear(proj, cfg.Rank, cfg.Alpha)
lora := NewLoRALinear(proj, cfg.Rank, cfg.Alpha)
proj.LoRA = lora
adapter.Layers[prefix+"."+target] = lora
}

View file

@ -1,39 +1,34 @@
//go:build darwin && arm64
// Package model provides transformer model architectures for MLX inference.
package model
package metal
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"forge.lthn.ai/core/go-mlx"
"forge.lthn.ai/core/go-mlx/cache"
"forge.lthn.ai/core/go-mlx/tokenizer"
)
// Model is the common interface for all transformer model architectures.
type Model interface {
// InternalModel is the common interface for all transformer model architectures.
type InternalModel interface {
// Forward runs the model forward pass on token IDs with KV caches.
Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array
Forward(tokens *Array, caches []Cache) *Array
// NewCache creates per-layer KV caches for generation.
NewCache() []cache.Cache
NewCache() []Cache
// NumLayers returns the number of transformer layers.
NumLayers() int
// Tokenizer returns the model's tokenizer.
Tokenizer() *tokenizer.Tokenizer
Tokenizer() *Tokenizer
// ModelType returns the architecture identifier (e.g. "gemma3", "qwen3").
ModelType() string
// ApplyLoRA wraps target projection layers with LoRA adapters for training.
// Returns the adapter which holds references to all LoRA layers.
ApplyLoRA(cfg mlx.LoRAConfig) *mlx.LoRAAdapter
ApplyLoRA(cfg LoRAConfig) *LoRAAdapter
}
// QuantizationConfig holds quantization parameters from config.json.
@ -43,7 +38,7 @@ type QuantizationConfig struct {
}
// resolveWeight looks up a weight with optional "language_model." prefix.
func resolveWeight(weights map[string]*mlx.Array, name string) *mlx.Array {
func resolveWeight(weights map[string]*Array, name string) *Array {
if w, ok := weights[name]; ok {
return w
}
@ -53,8 +48,8 @@ func resolveWeight(weights map[string]*mlx.Array, name string) *mlx.Array {
return nil
}
// LoadModel auto-detects the model architecture from config.json and loads it.
func LoadModel(modelPath string) (Model, error) {
// loadModel auto-detects the model architecture from config.json and loads it.
func loadModel(modelPath string) (InternalModel, error) {
data, err := os.ReadFile(filepath.Join(modelPath, "config.json"))
if err != nil {
return nil, fmt.Errorf("model: load config: %w", err)

View file

@ -1,6 +1,6 @@
//go:build darwin && arm64
package model
package metal
import (
"encoding/json"
@ -9,10 +9,6 @@ import (
"math"
"os"
"path/filepath"
"forge.lthn.ai/core/go-mlx"
"forge.lthn.ai/core/go-mlx/cache"
"forge.lthn.ai/core/go-mlx/tokenizer"
)
// Qwen3Config holds Qwen 3 model configuration.
@ -34,39 +30,39 @@ type Qwen3Config struct {
// Qwen3Model is the Qwen 3 text model.
type Qwen3Model struct {
EmbedTokens *mlx.Embedding
EmbedTokens *Embedding
Layers []*Qwen3DecoderLayer
Norm *mlx.RMSNormModule
Output *mlx.Linear
Norm *RMSNormModule
Output *Linear
Tok *tokenizer.Tokenizer
Tok *Tokenizer
Cfg *Qwen3Config
}
// Qwen3DecoderLayer is a single transformer block.
// Qwen 3 uses standard pre-norm residual: norm→attn→add, norm→mlp→add.
type Qwen3DecoderLayer struct {
InputNorm *mlx.RMSNormModule // Pre-attention norm
PostAttnNorm *mlx.RMSNormModule // Pre-MLP norm (confusingly named post_attention_layernorm)
InputNorm *RMSNormModule // Pre-attention norm
PostAttnNorm *RMSNormModule // Pre-MLP norm (confusingly named post_attention_layernorm)
Attention *Qwen3Attention
MLP *Qwen3MLP
}
// Qwen3Attention implements Qwen 3 GQA with Q/K RMS normalization.
type Qwen3Attention struct {
QProj *mlx.Linear
KProj *mlx.Linear
VProj *mlx.Linear
OProj *mlx.Linear
QNorm *mlx.RMSNormModule
KNorm *mlx.RMSNormModule
QProj *Linear
KProj *Linear
VProj *Linear
OProj *Linear
QNorm *RMSNormModule
KNorm *RMSNormModule
}
// Qwen3MLP is the SwiGLU feed-forward network: down(silu(gate(x)) * up(x)).
type Qwen3MLP struct {
GateProj *mlx.Linear
UpProj *mlx.Linear
DownProj *mlx.Linear
GateProj *Linear
UpProj *Linear
DownProj *Linear
}
func parseQwen3Config(data []byte) (*Qwen3Config, error) {
@ -114,40 +110,40 @@ func LoadQwen3(modelPath string) (*Qwen3Model, error) {
return nil, fmt.Errorf("qwen3: parse config: %w", err)
}
tok, err := tokenizer.Load(filepath.Join(modelPath, "tokenizer.json"))
tok, err := LoadTokenizer(filepath.Join(modelPath, "tokenizer.json"))
if err != nil {
return nil, fmt.Errorf("qwen3: load tokenizer: %w", err)
}
// Load weights from all safetensors files
weights := make(map[string]*mlx.Array)
weights := make(map[string]*Array)
matches, _ := filepath.Glob(filepath.Join(modelPath, "*.safetensors"))
for _, path := range matches {
for name, arr := range mlx.LoadSafetensors(path) {
for name, arr := range LoadSafetensors(path) {
weights[name] = arr
}
}
w := func(name string) *mlx.Array { return resolveWeight(weights, name) }
w := func(name string) *Array { return resolveWeight(weights, name) }
// Quantization setup
q := cfg.Quantization
if q != nil {
slog.Info("qwen3: using quantized inference", "bits", q.Bits, "group_size", q.GroupSize)
}
linear := func(prefix string) *mlx.Linear {
linear := func(prefix string) *Linear {
weight := w(prefix + ".weight")
scales := w(prefix + ".scales")
biases := w(prefix + ".biases")
bias := w(prefix + ".bias")
if scales != nil && q != nil {
return mlx.NewQuantizedLinear(weight, scales, biases, bias, q.GroupSize, q.Bits)
return NewQuantizedLinear(weight, scales, biases, bias, q.GroupSize, q.Bits)
}
return mlx.NewLinear(weight, bias)
return NewLinear(weight, bias)
}
// Embedding
embed := &mlx.Embedding{Weight: w("model.embed_tokens.weight")}
embed := &Embedding{Weight: w("model.embed_tokens.weight")}
if embedScales := w("model.embed_tokens.scales"); embedScales != nil && q != nil {
embed.Scales = embedScales
embed.Biases = w("model.embed_tokens.biases")
@ -158,7 +154,7 @@ func LoadQwen3(modelPath string) (*Qwen3Model, error) {
m := &Qwen3Model{
EmbedTokens: embed,
Layers: make([]*Qwen3DecoderLayer, cfg.NumHiddenLayers),
Norm: &mlx.RMSNormModule{Weight: w("model.norm.weight")},
Norm: &RMSNormModule{Weight: w("model.norm.weight")},
Tok: tok,
Cfg: cfg,
}
@ -166,15 +162,15 @@ func LoadQwen3(modelPath string) (*Qwen3Model, error) {
for i := int32(0); i < cfg.NumHiddenLayers; i++ {
p := fmt.Sprintf("model.layers.%d", i)
m.Layers[i] = &Qwen3DecoderLayer{
InputNorm: &mlx.RMSNormModule{Weight: w(p + ".input_layernorm.weight")},
PostAttnNorm: &mlx.RMSNormModule{Weight: w(p + ".post_attention_layernorm.weight")},
InputNorm: &RMSNormModule{Weight: w(p + ".input_layernorm.weight")},
PostAttnNorm: &RMSNormModule{Weight: w(p + ".post_attention_layernorm.weight")},
Attention: &Qwen3Attention{
QProj: linear(p + ".self_attn.q_proj"),
KProj: linear(p + ".self_attn.k_proj"),
VProj: linear(p + ".self_attn.v_proj"),
OProj: linear(p + ".self_attn.o_proj"),
QNorm: &mlx.RMSNormModule{Weight: w(p + ".self_attn.q_norm.weight")},
KNorm: &mlx.RMSNormModule{Weight: w(p + ".self_attn.k_norm.weight")},
QNorm: &RMSNormModule{Weight: w(p + ".self_attn.q_norm.weight")},
KNorm: &RMSNormModule{Weight: w(p + ".self_attn.k_norm.weight")},
},
MLP: &Qwen3MLP{
GateProj: linear(p + ".mlp.gate_proj"),
@ -189,20 +185,20 @@ func LoadQwen3(modelPath string) (*Qwen3Model, error) {
if lmHeadWeight != nil {
lmHeadScales := w("lm_head.scales")
if lmHeadScales != nil && q != nil {
m.Output = mlx.NewQuantizedLinear(lmHeadWeight, lmHeadScales, w("lm_head.biases"), nil, q.GroupSize, q.Bits)
m.Output = NewQuantizedLinear(lmHeadWeight, lmHeadScales, w("lm_head.biases"), nil, q.GroupSize, q.Bits)
} else {
m.Output = mlx.NewLinear(lmHeadWeight, nil)
m.Output = NewLinear(lmHeadWeight, nil)
}
} else {
m.Output = m.EmbedTokens.AsLinear()
}
// Materialise all weights onto Metal
var allArrays []*mlx.Array
var allArrays []*Array
for _, a := range weights {
allArrays = append(allArrays, a)
}
mlx.Materialize(allArrays...)
Materialize(allArrays...)
slog.Info("qwen3: model loaded",
"layers", cfg.NumHiddenLayers,
@ -218,7 +214,7 @@ func LoadQwen3(modelPath string) (*Qwen3Model, error) {
// Forward runs the Qwen 3 forward pass.
// Unlike Gemma, Qwen does NOT scale embeddings by sqrt(hidden_size).
func (m *Qwen3Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array {
func (m *Qwen3Model) Forward(tokens *Array, caches []Cache) *Array {
shape := tokens.Shape()
B, L := shape[0], shape[1]
@ -231,29 +227,29 @@ func (m *Qwen3Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array
return m.Output.Forward(m.Norm.Forward(h, m.Cfg.RMSNormEps))
}
func (l *Qwen3DecoderLayer) forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Qwen3Config) *mlx.Array {
func (l *Qwen3DecoderLayer) forward(x *Array, c Cache, B, L int32, cfg *Qwen3Config) *Array {
// Pre-attention norm → attention → residual add
normed := l.InputNorm.Forward(x, cfg.RMSNormEps)
attnOut := l.Attention.forward(normed, c, B, L, cfg)
h := mlx.Add(x, attnOut)
h := Add(x, attnOut)
// Pre-MLP norm → MLP → residual add
normed = l.PostAttnNorm.Forward(h, cfg.RMSNormEps)
mlpOut := l.MLP.forward(normed)
return mlx.Add(h, mlpOut)
return Add(h, mlpOut)
}
func (a *Qwen3Attention) forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Qwen3Config) *mlx.Array {
func (a *Qwen3Attention) forward(x *Array, c Cache, B, L int32, cfg *Qwen3Config) *Array {
q := a.QProj.Forward(x)
k := a.KProj.Forward(x)
v := a.VProj.Forward(x)
// Reshape to [B, num_heads, L, head_dim]
q = mlx.AsStrided(q, []int32{B, cfg.NumAttentionHeads, L, cfg.HeadDim},
q = AsStrided(q, []int32{B, cfg.NumAttentionHeads, L, cfg.HeadDim},
[]int64{int64(L * cfg.NumAttentionHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumAttentionHeads * cfg.HeadDim), 1}, 0)
k = mlx.AsStrided(k, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim},
k = AsStrided(k, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim},
[]int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0)
v = mlx.AsStrided(v, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim},
v = AsStrided(v, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim},
[]int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0)
// Q/K RMS normalization (Qwen 3 has this)
@ -261,8 +257,8 @@ func (a *Qwen3Attention) forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Q
k = a.KNorm.Forward(k, cfg.RMSNormEps)
// RoPE — single theta for all layers (no sliding window)
q = mlx.RoPE(q, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, c.Offset())
k = mlx.RoPE(k, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, c.Offset())
q = RoPE(q, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, c.Offset())
k = RoPE(k, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, c.Offset())
// Update KV cache
k, v = c.Update(k, v, int(L))
@ -270,27 +266,27 @@ func (a *Qwen3Attention) forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Q
// GQA: repeat K/V heads to match Q heads
repeatFactor := cfg.NumAttentionHeads / cfg.NumKeyValueHeads
if repeatFactor > 1 {
k = mlx.RepeatKV(k, repeatFactor)
v = mlx.RepeatKV(v, repeatFactor)
k = RepeatKV(k, repeatFactor)
v = RepeatKV(v, repeatFactor)
}
// Scaled dot-product attention
out := mlx.ScaledDotProductAttention(q, k, v, cfg.Scale, L > 1)
out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.HeadDim)
out := ScaledDotProductAttention(q, k, v, cfg.Scale, L > 1)
out = Reshape(Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.HeadDim)
return a.OProj.Forward(out)
}
// forward computes SwiGLU: down(silu(gate(x)) * up(x)).
func (m *Qwen3MLP) forward(x *mlx.Array) *mlx.Array {
gate := mlx.SiLU(m.GateProj.Forward(x))
return m.DownProj.Forward(mlx.Mul(gate, m.UpProj.Forward(x)))
func (m *Qwen3MLP) forward(x *Array) *Array {
gate := SiLU(m.GateProj.Forward(x))
return m.DownProj.Forward(Mul(gate, m.UpProj.Forward(x)))
}
// NewCache creates per-layer KV caches. Qwen 3 uses global attention only.
func (m *Qwen3Model) NewCache() []cache.Cache {
caches := make([]cache.Cache, len(m.Layers))
func (m *Qwen3Model) NewCache() []Cache {
caches := make([]Cache, len(m.Layers))
for i := range caches {
caches[i] = cache.NewKVCache()
caches[i] = NewKVCache()
}
return caches
}
@ -299,22 +295,22 @@ func (m *Qwen3Model) NewCache() []cache.Cache {
func (m *Qwen3Model) NumLayers() int { return len(m.Layers) }
// Tokenizer returns the model's tokenizer.
func (m *Qwen3Model) Tokenizer() *tokenizer.Tokenizer { return m.Tok }
func (m *Qwen3Model) Tokenizer() *Tokenizer { return m.Tok }
// ModelType returns the architecture identifier.
func (m *Qwen3Model) ModelType() string { return "qwen3" }
// ApplyLoRA wraps target projection layers with LoRA adapters.
func (m *Qwen3Model) ApplyLoRA(cfg mlx.LoRAConfig) *mlx.LoRAAdapter {
adapter := &mlx.LoRAAdapter{
Layers: make(map[string]*mlx.LoRALinear),
func (m *Qwen3Model) ApplyLoRA(cfg LoRAConfig) *LoRAAdapter {
adapter := &LoRAAdapter{
Layers: make(map[string]*LoRALinear),
Config: cfg,
}
for i, layer := range m.Layers {
prefix := fmt.Sprintf("model.layers.%d.self_attn", i)
for _, target := range cfg.TargetKeys {
var proj *mlx.Linear
var proj *Linear
switch target {
case "q_proj":
proj = layer.Attention.QProj
@ -326,7 +322,7 @@ func (m *Qwen3Model) ApplyLoRA(cfg mlx.LoRAConfig) *mlx.LoRAAdapter {
proj = layer.Attention.OProj
}
if proj != nil {
lora := mlx.NewLoRALinear(proj, cfg.Rank, cfg.Alpha)
lora := NewLoRALinear(proj, cfg.Rank, cfg.Alpha)
proj.LoRA = lora
adapter.Layers[prefix+"."+target] = lora
}

View file

@ -1,22 +1,19 @@
//go:build darwin && arm64
// Package sample provides composable token sampling strategies.
package sample
package metal
import (
"math"
"forge.lthn.ai/core/go-mlx"
)
// Sampler transforms logits into a sampled token index.
type Sampler interface {
Sample(logits *mlx.Array) *mlx.Array
Sample(logits *Array) *Array
}
// New creates a composable sampler chain from the given parameters.
// newSampler creates a composable sampler chain from the given parameters.
// Order: TopP -> MinP -> TopK -> Temperature -> categorical sample.
func New(temp, topP, minP float32, topK int) Sampler {
func newSampler(temp, topP, minP float32, topK int) Sampler {
if temp == 0 {
return greedy{}
}
@ -38,43 +35,43 @@ func New(temp, topP, minP float32, topK int) Sampler {
// chain applies a sequence of samplers, then samples from the result.
type chain []Sampler
func (c chain) Sample(logits *mlx.Array) *mlx.Array {
func (c chain) Sample(logits *Array) *Array {
for _, s := range c {
logits = s.Sample(logits)
}
// Final categorical sample from log-probabilities
return mlx.RandomCategorical(logits)
return RandomCategorical(logits)
}
// greedy returns the argmax token.
type greedy struct{}
func (greedy) Sample(logits *mlx.Array) *mlx.Array {
return mlx.Argmax(logits, -1, false)
func (greedy) Sample(logits *Array) *Array {
return Argmax(logits, -1, false)
}
// Temperature scales logits by 1/temp.
type Temperature float32
func (t Temperature) Sample(logits *mlx.Array) *mlx.Array {
return mlx.MulScalar(logits, 1.0/float32(t))
func (t Temperature) Sample(logits *Array) *Array {
return MulScalar(logits, 1.0/float32(t))
}
// TopKSampler masks all but the top-k logits.
type TopKSampler int
func (k TopKSampler) Sample(logits *mlx.Array) *mlx.Array {
neg := mlx.Negative(logits)
mask := mlx.Argpartition(neg, int(k)-1, -1)
func (k TopKSampler) Sample(logits *Array) *Array {
neg := Negative(logits)
mask := Argpartition(neg, int(k)-1, -1)
// Slice the indices beyond top-k
mask = mlx.SliceAxis(mask, -1, int32(k), int32(logits.Dim(-1)))
return mlx.PutAlongAxis(logits, mask, mlx.FromValue(float32(math.Inf(-1))), -1)
mask = SliceAxis(mask, -1, int32(k), int32(logits.Dim(-1)))
return PutAlongAxis(logits, mask, FromValue(float32(math.Inf(-1))), -1)
}
// TopP implements nucleus sampling (cumulative probability threshold).
type TopP float32
func (p TopP) Sample(logits *mlx.Array) *mlx.Array {
func (p TopP) Sample(logits *Array) *Array {
// TODO: full nucleus sampling requires cumsum which mlx-c doesn't expose directly.
// For now, pass through. TopK + Temperature covers most use cases.
return logits
@ -83,7 +80,7 @@ func (p TopP) Sample(logits *mlx.Array) *mlx.Array {
// MinPSampler masks tokens below min_p * max_prob.
type MinPSampler float32
func (p MinPSampler) Sample(logits *mlx.Array) *mlx.Array {
func (p MinPSampler) Sample(logits *Array) *Array {
// For now, pass through — MinP is an optimization over TopP.
// Full implementation requires finding max prob and masking below threshold.
return logits

View file

@ -1,7 +1,6 @@
//go:build darwin && arm64
// Package tokenizer provides BPE tokenization for transformer models.
package tokenizer
package metal
import (
"encoding/json"
@ -46,8 +45,8 @@ type tokenizerJSON struct {
} `json:"added_tokens"`
}
// Load reads a tokenizer.json file and creates a Tokenizer.
func Load(path string) (*Tokenizer, error) {
// LoadTokenizer reads a tokenizer.json file and creates a Tokenizer.
func LoadTokenizer(path string) (*Tokenizer, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, fmt.Errorf("tokenizer: read %s: %w", path, err)
@ -294,7 +293,7 @@ func (t *Tokenizer) DecodeToken(id int32) string {
return t.decodeGPT2Bytes(text)
}
// SentencePiece: replace with space but keep it (it's the word boundary)
// SentencePiece: replace with space but keep it (it's the word boundary)
return strings.ReplaceAll(text, "▁", " ")
}

View file

@ -1,116 +0,0 @@
//go:build darwin && arm64
package sample
import (
"testing"
"forge.lthn.ai/core/go-mlx"
)
func TestGreedy(t *testing.T) {
// Logits heavily favour index 2
logits := mlx.FromValues([]float32{-10, -10, 100, -10}, 1, 4)
s := New(0, 0, 0, 0) // temp=0 → greedy
token := s.Sample(logits)
mlx.Materialize(token)
if token.Int() != 2 {
t.Errorf("greedy sample = %d, want 2", token.Int())
}
}
func TestTemperature_HighTemp(t *testing.T) {
// High temperature should still produce a valid index
logits := mlx.FromValues([]float32{1, 2, 3, 4}, 1, 4)
s := New(100.0, 0, 0, 0) // very high temp → near uniform
token := s.Sample(logits)
mlx.Materialize(token)
idx := token.Int()
if idx < 0 || idx >= 4 {
t.Errorf("sample index = %d, out of range [0, 4)", idx)
}
}
func TestTemperature_LowTemp(t *testing.T) {
// Very low temperature should behave like greedy
logits := mlx.FromValues([]float32{-10, -10, 100, -10}, 1, 4)
s := New(0.001, 0, 0, 0) // near-zero temp → near-greedy
token := s.Sample(logits)
mlx.Materialize(token)
if token.Int() != 2 {
t.Errorf("low-temp sample = %d, want 2 (near greedy)", token.Int())
}
}
func TestTopK(t *testing.T) {
// TopK=1 with clear winner should always pick that token
logits := mlx.FromValues([]float32{-100, 100, -100, -100}, 1, 4)
s := New(1.0, 0, 0, 1) // topK=1
token := s.Sample(logits)
mlx.Materialize(token)
if token.Int() != 1 {
t.Errorf("topk=1 sample = %d, want 1", token.Int())
}
}
func TestTopK_MultipleTokens(t *testing.T) {
// TopK=2, both high logits — should pick one of them
logits := mlx.FromValues([]float32{-100, 50, 50, -100}, 1, 4)
s := New(1.0, 0, 0, 2) // topK=2
seen := map[int]bool{}
for range 20 {
token := s.Sample(logits)
mlx.Materialize(token)
seen[token.Int()] = true
}
// Should only ever pick index 1 or 2
for idx := range seen {
if idx != 1 && idx != 2 {
t.Errorf("topk=2 sampled index %d, expected only 1 or 2", idx)
}
}
}
func TestNew_Chain(t *testing.T) {
// Full chain: topK + temperature
logits := mlx.FromValues([]float32{1, 2, 3, 4, 5}, 1, 5)
s := New(0.5, 0, 0, 3) // temp=0.5, topK=3
token := s.Sample(logits)
mlx.Materialize(token)
idx := token.Int()
if idx < 0 || idx >= 5 {
t.Errorf("chain sample index = %d, out of range", idx)
}
}
func TestTopP_PassThrough(t *testing.T) {
// TopP is currently a stub — verify it doesn't break the chain
logits := mlx.FromValues([]float32{-10, -10, 100, -10}, 1, 4)
s := New(0.5, 0.9, 0, 0) // topP=0.9 (stub), temp=0.5
token := s.Sample(logits)
mlx.Materialize(token)
if token.Int() != 2 {
t.Errorf("topP stub + temp sample = %d, want 2", token.Int())
}
}
func TestMinP_PassThrough(t *testing.T) {
// MinP is currently a stub — verify it doesn't break the chain
logits := mlx.FromValues([]float32{-10, -10, 100, -10}, 1, 4)
s := New(0.5, 0, 0.1, 0) // minP=0.1 (stub), temp=0.5
token := s.Sample(logits)
mlx.Materialize(token)
if token.Int() != 2 {
t.Errorf("minP stub + temp sample = %d, want 2", token.Int())
}
}

View file

@ -1,217 +0,0 @@
//go:build darwin && arm64
package tokenizer
import (
"os"
"path/filepath"
"testing"
)
// minimalTokenizerJSON is a valid HuggingFace tokenizer.json with a tiny vocab.
const minimalTokenizerJSON = `{
"model": {
"type": "BPE",
"vocab": {
"h": 0,
"e": 1,
"l": 2,
"o": 3,
"▁": 4,
"he": 5,
"ll": 6,
"▁h": 7
},
"merges": ["h e", "l l"],
"byte_fallback": false
},
"added_tokens": [
{"id": 100, "content": "<bos>", "special": true},
{"id": 101, "content": "<eos>", "special": true}
]
}`
func writeTestTokenizer(t *testing.T) string {
t.Helper()
dir := t.TempDir()
path := filepath.Join(dir, "tokenizer.json")
if err := os.WriteFile(path, []byte(minimalTokenizerJSON), 0644); err != nil {
t.Fatalf("write test tokenizer: %v", err)
}
return path
}
func TestLoad(t *testing.T) {
path := writeTestTokenizer(t)
tok, err := Load(path)
if err != nil {
t.Fatalf("Load: %v", err)
}
if tok == nil {
t.Fatal("tokenizer is nil")
}
}
func TestLoad_MissingFile(t *testing.T) {
_, err := Load("/nonexistent/tokenizer.json")
if err == nil {
t.Error("expected error for missing file")
}
}
func TestLoad_InvalidJSON(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "tokenizer.json")
os.WriteFile(path, []byte("not json"), 0644)
_, err := Load(path)
if err == nil {
t.Error("expected error for invalid JSON")
}
}
func TestBOSEOS(t *testing.T) {
path := writeTestTokenizer(t)
tok, _ := Load(path)
if tok.BOSToken() != 100 {
t.Errorf("BOS = %d, want 100", tok.BOSToken())
}
if tok.EOSToken() != 101 {
t.Errorf("EOS = %d, want 101", tok.EOSToken())
}
}
func TestEncode_ProducesTokens(t *testing.T) {
path := writeTestTokenizer(t)
tok, _ := Load(path)
tokens := tok.Encode("hello")
if len(tokens) == 0 {
t.Fatal("Encode returned empty tokens")
}
// First token should be BOS
if tokens[0] != tok.BOSToken() {
t.Errorf("first token = %d, want BOS (%d)", tokens[0], tok.BOSToken())
}
t.Logf("Encode(\"hello\") = %v", tokens)
}
func TestDecode_SpecialTokensSkipped(t *testing.T) {
path := writeTestTokenizer(t)
tok, _ := Load(path)
// Decoding BOS/EOS should produce empty string
text := tok.Decode([]int32{100, 101})
if text != "" {
t.Errorf("Decode(BOS, EOS) = %q, want empty", text)
}
}
func TestDecode_RegularTokens(t *testing.T) {
path := writeTestTokenizer(t)
tok, _ := Load(path)
// Decode known vocab entries
text := tok.Decode([]int32{5, 6, 3}) // "he" + "ll" + "o"
if text != "hello" {
t.Errorf("Decode = %q, want %q", text, "hello")
}
}
func TestDecodeToken_Regular(t *testing.T) {
path := writeTestTokenizer(t)
tok, _ := Load(path)
// "he" = token 5
text := tok.DecodeToken(5)
if text != "he" {
t.Errorf("DecodeToken(5) = %q, want %q", text, "he")
}
}
func TestDecodeToken_Special(t *testing.T) {
path := writeTestTokenizer(t)
tok, _ := Load(path)
// Special tokens should return empty
text := tok.DecodeToken(100)
if text != "" {
t.Errorf("DecodeToken(BOS) = %q, want empty", text)
}
}
func TestDecodeToken_SentencePieceSpace(t *testing.T) {
path := writeTestTokenizer(t)
tok, _ := Load(path)
// "▁h" = token 7, should decode to " h" (space prefix)
text := tok.DecodeToken(7)
if text != " h" {
t.Errorf("DecodeToken(7) = %q, want %q", text, " h")
}
}
func TestDecodeToken_Unknown(t *testing.T) {
path := writeTestTokenizer(t)
tok, _ := Load(path)
text := tok.DecodeToken(9999)
if text != "" {
t.Errorf("DecodeToken(unknown) = %q, want empty", text)
}
}
func TestFormatGemmaPrompt(t *testing.T) {
got := FormatGemmaPrompt("What is 2+2?")
want := "<start_of_turn>user\nWhat is 2+2?<end_of_turn>\n<start_of_turn>model\n"
if got != want {
t.Errorf("FormatGemmaPrompt = %q, want %q", got, want)
}
}
// --- GPT-2 byte maps ---
func TestBuildGPT2ByteMaps(t *testing.T) {
decoder, encoder := buildGPT2ByteMaps()
// All 256 bytes must be mapped
if len(encoder) != 256 {
t.Errorf("encoder has %d entries, want 256", len(encoder))
}
if len(decoder) != 256 {
t.Errorf("decoder has %d entries, want 256", len(decoder))
}
// Round-trip: every byte should survive encode → decode
for b := 0; b < 256; b++ {
r := encoder[byte(b)]
got := decoder[r]
if got != byte(b) {
t.Errorf("byte %d: encode→decode = %d, want %d", b, got, b)
}
}
}
func TestBuildGPT2ByteMaps_PrintableASCII(t *testing.T) {
_, encoder := buildGPT2ByteMaps()
// Printable ASCII (33-126) should self-map
for b := 33; b <= 126; b++ {
if encoder[byte(b)] != rune(b) {
t.Errorf("byte %d (%c): expected self-map, got %c", b, b, encoder[byte(b)])
}
}
}
func TestBuildGPT2ByteMaps_ControlChars(t *testing.T) {
_, encoder := buildGPT2ByteMaps()
// Space (32) and control chars (0-31) should NOT self-map
if encoder[byte(32)] == rune(32) {
t.Error("space (32) should not self-map in GPT-2 encoding")
}
if encoder[byte(0)] == rune(0) {
t.Error("null (0) should not self-map in GPT-2 encoding")
}
}