From 08976aa504560fd937831256eaea7f15040b9b0c Mon Sep 17 00:00:00 2001 From: Snider Date: Thu, 19 Feb 2026 19:51:14 +0000 Subject: [PATCH] refactor(metal): flatten model, tokenizer, sample, cache into internal/metal Co-Authored-By: Virgil --- cache/cache_test.go | 200 ------------------- {cache => internal/metal}/cache.go | 73 ++++--- {model => internal/metal}/gemma3.go | 201 ++++++++++--------- {model => internal/metal}/model.go | 25 +-- {model => internal/metal}/qwen3.go | 122 ++++++------ {sample => internal/metal}/sample.go | 37 ++-- {tokenizer => internal/metal}/tokenizer.go | 9 +- sample/sample_test.go | 116 ----------- tokenizer/tokenizer_test.go | 217 --------------------- 9 files changed, 223 insertions(+), 777 deletions(-) delete mode 100644 cache/cache_test.go rename {cache => internal/metal}/cache.go (56%) rename {model => internal/metal}/gemma3.go (63%) rename {model => internal/metal}/model.go (69%) rename {model => internal/metal}/qwen3.go (69%) rename {sample => internal/metal}/sample.go (60%) rename {tokenizer => internal/metal}/tokenizer.go (96%) delete mode 100644 sample/sample_test.go delete mode 100644 tokenizer/tokenizer_test.go diff --git a/cache/cache_test.go b/cache/cache_test.go deleted file mode 100644 index e660af6..0000000 --- a/cache/cache_test.go +++ /dev/null @@ -1,200 +0,0 @@ -//go:build darwin && arm64 - -package cache - -import ( - "testing" - - "forge.lthn.ai/core/go-mlx" -) - -// makeKV creates a small K/V pair with shape [B=1, H=2, L=seqLen, D=4]. -func makeKV(seqLen int) (*mlx.Array, *mlx.Array) { - size := 1 * 2 * seqLen * 4 - data := make([]float32, size) - for i := range data { - data[i] = float32(i) * 0.1 - } - k := mlx.FromValues(data, 1, 2, seqLen, 4) - v := mlx.FromValues(data, 1, 2, seqLen, 4) - return k, v -} - -// --- KVCache --- - -func TestKVCache_New(t *testing.T) { - c := NewKVCache() - if c.Offset() != 0 { - t.Errorf("offset = %d, want 0", c.Offset()) - } - if c.Len() != 0 { - t.Errorf("len = %d, want 0", c.Len()) - } - if c.State() != nil { - t.Error("state should be nil for empty cache") - } -} - -func TestKVCache_SingleUpdate(t *testing.T) { - c := NewKVCache() - k, v := makeKV(3) // 3 tokens - - outK, outV := c.Update(k, v, 3) - mlx.Materialize(outK, outV) - - if c.Offset() != 3 { - t.Errorf("offset = %d, want 3", c.Offset()) - } - if c.Len() != 3 { - t.Errorf("len = %d, want 3", c.Len()) - } - - // Output K should have shape [1, 2, 3, 4] - shape := outK.Shape() - if shape[0] != 1 || shape[1] != 2 || shape[2] != 3 || shape[3] != 4 { - t.Errorf("outK shape = %v, want [1 2 3 4]", shape) - } -} - -func TestKVCache_MultipleUpdates(t *testing.T) { - c := NewKVCache() - - // Prompt: 5 tokens - k1, v1 := makeKV(5) - outK, outV := c.Update(k1, v1, 5) - mlx.Materialize(outK, outV) - - if c.Offset() != 5 { - t.Errorf("offset = %d, want 5", c.Offset()) - } - - // Generate: 1 token at a time - k2, v2 := makeKV(1) - outK, outV = c.Update(k2, v2, 1) - mlx.Materialize(outK, outV) - - if c.Offset() != 6 { - t.Errorf("offset = %d, want 6", c.Offset()) - } - - shape := outK.Shape() - if shape[2] != 6 { - t.Errorf("outK L dim = %d, want 6", shape[2]) - } -} - -func TestKVCache_Reset(t *testing.T) { - c := NewKVCache() - k, v := makeKV(3) - c.Update(k, v, 3) - - c.Reset() - - if c.Offset() != 0 { - t.Errorf("offset after reset = %d, want 0", c.Offset()) - } - if c.State() != nil { - t.Error("state should be nil after reset") - } -} - -func TestKVCache_State(t *testing.T) { - c := NewKVCache() - k, v := makeKV(2) - c.Update(k, v, 2) - - state := c.State() - if len(state) != 2 { - t.Fatalf("state length = %d, want 2", len(state)) - } - // state[0] = keys, state[1] = values - if state[0] == nil || state[1] == nil { - t.Error("state arrays should not be nil") - } -} - -// --- RotatingKVCache --- - -func TestRotatingKVCache_New(t *testing.T) { - c := NewRotatingKVCache(16) - if c.Offset() != 0 { - t.Errorf("offset = %d, want 0", c.Offset()) - } - if c.Len() != 0 { - t.Errorf("len = %d, want 0", c.Len()) - } -} - -func TestRotatingKVCache_SingleToken(t *testing.T) { - c := NewRotatingKVCache(8) - k, v := makeKV(1) - - outK, outV := c.Update(k, v, 1) - mlx.Materialize(outK, outV) - - if c.Offset() != 1 { - t.Errorf("offset = %d, want 1", c.Offset()) - } - if c.Len() != 1 { - t.Errorf("len = %d, want 1", c.Len()) - } -} - -func TestRotatingKVCache_MultiTokenPrompt(t *testing.T) { - c := NewRotatingKVCache(16) - k, v := makeKV(5) - - outK, outV := c.Update(k, v, 5) - mlx.Materialize(outK, outV) - - if c.Offset() != 5 { - t.Errorf("offset = %d, want 5", c.Offset()) - } - if c.Len() != 5 { - t.Errorf("len = %d, want 5", c.Len()) - } -} - -func TestRotatingKVCache_Bounded(t *testing.T) { - c := NewRotatingKVCache(4) - - // Fill with 4-token prompt (at max) - k, v := makeKV(4) - outK, outV := c.Update(k, v, 4) - mlx.Materialize(outK, outV) - - if c.Len() != 4 { - t.Errorf("len = %d, want 4 (at max)", c.Len()) - } - - // Add one more token — should trim to maxSize - k2, v2 := makeKV(1) - outK, outV = c.Update(k2, v2, 1) - mlx.Materialize(outK, outV) - - if c.Offset() != 5 { - t.Errorf("offset = %d, want 5", c.Offset()) - } - // Len should be bounded by maxSize - if c.Len() != 4 { - t.Errorf("len = %d, want 4 (bounded)", c.Len()) - } -} - -func TestRotatingKVCache_Reset(t *testing.T) { - c := NewRotatingKVCache(8) - k, v := makeKV(3) - c.Update(k, v, 3) - - c.Reset() - - if c.Offset() != 0 { - t.Errorf("offset after reset = %d, want 0", c.Offset()) - } - if c.Len() != 0 { - t.Errorf("len after reset = %d, want 0", c.Len()) - } - if c.State() != nil { - t.Error("state should be nil after reset") - } -} diff --git a/cache/cache.go b/internal/metal/cache.go similarity index 56% rename from cache/cache.go rename to internal/metal/cache.go index 28e15f5..c2630d8 100644 --- a/cache/cache.go +++ b/internal/metal/cache.go @@ -1,20 +1,17 @@ //go:build darwin && arm64 -// Package cache provides KV cache implementations for transformer inference. -package cache - -import "forge.lthn.ai/core/go-mlx" +package metal // Cache manages key-value pairs for transformer attention layers. type Cache interface { // Update adds new key/value tensors and returns the full cached K/V. - Update(k, v *mlx.Array, seqLen int) (*mlx.Array, *mlx.Array) + Update(k, v *Array, seqLen int) (*Array, *Array) // Offset returns the total number of tokens processed. Offset() int // Len returns the number of cached tokens (may differ from Offset for rotating caches). Len() int // State returns the cached K/V arrays, or nil if empty. - State() []*mlx.Array + State() []*Array // Reset clears the cache for a new generation session. Reset() } @@ -22,7 +19,7 @@ type Cache interface { // KVCache implements an unbounded cache that grows as needed. // Pre-allocates in chunks of `step` tokens to reduce allocations. type KVCache struct { - keys, values *mlx.Array + keys, values *Array offset int step int } @@ -32,7 +29,7 @@ func NewKVCache() *KVCache { return &KVCache{step: 256} } -func (c *KVCache) Update(k, v *mlx.Array, seqLen int) (*mlx.Array, *mlx.Array) { +func (c *KVCache) Update(k, v *Array, seqLen int) (*Array, *Array) { prev := c.offset shape := k.Shape() if len(shape) < 4 { @@ -49,34 +46,34 @@ func (c *KVCache) Update(k, v *mlx.Array, seqLen int) (*mlx.Array, *mlx.Array) { // Grow buffer if needed. if c.keys == nil || (prev+seqLen) > int(c.keys.Shape()[2]) { nSteps := (c.step + seqLen - 1) / c.step - newK := mlx.Zeros([]int32{B, H, int32(nSteps * c.step), Dk}, k.Dtype()) - newV := mlx.Zeros([]int32{B, H, int32(nSteps * c.step), Dv}, v.Dtype()) + newK := Zeros([]int32{B, H, int32(nSteps * c.step), Dk}, k.Dtype()) + newV := Zeros([]int32{B, H, int32(nSteps * c.step), Dv}, v.Dtype()) if c.keys != nil { if prev%c.step != 0 { - c.keys = mlx.Slice(c.keys, []int32{0, 0, 0, 0}, []int32{B, H, int32(prev), Dk}) - c.values = mlx.Slice(c.values, []int32{0, 0, 0, 0}, []int32{B, H, int32(prev), Dv}) + c.keys = Slice(c.keys, []int32{0, 0, 0, 0}, []int32{B, H, int32(prev), Dk}) + c.values = Slice(c.values, []int32{0, 0, 0, 0}, []int32{B, H, int32(prev), Dv}) } - c.keys = mlx.Concatenate([]*mlx.Array{c.keys, newK}, 2) - c.values = mlx.Concatenate([]*mlx.Array{c.values, newV}, 2) + c.keys = Concatenate([]*Array{c.keys, newK}, 2) + c.values = Concatenate([]*Array{c.values, newV}, 2) } else { c.keys, c.values = newK, newV } } c.offset += seqLen - c.keys = mlx.SliceUpdateInplace(c.keys, k, []int32{0, 0, int32(prev), 0}, []int32{B, H, int32(c.offset), Dk}) - c.values = mlx.SliceUpdateInplace(c.values, v, []int32{0, 0, int32(prev), 0}, []int32{B, H, int32(c.offset), Dv}) + c.keys = SliceUpdateInplace(c.keys, k, []int32{0, 0, int32(prev), 0}, []int32{B, H, int32(c.offset), Dk}) + c.values = SliceUpdateInplace(c.values, v, []int32{0, 0, int32(prev), 0}, []int32{B, H, int32(c.offset), Dv}) - return mlx.Slice(c.keys, []int32{0, 0, 0, 0}, []int32{B, H, int32(c.offset), Dk}), - mlx.Slice(c.values, []int32{0, 0, 0, 0}, []int32{B, H, int32(c.offset), Dv}) + return Slice(c.keys, []int32{0, 0, 0, 0}, []int32{B, H, int32(c.offset), Dk}), + Slice(c.values, []int32{0, 0, 0, 0}, []int32{B, H, int32(c.offset), Dv}) } -func (c *KVCache) State() []*mlx.Array { +func (c *KVCache) State() []*Array { if c.keys == nil { return nil } - return []*mlx.Array{c.keys, c.values} + return []*Array{c.keys, c.values} } func (c *KVCache) Offset() int { return c.offset } @@ -90,7 +87,7 @@ func (c *KVCache) Reset() { // RotatingKVCache implements a bounded sliding window cache. type RotatingKVCache struct { - keys, values *mlx.Array + keys, values *Array offset int maxSize int step int @@ -102,14 +99,14 @@ func NewRotatingKVCache(maxSize int) *RotatingKVCache { return &RotatingKVCache{maxSize: maxSize, step: 256} } -func (c *RotatingKVCache) Update(k, v *mlx.Array, seqLen int) (*mlx.Array, *mlx.Array) { +func (c *RotatingKVCache) Update(k, v *Array, seqLen int) (*Array, *Array) { if seqLen > 1 { return c.updateConcat(k, v, seqLen) } return c.updateInPlace(k, v) } -func (c *RotatingKVCache) updateInPlace(k, v *mlx.Array) (*mlx.Array, *mlx.Array) { +func (c *RotatingKVCache) updateInPlace(k, v *Array) (*Array, *Array) { shape := k.Shape() if len(shape) < 4 { if c.keys == nil { @@ -127,11 +124,11 @@ func (c *RotatingKVCache) updateInPlace(k, v *mlx.Array) (*mlx.Array, *mlx.Array cap = int(c.keys.Shape()[2]) } newSize := min(c.step, c.maxSize-cap) - newK := mlx.Zeros([]int32{B, H, int32(newSize), Dk}, k.Dtype()) - newV := mlx.Zeros([]int32{B, H, int32(newSize), Dv}, v.Dtype()) + newK := Zeros([]int32{B, H, int32(newSize), Dk}, k.Dtype()) + newV := Zeros([]int32{B, H, int32(newSize), Dv}, v.Dtype()) if c.keys != nil { - c.keys = mlx.Concatenate([]*mlx.Array{c.keys, newK}, 2) - c.values = mlx.Concatenate([]*mlx.Array{c.values, newV}, 2) + c.keys = Concatenate([]*Array{c.keys, newK}, 2) + c.values = Concatenate([]*Array{c.values, newV}, 2) } else { c.keys, c.values = newK, newV } @@ -141,18 +138,18 @@ func (c *RotatingKVCache) updateInPlace(k, v *mlx.Array) (*mlx.Array, *mlx.Array c.idx = 0 } - c.keys = mlx.SliceUpdateInplace(c.keys, k, []int32{0, 0, int32(c.idx), 0}, []int32{B, H, int32(c.idx + 1), Dk}) - c.values = mlx.SliceUpdateInplace(c.values, v, []int32{0, 0, int32(c.idx), 0}, []int32{B, H, int32(c.idx + 1), Dv}) + c.keys = SliceUpdateInplace(c.keys, k, []int32{0, 0, int32(c.idx), 0}, []int32{B, H, int32(c.idx + 1), Dk}) + c.values = SliceUpdateInplace(c.values, v, []int32{0, 0, int32(c.idx), 0}, []int32{B, H, int32(c.idx + 1), Dv}) c.offset++ c.idx++ validLen := int32(min(c.offset, c.maxSize)) - return mlx.Slice(c.keys, []int32{0, 0, 0, 0}, []int32{B, H, validLen, Dk}), - mlx.Slice(c.values, []int32{0, 0, 0, 0}, []int32{B, H, validLen, Dv}) + return Slice(c.keys, []int32{0, 0, 0, 0}, []int32{B, H, validLen, Dk}), + Slice(c.values, []int32{0, 0, 0, 0}, []int32{B, H, validLen, Dv}) } -func (c *RotatingKVCache) updateConcat(k, v *mlx.Array, seqLen int) (*mlx.Array, *mlx.Array) { +func (c *RotatingKVCache) updateConcat(k, v *Array, seqLen int) (*Array, *Array) { shape := k.Shape() if len(shape) < 4 { // K/V must be [B, H, L, D] — if not, pass through unchanged @@ -168,26 +165,26 @@ func (c *RotatingKVCache) updateConcat(k, v *mlx.Array, seqLen int) (*mlx.Array, if c.keys == nil { c.keys, c.values = k, v } else { - c.keys = mlx.Concatenate([]*mlx.Array{c.keys, k}, 2) - c.values = mlx.Concatenate([]*mlx.Array{c.values, v}, 2) + c.keys = Concatenate([]*Array{c.keys, k}, 2) + c.values = Concatenate([]*Array{c.values, v}, 2) } c.offset += seqLen cap := int(c.keys.Shape()[2]) if trim := cap - c.maxSize; trim > 0 { - c.keys = mlx.Slice(c.keys, []int32{0, 0, int32(trim), 0}, []int32{B, H, int32(cap), Dk}) - c.values = mlx.Slice(c.values, []int32{0, 0, int32(trim), 0}, []int32{B, H, int32(cap), Dv}) + c.keys = Slice(c.keys, []int32{0, 0, int32(trim), 0}, []int32{B, H, int32(cap), Dk}) + c.values = Slice(c.values, []int32{0, 0, int32(trim), 0}, []int32{B, H, int32(cap), Dv}) } c.idx = int(c.keys.Shape()[2]) return c.keys, c.values } -func (c *RotatingKVCache) State() []*mlx.Array { +func (c *RotatingKVCache) State() []*Array { if c.keys == nil { return nil } - return []*mlx.Array{c.keys, c.values} + return []*Array{c.keys, c.values} } func (c *RotatingKVCache) Offset() int { return c.offset } diff --git a/model/gemma3.go b/internal/metal/gemma3.go similarity index 63% rename from model/gemma3.go rename to internal/metal/gemma3.go index 2d0d770..6429fae 100644 --- a/model/gemma3.go +++ b/internal/metal/gemma3.go @@ -1,7 +1,6 @@ //go:build darwin && arm64 -// Package model provides transformer model architectures for MLX inference. -package model +package metal import ( "encoding/json" @@ -10,10 +9,6 @@ import ( "math" "os" "path/filepath" - - "forge.lthn.ai/core/go-mlx" - "forge.lthn.ai/core/go-mlx/cache" - "forge.lthn.ai/core/go-mlx/tokenizer" ) // TextConfig holds Gemma 3 text model configuration. @@ -38,32 +33,32 @@ type TextConfig struct { // GemmaModel is the Gemma 3 text model. type GemmaModel struct { - EmbedTokens *mlx.Embedding + EmbedTokens *Embedding Layers []*DecoderLayer - Norm *mlx.RMSNormModule - Output *mlx.Linear // Tied to EmbedTokens + Norm *RMSNormModule + Output *Linear // Tied to EmbedTokens // Precomputed (1 + weight) for Gemma-style RMSNorm - NormScaled *mlx.Array + NormScaled *Array - Tok *tokenizer.Tokenizer + Tok *Tokenizer Cfg *TextConfig } // DecoderLayer is a single transformer block. type DecoderLayer struct { - InputNorm *mlx.RMSNormModule + InputNorm *RMSNormModule Attention *Attention - PostAttnNorm *mlx.RMSNormModule - PreFFNorm *mlx.RMSNormModule + PostAttnNorm *RMSNormModule + PreFFNorm *RMSNormModule MLP *MLP - PostFFNorm *mlx.RMSNormModule + PostFFNorm *RMSNormModule // Precomputed scaled weights - InputNormScaled *mlx.Array - PostAttnNormScaled *mlx.Array - PreFFNormScaled *mlx.Array - PostFFNormScaled *mlx.Array + InputNormScaled *Array + PostAttnNormScaled *Array + PreFFNormScaled *Array + PostFFNormScaled *Array IsSliding bool LayerIdx int32 @@ -71,31 +66,31 @@ type DecoderLayer struct { // Attention implements Gemma 3 attention with Q/K normalization. type Attention struct { - QProj *mlx.Linear - KProj *mlx.Linear - VProj *mlx.Linear - OProj *mlx.Linear - QNorm *mlx.RMSNormModule - KNorm *mlx.RMSNormModule + QProj *Linear + KProj *Linear + VProj *Linear + OProj *Linear + QNorm *RMSNormModule + KNorm *RMSNormModule - QNormScaled *mlx.Array - KNormScaled *mlx.Array + QNormScaled *Array + KNormScaled *Array } // MLP is the feed-forward network. type MLP struct { - GateProj *mlx.Linear - UpProj *mlx.Linear - DownProj *mlx.Linear + GateProj *Linear + UpProj *Linear + DownProj *Linear } // compiledGELU is a singleton for the compiled GELU function. -var compiledGELU *mlx.CompiledFunc +var compiledGELU *CompiledFunc -func getCompiledGELU() *mlx.CompiledFunc { +func getCompiledGELU() *CompiledFunc { if compiledGELU == nil { - compiledGELU = mlx.CompileShapeless(func(inputs []*mlx.Array) []*mlx.Array { - return []*mlx.Array{geluApprox(inputs[0])} + compiledGELU = CompileShapeless(func(inputs []*Array) []*Array { + return []*Array{geluApprox(inputs[0])} }, true) } return compiledGELU @@ -103,16 +98,16 @@ func getCompiledGELU() *mlx.CompiledFunc { // geluApprox computes GELU using the tanh approximation: // 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3))) -func geluApprox(x *mlx.Array) *mlx.Array { +func geluApprox(x *Array) *Array { const sqrt2OverPi = 0.7978845608028654 const coeff = 0.044715 - x3 := mlx.Mul(mlx.Mul(x, x), x) - inner := mlx.Add(x, mlx.MulScalar(x3, coeff)) - scaled := mlx.MulScalar(inner, sqrt2OverPi) - t := mlx.Tanh(scaled) - onePlusT := mlx.AddScalar(t, 1.0) - return mlx.Mul(mlx.MulScalar(x, 0.5), onePlusT) + x3 := Mul(Mul(x, x), x) + inner := Add(x, MulScalar(x3, coeff)) + scaled := MulScalar(inner, sqrt2OverPi) + t := Tanh(scaled) + onePlusT := AddScalar(t, 1.0) + return Mul(MulScalar(x, 0.5), onePlusT) } // parseConfig handles both flat and nested (text_config) Gemma 3 configs. @@ -175,22 +170,22 @@ func LoadGemma3(modelPath string) (*GemmaModel, error) { } // Load tokenizer - tok, err := tokenizer.Load(filepath.Join(modelPath, "tokenizer.json")) + tok, err := LoadTokenizer(filepath.Join(modelPath, "tokenizer.json")) if err != nil { return nil, fmt.Errorf("gemma3: load tokenizer: %w", err) } // Load weights from all safetensors files - weights := make(map[string]*mlx.Array) + weights := make(map[string]*Array) matches, _ := filepath.Glob(filepath.Join(modelPath, "*.safetensors")) for _, path := range matches { - for name, arr := range mlx.LoadSafetensors(path) { + for name, arr := range LoadSafetensors(path) { weights[name] = arr } } // Helper to resolve weight with language_model. prefix fallback - w := func(name string) *mlx.Array { return resolveWeight(weights, name) } + w := func(name string) *Array { return resolveWeight(weights, name) } // Infer head_dim from q_proj weight shape when not in config. // Gemma 3 uses head_dim=256 which differs from hidden_size/num_heads. @@ -211,18 +206,18 @@ func LoadGemma3(modelPath string) (*GemmaModel, error) { if q != nil { slog.Info("mlx: using quantized inference", "bits", q.Bits, "group_size", q.GroupSize) } - linear := func(prefix string) *mlx.Linear { + linear := func(prefix string) *Linear { weight := w(prefix + ".weight") scales := w(prefix + ".scales") biases := w(prefix + ".biases") if scales != nil && q != nil { - return mlx.NewQuantizedLinear(weight, scales, biases, nil, q.GroupSize, q.Bits) + return NewQuantizedLinear(weight, scales, biases, nil, q.GroupSize, q.Bits) } - return mlx.NewLinear(weight, nil) + return NewLinear(weight, nil) } // Create embedding (quantized or dense) - embed := &mlx.Embedding{Weight: w("model.embed_tokens.weight")} + embed := &Embedding{Weight: w("model.embed_tokens.weight")} if embedScales := w("model.embed_tokens.scales"); embedScales != nil && q != nil { embed.Scales = embedScales embed.Biases = w("model.embed_tokens.biases") @@ -233,7 +228,7 @@ func LoadGemma3(modelPath string) (*GemmaModel, error) { m := &GemmaModel{ EmbedTokens: embed, Layers: make([]*DecoderLayer, cfg.NumHiddenLayers), - Norm: &mlx.RMSNormModule{Weight: w("model.norm.weight")}, + Norm: &RMSNormModule{Weight: w("model.norm.weight")}, Tok: tok, Cfg: cfg, } @@ -242,17 +237,17 @@ func LoadGemma3(modelPath string) (*GemmaModel, error) { for i := int32(0); i < cfg.NumHiddenLayers; i++ { prefix := fmt.Sprintf("model.layers.%d", i) m.Layers[i] = &DecoderLayer{ - InputNorm: &mlx.RMSNormModule{Weight: w(prefix + ".input_layernorm.weight")}, - PostAttnNorm: &mlx.RMSNormModule{Weight: w(prefix + ".post_attention_layernorm.weight")}, - PreFFNorm: &mlx.RMSNormModule{Weight: w(prefix + ".pre_feedforward_layernorm.weight")}, - PostFFNorm: &mlx.RMSNormModule{Weight: w(prefix + ".post_feedforward_layernorm.weight")}, + InputNorm: &RMSNormModule{Weight: w(prefix + ".input_layernorm.weight")}, + PostAttnNorm: &RMSNormModule{Weight: w(prefix + ".post_attention_layernorm.weight")}, + PreFFNorm: &RMSNormModule{Weight: w(prefix + ".pre_feedforward_layernorm.weight")}, + PostFFNorm: &RMSNormModule{Weight: w(prefix + ".post_feedforward_layernorm.weight")}, Attention: &Attention{ QProj: linear(prefix + ".self_attn.q_proj"), KProj: linear(prefix + ".self_attn.k_proj"), VProj: linear(prefix + ".self_attn.v_proj"), OProj: linear(prefix + ".self_attn.o_proj"), - QNorm: &mlx.RMSNormModule{Weight: w(prefix + ".self_attn.q_norm.weight")}, - KNorm: &mlx.RMSNormModule{Weight: w(prefix + ".self_attn.k_norm.weight")}, + QNorm: &RMSNormModule{Weight: w(prefix + ".self_attn.q_norm.weight")}, + KNorm: &RMSNormModule{Weight: w(prefix + ".self_attn.k_norm.weight")}, }, MLP: &MLP{ GateProj: linear(prefix + ".mlp.gate_proj"), @@ -269,9 +264,9 @@ func LoadGemma3(modelPath string) (*GemmaModel, error) { if lmHeadWeight != nil { lmHeadScales := w("lm_head.scales") if lmHeadScales != nil && q != nil { - m.Output = mlx.NewQuantizedLinear(lmHeadWeight, lmHeadScales, w("lm_head.biases"), nil, q.GroupSize, q.Bits) + m.Output = NewQuantizedLinear(lmHeadWeight, lmHeadScales, w("lm_head.biases"), nil, q.GroupSize, q.Bits) } else { - m.Output = mlx.NewLinear(lmHeadWeight, nil) + m.Output = NewLinear(lmHeadWeight, nil) } } else { // Tied embeddings — reuse embed_tokens weights (with quantization if present) @@ -279,11 +274,11 @@ func LoadGemma3(modelPath string) (*GemmaModel, error) { } // Materialize all weights - var allArrays []*mlx.Array + var allArrays []*Array for _, a := range weights { allArrays = append(allArrays, a) } - mlx.Materialize(allArrays...) + Materialize(allArrays...) // Precompute (1 + weight) for Gemma-style RMSNorm precomputeScaledWeights(m) @@ -292,25 +287,25 @@ func LoadGemma3(modelPath string) (*GemmaModel, error) { } func precomputeScaledWeights(m *GemmaModel) { - m.NormScaled = mlx.AddScalar(m.Norm.Weight, 1.0) + m.NormScaled = AddScalar(m.Norm.Weight, 1.0) for _, layer := range m.Layers { - layer.InputNormScaled = mlx.AddScalar(layer.InputNorm.Weight, 1.0) - layer.PostAttnNormScaled = mlx.AddScalar(layer.PostAttnNorm.Weight, 1.0) - layer.PreFFNormScaled = mlx.AddScalar(layer.PreFFNorm.Weight, 1.0) - layer.PostFFNormScaled = mlx.AddScalar(layer.PostFFNorm.Weight, 1.0) - layer.Attention.QNormScaled = mlx.AddScalar(layer.Attention.QNorm.Weight, 1.0) - layer.Attention.KNormScaled = mlx.AddScalar(layer.Attention.KNorm.Weight, 1.0) + layer.InputNormScaled = AddScalar(layer.InputNorm.Weight, 1.0) + layer.PostAttnNormScaled = AddScalar(layer.PostAttnNorm.Weight, 1.0) + layer.PreFFNormScaled = AddScalar(layer.PreFFNorm.Weight, 1.0) + layer.PostFFNormScaled = AddScalar(layer.PostFFNorm.Weight, 1.0) + layer.Attention.QNormScaled = AddScalar(layer.Attention.QNorm.Weight, 1.0) + layer.Attention.KNormScaled = AddScalar(layer.Attention.KNorm.Weight, 1.0) } - var scaled []*mlx.Array + var scaled []*Array scaled = append(scaled, m.NormScaled) for _, layer := range m.Layers { scaled = append(scaled, layer.InputNormScaled, layer.PostAttnNormScaled, layer.PreFFNormScaled, layer.PostFFNormScaled, layer.Attention.QNormScaled, layer.Attention.KNormScaled) } - mlx.Materialize(scaled...) + Materialize(scaled...) } func isLayerSliding(layerIdx, pattern int32) bool { @@ -321,56 +316,56 @@ func isLayerSliding(layerIdx, pattern int32) bool { } // Forward runs the text model forward pass. -func (m *GemmaModel) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array { +func (m *GemmaModel) Forward(tokens *Array, caches []Cache) *Array { shape := tokens.Shape() B, L := shape[0], shape[1] h := m.EmbedTokens.Forward(tokens) - h = mlx.MulScalar(h, float32(math.Sqrt(float64(m.Cfg.HiddenSize)))) + h = MulScalar(h, float32(math.Sqrt(float64(m.Cfg.HiddenSize)))) for i, layer := range m.Layers { h = layer.forward(h, caches[i], B, L, m.Cfg) } - return m.Output.Forward(mlx.RMSNorm(h, m.NormScaled, m.Cfg.RMSNormEps)) + return m.Output.Forward(RMSNorm(h, m.NormScaled, m.Cfg.RMSNormEps)) } -func (l *DecoderLayer) forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *TextConfig) *mlx.Array { - normed := mlx.RMSNorm(x, l.InputNormScaled, cfg.RMSNormEps) +func (l *DecoderLayer) forward(x *Array, c Cache, B, L int32, cfg *TextConfig) *Array { + normed := RMSNorm(x, l.InputNormScaled, cfg.RMSNormEps) attnOut := l.Attention.forward(normed, c, B, L, l.IsSliding, cfg) - attnOut = mlx.RMSNorm(attnOut, l.PostAttnNormScaled, cfg.RMSNormEps) - h := mlx.Add(x, attnOut) + attnOut = RMSNorm(attnOut, l.PostAttnNormScaled, cfg.RMSNormEps) + h := Add(x, attnOut) - normed = mlx.RMSNorm(h, l.PreFFNormScaled, cfg.RMSNormEps) + normed = RMSNorm(h, l.PreFFNormScaled, cfg.RMSNormEps) mlpOut := l.MLP.forward(normed) - mlpOut = mlx.RMSNorm(mlpOut, l.PostFFNormScaled, cfg.RMSNormEps) - return mlx.Add(h, mlpOut) + mlpOut = RMSNorm(mlpOut, l.PostFFNormScaled, cfg.RMSNormEps) + return Add(h, mlpOut) } -func (a *Attention) forward(x *mlx.Array, c cache.Cache, B, L int32, isSliding bool, cfg *TextConfig) *mlx.Array { +func (a *Attention) forward(x *Array, c Cache, B, L int32, isSliding bool, cfg *TextConfig) *Array { q := a.QProj.Forward(x) k := a.KProj.Forward(x) v := a.VProj.Forward(x) // Reshape to [B, num_heads, L, head_dim] - q = mlx.AsStrided(q, []int32{B, cfg.NumAttentionHeads, L, cfg.HeadDim}, + q = AsStrided(q, []int32{B, cfg.NumAttentionHeads, L, cfg.HeadDim}, []int64{int64(L * cfg.NumAttentionHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumAttentionHeads * cfg.HeadDim), 1}, 0) - k = mlx.AsStrided(k, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim}, + k = AsStrided(k, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim}, []int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0) - v = mlx.AsStrided(v, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim}, + v = AsStrided(v, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim}, []int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0) // Q/K normalization - q = mlx.RMSNorm(q, a.QNormScaled, cfg.RMSNormEps) - k = mlx.RMSNorm(k, a.KNormScaled, cfg.RMSNormEps) + q = RMSNorm(q, a.QNormScaled, cfg.RMSNormEps) + k = RMSNorm(k, a.KNormScaled, cfg.RMSNormEps) // RoPE with appropriate theta ropeTheta := cfg.RopeTheta if isSliding { ropeTheta = cfg.RopeLocalBaseFreq } - q = mlx.RoPE(q, int(cfg.HeadDim), false, ropeTheta, 1.0, c.Offset()) - k = mlx.RoPE(k, int(cfg.HeadDim), false, ropeTheta, 1.0, c.Offset()) + q = RoPE(q, int(cfg.HeadDim), false, ropeTheta, 1.0, c.Offset()) + k = RoPE(k, int(cfg.HeadDim), false, ropeTheta, 1.0, c.Offset()) // Update cache k, v = c.Update(k, v, int(L)) @@ -378,29 +373,29 @@ func (a *Attention) forward(x *mlx.Array, c cache.Cache, B, L int32, isSliding b // GQA: repeat K/V heads repeatFactor := cfg.NumAttentionHeads / cfg.NumKeyValueHeads if repeatFactor > 1 { - k = mlx.RepeatKV(k, repeatFactor) - v = mlx.RepeatKV(v, repeatFactor) + k = RepeatKV(k, repeatFactor) + v = RepeatKV(v, repeatFactor) } // Scaled dot-product attention - out := mlx.ScaledDotProductAttention(q, k, v, cfg.Scale, L > 1) - out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.HeadDim) + out := ScaledDotProductAttention(q, k, v, cfg.Scale, L > 1) + out = Reshape(Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.HeadDim) return a.OProj.Forward(out) } -func (m *MLP) forward(x *mlx.Array) *mlx.Array { +func (m *MLP) forward(x *Array) *Array { gate := getCompiledGELU().Call(m.GateProj.Forward(x))[0] - return m.DownProj.Forward(mlx.Mul(gate, m.UpProj.Forward(x))) + return m.DownProj.Forward(Mul(gate, m.UpProj.Forward(x))) } // NewCache creates per-layer caches for generation. -func (m *GemmaModel) NewCache() []cache.Cache { - caches := make([]cache.Cache, len(m.Layers)) +func (m *GemmaModel) NewCache() []Cache { + caches := make([]Cache, len(m.Layers)) for i := range caches { if m.Layers[i].IsSliding { - caches[i] = cache.NewRotatingKVCache(int(m.Cfg.SlidingWindow)) + caches[i] = NewRotatingKVCache(int(m.Cfg.SlidingWindow)) } else { - caches[i] = cache.NewKVCache() + caches[i] = NewKVCache() } } return caches @@ -410,22 +405,22 @@ func (m *GemmaModel) NewCache() []cache.Cache { func (m *GemmaModel) NumLayers() int { return len(m.Layers) } // Tokenizer returns the model's tokenizer. -func (m *GemmaModel) Tokenizer() *tokenizer.Tokenizer { return m.Tok } +func (m *GemmaModel) Tokenizer() *Tokenizer { return m.Tok } // ModelType returns the architecture identifier. func (m *GemmaModel) ModelType() string { return "gemma3" } // ApplyLoRA wraps target projection layers with LoRA adapters. -func (m *GemmaModel) ApplyLoRA(cfg mlx.LoRAConfig) *mlx.LoRAAdapter { - adapter := &mlx.LoRAAdapter{ - Layers: make(map[string]*mlx.LoRALinear), +func (m *GemmaModel) ApplyLoRA(cfg LoRAConfig) *LoRAAdapter { + adapter := &LoRAAdapter{ + Layers: make(map[string]*LoRALinear), Config: cfg, } for i, layer := range m.Layers { prefix := fmt.Sprintf("model.layers.%d.self_attn", i) for _, target := range cfg.TargetKeys { - var proj *mlx.Linear + var proj *Linear switch target { case "q_proj": proj = layer.Attention.QProj @@ -437,7 +432,7 @@ func (m *GemmaModel) ApplyLoRA(cfg mlx.LoRAConfig) *mlx.LoRAAdapter { proj = layer.Attention.OProj } if proj != nil { - lora := mlx.NewLoRALinear(proj, cfg.Rank, cfg.Alpha) + lora := NewLoRALinear(proj, cfg.Rank, cfg.Alpha) proj.LoRA = lora adapter.Layers[prefix+"."+target] = lora } diff --git a/model/model.go b/internal/metal/model.go similarity index 69% rename from model/model.go rename to internal/metal/model.go index 408d8d1..db97a03 100644 --- a/model/model.go +++ b/internal/metal/model.go @@ -1,39 +1,34 @@ //go:build darwin && arm64 -// Package model provides transformer model architectures for MLX inference. -package model +package metal import ( "encoding/json" "fmt" "os" "path/filepath" - - "forge.lthn.ai/core/go-mlx" - "forge.lthn.ai/core/go-mlx/cache" - "forge.lthn.ai/core/go-mlx/tokenizer" ) -// Model is the common interface for all transformer model architectures. -type Model interface { +// InternalModel is the common interface for all transformer model architectures. +type InternalModel interface { // Forward runs the model forward pass on token IDs with KV caches. - Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array + Forward(tokens *Array, caches []Cache) *Array // NewCache creates per-layer KV caches for generation. - NewCache() []cache.Cache + NewCache() []Cache // NumLayers returns the number of transformer layers. NumLayers() int // Tokenizer returns the model's tokenizer. - Tokenizer() *tokenizer.Tokenizer + Tokenizer() *Tokenizer // ModelType returns the architecture identifier (e.g. "gemma3", "qwen3"). ModelType() string // ApplyLoRA wraps target projection layers with LoRA adapters for training. // Returns the adapter which holds references to all LoRA layers. - ApplyLoRA(cfg mlx.LoRAConfig) *mlx.LoRAAdapter + ApplyLoRA(cfg LoRAConfig) *LoRAAdapter } // QuantizationConfig holds quantization parameters from config.json. @@ -43,7 +38,7 @@ type QuantizationConfig struct { } // resolveWeight looks up a weight with optional "language_model." prefix. -func resolveWeight(weights map[string]*mlx.Array, name string) *mlx.Array { +func resolveWeight(weights map[string]*Array, name string) *Array { if w, ok := weights[name]; ok { return w } @@ -53,8 +48,8 @@ func resolveWeight(weights map[string]*mlx.Array, name string) *mlx.Array { return nil } -// LoadModel auto-detects the model architecture from config.json and loads it. -func LoadModel(modelPath string) (Model, error) { +// loadModel auto-detects the model architecture from config.json and loads it. +func loadModel(modelPath string) (InternalModel, error) { data, err := os.ReadFile(filepath.Join(modelPath, "config.json")) if err != nil { return nil, fmt.Errorf("model: load config: %w", err) diff --git a/model/qwen3.go b/internal/metal/qwen3.go similarity index 69% rename from model/qwen3.go rename to internal/metal/qwen3.go index 0a7ca13..e288d0a 100644 --- a/model/qwen3.go +++ b/internal/metal/qwen3.go @@ -1,6 +1,6 @@ //go:build darwin && arm64 -package model +package metal import ( "encoding/json" @@ -9,10 +9,6 @@ import ( "math" "os" "path/filepath" - - "forge.lthn.ai/core/go-mlx" - "forge.lthn.ai/core/go-mlx/cache" - "forge.lthn.ai/core/go-mlx/tokenizer" ) // Qwen3Config holds Qwen 3 model configuration. @@ -34,39 +30,39 @@ type Qwen3Config struct { // Qwen3Model is the Qwen 3 text model. type Qwen3Model struct { - EmbedTokens *mlx.Embedding + EmbedTokens *Embedding Layers []*Qwen3DecoderLayer - Norm *mlx.RMSNormModule - Output *mlx.Linear + Norm *RMSNormModule + Output *Linear - Tok *tokenizer.Tokenizer + Tok *Tokenizer Cfg *Qwen3Config } // Qwen3DecoderLayer is a single transformer block. // Qwen 3 uses standard pre-norm residual: norm→attn→add, norm→mlp→add. type Qwen3DecoderLayer struct { - InputNorm *mlx.RMSNormModule // Pre-attention norm - PostAttnNorm *mlx.RMSNormModule // Pre-MLP norm (confusingly named post_attention_layernorm) + InputNorm *RMSNormModule // Pre-attention norm + PostAttnNorm *RMSNormModule // Pre-MLP norm (confusingly named post_attention_layernorm) Attention *Qwen3Attention MLP *Qwen3MLP } // Qwen3Attention implements Qwen 3 GQA with Q/K RMS normalization. type Qwen3Attention struct { - QProj *mlx.Linear - KProj *mlx.Linear - VProj *mlx.Linear - OProj *mlx.Linear - QNorm *mlx.RMSNormModule - KNorm *mlx.RMSNormModule + QProj *Linear + KProj *Linear + VProj *Linear + OProj *Linear + QNorm *RMSNormModule + KNorm *RMSNormModule } // Qwen3MLP is the SwiGLU feed-forward network: down(silu(gate(x)) * up(x)). type Qwen3MLP struct { - GateProj *mlx.Linear - UpProj *mlx.Linear - DownProj *mlx.Linear + GateProj *Linear + UpProj *Linear + DownProj *Linear } func parseQwen3Config(data []byte) (*Qwen3Config, error) { @@ -114,40 +110,40 @@ func LoadQwen3(modelPath string) (*Qwen3Model, error) { return nil, fmt.Errorf("qwen3: parse config: %w", err) } - tok, err := tokenizer.Load(filepath.Join(modelPath, "tokenizer.json")) + tok, err := LoadTokenizer(filepath.Join(modelPath, "tokenizer.json")) if err != nil { return nil, fmt.Errorf("qwen3: load tokenizer: %w", err) } // Load weights from all safetensors files - weights := make(map[string]*mlx.Array) + weights := make(map[string]*Array) matches, _ := filepath.Glob(filepath.Join(modelPath, "*.safetensors")) for _, path := range matches { - for name, arr := range mlx.LoadSafetensors(path) { + for name, arr := range LoadSafetensors(path) { weights[name] = arr } } - w := func(name string) *mlx.Array { return resolveWeight(weights, name) } + w := func(name string) *Array { return resolveWeight(weights, name) } // Quantization setup q := cfg.Quantization if q != nil { slog.Info("qwen3: using quantized inference", "bits", q.Bits, "group_size", q.GroupSize) } - linear := func(prefix string) *mlx.Linear { + linear := func(prefix string) *Linear { weight := w(prefix + ".weight") scales := w(prefix + ".scales") biases := w(prefix + ".biases") bias := w(prefix + ".bias") if scales != nil && q != nil { - return mlx.NewQuantizedLinear(weight, scales, biases, bias, q.GroupSize, q.Bits) + return NewQuantizedLinear(weight, scales, biases, bias, q.GroupSize, q.Bits) } - return mlx.NewLinear(weight, bias) + return NewLinear(weight, bias) } // Embedding - embed := &mlx.Embedding{Weight: w("model.embed_tokens.weight")} + embed := &Embedding{Weight: w("model.embed_tokens.weight")} if embedScales := w("model.embed_tokens.scales"); embedScales != nil && q != nil { embed.Scales = embedScales embed.Biases = w("model.embed_tokens.biases") @@ -158,7 +154,7 @@ func LoadQwen3(modelPath string) (*Qwen3Model, error) { m := &Qwen3Model{ EmbedTokens: embed, Layers: make([]*Qwen3DecoderLayer, cfg.NumHiddenLayers), - Norm: &mlx.RMSNormModule{Weight: w("model.norm.weight")}, + Norm: &RMSNormModule{Weight: w("model.norm.weight")}, Tok: tok, Cfg: cfg, } @@ -166,15 +162,15 @@ func LoadQwen3(modelPath string) (*Qwen3Model, error) { for i := int32(0); i < cfg.NumHiddenLayers; i++ { p := fmt.Sprintf("model.layers.%d", i) m.Layers[i] = &Qwen3DecoderLayer{ - InputNorm: &mlx.RMSNormModule{Weight: w(p + ".input_layernorm.weight")}, - PostAttnNorm: &mlx.RMSNormModule{Weight: w(p + ".post_attention_layernorm.weight")}, + InputNorm: &RMSNormModule{Weight: w(p + ".input_layernorm.weight")}, + PostAttnNorm: &RMSNormModule{Weight: w(p + ".post_attention_layernorm.weight")}, Attention: &Qwen3Attention{ QProj: linear(p + ".self_attn.q_proj"), KProj: linear(p + ".self_attn.k_proj"), VProj: linear(p + ".self_attn.v_proj"), OProj: linear(p + ".self_attn.o_proj"), - QNorm: &mlx.RMSNormModule{Weight: w(p + ".self_attn.q_norm.weight")}, - KNorm: &mlx.RMSNormModule{Weight: w(p + ".self_attn.k_norm.weight")}, + QNorm: &RMSNormModule{Weight: w(p + ".self_attn.q_norm.weight")}, + KNorm: &RMSNormModule{Weight: w(p + ".self_attn.k_norm.weight")}, }, MLP: &Qwen3MLP{ GateProj: linear(p + ".mlp.gate_proj"), @@ -189,20 +185,20 @@ func LoadQwen3(modelPath string) (*Qwen3Model, error) { if lmHeadWeight != nil { lmHeadScales := w("lm_head.scales") if lmHeadScales != nil && q != nil { - m.Output = mlx.NewQuantizedLinear(lmHeadWeight, lmHeadScales, w("lm_head.biases"), nil, q.GroupSize, q.Bits) + m.Output = NewQuantizedLinear(lmHeadWeight, lmHeadScales, w("lm_head.biases"), nil, q.GroupSize, q.Bits) } else { - m.Output = mlx.NewLinear(lmHeadWeight, nil) + m.Output = NewLinear(lmHeadWeight, nil) } } else { m.Output = m.EmbedTokens.AsLinear() } // Materialise all weights onto Metal - var allArrays []*mlx.Array + var allArrays []*Array for _, a := range weights { allArrays = append(allArrays, a) } - mlx.Materialize(allArrays...) + Materialize(allArrays...) slog.Info("qwen3: model loaded", "layers", cfg.NumHiddenLayers, @@ -218,7 +214,7 @@ func LoadQwen3(modelPath string) (*Qwen3Model, error) { // Forward runs the Qwen 3 forward pass. // Unlike Gemma, Qwen does NOT scale embeddings by sqrt(hidden_size). -func (m *Qwen3Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array { +func (m *Qwen3Model) Forward(tokens *Array, caches []Cache) *Array { shape := tokens.Shape() B, L := shape[0], shape[1] @@ -231,29 +227,29 @@ func (m *Qwen3Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array return m.Output.Forward(m.Norm.Forward(h, m.Cfg.RMSNormEps)) } -func (l *Qwen3DecoderLayer) forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Qwen3Config) *mlx.Array { +func (l *Qwen3DecoderLayer) forward(x *Array, c Cache, B, L int32, cfg *Qwen3Config) *Array { // Pre-attention norm → attention → residual add normed := l.InputNorm.Forward(x, cfg.RMSNormEps) attnOut := l.Attention.forward(normed, c, B, L, cfg) - h := mlx.Add(x, attnOut) + h := Add(x, attnOut) // Pre-MLP norm → MLP → residual add normed = l.PostAttnNorm.Forward(h, cfg.RMSNormEps) mlpOut := l.MLP.forward(normed) - return mlx.Add(h, mlpOut) + return Add(h, mlpOut) } -func (a *Qwen3Attention) forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Qwen3Config) *mlx.Array { +func (a *Qwen3Attention) forward(x *Array, c Cache, B, L int32, cfg *Qwen3Config) *Array { q := a.QProj.Forward(x) k := a.KProj.Forward(x) v := a.VProj.Forward(x) // Reshape to [B, num_heads, L, head_dim] - q = mlx.AsStrided(q, []int32{B, cfg.NumAttentionHeads, L, cfg.HeadDim}, + q = AsStrided(q, []int32{B, cfg.NumAttentionHeads, L, cfg.HeadDim}, []int64{int64(L * cfg.NumAttentionHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumAttentionHeads * cfg.HeadDim), 1}, 0) - k = mlx.AsStrided(k, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim}, + k = AsStrided(k, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim}, []int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0) - v = mlx.AsStrided(v, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim}, + v = AsStrided(v, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim}, []int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0) // Q/K RMS normalization (Qwen 3 has this) @@ -261,8 +257,8 @@ func (a *Qwen3Attention) forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Q k = a.KNorm.Forward(k, cfg.RMSNormEps) // RoPE — single theta for all layers (no sliding window) - q = mlx.RoPE(q, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, c.Offset()) - k = mlx.RoPE(k, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, c.Offset()) + q = RoPE(q, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, c.Offset()) + k = RoPE(k, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, c.Offset()) // Update KV cache k, v = c.Update(k, v, int(L)) @@ -270,27 +266,27 @@ func (a *Qwen3Attention) forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Q // GQA: repeat K/V heads to match Q heads repeatFactor := cfg.NumAttentionHeads / cfg.NumKeyValueHeads if repeatFactor > 1 { - k = mlx.RepeatKV(k, repeatFactor) - v = mlx.RepeatKV(v, repeatFactor) + k = RepeatKV(k, repeatFactor) + v = RepeatKV(v, repeatFactor) } // Scaled dot-product attention - out := mlx.ScaledDotProductAttention(q, k, v, cfg.Scale, L > 1) - out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.HeadDim) + out := ScaledDotProductAttention(q, k, v, cfg.Scale, L > 1) + out = Reshape(Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.HeadDim) return a.OProj.Forward(out) } // forward computes SwiGLU: down(silu(gate(x)) * up(x)). -func (m *Qwen3MLP) forward(x *mlx.Array) *mlx.Array { - gate := mlx.SiLU(m.GateProj.Forward(x)) - return m.DownProj.Forward(mlx.Mul(gate, m.UpProj.Forward(x))) +func (m *Qwen3MLP) forward(x *Array) *Array { + gate := SiLU(m.GateProj.Forward(x)) + return m.DownProj.Forward(Mul(gate, m.UpProj.Forward(x))) } // NewCache creates per-layer KV caches. Qwen 3 uses global attention only. -func (m *Qwen3Model) NewCache() []cache.Cache { - caches := make([]cache.Cache, len(m.Layers)) +func (m *Qwen3Model) NewCache() []Cache { + caches := make([]Cache, len(m.Layers)) for i := range caches { - caches[i] = cache.NewKVCache() + caches[i] = NewKVCache() } return caches } @@ -299,22 +295,22 @@ func (m *Qwen3Model) NewCache() []cache.Cache { func (m *Qwen3Model) NumLayers() int { return len(m.Layers) } // Tokenizer returns the model's tokenizer. -func (m *Qwen3Model) Tokenizer() *tokenizer.Tokenizer { return m.Tok } +func (m *Qwen3Model) Tokenizer() *Tokenizer { return m.Tok } // ModelType returns the architecture identifier. func (m *Qwen3Model) ModelType() string { return "qwen3" } // ApplyLoRA wraps target projection layers with LoRA adapters. -func (m *Qwen3Model) ApplyLoRA(cfg mlx.LoRAConfig) *mlx.LoRAAdapter { - adapter := &mlx.LoRAAdapter{ - Layers: make(map[string]*mlx.LoRALinear), +func (m *Qwen3Model) ApplyLoRA(cfg LoRAConfig) *LoRAAdapter { + adapter := &LoRAAdapter{ + Layers: make(map[string]*LoRALinear), Config: cfg, } for i, layer := range m.Layers { prefix := fmt.Sprintf("model.layers.%d.self_attn", i) for _, target := range cfg.TargetKeys { - var proj *mlx.Linear + var proj *Linear switch target { case "q_proj": proj = layer.Attention.QProj @@ -326,7 +322,7 @@ func (m *Qwen3Model) ApplyLoRA(cfg mlx.LoRAConfig) *mlx.LoRAAdapter { proj = layer.Attention.OProj } if proj != nil { - lora := mlx.NewLoRALinear(proj, cfg.Rank, cfg.Alpha) + lora := NewLoRALinear(proj, cfg.Rank, cfg.Alpha) proj.LoRA = lora adapter.Layers[prefix+"."+target] = lora } diff --git a/sample/sample.go b/internal/metal/sample.go similarity index 60% rename from sample/sample.go rename to internal/metal/sample.go index 4a6b599..7ef7f97 100644 --- a/sample/sample.go +++ b/internal/metal/sample.go @@ -1,22 +1,19 @@ //go:build darwin && arm64 -// Package sample provides composable token sampling strategies. -package sample +package metal import ( "math" - - "forge.lthn.ai/core/go-mlx" ) // Sampler transforms logits into a sampled token index. type Sampler interface { - Sample(logits *mlx.Array) *mlx.Array + Sample(logits *Array) *Array } -// New creates a composable sampler chain from the given parameters. +// newSampler creates a composable sampler chain from the given parameters. // Order: TopP -> MinP -> TopK -> Temperature -> categorical sample. -func New(temp, topP, minP float32, topK int) Sampler { +func newSampler(temp, topP, minP float32, topK int) Sampler { if temp == 0 { return greedy{} } @@ -38,43 +35,43 @@ func New(temp, topP, minP float32, topK int) Sampler { // chain applies a sequence of samplers, then samples from the result. type chain []Sampler -func (c chain) Sample(logits *mlx.Array) *mlx.Array { +func (c chain) Sample(logits *Array) *Array { for _, s := range c { logits = s.Sample(logits) } // Final categorical sample from log-probabilities - return mlx.RandomCategorical(logits) + return RandomCategorical(logits) } // greedy returns the argmax token. type greedy struct{} -func (greedy) Sample(logits *mlx.Array) *mlx.Array { - return mlx.Argmax(logits, -1, false) +func (greedy) Sample(logits *Array) *Array { + return Argmax(logits, -1, false) } // Temperature scales logits by 1/temp. type Temperature float32 -func (t Temperature) Sample(logits *mlx.Array) *mlx.Array { - return mlx.MulScalar(logits, 1.0/float32(t)) +func (t Temperature) Sample(logits *Array) *Array { + return MulScalar(logits, 1.0/float32(t)) } // TopKSampler masks all but the top-k logits. type TopKSampler int -func (k TopKSampler) Sample(logits *mlx.Array) *mlx.Array { - neg := mlx.Negative(logits) - mask := mlx.Argpartition(neg, int(k)-1, -1) +func (k TopKSampler) Sample(logits *Array) *Array { + neg := Negative(logits) + mask := Argpartition(neg, int(k)-1, -1) // Slice the indices beyond top-k - mask = mlx.SliceAxis(mask, -1, int32(k), int32(logits.Dim(-1))) - return mlx.PutAlongAxis(logits, mask, mlx.FromValue(float32(math.Inf(-1))), -1) + mask = SliceAxis(mask, -1, int32(k), int32(logits.Dim(-1))) + return PutAlongAxis(logits, mask, FromValue(float32(math.Inf(-1))), -1) } // TopP implements nucleus sampling (cumulative probability threshold). type TopP float32 -func (p TopP) Sample(logits *mlx.Array) *mlx.Array { +func (p TopP) Sample(logits *Array) *Array { // TODO: full nucleus sampling requires cumsum which mlx-c doesn't expose directly. // For now, pass through. TopK + Temperature covers most use cases. return logits @@ -83,7 +80,7 @@ func (p TopP) Sample(logits *mlx.Array) *mlx.Array { // MinPSampler masks tokens below min_p * max_prob. type MinPSampler float32 -func (p MinPSampler) Sample(logits *mlx.Array) *mlx.Array { +func (p MinPSampler) Sample(logits *Array) *Array { // For now, pass through — MinP is an optimization over TopP. // Full implementation requires finding max prob and masking below threshold. return logits diff --git a/tokenizer/tokenizer.go b/internal/metal/tokenizer.go similarity index 96% rename from tokenizer/tokenizer.go rename to internal/metal/tokenizer.go index 947a1a5..048202a 100644 --- a/tokenizer/tokenizer.go +++ b/internal/metal/tokenizer.go @@ -1,7 +1,6 @@ //go:build darwin && arm64 -// Package tokenizer provides BPE tokenization for transformer models. -package tokenizer +package metal import ( "encoding/json" @@ -46,8 +45,8 @@ type tokenizerJSON struct { } `json:"added_tokens"` } -// Load reads a tokenizer.json file and creates a Tokenizer. -func Load(path string) (*Tokenizer, error) { +// LoadTokenizer reads a tokenizer.json file and creates a Tokenizer. +func LoadTokenizer(path string) (*Tokenizer, error) { data, err := os.ReadFile(path) if err != nil { return nil, fmt.Errorf("tokenizer: read %s: %w", path, err) @@ -294,7 +293,7 @@ func (t *Tokenizer) DecodeToken(id int32) string { return t.decodeGPT2Bytes(text) } - // SentencePiece: replace ▁ with space but keep it (it's the word boundary) + // SentencePiece: replace with space but keep it (it's the word boundary) return strings.ReplaceAll(text, "▁", " ") } diff --git a/sample/sample_test.go b/sample/sample_test.go deleted file mode 100644 index 7cd3e62..0000000 --- a/sample/sample_test.go +++ /dev/null @@ -1,116 +0,0 @@ -//go:build darwin && arm64 - -package sample - -import ( - "testing" - - "forge.lthn.ai/core/go-mlx" -) - -func TestGreedy(t *testing.T) { - // Logits heavily favour index 2 - logits := mlx.FromValues([]float32{-10, -10, 100, -10}, 1, 4) - s := New(0, 0, 0, 0) // temp=0 → greedy - token := s.Sample(logits) - mlx.Materialize(token) - - if token.Int() != 2 { - t.Errorf("greedy sample = %d, want 2", token.Int()) - } -} - -func TestTemperature_HighTemp(t *testing.T) { - // High temperature should still produce a valid index - logits := mlx.FromValues([]float32{1, 2, 3, 4}, 1, 4) - s := New(100.0, 0, 0, 0) // very high temp → near uniform - token := s.Sample(logits) - mlx.Materialize(token) - - idx := token.Int() - if idx < 0 || idx >= 4 { - t.Errorf("sample index = %d, out of range [0, 4)", idx) - } -} - -func TestTemperature_LowTemp(t *testing.T) { - // Very low temperature should behave like greedy - logits := mlx.FromValues([]float32{-10, -10, 100, -10}, 1, 4) - s := New(0.001, 0, 0, 0) // near-zero temp → near-greedy - token := s.Sample(logits) - mlx.Materialize(token) - - if token.Int() != 2 { - t.Errorf("low-temp sample = %d, want 2 (near greedy)", token.Int()) - } -} - -func TestTopK(t *testing.T) { - // TopK=1 with clear winner should always pick that token - logits := mlx.FromValues([]float32{-100, 100, -100, -100}, 1, 4) - s := New(1.0, 0, 0, 1) // topK=1 - token := s.Sample(logits) - mlx.Materialize(token) - - if token.Int() != 1 { - t.Errorf("topk=1 sample = %d, want 1", token.Int()) - } -} - -func TestTopK_MultipleTokens(t *testing.T) { - // TopK=2, both high logits — should pick one of them - logits := mlx.FromValues([]float32{-100, 50, 50, -100}, 1, 4) - s := New(1.0, 0, 0, 2) // topK=2 - - seen := map[int]bool{} - for range 20 { - token := s.Sample(logits) - mlx.Materialize(token) - seen[token.Int()] = true - } - - // Should only ever pick index 1 or 2 - for idx := range seen { - if idx != 1 && idx != 2 { - t.Errorf("topk=2 sampled index %d, expected only 1 or 2", idx) - } - } -} - -func TestNew_Chain(t *testing.T) { - // Full chain: topK + temperature - logits := mlx.FromValues([]float32{1, 2, 3, 4, 5}, 1, 5) - s := New(0.5, 0, 0, 3) // temp=0.5, topK=3 - - token := s.Sample(logits) - mlx.Materialize(token) - - idx := token.Int() - if idx < 0 || idx >= 5 { - t.Errorf("chain sample index = %d, out of range", idx) - } -} - -func TestTopP_PassThrough(t *testing.T) { - // TopP is currently a stub — verify it doesn't break the chain - logits := mlx.FromValues([]float32{-10, -10, 100, -10}, 1, 4) - s := New(0.5, 0.9, 0, 0) // topP=0.9 (stub), temp=0.5 - token := s.Sample(logits) - mlx.Materialize(token) - - if token.Int() != 2 { - t.Errorf("topP stub + temp sample = %d, want 2", token.Int()) - } -} - -func TestMinP_PassThrough(t *testing.T) { - // MinP is currently a stub — verify it doesn't break the chain - logits := mlx.FromValues([]float32{-10, -10, 100, -10}, 1, 4) - s := New(0.5, 0, 0.1, 0) // minP=0.1 (stub), temp=0.5 - token := s.Sample(logits) - mlx.Materialize(token) - - if token.Int() != 2 { - t.Errorf("minP stub + temp sample = %d, want 2", token.Int()) - } -} diff --git a/tokenizer/tokenizer_test.go b/tokenizer/tokenizer_test.go deleted file mode 100644 index 2230f5f..0000000 --- a/tokenizer/tokenizer_test.go +++ /dev/null @@ -1,217 +0,0 @@ -//go:build darwin && arm64 - -package tokenizer - -import ( - "os" - "path/filepath" - "testing" -) - -// minimalTokenizerJSON is a valid HuggingFace tokenizer.json with a tiny vocab. -const minimalTokenizerJSON = `{ - "model": { - "type": "BPE", - "vocab": { - "h": 0, - "e": 1, - "l": 2, - "o": 3, - "▁": 4, - "he": 5, - "ll": 6, - "▁h": 7 - }, - "merges": ["h e", "l l"], - "byte_fallback": false - }, - "added_tokens": [ - {"id": 100, "content": "", "special": true}, - {"id": 101, "content": "", "special": true} - ] -}` - -func writeTestTokenizer(t *testing.T) string { - t.Helper() - dir := t.TempDir() - path := filepath.Join(dir, "tokenizer.json") - if err := os.WriteFile(path, []byte(minimalTokenizerJSON), 0644); err != nil { - t.Fatalf("write test tokenizer: %v", err) - } - return path -} - -func TestLoad(t *testing.T) { - path := writeTestTokenizer(t) - tok, err := Load(path) - if err != nil { - t.Fatalf("Load: %v", err) - } - if tok == nil { - t.Fatal("tokenizer is nil") - } -} - -func TestLoad_MissingFile(t *testing.T) { - _, err := Load("/nonexistent/tokenizer.json") - if err == nil { - t.Error("expected error for missing file") - } -} - -func TestLoad_InvalidJSON(t *testing.T) { - dir := t.TempDir() - path := filepath.Join(dir, "tokenizer.json") - os.WriteFile(path, []byte("not json"), 0644) - - _, err := Load(path) - if err == nil { - t.Error("expected error for invalid JSON") - } -} - -func TestBOSEOS(t *testing.T) { - path := writeTestTokenizer(t) - tok, _ := Load(path) - - if tok.BOSToken() != 100 { - t.Errorf("BOS = %d, want 100", tok.BOSToken()) - } - if tok.EOSToken() != 101 { - t.Errorf("EOS = %d, want 101", tok.EOSToken()) - } -} - -func TestEncode_ProducesTokens(t *testing.T) { - path := writeTestTokenizer(t) - tok, _ := Load(path) - - tokens := tok.Encode("hello") - if len(tokens) == 0 { - t.Fatal("Encode returned empty tokens") - } - // First token should be BOS - if tokens[0] != tok.BOSToken() { - t.Errorf("first token = %d, want BOS (%d)", tokens[0], tok.BOSToken()) - } - t.Logf("Encode(\"hello\") = %v", tokens) -} - -func TestDecode_SpecialTokensSkipped(t *testing.T) { - path := writeTestTokenizer(t) - tok, _ := Load(path) - - // Decoding BOS/EOS should produce empty string - text := tok.Decode([]int32{100, 101}) - if text != "" { - t.Errorf("Decode(BOS, EOS) = %q, want empty", text) - } -} - -func TestDecode_RegularTokens(t *testing.T) { - path := writeTestTokenizer(t) - tok, _ := Load(path) - - // Decode known vocab entries - text := tok.Decode([]int32{5, 6, 3}) // "he" + "ll" + "o" - if text != "hello" { - t.Errorf("Decode = %q, want %q", text, "hello") - } -} - -func TestDecodeToken_Regular(t *testing.T) { - path := writeTestTokenizer(t) - tok, _ := Load(path) - - // "he" = token 5 - text := tok.DecodeToken(5) - if text != "he" { - t.Errorf("DecodeToken(5) = %q, want %q", text, "he") - } -} - -func TestDecodeToken_Special(t *testing.T) { - path := writeTestTokenizer(t) - tok, _ := Load(path) - - // Special tokens should return empty - text := tok.DecodeToken(100) - if text != "" { - t.Errorf("DecodeToken(BOS) = %q, want empty", text) - } -} - -func TestDecodeToken_SentencePieceSpace(t *testing.T) { - path := writeTestTokenizer(t) - tok, _ := Load(path) - - // "▁h" = token 7, should decode to " h" (space prefix) - text := tok.DecodeToken(7) - if text != " h" { - t.Errorf("DecodeToken(7) = %q, want %q", text, " h") - } -} - -func TestDecodeToken_Unknown(t *testing.T) { - path := writeTestTokenizer(t) - tok, _ := Load(path) - - text := tok.DecodeToken(9999) - if text != "" { - t.Errorf("DecodeToken(unknown) = %q, want empty", text) - } -} - -func TestFormatGemmaPrompt(t *testing.T) { - got := FormatGemmaPrompt("What is 2+2?") - want := "user\nWhat is 2+2?\nmodel\n" - if got != want { - t.Errorf("FormatGemmaPrompt = %q, want %q", got, want) - } -} - -// --- GPT-2 byte maps --- - -func TestBuildGPT2ByteMaps(t *testing.T) { - decoder, encoder := buildGPT2ByteMaps() - - // All 256 bytes must be mapped - if len(encoder) != 256 { - t.Errorf("encoder has %d entries, want 256", len(encoder)) - } - if len(decoder) != 256 { - t.Errorf("decoder has %d entries, want 256", len(decoder)) - } - - // Round-trip: every byte should survive encode → decode - for b := 0; b < 256; b++ { - r := encoder[byte(b)] - got := decoder[r] - if got != byte(b) { - t.Errorf("byte %d: encode→decode = %d, want %d", b, got, b) - } - } -} - -func TestBuildGPT2ByteMaps_PrintableASCII(t *testing.T) { - _, encoder := buildGPT2ByteMaps() - - // Printable ASCII (33-126) should self-map - for b := 33; b <= 126; b++ { - if encoder[byte(b)] != rune(b) { - t.Errorf("byte %d (%c): expected self-map, got %c", b, b, encoder[byte(b)]) - } - } -} - -func TestBuildGPT2ByteMaps_ControlChars(t *testing.T) { - _, encoder := buildGPT2ByteMaps() - - // Space (32) and control chars (0-31) should NOT self-map - if encoder[byte(32)] == rune(32) { - t.Error("space (32) should not self-map in GPT-2 encoding") - } - if encoder[byte(0)] == rune(0) { - t.Error("null (0) should not self-map in GPT-2 encoding") - } -}