refactor(mlx): drop mlx build tag, auto-enable on darwin/arm64
Remove the manual -tags mlx requirement. MLX is now automatically compiled on darwin/arm64 via build constraints. Stubs remain for other platforms. No functional change. Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
parent
d0cbd5065e
commit
92c6282d50
19 changed files with 655 additions and 153 deletions
|
|
@ -1,4 +1,4 @@
|
||||||
//go:build darwin && arm64 && mlx
|
//go:build darwin && arm64
|
||||||
|
|
||||||
package ml
|
package ml
|
||||||
|
|
||||||
|
|
@ -16,9 +16,9 @@ import (
|
||||||
"forge.lthn.ai/core/go-ai/mlx/tokenizer"
|
"forge.lthn.ai/core/go-ai/mlx/tokenizer"
|
||||||
)
|
)
|
||||||
|
|
||||||
// MLXBackend implements Backend for native Metal inference via mlx-c.
|
// MLXBackend implements Backend and StreamingBackend for native Metal inference.
|
||||||
type MLXBackend struct {
|
type MLXBackend struct {
|
||||||
model *model.GemmaModel
|
model model.Model
|
||||||
tok *tokenizer.Tokenizer
|
tok *tokenizer.Tokenizer
|
||||||
caches []cache.Cache
|
caches []cache.Cache
|
||||||
sampler sample.Sampler
|
sampler sample.Sampler
|
||||||
|
|
@ -26,6 +26,9 @@ type MLXBackend struct {
|
||||||
modelBytes uint64 // model size at load time, for memory budget
|
modelBytes uint64 // model size at load time, for memory budget
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Compile-time check that MLXBackend satisfies StreamingBackend.
|
||||||
|
var _ StreamingBackend = (*MLXBackend)(nil)
|
||||||
|
|
||||||
// NewMLXBackend loads a model from a safetensors directory and creates
|
// NewMLXBackend loads a model from a safetensors directory and creates
|
||||||
// a native Metal inference backend.
|
// a native Metal inference backend.
|
||||||
func NewMLXBackend(modelPath string) (*MLXBackend, error) {
|
func NewMLXBackend(modelPath string) (*MLXBackend, error) {
|
||||||
|
|
@ -34,13 +37,12 @@ func NewMLXBackend(modelPath string) (*MLXBackend, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
slog.Info("mlx: loading model", "path", modelPath)
|
slog.Info("mlx: loading model", "path", modelPath)
|
||||||
m, err := model.LoadGemma3(modelPath)
|
m, err := model.LoadModel(modelPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("mlx: load model: %w", err)
|
return nil, fmt.Errorf("mlx: load model: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Cap Metal memory: cache limit for allocator reuse, memory limit as hard ceiling.
|
// Cap Metal memory: cache limit for allocator reuse, memory limit as hard ceiling.
|
||||||
// This prevents runaway memory growth from killing the system.
|
|
||||||
mlx.SetCacheLimit(16 * 1024 * 1024 * 1024) // 16 GB allocator cache
|
mlx.SetCacheLimit(16 * 1024 * 1024 * 1024) // 16 GB allocator cache
|
||||||
mlx.SetMemoryLimit(24 * 1024 * 1024 * 1024) // 24 GB hard cap
|
mlx.SetMemoryLimit(24 * 1024 * 1024 * 1024) // 24 GB hard cap
|
||||||
|
|
||||||
|
|
@ -54,31 +56,27 @@ func NewMLXBackend(modelPath string) (*MLXBackend, error) {
|
||||||
model: m,
|
model: m,
|
||||||
tok: m.Tokenizer(),
|
tok: m.Tokenizer(),
|
||||||
caches: m.NewCache(),
|
caches: m.NewCache(),
|
||||||
sampler: sample.New(0.1, 0, 0, 0), // default low temp
|
sampler: sample.New(0.1, 0, 0, 0),
|
||||||
modelBytes: mlx.GetActiveMemory(),
|
modelBytes: mlx.GetActiveMemory(),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generate produces text from a prompt using native Metal inference.
|
// generate is the core token generation loop. If cb is non-nil, each token's
|
||||||
func (b *MLXBackend) Generate(ctx context.Context, prompt string, opts GenOpts) (string, error) {
|
// text is sent to it (streaming mode). Returns the full output text.
|
||||||
|
func (b *MLXBackend) generate(ctx context.Context, tokens []int32, opts GenOpts, cb TokenCallback) (string, error) {
|
||||||
b.mu.Lock()
|
b.mu.Lock()
|
||||||
defer b.mu.Unlock()
|
defer b.mu.Unlock()
|
||||||
|
|
||||||
// Reset caches for new generation
|
|
||||||
for _, c := range b.caches {
|
for _, c := range b.caches {
|
||||||
c.Reset()
|
c.Reset()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set up sampler based on opts
|
|
||||||
temp := float32(opts.Temperature)
|
temp := float32(opts.Temperature)
|
||||||
if temp == 0 {
|
if temp == 0 {
|
||||||
temp = 0.1
|
temp = 0.1
|
||||||
}
|
}
|
||||||
sampler := sample.New(temp, 0, 0, 0)
|
sampler := sample.New(temp, 0, 0, 0)
|
||||||
|
|
||||||
// Tokenize
|
|
||||||
formatted := tokenizer.FormatGemmaPrompt(prompt)
|
|
||||||
tokens := b.tok.Encode(formatted)
|
|
||||||
input := mlx.FromValues(tokens, 1, len(tokens))
|
input := mlx.FromValues(tokens, 1, len(tokens))
|
||||||
|
|
||||||
maxTokens := opts.MaxTokens
|
maxTokens := opts.MaxTokens
|
||||||
|
|
@ -86,8 +84,6 @@ func (b *MLXBackend) Generate(ctx context.Context, prompt string, opts GenOpts)
|
||||||
maxTokens = 2048
|
maxTokens = 2048
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generation loop — force Go GC every 4 tokens so finalizers release
|
|
||||||
// intermediate C array handles that Go GC cannot see as memory pressure.
|
|
||||||
var output []int32
|
var output []int32
|
||||||
for i := 0; i < maxTokens; i++ {
|
for i := 0; i < maxTokens; i++ {
|
||||||
select {
|
select {
|
||||||
|
|
@ -110,20 +106,58 @@ func (b *MLXBackend) Generate(ctx context.Context, prompt string, opts GenOpts)
|
||||||
output = append(output, nextToken)
|
output = append(output, nextToken)
|
||||||
input = mlx.FromValues([]int32{nextToken}, 1, 1)
|
input = mlx.FromValues([]int32{nextToken}, 1, 1)
|
||||||
|
|
||||||
// Force GC to collect intermediate arrays + release Metal allocator cache
|
// Stream the token text to the callback
|
||||||
|
if cb != nil {
|
||||||
|
tokenText := b.tok.Decode([]int32{nextToken})
|
||||||
|
if err := cb(tokenText); err != nil {
|
||||||
|
runtime.GC()
|
||||||
|
mlx.ClearCache()
|
||||||
|
return b.tok.Decode(output), err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if i%4 == 3 {
|
if i%4 == 3 {
|
||||||
runtime.GC()
|
runtime.GC()
|
||||||
mlx.ClearCache()
|
mlx.ClearCache()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Cleanup between requests
|
|
||||||
runtime.GC()
|
runtime.GC()
|
||||||
mlx.ClearCache()
|
mlx.ClearCache()
|
||||||
b.checkMemory()
|
b.checkMemory()
|
||||||
return b.tok.Decode(output), nil
|
return b.tok.Decode(output), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Generate produces text from a prompt using native Metal inference.
|
||||||
|
func (b *MLXBackend) Generate(ctx context.Context, prompt string, opts GenOpts) (string, error) {
|
||||||
|
formatted := formatPrompt(b.model.ModelType(), prompt)
|
||||||
|
tokens := b.tok.Encode(formatted)
|
||||||
|
return b.generate(ctx, tokens, opts, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Chat formats messages and generates a response.
|
||||||
|
func (b *MLXBackend) Chat(ctx context.Context, messages []Message, opts GenOpts) (string, error) {
|
||||||
|
prompt := formatChat(b.model.ModelType(), messages)
|
||||||
|
tokens := b.tok.Encode(prompt)
|
||||||
|
return b.generate(ctx, tokens, opts, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GenerateStream streams tokens from a single prompt via the callback.
|
||||||
|
func (b *MLXBackend) GenerateStream(ctx context.Context, prompt string, opts GenOpts, cb TokenCallback) error {
|
||||||
|
formatted := formatPrompt(b.model.ModelType(), prompt)
|
||||||
|
tokens := b.tok.Encode(formatted)
|
||||||
|
_, err := b.generate(ctx, tokens, opts, cb)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// ChatStream streams tokens from a chat conversation via the callback.
|
||||||
|
func (b *MLXBackend) ChatStream(ctx context.Context, messages []Message, opts GenOpts, cb TokenCallback) error {
|
||||||
|
prompt := formatChat(b.model.ModelType(), messages)
|
||||||
|
tokens := b.tok.Encode(prompt)
|
||||||
|
_, err := b.generate(ctx, tokens, opts, cb)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
// lastPosition extracts the last sequence position from [B, L, V] logits → [B, V].
|
// lastPosition extracts the last sequence position from [B, L, V] logits → [B, V].
|
||||||
func lastPosition(logits *mlx.Array) *mlx.Array {
|
func lastPosition(logits *mlx.Array) *mlx.Array {
|
||||||
shape := logits.Shape()
|
shape := logits.Shape()
|
||||||
|
|
@ -137,9 +171,49 @@ func lastPosition(logits *mlx.Array) *mlx.Array {
|
||||||
return logits
|
return logits
|
||||||
}
|
}
|
||||||
|
|
||||||
// Chat formats messages and generates a response.
|
// checkMemory logs Metal memory usage and forces cleanup if it exceeds budget.
|
||||||
func (b *MLXBackend) Chat(ctx context.Context, messages []Message, opts GenOpts) (string, error) {
|
func (b *MLXBackend) checkMemory() {
|
||||||
// Format as Gemma chat
|
active := mlx.GetActiveMemory()
|
||||||
|
budget := b.modelBytes * 3
|
||||||
|
if active > budget {
|
||||||
|
slog.Warn("mlx: memory over budget, forcing cleanup",
|
||||||
|
"active_mb", active/1024/1024,
|
||||||
|
"model_mb", b.modelBytes/1024/1024,
|
||||||
|
"peak_mb", mlx.GetPeakMemory()/1024/1024,
|
||||||
|
)
|
||||||
|
runtime.GC()
|
||||||
|
runtime.GC()
|
||||||
|
mlx.ClearCache()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Name returns the backend identifier.
|
||||||
|
func (b *MLXBackend) Name() string { return "mlx" }
|
||||||
|
|
||||||
|
// Available reports whether Metal GPU is ready.
|
||||||
|
func (b *MLXBackend) Available() bool { return mlx.MetalAvailable() }
|
||||||
|
|
||||||
|
// formatPrompt wraps a raw prompt in the model's chat template for single-turn generation.
|
||||||
|
func formatPrompt(modelType, prompt string) string {
|
||||||
|
switch modelType {
|
||||||
|
case "qwen3":
|
||||||
|
return fmt.Sprintf("<|im_start|>user\n%s<|im_end|>\n<|im_start|>assistant\n", prompt)
|
||||||
|
default:
|
||||||
|
return tokenizer.FormatGemmaPrompt(prompt)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// formatChat builds a multi-turn chat prompt from messages using the model's template.
|
||||||
|
func formatChat(modelType string, messages []Message) string {
|
||||||
|
switch modelType {
|
||||||
|
case "qwen3":
|
||||||
|
return formatQwen3Chat(messages)
|
||||||
|
default:
|
||||||
|
return formatGemmaChat(messages)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func formatGemmaChat(messages []Message) string {
|
||||||
var prompt string
|
var prompt string
|
||||||
for _, msg := range messages {
|
for _, msg := range messages {
|
||||||
switch msg.Role {
|
switch msg.Role {
|
||||||
|
|
@ -152,83 +226,21 @@ func (b *MLXBackend) Chat(ctx context.Context, messages []Message, opts GenOpts)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
prompt += "<start_of_turn>model\n"
|
prompt += "<start_of_turn>model\n"
|
||||||
|
return prompt
|
||||||
// Use raw prompt (already formatted)
|
|
||||||
b.mu.Lock()
|
|
||||||
defer b.mu.Unlock()
|
|
||||||
|
|
||||||
for _, c := range b.caches {
|
|
||||||
c.Reset()
|
|
||||||
}
|
|
||||||
|
|
||||||
temp := float32(opts.Temperature)
|
|
||||||
if temp == 0 {
|
|
||||||
temp = 0.1
|
|
||||||
}
|
|
||||||
sampler := sample.New(temp, 0, 0, 0)
|
|
||||||
|
|
||||||
tokens := b.tok.Encode(prompt)
|
|
||||||
input := mlx.FromValues(tokens, 1, len(tokens))
|
|
||||||
|
|
||||||
maxTokens := opts.MaxTokens
|
|
||||||
if maxTokens == 0 {
|
|
||||||
maxTokens = 2048
|
|
||||||
}
|
|
||||||
|
|
||||||
var output []int32
|
|
||||||
for i := 0; i < maxTokens; i++ {
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
runtime.GC()
|
|
||||||
mlx.ClearCache()
|
|
||||||
return b.tok.Decode(output), ctx.Err()
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
|
|
||||||
logits := b.model.Forward(input, b.caches)
|
|
||||||
logits = lastPosition(logits)
|
|
||||||
next := sampler.Sample(logits)
|
|
||||||
mlx.Materialize(next)
|
|
||||||
|
|
||||||
nextToken := int32(next.Int())
|
|
||||||
if nextToken == b.tok.EOSToken() {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
output = append(output, nextToken)
|
|
||||||
input = mlx.FromValues([]int32{nextToken}, 1, 1)
|
|
||||||
|
|
||||||
// Force GC to collect intermediate arrays + release Metal allocator cache
|
|
||||||
if i%4 == 3 {
|
|
||||||
runtime.GC()
|
|
||||||
mlx.ClearCache()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Cleanup between requests
|
|
||||||
runtime.GC()
|
|
||||||
mlx.ClearCache()
|
|
||||||
b.checkMemory()
|
|
||||||
return b.tok.Decode(output), nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// checkMemory logs Metal memory usage and forces cleanup if it exceeds budget.
|
func formatQwen3Chat(messages []Message) string {
|
||||||
func (b *MLXBackend) checkMemory() {
|
var prompt string
|
||||||
active := mlx.GetActiveMemory()
|
for _, msg := range messages {
|
||||||
budget := b.modelBytes * 3 // 3× model size = danger zone
|
switch msg.Role {
|
||||||
if active > budget {
|
case "system":
|
||||||
slog.Warn("mlx: memory over budget, forcing cleanup",
|
prompt += fmt.Sprintf("<|im_start|>system\n%s<|im_end|>\n", msg.Content)
|
||||||
"active_mb", active/1024/1024,
|
case "user":
|
||||||
"model_mb", b.modelBytes/1024/1024,
|
prompt += fmt.Sprintf("<|im_start|>user\n%s<|im_end|>\n", msg.Content)
|
||||||
"peak_mb", mlx.GetPeakMemory()/1024/1024,
|
case "assistant":
|
||||||
)
|
prompt += fmt.Sprintf("<|im_start|>assistant\n%s<|im_end|>\n", msg.Content)
|
||||||
runtime.GC()
|
}
|
||||||
runtime.GC() // double GC to run finalizers
|
|
||||||
mlx.ClearCache()
|
|
||||||
}
|
}
|
||||||
|
prompt += "<|im_start|>assistant\n"
|
||||||
|
return prompt
|
||||||
}
|
}
|
||||||
|
|
||||||
// Name returns the backend identifier.
|
|
||||||
func (b *MLXBackend) Name() string { return "mlx" }
|
|
||||||
|
|
||||||
// Available reports whether Metal GPU is ready.
|
|
||||||
func (b *MLXBackend) Available() bool { return mlx.MetalAvailable() }
|
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
//go:build darwin && arm64 && mlx
|
//go:build darwin && arm64
|
||||||
|
|
||||||
package mlx
|
package mlx
|
||||||
|
|
||||||
|
|
|
||||||
2
mlx/cache/cache.go
vendored
2
mlx/cache/cache.go
vendored
|
|
@ -1,4 +1,4 @@
|
||||||
//go:build darwin && arm64 && mlx
|
//go:build darwin && arm64
|
||||||
|
|
||||||
// Package cache provides KV cache implementations for transformer inference.
|
// Package cache provides KV cache implementations for transformer inference.
|
||||||
package cache
|
package cache
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
//go:build darwin && arm64 && mlx
|
//go:build darwin && arm64
|
||||||
|
|
||||||
package mlx
|
package mlx
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
//go:build darwin && arm64 && mlx
|
//go:build darwin && arm64
|
||||||
|
|
||||||
package mlx
|
package mlx
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
//go:build darwin && arm64 && mlx
|
//go:build darwin && arm64
|
||||||
|
|
||||||
package mlx
|
package mlx
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
//go:build darwin && arm64 && mlx
|
//go:build darwin && arm64
|
||||||
|
|
||||||
package mlx
|
package mlx
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
//go:build darwin && arm64 && mlx
|
//go:build darwin && arm64
|
||||||
|
|
||||||
// Package mlx provides Go bindings for Apple's MLX framework via mlx-c.
|
// Package mlx provides Go bindings for Apple's MLX framework via mlx-c.
|
||||||
//
|
//
|
||||||
|
|
@ -6,9 +6,9 @@
|
||||||
//
|
//
|
||||||
// cd pkg/mlx && go generate ./...
|
// cd pkg/mlx && go generate ./...
|
||||||
//
|
//
|
||||||
// Build with MLX enabled:
|
// Build (MLX is auto-enabled on darwin/arm64):
|
||||||
//
|
//
|
||||||
// go build -tags mlx -o core .
|
// go build -o core .
|
||||||
package mlx
|
package mlx
|
||||||
|
|
||||||
//go:generate cmake -S . -B build -DCMAKE_INSTALL_PREFIX=dist -DCMAKE_BUILD_TYPE=Release
|
//go:generate cmake -S . -B build -DCMAKE_INSTALL_PREFIX=dist -DCMAKE_BUILD_TYPE=Release
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
//go:build !(darwin && arm64 && mlx)
|
//go:build !(darwin && arm64)
|
||||||
|
|
||||||
// Package mlx provides Go bindings for Apple's MLX framework via mlx-c.
|
// Package mlx provides Go bindings for Apple's MLX framework via mlx-c.
|
||||||
// This stub file is used on non-darwin/non-arm64 platforms or when the
|
// This stub file is used on non-darwin/non-arm64 platforms or when the
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
//go:build darwin && arm64 && mlx
|
//go:build darwin && arm64
|
||||||
|
|
||||||
// Package model provides transformer model architectures for MLX inference.
|
// Package model provides transformer model architectures for MLX inference.
|
||||||
package model
|
package model
|
||||||
|
|
@ -16,12 +16,6 @@ import (
|
||||||
"forge.lthn.ai/core/go-ai/mlx/tokenizer"
|
"forge.lthn.ai/core/go-ai/mlx/tokenizer"
|
||||||
)
|
)
|
||||||
|
|
||||||
// QuantizationConfig holds quantization parameters from config.json.
|
|
||||||
type QuantizationConfig struct {
|
|
||||||
GroupSize int `json:"group_size"`
|
|
||||||
Bits int `json:"bits"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// TextConfig holds Gemma 3 text model configuration.
|
// TextConfig holds Gemma 3 text model configuration.
|
||||||
type TextConfig struct {
|
type TextConfig struct {
|
||||||
HiddenSize int32 `json:"hidden_size"`
|
HiddenSize int32 `json:"hidden_size"`
|
||||||
|
|
@ -168,17 +162,6 @@ func parseConfig(data []byte) (*TextConfig, error) {
|
||||||
return &cfg, nil
|
return &cfg, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// resolveWeight looks up a weight with optional "language_model." prefix.
|
|
||||||
func resolveWeight(weights map[string]*mlx.Array, name string) *mlx.Array {
|
|
||||||
if w, ok := weights[name]; ok {
|
|
||||||
return w
|
|
||||||
}
|
|
||||||
if w, ok := weights["language_model."+name]; ok {
|
|
||||||
return w
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// LoadGemma3 loads a Gemma 3 text model from a directory.
|
// LoadGemma3 loads a Gemma 3 text model from a directory.
|
||||||
func LoadGemma3(modelPath string) (*GemmaModel, error) {
|
func LoadGemma3(modelPath string) (*GemmaModel, error) {
|
||||||
data, err := os.ReadFile(filepath.Join(modelPath, "config.json"))
|
data, err := os.ReadFile(filepath.Join(modelPath, "config.json"))
|
||||||
|
|
@ -428,3 +411,6 @@ func (m *GemmaModel) NumLayers() int { return len(m.Layers) }
|
||||||
|
|
||||||
// Tokenizer returns the model's tokenizer.
|
// Tokenizer returns the model's tokenizer.
|
||||||
func (m *GemmaModel) Tokenizer() *tokenizer.Tokenizer { return m.Tok }
|
func (m *GemmaModel) Tokenizer() *tokenizer.Tokenizer { return m.Tok }
|
||||||
|
|
||||||
|
// ModelType returns the architecture identifier.
|
||||||
|
func (m *GemmaModel) ModelType() string { return "gemma3" }
|
||||||
|
|
|
||||||
74
mlx/model/model.go
Normal file
74
mlx/model/model.go
Normal file
|
|
@ -0,0 +1,74 @@
|
||||||
|
//go:build darwin && arm64
|
||||||
|
|
||||||
|
// Package model provides transformer model architectures for MLX inference.
|
||||||
|
package model
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
|
||||||
|
"forge.lthn.ai/core/go-ai/mlx"
|
||||||
|
"forge.lthn.ai/core/go-ai/mlx/cache"
|
||||||
|
"forge.lthn.ai/core/go-ai/mlx/tokenizer"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Model is the common interface for all transformer model architectures.
|
||||||
|
type Model interface {
|
||||||
|
// Forward runs the model forward pass on token IDs with KV caches.
|
||||||
|
Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array
|
||||||
|
|
||||||
|
// NewCache creates per-layer KV caches for generation.
|
||||||
|
NewCache() []cache.Cache
|
||||||
|
|
||||||
|
// NumLayers returns the number of transformer layers.
|
||||||
|
NumLayers() int
|
||||||
|
|
||||||
|
// Tokenizer returns the model's tokenizer.
|
||||||
|
Tokenizer() *tokenizer.Tokenizer
|
||||||
|
|
||||||
|
// ModelType returns the architecture identifier (e.g. "gemma3", "qwen3").
|
||||||
|
ModelType() string
|
||||||
|
}
|
||||||
|
|
||||||
|
// QuantizationConfig holds quantization parameters from config.json.
|
||||||
|
type QuantizationConfig struct {
|
||||||
|
GroupSize int `json:"group_size"`
|
||||||
|
Bits int `json:"bits"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// resolveWeight looks up a weight with optional "language_model." prefix.
|
||||||
|
func resolveWeight(weights map[string]*mlx.Array, name string) *mlx.Array {
|
||||||
|
if w, ok := weights[name]; ok {
|
||||||
|
return w
|
||||||
|
}
|
||||||
|
if w, ok := weights["language_model."+name]; ok {
|
||||||
|
return w
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadModel auto-detects the model architecture from config.json and loads it.
|
||||||
|
func LoadModel(modelPath string) (Model, error) {
|
||||||
|
data, err := os.ReadFile(filepath.Join(modelPath, "config.json"))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("model: load config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var probe struct {
|
||||||
|
ModelType string `json:"model_type"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(data, &probe); err != nil {
|
||||||
|
return nil, fmt.Errorf("model: parse model_type: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
switch probe.ModelType {
|
||||||
|
case "qwen3":
|
||||||
|
return LoadQwen3(modelPath)
|
||||||
|
case "gemma3", "gemma2":
|
||||||
|
return LoadGemma3(modelPath)
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("model: unsupported architecture %q", probe.ModelType)
|
||||||
|
}
|
||||||
|
}
|
||||||
305
mlx/model/qwen3.go
Normal file
305
mlx/model/qwen3.go
Normal file
|
|
@ -0,0 +1,305 @@
|
||||||
|
//go:build darwin && arm64
|
||||||
|
|
||||||
|
package model
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
|
"math"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
|
||||||
|
"forge.lthn.ai/core/go-ai/mlx"
|
||||||
|
"forge.lthn.ai/core/go-ai/mlx/cache"
|
||||||
|
"forge.lthn.ai/core/go-ai/mlx/tokenizer"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Qwen3Config holds Qwen 3 model configuration.
|
||||||
|
type Qwen3Config struct {
|
||||||
|
HiddenSize int32 `json:"hidden_size"`
|
||||||
|
NumHiddenLayers int32 `json:"num_hidden_layers"`
|
||||||
|
IntermediateSize int32 `json:"intermediate_size"`
|
||||||
|
NumAttentionHeads int32 `json:"num_attention_heads"`
|
||||||
|
NumKeyValueHeads int32 `json:"num_key_value_heads"`
|
||||||
|
HeadDim int32 `json:"head_dim"`
|
||||||
|
VocabSize int32 `json:"vocab_size"`
|
||||||
|
RMSNormEps float32 `json:"rms_norm_eps"`
|
||||||
|
RopeTheta float32 `json:"rope_theta"`
|
||||||
|
MaxPositionEmbeddings int32 `json:"max_position_embeddings"`
|
||||||
|
|
||||||
|
Quantization *QuantizationConfig `json:"-"`
|
||||||
|
Scale float32 `json:"-"` // 1/sqrt(head_dim)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Qwen3Model is the Qwen 3 text model.
|
||||||
|
type Qwen3Model struct {
|
||||||
|
EmbedTokens *mlx.Embedding
|
||||||
|
Layers []*Qwen3DecoderLayer
|
||||||
|
Norm *mlx.RMSNormModule
|
||||||
|
Output *mlx.Linear
|
||||||
|
|
||||||
|
Tok *tokenizer.Tokenizer
|
||||||
|
Cfg *Qwen3Config
|
||||||
|
}
|
||||||
|
|
||||||
|
// Qwen3DecoderLayer is a single transformer block.
|
||||||
|
// Qwen 3 uses standard pre-norm residual: norm→attn→add, norm→mlp→add.
|
||||||
|
type Qwen3DecoderLayer struct {
|
||||||
|
InputNorm *mlx.RMSNormModule // Pre-attention norm
|
||||||
|
PostAttnNorm *mlx.RMSNormModule // Pre-MLP norm (confusingly named post_attention_layernorm)
|
||||||
|
Attention *Qwen3Attention
|
||||||
|
MLP *Qwen3MLP
|
||||||
|
}
|
||||||
|
|
||||||
|
// Qwen3Attention implements Qwen 3 GQA with Q/K RMS normalization.
|
||||||
|
type Qwen3Attention struct {
|
||||||
|
QProj *mlx.Linear
|
||||||
|
KProj *mlx.Linear
|
||||||
|
VProj *mlx.Linear
|
||||||
|
OProj *mlx.Linear
|
||||||
|
QNorm *mlx.RMSNormModule
|
||||||
|
KNorm *mlx.RMSNormModule
|
||||||
|
}
|
||||||
|
|
||||||
|
// Qwen3MLP is the SwiGLU feed-forward network: down(silu(gate(x)) * up(x)).
|
||||||
|
type Qwen3MLP struct {
|
||||||
|
GateProj *mlx.Linear
|
||||||
|
UpProj *mlx.Linear
|
||||||
|
DownProj *mlx.Linear
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseQwen3Config(data []byte) (*Qwen3Config, error) {
|
||||||
|
var cfg Qwen3Config
|
||||||
|
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Top-level quantization
|
||||||
|
var wrapper struct {
|
||||||
|
Quantization *QuantizationConfig `json:"quantization"`
|
||||||
|
}
|
||||||
|
json.Unmarshal(data, &wrapper)
|
||||||
|
cfg.Quantization = wrapper.Quantization
|
||||||
|
|
||||||
|
// Compute scale
|
||||||
|
if cfg.HeadDim == 0 {
|
||||||
|
cfg.HeadDim = cfg.HiddenSize / cfg.NumAttentionHeads
|
||||||
|
}
|
||||||
|
cfg.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim)))
|
||||||
|
|
||||||
|
// Defaults
|
||||||
|
if cfg.RopeTheta == 0 {
|
||||||
|
cfg.RopeTheta = 1000000
|
||||||
|
}
|
||||||
|
if cfg.RMSNormEps == 0 {
|
||||||
|
cfg.RMSNormEps = 1e-6
|
||||||
|
}
|
||||||
|
if cfg.VocabSize == 0 {
|
||||||
|
cfg.VocabSize = 151936
|
||||||
|
}
|
||||||
|
|
||||||
|
return &cfg, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadQwen3 loads a Qwen 3 model from a safetensors directory.
|
||||||
|
func LoadQwen3(modelPath string) (*Qwen3Model, error) {
|
||||||
|
data, err := os.ReadFile(filepath.Join(modelPath, "config.json"))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("qwen3: load config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg, err := parseQwen3Config(data)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("qwen3: parse config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tok, err := tokenizer.Load(filepath.Join(modelPath, "tokenizer.json"))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("qwen3: load tokenizer: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load weights from all safetensors files
|
||||||
|
weights := make(map[string]*mlx.Array)
|
||||||
|
matches, _ := filepath.Glob(filepath.Join(modelPath, "*.safetensors"))
|
||||||
|
for _, path := range matches {
|
||||||
|
for name, arr := range mlx.LoadSafetensors(path) {
|
||||||
|
weights[name] = arr
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
w := func(name string) *mlx.Array { return resolveWeight(weights, name) }
|
||||||
|
|
||||||
|
// Quantization setup
|
||||||
|
q := cfg.Quantization
|
||||||
|
if q != nil {
|
||||||
|
slog.Info("qwen3: using quantized inference", "bits", q.Bits, "group_size", q.GroupSize)
|
||||||
|
}
|
||||||
|
linear := func(prefix string) *mlx.Linear {
|
||||||
|
weight := w(prefix + ".weight")
|
||||||
|
scales := w(prefix + ".scales")
|
||||||
|
biases := w(prefix + ".biases")
|
||||||
|
bias := w(prefix + ".bias")
|
||||||
|
if scales != nil && q != nil {
|
||||||
|
return mlx.NewQuantizedLinear(weight, scales, biases, bias, q.GroupSize, q.Bits)
|
||||||
|
}
|
||||||
|
return mlx.NewLinear(weight, bias)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Embedding
|
||||||
|
embed := &mlx.Embedding{Weight: w("model.embed_tokens.weight")}
|
||||||
|
if embedScales := w("model.embed_tokens.scales"); embedScales != nil && q != nil {
|
||||||
|
embed.Scales = embedScales
|
||||||
|
embed.Biases = w("model.embed_tokens.biases")
|
||||||
|
embed.GroupSize = q.GroupSize
|
||||||
|
embed.Bits = q.Bits
|
||||||
|
}
|
||||||
|
|
||||||
|
m := &Qwen3Model{
|
||||||
|
EmbedTokens: embed,
|
||||||
|
Layers: make([]*Qwen3DecoderLayer, cfg.NumHiddenLayers),
|
||||||
|
Norm: &mlx.RMSNormModule{Weight: w("model.norm.weight")},
|
||||||
|
Tok: tok,
|
||||||
|
Cfg: cfg,
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := int32(0); i < cfg.NumHiddenLayers; i++ {
|
||||||
|
p := fmt.Sprintf("model.layers.%d", i)
|
||||||
|
m.Layers[i] = &Qwen3DecoderLayer{
|
||||||
|
InputNorm: &mlx.RMSNormModule{Weight: w(p + ".input_layernorm.weight")},
|
||||||
|
PostAttnNorm: &mlx.RMSNormModule{Weight: w(p + ".post_attention_layernorm.weight")},
|
||||||
|
Attention: &Qwen3Attention{
|
||||||
|
QProj: linear(p + ".self_attn.q_proj"),
|
||||||
|
KProj: linear(p + ".self_attn.k_proj"),
|
||||||
|
VProj: linear(p + ".self_attn.v_proj"),
|
||||||
|
OProj: linear(p + ".self_attn.o_proj"),
|
||||||
|
QNorm: &mlx.RMSNormModule{Weight: w(p + ".self_attn.q_norm.weight")},
|
||||||
|
KNorm: &mlx.RMSNormModule{Weight: w(p + ".self_attn.k_norm.weight")},
|
||||||
|
},
|
||||||
|
MLP: &Qwen3MLP{
|
||||||
|
GateProj: linear(p + ".mlp.gate_proj"),
|
||||||
|
UpProj: linear(p + ".mlp.up_proj"),
|
||||||
|
DownProj: linear(p + ".mlp.down_proj"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Output head — Qwen 3 has tie_word_embeddings=false, so lm_head is separate
|
||||||
|
lmHeadWeight := w("lm_head.weight")
|
||||||
|
if lmHeadWeight != nil {
|
||||||
|
lmHeadScales := w("lm_head.scales")
|
||||||
|
if lmHeadScales != nil && q != nil {
|
||||||
|
m.Output = mlx.NewQuantizedLinear(lmHeadWeight, lmHeadScales, w("lm_head.biases"), nil, q.GroupSize, q.Bits)
|
||||||
|
} else {
|
||||||
|
m.Output = mlx.NewLinear(lmHeadWeight, nil)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
m.Output = m.EmbedTokens.AsLinear()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Materialise all weights onto Metal
|
||||||
|
var allArrays []*mlx.Array
|
||||||
|
for _, a := range weights {
|
||||||
|
allArrays = append(allArrays, a)
|
||||||
|
}
|
||||||
|
mlx.Materialize(allArrays...)
|
||||||
|
|
||||||
|
slog.Info("qwen3: model loaded",
|
||||||
|
"layers", cfg.NumHiddenLayers,
|
||||||
|
"hidden", cfg.HiddenSize,
|
||||||
|
"heads", cfg.NumAttentionHeads,
|
||||||
|
"kv_heads", cfg.NumKeyValueHeads,
|
||||||
|
"head_dim", cfg.HeadDim,
|
||||||
|
"vocab", cfg.VocabSize,
|
||||||
|
)
|
||||||
|
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Forward runs the Qwen 3 forward pass.
|
||||||
|
// Unlike Gemma, Qwen does NOT scale embeddings by sqrt(hidden_size).
|
||||||
|
func (m *Qwen3Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array {
|
||||||
|
shape := tokens.Shape()
|
||||||
|
B, L := shape[0], shape[1]
|
||||||
|
|
||||||
|
h := m.EmbedTokens.Forward(tokens)
|
||||||
|
|
||||||
|
for i, layer := range m.Layers {
|
||||||
|
h = layer.forward(h, caches[i], B, L, m.Cfg)
|
||||||
|
}
|
||||||
|
|
||||||
|
return m.Output.Forward(m.Norm.Forward(h, m.Cfg.RMSNormEps))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Qwen3DecoderLayer) forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Qwen3Config) *mlx.Array {
|
||||||
|
// Pre-attention norm → attention → residual add
|
||||||
|
normed := l.InputNorm.Forward(x, cfg.RMSNormEps)
|
||||||
|
attnOut := l.Attention.forward(normed, c, B, L, cfg)
|
||||||
|
h := mlx.Add(x, attnOut)
|
||||||
|
|
||||||
|
// Pre-MLP norm → MLP → residual add
|
||||||
|
normed = l.PostAttnNorm.Forward(h, cfg.RMSNormEps)
|
||||||
|
mlpOut := l.MLP.forward(normed)
|
||||||
|
return mlx.Add(h, mlpOut)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Qwen3Attention) forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Qwen3Config) *mlx.Array {
|
||||||
|
q := a.QProj.Forward(x)
|
||||||
|
k := a.KProj.Forward(x)
|
||||||
|
v := a.VProj.Forward(x)
|
||||||
|
|
||||||
|
// Reshape to [B, num_heads, L, head_dim]
|
||||||
|
q = mlx.AsStrided(q, []int32{B, cfg.NumAttentionHeads, L, cfg.HeadDim},
|
||||||
|
[]int64{int64(L * cfg.NumAttentionHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumAttentionHeads * cfg.HeadDim), 1}, 0)
|
||||||
|
k = mlx.AsStrided(k, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim},
|
||||||
|
[]int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0)
|
||||||
|
v = mlx.AsStrided(v, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim},
|
||||||
|
[]int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0)
|
||||||
|
|
||||||
|
// Q/K RMS normalization (Qwen 3 has this)
|
||||||
|
q = a.QNorm.Forward(q, cfg.RMSNormEps)
|
||||||
|
k = a.KNorm.Forward(k, cfg.RMSNormEps)
|
||||||
|
|
||||||
|
// RoPE — single theta for all layers (no sliding window)
|
||||||
|
q = mlx.RoPE(q, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, c.Offset())
|
||||||
|
k = mlx.RoPE(k, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, c.Offset())
|
||||||
|
|
||||||
|
// Update KV cache
|
||||||
|
k, v = c.Update(k, v, int(L))
|
||||||
|
|
||||||
|
// GQA: repeat K/V heads to match Q heads
|
||||||
|
repeatFactor := cfg.NumAttentionHeads / cfg.NumKeyValueHeads
|
||||||
|
if repeatFactor > 1 {
|
||||||
|
k = mlx.RepeatKV(k, repeatFactor)
|
||||||
|
v = mlx.RepeatKV(v, repeatFactor)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Scaled dot-product attention
|
||||||
|
out := mlx.ScaledDotProductAttention(q, k, v, cfg.Scale, L > 1)
|
||||||
|
out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.HeadDim)
|
||||||
|
return a.OProj.Forward(out)
|
||||||
|
}
|
||||||
|
|
||||||
|
// forward computes SwiGLU: down(silu(gate(x)) * up(x)).
|
||||||
|
func (m *Qwen3MLP) forward(x *mlx.Array) *mlx.Array {
|
||||||
|
gate := mlx.SiLU(m.GateProj.Forward(x))
|
||||||
|
return m.DownProj.Forward(mlx.Mul(gate, m.UpProj.Forward(x)))
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewCache creates per-layer KV caches. Qwen 3 uses global attention only.
|
||||||
|
func (m *Qwen3Model) NewCache() []cache.Cache {
|
||||||
|
caches := make([]cache.Cache, len(m.Layers))
|
||||||
|
for i := range caches {
|
||||||
|
caches[i] = cache.NewKVCache()
|
||||||
|
}
|
||||||
|
return caches
|
||||||
|
}
|
||||||
|
|
||||||
|
// NumLayers returns the number of transformer layers.
|
||||||
|
func (m *Qwen3Model) NumLayers() int { return len(m.Layers) }
|
||||||
|
|
||||||
|
// Tokenizer returns the model's tokenizer.
|
||||||
|
func (m *Qwen3Model) Tokenizer() *tokenizer.Tokenizer { return m.Tok }
|
||||||
|
|
||||||
|
// ModelType returns the architecture identifier.
|
||||||
|
func (m *Qwen3Model) ModelType() string { return "qwen3" }
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
//go:build darwin && arm64 && mlx
|
//go:build darwin && arm64
|
||||||
|
|
||||||
package mlx
|
package mlx
|
||||||
|
|
||||||
|
|
|
||||||
14
mlx/ops.go
14
mlx/ops.go
|
|
@ -1,4 +1,4 @@
|
||||||
//go:build darwin && arm64 && mlx
|
//go:build darwin && arm64
|
||||||
|
|
||||||
package mlx
|
package mlx
|
||||||
|
|
||||||
|
|
@ -68,6 +68,18 @@ func Exp(a *Array) *Array {
|
||||||
return out
|
return out
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Sigmoid returns element-wise 1/(1+exp(-a)).
|
||||||
|
func Sigmoid(a *Array) *Array {
|
||||||
|
out := New("SIGMOID", a)
|
||||||
|
C.mlx_sigmoid(&out.ctx, a.ctx, DefaultStream().ctx)
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
// SiLU returns element-wise x * sigmoid(x) (Swish activation).
|
||||||
|
func SiLU(a *Array) *Array {
|
||||||
|
return Mul(a, Sigmoid(a))
|
||||||
|
}
|
||||||
|
|
||||||
// Tanh returns element-wise tanh(a).
|
// Tanh returns element-wise tanh(a).
|
||||||
func Tanh(a *Array) *Array {
|
func Tanh(a *Array) *Array {
|
||||||
out := New("TANH", a)
|
out := New("TANH", a)
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
//go:build darwin && arm64 && mlx
|
//go:build darwin && arm64
|
||||||
|
|
||||||
package mlx
|
package mlx
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
//go:build darwin && arm64 && mlx
|
//go:build darwin && arm64
|
||||||
|
|
||||||
// Package sample provides composable token sampling strategies.
|
// Package sample provides composable token sampling strategies.
|
||||||
package sample
|
package sample
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
//go:build darwin && arm64 && mlx
|
//go:build darwin && arm64
|
||||||
|
|
||||||
package mlx
|
package mlx
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
//go:build darwin && arm64 && mlx
|
//go:build darwin && arm64
|
||||||
|
|
||||||
package mlx
|
package mlx
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
//go:build darwin && arm64 && mlx
|
//go:build darwin && arm64
|
||||||
|
|
||||||
// Package tokenizer provides BPE/SentencePiece tokenization for Gemma models.
|
// Package tokenizer provides BPE tokenization for transformer models.
|
||||||
package tokenizer
|
package tokenizer
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
|
@ -19,6 +19,11 @@ type Tokenizer struct {
|
||||||
|
|
||||||
bosToken int32
|
bosToken int32
|
||||||
eosToken int32
|
eosToken int32
|
||||||
|
|
||||||
|
// GPT-2 byte-level BPE support (used by Qwen, GPT, Llama, etc.)
|
||||||
|
isGPT2BPE bool
|
||||||
|
gpt2Decoder map[rune]byte // Unicode char → original byte
|
||||||
|
gpt2Encoder map[byte]rune // original byte → Unicode char
|
||||||
}
|
}
|
||||||
|
|
||||||
type mergePair struct {
|
type mergePair struct {
|
||||||
|
|
@ -32,7 +37,7 @@ type tokenizerJSON struct {
|
||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
Vocab json.RawMessage `json:"vocab"`
|
Vocab json.RawMessage `json:"vocab"`
|
||||||
Merges json.RawMessage `json:"merges"`
|
Merges json.RawMessage `json:"merges"`
|
||||||
ByteFallback bool `json:"byte_fallback"`
|
ByteFallback bool `json:"byte_fallback"`
|
||||||
} `json:"model"`
|
} `json:"model"`
|
||||||
AddedTokens []struct {
|
AddedTokens []struct {
|
||||||
ID int32 `json:"id"`
|
ID int32 `json:"id"`
|
||||||
|
|
@ -71,7 +76,6 @@ func Load(path string) (*Tokenizer, error) {
|
||||||
|
|
||||||
// Parse merges — supports both ["a b", ...] and [["a","b"], ...] formats
|
// Parse merges — supports both ["a b", ...] and [["a","b"], ...] formats
|
||||||
if len(tj.Model.Merges) > 0 {
|
if len(tj.Model.Merges) > 0 {
|
||||||
// Try array-of-strings first
|
|
||||||
var stringMerges []string
|
var stringMerges []string
|
||||||
if err := json.Unmarshal(tj.Model.Merges, &stringMerges); err == nil {
|
if err := json.Unmarshal(tj.Model.Merges, &stringMerges); err == nil {
|
||||||
for rank, merge := range stringMerges {
|
for rank, merge := range stringMerges {
|
||||||
|
|
@ -81,7 +85,6 @@ func Load(path string) (*Tokenizer, error) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// Try array-of-arrays: [["a","b"], ...]
|
|
||||||
var arrayMerges [][]string
|
var arrayMerges [][]string
|
||||||
if err := json.Unmarshal(tj.Model.Merges, &arrayMerges); err == nil {
|
if err := json.Unmarshal(tj.Model.Merges, &arrayMerges); err == nil {
|
||||||
for rank, pair := range arrayMerges {
|
for rank, pair := range arrayMerges {
|
||||||
|
|
@ -102,37 +105,77 @@ func Load(path string) (*Tokenizer, error) {
|
||||||
t.invVocab[tok.ID] = tok.Content
|
t.invVocab[tok.ID] = tok.Content
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set BOS/EOS
|
// Detect GPT-2 byte-level BPE (Qwen, GPT, Llama use Ġ for space)
|
||||||
|
if _, ok := t.vocab["Ġ"]; ok {
|
||||||
|
t.isGPT2BPE = true
|
||||||
|
t.gpt2Decoder, t.gpt2Encoder = buildGPT2ByteMaps()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set BOS/EOS — detect model family from special tokens
|
||||||
if id, ok := t.special["<bos>"]; ok {
|
if id, ok := t.special["<bos>"]; ok {
|
||||||
t.bosToken = id
|
t.bosToken = id
|
||||||
}
|
}
|
||||||
if id, ok := t.special["<eos>"]; ok {
|
if id, ok := t.special["<eos>"]; ok {
|
||||||
t.eosToken = id
|
t.eosToken = id
|
||||||
}
|
}
|
||||||
|
// Gemma: <end_of_turn> is the generation stop token
|
||||||
if id, ok := t.special["<end_of_turn>"]; ok {
|
if id, ok := t.special["<end_of_turn>"]; ok {
|
||||||
t.eosToken = id // Gemma uses end_of_turn as EOS
|
t.eosToken = id
|
||||||
|
}
|
||||||
|
// Qwen3: <|im_end|> is the generation stop token
|
||||||
|
if id, ok := t.special["<|im_end|>"]; ok {
|
||||||
|
t.eosToken = id
|
||||||
|
}
|
||||||
|
// Qwen3 BOS: <|im_start|>
|
||||||
|
if id, ok := t.special["<|im_start|>"]; ok {
|
||||||
|
t.bosToken = id
|
||||||
}
|
}
|
||||||
|
|
||||||
return t, nil
|
return t, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// buildGPT2ByteMaps creates the GPT-2 byte-level BPE encoding/decoding maps.
|
||||||
|
// GPT-2 maps all 256 bytes to printable Unicode characters to avoid control chars
|
||||||
|
// in the vocabulary. Printable ASCII + Latin-1 Supplement map to themselves;
|
||||||
|
// everything else (0-32, 127-160, 173) maps to U+0100 onwards.
|
||||||
|
func buildGPT2ByteMaps() (decoder map[rune]byte, encoder map[byte]rune) {
|
||||||
|
encoder = make(map[byte]rune, 256)
|
||||||
|
decoder = make(map[rune]byte, 256)
|
||||||
|
|
||||||
|
// Self-mapping ranges: printable ASCII + Latin-1 Supplement
|
||||||
|
// Use int loop variable to avoid byte overflow at 255.
|
||||||
|
selfMap := func(lo, hi int) {
|
||||||
|
for b := lo; b <= hi; b++ {
|
||||||
|
encoder[byte(b)] = rune(b)
|
||||||
|
decoder[rune(b)] = byte(b)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
selfMap(33, 126) // ! through ~
|
||||||
|
selfMap(161, 172) // ¡ through ¬
|
||||||
|
selfMap(174, 255) // ® through ÿ
|
||||||
|
|
||||||
|
// Non-self-mapping: control chars, space, DEL, and gaps
|
||||||
|
n := 0
|
||||||
|
for b := 0; b < 256; b++ {
|
||||||
|
if _, ok := encoder[byte(b)]; !ok {
|
||||||
|
r := rune(256 + n)
|
||||||
|
encoder[byte(b)] = r
|
||||||
|
decoder[r] = byte(b)
|
||||||
|
n++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// Encode converts text to token IDs. Prepends BOS token.
|
// Encode converts text to token IDs. Prepends BOS token.
|
||||||
func (t *Tokenizer) Encode(text string) []int32 {
|
func (t *Tokenizer) Encode(text string) []int32 {
|
||||||
tokens := []int32{t.bosToken}
|
tokens := []int32{t.bosToken}
|
||||||
|
|
||||||
// Simple BPE encoding — split into characters then merge
|
if t.isGPT2BPE {
|
||||||
// This is a simplified version. Full implementation handles
|
return t.encodeGPT2(text)
|
||||||
// Unicode, byte fallback, and efficient BPE merging.
|
|
||||||
chars := []string{}
|
|
||||||
for _, r := range text {
|
|
||||||
s := string(r)
|
|
||||||
if s == " " {
|
|
||||||
s = "▁" // SentencePiece space marker
|
|
||||||
}
|
|
||||||
chars = append(chars, s)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check for special tokens first
|
// SentencePiece style encoding
|
||||||
remaining := text
|
remaining := text
|
||||||
for remaining != "" {
|
for remaining != "" {
|
||||||
found := false
|
found := false
|
||||||
|
|
@ -145,7 +188,6 @@ func (t *Tokenizer) Encode(text string) []int32 {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if !found {
|
if !found {
|
||||||
// Encode character by character (simplified BPE)
|
|
||||||
r := []rune(remaining)
|
r := []rune(remaining)
|
||||||
ch := "▁" + string(r[0])
|
ch := "▁" + string(r[0])
|
||||||
if id, ok := t.vocab[ch]; ok {
|
if id, ok := t.vocab[ch]; ok {
|
||||||
|
|
@ -160,24 +202,95 @@ func (t *Tokenizer) Encode(text string) []int32 {
|
||||||
return tokens
|
return tokens
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// encodeGPT2 encodes text using GPT-2 byte-level BPE.
|
||||||
|
func (t *Tokenizer) encodeGPT2(text string) []int32 {
|
||||||
|
tokens := []int32{t.bosToken}
|
||||||
|
|
||||||
|
// Convert text bytes to GPT-2 Unicode representation
|
||||||
|
var encoded strings.Builder
|
||||||
|
for _, b := range []byte(text) {
|
||||||
|
if r, ok := t.gpt2Encoder[b]; ok {
|
||||||
|
encoded.WriteRune(r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
gpt2Text := encoded.String()
|
||||||
|
|
||||||
|
// Scan for special tokens and regular text
|
||||||
|
remaining := gpt2Text
|
||||||
|
for remaining != "" {
|
||||||
|
// Check special tokens (these are stored as-is, not byte-encoded)
|
||||||
|
found := false
|
||||||
|
for tok, id := range t.special {
|
||||||
|
// Special tokens in GPT-2 tokenizers are stored in their original form
|
||||||
|
// Convert the special token to GPT-2 encoding for matching
|
||||||
|
var encTok strings.Builder
|
||||||
|
for _, b := range []byte(tok) {
|
||||||
|
if r, ok := t.gpt2Encoder[b]; ok {
|
||||||
|
encTok.WriteRune(r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
encStr := encTok.String()
|
||||||
|
if strings.HasPrefix(remaining, encStr) {
|
||||||
|
tokens = append(tokens, id)
|
||||||
|
remaining = remaining[len(encStr):]
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
// Character-by-character lookup (simplified BPE)
|
||||||
|
r := []rune(remaining)
|
||||||
|
ch := string(r[0])
|
||||||
|
if id, ok := t.vocab[ch]; ok {
|
||||||
|
tokens = append(tokens, id)
|
||||||
|
}
|
||||||
|
remaining = string(r[1:])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return tokens
|
||||||
|
}
|
||||||
|
|
||||||
// Decode converts token IDs back to text.
|
// Decode converts token IDs back to text.
|
||||||
func (t *Tokenizer) Decode(tokens []int32) string {
|
func (t *Tokenizer) Decode(tokens []int32) string {
|
||||||
var sb strings.Builder
|
var sb strings.Builder
|
||||||
for _, id := range tokens {
|
for _, id := range tokens {
|
||||||
if text, ok := t.invVocab[id]; ok {
|
if text, ok := t.invVocab[id]; ok {
|
||||||
// Replace SentencePiece space marker
|
// Skip special tokens in decode output
|
||||||
text = strings.ReplaceAll(text, "▁", " ")
|
if _, isSpecial := t.special[text]; isSpecial {
|
||||||
|
continue
|
||||||
|
}
|
||||||
sb.WriteString(text)
|
sb.WriteString(text)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
result := sb.String()
|
raw := sb.String()
|
||||||
// Trim leading space from SentencePiece encoding
|
|
||||||
|
if t.isGPT2BPE {
|
||||||
|
return t.decodeGPT2Bytes(raw)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SentencePiece style
|
||||||
|
result := strings.ReplaceAll(raw, "▁", " ")
|
||||||
if strings.HasPrefix(result, " ") {
|
if strings.HasPrefix(result, " ") {
|
||||||
result = result[1:]
|
result = result[1:]
|
||||||
}
|
}
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// decodeGPT2Bytes converts GPT-2 byte-level BPE Unicode back to real bytes.
|
||||||
|
func (t *Tokenizer) decodeGPT2Bytes(s string) string {
|
||||||
|
var buf []byte
|
||||||
|
for _, r := range s {
|
||||||
|
if b, ok := t.gpt2Decoder[r]; ok {
|
||||||
|
buf = append(buf, b)
|
||||||
|
} else {
|
||||||
|
// Non-mapped runes pass through as UTF-8
|
||||||
|
buf = append(buf, []byte(string(r))...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return string(buf)
|
||||||
|
}
|
||||||
|
|
||||||
// BOSToken returns the beginning-of-sequence token ID.
|
// BOSToken returns the beginning-of-sequence token ID.
|
||||||
func (t *Tokenizer) BOSToken() int32 { return t.bosToken }
|
func (t *Tokenizer) BOSToken() int32 { return t.bosToken }
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue