From a612d85dbabed03566e8499116fbc39725943c60 Mon Sep 17 00:00:00 2001 From: Snider Date: Sat, 21 Feb 2026 00:04:28 +0000 Subject: [PATCH] feat: add WithLocation reverse proxy header detection middleware Co-Authored-By: Virgil Co-Authored-By: Claude Opus 4.6 --- go.mod | 1 + go.sum | 2 + location_test.go | 180 +++++++++++++++++++++++++++++++++++++++++++++++ options.go | 14 ++++ 4 files changed, 197 insertions(+) create mode 100644 location_test.go diff --git a/go.mod b/go.mod index 32b3a53..1f62987 100644 --- a/go.mod +++ b/go.mod @@ -10,6 +10,7 @@ require ( github.com/gin-contrib/cors v1.7.6 github.com/gin-contrib/gzip v1.2.5 github.com/gin-contrib/httpsign v1.0.3 + github.com/gin-contrib/location/v2 v2.0.0 github.com/gin-contrib/secure v1.1.2 github.com/gin-contrib/sessions v1.0.4 github.com/gin-contrib/slog v1.2.0 diff --git a/go.sum b/go.sum index 8feecb2..c2347fb 100644 --- a/go.sum +++ b/go.sum @@ -38,6 +38,8 @@ github.com/gin-contrib/gzip v1.2.5 h1:fIZs0S+l17pIu1P5XRJOo/YNqfIuPCrZZ3TWB7pjck github.com/gin-contrib/gzip v1.2.5/go.mod h1:aomRgR7ftdZV3uWY0gW/m8rChfxau0n8YVvwlOHONzw= github.com/gin-contrib/httpsign v1.0.3 h1:NpeDQjmUV0qFjGCm/rkXSp3HH0hU7r84q1v+VtTiI5I= github.com/gin-contrib/httpsign v1.0.3/go.mod h1:n4GC7StmHNBhIzWzuW2njKbZMeEWh4tDbmn3bD1ab+k= +github.com/gin-contrib/location/v2 v2.0.0 h1:iLx5RatHQHSxgC0tm2AG0sIuQKecI7FhREessVd6RWY= +github.com/gin-contrib/location/v2 v2.0.0/go.mod h1:276TDNr25NENBA/NQZUuEIlwxy/I5CYVFIr/d2TgOdU= github.com/gin-contrib/secure v1.1.2 h1:6G8/NCOTSywWY7TeaH/0Yfaa6bfkE5ukkqtIm7lK11U= github.com/gin-contrib/secure v1.1.2/go.mod h1:xI3jI5/BpOYMCBtjgmIVrMA3kI7y9LwCFxs+eLf5S3w= github.com/gin-contrib/sessions v1.0.4 h1:ha6CNdpYiTOK/hTp05miJLbpTSNfOnFg5Jm2kbcqy8U= diff --git a/location_test.go b/location_test.go new file mode 100644 index 0000000..8672958 --- /dev/null +++ b/location_test.go @@ -0,0 +1,180 @@ +// SPDX-License-Identifier: EUPL-1.2 + +package api_test + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/gin-contrib/location/v2" + "github.com/gin-gonic/gin" + + api "forge.lthn.ai/core/go-api" +) + +// ── Helpers ───────────────────────────────────────────────────────────── + +// locationTestGroup exposes a route that returns the detected location. +type locationTestGroup struct{} + +func (l *locationTestGroup) Name() string { return "loc" } +func (l *locationTestGroup) BasePath() string { return "/loc" } +func (l *locationTestGroup) RegisterRoutes(rg *gin.RouterGroup) { + rg.GET("/info", func(c *gin.Context) { + url := location.Get(c) + c.JSON(http.StatusOK, api.OK(map[string]string{ + "scheme": url.Scheme, + "host": url.Host, + })) + }) +} + +// locationResponse is the typed response envelope for location info tests. +type locationResponse struct { + Success bool `json:"success"` + Data map[string]string `json:"data"` +} + +// ── WithLocation ──────────────────────────────────────────────────────── + +func TestWithLocation_Good_DetectsForwardedHost(t *testing.T) { + gin.SetMode(gin.TestMode) + e, _ := api.New(api.WithLocation()) + e.Register(&locationTestGroup{}) + + h := e.Handler() + w := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodGet, "/loc/info", nil) + req.Header.Set("X-Forwarded-Host", "api.example.com") + h.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", w.Code) + } + + var resp locationResponse + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("unmarshal error: %v", err) + } + if resp.Data["host"] != "api.example.com" { + t.Fatalf("expected host=%q, got %q", "api.example.com", resp.Data["host"]) + } +} + +func TestWithLocation_Good_DetectsForwardedProto(t *testing.T) { + gin.SetMode(gin.TestMode) + e, _ := api.New(api.WithLocation()) + e.Register(&locationTestGroup{}) + + h := e.Handler() + w := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodGet, "/loc/info", nil) + req.Header.Set("X-Forwarded-Proto", "https") + h.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", w.Code) + } + + var resp locationResponse + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("unmarshal error: %v", err) + } + if resp.Data["scheme"] != "https" { + t.Fatalf("expected scheme=%q, got %q", "https", resp.Data["scheme"]) + } +} + +func TestWithLocation_Good_FallsBackToRequestHost(t *testing.T) { + gin.SetMode(gin.TestMode) + e, _ := api.New(api.WithLocation()) + e.Register(&locationTestGroup{}) + + h := e.Handler() + w := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodGet, "/loc/info", nil) + // No X-Forwarded-* headers — middleware should fall back to defaults. + h.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", w.Code) + } + + var resp locationResponse + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("unmarshal error: %v", err) + } + + // Without forwarded headers the middleware falls back to its default + // scheme ("http"). The host will be either the request Host header + // value or the configured default; either way it must not be empty. + if resp.Data["scheme"] != "http" { + t.Fatalf("expected fallback scheme=%q, got %q", "http", resp.Data["scheme"]) + } + if resp.Data["host"] == "" { + t.Fatal("expected a non-empty host in fallback mode") + } +} + +func TestWithLocation_Good_CombinesWithOtherMiddleware(t *testing.T) { + gin.SetMode(gin.TestMode) + e, _ := api.New( + api.WithLocation(), + api.WithRequestID(), + ) + e.Register(&locationTestGroup{}) + + h := e.Handler() + w := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodGet, "/loc/info", nil) + req.Header.Set("X-Forwarded-Host", "proxy.example.com") + h.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", w.Code) + } + + // Location middleware should populate the detected host. + var resp locationResponse + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("unmarshal error: %v", err) + } + if resp.Data["host"] != "proxy.example.com" { + t.Fatalf("expected host=%q, got %q", "proxy.example.com", resp.Data["host"]) + } + + // RequestID middleware should also have run. + if w.Header().Get("X-Request-ID") == "" { + t.Fatal("expected X-Request-ID header from WithRequestID") + } +} + +func TestWithLocation_Good_BothHeadersCombined(t *testing.T) { + gin.SetMode(gin.TestMode) + e, _ := api.New(api.WithLocation()) + e.Register(&locationTestGroup{}) + + h := e.Handler() + w := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodGet, "/loc/info", nil) + req.Header.Set("X-Forwarded-Proto", "https") + req.Header.Set("X-Forwarded-Host", "secure.example.com") + h.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", w.Code) + } + + var resp locationResponse + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("unmarshal error: %v", err) + } + if resp.Data["scheme"] != "https" { + t.Fatalf("expected scheme=%q, got %q", "https", resp.Data["scheme"]) + } + if resp.Data["host"] != "secure.example.com" { + t.Fatalf("expected host=%q, got %q", "secure.example.com", resp.Data["host"]) + } +} diff --git a/options.go b/options.go index 1a8d922..282b501 100644 --- a/options.go +++ b/options.go @@ -13,6 +13,7 @@ import ( "github.com/gin-contrib/cors" gingzip "github.com/gin-contrib/gzip" "github.com/gin-contrib/httpsign" + "github.com/gin-contrib/location/v2" "github.com/gin-contrib/secure" "github.com/gin-contrib/sessions" "github.com/gin-contrib/sessions/cookie" @@ -263,3 +264,16 @@ func WithSSE(broker *SSEBroker) Option { e.sseBroker = broker } } + +// WithLocation adds reverse proxy header detection middleware via +// gin-contrib/location. It inspects X-Forwarded-Proto and X-Forwarded-Host +// headers to determine the original scheme and host when the server runs +// behind a TLS-terminating reverse proxy such as Traefik. +// +// After this middleware runs, handlers can call location.Get(c) to retrieve +// a *url.URL with the detected scheme, host, and base path. +func WithLocation() Option { + return func(e *Engine) { + e.middlewares = append(e.middlewares, location.Default()) + } +}