go-mlx/tokenizer/tokenizer_test.go
Snider 28b5c5bb47 test(sub-packages): add 33 tests for cache, sample, tokenizer
- cache_test.go (10): KVCache and RotatingKVCache — creation, single/multi
  update, bounded sliding window, reset, state access
- sample_test.go (8): greedy, temperature (high/low), topK (single/multi),
  chain composition, TopP/MinP stub pass-through
- tokenizer_test.go (15): Load (valid/missing/invalid), BOS/EOS detection,
  encode produces tokens, decode skips specials, decode regular tokens,
  DecodeToken (regular/special/space/unknown), FormatGemmaPrompt,
  GPT-2 byte maps (round-trip, printable ASCII, control chars)

Total test count: 29 → 148 across 10 test files.

Co-Authored-By: Virgil <virgil@lethean.io>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-19 18:40:49 +00:00

217 lines
4.9 KiB
Go

//go:build darwin && arm64
package tokenizer
import (
"os"
"path/filepath"
"testing"
)
// minimalTokenizerJSON is a valid HuggingFace tokenizer.json with a tiny vocab.
const minimalTokenizerJSON = `{
"model": {
"type": "BPE",
"vocab": {
"h": 0,
"e": 1,
"l": 2,
"o": 3,
"▁": 4,
"he": 5,
"ll": 6,
"▁h": 7
},
"merges": ["h e", "l l"],
"byte_fallback": false
},
"added_tokens": [
{"id": 100, "content": "<bos>", "special": true},
{"id": 101, "content": "<eos>", "special": true}
]
}`
func writeTestTokenizer(t *testing.T) string {
t.Helper()
dir := t.TempDir()
path := filepath.Join(dir, "tokenizer.json")
if err := os.WriteFile(path, []byte(minimalTokenizerJSON), 0644); err != nil {
t.Fatalf("write test tokenizer: %v", err)
}
return path
}
func TestLoad(t *testing.T) {
path := writeTestTokenizer(t)
tok, err := Load(path)
if err != nil {
t.Fatalf("Load: %v", err)
}
if tok == nil {
t.Fatal("tokenizer is nil")
}
}
func TestLoad_MissingFile(t *testing.T) {
_, err := Load("/nonexistent/tokenizer.json")
if err == nil {
t.Error("expected error for missing file")
}
}
func TestLoad_InvalidJSON(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "tokenizer.json")
os.WriteFile(path, []byte("not json"), 0644)
_, err := Load(path)
if err == nil {
t.Error("expected error for invalid JSON")
}
}
func TestBOSEOS(t *testing.T) {
path := writeTestTokenizer(t)
tok, _ := Load(path)
if tok.BOSToken() != 100 {
t.Errorf("BOS = %d, want 100", tok.BOSToken())
}
if tok.EOSToken() != 101 {
t.Errorf("EOS = %d, want 101", tok.EOSToken())
}
}
func TestEncode_ProducesTokens(t *testing.T) {
path := writeTestTokenizer(t)
tok, _ := Load(path)
tokens := tok.Encode("hello")
if len(tokens) == 0 {
t.Fatal("Encode returned empty tokens")
}
// First token should be BOS
if tokens[0] != tok.BOSToken() {
t.Errorf("first token = %d, want BOS (%d)", tokens[0], tok.BOSToken())
}
t.Logf("Encode(\"hello\") = %v", tokens)
}
func TestDecode_SpecialTokensSkipped(t *testing.T) {
path := writeTestTokenizer(t)
tok, _ := Load(path)
// Decoding BOS/EOS should produce empty string
text := tok.Decode([]int32{100, 101})
if text != "" {
t.Errorf("Decode(BOS, EOS) = %q, want empty", text)
}
}
func TestDecode_RegularTokens(t *testing.T) {
path := writeTestTokenizer(t)
tok, _ := Load(path)
// Decode known vocab entries
text := tok.Decode([]int32{5, 6, 3}) // "he" + "ll" + "o"
if text != "hello" {
t.Errorf("Decode = %q, want %q", text, "hello")
}
}
func TestDecodeToken_Regular(t *testing.T) {
path := writeTestTokenizer(t)
tok, _ := Load(path)
// "he" = token 5
text := tok.DecodeToken(5)
if text != "he" {
t.Errorf("DecodeToken(5) = %q, want %q", text, "he")
}
}
func TestDecodeToken_Special(t *testing.T) {
path := writeTestTokenizer(t)
tok, _ := Load(path)
// Special tokens should return empty
text := tok.DecodeToken(100)
if text != "" {
t.Errorf("DecodeToken(BOS) = %q, want empty", text)
}
}
func TestDecodeToken_SentencePieceSpace(t *testing.T) {
path := writeTestTokenizer(t)
tok, _ := Load(path)
// "▁h" = token 7, should decode to " h" (space prefix)
text := tok.DecodeToken(7)
if text != " h" {
t.Errorf("DecodeToken(7) = %q, want %q", text, " h")
}
}
func TestDecodeToken_Unknown(t *testing.T) {
path := writeTestTokenizer(t)
tok, _ := Load(path)
text := tok.DecodeToken(9999)
if text != "" {
t.Errorf("DecodeToken(unknown) = %q, want empty", text)
}
}
func TestFormatGemmaPrompt(t *testing.T) {
got := FormatGemmaPrompt("What is 2+2?")
want := "<start_of_turn>user\nWhat is 2+2?<end_of_turn>\n<start_of_turn>model\n"
if got != want {
t.Errorf("FormatGemmaPrompt = %q, want %q", got, want)
}
}
// --- GPT-2 byte maps ---
func TestBuildGPT2ByteMaps(t *testing.T) {
decoder, encoder := buildGPT2ByteMaps()
// All 256 bytes must be mapped
if len(encoder) != 256 {
t.Errorf("encoder has %d entries, want 256", len(encoder))
}
if len(decoder) != 256 {
t.Errorf("decoder has %d entries, want 256", len(decoder))
}
// Round-trip: every byte should survive encode → decode
for b := 0; b < 256; b++ {
r := encoder[byte(b)]
got := decoder[r]
if got != byte(b) {
t.Errorf("byte %d: encode→decode = %d, want %d", b, got, b)
}
}
}
func TestBuildGPT2ByteMaps_PrintableASCII(t *testing.T) {
_, encoder := buildGPT2ByteMaps()
// Printable ASCII (33-126) should self-map
for b := 33; b <= 126; b++ {
if encoder[byte(b)] != rune(b) {
t.Errorf("byte %d (%c): expected self-map, got %c", b, b, encoder[byte(b)])
}
}
}
func TestBuildGPT2ByteMaps_ControlChars(t *testing.T) {
_, encoder := buildGPT2ByteMaps()
// Space (32) and control chars (0-31) should NOT self-map
if encoder[byte(32)] == rune(32) {
t.Error("space (32) should not self-map in GPT-2 encoding")
}
if encoder[byte(0)] == rune(0) {
t.Error("null (0) should not self-map in GPT-2 encoding")
}
}