diff --git a/bridge.go b/bridge.go index 8dd8bc9..7932720 100644 --- a/bridge.go +++ b/bridge.go @@ -3,11 +3,14 @@ package api import ( + "bufio" "bytes" "encoding/json" + "errors" "fmt" "io" "iter" + "net" "net/http" "strconv" @@ -46,6 +49,9 @@ func NewToolBridge(basePath string) *ToolBridge { // Add registers a tool with its HTTP handler. func (b *ToolBridge) Add(desc ToolDescriptor, handler gin.HandlerFunc) { + if validator := newToolInputValidator(desc.OutputSchema); validator != nil { + handler = wrapToolResponseHandler(handler, validator) + } if validator := newToolInputValidator(desc.InputSchema); validator != nil { handler = wrapToolHandler(handler, validator) } @@ -156,6 +162,29 @@ func wrapToolHandler(handler gin.HandlerFunc, validator *toolInputValidator) gin } } +func wrapToolResponseHandler(handler gin.HandlerFunc, validator *toolInputValidator) gin.HandlerFunc { + return func(c *gin.Context) { + recorder := newToolResponseRecorder(c.Writer) + c.Writer = recorder + + handler(c) + + if recorder.Status() >= 200 && recorder.Status() < 300 { + if err := validator.ValidateResponse(recorder.body.Bytes()); err != nil { + recorder.reset() + recorder.writeErrorResponse(http.StatusInternalServerError, FailWithDetails( + "invalid_tool_response", + "Tool response does not match the declared output schema", + map[string]any{"error": err.Error()}, + )) + return + } + } + + recorder.commit() + } +} + type toolInputValidator struct { schema map[string]any } @@ -187,6 +216,41 @@ func (v *toolInputValidator) Validate(body []byte) error { return validateSchemaNode(payload, v.schema, "") } +func (v *toolInputValidator) ValidateResponse(body []byte) error { + if len(v.schema) == 0 { + return nil + } + + var envelope map[string]any + if err := json.Unmarshal(body, &envelope); err != nil { + return fmt.Errorf("invalid JSON response: %w", err) + } + + success, _ := envelope["success"].(bool) + if !success { + return fmt.Errorf("response is missing a successful envelope") + } + + data, ok := envelope["data"] + if !ok { + return fmt.Errorf("response is missing data") + } + + encoded, err := json.Marshal(data) + if err != nil { + return fmt.Errorf("encode response data: %w", err) + } + + var payload any + dec := json.NewDecoder(bytes.NewReader(encoded)) + dec.UseNumber() + if err := dec.Decode(&payload); err != nil { + return fmt.Errorf("decode response data: %w", err) + } + + return validateSchemaNode(payload, v.schema, "") +} + func validateSchemaNode(value any, schema map[string]any, path string) error { if len(schema) == 0 { return nil @@ -267,6 +331,107 @@ func validateSchemaNode(value any, schema map[string]any, path string) error { return nil } +type toolResponseRecorder struct { + gin.ResponseWriter + headers http.Header + body bytes.Buffer + status int + wroteHeader bool +} + +func newToolResponseRecorder(w gin.ResponseWriter) *toolResponseRecorder { + headers := make(http.Header) + for k, vals := range w.Header() { + headers[k] = append([]string(nil), vals...) + } + return &toolResponseRecorder{ + ResponseWriter: w, + headers: headers, + status: http.StatusOK, + } +} + +func (w *toolResponseRecorder) Header() http.Header { + return w.headers +} + +func (w *toolResponseRecorder) WriteHeader(code int) { + w.status = code + w.wroteHeader = true +} + +func (w *toolResponseRecorder) WriteHeaderNow() { + w.wroteHeader = true +} + +func (w *toolResponseRecorder) Write(data []byte) (int, error) { + if !w.wroteHeader { + w.WriteHeader(http.StatusOK) + } + return w.body.Write(data) +} + +func (w *toolResponseRecorder) WriteString(s string) (int, error) { + if !w.wroteHeader { + w.WriteHeader(http.StatusOK) + } + return w.body.WriteString(s) +} + +func (w *toolResponseRecorder) Flush() { +} + +func (w *toolResponseRecorder) Status() int { + if w.wroteHeader { + return w.status + } + return http.StatusOK +} + +func (w *toolResponseRecorder) Size() int { + return w.body.Len() +} + +func (w *toolResponseRecorder) Written() bool { + return w.wroteHeader +} + +func (w *toolResponseRecorder) Hijack() (net.Conn, *bufio.ReadWriter, error) { + return nil, nil, errors.New("response hijacking is not supported by ToolBridge output validation") +} + +func (w *toolResponseRecorder) commit() { + for k := range w.ResponseWriter.Header() { + w.ResponseWriter.Header().Del(k) + } + for k, vals := range w.headers { + for _, v := range vals { + w.ResponseWriter.Header().Add(k, v) + } + } + w.ResponseWriter.WriteHeader(w.Status()) + _, _ = w.ResponseWriter.Write(w.body.Bytes()) +} + +func (w *toolResponseRecorder) reset() { + w.headers = make(http.Header) + w.body.Reset() + w.status = http.StatusInternalServerError + w.wroteHeader = false +} + +func (w *toolResponseRecorder) writeErrorResponse(status int, resp Response[any]) { + data, err := json.Marshal(resp) + if err != nil { + http.Error(w.ResponseWriter, "internal server error", http.StatusInternalServerError) + return + } + + w.ResponseWriter.Header().Set("Content-Type", "application/json") + w.ResponseWriter.WriteHeader(status) + _, _ = w.ResponseWriter.Write(data) +} + func typeError(path, want string, value any) error { return fmt.Errorf("%s must be %s, got %s", displayPath(path), want, describeJSONValue(value)) } diff --git a/bridge_test.go b/bridge_test.go index f2eb442..dc607a6 100644 --- a/bridge_test.go +++ b/bridge_test.go @@ -198,6 +198,92 @@ func TestToolBridge_Good_ValidatesRequestBody(t *testing.T) { } } +func TestToolBridge_Good_ValidatesResponseBody(t *testing.T) { + gin.SetMode(gin.TestMode) + engine := gin.New() + + bridge := api.NewToolBridge("/tools") + bridge.Add(api.ToolDescriptor{ + Name: "file_read", + Description: "Read a file from disk", + Group: "files", + OutputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "path": map[string]any{"type": "string"}, + }, + "required": []any{"path"}, + }, + }, func(c *gin.Context) { + c.JSON(http.StatusOK, api.OK(map[string]any{"path": "/tmp/file.txt"})) + }) + + rg := engine.Group(bridge.BasePath()) + bridge.RegisterRoutes(rg) + + w := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodPost, "/tools/file_read", nil) + engine.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", w.Code) + } + + var resp api.Response[map[string]any] + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("unmarshal error: %v", err) + } + if !resp.Success { + t.Fatal("expected Success=true") + } + if resp.Data["path"] != "/tmp/file.txt" { + t.Fatalf("expected validated response data to reach client, got %v", resp.Data["path"]) + } +} + +func TestToolBridge_Bad_InvalidResponseBody(t *testing.T) { + gin.SetMode(gin.TestMode) + engine := gin.New() + + bridge := api.NewToolBridge("/tools") + bridge.Add(api.ToolDescriptor{ + Name: "file_read", + Description: "Read a file from disk", + Group: "files", + OutputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "path": map[string]any{"type": "string"}, + }, + "required": []any{"path"}, + }, + }, func(c *gin.Context) { + c.JSON(http.StatusOK, api.OK(map[string]any{"path": 123})) + }) + + rg := engine.Group(bridge.BasePath()) + bridge.RegisterRoutes(rg) + + w := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodPost, "/tools/file_read", nil) + engine.ServeHTTP(w, req) + + if w.Code != http.StatusInternalServerError { + t.Fatalf("expected 500, got %d", w.Code) + } + + var resp api.Response[any] + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("unmarshal error: %v", err) + } + if resp.Success { + t.Fatal("expected Success=false") + } + if resp.Error == nil || resp.Error.Code != "invalid_tool_response" { + t.Fatalf("expected invalid_tool_response error, got %#v", resp.Error) + } +} + func TestToolBridge_Bad_InvalidRequestBody(t *testing.T) { gin.SetMode(gin.TestMode) engine := gin.New()