go-inference/inference.go
Snider 07cd917259 feat: define shared TextModel, Backend, Token, Message interfaces
Zero-dependency interface package for the Core inference ecosystem.
Backends (go-mlx, go-rocm) implement these interfaces.
Consumers (go-ml, go-ai, go-i18n) import them.

Includes:
- TextModel: Generate, Chat, Err, Close (with context.Context)
- Backend: Named engine registry with platform preference
- Functional options: WithMaxTokens, WithTemperature, WithTopK, etc.
- LoadModel: Auto-selects best available backend

Co-Authored-By: Virgil <virgil@lethean.io>
2026-02-19 19:37:27 +00:00

145 lines
3.9 KiB
Go

// Package inference defines shared interfaces for text generation backends.
//
// This package is the contract between GPU-specific backends (go-mlx, go-rocm)
// and consumers (go-ml, go-ai, go-i18n). It has zero dependencies and compiles
// on all platforms.
//
// Backend implementations register via init() with build tags:
//
// // go-mlx: //go:build darwin && arm64
// func init() { inference.Register(metal.NewBackend()) }
//
// // go-rocm: //go:build linux && amd64
// func init() { inference.Register(rocm.NewBackend()) }
//
// Consumers load models via the registry:
//
// m, err := inference.LoadModel("/path/to/model/")
// defer m.Close()
// for tok := range m.Generate(ctx, "prompt", inference.WithMaxTokens(128)) {
// fmt.Print(tok.Text)
// }
package inference
import (
"context"
"fmt"
"iter"
"sync"
)
// Token represents a single generated token for streaming.
type Token struct {
ID int32
Text string
}
// Message represents a chat message for multi-turn conversation.
type Message struct {
Role string // "system", "user", "assistant"
Content string
}
// TextModel generates text from a loaded model.
type TextModel interface {
// Generate streams tokens for the given prompt.
Generate(ctx context.Context, prompt string, opts ...GenerateOption) iter.Seq[Token]
// Chat streams tokens from a multi-turn conversation.
// The model applies its native chat template.
Chat(ctx context.Context, messages []Message, opts ...GenerateOption) iter.Seq[Token]
// ModelType returns the architecture identifier (e.g. "gemma3", "qwen3", "llama3").
ModelType() string
// Err returns the error from the last Generate/Chat call, if any.
// Check this after the iterator stops to distinguish EOS from errors.
Err() error
// Close releases all resources (GPU memory, caches, subprocess).
Close() error
}
// Backend is a named inference engine that can load models.
type Backend interface {
// Name returns the backend identifier (e.g. "metal", "rocm", "llama_cpp").
Name() string
// LoadModel loads a model from the given path.
LoadModel(path string, opts ...LoadOption) (TextModel, error)
// Available reports whether this backend can run on the current hardware.
Available() bool
}
var (
backendsMu sync.RWMutex
backends = map[string]Backend{}
)
// Register adds a backend to the registry. Typically called from init().
func Register(b Backend) {
backendsMu.Lock()
defer backendsMu.Unlock()
backends[b.Name()] = b
}
// Get returns a registered backend by name.
func Get(name string) (Backend, bool) {
backendsMu.RLock()
defer backendsMu.RUnlock()
b, ok := backends[name]
return b, ok
}
// List returns the names of all registered backends.
func List() []string {
backendsMu.RLock()
defer backendsMu.RUnlock()
names := make([]string, 0, len(backends))
for name := range backends {
names = append(names, name)
}
return names
}
// Default returns the first available backend.
// Prefers "metal" on macOS, "rocm" on Linux, then any registered backend.
func Default() (Backend, error) {
backendsMu.RLock()
defer backendsMu.RUnlock()
// Platform preference order
for _, name := range []string{"metal", "rocm", "llama_cpp"} {
if b, ok := backends[name]; ok && b.Available() {
return b, nil
}
}
// Fall back to any available
for _, b := range backends {
if b.Available() {
return b, nil
}
}
return nil, fmt.Errorf("inference: no backends registered (import a backend package)")
}
// LoadModel loads a model using the specified or default backend.
func LoadModel(path string, opts ...LoadOption) (TextModel, error) {
cfg := ApplyLoadOpts(opts)
if cfg.Backend != "" {
b, ok := Get(cfg.Backend)
if !ok {
return nil, fmt.Errorf("inference: backend %q not registered", cfg.Backend)
}
if !b.Available() {
return nil, fmt.Errorf("inference: backend %q not available on this hardware", cfg.Backend)
}
return b.LoadModel(path, opts...)
}
b, err := Default()
if err != nil {
return nil, err
}
return b.LoadModel(path, opts...)
}