feat(bridge): enforce additional schema constraints

Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
Virgil 2026-04-01 16:50:29 +00:00
parent bfa80e3a27
commit c6034031a3
2 changed files with 249 additions and 1 deletions

152
bridge.go
View file

@ -13,7 +13,9 @@ import (
"net"
"net/http"
"reflect"
"regexp"
"strconv"
"unicode/utf8"
"github.com/gin-gonic/gin"
)
@ -297,6 +299,9 @@ func validateSchemaNode(value any, schema map[string]any, path string) error {
return fmt.Errorf("%s contains unknown field %q", displayPath(path), name)
}
}
if err := validateObjectConstraints(obj, schema, path); err != nil {
return err
}
case "array":
arr, ok := value.([]any)
if !ok {
@ -309,10 +314,17 @@ func validateSchemaNode(value any, schema map[string]any, path string) error {
}
}
}
if err := validateArrayConstraints(arr, schema, path); err != nil {
return err
}
case "string":
if _, ok := value.(string); !ok {
str, ok := value.(string)
if !ok {
return typeError(path, "string", value)
}
if err := validateStringConstraints(str, schema, path); err != nil {
return err
}
case "boolean":
if _, ok := value.(bool); !ok {
return typeError(path, "boolean", value)
@ -321,10 +333,16 @@ func validateSchemaNode(value any, schema map[string]any, path string) error {
if !isIntegerValue(value) {
return typeError(path, "integer", value)
}
if err := validateNumericConstraints(value, schema, path); err != nil {
return err
}
case "number":
if !isNumberValue(value) {
return typeError(path, "number", value)
}
if err := validateNumericConstraints(value, schema, path); err != nil {
return err
}
}
}
@ -347,6 +365,138 @@ func validateSchemaNode(value any, schema map[string]any, path string) error {
return nil
}
func validateStringConstraints(value string, schema map[string]any, path string) error {
length := utf8.RuneCountInString(value)
if minLength, ok := schemaInt(schema["minLength"]); ok && length < minLength {
return fmt.Errorf("%s must be at least %d characters long", displayPath(path), minLength)
}
if maxLength, ok := schemaInt(schema["maxLength"]); ok && length > maxLength {
return fmt.Errorf("%s must be at most %d characters long", displayPath(path), maxLength)
}
if pattern, ok := schema["pattern"].(string); ok && pattern != "" {
re, err := regexp.Compile(pattern)
if err != nil {
return fmt.Errorf("%s has an invalid pattern %q: %w", displayPath(path), pattern, err)
}
if !re.MatchString(value) {
return fmt.Errorf("%s does not match pattern %q", displayPath(path), pattern)
}
}
return nil
}
func validateNumericConstraints(value any, schema map[string]any, path string) error {
if minimum, ok := schemaFloat(schema["minimum"]); ok && numericLessThan(value, minimum) {
return fmt.Errorf("%s must be greater than or equal to %v", displayPath(path), minimum)
}
if maximum, ok := schemaFloat(schema["maximum"]); ok && numericGreaterThan(value, maximum) {
return fmt.Errorf("%s must be less than or equal to %v", displayPath(path), maximum)
}
return nil
}
func validateArrayConstraints(value []any, schema map[string]any, path string) error {
if minItems, ok := schemaInt(schema["minItems"]); ok && len(value) < minItems {
return fmt.Errorf("%s must contain at least %d items", displayPath(path), minItems)
}
if maxItems, ok := schemaInt(schema["maxItems"]); ok && len(value) > maxItems {
return fmt.Errorf("%s must contain at most %d items", displayPath(path), maxItems)
}
return nil
}
func validateObjectConstraints(value map[string]any, schema map[string]any, path string) error {
if minProps, ok := schemaInt(schema["minProperties"]); ok && len(value) < minProps {
return fmt.Errorf("%s must contain at least %d properties", displayPath(path), minProps)
}
if maxProps, ok := schemaInt(schema["maxProperties"]); ok && len(value) > maxProps {
return fmt.Errorf("%s must contain at most %d properties", displayPath(path), maxProps)
}
return nil
}
func schemaInt(value any) (int, bool) {
switch v := value.(type) {
case int:
return v, true
case int8:
return int(v), true
case int16:
return int(v), true
case int32:
return int(v), true
case int64:
return int(v), true
case uint:
return int(v), true
case uint8:
return int(v), true
case uint16:
return int(v), true
case uint32:
return int(v), true
case uint64:
return int(v), true
case float64:
if v == float64(int(v)) {
return int(v), true
}
case json.Number:
if n, err := v.Int64(); err == nil {
return int(n), true
}
}
return 0, false
}
func schemaFloat(value any) (float64, bool) {
switch v := value.(type) {
case float64:
return v, true
case float32:
return float64(v), true
case int:
return float64(v), true
case int8:
return float64(v), true
case int16:
return float64(v), true
case int32:
return float64(v), true
case int64:
return float64(v), true
case uint:
return float64(v), true
case uint8:
return float64(v), true
case uint16:
return float64(v), true
case uint32:
return float64(v), true
case uint64:
return float64(v), true
case json.Number:
if n, err := v.Float64(); err == nil {
return n, true
}
}
return 0, false
}
func numericLessThan(value any, limit float64) bool {
if n, ok := numericValue(value); ok {
return n < limit
}
return false
}
func numericGreaterThan(value any, limit float64) bool {
if n, ok := numericValue(value); ok {
return n > limit
}
return false
}
type toolResponseRecorder struct {
gin.ResponseWriter
headers http.Header

View file

@ -452,6 +452,104 @@ func TestToolBridge_Bad_RejectsAdditionalProperties(t *testing.T) {
}
}
func TestToolBridge_Good_EnforcesStringConstraints(t *testing.T) {
gin.SetMode(gin.TestMode)
engine := gin.New()
bridge := api.NewToolBridge("/tools")
bridge.Add(api.ToolDescriptor{
Name: "publish_code",
Description: "Publish a code",
Group: "items",
InputSchema: map[string]any{
"type": "object",
"properties": map[string]any{
"code": map[string]any{
"type": "string",
"minLength": 3,
"maxLength": 5,
"pattern": "^[A-Z]+$",
},
},
"required": []any{"code"},
},
}, func(c *gin.Context) {
c.JSON(http.StatusOK, api.OK("accepted"))
})
rg := engine.Group(bridge.BasePath())
bridge.RegisterRoutes(rg)
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodPost, "/tools/publish_code", bytes.NewBufferString(`{"code":"ABC"}`))
engine.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d", w.Code)
}
}
func TestToolBridge_Bad_RejectsNumericAndCollectionConstraints(t *testing.T) {
gin.SetMode(gin.TestMode)
engine := gin.New()
bridge := api.NewToolBridge("/tools")
bridge.Add(api.ToolDescriptor{
Name: "quota_check",
Description: "Check quotas",
Group: "items",
InputSchema: map[string]any{
"type": "object",
"properties": map[string]any{
"count": map[string]any{
"type": "integer",
"minimum": 1,
"maximum": 3,
},
"labels": map[string]any{
"type": "array",
"minItems": 2,
"maxItems": 4,
"items": map[string]any{
"type": "string",
},
},
"payload": map[string]any{
"type": "object",
"minProperties": 1,
"maxProperties": 2,
"additionalProperties": true,
},
},
"required": []any{"count", "labels", "payload"},
},
}, func(c *gin.Context) {
c.JSON(http.StatusOK, api.OK("accepted"))
})
rg := engine.Group(bridge.BasePath())
bridge.RegisterRoutes(rg)
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodPost, "/tools/quota_check", bytes.NewBufferString(`{"count":0,"labels":["one"],"payload":{}}`))
engine.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Fatalf("expected 400 for numeric/collection constraint failure, got %d", w.Code)
}
var resp api.Response[any]
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("unmarshal error: %v", err)
}
if resp.Success {
t.Fatal("expected Success=false")
}
if resp.Error == nil || resp.Error.Code != "invalid_request_body" {
t.Fatalf("expected invalid_request_body error, got %#v", resp.Error)
}
}
func TestToolBridge_Good_ToolsAccessor(t *testing.T) {
bridge := api.NewToolBridge("/tools")
bridge.Add(api.ToolDescriptor{Name: "alpha", Description: "Tool A", Group: "a"}, func(c *gin.Context) {})