diff --git a/docs/pkg/cli/daemon.md b/docs/pkg/cli/daemon.md index 05c16d9..c1409da 100644 --- a/docs/pkg/cli/daemon.md +++ b/docs/pkg/cli/daemon.md @@ -42,6 +42,34 @@ func runDaemon(cmd *cli.Command, args []string) error { } ``` +## Daemon Helper + +Use `cli.NewDaemon()` when you want a helper that writes a PID file and serves +basic `/health` and `/ready` probes: + +```go +daemon := cli.NewDaemon(cli.DaemonOptions{ + PIDFile: "/tmp/core.pid", + HealthAddr: "127.0.0.1:8080", + HealthCheck: func() bool { + return true + }, + ReadyCheck: func() bool { + return true + }, +}) + +if err := daemon.Start(context.Background()); err != nil { + return err +} +defer func() { + _ = daemon.Stop(context.Background()) +}() +``` + +`Start()` writes the current process ID to the configured file, and `Stop()` +removes it after shutting the probe server down. + ## Shutdown with Timeout The daemon stop logic sends SIGTERM and waits up to 30 seconds. If the process has not exited by then, it sends SIGKILL and removes the PID file. diff --git a/docs/pkg/cli/index.md b/docs/pkg/cli/index.md index b1ed2fa..882fd5d 100644 --- a/docs/pkg/cli/index.md +++ b/docs/pkg/cli/index.md @@ -52,6 +52,7 @@ The framework has three layers: | `TreeNode` | Tree structure with box-drawing connectors | | `TaskTracker` | Concurrent task display with live spinners | | `CheckBuilder` | Fluent API for pass/fail/skip result lines | +| `Daemon` | PID file and probe helper for background processes | | `AnsiStyle` | Terminal text styling (bold, dim, colour) | ## Built-in Services diff --git a/pkg/cli/daemon_process.go b/pkg/cli/daemon_process.go new file mode 100644 index 0000000..0ec9a7d --- /dev/null +++ b/pkg/cli/daemon_process.go @@ -0,0 +1,219 @@ +package cli + +import ( + "context" + "errors" + "io" + "net" + "net/http" + "os" + "path/filepath" + "strconv" + "sync" + "time" +) + +// DaemonOptions configures a background process helper. +type DaemonOptions struct { + // PIDFile stores the current process ID on Start and removes it on Stop. + PIDFile string + + // HealthAddr binds the HTTP health server. + // Pass an empty string to disable the server. + HealthAddr string + + // HealthPath serves the liveness probe endpoint. + HealthPath string + + // ReadyPath serves the readiness probe endpoint. + ReadyPath string + + // HealthCheck reports whether the process is healthy. + // Defaults to true when nil. + HealthCheck func() bool + + // ReadyCheck reports whether the process is ready to serve traffic. + // Defaults to HealthCheck when nil, or true when both are nil. + ReadyCheck func() bool +} + +// Daemon manages a PID file and optional HTTP health endpoints. +type Daemon struct { + opts DaemonOptions + + mu sync.Mutex + listener net.Listener + server *http.Server + addr string + started bool +} + +// NewDaemon creates a daemon helper with sensible defaults. +func NewDaemon(opts DaemonOptions) *Daemon { + if opts.HealthPath == "" { + opts.HealthPath = "/health" + } + if opts.ReadyPath == "" { + opts.ReadyPath = "/ready" + } + return &Daemon{opts: opts} +} + +// Start writes the PID file and starts the health server, if configured. +func (d *Daemon) Start(ctx context.Context) error { + if ctx == nil { + ctx = context.Background() + } + + d.mu.Lock() + defer d.mu.Unlock() + + if d.started { + return nil + } + + if err := d.writePIDFile(); err != nil { + return err + } + + if d.opts.HealthAddr != "" { + if err := d.startHealthServer(ctx); err != nil { + _ = d.removePIDFile() + return err + } + } + + d.started = true + return nil +} + +// Stop shuts down the health server and removes the PID file. +func (d *Daemon) Stop(ctx context.Context) error { + if ctx == nil { + ctx = context.Background() + } + + d.mu.Lock() + server := d.server + listener := d.listener + d.server = nil + d.listener = nil + d.addr = "" + d.started = false + d.mu.Unlock() + + var firstErr error + + if server != nil { + shutdownCtx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + if err := server.Shutdown(shutdownCtx); err != nil && !isClosedServerError(err) { + firstErr = err + } + } + + if listener != nil { + if err := listener.Close(); err != nil && !isListenerClosedError(err) && firstErr == nil { + firstErr = err + } + } + + if err := d.removePIDFile(); err != nil && firstErr == nil { + firstErr = err + } + + return firstErr +} + +// HealthAddr returns the bound health server address, if running. +func (d *Daemon) HealthAddr() string { + d.mu.Lock() + defer d.mu.Unlock() + if d.addr != "" { + return d.addr + } + return d.opts.HealthAddr +} + +func (d *Daemon) writePIDFile() error { + if d.opts.PIDFile == "" { + return nil + } + + if err := os.MkdirAll(filepath.Dir(d.opts.PIDFile), 0o755); err != nil { + return err + } + return os.WriteFile(d.opts.PIDFile, []byte(strconv.Itoa(os.Getpid())+"\n"), 0o644) +} + +func (d *Daemon) removePIDFile() error { + if d.opts.PIDFile == "" { + return nil + } + if err := os.Remove(d.opts.PIDFile); err != nil && !os.IsNotExist(err) { + return err + } + return nil +} + +func (d *Daemon) startHealthServer(ctx context.Context) error { + mux := http.NewServeMux() + healthCheck := d.opts.HealthCheck + if healthCheck == nil { + healthCheck = func() bool { return true } + } + readyCheck := d.opts.ReadyCheck + if readyCheck == nil { + readyCheck = healthCheck + } + + mux.HandleFunc(d.opts.HealthPath, func(w http.ResponseWriter, r *http.Request) { + writeProbe(w, healthCheck()) + }) + mux.HandleFunc(d.opts.ReadyPath, func(w http.ResponseWriter, r *http.Request) { + writeProbe(w, readyCheck()) + }) + + listener, err := net.Listen("tcp", d.opts.HealthAddr) + if err != nil { + return err + } + + server := &http.Server{ + Handler: mux, + BaseContext: func(net.Listener) context.Context { + return ctx + }, + } + + d.listener = listener + d.server = server + d.addr = listener.Addr().String() + + go func() { + err := server.Serve(listener) + if err != nil && !isClosedServerError(err) { + _ = err + } + }() + + return nil +} + +func writeProbe(w http.ResponseWriter, ok bool) { + if ok { + w.WriteHeader(http.StatusOK) + _, _ = io.WriteString(w, "ok\n") + return + } + w.WriteHeader(http.StatusServiceUnavailable) + _, _ = io.WriteString(w, "unhealthy\n") +} + +func isClosedServerError(err error) bool { + return err == nil || err == http.ErrServerClosed +} + +func isListenerClosedError(err error) bool { + return err == nil || errors.Is(err, net.ErrClosed) +} diff --git a/pkg/cli/daemon_process_test.go b/pkg/cli/daemon_process_test.go new file mode 100644 index 0000000..1e5c8fc --- /dev/null +++ b/pkg/cli/daemon_process_test.go @@ -0,0 +1,90 @@ +package cli + +import ( + "context" + "io" + "net/http" + "os" + "path/filepath" + "strconv" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestDaemon_StartStop(t *testing.T) { + tmp := t.TempDir() + pidFile := filepath.Join(tmp, "daemon.pid") + ready := false + + daemon := NewDaemon(DaemonOptions{ + PIDFile: pidFile, + HealthAddr: "127.0.0.1:0", + HealthCheck: func() bool { + return true + }, + ReadyCheck: func() bool { + return ready + }, + }) + + require.NoError(t, daemon.Start(context.Background())) + defer func() { + require.NoError(t, daemon.Stop(context.Background())) + }() + + rawPID, err := os.ReadFile(pidFile) + require.NoError(t, err) + assert.Equal(t, strconv.Itoa(os.Getpid()), strings.TrimSpace(string(rawPID))) + + addr := daemon.HealthAddr() + require.NotEmpty(t, addr) + + client := &http.Client{Timeout: 2 * time.Second} + + resp, err := client.Get("http://" + addr + "/health") + require.NoError(t, err) + body, err := io.ReadAll(resp.Body) + resp.Body.Close() + require.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.Equal(t, "ok\n", string(body)) + + resp, err = client.Get("http://" + addr + "/ready") + require.NoError(t, err) + body, err = io.ReadAll(resp.Body) + resp.Body.Close() + require.NoError(t, err) + assert.Equal(t, http.StatusServiceUnavailable, resp.StatusCode) + assert.Equal(t, "unhealthy\n", string(body)) + + ready = true + + resp, err = client.Get("http://" + addr + "/ready") + require.NoError(t, err) + body, err = io.ReadAll(resp.Body) + resp.Body.Close() + require.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.Equal(t, "ok\n", string(body)) +} + +func TestDaemon_StopRemovesPIDFile(t *testing.T) { + tmp := t.TempDir() + pidFile := filepath.Join(tmp, "daemon.pid") + + daemon := NewDaemon(DaemonOptions{PIDFile: pidFile}) + require.NoError(t, daemon.Start(context.Background())) + + _, err := os.Stat(pidFile) + require.NoError(t, err) + + require.NoError(t, daemon.Stop(context.Background())) + + _, err = os.Stat(pidFile) + require.Error(t, err) + assert.True(t, os.IsNotExist(err)) +}