fix(metal): address 5 important code review items

1. RepeatPenalty: implemented applyRepeatPenalty() — tracks generated
   token IDs, deduplicates, divides positive logits by penalty and
   multiplies negative logits by penalty. 2 new tests.

2. DefaultGPUStream/DefaultCPUStream: now cached with sync.Once,
   no more C stream allocation on every call.

3. CompileShapeless: removed dead C closure, callback, sync.Map,
   and nextID infrastructure. CompiledFunc is now a plain function
   wrapper with mutex. API unchanged.

4. Tokenizer BPE: implemented bpeMerge() — standard BPE algorithm
   using merge rank lookup. Both SentencePiece and GPT-2 Encode paths
   now apply merges instead of falling back to character-level lookup.
   3 new tests.

5. KV cache lifecycle: documented in Generate() godoc — fresh caches
   per call, ClearCache() between turns for prompt Metal reclaim.

Co-Authored-By: Virgil <virgil@lethean.io>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Snider 2026-02-19 21:31:45 +00:00
parent c96f9bd006
commit fb95cde30c
7 changed files with 286 additions and 129 deletions

10
TODO.md
View file

@ -91,15 +91,15 @@ Full codebase review after Phase 4 completion + go-inference integration. Groupe
### Important — Should Fix
- [ ] **KV cache leak between turns** — After `Generate()` returns, the KV cache arrays (allocated inside the closure at `generate.go:76`) are abandoned to GC finalizers. For multi-turn chat in LEM Lab, each turn allocates a new cache and the old one only gets freed when the GC eventually runs finalisers. Options: (a) accept a `*Cache` parameter so callers can reuse across turns, (b) add `Model.ResetCache()`, or (c) call `ClearCache()` after each turn. Document the expected pattern either way.
- [x] **KV cache leak between turns** — ✅ Documented in Generate() godoc: each call allocates fresh KV caches released to GC; call ClearCache() between turns for prompt reclaim. Cache reuse across turns deferred to batch inference design (Phase 5).
- [ ] **`RepeatPenalty` is accepted but never applied** — `GenerateConfig.RepeatPenalty` flows through from go-inference options and is stored in the config, but the generate loop has no token history and never applies it. Either implement it (track last N token IDs, penalise logits) or remove the field and document it as unsupported.
- [x] **`RepeatPenalty` is accepted but never applied** — ✅ Implemented `applyRepeatPenalty()` in generate.go. Tracks generated token IDs, deduplicates, then for each seen token: divides positive logits by penalty, multiplies negative logits by penalty. Applied before sampling when `RepeatPenalty > 1.0`. 2 new tests.
- [ ] **`DefaultGPUStream()` / `DefaultCPUStream()` leak and mislead** — `stream.go:32-41`: These create a new `C.mlx_stream` on every call but never free it. The names suggest they return *the* default stream (like `DefaultStream()` does with `sync.Once`), but they actually allocate. Either cache them like `DefaultStream()` or rename to `NewGPUStream()` / `NewCPUStream()` and add a `Free()` method.
- [x] **`DefaultGPUStream()` / `DefaultCPUStream()` leak and mislead** — ✅ Now cached with `sync.Once` like `DefaultStream()`. No more allocation on every call.
- [ ] **Tokenizer `Encode()` is character-level only** — BPE merges are parsed and stored (`tokenizer.go:77-96`) but never applied in `Encode()` or `encodeGPT2()`. Both methods fall back to single-character lookup. This produces far more tokens than a real BPE encoder, drastically reducing effective context length and increasing prefill time. For Phase 2 model validation this will need fixing, especially for Gemma3-1B at 5K sentences/sec.
- [x] **Tokenizer `Encode()` is character-level only** — ✅ Implemented `bpeMerge()` — standard BPE algorithm using merge rank lookup. Both SentencePiece `Encode()` and GPT-2 `encodeGPT2()` now split into characters, apply BPE merges, then look up merged symbols. Merge ranks built during tokenizer load. 3 new tests.
- [ ] **`CompileShapeless` is dead code** — `compile.go:82-89`: `Call()` bypasses the C closure entirely and just calls `cf.fn(inputs)` directly. The C closure, callback, and `sync.Map` registration are all unused overhead. The `shapeless` param is ignored. Either implement actual `mlx_compile` call-through or simplify to a plain function wrapper.
- [x] **`CompileShapeless` is dead code** — ✅ Removed C closure, callback, `sync.Map`, and `nextID` infrastructure. `CompiledFunc` is now a plain function wrapper with mutex. `CompileShapeless()` and `Call()` signatures unchanged (gemma3.go GELU still works).
### Minor — Nice to Have

View file

@ -2,89 +2,25 @@
package metal
/*
#include "mlx/c/mlx.h"
import "sync"
// Callback for compiled functions.
extern int goCompiledFunc(mlx_vector_array *outputs, const mlx_vector_array inputs, void *payload);
static mlx_closure new_closure(void *payload) {
return mlx_closure_new_func_payload(&goCompiledFunc, payload, NULL);
}
*/
import "C"
import (
"sync"
"unsafe"
)
// CompiledFunc wraps a compiled MLX computation graph for efficient repeated calls.
// CompiledFunc wraps a function for efficient repeated execution.
// The function is called directly; MLX's lazy evaluation graph
// still deduplicates and optimises the underlying Metal operations.
type CompiledFunc struct {
fn func([]*Array) []*Array
closure C.mlx_closure
mu sync.Mutex
fn func([]*Array) []*Array
mu sync.Mutex
}
var compiledFuncs sync.Map
//export goCompiledFunc
func goCompiledFunc(outputs *C.mlx_vector_array, inputs C.mlx_vector_array, payload unsafe.Pointer) C.int {
id := uintptr(payload)
fnI, ok := compiledFuncs.Load(id)
if !ok {
return 1
}
fn := fnI.(func([]*Array) []*Array)
// Convert inputs
nInputs := int(C.mlx_vector_array_size(inputs))
goInputs := make([]*Array, nInputs)
for i := 0; i < nInputs; i++ {
a := New("INPUT")
C.mlx_vector_array_get(&a.ctx, inputs, C.size_t(i))
goInputs[i] = a
}
// Call user function
goOutputs := fn(goInputs)
// The output vector arrives with ctx=nullptr. Create a valid vector,
// fill it, then copy to the output pointer.
tmp := C.mlx_vector_array_new()
for _, out := range goOutputs {
C.mlx_vector_array_append_value(tmp, out.ctx)
}
C.mlx_vector_array_set(outputs, tmp)
C.mlx_vector_array_free(tmp)
return 0
}
var nextID uintptr
var nextIDMu sync.Mutex
// CompileShapeless compiles a function for efficient repeated execution.
// The function must accept and return arrays of consistent shapes.
// CompileShapeless wraps a function for repeated execution.
// The shapeless parameter is accepted for API compatibility but unused.
func CompileShapeless(fn func([]*Array) []*Array, shapeless bool) *CompiledFunc {
nextIDMu.Lock()
nextID++
id := nextID
nextIDMu.Unlock()
compiledFuncs.Store(id, fn)
cf := &CompiledFunc{fn: fn}
cf.closure = C.new_closure(unsafe.Pointer(id))
return cf
return &CompiledFunc{fn: fn}
}
// Call executes the compiled function with the given inputs.
// Call executes the function with the given inputs.
func (cf *CompiledFunc) Call(inputs ...*Array) []*Array {
cf.mu.Lock()
defer cf.mu.Unlock()
// Fall back to direct call — compilation is an optimization.
// The compiled closure can be used via mlx_compiled but the
// direct path is simpler and still benefits from MLX's lazy evaluation.
return cf.fn(inputs)
}

View file

