From 12496ba57cff0090edcf735acc4fb88cc5327b34 Mon Sep 17 00:00:00 2001 From: Virgil Date: Thu, 2 Apr 2026 04:27:42 +0000 Subject: [PATCH] feat(cli): add external daemon stop helper Co-Authored-By: Virgil --- docs/pkg/cli/daemon.md | 5 ++ pkg/cli/daemon_process.go | 95 ++++++++++++++++++++++++++++ pkg/cli/daemon_process_test.go | 109 +++++++++++++++++++++++++++++++++ 3 files changed, 209 insertions(+) diff --git a/docs/pkg/cli/daemon.md b/docs/pkg/cli/daemon.md index c1409da..236d872 100644 --- a/docs/pkg/cli/daemon.md +++ b/docs/pkg/cli/daemon.md @@ -70,6 +70,11 @@ defer func() { `Start()` writes the current process ID to the configured file, and `Stop()` removes it after shutting the probe server down. +If you need to stop a daemon process from outside its own process tree, use +`cli.StopPIDFile(pidFile, timeout)`. It sends `SIGTERM`, waits up to the +timeout for exit, escalates to `SIGKILL` if needed, and removes the PID file +after the process stops. + ## Shutdown with Timeout The daemon stop logic sends SIGTERM and waits up to 30 seconds. If the process has not exited by then, it sends SIGKILL and removes the PID file. diff --git a/pkg/cli/daemon_process.go b/pkg/cli/daemon_process.go index 0ec9a7d..f4d8aed 100644 --- a/pkg/cli/daemon_process.go +++ b/pkg/cli/daemon_process.go @@ -3,13 +3,16 @@ package cli import ( "context" "errors" + "fmt" "io" "net" "net/http" "os" "path/filepath" "strconv" + "strings" "sync" + "syscall" "time" ) @@ -48,6 +51,28 @@ type Daemon struct { started bool } +var ( + processNow = time.Now + processSleep = time.Sleep + processAlive = func(pid int) bool { + proc, err := os.FindProcess(pid) + if err != nil { + return false + } + err = proc.Signal(syscall.Signal(0)) + return err == nil || errors.Is(err, syscall.EPERM) + } + processSignal = func(pid int, sig syscall.Signal) error { + proc, err := os.FindProcess(pid) + if err != nil { + return err + } + return proc.Signal(sig) + } + processPollInterval = 100 * time.Millisecond + processShutdownWait = 30 * time.Second +) + // NewDaemon creates a daemon helper with sensible defaults. func NewDaemon(opts DaemonOptions) *Daemon { if opts.HealthPath == "" { @@ -135,6 +160,76 @@ func (d *Daemon) HealthAddr() string { return d.opts.HealthAddr } +// StopPIDFile sends SIGTERM to the process identified by pidFile, waits for it +// to exit, escalates to SIGKILL after the timeout, and then removes the file. +// +// If the PID file does not exist, StopPIDFile returns nil. +func StopPIDFile(pidFile string, timeout time.Duration) error { + if pidFile == "" { + return nil + } + if timeout <= 0 { + timeout = processShutdownWait + } + + rawPID, err := os.ReadFile(pidFile) + if err != nil { + if os.IsNotExist(err) { + return nil + } + return err + } + + pid, err := parsePID(strings.TrimSpace(string(rawPID))) + if err != nil { + return fmt.Errorf("parse pid file %q: %w", pidFile, err) + } + + if err := processSignal(pid, syscall.SIGTERM); err != nil && !isProcessGone(err) { + return err + } + + deadline := processNow().Add(timeout) + for processAlive(pid) && processNow().Before(deadline) { + processSleep(processPollInterval) + } + + if processAlive(pid) { + if err := processSignal(pid, syscall.SIGKILL); err != nil && !isProcessGone(err) { + return err + } + + deadline = processNow().Add(processShutdownWait) + for processAlive(pid) && processNow().Before(deadline) { + processSleep(processPollInterval) + } + + if processAlive(pid) { + return fmt.Errorf("process %d did not exit after SIGKILL", pid) + } + } + + return os.Remove(pidFile) +} + +func parsePID(raw string) (int, error) { + if raw == "" { + return 0, fmt.Errorf("empty pid") + } + pid, err := strconv.Atoi(raw) + if err != nil { + return 0, err + } + if pid <= 0 { + return 0, fmt.Errorf("invalid pid %d", pid) + } + return pid, nil +} + +func isProcessGone(err error) bool { + return errors.Is(err, os.ErrProcessDone) || errors.Is(err, syscall.ESRCH) +} + func (d *Daemon) writePIDFile() error { if d.opts.PIDFile == "" { return nil diff --git a/pkg/cli/daemon_process_test.go b/pkg/cli/daemon_process_test.go index 1e5c8fc..511b884 100644 --- a/pkg/cli/daemon_process_test.go +++ b/pkg/cli/daemon_process_test.go @@ -8,6 +8,8 @@ import ( "path/filepath" "strconv" "strings" + "sync" + "syscall" "testing" "time" @@ -88,3 +90,110 @@ func TestDaemon_StopRemovesPIDFile(t *testing.T) { require.Error(t, err) assert.True(t, os.IsNotExist(err)) } + +func TestStopPIDFile_Good(t *testing.T) { + tmp := t.TempDir() + pidFile := filepath.Join(tmp, "daemon.pid") + require.NoError(t, os.WriteFile(pidFile, []byte("1234\n"), 0o644)) + + originalSignal := processSignal + originalAlive := processAlive + originalNow := processNow + originalSleep := processSleep + originalPoll := processPollInterval + originalShutdownWait := processShutdownWait + t.Cleanup(func() { + processSignal = originalSignal + processAlive = originalAlive + processNow = originalNow + processSleep = originalSleep + processPollInterval = originalPoll + processShutdownWait = originalShutdownWait + }) + + var mu sync.Mutex + var signals []syscall.Signal + processSignal = func(pid int, sig syscall.Signal) error { + mu.Lock() + signals = append(signals, sig) + mu.Unlock() + return nil + } + processAlive = func(pid int) bool { + mu.Lock() + defer mu.Unlock() + if len(signals) == 0 { + return true + } + return signals[len(signals)-1] != syscall.SIGTERM + } + processPollInterval = 0 + processShutdownWait = 0 + + require.NoError(t, StopPIDFile(pidFile, time.Second)) + + mu.Lock() + defer mu.Unlock() + require.Equal(t, []syscall.Signal{syscall.SIGTERM}, signals) + + _, err := os.Stat(pidFile) + require.Error(t, err) + assert.True(t, os.IsNotExist(err)) +} + +func TestStopPIDFile_Bad_Escalates(t *testing.T) { + tmp := t.TempDir() + pidFile := filepath.Join(tmp, "daemon.pid") + require.NoError(t, os.WriteFile(pidFile, []byte("4321\n"), 0o644)) + + originalSignal := processSignal + originalAlive := processAlive + originalNow := processNow + originalSleep := processSleep + originalPoll := processPollInterval + originalShutdownWait := processShutdownWait + t.Cleanup(func() { + processSignal = originalSignal + processAlive = originalAlive + processNow = originalNow + processSleep = originalSleep + processPollInterval = originalPoll + processShutdownWait = originalShutdownWait + }) + + var mu sync.Mutex + var signals []syscall.Signal + current := time.Unix(0, 0) + processNow = func() time.Time { + mu.Lock() + defer mu.Unlock() + return current + } + processSleep = func(d time.Duration) { + mu.Lock() + current = current.Add(d) + mu.Unlock() + } + processSignal = func(pid int, sig syscall.Signal) error { + mu.Lock() + signals = append(signals, sig) + mu.Unlock() + return nil + } + processAlive = func(pid int) bool { + mu.Lock() + defer mu.Unlock() + if len(signals) == 0 { + return true + } + return signals[len(signals)-1] != syscall.SIGKILL + } + processPollInterval = 10 * time.Millisecond + processShutdownWait = 0 + + require.NoError(t, StopPIDFile(pidFile, 15*time.Millisecond)) + + mu.Lock() + defer mu.Unlock() + require.Equal(t, []syscall.Signal{syscall.SIGTERM, syscall.SIGKILL}, signals) +}