diff --git a/api.go b/api.go index 5c8d08f..ebac90f 100644 --- a/api.go +++ b/api.go @@ -38,6 +38,7 @@ type Engine struct { groups []RouteGroup middlewares []gin.HandlerFunc wsHandler http.Handler + wsPath string sseBroker *SSEBroker swaggerEnabled bool swaggerTitle string @@ -223,7 +224,7 @@ func (e *Engine) build() *gin.Engine { // Mount WebSocket handler if configured. if e.wsHandler != nil { - r.GET("/ws", wrapWSHandler(e.wsHandler)) + r.GET(resolveWSPath(e.wsPath), wrapWSHandler(e.wsHandler)) } // Mount SSE endpoint if configured. diff --git a/options.go b/options.go index 108d2cb..a608c75 100644 --- a/options.go +++ b/options.go @@ -115,6 +115,7 @@ func WithStatic(urlPrefix, root string) Option { } // WithWSHandler registers a WebSocket handler at GET /ws. +// Use WithWSPath to customise the route before mounting the handler. // Typically this wraps a go-ws Hub.Handler(). // // Example: @@ -126,6 +127,14 @@ func WithWSHandler(h http.Handler) Option { } } +// WithWSPath sets a custom URL path for the WebSocket endpoint. +// The default path is "/ws". +func WithWSPath(path string) Option { + return func(e *Engine) { + e.wsPath = normaliseWSPath(path) + } +} + // WithAuthentik adds Authentik forward-auth middleware that extracts user // identity from X-authentik-* headers set by a trusted reverse proxy. // The middleware is permissive: unauthenticated requests are allowed through. diff --git a/spec_builder_helper.go b/spec_builder_helper.go index 1bf0f38..4ec4961 100644 --- a/spec_builder_helper.go +++ b/spec_builder_helper.go @@ -36,7 +36,7 @@ func (e *Engine) OpenAPISpecBuilder() *SpecBuilder { builder.GraphQLPlayground = e.graphql.playground } if e.wsHandler != nil { - builder.WSPath = "/ws" + builder.WSPath = resolveWSPath(e.wsPath) } if e.sseBroker != nil { builder.SSEPath = resolveSSEPath(e.ssePath) diff --git a/spec_builder_helper_test.go b/spec_builder_helper_test.go index 56a829b..1f71551 100644 --- a/spec_builder_helper_test.go +++ b/spec_builder_helper_test.go @@ -24,6 +24,7 @@ func TestEngine_Good_OpenAPISpecBuilderCarriesEngineMetadata(t *testing.T) { api.WithSwaggerServers("https://api.example.com", "/", "https://api.example.com"), api.WithSwaggerLicense("EUPL-1.2", "https://eupl.eu/1.2/en/"), api.WithSwaggerExternalDocs("Developer guide", "https://example.com/docs"), + api.WithWSPath("/socket"), api.WithWSHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})), api.WithGraphQL(newTestSchema(), api.WithPlayground(), api.WithGraphQLPath("/gql")), api.WithSSE(broker), @@ -116,8 +117,8 @@ func TestEngine_Good_OpenAPISpecBuilderCarriesEngineMetadata(t *testing.T) { if _, ok := paths["/gql/playground"]; !ok { t.Fatal("expected GraphQL playground path from engine metadata in generated spec") } - if _, ok := paths["/ws"]; !ok { - t.Fatal("expected WebSocket path from engine metadata in generated spec") + if _, ok := paths["/socket"]; !ok { + t.Fatal("expected custom WebSocket path from engine metadata in generated spec") } if _, ok := paths["/events"]; !ok { t.Fatal("expected SSE path from engine metadata in generated spec") diff --git a/websocket.go b/websocket.go index 8eb7a33..fc5bedc 100644 --- a/websocket.go +++ b/websocket.go @@ -4,14 +4,43 @@ package api import ( "net/http" + "strings" "github.com/gin-gonic/gin" ) -// wrapWSHandler adapts a standard http.Handler to a Gin handler for the /ws route. +// defaultWSPath is the URL path where the WebSocket endpoint is mounted. +const defaultWSPath = "/ws" + +// wrapWSHandler adapts a standard http.Handler to a Gin handler for the WebSocket route. // The underlying handler is responsible for upgrading the connection to WebSocket. func wrapWSHandler(h http.Handler) gin.HandlerFunc { return func(c *gin.Context) { h.ServeHTTP(c.Writer, c.Request) } } + +// normaliseWSPath coerces custom WebSocket paths into a stable form. +// The path always begins with a single slash and never ends with one. +func normaliseWSPath(path string) string { + path = strings.TrimSpace(path) + if path == "" { + return defaultWSPath + } + + path = "/" + strings.Trim(path, "/") + if path == "/" { + return defaultWSPath + } + + return path +} + +// resolveWSPath returns the configured WebSocket path or the default path +// when no override has been provided. +func resolveWSPath(path string) string { + if strings.TrimSpace(path) == "" { + return defaultWSPath + } + return normaliseWSPath(path) +} diff --git a/websocket_test.go b/websocket_test.go index cbad161..d287364 100644 --- a/websocket_test.go +++ b/websocket_test.go @@ -78,6 +78,46 @@ func TestWSEndpoint_Good(t *testing.T) { } } +func TestWSEndpoint_Good_CustomPath(t *testing.T) { + gin.SetMode(gin.TestMode) + + upgrader := websocket.Upgrader{ + CheckOrigin: func(r *http.Request) bool { return true }, + } + wsHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + t.Logf("upgrade error: %v", err) + return + } + defer conn.Close() + _ = conn.WriteMessage(websocket.TextMessage, []byte("custom")) + }) + + e, err := api.New(api.WithWSPath("/socket"), api.WithWSHandler(wsHandler)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + srv := httptest.NewServer(e.Handler()) + defer srv.Close() + + wsURL := "ws" + strings.TrimPrefix(srv.URL, "http") + "/socket" + conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) + if err != nil { + t.Fatalf("failed to dial custom WebSocket: %v", err) + } + defer conn.Close() + + _, msg, err := conn.ReadMessage() + if err != nil { + t.Fatalf("failed to read custom WebSocket message: %v", err) + } + if string(msg) != "custom" { + t.Fatalf("expected message=%q, got %q", "custom", string(msg)) + } +} + func TestNoWSHandler_Good(t *testing.T) { gin.SetMode(gin.TestMode)