docs(plan): fold Virgil review into design and implementation plan
Virgil review items integrated: - context.Context on Generate/Chat (required for HTTP cancellation) - Err() error on TextModel (distinguish EOS from OOM) - Chat() on TextModel (model owns its chat template) - Memory control functions exposed at root package level - Functional options convention confirmed - pkg/process confirmed — no changes needed for mlxlm Co-Authored-By: Virgil <virgil@lethean.io> Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
28e2a07316
commit
97d9041455
2 changed files with 102 additions and 8 deletions
|
|
@ -1,7 +1,7 @@
|
|||
# Backend Abstraction Design
|
||||
|
||||
**Date:** 2026-02-19
|
||||
**Status:** Approved
|
||||
**Status:** Approved (Virgil-reviewed 19 Feb 2026)
|
||||
**Author:** GoLand Claude (domain expert, go-mlx)
|
||||
|
||||
## Problem
|
||||
|
|
@ -87,10 +87,29 @@ type Token struct {
|
|||
}
|
||||
|
||||
type TextModel interface {
|
||||
Generate(prompt string, opts ...GenerateOption) iter.Seq[Token]
|
||||
// Generate streams tokens for the given prompt. Respects ctx cancellation.
|
||||
Generate(ctx context.Context, prompt string, opts ...GenerateOption) iter.Seq[Token]
|
||||
|
||||
// Chat formats messages using the model's native template, then generates.
|
||||
// Deferred to Phase 5 if needed — model owns its chat template.
|
||||
Chat(ctx context.Context, messages []Message, opts ...GenerateOption) iter.Seq[Token]
|
||||
|
||||
// ModelType returns the architecture identifier (e.g. "gemma3", "qwen3").
|
||||
ModelType() string
|
||||
|
||||
// Err returns the last generation error (OOM, C-level failure, etc.).
|
||||
// Distinguishes normal stop (EOS, max tokens) from errors.
|
||||
Err() error
|
||||
|
||||
// Close releases all resources (GPU memory, caches, subprocess).
|
||||
Close() error
|
||||
}
|
||||
|
||||
// Message represents a chat turn for Chat().
|
||||
type Message struct {
|
||||
Role string // "user", "assistant", "system"
|
||||
Content string
|
||||
}
|
||||
```
|
||||
|
||||
### Entry point
|
||||
|
|
@ -98,6 +117,14 @@ type TextModel interface {
|
|||
```go
|
||||
func LoadModel(path string, opts ...LoadOption) (TextModel, error)
|
||||
func MetalAvailable() bool
|
||||
|
||||
// Hardware-level memory controls (delegate to internal/metal).
|
||||
// These are not model-level — they control the Metal allocator directly.
|
||||
func SetCacheLimit(limit uint64) uint64
|
||||
func SetMemoryLimit(limit uint64) uint64
|
||||
func GetActiveMemory() uint64
|
||||
func GetPeakMemory() uint64
|
||||
func ClearCache()
|
||||
```
|
||||
|
||||
### Functional options
|
||||
|
|
|
|||
|
|
@ -10,6 +10,12 @@
|
|||
|
||||
**Design doc:** `docs/plans/2026-02-19-backend-abstraction-design.md`
|
||||
|
||||
**Virgil review (19 Feb 2026) — folded into plan:**
|
||||
1. **`context.Context` on Generate()** — REQUIRED. HTTP handlers need cancellation. Check `ctx.Done()` in decode loop.
|
||||
2. **`Err() error` on TextModel** — Distinguish EOS/max-tokens from OOM/C-errors. Check after iteration.
|
||||
3. **`Chat()` on TextModel** — Model owns its chat template. Can defer to Phase 5 but recommended now.
|
||||
4. **Memory control functions at root** — `SetCacheLimit`, `SetMemoryLimit`, `GetActiveMemory`, `GetPeakMemory`, `ClearCache` stay public, delegate to `internal/metal`.
|
||||
|
||||
---
|
||||
|
||||
## Pre-flight
|
||||
|
|
@ -61,7 +67,10 @@ Define the public API types. No CGO, no build tags. This file will eventually re
|
|||
```go
|
||||
package mlx
|
||||
|
||||
import "iter"
|
||||
import (
|
||||
"context"
|
||||
"iter"
|
||||
)
|
||||
|
||||
// Token represents a single generated token for streaming.
|
||||
type Token struct {
|
||||
|
|
@ -69,14 +78,30 @@ type Token struct {
|
|||
Text string
|
||||
}
|
||||
|
||||
// Message represents a chat turn for Chat().
|
||||
type Message struct {
|
||||
Role string // "user", "assistant", "system"
|
||||
Content string
|
||||
}
|
||||
|
||||
// TextModel generates text from a loaded model.
|
||||
type TextModel interface {
|
||||
// Generate streams tokens for the given prompt.
|
||||
Generate(prompt string, opts ...GenerateOption) iter.Seq[Token]
|
||||
// Respects ctx cancellation (HTTP handlers, timeouts, graceful shutdown).
|
||||
Generate(ctx context.Context, prompt string, opts ...GenerateOption) iter.Seq[Token]
|
||||
|
||||
// Chat formats messages using the model's native chat template, then generates.
|
||||
// The model owns its template — callers don't need to know Gemma vs Qwen formatting.
|
||||
Chat(ctx context.Context, messages []Message, opts ...GenerateOption) iter.Seq[Token]
|
||||
|
||||
// ModelType returns the architecture identifier (e.g. "gemma3", "qwen3").
|
||||
ModelType() string
|
||||
|
||||
// Err returns the error from the last Generate/Chat call, if any.
|
||||
// Distinguishes normal stop (EOS, max tokens) from failures (OOM, C-level error).
|
||||
// Returns nil if generation completed normally.
|
||||
Err() error
|
||||
|
||||
// Close releases all resources (GPU memory, caches, subprocess).
|
||||
Close() error
|
||||
}
|
||||
|
|
@ -529,6 +554,15 @@ func init() {
|
|||
|
||||
// MetalAvailable reports whether native Metal inference is available.
|
||||
func MetalAvailable() bool { return true }
|
||||
|
||||
// Hardware-level memory controls — delegate to internal/metal.
|
||||
// These are not model-level; they control the Metal allocator directly.
|
||||
|
||||
func SetCacheLimit(limit uint64) uint64 { return metal.SetCacheLimit(limit) }
|
||||
func SetMemoryLimit(limit uint64) uint64 { return metal.SetMemoryLimit(limit) }
|
||||
func GetActiveMemory() uint64 { return metal.GetActiveMemory() }
|
||||
func GetPeakMemory() uint64 { return metal.GetPeakMemory() }
|
||||
func ClearCache() { metal.ClearCache() }
|
||||
```
|
||||
|
||||
**Step 2: Write final `mlx.go`**
|
||||
|
|
@ -606,10 +640,14 @@ func TestGenerate_Greedy(t *testing.T) {
|
|||
}
|
||||
defer m.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
var tokens []mlx.Token
|
||||
for tok := range m.Generate("What is 2+2?", mlx.WithMaxTokens(16)) {
|
||||
for tok := range m.Generate(ctx, "What is 2+2?", mlx.WithMaxTokens(16)) {
|
||||
tokens = append(tokens, tok)
|
||||
}
|
||||
if err := m.Err(); err != nil {
|
||||
t.Fatalf("Generate error: %v", err)
|
||||
}
|
||||
|
||||
if len(tokens) == 0 {
|
||||
t.Fatal("Generate produced no tokens")
|
||||
|
|
@ -651,9 +689,11 @@ type metalModel struct {
|
|||
model InternalModel
|
||||
tokenizer *Tokenizer
|
||||
modelType string
|
||||
lastErr error // set by Generate/Chat on failure
|
||||
}
|
||||
|
||||
func (m *metalModel) ModelType() string { return m.modelType }
|
||||
func (m *metalModel) Err() error { return m.lastErr }
|
||||
|
||||
func (m *metalModel) Close() error {
|
||||
// TODO: explicit Free() on all model weight arrays and caches
|
||||
|
|
@ -661,8 +701,14 @@ func (m *metalModel) Close() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (m *metalModel) Generate(prompt string, opts ...mlx.GenerateOption) iter.Seq[mlx.Token] {
|
||||
func (m *metalModel) Chat(ctx context.Context, messages []mlx.Message, opts ...mlx.GenerateOption) iter.Seq[mlx.Token] {
|
||||
prompt := m.formatChat(messages) // model-specific template
|
||||
return m.Generate(ctx, prompt, opts...)
|
||||
}
|
||||
|
||||
func (m *metalModel) Generate(ctx context.Context, prompt string, opts ...mlx.GenerateOption) iter.Seq[mlx.Token] {
|
||||
cfg := mlx.ApplyGenerateOpts(opts)
|
||||
m.lastErr = nil // reset per-generation
|
||||
|
||||
return func(yield func(mlx.Token) bool) {
|
||||
tokens := m.tokenizer.Encode(prompt)
|
||||
|
|
@ -676,6 +722,14 @@ func (m *metalModel) Generate(prompt string, opts ...mlx.GenerateOption) iter.Se
|
|||
Materialize(logits)
|
||||
|
||||
for i := 0; i < cfg.MaxTokens; i++ {
|
||||
// Check context cancellation (HTTP timeout, shutdown)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
m.lastErr = ctx.Err()
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
// Sample from last position logits
|
||||
lastPos := SliceAxis(logits, 1, int32(logits.Dim(1)-1), int32(logits.Dim(1)))
|
||||
lastPos = Reshape(lastPos, 1, int32(lastPos.Dim(2))) // [1, vocab]
|
||||
|
|
@ -799,7 +853,8 @@ func TestMemory_GenerateDoesNotLeak(t *testing.T) {
|
|||
ClearCache()
|
||||
beforeMem := GetActiveMemory()
|
||||
|
||||
for tok := range m.Generate("Tell me a story",
|
||||
ctx := context.Background()
|
||||
for tok := range m.Generate(ctx, "Tell me a story",
|
||||
mlx.WithMaxTokens(100), mlx.WithTemperature(0.7)) {
|
||||
_ = tok
|
||||
}
|
||||
|
|
@ -881,8 +936,9 @@ func TestLoadModel_MetalBackend(t *testing.T) {
|
|||
t.Errorf("ModelType = %q, want gemma3", m.ModelType())
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
var count int
|
||||
for tok := range m.Generate("What is 2+2?", mlx.WithMaxTokens(16)) {
|
||||
for tok := range m.Generate(ctx, "What is 2+2?", mlx.WithMaxTokens(16)) {
|
||||
count++
|
||||
t.Logf("[%d] %q", tok.ID, tok.Text)
|
||||
}
|
||||
|
|
@ -968,3 +1024,14 @@ git commit -m "docs: update project docs for backend abstraction"
|
|||
**Critical checkpoint:** After Task 7, all 148 tests must pass in `internal/metal/`. If they don't, stop and fix before continuing.
|
||||
|
||||
**Model-dependent tests** (Tasks 9, 10, 11) require Gemma3-1B at `/Volumes/Data/lem/safetensors/gemma-3/`. They use `t.Skip()` when the model isn't available.
|
||||
|
||||
## Virgil Review Items (folded in)
|
||||
|
||||
| Item | Where | Status |
|
||||
|------|-------|--------|
|
||||
| `context.Context` on Generate/Chat | Task 2 (textmodel.go), Task 9 (generate.go) | Integrated |
|
||||
| `Err() error` on TextModel | Task 2 (textmodel.go), Task 9 (generate.go) | Integrated |
|
||||
| `Chat()` on TextModel | Task 2 (textmodel.go), Task 9 (generate.go) | Integrated |
|
||||
| Memory control functions at root | Task 8 (register_metal.go) | Integrated |
|
||||
| Functional options convention | Confirmed by Virgil — no conflict with core/go | N/A |
|
||||
| pkg/process for mlxlm | Confirmed by Virgil — no changes needed | N/A |
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue