diff --git a/bridge.go b/bridge.go index 79e2e78..8dd8bc9 100644 --- a/bridge.go +++ b/bridge.go @@ -3,7 +3,13 @@ package api import ( + "bytes" + "encoding/json" + "fmt" + "io" "iter" + "net/http" + "strconv" "github.com/gin-gonic/gin" ) @@ -40,6 +46,9 @@ func NewToolBridge(basePath string) *ToolBridge { // Add registers a tool with its HTTP handler. func (b *ToolBridge) Add(desc ToolDescriptor, handler gin.HandlerFunc) { + if validator := newToolInputValidator(desc.InputSchema); validator != nil { + handler = wrapToolHandler(handler, validator) + } b.tools = append(b.tools, boundTool{descriptor: desc, handler: handler}) } @@ -120,3 +129,225 @@ func (b *ToolBridge) ToolsIter() iter.Seq[ToolDescriptor] { } } } + +func wrapToolHandler(handler gin.HandlerFunc, validator *toolInputValidator) gin.HandlerFunc { + return func(c *gin.Context) { + body, err := io.ReadAll(c.Request.Body) + if err != nil { + c.AbortWithStatusJSON(http.StatusBadRequest, FailWithDetails( + "invalid_request_body", + "Unable to read request body", + map[string]any{"error": err.Error()}, + )) + return + } + + if err := validator.Validate(body); err != nil { + c.AbortWithStatusJSON(http.StatusBadRequest, FailWithDetails( + "invalid_request_body", + "Request body does not match the declared tool schema", + map[string]any{"error": err.Error()}, + )) + return + } + + c.Request.Body = io.NopCloser(bytes.NewReader(body)) + handler(c) + } +} + +type toolInputValidator struct { + schema map[string]any +} + +func newToolInputValidator(schema map[string]any) *toolInputValidator { + if len(schema) == 0 { + return nil + } + return &toolInputValidator{schema: schema} +} + +func (v *toolInputValidator) Validate(body []byte) error { + if len(bytes.TrimSpace(body)) == 0 { + return fmt.Errorf("request body is required") + } + + dec := json.NewDecoder(bytes.NewReader(body)) + dec.UseNumber() + + var payload any + if err := dec.Decode(&payload); err != nil { + return fmt.Errorf("invalid JSON: %w", err) + } + var extra any + if err := dec.Decode(&extra); err != io.EOF { + return fmt.Errorf("request body must contain a single JSON value") + } + + return validateSchemaNode(payload, v.schema, "") +} + +func validateSchemaNode(value any, schema map[string]any, path string) error { + if len(schema) == 0 { + return nil + } + + if schemaType, _ := schema["type"].(string); schemaType != "" { + switch schemaType { + case "object": + obj, ok := value.(map[string]any) + if !ok { + return typeError(path, "object", value) + } + + for _, name := range stringList(schema["required"]) { + if _, ok := obj[name]; !ok { + return fmt.Errorf("%s is missing required field %q", displayPath(path), name) + } + } + + for name, rawChild := range schemaMap(schema["properties"]) { + childSchema, ok := rawChild.(map[string]any) + if !ok { + continue + } + childValue, ok := obj[name] + if !ok { + continue + } + if err := validateSchemaNode(childValue, childSchema, joinPath(path, name)); err != nil { + return err + } + } + return nil + case "array": + arr, ok := value.([]any) + if !ok { + return typeError(path, "array", value) + } + if items := schemaMap(schema["items"]); len(items) > 0 { + for i, item := range arr { + if err := validateSchemaNode(item, items, joinPath(path, strconv.Itoa(i))); err != nil { + return err + } + } + } + return nil + case "string": + if _, ok := value.(string); !ok { + return typeError(path, "string", value) + } + return nil + case "boolean": + if _, ok := value.(bool); !ok { + return typeError(path, "boolean", value) + } + return nil + case "integer": + if !isIntegerValue(value) { + return typeError(path, "integer", value) + } + return nil + case "number": + if !isNumberValue(value) { + return typeError(path, "number", value) + } + return nil + } + } + + if props := schemaMap(schema["properties"]); len(props) > 0 { + return validateSchemaNode(value, map[string]any{ + "type": "object", + "properties": props, + "required": schema["required"], + }, path) + } + + return nil +} + +func typeError(path, want string, value any) error { + return fmt.Errorf("%s must be %s, got %s", displayPath(path), want, describeJSONValue(value)) +} + +func displayPath(path string) string { + if path == "" { + return "request body" + } + return "request body." + path +} + +func joinPath(parent, child string) string { + if parent == "" { + return child + } + return parent + "." + child +} + +func schemaMap(value any) map[string]any { + if value == nil { + return nil + } + m, _ := value.(map[string]any) + return m +} + +func stringList(value any) []string { + switch raw := value.(type) { + case []any: + out := make([]string, 0, len(raw)) + for _, item := range raw { + name, ok := item.(string) + if !ok { + continue + } + out = append(out, name) + } + return out + case []string: + return append([]string(nil), raw...) + default: + return nil + } +} + +func isIntegerValue(value any) bool { + switch v := value.(type) { + case json.Number: + _, err := v.Int64() + return err == nil + case float64: + return v == float64(int64(v)) + default: + return false + } +} + +func isNumberValue(value any) bool { + switch value.(type) { + case json.Number, float64: + return true + default: + return false + } +} + +func describeJSONValue(value any) string { + switch value.(type) { + case nil: + return "null" + case string: + return "string" + case bool: + return "boolean" + case json.Number, float64: + return "number" + case map[string]any: + return "object" + case []any: + return "array" + default: + return fmt.Sprintf("%T", value) + } +} diff --git a/bridge_test.go b/bridge_test.go index 3c5c6c4..f2eb442 100644 --- a/bridge_test.go +++ b/bridge_test.go @@ -3,6 +3,7 @@ package api_test import ( + "bytes" "encoding/json" "net/http" "net/http/httptest" @@ -153,6 +154,93 @@ func TestToolBridge_Good_Describe(t *testing.T) { } } +func TestToolBridge_Good_ValidatesRequestBody(t *testing.T) { + gin.SetMode(gin.TestMode) + engine := gin.New() + + bridge := api.NewToolBridge("/tools") + bridge.Add(api.ToolDescriptor{ + Name: "file_read", + Description: "Read a file from disk", + Group: "files", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "path": map[string]any{"type": "string"}, + }, + "required": []any{"path"}, + }, + }, func(c *gin.Context) { + var payload map[string]any + if err := json.NewDecoder(c.Request.Body).Decode(&payload); err != nil { + t.Fatalf("handler could not read validated body: %v", err) + } + c.JSON(http.StatusOK, api.OK(payload["path"])) + }) + + rg := engine.Group(bridge.BasePath()) + bridge.RegisterRoutes(rg) + + w := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodPost, "/tools/file_read", bytes.NewBufferString(`{"path":"/tmp/file.txt"}`)) + engine.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", w.Code) + } + + var resp api.Response[string] + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("unmarshal error: %v", err) + } + if resp.Data != "/tmp/file.txt" { + t.Fatalf("expected validated payload to reach handler, got %q", resp.Data) + } +} + +func TestToolBridge_Bad_InvalidRequestBody(t *testing.T) { + gin.SetMode(gin.TestMode) + engine := gin.New() + + bridge := api.NewToolBridge("/tools") + bridge.Add(api.ToolDescriptor{ + Name: "file_read", + Description: "Read a file from disk", + Group: "files", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "path": map[string]any{"type": "string"}, + }, + "required": []any{"path"}, + }, + }, func(c *gin.Context) { + c.JSON(http.StatusOK, api.OK("should not run")) + }) + + rg := engine.Group(bridge.BasePath()) + bridge.RegisterRoutes(rg) + + w := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodPost, "/tools/file_read", bytes.NewBufferString(`{"path":123}`)) + engine.ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Fatalf("expected 400, got %d", w.Code) + } + + var resp api.Response[any] + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("unmarshal error: %v", err) + } + if resp.Success { + t.Fatal("expected Success=false") + } + if resp.Error == nil || resp.Error.Code != "invalid_request_body" { + t.Fatalf("expected invalid_request_body error, got %#v", resp.Error) + } +} + func TestToolBridge_Good_ToolsAccessor(t *testing.T) { bridge := api.NewToolBridge("/tools") bridge.Add(api.ToolDescriptor{Name: "alpha", Description: "Tool A", Group: "a"}, func(c *gin.Context) {})