fix(mcp): broadcast notifications without log gating
Use the underlying MCP notification path for session broadcasts so fresh sessions receive notifications without requiring a prior log-level handshake. Add a regression test that verifies broadcast delivery on a connected session. Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
parent
a34115266c
commit
bf3ef9f595
2 changed files with 74 additions and 2 deletions
|
|
@ -39,6 +39,7 @@ func (lw *lockedWriter) Close() error { return nil }
|
|||
var sharedStdout = &lockedWriter{w: os.Stdout}
|
||||
|
||||
const channelNotificationMethod = "notifications/claude/channel"
|
||||
const loggingNotificationMethod = "notifications/message"
|
||||
|
||||
// ChannelNotification is the payload sent through the experimental channel
|
||||
// notification method.
|
||||
|
|
@ -54,7 +55,7 @@ type ChannelNotification struct {
|
|||
// s.SendNotificationToAllClients(ctx, "info", "monitor", map[string]any{"event": "build complete"})
|
||||
func (s *Service) SendNotificationToAllClients(ctx context.Context, level mcp.LoggingLevel, logger string, data any) {
|
||||
for session := range s.server.Sessions() {
|
||||
s.SendNotificationToSession(ctx, session, level, logger, data)
|
||||
s.sendLoggingNotificationToSession(ctx, session, level, logger, data)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -63,11 +64,15 @@ func (s *Service) SendNotificationToAllClients(ctx context.Context, level mcp.Lo
|
|||
//
|
||||
// s.SendNotificationToSession(ctx, session, "info", "monitor", data)
|
||||
func (s *Service) SendNotificationToSession(ctx context.Context, session *mcp.ServerSession, level mcp.LoggingLevel, logger string, data any) {
|
||||
s.sendLoggingNotificationToSession(ctx, session, level, logger, data)
|
||||
}
|
||||
|
||||
func (s *Service) sendLoggingNotificationToSession(ctx context.Context, session *mcp.ServerSession, level mcp.LoggingLevel, logger string, data any) {
|
||||
if session == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if err := session.Log(ctx, &mcp.LoggingMessageParams{
|
||||
if err := sendSessionNotification(ctx, session, loggingNotificationMethod, &mcp.LoggingMessageParams{
|
||||
Level: level,
|
||||
Logger: logger,
|
||||
Data: data,
|
||||
|
|
|
|||
|
|
@ -21,6 +21,73 @@ func TestSendNotificationToAllClients_Good(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func TestSendNotificationToAllClients_Good_CustomNotification(t *testing.T) {
|
||||
svc, err := New(Options{})
|
||||
if err != nil {
|
||||
t.Fatalf("New() failed: %v", err)
|
||||
}
|
||||
|
||||
serverConn, clientConn := net.Pipe()
|
||||
defer clientConn.Close()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
session, err := svc.server.Connect(ctx, &connTransport{conn: serverConn}, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Connect() failed: %v", err)
|
||||
}
|
||||
defer session.Close()
|
||||
|
||||
clientConn.SetDeadline(time.Now().Add(5 * time.Second))
|
||||
scanner := bufio.NewScanner(clientConn)
|
||||
scanner.Buffer(make([]byte, 64*1024), 10*1024*1024)
|
||||
|
||||
sent := make(chan struct{})
|
||||
go func() {
|
||||
svc.SendNotificationToAllClients(ctx, "info", "test", map[string]any{
|
||||
"event": "build.complete",
|
||||
})
|
||||
close(sent)
|
||||
}()
|
||||
|
||||
if !scanner.Scan() {
|
||||
t.Fatalf("failed to read notification: %v", scanner.Err())
|
||||
}
|
||||
|
||||
select {
|
||||
case <-sent:
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("timed out waiting for notification send to complete")
|
||||
}
|
||||
|
||||
var msg map[string]any
|
||||
if err := json.Unmarshal(scanner.Bytes(), &msg); err != nil {
|
||||
t.Fatalf("failed to unmarshal notification: %v", err)
|
||||
}
|
||||
if msg["method"] != loggingNotificationMethod {
|
||||
t.Fatalf("expected method %q, got %v", loggingNotificationMethod, msg["method"])
|
||||
}
|
||||
|
||||
params, ok := msg["params"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("expected params object, got %T", msg["params"])
|
||||
}
|
||||
if params["logger"] != "test" {
|
||||
t.Fatalf("expected logger test, got %v", params["logger"])
|
||||
}
|
||||
if params["level"] != "info" {
|
||||
t.Fatalf("expected level info, got %v", params["level"])
|
||||
}
|
||||
data, ok := params["data"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("expected data object, got %T", params["data"])
|
||||
}
|
||||
if data["event"] != "build.complete" {
|
||||
t.Fatalf("expected event build.complete, got %v", data["event"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestChannelSend_Good(t *testing.T) {
|
||||
svc, err := New(Options{})
|
||||
if err != nil {
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue