feat(api): define public TextModel, Backend, and options interfaces
Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
parent
f7a57657d4
commit
643df757e0
3 changed files with 183 additions and 0 deletions
66
backend.go
Normal file
66
backend.go
Normal file
|
|
@ -0,0 +1,66 @@
|
|||
package mlx
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// Backend is a named inference engine that can load models.
|
||||
type Backend interface {
|
||||
// Name returns the backend identifier (e.g. "metal", "mlx_lm").
|
||||
Name() string
|
||||
|
||||
// LoadModel loads a model from the given path.
|
||||
LoadModel(path string, opts ...LoadOption) (TextModel, error)
|
||||
}
|
||||
|
||||
var (
|
||||
backendsMu sync.RWMutex
|
||||
backends = map[string]Backend{}
|
||||
)
|
||||
|
||||
// Register adds a backend to the registry.
|
||||
func Register(b Backend) {
|
||||
backendsMu.Lock()
|
||||
defer backendsMu.Unlock()
|
||||
backends[b.Name()] = b
|
||||
}
|
||||
|
||||
// Get returns a registered backend by name.
|
||||
func Get(name string) (Backend, bool) {
|
||||
backendsMu.RLock()
|
||||
defer backendsMu.RUnlock()
|
||||
b, ok := backends[name]
|
||||
return b, ok
|
||||
}
|
||||
|
||||
// Default returns the first available backend.
|
||||
// Prefers "metal" if registered.
|
||||
func Default() (Backend, error) {
|
||||
backendsMu.RLock()
|
||||
defer backendsMu.RUnlock()
|
||||
if b, ok := backends["metal"]; ok {
|
||||
return b, nil
|
||||
}
|
||||
for _, b := range backends {
|
||||
return b, nil
|
||||
}
|
||||
return nil, fmt.Errorf("mlx: no backends registered")
|
||||
}
|
||||
|
||||
// LoadModel loads a model using the specified or default backend.
|
||||
func LoadModel(path string, opts ...LoadOption) (TextModel, error) {
|
||||
cfg := ApplyLoadOpts(opts)
|
||||
if cfg.Backend != "" {
|
||||
b, ok := Get(cfg.Backend)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("mlx: backend %q not registered", cfg.Backend)
|
||||
}
|
||||
return b.LoadModel(path, opts...)
|
||||
}
|
||||
b, err := Default()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return b.LoadModel(path, opts...)
|
||||
}
|
||||
77
options.go
Normal file
77
options.go
Normal file
|
|
@ -0,0 +1,77 @@
|
|||
package mlx
|
||||
|
||||
// GenerateConfig holds generation parameters.
|
||||
type GenerateConfig struct {
|
||||
MaxTokens int
|
||||
Temperature float32
|
||||
TopK int
|
||||
TopP float32
|
||||
StopTokens []int32
|
||||
}
|
||||
|
||||
// DefaultGenerateConfig returns sensible defaults.
|
||||
func DefaultGenerateConfig() GenerateConfig {
|
||||
return GenerateConfig{
|
||||
MaxTokens: 256,
|
||||
Temperature: 0.0,
|
||||
}
|
||||
}
|
||||
|
||||
// GenerateOption configures text generation.
|
||||
type GenerateOption func(*GenerateConfig)
|
||||
|
||||
// WithMaxTokens sets the maximum number of tokens to generate.
|
||||
func WithMaxTokens(n int) GenerateOption {
|
||||
return func(c *GenerateConfig) { c.MaxTokens = n }
|
||||
}
|
||||
|
||||
// WithTemperature sets the sampling temperature. 0 = greedy.
|
||||
func WithTemperature(t float32) GenerateOption {
|
||||
return func(c *GenerateConfig) { c.Temperature = t }
|
||||
}
|
||||
|
||||
// WithTopK sets top-k sampling. 0 = disabled.
|
||||
func WithTopK(k int) GenerateOption {
|
||||
return func(c *GenerateConfig) { c.TopK = k }
|
||||
}
|
||||
|
||||
// WithTopP sets nucleus sampling threshold. 0 = disabled.
|
||||
func WithTopP(p float32) GenerateOption {
|
||||
return func(c *GenerateConfig) { c.TopP = p }
|
||||
}
|
||||
|
||||
// WithStopTokens sets token IDs that stop generation.
|
||||
func WithStopTokens(ids ...int32) GenerateOption {
|
||||
return func(c *GenerateConfig) { c.StopTokens = ids }
|
||||
}
|
||||
|
||||
// LoadConfig holds model loading parameters.
|
||||
type LoadConfig struct {
|
||||
Backend string // "metal" (default), "mlx_lm"
|
||||
}
|
||||
|
||||
// LoadOption configures model loading.
|
||||
type LoadOption func(*LoadConfig)
|
||||
|
||||
// WithBackend selects a specific inference backend by name.
|
||||
func WithBackend(name string) LoadOption {
|
||||
return func(c *LoadConfig) { c.Backend = name }
|
||||
}
|
||||
|
||||
// ApplyGenerateOpts builds a GenerateConfig from options.
|
||||
func ApplyGenerateOpts(opts []GenerateOption) GenerateConfig {
|
||||
cfg := DefaultGenerateConfig()
|
||||
for _, o := range opts {
|
||||
o(&cfg)
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
|
||||
// ApplyLoadOpts builds a LoadConfig from options.
|
||||
func ApplyLoadOpts(opts []LoadOption) LoadConfig {
|
||||
var cfg LoadConfig
|
||||
for _, o := range opts {
|
||||
o(&cfg)
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
40
textmodel.go
Normal file
40
textmodel.go
Normal file
|
|
@ -0,0 +1,40 @@
|
|||
package mlx
|
||||
|
||||
import (
|
||||
"context"
|
||||
"iter"
|
||||
)
|
||||
|
||||
// Token represents a single generated token for streaming.
|
||||
type Token struct {
|
||||
ID int32
|
||||
Text string
|
||||
}
|
||||
|
||||
// Message represents a chat turn for Chat().
|
||||
type Message struct {
|
||||
Role string // "user", "assistant", "system"
|
||||
Content string
|
||||
}
|
||||
|
||||
// TextModel generates text from a loaded model.
|
||||
type TextModel interface {
|
||||
// Generate streams tokens for the given prompt.
|
||||
// Respects ctx cancellation (HTTP handlers, timeouts, graceful shutdown).
|
||||
Generate(ctx context.Context, prompt string, opts ...GenerateOption) iter.Seq[Token]
|
||||
|
||||
// Chat formats messages using the model's native chat template, then generates.
|
||||
// The model owns its template — callers don't need to know Gemma vs Qwen formatting.
|
||||
Chat(ctx context.Context, messages []Message, opts ...GenerateOption) iter.Seq[Token]
|
||||
|
||||
// ModelType returns the architecture identifier (e.g. "gemma3", "qwen3").
|
||||
ModelType() string
|
||||
|
||||
// Err returns the error from the last Generate/Chat call, if any.
|
||||
// Distinguishes normal stop (EOS, max tokens) from failures (OOM, C-level error).
|
||||
// Returns nil if generation completed normally.
|
||||
Err() error
|
||||
|
||||
// Close releases all resources (GPU memory, caches, subprocess).
|
||||
Close() error
|
||||
}
|
||||
Loading…
Add table
Reference in a new issue