fix(api): make SSE drain wait for clients
Drain() now signals SSE clients to disconnect, closes the event stream under the broker lock, and waits for handler goroutines to exit before returning. Added a regression test to verify that active clients disconnect and the broker reaches zero client count. Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
parent
6fc1767d31
commit
10fc9559fa
2 changed files with 101 additions and 14 deletions
63
sse.go
63
sse.go
|
|
@ -17,14 +17,17 @@ import (
|
|||
// subscribe to a specific channel via the ?channel= query parameter.
|
||||
type SSEBroker struct {
|
||||
mu sync.RWMutex
|
||||
wg sync.WaitGroup
|
||||
clients map[*sseClient]struct{}
|
||||
}
|
||||
|
||||
// sseClient represents a single connected SSE consumer.
|
||||
type sseClient struct {
|
||||
channel string
|
||||
events chan sseEvent
|
||||
done chan struct{}
|
||||
channel string
|
||||
events chan sseEvent
|
||||
done chan struct{}
|
||||
doneOnce sync.Once
|
||||
eventsOnce sync.Once
|
||||
}
|
||||
|
||||
// sseEvent is an internal representation of a single SSE message.
|
||||
|
|
@ -60,6 +63,11 @@ func (b *SSEBroker) Publish(channel, event string, data any) {
|
|||
for client := range b.clients {
|
||||
// Send to clients on the matching channel, or clients with no channel filter.
|
||||
if client.channel == "" || client.channel == channel {
|
||||
select {
|
||||
case <-client.done:
|
||||
continue
|
||||
default:
|
||||
}
|
||||
select {
|
||||
case client.events <- msg:
|
||||
case <-client.done:
|
||||
|
|
@ -85,13 +93,15 @@ func (b *SSEBroker) Handler() gin.HandlerFunc {
|
|||
|
||||
b.mu.Lock()
|
||||
b.clients[client] = struct{}{}
|
||||
b.wg.Add(1)
|
||||
b.mu.Unlock()
|
||||
|
||||
defer func() {
|
||||
close(client.done)
|
||||
b.mu.Lock()
|
||||
client.signalDone()
|
||||
delete(b.clients, client)
|
||||
b.mu.Unlock()
|
||||
b.wg.Done()
|
||||
}()
|
||||
|
||||
// Set SSE headers.
|
||||
|
|
@ -108,7 +118,20 @@ func (b *SSEBroker) Handler() gin.HandlerFunc {
|
|||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case evt := <-client.events:
|
||||
case <-client.done:
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-client.done:
|
||||
return
|
||||
case evt, ok := <-client.events:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
_, err := fmt.Fprintf(c.Writer, "event: %s\ndata: %s\n\n", evt.Event, evt.Data)
|
||||
if err != nil {
|
||||
return
|
||||
|
|
@ -129,17 +152,29 @@ func (b *SSEBroker) ClientCount() int {
|
|||
return len(b.clients)
|
||||
}
|
||||
|
||||
// Drain closes all connected clients by writing an empty response.
|
||||
// Useful for graceful shutdown.
|
||||
// Drain signals all connected clients to disconnect and waits for their
|
||||
// handler goroutines to exit. Useful for graceful shutdown.
|
||||
func (b *SSEBroker) Drain() {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
for client := range b.clients {
|
||||
select {
|
||||
case <-client.done:
|
||||
default:
|
||||
// Write EOF to trigger client disconnect via their event loop.
|
||||
close(client.events)
|
||||
}
|
||||
client.signalDone()
|
||||
client.closeEvents()
|
||||
}
|
||||
b.mu.Unlock()
|
||||
|
||||
b.wg.Wait()
|
||||
}
|
||||
|
||||
// signalDone closes the client done channel once.
|
||||
func (c *sseClient) signalDone() {
|
||||
c.doneOnce.Do(func() {
|
||||
close(c.done)
|
||||
})
|
||||
}
|
||||
|
||||
// closeEvents closes the client event channel once.
|
||||
func (c *sseClient) closeEvents() {
|
||||
c.eventsOnce.Do(func() {
|
||||
close(c.events)
|
||||
})
|
||||
}
|
||||
|
|
|
|||
52
sse_test.go
52
sse_test.go
|
|
@ -269,6 +269,58 @@ func TestWithSSE_Good_MultipleClients(t *testing.T) {
|
|||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestWithSSE_Good_DrainDisconnectsClients(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
broker := api.NewSSEBroker()
|
||||
e, err := api.New(api.WithSSE(broker))
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
srv := httptest.NewServer(e.Handler())
|
||||
defer srv.Close()
|
||||
|
||||
resp, err := http.Get(srv.URL + "/events")
|
||||
if err != nil {
|
||||
t.Fatalf("request failed: %v", err)
|
||||
}
|
||||
|
||||
waitForClients(t, broker, 1)
|
||||
|
||||
streamDone := make(chan struct{})
|
||||
go func() {
|
||||
defer close(streamDone)
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
for scanner.Scan() {
|
||||
}
|
||||
}()
|
||||
|
||||
drainDone := make(chan struct{})
|
||||
go func() {
|
||||
broker.Drain()
|
||||
close(drainDone)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-drainDone:
|
||||
case <-time.After(3 * time.Second):
|
||||
t.Fatal("timed out waiting for SSE drain to complete")
|
||||
}
|
||||
|
||||
select {
|
||||
case <-streamDone:
|
||||
case <-time.After(3 * time.Second):
|
||||
t.Fatal("timed out waiting for SSE client to disconnect")
|
||||
}
|
||||
|
||||
if got := broker.ClientCount(); got != 0 {
|
||||
t.Fatalf("expected 0 connected SSE clients after drain, got %d", got)
|
||||
}
|
||||
|
||||
_ = resp.Body.Close()
|
||||
}
|
||||
|
||||
// ── No SSE broker ────────────────────────────────────────────────────────
|
||||
|
||||
func TestNoSSEBroker_Good(t *testing.T) {
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue