fix(mcp): normalize tcp defaults and notification context

Make notification broadcasting tolerant of a nil context and make TCP transport fall back to DefaultTCPAddr when no address is supplied.

Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
Virgil 2026-04-02 13:19:24 +00:00
parent a0caa6918c
commit cd60d9030c
4 changed files with 66 additions and 5 deletions

View file

@ -19,6 +19,13 @@ import (
"github.com/modelcontextprotocol/go-sdk/mcp"
)
func normalizeNotificationContext(ctx context.Context) context.Context {
if ctx == nil {
return context.Background()
}
return ctx
}
// lockedWriter wraps an io.Writer with a mutex.
// Both the SDK's transport and ChannelSend use this writer,
// ensuring channel notifications don't interleave with SDK messages.
@ -98,6 +105,7 @@ func (s *Service) SendNotificationToAllClients(ctx context.Context, level mcp.Lo
if s == nil || s.server == nil {
return
}
ctx = normalizeNotificationContext(ctx)
for session := range s.server.Sessions() {
s.sendLoggingNotificationToSession(ctx, session, level, logger, data)
}
@ -111,6 +119,7 @@ func (s *Service) SendNotificationToSession(ctx context.Context, session *mcp.Se
if s == nil || s.server == nil {
return
}
ctx = normalizeNotificationContext(ctx)
s.sendLoggingNotificationToSession(ctx, session, level, logger, data)
}
@ -118,6 +127,7 @@ func (s *Service) sendLoggingNotificationToSession(ctx context.Context, session
if s == nil || s.server == nil || session == nil {
return
}
ctx = normalizeNotificationContext(ctx)
if err := sendSessionNotification(ctx, session, loggingNotificationMethod, &mcp.LoggingMessageParams{
Level: level,
@ -137,6 +147,7 @@ func (s *Service) ChannelSend(ctx context.Context, channel string, data any) {
if s == nil || s.server == nil {
return
}
ctx = normalizeNotificationContext(ctx)
payload := ChannelNotification{Channel: channel, Data: data}
s.sendChannelNotificationToAllClients(ctx, payload)
}
@ -148,7 +159,7 @@ func (s *Service) ChannelSendToSession(ctx context.Context, session *mcp.ServerS
if s == nil || s.server == nil || session == nil {
return
}
ctx = normalizeNotificationContext(ctx)
payload := ChannelNotification{Channel: channel, Data: data}
if err := sendSessionNotification(ctx, session, channelNotificationMethod, payload); err != nil {
s.logger.Debug("channel: failed to send to session", "session", session.ID(), "error", err)
@ -171,6 +182,7 @@ func (s *Service) sendChannelNotificationToAllClients(ctx context.Context, paylo
if s == nil || s.server == nil {
return
}
ctx = normalizeNotificationContext(ctx)
for session := range s.server.Sessions() {
if err := sendSessionNotification(ctx, session, channelNotificationMethod, payload); err != nil {
s.logger.Debug("channel: failed to send to session", "session", session.ID(), "error", err)
@ -182,6 +194,7 @@ func sendSessionNotification(ctx context.Context, session *mcp.ServerSession, me
if session == nil {
return nil
}
ctx = normalizeNotificationContext(ctx)
conn, err := sessionConnection(session)
if err != nil {

View file

@ -50,6 +50,18 @@ func TestNotificationMethods_Good_NilServer(t *testing.T) {
}
}
func TestNotificationMethods_Good_NilContext(t *testing.T) {
svc, err := New(Options{})
if err != nil {
t.Fatalf("New() failed: %v", err)
}
svc.SendNotificationToAllClients(nil, "info", "test", map[string]any{"ok": true})
svc.SendNotificationToSession(nil, nil, "info", "test", map[string]any{"ok": true})
svc.ChannelSend(nil, ChannelBuildComplete, map[string]any{"ok": true})
svc.ChannelSendToSession(nil, nil, ChannelBuildComplete, map[string]any{"ok": true})
}
func TestSendNotificationToAllClients_Good_CustomNotification(t *testing.T) {
svc, err := New(Options{})
if err != nil {

View file

@ -55,11 +55,14 @@ type TCPTransport struct {
// NewTCPTransport creates a new TCP transport listener.
// Defaults to 127.0.0.1 when the host component is empty (e.g. ":9100").
// Defaults to DefaultTCPAddr when addr is empty.
// Emits a security warning when explicitly binding to 0.0.0.0 (all interfaces).
//
// t, err := NewTCPTransport("127.0.0.1:9100")
// t, err := NewTCPTransport(":9100") // defaults to 127.0.0.1:9100
func NewTCPTransport(addr string) (*TCPTransport, error) {
addr = normalizeTCPAddr(addr)
host, port, _ := net.SplitHostPort(addr)
if host == "" {
addr = net.JoinHostPort("127.0.0.1", port)
@ -73,6 +76,23 @@ func NewTCPTransport(addr string) (*TCPTransport, error) {
return &TCPTransport{addr: addr, listener: listener}, nil
}
func normalizeTCPAddr(addr string) string {
if addr == "" {
return DefaultTCPAddr
}
host, port, err := net.SplitHostPort(addr)
if err != nil {
return addr
}
if host == "" {
return net.JoinHostPort("127.0.0.1", port)
}
return addr
}
// ServeTCP starts a TCP server for the MCP service.
// It accepts connections and spawns a new MCP server session for each connection.
//
@ -91,10 +111,6 @@ func (s *Service) ServeTCP(ctx context.Context, addr string) error {
<-ctx.Done()
_ = t.listener.Close()
}()
if addr == "" {
addr = t.listener.Addr().String()
}
diagPrintf("MCP TCP server listening on %s\n", t.listener.Addr().String())
for {

View file

@ -31,6 +31,26 @@ func TestNewTCPTransport_Defaults(t *testing.T) {
}
}
func TestNormalizeTCPAddr_Good_Defaults(t *testing.T) {
tests := []struct {
name string
in string
want string
}{
{name: "empty", in: "", want: DefaultTCPAddr},
{name: "missing host", in: ":9100", want: "127.0.0.1:9100"},
{name: "explicit host", in: "127.0.0.1:9100", want: "127.0.0.1:9100"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := normalizeTCPAddr(tt.in); got != tt.want {
t.Fatalf("normalizeTCPAddr(%q) = %q, want %q", tt.in, got, tt.want)
}
})
}
}
func TestNewTCPTransport_Warning(t *testing.T) {
// Capture warning output via setDiagWriter (mutex-protected, no race).
var buf bytes.Buffer