diff --git a/go.mod b/go.mod index 08b6d20..7b01a70 100644 --- a/go.mod +++ b/go.mod @@ -7,6 +7,7 @@ require ( github.com/gin-contrib/cors v1.7.6 github.com/gin-contrib/secure v1.1.2 github.com/gin-contrib/slog v1.2.0 + github.com/gin-contrib/timeout v1.1.0 github.com/gin-gonic/gin v1.11.0 github.com/gorilla/websocket v1.5.3 github.com/swaggo/files v1.0.1 diff --git a/go.sum b/go.sum index d3993c3..1ce9a76 100644 --- a/go.sum +++ b/go.sum @@ -30,6 +30,8 @@ github.com/gin-contrib/slog v1.2.0 h1:vAxZfr7knD1ZYK5+pMJLP52sZXIkJXkcRPa/0dx9hS github.com/gin-contrib/slog v1.2.0/go.mod h1:vYK6YltmpsEFkO0zfRMLTKHrWS3DwUSn0TMpT+kMagI= github.com/gin-contrib/sse v1.1.0 h1:n0w2GMuUpWDVp7qSpvze6fAu9iRxJY4Hmj6AmBOU05w= github.com/gin-contrib/sse v1.1.0/go.mod h1:hxRZ5gVpWMT7Z0B0gSNYqqsSCNIJMjzvm6fqCz9vjwM= +github.com/gin-contrib/timeout v1.1.0 h1:WAmWseo5gfBUbMrMJu5hJxDclehfSJUmK2wGwCC/EFw= +github.com/gin-contrib/timeout v1.1.0/go.mod h1:NpRo4gd1Ad8ZQ4T6bQLVFDqiplCmPRs2nvfckxS2Fw4= github.com/gin-gonic/gin v1.11.0 h1:OW/6PLjyusp2PPXtyxKHU0RbX6I/l28FTdDlae5ueWk= github.com/gin-gonic/gin v1.11.0/go.mod h1:+iq/FyxlGzII0KHiBGjuNn4UNENUlKbGlNmc+W50Dls= github.com/go-jose/go-jose/v4 v4.1.3 h1:CVLmWDhDVRa6Mi/IgCgaopNosCaHz7zrMeF9MlZRkrs= diff --git a/options.go b/options.go index 281aec0..e8e18d3 100644 --- a/options.go +++ b/options.go @@ -10,6 +10,7 @@ import ( "github.com/gin-contrib/cors" "github.com/gin-contrib/secure" ginslog "github.com/gin-contrib/slog" + "github.com/gin-contrib/timeout" "github.com/gin-gonic/gin" ) @@ -134,3 +135,25 @@ func WithSlog(logger *slog.Logger) Option { )) } } + +// WithTimeout adds per-request timeout middleware via gin-contrib/timeout. +// If a handler exceeds the given duration, the request is aborted with a +// 504 Gateway Timeout carrying the standard error envelope: +// +// {"success":false,"error":{"code":"timeout","message":"Request timed out"}} +// +// A zero or negative duration effectively disables the timeout (the handler +// runs without a deadline) — this is safe and will not panic. +func WithTimeout(d time.Duration) Option { + return func(e *Engine) { + e.middlewares = append(e.middlewares, timeout.New( + timeout.WithTimeout(d), + timeout.WithResponse(timeoutResponse), + )) + } +} + +// timeoutResponse writes a 504 Gateway Timeout with the standard error envelope. +func timeoutResponse(c *gin.Context) { + c.JSON(http.StatusGatewayTimeout, Fail("timeout", "Request timed out")) +} diff --git a/timeout_test.go b/timeout_test.go new file mode 100644 index 0000000..d761602 --- /dev/null +++ b/timeout_test.go @@ -0,0 +1,159 @@ +// SPDX-License-Identifier: EUPL-1.2 + +package api_test + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/gin-gonic/gin" + + api "forge.lthn.ai/core/go-api" +) + +// ── Helpers ───────────────────────────────────────────────────────────── + +// slowGroup provides a route that sleeps longer than the test timeout. +type slowGroup struct{} + +func (s *slowGroup) Name() string { return "slow" } +func (s *slowGroup) BasePath() string { return "/v1/slow" } +func (s *slowGroup) RegisterRoutes(rg *gin.RouterGroup) { + rg.GET("/wait", func(c *gin.Context) { + time.Sleep(200 * time.Millisecond) + c.JSON(http.StatusOK, api.OK("done")) + }) +} + +// ── WithTimeout ───────────────────────────────────────────────────────── + +func TestWithTimeout_Good_FastRequestSucceeds(t *testing.T) { + gin.SetMode(gin.TestMode) + e, _ := api.New(api.WithTimeout(500 * time.Millisecond)) + e.Register(&stubGroup{}) + + h := e.Handler() + w := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodGet, "/stub/ping", nil) + h.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", w.Code) + } + + var resp api.Response[string] + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("unmarshal error: %v", err) + } + if !resp.Success { + t.Fatal("expected Success=true") + } + if resp.Data != "pong" { + t.Fatalf("expected Data=%q, got %q", "pong", resp.Data) + } +} + +func TestWithTimeout_Good_SlowRequestTimesOut(t *testing.T) { + gin.SetMode(gin.TestMode) + e, _ := api.New(api.WithTimeout(50 * time.Millisecond)) + e.Register(&slowGroup{}) + + h := e.Handler() + w := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodGet, "/v1/slow/wait", nil) + h.ServeHTTP(w, req) + + if w.Code != http.StatusGatewayTimeout { + t.Fatalf("expected 504, got %d", w.Code) + } +} + +func TestWithTimeout_Good_TimeoutResponseEnvelope(t *testing.T) { + gin.SetMode(gin.TestMode) + e, _ := api.New(api.WithTimeout(50 * time.Millisecond)) + e.Register(&slowGroup{}) + + h := e.Handler() + w := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodGet, "/v1/slow/wait", nil) + h.ServeHTTP(w, req) + + if w.Code != http.StatusGatewayTimeout { + t.Fatalf("expected 504, got %d", w.Code) + } + + var resp api.Response[any] + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("unmarshal error: %v", err) + } + if resp.Success { + t.Fatal("expected Success=false") + } + if resp.Error == nil { + t.Fatal("expected Error to be non-nil") + } + if resp.Error.Code != "timeout" { + t.Fatalf("expected error code=%q, got %q", "timeout", resp.Error.Code) + } + if resp.Error.Message != "Request timed out" { + t.Fatalf("expected error message=%q, got %q", "Request timed out", resp.Error.Message) + } +} + +func TestWithTimeout_Good_CombinesWithOtherMiddleware(t *testing.T) { + gin.SetMode(gin.TestMode) + e, _ := api.New( + api.WithRequestID(), + api.WithTimeout(500*time.Millisecond), + ) + e.Register(&stubGroup{}) + + h := e.Handler() + w := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodGet, "/stub/ping", nil) + h.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", w.Code) + } + + // WithRequestID should still set the header. + id := w.Header().Get("X-Request-ID") + if id == "" { + t.Fatal("expected X-Request-ID header to be set") + } + + var resp api.Response[string] + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("unmarshal error: %v", err) + } + if resp.Data != "pong" { + t.Fatalf("expected Data=%q, got %q", "pong", resp.Data) + } +} + +func TestWithTimeout_Ugly_ZeroDurationDoesNotPanic(t *testing.T) { + gin.SetMode(gin.TestMode) + + defer func() { + if r := recover(); r != nil { + t.Fatalf("WithTimeout(0) panicked: %v", r) + } + }() + + e, err := api.New(api.WithTimeout(0)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + e.Register(&stubGroup{}) + + h := e.Handler() + w := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodGet, "/stub/ping", nil) + h.ServeHTTP(w, req) + + // We only care that it did not panic. Status may vary with zero timeout. +}