diff --git a/model.go b/model.go index bc07d07..366bce3 100644 --- a/model.go +++ b/model.go @@ -4,6 +4,7 @@ package rocm import ( "context" + "fmt" "iter" "sync" @@ -22,6 +23,13 @@ type rocmModel struct { // Generate streams tokens for the given prompt via llama-server's /v1/completions endpoint. func (m *rocmModel) Generate(ctx context.Context, prompt string, opts ...inference.GenerateOption) iter.Seq[inference.Token] { + if !m.srv.alive() { + m.mu.Lock() + m.lastErr = fmt.Errorf("rocm: server has exited: %w", m.srv.exitErr) + m.mu.Unlock() + return func(yield func(inference.Token) bool) {} + } + cfg := inference.ApplyGenerateOpts(opts) req := llamacpp.CompletionRequest{ @@ -51,6 +59,13 @@ func (m *rocmModel) Generate(ctx context.Context, prompt string, opts ...inferen // Chat streams tokens from a multi-turn conversation via llama-server's /v1/chat/completions endpoint. func (m *rocmModel) Chat(ctx context.Context, messages []inference.Message, opts ...inference.GenerateOption) iter.Seq[inference.Token] { + if !m.srv.alive() { + m.mu.Lock() + m.lastErr = fmt.Errorf("rocm: server has exited: %w", m.srv.exitErr) + m.mu.Unlock() + return func(yield func(inference.Token) bool) {} + } + cfg := inference.ApplyGenerateOpts(opts) chatMsgs := make([]llamacpp.ChatMessage, len(messages)) diff --git a/server.go b/server.go index b1affcd..1a1566e 100644 --- a/server.go +++ b/server.go @@ -25,6 +25,16 @@ type server struct { exitErr error } +// alive reports whether the llama-server process is still running. +func (s *server) alive() bool { + select { + case <-s.exited: + return false + default: + return true + } +} + // findLlamaServer locates the llama-server binary. // Checks ROCM_LLAMA_SERVER_PATH first, then PATH. func findLlamaServer() (string, error) { diff --git a/server_test.go b/server_test.go index 0976322..6a4e074 100644 --- a/server_test.go +++ b/server_test.go @@ -3,10 +3,13 @@ package rocm import ( + "context" + "fmt" "os" "strings" "testing" + "forge.lthn.ai/core/go-inference" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -95,3 +98,50 @@ func TestGuessModelType(t *testing.T) { }) } } + +func TestServerAlive_Running(t *testing.T) { + s := &server{exited: make(chan struct{})} + assert.True(t, s.alive()) +} + +func TestServerAlive_Exited(t *testing.T) { + exited := make(chan struct{}) + close(exited) + s := &server{exited: exited, exitErr: fmt.Errorf("process killed")} + assert.False(t, s.alive()) +} + +func TestGenerate_ServerDead(t *testing.T) { + exited := make(chan struct{}) + close(exited) + s := &server{ + exited: exited, + exitErr: fmt.Errorf("process killed"), + } + m := &rocmModel{srv: s} + + var count int + for range m.Generate(context.Background(), "hello") { + count++ + } + assert.Equal(t, 0, count) + assert.ErrorContains(t, m.Err(), "server has exited") +} + +func TestChat_ServerDead(t *testing.T) { + exited := make(chan struct{}) + close(exited) + s := &server{ + exited: exited, + exitErr: fmt.Errorf("process killed"), + } + m := &rocmModel{srv: s} + + msgs := []inference.Message{{Role: "user", Content: "hello"}} + var count int + for range m.Chat(context.Background(), msgs) { + count++ + } + assert.Equal(t, 0, count) + assert.ErrorContains(t, m.Err(), "server has exited") +}