feat(bridge): enforce additional schema constraints
Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
parent
bfa80e3a27
commit
c6034031a3
2 changed files with 249 additions and 1 deletions
152
bridge.go
152
bridge.go
|
|
@ -13,7 +13,9 @@ import (
|
|||
"net"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
|
@ -297,6 +299,9 @@ func validateSchemaNode(value any, schema map[string]any, path string) error {
|
|||
return fmt.Errorf("%s contains unknown field %q", displayPath(path), name)
|
||||
}
|
||||
}
|
||||
if err := validateObjectConstraints(obj, schema, path); err != nil {
|
||||
return err
|
||||
}
|
||||
case "array":
|
||||
arr, ok := value.([]any)
|
||||
if !ok {
|
||||
|
|
@ -309,10 +314,17 @@ func validateSchemaNode(value any, schema map[string]any, path string) error {
|
|||
}
|
||||
}
|
||||
}
|
||||
if err := validateArrayConstraints(arr, schema, path); err != nil {
|
||||
return err
|
||||
}
|
||||
case "string":
|
||||
if _, ok := value.(string); !ok {
|
||||
str, ok := value.(string)
|
||||
if !ok {
|
||||
return typeError(path, "string", value)
|
||||
}
|
||||
if err := validateStringConstraints(str, schema, path); err != nil {
|
||||
return err
|
||||
}
|
||||
case "boolean":
|
||||
if _, ok := value.(bool); !ok {
|
||||
return typeError(path, "boolean", value)
|
||||
|
|
@ -321,10 +333,16 @@ func validateSchemaNode(value any, schema map[string]any, path string) error {
|
|||
if !isIntegerValue(value) {
|
||||
return typeError(path, "integer", value)
|
||||
}
|
||||
if err := validateNumericConstraints(value, schema, path); err != nil {
|
||||
return err
|
||||
}
|
||||
case "number":
|
||||
if !isNumberValue(value) {
|
||||
return typeError(path, "number", value)
|
||||
}
|
||||
if err := validateNumericConstraints(value, schema, path); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -347,6 +365,138 @@ func validateSchemaNode(value any, schema map[string]any, path string) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func validateStringConstraints(value string, schema map[string]any, path string) error {
|
||||
length := utf8.RuneCountInString(value)
|
||||
if minLength, ok := schemaInt(schema["minLength"]); ok && length < minLength {
|
||||
return fmt.Errorf("%s must be at least %d characters long", displayPath(path), minLength)
|
||||
}
|
||||
if maxLength, ok := schemaInt(schema["maxLength"]); ok && length > maxLength {
|
||||
return fmt.Errorf("%s must be at most %d characters long", displayPath(path), maxLength)
|
||||
}
|
||||
if pattern, ok := schema["pattern"].(string); ok && pattern != "" {
|
||||
re, err := regexp.Compile(pattern)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s has an invalid pattern %q: %w", displayPath(path), pattern, err)
|
||||
}
|
||||
if !re.MatchString(value) {
|
||||
return fmt.Errorf("%s does not match pattern %q", displayPath(path), pattern)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateNumericConstraints(value any, schema map[string]any, path string) error {
|
||||
if minimum, ok := schemaFloat(schema["minimum"]); ok && numericLessThan(value, minimum) {
|
||||
return fmt.Errorf("%s must be greater than or equal to %v", displayPath(path), minimum)
|
||||
}
|
||||
if maximum, ok := schemaFloat(schema["maximum"]); ok && numericGreaterThan(value, maximum) {
|
||||
return fmt.Errorf("%s must be less than or equal to %v", displayPath(path), maximum)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateArrayConstraints(value []any, schema map[string]any, path string) error {
|
||||
if minItems, ok := schemaInt(schema["minItems"]); ok && len(value) < minItems {
|
||||
return fmt.Errorf("%s must contain at least %d items", displayPath(path), minItems)
|
||||
}
|
||||
if maxItems, ok := schemaInt(schema["maxItems"]); ok && len(value) > maxItems {
|
||||
return fmt.Errorf("%s must contain at most %d items", displayPath(path), maxItems)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateObjectConstraints(value map[string]any, schema map[string]any, path string) error {
|
||||
if minProps, ok := schemaInt(schema["minProperties"]); ok && len(value) < minProps {
|
||||
return fmt.Errorf("%s must contain at least %d properties", displayPath(path), minProps)
|
||||
}
|
||||
if maxProps, ok := schemaInt(schema["maxProperties"]); ok && len(value) > maxProps {
|
||||
return fmt.Errorf("%s must contain at most %d properties", displayPath(path), maxProps)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func schemaInt(value any) (int, bool) {
|
||||
switch v := value.(type) {
|
||||
case int:
|
||||
return v, true
|
||||
case int8:
|
||||
return int(v), true
|
||||
case int16:
|
||||
return int(v), true
|
||||
case int32:
|
||||
return int(v), true
|
||||
case int64:
|
||||
return int(v), true
|
||||
case uint:
|
||||
return int(v), true
|
||||
case uint8:
|
||||
return int(v), true
|
||||
case uint16:
|
||||
return int(v), true
|
||||
case uint32:
|
||||
return int(v), true
|
||||
case uint64:
|
||||
return int(v), true
|
||||
case float64:
|
||||
if v == float64(int(v)) {
|
||||
return int(v), true
|
||||
}
|
||||
case json.Number:
|
||||
if n, err := v.Int64(); err == nil {
|
||||
return int(n), true
|
||||
}
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
func schemaFloat(value any) (float64, bool) {
|
||||
switch v := value.(type) {
|
||||
case float64:
|
||||
return v, true
|
||||
case float32:
|
||||
return float64(v), true
|
||||
case int:
|
||||
return float64(v), true
|
||||
case int8:
|
||||
return float64(v), true
|
||||
case int16:
|
||||
return float64(v), true
|
||||
case int32:
|
||||
return float64(v), true
|
||||
case int64:
|
||||
return float64(v), true
|
||||
case uint:
|
||||
return float64(v), true
|
||||
case uint8:
|
||||
return float64(v), true
|
||||
case uint16:
|
||||
return float64(v), true
|
||||
case uint32:
|
||||
return float64(v), true
|
||||
case uint64:
|
||||
return float64(v), true
|
||||
case json.Number:
|
||||
if n, err := v.Float64(); err == nil {
|
||||
return n, true
|
||||
}
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
func numericLessThan(value any, limit float64) bool {
|
||||
if n, ok := numericValue(value); ok {
|
||||
return n < limit
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func numericGreaterThan(value any, limit float64) bool {
|
||||
if n, ok := numericValue(value); ok {
|
||||
return n > limit
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
type toolResponseRecorder struct {
|
||||
gin.ResponseWriter
|
||||
headers http.Header
|
||||
|
|
|
|||
|
|
@ -452,6 +452,104 @@ func TestToolBridge_Bad_RejectsAdditionalProperties(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestToolBridge_Good_EnforcesStringConstraints(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
engine := gin.New()
|
||||
|
||||
bridge := api.NewToolBridge("/tools")
|
||||
bridge.Add(api.ToolDescriptor{
|
||||
Name: "publish_code",
|
||||
Description: "Publish a code",
|
||||
Group: "items",
|
||||
InputSchema: map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"code": map[string]any{
|
||||
"type": "string",
|
||||
"minLength": 3,
|
||||
"maxLength": 5,
|
||||
"pattern": "^[A-Z]+$",
|
||||
},
|
||||
},
|
||||
"required": []any{"code"},
|
||||
},
|
||||
}, func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, api.OK("accepted"))
|
||||
})
|
||||
|
||||
rg := engine.Group(bridge.BasePath())
|
||||
bridge.RegisterRoutes(rg)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest(http.MethodPost, "/tools/publish_code", bytes.NewBufferString(`{"code":"ABC"}`))
|
||||
engine.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolBridge_Bad_RejectsNumericAndCollectionConstraints(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
engine := gin.New()
|
||||
|
||||
bridge := api.NewToolBridge("/tools")
|
||||
bridge.Add(api.ToolDescriptor{
|
||||
Name: "quota_check",
|
||||
Description: "Check quotas",
|
||||
Group: "items",
|
||||
InputSchema: map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"count": map[string]any{
|
||||
"type": "integer",
|
||||
"minimum": 1,
|
||||
"maximum": 3,
|
||||
},
|
||||
"labels": map[string]any{
|
||||
"type": "array",
|
||||
"minItems": 2,
|
||||
"maxItems": 4,
|
||||
"items": map[string]any{
|
||||
"type": "string",
|
||||
},
|
||||
},
|
||||
"payload": map[string]any{
|
||||
"type": "object",
|
||||
"minProperties": 1,
|
||||
"maxProperties": 2,
|
||||
"additionalProperties": true,
|
||||
},
|
||||
},
|
||||
"required": []any{"count", "labels", "payload"},
|
||||
},
|
||||
}, func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, api.OK("accepted"))
|
||||
})
|
||||
|
||||
rg := engine.Group(bridge.BasePath())
|
||||
bridge.RegisterRoutes(rg)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest(http.MethodPost, "/tools/quota_check", bytes.NewBufferString(`{"count":0,"labels":["one"],"payload":{}}`))
|
||||
engine.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Fatalf("expected 400 for numeric/collection constraint failure, got %d", w.Code)
|
||||
}
|
||||
|
||||
var resp api.Response[any]
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("unmarshal error: %v", err)
|
||||
}
|
||||
if resp.Success {
|
||||
t.Fatal("expected Success=false")
|
||||
}
|
||||
if resp.Error == nil || resp.Error.Code != "invalid_request_body" {
|
||||
t.Fatalf("expected invalid_request_body error, got %#v", resp.Error)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolBridge_Good_ToolsAccessor(t *testing.T) {
|
||||
bridge := api.NewToolBridge("/tools")
|
||||
bridge.Add(api.ToolDescriptor{Name: "alpha", Description: "Tool A", Group: "a"}, func(c *gin.Context) {})
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue