Two new TextModel methods: Classify (prefill-only, fast path for classification) and BatchGenerate (autoregressive, multi-prompt). Adds attention masking for padded batches. Primary consumer: go-i18n Phase 2a domain classification at ~5K sentences/sec. Co-Authored-By: Virgil <virgil@lethean.io> Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
3.8 KiB
Batch Inference API Design
Goal: Enable batched inference (multiple prompts in one forward pass) for both classification (prefill-only) and generation (autoregressive).
Primary consumer: go-i18n Phase 2a — ~5K sentences/sec domain classification through Gemma3-1B.
API Surface (go-inference)
Two new methods on TextModel, two new result types:
type ClassifyResult struct {
Token Token // Sampled token at last prompt position
Logits []float32 // Raw vocab-sized logits (when WithLogits() set)
}
type BatchResult struct {
Tokens []Token // All generated tokens for this prompt
Err error // Per-prompt error (context cancel, OOM, etc.)
}
// New methods on TextModel:
Classify(ctx context.Context, prompts []string, opts ...GenerateOption) ([]ClassifyResult, error)
BatchGenerate(ctx context.Context, prompts []string, opts ...GenerateOption) ([]BatchResult, error)
New option: WithLogits() — include raw logits in ClassifyResult (off by default to save memory).
Architecture
Classify (prefill-only, fast path)
- Encode all prompts via tokenizer
- Pad to max sequence length with pad token (0)
- Build attention mask:
[N, 1, maxLen, maxLen]combining causal + padding - Single forward pass:
[N, maxLen]→[N, maxLen, vocab] - Gather logits at each prompt's last real token position
- Sample (greedy/temperature) per prompt
- Return
[]ClassifyResult
No KV caches needed — single pass, no autoregressive decode.
BatchGenerate (autoregressive)
- Prefill: same as Classify steps 1-5
- Decode loop (up to MaxTokens iterations):
- Sample next token for all active sequences
- Check stop conditions per sequence (EOS, stop tokens)
- Mark finished sequences; continue until all done
- Append
[N, 1]next tokens, forward through model with KV caches
- Collect per-sequence token lists into
[]BatchResult
Simple batch management: keep all sequences in the batch even after some finish (pad with zeros). This wastes some compute but avoids complex batch compaction.
Attention Mask
For padded batches, SDPA needs a combined causal + padding mask:
mask[b, 0, i, j] = 0.0 if j <= i AND j < promptLen[b] (attend)
mask[b, 0, i, j] = -inf otherwise (don't attend)
Shape: [N, 1, maxLen, maxLen] — broadcasts across attention heads.
Uses existing ScaledDotProductAttentionWithMask.
Model Forward Changes
Add ForwardMasked(tokens, mask *Array, caches []Cache) *Array to InternalModel. Threads the mask through to each attention layer. Existing Forward unchanged (causal-only, no mask).
Changes to attention layers in gemma3.go and qwen3.go:
- Accept optional mask parameter
- Use
ScaledDotProductAttentionWithMaskwhen mask is provided - Fall back to
ScaledDotProductAttentionwith causal bool when nil
Padding and Tokenisation
- Pad token: 0 (standard for most models, verified for Gemma3/Qwen/Llama)
- Left-pad or right-pad? Right-pad — simpler, causal mask handles it.
Last real token position =
promptLen[b] - 1. - Sequences sorted by length descending before batching (reduces padding waste)
Implementation Location
go-inference: types (ClassifyResult,BatchResult), option (WithLogits), interface methodsinternal/metal/batch.go:BatchClassify,BatchGenerateimplementationinternal/metal/generate.go: wire throughModel.Classify,Model.BatchGenerateinternal/metal/gemma3.go,qwen3.go: addForwardMaskedmethod
Testing
- Unit: mask construction (causal + padding) with known shapes
- Classify: batch of 4 prompts through Gemma3-1B, verify each gets a sensible token
- BatchGenerate: batch of 2-3 prompts, verify coherent output per prompt
- Throughput: benchmark Classify at batch sizes 1, 8, 32, measure sentences/sec