feat(api): validate ToolBridge output schemas

Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
Virgil 2026-04-01 13:18:10 +00:00
parent 9aa7c644ef
commit 5d5ca8aa51
2 changed files with 251 additions and 0 deletions

165
bridge.go
View file

@ -3,11 +3,14 @@
package api
import (
"bufio"
"bytes"
"encoding/json"
"errors"
"fmt"
"io"
"iter"
"net"
"net/http"
"strconv"
@ -46,6 +49,9 @@ func NewToolBridge(basePath string) *ToolBridge {
// Add registers a tool with its HTTP handler.
func (b *ToolBridge) Add(desc ToolDescriptor, handler gin.HandlerFunc) {
if validator := newToolInputValidator(desc.OutputSchema); validator != nil {
handler = wrapToolResponseHandler(handler, validator)
}
if validator := newToolInputValidator(desc.InputSchema); validator != nil {
handler = wrapToolHandler(handler, validator)
}
@ -156,6 +162,29 @@ func wrapToolHandler(handler gin.HandlerFunc, validator *toolInputValidator) gin
}
}
func wrapToolResponseHandler(handler gin.HandlerFunc, validator *toolInputValidator) gin.HandlerFunc {
return func(c *gin.Context) {
recorder := newToolResponseRecorder(c.Writer)
c.Writer = recorder
handler(c)
if recorder.Status() >= 200 && recorder.Status() < 300 {
if err := validator.ValidateResponse(recorder.body.Bytes()); err != nil {
recorder.reset()
recorder.writeErrorResponse(http.StatusInternalServerError, FailWithDetails(
"invalid_tool_response",
"Tool response does not match the declared output schema",
map[string]any{"error": err.Error()},
))
return
}
}
recorder.commit()
}
}
type toolInputValidator struct {
schema map[string]any
}
@ -187,6 +216,41 @@ func (v *toolInputValidator) Validate(body []byte) error {
return validateSchemaNode(payload, v.schema, "")
}
func (v *toolInputValidator) ValidateResponse(body []byte) error {
if len(v.schema) == 0 {
return nil
}
var envelope map[string]any
if err := json.Unmarshal(body, &envelope); err != nil {
return fmt.Errorf("invalid JSON response: %w", err)
}
success, _ := envelope["success"].(bool)
if !success {
return fmt.Errorf("response is missing a successful envelope")
}
data, ok := envelope["data"]
if !ok {
return fmt.Errorf("response is missing data")
}
encoded, err := json.Marshal(data)
if err != nil {
return fmt.Errorf("encode response data: %w", err)
}
var payload any
dec := json.NewDecoder(bytes.NewReader(encoded))
dec.UseNumber()
if err := dec.Decode(&payload); err != nil {
return fmt.Errorf("decode response data: %w", err)
}
return validateSchemaNode(payload, v.schema, "")
}
func validateSchemaNode(value any, schema map[string]any, path string) error {
if len(schema) == 0 {
return nil
@ -267,6 +331,107 @@ func validateSchemaNode(value any, schema map[string]any, path string) error {
return nil
}
type toolResponseRecorder struct {
gin.ResponseWriter
headers http.Header
body bytes.Buffer
status int
wroteHeader bool
}
func newToolResponseRecorder(w gin.ResponseWriter) *toolResponseRecorder {
headers := make(http.Header)
for k, vals := range w.Header() {
headers[k] = append([]string(nil), vals...)
}
return &toolResponseRecorder{
ResponseWriter: w,
headers: headers,
status: http.StatusOK,
}
}
func (w *toolResponseRecorder) Header() http.Header {
return w.headers
}
func (w *toolResponseRecorder) WriteHeader(code int) {
w.status = code
w.wroteHeader = true
}
func (w *toolResponseRecorder) WriteHeaderNow() {
w.wroteHeader = true
}
func (w *toolResponseRecorder) Write(data []byte) (int, error) {
if !w.wroteHeader {
w.WriteHeader(http.StatusOK)
}
return w.body.Write(data)
}
func (w *toolResponseRecorder) WriteString(s string) (int, error) {
if !w.wroteHeader {
w.WriteHeader(http.StatusOK)
}
return w.body.WriteString(s)
}
func (w *toolResponseRecorder) Flush() {
}
func (w *toolResponseRecorder) Status() int {
if w.wroteHeader {
return w.status
}
return http.StatusOK
}
func (w *toolResponseRecorder) Size() int {
return w.body.Len()
}
func (w *toolResponseRecorder) Written() bool {
return w.wroteHeader
}
func (w *toolResponseRecorder) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return nil, nil, errors.New("response hijacking is not supported by ToolBridge output validation")
}
func (w *toolResponseRecorder) commit() {
for k := range w.ResponseWriter.Header() {
w.ResponseWriter.Header().Del(k)
}
for k, vals := range w.headers {
for _, v := range vals {
w.ResponseWriter.Header().Add(k, v)
}
}
w.ResponseWriter.WriteHeader(w.Status())
_, _ = w.ResponseWriter.Write(w.body.Bytes())
}
func (w *toolResponseRecorder) reset() {
w.headers = make(http.Header)
w.body.Reset()
w.status = http.StatusInternalServerError
w.wroteHeader = false
}
func (w *toolResponseRecorder) writeErrorResponse(status int, resp Response[any]) {
data, err := json.Marshal(resp)
if err != nil {
http.Error(w.ResponseWriter, "internal server error", http.StatusInternalServerError)
return
}
w.ResponseWriter.Header().Set("Content-Type", "application/json")
w.ResponseWriter.WriteHeader(status)
_, _ = w.ResponseWriter.Write(data)
}
func typeError(path, want string, value any) error {
return fmt.Errorf("%s must be %s, got %s", displayPath(path), want, describeJSONValue(value))
}

View file

@ -198,6 +198,92 @@ func TestToolBridge_Good_ValidatesRequestBody(t *testing.T) {
}
}
func TestToolBridge_Good_ValidatesResponseBody(t *testing.T) {
gin.SetMode(gin.TestMode)
engine := gin.New()
bridge := api.NewToolBridge("/tools")
bridge.Add(api.ToolDescriptor{
Name: "file_read",
Description: "Read a file from disk",
Group: "files",
OutputSchema: map[string]any{
"type": "object",
"properties": map[string]any{
"path": map[string]any{"type": "string"},
},
"required": []any{"path"},
},
}, func(c *gin.Context) {
c.JSON(http.StatusOK, api.OK(map[string]any{"path": "/tmp/file.txt"}))
})
rg := engine.Group(bridge.BasePath())
bridge.RegisterRoutes(rg)
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodPost, "/tools/file_read", nil)
engine.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d", w.Code)
}
var resp api.Response[map[string]any]
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("unmarshal error: %v", err)
}
if !resp.Success {
t.Fatal("expected Success=true")
}
if resp.Data["path"] != "/tmp/file.txt" {
t.Fatalf("expected validated response data to reach client, got %v", resp.Data["path"])
}
}
func TestToolBridge_Bad_InvalidResponseBody(t *testing.T) {
gin.SetMode(gin.TestMode)
engine := gin.New()
bridge := api.NewToolBridge("/tools")
bridge.Add(api.ToolDescriptor{
Name: "file_read",
Description: "Read a file from disk",
Group: "files",
OutputSchema: map[string]any{
"type": "object",
"properties": map[string]any{
"path": map[string]any{"type": "string"},
},
"required": []any{"path"},
},
}, func(c *gin.Context) {
c.JSON(http.StatusOK, api.OK(map[string]any{"path": 123}))
})
rg := engine.Group(bridge.BasePath())
bridge.RegisterRoutes(rg)
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodPost, "/tools/file_read", nil)
engine.ServeHTTP(w, req)
if w.Code != http.StatusInternalServerError {
t.Fatalf("expected 500, got %d", w.Code)
}
var resp api.Response[any]
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("unmarshal error: %v", err)
}
if resp.Success {
t.Fatal("expected Success=false")
}
if resp.Error == nil || resp.Error.Code != "invalid_tool_response" {
t.Fatalf("expected invalid_tool_response error, got %#v", resp.Error)
}
}
func TestToolBridge_Bad_InvalidRequestBody(t *testing.T) {
gin.SetMode(gin.TestMode)
engine := gin.New()