From 28f9540fa8add733e69f8a9e5c0ede93e407261c Mon Sep 17 00:00:00 2001 From: Virgil Date: Wed, 1 Apr 2026 15:54:32 +0000 Subject: [PATCH] fix(bridge): enforce tool schema enum validation Co-Authored-By: Virgil --- bridge.go | 13 ++++--------- bridge_test.go | 46 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 50 insertions(+), 9 deletions(-) diff --git a/bridge.go b/bridge.go index 067dd88..97ecf9b 100644 --- a/bridge.go +++ b/bridge.go @@ -257,7 +257,8 @@ func validateSchemaNode(value any, schema map[string]any, path string) error { return nil } - if schemaType, _ := schema["type"].(string); schemaType != "" { + schemaType, _ := schema["type"].(string) + if schemaType != "" { switch schemaType { case "object": obj, ok := value.(map[string]any) @@ -296,8 +297,6 @@ func validateSchemaNode(value any, schema map[string]any, path string) error { return fmt.Errorf("%s contains unknown field %q", displayPath(path), name) } } - - return nil case "array": arr, ok := value.([]any) if !ok { @@ -310,31 +309,27 @@ func validateSchemaNode(value any, schema map[string]any, path string) error { } } } - return nil case "string": if _, ok := value.(string); !ok { return typeError(path, "string", value) } - return nil case "boolean": if _, ok := value.(bool); !ok { return typeError(path, "boolean", value) } - return nil case "integer": if !isIntegerValue(value) { return typeError(path, "integer", value) } - return nil case "number": if !isNumberValue(value) { return typeError(path, "number", value) } - return nil } } - if props := schemaMap(schema["properties"]); len(props) > 0 || schema["required"] != nil || schema["additionalProperties"] != nil { + if schemaType == "" && (len(schemaMap(schema["properties"])) > 0 || schema["required"] != nil || schema["additionalProperties"] != nil) { + props := schemaMap(schema["properties"]) return validateSchemaNode(value, map[string]any{ "type": "object", "properties": props, diff --git a/bridge_test.go b/bridge_test.go index ba2e29b..227b3da 100644 --- a/bridge_test.go +++ b/bridge_test.go @@ -362,6 +362,52 @@ func TestToolBridge_Good_ValidatesEnumValues(t *testing.T) { } } +func TestToolBridge_Bad_RejectsInvalidEnumValues(t *testing.T) { + gin.SetMode(gin.TestMode) + engine := gin.New() + + bridge := api.NewToolBridge("/tools") + bridge.Add(api.ToolDescriptor{ + Name: "publish_item", + Description: "Publish an item", + Group: "items", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "status": map[string]any{ + "type": "string", + "enum": []any{"draft", "published"}, + }, + }, + "required": []any{"status"}, + }, + }, func(c *gin.Context) { + c.JSON(http.StatusOK, api.OK("published")) + }) + + rg := engine.Group(bridge.BasePath()) + bridge.RegisterRoutes(rg) + + w := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodPost, "/tools/publish_item", bytes.NewBufferString(`{"status":"archived"}`)) + engine.ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Fatalf("expected 400, got %d", w.Code) + } + + var resp api.Response[any] + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("unmarshal error: %v", err) + } + if resp.Success { + t.Fatal("expected Success=false") + } + if resp.Error == nil || resp.Error.Code != "invalid_request_body" { + t.Fatalf("expected invalid_request_body error, got %#v", resp.Error) + } +} + func TestToolBridge_Bad_RejectsAdditionalProperties(t *testing.T) { gin.SetMode(gin.TestMode) engine := gin.New()