feat(api): add configurable Swagger path
Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
parent
39bf094b51
commit
ef641c7547
5 changed files with 125 additions and 8 deletions
2
api.go
2
api.go
|
|
@ -43,6 +43,7 @@ type Engine struct {
|
|||
swaggerTitle string
|
||||
swaggerDesc string
|
||||
swaggerVersion string
|
||||
swaggerPath string
|
||||
swaggerTermsOfService string
|
||||
swaggerServers []string
|
||||
swaggerContactName string
|
||||
|
|
@ -243,6 +244,7 @@ func (e *Engine) build() *gin.Engine {
|
|||
}
|
||||
registerSwagger(
|
||||
r,
|
||||
resolveSwaggerPath(e.swaggerPath),
|
||||
e.swaggerTitle,
|
||||
e.swaggerDesc,
|
||||
e.swaggerVersion,
|
||||
|
|
|
|||
|
|
@ -38,10 +38,10 @@ func recoveryMiddleware() gin.HandlerFunc {
|
|||
// bearerAuthMiddleware validates the Authorization: Bearer <token> header.
|
||||
// Requests to paths in the skip list are allowed through without authentication.
|
||||
// Returns 401 with Fail("unauthorised", ...) on missing or invalid tokens.
|
||||
func bearerAuthMiddleware(token string, skip []string) gin.HandlerFunc {
|
||||
func bearerAuthMiddleware(token string, skip func() []string) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// Check whether the request path should bypass authentication.
|
||||
for _, path := range skip {
|
||||
for _, path := range skip() {
|
||||
if isPublicPath(c.Request.URL.Path, path) {
|
||||
c.Next()
|
||||
return
|
||||
|
|
|
|||
21
options.go
21
options.go
|
|
@ -40,11 +40,16 @@ func WithAddr(addr string) Option {
|
|||
}
|
||||
|
||||
// WithBearerAuth adds bearer token authentication middleware.
|
||||
// Requests to /health and paths starting with /swagger are exempt.
|
||||
// Requests to /health and the Swagger UI path are exempt.
|
||||
func WithBearerAuth(token string) Option {
|
||||
return func(e *Engine) {
|
||||
skip := []string{"/health", "/swagger"}
|
||||
e.middlewares = append(e.middlewares, bearerAuthMiddleware(token, skip))
|
||||
e.middlewares = append(e.middlewares, bearerAuthMiddleware(token, func() []string {
|
||||
skip := []string{"/health"}
|
||||
if swaggerPath := resolveSwaggerPath(e.swaggerPath); swaggerPath != "" {
|
||||
skip = append(skip, swaggerPath)
|
||||
}
|
||||
return skip
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -140,7 +145,7 @@ func WithSunset(sunsetDate, replacement string) Option {
|
|||
}
|
||||
}
|
||||
|
||||
// WithSwagger enables the Swagger UI at /swagger/.
|
||||
// WithSwagger enables the Swagger UI at /swagger/ by default.
|
||||
// The title, description, and version populate the OpenAPI info block.
|
||||
func WithSwagger(title, description, version string) Option {
|
||||
return func(e *Engine) {
|
||||
|
|
@ -151,6 +156,14 @@ func WithSwagger(title, description, version string) Option {
|
|||
}
|
||||
}
|
||||
|
||||
// WithSwaggerPath sets a custom URL path for the Swagger UI.
|
||||
// The default path is "/swagger".
|
||||
func WithSwaggerPath(path string) Option {
|
||||
return func(e *Engine) {
|
||||
e.swaggerPath = normaliseSwaggerPath(path)
|
||||
}
|
||||
}
|
||||
|
||||
// WithSwaggerTermsOfService adds the terms of service URL to the generated Swagger spec.
|
||||
// Empty strings are ignored.
|
||||
//
|
||||
|
|
|
|||
34
swagger.go
34
swagger.go
|
|
@ -4,6 +4,7 @@ package api
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
|
|
@ -18,6 +19,9 @@ import (
|
|||
// (common in tests) do not collide in the global swag registry.
|
||||
var swaggerSeq atomic.Uint64
|
||||
|
||||
// defaultSwaggerPath is the URL path where the Swagger UI is mounted.
|
||||
const defaultSwaggerPath = "/swagger"
|
||||
|
||||
// swaggerSpec wraps SpecBuilder to satisfy the swag.Spec interface.
|
||||
// The spec is built once on first access and cached.
|
||||
type swaggerSpec struct {
|
||||
|
|
@ -48,7 +52,8 @@ func (s *swaggerSpec) ReadDoc() string {
|
|||
}
|
||||
|
||||
// registerSwagger mounts the Swagger UI and doc.json endpoint.
|
||||
func registerSwagger(g *gin.Engine, title, description, version, graphqlPath, ssePath, termsOfService, contactName, contactURL, contactEmail string, servers []string, licenseName, licenseURL, externalDocsDescription, externalDocsURL string, groups []RouteGroup) {
|
||||
func registerSwagger(g *gin.Engine, swaggerPath, title, description, version, graphqlPath, ssePath, termsOfService, contactName, contactURL, contactEmail string, servers []string, licenseName, licenseURL, externalDocsDescription, externalDocsURL string, groups []RouteGroup) {
|
||||
swaggerPath = resolveSwaggerPath(swaggerPath)
|
||||
spec := newSwaggerSpec(&SpecBuilder{
|
||||
Title: title,
|
||||
Description: description,
|
||||
|
|
@ -67,5 +72,30 @@ func registerSwagger(g *gin.Engine, title, description, version, graphqlPath, ss
|
|||
}, groups)
|
||||
name := fmt.Sprintf("swagger_%d", swaggerSeq.Add(1))
|
||||
swag.Register(name, spec)
|
||||
g.GET("/swagger/*any", ginSwagger.WrapHandler(swaggerFiles.NewHandler(), ginSwagger.InstanceName(name)))
|
||||
g.GET(swaggerPath+"/*any", ginSwagger.WrapHandler(swaggerFiles.NewHandler(), ginSwagger.InstanceName(name)))
|
||||
}
|
||||
|
||||
// normaliseSwaggerPath coerces custom Swagger paths into a stable form.
|
||||
// The path always begins with a single slash and never ends with one.
|
||||
func normaliseSwaggerPath(path string) string {
|
||||
path = strings.TrimSpace(path)
|
||||
if path == "" {
|
||||
return defaultSwaggerPath
|
||||
}
|
||||
|
||||
path = "/" + strings.Trim(path, "/")
|
||||
if path == "/" {
|
||||
return defaultSwaggerPath
|
||||
}
|
||||
|
||||
return path
|
||||
}
|
||||
|
||||
// resolveSwaggerPath returns the configured Swagger path or the default path
|
||||
// when no override has been provided.
|
||||
func resolveSwaggerPath(path string) string {
|
||||
if strings.TrimSpace(path) == "" {
|
||||
return defaultSwaggerPath
|
||||
}
|
||||
return normaliseSwaggerPath(path)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -65,6 +65,52 @@ func TestSwaggerEndpoint_Good(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestSwaggerEndpoint_Good_CustomPath(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
e, err := api.New(
|
||||
api.WithSwagger("Test API", "A test API service", "1.0.0"),
|
||||
api.WithSwaggerPath("/docs"),
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
srv := httptest.NewServer(e.Handler())
|
||||
defer srv.Close()
|
||||
|
||||
resp, err := http.Get(srv.URL + "/docs/doc.json")
|
||||
if err != nil {
|
||||
t.Fatalf("request failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read body: %v", err)
|
||||
}
|
||||
if len(body) == 0 {
|
||||
t.Fatal("expected non-empty response body")
|
||||
}
|
||||
|
||||
var doc map[string]any
|
||||
if err := json.Unmarshal(body, &doc); err != nil {
|
||||
t.Fatalf("expected valid JSON, got unmarshal error: %v", err)
|
||||
}
|
||||
|
||||
info, ok := doc["info"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("expected 'info' object in swagger doc")
|
||||
}
|
||||
if info["title"] != "Test API" {
|
||||
t.Fatalf("expected title=%q, got %q", "Test API", info["title"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestSwaggerDisabledByDefault_Good(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
|
|
@ -81,6 +127,32 @@ func TestSwaggerDisabledByDefault_Good(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestSwaggerAuth_Good_CustomPathBypassesBearerAuth(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
e, err := api.New(
|
||||
api.WithBearerAuth("secret"),
|
||||
api.WithSwagger("Test API", "A test API service", "1.0.0"),
|
||||
api.WithSwaggerPath("/docs"),
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
srv := httptest.NewServer(e.Handler())
|
||||
defer srv.Close()
|
||||
|
||||
resp, err := http.Get(srv.URL + "/docs/doc.json")
|
||||
if err != nil {
|
||||
t.Fatalf("request failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("expected 200 for custom swagger path without auth, got %d", resp.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSwagger_Good_SpecNotEmpty(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue