diff --git a/pkg/mcp/notify.go b/pkg/mcp/notify.go index aba9afa..805cf68 100644 --- a/pkg/mcp/notify.go +++ b/pkg/mcp/notify.go @@ -39,6 +39,7 @@ func (lw *lockedWriter) Close() error { return nil } var sharedStdout = &lockedWriter{w: os.Stdout} const channelNotificationMethod = "notifications/claude/channel" +const loggingNotificationMethod = "notifications/message" // ChannelNotification is the payload sent through the experimental channel // notification method. @@ -54,7 +55,7 @@ type ChannelNotification struct { // s.SendNotificationToAllClients(ctx, "info", "monitor", map[string]any{"event": "build complete"}) func (s *Service) SendNotificationToAllClients(ctx context.Context, level mcp.LoggingLevel, logger string, data any) { for session := range s.server.Sessions() { - s.SendNotificationToSession(ctx, session, level, logger, data) + s.sendLoggingNotificationToSession(ctx, session, level, logger, data) } } @@ -63,11 +64,15 @@ func (s *Service) SendNotificationToAllClients(ctx context.Context, level mcp.Lo // // s.SendNotificationToSession(ctx, session, "info", "monitor", data) func (s *Service) SendNotificationToSession(ctx context.Context, session *mcp.ServerSession, level mcp.LoggingLevel, logger string, data any) { + s.sendLoggingNotificationToSession(ctx, session, level, logger, data) +} + +func (s *Service) sendLoggingNotificationToSession(ctx context.Context, session *mcp.ServerSession, level mcp.LoggingLevel, logger string, data any) { if session == nil { return } - if err := session.Log(ctx, &mcp.LoggingMessageParams{ + if err := sendSessionNotification(ctx, session, loggingNotificationMethod, &mcp.LoggingMessageParams{ Level: level, Logger: logger, Data: data, diff --git a/pkg/mcp/notify_test.go b/pkg/mcp/notify_test.go index def8148..03ab35f 100644 --- a/pkg/mcp/notify_test.go +++ b/pkg/mcp/notify_test.go @@ -21,6 +21,73 @@ func TestSendNotificationToAllClients_Good(t *testing.T) { }) } +func TestSendNotificationToAllClients_Good_CustomNotification(t *testing.T) { + svc, err := New(Options{}) + if err != nil { + t.Fatalf("New() failed: %v", err) + } + + serverConn, clientConn := net.Pipe() + defer clientConn.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + session, err := svc.server.Connect(ctx, &connTransport{conn: serverConn}, nil) + if err != nil { + t.Fatalf("Connect() failed: %v", err) + } + defer session.Close() + + clientConn.SetDeadline(time.Now().Add(5 * time.Second)) + scanner := bufio.NewScanner(clientConn) + scanner.Buffer(make([]byte, 64*1024), 10*1024*1024) + + sent := make(chan struct{}) + go func() { + svc.SendNotificationToAllClients(ctx, "info", "test", map[string]any{ + "event": "build.complete", + }) + close(sent) + }() + + if !scanner.Scan() { + t.Fatalf("failed to read notification: %v", scanner.Err()) + } + + select { + case <-sent: + case <-time.After(5 * time.Second): + t.Fatal("timed out waiting for notification send to complete") + } + + var msg map[string]any + if err := json.Unmarshal(scanner.Bytes(), &msg); err != nil { + t.Fatalf("failed to unmarshal notification: %v", err) + } + if msg["method"] != loggingNotificationMethod { + t.Fatalf("expected method %q, got %v", loggingNotificationMethod, msg["method"]) + } + + params, ok := msg["params"].(map[string]any) + if !ok { + t.Fatalf("expected params object, got %T", msg["params"]) + } + if params["logger"] != "test" { + t.Fatalf("expected logger test, got %v", params["logger"]) + } + if params["level"] != "info" { + t.Fatalf("expected level info, got %v", params["level"]) + } + data, ok := params["data"].(map[string]any) + if !ok { + t.Fatalf("expected data object, got %T", params["data"]) + } + if data["event"] != "build.complete" { + t.Fatalf("expected event build.complete, got %v", data["event"]) + } +} + func TestChannelSend_Good(t *testing.T) { svc, err := New(Options{}) if err != nil {