feat(api): add request ID accessor helper

Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
Virgil 2026-04-01 13:31:19 +00:00
parent c9cf407530
commit 825b61c113
2 changed files with 53 additions and 1 deletions

View file

@ -11,6 +11,9 @@ import (
"github.com/gin-gonic/gin"
)
// requestIDContextKey is the Gin context key used by requestIDMiddleware.
const requestIDContextKey = "request_id"
// bearerAuthMiddleware validates the Authorization: Bearer <token> header.
// Requests to paths in the skip list are allowed through without authentication.
// Returns 401 with Fail("unauthorised", ...) on missing or invalid tokens.
@ -52,8 +55,19 @@ func requestIDMiddleware() gin.HandlerFunc {
id = hex.EncodeToString(b)
}
c.Set("request_id", id)
c.Set(requestIDContextKey, id)
c.Header("X-Request-ID", id)
c.Next()
}
}
// GetRequestID returns the request ID assigned by requestIDMiddleware.
// Returns an empty string when the middleware was not applied.
func GetRequestID(c *gin.Context) string {
if v, ok := c.Get(requestIDContextKey); ok {
if s, ok := v.(string); ok {
return s
}
}
return ""
}

View file

@ -26,6 +26,19 @@ func (m *mwTestGroup) RegisterRoutes(rg *gin.RouterGroup) {
})
}
type requestIDTestGroup struct {
gotID *string
}
func (g requestIDTestGroup) Name() string { return "request-id" }
func (g requestIDTestGroup) BasePath() string { return "/v1" }
func (g requestIDTestGroup) RegisterRoutes(rg *gin.RouterGroup) {
rg.GET("/secret", func(c *gin.Context) {
*g.gotID = api.GetRequestID(c)
c.JSON(http.StatusOK, api.OK("classified"))
})
}
// ── Bearer auth ─────────────────────────────────────────────────────────
func TestBearerAuth_Bad_MissingToken(t *testing.T) {
@ -151,6 +164,31 @@ func TestRequestID_Good_PreservesClientID(t *testing.T) {
}
}
func TestRequestID_Good_ContextAccessor(t *testing.T) {
gin.SetMode(gin.TestMode)
e, _ := api.New(api.WithRequestID())
var gotID string
e.Register(requestIDTestGroup{gotID: &gotID})
h := e.Handler()
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodGet, "/v1/secret", nil)
req.Header.Set("X-Request-ID", "client-id-xyz")
h.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d", w.Code)
}
if gotID == "" {
t.Fatal("expected GetRequestID to return the request ID inside the handler")
}
if gotID != "client-id-xyz" {
t.Fatalf("expected GetRequestID=%q, got %q", "client-id-xyz", gotID)
}
}
// ── CORS ────────────────────────────────────────────────────────────────
func TestCORS_Good_PreflightAllOrigins(t *testing.T) {