refactor(mcp): add typed channel capability helper
Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
parent
af3cf3c8e3
commit
fa9a5eed28
2 changed files with 100 additions and 23 deletions
|
|
@ -90,6 +90,34 @@ var channelCapabilityList = []string{
|
|||
ChannelTestResult,
|
||||
}
|
||||
|
||||
// ChannelCapabilitySpec describes the experimental claude/channel capability.
|
||||
//
|
||||
// spec := ChannelCapabilitySpec{
|
||||
// Version: "1",
|
||||
// Description: "Push events into client sessions via named channels",
|
||||
// Channels: ChannelCapabilityChannels(),
|
||||
// }
|
||||
type ChannelCapabilitySpec struct {
|
||||
Version string `json:"version"`
|
||||
Description string `json:"description"`
|
||||
Channels []string `json:"channels"`
|
||||
}
|
||||
|
||||
// Map converts the typed capability into the wire-format map expected by the SDK.
|
||||
//
|
||||
// caps := ChannelCapabilitySpec{
|
||||
// Version: "1",
|
||||
// Description: "Push events into client sessions via named channels",
|
||||
// Channels: ChannelCapabilityChannels(),
|
||||
// }.Map()
|
||||
func (c ChannelCapabilitySpec) Map() map[string]any {
|
||||
return map[string]any{
|
||||
"version": c.Version,
|
||||
"description": c.Description,
|
||||
"channels": slices.Clone(c.Channels),
|
||||
}
|
||||
}
|
||||
|
||||
// ChannelNotification is the payload sent through the experimental channel
|
||||
// notification method.
|
||||
//
|
||||
|
|
@ -280,11 +308,18 @@ func (e *notificationError) Error() string {
|
|||
// for claude/channel, registered during New().
|
||||
func channelCapability() map[string]any {
|
||||
return map[string]any{
|
||||
"claude/channel": map[string]any{
|
||||
"version": "1",
|
||||
"description": "Push events into client sessions via named channels",
|
||||
"channels": channelCapabilityChannels(),
|
||||
},
|
||||
"claude/channel": ClaudeChannelCapability().Map(),
|
||||
}
|
||||
}
|
||||
|
||||
// ClaudeChannelCapability returns the typed experimental capability descriptor.
|
||||
//
|
||||
// cap := ClaudeChannelCapability()
|
||||
func ClaudeChannelCapability() ChannelCapabilitySpec {
|
||||
return ChannelCapabilitySpec{
|
||||
Version: "1",
|
||||
Description: "Push events into client sessions via named channels",
|
||||
Channels: channelCapabilityChannels(),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -59,6 +59,36 @@ func readNotificationMessage(t *testing.T, conn net.Conn) <-chan notificationRea
|
|||
return resultCh
|
||||
}
|
||||
|
||||
func readNotificationMessageUntil(t *testing.T, conn net.Conn, match func(map[string]any) bool) <-chan notificationReadResult {
|
||||
t.Helper()
|
||||
|
||||
resultCh := make(chan notificationReadResult, 1)
|
||||
scanner := bufio.NewScanner(conn)
|
||||
scanner.Buffer(make([]byte, 64*1024), 10*1024*1024)
|
||||
|
||||
go func() {
|
||||
for scanner.Scan() {
|
||||
var msg map[string]any
|
||||
if err := json.Unmarshal(scanner.Bytes(), &msg); err != nil {
|
||||
resultCh <- notificationReadResult{err: err}
|
||||
return
|
||||
}
|
||||
if match(msg) {
|
||||
resultCh <- notificationReadResult{msg: msg}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
resultCh <- notificationReadResult{err: err}
|
||||
return
|
||||
}
|
||||
resultCh <- notificationReadResult{err: context.DeadlineExceeded}
|
||||
}()
|
||||
|
||||
return resultCh
|
||||
}
|
||||
|
||||
func TestSendNotificationToAllClients_Good(t *testing.T) {
|
||||
svc, err := New(Options{})
|
||||
if err != nil {
|
||||
|
|
@ -130,8 +160,10 @@ func TestSendNotificationToAllClients_Good_CustomNotification(t *testing.T) {
|
|||
defer session.Close()
|
||||
|
||||
clientConn.SetDeadline(time.Now().Add(5 * time.Second))
|
||||
scanner := bufio.NewScanner(clientConn)
|
||||
scanner.Buffer(make([]byte, 64*1024), 10*1024*1024)
|
||||
|
||||
read := readNotificationMessageUntil(t, clientConn, func(msg map[string]any) bool {
|
||||
return msg["method"] == loggingNotificationMethod
|
||||
})
|
||||
|
||||
sent := make(chan struct{})
|
||||
go func() {
|
||||
|
|
@ -141,20 +173,17 @@ func TestSendNotificationToAllClients_Good_CustomNotification(t *testing.T) {
|
|||
close(sent)
|
||||
}()
|
||||
|
||||
if !scanner.Scan() {
|
||||
t.Fatalf("failed to read notification: %v", scanner.Err())
|
||||
}
|
||||
|
||||
select {
|
||||
case <-sent:
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("timed out waiting for notification send to complete")
|
||||
}
|
||||
|
||||
var msg map[string]any
|
||||
if err := json.Unmarshal(scanner.Bytes(), &msg); err != nil {
|
||||
t.Fatalf("failed to unmarshal notification: %v", err)
|
||||
res := <-read
|
||||
if res.err != nil {
|
||||
t.Fatalf("failed to read notification: %v", res.err)
|
||||
}
|
||||
msg := res.msg
|
||||
if msg["method"] != loggingNotificationMethod {
|
||||
t.Fatalf("expected method %q, got %v", loggingNotificationMethod, msg["method"])
|
||||
}
|
||||
|
|
@ -233,8 +262,10 @@ func TestChannelSendToSession_Good_CustomNotification(t *testing.T) {
|
|||
defer session.Close()
|
||||
|
||||
clientConn.SetDeadline(time.Now().Add(5 * time.Second))
|
||||
scanner := bufio.NewScanner(clientConn)
|
||||
scanner.Buffer(make([]byte, 64*1024), 10*1024*1024)
|
||||
|
||||
read := readNotificationMessageUntil(t, clientConn, func(msg map[string]any) bool {
|
||||
return msg["method"] == channelNotificationMethod
|
||||
})
|
||||
|
||||
sent := make(chan struct{})
|
||||
go func() {
|
||||
|
|
@ -244,20 +275,17 @@ func TestChannelSendToSession_Good_CustomNotification(t *testing.T) {
|
|||
close(sent)
|
||||
}()
|
||||
|
||||
if !scanner.Scan() {
|
||||
t.Fatalf("failed to read custom notification: %v", scanner.Err())
|
||||
}
|
||||
|
||||
select {
|
||||
case <-sent:
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("timed out waiting for notification send to complete")
|
||||
}
|
||||
|
||||
var msg map[string]any
|
||||
if err := json.Unmarshal(scanner.Bytes(), &msg); err != nil {
|
||||
t.Fatalf("failed to unmarshal notification: %v", err)
|
||||
res := <-read
|
||||
if res.err != nil {
|
||||
t.Fatalf("failed to read custom notification: %v", res.err)
|
||||
}
|
||||
msg := res.msg
|
||||
if msg["method"] != channelNotificationMethod {
|
||||
t.Fatalf("expected method %q, got %v", channelNotificationMethod, msg["method"])
|
||||
}
|
||||
|
|
@ -321,6 +349,20 @@ func TestChannelCapability_Good_PublicHelpers(t *testing.T) {
|
|||
t.Fatalf("expected public capability helper to match internal definition")
|
||||
}
|
||||
|
||||
spec := ClaudeChannelCapability()
|
||||
if spec.Version != "1" {
|
||||
t.Fatalf("expected typed capability version 1, got %q", spec.Version)
|
||||
}
|
||||
if spec.Description == "" {
|
||||
t.Fatal("expected typed capability description to be populated")
|
||||
}
|
||||
if !slices.Equal(spec.Channels, channelCapabilityChannels()) {
|
||||
t.Fatalf("expected typed capability channels to match: got %v want %v", spec.Channels, channelCapabilityChannels())
|
||||
}
|
||||
if !reflect.DeepEqual(spec.Map(), want["claude/channel"].(map[string]any)) {
|
||||
t.Fatal("expected typed capability map to match wire-format descriptor")
|
||||
}
|
||||
|
||||
gotChannels := ChannelCapabilityChannels()
|
||||
wantChannels := channelCapabilityChannels()
|
||||
if !slices.Equal(gotChannels, wantChannels) {
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue