feat(authentik): add RequireAuth and RequireGroup middleware
Add two route-level middleware helpers for enforcing authentication and
group membership. RequireAuth returns 401 when no user is in context.
RequireGroup returns 401 for unauthenticated requests and 403 when the
user lacks the specified group. Both use UK English error codes
("unauthorised", "forbidden") consistent with existing bearer auth.
Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
parent
5cba2f2cd4
commit
d279343491
2 changed files with 160 additions and 0 deletions
35
authentik.go
35
authentik.go
|
|
@ -4,6 +4,7 @@ package api
|
|||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
|
|
@ -195,3 +196,37 @@ func authentikMiddleware(cfg AuthentikConfig) gin.HandlerFunc {
|
|||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// RequireAuth is Gin middleware that rejects unauthenticated requests.
|
||||
// It checks for a user set by the Authentik middleware and returns 401
|
||||
// when none is present.
|
||||
func RequireAuth() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if GetUser(c) == nil {
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized,
|
||||
Fail("unauthorised", "Authentication required"))
|
||||
return
|
||||
}
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// RequireGroup is Gin middleware that rejects requests from users who do
|
||||
// not belong to the specified group. Returns 401 when no user is present
|
||||
// and 403 when the user lacks the required group membership.
|
||||
func RequireGroup(group string) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
user := GetUser(c)
|
||||
if user == nil {
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized,
|
||||
Fail("unauthorised", "Authentication required"))
|
||||
return
|
||||
}
|
||||
if !user.HasGroup(group) {
|
||||
c.AbortWithStatusJSON(http.StatusForbidden,
|
||||
Fail("forbidden", "Insufficient permissions"))
|
||||
return
|
||||
}
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ package api_test
|
|||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
|
@ -321,6 +322,107 @@ func TestBearerAndAuthentikCoexist_Good(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
// ── RequireAuth / RequireGroup ────────────────────────────────────────
|
||||
|
||||
func TestRequireAuth_Good(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
cfg := api.AuthentikConfig{TrustedProxy: true}
|
||||
e, _ := api.New(api.WithAuthentik(cfg))
|
||||
e.Register(&protectedGroup{})
|
||||
|
||||
h := e.Handler()
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest(http.MethodGet, "/v1/protected/data", nil)
|
||||
req.Header.Set("X-authentik-username", "alice")
|
||||
req.Header.Set("X-authentik-email", "alice@example.com")
|
||||
h.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequireAuth_Bad_NoUser(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
cfg := api.AuthentikConfig{TrustedProxy: true}
|
||||
e, _ := api.New(api.WithAuthentik(cfg))
|
||||
e.Register(&protectedGroup{})
|
||||
|
||||
h := e.Handler()
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest(http.MethodGet, "/v1/protected/data", nil)
|
||||
h.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusUnauthorized {
|
||||
t.Fatalf("expected 401, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
body := w.Body.String()
|
||||
if !strings.Contains(body, `"unauthorised"`) {
|
||||
t.Fatalf("expected error code 'unauthorised' in body, got %s", body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequireAuth_Bad_NoAuthentikMiddleware(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
// Engine without WithAuthentik — RequireAuth should still reject.
|
||||
e, _ := api.New()
|
||||
e.Register(&protectedGroup{})
|
||||
|
||||
h := e.Handler()
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest(http.MethodGet, "/v1/protected/data", nil)
|
||||
h.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusUnauthorized {
|
||||
t.Fatalf("expected 401, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequireGroup_Good(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
cfg := api.AuthentikConfig{TrustedProxy: true}
|
||||
e, _ := api.New(api.WithAuthentik(cfg))
|
||||
e.Register(&groupRequireGroup{})
|
||||
|
||||
h := e.Handler()
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest(http.MethodGet, "/v1/admin/panel", nil)
|
||||
req.Header.Set("X-authentik-username", "admin-user")
|
||||
req.Header.Set("X-authentik-groups", "admins|staff")
|
||||
h.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequireGroup_Bad_WrongGroup(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
cfg := api.AuthentikConfig{TrustedProxy: true}
|
||||
e, _ := api.New(api.WithAuthentik(cfg))
|
||||
e.Register(&groupRequireGroup{})
|
||||
|
||||
h := e.Handler()
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest(http.MethodGet, "/v1/admin/panel", nil)
|
||||
req.Header.Set("X-authentik-username", "dev-user")
|
||||
req.Header.Set("X-authentik-groups", "developers")
|
||||
h.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusForbidden {
|
||||
t.Fatalf("expected 403, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
body := w.Body.String()
|
||||
if !strings.Contains(body, `"forbidden"`) {
|
||||
t.Fatalf("expected error code 'forbidden' in body, got %s", body)
|
||||
}
|
||||
}
|
||||
|
||||
// ── Test helpers ───────────────────────────────────────────────────────
|
||||
|
||||
// authTestGroup provides a /v1/check endpoint that calls a custom handler.
|
||||
|
|
@ -333,3 +435,26 @@ func (a *authTestGroup) BasePath() string { return "/v1" }
|
|||
func (a *authTestGroup) RegisterRoutes(rg *gin.RouterGroup) {
|
||||
rg.GET("/check", a.onRequest)
|
||||
}
|
||||
|
||||
// protectedGroup provides a /v1/protected/data endpoint guarded by RequireAuth.
|
||||
type protectedGroup struct{}
|
||||
|
||||
func (g *protectedGroup) Name() string { return "protected" }
|
||||
func (g *protectedGroup) BasePath() string { return "/v1/protected" }
|
||||
func (g *protectedGroup) RegisterRoutes(rg *gin.RouterGroup) {
|
||||
rg.GET("/data", api.RequireAuth(), func(c *gin.Context) {
|
||||
user := api.GetUser(c)
|
||||
c.JSON(200, api.OK(user.Username))
|
||||
})
|
||||
}
|
||||
|
||||
// groupRequireGroup provides a /v1/admin/panel endpoint guarded by RequireGroup.
|
||||
type groupRequireGroup struct{}
|
||||
|
||||
func (g *groupRequireGroup) Name() string { return "adminonly" }
|
||||
func (g *groupRequireGroup) BasePath() string { return "/v1/admin" }
|
||||
func (g *groupRequireGroup) RegisterRoutes(rg *gin.RouterGroup) {
|
||||
rg.GET("/panel", api.RequireGroup("admins"), func(c *gin.Context) {
|
||||
c.JSON(200, api.OK("admin panel"))
|
||||
})
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue