test(sub-packages): add 33 tests for cache, sample, tokenizer
- cache_test.go (10): KVCache and RotatingKVCache — creation, single/multi update, bounded sliding window, reset, state access - sample_test.go (8): greedy, temperature (high/low), topK (single/multi), chain composition, TopP/MinP stub pass-through - tokenizer_test.go (15): Load (valid/missing/invalid), BOS/EOS detection, encode produces tokens, decode skips specials, decode regular tokens, DecodeToken (regular/special/space/unknown), FormatGemmaPrompt, GPT-2 byte maps (round-trip, printable ASCII, control chars) Total test count: 29 → 148 across 10 test files. Co-Authored-By: Virgil <virgil@lethean.io> Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
37abc496ba
commit
28b5c5bb47
5 changed files with 539 additions and 1 deletions
2
TODO.md
2
TODO.md
|
|
@ -8,7 +8,7 @@ Dispatched from core/go orchestration. Pick up tasks in order.
|
|||
|
||||
- [x] **Verify go generate → test round-trip** — ✅ 29/29 tests pass. CMake 3.24+, AppleClang 17.0.0, macOS SDK 26.2. Build takes ~2min on M3 Ultra.
|
||||
- [x] **Add missing tests for core operations** — ✅ 86 new tests across 4 files: array_test.go (25), ops_test.go (44), nn_test.go (8), fast_test.go (9). Covers: all scalar/array creation, shape ops, element-wise arithmetic, math functions, matrix ops, reductions, indexing, slicing, fused kernels (RMSNorm, LayerNorm, RoPE, SDPA), Linear, Embedding, RepeatKV. Found non-contiguous view bug in Floats()/DataInt32() — see FINDINGS.md.
|
||||
- [ ] **Add missing tests for model/tokenizer/sample/cache** — `model/`, `tokenizer/`, `sample/`, `cache/` have zero test files. Priority: tokenizer (BPE round-trip), sample (temperature/top-k), cache (KV append/trim).
|
||||
- [x] **Add missing tests for model/tokenizer/sample/cache** — ✅ 33 new tests: cache_test.go (10: KVCache + RotatingKVCache lifecycle, update, bounded, reset), sample_test.go (8: greedy, temperature, topK, chain, stub pass-through), tokenizer_test.go (15: Load/error, BOS/EOS, encode/decode, DecodeToken, SentencePiece space, GPT-2 byte maps). model/ still needs tests (requires model files on disk).
|
||||
- [ ] **Benchmark suite** — No benchmarks exist. Add: MatMul (various sizes), Softmax, model.Forward (single token), tokenizer.Encode/Decode, full Generate (tokens/sec). Baseline on M3 Ultra.
|
||||
|
||||
## Phase 2: Model Support
|
||||
|
|
|
|||
200
cache/cache_test.go
vendored
Normal file
200
cache/cache_test.go
vendored
Normal file
|
|
@ -0,0 +1,200 @@
|
|||
//go:build darwin && arm64
|
||||
|
||||
package cache
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"forge.lthn.ai/core/go-mlx"
|
||||
)
|
||||
|
||||
// makeKV creates a small K/V pair with shape [B=1, H=2, L=seqLen, D=4].
|
||||
func makeKV(seqLen int) (*mlx.Array, *mlx.Array) {
|
||||
size := 1 * 2 * seqLen * 4
|
||||
data := make([]float32, size)
|
||||
for i := range data {
|
||||
data[i] = float32(i) * 0.1
|
||||
}
|
||||
k := mlx.FromValues(data, 1, 2, seqLen, 4)
|
||||
v := mlx.FromValues(data, 1, 2, seqLen, 4)
|
||||
return k, v
|
||||
}
|
||||
|
||||
// --- KVCache ---
|
||||
|
||||
func TestKVCache_New(t *testing.T) {
|
||||
c := NewKVCache()
|
||||
if c.Offset() != 0 {
|
||||
t.Errorf("offset = %d, want 0", c.Offset())
|
||||
}
|
||||
if c.Len() != 0 {
|
||||
t.Errorf("len = %d, want 0", c.Len())
|
||||
}
|
||||
if c.State() != nil {
|
||||
t.Error("state should be nil for empty cache")
|
||||
}
|
||||
}
|
||||
|
||||
func TestKVCache_SingleUpdate(t *testing.T) {
|
||||
c := NewKVCache()
|
||||
k, v := makeKV(3) // 3 tokens
|
||||
|
||||
outK, outV := c.Update(k, v, 3)
|
||||
mlx.Materialize(outK, outV)
|
||||
|
||||
if c.Offset() != 3 {
|
||||
t.Errorf("offset = %d, want 3", c.Offset())
|
||||
}
|
||||
if c.Len() != 3 {
|
||||
t.Errorf("len = %d, want 3", c.Len())
|
||||
}
|
||||
|
||||
// Output K should have shape [1, 2, 3, 4]
|
||||
shape := outK.Shape()
|
||||
if shape[0] != 1 || shape[1] != 2 || shape[2] != 3 || shape[3] != 4 {
|
||||
t.Errorf("outK shape = %v, want [1 2 3 4]", shape)
|
||||
}
|
||||
}
|
||||
|
||||
func TestKVCache_MultipleUpdates(t *testing.T) {
|
||||
c := NewKVCache()
|
||||
|
||||
// Prompt: 5 tokens
|
||||
k1, v1 := makeKV(5)
|
||||
outK, outV := c.Update(k1, v1, 5)
|
||||
mlx.Materialize(outK, outV)
|
||||
|
||||
if c.Offset() != 5 {
|
||||
t.Errorf("offset = %d, want 5", c.Offset())
|
||||
}
|
||||
|
||||
// Generate: 1 token at a time
|
||||
k2, v2 := makeKV(1)
|
||||
outK, outV = c.Update(k2, v2, 1)
|
||||
mlx.Materialize(outK, outV)
|
||||
|
||||
if c.Offset() != 6 {
|
||||
t.Errorf("offset = %d, want 6", c.Offset())
|
||||
}
|
||||
|
||||
shape := outK.Shape()
|
||||
if shape[2] != 6 {
|
||||
t.Errorf("outK L dim = %d, want 6", shape[2])
|
||||
}
|
||||
}
|
||||
|
||||
func TestKVCache_Reset(t *testing.T) {
|
||||
c := NewKVCache()
|
||||
k, v := makeKV(3)
|
||||
c.Update(k, v, 3)
|
||||
|
||||
c.Reset()
|
||||
|
||||
if c.Offset() != 0 {
|
||||
t.Errorf("offset after reset = %d, want 0", c.Offset())
|
||||
}
|
||||
if c.State() != nil {
|
||||
t.Error("state should be nil after reset")
|
||||
}
|
||||
}
|
||||
|
||||
func TestKVCache_State(t *testing.T) {
|
||||
c := NewKVCache()
|
||||
k, v := makeKV(2)
|
||||
c.Update(k, v, 2)
|
||||
|
||||
state := c.State()
|
||||
if len(state) != 2 {
|
||||
t.Fatalf("state length = %d, want 2", len(state))
|
||||
}
|
||||
// state[0] = keys, state[1] = values
|
||||
if state[0] == nil || state[1] == nil {
|
||||
t.Error("state arrays should not be nil")
|
||||
}
|
||||
}
|
||||
|
||||
// --- RotatingKVCache ---
|
||||
|
||||
func TestRotatingKVCache_New(t *testing.T) {
|
||||
c := NewRotatingKVCache(16)
|
||||
if c.Offset() != 0 {
|
||||
t.Errorf("offset = %d, want 0", c.Offset())
|
||||
}
|
||||
if c.Len() != 0 {
|
||||
t.Errorf("len = %d, want 0", c.Len())
|
||||
}
|
||||
}
|
||||
|
||||
func TestRotatingKVCache_SingleToken(t *testing.T) {
|
||||
c := NewRotatingKVCache(8)
|
||||
k, v := makeKV(1)
|
||||
|
||||
outK, outV := c.Update(k, v, 1)
|
||||
mlx.Materialize(outK, outV)
|
||||
|
||||
if c.Offset() != 1 {
|
||||
t.Errorf("offset = %d, want 1", c.Offset())
|
||||
}
|
||||
if c.Len() != 1 {
|
||||
t.Errorf("len = %d, want 1", c.Len())
|
||||
}
|
||||
}
|
||||
|
||||
func TestRotatingKVCache_MultiTokenPrompt(t *testing.T) {
|
||||
c := NewRotatingKVCache(16)
|
||||
k, v := makeKV(5)
|
||||
|
||||
outK, outV := c.Update(k, v, 5)
|
||||
mlx.Materialize(outK, outV)
|
||||
|
||||
if c.Offset() != 5 {
|
||||
t.Errorf("offset = %d, want 5", c.Offset())
|
||||
}
|
||||
if c.Len() != 5 {
|
||||
t.Errorf("len = %d, want 5", c.Len())
|
||||
}
|
||||
}
|
||||
|
||||
func TestRotatingKVCache_Bounded(t *testing.T) {
|
||||
c := NewRotatingKVCache(4)
|
||||
|
||||
// Fill with 4-token prompt (at max)
|
||||
k, v := makeKV(4)
|
||||
outK, outV := c.Update(k, v, 4)
|
||||
mlx.Materialize(outK, outV)
|
||||
|
||||
if c.Len() != 4 {
|
||||
t.Errorf("len = %d, want 4 (at max)", c.Len())
|
||||
}
|
||||
|
||||
// Add one more token — should trim to maxSize
|
||||
k2, v2 := makeKV(1)
|
||||
outK, outV = c.Update(k2, v2, 1)
|
||||
mlx.Materialize(outK, outV)
|
||||
|
||||
if c.Offset() != 5 {
|
||||
t.Errorf("offset = %d, want 5", c.Offset())
|
||||
}
|
||||
// Len should be bounded by maxSize
|
||||
if c.Len() != 4 {
|
||||
t.Errorf("len = %d, want 4 (bounded)", c.Len())
|
||||
}
|
||||
}
|
||||
|
||||
func TestRotatingKVCache_Reset(t *testing.T) {
|
||||
c := NewRotatingKVCache(8)
|
||||
k, v := makeKV(3)
|
||||
c.Update(k, v, 3)
|
||||
|
||||
c.Reset()
|
||||
|
||||
if c.Offset() != 0 {
|
||||
t.Errorf("offset after reset = %d, want 0", c.Offset())
|
||||
}
|
||||
if c.Len() != 0 {
|
||||
t.Errorf("len after reset = %d, want 0", c.Len())
|
||||
}
|
||||
if c.State() != nil {
|
||||
t.Error("state should be nil after reset")
|
||||
}
|
||||
}
|
||||
|
|
@ -10,6 +10,11 @@ Tasks for the CLion Claude session. Written by GoLand Claude or Virgil.
|
|||
- [ ] **Understand the error model** — `error.h` provides `mlx_set_error_handler()`. The Go side registers a handler that logs to stderr. Research: can we get structured error info (error codes, categories)? Is the error string stable or does it vary?
|
||||
- [ ] **Check memory management patterns** — `mlx_*_free()` functions exist for each type. Verify: is double-free safe? What happens if you free during async eval? Document for the Go finaliser integration.
|
||||
|
||||
## Priority Tasks (from GoLand Claude)
|
||||
|
||||
- [ ] **Find `mlx_contiguous` or equivalent** — `Floats()`/`DataInt32()` on non-contiguous arrays (transpose, broadcast, slice views) returns wrong data because `mlx_array_data_float32` returns the physical buffer, not the logical layout. Need a C function that copies a non-contiguous array to contiguous memory. Check if `mlx_contiguous` exists in mlx-c headers or if we need `mlx_reshape` to force a copy. This is a data correctness bug — see FINDINGS.md in project root.
|
||||
- [ ] **Verify `mlx_array_data_*` eval semantics** — Does `mlx_array_data_float32()` trigger evaluation (like C++ `array::data()` does), or must we call `mlx_eval` first? The Go side calls `Materialize()` before data access but some code paths might skip it.
|
||||
|
||||
## Standing Tasks
|
||||
|
||||
- [ ] **API gap analysis** — When the GoLand Claude needs a C function that isn't exposed by mlx-c, document the gap here and research if upstream mlx-c supports it or if a patch is needed.
|
||||
|
|
|
|||
116
sample/sample_test.go
Normal file
116
sample/sample_test.go
Normal file
|
|
@ -0,0 +1,116 @@
|
|||
//go:build darwin && arm64
|
||||
|
||||
package sample
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"forge.lthn.ai/core/go-mlx"
|
||||
)
|
||||
|
||||
func TestGreedy(t *testing.T) {
|
||||
// Logits heavily favour index 2
|
||||
logits := mlx.FromValues([]float32{-10, -10, 100, -10}, 1, 4)
|
||||
s := New(0, 0, 0, 0) // temp=0 → greedy
|
||||
token := s.Sample(logits)
|
||||
mlx.Materialize(token)
|
||||
|
||||
if token.Int() != 2 {
|
||||
t.Errorf("greedy sample = %d, want 2", token.Int())
|
||||
}
|
||||
}
|
||||
|
||||
func TestTemperature_HighTemp(t *testing.T) {
|
||||
// High temperature should still produce a valid index
|
||||
logits := mlx.FromValues([]float32{1, 2, 3, 4}, 1, 4)
|
||||
s := New(100.0, 0, 0, 0) // very high temp → near uniform
|
||||
token := s.Sample(logits)
|
||||
mlx.Materialize(token)
|
||||
|
||||
idx := token.Int()
|
||||
if idx < 0 || idx >= 4 {
|
||||
t.Errorf("sample index = %d, out of range [0, 4)", idx)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTemperature_LowTemp(t *testing.T) {
|
||||
// Very low temperature should behave like greedy
|
||||
logits := mlx.FromValues([]float32{-10, -10, 100, -10}, 1, 4)
|
||||
s := New(0.001, 0, 0, 0) // near-zero temp → near-greedy
|
||||
token := s.Sample(logits)
|
||||
mlx.Materialize(token)
|
||||
|
||||
if token.Int() != 2 {
|
||||
t.Errorf("low-temp sample = %d, want 2 (near greedy)", token.Int())
|
||||
}
|
||||
}
|
||||
|
||||
func TestTopK(t *testing.T) {
|
||||
// TopK=1 with clear winner should always pick that token
|
||||
logits := mlx.FromValues([]float32{-100, 100, -100, -100}, 1, 4)
|
||||
s := New(1.0, 0, 0, 1) // topK=1
|
||||
token := s.Sample(logits)
|
||||
mlx.Materialize(token)
|
||||
|
||||
if token.Int() != 1 {
|
||||
t.Errorf("topk=1 sample = %d, want 1", token.Int())
|
||||
}
|
||||
}
|
||||
|
||||
func TestTopK_MultipleTokens(t *testing.T) {
|
||||
// TopK=2, both high logits — should pick one of them
|
||||
logits := mlx.FromValues([]float32{-100, 50, 50, -100}, 1, 4)
|
||||
s := New(1.0, 0, 0, 2) // topK=2
|
||||
|
||||
seen := map[int]bool{}
|
||||
for range 20 {
|
||||
token := s.Sample(logits)
|
||||
mlx.Materialize(token)
|
||||
seen[token.Int()] = true
|
||||
}
|
||||
|
||||
// Should only ever pick index 1 or 2
|
||||
for idx := range seen {
|
||||
if idx != 1 && idx != 2 {
|
||||
t.Errorf("topk=2 sampled index %d, expected only 1 or 2", idx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNew_Chain(t *testing.T) {
|
||||
// Full chain: topK + temperature
|
||||
logits := mlx.FromValues([]float32{1, 2, 3, 4, 5}, 1, 5)
|
||||
s := New(0.5, 0, 0, 3) // temp=0.5, topK=3
|
||||
|
||||
token := s.Sample(logits)
|
||||
mlx.Materialize(token)
|
||||
|
||||
idx := token.Int()
|
||||
if idx < 0 || idx >= 5 {
|
||||
t.Errorf("chain sample index = %d, out of range", idx)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTopP_PassThrough(t *testing.T) {
|
||||
// TopP is currently a stub — verify it doesn't break the chain
|
||||
logits := mlx.FromValues([]float32{-10, -10, 100, -10}, 1, 4)
|
||||
s := New(0.5, 0.9, 0, 0) // topP=0.9 (stub), temp=0.5
|
||||
token := s.Sample(logits)
|
||||
mlx.Materialize(token)
|
||||
|
||||
if token.Int() != 2 {
|
||||
t.Errorf("topP stub + temp sample = %d, want 2", token.Int())
|
||||
}
|
||||
}
|
||||
|
||||
func TestMinP_PassThrough(t *testing.T) {
|
||||
// MinP is currently a stub — verify it doesn't break the chain
|
||||
logits := mlx.FromValues([]float32{-10, -10, 100, -10}, 1, 4)
|
||||
s := New(0.5, 0, 0.1, 0) // minP=0.1 (stub), temp=0.5
|
||||
token := s.Sample(logits)
|
||||
mlx.Materialize(token)
|
||||
|
||||
if token.Int() != 2 {
|
||||
t.Errorf("minP stub + temp sample = %d, want 2", token.Int())
|
||||
}
|
||||
}
|
||||
217
tokenizer/tokenizer_test.go
Normal file
217
tokenizer/tokenizer_test.go
Normal file
|
|
@ -0,0 +1,217 @@
|
|||
//go:build darwin && arm64
|
||||
|
||||
package tokenizer
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// minimalTokenizerJSON is a valid HuggingFace tokenizer.json with a tiny vocab.
|
||||
const minimalTokenizerJSON = `{
|
||||
"model": {
|
||||
"type": "BPE",
|
||||
"vocab": {
|
||||
"h": 0,
|
||||
"e": 1,
|
||||
"l": 2,
|
||||
"o": 3,
|
||||
"▁": 4,
|
||||
"he": 5,
|
||||
"ll": 6,
|
||||
"▁h": 7
|
||||
},
|
||||
"merges": ["h e", "l l"],
|
||||
"byte_fallback": false
|
||||
},
|
||||
"added_tokens": [
|
||||
{"id": 100, "content": "<bos>", "special": true},
|
||||
{"id": 101, "content": "<eos>", "special": true}
|
||||
]
|
||||
}`
|
||||
|
||||
func writeTestTokenizer(t *testing.T) string {
|
||||
t.Helper()
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "tokenizer.json")
|
||||
if err := os.WriteFile(path, []byte(minimalTokenizerJSON), 0644); err != nil {
|
||||
t.Fatalf("write test tokenizer: %v", err)
|
||||
}
|
||||
return path
|
||||
}
|
||||
|
||||
func TestLoad(t *testing.T) {
|
||||
path := writeTestTokenizer(t)
|
||||
tok, err := Load(path)
|
||||
if err != nil {
|
||||
t.Fatalf("Load: %v", err)
|
||||
}
|
||||
if tok == nil {
|
||||
t.Fatal("tokenizer is nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoad_MissingFile(t *testing.T) {
|
||||
_, err := Load("/nonexistent/tokenizer.json")
|
||||
if err == nil {
|
||||
t.Error("expected error for missing file")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoad_InvalidJSON(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "tokenizer.json")
|
||||
os.WriteFile(path, []byte("not json"), 0644)
|
||||
|
||||
_, err := Load(path)
|
||||
if err == nil {
|
||||
t.Error("expected error for invalid JSON")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBOSEOS(t *testing.T) {
|
||||
path := writeTestTokenizer(t)
|
||||
tok, _ := Load(path)
|
||||
|
||||
if tok.BOSToken() != 100 {
|
||||
t.Errorf("BOS = %d, want 100", tok.BOSToken())
|
||||
}
|
||||
if tok.EOSToken() != 101 {
|
||||
t.Errorf("EOS = %d, want 101", tok.EOSToken())
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncode_ProducesTokens(t *testing.T) {
|
||||
path := writeTestTokenizer(t)
|
||||
tok, _ := Load(path)
|
||||
|
||||
tokens := tok.Encode("hello")
|
||||
if len(tokens) == 0 {
|
||||
t.Fatal("Encode returned empty tokens")
|
||||
}
|
||||
// First token should be BOS
|
||||
if tokens[0] != tok.BOSToken() {
|
||||
t.Errorf("first token = %d, want BOS (%d)", tokens[0], tok.BOSToken())
|
||||
}
|
||||
t.Logf("Encode(\"hello\") = %v", tokens)
|
||||
}
|
||||
|
||||
func TestDecode_SpecialTokensSkipped(t *testing.T) {
|
||||
path := writeTestTokenizer(t)
|
||||
tok, _ := Load(path)
|
||||
|
||||
// Decoding BOS/EOS should produce empty string
|
||||
text := tok.Decode([]int32{100, 101})
|
||||
if text != "" {
|
||||
t.Errorf("Decode(BOS, EOS) = %q, want empty", text)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecode_RegularTokens(t *testing.T) {
|
||||
path := writeTestTokenizer(t)
|
||||
tok, _ := Load(path)
|
||||
|
||||
// Decode known vocab entries
|
||||
text := tok.Decode([]int32{5, 6, 3}) // "he" + "ll" + "o"
|
||||
if text != "hello" {
|
||||
t.Errorf("Decode = %q, want %q", text, "hello")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeToken_Regular(t *testing.T) {
|
||||
path := writeTestTokenizer(t)
|
||||
tok, _ := Load(path)
|
||||
|
||||
// "he" = token 5
|
||||
text := tok.DecodeToken(5)
|
||||
if text != "he" {
|
||||
t.Errorf("DecodeToken(5) = %q, want %q", text, "he")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeToken_Special(t *testing.T) {
|
||||
path := writeTestTokenizer(t)
|
||||
tok, _ := Load(path)
|
||||
|
||||
// Special tokens should return empty
|
||||
text := tok.DecodeToken(100)
|
||||
if text != "" {
|
||||
t.Errorf("DecodeToken(BOS) = %q, want empty", text)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeToken_SentencePieceSpace(t *testing.T) {
|
||||
path := writeTestTokenizer(t)
|
||||
tok, _ := Load(path)
|
||||
|
||||
// "▁h" = token 7, should decode to " h" (space prefix)
|
||||
text := tok.DecodeToken(7)
|
||||
if text != " h" {
|
||||
t.Errorf("DecodeToken(7) = %q, want %q", text, " h")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeToken_Unknown(t *testing.T) {
|
||||
path := writeTestTokenizer(t)
|
||||
tok, _ := Load(path)
|
||||
|
||||
text := tok.DecodeToken(9999)
|
||||
if text != "" {
|
||||
t.Errorf("DecodeToken(unknown) = %q, want empty", text)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatGemmaPrompt(t *testing.T) {
|
||||
got := FormatGemmaPrompt("What is 2+2?")
|
||||
want := "<start_of_turn>user\nWhat is 2+2?<end_of_turn>\n<start_of_turn>model\n"
|
||||
if got != want {
|
||||
t.Errorf("FormatGemmaPrompt = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
// --- GPT-2 byte maps ---
|
||||
|
||||
func TestBuildGPT2ByteMaps(t *testing.T) {
|
||||
decoder, encoder := buildGPT2ByteMaps()
|
||||
|
||||
// All 256 bytes must be mapped
|
||||
if len(encoder) != 256 {
|
||||
t.Errorf("encoder has %d entries, want 256", len(encoder))
|
||||
}
|
||||
if len(decoder) != 256 {
|
||||
t.Errorf("decoder has %d entries, want 256", len(decoder))
|
||||
}
|
||||
|
||||
// Round-trip: every byte should survive encode → decode
|
||||
for b := 0; b < 256; b++ {
|
||||
r := encoder[byte(b)]
|
||||
got := decoder[r]
|
||||
if got != byte(b) {
|
||||
t.Errorf("byte %d: encode→decode = %d, want %d", b, got, b)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildGPT2ByteMaps_PrintableASCII(t *testing.T) {
|
||||
_, encoder := buildGPT2ByteMaps()
|
||||
|
||||
// Printable ASCII (33-126) should self-map
|
||||
for b := 33; b <= 126; b++ {
|
||||
if encoder[byte(b)] != rune(b) {
|
||||
t.Errorf("byte %d (%c): expected self-map, got %c", b, b, encoder[byte(b)])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildGPT2ByteMaps_ControlChars(t *testing.T) {
|
||||
_, encoder := buildGPT2ByteMaps()
|
||||
|
||||
// Space (32) and control chars (0-31) should NOT self-map
|
||||
if encoder[byte(32)] == rune(32) {
|
||||
t.Error("space (32) should not self-map in GPT-2 encoding")
|
||||
}
|
||||
if encoder[byte(0)] == rune(0) {
|
||||
t.Error("null (0) should not self-map in GPT-2 encoding")
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Reference in a new issue