diff --git a/pkg/api/provider.go b/pkg/api/provider.go index 7717458..1e5442e 100644 --- a/pkg/api/provider.go +++ b/pkg/api/provider.go @@ -293,7 +293,7 @@ func (p *ProcessProvider) Describe() []api.RouteDescription { Method: "POST", Path: "/processes/:id/kill", Summary: "Kill a managed process", - Description: "Sends SIGKILL to the managed process identified by ID.", + Description: "Sends SIGKILL to the managed process identified by ID, or to a raw OS PID when the path value is numeric.", Tags: []string{"process"}, Response: map[string]any{ "type": "object", @@ -306,7 +306,7 @@ func (p *ProcessProvider) Describe() []api.RouteDescription { Method: "POST", Path: "/processes/:id/signal", Summary: "Signal a managed process", - Description: "Sends a Unix signal to the managed process identified by ID.", + Description: "Sends a Unix signal to the managed process identified by ID, or to a raw OS PID when the path value is numeric.", Tags: []string{"process"}, RequestBody: map[string]any{ "type": "object", @@ -578,7 +578,16 @@ func (p *ProcessProvider) killProcess(c *gin.Context) { return } - if err := p.service.Kill(c.Param("id")); err != nil { + id := c.Param("id") + if err := p.service.Kill(id); err != nil { + if pid, ok := pidFromString(id); ok { + if pidErr := p.service.KillPID(pid); pidErr == nil { + c.JSON(http.StatusOK, api.OK(map[string]any{"killed": true})) + return + } else { + err = pidErr + } + } status := http.StatusInternalServerError if err == process.ErrProcessNotFound { status = http.StatusNotFound @@ -612,7 +621,16 @@ func (p *ProcessProvider) signalProcess(c *gin.Context) { return } - if err := p.service.Signal(c.Param("id"), sig); err != nil { + id := c.Param("id") + if err := p.service.Signal(id, sig); err != nil { + if pid, ok := pidFromString(id); ok { + if pidErr := p.service.SignalPID(pid, sig); pidErr == nil { + c.JSON(http.StatusOK, api.OK(map[string]any{"signalled": true})) + return + } else { + err = pidErr + } + } status := http.StatusInternalServerError if err == process.ErrProcessNotFound || err == process.ErrProcessNotRunning { status = http.StatusNotFound @@ -723,6 +741,14 @@ func intParam(c *gin.Context, name string) int { return v } +func pidFromString(value string) (int, bool) { + pid, err := strconv.Atoi(strings.TrimSpace(value)) + if err != nil || pid <= 0 { + return 0, false + } + return pid, true +} + func parseSignal(value string) (syscall.Signal, error) { trimmed := strings.TrimSpace(strings.ToUpper(value)) if trimmed == "" { diff --git a/pkg/api/provider_test.go b/pkg/api/provider_test.go index 165910b..53cae02 100644 --- a/pkg/api/provider_test.go +++ b/pkg/api/provider_test.go @@ -8,6 +8,8 @@ import ( "net/http" "net/http/httptest" "os" + "os/exec" + "strconv" "strings" "testing" "time" @@ -474,6 +476,50 @@ func TestProcessProvider_KillProcess_Good(t *testing.T) { assert.Equal(t, process.StatusKilled, proc.Status) } +func TestProcessProvider_KillProcess_ByPID_Good(t *testing.T) { + svc := newTestProcessService(t) + p := processapi.NewProvider(nil, svc, nil) + r := setupRouter(p) + + cmd := exec.Command("sleep", "60") + require.NoError(t, cmd.Start()) + + waitCh := make(chan error, 1) + go func() { + waitCh <- cmd.Wait() + }() + + t.Cleanup(func() { + if cmd.ProcessState == nil && cmd.Process != nil { + _ = cmd.Process.Kill() + } + select { + case <-waitCh: + case <-time.After(2 * time.Second): + } + }) + + w := httptest.NewRecorder() + req, err := http.NewRequest("POST", "/api/process/processes/"+strconv.Itoa(cmd.Process.Pid)+"/kill", nil) + require.NoError(t, err) + r.ServeHTTP(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + + var resp goapi.Response[map[string]any] + err = json.Unmarshal(w.Body.Bytes(), &resp) + require.NoError(t, err) + require.True(t, resp.Success) + assert.Equal(t, true, resp.Data["killed"]) + + select { + case err := <-waitCh: + require.Error(t, err) + case <-time.After(5 * time.Second): + t.Fatal("unmanaged process should have been killed by PID") + } +} + func TestProcessProvider_SignalProcess_Good(t *testing.T) { svc := newTestProcessService(t) proc, err := svc.Start(context.Background(), "sleep", "60") @@ -504,6 +550,51 @@ func TestProcessProvider_SignalProcess_Good(t *testing.T) { assert.Equal(t, process.StatusKilled, proc.Status) } +func TestProcessProvider_SignalProcess_ByPID_Good(t *testing.T) { + svc := newTestProcessService(t) + p := processapi.NewProvider(nil, svc, nil) + r := setupRouter(p) + + cmd := exec.Command("sleep", "60") + require.NoError(t, cmd.Start()) + + waitCh := make(chan error, 1) + go func() { + waitCh <- cmd.Wait() + }() + + t.Cleanup(func() { + if cmd.ProcessState == nil && cmd.Process != nil { + _ = cmd.Process.Kill() + } + select { + case <-waitCh: + case <-time.After(2 * time.Second): + } + }) + + w := httptest.NewRecorder() + req, err := http.NewRequest("POST", "/api/process/processes/"+strconv.Itoa(cmd.Process.Pid)+"/signal", strings.NewReader(`{"signal":"SIGTERM"}`)) + require.NoError(t, err) + req.Header.Set("Content-Type", "application/json") + r.ServeHTTP(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + + var resp goapi.Response[map[string]any] + err = json.Unmarshal(w.Body.Bytes(), &resp) + require.NoError(t, err) + require.True(t, resp.Success) + assert.Equal(t, true, resp.Data["signalled"]) + + select { + case err := <-waitCh: + require.Error(t, err) + case <-time.After(5 * time.Second): + t.Fatal("unmanaged process should have been signalled by PID") + } +} + func TestProcessProvider_SignalProcess_InvalidSignal_Bad(t *testing.T) { svc := newTestProcessService(t) proc, err := svc.Start(context.Background(), "sleep", "60")