fix(api): make SSE drain wait for clients

Drain() now signals SSE clients to disconnect, closes the event stream under the broker lock, and waits for handler goroutines to exit before returning. Added a regression test to verify that active clients disconnect and the broker reaches zero client count.

Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
Virgil 2026-04-01 05:13:44 +00:00
parent 6fc1767d31
commit 10fc9559fa
2 changed files with 101 additions and 14 deletions

63
sse.go
View file

@ -17,14 +17,17 @@ import (
// subscribe to a specific channel via the ?channel= query parameter.
type SSEBroker struct {
mu sync.RWMutex
wg sync.WaitGroup
clients map[*sseClient]struct{}
}
// sseClient represents a single connected SSE consumer.
type sseClient struct {
channel string
events chan sseEvent
done chan struct{}
channel string
events chan sseEvent
done chan struct{}
doneOnce sync.Once
eventsOnce sync.Once
}
// sseEvent is an internal representation of a single SSE message.
@ -60,6 +63,11 @@ func (b *SSEBroker) Publish(channel, event string, data any) {
for client := range b.clients {
// Send to clients on the matching channel, or clients with no channel filter.
if client.channel == "" || client.channel == channel {
select {
case <-client.done:
continue
default:
}
select {
case client.events <- msg:
case <-client.done:
@ -85,13 +93,15 @@ func (b *SSEBroker) Handler() gin.HandlerFunc {
b.mu.Lock()
b.clients[client] = struct{}{}
b.wg.Add(1)
b.mu.Unlock()
defer func() {
close(client.done)
b.mu.Lock()
client.signalDone()
delete(b.clients, client)
b.mu.Unlock()
b.wg.Done()
}()
// Set SSE headers.
@ -108,7 +118,20 @@ func (b *SSEBroker) Handler() gin.HandlerFunc {
select {
case <-ctx.Done():
return
case evt := <-client.events:
case <-client.done:
return
default:
}
select {
case <-ctx.Done():
return
case <-client.done:
return
case evt, ok := <-client.events:
if !ok {
return
}
_, err := fmt.Fprintf(c.Writer, "event: %s\ndata: %s\n\n", evt.Event, evt.Data)
if err != nil {
return
@ -129,17 +152,29 @@ func (b *SSEBroker) ClientCount() int {
return len(b.clients)
}
// Drain closes all connected clients by writing an empty response.
// Useful for graceful shutdown.
// Drain signals all connected clients to disconnect and waits for their
// handler goroutines to exit. Useful for graceful shutdown.
func (b *SSEBroker) Drain() {
b.mu.Lock()
defer b.mu.Unlock()
for client := range b.clients {
select {
case <-client.done:
default:
// Write EOF to trigger client disconnect via their event loop.
close(client.events)
}
client.signalDone()
client.closeEvents()
}
b.mu.Unlock()
b.wg.Wait()
}
// signalDone closes the client done channel once.
func (c *sseClient) signalDone() {
c.doneOnce.Do(func() {
close(c.done)
})
}
// closeEvents closes the client event channel once.
func (c *sseClient) closeEvents() {
c.eventsOnce.Do(func() {
close(c.events)
})
}

View file

@ -269,6 +269,58 @@ func TestWithSSE_Good_MultipleClients(t *testing.T) {
wg.Wait()
}
func TestWithSSE_Good_DrainDisconnectsClients(t *testing.T) {
gin.SetMode(gin.TestMode)
broker := api.NewSSEBroker()
e, err := api.New(api.WithSSE(broker))
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
srv := httptest.NewServer(e.Handler())
defer srv.Close()
resp, err := http.Get(srv.URL + "/events")
if err != nil {
t.Fatalf("request failed: %v", err)
}
waitForClients(t, broker, 1)
streamDone := make(chan struct{})
go func() {
defer close(streamDone)
scanner := bufio.NewScanner(resp.Body)
for scanner.Scan() {
}
}()
drainDone := make(chan struct{})
go func() {
broker.Drain()
close(drainDone)
}()
select {
case <-drainDone:
case <-time.After(3 * time.Second):
t.Fatal("timed out waiting for SSE drain to complete")
}
select {
case <-streamDone:
case <-time.After(3 * time.Second):
t.Fatal("timed out waiting for SSE client to disconnect")
}
if got := broker.ClientCount(); got != 0 {
t.Fatalf("expected 0 connected SSE clients after drain, got %d", got)
}
_ = resp.Body.Close()
}
// ── No SSE broker ────────────────────────────────────────────────────────
func TestNoSSEBroker_Good(t *testing.T) {