diff --git a/docs/plans/2026-02-19-backend-abstraction-design.md b/docs/plans/2026-02-19-backend-abstraction-design.md index 6246adb..ddf8f6f 100644 --- a/docs/plans/2026-02-19-backend-abstraction-design.md +++ b/docs/plans/2026-02-19-backend-abstraction-design.md @@ -1,7 +1,7 @@ # Backend Abstraction Design **Date:** 2026-02-19 -**Status:** Approved +**Status:** Approved (Virgil-reviewed 19 Feb 2026) **Author:** GoLand Claude (domain expert, go-mlx) ## Problem @@ -87,10 +87,29 @@ type Token struct { } type TextModel interface { - Generate(prompt string, opts ...GenerateOption) iter.Seq[Token] + // Generate streams tokens for the given prompt. Respects ctx cancellation. + Generate(ctx context.Context, prompt string, opts ...GenerateOption) iter.Seq[Token] + + // Chat formats messages using the model's native template, then generates. + // Deferred to Phase 5 if needed — model owns its chat template. + Chat(ctx context.Context, messages []Message, opts ...GenerateOption) iter.Seq[Token] + + // ModelType returns the architecture identifier (e.g. "gemma3", "qwen3"). ModelType() string + + // Err returns the last generation error (OOM, C-level failure, etc.). + // Distinguishes normal stop (EOS, max tokens) from errors. + Err() error + + // Close releases all resources (GPU memory, caches, subprocess). Close() error } + +// Message represents a chat turn for Chat(). +type Message struct { + Role string // "user", "assistant", "system" + Content string +} ``` ### Entry point @@ -98,6 +117,14 @@ type TextModel interface { ```go func LoadModel(path string, opts ...LoadOption) (TextModel, error) func MetalAvailable() bool + +// Hardware-level memory controls (delegate to internal/metal). +// These are not model-level — they control the Metal allocator directly. +func SetCacheLimit(limit uint64) uint64 +func SetMemoryLimit(limit uint64) uint64 +func GetActiveMemory() uint64 +func GetPeakMemory() uint64 +func ClearCache() ``` ### Functional options diff --git a/docs/plans/2026-02-19-backend-abstraction-plan.md b/docs/plans/2026-02-19-backend-abstraction-plan.md index 3777860..a1a8343 100644 --- a/docs/plans/2026-02-19-backend-abstraction-plan.md +++ b/docs/plans/2026-02-19-backend-abstraction-plan.md @@ -10,6 +10,12 @@ **Design doc:** `docs/plans/2026-02-19-backend-abstraction-design.md` +**Virgil review (19 Feb 2026) — folded into plan:** +1. **`context.Context` on Generate()** — REQUIRED. HTTP handlers need cancellation. Check `ctx.Done()` in decode loop. +2. **`Err() error` on TextModel** — Distinguish EOS/max-tokens from OOM/C-errors. Check after iteration. +3. **`Chat()` on TextModel** — Model owns its chat template. Can defer to Phase 5 but recommended now. +4. **Memory control functions at root** — `SetCacheLimit`, `SetMemoryLimit`, `GetActiveMemory`, `GetPeakMemory`, `ClearCache` stay public, delegate to `internal/metal`. + --- ## Pre-flight @@ -61,7 +67,10 @@ Define the public API types. No CGO, no build tags. This file will eventually re ```go package mlx -import "iter" +import ( + "context" + "iter" +) // Token represents a single generated token for streaming. type Token struct { @@ -69,14 +78,30 @@ type Token struct { Text string } +// Message represents a chat turn for Chat(). +type Message struct { + Role string // "user", "assistant", "system" + Content string +} + // TextModel generates text from a loaded model. type TextModel interface { // Generate streams tokens for the given prompt. - Generate(prompt string, opts ...GenerateOption) iter.Seq[Token] + // Respects ctx cancellation (HTTP handlers, timeouts, graceful shutdown). + Generate(ctx context.Context, prompt string, opts ...GenerateOption) iter.Seq[Token] + + // Chat formats messages using the model's native chat template, then generates. + // The model owns its template — callers don't need to know Gemma vs Qwen formatting. + Chat(ctx context.Context, messages []Message, opts ...GenerateOption) iter.Seq[Token] // ModelType returns the architecture identifier (e.g. "gemma3", "qwen3"). ModelType() string + // Err returns the error from the last Generate/Chat call, if any. + // Distinguishes normal stop (EOS, max tokens) from failures (OOM, C-level error). + // Returns nil if generation completed normally. + Err() error + // Close releases all resources (GPU memory, caches, subprocess). Close() error } @@ -529,6 +554,15 @@ func init() { // MetalAvailable reports whether native Metal inference is available. func MetalAvailable() bool { return true } + +// Hardware-level memory controls — delegate to internal/metal. +// These are not model-level; they control the Metal allocator directly. + +func SetCacheLimit(limit uint64) uint64 { return metal.SetCacheLimit(limit) } +func SetMemoryLimit(limit uint64) uint64 { return metal.SetMemoryLimit(limit) } +func GetActiveMemory() uint64 { return metal.GetActiveMemory() } +func GetPeakMemory() uint64 { return metal.GetPeakMemory() } +func ClearCache() { metal.ClearCache() } ``` **Step 2: Write final `mlx.go`** @@ -606,10 +640,14 @@ func TestGenerate_Greedy(t *testing.T) { } defer m.Close() + ctx := context.Background() var tokens []mlx.Token - for tok := range m.Generate("What is 2+2?", mlx.WithMaxTokens(16)) { + for tok := range m.Generate(ctx, "What is 2+2?", mlx.WithMaxTokens(16)) { tokens = append(tokens, tok) } + if err := m.Err(); err != nil { + t.Fatalf("Generate error: %v", err) + } if len(tokens) == 0 { t.Fatal("Generate produced no tokens") @@ -651,9 +689,11 @@ type metalModel struct { model InternalModel tokenizer *Tokenizer modelType string + lastErr error // set by Generate/Chat on failure } func (m *metalModel) ModelType() string { return m.modelType } +func (m *metalModel) Err() error { return m.lastErr } func (m *metalModel) Close() error { // TODO: explicit Free() on all model weight arrays and caches @@ -661,8 +701,14 @@ func (m *metalModel) Close() error { return nil } -func (m *metalModel) Generate(prompt string, opts ...mlx.GenerateOption) iter.Seq[mlx.Token] { +func (m *metalModel) Chat(ctx context.Context, messages []mlx.Message, opts ...mlx.GenerateOption) iter.Seq[mlx.Token] { + prompt := m.formatChat(messages) // model-specific template + return m.Generate(ctx, prompt, opts...) +} + +func (m *metalModel) Generate(ctx context.Context, prompt string, opts ...mlx.GenerateOption) iter.Seq[mlx.Token] { cfg := mlx.ApplyGenerateOpts(opts) + m.lastErr = nil // reset per-generation return func(yield func(mlx.Token) bool) { tokens := m.tokenizer.Encode(prompt) @@ -676,6 +722,14 @@ func (m *metalModel) Generate(prompt string, opts ...mlx.GenerateOption) iter.Se Materialize(logits) for i := 0; i < cfg.MaxTokens; i++ { + // Check context cancellation (HTTP timeout, shutdown) + select { + case <-ctx.Done(): + m.lastErr = ctx.Err() + return + default: + } + // Sample from last position logits lastPos := SliceAxis(logits, 1, int32(logits.Dim(1)-1), int32(logits.Dim(1))) lastPos = Reshape(lastPos, 1, int32(lastPos.Dim(2))) // [1, vocab] @@ -799,7 +853,8 @@ func TestMemory_GenerateDoesNotLeak(t *testing.T) { ClearCache() beforeMem := GetActiveMemory() - for tok := range m.Generate("Tell me a story", + ctx := context.Background() + for tok := range m.Generate(ctx, "Tell me a story", mlx.WithMaxTokens(100), mlx.WithTemperature(0.7)) { _ = tok } @@ -881,8 +936,9 @@ func TestLoadModel_MetalBackend(t *testing.T) { t.Errorf("ModelType = %q, want gemma3", m.ModelType()) } + ctx := context.Background() var count int - for tok := range m.Generate("What is 2+2?", mlx.WithMaxTokens(16)) { + for tok := range m.Generate(ctx, "What is 2+2?", mlx.WithMaxTokens(16)) { count++ t.Logf("[%d] %q", tok.ID, tok.Text) } @@ -968,3 +1024,14 @@ git commit -m "docs: update project docs for backend abstraction" **Critical checkpoint:** After Task 7, all 148 tests must pass in `internal/metal/`. If they don't, stop and fix before continuing. **Model-dependent tests** (Tasks 9, 10, 11) require Gemma3-1B at `/Volumes/Data/lem/safetensors/gemma-3/`. They use `t.Skip()` when the model isn't available. + +## Virgil Review Items (folded in) + +| Item | Where | Status | +|------|-------|--------| +| `context.Context` on Generate/Chat | Task 2 (textmodel.go), Task 9 (generate.go) | Integrated | +| `Err() error` on TextModel | Task 2 (textmodel.go), Task 9 (generate.go) | Integrated | +| `Chat()` on TextModel | Task 2 (textmodel.go), Task 9 (generate.go) | Integrated | +| Memory control functions at root | Task 8 (register_metal.go) | Integrated | +| Functional options convention | Confirmed by Virgil — no conflict with core/go | N/A | +| pkg/process for mlxlm | Confirmed by Virgil — no changes needed | N/A |