diff --git a/cache.go b/cache.go index 20cc048..d9d86c9 100644 --- a/cache.go +++ b/cache.go @@ -19,25 +19,29 @@ type cacheEntry struct { status int headers http.Header body []byte + size int expires time.Time } // cacheStore is a simple thread-safe in-memory cache keyed by request URL. type cacheStore struct { - mu sync.RWMutex - entries map[string]*cacheEntry - order *list.List - index map[string]*list.Element - maxEntries int + mu sync.RWMutex + entries map[string]*cacheEntry + order *list.List + index map[string]*list.Element + maxEntries int + maxBytes int + currentBytes int } // newCacheStore creates an empty cache store. -func newCacheStore(maxEntries int) *cacheStore { +func newCacheStore(maxEntries, maxBytes int) *cacheStore { return &cacheStore{ entries: make(map[string]*cacheEntry), order: list.New(), index: make(map[string]*list.Element), maxEntries: maxEntries, + maxBytes: maxBytes, } } @@ -62,6 +66,10 @@ func (s *cacheStore) get(key string) *cacheEntry { s.order.Remove(elem) delete(s.index, key) } + s.currentBytes -= entry.size + if s.currentBytes < 0 { + s.currentBytes = 0 + } delete(s.entries, key) s.mu.Unlock() return nil @@ -72,29 +80,81 @@ func (s *cacheStore) get(key string) *cacheEntry { // set stores a cache entry with the given TTL. func (s *cacheStore) set(key string, entry *cacheEntry) { s.mu.Lock() + if entry.size <= 0 { + entry.size = cacheEntrySize(entry.headers, entry.body) + } + if elem, ok := s.index[key]; ok { + if existing, exists := s.entries[key]; exists { + s.currentBytes -= existing.size + if s.currentBytes < 0 { + s.currentBytes = 0 + } + } s.order.MoveToFront(elem) s.entries[key] = entry + s.currentBytes += entry.size + s.evictBySizeLocked() s.mu.Unlock() return } - if s.maxEntries > 0 && len(s.entries) >= s.maxEntries { - back := s.order.Back() - if back != nil { - oldKey := back.Value.(string) - delete(s.entries, oldKey) - delete(s.index, oldKey) - s.order.Remove(back) + if s.maxBytes > 0 && entry.size > s.maxBytes { + s.mu.Unlock() + return + } + + for (s.maxEntries > 0 && len(s.entries) >= s.maxEntries) || s.wouldExceedBytesLocked(entry.size) { + if !s.evictOldestLocked() { + break } } + if s.maxBytes > 0 && s.wouldExceedBytesLocked(entry.size) { + s.mu.Unlock() + return + } + s.entries[key] = entry elem := s.order.PushFront(key) s.index[key] = elem + s.currentBytes += entry.size s.mu.Unlock() } +func (s *cacheStore) wouldExceedBytesLocked(nextSize int) bool { + if s.maxBytes <= 0 { + return false + } + return s.currentBytes+nextSize > s.maxBytes +} + +func (s *cacheStore) evictBySizeLocked() { + for s.maxBytes > 0 && s.currentBytes > s.maxBytes { + if !s.evictOldestLocked() { + return + } + } +} + +func (s *cacheStore) evictOldestLocked() bool { + back := s.order.Back() + if back == nil { + return false + } + oldKey := back.Value.(string) + if existing, ok := s.entries[oldKey]; ok { + s.currentBytes -= existing.size + if s.currentBytes < 0 { + s.currentBytes = 0 + } + } + delete(s.entries, oldKey) + delete(s.index, oldKey) + s.order.Remove(back) + return true +} + // cacheWriter intercepts writes to capture the response body and status. type cacheWriter struct { gin.ResponseWriter @@ -172,6 +232,7 @@ func cacheMiddleware(store *cacheStore, ttl time.Duration) gin.HandlerFunc { status: status, headers: headers, body: cw.body.Bytes(), + size: cacheEntrySize(headers, cw.body.Bytes()), expires: time.Now().Add(ttl), }) } @@ -185,3 +246,14 @@ func cacheMiddleware(store *cacheStore, ttl time.Duration) gin.HandlerFunc { func refreshCachedResponseMeta(body []byte, meta *Meta) []byte { return refreshResponseMetaBody(body, meta) } + +func cacheEntrySize(headers http.Header, body []byte) int { + size := len(body) + for key, vals := range headers { + size += len(key) + for _, val := range vals { + size += len(val) + } + } + return size +} diff --git a/cache_test.go b/cache_test.go index e485444..aec4989 100644 --- a/cache_test.go +++ b/cache_test.go @@ -40,6 +40,23 @@ func (g *cacheCounterGroup) RegisterRoutes(rg *gin.RouterGroup) { }) } +type cacheSizedGroup struct { + counter atomic.Int64 +} + +func (g *cacheSizedGroup) Name() string { return "cache-sized" } +func (g *cacheSizedGroup) BasePath() string { return "/cache" } +func (g *cacheSizedGroup) RegisterRoutes(rg *gin.RouterGroup) { + rg.GET("/small", func(c *gin.Context) { + n := g.counter.Add(1) + c.JSON(http.StatusOK, api.OK(fmt.Sprintf("small-%d-%s", n, strings.Repeat("a", 96)))) + }) + rg.GET("/large", func(c *gin.Context) { + n := g.counter.Add(1) + c.JSON(http.StatusOK, api.OK(fmt.Sprintf("large-%d-%s", n, strings.Repeat("b", 96)))) + }) +} + // ── WithCache ─────────────────────────────────────────────────────────── func TestWithCache_Good_CachesGETResponse(t *testing.T) { @@ -467,3 +484,41 @@ func TestWithCache_Good_EvictsWhenCapacityReached(t *testing.T) { t.Fatalf("expected counter=3 after eviction, got %d", grp.counter.Load()) } } + +func TestWithCache_Good_EvictsWhenSizeLimitReached(t *testing.T) { + gin.SetMode(gin.TestMode) + grp := &cacheSizedGroup{} + e, _ := api.New(api.WithCache(5*time.Second, 10, 250)) + e.Register(grp) + + h := e.Handler() + + w1 := httptest.NewRecorder() + req1, _ := http.NewRequest(http.MethodGet, "/cache/small", nil) + h.ServeHTTP(w1, req1) + if !strings.Contains(w1.Body.String(), "small-1") { + t.Fatalf("expected first response to contain %q, got %q", "small-1", w1.Body.String()) + } + + w2 := httptest.NewRecorder() + req2, _ := http.NewRequest(http.MethodGet, "/cache/large", nil) + h.ServeHTTP(w2, req2) + if !strings.Contains(w2.Body.String(), "large-2") { + t.Fatalf("expected second response to contain %q, got %q", "large-2", w2.Body.String()) + } + + w3 := httptest.NewRecorder() + req3, _ := http.NewRequest(http.MethodGet, "/cache/small", nil) + h.ServeHTTP(w3, req3) + if !strings.Contains(w3.Body.String(), "small-3") { + t.Fatalf("expected size-limited cache to evict the oldest entry, got %q", w3.Body.String()) + } + + if got := w3.Header().Get("X-Cache"); got != "" { + t.Fatalf("expected re-executed response to miss the cache, got X-Cache=%q", got) + } + + if grp.counter.Load() != 3 { + t.Fatalf("expected counter=3 after size-based eviction, got %d", grp.counter.Load()) + } +} diff --git a/docs/history.md b/docs/history.md index 823e360..f2a6f81 100644 --- a/docs/history.md +++ b/docs/history.md @@ -169,11 +169,12 @@ At the end of Phase 3, the module has 176 tests. ## Known Limitations -### 1. Cache has no size limit +### 1. Cache remains in-memory -`WithCache(ttl)` stores all successful GET responses in memory with no maximum entry count or -total size bound. For a server receiving requests to many distinct URLs, the cache will grow -without bound. A LRU eviction policy or a configurable maximum is the natural next step. +`WithCache(ttl, maxEntries, maxBytes)` can now bound the cache by entry count and approximate +payload size, but it still stores responses in memory. Workloads with very large cached bodies +or a long-lived process will still consume RAM, so a disk-backed cache would be the next step if +that becomes a concern. ### 2. SDK codegen requires an external binary diff --git a/options.go b/options.go index f5e38fd..698bf91 100644 --- a/options.go +++ b/options.go @@ -350,23 +350,31 @@ func timeoutResponse(c *gin.Context) { // with an X-Cache: HIT header on subsequent requests. Non-GET methods // and error responses pass through uncached. // -// An optional maxEntries limit enables LRU eviction when the cache reaches -// capacity. A value <= 0 keeps the cache unbounded for backward compatibility. -// A non-positive TTL disables the middleware entirely. +// Optional integer limits enable LRU eviction: +// - maxEntries limits the number of cached responses +// - maxBytes limits the approximate total cached payload size +// +// Pass a non-positive value to either limit to leave that dimension +// unbounded for backward compatibility. A non-positive TTL disables the +// middleware entirely. // // Example: // -// engine, _ := api.New(api.WithCache(5*time.Minute, 100)) +// engine, _ := api.New(api.WithCache(5*time.Minute, 100, 10<<20)) func WithCache(ttl time.Duration, maxEntries ...int) Option { return func(e *Engine) { if ttl <= 0 { return } - limit := 0 + entryLimit := 0 + byteLimit := 0 if len(maxEntries) > 0 { - limit = maxEntries[0] + entryLimit = maxEntries[0] } - store := newCacheStore(limit) + if len(maxEntries) > 1 { + byteLimit = maxEntries[1] + } + store := newCacheStore(entryLimit, byteLimit) e.middlewares = append(e.middlewares, cacheMiddleware(store, ttl)) } }