diff --git a/authentik.go b/authentik.go index b80b434..7344bca 100644 --- a/authentik.go +++ b/authentik.go @@ -4,6 +4,7 @@ package api import ( "context" + "net/http" "strings" "sync" @@ -195,3 +196,37 @@ func authentikMiddleware(cfg AuthentikConfig) gin.HandlerFunc { c.Next() } } + +// RequireAuth is Gin middleware that rejects unauthenticated requests. +// It checks for a user set by the Authentik middleware and returns 401 +// when none is present. +func RequireAuth() gin.HandlerFunc { + return func(c *gin.Context) { + if GetUser(c) == nil { + c.AbortWithStatusJSON(http.StatusUnauthorized, + Fail("unauthorised", "Authentication required")) + return + } + c.Next() + } +} + +// RequireGroup is Gin middleware that rejects requests from users who do +// not belong to the specified group. Returns 401 when no user is present +// and 403 when the user lacks the required group membership. +func RequireGroup(group string) gin.HandlerFunc { + return func(c *gin.Context) { + user := GetUser(c) + if user == nil { + c.AbortWithStatusJSON(http.StatusUnauthorized, + Fail("unauthorised", "Authentication required")) + return + } + if !user.HasGroup(group) { + c.AbortWithStatusJSON(http.StatusForbidden, + Fail("forbidden", "Insufficient permissions")) + return + } + c.Next() + } +} diff --git a/authentik_test.go b/authentik_test.go index 2c10b04..0d68813 100644 --- a/authentik_test.go +++ b/authentik_test.go @@ -5,6 +5,7 @@ package api_test import ( "net/http" "net/http/httptest" + "strings" "testing" "github.com/gin-gonic/gin" @@ -321,6 +322,107 @@ func TestBearerAndAuthentikCoexist_Good(t *testing.T) { } } +// ── RequireAuth / RequireGroup ──────────────────────────────────────── + +func TestRequireAuth_Good(t *testing.T) { + gin.SetMode(gin.TestMode) + + cfg := api.AuthentikConfig{TrustedProxy: true} + e, _ := api.New(api.WithAuthentik(cfg)) + e.Register(&protectedGroup{}) + + h := e.Handler() + w := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodGet, "/v1/protected/data", nil) + req.Header.Set("X-authentik-username", "alice") + req.Header.Set("X-authentik-email", "alice@example.com") + h.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String()) + } +} + +func TestRequireAuth_Bad_NoUser(t *testing.T) { + gin.SetMode(gin.TestMode) + + cfg := api.AuthentikConfig{TrustedProxy: true} + e, _ := api.New(api.WithAuthentik(cfg)) + e.Register(&protectedGroup{}) + + h := e.Handler() + w := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodGet, "/v1/protected/data", nil) + h.ServeHTTP(w, req) + + if w.Code != http.StatusUnauthorized { + t.Fatalf("expected 401, got %d: %s", w.Code, w.Body.String()) + } + body := w.Body.String() + if !strings.Contains(body, `"unauthorised"`) { + t.Fatalf("expected error code 'unauthorised' in body, got %s", body) + } +} + +func TestRequireAuth_Bad_NoAuthentikMiddleware(t *testing.T) { + gin.SetMode(gin.TestMode) + + // Engine without WithAuthentik — RequireAuth should still reject. + e, _ := api.New() + e.Register(&protectedGroup{}) + + h := e.Handler() + w := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodGet, "/v1/protected/data", nil) + h.ServeHTTP(w, req) + + if w.Code != http.StatusUnauthorized { + t.Fatalf("expected 401, got %d: %s", w.Code, w.Body.String()) + } +} + +func TestRequireGroup_Good(t *testing.T) { + gin.SetMode(gin.TestMode) + + cfg := api.AuthentikConfig{TrustedProxy: true} + e, _ := api.New(api.WithAuthentik(cfg)) + e.Register(&groupRequireGroup{}) + + h := e.Handler() + w := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodGet, "/v1/admin/panel", nil) + req.Header.Set("X-authentik-username", "admin-user") + req.Header.Set("X-authentik-groups", "admins|staff") + h.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String()) + } +} + +func TestRequireGroup_Bad_WrongGroup(t *testing.T) { + gin.SetMode(gin.TestMode) + + cfg := api.AuthentikConfig{TrustedProxy: true} + e, _ := api.New(api.WithAuthentik(cfg)) + e.Register(&groupRequireGroup{}) + + h := e.Handler() + w := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodGet, "/v1/admin/panel", nil) + req.Header.Set("X-authentik-username", "dev-user") + req.Header.Set("X-authentik-groups", "developers") + h.ServeHTTP(w, req) + + if w.Code != http.StatusForbidden { + t.Fatalf("expected 403, got %d: %s", w.Code, w.Body.String()) + } + body := w.Body.String() + if !strings.Contains(body, `"forbidden"`) { + t.Fatalf("expected error code 'forbidden' in body, got %s", body) + } +} + // ── Test helpers ─────────────────────────────────────────────────────── // authTestGroup provides a /v1/check endpoint that calls a custom handler. @@ -333,3 +435,26 @@ func (a *authTestGroup) BasePath() string { return "/v1" } func (a *authTestGroup) RegisterRoutes(rg *gin.RouterGroup) { rg.GET("/check", a.onRequest) } + +// protectedGroup provides a /v1/protected/data endpoint guarded by RequireAuth. +type protectedGroup struct{} + +func (g *protectedGroup) Name() string { return "protected" } +func (g *protectedGroup) BasePath() string { return "/v1/protected" } +func (g *protectedGroup) RegisterRoutes(rg *gin.RouterGroup) { + rg.GET("/data", api.RequireAuth(), func(c *gin.Context) { + user := api.GetUser(c) + c.JSON(200, api.OK(user.Username)) + }) +} + +// groupRequireGroup provides a /v1/admin/panel endpoint guarded by RequireGroup. +type groupRequireGroup struct{} + +func (g *groupRequireGroup) Name() string { return "adminonly" } +func (g *groupRequireGroup) BasePath() string { return "/v1/admin" } +func (g *groupRequireGroup) RegisterRoutes(rg *gin.RouterGroup) { + rg.GET("/panel", api.RequireGroup("admins"), func(c *gin.Context) { + c.JSON(200, api.OK("admin panel")) + }) +}