feat: benchmark suite for decode speed, TTFT, and concurrent throughput
BenchmarkDecode, BenchmarkTTFT, BenchmarkConcurrent across 3 models. Uses testing.B with b.ReportMetric for Go benchstat integration. Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
parent
72120bb200
commit
870ee232bf
1 changed files with 181 additions and 0 deletions
181
rocm_benchmark_test.go
Normal file
181
rocm_benchmark_test.go
Normal file
|
|
@ -0,0 +1,181 @@
|
|||
//go:build rocm
|
||||
|
||||
package rocm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"forge.lthn.ai/core/go-inference"
|
||||
)
|
||||
|
||||
// benchModels lists the models to benchmark.
|
||||
// Each loads ~3-9 GB of VRAM, so they run sequentially (one at a time).
|
||||
var benchModels = []struct {
|
||||
name string
|
||||
path string
|
||||
}{
|
||||
{"Gemma3-4B-Q4_K_M", "/data/lem/gguf/LEK-Gemma3-4B-Q4_K_M.gguf"},
|
||||
{"Llama3.1-8B-Q4_K_M", "/data/lem/gguf/LEK-Llama-3.1-8B-Q4_K_M.gguf"},
|
||||
{"Qwen2.5-7B-Q4_K_M", "/data/lem/gguf/LEK-Qwen-2.5-7B-Q4_K_M.gguf"},
|
||||
}
|
||||
|
||||
func skipBenchIfUnavailable(b *testing.B) {
|
||||
b.Helper()
|
||||
if _, err := os.Stat("/dev/kfd"); err != nil {
|
||||
b.Skip("no ROCm hardware")
|
||||
}
|
||||
if _, err := findLlamaServer(); err != nil {
|
||||
b.Skip("llama-server not found")
|
||||
}
|
||||
}
|
||||
|
||||
// loadBenchModel loads a model for benchmarking. Caller must defer m.Close().
|
||||
// Stops the benchmark timer during loading so load time is excluded.
|
||||
func loadBenchModel(b *testing.B, path string, opts ...inference.LoadOption) inference.TextModel {
|
||||
b.Helper()
|
||||
if _, err := os.Stat(path); err != nil {
|
||||
b.Skipf("model not available: %s", path)
|
||||
}
|
||||
|
||||
b.StopTimer()
|
||||
backend := &rocmBackend{}
|
||||
defaults := []inference.LoadOption{inference.WithContextLen(2048)}
|
||||
m, err := backend.LoadModel(path, append(defaults, opts...)...)
|
||||
if err != nil {
|
||||
b.Fatalf("load model: %v", err)
|
||||
}
|
||||
|
||||
if vram, err := GetVRAMInfo(); err == nil {
|
||||
b.Logf("VRAM after load: %d MiB used / %d MiB total",
|
||||
vram.Used/(1024*1024), vram.Total/(1024*1024))
|
||||
}
|
||||
|
||||
b.StartTimer()
|
||||
return m
|
||||
}
|
||||
|
||||
// BenchmarkDecode measures token generation speed (tok/s).
|
||||
// Generates 128 tokens per iteration, reports tokens/second.
|
||||
func BenchmarkDecode(b *testing.B) {
|
||||
skipBenchIfUnavailable(b)
|
||||
|
||||
for _, model := range benchModels {
|
||||
b.Run(model.name, func(b *testing.B) {
|
||||
m := loadBenchModel(b, model.path)
|
||||
defer m.Close()
|
||||
|
||||
const maxTok = 128
|
||||
|
||||
var totalTokens int
|
||||
b.ResetTimer()
|
||||
for range b.N {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
|
||||
var count int
|
||||
for range m.Generate(ctx, "Explain the theory of relativity in detail.", inference.WithMaxTokens(maxTok)) {
|
||||
count++
|
||||
}
|
||||
cancel()
|
||||
totalTokens += count
|
||||
}
|
||||
b.StopTimer()
|
||||
|
||||
if totalTokens > 0 {
|
||||
elapsed := b.Elapsed()
|
||||
tokPerSec := float64(totalTokens) / elapsed.Seconds()
|
||||
b.ReportMetric(tokPerSec, "tok/s")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkTTFT measures time-to-first-token latency.
|
||||
// Reports the average time from request start to first token received.
|
||||
func BenchmarkTTFT(b *testing.B) {
|
||||
skipBenchIfUnavailable(b)
|
||||
|
||||
for _, model := range benchModels {
|
||||
b.Run(model.name, func(b *testing.B) {
|
||||
m := loadBenchModel(b, model.path)
|
||||
defer m.Close()
|
||||
|
||||
var totalTTFT time.Duration
|
||||
b.ResetTimer()
|
||||
for range b.N {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
start := time.Now()
|
||||
for range m.Generate(ctx, "Hello", inference.WithMaxTokens(1)) {
|
||||
totalTTFT += time.Since(start)
|
||||
}
|
||||
cancel()
|
||||
}
|
||||
b.StopTimer()
|
||||
|
||||
if b.N > 0 {
|
||||
avgTTFT := totalTTFT / time.Duration(b.N)
|
||||
b.ReportMetric(float64(avgTTFT.Microseconds()), "µs/first-tok")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkConcurrent measures throughput with multiple goroutines.
|
||||
// Uses 4 parallel slots and 4 goroutines generating simultaneously.
|
||||
func BenchmarkConcurrent(b *testing.B) {
|
||||
skipBenchIfUnavailable(b)
|
||||
|
||||
for _, model := range benchModels {
|
||||
b.Run(model.name, func(b *testing.B) {
|
||||
m := loadBenchModel(b, model.path, inference.WithParallelSlots(4))
|
||||
defer m.Close()
|
||||
|
||||
const numWorkers = 4
|
||||
const maxTok = 32
|
||||
|
||||
prompts := []string{
|
||||
"The capital of France is",
|
||||
"The capital of Germany is",
|
||||
"The capital of Italy is",
|
||||
"The capital of Spain is",
|
||||
}
|
||||
|
||||
var totalTokens int
|
||||
b.ResetTimer()
|
||||
for range b.N {
|
||||
var wg sync.WaitGroup
|
||||
var mu sync.Mutex
|
||||
iterTokens := 0
|
||||
|
||||
wg.Add(numWorkers)
|
||||
for i := range numWorkers {
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
count := 0
|
||||
for range m.Generate(ctx, prompts[idx], inference.WithMaxTokens(maxTok)) {
|
||||
count++
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
iterTokens += count
|
||||
mu.Unlock()
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
totalTokens += iterTokens
|
||||
}
|
||||
b.StopTimer()
|
||||
|
||||
if totalTokens > 0 {
|
||||
elapsed := b.Elapsed()
|
||||
b.ReportMetric(float64(totalTokens)/elapsed.Seconds(), "tok/s-aggregate")
|
||||
b.ReportMetric(float64(totalTokens)/float64(b.N), "tok/iter")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Reference in a new issue