From ebad4c397d76fa60c26e5c86659595f9c36f9640 Mon Sep 17 00:00:00 2001 From: Virgil Date: Wed, 1 Apr 2026 19:50:41 +0000 Subject: [PATCH] feat(client): support header and cookie parameters Co-Authored-By: Virgil --- client.go | 174 +++++++++++++++++++++++++++++++++++++++++++++++-- client_test.go | 97 +++++++++++++++++++++++++++ 2 files changed, 266 insertions(+), 5 deletions(-) diff --git a/client.go b/client.go index 26219d1..0eda237 100644 --- a/client.go +++ b/client.go @@ -36,10 +36,16 @@ type openAPIOperation struct { method string pathTemplate string hasRequestBody bool + parameters []openAPIParameter requestSchema map[string]any responseSchema map[string]any } +type openAPIParameter struct { + name string + in string +} + // OpenAPIClientOption configures a runtime OpenAPI client. type OpenAPIClientOption func(*OpenAPIClient) @@ -88,9 +94,9 @@ func NewOpenAPIClient(opts ...OpenAPIClientOption) *OpenAPIClient { // Call invokes the operation with the given operationId. // // The params argument may be a map, struct, or nil. For convenience, a map may -// include "path", "query", and "body" keys to explicitly control where the -// values are sent. When no explicit body is provided, requests with a declared -// requestBody send the remaining parameters as JSON. +// include "path", "query", "header", "cookie", and "body" keys to explicitly +// control where the values are sent. When no explicit body is provided, +// requests with a declared requestBody send the remaining parameters as JSON. func (c *OpenAPIClient) Call(operationID string, params any) (any, error) { if err := c.load(); err != nil { return nil, err @@ -140,6 +146,7 @@ func (c *OpenAPIClient) Call(operationID string, params any) (any, error) { if c.bearerToken != "" { req.Header.Set("Authorization", "Bearer "+c.bearerToken) } + applyRequestParameters(req, op, merged) resp, err := c.httpClient.Do(req) if err != nil { @@ -228,10 +235,12 @@ func (c *OpenAPIClient) loadSpec() error { if operationID == "" { continue } + params := parseOperationParameters(operation) operations[operationID] = openAPIOperation{ method: strings.ToUpper(method), pathTemplate: pathTemplate, hasRequestBody: operation["requestBody"] != nil, + parameters: params, requestSchema: requestBodySchema(operation), responseSchema: firstSuccessResponseSchema(operation), } @@ -303,12 +312,15 @@ func (c *OpenAPIClient) buildURL(op openAPIOperation, params map[string]any) (st } if op.method == http.MethodGet || (op.method == http.MethodHead && !op.hasRequestBody) { for key, value := range params { - if key == "path" || key == "body" || key == "query" { + if key == "path" || key == "body" || key == "query" || key == "header" || key == "cookie" { continue } if containsString(pathKeys, key) { continue } + if operationParameterLocation(op, key) == "header" || operationParameterLocation(op, key) == "cookie" { + continue + } if _, exists := query[key]; exists { continue } @@ -346,12 +358,16 @@ func (c *OpenAPIClient) buildBody(op openAPIOperation, params map[string]any) ([ payload := make(map[string]any, len(params)) for key, value := range params { - if key == "path" || key == "query" || key == "body" { + if key == "path" || key == "query" || key == "body" || key == "header" || key == "cookie" { continue } if containsString(pathKeys, key) { continue } + switch operationParameterLocation(op, key) { + case "header", "cookie", "query": + continue + } if _, exists := queryKeys[key]; exists { continue } @@ -363,6 +379,154 @@ func (c *OpenAPIClient) buildBody(op openAPIOperation, params map[string]any) ([ return encodeJSONBody(payload) } +func applyRequestParameters(req *http.Request, op openAPIOperation, params map[string]any) { + explicitHeaders, hasExplicitHeaders := nestedMap(params, "header") + explicitCookies, hasExplicitCookies := nestedMap(params, "cookie") + + if hasExplicitHeaders { + applyHeaderValues(req.Header, explicitHeaders) + } + + applyTopLevelHeaderParameters(req.Header, op, params, explicitHeaders, hasExplicitHeaders) + + if hasExplicitCookies { + applyCookieValues(req, explicitCookies) + } + applyTopLevelCookieParameters(req, op, params, explicitCookies, hasExplicitCookies) +} + +func applyTopLevelHeaderParameters(headers http.Header, op openAPIOperation, params, explicit map[string]any, hasExplicit bool) { + for key, value := range params { + if key == "path" || key == "query" || key == "body" || key == "header" || key == "cookie" { + continue + } + if operationParameterLocation(op, key) != "header" { + continue + } + if hasExplicit { + if _, ok := explicit[key]; ok { + continue + } + } + applyHeaderValue(headers, key, value) + } +} + +func applyTopLevelCookieParameters(req *http.Request, op openAPIOperation, params, explicit map[string]any, hasExplicit bool) { + for key, value := range params { + if key == "path" || key == "query" || key == "body" || key == "header" || key == "cookie" { + continue + } + if operationParameterLocation(op, key) != "cookie" { + continue + } + if hasExplicit { + if _, ok := explicit[key]; ok { + continue + } + } + applyCookieValue(req, key, value) + } +} + +func applyHeaderValues(headers http.Header, values map[string]any) { + for key, value := range values { + applyHeaderValue(headers, key, value) + } +} + +func applyHeaderValue(headers http.Header, key string, value any) { + switch v := value.(type) { + case nil: + return + case []string: + for _, item := range v { + headers.Add(key, item) + } + return + case []any: + for _, item := range v { + headers.Add(key, fmt.Sprint(item)) + } + return + } + + rv := reflect.ValueOf(value) + if rv.IsValid() && (rv.Kind() == reflect.Slice || rv.Kind() == reflect.Array) && !(rv.Type().Elem().Kind() == reflect.Uint8) { + for i := 0; i < rv.Len(); i++ { + headers.Add(key, fmt.Sprint(rv.Index(i).Interface())) + } + return + } + + headers.Set(key, fmt.Sprint(value)) +} + +func applyCookieValues(req *http.Request, values map[string]any) { + for key, value := range values { + applyCookieValue(req, key, value) + } +} + +func applyCookieValue(req *http.Request, key string, value any) { + switch v := value.(type) { + case nil: + return + case []string: + for _, item := range v { + req.AddCookie(&http.Cookie{Name: key, Value: item}) + } + return + case []any: + for _, item := range v { + req.AddCookie(&http.Cookie{Name: key, Value: fmt.Sprint(item)}) + } + return + } + + rv := reflect.ValueOf(value) + if rv.IsValid() && (rv.Kind() == reflect.Slice || rv.Kind() == reflect.Array) && !(rv.Type().Elem().Kind() == reflect.Uint8) { + for i := 0; i < rv.Len(); i++ { + req.AddCookie(&http.Cookie{Name: key, Value: fmt.Sprint(rv.Index(i).Interface())}) + } + return + } + + req.AddCookie(&http.Cookie{Name: key, Value: fmt.Sprint(value)}) +} + +func parseOperationParameters(operation map[string]any) []openAPIParameter { + rawParams, ok := operation["parameters"].([]any) + if !ok { + return nil + } + + params := make([]openAPIParameter, 0, len(rawParams)) + for _, rawParam := range rawParams { + param, ok := rawParam.(map[string]any) + if !ok { + continue + } + name, _ := param["name"].(string) + in, _ := param["in"].(string) + if name == "" || in == "" { + continue + } + params = append(params, openAPIParameter{name: name, in: in}) + } + + return params +} + +func operationParameterLocation(op openAPIOperation, name string) string { + for _, param := range op.parameters { + if param.name == name { + return param.in + } + } + return "" +} + func encodeJSONBody(v any) ([]byte, error) { data, err := json.Marshal(v) if err != nil { diff --git a/client_test.go b/client_test.go index 4cc3a7a..ecbdda2 100644 --- a/client_test.go +++ b/client_test.go @@ -257,6 +257,103 @@ paths: } } +func TestOpenAPIClient_Good_UsesHeaderAndCookieParameters(t *testing.T) { + errCh := make(chan error, 1) + mux := http.NewServeMux() + mux.HandleFunc("/inspect", func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + errCh <- fmt.Errorf("expected GET, got %s", r.Method) + w.WriteHeader(http.StatusInternalServerError) + return + } + if got := r.Header.Get("X-Trace-ID"); got != "trace-123" { + errCh <- fmt.Errorf("expected X-Trace-ID=trace-123, got %q", got) + w.WriteHeader(http.StatusInternalServerError) + return + } + if got := r.Header.Get("X-Custom-Header"); got != "custom-value" { + errCh <- fmt.Errorf("expected X-Custom-Header=custom-value, got %q", got) + w.WriteHeader(http.StatusInternalServerError) + return + } + session, err := r.Cookie("session_id") + if err != nil { + errCh <- fmt.Errorf("expected session_id cookie: %v", err) + w.WriteHeader(http.StatusInternalServerError) + return + } + if session.Value != "cookie-123" { + errCh <- fmt.Errorf("expected session_id=cookie-123, got %q", session.Value) + w.WriteHeader(http.StatusInternalServerError) + return + } + pref, err := r.Cookie("pref") + if err != nil { + errCh <- fmt.Errorf("expected pref cookie: %v", err) + w.WriteHeader(http.StatusInternalServerError) + return + } + if pref.Value != "dark" { + errCh <- fmt.Errorf("expected pref=dark, got %q", pref.Value) + w.WriteHeader(http.StatusInternalServerError) + return + } + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"success":true,"data":{"ok":true}}`)) + }) + + srv := httptest.NewServer(mux) + defer srv.Close() + + specPath := writeTempSpec(t, `openapi: 3.1.0 +info: + title: Test API + version: 1.0.0 +paths: + /inspect: + get: + operationId: inspect_request + parameters: + - name: X-Trace-ID + in: header + - name: session_id + in: cookie +`) + + client := api.NewOpenAPIClient( + api.WithSpec(specPath), + api.WithBaseURL(srv.URL), + ) + + result, err := client.Call("inspect_request", map[string]any{ + "X-Trace-ID": "trace-123", + "session_id": "cookie-123", + "header": map[string]any{ + "X-Custom-Header": "custom-value", + }, + "cookie": map[string]any{ + "pref": "dark", + }, + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + select { + case err := <-errCh: + t.Fatal(err) + default: + } + + decoded, ok := result.(map[string]any) + if !ok { + t.Fatalf("expected map result, got %T", result) + } + if okValue, ok := decoded["ok"].(bool); !ok || !okValue { + t.Fatalf("expected ok=true, got %#v", decoded["ok"]) + } +} + func TestOpenAPIClient_Good_UsesFirstAbsoluteServer(t *testing.T) { errCh := make(chan error, 1) mux := http.NewServeMux()