// SPDX-License-Identifier: EUPL-1.2 package api import ( "bufio" "bytes" "encoding/json" "errors" "fmt" "io" "iter" "net" "net/http" "strconv" "github.com/gin-gonic/gin" ) // ToolDescriptor describes a tool that can be exposed as a REST endpoint. type ToolDescriptor struct { Name string // Tool name, e.g. "file_read" (becomes POST path segment) Description string // Human-readable description Group string // OpenAPI tag group, e.g. "files" InputSchema map[string]any // JSON Schema for request body OutputSchema map[string]any // JSON Schema for response data (optional) } // ToolBridge converts tool descriptors into REST endpoints and OpenAPI paths. // It implements both RouteGroup and DescribableGroup. type ToolBridge struct { basePath string name string tools []boundTool } type boundTool struct { descriptor ToolDescriptor handler gin.HandlerFunc } // NewToolBridge creates a bridge that mounts tool endpoints at basePath. func NewToolBridge(basePath string) *ToolBridge { return &ToolBridge{ basePath: basePath, name: "tools", } } // Add registers a tool with its HTTP handler. func (b *ToolBridge) Add(desc ToolDescriptor, handler gin.HandlerFunc) { if validator := newToolInputValidator(desc.OutputSchema); validator != nil { handler = wrapToolResponseHandler(handler, validator) } if validator := newToolInputValidator(desc.InputSchema); validator != nil { handler = wrapToolHandler(handler, validator) } b.tools = append(b.tools, boundTool{descriptor: desc, handler: handler}) } // Name returns the bridge identifier. func (b *ToolBridge) Name() string { return b.name } // BasePath returns the URL prefix for all tool endpoints. func (b *ToolBridge) BasePath() string { return b.basePath } // RegisterRoutes mounts POST /{tool_name} for each registered tool. func (b *ToolBridge) RegisterRoutes(rg *gin.RouterGroup) { for _, t := range b.tools { rg.POST("/"+t.descriptor.Name, t.handler) } } // Describe returns OpenAPI route descriptions for all registered tools. func (b *ToolBridge) Describe() []RouteDescription { descs := make([]RouteDescription, 0, len(b.tools)) for _, t := range b.tools { tags := []string{t.descriptor.Group} if t.descriptor.Group == "" { tags = []string{b.name} } descs = append(descs, RouteDescription{ Method: "POST", Path: "/" + t.descriptor.Name, Summary: t.descriptor.Description, Description: t.descriptor.Description, Tags: tags, RequestBody: t.descriptor.InputSchema, Response: t.descriptor.OutputSchema, }) } return descs } // DescribeIter returns an iterator over OpenAPI route descriptions for all registered tools. func (b *ToolBridge) DescribeIter() iter.Seq[RouteDescription] { return func(yield func(RouteDescription) bool) { for _, t := range b.tools { tags := []string{t.descriptor.Group} if t.descriptor.Group == "" { tags = []string{b.name} } rd := RouteDescription{ Method: "POST", Path: "/" + t.descriptor.Name, Summary: t.descriptor.Description, Description: t.descriptor.Description, Tags: tags, RequestBody: t.descriptor.InputSchema, Response: t.descriptor.OutputSchema, } if !yield(rd) { return } } } } // Tools returns all registered tool descriptors. func (b *ToolBridge) Tools() []ToolDescriptor { descs := make([]ToolDescriptor, len(b.tools)) for i, t := range b.tools { descs[i] = t.descriptor } return descs } // ToolsIter returns an iterator over all registered tool descriptors. func (b *ToolBridge) ToolsIter() iter.Seq[ToolDescriptor] { return func(yield func(ToolDescriptor) bool) { for _, t := range b.tools { if !yield(t.descriptor) { return } } } } func wrapToolHandler(handler gin.HandlerFunc, validator *toolInputValidator) gin.HandlerFunc { return func(c *gin.Context) { body, err := io.ReadAll(c.Request.Body) if err != nil { c.AbortWithStatusJSON(http.StatusBadRequest, FailWithDetails( "invalid_request_body", "Unable to read request body", map[string]any{"error": err.Error()}, )) return } if err := validator.Validate(body); err != nil { c.AbortWithStatusJSON(http.StatusBadRequest, FailWithDetails( "invalid_request_body", "Request body does not match the declared tool schema", map[string]any{"error": err.Error()}, )) return } c.Request.Body = io.NopCloser(bytes.NewReader(body)) handler(c) } } func wrapToolResponseHandler(handler gin.HandlerFunc, validator *toolInputValidator) gin.HandlerFunc { return func(c *gin.Context) { recorder := newToolResponseRecorder(c.Writer) c.Writer = recorder handler(c) if recorder.Status() >= 200 && recorder.Status() < 300 { if err := validator.ValidateResponse(recorder.body.Bytes()); err != nil { recorder.reset() recorder.writeErrorResponse(http.StatusInternalServerError, FailWithDetails( "invalid_tool_response", "Tool response does not match the declared output schema", map[string]any{"error": err.Error()}, )) return } } recorder.commit() } } type toolInputValidator struct { schema map[string]any } func newToolInputValidator(schema map[string]any) *toolInputValidator { if len(schema) == 0 { return nil } return &toolInputValidator{schema: schema} } func (v *toolInputValidator) Validate(body []byte) error { if len(bytes.TrimSpace(body)) == 0 { return fmt.Errorf("request body is required") } dec := json.NewDecoder(bytes.NewReader(body)) dec.UseNumber() var payload any if err := dec.Decode(&payload); err != nil { return fmt.Errorf("invalid JSON: %w", err) } var extra any if err := dec.Decode(&extra); err != io.EOF { return fmt.Errorf("request body must contain a single JSON value") } return validateSchemaNode(payload, v.schema, "") } func (v *toolInputValidator) ValidateResponse(body []byte) error { if len(v.schema) == 0 { return nil } var envelope map[string]any if err := json.Unmarshal(body, &envelope); err != nil { return fmt.Errorf("invalid JSON response: %w", err) } success, _ := envelope["success"].(bool) if !success { return fmt.Errorf("response is missing a successful envelope") } data, ok := envelope["data"] if !ok { return fmt.Errorf("response is missing data") } encoded, err := json.Marshal(data) if err != nil { return fmt.Errorf("encode response data: %w", err) } var payload any dec := json.NewDecoder(bytes.NewReader(encoded)) dec.UseNumber() if err := dec.Decode(&payload); err != nil { return fmt.Errorf("decode response data: %w", err) } return validateSchemaNode(payload, v.schema, "") } func validateSchemaNode(value any, schema map[string]any, path string) error { if len(schema) == 0 { return nil } if schemaType, _ := schema["type"].(string); schemaType != "" { switch schemaType { case "object": obj, ok := value.(map[string]any) if !ok { return typeError(path, "object", value) } for _, name := range stringList(schema["required"]) { if _, ok := obj[name]; !ok { return fmt.Errorf("%s is missing required field %q", displayPath(path), name) } } for name, rawChild := range schemaMap(schema["properties"]) { childSchema, ok := rawChild.(map[string]any) if !ok { continue } childValue, ok := obj[name] if !ok { continue } if err := validateSchemaNode(childValue, childSchema, joinPath(path, name)); err != nil { return err } } return nil case "array": arr, ok := value.([]any) if !ok { return typeError(path, "array", value) } if items := schemaMap(schema["items"]); len(items) > 0 { for i, item := range arr { if err := validateSchemaNode(item, items, joinPath(path, strconv.Itoa(i))); err != nil { return err } } } return nil case "string": if _, ok := value.(string); !ok { return typeError(path, "string", value) } return nil case "boolean": if _, ok := value.(bool); !ok { return typeError(path, "boolean", value) } return nil case "integer": if !isIntegerValue(value) { return typeError(path, "integer", value) } return nil case "number": if !isNumberValue(value) { return typeError(path, "number", value) } return nil } } if props := schemaMap(schema["properties"]); len(props) > 0 { return validateSchemaNode(value, map[string]any{ "type": "object", "properties": props, "required": schema["required"], }, path) } return nil } type toolResponseRecorder struct { gin.ResponseWriter headers http.Header body bytes.Buffer status int wroteHeader bool } func newToolResponseRecorder(w gin.ResponseWriter) *toolResponseRecorder { headers := make(http.Header) for k, vals := range w.Header() { headers[k] = append([]string(nil), vals...) } return &toolResponseRecorder{ ResponseWriter: w, headers: headers, status: http.StatusOK, } } func (w *toolResponseRecorder) Header() http.Header { return w.headers } func (w *toolResponseRecorder) WriteHeader(code int) { w.status = code w.wroteHeader = true } func (w *toolResponseRecorder) WriteHeaderNow() { w.wroteHeader = true } func (w *toolResponseRecorder) Write(data []byte) (int, error) { if !w.wroteHeader { w.WriteHeader(http.StatusOK) } return w.body.Write(data) } func (w *toolResponseRecorder) WriteString(s string) (int, error) { if !w.wroteHeader { w.WriteHeader(http.StatusOK) } return w.body.WriteString(s) } func (w *toolResponseRecorder) Flush() { } func (w *toolResponseRecorder) Status() int { if w.wroteHeader { return w.status } return http.StatusOK } func (w *toolResponseRecorder) Size() int { return w.body.Len() } func (w *toolResponseRecorder) Written() bool { return w.wroteHeader } func (w *toolResponseRecorder) Hijack() (net.Conn, *bufio.ReadWriter, error) { return nil, nil, errors.New("response hijacking is not supported by ToolBridge output validation") } func (w *toolResponseRecorder) commit() { for k := range w.ResponseWriter.Header() { w.ResponseWriter.Header().Del(k) } for k, vals := range w.headers { for _, v := range vals { w.ResponseWriter.Header().Add(k, v) } } w.ResponseWriter.WriteHeader(w.Status()) _, _ = w.ResponseWriter.Write(w.body.Bytes()) } func (w *toolResponseRecorder) reset() { w.headers = make(http.Header) w.body.Reset() w.status = http.StatusInternalServerError w.wroteHeader = false } func (w *toolResponseRecorder) writeErrorResponse(status int, resp Response[any]) { data, err := json.Marshal(resp) if err != nil { http.Error(w.ResponseWriter, "internal server error", http.StatusInternalServerError) return } w.ResponseWriter.Header().Set("Content-Type", "application/json") w.ResponseWriter.WriteHeader(status) _, _ = w.ResponseWriter.Write(data) } func typeError(path, want string, value any) error { return fmt.Errorf("%s must be %s, got %s", displayPath(path), want, describeJSONValue(value)) } func displayPath(path string) string { if path == "" { return "request body" } return "request body." + path } func joinPath(parent, child string) string { if parent == "" { return child } return parent + "." + child } func schemaMap(value any) map[string]any { if value == nil { return nil } m, _ := value.(map[string]any) return m } func stringList(value any) []string { switch raw := value.(type) { case []any: out := make([]string, 0, len(raw)) for _, item := range raw { name, ok := item.(string) if !ok { continue } out = append(out, name) } return out case []string: return append([]string(nil), raw...) default: return nil } } func isIntegerValue(value any) bool { switch v := value.(type) { case json.Number: _, err := v.Int64() return err == nil case float64: return v == float64(int64(v)) default: return false } } func isNumberValue(value any) bool { switch value.(type) { case json.Number, float64: return true default: return false } } func describeJSONValue(value any) string { switch value.(type) { case nil: return "null" case string: return "string" case bool: return "boolean" case json.Number, float64: return "number" case map[string]any: return "object" case []any: return "array" default: return fmt.Sprintf("%T", value) } }