diff --git a/sse.go b/sse.go index 9adf7ee..a1e4d02 100644 --- a/sse.go +++ b/sse.go @@ -17,14 +17,17 @@ import ( // subscribe to a specific channel via the ?channel= query parameter. type SSEBroker struct { mu sync.RWMutex + wg sync.WaitGroup clients map[*sseClient]struct{} } // sseClient represents a single connected SSE consumer. type sseClient struct { - channel string - events chan sseEvent - done chan struct{} + channel string + events chan sseEvent + done chan struct{} + doneOnce sync.Once + eventsOnce sync.Once } // sseEvent is an internal representation of a single SSE message. @@ -60,6 +63,11 @@ func (b *SSEBroker) Publish(channel, event string, data any) { for client := range b.clients { // Send to clients on the matching channel, or clients with no channel filter. if client.channel == "" || client.channel == channel { + select { + case <-client.done: + continue + default: + } select { case client.events <- msg: case <-client.done: @@ -85,13 +93,15 @@ func (b *SSEBroker) Handler() gin.HandlerFunc { b.mu.Lock() b.clients[client] = struct{}{} + b.wg.Add(1) b.mu.Unlock() defer func() { - close(client.done) b.mu.Lock() + client.signalDone() delete(b.clients, client) b.mu.Unlock() + b.wg.Done() }() // Set SSE headers. @@ -108,7 +118,20 @@ func (b *SSEBroker) Handler() gin.HandlerFunc { select { case <-ctx.Done(): return - case evt := <-client.events: + case <-client.done: + return + default: + } + + select { + case <-ctx.Done(): + return + case <-client.done: + return + case evt, ok := <-client.events: + if !ok { + return + } _, err := fmt.Fprintf(c.Writer, "event: %s\ndata: %s\n\n", evt.Event, evt.Data) if err != nil { return @@ -129,17 +152,29 @@ func (b *SSEBroker) ClientCount() int { return len(b.clients) } -// Drain closes all connected clients by writing an empty response. -// Useful for graceful shutdown. +// Drain signals all connected clients to disconnect and waits for their +// handler goroutines to exit. Useful for graceful shutdown. func (b *SSEBroker) Drain() { b.mu.Lock() - defer b.mu.Unlock() for client := range b.clients { - select { - case <-client.done: - default: - // Write EOF to trigger client disconnect via their event loop. - close(client.events) - } + client.signalDone() + client.closeEvents() } + b.mu.Unlock() + + b.wg.Wait() +} + +// signalDone closes the client done channel once. +func (c *sseClient) signalDone() { + c.doneOnce.Do(func() { + close(c.done) + }) +} + +// closeEvents closes the client event channel once. +func (c *sseClient) closeEvents() { + c.eventsOnce.Do(func() { + close(c.events) + }) } diff --git a/sse_test.go b/sse_test.go index 7467b38..a2769e3 100644 --- a/sse_test.go +++ b/sse_test.go @@ -269,6 +269,58 @@ func TestWithSSE_Good_MultipleClients(t *testing.T) { wg.Wait() } +func TestWithSSE_Good_DrainDisconnectsClients(t *testing.T) { + gin.SetMode(gin.TestMode) + + broker := api.NewSSEBroker() + e, err := api.New(api.WithSSE(broker)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + srv := httptest.NewServer(e.Handler()) + defer srv.Close() + + resp, err := http.Get(srv.URL + "/events") + if err != nil { + t.Fatalf("request failed: %v", err) + } + + waitForClients(t, broker, 1) + + streamDone := make(chan struct{}) + go func() { + defer close(streamDone) + scanner := bufio.NewScanner(resp.Body) + for scanner.Scan() { + } + }() + + drainDone := make(chan struct{}) + go func() { + broker.Drain() + close(drainDone) + }() + + select { + case <-drainDone: + case <-time.After(3 * time.Second): + t.Fatal("timed out waiting for SSE drain to complete") + } + + select { + case <-streamDone: + case <-time.After(3 * time.Second): + t.Fatal("timed out waiting for SSE client to disconnect") + } + + if got := broker.ClientCount(); got != 0 { + t.Fatalf("expected 0 connected SSE clients after drain, got %d", got) + } + + _ = resp.Body.Close() +} + // ── No SSE broker ──────────────────────────────────────────────────────── func TestNoSSEBroker_Good(t *testing.T) {