diff --git a/bridge.go b/bridge.go index 0b74e38..a693a36 100644 --- a/bridge.go +++ b/bridge.go @@ -194,13 +194,24 @@ func describeTool(desc ToolDescriptor, defaultTag string) RouteDescription { } } +// maxToolRequestBodyBytes is the maximum request body size accepted by the +// tool bridge handler. Requests larger than this are rejected with 413. +const maxToolRequestBodyBytes = 10 << 20 // 10 MiB + func wrapToolHandler(handler gin.HandlerFunc, validator *toolInputValidator) gin.HandlerFunc { return func(c *gin.Context) { - body, err := io.ReadAll(c.Request.Body) + limited := http.MaxBytesReader(c.Writer, c.Request.Body, maxToolRequestBodyBytes) + body, err := io.ReadAll(limited) if err != nil { - c.AbortWithStatusJSON(http.StatusBadRequest, FailWithDetails( + status := http.StatusBadRequest + msg := "Unable to read request body" + if err.Error() == "http: request body too large" { + status = http.StatusRequestEntityTooLarge + msg = "Request body exceeds the maximum allowed size" + } + c.AbortWithStatusJSON(status, FailWithDetails( "invalid_request_body", - "Unable to read request body", + msg, map[string]any{"error": err.Error()}, )) return @@ -289,9 +300,12 @@ func (v *toolInputValidator) ValidateResponse(body []byte) error { return coreerr.E("ToolBridge.ValidateResponse", "response is missing a successful envelope", nil) } + // data is serialised with omitempty, so a nil/zero-value payload from + // constructors like OK(nil) or OK(false) will omit the key entirely. + // Treat a missing data key as a valid nil payload for successful responses. data, ok := envelope["data"] if !ok { - return coreerr.E("ToolBridge.ValidateResponse", "response is missing data", nil) + return nil } encoded, err := json.Marshal(data) @@ -691,10 +705,17 @@ func (w *toolResponseRecorder) reset() { func (w *toolResponseRecorder) writeErrorResponse(status int, resp Response[any]) { data, err := json.Marshal(resp) if err != nil { + w.status = http.StatusInternalServerError + w.wroteHeader = true http.Error(w.ResponseWriter, "internal server error", http.StatusInternalServerError) return } + // Update recorder state so middleware observing c.Writer.Status() or + // Written() sees the correct values after an error response is emitted. + w.status = status + w.wroteHeader = true + w.ResponseWriter.Header().Set("Content-Type", "application/json") w.ResponseWriter.WriteHeader(status) _, _ = w.ResponseWriter.Write(data) diff --git a/cache.go b/cache.go index d9d86c9..9f71f4d 100644 --- a/cache.go +++ b/cache.go @@ -85,6 +85,12 @@ func (s *cacheStore) set(key string, entry *cacheEntry) { } if elem, ok := s.index[key]; ok { + // Reject an oversized replacement before touching LRU state so the + // existing entry remains intact when the new value cannot fit. + if s.maxBytes > 0 && entry.size > s.maxBytes { + s.mu.Unlock() + return + } if existing, exists := s.entries[key]; exists { s.currentBytes -= existing.size if s.currentBytes < 0 { diff --git a/cmd/api/cmd_sdk.go b/cmd/api/cmd_sdk.go index bfad014..dfd5a38 100644 --- a/cmd/api/cmd_sdk.go +++ b/cmd/api/cmd_sdk.go @@ -55,7 +55,8 @@ func addSDKCommand(parent *cli.Command) { // If no spec file was provided, generate one only after confirming the // generator is available. - if specFile == "" { + resolvedSpecFile := specFile + if resolvedSpecFile == "" { builder, err := sdkSpecBuilder(cfg) if err != nil { return err @@ -76,10 +77,10 @@ func addSDKCommand(parent *cli.Command) { if err := goapi.ExportSpecToFileIter(tmpPath, "json", builder, groups); err != nil { return coreerr.E("sdk.Generate", "generate spec", err) } - specFile = tmpPath + resolvedSpecFile = tmpPath } - gen.SpecPath = specFile + gen.SpecPath = resolvedSpecFile // Generate for each language. for _, l := range languages { diff --git a/cmd/api/cmd_test.go b/cmd/api/cmd_test.go index 0d5396d..09361b7 100644 --- a/cmd/api/cmd_test.go +++ b/cmd/api/cmd_test.go @@ -810,6 +810,13 @@ func TestAPISpecCmd_Good_ServerFlagAddsServers(t *testing.T) { } func TestAPISpecCmd_Good_RegisteredSpecGroups(t *testing.T) { + snapshot := api.RegisteredSpecGroups() + api.ResetSpecGroups() + t.Cleanup(func() { + api.ResetSpecGroups() + api.RegisterSpecGroups(snapshot...) + }) + api.RegisterSpecGroups(specCmdStubGroup{}) root := &cli.Command{Use: "root"} diff --git a/codegen.go b/codegen.go index ffa61e5..e031dea 100644 --- a/codegen.go +++ b/codegen.go @@ -93,7 +93,7 @@ func (g *SDKGenerator) Generate(ctx context.Context, language string) error { return coreerr.E("SDKGenerator.Generate", "create output directory", err) } - args := g.buildArgs(generator, outputDir) + args := g.buildArgs(specPath, generator, outputDir) cmd := exec.CommandContext(ctx, "openapi-generator-cli", args...) cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr @@ -106,10 +106,10 @@ func (g *SDKGenerator) Generate(ctx context.Context, language string) error { } // buildArgs constructs the openapi-generator-cli command arguments. -func (g *SDKGenerator) buildArgs(generator, outputDir string) []string { +func (g *SDKGenerator) buildArgs(specPath, generator, outputDir string) []string { args := []string{ "generate", - "-i", g.SpecPath, + "-i", specPath, "-g", generator, "-o", outputDir, } diff --git a/export.go b/export.go index 3095a6c..fadff80 100644 --- a/export.go +++ b/export.go @@ -94,21 +94,37 @@ func ExportSpecToFileIter(path, format string, builder *SpecBuilder, groups iter } func exportSpecToFile(path, op string, write func(io.Writer) error) (err error) { - if err := coreio.Local.EnsureDir(filepath.Dir(path)); err != nil { + dir := filepath.Dir(path) + if err := coreio.Local.EnsureDir(dir); err != nil { return coreerr.E(op, "create directory", err) } - f, err := os.Create(path) + + // Write to a temp file in the same directory so the rename is atomic on + // most filesystems. The destination is never truncated unless the full + // export succeeds. + f, err := os.CreateTemp(dir, ".export-*.tmp") if err != nil { - return coreerr.E(op, "create file", err) + return coreerr.E(op, "create temp file", err) } + tmpPath := f.Name() + defer func() { - if closeErr := f.Close(); closeErr != nil && err == nil { - err = coreerr.E(op, "close file", closeErr) + if err != nil { + _ = os.Remove(tmpPath) } }() - if err = write(f); err != nil { - return err + if writeErr := write(f); writeErr != nil { + _ = f.Close() + return writeErr + } + + if closeErr := f.Close(); closeErr != nil { + return coreerr.E(op, "close temp file", closeErr) + } + + if renameErr := os.Rename(tmpPath, path); renameErr != nil { + return coreerr.E(op, "rename temp file", renameErr) } return nil } diff --git a/openapi.go b/openapi.go index 82d0cc2..eaba297 100644 --- a/openapi.go +++ b/openapi.go @@ -1910,10 +1910,14 @@ func sseResponseHeaders() map[string]any { } // effectiveGraphQLPath returns the configured GraphQL path or the default -// GraphQL path when GraphQL is enabled without an explicit path. +// GraphQL path when GraphQL is enabled without an explicit path. Returns an +// empty string when neither GraphQL nor the playground is enabled. func (sb *SpecBuilder) effectiveGraphQLPath() string { + if !sb.GraphQLEnabled && !sb.GraphQLPlayground { + return "" + } graphqlPath := strings.TrimSpace(sb.GraphQLPath) - if graphqlPath == "" && (sb.GraphQLEnabled || sb.GraphQLPlayground) { + if graphqlPath == "" { return defaultGraphQLPath } return graphqlPath @@ -1940,30 +1944,42 @@ func (sb *SpecBuilder) effectiveGraphQLPlaygroundPath() string { } // effectiveSwaggerPath returns the configured Swagger UI path or the default -// path when Swagger is enabled without an explicit override. +// path when Swagger is enabled without an explicit override. Returns an empty +// string when Swagger is disabled. func (sb *SpecBuilder) effectiveSwaggerPath() string { + if !sb.SwaggerEnabled { + return "" + } swaggerPath := strings.TrimSpace(sb.SwaggerPath) - if swaggerPath == "" && sb.SwaggerEnabled { + if swaggerPath == "" { return defaultSwaggerPath } return swaggerPath } // effectiveWSPath returns the configured WebSocket path or the default path -// when WebSockets are enabled without an explicit override. +// when WebSockets are enabled without an explicit override. Returns an empty +// string when WebSockets are disabled. func (sb *SpecBuilder) effectiveWSPath() string { + if !sb.WSEnabled { + return "" + } wsPath := strings.TrimSpace(sb.WSPath) - if wsPath == "" && sb.WSEnabled { + if wsPath == "" { return defaultWSPath } return wsPath } // effectiveSSEPath returns the configured SSE path or the default path when -// SSE is enabled without an explicit override. +// SSE is enabled without an explicit override. Returns an empty string when +// SSE is disabled. func (sb *SpecBuilder) effectiveSSEPath() string { + if !sb.SSEEnabled { + return "" + } ssePath := strings.TrimSpace(sb.SSEPath) - if ssePath == "" && sb.SSEEnabled { + if ssePath == "" { return defaultSSEPath } return ssePath @@ -1992,7 +2008,10 @@ func (sb *SpecBuilder) effectiveAuthentikPublicPaths() []string { return nil } - paths := []string{"/health", "/swagger", resolveSwaggerPath(sb.SwaggerPath)} + paths := []string{"/health"} + if swaggerPath := sb.effectiveSwaggerPath(); swaggerPath != "" { + paths = append(paths, swaggerPath) + } paths = append(paths, sb.AuthentikPublicPaths...) return normalisePublicPaths(paths) } diff --git a/pkg/provider/proxy.go b/pkg/provider/proxy.go index c588b32..e2ef86b 100644 --- a/pkg/provider/proxy.go +++ b/pkg/provider/proxy.go @@ -3,6 +3,7 @@ package provider import ( + "fmt" "net/http" "net/http/httputil" "net/url" @@ -56,6 +57,16 @@ func NewProxy(cfg ProxyConfig) *ProxyProvider { } } + // url.Parse accepts inputs like "127.0.0.1:9901" without error — they + // parse without a scheme or host, which causes httputil.ReverseProxy to + // fail silently at runtime. Require both to be present. + if target.Scheme == "" || target.Host == "" { + return &ProxyProvider{ + config: cfg, + err: fmt.Errorf("upstream %q must include a scheme and host (e.g. http://127.0.0.1:9901)", cfg.Upstream), + } + } + proxy := httputil.NewSingleHostReverseProxy(target) // Preserve the original Director but strip the base path so the diff --git a/ratelimit.go b/ratelimit.go index 20ebd77..29bed2f 100644 --- a/ratelimit.go +++ b/ratelimit.go @@ -3,6 +3,8 @@ package api import ( + "crypto/sha256" + "encoding/hex" "math" "net/http" "strconv" @@ -146,8 +148,8 @@ func rateLimitMiddleware(limit int) gin.HandlerFunc { return } - c.Next() setRateLimitHeaders(c, decision.limit, decision.remaining, decision.resetAt) + c.Next() } } @@ -183,21 +185,41 @@ func timeUntilFull(tokens float64, limit int) time.Duration { return time.Duration(math.Ceil(seconds * float64(time.Second))) } -// clientRateLimitKey prefers caller-provided credentials for bucket -// isolation, then falls back to the network address. +// clientRateLimitKey derives a bucket key for the request. It prefers a +// validated principal from context (set by auth middleware), then falls back +// to the client IP. Raw credential headers are hashed with SHA-256 when used +// as a last resort so that secrets are never stored in the bucket map. func clientRateLimitKey(c *gin.Context) string { - if apiKey := strings.TrimSpace(c.GetHeader("X-API-Key")); apiKey != "" { - return "api_key:" + apiKey + // Prefer a validated principal placed in context by auth middleware. + if principal, ok := c.Get("principal"); ok && principal != nil { + if s, ok := principal.(string); ok && s != "" { + return "principal:" + s + } } - if bearer := bearerTokenFromHeader(c.GetHeader("Authorization")); bearer != "" { - return "bearer:" + bearer + if userID, ok := c.Get("userID"); ok && userID != nil { + if s, ok := userID.(string); ok && s != "" { + return "user:" + s + } } + + // Fall back to IP address. if ip := c.ClientIP(); ip != "" { return "ip:" + ip } if c.Request != nil && c.Request.RemoteAddr != "" { return "ip:" + c.Request.RemoteAddr } + + // Last resort: hash credential headers so raw secrets are not retained. + if apiKey := strings.TrimSpace(c.GetHeader("X-API-Key")); apiKey != "" { + h := sha256.Sum256([]byte(apiKey)) + return "cred:sha256:" + hex.EncodeToString(h[:]) + } + if bearer := bearerTokenFromHeader(c.GetHeader("Authorization")); bearer != "" { + h := sha256.Sum256([]byte(bearer)) + return "cred:sha256:" + hex.EncodeToString(h[:]) + } + return "ip:unknown" } diff --git a/response_meta.go b/response_meta.go index 8438a7c..74f9e8a 100644 --- a/response_meta.go +++ b/response_meta.go @@ -23,6 +23,7 @@ type responseMetaRecorder struct { gin.ResponseWriter headers http.Header body bytes.Buffer + size int status int wroteHeader bool committed bool @@ -76,7 +77,9 @@ func (w *responseMetaRecorder) Write(data []byte) (int, error) { if !w.wroteHeader { w.WriteHeader(http.StatusOK) } - return w.body.Write(data) + n, err := w.body.Write(data) + w.size += n + return n, err } func (w *responseMetaRecorder) WriteString(s string) (int, error) { @@ -89,7 +92,9 @@ func (w *responseMetaRecorder) WriteString(s string) (int, error) { if !w.wroteHeader { w.WriteHeader(http.StatusOK) } - return w.body.WriteString(s) + n, err := w.body.WriteString(s) + w.size += n + return n, err } func (w *responseMetaRecorder) Flush() { @@ -121,7 +126,10 @@ func (w *responseMetaRecorder) Status() int { } func (w *responseMetaRecorder) Size() int { - return w.body.Len() + if w.passthrough { + return w.ResponseWriter.Size() + } + return w.size } func (w *responseMetaRecorder) Written() bool { @@ -194,6 +202,7 @@ func responseMetaMiddleware() gin.HandlerFunc { recorder.body.Reset() _, _ = recorder.body.Write(body) + recorder.size = len(body) recorder.Header().Set("Content-Length", strconv.Itoa(len(body))) recorder.commit(true) } diff --git a/src/php/src/Api/Concerns/HasApiResponses.php b/src/php/src/Api/Concerns/HasApiResponses.php index 4a47c75..e63de74 100644 --- a/src/php/src/Api/Concerns/HasApiResponses.php +++ b/src/php/src/Api/Concerns/HasApiResponses.php @@ -1,5 +1,7 @@ diff --git a/src/php/src/Api/Exceptions/RateLimitExceededException.php b/src/php/src/Api/Exceptions/RateLimitExceededException.php index ac764cd..b4d768c 100644 --- a/src/php/src/Api/Exceptions/RateLimitExceededException.php +++ b/src/php/src/Api/Exceptions/RateLimitExceededException.php @@ -51,9 +51,17 @@ class RateLimitExceededException extends HttpException )->withHeaders($this->rateLimitResult->headers()); if ($request !== null) { - $origin = $request->headers->get('Origin', '*'); - $response->headers->set('Access-Control-Allow-Origin', $origin); - $response->headers->set('Vary', 'Origin'); + $origin = $request->headers->get('Origin'); + $allowedOrigins = (array) config('cors.allowed_origins', []); + if ($origin !== null && in_array($origin, $allowedOrigins, true)) { + $response->headers->set('Access-Control-Allow-Origin', $origin); + } + + $existingVary = $response->headers->get('Vary'); + $response->headers->set( + 'Vary', + $existingVary ? $existingVary.', Origin' : 'Origin' + ); } return $response; diff --git a/src/php/src/Api/Services/ApiUsageService.php b/src/php/src/Api/Services/ApiUsageService.php index e9d241d..f5d7445 100644 --- a/src/php/src/Api/Services/ApiUsageService.php +++ b/src/php/src/Api/Services/ApiUsageService.php @@ -5,6 +5,7 @@ declare(strict_types=1); namespace Core\Api\Services; use Carbon\Carbon; +use Core\Api\Models\ApiKey; use Core\Api\Models\ApiUsage; use Core\Api\Models\ApiUsageDaily; @@ -282,7 +283,7 @@ class ApiUsageService // Fetch API keys separately to avoid broken eager loading with aggregation $apiKeyIds = $aggregated->pluck('api_key_id')->filter()->unique()->all(); - $apiKeys = \Mod\Api\Models\ApiKey::whereIn('id', $apiKeyIds) + $apiKeys = ApiKey::whereIn('id', $apiKeyIds) ->select('id', 'name', 'prefix') ->get() ->keyBy('id'); diff --git a/src/php/src/Api/Services/SeoReportService.php b/src/php/src/Api/Services/SeoReportService.php index c1cc1d7..2f0fe95 100644 --- a/src/php/src/Api/Services/SeoReportService.php +++ b/src/php/src/Api/Services/SeoReportService.php @@ -20,16 +20,24 @@ class SeoReportService { /** * Analyse a URL and return a technical SEO report. + * + * @throws RuntimeException when the URL is blocked for SSRF reasons or the fetch fails. */ public function analyse(string $url): array { + $this->validateUrlForSsrf($url); + try { $response = Http::withHeaders([ 'User-Agent' => config('app.name', 'Core API').' SEO Reporter/1.0', 'Accept' => 'text/html,application/xhtml+xml', ]) ->timeout((int) config('api.seo.timeout', 10)) - ->get($url); + ->withoutRedirecting() + ->get($url) + ->throw(); + } catch (RuntimeException $exception) { + throw $exception; } catch (Throwable $exception) { throw new RuntimeException('Unable to fetch the requested URL.', 0, $exception); } @@ -349,6 +357,107 @@ class SeoReportService ]; } + /** + * Validate that a URL is safe to fetch and does not target internal/private + * network resources (SSRF protection). + * + * Blocks: + * - Non-HTTP/HTTPS schemes + * - Loopback addresses (127.0.0.0/8, ::1) + * - RFC-1918 private ranges (10/8, 172.16/12, 192.168/16) + * - Link-local ranges (169.254.0.0/16, fe80::/10) + * - IPv6 ULA (fc00::/7) + * + * @throws RuntimeException when the URL fails SSRF validation. + */ + protected function validateUrlForSsrf(string $url): void + { + $parsed = parse_url($url); + + if ($parsed === false || empty($parsed['scheme']) || empty($parsed['host'])) { + throw new RuntimeException('The supplied URL is not valid.'); + } + + if (! in_array(strtolower($parsed['scheme']), ['http', 'https'], true)) { + throw new RuntimeException('Only HTTP and HTTPS URLs are permitted.'); + } + + $host = $parsed['host']; + $records = dns_get_record($host, DNS_A | DNS_AAAA) ?: []; + + // Fall back to gethostbyname for hosts not returned by dns_get_record. + if (empty($records)) { + $resolved = gethostbyname($host); + if ($resolved !== $host) { + $records[] = ['ip' => $resolved]; + } + } + + foreach ($records as $record) { + $ip = $record['ip'] ?? $record['ipv6'] ?? null; + if ($ip === null) { + continue; + } + if ($this->isPrivateIp($ip)) { + throw new RuntimeException('The supplied URL resolves to a private or reserved address.'); + } + } + } + + /** + * Return true when an IP address falls within a private, loopback, or + * link-local range. + */ + protected function isPrivateIp(string $ip): bool + { + // inet_pton returns false for invalid addresses. + $packed = inet_pton($ip); + if ($packed === false) { + return true; // Treat unresolvable as unsafe. + } + + if (strlen($packed) === 4) { + // IPv4 checks. + $long = ip2long($ip); + if ($long === false) { + return true; + } + $privateRanges = [ + ['start' => ip2long('127.0.0.0'), 'end' => ip2long('127.255.255.255')], // loopback + ['start' => ip2long('10.0.0.0'), 'end' => ip2long('10.255.255.255')], // RFC-1918 + ['start' => ip2long('172.16.0.0'), 'end' => ip2long('172.31.255.255')], // RFC-1918 + ['start' => ip2long('192.168.0.0'), 'end' => ip2long('192.168.255.255')], // RFC-1918 + ['start' => ip2long('169.254.0.0'), 'end' => ip2long('169.254.255.255')], // link-local + ]; + foreach ($privateRanges as $range) { + if ($long >= $range['start'] && $long <= $range['end']) { + return true; + } + } + + return false; + } + + // IPv6 checks: loopback (::1), link-local (fe80::/10), ULA (fc00::/7). + if ($ip === '::1') { + return true; + } + $prefix2 = strtolower(substr(bin2hex($packed), 0, 2)); + // fe80::/10 — first byte 0xfe, second byte 0x80–0xbf + if ($prefix2 === 'fe') { + $secondNibble = hexdec(substr(bin2hex($packed), 2, 1)); + if ($secondNibble >= 8 && $secondNibble <= 11) { + return true; + } + } + // fc00::/7 — first byte 0xfc or 0xfd + if (in_array($prefix2, ['fc', 'fd'], true)) { + return true; + } + + return false; + } + /** * Quote a literal for XPath queries. */ diff --git a/src/php/src/Api/config.php b/src/php/src/Api/config.php index 220058c..d2a835d 100644 --- a/src/php/src/Api/config.php +++ b/src/php/src/Api/config.php @@ -1,5 +1,7 @@