diff --git a/docs/architecture.md b/docs/architecture.md index db1deb4..d7cd2ad 100644 --- a/docs/architecture.md +++ b/docs/architecture.md @@ -152,6 +152,7 @@ They execute after `gin.Recovery()` but before any route handler. The `Option` t | `WithBearerAuth(token)` | Static bearer token authentication | Skips `/health` and `/swagger` | | `WithRequestID()` | `X-Request-ID` propagation | Preserves client-supplied IDs; generates 16-byte hex otherwise | | `WithCORS(origins...)` | CORS policy | `"*"` enables `AllowAllOrigins`; 12-hour `MaxAge` | +| `WithRateLimit(limit)` | Per-IP token-bucket rate limiting | `429 Too Many Requests`; `Retry-After` on rejection; zero or negative disables | | `WithMiddleware(mw...)` | Arbitrary Gin middleware | Escape hatch for custom middleware | | `WithStatic(prefix, root)` | Static file serving | Directory listing disabled | | `WithWSHandler(h)` | WebSocket at `/ws` | Wraps any `http.Handler` | diff --git a/docs/index.md b/docs/index.md index 3dec037..2a27d64 100644 --- a/docs/index.md +++ b/docs/index.md @@ -94,7 +94,7 @@ engine.Register(&Routes{service: svc}) | File | Purpose | |------|---------| | `api.go` | `Engine` struct, `New()`, `build()`, `Serve()`, `Handler()`, `Channels()` | -| `options.go` | All `With*()` option functions (25 options) | +| `options.go` | All `With*()` option functions (26 options) | | `group.go` | `RouteGroup`, `StreamGroup`, `DescribableGroup` interfaces; `RouteDescription` | | `response.go` | `Response[T]`, `Error`, `Meta`, `OK()`, `Fail()`, `FailWithDetails()`, `Paginated()` | | `middleware.go` | `bearerAuthMiddleware()`, `requestIDMiddleware()` | diff --git a/options.go b/options.go index bdf3f66..14c6e00 100644 --- a/options.go +++ b/options.go @@ -239,6 +239,16 @@ func WithCache(ttl time.Duration) Option { } } +// WithRateLimit adds per-IP token-bucket rate limiting middleware. +// Requests exceeding the configured limit per second are rejected with +// 429 Too Many Requests and the standard Fail() error envelope. +// A zero or negative limit disables rate limiting. +func WithRateLimit(limit int) Option { + return func(e *Engine) { + e.middlewares = append(e.middlewares, rateLimitMiddleware(limit)) + } +} + // WithSessions adds server-side session management middleware via // gin-contrib/sessions using a cookie-based store. The name parameter // sets the session cookie name (e.g. "session") and secret is the key diff --git a/ratelimit.go b/ratelimit.go new file mode 100644 index 0000000..8e78999 --- /dev/null +++ b/ratelimit.go @@ -0,0 +1,139 @@ +// SPDX-License-Identifier: EUPL-1.2 + +package api + +import ( + "net/http" + "strconv" + "sync" + "time" + + "github.com/gin-gonic/gin" +) + +const ( + rateLimitCleanupInterval = time.Minute + rateLimitStaleAfter = 10 * time.Minute +) + +type rateLimitStore struct { + mu sync.Mutex + buckets map[string]*rateLimitBucket + limit int + lastSweep time.Time +} + +type rateLimitBucket struct { + mu sync.Mutex + tokens float64 + last time.Time + lastSeen time.Time +} + +func newRateLimitStore(limit int) *rateLimitStore { + now := time.Now() + return &rateLimitStore{ + buckets: make(map[string]*rateLimitBucket), + limit: limit, + lastSweep: now, + } +} + +func (s *rateLimitStore) allow(key string) (bool, time.Duration) { + now := time.Now() + + s.mu.Lock() + bucket, ok := s.buckets[key] + if !ok || now.Sub(bucket.lastSeen) > rateLimitStaleAfter { + bucket = &rateLimitBucket{ + tokens: float64(s.limit), + last: now, + lastSeen: now, + } + s.buckets[key] = bucket + } else { + bucket.lastSeen = now + } + + if now.Sub(s.lastSweep) >= rateLimitCleanupInterval { + for k, candidate := range s.buckets { + if now.Sub(candidate.lastSeen) > rateLimitStaleAfter { + delete(s.buckets, k) + } + } + s.lastSweep = now + } + s.mu.Unlock() + + bucket.mu.Lock() + defer bucket.mu.Unlock() + + elapsed := now.Sub(bucket.last) + if elapsed > 0 { + refill := elapsed.Seconds() * float64(s.limit) + if bucket.tokens+refill > float64(s.limit) { + bucket.tokens = float64(s.limit) + } else { + bucket.tokens += refill + } + bucket.last = now + } + + if bucket.tokens >= 1 { + bucket.tokens-- + return true, 0 + } + + deficit := 1 - bucket.tokens + wait := time.Duration(deficit / float64(s.limit) * float64(time.Second)) + if wait <= 0 { + wait = time.Second / time.Duration(s.limit) + if wait <= 0 { + wait = time.Second + } + } + + return false, wait +} + +func rateLimitMiddleware(limit int) gin.HandlerFunc { + if limit <= 0 { + return func(c *gin.Context) { + c.Next() + } + } + + store := newRateLimitStore(limit) + + return func(c *gin.Context) { + key := clientRateLimitKey(c) + allowed, retryAfter := store.allow(key) + if !allowed { + secs := int(retryAfter / time.Second) + if retryAfter%time.Second != 0 { + secs++ + } + if secs < 1 { + secs = 1 + } + c.Header("Retry-After", strconv.Itoa(secs)) + c.AbortWithStatusJSON(http.StatusTooManyRequests, Fail( + "rate_limit_exceeded", + "Too many requests", + )) + return + } + + c.Next() + } +} + +func clientRateLimitKey(c *gin.Context) string { + if ip := c.ClientIP(); ip != "" { + return ip + } + if c.Request != nil && c.Request.RemoteAddr != "" { + return c.Request.RemoteAddr + } + return "unknown" +} diff --git a/ratelimit_test.go b/ratelimit_test.go new file mode 100644 index 0000000..b9c72ab --- /dev/null +++ b/ratelimit_test.go @@ -0,0 +1,149 @@ +// SPDX-License-Identifier: EUPL-1.2 + +package api_test + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/gin-gonic/gin" + + api "dappco.re/go/core/api" +) + +type rateLimitTestGroup struct{} + +func (r *rateLimitTestGroup) Name() string { return "rate-limit" } +func (r *rateLimitTestGroup) BasePath() string { return "/rate" } +func (r *rateLimitTestGroup) RegisterRoutes(rg *gin.RouterGroup) { + rg.GET("/ping", func(c *gin.Context) { + c.JSON(http.StatusOK, api.OK("pong")) + }) +} + +func TestWithRateLimit_Good_AllowsBurstThenRejects(t *testing.T) { + gin.SetMode(gin.TestMode) + e, _ := api.New(api.WithRateLimit(2)) + e.Register(&rateLimitTestGroup{}) + + h := e.Handler() + + w1 := httptest.NewRecorder() + req1, _ := http.NewRequest(http.MethodGet, "/rate/ping", nil) + req1.RemoteAddr = "203.0.113.10:1234" + h.ServeHTTP(w1, req1) + if w1.Code != http.StatusOK { + t.Fatalf("expected first request to succeed, got %d", w1.Code) + } + + w2 := httptest.NewRecorder() + req2, _ := http.NewRequest(http.MethodGet, "/rate/ping", nil) + req2.RemoteAddr = "203.0.113.10:1234" + h.ServeHTTP(w2, req2) + if w2.Code != http.StatusOK { + t.Fatalf("expected second request to succeed, got %d", w2.Code) + } + + w3 := httptest.NewRecorder() + req3, _ := http.NewRequest(http.MethodGet, "/rate/ping", nil) + req3.RemoteAddr = "203.0.113.10:1234" + h.ServeHTTP(w3, req3) + if w3.Code != http.StatusTooManyRequests { + t.Fatalf("expected third request to be rate limited, got %d", w3.Code) + } + + if got := w3.Header().Get("Retry-After"); got == "" { + t.Fatal("expected Retry-After header on 429 response") + } + + var resp api.Response[any] + if err := json.Unmarshal(w3.Body.Bytes(), &resp); err != nil { + t.Fatalf("unmarshal error: %v", err) + } + if resp.Success { + t.Fatal("expected Success=false for rate limited response") + } + if resp.Error == nil || resp.Error.Code != "rate_limit_exceeded" { + t.Fatalf("expected rate_limit_exceeded error, got %+v", resp.Error) + } +} + +func TestWithRateLimit_Good_IsolatesPerIP(t *testing.T) { + gin.SetMode(gin.TestMode) + e, _ := api.New(api.WithRateLimit(1)) + e.Register(&rateLimitTestGroup{}) + + h := e.Handler() + + w1 := httptest.NewRecorder() + req1, _ := http.NewRequest(http.MethodGet, "/rate/ping", nil) + req1.RemoteAddr = "203.0.113.10:1234" + h.ServeHTTP(w1, req1) + if w1.Code != http.StatusOK { + t.Fatalf("expected first IP to succeed, got %d", w1.Code) + } + + w2 := httptest.NewRecorder() + req2, _ := http.NewRequest(http.MethodGet, "/rate/ping", nil) + req2.RemoteAddr = "203.0.113.11:1234" + h.ServeHTTP(w2, req2) + if w2.Code != http.StatusOK { + t.Fatalf("expected second IP to have its own bucket, got %d", w2.Code) + } +} + +func TestWithRateLimit_Good_RefillsOverTime(t *testing.T) { + gin.SetMode(gin.TestMode) + e, _ := api.New(api.WithRateLimit(1)) + e.Register(&rateLimitTestGroup{}) + + h := e.Handler() + + req, _ := http.NewRequest(http.MethodGet, "/rate/ping", nil) + req.RemoteAddr = "203.0.113.12:1234" + + w1 := httptest.NewRecorder() + h.ServeHTTP(w1, req.Clone(req.Context())) + if w1.Code != http.StatusOK { + t.Fatalf("expected first request to succeed, got %d", w1.Code) + } + + w2 := httptest.NewRecorder() + req2 := req.Clone(req.Context()) + req2.RemoteAddr = req.RemoteAddr + h.ServeHTTP(w2, req2) + if w2.Code != http.StatusTooManyRequests { + t.Fatalf("expected second request to be rate limited, got %d", w2.Code) + } + + time.Sleep(1100 * time.Millisecond) + + w3 := httptest.NewRecorder() + req3 := req.Clone(req.Context()) + req3.RemoteAddr = req.RemoteAddr + h.ServeHTTP(w3, req3) + if w3.Code != http.StatusOK { + t.Fatalf("expected bucket to refill after waiting, got %d", w3.Code) + } +} + +func TestWithRateLimit_Ugly_NonPositiveLimitDisablesMiddleware(t *testing.T) { + gin.SetMode(gin.TestMode) + e, _ := api.New(api.WithRateLimit(0)) + e.Register(&rateLimitTestGroup{}) + + h := e.Handler() + + for i := 0; i < 3; i++ { + w := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodGet, "/rate/ping", nil) + req.RemoteAddr = "203.0.113.13:1234" + h.ServeHTTP(w, req) + if w.Code != http.StatusOK { + t.Fatalf("expected request %d to succeed with disabled limiter, got %d", i+1, w.Code) + } + } +}