diff --git a/pkg/mcp/bridge.go b/pkg/mcp/bridge.go index f1cb980..2b85313 100644 --- a/pkg/mcp/bridge.go +++ b/pkg/mcp/bridge.go @@ -63,7 +63,7 @@ func BridgeToAPI(svc *Service, bridge *api.ToolBridge) { if err != nil { // Body present + error = likely bad input (malformed JSON). // No body + error = tool execution failure. - if len(body) > 0 && core.Contains(err.Error(), "unmarshal") { + if errors.Is(err, errInvalidRESTInput) { c.JSON(http.StatusBadRequest, api.Fail("invalid_input", "Malformed JSON in request body")) return } diff --git a/pkg/mcp/bridge_test.go b/pkg/mcp/bridge_test.go index 87ad818..ae0a92b 100644 --- a/pkg/mcp/bridge_test.go +++ b/pkg/mcp/bridge_test.go @@ -165,13 +165,8 @@ func TestBridgeToAPI_Bad_InvalidJSON(t *testing.T) { req.Header.Set("Content-Type", "application/json") engine.ServeHTTP(w, req) - if w.Code != http.StatusInternalServerError { - // The handler unmarshals via RESTHandler which returns an error, - // but since it's a JSON parse error it ends up as tool_error. - // Check we get a non-200 with an error envelope. - if w.Code == http.StatusOK { - t.Fatalf("expected non-200 for invalid JSON, got 200") - } + if w.Code != http.StatusBadRequest { + t.Fatalf("expected 400 for invalid JSON, got %d: %s", w.Code, w.Body.String()) } var resp api.Response[any] diff --git a/pkg/mcp/registry.go b/pkg/mcp/registry.go index c84b72d..c6a6baf 100644 --- a/pkg/mcp/registry.go +++ b/pkg/mcp/registry.go @@ -4,6 +4,8 @@ package mcp import ( "context" + "errors" + "fmt" "reflect" "time" @@ -21,6 +23,9 @@ import ( // } type RESTHandler func(ctx context.Context, body []byte) (any, error) +// errInvalidRESTInput marks malformed JSON bodies for the REST bridge. +var errInvalidRESTInput = errors.New("invalid REST input") + // ToolRecord captures metadata about a registered MCP tool. // // for _, rec := range svc.Tools() { @@ -53,9 +58,9 @@ func AddToolRecorded[In, Out any](s *Service, server *mcp.Server, group string, if len(body) > 0 { if r := core.JSONUnmarshal(body, &input); !r.OK { if err, ok := r.Value.(error); ok { - return nil, err + return nil, fmt.Errorf("%w: %v", errInvalidRESTInput, err) } - return nil, core.E("registry.RESTHandler", "failed to unmarshal input", nil) + return nil, fmt.Errorf("%w", errInvalidRESTInput) } } // nil: REST callers have no MCP request context.