From 22f8a6915c5e086c18e9db4d2ed9adb14e9f02af Mon Sep 17 00:00:00 2001 From: Snider Date: Fri, 20 Feb 2026 15:53:10 +0000 Subject: [PATCH] feat: add WebSocket endpoint and channel listing from StreamGroups Co-Authored-By: Virgil Co-Authored-By: Claude Opus 4.6 --- api.go | 18 +++++++ go.mod | 1 + go.sum | 2 + options.go | 9 ++++ websocket.go | 17 +++++++ websocket_test.go | 116 ++++++++++++++++++++++++++++++++++++++++++++++ 6 files changed, 163 insertions(+) create mode 100644 websocket.go create mode 100644 websocket_test.go diff --git a/api.go b/api.go index 2d8a763..44d0e39 100644 --- a/api.go +++ b/api.go @@ -24,6 +24,7 @@ type Engine struct { addr string groups []RouteGroup middlewares []gin.HandlerFunc + wsHandler http.Handler } // New creates an Engine with the given options. @@ -53,6 +54,18 @@ func (e *Engine) Register(group RouteGroup) { e.groups = append(e.groups, group) } +// Channels returns all WebSocket channel names from registered StreamGroups. +// Groups that do not implement StreamGroup are silently skipped. +func (e *Engine) Channels() []string { + var channels []string + for _, g := range e.groups { + if sg, ok := g.(StreamGroup); ok { + channels = append(channels, sg.Channels()...) + } + } + return channels +} + // Handler builds the Gin engine and returns it as an http.Handler. // Each call produces a fresh handler reflecting the current set of groups. func (e *Engine) Handler() http.Handler { @@ -112,5 +125,10 @@ func (e *Engine) build() *gin.Engine { g.RegisterRoutes(rg) } + // Mount WebSocket handler if configured. + if e.wsHandler != nil { + r.GET("/ws", wrapWSHandler(e.wsHandler)) + } + return r } diff --git a/go.mod b/go.mod index adf9a07..e944b3a 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,7 @@ go 1.25.5 require ( github.com/gin-contrib/cors v1.7.6 github.com/gin-gonic/gin v1.11.0 + github.com/gorilla/websocket v1.5.3 ) require ( diff --git a/go.sum b/go.sum index 7539074..f408e37 100644 --- a/go.sum +++ b/go.sum @@ -30,6 +30,8 @@ github.com/goccy/go-yaml v1.18.0/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7Lk github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= +github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y= diff --git a/options.go b/options.go index 8b41625..9bcd068 100644 --- a/options.go +++ b/options.go @@ -3,6 +3,7 @@ package api import ( + "net/http" "time" "github.com/gin-contrib/cors" @@ -68,3 +69,11 @@ func WithMiddleware(mw ...gin.HandlerFunc) Option { e.middlewares = append(e.middlewares, mw...) } } + +// WithWSHandler registers a WebSocket handler at GET /ws. +// Typically this wraps a go-ws Hub.Handler(). +func WithWSHandler(h http.Handler) Option { + return func(e *Engine) { + e.wsHandler = h + } +} diff --git a/websocket.go b/websocket.go new file mode 100644 index 0000000..8eb7a33 --- /dev/null +++ b/websocket.go @@ -0,0 +1,17 @@ +// SPDX-License-Identifier: EUPL-1.2 + +package api + +import ( + "net/http" + + "github.com/gin-gonic/gin" +) + +// wrapWSHandler adapts a standard http.Handler to a Gin handler for the /ws route. +// The underlying handler is responsible for upgrading the connection to WebSocket. +func wrapWSHandler(h http.Handler) gin.HandlerFunc { + return func(c *gin.Context) { + h.ServeHTTP(c.Writer, c.Request) + } +} diff --git a/websocket_test.go b/websocket_test.go new file mode 100644 index 0000000..06e42b3 --- /dev/null +++ b/websocket_test.go @@ -0,0 +1,116 @@ +// SPDX-License-Identifier: EUPL-1.2 + +package api_test + +import ( + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/gin-gonic/gin" + "github.com/gorilla/websocket" + + api "forge.lthn.ai/core/go-api" +) + +// ── Stub groups ───────────────────────────────────────────────────────── + +// wsStubGroup is a basic RouteGroup for WebSocket tests. +type wsStubGroup struct{} + +func (s *wsStubGroup) Name() string { return "wsstub" } +func (s *wsStubGroup) BasePath() string { return "/v1/wsstub" } +func (s *wsStubGroup) RegisterRoutes(rg *gin.RouterGroup) { + rg.GET("/ping", func(c *gin.Context) { + c.JSON(200, api.OK("pong")) + }) +} + +// wsStubStreamGroup embeds wsStubGroup and implements StreamGroup. +type wsStubStreamGroup struct{ wsStubGroup } + +func (s *wsStubStreamGroup) Channels() []string { + return []string{"wsstub.events", "wsstub.updates"} +} + +// ── WebSocket endpoint ────────────────────────────────────────────────── + +func TestWSEndpoint_Good(t *testing.T) { + gin.SetMode(gin.TestMode) + + // Create a WebSocket upgrader that writes "hello" to every connection. + upgrader := websocket.Upgrader{ + CheckOrigin: func(r *http.Request) bool { return true }, + } + wsHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + t.Logf("upgrade error: %v", err) + return + } + defer conn.Close() + _ = conn.WriteMessage(websocket.TextMessage, []byte("hello")) + }) + + e, err := api.New(api.WithWSHandler(wsHandler)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + srv := httptest.NewServer(e.Handler()) + defer srv.Close() + + // Dial the WebSocket endpoint. + wsURL := "ws" + strings.TrimPrefix(srv.URL, "http") + "/ws" + conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) + if err != nil { + t.Fatalf("failed to dial WebSocket: %v", err) + } + defer conn.Close() + + _, msg, err := conn.ReadMessage() + if err != nil { + t.Fatalf("failed to read message: %v", err) + } + if string(msg) != "hello" { + t.Fatalf("expected message=%q, got %q", "hello", string(msg)) + } +} + +func TestNoWSHandler_Good(t *testing.T) { + gin.SetMode(gin.TestMode) + + // Without WithWSHandler, GET /ws should return 404. + e, _ := api.New() + + h := e.Handler() + w := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodGet, "/ws", nil) + h.ServeHTTP(w, req) + + if w.Code != http.StatusNotFound { + t.Fatalf("expected 404 for /ws without handler, got %d", w.Code) + } +} + +// ── Channel listing ───────────────────────────────────────────────────── + +func TestChannelListing_Good(t *testing.T) { + e, _ := api.New() + + // Register a plain RouteGroup (no channels) and a StreamGroup. + e.Register(&wsStubGroup{}) + e.Register(&wsStubStreamGroup{}) + + channels := e.Channels() + if len(channels) != 2 { + t.Fatalf("expected 2 channels, got %d", len(channels)) + } + if channels[0] != "wsstub.events" { + t.Fatalf("expected channels[0]=%q, got %q", "wsstub.events", channels[0]) + } + if channels[1] != "wsstub.updates" { + t.Fatalf("expected channels[1]=%q, got %q", "wsstub.updates", channels[1]) + } +}