From a0caa6918c16d27fdf08e05983fed39a88ee0aa7 Mon Sep 17 00:00:00 2001 From: Virgil Date: Thu, 2 Apr 2026 13:13:04 +0000 Subject: [PATCH] fix(mcp): make notifications nil-safe Guard the notification broadcast helpers against nil Service and nil server values so they degrade to no-ops instead of panicking. Add coverage for zero-value and nil-server use. Co-Authored-By: Virgil --- pkg/mcp/notify.go | 19 +++++++++++++++++-- pkg/mcp/notify_test.go | 28 ++++++++++++++++++++++++++++ 2 files changed, 45 insertions(+), 2 deletions(-) diff --git a/pkg/mcp/notify.go b/pkg/mcp/notify.go index 6340430..69e5ab7 100644 --- a/pkg/mcp/notify.go +++ b/pkg/mcp/notify.go @@ -95,6 +95,9 @@ type ChannelNotification struct { // // s.SendNotificationToAllClients(ctx, "info", "monitor", map[string]any{"event": "build complete"}) func (s *Service) SendNotificationToAllClients(ctx context.Context, level mcp.LoggingLevel, logger string, data any) { + if s == nil || s.server == nil { + return + } for session := range s.server.Sessions() { s.sendLoggingNotificationToSession(ctx, session, level, logger, data) } @@ -105,11 +108,14 @@ func (s *Service) SendNotificationToAllClients(ctx context.Context, level mcp.Lo // // s.SendNotificationToSession(ctx, session, "info", "monitor", data) func (s *Service) SendNotificationToSession(ctx context.Context, session *mcp.ServerSession, level mcp.LoggingLevel, logger string, data any) { + if s == nil || s.server == nil { + return + } s.sendLoggingNotificationToSession(ctx, session, level, logger, data) } func (s *Service) sendLoggingNotificationToSession(ctx context.Context, session *mcp.ServerSession, level mcp.LoggingLevel, logger string, data any) { - if session == nil { + if s == nil || s.server == nil || session == nil { return } @@ -128,6 +134,9 @@ func (s *Service) sendLoggingNotificationToSession(ctx context.Context, session // s.ChannelSend(ctx, "agent.complete", map[string]any{"repo": "go-io", "workspace": "go-io-123"}) // s.ChannelSend(ctx, "build.failed", map[string]any{"repo": "core", "error": "test timeout"}) func (s *Service) ChannelSend(ctx context.Context, channel string, data any) { + if s == nil || s.server == nil { + return + } payload := ChannelNotification{Channel: channel, Data: data} s.sendChannelNotificationToAllClients(ctx, payload) } @@ -136,7 +145,7 @@ func (s *Service) ChannelSend(ctx context.Context, channel string, data any) { // // s.ChannelSendToSession(ctx, session, "agent.progress", progressData) func (s *Service) ChannelSendToSession(ctx context.Context, session *mcp.ServerSession, channel string, data any) { - if session == nil { + if s == nil || s.server == nil || session == nil { return } @@ -152,10 +161,16 @@ func (s *Service) ChannelSendToSession(ctx context.Context, session *mcp.ServerS // s.ChannelSendToSession(ctx, session, "status", data) // } func (s *Service) Sessions() iter.Seq[*mcp.ServerSession] { + if s == nil || s.server == nil { + return func(yield func(*mcp.ServerSession) bool) {} + } return s.server.Sessions() } func (s *Service) sendChannelNotificationToAllClients(ctx context.Context, payload ChannelNotification) { + if s == nil || s.server == nil { + return + } for session := range s.server.Sessions() { if err := sendSessionNotification(ctx, session, channelNotificationMethod, payload); err != nil { s.logger.Debug("channel: failed to send to session", "session", session.ID(), "error", err) diff --git a/pkg/mcp/notify_test.go b/pkg/mcp/notify_test.go index 8fcd345..04430bb 100644 --- a/pkg/mcp/notify_test.go +++ b/pkg/mcp/notify_test.go @@ -22,6 +22,34 @@ func TestSendNotificationToAllClients_Good(t *testing.T) { }) } +func TestNotificationMethods_Good_NilService(t *testing.T) { + var svc *Service + + ctx := context.Background() + svc.SendNotificationToAllClients(ctx, "info", "test", map[string]any{"ok": true}) + svc.SendNotificationToSession(ctx, nil, "info", "test", map[string]any{"ok": true}) + svc.ChannelSend(ctx, ChannelBuildComplete, map[string]any{"ok": true}) + svc.ChannelSendToSession(ctx, nil, ChannelBuildComplete, map[string]any{"ok": true}) + + for range svc.Sessions() { + t.Fatal("expected no sessions from nil service") + } +} + +func TestNotificationMethods_Good_NilServer(t *testing.T) { + svc := &Service{} + + ctx := context.Background() + svc.SendNotificationToAllClients(ctx, "info", "test", map[string]any{"ok": true}) + svc.SendNotificationToSession(ctx, nil, "info", "test", map[string]any{"ok": true}) + svc.ChannelSend(ctx, ChannelBuildComplete, map[string]any{"ok": true}) + svc.ChannelSendToSession(ctx, nil, ChannelBuildComplete, map[string]any{"ok": true}) + + for range svc.Sessions() { + t.Fatal("expected no sessions from service without a server") + } +} + func TestSendNotificationToAllClients_Good_CustomNotification(t *testing.T) { svc, err := New(Options{}) if err != nil {