From 8fc307b4546d57103e595a222e5592f495e239c2 Mon Sep 17 00:00:00 2001 From: Snider Date: Sat, 21 Feb 2026 00:01:04 +0000 Subject: [PATCH] feat: add WithSSE server-sent events support Add SSEBroker type that manages SSE client connections and broadcasts events to subscribed clients. Clients connect via GET /events and optionally filter by channel with ?channel=. Includes Publish, ClientCount, and Drain methods for event broadcasting and lifecycle management. Co-Authored-By: Virgil --- api.go | 6 + options.go | 10 ++ sse.go | 146 +++++++++++++++++++++++++ sse_test.go | 308 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 470 insertions(+) create mode 100644 sse.go create mode 100644 sse_test.go diff --git a/api.go b/api.go index 63e9f4e..27556d4 100644 --- a/api.go +++ b/api.go @@ -25,6 +25,7 @@ type Engine struct { groups []RouteGroup middlewares []gin.HandlerFunc wsHandler http.Handler + sseBroker *SSEBroker swaggerEnabled bool swaggerTitle string swaggerDesc string @@ -134,6 +135,11 @@ func (e *Engine) build() *gin.Engine { r.GET("/ws", wrapWSHandler(e.wsHandler)) } + // Mount SSE endpoint if configured. + if e.sseBroker != nil { + r.GET("/events", e.sseBroker.Handler()) + } + // Mount Swagger UI if enabled. if e.swaggerEnabled { registerSwagger(r, e.swaggerTitle, e.swaggerDesc, e.swaggerVersion) diff --git a/options.go b/options.go index 3b0fba0..1a8d922 100644 --- a/options.go +++ b/options.go @@ -253,3 +253,13 @@ func WithHTTPSign(secrets httpsign.Secrets, opts ...httpsign.Option) Option { e.middlewares = append(e.middlewares, auth.Authenticated()) } } + +// WithSSE registers a Server-Sent Events broker at GET /events. +// Clients connect to the endpoint and receive a streaming text/event-stream +// response. The broker manages client connections and broadcasts events +// published via its Publish method. +func WithSSE(broker *SSEBroker) Option { + return func(e *Engine) { + e.sseBroker = broker + } +} diff --git a/sse.go b/sse.go new file mode 100644 index 0000000..c3701c1 --- /dev/null +++ b/sse.go @@ -0,0 +1,146 @@ +// SPDX-License-Identifier: EUPL-1.2 + +package api + +import ( + "encoding/json" + "fmt" + "net/http" + "sync" + + "github.com/gin-gonic/gin" +) + +// SSEBroker manages Server-Sent Events connections and broadcasts events +// to subscribed clients. Clients connect via a GET endpoint and receive +// a streaming text/event-stream response. Each client may optionally +// subscribe to a specific channel via the ?channel= query parameter. +type SSEBroker struct { + mu sync.RWMutex + clients map[*sseClient]struct{} +} + +// sseClient represents a single connected SSE consumer. +type sseClient struct { + channel string + events chan sseEvent + done chan struct{} +} + +// sseEvent is an internal representation of a single SSE message. +type sseEvent struct { + Event string + Data string +} + +// NewSSEBroker creates a ready-to-use SSE broker. +func NewSSEBroker() *SSEBroker { + return &SSEBroker{ + clients: make(map[*sseClient]struct{}), + } +} + +// Publish sends an event to all clients subscribed to the given channel. +// Clients subscribed to an empty channel (no ?channel= param) receive +// events on every channel. The data value is JSON-encoded before sending. +func (b *SSEBroker) Publish(channel, event string, data any) { + encoded, err := json.Marshal(data) + if err != nil { + return + } + + msg := sseEvent{ + Event: event, + Data: string(encoded), + } + + b.mu.RLock() + defer b.mu.RUnlock() + + for client := range b.clients { + // Send to clients on the matching channel, or clients with no channel filter. + if client.channel == "" || client.channel == channel { + select { + case client.events <- msg: + case <-client.done: + default: + // Drop event if client buffer is full. + } + } + } +} + +// Handler returns a Gin handler for the SSE endpoint. Clients connect with +// a GET request and receive events as text/event-stream. An optional +// ?channel= query parameter subscribes the client to a specific channel. +func (b *SSEBroker) Handler() gin.HandlerFunc { + return func(c *gin.Context) { + channel := c.Query("channel") + + client := &sseClient{ + channel: channel, + events: make(chan sseEvent, 64), + done: make(chan struct{}), + } + + b.mu.Lock() + b.clients[client] = struct{}{} + b.mu.Unlock() + + defer func() { + close(client.done) + b.mu.Lock() + delete(b.clients, client) + b.mu.Unlock() + }() + + // Set SSE headers. + c.Writer.Header().Set("Content-Type", "text/event-stream") + c.Writer.Header().Set("Cache-Control", "no-cache") + c.Writer.Header().Set("Connection", "keep-alive") + c.Writer.Header().Set("X-Accel-Buffering", "no") + c.Status(http.StatusOK) + c.Writer.Flush() + + // Stream events until client disconnects. + ctx := c.Request.Context() + for { + select { + case <-ctx.Done(): + return + case evt := <-client.events: + _, err := fmt.Fprintf(c.Writer, "event: %s\ndata: %s\n\n", evt.Event, evt.Data) + if err != nil { + return + } + // Flush to ensure the event is sent immediately. + if f, ok := c.Writer.(http.Flusher); ok { + f.Flush() + } + } + } + } +} + +// ClientCount returns the number of currently connected SSE clients. +func (b *SSEBroker) ClientCount() int { + b.mu.RLock() + defer b.mu.RUnlock() + return len(b.clients) +} + +// Drain closes all connected clients by writing an empty response. +// Useful for graceful shutdown. +func (b *SSEBroker) Drain() { + b.mu.Lock() + defer b.mu.Unlock() + for client := range b.clients { + select { + case <-client.done: + default: + // Write EOF to trigger client disconnect via their event loop. + close(client.events) + } + } +} + diff --git a/sse_test.go b/sse_test.go new file mode 100644 index 0000000..3cba62a --- /dev/null +++ b/sse_test.go @@ -0,0 +1,308 @@ +// SPDX-License-Identifier: EUPL-1.2 + +package api_test + +import ( + "bufio" + "net/http" + "net/http/httptest" + "strings" + "sync" + "testing" + "time" + + "github.com/gin-gonic/gin" + + api "forge.lthn.ai/core/go-api" +) + +// ── SSE endpoint ──────────────────────────────────────────────────────── + +func TestWithSSE_Good_EndpointExists(t *testing.T) { + gin.SetMode(gin.TestMode) + + broker := api.NewSSEBroker() + e, err := api.New(api.WithSSE(broker)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + srv := httptest.NewServer(e.Handler()) + defer srv.Close() + + resp, err := http.Get(srv.URL + "/events") + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected 200, got %d", resp.StatusCode) + } + + ct := resp.Header.Get("Content-Type") + if !strings.HasPrefix(ct, "text/event-stream") { + t.Fatalf("expected Content-Type starting with text/event-stream, got %q", ct) + } +} + +func TestWithSSE_Good_ReceivesPublishedEvent(t *testing.T) { + gin.SetMode(gin.TestMode) + + broker := api.NewSSEBroker() + e, err := api.New(api.WithSSE(broker)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + srv := httptest.NewServer(e.Handler()) + defer srv.Close() + + resp, err := http.Get(srv.URL + "/events") + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + + // Wait for the client to register before publishing. + waitForClients(t, broker, 1) + + // Publish an event on the default channel. + broker.Publish("test", "greeting", map[string]string{"msg": "hello"}) + + // Read SSE lines from the response body. + scanner := bufio.NewScanner(resp.Body) + var eventLine, dataLine string + + deadline := time.After(3 * time.Second) + done := make(chan struct{}) + + go func() { + defer close(done) + for scanner.Scan() { + line := scanner.Text() + if strings.HasPrefix(line, "event: ") { + eventLine = strings.TrimPrefix(line, "event: ") + } + if strings.HasPrefix(line, "data: ") { + dataLine = strings.TrimPrefix(line, "data: ") + return + } + } + }() + + select { + case <-done: + case <-deadline: + t.Fatal("timed out waiting for SSE event") + } + + if eventLine != "greeting" { + t.Fatalf("expected event=%q, got %q", "greeting", eventLine) + } + if !strings.Contains(dataLine, `"msg":"hello"`) { + t.Fatalf("expected data containing msg:hello, got %q", dataLine) + } +} + +func TestWithSSE_Good_ChannelFiltering(t *testing.T) { + gin.SetMode(gin.TestMode) + + broker := api.NewSSEBroker() + e, err := api.New(api.WithSSE(broker)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + srv := httptest.NewServer(e.Handler()) + defer srv.Close() + + // Subscribe to channel "foo" only. + resp, err := http.Get(srv.URL + "/events?channel=foo") + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + + // Wait for client to register. + waitForClients(t, broker, 1) + + // Publish to "bar" (should not be received), then to "foo" (should be received). + broker.Publish("bar", "ignore", "bar-data") + // Small delay to ensure ordering. + time.Sleep(50 * time.Millisecond) + broker.Publish("foo", "match", "foo-data") + + // Read the first event from the stream. + scanner := bufio.NewScanner(resp.Body) + var eventLine string + + deadline := time.After(3 * time.Second) + done := make(chan struct{}) + + go func() { + defer close(done) + for scanner.Scan() { + line := scanner.Text() + if strings.HasPrefix(line, "event: ") { + eventLine = strings.TrimPrefix(line, "event: ") + // Read past the data and blank line. + scanner.Scan() // data line + return + } + } + }() + + select { + case <-done: + case <-deadline: + t.Fatal("timed out waiting for SSE event") + } + + // The first event received should be "match" (from channel foo), not "ignore" (from bar). + if eventLine != "match" { + t.Fatalf("expected event=%q, got %q (channel filtering failed)", "match", eventLine) + } +} + +func TestWithSSE_Good_CombinesWithOtherMiddleware(t *testing.T) { + gin.SetMode(gin.TestMode) + + broker := api.NewSSEBroker() + e, err := api.New( + api.WithRequestID(), + api.WithSSE(broker), + ) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + srv := httptest.NewServer(e.Handler()) + defer srv.Close() + + resp, err := http.Get(srv.URL + "/events") + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected 200, got %d", resp.StatusCode) + } + + // RequestID middleware should have injected the header. + reqID := resp.Header.Get("X-Request-ID") + if reqID == "" { + t.Fatal("expected X-Request-ID header from RequestID middleware") + } + + ct := resp.Header.Get("Content-Type") + if !strings.HasPrefix(ct, "text/event-stream") { + t.Fatalf("expected Content-Type starting with text/event-stream, got %q", ct) + } +} + +func TestWithSSE_Good_MultipleClients(t *testing.T) { + gin.SetMode(gin.TestMode) + + broker := api.NewSSEBroker() + e, err := api.New(api.WithSSE(broker)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + srv := httptest.NewServer(e.Handler()) + defer srv.Close() + + // Connect two clients. + resp1, err := http.Get(srv.URL + "/events") + if err != nil { + t.Fatalf("client 1 request failed: %v", err) + } + defer resp1.Body.Close() + + resp2, err := http.Get(srv.URL + "/events") + if err != nil { + t.Fatalf("client 2 request failed: %v", err) + } + defer resp2.Body.Close() + + // Wait for both clients to register. + waitForClients(t, broker, 2) + + // Publish a single event. + broker.Publish("broadcast", "ping", "pong") + + // Both clients should receive it. + var wg sync.WaitGroup + wg.Add(2) + + readEvent := func(name string, resp *http.Response) { + defer wg.Done() + scanner := bufio.NewScanner(resp.Body) + deadline := time.After(3 * time.Second) + done := make(chan string, 1) + + go func() { + for scanner.Scan() { + line := scanner.Text() + if strings.HasPrefix(line, "event: ") { + done <- strings.TrimPrefix(line, "event: ") + return + } + } + }() + + select { + case evt := <-done: + if evt != "ping" { + t.Errorf("%s: expected event=%q, got %q", name, "ping", evt) + } + case <-deadline: + t.Errorf("%s: timed out waiting for SSE event", name) + } + } + + go readEvent("client1", resp1) + go readEvent("client2", resp2) + + wg.Wait() +} + +// ── No SSE broker ──────────────────────────────────────────────────────── + +func TestNoSSEBroker_Good(t *testing.T) { + gin.SetMode(gin.TestMode) + + // Without WithSSE, GET /events should return 404. + e, _ := api.New() + + h := e.Handler() + w := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodGet, "/events", nil) + h.ServeHTTP(w, req) + + if w.Code != http.StatusNotFound { + t.Fatalf("expected 404 for /events without broker, got %d", w.Code) + } +} + +// ── Helpers ────────────────────────────────────────────────────────────── + +// waitForClients polls the broker until the expected number of clients +// are connected or the timeout expires. +func waitForClients(t *testing.T, broker *api.SSEBroker, want int) { + t.Helper() + deadline := time.After(2 * time.Second) + for { + if broker.ClientCount() >= want { + return + } + select { + case <-deadline: + t.Fatalf("timed out waiting for %d SSE clients (have %d)", want, broker.ClientCount()) + default: + time.Sleep(10 * time.Millisecond) + } + } +}