diff --git a/inference.go b/inference.go index 8724714..af86f91 100644 --- a/inference.go +++ b/inference.go @@ -40,6 +40,18 @@ type Message struct { Content string } +// ClassifyResult holds the output for a single prompt in a batch classification. +type ClassifyResult struct { + Token Token // Sampled/greedy token at last prompt position + Logits []float32 // Raw vocab-sized logits (only when WithLogits is set) +} + +// BatchResult holds the output for a single prompt in batch generation. +type BatchResult struct { + Tokens []Token // All generated tokens for this prompt + Err error // Per-prompt error (context cancel, OOM, etc.) +} + // TextModel generates text from a loaded model. type TextModel interface { // Generate streams tokens for the given prompt. @@ -49,6 +61,15 @@ type TextModel interface { // The model applies its native chat template. Chat(ctx context.Context, messages []Message, opts ...GenerateOption) iter.Seq[Token] + // Classify runs batched prefill-only inference. Each prompt gets a single + // forward pass and the token at the last position is sampled. This is the + // fast path for classification tasks (e.g. domain labelling). + Classify(ctx context.Context, prompts []string, opts ...GenerateOption) ([]ClassifyResult, error) + + // BatchGenerate runs batched autoregressive generation. Each prompt is + // decoded up to MaxTokens. Returns all generated tokens per prompt. + BatchGenerate(ctx context.Context, prompts []string, opts ...GenerateOption) ([]BatchResult, error) + // ModelType returns the architecture identifier (e.g. "gemma3", "qwen3", "llama3"). ModelType() string diff --git a/options.go b/options.go index 5546a11..c5f48e9 100644 --- a/options.go +++ b/options.go @@ -2,12 +2,13 @@ package inference // GenerateConfig holds generation parameters. type GenerateConfig struct { - MaxTokens int - Temperature float32 - TopK int - TopP float32 - StopTokens []int32 + MaxTokens int + Temperature float32 + TopK int + TopP float32 + StopTokens []int32 RepeatPenalty float32 + ReturnLogits bool // Return raw logits in ClassifyResult (default false) } // DefaultGenerateConfig returns sensible defaults. @@ -51,6 +52,11 @@ func WithRepeatPenalty(p float32) GenerateOption { return func(c *GenerateConfig) { c.RepeatPenalty = p } } +// WithLogits requests raw logits in ClassifyResult. Off by default to save memory. +func WithLogits() GenerateOption { + return func(c *GenerateConfig) { c.ReturnLogits = true } +} + // ApplyGenerateOpts builds a GenerateConfig from options. func ApplyGenerateOpts(opts []GenerateOption) GenerateConfig { cfg := DefaultGenerateConfig()