feat(bridge): validate tool request bodies
This commit is contained in:
parent
10fc9559fa
commit
6ef194754e
2 changed files with 319 additions and 0 deletions
231
bridge.go
231
bridge.go
|
|
@ -3,7 +3,13 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"iter"
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
|
@ -40,6 +46,9 @@ func NewToolBridge(basePath string) *ToolBridge {
|
|||
|
||||
// Add registers a tool with its HTTP handler.
|
||||
func (b *ToolBridge) Add(desc ToolDescriptor, handler gin.HandlerFunc) {
|
||||
if validator := newToolInputValidator(desc.InputSchema); validator != nil {
|
||||
handler = wrapToolHandler(handler, validator)
|
||||
}
|
||||
b.tools = append(b.tools, boundTool{descriptor: desc, handler: handler})
|
||||
}
|
||||
|
||||
|
|
@ -120,3 +129,225 @@ func (b *ToolBridge) ToolsIter() iter.Seq[ToolDescriptor] {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
func wrapToolHandler(handler gin.HandlerFunc, validator *toolInputValidator) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
body, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, FailWithDetails(
|
||||
"invalid_request_body",
|
||||
"Unable to read request body",
|
||||
map[string]any{"error": err.Error()},
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
if err := validator.Validate(body); err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, FailWithDetails(
|
||||
"invalid_request_body",
|
||||
"Request body does not match the declared tool schema",
|
||||
map[string]any{"error": err.Error()},
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
c.Request.Body = io.NopCloser(bytes.NewReader(body))
|
||||
handler(c)
|
||||
}
|
||||
}
|
||||
|
||||
type toolInputValidator struct {
|
||||
schema map[string]any
|
||||
}
|
||||
|
||||
func newToolInputValidator(schema map[string]any) *toolInputValidator {
|
||||
if len(schema) == 0 {
|
||||
return nil
|
||||
}
|
||||
return &toolInputValidator{schema: schema}
|
||||
}
|
||||
|
||||
func (v *toolInputValidator) Validate(body []byte) error {
|
||||
if len(bytes.TrimSpace(body)) == 0 {
|
||||
return fmt.Errorf("request body is required")
|
||||
}
|
||||
|
||||
dec := json.NewDecoder(bytes.NewReader(body))
|
||||
dec.UseNumber()
|
||||
|
||||
var payload any
|
||||
if err := dec.Decode(&payload); err != nil {
|
||||
return fmt.Errorf("invalid JSON: %w", err)
|
||||
}
|
||||
var extra any
|
||||
if err := dec.Decode(&extra); err != io.EOF {
|
||||
return fmt.Errorf("request body must contain a single JSON value")
|
||||
}
|
||||
|
||||
return validateSchemaNode(payload, v.schema, "")
|
||||
}
|
||||
|
||||
func validateSchemaNode(value any, schema map[string]any, path string) error {
|
||||
if len(schema) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if schemaType, _ := schema["type"].(string); schemaType != "" {
|
||||
switch schemaType {
|
||||
case "object":
|
||||
obj, ok := value.(map[string]any)
|
||||
if !ok {
|
||||
return typeError(path, "object", value)
|
||||
}
|
||||
|
||||
for _, name := range stringList(schema["required"]) {
|
||||
if _, ok := obj[name]; !ok {
|
||||
return fmt.Errorf("%s is missing required field %q", displayPath(path), name)
|
||||
}
|
||||
}
|
||||
|
||||
for name, rawChild := range schemaMap(schema["properties"]) {
|
||||
childSchema, ok := rawChild.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
childValue, ok := obj[name]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if err := validateSchemaNode(childValue, childSchema, joinPath(path, name)); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
case "array":
|
||||
arr, ok := value.([]any)
|
||||
if !ok {
|
||||
return typeError(path, "array", value)
|
||||
}
|
||||
if items := schemaMap(schema["items"]); len(items) > 0 {
|
||||
for i, item := range arr {
|
||||
if err := validateSchemaNode(item, items, joinPath(path, strconv.Itoa(i))); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
case "string":
|
||||
if _, ok := value.(string); !ok {
|
||||
return typeError(path, "string", value)
|
||||
}
|
||||
return nil
|
||||
case "boolean":
|
||||
if _, ok := value.(bool); !ok {
|
||||
return typeError(path, "boolean", value)
|
||||
}
|
||||
return nil
|
||||
case "integer":
|
||||
if !isIntegerValue(value) {
|
||||
return typeError(path, "integer", value)
|
||||
}
|
||||
return nil
|
||||
case "number":
|
||||
if !isNumberValue(value) {
|
||||
return typeError(path, "number", value)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
if props := schemaMap(schema["properties"]); len(props) > 0 {
|
||||
return validateSchemaNode(value, map[string]any{
|
||||
"type": "object",
|
||||
"properties": props,
|
||||
"required": schema["required"],
|
||||
}, path)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func typeError(path, want string, value any) error {
|
||||
return fmt.Errorf("%s must be %s, got %s", displayPath(path), want, describeJSONValue(value))
|
||||
}
|
||||
|
||||
func displayPath(path string) string {
|
||||
if path == "" {
|
||||
return "request body"
|
||||
}
|
||||
return "request body." + path
|
||||
}
|
||||
|
||||
func joinPath(parent, child string) string {
|
||||
if parent == "" {
|
||||
return child
|
||||
}
|
||||
return parent + "." + child
|
||||
}
|
||||
|
||||
func schemaMap(value any) map[string]any {
|
||||
if value == nil {
|
||||
return nil
|
||||
}
|
||||
m, _ := value.(map[string]any)
|
||||
return m
|
||||
}
|
||||
|
||||
func stringList(value any) []string {
|
||||
switch raw := value.(type) {
|
||||
case []any:
|
||||
out := make([]string, 0, len(raw))
|
||||
for _, item := range raw {
|
||||
name, ok := item.(string)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
out = append(out, name)
|
||||
}
|
||||
return out
|
||||
case []string:
|
||||
return append([]string(nil), raw...)
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func isIntegerValue(value any) bool {
|
||||
switch v := value.(type) {
|
||||
case json.Number:
|
||||
_, err := v.Int64()
|
||||
return err == nil
|
||||
case float64:
|
||||
return v == float64(int64(v))
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func isNumberValue(value any) bool {
|
||||
switch value.(type) {
|
||||
case json.Number, float64:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func describeJSONValue(value any) string {
|
||||
switch value.(type) {
|
||||
case nil:
|
||||
return "null"
|
||||
case string:
|
||||
return "string"
|
||||
case bool:
|
||||
return "boolean"
|
||||
case json.Number, float64:
|
||||
return "number"
|
||||
case map[string]any:
|
||||
return "object"
|
||||
case []any:
|
||||
return "array"
|
||||
default:
|
||||
return fmt.Sprintf("%T", value)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
package api_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
|
|
@ -153,6 +154,93 @@ func TestToolBridge_Good_Describe(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestToolBridge_Good_ValidatesRequestBody(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
engine := gin.New()
|
||||
|
||||
bridge := api.NewToolBridge("/tools")
|
||||
bridge.Add(api.ToolDescriptor{
|
||||
Name: "file_read",
|
||||
Description: "Read a file from disk",
|
||||
Group: "files",
|
||||
InputSchema: map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"path": map[string]any{"type": "string"},
|
||||
},
|
||||
"required": []any{"path"},
|
||||
},
|
||||
}, func(c *gin.Context) {
|
||||
var payload map[string]any
|
||||
if err := json.NewDecoder(c.Request.Body).Decode(&payload); err != nil {
|
||||
t.Fatalf("handler could not read validated body: %v", err)
|
||||
}
|
||||
c.JSON(http.StatusOK, api.OK(payload["path"]))
|
||||
})
|
||||
|
||||
rg := engine.Group(bridge.BasePath())
|
||||
bridge.RegisterRoutes(rg)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest(http.MethodPost, "/tools/file_read", bytes.NewBufferString(`{"path":"/tmp/file.txt"}`))
|
||||
engine.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
var resp api.Response[string]
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("unmarshal error: %v", err)
|
||||
}
|
||||
if resp.Data != "/tmp/file.txt" {
|
||||
t.Fatalf("expected validated payload to reach handler, got %q", resp.Data)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolBridge_Bad_InvalidRequestBody(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
engine := gin.New()
|
||||
|
||||
bridge := api.NewToolBridge("/tools")
|
||||
bridge.Add(api.ToolDescriptor{
|
||||
Name: "file_read",
|
||||
Description: "Read a file from disk",
|
||||
Group: "files",
|
||||
InputSchema: map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"path": map[string]any{"type": "string"},
|
||||
},
|
||||
"required": []any{"path"},
|
||||
},
|
||||
}, func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, api.OK("should not run"))
|
||||
})
|
||||
|
||||
rg := engine.Group(bridge.BasePath())
|
||||
bridge.RegisterRoutes(rg)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest(http.MethodPost, "/tools/file_read", bytes.NewBufferString(`{"path":123}`))
|
||||
engine.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Fatalf("expected 400, got %d", w.Code)
|
||||
}
|
||||
|
||||
var resp api.Response[any]
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("unmarshal error: %v", err)
|
||||
}
|
||||
if resp.Success {
|
||||
t.Fatal("expected Success=false")
|
||||
}
|
||||
if resp.Error == nil || resp.Error.Code != "invalid_request_body" {
|
||||
t.Fatalf("expected invalid_request_body error, got %#v", resp.Error)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolBridge_Good_ToolsAccessor(t *testing.T) {
|
||||
bridge := api.NewToolBridge("/tools")
|
||||
bridge.Add(api.ToolDescriptor{Name: "alpha", Description: "Tool A", Group: "a"}, func(c *gin.Context) {})
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue