feat: add batch inference API (Classify, BatchGenerate)

Add ClassifyResult, BatchResult types and Classify/BatchGenerate methods
to TextModel for batched prefill-only and autoregressive inference.
Add WithLogits option for returning raw vocab logits.

Co-Authored-By: Virgil <virgil@lethean.io>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Snider 2026-02-19 23:27:53 +00:00
parent 3719734f56
commit 2517b774b8
2 changed files with 32 additions and 5 deletions

View file

@ -40,6 +40,18 @@ type Message struct {
Content string
}
// ClassifyResult holds the output for a single prompt in a batch classification.
type ClassifyResult struct {
Token Token // Sampled/greedy token at last prompt position
Logits []float32 // Raw vocab-sized logits (only when WithLogits is set)
}
// BatchResult holds the output for a single prompt in batch generation.
type BatchResult struct {
Tokens []Token // All generated tokens for this prompt
Err error // Per-prompt error (context cancel, OOM, etc.)
}
// TextModel generates text from a loaded model.
type TextModel interface {
// Generate streams tokens for the given prompt.
@ -49,6 +61,15 @@ type TextModel interface {
// The model applies its native chat template.
Chat(ctx context.Context, messages []Message, opts ...GenerateOption) iter.Seq[Token]
// Classify runs batched prefill-only inference. Each prompt gets a single
// forward pass and the token at the last position is sampled. This is the
// fast path for classification tasks (e.g. domain labelling).
Classify(ctx context.Context, prompts []string, opts ...GenerateOption) ([]ClassifyResult, error)
// BatchGenerate runs batched autoregressive generation. Each prompt is
// decoded up to MaxTokens. Returns all generated tokens per prompt.
BatchGenerate(ctx context.Context, prompts []string, opts ...GenerateOption) ([]BatchResult, error)
// ModelType returns the architecture identifier (e.g. "gemma3", "qwen3", "llama3").
ModelType() string

View file

@ -2,12 +2,13 @@ package inference
// GenerateConfig holds generation parameters.
type GenerateConfig struct {
MaxTokens int
Temperature float32
TopK int
TopP float32
StopTokens []int32
MaxTokens int
Temperature float32
TopK int
TopP float32
StopTokens []int32
RepeatPenalty float32
ReturnLogits bool // Return raw logits in ClassifyResult (default false)
}
// DefaultGenerateConfig returns sensible defaults.
@ -51,6 +52,11 @@ func WithRepeatPenalty(p float32) GenerateOption {
return func(c *GenerateConfig) { c.RepeatPenalty = p }
}
// WithLogits requests raw logits in ClassifyResult. Off by default to save memory.
func WithLogits() GenerateOption {
return func(c *GenerateConfig) { c.ReturnLogits = true }
}
// ApplyGenerateOpts builds a GenerateConfig from options.
func ApplyGenerateOpts(opts []GenerateOption) GenerateConfig {
cfg := DefaultGenerateConfig()