test: graceful shutdown and concurrent request integration tests

Clear lastErr at the start of each Generate/Chat call so that Err()
reflects the most recent call, not a stale cancellation from a prior one.

Add two integration tests:
- GracefulShutdown: cancel mid-stream then generate again on the same
  model, verifying the server survives cancellation.
- ConcurrentRequests: three goroutines calling Generate() simultaneously,
  verifying no panics or deadlocks (llama-server serialises via slots).

Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
Claude 2026-02-19 21:50:47 +00:00
parent 954c57071a
commit a6e647c5b7
No known key found for this signature in database
GPG key ID: AF404715446AEB41
2 changed files with 90 additions and 0 deletions

View file

@ -23,6 +23,10 @@ type rocmModel struct {
// Generate streams tokens for the given prompt via llama-server's /v1/completions endpoint.
func (m *rocmModel) Generate(ctx context.Context, prompt string, opts ...inference.GenerateOption) iter.Seq[inference.Token] {
m.mu.Lock()
m.lastErr = nil
m.mu.Unlock()
if !m.srv.alive() {
m.mu.Lock()
if m.srv.exitErr != nil {
@ -63,6 +67,10 @@ func (m *rocmModel) Generate(ctx context.Context, prompt string, opts ...inferen
// Chat streams tokens from a multi-turn conversation via llama-server's /v1/chat/completions endpoint.
func (m *rocmModel) Chat(ctx context.Context, messages []inference.Message, opts ...inference.GenerateOption) iter.Seq[inference.Token] {
m.mu.Lock()
m.lastErr = nil
m.mu.Unlock()
if !m.srv.alive() {
m.mu.Lock()
if m.srv.exitErr != nil {

View file

@ -5,6 +5,8 @@ package rocm
import (
"context"
"os"
"strings"
"sync"
"testing"
"time"
@ -117,3 +119,83 @@ func TestROCm_ContextCancellation(t *testing.T) {
t.Logf("Got %d tokens before cancel", count)
assert.GreaterOrEqual(t, count, 3)
}
func TestROCm_GracefulShutdown(t *testing.T) {
skipIfNoROCm(t)
skipIfNoModel(t)
b := &rocmBackend{}
m, err := b.LoadModel(testModel, inference.WithContextLen(2048))
require.NoError(t, err)
defer m.Close()
// Cancel mid-stream.
ctx1, cancel1 := context.WithCancel(context.Background())
var count1 int
for tok := range m.Generate(ctx1, "Write a long story about space exploration", inference.WithMaxTokens(256)) {
_ = tok
count1++
if count1 >= 5 {
cancel1()
}
}
t.Logf("First generation: %d tokens before cancel", count1)
// Generate again on the same model — server should still be alive.
ctx2, cancel2 := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel2()
var count2 int
for tok := range m.Generate(ctx2, "The capital of France is", inference.WithMaxTokens(16)) {
_ = tok
count2++
}
require.NoError(t, m.Err())
assert.Greater(t, count2, 0, "expected tokens from second generation after cancel")
t.Logf("Second generation: %d tokens", count2)
}
func TestROCm_ConcurrentRequests(t *testing.T) {
skipIfNoROCm(t)
skipIfNoModel(t)
b := &rocmBackend{}
m, err := b.LoadModel(testModel, inference.WithContextLen(2048))
require.NoError(t, err)
defer m.Close()
const numGoroutines = 3
results := make([]string, numGoroutines)
prompts := []string{
"The capital of France is",
"The capital of Germany is",
"The capital of Italy is",
}
var wg sync.WaitGroup
wg.Add(numGoroutines)
for i := range numGoroutines {
go func(idx int) {
defer wg.Done()
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
var sb strings.Builder
for tok := range m.Generate(ctx, prompts[idx], inference.WithMaxTokens(16)) {
sb.WriteString(tok.Text)
}
results[idx] = sb.String()
}(i)
}
wg.Wait()
for i, result := range results {
t.Logf("Goroutine %d: %s", i, result)
assert.NotEmpty(t, result, "goroutine %d produced no output", i)
}
}