diff --git a/pkg/mcp/brain/tools.go b/pkg/mcp/brain/tools.go index ced8601..dc0a4d6 100644 --- a/pkg/mcp/brain/tools.go +++ b/pkg/mcp/brain/tools.go @@ -5,6 +5,7 @@ package brain import ( "context" "time" + "unicode/utf8" coreerr "dappco.re/go/log" coremcp "dappco.re/go/mcp/pkg/mcp" @@ -12,6 +13,8 @@ import ( "github.com/modelcontextprotocol/go-sdk/mcp" ) +const brainOrgMaxLength = 128 + // emitChannel pushes a brain event through the shared notifier. func (s *Subsystem) emitChannel(ctx context.Context, channel string, data any) { if s.notifier != nil { @@ -128,6 +131,25 @@ type ListOutput struct { Memories []Memory `json:"memories"` } +func validateBrainOrg(org string) error { + if utf8.RuneCountInString(org) > brainOrgMaxLength { + return coreerr.E("brain.validate", "org exceeds maximum length of 128 characters", nil) + } + return nil +} + +func validateRememberInput(input RememberInput) error { + return validateBrainOrg(input.Org) +} + +func validateRecallInput(input RecallInput) error { + return validateBrainOrg(input.Filter.Org) +} + +func validateListInput(input ListInput) error { + return validateBrainOrg(input.Org) +} + // -- Tool registration -------------------------------------------------------- func (s *Subsystem) registerBrainTools(svc *coremcp.Service) { @@ -156,6 +178,9 @@ func (s *Subsystem) registerBrainTools(svc *coremcp.Service) { // -- Tool handlers ------------------------------------------------------------ func (s *Subsystem) brainRemember(ctx context.Context, _ *mcp.CallToolRequest, input RememberInput) (*mcp.CallToolResult, RememberOutput, error) { + if err := validateRememberInput(input); err != nil { + return nil, RememberOutput{}, err + } if s.bridge == nil { return nil, RememberOutput{}, errBridgeNotAvailable } @@ -190,6 +215,9 @@ func (s *Subsystem) brainRemember(ctx context.Context, _ *mcp.CallToolRequest, i } func (s *Subsystem) brainRecall(ctx context.Context, _ *mcp.CallToolRequest, input RecallInput) (*mcp.CallToolResult, RecallOutput, error) { + if err := validateRecallInput(input); err != nil { + return nil, RecallOutput{}, err + } if s.bridge == nil { return nil, RecallOutput{}, errBridgeNotAvailable } @@ -240,6 +268,9 @@ func (s *Subsystem) brainForget(ctx context.Context, _ *mcp.CallToolRequest, inp } func (s *Subsystem) brainList(ctx context.Context, _ *mcp.CallToolRequest, input ListInput) (*mcp.CallToolResult, ListOutput, error) { + if err := validateListInput(input); err != nil { + return nil, ListOutput{}, err + } if s.bridge == nil { return nil, ListOutput{}, errBridgeNotAvailable } diff --git a/pkg/mcp/brain/tools_test.go b/pkg/mcp/brain/tools_test.go new file mode 100644 index 0000000..6dd121b --- /dev/null +++ b/pkg/mcp/brain/tools_test.go @@ -0,0 +1,166 @@ +// SPDX-License-Identifier: EUPL-1.2 + +package brain + +import ( + "context" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "dappco.re/go/mcp/pkg/mcp/ide" + "dappco.re/go/ws" + "github.com/gorilla/websocket" +) + +var brainToolTestUpgrader = websocket.Upgrader{ + CheckOrigin: func(_ *http.Request) bool { return true }, +} + +func newConnectedBrainToolSubsystem(t *testing.T) (*Subsystem, <-chan ide.BridgeMessage) { + t.Helper() + + messages := make(chan ide.BridgeMessage, 8) + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := brainToolTestUpgrader.Upgrade(w, r, nil) + if err != nil { + t.Logf("upgrade error: %v", err) + return + } + defer conn.Close() + + for { + var msg ide.BridgeMessage + if err := conn.ReadJSON(&msg); err != nil { + return + } + messages <- msg + } + })) + + ctx, cancel := context.WithCancel(context.Background()) + hub := ws.NewHub() + go hub.Run(ctx) + + cfg := ide.DefaultConfig() + cfg.LaravelWSURL = "ws" + strings.TrimPrefix(srv.URL, "http") + cfg.ReconnectInterval = 10 * time.Millisecond + cfg.MaxReconnectInterval = 10 * time.Millisecond + + bridge := ide.NewBridge(hub, cfg) + bridge.Start(ctx) + waitBrainToolBridgeConnected(t, bridge) + + t.Cleanup(func() { + bridge.Shutdown() + cancel() + srv.Close() + }) + + return New(bridge), messages +} + +func waitBrainToolBridgeConnected(t *testing.T, bridge *ide.Bridge) { + t.Helper() + + deadline := time.Now().Add(2 * time.Second) + for time.Now().Before(deadline) { + if bridge.Connected() { + return + } + time.Sleep(10 * time.Millisecond) + } + t.Fatal("bridge did not connect within timeout") +} + +func readBrainToolBridgeMessage(t *testing.T, messages <-chan ide.BridgeMessage) ide.BridgeMessage { + t.Helper() + + select { + case msg := <-messages: + return msg + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for bridge message") + return ide.BridgeMessage{} + } +} + +func assertBrainOrgValidationError(t *testing.T, err error) { + t.Helper() + + if err == nil { + t.Fatal("expected org validation error") + } + if !strings.Contains(err.Error(), "org exceeds maximum length of 128 characters") { + t.Fatalf("expected org length error, got %v", err) + } +} + +func TestBrainRemember_Good_OrgLengthBoundary(t *testing.T) { + sub, messages := newConnectedBrainToolSubsystem(t) + + for _, tc := range []struct { + name string + org string + }{ + {name: "non_empty", org: "core"}, + {name: "empty", org: ""}, + {name: "boundary", org: strings.Repeat("a", brainOrgMaxLength)}, + } { + t.Run(tc.name, func(t *testing.T) { + _, out, err := sub.brainRemember(context.Background(), nil, RememberInput{ + Content: "test memory", + Type: "observation", + Org: tc.org, + }) + if err != nil { + t.Fatalf("brainRemember failed: %v", err) + } + if !out.Success { + t.Fatal("expected success=true") + } + + msg := readBrainToolBridgeMessage(t, messages) + if msg.Type != "brain_remember" { + t.Fatalf("expected brain_remember message, got %q", msg.Type) + } + data, ok := msg.Data.(map[string]any) + if !ok { + t.Fatalf("expected bridge data map, got %T", msg.Data) + } + if data["org"] != tc.org { + t.Fatalf("expected org %q, got %v", tc.org, data["org"]) + } + }) + } +} + +func TestBrainRemember_Bad_OrgTooLong(t *testing.T) { + sub := New(nil) + + _, _, err := sub.brainRemember(context.Background(), nil, RememberInput{ + Content: "test memory", + Type: "observation", + Org: strings.Repeat("a", brainOrgMaxLength+1), + }) + + assertBrainOrgValidationError(t, err) +} + +func TestBrainOrgValidation_Bad_RecallAndListRejectBeforeBridge(t *testing.T) { + sub := New(nil) + tooLong := strings.Repeat("a", brainOrgMaxLength+1) + + _, _, err := sub.brainRecall(context.Background(), nil, RecallInput{ + Query: "test", + Filter: RecallFilter{Org: tooLong}, + }) + assertBrainOrgValidationError(t, err) + + _, _, err = sub.brainList(context.Background(), nil, ListInput{ + Org: tooLong, + }) + assertBrainOrgValidationError(t, err) +}