- cache_test.go (10): KVCache and RotatingKVCache — creation, single/multi update, bounded sliding window, reset, state access - sample_test.go (8): greedy, temperature (high/low), topK (single/multi), chain composition, TopP/MinP stub pass-through - tokenizer_test.go (15): Load (valid/missing/invalid), BOS/EOS detection, encode produces tokens, decode skips specials, decode regular tokens, DecodeToken (regular/special/space/unknown), FormatGemmaPrompt, GPT-2 byte maps (round-trip, printable ASCII, control chars) Total test count: 29 → 148 across 10 test files. Co-Authored-By: Virgil <virgil@lethean.io> Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
116 lines
3 KiB
Go
116 lines
3 KiB
Go
//go:build darwin && arm64
|
|
|
|
package sample
|
|
|
|
import (
|
|
"testing"
|
|
|
|
"forge.lthn.ai/core/go-mlx"
|
|
)
|
|
|
|
func TestGreedy(t *testing.T) {
|
|
// Logits heavily favour index 2
|
|
logits := mlx.FromValues([]float32{-10, -10, 100, -10}, 1, 4)
|
|
s := New(0, 0, 0, 0) // temp=0 → greedy
|
|
token := s.Sample(logits)
|
|
mlx.Materialize(token)
|
|
|
|
if token.Int() != 2 {
|
|
t.Errorf("greedy sample = %d, want 2", token.Int())
|
|
}
|
|
}
|
|
|
|
func TestTemperature_HighTemp(t *testing.T) {
|
|
// High temperature should still produce a valid index
|
|
logits := mlx.FromValues([]float32{1, 2, 3, 4}, 1, 4)
|
|
s := New(100.0, 0, 0, 0) // very high temp → near uniform
|
|
token := s.Sample(logits)
|
|
mlx.Materialize(token)
|
|
|
|
idx := token.Int()
|
|
if idx < 0 || idx >= 4 {
|
|
t.Errorf("sample index = %d, out of range [0, 4)", idx)
|
|
}
|
|
}
|
|
|
|
func TestTemperature_LowTemp(t *testing.T) {
|
|
// Very low temperature should behave like greedy
|
|
logits := mlx.FromValues([]float32{-10, -10, 100, -10}, 1, 4)
|
|
s := New(0.001, 0, 0, 0) // near-zero temp → near-greedy
|
|
token := s.Sample(logits)
|
|
mlx.Materialize(token)
|
|
|
|
if token.Int() != 2 {
|
|
t.Errorf("low-temp sample = %d, want 2 (near greedy)", token.Int())
|
|
}
|
|
}
|
|
|
|
func TestTopK(t *testing.T) {
|
|
// TopK=1 with clear winner should always pick that token
|
|
logits := mlx.FromValues([]float32{-100, 100, -100, -100}, 1, 4)
|
|
s := New(1.0, 0, 0, 1) // topK=1
|
|
token := s.Sample(logits)
|
|
mlx.Materialize(token)
|
|
|
|
if token.Int() != 1 {
|
|
t.Errorf("topk=1 sample = %d, want 1", token.Int())
|
|
}
|
|
}
|
|
|
|
func TestTopK_MultipleTokens(t *testing.T) {
|
|
// TopK=2, both high logits — should pick one of them
|
|
logits := mlx.FromValues([]float32{-100, 50, 50, -100}, 1, 4)
|
|
s := New(1.0, 0, 0, 2) // topK=2
|
|
|
|
seen := map[int]bool{}
|
|
for range 20 {
|
|
token := s.Sample(logits)
|
|
mlx.Materialize(token)
|
|
seen[token.Int()] = true
|
|
}
|
|
|
|
// Should only ever pick index 1 or 2
|
|
for idx := range seen {
|
|
if idx != 1 && idx != 2 {
|
|
t.Errorf("topk=2 sampled index %d, expected only 1 or 2", idx)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestNew_Chain(t *testing.T) {
|
|
// Full chain: topK + temperature
|
|
logits := mlx.FromValues([]float32{1, 2, 3, 4, 5}, 1, 5)
|
|
s := New(0.5, 0, 0, 3) // temp=0.5, topK=3
|
|
|
|
token := s.Sample(logits)
|
|
mlx.Materialize(token)
|
|
|
|
idx := token.Int()
|
|
if idx < 0 || idx >= 5 {
|
|
t.Errorf("chain sample index = %d, out of range", idx)
|
|
}
|
|
}
|
|
|
|
func TestTopP_PassThrough(t *testing.T) {
|
|
// TopP is currently a stub — verify it doesn't break the chain
|
|
logits := mlx.FromValues([]float32{-10, -10, 100, -10}, 1, 4)
|
|
s := New(0.5, 0.9, 0, 0) // topP=0.9 (stub), temp=0.5
|
|
token := s.Sample(logits)
|
|
mlx.Materialize(token)
|
|
|
|
if token.Int() != 2 {
|
|
t.Errorf("topP stub + temp sample = %d, want 2", token.Int())
|
|
}
|
|
}
|
|
|
|
func TestMinP_PassThrough(t *testing.T) {
|
|
// MinP is currently a stub — verify it doesn't break the chain
|
|
logits := mlx.FromValues([]float32{-10, -10, 100, -10}, 1, 4)
|
|
s := New(0.5, 0, 0.1, 0) // minP=0.1 (stub), temp=0.5
|
|
token := s.Sample(logits)
|
|
mlx.Materialize(token)
|
|
|
|
if token.Int() != 2 {
|
|
t.Errorf("minP stub + temp sample = %d, want 2", token.Int())
|
|
}
|
|
}
|