From 643df757e0a1c4871fd331613f44384feb22459e Mon Sep 17 00:00:00 2001 From: Snider Date: Thu, 19 Feb 2026 19:32:00 +0000 Subject: [PATCH] feat(api): define public TextModel, Backend, and options interfaces Co-Authored-By: Virgil --- backend.go | 66 ++++++++++++++++++++++++++++++++++++++++++++ options.go | 77 ++++++++++++++++++++++++++++++++++++++++++++++++++++ textmodel.go | 40 +++++++++++++++++++++++++++ 3 files changed, 183 insertions(+) create mode 100644 backend.go create mode 100644 options.go create mode 100644 textmodel.go diff --git a/backend.go b/backend.go new file mode 100644 index 0000000..ec2efc3 --- /dev/null +++ b/backend.go @@ -0,0 +1,66 @@ +package mlx + +import ( + "fmt" + "sync" +) + +// Backend is a named inference engine that can load models. +type Backend interface { + // Name returns the backend identifier (e.g. "metal", "mlx_lm"). + Name() string + + // LoadModel loads a model from the given path. + LoadModel(path string, opts ...LoadOption) (TextModel, error) +} + +var ( + backendsMu sync.RWMutex + backends = map[string]Backend{} +) + +// Register adds a backend to the registry. +func Register(b Backend) { + backendsMu.Lock() + defer backendsMu.Unlock() + backends[b.Name()] = b +} + +// Get returns a registered backend by name. +func Get(name string) (Backend, bool) { + backendsMu.RLock() + defer backendsMu.RUnlock() + b, ok := backends[name] + return b, ok +} + +// Default returns the first available backend. +// Prefers "metal" if registered. +func Default() (Backend, error) { + backendsMu.RLock() + defer backendsMu.RUnlock() + if b, ok := backends["metal"]; ok { + return b, nil + } + for _, b := range backends { + return b, nil + } + return nil, fmt.Errorf("mlx: no backends registered") +} + +// LoadModel loads a model using the specified or default backend. +func LoadModel(path string, opts ...LoadOption) (TextModel, error) { + cfg := ApplyLoadOpts(opts) + if cfg.Backend != "" { + b, ok := Get(cfg.Backend) + if !ok { + return nil, fmt.Errorf("mlx: backend %q not registered", cfg.Backend) + } + return b.LoadModel(path, opts...) + } + b, err := Default() + if err != nil { + return nil, err + } + return b.LoadModel(path, opts...) +} diff --git a/options.go b/options.go new file mode 100644 index 0000000..b8b867e --- /dev/null +++ b/options.go @@ -0,0 +1,77 @@ +package mlx + +// GenerateConfig holds generation parameters. +type GenerateConfig struct { + MaxTokens int + Temperature float32 + TopK int + TopP float32 + StopTokens []int32 +} + +// DefaultGenerateConfig returns sensible defaults. +func DefaultGenerateConfig() GenerateConfig { + return GenerateConfig{ + MaxTokens: 256, + Temperature: 0.0, + } +} + +// GenerateOption configures text generation. +type GenerateOption func(*GenerateConfig) + +// WithMaxTokens sets the maximum number of tokens to generate. +func WithMaxTokens(n int) GenerateOption { + return func(c *GenerateConfig) { c.MaxTokens = n } +} + +// WithTemperature sets the sampling temperature. 0 = greedy. +func WithTemperature(t float32) GenerateOption { + return func(c *GenerateConfig) { c.Temperature = t } +} + +// WithTopK sets top-k sampling. 0 = disabled. +func WithTopK(k int) GenerateOption { + return func(c *GenerateConfig) { c.TopK = k } +} + +// WithTopP sets nucleus sampling threshold. 0 = disabled. +func WithTopP(p float32) GenerateOption { + return func(c *GenerateConfig) { c.TopP = p } +} + +// WithStopTokens sets token IDs that stop generation. +func WithStopTokens(ids ...int32) GenerateOption { + return func(c *GenerateConfig) { c.StopTokens = ids } +} + +// LoadConfig holds model loading parameters. +type LoadConfig struct { + Backend string // "metal" (default), "mlx_lm" +} + +// LoadOption configures model loading. +type LoadOption func(*LoadConfig) + +// WithBackend selects a specific inference backend by name. +func WithBackend(name string) LoadOption { + return func(c *LoadConfig) { c.Backend = name } +} + +// ApplyGenerateOpts builds a GenerateConfig from options. +func ApplyGenerateOpts(opts []GenerateOption) GenerateConfig { + cfg := DefaultGenerateConfig() + for _, o := range opts { + o(&cfg) + } + return cfg +} + +// ApplyLoadOpts builds a LoadConfig from options. +func ApplyLoadOpts(opts []LoadOption) LoadConfig { + var cfg LoadConfig + for _, o := range opts { + o(&cfg) + } + return cfg +} diff --git a/textmodel.go b/textmodel.go new file mode 100644 index 0000000..ab87ce6 --- /dev/null +++ b/textmodel.go @@ -0,0 +1,40 @@ +package mlx + +import ( + "context" + "iter" +) + +// Token represents a single generated token for streaming. +type Token struct { + ID int32 + Text string +} + +// Message represents a chat turn for Chat(). +type Message struct { + Role string // "user", "assistant", "system" + Content string +} + +// TextModel generates text from a loaded model. +type TextModel interface { + // Generate streams tokens for the given prompt. + // Respects ctx cancellation (HTTP handlers, timeouts, graceful shutdown). + Generate(ctx context.Context, prompt string, opts ...GenerateOption) iter.Seq[Token] + + // Chat formats messages using the model's native chat template, then generates. + // The model owns its template — callers don't need to know Gemma vs Qwen formatting. + Chat(ctx context.Context, messages []Message, opts ...GenerateOption) iter.Seq[Token] + + // ModelType returns the architecture identifier (e.g. "gemma3", "qwen3"). + ModelType() string + + // Err returns the error from the last Generate/Chat call, if any. + // Distinguishes normal stop (EOS, max tokens) from failures (OOM, C-level error). + // Returns nil if generation completed normally. + Err() error + + // Close releases all resources (GPU memory, caches, subprocess). + Close() error +}