diff --git a/cache.go b/cache.go index dbd2f93..19d6bea 100644 --- a/cache.go +++ b/cache.go @@ -50,27 +50,15 @@ func newCacheStore(maxEntries, maxBytes int) *cacheStore { func (s *cacheStore) get(key string) *cacheEntry { s.mu.Lock() entry, ok := s.entries[key] - if ok { - if elem, exists := s.index[key]; exists { - s.order.MoveToFront(elem) - } - } - s.mu.Unlock() - if !ok { + s.mu.Unlock() return nil } + + // Check expiry before promoting in the LRU order so we never move a stale + // entry to the front. All expiry checking and eviction happen inside the + // same critical section to avoid a TOCTOU race. if time.Now().After(entry.expires) { - s.mu.Lock() - // Re-verify the entry pointer is unchanged before evicting. Another - // goroutine may have called set() between us releasing and re-acquiring - // the lock, replacing the entry with a fresh one. Evicting the new - // entry would corrupt s.currentBytes and lose valid cached data. - currentEntry, stillExists := s.entries[key] - if !stillExists || currentEntry != entry { - s.mu.Unlock() - return nil - } if elem, exists := s.index[key]; exists { s.order.Remove(elem) delete(s.index, key) @@ -83,6 +71,12 @@ func (s *cacheStore) get(key string) *cacheEntry { s.mu.Unlock() return nil } + + // Only promote to LRU front after confirming the entry is still valid. + if elem, exists := s.index[key]; exists { + s.order.MoveToFront(elem) + } + s.mu.Unlock() return entry } @@ -201,15 +195,35 @@ func cacheMiddleware(store *cacheStore, ttl time.Duration) gin.HandlerFunc { // Serve from cache if a valid entry exists. if entry := store.get(key); entry != nil { body := entry.body + metaRewritten := false if meta := GetRequestMeta(c); meta != nil { body = refreshCachedResponseMeta(entry.body, meta) + metaRewritten = true + } + + // staleValidatorHeader returns true for headers that describe the + // exact bytes of the cached body and must be dropped when the body + // has been rewritten by refreshCachedResponseMeta. + staleValidatorHeader := func(canonical string) bool { + if !metaRewritten { + return false + } + switch canonical { + case "Etag", "Content-Md5", "Digest": + return true + } + return false } for k, vals := range entry.headers { - if http.CanonicalHeaderKey(k) == "X-Request-ID" { + canonical := http.CanonicalHeaderKey(k) + if canonical == "X-Request-Id" { continue } - if http.CanonicalHeaderKey(k) == "Content-Length" { + if canonical == "Content-Length" { + continue + } + if staleValidatorHeader(canonical) { continue } for _, v := range vals { diff --git a/ratelimit.go b/ratelimit.go index b7fb4fc..308ce7e 100644 --- a/ratelimit.go +++ b/ratelimit.go @@ -18,6 +18,14 @@ import ( const ( rateLimitCleanupInterval = time.Minute rateLimitStaleAfter = 10 * time.Minute + + // rateLimitMaxBuckets caps the total number of tracked keys to prevent + // unbounded memory growth under high-cardinality traffic (e.g. scanning + // bots cycling random IPs). When the cap is reached, new keys that cannot + // evict a stale bucket are routed to a shared overflow bucket so requests + // are still rate-limited rather than bypassing the limiter entirely. + rateLimitMaxBuckets = 100_000 + rateLimitOverflowKey = "__overflow__" ) type rateLimitStore struct { @@ -57,12 +65,39 @@ func (s *rateLimitStore) allow(key string) rateLimitDecision { s.mu.Lock() bucket, ok := s.buckets[key] if !ok || now.Sub(bucket.lastSeen) > rateLimitStaleAfter { - bucket = &rateLimitBucket{ - tokens: float64(s.limit), - last: now, - lastSeen: now, + // Enforce the bucket cap before inserting a new entry. First try to + // evict a single stale entry; if none exists and the map is full, + // route the request to the shared overflow bucket so it is still + // rate-limited rather than bypassing the limiter. + if !ok && len(s.buckets) >= rateLimitMaxBuckets { + evicted := false + for k, candidate := range s.buckets { + if now.Sub(candidate.lastSeen) > rateLimitStaleAfter { + delete(s.buckets, k) + evicted = true + break + } + } + if !evicted { + // Cap reached and no stale entry to evict: use overflow bucket. + key = rateLimitOverflowKey + if ob, exists := s.buckets[key]; exists { + bucket = ob + ok = true + } + } + } + + if !ok { + bucket = &rateLimitBucket{ + tokens: float64(s.limit), + last: now, + lastSeen: now, + } + s.buckets[key] = bucket + } else { + bucket.lastSeen = now } - s.buckets[key] = bucket } else { bucket.lastSeen = now } @@ -186,9 +221,10 @@ func timeUntilFull(tokens float64, limit int) time.Duration { } // clientRateLimitKey derives a bucket key for the request. It prefers a -// validated principal from context (set by auth middleware), then falls back -// to the client IP. Raw credential headers are hashed with SHA-256 when used -// as a last resort so that secrets are never stored in the bucket map. +// validated principal placed in context by auth middleware, then falls back to +// raw credential headers (X-API-Key or Bearer token, hashed with SHA-256 so +// secrets are never stored in the bucket map), and finally falls back to the +// client IP when no credentials are present. func clientRateLimitKey(c *gin.Context) string { // Prefer a validated principal placed in context by auth middleware. if principal, ok := c.Get("principal"); ok && principal != nil {