From 0e68d71c8a6aa6384a8524e36903e93dd05220d8 Mon Sep 17 00:00:00 2001 From: Claude Date: Thu, 19 Feb 2026 21:15:02 +0000 Subject: [PATCH] test: integration tests for full ROCm inference pipeline MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit LoadModel → Generate → Chat → Close on real AMD GPU hardware. Build-tagged //go:build rocm so normal go test skips them. Co-Authored-By: Virgil Co-Authored-By: Claude Opus 4.6 --- rocm_integration_test.go | 119 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 119 insertions(+) create mode 100644 rocm_integration_test.go diff --git a/rocm_integration_test.go b/rocm_integration_test.go new file mode 100644 index 0000000..c166f51 --- /dev/null +++ b/rocm_integration_test.go @@ -0,0 +1,119 @@ +//go:build rocm + +package rocm + +import ( + "context" + "os" + "testing" + "time" + + "forge.lthn.ai/core/go-inference" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +const testModel = "/data/lem/gguf/LEK-Gemma3-1B-layered-v2-Q5_K_M.gguf" + +func skipIfNoModel(t *testing.T) { + t.Helper() + if _, err := os.Stat(testModel); os.IsNotExist(err) { + t.Skip("test model not available (SMB mount down?)") + } +} + +func skipIfNoROCm(t *testing.T) { + t.Helper() + if _, err := os.Stat("/dev/kfd"); err != nil { + t.Skip("no ROCm hardware") + } + if _, err := findLlamaServer(); err != nil { + t.Skip("llama-server not found") + } +} + +func TestROCm_LoadAndGenerate(t *testing.T) { + skipIfNoROCm(t) + skipIfNoModel(t) + + b := &rocmBackend{} + require.True(t, b.Available()) + + m, err := b.LoadModel(testModel, inference.WithContextLen(2048)) + require.NoError(t, err) + defer m.Close() + + assert.Equal(t, "gemma3", m.ModelType()) + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + var tokens []string + for tok := range m.Generate(ctx, "The capital of France is", inference.WithMaxTokens(16)) { + tokens = append(tokens, tok.Text) + } + + require.NoError(t, m.Err()) + require.NotEmpty(t, tokens, "expected at least one token") + + full := "" + for _, tok := range tokens { + full += tok + } + t.Logf("Generated: %s", full) +} + +func TestROCm_Chat(t *testing.T) { + skipIfNoROCm(t) + skipIfNoModel(t) + + b := &rocmBackend{} + m, err := b.LoadModel(testModel, inference.WithContextLen(2048)) + require.NoError(t, err) + defer m.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + messages := []inference.Message{ + {Role: "user", Content: "Say hello in exactly three words."}, + } + + var tokens []string + for tok := range m.Chat(ctx, messages, inference.WithMaxTokens(32)) { + tokens = append(tokens, tok.Text) + } + + require.NoError(t, m.Err()) + require.NotEmpty(t, tokens, "expected at least one token") + + full := "" + for _, tok := range tokens { + full += tok + } + t.Logf("Chat response: %s", full) +} + +func TestROCm_ContextCancellation(t *testing.T) { + skipIfNoROCm(t) + skipIfNoModel(t) + + b := &rocmBackend{} + m, err := b.LoadModel(testModel, inference.WithContextLen(2048)) + require.NoError(t, err) + defer m.Close() + + ctx, cancel := context.WithCancel(context.Background()) + + var count int + for tok := range m.Generate(ctx, "Write a very long story about dragons", inference.WithMaxTokens(256)) { + _ = tok + count++ + if count >= 3 { + cancel() + } + } + + t.Logf("Got %d tokens before cancel", count) + assert.GreaterOrEqual(t, count, 3) +}