From 2f8f8f805e2d9a2ab4b8f907a94c145d0f7de847 Mon Sep 17 00:00:00 2001 From: Virgil Date: Wed, 1 Apr 2026 18:22:17 +0000 Subject: [PATCH] fix(api): scope rate limiting by key Co-Authored-By: Virgil --- options.go | 9 +++--- ratelimit.go | 29 ++++++++++++++++++-- ratelimit_test.go | 70 +++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 101 insertions(+), 7 deletions(-) diff --git a/options.go b/options.go index c0d986c..bb973a3 100644 --- a/options.go +++ b/options.go @@ -268,10 +268,11 @@ func WithCache(ttl time.Duration, maxEntries ...int) Option { } } -// WithRateLimit adds per-IP token-bucket rate limiting middleware. -// Requests that pass are annotated with X-RateLimit-Limit, -// X-RateLimit-Remaining, and X-RateLimit-Reset headers. Requests -// exceeding the configured limit are rejected with 429 Too Many +// WithRateLimit adds token-bucket rate limiting middleware. +// Requests are bucketed by API key or bearer token when present, and +// otherwise by client IP. Passing requests are annotated with +// X-RateLimit-Limit, X-RateLimit-Remaining, and X-RateLimit-Reset headers. +// Requests exceeding the configured limit are rejected with 429 Too Many // Requests, Retry-After, and the standard Fail() error envelope. // A zero or negative limit disables rate limiting. func WithRateLimit(limit int) Option { diff --git a/ratelimit.go b/ratelimit.go index 37aad0b..20ebd77 100644 --- a/ratelimit.go +++ b/ratelimit.go @@ -6,6 +6,7 @@ import ( "math" "net/http" "strconv" + "strings" "sync" "time" @@ -182,12 +183,34 @@ func timeUntilFull(tokens float64, limit int) time.Duration { return time.Duration(math.Ceil(seconds * float64(time.Second))) } +// clientRateLimitKey prefers caller-provided credentials for bucket +// isolation, then falls back to the network address. func clientRateLimitKey(c *gin.Context) string { + if apiKey := strings.TrimSpace(c.GetHeader("X-API-Key")); apiKey != "" { + return "api_key:" + apiKey + } + if bearer := bearerTokenFromHeader(c.GetHeader("Authorization")); bearer != "" { + return "bearer:" + bearer + } if ip := c.ClientIP(); ip != "" { - return ip + return "ip:" + ip } if c.Request != nil && c.Request.RemoteAddr != "" { - return c.Request.RemoteAddr + return "ip:" + c.Request.RemoteAddr } - return "unknown" + return "ip:unknown" +} + +func bearerTokenFromHeader(header string) string { + header = strings.TrimSpace(header) + if header == "" { + return "" + } + + parts := strings.SplitN(header, " ", 2) + if len(parts) != 2 || !strings.EqualFold(parts[0], "Bearer") { + return "" + } + + return strings.TrimSpace(parts[1]) } diff --git a/ratelimit_test.go b/ratelimit_test.go index ef19482..24d75ed 100644 --- a/ratelimit_test.go +++ b/ratelimit_test.go @@ -116,6 +116,76 @@ func TestWithRateLimit_Good_IsolatesPerIP(t *testing.T) { } } +func TestWithRateLimit_Good_IsolatesPerAPIKey(t *testing.T) { + gin.SetMode(gin.TestMode) + e, _ := api.New(api.WithRateLimit(1)) + e.Register(&rateLimitTestGroup{}) + + h := e.Handler() + + w1 := httptest.NewRecorder() + req1, _ := http.NewRequest(http.MethodGet, "/rate/ping", nil) + req1.RemoteAddr = "203.0.113.20:1234" + req1.Header.Set("X-API-Key", "key-a") + h.ServeHTTP(w1, req1) + if w1.Code != http.StatusOK { + t.Fatalf("expected first API key request to succeed, got %d", w1.Code) + } + + w2 := httptest.NewRecorder() + req2, _ := http.NewRequest(http.MethodGet, "/rate/ping", nil) + req2.RemoteAddr = "203.0.113.20:1234" + req2.Header.Set("X-API-Key", "key-b") + h.ServeHTTP(w2, req2) + if w2.Code != http.StatusOK { + t.Fatalf("expected second API key to have its own bucket, got %d", w2.Code) + } + + w3 := httptest.NewRecorder() + req3, _ := http.NewRequest(http.MethodGet, "/rate/ping", nil) + req3.RemoteAddr = "203.0.113.20:1234" + req3.Header.Set("X-API-Key", "key-a") + h.ServeHTTP(w3, req3) + if w3.Code != http.StatusTooManyRequests { + t.Fatalf("expected repeated API key to be rate limited, got %d", w3.Code) + } +} + +func TestWithRateLimit_Good_UsesBearerTokenWhenPresent(t *testing.T) { + gin.SetMode(gin.TestMode) + e, _ := api.New(api.WithRateLimit(1)) + e.Register(&rateLimitTestGroup{}) + + h := e.Handler() + + w1 := httptest.NewRecorder() + req1, _ := http.NewRequest(http.MethodGet, "/rate/ping", nil) + req1.RemoteAddr = "203.0.113.30:1234" + req1.Header.Set("Authorization", "Bearer token-a") + h.ServeHTTP(w1, req1) + if w1.Code != http.StatusOK { + t.Fatalf("expected first bearer token request to succeed, got %d", w1.Code) + } + + w2 := httptest.NewRecorder() + req2, _ := http.NewRequest(http.MethodGet, "/rate/ping", nil) + req2.RemoteAddr = "203.0.113.30:1234" + req2.Header.Set("Authorization", "Bearer token-b") + h.ServeHTTP(w2, req2) + if w2.Code != http.StatusOK { + t.Fatalf("expected second bearer token to have its own bucket, got %d", w2.Code) + } + + w3 := httptest.NewRecorder() + req3, _ := http.NewRequest(http.MethodGet, "/rate/ping", nil) + req3.RemoteAddr = "203.0.113.30:1234" + req3.Header.Set("Authorization", "Bearer token-a") + h.ServeHTTP(w3, req3) + if w3.Code != http.StatusTooManyRequests { + t.Fatalf("expected repeated bearer token to be rate limited, got %d", w3.Code) + } +} + func TestWithRateLimit_Good_RefillsOverTime(t *testing.T) { gin.SetMode(gin.TestMode) e, _ := api.New(api.WithRateLimit(1))