diff --git a/pkg/mcp/notify.go b/pkg/mcp/notify.go index f8b8c31..d1b1d41 100644 --- a/pkg/mcp/notify.go +++ b/pkg/mcp/notify.go @@ -90,6 +90,34 @@ var channelCapabilityList = []string{ ChannelTestResult, } +// ChannelCapabilitySpec describes the experimental claude/channel capability. +// +// spec := ChannelCapabilitySpec{ +// Version: "1", +// Description: "Push events into client sessions via named channels", +// Channels: ChannelCapabilityChannels(), +// } +type ChannelCapabilitySpec struct { + Version string `json:"version"` + Description string `json:"description"` + Channels []string `json:"channels"` +} + +// Map converts the typed capability into the wire-format map expected by the SDK. +// +// caps := ChannelCapabilitySpec{ +// Version: "1", +// Description: "Push events into client sessions via named channels", +// Channels: ChannelCapabilityChannels(), +// }.Map() +func (c ChannelCapabilitySpec) Map() map[string]any { + return map[string]any{ + "version": c.Version, + "description": c.Description, + "channels": slices.Clone(c.Channels), + } +} + // ChannelNotification is the payload sent through the experimental channel // notification method. // @@ -280,11 +308,18 @@ func (e *notificationError) Error() string { // for claude/channel, registered during New(). func channelCapability() map[string]any { return map[string]any{ - "claude/channel": map[string]any{ - "version": "1", - "description": "Push events into client sessions via named channels", - "channels": channelCapabilityChannels(), - }, + "claude/channel": ClaudeChannelCapability().Map(), + } +} + +// ClaudeChannelCapability returns the typed experimental capability descriptor. +// +// cap := ClaudeChannelCapability() +func ClaudeChannelCapability() ChannelCapabilitySpec { + return ChannelCapabilitySpec{ + Version: "1", + Description: "Push events into client sessions via named channels", + Channels: channelCapabilityChannels(), } } diff --git a/pkg/mcp/notify_test.go b/pkg/mcp/notify_test.go index 7801758..d4f0b39 100644 --- a/pkg/mcp/notify_test.go +++ b/pkg/mcp/notify_test.go @@ -59,6 +59,36 @@ func readNotificationMessage(t *testing.T, conn net.Conn) <-chan notificationRea return resultCh } +func readNotificationMessageUntil(t *testing.T, conn net.Conn, match func(map[string]any) bool) <-chan notificationReadResult { + t.Helper() + + resultCh := make(chan notificationReadResult, 1) + scanner := bufio.NewScanner(conn) + scanner.Buffer(make([]byte, 64*1024), 10*1024*1024) + + go func() { + for scanner.Scan() { + var msg map[string]any + if err := json.Unmarshal(scanner.Bytes(), &msg); err != nil { + resultCh <- notificationReadResult{err: err} + return + } + if match(msg) { + resultCh <- notificationReadResult{msg: msg} + return + } + } + + if err := scanner.Err(); err != nil { + resultCh <- notificationReadResult{err: err} + return + } + resultCh <- notificationReadResult{err: context.DeadlineExceeded} + }() + + return resultCh +} + func TestSendNotificationToAllClients_Good(t *testing.T) { svc, err := New(Options{}) if err != nil { @@ -130,8 +160,10 @@ func TestSendNotificationToAllClients_Good_CustomNotification(t *testing.T) { defer session.Close() clientConn.SetDeadline(time.Now().Add(5 * time.Second)) - scanner := bufio.NewScanner(clientConn) - scanner.Buffer(make([]byte, 64*1024), 10*1024*1024) + + read := readNotificationMessageUntil(t, clientConn, func(msg map[string]any) bool { + return msg["method"] == loggingNotificationMethod + }) sent := make(chan struct{}) go func() { @@ -141,20 +173,17 @@ func TestSendNotificationToAllClients_Good_CustomNotification(t *testing.T) { close(sent) }() - if !scanner.Scan() { - t.Fatalf("failed to read notification: %v", scanner.Err()) - } - select { case <-sent: case <-time.After(5 * time.Second): t.Fatal("timed out waiting for notification send to complete") } - var msg map[string]any - if err := json.Unmarshal(scanner.Bytes(), &msg); err != nil { - t.Fatalf("failed to unmarshal notification: %v", err) + res := <-read + if res.err != nil { + t.Fatalf("failed to read notification: %v", res.err) } + msg := res.msg if msg["method"] != loggingNotificationMethod { t.Fatalf("expected method %q, got %v", loggingNotificationMethod, msg["method"]) } @@ -233,8 +262,10 @@ func TestChannelSendToSession_Good_CustomNotification(t *testing.T) { defer session.Close() clientConn.SetDeadline(time.Now().Add(5 * time.Second)) - scanner := bufio.NewScanner(clientConn) - scanner.Buffer(make([]byte, 64*1024), 10*1024*1024) + + read := readNotificationMessageUntil(t, clientConn, func(msg map[string]any) bool { + return msg["method"] == channelNotificationMethod + }) sent := make(chan struct{}) go func() { @@ -244,20 +275,17 @@ func TestChannelSendToSession_Good_CustomNotification(t *testing.T) { close(sent) }() - if !scanner.Scan() { - t.Fatalf("failed to read custom notification: %v", scanner.Err()) - } - select { case <-sent: case <-time.After(5 * time.Second): t.Fatal("timed out waiting for notification send to complete") } - var msg map[string]any - if err := json.Unmarshal(scanner.Bytes(), &msg); err != nil { - t.Fatalf("failed to unmarshal notification: %v", err) + res := <-read + if res.err != nil { + t.Fatalf("failed to read custom notification: %v", res.err) } + msg := res.msg if msg["method"] != channelNotificationMethod { t.Fatalf("expected method %q, got %v", channelNotificationMethod, msg["method"]) } @@ -321,6 +349,20 @@ func TestChannelCapability_Good_PublicHelpers(t *testing.T) { t.Fatalf("expected public capability helper to match internal definition") } + spec := ClaudeChannelCapability() + if spec.Version != "1" { + t.Fatalf("expected typed capability version 1, got %q", spec.Version) + } + if spec.Description == "" { + t.Fatal("expected typed capability description to be populated") + } + if !slices.Equal(spec.Channels, channelCapabilityChannels()) { + t.Fatalf("expected typed capability channels to match: got %v want %v", spec.Channels, channelCapabilityChannels()) + } + if !reflect.DeepEqual(spec.Map(), want["claude/channel"].(map[string]any)) { + t.Fatal("expected typed capability map to match wire-format descriptor") + } + gotChannels := ChannelCapabilityChannels() wantChannels := channelCapabilityChannels() if !slices.Equal(gotChannels, wantChannels) {