feat: detect server crash before Generate/Chat calls

Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
Claude 2026-02-19 21:34:46 +00:00
parent d963cbf787
commit 2c4966e652
No known key found for this signature in database
GPG key ID: AF404715446AEB41
3 changed files with 75 additions and 0 deletions

View file

@ -4,6 +4,7 @@ package rocm
import (
"context"
"fmt"
"iter"
"sync"
@ -22,6 +23,13 @@ type rocmModel struct {
// Generate streams tokens for the given prompt via llama-server's /v1/completions endpoint.
func (m *rocmModel) Generate(ctx context.Context, prompt string, opts ...inference.GenerateOption) iter.Seq[inference.Token] {
if !m.srv.alive() {
m.mu.Lock()
m.lastErr = fmt.Errorf("rocm: server has exited: %w", m.srv.exitErr)
m.mu.Unlock()
return func(yield func(inference.Token) bool) {}
}
cfg := inference.ApplyGenerateOpts(opts)
req := llamacpp.CompletionRequest{
@ -51,6 +59,13 @@ func (m *rocmModel) Generate(ctx context.Context, prompt string, opts ...inferen
// Chat streams tokens from a multi-turn conversation via llama-server's /v1/chat/completions endpoint.
func (m *rocmModel) Chat(ctx context.Context, messages []inference.Message, opts ...inference.GenerateOption) iter.Seq[inference.Token] {
if !m.srv.alive() {
m.mu.Lock()
m.lastErr = fmt.Errorf("rocm: server has exited: %w", m.srv.exitErr)
m.mu.Unlock()
return func(yield func(inference.Token) bool) {}
}
cfg := inference.ApplyGenerateOpts(opts)
chatMsgs := make([]llamacpp.ChatMessage, len(messages))

View file

@ -25,6 +25,16 @@ type server struct {
exitErr error
}
// alive reports whether the llama-server process is still running.
func (s *server) alive() bool {
select {
case <-s.exited:
return false
default:
return true
}
}
// findLlamaServer locates the llama-server binary.
// Checks ROCM_LLAMA_SERVER_PATH first, then PATH.
func findLlamaServer() (string, error) {

View file

@ -3,10 +3,13 @@
package rocm
import (
"context"
"fmt"
"os"
"strings"
"testing"
"forge.lthn.ai/core/go-inference"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@ -95,3 +98,50 @@ func TestGuessModelType(t *testing.T) {
})
}
}
func TestServerAlive_Running(t *testing.T) {
s := &server{exited: make(chan struct{})}
assert.True(t, s.alive())
}
func TestServerAlive_Exited(t *testing.T) {
exited := make(chan struct{})
close(exited)
s := &server{exited: exited, exitErr: fmt.Errorf("process killed")}
assert.False(t, s.alive())
}
func TestGenerate_ServerDead(t *testing.T) {
exited := make(chan struct{})
close(exited)
s := &server{
exited: exited,
exitErr: fmt.Errorf("process killed"),
}
m := &rocmModel{srv: s}
var count int
for range m.Generate(context.Background(), "hello") {
count++
}
assert.Equal(t, 0, count)
assert.ErrorContains(t, m.Err(), "server has exited")
}
func TestChat_ServerDead(t *testing.T) {
exited := make(chan struct{})
close(exited)
s := &server{
exited: exited,
exitErr: fmt.Errorf("process killed"),
}
m := &rocmModel{srv: s}
msgs := []inference.Message{{Role: "user", Content: "hello"}}
var count int
for range m.Chat(context.Background(), msgs) {
count++
}
assert.Equal(t, 0, count)
assert.ErrorContains(t, m.Err(), "server has exited")
}