docs: update project docs for backend abstraction completion

- CLAUDE.md: new architecture diagram, public API examples
- TODO.md: Phase 4 marked complete, remaining items noted
- FINDINGS.md: migration completion notes, import cycle resolution

Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
Snider 2026-02-19 20:07:01 +00:00
parent eb8dee31bf
commit 95d92fffff
3 changed files with 136 additions and 82 deletions

126
CLAUDE.md
View file

@ -8,33 +8,30 @@ Pure Go + CGO package that wraps Apple's [MLX framework](https://github.com/ml-e
## Platform
**darwin/arm64 only.** All files carry `//go:build darwin && arm64`. A stub (`mlx_stub.go`) provides `MetalAvailable() bool` returning false on other platforms.
**darwin/arm64 only.** All CGO files carry `//go:build darwin && arm64`. A stub (`mlx_stub.go`) provides `MetalAvailable() bool` returning false on other platforms.
## Build
```bash
# Step 1: Build mlx-c C library via CMake (fetches mlx-c v0.4.1)
cd /tmp/core-go-mlx && go generate ./...
go generate ./...
# Step 2: Run tests (must be on Apple Silicon)
go test ./...
# Step 3: Use in another module
# Import "forge.lthn.ai/core/go-mlx" and "forge.lthn.ai/core/go-mlx/model"
```
### CGO Flags (auto-set via go:generate)
CMake installs to `dist/` inside the package directory. The `#cgo` directives in `mlx.go` reference:
- `CPPFLAGS: -I${SRCDIR}/dist/include`
- `LDFLAGS: -L${SRCDIR}/dist/lib -lmlxc -lmlx`
CMake installs to `dist/` inside the package directory. The `#cgo` directives in `internal/metal/metal.go` reference:
- `CPPFLAGS: -I${SRCDIR}/../../dist/include`
- `LDFLAGS: -L${SRCDIR}/../../dist/lib -lmlxc -lmlx`
- Frameworks: Foundation, Metal, Accelerate
### CMake Config
`CMakeLists.txt` fetches mlx-c v0.4.1 from GitHub. Key settings:
- `MLX_BUILD_SAFETENSORS=ON` (model loading)
- `MLX_BUILD_GGUF=OFF` (GGUF is for llama.cpp backend in go-ai/ml/)
- `MLX_BUILD_GGUF=OFF`
- `BUILD_SHARED_LIBS=ON`
- macOS deployment target: 26.0
@ -42,61 +39,74 @@ CMake installs to `dist/` inside the package directory. The `#cgo` directives in
```
go-mlx/
├── Core Layer (Metal GPU)
│ ├── mlx.go — Init, Materialize, MetalAvailable (CGO bridge)
│ ├── mlx_stub.go — Non-Apple fallback
│ ├── array.go — MLX array wrapper (create, reshape, data access)
│ ├── dtype.go — Data types (Float16, Float32, BFloat16, Int32, etc.)
│ ├── stream.go — Metal stream/queue management
│ └── CMakeLists.txt — Fetches mlx-c v0.4.1
├── Public API (compiles on all platforms)
│ ├── mlx.go — Package doc + go:generate CMake directives
│ ├── textmodel.go — TextModel interface, Token, Message types
│ ├── options.go — GenerateOption, LoadOption functional options
│ ├── backend.go — Backend interface, Register/Get/Default/LoadModel
│ ├── register_metal.go — //go:build darwin && arm64 — auto-registers metal
│ ├── mlx_stub.go — //go:build !darwin || !arm64 — MetalAvailable() false
│ └── mlx_test.go — Integration tests (public API)
├── Operations
│ ├── ops.go — Element-wise: Add, Multiply, MatMul, Softmax, etc.
│ ├── fast.go — Fused Metal kernels: RMSNorm, RoPE, ScaledDotProductAttention
│ ├── nn.go — Neural network layers: Linear, Embedding, RMSNorm
│ ├── compile.go — Compiled function closures for kernel fusion
│ ├── slice.go — Array slicing/indexing
│ └── random.go — Categorical sampling
├── internal/metal/ — All CGO code (darwin/arm64 only)
│ ├── metal.go — Init, Materialize, error handler, stream
│ ├── array.go — Array type, creation, data access
│ ├── dtype.go — DType constants
│ ├── stream.go — Metal stream/queue, memory controls
│ ├── ops.go — Element-wise, reduction, shape ops
│ ├── fast.go — Fused Metal kernels (RMSNorm, RoPE, SDPA)
│ ├── nn.go — Linear, Embedding, RMSNormModule
│ ├── compile.go — CompiledFunc
│ ├── slice.go — Array slicing
│ ├── random.go — RandomCategorical, RandomUniform
│ ├── io.go — Safetensors loading
│ ├── model.go — InternalModel interface + architecture dispatch
│ ├── gemma3.go — Gemma3 decoder
│ ├── qwen3.go — Qwen3 decoder
│ ├── cache.go — KVCache + RotatingKVCache
│ ├── sample.go — Sampling chain (greedy, temp, topK, topP)
│ ├── tokenizer.go — BPE tokenizer
│ ├── grad.go — VJP
│ ├── lora.go — LoRA adapters
│ ├── optim.go — AdamW
│ ├── generate.go — Autoregressive generation loop + chat templates
│ └── backend.go — LoadAndInit entry point
├── Training
│ ├── grad.go — VJP (Vector-Jacobian Product) gradient computation
│ ├── lora.go — LoRA adapter (rank decomposition fine-tuning)
│ └── optim.go — AdamW optimiser
├── cpp/ — CLion Claude workspace (C++ research)
│ ├── CMakeLists.txt
│ ├── CLAUDE.md
│ ├── TODO.md
│ └── FINDINGS.md
├── Model Support
│ ├── model/
│ │ ├── model.go — Base model interface (LoadModel, Generate)
│ │ ├── gemma3.go — Google Gemma3 decoder
│ │ └── qwen3.go — Alibaba Qwen3 decoder
│ ├── tokenizer/
│ │ └── tokenizer.go — BPE tokenizer (sentencepiece format)
│ ├── sample/
│ │ └── sample.go — Sampling strategies (temperature, top-k, top-p)
│ └── cache/
│ └── cache.go — KV cache for autoregressive inference
└── I/O
└── io.go — Safetensors model loading
└── docs/plans/ — Design and implementation docs
```
## Key Interfaces
## Public API
```go
// model/model.go — base interface for all models
type Model interface {
Forward(x *mlx.Array, cache *cache.KVCache) *mlx.Array
import "forge.lthn.ai/core/go-mlx"
// Load and generate
m, err := mlx.LoadModel("/path/to/model/")
if err != nil { log.Fatal(err) }
defer m.Close()
ctx := context.Background()
for tok := range m.Generate(ctx, "What is 2+2?", mlx.WithMaxTokens(128)) {
fmt.Print(tok.Text)
}
if err := m.Err(); err != nil { log.Fatal(err) }
// Chat with template formatting
for tok := range m.Chat(ctx, []mlx.Message{
{Role: "user", Content: "Hello"},
}, mlx.WithMaxTokens(64)) {
fmt.Print(tok.Text)
}
// Top-level — GPU materialisation
mlx.Materialize(outputs...) // Sync GPU eval
mlx.MaterializeAsync(outputs...) // Async GPU eval
mlx.MetalAvailable() bool // Check Metal GPU
// Array operations
mlx.NewArray(data, shape, dtype)
mlx.MatMul(a, b)
mlx.Softmax(x, axis)
mlx.Add(a, b), mlx.Multiply(a, b), mlx.Divide(a, b)
// Memory controls
mlx.SetCacheLimit(4 * 1024 * 1024 * 1024) // 4GB
mlx.ClearCache()
```
## Dependencies
@ -107,8 +117,10 @@ mlx.Add(a, b), mlx.Multiply(a, b), mlx.Divide(a, b)
## Downstream Consumers
- `forge.lthn.ai/core/go-ai/ml``backend_mlx.go` uses this for Metal inference
- `forge.lthn.ai/core/go-i18n` — Phase 2a needs Gemma3-1B inference for domain classification (blocked on this package being importable)
- `forge.lthn.ai/core/go-ml``backend_mlx.go` uses `mlx.LoadModel()` + `m.Generate()`
- `forge.lthn.ai/core/go-i18n` — Phase 2a needs Gemma3-1B inference for domain classification
The old API (`Array`, `MatMul`, `model.LoadModel`, etc.) is no longer public — all CGO is in `internal/metal/`. Consumers use the clean `TextModel` interface.
## Model Format

View file

@ -180,37 +180,56 @@ This is a data correctness issue — silent wrong results, not a crash.
---
## 2026-02-19: Backend Abstraction — API Breaking Change
## 2026-02-19: Backend Abstraction — COMPLETED
**Design doc:** `docs/plans/2026-02-19-backend-abstraction-design.md`
**Implementation plan:** `docs/plans/2026-02-19-backend-abstraction-plan.md`
### What's changing
### What changed
The entire public API is being replaced. All CGO code moves to `internal/metal/`. The root package becomes a clean interface layer:
The entire public API has been replaced. All CGO code is now in `internal/metal/`. The root package is a clean interface layer:
```go
m, _ := mlx.LoadModel("/path/to/model/")
defer m.Close()
for tok := range m.Generate("prompt", mlx.WithMaxTokens(128)) {
ctx := context.Background()
for tok := range m.Generate(ctx, "prompt", mlx.WithMaxTokens(128)) {
fmt.Print(tok.Text)
}
if err := m.Err(); err != nil { log.Fatal(err) }
```
The old API (`Array`, `MatMul`, `model.LoadModel`, `model.Model`, etc.) will no longer be public.
The old API (`Array`, `MatMul`, `model.LoadModel`, `model.Model`, sub-packages `model/`, `tokenizer/`, `sample/`, `cache/`) is no longer public. All moved to `internal/metal/`.
### Impact on go-ai
### Architecture note: import cycle resolution
`backend_mlx.go` currently imports root-level Array, ops, model types. These move to `internal/metal/` and become inaccessible. Migration: replace direct tensor manipulation with `mlx.LoadModel()` + `mlx.TextModel.Generate()`.
`internal/metal/` cannot import the root package (circular dependency). Solution: internal/metal defines its own concrete types (`metal.Token`, `metal.GenerateConfig`, `metal.Model`), and `register_metal.go` in root provides a thin adapter (`metalAdapter`) that converts between root types (`mlx.Token`) and metal types.
### Impact on go-ml
`backend_mlx.go` must migrate from direct tensor manipulation to:
```go
m, _ := mlx.LoadModel(path)
ctx := context.Background()
for tok := range m.Generate(ctx, prompt, mlx.WithMaxTokens(n)) { ... }
if err := m.Err(); err != nil { ... }
```
253 LOC → ~60 LOC. Memory controls: `mlx.SetCacheLimit()`, `mlx.ClearCache()`, etc.
### Impact on go-i18n
The API for Gemma3-1B domain classification will be:
```go
m, _ := mlx.LoadModel("/path/to/gemma-3-1b/")
for tok := range m.Generate(sentence, mlx.WithMaxTokens(32)) { ... }
ctx := context.Background()
for tok := range m.Generate(ctx, sentence, mlx.WithMaxTokens(32)) { ... }
```
Streaming via `iter.Seq[Token]`. No tokenisation or sampling to handle.
### Memory leak fix included
### Memory management status
The refactor includes deterministic memory management — `TextModel.Close()` for model weights and per-step intermediate cleanup during generation. This addresses the current production blocker.
`Close()` stub is in place but does not yet explicitly free model weights. Per-step intermediate cleanup (`ClearCache()` per decode step) is implemented in the generate loop. Full deterministic cleanup awaits CLion Claude research on `mlx_array_free` safety (see `cpp/TODO.md`).
### Test results
- 148 existing tests moved to `internal/metal/` — all pass
- 7 new integration tests for public API — all pass
- Total: 155 tests passing

49
TODO.md
View file

@ -13,7 +13,7 @@ Dispatched from core/go orchestration. Pick up tasks in order.
## Phase 2: Model Support
- [ ] **Gemma3-1B inference validation** — The go-i18n Phase 2a needs 1B model inference for domain classification at ~5K sentences/sec. Validate Gemma3-1B loads and generates correctly via `model.LoadModel()` + `model.Generate()`. Report tokens/sec.
- [ ] **Gemma3-1B inference validation** — The go-i18n Phase 2a needs 1B model inference for domain classification at ~5K sentences/sec. Validate Gemma3-1B loads and generates correctly via `mlx.LoadModel()` + `m.Generate()`. Report tokens/sec.
- [ ] **Model loading robustness** — Test with missing files, corrupted safetensors, wrong dtype. Currently no error handling tests for `io.go`.
- [ ] **Add Llama model support** — Only Gemma3 and Qwen3 exist. Llama architecture would cover Meta's model family (Llama 3, CodeLlama).
@ -23,31 +23,54 @@ Dispatched from core/go orchestration. Pick up tasks in order.
- [ ] **Gradient checkpointing**`grad.go` has VJP but large models will OOM without checkpointing. Add selective recomputation.
- [ ] **Mixed precision training** — MLX supports BFloat16/Float16. Add dtype selection for training (currently inference uses model's native dtype).
## Phase 4: API Polish
## Phase 4: Backend Abstraction — ✅ COMPLETE (19 Feb 2026)
- [ ] **Error handling audit**`checkError()` only logs. Should return errors to callers instead of silent logging. The C error handler stores last error but Go code doesn't propagate it.
- [ ] **Memory management audit** — Array finalizers use `runtime.SetFinalizer`. Verify no leaks under sustained inference (1000+ tokens). Check C-side deallocation.
- [ ] **Documentation** — Public API has minimal godoc. Add examples for common workflows: load model, generate text, fine-tune with LoRA.
Design doc: `docs/plans/2026-02-19-backend-abstraction-design.md`
Implementation plan: `docs/plans/2026-02-19-backend-abstraction-plan.md`
**All Virgil review items implemented:**
- [x] **`context.Context` on `TextModel.Generate()`** — `Generate(ctx context.Context, prompt string, opts ...GenerateOption) iter.Seq[Token]`. Checks `ctx.Done()` in the decode loop.
- [x] **`Err() error` on `TextModel`** — Distinguishes normal stop (EOS, max tokens) from errors (OOM, ctx cancelled).
- [x] **`Chat()` on `TextModel`** — Model owns its chat template. Gemma3 and Qwen3 templates implemented.
- [x] **Memory control functions at root**`SetCacheLimit`, `SetMemoryLimit`, `GetActiveMemory`, `GetPeakMemory`, `ClearCache` delegate to `internal/metal`.
- [x] **Backend registration**`register_metal.go` auto-registers via build-tagged `init()`.
- [x] **All CGO moved to `internal/metal/`** — 19 source files, 10 test files, 148 tests passing.
- [x] **Public API: `TextModel`, `Backend`, functional options** — Clean root package, compiles on all platforms.
- [x] **Integration tests** — 7 tests for public API (backend registration, options, LoadModel paths).
- [ ] **Error handling audit**`checkError()` still logs + swallows. Needs conversion to error returns. Low priority — existing behaviour, not a regression.
- [ ] **Memory management — deterministic cleanup**`Close()` stub in place. Needs CLion Claude research on `mlx_array_free` safety before implementing per-step cleanup. See `cpp/TODO.md`.
- [ ] **Documentation** — Public API has godoc but needs examples for common workflows.
## Phase 5: Ecosystem Integration (Virgil wishlist)
- [ ] **Streaming token iterator**`model.Generate()` should return an `iter.Seq[Token]` (or channel) for token-by-token streaming. LEM Lab chat UI needs this for real-time output. Currently the generate path is batch-only.
- [ ] **Batch inference API** — go-i18n Phase 2a wants ~5K sentences/sec through Gemma3-1B. Single-request inference won't hit that. Add a batch API that processes N prompts concurrently on the GPU. Even batching 8-16 at a time would help throughput significantly.
- [ ] **Inference metrics** — Expose tokens/sec, peak memory, GPU utilisation as structured data. LEM Lab dashboard and go-ai scoring engine both want this. A simple `Stats` struct returned alongside generation results.
- [ ] **Model quantisation awareness** — MLX supports 4-bit and 8-bit quantised models. The `model.LoadModel()` path should detect and handle quantised safetensors. This matters for running 27B models on 32GB Macs.
- [ ] **Embed-friendly model loading** — For the Core native app (FrankenPHP + Wails), models are on disk but the path discovery needs to be clean. Add `model.Discover(baseDir)` that scans for available models and returns metadata (name, params, quant level, size).
- [ ] **Batch inference API** — go-i18n Phase 2a wants ~5K sentences/sec through Gemma3-1B. Single-prompt `Generate(..., WithMaxTokens(1))` works functionally for classification but won't hit 5K/sec. True batch inference (multiple prompts through one forward pass) is needed.
- [ ] **Inference metrics** — Expose tokens/sec, peak memory, GPU utilisation as structured data. LEM Lab dashboard and go-ai scoring engine both want this.
- [ ] **Model quantisation awareness** — MLX supports 4-bit and 8-bit quantised models. The loader already handles quantised safetensors (GroupSize, Bits in config).
- [ ] **Embed-friendly model loading** — Add `Discover(baseDir)` that scans for available models and returns metadata.
- [ ] **mlxlm/ backend** — Python subprocess wrapper via `core/go/pkg/process`. Implements `mlx.Backend` for mlx_lm compatibility.
## Phase 6: Go 1.26 Modernisation
- [x] **Evaluate Go 1.26 features** — ✅ Documented in FINDINGS.md. Key wins: CGO ~30% faster (free), Green Tea GC default (10-40% less overhead, helps Array finalisers), slice stack alloc. Range-over-func already stable since 1.23.
- [ ] **Range-over-func for Array** — If 1.26 stabilises range-over-func, `Array.Iter()` returning `iter.Seq[float32]` (or typed variant) would be cleaner than the current index-based access. Measure overhead vs direct C pointer access.
- [x] **Evaluate Go 1.26 features** — ✅ Documented in FINDINGS.md. Key wins: CGO ~30% faster (free), Green Tea GC default (10-40% less overhead, helps Array finalisers), slice stack alloc.
- [ ] **Range-over-func for Array**`Array.Iter()` returning `iter.Seq[float32]` for cleaner iteration. Measure overhead vs direct C pointer access.
---
## Upstream Dependencies
- **go-i18n Phase 2a** is blocked on this package providing working Gemma3-1B inference
- **go-ai/ml/backend_mlx.go** imports this package — after split, go-ai needs `replace` directive or published module
- **go-ml/backend_mlx.go** needs updating to use new API: `mlx.LoadModel()` + `m.Generate()`. Old imports (`mlx.Array`, `model.LoadModel`, sub-packages) are now internal.
- **go-ai** has a `replace` directive pointing at `../go-mlx`. No code changes needed in go-ai itself.
- **LEM Lab** uses `MLXBackend` via go-ml. Migration transparent once go-ml updates.
## Functional Options Convention
Virgil confirms: the `WithMaxTokens(n)` functional option pattern is the right call for this package.
## core/go/pkg/process (for mlxlm backend, Phase 5)
Virgil confirms: no changes needed. The process package provides everything needed for the mlxlm subprocess backend.
## Workflow