feat: add WebSocket endpoint and channel listing from StreamGroups

Co-Authored-By: Virgil <virgil@lethean.io>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Snider 2026-02-20 15:53:10 +00:00
parent d21734d8d9
commit 22f8a6915c
6 changed files with 163 additions and 0 deletions

18
api.go
View file

@ -24,6 +24,7 @@ type Engine struct {
addr string
groups []RouteGroup
middlewares []gin.HandlerFunc
wsHandler http.Handler
}
// New creates an Engine with the given options.
@ -53,6 +54,18 @@ func (e *Engine) Register(group RouteGroup) {
e.groups = append(e.groups, group)
}
// Channels returns all WebSocket channel names from registered StreamGroups.
// Groups that do not implement StreamGroup are silently skipped.
func (e *Engine) Channels() []string {
var channels []string
for _, g := range e.groups {
if sg, ok := g.(StreamGroup); ok {
channels = append(channels, sg.Channels()...)
}
}
return channels
}
// Handler builds the Gin engine and returns it as an http.Handler.
// Each call produces a fresh handler reflecting the current set of groups.
func (e *Engine) Handler() http.Handler {
@ -112,5 +125,10 @@ func (e *Engine) build() *gin.Engine {
g.RegisterRoutes(rg)
}
// Mount WebSocket handler if configured.
if e.wsHandler != nil {
r.GET("/ws", wrapWSHandler(e.wsHandler))
}
return r
}

1
go.mod
View file

@ -5,6 +5,7 @@ go 1.25.5
require (
github.com/gin-contrib/cors v1.7.6
github.com/gin-gonic/gin v1.11.0
github.com/gorilla/websocket v1.5.3
)
require (

2
go.sum
View file

@ -30,6 +30,8 @@ github.com/goccy/go-yaml v1.18.0/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7Lk
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y=

View file

@ -3,6 +3,7 @@
package api
import (
"net/http"
"time"
"github.com/gin-contrib/cors"
@ -68,3 +69,11 @@ func WithMiddleware(mw ...gin.HandlerFunc) Option {
e.middlewares = append(e.middlewares, mw...)
}
}
// WithWSHandler registers a WebSocket handler at GET /ws.
// Typically this wraps a go-ws Hub.Handler().
func WithWSHandler(h http.Handler) Option {
return func(e *Engine) {
e.wsHandler = h
}
}

17
websocket.go Normal file
View file

@ -0,0 +1,17 @@
// SPDX-License-Identifier: EUPL-1.2
package api
import (
"net/http"
"github.com/gin-gonic/gin"
)
// wrapWSHandler adapts a standard http.Handler to a Gin handler for the /ws route.
// The underlying handler is responsible for upgrading the connection to WebSocket.
func wrapWSHandler(h http.Handler) gin.HandlerFunc {
return func(c *gin.Context) {
h.ServeHTTP(c.Writer, c.Request)
}
}

116
websocket_test.go Normal file
View file

@ -0,0 +1,116 @@
// SPDX-License-Identifier: EUPL-1.2
package api_test
import (
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
api "forge.lthn.ai/core/go-api"
)
// ── Stub groups ─────────────────────────────────────────────────────────
// wsStubGroup is a basic RouteGroup for WebSocket tests.
type wsStubGroup struct{}
func (s *wsStubGroup) Name() string { return "wsstub" }
func (s *wsStubGroup) BasePath() string { return "/v1/wsstub" }
func (s *wsStubGroup) RegisterRoutes(rg *gin.RouterGroup) {
rg.GET("/ping", func(c *gin.Context) {
c.JSON(200, api.OK("pong"))
})
}
// wsStubStreamGroup embeds wsStubGroup and implements StreamGroup.
type wsStubStreamGroup struct{ wsStubGroup }
func (s *wsStubStreamGroup) Channels() []string {
return []string{"wsstub.events", "wsstub.updates"}
}
// ── WebSocket endpoint ──────────────────────────────────────────────────
func TestWSEndpoint_Good(t *testing.T) {
gin.SetMode(gin.TestMode)
// Create a WebSocket upgrader that writes "hello" to every connection.
upgrader := websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool { return true },
}
wsHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
t.Logf("upgrade error: %v", err)
return
}
defer conn.Close()
_ = conn.WriteMessage(websocket.TextMessage, []byte("hello"))
})
e, err := api.New(api.WithWSHandler(wsHandler))
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
srv := httptest.NewServer(e.Handler())
defer srv.Close()
// Dial the WebSocket endpoint.
wsURL := "ws" + strings.TrimPrefix(srv.URL, "http") + "/ws"
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
if err != nil {
t.Fatalf("failed to dial WebSocket: %v", err)
}
defer conn.Close()
_, msg, err := conn.ReadMessage()
if err != nil {
t.Fatalf("failed to read message: %v", err)
}
if string(msg) != "hello" {
t.Fatalf("expected message=%q, got %q", "hello", string(msg))
}
}
func TestNoWSHandler_Good(t *testing.T) {
gin.SetMode(gin.TestMode)
// Without WithWSHandler, GET /ws should return 404.
e, _ := api.New()
h := e.Handler()
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodGet, "/ws", nil)
h.ServeHTTP(w, req)
if w.Code != http.StatusNotFound {
t.Fatalf("expected 404 for /ws without handler, got %d", w.Code)
}
}
// ── Channel listing ─────────────────────────────────────────────────────
func TestChannelListing_Good(t *testing.T) {
e, _ := api.New()
// Register a plain RouteGroup (no channels) and a StreamGroup.
e.Register(&wsStubGroup{})
e.Register(&wsStubStreamGroup{})
channels := e.Channels()
if len(channels) != 2 {
t.Fatalf("expected 2 channels, got %d", len(channels))
}
if channels[0] != "wsstub.events" {
t.Fatalf("expected channels[0]=%q, got %q", "wsstub.events", channels[0])
}
if channels[1] != "wsstub.updates" {
t.Fatalf("expected channels[1]=%q, got %q", "wsstub.updates", channels[1])
}
}