diff --git a/TODO.md b/TODO.md index 39439be..ff5497c 100644 --- a/TODO.md +++ b/TODO.md @@ -91,15 +91,15 @@ Full codebase review after Phase 4 completion + go-inference integration. Groupe ### Important — Should Fix -- [ ] **KV cache leak between turns** — After `Generate()` returns, the KV cache arrays (allocated inside the closure at `generate.go:76`) are abandoned to GC finalizers. For multi-turn chat in LEM Lab, each turn allocates a new cache and the old one only gets freed when the GC eventually runs finalisers. Options: (a) accept a `*Cache` parameter so callers can reuse across turns, (b) add `Model.ResetCache()`, or (c) call `ClearCache()` after each turn. Document the expected pattern either way. +- [x] **KV cache leak between turns** — ✅ Documented in Generate() godoc: each call allocates fresh KV caches released to GC; call ClearCache() between turns for prompt reclaim. Cache reuse across turns deferred to batch inference design (Phase 5). -- [ ] **`RepeatPenalty` is accepted but never applied** — `GenerateConfig.RepeatPenalty` flows through from go-inference options and is stored in the config, but the generate loop has no token history and never applies it. Either implement it (track last N token IDs, penalise logits) or remove the field and document it as unsupported. +- [x] **`RepeatPenalty` is accepted but never applied** — ✅ Implemented `applyRepeatPenalty()` in generate.go. Tracks generated token IDs, deduplicates, then for each seen token: divides positive logits by penalty, multiplies negative logits by penalty. Applied before sampling when `RepeatPenalty > 1.0`. 2 new tests. -- [ ] **`DefaultGPUStream()` / `DefaultCPUStream()` leak and mislead** — `stream.go:32-41`: These create a new `C.mlx_stream` on every call but never free it. The names suggest they return *the* default stream (like `DefaultStream()` does with `sync.Once`), but they actually allocate. Either cache them like `DefaultStream()` or rename to `NewGPUStream()` / `NewCPUStream()` and add a `Free()` method. +- [x] **`DefaultGPUStream()` / `DefaultCPUStream()` leak and mislead** — ✅ Now cached with `sync.Once` like `DefaultStream()`. No more allocation on every call. -- [ ] **Tokenizer `Encode()` is character-level only** — BPE merges are parsed and stored (`tokenizer.go:77-96`) but never applied in `Encode()` or `encodeGPT2()`. Both methods fall back to single-character lookup. This produces far more tokens than a real BPE encoder, drastically reducing effective context length and increasing prefill time. For Phase 2 model validation this will need fixing, especially for Gemma3-1B at 5K sentences/sec. +- [x] **Tokenizer `Encode()` is character-level only** — ✅ Implemented `bpeMerge()` — standard BPE algorithm using merge rank lookup. Both SentencePiece `Encode()` and GPT-2 `encodeGPT2()` now split into characters, apply BPE merges, then look up merged symbols. Merge ranks built during tokenizer load. 3 new tests. -- [ ] **`CompileShapeless` is dead code** — `compile.go:82-89`: `Call()` bypasses the C closure entirely and just calls `cf.fn(inputs)` directly. The C closure, callback, and `sync.Map` registration are all unused overhead. The `shapeless` param is ignored. Either implement actual `mlx_compile` call-through or simplify to a plain function wrapper. +- [x] **`CompileShapeless` is dead code** — ✅ Removed C closure, callback, `sync.Map`, and `nextID` infrastructure. `CompiledFunc` is now a plain function wrapper with mutex. `CompileShapeless()` and `Call()` signatures unchanged (gemma3.go GELU still works). ### Minor — Nice to Have diff --git a/internal/metal/compile.go b/internal/metal/compile.go index c1edeb3..f6bf6d6 100644 --- a/internal/metal/compile.go +++ b/internal/metal/compile.go @@ -2,89 +2,25 @@ package metal -/* -#include "mlx/c/mlx.h" +import "sync" -// Callback for compiled functions. -extern int goCompiledFunc(mlx_vector_array *outputs, const mlx_vector_array inputs, void *payload); - -static mlx_closure new_closure(void *payload) { - return mlx_closure_new_func_payload(&goCompiledFunc, payload, NULL); -} -*/ -import "C" - -import ( - "sync" - "unsafe" -) - -// CompiledFunc wraps a compiled MLX computation graph for efficient repeated calls. +// CompiledFunc wraps a function for efficient repeated execution. +// The function is called directly; MLX's lazy evaluation graph +// still deduplicates and optimises the underlying Metal operations. type CompiledFunc struct { - fn func([]*Array) []*Array - closure C.mlx_closure - mu sync.Mutex + fn func([]*Array) []*Array + mu sync.Mutex } -var compiledFuncs sync.Map - -//export goCompiledFunc -func goCompiledFunc(outputs *C.mlx_vector_array, inputs C.mlx_vector_array, payload unsafe.Pointer) C.int { - id := uintptr(payload) - fnI, ok := compiledFuncs.Load(id) - if !ok { - return 1 - } - fn := fnI.(func([]*Array) []*Array) - - // Convert inputs - nInputs := int(C.mlx_vector_array_size(inputs)) - goInputs := make([]*Array, nInputs) - for i := 0; i < nInputs; i++ { - a := New("INPUT") - C.mlx_vector_array_get(&a.ctx, inputs, C.size_t(i)) - goInputs[i] = a - } - - // Call user function - goOutputs := fn(goInputs) - - // The output vector arrives with ctx=nullptr. Create a valid vector, - // fill it, then copy to the output pointer. - tmp := C.mlx_vector_array_new() - for _, out := range goOutputs { - C.mlx_vector_array_append_value(tmp, out.ctx) - } - C.mlx_vector_array_set(outputs, tmp) - C.mlx_vector_array_free(tmp) - return 0 -} - -var nextID uintptr -var nextIDMu sync.Mutex - -// CompileShapeless compiles a function for efficient repeated execution. -// The function must accept and return arrays of consistent shapes. +// CompileShapeless wraps a function for repeated execution. +// The shapeless parameter is accepted for API compatibility but unused. func CompileShapeless(fn func([]*Array) []*Array, shapeless bool) *CompiledFunc { - nextIDMu.Lock() - nextID++ - id := nextID - nextIDMu.Unlock() - - compiledFuncs.Store(id, fn) - - cf := &CompiledFunc{fn: fn} - cf.closure = C.new_closure(unsafe.Pointer(id)) - return cf + return &CompiledFunc{fn: fn} } -// Call executes the compiled function with the given inputs. +// Call executes the function with the given inputs. func (cf *CompiledFunc) Call(inputs ...*Array) []*Array { cf.mu.Lock() defer cf.mu.Unlock() - - // Fall back to direct call — compilation is an optimization. - // The compiled closure can be used via mlx_compiled but the - // direct path is simpler and still benefits from MLX's lazy evaluation. return cf.fn(inputs) } diff --git a/internal/metal/generate.go b/internal/metal/generate.go index c50ad9e..aa6de02 100644 --- a/internal/metal/generate.go +++ b/internal/metal/generate.go @@ -69,6 +69,10 @@ func (m *Model) Chat(ctx context.Context, messages []ChatMessage, cfg GenerateCo } // Generate streams tokens for the given prompt. +// +// Each call allocates fresh KV caches that are released to GC when the iterator +// completes. For multi-turn chat, call [ClearCache] between turns to reclaim +// Metal memory promptly rather than waiting for GC finalisers. func (m *Model) Generate(ctx context.Context, prompt string, cfg GenerateConfig) iter.Seq[Token] { m.lastErr = nil @@ -86,6 +90,9 @@ func (m *Model) Generate(ctx context.Context, prompt string, cfg GenerateConfig) return } + // Track generated token IDs for repeat penalty. + var history []int32 + for i := 0; i < cfg.MaxTokens; i++ { select { case <-ctx.Done(): @@ -97,6 +104,12 @@ func (m *Model) Generate(ctx context.Context, prompt string, cfg GenerateConfig) // Sample from last position logits lastPos := SliceAxis(logits, 1, int32(logits.Dim(1)-1), int32(logits.Dim(1))) lastPos = Reshape(lastPos, 1, int32(lastPos.Dim(2))) + + // Apply repeat penalty before sampling. + if cfg.RepeatPenalty > 1.0 && len(history) > 0 { + lastPos = applyRepeatPenalty(lastPos, history, cfg.RepeatPenalty) + } + next := sampler.Sample(lastPos) if err := Eval(next); err != nil { m.lastErr = fmt.Errorf("sample step %d: %w", i, err) @@ -104,6 +117,7 @@ func (m *Model) Generate(ctx context.Context, prompt string, cfg GenerateConfig) } id := int32(next.Int()) + history = append(history, id) // Check stop conditions if id == m.tokenizer.EOSToken() { @@ -132,6 +146,37 @@ func (m *Model) Generate(ctx context.Context, prompt string, cfg GenerateConfig) } } +// applyRepeatPenalty modifies logits to discourage repeated tokens. +// For each unique token ID in history: positive logits are divided by penalty, +// negative logits are multiplied by penalty. Both make the token less likely. +func applyRepeatPenalty(logits *Array, history []int32, penalty float32) *Array { + // Deduplicate history to get unique token IDs. + seen := make(map[int32]bool, len(history)) + var indices []int32 + for _, id := range history { + if !seen[id] { + seen[id] = true + indices = append(indices, id) + } + } + + idx := FromValues(indices, 1, len(indices)) + gathered := TakeAlongAxis(logits, idx, -1) + + zero := FromValue(float32(0)) + invPenalty := FromValue(1.0 / penalty) + penaltyVal := FromValue(penalty) + + // Positive logits: divide by penalty. Negative logits: multiply by penalty. + penalised := Where( + Greater(gathered, zero), + Mul(gathered, invPenalty), + Mul(gathered, penaltyVal), + ) + + return PutAlongAxis(logits, idx, penalised, -1) +} + // newCaches creates per-layer KV caches. If contextLen is set, all unbounded // caches are replaced with RotatingKVCache to cap memory usage. func (m *Model) newCaches() []Cache { diff --git a/internal/metal/sample_test.go b/internal/metal/sample_test.go index 2e98752..efc02a1 100644 --- a/internal/metal/sample_test.go +++ b/internal/metal/sample_test.go @@ -146,3 +146,47 @@ func TestMinP_RestrictsOptions(t *testing.T) { } } } + +func TestApplyRepeatPenalty(t *testing.T) { + // Logits: [1, 4] with values [5.0, -3.0, 1.0, 0.0] + // History: tokens 0 and 1 have been seen. + // Penalty 2.0: + // token 0 (logit 5.0 > 0): 5.0 / 2.0 = 2.5 + // token 1 (logit -3.0 < 0): -3.0 * 2.0 = -6.0 + // token 2 (not in history): unchanged = 1.0 + // token 3 (not in history): unchanged = 0.0 + logits := FromValues([]float32{5.0, -3.0, 1.0, 0.0}, 1, 4) + Materialize(logits) + + result := applyRepeatPenalty(logits, []int32{0, 1, 0}, 2.0) // duplicate 0 should be deduped + Materialize(result) + + got := result.Floats() + want := []float32{2.5, -6.0, 1.0, 0.0} + for i := range got { + diff := got[i] - want[i] + if diff > 0.01 || diff < -0.01 { + t.Errorf("repeatPenalty[%d] = %f, want %f", i, got[i], want[i]) + } + } +} + +func TestApplyRepeatPenalty_NoHistory(t *testing.T) { + // With empty history, logits should be unchanged. + logits := FromValues([]float32{5.0, -3.0, 1.0}, 1, 3) + Materialize(logits) + + // applyRepeatPenalty is not called when history is empty (checked in generate loop), + // but verify the function handles it gracefully if called directly. + result := applyRepeatPenalty(logits, []int32{1}, 1.0) // penalty=1.0 → no change + Materialize(result) + + got := result.Floats() + want := []float32{5.0, -3.0, 1.0} + for i := range got { + diff := got[i] - want[i] + if diff > 0.01 || diff < -0.01 { + t.Errorf("penalty=1.0[%d] = %f, want %f", i, got[i], want[i]) + } + } +} diff --git a/internal/metal/stream.go b/internal/metal/stream.go index e8c5b23..7e25241 100644 --- a/internal/metal/stream.go +++ b/internal/metal/stream.go @@ -17,6 +17,12 @@ type Stream struct { var ( defaultStream *Stream defaultStreamOnce sync.Once + + defaultGPUStream *Stream + defaultGPUStreamOnce sync.Once + + defaultCPUStream *Stream + defaultCPUStreamOnce sync.Once ) // DefaultStream returns the default GPU stream, creating it on first use. @@ -28,16 +34,22 @@ func DefaultStream() *Stream { return defaultStream } -// DefaultGPUStream returns a new GPU stream. +// DefaultGPUStream returns the cached default GPU stream. func DefaultGPUStream() *Stream { - Init() - return &Stream{ctx: C.mlx_default_gpu_stream_new()} + defaultGPUStreamOnce.Do(func() { + Init() + defaultGPUStream = &Stream{ctx: C.mlx_default_gpu_stream_new()} + }) + return defaultGPUStream } -// DefaultCPUStream returns a new CPU stream. +// DefaultCPUStream returns the cached default CPU stream. func DefaultCPUStream() *Stream { - Init() - return &Stream{ctx: C.mlx_default_cpu_stream_new()} + defaultCPUStreamOnce.Do(func() { + Init() + defaultCPUStream = &Stream{ctx: C.mlx_default_cpu_stream_new()} + }) + return defaultCPUStream } // Synchronize waits for all operations on the stream to complete. diff --git a/internal/metal/tokenizer.go b/internal/metal/tokenizer.go index 048202a..6ab8cc9 100644 --- a/internal/metal/tokenizer.go +++ b/internal/metal/tokenizer.go @@ -11,10 +11,11 @@ import ( // Tokenizer handles text-to-token and token-to-text conversion. type Tokenizer struct { - vocab map[string]int32 - invVocab map[int32]string - merges []mergePair - special map[string]int32 + vocab map[string]int32 + invVocab map[int32]string + merges []mergePair + mergeRanks map[string]int // "a b" → rank for O(1) merge lookup + special map[string]int32 bosToken int32 eosToken int32 @@ -95,6 +96,12 @@ func LoadTokenizer(path string) (*Tokenizer, error) { } } + // Build merge rank lookup for BPE. + t.mergeRanks = make(map[string]int, len(t.merges)) + for _, m := range t.merges { + t.mergeRanks[m.a+" "+m.b] = m.rank + } + // Parse special tokens for _, tok := range tj.AddedTokens { if tok.Special { @@ -166,17 +173,44 @@ func buildGPT2ByteMaps() (decoder map[rune]byte, encoder map[byte]rune) { return } +// bpeMerge applies BPE merges to a sequence of symbols until no more merges apply. +// Uses the standard algorithm: repeatedly find the lowest-rank adjacent pair and merge it. +func (t *Tokenizer) bpeMerge(symbols []string) []string { + for len(symbols) > 1 { + // Find the pair with the lowest merge rank. + bestRank := -1 + bestIdx := -1 + for i := 0; i < len(symbols)-1; i++ { + key := symbols[i] + " " + symbols[i+1] + if rank, ok := t.mergeRanks[key]; ok { + if bestRank < 0 || rank < bestRank { + bestRank = rank + bestIdx = i + } + } + } + if bestIdx < 0 { + break // No more merges available. + } + // Merge the pair at bestIdx. + merged := symbols[bestIdx] + symbols[bestIdx+1] + symbols = append(symbols[:bestIdx], append([]string{merged}, symbols[bestIdx+2:]...)...) + } + return symbols +} + // Encode converts text to token IDs. Prepends BOS token. func (t *Tokenizer) Encode(text string) []int32 { - tokens := []int32{t.bosToken} - if t.isGPT2BPE { return t.encodeGPT2(text) } - // SentencePiece style encoding + tokens := []int32{t.bosToken} + + // SentencePiece style: split into segments around special tokens, then BPE each segment. remaining := text for remaining != "" { + // Check for special tokens at the current position. found := false for tok, id := range t.special { if strings.HasPrefix(remaining, tok) { @@ -186,15 +220,35 @@ func (t *Tokenizer) Encode(text string) []int32 { break } } - if !found { - r := []rune(remaining) - ch := "▁" + string(r[0]) - if id, ok := t.vocab[ch]; ok { - tokens = append(tokens, id) - } else if id, ok := t.vocab[string(r[0])]; ok { + if found { + continue + } + + // Find the next special token boundary (or end of string). + end := len(remaining) + for tok := range t.special { + if idx := strings.Index(remaining, tok); idx > 0 && idx < end { + end = idx + } + } + segment := remaining[:end] + remaining = remaining[end:] + + // SentencePiece: prefix the segment with ▁ (space marker) and split into characters. + spText := "▁" + segment + symbols := make([]string, 0, len([]rune(spText))) + for _, r := range spText { + symbols = append(symbols, string(r)) + } + + // Apply BPE merges. + symbols = t.bpeMerge(symbols) + + // Look up merged symbols in vocab. + for _, sym := range symbols { + if id, ok := t.vocab[sym]; ok { tokens = append(tokens, id) } - remaining = string(r[1:]) } } @@ -205,45 +259,56 @@ func (t *Tokenizer) Encode(text string) []int32 { func (t *Tokenizer) encodeGPT2(text string) []int32 { tokens := []int32{t.bosToken} - // Convert text bytes to GPT-2 Unicode representation - var encoded strings.Builder - for _, b := range []byte(text) { - if r, ok := t.gpt2Encoder[b]; ok { - encoded.WriteRune(r) - } - } - gpt2Text := encoded.String() - - // Scan for special tokens and regular text - remaining := gpt2Text + // Split text around special tokens (matched in original form, not byte-encoded). + remaining := text for remaining != "" { - // Check special tokens (these are stored as-is, not byte-encoded) + // Check for special tokens at the current position. found := false for tok, id := range t.special { - // Special tokens in GPT-2 tokenizers are stored in their original form - // Convert the special token to GPT-2 encoding for matching - var encTok strings.Builder - for _, b := range []byte(tok) { - if r, ok := t.gpt2Encoder[b]; ok { - encTok.WriteRune(r) - } - } - encStr := encTok.String() - if strings.HasPrefix(remaining, encStr) { + if strings.HasPrefix(remaining, tok) { tokens = append(tokens, id) - remaining = remaining[len(encStr):] + remaining = remaining[len(tok):] found = true break } } - if !found { - // Character-by-character lookup (simplified BPE) - r := []rune(remaining) - ch := string(r[0]) - if id, ok := t.vocab[ch]; ok { + if found { + continue + } + + // Find the next special token boundary (or end of string). + end := len(remaining) + for tok := range t.special { + if idx := strings.Index(remaining, tok); idx > 0 && idx < end { + end = idx + } + } + segment := remaining[:end] + remaining = remaining[end:] + + // Convert segment bytes to GPT-2 Unicode representation. + var encoded strings.Builder + for _, b := range []byte(segment) { + if r, ok := t.gpt2Encoder[b]; ok { + encoded.WriteRune(r) + } + } + + // Split into individual runes (GPT-2 BPE operates on Unicode chars). + runes := []rune(encoded.String()) + symbols := make([]string, len(runes)) + for i, r := range runes { + symbols[i] = string(r) + } + + // Apply BPE merges. + symbols = t.bpeMerge(symbols) + + // Look up merged symbols in vocab. + for _, sym := range symbols { + if id, ok := t.vocab[sym]; ok { tokens = append(tokens, id) } - remaining = string(r[1:]) } } diff --git a/internal/metal/tokenizer_test.go b/internal/metal/tokenizer_test.go index eb92e41..b95bf83 100644 --- a/internal/metal/tokenizer_test.go +++ b/internal/metal/tokenizer_test.go @@ -94,7 +94,62 @@ func TestEncode_ProducesTokens(t *testing.T) { if tokens[0] != tok.BOSToken() { t.Errorf("first token = %d, want BOS (%d)", tokens[0], tok.BOSToken()) } - t.Logf("Encode(\"hello\") = %v", tokens) + // With BPE merges ("h e" → "he", "l l" → "ll"), "hello" with ▁ prefix becomes: + // "▁" "h" "e" "l" "l" "o" → merge "h e" → "▁" "he" "l" "l" "o" + // → merge "l l" → "▁" "he" "ll" "o" + // No further merges. But "▁" is not "▁h" so it stays as "▁". + // Vocab: ▁=4, he=5, ll=6, o=3. Expected: [BOS, 4, 5, 6, 3] + want := []int32{100, 4, 5, 6, 3} + if len(tokens) != len(want) { + t.Fatalf("Encode(\"hello\") = %v, want %v", tokens, want) + } + for i := range tokens { + if tokens[i] != want[i] { + t.Errorf("tokens[%d] = %d, want %d", i, tokens[i], want[i]) + } + } +} + +func TestBPEMerge(t *testing.T) { + tok := &Tokenizer{ + mergeRanks: map[string]int{ + "h e": 0, + "l l": 1, + "he l": 2, + }, + } + + // "h" "e" "l" "l" "o" → merge "h e" (rank 0) → "he" "l" "l" "o" + // → merge "l l" (rank 1) → "he" "ll" "o" + // → merge "he l" does NOT match "he ll" — stops here. + symbols := []string{"h", "e", "l", "l", "o"} + got := tok.bpeMerge(symbols) + want := []string{"he", "ll", "o"} + if len(got) != len(want) { + t.Fatalf("bpeMerge = %v, want %v", got, want) + } + for i := range got { + if got[i] != want[i] { + t.Errorf("bpeMerge[%d] = %q, want %q", i, got[i], want[i]) + } + } +} + +func TestBPEMerge_NoMerges(t *testing.T) { + tok := &Tokenizer{mergeRanks: map[string]int{}} + symbols := []string{"a", "b", "c"} + got := tok.bpeMerge(symbols) + if len(got) != 3 { + t.Errorf("bpeMerge with no merges = %v, want [a b c]", got) + } +} + +func TestBPEMerge_SingleSymbol(t *testing.T) { + tok := &Tokenizer{mergeRanks: map[string]int{"a b": 0}} + got := tok.bpeMerge([]string{"x"}) + if len(got) != 1 || got[0] != "x" { + t.Errorf("bpeMerge single = %v, want [x]", got) + } } func TestDecode_SpecialTokensSkipped(t *testing.T) {