feat(api): define public TextModel, Backend, and options interfaces

Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
Snider 2026-02-19 19:32:00 +00:00
parent f7a57657d4
commit 643df757e0
3 changed files with 183 additions and 0 deletions

66
backend.go Normal file
View file

@ -0,0 +1,66 @@
package mlx
import (
"fmt"
"sync"
)
// Backend is a named inference engine that can load models.
type Backend interface {
// Name returns the backend identifier (e.g. "metal", "mlx_lm").
Name() string
// LoadModel loads a model from the given path.
LoadModel(path string, opts ...LoadOption) (TextModel, error)
}
var (
backendsMu sync.RWMutex
backends = map[string]Backend{}
)
// Register adds a backend to the registry.
func Register(b Backend) {
backendsMu.Lock()
defer backendsMu.Unlock()
backends[b.Name()] = b
}
// Get returns a registered backend by name.
func Get(name string) (Backend, bool) {
backendsMu.RLock()
defer backendsMu.RUnlock()
b, ok := backends[name]
return b, ok
}
// Default returns the first available backend.
// Prefers "metal" if registered.
func Default() (Backend, error) {
backendsMu.RLock()
defer backendsMu.RUnlock()
if b, ok := backends["metal"]; ok {
return b, nil
}
for _, b := range backends {
return b, nil
}
return nil, fmt.Errorf("mlx: no backends registered")
}
// LoadModel loads a model using the specified or default backend.
func LoadModel(path string, opts ...LoadOption) (TextModel, error) {
cfg := ApplyLoadOpts(opts)
if cfg.Backend != "" {
b, ok := Get(cfg.Backend)
if !ok {
return nil, fmt.Errorf("mlx: backend %q not registered", cfg.Backend)
}
return b.LoadModel(path, opts...)
}
b, err := Default()
if err != nil {
return nil, err
}
return b.LoadModel(path, opts...)
}

77
options.go Normal file
View file

@ -0,0 +1,77 @@
package mlx
// GenerateConfig holds generation parameters.
type GenerateConfig struct {
MaxTokens int
Temperature float32
TopK int
TopP float32
StopTokens []int32
}
// DefaultGenerateConfig returns sensible defaults.
func DefaultGenerateConfig() GenerateConfig {
return GenerateConfig{
MaxTokens: 256,
Temperature: 0.0,
}
}
// GenerateOption configures text generation.
type GenerateOption func(*GenerateConfig)
// WithMaxTokens sets the maximum number of tokens to generate.
func WithMaxTokens(n int) GenerateOption {
return func(c *GenerateConfig) { c.MaxTokens = n }
}
// WithTemperature sets the sampling temperature. 0 = greedy.
func WithTemperature(t float32) GenerateOption {
return func(c *GenerateConfig) { c.Temperature = t }
}
// WithTopK sets top-k sampling. 0 = disabled.
func WithTopK(k int) GenerateOption {
return func(c *GenerateConfig) { c.TopK = k }
}
// WithTopP sets nucleus sampling threshold. 0 = disabled.
func WithTopP(p float32) GenerateOption {
return func(c *GenerateConfig) { c.TopP = p }
}
// WithStopTokens sets token IDs that stop generation.
func WithStopTokens(ids ...int32) GenerateOption {
return func(c *GenerateConfig) { c.StopTokens = ids }
}
// LoadConfig holds model loading parameters.
type LoadConfig struct {
Backend string // "metal" (default), "mlx_lm"
}
// LoadOption configures model loading.
type LoadOption func(*LoadConfig)
// WithBackend selects a specific inference backend by name.
func WithBackend(name string) LoadOption {
return func(c *LoadConfig) { c.Backend = name }
}
// ApplyGenerateOpts builds a GenerateConfig from options.
func ApplyGenerateOpts(opts []GenerateOption) GenerateConfig {
cfg := DefaultGenerateConfig()
for _, o := range opts {
o(&cfg)
}
return cfg
}
// ApplyLoadOpts builds a LoadConfig from options.
func ApplyLoadOpts(opts []LoadOption) LoadConfig {
var cfg LoadConfig
for _, o := range opts {
o(&cfg)
}
return cfg
}

40
textmodel.go Normal file
View file

@ -0,0 +1,40 @@
package mlx
import (
"context"
"iter"
)
// Token represents a single generated token for streaming.
type Token struct {
ID int32
Text string
}
// Message represents a chat turn for Chat().
type Message struct {
Role string // "user", "assistant", "system"
Content string
}
// TextModel generates text from a loaded model.
type TextModel interface {
// Generate streams tokens for the given prompt.
// Respects ctx cancellation (HTTP handlers, timeouts, graceful shutdown).
Generate(ctx context.Context, prompt string, opts ...GenerateOption) iter.Seq[Token]
// Chat formats messages using the model's native chat template, then generates.
// The model owns its template — callers don't need to know Gemma vs Qwen formatting.
Chat(ctx context.Context, messages []Message, opts ...GenerateOption) iter.Seq[Token]
// ModelType returns the architecture identifier (e.g. "gemma3", "qwen3").
ModelType() string
// Err returns the error from the last Generate/Chat call, if any.
// Distinguishes normal stop (EOS, max tokens) from failures (OOM, C-level error).
// Returns nil if generation completed normally.
Err() error
// Close releases all resources (GPU memory, caches, subprocess).
Close() error
}