# Architecture Module: `forge.lthn.ai/core/go-mlx` Native Apple Metal GPU inference via mlx-c bindings, implementing the `inference.Backend` interface from `forge.lthn.ai/core/go-inference` for Apple Silicon (M1-M4). --- ## Package Layout ``` go-mlx/ ├── mlx.go — Package doc + go:generate CMake directives ├── mlx_stub.go — !darwin || !arm64: MetalAvailable() = false ├── register_metal.go — darwin && arm64: registers "metal" backend via init() ├── mlx_test.go — Integration tests (public API via go-inference) │ ├── internal/metal/ — All CGO code (darwin && arm64 only) │ ├── metal.go — Init, error handler, Eval/EvalAsync/Materialize │ ├── array.go — Array type, creation, data access, Iter() │ ├── dtype.go — DType constants │ ├── stream.go — Metal stream/queue, memory controls │ ├── ops.go — Element-wise, reduction, matrix, shape ops │ ├── fast.go — Fused Metal kernels: RMSNorm, LayerNorm, RoPE, SDPA │ ├── nn.go — Linear, Embedding, RMSNormModule, RepeatKV │ ├── compile.go — CompiledFunc (shapeless function compilation) │ ├── slice.go — Array slicing, update-in-place │ ├── random.go — RandomCategorical, RandomUniform, RandomNormal │ ├── io.go — Safetensors load/save │ ├── model.go — InternalModel interface + architecture dispatch │ ├── gemma3.go — Gemma 3 decoder │ ├── qwen3.go — Qwen 2/3 and Llama 3 decoder │ ├── cache.go — KVCache + RotatingKVCache │ ├── sample.go — Sampling chain: greedy, temperature, TopK, TopP, MinP │ ├── tokenizer.go — BPE tokenizer (SentencePiece + GPT-2 byte-level) │ ├── grad.go — VJP, JVP, ValueAndGrad, Checkpoint, loss functions │ ├── lora.go — LoRA adapters, random normal, safetensors save │ ├── optim.go — AdamW optimiser │ ├── generate.go — Model, Generate, Chat, batch inference, metrics │ ├── close.go — Deterministic weight cleanup │ └── backend.go — LoadAndInit entry point │ └── mlxlm/ — Python subprocess backend ├── backend.go — mlxlmBackend implementing inference.Backend └── bridge.py — Python script (embedded via //go:embed) ``` --- ## CGO / mlx-c Binding ### Build Chain The native layer depends on mlx-c v0.4.1, a C API wrapping Apple's MLX C++ framework. `go generate ./...` fetches and builds it via CMake: ``` go:generate cmake -S . -B build -DCMAKE_INSTALL_PREFIX=dist ... go:generate cmake --build build --parallel go:generate cmake --install build ``` CMake installs headers to `dist/include/` and shared libraries to `dist/lib/`. The `#cgo` directives in `internal/metal/metal.go` reference those paths: ``` CPPFLAGS: -I${SRCDIR}/../../dist/include LDFLAGS: -L${SRCDIR}/../../dist/lib -lmlxc -lmlx darwin: -framework Foundation -framework Metal -framework Accelerate -Wl,-rpath,${SRCDIR}/../../dist/lib ``` Every Go source file in `internal/metal/` carries `//go:build darwin && arm64`. The root package compiles on all platforms; only the blank import of `_ "forge.lthn.ai/core/go-mlx"` triggers the Metal backend on supported hardware. ### Error Handling mlx-c reports errors through a registered C callback. The handler stores the error string in a C atomic variable using `atomic_store_explicit` with release ordering. `lastError()` reads and atomically clears it with acquire ordering, returning a Go error. `Eval()` checks the mlx return code and calls `lastError()` to surface real MLX messages. `Materialize()` wraps `Eval()` and logs on error without returning; callers that need propagation call `Eval()` directly. ### Evaluation Model MLX uses lazy evaluation: operations build a computation graph without executing. Execution is triggered by `mlx_eval` or `mlx_async_eval`, which dispatch the graph to the Metal GPU. Go wrappers: - `Eval(...*Array) error` — synchronous, returns error - `EvalAsync(...*Array) error` — queues for async execution - `Materialize(...*Array)` — synchronous, logs error (used in test helpers and weight loading) --- ## Array Type `Array` wraps an `mlx_array` (a C-side opaque handle). Arrays are reference-counted on the C side; Go uses `runtime.SetFinalizer` to call `mlx_array_free` when the Go object is collected. Go 1.26's Green Tea GC reduces finaliser latency under sustained inference. Key operations: - Creation: `newArray()`, `FromValue()`, `FromValues()`, `Zeros()`, `Ones()` - Data access: `Floats()`, `DataInt32()`, `Int()` — all call `ensureContiguous()` first to handle view arrays (transpose, broadcast, slice views) that have non-contiguous physical layouts. Previously, reading views returned silently incorrect data. - Shape: `Shape()`, `Dim()`, `Size()` - Iteration: `Array.Iter() iter.Seq[float32]` — range-over-func (stable since Go 1.23), handles non-contiguous arrays --- ## Memory Management The Metal allocator (separate from system RAM) is controlled via functions exposed at the root package level: | Function | Purpose | |----------|---------| | `SetCacheLimit(bytes)` | Soft limit on allocator cache | | `SetMemoryLimit(bytes)` | Hard limit | | `SetWiredLimit(bytes)` | Wired memory limit | | `GetActiveMemory()` | Current live allocations | | `GetPeakMemory()` | High-water mark since last reset | | `GetCacheMemory()` | Cached (not yet freed) memory | | `ClearCache()` | Release cached memory to OS | | `ResetPeakMemory()` | Reset high-water mark | `Model.Close()` walks the full model tree and explicitly frees all weight arrays via `Free()`, without relying on GC finalisers. Tied output weights (shared with the embedding table) are detected and skipped to prevent double-free. `Close()` is idempotent. During generation, each call allocates fresh KV caches that are released to GC at iterator completion. Call `ClearCache()` between multi-turn chat turns for prompt reclaim rather than waiting for GC. --- ## Model Architectures All architectures implement the `InternalModel` interface: ```go type InternalModel interface { Forward(tokens *Array, caches []Cache) *Array ForwardMasked(tokens *Array, mask *Array, caches []Cache) *Array NewCache() []Cache NumLayers() int Tokenizer() *Tokenizer ModelType() string ApplyLoRA(cfg LoRAConfig) *LoRAAdapter } ``` Architecture is detected from `config.json` → `model_type` field: | `model_type` values | Loader | Notes | |---------------------|--------|-------| | `gemma3`, `gemma3_text`, `gemma2` | `LoadGemma3` | Gemma 3 decoder | | `qwen3`, `qwen2`, `llama` | `LoadQwen3` | Shared decoder, variant-specific features | ### Gemma 3 Decoder structure per layer (pre-norm with four norm points per block): ``` input → InputNorm → Attention → PostAttnNorm → residual add → PreFFNorm → MLP → PostFFNorm → residual add ``` Attention specifics: - Q/K RMS normalisation (separate `QNorm`, `KNorm` modules) - Alternating sliding window / global attention: sliding layers use `RopeLocalBaseFreq` (10000), global layers use `RopeTheta` (1000000). Pattern period determined by `sliding_window_pattern` (default 6) - Rotary embeddings via fused RoPE Metal kernel with per-layer theta - Grouped-query attention (GQA): K/V heads repeated via `RepeatKV` when `num_kv_heads < num_attention_heads` MLP: GELU-based gate using tanh approximation. The GELU function is compiled via `CompileShapeless` (shapeless function compilation) as a singleton to avoid recompilation across calls. Normalisation: Gemma uses `(1 + weight) * RMSNorm(x)` — the `(1 + weight)` factor is precomputed at load time (`precomputeScaledWeights`) for all seven norm points per layer to avoid repeated additions during inference. Embedding scale: hidden states are multiplied by `sqrt(hidden_size)` after embedding lookup (Gemma-specific convention). Qwen and Llama do not apply this scale. Output head: Gemma 3 typically ties `lm_head` weights to `embed_tokens`. If a separate `lm_head.weight` is present in the safetensors, it is used as an independent output projection. ### Qwen 3 / Qwen 2 / Llama 3 These three architectures share one loader (`LoadQwen3`) and one decoder implementation. Distinctions: | Feature | Qwen 3 | Qwen 2 | Llama 3 | |---------|--------|--------|---------| | Q/K norm | Yes | No | No | | Sliding window | No | No | No | | EOS token | `<\|im_end\|>` | `<\|im_end\|>` | `<\|eot_id\|>` | | BOS token | `<\|im_start\|>` | `<\|im_start\|>` | `<\|begin_of_text\|>` | Qwen 2 detection: if `model_type` is absent from config, weight presence of `model.layers.0.self_attn.q_norm.weight` distinguishes Qwen 3 (present) from Qwen 2 (absent). Decoder structure per layer (standard pre-norm): ``` input → InputNorm → Attention → residual add → PostAttnNorm → MLP → residual add ``` MLP: SwiGLU gate — `down(silu(gate(x)) * up(x))`. Output head: always a separate `lm_head.weight` (Qwen 3 has `tie_word_embeddings=false`). ### Weight Loading All architectures load from HuggingFace safetensors format (not GGUF). The loader: 1. Reads `config.json` for model configuration 2. Loads `tokenizer.json` for the tokeniser 3. Glob-matches all `*.safetensors` files in the directory (multi-shard support) 4. Calls `LoadSafetensors` per shard; checks `lastError()` after each 5. Resolves weights by name, with automatic `language_model.` prefix fallback via `resolveWeight()` 6. Constructs `Linear` layers as quantised or dense based on presence of `scales` tensors 7. Calls `Materialize()` on all weight arrays to commit them to GPU memory Quantisation is transparent: `NewQuantizedLinear` stores packed weights with `scales` and `biases`, dispatching to `QuantizedMatmul` (mlx-c grouped quantisation) in `Forward`. Quantisation parameters (`bits`, `group_size`) are read from top-level `config.json`. Head dimension inference: if `head_dim` is absent from `config.json` (as with some Gemma 3 variants), it is inferred from `q_proj.weight[0]` / `num_attention_heads`. --- ## Attention Mechanism ### Virtual Transpose Linear projections produce `[B, L, H*D]`. The reshape to `[B, H, L, D]` is implemented via `AsStrided` — a zero-copy stride manipulation that avoids a physical copy: ``` shape: [B, H, L, D] strides: [L*H*D, D, H*D, 1] ``` - Batch stride: `L*H*D` (jump entire sequence) - Head stride: `D` (adjacent heads are contiguous in memory) - Sequence stride: `H*D` (jump one full row of heads) - Element stride: `1` (contiguous within head) The result is a non-contiguous view used for RoPE and SDPA calls. ### Rotary Position Embeddings (RoPE) Applied via the fused `mlx_fast_rope` Metal kernel. Parameters: - `dims`: head dimension - `traditional`: false (standard non-interleaved layout) - `base`: theta (varies by layer type in Gemma 3; single value for Qwen/Llama) - `scale`: 1.0 (no frequency scaling) - `offset`: current KV cache offset (enables continuation from cached position) ### Scaled Dot-Product Attention (SDPA) Implemented via the fused `mlx_fast_scaled_dot_product_attention` kernel with two variants: - `ScaledDotProductAttention(q, k, v, scale, causal)` — causal masking handled internally by the kernel - `ScaledDotProductAttentionWithMask(q, k, v, mask, scale)` — explicit additive mask (0 = attend, -inf = ignore), used for batched inference with padding Scale = `1/sqrt(head_dim)`, precomputed at load time. After SDPA, output is transposed from `[B, H, L, D]` back to `[B, L, H*D]` via `Reshape(Transpose(out, 0, 2, 1, 3), ...)` for the output projection. --- ## KV Cache The `Cache` interface provides `Update(k, v *Array, seqLen int) (*Array, *Array)`, returning the full accumulated K/V to pass to SDPA. `Offset()` tracks total tokens processed for RoPE continuation. ### KVCache (Unbounded) Pre-allocates in 256-token chunks, growing as needed. On each decode step: 1. Checks whether the current buffer capacity is sufficient 2. If not, allocates a new chunk and concatenates it 3. Writes the new K/V via `SliceUpdateInplace` 4. Returns a slice view `[0:offset]` of the buffer This amortises allocation cost while keeping the returned slice valid for the SDPA call. ### RotatingKVCache (Sliding Window) Bounded to `maxSize` tokens. Two update paths: - Prefill (`seqLen > 1`): concatenate, then trim the leading tokens that fall outside the window - Decode (`seqLen == 1`): write in-place at circular index `idx % maxSize` Used for Gemma 3 sliding-window attention layers (window size from `sliding_window` config field). Qwen and Llama use only unbounded caches. --- ## Tokeniser `Tokenizer` supports two BPE variants detected at load time from `tokenizer.json`: ### SentencePiece BPE (Gemma 3) - Prefix each segment with `▁` (Unicode U+2581, the SentencePiece space marker) - Split into characters - Apply BPE merges via `bpeMerge()` using a rank-sorted lookup table - Look up merged symbols in the vocabulary Detection: checks for absence of `Ġthe` in the vocabulary. Large SentencePiece vocabularies (Gemma 3 at 262K entries) may contain `Ġ` as an unrelated character, so the detection checks `Ġthe` rather than bare `Ġ`. ### GPT-2 Byte-Level BPE (Qwen, Llama, DeepSeek) - Maps all 256 bytes to printable Unicode via `buildGPT2ByteMaps()` - Printable ASCII (33–126) and Latin-1 Supplement (161–172, 174–255) map to themselves - Control characters, space (32), DEL (127), and gap values (0–32, 127–160, 173) map to U+0100 onwards - Apply BPE merges in this Unicode representation, then look up in vocabulary Detection: presence of `Ġthe` in the vocabulary. ### BPE Merge Algorithm `bpeMerge()` implements the standard greedy algorithm: 1. Build merge rank table from `tokenizer.json` merges field (O(1) lookup by `"a b"` key) 2. Scan all adjacent pairs; find the pair with the lowest rank 3. Merge that pair into a single symbol 4. Repeat until no merge can be applied Merges are parsed from both `["a b", ...]` and `[["a","b"], ...]` JSON formats. ### Special Token Handling Special tokens (BOS, EOS, chat delimiters) are matched before BPE encoding. Each architecture family uses different stop tokens: | Family | BOS | EOS / Stop | |--------|-----|-----------| | Gemma 3 | `` | `` | | Qwen 2/3 | `<\|im_start\|>` | `<\|im_end\|>` | | Llama 3 | `<\|begin_of_text\|>` | `<\|eot_id\|>` | --- ## Generation Loop `Model.Generate(ctx, prompt, cfg)` returns `iter.Seq[Token]` (range-over-func). The iterator: 1. Encodes the prompt via `Tokenizer.Encode()` 2. Allocates per-layer KV caches via `newCaches()` 3. Prefill: runs `model.Forward(tokens, caches)` on the full prompt in one pass; records prefill timing 4. Decode loop, up to `MaxTokens`: - Checks `ctx.Done()`; sets `m.lastErr = ctx.Err()` and returns on cancellation - Slices last-position logits via `SliceAxis` - Applies `applyRepeatPenalty` if `RepeatPenalty > 1.0` - Samples via the configured `Sampler` chain; calls `Eval()` and propagates any GPU error - Checks EOS token and `StopTokens` IDs - Yields `Token{ID, Text}` to the consumer; stops if `yield` returns false - Runs `model.Forward({next_token}, caches)` with the single new token 5. Records decode timing and memory metrics in `m.lastMetrics` via deferred closure `Model.Err()` returns the error from the most recent `Generate` or `Chat` call. ### Repeat Penalty `applyRepeatPenalty(logits, history, penalty)` deduplicates the history, gathers logits at those positions, then applies: - Positive logits: divide by `penalty` (reduces probability) - Negative logits: multiply by `penalty` (increases magnitude, reducing probability further) ### Chat Templates `Model.Chat()` formats messages through `formatChat()` before calling `Generate()`: | Architecture | Format | |-------------|--------| | Gemma 3 | `role\ncontent\n` | | Qwen 2/3 | `<\|im_start\|>role\ncontent<\|im_end\|>\n` | | Llama 3 | `<\|start_header_id\|>role<\|end_header_id\|>\n\ncontent<\|eot_id\|>` | --- ## Batch Inference ### Classify (Prefill-Only) `Model.Classify(ctx, prompts, cfg, returnLogits)` runs a single forward pass per batch — no decode loop. The batch is right-padded to the length of the longest prompt: 1. Tokenise all prompts 2. Sort by descending token count 3. Build a `[N, 1, L, L]` attention mask combining causal masking and padding (0 = attend, -inf = ignore) 4. Run `ForwardMasked(tokens, mask, caches)` on the padded batch 5. Extract last-position logits for each prompt 6. Sample or return raw logits per the configuration Measured throughput on M3 Ultra: 152 prompts/s for 4-prompt batches (Gemma3-1B 4-bit). ### BatchGenerate (Autoregressive Batches) `Model.BatchGenerate(ctx, prompts, cfg)` runs full autoregressive generation for multiple prompts using the same masking approach as `Classify`. Each decode step processes the entire batch in one `ForwardMasked` call. Returns `[]BatchResult`, each holding the generated tokens and any per-prompt error. --- ## Sampling Chain `newSampler(temp, topP, minP, topK)` builds a composable pipeline: ``` TopP -> MinP -> TopK -> Temperature -> RandomCategorical ``` If `temp == 0`, the chain collapses to greedy (argmax). Otherwise, each filter stage masks logits before the final categorical sample. - **Greedy**: `Argmax(logits, -1)` - **Temperature**: multiply logits by `1/temp` - **TopK**: mask all but the K highest logits with `-inf` - **TopP (nucleus)**: keep the smallest set with cumulative probability exceeding `p`; implemented via argsort, cumsum, and `PutAlongAxis` scatter back to original positions - **MinP**: mask tokens whose probability falls below `min_p * max_probability` --- ## Training Pipeline ### LoRA Fine-Tuning `InternalModel.ApplyLoRA(cfg)` wraps target projection layers in-place. The `LoRALinear` struct: ```go type LoRALinear struct { Base *Linear // frozen base weights (may be quantised) A *Array // [rank, in_features] — Kaiming normal initialisation B *Array // [out_features, rank] — zero initialisation Scale float32 // alpha / rank } ``` Forward pass: `base(x) + scale * (x @ A^T) @ B^T` B is zero-initialised so LoRA starts as the identity transformation (no change to base output). `LoRAAdapter` collects all `LoRALinear` instances by weight path key. `AllTrainableParams()` returns A and B arrays in deterministic sorted order for use with `ValueAndGrad`. `LoRAAdapter.Save(path)` writes only the A and B matrices to safetensors (not the frozen base weights). ### Gradient Computation Three autodiff interfaces via mlx-c: - `VJP(fn, primals, cotangents)` — reverse mode (backward pass) - `JVP(fn, primals, tangents)` — forward mode (directional derivative) - `ValueAndGrad(fn, argnums)` — returns a `GradFn` that computes both value and gradients in one call Go functions are registered as mlx-c closures via `goGradFunc` (exported CGO callback) using an atomic ID registry (`gradNextID atomic.Uintptr`). ### Gradient Checkpointing `Checkpoint(fn)` wraps a function using `mlx_checkpoint`, which recomputes intermediate activations during the backward pass rather than storing them. Trades compute for GPU memory — useful for large models on constrained hardware. ### Mixed Precision `LoRAConfig.DType` selects the dtype for A and B matrices. `DTypeBFloat16` halves parameter memory with accuracy matching Float32 in practice (validated: loss 7.15→6.29 in 5 steps). MLX auto-promotes operands for cross-dtype operations. ### AdamW Optimiser Standard AdamW with decoupled weight decay: ``` m = beta1*m + (1-beta1)*grad v = beta2*v + (1-beta2)*grad^2 param = param*(1 - lr*wd) - lr * m_hat / (sqrt(v_hat) + eps) ``` Defaults: `lr=1e-5`, `beta1=0.9`, `beta2=0.999`, `eps=1e-8`, `weight_decay=0.01`. ### Loss Functions - `CrossEntropyLoss(logits, targets)` — numerically stable via logsumexp; averaged over all positions - `MaskedCrossEntropyLoss(logits, targets, mask)` — averaged over masked positions only - `MSELoss(predictions, targets)` — mean squared error --- ## Fused Metal Kernels `internal/metal/fast.go` wraps four mlx-c fused kernels: | Kernel | Go function | Notes | |--------|------------|-------| | `mlx_fast_rms_norm` | `RMSNorm(x, weight, eps)` | Gemma uses pre-scaled `(1+weight)` | | `mlx_fast_layer_norm` | `LayerNorm(x, weight, bias, eps)` | Standard layer norm | | `mlx_fast_rope` | `RoPE(x, dims, traditional, base, scale, offset)` | Rotary position embeddings | | `mlx_fast_scaled_dot_product_attention` | `ScaledDotProductAttention(...)` | Causal or explicit mask | These bypass the general MLX computation graph, dispatching directly to optimised Metal compute shaders. --- ## go-inference Integration The public API is provided entirely by `forge.lthn.ai/core/go-inference`. go-mlx exports only Metal-specific controls: - `MetalAvailable() bool` — hardware check - `SetCacheLimit`, `SetMemoryLimit`, `GetActiveMemory`, `GetPeakMemory`, `ClearCache`, `GetCacheMemory`, `ResetPeakMemory`, `SetWiredLimit`, `GetDeviceInfo` `register_metal.go` auto-registers `metalBackend` via `init()` on darwin/arm64. The adapter (`metalAdapter`) converts between `inference.*` types and `metal.*` types, implementing: `Generate`, `Chat`, `Classify`, `BatchGenerate`, `Metrics`, `Info`, `ModelType`, `Err`, `Close`. Consumer pattern: ```go import ( "forge.lthn.ai/core/go-inference" _ "forge.lthn.ai/core/go-mlx" ) m, err := inference.LoadModel("/path/to/model/") for tok := range m.Generate(ctx, "prompt", inference.WithMaxTokens(128)) { fmt.Print(tok.Text) } ``` `inference.LoadConfig` options understood by the Metal backend: - `ContextLen` — replaces unbounded `KVCache` with `RotatingKVCache(contextLen)` for all layers - `GPULayers` — logged as a warning if set to 0 (Metal always uses full GPU offload) --- ## mlxlm Subprocess Backend `mlxlm/` provides a second backend (`"mlx_lm"`) that does not require CGO. It spawns a Python 3 process running the embedded `bridge.py` script and communicates via JSON Lines over stdin/stdout. ### Protocol Commands sent to stdin (newline-delimited JSON): | Command | Request fields | Response | |---------|---------------|----------| | `load` | `path` | `{ok, model_type, vocab_size}` or `{error}` | | `generate` | `prompt, max_tokens, temperature?, top_k?, top_p?` | stream of `{token, token_id}`, then `{done}` | | `chat` | `messages, max_tokens, ...` | same as generate | | `info` | — | `{model_type, vocab_size, layers, hidden_size}` | | `cancel` | — | subprocess drains and returns `{done}` | | `quit` | — | subprocess exits cleanly | Concurrent `Generate`/`Chat` calls are serialised via `sync.Mutex` (one generation at a time per subprocess instance). ### bridge.py Embedded via `//go:embed bridge.py`, extracted to a temp file on first use via `sync.Once`. Uses `mlx_lm.load()` and `mlx_lm.stream_generate()` from the `mlx-lm` Python package. Flushes stdout after every line (critical for streaming). ### Limitations - `Classify` and `BatchGenerate` are not supported (return error directing caller to use the native Metal backend) - No inference metrics (`Metrics()` returns zero values) - Requires Python 3 and `mlx-lm` installed in the Python environment - Build tag `nomlxlm` removes the backend entirely --- ## Downstream Consumers | Package | Role | |---------|------| | `forge.lthn.ai/core/go-ml` | Imports go-inference + go-mlx for Metal backend | | `forge.lthn.ai/core/go-i18n` | Phase 2a: Gemma3-1B domain classification | | `forge.lthn.ai/core/go-rocm` | Sibling AMD GPU backend, same go-inference interfaces | --- ## Performance Baseline (M3 Ultra, 60-core GPU, 96 GB unified memory) | Operation | Throughput | |-----------|-----------| | Gemma3-1B 4-bit prefill | 246 tok/s | | Gemma3-1B 4-bit decode | 82 tok/s | | Gemma3-1B 4-bit classify (4 prompts) | 152 prompts/s | | DeepSeek R1 7B 4-bit decode | 27 tok/s | | Llama 3.1 8B 4-bit decode | 30 tok/s | CGO call overhead floors at approximately 170 µs per operation (Metal command buffer + CGO boundary). MatMul scales well: 128² to 4096² is roughly 55× slower for 1024× more work. Full sampling chain (TopP+MinP+TopK) adds approximately 560 µs over greedy per token.