go-mlx/internal/metal/compile.go
Snider fb95cde30c fix(metal): address 5 important code review items
1. RepeatPenalty: implemented applyRepeatPenalty() — tracks generated
   token IDs, deduplicates, divides positive logits by penalty and
   multiplies negative logits by penalty. 2 new tests.

2. DefaultGPUStream/DefaultCPUStream: now cached with sync.Once,
   no more C stream allocation on every call.

3. CompileShapeless: removed dead C closure, callback, sync.Map,
   and nextID infrastructure. CompiledFunc is now a plain function
   wrapper with mutex. API unchanged.

4. Tokenizer BPE: implemented bpeMerge() — standard BPE algorithm
   using merge rank lookup. Both SentencePiece and GPT-2 Encode paths
   now apply merges instead of falling back to character-level lookup.
   3 new tests.

5. KV cache lifecycle: documented in Generate() godoc — fresh caches
   per call, ClearCache() between turns for prompt Metal reclaim.

Co-Authored-By: Virgil <virgil@lethean.io>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-19 21:31:45 +00:00

26 lines
751 B
Go

//go:build darwin && arm64
package metal
import "sync"
// CompiledFunc wraps a function for efficient repeated execution.
// The function is called directly; MLX's lazy evaluation graph
// still deduplicates and optimises the underlying Metal operations.
type CompiledFunc struct {
fn func([]*Array) []*Array
mu sync.Mutex
}
// CompileShapeless wraps a function for repeated execution.
// The shapeless parameter is accepted for API compatibility but unused.
func CompileShapeless(fn func([]*Array) []*Array, shapeless bool) *CompiledFunc {
return &CompiledFunc{fn: fn}
}
// Call executes the function with the given inputs.
func (cf *CompiledFunc) Call(inputs ...*Array) []*Array {
cf.mu.Lock()
defer cf.mu.Unlock()
return cf.fn(inputs)
}