From f6349145bc87eaa8cf44569ccdba5ce7c45e6a6f Mon Sep 17 00:00:00 2001 From: Virgil Date: Wed, 1 Apr 2026 17:48:49 +0000 Subject: [PATCH] feat(api): validate openapi client requests and responses Co-Authored-By: Virgil --- client.go | 121 +++++++++++++++++++++++++++++++++++++++++++++++-- client_test.go | 112 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 228 insertions(+), 5 deletions(-) diff --git a/client.go b/client.go index d5ef909..26219d1 100644 --- a/client.go +++ b/client.go @@ -36,6 +36,8 @@ type openAPIOperation struct { method string pathTemplate string hasRequestBody bool + requestSchema map[string]any + responseSchema map[string]any } // OpenAPIClientOption configures a runtime OpenAPI client. @@ -117,11 +119,22 @@ func (c *OpenAPIClient) Call(operationID string, params any) (any, error) { return nil, err } - req, err := http.NewRequest(op.method, requestURL, body) + if op.requestSchema != nil && len(body) > 0 { + if err := validateOpenAPISchema(body, op.requestSchema, "request body"); err != nil { + return nil, err + } + } + + var bodyReader io.Reader + if len(body) > 0 { + bodyReader = bytes.NewReader(body) + } + + req, err := http.NewRequest(op.method, requestURL, bodyReader) if err != nil { return nil, err } - if body != nil { + if bodyReader != nil { req.Header.Set("Content-Type", "application/json") } if c.bearerToken != "" { @@ -143,6 +156,12 @@ func (c *OpenAPIClient) Call(operationID string, params any) (any, error) { return nil, fmt.Errorf("openapi call %s returned %s: %s", operationID, resp.Status, strings.TrimSpace(string(payload))) } + if op.responseSchema != nil && len(bytes.TrimSpace(payload)) > 0 { + if err := validateOpenAPIResponse(payload, op.responseSchema, operationID); err != nil { + return nil, err + } + } + if len(bytes.TrimSpace(payload)) == 0 { return nil, nil } @@ -213,6 +232,8 @@ func (c *OpenAPIClient) loadSpec() error { method: strings.ToUpper(method), pathTemplate: pathTemplate, hasRequestBody: operation["requestBody"] != nil, + requestSchema: requestBodySchema(operation), + responseSchema: firstSuccessResponseSchema(operation), } } } @@ -302,7 +323,7 @@ func (c *OpenAPIClient) buildURL(op openAPIOperation, params map[string]any) (st return fullURL, nil } -func (c *OpenAPIClient) buildBody(op openAPIOperation, params map[string]any) (io.Reader, error) { +func (c *OpenAPIClient) buildBody(op openAPIOperation, params map[string]any) ([]byte, error) { if explicitBody, ok := params["body"]; ok { return encodeJSONBody(explicitBody) } @@ -342,12 +363,12 @@ func (c *OpenAPIClient) buildBody(op openAPIOperation, params map[string]any) (i return encodeJSONBody(payload) } -func encodeJSONBody(v any) (io.Reader, error) { +func encodeJSONBody(v any) ([]byte, error) { data, err := json.Marshal(v) if err != nil { return nil, err } - return bytes.NewReader(data), nil + return data, nil } func normaliseParams(params any) (map[string]any, error) { @@ -467,3 +488,93 @@ func isAbsoluteBaseURL(raw string) bool { } return u.Scheme != "" && u.Host != "" } + +func requestBodySchema(operation map[string]any) map[string]any { + rawRequestBody, ok := operation["requestBody"].(map[string]any) + if !ok { + return nil + } + + content, ok := rawRequestBody["content"].(map[string]any) + if !ok { + return nil + } + + rawJSON, ok := content["application/json"].(map[string]any) + if !ok { + return nil + } + + schema, _ := rawJSON["schema"].(map[string]any) + return schema +} + +func firstSuccessResponseSchema(operation map[string]any) map[string]any { + responses, ok := operation["responses"].(map[string]any) + if !ok { + return nil + } + + for _, code := range []string{"200", "201", "202", "203", "204", "205", "206", "207", "208", "226"} { + rawResp, ok := responses[code].(map[string]any) + if !ok { + continue + } + content, ok := rawResp["content"].(map[string]any) + if !ok { + continue + } + rawJSON, ok := content["application/json"].(map[string]any) + if !ok { + continue + } + schema, _ := rawJSON["schema"].(map[string]any) + if len(schema) > 0 { + return schema + } + } + + return nil +} + +func validateOpenAPISchema(body []byte, schema map[string]any, label string) error { + if len(bytes.TrimSpace(body)) == 0 { + return nil + } + + var payload any + dec := json.NewDecoder(bytes.NewReader(body)) + dec.UseNumber() + if err := dec.Decode(&payload); err != nil { + return fmt.Errorf("validate %s: invalid JSON: %w", label, err) + } + var extra any + if err := dec.Decode(&extra); err != io.EOF { + return fmt.Errorf("validate %s: expected a single JSON value", label) + } + + if err := validateSchemaNode(payload, schema, ""); err != nil { + return fmt.Errorf("validate %s: %w", label, err) + } + + return nil +} + +func validateOpenAPIResponse(payload []byte, schema map[string]any, operationID string) error { + var decoded any + dec := json.NewDecoder(bytes.NewReader(payload)) + dec.UseNumber() + if err := dec.Decode(&decoded); err != nil { + return fmt.Errorf("openapi call %s returned invalid JSON: %w", operationID, err) + } + var extra any + if err := dec.Decode(&extra); err != io.EOF { + return fmt.Errorf("openapi call %s returned multiple JSON values", operationID) + } + + if err := validateSchemaNode(decoded, schema, ""); err != nil { + return fmt.Errorf("openapi call %s response does not match spec: %w", operationID, err) + } + + return nil +} diff --git a/client_test.go b/client_test.go index 1c8e9e9..4cc3a7a 100644 --- a/client_test.go +++ b/client_test.go @@ -310,6 +310,118 @@ paths: } } +func TestOpenAPIClient_Bad_ValidatesRequestBodyAgainstSchema(t *testing.T) { + called := make(chan struct{}, 1) + mux := http.NewServeMux() + mux.HandleFunc("/users", func(w http.ResponseWriter, r *http.Request) { + called <- struct{}{} + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"success":true,"data":{"id":"123"}}`)) + }) + + srv := httptest.NewServer(mux) + defer srv.Close() + + specPath := writeTempSpec(t, `openapi: 3.1.0 +info: + title: Test API + version: 1.0.0 +paths: + /users: + post: + operationId: create_user + requestBody: + required: true + content: + application/json: + schema: + type: object + required: [name] + properties: + name: + type: string + responses: + "200": + description: OK + content: + application/json: + schema: + type: object + properties: + success: + type: boolean + data: + type: object + properties: + id: + type: string +`) + + client := api.NewOpenAPIClient( + api.WithSpec(specPath), + api.WithBaseURL(srv.URL), + ) + + if _, err := client.Call("create_user", map[string]any{ + "body": map[string]any{}, + }); err == nil { + t.Fatal("expected request body validation error, got nil") + } + + select { + case <-called: + t.Fatal("expected request validation to fail before the HTTP call") + default: + } +} + +func TestOpenAPIClient_Bad_ValidatesResponseAgainstSchema(t *testing.T) { + mux := http.NewServeMux() + mux.HandleFunc("/users", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"success":true,"data":{"id":123}}`)) + }) + + srv := httptest.NewServer(mux) + defer srv.Close() + + specPath := writeTempSpec(t, `openapi: 3.1.0 +info: + title: Test API + version: 1.0.0 +paths: + /users: + get: + operationId: list_users + responses: + "200": + description: OK + content: + application/json: + schema: + type: object + required: [success, data] + properties: + success: + type: boolean + data: + type: object + required: [id] + properties: + id: + type: string +`) + + client := api.NewOpenAPIClient( + api.WithSpec(specPath), + api.WithBaseURL(srv.URL), + ) + + if _, err := client.Call("list_users", nil); err == nil { + t.Fatal("expected response validation error, got nil") + } +} + func TestOpenAPIClient_Bad_MissingOperation(t *testing.T) { specPath := writeTempSpec(t, `openapi: 3.1.0 info: