diff --git a/pkg/mcp/notify.go b/pkg/mcp/notify.go index 042df59..6ed9d78 100644 --- a/pkg/mcp/notify.go +++ b/pkg/mcp/notify.go @@ -157,6 +157,14 @@ func (s *Service) SendNotificationToSession(ctx context.Context, session *mcp.Se s.sendLoggingNotificationToSession(ctx, session, level, logger, data) } +// SendNotificationToClient sends a log-level notification to one connected +// MCP client. +// +// s.SendNotificationToClient(ctx, client, "info", "monitor", data) +func (s *Service) SendNotificationToClient(ctx context.Context, client *mcp.ServerSession, level mcp.LoggingLevel, logger string, data any) { + s.SendNotificationToSession(ctx, client, level, logger, data) +} + func (s *Service) sendLoggingNotificationToSession(ctx context.Context, session *mcp.ServerSession, level mcp.LoggingLevel, logger string, data any) { if s == nil || s.server == nil || session == nil { return @@ -206,6 +214,13 @@ func (s *Service) ChannelSendToSession(ctx context.Context, session *mcp.ServerS } } +// ChannelSendToClient pushes a channel event to one connected MCP client. +// +// s.ChannelSendToClient(ctx, client, "agent.progress", progressData) +func (s *Service) ChannelSendToClient(ctx context.Context, client *mcp.ServerSession, channel string, data any) { + s.ChannelSendToSession(ctx, client, channel, data) +} + // Sessions returns an iterator over all connected MCP sessions. // // for session := range s.Sessions() { diff --git a/pkg/mcp/notify_test.go b/pkg/mcp/notify_test.go index d4f0b39..8b7cf3d 100644 --- a/pkg/mcp/notify_test.go +++ b/pkg/mcp/notify_test.go @@ -306,6 +306,102 @@ func TestChannelSendToSession_Good_CustomNotification(t *testing.T) { } } +func TestChannelSendToClient_Good_CustomNotification(t *testing.T) { + svc, err := New(Options{}) + if err != nil { + t.Fatalf("New() failed: %v", err) + } + + serverConn, clientConn := net.Pipe() + defer clientConn.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + session, err := svc.server.Connect(ctx, &connTransport{conn: serverConn}, nil) + if err != nil { + t.Fatalf("Connect() failed: %v", err) + } + defer session.Close() + + clientConn.SetDeadline(time.Now().Add(5 * time.Second)) + + read := readNotificationMessageUntil(t, clientConn, func(msg map[string]any) bool { + return msg["method"] == channelNotificationMethod + }) + + sent := make(chan struct{}) + go func() { + svc.ChannelSendToClient(ctx, session, ChannelBuildComplete, map[string]any{ + "repo": "go-io", + }) + close(sent) + }() + + select { + case <-sent: + case <-time.After(5 * time.Second): + t.Fatal("timed out waiting for notification send to complete") + } + + res := <-read + if res.err != nil { + t.Fatalf("failed to read custom notification: %v", res.err) + } + msg := res.msg + if msg["method"] != channelNotificationMethod { + t.Fatalf("expected method %q, got %v", channelNotificationMethod, msg["method"]) + } +} + +func TestSendNotificationToClient_Good_CustomNotification(t *testing.T) { + svc, err := New(Options{}) + if err != nil { + t.Fatalf("New() failed: %v", err) + } + + serverConn, clientConn := net.Pipe() + defer clientConn.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + session, err := svc.server.Connect(ctx, &connTransport{conn: serverConn}, nil) + if err != nil { + t.Fatalf("Connect() failed: %v", err) + } + defer session.Close() + + clientConn.SetDeadline(time.Now().Add(5 * time.Second)) + + read := readNotificationMessageUntil(t, clientConn, func(msg map[string]any) bool { + return msg["method"] == loggingNotificationMethod + }) + + sent := make(chan struct{}) + go func() { + svc.SendNotificationToClient(ctx, session, "info", "test", map[string]any{ + "event": ChannelBuildComplete, + }) + close(sent) + }() + + select { + case <-sent: + case <-time.After(5 * time.Second): + t.Fatal("timed out waiting for notification send to complete") + } + + res := <-read + if res.err != nil { + t.Fatalf("failed to read notification: %v", res.err) + } + msg := res.msg + if msg["method"] != loggingNotificationMethod { + t.Fatalf("expected method %q, got %v", loggingNotificationMethod, msg["method"]) + } +} + func TestChannelCapability_Good(t *testing.T) { caps := channelCapability() raw, ok := caps["claude/channel"]