@ -69,6 +69,10 @@ func (m *Model) Chat(ctx context.Context, messages []ChatMessage, cfg GenerateCo
}
// Generate streams tokens for the given prompt.
//
// Each call allocates fresh KV caches that are released to GC when the iterator
// completes. For multi-turn chat, call [ClearCache] between turns to reclaim
// Metal memory promptly rather than waiting for GC finalisers.
func (m *Model) Generate(ctx context.Context, prompt string, cfg GenerateConfig) iter.Seq[Token] {
m.lastErr = nil
@ -86,6 +90,9 @@ func (m *Model) Generate(ctx context.Context, prompt string, cfg GenerateConfig)
return
}
// Track generated token IDs for repeat penalty.
var history []int32
for i := 0; i < cfg.MaxTokens; i++ {
select {
case <-ctx.Done():
@ -97,6 +104,12 @@ func (m *Model) Generate(ctx context.Context, prompt string, cfg GenerateConfig)
// Sample from last position logits
lastPos := SliceAxis(logits, 1, int32(logits.Dim(1)-1), int32(logits.Dim(1)))
lastPos = Reshape(lastPos, 1, int32(lastPos.Dim(2)))
// Apply repeat penalty before sampling.
if cfg.RepeatPenalty > 1.0 && len(history) > 0 {
lastPos = applyRepeatPenalty(lastPos, history, cfg.RepeatPenalty)
}
next := sampler.Sample(lastPos)
if err := Eval(next); err != nil {
m.lastErr = fmt.Errorf("sample step %d: %w", i, err)
@ -104,6 +117,7 @@ func (m *Model) Generate(ctx context.Context, prompt string, cfg GenerateConfig)
}
id := int32(next.Int())
history = append(history, id)
// Check stop conditions
if id == m.tokenizer.EOSToken() {
@ -132,6 +146,37 @@ func (m *Model) Generate(ctx context.Context, prompt string, cfg GenerateConfig)
}
}
// applyRepeatPenalty modifies logits to discourage repeated tokens.
// For each unique token ID in history: positive logits are divided by penalty,
// negative logits are multiplied by penalty. Both make the token less likely.
func applyRepeatPenalty(logits *Array, history []int32, penalty float32) *Array {
// Deduplicate history to get unique token IDs.
seen := make(map[int32]bool, len(history))
var indices []int32
for _, id := range history {
if !seen[id] {
seen[id] = true
indices = append(indices, id)
}
}
idx := FromValues(indices, 1, len(indices))
gathered := TakeAlongAxis(logits, idx, -1)
zero := FromValue(float32(0))
invPenalty := FromValue(1.0 / penalty)
penaltyVal := FromValue(penalty)
// Positive logits: divide by penalty. Negative logits: multiply by penalty.
penalised := Where(
Greater(gathered, zero),
Mul(gathered, invPenalty),
Mul(gathered, penaltyVal),
)
return PutAlongAxis(logits, idx, penalised, -1)
}
// newCaches creates per-layer KV caches. If contextLen is set, all unbounded
// caches are replaced with RotatingKVCache to cap memory usage.
func (m *Model) newCaches() []Cache {

View file

@ -146,3 +146,47 @@ func TestMinP_RestrictsOptions(t *testing.T) {
}
}
}
func TestApplyRepeatPenalty(t *testing.T) {
// Logits: [1, 4] with values [5.0, -3.0, 1.0, 0.0]
// History: tokens 0 and 1 have been seen.
// Penalty 2.0:
// token 0 (logit 5.0 > 0): 5.0 / 2.0 = 2.5
// token 1 (logit -3.0 < 0): -3.0 * 2.0 = -6.0
// token 2 (not in history): unchanged = 1.0
// token 3 (not in history): unchanged = 0.0
logits := FromValues([]float32{5.0, -3.0, 1.0, 0.0}, 1, 4)
Materialize(logits)
result := applyRepeatPenalty(logits, []int32{0, 1, 0}, 2.0) // duplicate 0 should be deduped
Materialize(result)
got := result.Floats()
want := []float32{2.5, -6.0, 1.0, 0.0}
for i := range got {
diff := got[i] - want[i]
if diff > 0.01 || diff < -0.01 {
t.Errorf("repeatPenalty[%d] = %f, want %f", i, got[i], want[i])
}
}
}
func TestApplyRepeatPenalty_NoHistory(t *testing.T) {
// With empty history, logits should be unchanged.
logits := FromValues([]float32{5.0, -3.0, 1.0}, 1, 3)
Materialize(logits)
// applyRepeatPenalty is not called when history is empty (checked in generate loop),
// but verify the function handles it gracefully if called directly.
result := applyRepeatPenalty(logits, []int32{1}, 1.0) // penalty=1.0 → no change
Materialize(result)
got := result.Floats()
want := []float32{5.0, -3.0, 1.0}
for i := range got {
diff := got[i] - want[i]
if diff > 0.01 || diff < -0.01 {
t.Errorf("penalty=1.0[%d] = %f, want %f", i, got[i], want[i])
}
}
}

View file

@ -17,6 +17,12 @@ type Stream struct {
var (
defaultStream *Stream
defaultStreamOnce sync.Once
defaultGPUStream *Stream
defaultGPUStreamOnce sync.Once
defaultCPUStream *Stream
defaultCPUStreamOnce sync.Once
)
// DefaultStream returns the default GPU stream, creating it on first use.
@ -28,16 +34,22 @@ func DefaultStream() *Stream {
return defaultStream
}
// DefaultGPUStream returns a new GPU stream.
// DefaultGPUStream returns the cached default GPU stream.
func DefaultGPUStream() *Stream {
Init()
return &Stream{ctx: C.mlx_default_gpu_stream_new()}
defaultGPUStreamOnce.Do(func() {
Init()
defaultGPUStream = &Stream{ctx: C.mlx_default_gpu_stream_new()}
})
return defaultGPUStream
}
// DefaultCPUStream returns a new CPU stream.
// DefaultCPUStream returns the cached default CPU stream.
func DefaultCPUStream() *Stream {
Init()
return &Stream{ctx: C.mlx_default_cpu_stream_new()}
defaultCPUStreamOnce.Do(func() {
Init()
defaultCPUStream = &Stream{ctx: C.mlx_default_cpu_stream_new()}
})
return defaultCPUStream
}
// Synchronize waits for all operations on the stream to complete.

View file

@ -11,10 +11,11 @@ import (
// Tokenizer handles text-to-token and token-to-text conversion.
type Tokenizer struct {
vocab map[string]int32
invVocab map[int32]string
merges []mergePair
special map[string]int32
vocab map[string]int32
invVocab map[int32]string
merges []mergePair
mergeRanks map[string]int // "a b" → rank for O(1) merge lookup
special map[string]int32
bosToken int32
eosToken int32
@ -95,6 +96,12 @@ func LoadTokenizer(path string) (*Tokenizer, error) {
}
}
// Build merge rank lookup for BPE.
t.mergeRanks = make(map[string]int, len(t.merges))
for _, m := range t.merges {
t.mergeRanks[m.a+" "+m.b] = m.rank
}
// Parse special tokens
for _, tok := range tj.AddedTokens {
if tok.Special {
@ -166,17 +173,44 @@ func buildGPT2ByteMaps() (decoder map[rune]byte, encoder map[byte]rune) {
return
}
// bpeMerge applies BPE merges to a sequence of symbols until no more merges apply.
// Uses the standard algorithm: repeatedly find the lowest-rank adjacent pair and merge it.
func (t *Tokenizer) bpeMerge(symbols []string) []string {
for len(symbols) > 1 {
// Find the pair with the lowest merge rank.
bestRank := -1
bestIdx := -1
for i := 0; i < len(symbols)-1; i++ {
key := symbols[i] + " " + symbols[i+1]
if rank, ok := t.mergeRanks[key]; ok {
if bestRank < 0 || rank < bestRank {
bestRank = rank
bestIdx = i
}
}
}
if bestIdx < 0 {
break // No more merges available.
}
// Merge the pair at bestIdx.
merged := symbols[bestIdx] + symbols[bestIdx+1]
symbols = append(symbols[:bestIdx], append([]string{merged}, symbols[bestIdx+2:]...)...)
}
return symbols
}
// Encode converts text to token IDs. Prepends BOS token.
func (t *Tokenizer) Encode(text string) []int32 {
tokens := []int32{t.bosToken}
if t.isGPT2BPE {
return t.encodeGPT2(text)
}
// SentencePiece style encoding
tokens := []int32{t.bosToken}
// SentencePiece style: split into segments around special tokens, then BPE each segment.
remaining := text
for remaining != "" {
// Check for special tokens at the current position.
found := false
for tok, id := range t.special {
if strings.HasPrefix(remaining, tok) {
@ -186,15 +220,35 @@ func (t *Tokenizer) Encode(text string) []int32 {
break
}
}
if !found {
r := []rune(remaining)
ch := "▁" + string(r[0])
if id, ok := t.vocab[ch]; ok {
tokens = append(tokens, id)
} else if id, ok := t.vocab[string(r[0])]; ok {
if found {
continue
}
// Find the next special token boundary (or end of string).
end := len(remaining)
for tok := range t.special {
if idx := strings.Index(remaining, tok); idx > 0 && idx < end {
end = idx
}
}
segment := remaining[:end]
remaining = remaining[end:]
// SentencePiece: prefix the segment with ▁ (space marker) and split into characters.
spText := "▁" + segment
symbols := make([]string, 0, len([]rune(spText)))
for _, r := range spText {
symbols = append(symbols, string(r))
}
// Apply BPE merges.
symbols = t.bpeMerge(symbols)
// Look up merged symbols in vocab.
for _, sym := range symbols {
if id, ok := t.vocab[sym]; ok {
tokens = append(tokens, id)
}
remaining = string(r[1:])
}
}
@ -205,45 +259,56 @@ func (t *Tokenizer) Encode(text string) []int32 {
func (t *Tokenizer) encodeGPT2(text string) []int32 {
tokens := []int32{t.bosToken}
// Convert text bytes to GPT-2 Unicode representation
var encoded strings.Builder
for _, b := range []byte(text) {
if r, ok := t.gpt2Encoder[b]; ok {
encoded.WriteRune(r)
}
}
gpt2Text := encoded.String()
// Scan for special tokens and regular text
remaining := gpt2Text
// Split text around special tokens (matched in original form, not byte-encoded).
remaining := text
for remaining != "" {
// Check special tokens (these are stored as-is, not byte-encoded)
// Check for special tokens at the current position.
found := false
for tok, id := range t.special {
// Special tokens in GPT-2 tokenizers are stored in their original form
// Convert the special token to GPT-2 encoding for matching
var encTok strings.Builder
for _, b := range []byte(tok) {
if r, ok := t.gpt2Encoder[b]; ok {
encTok.WriteRune(r)
}
}
encStr := encTok.String()
if strings.HasPrefix(remaining, encStr) {
if strings.HasPrefix(remaining, tok) {
tokens = append(tokens, id)
remaining = remaining[len(encStr):]
remaining = remaining[len(tok):]
found = true
break
}
}
if !found {
// Character-by-character lookup (simplified BPE)
r := []rune(remaining)
ch := string(r[0])
if id, ok := t.vocab[ch]; ok {
if found {
continue
}
// Find the next special token boundary (or end of string).
end := len(remaining)
for tok := range t.special {
if idx := strings.Index(remaining, tok); idx > 0 && idx < end {
end = idx
}
}
segment := remaining[:end]
remaining = remaining[end:]
// Convert segment bytes to GPT-2 Unicode representation.
var encoded strings.Builder
for _, b := range []byte(segment) {
if r, ok := t.gpt2Encoder[b]; ok {
encoded.WriteRune(r)
}
}
// Split into individual runes (GPT-2 BPE operates on Unicode chars).
runes := []rune(encoded.String())
symbols := make([]string, len(runes))
for i, r := range runes {
symbols[i] = string(r)
}
// Apply BPE merges.
symbols = t.bpeMerge(symbols)
// Look up merged symbols in vocab.
for _, sym := range symbols {
if id, ok := t.vocab[sym]; ok {
tokens = append(tokens, id)
}
remaining = string(r[1:])
}
}

View file

@ -94,7 +94,62 @@ func TestEncode_ProducesTokens(t *testing.T) {
if tokens[0] != tok.BOSToken() {
t.Errorf("first token = %d, want BOS (%d)", tokens[0], tok.BOSToken())
}
t.Logf("Encode(\"hello\") = %v", tokens)
// With BPE merges ("h e" → "he", "l l" → "ll"), "hello" with ▁ prefix becomes:
// "▁" "h" "e" "l" "l" "o" → merge "h e" → "▁" "he" "l" "l" "o"
// → merge "l l" → "▁" "he" "ll" "o"
// No further merges. But "▁" is not "▁h" so it stays as "▁".
// Vocab: ▁=4, he=5, ll=6, o=3. Expected: [BOS, 4, 5, 6, 3]
want := []int32{100, 4, 5, 6, 3}
if len(tokens) != len(want) {
t.Fatalf("Encode(\"hello\") = %v, want %v", tokens, want)
}
for i := range tokens {
if tokens[i] != want[i] {
t.Errorf("tokens[%d] = %d, want %d", i, tokens[i], want[i])
}
}
}
func TestBPEMerge(t *testing.T) {
tok := &Tokenizer{
mergeRanks: map[string]int{
"h e": 0,
"l l": 1,
"he l": 2,
},
}
// "h" "e" "l" "l" "o" → merge "h e" (rank 0) → "he" "l" "l" "o"
// → merge "l l" (rank 1) → "he" "ll" "o"
// → merge "he l" does NOT match "he ll" — stops here.
symbols := []string{"h", "e", "l", "l", "o"}
got := tok.bpeMerge(symbols)
want := []string{"he", "ll", "o"}
if len(got) != len(want) {
t.Fatalf("bpeMerge = %v, want %v", got, want)
}
for i := range got {
if got[i] != want[i] {
t.Errorf("bpeMerge[%d] = %q, want %q", i, got[i], want[i])
}
}
}
func TestBPEMerge_NoMerges(t *testing.T) {
tok := &Tokenizer{mergeRanks: map[string]int{}}
symbols := []string{"a", "b", "c"}
got := tok.bpeMerge(symbols)
if len(got) != 3 {
t.Errorf("bpeMerge with no merges = %v, want [a b c]", got)
}
}
func TestBPEMerge_SingleSymbol(t *testing.T) {
tok := &Tokenizer{mergeRanks: map[string]int{"a b": 0}}
got := tok.bpeMerge([]string{"x"})
if len(got) != 1 || got[0] != "x" {
t.Errorf("bpeMerge single = %v, want [x]", got)
}
}
func TestDecode_SpecialTokensSkipped(t *testing.T) {