diff --git a/bridge.go b/bridge.go index 9e1e201..eb2b5d0 100644 --- a/bridge.go +++ b/bridge.go @@ -362,6 +362,53 @@ func validateSchemaNode(value any, schema map[string]any, path string) error { } } + if err := validateSchemaCombinators(value, schema, path); err != nil { + return err + } + + return nil +} + +func validateSchemaCombinators(value any, schema map[string]any, path string) error { + if subschemas := schemaObjects(schema["allOf"]); len(subschemas) > 0 { + for _, subschema := range subschemas { + if err := validateSchemaNode(value, subschema, path); err != nil { + return err + } + } + } + + if subschemas := schemaObjects(schema["anyOf"]); len(subschemas) > 0 { + for _, subschema := range subschemas { + if err := validateSchemaNode(value, subschema, path); err == nil { + goto anyOfMatched + } + } + return fmt.Errorf("%s must match at least one schema in anyOf", displayPath(path)) + } + +anyOfMatched: + if subschemas := schemaObjects(schema["oneOf"]); len(subschemas) > 0 { + matches := 0 + for _, subschema := range subschemas { + if err := validateSchemaNode(value, subschema, path); err == nil { + matches++ + } + } + if matches != 1 { + if matches == 0 { + return fmt.Errorf("%s must match exactly one schema in oneOf", displayPath(path)) + } + return fmt.Errorf("%s matches multiple schemas in oneOf", displayPath(path)) + } + } + + if subschema, ok := schema["not"].(map[string]any); ok && subschema != nil { + if err := validateSchemaNode(value, subschema, path); err == nil { + return fmt.Errorf("%s must not match the forbidden schema", displayPath(path)) + } + } + return nil } @@ -624,6 +671,23 @@ func schemaMap(value any) map[string]any { return m } +func schemaObjects(value any) []map[string]any { + switch raw := value.(type) { + case []any: + out := make([]map[string]any, 0, len(raw)) + for _, item := range raw { + if schema := schemaMap(item); schema != nil { + out = append(out, schema) + } + } + return out + case []map[string]any: + return append([]map[string]any(nil), raw...) + default: + return nil + } +} + func stringList(value any) []string { switch raw := value.(type) { case []any: diff --git a/bridge_test.go b/bridge_test.go index baa8e26..706f7eb 100644 --- a/bridge_test.go +++ b/bridge_test.go @@ -408,6 +408,109 @@ func TestToolBridge_Bad_RejectsInvalidEnumValues(t *testing.T) { } } +func TestToolBridge_Good_ValidatesSchemaCombinators(t *testing.T) { + gin.SetMode(gin.TestMode) + engine := gin.New() + + bridge := api.NewToolBridge("/tools") + bridge.Add(api.ToolDescriptor{ + Name: "route_choice", + Description: "Choose a route", + Group: "items", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "choice": map[string]any{ + "oneOf": []any{ + map[string]any{ + "type": "string", + "allOf": []any{ + map[string]any{"minLength": 2}, + map[string]any{"pattern": "^[A-Z]+$"}, + }, + }, + map[string]any{ + "type": "string", + "pattern": "^A", + }, + }, + }, + }, + "required": []any{"choice"}, + }, + }, func(c *gin.Context) { + c.JSON(http.StatusOK, api.OK("accepted")) + }) + + rg := engine.Group(bridge.BasePath()) + bridge.RegisterRoutes(rg) + + w := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodPost, "/tools/route_choice", bytes.NewBufferString(`{"choice":"BC"}`)) + engine.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", w.Code) + } +} + +func TestToolBridge_Bad_RejectsAmbiguousOneOfMatches(t *testing.T) { + gin.SetMode(gin.TestMode) + engine := gin.New() + + bridge := api.NewToolBridge("/tools") + bridge.Add(api.ToolDescriptor{ + Name: "route_choice", + Description: "Choose a route", + Group: "items", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "choice": map[string]any{ + "oneOf": []any{ + map[string]any{ + "type": "string", + "allOf": []any{ + map[string]any{"minLength": 1}, + map[string]any{"pattern": "^[A-Z]+$"}, + }, + }, + map[string]any{ + "type": "string", + "pattern": "^A", + }, + }, + }, + }, + "required": []any{"choice"}, + }, + }, func(c *gin.Context) { + c.JSON(http.StatusOK, api.OK("accepted")) + }) + + rg := engine.Group(bridge.BasePath()) + bridge.RegisterRoutes(rg) + + w := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodPost, "/tools/route_choice", bytes.NewBufferString(`{"choice":"A"}`)) + engine.ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Fatalf("expected 400, got %d", w.Code) + } + + var resp api.Response[any] + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("unmarshal error: %v", err) + } + if resp.Success { + t.Fatal("expected Success=false") + } + if resp.Error == nil || resp.Error.Code != "invalid_request_body" { + t.Fatalf("expected invalid_request_body error, got %#v", resp.Error) + } +} + func TestToolBridge_Bad_RejectsAdditionalProperties(t *testing.T) { gin.SetMode(gin.TestMode) engine := gin.New()