diff --git a/inference.go b/inference.go index af86f91..1fbf744 100644 --- a/inference.go +++ b/inference.go @@ -26,6 +26,7 @@ import ( "fmt" "iter" "sync" + "time" ) // Token represents a single generated token for streaming. @@ -52,6 +53,27 @@ type BatchResult struct { Err error // Per-prompt error (context cancel, OOM, etc.) } +// GenerateMetrics holds performance metrics from the last inference operation. +// Retrieved via TextModel.Metrics() after Generate, Chat, Classify, or BatchGenerate. +type GenerateMetrics struct { + // Token counts + PromptTokens int // Input tokens (sum across batch for batch ops) + GeneratedTokens int // Output tokens generated + + // Timing + PrefillDuration time.Duration // Time to process the prompt(s) + DecodeDuration time.Duration // Time for autoregressive decoding + TotalDuration time.Duration // Wall-clock time for the full operation + + // Throughput (computed) + PrefillTokensPerSec float64 // PromptTokens / PrefillDuration + DecodeTokensPerSec float64 // GeneratedTokens / DecodeDuration + + // Memory (Metal/GPU) + PeakMemoryBytes uint64 // Peak GPU memory during this operation + ActiveMemoryBytes uint64 // Active GPU memory after operation +} + // TextModel generates text from a loaded model. type TextModel interface { // Generate streams tokens for the given prompt. @@ -73,6 +95,10 @@ type TextModel interface { // ModelType returns the architecture identifier (e.g. "gemma3", "qwen3", "llama3"). ModelType() string + // Metrics returns performance metrics from the last inference operation. + // Valid after Generate (iterator exhausted), Chat, Classify, or BatchGenerate. + Metrics() GenerateMetrics + // Err returns the error from the last Generate/Chat call, if any. // Check this after the iterator stops to distinguish EOS from errors. Err() error