fix(mcp): make notifications nil-safe
Guard the notification broadcast helpers against nil Service and nil server values so they degrade to no-ops instead of panicking. Add coverage for zero-value and nil-server use. Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
parent
d498b2981a
commit
a0caa6918c
2 changed files with 45 additions and 2 deletions
|
|
@ -95,6 +95,9 @@ type ChannelNotification struct {
|
|||
//
|
||||
// s.SendNotificationToAllClients(ctx, "info", "monitor", map[string]any{"event": "build complete"})
|
||||
func (s *Service) SendNotificationToAllClients(ctx context.Context, level mcp.LoggingLevel, logger string, data any) {
|
||||
if s == nil || s.server == nil {
|
||||
return
|
||||
}
|
||||
for session := range s.server.Sessions() {
|
||||
s.sendLoggingNotificationToSession(ctx, session, level, logger, data)
|
||||
}
|
||||
|
|
@ -105,11 +108,14 @@ func (s *Service) SendNotificationToAllClients(ctx context.Context, level mcp.Lo
|
|||
//
|
||||
// s.SendNotificationToSession(ctx, session, "info", "monitor", data)
|
||||
func (s *Service) SendNotificationToSession(ctx context.Context, session *mcp.ServerSession, level mcp.LoggingLevel, logger string, data any) {
|
||||
if s == nil || s.server == nil {
|
||||
return
|
||||
}
|
||||
s.sendLoggingNotificationToSession(ctx, session, level, logger, data)
|
||||
}
|
||||
|
||||
func (s *Service) sendLoggingNotificationToSession(ctx context.Context, session *mcp.ServerSession, level mcp.LoggingLevel, logger string, data any) {
|
||||
if session == nil {
|
||||
if s == nil || s.server == nil || session == nil {
|
||||
return
|
||||
}
|
||||
|
||||
|
|
@ -128,6 +134,9 @@ func (s *Service) sendLoggingNotificationToSession(ctx context.Context, session
|
|||
// s.ChannelSend(ctx, "agent.complete", map[string]any{"repo": "go-io", "workspace": "go-io-123"})
|
||||
// s.ChannelSend(ctx, "build.failed", map[string]any{"repo": "core", "error": "test timeout"})
|
||||
func (s *Service) ChannelSend(ctx context.Context, channel string, data any) {
|
||||
if s == nil || s.server == nil {
|
||||
return
|
||||
}
|
||||
payload := ChannelNotification{Channel: channel, Data: data}
|
||||
s.sendChannelNotificationToAllClients(ctx, payload)
|
||||
}
|
||||
|
|
@ -136,7 +145,7 @@ func (s *Service) ChannelSend(ctx context.Context, channel string, data any) {
|
|||
//
|
||||
// s.ChannelSendToSession(ctx, session, "agent.progress", progressData)
|
||||
func (s *Service) ChannelSendToSession(ctx context.Context, session *mcp.ServerSession, channel string, data any) {
|
||||
if session == nil {
|
||||
if s == nil || s.server == nil || session == nil {
|
||||
return
|
||||
}
|
||||
|
||||
|
|
@ -152,10 +161,16 @@ func (s *Service) ChannelSendToSession(ctx context.Context, session *mcp.ServerS
|
|||
// s.ChannelSendToSession(ctx, session, "status", data)
|
||||
// }
|
||||
func (s *Service) Sessions() iter.Seq[*mcp.ServerSession] {
|
||||
if s == nil || s.server == nil {
|
||||
return func(yield func(*mcp.ServerSession) bool) {}
|
||||
}
|
||||
return s.server.Sessions()
|
||||
}
|
||||
|
||||
func (s *Service) sendChannelNotificationToAllClients(ctx context.Context, payload ChannelNotification) {
|
||||
if s == nil || s.server == nil {
|
||||
return
|
||||
}
|
||||
for session := range s.server.Sessions() {
|
||||
if err := sendSessionNotification(ctx, session, channelNotificationMethod, payload); err != nil {
|
||||
s.logger.Debug("channel: failed to send to session", "session", session.ID(), "error", err)
|
||||
|
|
|
|||
|
|
@ -22,6 +22,34 @@ func TestSendNotificationToAllClients_Good(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func TestNotificationMethods_Good_NilService(t *testing.T) {
|
||||
var svc *Service
|
||||
|
||||
ctx := context.Background()
|
||||
svc.SendNotificationToAllClients(ctx, "info", "test", map[string]any{"ok": true})
|
||||
svc.SendNotificationToSession(ctx, nil, "info", "test", map[string]any{"ok": true})
|
||||
svc.ChannelSend(ctx, ChannelBuildComplete, map[string]any{"ok": true})
|
||||
svc.ChannelSendToSession(ctx, nil, ChannelBuildComplete, map[string]any{"ok": true})
|
||||
|
||||
for range svc.Sessions() {
|
||||
t.Fatal("expected no sessions from nil service")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNotificationMethods_Good_NilServer(t *testing.T) {
|
||||
svc := &Service{}
|
||||
|
||||
ctx := context.Background()
|
||||
svc.SendNotificationToAllClients(ctx, "info", "test", map[string]any{"ok": true})
|
||||
svc.SendNotificationToSession(ctx, nil, "info", "test", map[string]any{"ok": true})
|
||||
svc.ChannelSend(ctx, ChannelBuildComplete, map[string]any{"ok": true})
|
||||
svc.ChannelSendToSession(ctx, nil, ChannelBuildComplete, map[string]any{"ok": true})
|
||||
|
||||
for range svc.Sessions() {
|
||||
t.Fatal("expected no sessions from service without a server")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendNotificationToAllClients_Good_CustomNotification(t *testing.T) {
|
||||
svc, err := New(Options{})
|
||||
if err != nil {
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue