feat(api): add configurable websocket path
Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
parent
d9ccd7c49a
commit
e47b010194
6 changed files with 85 additions and 5 deletions
3
api.go
3
api.go
|
|
@ -38,6 +38,7 @@ type Engine struct {
|
|||
groups []RouteGroup
|
||||
middlewares []gin.HandlerFunc
|
||||
wsHandler http.Handler
|
||||
wsPath string
|
||||
sseBroker *SSEBroker
|
||||
swaggerEnabled bool
|
||||
swaggerTitle string
|
||||
|
|
@ -223,7 +224,7 @@ func (e *Engine) build() *gin.Engine {
|
|||
|
||||
// Mount WebSocket handler if configured.
|
||||
if e.wsHandler != nil {
|
||||
r.GET("/ws", wrapWSHandler(e.wsHandler))
|
||||
r.GET(resolveWSPath(e.wsPath), wrapWSHandler(e.wsHandler))
|
||||
}
|
||||
|
||||
// Mount SSE endpoint if configured.
|
||||
|
|
|
|||
|
|
@ -115,6 +115,7 @@ func WithStatic(urlPrefix, root string) Option {
|
|||
}
|
||||
|
||||
// WithWSHandler registers a WebSocket handler at GET /ws.
|
||||
// Use WithWSPath to customise the route before mounting the handler.
|
||||
// Typically this wraps a go-ws Hub.Handler().
|
||||
//
|
||||
// Example:
|
||||
|
|
@ -126,6 +127,14 @@ func WithWSHandler(h http.Handler) Option {
|
|||
}
|
||||
}
|
||||
|
||||
// WithWSPath sets a custom URL path for the WebSocket endpoint.
|
||||
// The default path is "/ws".
|
||||
func WithWSPath(path string) Option {
|
||||
return func(e *Engine) {
|
||||
e.wsPath = normaliseWSPath(path)
|
||||
}
|
||||
}
|
||||
|
||||
// WithAuthentik adds Authentik forward-auth middleware that extracts user
|
||||
// identity from X-authentik-* headers set by a trusted reverse proxy.
|
||||
// The middleware is permissive: unauthenticated requests are allowed through.
|
||||
|
|
|
|||
|
|
@ -36,7 +36,7 @@ func (e *Engine) OpenAPISpecBuilder() *SpecBuilder {
|
|||
builder.GraphQLPlayground = e.graphql.playground
|
||||
}
|
||||
if e.wsHandler != nil {
|
||||
builder.WSPath = "/ws"
|
||||
builder.WSPath = resolveWSPath(e.wsPath)
|
||||
}
|
||||
if e.sseBroker != nil {
|
||||
builder.SSEPath = resolveSSEPath(e.ssePath)
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ func TestEngine_Good_OpenAPISpecBuilderCarriesEngineMetadata(t *testing.T) {
|
|||
api.WithSwaggerServers("https://api.example.com", "/", "https://api.example.com"),
|
||||
api.WithSwaggerLicense("EUPL-1.2", "https://eupl.eu/1.2/en/"),
|
||||
api.WithSwaggerExternalDocs("Developer guide", "https://example.com/docs"),
|
||||
api.WithWSPath("/socket"),
|
||||
api.WithWSHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})),
|
||||
api.WithGraphQL(newTestSchema(), api.WithPlayground(), api.WithGraphQLPath("/gql")),
|
||||
api.WithSSE(broker),
|
||||
|
|
@ -116,8 +117,8 @@ func TestEngine_Good_OpenAPISpecBuilderCarriesEngineMetadata(t *testing.T) {
|
|||
if _, ok := paths["/gql/playground"]; !ok {
|
||||
t.Fatal("expected GraphQL playground path from engine metadata in generated spec")
|
||||
}
|
||||
if _, ok := paths["/ws"]; !ok {
|
||||
t.Fatal("expected WebSocket path from engine metadata in generated spec")
|
||||
if _, ok := paths["/socket"]; !ok {
|
||||
t.Fatal("expected custom WebSocket path from engine metadata in generated spec")
|
||||
}
|
||||
if _, ok := paths["/events"]; !ok {
|
||||
t.Fatal("expected SSE path from engine metadata in generated spec")
|
||||
|
|
|
|||
31
websocket.go
31
websocket.go
|
|
@ -4,14 +4,43 @@ package api
|
|||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// wrapWSHandler adapts a standard http.Handler to a Gin handler for the /ws route.
|
||||
// defaultWSPath is the URL path where the WebSocket endpoint is mounted.
|
||||
const defaultWSPath = "/ws"
|
||||
|
||||
// wrapWSHandler adapts a standard http.Handler to a Gin handler for the WebSocket route.
|
||||
// The underlying handler is responsible for upgrading the connection to WebSocket.
|
||||
func wrapWSHandler(h http.Handler) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
h.ServeHTTP(c.Writer, c.Request)
|
||||
}
|
||||
}
|
||||
|
||||
// normaliseWSPath coerces custom WebSocket paths into a stable form.
|
||||
// The path always begins with a single slash and never ends with one.
|
||||
func normaliseWSPath(path string) string {
|
||||
path = strings.TrimSpace(path)
|
||||
if path == "" {
|
||||
return defaultWSPath
|
||||
}
|
||||
|
||||
path = "/" + strings.Trim(path, "/")
|
||||
if path == "/" {
|
||||
return defaultWSPath
|
||||
}
|
||||
|
||||
return path
|
||||
}
|
||||
|
||||
// resolveWSPath returns the configured WebSocket path or the default path
|
||||
// when no override has been provided.
|
||||
func resolveWSPath(path string) string {
|
||||
if strings.TrimSpace(path) == "" {
|
||||
return defaultWSPath
|
||||
}
|
||||
return normaliseWSPath(path)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -78,6 +78,46 @@ func TestWSEndpoint_Good(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestWSEndpoint_Good_CustomPath(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
upgrader := websocket.Upgrader{
|
||||
CheckOrigin: func(r *http.Request) bool { return true },
|
||||
}
|
||||
wsHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
t.Logf("upgrade error: %v", err)
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
_ = conn.WriteMessage(websocket.TextMessage, []byte("custom"))
|
||||
})
|
||||
|
||||
e, err := api.New(api.WithWSPath("/socket"), api.WithWSHandler(wsHandler))
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
srv := httptest.NewServer(e.Handler())
|
||||
defer srv.Close()
|
||||
|
||||
wsURL := "ws" + strings.TrimPrefix(srv.URL, "http") + "/socket"
|
||||
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to dial custom WebSocket: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
_, msg, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read custom WebSocket message: %v", err)
|
||||
}
|
||||
if string(msg) != "custom" {
|
||||
t.Fatalf("expected message=%q, got %q", "custom", string(msg))
|
||||
}
|
||||
}
|
||||
|
||||
func TestNoWSHandler_Good(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue