diff --git a/ml/backend_mlx.go b/ml/backend_mlx.go index 4a0e7d6..8783f9d 100644 --- a/ml/backend_mlx.go +++ b/ml/backend_mlx.go @@ -7,6 +7,7 @@ import ( "fmt" "log/slog" "runtime" + "strings" "sync" "forge.lthn.ai/core/go-ai/mlx" @@ -85,6 +86,7 @@ func (b *MLXBackend) generate(ctx context.Context, tokens []int32, opts GenOpts, } var output []int32 + firstToken := true for i := 0; i < maxTokens; i++ { select { case <-ctx.Done(): @@ -108,7 +110,12 @@ func (b *MLXBackend) generate(ctx context.Context, tokens []int32, opts GenOpts, // Stream the token text to the callback if cb != nil { - tokenText := b.tok.Decode([]int32{nextToken}) + tokenText := b.tok.DecodeToken(nextToken) + // Strip the SentencePiece leading space only on the first token + if firstToken { + tokenText = strings.TrimLeft(tokenText, " ") + firstToken = false + } if err := cb(tokenText); err != nil { runtime.GC() mlx.ClearCache() diff --git a/mlx/tokenizer/tokenizer.go b/mlx/tokenizer/tokenizer.go index 3537d6b..947a1a5 100644 --- a/mlx/tokenizer/tokenizer.go +++ b/mlx/tokenizer/tokenizer.go @@ -252,6 +252,7 @@ func (t *Tokenizer) encodeGPT2(text string) []int32 { } // Decode converts token IDs back to text. +// For full-sequence decoding, the SentencePiece leading space is stripped. func (t *Tokenizer) Decode(tokens []int32) string { var sb strings.Builder for _, id := range tokens { @@ -277,6 +278,26 @@ func (t *Tokenizer) Decode(tokens []int32) string { return result } +// DecodeToken converts a single token ID to text for streaming. +// Unlike Decode, it preserves the leading space (word boundary) so that +// token-by-token output maintains correct spacing between words. +func (t *Tokenizer) DecodeToken(id int32) string { + text, ok := t.invVocab[id] + if !ok { + return "" + } + if _, isSpecial := t.special[text]; isSpecial { + return "" + } + + if t.isGPT2BPE { + return t.decodeGPT2Bytes(text) + } + + // SentencePiece: replace ▁ with space but keep it (it's the word boundary) + return strings.ReplaceAll(text, "▁", " ") +} + // decodeGPT2Bytes converts GPT-2 byte-level BPE Unicode back to real bytes. func (t *Tokenizer) decodeGPT2Bytes(s string) string { var buf []byte