From 92c6282d50eae80ac0ab4c88e8c37488fec8b79f Mon Sep 17 00:00:00 2001 From: Snider Date: Tue, 17 Feb 2026 16:57:41 +0000 Subject: [PATCH] refactor(mlx): drop mlx build tag, auto-enable on darwin/arm64 Remove the manual -tags mlx requirement. MLX is now automatically compiled on darwin/arm64 via build constraints. Stubs remain for other platforms. No functional change. Co-Authored-By: Virgil --- ml/backend_mlx.go | 204 +++++++++++++------------ mlx/array.go | 2 +- mlx/cache/cache.go | 2 +- mlx/compile.go | 2 +- mlx/dtype.go | 2 +- mlx/fast.go | 2 +- mlx/io.go | 2 +- mlx/mlx.go | 6 +- mlx/mlx_stub.go | 2 +- mlx/model/gemma3.go | 22 +-- mlx/model/model.go | 74 +++++++++ mlx/model/qwen3.go | 305 +++++++++++++++++++++++++++++++++++++ mlx/nn.go | 2 +- mlx/ops.go | 14 +- mlx/random.go | 2 +- mlx/sample/sample.go | 2 +- mlx/slice.go | 2 +- mlx/stream.go | 2 +- mlx/tokenizer/tokenizer.go | 159 ++++++++++++++++--- 19 files changed, 655 insertions(+), 153 deletions(-) create mode 100644 mlx/model/model.go create mode 100644 mlx/model/qwen3.go diff --git a/ml/backend_mlx.go b/ml/backend_mlx.go index 96b8b71..4a0e7d6 100644 --- a/ml/backend_mlx.go +++ b/ml/backend_mlx.go @@ -1,4 +1,4 @@ -//go:build darwin && arm64 && mlx +//go:build darwin && arm64 package ml @@ -16,9 +16,9 @@ import ( "forge.lthn.ai/core/go-ai/mlx/tokenizer" ) -// MLXBackend implements Backend for native Metal inference via mlx-c. +// MLXBackend implements Backend and StreamingBackend for native Metal inference. type MLXBackend struct { - model *model.GemmaModel + model model.Model tok *tokenizer.Tokenizer caches []cache.Cache sampler sample.Sampler @@ -26,6 +26,9 @@ type MLXBackend struct { modelBytes uint64 // model size at load time, for memory budget } +// Compile-time check that MLXBackend satisfies StreamingBackend. +var _ StreamingBackend = (*MLXBackend)(nil) + // NewMLXBackend loads a model from a safetensors directory and creates // a native Metal inference backend. func NewMLXBackend(modelPath string) (*MLXBackend, error) { @@ -34,13 +37,12 @@ func NewMLXBackend(modelPath string) (*MLXBackend, error) { } slog.Info("mlx: loading model", "path", modelPath) - m, err := model.LoadGemma3(modelPath) + m, err := model.LoadModel(modelPath) if err != nil { return nil, fmt.Errorf("mlx: load model: %w", err) } // Cap Metal memory: cache limit for allocator reuse, memory limit as hard ceiling. - // This prevents runaway memory growth from killing the system. mlx.SetCacheLimit(16 * 1024 * 1024 * 1024) // 16 GB allocator cache mlx.SetMemoryLimit(24 * 1024 * 1024 * 1024) // 24 GB hard cap @@ -54,31 +56,27 @@ func NewMLXBackend(modelPath string) (*MLXBackend, error) { model: m, tok: m.Tokenizer(), caches: m.NewCache(), - sampler: sample.New(0.1, 0, 0, 0), // default low temp + sampler: sample.New(0.1, 0, 0, 0), modelBytes: mlx.GetActiveMemory(), }, nil } -// Generate produces text from a prompt using native Metal inference. -func (b *MLXBackend) Generate(ctx context.Context, prompt string, opts GenOpts) (string, error) { +// generate is the core token generation loop. If cb is non-nil, each token's +// text is sent to it (streaming mode). Returns the full output text. +func (b *MLXBackend) generate(ctx context.Context, tokens []int32, opts GenOpts, cb TokenCallback) (string, error) { b.mu.Lock() defer b.mu.Unlock() - // Reset caches for new generation for _, c := range b.caches { c.Reset() } - // Set up sampler based on opts temp := float32(opts.Temperature) if temp == 0 { temp = 0.1 } sampler := sample.New(temp, 0, 0, 0) - // Tokenize - formatted := tokenizer.FormatGemmaPrompt(prompt) - tokens := b.tok.Encode(formatted) input := mlx.FromValues(tokens, 1, len(tokens)) maxTokens := opts.MaxTokens @@ -86,8 +84,6 @@ func (b *MLXBackend) Generate(ctx context.Context, prompt string, opts GenOpts) maxTokens = 2048 } - // Generation loop — force Go GC every 4 tokens so finalizers release - // intermediate C array handles that Go GC cannot see as memory pressure. var output []int32 for i := 0; i < maxTokens; i++ { select { @@ -110,20 +106,58 @@ func (b *MLXBackend) Generate(ctx context.Context, prompt string, opts GenOpts) output = append(output, nextToken) input = mlx.FromValues([]int32{nextToken}, 1, 1) - // Force GC to collect intermediate arrays + release Metal allocator cache + // Stream the token text to the callback + if cb != nil { + tokenText := b.tok.Decode([]int32{nextToken}) + if err := cb(tokenText); err != nil { + runtime.GC() + mlx.ClearCache() + return b.tok.Decode(output), err + } + } + if i%4 == 3 { runtime.GC() mlx.ClearCache() } } - // Cleanup between requests runtime.GC() mlx.ClearCache() b.checkMemory() return b.tok.Decode(output), nil } +// Generate produces text from a prompt using native Metal inference. +func (b *MLXBackend) Generate(ctx context.Context, prompt string, opts GenOpts) (string, error) { + formatted := formatPrompt(b.model.ModelType(), prompt) + tokens := b.tok.Encode(formatted) + return b.generate(ctx, tokens, opts, nil) +} + +// Chat formats messages and generates a response. +func (b *MLXBackend) Chat(ctx context.Context, messages []Message, opts GenOpts) (string, error) { + prompt := formatChat(b.model.ModelType(), messages) + tokens := b.tok.Encode(prompt) + return b.generate(ctx, tokens, opts, nil) +} + +// GenerateStream streams tokens from a single prompt via the callback. +func (b *MLXBackend) GenerateStream(ctx context.Context, prompt string, opts GenOpts, cb TokenCallback) error { + formatted := formatPrompt(b.model.ModelType(), prompt) + tokens := b.tok.Encode(formatted) + _, err := b.generate(ctx, tokens, opts, cb) + return err +} + +// ChatStream streams tokens from a chat conversation via the callback. +func (b *MLXBackend) ChatStream(ctx context.Context, messages []Message, opts GenOpts, cb TokenCallback) error { + prompt := formatChat(b.model.ModelType(), messages) + tokens := b.tok.Encode(prompt) + _, err := b.generate(ctx, tokens, opts, cb) + return err +} + // lastPosition extracts the last sequence position from [B, L, V] logits → [B, V]. func lastPosition(logits *mlx.Array) *mlx.Array { shape := logits.Shape() @@ -137,9 +171,49 @@ func lastPosition(logits *mlx.Array) *mlx.Array { return logits } -// Chat formats messages and generates a response. -func (b *MLXBackend) Chat(ctx context.Context, messages []Message, opts GenOpts) (string, error) { - // Format as Gemma chat +// checkMemory logs Metal memory usage and forces cleanup if it exceeds budget. +func (b *MLXBackend) checkMemory() { + active := mlx.GetActiveMemory() + budget := b.modelBytes * 3 + if active > budget { + slog.Warn("mlx: memory over budget, forcing cleanup", + "active_mb", active/1024/1024, + "model_mb", b.modelBytes/1024/1024, + "peak_mb", mlx.GetPeakMemory()/1024/1024, + ) + runtime.GC() + runtime.GC() + mlx.ClearCache() + } +} + +// Name returns the backend identifier. +func (b *MLXBackend) Name() string { return "mlx" } + +// Available reports whether Metal GPU is ready. +func (b *MLXBackend) Available() bool { return mlx.MetalAvailable() } + +// formatPrompt wraps a raw prompt in the model's chat template for single-turn generation. +func formatPrompt(modelType, prompt string) string { + switch modelType { + case "qwen3": + return fmt.Sprintf("<|im_start|>user\n%s<|im_end|>\n<|im_start|>assistant\n", prompt) + default: + return tokenizer.FormatGemmaPrompt(prompt) + } +} + +// formatChat builds a multi-turn chat prompt from messages using the model's template. +func formatChat(modelType string, messages []Message) string { + switch modelType { + case "qwen3": + return formatQwen3Chat(messages) + default: + return formatGemmaChat(messages) + } +} + +func formatGemmaChat(messages []Message) string { var prompt string for _, msg := range messages { switch msg.Role { @@ -152,83 +226,21 @@ func (b *MLXBackend) Chat(ctx context.Context, messages []Message, opts GenOpts) } } prompt += "model\n" - - // Use raw prompt (already formatted) - b.mu.Lock() - defer b.mu.Unlock() - - for _, c := range b.caches { - c.Reset() - } - - temp := float32(opts.Temperature) - if temp == 0 { - temp = 0.1 - } - sampler := sample.New(temp, 0, 0, 0) - - tokens := b.tok.Encode(prompt) - input := mlx.FromValues(tokens, 1, len(tokens)) - - maxTokens := opts.MaxTokens - if maxTokens == 0 { - maxTokens = 2048 - } - - var output []int32 - for i := 0; i < maxTokens; i++ { - select { - case <-ctx.Done(): - runtime.GC() - mlx.ClearCache() - return b.tok.Decode(output), ctx.Err() - default: - } - - logits := b.model.Forward(input, b.caches) - logits = lastPosition(logits) - next := sampler.Sample(logits) - mlx.Materialize(next) - - nextToken := int32(next.Int()) - if nextToken == b.tok.EOSToken() { - break - } - output = append(output, nextToken) - input = mlx.FromValues([]int32{nextToken}, 1, 1) - - // Force GC to collect intermediate arrays + release Metal allocator cache - if i%4 == 3 { - runtime.GC() - mlx.ClearCache() - } - } - - // Cleanup between requests - runtime.GC() - mlx.ClearCache() - b.checkMemory() - return b.tok.Decode(output), nil + return prompt } -// checkMemory logs Metal memory usage and forces cleanup if it exceeds budget. -func (b *MLXBackend) checkMemory() { - active := mlx.GetActiveMemory() - budget := b.modelBytes * 3 // 3× model size = danger zone - if active > budget { - slog.Warn("mlx: memory over budget, forcing cleanup", - "active_mb", active/1024/1024, - "model_mb", b.modelBytes/1024/1024, - "peak_mb", mlx.GetPeakMemory()/1024/1024, - ) - runtime.GC() - runtime.GC() // double GC to run finalizers - mlx.ClearCache() +func formatQwen3Chat(messages []Message) string { + var prompt string + for _, msg := range messages { + switch msg.Role { + case "system": + prompt += fmt.Sprintf("<|im_start|>system\n%s<|im_end|>\n", msg.Content) + case "user": + prompt += fmt.Sprintf("<|im_start|>user\n%s<|im_end|>\n", msg.Content) + case "assistant": + prompt += fmt.Sprintf("<|im_start|>assistant\n%s<|im_end|>\n", msg.Content) + } } + prompt += "<|im_start|>assistant\n" + return prompt } - -// Name returns the backend identifier. -func (b *MLXBackend) Name() string { return "mlx" } - -// Available reports whether Metal GPU is ready. -func (b *MLXBackend) Available() bool { return mlx.MetalAvailable() } diff --git a/mlx/array.go b/mlx/array.go index 6d36df2..ee4f4a8 100644 --- a/mlx/array.go +++ b/mlx/array.go @@ -1,4 +1,4 @@ -//go:build darwin && arm64 && mlx +//go:build darwin && arm64 package mlx diff --git a/mlx/cache/cache.go b/mlx/cache/cache.go index 3945b78..ced8b39 100644 --- a/mlx/cache/cache.go +++ b/mlx/cache/cache.go @@ -1,4 +1,4 @@ -//go:build darwin && arm64 && mlx +//go:build darwin && arm64 // Package cache provides KV cache implementations for transformer inference. package cache diff --git a/mlx/compile.go b/mlx/compile.go index 7727344..f62426f 100644 --- a/mlx/compile.go +++ b/mlx/compile.go @@ -1,4 +1,4 @@ -//go:build darwin && arm64 && mlx +//go:build darwin && arm64 package mlx diff --git a/mlx/dtype.go b/mlx/dtype.go index 8692f95..eae583f 100644 --- a/mlx/dtype.go +++ b/mlx/dtype.go @@ -1,4 +1,4 @@ -//go:build darwin && arm64 && mlx +//go:build darwin && arm64 package mlx diff --git a/mlx/fast.go b/mlx/fast.go index 936c64a..e1abeba 100644 --- a/mlx/fast.go +++ b/mlx/fast.go @@ -1,4 +1,4 @@ -//go:build darwin && arm64 && mlx +//go:build darwin && arm64 package mlx diff --git a/mlx/io.go b/mlx/io.go index c7247b2..7e35773 100644 --- a/mlx/io.go +++ b/mlx/io.go @@ -1,4 +1,4 @@ -//go:build darwin && arm64 && mlx +//go:build darwin && arm64 package mlx diff --git a/mlx/mlx.go b/mlx/mlx.go index 31445dd..470e1f7 100644 --- a/mlx/mlx.go +++ b/mlx/mlx.go @@ -1,4 +1,4 @@ -//go:build darwin && arm64 && mlx +//go:build darwin && arm64 // Package mlx provides Go bindings for Apple's MLX framework via mlx-c. // @@ -6,9 +6,9 @@ // // cd pkg/mlx && go generate ./... // -// Build with MLX enabled: +// Build (MLX is auto-enabled on darwin/arm64): // -// go build -tags mlx -o core . +// go build -o core . package mlx //go:generate cmake -S . -B build -DCMAKE_INSTALL_PREFIX=dist -DCMAKE_BUILD_TYPE=Release diff --git a/mlx/mlx_stub.go b/mlx/mlx_stub.go index 9b6b5cb..281c8cb 100644 --- a/mlx/mlx_stub.go +++ b/mlx/mlx_stub.go @@ -1,4 +1,4 @@ -//go:build !(darwin && arm64 && mlx) +//go:build !(darwin && arm64) // Package mlx provides Go bindings for Apple's MLX framework via mlx-c. // This stub file is used on non-darwin/non-arm64 platforms or when the diff --git a/mlx/model/gemma3.go b/mlx/model/gemma3.go index f448f8c..11d3ae1 100644 --- a/mlx/model/gemma3.go +++ b/mlx/model/gemma3.go @@ -1,4 +1,4 @@ -//go:build darwin && arm64 && mlx +//go:build darwin && arm64 // Package model provides transformer model architectures for MLX inference. package model @@ -16,12 +16,6 @@ import ( "forge.lthn.ai/core/go-ai/mlx/tokenizer" ) -// QuantizationConfig holds quantization parameters from config.json. -type QuantizationConfig struct { - GroupSize int `json:"group_size"` - Bits int `json:"bits"` -} - // TextConfig holds Gemma 3 text model configuration. type TextConfig struct { HiddenSize int32 `json:"hidden_size"` @@ -168,17 +162,6 @@ func parseConfig(data []byte) (*TextConfig, error) { return &cfg, nil } -// resolveWeight looks up a weight with optional "language_model." prefix. -func resolveWeight(weights map[string]*mlx.Array, name string) *mlx.Array { - if w, ok := weights[name]; ok { - return w - } - if w, ok := weights["language_model."+name]; ok { - return w - } - return nil -} - // LoadGemma3 loads a Gemma 3 text model from a directory. func LoadGemma3(modelPath string) (*GemmaModel, error) { data, err := os.ReadFile(filepath.Join(modelPath, "config.json")) @@ -428,3 +411,6 @@ func (m *GemmaModel) NumLayers() int { return len(m.Layers) } // Tokenizer returns the model's tokenizer. func (m *GemmaModel) Tokenizer() *tokenizer.Tokenizer { return m.Tok } + +// ModelType returns the architecture identifier. +func (m *GemmaModel) ModelType() string { return "gemma3" } diff --git a/mlx/model/model.go b/mlx/model/model.go new file mode 100644 index 0000000..5e5481d --- /dev/null +++ b/mlx/model/model.go @@ -0,0 +1,74 @@ +//go:build darwin && arm64 + +// Package model provides transformer model architectures for MLX inference. +package model + +import ( + "encoding/json" + "fmt" + "os" + "path/filepath" + + "forge.lthn.ai/core/go-ai/mlx" + "forge.lthn.ai/core/go-ai/mlx/cache" + "forge.lthn.ai/core/go-ai/mlx/tokenizer" +) + +// Model is the common interface for all transformer model architectures. +type Model interface { + // Forward runs the model forward pass on token IDs with KV caches. + Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array + + // NewCache creates per-layer KV caches for generation. + NewCache() []cache.Cache + + // NumLayers returns the number of transformer layers. + NumLayers() int + + // Tokenizer returns the model's tokenizer. + Tokenizer() *tokenizer.Tokenizer + + // ModelType returns the architecture identifier (e.g. "gemma3", "qwen3"). + ModelType() string +} + +// QuantizationConfig holds quantization parameters from config.json. +type QuantizationConfig struct { + GroupSize int `json:"group_size"` + Bits int `json:"bits"` +} + +// resolveWeight looks up a weight with optional "language_model." prefix. +func resolveWeight(weights map[string]*mlx.Array, name string) *mlx.Array { + if w, ok := weights[name]; ok { + return w + } + if w, ok := weights["language_model."+name]; ok { + return w + } + return nil +} + +// LoadModel auto-detects the model architecture from config.json and loads it. +func LoadModel(modelPath string) (Model, error) { + data, err := os.ReadFile(filepath.Join(modelPath, "config.json")) + if err != nil { + return nil, fmt.Errorf("model: load config: %w", err) + } + + var probe struct { + ModelType string `json:"model_type"` + } + if err := json.Unmarshal(data, &probe); err != nil { + return nil, fmt.Errorf("model: parse model_type: %w", err) + } + + switch probe.ModelType { + case "qwen3": + return LoadQwen3(modelPath) + case "gemma3", "gemma2": + return LoadGemma3(modelPath) + default: + return nil, fmt.Errorf("model: unsupported architecture %q", probe.ModelType) + } +} diff --git a/mlx/model/qwen3.go b/mlx/model/qwen3.go new file mode 100644 index 0000000..3c15578 --- /dev/null +++ b/mlx/model/qwen3.go @@ -0,0 +1,305 @@ +//go:build darwin && arm64 + +package model + +import ( + "encoding/json" + "fmt" + "log/slog" + "math" + "os" + "path/filepath" + + "forge.lthn.ai/core/go-ai/mlx" + "forge.lthn.ai/core/go-ai/mlx/cache" + "forge.lthn.ai/core/go-ai/mlx/tokenizer" +) + +// Qwen3Config holds Qwen 3 model configuration. +type Qwen3Config struct { + HiddenSize int32 `json:"hidden_size"` + NumHiddenLayers int32 `json:"num_hidden_layers"` + IntermediateSize int32 `json:"intermediate_size"` + NumAttentionHeads int32 `json:"num_attention_heads"` + NumKeyValueHeads int32 `json:"num_key_value_heads"` + HeadDim int32 `json:"head_dim"` + VocabSize int32 `json:"vocab_size"` + RMSNormEps float32 `json:"rms_norm_eps"` + RopeTheta float32 `json:"rope_theta"` + MaxPositionEmbeddings int32 `json:"max_position_embeddings"` + + Quantization *QuantizationConfig `json:"-"` + Scale float32 `json:"-"` // 1/sqrt(head_dim) +} + +// Qwen3Model is the Qwen 3 text model. +type Qwen3Model struct { + EmbedTokens *mlx.Embedding + Layers []*Qwen3DecoderLayer + Norm *mlx.RMSNormModule + Output *mlx.Linear + + Tok *tokenizer.Tokenizer + Cfg *Qwen3Config +} + +// Qwen3DecoderLayer is a single transformer block. +// Qwen 3 uses standard pre-norm residual: norm→attn→add, norm→mlp→add. +type Qwen3DecoderLayer struct { + InputNorm *mlx.RMSNormModule // Pre-attention norm + PostAttnNorm *mlx.RMSNormModule // Pre-MLP norm (confusingly named post_attention_layernorm) + Attention *Qwen3Attention + MLP *Qwen3MLP +} + +// Qwen3Attention implements Qwen 3 GQA with Q/K RMS normalization. +type Qwen3Attention struct { + QProj *mlx.Linear + KProj *mlx.Linear + VProj *mlx.Linear + OProj *mlx.Linear + QNorm *mlx.RMSNormModule + KNorm *mlx.RMSNormModule +} + +// Qwen3MLP is the SwiGLU feed-forward network: down(silu(gate(x)) * up(x)). +type Qwen3MLP struct { + GateProj *mlx.Linear + UpProj *mlx.Linear + DownProj *mlx.Linear +} + +func parseQwen3Config(data []byte) (*Qwen3Config, error) { + var cfg Qwen3Config + if err := json.Unmarshal(data, &cfg); err != nil { + return nil, err + } + + // Top-level quantization + var wrapper struct { + Quantization *QuantizationConfig `json:"quantization"` + } + json.Unmarshal(data, &wrapper) + cfg.Quantization = wrapper.Quantization + + // Compute scale + if cfg.HeadDim == 0 { + cfg.HeadDim = cfg.HiddenSize / cfg.NumAttentionHeads + } + cfg.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim))) + + // Defaults + if cfg.RopeTheta == 0 { + cfg.RopeTheta = 1000000 + } + if cfg.RMSNormEps == 0 { + cfg.RMSNormEps = 1e-6 + } + if cfg.VocabSize == 0 { + cfg.VocabSize = 151936 + } + + return &cfg, nil +} + +// LoadQwen3 loads a Qwen 3 model from a safetensors directory. +func LoadQwen3(modelPath string) (*Qwen3Model, error) { + data, err := os.ReadFile(filepath.Join(modelPath, "config.json")) + if err != nil { + return nil, fmt.Errorf("qwen3: load config: %w", err) + } + + cfg, err := parseQwen3Config(data) + if err != nil { + return nil, fmt.Errorf("qwen3: parse config: %w", err) + } + + tok, err := tokenizer.Load(filepath.Join(modelPath, "tokenizer.json")) + if err != nil { + return nil, fmt.Errorf("qwen3: load tokenizer: %w", err) + } + + // Load weights from all safetensors files + weights := make(map[string]*mlx.Array) + matches, _ := filepath.Glob(filepath.Join(modelPath, "*.safetensors")) + for _, path := range matches { + for name, arr := range mlx.LoadSafetensors(path) { + weights[name] = arr + } + } + + w := func(name string) *mlx.Array { return resolveWeight(weights, name) } + + // Quantization setup + q := cfg.Quantization + if q != nil { + slog.Info("qwen3: using quantized inference", "bits", q.Bits, "group_size", q.GroupSize) + } + linear := func(prefix string) *mlx.Linear { + weight := w(prefix + ".weight") + scales := w(prefix + ".scales") + biases := w(prefix + ".biases") + bias := w(prefix + ".bias") + if scales != nil && q != nil { + return mlx.NewQuantizedLinear(weight, scales, biases, bias, q.GroupSize, q.Bits) + } + return mlx.NewLinear(weight, bias) + } + + // Embedding + embed := &mlx.Embedding{Weight: w("model.embed_tokens.weight")} + if embedScales := w("model.embed_tokens.scales"); embedScales != nil && q != nil { + embed.Scales = embedScales + embed.Biases = w("model.embed_tokens.biases") + embed.GroupSize = q.GroupSize + embed.Bits = q.Bits + } + + m := &Qwen3Model{ + EmbedTokens: embed, + Layers: make([]*Qwen3DecoderLayer, cfg.NumHiddenLayers), + Norm: &mlx.RMSNormModule{Weight: w("model.norm.weight")}, + Tok: tok, + Cfg: cfg, + } + + for i := int32(0); i < cfg.NumHiddenLayers; i++ { + p := fmt.Sprintf("model.layers.%d", i) + m.Layers[i] = &Qwen3DecoderLayer{ + InputNorm: &mlx.RMSNormModule{Weight: w(p + ".input_layernorm.weight")}, + PostAttnNorm: &mlx.RMSNormModule{Weight: w(p + ".post_attention_layernorm.weight")}, + Attention: &Qwen3Attention{ + QProj: linear(p + ".self_attn.q_proj"), + KProj: linear(p + ".self_attn.k_proj"), + VProj: linear(p + ".self_attn.v_proj"), + OProj: linear(p + ".self_attn.o_proj"), + QNorm: &mlx.RMSNormModule{Weight: w(p + ".self_attn.q_norm.weight")}, + KNorm: &mlx.RMSNormModule{Weight: w(p + ".self_attn.k_norm.weight")}, + }, + MLP: &Qwen3MLP{ + GateProj: linear(p + ".mlp.gate_proj"), + UpProj: linear(p + ".mlp.up_proj"), + DownProj: linear(p + ".mlp.down_proj"), + }, + } + } + + // Output head — Qwen 3 has tie_word_embeddings=false, so lm_head is separate + lmHeadWeight := w("lm_head.weight") + if lmHeadWeight != nil { + lmHeadScales := w("lm_head.scales") + if lmHeadScales != nil && q != nil { + m.Output = mlx.NewQuantizedLinear(lmHeadWeight, lmHeadScales, w("lm_head.biases"), nil, q.GroupSize, q.Bits) + } else { + m.Output = mlx.NewLinear(lmHeadWeight, nil) + } + } else { + m.Output = m.EmbedTokens.AsLinear() + } + + // Materialise all weights onto Metal + var allArrays []*mlx.Array + for _, a := range weights { + allArrays = append(allArrays, a) + } + mlx.Materialize(allArrays...) + + slog.Info("qwen3: model loaded", + "layers", cfg.NumHiddenLayers, + "hidden", cfg.HiddenSize, + "heads", cfg.NumAttentionHeads, + "kv_heads", cfg.NumKeyValueHeads, + "head_dim", cfg.HeadDim, + "vocab", cfg.VocabSize, + ) + + return m, nil +} + +// Forward runs the Qwen 3 forward pass. +// Unlike Gemma, Qwen does NOT scale embeddings by sqrt(hidden_size). +func (m *Qwen3Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array { + shape := tokens.Shape() + B, L := shape[0], shape[1] + + h := m.EmbedTokens.Forward(tokens) + + for i, layer := range m.Layers { + h = layer.forward(h, caches[i], B, L, m.Cfg) + } + + return m.Output.Forward(m.Norm.Forward(h, m.Cfg.RMSNormEps)) +} + +func (l *Qwen3DecoderLayer) forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Qwen3Config) *mlx.Array { + // Pre-attention norm → attention → residual add + normed := l.InputNorm.Forward(x, cfg.RMSNormEps) + attnOut := l.Attention.forward(normed, c, B, L, cfg) + h := mlx.Add(x, attnOut) + + // Pre-MLP norm → MLP → residual add + normed = l.PostAttnNorm.Forward(h, cfg.RMSNormEps) + mlpOut := l.MLP.forward(normed) + return mlx.Add(h, mlpOut) +} + +func (a *Qwen3Attention) forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Qwen3Config) *mlx.Array { + q := a.QProj.Forward(x) + k := a.KProj.Forward(x) + v := a.VProj.Forward(x) + + // Reshape to [B, num_heads, L, head_dim] + q = mlx.AsStrided(q, []int32{B, cfg.NumAttentionHeads, L, cfg.HeadDim}, + []int64{int64(L * cfg.NumAttentionHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumAttentionHeads * cfg.HeadDim), 1}, 0) + k = mlx.AsStrided(k, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim}, + []int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0) + v = mlx.AsStrided(v, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim}, + []int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0) + + // Q/K RMS normalization (Qwen 3 has this) + q = a.QNorm.Forward(q, cfg.RMSNormEps) + k = a.KNorm.Forward(k, cfg.RMSNormEps) + + // RoPE — single theta for all layers (no sliding window) + q = mlx.RoPE(q, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, c.Offset()) + k = mlx.RoPE(k, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, c.Offset()) + + // Update KV cache + k, v = c.Update(k, v, int(L)) + + // GQA: repeat K/V heads to match Q heads + repeatFactor := cfg.NumAttentionHeads / cfg.NumKeyValueHeads + if repeatFactor > 1 { + k = mlx.RepeatKV(k, repeatFactor) + v = mlx.RepeatKV(v, repeatFactor) + } + + // Scaled dot-product attention + out := mlx.ScaledDotProductAttention(q, k, v, cfg.Scale, L > 1) + out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.HeadDim) + return a.OProj.Forward(out) +} + +// forward computes SwiGLU: down(silu(gate(x)) * up(x)). +func (m *Qwen3MLP) forward(x *mlx.Array) *mlx.Array { + gate := mlx.SiLU(m.GateProj.Forward(x)) + return m.DownProj.Forward(mlx.Mul(gate, m.UpProj.Forward(x))) +} + +// NewCache creates per-layer KV caches. Qwen 3 uses global attention only. +func (m *Qwen3Model) NewCache() []cache.Cache { + caches := make([]cache.Cache, len(m.Layers)) + for i := range caches { + caches[i] = cache.NewKVCache() + } + return caches +} + +// NumLayers returns the number of transformer layers. +func (m *Qwen3Model) NumLayers() int { return len(m.Layers) } + +// Tokenizer returns the model's tokenizer. +func (m *Qwen3Model) Tokenizer() *tokenizer.Tokenizer { return m.Tok } + +// ModelType returns the architecture identifier. +func (m *Qwen3Model) ModelType() string { return "qwen3" } diff --git a/mlx/nn.go b/mlx/nn.go index f06aada..ccd0c9e 100644 --- a/mlx/nn.go +++ b/mlx/nn.go @@ -1,4 +1,4 @@ -//go:build darwin && arm64 && mlx +//go:build darwin && arm64 package mlx diff --git a/mlx/ops.go b/mlx/ops.go index 7c388f9..c743ce9 100644 --- a/mlx/ops.go +++ b/mlx/ops.go @@ -1,4 +1,4 @@ -//go:build darwin && arm64 && mlx +//go:build darwin && arm64 package mlx @@ -68,6 +68,18 @@ func Exp(a *Array) *Array { return out } +// Sigmoid returns element-wise 1/(1+exp(-a)). +func Sigmoid(a *Array) *Array { + out := New("SIGMOID", a) + C.mlx_sigmoid(&out.ctx, a.ctx, DefaultStream().ctx) + return out +} + +// SiLU returns element-wise x * sigmoid(x) (Swish activation). +func SiLU(a *Array) *Array { + return Mul(a, Sigmoid(a)) +} + // Tanh returns element-wise tanh(a). func Tanh(a *Array) *Array { out := New("TANH", a) diff --git a/mlx/random.go b/mlx/random.go index bfadada..f7e09b2 100644 --- a/mlx/random.go +++ b/mlx/random.go @@ -1,4 +1,4 @@ -//go:build darwin && arm64 && mlx +//go:build darwin && arm64 package mlx diff --git a/mlx/sample/sample.go b/mlx/sample/sample.go index ff8f19d..8a49c04 100644 --- a/mlx/sample/sample.go +++ b/mlx/sample/sample.go @@ -1,4 +1,4 @@ -//go:build darwin && arm64 && mlx +//go:build darwin && arm64 // Package sample provides composable token sampling strategies. package sample diff --git a/mlx/slice.go b/mlx/slice.go index da5ff74..5bb7a66 100644 --- a/mlx/slice.go +++ b/mlx/slice.go @@ -1,4 +1,4 @@ -//go:build darwin && arm64 && mlx +//go:build darwin && arm64 package mlx diff --git a/mlx/stream.go b/mlx/stream.go index 261ea93..248a224 100644 --- a/mlx/stream.go +++ b/mlx/stream.go @@ -1,4 +1,4 @@ -//go:build darwin && arm64 && mlx +//go:build darwin && arm64 package mlx diff --git a/mlx/tokenizer/tokenizer.go b/mlx/tokenizer/tokenizer.go index 9dd9450..3537d6b 100644 --- a/mlx/tokenizer/tokenizer.go +++ b/mlx/tokenizer/tokenizer.go @@ -1,6 +1,6 @@ -//go:build darwin && arm64 && mlx +//go:build darwin && arm64 -// Package tokenizer provides BPE/SentencePiece tokenization for Gemma models. +// Package tokenizer provides BPE tokenization for transformer models. package tokenizer import ( @@ -19,6 +19,11 @@ type Tokenizer struct { bosToken int32 eosToken int32 + + // GPT-2 byte-level BPE support (used by Qwen, GPT, Llama, etc.) + isGPT2BPE bool + gpt2Decoder map[rune]byte // Unicode char → original byte + gpt2Encoder map[byte]rune // original byte → Unicode char } type mergePair struct { @@ -32,7 +37,7 @@ type tokenizerJSON struct { Type string `json:"type"` Vocab json.RawMessage `json:"vocab"` Merges json.RawMessage `json:"merges"` - ByteFallback bool `json:"byte_fallback"` + ByteFallback bool `json:"byte_fallback"` } `json:"model"` AddedTokens []struct { ID int32 `json:"id"` @@ -71,7 +76,6 @@ func Load(path string) (*Tokenizer, error) { // Parse merges — supports both ["a b", ...] and [["a","b"], ...] formats if len(tj.Model.Merges) > 0 { - // Try array-of-strings first var stringMerges []string if err := json.Unmarshal(tj.Model.Merges, &stringMerges); err == nil { for rank, merge := range stringMerges { @@ -81,7 +85,6 @@ func Load(path string) (*Tokenizer, error) { } } } else { - // Try array-of-arrays: [["a","b"], ...] var arrayMerges [][]string if err := json.Unmarshal(tj.Model.Merges, &arrayMerges); err == nil { for rank, pair := range arrayMerges { @@ -102,37 +105,77 @@ func Load(path string) (*Tokenizer, error) { t.invVocab[tok.ID] = tok.Content } - // Set BOS/EOS + // Detect GPT-2 byte-level BPE (Qwen, GPT, Llama use Ġ for space) + if _, ok := t.vocab["Ġ"]; ok { + t.isGPT2BPE = true + t.gpt2Decoder, t.gpt2Encoder = buildGPT2ByteMaps() + } + + // Set BOS/EOS — detect model family from special tokens if id, ok := t.special[""]; ok { t.bosToken = id } if id, ok := t.special[""]; ok { t.eosToken = id } + // Gemma: is the generation stop token if id, ok := t.special[""]; ok { - t.eosToken = id // Gemma uses end_of_turn as EOS + t.eosToken = id + } + // Qwen3: <|im_end|> is the generation stop token + if id, ok := t.special["<|im_end|>"]; ok { + t.eosToken = id + } + // Qwen3 BOS: <|im_start|> + if id, ok := t.special["<|im_start|>"]; ok { + t.bosToken = id } return t, nil } +// buildGPT2ByteMaps creates the GPT-2 byte-level BPE encoding/decoding maps. +// GPT-2 maps all 256 bytes to printable Unicode characters to avoid control chars +// in the vocabulary. Printable ASCII + Latin-1 Supplement map to themselves; +// everything else (0-32, 127-160, 173) maps to U+0100 onwards. +func buildGPT2ByteMaps() (decoder map[rune]byte, encoder map[byte]rune) { + encoder = make(map[byte]rune, 256) + decoder = make(map[rune]byte, 256) + + // Self-mapping ranges: printable ASCII + Latin-1 Supplement + // Use int loop variable to avoid byte overflow at 255. + selfMap := func(lo, hi int) { + for b := lo; b <= hi; b++ { + encoder[byte(b)] = rune(b) + decoder[rune(b)] = byte(b) + } + } + selfMap(33, 126) // ! through ~ + selfMap(161, 172) // ¡ through ¬ + selfMap(174, 255) // ® through ÿ + + // Non-self-mapping: control chars, space, DEL, and gaps + n := 0 + for b := 0; b < 256; b++ { + if _, ok := encoder[byte(b)]; !ok { + r := rune(256 + n) + encoder[byte(b)] = r + decoder[r] = byte(b) + n++ + } + } + return +} + // Encode converts text to token IDs. Prepends BOS token. func (t *Tokenizer) Encode(text string) []int32 { tokens := []int32{t.bosToken} - // Simple BPE encoding — split into characters then merge - // This is a simplified version. Full implementation handles - // Unicode, byte fallback, and efficient BPE merging. - chars := []string{} - for _, r := range text { - s := string(r) - if s == " " { - s = "▁" // SentencePiece space marker - } - chars = append(chars, s) + if t.isGPT2BPE { + return t.encodeGPT2(text) } - // Check for special tokens first + // SentencePiece style encoding remaining := text for remaining != "" { found := false @@ -145,7 +188,6 @@ func (t *Tokenizer) Encode(text string) []int32 { } } if !found { - // Encode character by character (simplified BPE) r := []rune(remaining) ch := "▁" + string(r[0]) if id, ok := t.vocab[ch]; ok { @@ -160,24 +202,95 @@ func (t *Tokenizer) Encode(text string) []int32 { return tokens } +// encodeGPT2 encodes text using GPT-2 byte-level BPE. +func (t *Tokenizer) encodeGPT2(text string) []int32 { + tokens := []int32{t.bosToken} + + // Convert text bytes to GPT-2 Unicode representation + var encoded strings.Builder + for _, b := range []byte(text) { + if r, ok := t.gpt2Encoder[b]; ok { + encoded.WriteRune(r) + } + } + gpt2Text := encoded.String() + + // Scan for special tokens and regular text + remaining := gpt2Text + for remaining != "" { + // Check special tokens (these are stored as-is, not byte-encoded) + found := false + for tok, id := range t.special { + // Special tokens in GPT-2 tokenizers are stored in their original form + // Convert the special token to GPT-2 encoding for matching + var encTok strings.Builder + for _, b := range []byte(tok) { + if r, ok := t.gpt2Encoder[b]; ok { + encTok.WriteRune(r) + } + } + encStr := encTok.String() + if strings.HasPrefix(remaining, encStr) { + tokens = append(tokens, id) + remaining = remaining[len(encStr):] + found = true + break + } + } + if !found { + // Character-by-character lookup (simplified BPE) + r := []rune(remaining) + ch := string(r[0]) + if id, ok := t.vocab[ch]; ok { + tokens = append(tokens, id) + } + remaining = string(r[1:]) + } + } + + return tokens +} + // Decode converts token IDs back to text. func (t *Tokenizer) Decode(tokens []int32) string { var sb strings.Builder for _, id := range tokens { if text, ok := t.invVocab[id]; ok { - // Replace SentencePiece space marker - text = strings.ReplaceAll(text, "▁", " ") + // Skip special tokens in decode output + if _, isSpecial := t.special[text]; isSpecial { + continue + } sb.WriteString(text) } } - result := sb.String() - // Trim leading space from SentencePiece encoding + raw := sb.String() + + if t.isGPT2BPE { + return t.decodeGPT2Bytes(raw) + } + + // SentencePiece style + result := strings.ReplaceAll(raw, "▁", " ") if strings.HasPrefix(result, " ") { result = result[1:] } return result } +// decodeGPT2Bytes converts GPT-2 byte-level BPE Unicode back to real bytes. +func (t *Tokenizer) decodeGPT2Bytes(s string) string { + var buf []byte + for _, r := range s { + if b, ok := t.gpt2Decoder[r]; ok { + buf = append(buf, b) + } else { + // Non-mapped runes pass through as UTF-8 + buf = append(buf, []byte(string(r))...) + } + } + return string(buf) +} + // BOSToken returns the beginning-of-sequence token ID. func (t *Tokenizer) BOSToken() int32 { return t.bosToken }