feat: add WebSocket endpoint and channel listing from StreamGroups
Co-Authored-By: Virgil <virgil@lethean.io> Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
d21734d8d9
commit
22f8a6915c
6 changed files with 163 additions and 0 deletions
18
api.go
18
api.go
|
|
@ -24,6 +24,7 @@ type Engine struct {
|
|||
addr string
|
||||
groups []RouteGroup
|
||||
middlewares []gin.HandlerFunc
|
||||
wsHandler http.Handler
|
||||
}
|
||||
|
||||
// New creates an Engine with the given options.
|
||||
|
|
@ -53,6 +54,18 @@ func (e *Engine) Register(group RouteGroup) {
|
|||
e.groups = append(e.groups, group)
|
||||
}
|
||||
|
||||
// Channels returns all WebSocket channel names from registered StreamGroups.
|
||||
// Groups that do not implement StreamGroup are silently skipped.
|
||||
func (e *Engine) Channels() []string {
|
||||
var channels []string
|
||||
for _, g := range e.groups {
|
||||
if sg, ok := g.(StreamGroup); ok {
|
||||
channels = append(channels, sg.Channels()...)
|
||||
}
|
||||
}
|
||||
return channels
|
||||
}
|
||||
|
||||
// Handler builds the Gin engine and returns it as an http.Handler.
|
||||
// Each call produces a fresh handler reflecting the current set of groups.
|
||||
func (e *Engine) Handler() http.Handler {
|
||||
|
|
@ -112,5 +125,10 @@ func (e *Engine) build() *gin.Engine {
|
|||
g.RegisterRoutes(rg)
|
||||
}
|
||||
|
||||
// Mount WebSocket handler if configured.
|
||||
if e.wsHandler != nil {
|
||||
r.GET("/ws", wrapWSHandler(e.wsHandler))
|
||||
}
|
||||
|
||||
return r
|
||||
}
|
||||
|
|
|
|||
1
go.mod
1
go.mod
|
|
@ -5,6 +5,7 @@ go 1.25.5
|
|||
require (
|
||||
github.com/gin-contrib/cors v1.7.6
|
||||
github.com/gin-gonic/gin v1.11.0
|
||||
github.com/gorilla/websocket v1.5.3
|
||||
)
|
||||
|
||||
require (
|
||||
|
|
|
|||
2
go.sum
2
go.sum
|
|
@ -30,6 +30,8 @@ github.com/goccy/go-yaml v1.18.0/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7Lk
|
|||
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
|
||||
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
|
||||
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
|
||||
github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y=
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gin-contrib/cors"
|
||||
|
|
@ -68,3 +69,11 @@ func WithMiddleware(mw ...gin.HandlerFunc) Option {
|
|||
e.middlewares = append(e.middlewares, mw...)
|
||||
}
|
||||
}
|
||||
|
||||
// WithWSHandler registers a WebSocket handler at GET /ws.
|
||||
// Typically this wraps a go-ws Hub.Handler().
|
||||
func WithWSHandler(h http.Handler) Option {
|
||||
return func(e *Engine) {
|
||||
e.wsHandler = h
|
||||
}
|
||||
}
|
||||
|
|
|
|||
17
websocket.go
Normal file
17
websocket.go
Normal file
|
|
@ -0,0 +1,17 @@
|
|||
// SPDX-License-Identifier: EUPL-1.2
|
||||
|
||||
package api
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// wrapWSHandler adapts a standard http.Handler to a Gin handler for the /ws route.
|
||||
// The underlying handler is responsible for upgrading the connection to WebSocket.
|
||||
func wrapWSHandler(h http.Handler) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
h.ServeHTTP(c.Writer, c.Request)
|
||||
}
|
||||
}
|
||||
116
websocket_test.go
Normal file
116
websocket_test.go
Normal file
|
|
@ -0,0 +1,116 @@
|
|||
// SPDX-License-Identifier: EUPL-1.2
|
||||
|
||||
package api_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/websocket"
|
||||
|
||||
api "forge.lthn.ai/core/go-api"
|
||||
)
|
||||
|
||||
// ── Stub groups ─────────────────────────────────────────────────────────
|
||||
|
||||
// wsStubGroup is a basic RouteGroup for WebSocket tests.
|
||||
type wsStubGroup struct{}
|
||||
|
||||
func (s *wsStubGroup) Name() string { return "wsstub" }
|
||||
func (s *wsStubGroup) BasePath() string { return "/v1/wsstub" }
|
||||
func (s *wsStubGroup) RegisterRoutes(rg *gin.RouterGroup) {
|
||||
rg.GET("/ping", func(c *gin.Context) {
|
||||
c.JSON(200, api.OK("pong"))
|
||||
})
|
||||
}
|
||||
|
||||
// wsStubStreamGroup embeds wsStubGroup and implements StreamGroup.
|
||||
type wsStubStreamGroup struct{ wsStubGroup }
|
||||
|
||||
func (s *wsStubStreamGroup) Channels() []string {
|
||||
return []string{"wsstub.events", "wsstub.updates"}
|
||||
}
|
||||
|
||||
// ── WebSocket endpoint ──────────────────────────────────────────────────
|
||||
|
||||
func TestWSEndpoint_Good(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
// Create a WebSocket upgrader that writes "hello" to every connection.
|
||||
upgrader := websocket.Upgrader{
|
||||
CheckOrigin: func(r *http.Request) bool { return true },
|
||||
}
|
||||
wsHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
t.Logf("upgrade error: %v", err)
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
_ = conn.WriteMessage(websocket.TextMessage, []byte("hello"))
|
||||
})
|
||||
|
||||
e, err := api.New(api.WithWSHandler(wsHandler))
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
srv := httptest.NewServer(e.Handler())
|
||||
defer srv.Close()
|
||||
|
||||
// Dial the WebSocket endpoint.
|
||||
wsURL := "ws" + strings.TrimPrefix(srv.URL, "http") + "/ws"
|
||||
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to dial WebSocket: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
_, msg, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read message: %v", err)
|
||||
}
|
||||
if string(msg) != "hello" {
|
||||
t.Fatalf("expected message=%q, got %q", "hello", string(msg))
|
||||
}
|
||||
}
|
||||
|
||||
func TestNoWSHandler_Good(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
// Without WithWSHandler, GET /ws should return 404.
|
||||
e, _ := api.New()
|
||||
|
||||
h := e.Handler()
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest(http.MethodGet, "/ws", nil)
|
||||
h.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Fatalf("expected 404 for /ws without handler, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// ── Channel listing ─────────────────────────────────────────────────────
|
||||
|
||||
func TestChannelListing_Good(t *testing.T) {
|
||||
e, _ := api.New()
|
||||
|
||||
// Register a plain RouteGroup (no channels) and a StreamGroup.
|
||||
e.Register(&wsStubGroup{})
|
||||
e.Register(&wsStubStreamGroup{})
|
||||
|
||||
channels := e.Channels()
|
||||
if len(channels) != 2 {
|
||||
t.Fatalf("expected 2 channels, got %d", len(channels))
|
||||
}
|
||||
if channels[0] != "wsstub.events" {
|
||||
t.Fatalf("expected channels[0]=%q, got %q", "wsstub.events", channels[0])
|
||||
}
|
||||
if channels[1] != "wsstub.updates" {
|
||||
t.Fatalf("expected channels[1]=%q, got %q", "wsstub.updates", channels[1])
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Reference in a new issue