From ac219926237fbbd4199328cd7c06c1ef29eadc34 Mon Sep 17 00:00:00 2001 From: Virgil Date: Wed, 1 Apr 2026 15:46:35 +0000 Subject: [PATCH] feat(api): enforce tool schema enums Co-Authored-By: Virgil --- bridge.go | 111 +++++++++++++++++++++++++++++++++++++++++++++++-- bridge_test.go | 79 +++++++++++++++++++++++++++++++++++ 2 files changed, 186 insertions(+), 4 deletions(-) diff --git a/bridge.go b/bridge.go index 7932720..067dd88 100644 --- a/bridge.go +++ b/bridge.go @@ -12,6 +12,7 @@ import ( "iter" "net" "net/http" + "reflect" "strconv" "github.com/gin-gonic/gin" @@ -283,6 +284,19 @@ func validateSchemaNode(value any, schema map[string]any, path string) error { return err } } + + if additionalProperties, ok := schema["additionalProperties"].(bool); ok && !additionalProperties { + properties := schemaMap(schema["properties"]) + for name := range obj { + if properties != nil { + if _, ok := properties[name]; ok { + continue + } + } + return fmt.Errorf("%s contains unknown field %q", displayPath(path), name) + } + } + return nil case "array": arr, ok := value.([]any) @@ -320,14 +334,21 @@ func validateSchemaNode(value any, schema map[string]any, path string) error { } } - if props := schemaMap(schema["properties"]); len(props) > 0 { + if props := schemaMap(schema["properties"]); len(props) > 0 || schema["required"] != nil || schema["additionalProperties"] != nil { return validateSchemaNode(value, map[string]any{ - "type": "object", - "properties": props, - "required": schema["required"], + "type": "object", + "properties": props, + "required": schema["required"], + "additionalProperties": schema["additionalProperties"], }, path) } + if rawEnum, ok := schema["enum"]; ok { + if !enumContains(value, rawEnum) { + return fmt.Errorf("%s must be one of the declared enum values", displayPath(path)) + } + } + return nil } @@ -498,6 +519,88 @@ func isNumberValue(value any) bool { } } +func enumContains(value any, rawEnum any) bool { + items := enumValues(rawEnum) + for _, candidate := range items { + if valuesEqual(value, candidate) { + return true + } + } + return false +} + +func enumValues(rawEnum any) []any { + switch values := rawEnum.(type) { + case []any: + out := make([]any, 0, len(values)) + for _, value := range values { + out = append(out, value) + } + return out + case []string: + out := make([]any, 0, len(values)) + for _, value := range values { + out = append(out, value) + } + return out + default: + return nil + } +} + +func valuesEqual(left, right any) bool { + if isNumericValue(left) && isNumericValue(right) { + lv, lok := numericValue(left) + rv, rok := numericValue(right) + return lok && rok && lv == rv + } + return reflect.DeepEqual(left, right) +} + +func isNumericValue(value any) bool { + switch value.(type) { + case json.Number, float64, float32, int, int8, int16, int32, int64, + uint, uint8, uint16, uint32, uint64: + return true + default: + return false + } +} + +func numericValue(value any) (float64, bool) { + switch v := value.(type) { + case json.Number: + n, err := v.Float64() + return n, err == nil + case float64: + return v, true + case float32: + return float64(v), true + case int: + return float64(v), true + case int8: + return float64(v), true + case int16: + return float64(v), true + case int32: + return float64(v), true + case int64: + return float64(v), true + case uint: + return float64(v), true + case uint8: + return float64(v), true + case uint16: + return float64(v), true + case uint32: + return float64(v), true + case uint64: + return float64(v), true + default: + return 0, false + } +} + func describeJSONValue(value any) string { switch value.(type) { case nil: diff --git a/bridge_test.go b/bridge_test.go index dc607a6..ba2e29b 100644 --- a/bridge_test.go +++ b/bridge_test.go @@ -327,6 +327,85 @@ func TestToolBridge_Bad_InvalidRequestBody(t *testing.T) { } } +func TestToolBridge_Good_ValidatesEnumValues(t *testing.T) { + gin.SetMode(gin.TestMode) + engine := gin.New() + + bridge := api.NewToolBridge("/tools") + bridge.Add(api.ToolDescriptor{ + Name: "publish_item", + Description: "Publish an item", + Group: "items", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "status": map[string]any{ + "type": "string", + "enum": []any{"draft", "published"}, + }, + }, + "required": []any{"status"}, + }, + }, func(c *gin.Context) { + c.JSON(http.StatusOK, api.OK("published")) + }) + + rg := engine.Group(bridge.BasePath()) + bridge.RegisterRoutes(rg) + + w := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodPost, "/tools/publish_item", bytes.NewBufferString(`{"status":"published"}`)) + engine.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", w.Code) + } +} + +func TestToolBridge_Bad_RejectsAdditionalProperties(t *testing.T) { + gin.SetMode(gin.TestMode) + engine := gin.New() + + bridge := api.NewToolBridge("/tools") + bridge.Add(api.ToolDescriptor{ + Name: "publish_item", + Description: "Publish an item", + Group: "items", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "status": map[string]any{"type": "string"}, + }, + "required": []any{"status"}, + "additionalProperties": false, + }, + }, func(c *gin.Context) { + c.JSON(http.StatusOK, api.OK("published")) + }) + + rg := engine.Group(bridge.BasePath()) + bridge.RegisterRoutes(rg) + + w := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodPost, "/tools/publish_item", bytes.NewBufferString(`{"status":"published","unexpected":true}`)) + engine.ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Fatalf("expected 400, got %d", w.Code) + } + + var resp api.Response[any] + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("unmarshal error: %v", err) + } + if resp.Success { + t.Fatal("expected Success=false") + } + if resp.Error == nil || resp.Error.Code != "invalid_request_body" { + t.Fatalf("expected invalid_request_body error, got %#v", resp.Error) + } +} + func TestToolBridge_Good_ToolsAccessor(t *testing.T) { bridge := api.NewToolBridge("/tools") bridge.Add(api.ToolDescriptor{Name: "alpha", Description: "Tool A", Group: "a"}, func(c *gin.Context) {})