From 28b5c5bb4718aa8112b518a9cdffc8be1b00dbd0 Mon Sep 17 00:00:00 2001 From: Snider Date: Thu, 19 Feb 2026 18:40:49 +0000 Subject: [PATCH] test(sub-packages): add 33 tests for cache, sample, tokenizer MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - cache_test.go (10): KVCache and RotatingKVCache — creation, single/multi update, bounded sliding window, reset, state access - sample_test.go (8): greedy, temperature (high/low), topK (single/multi), chain composition, TopP/MinP stub pass-through - tokenizer_test.go (15): Load (valid/missing/invalid), BOS/EOS detection, encode produces tokens, decode skips specials, decode regular tokens, DecodeToken (regular/special/space/unknown), FormatGemmaPrompt, GPT-2 byte maps (round-trip, printable ASCII, control chars) Total test count: 29 → 148 across 10 test files. Co-Authored-By: Virgil Co-Authored-By: Claude Opus 4.6 --- TODO.md | 2 +- cache/cache_test.go | 200 +++++++++++++++++++++++++++++++++ cpp/TODO.md | 5 + sample/sample_test.go | 116 +++++++++++++++++++ tokenizer/tokenizer_test.go | 217 ++++++++++++++++++++++++++++++++++++ 5 files changed, 539 insertions(+), 1 deletion(-) create mode 100644 cache/cache_test.go create mode 100644 sample/sample_test.go create mode 100644 tokenizer/tokenizer_test.go diff --git a/TODO.md b/TODO.md index 5e4b322..22759cc 100644 --- a/TODO.md +++ b/TODO.md @@ -8,7 +8,7 @@ Dispatched from core/go orchestration. Pick up tasks in order. - [x] **Verify go generate → test round-trip** — ✅ 29/29 tests pass. CMake 3.24+, AppleClang 17.0.0, macOS SDK 26.2. Build takes ~2min on M3 Ultra. - [x] **Add missing tests for core operations** — ✅ 86 new tests across 4 files: array_test.go (25), ops_test.go (44), nn_test.go (8), fast_test.go (9). Covers: all scalar/array creation, shape ops, element-wise arithmetic, math functions, matrix ops, reductions, indexing, slicing, fused kernels (RMSNorm, LayerNorm, RoPE, SDPA), Linear, Embedding, RepeatKV. Found non-contiguous view bug in Floats()/DataInt32() — see FINDINGS.md. -- [ ] **Add missing tests for model/tokenizer/sample/cache** — `model/`, `tokenizer/`, `sample/`, `cache/` have zero test files. Priority: tokenizer (BPE round-trip), sample (temperature/top-k), cache (KV append/trim). +- [x] **Add missing tests for model/tokenizer/sample/cache** — ✅ 33 new tests: cache_test.go (10: KVCache + RotatingKVCache lifecycle, update, bounded, reset), sample_test.go (8: greedy, temperature, topK, chain, stub pass-through), tokenizer_test.go (15: Load/error, BOS/EOS, encode/decode, DecodeToken, SentencePiece space, GPT-2 byte maps). model/ still needs tests (requires model files on disk). - [ ] **Benchmark suite** — No benchmarks exist. Add: MatMul (various sizes), Softmax, model.Forward (single token), tokenizer.Encode/Decode, full Generate (tokens/sec). Baseline on M3 Ultra. ## Phase 2: Model Support diff --git a/cache/cache_test.go b/cache/cache_test.go new file mode 100644 index 0000000..e660af6 --- /dev/null +++ b/cache/cache_test.go @@ -0,0 +1,200 @@ +//go:build darwin && arm64 + +package cache + +import ( + "testing" + + "forge.lthn.ai/core/go-mlx" +) + +// makeKV creates a small K/V pair with shape [B=1, H=2, L=seqLen, D=4]. +func makeKV(seqLen int) (*mlx.Array, *mlx.Array) { + size := 1 * 2 * seqLen * 4 + data := make([]float32, size) + for i := range data { + data[i] = float32(i) * 0.1 + } + k := mlx.FromValues(data, 1, 2, seqLen, 4) + v := mlx.FromValues(data, 1, 2, seqLen, 4) + return k, v +} + +// --- KVCache --- + +func TestKVCache_New(t *testing.T) { + c := NewKVCache() + if c.Offset() != 0 { + t.Errorf("offset = %d, want 0", c.Offset()) + } + if c.Len() != 0 { + t.Errorf("len = %d, want 0", c.Len()) + } + if c.State() != nil { + t.Error("state should be nil for empty cache") + } +} + +func TestKVCache_SingleUpdate(t *testing.T) { + c := NewKVCache() + k, v := makeKV(3) // 3 tokens + + outK, outV := c.Update(k, v, 3) + mlx.Materialize(outK, outV) + + if c.Offset() != 3 { + t.Errorf("offset = %d, want 3", c.Offset()) + } + if c.Len() != 3 { + t.Errorf("len = %d, want 3", c.Len()) + } + + // Output K should have shape [1, 2, 3, 4] + shape := outK.Shape() + if shape[0] != 1 || shape[1] != 2 || shape[2] != 3 || shape[3] != 4 { + t.Errorf("outK shape = %v, want [1 2 3 4]", shape) + } +} + +func TestKVCache_MultipleUpdates(t *testing.T) { + c := NewKVCache() + + // Prompt: 5 tokens + k1, v1 := makeKV(5) + outK, outV := c.Update(k1, v1, 5) + mlx.Materialize(outK, outV) + + if c.Offset() != 5 { + t.Errorf("offset = %d, want 5", c.Offset()) + } + + // Generate: 1 token at a time + k2, v2 := makeKV(1) + outK, outV = c.Update(k2, v2, 1) + mlx.Materialize(outK, outV) + + if c.Offset() != 6 { + t.Errorf("offset = %d, want 6", c.Offset()) + } + + shape := outK.Shape() + if shape[2] != 6 { + t.Errorf("outK L dim = %d, want 6", shape[2]) + } +} + +func TestKVCache_Reset(t *testing.T) { + c := NewKVCache() + k, v := makeKV(3) + c.Update(k, v, 3) + + c.Reset() + + if c.Offset() != 0 { + t.Errorf("offset after reset = %d, want 0", c.Offset()) + } + if c.State() != nil { + t.Error("state should be nil after reset") + } +} + +func TestKVCache_State(t *testing.T) { + c := NewKVCache() + k, v := makeKV(2) + c.Update(k, v, 2) + + state := c.State() + if len(state) != 2 { + t.Fatalf("state length = %d, want 2", len(state)) + } + // state[0] = keys, state[1] = values + if state[0] == nil || state[1] == nil { + t.Error("state arrays should not be nil") + } +} + +// --- RotatingKVCache --- + +func TestRotatingKVCache_New(t *testing.T) { + c := NewRotatingKVCache(16) + if c.Offset() != 0 { + t.Errorf("offset = %d, want 0", c.Offset()) + } + if c.Len() != 0 { + t.Errorf("len = %d, want 0", c.Len()) + } +} + +func TestRotatingKVCache_SingleToken(t *testing.T) { + c := NewRotatingKVCache(8) + k, v := makeKV(1) + + outK, outV := c.Update(k, v, 1) + mlx.Materialize(outK, outV) + + if c.Offset() != 1 { + t.Errorf("offset = %d, want 1", c.Offset()) + } + if c.Len() != 1 { + t.Errorf("len = %d, want 1", c.Len()) + } +} + +func TestRotatingKVCache_MultiTokenPrompt(t *testing.T) { + c := NewRotatingKVCache(16) + k, v := makeKV(5) + + outK, outV := c.Update(k, v, 5) + mlx.Materialize(outK, outV) + + if c.Offset() != 5 { + t.Errorf("offset = %d, want 5", c.Offset()) + } + if c.Len() != 5 { + t.Errorf("len = %d, want 5", c.Len()) + } +} + +func TestRotatingKVCache_Bounded(t *testing.T) { + c := NewRotatingKVCache(4) + + // Fill with 4-token prompt (at max) + k, v := makeKV(4) + outK, outV := c.Update(k, v, 4) + mlx.Materialize(outK, outV) + + if c.Len() != 4 { + t.Errorf("len = %d, want 4 (at max)", c.Len()) + } + + // Add one more token — should trim to maxSize + k2, v2 := makeKV(1) + outK, outV = c.Update(k2, v2, 1) + mlx.Materialize(outK, outV) + + if c.Offset() != 5 { + t.Errorf("offset = %d, want 5", c.Offset()) + } + // Len should be bounded by maxSize + if c.Len() != 4 { + t.Errorf("len = %d, want 4 (bounded)", c.Len()) + } +} + +func TestRotatingKVCache_Reset(t *testing.T) { + c := NewRotatingKVCache(8) + k, v := makeKV(3) + c.Update(k, v, 3) + + c.Reset() + + if c.Offset() != 0 { + t.Errorf("offset after reset = %d, want 0", c.Offset()) + } + if c.Len() != 0 { + t.Errorf("len after reset = %d, want 0", c.Len()) + } + if c.State() != nil { + t.Error("state should be nil after reset") + } +} diff --git a/cpp/TODO.md b/cpp/TODO.md index 4d35b12..8504e80 100644 --- a/cpp/TODO.md +++ b/cpp/TODO.md @@ -10,6 +10,11 @@ Tasks for the CLion Claude session. Written by GoLand Claude or Virgil. - [ ] **Understand the error model** — `error.h` provides `mlx_set_error_handler()`. The Go side registers a handler that logs to stderr. Research: can we get structured error info (error codes, categories)? Is the error string stable or does it vary? - [ ] **Check memory management patterns** — `mlx_*_free()` functions exist for each type. Verify: is double-free safe? What happens if you free during async eval? Document for the Go finaliser integration. +## Priority Tasks (from GoLand Claude) + +- [ ] **Find `mlx_contiguous` or equivalent** — `Floats()`/`DataInt32()` on non-contiguous arrays (transpose, broadcast, slice views) returns wrong data because `mlx_array_data_float32` returns the physical buffer, not the logical layout. Need a C function that copies a non-contiguous array to contiguous memory. Check if `mlx_contiguous` exists in mlx-c headers or if we need `mlx_reshape` to force a copy. This is a data correctness bug — see FINDINGS.md in project root. +- [ ] **Verify `mlx_array_data_*` eval semantics** — Does `mlx_array_data_float32()` trigger evaluation (like C++ `array::data()` does), or must we call `mlx_eval` first? The Go side calls `Materialize()` before data access but some code paths might skip it. + ## Standing Tasks - [ ] **API gap analysis** — When the GoLand Claude needs a C function that isn't exposed by mlx-c, document the gap here and research if upstream mlx-c supports it or if a patch is needed. diff --git a/sample/sample_test.go b/sample/sample_test.go new file mode 100644 index 0000000..7cd3e62 --- /dev/null +++ b/sample/sample_test.go @@ -0,0 +1,116 @@ +//go:build darwin && arm64 + +package sample + +import ( + "testing" + + "forge.lthn.ai/core/go-mlx" +) + +func TestGreedy(t *testing.T) { + // Logits heavily favour index 2 + logits := mlx.FromValues([]float32{-10, -10, 100, -10}, 1, 4) + s := New(0, 0, 0, 0) // temp=0 → greedy + token := s.Sample(logits) + mlx.Materialize(token) + + if token.Int() != 2 { + t.Errorf("greedy sample = %d, want 2", token.Int()) + } +} + +func TestTemperature_HighTemp(t *testing.T) { + // High temperature should still produce a valid index + logits := mlx.FromValues([]float32{1, 2, 3, 4}, 1, 4) + s := New(100.0, 0, 0, 0) // very high temp → near uniform + token := s.Sample(logits) + mlx.Materialize(token) + + idx := token.Int() + if idx < 0 || idx >= 4 { + t.Errorf("sample index = %d, out of range [0, 4)", idx) + } +} + +func TestTemperature_LowTemp(t *testing.T) { + // Very low temperature should behave like greedy + logits := mlx.FromValues([]float32{-10, -10, 100, -10}, 1, 4) + s := New(0.001, 0, 0, 0) // near-zero temp → near-greedy + token := s.Sample(logits) + mlx.Materialize(token) + + if token.Int() != 2 { + t.Errorf("low-temp sample = %d, want 2 (near greedy)", token.Int()) + } +} + +func TestTopK(t *testing.T) { + // TopK=1 with clear winner should always pick that token + logits := mlx.FromValues([]float32{-100, 100, -100, -100}, 1, 4) + s := New(1.0, 0, 0, 1) // topK=1 + token := s.Sample(logits) + mlx.Materialize(token) + + if token.Int() != 1 { + t.Errorf("topk=1 sample = %d, want 1", token.Int()) + } +} + +func TestTopK_MultipleTokens(t *testing.T) { + // TopK=2, both high logits — should pick one of them + logits := mlx.FromValues([]float32{-100, 50, 50, -100}, 1, 4) + s := New(1.0, 0, 0, 2) // topK=2 + + seen := map[int]bool{} + for range 20 { + token := s.Sample(logits) + mlx.Materialize(token) + seen[token.Int()] = true + } + + // Should only ever pick index 1 or 2 + for idx := range seen { + if idx != 1 && idx != 2 { + t.Errorf("topk=2 sampled index %d, expected only 1 or 2", idx) + } + } +} + +func TestNew_Chain(t *testing.T) { + // Full chain: topK + temperature + logits := mlx.FromValues([]float32{1, 2, 3, 4, 5}, 1, 5) + s := New(0.5, 0, 0, 3) // temp=0.5, topK=3 + + token := s.Sample(logits) + mlx.Materialize(token) + + idx := token.Int() + if idx < 0 || idx >= 5 { + t.Errorf("chain sample index = %d, out of range", idx) + } +} + +func TestTopP_PassThrough(t *testing.T) { + // TopP is currently a stub — verify it doesn't break the chain + logits := mlx.FromValues([]float32{-10, -10, 100, -10}, 1, 4) + s := New(0.5, 0.9, 0, 0) // topP=0.9 (stub), temp=0.5 + token := s.Sample(logits) + mlx.Materialize(token) + + if token.Int() != 2 { + t.Errorf("topP stub + temp sample = %d, want 2", token.Int()) + } +} + +func TestMinP_PassThrough(t *testing.T) { + // MinP is currently a stub — verify it doesn't break the chain + logits := mlx.FromValues([]float32{-10, -10, 100, -10}, 1, 4) + s := New(0.5, 0, 0.1, 0) // minP=0.1 (stub), temp=0.5 + token := s.Sample(logits) + mlx.Materialize(token) + + if token.Int() != 2 { + t.Errorf("minP stub + temp sample = %d, want 2", token.Int()) + } +} diff --git a/tokenizer/tokenizer_test.go b/tokenizer/tokenizer_test.go new file mode 100644 index 0000000..2230f5f --- /dev/null +++ b/tokenizer/tokenizer_test.go @@ -0,0 +1,217 @@ +//go:build darwin && arm64 + +package tokenizer + +import ( + "os" + "path/filepath" + "testing" +) + +// minimalTokenizerJSON is a valid HuggingFace tokenizer.json with a tiny vocab. +const minimalTokenizerJSON = `{ + "model": { + "type": "BPE", + "vocab": { + "h": 0, + "e": 1, + "l": 2, + "o": 3, + "▁": 4, + "he": 5, + "ll": 6, + "▁h": 7 + }, + "merges": ["h e", "l l"], + "byte_fallback": false + }, + "added_tokens": [ + {"id": 100, "content": "", "special": true}, + {"id": 101, "content": "", "special": true} + ] +}` + +func writeTestTokenizer(t *testing.T) string { + t.Helper() + dir := t.TempDir() + path := filepath.Join(dir, "tokenizer.json") + if err := os.WriteFile(path, []byte(minimalTokenizerJSON), 0644); err != nil { + t.Fatalf("write test tokenizer: %v", err) + } + return path +} + +func TestLoad(t *testing.T) { + path := writeTestTokenizer(t) + tok, err := Load(path) + if err != nil { + t.Fatalf("Load: %v", err) + } + if tok == nil { + t.Fatal("tokenizer is nil") + } +} + +func TestLoad_MissingFile(t *testing.T) { + _, err := Load("/nonexistent/tokenizer.json") + if err == nil { + t.Error("expected error for missing file") + } +} + +func TestLoad_InvalidJSON(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "tokenizer.json") + os.WriteFile(path, []byte("not json"), 0644) + + _, err := Load(path) + if err == nil { + t.Error("expected error for invalid JSON") + } +} + +func TestBOSEOS(t *testing.T) { + path := writeTestTokenizer(t) + tok, _ := Load(path) + + if tok.BOSToken() != 100 { + t.Errorf("BOS = %d, want 100", tok.BOSToken()) + } + if tok.EOSToken() != 101 { + t.Errorf("EOS = %d, want 101", tok.EOSToken()) + } +} + +func TestEncode_ProducesTokens(t *testing.T) { + path := writeTestTokenizer(t) + tok, _ := Load(path) + + tokens := tok.Encode("hello") + if len(tokens) == 0 { + t.Fatal("Encode returned empty tokens") + } + // First token should be BOS + if tokens[0] != tok.BOSToken() { + t.Errorf("first token = %d, want BOS (%d)", tokens[0], tok.BOSToken()) + } + t.Logf("Encode(\"hello\") = %v", tokens) +} + +func TestDecode_SpecialTokensSkipped(t *testing.T) { + path := writeTestTokenizer(t) + tok, _ := Load(path) + + // Decoding BOS/EOS should produce empty string + text := tok.Decode([]int32{100, 101}) + if text != "" { + t.Errorf("Decode(BOS, EOS) = %q, want empty", text) + } +} + +func TestDecode_RegularTokens(t *testing.T) { + path := writeTestTokenizer(t) + tok, _ := Load(path) + + // Decode known vocab entries + text := tok.Decode([]int32{5, 6, 3}) // "he" + "ll" + "o" + if text != "hello" { + t.Errorf("Decode = %q, want %q", text, "hello") + } +} + +func TestDecodeToken_Regular(t *testing.T) { + path := writeTestTokenizer(t) + tok, _ := Load(path) + + // "he" = token 5 + text := tok.DecodeToken(5) + if text != "he" { + t.Errorf("DecodeToken(5) = %q, want %q", text, "he") + } +} + +func TestDecodeToken_Special(t *testing.T) { + path := writeTestTokenizer(t) + tok, _ := Load(path) + + // Special tokens should return empty + text := tok.DecodeToken(100) + if text != "" { + t.Errorf("DecodeToken(BOS) = %q, want empty", text) + } +} + +func TestDecodeToken_SentencePieceSpace(t *testing.T) { + path := writeTestTokenizer(t) + tok, _ := Load(path) + + // "▁h" = token 7, should decode to " h" (space prefix) + text := tok.DecodeToken(7) + if text != " h" { + t.Errorf("DecodeToken(7) = %q, want %q", text, " h") + } +} + +func TestDecodeToken_Unknown(t *testing.T) { + path := writeTestTokenizer(t) + tok, _ := Load(path) + + text := tok.DecodeToken(9999) + if text != "" { + t.Errorf("DecodeToken(unknown) = %q, want empty", text) + } +} + +func TestFormatGemmaPrompt(t *testing.T) { + got := FormatGemmaPrompt("What is 2+2?") + want := "user\nWhat is 2+2?\nmodel\n" + if got != want { + t.Errorf("FormatGemmaPrompt = %q, want %q", got, want) + } +} + +// --- GPT-2 byte maps --- + +func TestBuildGPT2ByteMaps(t *testing.T) { + decoder, encoder := buildGPT2ByteMaps() + + // All 256 bytes must be mapped + if len(encoder) != 256 { + t.Errorf("encoder has %d entries, want 256", len(encoder)) + } + if len(decoder) != 256 { + t.Errorf("decoder has %d entries, want 256", len(decoder)) + } + + // Round-trip: every byte should survive encode → decode + for b := 0; b < 256; b++ { + r := encoder[byte(b)] + got := decoder[r] + if got != byte(b) { + t.Errorf("byte %d: encode→decode = %d, want %d", b, got, b) + } + } +} + +func TestBuildGPT2ByteMaps_PrintableASCII(t *testing.T) { + _, encoder := buildGPT2ByteMaps() + + // Printable ASCII (33-126) should self-map + for b := 33; b <= 126; b++ { + if encoder[byte(b)] != rune(b) { + t.Errorf("byte %d (%c): expected self-map, got %c", b, b, encoder[byte(b)]) + } + } +} + +func TestBuildGPT2ByteMaps_ControlChars(t *testing.T) { + _, encoder := buildGPT2ByteMaps() + + // Space (32) and control chars (0-31) should NOT self-map + if encoder[byte(32)] == rune(32) { + t.Error("space (32) should not self-map in GPT-2 encoding") + } + if encoder[byte(0)] == rune(0) { + t.Error("null (0) should not self-map in GPT-2 encoding") + } +}