feat(bridge): validate tool request bodies

This commit is contained in:
Virgil 2026-04-01 06:23:58 +00:00
parent 10fc9559fa
commit 6ef194754e
2 changed files with 319 additions and 0 deletions

231
bridge.go
View file

@ -3,7 +3,13 @@
package api
import (
"bytes"
"encoding/json"
"fmt"
"io"
"iter"
"net/http"
"strconv"
"github.com/gin-gonic/gin"
)
@ -40,6 +46,9 @@ func NewToolBridge(basePath string) *ToolBridge {
// Add registers a tool with its HTTP handler.
func (b *ToolBridge) Add(desc ToolDescriptor, handler gin.HandlerFunc) {
if validator := newToolInputValidator(desc.InputSchema); validator != nil {
handler = wrapToolHandler(handler, validator)
}
b.tools = append(b.tools, boundTool{descriptor: desc, handler: handler})
}
@ -120,3 +129,225 @@ func (b *ToolBridge) ToolsIter() iter.Seq[ToolDescriptor] {
}
}
}
func wrapToolHandler(handler gin.HandlerFunc, validator *toolInputValidator) gin.HandlerFunc {
return func(c *gin.Context) {
body, err := io.ReadAll(c.Request.Body)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, FailWithDetails(
"invalid_request_body",
"Unable to read request body",
map[string]any{"error": err.Error()},
))
return
}
if err := validator.Validate(body); err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, FailWithDetails(
"invalid_request_body",
"Request body does not match the declared tool schema",
map[string]any{"error": err.Error()},
))
return
}
c.Request.Body = io.NopCloser(bytes.NewReader(body))
handler(c)
}
}
type toolInputValidator struct {
schema map[string]any
}
func newToolInputValidator(schema map[string]any) *toolInputValidator {
if len(schema) == 0 {
return nil
}
return &toolInputValidator{schema: schema}
}
func (v *toolInputValidator) Validate(body []byte) error {
if len(bytes.TrimSpace(body)) == 0 {
return fmt.Errorf("request body is required")
}
dec := json.NewDecoder(bytes.NewReader(body))
dec.UseNumber()
var payload any
if err := dec.Decode(&payload); err != nil {
return fmt.Errorf("invalid JSON: %w", err)
}
var extra any
if err := dec.Decode(&extra); err != io.EOF {
return fmt.Errorf("request body must contain a single JSON value")
}
return validateSchemaNode(payload, v.schema, "")
}
func validateSchemaNode(value any, schema map[string]any, path string) error {
if len(schema) == 0 {
return nil
}
if schemaType, _ := schema["type"].(string); schemaType != "" {
switch schemaType {
case "object":
obj, ok := value.(map[string]any)
if !ok {
return typeError(path, "object", value)
}
for _, name := range stringList(schema["required"]) {
if _, ok := obj[name]; !ok {
return fmt.Errorf("%s is missing required field %q", displayPath(path), name)
}
}
for name, rawChild := range schemaMap(schema["properties"]) {
childSchema, ok := rawChild.(map[string]any)
if !ok {
continue
}
childValue, ok := obj[name]
if !ok {
continue
}
if err := validateSchemaNode(childValue, childSchema, joinPath(path, name)); err != nil {
return err
}
}
return nil
case "array":
arr, ok := value.([]any)
if !ok {
return typeError(path, "array", value)
}
if items := schemaMap(schema["items"]); len(items) > 0 {
for i, item := range arr {
if err := validateSchemaNode(item, items, joinPath(path, strconv.Itoa(i))); err != nil {
return err
}
}
}
return nil
case "string":
if _, ok := value.(string); !ok {
return typeError(path, "string", value)
}
return nil
case "boolean":
if _, ok := value.(bool); !ok {
return typeError(path, "boolean", value)
}
return nil
case "integer":
if !isIntegerValue(value) {
return typeError(path, "integer", value)
}
return nil
case "number":
if !isNumberValue(value) {
return typeError(path, "number", value)
}
return nil
}
}
if props := schemaMap(schema["properties"]); len(props) > 0 {
return validateSchemaNode(value, map[string]any{
"type": "object",
"properties": props,
"required": schema["required"],
}, path)
}
return nil
}
func typeError(path, want string, value any) error {
return fmt.Errorf("%s must be %s, got %s", displayPath(path), want, describeJSONValue(value))
}
func displayPath(path string) string {
if path == "" {
return "request body"
}
return "request body." + path
}
func joinPath(parent, child string) string {
if parent == "" {
return child
}
return parent + "." + child
}
func schemaMap(value any) map[string]any {
if value == nil {
return nil
}
m, _ := value.(map[string]any)
return m
}
func stringList(value any) []string {
switch raw := value.(type) {
case []any:
out := make([]string, 0, len(raw))
for _, item := range raw {
name, ok := item.(string)
if !ok {
continue
}
out = append(out, name)
}
return out
case []string:
return append([]string(nil), raw...)
default:
return nil
}
}
func isIntegerValue(value any) bool {
switch v := value.(type) {
case json.Number:
_, err := v.Int64()
return err == nil
case float64:
return v == float64(int64(v))
default:
return false
}
}
func isNumberValue(value any) bool {
switch value.(type) {
case json.Number, float64:
return true
default:
return false
}
}
func describeJSONValue(value any) string {
switch value.(type) {
case nil:
return "null"
case string:
return "string"
case bool:
return "boolean"
case json.Number, float64:
return "number"
case map[string]any:
return "object"
case []any:
return "array"
default:
return fmt.Sprintf("%T", value)
}
}

