feat: add batch inference API (Classify, BatchGenerate)
Add ClassifyResult, BatchResult types and Classify/BatchGenerate methods to TextModel for batched prefill-only and autoregressive inference. Add WithLogits option for returning raw vocab logits. Co-Authored-By: Virgil <virgil@lethean.io> Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
3719734f56
commit
2517b774b8
2 changed files with 32 additions and 5 deletions
21
inference.go
21
inference.go
|
|
@ -40,6 +40,18 @@ type Message struct {
|
|||
Content string
|
||||
}
|
||||
|
||||
// ClassifyResult holds the output for a single prompt in a batch classification.
|
||||
type ClassifyResult struct {
|
||||
Token Token // Sampled/greedy token at last prompt position
|
||||
Logits []float32 // Raw vocab-sized logits (only when WithLogits is set)
|
||||
}
|
||||
|
||||
// BatchResult holds the output for a single prompt in batch generation.
|
||||
type BatchResult struct {
|
||||
Tokens []Token // All generated tokens for this prompt
|
||||
Err error // Per-prompt error (context cancel, OOM, etc.)
|
||||
}
|
||||
|
||||
// TextModel generates text from a loaded model.
|
||||
type TextModel interface {
|
||||
// Generate streams tokens for the given prompt.
|
||||
|
|
@ -49,6 +61,15 @@ type TextModel interface {
|
|||
// The model applies its native chat template.
|
||||
Chat(ctx context.Context, messages []Message, opts ...GenerateOption) iter.Seq[Token]
|
||||
|
||||
// Classify runs batched prefill-only inference. Each prompt gets a single
|
||||
// forward pass and the token at the last position is sampled. This is the
|
||||
// fast path for classification tasks (e.g. domain labelling).
|
||||
Classify(ctx context.Context, prompts []string, opts ...GenerateOption) ([]ClassifyResult, error)
|
||||
|
||||
// BatchGenerate runs batched autoregressive generation. Each prompt is
|
||||
// decoded up to MaxTokens. Returns all generated tokens per prompt.
|
||||
BatchGenerate(ctx context.Context, prompts []string, opts ...GenerateOption) ([]BatchResult, error)
|
||||
|
||||
// ModelType returns the architecture identifier (e.g. "gemma3", "qwen3", "llama3").
|
||||
ModelType() string
|
||||
|
||||
|
|
|
|||
16
options.go
16
options.go
|
|
@ -2,12 +2,13 @@ package inference
|
|||
|
||||
// GenerateConfig holds generation parameters.
|
||||
type GenerateConfig struct {
|
||||
MaxTokens int
|
||||
Temperature float32
|
||||
TopK int
|
||||
TopP float32
|
||||
StopTokens []int32
|
||||
MaxTokens int
|
||||
Temperature float32
|
||||
TopK int
|
||||
TopP float32
|
||||
StopTokens []int32
|
||||
RepeatPenalty float32
|
||||
ReturnLogits bool // Return raw logits in ClassifyResult (default false)
|
||||
}
|
||||
|
||||
// DefaultGenerateConfig returns sensible defaults.
|
||||
|
|
@ -51,6 +52,11 @@ func WithRepeatPenalty(p float32) GenerateOption {
|
|||
return func(c *GenerateConfig) { c.RepeatPenalty = p }
|
||||
}
|
||||
|
||||
// WithLogits requests raw logits in ClassifyResult. Off by default to save memory.
|
||||
func WithLogits() GenerateOption {
|
||||
return func(c *GenerateConfig) { c.ReturnLogits = true }
|
||||
}
|
||||
|
||||
// ApplyGenerateOpts builds a GenerateConfig from options.
|
||||
func ApplyGenerateOpts(opts []GenerateOption) GenerateConfig {
|
||||
cfg := DefaultGenerateConfig()
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue