feat(api): validate openapi client requests and responses
Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
parent
1ec5bf4062
commit
f6349145bc
2 changed files with 228 additions and 5 deletions
121
client.go
121
client.go
|
|
@ -36,6 +36,8 @@ type openAPIOperation struct {
|
|||
method string
|
||||
pathTemplate string
|
||||
hasRequestBody bool
|
||||
requestSchema map[string]any
|
||||
responseSchema map[string]any
|
||||
}
|
||||
|
||||
// OpenAPIClientOption configures a runtime OpenAPI client.
|
||||
|
|
@ -117,11 +119,22 @@ func (c *OpenAPIClient) Call(operationID string, params any) (any, error) {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
req, err := http.NewRequest(op.method, requestURL, body)
|
||||
if op.requestSchema != nil && len(body) > 0 {
|
||||
if err := validateOpenAPISchema(body, op.requestSchema, "request body"); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
var bodyReader io.Reader
|
||||
if len(body) > 0 {
|
||||
bodyReader = bytes.NewReader(body)
|
||||
}
|
||||
|
||||
req, err := http.NewRequest(op.method, requestURL, bodyReader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if body != nil {
|
||||
if bodyReader != nil {
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
}
|
||||
if c.bearerToken != "" {
|
||||
|
|
@ -143,6 +156,12 @@ func (c *OpenAPIClient) Call(operationID string, params any) (any, error) {
|
|||
return nil, fmt.Errorf("openapi call %s returned %s: %s", operationID, resp.Status, strings.TrimSpace(string(payload)))
|
||||
}
|
||||
|
||||
if op.responseSchema != nil && len(bytes.TrimSpace(payload)) > 0 {
|
||||
if err := validateOpenAPIResponse(payload, op.responseSchema, operationID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if len(bytes.TrimSpace(payload)) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
|
@ -213,6 +232,8 @@ func (c *OpenAPIClient) loadSpec() error {
|
|||
method: strings.ToUpper(method),
|
||||
pathTemplate: pathTemplate,
|
||||
hasRequestBody: operation["requestBody"] != nil,
|
||||
requestSchema: requestBodySchema(operation),
|
||||
responseSchema: firstSuccessResponseSchema(operation),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -302,7 +323,7 @@ func (c *OpenAPIClient) buildURL(op openAPIOperation, params map[string]any) (st
|
|||
return fullURL, nil
|
||||
}
|
||||
|
||||
func (c *OpenAPIClient) buildBody(op openAPIOperation, params map[string]any) (io.Reader, error) {
|
||||
func (c *OpenAPIClient) buildBody(op openAPIOperation, params map[string]any) ([]byte, error) {
|
||||
if explicitBody, ok := params["body"]; ok {
|
||||
return encodeJSONBody(explicitBody)
|
||||
}
|
||||
|
|
@ -342,12 +363,12 @@ func (c *OpenAPIClient) buildBody(op openAPIOperation, params map[string]any) (i
|
|||
return encodeJSONBody(payload)
|
||||
}
|
||||
|
||||
func encodeJSONBody(v any) (io.Reader, error) {
|
||||
func encodeJSONBody(v any) ([]byte, error) {
|
||||
data, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return bytes.NewReader(data), nil
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func normaliseParams(params any) (map[string]any, error) {
|
||||
|
|
@ -467,3 +488,93 @@ func isAbsoluteBaseURL(raw string) bool {
|
|||
}
|
||||
return u.Scheme != "" && u.Host != ""
|
||||
}
|
||||
|
||||
func requestBodySchema(operation map[string]any) map[string]any {
|
||||
rawRequestBody, ok := operation["requestBody"].(map[string]any)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
content, ok := rawRequestBody["content"].(map[string]any)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
rawJSON, ok := content["application/json"].(map[string]any)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
schema, _ := rawJSON["schema"].(map[string]any)
|
||||
return schema
|
||||
}
|
||||
|
||||
func firstSuccessResponseSchema(operation map[string]any) map[string]any {
|
||||
responses, ok := operation["responses"].(map[string]any)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, code := range []string{"200", "201", "202", "203", "204", "205", "206", "207", "208", "226"} {
|
||||
rawResp, ok := responses[code].(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
content, ok := rawResp["content"].(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
rawJSON, ok := content["application/json"].(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
schema, _ := rawJSON["schema"].(map[string]any)
|
||||
if len(schema) > 0 {
|
||||
return schema
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateOpenAPISchema(body []byte, schema map[string]any, label string) error {
|
||||
if len(bytes.TrimSpace(body)) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
var payload any
|
||||
dec := json.NewDecoder(bytes.NewReader(body))
|
||||
dec.UseNumber()
|
||||
if err := dec.Decode(&payload); err != nil {
|
||||
return fmt.Errorf("validate %s: invalid JSON: %w", label, err)
|
||||
}
|
||||
var extra any
|
||||
if err := dec.Decode(&extra); err != io.EOF {
|
||||
return fmt.Errorf("validate %s: expected a single JSON value", label)
|
||||
}
|
||||
|
||||
if err := validateSchemaNode(payload, schema, ""); err != nil {
|
||||
return fmt.Errorf("validate %s: %w", label, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateOpenAPIResponse(payload []byte, schema map[string]any, operationID string) error {
|
||||
var decoded any
|
||||
dec := json.NewDecoder(bytes.NewReader(payload))
|
||||
dec.UseNumber()
|
||||
if err := dec.Decode(&decoded); err != nil {
|
||||
return fmt.Errorf("openapi call %s returned invalid JSON: %w", operationID, err)
|
||||
}
|
||||
var extra any
|
||||
if err := dec.Decode(&extra); err != io.EOF {
|
||||
return fmt.Errorf("openapi call %s returned multiple JSON values", operationID)
|
||||
}
|
||||
|
||||
if err := validateSchemaNode(decoded, schema, ""); err != nil {
|
||||
return fmt.Errorf("openapi call %s response does not match spec: %w", operationID, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
|||
112
client_test.go
112
client_test.go
|
|
@ -310,6 +310,118 @@ paths:
|
|||
}
|
||||
}
|
||||
|
||||
func TestOpenAPIClient_Bad_ValidatesRequestBodyAgainstSchema(t *testing.T) {
|
||||
called := make(chan struct{}, 1)
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/users", func(w http.ResponseWriter, r *http.Request) {
|
||||
called <- struct{}{}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"success":true,"data":{"id":"123"}}`))
|
||||
})
|
||||
|
||||
srv := httptest.NewServer(mux)
|
||||
defer srv.Close()
|
||||
|
||||
specPath := writeTempSpec(t, `openapi: 3.1.0
|
||||
info:
|
||||
title: Test API
|
||||
version: 1.0.0
|
||||
paths:
|
||||
/users:
|
||||
post:
|
||||
operationId: create_user
|
||||
requestBody:
|
||||
required: true
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: object
|
||||
required: [name]
|
||||
properties:
|
||||
name:
|
||||
type: string
|
||||
responses:
|
||||
"200":
|
||||
description: OK
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
success:
|
||||
type: boolean
|
||||
data:
|
||||
type: object
|
||||
properties:
|
||||
id:
|
||||
type: string
|
||||
`)
|
||||
|
||||
client := api.NewOpenAPIClient(
|
||||
api.WithSpec(specPath),
|
||||
api.WithBaseURL(srv.URL),
|
||||
)
|
||||
|
||||
if _, err := client.Call("create_user", map[string]any{
|
||||
"body": map[string]any{},
|
||||
}); err == nil {
|
||||
t.Fatal("expected request body validation error, got nil")
|
||||
}
|
||||
|
||||
select {
|
||||
case <-called:
|
||||
t.Fatal("expected request validation to fail before the HTTP call")
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAPIClient_Bad_ValidatesResponseAgainstSchema(t *testing.T) {
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/users", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"success":true,"data":{"id":123}}`))
|
||||
})
|
||||
|
||||
srv := httptest.NewServer(mux)
|
||||
defer srv.Close()
|
||||
|
||||
specPath := writeTempSpec(t, `openapi: 3.1.0
|
||||
info:
|
||||
title: Test API
|
||||
version: 1.0.0
|
||||
paths:
|
||||
/users:
|
||||
get:
|
||||
operationId: list_users
|
||||
responses:
|
||||
"200":
|
||||
description: OK
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: object
|
||||
required: [success, data]
|
||||
properties:
|
||||
success:
|
||||
type: boolean
|
||||
data:
|
||||
type: object
|
||||
required: [id]
|
||||
properties:
|
||||
id:
|
||||
type: string
|
||||
`)
|
||||
|
||||
client := api.NewOpenAPIClient(
|
||||
api.WithSpec(specPath),
|
||||
api.WithBaseURL(srv.URL),
|
||||
)
|
||||
|
||||
if _, err := client.Call("list_users", nil); err == nil {
|
||||
t.Fatal("expected response validation error, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAPIClient_Bad_MissingOperation(t *testing.T) {
|
||||
specPath := writeTempSpec(t, `openapi: 3.1.0
|
||||
info:
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue