refactor(mlx): drop mlx build tag, auto-enable on darwin/arm64

Remove the manual -tags mlx requirement. MLX is now automatically
compiled on darwin/arm64 via build constraints. Stubs remain for
other platforms. No functional change.

Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
Snider 2026-02-17 16:57:41 +00:00
parent d0cbd5065e
commit 92c6282d50
19 changed files with 655 additions and 153 deletions

View file

@ -1,4 +1,4 @@
//go:build darwin && arm64 && mlx //go:build darwin && arm64
package ml package ml
@ -16,9 +16,9 @@ import (
"forge.lthn.ai/core/go-ai/mlx/tokenizer" "forge.lthn.ai/core/go-ai/mlx/tokenizer"
) )
// MLXBackend implements Backend for native Metal inference via mlx-c. // MLXBackend implements Backend and StreamingBackend for native Metal inference.
type MLXBackend struct { type MLXBackend struct {
model *model.GemmaModel model model.Model
tok *tokenizer.Tokenizer tok *tokenizer.Tokenizer
caches []cache.Cache caches []cache.Cache
sampler sample.Sampler sampler sample.Sampler
@ -26,6 +26,9 @@ type MLXBackend struct {
modelBytes uint64 // model size at load time, for memory budget modelBytes uint64 // model size at load time, for memory budget
} }
// Compile-time check that MLXBackend satisfies StreamingBackend.
var _ StreamingBackend = (*MLXBackend)(nil)
// NewMLXBackend loads a model from a safetensors directory and creates // NewMLXBackend loads a model from a safetensors directory and creates
// a native Metal inference backend. // a native Metal inference backend.
func NewMLXBackend(modelPath string) (*MLXBackend, error) { func NewMLXBackend(modelPath string) (*MLXBackend, error) {
@ -34,13 +37,12 @@ func NewMLXBackend(modelPath string) (*MLXBackend, error) {
} }
slog.Info("mlx: loading model", "path", modelPath) slog.Info("mlx: loading model", "path", modelPath)
m, err := model.LoadGemma3(modelPath) m, err := model.LoadModel(modelPath)
if err != nil { if err != nil {
return nil, fmt.Errorf("mlx: load model: %w", err) return nil, fmt.Errorf("mlx: load model: %w", err)
} }
// Cap Metal memory: cache limit for allocator reuse, memory limit as hard ceiling. // Cap Metal memory: cache limit for allocator reuse, memory limit as hard ceiling.
// This prevents runaway memory growth from killing the system.
mlx.SetCacheLimit(16 * 1024 * 1024 * 1024) // 16 GB allocator cache mlx.SetCacheLimit(16 * 1024 * 1024 * 1024) // 16 GB allocator cache
mlx.SetMemoryLimit(24 * 1024 * 1024 * 1024) // 24 GB hard cap mlx.SetMemoryLimit(24 * 1024 * 1024 * 1024) // 24 GB hard cap
@ -54,31 +56,27 @@ func NewMLXBackend(modelPath string) (*MLXBackend, error) {
model: m, model: m,
tok: m.Tokenizer(), tok: m.Tokenizer(),
caches: m.NewCache(), caches: m.NewCache(),
sampler: sample.New(0.1, 0, 0, 0), // default low temp sampler: sample.New(0.1, 0, 0, 0),
modelBytes: mlx.GetActiveMemory(), modelBytes: mlx.GetActiveMemory(),
}, nil }, nil
} }
// Generate produces text from a prompt using native Metal inference. // generate is the core token generation loop. If cb is non-nil, each token's
func (b *MLXBackend) Generate(ctx context.Context, prompt string, opts GenOpts) (string, error) { // text is sent to it (streaming mode). Returns the full output text.
func (b *MLXBackend) generate(ctx context.Context, tokens []int32, opts GenOpts, cb TokenCallback) (string, error) {
b.mu.Lock() b.mu.Lock()
defer b.mu.Unlock() defer b.mu.Unlock()
// Reset caches for new generation
for _, c := range b.caches { for _, c := range b.caches {
c.Reset() c.Reset()
} }
// Set up sampler based on opts
temp := float32(opts.Temperature) temp := float32(opts.Temperature)
if temp == 0 { if temp == 0 {
temp = 0.1 temp = 0.1
} }
sampler := sample.New(temp, 0, 0, 0) sampler := sample.New(temp, 0, 0, 0)
// Tokenize
formatted := tokenizer.FormatGemmaPrompt(prompt)
tokens := b.tok.Encode(formatted)
input := mlx.FromValues(tokens, 1, len(tokens)) input := mlx.FromValues(tokens, 1, len(tokens))
maxTokens := opts.MaxTokens maxTokens := opts.MaxTokens
@ -86,8 +84,6 @@ func (b *MLXBackend) Generate(ctx context.Context, prompt string, opts GenOpts)
maxTokens = 2048 maxTokens = 2048
} }
// Generation loop — force Go GC every 4 tokens so finalizers release
// intermediate C array handles that Go GC cannot see as memory pressure.
var output []int32 var output []int32
for i := 0; i < maxTokens; i++ { for i := 0; i < maxTokens; i++ {
select { select {
@ -110,20 +106,58 @@ func (b *MLXBackend) Generate(ctx context.Context, prompt string, opts GenOpts)
output = append(output, nextToken) output = append(output, nextToken)
input = mlx.FromValues([]int32{nextToken}, 1, 1) input = mlx.FromValues([]int32{nextToken}, 1, 1)
// Force GC to collect intermediate arrays + release Metal allocator cache // Stream the token text to the callback
if cb != nil {
tokenText := b.tok.Decode([]int32{nextToken})
if err := cb(tokenText); err != nil {
runtime.GC()
mlx.ClearCache()
return b.tok.Decode(output), err
}
}
if i%4 == 3 { if i%4 == 3 {
runtime.GC() runtime.GC()
mlx.ClearCache() mlx.ClearCache()
} }
} }
// Cleanup between requests
runtime.GC() runtime.GC()
mlx.ClearCache() mlx.ClearCache()
b.checkMemory() b.checkMemory()
return b.tok.Decode(output), nil return b.tok.Decode(output), nil
} }
// Generate produces text from a prompt using native Metal inference.
func (b *MLXBackend) Generate(ctx context.Context, prompt string, opts GenOpts) (string, error) {
formatted := formatPrompt(b.model.ModelType(), prompt)
tokens := b.tok.Encode(formatted)
return b.generate(ctx, tokens, opts, nil)
}
// Chat formats messages and generates a response.
func (b *MLXBackend) Chat(ctx context.Context, messages []Message, opts GenOpts) (string, error) {
prompt := formatChat(b.model.ModelType(), messages)
tokens := b.tok.Encode(prompt)
return b.generate(ctx, tokens, opts, nil)
}
// GenerateStream streams tokens from a single prompt via the callback.
func (b *MLXBackend) GenerateStream(ctx context.Context, prompt string, opts GenOpts, cb TokenCallback) error {
formatted := formatPrompt(b.model.ModelType(), prompt)
tokens := b.tok.Encode(formatted)
_, err := b.generate(ctx, tokens, opts, cb)
return err
}
// ChatStream streams tokens from a chat conversation via the callback.
func (b *MLXBackend) ChatStream(ctx context.Context, messages []Message, opts GenOpts, cb TokenCallback) error {
prompt := formatChat(b.model.ModelType(), messages)
tokens := b.tok.Encode(prompt)
_, err := b.generate(ctx, tokens, opts, cb)
return err
}
// lastPosition extracts the last sequence position from [B, L, V] logits → [B, V]. // lastPosition extracts the last sequence position from [B, L, V] logits → [B, V].
func lastPosition(logits *mlx.Array) *mlx.Array { func lastPosition(logits *mlx.Array) *mlx.Array {
shape := logits.Shape() shape := logits.Shape()
@ -137,9 +171,49 @@ func lastPosition(logits *mlx.Array) *mlx.Array {
return logits return logits
} }
// Chat formats messages and generates a response. // checkMemory logs Metal memory usage and forces cleanup if it exceeds budget.
func (b *MLXBackend) Chat(ctx context.Context, messages []Message, opts GenOpts) (string, error) { func (b *MLXBackend) checkMemory() {
// Format as Gemma chat active := mlx.GetActiveMemory()
budget := b.modelBytes * 3
if active > budget {
slog.Warn("mlx: memory over budget, forcing cleanup",
"active_mb", active/1024/1024,
"model_mb", b.modelBytes/1024/1024,
"peak_mb", mlx.GetPeakMemory()/1024/1024,
)
runtime.GC()
runtime.GC()
mlx.ClearCache()
}
}
// Name returns the backend identifier.
func (b *MLXBackend) Name() string { return "mlx" }
// Available reports whether Metal GPU is ready.
func (b *MLXBackend) Available() bool { return mlx.MetalAvailable() }
// formatPrompt wraps a raw prompt in the model's chat template for single-turn generation.
func formatPrompt(modelType, prompt string) string {
switch modelType {
case "qwen3":
return fmt.Sprintf("<|im_start|>user\n%s<|im_end|>\n<|im_start|>assistant\n", prompt)
default:
return tokenizer.FormatGemmaPrompt(prompt)
}
}
// formatChat builds a multi-turn chat prompt from messages using the model's template.
func formatChat(modelType string, messages []Message) string {
switch modelType {
case "qwen3":
return formatQwen3Chat(messages)
default:
return formatGemmaChat(messages)
}
}
func formatGemmaChat(messages []Message) string {
var prompt string var prompt string
for _, msg := range messages { for _, msg := range messages {
switch msg.Role { switch msg.Role {
@ -152,83 +226,21 @@ func (b *MLXBackend) Chat(ctx context.Context, messages []Message, opts GenOpts)
} }
} }
prompt += "<start_of_turn>model\n" prompt += "<start_of_turn>model\n"
return prompt
// Use raw prompt (already formatted)
b.mu.Lock()
defer b.mu.Unlock()
for _, c := range b.caches {
c.Reset()
}
temp := float32(opts.Temperature)
if temp == 0 {
temp = 0.1
}
sampler := sample.New(temp, 0, 0, 0)
tokens := b.tok.Encode(prompt)
input := mlx.FromValues(tokens, 1, len(tokens))
maxTokens := opts.MaxTokens
if maxTokens == 0 {
maxTokens = 2048
}
var output []int32
for i := 0; i < maxTokens; i++ {
select {
case <-ctx.Done():
runtime.GC()
mlx.ClearCache()
return b.tok.Decode(output), ctx.Err()
default:
}
logits := b.model.Forward(input, b.caches)
logits = lastPosition(logits)
next := sampler.Sample(logits)
mlx.Materialize(next)
nextToken := int32(next.Int())
if nextToken == b.tok.EOSToken() {
break
}
output = append(output, nextToken)
input = mlx.FromValues([]int32{nextToken}, 1, 1)
// Force GC to collect intermediate arrays + release Metal allocator cache
if i%4 == 3 {
runtime.GC()
mlx.ClearCache()
}
}
// Cleanup between requests
runtime.GC()
mlx.ClearCache()
b.checkMemory()
return b.tok.Decode(output), nil
} }
// checkMemory logs Metal memory usage and forces cleanup if it exceeds budget. func formatQwen3Chat(messages []Message) string {
func (b *MLXBackend) checkMemory() { var prompt string
active := mlx.GetActiveMemory() for _, msg := range messages {
budget := b.modelBytes * 3 // 3× model size = danger zone switch msg.Role {
if active > budget { case "system":
slog.Warn("mlx: memory over budget, forcing cleanup", prompt += fmt.Sprintf("<|im_start|>system\n%s<|im_end|>\n", msg.Content)
"active_mb", active/1024/1024, case "user":
"model_mb", b.modelBytes/1024/1024, prompt += fmt.Sprintf("<|im_start|>user\n%s<|im_end|>\n", msg.Content)
"peak_mb", mlx.GetPeakMemory()/1024/1024, case "assistant":
) prompt += fmt.Sprintf("<|im_start|>assistant\n%s<|im_end|>\n", msg.Content)
runtime.GC() }
runtime.GC() // double GC to run finalizers
mlx.ClearCache()
} }
prompt += "<|im_start|>assistant\n"
return prompt
} }
// Name returns the backend identifier.
func (b *MLXBackend) Name() string { return "mlx" }
// Available reports whether Metal GPU is ready.
func (b *MLXBackend) Available() bool { return mlx.MetalAvailable() }

View file

@ -1,4 +1,4 @@
//go:build darwin && arm64 && mlx //go:build darwin && arm64
package mlx package mlx

2
mlx/cache/cache.go vendored
View file

@ -1,4 +1,4 @@
//go:build darwin && arm64 && mlx //go:build darwin && arm64
// Package cache provides KV cache implementations for transformer inference. // Package cache provides KV cache implementations for transformer inference.
package cache package cache

View file

@ -1,4 +1,4 @@
//go:build darwin && arm64 && mlx //go:build darwin && arm64
package mlx package mlx

View file

@ -1,4 +1,4 @@
//go:build darwin && arm64 && mlx //go:build darwin && arm64
package mlx package mlx

View file

@ -1,4 +1,4 @@
//go:build darwin && arm64 && mlx //go:build darwin && arm64
package mlx package mlx

View file

@ -1,4 +1,4 @@
//go:build darwin && arm64 && mlx //go:build darwin && arm64
package mlx package mlx

View file

@ -1,4 +1,4 @@
//go:build darwin && arm64 && mlx //go:build darwin && arm64
// Package mlx provides Go bindings for Apple's MLX framework via mlx-c. // Package mlx provides Go bindings for Apple's MLX framework via mlx-c.
// //
@ -6,9 +6,9 @@
// //
// cd pkg/mlx && go generate ./... // cd pkg/mlx && go generate ./...
// //
// Build with MLX enabled: // Build (MLX is auto-enabled on darwin/arm64):
// //
// go build -tags mlx -o core . // go build -o core .
package mlx package mlx
//go:generate cmake -S . -B build -DCMAKE_INSTALL_PREFIX=dist -DCMAKE_BUILD_TYPE=Release //go:generate cmake -S . -B build -DCMAKE_INSTALL_PREFIX=dist -DCMAKE_BUILD_TYPE=Release

View file

@ -1,4 +1,4 @@
//go:build !(darwin && arm64 && mlx) //go:build !(darwin && arm64)
// Package mlx provides Go bindings for Apple's MLX framework via mlx-c. // Package mlx provides Go bindings for Apple's MLX framework via mlx-c.
// This stub file is used on non-darwin/non-arm64 platforms or when the // This stub file is used on non-darwin/non-arm64 platforms or when the

View file

@ -1,4 +1,4 @@
//go:build darwin && arm64 && mlx //go:build darwin && arm64
// Package model provides transformer model architectures for MLX inference. // Package model provides transformer model architectures for MLX inference.
package model package model
@ -16,12 +16,6 @@ import (
"forge.lthn.ai/core/go-ai/mlx/tokenizer" "forge.lthn.ai/core/go-ai/mlx/tokenizer"
) )
// QuantizationConfig holds quantization parameters from config.json.
type QuantizationConfig struct {
GroupSize int `json:"group_size"`
Bits int `json:"bits"`
}
// TextConfig holds Gemma 3 text model configuration. // TextConfig holds Gemma 3 text model configuration.
type TextConfig struct { type TextConfig struct {
HiddenSize int32 `json:"hidden_size"` HiddenSize int32 `json:"hidden_size"`
@ -168,17 +162,6 @@ func parseConfig(data []byte) (*TextConfig, error) {
return &cfg, nil return &cfg, nil
} }
// resolveWeight looks up a weight with optional "language_model." prefix.
func resolveWeight(weights map[string]*mlx.Array, name string) *mlx.Array {
if w, ok := weights[name]; ok {
return w
}
if w, ok := weights["language_model."+name]; ok {
return w
}
return nil
}
// LoadGemma3 loads a Gemma 3 text model from a directory. // LoadGemma3 loads a Gemma 3 text model from a directory.
func LoadGemma3(modelPath string) (*GemmaModel, error) { func LoadGemma3(modelPath string) (*GemmaModel, error) {
data, err := os.ReadFile(filepath.Join(modelPath, "config.json")) data, err := os.ReadFile(filepath.Join(modelPath, "config.json"))
@ -428,3 +411,6 @@ func (m *GemmaModel) NumLayers() int { return len(m.Layers) }
// Tokenizer returns the model's tokenizer. // Tokenizer returns the model's tokenizer.
func (m *GemmaModel) Tokenizer() *tokenizer.Tokenizer { return m.Tok } func (m *GemmaModel) Tokenizer() *tokenizer.Tokenizer { return m.Tok }
// ModelType returns the architecture identifier.
func (m *GemmaModel) ModelType() string { return "gemma3" }

74
mlx/model/model.go Normal file
View file

@ -0,0 +1,74 @@
//go:build darwin && arm64
// Package model provides transformer model architectures for MLX inference.
package model
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"forge.lthn.ai/core/go-ai/mlx"
"forge.lthn.ai/core/go-ai/mlx/cache"
"forge.lthn.ai/core/go-ai/mlx/tokenizer"
)
// Model is the common interface for all transformer model architectures.
type Model interface {
// Forward runs the model forward pass on token IDs with KV caches.
Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array
// NewCache creates per-layer KV caches for generation.
NewCache() []cache.Cache
// NumLayers returns the number of transformer layers.
NumLayers() int
// Tokenizer returns the model's tokenizer.
Tokenizer() *tokenizer.Tokenizer
// ModelType returns the architecture identifier (e.g. "gemma3", "qwen3").
ModelType() string
}
// QuantizationConfig holds quantization parameters from config.json.
type QuantizationConfig struct {
GroupSize int `json:"group_size"`
Bits int `json:"bits"`
}
// resolveWeight looks up a weight with optional "language_model." prefix.
func resolveWeight(weights map[string]*mlx.Array, name string) *mlx.Array {
if w, ok := weights[name]; ok {
return w
}
if w, ok := weights["language_model."+name]; ok {
return w
}
return nil
}
// LoadModel auto-detects the model architecture from config.json and loads it.
func LoadModel(modelPath string) (Model, error) {
data, err := os.ReadFile(filepath.Join(modelPath, "config.json"))
if err != nil {
return nil, fmt.Errorf("model: load config: %w", err)
}
var probe struct {
ModelType string `json:"model_type"`
}
if err := json.Unmarshal(data, &probe); err != nil {
return nil, fmt.Errorf("model: parse model_type: %w", err)
}
switch probe.ModelType {
case "qwen3":
return LoadQwen3(modelPath)
case "gemma3", "gemma2":
return LoadGemma3(modelPath)
default:
return nil, fmt.Errorf("model: unsupported architecture %q", probe.ModelType)
}
}

305
mlx/model/qwen3.go Normal file
View file

@ -0,0 +1,305 @@
//go:build darwin && arm64
package model
import (
"encoding/json"
"fmt"
"log/slog"
"math"
"os"
"path/filepath"
"forge.lthn.ai/core/go-ai/mlx"
"forge.lthn.ai/core/go-ai/mlx/cache"
"forge.lthn.ai/core/go-ai/mlx/tokenizer"
)
// Qwen3Config holds Qwen 3 model configuration.
type Qwen3Config struct {
HiddenSize int32 `json:"hidden_size"`
NumHiddenLayers int32 `json:"num_hidden_layers"`
IntermediateSize int32 `json:"intermediate_size"`
NumAttentionHeads int32 `json:"num_attention_heads"`
NumKeyValueHeads int32 `json:"num_key_value_heads"`
HeadDim int32 `json:"head_dim"`
VocabSize int32 `json:"vocab_size"`
RMSNormEps float32 `json:"rms_norm_eps"`
RopeTheta float32 `json:"rope_theta"`
MaxPositionEmbeddings int32 `json:"max_position_embeddings"`
Quantization *QuantizationConfig `json:"-"`
Scale float32 `json:"-"` // 1/sqrt(head_dim)
}
// Qwen3Model is the Qwen 3 text model.
type Qwen3Model struct {
EmbedTokens *mlx.Embedding
Layers []*Qwen3DecoderLayer
Norm *mlx.RMSNormModule
Output *mlx.Linear
Tok *tokenizer.Tokenizer
Cfg *Qwen3Config
}
// Qwen3DecoderLayer is a single transformer block.
// Qwen 3 uses standard pre-norm residual: norm→attn→add, norm→mlp→add.
type Qwen3DecoderLayer struct {
InputNorm *mlx.RMSNormModule // Pre-attention norm
PostAttnNorm *mlx.RMSNormModule // Pre-MLP norm (confusingly named post_attention_layernorm)
Attention *Qwen3Attention
MLP *Qwen3MLP
}
// Qwen3Attention implements Qwen 3 GQA with Q/K RMS normalization.
type Qwen3Attention struct {
QProj *mlx.Linear
KProj *mlx.Linear
VProj *mlx.Linear
OProj *mlx.Linear
QNorm *mlx.RMSNormModule
KNorm *mlx.RMSNormModule
}
// Qwen3MLP is the SwiGLU feed-forward network: down(silu(gate(x)) * up(x)).
type Qwen3MLP struct {
GateProj *mlx.Linear
UpProj *mlx.Linear
DownProj *mlx.Linear
}
func parseQwen3Config(data []byte) (*Qwen3Config, error) {
var cfg Qwen3Config
if err := json.Unmarshal(data, &cfg); err != nil {
return nil, err
}
// Top-level quantization
var wrapper struct {
Quantization *QuantizationConfig `json:"quantization"`
}
json.Unmarshal(data, &wrapper)
cfg.Quantization = wrapper.Quantization
// Compute scale
if cfg.HeadDim == 0 {
cfg.HeadDim = cfg.HiddenSize / cfg.NumAttentionHeads
}
cfg.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim)))
// Defaults
if cfg.RopeTheta == 0 {
cfg.RopeTheta = 1000000
}
if cfg.RMSNormEps == 0 {
cfg.RMSNormEps = 1e-6
}
if cfg.VocabSize == 0 {
cfg.VocabSize = 151936
}
return &cfg, nil
}
// LoadQwen3 loads a Qwen 3 model from a safetensors directory.
func LoadQwen3(modelPath string) (*Qwen3Model, error) {
data, err := os.ReadFile(filepath.Join(modelPath, "config.json"))
if err != nil {
return nil, fmt.Errorf("qwen3: load config: %w", err)
}
cfg, err := parseQwen3Config(data)
if err != nil {
return nil, fmt.Errorf("qwen3: parse config: %w", err)
}
tok, err := tokenizer.Load(filepath.Join(modelPath, "tokenizer.json"))
if err != nil {
return nil, fmt.Errorf("qwen3: load tokenizer: %w", err)
}
// Load weights from all safetensors files
weights := make(map[string]*mlx.Array)
matches, _ := filepath.Glob(filepath.Join(modelPath, "*.safetensors"))
for _, path := range matches {
for name, arr := range mlx.LoadSafetensors(path) {
weights[name] = arr
}
}
w := func(name string) *mlx.Array { return resolveWeight(weights, name) }
// Quantization setup
q := cfg.Quantization
if q != nil {
slog.Info("qwen3: using quantized inference", "bits", q.Bits, "group_size", q.GroupSize)
}
linear := func(prefix string) *mlx.Linear {
weight := w(prefix + ".weight")
scales := w(prefix + ".scales")
biases := w(prefix + ".biases")
bias := w(prefix + ".bias")
if scales != nil && q != nil {
return mlx.NewQuantizedLinear(weight, scales, biases, bias, q.GroupSize, q.Bits)
}
return mlx.NewLinear(weight, bias)
}
// Embedding
embed := &mlx.Embedding{Weight: w("model.embed_tokens.weight")}
if embedScales := w("model.embed_tokens.scales"); embedScales != nil && q != nil {
embed.Scales = embedScales
embed.Biases = w("model.embed_tokens.biases")
embed.GroupSize = q.GroupSize
embed.Bits = q.Bits
}
m := &Qwen3Model{
EmbedTokens: embed,
Layers: make([]*Qwen3DecoderLayer, cfg.NumHiddenLayers),
Norm: &mlx.RMSNormModule{Weight: w("model.norm.weight")},
Tok: tok,
Cfg: cfg,
}
for i := int32(0); i < cfg.NumHiddenLayers; i++ {
p := fmt.Sprintf("model.layers.%d", i)
m.Layers[i] = &Qwen3DecoderLayer{
InputNorm: &mlx.RMSNormModule{Weight: w(p + ".input_layernorm.weight")},
PostAttnNorm: &mlx.RMSNormModule{Weight: w(p + ".post_attention_layernorm.weight")},
Attention: &Qwen3Attention{
QProj: linear(p + ".self_attn.q_proj"),
KProj: linear(p + ".self_attn.k_proj"),
VProj: linear(p + ".self_attn.v_proj"),
OProj: linear(p + ".self_attn.o_proj"),
QNorm: &mlx.RMSNormModule{Weight: w(p + ".self_attn.q_norm.weight")},
KNorm: &mlx.RMSNormModule{Weight: w(p + ".self_attn.k_norm.weight")},
},
MLP: &Qwen3MLP{
GateProj: linear(p + ".mlp.gate_proj"),
UpProj: linear(p + ".mlp.up_proj"),
DownProj: linear(p + ".mlp.down_proj"),
},
}
}
// Output head — Qwen 3 has tie_word_embeddings=false, so lm_head is separate
lmHeadWeight := w("lm_head.weight")
if lmHeadWeight != nil {
lmHeadScales := w("lm_head.scales")
if lmHeadScales != nil && q != nil {
m.Output = mlx.NewQuantizedLinear(lmHeadWeight, lmHeadScales, w("lm_head.biases"), nil, q.GroupSize, q.Bits)
} else {
m.Output = mlx.NewLinear(lmHeadWeight, nil)
}
} else {
m.Output = m.EmbedTokens.AsLinear()
}
// Materialise all weights onto Metal
var allArrays []*mlx.Array
for _, a := range weights {
allArrays = append(allArrays, a)
}
mlx.Materialize(allArrays...)
slog.Info("qwen3: model loaded",
"layers", cfg.NumHiddenLayers,
"hidden", cfg.HiddenSize,
"heads", cfg.NumAttentionHeads,
"kv_heads", cfg.NumKeyValueHeads,
"head_dim", cfg.HeadDim,
"vocab", cfg.VocabSize,
)
return m, nil
}
// Forward runs the Qwen 3 forward pass.
// Unlike Gemma, Qwen does NOT scale embeddings by sqrt(hidden_size).
func (m *Qwen3Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array {
shape := tokens.Shape()
B, L := shape[0], shape[1]
h := m.EmbedTokens.Forward(tokens)
for i, layer := range m.Layers {
h = layer.forward(h, caches[i], B, L, m.Cfg)
}
return m.Output.Forward(m.Norm.Forward(h, m.Cfg.RMSNormEps))
}
func (l *Qwen3DecoderLayer) forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Qwen3Config) *mlx.Array {
// Pre-attention norm → attention → residual add
normed := l.InputNorm.Forward(x, cfg.RMSNormEps)
attnOut := l.Attention.forward(normed, c, B, L, cfg)
h := mlx.Add(x, attnOut)
// Pre-MLP norm → MLP → residual add
normed = l.PostAttnNorm.Forward(h, cfg.RMSNormEps)
mlpOut := l.MLP.forward(normed)
return mlx.Add(h, mlpOut)
}
func (a *Qwen3Attention) forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Qwen3Config) *mlx.Array {
q := a.QProj.Forward(x)
k := a.KProj.Forward(x)
v := a.VProj.Forward(x)
// Reshape to [B, num_heads, L, head_dim]
q = mlx.AsStrided(q, []int32{B, cfg.NumAttentionHeads, L, cfg.HeadDim},
[]int64{int64(L * cfg.NumAttentionHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumAttentionHeads * cfg.HeadDim), 1}, 0)
k = mlx.AsStrided(k, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim},
[]int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0)
v = mlx.AsStrided(v, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim},
[]int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0)
// Q/K RMS normalization (Qwen 3 has this)
q = a.QNorm.Forward(q, cfg.RMSNormEps)
k = a.KNorm.Forward(k, cfg.RMSNormEps)
// RoPE — single theta for all layers (no sliding window)
q = mlx.RoPE(q, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, c.Offset())
k = mlx.RoPE(k, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, c.Offset())
// Update KV cache
k, v = c.Update(k, v, int(L))
// GQA: repeat K/V heads to match Q heads
repeatFactor := cfg.NumAttentionHeads / cfg.NumKeyValueHeads
if repeatFactor > 1 {
k = mlx.RepeatKV(k, repeatFactor)
v = mlx.RepeatKV(v, repeatFactor)
}
// Scaled dot-product attention
out := mlx.ScaledDotProductAttention(q, k, v, cfg.Scale, L > 1)
out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.HeadDim)
return a.OProj.Forward(out)
}
// forward computes SwiGLU: down(silu(gate(x)) * up(x)).
func (m *Qwen3MLP) forward(x *mlx.Array) *mlx.Array {
gate := mlx.SiLU(m.GateProj.Forward(x))
return m.DownProj.Forward(mlx.Mul(gate, m.UpProj.Forward(x)))
}
// NewCache creates per-layer KV caches. Qwen 3 uses global attention only.
func (m *Qwen3Model) NewCache() []cache.Cache {
caches := make([]cache.Cache, len(m.Layers))
for i := range caches {
caches[i] = cache.NewKVCache()
}
return caches
}
// NumLayers returns the number of transformer layers.
func (m *Qwen3Model) NumLayers() int { return len(m.Layers) }
// Tokenizer returns the model's tokenizer.
func (m *Qwen3Model) Tokenizer() *tokenizer.Tokenizer { return m.Tok }
// ModelType returns the architecture identifier.
func (m *Qwen3Model) ModelType() string { return "qwen3" }

View file

@ -1,4 +1,4 @@
//go:build darwin && arm64 && mlx //go:build darwin && arm64
package mlx package mlx

View file

@ -1,4 +1,4 @@
//go:build darwin && arm64 && mlx //go:build darwin && arm64
package mlx package mlx
@ -68,6 +68,18 @@ func Exp(a *Array) *Array {
return out return out
} }
// Sigmoid returns element-wise 1/(1+exp(-a)).
func Sigmoid(a *Array) *Array {
out := New("SIGMOID", a)
C.mlx_sigmoid(&out.ctx, a.ctx, DefaultStream().ctx)
return out
}
// SiLU returns element-wise x * sigmoid(x) (Swish activation).
func SiLU(a *Array) *Array {
return Mul(a, Sigmoid(a))
}
// Tanh returns element-wise tanh(a). // Tanh returns element-wise tanh(a).
func Tanh(a *Array) *Array { func Tanh(a *Array) *Array {
out := New("TANH", a) out := New("TANH", a)

View file

@ -1,4 +1,4 @@
//go:build darwin && arm64 && mlx //go:build darwin && arm64
package mlx package mlx

View file

@ -1,4 +1,4 @@
//go:build darwin && arm64 && mlx //go:build darwin && arm64
// Package sample provides composable token sampling strategies. // Package sample provides composable token sampling strategies.
package sample package sample

View file

@ -1,4 +1,4 @@
//go:build darwin && arm64 && mlx //go:build darwin && arm64
package mlx package mlx

View file

@ -1,4 +1,4 @@
//go:build darwin && arm64 && mlx //go:build darwin && arm64
package mlx package mlx

View file

@ -1,6 +1,6 @@
//go:build darwin && arm64 && mlx //go:build darwin && arm64
// Package tokenizer provides BPE/SentencePiece tokenization for Gemma models. // Package tokenizer provides BPE tokenization for transformer models.
package tokenizer package tokenizer
import ( import (
@ -19,6 +19,11 @@ type Tokenizer struct {
bosToken int32 bosToken int32
eosToken int32 eosToken int32
// GPT-2 byte-level BPE support (used by Qwen, GPT, Llama, etc.)
isGPT2BPE bool
gpt2Decoder map[rune]byte // Unicode char → original byte
gpt2Encoder map[byte]rune // original byte → Unicode char
} }
type mergePair struct { type mergePair struct {
@ -32,7 +37,7 @@ type tokenizerJSON struct {
Type string `json:"type"` Type string `json:"type"`
Vocab json.RawMessage `json:"vocab"` Vocab json.RawMessage `json:"vocab"`
Merges json.RawMessage `json:"merges"` Merges json.RawMessage `json:"merges"`
ByteFallback bool `json:"byte_fallback"` ByteFallback bool `json:"byte_fallback"`
} `json:"model"` } `json:"model"`
AddedTokens []struct { AddedTokens []struct {
ID int32 `json:"id"` ID int32 `json:"id"`
@ -71,7 +76,6 @@ func Load(path string) (*Tokenizer, error) {
// Parse merges — supports both ["a b", ...] and [["a","b"], ...] formats // Parse merges — supports both ["a b", ...] and [["a","b"], ...] formats
if len(tj.Model.Merges) > 0 { if len(tj.Model.Merges) > 0 {
// Try array-of-strings first
var stringMerges []string var stringMerges []string
if err := json.Unmarshal(tj.Model.Merges, &stringMerges); err == nil { if err := json.Unmarshal(tj.Model.Merges, &stringMerges); err == nil {
for rank, merge := range stringMerges { for rank, merge := range stringMerges {
@ -81,7 +85,6 @@ func Load(path string) (*Tokenizer, error) {
} }
} }
} else { } else {
// Try array-of-arrays: [["a","b"], ...]
var arrayMerges [][]string var arrayMerges [][]string
if err := json.Unmarshal(tj.Model.Merges, &arrayMerges); err == nil { if err := json.Unmarshal(tj.Model.Merges, &arrayMerges); err == nil {
for rank, pair := range arrayMerges { for rank, pair := range arrayMerges {
@ -102,37 +105,77 @@ func Load(path string) (*Tokenizer, error) {
t.invVocab[tok.ID] = tok.Content t.invVocab[tok.ID] = tok.Content
} }
// Set BOS/EOS // Detect GPT-2 byte-level BPE (Qwen, GPT, Llama use Ġ for space)
if _, ok := t.vocab["Ġ"]; ok {
t.isGPT2BPE = true
t.gpt2Decoder, t.gpt2Encoder = buildGPT2ByteMaps()
}
// Set BOS/EOS — detect model family from special tokens
if id, ok := t.special["<bos>"]; ok { if id, ok := t.special["<bos>"]; ok {
t.bosToken = id t.bosToken = id
} }
if id, ok := t.special["<eos>"]; ok { if id, ok := t.special["<eos>"]; ok {
t.eosToken = id t.eosToken = id
} }
// Gemma: <end_of_turn> is the generation stop token
if id, ok := t.special["<end_of_turn>"]; ok { if id, ok := t.special["<end_of_turn>"]; ok {
t.eosToken = id // Gemma uses end_of_turn as EOS t.eosToken = id
}
// Qwen3: <|im_end|> is the generation stop token
if id, ok := t.special["<|im_end|>"]; ok {
t.eosToken = id
}
// Qwen3 BOS: <|im_start|>
if id, ok := t.special["<|im_start|>"]; ok {
t.bosToken = id
} }
return t, nil return t, nil
} }
// buildGPT2ByteMaps creates the GPT-2 byte-level BPE encoding/decoding maps.
// GPT-2 maps all 256 bytes to printable Unicode characters to avoid control chars
// in the vocabulary. Printable ASCII + Latin-1 Supplement map to themselves;
// everything else (0-32, 127-160, 173) maps to U+0100 onwards.
func buildGPT2ByteMaps() (decoder map[rune]byte, encoder map[byte]rune) {
encoder = make(map[byte]rune, 256)
decoder = make(map[rune]byte, 256)
// Self-mapping ranges: printable ASCII + Latin-1 Supplement
// Use int loop variable to avoid byte overflow at 255.
selfMap := func(lo, hi int) {
for b := lo; b <= hi; b++ {
encoder[byte(b)] = rune(b)
decoder[rune(b)] = byte(b)
}
}
selfMap(33, 126) // ! through ~
selfMap(161, 172) // ¡ through ¬
selfMap(174, 255) // ® through ÿ
// Non-self-mapping: control chars, space, DEL, and gaps
n := 0
for b := 0; b < 256; b++ {
if _, ok := encoder[byte(b)]; !ok {
r := rune(256 + n)
encoder[byte(b)] = r
decoder[r] = byte(b)
n++
}
}
return
}
// Encode converts text to token IDs. Prepends BOS token. // Encode converts text to token IDs. Prepends BOS token.
func (t *Tokenizer) Encode(text string) []int32 { func (t *Tokenizer) Encode(text string) []int32 {
tokens := []int32{t.bosToken} tokens := []int32{t.bosToken}
// Simple BPE encoding — split into characters then merge if t.isGPT2BPE {
// This is a simplified version. Full implementation handles return t.encodeGPT2(text)
// Unicode, byte fallback, and efficient BPE merging.
chars := []string{}
for _, r := range text {
s := string(r)
if s == " " {
s = "▁" // SentencePiece space marker
}
chars = append(chars, s)
} }
// Check for special tokens first // SentencePiece style encoding
remaining := text remaining := text
for remaining != "" { for remaining != "" {
found := false found := false
@ -145,7 +188,6 @@ func (t *Tokenizer) Encode(text string) []int32 {
} }
} }
if !found { if !found {
// Encode character by character (simplified BPE)
r := []rune(remaining) r := []rune(remaining)
ch := "▁" + string(r[0]) ch := "▁" + string(r[0])
if id, ok := t.vocab[ch]; ok { if id, ok := t.vocab[ch]; ok {
@ -160,24 +202,95 @@ func (t *Tokenizer) Encode(text string) []int32 {
return tokens return tokens
} }
// encodeGPT2 encodes text using GPT-2 byte-level BPE.
func (t *Tokenizer) encodeGPT2(text string) []int32 {
tokens := []int32{t.bosToken}
// Convert text bytes to GPT-2 Unicode representation
var encoded strings.Builder
for _, b := range []byte(text) {
if r, ok := t.gpt2Encoder[b]; ok {
encoded.WriteRune(r)
}
}
gpt2Text := encoded.String()
// Scan for special tokens and regular text
remaining := gpt2Text
for remaining != "" {
// Check special tokens (these are stored as-is, not byte-encoded)
found := false
for tok, id := range t.special {
// Special tokens in GPT-2 tokenizers are stored in their original form
// Convert the special token to GPT-2 encoding for matching
var encTok strings.Builder
for _, b := range []byte(tok) {
if r, ok := t.gpt2Encoder[b]; ok {
encTok.WriteRune(r)
}
}
encStr := encTok.String()
if strings.HasPrefix(remaining, encStr) {
tokens = append(tokens, id)
remaining = remaining[len(encStr):]
found = true
break
}
}
if !found {
// Character-by-character lookup (simplified BPE)
r := []rune(remaining)
ch := string(r[0])
if id, ok := t.vocab[ch]; ok {
tokens = append(tokens, id)
}
remaining = string(r[1:])
}
}
return tokens
}
// Decode converts token IDs back to text. // Decode converts token IDs back to text.
func (t *Tokenizer) Decode(tokens []int32) string { func (t *Tokenizer) Decode(tokens []int32) string {
var sb strings.Builder var sb strings.Builder
for _, id := range tokens { for _, id := range tokens {
if text, ok := t.invVocab[id]; ok { if text, ok := t.invVocab[id]; ok {
// Replace SentencePiece space marker // Skip special tokens in decode output
text = strings.ReplaceAll(text, "▁", " ") if _, isSpecial := t.special[text]; isSpecial {
continue
}
sb.WriteString(text) sb.WriteString(text)
} }
} }
result := sb.String() raw := sb.String()
// Trim leading space from SentencePiece encoding
if t.isGPT2BPE {
return t.decodeGPT2Bytes(raw)
}
// SentencePiece style
result := strings.ReplaceAll(raw, "▁", " ")
if strings.HasPrefix(result, " ") { if strings.HasPrefix(result, " ") {
result = result[1:] result = result[1:]
} }
return result return result
} }
// decodeGPT2Bytes converts GPT-2 byte-level BPE Unicode back to real bytes.
func (t *Tokenizer) decodeGPT2Bytes(s string) string {
var buf []byte
for _, r := range s {
if b, ok := t.gpt2Decoder[r]; ok {
buf = append(buf, b)
} else {
// Non-mapped runes pass through as UTF-8
buf = append(buf, []byte(string(r))...)
}
}
return string(buf)
}
// BOSToken returns the beginning-of-sequence token ID. // BOSToken returns the beginning-of-sequence token ID.
func (t *Tokenizer) BOSToken() int32 { return t.bosToken } func (t *Tokenizer) BOSToken() int32 { return t.bosToken }