Compare commits
1 commit
dev
...
ax/review-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
43abce034e |
42 changed files with 1069 additions and 312 deletions
66
api.go
66
api.go
|
|
@ -6,12 +6,12 @@ package api
|
|||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"iter"
|
||||
"net/http"
|
||||
"slices"
|
||||
"time"
|
||||
|
||||
coreerr "dappco.re/go/core/log"
|
||||
"github.com/gin-contrib/expvar"
|
||||
"github.com/gin-contrib/pprof"
|
||||
"github.com/gin-gonic/gin"
|
||||
|
|
@ -41,6 +41,10 @@ type Engine struct {
|
|||
|
||||
// New creates an Engine with the given options.
|
||||
// The default listen address is ":8080".
|
||||
//
|
||||
// engine, _ := api.New(api.WithAddr(":9090"), api.WithCORS("*"))
|
||||
// engine.Register(myGroup)
|
||||
// engine.Serve(ctx)
|
||||
func New(opts ...Option) (*Engine, error) {
|
||||
e := &Engine{
|
||||
addr: defaultAddr,
|
||||
|
|
@ -52,6 +56,9 @@ func New(opts ...Option) (*Engine, error) {
|
|||
}
|
||||
|
||||
// Addr returns the configured listen address.
|
||||
//
|
||||
// engine, _ := api.New(api.WithAddr(":9090"))
|
||||
// addr := engine.Addr() // ":9090"
|
||||
func (e *Engine) Addr() string {
|
||||
return e.addr
|
||||
}
|
||||
|
|
@ -67,6 +74,9 @@ func (e *Engine) GroupsIter() iter.Seq[RouteGroup] {
|
|||
}
|
||||
|
||||
// Register adds a route group to the engine.
|
||||
//
|
||||
// engine.Register(api.NewToolBridge("/tools"))
|
||||
// engine.Register(myRouteGroup)
|
||||
func (e *Engine) Register(group RouteGroup) {
|
||||
e.groups = append(e.groups, group)
|
||||
}
|
||||
|
|
@ -75,9 +85,9 @@ func (e *Engine) Register(group RouteGroup) {
|
|||
// Groups that do not implement StreamGroup are silently skipped.
|
||||
func (e *Engine) Channels() []string {
|
||||
var channels []string
|
||||
for _, g := range e.groups {
|
||||
if sg, ok := g.(StreamGroup); ok {
|
||||
channels = append(channels, sg.Channels()...)
|
||||
for _, group := range e.groups {
|
||||
if streamGroup, ok := group.(StreamGroup); ok {
|
||||
channels = append(channels, streamGroup.Channels()...)
|
||||
}
|
||||
}
|
||||
return channels
|
||||
|
|
@ -86,10 +96,10 @@ func (e *Engine) Channels() []string {
|
|||
// ChannelsIter returns an iterator over WebSocket channel names from registered StreamGroups.
|
||||
func (e *Engine) ChannelsIter() iter.Seq[string] {
|
||||
return func(yield func(string) bool) {
|
||||
for _, g := range e.groups {
|
||||
if sg, ok := g.(StreamGroup); ok {
|
||||
for _, c := range sg.Channels() {
|
||||
if !yield(c) {
|
||||
for _, group := range e.groups {
|
||||
if streamGroup, ok := group.(StreamGroup); ok {
|
||||
for _, channelName := range streamGroup.Channels() {
|
||||
if !yield(channelName) {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
|
@ -100,6 +110,8 @@ func (e *Engine) ChannelsIter() iter.Seq[string] {
|
|||
|
||||
// Handler builds the Gin engine and returns it as an http.Handler.
|
||||
// Each call produces a fresh handler reflecting the current set of groups.
|
||||
//
|
||||
// http.ListenAndServe(":8080", engine.Handler())
|
||||
func (e *Engine) Handler() http.Handler {
|
||||
return e.build()
|
||||
}
|
||||
|
|
@ -107,14 +119,14 @@ func (e *Engine) Handler() http.Handler {
|
|||
// Serve starts the HTTP server and blocks until the context is cancelled,
|
||||
// then performs a graceful shutdown allowing in-flight requests to complete.
|
||||
func (e *Engine) Serve(ctx context.Context) error {
|
||||
srv := &http.Server{
|
||||
server := &http.Server{
|
||||
Addr: e.addr,
|
||||
Handler: e.build(),
|
||||
}
|
||||
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
if err := srv.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
if err := server.ListenAndServe(); err != nil && !coreerr.Is(err, http.ErrServerClosed) {
|
||||
errCh <- err
|
||||
}
|
||||
close(errCh)
|
||||
|
|
@ -124,10 +136,10 @@ func (e *Engine) Serve(ctx context.Context) error {
|
|||
<-ctx.Done()
|
||||
|
||||
// Graceful shutdown with timeout.
|
||||
shutdownCtx, cancel := context.WithTimeout(context.Background(), shutdownTimeout)
|
||||
shutdownContext, cancel := context.WithTimeout(context.Background(), shutdownTimeout)
|
||||
defer cancel()
|
||||
|
||||
if err := srv.Shutdown(shutdownCtx); err != nil {
|
||||
if err := server.Shutdown(shutdownContext); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
|
@ -138,54 +150,54 @@ func (e *Engine) Serve(ctx context.Context) error {
|
|||
// build creates a configured Gin engine with recovery middleware,
|
||||
// user-supplied middleware, the health endpoint, and all registered route groups.
|
||||
func (e *Engine) build() *gin.Engine {
|
||||
r := gin.New()
|
||||
r.Use(gin.Recovery())
|
||||
router := gin.New()
|
||||
router.Use(gin.Recovery())
|
||||
|
||||
// Apply user-supplied middleware after recovery but before routes.
|
||||
for _, mw := range e.middlewares {
|
||||
r.Use(mw)
|
||||
for _, middleware := range e.middlewares {
|
||||
router.Use(middleware)
|
||||
}
|
||||
|
||||
// Built-in health check.
|
||||
r.GET("/health", func(c *gin.Context) {
|
||||
router.GET("/health", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, OK("healthy"))
|
||||
})
|
||||
|
||||
// Mount each registered group at its base path.
|
||||
for _, g := range e.groups {
|
||||
rg := r.Group(g.BasePath())
|
||||
g.RegisterRoutes(rg)
|
||||
for _, group := range e.groups {
|
||||
routerGroup := router.Group(group.BasePath())
|
||||
group.RegisterRoutes(routerGroup)
|
||||
}
|
||||
|
||||
// Mount WebSocket handler if configured.
|
||||
if e.wsHandler != nil {
|
||||
r.GET("/ws", wrapWSHandler(e.wsHandler))
|
||||
router.GET("/ws", wrapWSHandler(e.wsHandler))
|
||||
}
|
||||
|
||||
// Mount SSE endpoint if configured.
|
||||
if e.sseBroker != nil {
|
||||
r.GET("/events", e.sseBroker.Handler())
|
||||
router.GET("/events", e.sseBroker.Handler())
|
||||
}
|
||||
|
||||
// Mount GraphQL endpoint if configured.
|
||||
if e.graphql != nil {
|
||||
mountGraphQL(r, e.graphql)
|
||||
mountGraphQL(router, e.graphql)
|
||||
}
|
||||
|
||||
// Mount Swagger UI if enabled.
|
||||
if e.swaggerEnabled {
|
||||
registerSwagger(r, e.swaggerTitle, e.swaggerDesc, e.swaggerVersion, e.groups)
|
||||
registerSwagger(router, e.swaggerTitle, e.swaggerDesc, e.swaggerVersion, e.groups)
|
||||
}
|
||||
|
||||
// Mount pprof profiling endpoints if enabled.
|
||||
if e.pprofEnabled {
|
||||
pprof.Register(r)
|
||||
pprof.Register(router)
|
||||
}
|
||||
|
||||
// Mount expvar runtime metrics endpoint if enabled.
|
||||
if e.expvarEnabled {
|
||||
r.GET("/debug/vars", expvar.Handler())
|
||||
router.GET("/debug/vars", expvar.Handler())
|
||||
}
|
||||
|
||||
return r
|
||||
return router
|
||||
}
|
||||
|
|
|
|||
18
api_test.go
18
api_test.go
|
|
@ -202,3 +202,21 @@ func TestServe_Good_GracefulShutdown(t *testing.T) {
|
|||
t.Fatal("Serve did not return within 5 seconds after context cancellation")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNew_Ugly_MultipleOptionsDontPanic(t *testing.T) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Fatalf("New with many options panicked: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
// Applying many options at once should not panic.
|
||||
_, err := api.New(
|
||||
api.WithAddr(":0"),
|
||||
api.WithRequestID(),
|
||||
api.WithCORS("*"),
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
54
authentik.go
54
authentik.go
|
|
@ -6,9 +6,9 @@ import (
|
|||
"context"
|
||||
"net/http"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"dappco.re/go/core"
|
||||
"github.com/coreos/go-oidc/v3/oidc"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
|
@ -43,6 +43,9 @@ type AuthentikUser struct {
|
|||
}
|
||||
|
||||
// HasGroup reports whether the user belongs to the named group.
|
||||
//
|
||||
// user := api.GetUser(c)
|
||||
// if user.HasGroup("admins") { /* allow */ }
|
||||
func (u *AuthentikUser) HasGroup(group string) bool {
|
||||
return slices.Contains(u.Groups, group)
|
||||
}
|
||||
|
|
@ -53,6 +56,9 @@ const authentikUserKey = "authentik_user"
|
|||
// GetUser retrieves the AuthentikUser from the Gin context.
|
||||
// Returns nil when no user has been set (unauthenticated request or
|
||||
// middleware not active).
|
||||
//
|
||||
// user := api.GetUser(c)
|
||||
// if user == nil { c.AbortWithStatus(401); return }
|
||||
func GetUser(c *gin.Context) *AuthentikUser {
|
||||
val, exists := c.Get(authentikUserKey)
|
||||
if !exists {
|
||||
|
|
@ -78,28 +84,28 @@ func getOIDCProvider(ctx context.Context, issuer string) (*oidc.Provider, error)
|
|||
oidcProviderMu.Lock()
|
||||
defer oidcProviderMu.Unlock()
|
||||
|
||||
if p, ok := oidcProviders[issuer]; ok {
|
||||
return p, nil
|
||||
if provider, ok := oidcProviders[issuer]; ok {
|
||||
return provider, nil
|
||||
}
|
||||
|
||||
p, err := oidc.NewProvider(ctx, issuer)
|
||||
provider, err := oidc.NewProvider(ctx, issuer)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
oidcProviders[issuer] = p
|
||||
return p, nil
|
||||
oidcProviders[issuer] = provider
|
||||
return provider, nil
|
||||
}
|
||||
|
||||
// validateJWT verifies a raw JWT against the configured OIDC issuer and
|
||||
// extracts user claims on success.
|
||||
func validateJWT(ctx context.Context, cfg AuthentikConfig, rawToken string) (*AuthentikUser, error) {
|
||||
provider, err := getOIDCProvider(ctx, cfg.Issuer)
|
||||
func validateJWT(ctx context.Context, config AuthentikConfig, rawToken string) (*AuthentikUser, error) {
|
||||
provider, err := getOIDCProvider(ctx, config.Issuer)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
verifier := provider.Verifier(&oidc.Config{ClientID: cfg.ClientID})
|
||||
verifier := provider.Verifier(&oidc.Config{ClientID: config.ClientID})
|
||||
|
||||
idToken, err := verifier.Verify(ctx, rawToken)
|
||||
if err != nil {
|
||||
|
|
@ -134,28 +140,28 @@ func validateJWT(ctx context.Context, cfg AuthentikConfig, rawToken string) (*Au
|
|||
// The middleware is PERMISSIVE: it populates the context when credentials are
|
||||
// present but never rejects unauthenticated requests. Downstream handlers
|
||||
// use GetUser to check authentication.
|
||||
func authentikMiddleware(cfg AuthentikConfig) gin.HandlerFunc {
|
||||
func authentikMiddleware(config AuthentikConfig) gin.HandlerFunc {
|
||||
// Build the set of public paths that skip header extraction entirely.
|
||||
public := map[string]bool{
|
||||
"/health": true,
|
||||
"/swagger": true,
|
||||
}
|
||||
for _, p := range cfg.PublicPaths {
|
||||
public[p] = true
|
||||
for _, publicPath := range config.PublicPaths {
|
||||
public[publicPath] = true
|
||||
}
|
||||
|
||||
return func(c *gin.Context) {
|
||||
// Skip public paths.
|
||||
path := c.Request.URL.Path
|
||||
for p := range public {
|
||||
if strings.HasPrefix(path, p) {
|
||||
for publicPath := range public {
|
||||
if core.HasPrefix(path, publicPath) {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Block 1: Extract user from X-authentik-* forward-auth headers.
|
||||
if cfg.TrustedProxy {
|
||||
if config.TrustedProxy {
|
||||
username := c.GetHeader("X-authentik-username")
|
||||
if username != "" {
|
||||
user := &AuthentikUser{
|
||||
|
|
@ -167,10 +173,10 @@ func authentikMiddleware(cfg AuthentikConfig) gin.HandlerFunc {
|
|||
}
|
||||
|
||||
if groups := c.GetHeader("X-authentik-groups"); groups != "" {
|
||||
user.Groups = strings.Split(groups, "|")
|
||||
user.Groups = core.Split(groups, "|")
|
||||
}
|
||||
if ent := c.GetHeader("X-authentik-entitlements"); ent != "" {
|
||||
user.Entitlements = strings.Split(ent, "|")
|
||||
user.Entitlements = core.Split(ent, "|")
|
||||
}
|
||||
|
||||
c.Set(authentikUserKey, user)
|
||||
|
|
@ -179,10 +185,10 @@ func authentikMiddleware(cfg AuthentikConfig) gin.HandlerFunc {
|
|||
|
||||
// Block 2: Attempt JWT validation for direct API clients.
|
||||
// Only when OIDC is configured and no user was extracted from headers.
|
||||
if cfg.Issuer != "" && cfg.ClientID != "" && GetUser(c) == nil {
|
||||
if auth := c.GetHeader("Authorization"); strings.HasPrefix(auth, "Bearer ") {
|
||||
rawToken := strings.TrimPrefix(auth, "Bearer ")
|
||||
if user, err := validateJWT(c.Request.Context(), cfg, rawToken); err == nil {
|
||||
if config.Issuer != "" && config.ClientID != "" && GetUser(c) == nil {
|
||||
if auth := c.GetHeader("Authorization"); core.HasPrefix(auth, "Bearer ") {
|
||||
rawToken := core.TrimPrefix(auth, "Bearer ")
|
||||
if user, err := validateJWT(c.Request.Context(), config, rawToken); err == nil {
|
||||
c.Set(authentikUserKey, user)
|
||||
}
|
||||
// On failure: continue without user (fail open / permissive).
|
||||
|
|
@ -196,6 +202,9 @@ func authentikMiddleware(cfg AuthentikConfig) gin.HandlerFunc {
|
|||
// RequireAuth is Gin middleware that rejects unauthenticated requests.
|
||||
// It checks for a user set by the Authentik middleware and returns 401
|
||||
// when none is present.
|
||||
//
|
||||
// rg := router.Group("/api", api.RequireAuth())
|
||||
// rg.GET("/profile", profileHandler)
|
||||
func RequireAuth() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if GetUser(c) == nil {
|
||||
|
|
@ -210,6 +219,9 @@ func RequireAuth() gin.HandlerFunc {
|
|||
// RequireGroup is Gin middleware that rejects requests from users who do
|
||||
// not belong to the specified group. Returns 401 when no user is present
|
||||
// and 403 when the user lacks the required group membership.
|
||||
//
|
||||
// rg := router.Group("/admin", api.RequireGroup("admins"))
|
||||
// rg.DELETE("/users/:id", deleteUserHandler)
|
||||
func RequireGroup(group string) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
user := GetUser(c)
|
||||
|
|
|
|||
|
|
@ -3,16 +3,14 @@
|
|||
package api_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"dappco.re/go/core"
|
||||
|
||||
api "dappco.re/go/core/api"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
|
@ -43,58 +41,68 @@ func getClientCredentialsToken(t *testing.T, issuer, clientID, clientSecret stri
|
|||
t.Helper()
|
||||
|
||||
// Discover token endpoint.
|
||||
disc := strings.TrimSuffix(issuer, "/") + "/.well-known/openid-configuration"
|
||||
resp, err := http.Get(disc)
|
||||
discoveryURL := core.TrimSuffix(issuer, "/") + "/.well-known/openid-configuration"
|
||||
resp, err := http.Get(discoveryURL) //nolint:noctx
|
||||
if err != nil {
|
||||
t.Fatalf("OIDC discovery failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var config struct {
|
||||
discoveryBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("read discovery body: %v", err)
|
||||
}
|
||||
|
||||
var oidcConfig struct {
|
||||
TokenEndpoint string `json:"token_endpoint"`
|
||||
}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&config); err != nil {
|
||||
t.Fatalf("decode discovery: %v", err)
|
||||
if result := core.JSONUnmarshal(discoveryBody, &oidcConfig); !result.OK {
|
||||
t.Fatalf("decode discovery: %v", result.Value)
|
||||
}
|
||||
|
||||
// Request token.
|
||||
data := url.Values{
|
||||
formData := url.Values{
|
||||
"grant_type": {"client_credentials"},
|
||||
"client_id": {clientID},
|
||||
"client_secret": {clientSecret},
|
||||
"scope": {"openid email profile entitlements"},
|
||||
}
|
||||
resp, err = http.PostForm(config.TokenEndpoint, data)
|
||||
tokenResp, err := http.PostForm(oidcConfig.TokenEndpoint, formData) //nolint:noctx
|
||||
if err != nil {
|
||||
t.Fatalf("token request failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
defer tokenResp.Body.Close()
|
||||
|
||||
var tokenResp struct {
|
||||
tokenBody, err := io.ReadAll(tokenResp.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("read token body: %v", err)
|
||||
}
|
||||
|
||||
var tokenResult struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
IDToken string `json:"id_token"`
|
||||
Error string `json:"error"`
|
||||
ErrorDesc string `json:"error_description"`
|
||||
}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
|
||||
t.Fatalf("decode token response: %v", err)
|
||||
if result := core.JSONUnmarshal(tokenBody, &tokenResult); !result.OK {
|
||||
t.Fatalf("decode token response: %v", result.Value)
|
||||
}
|
||||
if tokenResp.Error != "" {
|
||||
t.Fatalf("token error: %s — %s", tokenResp.Error, tokenResp.ErrorDesc)
|
||||
if tokenResult.Error != "" {
|
||||
t.Fatalf("token error: %s — %s", tokenResult.Error, tokenResult.ErrorDesc)
|
||||
}
|
||||
|
||||
return tokenResp.AccessToken, tokenResp.IDToken
|
||||
return tokenResult.AccessToken, tokenResult.IDToken
|
||||
}
|
||||
|
||||
func TestAuthentikIntegration(t *testing.T) {
|
||||
func TestAuthentikIntegration_Good_LiveTokenFlow(t *testing.T) {
|
||||
// Skip unless explicitly enabled — requires live Authentik at auth.lthn.io.
|
||||
if os.Getenv("AUTHENTIK_INTEGRATION") != "1" {
|
||||
if core.Env("AUTHENTIK_INTEGRATION") != "1" {
|
||||
t.Skip("set AUTHENTIK_INTEGRATION=1 to run live Authentik tests")
|
||||
}
|
||||
|
||||
issuer := envOr("AUTHENTIK_ISSUER", "https://auth.lthn.io/application/o/core-api/")
|
||||
clientID := envOr("AUTHENTIK_CLIENT_ID", "core-api")
|
||||
clientSecret := os.Getenv("AUTHENTIK_CLIENT_SECRET")
|
||||
issuer := envOrDefault("AUTHENTIK_ISSUER", "https://auth.lthn.io/application/o/core-api/")
|
||||
clientID := envOrDefault("AUTHENTIK_CLIENT_ID", "core-api")
|
||||
clientSecret := core.Env("AUTHENTIK_CLIENT_SECRET")
|
||||
if clientSecret == "" {
|
||||
t.Fatal("AUTHENTIK_CLIENT_SECRET is required")
|
||||
}
|
||||
|
|
@ -126,60 +134,60 @@ func TestAuthentikIntegration(t *testing.T) {
|
|||
t.Fatalf("engine: %v", err)
|
||||
}
|
||||
engine.Register(&testAuthRoutes{})
|
||||
ts := httptest.NewServer(engine.Handler())
|
||||
defer ts.Close()
|
||||
testServer := httptest.NewServer(engine.Handler())
|
||||
defer testServer.Close()
|
||||
|
||||
accessToken, _ := getClientCredentialsToken(t, issuer, clientID, clientSecret)
|
||||
|
||||
t.Run("Health_NoAuth", func(t *testing.T) {
|
||||
resp := get(t, ts.URL+"/health", "")
|
||||
assertStatus(t, resp, 200)
|
||||
body := readBody(t, resp)
|
||||
resp := getWithBearer(t, testServer.URL+"/health", "")
|
||||
assertStatusCode(t, resp, 200)
|
||||
body := readResponseBody(t, resp)
|
||||
t.Logf("health: %s", body)
|
||||
})
|
||||
|
||||
t.Run("Public_NoAuth", func(t *testing.T) {
|
||||
resp := get(t, ts.URL+"/v1/public", "")
|
||||
assertStatus(t, resp, 200)
|
||||
body := readBody(t, resp)
|
||||
resp := getWithBearer(t, testServer.URL+"/v1/public", "")
|
||||
assertStatusCode(t, resp, 200)
|
||||
body := readResponseBody(t, resp)
|
||||
t.Logf("public: %s", body)
|
||||
})
|
||||
|
||||
t.Run("Whoami_NoToken_401", func(t *testing.T) {
|
||||
resp := get(t, ts.URL+"/v1/whoami", "")
|
||||
assertStatus(t, resp, 401)
|
||||
resp := getWithBearer(t, testServer.URL+"/v1/whoami", "")
|
||||
assertStatusCode(t, resp, 401)
|
||||
})
|
||||
|
||||
t.Run("Whoami_WithAccessToken", func(t *testing.T) {
|
||||
resp := get(t, ts.URL+"/v1/whoami", accessToken)
|
||||
assertStatus(t, resp, 200)
|
||||
body := readBody(t, resp)
|
||||
resp := getWithBearer(t, testServer.URL+"/v1/whoami", accessToken)
|
||||
assertStatusCode(t, resp, 200)
|
||||
body := readResponseBody(t, resp)
|
||||
t.Logf("whoami (access_token): %s", body)
|
||||
|
||||
// Parse response and verify user fields.
|
||||
var envelope struct {
|
||||
Data api.AuthentikUser `json:"data"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(body), &envelope); err != nil {
|
||||
t.Fatalf("parse whoami: %v", err)
|
||||
if result := core.JSONUnmarshalString(body, &envelope); !result.OK {
|
||||
t.Fatalf("parse whoami: %v", result.Value)
|
||||
}
|
||||
if envelope.Data.UID == "" {
|
||||
t.Error("expected non-empty UID")
|
||||
}
|
||||
if !strings.Contains(envelope.Data.Username, "client_credentials") {
|
||||
if !core.Contains(envelope.Data.Username, "client_credentials") {
|
||||
t.Logf("username: %s (service account)", envelope.Data.Username)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Admin_ServiceAccount_403", func(t *testing.T) {
|
||||
// Service account has no groups — should get 403.
|
||||
resp := get(t, ts.URL+"/v1/admin", accessToken)
|
||||
assertStatus(t, resp, 403)
|
||||
resp := getWithBearer(t, testServer.URL+"/v1/admin", accessToken)
|
||||
assertStatusCode(t, resp, 403)
|
||||
})
|
||||
|
||||
t.Run("Whoami_ForwardAuthHeaders", func(t *testing.T) {
|
||||
// Simulate what Traefik sends after forward auth.
|
||||
req, _ := http.NewRequest("GET", ts.URL+"/v1/whoami", nil)
|
||||
req, _ := http.NewRequest("GET", testServer.URL+"/v1/whoami", nil)
|
||||
req.Header.Set("X-authentik-username", "akadmin")
|
||||
req.Header.Set("X-authentik-email", "mafiafire@proton.me")
|
||||
req.Header.Set("X-authentik-name", "Admin User")
|
||||
|
|
@ -192,16 +200,16 @@ func TestAuthentikIntegration(t *testing.T) {
|
|||
t.Fatalf("request: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
assertStatus(t, resp, 200)
|
||||
assertStatusCode(t, resp, 200)
|
||||
|
||||
body := readBody(t, resp)
|
||||
body := readResponseBody(t, resp)
|
||||
t.Logf("whoami (forward auth): %s", body)
|
||||
|
||||
var envelope struct {
|
||||
Data api.AuthentikUser `json:"data"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(body), &envelope); err != nil {
|
||||
t.Fatalf("parse: %v", err)
|
||||
if result := core.JSONUnmarshalString(body, &envelope); !result.OK {
|
||||
t.Fatalf("parse: %v", result.Value)
|
||||
}
|
||||
if envelope.Data.Username != "akadmin" {
|
||||
t.Errorf("expected username akadmin, got %s", envelope.Data.Username)
|
||||
|
|
@ -212,7 +220,7 @@ func TestAuthentikIntegration(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("Admin_ForwardAuth_Admins_200", func(t *testing.T) {
|
||||
req, _ := http.NewRequest("GET", ts.URL+"/v1/admin", nil)
|
||||
req, _ := http.NewRequest("GET", testServer.URL+"/v1/admin", nil)
|
||||
req.Header.Set("X-authentik-username", "akadmin")
|
||||
req.Header.Set("X-authentik-email", "mafiafire@proton.me")
|
||||
req.Header.Set("X-authentik-name", "Admin User")
|
||||
|
|
@ -224,72 +232,72 @@ func TestAuthentikIntegration(t *testing.T) {
|
|||
t.Fatalf("request: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
assertStatus(t, resp, 200)
|
||||
t.Logf("admin (forward auth): %s", readBody(t, resp))
|
||||
assertStatusCode(t, resp, 200)
|
||||
t.Logf("admin (forward auth): %s", readResponseBody(t, resp))
|
||||
})
|
||||
|
||||
t.Run("InvalidJWT_FailOpen", func(t *testing.T) {
|
||||
// Invalid token on a public endpoint — should still work (permissive).
|
||||
resp := get(t, ts.URL+"/v1/public", "not-a-real-token")
|
||||
assertStatus(t, resp, 200)
|
||||
resp := getWithBearer(t, testServer.URL+"/v1/public", "not-a-real-token")
|
||||
assertStatusCode(t, resp, 200)
|
||||
})
|
||||
|
||||
t.Run("InvalidJWT_Protected_401", func(t *testing.T) {
|
||||
// Invalid token on a protected endpoint — no user extracted, RequireAuth returns 401.
|
||||
resp := get(t, ts.URL+"/v1/whoami", "not-a-real-token")
|
||||
assertStatus(t, resp, 401)
|
||||
resp := getWithBearer(t, testServer.URL+"/v1/whoami", "not-a-real-token")
|
||||
assertStatusCode(t, resp, 401)
|
||||
})
|
||||
}
|
||||
|
||||
func get(t *testing.T, url, bearerToken string) *http.Response {
|
||||
func getWithBearer(t *testing.T, requestURL, bearerToken string) *http.Response {
|
||||
t.Helper()
|
||||
req, _ := http.NewRequest("GET", url, nil)
|
||||
req, _ := http.NewRequest("GET", requestURL, nil)
|
||||
if bearerToken != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+bearerToken)
|
||||
}
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("GET %s: %v", url, err)
|
||||
t.Fatalf("GET %s: %v", requestURL, err)
|
||||
}
|
||||
return resp
|
||||
}
|
||||
|
||||
func readBody(t *testing.T, resp *http.Response) string {
|
||||
func readResponseBody(t *testing.T, resp *http.Response) string {
|
||||
t.Helper()
|
||||
b, err := io.ReadAll(resp.Body)
|
||||
responseBytes, err := io.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
if err != nil {
|
||||
t.Fatalf("read body: %v", err)
|
||||
}
|
||||
return string(b)
|
||||
return string(responseBytes)
|
||||
}
|
||||
|
||||
func assertStatus(t *testing.T, resp *http.Response, want int) {
|
||||
func assertStatusCode(t *testing.T, resp *http.Response, want int) {
|
||||
t.Helper()
|
||||
if resp.StatusCode != want {
|
||||
b, _ := io.ReadAll(resp.Body)
|
||||
responseBytes, _ := io.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
t.Fatalf("want status %d, got %d: %s", want, resp.StatusCode, string(b))
|
||||
t.Fatalf("want status %d, got %d: %s", want, resp.StatusCode, string(responseBytes))
|
||||
}
|
||||
}
|
||||
|
||||
func envOr(key, fallback string) string {
|
||||
if v := os.Getenv(key); v != "" {
|
||||
return v
|
||||
func envOrDefault(key, fallback string) string {
|
||||
if value := core.Env(key); value != "" {
|
||||
return value
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
// TestOIDCDiscovery validates that the OIDC discovery endpoint is reachable.
|
||||
func TestOIDCDiscovery(t *testing.T) {
|
||||
if os.Getenv("AUTHENTIK_INTEGRATION") != "1" {
|
||||
// TestOIDCDiscovery_Good_EndpointReachable validates that the OIDC discovery endpoint is reachable.
|
||||
func TestOIDCDiscovery_Good_EndpointReachable(t *testing.T) {
|
||||
if core.Env("AUTHENTIK_INTEGRATION") != "1" {
|
||||
t.Skip("set AUTHENTIK_INTEGRATION=1 to run live Authentik tests")
|
||||
}
|
||||
|
||||
issuer := envOr("AUTHENTIK_ISSUER", "https://auth.lthn.io/application/o/core-api/")
|
||||
disc := strings.TrimSuffix(issuer, "/") + "/.well-known/openid-configuration"
|
||||
issuer := envOrDefault("AUTHENTIK_ISSUER", "https://auth.lthn.io/application/o/core-api/")
|
||||
discoveryURL := core.TrimSuffix(issuer, "/") + "/.well-known/openid-configuration"
|
||||
|
||||
resp, err := http.Get(disc)
|
||||
resp, err := http.Get(discoveryURL) //nolint:noctx
|
||||
if err != nil {
|
||||
t.Fatalf("discovery request: %v", err)
|
||||
}
|
||||
|
|
@ -299,39 +307,70 @@ func TestOIDCDiscovery(t *testing.T) {
|
|||
t.Fatalf("discovery status: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var config map[string]any
|
||||
if err := json.NewDecoder(resp.Body).Decode(&config); err != nil {
|
||||
t.Fatalf("decode: %v", err)
|
||||
discoveryBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("read discovery body: %v", err)
|
||||
}
|
||||
|
||||
var discoveryConfig map[string]any
|
||||
if result := core.JSONUnmarshal(discoveryBody, &discoveryConfig); !result.OK {
|
||||
t.Fatalf("decode: %v", result.Value)
|
||||
}
|
||||
|
||||
// Verify essential fields.
|
||||
for _, field := range []string{"issuer", "token_endpoint", "jwks_uri", "authorization_endpoint"} {
|
||||
if config[field] == nil {
|
||||
if discoveryConfig[field] == nil {
|
||||
t.Errorf("missing field: %s", field)
|
||||
}
|
||||
}
|
||||
|
||||
if config["issuer"] != issuer {
|
||||
t.Errorf("issuer mismatch: got %v, want %s", config["issuer"], issuer)
|
||||
if discoveryConfig["issuer"] != issuer {
|
||||
t.Errorf("issuer mismatch: got %v, want %s", discoveryConfig["issuer"], issuer)
|
||||
}
|
||||
|
||||
// Verify grant types include client_credentials.
|
||||
grants, ok := config["grant_types_supported"].([]any)
|
||||
grants, ok := discoveryConfig["grant_types_supported"].([]any)
|
||||
if !ok {
|
||||
t.Fatal("missing grant_types_supported")
|
||||
}
|
||||
found := false
|
||||
for _, g := range grants {
|
||||
if g == "client_credentials" {
|
||||
found = true
|
||||
clientCredentialsFound := false
|
||||
for _, grantType := range grants {
|
||||
if grantType == "client_credentials" {
|
||||
clientCredentialsFound = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
if !clientCredentialsFound {
|
||||
t.Error("client_credentials grant not supported")
|
||||
}
|
||||
|
||||
fmt.Printf(" OIDC discovery OK — issuer: %s\n", config["issuer"])
|
||||
fmt.Printf(" Token endpoint: %s\n", config["token_endpoint"])
|
||||
fmt.Printf(" JWKS URI: %s\n", config["jwks_uri"])
|
||||
t.Logf("OIDC discovery OK — issuer: %s", discoveryConfig["issuer"])
|
||||
t.Logf("Token endpoint: %s", discoveryConfig["token_endpoint"])
|
||||
t.Logf("JWKS URI: %s", discoveryConfig["jwks_uri"])
|
||||
}
|
||||
|
||||
// TestOIDCDiscovery_Bad_SkipsWithoutEnvVar verifies the test skips without AUTHENTIK_INTEGRATION=1.
|
||||
func TestOIDCDiscovery_Bad_SkipsWithoutEnvVar(t *testing.T) {
|
||||
// This test always runs; it verifies no network call is made without the env var.
|
||||
// Since we cannot unset env vars safely in parallel tests, we verify the skip logic
|
||||
// by running this in an environment where AUTHENTIK_INTEGRATION is not "1".
|
||||
if core.Env("AUTHENTIK_INTEGRATION") == "1" {
|
||||
t.Skip("skipping skip-check test when integration env is set")
|
||||
}
|
||||
// No network call should happen — test passes if we reach here.
|
||||
}
|
||||
|
||||
// TestOIDCDiscovery_Ugly_MalformedIssuerHandled verifies the discovery helper does not panic on bad issuer.
|
||||
func TestOIDCDiscovery_Ugly_MalformedIssuerHandled(t *testing.T) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Fatalf("malformed issuer panicked: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
// envOrDefault returns fallback on empty — verify it does not panic on empty key.
|
||||
result := envOrDefault("", "fallback")
|
||||
if result != "fallback" {
|
||||
t.Errorf("expected fallback, got %q", result)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -458,3 +458,41 @@ func (g *groupRequireGroup) RegisterRoutes(rg *gin.RouterGroup) {
|
|||
c.JSON(200, api.OK("admin panel"))
|
||||
})
|
||||
}
|
||||
|
||||
func TestAuthentikUser_Ugly_EmptyGroupsDontPanic(t *testing.T) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Fatalf("HasGroup on empty groups panicked: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
u := api.AuthentikUser{}
|
||||
// HasGroup on a zero-value user (nil Groups slice) must not panic.
|
||||
if u.HasGroup("admins") {
|
||||
t.Fatal("expected HasGroup to return false for empty user")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetUser_Ugly_WrongTypeInContext(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
router := gin.New()
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
// Inject a wrong type under the authentik key — GetUser must return nil.
|
||||
c.Set("authentik_user", "not-a-user-struct")
|
||||
user := api.GetUser(c)
|
||||
if user != nil {
|
||||
c.JSON(http.StatusInternalServerError, api.Fail("error", "unexpected user"))
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, api.OK("nil as expected"))
|
||||
})
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
request, _ := http.NewRequest(http.MethodGet, "/test", nil)
|
||||
router.ServeHTTP(recorder, request)
|
||||
|
||||
if recorder.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d: %s", recorder.Code, recorder.Body.String())
|
||||
}
|
||||
}
|
||||
|
|
|
|||
50
bridge.go
50
bridge.go
|
|
@ -31,6 +31,10 @@ type boundTool struct {
|
|||
}
|
||||
|
||||
// NewToolBridge creates a bridge that mounts tool endpoints at basePath.
|
||||
//
|
||||
// bridge := api.NewToolBridge("/tools")
|
||||
// bridge.Add(api.ToolDescriptor{Name: "file_read"}, fileReadHandler)
|
||||
// engine.Register(bridge)
|
||||
func NewToolBridge(basePath string) *ToolBridge {
|
||||
return &ToolBridge{
|
||||
basePath: basePath,
|
||||
|
|
@ -39,6 +43,8 @@ func NewToolBridge(basePath string) *ToolBridge {
|
|||
}
|
||||
|
||||
// Add registers a tool with its HTTP handler.
|
||||
//
|
||||
// bridge.Add(api.ToolDescriptor{Name: "file_read", Group: "files"}, fileReadHandler)
|
||||
func (b *ToolBridge) Add(desc ToolDescriptor, handler gin.HandlerFunc) {
|
||||
b.tools = append(b.tools, boundTool{descriptor: desc, handler: handler})
|
||||
}
|
||||
|
|
@ -51,27 +57,27 @@ func (b *ToolBridge) BasePath() string { return b.basePath }
|
|||
|
||||
// RegisterRoutes mounts POST /{tool_name} for each registered tool.
|
||||
func (b *ToolBridge) RegisterRoutes(rg *gin.RouterGroup) {
|
||||
for _, t := range b.tools {
|
||||
rg.POST("/"+t.descriptor.Name, t.handler)
|
||||
for _, tool := range b.tools {
|
||||
rg.POST("/"+tool.descriptor.Name, tool.handler)
|
||||
}
|
||||
}
|
||||
|
||||
// Describe returns OpenAPI route descriptions for all registered tools.
|
||||
func (b *ToolBridge) Describe() []RouteDescription {
|
||||
descs := make([]RouteDescription, 0, len(b.tools))
|
||||
for _, t := range b.tools {
|
||||
tags := []string{t.descriptor.Group}
|
||||
if t.descriptor.Group == "" {
|
||||
for _, tool := range b.tools {
|
||||
tags := []string{tool.descriptor.Group}
|
||||
if tool.descriptor.Group == "" {
|
||||
tags = []string{b.name}
|
||||
}
|
||||
descs = append(descs, RouteDescription{
|
||||
Method: "POST",
|
||||
Path: "/" + t.descriptor.Name,
|
||||
Summary: t.descriptor.Description,
|
||||
Description: t.descriptor.Description,
|
||||
Path: "/" + tool.descriptor.Name,
|
||||
Summary: tool.descriptor.Description,
|
||||
Description: tool.descriptor.Description,
|
||||
Tags: tags,
|
||||
RequestBody: t.descriptor.InputSchema,
|
||||
Response: t.descriptor.OutputSchema,
|
||||
RequestBody: tool.descriptor.InputSchema,
|
||||
Response: tool.descriptor.OutputSchema,
|
||||
})
|
||||
}
|
||||
return descs
|
||||
|
|
@ -80,19 +86,19 @@ func (b *ToolBridge) Describe() []RouteDescription {
|
|||
// DescribeIter returns an iterator over OpenAPI route descriptions for all registered tools.
|
||||
func (b *ToolBridge) DescribeIter() iter.Seq[RouteDescription] {
|
||||
return func(yield func(RouteDescription) bool) {
|
||||
for _, t := range b.tools {
|
||||
tags := []string{t.descriptor.Group}
|
||||
if t.descriptor.Group == "" {
|
||||
for _, tool := range b.tools {
|
||||
tags := []string{tool.descriptor.Group}
|
||||
if tool.descriptor.Group == "" {
|
||||
tags = []string{b.name}
|
||||
}
|
||||
rd := RouteDescription{
|
||||
Method: "POST",
|
||||
Path: "/" + t.descriptor.Name,
|
||||
Summary: t.descriptor.Description,
|
||||
Description: t.descriptor.Description,
|
||||
Path: "/" + tool.descriptor.Name,
|
||||
Summary: tool.descriptor.Description,
|
||||
Description: tool.descriptor.Description,
|
||||
Tags: tags,
|
||||
RequestBody: t.descriptor.InputSchema,
|
||||
Response: t.descriptor.OutputSchema,
|
||||
RequestBody: tool.descriptor.InputSchema,
|
||||
Response: tool.descriptor.OutputSchema,
|
||||
}
|
||||
if !yield(rd) {
|
||||
return
|
||||
|
|
@ -104,8 +110,8 @@ func (b *ToolBridge) DescribeIter() iter.Seq[RouteDescription] {
|
|||
// Tools returns all registered tool descriptors.
|
||||
func (b *ToolBridge) Tools() []ToolDescriptor {
|
||||
descs := make([]ToolDescriptor, len(b.tools))
|
||||
for i, t := range b.tools {
|
||||
descs[i] = t.descriptor
|
||||
for i, tool := range b.tools {
|
||||
descs[i] = tool.descriptor
|
||||
}
|
||||
return descs
|
||||
}
|
||||
|
|
@ -113,8 +119,8 @@ func (b *ToolBridge) Tools() []ToolDescriptor {
|
|||
// ToolsIter returns an iterator over all registered tool descriptors.
|
||||
func (b *ToolBridge) ToolsIter() iter.Seq[ToolDescriptor] {
|
||||
return func(yield func(ToolDescriptor) bool) {
|
||||
for _, t := range b.tools {
|
||||
if !yield(t.descriptor) {
|
||||
for _, tool := range b.tools {
|
||||
if !yield(tool.descriptor) {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -232,3 +232,20 @@ func TestToolBridge_Good_IntegrationWithEngine(t *testing.T) {
|
|||
t.Fatalf("expected Data=%q, got %q", "pong", resp.Data)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolBridge_Ugly_NilHandlerDoesNotPanic(t *testing.T) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Fatalf("Add with nil handler panicked: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
bridge := api.NewToolBridge("/tools")
|
||||
// Adding a tool with a nil handler should not panic on Add itself.
|
||||
bridge.Add(api.ToolDescriptor{Name: "noop", Group: "test"}, nil)
|
||||
|
||||
tools := bridge.Tools()
|
||||
if len(tools) != 1 {
|
||||
t.Fatalf("expected 1 tool, got %d", len(tools))
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -6,9 +6,9 @@ import (
|
|||
"io"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"dappco.re/go/core"
|
||||
"github.com/andybalholm/brotli"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
|
@ -47,7 +47,7 @@ func newBrotliHandler(level int) *brotliHandler {
|
|||
|
||||
// Handle is the Gin middleware function that compresses responses with Brotli.
|
||||
func (h *brotliHandler) Handle(c *gin.Context) {
|
||||
if !strings.Contains(c.Request.Header.Get("Accept-Encoding"), "br") {
|
||||
if !core.Contains(c.Request.Header.Get("Accept-Encoding"), "br") {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
|
|
|||
|
|
@ -130,3 +130,27 @@ func TestWithBrotli_Good_CombinesWithOtherMiddleware(t *testing.T) {
|
|||
t.Fatal("expected X-Request-ID header from WithRequestID")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithBrotli_Ugly_InvalidLevelClampsToDefault(t *testing.T) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Fatalf("WithBrotli with invalid level panicked: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
// A level out of range should silently clamp to default, not panic.
|
||||
e, err := api.New(api.WithBrotli(999))
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
request, _ := http.NewRequest(http.MethodGet, "/health", nil)
|
||||
request.Header.Set("Accept-Encoding", "br")
|
||||
e.Handler().ServeHTTP(recorder, request)
|
||||
|
||||
if recorder.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d", recorder.Code)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
6
cache.go
6
cache.go
|
|
@ -89,9 +89,9 @@ func cacheMiddleware(store *cacheStore, ttl time.Duration) gin.HandlerFunc {
|
|||
|
||||
// Serve from cache if a valid entry exists.
|
||||
if entry := store.get(key); entry != nil {
|
||||
for k, vals := range entry.headers {
|
||||
for _, v := range vals {
|
||||
c.Writer.Header().Set(k, v)
|
||||
for headerName, headerValues := range entry.headers {
|
||||
for _, headerValue := range headerValues {
|
||||
c.Writer.Header().Set(headerName, headerValue)
|
||||
}
|
||||
}
|
||||
c.Writer.Header().Set("X-Cache", "HIT")
|
||||
|
|
|
|||
|
|
@ -250,3 +250,31 @@ func TestWithCache_Good_ExpiredCacheMisses(t *testing.T) {
|
|||
t.Fatalf("expected counter=2, got %d", grp.counter.Load())
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithCache_Ugly_ConcurrentGetsDontDeadlock(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
grp := &cacheCounterGroup{}
|
||||
engine, _ := api.New(api.WithCache(50 * time.Millisecond))
|
||||
engine.Register(grp)
|
||||
handler := engine.Handler()
|
||||
|
||||
// Fire many concurrent GET requests; none should deadlock.
|
||||
done := make(chan struct{})
|
||||
for requestIndex := 0; requestIndex < 20; requestIndex++ {
|
||||
go func() {
|
||||
recorder := httptest.NewRecorder()
|
||||
request, _ := http.NewRequest(http.MethodGet, "/cache/counter", nil)
|
||||
handler.ServeHTTP(recorder, request)
|
||||
done <- struct{}{}
|
||||
}()
|
||||
}
|
||||
|
||||
for requestIndex := 0; requestIndex < 20; requestIndex++ {
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("concurrent requests deadlocked")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,14 +4,12 @@ package api
|
|||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"forge.lthn.ai/core/cli/pkg/cli"
|
||||
|
||||
"dappco.re/go/core"
|
||||
coreio "dappco.re/go/core/io"
|
||||
coreerr "dappco.re/go/core/log"
|
||||
corelog "dappco.re/go/core/log"
|
||||
|
||||
goapi "dappco.re/go/core/api"
|
||||
)
|
||||
|
|
@ -26,7 +24,7 @@ func addSDKCommand(parent *cli.Command) {
|
|||
|
||||
cmd := cli.NewCommand("sdk", "Generate client SDKs from OpenAPI spec", "", func(cmd *cli.Command, args []string) error {
|
||||
if lang == "" {
|
||||
return coreerr.E("sdk.Generate", "--lang is required. Supported: "+strings.Join(goapi.SupportedLanguages(), ", "), nil)
|
||||
return coreerr.E("sdk.Generate", "--lang is required. Supported: "+core.Join(", ", goapi.SupportedLanguages()...), nil)
|
||||
}
|
||||
|
||||
// If no spec file provided, generate one to a temp file.
|
||||
|
|
@ -40,44 +38,47 @@ func addSDKCommand(parent *cli.Command) {
|
|||
bridge := goapi.NewToolBridge("/tools")
|
||||
groups := []goapi.RouteGroup{bridge}
|
||||
|
||||
tmpFile, err := os.CreateTemp("", "openapi-*.json")
|
||||
tmpPath := core.Path(core.Env("DIR_TMP"), "openapi-spec.json")
|
||||
writer, err := coreio.Local.Create(tmpPath)
|
||||
if err != nil {
|
||||
return coreerr.E("sdk.Generate", "create temp spec file", err)
|
||||
}
|
||||
defer coreio.Local.Delete(tmpFile.Name())
|
||||
|
||||
if err := goapi.ExportSpec(tmpFile, "json", builder, groups); err != nil {
|
||||
tmpFile.Close()
|
||||
if err := goapi.ExportSpec(writer, "json", builder, groups); err != nil {
|
||||
writer.Close()
|
||||
return coreerr.E("sdk.Generate", "generate spec", err)
|
||||
}
|
||||
tmpFile.Close()
|
||||
specFile = tmpFile.Name()
|
||||
writer.Close()
|
||||
defer coreio.Local.Delete(tmpPath)
|
||||
specFile = tmpPath
|
||||
}
|
||||
|
||||
gen := &goapi.SDKGenerator{
|
||||
SpecPath: specFile,
|
||||
OutputDir: output,
|
||||
PackageName: packageName,
|
||||
Stdout: cmd.OutOrStdout(),
|
||||
Stderr: cmd.ErrOrStderr(),
|
||||
}
|
||||
|
||||
if !gen.Available() {
|
||||
fmt.Fprintln(os.Stderr, "openapi-generator-cli not found. Install with:")
|
||||
fmt.Fprintln(os.Stderr, " brew install openapi-generator (macOS)")
|
||||
fmt.Fprintln(os.Stderr, " npm install @openapitools/openapi-generator-cli -g")
|
||||
corelog.Error("openapi-generator-cli not found. Install with:")
|
||||
corelog.Error(" brew install openapi-generator (macOS)")
|
||||
corelog.Error(" npm install @openapitools/openapi-generator-cli -g")
|
||||
return coreerr.E("sdk.Generate", "openapi-generator-cli not installed", nil)
|
||||
}
|
||||
|
||||
// Generate for each language.
|
||||
for l := range strings.SplitSeq(lang, ",") {
|
||||
l = strings.TrimSpace(l)
|
||||
if l == "" {
|
||||
for _, language := range core.Split(lang, ",") {
|
||||
language = core.Trim(language)
|
||||
if language == "" {
|
||||
continue
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "Generating %s SDK...\n", l)
|
||||
if err := gen.Generate(context.Background(), l); err != nil {
|
||||
return coreerr.E("sdk.Generate", "generate "+l, err)
|
||||
corelog.Info("generating " + language + " SDK...")
|
||||
if err := gen.Generate(context.Background(), language); err != nil {
|
||||
return coreerr.E("sdk.Generate", "generate "+language, err)
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, " Done: %s/%s/\n", output, l)
|
||||
corelog.Info("done: " + output + "/" + language + "/")
|
||||
}
|
||||
|
||||
return nil
|
||||
|
|
|
|||
|
|
@ -3,10 +3,8 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"forge.lthn.ai/core/cli/pkg/cli"
|
||||
corelog "dappco.re/go/core/log"
|
||||
|
||||
goapi "dappco.re/go/core/api"
|
||||
)
|
||||
|
|
@ -38,11 +36,11 @@ func addSpecCommand(parent *cli.Command) {
|
|||
if err := goapi.ExportSpecToFile(output, format, builder, groups); err != nil {
|
||||
return err
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "Spec written to %s\n", output)
|
||||
corelog.Info("spec written to " + output)
|
||||
return nil
|
||||
}
|
||||
|
||||
return goapi.ExportSpec(os.Stdout, format, builder, groups)
|
||||
return goapi.ExportSpec(cmd.OutOrStdout(), format, builder, groups)
|
||||
})
|
||||
|
||||
cli.StringFlag(cmd, &output, "output", "o", "", "Write spec to file instead of stdout")
|
||||
|
|
|
|||
52
codegen.go
52
codegen.go
|
|
@ -4,16 +4,16 @@ package api
|
|||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"iter"
|
||||
"maps"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
|
||||
"dappco.re/go/core"
|
||||
coreio "dappco.re/go/core/io"
|
||||
coreerr "dappco.re/go/core/log"
|
||||
coreexec "dappco.re/go/core/process/exec"
|
||||
coreprocess "dappco.re/go/core/process"
|
||||
)
|
||||
|
||||
// Supported SDK target languages.
|
||||
|
|
@ -32,6 +32,9 @@ var supportedLanguages = map[string]string{
|
|||
}
|
||||
|
||||
// SDKGenerator wraps openapi-generator-cli for SDK generation.
|
||||
//
|
||||
// gen := &api.SDKGenerator{SpecPath: "./openapi.json", OutputDir: "./sdk", PackageName: "myapi"}
|
||||
// if gen.Available() { gen.Generate(ctx, "go") }
|
||||
type SDKGenerator struct {
|
||||
// SpecPath is the path to the OpenAPI spec file (JSON or YAML).
|
||||
SpecPath string
|
||||
|
|
@ -41,29 +44,48 @@ type SDKGenerator struct {
|
|||
|
||||
// PackageName is the name used for the generated package/module.
|
||||
PackageName string
|
||||
|
||||
// Stdout receives command output (defaults to io.Discard when nil).
|
||||
Stdout io.Writer
|
||||
|
||||
// Stderr receives command error output (defaults to io.Discard when nil).
|
||||
Stderr io.Writer
|
||||
}
|
||||
|
||||
// Generate creates an SDK for the given language using openapi-generator-cli.
|
||||
// The language must be one of the supported languages returned by SupportedLanguages().
|
||||
//
|
||||
// err := gen.Generate(ctx, "go")
|
||||
// err := gen.Generate(ctx, "python")
|
||||
func (g *SDKGenerator) Generate(ctx context.Context, language string) error {
|
||||
generator, ok := supportedLanguages[language]
|
||||
if !ok {
|
||||
return coreerr.E("SDKGenerator.Generate", fmt.Sprintf("unsupported language %q: supported languages are %v", language, SupportedLanguages()), nil)
|
||||
return coreerr.E("SDKGenerator.Generate", core.Sprintf("unsupported language %q: supported languages are %v", language, SupportedLanguages()), nil)
|
||||
}
|
||||
|
||||
if _, err := os.Stat(g.SpecPath); os.IsNotExist(err) {
|
||||
if !coreio.Local.IsFile(g.SpecPath) {
|
||||
return coreerr.E("SDKGenerator.Generate", "spec file not found: "+g.SpecPath, nil)
|
||||
}
|
||||
|
||||
outputDir := filepath.Join(g.OutputDir, language)
|
||||
outputDir := core.Path(g.OutputDir, language)
|
||||
if err := coreio.Local.EnsureDir(outputDir); err != nil {
|
||||
return coreerr.E("SDKGenerator.Generate", "create output directory", err)
|
||||
}
|
||||
|
||||
args := g.buildArgs(generator, outputDir)
|
||||
cmd := exec.CommandContext(ctx, "openapi-generator-cli", args...)
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
|
||||
stdout := g.Stdout
|
||||
if stdout == nil {
|
||||
stdout = io.Discard
|
||||
}
|
||||
stderr := g.Stderr
|
||||
if stderr == nil {
|
||||
stderr = io.Discard
|
||||
}
|
||||
|
||||
cmd := coreexec.Command(ctx, "openapi-generator-cli", args...).
|
||||
WithStdout(stdout).
|
||||
WithStderr(stderr)
|
||||
|
||||
if err := cmd.Run(); err != nil {
|
||||
return coreerr.E("SDKGenerator.Generate", "openapi-generator-cli failed for "+language, err)
|
||||
|
|
@ -87,18 +109,24 @@ func (g *SDKGenerator) buildArgs(generator, outputDir string) []string {
|
|||
}
|
||||
|
||||
// Available checks if openapi-generator-cli is installed and accessible.
|
||||
//
|
||||
// if gen.Available() { gen.Generate(ctx, "go") }
|
||||
func (g *SDKGenerator) Available() bool {
|
||||
_, err := exec.LookPath("openapi-generator-cli")
|
||||
return err == nil
|
||||
prog := &coreprocess.Program{Name: "openapi-generator-cli"}
|
||||
return prog.Find() == nil
|
||||
}
|
||||
|
||||
// SupportedLanguages returns the list of supported SDK target languages
|
||||
// in sorted order for deterministic output.
|
||||
//
|
||||
// langs := api.SupportedLanguages() // ["csharp", "go", "java", ...]
|
||||
func SupportedLanguages() []string {
|
||||
return slices.Sorted(maps.Keys(supportedLanguages))
|
||||
}
|
||||
|
||||
// SupportedLanguagesIter returns an iterator over supported SDK target languages in sorted order.
|
||||
//
|
||||
// for lang := range api.SupportedLanguagesIter() { fmt.Println(lang) }
|
||||
func SupportedLanguagesIter() iter.Seq[string] {
|
||||
return slices.Values(SupportedLanguages())
|
||||
}
|
||||
|
|
|
|||
|
|
@ -92,3 +92,24 @@ func TestSDKGenerator_Good_Available(t *testing.T) {
|
|||
// Just verify it returns a bool and does not panic.
|
||||
_ = gen.Available()
|
||||
}
|
||||
|
||||
func TestSupportedLanguages_Ugly_CalledRepeatedly(t *testing.T) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Fatalf("SupportedLanguages called repeatedly panicked: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
// Calling multiple times should always return the same sorted slice.
|
||||
first := api.SupportedLanguages()
|
||||
second := api.SupportedLanguages()
|
||||
|
||||
if len(first) != len(second) {
|
||||
t.Fatalf("inconsistent results: %d vs %d", len(first), len(second))
|
||||
}
|
||||
for languageIndex, language := range first {
|
||||
if language != second[languageIndex] {
|
||||
t.Fatalf("mismatch at index %d: %q vs %q", languageIndex, language, second[languageIndex])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
28
export.go
28
export.go
|
|
@ -3,13 +3,11 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
|
||||
"dappco.re/go/core"
|
||||
coreio "dappco.re/go/core/io"
|
||||
coreerr "dappco.re/go/core/log"
|
||||
)
|
||||
|
|
@ -29,15 +27,16 @@ func ExportSpec(w io.Writer, format string, builder *SpecBuilder, groups []Route
|
|||
case "yaml":
|
||||
// Unmarshal JSON then re-marshal as YAML.
|
||||
var obj any
|
||||
if err := json.Unmarshal(data, &obj); err != nil {
|
||||
return coreerr.E("ExportSpec", "unmarshal spec", err)
|
||||
result := core.JSONUnmarshal(data, &obj)
|
||||
if !result.OK {
|
||||
return coreerr.E("ExportSpec", "unmarshal spec", result.Value.(error))
|
||||
}
|
||||
enc := yaml.NewEncoder(w)
|
||||
enc.SetIndent(2)
|
||||
if err := enc.Encode(obj); err != nil {
|
||||
encoder := yaml.NewEncoder(w)
|
||||
encoder.SetIndent(2)
|
||||
if err := encoder.Encode(obj); err != nil {
|
||||
return coreerr.E("ExportSpec", "encode yaml", err)
|
||||
}
|
||||
return enc.Close()
|
||||
return encoder.Close()
|
||||
default:
|
||||
return coreerr.E("ExportSpec", "unsupported format "+format+": use \"json\" or \"yaml\"", nil)
|
||||
}
|
||||
|
|
@ -45,14 +44,17 @@ func ExportSpec(w io.Writer, format string, builder *SpecBuilder, groups []Route
|
|||
|
||||
// ExportSpecToFile writes the spec to the given path.
|
||||
// The parent directory is created if it does not exist.
|
||||
//
|
||||
// err := api.ExportSpecToFile("./docs/openapi.json", "json", builder, groups)
|
||||
// err := api.ExportSpecToFile("./docs/openapi.yaml", "yaml", builder, groups)
|
||||
func ExportSpecToFile(path, format string, builder *SpecBuilder, groups []RouteGroup) error {
|
||||
if err := coreio.Local.EnsureDir(filepath.Dir(path)); err != nil {
|
||||
if err := coreio.Local.EnsureDir(core.PathDir(path)); err != nil {
|
||||
return coreerr.E("ExportSpecToFile", "create directory", err)
|
||||
}
|
||||
f, err := os.Create(path)
|
||||
writer, err := coreio.Local.Create(path)
|
||||
if err != nil {
|
||||
return coreerr.E("ExportSpecToFile", "create file", err)
|
||||
}
|
||||
defer f.Close()
|
||||
return ExportSpec(f, format, builder, groups)
|
||||
defer writer.Close()
|
||||
return ExportSpec(writer, format, builder, groups)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -164,3 +164,19 @@ func TestExportSpec_Good_WithToolBridge(t *testing.T) {
|
|||
t.Fatal("expected /tools/metrics_query path in spec")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExportSpec_Ugly_EmptyFormatDoesNotPanic(t *testing.T) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Fatalf("ExportSpec with empty format panicked: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
builder := &api.SpecBuilder{Title: "Test", Version: "1.0.0"}
|
||||
var output strings.Builder
|
||||
// Unknown format should return an error, not panic.
|
||||
err := api.ExportSpec(&output, "xml", builder, nil)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for unsupported format, got nil")
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -139,3 +139,27 @@ func TestWithExpvar_Bad_NotMountedWithoutOption(t *testing.T) {
|
|||
t.Fatalf("expected 404 for /debug/vars without WithExpvar, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithExpvar_Ugly_DoubleRegistrationDoesNotPanic(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Fatalf("double WithExpvar panicked: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
// Registering expvar twice should not panic.
|
||||
engine, err := api.New(api.WithExpvar(), api.WithExpvar())
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
request, _ := http.NewRequest(http.MethodGet, "/health", nil)
|
||||
engine.Handler().ServeHTTP(recorder, request)
|
||||
|
||||
if recorder.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d", recorder.Code)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
7
go.mod
7
go.mod
|
|
@ -3,8 +3,10 @@ module dappco.re/go/core/api
|
|||
go 1.26.0
|
||||
|
||||
require (
|
||||
dappco.re/go/core/io v0.1.7
|
||||
dappco.re/go/core/log v0.0.4
|
||||
dappco.re/go/core v0.8.0-alpha.1
|
||||
dappco.re/go/core/io v0.2.0
|
||||
dappco.re/go/core/log v0.1.0
|
||||
dappco.re/go/core/process v0.0.0-00010101000000-000000000000
|
||||
forge.lthn.ai/core/cli v0.3.7
|
||||
github.com/99designs/gqlgen v0.17.88
|
||||
github.com/andybalholm/brotli v1.2.0
|
||||
|
|
@ -134,4 +136,5 @@ replace (
|
|||
dappco.re/go/core/i18n => ../go-i18n
|
||||
dappco.re/go/core/io => ../go-io
|
||||
dappco.re/go/core/log => ../go-log
|
||||
dappco.re/go/core/process => ../go-process
|
||||
)
|
||||
|
|
|
|||
28
graphql.go
28
graphql.go
|
|
@ -26,38 +26,38 @@ type GraphQLOption func(*graphqlConfig)
|
|||
|
||||
// WithPlayground enables the GraphQL Playground UI at {path}/playground.
|
||||
func WithPlayground() GraphQLOption {
|
||||
return func(cfg *graphqlConfig) {
|
||||
cfg.playground = true
|
||||
return func(config *graphqlConfig) {
|
||||
config.playground = true
|
||||
}
|
||||
}
|
||||
|
||||
// WithGraphQLPath sets a custom URL path for the GraphQL endpoint.
|
||||
// The default path is "/graphql".
|
||||
func WithGraphQLPath(path string) GraphQLOption {
|
||||
return func(cfg *graphqlConfig) {
|
||||
cfg.path = path
|
||||
return func(config *graphqlConfig) {
|
||||
config.path = path
|
||||
}
|
||||
}
|
||||
|
||||
// mountGraphQL registers the GraphQL handler and optional playground on the Gin engine.
|
||||
func mountGraphQL(r *gin.Engine, cfg *graphqlConfig) {
|
||||
srv := handler.NewDefaultServer(cfg.schema)
|
||||
graphqlHandler := gin.WrapH(srv)
|
||||
func mountGraphQL(router *gin.Engine, config *graphqlConfig) {
|
||||
graphqlServer := handler.NewDefaultServer(config.schema)
|
||||
graphqlHandler := gin.WrapH(graphqlServer)
|
||||
|
||||
// Mount the GraphQL endpoint for all HTTP methods (POST for queries/mutations,
|
||||
// GET for playground redirects and introspection).
|
||||
r.Any(cfg.path, graphqlHandler)
|
||||
router.Any(config.path, graphqlHandler)
|
||||
|
||||
if cfg.playground {
|
||||
playgroundPath := cfg.path + "/playground"
|
||||
playgroundHandler := playground.Handler("GraphQL", cfg.path)
|
||||
r.GET(playgroundPath, wrapHTTPHandler(playgroundHandler))
|
||||
if config.playground {
|
||||
playgroundPath := config.path + "/playground"
|
||||
playgroundHandler := playground.Handler("GraphQL", config.path)
|
||||
router.GET(playgroundPath, wrapHTTPHandler(playgroundHandler))
|
||||
}
|
||||
}
|
||||
|
||||
// wrapHTTPHandler adapts a standard http.Handler to a Gin handler function.
|
||||
func wrapHTTPHandler(h http.Handler) gin.HandlerFunc {
|
||||
func wrapHTTPHandler(handler http.Handler) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
h.ServeHTTP(c.Writer, c.Request)
|
||||
handler.ServeHTTP(c.Writer, c.Request)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -232,3 +232,30 @@ func TestWithGraphQL_Good_CombinesWithOtherMiddleware(t *testing.T) {
|
|||
t.Fatalf("expected response containing name:test, got %q", string(respBody))
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithGraphQL_Ugly_DoubleRegistrationDoesNotPanic(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Fatalf("double WithGraphQL panicked: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
schema := newTestSchema()
|
||||
// Registering two GraphQL schemas with different paths must not panic.
|
||||
engine, err := api.New(
|
||||
api.WithGraphQL(schema, api.WithGraphQLPath("/graphql")),
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
request, _ := http.NewRequest(http.MethodGet, "/health", nil)
|
||||
engine.Handler().ServeHTTP(recorder, request)
|
||||
|
||||
if recorder.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d", recorder.Code)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -224,3 +224,28 @@ func TestDescribableGroup_Bad_NilSchemas(t *testing.T) {
|
|||
t.Fatalf("expected nil Response, got %v", descs[0].Response)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouteGroup_Ugly_EmptyBasePathDoesNotPanic(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Fatalf("Register with empty BasePath panicked: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
// A group with an empty base path should mount at root without panicking.
|
||||
engine, err := api.New()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
engine.Register(&stubGroup{})
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
request, _ := http.NewRequest(http.MethodGet, "/health", nil)
|
||||
engine.Handler().ServeHTTP(recorder, request)
|
||||
|
||||
if recorder.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d", recorder.Code)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
24
gzip_test.go
24
gzip_test.go
|
|
@ -131,3 +131,27 @@ func TestWithGzip_Good_CombinesWithOtherMiddleware(t *testing.T) {
|
|||
t.Fatal("expected X-Request-ID header from WithRequestID")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithGzip_Ugly_NilBodyDoesNotPanic(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Fatalf("gzip handler panicked on nil body: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
engine, err := api.New(api.WithGzip())
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
request, _ := http.NewRequest(http.MethodGet, "/health", nil)
|
||||
request.Header.Set("Accept-Encoding", "gzip")
|
||||
engine.Handler().ServeHTTP(recorder, request)
|
||||
|
||||
if recorder.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d", recorder.Code)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
4
i18n.go
4
i18n.go
|
|
@ -50,8 +50,8 @@ func WithI18n(cfg ...I18nConfig) Option {
|
|||
|
||||
// Build the language.Matcher from supported locales.
|
||||
tags := []language.Tag{language.Make(config.DefaultLocale)}
|
||||
for _, s := range config.Supported {
|
||||
tag := language.Make(s)
|
||||
for _, supportedLocale := range config.Supported {
|
||||
tag := language.Make(supportedLocale)
|
||||
// Avoid duplicating the default if it also appears in Supported.
|
||||
if tag != tags[0] {
|
||||
tags = append(tags, tag)
|
||||
|
|
|
|||
28
i18n_test.go
28
i18n_test.go
|
|
@ -224,3 +224,31 @@ func TestWithI18n_Good_LooksUpMessage(t *testing.T) {
|
|||
t.Fatalf("expected message=%q, got %q", "Hello", respEn.Data.Message)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithI18n_Ugly_MalformedAcceptLanguageDoesNotPanic(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Fatalf("i18n middleware panicked on malformed Accept-Language: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
engine, err := api.New(api.WithI18n(api.I18nConfig{
|
||||
DefaultLocale: "en",
|
||||
Supported: []string{"en", "fr"},
|
||||
}))
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
// Gibberish Accept-Language should fall back to default, not panic.
|
||||
request, _ := http.NewRequest(http.MethodGet, "/health", nil)
|
||||
request.Header.Set("Accept-Language", ";;;invalid;;;")
|
||||
engine.Handler().ServeHTTP(recorder, request)
|
||||
|
||||
if recorder.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d", recorder.Code)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -178,3 +178,27 @@ func TestWithLocation_Good_BothHeadersCombined(t *testing.T) {
|
|||
t.Fatalf("expected host=%q, got %q", "secure.example.com", resp.Data["host"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithLocation_Ugly_MissingHeadersDoesNotPanic(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Fatalf("location middleware panicked on missing headers: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
engine, err := api.New(api.WithLocation())
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
// No X-Forwarded-Proto or X-Forwarded-Host headers — should not panic.
|
||||
request, _ := http.NewRequest(http.MethodGet, "/health", nil)
|
||||
engine.Handler().ServeHTTP(recorder, request)
|
||||
|
||||
if recorder.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d", recorder.Code)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -6,8 +6,8 @@ import (
|
|||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"dappco.re/go/core"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
|
|
@ -18,7 +18,7 @@ func bearerAuthMiddleware(token string, skip []string) gin.HandlerFunc {
|
|||
return func(c *gin.Context) {
|
||||
// Check whether the request path should bypass authentication.
|
||||
for _, path := range skip {
|
||||
if strings.HasPrefix(c.Request.URL.Path, path) {
|
||||
if core.HasPrefix(c.Request.URL.Path, path) {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
|
@ -30,8 +30,8 @@ func bearerAuthMiddleware(token string, skip []string) gin.HandlerFunc {
|
|||
return
|
||||
}
|
||||
|
||||
parts := strings.SplitN(header, " ", 2)
|
||||
if len(parts) != 2 || !strings.EqualFold(parts[0], "Bearer") || parts[1] != token {
|
||||
parts := core.SplitN(header, " ", 2)
|
||||
if len(parts) != 2 || core.Lower(parts[0]) != "bearer" || parts[1] != token {
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, Fail("unauthorised", "invalid bearer token"))
|
||||
return
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3,11 +3,11 @@
|
|||
package api_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"dappco.re/go/core"
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
api "dappco.re/go/core/api"
|
||||
|
|
@ -43,8 +43,8 @@ func TestBearerAuth_Bad_MissingToken(t *testing.T) {
|
|||
}
|
||||
|
||||
var resp api.Response[any]
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("unmarshal error: %v", err)
|
||||
if result := core.JSONUnmarshal(w.Body.Bytes(), &resp); !result.OK {
|
||||
t.Fatalf("unmarshal error: %v", result.Value)
|
||||
}
|
||||
if resp.Error == nil || resp.Error.Code != "unauthorised" {
|
||||
t.Fatalf("expected error code=%q, got %+v", "unauthorised", resp.Error)
|
||||
|
|
@ -67,8 +67,8 @@ func TestBearerAuth_Bad_WrongToken(t *testing.T) {
|
|||
}
|
||||
|
||||
var resp api.Response[any]
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("unmarshal error: %v", err)
|
||||
if result := core.JSONUnmarshal(w.Body.Bytes(), &resp); !result.OK {
|
||||
t.Fatalf("unmarshal error: %v", result.Value)
|
||||
}
|
||||
if resp.Error == nil || resp.Error.Code != "unauthorised" {
|
||||
t.Fatalf("expected error code=%q, got %+v", "unauthorised", resp.Error)
|
||||
|
|
@ -91,8 +91,8 @@ func TestBearerAuth_Good_CorrectToken(t *testing.T) {
|
|||
}
|
||||
|
||||
var resp api.Response[string]
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("unmarshal error: %v", err)
|
||||
if result := core.JSONUnmarshal(w.Body.Bytes(), &resp); !result.OK {
|
||||
t.Fatalf("unmarshal error: %v", result.Value)
|
||||
}
|
||||
if resp.Data != "classified" {
|
||||
t.Fatalf("expected Data=%q, got %q", "classified", resp.Data)
|
||||
|
|
@ -218,3 +218,29 @@ func TestCORS_Bad_DisallowedOrigin(t *testing.T) {
|
|||
t.Fatalf("expected no Access-Control-Allow-Origin for disallowed origin, got %q", origin)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBearerAuth_Ugly_MalformedAuthHeaderDoesNotPanic(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Fatalf("bearerAuth panicked on malformed header: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
engine, err := api.New(api.WithBearerAuth("secret"))
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
engine.Register(&mwTestGroup{})
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
// Only one word — no space — should return 401, not panic.
|
||||
request, _ := http.NewRequest(http.MethodGet, "/v1/secret", nil)
|
||||
request.Header.Set("Authorization", "BearerNOSPACE")
|
||||
engine.Handler().ServeHTTP(recorder, request)
|
||||
|
||||
if recorder.Code != http.StatusUnauthorized {
|
||||
t.Fatalf("expected 401, got %d", recorder.Code)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -9,14 +9,22 @@ import (
|
|||
api "dappco.re/go/core/api"
|
||||
)
|
||||
|
||||
func TestEngine_GroupsIter(t *testing.T) {
|
||||
e, _ := api.New()
|
||||
g1 := &healthGroup{}
|
||||
e.Register(g1)
|
||||
type streamGroupStub struct {
|
||||
healthGroup
|
||||
channels []string
|
||||
}
|
||||
|
||||
func (s *streamGroupStub) Channels() []string { return s.channels }
|
||||
|
||||
// ── GroupsIter ────────────────────────────────────────────────────────
|
||||
|
||||
func TestModernization_GroupsIter_Good(t *testing.T) {
|
||||
engine, _ := api.New()
|
||||
engine.Register(&healthGroup{})
|
||||
|
||||
var groups []api.RouteGroup
|
||||
for g := range e.GroupsIter() {
|
||||
groups = append(groups, g)
|
||||
for group := range engine.GroupsIter() {
|
||||
groups = append(groups, group)
|
||||
}
|
||||
|
||||
if len(groups) != 1 {
|
||||
|
|
@ -27,23 +35,42 @@ func TestEngine_GroupsIter(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
type streamGroupStub struct {
|
||||
healthGroup
|
||||
channels []string
|
||||
func TestModernization_GroupsIter_Bad(t *testing.T) {
|
||||
engine, _ := api.New()
|
||||
// No groups registered — iterator should yield nothing.
|
||||
var groups []api.RouteGroup
|
||||
for group := range engine.GroupsIter() {
|
||||
groups = append(groups, group)
|
||||
}
|
||||
if len(groups) != 0 {
|
||||
t.Fatalf("expected 0 groups with no registration, got %d", len(groups))
|
||||
}
|
||||
}
|
||||
|
||||
func (s *streamGroupStub) Channels() []string { return s.channels }
|
||||
func TestModernization_GroupsIter_Ugly(t *testing.T) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Fatalf("GroupsIter on nil groups panicked: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
func TestEngine_ChannelsIter(t *testing.T) {
|
||||
e, _ := api.New()
|
||||
g1 := &streamGroupStub{channels: []string{"ch1", "ch2"}}
|
||||
g2 := &streamGroupStub{channels: []string{"ch3"}}
|
||||
e.Register(g1)
|
||||
e.Register(g2)
|
||||
engine, _ := api.New()
|
||||
// Iterating immediately without any Register call must not panic.
|
||||
for range engine.GroupsIter() {
|
||||
t.Fatal("expected no iterations")
|
||||
}
|
||||
}
|
||||
|
||||
// ── ChannelsIter ──────────────────────────────────────────────────────
|
||||
|
||||
func TestModernization_ChannelsIter_Good(t *testing.T) {
|
||||
engine, _ := api.New()
|
||||
engine.Register(&streamGroupStub{channels: []string{"ch1", "ch2"}})
|
||||
engine.Register(&streamGroupStub{channels: []string{"ch3"}})
|
||||
|
||||
var channels []string
|
||||
for ch := range e.ChannelsIter() {
|
||||
channels = append(channels, ch)
|
||||
for channelName := range engine.ChannelsIter() {
|
||||
channels = append(channels, channelName)
|
||||
}
|
||||
|
||||
expected := []string{"ch1", "ch2", "ch3"}
|
||||
|
|
@ -52,42 +79,134 @@ func TestEngine_ChannelsIter(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestToolBridge_Iterators(t *testing.T) {
|
||||
b := api.NewToolBridge("/tools")
|
||||
desc := api.ToolDescriptor{Name: "test", Group: "g1"}
|
||||
b.Add(desc, nil)
|
||||
func TestModernization_ChannelsIter_Bad(t *testing.T) {
|
||||
engine, _ := api.New()
|
||||
// Register a group that has no Channels() — ChannelsIter must skip it.
|
||||
engine.Register(&healthGroup{})
|
||||
|
||||
var channels []string
|
||||
for channelName := range engine.ChannelsIter() {
|
||||
channels = append(channels, channelName)
|
||||
}
|
||||
|
||||
if len(channels) != 0 {
|
||||
t.Fatalf("expected 0 channels for non-StreamGroup, got %v", channels)
|
||||
}
|
||||
}
|
||||
|
||||
func TestModernization_ChannelsIter_Ugly(t *testing.T) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Fatalf("ChannelsIter panicked: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
engine, _ := api.New()
|
||||
// Group with empty channel list must not panic during iteration.
|
||||
engine.Register(&streamGroupStub{channels: []string{}})
|
||||
for range engine.ChannelsIter() {
|
||||
t.Fatal("expected no iterations for empty channel list")
|
||||
}
|
||||
}
|
||||
|
||||
// ── ToolBridge iterators ──────────────────────────────────────────────
|
||||
|
||||
func TestModernization_ToolBridgeIterators_Good(t *testing.T) {
|
||||
bridge := api.NewToolBridge("/tools")
|
||||
bridge.Add(api.ToolDescriptor{Name: "test", Group: "g1"}, nil)
|
||||
|
||||
// Test ToolsIter
|
||||
var tools []api.ToolDescriptor
|
||||
for t := range b.ToolsIter() {
|
||||
tools = append(tools, t)
|
||||
for tool := range bridge.ToolsIter() {
|
||||
tools = append(tools, tool)
|
||||
}
|
||||
if len(tools) != 1 || tools[0].Name != "test" {
|
||||
t.Errorf("ToolsIter failed, got %v", tools)
|
||||
}
|
||||
|
||||
// Test DescribeIter
|
||||
var descs []api.RouteDescription
|
||||
for d := range b.DescribeIter() {
|
||||
descs = append(descs, d)
|
||||
for desc := range bridge.DescribeIter() {
|
||||
descs = append(descs, desc)
|
||||
}
|
||||
if len(descs) != 1 || descs[0].Path != "/test" {
|
||||
t.Errorf("DescribeIter failed, got %v", descs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCodegen_SupportedLanguagesIter(t *testing.T) {
|
||||
func TestModernization_ToolBridgeIterators_Bad(t *testing.T) {
|
||||
bridge := api.NewToolBridge("/tools")
|
||||
// Empty bridge — iterators must yield nothing.
|
||||
for range bridge.ToolsIter() {
|
||||
t.Fatal("expected no iterations on empty bridge (ToolsIter)")
|
||||
}
|
||||
for range bridge.DescribeIter() {
|
||||
t.Fatal("expected no iterations on empty bridge (DescribeIter)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestModernization_ToolBridgeIterators_Ugly(t *testing.T) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Fatalf("ToolBridge iterator with nil handler panicked: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
bridge := api.NewToolBridge("/tools")
|
||||
bridge.Add(api.ToolDescriptor{Name: "noop"}, nil)
|
||||
|
||||
var toolCount int
|
||||
for range bridge.ToolsIter() {
|
||||
toolCount++
|
||||
}
|
||||
if toolCount != 1 {
|
||||
t.Fatalf("expected 1 tool, got %d", toolCount)
|
||||
}
|
||||
}
|
||||
|
||||
// ── SupportedLanguagesIter ────────────────────────────────────────────
|
||||
|
||||
func TestModernization_SupportedLanguagesIter_Good(t *testing.T) {
|
||||
var langs []string
|
||||
for l := range api.SupportedLanguagesIter() {
|
||||
langs = append(langs, l)
|
||||
for language := range api.SupportedLanguagesIter() {
|
||||
langs = append(langs, language)
|
||||
}
|
||||
|
||||
if !slices.Contains(langs, "go") {
|
||||
t.Errorf("SupportedLanguagesIter missing 'go'")
|
||||
}
|
||||
|
||||
// Should be sorted
|
||||
if !slices.IsSorted(langs) {
|
||||
t.Errorf("SupportedLanguagesIter should be sorted, got %v", langs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestModernization_SupportedLanguagesIter_Bad(t *testing.T) {
|
||||
// Iterator and slice function must agree on count.
|
||||
iterCount := 0
|
||||
for range api.SupportedLanguagesIter() {
|
||||
iterCount++
|
||||
}
|
||||
sliceCount := len(api.SupportedLanguages())
|
||||
if iterCount != sliceCount {
|
||||
t.Fatalf("SupportedLanguagesIter count %d != SupportedLanguages count %d", iterCount, sliceCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestModernization_SupportedLanguagesIter_Ugly(t *testing.T) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Fatalf("SupportedLanguagesIter panicked: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
// Calling multiple times concurrently should not panic.
|
||||
done := make(chan struct{}, 5)
|
||||
for goroutineIndex := 0; goroutineIndex < 5; goroutineIndex++ {
|
||||
go func() {
|
||||
for range api.SupportedLanguagesIter() {
|
||||
}
|
||||
done <- struct{}{}
|
||||
}()
|
||||
}
|
||||
for goroutineIndex := 0; goroutineIndex < 5; goroutineIndex++ {
|
||||
<-done
|
||||
}
|
||||
}
|
||||
|
|
|
|||
25
openapi.go
25
openapi.go
|
|
@ -2,10 +2,7 @@
|
|||
|
||||
package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
)
|
||||
import "dappco.re/go/core"
|
||||
|
||||
// SpecBuilder constructs an OpenAPI 3.1 specification from registered RouteGroups.
|
||||
type SpecBuilder struct {
|
||||
|
|
@ -54,7 +51,11 @@ func (sb *SpecBuilder) Build(groups []RouteGroup) ([]byte, error) {
|
|||
},
|
||||
}
|
||||
|
||||
return json.MarshalIndent(spec, "", " ")
|
||||
result := core.JSONMarshal(spec)
|
||||
if !result.OK {
|
||||
return nil, result.Value.(error)
|
||||
}
|
||||
return result.Value.([]byte), nil
|
||||
}
|
||||
|
||||
// buildPaths generates the paths object from all DescribableGroups.
|
||||
|
|
@ -80,14 +81,14 @@ func (sb *SpecBuilder) buildPaths(groups []RouteGroup) map[string]any {
|
|||
},
|
||||
}
|
||||
|
||||
for _, g := range groups {
|
||||
dg, ok := g.(DescribableGroup)
|
||||
for _, group := range groups {
|
||||
describableGroup, ok := group.(DescribableGroup)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
for _, rd := range dg.Describe() {
|
||||
fullPath := g.BasePath() + rd.Path
|
||||
method := strings.ToLower(rd.Method)
|
||||
for _, rd := range describableGroup.Describe() {
|
||||
fullPath := group.BasePath() + rd.Path
|
||||
method := core.Lower(rd.Method)
|
||||
|
||||
operation := map[string]any{
|
||||
"summary": rd.Summary,
|
||||
|
|
@ -146,8 +147,8 @@ func (sb *SpecBuilder) buildTags(groups []RouteGroup) []map[string]any {
|
|||
}
|
||||
seen := map[string]bool{"system": true}
|
||||
|
||||
for _, g := range groups {
|
||||
name := g.Name()
|
||||
for _, group := range groups {
|
||||
name := group.Name()
|
||||
if !seen[name] {
|
||||
tags = append(tags, map[string]any{
|
||||
"name": name,
|
||||
|
|
|
|||
|
|
@ -401,3 +401,21 @@ func TestSpecBuilder_Bad_InfoFields(t *testing.T) {
|
|||
t.Fatalf("expected version=1.0.0, got %v", info["version"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestSpecBuilder_Ugly_NilGroupsDoesNotPanic(t *testing.T) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Fatalf("Build with nil groups panicked: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
builder := &api.SpecBuilder{Title: "Test", Version: "0.0.1"}
|
||||
// Passing nil as groups should return a valid spec without panicking.
|
||||
data, err := builder.Build(nil)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(data) == 0 {
|
||||
t.Fatal("expected non-empty spec data")
|
||||
}
|
||||
}
|
||||
|
|
|
|||
16
options.go
16
options.go
|
|
@ -58,20 +58,20 @@ func WithRequestID() Option {
|
|||
// headers (Authorization, Content-Type, X-Request-ID) are permitted.
|
||||
func WithCORS(allowOrigins ...string) Option {
|
||||
return func(e *Engine) {
|
||||
cfg := cors.Config{
|
||||
corsConfig := cors.Config{
|
||||
AllowMethods: []string{"GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"},
|
||||
AllowHeaders: []string{"Authorization", "Content-Type", "X-Request-ID"},
|
||||
MaxAge: 12 * time.Hour,
|
||||
}
|
||||
|
||||
if slices.Contains(allowOrigins, "*") {
|
||||
cfg.AllowAllOrigins = true
|
||||
corsConfig.AllowAllOrigins = true
|
||||
}
|
||||
if !cfg.AllowAllOrigins {
|
||||
cfg.AllowOrigins = allowOrigins
|
||||
if !corsConfig.AllowAllOrigins {
|
||||
corsConfig.AllowOrigins = allowOrigins
|
||||
}
|
||||
|
||||
e.middlewares = append(e.middlewares, cors.New(cfg))
|
||||
e.middlewares = append(e.middlewares, cors.New(corsConfig))
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -313,13 +313,13 @@ func WithLocation() Option {
|
|||
// )
|
||||
func WithGraphQL(schema graphql.ExecutableSchema, opts ...GraphQLOption) Option {
|
||||
return func(e *Engine) {
|
||||
cfg := &graphqlConfig{
|
||||
graphqlCfg := &graphqlConfig{
|
||||
schema: schema,
|
||||
path: defaultGraphQLPath,
|
||||
}
|
||||
for _, opt := range opts {
|
||||
opt(cfg)
|
||||
opt(graphqlCfg)
|
||||
}
|
||||
e.graphql = cfg
|
||||
e.graphql = graphqlCfg
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -122,3 +122,27 @@ func TestWithPprof_Good_CmdlineEndpointExists(t *testing.T) {
|
|||
t.Fatalf("expected 200 for /debug/pprof/cmdline, got %d", resp.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithPprof_Ugly_DoubleRegistrationDoesNotPanic(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Fatalf("double WithPprof panicked: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
// Registering pprof twice should not panic on engine construction.
|
||||
engine, err := api.New(api.WithPprof(), api.WithPprof())
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
request, _ := http.NewRequest(http.MethodGet, "/health", nil)
|
||||
engine.Handler().ServeHTTP(recorder, request)
|
||||
|
||||
if recorder.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d", recorder.Code)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -27,6 +27,9 @@ type Meta struct {
|
|||
}
|
||||
|
||||
// OK wraps data in a successful response envelope.
|
||||
//
|
||||
// c.JSON(http.StatusOK, api.OK(user))
|
||||
// c.JSON(http.StatusOK, api.OK("healthy"))
|
||||
func OK[T any](data T) Response[T] {
|
||||
return Response[T]{
|
||||
Success: true,
|
||||
|
|
@ -35,6 +38,8 @@ func OK[T any](data T) Response[T] {
|
|||
}
|
||||
|
||||
// Fail creates an error response with the given code and message.
|
||||
//
|
||||
// c.AbortWithStatusJSON(http.StatusUnauthorized, api.Fail("unauthorised", "token expired"))
|
||||
func Fail(code, message string) Response[any] {
|
||||
return Response[any]{
|
||||
Success: false,
|
||||
|
|
@ -46,6 +51,8 @@ func Fail(code, message string) Response[any] {
|
|||
}
|
||||
|
||||
// FailWithDetails creates an error response with additional detail payload.
|
||||
//
|
||||
// c.JSON(http.StatusBadRequest, api.FailWithDetails("validation", "invalid input", fieldErrors))
|
||||
func FailWithDetails(code, message string, details any) Response[any] {
|
||||
return Response[any]{
|
||||
Success: false,
|
||||
|
|
@ -58,6 +65,8 @@ func FailWithDetails(code, message string, details any) Response[any] {
|
|||
}
|
||||
|
||||
// Paginated wraps data in a successful response with pagination metadata.
|
||||
//
|
||||
// c.JSON(http.StatusOK, api.Paginated(users, page, 20, totalCount))
|
||||
func Paginated[T any](data T, page, perPage, total int) Response[T] {
|
||||
return Response[T]{
|
||||
Success: true,
|
||||
|
|
|
|||
|
|
@ -203,3 +203,26 @@ func TestPaginated_Good_JSONIncludesMeta(t *testing.T) {
|
|||
t.Fatalf("expected total=50, got %v", meta["total"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestResponse_Ugly_ZeroValuesDontPanic(t *testing.T) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Fatalf("Response zero value caused panic: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
// A zero-value Response[any] should be safe to use.
|
||||
var zeroResponse api.Response[any]
|
||||
if zeroResponse.Success {
|
||||
t.Fatal("expected zero-value Success=false")
|
||||
}
|
||||
if zeroResponse.Error != nil {
|
||||
t.Fatal("expected nil Error in zero value")
|
||||
}
|
||||
|
||||
// Paginated with zero values should not panic.
|
||||
paginated := api.Paginated[[]string](nil, 0, 0, 0)
|
||||
if !paginated.Success {
|
||||
t.Fatal("expected Paginated to return Success=true")
|
||||
}
|
||||
}
|
||||
|
|
|
|||
31
sse.go
31
sse.go
|
|
@ -3,11 +3,10 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sync"
|
||||
|
||||
"dappco.re/go/core"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
|
|
@ -34,6 +33,10 @@ type sseEvent struct {
|
|||
}
|
||||
|
||||
// NewSSEBroker creates a ready-to-use SSE broker.
|
||||
//
|
||||
// broker := api.NewSSEBroker()
|
||||
// engine, _ := api.New(api.WithSSE(broker))
|
||||
// broker.Publish("updates", "item.created", item)
|
||||
func NewSSEBroker() *SSEBroker {
|
||||
return &SSEBroker{
|
||||
clients: make(map[*sseClient]struct{}),
|
||||
|
|
@ -43,15 +46,15 @@ func NewSSEBroker() *SSEBroker {
|
|||
// Publish sends an event to all clients subscribed to the given channel.
|
||||
// Clients subscribed to an empty channel (no ?channel= param) receive
|
||||
// events on every channel. The data value is JSON-encoded before sending.
|
||||
//
|
||||
// broker.Publish("orders", "order.placed", order)
|
||||
// broker.Publish("", "ping", nil) // broadcasts to all clients
|
||||
func (b *SSEBroker) Publish(channel, event string, data any) {
|
||||
encoded, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
encoded := core.JSONMarshalString(data)
|
||||
|
||||
msg := sseEvent{
|
||||
Event: event,
|
||||
Data: string(encoded),
|
||||
Data: encoded,
|
||||
}
|
||||
|
||||
b.mu.RLock()
|
||||
|
|
@ -103,19 +106,19 @@ func (b *SSEBroker) Handler() gin.HandlerFunc {
|
|||
c.Writer.Flush()
|
||||
|
||||
// Stream events until client disconnects.
|
||||
ctx := c.Request.Context()
|
||||
requestCtx := c.Request.Context()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
case <-requestCtx.Done():
|
||||
return
|
||||
case evt := <-client.events:
|
||||
_, err := fmt.Fprintf(c.Writer, "event: %s\ndata: %s\n\n", evt.Event, evt.Data)
|
||||
if err != nil {
|
||||
case event := <-client.events:
|
||||
line := core.Sprintf("event: %s\ndata: %s\n\n", event.Event, event.Data)
|
||||
if _, err := c.Writer.WriteString(line); err != nil {
|
||||
return
|
||||
}
|
||||
// Flush to ensure the event is sent immediately.
|
||||
if f, ok := c.Writer.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
if flusher, ok := c.Writer.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
17
sse_test.go
17
sse_test.go
|
|
@ -306,3 +306,20 @@ func waitForClients(t *testing.T, broker *api.SSEBroker, want int) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSEBroker_Ugly_PublishToEmptyBrokerDoesNotPanic(t *testing.T) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Fatalf("Publish on empty broker panicked: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
broker := api.NewSSEBroker()
|
||||
// Publishing with no connected clients should be a no-op, not panic.
|
||||
broker.Publish("channel", "event", map[string]string{"key": "value"})
|
||||
broker.Publish("", "ping", nil)
|
||||
|
||||
if broker.ClientCount() != 0 {
|
||||
t.Fatalf("expected 0 clients, got %d", broker.ClientCount())
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -162,3 +162,27 @@ func TestWithStatic_Good_MultipleStaticDirs(t *testing.T) {
|
|||
t.Fatalf("css: expected body=%q, got %q", "body{}", w2.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithStatic_Ugly_NonexistentRootDoesNotPanic(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Fatalf("WithStatic with nonexistent root panicked: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
// A nonexistent root path should not panic on engine construction.
|
||||
engine, err := api.New(api.WithStatic("/assets", "/nonexistent/path/that/does/not/exist"))
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
request, _ := http.NewRequest(http.MethodGet, "/health", nil)
|
||||
engine.Handler().ServeHTTP(recorder, request)
|
||||
|
||||
if recorder.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d", recorder.Code)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3,10 +3,10 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"dappco.re/go/core"
|
||||
"github.com/gin-gonic/gin"
|
||||
swaggerFiles "github.com/swaggo/files"
|
||||
ginSwagger "github.com/swaggo/gin-swagger"
|
||||
|
|
@ -40,7 +40,7 @@ func (s *swaggerSpec) ReadDoc() string {
|
|||
}
|
||||
|
||||
// registerSwagger mounts the Swagger UI and doc.json endpoint.
|
||||
func registerSwagger(g *gin.Engine, title, description, version string, groups []RouteGroup) {
|
||||
func registerSwagger(router *gin.Engine, title, description, version string, groups []RouteGroup) {
|
||||
spec := &swaggerSpec{
|
||||
builder: &SpecBuilder{
|
||||
Title: title,
|
||||
|
|
@ -49,7 +49,7 @@ func registerSwagger(g *gin.Engine, title, description, version string, groups [
|
|||
},
|
||||
groups: groups,
|
||||
}
|
||||
name := fmt.Sprintf("swagger_%d", swaggerSeq.Add(1))
|
||||
name := core.Sprintf("swagger_%d", swaggerSeq.Add(1))
|
||||
swag.Register(name, spec)
|
||||
g.GET("/swagger/*any", ginSwagger.WrapHandler(swaggerFiles.NewHandler(), ginSwagger.InstanceName(name)))
|
||||
router.GET("/swagger/*any", ginSwagger.WrapHandler(swaggerFiles.NewHandler(), ginSwagger.InstanceName(name)))
|
||||
}
|
||||
|
|
|
|||
|
|
@ -317,3 +317,30 @@ func (h *swaggerSpecHelper) ReadDoc() string {
|
|||
h.cache = string(data)
|
||||
return h.cache
|
||||
}
|
||||
|
||||
func TestRegisterSwagger_Ugly_MultipleEnginesDoNotCollide(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Fatalf("multiple Swagger engines panicked: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
// Creating multiple engines with Swagger enabled should not panic due
|
||||
// to global swag registry collisions (each instance uses a unique name).
|
||||
for engineIndex := 0; engineIndex < 5; engineIndex++ {
|
||||
engine, err := api.New(api.WithSwagger("Test API", "desc", "v1"))
|
||||
if err != nil {
|
||||
t.Fatalf("engine %d: unexpected error: %v", engineIndex, err)
|
||||
}
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
request, _ := http.NewRequest(http.MethodGet, "/health", nil)
|
||||
engine.Handler().ServeHTTP(recorder, request)
|
||||
|
||||
if recorder.Code != http.StatusOK {
|
||||
t.Fatalf("engine %d: expected 200, got %d", engineIndex, recorder.Code)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -250,3 +250,30 @@ func TestWithTracing_Good_ServiceNameInSpan(t *testing.T) {
|
|||
t.Errorf("expected server.address=%q, got %q", serviceName, kv.Value.AsString())
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithTracing_Ugly_DoubleRegistrationDoesNotPanic(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Fatalf("double WithTracing panicked: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
// Registering tracing twice should not panic.
|
||||
engine, err := api.New(
|
||||
api.WithTracing("service-a"),
|
||||
api.WithTracing("service-b"),
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
request, _ := http.NewRequest(http.MethodGet, "/health", nil)
|
||||
engine.Handler().ServeHTTP(recorder, request)
|
||||
|
||||
if recorder.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d", recorder.Code)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -114,3 +114,27 @@ func TestChannelListing_Good(t *testing.T) {
|
|||
t.Fatalf("expected channels[1]=%q, got %q", "wsstub.updates", channels[1])
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithWSHandler_Ugly_NilHandlerDoesNotPanic(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Fatalf("WithWSHandler with nil panicked: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
// Passing nil as the WS handler should not crash engine construction.
|
||||
engine, err := api.New(api.WithWSHandler(nil))
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
request, _ := http.NewRequest(http.MethodGet, "/health", nil)
|
||||
engine.Handler().ServeHTTP(recorder, request)
|
||||
|
||||
if recorder.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d", recorder.Code)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue