diff --git a/docs/architecture.md b/docs/architecture.md index feaf595..65dde23 100644 --- a/docs/architecture.md +++ b/docs/architecture.md @@ -153,7 +153,7 @@ They execute after `gin.Recovery()` but before any route handler. The `Option` t | `WithRequestID()` | `X-Request-ID` propagation | Preserves client-supplied IDs; generates 16-byte hex otherwise | | `WithResponseMeta()` | Request metadata in JSON envelopes | Merges `request_id` and `duration` into standard responses | | `WithCORS(origins...)` | CORS policy | `"*"` enables `AllowAllOrigins`; 12-hour `MaxAge` | -| `WithRateLimit(limit)` | Per-IP token-bucket rate limiting | `429 Too Many Requests`; `Retry-After` on rejection; zero or negative disables | +| `WithRateLimit(limit)` | Per-IP token-bucket rate limiting | `429 Too Many Requests`; `X-RateLimit-*` on success; `Retry-After` on rejection; zero or negative disables | | `WithMiddleware(mw...)` | Arbitrary Gin middleware | Escape hatch for custom middleware | | `WithStatic(prefix, root)` | Static file serving | Directory listing disabled | | `WithWSHandler(h)` | WebSocket at `/ws` | Wraps any `http.Handler` | diff --git a/openapi.go b/openapi.go index e4cdc01..9d2ecfd 100644 --- a/openapi.go +++ b/openapi.go @@ -170,7 +170,7 @@ func operationResponses(dataSchema map[string]any) map[string]any { "schema": envelopeSchema(dataSchema), }, }, - "headers": standardResponseHeaders(), + "headers": mergeHeaders(standardResponseHeaders(), rateLimitSuccessHeaders()), }, "400": map[string]any{ "description": "Bad request", @@ -231,7 +231,7 @@ func healthResponses() map[string]any { "schema": envelopeSchema(map[string]any{"type": "string"}), }, }, - "headers": standardResponseHeaders(), + "headers": mergeHeaders(standardResponseHeaders(), rateLimitSuccessHeaders()), }, "429": map[string]any{ "description": "Too many requests", @@ -316,6 +316,27 @@ func envelopeSchema(dataSchema map[string]any) map[string]any { // rejects a request. func rateLimitHeaders() map[string]any { return map[string]any{ + "X-RateLimit-Limit": map[string]any{ + "description": "Maximum number of requests allowed in the current window", + "schema": map[string]any{ + "type": "integer", + "minimum": 1, + }, + }, + "X-RateLimit-Remaining": map[string]any{ + "description": "Number of requests remaining in the current window", + "schema": map[string]any{ + "type": "integer", + "minimum": 0, + }, + }, + "X-RateLimit-Reset": map[string]any{ + "description": "Unix timestamp when the rate limit window resets", + "schema": map[string]any{ + "type": "integer", + "minimum": 1, + }, + }, "Retry-After": map[string]any{ "description": "Seconds until the rate limit resets", "schema": map[string]any{ @@ -326,6 +347,34 @@ func rateLimitHeaders() map[string]any { } } +// rateLimitSuccessHeaders documents the response headers emitted on +// successful requests when rate limiting is enabled. +func rateLimitSuccessHeaders() map[string]any { + return map[string]any{ + "X-RateLimit-Limit": map[string]any{ + "description": "Maximum number of requests allowed in the current window", + "schema": map[string]any{ + "type": "integer", + "minimum": 1, + }, + }, + "X-RateLimit-Remaining": map[string]any{ + "description": "Number of requests remaining in the current window", + "schema": map[string]any{ + "type": "integer", + "minimum": 0, + }, + }, + "X-RateLimit-Reset": map[string]any{ + "description": "Unix timestamp when the rate limit window resets", + "schema": map[string]any{ + "type": "integer", + "minimum": 1, + }, + }, + } +} + // standardResponseHeaders documents headers emitted by the response envelope // middleware on all responses when request IDs are enabled. func standardResponseHeaders() map[string]any { diff --git a/openapi_test.go b/openapi_test.go index 6c1c11a..3c31701 100644 --- a/openapi_test.go +++ b/openapi_test.go @@ -76,6 +76,15 @@ func TestSpecBuilder_Good_EmptyGroups(t *testing.T) { if _, ok := headers["X-Request-ID"]; !ok { t.Fatal("expected X-Request-ID header on /health 429 response") } + if _, ok := headers["X-RateLimit-Limit"]; !ok { + t.Fatal("expected X-RateLimit-Limit header on /health 429 response") + } + if _, ok := headers["X-RateLimit-Remaining"]; !ok { + t.Fatal("expected X-RateLimit-Remaining header on /health 429 response") + } + if _, ok := headers["X-RateLimit-Reset"]; !ok { + t.Fatal("expected X-RateLimit-Reset header on /health 429 response") + } // Verify system tag exists. tags := spec["tags"].([]any) @@ -256,6 +265,15 @@ func TestSpecBuilder_Good_SecuredResponses(t *testing.T) { if _, ok := headers["X-Request-ID"]; !ok { t.Fatal("expected X-Request-ID header in secured operation 429 response") } + if _, ok := headers["X-RateLimit-Limit"]; !ok { + t.Fatal("expected X-RateLimit-Limit header in secured operation 429 response") + } + if _, ok := headers["X-RateLimit-Remaining"]; !ok { + t.Fatal("expected X-RateLimit-Remaining header in secured operation 429 response") + } + if _, ok := headers["X-RateLimit-Reset"]; !ok { + t.Fatal("expected X-RateLimit-Reset header in secured operation 429 response") + } } func TestSpecBuilder_Good_EnvelopeWrapping(t *testing.T) { @@ -302,6 +320,15 @@ func TestSpecBuilder_Good_EnvelopeWrapping(t *testing.T) { if _, ok := headers["X-Request-ID"]; !ok { t.Fatal("expected X-Request-ID header on 200 response") } + if _, ok := headers["X-RateLimit-Limit"]; !ok { + t.Fatal("expected X-RateLimit-Limit header on 200 response") + } + if _, ok := headers["X-RateLimit-Remaining"]; !ok { + t.Fatal("expected X-RateLimit-Remaining header on 200 response") + } + if _, ok := headers["X-RateLimit-Reset"]; !ok { + t.Fatal("expected X-RateLimit-Reset header on 200 response") + } content := resp200["content"].(map[string]any) appJSON := content["application/json"].(map[string]any) schema := appJSON["schema"].(map[string]any) diff --git a/options.go b/options.go index dc0f4d4..36819b6 100644 --- a/options.go +++ b/options.go @@ -260,8 +260,10 @@ func WithCache(ttl time.Duration, maxEntries ...int) Option { } // WithRateLimit adds per-IP token-bucket rate limiting middleware. -// Requests exceeding the configured limit per second are rejected with -// 429 Too Many Requests and the standard Fail() error envelope. +// Requests that pass are annotated with X-RateLimit-Limit, +// X-RateLimit-Remaining, and X-RateLimit-Reset headers. Requests +// exceeding the configured limit are rejected with 429 Too Many +// Requests, Retry-After, and the standard Fail() error envelope. // A zero or negative limit disables rate limiting. func WithRateLimit(limit int) Option { return func(e *Engine) { diff --git a/ratelimit.go b/ratelimit.go index 8e78999..37aad0b 100644 --- a/ratelimit.go +++ b/ratelimit.go @@ -3,6 +3,7 @@ package api import ( + "math" "net/http" "strconv" "sync" @@ -30,6 +31,14 @@ type rateLimitBucket struct { lastSeen time.Time } +type rateLimitDecision struct { + allowed bool + retryAfter time.Duration + limit int + remaining int + resetAt time.Time +} + func newRateLimitStore(limit int) *rateLimitStore { now := time.Now() return &rateLimitStore{ @@ -39,7 +48,7 @@ func newRateLimitStore(limit int) *rateLimitStore { } } -func (s *rateLimitStore) allow(key string) (bool, time.Duration) { +func (s *rateLimitStore) allow(key string) rateLimitDecision { now := time.Now() s.mu.Lock() @@ -81,7 +90,12 @@ func (s *rateLimitStore) allow(key string) (bool, time.Duration) { if bucket.tokens >= 1 { bucket.tokens-- - return true, 0 + return rateLimitDecision{ + allowed: true, + limit: s.limit, + remaining: int(math.Floor(bucket.tokens)), + resetAt: now.Add(timeUntilFull(bucket.tokens, s.limit)), + } } deficit := 1 - bucket.tokens @@ -93,7 +107,13 @@ func (s *rateLimitStore) allow(key string) (bool, time.Duration) { } } - return false, wait + return rateLimitDecision{ + allowed: false, + retryAfter: wait, + limit: s.limit, + remaining: 0, + resetAt: now.Add(wait), + } } func rateLimitMiddleware(limit int) gin.HandlerFunc { @@ -107,15 +127,16 @@ func rateLimitMiddleware(limit int) gin.HandlerFunc { return func(c *gin.Context) { key := clientRateLimitKey(c) - allowed, retryAfter := store.allow(key) - if !allowed { - secs := int(retryAfter / time.Second) - if retryAfter%time.Second != 0 { + decision := store.allow(key) + if !decision.allowed { + secs := int(decision.retryAfter / time.Second) + if decision.retryAfter%time.Second != 0 { secs++ } if secs < 1 { secs = 1 } + setRateLimitHeaders(c, decision.limit, decision.remaining, decision.resetAt) c.Header("Retry-After", strconv.Itoa(secs)) c.AbortWithStatusJSON(http.StatusTooManyRequests, Fail( "rate_limit_exceeded", @@ -125,9 +146,42 @@ func rateLimitMiddleware(limit int) gin.HandlerFunc { } c.Next() + setRateLimitHeaders(c, decision.limit, decision.remaining, decision.resetAt) } } +func setRateLimitHeaders(c *gin.Context, limit, remaining int, resetAt time.Time) { + if limit > 0 { + c.Header("X-RateLimit-Limit", strconv.Itoa(limit)) + } + if remaining < 0 { + remaining = 0 + } + c.Header("X-RateLimit-Remaining", strconv.Itoa(remaining)) + if !resetAt.IsZero() { + reset := resetAt.Unix() + if reset <= time.Now().Unix() { + reset = time.Now().Add(time.Second).Unix() + } + c.Header("X-RateLimit-Reset", strconv.FormatInt(reset, 10)) + } +} + +func timeUntilFull(tokens float64, limit int) time.Duration { + if limit <= 0 { + return 0 + } + missing := float64(limit) - tokens + if missing <= 0 { + return 0 + } + seconds := missing / float64(limit) + if seconds <= 0 { + return 0 + } + return time.Duration(math.Ceil(seconds * float64(time.Second))) +} + func clientRateLimitKey(c *gin.Context) string { if ip := c.ClientIP(); ip != "" { return ip diff --git a/ratelimit_test.go b/ratelimit_test.go index b9c72ab..ef19482 100644 --- a/ratelimit_test.go +++ b/ratelimit_test.go @@ -38,6 +38,15 @@ func TestWithRateLimit_Good_AllowsBurstThenRejects(t *testing.T) { if w1.Code != http.StatusOK { t.Fatalf("expected first request to succeed, got %d", w1.Code) } + if got := w1.Header().Get("X-RateLimit-Limit"); got != "2" { + t.Fatalf("expected X-RateLimit-Limit=2, got %q", got) + } + if got := w1.Header().Get("X-RateLimit-Remaining"); got != "1" { + t.Fatalf("expected X-RateLimit-Remaining=1, got %q", got) + } + if got := w1.Header().Get("X-RateLimit-Reset"); got == "" { + t.Fatal("expected X-RateLimit-Reset on successful response") + } w2 := httptest.NewRecorder() req2, _ := http.NewRequest(http.MethodGet, "/rate/ping", nil) @@ -58,6 +67,15 @@ func TestWithRateLimit_Good_AllowsBurstThenRejects(t *testing.T) { if got := w3.Header().Get("Retry-After"); got == "" { t.Fatal("expected Retry-After header on 429 response") } + if got := w3.Header().Get("X-RateLimit-Limit"); got != "2" { + t.Fatalf("expected X-RateLimit-Limit=2 on 429, got %q", got) + } + if got := w3.Header().Get("X-RateLimit-Remaining"); got != "0" { + t.Fatalf("expected X-RateLimit-Remaining=0 on 429, got %q", got) + } + if got := w3.Header().Get("X-RateLimit-Reset"); got == "" { + t.Fatal("expected X-RateLimit-Reset on 429 response") + } var resp api.Response[any] if err := json.Unmarshal(w3.Body.Bytes(), &resp); err != nil { @@ -85,6 +103,9 @@ func TestWithRateLimit_Good_IsolatesPerIP(t *testing.T) { if w1.Code != http.StatusOK { t.Fatalf("expected first IP to succeed, got %d", w1.Code) } + if got := w1.Header().Get("X-RateLimit-Limit"); got != "1" { + t.Fatalf("expected X-RateLimit-Limit=1, got %q", got) + } w2 := httptest.NewRecorder() req2, _ := http.NewRequest(http.MethodGet, "/rate/ping", nil)