feat: add GenerateMetrics type and Metrics() to TextModel

Expose prefill/decode timing, token counts, throughput, and GPU memory
stats from the last inference operation. Same retrieval pattern as Err().

Co-Authored-By: Virgil <virgil@lethean.io>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Snider 2026-02-19 23:34:31 +00:00
parent 2517b774b8
commit df176765e7

View file

@ -26,6 +26,7 @@ import (
"fmt"
"iter"
"sync"
"time"
)
// Token represents a single generated token for streaming.
@ -52,6 +53,27 @@ type BatchResult struct {
Err error // Per-prompt error (context cancel, OOM, etc.)
}
// GenerateMetrics holds performance metrics from the last inference operation.
// Retrieved via TextModel.Metrics() after Generate, Chat, Classify, or BatchGenerate.
type GenerateMetrics struct {
// Token counts
PromptTokens int // Input tokens (sum across batch for batch ops)
GeneratedTokens int // Output tokens generated
// Timing
PrefillDuration time.Duration // Time to process the prompt(s)
DecodeDuration time.Duration // Time for autoregressive decoding
TotalDuration time.Duration // Wall-clock time for the full operation
// Throughput (computed)
PrefillTokensPerSec float64 // PromptTokens / PrefillDuration
DecodeTokensPerSec float64 // GeneratedTokens / DecodeDuration
// Memory (Metal/GPU)
PeakMemoryBytes uint64 // Peak GPU memory during this operation
ActiveMemoryBytes uint64 // Active GPU memory after operation
}
// TextModel generates text from a loaded model.
type TextModel interface {
// Generate streams tokens for the given prompt.
@ -73,6 +95,10 @@ type TextModel interface {
// ModelType returns the architecture identifier (e.g. "gemma3", "qwen3", "llama3").
ModelType() string
// Metrics returns performance metrics from the last inference operation.
// Valid after Generate (iterator exhausted), Chat, Classify, or BatchGenerate.
Metrics() GenerateMetrics
// Err returns the error from the last Generate/Chat call, if any.
// Check this after the iterator stops to distinguish EOS from errors.
Err() error