diff --git a/cache.go b/cache.go index d032346..9830aa6 100644 --- a/cache.go +++ b/cache.go @@ -4,6 +4,7 @@ package api import ( "bytes" + "container/list" "maps" "net/http" "sync" @@ -22,29 +23,44 @@ type cacheEntry struct { // cacheStore is a simple thread-safe in-memory cache keyed by request URL. type cacheStore struct { - mu sync.RWMutex - entries map[string]*cacheEntry + mu sync.RWMutex + entries map[string]*cacheEntry + order *list.List + index map[string]*list.Element + maxEntries int } // newCacheStore creates an empty cache store. -func newCacheStore() *cacheStore { +func newCacheStore(maxEntries int) *cacheStore { return &cacheStore{ - entries: make(map[string]*cacheEntry), + entries: make(map[string]*cacheEntry), + order: list.New(), + index: make(map[string]*list.Element), + maxEntries: maxEntries, } } // get retrieves a non-expired entry for the given key. // Returns nil if the key is missing or expired. func (s *cacheStore) get(key string) *cacheEntry { - s.mu.RLock() + s.mu.Lock() entry, ok := s.entries[key] - s.mu.RUnlock() + if ok { + if elem, exists := s.index[key]; exists { + s.order.MoveToFront(elem) + } + } + s.mu.Unlock() if !ok { return nil } if time.Now().After(entry.expires) { s.mu.Lock() + if elem, exists := s.index[key]; exists { + s.order.Remove(elem) + delete(s.index, key) + } delete(s.entries, key) s.mu.Unlock() return nil @@ -55,7 +71,26 @@ func (s *cacheStore) get(key string) *cacheEntry { // set stores a cache entry with the given TTL. func (s *cacheStore) set(key string, entry *cacheEntry) { s.mu.Lock() + if elem, ok := s.index[key]; ok { + s.order.MoveToFront(elem) + s.entries[key] = entry + s.mu.Unlock() + return + } + + if s.maxEntries > 0 && len(s.entries) >= s.maxEntries { + back := s.order.Back() + if back != nil { + oldKey := back.Value.(string) + delete(s.entries, oldKey) + delete(s.index, oldKey) + s.order.Remove(back) + } + } + s.entries[key] = entry + elem := s.order.PushFront(key) + s.index[key] = elem s.mu.Unlock() } diff --git a/cache_test.go b/cache_test.go index 58820c3..3e48349 100644 --- a/cache_test.go +++ b/cache_test.go @@ -250,3 +250,37 @@ func TestWithCache_Good_ExpiredCacheMisses(t *testing.T) { t.Fatalf("expected counter=2, got %d", grp.counter.Load()) } } + +func TestWithCache_Good_EvictsWhenCapacityReached(t *testing.T) { + gin.SetMode(gin.TestMode) + grp := &cacheCounterGroup{} + e, _ := api.New(api.WithCache(5*time.Second, 1)) + e.Register(grp) + + h := e.Handler() + + w1 := httptest.NewRecorder() + req1, _ := http.NewRequest(http.MethodGet, "/cache/counter", nil) + h.ServeHTTP(w1, req1) + if !strings.Contains(w1.Body.String(), "call-1") { + t.Fatalf("expected first response to contain %q, got %q", "call-1", w1.Body.String()) + } + + w2 := httptest.NewRecorder() + req2, _ := http.NewRequest(http.MethodGet, "/cache/other", nil) + h.ServeHTTP(w2, req2) + if !strings.Contains(w2.Body.String(), "other-2") { + t.Fatalf("expected second response to contain %q, got %q", "other-2", w2.Body.String()) + } + + w3 := httptest.NewRecorder() + req3, _ := http.NewRequest(http.MethodGet, "/cache/counter", nil) + h.ServeHTTP(w3, req3) + if !strings.Contains(w3.Body.String(), "call-3") { + t.Fatalf("expected evicted response to contain %q, got %q", "call-3", w3.Body.String()) + } + + if grp.counter.Load() != 3 { + t.Fatalf("expected counter=3 after eviction, got %d", grp.counter.Load()) + } +} diff --git a/options.go b/options.go index 14c6e00..cf92163 100644 --- a/options.go +++ b/options.go @@ -232,9 +232,16 @@ func timeoutResponse(c *gin.Context) { // Successful (2xx) GET responses are cached for the given TTL and served // with an X-Cache: HIT header on subsequent requests. Non-GET methods // and error responses pass through uncached. -func WithCache(ttl time.Duration) Option { +// +// An optional maxEntries limit enables LRU eviction when the cache reaches +// capacity. A value <= 0 keeps the cache unbounded for backward compatibility. +func WithCache(ttl time.Duration, maxEntries ...int) Option { return func(e *Engine) { - store := newCacheStore() + limit := 0 + if len(maxEntries) > 0 { + limit = maxEntries[0] + } + store := newCacheStore(limit) e.middlewares = append(e.middlewares, cacheMiddleware(store, ttl)) } }