go-mlx/docs/architecture.md

553 lines
24 KiB
Markdown
Raw Normal View History

# Architecture
Module: `forge.lthn.ai/core/go-mlx`
Native Apple Metal GPU inference via mlx-c bindings, implementing the `inference.Backend` interface from `forge.lthn.ai/core/go-inference` for Apple Silicon (M1-M4).
---
## Package Layout
```
go-mlx/
├── mlx.go — Package doc + go:generate CMake directives
├── mlx_stub.go — !darwin || !arm64: MetalAvailable() = false
├── register_metal.go — darwin && arm64: registers "metal" backend via init()
├── mlx_test.go — Integration tests (public API via go-inference)
├── internal/metal/ — All CGO code (darwin && arm64 only)
│ ├── metal.go — Init, error handler, Eval/EvalAsync/Materialize
│ ├── array.go — Array type, creation, data access, Iter()
│ ├── dtype.go — DType constants
│ ├── stream.go — Metal stream/queue, memory controls
│ ├── ops.go — Element-wise, reduction, matrix, shape ops
│ ├── fast.go — Fused Metal kernels: RMSNorm, LayerNorm, RoPE, SDPA
│ ├── nn.go — Linear, Embedding, RMSNormModule, RepeatKV
│ ├── compile.go — CompiledFunc (shapeless function compilation)
│ ├── slice.go — Array slicing, update-in-place
│ ├── random.go — RandomCategorical, RandomUniform, RandomNormal
│ ├── io.go — Safetensors load/save
│ ├── model.go — InternalModel interface + architecture dispatch
│ ├── gemma3.go — Gemma 3 decoder
│ ├── qwen3.go — Qwen 2/3 and Llama 3 decoder
│ ├── cache.go — KVCache + RotatingKVCache
│ ├── sample.go — Sampling chain: greedy, temperature, TopK, TopP, MinP
│ ├── tokenizer.go — BPE tokenizer (SentencePiece + GPT-2 byte-level)
│ ├── grad.go — VJP, JVP, ValueAndGrad, Checkpoint, loss functions
│ ├── lora.go — LoRA adapters, random normal, safetensors save
│ ├── optim.go — AdamW optimiser
│ ├── generate.go — Model, Generate, Chat, batch inference, metrics
│ ├── close.go — Deterministic weight cleanup
│ └── backend.go — LoadAndInit entry point
└── mlxlm/ — Python subprocess backend
├── backend.go — mlxlmBackend implementing inference.Backend
└── bridge.py — Python script (embedded via //go:embed)
```
---
## CGO / mlx-c Binding
### Build Chain
The native layer depends on mlx-c v0.4.1, a C API wrapping Apple's MLX C++ framework. `go generate ./...` fetches and builds it via CMake:
```
go:generate cmake -S . -B build -DCMAKE_INSTALL_PREFIX=dist ...
go:generate cmake --build build --parallel
go:generate cmake --install build
```
CMake installs headers to `dist/include/` and shared libraries to `dist/lib/`. The `#cgo` directives in `internal/metal/metal.go` reference those paths:
```
CPPFLAGS: -I${SRCDIR}/../../dist/include
LDFLAGS: -L${SRCDIR}/../../dist/lib -lmlxc -lmlx
darwin: -framework Foundation -framework Metal -framework Accelerate
-Wl,-rpath,${SRCDIR}/../../dist/lib
```
Every Go source file in `internal/metal/` carries `//go:build darwin && arm64`. The root package compiles on all platforms; only the blank import of `_ "forge.lthn.ai/core/go-mlx"` triggers the Metal backend on supported hardware.
### Error Handling
mlx-c reports errors through a registered C callback. The handler stores the error string in a C atomic variable using `atomic_store_explicit` with release ordering. `lastError()` reads and atomically clears it with acquire ordering, returning a Go error. `Eval()` checks the mlx return code and calls `lastError()` to surface real MLX messages. `Materialize()` wraps `Eval()` and logs on error without returning; callers that need propagation call `Eval()` directly.
### Evaluation Model
MLX uses lazy evaluation: operations build a computation graph without executing. Execution is triggered by `mlx_eval` or `mlx_async_eval`, which dispatch the graph to the Metal GPU. Go wrappers:
- `Eval(...*Array) error` — synchronous, returns error
- `EvalAsync(...*Array) error` — queues for async execution
- `Materialize(...*Array)` — synchronous, logs error (used in test helpers and weight loading)
---
## Array Type
`Array` wraps an `mlx_array` (a C-side opaque handle). Arrays are reference-counted on the C side; Go uses `runtime.SetFinalizer` to call `mlx_array_free` when the Go object is collected. Go 1.26's Green Tea GC reduces finaliser latency under sustained inference.
Key operations:
- Creation: `newArray()`, `FromValue()`, `FromValues()`, `Zeros()`, `Ones()`
- Data access: `Floats()`, `DataInt32()`, `Int()` — all call `ensureContiguous()` first to handle view arrays (transpose, broadcast, slice views) that have non-contiguous physical layouts. Previously, reading views returned silently incorrect data.
- Shape: `Shape()`, `Dim()`, `Size()`
- Iteration: `Array.Iter() iter.Seq[float32]` — range-over-func (stable since Go 1.23), handles non-contiguous arrays
---
## Memory Management
The Metal allocator (separate from system RAM) is controlled via functions exposed at the root package level:
| Function | Purpose |
|----------|---------|
| `SetCacheLimit(bytes)` | Soft limit on allocator cache |
| `SetMemoryLimit(bytes)` | Hard limit |
| `SetWiredLimit(bytes)` | Wired memory limit |
| `GetActiveMemory()` | Current live allocations |
| `GetPeakMemory()` | High-water mark since last reset |
| `GetCacheMemory()` | Cached (not yet freed) memory |
| `ClearCache()` | Release cached memory to OS |
| `ResetPeakMemory()` | Reset high-water mark |
`Model.Close()` walks the full model tree and explicitly frees all weight arrays via `Free()`, without relying on GC finalisers. Tied output weights (shared with the embedding table) are detected and skipped to prevent double-free. `Close()` is idempotent.
During generation, each call allocates fresh KV caches that are released to GC at iterator completion. Call `ClearCache()` between multi-turn chat turns for prompt reclaim rather than waiting for GC.
---
## Model Architectures
All architectures implement the `InternalModel` interface:
```go
type InternalModel interface {
Forward(tokens *Array, caches []Cache) *Array
ForwardMasked(tokens *Array, mask *Array, caches []Cache) *Array
NewCache() []Cache
NumLayers() int
Tokenizer() *Tokenizer
ModelType() string
ApplyLoRA(cfg LoRAConfig) *LoRAAdapter
}
```
Architecture is detected from `config.json``model_type` field:
| `model_type` values | Loader | Notes |
|---------------------|--------|-------|
| `gemma3`, `gemma3_text`, `gemma2` | `LoadGemma3` | Gemma 3 decoder |
| `qwen3`, `qwen2`, `llama` | `LoadQwen3` | Shared decoder, variant-specific features |
### Gemma 3
Decoder structure per layer (pre-norm with four norm points per block):
```
input → InputNorm → Attention → PostAttnNorm → residual add
→ PreFFNorm → MLP → PostFFNorm → residual add
```
Attention specifics:
- Q/K RMS normalisation (separate `QNorm`, `KNorm` modules)
- Alternating sliding window / global attention: sliding layers use `RopeLocalBaseFreq` (10000), global layers use `RopeTheta` (1000000). Pattern period determined by `sliding_window_pattern` (default 6)
- Rotary embeddings via fused RoPE Metal kernel with per-layer theta
- Grouped-query attention (GQA): K/V heads repeated via `RepeatKV` when `num_kv_heads < num_attention_heads`
MLP: GELU-based gate using tanh approximation. The GELU function is compiled via `CompileShapeless` (shapeless function compilation) as a singleton to avoid recompilation across calls.
Normalisation: Gemma uses `(1 + weight) * RMSNorm(x)` — the `(1 + weight)` factor is precomputed at load time (`precomputeScaledWeights`) for all seven norm points per layer to avoid repeated additions during inference.
Embedding scale: hidden states are multiplied by `sqrt(hidden_size)` after embedding lookup (Gemma-specific convention). Qwen and Llama do not apply this scale.
Output head: Gemma 3 typically ties `lm_head` weights to `embed_tokens`. If a separate `lm_head.weight` is present in the safetensors, it is used as an independent output projection.
### Qwen 3 / Qwen 2 / Llama 3
These three architectures share one loader (`LoadQwen3`) and one decoder implementation. Distinctions:
| Feature | Qwen 3 | Qwen 2 | Llama 3 |
|---------|--------|--------|---------|
| Q/K norm | Yes | No | No |
| Sliding window | No | No | No |
| EOS token | `<\|im_end\|>` | `<\|im_end\|>` | `<\|eot_id\|>` |
| BOS token | `<\|im_start\|>` | `<\|im_start\|>` | `<\|begin_of_text\|>` |
Qwen 2 detection: if `model_type` is absent from config, weight presence of `model.layers.0.self_attn.q_norm.weight` distinguishes Qwen 3 (present) from Qwen 2 (absent).
Decoder structure per layer (standard pre-norm):
```
input → InputNorm → Attention → residual add
→ PostAttnNorm → MLP → residual add
```
MLP: SwiGLU gate — `down(silu(gate(x)) * up(x))`.
Output head: always a separate `lm_head.weight` (Qwen 3 has `tie_word_embeddings=false`).
### Weight Loading
All architectures load from HuggingFace safetensors format (not GGUF). The loader:
1. Reads `config.json` for model configuration
2. Loads `tokenizer.json` for the tokeniser
3. Glob-matches all `*.safetensors` files in the directory (multi-shard support)
4. Calls `LoadSafetensors` per shard; checks `lastError()` after each
5. Resolves weights by name, with automatic `language_model.` prefix fallback via `resolveWeight()`
6. Constructs `Linear` layers as quantised or dense based on presence of `scales` tensors
7. Calls `Materialize()` on all weight arrays to commit them to GPU memory
Quantisation is transparent: `NewQuantizedLinear` stores packed weights with `scales` and `biases`, dispatching to `QuantizedMatmul` (mlx-c grouped quantisation) in `Forward`. Quantisation parameters (`bits`, `group_size`) are read from top-level `config.json`.
Head dimension inference: if `head_dim` is absent from `config.json` (as with some Gemma 3 variants), it is inferred from `q_proj.weight[0]` / `num_attention_heads`.
---
## Attention Mechanism
### Virtual Transpose
Linear projections produce `[B, L, H*D]`. The reshape to `[B, H, L, D]` is implemented via `AsStrided` — a zero-copy stride manipulation that avoids a physical copy:
```
shape: [B, H, L, D]
strides: [L*H*D, D, H*D, 1]
```
- Batch stride: `L*H*D` (jump entire sequence)
- Head stride: `D` (adjacent heads are contiguous in memory)
- Sequence stride: `H*D` (jump one full row of heads)
- Element stride: `1` (contiguous within head)
The result is a non-contiguous view used for RoPE and SDPA calls.
### Rotary Position Embeddings (RoPE)
Applied via the fused `mlx_fast_rope` Metal kernel. Parameters:
- `dims`: head dimension
- `traditional`: false (standard non-interleaved layout)
- `base`: theta (varies by layer type in Gemma 3; single value for Qwen/Llama)
- `scale`: 1.0 (no frequency scaling)
- `offset`: current KV cache offset (enables continuation from cached position)
### Scaled Dot-Product Attention (SDPA)
Implemented via the fused `mlx_fast_scaled_dot_product_attention` kernel with two variants:
- `ScaledDotProductAttention(q, k, v, scale, causal)` — causal masking handled internally by the kernel
- `ScaledDotProductAttentionWithMask(q, k, v, mask, scale)` — explicit additive mask (0 = attend, -inf = ignore), used for batched inference with padding
Scale = `1/sqrt(head_dim)`, precomputed at load time.
After SDPA, output is transposed from `[B, H, L, D]` back to `[B, L, H*D]` via `Reshape(Transpose(out, 0, 2, 1, 3), ...)` for the output projection.
---
## KV Cache
The `Cache` interface provides `Update(k, v *Array, seqLen int) (*Array, *Array)`, returning the full accumulated K/V to pass to SDPA. `Offset()` tracks total tokens processed for RoPE continuation.
### KVCache (Unbounded)
Pre-allocates in 256-token chunks, growing as needed. On each decode step:
1. Checks whether the current buffer capacity is sufficient
2. If not, allocates a new chunk and concatenates it
3. Writes the new K/V via `SliceUpdateInplace`
4. Returns a slice view `[0:offset]` of the buffer
This amortises allocation cost while keeping the returned slice valid for the SDPA call.
### RotatingKVCache (Sliding Window)
Bounded to `maxSize` tokens. Two update paths:
- Prefill (`seqLen > 1`): concatenate, then trim the leading tokens that fall outside the window
- Decode (`seqLen == 1`): write in-place at circular index `idx % maxSize`
Used for Gemma 3 sliding-window attention layers (window size from `sliding_window` config field). Qwen and Llama use only unbounded caches.
---
## Tokeniser
`Tokenizer` supports two BPE variants detected at load time from `tokenizer.json`:
### SentencePiece BPE (Gemma 3)
- Prefix each segment with `▁` (Unicode U+2581, the SentencePiece space marker)
- Split into characters
- Apply BPE merges via `bpeMerge()` using a rank-sorted lookup table
- Look up merged symbols in the vocabulary
Detection: checks for absence of `Ġthe` in the vocabulary. Large SentencePiece vocabularies (Gemma 3 at 262K entries) may contain `Ġ` as an unrelated character, so the detection checks `Ġthe` rather than bare `Ġ`.
### GPT-2 Byte-Level BPE (Qwen, Llama, DeepSeek)
- Maps all 256 bytes to printable Unicode via `buildGPT2ByteMaps()`
- Printable ASCII (33126) and Latin-1 Supplement (161172, 174255) map to themselves
- Control characters, space (32), DEL (127), and gap values (032, 127160, 173) map to U+0100 onwards
- Apply BPE merges in this Unicode representation, then look up in vocabulary
Detection: presence of `Ġthe` in the vocabulary.
### BPE Merge Algorithm
`bpeMerge()` implements the standard greedy algorithm:
1. Build merge rank table from `tokenizer.json` merges field (O(1) lookup by `"a b"` key)
2. Scan all adjacent pairs; find the pair with the lowest rank
3. Merge that pair into a single symbol
4. Repeat until no merge can be applied
Merges are parsed from both `["a b", ...]` and `[["a","b"], ...]` JSON formats.
### Special Token Handling
Special tokens (BOS, EOS, chat delimiters) are matched before BPE encoding. Each architecture family uses different stop tokens:
| Family | BOS | EOS / Stop |
|--------|-----|-----------|
| Gemma 3 | `<bos>` | `<end_of_turn>` |
| Qwen 2/3 | `<\|im_start\|>` | `<\|im_end\|>` |
| Llama 3 | `<\|begin_of_text\|>` | `<\|eot_id\|>` |
---
## Generation Loop
`Model.Generate(ctx, prompt, cfg)` returns `iter.Seq[Token]` (range-over-func). The iterator:
1. Encodes the prompt via `Tokenizer.Encode()`
2. Allocates per-layer KV caches via `newCaches()`
3. Prefill: runs `model.Forward(tokens, caches)` on the full prompt in one pass; records prefill timing
4. Decode loop, up to `MaxTokens`:
- Checks `ctx.Done()`; sets `m.lastErr = ctx.Err()` and returns on cancellation
- Slices last-position logits via `SliceAxis`
- Applies `applyRepeatPenalty` if `RepeatPenalty > 1.0`
- Samples via the configured `Sampler` chain; calls `Eval()` and propagates any GPU error
- Checks EOS token and `StopTokens` IDs
- Yields `Token{ID, Text}` to the consumer; stops if `yield` returns false
- Runs `model.Forward({next_token}, caches)` with the single new token
5. Records decode timing and memory metrics in `m.lastMetrics` via deferred closure
`Model.Err()` returns the error from the most recent `Generate` or `Chat` call.
### Repeat Penalty
`applyRepeatPenalty(logits, history, penalty)` deduplicates the history, gathers logits at those positions, then applies:
- Positive logits: divide by `penalty` (reduces probability)
- Negative logits: multiply by `penalty` (increases magnitude, reducing probability further)
### Chat Templates
`Model.Chat()` formats messages through `formatChat()` before calling `Generate()`:
| Architecture | Format |
|-------------|--------|
| Gemma 3 | `<start_of_turn>role\ncontent<end_of_turn>\n` |
| Qwen 2/3 | `<\|im_start\|>role\ncontent<\|im_end\|>\n` |
| Llama 3 | `<\|start_header_id\|>role<\|end_header_id\|>\n\ncontent<\|eot_id\|>` |
---
## Batch Inference
### Classify (Prefill-Only)
`Model.Classify(ctx, prompts, cfg, returnLogits)` runs a single forward pass per batch — no decode loop. The batch is right-padded to the length of the longest prompt:
1. Tokenise all prompts
2. Sort by descending token count
3. Build a `[N, 1, L, L]` attention mask combining causal masking and padding (0 = attend, -inf = ignore)
4. Run `ForwardMasked(tokens, mask, caches)` on the padded batch
5. Extract last-position logits for each prompt
6. Sample or return raw logits per the configuration
Measured throughput on M3 Ultra: 152 prompts/s for 4-prompt batches (Gemma3-1B 4-bit).
### BatchGenerate (Autoregressive Batches)
`Model.BatchGenerate(ctx, prompts, cfg)` runs full autoregressive generation for multiple prompts using the same masking approach as `Classify`. Each decode step processes the entire batch in one `ForwardMasked` call. Returns `[]BatchResult`, each holding the generated tokens and any per-prompt error.
---
## Sampling Chain
`newSampler(temp, topP, minP, topK)` builds a composable pipeline:
```
TopP -> MinP -> TopK -> Temperature -> RandomCategorical
```
If `temp == 0`, the chain collapses to greedy (argmax). Otherwise, each filter stage masks logits before the final categorical sample.
- **Greedy**: `Argmax(logits, -1)`
- **Temperature**: multiply logits by `1/temp`
- **TopK**: mask all but the K highest logits with `-inf`
- **TopP (nucleus)**: keep the smallest set with cumulative probability exceeding `p`; implemented via argsort, cumsum, and `PutAlongAxis` scatter back to original positions
- **MinP**: mask tokens whose probability falls below `min_p * max_probability`
---
## Training Pipeline
### LoRA Fine-Tuning
`InternalModel.ApplyLoRA(cfg)` wraps target projection layers in-place. The `LoRALinear` struct:
```go
type LoRALinear struct {
Base *Linear // frozen base weights (may be quantised)
A *Array // [rank, in_features] — Kaiming normal initialisation
B *Array // [out_features, rank] — zero initialisation
Scale float32 // alpha / rank
}
```
Forward pass: `base(x) + scale * (x @ A^T) @ B^T`
B is zero-initialised so LoRA starts as the identity transformation (no change to base output).
`LoRAAdapter` collects all `LoRALinear` instances by weight path key. `AllTrainableParams()` returns A and B arrays in deterministic sorted order for use with `ValueAndGrad`. `LoRAAdapter.Save(path)` writes only the A and B matrices to safetensors (not the frozen base weights).
### Gradient Computation
Three autodiff interfaces via mlx-c:
- `VJP(fn, primals, cotangents)` — reverse mode (backward pass)
- `JVP(fn, primals, tangents)` — forward mode (directional derivative)
- `ValueAndGrad(fn, argnums)` — returns a `GradFn` that computes both value and gradients in one call
Go functions are registered as mlx-c closures via `goGradFunc` (exported CGO callback) using an atomic ID registry (`gradNextID atomic.Uintptr`).
### Gradient Checkpointing
`Checkpoint(fn)` wraps a function using `mlx_checkpoint`, which recomputes intermediate activations during the backward pass rather than storing them. Trades compute for GPU memory — useful for large models on constrained hardware.
### Mixed Precision
`LoRAConfig.DType` selects the dtype for A and B matrices. `DTypeBFloat16` halves parameter memory with accuracy matching Float32 in practice (validated: loss 7.15→6.29 in 5 steps). MLX auto-promotes operands for cross-dtype operations.
### AdamW Optimiser
Standard AdamW with decoupled weight decay:
```
m = beta1*m + (1-beta1)*grad
v = beta2*v + (1-beta2)*grad^2
param = param*(1 - lr*wd) - lr * m_hat / (sqrt(v_hat) + eps)
```
Defaults: `lr=1e-5`, `beta1=0.9`, `beta2=0.999`, `eps=1e-8`, `weight_decay=0.01`.
### Loss Functions
- `CrossEntropyLoss(logits, targets)` — numerically stable via logsumexp; averaged over all positions
- `MaskedCrossEntropyLoss(logits, targets, mask)` — averaged over masked positions only
- `MSELoss(predictions, targets)` — mean squared error
---
## Fused Metal Kernels
`internal/metal/fast.go` wraps four mlx-c fused kernels:
| Kernel | Go function | Notes |
|--------|------------|-------|
| `mlx_fast_rms_norm` | `RMSNorm(x, weight, eps)` | Gemma uses pre-scaled `(1+weight)` |
| `mlx_fast_layer_norm` | `LayerNorm(x, weight, bias, eps)` | Standard layer norm |
| `mlx_fast_rope` | `RoPE(x, dims, traditional, base, scale, offset)` | Rotary position embeddings |
| `mlx_fast_scaled_dot_product_attention` | `ScaledDotProductAttention(...)` | Causal or explicit mask |
These bypass the general MLX computation graph, dispatching directly to optimised Metal compute shaders.
---
## go-inference Integration
The public API is provided entirely by `forge.lthn.ai/core/go-inference`. go-mlx exports only Metal-specific controls:
- `MetalAvailable() bool` — hardware check
- `SetCacheLimit`, `SetMemoryLimit`, `GetActiveMemory`, `GetPeakMemory`, `ClearCache`, `GetCacheMemory`, `ResetPeakMemory`, `SetWiredLimit`, `GetDeviceInfo`
`register_metal.go` auto-registers `metalBackend` via `init()` on darwin/arm64. The adapter (`metalAdapter`) converts between `inference.*` types and `metal.*` types, implementing: `Generate`, `Chat`, `Classify`, `BatchGenerate`, `Metrics`, `Info`, `ModelType`, `Err`, `Close`.
Consumer pattern:
```go
import (
"forge.lthn.ai/core/go-inference"
_ "forge.lthn.ai/core/go-mlx"
)
m, err := inference.LoadModel("/path/to/model/")
for tok := range m.Generate(ctx, "prompt", inference.WithMaxTokens(128)) {
fmt.Print(tok.Text)
}
```
`inference.LoadConfig` options understood by the Metal backend:
- `ContextLen` — replaces unbounded `KVCache` with `RotatingKVCache(contextLen)` for all layers
- `GPULayers` — logged as a warning if set to 0 (Metal always uses full GPU offload)
---
## mlxlm Subprocess Backend
`mlxlm/` provides a second backend (`"mlx_lm"`) that does not require CGO. It spawns a Python 3 process running the embedded `bridge.py` script and communicates via JSON Lines over stdin/stdout.
### Protocol
Commands sent to stdin (newline-delimited JSON):
| Command | Request fields | Response |
|---------|---------------|----------|
| `load` | `path` | `{ok, model_type, vocab_size}` or `{error}` |
| `generate` | `prompt, max_tokens, temperature?, top_k?, top_p?` | stream of `{token, token_id}`, then `{done}` |
| `chat` | `messages, max_tokens, ...` | same as generate |
| `info` | — | `{model_type, vocab_size, layers, hidden_size}` |
| `cancel` | — | subprocess drains and returns `{done}` |
| `quit` | — | subprocess exits cleanly |
Concurrent `Generate`/`Chat` calls are serialised via `sync.Mutex` (one generation at a time per subprocess instance).
### bridge.py
Embedded via `//go:embed bridge.py`, extracted to a temp file on first use via `sync.Once`. Uses `mlx_lm.load()` and `mlx_lm.stream_generate()` from the `mlx-lm` Python package. Flushes stdout after every line (critical for streaming).
### Limitations
- `Classify` and `BatchGenerate` are not supported (return error directing caller to use the native Metal backend)
- No inference metrics (`Metrics()` returns zero values)
- Requires Python 3 and `mlx-lm` installed in the Python environment
- Build tag `nomlxlm` removes the backend entirely
---
## Downstream Consumers
| Package | Role |
|---------|------|
| `forge.lthn.ai/core/go-ml` | Imports go-inference + go-mlx for Metal backend |
| `forge.lthn.ai/core/go-i18n` | Phase 2a: Gemma3-1B domain classification |
| `forge.lthn.ai/core/go-rocm` | Sibling AMD GPU backend, same go-inference interfaces |
---
## Performance Baseline (M3 Ultra, 60-core GPU, 96 GB unified memory)
| Operation | Throughput |
|-----------|-----------|
| Gemma3-1B 4-bit prefill | 246 tok/s |
| Gemma3-1B 4-bit decode | 82 tok/s |
| Gemma3-1B 4-bit classify (4 prompts) | 152 prompts/s |
| DeepSeek R1 7B 4-bit decode | 27 tok/s |
| Llama 3.1 8B 4-bit decode | 30 tok/s |
CGO call overhead floors at approximately 170 µs per operation (Metal command buffer + CGO boundary). MatMul scales well: 128² to 4096² is roughly 55× slower for 1024× more work. Full sampling chain (TopP+MinP+TopK) adds approximately 560 µs over greedy per token.