From 10fc9559fa58d9a4ca3e3311661b7fab688cc1c3 Mon Sep 17 00:00:00 2001 From: Virgil Date: Wed, 1 Apr 2026 05:13:44 +0000 Subject: [PATCH] fix(api): make SSE drain wait for clients Drain() now signals SSE clients to disconnect, closes the event stream under the broker lock, and waits for handler goroutines to exit before returning. Added a regression test to verify that active clients disconnect and the broker reaches zero client count. Co-Authored-By: Virgil --- sse.go | 63 +++++++++++++++++++++++++++++++++++++++++------------ sse_test.go | 52 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 101 insertions(+), 14 deletions(-) diff --git a/sse.go b/sse.go index 9adf7ee..a1e4d02 100644 --- a/sse.go +++ b/sse.go @@ -17,14 +17,17 @@ import ( // subscribe to a specific channel via the ?channel= query parameter. type SSEBroker struct { mu sync.RWMutex + wg sync.WaitGroup clients map[*sseClient]struct{} } // sseClient represents a single connected SSE consumer. type sseClient struct { - channel string - events chan sseEvent - done chan struct{} + channel string + events chan sseEvent + done chan struct{} + doneOnce sync.Once + eventsOnce sync.Once } // sseEvent is an internal representation of a single SSE message. @@ -60,6 +63,11 @@ func (b *SSEBroker) Publish(channel, event string, data any) { for client := range b.clients { // Send to clients on the matching channel, or clients with no channel filter. if client.channel == "" || client.channel == channel { + select { + case <-client.done: + continue + default: + } select { case client.events <- msg: case <-client.done: @@ -85,13 +93,15 @@ func (b *SSEBroker) Handler() gin.HandlerFunc { b.mu.Lock() b.clients[client] = struct{}{} + b.wg.Add(1) b.mu.Unlock() defer func() { - close(client.done) b.mu.Lock() + client.signalDone() delete(b.clients, client) b.mu.Unlock() + b.wg.Done() }() // Set SSE headers. @@ -108,7 +118,20 @@ func (b *SSEBroker) Handler() gin.HandlerFunc { select { case <-ctx.Done(): return - case evt := <-client.events: + case <-client.done: + return + default: + } + + select { + case <-ctx.Done(): + return + case <-client.done: + return + case evt, ok := <-client.events: + if !ok { + return + } _, err := fmt.Fprintf(c.Writer, "event: %s\ndata: %s\n\n", evt.Event, evt.Data) if err != nil { return @@ -129,17 +152,29 @@ func (b *SSEBroker) ClientCount() int { return len(b.clients) } -// Drain closes all connected clients by writing an empty response. -// Useful for graceful shutdown. +// Drain signals all connected clients to disconnect and waits for their +// handler goroutines to exit. Useful for graceful shutdown. func (b *SSEBroker) Drain() { b.mu.Lock() - defer b.mu.Unlock() for client := range b.clients { - select { - case <-client.done: - default: - // Write EOF to trigger client disconnect via their event loop. - close(client.events) - } + client.signalDone() + client.closeEvents() } + b.mu.Unlock() + + b.wg.Wait() +} + +// signalDone closes the client done channel once. +func (c *sseClient) signalDone() { + c.doneOnce.Do(func() { + close(c.done) + }) +} + +// closeEvents closes the client event channel once. +func (c *sseClient) closeEvents() { + c.eventsOnce.Do(func() { + close(c.events) + }) } diff --git a/sse_test.go b/sse_test.go index 7467b38..a2769e3 100644 --- a/sse_test.go +++ b/sse_test.go @@ -269,6 +269,58 @@ func TestWithSSE_Good_MultipleClients(t *testing.T) { wg.Wait() } +func TestWithSSE_Good_DrainDisconnectsClients(t *testing.T) { + gin.SetMode(gin.TestMode) + + broker := api.NewSSEBroker() + e, err := api.New(api.WithSSE(broker)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + srv := httptest.NewServer(e.Handler()) + defer srv.Close() + + resp, err := http.Get(srv.URL + "/events") + if err != nil { + t.Fatalf("request failed: %v", err) + } + + waitForClients(t, broker, 1) + + streamDone := make(chan struct{}) + go func() { + defer close(streamDone) + scanner := bufio.NewScanner(resp.Body) + for scanner.Scan() { + } + }() + + drainDone := make(chan struct{}) + go func() { + broker.Drain() + close(drainDone) + }() + + select { + case <-drainDone: + case <-time.After(3 * time.Second): + t.Fatal("timed out waiting for SSE drain to complete") + } + + select { + case <-streamDone: + case <-time.After(3 * time.Second): + t.Fatal("timed out waiting for SSE client to disconnect") + } + + if got := broker.ClientCount(); got != 0 { + t.Fatalf("expected 0 connected SSE clients after drain, got %d", got) + } + + _ = resp.Body.Close() +} + // ── No SSE broker ──────────────────────────────────────────────────────── func TestNoSSEBroker_Good(t *testing.T) {