diff --git a/backend.go b/backend.go index b32c6b6..27b354a 100644 --- a/backend.go +++ b/backend.go @@ -49,7 +49,7 @@ func (b *rocmBackend) LoadModel(path string, opts ...inference.LoadOption) (infe ctxLen = int(min(meta.ContextLength, 4096)) } - srv, err := startServer(binary, path, cfg.GPULayers, ctxLen) + srv, err := startServer(binary, path, cfg.GPULayers, ctxLen, cfg.ParallelSlots) if err != nil { return nil, err } diff --git a/server.go b/server.go index 4947df2..c5b8b7b 100644 --- a/server.go +++ b/server.go @@ -81,7 +81,7 @@ func serverEnv() []string { // startServer spawns llama-server and waits for it to become ready. // It selects a free port automatically, retrying up to 3 times if the // process exits during startup (e.g. port conflict). -func startServer(binary, modelPath string, gpuLayers, ctxSize int) (*server, error) { +func startServer(binary, modelPath string, gpuLayers, ctxSize, parallelSlots int) (*server, error) { if gpuLayers < 0 { gpuLayers = 999 } @@ -104,6 +104,9 @@ func startServer(binary, modelPath string, gpuLayers, ctxSize int) (*server, err if ctxSize > 0 { args = append(args, "--ctx-size", strconv.Itoa(ctxSize)) } + if parallelSlots > 0 { + args = append(args, "--parallel", strconv.Itoa(parallelSlots)) + } cmd := exec.Command(binary, args...) cmd.Env = serverEnv() diff --git a/server_test.go b/server_test.go index 5d73045..75a157f 100644 --- a/server_test.go +++ b/server_test.go @@ -114,7 +114,7 @@ func TestGenerate_ServerDead(t *testing.T) { func TestStartServer_RetriesOnProcessExit(t *testing.T) { // /bin/false starts successfully but exits immediately with code 1. // startServer should retry up to 3 times, then fail. - _, err := startServer("/bin/false", "/nonexistent/model.gguf", 999, 0) + _, err := startServer("/bin/false", "/nonexistent/model.gguf", 999, 0, 0) require.Error(t, err) assert.Contains(t, err.Error(), "failed after 3 attempts") }