fix(mcp): normalize tcp defaults and notification context
Make notification broadcasting tolerant of a nil context and make TCP transport fall back to DefaultTCPAddr when no address is supplied. Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
parent
a0caa6918c
commit
cd60d9030c
4 changed files with 66 additions and 5 deletions
|
|
@ -19,6 +19,13 @@ import (
|
|||
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||
)
|
||||
|
||||
func normalizeNotificationContext(ctx context.Context) context.Context {
|
||||
if ctx == nil {
|
||||
return context.Background()
|
||||
}
|
||||
return ctx
|
||||
}
|
||||
|
||||
// lockedWriter wraps an io.Writer with a mutex.
|
||||
// Both the SDK's transport and ChannelSend use this writer,
|
||||
// ensuring channel notifications don't interleave with SDK messages.
|
||||
|
|
@ -98,6 +105,7 @@ func (s *Service) SendNotificationToAllClients(ctx context.Context, level mcp.Lo
|
|||
if s == nil || s.server == nil {
|
||||
return
|
||||
}
|
||||
ctx = normalizeNotificationContext(ctx)
|
||||
for session := range s.server.Sessions() {
|
||||
s.sendLoggingNotificationToSession(ctx, session, level, logger, data)
|
||||
}
|
||||
|
|
@ -111,6 +119,7 @@ func (s *Service) SendNotificationToSession(ctx context.Context, session *mcp.Se
|
|||
if s == nil || s.server == nil {
|
||||
return
|
||||
}
|
||||
ctx = normalizeNotificationContext(ctx)
|
||||
s.sendLoggingNotificationToSession(ctx, session, level, logger, data)
|
||||
}
|
||||
|
||||
|
|
@ -118,6 +127,7 @@ func (s *Service) sendLoggingNotificationToSession(ctx context.Context, session
|
|||
if s == nil || s.server == nil || session == nil {
|
||||
return
|
||||
}
|
||||
ctx = normalizeNotificationContext(ctx)
|
||||
|
||||
if err := sendSessionNotification(ctx, session, loggingNotificationMethod, &mcp.LoggingMessageParams{
|
||||
Level: level,
|
||||
|
|
@ -137,6 +147,7 @@ func (s *Service) ChannelSend(ctx context.Context, channel string, data any) {
|
|||
if s == nil || s.server == nil {
|
||||
return
|
||||
}
|
||||
ctx = normalizeNotificationContext(ctx)
|
||||
payload := ChannelNotification{Channel: channel, Data: data}
|
||||
s.sendChannelNotificationToAllClients(ctx, payload)
|
||||
}
|
||||
|
|
@ -148,7 +159,7 @@ func (s *Service) ChannelSendToSession(ctx context.Context, session *mcp.ServerS
|
|||
if s == nil || s.server == nil || session == nil {
|
||||
return
|
||||
}
|
||||
|
||||
ctx = normalizeNotificationContext(ctx)
|
||||
payload := ChannelNotification{Channel: channel, Data: data}
|
||||
if err := sendSessionNotification(ctx, session, channelNotificationMethod, payload); err != nil {
|
||||
s.logger.Debug("channel: failed to send to session", "session", session.ID(), "error", err)
|
||||
|
|
@ -171,6 +182,7 @@ func (s *Service) sendChannelNotificationToAllClients(ctx context.Context, paylo
|
|||
if s == nil || s.server == nil {
|
||||
return
|
||||
}
|
||||
ctx = normalizeNotificationContext(ctx)
|
||||
for session := range s.server.Sessions() {
|
||||
if err := sendSessionNotification(ctx, session, channelNotificationMethod, payload); err != nil {
|
||||
s.logger.Debug("channel: failed to send to session", "session", session.ID(), "error", err)
|
||||
|
|
@ -182,6 +194,7 @@ func sendSessionNotification(ctx context.Context, session *mcp.ServerSession, me
|
|||
if session == nil {
|
||||
return nil
|
||||
}
|
||||
ctx = normalizeNotificationContext(ctx)
|
||||
|
||||
conn, err := sessionConnection(session)
|
||||
if err != nil {
|
||||
|
|
|
|||
|
|
@ -50,6 +50,18 @@ func TestNotificationMethods_Good_NilServer(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestNotificationMethods_Good_NilContext(t *testing.T) {
|
||||
svc, err := New(Options{})
|
||||
if err != nil {
|
||||
t.Fatalf("New() failed: %v", err)
|
||||
}
|
||||
|
||||
svc.SendNotificationToAllClients(nil, "info", "test", map[string]any{"ok": true})
|
||||
svc.SendNotificationToSession(nil, nil, "info", "test", map[string]any{"ok": true})
|
||||
svc.ChannelSend(nil, ChannelBuildComplete, map[string]any{"ok": true})
|
||||
svc.ChannelSendToSession(nil, nil, ChannelBuildComplete, map[string]any{"ok": true})
|
||||
}
|
||||
|
||||
func TestSendNotificationToAllClients_Good_CustomNotification(t *testing.T) {
|
||||
svc, err := New(Options{})
|
||||
if err != nil {
|
||||
|
|
|
|||
|
|
@ -55,11 +55,14 @@ type TCPTransport struct {
|
|||
|
||||
// NewTCPTransport creates a new TCP transport listener.
|
||||
// Defaults to 127.0.0.1 when the host component is empty (e.g. ":9100").
|
||||
// Defaults to DefaultTCPAddr when addr is empty.
|
||||
// Emits a security warning when explicitly binding to 0.0.0.0 (all interfaces).
|
||||
//
|
||||
// t, err := NewTCPTransport("127.0.0.1:9100")
|
||||
// t, err := NewTCPTransport(":9100") // defaults to 127.0.0.1:9100
|
||||
func NewTCPTransport(addr string) (*TCPTransport, error) {
|
||||
addr = normalizeTCPAddr(addr)
|
||||
|
||||
host, port, _ := net.SplitHostPort(addr)
|
||||
if host == "" {
|
||||
addr = net.JoinHostPort("127.0.0.1", port)
|
||||
|
|
@ -73,6 +76,23 @@ func NewTCPTransport(addr string) (*TCPTransport, error) {
|
|||
return &TCPTransport{addr: addr, listener: listener}, nil
|
||||
}
|
||||
|
||||
func normalizeTCPAddr(addr string) string {
|
||||
if addr == "" {
|
||||
return DefaultTCPAddr
|
||||
}
|
||||
|
||||
host, port, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
return addr
|
||||
}
|
||||
|
||||
if host == "" {
|
||||
return net.JoinHostPort("127.0.0.1", port)
|
||||
}
|
||||
|
||||
return addr
|
||||
}
|
||||
|
||||
// ServeTCP starts a TCP server for the MCP service.
|
||||
// It accepts connections and spawns a new MCP server session for each connection.
|
||||
//
|
||||
|
|
@ -91,10 +111,6 @@ func (s *Service) ServeTCP(ctx context.Context, addr string) error {
|
|||
<-ctx.Done()
|
||||
_ = t.listener.Close()
|
||||
}()
|
||||
|
||||
if addr == "" {
|
||||
addr = t.listener.Addr().String()
|
||||
}
|
||||
diagPrintf("MCP TCP server listening on %s\n", t.listener.Addr().String())
|
||||
|
||||
for {
|
||||
|
|
|
|||
|
|
@ -31,6 +31,26 @@ func TestNewTCPTransport_Defaults(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestNormalizeTCPAddr_Good_Defaults(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
in string
|
||||
want string
|
||||
}{
|
||||
{name: "empty", in: "", want: DefaultTCPAddr},
|
||||
{name: "missing host", in: ":9100", want: "127.0.0.1:9100"},
|
||||
{name: "explicit host", in: "127.0.0.1:9100", want: "127.0.0.1:9100"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := normalizeTCPAddr(tt.in); got != tt.want {
|
||||
t.Fatalf("normalizeTCPAddr(%q) = %q, want %q", tt.in, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewTCPTransport_Warning(t *testing.T) {
|
||||
// Capture warning output via setDiagWriter (mutex-protected, no race).
|
||||
var buf bytes.Buffer
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue