diff --git a/api.go b/api.go index 09cc2c3..59ddcb0 100644 --- a/api.go +++ b/api.go @@ -43,6 +43,7 @@ type Engine struct { swaggerTitle string swaggerDesc string swaggerVersion string + swaggerPath string swaggerTermsOfService string swaggerServers []string swaggerContactName string @@ -243,6 +244,7 @@ func (e *Engine) build() *gin.Engine { } registerSwagger( r, + resolveSwaggerPath(e.swaggerPath), e.swaggerTitle, e.swaggerDesc, e.swaggerVersion, diff --git a/middleware.go b/middleware.go index 4ad6d66..d555fa3 100644 --- a/middleware.go +++ b/middleware.go @@ -38,10 +38,10 @@ func recoveryMiddleware() gin.HandlerFunc { // bearerAuthMiddleware validates the Authorization: Bearer header. // Requests to paths in the skip list are allowed through without authentication. // Returns 401 with Fail("unauthorised", ...) on missing or invalid tokens. -func bearerAuthMiddleware(token string, skip []string) gin.HandlerFunc { +func bearerAuthMiddleware(token string, skip func() []string) gin.HandlerFunc { return func(c *gin.Context) { // Check whether the request path should bypass authentication. - for _, path := range skip { + for _, path := range skip() { if isPublicPath(c.Request.URL.Path, path) { c.Next() return diff --git a/options.go b/options.go index 153eb9a..108d2cb 100644 --- a/options.go +++ b/options.go @@ -40,11 +40,16 @@ func WithAddr(addr string) Option { } // WithBearerAuth adds bearer token authentication middleware. -// Requests to /health and paths starting with /swagger are exempt. +// Requests to /health and the Swagger UI path are exempt. func WithBearerAuth(token string) Option { return func(e *Engine) { - skip := []string{"/health", "/swagger"} - e.middlewares = append(e.middlewares, bearerAuthMiddleware(token, skip)) + e.middlewares = append(e.middlewares, bearerAuthMiddleware(token, func() []string { + skip := []string{"/health"} + if swaggerPath := resolveSwaggerPath(e.swaggerPath); swaggerPath != "" { + skip = append(skip, swaggerPath) + } + return skip + })) } } @@ -140,7 +145,7 @@ func WithSunset(sunsetDate, replacement string) Option { } } -// WithSwagger enables the Swagger UI at /swagger/. +// WithSwagger enables the Swagger UI at /swagger/ by default. // The title, description, and version populate the OpenAPI info block. func WithSwagger(title, description, version string) Option { return func(e *Engine) { @@ -151,6 +156,14 @@ func WithSwagger(title, description, version string) Option { } } +// WithSwaggerPath sets a custom URL path for the Swagger UI. +// The default path is "/swagger". +func WithSwaggerPath(path string) Option { + return func(e *Engine) { + e.swaggerPath = normaliseSwaggerPath(path) + } +} + // WithSwaggerTermsOfService adds the terms of service URL to the generated Swagger spec. // Empty strings are ignored. // diff --git a/swagger.go b/swagger.go index 3007c4c..b398002 100644 --- a/swagger.go +++ b/swagger.go @@ -4,6 +4,7 @@ package api import ( "fmt" + "strings" "sync" "sync/atomic" @@ -18,6 +19,9 @@ import ( // (common in tests) do not collide in the global swag registry. var swaggerSeq atomic.Uint64 +// defaultSwaggerPath is the URL path where the Swagger UI is mounted. +const defaultSwaggerPath = "/swagger" + // swaggerSpec wraps SpecBuilder to satisfy the swag.Spec interface. // The spec is built once on first access and cached. type swaggerSpec struct { @@ -48,7 +52,8 @@ func (s *swaggerSpec) ReadDoc() string { } // registerSwagger mounts the Swagger UI and doc.json endpoint. -func registerSwagger(g *gin.Engine, title, description, version, graphqlPath, ssePath, termsOfService, contactName, contactURL, contactEmail string, servers []string, licenseName, licenseURL, externalDocsDescription, externalDocsURL string, groups []RouteGroup) { +func registerSwagger(g *gin.Engine, swaggerPath, title, description, version, graphqlPath, ssePath, termsOfService, contactName, contactURL, contactEmail string, servers []string, licenseName, licenseURL, externalDocsDescription, externalDocsURL string, groups []RouteGroup) { + swaggerPath = resolveSwaggerPath(swaggerPath) spec := newSwaggerSpec(&SpecBuilder{ Title: title, Description: description, @@ -67,5 +72,30 @@ func registerSwagger(g *gin.Engine, title, description, version, graphqlPath, ss }, groups) name := fmt.Sprintf("swagger_%d", swaggerSeq.Add(1)) swag.Register(name, spec) - g.GET("/swagger/*any", ginSwagger.WrapHandler(swaggerFiles.NewHandler(), ginSwagger.InstanceName(name))) + g.GET(swaggerPath+"/*any", ginSwagger.WrapHandler(swaggerFiles.NewHandler(), ginSwagger.InstanceName(name))) +} + +// normaliseSwaggerPath coerces custom Swagger paths into a stable form. +// The path always begins with a single slash and never ends with one. +func normaliseSwaggerPath(path string) string { + path = strings.TrimSpace(path) + if path == "" { + return defaultSwaggerPath + } + + path = "/" + strings.Trim(path, "/") + if path == "/" { + return defaultSwaggerPath + } + + return path +} + +// resolveSwaggerPath returns the configured Swagger path or the default path +// when no override has been provided. +func resolveSwaggerPath(path string) string { + if strings.TrimSpace(path) == "" { + return defaultSwaggerPath + } + return normaliseSwaggerPath(path) } diff --git a/swagger_test.go b/swagger_test.go index adbb0d5..245d4b2 100644 --- a/swagger_test.go +++ b/swagger_test.go @@ -65,6 +65,52 @@ func TestSwaggerEndpoint_Good(t *testing.T) { } } +func TestSwaggerEndpoint_Good_CustomPath(t *testing.T) { + gin.SetMode(gin.TestMode) + + e, err := api.New( + api.WithSwagger("Test API", "A test API service", "1.0.0"), + api.WithSwaggerPath("/docs"), + ) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + srv := httptest.NewServer(e.Handler()) + defer srv.Close() + + resp, err := http.Get(srv.URL + "/docs/doc.json") + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected 200, got %d", resp.StatusCode) + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("failed to read body: %v", err) + } + if len(body) == 0 { + t.Fatal("expected non-empty response body") + } + + var doc map[string]any + if err := json.Unmarshal(body, &doc); err != nil { + t.Fatalf("expected valid JSON, got unmarshal error: %v", err) + } + + info, ok := doc["info"].(map[string]any) + if !ok { + t.Fatal("expected 'info' object in swagger doc") + } + if info["title"] != "Test API" { + t.Fatalf("expected title=%q, got %q", "Test API", info["title"]) + } +} + func TestSwaggerDisabledByDefault_Good(t *testing.T) { gin.SetMode(gin.TestMode) @@ -81,6 +127,32 @@ func TestSwaggerDisabledByDefault_Good(t *testing.T) { } } +func TestSwaggerAuth_Good_CustomPathBypassesBearerAuth(t *testing.T) { + gin.SetMode(gin.TestMode) + + e, err := api.New( + api.WithBearerAuth("secret"), + api.WithSwagger("Test API", "A test API service", "1.0.0"), + api.WithSwaggerPath("/docs"), + ) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + srv := httptest.NewServer(e.Handler()) + defer srv.Close() + + resp, err := http.Get(srv.URL + "/docs/doc.json") + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected 200 for custom swagger path without auth, got %d", resp.StatusCode) + } +} + func TestSwagger_Good_SpecNotEmpty(t *testing.T) { gin.SetMode(gin.TestMode)