docs: update project docs for backend abstraction completion
- CLAUDE.md: new architecture diagram, public API examples - TODO.md: Phase 4 marked complete, remaining items noted - FINDINGS.md: migration completion notes, import cycle resolution Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
parent
eb8dee31bf
commit
95d92fffff
3 changed files with 136 additions and 82 deletions
126
CLAUDE.md
126
CLAUDE.md
|
|
@ -8,33 +8,30 @@ Pure Go + CGO package that wraps Apple's [MLX framework](https://github.com/ml-e
|
|||
|
||||
## Platform
|
||||
|
||||
**darwin/arm64 only.** All files carry `//go:build darwin && arm64`. A stub (`mlx_stub.go`) provides `MetalAvailable() bool` returning false on other platforms.
|
||||
**darwin/arm64 only.** All CGO files carry `//go:build darwin && arm64`. A stub (`mlx_stub.go`) provides `MetalAvailable() bool` returning false on other platforms.
|
||||
|
||||
## Build
|
||||
|
||||
```bash
|
||||
# Step 1: Build mlx-c C library via CMake (fetches mlx-c v0.4.1)
|
||||
cd /tmp/core-go-mlx && go generate ./...
|
||||
go generate ./...
|
||||
|
||||
# Step 2: Run tests (must be on Apple Silicon)
|
||||
go test ./...
|
||||
|
||||
# Step 3: Use in another module
|
||||
# Import "forge.lthn.ai/core/go-mlx" and "forge.lthn.ai/core/go-mlx/model"
|
||||
```
|
||||
|
||||
### CGO Flags (auto-set via go:generate)
|
||||
|
||||
CMake installs to `dist/` inside the package directory. The `#cgo` directives in `mlx.go` reference:
|
||||
- `CPPFLAGS: -I${SRCDIR}/dist/include`
|
||||
- `LDFLAGS: -L${SRCDIR}/dist/lib -lmlxc -lmlx`
|
||||
CMake installs to `dist/` inside the package directory. The `#cgo` directives in `internal/metal/metal.go` reference:
|
||||
- `CPPFLAGS: -I${SRCDIR}/../../dist/include`
|
||||
- `LDFLAGS: -L${SRCDIR}/../../dist/lib -lmlxc -lmlx`
|
||||
- Frameworks: Foundation, Metal, Accelerate
|
||||
|
||||
### CMake Config
|
||||
|
||||
`CMakeLists.txt` fetches mlx-c v0.4.1 from GitHub. Key settings:
|
||||
- `MLX_BUILD_SAFETENSORS=ON` (model loading)
|
||||
- `MLX_BUILD_GGUF=OFF` (GGUF is for llama.cpp backend in go-ai/ml/)
|
||||
- `MLX_BUILD_GGUF=OFF`
|
||||
- `BUILD_SHARED_LIBS=ON`
|
||||
- macOS deployment target: 26.0
|
||||
|
||||
|
|
@ -42,61 +39,74 @@ CMake installs to `dist/` inside the package directory. The `#cgo` directives in
|
|||
|
||||
```
|
||||
go-mlx/
|
||||
├── Core Layer (Metal GPU)
|
||||
│ ├── mlx.go — Init, Materialize, MetalAvailable (CGO bridge)
|
||||
│ ├── mlx_stub.go — Non-Apple fallback
|
||||
│ ├── array.go — MLX array wrapper (create, reshape, data access)
|
||||
│ ├── dtype.go — Data types (Float16, Float32, BFloat16, Int32, etc.)
|
||||
│ ├── stream.go — Metal stream/queue management
|
||||
│ └── CMakeLists.txt — Fetches mlx-c v0.4.1
|
||||
├── Public API (compiles on all platforms)
|
||||
│ ├── mlx.go — Package doc + go:generate CMake directives
|
||||
│ ├── textmodel.go — TextModel interface, Token, Message types
|
||||
│ ├── options.go — GenerateOption, LoadOption functional options
|
||||
│ ├── backend.go — Backend interface, Register/Get/Default/LoadModel
|
||||
│ ├── register_metal.go — //go:build darwin && arm64 — auto-registers metal
|
||||
│ ├── mlx_stub.go — //go:build !darwin || !arm64 — MetalAvailable() false
|
||||
│ └── mlx_test.go — Integration tests (public API)
|
||||
│
|
||||
├── Operations
|
||||
│ ├── ops.go — Element-wise: Add, Multiply, MatMul, Softmax, etc.
|
||||
│ ├── fast.go — Fused Metal kernels: RMSNorm, RoPE, ScaledDotProductAttention
|
||||
│ ├── nn.go — Neural network layers: Linear, Embedding, RMSNorm
|
||||
│ ├── compile.go — Compiled function closures for kernel fusion
|
||||
│ ├── slice.go — Array slicing/indexing
|
||||
│ └── random.go — Categorical sampling
|
||||
├── internal/metal/ — All CGO code (darwin/arm64 only)
|
||||
│ ├── metal.go — Init, Materialize, error handler, stream
|
||||
│ ├── array.go — Array type, creation, data access
|
||||
│ ├── dtype.go — DType constants
|
||||
│ ├── stream.go — Metal stream/queue, memory controls
|
||||
│ ├── ops.go — Element-wise, reduction, shape ops
|
||||
│ ├── fast.go — Fused Metal kernels (RMSNorm, RoPE, SDPA)
|
||||
│ ├── nn.go — Linear, Embedding, RMSNormModule
|
||||
│ ├── compile.go — CompiledFunc
|
||||
│ ├── slice.go — Array slicing
|
||||
│ ├── random.go — RandomCategorical, RandomUniform
|
||||
│ ├── io.go — Safetensors loading
|
||||
│ ├── model.go — InternalModel interface + architecture dispatch
|
||||
│ ├── gemma3.go — Gemma3 decoder
|
||||
│ ├── qwen3.go — Qwen3 decoder
|
||||
│ ├── cache.go — KVCache + RotatingKVCache
|
||||
│ ├── sample.go — Sampling chain (greedy, temp, topK, topP)
|
||||
│ ├── tokenizer.go — BPE tokenizer
|
||||
│ ├── grad.go — VJP
|
||||
│ ├── lora.go — LoRA adapters
|
||||
│ ├── optim.go — AdamW
|
||||
│ ├── generate.go — Autoregressive generation loop + chat templates
|
||||
│ └── backend.go — LoadAndInit entry point
|
||||
│
|
||||
├── Training
|
||||
│ ├── grad.go — VJP (Vector-Jacobian Product) gradient computation
|
||||
│ ├── lora.go — LoRA adapter (rank decomposition fine-tuning)
|
||||
│ └── optim.go — AdamW optimiser
|
||||
├── cpp/ — CLion Claude workspace (C++ research)
|
||||
│ ├── CMakeLists.txt
|
||||
│ ├── CLAUDE.md
|
||||
│ ├── TODO.md
|
||||
│ └── FINDINGS.md
|
||||
│
|
||||
├── Model Support
|
||||
│ ├── model/
|
||||
│ │ ├── model.go — Base model interface (LoadModel, Generate)
|
||||
│ │ ├── gemma3.go — Google Gemma3 decoder
|
||||
│ │ └── qwen3.go — Alibaba Qwen3 decoder
|
||||
│ ├── tokenizer/
|
||||
│ │ └── tokenizer.go — BPE tokenizer (sentencepiece format)
|
||||
│ ├── sample/
|
||||
│ │ └── sample.go — Sampling strategies (temperature, top-k, top-p)
|
||||
│ └── cache/
|
||||
│ └── cache.go — KV cache for autoregressive inference
|
||||
│
|
||||
└── I/O
|
||||
└── io.go — Safetensors model loading
|
||||
└── docs/plans/ — Design and implementation docs
|
||||
```
|
||||
|
||||
## Key Interfaces
|
||||
## Public API
|
||||
|
||||
```go
|
||||
// model/model.go — base interface for all models
|
||||
type Model interface {
|
||||
Forward(x *mlx.Array, cache *cache.KVCache) *mlx.Array
|
||||
import "forge.lthn.ai/core/go-mlx"
|
||||
|
||||
// Load and generate
|
||||
m, err := mlx.LoadModel("/path/to/model/")
|
||||
if err != nil { log.Fatal(err) }
|
||||
defer m.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
for tok := range m.Generate(ctx, "What is 2+2?", mlx.WithMaxTokens(128)) {
|
||||
fmt.Print(tok.Text)
|
||||
}
|
||||
if err := m.Err(); err != nil { log.Fatal(err) }
|
||||
|
||||
// Chat with template formatting
|
||||
for tok := range m.Chat(ctx, []mlx.Message{
|
||||
{Role: "user", Content: "Hello"},
|
||||
}, mlx.WithMaxTokens(64)) {
|
||||
fmt.Print(tok.Text)
|
||||
}
|
||||
|
||||
// Top-level — GPU materialisation
|
||||
mlx.Materialize(outputs...) // Sync GPU eval
|
||||
mlx.MaterializeAsync(outputs...) // Async GPU eval
|
||||
mlx.MetalAvailable() bool // Check Metal GPU
|
||||
|
||||
// Array operations
|
||||
mlx.NewArray(data, shape, dtype)
|
||||
mlx.MatMul(a, b)
|
||||
mlx.Softmax(x, axis)
|
||||
mlx.Add(a, b), mlx.Multiply(a, b), mlx.Divide(a, b)
|
||||
// Memory controls
|
||||
mlx.SetCacheLimit(4 * 1024 * 1024 * 1024) // 4GB
|
||||
mlx.ClearCache()
|
||||
```
|
||||
|
||||
## Dependencies
|
||||
|
|
@ -107,8 +117,10 @@ mlx.Add(a, b), mlx.Multiply(a, b), mlx.Divide(a, b)
|
|||
|
||||
## Downstream Consumers
|
||||
|
||||
- `forge.lthn.ai/core/go-ai/ml` — `backend_mlx.go` uses this for Metal inference
|
||||
- `forge.lthn.ai/core/go-i18n` — Phase 2a needs Gemma3-1B inference for domain classification (blocked on this package being importable)
|
||||
- `forge.lthn.ai/core/go-ml` — `backend_mlx.go` uses `mlx.LoadModel()` + `m.Generate()`
|
||||
- `forge.lthn.ai/core/go-i18n` — Phase 2a needs Gemma3-1B inference for domain classification
|
||||
|
||||
The old API (`Array`, `MatMul`, `model.LoadModel`, etc.) is no longer public — all CGO is in `internal/metal/`. Consumers use the clean `TextModel` interface.
|
||||
|
||||
## Model Format
|
||||
|
||||
|
|
|
|||
43
FINDINGS.md
43
FINDINGS.md
|
|
@ -180,37 +180,56 @@ This is a data correctness issue — silent wrong results, not a crash.
|
|||
|
||||
---
|
||||
|
||||
## 2026-02-19: Backend Abstraction — API Breaking Change
|
||||
## 2026-02-19: Backend Abstraction — COMPLETED
|
||||
|
||||
**Design doc:** `docs/plans/2026-02-19-backend-abstraction-design.md`
|
||||
**Implementation plan:** `docs/plans/2026-02-19-backend-abstraction-plan.md`
|
||||
|
||||
### What's changing
|
||||
### What changed
|
||||
|
||||
The entire public API is being replaced. All CGO code moves to `internal/metal/`. The root package becomes a clean interface layer:
|
||||
The entire public API has been replaced. All CGO code is now in `internal/metal/`. The root package is a clean interface layer:
|
||||
|
||||
```go
|
||||
m, _ := mlx.LoadModel("/path/to/model/")
|
||||
defer m.Close()
|
||||
for tok := range m.Generate("prompt", mlx.WithMaxTokens(128)) {
|
||||
ctx := context.Background()
|
||||
for tok := range m.Generate(ctx, "prompt", mlx.WithMaxTokens(128)) {
|
||||
fmt.Print(tok.Text)
|
||||
}
|
||||
if err := m.Err(); err != nil { log.Fatal(err) }
|
||||
```
|
||||
|
||||
The old API (`Array`, `MatMul`, `model.LoadModel`, `model.Model`, etc.) will no longer be public.
|
||||
The old API (`Array`, `MatMul`, `model.LoadModel`, `model.Model`, sub-packages `model/`, `tokenizer/`, `sample/`, `cache/`) is no longer public. All moved to `internal/metal/`.
|
||||
|
||||
### Impact on go-ai
|
||||
### Architecture note: import cycle resolution
|
||||
|
||||
`backend_mlx.go` currently imports root-level Array, ops, model types. These move to `internal/metal/` and become inaccessible. Migration: replace direct tensor manipulation with `mlx.LoadModel()` + `mlx.TextModel.Generate()`.
|
||||
`internal/metal/` cannot import the root package (circular dependency). Solution: internal/metal defines its own concrete types (`metal.Token`, `metal.GenerateConfig`, `metal.Model`), and `register_metal.go` in root provides a thin adapter (`metalAdapter`) that converts between root types (`mlx.Token`) and metal types.
|
||||
|
||||
### Impact on go-ml
|
||||
|
||||
`backend_mlx.go` must migrate from direct tensor manipulation to:
|
||||
```go
|
||||
m, _ := mlx.LoadModel(path)
|
||||
ctx := context.Background()
|
||||
for tok := range m.Generate(ctx, prompt, mlx.WithMaxTokens(n)) { ... }
|
||||
if err := m.Err(); err != nil { ... }
|
||||
```
|
||||
253 LOC → ~60 LOC. Memory controls: `mlx.SetCacheLimit()`, `mlx.ClearCache()`, etc.
|
||||
|
||||
### Impact on go-i18n
|
||||
|
||||
The API for Gemma3-1B domain classification will be:
|
||||
```go
|
||||
m, _ := mlx.LoadModel("/path/to/gemma-3-1b/")
|
||||
for tok := range m.Generate(sentence, mlx.WithMaxTokens(32)) { ... }
|
||||
ctx := context.Background()
|
||||
for tok := range m.Generate(ctx, sentence, mlx.WithMaxTokens(32)) { ... }
|
||||
```
|
||||
Streaming via `iter.Seq[Token]`. No tokenisation or sampling to handle.
|
||||
|
||||
### Memory leak fix included
|
||||
### Memory management status
|
||||
|
||||
The refactor includes deterministic memory management — `TextModel.Close()` for model weights and per-step intermediate cleanup during generation. This addresses the current production blocker.
|
||||
`Close()` stub is in place but does not yet explicitly free model weights. Per-step intermediate cleanup (`ClearCache()` per decode step) is implemented in the generate loop. Full deterministic cleanup awaits CLion Claude research on `mlx_array_free` safety (see `cpp/TODO.md`).
|
||||
|
||||
### Test results
|
||||
|
||||
- 148 existing tests moved to `internal/metal/` — all pass
|
||||
- 7 new integration tests for public API — all pass
|
||||
- Total: 155 tests passing
|
||||
|
|
|
|||
49
TODO.md
49
TODO.md
|
|
@ -13,7 +13,7 @@ Dispatched from core/go orchestration. Pick up tasks in order.
|
|||
|
||||
## Phase 2: Model Support
|
||||
|
||||
- [ ] **Gemma3-1B inference validation** — The go-i18n Phase 2a needs 1B model inference for domain classification at ~5K sentences/sec. Validate Gemma3-1B loads and generates correctly via `model.LoadModel()` + `model.Generate()`. Report tokens/sec.
|
||||
- [ ] **Gemma3-1B inference validation** — The go-i18n Phase 2a needs 1B model inference for domain classification at ~5K sentences/sec. Validate Gemma3-1B loads and generates correctly via `mlx.LoadModel()` + `m.Generate()`. Report tokens/sec.
|
||||
- [ ] **Model loading robustness** — Test with missing files, corrupted safetensors, wrong dtype. Currently no error handling tests for `io.go`.
|
||||
- [ ] **Add Llama model support** — Only Gemma3 and Qwen3 exist. Llama architecture would cover Meta's model family (Llama 3, CodeLlama).
|
||||
|
||||
|
|
@ -23,31 +23,54 @@ Dispatched from core/go orchestration. Pick up tasks in order.
|
|||
- [ ] **Gradient checkpointing** — `grad.go` has VJP but large models will OOM without checkpointing. Add selective recomputation.
|
||||
- [ ] **Mixed precision training** — MLX supports BFloat16/Float16. Add dtype selection for training (currently inference uses model's native dtype).
|
||||
|
||||
## Phase 4: API Polish
|
||||
## Phase 4: Backend Abstraction — ✅ COMPLETE (19 Feb 2026)
|
||||
|
||||
- [ ] **Error handling audit** — `checkError()` only logs. Should return errors to callers instead of silent logging. The C error handler stores last error but Go code doesn't propagate it.
|
||||
- [ ] **Memory management audit** — Array finalizers use `runtime.SetFinalizer`. Verify no leaks under sustained inference (1000+ tokens). Check C-side deallocation.
|
||||
- [ ] **Documentation** — Public API has minimal godoc. Add examples for common workflows: load model, generate text, fine-tune with LoRA.
|
||||
Design doc: `docs/plans/2026-02-19-backend-abstraction-design.md`
|
||||
Implementation plan: `docs/plans/2026-02-19-backend-abstraction-plan.md`
|
||||
|
||||
**All Virgil review items implemented:**
|
||||
|
||||
- [x] **`context.Context` on `TextModel.Generate()`** — `Generate(ctx context.Context, prompt string, opts ...GenerateOption) iter.Seq[Token]`. Checks `ctx.Done()` in the decode loop.
|
||||
- [x] **`Err() error` on `TextModel`** — Distinguishes normal stop (EOS, max tokens) from errors (OOM, ctx cancelled).
|
||||
- [x] **`Chat()` on `TextModel`** — Model owns its chat template. Gemma3 and Qwen3 templates implemented.
|
||||
- [x] **Memory control functions at root** — `SetCacheLimit`, `SetMemoryLimit`, `GetActiveMemory`, `GetPeakMemory`, `ClearCache` delegate to `internal/metal`.
|
||||
- [x] **Backend registration** — `register_metal.go` auto-registers via build-tagged `init()`.
|
||||
- [x] **All CGO moved to `internal/metal/`** — 19 source files, 10 test files, 148 tests passing.
|
||||
- [x] **Public API: `TextModel`, `Backend`, functional options** — Clean root package, compiles on all platforms.
|
||||
- [x] **Integration tests** — 7 tests for public API (backend registration, options, LoadModel paths).
|
||||
- [ ] **Error handling audit** — `checkError()` still logs + swallows. Needs conversion to error returns. Low priority — existing behaviour, not a regression.
|
||||
- [ ] **Memory management — deterministic cleanup** — `Close()` stub in place. Needs CLion Claude research on `mlx_array_free` safety before implementing per-step cleanup. See `cpp/TODO.md`.
|
||||
- [ ] **Documentation** — Public API has godoc but needs examples for common workflows.
|
||||
|
||||
## Phase 5: Ecosystem Integration (Virgil wishlist)
|
||||
|
||||
- [ ] **Streaming token iterator** — `model.Generate()` should return an `iter.Seq[Token]` (or channel) for token-by-token streaming. LEM Lab chat UI needs this for real-time output. Currently the generate path is batch-only.
|
||||
- [ ] **Batch inference API** — go-i18n Phase 2a wants ~5K sentences/sec through Gemma3-1B. Single-request inference won't hit that. Add a batch API that processes N prompts concurrently on the GPU. Even batching 8-16 at a time would help throughput significantly.
|
||||
- [ ] **Inference metrics** — Expose tokens/sec, peak memory, GPU utilisation as structured data. LEM Lab dashboard and go-ai scoring engine both want this. A simple `Stats` struct returned alongside generation results.
|
||||
- [ ] **Model quantisation awareness** — MLX supports 4-bit and 8-bit quantised models. The `model.LoadModel()` path should detect and handle quantised safetensors. This matters for running 27B models on 32GB Macs.
|
||||
- [ ] **Embed-friendly model loading** — For the Core native app (FrankenPHP + Wails), models are on disk but the path discovery needs to be clean. Add `model.Discover(baseDir)` that scans for available models and returns metadata (name, params, quant level, size).
|
||||
- [ ] **Batch inference API** — go-i18n Phase 2a wants ~5K sentences/sec through Gemma3-1B. Single-prompt `Generate(..., WithMaxTokens(1))` works functionally for classification but won't hit 5K/sec. True batch inference (multiple prompts through one forward pass) is needed.
|
||||
- [ ] **Inference metrics** — Expose tokens/sec, peak memory, GPU utilisation as structured data. LEM Lab dashboard and go-ai scoring engine both want this.
|
||||
- [ ] **Model quantisation awareness** — MLX supports 4-bit and 8-bit quantised models. The loader already handles quantised safetensors (GroupSize, Bits in config).
|
||||
- [ ] **Embed-friendly model loading** — Add `Discover(baseDir)` that scans for available models and returns metadata.
|
||||
- [ ] **mlxlm/ backend** — Python subprocess wrapper via `core/go/pkg/process`. Implements `mlx.Backend` for mlx_lm compatibility.
|
||||
|
||||
## Phase 6: Go 1.26 Modernisation
|
||||
|
||||
- [x] **Evaluate Go 1.26 features** — ✅ Documented in FINDINGS.md. Key wins: CGO ~30% faster (free), Green Tea GC default (10-40% less overhead, helps Array finalisers), slice stack alloc. Range-over-func already stable since 1.23.
|
||||
- [ ] **Range-over-func for Array** — If 1.26 stabilises range-over-func, `Array.Iter()` returning `iter.Seq[float32]` (or typed variant) would be cleaner than the current index-based access. Measure overhead vs direct C pointer access.
|
||||
- [x] **Evaluate Go 1.26 features** — ✅ Documented in FINDINGS.md. Key wins: CGO ~30% faster (free), Green Tea GC default (10-40% less overhead, helps Array finalisers), slice stack alloc.
|
||||
- [ ] **Range-over-func for Array** — `Array.Iter()` returning `iter.Seq[float32]` for cleaner iteration. Measure overhead vs direct C pointer access.
|
||||
|
||||
---
|
||||
|
||||
## Upstream Dependencies
|
||||
|
||||
- **go-i18n Phase 2a** is blocked on this package providing working Gemma3-1B inference
|
||||
- **go-ai/ml/backend_mlx.go** imports this package — after split, go-ai needs `replace` directive or published module
|
||||
- **go-ml/backend_mlx.go** needs updating to use new API: `mlx.LoadModel()` + `m.Generate()`. Old imports (`mlx.Array`, `model.LoadModel`, sub-packages) are now internal.
|
||||
- **go-ai** has a `replace` directive pointing at `../go-mlx`. No code changes needed in go-ai itself.
|
||||
- **LEM Lab** uses `MLXBackend` via go-ml. Migration transparent once go-ml updates.
|
||||
|
||||
## Functional Options Convention
|
||||
|
||||
Virgil confirms: the `WithMaxTokens(n)` functional option pattern is the right call for this package.
|
||||
|
||||
## core/go/pkg/process (for mlxlm backend, Phase 5)
|
||||
|
||||
Virgil confirms: no changes needed. The process package provides everything needed for the mlxlm subprocess backend.
|
||||
|
||||
## Workflow
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue