From 2517b774b8a31f783643eefdf49e5215a9b271d7 Mon Sep 17 00:00:00 2001 From: Snider Date: Thu, 19 Feb 2026 23:27:53 +0000 Subject: [PATCH] feat: add batch inference API (Classify, BatchGenerate) Add ClassifyResult, BatchResult types and Classify/BatchGenerate methods to TextModel for batched prefill-only and autoregressive inference. Add WithLogits option for returning raw vocab logits. Co-Authored-By: Virgil Co-Authored-By: Claude Opus 4.6 --- inference.go | 21 +++++++++++++++++++++ options.go | 16 +++++++++++----- 2 files changed, 32 insertions(+), 5 deletions(-) diff --git a/inference.go b/inference.go index 8724714..af86f91 100644 --- a/inference.go +++ b/inference.go @@ -40,6 +40,18 @@ type Message struct { Content string } +// ClassifyResult holds the output for a single prompt in a batch classification. +type ClassifyResult struct { + Token Token // Sampled/greedy token at last prompt position + Logits []float32 // Raw vocab-sized logits (only when WithLogits is set) +} + +// BatchResult holds the output for a single prompt in batch generation. +type BatchResult struct { + Tokens []Token // All generated tokens for this prompt + Err error // Per-prompt error (context cancel, OOM, etc.) +} + // TextModel generates text from a loaded model. type TextModel interface { // Generate streams tokens for the given prompt. @@ -49,6 +61,15 @@ type TextModel interface { // The model applies its native chat template. Chat(ctx context.Context, messages []Message, opts ...GenerateOption) iter.Seq[Token] + // Classify runs batched prefill-only inference. Each prompt gets a single + // forward pass and the token at the last position is sampled. This is the + // fast path for classification tasks (e.g. domain labelling). + Classify(ctx context.Context, prompts []string, opts ...GenerateOption) ([]ClassifyResult, error) + + // BatchGenerate runs batched autoregressive generation. Each prompt is + // decoded up to MaxTokens. Returns all generated tokens per prompt. + BatchGenerate(ctx context.Context, prompts []string, opts ...GenerateOption) ([]BatchResult, error) + // ModelType returns the architecture identifier (e.g. "gemma3", "qwen3", "llama3"). ModelType() string diff --git a/options.go b/options.go index 5546a11..c5f48e9 100644 --- a/options.go +++ b/options.go @@ -2,12 +2,13 @@ package inference // GenerateConfig holds generation parameters. type GenerateConfig struct { - MaxTokens int - Temperature float32 - TopK int - TopP float32 - StopTokens []int32 + MaxTokens int + Temperature float32 + TopK int + TopP float32 + StopTokens []int32 RepeatPenalty float32 + ReturnLogits bool // Return raw logits in ClassifyResult (default false) } // DefaultGenerateConfig returns sensible defaults. @@ -51,6 +52,11 @@ func WithRepeatPenalty(p float32) GenerateOption { return func(c *GenerateConfig) { c.RepeatPenalty = p } } +// WithLogits requests raw logits in ClassifyResult. Off by default to save memory. +func WithLogits() GenerateOption { + return func(c *GenerateConfig) { c.ReturnLogits = true } +} + // ApplyGenerateOpts builds a GenerateConfig from options. func ApplyGenerateOpts(opts []GenerateOption) GenerateConfig { cfg := DefaultGenerateConfig()