fix(mlx): add DecodeToken for correct streaming word boundaries

The Decode method strips the SentencePiece leading space from every
token, which loses word boundaries during streaming. DecodeToken
preserves the space (it represents the word boundary) and only the
first token of each generation has its leading space stripped.

Fixes Gemma3 space prefix appearing in chat UI output.

Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
Snider 2026-02-17 19:18:27 +00:00
parent 2d870385f9
commit 98749c66f2
2 changed files with 29 additions and 1 deletions

View file

@ -7,6 +7,7 @@ import (
"fmt" "fmt"
"log/slog" "log/slog"
"runtime" "runtime"
"strings"
"sync" "sync"
"forge.lthn.ai/core/go-ai/mlx" "forge.lthn.ai/core/go-ai/mlx"
@ -85,6 +86,7 @@ func (b *MLXBackend) generate(ctx context.Context, tokens []int32, opts GenOpts,
} }
var output []int32 var output []int32
firstToken := true
for i := 0; i < maxTokens; i++ { for i := 0; i < maxTokens; i++ {
select { select {
case <-ctx.Done(): case <-ctx.Done():
@ -108,7 +110,12 @@ func (b *MLXBackend) generate(ctx context.Context, tokens []int32, opts GenOpts,
// Stream the token text to the callback // Stream the token text to the callback
if cb != nil { if cb != nil {
tokenText := b.tok.Decode([]int32{nextToken}) tokenText := b.tok.DecodeToken(nextToken)
// Strip the SentencePiece leading space only on the first token
if firstToken {
tokenText = strings.TrimLeft(tokenText, " ")
firstToken = false
}
if err := cb(tokenText); err != nil { if err := cb(tokenText); err != nil {
runtime.GC() runtime.GC()
mlx.ClearCache() mlx.ClearCache()

View file

@ -252,6 +252,7 @@ func (t *Tokenizer) encodeGPT2(text string) []int32 {
} }
// Decode converts token IDs back to text. // Decode converts token IDs back to text.
// For full-sequence decoding, the SentencePiece leading space is stripped.
func (t *Tokenizer) Decode(tokens []int32) string { func (t *Tokenizer) Decode(tokens []int32) string {
var sb strings.Builder var sb strings.Builder
for _, id := range tokens { for _, id := range tokens {
@ -277,6 +278,26 @@ func (t *Tokenizer) Decode(tokens []int32) string {
return result return result
} }
// DecodeToken converts a single token ID to text for streaming.
// Unlike Decode, it preserves the leading space (word boundary) so that
// token-by-token output maintains correct spacing between words.
func (t *Tokenizer) DecodeToken(id int32) string {
text, ok := t.invVocab[id]
if !ok {
return ""
}
if _, isSpecial := t.special[text]; isSpecial {
return ""
}
if t.isGPT2BPE {
return t.decodeGPT2Bytes(text)
}
// SentencePiece: replace ▁ with space but keep it (it's the word boundary)
return strings.ReplaceAll(text, "▁", " ")
}
// decodeGPT2Bytes converts GPT-2 byte-level BPE Unicode back to real bytes. // decodeGPT2Bytes converts GPT-2 byte-level BPE Unicode back to real bytes.
func (t *Tokenizer) decodeGPT2Bytes(s string) string { func (t *Tokenizer) decodeGPT2Bytes(s string) string {
var buf []byte var buf []byte