feat(api): enforce tool schema enums
Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
parent
4420651fcf
commit
ac21992623
2 changed files with 186 additions and 4 deletions
111
bridge.go
111
bridge.go
|
|
@ -12,6 +12,7 @@ import (
|
|||
"iter"
|
||||
"net"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"strconv"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
|
@ -283,6 +284,19 @@ func validateSchemaNode(value any, schema map[string]any, path string) error {
|
|||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if additionalProperties, ok := schema["additionalProperties"].(bool); ok && !additionalProperties {
|
||||
properties := schemaMap(schema["properties"])
|
||||
for name := range obj {
|
||||
if properties != nil {
|
||||
if _, ok := properties[name]; ok {
|
||||
continue
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("%s contains unknown field %q", displayPath(path), name)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
case "array":
|
||||
arr, ok := value.([]any)
|
||||
|
|
@ -320,14 +334,21 @@ func validateSchemaNode(value any, schema map[string]any, path string) error {
|
|||
}
|
||||
}
|
||||
|
||||
if props := schemaMap(schema["properties"]); len(props) > 0 {
|
||||
if props := schemaMap(schema["properties"]); len(props) > 0 || schema["required"] != nil || schema["additionalProperties"] != nil {
|
||||
return validateSchemaNode(value, map[string]any{
|
||||
"type": "object",
|
||||
"properties": props,
|
||||
"required": schema["required"],
|
||||
"type": "object",
|
||||
"properties": props,
|
||||
"required": schema["required"],
|
||||
"additionalProperties": schema["additionalProperties"],
|
||||
}, path)
|
||||
}
|
||||
|
||||
if rawEnum, ok := schema["enum"]; ok {
|
||||
if !enumContains(value, rawEnum) {
|
||||
return fmt.Errorf("%s must be one of the declared enum values", displayPath(path))
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
@ -498,6 +519,88 @@ func isNumberValue(value any) bool {
|
|||
}
|
||||
}
|
||||
|
||||
func enumContains(value any, rawEnum any) bool {
|
||||
items := enumValues(rawEnum)
|
||||
for _, candidate := range items {
|
||||
if valuesEqual(value, candidate) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func enumValues(rawEnum any) []any {
|
||||
switch values := rawEnum.(type) {
|
||||
case []any:
|
||||
out := make([]any, 0, len(values))
|
||||
for _, value := range values {
|
||||
out = append(out, value)
|
||||
}
|
||||
return out
|
||||
case []string:
|
||||
out := make([]any, 0, len(values))
|
||||
for _, value := range values {
|
||||
out = append(out, value)
|
||||
}
|
||||
return out
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func valuesEqual(left, right any) bool {
|
||||
if isNumericValue(left) && isNumericValue(right) {
|
||||
lv, lok := numericValue(left)
|
||||
rv, rok := numericValue(right)
|
||||
return lok && rok && lv == rv
|
||||
}
|
||||
return reflect.DeepEqual(left, right)
|
||||
}
|
||||
|
||||
func isNumericValue(value any) bool {
|
||||
switch value.(type) {
|
||||
case json.Number, float64, float32, int, int8, int16, int32, int64,
|
||||
uint, uint8, uint16, uint32, uint64:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func numericValue(value any) (float64, bool) {
|
||||
switch v := value.(type) {
|
||||
case json.Number:
|
||||
n, err := v.Float64()
|
||||
return n, err == nil
|
||||
case float64:
|
||||
return v, true
|
||||
case float32:
|
||||
return float64(v), true
|
||||
case int:
|
||||
return float64(v), true
|
||||
case int8:
|
||||
return float64(v), true
|
||||
case int16:
|
||||
return float64(v), true
|
||||
case int32:
|
||||
return float64(v), true
|
||||
case int64:
|
||||
return float64(v), true
|
||||
case uint:
|
||||
return float64(v), true
|
||||
case uint8:
|
||||
return float64(v), true
|
||||
case uint16:
|
||||
return float64(v), true
|
||||
case uint32:
|
||||
return float64(v), true
|
||||
case uint64:
|
||||
return float64(v), true
|
||||
default:
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
|
||||
func describeJSONValue(value any) string {
|
||||
switch value.(type) {
|
||||
case nil:
|
||||
|
|
|
|||
|
|
@ -327,6 +327,85 @@ func TestToolBridge_Bad_InvalidRequestBody(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestToolBridge_Good_ValidatesEnumValues(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
engine := gin.New()
|
||||
|
||||
bridge := api.NewToolBridge("/tools")
|
||||
bridge.Add(api.ToolDescriptor{
|
||||
Name: "publish_item",
|
||||
Description: "Publish an item",
|
||||
Group: "items",
|
||||
InputSchema: map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"status": map[string]any{
|
||||
"type": "string",
|
||||
"enum": []any{"draft", "published"},
|
||||
},
|
||||
},
|
||||
"required": []any{"status"},
|
||||
},
|
||||
}, func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, api.OK("published"))
|
||||
})
|
||||
|
||||
rg := engine.Group(bridge.BasePath())
|
||||
bridge.RegisterRoutes(rg)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest(http.MethodPost, "/tools/publish_item", bytes.NewBufferString(`{"status":"published"}`))
|
||||
engine.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolBridge_Bad_RejectsAdditionalProperties(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
engine := gin.New()
|
||||
|
||||
bridge := api.NewToolBridge("/tools")
|
||||
bridge.Add(api.ToolDescriptor{
|
||||
Name: "publish_item",
|
||||
Description: "Publish an item",
|
||||
Group: "items",
|
||||
InputSchema: map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"status": map[string]any{"type": "string"},
|
||||
},
|
||||
"required": []any{"status"},
|
||||
"additionalProperties": false,
|
||||
},
|
||||
}, func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, api.OK("published"))
|
||||
})
|
||||
|
||||
rg := engine.Group(bridge.BasePath())
|
||||
bridge.RegisterRoutes(rg)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest(http.MethodPost, "/tools/publish_item", bytes.NewBufferString(`{"status":"published","unexpected":true}`))
|
||||
engine.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Fatalf("expected 400, got %d", w.Code)
|
||||
}
|
||||
|
||||
var resp api.Response[any]
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("unmarshal error: %v", err)
|
||||
}
|
||||
if resp.Success {
|
||||
t.Fatal("expected Success=false")
|
||||
}
|
||||
if resp.Error == nil || resp.Error.Code != "invalid_request_body" {
|
||||
t.Fatalf("expected invalid_request_body error, got %#v", resp.Error)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolBridge_Good_ToolsAccessor(t *testing.T) {
|
||||
bridge := api.NewToolBridge("/tools")
|
||||
bridge.Add(api.ToolDescriptor{Name: "alpha", Description: "Tool A", Group: "a"}, func(c *gin.Context) {})
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue