From d1fb26d51e62b2766e913c3160b7e2c15dc4ea8b Mon Sep 17 00:00:00 2001 From: Snider Date: Thu, 19 Feb 2026 23:44:07 +0000 Subject: [PATCH] docs: expand package doc with workflow examples Cover generate, chat, classify, batch generate, metrics, model info, discovery, and Metal memory controls. Co-Authored-By: Claude Opus 4.6 --- TODO.md | 2 +- mlx.go | 83 ++++++++++++++++++++++++++++++++++++++++++++++++++++++--- 2 files changed, 80 insertions(+), 5 deletions(-) diff --git a/TODO.md b/TODO.md index 75abb01..69523ec 100644 --- a/TODO.md +++ b/TODO.md @@ -41,7 +41,7 @@ Implementation plan: `docs/plans/2026-02-19-backend-abstraction-plan.md` - [x] **Integration tests** — 7 tests for public API (backend registration, options, LoadModel paths). - [x] **Error handling audit** — ✅ `checkError()` replaced with `lastError() error` (reads + clears C-level error string). Added `Eval(...*Array) error` and `EvalAsync(...*Array) error` as error-returning variants of Materialize. Generate loop propagates errors via `m.lastErr`. `LoadAllSafetensors` returns `(map, error)`. Model loaders (gemma3, qwen3) check `lastError()` after safetensors load. grad.go/lora.go now surface real MLX error messages. 4 new tests in error_test.go. - [x] **Memory management — deterministic cleanup** — ✅ `Model.Close()` now walks the full model tree (GemmaModel/Qwen3Model) and explicitly frees all weight arrays via `Free()`. Helpers: `freeLinear`, `freeEmbedding`, `freeRMSNorm`, `freeCaches`, `closeGemma`, `closeQwen3` in close.go. Handles tied output weights (skip double-free), nil safety, idempotent Close(). 8 new tests in close_test.go. -- [ ] **Documentation** — Public API has godoc but needs examples for common workflows. +- [x] **Documentation** — ✅ Package docs expanded with examples for all common workflows: Generate, Chat, Classify, BatchGenerate, Metrics, ModelInfo, Discover, memory controls. Both go-mlx and go-inference package docs updated with godoc heading sections. ## Phase 5: Ecosystem Integration (Virgil wishlist) diff --git a/mlx.go b/mlx.go index d226e8d..3cc8eaf 100644 --- a/mlx.go +++ b/mlx.go @@ -2,23 +2,98 @@ // // This package implements the [inference.Backend] interface from // forge.lthn.ai/core/go-inference for Apple Silicon (M1-M4) GPUs. +// Import it blank to register the "metal" backend automatically: +// +// import _ "forge.lthn.ai/core/go-mlx" // // Build mlx-c before use: // // go generate ./... // -// Load a model and generate text: -// -// import "forge.lthn.ai/core/go-inference" -// import _ "forge.lthn.ai/core/go-mlx" // register Metal backend +// # Generate text // // m, err := inference.LoadModel("/path/to/model/") // if err != nil { log.Fatal(err) } // defer m.Close() // +// ctx := context.Background() // for tok := range m.Generate(ctx, "What is 2+2?", inference.WithMaxTokens(128)) { // fmt.Print(tok.Text) // } +// if err := m.Err(); err != nil { log.Fatal(err) } +// +// # Multi-turn chat +// +// Chat applies the model's native template (Gemma3, Qwen3, Llama3): +// +// for tok := range m.Chat(ctx, []inference.Message{ +// {Role: "system", Content: "You are a helpful assistant."}, +// {Role: "user", Content: "Translate 'hello' to French."}, +// }, inference.WithMaxTokens(64)) { +// fmt.Print(tok.Text) +// } +// +// # Batch classification +// +// Classify runs a single forward pass per prompt (prefill only, no decoding): +// +// results, err := m.Classify(ctx, []string{ +// "Bonjour, comment allez-vous?", +// "The quarterly report shows growth.", +// }, inference.WithTemperature(0)) +// for i, r := range results { +// fmt.Printf("prompt %d → %q\n", i, r.Token.Text) +// } +// +// # Batch generation +// +// results, err := m.BatchGenerate(ctx, []string{ +// "The capital of France is", +// "Water boils at", +// }, inference.WithMaxTokens(32)) +// for i, r := range results { +// for _, tok := range r.Tokens { +// fmt.Print(tok.Text) +// } +// fmt.Println() +// } +// +// # Performance metrics +// +// After any inference call, retrieve timing and memory statistics: +// +// for tok := range m.Generate(ctx, prompt, inference.WithMaxTokens(128)) { +// fmt.Print(tok.Text) +// } +// met := m.Metrics() +// fmt.Printf("decode: %.0f tok/s, peak GPU: %d MB\n", +// met.DecodeTokensPerSec, met.PeakMemoryBytes/1024/1024) +// +// # Model info +// +// info := m.Info() +// fmt.Printf("%s %d-layer, %d-bit quantised\n", +// info.Architecture, info.NumLayers, info.QuantBits) +// +// # Model discovery +// +// models, err := inference.Discover("/path/to/models/") +// for _, d := range models { +// fmt.Printf("%s (%s, %d-bit)\n", d.Path, d.ModelType, d.QuantBits) +// } +// +// # Metal memory controls +// +// These control the Metal allocator directly, not individual models: +// +// mlx.SetCacheLimit(4 << 30) // 4 GB cache limit +// mlx.SetMemoryLimit(32 << 30) // 32 GB hard limit +// +// // Between chat turns, reclaim prompt cache memory: +// mlx.ClearCache() +// +// fmt.Printf("active: %d MB, peak: %d MB\n", +// mlx.GetActiveMemory()/1024/1024, mlx.GetPeakMemory()/1024/1024) package mlx //go:generate cmake -S . -B build -DCMAKE_INSTALL_PREFIX=dist -DCMAKE_BUILD_TYPE=Release