feat: retry port selection in startServer on process failure
Co-Authored-By: Virgil <virgil@lethean.io> Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
c07f37afe9
commit
c50a8e9e9b
3 changed files with 53 additions and 39 deletions
|
|
@ -3,7 +3,6 @@
|
||||||
package rocm
|
package rocm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
@ -37,12 +36,7 @@ func (b *rocmBackend) LoadModel(path string, opts ...inference.LoadOption) (infe
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
port, err := freePort()
|
srv, err := startServer(binary, path, cfg.GPULayers, cfg.ContextLen)
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("rocm: find free port: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
srv, err := startServer(binary, path, port, cfg.GPULayers, cfg.ContextLen)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
||||||
76
server.go
76
server.go
|
|
@ -79,51 +79,63 @@ func serverEnv() []string {
|
||||||
}
|
}
|
||||||
|
|
||||||
// startServer spawns llama-server and waits for it to become ready.
|
// startServer spawns llama-server and waits for it to become ready.
|
||||||
func startServer(binary, modelPath string, port, gpuLayers, ctxSize int) (*server, error) {
|
// It selects a free port automatically, retrying up to 3 times if the
|
||||||
|
// process exits during startup (e.g. port conflict).
|
||||||
|
func startServer(binary, modelPath string, gpuLayers, ctxSize int) (*server, error) {
|
||||||
if gpuLayers < 0 {
|
if gpuLayers < 0 {
|
||||||
gpuLayers = 999
|
gpuLayers = 999
|
||||||
}
|
}
|
||||||
|
|
||||||
args := []string{
|
const maxAttempts = 3
|
||||||
"--model", modelPath,
|
var lastErr error
|
||||||
"--host", "127.0.0.1",
|
|
||||||
"--port", strconv.Itoa(port),
|
|
||||||
"--n-gpu-layers", strconv.Itoa(gpuLayers),
|
|
||||||
}
|
|
||||||
if ctxSize > 0 {
|
|
||||||
args = append(args, "--ctx-size", strconv.Itoa(ctxSize))
|
|
||||||
}
|
|
||||||
|
|
||||||
cmd := exec.Command(binary, args...)
|
for attempt := range maxAttempts {
|
||||||
cmd.Env = serverEnv()
|
port, err := freePort()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("rocm: find free port: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
if err := cmd.Start(); err != nil {
|
args := []string{
|
||||||
return nil, fmt.Errorf("start llama-server: %w", err)
|
"--model", modelPath,
|
||||||
}
|
"--host", "127.0.0.1",
|
||||||
|
"--port", strconv.Itoa(port),
|
||||||
|
"--n-gpu-layers", strconv.Itoa(gpuLayers),
|
||||||
|
}
|
||||||
|
if ctxSize > 0 {
|
||||||
|
args = append(args, "--ctx-size", strconv.Itoa(ctxSize))
|
||||||
|
}
|
||||||
|
|
||||||
s := &server{
|
cmd := exec.Command(binary, args...)
|
||||||
cmd: cmd,
|
cmd.Env = serverEnv()
|
||||||
port: port,
|
|
||||||
client: llamacpp.NewClient(fmt.Sprintf("http://127.0.0.1:%d", port)),
|
|
||||||
exited: make(chan struct{}),
|
|
||||||
}
|
|
||||||
|
|
||||||
// Goroutine to detect process exit.
|
if err := cmd.Start(); err != nil {
|
||||||
go func() {
|
return nil, fmt.Errorf("start llama-server: %w", err)
|
||||||
s.exitErr = cmd.Wait()
|
}
|
||||||
close(s.exited)
|
|
||||||
}()
|
|
||||||
|
|
||||||
// Wait for the health endpoint with a 60s timeout.
|
s := &server{
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
|
cmd: cmd,
|
||||||
defer cancel()
|
port: port,
|
||||||
|
client: llamacpp.NewClient(fmt.Sprintf("http://127.0.0.1:%d", port)),
|
||||||
|
exited: make(chan struct{}),
|
||||||
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
s.exitErr = cmd.Wait()
|
||||||
|
close(s.exited)
|
||||||
|
}()
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
|
||||||
|
err = s.waitReady(ctx)
|
||||||
|
cancel()
|
||||||
|
if err == nil {
|
||||||
|
return s, nil
|
||||||
|
}
|
||||||
|
|
||||||
if err := s.waitReady(ctx); err != nil {
|
|
||||||
_ = s.stop()
|
_ = s.stop()
|
||||||
return nil, fmt.Errorf("llama-server not ready: %w", err)
|
lastErr = fmt.Errorf("attempt %d: %w", attempt+1, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return s, nil
|
return nil, fmt.Errorf("rocm: server failed after %d attempts: %w", maxAttempts, lastErr)
|
||||||
}
|
}
|
||||||
|
|
||||||
// waitReady polls the health endpoint until the server is ready.
|
// waitReady polls the health endpoint until the server is ready.
|
||||||
|
|
|
||||||
|
|
@ -128,6 +128,14 @@ func TestGenerate_ServerDead(t *testing.T) {
|
||||||
assert.ErrorContains(t, m.Err(), "server has exited")
|
assert.ErrorContains(t, m.Err(), "server has exited")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestStartServer_RetriesOnProcessExit(t *testing.T) {
|
||||||
|
// /bin/false starts successfully but exits immediately with code 1.
|
||||||
|
// startServer should retry up to 3 times, then fail.
|
||||||
|
_, err := startServer("/bin/false", "/nonexistent/model.gguf", 999, 0)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "failed after 3 attempts")
|
||||||
|
}
|
||||||
|
|
||||||
func TestChat_ServerDead(t *testing.T) {
|
func TestChat_ServerDead(t *testing.T) {
|
||||||
exited := make(chan struct{})
|
exited := make(chan struct{})
|
||||||
close(exited)
|
close(exited)
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue