From c6034031a310e8113d77ff4bafa29e7df9e87939 Mon Sep 17 00:00:00 2001 From: Virgil Date: Wed, 1 Apr 2026 16:50:29 +0000 Subject: [PATCH] feat(bridge): enforce additional schema constraints Co-Authored-By: Virgil --- bridge.go | 152 ++++++++++++++++++++++++++++++++++++++++++++++++- bridge_test.go | 98 +++++++++++++++++++++++++++++++ 2 files changed, 249 insertions(+), 1 deletion(-) diff --git a/bridge.go b/bridge.go index 97ecf9b..9e1e201 100644 --- a/bridge.go +++ b/bridge.go @@ -13,7 +13,9 @@ import ( "net" "net/http" "reflect" + "regexp" "strconv" + "unicode/utf8" "github.com/gin-gonic/gin" ) @@ -297,6 +299,9 @@ func validateSchemaNode(value any, schema map[string]any, path string) error { return fmt.Errorf("%s contains unknown field %q", displayPath(path), name) } } + if err := validateObjectConstraints(obj, schema, path); err != nil { + return err + } case "array": arr, ok := value.([]any) if !ok { @@ -309,10 +314,17 @@ func validateSchemaNode(value any, schema map[string]any, path string) error { } } } + if err := validateArrayConstraints(arr, schema, path); err != nil { + return err + } case "string": - if _, ok := value.(string); !ok { + str, ok := value.(string) + if !ok { return typeError(path, "string", value) } + if err := validateStringConstraints(str, schema, path); err != nil { + return err + } case "boolean": if _, ok := value.(bool); !ok { return typeError(path, "boolean", value) @@ -321,10 +333,16 @@ func validateSchemaNode(value any, schema map[string]any, path string) error { if !isIntegerValue(value) { return typeError(path, "integer", value) } + if err := validateNumericConstraints(value, schema, path); err != nil { + return err + } case "number": if !isNumberValue(value) { return typeError(path, "number", value) } + if err := validateNumericConstraints(value, schema, path); err != nil { + return err + } } } @@ -347,6 +365,138 @@ func validateSchemaNode(value any, schema map[string]any, path string) error { return nil } +func validateStringConstraints(value string, schema map[string]any, path string) error { + length := utf8.RuneCountInString(value) + if minLength, ok := schemaInt(schema["minLength"]); ok && length < minLength { + return fmt.Errorf("%s must be at least %d characters long", displayPath(path), minLength) + } + if maxLength, ok := schemaInt(schema["maxLength"]); ok && length > maxLength { + return fmt.Errorf("%s must be at most %d characters long", displayPath(path), maxLength) + } + if pattern, ok := schema["pattern"].(string); ok && pattern != "" { + re, err := regexp.Compile(pattern) + if err != nil { + return fmt.Errorf("%s has an invalid pattern %q: %w", displayPath(path), pattern, err) + } + if !re.MatchString(value) { + return fmt.Errorf("%s does not match pattern %q", displayPath(path), pattern) + } + } + return nil +} + +func validateNumericConstraints(value any, schema map[string]any, path string) error { + if minimum, ok := schemaFloat(schema["minimum"]); ok && numericLessThan(value, minimum) { + return fmt.Errorf("%s must be greater than or equal to %v", displayPath(path), minimum) + } + if maximum, ok := schemaFloat(schema["maximum"]); ok && numericGreaterThan(value, maximum) { + return fmt.Errorf("%s must be less than or equal to %v", displayPath(path), maximum) + } + return nil +} + +func validateArrayConstraints(value []any, schema map[string]any, path string) error { + if minItems, ok := schemaInt(schema["minItems"]); ok && len(value) < minItems { + return fmt.Errorf("%s must contain at least %d items", displayPath(path), minItems) + } + if maxItems, ok := schemaInt(schema["maxItems"]); ok && len(value) > maxItems { + return fmt.Errorf("%s must contain at most %d items", displayPath(path), maxItems) + } + return nil +} + +func validateObjectConstraints(value map[string]any, schema map[string]any, path string) error { + if minProps, ok := schemaInt(schema["minProperties"]); ok && len(value) < minProps { + return fmt.Errorf("%s must contain at least %d properties", displayPath(path), minProps) + } + if maxProps, ok := schemaInt(schema["maxProperties"]); ok && len(value) > maxProps { + return fmt.Errorf("%s must contain at most %d properties", displayPath(path), maxProps) + } + return nil +} + +func schemaInt(value any) (int, bool) { + switch v := value.(type) { + case int: + return v, true + case int8: + return int(v), true + case int16: + return int(v), true + case int32: + return int(v), true + case int64: + return int(v), true + case uint: + return int(v), true + case uint8: + return int(v), true + case uint16: + return int(v), true + case uint32: + return int(v), true + case uint64: + return int(v), true + case float64: + if v == float64(int(v)) { + return int(v), true + } + case json.Number: + if n, err := v.Int64(); err == nil { + return int(n), true + } + } + return 0, false +} + +func schemaFloat(value any) (float64, bool) { + switch v := value.(type) { + case float64: + return v, true + case float32: + return float64(v), true + case int: + return float64(v), true + case int8: + return float64(v), true + case int16: + return float64(v), true + case int32: + return float64(v), true + case int64: + return float64(v), true + case uint: + return float64(v), true + case uint8: + return float64(v), true + case uint16: + return float64(v), true + case uint32: + return float64(v), true + case uint64: + return float64(v), true + case json.Number: + if n, err := v.Float64(); err == nil { + return n, true + } + } + return 0, false +} + +func numericLessThan(value any, limit float64) bool { + if n, ok := numericValue(value); ok { + return n < limit + } + return false +} + +func numericGreaterThan(value any, limit float64) bool { + if n, ok := numericValue(value); ok { + return n > limit + } + return false +} + type toolResponseRecorder struct { gin.ResponseWriter headers http.Header diff --git a/bridge_test.go b/bridge_test.go index 227b3da..baa8e26 100644 --- a/bridge_test.go +++ b/bridge_test.go @@ -452,6 +452,104 @@ func TestToolBridge_Bad_RejectsAdditionalProperties(t *testing.T) { } } +func TestToolBridge_Good_EnforcesStringConstraints(t *testing.T) { + gin.SetMode(gin.TestMode) + engine := gin.New() + + bridge := api.NewToolBridge("/tools") + bridge.Add(api.ToolDescriptor{ + Name: "publish_code", + Description: "Publish a code", + Group: "items", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "code": map[string]any{ + "type": "string", + "minLength": 3, + "maxLength": 5, + "pattern": "^[A-Z]+$", + }, + }, + "required": []any{"code"}, + }, + }, func(c *gin.Context) { + c.JSON(http.StatusOK, api.OK("accepted")) + }) + + rg := engine.Group(bridge.BasePath()) + bridge.RegisterRoutes(rg) + + w := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodPost, "/tools/publish_code", bytes.NewBufferString(`{"code":"ABC"}`)) + engine.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", w.Code) + } +} + +func TestToolBridge_Bad_RejectsNumericAndCollectionConstraints(t *testing.T) { + gin.SetMode(gin.TestMode) + engine := gin.New() + + bridge := api.NewToolBridge("/tools") + bridge.Add(api.ToolDescriptor{ + Name: "quota_check", + Description: "Check quotas", + Group: "items", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "count": map[string]any{ + "type": "integer", + "minimum": 1, + "maximum": 3, + }, + "labels": map[string]any{ + "type": "array", + "minItems": 2, + "maxItems": 4, + "items": map[string]any{ + "type": "string", + }, + }, + "payload": map[string]any{ + "type": "object", + "minProperties": 1, + "maxProperties": 2, + "additionalProperties": true, + }, + }, + "required": []any{"count", "labels", "payload"}, + }, + }, func(c *gin.Context) { + c.JSON(http.StatusOK, api.OK("accepted")) + }) + + rg := engine.Group(bridge.BasePath()) + bridge.RegisterRoutes(rg) + + w := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodPost, "/tools/quota_check", bytes.NewBufferString(`{"count":0,"labels":["one"],"payload":{}}`)) + engine.ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Fatalf("expected 400 for numeric/collection constraint failure, got %d", w.Code) + } + + var resp api.Response[any] + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("unmarshal error: %v", err) + } + if resp.Success { + t.Fatal("expected Success=false") + } + if resp.Error == nil || resp.Error.Code != "invalid_request_body" { + t.Fatalf("expected invalid_request_body error, got %#v", resp.Error) + } +} + func TestToolBridge_Good_ToolsAccessor(t *testing.T) { bridge := api.NewToolBridge("/tools") bridge.Add(api.ToolDescriptor{Name: "alpha", Description: "Tool A", Group: "a"}, func(c *gin.Context) {})