feat(mcp): fan out bridge observers for brain recall
Allow the IDE bridge to register multiple observers so the IDE and brain subsystems can both react to inbound Laravel messages. Brain recall notifications now fire from the bridge callback with the real result count instead of the request path, and the brain provider follows the same async notification flow. Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
parent
e138af6635
commit
981ad9f7da
7 changed files with 205 additions and 18 deletions
|
|
@ -7,6 +7,7 @@ package brain
|
|||
import (
|
||||
"context"
|
||||
|
||||
coremcp "dappco.re/go/mcp/pkg/mcp"
|
||||
"dappco.re/go/mcp/pkg/mcp/ide"
|
||||
coreerr "forge.lthn.ai/core/go-log"
|
||||
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||
|
|
@ -26,7 +27,13 @@ type Subsystem struct {
|
|||
// New creates a brain subsystem that uses the given IDE bridge for Laravel communication.
|
||||
// Pass nil if headless (tools will return errBridgeNotAvailable).
|
||||
func New(bridge *ide.Bridge) *Subsystem {
|
||||
return &Subsystem{bridge: bridge}
|
||||
s := &Subsystem{bridge: bridge}
|
||||
if bridge != nil {
|
||||
bridge.AddObserver(func(msg ide.BridgeMessage) {
|
||||
s.handleBridgeMessage(msg)
|
||||
})
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// Name implements mcp.Subsystem.
|
||||
|
|
@ -47,6 +54,31 @@ func (s *Subsystem) RegisterTools(server *mcp.Server) {
|
|||
s.registerBrainTools(server)
|
||||
}
|
||||
|
||||
func (s *Subsystem) handleBridgeMessage(msg ide.BridgeMessage) {
|
||||
if msg.Type != "brain_recall" {
|
||||
return
|
||||
}
|
||||
|
||||
payload := map[string]any{}
|
||||
if data, ok := msg.Data.(map[string]any); ok {
|
||||
for _, key := range []string{"query", "project", "type", "agent_id"} {
|
||||
if value, ok := data[key]; ok {
|
||||
payload[key] = value
|
||||
}
|
||||
}
|
||||
if count, ok := data["count"]; ok {
|
||||
payload["count"] = count
|
||||
} else if memories, ok := data["memories"].([]any); ok {
|
||||
payload["count"] = len(memories)
|
||||
}
|
||||
}
|
||||
if _, ok := payload["count"]; !ok {
|
||||
payload["count"] = 0
|
||||
}
|
||||
|
||||
s.emitChannel(context.Background(), coremcp.ChannelBrainRecallDone, payload)
|
||||
}
|
||||
|
||||
// Shutdown implements mcp.SubsystemWithShutdown.
|
||||
func (s *Subsystem) Shutdown(_ context.Context) error {
|
||||
return nil
|
||||
|
|
|
|||
|
|
@ -7,8 +7,20 @@ import (
|
|||
"encoding/json"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"dappco.re/go/mcp/pkg/mcp/ide"
|
||||
)
|
||||
|
||||
type recordingNotifier struct {
|
||||
channel string
|
||||
data any
|
||||
}
|
||||
|
||||
func (r *recordingNotifier) ChannelSend(_ context.Context, channel string, data any) {
|
||||
r.channel = channel
|
||||
r.data = data
|
||||
}
|
||||
|
||||
// --- Nil bridge tests (headless mode) ---
|
||||
|
||||
func TestBrainRemember_Bad_NilBridge(t *testing.T) {
|
||||
|
|
@ -68,6 +80,38 @@ func TestSubsystem_Good_ShutdownNoop(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestSubsystem_Good_BridgeRecallNotification(t *testing.T) {
|
||||
sub := New(nil)
|
||||
notifier := &recordingNotifier{}
|
||||
sub.notifier = notifier
|
||||
|
||||
sub.handleBridgeMessage(ide.BridgeMessage{
|
||||
Type: "brain_recall",
|
||||
Data: map[string]any{
|
||||
"query": "how does scoring work?",
|
||||
"memories": []any{
|
||||
map[string]any{"id": "m1"},
|
||||
map[string]any{"id": "m2"},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
if notifier.channel != "brain.recall.complete" {
|
||||
t.Fatalf("expected brain.recall.complete, got %q", notifier.channel)
|
||||
}
|
||||
|
||||
payload, ok := notifier.data.(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("expected payload map, got %T", notifier.data)
|
||||
}
|
||||
if payload["count"] != 2 {
|
||||
t.Fatalf("expected count 2, got %v", payload["count"])
|
||||
}
|
||||
if payload["query"] != "how does scoring work?" {
|
||||
t.Fatalf("expected query to be forwarded, got %v", payload["query"])
|
||||
}
|
||||
}
|
||||
|
||||
// --- Struct round-trip tests ---
|
||||
|
||||
func TestRememberInput_Good_RoundTrip(t *testing.T) {
|
||||
|
|
|
|||
|
|
@ -31,10 +31,16 @@ var (
|
|||
// NewProvider creates a brain provider that proxies to Laravel via the IDE bridge.
|
||||
// The WS hub is used to emit brain events. Pass nil for hub if not needed.
|
||||
func NewProvider(bridge *ide.Bridge, hub *ws.Hub) *BrainProvider {
|
||||
return &BrainProvider{
|
||||
p := &BrainProvider{
|
||||
bridge: bridge,
|
||||
hub: hub,
|
||||
}
|
||||
if bridge != nil {
|
||||
bridge.AddObserver(func(msg ide.BridgeMessage) {
|
||||
p.handleBridgeMessage(msg)
|
||||
})
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
// Name implements api.RouteGroup.
|
||||
|
|
@ -246,10 +252,6 @@ func (p *BrainProvider) recall(c *gin.Context) {
|
|||
return
|
||||
}
|
||||
|
||||
p.emitEvent(coremcp.ChannelBrainRecallDone, map[string]any{
|
||||
"query": input.Query,
|
||||
})
|
||||
|
||||
c.JSON(http.StatusOK, api.OK(RecallOutput{
|
||||
Success: true,
|
||||
Memories: []Memory{},
|
||||
|
|
@ -348,3 +350,28 @@ func (p *BrainProvider) emitEvent(channel string, data any) {
|
|||
Data: data,
|
||||
})
|
||||
}
|
||||
|
||||
func (p *BrainProvider) handleBridgeMessage(msg ide.BridgeMessage) {
|
||||
if msg.Type != "brain_recall" {
|
||||
return
|
||||
}
|
||||
|
||||
payload := map[string]any{}
|
||||
if data, ok := msg.Data.(map[string]any); ok {
|
||||
for _, key := range []string{"query", "project", "type", "agent_id"} {
|
||||
if value, ok := data[key]; ok {
|
||||
payload[key] = value
|
||||
}
|
||||
}
|
||||
if count, ok := data["count"]; ok {
|
||||
payload["count"] = count
|
||||
} else if memories, ok := data["memories"].([]any); ok {
|
||||
payload["count"] = len(memories)
|
||||
}
|
||||
}
|
||||
if _, ok := payload["count"]; !ok {
|
||||
payload["count"] = 0
|
||||
}
|
||||
|
||||
p.emitEvent(coremcp.ChannelBrainRecallDone, payload)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -179,11 +179,6 @@ func (s *Subsystem) brainRecall(ctx context.Context, _ *mcp.CallToolRequest, inp
|
|||
return nil, RecallOutput{}, coreerr.E("brain.recall", "failed to send brain_recall", err)
|
||||
}
|
||||
|
||||
s.emitChannel(ctx, coremcp.ChannelBrainRecallDone, map[string]any{
|
||||
"query": input.Query,
|
||||
"count": 0,
|
||||
})
|
||||
|
||||
return nil, RecallOutput{
|
||||
Success: true,
|
||||
Memories: []Memory{},
|
||||
|
|
|
|||
|
|
@ -31,7 +31,7 @@ type Bridge struct {
|
|||
mu sync.Mutex
|
||||
connected bool
|
||||
cancel context.CancelFunc
|
||||
onMessage func(BridgeMessage)
|
||||
observers []func(BridgeMessage)
|
||||
}
|
||||
|
||||
// NewBridge creates a bridge that will connect to the Laravel backend and
|
||||
|
|
@ -44,7 +44,22 @@ func NewBridge(hub *ws.Hub, cfg Config) *Bridge {
|
|||
func (b *Bridge) SetObserver(fn func(BridgeMessage)) {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
b.onMessage = fn
|
||||
if fn == nil {
|
||||
b.observers = nil
|
||||
return
|
||||
}
|
||||
b.observers = []func(BridgeMessage){fn}
|
||||
}
|
||||
|
||||
// AddObserver registers an additional bridge observer.
|
||||
// Observers are invoked in registration order after each inbound message.
|
||||
func (b *Bridge) AddObserver(fn func(BridgeMessage)) {
|
||||
if fn == nil {
|
||||
return
|
||||
}
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
b.observers = append(b.observers, fn)
|
||||
}
|
||||
|
||||
// Start begins the connection loop in a background goroutine.
|
||||
|
|
@ -169,15 +184,24 @@ func (b *Bridge) readLoop(ctx context.Context) {
|
|||
}
|
||||
|
||||
b.dispatch(msg)
|
||||
b.mu.Lock()
|
||||
observer := b.onMessage
|
||||
b.mu.Unlock()
|
||||
if observer != nil {
|
||||
for _, observer := range b.snapshotObservers() {
|
||||
observer(msg)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (b *Bridge) snapshotObservers() []func(BridgeMessage) {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
|
||||
if len(b.observers) == 0 {
|
||||
return nil
|
||||
}
|
||||
observers := make([]func(BridgeMessage), len(b.observers))
|
||||
copy(observers, b.observers)
|
||||
return observers
|
||||
}
|
||||
|
||||
// dispatch routes an incoming message to the appropriate ws.Hub channel.
|
||||
func (b *Bridge) dispatch(msg BridgeMessage) {
|
||||
if b.hub == nil {
|
||||
|
|
|
|||
|
|
@ -164,6 +164,71 @@ func TestBridge_Good_MessageDispatch(t *testing.T) {
|
|||
// This confirms the dispatch path ran without error.
|
||||
}
|
||||
|
||||
func TestBridge_Good_MultipleObservers(t *testing.T) {
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
conn, err := testUpgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
msg := BridgeMessage{
|
||||
Type: "brain_recall",
|
||||
Data: map[string]any{
|
||||
"query": "test query",
|
||||
"count": 3,
|
||||
},
|
||||
}
|
||||
data, _ := json.Marshal(msg)
|
||||
_ = conn.WriteMessage(websocket.TextMessage, data)
|
||||
|
||||
for {
|
||||
if _, _, err := conn.ReadMessage(); err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
hub := ws.NewHub()
|
||||
ctx := t.Context()
|
||||
go hub.Run(ctx)
|
||||
|
||||
cfg := DefaultConfig()
|
||||
cfg.LaravelWSURL = wsURL(ts)
|
||||
cfg.ReconnectInterval = 100 * time.Millisecond
|
||||
|
||||
bridge := NewBridge(hub, cfg)
|
||||
|
||||
first := make(chan struct{}, 1)
|
||||
second := make(chan struct{}, 1)
|
||||
bridge.AddObserver(func(msg BridgeMessage) {
|
||||
if msg.Type == "brain_recall" {
|
||||
first <- struct{}{}
|
||||
}
|
||||
})
|
||||
bridge.AddObserver(func(msg BridgeMessage) {
|
||||
if msg.Type == "brain_recall" {
|
||||
second <- struct{}{}
|
||||
}
|
||||
})
|
||||
|
||||
bridge.Start(ctx)
|
||||
waitConnected(t, bridge, 2*time.Second)
|
||||
|
||||
select {
|
||||
case <-first:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timed out waiting for first observer")
|
||||
}
|
||||
|
||||
select {
|
||||
case <-second:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timed out waiting for second observer")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBridge_Good_Reconnect(t *testing.T) {
|
||||
// Use atomic counter to avoid data race between HTTP handler goroutine
|
||||
// and the test goroutine.
|
||||
|
|
|
|||
|
|
@ -51,7 +51,7 @@ func New(hub *ws.Hub, cfg Config) *Subsystem {
|
|||
}
|
||||
if hub != nil {
|
||||
s.bridge = NewBridge(hub, cfg)
|
||||
s.bridge.SetObserver(func(msg BridgeMessage) {
|
||||
s.bridge.AddObserver(func(msg BridgeMessage) {
|
||||
s.handleBridgeMessage(msg)
|
||||
})
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue