fix(metal): address 5 important code review items
1. RepeatPenalty: implemented applyRepeatPenalty() — tracks generated token IDs, deduplicates, divides positive logits by penalty and multiplies negative logits by penalty. 2 new tests. 2. DefaultGPUStream/DefaultCPUStream: now cached with sync.Once, no more C stream allocation on every call. 3. CompileShapeless: removed dead C closure, callback, sync.Map, and nextID infrastructure. CompiledFunc is now a plain function wrapper with mutex. API unchanged. 4. Tokenizer BPE: implemented bpeMerge() — standard BPE algorithm using merge rank lookup. Both SentencePiece and GPT-2 Encode paths now apply merges instead of falling back to character-level lookup. 3 new tests. 5. KV cache lifecycle: documented in Generate() godoc — fresh caches per call, ClearCache() between turns for prompt Metal reclaim. Co-Authored-By: Virgil <virgil@lethean.io> Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
c96f9bd006
commit
fb95cde30c
7 changed files with 286 additions and 129 deletions
10
TODO.md
10
TODO.md
|
|
@ -91,15 +91,15 @@ Full codebase review after Phase 4 completion + go-inference integration. Groupe
|
|||
|
||||
### Important — Should Fix
|
||||
|
||||
- [ ] **KV cache leak between turns** — After `Generate()` returns, the KV cache arrays (allocated inside the closure at `generate.go:76`) are abandoned to GC finalizers. For multi-turn chat in LEM Lab, each turn allocates a new cache and the old one only gets freed when the GC eventually runs finalisers. Options: (a) accept a `*Cache` parameter so callers can reuse across turns, (b) add `Model.ResetCache()`, or (c) call `ClearCache()` after each turn. Document the expected pattern either way.
|
||||
- [x] **KV cache leak between turns** — ✅ Documented in Generate() godoc: each call allocates fresh KV caches released to GC; call ClearCache() between turns for prompt reclaim. Cache reuse across turns deferred to batch inference design (Phase 5).
|
||||
|
||||
- [ ] **`RepeatPenalty` is accepted but never applied** — `GenerateConfig.RepeatPenalty` flows through from go-inference options and is stored in the config, but the generate loop has no token history and never applies it. Either implement it (track last N token IDs, penalise logits) or remove the field and document it as unsupported.
|
||||
- [x] **`RepeatPenalty` is accepted but never applied** — ✅ Implemented `applyRepeatPenalty()` in generate.go. Tracks generated token IDs, deduplicates, then for each seen token: divides positive logits by penalty, multiplies negative logits by penalty. Applied before sampling when `RepeatPenalty > 1.0`. 2 new tests.
|
||||
|
||||
- [ ] **`DefaultGPUStream()` / `DefaultCPUStream()` leak and mislead** — `stream.go:32-41`: These create a new `C.mlx_stream` on every call but never free it. The names suggest they return *the* default stream (like `DefaultStream()` does with `sync.Once`), but they actually allocate. Either cache them like `DefaultStream()` or rename to `NewGPUStream()` / `NewCPUStream()` and add a `Free()` method.
|
||||
- [x] **`DefaultGPUStream()` / `DefaultCPUStream()` leak and mislead** — ✅ Now cached with `sync.Once` like `DefaultStream()`. No more allocation on every call.
|
||||
|
||||
- [ ] **Tokenizer `Encode()` is character-level only** — BPE merges are parsed and stored (`tokenizer.go:77-96`) but never applied in `Encode()` or `encodeGPT2()`. Both methods fall back to single-character lookup. This produces far more tokens than a real BPE encoder, drastically reducing effective context length and increasing prefill time. For Phase 2 model validation this will need fixing, especially for Gemma3-1B at 5K sentences/sec.
|
||||
- [x] **Tokenizer `Encode()` is character-level only** — ✅ Implemented `bpeMerge()` — standard BPE algorithm using merge rank lookup. Both SentencePiece `Encode()` and GPT-2 `encodeGPT2()` now split into characters, apply BPE merges, then look up merged symbols. Merge ranks built during tokenizer load. 3 new tests.
|
||||
|
||||
- [ ] **`CompileShapeless` is dead code** — `compile.go:82-89`: `Call()` bypasses the C closure entirely and just calls `cf.fn(inputs)` directly. The C closure, callback, and `sync.Map` registration are all unused overhead. The `shapeless` param is ignored. Either implement actual `mlx_compile` call-through or simplify to a plain function wrapper.
|
||||
- [x] **`CompileShapeless` is dead code** — ✅ Removed C closure, callback, `sync.Map`, and `nextID` infrastructure. `CompiledFunc` is now a plain function wrapper with mutex. `CompileShapeless()` and `Call()` signatures unchanged (gemma3.go GELU still works).
|
||||
|
||||
### Minor — Nice to Have
|
||||
|
||||
|
|
|
|||
|
|
@ -2,89 +2,25 @@
|
|||
|
||||
package metal
|
||||
|
||||
/*
|
||||
#include "mlx/c/mlx.h"
|
||||
import "sync"
|
||||
|
||||
// Callback for compiled functions.
|
||||
extern int goCompiledFunc(mlx_vector_array *outputs, const mlx_vector_array inputs, void *payload);
|
||||
|
||||
static mlx_closure new_closure(void *payload) {
|
||||
return mlx_closure_new_func_payload(&goCompiledFunc, payload, NULL);
|
||||
}
|
||||
*/
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// CompiledFunc wraps a compiled MLX computation graph for efficient repeated calls.
|
||||
// CompiledFunc wraps a function for efficient repeated execution.
|
||||
// The function is called directly; MLX's lazy evaluation graph
|
||||
// still deduplicates and optimises the underlying Metal operations.
|
||||
type CompiledFunc struct {
|
||||
fn func([]*Array) []*Array
|
||||
closure C.mlx_closure
|
||||
mu sync.Mutex
|
||||
fn func([]*Array) []*Array
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
var compiledFuncs sync.Map
|
||||
|
||||
//export goCompiledFunc
|
||||
func goCompiledFunc(outputs *C.mlx_vector_array, inputs C.mlx_vector_array, payload unsafe.Pointer) C.int {
|
||||
id := uintptr(payload)
|
||||
fnI, ok := compiledFuncs.Load(id)
|
||||
if !ok {
|
||||
return 1
|
||||
}
|
||||
fn := fnI.(func([]*Array) []*Array)
|
||||
|
||||
// Convert inputs
|
||||
nInputs := int(C.mlx_vector_array_size(inputs))
|
||||
goInputs := make([]*Array, nInputs)
|
||||
for i := 0; i < nInputs; i++ {
|
||||
a := New("INPUT")
|
||||
C.mlx_vector_array_get(&a.ctx, inputs, C.size_t(i))
|
||||
goInputs[i] = a
|
||||
}
|
||||
|
||||
// Call user function
|
||||
goOutputs := fn(goInputs)
|
||||
|
||||
// The output vector arrives with ctx=nullptr. Create a valid vector,
|
||||
// fill it, then copy to the output pointer.
|
||||
tmp := C.mlx_vector_array_new()
|
||||
for _, out := range goOutputs {
|
||||
C.mlx_vector_array_append_value(tmp, out.ctx)
|
||||
}
|
||||
C.mlx_vector_array_set(outputs, tmp)
|
||||
C.mlx_vector_array_free(tmp)
|
||||
return 0
|
||||
}
|
||||
|
||||
var nextID uintptr
|
||||
var nextIDMu sync.Mutex
|
||||
|
||||
// CompileShapeless compiles a function for efficient repeated execution.
|
||||
// The function must accept and return arrays of consistent shapes.
|
||||
// CompileShapeless wraps a function for repeated execution.
|
||||
// The shapeless parameter is accepted for API compatibility but unused.
|
||||
func CompileShapeless(fn func([]*Array) []*Array, shapeless bool) *CompiledFunc {
|
||||
nextIDMu.Lock()
|
||||
nextID++
|
||||
id := nextID
|
||||
nextIDMu.Unlock()
|
||||
|
||||
compiledFuncs.Store(id, fn)
|
||||
|
||||
cf := &CompiledFunc{fn: fn}
|
||||
cf.closure = C.new_closure(unsafe.Pointer(id))
|
||||
return cf
|
||||
return &CompiledFunc{fn: fn}
|
||||
}
|
||||
|
||||
// Call executes the compiled function with the given inputs.
|
||||
// Call executes the function with the given inputs.
|
||||
func (cf *CompiledFunc) Call(inputs ...*Array) []*Array {
|
||||
cf.mu.Lock()
|
||||
defer cf.mu.Unlock()
|
||||
|
||||
// Fall back to direct call — compilation is an optimization.
|
||||
// The compiled closure can be used via mlx_compiled but the
|
||||
// direct path is simpler and still benefits from MLX's lazy evaluation.
|
||||
return cf.fn(inputs)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -69,6 +69,10 @@ func (m *Model) Chat(ctx context.Context, messages []ChatMessage, cfg GenerateCo
|
|||
}
|
||||
|
||||
// Generate streams tokens for the given prompt.
|
||||
//
|
||||
// Each call allocates fresh KV caches that are released to GC when the iterator
|
||||
// completes. For multi-turn chat, call [ClearCache] between turns to reclaim
|
||||
// Metal memory promptly rather than waiting for GC finalisers.
|
||||
func (m *Model) Generate(ctx context.Context, prompt string, cfg GenerateConfig) iter.Seq[Token] {
|
||||
m.lastErr = nil
|
||||
|
||||
|
|
@ -86,6 +90,9 @@ func (m *Model) Generate(ctx context.Context, prompt string, cfg GenerateConfig)
|
|||
return
|
||||
}
|
||||
|
||||
// Track generated token IDs for repeat penalty.
|
||||
var history []int32
|
||||
|
||||
for i := 0; i < cfg.MaxTokens; i++ {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
|
|
@ -97,6 +104,12 @@ func (m *Model) Generate(ctx context.Context, prompt string, cfg GenerateConfig)
|
|||
// Sample from last position logits
|
||||
lastPos := SliceAxis(logits, 1, int32(logits.Dim(1)-1), int32(logits.Dim(1)))
|
||||
lastPos = Reshape(lastPos, 1, int32(lastPos.Dim(2)))
|
||||
|
||||
// Apply repeat penalty before sampling.
|
||||
if cfg.RepeatPenalty > 1.0 && len(history) > 0 {
|
||||
lastPos = applyRepeatPenalty(lastPos, history, cfg.RepeatPenalty)
|
||||
}
|
||||
|
||||
next := sampler.Sample(lastPos)
|
||||
if err := Eval(next); err != nil {
|
||||
m.lastErr = fmt.Errorf("sample step %d: %w", i, err)
|
||||
|
|
@ -104,6 +117,7 @@ func (m *Model) Generate(ctx context.Context, prompt string, cfg GenerateConfig)
|
|||
}
|
||||
|
||||
id := int32(next.Int())
|
||||
history = append(history, id)
|
||||
|
||||
// Check stop conditions
|
||||
if id == m.tokenizer.EOSToken() {
|
||||
|
|
@ -132,6 +146,37 @@ func (m *Model) Generate(ctx context.Context, prompt string, cfg GenerateConfig)
|
|||
}
|
||||
}
|
||||
|
||||
// applyRepeatPenalty modifies logits to discourage repeated tokens.
|
||||
// For each unique token ID in history: positive logits are divided by penalty,
|
||||
// negative logits are multiplied by penalty. Both make the token less likely.
|
||||
func applyRepeatPenalty(logits *Array, history []int32, penalty float32) *Array {
|
||||
// Deduplicate history to get unique token IDs.
|
||||
seen := make(map[int32]bool, len(history))
|
||||
var indices []int32
|
||||
for _, id := range history {
|
||||
if !seen[id] {
|
||||
seen[id] = true
|
||||
indices = append(indices, id)
|
||||
}
|
||||
}
|
||||
|
||||
idx := FromValues(indices, 1, len(indices))
|
||||
gathered := TakeAlongAxis(logits, idx, -1)
|
||||
|
||||
zero := FromValue(float32(0))
|
||||
invPenalty := FromValue(1.0 / penalty)
|
||||
penaltyVal := FromValue(penalty)
|
||||
|
||||
// Positive logits: divide by penalty. Negative logits: multiply by penalty.
|
||||
penalised := Where(
|
||||
Greater(gathered, zero),
|
||||
Mul(gathered, invPenalty),
|
||||
Mul(gathered, penaltyVal),
|
||||
)
|
||||
|
||||
return PutAlongAxis(logits, idx, penalised, -1)
|
||||
}
|
||||
|
||||
// newCaches creates per-layer KV caches. If contextLen is set, all unbounded
|
||||
// caches are replaced with RotatingKVCache to cap memory usage.
|
||||
func (m *Model) newCaches() []Cache {
|
||||
|
|
|
|||
|
|
@ -146,3 +146,47 @@ func TestMinP_RestrictsOptions(t *testing.T) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyRepeatPenalty(t *testing.T) {
|
||||
// Logits: [1, 4] with values [5.0, -3.0, 1.0, 0.0]
|
||||
// History: tokens 0 and 1 have been seen.
|
||||
// Penalty 2.0:
|
||||
// token 0 (logit 5.0 > 0): 5.0 / 2.0 = 2.5
|
||||
// token 1 (logit -3.0 < 0): -3.0 * 2.0 = -6.0
|
||||
// token 2 (not in history): unchanged = 1.0
|
||||
// token 3 (not in history): unchanged = 0.0
|
||||
logits := FromValues([]float32{5.0, -3.0, 1.0, 0.0}, 1, 4)
|
||||
Materialize(logits)
|
||||
|
||||
result := applyRepeatPenalty(logits, []int32{0, 1, 0}, 2.0) // duplicate 0 should be deduped
|
||||
Materialize(result)
|
||||
|
||||
got := result.Floats()
|
||||
want := []float32{2.5, -6.0, 1.0, 0.0}
|
||||
for i := range got {
|
||||
diff := got[i] - want[i]
|
||||
if diff > 0.01 || diff < -0.01 {
|
||||
t.Errorf("repeatPenalty[%d] = %f, want %f", i, got[i], want[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyRepeatPenalty_NoHistory(t *testing.T) {
|
||||
// With empty history, logits should be unchanged.
|
||||
logits := FromValues([]float32{5.0, -3.0, 1.0}, 1, 3)
|
||||
Materialize(logits)
|
||||
|
||||
// applyRepeatPenalty is not called when history is empty (checked in generate loop),
|
||||
// but verify the function handles it gracefully if called directly.
|
||||
result := applyRepeatPenalty(logits, []int32{1}, 1.0) // penalty=1.0 → no change
|
||||
Materialize(result)
|
||||
|
||||
got := result.Floats()
|
||||
want := []float32{5.0, -3.0, 1.0}
|
||||
for i := range got {
|
||||
diff := got[i] - want[i]
|
||||
if diff > 0.01 || diff < -0.01 {
|
||||
t.Errorf("penalty=1.0[%d] = %f, want %f", i, got[i], want[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -17,6 +17,12 @@ type Stream struct {
|
|||
var (
|
||||
defaultStream *Stream
|
||||
defaultStreamOnce sync.Once
|
||||
|
||||
defaultGPUStream *Stream
|
||||
defaultGPUStreamOnce sync.Once
|
||||
|
||||
defaultCPUStream *Stream
|
||||
defaultCPUStreamOnce sync.Once
|
||||
)
|
||||
|
||||
// DefaultStream returns the default GPU stream, creating it on first use.
|
||||
|
|
@ -28,16 +34,22 @@ func DefaultStream() *Stream {
|
|||
return defaultStream
|
||||
}
|
||||
|
||||
// DefaultGPUStream returns a new GPU stream.
|
||||
// DefaultGPUStream returns the cached default GPU stream.
|
||||
func DefaultGPUStream() *Stream {
|
||||
Init()
|
||||
return &Stream{ctx: C.mlx_default_gpu_stream_new()}
|
||||
defaultGPUStreamOnce.Do(func() {
|
||||
Init()
|
||||
defaultGPUStream = &Stream{ctx: C.mlx_default_gpu_stream_new()}
|
||||
})
|
||||
return defaultGPUStream
|
||||
}
|
||||
|
||||
// DefaultCPUStream returns a new CPU stream.
|
||||
// DefaultCPUStream returns the cached default CPU stream.
|
||||
func DefaultCPUStream() *Stream {
|
||||
Init()
|
||||
return &Stream{ctx: C.mlx_default_cpu_stream_new()}
|
||||
defaultCPUStreamOnce.Do(func() {
|
||||
Init()
|
||||
defaultCPUStream = &Stream{ctx: C.mlx_default_cpu_stream_new()}
|
||||
})
|
||||
return defaultCPUStream
|
||||
}
|
||||
|
||||
// Synchronize waits for all operations on the stream to complete.
|
||||
|
|
|
|||
|
|
@ -11,10 +11,11 @@ import (
|
|||
|
||||
// Tokenizer handles text-to-token and token-to-text conversion.
|
||||
type Tokenizer struct {
|
||||
vocab map[string]int32
|
||||
invVocab map[int32]string
|
||||
merges []mergePair
|
||||
special map[string]int32
|
||||
vocab map[string]int32
|
||||
invVocab map[int32]string
|
||||
merges []mergePair
|
||||
mergeRanks map[string]int // "a b" → rank for O(1) merge lookup
|
||||
special map[string]int32
|
||||
|
||||
bosToken int32
|
||||
eosToken int32
|
||||
|
|
@ -95,6 +96,12 @@ func LoadTokenizer(path string) (*Tokenizer, error) {
|
|||
}
|
||||
}
|
||||
|
||||
// Build merge rank lookup for BPE.
|
||||
t.mergeRanks = make(map[string]int, len(t.merges))
|
||||
for _, m := range t.merges {
|
||||
t.mergeRanks[m.a+" "+m.b] = m.rank
|
||||
}
|
||||
|
||||
// Parse special tokens
|
||||
for _, tok := range tj.AddedTokens {
|
||||
if tok.Special {
|
||||
|
|
@ -166,17 +173,44 @@ func buildGPT2ByteMaps() (decoder map[rune]byte, encoder map[byte]rune) {
|
|||
return
|
||||
}
|
||||
|
||||
// bpeMerge applies BPE merges to a sequence of symbols until no more merges apply.
|
||||
// Uses the standard algorithm: repeatedly find the lowest-rank adjacent pair and merge it.
|
||||
func (t *Tokenizer) bpeMerge(symbols []string) []string {
|
||||
for len(symbols) > 1 {
|
||||
// Find the pair with the lowest merge rank.
|
||||
bestRank := -1
|
||||
bestIdx := -1
|
||||
for i := 0; i < len(symbols)-1; i++ {
|
||||
key := symbols[i] + " " + symbols[i+1]
|
||||
if rank, ok := t.mergeRanks[key]; ok {
|
||||
if bestRank < 0 || rank < bestRank {
|
||||
bestRank = rank
|
||||
bestIdx = i
|
||||
}
|
||||
}
|
||||
}
|
||||
if bestIdx < 0 {
|
||||
break // No more merges available.
|
||||
}
|
||||
// Merge the pair at bestIdx.
|
||||
merged := symbols[bestIdx] + symbols[bestIdx+1]
|
||||
symbols = append(symbols[:bestIdx], append([]string{merged}, symbols[bestIdx+2:]...)...)
|
||||
}
|
||||
return symbols
|
||||
}
|
||||
|
||||
// Encode converts text to token IDs. Prepends BOS token.
|
||||
func (t *Tokenizer) Encode(text string) []int32 {
|
||||
tokens := []int32{t.bosToken}
|
||||
|
||||
if t.isGPT2BPE {
|
||||
return t.encodeGPT2(text)
|
||||
}
|
||||
|
||||
// SentencePiece style encoding
|
||||
tokens := []int32{t.bosToken}
|
||||
|
||||
// SentencePiece style: split into segments around special tokens, then BPE each segment.
|
||||
remaining := text
|
||||
for remaining != "" {
|
||||
// Check for special tokens at the current position.
|
||||
found := false
|
||||
for tok, id := range t.special {
|
||||
if strings.HasPrefix(remaining, tok) {
|
||||
|
|
@ -186,15 +220,35 @@ func (t *Tokenizer) Encode(text string) []int32 {
|
|||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
r := []rune(remaining)
|
||||
ch := "▁" + string(r[0])
|
||||
if id, ok := t.vocab[ch]; ok {
|
||||
tokens = append(tokens, id)
|
||||
} else if id, ok := t.vocab[string(r[0])]; ok {
|
||||
if found {
|
||||
continue
|
||||
}
|
||||
|
||||
// Find the next special token boundary (or end of string).
|
||||
end := len(remaining)
|
||||
for tok := range t.special {
|
||||
if idx := strings.Index(remaining, tok); idx > 0 && idx < end {
|
||||
end = idx
|
||||
}
|
||||
}
|
||||
segment := remaining[:end]
|
||||
remaining = remaining[end:]
|
||||
|
||||
// SentencePiece: prefix the segment with ▁ (space marker) and split into characters.
|
||||
spText := "▁" + segment
|
||||
symbols := make([]string, 0, len([]rune(spText)))
|
||||
for _, r := range spText {
|
||||
symbols = append(symbols, string(r))
|
||||
}
|
||||
|
||||
// Apply BPE merges.
|
||||
symbols = t.bpeMerge(symbols)
|
||||
|
||||
// Look up merged symbols in vocab.
|
||||
for _, sym := range symbols {
|
||||
if id, ok := t.vocab[sym]; ok {
|
||||
tokens = append(tokens, id)
|
||||
}
|
||||
remaining = string(r[1:])
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -205,45 +259,56 @@ func (t *Tokenizer) Encode(text string) []int32 {
|
|||
func (t *Tokenizer) encodeGPT2(text string) []int32 {
|
||||
tokens := []int32{t.bosToken}
|
||||
|
||||
// Convert text bytes to GPT-2 Unicode representation
|
||||
var encoded strings.Builder
|
||||
for _, b := range []byte(text) {
|
||||
if r, ok := t.gpt2Encoder[b]; ok {
|
||||
encoded.WriteRune(r)
|
||||
}
|
||||
}
|
||||
gpt2Text := encoded.String()
|
||||
|
||||
// Scan for special tokens and regular text
|
||||
remaining := gpt2Text
|
||||
// Split text around special tokens (matched in original form, not byte-encoded).
|
||||
remaining := text
|
||||
for remaining != "" {
|
||||
// Check special tokens (these are stored as-is, not byte-encoded)
|
||||
// Check for special tokens at the current position.
|
||||
found := false
|
||||
for tok, id := range t.special {
|
||||
// Special tokens in GPT-2 tokenizers are stored in their original form
|
||||
// Convert the special token to GPT-2 encoding for matching
|
||||
var encTok strings.Builder
|
||||
for _, b := range []byte(tok) {
|
||||
if r, ok := t.gpt2Encoder[b]; ok {
|
||||
encTok.WriteRune(r)
|
||||
}
|
||||
}
|
||||
encStr := encTok.String()
|
||||
if strings.HasPrefix(remaining, encStr) {
|
||||
if strings.HasPrefix(remaining, tok) {
|
||||
tokens = append(tokens, id)
|
||||
remaining = remaining[len(encStr):]
|
||||
remaining = remaining[len(tok):]
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
// Character-by-character lookup (simplified BPE)
|
||||
r := []rune(remaining)
|
||||
ch := string(r[0])
|
||||
if id, ok := t.vocab[ch]; ok {
|
||||
if found {
|
||||
continue
|
||||
}
|
||||
|
||||
// Find the next special token boundary (or end of string).
|
||||
end := len(remaining)
|
||||
for tok := range t.special {
|
||||
if idx := strings.Index(remaining, tok); idx > 0 && idx < end {
|
||||
end = idx
|
||||
}
|
||||
}
|
||||
segment := remaining[:end]
|
||||
remaining = remaining[end:]
|
||||
|
||||
// Convert segment bytes to GPT-2 Unicode representation.
|
||||
var encoded strings.Builder
|
||||
for _, b := range []byte(segment) {
|
||||
if r, ok := t.gpt2Encoder[b]; ok {
|
||||
encoded.WriteRune(r)
|
||||
}
|
||||
}
|
||||
|
||||
// Split into individual runes (GPT-2 BPE operates on Unicode chars).
|
||||
runes := []rune(encoded.String())
|
||||
symbols := make([]string, len(runes))
|
||||
for i, r := range runes {
|
||||
symbols[i] = string(r)
|
||||
}
|
||||
|
||||
// Apply BPE merges.
|
||||
symbols = t.bpeMerge(symbols)
|
||||
|
||||
// Look up merged symbols in vocab.
|
||||
for _, sym := range symbols {
|
||||
if id, ok := t.vocab[sym]; ok {
|
||||
tokens = append(tokens, id)
|
||||
}
|
||||
remaining = string(r[1:])
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -94,7 +94,62 @@ func TestEncode_ProducesTokens(t *testing.T) {
|
|||
if tokens[0] != tok.BOSToken() {
|
||||
t.Errorf("first token = %d, want BOS (%d)", tokens[0], tok.BOSToken())
|
||||
}
|
||||
t.Logf("Encode(\"hello\") = %v", tokens)
|
||||
// With BPE merges ("h e" → "he", "l l" → "ll"), "hello" with ▁ prefix becomes:
|
||||
// "▁" "h" "e" "l" "l" "o" → merge "h e" → "▁" "he" "l" "l" "o"
|
||||
// → merge "l l" → "▁" "he" "ll" "o"
|
||||
// No further merges. But "▁" is not "▁h" so it stays as "▁".
|
||||
// Vocab: ▁=4, he=5, ll=6, o=3. Expected: [BOS, 4, 5, 6, 3]
|
||||
want := []int32{100, 4, 5, 6, 3}
|
||||
if len(tokens) != len(want) {
|
||||
t.Fatalf("Encode(\"hello\") = %v, want %v", tokens, want)
|
||||
}
|
||||
for i := range tokens {
|
||||
if tokens[i] != want[i] {
|
||||
t.Errorf("tokens[%d] = %d, want %d", i, tokens[i], want[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBPEMerge(t *testing.T) {
|
||||
tok := &Tokenizer{
|
||||
mergeRanks: map[string]int{
|
||||
"h e": 0,
|
||||
"l l": 1,
|
||||
"he l": 2,
|
||||
},
|
||||
}
|
||||
|
||||
// "h" "e" "l" "l" "o" → merge "h e" (rank 0) → "he" "l" "l" "o"
|
||||
// → merge "l l" (rank 1) → "he" "ll" "o"
|
||||
// → merge "he l" does NOT match "he ll" — stops here.
|
||||
symbols := []string{"h", "e", "l", "l", "o"}
|
||||
got := tok.bpeMerge(symbols)
|
||||
want := []string{"he", "ll", "o"}
|
||||
if len(got) != len(want) {
|
||||
t.Fatalf("bpeMerge = %v, want %v", got, want)
|
||||
}
|
||||
for i := range got {
|
||||
if got[i] != want[i] {
|
||||
t.Errorf("bpeMerge[%d] = %q, want %q", i, got[i], want[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBPEMerge_NoMerges(t *testing.T) {
|
||||
tok := &Tokenizer{mergeRanks: map[string]int{}}
|
||||
symbols := []string{"a", "b", "c"}
|
||||
got := tok.bpeMerge(symbols)
|
||||
if len(got) != 3 {
|
||||
t.Errorf("bpeMerge with no merges = %v, want [a b c]", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBPEMerge_SingleSymbol(t *testing.T) {
|
||||
tok := &Tokenizer{mergeRanks: map[string]int{"a b": 0}}
|
||||
got := tok.bpeMerge([]string{"x"})
|
||||
if len(got) != 1 || got[0] != "x" {
|
||||
t.Errorf("bpeMerge single = %v, want [x]", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecode_SpecialTokensSkipped(t *testing.T) {
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue