From d27934349104640a1e03b0c650d65492805a9769 Mon Sep 17 00:00:00 2001 From: Snider Date: Fri, 20 Feb 2026 16:43:55 +0000 Subject: [PATCH] feat(authentik): add RequireAuth and RequireGroup middleware Add two route-level middleware helpers for enforcing authentication and group membership. RequireAuth returns 401 when no user is in context. RequireGroup returns 401 for unauthenticated requests and 403 when the user lacks the specified group. Both use UK English error codes ("unauthorised", "forbidden") consistent with existing bearer auth. Co-Authored-By: Virgil --- authentik.go | 35 +++++++++++++ authentik_test.go | 125 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 160 insertions(+) diff --git a/authentik.go b/authentik.go index b80b434..7344bca 100644 --- a/authentik.go +++ b/authentik.go @@ -4,6 +4,7 @@ package api import ( "context" + "net/http" "strings" "sync" @@ -195,3 +196,37 @@ func authentikMiddleware(cfg AuthentikConfig) gin.HandlerFunc { c.Next() } } + +// RequireAuth is Gin middleware that rejects unauthenticated requests. +// It checks for a user set by the Authentik middleware and returns 401 +// when none is present. +func RequireAuth() gin.HandlerFunc { + return func(c *gin.Context) { + if GetUser(c) == nil { + c.AbortWithStatusJSON(http.StatusUnauthorized, + Fail("unauthorised", "Authentication required")) + return + } + c.Next() + } +} + +// RequireGroup is Gin middleware that rejects requests from users who do +// not belong to the specified group. Returns 401 when no user is present +// and 403 when the user lacks the required group membership. +func RequireGroup(group string) gin.HandlerFunc { + return func(c *gin.Context) { + user := GetUser(c) + if user == nil { + c.AbortWithStatusJSON(http.StatusUnauthorized, + Fail("unauthorised", "Authentication required")) + return + } + if !user.HasGroup(group) { + c.AbortWithStatusJSON(http.StatusForbidden, + Fail("forbidden", "Insufficient permissions")) + return + } + c.Next() + } +} diff --git a/authentik_test.go b/authentik_test.go index 2c10b04..0d68813 100644 --- a/authentik_test.go +++ b/authentik_test.go @@ -5,6 +5,7 @@ package api_test import ( "net/http" "net/http/httptest" + "strings" "testing" "github.com/gin-gonic/gin" @@ -321,6 +322,107 @@ func TestBearerAndAuthentikCoexist_Good(t *testing.T) { } } +// ── RequireAuth / RequireGroup ──────────────────────────────────────── + +func TestRequireAuth_Good(t *testing.T) { + gin.SetMode(gin.TestMode) + + cfg := api.AuthentikConfig{TrustedProxy: true} + e, _ := api.New(api.WithAuthentik(cfg)) + e.Register(&protectedGroup{}) + + h := e.Handler() + w := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodGet, "/v1/protected/data", nil) + req.Header.Set("X-authentik-username", "alice") + req.Header.Set("X-authentik-email", "alice@example.com") + h.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String()) + } +} + +func TestRequireAuth_Bad_NoUser(t *testing.T) { + gin.SetMode(gin.TestMode) + + cfg := api.AuthentikConfig{TrustedProxy: true} + e, _ := api.New(api.WithAuthentik(cfg)) + e.Register(&protectedGroup{}) + + h := e.Handler() + w := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodGet, "/v1/protected/data", nil) + h.ServeHTTP(w, req) + + if w.Code != http.StatusUnauthorized { + t.Fatalf("expected 401, got %d: %s", w.Code, w.Body.String()) + } + body := w.Body.String() + if !strings.Contains(body, `"unauthorised"`) { + t.Fatalf("expected error code 'unauthorised' in body, got %s", body) + } +} + +func TestRequireAuth_Bad_NoAuthentikMiddleware(t *testing.T) { + gin.SetMode(gin.TestMode) + + // Engine without WithAuthentik — RequireAuth should still reject. + e, _ := api.New() + e.Register(&protectedGroup{}) + + h := e.Handler() + w := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodGet, "/v1/protected/data", nil) + h.ServeHTTP(w, req) + + if w.Code != http.StatusUnauthorized { + t.Fatalf("expected 401, got %d: %s", w.Code, w.Body.String()) + } +} + +func TestRequireGroup_Good(t *testing.T) { + gin.SetMode(gin.TestMode) + + cfg := api.AuthentikConfig{TrustedProxy: true} + e, _ := api.New(api.WithAuthentik(cfg)) + e.Register(&groupRequireGroup{}) + + h := e.Handler() + w := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodGet, "/v1/admin/panel", nil) + req.Header.Set("X-authentik-username", "admin-user") + req.Header.Set("X-authentik-groups", "admins|staff") + h.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String()) + } +} + +func TestRequireGroup_Bad_WrongGroup(t *testing.T) { + gin.SetMode(gin.TestMode) + + cfg := api.AuthentikConfig{TrustedProxy: true} + e, _ := api.New(api.WithAuthentik(cfg)) + e.Register(&groupRequireGroup{}) + + h := e.Handler() + w := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodGet, "/v1/admin/panel", nil) + req.Header.Set("X-authentik-username", "dev-user") + req.Header.Set("X-authentik-groups", "developers") + h.ServeHTTP(w, req) + + if w.Code != http.StatusForbidden { + t.Fatalf("expected 403, got %d: %s", w.Code, w.Body.String()) + } + body := w.Body.String() + if !strings.Contains(body, `"forbidden"`) { + t.Fatalf("expected error code 'forbidden' in body, got %s", body) + } +} + // ── Test helpers ─────────────────────────────────────────────────────── // authTestGroup provides a /v1/check endpoint that calls a custom handler. @@ -333,3 +435,26 @@ func (a *authTestGroup) BasePath() string { return "/v1" } func (a *authTestGroup) RegisterRoutes(rg *gin.RouterGroup) { rg.GET("/check", a.onRequest) } + +// protectedGroup provides a /v1/protected/data endpoint guarded by RequireAuth. +type protectedGroup struct{} + +func (g *protectedGroup) Name() string { return "protected" } +func (g *protectedGroup) BasePath() string { return "/v1/protected" } +func (g *protectedGroup) RegisterRoutes(rg *gin.RouterGroup) { + rg.GET("/data", api.RequireAuth(), func(c *gin.Context) { + user := api.GetUser(c) + c.JSON(200, api.OK(user.Username)) + }) +} + +// groupRequireGroup provides a /v1/admin/panel endpoint guarded by RequireGroup. +type groupRequireGroup struct{} + +func (g *groupRequireGroup) Name() string { return "adminonly" } +func (g *groupRequireGroup) BasePath() string { return "/v1/admin" } +func (g *groupRequireGroup) RegisterRoutes(rg *gin.RouterGroup) { + rg.GET("/panel", api.RequireGroup("admins"), func(c *gin.Context) { + c.JSON(200, api.OK("admin panel")) + }) +}