From d21734d8d9cf61af01370b0036f53c7c02b83e02 Mon Sep 17 00:00:00 2001 From: Snider Date: Fri, 20 Feb 2026 15:49:35 +0000 Subject: [PATCH] feat: add bearer auth, request ID, and CORS middleware Co-Authored-By: Virgil Co-Authored-By: Claude Opus 4.6 --- api.go | 12 ++- go.mod | 11 ++- go.sum | 13 +-- middleware.go | 59 ++++++++++++ middleware_test.go | 220 +++++++++++++++++++++++++++++++++++++++++++++ options.go | 57 ++++++++++++ 6 files changed, 360 insertions(+), 12 deletions(-) create mode 100644 middleware.go create mode 100644 middleware_test.go diff --git a/api.go b/api.go index d18a9c3..2d8a763 100644 --- a/api.go +++ b/api.go @@ -21,8 +21,9 @@ const shutdownTimeout = 10 * time.Second // Engine is the central API server managing route groups and middleware. type Engine struct { - addr string - groups []RouteGroup + addr string + groups []RouteGroup + middlewares []gin.HandlerFunc } // New creates an Engine with the given options. @@ -90,11 +91,16 @@ func (e *Engine) Serve(ctx context.Context) error { } // build creates a configured Gin engine with recovery middleware, -// the health endpoint, and all registered route groups. +// user-supplied middleware, the health endpoint, and all registered route groups. func (e *Engine) build() *gin.Engine { r := gin.New() r.Use(gin.Recovery()) + // Apply user-supplied middleware after recovery but before routes. + for _, mw := range e.middlewares { + r.Use(mw) + } + // Built-in health check. r.GET("/health", func(c *gin.Context) { c.JSON(http.StatusOK, OK("healthy")) diff --git a/go.mod b/go.mod index a1e51be..adf9a07 100644 --- a/go.mod +++ b/go.mod @@ -2,24 +2,27 @@ module forge.lthn.ai/core/go-api go 1.25.5 -require github.com/gin-gonic/gin v1.11.0 +require ( + github.com/gin-contrib/cors v1.7.6 + github.com/gin-gonic/gin v1.11.0 +) require ( github.com/bytedance/sonic v1.14.0 // indirect github.com/bytedance/sonic/loader v0.3.0 // indirect github.com/cloudwego/base64x v0.1.6 // indirect - github.com/gabriel-vasile/mimetype v1.4.8 // indirect + github.com/gabriel-vasile/mimetype v1.4.9 // indirect github.com/gin-contrib/sse v1.1.0 // indirect github.com/go-playground/locales v0.14.1 // indirect github.com/go-playground/universal-translator v0.18.1 // indirect github.com/go-playground/validator/v10 v10.27.0 // indirect - github.com/goccy/go-json v0.10.2 // indirect + github.com/goccy/go-json v0.10.5 // indirect github.com/goccy/go-yaml v1.18.0 // indirect github.com/json-iterator/go v1.1.12 // indirect github.com/klauspost/cpuid/v2 v2.3.0 // indirect github.com/leodido/go-urn v1.4.0 // indirect github.com/mattn/go-isatty v0.0.20 // indirect - github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421 // indirect + github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect github.com/pelletier/go-toml/v2 v2.2.4 // indirect github.com/quic-go/qpack v0.5.1 // indirect diff --git a/go.sum b/go.sum index a25fff5..7539074 100644 --- a/go.sum +++ b/go.sum @@ -7,8 +7,10 @@ github.com/cloudwego/base64x v0.1.6/go.mod h1:OFcloc187FXDaYHvrNIjxSe8ncn0OOM8gE github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/gabriel-vasile/mimetype v1.4.8 h1:FfZ3gj38NjllZIeJAmMhr+qKL8Wu+nOoI3GqacKw1NM= -github.com/gabriel-vasile/mimetype v1.4.8/go.mod h1:ByKUIKGjh1ODkGM1asKUbQZOLGrPjydw3hYPU2YU9t8= +github.com/gabriel-vasile/mimetype v1.4.9 h1:5k+WDwEsD9eTLL8Tz3L0VnmVh9QxGjRmjBvAG7U/oYY= +github.com/gabriel-vasile/mimetype v1.4.9/go.mod h1:WnSQhFKJuBlRyLiKohA/2DtIlPFAbguNaG7QCHcyGok= +github.com/gin-contrib/cors v1.7.6 h1:3gQ8GMzs1Ylpf70y8bMw4fVpycXIeX1ZemuSQIsnQQY= +github.com/gin-contrib/cors v1.7.6/go.mod h1:Ulcl+xN4jel9t1Ry8vqph23a60FwH9xVLd+3ykmTjOk= github.com/gin-contrib/sse v1.1.0 h1:n0w2GMuUpWDVp7qSpvze6fAu9iRxJY4Hmj6AmBOU05w= github.com/gin-contrib/sse v1.1.0/go.mod h1:hxRZ5gVpWMT7Z0B0gSNYqqsSCNIJMjzvm6fqCz9vjwM= github.com/gin-gonic/gin v1.11.0 h1:OW/6PLjyusp2PPXtyxKHU0RbX6I/l28FTdDlae5ueWk= @@ -21,8 +23,8 @@ github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJn github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= github.com/go-playground/validator/v10 v10.27.0 h1:w8+XrWVMhGkxOaaowyKH35gFydVHOvC0/uWoy2Fzwn4= github.com/go-playground/validator/v10 v10.27.0/go.mod h1:I5QpIEbmr8On7W0TktmJAumgzX4CA1XNl4ZmDuVHKKo= -github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU= -github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= +github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4= +github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M= github.com/goccy/go-yaml v1.18.0 h1:8W7wMFS12Pcas7KU+VVkaiCng+kG8QiFeFwzFb+rwuw= github.com/goccy/go-yaml v1.18.0/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= @@ -36,8 +38,9 @@ github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ= github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= -github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421 h1:ZqeYNhU3OHLH3mGKHDcjJRFFRrJa6eAM5H+CtDdOsPc= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= +github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4= diff --git a/middleware.go b/middleware.go new file mode 100644 index 0000000..55fe8ae --- /dev/null +++ b/middleware.go @@ -0,0 +1,59 @@ +// SPDX-License-Identifier: EUPL-1.2 + +package api + +import ( + "crypto/rand" + "encoding/hex" + "net/http" + "strings" + + "github.com/gin-gonic/gin" +) + +// bearerAuthMiddleware validates the Authorization: Bearer header. +// Requests to paths in the skip list are allowed through without authentication. +// Returns 401 with Fail("unauthorised", ...) on missing or invalid tokens. +func bearerAuthMiddleware(token string, skip []string) gin.HandlerFunc { + return func(c *gin.Context) { + // Check whether the request path should bypass authentication. + for _, path := range skip { + if strings.HasPrefix(c.Request.URL.Path, path) { + c.Next() + return + } + } + + header := c.GetHeader("Authorization") + if header == "" { + c.AbortWithStatusJSON(http.StatusUnauthorized, Fail("unauthorised", "missing authorization header")) + return + } + + parts := strings.SplitN(header, " ", 2) + if len(parts) != 2 || !strings.EqualFold(parts[0], "Bearer") || parts[1] != token { + c.AbortWithStatusJSON(http.StatusUnauthorized, Fail("unauthorised", "invalid bearer token")) + return + } + + c.Next() + } +} + +// requestIDMiddleware ensures every response carries an X-Request-ID header. +// If the client sends one, it is preserved; otherwise a random 16-byte hex +// string is generated. The ID is also stored in the Gin context as "request_id". +func requestIDMiddleware() gin.HandlerFunc { + return func(c *gin.Context) { + id := c.GetHeader("X-Request-ID") + if id == "" { + b := make([]byte, 16) + _, _ = rand.Read(b) + id = hex.EncodeToString(b) + } + + c.Set("request_id", id) + c.Header("X-Request-ID", id) + c.Next() + } +} diff --git a/middleware_test.go b/middleware_test.go new file mode 100644 index 0000000..3049192 --- /dev/null +++ b/middleware_test.go @@ -0,0 +1,220 @@ +// SPDX-License-Identifier: EUPL-1.2 + +package api_test + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/gin-gonic/gin" + + api "forge.lthn.ai/core/go-api" +) + +// ── Helpers ───────────────────────────────────────────────────────────── + +// mwTestGroup provides a simple /v1/secret endpoint for middleware tests. +type mwTestGroup struct{} + +func (m *mwTestGroup) Name() string { return "mw-test" } +func (m *mwTestGroup) BasePath() string { return "/v1" } +func (m *mwTestGroup) RegisterRoutes(rg *gin.RouterGroup) { + rg.GET("/secret", func(c *gin.Context) { + c.JSON(http.StatusOK, api.OK("classified")) + }) +} + +// ── Bearer auth ───────────────────────────────────────────────────────── + +func TestBearerAuth_Bad_MissingToken(t *testing.T) { + gin.SetMode(gin.TestMode) + e, _ := api.New(api.WithBearerAuth("s3cret")) + e.Register(&mwTestGroup{}) + + h := e.Handler() + w := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodGet, "/v1/secret", nil) + h.ServeHTTP(w, req) + + if w.Code != http.StatusUnauthorized { + t.Fatalf("expected 401, got %d", w.Code) + } + + var resp api.Response[any] + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("unmarshal error: %v", err) + } + if resp.Error == nil || resp.Error.Code != "unauthorised" { + t.Fatalf("expected error code=%q, got %+v", "unauthorised", resp.Error) + } +} + +func TestBearerAuth_Bad_WrongToken(t *testing.T) { + gin.SetMode(gin.TestMode) + e, _ := api.New(api.WithBearerAuth("s3cret")) + e.Register(&mwTestGroup{}) + + h := e.Handler() + w := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodGet, "/v1/secret", nil) + req.Header.Set("Authorization", "Bearer wrong-token") + h.ServeHTTP(w, req) + + if w.Code != http.StatusUnauthorized { + t.Fatalf("expected 401, got %d", w.Code) + } + + var resp api.Response[any] + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("unmarshal error: %v", err) + } + if resp.Error == nil || resp.Error.Code != "unauthorised" { + t.Fatalf("expected error code=%q, got %+v", "unauthorised", resp.Error) + } +} + +func TestBearerAuth_Good_CorrectToken(t *testing.T) { + gin.SetMode(gin.TestMode) + e, _ := api.New(api.WithBearerAuth("s3cret")) + e.Register(&mwTestGroup{}) + + h := e.Handler() + w := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodGet, "/v1/secret", nil) + req.Header.Set("Authorization", "Bearer s3cret") + h.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", w.Code) + } + + var resp api.Response[string] + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("unmarshal error: %v", err) + } + if resp.Data != "classified" { + t.Fatalf("expected Data=%q, got %q", "classified", resp.Data) + } +} + +func TestBearerAuth_Good_HealthBypassesAuth(t *testing.T) { + gin.SetMode(gin.TestMode) + e, _ := api.New(api.WithBearerAuth("s3cret")) + + h := e.Handler() + w := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodGet, "/health", nil) + // No Authorization header. + h.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200 for /health, got %d", w.Code) + } +} + +// ── Request ID ────────────────────────────────────────────────────────── + +func TestRequestID_Good_GeneratedWhenMissing(t *testing.T) { + gin.SetMode(gin.TestMode) + e, _ := api.New(api.WithRequestID()) + + h := e.Handler() + w := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodGet, "/health", nil) + h.ServeHTTP(w, req) + + id := w.Header().Get("X-Request-ID") + if id == "" { + t.Fatal("expected X-Request-ID header to be set") + } + // 16 bytes = 32 hex characters. + if len(id) != 32 { + t.Fatalf("expected 32-char hex ID, got %d chars: %q", len(id), id) + } +} + +func TestRequestID_Good_PreservesClientID(t *testing.T) { + gin.SetMode(gin.TestMode) + e, _ := api.New(api.WithRequestID()) + + h := e.Handler() + w := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodGet, "/health", nil) + req.Header.Set("X-Request-ID", "client-id-abc") + h.ServeHTTP(w, req) + + id := w.Header().Get("X-Request-ID") + if id != "client-id-abc" { + t.Fatalf("expected X-Request-ID=%q, got %q", "client-id-abc", id) + } +} + +// ── CORS ──────────────────────────────────────────────────────────────── + +func TestCORS_Good_PreflightAllOrigins(t *testing.T) { + gin.SetMode(gin.TestMode) + e, _ := api.New(api.WithCORS("*")) + + h := e.Handler() + w := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodOptions, "/health", nil) + req.Header.Set("Origin", "https://example.com") + req.Header.Set("Access-Control-Request-Method", "GET") + req.Header.Set("Access-Control-Request-Headers", "Authorization") + h.ServeHTTP(w, req) + + if w.Code != http.StatusNoContent && w.Code != http.StatusOK { + t.Fatalf("expected 200 or 204 for preflight, got %d", w.Code) + } + + origin := w.Header().Get("Access-Control-Allow-Origin") + if origin != "*" { + t.Fatalf("expected Access-Control-Allow-Origin=%q, got %q", "*", origin) + } + + methods := w.Header().Get("Access-Control-Allow-Methods") + if methods == "" { + t.Fatal("expected Access-Control-Allow-Methods to be set") + } + + headers := w.Header().Get("Access-Control-Allow-Headers") + if headers == "" { + t.Fatal("expected Access-Control-Allow-Headers to be set") + } +} + +func TestCORS_Good_SpecificOrigin(t *testing.T) { + gin.SetMode(gin.TestMode) + e, _ := api.New(api.WithCORS("https://app.example.com")) + + h := e.Handler() + w := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodOptions, "/health", nil) + req.Header.Set("Origin", "https://app.example.com") + req.Header.Set("Access-Control-Request-Method", "POST") + h.ServeHTTP(w, req) + + origin := w.Header().Get("Access-Control-Allow-Origin") + if origin != "https://app.example.com" { + t.Fatalf("expected origin=%q, got %q", "https://app.example.com", origin) + } +} + +func TestCORS_Bad_DisallowedOrigin(t *testing.T) { + gin.SetMode(gin.TestMode) + e, _ := api.New(api.WithCORS("https://allowed.example.com")) + + h := e.Handler() + w := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodOptions, "/health", nil) + req.Header.Set("Origin", "https://evil.example.com") + req.Header.Set("Access-Control-Request-Method", "GET") + h.ServeHTTP(w, req) + + origin := w.Header().Get("Access-Control-Allow-Origin") + if origin != "" { + t.Fatalf("expected no Access-Control-Allow-Origin for disallowed origin, got %q", origin) + } +} diff --git a/options.go b/options.go index bd628c4..8b41625 100644 --- a/options.go +++ b/options.go @@ -2,6 +2,13 @@ package api +import ( + "time" + + "github.com/gin-contrib/cors" + "github.com/gin-gonic/gin" +) + // Option configures an Engine during construction. type Option func(*Engine) @@ -11,3 +18,53 @@ func WithAddr(addr string) Option { e.addr = addr } } + +// WithBearerAuth adds bearer token authentication middleware. +// Requests to /health and paths starting with /swagger are exempt. +func WithBearerAuth(token string) Option { + return func(e *Engine) { + skip := []string{"/health", "/swagger"} + e.middlewares = append(e.middlewares, bearerAuthMiddleware(token, skip)) + } +} + +// WithRequestID adds middleware that assigns an X-Request-ID to every response. +// Client-provided IDs are preserved; otherwise a random hex ID is generated. +func WithRequestID() Option { + return func(e *Engine) { + e.middlewares = append(e.middlewares, requestIDMiddleware()) + } +} + +// WithCORS configures Cross-Origin Resource Sharing via gin-contrib/cors. +// Pass "*" to allow all origins, or supply specific origin URLs. +// Standard methods (GET, POST, PUT, PATCH, DELETE, OPTIONS) and common +// headers (Authorization, Content-Type, X-Request-ID) are permitted. +func WithCORS(allowOrigins ...string) Option { + return func(e *Engine) { + cfg := cors.Config{ + AllowMethods: []string{"GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"}, + AllowHeaders: []string{"Authorization", "Content-Type", "X-Request-ID"}, + MaxAge: 12 * time.Hour, + } + + for _, o := range allowOrigins { + if o == "*" { + cfg.AllowAllOrigins = true + break + } + } + if !cfg.AllowAllOrigins { + cfg.AllowOrigins = allowOrigins + } + + e.middlewares = append(e.middlewares, cors.New(cfg)) + } +} + +// WithMiddleware appends arbitrary Gin middleware to the engine. +func WithMiddleware(mw ...gin.HandlerFunc) Option { + return func(e *Engine) { + e.middlewares = append(e.middlewares, mw...) + } +}