From c612c3e0600d34c8f47a10cdbf9b84300406f07a Mon Sep 17 00:00:00 2001 From: Snider Date: Thu, 19 Feb 2026 20:00:02 +0000 Subject: [PATCH] refactor(metal): move all tests to internal/metal (148 tests passing) Co-Authored-By: Virgil --- array_test.go => internal/metal/array_test.go | 2 +- internal/metal/cache_test.go | 198 ++++++++++++++++ fast_test.go => internal/metal/fast_test.go | 2 +- grad_test.go => internal/metal/grad_test.go | 2 +- lora_test.go => internal/metal/lora_test.go | 2 +- nn_test.go => internal/metal/nn_test.go | 2 +- ops_test.go => internal/metal/ops_test.go | 2 +- optim_test.go => internal/metal/optim_test.go | 2 +- internal/metal/sample_test.go | 114 +++++++++ internal/metal/tokenizer_test.go | 217 ++++++++++++++++++ 10 files changed, 536 insertions(+), 7 deletions(-) rename array_test.go => internal/metal/array_test.go (99%) create mode 100644 internal/metal/cache_test.go rename fast_test.go => internal/metal/fast_test.go (99%) rename grad_test.go => internal/metal/grad_test.go (99%) rename lora_test.go => internal/metal/lora_test.go (99%) rename nn_test.go => internal/metal/nn_test.go (99%) rename ops_test.go => internal/metal/ops_test.go (99%) rename optim_test.go => internal/metal/optim_test.go (99%) create mode 100644 internal/metal/sample_test.go create mode 100644 internal/metal/tokenizer_test.go diff --git a/array_test.go b/internal/metal/array_test.go similarity index 99% rename from array_test.go rename to internal/metal/array_test.go index 397c811..ff745ac 100644 --- a/array_test.go +++ b/internal/metal/array_test.go @@ -1,6 +1,6 @@ //go:build darwin && arm64 -package mlx +package metal import ( "math" diff --git a/internal/metal/cache_test.go b/internal/metal/cache_test.go new file mode 100644 index 0000000..de1e68d --- /dev/null +++ b/internal/metal/cache_test.go @@ -0,0 +1,198 @@ +//go:build darwin && arm64 + +package metal + +import ( + "testing" +) + +// makeKV creates a small K/V pair with shape [B=1, H=2, L=seqLen, D=4]. +func makeKV(seqLen int) (*Array, *Array) { + size := 1 * 2 * seqLen * 4 + data := make([]float32, size) + for i := range data { + data[i] = float32(i) * 0.1 + } + k := FromValues(data, 1, 2, seqLen, 4) + v := FromValues(data, 1, 2, seqLen, 4) + return k, v +} + +// --- KVCache --- + +func TestKVCache_New(t *testing.T) { + c := NewKVCache() + if c.Offset() != 0 { + t.Errorf("offset = %d, want 0", c.Offset()) + } + if c.Len() != 0 { + t.Errorf("len = %d, want 0", c.Len()) + } + if c.State() != nil { + t.Error("state should be nil for empty cache") + } +} + +func TestKVCache_SingleUpdate(t *testing.T) { + c := NewKVCache() + k, v := makeKV(3) // 3 tokens + + outK, outV := c.Update(k, v, 3) + Materialize(outK, outV) + + if c.Offset() != 3 { + t.Errorf("offset = %d, want 3", c.Offset()) + } + if c.Len() != 3 { + t.Errorf("len = %d, want 3", c.Len()) + } + + // Output K should have shape [1, 2, 3, 4] + shape := outK.Shape() + if shape[0] != 1 || shape[1] != 2 || shape[2] != 3 || shape[3] != 4 { + t.Errorf("outK shape = %v, want [1 2 3 4]", shape) + } +} + +func TestKVCache_MultipleUpdates(t *testing.T) { + c := NewKVCache() + + // Prompt: 5 tokens + k1, v1 := makeKV(5) + outK, outV := c.Update(k1, v1, 5) + Materialize(outK, outV) + + if c.Offset() != 5 { + t.Errorf("offset = %d, want 5", c.Offset()) + } + + // Generate: 1 token at a time + k2, v2 := makeKV(1) + outK, outV = c.Update(k2, v2, 1) + Materialize(outK, outV) + + if c.Offset() != 6 { + t.Errorf("offset = %d, want 6", c.Offset()) + } + + shape := outK.Shape() + if shape[2] != 6 { + t.Errorf("outK L dim = %d, want 6", shape[2]) + } +} + +func TestKVCache_Reset(t *testing.T) { + c := NewKVCache() + k, v := makeKV(3) + c.Update(k, v, 3) + + c.Reset() + + if c.Offset() != 0 { + t.Errorf("offset after reset = %d, want 0", c.Offset()) + } + if c.State() != nil { + t.Error("state should be nil after reset") + } +} + +func TestKVCache_State(t *testing.T) { + c := NewKVCache() + k, v := makeKV(2) + c.Update(k, v, 2) + + state := c.State() + if len(state) != 2 { + t.Fatalf("state length = %d, want 2", len(state)) + } + // state[0] = keys, state[1] = values + if state[0] == nil || state[1] == nil { + t.Error("state arrays should not be nil") + } +} + +// --- RotatingKVCache --- + +func TestRotatingKVCache_New(t *testing.T) { + c := NewRotatingKVCache(16) + if c.Offset() != 0 { + t.Errorf("offset = %d, want 0", c.Offset()) + } + if c.Len() != 0 { + t.Errorf("len = %d, want 0", c.Len()) + } +} + +func TestRotatingKVCache_SingleToken(t *testing.T) { + c := NewRotatingKVCache(8) + k, v := makeKV(1) + + outK, outV := c.Update(k, v, 1) + Materialize(outK, outV) + + if c.Offset() != 1 { + t.Errorf("offset = %d, want 1", c.Offset()) + } + if c.Len() != 1 { + t.Errorf("len = %d, want 1", c.Len()) + } +} + +func TestRotatingKVCache_MultiTokenPrompt(t *testing.T) { + c := NewRotatingKVCache(16) + k, v := makeKV(5) + + outK, outV := c.Update(k, v, 5) + Materialize(outK, outV) + + if c.Offset() != 5 { + t.Errorf("offset = %d, want 5", c.Offset()) + } + if c.Len() != 5 { + t.Errorf("len = %d, want 5", c.Len()) + } +} + +func TestRotatingKVCache_Bounded(t *testing.T) { + c := NewRotatingKVCache(4) + + // Fill with 4-token prompt (at max) + k, v := makeKV(4) + outK, outV := c.Update(k, v, 4) + Materialize(outK, outV) + + if c.Len() != 4 { + t.Errorf("len = %d, want 4 (at max)", c.Len()) + } + + // Add one more token — should trim to maxSize + k2, v2 := makeKV(1) + outK, outV = c.Update(k2, v2, 1) + Materialize(outK, outV) + + if c.Offset() != 5 { + t.Errorf("offset = %d, want 5", c.Offset()) + } + // Len should be bounded by maxSize + if c.Len() != 4 { + t.Errorf("len = %d, want 4 (bounded)", c.Len()) + } +} + +func TestRotatingKVCache_Reset(t *testing.T) { + c := NewRotatingKVCache(8) + k, v := makeKV(3) + c.Update(k, v, 3) + + c.Reset() + + if c.Offset() != 0 { + t.Errorf("offset after reset = %d, want 0", c.Offset()) + } + if c.Len() != 0 { + t.Errorf("len after reset = %d, want 0", c.Len()) + } + if c.State() != nil { + t.Error("state should be nil after reset") + } +} diff --git a/fast_test.go b/internal/metal/fast_test.go similarity index 99% rename from fast_test.go rename to internal/metal/fast_test.go index b896bce..b97419d 100644 --- a/fast_test.go +++ b/internal/metal/fast_test.go @@ -1,6 +1,6 @@ //go:build darwin && arm64 -package mlx +package metal import ( "math" diff --git a/grad_test.go b/internal/metal/grad_test.go similarity index 99% rename from grad_test.go rename to internal/metal/grad_test.go index 0a00751..59f43ed 100644 --- a/grad_test.go +++ b/internal/metal/grad_test.go @@ -1,6 +1,6 @@ //go:build darwin && arm64 -package mlx +package metal import ( "math" diff --git a/lora_test.go b/internal/metal/lora_test.go similarity index 99% rename from lora_test.go rename to internal/metal/lora_test.go index 729dffa..10b16e8 100644 --- a/lora_test.go +++ b/internal/metal/lora_test.go @@ -1,6 +1,6 @@ //go:build darwin && arm64 -package mlx +package metal import ( "math" diff --git a/nn_test.go b/internal/metal/nn_test.go similarity index 99% rename from nn_test.go rename to internal/metal/nn_test.go index f7cbf21..0e6c183 100644 --- a/nn_test.go +++ b/internal/metal/nn_test.go @@ -1,6 +1,6 @@ //go:build darwin && arm64 -package mlx +package metal import ( "math" diff --git a/ops_test.go b/internal/metal/ops_test.go similarity index 99% rename from ops_test.go rename to internal/metal/ops_test.go index d0d2b8f..85a154d 100644 --- a/ops_test.go +++ b/internal/metal/ops_test.go @@ -1,6 +1,6 @@ //go:build darwin && arm64 -package mlx +package metal import ( "math" diff --git a/optim_test.go b/internal/metal/optim_test.go similarity index 99% rename from optim_test.go rename to internal/metal/optim_test.go index b2df911..1cd41bd 100644 --- a/optim_test.go +++ b/internal/metal/optim_test.go @@ -1,6 +1,6 @@ //go:build darwin && arm64 -package mlx +package metal import ( "math" diff --git a/internal/metal/sample_test.go b/internal/metal/sample_test.go new file mode 100644 index 0000000..9cead08 --- /dev/null +++ b/internal/metal/sample_test.go @@ -0,0 +1,114 @@ +//go:build darwin && arm64 + +package metal + +import ( + "testing" +) + +func TestGreedy(t *testing.T) { + // Logits heavily favour index 2 + logits := FromValues([]float32{-10, -10, 100, -10}, 1, 4) + s := newSampler(0, 0, 0, 0) // temp=0 → greedy + token := s.Sample(logits) + Materialize(token) + + if token.Int() != 2 { + t.Errorf("greedy sample = %d, want 2", token.Int()) + } +} + +func TestTemperature_HighTemp(t *testing.T) { + // High temperature should still produce a valid index + logits := FromValues([]float32{1, 2, 3, 4}, 1, 4) + s := newSampler(100.0, 0, 0, 0) // very high temp → near uniform + token := s.Sample(logits) + Materialize(token) + + idx := token.Int() + if idx < 0 || idx >= 4 { + t.Errorf("sample index = %d, out of range [0, 4)", idx) + } +} + +func TestTemperature_LowTemp(t *testing.T) { + // Very low temperature should behave like greedy + logits := FromValues([]float32{-10, -10, 100, -10}, 1, 4) + s := newSampler(0.001, 0, 0, 0) // near-zero temp → near-greedy + token := s.Sample(logits) + Materialize(token) + + if token.Int() != 2 { + t.Errorf("low-temp sample = %d, want 2 (near greedy)", token.Int()) + } +} + +func TestSampler_TopK(t *testing.T) { + // TopK=1 with clear winner should always pick that token + logits := FromValues([]float32{-100, 100, -100, -100}, 1, 4) + s := newSampler(1.0, 0, 0, 1) // topK=1 + token := s.Sample(logits) + Materialize(token) + + if token.Int() != 1 { + t.Errorf("topk=1 sample = %d, want 1", token.Int()) + } +} + +func TestSampler_TopK_MultipleTokens(t *testing.T) { + // TopK=2, both high logits — should pick one of them + logits := FromValues([]float32{-100, 50, 50, -100}, 1, 4) + s := newSampler(1.0, 0, 0, 2) // topK=2 + + seen := map[int]bool{} + for range 20 { + token := s.Sample(logits) + Materialize(token) + seen[token.Int()] = true + } + + // Should only ever pick index 1 or 2 + for idx := range seen { + if idx != 1 && idx != 2 { + t.Errorf("topk=2 sampled index %d, expected only 1 or 2", idx) + } + } +} + +func TestNew_Chain(t *testing.T) { + // Full chain: topK + temperature + logits := FromValues([]float32{1, 2, 3, 4, 5}, 1, 5) + s := newSampler(0.5, 0, 0, 3) // temp=0.5, topK=3 + + token := s.Sample(logits) + Materialize(token) + + idx := token.Int() + if idx < 0 || idx >= 5 { + t.Errorf("chain sample index = %d, out of range", idx) + } +} + +func TestTopP_PassThrough(t *testing.T) { + // TopP is currently a stub — verify it doesn't break the chain + logits := FromValues([]float32{-10, -10, 100, -10}, 1, 4) + s := newSampler(0.5, 0.9, 0, 0) // topP=0.9 (stub), temp=0.5 + token := s.Sample(logits) + Materialize(token) + + if token.Int() != 2 { + t.Errorf("topP stub + temp sample = %d, want 2", token.Int()) + } +} + +func TestMinP_PassThrough(t *testing.T) { + // MinP is currently a stub — verify it doesn't break the chain + logits := FromValues([]float32{-10, -10, 100, -10}, 1, 4) + s := newSampler(0.5, 0, 0.1, 0) // minP=0.1 (stub), temp=0.5 + token := s.Sample(logits) + Materialize(token) + + if token.Int() != 2 { + t.Errorf("minP stub + temp sample = %d, want 2", token.Int()) + } +} diff --git a/internal/metal/tokenizer_test.go b/internal/metal/tokenizer_test.go new file mode 100644 index 0000000..eb92e41 --- /dev/null +++ b/internal/metal/tokenizer_test.go @@ -0,0 +1,217 @@ +//go:build darwin && arm64 + +package metal + +import ( + "os" + "path/filepath" + "testing" +) + +// minimalTokenizerJSON is a valid HuggingFace tokenizer.json with a tiny vocab. +const minimalTokenizerJSON = `{ + "model": { + "type": "BPE", + "vocab": { + "h": 0, + "e": 1, + "l": 2, + "o": 3, + "▁": 4, + "he": 5, + "ll": 6, + "▁h": 7 + }, + "merges": ["h e", "l l"], + "byte_fallback": false + }, + "added_tokens": [ + {"id": 100, "content": "", "special": true}, + {"id": 101, "content": "", "special": true} + ] +}` + +func writeTestTokenizer(t *testing.T) string { + t.Helper() + dir := t.TempDir() + path := filepath.Join(dir, "tokenizer.json") + if err := os.WriteFile(path, []byte(minimalTokenizerJSON), 0644); err != nil { + t.Fatalf("write test tokenizer: %v", err) + } + return path +} + +func TestLoad(t *testing.T) { + path := writeTestTokenizer(t) + tok, err := LoadTokenizer(path) + if err != nil { + t.Fatalf("Load: %v", err) + } + if tok == nil { + t.Fatal("tokenizer is nil") + } +} + +func TestLoad_MissingFile(t *testing.T) { + _, err := LoadTokenizer("/nonexistent/tokenizer.json") + if err == nil { + t.Error("expected error for missing file") + } +} + +func TestLoad_InvalidJSON(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "tokenizer.json") + os.WriteFile(path, []byte("not json"), 0644) + + _, err := LoadTokenizer(path) + if err == nil { + t.Error("expected error for invalid JSON") + } +} + +func TestBOSEOS(t *testing.T) { + path := writeTestTokenizer(t) + tok, _ := LoadTokenizer(path) + + if tok.BOSToken() != 100 { + t.Errorf("BOS = %d, want 100", tok.BOSToken()) + } + if tok.EOSToken() != 101 { + t.Errorf("EOS = %d, want 101", tok.EOSToken()) + } +} + +func TestEncode_ProducesTokens(t *testing.T) { + path := writeTestTokenizer(t) + tok, _ := LoadTokenizer(path) + + tokens := tok.Encode("hello") + if len(tokens) == 0 { + t.Fatal("Encode returned empty tokens") + } + // First token should be BOS + if tokens[0] != tok.BOSToken() { + t.Errorf("first token = %d, want BOS (%d)", tokens[0], tok.BOSToken()) + } + t.Logf("Encode(\"hello\") = %v", tokens) +} + +func TestDecode_SpecialTokensSkipped(t *testing.T) { + path := writeTestTokenizer(t) + tok, _ := LoadTokenizer(path) + + // Decoding BOS/EOS should produce empty string + text := tok.Decode([]int32{100, 101}) + if text != "" { + t.Errorf("Decode(BOS, EOS) = %q, want empty", text) + } +} + +func TestDecode_RegularTokens(t *testing.T) { + path := writeTestTokenizer(t) + tok, _ := LoadTokenizer(path) + + // Decode known vocab entries + text := tok.Decode([]int32{5, 6, 3}) // "he" + "ll" + "o" + if text != "hello" { + t.Errorf("Decode = %q, want %q", text, "hello") + } +} + +func TestDecodeToken_Regular(t *testing.T) { + path := writeTestTokenizer(t) + tok, _ := LoadTokenizer(path) + + // "he" = token 5 + text := tok.DecodeToken(5) + if text != "he" { + t.Errorf("DecodeToken(5) = %q, want %q", text, "he") + } +} + +func TestDecodeToken_Special(t *testing.T) { + path := writeTestTokenizer(t) + tok, _ := LoadTokenizer(path) + + // Special tokens should return empty + text := tok.DecodeToken(100) + if text != "" { + t.Errorf("DecodeToken(BOS) = %q, want empty", text) + } +} + +func TestDecodeToken_SentencePieceSpace(t *testing.T) { + path := writeTestTokenizer(t) + tok, _ := LoadTokenizer(path) + + // "▁h" = token 7, should decode to " h" (space prefix) + text := tok.DecodeToken(7) + if text != " h" { + t.Errorf("DecodeToken(7) = %q, want %q", text, " h") + } +} + +func TestDecodeToken_Unknown(t *testing.T) { + path := writeTestTokenizer(t) + tok, _ := LoadTokenizer(path) + + text := tok.DecodeToken(9999) + if text != "" { + t.Errorf("DecodeToken(unknown) = %q, want empty", text) + } +} + +func TestFormatGemmaPrompt(t *testing.T) { + got := FormatGemmaPrompt("What is 2+2?") + want := "user\nWhat is 2+2?\nmodel\n" + if got != want { + t.Errorf("FormatGemmaPrompt = %q, want %q", got, want) + } +} + +// --- GPT-2 byte maps --- + +func TestBuildGPT2ByteMaps(t *testing.T) { + decoder, encoder := buildGPT2ByteMaps() + + // All 256 bytes must be mapped + if len(encoder) != 256 { + t.Errorf("encoder has %d entries, want 256", len(encoder)) + } + if len(decoder) != 256 { + t.Errorf("decoder has %d entries, want 256", len(decoder)) + } + + // Round-trip: every byte should survive encode → decode + for b := 0; b < 256; b++ { + r := encoder[byte(b)] + got := decoder[r] + if got != byte(b) { + t.Errorf("byte %d: encode→decode = %d, want %d", b, got, b) + } + } +} + +func TestBuildGPT2ByteMaps_PrintableASCII(t *testing.T) { + _, encoder := buildGPT2ByteMaps() + + // Printable ASCII (33-126) should self-map + for b := 33; b <= 126; b++ { + if encoder[byte(b)] != rune(b) { + t.Errorf("byte %d (%c): expected self-map, got %c", b, b, encoder[byte(b)]) + } + } +} + +func TestBuildGPT2ByteMaps_ControlChars(t *testing.T) { + _, encoder := buildGPT2ByteMaps() + + // Space (32) and control chars (0-31) should NOT self-map + if encoder[byte(32)] == rune(32) { + t.Error("space (32) should not self-map in GPT-2 encoding") + } + if encoder[byte(0)] == rune(0) { + t.Error("null (0) should not self-map in GPT-2 encoding") + } +}