diff --git a/pkg/mlx/tokenizer/tokenizer.go b/pkg/mlx/tokenizer/tokenizer.go index 4a1258a..9dd9450 100644 --- a/pkg/mlx/tokenizer/tokenizer.go +++ b/pkg/mlx/tokenizer/tokenizer.go @@ -29,10 +29,10 @@ type mergePair struct { // tokenizerJSON is the HuggingFace tokenizer.json format. type tokenizerJSON struct { Model struct { - Type string `json:"type"` - Vocab json.RawMessage `json:"vocab"` - Merges []string `json:"merges"` - ByteFallback bool `json:"byte_fallback"` + Type string `json:"type"` + Vocab json.RawMessage `json:"vocab"` + Merges json.RawMessage `json:"merges"` + ByteFallback bool `json:"byte_fallback"` } `json:"model"` AddedTokens []struct { ID int32 `json:"id"` @@ -69,11 +69,27 @@ func Load(path string) (*Tokenizer, error) { t.invVocab[v] = k } - // Parse merges - for rank, merge := range tj.Model.Merges { - parts := strings.SplitN(merge, " ", 2) - if len(parts) == 2 { - t.merges = append(t.merges, mergePair{a: parts[0], b: parts[1], rank: rank}) + // Parse merges — supports both ["a b", ...] and [["a","b"], ...] formats + if len(tj.Model.Merges) > 0 { + // Try array-of-strings first + var stringMerges []string + if err := json.Unmarshal(tj.Model.Merges, &stringMerges); err == nil { + for rank, merge := range stringMerges { + parts := strings.SplitN(merge, " ", 2) + if len(parts) == 2 { + t.merges = append(t.merges, mergePair{a: parts[0], b: parts[1], rank: rank}) + } + } + } else { + // Try array-of-arrays: [["a","b"], ...] + var arrayMerges [][]string + if err := json.Unmarshal(tj.Model.Merges, &arrayMerges); err == nil { + for rank, pair := range arrayMerges { + if len(pair) == 2 { + t.merges = append(t.merges, mergePair{a: pair[0], b: pair[1], rank: rank}) + } + } + } } }