View file

@ -3,6 +3,7 @@
package api_test
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
@ -153,6 +154,93 @@ func TestToolBridge_Good_Describe(t *testing.T) {
}
}
func TestToolBridge_Good_ValidatesRequestBody(t *testing.T) {
gin.SetMode(gin.TestMode)
engine := gin.New()
bridge := api.NewToolBridge("/tools")
bridge.Add(api.ToolDescriptor{
Name: "file_read",
Description: "Read a file from disk",
Group: "files",
InputSchema: map[string]any{
"type": "object",
"properties": map[string]any{
"path": map[string]any{"type": "string"},
},
"required": []any{"path"},
},
}, func(c *gin.Context) {
var payload map[string]any
if err := json.NewDecoder(c.Request.Body).Decode(&payload); err != nil {
t.Fatalf("handler could not read validated body: %v", err)
}
c.JSON(http.StatusOK, api.OK(payload["path"]))
})
rg := engine.Group(bridge.BasePath())
bridge.RegisterRoutes(rg)
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodPost, "/tools/file_read", bytes.NewBufferString(`{"path":"/tmp/file.txt"}`))
engine.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d", w.Code)
}
var resp api.Response[string]
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("unmarshal error: %v", err)
}
if resp.Data != "/tmp/file.txt" {
t.Fatalf("expected validated payload to reach handler, got %q", resp.Data)
}
}
func TestToolBridge_Bad_InvalidRequestBody(t *testing.T) {
gin.SetMode(gin.TestMode)
engine := gin.New()
bridge := api.NewToolBridge("/tools")
bridge.Add(api.ToolDescriptor{
Name: "file_read",
Description: "Read a file from disk",
Group: "files",
InputSchema: map[string]any{
"type": "object",
"properties": map[string]any{
"path": map[string]any{"type": "string"},
},
"required": []any{"path"},
},
}, func(c *gin.Context) {
c.JSON(http.StatusOK, api.OK("should not run"))
})
rg := engine.Group(bridge.BasePath())
bridge.RegisterRoutes(rg)
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodPost, "/tools/file_read", bytes.NewBufferString(`{"path":123}`))
engine.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Fatalf("expected 400, got %d", w.Code)
}
var resp api.Response[any]
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("unmarshal error: %v", err)
}
if resp.Success {
t.Fatal("expected Success=false")
}
if resp.Error == nil || resp.Error.Code != "invalid_request_body" {
t.Fatalf("expected invalid_request_body error, got %#v", resp.Error)
}
}
func TestToolBridge_Good_ToolsAccessor(t *testing.T) {
bridge := api.NewToolBridge("/tools")
bridge.Add(api.ToolDescriptor{Name: "alpha", Description: "Tool A", Group: "a"}, func(c *gin.Context) {})