feat: detect server crash before Generate/Chat calls
Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
parent
d963cbf787
commit
2c4966e652
3 changed files with 75 additions and 0 deletions
15
model.go
15
model.go
|
|
@ -4,6 +4,7 @@ package rocm
|
|||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"iter"
|
||||
"sync"
|
||||
|
||||
|
|
@ -22,6 +23,13 @@ type rocmModel struct {
|
|||
|
||||
// Generate streams tokens for the given prompt via llama-server's /v1/completions endpoint.
|
||||
func (m *rocmModel) Generate(ctx context.Context, prompt string, opts ...inference.GenerateOption) iter.Seq[inference.Token] {
|
||||
if !m.srv.alive() {
|
||||
m.mu.Lock()
|
||||
m.lastErr = fmt.Errorf("rocm: server has exited: %w", m.srv.exitErr)
|
||||
m.mu.Unlock()
|
||||
return func(yield func(inference.Token) bool) {}
|
||||
}
|
||||
|
||||
cfg := inference.ApplyGenerateOpts(opts)
|
||||
|
||||
req := llamacpp.CompletionRequest{
|
||||
|
|
@ -51,6 +59,13 @@ func (m *rocmModel) Generate(ctx context.Context, prompt string, opts ...inferen
|
|||
|
||||
// Chat streams tokens from a multi-turn conversation via llama-server's /v1/chat/completions endpoint.
|
||||
func (m *rocmModel) Chat(ctx context.Context, messages []inference.Message, opts ...inference.GenerateOption) iter.Seq[inference.Token] {
|
||||
if !m.srv.alive() {
|
||||
m.mu.Lock()
|
||||
m.lastErr = fmt.Errorf("rocm: server has exited: %w", m.srv.exitErr)
|
||||
m.mu.Unlock()
|
||||
return func(yield func(inference.Token) bool) {}
|
||||
}
|
||||
|
||||
cfg := inference.ApplyGenerateOpts(opts)
|
||||
|
||||
chatMsgs := make([]llamacpp.ChatMessage, len(messages))
|
||||
|
|
|
|||
10
server.go
10
server.go
|
|
@ -25,6 +25,16 @@ type server struct {
|
|||
exitErr error
|
||||
}
|
||||
|
||||
// alive reports whether the llama-server process is still running.
|
||||
func (s *server) alive() bool {
|
||||
select {
|
||||
case <-s.exited:
|
||||
return false
|
||||
default:
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// findLlamaServer locates the llama-server binary.
|
||||
// Checks ROCM_LLAMA_SERVER_PATH first, then PATH.
|
||||
func findLlamaServer() (string, error) {
|
||||
|
|
|
|||
|
|
@ -3,10 +3,13 @@
|
|||
package rocm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"forge.lthn.ai/core/go-inference"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
|
@ -95,3 +98,50 @@ func TestGuessModelType(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestServerAlive_Running(t *testing.T) {
|
||||
s := &server{exited: make(chan struct{})}
|
||||
assert.True(t, s.alive())
|
||||
}
|
||||
|
||||
func TestServerAlive_Exited(t *testing.T) {
|
||||
exited := make(chan struct{})
|
||||
close(exited)
|
||||
s := &server{exited: exited, exitErr: fmt.Errorf("process killed")}
|
||||
assert.False(t, s.alive())
|
||||
}
|
||||
|
||||
func TestGenerate_ServerDead(t *testing.T) {
|
||||
exited := make(chan struct{})
|
||||
close(exited)
|
||||
s := &server{
|
||||
exited: exited,
|
||||
exitErr: fmt.Errorf("process killed"),
|
||||
}
|
||||
m := &rocmModel{srv: s}
|
||||
|
||||
var count int
|
||||
for range m.Generate(context.Background(), "hello") {
|
||||
count++
|
||||
}
|
||||
assert.Equal(t, 0, count)
|
||||
assert.ErrorContains(t, m.Err(), "server has exited")
|
||||
}
|
||||
|
||||
func TestChat_ServerDead(t *testing.T) {
|
||||
exited := make(chan struct{})
|
||||
close(exited)
|
||||
s := &server{
|
||||
exited: exited,
|
||||
exitErr: fmt.Errorf("process killed"),
|
||||
}
|
||||
m := &rocmModel{srv: s}
|
||||
|
||||
msgs := []inference.Message{{Role: "user", Content: "hello"}}
|
||||
var count int
|
||||
for range m.Chat(context.Background(), msgs) {
|
||||
count++
|
||||
}
|
||||
assert.Equal(t, 0, count)
|
||||
assert.ErrorContains(t, m.Err(), "server has exited")
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue