From cd60d9030ccd1702aeb1a889df9f66f5df496742 Mon Sep 17 00:00:00 2001 From: Virgil Date: Thu, 2 Apr 2026 13:19:24 +0000 Subject: [PATCH] fix(mcp): normalize tcp defaults and notification context Make notification broadcasting tolerant of a nil context and make TCP transport fall back to DefaultTCPAddr when no address is supplied. Co-Authored-By: Virgil --- pkg/mcp/notify.go | 15 ++++++++++++++- pkg/mcp/notify_test.go | 12 ++++++++++++ pkg/mcp/transport_tcp.go | 24 ++++++++++++++++++++---- pkg/mcp/transport_tcp_test.go | 20 ++++++++++++++++++++ 4 files changed, 66 insertions(+), 5 deletions(-) diff --git a/pkg/mcp/notify.go b/pkg/mcp/notify.go index 69e5ab7..54121b1 100644 --- a/pkg/mcp/notify.go +++ b/pkg/mcp/notify.go @@ -19,6 +19,13 @@ import ( "github.com/modelcontextprotocol/go-sdk/mcp" ) +func normalizeNotificationContext(ctx context.Context) context.Context { + if ctx == nil { + return context.Background() + } + return ctx +} + // lockedWriter wraps an io.Writer with a mutex. // Both the SDK's transport and ChannelSend use this writer, // ensuring channel notifications don't interleave with SDK messages. @@ -98,6 +105,7 @@ func (s *Service) SendNotificationToAllClients(ctx context.Context, level mcp.Lo if s == nil || s.server == nil { return } + ctx = normalizeNotificationContext(ctx) for session := range s.server.Sessions() { s.sendLoggingNotificationToSession(ctx, session, level, logger, data) } @@ -111,6 +119,7 @@ func (s *Service) SendNotificationToSession(ctx context.Context, session *mcp.Se if s == nil || s.server == nil { return } + ctx = normalizeNotificationContext(ctx) s.sendLoggingNotificationToSession(ctx, session, level, logger, data) } @@ -118,6 +127,7 @@ func (s *Service) sendLoggingNotificationToSession(ctx context.Context, session if s == nil || s.server == nil || session == nil { return } + ctx = normalizeNotificationContext(ctx) if err := sendSessionNotification(ctx, session, loggingNotificationMethod, &mcp.LoggingMessageParams{ Level: level, @@ -137,6 +147,7 @@ func (s *Service) ChannelSend(ctx context.Context, channel string, data any) { if s == nil || s.server == nil { return } + ctx = normalizeNotificationContext(ctx) payload := ChannelNotification{Channel: channel, Data: data} s.sendChannelNotificationToAllClients(ctx, payload) } @@ -148,7 +159,7 @@ func (s *Service) ChannelSendToSession(ctx context.Context, session *mcp.ServerS if s == nil || s.server == nil || session == nil { return } - + ctx = normalizeNotificationContext(ctx) payload := ChannelNotification{Channel: channel, Data: data} if err := sendSessionNotification(ctx, session, channelNotificationMethod, payload); err != nil { s.logger.Debug("channel: failed to send to session", "session", session.ID(), "error", err) @@ -171,6 +182,7 @@ func (s *Service) sendChannelNotificationToAllClients(ctx context.Context, paylo if s == nil || s.server == nil { return } + ctx = normalizeNotificationContext(ctx) for session := range s.server.Sessions() { if err := sendSessionNotification(ctx, session, channelNotificationMethod, payload); err != nil { s.logger.Debug("channel: failed to send to session", "session", session.ID(), "error", err) @@ -182,6 +194,7 @@ func sendSessionNotification(ctx context.Context, session *mcp.ServerSession, me if session == nil { return nil } + ctx = normalizeNotificationContext(ctx) conn, err := sessionConnection(session) if err != nil { diff --git a/pkg/mcp/notify_test.go b/pkg/mcp/notify_test.go index 04430bb..3a556af 100644 --- a/pkg/mcp/notify_test.go +++ b/pkg/mcp/notify_test.go @@ -50,6 +50,18 @@ func TestNotificationMethods_Good_NilServer(t *testing.T) { } } +func TestNotificationMethods_Good_NilContext(t *testing.T) { + svc, err := New(Options{}) + if err != nil { + t.Fatalf("New() failed: %v", err) + } + + svc.SendNotificationToAllClients(nil, "info", "test", map[string]any{"ok": true}) + svc.SendNotificationToSession(nil, nil, "info", "test", map[string]any{"ok": true}) + svc.ChannelSend(nil, ChannelBuildComplete, map[string]any{"ok": true}) + svc.ChannelSendToSession(nil, nil, ChannelBuildComplete, map[string]any{"ok": true}) +} + func TestSendNotificationToAllClients_Good_CustomNotification(t *testing.T) { svc, err := New(Options{}) if err != nil { diff --git a/pkg/mcp/transport_tcp.go b/pkg/mcp/transport_tcp.go index ca1a33e..24cdbfe 100644 --- a/pkg/mcp/transport_tcp.go +++ b/pkg/mcp/transport_tcp.go @@ -55,11 +55,14 @@ type TCPTransport struct { // NewTCPTransport creates a new TCP transport listener. // Defaults to 127.0.0.1 when the host component is empty (e.g. ":9100"). +// Defaults to DefaultTCPAddr when addr is empty. // Emits a security warning when explicitly binding to 0.0.0.0 (all interfaces). // // t, err := NewTCPTransport("127.0.0.1:9100") // t, err := NewTCPTransport(":9100") // defaults to 127.0.0.1:9100 func NewTCPTransport(addr string) (*TCPTransport, error) { + addr = normalizeTCPAddr(addr) + host, port, _ := net.SplitHostPort(addr) if host == "" { addr = net.JoinHostPort("127.0.0.1", port) @@ -73,6 +76,23 @@ func NewTCPTransport(addr string) (*TCPTransport, error) { return &TCPTransport{addr: addr, listener: listener}, nil } +func normalizeTCPAddr(addr string) string { + if addr == "" { + return DefaultTCPAddr + } + + host, port, err := net.SplitHostPort(addr) + if err != nil { + return addr + } + + if host == "" { + return net.JoinHostPort("127.0.0.1", port) + } + + return addr +} + // ServeTCP starts a TCP server for the MCP service. // It accepts connections and spawns a new MCP server session for each connection. // @@ -91,10 +111,6 @@ func (s *Service) ServeTCP(ctx context.Context, addr string) error { <-ctx.Done() _ = t.listener.Close() }() - - if addr == "" { - addr = t.listener.Addr().String() - } diagPrintf("MCP TCP server listening on %s\n", t.listener.Addr().String()) for { diff --git a/pkg/mcp/transport_tcp_test.go b/pkg/mcp/transport_tcp_test.go index 109a617..85e7379 100644 --- a/pkg/mcp/transport_tcp_test.go +++ b/pkg/mcp/transport_tcp_test.go @@ -31,6 +31,26 @@ func TestNewTCPTransport_Defaults(t *testing.T) { } } +func TestNormalizeTCPAddr_Good_Defaults(t *testing.T) { + tests := []struct { + name string + in string + want string + }{ + {name: "empty", in: "", want: DefaultTCPAddr}, + {name: "missing host", in: ":9100", want: "127.0.0.1:9100"}, + {name: "explicit host", in: "127.0.0.1:9100", want: "127.0.0.1:9100"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := normalizeTCPAddr(tt.in); got != tt.want { + t.Fatalf("normalizeTCPAddr(%q) = %q, want %q", tt.in, got, tt.want) + } + }) + } +} + func TestNewTCPTransport_Warning(t *testing.T) { // Capture warning output via setDiagWriter (mutex-protected, no race). var buf bytes.Buffer