From 0ab962a258ccd2f27cfa437387427c589da1ae51 Mon Sep 17 00:00:00 2001 From: Snider Date: Fri, 20 Feb 2026 23:41:48 +0000 Subject: [PATCH] feat: add WithCache response caching middleware Co-Authored-By: Virgil Co-Authored-By: Claude Opus 4.6 --- cache.go | 127 +++++++++++++++++++++++++ cache_test.go | 252 ++++++++++++++++++++++++++++++++++++++++++++++++++ options.go | 11 +++ 3 files changed, 390 insertions(+) create mode 100644 cache.go create mode 100644 cache_test.go diff --git a/cache.go b/cache.go new file mode 100644 index 0000000..54b5511 --- /dev/null +++ b/cache.go @@ -0,0 +1,127 @@ +// SPDX-License-Identifier: EUPL-1.2 + +package api + +import ( + "bytes" + "net/http" + "sync" + "time" + + "github.com/gin-gonic/gin" +) + +// cacheEntry holds a cached response body, status code, headers, and expiry. +type cacheEntry struct { + status int + headers http.Header + body []byte + expires time.Time +} + +// cacheStore is a simple thread-safe in-memory cache keyed by request URL. +type cacheStore struct { + mu sync.RWMutex + entries map[string]*cacheEntry +} + +// newCacheStore creates an empty cache store. +func newCacheStore() *cacheStore { + return &cacheStore{ + entries: make(map[string]*cacheEntry), + } +} + +// get retrieves a non-expired entry for the given key. +// Returns nil if the key is missing or expired. +func (s *cacheStore) get(key string) *cacheEntry { + s.mu.RLock() + entry, ok := s.entries[key] + s.mu.RUnlock() + + if !ok { + return nil + } + if time.Now().After(entry.expires) { + s.mu.Lock() + delete(s.entries, key) + s.mu.Unlock() + return nil + } + return entry +} + +// set stores a cache entry with the given TTL. +func (s *cacheStore) set(key string, entry *cacheEntry) { + s.mu.Lock() + s.entries[key] = entry + s.mu.Unlock() +} + +// cacheWriter intercepts writes to capture the response body and status. +type cacheWriter struct { + gin.ResponseWriter + body *bytes.Buffer +} + +func (w *cacheWriter) Write(data []byte) (int, error) { + w.body.Write(data) + return w.ResponseWriter.Write(data) +} + +func (w *cacheWriter) WriteString(s string) (int, error) { + w.body.WriteString(s) + return w.ResponseWriter.WriteString(s) +} + +// cacheMiddleware returns Gin middleware that caches GET responses in memory. +// Only successful responses (2xx) are cached. Non-GET methods pass through. +func cacheMiddleware(store *cacheStore, ttl time.Duration) gin.HandlerFunc { + return func(c *gin.Context) { + // Only cache GET requests. + if c.Request.Method != http.MethodGet { + c.Next() + return + } + + key := c.Request.URL.RequestURI() + + // Serve from cache if a valid entry exists. + if entry := store.get(key); entry != nil { + for k, vals := range entry.headers { + for _, v := range vals { + c.Writer.Header().Set(k, v) + } + } + c.Writer.Header().Set("X-Cache", "HIT") + c.Writer.WriteHeader(entry.status) + _, _ = c.Writer.Write(entry.body) + c.Abort() + return + } + + // Wrap the writer to capture the response. + cw := &cacheWriter{ + ResponseWriter: c.Writer, + body: &bytes.Buffer{}, + } + c.Writer = cw + + c.Next() + + // Only cache successful responses. + status := cw.ResponseWriter.Status() + if status >= 200 && status < 300 { + headers := make(http.Header) + for k, vals := range cw.ResponseWriter.Header() { + headers[k] = vals + } + store.set(key, &cacheEntry{ + status: status, + headers: headers, + body: cw.body.Bytes(), + expires: time.Now().Add(ttl), + }) + } + } +} diff --git a/cache_test.go b/cache_test.go new file mode 100644 index 0000000..e2dfc4e --- /dev/null +++ b/cache_test.go @@ -0,0 +1,252 @@ +// SPDX-License-Identifier: EUPL-1.2 + +package api_test + +import ( + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "strings" + "sync/atomic" + "testing" + "time" + + "github.com/gin-gonic/gin" + + api "forge.lthn.ai/core/go-api" +) + +// cacheCounterGroup registers routes that increment a counter on each call, +// allowing tests to distinguish cached from uncached responses. +type cacheCounterGroup struct { + counter atomic.Int64 +} + +func (g *cacheCounterGroup) Name() string { return "cache-test" } +func (g *cacheCounterGroup) BasePath() string { return "/cache" } +func (g *cacheCounterGroup) RegisterRoutes(rg *gin.RouterGroup) { + rg.GET("/counter", func(c *gin.Context) { + n := g.counter.Add(1) + c.JSON(http.StatusOK, api.OK(fmt.Sprintf("call-%d", n))) + }) + rg.GET("/other", func(c *gin.Context) { + n := g.counter.Add(1) + c.JSON(http.StatusOK, api.OK(fmt.Sprintf("other-%d", n))) + }) + rg.POST("/counter", func(c *gin.Context) { + n := g.counter.Add(1) + c.JSON(http.StatusOK, api.OK(fmt.Sprintf("post-%d", n))) + }) +} + +// ── WithCache ─────────────────────────────────────────────────────────── + +func TestWithCache_Good_CachesGETResponse(t *testing.T) { + gin.SetMode(gin.TestMode) + grp := &cacheCounterGroup{} + e, _ := api.New(api.WithCache(5 * time.Second)) + e.Register(grp) + + h := e.Handler() + + // First request — cache MISS. + w1 := httptest.NewRecorder() + req1, _ := http.NewRequest(http.MethodGet, "/cache/counter", nil) + h.ServeHTTP(w1, req1) + + if w1.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", w1.Code) + } + + body1 := w1.Body.String() + if !strings.Contains(body1, "call-1") { + t.Fatalf("expected body to contain %q, got %q", "call-1", body1) + } + + // Second request — should be a cache HIT returning the same body. + w2 := httptest.NewRecorder() + req2, _ := http.NewRequest(http.MethodGet, "/cache/counter", nil) + h.ServeHTTP(w2, req2) + + if w2.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", w2.Code) + } + + body2 := w2.Body.String() + if body1 != body2 { + t.Fatalf("expected cached body %q, got %q", body1, body2) + } + + cacheHeader := w2.Header().Get("X-Cache") + if cacheHeader != "HIT" { + t.Fatalf("expected X-Cache=HIT, got %q", cacheHeader) + } + + // Counter should still be 1 (handler was not called again). + if grp.counter.Load() != 1 { + t.Fatalf("expected counter=1 (cached), got %d", grp.counter.Load()) + } +} + +func TestWithCache_Good_POSTNotCached(t *testing.T) { + gin.SetMode(gin.TestMode) + grp := &cacheCounterGroup{} + e, _ := api.New(api.WithCache(5 * time.Second)) + e.Register(grp) + + h := e.Handler() + + // First POST request. + w1 := httptest.NewRecorder() + req1, _ := http.NewRequest(http.MethodPost, "/cache/counter", nil) + h.ServeHTTP(w1, req1) + + if w1.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", w1.Code) + } + + var resp1 api.Response[string] + if err := json.Unmarshal(w1.Body.Bytes(), &resp1); err != nil { + t.Fatalf("unmarshal error: %v", err) + } + if resp1.Data != "post-1" { + t.Fatalf("expected Data=%q, got %q", "post-1", resp1.Data) + } + + // Second POST request — should NOT be cached, counter increments. + w2 := httptest.NewRecorder() + req2, _ := http.NewRequest(http.MethodPost, "/cache/counter", nil) + h.ServeHTTP(w2, req2) + + var resp2 api.Response[string] + if err := json.Unmarshal(w2.Body.Bytes(), &resp2); err != nil { + t.Fatalf("unmarshal error: %v", err) + } + if resp2.Data != "post-2" { + t.Fatalf("expected Data=%q, got %q", "post-2", resp2.Data) + } + + // Counter should be 2 — both POST requests hit the handler. + if grp.counter.Load() != 2 { + t.Fatalf("expected counter=2, got %d", grp.counter.Load()) + } +} + +func TestWithCache_Good_DifferentPathsSeparatelyCached(t *testing.T) { + gin.SetMode(gin.TestMode) + grp := &cacheCounterGroup{} + e, _ := api.New(api.WithCache(5 * time.Second)) + e.Register(grp) + + h := e.Handler() + + // Request to /cache/counter. + w1 := httptest.NewRecorder() + req1, _ := http.NewRequest(http.MethodGet, "/cache/counter", nil) + h.ServeHTTP(w1, req1) + + body1 := w1.Body.String() + if !strings.Contains(body1, "call-1") { + t.Fatalf("expected body to contain %q, got %q", "call-1", body1) + } + + // Request to /cache/other — different path, should miss cache. + w2 := httptest.NewRecorder() + req2, _ := http.NewRequest(http.MethodGet, "/cache/other", nil) + h.ServeHTTP(w2, req2) + + body2 := w2.Body.String() + if !strings.Contains(body2, "other-2") { + t.Fatalf("expected body to contain %q, got %q", "other-2", body2) + } + + // Counter is 2 — both paths hit the handler. + if grp.counter.Load() != 2 { + t.Fatalf("expected counter=2, got %d", grp.counter.Load()) + } + + // Re-request /cache/counter — should serve cached "call-1". + w3 := httptest.NewRecorder() + req3, _ := http.NewRequest(http.MethodGet, "/cache/counter", nil) + h.ServeHTTP(w3, req3) + + body3 := w3.Body.String() + if body1 != body3 { + t.Fatalf("expected cached body %q, got %q", body1, body3) + } + + // Counter unchanged — served from cache. + if grp.counter.Load() != 2 { + t.Fatalf("expected counter=2 (cached), got %d", grp.counter.Load()) + } +} + +func TestWithCache_Good_CombinesWithOtherMiddleware(t *testing.T) { + gin.SetMode(gin.TestMode) + grp := &cacheCounterGroup{} + e, _ := api.New( + api.WithRequestID(), + api.WithCache(5*time.Second), + ) + e.Register(grp) + + h := e.Handler() + + w := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodGet, "/cache/counter", nil) + h.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", w.Code) + } + + // RequestID middleware should still set X-Request-ID. + rid := w.Header().Get("X-Request-ID") + if rid == "" { + t.Fatal("expected X-Request-ID header from WithRequestID") + } + + // Body should contain the expected response. + body := w.Body.String() + if !strings.Contains(body, "call-1") { + t.Fatalf("expected body to contain %q, got %q", "call-1", body) + } +} + +func TestWithCache_Good_ExpiredCacheMisses(t *testing.T) { + gin.SetMode(gin.TestMode) + grp := &cacheCounterGroup{} + e, _ := api.New(api.WithCache(50 * time.Millisecond)) + e.Register(grp) + + h := e.Handler() + + // First request — populates cache. + w1 := httptest.NewRecorder() + req1, _ := http.NewRequest(http.MethodGet, "/cache/counter", nil) + h.ServeHTTP(w1, req1) + + body1 := w1.Body.String() + if !strings.Contains(body1, "call-1") { + t.Fatalf("expected body to contain %q, got %q", "call-1", body1) + } + + // Wait for cache to expire. + time.Sleep(100 * time.Millisecond) + + // Second request — cache expired, handler called again. + w2 := httptest.NewRecorder() + req2, _ := http.NewRequest(http.MethodGet, "/cache/counter", nil) + h.ServeHTTP(w2, req2) + + body2 := w2.Body.String() + if !strings.Contains(body2, "call-2") { + t.Fatalf("expected body to contain %q after expiry, got %q", "call-2", body2) + } + + // Counter should be 2 — both requests hit the handler. + if grp.counter.Load() != 2 { + t.Fatalf("expected counter=2, got %d", grp.counter.Load()) + } +} diff --git a/options.go b/options.go index 634107d..8ea490f 100644 --- a/options.go +++ b/options.go @@ -195,3 +195,14 @@ func WithTimeout(d time.Duration) Option { func timeoutResponse(c *gin.Context) { c.JSON(http.StatusGatewayTimeout, Fail("timeout", "Request timed out")) } + +// WithCache adds in-memory response caching middleware for GET requests. +// Successful (2xx) GET responses are cached for the given TTL and served +// with an X-Cache: HIT header on subsequent requests. Non-GET methods +// and error responses pass through uncached. +func WithCache(ttl time.Duration) Option { + return func(e *Engine) { + store := newCacheStore() + e.middlewares = append(e.middlewares, cacheMiddleware(store, ttl)) + } +}