diff --git a/pkg/mcp/notify_test.go b/pkg/mcp/notify_test.go index be4ed4a..7801758 100644 --- a/pkg/mcp/notify_test.go +++ b/pkg/mcp/notify_test.go @@ -9,8 +9,56 @@ import ( "slices" "testing" "time" + + "github.com/modelcontextprotocol/go-sdk/mcp" ) +type notificationReadResult struct { + msg map[string]any + err error +} + +func connectNotificationSession(t *testing.T, svc *Service) (context.CancelFunc, *mcp.ServerSession, net.Conn) { + t.Helper() + + serverConn, clientConn := net.Pipe() + ctx, cancel := context.WithCancel(context.Background()) + + session, err := svc.server.Connect(ctx, &connTransport{conn: serverConn}, nil) + if err != nil { + cancel() + clientConn.Close() + t.Fatalf("Connect() failed: %v", err) + } + + return cancel, session, clientConn +} + +func readNotificationMessage(t *testing.T, conn net.Conn) <-chan notificationReadResult { + t.Helper() + + resultCh := make(chan notificationReadResult, 1) + go func() { + scanner := bufio.NewScanner(conn) + scanner.Buffer(make([]byte, 64*1024), 10*1024*1024) + + if !scanner.Scan() { + resultCh <- notificationReadResult{err: scanner.Err()} + return + } + + var msg map[string]any + if err := json.Unmarshal(scanner.Bytes(), &msg); err != nil { + resultCh <- notificationReadResult{err: err} + return + } + + resultCh <- notificationReadResult{msg: msg} + }() + + return resultCh +} + func TestSendNotificationToAllClients_Good(t *testing.T) { svc, err := New(Options{}) if err != nil { @@ -279,3 +327,63 @@ func TestChannelCapability_Good_PublicHelpers(t *testing.T) { t.Fatalf("expected public channel list to match internal definition: got %v want %v", gotChannels, wantChannels) } } + +func TestSendNotificationToAllClients_Good_BroadcastsToMultipleSessions(t *testing.T) { + svc, err := New(Options{}) + if err != nil { + t.Fatalf("New() failed: %v", err) + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + cancel1, session1, clientConn1 := connectNotificationSession(t, svc) + defer cancel1() + defer session1.Close() + defer clientConn1.Close() + + cancel2, session2, clientConn2 := connectNotificationSession(t, svc) + defer cancel2() + defer session2.Close() + defer clientConn2.Close() + + read1 := readNotificationMessage(t, clientConn1) + read2 := readNotificationMessage(t, clientConn2) + + sent := make(chan struct{}) + go func() { + svc.SendNotificationToAllClients(ctx, "info", "test", map[string]any{ + "event": ChannelBuildComplete, + }) + close(sent) + }() + + select { + case <-sent: + case <-time.After(5 * time.Second): + t.Fatal("timed out waiting for broadcast to complete") + } + + res1 := <-read1 + if res1.err != nil { + t.Fatalf("failed to read notification from session 1: %v", res1.err) + } + res2 := <-read2 + if res2.err != nil { + t.Fatalf("failed to read notification from session 2: %v", res2.err) + } + + for idx, res := range []notificationReadResult{res1, res2} { + if res.msg["method"] != loggingNotificationMethod { + t.Fatalf("session %d: expected method %q, got %v", idx+1, loggingNotificationMethod, res.msg["method"]) + } + + params, ok := res.msg["params"].(map[string]any) + if !ok { + t.Fatalf("session %d: expected params object, got %T", idx+1, res.msg["params"]) + } + if params["logger"] != "test" { + t.Fatalf("session %d: expected logger test, got %v", idx+1, params["logger"]) + } + } +}