feat(api): enforce tool schema enums

Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
Virgil 2026-04-01 15:46:35 +00:00
parent 4420651fcf
commit ac21992623
2 changed files with 186 additions and 4 deletions

111
bridge.go
View file

@ -12,6 +12,7 @@ import (
"iter"
"net"
"net/http"
"reflect"
"strconv"
"github.com/gin-gonic/gin"
@ -283,6 +284,19 @@ func validateSchemaNode(value any, schema map[string]any, path string) error {
return err
}
}
if additionalProperties, ok := schema["additionalProperties"].(bool); ok && !additionalProperties {
properties := schemaMap(schema["properties"])
for name := range obj {
if properties != nil {
if _, ok := properties[name]; ok {
continue
}
}
return fmt.Errorf("%s contains unknown field %q", displayPath(path), name)
}
}
return nil
case "array":
arr, ok := value.([]any)
@ -320,14 +334,21 @@ func validateSchemaNode(value any, schema map[string]any, path string) error {
}
}
if props := schemaMap(schema["properties"]); len(props) > 0 {
if props := schemaMap(schema["properties"]); len(props) > 0 || schema["required"] != nil || schema["additionalProperties"] != nil {
return validateSchemaNode(value, map[string]any{
"type": "object",
"properties": props,
"required": schema["required"],
"type": "object",
"properties": props,
"required": schema["required"],
"additionalProperties": schema["additionalProperties"],
}, path)
}
if rawEnum, ok := schema["enum"]; ok {
if !enumContains(value, rawEnum) {
return fmt.Errorf("%s must be one of the declared enum values", displayPath(path))
}
}
return nil
}
@ -498,6 +519,88 @@ func isNumberValue(value any) bool {
}
}
func enumContains(value any, rawEnum any) bool {
items := enumValues(rawEnum)
for _, candidate := range items {
if valuesEqual(value, candidate) {
return true
}
}
return false
}
func enumValues(rawEnum any) []any {
switch values := rawEnum.(type) {
case []any:
out := make([]any, 0, len(values))
for _, value := range values {
out = append(out, value)
}
return out
case []string:
out := make([]any, 0, len(values))
for _, value := range values {
out = append(out, value)
}
return out
default:
return nil
}
}
func valuesEqual(left, right any) bool {
if isNumericValue(left) && isNumericValue(right) {
lv, lok := numericValue(left)
rv, rok := numericValue(right)
return lok && rok && lv == rv
}
return reflect.DeepEqual(left, right)
}
func isNumericValue(value any) bool {
switch value.(type) {
case json.Number, float64, float32, int, int8, int16, int32, int64,
uint, uint8, uint16, uint32, uint64:
return true
default:
return false
}
}
func numericValue(value any) (float64, bool) {
switch v := value.(type) {
case json.Number:
n, err := v.Float64()
return n, err == nil
case float64:
return v, true
case float32:
return float64(v), true
case int:
return float64(v), true
case int8:
return float64(v), true
case int16:
return float64(v), true
case int32:
return float64(v), true
case int64:
return float64(v), true
case uint:
return float64(v), true
case uint8:
return float64(v), true
case uint16:
return float64(v), true
case uint32:
return float64(v), true
case uint64:
return float64(v), true
default:
return 0, false
}
}
func describeJSONValue(value any) string {
switch value.(type) {
case nil:

View file

@ -327,6 +327,85 @@ func TestToolBridge_Bad_InvalidRequestBody(t *testing.T) {
}
}
func TestToolBridge_Good_ValidatesEnumValues(t *testing.T) {
gin.SetMode(gin.TestMode)
engine := gin.New()
bridge := api.NewToolBridge("/tools")
bridge.Add(api.ToolDescriptor{
Name: "publish_item",
Description: "Publish an item",
Group: "items",
InputSchema: map[string]any{
"type": "object",
"properties": map[string]any{
"status": map[string]any{
"type": "string",
"enum": []any{"draft", "published"},
},
},
"required": []any{"status"},
},
}, func(c *gin.Context) {
c.JSON(http.StatusOK, api.OK("published"))
})
rg := engine.Group(bridge.BasePath())
bridge.RegisterRoutes(rg)
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodPost, "/tools/publish_item", bytes.NewBufferString(`{"status":"published"}`))
engine.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d", w.Code)
}
}
func TestToolBridge_Bad_RejectsAdditionalProperties(t *testing.T) {
gin.SetMode(gin.TestMode)
engine := gin.New()
bridge := api.NewToolBridge("/tools")
bridge.Add(api.ToolDescriptor{
Name: "publish_item",
Description: "Publish an item",
Group: "items",
InputSchema: map[string]any{
"type": "object",
"properties": map[string]any{
"status": map[string]any{"type": "string"},
},
"required": []any{"status"},
"additionalProperties": false,
},
}, func(c *gin.Context) {
c.JSON(http.StatusOK, api.OK("published"))
})
rg := engine.Group(bridge.BasePath())
bridge.RegisterRoutes(rg)
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodPost, "/tools/publish_item", bytes.NewBufferString(`{"status":"published","unexpected":true}`))
engine.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Fatalf("expected 400, got %d", w.Code)
}
var resp api.Response[any]
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("unmarshal error: %v", err)
}
if resp.Success {
t.Fatal("expected Success=false")
}
if resp.Error == nil || resp.Error.Code != "invalid_request_body" {
t.Fatalf("expected invalid_request_body error, got %#v", resp.Error)
}
}
func TestToolBridge_Good_ToolsAccessor(t *testing.T) {
bridge := api.NewToolBridge("/tools")
bridge.Add(api.ToolDescriptor{Name: "alpha", Description: "Tool A", Group: "a"}, func(c *gin.Context) {})