- cache_test.go (10): KVCache and RotatingKVCache — creation, single/multi update, bounded sliding window, reset, state access - sample_test.go (8): greedy, temperature (high/low), topK (single/multi), chain composition, TopP/MinP stub pass-through - tokenizer_test.go (15): Load (valid/missing/invalid), BOS/EOS detection, encode produces tokens, decode skips specials, decode regular tokens, DecodeToken (regular/special/space/unknown), FormatGemmaPrompt, GPT-2 byte maps (round-trip, printable ASCII, control chars) Total test count: 29 → 148 across 10 test files. Co-Authored-By: Virgil <virgil@lethean.io> Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
200 lines
4 KiB
Go
200 lines
4 KiB
Go
//go:build darwin && arm64
|
|
|
|
package cache
|
|
|
|
import (
|
|
"testing"
|
|
|
|
"forge.lthn.ai/core/go-mlx"
|
|
)
|
|
|
|
// makeKV creates a small K/V pair with shape [B=1, H=2, L=seqLen, D=4].
|
|
func makeKV(seqLen int) (*mlx.Array, *mlx.Array) {
|
|
size := 1 * 2 * seqLen * 4
|
|
data := make([]float32, size)
|
|
for i := range data {
|
|
data[i] = float32(i) * 0.1
|
|
}
|
|
k := mlx.FromValues(data, 1, 2, seqLen, 4)
|
|
v := mlx.FromValues(data, 1, 2, seqLen, 4)
|
|
return k, v
|
|
}
|
|
|
|
// --- KVCache ---
|
|
|
|
func TestKVCache_New(t *testing.T) {
|
|
c := NewKVCache()
|
|
if c.Offset() != 0 {
|
|
t.Errorf("offset = %d, want 0", c.Offset())
|
|
}
|
|
if c.Len() != 0 {
|
|
t.Errorf("len = %d, want 0", c.Len())
|
|
}
|
|
if c.State() != nil {
|
|
t.Error("state should be nil for empty cache")
|
|
}
|
|
}
|
|
|
|
func TestKVCache_SingleUpdate(t *testing.T) {
|
|
c := NewKVCache()
|
|
k, v := makeKV(3) // 3 tokens
|
|
|
|
outK, outV := c.Update(k, v, 3)
|
|
mlx.Materialize(outK, outV)
|
|
|
|
if c.Offset() != 3 {
|
|
t.Errorf("offset = %d, want 3", c.Offset())
|
|
}
|
|
if c.Len() != 3 {
|
|
t.Errorf("len = %d, want 3", c.Len())
|
|
}
|
|
|
|
// Output K should have shape [1, 2, 3, 4]
|
|
shape := outK.Shape()
|
|
if shape[0] != 1 || shape[1] != 2 || shape[2] != 3 || shape[3] != 4 {
|
|
t.Errorf("outK shape = %v, want [1 2 3 4]", shape)
|
|
}
|
|
}
|
|
|
|
func TestKVCache_MultipleUpdates(t *testing.T) {
|
|
c := NewKVCache()
|
|
|
|
// Prompt: 5 tokens
|
|
k1, v1 := makeKV(5)
|
|
outK, outV := c.Update(k1, v1, 5)
|
|
mlx.Materialize(outK, outV)
|
|
|
|
if c.Offset() != 5 {
|
|
t.Errorf("offset = %d, want 5", c.Offset())
|
|
}
|
|
|
|
// Generate: 1 token at a time
|
|
k2, v2 := makeKV(1)
|
|
outK, outV = c.Update(k2, v2, 1)
|
|
mlx.Materialize(outK, outV)
|
|
|
|
if c.Offset() != 6 {
|
|
t.Errorf("offset = %d, want 6", c.Offset())
|
|
}
|
|
|
|
shape := outK.Shape()
|
|
if shape[2] != 6 {
|
|
t.Errorf("outK L dim = %d, want 6", shape[2])
|
|
}
|
|
}
|
|
|
|
func TestKVCache_Reset(t *testing.T) {
|
|
c := NewKVCache()
|
|
k, v := makeKV(3)
|
|
c.Update(k, v, 3)
|
|
|
|
c.Reset()
|
|
|
|
if c.Offset() != 0 {
|
|
t.Errorf("offset after reset = %d, want 0", c.Offset())
|
|
}
|
|
if c.State() != nil {
|
|
t.Error("state should be nil after reset")
|
|
}
|
|
}
|
|
|
|
func TestKVCache_State(t *testing.T) {
|
|
c := NewKVCache()
|
|
k, v := makeKV(2)
|
|
c.Update(k, v, 2)
|
|
|
|
state := c.State()
|
|
if len(state) != 2 {
|
|
t.Fatalf("state length = %d, want 2", len(state))
|
|
}
|
|
// state[0] = keys, state[1] = values
|
|
if state[0] == nil || state[1] == nil {
|
|
t.Error("state arrays should not be nil")
|
|
}
|
|
}
|
|
|
|
// --- RotatingKVCache ---
|
|
|
|
func TestRotatingKVCache_New(t *testing.T) {
|
|
c := NewRotatingKVCache(16)
|
|
if c.Offset() != 0 {
|
|
t.Errorf("offset = %d, want 0", c.Offset())
|
|
}
|
|
if c.Len() != 0 {
|
|
t.Errorf("len = %d, want 0", c.Len())
|
|
}
|
|
}
|
|
|
|
func TestRotatingKVCache_SingleToken(t *testing.T) {
|
|
c := NewRotatingKVCache(8)
|
|
k, v := makeKV(1)
|
|
|
|
outK, outV := c.Update(k, v, 1)
|
|
mlx.Materialize(outK, outV)
|
|
|
|
if c.Offset() != 1 {
|
|
t.Errorf("offset = %d, want 1", c.Offset())
|
|
}
|
|
if c.Len() != 1 {
|
|
t.Errorf("len = %d, want 1", c.Len())
|
|
}
|
|
}
|
|
|
|
func TestRotatingKVCache_MultiTokenPrompt(t *testing.T) {
|
|
c := NewRotatingKVCache(16)
|
|
k, v := makeKV(5)
|
|
|
|
outK, outV := c.Update(k, v, 5)
|
|
mlx.Materialize(outK, outV)
|
|
|
|
if c.Offset() != 5 {
|
|
t.Errorf("offset = %d, want 5", c.Offset())
|
|
}
|
|
if c.Len() != 5 {
|
|
t.Errorf("len = %d, want 5", c.Len())
|
|
}
|
|
}
|
|
|
|
func TestRotatingKVCache_Bounded(t *testing.T) {
|
|
c := NewRotatingKVCache(4)
|
|
|
|
// Fill with 4-token prompt (at max)
|
|
k, v := makeKV(4)
|
|
outK, outV := c.Update(k, v, 4)
|
|
mlx.Materialize(outK, outV)
|
|
|
|
if c.Len() != 4 {
|
|
t.Errorf("len = %d, want 4 (at max)", c.Len())
|
|
}
|
|
|
|
// Add one more token — should trim to maxSize
|
|
k2, v2 := makeKV(1)
|
|
outK, outV = c.Update(k2, v2, 1)
|
|
mlx.Materialize(outK, outV)
|
|
|
|
if c.Offset() != 5 {
|
|
t.Errorf("offset = %d, want 5", c.Offset())
|
|
}
|
|
// Len should be bounded by maxSize
|
|
if c.Len() != 4 {
|
|
t.Errorf("len = %d, want 4 (bounded)", c.Len())
|
|
}
|
|
}
|
|
|
|
func TestRotatingKVCache_Reset(t *testing.T) {
|
|
c := NewRotatingKVCache(8)
|
|
k, v := makeKV(3)
|
|
c.Update(k, v, 3)
|
|
|
|
c.Reset()
|
|
|
|
if c.Offset() != 0 {
|
|
t.Errorf("offset after reset = %d, want 0", c.Offset())
|
|
}
|
|
if c.Len() != 0 {
|
|
t.Errorf("len after reset = %d, want 0", c.Len())
|
|
}
|
|
if c.State() != nil {
|
|
t.Error("state should be nil after reset")
|
|
}
|
|
}
|