feat(api): validate openapi client requests and responses

Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
Virgil 2026-04-01 17:48:49 +00:00
parent 1ec5bf4062
commit f6349145bc
2 changed files with 228 additions and 5 deletions

121
client.go
View file

@ -36,6 +36,8 @@ type openAPIOperation struct {
method string
pathTemplate string
hasRequestBody bool
requestSchema map[string]any
responseSchema map[string]any
}
// OpenAPIClientOption configures a runtime OpenAPI client.
@ -117,11 +119,22 @@ func (c *OpenAPIClient) Call(operationID string, params any) (any, error) {
return nil, err
}
req, err := http.NewRequest(op.method, requestURL, body)
if op.requestSchema != nil && len(body) > 0 {
if err := validateOpenAPISchema(body, op.requestSchema, "request body"); err != nil {
return nil, err
}
}
var bodyReader io.Reader
if len(body) > 0 {
bodyReader = bytes.NewReader(body)
}
req, err := http.NewRequest(op.method, requestURL, bodyReader)
if err != nil {
return nil, err
}
if body != nil {
if bodyReader != nil {
req.Header.Set("Content-Type", "application/json")
}
if c.bearerToken != "" {
@ -143,6 +156,12 @@ func (c *OpenAPIClient) Call(operationID string, params any) (any, error) {
return nil, fmt.Errorf("openapi call %s returned %s: %s", operationID, resp.Status, strings.TrimSpace(string(payload)))
}
if op.responseSchema != nil && len(bytes.TrimSpace(payload)) > 0 {
if err := validateOpenAPIResponse(payload, op.responseSchema, operationID); err != nil {
return nil, err
}
}
if len(bytes.TrimSpace(payload)) == 0 {
return nil, nil
}
@ -213,6 +232,8 @@ func (c *OpenAPIClient) loadSpec() error {
method: strings.ToUpper(method),
pathTemplate: pathTemplate,
hasRequestBody: operation["requestBody"] != nil,
requestSchema: requestBodySchema(operation),
responseSchema: firstSuccessResponseSchema(operation),
}
}
}
@ -302,7 +323,7 @@ func (c *OpenAPIClient) buildURL(op openAPIOperation, params map[string]any) (st
return fullURL, nil
}
func (c *OpenAPIClient) buildBody(op openAPIOperation, params map[string]any) (io.Reader, error) {
func (c *OpenAPIClient) buildBody(op openAPIOperation, params map[string]any) ([]byte, error) {
if explicitBody, ok := params["body"]; ok {
return encodeJSONBody(explicitBody)
}
@ -342,12 +363,12 @@ func (c *OpenAPIClient) buildBody(op openAPIOperation, params map[string]any) (i
return encodeJSONBody(payload)
}
func encodeJSONBody(v any) (io.Reader, error) {
func encodeJSONBody(v any) ([]byte, error) {
data, err := json.Marshal(v)
if err != nil {
return nil, err
}
return bytes.NewReader(data), nil
return data, nil
}
func normaliseParams(params any) (map[string]any, error) {
@ -467,3 +488,93 @@ func isAbsoluteBaseURL(raw string) bool {
}
return u.Scheme != "" && u.Host != ""
}
func requestBodySchema(operation map[string]any) map[string]any {
rawRequestBody, ok := operation["requestBody"].(map[string]any)
if !ok {
return nil
}
content, ok := rawRequestBody["content"].(map[string]any)
if !ok {
return nil
}
rawJSON, ok := content["application/json"].(map[string]any)
if !ok {
return nil
}
schema, _ := rawJSON["schema"].(map[string]any)
return schema
}
func firstSuccessResponseSchema(operation map[string]any) map[string]any {
responses, ok := operation["responses"].(map[string]any)
if !ok {
return nil
}
for _, code := range []string{"200", "201", "202", "203", "204", "205", "206", "207", "208", "226"} {
rawResp, ok := responses[code].(map[string]any)
if !ok {
continue
}
content, ok := rawResp["content"].(map[string]any)
if !ok {
continue
}
rawJSON, ok := content["application/json"].(map[string]any)
if !ok {
continue
}
schema, _ := rawJSON["schema"].(map[string]any)
if len(schema) > 0 {
return schema
}
}
return nil
}
func validateOpenAPISchema(body []byte, schema map[string]any, label string) error {
if len(bytes.TrimSpace(body)) == 0 {
return nil
}
var payload any
dec := json.NewDecoder(bytes.NewReader(body))
dec.UseNumber()
if err := dec.Decode(&payload); err != nil {
return fmt.Errorf("validate %s: invalid JSON: %w", label, err)
}
var extra any
if err := dec.Decode(&extra); err != io.EOF {
return fmt.Errorf("validate %s: expected a single JSON value", label)
}
if err := validateSchemaNode(payload, schema, ""); err != nil {
return fmt.Errorf("validate %s: %w", label, err)
}
return nil
}
func validateOpenAPIResponse(payload []byte, schema map[string]any, operationID string) error {
var decoded any
dec := json.NewDecoder(bytes.NewReader(payload))
dec.UseNumber()
if err := dec.Decode(&decoded); err != nil {
return fmt.Errorf("openapi call %s returned invalid JSON: %w", operationID, err)
}
var extra any
if err := dec.Decode(&extra); err != io.EOF {
return fmt.Errorf("openapi call %s returned multiple JSON values", operationID)
}
if err := validateSchemaNode(decoded, schema, ""); err != nil {
return fmt.Errorf("openapi call %s response does not match spec: %w", operationID, err)
}
return nil
}

View file

@ -310,6 +310,118 @@ paths:
}
}
func TestOpenAPIClient_Bad_ValidatesRequestBodyAgainstSchema(t *testing.T) {
called := make(chan struct{}, 1)
mux := http.NewServeMux()
mux.HandleFunc("/users", func(w http.ResponseWriter, r *http.Request) {
called <- struct{}{}
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"success":true,"data":{"id":"123"}}`))
})
srv := httptest.NewServer(mux)
defer srv.Close()
specPath := writeTempSpec(t, `openapi: 3.1.0
info:
title: Test API
version: 1.0.0
paths:
/users:
post:
operationId: create_user
requestBody:
required: true
content:
application/json:
schema:
type: object
required: [name]
properties:
name:
type: string
responses:
"200":
description: OK
content:
application/json:
schema:
type: object
properties:
success:
type: boolean
data:
type: object
properties:
id:
type: string
`)
client := api.NewOpenAPIClient(
api.WithSpec(specPath),
api.WithBaseURL(srv.URL),
)
if _, err := client.Call("create_user", map[string]any{
"body": map[string]any{},
}); err == nil {
t.Fatal("expected request body validation error, got nil")
}
select {
case <-called:
t.Fatal("expected request validation to fail before the HTTP call")
default:
}
}
func TestOpenAPIClient_Bad_ValidatesResponseAgainstSchema(t *testing.T) {
mux := http.NewServeMux()
mux.HandleFunc("/users", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"success":true,"data":{"id":123}}`))
})
srv := httptest.NewServer(mux)
defer srv.Close()
specPath := writeTempSpec(t, `openapi: 3.1.0
info:
title: Test API
version: 1.0.0
paths:
/users:
get:
operationId: list_users
responses:
"200":
description: OK
content:
application/json:
schema:
type: object
required: [success, data]
properties:
success:
type: boolean
data:
type: object
required: [id]
properties:
id:
type: string
`)
client := api.NewOpenAPIClient(
api.WithSpec(specPath),
api.WithBaseURL(srv.URL),
)
if _, err := client.Call("list_users", nil); err == nil {
t.Fatal("expected response validation error, got nil")
}
}
func TestOpenAPIClient_Bad_MissingOperation(t *testing.T) {
specPath := writeTempSpec(t, `openapi: 3.1.0
info: