From 98749c66f2c1c2572bc5ac40ab63ed5e6cd1f064 Mon Sep 17 00:00:00 2001 From: Snider Date: Tue, 17 Feb 2026 19:18:27 +0000 Subject: [PATCH] fix(mlx): add DecodeToken for correct streaming word boundaries The Decode method strips the SentencePiece leading space from every token, which loses word boundaries during streaming. DecodeToken preserves the space (it represents the word boundary) and only the first token of each generation has its leading space stripped. Fixes Gemma3 space prefix appearing in chat UI output. Co-Authored-By: Virgil --- ml/backend_mlx.go | 9 ++++++++- mlx/tokenizer/tokenizer.go | 21 +++++++++++++++++++++ 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/ml/backend_mlx.go b/ml/backend_mlx.go index 4a0e7d6..8783f9d 100644 --- a/ml/backend_mlx.go +++ b/ml/backend_mlx.go @@ -7,6 +7,7 @@ import ( "fmt" "log/slog" "runtime" + "strings" "sync" "forge.lthn.ai/core/go-ai/mlx" @@ -85,6 +86,7 @@ func (b *MLXBackend) generate(ctx context.Context, tokens []int32, opts GenOpts, } var output []int32 + firstToken := true for i := 0; i < maxTokens; i++ { select { case <-ctx.Done(): @@ -108,7 +110,12 @@ func (b *MLXBackend) generate(ctx context.Context, tokens []int32, opts GenOpts, // Stream the token text to the callback if cb != nil { - tokenText := b.tok.Decode([]int32{nextToken}) + tokenText := b.tok.DecodeToken(nextToken) + // Strip the SentencePiece leading space only on the first token + if firstToken { + tokenText = strings.TrimLeft(tokenText, " ") + firstToken = false + } if err := cb(tokenText); err != nil { runtime.GC() mlx.ClearCache() diff --git a/mlx/tokenizer/tokenizer.go b/mlx/tokenizer/tokenizer.go index 3537d6b..947a1a5 100644 --- a/mlx/tokenizer/tokenizer.go +++ b/mlx/tokenizer/tokenizer.go @@ -252,6 +252,7 @@ func (t *Tokenizer) encodeGPT2(text string) []int32 { } // Decode converts token IDs back to text. +// For full-sequence decoding, the SentencePiece leading space is stripped. func (t *Tokenizer) Decode(tokens []int32) string { var sb strings.Builder for _, id := range tokens { @@ -277,6 +278,26 @@ func (t *Tokenizer) Decode(tokens []int32) string { return result } +// DecodeToken converts a single token ID to text for streaming. +// Unlike Decode, it preserves the leading space (word boundary) so that +// token-by-token output maintains correct spacing between words. +func (t *Tokenizer) DecodeToken(id int32) string { + text, ok := t.invVocab[id] + if !ok { + return "" + } + if _, isSpecial := t.special[text]; isSpecial { + return "" + } + + if t.isGPT2BPE { + return t.decodeGPT2Bytes(text) + } + + // SentencePiece: replace ▁ with space but keep it (it's the word boundary) + return strings.ReplaceAll(text, "▁", " ") +} + // decodeGPT2Bytes converts GPT-2 byte-level BPE Unicode back to real bytes. func (t *Tokenizer) decodeGPT2Bytes(s string) string { var buf []byte