From 870ee232bfa8c906fa6c8c68c48cd69bb841f30c Mon Sep 17 00:00:00 2001 From: Claude Date: Thu, 19 Feb 2026 23:16:40 +0000 Subject: [PATCH] feat: benchmark suite for decode speed, TTFT, and concurrent throughput BenchmarkDecode, BenchmarkTTFT, BenchmarkConcurrent across 3 models. Uses testing.B with b.ReportMetric for Go benchstat integration. Co-Authored-By: Virgil --- rocm_benchmark_test.go | 181 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 181 insertions(+) create mode 100644 rocm_benchmark_test.go diff --git a/rocm_benchmark_test.go b/rocm_benchmark_test.go new file mode 100644 index 0000000..b833eaa --- /dev/null +++ b/rocm_benchmark_test.go @@ -0,0 +1,181 @@ +//go:build rocm + +package rocm + +import ( + "context" + "os" + "sync" + "testing" + "time" + + "forge.lthn.ai/core/go-inference" +) + +// benchModels lists the models to benchmark. +// Each loads ~3-9 GB of VRAM, so they run sequentially (one at a time). +var benchModels = []struct { + name string + path string +}{ + {"Gemma3-4B-Q4_K_M", "/data/lem/gguf/LEK-Gemma3-4B-Q4_K_M.gguf"}, + {"Llama3.1-8B-Q4_K_M", "/data/lem/gguf/LEK-Llama-3.1-8B-Q4_K_M.gguf"}, + {"Qwen2.5-7B-Q4_K_M", "/data/lem/gguf/LEK-Qwen-2.5-7B-Q4_K_M.gguf"}, +} + +func skipBenchIfUnavailable(b *testing.B) { + b.Helper() + if _, err := os.Stat("/dev/kfd"); err != nil { + b.Skip("no ROCm hardware") + } + if _, err := findLlamaServer(); err != nil { + b.Skip("llama-server not found") + } +} + +// loadBenchModel loads a model for benchmarking. Caller must defer m.Close(). +// Stops the benchmark timer during loading so load time is excluded. +func loadBenchModel(b *testing.B, path string, opts ...inference.LoadOption) inference.TextModel { + b.Helper() + if _, err := os.Stat(path); err != nil { + b.Skipf("model not available: %s", path) + } + + b.StopTimer() + backend := &rocmBackend{} + defaults := []inference.LoadOption{inference.WithContextLen(2048)} + m, err := backend.LoadModel(path, append(defaults, opts...)...) + if err != nil { + b.Fatalf("load model: %v", err) + } + + if vram, err := GetVRAMInfo(); err == nil { + b.Logf("VRAM after load: %d MiB used / %d MiB total", + vram.Used/(1024*1024), vram.Total/(1024*1024)) + } + + b.StartTimer() + return m +} + +// BenchmarkDecode measures token generation speed (tok/s). +// Generates 128 tokens per iteration, reports tokens/second. +func BenchmarkDecode(b *testing.B) { + skipBenchIfUnavailable(b) + + for _, model := range benchModels { + b.Run(model.name, func(b *testing.B) { + m := loadBenchModel(b, model.path) + defer m.Close() + + const maxTok = 128 + + var totalTokens int + b.ResetTimer() + for range b.N { + ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) + var count int + for range m.Generate(ctx, "Explain the theory of relativity in detail.", inference.WithMaxTokens(maxTok)) { + count++ + } + cancel() + totalTokens += count + } + b.StopTimer() + + if totalTokens > 0 { + elapsed := b.Elapsed() + tokPerSec := float64(totalTokens) / elapsed.Seconds() + b.ReportMetric(tokPerSec, "tok/s") + } + }) + } +} + +// BenchmarkTTFT measures time-to-first-token latency. +// Reports the average time from request start to first token received. +func BenchmarkTTFT(b *testing.B) { + skipBenchIfUnavailable(b) + + for _, model := range benchModels { + b.Run(model.name, func(b *testing.B) { + m := loadBenchModel(b, model.path) + defer m.Close() + + var totalTTFT time.Duration + b.ResetTimer() + for range b.N { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + start := time.Now() + for range m.Generate(ctx, "Hello", inference.WithMaxTokens(1)) { + totalTTFT += time.Since(start) + } + cancel() + } + b.StopTimer() + + if b.N > 0 { + avgTTFT := totalTTFT / time.Duration(b.N) + b.ReportMetric(float64(avgTTFT.Microseconds()), "µs/first-tok") + } + }) + } +} + +// BenchmarkConcurrent measures throughput with multiple goroutines. +// Uses 4 parallel slots and 4 goroutines generating simultaneously. +func BenchmarkConcurrent(b *testing.B) { + skipBenchIfUnavailable(b) + + for _, model := range benchModels { + b.Run(model.name, func(b *testing.B) { + m := loadBenchModel(b, model.path, inference.WithParallelSlots(4)) + defer m.Close() + + const numWorkers = 4 + const maxTok = 32 + + prompts := []string{ + "The capital of France is", + "The capital of Germany is", + "The capital of Italy is", + "The capital of Spain is", + } + + var totalTokens int + b.ResetTimer() + for range b.N { + var wg sync.WaitGroup + var mu sync.Mutex + iterTokens := 0 + + wg.Add(numWorkers) + for i := range numWorkers { + go func(idx int) { + defer wg.Done() + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + count := 0 + for range m.Generate(ctx, prompts[idx], inference.WithMaxTokens(maxTok)) { + count++ + } + + mu.Lock() + iterTokens += count + mu.Unlock() + }(i) + } + wg.Wait() + totalTokens += iterTokens + } + b.StopTimer() + + if totalTokens > 0 { + elapsed := b.Elapsed() + b.ReportMetric(float64(totalTokens)/elapsed.Seconds(), "tok/s-aggregate") + b.ReportMetric(float64(totalTokens)/float64(b.N), "tok/iter") + } + }) + } +}