diff --git a/middleware.go b/middleware.go index 55fe8ae..df5307d 100644 --- a/middleware.go +++ b/middleware.go @@ -11,6 +11,9 @@ import ( "github.com/gin-gonic/gin" ) +// requestIDContextKey is the Gin context key used by requestIDMiddleware. +const requestIDContextKey = "request_id" + // bearerAuthMiddleware validates the Authorization: Bearer header. // Requests to paths in the skip list are allowed through without authentication. // Returns 401 with Fail("unauthorised", ...) on missing or invalid tokens. @@ -52,8 +55,19 @@ func requestIDMiddleware() gin.HandlerFunc { id = hex.EncodeToString(b) } - c.Set("request_id", id) + c.Set(requestIDContextKey, id) c.Header("X-Request-ID", id) c.Next() } } + +// GetRequestID returns the request ID assigned by requestIDMiddleware. +// Returns an empty string when the middleware was not applied. +func GetRequestID(c *gin.Context) string { + if v, ok := c.Get(requestIDContextKey); ok { + if s, ok := v.(string); ok { + return s + } + } + return "" +} diff --git a/middleware_test.go b/middleware_test.go index a44da53..19ed777 100644 --- a/middleware_test.go +++ b/middleware_test.go @@ -26,6 +26,19 @@ func (m *mwTestGroup) RegisterRoutes(rg *gin.RouterGroup) { }) } +type requestIDTestGroup struct { + gotID *string +} + +func (g requestIDTestGroup) Name() string { return "request-id" } +func (g requestIDTestGroup) BasePath() string { return "/v1" } +func (g requestIDTestGroup) RegisterRoutes(rg *gin.RouterGroup) { + rg.GET("/secret", func(c *gin.Context) { + *g.gotID = api.GetRequestID(c) + c.JSON(http.StatusOK, api.OK("classified")) + }) +} + // ── Bearer auth ───────────────────────────────────────────────────────── func TestBearerAuth_Bad_MissingToken(t *testing.T) { @@ -151,6 +164,31 @@ func TestRequestID_Good_PreservesClientID(t *testing.T) { } } +func TestRequestID_Good_ContextAccessor(t *testing.T) { + gin.SetMode(gin.TestMode) + e, _ := api.New(api.WithRequestID()) + + var gotID string + e.Register(requestIDTestGroup{gotID: &gotID}) + + h := e.Handler() + w := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodGet, "/v1/secret", nil) + req.Header.Set("X-Request-ID", "client-id-xyz") + h.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", w.Code) + } + + if gotID == "" { + t.Fatal("expected GetRequestID to return the request ID inside the handler") + } + if gotID != "client-id-xyz" { + t.Fatalf("expected GetRequestID=%q, got %q", "client-id-xyz", gotID) + } +} + // ── CORS ──────────────────────────────────────────────────────────────── func TestCORS_Good_PreflightAllOrigins(t *testing.T) {