diff --git a/authentik.go b/authentik.go index c78b146..c81ff97 100644 --- a/authentik.go +++ b/authentik.go @@ -2,6 +2,12 @@ package api +import ( + "strings" + + "github.com/gin-gonic/gin" +) + // AuthentikConfig holds settings for the Authentik forward-auth integration. type AuthentikConfig struct { // Issuer is the OIDC issuer URL (e.g. https://auth.example.com/application/o/my-app/). @@ -40,3 +46,75 @@ func (u *AuthentikUser) HasGroup(group string) bool { } return false } + +// authentikUserKey is the Gin context key used to store the authenticated user. +const authentikUserKey = "authentik_user" + +// GetUser retrieves the AuthentikUser from the Gin context. +// Returns nil when no user has been set (unauthenticated request or +// middleware not active). +func GetUser(c *gin.Context) *AuthentikUser { + val, exists := c.Get(authentikUserKey) + if !exists { + return nil + } + user, ok := val.(*AuthentikUser) + if !ok { + return nil + } + return user +} + +// authentikMiddleware returns Gin middleware that extracts user identity from +// X-authentik-* headers set by a trusted reverse proxy (e.g. Traefik with +// Authentik forward-auth). +// +// The middleware is PERMISSIVE: it populates the context when headers are +// present but never rejects unauthenticated requests. Downstream handlers +// use GetUser to check authentication. +func authentikMiddleware(cfg AuthentikConfig) gin.HandlerFunc { + // Build the set of public paths that skip header extraction entirely. + public := map[string]bool{ + "/health": true, + "/swagger": true, + } + for _, p := range cfg.PublicPaths { + public[p] = true + } + + return func(c *gin.Context) { + // Skip public paths. + path := c.Request.URL.Path + for p := range public { + if strings.HasPrefix(path, p) { + c.Next() + return + } + } + + // Only read headers when the proxy is trusted. + if cfg.TrustedProxy { + username := c.GetHeader("X-authentik-username") + if username != "" { + user := &AuthentikUser{ + Username: username, + Email: c.GetHeader("X-authentik-email"), + Name: c.GetHeader("X-authentik-name"), + UID: c.GetHeader("X-authentik-uid"), + JWT: c.GetHeader("X-authentik-jwt"), + } + + if groups := c.GetHeader("X-authentik-groups"); groups != "" { + user.Groups = strings.Split(groups, "|") + } + if ent := c.GetHeader("X-authentik-entitlements"); ent != "" { + user.Entitlements = strings.Split(ent, "|") + } + + c.Set(authentikUserKey, user) + } + } + + c.Next() + } +} diff --git a/authentik_test.go b/authentik_test.go index e2b99c7..dba810d 100644 --- a/authentik_test.go +++ b/authentik_test.go @@ -3,8 +3,12 @@ package api_test import ( + "net/http" + "net/http/httptest" "testing" + "github.com/gin-gonic/gin" + api "forge.lthn.ai/core/go-api" ) @@ -86,3 +90,170 @@ func TestAuthentikConfig_Good(t *testing.T) { t.Fatalf("expected 2 public paths, got %d", len(cfg.PublicPaths)) } } + +// ── Forward auth middleware ──────────────────────────────────────────── + +func TestForwardAuthHeaders_Good(t *testing.T) { + gin.SetMode(gin.TestMode) + + cfg := api.AuthentikConfig{TrustedProxy: true} + e, _ := api.New(api.WithAuthentik(cfg)) + + var gotUser *api.AuthentikUser + e.Register(&authTestGroup{onRequest: func(c *gin.Context) { + gotUser = api.GetUser(c) + c.JSON(http.StatusOK, api.OK("ok")) + }}) + + h := e.Handler() + w := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodGet, "/v1/check", nil) + req.Header.Set("X-authentik-username", "bob") + req.Header.Set("X-authentik-email", "bob@example.com") + req.Header.Set("X-authentik-name", "Bob Jones") + req.Header.Set("X-authentik-uid", "uid-456") + req.Header.Set("X-authentik-jwt", "jwt.tok.en") + req.Header.Set("X-authentik-groups", "staff|admins|ops") + req.Header.Set("X-authentik-entitlements", "read|write") + h.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", w.Code) + } + if gotUser == nil { + t.Fatal("expected GetUser to return a user, got nil") + } + if gotUser.Username != "bob" { + t.Fatalf("expected Username=%q, got %q", "bob", gotUser.Username) + } + if gotUser.Email != "bob@example.com" { + t.Fatalf("expected Email=%q, got %q", "bob@example.com", gotUser.Email) + } + if gotUser.Name != "Bob Jones" { + t.Fatalf("expected Name=%q, got %q", "Bob Jones", gotUser.Name) + } + if gotUser.UID != "uid-456" { + t.Fatalf("expected UID=%q, got %q", "uid-456", gotUser.UID) + } + if gotUser.JWT != "jwt.tok.en" { + t.Fatalf("expected JWT=%q, got %q", "jwt.tok.en", gotUser.JWT) + } + if len(gotUser.Groups) != 3 { + t.Fatalf("expected 3 groups, got %d: %v", len(gotUser.Groups), gotUser.Groups) + } + if gotUser.Groups[0] != "staff" || gotUser.Groups[1] != "admins" || gotUser.Groups[2] != "ops" { + t.Fatalf("expected groups [staff admins ops], got %v", gotUser.Groups) + } + if len(gotUser.Entitlements) != 2 { + t.Fatalf("expected 2 entitlements, got %d: %v", len(gotUser.Entitlements), gotUser.Entitlements) + } + if gotUser.Entitlements[0] != "read" || gotUser.Entitlements[1] != "write" { + t.Fatalf("expected entitlements [read write], got %v", gotUser.Entitlements) + } +} + +func TestForwardAuthHeaders_Good_NoHeaders(t *testing.T) { + gin.SetMode(gin.TestMode) + + cfg := api.AuthentikConfig{TrustedProxy: true} + e, _ := api.New(api.WithAuthentik(cfg)) + + var gotUser *api.AuthentikUser + e.Register(&authTestGroup{onRequest: func(c *gin.Context) { + gotUser = api.GetUser(c) + c.JSON(http.StatusOK, api.OK("ok")) + }}) + + h := e.Handler() + w := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodGet, "/v1/check", nil) + h.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", w.Code) + } + if gotUser != nil { + t.Fatalf("expected GetUser to return nil without headers, got %+v", gotUser) + } +} + +func TestForwardAuthHeaders_Bad_NotTrusted(t *testing.T) { + gin.SetMode(gin.TestMode) + + cfg := api.AuthentikConfig{TrustedProxy: false} + e, _ := api.New(api.WithAuthentik(cfg)) + + var gotUser *api.AuthentikUser + e.Register(&authTestGroup{onRequest: func(c *gin.Context) { + gotUser = api.GetUser(c) + c.JSON(http.StatusOK, api.OK("ok")) + }}) + + h := e.Handler() + w := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodGet, "/v1/check", nil) + req.Header.Set("X-authentik-username", "mallory") + req.Header.Set("X-authentik-email", "mallory@evil.com") + h.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", w.Code) + } + if gotUser != nil { + t.Fatalf("expected GetUser to return nil when TrustedProxy=false, got %+v", gotUser) + } +} + +func TestHealthBypassesAuthentik_Good(t *testing.T) { + gin.SetMode(gin.TestMode) + + cfg := api.AuthentikConfig{TrustedProxy: true} + e, _ := api.New(api.WithAuthentik(cfg)) + + h := e.Handler() + w := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodGet, "/health", nil) + h.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200 for /health, got %d", w.Code) + } +} + +func TestGetUser_Good_NilContext(t *testing.T) { + gin.SetMode(gin.TestMode) + + // Engine without WithAuthentik — GetUser should return nil. + e, _ := api.New() + + var gotUser *api.AuthentikUser + e.Register(&authTestGroup{onRequest: func(c *gin.Context) { + gotUser = api.GetUser(c) + c.JSON(http.StatusOK, api.OK("ok")) + }}) + + h := e.Handler() + w := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodGet, "/v1/check", nil) + h.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", w.Code) + } + if gotUser != nil { + t.Fatalf("expected GetUser to return nil without middleware, got %+v", gotUser) + } +} + +// ── Test helpers ─────────────────────────────────────────────────────── + +// authTestGroup provides a /v1/check endpoint that calls a custom handler. +type authTestGroup struct { + onRequest func(c *gin.Context) +} + +func (a *authTestGroup) Name() string { return "auth-test" } +func (a *authTestGroup) BasePath() string { return "/v1" } +func (a *authTestGroup) RegisterRoutes(rg *gin.RouterGroup) { + rg.GET("/check", a.onRequest) +} diff --git a/options.go b/options.go index 05a08e6..c3111f6 100644 --- a/options.go +++ b/options.go @@ -78,6 +78,15 @@ func WithWSHandler(h http.Handler) Option { } } +// WithAuthentik adds Authentik forward-auth middleware that extracts user +// identity from X-authentik-* headers set by a trusted reverse proxy. +// The middleware is permissive: unauthenticated requests are allowed through. +func WithAuthentik(cfg AuthentikConfig) Option { + return func(e *Engine) { + e.middlewares = append(e.middlewares, authentikMiddleware(cfg)) + } +} + // WithSwagger enables the Swagger UI at /swagger/. // The title, description, and version populate the OpenAPI info block. func WithSwagger(title, description, version string) Option {