refactor(metal): flatten model, tokenizer, sample, cache into internal/metal
Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
parent
a669d1d9c1
commit
08976aa504
9 changed files with 223 additions and 777 deletions
200
cache/cache_test.go
vendored
200
cache/cache_test.go
vendored
|
|
@ -1,200 +0,0 @@
|
|||
//go:build darwin && arm64
|
||||
|
||||
package cache
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"forge.lthn.ai/core/go-mlx"
|
||||
)
|
||||
|
||||
// makeKV creates a small K/V pair with shape [B=1, H=2, L=seqLen, D=4].
|
||||
func makeKV(seqLen int) (*mlx.Array, *mlx.Array) {
|
||||
size := 1 * 2 * seqLen * 4
|
||||
data := make([]float32, size)
|
||||
for i := range data {
|
||||
data[i] = float32(i) * 0.1
|
||||
}
|
||||
k := mlx.FromValues(data, 1, 2, seqLen, 4)
|
||||
v := mlx.FromValues(data, 1, 2, seqLen, 4)
|
||||
return k, v
|
||||
}
|
||||
|
||||
// --- KVCache ---
|
||||
|
||||
func TestKVCache_New(t *testing.T) {
|
||||
c := NewKVCache()
|
||||
if c.Offset() != 0 {
|
||||
t.Errorf("offset = %d, want 0", c.Offset())
|
||||
}
|
||||
if c.Len() != 0 {
|
||||
t.Errorf("len = %d, want 0", c.Len())
|
||||
}
|
||||
if c.State() != nil {
|
||||
t.Error("state should be nil for empty cache")
|
||||
}
|
||||
}
|
||||
|
||||
func TestKVCache_SingleUpdate(t *testing.T) {
|
||||
c := NewKVCache()
|
||||
k, v := makeKV(3) // 3 tokens
|
||||
|
||||
outK, outV := c.Update(k, v, 3)
|
||||
mlx.Materialize(outK, outV)
|
||||
|
||||
if c.Offset() != 3 {
|
||||
t.Errorf("offset = %d, want 3", c.Offset())
|
||||
}
|
||||
if c.Len() != 3 {
|
||||
t.Errorf("len = %d, want 3", c.Len())
|
||||
}
|
||||
|
||||
// Output K should have shape [1, 2, 3, 4]
|
||||
shape := outK.Shape()
|
||||
if shape[0] != 1 || shape[1] != 2 || shape[2] != 3 || shape[3] != 4 {
|
||||
t.Errorf("outK shape = %v, want [1 2 3 4]", shape)
|
||||
}
|
||||
}
|
||||
|
||||
func TestKVCache_MultipleUpdates(t *testing.T) {
|
||||
c := NewKVCache()
|
||||
|
||||
// Prompt: 5 tokens
|
||||
k1, v1 := makeKV(5)
|
||||
outK, outV := c.Update(k1, v1, 5)
|
||||
mlx.Materialize(outK, outV)
|
||||
|
||||
if c.Offset() != 5 {
|
||||
t.Errorf("offset = %d, want 5", c.Offset())
|
||||
}
|
||||
|
||||
// Generate: 1 token at a time
|
||||
k2, v2 := makeKV(1)
|
||||
outK, outV = c.Update(k2, v2, 1)
|
||||
mlx.Materialize(outK, outV)
|
||||
|
||||
if c.Offset() != 6 {
|
||||
t.Errorf("offset = %d, want 6", c.Offset())
|
||||
}
|
||||
|
||||
shape := outK.Shape()
|
||||
if shape[2] != 6 {
|
||||
t.Errorf("outK L dim = %d, want 6", shape[2])
|
||||
}
|
||||
}
|
||||
|
||||
func TestKVCache_Reset(t *testing.T) {
|
||||
c := NewKVCache()
|
||||
k, v := makeKV(3)
|
||||
c.Update(k, v, 3)
|
||||
|
||||
c.Reset()
|
||||
|
||||
if c.Offset() != 0 {
|
||||
t.Errorf("offset after reset = %d, want 0", c.Offset())
|
||||
}
|
||||
if c.State() != nil {
|
||||
t.Error("state should be nil after reset")
|
||||
}
|
||||
}
|
||||
|
||||
func TestKVCache_State(t *testing.T) {
|
||||
c := NewKVCache()
|
||||
k, v := makeKV(2)
|
||||
c.Update(k, v, 2)
|
||||
|
||||
state := c.State()
|
||||
if len(state) != 2 {
|
||||
t.Fatalf("state length = %d, want 2", len(state))
|
||||
}
|
||||
// state[0] = keys, state[1] = values
|
||||
if state[0] == nil || state[1] == nil {
|
||||
t.Error("state arrays should not be nil")
|
||||
}
|
||||
}
|
||||
|
||||
// --- RotatingKVCache ---
|
||||
|
||||
func TestRotatingKVCache_New(t *testing.T) {
|
||||
c := NewRotatingKVCache(16)
|
||||
if c.Offset() != 0 {
|
||||
t.Errorf("offset = %d, want 0", c.Offset())
|
||||
}
|
||||
if c.Len() != 0 {
|
||||
t.Errorf("len = %d, want 0", c.Len())
|
||||
}
|
||||
}
|
||||
|
||||
func TestRotatingKVCache_SingleToken(t *testing.T) {
|
||||
c := NewRotatingKVCache(8)
|
||||
k, v := makeKV(1)
|
||||
|
||||
outK, outV := c.Update(k, v, 1)
|
||||
mlx.Materialize(outK, outV)
|
||||
|
||||
if c.Offset() != 1 {
|
||||
t.Errorf("offset = %d, want 1", c.Offset())
|
||||
}
|
||||
if c.Len() != 1 {
|
||||
t.Errorf("len = %d, want 1", c.Len())
|
||||
}
|
||||
}
|
||||
|
||||
func TestRotatingKVCache_MultiTokenPrompt(t *testing.T) {
|
||||
c := NewRotatingKVCache(16)
|
||||
k, v := makeKV(5)
|
||||
|
||||
outK, outV := c.Update(k, v, 5)
|
||||
mlx.Materialize(outK, outV)
|
||||
|
||||
if c.Offset() != 5 {
|
||||
t.Errorf("offset = %d, want 5", c.Offset())
|
||||
}
|
||||
if c.Len() != 5 {
|
||||
t.Errorf("len = %d, want 5", c.Len())
|
||||
}
|
||||
}
|
||||
|
||||
func TestRotatingKVCache_Bounded(t *testing.T) {
|
||||
c := NewRotatingKVCache(4)
|
||||
|
||||
// Fill with 4-token prompt (at max)
|
||||
k, v := makeKV(4)
|
||||
outK, outV := c.Update(k, v, 4)
|
||||
mlx.Materialize(outK, outV)
|
||||
|
||||
if c.Len() != 4 {
|
||||
t.Errorf("len = %d, want 4 (at max)", c.Len())
|
||||
}
|
||||
|
||||
// Add one more token — should trim to maxSize
|
||||
k2, v2 := makeKV(1)
|
||||
outK, outV = c.Update(k2, v2, 1)
|
||||
mlx.Materialize(outK, outV)
|
||||
|
||||
if c.Offset() != 5 {
|
||||
t.Errorf("offset = %d, want 5", c.Offset())
|
||||
}
|
||||
// Len should be bounded by maxSize
|
||||
if c.Len() != 4 {
|
||||
t.Errorf("len = %d, want 4 (bounded)", c.Len())
|
||||
}
|
||||
}
|
||||
|
||||
func TestRotatingKVCache_Reset(t *testing.T) {
|
||||
c := NewRotatingKVCache(8)
|
||||
k, v := makeKV(3)
|
||||
c.Update(k, v, 3)
|
||||
|
||||
c.Reset()
|
||||
|
||||
if c.Offset() != 0 {
|
||||
t.Errorf("offset after reset = %d, want 0", c.Offset())
|
||||
}
|
||||
if c.Len() != 0 {
|
||||
t.Errorf("len after reset = %d, want 0", c.Len())
|
||||
}
|
||||
if c.State() != nil {
|
||||
t.Error("state should be nil after reset")
|
||||
}
|
||||
}
|
||||
|
|
@ -1,20 +1,17 @@
|
|||
//go:build darwin && arm64
|
||||
|
||||
// Package cache provides KV cache implementations for transformer inference.
|
||||
package cache
|
||||
|
||||
import "forge.lthn.ai/core/go-mlx"
|
||||
package metal
|
||||
|
||||
// Cache manages key-value pairs for transformer attention layers.
|
||||
type Cache interface {
|
||||
// Update adds new key/value tensors and returns the full cached K/V.
|
||||
Update(k, v *mlx.Array, seqLen int) (*mlx.Array, *mlx.Array)
|
||||
Update(k, v *Array, seqLen int) (*Array, *Array)
|
||||
// Offset returns the total number of tokens processed.
|
||||
Offset() int
|
||||
// Len returns the number of cached tokens (may differ from Offset for rotating caches).
|
||||
Len() int
|
||||
// State returns the cached K/V arrays, or nil if empty.
|
||||
State() []*mlx.Array
|
||||
State() []*Array
|
||||
// Reset clears the cache for a new generation session.
|
||||
Reset()
|
||||
}
|
||||
|
|
@ -22,7 +19,7 @@ type Cache interface {
|
|||
// KVCache implements an unbounded cache that grows as needed.
|
||||
// Pre-allocates in chunks of `step` tokens to reduce allocations.
|
||||
type KVCache struct {
|
||||
keys, values *mlx.Array
|
||||
keys, values *Array
|
||||
offset int
|
||||
step int
|
||||
}
|
||||
|
|
@ -32,7 +29,7 @@ func NewKVCache() *KVCache {
|
|||
return &KVCache{step: 256}
|
||||
}
|
||||
|
||||
func (c *KVCache) Update(k, v *mlx.Array, seqLen int) (*mlx.Array, *mlx.Array) {
|
||||
func (c *KVCache) Update(k, v *Array, seqLen int) (*Array, *Array) {
|
||||
prev := c.offset
|
||||
shape := k.Shape()
|
||||
if len(shape) < 4 {
|
||||
|
|
@ -49,34 +46,34 @@ func (c *KVCache) Update(k, v *mlx.Array, seqLen int) (*mlx.Array, *mlx.Array) {
|
|||
// Grow buffer if needed.
|
||||
if c.keys == nil || (prev+seqLen) > int(c.keys.Shape()[2]) {
|
||||
nSteps := (c.step + seqLen - 1) / c.step
|
||||
newK := mlx.Zeros([]int32{B, H, int32(nSteps * c.step), Dk}, k.Dtype())
|
||||
newV := mlx.Zeros([]int32{B, H, int32(nSteps * c.step), Dv}, v.Dtype())
|
||||
newK := Zeros([]int32{B, H, int32(nSteps * c.step), Dk}, k.Dtype())
|
||||
newV := Zeros([]int32{B, H, int32(nSteps * c.step), Dv}, v.Dtype())
|
||||
|
||||
if c.keys != nil {
|
||||
if prev%c.step != 0 {
|
||||
c.keys = mlx.Slice(c.keys, []int32{0, 0, 0, 0}, []int32{B, H, int32(prev), Dk})
|
||||
c.values = mlx.Slice(c.values, []int32{0, 0, 0, 0}, []int32{B, H, int32(prev), Dv})
|
||||
c.keys = Slice(c.keys, []int32{0, 0, 0, 0}, []int32{B, H, int32(prev), Dk})
|
||||
c.values = Slice(c.values, []int32{0, 0, 0, 0}, []int32{B, H, int32(prev), Dv})
|
||||
}
|
||||
c.keys = mlx.Concatenate([]*mlx.Array{c.keys, newK}, 2)
|
||||
c.values = mlx.Concatenate([]*mlx.Array{c.values, newV}, 2)
|
||||
c.keys = Concatenate([]*Array{c.keys, newK}, 2)
|
||||
c.values = Concatenate([]*Array{c.values, newV}, 2)
|
||||
} else {
|
||||
c.keys, c.values = newK, newV
|
||||
}
|
||||
}
|
||||
|
||||
c.offset += seqLen
|
||||
c.keys = mlx.SliceUpdateInplace(c.keys, k, []int32{0, 0, int32(prev), 0}, []int32{B, H, int32(c.offset), Dk})
|
||||
c.values = mlx.SliceUpdateInplace(c.values, v, []int32{0, 0, int32(prev), 0}, []int32{B, H, int32(c.offset), Dv})
|
||||
c.keys = SliceUpdateInplace(c.keys, k, []int32{0, 0, int32(prev), 0}, []int32{B, H, int32(c.offset), Dk})
|
||||
c.values = SliceUpdateInplace(c.values, v, []int32{0, 0, int32(prev), 0}, []int32{B, H, int32(c.offset), Dv})
|
||||
|
||||
return mlx.Slice(c.keys, []int32{0, 0, 0, 0}, []int32{B, H, int32(c.offset), Dk}),
|
||||
mlx.Slice(c.values, []int32{0, 0, 0, 0}, []int32{B, H, int32(c.offset), Dv})
|
||||
return Slice(c.keys, []int32{0, 0, 0, 0}, []int32{B, H, int32(c.offset), Dk}),
|
||||
Slice(c.values, []int32{0, 0, 0, 0}, []int32{B, H, int32(c.offset), Dv})
|
||||
}
|
||||
|
||||
func (c *KVCache) State() []*mlx.Array {
|
||||
func (c *KVCache) State() []*Array {
|
||||
if c.keys == nil {
|
||||
return nil
|
||||
}
|
||||
return []*mlx.Array{c.keys, c.values}
|
||||
return []*Array{c.keys, c.values}
|
||||
}
|
||||
|
||||
func (c *KVCache) Offset() int { return c.offset }
|
||||
|
|
@ -90,7 +87,7 @@ func (c *KVCache) Reset() {
|
|||
|
||||
// RotatingKVCache implements a bounded sliding window cache.
|
||||
type RotatingKVCache struct {
|
||||
keys, values *mlx.Array
|
||||
keys, values *Array
|
||||
offset int
|
||||
maxSize int
|
||||
step int
|
||||
|
|
@ -102,14 +99,14 @@ func NewRotatingKVCache(maxSize int) *RotatingKVCache {
|
|||
return &RotatingKVCache{maxSize: maxSize, step: 256}
|
||||
}
|
||||
|
||||
func (c *RotatingKVCache) Update(k, v *mlx.Array, seqLen int) (*mlx.Array, *mlx.Array) {
|
||||
func (c *RotatingKVCache) Update(k, v *Array, seqLen int) (*Array, *Array) {
|
||||
if seqLen > 1 {
|
||||
return c.updateConcat(k, v, seqLen)
|
||||
}
|
||||
return c.updateInPlace(k, v)
|
||||
}
|
||||
|
||||
func (c *RotatingKVCache) updateInPlace(k, v *mlx.Array) (*mlx.Array, *mlx.Array) {
|
||||
func (c *RotatingKVCache) updateInPlace(k, v *Array) (*Array, *Array) {
|
||||
shape := k.Shape()
|
||||
if len(shape) < 4 {
|
||||
if c.keys == nil {
|
||||
|
|
@ -127,11 +124,11 @@ func (c *RotatingKVCache) updateInPlace(k, v *mlx.Array) (*mlx.Array, *mlx.Array
|
|||
cap = int(c.keys.Shape()[2])
|
||||
}
|
||||
newSize := min(c.step, c.maxSize-cap)
|
||||
newK := mlx.Zeros([]int32{B, H, int32(newSize), Dk}, k.Dtype())
|
||||
newV := mlx.Zeros([]int32{B, H, int32(newSize), Dv}, v.Dtype())
|
||||
newK := Zeros([]int32{B, H, int32(newSize), Dk}, k.Dtype())
|
||||
newV := Zeros([]int32{B, H, int32(newSize), Dv}, v.Dtype())
|
||||
if c.keys != nil {
|
||||
c.keys = mlx.Concatenate([]*mlx.Array{c.keys, newK}, 2)
|
||||
c.values = mlx.Concatenate([]*mlx.Array{c.values, newV}, 2)
|
||||
c.keys = Concatenate([]*Array{c.keys, newK}, 2)
|
||||
c.values = Concatenate([]*Array{c.values, newV}, 2)
|
||||
} else {
|
||||
c.keys, c.values = newK, newV
|
||||
}
|
||||
|
|
@ -141,18 +138,18 @@ func (c *RotatingKVCache) updateInPlace(k, v *mlx.Array) (*mlx.Array, *mlx.Array
|
|||
c.idx = 0
|
||||
}
|
||||
|
||||
c.keys = mlx.SliceUpdateInplace(c.keys, k, []int32{0, 0, int32(c.idx), 0}, []int32{B, H, int32(c.idx + 1), Dk})
|
||||
c.values = mlx.SliceUpdateInplace(c.values, v, []int32{0, 0, int32(c.idx), 0}, []int32{B, H, int32(c.idx + 1), Dv})
|
||||
c.keys = SliceUpdateInplace(c.keys, k, []int32{0, 0, int32(c.idx), 0}, []int32{B, H, int32(c.idx + 1), Dk})
|
||||
c.values = SliceUpdateInplace(c.values, v, []int32{0, 0, int32(c.idx), 0}, []int32{B, H, int32(c.idx + 1), Dv})
|
||||
|
||||
c.offset++
|
||||
c.idx++
|
||||
|
||||
validLen := int32(min(c.offset, c.maxSize))
|
||||
return mlx.Slice(c.keys, []int32{0, 0, 0, 0}, []int32{B, H, validLen, Dk}),
|
||||
mlx.Slice(c.values, []int32{0, 0, 0, 0}, []int32{B, H, validLen, Dv})
|
||||
return Slice(c.keys, []int32{0, 0, 0, 0}, []int32{B, H, validLen, Dk}),
|
||||
Slice(c.values, []int32{0, 0, 0, 0}, []int32{B, H, validLen, Dv})
|
||||
}
|
||||
|
||||
func (c *RotatingKVCache) updateConcat(k, v *mlx.Array, seqLen int) (*mlx.Array, *mlx.Array) {
|
||||
func (c *RotatingKVCache) updateConcat(k, v *Array, seqLen int) (*Array, *Array) {
|
||||
shape := k.Shape()
|
||||
if len(shape) < 4 {
|
||||
// K/V must be [B, H, L, D] — if not, pass through unchanged
|
||||
|
|
@ -168,26 +165,26 @@ func (c *RotatingKVCache) updateConcat(k, v *mlx.Array, seqLen int) (*mlx.Array,
|
|||
if c.keys == nil {
|
||||
c.keys, c.values = k, v
|
||||
} else {
|
||||
c.keys = mlx.Concatenate([]*mlx.Array{c.keys, k}, 2)
|
||||
c.values = mlx.Concatenate([]*mlx.Array{c.values, v}, 2)
|
||||
c.keys = Concatenate([]*Array{c.keys, k}, 2)
|
||||
c.values = Concatenate([]*Array{c.values, v}, 2)
|
||||
}
|
||||
c.offset += seqLen
|
||||
|
||||
cap := int(c.keys.Shape()[2])
|
||||
if trim := cap - c.maxSize; trim > 0 {
|
||||
c.keys = mlx.Slice(c.keys, []int32{0, 0, int32(trim), 0}, []int32{B, H, int32(cap), Dk})
|
||||
c.values = mlx.Slice(c.values, []int32{0, 0, int32(trim), 0}, []int32{B, H, int32(cap), Dv})
|
||||
c.keys = Slice(c.keys, []int32{0, 0, int32(trim), 0}, []int32{B, H, int32(cap), Dk})
|
||||
c.values = Slice(c.values, []int32{0, 0, int32(trim), 0}, []int32{B, H, int32(cap), Dv})
|
||||
}
|
||||
|
||||
c.idx = int(c.keys.Shape()[2])
|
||||
return c.keys, c.values
|
||||
}
|
||||
|
||||
func (c *RotatingKVCache) State() []*mlx.Array {
|
||||
func (c *RotatingKVCache) State() []*Array {
|
||||
if c.keys == nil {
|
||||
return nil
|
||||
}
|
||||
return []*mlx.Array{c.keys, c.values}
|
||||
return []*Array{c.keys, c.values}
|
||||
}
|
||||
|
||||
func (c *RotatingKVCache) Offset() int { return c.offset }
|
||||
|
|
@ -1,7 +1,6 @@
|
|||
//go:build darwin && arm64
|
||||
|
||||
// Package model provides transformer model architectures for MLX inference.
|
||||
package model
|
||||
package metal
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
|
|
@ -10,10 +9,6 @@ import (
|
|||
"math"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"forge.lthn.ai/core/go-mlx"
|
||||
"forge.lthn.ai/core/go-mlx/cache"
|
||||
"forge.lthn.ai/core/go-mlx/tokenizer"
|
||||
)
|
||||
|
||||
// TextConfig holds Gemma 3 text model configuration.
|
||||
|
|
@ -38,32 +33,32 @@ type TextConfig struct {
|
|||
|
||||
// GemmaModel is the Gemma 3 text model.
|
||||
type GemmaModel struct {
|
||||
EmbedTokens *mlx.Embedding
|
||||
EmbedTokens *Embedding
|
||||
Layers []*DecoderLayer
|
||||
Norm *mlx.RMSNormModule
|
||||
Output *mlx.Linear // Tied to EmbedTokens
|
||||
Norm *RMSNormModule
|
||||
Output *Linear // Tied to EmbedTokens
|
||||
|
||||
// Precomputed (1 + weight) for Gemma-style RMSNorm
|
||||
NormScaled *mlx.Array
|
||||
NormScaled *Array
|
||||
|
||||
Tok *tokenizer.Tokenizer
|
||||
Tok *Tokenizer
|
||||
Cfg *TextConfig
|
||||
}
|
||||
|
||||
// DecoderLayer is a single transformer block.
|
||||
type DecoderLayer struct {
|
||||
InputNorm *mlx.RMSNormModule
|
||||
InputNorm *RMSNormModule
|
||||
Attention *Attention
|
||||
PostAttnNorm *mlx.RMSNormModule
|
||||
PreFFNorm *mlx.RMSNormModule
|
||||
PostAttnNorm *RMSNormModule
|
||||
PreFFNorm *RMSNormModule
|
||||
MLP *MLP
|
||||
PostFFNorm *mlx.RMSNormModule
|
||||
PostFFNorm *RMSNormModule
|
||||
|
||||
// Precomputed scaled weights
|
||||
InputNormScaled *mlx.Array
|
||||
PostAttnNormScaled *mlx.Array
|
||||
PreFFNormScaled *mlx.Array
|
||||
PostFFNormScaled *mlx.Array
|
||||
InputNormScaled *Array
|
||||
PostAttnNormScaled *Array
|
||||
PreFFNormScaled *Array
|
||||
PostFFNormScaled *Array
|
||||
|
||||
IsSliding bool
|
||||
LayerIdx int32
|
||||
|
|
@ -71,31 +66,31 @@ type DecoderLayer struct {
|
|||
|
||||
// Attention implements Gemma 3 attention with Q/K normalization.
|
||||
type Attention struct {
|
||||
QProj *mlx.Linear
|
||||
KProj *mlx.Linear
|
||||
VProj *mlx.Linear
|
||||
OProj *mlx.Linear
|
||||
QNorm *mlx.RMSNormModule
|
||||
KNorm *mlx.RMSNormModule
|
||||
QProj *Linear
|
||||
KProj *Linear
|
||||
VProj *Linear
|
||||
OProj *Linear
|
||||
QNorm *RMSNormModule
|
||||
KNorm *RMSNormModule
|
||||
|
||||
QNormScaled *mlx.Array
|
||||
KNormScaled *mlx.Array
|
||||
QNormScaled *Array
|
||||
KNormScaled *Array
|
||||
}
|
||||
|
||||
// MLP is the feed-forward network.
|
||||
type MLP struct {
|
||||
GateProj *mlx.Linear
|
||||
UpProj *mlx.Linear
|
||||
DownProj *mlx.Linear
|
||||
GateProj *Linear
|
||||
UpProj *Linear
|
||||
DownProj *Linear
|
||||
}
|
||||
|
||||
// compiledGELU is a singleton for the compiled GELU function.
|
||||
var compiledGELU *mlx.CompiledFunc
|
||||
var compiledGELU *CompiledFunc
|
||||
|
||||
func getCompiledGELU() *mlx.CompiledFunc {
|
||||
func getCompiledGELU() *CompiledFunc {
|
||||
if compiledGELU == nil {
|
||||
compiledGELU = mlx.CompileShapeless(func(inputs []*mlx.Array) []*mlx.Array {
|
||||
return []*mlx.Array{geluApprox(inputs[0])}
|
||||
compiledGELU = CompileShapeless(func(inputs []*Array) []*Array {
|
||||
return []*Array{geluApprox(inputs[0])}
|
||||
}, true)
|
||||
}
|
||||
return compiledGELU
|
||||
|
|
@ -103,16 +98,16 @@ func getCompiledGELU() *mlx.CompiledFunc {
|
|||
|
||||
// geluApprox computes GELU using the tanh approximation:
|
||||
// 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))
|
||||
func geluApprox(x *mlx.Array) *mlx.Array {
|
||||
func geluApprox(x *Array) *Array {
|
||||
const sqrt2OverPi = 0.7978845608028654
|
||||
const coeff = 0.044715
|
||||
|
||||
x3 := mlx.Mul(mlx.Mul(x, x), x)
|
||||
inner := mlx.Add(x, mlx.MulScalar(x3, coeff))
|
||||
scaled := mlx.MulScalar(inner, sqrt2OverPi)
|
||||
t := mlx.Tanh(scaled)
|
||||
onePlusT := mlx.AddScalar(t, 1.0)
|
||||
return mlx.Mul(mlx.MulScalar(x, 0.5), onePlusT)
|
||||
x3 := Mul(Mul(x, x), x)
|
||||
inner := Add(x, MulScalar(x3, coeff))
|
||||
scaled := MulScalar(inner, sqrt2OverPi)
|
||||
t := Tanh(scaled)
|
||||
onePlusT := AddScalar(t, 1.0)
|
||||
return Mul(MulScalar(x, 0.5), onePlusT)
|
||||
}
|
||||
|
||||
// parseConfig handles both flat and nested (text_config) Gemma 3 configs.
|
||||
|
|
@ -175,22 +170,22 @@ func LoadGemma3(modelPath string) (*GemmaModel, error) {
|
|||
}
|
||||
|
||||
// Load tokenizer
|
||||
tok, err := tokenizer.Load(filepath.Join(modelPath, "tokenizer.json"))
|
||||
tok, err := LoadTokenizer(filepath.Join(modelPath, "tokenizer.json"))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("gemma3: load tokenizer: %w", err)
|
||||
}
|
||||
|
||||
// Load weights from all safetensors files
|
||||
weights := make(map[string]*mlx.Array)
|
||||
weights := make(map[string]*Array)
|
||||
matches, _ := filepath.Glob(filepath.Join(modelPath, "*.safetensors"))
|
||||
for _, path := range matches {
|
||||
for name, arr := range mlx.LoadSafetensors(path) {
|
||||
for name, arr := range LoadSafetensors(path) {
|
||||
weights[name] = arr
|
||||
}
|
||||
}
|
||||
|
||||
// Helper to resolve weight with language_model. prefix fallback
|
||||
w := func(name string) *mlx.Array { return resolveWeight(weights, name) }
|
||||
w := func(name string) *Array { return resolveWeight(weights, name) }
|
||||
|
||||
// Infer head_dim from q_proj weight shape when not in config.
|
||||
// Gemma 3 uses head_dim=256 which differs from hidden_size/num_heads.
|
||||
|
|
@ -211,18 +206,18 @@ func LoadGemma3(modelPath string) (*GemmaModel, error) {
|
|||
if q != nil {
|
||||
slog.Info("mlx: using quantized inference", "bits", q.Bits, "group_size", q.GroupSize)
|
||||
}
|
||||
linear := func(prefix string) *mlx.Linear {
|
||||
linear := func(prefix string) *Linear {
|
||||
weight := w(prefix + ".weight")
|
||||
scales := w(prefix + ".scales")
|
||||
biases := w(prefix + ".biases")
|
||||
if scales != nil && q != nil {
|
||||
return mlx.NewQuantizedLinear(weight, scales, biases, nil, q.GroupSize, q.Bits)
|
||||
return NewQuantizedLinear(weight, scales, biases, nil, q.GroupSize, q.Bits)
|
||||
}
|
||||
return mlx.NewLinear(weight, nil)
|
||||
return NewLinear(weight, nil)
|
||||
}
|
||||
|
||||
// Create embedding (quantized or dense)
|
||||
embed := &mlx.Embedding{Weight: w("model.embed_tokens.weight")}
|
||||
embed := &Embedding{Weight: w("model.embed_tokens.weight")}
|
||||
if embedScales := w("model.embed_tokens.scales"); embedScales != nil && q != nil {
|
||||
embed.Scales = embedScales
|
||||
embed.Biases = w("model.embed_tokens.biases")
|
||||
|
|
@ -233,7 +228,7 @@ func LoadGemma3(modelPath string) (*GemmaModel, error) {
|
|||
m := &GemmaModel{
|
||||
EmbedTokens: embed,
|
||||
Layers: make([]*DecoderLayer, cfg.NumHiddenLayers),
|
||||
Norm: &mlx.RMSNormModule{Weight: w("model.norm.weight")},
|
||||
Norm: &RMSNormModule{Weight: w("model.norm.weight")},
|
||||
Tok: tok,
|
||||
Cfg: cfg,
|
||||
}
|
||||
|
|
@ -242,17 +237,17 @@ func LoadGemma3(modelPath string) (*GemmaModel, error) {
|
|||
for i := int32(0); i < cfg.NumHiddenLayers; i++ {
|
||||
prefix := fmt.Sprintf("model.layers.%d", i)
|
||||
m.Layers[i] = &DecoderLayer{
|
||||
InputNorm: &mlx.RMSNormModule{Weight: w(prefix + ".input_layernorm.weight")},
|
||||
PostAttnNorm: &mlx.RMSNormModule{Weight: w(prefix + ".post_attention_layernorm.weight")},
|
||||
PreFFNorm: &mlx.RMSNormModule{Weight: w(prefix + ".pre_feedforward_layernorm.weight")},
|
||||
PostFFNorm: &mlx.RMSNormModule{Weight: w(prefix + ".post_feedforward_layernorm.weight")},
|
||||
InputNorm: &RMSNormModule{Weight: w(prefix + ".input_layernorm.weight")},
|
||||
PostAttnNorm: &RMSNormModule{Weight: w(prefix + ".post_attention_layernorm.weight")},
|
||||
PreFFNorm: &RMSNormModule{Weight: w(prefix + ".pre_feedforward_layernorm.weight")},
|
||||
PostFFNorm: &RMSNormModule{Weight: w(prefix + ".post_feedforward_layernorm.weight")},
|
||||
Attention: &Attention{
|
||||
QProj: linear(prefix + ".self_attn.q_proj"),
|
||||
KProj: linear(prefix + ".self_attn.k_proj"),
|
||||
VProj: linear(prefix + ".self_attn.v_proj"),
|
||||
OProj: linear(prefix + ".self_attn.o_proj"),
|
||||
QNorm: &mlx.RMSNormModule{Weight: w(prefix + ".self_attn.q_norm.weight")},
|
||||
KNorm: &mlx.RMSNormModule{Weight: w(prefix + ".self_attn.k_norm.weight")},
|
||||
QNorm: &RMSNormModule{Weight: w(prefix + ".self_attn.q_norm.weight")},
|
||||
KNorm: &RMSNormModule{Weight: w(prefix + ".self_attn.k_norm.weight")},
|
||||
},
|
||||
MLP: &MLP{
|
||||
GateProj: linear(prefix + ".mlp.gate_proj"),
|
||||
|
|
@ -269,9 +264,9 @@ func LoadGemma3(modelPath string) (*GemmaModel, error) {
|
|||
if lmHeadWeight != nil {
|
||||
lmHeadScales := w("lm_head.scales")
|
||||
if lmHeadScales != nil && q != nil {
|
||||
m.Output = mlx.NewQuantizedLinear(lmHeadWeight, lmHeadScales, w("lm_head.biases"), nil, q.GroupSize, q.Bits)
|
||||
m.Output = NewQuantizedLinear(lmHeadWeight, lmHeadScales, w("lm_head.biases"), nil, q.GroupSize, q.Bits)
|
||||
} else {
|
||||
m.Output = mlx.NewLinear(lmHeadWeight, nil)
|
||||
m.Output = NewLinear(lmHeadWeight, nil)
|
||||
}
|
||||
} else {
|
||||
// Tied embeddings — reuse embed_tokens weights (with quantization if present)
|
||||
|
|
@ -279,11 +274,11 @@ func LoadGemma3(modelPath string) (*GemmaModel, error) {
|
|||
}
|
||||
|
||||
// Materialize all weights
|
||||
var allArrays []*mlx.Array
|
||||
var allArrays []*Array
|
||||
for _, a := range weights {
|
||||
allArrays = append(allArrays, a)
|
||||
}
|
||||
mlx.Materialize(allArrays...)
|
||||
Materialize(allArrays...)
|
||||
|
||||
// Precompute (1 + weight) for Gemma-style RMSNorm
|
||||
precomputeScaledWeights(m)
|
||||
|
|
@ -292,25 +287,25 @@ func LoadGemma3(modelPath string) (*GemmaModel, error) {
|
|||
}
|
||||
|
||||
func precomputeScaledWeights(m *GemmaModel) {
|
||||
m.NormScaled = mlx.AddScalar(m.Norm.Weight, 1.0)
|
||||
m.NormScaled = AddScalar(m.Norm.Weight, 1.0)
|
||||
|
||||
for _, layer := range m.Layers {
|
||||
layer.InputNormScaled = mlx.AddScalar(layer.InputNorm.Weight, 1.0)
|
||||
layer.PostAttnNormScaled = mlx.AddScalar(layer.PostAttnNorm.Weight, 1.0)
|
||||
layer.PreFFNormScaled = mlx.AddScalar(layer.PreFFNorm.Weight, 1.0)
|
||||
layer.PostFFNormScaled = mlx.AddScalar(layer.PostFFNorm.Weight, 1.0)
|
||||
layer.Attention.QNormScaled = mlx.AddScalar(layer.Attention.QNorm.Weight, 1.0)
|
||||
layer.Attention.KNormScaled = mlx.AddScalar(layer.Attention.KNorm.Weight, 1.0)
|
||||
layer.InputNormScaled = AddScalar(layer.InputNorm.Weight, 1.0)
|
||||
layer.PostAttnNormScaled = AddScalar(layer.PostAttnNorm.Weight, 1.0)
|
||||
layer.PreFFNormScaled = AddScalar(layer.PreFFNorm.Weight, 1.0)
|
||||
layer.PostFFNormScaled = AddScalar(layer.PostFFNorm.Weight, 1.0)
|
||||
layer.Attention.QNormScaled = AddScalar(layer.Attention.QNorm.Weight, 1.0)
|
||||
layer.Attention.KNormScaled = AddScalar(layer.Attention.KNorm.Weight, 1.0)
|
||||
}
|
||||
|
||||
var scaled []*mlx.Array
|
||||
var scaled []*Array
|
||||
scaled = append(scaled, m.NormScaled)
|
||||
for _, layer := range m.Layers {
|
||||
scaled = append(scaled, layer.InputNormScaled, layer.PostAttnNormScaled,
|
||||
layer.PreFFNormScaled, layer.PostFFNormScaled,
|
||||
layer.Attention.QNormScaled, layer.Attention.KNormScaled)
|
||||
}
|
||||
mlx.Materialize(scaled...)
|
||||
Materialize(scaled...)
|
||||
}
|
||||
|
||||
func isLayerSliding(layerIdx, pattern int32) bool {
|
||||
|
|
@ -321,56 +316,56 @@ func isLayerSliding(layerIdx, pattern int32) bool {
|
|||
}
|
||||
|
||||
// Forward runs the text model forward pass.
|
||||
func (m *GemmaModel) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array {
|
||||
func (m *GemmaModel) Forward(tokens *Array, caches []Cache) *Array {
|
||||
shape := tokens.Shape()
|
||||
B, L := shape[0], shape[1]
|
||||
|
||||
h := m.EmbedTokens.Forward(tokens)
|
||||
h = mlx.MulScalar(h, float32(math.Sqrt(float64(m.Cfg.HiddenSize))))
|
||||
h = MulScalar(h, float32(math.Sqrt(float64(m.Cfg.HiddenSize))))
|
||||
|
||||
for i, layer := range m.Layers {
|
||||
h = layer.forward(h, caches[i], B, L, m.Cfg)
|
||||
}
|
||||
|
||||
return m.Output.Forward(mlx.RMSNorm(h, m.NormScaled, m.Cfg.RMSNormEps))
|
||||
return m.Output.Forward(RMSNorm(h, m.NormScaled, m.Cfg.RMSNormEps))
|
||||
}
|
||||
|
||||
func (l *DecoderLayer) forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *TextConfig) *mlx.Array {
|
||||
normed := mlx.RMSNorm(x, l.InputNormScaled, cfg.RMSNormEps)
|
||||
func (l *DecoderLayer) forward(x *Array, c Cache, B, L int32, cfg *TextConfig) *Array {
|
||||
normed := RMSNorm(x, l.InputNormScaled, cfg.RMSNormEps)
|
||||
attnOut := l.Attention.forward(normed, c, B, L, l.IsSliding, cfg)
|
||||
attnOut = mlx.RMSNorm(attnOut, l.PostAttnNormScaled, cfg.RMSNormEps)
|
||||
h := mlx.Add(x, attnOut)
|
||||
attnOut = RMSNorm(attnOut, l.PostAttnNormScaled, cfg.RMSNormEps)
|
||||
h := Add(x, attnOut)
|
||||
|
||||
normed = mlx.RMSNorm(h, l.PreFFNormScaled, cfg.RMSNormEps)
|
||||
normed = RMSNorm(h, l.PreFFNormScaled, cfg.RMSNormEps)
|
||||
mlpOut := l.MLP.forward(normed)
|
||||
mlpOut = mlx.RMSNorm(mlpOut, l.PostFFNormScaled, cfg.RMSNormEps)
|
||||
return mlx.Add(h, mlpOut)
|
||||
mlpOut = RMSNorm(mlpOut, l.PostFFNormScaled, cfg.RMSNormEps)
|
||||
return Add(h, mlpOut)
|
||||
}
|
||||
|
||||
func (a *Attention) forward(x *mlx.Array, c cache.Cache, B, L int32, isSliding bool, cfg *TextConfig) *mlx.Array {
|
||||
func (a *Attention) forward(x *Array, c Cache, B, L int32, isSliding bool, cfg *TextConfig) *Array {
|
||||
q := a.QProj.Forward(x)
|
||||
k := a.KProj.Forward(x)
|
||||
v := a.VProj.Forward(x)
|
||||
|
||||
// Reshape to [B, num_heads, L, head_dim]
|
||||
q = mlx.AsStrided(q, []int32{B, cfg.NumAttentionHeads, L, cfg.HeadDim},
|
||||
q = AsStrided(q, []int32{B, cfg.NumAttentionHeads, L, cfg.HeadDim},
|
||||
[]int64{int64(L * cfg.NumAttentionHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumAttentionHeads * cfg.HeadDim), 1}, 0)
|
||||
k = mlx.AsStrided(k, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim},
|
||||
k = AsStrided(k, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim},
|
||||
[]int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0)
|
||||
v = mlx.AsStrided(v, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim},
|
||||
v = AsStrided(v, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim},
|
||||
[]int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0)
|
||||
|
||||
// Q/K normalization
|
||||
q = mlx.RMSNorm(q, a.QNormScaled, cfg.RMSNormEps)
|
||||
k = mlx.RMSNorm(k, a.KNormScaled, cfg.RMSNormEps)
|
||||
q = RMSNorm(q, a.QNormScaled, cfg.RMSNormEps)
|
||||
k = RMSNorm(k, a.KNormScaled, cfg.RMSNormEps)
|
||||
|
||||
// RoPE with appropriate theta
|
||||
ropeTheta := cfg.RopeTheta
|
||||
if isSliding {
|
||||
ropeTheta = cfg.RopeLocalBaseFreq
|
||||
}
|
||||
q = mlx.RoPE(q, int(cfg.HeadDim), false, ropeTheta, 1.0, c.Offset())
|
||||
k = mlx.RoPE(k, int(cfg.HeadDim), false, ropeTheta, 1.0, c.Offset())
|
||||
q = RoPE(q, int(cfg.HeadDim), false, ropeTheta, 1.0, c.Offset())
|
||||
k = RoPE(k, int(cfg.HeadDim), false, ropeTheta, 1.0, c.Offset())
|
||||
|
||||
// Update cache
|
||||
k, v = c.Update(k, v, int(L))
|
||||
|
|
@ -378,29 +373,29 @@ func (a *Attention) forward(x *mlx.Array, c cache.Cache, B, L int32, isSliding b
|
|||
// GQA: repeat K/V heads
|
||||
repeatFactor := cfg.NumAttentionHeads / cfg.NumKeyValueHeads
|
||||
if repeatFactor > 1 {
|
||||
k = mlx.RepeatKV(k, repeatFactor)
|
||||
v = mlx.RepeatKV(v, repeatFactor)
|
||||
k = RepeatKV(k, repeatFactor)
|
||||
v = RepeatKV(v, repeatFactor)
|
||||
}
|
||||
|
||||
// Scaled dot-product attention
|
||||
out := mlx.ScaledDotProductAttention(q, k, v, cfg.Scale, L > 1)
|
||||
out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.HeadDim)
|
||||
out := ScaledDotProductAttention(q, k, v, cfg.Scale, L > 1)
|
||||
out = Reshape(Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.HeadDim)
|
||||
return a.OProj.Forward(out)
|
||||
}
|
||||
|
||||
func (m *MLP) forward(x *mlx.Array) *mlx.Array {
|
||||
func (m *MLP) forward(x *Array) *Array {
|
||||
gate := getCompiledGELU().Call(m.GateProj.Forward(x))[0]
|
||||
return m.DownProj.Forward(mlx.Mul(gate, m.UpProj.Forward(x)))
|
||||
return m.DownProj.Forward(Mul(gate, m.UpProj.Forward(x)))
|
||||
}
|
||||
|
||||
// NewCache creates per-layer caches for generation.
|
||||
func (m *GemmaModel) NewCache() []cache.Cache {
|
||||
caches := make([]cache.Cache, len(m.Layers))
|
||||
func (m *GemmaModel) NewCache() []Cache {
|
||||
caches := make([]Cache, len(m.Layers))
|
||||
for i := range caches {
|
||||
if m.Layers[i].IsSliding {
|
||||
caches[i] = cache.NewRotatingKVCache(int(m.Cfg.SlidingWindow))
|
||||
caches[i] = NewRotatingKVCache(int(m.Cfg.SlidingWindow))
|
||||
} else {
|
||||
caches[i] = cache.NewKVCache()
|
||||
caches[i] = NewKVCache()
|
||||
}
|
||||
}
|
||||
return caches
|
||||
|
|
@ -410,22 +405,22 @@ func (m *GemmaModel) NewCache() []cache.Cache {
|
|||
func (m *GemmaModel) NumLayers() int { return len(m.Layers) }
|
||||
|
||||
// Tokenizer returns the model's tokenizer.
|
||||
func (m *GemmaModel) Tokenizer() *tokenizer.Tokenizer { return m.Tok }
|
||||
func (m *GemmaModel) Tokenizer() *Tokenizer { return m.Tok }
|
||||
|
||||
// ModelType returns the architecture identifier.
|
||||
func (m *GemmaModel) ModelType() string { return "gemma3" }
|
||||
|
||||
// ApplyLoRA wraps target projection layers with LoRA adapters.
|
||||
func (m *GemmaModel) ApplyLoRA(cfg mlx.LoRAConfig) *mlx.LoRAAdapter {
|
||||
adapter := &mlx.LoRAAdapter{
|
||||
Layers: make(map[string]*mlx.LoRALinear),
|
||||
func (m *GemmaModel) ApplyLoRA(cfg LoRAConfig) *LoRAAdapter {
|
||||
adapter := &LoRAAdapter{
|
||||
Layers: make(map[string]*LoRALinear),
|
||||
Config: cfg,
|
||||
}
|
||||
|
||||
for i, layer := range m.Layers {
|
||||
prefix := fmt.Sprintf("model.layers.%d.self_attn", i)
|
||||
for _, target := range cfg.TargetKeys {
|
||||
var proj *mlx.Linear
|
||||
var proj *Linear
|
||||
switch target {
|
||||
case "q_proj":
|
||||
proj = layer.Attention.QProj
|
||||
|
|
@ -437,7 +432,7 @@ func (m *GemmaModel) ApplyLoRA(cfg mlx.LoRAConfig) *mlx.LoRAAdapter {
|
|||
proj = layer.Attention.OProj
|
||||
}
|
||||
if proj != nil {
|
||||
lora := mlx.NewLoRALinear(proj, cfg.Rank, cfg.Alpha)
|
||||
lora := NewLoRALinear(proj, cfg.Rank, cfg.Alpha)
|
||||
proj.LoRA = lora
|
||||
adapter.Layers[prefix+"."+target] = lora
|
||||
}
|
||||
|
|
@ -1,39 +1,34 @@
|
|||
//go:build darwin && arm64
|
||||
|
||||
// Package model provides transformer model architectures for MLX inference.
|
||||
package model
|
||||
package metal
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"forge.lthn.ai/core/go-mlx"
|
||||
"forge.lthn.ai/core/go-mlx/cache"
|
||||
"forge.lthn.ai/core/go-mlx/tokenizer"
|
||||
)
|
||||
|
||||
// Model is the common interface for all transformer model architectures.
|
||||
type Model interface {
|
||||
// InternalModel is the common interface for all transformer model architectures.
|
||||
type InternalModel interface {
|
||||
// Forward runs the model forward pass on token IDs with KV caches.
|
||||
Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array
|
||||
Forward(tokens *Array, caches []Cache) *Array
|
||||
|
||||
// NewCache creates per-layer KV caches for generation.
|
||||
NewCache() []cache.Cache
|
||||
NewCache() []Cache
|
||||
|
||||
// NumLayers returns the number of transformer layers.
|
||||
NumLayers() int
|
||||
|
||||
// Tokenizer returns the model's tokenizer.
|
||||
Tokenizer() *tokenizer.Tokenizer
|
||||
Tokenizer() *Tokenizer
|
||||
|
||||
// ModelType returns the architecture identifier (e.g. "gemma3", "qwen3").
|
||||
ModelType() string
|
||||
|
||||
// ApplyLoRA wraps target projection layers with LoRA adapters for training.
|
||||
// Returns the adapter which holds references to all LoRA layers.
|
||||
ApplyLoRA(cfg mlx.LoRAConfig) *mlx.LoRAAdapter
|
||||
ApplyLoRA(cfg LoRAConfig) *LoRAAdapter
|
||||
}
|
||||
|
||||
// QuantizationConfig holds quantization parameters from config.json.
|
||||
|
|
@ -43,7 +38,7 @@ type QuantizationConfig struct {
|
|||
}
|
||||
|
||||
// resolveWeight looks up a weight with optional "language_model." prefix.
|
||||
func resolveWeight(weights map[string]*mlx.Array, name string) *mlx.Array {
|
||||
func resolveWeight(weights map[string]*Array, name string) *Array {
|
||||
if w, ok := weights[name]; ok {
|
||||
return w
|
||||
}
|
||||
|
|
@ -53,8 +48,8 @@ func resolveWeight(weights map[string]*mlx.Array, name string) *mlx.Array {
|
|||
return nil
|
||||
}
|
||||
|
||||
// LoadModel auto-detects the model architecture from config.json and loads it.
|
||||
func LoadModel(modelPath string) (Model, error) {
|
||||
// loadModel auto-detects the model architecture from config.json and loads it.
|
||||
func loadModel(modelPath string) (InternalModel, error) {
|
||||
data, err := os.ReadFile(filepath.Join(modelPath, "config.json"))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("model: load config: %w", err)
|
||||
|
|
@ -1,6 +1,6 @@
|
|||
//go:build darwin && arm64
|
||||
|
||||
package model
|
||||
package metal
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
|
|
@ -9,10 +9,6 @@ import (
|
|||
"math"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"forge.lthn.ai/core/go-mlx"
|
||||
"forge.lthn.ai/core/go-mlx/cache"
|
||||
"forge.lthn.ai/core/go-mlx/tokenizer"
|
||||
)
|
||||
|
||||
// Qwen3Config holds Qwen 3 model configuration.
|
||||
|
|
@ -34,39 +30,39 @@ type Qwen3Config struct {
|
|||
|
||||
// Qwen3Model is the Qwen 3 text model.
|
||||
type Qwen3Model struct {
|
||||
EmbedTokens *mlx.Embedding
|
||||
EmbedTokens *Embedding
|
||||
Layers []*Qwen3DecoderLayer
|
||||
Norm *mlx.RMSNormModule
|
||||
Output *mlx.Linear
|
||||
Norm *RMSNormModule
|
||||
Output *Linear
|
||||
|
||||
Tok *tokenizer.Tokenizer
|
||||
Tok *Tokenizer
|
||||
Cfg *Qwen3Config
|
||||
}
|
||||
|
||||
// Qwen3DecoderLayer is a single transformer block.
|
||||
// Qwen 3 uses standard pre-norm residual: norm→attn→add, norm→mlp→add.
|
||||
type Qwen3DecoderLayer struct {
|
||||
InputNorm *mlx.RMSNormModule // Pre-attention norm
|
||||
PostAttnNorm *mlx.RMSNormModule // Pre-MLP norm (confusingly named post_attention_layernorm)
|
||||
InputNorm *RMSNormModule // Pre-attention norm
|
||||
PostAttnNorm *RMSNormModule // Pre-MLP norm (confusingly named post_attention_layernorm)
|
||||
Attention *Qwen3Attention
|
||||
MLP *Qwen3MLP
|
||||
}
|
||||
|
||||
// Qwen3Attention implements Qwen 3 GQA with Q/K RMS normalization.
|
||||
type Qwen3Attention struct {
|
||||
QProj *mlx.Linear
|
||||
KProj *mlx.Linear
|
||||
VProj *mlx.Linear
|
||||
OProj *mlx.Linear
|
||||
QNorm *mlx.RMSNormModule
|
||||
KNorm *mlx.RMSNormModule
|
||||
QProj *Linear
|
||||
KProj *Linear
|
||||
VProj *Linear
|
||||
OProj *Linear
|
||||
QNorm *RMSNormModule
|
||||
KNorm *RMSNormModule
|
||||
}
|
||||
|
||||
// Qwen3MLP is the SwiGLU feed-forward network: down(silu(gate(x)) * up(x)).
|
||||
type Qwen3MLP struct {
|
||||
GateProj *mlx.Linear
|
||||
UpProj *mlx.Linear
|
||||
DownProj *mlx.Linear
|
||||
GateProj *Linear
|
||||
UpProj *Linear
|
||||
DownProj *Linear
|
||||
}
|
||||
|
||||
func parseQwen3Config(data []byte) (*Qwen3Config, error) {
|
||||
|
|
@ -114,40 +110,40 @@ func LoadQwen3(modelPath string) (*Qwen3Model, error) {
|
|||
return nil, fmt.Errorf("qwen3: parse config: %w", err)
|
||||
}
|
||||
|
||||
tok, err := tokenizer.Load(filepath.Join(modelPath, "tokenizer.json"))
|
||||
tok, err := LoadTokenizer(filepath.Join(modelPath, "tokenizer.json"))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("qwen3: load tokenizer: %w", err)
|
||||
}
|
||||
|
||||
// Load weights from all safetensors files
|
||||
weights := make(map[string]*mlx.Array)
|
||||
weights := make(map[string]*Array)
|
||||
matches, _ := filepath.Glob(filepath.Join(modelPath, "*.safetensors"))
|
||||
for _, path := range matches {
|
||||
for name, arr := range mlx.LoadSafetensors(path) {
|
||||
for name, arr := range LoadSafetensors(path) {
|
||||
weights[name] = arr
|
||||
}
|
||||
}
|
||||
|
||||
w := func(name string) *mlx.Array { return resolveWeight(weights, name) }
|
||||
w := func(name string) *Array { return resolveWeight(weights, name) }
|
||||
|
||||
// Quantization setup
|
||||
q := cfg.Quantization
|
||||
if q != nil {
|
||||
slog.Info("qwen3: using quantized inference", "bits", q.Bits, "group_size", q.GroupSize)
|
||||
}
|
||||
linear := func(prefix string) *mlx.Linear {
|
||||
linear := func(prefix string) *Linear {
|
||||
weight := w(prefix + ".weight")
|
||||
scales := w(prefix + ".scales")
|
||||
biases := w(prefix + ".biases")
|
||||
bias := w(prefix + ".bias")
|
||||
if scales != nil && q != nil {
|
||||
return mlx.NewQuantizedLinear(weight, scales, biases, bias, q.GroupSize, q.Bits)
|
||||
return NewQuantizedLinear(weight, scales, biases, bias, q.GroupSize, q.Bits)
|
||||
}
|
||||
return mlx.NewLinear(weight, bias)
|
||||
return NewLinear(weight, bias)
|
||||
}
|
||||
|
||||
// Embedding
|
||||
embed := &mlx.Embedding{Weight: w("model.embed_tokens.weight")}
|
||||
embed := &Embedding{Weight: w("model.embed_tokens.weight")}
|
||||
if embedScales := w("model.embed_tokens.scales"); embedScales != nil && q != nil {
|
||||
embed.Scales = embedScales
|
||||
embed.Biases = w("model.embed_tokens.biases")
|
||||
|
|
@ -158,7 +154,7 @@ func LoadQwen3(modelPath string) (*Qwen3Model, error) {
|
|||
m := &Qwen3Model{
|
||||
EmbedTokens: embed,
|
||||
Layers: make([]*Qwen3DecoderLayer, cfg.NumHiddenLayers),
|
||||
Norm: &mlx.RMSNormModule{Weight: w("model.norm.weight")},
|
||||
Norm: &RMSNormModule{Weight: w("model.norm.weight")},
|
||||
Tok: tok,
|
||||
Cfg: cfg,
|
||||
}
|
||||
|
|
@ -166,15 +162,15 @@ func LoadQwen3(modelPath string) (*Qwen3Model, error) {
|
|||
for i := int32(0); i < cfg.NumHiddenLayers; i++ {
|
||||
p := fmt.Sprintf("model.layers.%d", i)
|
||||
m.Layers[i] = &Qwen3DecoderLayer{
|
||||
InputNorm: &mlx.RMSNormModule{Weight: w(p + ".input_layernorm.weight")},
|
||||
PostAttnNorm: &mlx.RMSNormModule{Weight: w(p + ".post_attention_layernorm.weight")},
|
||||
InputNorm: &RMSNormModule{Weight: w(p + ".input_layernorm.weight")},
|
||||
PostAttnNorm: &RMSNormModule{Weight: w(p + ".post_attention_layernorm.weight")},
|
||||
Attention: &Qwen3Attention{
|
||||
QProj: linear(p + ".self_attn.q_proj"),
|
||||
KProj: linear(p + ".self_attn.k_proj"),
|
||||
VProj: linear(p + ".self_attn.v_proj"),
|
||||
OProj: linear(p + ".self_attn.o_proj"),
|
||||
QNorm: &mlx.RMSNormModule{Weight: w(p + ".self_attn.q_norm.weight")},
|
||||
KNorm: &mlx.RMSNormModule{Weight: w(p + ".self_attn.k_norm.weight")},
|
||||
QNorm: &RMSNormModule{Weight: w(p + ".self_attn.q_norm.weight")},
|
||||
KNorm: &RMSNormModule{Weight: w(p + ".self_attn.k_norm.weight")},
|
||||
},
|
||||
MLP: &Qwen3MLP{
|
||||
GateProj: linear(p + ".mlp.gate_proj"),
|
||||
|
|
@ -189,20 +185,20 @@ func LoadQwen3(modelPath string) (*Qwen3Model, error) {
|
|||
if lmHeadWeight != nil {
|
||||
lmHeadScales := w("lm_head.scales")
|
||||
if lmHeadScales != nil && q != nil {
|
||||
m.Output = mlx.NewQuantizedLinear(lmHeadWeight, lmHeadScales, w("lm_head.biases"), nil, q.GroupSize, q.Bits)
|
||||
m.Output = NewQuantizedLinear(lmHeadWeight, lmHeadScales, w("lm_head.biases"), nil, q.GroupSize, q.Bits)
|
||||
} else {
|
||||
m.Output = mlx.NewLinear(lmHeadWeight, nil)
|
||||
m.Output = NewLinear(lmHeadWeight, nil)
|
||||
}
|
||||
} else {
|
||||
m.Output = m.EmbedTokens.AsLinear()
|
||||
}
|
||||
|
||||
// Materialise all weights onto Metal
|
||||
var allArrays []*mlx.Array
|
||||
var allArrays []*Array
|
||||
for _, a := range weights {
|
||||
allArrays = append(allArrays, a)
|
||||
}
|
||||
mlx.Materialize(allArrays...)
|
||||
Materialize(allArrays...)
|
||||
|
||||
slog.Info("qwen3: model loaded",
|
||||
"layers", cfg.NumHiddenLayers,
|
||||
|
|
@ -218,7 +214,7 @@ func LoadQwen3(modelPath string) (*Qwen3Model, error) {
|
|||
|
||||
// Forward runs the Qwen 3 forward pass.
|
||||
// Unlike Gemma, Qwen does NOT scale embeddings by sqrt(hidden_size).
|
||||
func (m *Qwen3Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array {
|
||||
func (m *Qwen3Model) Forward(tokens *Array, caches []Cache) *Array {
|
||||
shape := tokens.Shape()
|
||||
B, L := shape[0], shape[1]
|
||||
|
||||
|
|
@ -231,29 +227,29 @@ func (m *Qwen3Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array
|
|||
return m.Output.Forward(m.Norm.Forward(h, m.Cfg.RMSNormEps))
|
||||
}
|
||||
|
||||
func (l *Qwen3DecoderLayer) forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Qwen3Config) *mlx.Array {
|
||||
func (l *Qwen3DecoderLayer) forward(x *Array, c Cache, B, L int32, cfg *Qwen3Config) *Array {
|
||||
// Pre-attention norm → attention → residual add
|
||||
normed := l.InputNorm.Forward(x, cfg.RMSNormEps)
|
||||
attnOut := l.Attention.forward(normed, c, B, L, cfg)
|
||||
h := mlx.Add(x, attnOut)
|
||||
h := Add(x, attnOut)
|
||||
|
||||
// Pre-MLP norm → MLP → residual add
|
||||
normed = l.PostAttnNorm.Forward(h, cfg.RMSNormEps)
|
||||
mlpOut := l.MLP.forward(normed)
|
||||
return mlx.Add(h, mlpOut)
|
||||
return Add(h, mlpOut)
|
||||
}
|
||||
|
||||
func (a *Qwen3Attention) forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Qwen3Config) *mlx.Array {
|
||||
func (a *Qwen3Attention) forward(x *Array, c Cache, B, L int32, cfg *Qwen3Config) *Array {
|
||||
q := a.QProj.Forward(x)
|
||||
k := a.KProj.Forward(x)
|
||||
v := a.VProj.Forward(x)
|
||||
|
||||
// Reshape to [B, num_heads, L, head_dim]
|
||||
q = mlx.AsStrided(q, []int32{B, cfg.NumAttentionHeads, L, cfg.HeadDim},
|
||||
q = AsStrided(q, []int32{B, cfg.NumAttentionHeads, L, cfg.HeadDim},
|
||||
[]int64{int64(L * cfg.NumAttentionHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumAttentionHeads * cfg.HeadDim), 1}, 0)
|
||||
k = mlx.AsStrided(k, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim},
|
||||
k = AsStrided(k, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim},
|
||||
[]int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0)
|
||||
v = mlx.AsStrided(v, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim},
|
||||
v = AsStrided(v, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim},
|
||||
[]int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0)
|
||||
|
||||
// Q/K RMS normalization (Qwen 3 has this)
|
||||
|
|
@ -261,8 +257,8 @@ func (a *Qwen3Attention) forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Q
|
|||
k = a.KNorm.Forward(k, cfg.RMSNormEps)
|
||||
|
||||
// RoPE — single theta for all layers (no sliding window)
|
||||
q = mlx.RoPE(q, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, c.Offset())
|
||||
k = mlx.RoPE(k, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, c.Offset())
|
||||
q = RoPE(q, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, c.Offset())
|
||||
k = RoPE(k, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, c.Offset())
|
||||
|
||||
// Update KV cache
|
||||
k, v = c.Update(k, v, int(L))
|
||||
|
|
@ -270,27 +266,27 @@ func (a *Qwen3Attention) forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Q
|
|||
// GQA: repeat K/V heads to match Q heads
|
||||
repeatFactor := cfg.NumAttentionHeads / cfg.NumKeyValueHeads
|
||||
if repeatFactor > 1 {
|
||||
k = mlx.RepeatKV(k, repeatFactor)
|
||||
v = mlx.RepeatKV(v, repeatFactor)
|
||||
k = RepeatKV(k, repeatFactor)
|
||||
v = RepeatKV(v, repeatFactor)
|
||||
}
|
||||
|
||||
// Scaled dot-product attention
|
||||
out := mlx.ScaledDotProductAttention(q, k, v, cfg.Scale, L > 1)
|
||||
out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.HeadDim)
|
||||
out := ScaledDotProductAttention(q, k, v, cfg.Scale, L > 1)
|
||||
out = Reshape(Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.HeadDim)
|
||||
return a.OProj.Forward(out)
|
||||
}
|
||||
|
||||
// forward computes SwiGLU: down(silu(gate(x)) * up(x)).
|
||||
func (m *Qwen3MLP) forward(x *mlx.Array) *mlx.Array {
|
||||
gate := mlx.SiLU(m.GateProj.Forward(x))
|
||||
return m.DownProj.Forward(mlx.Mul(gate, m.UpProj.Forward(x)))
|
||||
func (m *Qwen3MLP) forward(x *Array) *Array {
|
||||
gate := SiLU(m.GateProj.Forward(x))
|
||||
return m.DownProj.Forward(Mul(gate, m.UpProj.Forward(x)))
|
||||
}
|
||||
|
||||
// NewCache creates per-layer KV caches. Qwen 3 uses global attention only.
|
||||
func (m *Qwen3Model) NewCache() []cache.Cache {
|
||||
caches := make([]cache.Cache, len(m.Layers))
|
||||
func (m *Qwen3Model) NewCache() []Cache {
|
||||
caches := make([]Cache, len(m.Layers))
|
||||
for i := range caches {
|
||||
caches[i] = cache.NewKVCache()
|
||||
caches[i] = NewKVCache()
|
||||
}
|
||||
return caches
|
||||
}
|
||||
|
|
@ -299,22 +295,22 @@ func (m *Qwen3Model) NewCache() []cache.Cache {
|
|||
func (m *Qwen3Model) NumLayers() int { return len(m.Layers) }
|
||||
|
||||
// Tokenizer returns the model's tokenizer.
|
||||
func (m *Qwen3Model) Tokenizer() *tokenizer.Tokenizer { return m.Tok }
|
||||
func (m *Qwen3Model) Tokenizer() *Tokenizer { return m.Tok }
|
||||
|
||||
// ModelType returns the architecture identifier.
|
||||
func (m *Qwen3Model) ModelType() string { return "qwen3" }
|
||||
|
||||
// ApplyLoRA wraps target projection layers with LoRA adapters.
|
||||
func (m *Qwen3Model) ApplyLoRA(cfg mlx.LoRAConfig) *mlx.LoRAAdapter {
|
||||
adapter := &mlx.LoRAAdapter{
|
||||
Layers: make(map[string]*mlx.LoRALinear),
|
||||
func (m *Qwen3Model) ApplyLoRA(cfg LoRAConfig) *LoRAAdapter {
|
||||
adapter := &LoRAAdapter{
|
||||
Layers: make(map[string]*LoRALinear),
|
||||
Config: cfg,
|
||||
}
|
||||
|
||||
for i, layer := range m.Layers {
|
||||
prefix := fmt.Sprintf("model.layers.%d.self_attn", i)
|
||||
for _, target := range cfg.TargetKeys {
|
||||
var proj *mlx.Linear
|
||||
var proj *Linear
|
||||
switch target {
|
||||
case "q_proj":
|
||||
proj = layer.Attention.QProj
|
||||
|
|
@ -326,7 +322,7 @@ func (m *Qwen3Model) ApplyLoRA(cfg mlx.LoRAConfig) *mlx.LoRAAdapter {
|
|||
proj = layer.Attention.OProj
|
||||
}
|
||||
if proj != nil {
|
||||
lora := mlx.NewLoRALinear(proj, cfg.Rank, cfg.Alpha)
|
||||
lora := NewLoRALinear(proj, cfg.Rank, cfg.Alpha)
|
||||
proj.LoRA = lora
|
||||
adapter.Layers[prefix+"."+target] = lora
|
||||
}
|
||||
|
|
@ -1,22 +1,19 @@
|
|||
//go:build darwin && arm64
|
||||
|
||||
// Package sample provides composable token sampling strategies.
|
||||
package sample
|
||||
package metal
|
||||
|
||||
import (
|
||||
"math"
|
||||
|
||||
"forge.lthn.ai/core/go-mlx"
|
||||
)
|
||||
|
||||
// Sampler transforms logits into a sampled token index.
|
||||
type Sampler interface {
|
||||
Sample(logits *mlx.Array) *mlx.Array
|
||||
Sample(logits *Array) *Array
|
||||
}
|
||||
|
||||
// New creates a composable sampler chain from the given parameters.
|
||||
// newSampler creates a composable sampler chain from the given parameters.
|
||||
// Order: TopP -> MinP -> TopK -> Temperature -> categorical sample.
|
||||
func New(temp, topP, minP float32, topK int) Sampler {
|
||||
func newSampler(temp, topP, minP float32, topK int) Sampler {
|
||||
if temp == 0 {
|
||||
return greedy{}
|
||||
}
|
||||
|
|
@ -38,43 +35,43 @@ func New(temp, topP, minP float32, topK int) Sampler {
|
|||
// chain applies a sequence of samplers, then samples from the result.
|
||||
type chain []Sampler
|
||||
|
||||
func (c chain) Sample(logits *mlx.Array) *mlx.Array {
|
||||
func (c chain) Sample(logits *Array) *Array {
|
||||
for _, s := range c {
|
||||
logits = s.Sample(logits)
|
||||
}
|
||||
// Final categorical sample from log-probabilities
|
||||
return mlx.RandomCategorical(logits)
|
||||
return RandomCategorical(logits)
|
||||
}
|
||||
|
||||
// greedy returns the argmax token.
|
||||
type greedy struct{}
|
||||
|
||||
func (greedy) Sample(logits *mlx.Array) *mlx.Array {
|
||||
return mlx.Argmax(logits, -1, false)
|
||||
func (greedy) Sample(logits *Array) *Array {
|
||||
return Argmax(logits, -1, false)
|
||||
}
|
||||
|
||||
// Temperature scales logits by 1/temp.
|
||||
type Temperature float32
|
||||
|
||||
func (t Temperature) Sample(logits *mlx.Array) *mlx.Array {
|
||||
return mlx.MulScalar(logits, 1.0/float32(t))
|
||||
func (t Temperature) Sample(logits *Array) *Array {
|
||||
return MulScalar(logits, 1.0/float32(t))
|
||||
}
|
||||
|
||||
// TopKSampler masks all but the top-k logits.
|
||||
type TopKSampler int
|
||||
|
||||
func (k TopKSampler) Sample(logits *mlx.Array) *mlx.Array {
|
||||
neg := mlx.Negative(logits)
|
||||
mask := mlx.Argpartition(neg, int(k)-1, -1)
|
||||
func (k TopKSampler) Sample(logits *Array) *Array {
|
||||
neg := Negative(logits)
|
||||
mask := Argpartition(neg, int(k)-1, -1)
|
||||
// Slice the indices beyond top-k
|
||||
mask = mlx.SliceAxis(mask, -1, int32(k), int32(logits.Dim(-1)))
|
||||
return mlx.PutAlongAxis(logits, mask, mlx.FromValue(float32(math.Inf(-1))), -1)
|
||||
mask = SliceAxis(mask, -1, int32(k), int32(logits.Dim(-1)))
|
||||
return PutAlongAxis(logits, mask, FromValue(float32(math.Inf(-1))), -1)
|
||||
}
|
||||
|
||||
// TopP implements nucleus sampling (cumulative probability threshold).
|
||||
type TopP float32
|
||||
|
||||
func (p TopP) Sample(logits *mlx.Array) *mlx.Array {
|
||||
func (p TopP) Sample(logits *Array) *Array {
|
||||
// TODO: full nucleus sampling requires cumsum which mlx-c doesn't expose directly.
|
||||
// For now, pass through. TopK + Temperature covers most use cases.
|
||||
return logits
|
||||
|
|
@ -83,7 +80,7 @@ func (p TopP) Sample(logits *mlx.Array) *mlx.Array {
|
|||
// MinPSampler masks tokens below min_p * max_prob.
|
||||
type MinPSampler float32
|
||||
|
||||
func (p MinPSampler) Sample(logits *mlx.Array) *mlx.Array {
|
||||
func (p MinPSampler) Sample(logits *Array) *Array {
|
||||
// For now, pass through — MinP is an optimization over TopP.
|
||||
// Full implementation requires finding max prob and masking below threshold.
|
||||
return logits
|
||||
|
|
@ -1,7 +1,6 @@
|
|||
//go:build darwin && arm64
|
||||
|
||||
// Package tokenizer provides BPE tokenization for transformer models.
|
||||
package tokenizer
|
||||
package metal
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
|
|
@ -46,8 +45,8 @@ type tokenizerJSON struct {
|
|||
} `json:"added_tokens"`
|
||||
}
|
||||
|
||||
// Load reads a tokenizer.json file and creates a Tokenizer.
|
||||
func Load(path string) (*Tokenizer, error) {
|
||||
// LoadTokenizer reads a tokenizer.json file and creates a Tokenizer.
|
||||
func LoadTokenizer(path string) (*Tokenizer, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("tokenizer: read %s: %w", path, err)
|
||||
|
|
@ -294,7 +293,7 @@ func (t *Tokenizer) DecodeToken(id int32) string {
|
|||
return t.decodeGPT2Bytes(text)
|
||||
}
|
||||
|
||||
// SentencePiece: replace ▁ with space but keep it (it's the word boundary)
|
||||
// SentencePiece: replace with space but keep it (it's the word boundary)
|
||||
return strings.ReplaceAll(text, "▁", " ")
|
||||
}
|
||||
|
||||
|
|
@ -1,116 +0,0 @@
|
|||
//go:build darwin && arm64
|
||||
|
||||
package sample
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"forge.lthn.ai/core/go-mlx"
|
||||
)
|
||||
|
||||
func TestGreedy(t *testing.T) {
|
||||
// Logits heavily favour index 2
|
||||
logits := mlx.FromValues([]float32{-10, -10, 100, -10}, 1, 4)
|
||||
s := New(0, 0, 0, 0) // temp=0 → greedy
|
||||
token := s.Sample(logits)
|
||||
mlx.Materialize(token)
|
||||
|
||||
if token.Int() != 2 {
|
||||
t.Errorf("greedy sample = %d, want 2", token.Int())
|
||||
}
|
||||
}
|
||||
|
||||
func TestTemperature_HighTemp(t *testing.T) {
|
||||
// High temperature should still produce a valid index
|
||||
logits := mlx.FromValues([]float32{1, 2, 3, 4}, 1, 4)
|
||||
s := New(100.0, 0, 0, 0) // very high temp → near uniform
|
||||
token := s.Sample(logits)
|
||||
mlx.Materialize(token)
|
||||
|
||||
idx := token.Int()
|
||||
if idx < 0 || idx >= 4 {
|
||||
t.Errorf("sample index = %d, out of range [0, 4)", idx)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTemperature_LowTemp(t *testing.T) {
|
||||
// Very low temperature should behave like greedy
|
||||
logits := mlx.FromValues([]float32{-10, -10, 100, -10}, 1, 4)
|
||||
s := New(0.001, 0, 0, 0) // near-zero temp → near-greedy
|
||||
token := s.Sample(logits)
|
||||
mlx.Materialize(token)
|
||||
|
||||
if token.Int() != 2 {
|
||||
t.Errorf("low-temp sample = %d, want 2 (near greedy)", token.Int())
|
||||
}
|
||||
}
|
||||
|
||||
func TestTopK(t *testing.T) {
|
||||
// TopK=1 with clear winner should always pick that token
|
||||
logits := mlx.FromValues([]float32{-100, 100, -100, -100}, 1, 4)
|
||||
s := New(1.0, 0, 0, 1) // topK=1
|
||||
token := s.Sample(logits)
|
||||
mlx.Materialize(token)
|
||||
|
||||
if token.Int() != 1 {
|
||||
t.Errorf("topk=1 sample = %d, want 1", token.Int())
|
||||
}
|
||||
}
|
||||
|
||||
func TestTopK_MultipleTokens(t *testing.T) {
|
||||
// TopK=2, both high logits — should pick one of them
|
||||
logits := mlx.FromValues([]float32{-100, 50, 50, -100}, 1, 4)
|
||||
s := New(1.0, 0, 0, 2) // topK=2
|
||||
|
||||
seen := map[int]bool{}
|
||||
for range 20 {
|
||||
token := s.Sample(logits)
|
||||
mlx.Materialize(token)
|
||||
seen[token.Int()] = true
|
||||
}
|
||||
|
||||
// Should only ever pick index 1 or 2
|
||||
for idx := range seen {
|
||||
if idx != 1 && idx != 2 {
|
||||
t.Errorf("topk=2 sampled index %d, expected only 1 or 2", idx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNew_Chain(t *testing.T) {
|
||||
// Full chain: topK + temperature
|
||||
logits := mlx.FromValues([]float32{1, 2, 3, 4, 5}, 1, 5)
|
||||
s := New(0.5, 0, 0, 3) // temp=0.5, topK=3
|
||||
|
||||
token := s.Sample(logits)
|
||||
mlx.Materialize(token)
|
||||
|
||||
idx := token.Int()
|
||||
if idx < 0 || idx >= 5 {
|
||||
t.Errorf("chain sample index = %d, out of range", idx)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTopP_PassThrough(t *testing.T) {
|
||||
// TopP is currently a stub — verify it doesn't break the chain
|
||||
logits := mlx.FromValues([]float32{-10, -10, 100, -10}, 1, 4)
|
||||
s := New(0.5, 0.9, 0, 0) // topP=0.9 (stub), temp=0.5
|
||||
token := s.Sample(logits)
|
||||
mlx.Materialize(token)
|
||||
|
||||
if token.Int() != 2 {
|
||||
t.Errorf("topP stub + temp sample = %d, want 2", token.Int())
|
||||
}
|
||||
}
|
||||
|
||||
func TestMinP_PassThrough(t *testing.T) {
|
||||
// MinP is currently a stub — verify it doesn't break the chain
|
||||
logits := mlx.FromValues([]float32{-10, -10, 100, -10}, 1, 4)
|
||||
s := New(0.5, 0, 0.1, 0) // minP=0.1 (stub), temp=0.5
|
||||
token := s.Sample(logits)
|
||||
mlx.Materialize(token)
|
||||
|
||||
if token.Int() != 2 {
|
||||
t.Errorf("minP stub + temp sample = %d, want 2", token.Int())
|
||||
}
|
||||
}
|
||||
|
|
@ -1,217 +0,0 @@
|
|||
//go:build darwin && arm64
|
||||
|
||||
package tokenizer
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// minimalTokenizerJSON is a valid HuggingFace tokenizer.json with a tiny vocab.
|
||||
const minimalTokenizerJSON = `{
|
||||
"model": {
|
||||
"type": "BPE",
|
||||
"vocab": {
|
||||
"h": 0,
|
||||
"e": 1,
|
||||
"l": 2,
|
||||
"o": 3,
|
||||
"▁": 4,
|
||||
"he": 5,
|
||||
"ll": 6,
|
||||
"▁h": 7
|
||||
},
|
||||
"merges": ["h e", "l l"],
|
||||
"byte_fallback": false
|
||||
},
|
||||
"added_tokens": [
|
||||
{"id": 100, "content": "<bos>", "special": true},
|
||||
{"id": 101, "content": "<eos>", "special": true}
|
||||
]
|
||||
}`
|
||||
|
||||
func writeTestTokenizer(t *testing.T) string {
|
||||
t.Helper()
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "tokenizer.json")
|
||||
if err := os.WriteFile(path, []byte(minimalTokenizerJSON), 0644); err != nil {
|
||||
t.Fatalf("write test tokenizer: %v", err)
|
||||
}
|
||||
return path
|
||||
}
|
||||
|
||||
func TestLoad(t *testing.T) {
|
||||
path := writeTestTokenizer(t)
|
||||
tok, err := Load(path)
|
||||
if err != nil {
|
||||
t.Fatalf("Load: %v", err)
|
||||
}
|
||||
if tok == nil {
|
||||
t.Fatal("tokenizer is nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoad_MissingFile(t *testing.T) {
|
||||
_, err := Load("/nonexistent/tokenizer.json")
|
||||
if err == nil {
|
||||
t.Error("expected error for missing file")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoad_InvalidJSON(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "tokenizer.json")
|
||||
os.WriteFile(path, []byte("not json"), 0644)
|
||||
|
||||
_, err := Load(path)
|
||||
if err == nil {
|
||||
t.Error("expected error for invalid JSON")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBOSEOS(t *testing.T) {
|
||||
path := writeTestTokenizer(t)
|
||||
tok, _ := Load(path)
|
||||
|
||||
if tok.BOSToken() != 100 {
|
||||
t.Errorf("BOS = %d, want 100", tok.BOSToken())
|
||||
}
|
||||
if tok.EOSToken() != 101 {
|
||||
t.Errorf("EOS = %d, want 101", tok.EOSToken())
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncode_ProducesTokens(t *testing.T) {
|
||||
path := writeTestTokenizer(t)
|
||||
tok, _ := Load(path)
|
||||
|
||||
tokens := tok.Encode("hello")
|
||||
if len(tokens) == 0 {
|
||||
t.Fatal("Encode returned empty tokens")
|
||||
}
|
||||
// First token should be BOS
|
||||
if tokens[0] != tok.BOSToken() {
|
||||
t.Errorf("first token = %d, want BOS (%d)", tokens[0], tok.BOSToken())
|
||||
}
|
||||
t.Logf("Encode(\"hello\") = %v", tokens)
|
||||
}
|
||||
|
||||
func TestDecode_SpecialTokensSkipped(t *testing.T) {
|
||||
path := writeTestTokenizer(t)
|
||||
tok, _ := Load(path)
|
||||
|
||||
// Decoding BOS/EOS should produce empty string
|
||||
text := tok.Decode([]int32{100, 101})
|
||||
if text != "" {
|
||||
t.Errorf("Decode(BOS, EOS) = %q, want empty", text)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecode_RegularTokens(t *testing.T) {
|
||||
path := writeTestTokenizer(t)
|
||||
tok, _ := Load(path)
|
||||
|
||||
// Decode known vocab entries
|
||||
text := tok.Decode([]int32{5, 6, 3}) // "he" + "ll" + "o"
|
||||
if text != "hello" {
|
||||
t.Errorf("Decode = %q, want %q", text, "hello")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeToken_Regular(t *testing.T) {
|
||||
path := writeTestTokenizer(t)
|
||||
tok, _ := Load(path)
|
||||
|
||||
// "he" = token 5
|
||||
text := tok.DecodeToken(5)
|
||||
if text != "he" {
|
||||
t.Errorf("DecodeToken(5) = %q, want %q", text, "he")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeToken_Special(t *testing.T) {
|
||||
path := writeTestTokenizer(t)
|
||||
tok, _ := Load(path)
|
||||
|
||||
// Special tokens should return empty
|
||||
text := tok.DecodeToken(100)
|
||||
if text != "" {
|
||||
t.Errorf("DecodeToken(BOS) = %q, want empty", text)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeToken_SentencePieceSpace(t *testing.T) {
|
||||
path := writeTestTokenizer(t)
|
||||
tok, _ := Load(path)
|
||||
|
||||
// "▁h" = token 7, should decode to " h" (space prefix)
|
||||
text := tok.DecodeToken(7)
|
||||
if text != " h" {
|
||||
t.Errorf("DecodeToken(7) = %q, want %q", text, " h")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeToken_Unknown(t *testing.T) {
|
||||
path := writeTestTokenizer(t)
|
||||
tok, _ := Load(path)
|
||||
|
||||
text := tok.DecodeToken(9999)
|
||||
if text != "" {
|
||||
t.Errorf("DecodeToken(unknown) = %q, want empty", text)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatGemmaPrompt(t *testing.T) {
|
||||
got := FormatGemmaPrompt("What is 2+2?")
|
||||
want := "<start_of_turn>user\nWhat is 2+2?<end_of_turn>\n<start_of_turn>model\n"
|
||||
if got != want {
|
||||
t.Errorf("FormatGemmaPrompt = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
// --- GPT-2 byte maps ---
|
||||
|
||||
func TestBuildGPT2ByteMaps(t *testing.T) {
|
||||
decoder, encoder := buildGPT2ByteMaps()
|
||||
|
||||
// All 256 bytes must be mapped
|
||||
if len(encoder) != 256 {
|
||||
t.Errorf("encoder has %d entries, want 256", len(encoder))
|
||||
}
|
||||
if len(decoder) != 256 {
|
||||
t.Errorf("decoder has %d entries, want 256", len(decoder))
|
||||
}
|
||||
|
||||
// Round-trip: every byte should survive encode → decode
|
||||
for b := 0; b < 256; b++ {
|
||||
r := encoder[byte(b)]
|
||||
got := decoder[r]
|
||||
if got != byte(b) {
|
||||
t.Errorf("byte %d: encode→decode = %d, want %d", b, got, b)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildGPT2ByteMaps_PrintableASCII(t *testing.T) {
|
||||
_, encoder := buildGPT2ByteMaps()
|
||||
|
||||
// Printable ASCII (33-126) should self-map
|
||||
for b := 33; b <= 126; b++ {
|
||||
if encoder[byte(b)] != rune(b) {
|
||||
t.Errorf("byte %d (%c): expected self-map, got %c", b, b, encoder[byte(b)])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildGPT2ByteMaps_ControlChars(t *testing.T) {
|
||||
_, encoder := buildGPT2ByteMaps()
|
||||
|
||||
// Space (32) and control chars (0-31) should NOT self-map
|
||||
if encoder[byte(32)] == rune(32) {
|
||||
t.Error("space (32) should not self-map in GPT-2 encoding")
|
||||
}
|
||||
if encoder[byte(0)] == rune(0) {
|
||||
t.Error("null (0) should not self-map in GPT-2 encoding")
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Reference in a new issue