The Decode method strips the SentencePiece leading space from every token, which loses word boundaries during streaming. DecodeToken preserves the space (it represents the word boundary) and only the first token of each generation has its leading space stripped. Fixes Gemma3 space prefix appearing in chat UI output. Co-Authored-By: Virgil <virgil@lethean.io>
324 lines
8.3 KiB
Go
324 lines
8.3 KiB
Go
//go:build darwin && arm64
|
|
|
|
// Package tokenizer provides BPE tokenization for transformer models.
|
|
package tokenizer
|
|
|
|
import (
|
|
"encoding/json"
|
|
"fmt"
|
|
"os"
|
|
"strings"
|
|
)
|
|
|
|
// Tokenizer handles text-to-token and token-to-text conversion.
|
|
type Tokenizer struct {
|
|
vocab map[string]int32
|
|
invVocab map[int32]string
|
|
merges []mergePair
|
|
special map[string]int32
|
|
|
|
bosToken int32
|
|
eosToken int32
|
|
|
|
// GPT-2 byte-level BPE support (used by Qwen, GPT, Llama, etc.)
|
|
isGPT2BPE bool
|
|
gpt2Decoder map[rune]byte // Unicode char → original byte
|
|
gpt2Encoder map[byte]rune // original byte → Unicode char
|
|
}
|
|
|
|
type mergePair struct {
|
|
a, b string
|
|
rank int
|
|
}
|
|
|
|
// tokenizerJSON is the HuggingFace tokenizer.json format.
|
|
type tokenizerJSON struct {
|
|
Model struct {
|
|
Type string `json:"type"`
|
|
Vocab json.RawMessage `json:"vocab"`
|
|
Merges json.RawMessage `json:"merges"`
|
|
ByteFallback bool `json:"byte_fallback"`
|
|
} `json:"model"`
|
|
AddedTokens []struct {
|
|
ID int32 `json:"id"`
|
|
Content string `json:"content"`
|
|
Special bool `json:"special"`
|
|
} `json:"added_tokens"`
|
|
}
|
|
|
|
// Load reads a tokenizer.json file and creates a Tokenizer.
|
|
func Load(path string) (*Tokenizer, error) {
|
|
data, err := os.ReadFile(path)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("tokenizer: read %s: %w", path, err)
|
|
}
|
|
|
|
var tj tokenizerJSON
|
|
if err := json.Unmarshal(data, &tj); err != nil {
|
|
return nil, fmt.Errorf("tokenizer: parse: %w", err)
|
|
}
|
|
|
|
t := &Tokenizer{
|
|
vocab: make(map[string]int32),
|
|
invVocab: make(map[int32]string),
|
|
special: make(map[string]int32),
|
|
}
|
|
|
|
// Parse vocab
|
|
var vocab map[string]int32
|
|
if err := json.Unmarshal(tj.Model.Vocab, &vocab); err != nil {
|
|
return nil, fmt.Errorf("tokenizer: parse vocab: %w", err)
|
|
}
|
|
t.vocab = vocab
|
|
for k, v := range vocab {
|
|
t.invVocab[v] = k
|
|
}
|
|
|
|
// Parse merges — supports both ["a b", ...] and [["a","b"], ...] formats
|
|
if len(tj.Model.Merges) > 0 {
|
|
var stringMerges []string
|
|
if err := json.Unmarshal(tj.Model.Merges, &stringMerges); err == nil {
|
|
for rank, merge := range stringMerges {
|
|
parts := strings.SplitN(merge, " ", 2)
|
|
if len(parts) == 2 {
|
|
t.merges = append(t.merges, mergePair{a: parts[0], b: parts[1], rank: rank})
|
|
}
|
|
}
|
|
} else {
|
|
var arrayMerges [][]string
|
|
if err := json.Unmarshal(tj.Model.Merges, &arrayMerges); err == nil {
|
|
for rank, pair := range arrayMerges {
|
|
if len(pair) == 2 {
|
|
t.merges = append(t.merges, mergePair{a: pair[0], b: pair[1], rank: rank})
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Parse special tokens
|
|
for _, tok := range tj.AddedTokens {
|
|
if tok.Special {
|
|
t.special[tok.Content] = tok.ID
|
|
}
|
|
t.vocab[tok.Content] = tok.ID
|
|
t.invVocab[tok.ID] = tok.Content
|
|
}
|
|
|
|
// Detect GPT-2 byte-level BPE (Qwen, GPT, Llama use Ġ for space)
|
|
if _, ok := t.vocab["Ġ"]; ok {
|
|
t.isGPT2BPE = true
|
|
t.gpt2Decoder, t.gpt2Encoder = buildGPT2ByteMaps()
|
|
}
|
|
|
|
// Set BOS/EOS — detect model family from special tokens
|
|
if id, ok := t.special["<bos>"]; ok {
|
|
t.bosToken = id
|
|
}
|
|
if id, ok := t.special["<eos>"]; ok {
|
|
t.eosToken = id
|
|
}
|
|
// Gemma: <end_of_turn> is the generation stop token
|
|
if id, ok := t.special["<end_of_turn>"]; ok {
|
|
t.eosToken = id
|
|
}
|
|
// Qwen3: <|im_end|> is the generation stop token
|
|
if id, ok := t.special["<|im_end|>"]; ok {
|
|
t.eosToken = id
|
|
}
|
|
// Qwen3 BOS: <|im_start|>
|
|
if id, ok := t.special["<|im_start|>"]; ok {
|
|
t.bosToken = id
|
|
}
|
|
|
|
return t, nil
|
|
}
|
|
|
|
// buildGPT2ByteMaps creates the GPT-2 byte-level BPE encoding/decoding maps.
|
|
// GPT-2 maps all 256 bytes to printable Unicode characters to avoid control chars
|
|
// in the vocabulary. Printable ASCII + Latin-1 Supplement map to themselves;
|
|
// everything else (0-32, 127-160, 173) maps to U+0100 onwards.
|
|
func buildGPT2ByteMaps() (decoder map[rune]byte, encoder map[byte]rune) {
|
|
encoder = make(map[byte]rune, 256)
|
|
decoder = make(map[rune]byte, 256)
|
|
|
|
// Self-mapping ranges: printable ASCII + Latin-1 Supplement
|
|
// Use int loop variable to avoid byte overflow at 255.
|
|
selfMap := func(lo, hi int) {
|
|
for b := lo; b <= hi; b++ {
|
|
encoder[byte(b)] = rune(b)
|
|
decoder[rune(b)] = byte(b)
|
|
}
|
|
}
|
|
selfMap(33, 126) // ! through ~
|
|
selfMap(161, 172) // ¡ through ¬
|
|
selfMap(174, 255) // ® through ÿ
|
|
|
|
// Non-self-mapping: control chars, space, DEL, and gaps
|
|
n := 0
|
|
for b := 0; b < 256; b++ {
|
|
if _, ok := encoder[byte(b)]; !ok {
|
|
r := rune(256 + n)
|
|
encoder[byte(b)] = r
|
|
decoder[r] = byte(b)
|
|
n++
|
|
}
|
|
}
|
|
return
|
|
}
|
|
|
|
// Encode converts text to token IDs. Prepends BOS token.
|
|
func (t *Tokenizer) Encode(text string) []int32 {
|
|
tokens := []int32{t.bosToken}
|
|
|
|
if t.isGPT2BPE {
|
|
return t.encodeGPT2(text)
|
|
}
|
|
|
|
// SentencePiece style encoding
|
|
remaining := text
|
|
for remaining != "" {
|
|
found := false
|
|
for tok, id := range t.special {
|
|
if strings.HasPrefix(remaining, tok) {
|
|
tokens = append(tokens, id)
|
|
remaining = remaining[len(tok):]
|
|
found = true
|
|
break
|
|
}
|
|
}
|
|
if !found {
|
|
r := []rune(remaining)
|
|
ch := "▁" + string(r[0])
|
|
if id, ok := t.vocab[ch]; ok {
|
|
tokens = append(tokens, id)
|
|
} else if id, ok := t.vocab[string(r[0])]; ok {
|
|
tokens = append(tokens, id)
|
|
}
|
|
remaining = string(r[1:])
|
|
}
|
|
}
|
|
|
|
return tokens
|
|
}
|
|
|
|
// encodeGPT2 encodes text using GPT-2 byte-level BPE.
|
|
func (t *Tokenizer) encodeGPT2(text string) []int32 {
|
|
tokens := []int32{t.bosToken}
|
|
|
|
// Convert text bytes to GPT-2 Unicode representation
|
|
var encoded strings.Builder
|
|
for _, b := range []byte(text) {
|
|
if r, ok := t.gpt2Encoder[b]; ok {
|
|
encoded.WriteRune(r)
|
|
}
|
|
}
|
|
gpt2Text := encoded.String()
|
|
|
|
// Scan for special tokens and regular text
|
|
remaining := gpt2Text
|
|
for remaining != "" {
|
|
// Check special tokens (these are stored as-is, not byte-encoded)
|
|
found := false
|
|
for tok, id := range t.special {
|
|
// Special tokens in GPT-2 tokenizers are stored in their original form
|
|
// Convert the special token to GPT-2 encoding for matching
|
|
var encTok strings.Builder
|
|
for _, b := range []byte(tok) {
|
|
if r, ok := t.gpt2Encoder[b]; ok {
|
|
encTok.WriteRune(r)
|
|
}
|
|
}
|
|
encStr := encTok.String()
|
|
if strings.HasPrefix(remaining, encStr) {
|
|
tokens = append(tokens, id)
|
|
remaining = remaining[len(encStr):]
|
|
found = true
|
|
break
|
|
}
|
|
}
|
|
if !found {
|
|
// Character-by-character lookup (simplified BPE)
|
|
r := []rune(remaining)
|
|
ch := string(r[0])
|
|
if id, ok := t.vocab[ch]; ok {
|
|
tokens = append(tokens, id)
|
|
}
|
|
remaining = string(r[1:])
|
|
}
|
|
}
|
|
|
|
return tokens
|
|
}
|
|
|
|
// Decode converts token IDs back to text.
|
|
// For full-sequence decoding, the SentencePiece leading space is stripped.
|
|
func (t *Tokenizer) Decode(tokens []int32) string {
|
|
var sb strings.Builder
|
|
for _, id := range tokens {
|
|
if text, ok := t.invVocab[id]; ok {
|
|
// Skip special tokens in decode output
|
|
if _, isSpecial := t.special[text]; isSpecial {
|
|
continue
|
|
}
|
|
sb.WriteString(text)
|
|
}
|
|
}
|
|
raw := sb.String()
|
|
|
|
if t.isGPT2BPE {
|
|
return t.decodeGPT2Bytes(raw)
|
|
}
|
|
|
|
// SentencePiece style
|
|
result := strings.ReplaceAll(raw, "▁", " ")
|
|
if strings.HasPrefix(result, " ") {
|
|
result = result[1:]
|
|
}
|
|
return result
|
|
}
|
|
|
|
// DecodeToken converts a single token ID to text for streaming.
|
|
// Unlike Decode, it preserves the leading space (word boundary) so that
|
|
// token-by-token output maintains correct spacing between words.
|
|
func (t *Tokenizer) DecodeToken(id int32) string {
|
|
text, ok := t.invVocab[id]
|
|
if !ok {
|
|
return ""
|
|
}
|
|
if _, isSpecial := t.special[text]; isSpecial {
|
|
return ""
|
|
}
|
|
|
|
if t.isGPT2BPE {
|
|
return t.decodeGPT2Bytes(text)
|
|
}
|
|
|
|
// SentencePiece: replace ▁ with space but keep it (it's the word boundary)
|
|
return strings.ReplaceAll(text, "▁", " ")
|
|
}
|
|
|
|
// decodeGPT2Bytes converts GPT-2 byte-level BPE Unicode back to real bytes.
|
|
func (t *Tokenizer) decodeGPT2Bytes(s string) string {
|
|
var buf []byte
|
|
for _, r := range s {
|
|
if b, ok := t.gpt2Decoder[r]; ok {
|
|
buf = append(buf, b)
|
|
} else {
|
|
// Non-mapped runes pass through as UTF-8
|
|
buf = append(buf, []byte(string(r))...)
|
|
}
|
|
}
|
|
return string(buf)
|
|
}
|
|
|
|
// BOSToken returns the beginning-of-sequence token ID.
|
|
func (t *Tokenizer) BOSToken() int32 { return t.bosToken }
|
|
|
|
// EOSToken returns the end-of-sequence token ID.
|
|
func (t *Tokenizer) EOSToken() int32 { return t.eosToken }
|
|
|
|
// FormatGemmaPrompt applies the Gemma 3 chat template.
|
|
func FormatGemmaPrompt(prompt string) string {
|
|
return fmt.Sprintf("<start_of_turn>user\n%s<end_of_turn>\n<start_of_turn>model\n", prompt)
|
|
}
|