// SPDX-License-Identifier: EUPL-1.2 package api_test import ( "encoding/json" "net/http" "net/http/httptest" "testing" "github.com/gin-gonic/gin" api "forge.lthn.ai/core/go-api" ) // ── Helpers ───────────────────────────────────────────────────────────── // mwTestGroup provides a simple /v1/secret endpoint for middleware tests. type mwTestGroup struct{} func (m *mwTestGroup) Name() string { return "mw-test" } func (m *mwTestGroup) BasePath() string { return "/v1" } func (m *mwTestGroup) RegisterRoutes(rg *gin.RouterGroup) { rg.GET("/secret", func(c *gin.Context) { c.JSON(http.StatusOK, api.OK("classified")) }) } // ── Bearer auth ───────────────────────────────────────────────────────── func TestBearerAuth_Bad_MissingToken(t *testing.T) { gin.SetMode(gin.TestMode) e, _ := api.New(api.WithBearerAuth("s3cret")) e.Register(&mwTestGroup{}) h := e.Handler() w := httptest.NewRecorder() req, _ := http.NewRequest(http.MethodGet, "/v1/secret", nil) h.ServeHTTP(w, req) if w.Code != http.StatusUnauthorized { t.Fatalf("expected 401, got %d", w.Code) } var resp api.Response[any] if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { t.Fatalf("unmarshal error: %v", err) } if resp.Error == nil || resp.Error.Code != "unauthorised" { t.Fatalf("expected error code=%q, got %+v", "unauthorised", resp.Error) } } func TestBearerAuth_Bad_WrongToken(t *testing.T) { gin.SetMode(gin.TestMode) e, _ := api.New(api.WithBearerAuth("s3cret")) e.Register(&mwTestGroup{}) h := e.Handler() w := httptest.NewRecorder() req, _ := http.NewRequest(http.MethodGet, "/v1/secret", nil) req.Header.Set("Authorization", "Bearer wrong-token") h.ServeHTTP(w, req) if w.Code != http.StatusUnauthorized { t.Fatalf("expected 401, got %d", w.Code) } var resp api.Response[any] if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { t.Fatalf("unmarshal error: %v", err) } if resp.Error == nil || resp.Error.Code != "unauthorised" { t.Fatalf("expected error code=%q, got %+v", "unauthorised", resp.Error) } } func TestBearerAuth_Good_CorrectToken(t *testing.T) { gin.SetMode(gin.TestMode) e, _ := api.New(api.WithBearerAuth("s3cret")) e.Register(&mwTestGroup{}) h := e.Handler() w := httptest.NewRecorder() req, _ := http.NewRequest(http.MethodGet, "/v1/secret", nil) req.Header.Set("Authorization", "Bearer s3cret") h.ServeHTTP(w, req) if w.Code != http.StatusOK { t.Fatalf("expected 200, got %d", w.Code) } var resp api.Response[string] if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { t.Fatalf("unmarshal error: %v", err) } if resp.Data != "classified" { t.Fatalf("expected Data=%q, got %q", "classified", resp.Data) } } func TestBearerAuth_Good_HealthBypassesAuth(t *testing.T) { gin.SetMode(gin.TestMode) e, _ := api.New(api.WithBearerAuth("s3cret")) h := e.Handler() w := httptest.NewRecorder() req, _ := http.NewRequest(http.MethodGet, "/health", nil) // No Authorization header. h.ServeHTTP(w, req) if w.Code != http.StatusOK { t.Fatalf("expected 200 for /health, got %d", w.Code) } } // ── Request ID ────────────────────────────────────────────────────────── func TestRequestID_Good_GeneratedWhenMissing(t *testing.T) { gin.SetMode(gin.TestMode) e, _ := api.New(api.WithRequestID()) h := e.Handler() w := httptest.NewRecorder() req, _ := http.NewRequest(http.MethodGet, "/health", nil) h.ServeHTTP(w, req) id := w.Header().Get("X-Request-ID") if id == "" { t.Fatal("expected X-Request-ID header to be set") } // 16 bytes = 32 hex characters. if len(id) != 32 { t.Fatalf("expected 32-char hex ID, got %d chars: %q", len(id), id) } } func TestRequestID_Good_PreservesClientID(t *testing.T) { gin.SetMode(gin.TestMode) e, _ := api.New(api.WithRequestID()) h := e.Handler() w := httptest.NewRecorder() req, _ := http.NewRequest(http.MethodGet, "/health", nil) req.Header.Set("X-Request-ID", "client-id-abc") h.ServeHTTP(w, req) id := w.Header().Get("X-Request-ID") if id != "client-id-abc" { t.Fatalf("expected X-Request-ID=%q, got %q", "client-id-abc", id) } } // ── CORS ──────────────────────────────────────────────────────────────── func TestCORS_Good_PreflightAllOrigins(t *testing.T) { gin.SetMode(gin.TestMode) e, _ := api.New(api.WithCORS("*")) h := e.Handler() w := httptest.NewRecorder() req, _ := http.NewRequest(http.MethodOptions, "/health", nil) req.Header.Set("Origin", "https://example.com") req.Header.Set("Access-Control-Request-Method", "GET") req.Header.Set("Access-Control-Request-Headers", "Authorization") h.ServeHTTP(w, req) if w.Code != http.StatusNoContent && w.Code != http.StatusOK { t.Fatalf("expected 200 or 204 for preflight, got %d", w.Code) } origin := w.Header().Get("Access-Control-Allow-Origin") if origin != "*" { t.Fatalf("expected Access-Control-Allow-Origin=%q, got %q", "*", origin) } methods := w.Header().Get("Access-Control-Allow-Methods") if methods == "" { t.Fatal("expected Access-Control-Allow-Methods to be set") } headers := w.Header().Get("Access-Control-Allow-Headers") if headers == "" { t.Fatal("expected Access-Control-Allow-Headers to be set") } } func TestCORS_Good_SpecificOrigin(t *testing.T) { gin.SetMode(gin.TestMode) e, _ := api.New(api.WithCORS("https://app.example.com")) h := e.Handler() w := httptest.NewRecorder() req, _ := http.NewRequest(http.MethodOptions, "/health", nil) req.Header.Set("Origin", "https://app.example.com") req.Header.Set("Access-Control-Request-Method", "POST") h.ServeHTTP(w, req) origin := w.Header().Get("Access-Control-Allow-Origin") if origin != "https://app.example.com" { t.Fatalf("expected origin=%q, got %q", "https://app.example.com", origin) } } func TestCORS_Bad_DisallowedOrigin(t *testing.T) { gin.SetMode(gin.TestMode) e, _ := api.New(api.WithCORS("https://allowed.example.com")) h := e.Handler() w := httptest.NewRecorder() req, _ := http.NewRequest(http.MethodOptions, "/health", nil) req.Header.Set("Origin", "https://evil.example.com") req.Header.Set("Access-Control-Request-Method", "GET") h.ServeHTTP(w, req) origin := w.Header().Get("Access-Control-Allow-Origin") if origin != "" { t.Fatalf("expected no Access-Control-Allow-Origin for disallowed origin, got %q", origin) } }