feat: add WithCache response caching middleware
Co-Authored-By: Virgil <virgil@lethean.io> Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
64a8b16ca2
commit
0ab962a258
3 changed files with 390 additions and 0 deletions
127
cache.go
Normal file
127
cache.go
Normal file
|
|
@ -0,0 +1,127 @@
|
|||
// SPDX-License-Identifier: EUPL-1.2
|
||||
|
||||
package api
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// cacheEntry holds a cached response body, status code, headers, and expiry.
|
||||
type cacheEntry struct {
|
||||
status int
|
||||
headers http.Header
|
||||
body []byte
|
||||
expires time.Time
|
||||
}
|
||||
|
||||
// cacheStore is a simple thread-safe in-memory cache keyed by request URL.
|
||||
type cacheStore struct {
|
||||
mu sync.RWMutex
|
||||
entries map[string]*cacheEntry
|
||||
}
|
||||
|
||||
// newCacheStore creates an empty cache store.
|
||||
func newCacheStore() *cacheStore {
|
||||
return &cacheStore{
|
||||
entries: make(map[string]*cacheEntry),
|
||||
}
|
||||
}
|
||||
|
||||
// get retrieves a non-expired entry for the given key.
|
||||
// Returns nil if the key is missing or expired.
|
||||
func (s *cacheStore) get(key string) *cacheEntry {
|
||||
s.mu.RLock()
|
||||
entry, ok := s.entries[key]
|
||||
s.mu.RUnlock()
|
||||
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
if time.Now().After(entry.expires) {
|
||||
s.mu.Lock()
|
||||
delete(s.entries, key)
|
||||
s.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
return entry
|
||||
}
|
||||
|
||||
// set stores a cache entry with the given TTL.
|
||||
func (s *cacheStore) set(key string, entry *cacheEntry) {
|
||||
s.mu.Lock()
|
||||
s.entries[key] = entry
|
||||
s.mu.Unlock()
|
||||
}
|
||||
|
||||
// cacheWriter intercepts writes to capture the response body and status.
|
||||
type cacheWriter struct {
|
||||
gin.ResponseWriter
|
||||
body *bytes.Buffer
|
||||
}
|
||||
|
||||
func (w *cacheWriter) Write(data []byte) (int, error) {
|
||||
w.body.Write(data)
|
||||
return w.ResponseWriter.Write(data)
|
||||
}
|
||||
|
||||
func (w *cacheWriter) WriteString(s string) (int, error) {
|
||||
w.body.WriteString(s)
|
||||
return w.ResponseWriter.WriteString(s)
|
||||
}
|
||||
|
||||
// cacheMiddleware returns Gin middleware that caches GET responses in memory.
|
||||
// Only successful responses (2xx) are cached. Non-GET methods pass through.
|
||||
func cacheMiddleware(store *cacheStore, ttl time.Duration) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// Only cache GET requests.
|
||||
if c.Request.Method != http.MethodGet {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
key := c.Request.URL.RequestURI()
|
||||
|
||||
// Serve from cache if a valid entry exists.
|
||||
if entry := store.get(key); entry != nil {
|
||||
for k, vals := range entry.headers {
|
||||
for _, v := range vals {
|
||||
c.Writer.Header().Set(k, v)
|
||||
}
|
||||
}
|
||||
c.Writer.Header().Set("X-Cache", "HIT")
|
||||
c.Writer.WriteHeader(entry.status)
|
||||
_, _ = c.Writer.Write(entry.body)
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// Wrap the writer to capture the response.
|
||||
cw := &cacheWriter{
|
||||
ResponseWriter: c.Writer,
|
||||
body: &bytes.Buffer{},
|
||||
}
|
||||
c.Writer = cw
|
||||
|
||||
c.Next()
|
||||
|
||||
// Only cache successful responses.
|
||||
status := cw.ResponseWriter.Status()
|
||||
if status >= 200 && status < 300 {
|
||||
headers := make(http.Header)
|
||||
for k, vals := range cw.ResponseWriter.Header() {
|
||||
headers[k] = vals
|
||||
}
|
||||
store.set(key, &cacheEntry{
|
||||
status: status,
|
||||
headers: headers,
|
||||
body: cw.body.Bytes(),
|
||||
expires: time.Now().Add(ttl),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
252
cache_test.go
Normal file
252
cache_test.go
Normal file
|
|
@ -0,0 +1,252 @@
|
|||
// SPDX-License-Identifier: EUPL-1.2
|
||||
|
||||
package api_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
api "forge.lthn.ai/core/go-api"
|
||||
)
|
||||
|
||||
// cacheCounterGroup registers routes that increment a counter on each call,
|
||||
// allowing tests to distinguish cached from uncached responses.
|
||||
type cacheCounterGroup struct {
|
||||
counter atomic.Int64
|
||||
}
|
||||
|
||||
func (g *cacheCounterGroup) Name() string { return "cache-test" }
|
||||
func (g *cacheCounterGroup) BasePath() string { return "/cache" }
|
||||
func (g *cacheCounterGroup) RegisterRoutes(rg *gin.RouterGroup) {
|
||||
rg.GET("/counter", func(c *gin.Context) {
|
||||
n := g.counter.Add(1)
|
||||
c.JSON(http.StatusOK, api.OK(fmt.Sprintf("call-%d", n)))
|
||||
})
|
||||
rg.GET("/other", func(c *gin.Context) {
|
||||
n := g.counter.Add(1)
|
||||
c.JSON(http.StatusOK, api.OK(fmt.Sprintf("other-%d", n)))
|
||||
})
|
||||
rg.POST("/counter", func(c *gin.Context) {
|
||||
n := g.counter.Add(1)
|
||||
c.JSON(http.StatusOK, api.OK(fmt.Sprintf("post-%d", n)))
|
||||
})
|
||||
}
|
||||
|
||||
// ── WithCache ───────────────────────────────────────────────────────────
|
||||
|
||||
func TestWithCache_Good_CachesGETResponse(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
grp := &cacheCounterGroup{}
|
||||
e, _ := api.New(api.WithCache(5 * time.Second))
|
||||
e.Register(grp)
|
||||
|
||||
h := e.Handler()
|
||||
|
||||
// First request — cache MISS.
|
||||
w1 := httptest.NewRecorder()
|
||||
req1, _ := http.NewRequest(http.MethodGet, "/cache/counter", nil)
|
||||
h.ServeHTTP(w1, req1)
|
||||
|
||||
if w1.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d", w1.Code)
|
||||
}
|
||||
|
||||
body1 := w1.Body.String()
|
||||
if !strings.Contains(body1, "call-1") {
|
||||
t.Fatalf("expected body to contain %q, got %q", "call-1", body1)
|
||||
}
|
||||
|
||||
// Second request — should be a cache HIT returning the same body.
|
||||
w2 := httptest.NewRecorder()
|
||||
req2, _ := http.NewRequest(http.MethodGet, "/cache/counter", nil)
|
||||
h.ServeHTTP(w2, req2)
|
||||
|
||||
if w2.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d", w2.Code)
|
||||
}
|
||||
|
||||
body2 := w2.Body.String()
|
||||
if body1 != body2 {
|
||||
t.Fatalf("expected cached body %q, got %q", body1, body2)
|
||||
}
|
||||
|
||||
cacheHeader := w2.Header().Get("X-Cache")
|
||||
if cacheHeader != "HIT" {
|
||||
t.Fatalf("expected X-Cache=HIT, got %q", cacheHeader)
|
||||
}
|
||||
|
||||
// Counter should still be 1 (handler was not called again).
|
||||
if grp.counter.Load() != 1 {
|
||||
t.Fatalf("expected counter=1 (cached), got %d", grp.counter.Load())
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithCache_Good_POSTNotCached(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
grp := &cacheCounterGroup{}
|
||||
e, _ := api.New(api.WithCache(5 * time.Second))
|
||||
e.Register(grp)
|
||||
|
||||
h := e.Handler()
|
||||
|
||||
// First POST request.
|
||||
w1 := httptest.NewRecorder()
|
||||
req1, _ := http.NewRequest(http.MethodPost, "/cache/counter", nil)
|
||||
h.ServeHTTP(w1, req1)
|
||||
|
||||
if w1.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d", w1.Code)
|
||||
}
|
||||
|
||||
var resp1 api.Response[string]
|
||||
if err := json.Unmarshal(w1.Body.Bytes(), &resp1); err != nil {
|
||||
t.Fatalf("unmarshal error: %v", err)
|
||||
}
|
||||
if resp1.Data != "post-1" {
|
||||
t.Fatalf("expected Data=%q, got %q", "post-1", resp1.Data)
|
||||
}
|
||||
|
||||
// Second POST request — should NOT be cached, counter increments.
|
||||
w2 := httptest.NewRecorder()
|
||||
req2, _ := http.NewRequest(http.MethodPost, "/cache/counter", nil)
|
||||
h.ServeHTTP(w2, req2)
|
||||
|
||||
var resp2 api.Response[string]
|
||||
if err := json.Unmarshal(w2.Body.Bytes(), &resp2); err != nil {
|
||||
t.Fatalf("unmarshal error: %v", err)
|
||||
}
|
||||
if resp2.Data != "post-2" {
|
||||
t.Fatalf("expected Data=%q, got %q", "post-2", resp2.Data)
|
||||
}
|
||||
|
||||
// Counter should be 2 — both POST requests hit the handler.
|
||||
if grp.counter.Load() != 2 {
|
||||
t.Fatalf("expected counter=2, got %d", grp.counter.Load())
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithCache_Good_DifferentPathsSeparatelyCached(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
grp := &cacheCounterGroup{}
|
||||
e, _ := api.New(api.WithCache(5 * time.Second))
|
||||
e.Register(grp)
|
||||
|
||||
h := e.Handler()
|
||||
|
||||
// Request to /cache/counter.
|
||||
w1 := httptest.NewRecorder()
|
||||
req1, _ := http.NewRequest(http.MethodGet, "/cache/counter", nil)
|
||||
h.ServeHTTP(w1, req1)
|
||||
|
||||
body1 := w1.Body.String()
|
||||
if !strings.Contains(body1, "call-1") {
|
||||
t.Fatalf("expected body to contain %q, got %q", "call-1", body1)
|
||||
}
|
||||
|
||||
// Request to /cache/other — different path, should miss cache.
|
||||
w2 := httptest.NewRecorder()
|
||||
req2, _ := http.NewRequest(http.MethodGet, "/cache/other", nil)
|
||||
h.ServeHTTP(w2, req2)
|
||||
|
||||
body2 := w2.Body.String()
|
||||
if !strings.Contains(body2, "other-2") {
|
||||
t.Fatalf("expected body to contain %q, got %q", "other-2", body2)
|
||||
}
|
||||
|
||||
// Counter is 2 — both paths hit the handler.
|
||||
if grp.counter.Load() != 2 {
|
||||
t.Fatalf("expected counter=2, got %d", grp.counter.Load())
|
||||
}
|
||||
|
||||
// Re-request /cache/counter — should serve cached "call-1".
|
||||
w3 := httptest.NewRecorder()
|
||||
req3, _ := http.NewRequest(http.MethodGet, "/cache/counter", nil)
|
||||
h.ServeHTTP(w3, req3)
|
||||
|
||||
body3 := w3.Body.String()
|
||||
if body1 != body3 {
|
||||
t.Fatalf("expected cached body %q, got %q", body1, body3)
|
||||
}
|
||||
|
||||
// Counter unchanged — served from cache.
|
||||
if grp.counter.Load() != 2 {
|
||||
t.Fatalf("expected counter=2 (cached), got %d", grp.counter.Load())
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithCache_Good_CombinesWithOtherMiddleware(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
grp := &cacheCounterGroup{}
|
||||
e, _ := api.New(
|
||||
api.WithRequestID(),
|
||||
api.WithCache(5*time.Second),
|
||||
)
|
||||
e.Register(grp)
|
||||
|
||||
h := e.Handler()
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest(http.MethodGet, "/cache/counter", nil)
|
||||
h.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
// RequestID middleware should still set X-Request-ID.
|
||||
rid := w.Header().Get("X-Request-ID")
|
||||
if rid == "" {
|
||||
t.Fatal("expected X-Request-ID header from WithRequestID")
|
||||
}
|
||||
|
||||
// Body should contain the expected response.
|
||||
body := w.Body.String()
|
||||
if !strings.Contains(body, "call-1") {
|
||||
t.Fatalf("expected body to contain %q, got %q", "call-1", body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithCache_Good_ExpiredCacheMisses(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
grp := &cacheCounterGroup{}
|
||||
e, _ := api.New(api.WithCache(50 * time.Millisecond))
|
||||
e.Register(grp)
|
||||
|
||||
h := e.Handler()
|
||||
|
||||
// First request — populates cache.
|
||||
w1 := httptest.NewRecorder()
|
||||
req1, _ := http.NewRequest(http.MethodGet, "/cache/counter", nil)
|
||||
h.ServeHTTP(w1, req1)
|
||||
|
||||
body1 := w1.Body.String()
|
||||
if !strings.Contains(body1, "call-1") {
|
||||
t.Fatalf("expected body to contain %q, got %q", "call-1", body1)
|
||||
}
|
||||
|
||||
// Wait for cache to expire.
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Second request — cache expired, handler called again.
|
||||
w2 := httptest.NewRecorder()
|
||||
req2, _ := http.NewRequest(http.MethodGet, "/cache/counter", nil)
|
||||
h.ServeHTTP(w2, req2)
|
||||
|
||||
body2 := w2.Body.String()
|
||||
if !strings.Contains(body2, "call-2") {
|
||||
t.Fatalf("expected body to contain %q after expiry, got %q", "call-2", body2)
|
||||
}
|
||||
|
||||
// Counter should be 2 — both requests hit the handler.
|
||||
if grp.counter.Load() != 2 {
|
||||
t.Fatalf("expected counter=2, got %d", grp.counter.Load())
|
||||
}
|
||||
}
|
||||
11
options.go
11
options.go
|
|
@ -195,3 +195,14 @@ func WithTimeout(d time.Duration) Option {
|
|||
func timeoutResponse(c *gin.Context) {
|
||||
c.JSON(http.StatusGatewayTimeout, Fail("timeout", "Request timed out"))
|
||||
}
|
||||
|
||||
// WithCache adds in-memory response caching middleware for GET requests.
|
||||
// Successful (2xx) GET responses are cached for the given TTL and served
|
||||
// with an X-Cache: HIT header on subsequent requests. Non-GET methods
|
||||
// and error responses pass through uncached.
|
||||
func WithCache(ttl time.Duration) Option {
|
||||
return func(e *Engine) {
|
||||
store := newCacheStore()
|
||||
e.middlewares = append(e.middlewares, cacheMiddleware(store, ttl))
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue