feat: add WithLocation reverse proxy header detection middleware

Co-Authored-By: Virgil <virgil@lethean.io>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Snider 2026-02-21 00:04:28 +00:00
parent 8fc307b454
commit a612d85dba
4 changed files with 197 additions and 0 deletions

1
go.mod
View file

@ -10,6 +10,7 @@ require (
github.com/gin-contrib/cors v1.7.6
github.com/gin-contrib/gzip v1.2.5
github.com/gin-contrib/httpsign v1.0.3
github.com/gin-contrib/location/v2 v2.0.0
github.com/gin-contrib/secure v1.1.2
github.com/gin-contrib/sessions v1.0.4
github.com/gin-contrib/slog v1.2.0

2
go.sum
View file

@ -38,6 +38,8 @@ github.com/gin-contrib/gzip v1.2.5 h1:fIZs0S+l17pIu1P5XRJOo/YNqfIuPCrZZ3TWB7pjck
github.com/gin-contrib/gzip v1.2.5/go.mod h1:aomRgR7ftdZV3uWY0gW/m8rChfxau0n8YVvwlOHONzw=
github.com/gin-contrib/httpsign v1.0.3 h1:NpeDQjmUV0qFjGCm/rkXSp3HH0hU7r84q1v+VtTiI5I=
github.com/gin-contrib/httpsign v1.0.3/go.mod h1:n4GC7StmHNBhIzWzuW2njKbZMeEWh4tDbmn3bD1ab+k=
github.com/gin-contrib/location/v2 v2.0.0 h1:iLx5RatHQHSxgC0tm2AG0sIuQKecI7FhREessVd6RWY=
github.com/gin-contrib/location/v2 v2.0.0/go.mod h1:276TDNr25NENBA/NQZUuEIlwxy/I5CYVFIr/d2TgOdU=
github.com/gin-contrib/secure v1.1.2 h1:6G8/NCOTSywWY7TeaH/0Yfaa6bfkE5ukkqtIm7lK11U=
github.com/gin-contrib/secure v1.1.2/go.mod h1:xI3jI5/BpOYMCBtjgmIVrMA3kI7y9LwCFxs+eLf5S3w=
github.com/gin-contrib/sessions v1.0.4 h1:ha6CNdpYiTOK/hTp05miJLbpTSNfOnFg5Jm2kbcqy8U=

180
location_test.go Normal file
View file

@ -0,0 +1,180 @@
// SPDX-License-Identifier: EUPL-1.2
package api_test
import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/gin-contrib/location/v2"
"github.com/gin-gonic/gin"
api "forge.lthn.ai/core/go-api"
)
// ── Helpers ─────────────────────────────────────────────────────────────
// locationTestGroup exposes a route that returns the detected location.
type locationTestGroup struct{}
func (l *locationTestGroup) Name() string { return "loc" }
func (l *locationTestGroup) BasePath() string { return "/loc" }
func (l *locationTestGroup) RegisterRoutes(rg *gin.RouterGroup) {
rg.GET("/info", func(c *gin.Context) {
url := location.Get(c)
c.JSON(http.StatusOK, api.OK(map[string]string{
"scheme": url.Scheme,
"host": url.Host,
}))
})
}
// locationResponse is the typed response envelope for location info tests.
type locationResponse struct {
Success bool `json:"success"`
Data map[string]string `json:"data"`
}
// ── WithLocation ────────────────────────────────────────────────────────
func TestWithLocation_Good_DetectsForwardedHost(t *testing.T) {
gin.SetMode(gin.TestMode)
e, _ := api.New(api.WithLocation())
e.Register(&locationTestGroup{})
h := e.Handler()
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodGet, "/loc/info", nil)
req.Header.Set("X-Forwarded-Host", "api.example.com")
h.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d", w.Code)
}
var resp locationResponse
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("unmarshal error: %v", err)
}
if resp.Data["host"] != "api.example.com" {
t.Fatalf("expected host=%q, got %q", "api.example.com", resp.Data["host"])
}
}
func TestWithLocation_Good_DetectsForwardedProto(t *testing.T) {
gin.SetMode(gin.TestMode)
e, _ := api.New(api.WithLocation())
e.Register(&locationTestGroup{})
h := e.Handler()
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodGet, "/loc/info", nil)
req.Header.Set("X-Forwarded-Proto", "https")
h.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d", w.Code)
}
var resp locationResponse
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("unmarshal error: %v", err)
}
if resp.Data["scheme"] != "https" {
t.Fatalf("expected scheme=%q, got %q", "https", resp.Data["scheme"])
}
}
func TestWithLocation_Good_FallsBackToRequestHost(t *testing.T) {
gin.SetMode(gin.TestMode)
e, _ := api.New(api.WithLocation())
e.Register(&locationTestGroup{})
h := e.Handler()
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodGet, "/loc/info", nil)
// No X-Forwarded-* headers — middleware should fall back to defaults.
h.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d", w.Code)
}
var resp locationResponse
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("unmarshal error: %v", err)
}
// Without forwarded headers the middleware falls back to its default
// scheme ("http"). The host will be either the request Host header
// value or the configured default; either way it must not be empty.
if resp.Data["scheme"] != "http" {
t.Fatalf("expected fallback scheme=%q, got %q", "http", resp.Data["scheme"])
}
if resp.Data["host"] == "" {
t.Fatal("expected a non-empty host in fallback mode")
}
}
func TestWithLocation_Good_CombinesWithOtherMiddleware(t *testing.T) {
gin.SetMode(gin.TestMode)
e, _ := api.New(
api.WithLocation(),
api.WithRequestID(),
)
e.Register(&locationTestGroup{})
h := e.Handler()
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodGet, "/loc/info", nil)
req.Header.Set("X-Forwarded-Host", "proxy.example.com")
h.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d", w.Code)
}
// Location middleware should populate the detected host.
var resp locationResponse
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("unmarshal error: %v", err)
}
if resp.Data["host"] != "proxy.example.com" {
t.Fatalf("expected host=%q, got %q", "proxy.example.com", resp.Data["host"])
}
// RequestID middleware should also have run.
if w.Header().Get("X-Request-ID") == "" {
t.Fatal("expected X-Request-ID header from WithRequestID")
}
}
func TestWithLocation_Good_BothHeadersCombined(t *testing.T) {
gin.SetMode(gin.TestMode)
e, _ := api.New(api.WithLocation())
e.Register(&locationTestGroup{})
h := e.Handler()
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodGet, "/loc/info", nil)
req.Header.Set("X-Forwarded-Proto", "https")
req.Header.Set("X-Forwarded-Host", "secure.example.com")
h.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d", w.Code)
}
var resp locationResponse
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("unmarshal error: %v", err)
}
if resp.Data["scheme"] != "https" {
t.Fatalf("expected scheme=%q, got %q", "https", resp.Data["scheme"])
}
if resp.Data["host"] != "secure.example.com" {
t.Fatalf("expected host=%q, got %q", "secure.example.com", resp.Data["host"])
}
}

View file

@ -13,6 +13,7 @@ import (
"github.com/gin-contrib/cors"
gingzip "github.com/gin-contrib/gzip"
"github.com/gin-contrib/httpsign"
"github.com/gin-contrib/location/v2"
"github.com/gin-contrib/secure"
"github.com/gin-contrib/sessions"
"github.com/gin-contrib/sessions/cookie"
@ -263,3 +264,16 @@ func WithSSE(broker *SSEBroker) Option {
e.sseBroker = broker
}
}
// WithLocation adds reverse proxy header detection middleware via
// gin-contrib/location. It inspects X-Forwarded-Proto and X-Forwarded-Host
// headers to determine the original scheme and host when the server runs
// behind a TLS-terminating reverse proxy such as Traefik.
//
// After this middleware runs, handlers can call location.Get(c) to retrieve
// a *url.URL with the detected scheme, host, and base path.
func WithLocation() Option {
return func(e *Engine) {
e.middlewares = append(e.middlewares, location.Default())
}
}