go-inference/inference.go
Snider 15ee86ec62 fix: add json struct tags to Message for API serialization
Message is used as a type alias by go-ml. Without json tags,
HTTP backends that serialize messages to JSON (e.g. OpenAI API)
would produce incorrect field names.

Co-Authored-By: Virgil <virgil@lethean.io>
2026-02-20 02:05:46 +00:00

247 lines
7.9 KiB
Go

// Package inference defines shared interfaces for text generation backends.
//
// This package is the contract between GPU-specific backends (go-mlx, go-rocm)
// and consumers (go-ml, go-ai, go-i18n). It has zero dependencies and compiles
// on all platforms.
//
// # Backend registration
//
// Backend implementations register via init() with build tags:
//
// // go-mlx: //go:build darwin && arm64
// func init() { inference.Register(metal.NewBackend()) }
//
// // go-rocm: //go:build linux && amd64
// func init() { inference.Register(rocm.NewBackend()) }
//
// # Loading and generating
//
// m, err := inference.LoadModel("/path/to/model/")
// defer m.Close()
//
// ctx := context.Background()
// for tok := range m.Generate(ctx, "prompt", inference.WithMaxTokens(128)) {
// fmt.Print(tok.Text)
// }
// if err := m.Err(); err != nil { log.Fatal(err) }
//
// # Chat, classify, and batch generate
//
// [TextModel] supports multi-turn chat (with model-native templates),
// batch classification (prefill-only, fast path), and batch generation:
//
// // Chat
// for tok := range m.Chat(ctx, []inference.Message{
// {Role: "user", Content: "Hello"},
// }, inference.WithMaxTokens(64)) {
// fmt.Print(tok.Text)
// }
//
// // Classify — single forward pass per prompt
// results, _ := m.Classify(ctx, prompts, inference.WithTemperature(0))
//
// // Batch generate — parallel autoregressive decoding
// batched, _ := m.BatchGenerate(ctx, prompts, inference.WithMaxTokens(32))
//
// # Functional options
//
// Generation and loading are configured via functional options:
//
// inference.WithMaxTokens(256) // cap output length
// inference.WithTemperature(0.7) // sampling temperature
// inference.WithTopK(40) // top-k sampling
// inference.WithRepeatPenalty(1.1) // discourage repetition
// inference.WithContextLen(4096) // limit KV cache memory
//
// # Model discovery
//
// [Discover] scans a directory for model directories (config.json + *.safetensors):
//
// models, _ := inference.Discover("/path/to/models/")
// for _, d := range models {
// fmt.Printf("%s (%s)\n", d.Path, d.ModelType)
// }
package inference
import (
"context"
"fmt"
"iter"
"sync"
"time"
)
// Token represents a single generated token for streaming.
type Token struct {
ID int32
Text string
}
// Message represents a chat message for multi-turn conversation.
type Message struct {
Role string `json:"role"` // "system", "user", "assistant"
Content string `json:"content"`
}
// ClassifyResult holds the output for a single prompt in a batch classification.
type ClassifyResult struct {
Token Token // Sampled/greedy token at last prompt position
Logits []float32 // Raw vocab-sized logits (only when WithLogits is set)
}
// BatchResult holds the output for a single prompt in batch generation.
type BatchResult struct {
Tokens []Token // All generated tokens for this prompt
Err error // Per-prompt error (context cancel, OOM, etc.)
}
// GenerateMetrics holds performance metrics from the last inference operation.
// Retrieved via TextModel.Metrics() after Generate, Chat, Classify, or BatchGenerate.
type GenerateMetrics struct {
// Token counts
PromptTokens int // Input tokens (sum across batch for batch ops)
GeneratedTokens int // Output tokens generated
// Timing
PrefillDuration time.Duration // Time to process the prompt(s)
DecodeDuration time.Duration // Time for autoregressive decoding
TotalDuration time.Duration // Wall-clock time for the full operation
// Throughput (computed)
PrefillTokensPerSec float64 // PromptTokens / PrefillDuration
DecodeTokensPerSec float64 // GeneratedTokens / DecodeDuration
// Memory (Metal/GPU)
PeakMemoryBytes uint64 // Peak GPU memory during this operation
ActiveMemoryBytes uint64 // Active GPU memory after operation
}
// ModelInfo holds metadata about a loaded model.
type ModelInfo struct {
Architecture string // e.g. "gemma3", "qwen3", "llama"
VocabSize int // Vocabulary size
NumLayers int // Number of transformer layers
HiddenSize int // Hidden dimension
QuantBits int // Quantisation bits (0 = unquantised, 4 = 4-bit, 8 = 8-bit)
QuantGroup int // Quantisation group size (0 if unquantised)
}
// TextModel generates text from a loaded model.
type TextModel interface {
// Generate streams tokens for the given prompt.
Generate(ctx context.Context, prompt string, opts ...GenerateOption) iter.Seq[Token]
// Chat streams tokens from a multi-turn conversation.
// The model applies its native chat template.
Chat(ctx context.Context, messages []Message, opts ...GenerateOption) iter.Seq[Token]
// Classify runs batched prefill-only inference. Each prompt gets a single
// forward pass and the token at the last position is sampled. This is the
// fast path for classification tasks (e.g. domain labelling).
Classify(ctx context.Context, prompts []string, opts ...GenerateOption) ([]ClassifyResult, error)
// BatchGenerate runs batched autoregressive generation. Each prompt is
// decoded up to MaxTokens. Returns all generated tokens per prompt.
BatchGenerate(ctx context.Context, prompts []string, opts ...GenerateOption) ([]BatchResult, error)
// ModelType returns the architecture identifier (e.g. "gemma3", "qwen3", "llama3").
ModelType() string
// Info returns metadata about the loaded model (architecture, quantisation, etc.).
Info() ModelInfo
// Metrics returns performance metrics from the last inference operation.
// Valid after Generate (iterator exhausted), Chat, Classify, or BatchGenerate.
Metrics() GenerateMetrics
// Err returns the error from the last Generate/Chat call, if any.
// Check this after the iterator stops to distinguish EOS from errors.
Err() error
// Close releases all resources (GPU memory, caches, subprocess).
Close() error
}
// Backend is a named inference engine that can load models.
type Backend interface {
// Name returns the backend identifier (e.g. "metal", "rocm", "llama_cpp").
Name() string
// LoadModel loads a model from the given path.
LoadModel(path string, opts ...LoadOption) (TextModel, error)
// Available reports whether this backend can run on the current hardware.
Available() bool
}
var (
backendsMu sync.RWMutex
backends = map[string]Backend{}
)
// Register adds a backend to the registry. Typically called from init().
func Register(b Backend) {
backendsMu.Lock()
defer backendsMu.Unlock()
backends[b.Name()] = b
}
// Get returns a registered backend by name.
func Get(name string) (Backend, bool) {
backendsMu.RLock()
defer backendsMu.RUnlock()
b, ok := backends[name]
return b, ok
}
// List returns the names of all registered backends.
func List() []string {
backendsMu.RLock()
defer backendsMu.RUnlock()
names := make([]string, 0, len(backends))
for name := range backends {
names = append(names, name)
}
return names
}
// Default returns the first available backend.
// Prefers "metal" on macOS, "rocm" on Linux, then any registered backend.
func Default() (Backend, error) {
backendsMu.RLock()
defer backendsMu.RUnlock()
// Platform preference order
for _, name := range []string{"metal", "rocm", "llama_cpp"} {
if b, ok := backends[name]; ok && b.Available() {
return b, nil
}
}
// Fall back to any available
for _, b := range backends {
if b.Available() {
return b, nil
}
}
return nil, fmt.Errorf("inference: no backends registered (import a backend package)")
}
// LoadModel loads a model using the specified or default backend.
func LoadModel(path string, opts ...LoadOption) (TextModel, error) {
cfg := ApplyLoadOpts(opts)
if cfg.Backend != "" {
b, ok := Get(cfg.Backend)
if !ok {
return nil, fmt.Errorf("inference: backend %q not registered", cfg.Backend)
}
if !b.Available() {
return nil, fmt.Errorf("inference: backend %q not available on this hardware", cfg.Backend)
}
return b.LoadModel(path, opts...)
}
b, err := Default()
if err != nil {
return nil, err
}
return b.LoadModel(path, opts...)
}