feat(bridge): support schema composition keywords

Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
Virgil 2026-04-01 18:19:15 +00:00
parent 6e878778dc
commit 5da281c431
2 changed files with 167 additions and 0 deletions

View file

@ -362,6 +362,53 @@ func validateSchemaNode(value any, schema map[string]any, path string) error {
}
}
if err := validateSchemaCombinators(value, schema, path); err != nil {
return err
}
return nil
}
func validateSchemaCombinators(value any, schema map[string]any, path string) error {
if subschemas := schemaObjects(schema["allOf"]); len(subschemas) > 0 {
for _, subschema := range subschemas {
if err := validateSchemaNode(value, subschema, path); err != nil {
return err
}
}
}
if subschemas := schemaObjects(schema["anyOf"]); len(subschemas) > 0 {
for _, subschema := range subschemas {
if err := validateSchemaNode(value, subschema, path); err == nil {
goto anyOfMatched
}
}
return fmt.Errorf("%s must match at least one schema in anyOf", displayPath(path))
}
anyOfMatched:
if subschemas := schemaObjects(schema["oneOf"]); len(subschemas) > 0 {
matches := 0
for _, subschema := range subschemas {
if err := validateSchemaNode(value, subschema, path); err == nil {
matches++
}
}
if matches != 1 {
if matches == 0 {
return fmt.Errorf("%s must match exactly one schema in oneOf", displayPath(path))
}
return fmt.Errorf("%s matches multiple schemas in oneOf", displayPath(path))
}
}
if subschema, ok := schema["not"].(map[string]any); ok && subschema != nil {
if err := validateSchemaNode(value, subschema, path); err == nil {
return fmt.Errorf("%s must not match the forbidden schema", displayPath(path))
}
}
return nil
}
@ -624,6 +671,23 @@ func schemaMap(value any) map[string]any {
return m
}
func schemaObjects(value any) []map[string]any {
switch raw := value.(type) {
case []any:
out := make([]map[string]any, 0, len(raw))
for _, item := range raw {
if schema := schemaMap(item); schema != nil {
out = append(out, schema)
}
}
return out
case []map[string]any:
return append([]map[string]any(nil), raw...)
default:
return nil
}
}
func stringList(value any) []string {
switch raw := value.(type) {
case []any:

View file

@ -408,6 +408,109 @@ func TestToolBridge_Bad_RejectsInvalidEnumValues(t *testing.T) {
}
}
func TestToolBridge_Good_ValidatesSchemaCombinators(t *testing.T) {
gin.SetMode(gin.TestMode)
engine := gin.New()
bridge := api.NewToolBridge("/tools")
bridge.Add(api.ToolDescriptor{
Name: "route_choice",
Description: "Choose a route",
Group: "items",
InputSchema: map[string]any{
"type": "object",
"properties": map[string]any{
"choice": map[string]any{
"oneOf": []any{
map[string]any{
"type": "string",
"allOf": []any{
map[string]any{"minLength": 2},
map[string]any{"pattern": "^[A-Z]+$"},
},
},
map[string]any{
"type": "string",
"pattern": "^A",
},
},
},
},
"required": []any{"choice"},
},
}, func(c *gin.Context) {
c.JSON(http.StatusOK, api.OK("accepted"))
})
rg := engine.Group(bridge.BasePath())
bridge.RegisterRoutes(rg)
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodPost, "/tools/route_choice", bytes.NewBufferString(`{"choice":"BC"}`))
engine.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d", w.Code)
}
}
func TestToolBridge_Bad_RejectsAmbiguousOneOfMatches(t *testing.T) {
gin.SetMode(gin.TestMode)
engine := gin.New()
bridge := api.NewToolBridge("/tools")
bridge.Add(api.ToolDescriptor{
Name: "route_choice",
Description: "Choose a route",
Group: "items",
InputSchema: map[string]any{
"type": "object",
"properties": map[string]any{
"choice": map[string]any{
"oneOf": []any{
map[string]any{
"type": "string",
"allOf": []any{
map[string]any{"minLength": 1},
map[string]any{"pattern": "^[A-Z]+$"},
},
},
map[string]any{
"type": "string",
"pattern": "^A",
},
},
},
},
"required": []any{"choice"},
},
}, func(c *gin.Context) {
c.JSON(http.StatusOK, api.OK("accepted"))
})
rg := engine.Group(bridge.BasePath())
bridge.RegisterRoutes(rg)
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodPost, "/tools/route_choice", bytes.NewBufferString(`{"choice":"A"}`))
engine.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Fatalf("expected 400, got %d", w.Code)
}
var resp api.Response[any]
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("unmarshal error: %v", err)
}
if resp.Success {
t.Fatal("expected Success=false")
}
if resp.Error == nil || resp.Error.Code != "invalid_request_body" {
t.Fatalf("expected invalid_request_body error, got %#v", resp.Error)
}
}
func TestToolBridge_Bad_RejectsAdditionalProperties(t *testing.T) {
gin.SetMode(gin.TestMode)
engine := gin.New()