go-rocm/rocm_integration_test.go
Claude d7db2d6e95
docs: Phase 3 complete — GGUF metadata, discovery, auto context
Integration test verifies model discovery on real GGUF files.
All 9 models in /data/lem/gguf/ discovered with correct metadata.

Co-Authored-By: Virgil <virgil@lethean.io>
2026-02-19 22:24:52 +00:00

220 lines
5.1 KiB
Go

//go:build rocm
package rocm
import (
"context"
"os"
"path/filepath"
"strings"
"sync"
"testing"
"time"
"forge.lthn.ai/core/go-inference"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
const testModel = "/data/lem/gguf/LEK-Gemma3-1B-layered-v2-Q5_K_M.gguf"
func skipIfNoModel(t *testing.T) {
t.Helper()
if _, err := os.Stat(testModel); os.IsNotExist(err) {
t.Skip("test model not available (SMB mount down?)")
}
}
func skipIfNoROCm(t *testing.T) {
t.Helper()
if _, err := os.Stat("/dev/kfd"); err != nil {
t.Skip("no ROCm hardware")
}
if _, err := findLlamaServer(); err != nil {
t.Skip("llama-server not found")
}
}
func TestROCm_LoadAndGenerate(t *testing.T) {
skipIfNoROCm(t)
skipIfNoModel(t)
b := &rocmBackend{}
require.True(t, b.Available())
m, err := b.LoadModel(testModel, inference.WithContextLen(2048))
require.NoError(t, err)
defer m.Close()
assert.Equal(t, "gemma3", m.ModelType())
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
var tokens []string
for tok := range m.Generate(ctx, "The capital of France is", inference.WithMaxTokens(16)) {
tokens = append(tokens, tok.Text)
}
require.NoError(t, m.Err())
require.NotEmpty(t, tokens, "expected at least one token")
full := ""
for _, tok := range tokens {
full += tok
}
t.Logf("Generated: %s", full)
}
func TestROCm_Chat(t *testing.T) {
skipIfNoROCm(t)
skipIfNoModel(t)
b := &rocmBackend{}
m, err := b.LoadModel(testModel, inference.WithContextLen(2048))
require.NoError(t, err)
defer m.Close()
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
messages := []inference.Message{
{Role: "user", Content: "Say hello in exactly three words."},
}
var tokens []string
for tok := range m.Chat(ctx, messages, inference.WithMaxTokens(32)) {
tokens = append(tokens, tok.Text)
}
require.NoError(t, m.Err())
require.NotEmpty(t, tokens, "expected at least one token")
full := ""
for _, tok := range tokens {
full += tok
}
t.Logf("Chat response: %s", full)
}
func TestROCm_ContextCancellation(t *testing.T) {
skipIfNoROCm(t)
skipIfNoModel(t)
b := &rocmBackend{}
m, err := b.LoadModel(testModel, inference.WithContextLen(2048))
require.NoError(t, err)
defer m.Close()
ctx, cancel := context.WithCancel(context.Background())
var count int
for tok := range m.Generate(ctx, "Write a very long story about dragons", inference.WithMaxTokens(256)) {
_ = tok
count++
if count >= 3 {
cancel()
}
}
t.Logf("Got %d tokens before cancel", count)
assert.GreaterOrEqual(t, count, 3)
}
func TestROCm_GracefulShutdown(t *testing.T) {
skipIfNoROCm(t)
skipIfNoModel(t)
b := &rocmBackend{}
m, err := b.LoadModel(testModel, inference.WithContextLen(2048))
require.NoError(t, err)
defer m.Close()
// Cancel mid-stream.
ctx1, cancel1 := context.WithCancel(context.Background())
var count1 int
for tok := range m.Generate(ctx1, "Write a long story about space exploration", inference.WithMaxTokens(256)) {
_ = tok
count1++
if count1 >= 5 {
cancel1()
}
}
t.Logf("First generation: %d tokens before cancel", count1)
// Generate again on the same model — server should still be alive.
ctx2, cancel2 := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel2()
var count2 int
for tok := range m.Generate(ctx2, "The capital of France is", inference.WithMaxTokens(16)) {
_ = tok
count2++
}
require.NoError(t, m.Err())
assert.Greater(t, count2, 0, "expected tokens from second generation after cancel")
t.Logf("Second generation: %d tokens", count2)
}
func TestROCm_ConcurrentRequests(t *testing.T) {
skipIfNoROCm(t)
skipIfNoModel(t)
b := &rocmBackend{}
m, err := b.LoadModel(testModel, inference.WithContextLen(2048))
require.NoError(t, err)
defer m.Close()
const numGoroutines = 3
results := make([]string, numGoroutines)
prompts := []string{
"The capital of France is",
"The capital of Germany is",
"The capital of Italy is",
}
var wg sync.WaitGroup
wg.Add(numGoroutines)
for i := range numGoroutines {
go func(idx int) {
defer wg.Done()
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
var sb strings.Builder
for tok := range m.Generate(ctx, prompts[idx], inference.WithMaxTokens(16)) {
sb.WriteString(tok.Text)
}
results[idx] = sb.String()
}(i)
}
wg.Wait()
for i, result := range results {
t.Logf("Goroutine %d: %s", i, result)
assert.NotEmpty(t, result, "goroutine %d produced no output", i)
}
}
func TestROCm_DiscoverModels(t *testing.T) {
dir := filepath.Dir(testModel)
if _, err := os.Stat(dir); err != nil {
t.Skip("model directory not available")
}
models, err := DiscoverModels(dir)
require.NoError(t, err)
require.NotEmpty(t, models, "expected at least one model in %s", dir)
for _, m := range models {
t.Logf("Found: %s (%s %s %s, ctx=%d)", filepath.Base(m.Path), m.Architecture, m.Parameters, m.Quantisation, m.ContextLen)
assert.NotEmpty(t, m.Architecture)
assert.NotEmpty(t, m.Name)
assert.Greater(t, m.FileSize, int64(0))
}
}