diff --git a/cache.go b/cache.go index 9830aa6..01fdd91 100644 --- a/cache.go +++ b/cache.go @@ -125,10 +125,18 @@ func cacheMiddleware(store *cacheStore, ttl time.Duration) gin.HandlerFunc { // Serve from cache if a valid entry exists. if entry := store.get(key); entry != nil { for k, vals := range entry.headers { + if http.CanonicalHeaderKey(k) == "X-Request-ID" { + continue + } for _, v := range vals { c.Writer.Header().Set(k, v) } } + if requestID := GetRequestID(c); requestID != "" { + c.Writer.Header().Set("X-Request-ID", requestID) + } else if requestID := c.GetHeader("X-Request-ID"); requestID != "" { + c.Writer.Header().Set("X-Request-ID", requestID) + } c.Writer.Header().Set("X-Cache", "HIT") c.Writer.WriteHeader(entry.status) _, _ = c.Writer.Write(entry.body) diff --git a/cache_test.go b/cache_test.go index 3e48349..fa88610 100644 --- a/cache_test.go +++ b/cache_test.go @@ -214,6 +214,47 @@ func TestWithCache_Good_CombinesWithOtherMiddleware(t *testing.T) { } } +func TestWithCache_Good_PreservesCurrentRequestIDOnHit(t *testing.T) { + gin.SetMode(gin.TestMode) + grp := &cacheCounterGroup{} + e, _ := api.New( + api.WithRequestID(), + api.WithCache(5*time.Second), + ) + e.Register(grp) + + h := e.Handler() + + w1 := httptest.NewRecorder() + req1, _ := http.NewRequest(http.MethodGet, "/cache/counter", nil) + req1.Header.Set("X-Request-ID", "first-request-id") + h.ServeHTTP(w1, req1) + if w1.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", w1.Code) + } + if got := w1.Header().Get("X-Request-ID"); got != "first-request-id" { + t.Fatalf("expected first response request ID %q, got %q", "first-request-id", got) + } + + w2 := httptest.NewRecorder() + req2, _ := http.NewRequest(http.MethodGet, "/cache/counter", nil) + req2.Header.Set("X-Request-ID", "second-request-id") + h.ServeHTTP(w2, req2) + if w2.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", w2.Code) + } + + if got := w2.Header().Get("X-Request-ID"); got != "second-request-id" { + t.Fatalf("expected cached response to preserve current request ID %q, got %q", "second-request-id", got) + } + if got := w2.Header().Get("X-Cache"); got != "HIT" { + t.Fatalf("expected X-Cache=HIT, got %q", got) + } + if body1, body2 := w1.Body.String(), w2.Body.String(); body1 != body2 { + t.Fatalf("expected cached body %q, got %q", body1, body2) + } +} + func TestWithCache_Good_ExpiredCacheMisses(t *testing.T) { gin.SetMode(gin.TestMode) grp := &cacheCounterGroup{}