diff --git a/client.go b/client.go index c436e12..7eabeed 100644 --- a/client.go +++ b/client.go @@ -10,6 +10,7 @@ import ( "net/http" "net/url" "os" + "reflect" "strings" "sync" @@ -270,7 +271,7 @@ func (c *OpenAPIClient) buildURL(op openAPIOperation, params map[string]any) (st query := url.Values{} if explicitQuery, ok := nestedMap(params, "query"); ok { for key, value := range explicitQuery { - query.Set(key, fmt.Sprint(value)) + appendQueryValue(query, key, value) } } if op.method == http.MethodGet || (op.method == http.MethodHead && !op.hasRequestBody) { @@ -284,7 +285,7 @@ func (c *OpenAPIClient) buildURL(op openAPIOperation, params map[string]any) (st if _, exists := query[key]; exists { continue } - query.Set(key, fmt.Sprint(value)) + appendQueryValue(query, key, value) } } @@ -413,3 +414,42 @@ func containsString(values []string, target string) bool { } return false } + +func appendQueryValue(query url.Values, key string, value any) { + switch v := value.(type) { + case nil: + return + case []byte: + query.Add(key, string(v)) + return + case []string: + for _, item := range v { + query.Add(key, item) + } + return + case []any: + for _, item := range v { + appendQueryValue(query, key, item) + } + return + } + + rv := reflect.ValueOf(value) + if !rv.IsValid() { + return + } + + switch rv.Kind() { + case reflect.Slice, reflect.Array: + if rv.Type().Elem().Kind() == reflect.Uint8 { + query.Add(key, string(rv.Bytes())) + return + } + for i := 0; i < rv.Len(); i++ { + appendQueryValue(query, key, rv.Index(i).Interface()) + } + return + } + + query.Add(key, fmt.Sprint(value)) +} diff --git a/client_test.go b/client_test.go index ebe7e12..884f43c 100644 --- a/client_test.go +++ b/client_test.go @@ -194,6 +194,69 @@ paths: } } +func TestOpenAPIClient_Good_CallOperationWithRepeatedQueryValues(t *testing.T) { + errCh := make(chan error, 1) + mux := http.NewServeMux() + mux.HandleFunc("/search", func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + errCh <- fmt.Errorf("expected GET, got %s", r.Method) + w.WriteHeader(http.StatusInternalServerError) + return + } + if got := r.URL.Query()["tag"]; len(got) != 2 || got[0] != "alpha" || got[1] != "beta" { + errCh <- fmt.Errorf("expected repeated tag values [alpha beta], got %v", got) + w.WriteHeader(http.StatusInternalServerError) + return + } + if got := r.URL.Query().Get("page"); got != "2" { + errCh <- fmt.Errorf("expected page=2, got %q", got) + w.WriteHeader(http.StatusInternalServerError) + return + } + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"success":true,"data":{"ok":true}}`)) + }) + + srv := httptest.NewServer(mux) + defer srv.Close() + + specPath := writeTempSpec(t, `openapi: 3.1.0 +info: + title: Test API + version: 1.0.0 +paths: + /search: + get: + operationId: search_items +`) + + client := api.NewOpenAPIClient( + api.WithSpec(specPath), + api.WithBaseURL(srv.URL), + ) + + result, err := client.Call("search_items", map[string]any{ + "tag": []string{"alpha", "beta"}, + "page": 2, + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + select { + case err := <-errCh: + t.Fatal(err) + default: + } + + decoded, ok := result.(map[string]any) + if !ok { + t.Fatalf("expected map result, got %T", result) + } + if okValue, ok := decoded["ok"].(bool); !ok || !okValue { + t.Fatalf("expected ok=true, got %#v", decoded["ok"]) + } +} + func TestOpenAPIClient_Bad_MissingOperation(t *testing.T) { specPath := writeTempSpec(t, `openapi: 3.1.0 info: