feat(api): add configurable websocket path

Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
Virgil 2026-04-02 03:21:28 +00:00
parent d9ccd7c49a
commit e47b010194
6 changed files with 85 additions and 5 deletions

3
api.go
View file

@ -38,6 +38,7 @@ type Engine struct {
groups []RouteGroup
middlewares []gin.HandlerFunc
wsHandler http.Handler
wsPath string
sseBroker *SSEBroker
swaggerEnabled bool
swaggerTitle string
@ -223,7 +224,7 @@ func (e *Engine) build() *gin.Engine {
// Mount WebSocket handler if configured.
if e.wsHandler != nil {
r.GET("/ws", wrapWSHandler(e.wsHandler))
r.GET(resolveWSPath(e.wsPath), wrapWSHandler(e.wsHandler))
}
// Mount SSE endpoint if configured.

View file

@ -115,6 +115,7 @@ func WithStatic(urlPrefix, root string) Option {
}
// WithWSHandler registers a WebSocket handler at GET /ws.
// Use WithWSPath to customise the route before mounting the handler.
// Typically this wraps a go-ws Hub.Handler().
//
// Example:
@ -126,6 +127,14 @@ func WithWSHandler(h http.Handler) Option {
}
}
// WithWSPath sets a custom URL path for the WebSocket endpoint.
// The default path is "/ws".
func WithWSPath(path string) Option {
return func(e *Engine) {
e.wsPath = normaliseWSPath(path)
}
}
// WithAuthentik adds Authentik forward-auth middleware that extracts user
// identity from X-authentik-* headers set by a trusted reverse proxy.
// The middleware is permissive: unauthenticated requests are allowed through.

View file

@ -36,7 +36,7 @@ func (e *Engine) OpenAPISpecBuilder() *SpecBuilder {
builder.GraphQLPlayground = e.graphql.playground
}
if e.wsHandler != nil {
builder.WSPath = "/ws"
builder.WSPath = resolveWSPath(e.wsPath)
}
if e.sseBroker != nil {
builder.SSEPath = resolveSSEPath(e.ssePath)

View file

@ -24,6 +24,7 @@ func TestEngine_Good_OpenAPISpecBuilderCarriesEngineMetadata(t *testing.T) {
api.WithSwaggerServers("https://api.example.com", "/", "https://api.example.com"),
api.WithSwaggerLicense("EUPL-1.2", "https://eupl.eu/1.2/en/"),
api.WithSwaggerExternalDocs("Developer guide", "https://example.com/docs"),
api.WithWSPath("/socket"),
api.WithWSHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})),
api.WithGraphQL(newTestSchema(), api.WithPlayground(), api.WithGraphQLPath("/gql")),
api.WithSSE(broker),
@ -116,8 +117,8 @@ func TestEngine_Good_OpenAPISpecBuilderCarriesEngineMetadata(t *testing.T) {
if _, ok := paths["/gql/playground"]; !ok {
t.Fatal("expected GraphQL playground path from engine metadata in generated spec")
}
if _, ok := paths["/ws"]; !ok {
t.Fatal("expected WebSocket path from engine metadata in generated spec")
if _, ok := paths["/socket"]; !ok {
t.Fatal("expected custom WebSocket path from engine metadata in generated spec")
}
if _, ok := paths["/events"]; !ok {
t.Fatal("expected SSE path from engine metadata in generated spec")

View file

@ -4,14 +4,43 @@ package api
import (
"net/http"
"strings"
"github.com/gin-gonic/gin"
)
// wrapWSHandler adapts a standard http.Handler to a Gin handler for the /ws route.
// defaultWSPath is the URL path where the WebSocket endpoint is mounted.
const defaultWSPath = "/ws"
// wrapWSHandler adapts a standard http.Handler to a Gin handler for the WebSocket route.
// The underlying handler is responsible for upgrading the connection to WebSocket.
func wrapWSHandler(h http.Handler) gin.HandlerFunc {
return func(c *gin.Context) {
h.ServeHTTP(c.Writer, c.Request)
}
}
// normaliseWSPath coerces custom WebSocket paths into a stable form.
// The path always begins with a single slash and never ends with one.
func normaliseWSPath(path string) string {
path = strings.TrimSpace(path)
if path == "" {
return defaultWSPath
}
path = "/" + strings.Trim(path, "/")
if path == "/" {
return defaultWSPath
}
return path
}
// resolveWSPath returns the configured WebSocket path or the default path
// when no override has been provided.
func resolveWSPath(path string) string {
if strings.TrimSpace(path) == "" {
return defaultWSPath
}
return normaliseWSPath(path)
}

View file

@ -78,6 +78,46 @@ func TestWSEndpoint_Good(t *testing.T) {
}
}
func TestWSEndpoint_Good_CustomPath(t *testing.T) {
gin.SetMode(gin.TestMode)
upgrader := websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool { return true },
}
wsHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
t.Logf("upgrade error: %v", err)
return
}
defer conn.Close()
_ = conn.WriteMessage(websocket.TextMessage, []byte("custom"))
})
e, err := api.New(api.WithWSPath("/socket"), api.WithWSHandler(wsHandler))
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
srv := httptest.NewServer(e.Handler())
defer srv.Close()
wsURL := "ws" + strings.TrimPrefix(srv.URL, "http") + "/socket"
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
if err != nil {
t.Fatalf("failed to dial custom WebSocket: %v", err)
}
defer conn.Close()
_, msg, err := conn.ReadMessage()
if err != nil {
t.Fatalf("failed to read custom WebSocket message: %v", err)
}
if string(msg) != "custom" {
t.Fatalf("expected message=%q, got %q", "custom", string(msg))
}
}
func TestNoWSHandler_Good(t *testing.T) {
gin.SetMode(gin.TestMode)