198 lines
4 KiB
Go
198 lines
4 KiB
Go
//go:build darwin && arm64
|
|
|
|
package metal
|
|
|
|
import (
|
|
"testing"
|
|
)
|
|
|
|
// makeKV creates a small K/V pair with shape [B=1, H=2, L=seqLen, D=4].
|
|
func makeKV(seqLen int) (*Array, *Array) {
|
|
size := 1 * 2 * seqLen * 4
|
|
data := make([]float32, size)
|
|
for i := range data {
|
|
data[i] = float32(i) * 0.1
|
|
}
|
|
k := FromValues(data, 1, 2, seqLen, 4)
|
|
v := FromValues(data, 1, 2, seqLen, 4)
|
|
return k, v
|
|
}
|
|
|
|
// --- KVCache ---
|
|
|
|
func TestKVCache_New(t *testing.T) {
|
|
c := NewKVCache()
|
|
if c.Offset() != 0 {
|
|
t.Errorf("offset = %d, want 0", c.Offset())
|
|
}
|
|
if c.Len() != 0 {
|
|
t.Errorf("len = %d, want 0", c.Len())
|
|
}
|
|
if c.State() != nil {
|
|
t.Error("state should be nil for empty cache")
|
|
}
|
|
}
|
|
|
|
func TestKVCache_SingleUpdate(t *testing.T) {
|
|
c := NewKVCache()
|
|
k, v := makeKV(3) // 3 tokens
|
|
|
|
outK, outV := c.Update(k, v, 3)
|
|
Materialize(outK, outV)
|
|
|
|
if c.Offset() != 3 {
|
|
t.Errorf("offset = %d, want 3", c.Offset())
|
|
}
|
|
if c.Len() != 3 {
|
|
t.Errorf("len = %d, want 3", c.Len())
|
|
}
|
|
|
|
// Output K should have shape [1, 2, 3, 4]
|
|
shape := outK.Shape()
|
|
if shape[0] != 1 || shape[1] != 2 || shape[2] != 3 || shape[3] != 4 {
|
|
t.Errorf("outK shape = %v, want [1 2 3 4]", shape)
|
|
}
|
|
}
|
|
|
|
func TestKVCache_MultipleUpdates(t *testing.T) {
|
|
c := NewKVCache()
|
|
|
|
// Prompt: 5 tokens
|
|
k1, v1 := makeKV(5)
|
|
outK, outV := c.Update(k1, v1, 5)
|
|
Materialize(outK, outV)
|
|
|
|
if c.Offset() != 5 {
|
|
t.Errorf("offset = %d, want 5", c.Offset())
|
|
}
|
|
|
|
// Generate: 1 token at a time
|
|
k2, v2 := makeKV(1)
|
|
outK, outV = c.Update(k2, v2, 1)
|
|
Materialize(outK, outV)
|
|
|
|
if c.Offset() != 6 {
|
|
t.Errorf("offset = %d, want 6", c.Offset())
|
|
}
|
|
|
|
shape := outK.Shape()
|
|
if shape[2] != 6 {
|
|
t.Errorf("outK L dim = %d, want 6", shape[2])
|
|
}
|
|
}
|
|
|
|
func TestKVCache_Reset(t *testing.T) {
|
|
c := NewKVCache()
|
|
k, v := makeKV(3)
|
|
c.Update(k, v, 3)
|
|
|
|
c.Reset()
|
|
|
|
if c.Offset() != 0 {
|
|
t.Errorf("offset after reset = %d, want 0", c.Offset())
|
|
}
|
|
if c.State() != nil {
|
|
t.Error("state should be nil after reset")
|
|
}
|
|
}
|
|
|
|
func TestKVCache_State(t *testing.T) {
|
|
c := NewKVCache()
|
|
k, v := makeKV(2)
|
|
c.Update(k, v, 2)
|
|
|
|
state := c.State()
|
|
if len(state) != 2 {
|
|
t.Fatalf("state length = %d, want 2", len(state))
|
|
}
|
|
// state[0] = keys, state[1] = values
|
|
if state[0] == nil || state[1] == nil {
|
|
t.Error("state arrays should not be nil")
|
|
}
|
|
}
|
|
|
|
// --- RotatingKVCache ---
|
|
|
|
func TestRotatingKVCache_New(t *testing.T) {
|
|
c := NewRotatingKVCache(16)
|
|
if c.Offset() != 0 {
|
|
t.Errorf("offset = %d, want 0", c.Offset())
|
|
}
|
|
if c.Len() != 0 {
|
|
t.Errorf("len = %d, want 0", c.Len())
|
|
}
|
|
}
|
|
|
|
func TestRotatingKVCache_SingleToken(t *testing.T) {
|
|
c := NewRotatingKVCache(8)
|
|
k, v := makeKV(1)
|
|
|
|
outK, outV := c.Update(k, v, 1)
|
|
Materialize(outK, outV)
|
|
|
|
if c.Offset() != 1 {
|
|
t.Errorf("offset = %d, want 1", c.Offset())
|
|
}
|
|
if c.Len() != 1 {
|
|
t.Errorf("len = %d, want 1", c.Len())
|
|
}
|
|
}
|
|
|
|
func TestRotatingKVCache_MultiTokenPrompt(t *testing.T) {
|
|
c := NewRotatingKVCache(16)
|
|
k, v := makeKV(5)
|
|
|
|
outK, outV := c.Update(k, v, 5)
|
|
Materialize(outK, outV)
|
|
|
|
if c.Offset() != 5 {
|
|
t.Errorf("offset = %d, want 5", c.Offset())
|
|
}
|
|
if c.Len() != 5 {
|
|
t.Errorf("len = %d, want 5", c.Len())
|
|
}
|
|
}
|
|
|
|
func TestRotatingKVCache_Bounded(t *testing.T) {
|
|
c := NewRotatingKVCache(4)
|
|
|
|
// Fill with 4-token prompt (at max)
|
|
k, v := makeKV(4)
|
|
outK, outV := c.Update(k, v, 4)
|
|
Materialize(outK, outV)
|
|
|
|
if c.Len() != 4 {
|
|
t.Errorf("len = %d, want 4 (at max)", c.Len())
|
|
}
|
|
|
|
// Add one more token — should trim to maxSize
|
|
k2, v2 := makeKV(1)
|
|
outK, outV = c.Update(k2, v2, 1)
|
|
Materialize(outK, outV)
|
|
|
|
if c.Offset() != 5 {
|
|
t.Errorf("offset = %d, want 5", c.Offset())
|
|
}
|
|
// Len should be bounded by maxSize
|
|
if c.Len() != 4 {
|
|
t.Errorf("len = %d, want 4 (bounded)", c.Len())
|
|
}
|
|
}
|
|
|
|
func TestRotatingKVCache_Reset(t *testing.T) {
|
|
c := NewRotatingKVCache(8)
|
|
k, v := makeKV(3)
|
|
c.Update(k, v, 3)
|
|
|
|
c.Reset()
|
|
|
|
if c.Offset() != 0 {
|
|
t.Errorf("offset after reset = %d, want 0", c.Offset())
|
|
}
|
|
if c.Len() != 0 {
|
|
t.Errorf("len after reset = %d, want 0", c.Len())
|
|
}
|
|
if c.State() != nil {
|
|
t.Error("state should be nil after reset")
|
|
}
|
|
}
|