test: graceful shutdown and concurrent request integration tests
Clear lastErr at the start of each Generate/Chat call so that Err() reflects the most recent call, not a stale cancellation from a prior one. Add two integration tests: - GracefulShutdown: cancel mid-stream then generate again on the same model, verifying the server survives cancellation. - ConcurrentRequests: three goroutines calling Generate() simultaneously, verifying no panics or deadlocks (llama-server serialises via slots). Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
parent
954c57071a
commit
a6e647c5b7
2 changed files with 90 additions and 0 deletions
8
model.go
8
model.go
|
|
@ -23,6 +23,10 @@ type rocmModel struct {
|
|||
|
||||
// Generate streams tokens for the given prompt via llama-server's /v1/completions endpoint.
|
||||
func (m *rocmModel) Generate(ctx context.Context, prompt string, opts ...inference.GenerateOption) iter.Seq[inference.Token] {
|
||||
m.mu.Lock()
|
||||
m.lastErr = nil
|
||||
m.mu.Unlock()
|
||||
|
||||
if !m.srv.alive() {
|
||||
m.mu.Lock()
|
||||
if m.srv.exitErr != nil {
|
||||
|
|
@ -63,6 +67,10 @@ func (m *rocmModel) Generate(ctx context.Context, prompt string, opts ...inferen
|
|||
|
||||
// Chat streams tokens from a multi-turn conversation via llama-server's /v1/chat/completions endpoint.
|
||||
func (m *rocmModel) Chat(ctx context.Context, messages []inference.Message, opts ...inference.GenerateOption) iter.Seq[inference.Token] {
|
||||
m.mu.Lock()
|
||||
m.lastErr = nil
|
||||
m.mu.Unlock()
|
||||
|
||||
if !m.srv.alive() {
|
||||
m.mu.Lock()
|
||||
if m.srv.exitErr != nil {
|
||||
|
|
|
|||
|
|
@ -5,6 +5,8 @@ package rocm
|
|||
import (
|
||||
"context"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
|
|
@ -117,3 +119,83 @@ func TestROCm_ContextCancellation(t *testing.T) {
|
|||
t.Logf("Got %d tokens before cancel", count)
|
||||
assert.GreaterOrEqual(t, count, 3)
|
||||
}
|
||||
|
||||
func TestROCm_GracefulShutdown(t *testing.T) {
|
||||
skipIfNoROCm(t)
|
||||
skipIfNoModel(t)
|
||||
|
||||
b := &rocmBackend{}
|
||||
m, err := b.LoadModel(testModel, inference.WithContextLen(2048))
|
||||
require.NoError(t, err)
|
||||
defer m.Close()
|
||||
|
||||
// Cancel mid-stream.
|
||||
ctx1, cancel1 := context.WithCancel(context.Background())
|
||||
var count1 int
|
||||
for tok := range m.Generate(ctx1, "Write a long story about space exploration", inference.WithMaxTokens(256)) {
|
||||
_ = tok
|
||||
count1++
|
||||
if count1 >= 5 {
|
||||
cancel1()
|
||||
}
|
||||
}
|
||||
t.Logf("First generation: %d tokens before cancel", count1)
|
||||
|
||||
// Generate again on the same model — server should still be alive.
|
||||
ctx2, cancel2 := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel2()
|
||||
|
||||
var count2 int
|
||||
for tok := range m.Generate(ctx2, "The capital of France is", inference.WithMaxTokens(16)) {
|
||||
_ = tok
|
||||
count2++
|
||||
}
|
||||
|
||||
require.NoError(t, m.Err())
|
||||
assert.Greater(t, count2, 0, "expected tokens from second generation after cancel")
|
||||
t.Logf("Second generation: %d tokens", count2)
|
||||
}
|
||||
|
||||
func TestROCm_ConcurrentRequests(t *testing.T) {
|
||||
skipIfNoROCm(t)
|
||||
skipIfNoModel(t)
|
||||
|
||||
b := &rocmBackend{}
|
||||
m, err := b.LoadModel(testModel, inference.WithContextLen(2048))
|
||||
require.NoError(t, err)
|
||||
defer m.Close()
|
||||
|
||||
const numGoroutines = 3
|
||||
results := make([]string, numGoroutines)
|
||||
|
||||
prompts := []string{
|
||||
"The capital of France is",
|
||||
"The capital of Germany is",
|
||||
"The capital of Italy is",
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(numGoroutines)
|
||||
|
||||
for i := range numGoroutines {
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
var sb strings.Builder
|
||||
for tok := range m.Generate(ctx, prompts[idx], inference.WithMaxTokens(16)) {
|
||||
sb.WriteString(tok.Text)
|
||||
}
|
||||
results[idx] = sb.String()
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
for i, result := range results {
|
||||
t.Logf("Goroutine %d: %s", i, result)
|
||||
assert.NotEmpty(t, result, "goroutine %d produced no output", i)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue