fix(mlx): add DecodeToken for correct streaming word boundaries
The Decode method strips the SentencePiece leading space from every token, which loses word boundaries during streaming. DecodeToken preserves the space (it represents the word boundary) and only the first token of each generation has its leading space stripped. Fixes Gemma3 space prefix appearing in chat UI output. Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
parent
2d870385f9
commit
98749c66f2
2 changed files with 29 additions and 1 deletions
|
|
@ -7,6 +7,7 @@ import (
|
|||
"fmt"
|
||||
"log/slog"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"forge.lthn.ai/core/go-ai/mlx"
|
||||
|
|
@ -85,6 +86,7 @@ func (b *MLXBackend) generate(ctx context.Context, tokens []int32, opts GenOpts,
|
|||
}
|
||||
|
||||
var output []int32
|
||||
firstToken := true
|
||||
for i := 0; i < maxTokens; i++ {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
|
|
@ -108,7 +110,12 @@ func (b *MLXBackend) generate(ctx context.Context, tokens []int32, opts GenOpts,
|
|||
|
||||
// Stream the token text to the callback
|
||||
if cb != nil {
|
||||
tokenText := b.tok.Decode([]int32{nextToken})
|
||||
tokenText := b.tok.DecodeToken(nextToken)
|
||||
// Strip the SentencePiece leading space only on the first token
|
||||
if firstToken {
|
||||
tokenText = strings.TrimLeft(tokenText, " ")
|
||||
firstToken = false
|
||||
}
|
||||
if err := cb(tokenText); err != nil {
|
||||
runtime.GC()
|
||||
mlx.ClearCache()
|
||||
|
|
|
|||
|
|
@ -252,6 +252,7 @@ func (t *Tokenizer) encodeGPT2(text string) []int32 {
|
|||
}
|
||||
|
||||
// Decode converts token IDs back to text.
|
||||
// For full-sequence decoding, the SentencePiece leading space is stripped.
|
||||
func (t *Tokenizer) Decode(tokens []int32) string {
|
||||
var sb strings.Builder
|
||||
for _, id := range tokens {
|
||||
|
|
@ -277,6 +278,26 @@ func (t *Tokenizer) Decode(tokens []int32) string {
|
|||
return result
|
||||
}
|
||||
|
||||
// DecodeToken converts a single token ID to text for streaming.
|
||||
// Unlike Decode, it preserves the leading space (word boundary) so that
|
||||
// token-by-token output maintains correct spacing between words.
|
||||
func (t *Tokenizer) DecodeToken(id int32) string {
|
||||
text, ok := t.invVocab[id]
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
if _, isSpecial := t.special[text]; isSpecial {
|
||||
return ""
|
||||
}
|
||||
|
||||
if t.isGPT2BPE {
|
||||
return t.decodeGPT2Bytes(text)
|
||||
}
|
||||
|
||||
// SentencePiece: replace ▁ with space but keep it (it's the word boundary)
|
||||
return strings.ReplaceAll(text, "▁", " ")
|
||||
}
|
||||
|
||||
// decodeGPT2Bytes converts GPT-2 byte-level BPE Unicode back to real bytes.
|
||||
func (t *Tokenizer) decodeGPT2Bytes(s string) string {
|
||||
var buf []byte
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue