docs(plan): fold Virgil review into design and implementation plan

Virgil review items integrated:
- context.Context on Generate/Chat (required for HTTP cancellation)
- Err() error on TextModel (distinguish EOS from OOM)
- Chat() on TextModel (model owns its chat template)
- Memory control functions exposed at root package level
- Functional options convention confirmed
- pkg/process confirmed — no changes needed for mlxlm

Co-Authored-By: Virgil <virgil@lethean.io>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Snider 2026-02-19 19:25:05 +00:00
parent 28e2a07316
commit 97d9041455
2 changed files with 102 additions and 8 deletions

View file

@ -1,7 +1,7 @@
# Backend Abstraction Design
**Date:** 2026-02-19
**Status:** Approved
**Status:** Approved (Virgil-reviewed 19 Feb 2026)
**Author:** GoLand Claude (domain expert, go-mlx)
## Problem
@ -87,10 +87,29 @@ type Token struct {
}
type TextModel interface {
Generate(prompt string, opts ...GenerateOption) iter.Seq[Token]
// Generate streams tokens for the given prompt. Respects ctx cancellation.
Generate(ctx context.Context, prompt string, opts ...GenerateOption) iter.Seq[Token]
// Chat formats messages using the model's native template, then generates.
// Deferred to Phase 5 if needed — model owns its chat template.
Chat(ctx context.Context, messages []Message, opts ...GenerateOption) iter.Seq[Token]
// ModelType returns the architecture identifier (e.g. "gemma3", "qwen3").
ModelType() string
// Err returns the last generation error (OOM, C-level failure, etc.).
// Distinguishes normal stop (EOS, max tokens) from errors.
Err() error
// Close releases all resources (GPU memory, caches, subprocess).
Close() error
}
// Message represents a chat turn for Chat().
type Message struct {
Role string // "user", "assistant", "system"
Content string
}
```
### Entry point
@ -98,6 +117,14 @@ type TextModel interface {
```go
func LoadModel(path string, opts ...LoadOption) (TextModel, error)
func MetalAvailable() bool
// Hardware-level memory controls (delegate to internal/metal).
// These are not model-level — they control the Metal allocator directly.
func SetCacheLimit(limit uint64) uint64
func SetMemoryLimit(limit uint64) uint64
func GetActiveMemory() uint64
func GetPeakMemory() uint64
func ClearCache()
```
### Functional options

View file

@ -10,6 +10,12 @@
**Design doc:** `docs/plans/2026-02-19-backend-abstraction-design.md`
**Virgil review (19 Feb 2026) — folded into plan:**
1. **`context.Context` on Generate()** — REQUIRED. HTTP handlers need cancellation. Check `ctx.Done()` in decode loop.
2. **`Err() error` on TextModel** — Distinguish EOS/max-tokens from OOM/C-errors. Check after iteration.
3. **`Chat()` on TextModel** — Model owns its chat template. Can defer to Phase 5 but recommended now.
4. **Memory control functions at root**`SetCacheLimit`, `SetMemoryLimit`, `GetActiveMemory`, `GetPeakMemory`, `ClearCache` stay public, delegate to `internal/metal`.
---
## Pre-flight
@ -61,7 +67,10 @@ Define the public API types. No CGO, no build tags. This file will eventually re
```go
package mlx
import "iter"
import (
"context"
"iter"
)
// Token represents a single generated token for streaming.
type Token struct {
@ -69,14 +78,30 @@ type Token struct {
Text string
}
// Message represents a chat turn for Chat().
type Message struct {
Role string // "user", "assistant", "system"
Content string
}
// TextModel generates text from a loaded model.
type TextModel interface {
// Generate streams tokens for the given prompt.
Generate(prompt string, opts ...GenerateOption) iter.Seq[Token]
// Respects ctx cancellation (HTTP handlers, timeouts, graceful shutdown).
Generate(ctx context.Context, prompt string, opts ...GenerateOption) iter.Seq[Token]
// Chat formats messages using the model's native chat template, then generates.
// The model owns its template — callers don't need to know Gemma vs Qwen formatting.
Chat(ctx context.Context, messages []Message, opts ...GenerateOption) iter.Seq[Token]
// ModelType returns the architecture identifier (e.g. "gemma3", "qwen3").
ModelType() string
// Err returns the error from the last Generate/Chat call, if any.
// Distinguishes normal stop (EOS, max tokens) from failures (OOM, C-level error).
// Returns nil if generation completed normally.
Err() error
// Close releases all resources (GPU memory, caches, subprocess).
Close() error
}
@ -529,6 +554,15 @@ func init() {
// MetalAvailable reports whether native Metal inference is available.
func MetalAvailable() bool { return true }
// Hardware-level memory controls — delegate to internal/metal.
// These are not model-level; they control the Metal allocator directly.
func SetCacheLimit(limit uint64) uint64 { return metal.SetCacheLimit(limit) }
func SetMemoryLimit(limit uint64) uint64 { return metal.SetMemoryLimit(limit) }
func GetActiveMemory() uint64 { return metal.GetActiveMemory() }
func GetPeakMemory() uint64 { return metal.GetPeakMemory() }
func ClearCache() { metal.ClearCache() }
```
**Step 2: Write final `mlx.go`**
@ -606,10 +640,14 @@ func TestGenerate_Greedy(t *testing.T) {
}
defer m.Close()
ctx := context.Background()
var tokens []mlx.Token
for tok := range m.Generate("What is 2+2?", mlx.WithMaxTokens(16)) {
for tok := range m.Generate(ctx, "What is 2+2?", mlx.WithMaxTokens(16)) {
tokens = append(tokens, tok)
}
if err := m.Err(); err != nil {
t.Fatalf("Generate error: %v", err)
}
if len(tokens) == 0 {
t.Fatal("Generate produced no tokens")
@ -651,9 +689,11 @@ type metalModel struct {
model InternalModel
tokenizer *Tokenizer
modelType string
lastErr error // set by Generate/Chat on failure
}
func (m *metalModel) ModelType() string { return m.modelType }
func (m *metalModel) Err() error { return m.lastErr }
func (m *metalModel) Close() error {
// TODO: explicit Free() on all model weight arrays and caches
@ -661,8 +701,14 @@ func (m *metalModel) Close() error {
return nil
}
func (m *metalModel) Generate(prompt string, opts ...mlx.GenerateOption) iter.Seq[mlx.Token] {
func (m *metalModel) Chat(ctx context.Context, messages []mlx.Message, opts ...mlx.GenerateOption) iter.Seq[mlx.Token] {
prompt := m.formatChat(messages) // model-specific template
return m.Generate(ctx, prompt, opts...)
}
func (m *metalModel) Generate(ctx context.Context, prompt string, opts ...mlx.GenerateOption) iter.Seq[mlx.Token] {
cfg := mlx.ApplyGenerateOpts(opts)
m.lastErr = nil // reset per-generation
return func(yield func(mlx.Token) bool) {
tokens := m.tokenizer.Encode(prompt)
@ -676,6 +722,14 @@ func (m *metalModel) Generate(prompt string, opts ...mlx.GenerateOption) iter.Se
Materialize(logits)
for i := 0; i < cfg.MaxTokens; i++ {
// Check context cancellation (HTTP timeout, shutdown)
select {
case <-ctx.Done():
m.lastErr = ctx.Err()
return
default:
}
// Sample from last position logits
lastPos := SliceAxis(logits, 1, int32(logits.Dim(1)-1), int32(logits.Dim(1)))
lastPos = Reshape(lastPos, 1, int32(lastPos.Dim(2))) // [1, vocab]
@ -799,7 +853,8 @@ func TestMemory_GenerateDoesNotLeak(t *testing.T) {
ClearCache()
beforeMem := GetActiveMemory()
for tok := range m.Generate("Tell me a story",
ctx := context.Background()
for tok := range m.Generate(ctx, "Tell me a story",
mlx.WithMaxTokens(100), mlx.WithTemperature(0.7)) {
_ = tok
}
@ -881,8 +936,9 @@ func TestLoadModel_MetalBackend(t *testing.T) {
t.Errorf("ModelType = %q, want gemma3", m.ModelType())
}
ctx := context.Background()
var count int
for tok := range m.Generate("What is 2+2?", mlx.WithMaxTokens(16)) {
for tok := range m.Generate(ctx, "What is 2+2?", mlx.WithMaxTokens(16)) {
count++
t.Logf("[%d] %q", tok.ID, tok.Text)
}
@ -968,3 +1024,14 @@ git commit -m "docs: update project docs for backend abstraction"
**Critical checkpoint:** After Task 7, all 148 tests must pass in `internal/metal/`. If they don't, stop and fix before continuing.
**Model-dependent tests** (Tasks 9, 10, 11) require Gemma3-1B at `/Volumes/Data/lem/safetensors/gemma-3/`. They use `t.Skip()` when the model isn't available.
## Virgil Review Items (folded in)
| Item | Where | Status |
|------|-------|--------|
| `context.Context` on Generate/Chat | Task 2 (textmodel.go), Task 9 (generate.go) | Integrated |
| `Err() error` on TextModel | Task 2 (textmodel.go), Task 9 (generate.go) | Integrated |
| `Chat()` on TextModel | Task 2 (textmodel.go), Task 9 (generate.go) | Integrated |
| Memory control functions at root | Task 8 (register_metal.go) | Integrated |
| Functional options convention | Confirmed by Virgil — no conflict with core/go | N/A |
| pkg/process for mlxlm | Confirmed by Virgil — no changes needed | N/A |