diff --git a/pkg/mcp/brain/brain.go b/pkg/mcp/brain/brain.go index f9386cb..b19f993 100644 --- a/pkg/mcp/brain/brain.go +++ b/pkg/mcp/brain/brain.go @@ -7,6 +7,7 @@ package brain import ( "context" + coremcp "dappco.re/go/mcp/pkg/mcp" "dappco.re/go/mcp/pkg/mcp/ide" coreerr "forge.lthn.ai/core/go-log" "github.com/modelcontextprotocol/go-sdk/mcp" @@ -26,7 +27,13 @@ type Subsystem struct { // New creates a brain subsystem that uses the given IDE bridge for Laravel communication. // Pass nil if headless (tools will return errBridgeNotAvailable). func New(bridge *ide.Bridge) *Subsystem { - return &Subsystem{bridge: bridge} + s := &Subsystem{bridge: bridge} + if bridge != nil { + bridge.AddObserver(func(msg ide.BridgeMessage) { + s.handleBridgeMessage(msg) + }) + } + return s } // Name implements mcp.Subsystem. @@ -47,6 +54,31 @@ func (s *Subsystem) RegisterTools(server *mcp.Server) { s.registerBrainTools(server) } +func (s *Subsystem) handleBridgeMessage(msg ide.BridgeMessage) { + if msg.Type != "brain_recall" { + return + } + + payload := map[string]any{} + if data, ok := msg.Data.(map[string]any); ok { + for _, key := range []string{"query", "project", "type", "agent_id"} { + if value, ok := data[key]; ok { + payload[key] = value + } + } + if count, ok := data["count"]; ok { + payload["count"] = count + } else if memories, ok := data["memories"].([]any); ok { + payload["count"] = len(memories) + } + } + if _, ok := payload["count"]; !ok { + payload["count"] = 0 + } + + s.emitChannel(context.Background(), coremcp.ChannelBrainRecallDone, payload) +} + // Shutdown implements mcp.SubsystemWithShutdown. func (s *Subsystem) Shutdown(_ context.Context) error { return nil diff --git a/pkg/mcp/brain/brain_test.go b/pkg/mcp/brain/brain_test.go index bf71cc5..3641dfc 100644 --- a/pkg/mcp/brain/brain_test.go +++ b/pkg/mcp/brain/brain_test.go @@ -7,8 +7,20 @@ import ( "encoding/json" "testing" "time" + + "dappco.re/go/mcp/pkg/mcp/ide" ) +type recordingNotifier struct { + channel string + data any +} + +func (r *recordingNotifier) ChannelSend(_ context.Context, channel string, data any) { + r.channel = channel + r.data = data +} + // --- Nil bridge tests (headless mode) --- func TestBrainRemember_Bad_NilBridge(t *testing.T) { @@ -68,6 +80,38 @@ func TestSubsystem_Good_ShutdownNoop(t *testing.T) { } } +func TestSubsystem_Good_BridgeRecallNotification(t *testing.T) { + sub := New(nil) + notifier := &recordingNotifier{} + sub.notifier = notifier + + sub.handleBridgeMessage(ide.BridgeMessage{ + Type: "brain_recall", + Data: map[string]any{ + "query": "how does scoring work?", + "memories": []any{ + map[string]any{"id": "m1"}, + map[string]any{"id": "m2"}, + }, + }, + }) + + if notifier.channel != "brain.recall.complete" { + t.Fatalf("expected brain.recall.complete, got %q", notifier.channel) + } + + payload, ok := notifier.data.(map[string]any) + if !ok { + t.Fatalf("expected payload map, got %T", notifier.data) + } + if payload["count"] != 2 { + t.Fatalf("expected count 2, got %v", payload["count"]) + } + if payload["query"] != "how does scoring work?" { + t.Fatalf("expected query to be forwarded, got %v", payload["query"]) + } +} + // --- Struct round-trip tests --- func TestRememberInput_Good_RoundTrip(t *testing.T) { diff --git a/pkg/mcp/brain/provider.go b/pkg/mcp/brain/provider.go index c279860..d2a1baa 100644 --- a/pkg/mcp/brain/provider.go +++ b/pkg/mcp/brain/provider.go @@ -31,10 +31,16 @@ var ( // NewProvider creates a brain provider that proxies to Laravel via the IDE bridge. // The WS hub is used to emit brain events. Pass nil for hub if not needed. func NewProvider(bridge *ide.Bridge, hub *ws.Hub) *BrainProvider { - return &BrainProvider{ + p := &BrainProvider{ bridge: bridge, hub: hub, } + if bridge != nil { + bridge.AddObserver(func(msg ide.BridgeMessage) { + p.handleBridgeMessage(msg) + }) + } + return p } // Name implements api.RouteGroup. @@ -246,10 +252,6 @@ func (p *BrainProvider) recall(c *gin.Context) { return } - p.emitEvent(coremcp.ChannelBrainRecallDone, map[string]any{ - "query": input.Query, - }) - c.JSON(http.StatusOK, api.OK(RecallOutput{ Success: true, Memories: []Memory{}, @@ -348,3 +350,28 @@ func (p *BrainProvider) emitEvent(channel string, data any) { Data: data, }) } + +func (p *BrainProvider) handleBridgeMessage(msg ide.BridgeMessage) { + if msg.Type != "brain_recall" { + return + } + + payload := map[string]any{} + if data, ok := msg.Data.(map[string]any); ok { + for _, key := range []string{"query", "project", "type", "agent_id"} { + if value, ok := data[key]; ok { + payload[key] = value + } + } + if count, ok := data["count"]; ok { + payload["count"] = count + } else if memories, ok := data["memories"].([]any); ok { + payload["count"] = len(memories) + } + } + if _, ok := payload["count"]; !ok { + payload["count"] = 0 + } + + p.emitEvent(coremcp.ChannelBrainRecallDone, payload) +} diff --git a/pkg/mcp/brain/tools.go b/pkg/mcp/brain/tools.go index c31cea0..cf54f64 100644 --- a/pkg/mcp/brain/tools.go +++ b/pkg/mcp/brain/tools.go @@ -179,11 +179,6 @@ func (s *Subsystem) brainRecall(ctx context.Context, _ *mcp.CallToolRequest, inp return nil, RecallOutput{}, coreerr.E("brain.recall", "failed to send brain_recall", err) } - s.emitChannel(ctx, coremcp.ChannelBrainRecallDone, map[string]any{ - "query": input.Query, - "count": 0, - }) - return nil, RecallOutput{ Success: true, Memories: []Memory{}, diff --git a/pkg/mcp/ide/bridge.go b/pkg/mcp/ide/bridge.go index 43de426..af91437 100644 --- a/pkg/mcp/ide/bridge.go +++ b/pkg/mcp/ide/bridge.go @@ -31,7 +31,7 @@ type Bridge struct { mu sync.Mutex connected bool cancel context.CancelFunc - onMessage func(BridgeMessage) + observers []func(BridgeMessage) } // NewBridge creates a bridge that will connect to the Laravel backend and @@ -44,7 +44,22 @@ func NewBridge(hub *ws.Hub, cfg Config) *Bridge { func (b *Bridge) SetObserver(fn func(BridgeMessage)) { b.mu.Lock() defer b.mu.Unlock() - b.onMessage = fn + if fn == nil { + b.observers = nil + return + } + b.observers = []func(BridgeMessage){fn} +} + +// AddObserver registers an additional bridge observer. +// Observers are invoked in registration order after each inbound message. +func (b *Bridge) AddObserver(fn func(BridgeMessage)) { + if fn == nil { + return + } + b.mu.Lock() + defer b.mu.Unlock() + b.observers = append(b.observers, fn) } // Start begins the connection loop in a background goroutine. @@ -169,15 +184,24 @@ func (b *Bridge) readLoop(ctx context.Context) { } b.dispatch(msg) - b.mu.Lock() - observer := b.onMessage - b.mu.Unlock() - if observer != nil { + for _, observer := range b.snapshotObservers() { observer(msg) } } } +func (b *Bridge) snapshotObservers() []func(BridgeMessage) { + b.mu.Lock() + defer b.mu.Unlock() + + if len(b.observers) == 0 { + return nil + } + observers := make([]func(BridgeMessage), len(b.observers)) + copy(observers, b.observers) + return observers +} + // dispatch routes an incoming message to the appropriate ws.Hub channel. func (b *Bridge) dispatch(msg BridgeMessage) { if b.hub == nil { diff --git a/pkg/mcp/ide/bridge_test.go b/pkg/mcp/ide/bridge_test.go index ad51959..d732739 100644 --- a/pkg/mcp/ide/bridge_test.go +++ b/pkg/mcp/ide/bridge_test.go @@ -164,6 +164,71 @@ func TestBridge_Good_MessageDispatch(t *testing.T) { // This confirms the dispatch path ran without error. } +func TestBridge_Good_MultipleObservers(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := testUpgrader.Upgrade(w, r, nil) + if err != nil { + return + } + defer conn.Close() + + msg := BridgeMessage{ + Type: "brain_recall", + Data: map[string]any{ + "query": "test query", + "count": 3, + }, + } + data, _ := json.Marshal(msg) + _ = conn.WriteMessage(websocket.TextMessage, data) + + for { + if _, _, err := conn.ReadMessage(); err != nil { + break + } + } + })) + defer ts.Close() + + hub := ws.NewHub() + ctx := t.Context() + go hub.Run(ctx) + + cfg := DefaultConfig() + cfg.LaravelWSURL = wsURL(ts) + cfg.ReconnectInterval = 100 * time.Millisecond + + bridge := NewBridge(hub, cfg) + + first := make(chan struct{}, 1) + second := make(chan struct{}, 1) + bridge.AddObserver(func(msg BridgeMessage) { + if msg.Type == "brain_recall" { + first <- struct{}{} + } + }) + bridge.AddObserver(func(msg BridgeMessage) { + if msg.Type == "brain_recall" { + second <- struct{}{} + } + }) + + bridge.Start(ctx) + waitConnected(t, bridge, 2*time.Second) + + select { + case <-first: + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for first observer") + } + + select { + case <-second: + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for second observer") + } +} + func TestBridge_Good_Reconnect(t *testing.T) { // Use atomic counter to avoid data race between HTTP handler goroutine // and the test goroutine. diff --git a/pkg/mcp/ide/ide.go b/pkg/mcp/ide/ide.go index 376f8cb..5e90099 100644 --- a/pkg/mcp/ide/ide.go +++ b/pkg/mcp/ide/ide.go @@ -51,7 +51,7 @@ func New(hub *ws.Hub, cfg Config) *Subsystem { } if hub != nil { s.bridge = NewBridge(hub, cfg) - s.bridge.SetObserver(func(msg BridgeMessage) { + s.bridge.AddObserver(func(msg BridgeMessage) { s.handleBridgeMessage(msg) }) }