From 006a065ea04aef65c8a459388e12053769d6f274 Mon Sep 17 00:00:00 2001 From: Virgil Date: Thu, 2 Apr 2026 02:33:31 +0000 Subject: [PATCH] feat(openapi): document WebSocket endpoint Co-Authored-By: Virgil --- openapi.go | 122 ++++++++++++++++++++++++++++++++++++ openapi_test.go | 53 ++++++++++++++++ spec_builder_helper.go | 3 + spec_builder_helper_test.go | 5 ++ 4 files changed, 183 insertions(+) diff --git a/openapi.go b/openapi.go index af327f7..7ceeb48 100644 --- a/openapi.go +++ b/openapi.go @@ -26,6 +26,7 @@ type SpecBuilder struct { Version string GraphQLPath string GraphQLPlayground bool + WSPath string SSEPath string TermsOfService string ContactName string @@ -199,6 +200,11 @@ func (sb *SpecBuilder) buildPaths(groups []preparedRouteGroup) map[string]any { } } + if wsPath := strings.TrimSpace(sb.WSPath); wsPath != "" { + wsPath = normaliseOpenAPIPath(wsPath) + paths[wsPath] = wsPathItem(wsPath, operationIDs) + } + if ssePath := strings.TrimSpace(sb.SSEPath); ssePath != "" { ssePath = normaliseOpenAPIPath(ssePath) paths[ssePath] = ssePathItem(ssePath, operationIDs) @@ -715,6 +721,23 @@ func graphqlPlaygroundPathItem(path string, operationIDs map[string]int) map[str } } +func wsPathItem(path string, operationIDs map[string]int) map[string]any { + return map[string]any{ + "get": map[string]any{ + "summary": "WebSocket connection", + "description": "Upgrades the connection to a WebSocket stream", + "tags": []string{"system"}, + "operationId": operationID("get", path, operationIDs), + "security": []any{ + map[string]any{ + "bearerAuth": []any{}, + }, + }, + "responses": wsResponses(), + }, + } +} + func ssePathItem(path string, operationIDs map[string]int) map[string]any { return map[string]any{ "get": map[string]any{ @@ -743,6 +766,105 @@ func ssePathItem(path string, operationIDs map[string]int) map[string]any { } } +func wsResponses() map[string]any { + successHeaders := mergeHeaders( + standardResponseHeaders(), + rateLimitSuccessHeaders(), + wsUpgradeHeaders(), + ) + errorHeaders := mergeHeaders(standardResponseHeaders(), rateLimitSuccessHeaders()) + + return map[string]any{ + "101": map[string]any{ + "description": "Switching protocols", + "headers": successHeaders, + }, + "401": map[string]any{ + "description": "Unauthorised", + "content": map[string]any{ + "application/json": map[string]any{ + "schema": map[string]any{ + "type": "object", + "additionalProperties": true, + }, + }, + }, + "headers": errorHeaders, + }, + "403": map[string]any{ + "description": "Forbidden", + "content": map[string]any{ + "application/json": map[string]any{ + "schema": map[string]any{ + "type": "object", + "additionalProperties": true, + }, + }, + }, + "headers": errorHeaders, + }, + "429": map[string]any{ + "description": "Too many requests", + "content": map[string]any{ + "application/json": map[string]any{ + "schema": map[string]any{ + "type": "object", + "additionalProperties": true, + }, + }, + }, + "headers": mergeHeaders(standardResponseHeaders(), rateLimitHeaders()), + }, + "500": map[string]any{ + "description": "Internal server error", + "content": map[string]any{ + "application/json": map[string]any{ + "schema": map[string]any{ + "type": "object", + "additionalProperties": true, + }, + }, + }, + "headers": errorHeaders, + }, + "504": map[string]any{ + "description": "Gateway timeout", + "content": map[string]any{ + "application/json": map[string]any{ + "schema": map[string]any{ + "type": "object", + "additionalProperties": true, + }, + }, + }, + "headers": errorHeaders, + }, + } +} + +func wsUpgradeHeaders() map[string]any { + return map[string]any{ + "Upgrade": map[string]any{ + "description": "Indicates that the connection has switched to WebSocket", + "schema": map[string]any{ + "type": "string", + }, + }, + "Connection": map[string]any{ + "description": "Keeps the upgraded connection open", + "schema": map[string]any{ + "type": "string", + }, + }, + "Sec-WebSocket-Accept": map[string]any{ + "description": "Validates the WebSocket handshake", + "schema": map[string]any{ + "type": "string", + }, + }, + } +} + func pprofPathItem(operationIDs map[string]int) map[string]any { return map[string]any{ "get": map[string]any{ diff --git a/openapi_test.go b/openapi_test.go index d233eba..099e9a9 100644 --- a/openapi_test.go +++ b/openapi_test.go @@ -304,6 +304,59 @@ func TestSpecBuilder_Good_GraphQLPlaygroundEndpoint(t *testing.T) { } } +func TestSpecBuilder_Good_WebSocketEndpoint(t *testing.T) { + sb := &api.SpecBuilder{ + Title: "Test", + Version: "1.0.0", + WSPath: "/ws", + } + + data, err := sb.Build(nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + var spec map[string]any + if err := json.Unmarshal(data, &spec); err != nil { + t.Fatalf("invalid JSON: %v", err) + } + + tags := spec["tags"].([]any) + found := false + for _, tag := range tags { + tm := tag.(map[string]any) + if tm["name"] == "system" { + found = true + break + } + } + if !found { + t.Fatal("expected system tag in spec") + } + + paths := spec["paths"].(map[string]any) + pathItem, ok := paths["/ws"].(map[string]any) + if !ok { + t.Fatal("expected /ws path in spec") + } + + getOp := pathItem["get"].(map[string]any) + if getOp["operationId"] != "get_ws" { + t.Fatalf("expected WebSocket operationId to be get_ws, got %v", getOp["operationId"]) + } + if getOp["summary"] != "WebSocket connection" { + t.Fatalf("expected WebSocket summary, got %v", getOp["summary"]) + } + + responses := getOp["responses"].(map[string]any) + if _, ok := responses["101"]; !ok { + t.Fatal("expected 101 response on /ws") + } + if _, ok := responses["429"]; !ok { + t.Fatal("expected 429 response on /ws") + } +} + func TestSpecBuilder_Good_ServerSentEventsEndpoint(t *testing.T) { sb := &api.SpecBuilder{ Title: "Test", diff --git a/spec_builder_helper.go b/spec_builder_helper.go index 72ea0a6..b165acd 100644 --- a/spec_builder_helper.go +++ b/spec_builder_helper.go @@ -34,6 +34,9 @@ func (e *Engine) OpenAPISpecBuilder() *SpecBuilder { builder.GraphQLPath = e.graphql.path builder.GraphQLPlayground = e.graphql.playground } + if e.wsHandler != nil { + builder.WSPath = "/ws" + } if e.sseBroker != nil { builder.SSEPath = resolveSSEPath(e.ssePath) } diff --git a/spec_builder_helper_test.go b/spec_builder_helper_test.go index 2cd28cb..4f339d2 100644 --- a/spec_builder_helper_test.go +++ b/spec_builder_helper_test.go @@ -4,6 +4,7 @@ package api_test import ( "encoding/json" + "net/http" "testing" "github.com/gin-gonic/gin" @@ -22,6 +23,7 @@ func TestEngine_Good_OpenAPISpecBuilderCarriesEngineMetadata(t *testing.T) { api.WithSwaggerServers("https://api.example.com", "/", "https://api.example.com"), api.WithSwaggerLicense("EUPL-1.2", "https://eupl.eu/1.2/en/"), api.WithSwaggerExternalDocs("Developer guide", "https://example.com/docs"), + api.WithWSHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})), api.WithGraphQL(newTestSchema(), api.WithPlayground(), api.WithGraphQLPath("/gql")), api.WithSSE(broker), api.WithSSEPath("/events"), @@ -109,6 +111,9 @@ func TestEngine_Good_OpenAPISpecBuilderCarriesEngineMetadata(t *testing.T) { if _, ok := paths["/gql/playground"]; !ok { t.Fatal("expected GraphQL playground path from engine metadata in generated spec") } + if _, ok := paths["/ws"]; !ok { + t.Fatal("expected WebSocket path from engine metadata in generated spec") + } if _, ok := paths["/events"]; !ok { t.Fatal("expected SSE path from engine metadata in generated spec") }