feat(cli): add external daemon stop helper
All checks were successful
Security Scan / security (push) Successful in 19s
All checks were successful
Security Scan / security (push) Successful in 19s
Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
parent
a2f27b9af4
commit
12496ba57c
3 changed files with 209 additions and 0 deletions
|
|
@ -70,6 +70,11 @@ defer func() {
|
||||||
`Start()` writes the current process ID to the configured file, and `Stop()`
|
`Start()` writes the current process ID to the configured file, and `Stop()`
|
||||||
removes it after shutting the probe server down.
|
removes it after shutting the probe server down.
|
||||||
|
|
||||||
|
If you need to stop a daemon process from outside its own process tree, use
|
||||||
|
`cli.StopPIDFile(pidFile, timeout)`. It sends `SIGTERM`, waits up to the
|
||||||
|
timeout for exit, escalates to `SIGKILL` if needed, and removes the PID file
|
||||||
|
after the process stops.
|
||||||
|
|
||||||
## Shutdown with Timeout
|
## Shutdown with Timeout
|
||||||
|
|
||||||
The daemon stop logic sends SIGTERM and waits up to 30 seconds. If the process has not exited by then, it sends SIGKILL and removes the PID file.
|
The daemon stop logic sends SIGTERM and waits up to 30 seconds. If the process has not exited by then, it sends SIGKILL and removes the PID file.
|
||||||
|
|
|
||||||
|
|
@ -3,13 +3,16 @@ package cli
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -48,6 +51,28 @@ type Daemon struct {
|
||||||
started bool
|
started bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
processNow = time.Now
|
||||||
|
processSleep = time.Sleep
|
||||||
|
processAlive = func(pid int) bool {
|
||||||
|
proc, err := os.FindProcess(pid)
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
err = proc.Signal(syscall.Signal(0))
|
||||||
|
return err == nil || errors.Is(err, syscall.EPERM)
|
||||||
|
}
|
||||||
|
processSignal = func(pid int, sig syscall.Signal) error {
|
||||||
|
proc, err := os.FindProcess(pid)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return proc.Signal(sig)
|
||||||
|
}
|
||||||
|
processPollInterval = 100 * time.Millisecond
|
||||||
|
processShutdownWait = 30 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
// NewDaemon creates a daemon helper with sensible defaults.
|
// NewDaemon creates a daemon helper with sensible defaults.
|
||||||
func NewDaemon(opts DaemonOptions) *Daemon {
|
func NewDaemon(opts DaemonOptions) *Daemon {
|
||||||
if opts.HealthPath == "" {
|
if opts.HealthPath == "" {
|
||||||
|
|
@ -135,6 +160,76 @@ func (d *Daemon) HealthAddr() string {
|
||||||
return d.opts.HealthAddr
|
return d.opts.HealthAddr
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// StopPIDFile sends SIGTERM to the process identified by pidFile, waits for it
|
||||||
|
// to exit, escalates to SIGKILL after the timeout, and then removes the file.
|
||||||
|
//
|
||||||
|
// If the PID file does not exist, StopPIDFile returns nil.
|
||||||
|
func StopPIDFile(pidFile string, timeout time.Duration) error {
|
||||||
|
if pidFile == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if timeout <= 0 {
|
||||||
|
timeout = processShutdownWait
|
||||||
|
}
|
||||||
|
|
||||||
|
rawPID, err := os.ReadFile(pidFile)
|
||||||
|
if err != nil {
|
||||||
|
if os.IsNotExist(err) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
pid, err := parsePID(strings.TrimSpace(string(rawPID)))
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("parse pid file %q: %w", pidFile, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := processSignal(pid, syscall.SIGTERM); err != nil && !isProcessGone(err) {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
deadline := processNow().Add(timeout)
|
||||||
|
for processAlive(pid) && processNow().Before(deadline) {
|
||||||
|
processSleep(processPollInterval)
|
||||||
|
}
|
||||||
|
|
||||||
|
if processAlive(pid) {
|
||||||
|
if err := processSignal(pid, syscall.SIGKILL); err != nil && !isProcessGone(err) {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
deadline = processNow().Add(processShutdownWait)
|
||||||
|
for processAlive(pid) && processNow().Before(deadline) {
|
||||||
|
processSleep(processPollInterval)
|
||||||
|
}
|
||||||
|
|
||||||
|
if processAlive(pid) {
|
||||||
|
return fmt.Errorf("process %d did not exit after SIGKILL", pid)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return os.Remove(pidFile)
|
||||||
|
}
|
||||||
|
|
||||||
|
func parsePID(raw string) (int, error) {
|
||||||
|
if raw == "" {
|
||||||
|
return 0, fmt.Errorf("empty pid")
|
||||||
|
}
|
||||||
|
pid, err := strconv.Atoi(raw)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
if pid <= 0 {
|
||||||
|
return 0, fmt.Errorf("invalid pid %d", pid)
|
||||||
|
}
|
||||||
|
return pid, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func isProcessGone(err error) bool {
|
||||||
|
return errors.Is(err, os.ErrProcessDone) || errors.Is(err, syscall.ESRCH)
|
||||||
|
}
|
||||||
|
|
||||||
func (d *Daemon) writePIDFile() error {
|
func (d *Daemon) writePIDFile() error {
|
||||||
if d.opts.PIDFile == "" {
|
if d.opts.PIDFile == "" {
|
||||||
return nil
|
return nil
|
||||||
|
|
|
||||||
|
|
@ -8,6 +8,8 @@ import (
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"syscall"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
|
@ -88,3 +90,110 @@ func TestDaemon_StopRemovesPIDFile(t *testing.T) {
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
assert.True(t, os.IsNotExist(err))
|
assert.True(t, os.IsNotExist(err))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestStopPIDFile_Good(t *testing.T) {
|
||||||
|
tmp := t.TempDir()
|
||||||
|
pidFile := filepath.Join(tmp, "daemon.pid")
|
||||||
|
require.NoError(t, os.WriteFile(pidFile, []byte("1234\n"), 0o644))
|
||||||
|
|
||||||
|
originalSignal := processSignal
|
||||||
|
originalAlive := processAlive
|
||||||
|
originalNow := processNow
|
||||||
|
originalSleep := processSleep
|
||||||
|
originalPoll := processPollInterval
|
||||||
|
originalShutdownWait := processShutdownWait
|
||||||
|
t.Cleanup(func() {
|
||||||
|
processSignal = originalSignal
|
||||||
|
processAlive = originalAlive
|
||||||
|
processNow = originalNow
|
||||||
|
processSleep = originalSleep
|
||||||
|
processPollInterval = originalPoll
|
||||||
|
processShutdownWait = originalShutdownWait
|
||||||
|
})
|
||||||
|
|
||||||
|
var mu sync.Mutex
|
||||||
|
var signals []syscall.Signal
|
||||||
|
processSignal = func(pid int, sig syscall.Signal) error {
|
||||||
|
mu.Lock()
|
||||||
|
signals = append(signals, sig)
|
||||||
|
mu.Unlock()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
processAlive = func(pid int) bool {
|
||||||
|
mu.Lock()
|
||||||
|
defer mu.Unlock()
|
||||||
|
if len(signals) == 0 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return signals[len(signals)-1] != syscall.SIGTERM
|
||||||
|
}
|
||||||
|
processPollInterval = 0
|
||||||
|
processShutdownWait = 0
|
||||||
|
|
||||||
|
require.NoError(t, StopPIDFile(pidFile, time.Second))
|
||||||
|
|
||||||
|
mu.Lock()
|
||||||
|
defer mu.Unlock()
|
||||||
|
require.Equal(t, []syscall.Signal{syscall.SIGTERM}, signals)
|
||||||
|
|
||||||
|
_, err := os.Stat(pidFile)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.True(t, os.IsNotExist(err))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStopPIDFile_Bad_Escalates(t *testing.T) {
|
||||||
|
tmp := t.TempDir()
|
||||||
|
pidFile := filepath.Join(tmp, "daemon.pid")
|
||||||
|
require.NoError(t, os.WriteFile(pidFile, []byte("4321\n"), 0o644))
|
||||||
|
|
||||||
|
originalSignal := processSignal
|
||||||
|
originalAlive := processAlive
|
||||||
|
originalNow := processNow
|
||||||
|
originalSleep := processSleep
|
||||||
|
originalPoll := processPollInterval
|
||||||
|
originalShutdownWait := processShutdownWait
|
||||||
|
t.Cleanup(func() {
|
||||||
|
processSignal = originalSignal
|
||||||
|
processAlive = originalAlive
|
||||||
|
processNow = originalNow
|
||||||
|
processSleep = originalSleep
|
||||||
|
processPollInterval = originalPoll
|
||||||
|
processShutdownWait = originalShutdownWait
|
||||||
|
})
|
||||||
|
|
||||||
|
var mu sync.Mutex
|
||||||
|
var signals []syscall.Signal
|
||||||
|
current := time.Unix(0, 0)
|
||||||
|
processNow = func() time.Time {
|
||||||
|
mu.Lock()
|
||||||
|
defer mu.Unlock()
|
||||||
|
return current
|
||||||
|
}
|
||||||
|
processSleep = func(d time.Duration) {
|
||||||
|
mu.Lock()
|
||||||
|
current = current.Add(d)
|
||||||
|
mu.Unlock()
|
||||||
|
}
|
||||||
|
processSignal = func(pid int, sig syscall.Signal) error {
|
||||||
|
mu.Lock()
|
||||||
|
signals = append(signals, sig)
|
||||||
|
mu.Unlock()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
processAlive = func(pid int) bool {
|
||||||
|
mu.Lock()
|
||||||
|
defer mu.Unlock()
|
||||||
|
if len(signals) == 0 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return signals[len(signals)-1] != syscall.SIGKILL
|
||||||
|
}
|
||||||
|
processPollInterval = 10 * time.Millisecond
|
||||||
|
processShutdownWait = 0
|
||||||
|
|
||||||
|
require.NoError(t, StopPIDFile(pidFile, 15*time.Millisecond))
|
||||||
|
|
||||||
|
mu.Lock()
|
||||||
|
defer mu.Unlock()
|
||||||
|
require.Equal(t, []syscall.Signal{syscall.SIGTERM, syscall.SIGKILL}, signals)
|
||||||
|
}
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue