From e1f5b0ff4087c81d0f844f7f36dbda51e75c34df Mon Sep 17 00:00:00 2001 From: Virgil Date: Sat, 4 Apr 2026 07:00:45 +0000 Subject: [PATCH] fix(process): harden health server snapshots Co-authored-by: Virgil --- health.go | 35 +++++++++++++++++++++++++---------- health_test.go | 29 +++++++++++++++++++++++++++++ 2 files changed, 54 insertions(+), 10 deletions(-) diff --git a/health.go b/health.go index 0cd54ed..a5a2ca0 100644 --- a/health.go +++ b/health.go @@ -21,7 +21,7 @@ type HealthServer struct { addr string server *http.Server listener net.Listener - mu sync.Mutex + mu sync.RWMutex ready bool checks []HealthCheck } @@ -68,8 +68,8 @@ func (h *HealthServer) SetReady(ready bool) { // // publish the service // } func (h *HealthServer) Ready() bool { - h.mu.Lock() - defer h.mu.Unlock() + h.mu.RLock() + defer h.mu.RUnlock() return h.ready } @@ -82,11 +82,12 @@ func (h *HealthServer) Start() error { mux := http.NewServeMux() mux.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) { - h.mu.Lock() - checks := h.checks - h.mu.Unlock() + checks := h.checksSnapshot() for _, check := range checks { + if check == nil { + continue + } if err := check(); err != nil { w.WriteHeader(http.StatusServiceUnavailable) _, _ = fmt.Fprintf(w, "unhealthy: %v\n", err) @@ -99,9 +100,9 @@ func (h *HealthServer) Start() error { }) mux.HandleFunc("/ready", func(w http.ResponseWriter, r *http.Request) { - h.mu.Lock() + h.mu.RLock() ready := h.ready - h.mu.Unlock() + h.mu.RUnlock() if !ready { w.WriteHeader(http.StatusServiceUnavailable) @@ -131,6 +132,20 @@ func (h *HealthServer) Start() error { return nil } +// checksSnapshot returns a stable copy of the registered health checks. +func (h *HealthServer) checksSnapshot() []HealthCheck { + h.mu.RLock() + defer h.mu.RUnlock() + + if len(h.checks) == 0 { + return nil + } + + checks := make([]HealthCheck, len(h.checks)) + copy(checks, h.checks) + return checks +} + // Stop gracefully shuts down the health server. // // Example: @@ -156,8 +171,8 @@ func (h *HealthServer) Stop(ctx context.Context) error { // // addr := server.Addr() func (h *HealthServer) Addr() string { - h.mu.Lock() - defer h.mu.Unlock() + h.mu.RLock() + defer h.mu.RUnlock() if h.listener != nil { return h.listener.Addr().String() } diff --git a/health_test.go b/health_test.go index 386b2ed..faf9b3b 100644 --- a/health_test.go +++ b/health_test.go @@ -77,6 +77,35 @@ func TestHealthServer_WithChecks(t *testing.T) { _ = resp.Body.Close() } +func TestHealthServer_NilCheckIgnored(t *testing.T) { + hs := NewHealthServer("127.0.0.1:0") + + var check HealthCheck + hs.AddCheck(check) + + err := hs.Start() + require.NoError(t, err) + defer func() { _ = hs.Stop(context.Background()) }() + + addr := hs.Addr() + + resp, err := http.Get("http://" + addr + "/health") + require.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + _ = resp.Body.Close() +} + +func TestHealthServer_ChecksSnapshotIsStable(t *testing.T) { + hs := NewHealthServer("127.0.0.1:0") + + hs.AddCheck(func() error { return nil }) + snapshot := hs.checksSnapshot() + hs.AddCheck(func() error { return assert.AnError }) + + require.Len(t, snapshot, 1) + require.NotNil(t, snapshot[0]) +} + func TestWaitForHealth_Reachable(t *testing.T) { hs := NewHealthServer("127.0.0.1:0") require.NoError(t, hs.Start())