// Package inference defines shared interfaces for text generation backends. // // This package is the contract between GPU-specific backends (go-mlx, go-rocm) // and consumers (go-ml, go-ai, go-i18n). It has zero dependencies and compiles // on all platforms. // // # Backend registration // // Backend implementations register via init() with build tags: // // // go-mlx: //go:build darwin && arm64 // func init() { inference.Register(metal.NewBackend()) } // // // go-rocm: //go:build linux && amd64 // func init() { inference.Register(rocm.NewBackend()) } // // # Loading and generating // // m, err := inference.LoadModel("/path/to/model/") // defer m.Close() // // ctx := context.Background() // for tok := range m.Generate(ctx, "prompt", inference.WithMaxTokens(128)) { // fmt.Print(tok.Text) // } // if err := m.Err(); err != nil { log.Fatal(err) } // // # Chat, classify, and batch generate // // [TextModel] supports multi-turn chat (with model-native templates), // batch classification (prefill-only, fast path), and batch generation: // // // Chat // for tok := range m.Chat(ctx, []inference.Message{ // {Role: "user", Content: "Hello"}, // }, inference.WithMaxTokens(64)) { // fmt.Print(tok.Text) // } // // // Classify — single forward pass per prompt // results, _ := m.Classify(ctx, prompts, inference.WithTemperature(0)) // // // Batch generate — parallel autoregressive decoding // batched, _ := m.BatchGenerate(ctx, prompts, inference.WithMaxTokens(32)) // // # Functional options // // Generation and loading are configured via functional options: // // inference.WithMaxTokens(256) // cap output length // inference.WithTemperature(0.7) // sampling temperature // inference.WithTopK(40) // top-k sampling // inference.WithRepeatPenalty(1.1) // discourage repetition // inference.WithContextLen(4096) // limit KV cache memory // // # Model discovery // // [Discover] scans a directory for model directories (config.json + *.safetensors): // // models, _ := inference.Discover("/path/to/models/") // for _, d := range models { // fmt.Printf("%s (%s)\n", d.Path, d.ModelType) // } package inference import ( "context" "fmt" "iter" "maps" "slices" "sync" "time" ) // Token represents a single generated token for streaming. type Token struct { ID int32 Text string } // Message represents a chat message for multi-turn conversation. type Message struct { Role string `json:"role"` // "system", "user", "assistant" Content string `json:"content"` } // ClassifyResult holds the output for a single prompt in a batch classification. type ClassifyResult struct { Token Token // Sampled/greedy token at last prompt position Logits []float32 // Raw vocab-sized logits (only when WithLogits is set) } // BatchResult holds the output for a single prompt in batch generation. type BatchResult struct { Tokens []Token // All generated tokens for this prompt Err error // Per-prompt error (context cancel, OOM, etc.) } // GenerateMetrics holds performance metrics from the last inference operation. // Retrieved via TextModel.Metrics() after Generate, Chat, Classify, or BatchGenerate. type GenerateMetrics struct { // Token counts PromptTokens int // Input tokens (sum across batch for batch ops) GeneratedTokens int // Output tokens generated // Timing PrefillDuration time.Duration // Time to process the prompt(s) DecodeDuration time.Duration // Time for autoregressive decoding TotalDuration time.Duration // Wall-clock time for the full operation // Throughput (computed) PrefillTokensPerSec float64 // PromptTokens / PrefillDuration DecodeTokensPerSec float64 // GeneratedTokens / DecodeDuration // Memory (Metal/GPU) PeakMemoryBytes uint64 // Peak GPU memory during this operation ActiveMemoryBytes uint64 // Active GPU memory after operation } // ModelInfo holds metadata about a loaded model. type ModelInfo struct { Architecture string // e.g. "gemma3", "qwen3", "llama" VocabSize int // Vocabulary size NumLayers int // Number of transformer layers HiddenSize int // Hidden dimension QuantBits int // Quantisation bits (0 = unquantised, 4 = 4-bit, 8 = 8-bit) QuantGroup int // Quantisation group size (0 if unquantised) } // AttentionSnapshot holds K vectors extracted from the KV cache after prefill. // Keys is indexed [layer][head][position*head_dim] — flattened per head. type AttentionSnapshot struct { NumLayers int `json:"num_layers"` NumHeads int `json:"num_heads"` // num_kv_heads (may differ from query heads in GQA) SeqLen int `json:"seq_len"` // number of tokens in the prompt HeadDim int `json:"head_dim"` Keys [][][]float32 `json:"keys"` // [layer][head] → flat float32 of len seq_len*head_dim Architecture string `json:"architecture"` } // AttentionInspector is an optional interface that backends may implement // to expose attention-level data for Q/K Bone Orientation analysis. // Use type assertion: if inspector, ok := model.(AttentionInspector); ok { ... } type AttentionInspector interface { InspectAttention(ctx context.Context, prompt string, opts ...GenerateOption) (*AttentionSnapshot, error) } // TextModel generates text from a loaded model. type TextModel interface { // Generate streams tokens for the given prompt. Generate(ctx context.Context, prompt string, opts ...GenerateOption) iter.Seq[Token] // Chat streams tokens from a multi-turn conversation. // The model applies its native chat template. Chat(ctx context.Context, messages []Message, opts ...GenerateOption) iter.Seq[Token] // Classify runs batched prefill-only inference. Each prompt gets a single // forward pass and the token at the last position is sampled. This is the // fast path for classification tasks (e.g. domain labelling). Classify(ctx context.Context, prompts []string, opts ...GenerateOption) ([]ClassifyResult, error) // BatchGenerate runs batched autoregressive generation. Each prompt is // decoded up to MaxTokens. Returns all generated tokens per prompt. BatchGenerate(ctx context.Context, prompts []string, opts ...GenerateOption) ([]BatchResult, error) // ModelType returns the architecture identifier (e.g. "gemma3", "qwen3", "llama3"). ModelType() string // Info returns metadata about the loaded model (architecture, quantisation, etc.). Info() ModelInfo // Metrics returns performance metrics from the last inference operation. // Valid after Generate (iterator exhausted), Chat, Classify, or BatchGenerate. Metrics() GenerateMetrics // Err returns the error from the last Generate/Chat call, if any. // Check this after the iterator stops to distinguish EOS from errors. Err() error // Close releases all resources (GPU memory, caches, subprocess). Close() error } // Backend is a named inference engine that can load models. type Backend interface { // Name returns the backend identifier (e.g. "metal", "rocm", "llama_cpp"). Name() string // LoadModel loads a model from the given path. LoadModel(path string, opts ...LoadOption) (TextModel, error) // Available reports whether this backend can run on the current hardware. Available() bool } var ( backendsMu sync.RWMutex backends = map[string]Backend{} ) // Register adds a backend to the registry. Typically called from init(). func Register(b Backend) { backendsMu.Lock() defer backendsMu.Unlock() backends[b.Name()] = b } // Get returns a registered backend by name. func Get(name string) (Backend, bool) { backendsMu.RLock() defer backendsMu.RUnlock() b, ok := backends[name] return b, ok } // List returns the names of all registered backends in alphabetical order. func List() []string { backendsMu.RLock() defer backendsMu.RUnlock() return slices.Sorted(maps.Keys(backends)) } // All returns an iterator over all registered backends. func All() iter.Seq2[string, Backend] { return func(yield func(string, Backend) bool) { backendsMu.RLock() defer backendsMu.RUnlock() for k, v := range backends { if !yield(k, v) { return } } } } // Default returns the first available backend. // Prefers "metal" on macOS, "rocm" on Linux, then any registered backend. func Default() (Backend, error) { backendsMu.RLock() defer backendsMu.RUnlock() // Platform preference order for _, name := range []string{"metal", "rocm", "llama_cpp"} { if b, ok := backends[name]; ok && b.Available() { return b, nil } } // Fall back to any available for b := range maps.Values(backends) { if b.Available() { return b, nil } } return nil, fmt.Errorf("inference: no backends registered (import a backend package)") } // LoadModel loads a model using the specified or default backend. func LoadModel(path string, opts ...LoadOption) (TextModel, error) { cfg := ApplyLoadOpts(opts) if cfg.Backend != "" { b, ok := Get(cfg.Backend) if !ok { return nil, fmt.Errorf("inference: backend %q not registered", cfg.Backend) } if !b.Available() { return nil, fmt.Errorf("inference: backend %q not available on this hardware", cfg.Backend) } return b.LoadModel(path, opts...) } b, err := Default() if err != nil { return nil, err } return b.LoadModel(path, opts...) }