feat: add WithSSE server-sent events support
Add SSEBroker type that manages SSE client connections and broadcasts events to subscribed clients. Clients connect via GET /events and optionally filter by channel with ?channel=<name>. Includes Publish, ClientCount, and Drain methods for event broadcasting and lifecycle management. Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
parent
7b3f99e421
commit
8fc307b454
4 changed files with 470 additions and 0 deletions
6
api.go
6
api.go
|
|
@ -25,6 +25,7 @@ type Engine struct {
|
|||
groups []RouteGroup
|
||||
middlewares []gin.HandlerFunc
|
||||
wsHandler http.Handler
|
||||
sseBroker *SSEBroker
|
||||
swaggerEnabled bool
|
||||
swaggerTitle string
|
||||
swaggerDesc string
|
||||
|
|
@ -134,6 +135,11 @@ func (e *Engine) build() *gin.Engine {
|
|||
r.GET("/ws", wrapWSHandler(e.wsHandler))
|
||||
}
|
||||
|
||||
// Mount SSE endpoint if configured.
|
||||
if e.sseBroker != nil {
|
||||
r.GET("/events", e.sseBroker.Handler())
|
||||
}
|
||||
|
||||
// Mount Swagger UI if enabled.
|
||||
if e.swaggerEnabled {
|
||||
registerSwagger(r, e.swaggerTitle, e.swaggerDesc, e.swaggerVersion)
|
||||
|
|
|
|||
10
options.go
10
options.go
|
|
@ -253,3 +253,13 @@ func WithHTTPSign(secrets httpsign.Secrets, opts ...httpsign.Option) Option {
|
|||
e.middlewares = append(e.middlewares, auth.Authenticated())
|
||||
}
|
||||
}
|
||||
|
||||
// WithSSE registers a Server-Sent Events broker at GET /events.
|
||||
// Clients connect to the endpoint and receive a streaming text/event-stream
|
||||
// response. The broker manages client connections and broadcasts events
|
||||
// published via its Publish method.
|
||||
func WithSSE(broker *SSEBroker) Option {
|
||||
return func(e *Engine) {
|
||||
e.sseBroker = broker
|
||||
}
|
||||
}
|
||||
|
|
|
|||
146
sse.go
Normal file
146
sse.go
Normal file
|
|
@ -0,0 +1,146 @@
|
|||
// SPDX-License-Identifier: EUPL-1.2
|
||||
|
||||
package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sync"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// SSEBroker manages Server-Sent Events connections and broadcasts events
|
||||
// to subscribed clients. Clients connect via a GET endpoint and receive
|
||||
// a streaming text/event-stream response. Each client may optionally
|
||||
// subscribe to a specific channel via the ?channel= query parameter.
|
||||
type SSEBroker struct {
|
||||
mu sync.RWMutex
|
||||
clients map[*sseClient]struct{}
|
||||
}
|
||||
|
||||
// sseClient represents a single connected SSE consumer.
|
||||
type sseClient struct {
|
||||
channel string
|
||||
events chan sseEvent
|
||||
done chan struct{}
|
||||
}
|
||||
|
||||
// sseEvent is an internal representation of a single SSE message.
|
||||
type sseEvent struct {
|
||||
Event string
|
||||
Data string
|
||||
}
|
||||
|
||||
// NewSSEBroker creates a ready-to-use SSE broker.
|
||||
func NewSSEBroker() *SSEBroker {
|
||||
return &SSEBroker{
|
||||
clients: make(map[*sseClient]struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// Publish sends an event to all clients subscribed to the given channel.
|
||||
// Clients subscribed to an empty channel (no ?channel= param) receive
|
||||
// events on every channel. The data value is JSON-encoded before sending.
|
||||
func (b *SSEBroker) Publish(channel, event string, data any) {
|
||||
encoded, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
msg := sseEvent{
|
||||
Event: event,
|
||||
Data: string(encoded),
|
||||
}
|
||||
|
||||
b.mu.RLock()
|
||||
defer b.mu.RUnlock()
|
||||
|
||||
for client := range b.clients {
|
||||
// Send to clients on the matching channel, or clients with no channel filter.
|
||||
if client.channel == "" || client.channel == channel {
|
||||
select {
|
||||
case client.events <- msg:
|
||||
case <-client.done:
|
||||
default:
|
||||
// Drop event if client buffer is full.
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Handler returns a Gin handler for the SSE endpoint. Clients connect with
|
||||
// a GET request and receive events as text/event-stream. An optional
|
||||
// ?channel=<name> query parameter subscribes the client to a specific channel.
|
||||
func (b *SSEBroker) Handler() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
channel := c.Query("channel")
|
||||
|
||||
client := &sseClient{
|
||||
channel: channel,
|
||||
events: make(chan sseEvent, 64),
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
|
||||
b.mu.Lock()
|
||||
b.clients[client] = struct{}{}
|
||||
b.mu.Unlock()
|
||||
|
||||
defer func() {
|
||||
close(client.done)
|
||||
b.mu.Lock()
|
||||
delete(b.clients, client)
|
||||
b.mu.Unlock()
|
||||
}()
|
||||
|
||||
// Set SSE headers.
|
||||
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
||||
c.Writer.Header().Set("Cache-Control", "no-cache")
|
||||
c.Writer.Header().Set("Connection", "keep-alive")
|
||||
c.Writer.Header().Set("X-Accel-Buffering", "no")
|
||||
c.Status(http.StatusOK)
|
||||
c.Writer.Flush()
|
||||
|
||||
// Stream events until client disconnects.
|
||||
ctx := c.Request.Context()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case evt := <-client.events:
|
||||
_, err := fmt.Fprintf(c.Writer, "event: %s\ndata: %s\n\n", evt.Event, evt.Data)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
// Flush to ensure the event is sent immediately.
|
||||
if f, ok := c.Writer.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ClientCount returns the number of currently connected SSE clients.
|
||||
func (b *SSEBroker) ClientCount() int {
|
||||
b.mu.RLock()
|
||||
defer b.mu.RUnlock()
|
||||
return len(b.clients)
|
||||
}
|
||||
|
||||
// Drain closes all connected clients by writing an empty response.
|
||||
// Useful for graceful shutdown.
|
||||
func (b *SSEBroker) Drain() {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
for client := range b.clients {
|
||||
select {
|
||||
case <-client.done:
|
||||
default:
|
||||
// Write EOF to trigger client disconnect via their event loop.
|
||||
close(client.events)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
308
sse_test.go
Normal file
308
sse_test.go
Normal file
|
|
@ -0,0 +1,308 @@
|
|||
// SPDX-License-Identifier: EUPL-1.2
|
||||
|
||||
package api_test
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
api "forge.lthn.ai/core/go-api"
|
||||
)
|
||||
|
||||
// ── SSE endpoint ────────────────────────────────────────────────────────
|
||||
|
||||
func TestWithSSE_Good_EndpointExists(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
broker := api.NewSSEBroker()
|
||||
e, err := api.New(api.WithSSE(broker))
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
srv := httptest.NewServer(e.Handler())
|
||||
defer srv.Close()
|
||||
|
||||
resp, err := http.Get(srv.URL + "/events")
|
||||
if err != nil {
|
||||
t.Fatalf("request failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
ct := resp.Header.Get("Content-Type")
|
||||
if !strings.HasPrefix(ct, "text/event-stream") {
|
||||
t.Fatalf("expected Content-Type starting with text/event-stream, got %q", ct)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithSSE_Good_ReceivesPublishedEvent(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
broker := api.NewSSEBroker()
|
||||
e, err := api.New(api.WithSSE(broker))
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
srv := httptest.NewServer(e.Handler())
|
||||
defer srv.Close()
|
||||
|
||||
resp, err := http.Get(srv.URL + "/events")
|
||||
if err != nil {
|
||||
t.Fatalf("request failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Wait for the client to register before publishing.
|
||||
waitForClients(t, broker, 1)
|
||||
|
||||
// Publish an event on the default channel.
|
||||
broker.Publish("test", "greeting", map[string]string{"msg": "hello"})
|
||||
|
||||
// Read SSE lines from the response body.
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
var eventLine, dataLine string
|
||||
|
||||
deadline := time.After(3 * time.Second)
|
||||
done := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
defer close(done)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if strings.HasPrefix(line, "event: ") {
|
||||
eventLine = strings.TrimPrefix(line, "event: ")
|
||||
}
|
||||
if strings.HasPrefix(line, "data: ") {
|
||||
dataLine = strings.TrimPrefix(line, "data: ")
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-deadline:
|
||||
t.Fatal("timed out waiting for SSE event")
|
||||
}
|
||||
|
||||
if eventLine != "greeting" {
|
||||
t.Fatalf("expected event=%q, got %q", "greeting", eventLine)
|
||||
}
|
||||
if !strings.Contains(dataLine, `"msg":"hello"`) {
|
||||
t.Fatalf("expected data containing msg:hello, got %q", dataLine)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithSSE_Good_ChannelFiltering(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
broker := api.NewSSEBroker()
|
||||
e, err := api.New(api.WithSSE(broker))
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
srv := httptest.NewServer(e.Handler())
|
||||
defer srv.Close()
|
||||
|
||||
// Subscribe to channel "foo" only.
|
||||
resp, err := http.Get(srv.URL + "/events?channel=foo")
|
||||
if err != nil {
|
||||
t.Fatalf("request failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Wait for client to register.
|
||||
waitForClients(t, broker, 1)
|
||||
|
||||
// Publish to "bar" (should not be received), then to "foo" (should be received).
|
||||
broker.Publish("bar", "ignore", "bar-data")
|
||||
// Small delay to ensure ordering.
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
broker.Publish("foo", "match", "foo-data")
|
||||
|
||||
// Read the first event from the stream.
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
var eventLine string
|
||||
|
||||
deadline := time.After(3 * time.Second)
|
||||
done := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
defer close(done)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if strings.HasPrefix(line, "event: ") {
|
||||
eventLine = strings.TrimPrefix(line, "event: ")
|
||||
// Read past the data and blank line.
|
||||
scanner.Scan() // data line
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-deadline:
|
||||
t.Fatal("timed out waiting for SSE event")
|
||||
}
|
||||
|
||||
// The first event received should be "match" (from channel foo), not "ignore" (from bar).
|
||||
if eventLine != "match" {
|
||||
t.Fatalf("expected event=%q, got %q (channel filtering failed)", "match", eventLine)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithSSE_Good_CombinesWithOtherMiddleware(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
broker := api.NewSSEBroker()
|
||||
e, err := api.New(
|
||||
api.WithRequestID(),
|
||||
api.WithSSE(broker),
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
srv := httptest.NewServer(e.Handler())
|
||||
defer srv.Close()
|
||||
|
||||
resp, err := http.Get(srv.URL + "/events")
|
||||
if err != nil {
|
||||
t.Fatalf("request failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// RequestID middleware should have injected the header.
|
||||
reqID := resp.Header.Get("X-Request-ID")
|
||||
if reqID == "" {
|
||||
t.Fatal("expected X-Request-ID header from RequestID middleware")
|
||||
}
|
||||
|
||||
ct := resp.Header.Get("Content-Type")
|
||||
if !strings.HasPrefix(ct, "text/event-stream") {
|
||||
t.Fatalf("expected Content-Type starting with text/event-stream, got %q", ct)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithSSE_Good_MultipleClients(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
broker := api.NewSSEBroker()
|
||||
e, err := api.New(api.WithSSE(broker))
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
srv := httptest.NewServer(e.Handler())
|
||||
defer srv.Close()
|
||||
|
||||
// Connect two clients.
|
||||
resp1, err := http.Get(srv.URL + "/events")
|
||||
if err != nil {
|
||||
t.Fatalf("client 1 request failed: %v", err)
|
||||
}
|
||||
defer resp1.Body.Close()
|
||||
|
||||
resp2, err := http.Get(srv.URL + "/events")
|
||||
if err != nil {
|
||||
t.Fatalf("client 2 request failed: %v", err)
|
||||
}
|
||||
defer resp2.Body.Close()
|
||||
|
||||
// Wait for both clients to register.
|
||||
waitForClients(t, broker, 2)
|
||||
|
||||
// Publish a single event.
|
||||
broker.Publish("broadcast", "ping", "pong")
|
||||
|
||||
// Both clients should receive it.
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
|
||||
readEvent := func(name string, resp *http.Response) {
|
||||
defer wg.Done()
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
deadline := time.After(3 * time.Second)
|
||||
done := make(chan string, 1)
|
||||
|
||||
go func() {
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if strings.HasPrefix(line, "event: ") {
|
||||
done <- strings.TrimPrefix(line, "event: ")
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
select {
|
||||
case evt := <-done:
|
||||
if evt != "ping" {
|
||||
t.Errorf("%s: expected event=%q, got %q", name, "ping", evt)
|
||||
}
|
||||
case <-deadline:
|
||||
t.Errorf("%s: timed out waiting for SSE event", name)
|
||||
}
|
||||
}
|
||||
|
||||
go readEvent("client1", resp1)
|
||||
go readEvent("client2", resp2)
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// ── No SSE broker ────────────────────────────────────────────────────────
|
||||
|
||||
func TestNoSSEBroker_Good(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
// Without WithSSE, GET /events should return 404.
|
||||
e, _ := api.New()
|
||||
|
||||
h := e.Handler()
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest(http.MethodGet, "/events", nil)
|
||||
h.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Fatalf("expected 404 for /events without broker, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// ── Helpers ──────────────────────────────────────────────────────────────
|
||||
|
||||
// waitForClients polls the broker until the expected number of clients
|
||||
// are connected or the timeout expires.
|
||||
func waitForClients(t *testing.T, broker *api.SSEBroker, want int) {
|
||||
t.Helper()
|
||||
deadline := time.After(2 * time.Second)
|
||||
for {
|
||||
if broker.ClientCount() >= want {
|
||||
return
|
||||
}
|
||||
select {
|
||||
case <-deadline:
|
||||
t.Fatalf("timed out waiting for %d SSE clients (have %d)", want, broker.ClientCount())
|
||||
default:
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Reference in a new issue