fix(mlx): add DecodeToken for correct streaming word boundaries
The Decode method strips the SentencePiece leading space from every token, which loses word boundaries during streaming. DecodeToken preserves the space (it represents the word boundary) and only the first token of each generation has its leading space stripped. Fixes Gemma3 space prefix appearing in chat UI output. Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
parent
2d870385f9
commit
98749c66f2
2 changed files with 29 additions and 1 deletions
|
|
@ -7,6 +7,7 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"forge.lthn.ai/core/go-ai/mlx"
|
"forge.lthn.ai/core/go-ai/mlx"
|
||||||
|
|
@ -85,6 +86,7 @@ func (b *MLXBackend) generate(ctx context.Context, tokens []int32, opts GenOpts,
|
||||||
}
|
}
|
||||||
|
|
||||||
var output []int32
|
var output []int32
|
||||||
|
firstToken := true
|
||||||
for i := 0; i < maxTokens; i++ {
|
for i := 0; i < maxTokens; i++ {
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
|
|
@ -108,7 +110,12 @@ func (b *MLXBackend) generate(ctx context.Context, tokens []int32, opts GenOpts,
|
||||||
|
|
||||||
// Stream the token text to the callback
|
// Stream the token text to the callback
|
||||||
if cb != nil {
|
if cb != nil {
|
||||||
tokenText := b.tok.Decode([]int32{nextToken})
|
tokenText := b.tok.DecodeToken(nextToken)
|
||||||
|
// Strip the SentencePiece leading space only on the first token
|
||||||
|
if firstToken {
|
||||||
|
tokenText = strings.TrimLeft(tokenText, " ")
|
||||||
|
firstToken = false
|
||||||
|
}
|
||||||
if err := cb(tokenText); err != nil {
|
if err := cb(tokenText); err != nil {
|
||||||
runtime.GC()
|
runtime.GC()
|
||||||
mlx.ClearCache()
|
mlx.ClearCache()
|
||||||
|
|
|
||||||
|
|
@ -252,6 +252,7 @@ func (t *Tokenizer) encodeGPT2(text string) []int32 {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Decode converts token IDs back to text.
|
// Decode converts token IDs back to text.
|
||||||
|
// For full-sequence decoding, the SentencePiece leading space is stripped.
|
||||||
func (t *Tokenizer) Decode(tokens []int32) string {
|
func (t *Tokenizer) Decode(tokens []int32) string {
|
||||||
var sb strings.Builder
|
var sb strings.Builder
|
||||||
for _, id := range tokens {
|
for _, id := range tokens {
|
||||||
|
|
@ -277,6 +278,26 @@ func (t *Tokenizer) Decode(tokens []int32) string {
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DecodeToken converts a single token ID to text for streaming.
|
||||||
|
// Unlike Decode, it preserves the leading space (word boundary) so that
|
||||||
|
// token-by-token output maintains correct spacing between words.
|
||||||
|
func (t *Tokenizer) DecodeToken(id int32) string {
|
||||||
|
text, ok := t.invVocab[id]
|
||||||
|
if !ok {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if _, isSpecial := t.special[text]; isSpecial {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
if t.isGPT2BPE {
|
||||||
|
return t.decodeGPT2Bytes(text)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SentencePiece: replace ▁ with space but keep it (it's the word boundary)
|
||||||
|
return strings.ReplaceAll(text, "▁", " ")
|
||||||
|
}
|
||||||
|
|
||||||
// decodeGPT2Bytes converts GPT-2 byte-level BPE Unicode back to real bytes.
|
// decodeGPT2Bytes converts GPT-2 byte-level BPE Unicode back to real bytes.
|
||||||
func (t *Tokenizer) decodeGPT2Bytes(s string) string {
|
func (t *Tokenizer) decodeGPT2Bytes(s string) string {
|
||||||
var buf []byte
|
var buf []byte
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue