feat(api): validate ToolBridge output schemas
Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
parent
9aa7c644ef
commit
5d5ca8aa51
2 changed files with 251 additions and 0 deletions
165
bridge.go
165
bridge.go
|
|
@ -3,11 +3,14 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"iter"
|
||||
"net"
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
|
|
@ -46,6 +49,9 @@ func NewToolBridge(basePath string) *ToolBridge {
|
|||
|
||||
// Add registers a tool with its HTTP handler.
|
||||
func (b *ToolBridge) Add(desc ToolDescriptor, handler gin.HandlerFunc) {
|
||||
if validator := newToolInputValidator(desc.OutputSchema); validator != nil {
|
||||
handler = wrapToolResponseHandler(handler, validator)
|
||||
}
|
||||
if validator := newToolInputValidator(desc.InputSchema); validator != nil {
|
||||
handler = wrapToolHandler(handler, validator)
|
||||
}
|
||||
|
|
@ -156,6 +162,29 @@ func wrapToolHandler(handler gin.HandlerFunc, validator *toolInputValidator) gin
|
|||
}
|
||||
}
|
||||
|
||||
func wrapToolResponseHandler(handler gin.HandlerFunc, validator *toolInputValidator) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
recorder := newToolResponseRecorder(c.Writer)
|
||||
c.Writer = recorder
|
||||
|
||||
handler(c)
|
||||
|
||||
if recorder.Status() >= 200 && recorder.Status() < 300 {
|
||||
if err := validator.ValidateResponse(recorder.body.Bytes()); err != nil {
|
||||
recorder.reset()
|
||||
recorder.writeErrorResponse(http.StatusInternalServerError, FailWithDetails(
|
||||
"invalid_tool_response",
|
||||
"Tool response does not match the declared output schema",
|
||||
map[string]any{"error": err.Error()},
|
||||
))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
recorder.commit()
|
||||
}
|
||||
}
|
||||
|
||||
type toolInputValidator struct {
|
||||
schema map[string]any
|
||||
}
|
||||
|
|
@ -187,6 +216,41 @@ func (v *toolInputValidator) Validate(body []byte) error {
|
|||
return validateSchemaNode(payload, v.schema, "")
|
||||
}
|
||||
|
||||
func (v *toolInputValidator) ValidateResponse(body []byte) error {
|
||||
if len(v.schema) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
var envelope map[string]any
|
||||
if err := json.Unmarshal(body, &envelope); err != nil {
|
||||
return fmt.Errorf("invalid JSON response: %w", err)
|
||||
}
|
||||
|
||||
success, _ := envelope["success"].(bool)
|
||||
if !success {
|
||||
return fmt.Errorf("response is missing a successful envelope")
|
||||
}
|
||||
|
||||
data, ok := envelope["data"]
|
||||
if !ok {
|
||||
return fmt.Errorf("response is missing data")
|
||||
}
|
||||
|
||||
encoded, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
return fmt.Errorf("encode response data: %w", err)
|
||||
}
|
||||
|
||||
var payload any
|
||||
dec := json.NewDecoder(bytes.NewReader(encoded))
|
||||
dec.UseNumber()
|
||||
if err := dec.Decode(&payload); err != nil {
|
||||
return fmt.Errorf("decode response data: %w", err)
|
||||
}
|
||||
|
||||
return validateSchemaNode(payload, v.schema, "")
|
||||
}
|
||||
|
||||
func validateSchemaNode(value any, schema map[string]any, path string) error {
|
||||
if len(schema) == 0 {
|
||||
return nil
|
||||
|
|
@ -267,6 +331,107 @@ func validateSchemaNode(value any, schema map[string]any, path string) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
type toolResponseRecorder struct {
|
||||
gin.ResponseWriter
|
||||
headers http.Header
|
||||
body bytes.Buffer
|
||||
status int
|
||||
wroteHeader bool
|
||||
}
|
||||
|
||||
func newToolResponseRecorder(w gin.ResponseWriter) *toolResponseRecorder {
|
||||
headers := make(http.Header)
|
||||
for k, vals := range w.Header() {
|
||||
headers[k] = append([]string(nil), vals...)
|
||||
}
|
||||
return &toolResponseRecorder{
|
||||
ResponseWriter: w,
|
||||
headers: headers,
|
||||
status: http.StatusOK,
|
||||
}
|
||||
}
|
||||
|
||||
func (w *toolResponseRecorder) Header() http.Header {
|
||||
return w.headers
|
||||
}
|
||||
|
||||
func (w *toolResponseRecorder) WriteHeader(code int) {
|
||||
w.status = code
|
||||
w.wroteHeader = true
|
||||
}
|
||||
|
||||
func (w *toolResponseRecorder) WriteHeaderNow() {
|
||||
w.wroteHeader = true
|
||||
}
|
||||
|
||||
func (w *toolResponseRecorder) Write(data []byte) (int, error) {
|
||||
if !w.wroteHeader {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
return w.body.Write(data)
|
||||
}
|
||||
|
||||
func (w *toolResponseRecorder) WriteString(s string) (int, error) {
|
||||
if !w.wroteHeader {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
return w.body.WriteString(s)
|
||||
}
|
||||
|
||||
func (w *toolResponseRecorder) Flush() {
|
||||
}
|
||||
|
||||
func (w *toolResponseRecorder) Status() int {
|
||||
if w.wroteHeader {
|
||||
return w.status
|
||||
}
|
||||
return http.StatusOK
|
||||
}
|
||||
|
||||
func (w *toolResponseRecorder) Size() int {
|
||||
return w.body.Len()
|
||||
}
|
||||
|
||||
func (w *toolResponseRecorder) Written() bool {
|
||||
return w.wroteHeader
|
||||
}
|
||||
|
||||
func (w *toolResponseRecorder) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
return nil, nil, errors.New("response hijacking is not supported by ToolBridge output validation")
|
||||
}
|
||||
|
||||
func (w *toolResponseRecorder) commit() {
|
||||
for k := range w.ResponseWriter.Header() {
|
||||
w.ResponseWriter.Header().Del(k)
|
||||
}
|
||||
for k, vals := range w.headers {
|
||||
for _, v := range vals {
|
||||
w.ResponseWriter.Header().Add(k, v)
|
||||
}
|
||||
}
|
||||
w.ResponseWriter.WriteHeader(w.Status())
|
||||
_, _ = w.ResponseWriter.Write(w.body.Bytes())
|
||||
}
|
||||
|
||||
func (w *toolResponseRecorder) reset() {
|
||||
w.headers = make(http.Header)
|
||||
w.body.Reset()
|
||||
w.status = http.StatusInternalServerError
|
||||
w.wroteHeader = false
|
||||
}
|
||||
|
||||
func (w *toolResponseRecorder) writeErrorResponse(status int, resp Response[any]) {
|
||||
data, err := json.Marshal(resp)
|
||||
if err != nil {
|
||||
http.Error(w.ResponseWriter, "internal server error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
||||
w.ResponseWriter.WriteHeader(status)
|
||||
_, _ = w.ResponseWriter.Write(data)
|
||||
}
|
||||
|
||||
func typeError(path, want string, value any) error {
|
||||
return fmt.Errorf("%s must be %s, got %s", displayPath(path), want, describeJSONValue(value))
|
||||
}
|
||||
|
|
|
|||
|
|
@ -198,6 +198,92 @@ func TestToolBridge_Good_ValidatesRequestBody(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestToolBridge_Good_ValidatesResponseBody(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
engine := gin.New()
|
||||
|
||||
bridge := api.NewToolBridge("/tools")
|
||||
bridge.Add(api.ToolDescriptor{
|
||||
Name: "file_read",
|
||||
Description: "Read a file from disk",
|
||||
Group: "files",
|
||||
OutputSchema: map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"path": map[string]any{"type": "string"},
|
||||
},
|
||||
"required": []any{"path"},
|
||||
},
|
||||
}, func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, api.OK(map[string]any{"path": "/tmp/file.txt"}))
|
||||
})
|
||||
|
||||
rg := engine.Group(bridge.BasePath())
|
||||
bridge.RegisterRoutes(rg)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest(http.MethodPost, "/tools/file_read", nil)
|
||||
engine.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
var resp api.Response[map[string]any]
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("unmarshal error: %v", err)
|
||||
}
|
||||
if !resp.Success {
|
||||
t.Fatal("expected Success=true")
|
||||
}
|
||||
if resp.Data["path"] != "/tmp/file.txt" {
|
||||
t.Fatalf("expected validated response data to reach client, got %v", resp.Data["path"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolBridge_Bad_InvalidResponseBody(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
engine := gin.New()
|
||||
|
||||
bridge := api.NewToolBridge("/tools")
|
||||
bridge.Add(api.ToolDescriptor{
|
||||
Name: "file_read",
|
||||
Description: "Read a file from disk",
|
||||
Group: "files",
|
||||
OutputSchema: map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"path": map[string]any{"type": "string"},
|
||||
},
|
||||
"required": []any{"path"},
|
||||
},
|
||||
}, func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, api.OK(map[string]any{"path": 123}))
|
||||
})
|
||||
|
||||
rg := engine.Group(bridge.BasePath())
|
||||
bridge.RegisterRoutes(rg)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest(http.MethodPost, "/tools/file_read", nil)
|
||||
engine.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusInternalServerError {
|
||||
t.Fatalf("expected 500, got %d", w.Code)
|
||||
}
|
||||
|
||||
var resp api.Response[any]
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("unmarshal error: %v", err)
|
||||
}
|
||||
if resp.Success {
|
||||
t.Fatal("expected Success=false")
|
||||
}
|
||||
if resp.Error == nil || resp.Error.Code != "invalid_tool_response" {
|
||||
t.Fatalf("expected invalid_tool_response error, got %#v", resp.Error)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolBridge_Bad_InvalidRequestBody(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
engine := gin.New()
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue