From 402bdc2205bd213aa4d0b0f4736db21f29e458b1 Mon Sep 17 00:00:00 2001 From: Claude Date: Tue, 24 Feb 2026 18:52:10 +0000 Subject: [PATCH] test: add integration tests for Classify, BatchGenerate, Info, Metrics Verified on RX 7800 XT (gfx1100, ROCm 7.2): - Classify: greedy single-token via max_tokens=1 - BatchGenerate: sequential multi-prompt generation - Info: GGUF metadata (gemma3, 26 layers, Q5_K_M) - Metrics: 250 tok/s decode, 3244 MiB VRAM Co-Authored-By: Virgil --- rocm_integration_test.go | 94 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 94 insertions(+) diff --git a/rocm_integration_test.go b/rocm_integration_test.go index 9f4cdab..1856036 100644 --- a/rocm_integration_test.go +++ b/rocm_integration_test.go @@ -201,6 +201,100 @@ func TestROCm_ConcurrentRequests(t *testing.T) { } } +func TestROCm_Classify(t *testing.T) { + skipIfNoROCm(t) + skipIfNoModel(t) + + b := &rocmBackend{} + m, err := b.LoadModel(testModel, inference.WithContextLen(2048)) + require.NoError(t, err) + defer m.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + prompts := []string{ + "The capital of France is", + "2 + 2 =", + } + + results, err := m.Classify(ctx, prompts) + require.NoError(t, err) + require.Len(t, results, 2) + + for i, r := range results { + assert.NotEmpty(t, r.Token.Text, "classify result %d should have a token", i) + t.Logf("Classify %d: %q", i, r.Token.Text) + } +} + +func TestROCm_BatchGenerate(t *testing.T) { + skipIfNoROCm(t) + skipIfNoModel(t) + + b := &rocmBackend{} + m, err := b.LoadModel(testModel, inference.WithContextLen(2048)) + require.NoError(t, err) + defer m.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) + defer cancel() + + prompts := []string{ + "The capital of France is", + "The capital of Germany is", + } + + results, err := m.BatchGenerate(ctx, prompts, inference.WithMaxTokens(8)) + require.NoError(t, err) + require.Len(t, results, 2) + + for i, r := range results { + require.NoError(t, r.Err, "batch result %d error", i) + assert.NotEmpty(t, r.Tokens, "batch result %d should have tokens", i) + + var sb strings.Builder + for _, tok := range r.Tokens { + sb.WriteString(tok.Text) + } + t.Logf("Batch %d: %s", i, sb.String()) + } +} + +func TestROCm_InfoAndMetrics(t *testing.T) { + skipIfNoROCm(t) + skipIfNoModel(t) + + b := &rocmBackend{} + m, err := b.LoadModel(testModel, inference.WithContextLen(2048)) + require.NoError(t, err) + defer m.Close() + + // Info should be populated from GGUF metadata. + info := m.Info() + assert.Equal(t, "gemma3", info.Architecture) + assert.Greater(t, info.NumLayers, 0, "expected non-zero layer count") + assert.Greater(t, info.QuantBits, 0, "expected non-zero quant bits") + t.Logf("Info: arch=%s layers=%d quant=%d-bit group=%d", + info.Architecture, info.NumLayers, info.QuantBits, info.QuantGroup) + + // Generate some tokens to populate metrics. + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + for range m.Generate(ctx, "Hello", inference.WithMaxTokens(4)) { + } + require.NoError(t, m.Err()) + + met := m.Metrics() + assert.Greater(t, met.GeneratedTokens, 0, "expected generated tokens") + assert.Greater(t, met.TotalDuration, time.Duration(0), "expected non-zero duration") + assert.Greater(t, met.DecodeTokensPerSec, float64(0), "expected non-zero decode throughput") + t.Logf("Metrics: gen=%d tok, total=%s, decode=%.1f tok/s, vram=%d MiB", + met.GeneratedTokens, met.TotalDuration, met.DecodeTokensPerSec, + met.ActiveMemoryBytes/(1024*1024)) +} + func TestROCm_DiscoverModels(t *testing.T) { dir := filepath.Dir(testModel) if _, err := os.Stat(dir); err != nil {