diff --git a/CLAUDE.md b/CLAUDE.md index 4de07f7..fb84ac5 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -8,33 +8,30 @@ Pure Go + CGO package that wraps Apple's [MLX framework](https://github.com/ml-e ## Platform -**darwin/arm64 only.** All files carry `//go:build darwin && arm64`. A stub (`mlx_stub.go`) provides `MetalAvailable() bool` returning false on other platforms. +**darwin/arm64 only.** All CGO files carry `//go:build darwin && arm64`. A stub (`mlx_stub.go`) provides `MetalAvailable() bool` returning false on other platforms. ## Build ```bash # Step 1: Build mlx-c C library via CMake (fetches mlx-c v0.4.1) -cd /tmp/core-go-mlx && go generate ./... +go generate ./... # Step 2: Run tests (must be on Apple Silicon) go test ./... - -# Step 3: Use in another module -# Import "forge.lthn.ai/core/go-mlx" and "forge.lthn.ai/core/go-mlx/model" ``` ### CGO Flags (auto-set via go:generate) -CMake installs to `dist/` inside the package directory. The `#cgo` directives in `mlx.go` reference: -- `CPPFLAGS: -I${SRCDIR}/dist/include` -- `LDFLAGS: -L${SRCDIR}/dist/lib -lmlxc -lmlx` +CMake installs to `dist/` inside the package directory. The `#cgo` directives in `internal/metal/metal.go` reference: +- `CPPFLAGS: -I${SRCDIR}/../../dist/include` +- `LDFLAGS: -L${SRCDIR}/../../dist/lib -lmlxc -lmlx` - Frameworks: Foundation, Metal, Accelerate ### CMake Config `CMakeLists.txt` fetches mlx-c v0.4.1 from GitHub. Key settings: - `MLX_BUILD_SAFETENSORS=ON` (model loading) -- `MLX_BUILD_GGUF=OFF` (GGUF is for llama.cpp backend in go-ai/ml/) +- `MLX_BUILD_GGUF=OFF` - `BUILD_SHARED_LIBS=ON` - macOS deployment target: 26.0 @@ -42,61 +39,74 @@ CMake installs to `dist/` inside the package directory. The `#cgo` directives in ``` go-mlx/ -├── Core Layer (Metal GPU) -│ ├── mlx.go — Init, Materialize, MetalAvailable (CGO bridge) -│ ├── mlx_stub.go — Non-Apple fallback -│ ├── array.go — MLX array wrapper (create, reshape, data access) -│ ├── dtype.go — Data types (Float16, Float32, BFloat16, Int32, etc.) -│ ├── stream.go — Metal stream/queue management -│ └── CMakeLists.txt — Fetches mlx-c v0.4.1 +├── Public API (compiles on all platforms) +│ ├── mlx.go — Package doc + go:generate CMake directives +│ ├── textmodel.go — TextModel interface, Token, Message types +│ ├── options.go — GenerateOption, LoadOption functional options +│ ├── backend.go — Backend interface, Register/Get/Default/LoadModel +│ ├── register_metal.go — //go:build darwin && arm64 — auto-registers metal +│ ├── mlx_stub.go — //go:build !darwin || !arm64 — MetalAvailable() false +│ └── mlx_test.go — Integration tests (public API) │ -├── Operations -│ ├── ops.go — Element-wise: Add, Multiply, MatMul, Softmax, etc. -│ ├── fast.go — Fused Metal kernels: RMSNorm, RoPE, ScaledDotProductAttention -│ ├── nn.go — Neural network layers: Linear, Embedding, RMSNorm -│ ├── compile.go — Compiled function closures for kernel fusion -│ ├── slice.go — Array slicing/indexing -│ └── random.go — Categorical sampling +├── internal/metal/ — All CGO code (darwin/arm64 only) +│ ├── metal.go — Init, Materialize, error handler, stream +│ ├── array.go — Array type, creation, data access +│ ├── dtype.go — DType constants +│ ├── stream.go — Metal stream/queue, memory controls +│ ├── ops.go — Element-wise, reduction, shape ops +│ ├── fast.go — Fused Metal kernels (RMSNorm, RoPE, SDPA) +│ ├── nn.go — Linear, Embedding, RMSNormModule +│ ├── compile.go — CompiledFunc +│ ├── slice.go — Array slicing +│ ├── random.go — RandomCategorical, RandomUniform +│ ├── io.go — Safetensors loading +│ ├── model.go — InternalModel interface + architecture dispatch +│ ├── gemma3.go — Gemma3 decoder +│ ├── qwen3.go — Qwen3 decoder +│ ├── cache.go — KVCache + RotatingKVCache +│ ├── sample.go — Sampling chain (greedy, temp, topK, topP) +│ ├── tokenizer.go — BPE tokenizer +│ ├── grad.go — VJP +│ ├── lora.go — LoRA adapters +│ ├── optim.go — AdamW +│ ├── generate.go — Autoregressive generation loop + chat templates +│ └── backend.go — LoadAndInit entry point │ -├── Training -│ ├── grad.go — VJP (Vector-Jacobian Product) gradient computation -│ ├── lora.go — LoRA adapter (rank decomposition fine-tuning) -│ └── optim.go — AdamW optimiser +├── cpp/ — CLion Claude workspace (C++ research) +│ ├── CMakeLists.txt +│ ├── CLAUDE.md +│ ├── TODO.md +│ └── FINDINGS.md │ -├── Model Support -│ ├── model/ -│ │ ├── model.go — Base model interface (LoadModel, Generate) -│ │ ├── gemma3.go — Google Gemma3 decoder -│ │ └── qwen3.go — Alibaba Qwen3 decoder -│ ├── tokenizer/ -│ │ └── tokenizer.go — BPE tokenizer (sentencepiece format) -│ ├── sample/ -│ │ └── sample.go — Sampling strategies (temperature, top-k, top-p) -│ └── cache/ -│ └── cache.go — KV cache for autoregressive inference -│ -└── I/O - └── io.go — Safetensors model loading +└── docs/plans/ — Design and implementation docs ``` -## Key Interfaces +## Public API ```go -// model/model.go — base interface for all models -type Model interface { - Forward(x *mlx.Array, cache *cache.KVCache) *mlx.Array +import "forge.lthn.ai/core/go-mlx" + +// Load and generate +m, err := mlx.LoadModel("/path/to/model/") +if err != nil { log.Fatal(err) } +defer m.Close() + +ctx := context.Background() +for tok := range m.Generate(ctx, "What is 2+2?", mlx.WithMaxTokens(128)) { + fmt.Print(tok.Text) +} +if err := m.Err(); err != nil { log.Fatal(err) } + +// Chat with template formatting +for tok := range m.Chat(ctx, []mlx.Message{ + {Role: "user", Content: "Hello"}, +}, mlx.WithMaxTokens(64)) { + fmt.Print(tok.Text) } -// Top-level — GPU materialisation -mlx.Materialize(outputs...) // Sync GPU eval -mlx.MaterializeAsync(outputs...) // Async GPU eval -mlx.MetalAvailable() bool // Check Metal GPU - -// Array operations -mlx.NewArray(data, shape, dtype) -mlx.MatMul(a, b) -mlx.Softmax(x, axis) -mlx.Add(a, b), mlx.Multiply(a, b), mlx.Divide(a, b) +// Memory controls +mlx.SetCacheLimit(4 * 1024 * 1024 * 1024) // 4GB +mlx.ClearCache() ``` ## Dependencies @@ -107,8 +117,10 @@ mlx.Add(a, b), mlx.Multiply(a, b), mlx.Divide(a, b) ## Downstream Consumers -- `forge.lthn.ai/core/go-ai/ml` — `backend_mlx.go` uses this for Metal inference -- `forge.lthn.ai/core/go-i18n` — Phase 2a needs Gemma3-1B inference for domain classification (blocked on this package being importable) +- `forge.lthn.ai/core/go-ml` — `backend_mlx.go` uses `mlx.LoadModel()` + `m.Generate()` +- `forge.lthn.ai/core/go-i18n` — Phase 2a needs Gemma3-1B inference for domain classification + +The old API (`Array`, `MatMul`, `model.LoadModel`, etc.) is no longer public — all CGO is in `internal/metal/`. Consumers use the clean `TextModel` interface. ## Model Format diff --git a/FINDINGS.md b/FINDINGS.md index 31b482d..9eec527 100644 --- a/FINDINGS.md +++ b/FINDINGS.md @@ -180,37 +180,56 @@ This is a data correctness issue — silent wrong results, not a crash. --- -## 2026-02-19: Backend Abstraction — API Breaking Change +## 2026-02-19: Backend Abstraction — COMPLETED **Design doc:** `docs/plans/2026-02-19-backend-abstraction-design.md` +**Implementation plan:** `docs/plans/2026-02-19-backend-abstraction-plan.md` -### What's changing +### What changed -The entire public API is being replaced. All CGO code moves to `internal/metal/`. The root package becomes a clean interface layer: +The entire public API has been replaced. All CGO code is now in `internal/metal/`. The root package is a clean interface layer: ```go m, _ := mlx.LoadModel("/path/to/model/") defer m.Close() -for tok := range m.Generate("prompt", mlx.WithMaxTokens(128)) { +ctx := context.Background() +for tok := range m.Generate(ctx, "prompt", mlx.WithMaxTokens(128)) { fmt.Print(tok.Text) } +if err := m.Err(); err != nil { log.Fatal(err) } ``` -The old API (`Array`, `MatMul`, `model.LoadModel`, `model.Model`, etc.) will no longer be public. +The old API (`Array`, `MatMul`, `model.LoadModel`, `model.Model`, sub-packages `model/`, `tokenizer/`, `sample/`, `cache/`) is no longer public. All moved to `internal/metal/`. -### Impact on go-ai +### Architecture note: import cycle resolution -`backend_mlx.go` currently imports root-level Array, ops, model types. These move to `internal/metal/` and become inaccessible. Migration: replace direct tensor manipulation with `mlx.LoadModel()` + `mlx.TextModel.Generate()`. +`internal/metal/` cannot import the root package (circular dependency). Solution: internal/metal defines its own concrete types (`metal.Token`, `metal.GenerateConfig`, `metal.Model`), and `register_metal.go` in root provides a thin adapter (`metalAdapter`) that converts between root types (`mlx.Token`) and metal types. + +### Impact on go-ml + +`backend_mlx.go` must migrate from direct tensor manipulation to: +```go +m, _ := mlx.LoadModel(path) +ctx := context.Background() +for tok := range m.Generate(ctx, prompt, mlx.WithMaxTokens(n)) { ... } +if err := m.Err(); err != nil { ... } +``` +253 LOC → ~60 LOC. Memory controls: `mlx.SetCacheLimit()`, `mlx.ClearCache()`, etc. ### Impact on go-i18n -The API for Gemma3-1B domain classification will be: ```go m, _ := mlx.LoadModel("/path/to/gemma-3-1b/") -for tok := range m.Generate(sentence, mlx.WithMaxTokens(32)) { ... } +ctx := context.Background() +for tok := range m.Generate(ctx, sentence, mlx.WithMaxTokens(32)) { ... } ``` -Streaming via `iter.Seq[Token]`. No tokenisation or sampling to handle. -### Memory leak fix included +### Memory management status -The refactor includes deterministic memory management — `TextModel.Close()` for model weights and per-step intermediate cleanup during generation. This addresses the current production blocker. +`Close()` stub is in place but does not yet explicitly free model weights. Per-step intermediate cleanup (`ClearCache()` per decode step) is implemented in the generate loop. Full deterministic cleanup awaits CLion Claude research on `mlx_array_free` safety (see `cpp/TODO.md`). + +### Test results + +- 148 existing tests moved to `internal/metal/` — all pass +- 7 new integration tests for public API — all pass +- Total: 155 tests passing diff --git a/TODO.md b/TODO.md index 22759cc..75710f7 100644 --- a/TODO.md +++ b/TODO.md @@ -13,7 +13,7 @@ Dispatched from core/go orchestration. Pick up tasks in order. ## Phase 2: Model Support -- [ ] **Gemma3-1B inference validation** — The go-i18n Phase 2a needs 1B model inference for domain classification at ~5K sentences/sec. Validate Gemma3-1B loads and generates correctly via `model.LoadModel()` + `model.Generate()`. Report tokens/sec. +- [ ] **Gemma3-1B inference validation** — The go-i18n Phase 2a needs 1B model inference for domain classification at ~5K sentences/sec. Validate Gemma3-1B loads and generates correctly via `mlx.LoadModel()` + `m.Generate()`. Report tokens/sec. - [ ] **Model loading robustness** — Test with missing files, corrupted safetensors, wrong dtype. Currently no error handling tests for `io.go`. - [ ] **Add Llama model support** — Only Gemma3 and Qwen3 exist. Llama architecture would cover Meta's model family (Llama 3, CodeLlama). @@ -23,31 +23,54 @@ Dispatched from core/go orchestration. Pick up tasks in order. - [ ] **Gradient checkpointing** — `grad.go` has VJP but large models will OOM without checkpointing. Add selective recomputation. - [ ] **Mixed precision training** — MLX supports BFloat16/Float16. Add dtype selection for training (currently inference uses model's native dtype). -## Phase 4: API Polish +## Phase 4: Backend Abstraction — ✅ COMPLETE (19 Feb 2026) -- [ ] **Error handling audit** — `checkError()` only logs. Should return errors to callers instead of silent logging. The C error handler stores last error but Go code doesn't propagate it. -- [ ] **Memory management audit** — Array finalizers use `runtime.SetFinalizer`. Verify no leaks under sustained inference (1000+ tokens). Check C-side deallocation. -- [ ] **Documentation** — Public API has minimal godoc. Add examples for common workflows: load model, generate text, fine-tune with LoRA. +Design doc: `docs/plans/2026-02-19-backend-abstraction-design.md` +Implementation plan: `docs/plans/2026-02-19-backend-abstraction-plan.md` + +**All Virgil review items implemented:** + +- [x] **`context.Context` on `TextModel.Generate()`** — `Generate(ctx context.Context, prompt string, opts ...GenerateOption) iter.Seq[Token]`. Checks `ctx.Done()` in the decode loop. +- [x] **`Err() error` on `TextModel`** — Distinguishes normal stop (EOS, max tokens) from errors (OOM, ctx cancelled). +- [x] **`Chat()` on `TextModel`** — Model owns its chat template. Gemma3 and Qwen3 templates implemented. +- [x] **Memory control functions at root** — `SetCacheLimit`, `SetMemoryLimit`, `GetActiveMemory`, `GetPeakMemory`, `ClearCache` delegate to `internal/metal`. +- [x] **Backend registration** — `register_metal.go` auto-registers via build-tagged `init()`. +- [x] **All CGO moved to `internal/metal/`** — 19 source files, 10 test files, 148 tests passing. +- [x] **Public API: `TextModel`, `Backend`, functional options** — Clean root package, compiles on all platforms. +- [x] **Integration tests** — 7 tests for public API (backend registration, options, LoadModel paths). +- [ ] **Error handling audit** — `checkError()` still logs + swallows. Needs conversion to error returns. Low priority — existing behaviour, not a regression. +- [ ] **Memory management — deterministic cleanup** — `Close()` stub in place. Needs CLion Claude research on `mlx_array_free` safety before implementing per-step cleanup. See `cpp/TODO.md`. +- [ ] **Documentation** — Public API has godoc but needs examples for common workflows. ## Phase 5: Ecosystem Integration (Virgil wishlist) -- [ ] **Streaming token iterator** — `model.Generate()` should return an `iter.Seq[Token]` (or channel) for token-by-token streaming. LEM Lab chat UI needs this for real-time output. Currently the generate path is batch-only. -- [ ] **Batch inference API** — go-i18n Phase 2a wants ~5K sentences/sec through Gemma3-1B. Single-request inference won't hit that. Add a batch API that processes N prompts concurrently on the GPU. Even batching 8-16 at a time would help throughput significantly. -- [ ] **Inference metrics** — Expose tokens/sec, peak memory, GPU utilisation as structured data. LEM Lab dashboard and go-ai scoring engine both want this. A simple `Stats` struct returned alongside generation results. -- [ ] **Model quantisation awareness** — MLX supports 4-bit and 8-bit quantised models. The `model.LoadModel()` path should detect and handle quantised safetensors. This matters for running 27B models on 32GB Macs. -- [ ] **Embed-friendly model loading** — For the Core native app (FrankenPHP + Wails), models are on disk but the path discovery needs to be clean. Add `model.Discover(baseDir)` that scans for available models and returns metadata (name, params, quant level, size). +- [ ] **Batch inference API** — go-i18n Phase 2a wants ~5K sentences/sec through Gemma3-1B. Single-prompt `Generate(..., WithMaxTokens(1))` works functionally for classification but won't hit 5K/sec. True batch inference (multiple prompts through one forward pass) is needed. +- [ ] **Inference metrics** — Expose tokens/sec, peak memory, GPU utilisation as structured data. LEM Lab dashboard and go-ai scoring engine both want this. +- [ ] **Model quantisation awareness** — MLX supports 4-bit and 8-bit quantised models. The loader already handles quantised safetensors (GroupSize, Bits in config). +- [ ] **Embed-friendly model loading** — Add `Discover(baseDir)` that scans for available models and returns metadata. +- [ ] **mlxlm/ backend** — Python subprocess wrapper via `core/go/pkg/process`. Implements `mlx.Backend` for mlx_lm compatibility. ## Phase 6: Go 1.26 Modernisation -- [x] **Evaluate Go 1.26 features** — ✅ Documented in FINDINGS.md. Key wins: CGO ~30% faster (free), Green Tea GC default (10-40% less overhead, helps Array finalisers), slice stack alloc. Range-over-func already stable since 1.23. -- [ ] **Range-over-func for Array** — If 1.26 stabilises range-over-func, `Array.Iter()` returning `iter.Seq[float32]` (or typed variant) would be cleaner than the current index-based access. Measure overhead vs direct C pointer access. +- [x] **Evaluate Go 1.26 features** — ✅ Documented in FINDINGS.md. Key wins: CGO ~30% faster (free), Green Tea GC default (10-40% less overhead, helps Array finalisers), slice stack alloc. +- [ ] **Range-over-func for Array** — `Array.Iter()` returning `iter.Seq[float32]` for cleaner iteration. Measure overhead vs direct C pointer access. --- ## Upstream Dependencies - **go-i18n Phase 2a** is blocked on this package providing working Gemma3-1B inference -- **go-ai/ml/backend_mlx.go** imports this package — after split, go-ai needs `replace` directive or published module +- **go-ml/backend_mlx.go** needs updating to use new API: `mlx.LoadModel()` + `m.Generate()`. Old imports (`mlx.Array`, `model.LoadModel`, sub-packages) are now internal. +- **go-ai** has a `replace` directive pointing at `../go-mlx`. No code changes needed in go-ai itself. +- **LEM Lab** uses `MLXBackend` via go-ml. Migration transparent once go-ml updates. + +## Functional Options Convention + +Virgil confirms: the `WithMaxTokens(n)` functional option pattern is the right call for this package. + +## core/go/pkg/process (for mlxlm backend, Phase 5) + +Virgil confirms: no changes needed. The process package provides everything needed for the mlxlm subprocess backend. ## Workflow