Replace internal task tracking with structured docs covering CGO/mlx-c architecture, 4 model architectures, training pipeline, mlxlm backend, development guide, and full project history across 5 phases. Co-Authored-By: Virgil <virgil@lethean.io>
12 KiB
Project History
Module: forge.lthn.ai/core/go-mlx
Origin
go-mlx was extracted from forge.lthn.ai/core/go-ai/mlx/ on 19 February 2026. The split was motivated by three concerns:
- Platform isolation. mlx is darwin/arm64 only with CGO and CMake build requirements. Keeping it inside go-ai forced the entire AI package to carry platform-specific build complexity.
- Dependency chain. go-i18n Phase 2a needs Gemma3-1B inference for domain classification and should not pull in all of go-ai (DuckDB, Parquet, gRPC, Ollama, etc.) as a transitive dependency.
- Build tag cleanliness. Every file is
//go:build darwin && arm64. As a standalone module this is the default; inside go-ai it was a special case requiring careful build tag management.
Initial extraction: commit cae7ef0 — 22 files, approximately 4,354 lines, comprising core MLX bindings, Gemma3 and Qwen3 model implementations, BPE tokeniser, sampling strategies, and KV cache.
Phase 1: Standalone Package Hardening
Commits: 37abc49 through 40cbdd7
- Go generate → test round-trip verified. 29/29 tests pass. CMake 3.24+, AppleClang 17.0.0, macOS SDK 26.2. Build approximately 2 minutes on M3 Ultra.
- 86 new tests across
array_test.go(25),ops_test.go(44),nn_test.go(8),fast_test.go(9). Covers all scalar/array creation, shape operations, element-wise arithmetic, math functions, matrix operations, reductions, indexing, slicing, fused kernels (RMSNorm, LayerNorm, RoPE, SDPA), Linear, Embedding, RepeatKV. - 33 new tests for cache, sample, tokeniser: KVCache and RotatingKVCache lifecycle and update, greedy/temperature/TopK sampling chain, BPE load and encode/decode, DecodeToken, SentencePiece space handling, GPT-2 byte maps.
- Non-contiguous view bug discovered:
Floats()andDataInt32()readSize()elements from the raw C data pointer. For non-contiguous arrays (transpose, broadcast, slice views), the physical layout does not match the logical layout. Workaround documented:Reshape(arr, totalSize)forces a contiguous copy. Fixed in Phase 4. - 29 benchmarks in
bench_test.gobaselined on M3 Ultra.
Phase 2: Model Support
Commits: 18e8dca through a2493e0
- Gemma3-1B inference validated. End-to-end inference: 4-bit quantised Gemma3-1B loads and generates coherently at 46 tok/s on M3 Ultra. Two bugs fixed:
model_type: "gemma3_text"was not matched in architecture dispatch; GPT-2 BPE false detection on 262K SentencePiece vocab (fixed by checkingĠtherather than bareĠ). - 24 model loading robustness tests: missing/invalid config, unsupported architecture,
gemma3_textdispatch, missing tokeniser, missing safetensors (was a nil-pointer panic — fixed with early error return in bothLoadGemma3andLoadQwen3), config parsing defaults/quantisation/nestedtext_config,isLayerSliding,resolveWeightwith prefix fallback. - Qwen2 support added. Shares Qwen3 loader with optional Q/K RMS normalisation (Qwen3 has it, Qwen2 does not); auto-detected from weight presence. DeepSeek R1 7B: 27 tok/s on M3 Ultra.
- Llama 3 support added. Shares Qwen3 loader (same decoder: pre-norm, SwiGLU, GQA, no Q/K norm). Model type detected from
config.jsonmodel_type. Llama 3 chat template and EOS token (<|eot_id|>id=128009) added. Llama 3.1 8B 4-bit: 30 tok/s on M3 Ultra.
Phase 3: Training Pipeline
Commits: fb0692b through e3fbc22
- LoRA fine-tuning validated end-to-end. Load Gemma3-1B → apply LoRA (rank=8, q_proj+v_proj, 745K params) → 5 training steps with cross-entropy loss (7.15→6.31) → save adapter (2.9 MB safetensors) → reload and verify weights match. Uses ValueAndGrad + AdamW.
- Gradient checkpointing validated.
Checkpoint()wraps forward pass to recompute activations during backward. Produces correct gradients (loss 7.15→7.08 in 3 steps, matching non-checkpointed initial loss). - Mixed precision training validated.
LoRAConfig.DTypeselects training dtype for A/B matrices. BFloat16: loss 7.15→6.29 in 5 steps, matches Float32 accuracy with half parameter memory. MLX auto-promotes for cross-dtype operations.
Phase 4: Backend Abstraction
Commits: 1cf5178 through bff97cc
Design documents: docs/plans/2026-02-19-backend-abstraction-design.md, docs/plans/2026-02-19-backend-abstraction-plan.md
This phase was a full architectural restructure. All CGO code was moved to internal/metal/. The root package became a clean interface layer. Completed 19 February 2026.
- All CGO moved to
internal/metal/. 19 source files, 10 test files. - Public API:
TextModel,Backend, functional options — all from go-inference, not go-mlx. context.ContextonGenerate(). Checksctx.Done()in the decode loop.Err() erroronTextModel. Distinguishes normal stop (EOS, max tokens) from errors (OOM, context cancelled).Chat()onTextModel. Model owns its chat template.- Backend registration.
register_metal.goauto-registers via build-taggedinit(). - Integration tests. 7 tests for public API (backend registration, options,
LoadModelpaths). - go-inference migration. All types (
TextModel,Backend,Token,Message,GenerateConfig,LoadConfig, options) moved toforge.lthn.ai/core/go-inference. Import_ "forge.lthn.ai/core/go-mlx"to register the"metal"backend. Completed commitbff97cc. - Error handling audit (
ff01175):checkError()replaced withlastError()(reads and clears C-level error string).Eval()andEvalAsync()added as error-returning variants. Generate loop propagates errors.LoadAllSafetensorsreturns(map, error). Model loaders checklastError()after safetensors load. - Deterministic
Close()(f2ca7fe): Walks full model tree and explicitly frees all weight arrays. Handles tied output weights (skips double-free), nil safety, idempotent close. 8 new tests inclose_test.go. - Non-contiguous array fix (
df0b300):ensureContiguous()added.Floats(),DataInt32(),Ints()now call it automatically.mlx_contiguousand_mlx_array_is_row_contiguousbound from mlx-c. - TopP and MinP sampling implemented (
df0b300): Previously stubs passing logits through unchanged. Now fully implemented using cumsum, argsort, and masked scattering. - Virgil code review applied (
fb0692bthrough443347a): 12 items across critical/important/minor categories including thread-safe error handler (atomic), macOS deployment target corrected (13.3),LoadOptionpropagation, KV cache leak documented, repeat penalty implemented, stream caching, BPE merge algorithm,CompileShapelessdead code removed, naming cleanup. - 29 benchmarks baselined on M3 Ultra (
ff01175). - 4 new error handling tests in
error_test.go. - 148 tests total in
internal/metal/; 11 root integration tests (159 total).
Phase 5: Ecosystem Integration
Commits: 5644857 through 757a241
- Batch inference API (
5644857, design docdocs/plans/2026-02-19-batch-inference-design.md):Classify(prefill-only, 152 prompts/s on M3 Ultra) andBatchGenerate(autoregressive) implemented.ForwardMaskedadded toInternalModelinterface, threaded through Gemma3 and Qwen3 decoders. Mask:[N, 1, L, L]combining causal + padding.ClassifyResult,BatchResult,WithLogitsadded to go-inference. - Inference metrics (
a44e9f5):GenerateMetricstype in go-inference withMetrics()onTextModel. Captures prefill/decode timing, token counts, throughput (tok/s), peak and active GPU memory. InstrumentedGenerate,Classify,BatchGenerate. Gemma3-1B 4-bit on M3 Ultra: prefill 246 tok/s, decode 82 tok/s, peak 6.2 GB. - Model quantisation awareness (
ceb966b):ModelInfotype in go-inference withInfo()onTextModel. Exposes architecture, vocab size, layer count, hidden dimension, quantisation bits and group size. - Embed-friendly model loading (
dd49b4a):Discover(baseDir)in go-inference scans for model directories (config.json + *.safetensors). ReturnsDiscoveredModelwith path, architecture, quantisation, file count. - Package documentation expanded (
d7c8f17): godoc examples for Generate, Chat, Classify, BatchGenerate, Metrics, ModelInfo, Discover, memory controls. - mlxlm subprocess backend (
757a241): Python subprocess wrapper implementinginference.Backend. Communicates via JSON Lines. Auto-registers as"mlx_lm"viainit(). Build tagnomlxlmto opt out. Full test suite using mock bridge script (testdata/mock_bridge.py), nomlx-lminstall required for tests. Array.Iter()(f2ca7fe):iter.Seq[float32]implemented, handles non-contiguous arrays, supports early break.
Total tests after Phase 5: 176+ in internal/metal/, 11 root integration tests, 10 mlxlm tests.
Known Limitations
Per-Step Intermediate Array Accumulation
The generate loop creates approximately 500 intermediate arrays per forward pass. These rely on GC finalisers for C-side deallocation via mlx_array_free. Under sustained high-throughput inference, GC may not keep pace, causing Metal memory to grow until a collection cycle frees the finalisers.
Go 1.26's Green Tea GC partially mitigates this (10–40% less GC overhead, more timely finaliser execution). Explicit per-step freeing of intermediate arrays (freeIntermediates) was discussed but not implemented; the generate loop calls ClearCache() per step for prompt cache reclaim but does not explicitly free computation intermediates.
KV Cache Per-Turn Allocation
Each call to Generate() allocates a fresh set of KV caches. Cross-turn cache reuse (for multi-turn chat without re-encoding the full history) is not implemented. Call ClearCache() between turns to reclaim Metal memory promptly.
SentencePiece BPE Merges
The merges field is parsed and the merge rank table is built during tokeniser load. For Gemma 3's SentencePiece tokeniser, the vocabulary contains pre-merged token strings, so character-level splitting followed by BPE merge application is the correct approach and is implemented. The correctness boundary: if a segment contains characters not present individually in the vocabulary, they will be silently dropped. This has not caused issues with Gemma3-1B or Gemma3-4B in practice.
mlxlm Backend Limitations
The Python subprocess backend (mlxlm) does not support Classify, BatchGenerate, or inference metrics. It requires Python 3 and mlx-lm installed in the active Python environment. There is no way to target a specific virtual environment without setting PATH before import.
macOS Version Minimum
The CMake build sets CMAKE_OSX_DEPLOYMENT_TARGET=13.3, which is MLX's stated minimum. Testing has been performed on macOS 26.2 (Tahoe beta). Behaviour on macOS 13.x or 14.x has not been validated.
Future Considerations
The following items were identified during development but deferred:
- Per-step intermediate freeing. Binding
mlx_array_freedirectly and calling it on known-dead intermediates at the end of each decode step would reduce peak Metal memory and GC pressure. Requires careful bookkeeping to avoid double-free. - Cross-turn KV cache reuse. The
Cacheinterface already supportsReset()andState(). A generation session object that preserves caches acrossChatturns would enable prefix caching without re-encoding. - Go 1.26
testing.ArtifactDir()for storing benchmark results alongside test runs. - go-ml backend migration.
go-ai/ml/backend_mlx.go(253 LOC) should be updated to useinference.LoadModel()andm.Generate()rather than direct tensor manipulation, reducing it to approximately 60 LOC. - mlxlm virtual environment support. Accept a Python interpreter path in
LoadOptionto target a specific venv.