diff --git a/pkg/mcp/agentic/issue.go b/pkg/mcp/agentic/issue.go index 7bd3a09..1b66c0e 100644 --- a/pkg/mcp/agentic/issue.go +++ b/pkg/mcp/agentic/issue.go @@ -49,6 +49,12 @@ func (s *PrepSubsystem) registerIssueTools(svc *coremcp.Service) { Description: "Dispatch an agent to work on a Forge issue. Assigns the issue as a lock, prepends the issue body to TODO.md, creates an issue-specific branch, and spawns the agent.", }, s.dispatchIssue) + // agentic_issue_dispatch is the spec-aligned name for the same action. + coremcp.AddToolRecorded(svc, server, "agentic", &mcp.Tool{ + Name: "agentic_issue_dispatch", + Description: "Dispatch an agent to work on a Forge issue. Spec-aligned alias for agentic_dispatch_issue.", + }, s.dispatchIssue) + coremcp.AddToolRecorded(svc, server, "agentic", &mcp.Tool{ Name: "agentic_pr", Description: "Create a pull request from an agent workspace. Pushes the branch and creates a Forge PR linked to the tracked issue, if any.", diff --git a/pkg/mcp/ide/tools_dashboard.go b/pkg/mcp/ide/tools_dashboard.go index 0c14713..8ebee56 100644 --- a/pkg/mcp/ide/tools_dashboard.go +++ b/pkg/mcp/ide/tools_dashboard.go @@ -4,6 +4,7 @@ package ide import ( "context" + "sync" "time" coremcp "dappco.re/go/mcp/pkg/mcp" @@ -86,6 +87,46 @@ type DashboardMetricsOutput struct { Metrics DashboardMetrics `json:"metrics"` } +// DashboardStateInput is the input for ide_dashboard_state. +// +// input := DashboardStateInput{} +type DashboardStateInput struct{} + +// DashboardStateOutput is the output for ide_dashboard_state. +// +// // out.State["theme"] == "dark" +type DashboardStateOutput struct { + State map[string]any `json:"state"` // arbitrary key/value map + UpdatedAt time.Time `json:"updatedAt"` // when the state last changed +} + +// DashboardUpdateInput is the input for ide_dashboard_update. +// +// input := DashboardUpdateInput{ +// State: map[string]any{"theme": "light", "sidebar": true}, +// Replace: false, +// } +type DashboardUpdateInput struct { + State map[string]any `json:"state"` // partial or full state + Replace bool `json:"replace,omitempty"` // true to overwrite, false to merge (default) +} + +// DashboardUpdateOutput is the output for ide_dashboard_update. +// +// // out.State reflects the merged/replaced state +type DashboardUpdateOutput struct { + State map[string]any `json:"state"` // merged state after the update + UpdatedAt time.Time `json:"updatedAt"` // when the state was applied +} + +// dashboardStateStore holds the mutable dashboard UI state shared between the +// IDE frontend and MCP callers. Access is guarded by dashboardStateMu. +var ( + dashboardStateMu sync.RWMutex + dashboardStateStore = map[string]any{} + dashboardStateUpdated time.Time +) + func (s *Subsystem) registerDashboardTools(svc *coremcp.Service) { server := svc.Server() coremcp.AddToolRecorded(svc, server, "ide", &mcp.Tool{ @@ -102,6 +143,16 @@ func (s *Subsystem) registerDashboardTools(svc *coremcp.Service) { Name: "ide_dashboard_metrics", Description: "Get aggregate build and agent metrics for a time period", }, s.dashboardMetrics) + + coremcp.AddToolRecorded(svc, server, "ide", &mcp.Tool{ + Name: "ide_dashboard_state", + Description: "Get the current dashboard UI state (arbitrary key/value map shared with the IDE).", + }, s.dashboardState) + + coremcp.AddToolRecorded(svc, server, "ide", &mcp.Tool{ + Name: "ide_dashboard_update", + Description: "Update the dashboard UI state. Merges into existing state by default; set replace=true to overwrite.", + }, s.dashboardUpdate) } // dashboardOverview returns a platform overview with bridge status and @@ -211,3 +262,79 @@ func (s *Subsystem) dashboardMetrics(_ context.Context, _ *mcp.CallToolRequest, }, }, nil } + +// dashboardState returns the current dashboard UI state as a snapshot. +// +// out := s.dashboardState(ctx, nil, DashboardStateInput{}) +func (s *Subsystem) dashboardState(_ context.Context, _ *mcp.CallToolRequest, _ DashboardStateInput) (*mcp.CallToolResult, DashboardStateOutput, error) { + dashboardStateMu.RLock() + defer dashboardStateMu.RUnlock() + + snapshot := make(map[string]any, len(dashboardStateStore)) + for k, v := range dashboardStateStore { + snapshot[k] = v + } + + return nil, DashboardStateOutput{ + State: snapshot, + UpdatedAt: dashboardStateUpdated, + }, nil +} + +// dashboardUpdate merges or replaces the dashboard UI state and emits an +// activity event so the IDE can react to the change. +// +// out := s.dashboardUpdate(ctx, nil, DashboardUpdateInput{State: map[string]any{"theme": "dark"}}) +func (s *Subsystem) dashboardUpdate(ctx context.Context, _ *mcp.CallToolRequest, input DashboardUpdateInput) (*mcp.CallToolResult, DashboardUpdateOutput, error) { + now := time.Now() + + dashboardStateMu.Lock() + if input.Replace || dashboardStateStore == nil { + dashboardStateStore = make(map[string]any, len(input.State)) + } + for k, v := range input.State { + dashboardStateStore[k] = v + } + dashboardStateUpdated = now + + snapshot := make(map[string]any, len(dashboardStateStore)) + for k, v := range dashboardStateStore { + snapshot[k] = v + } + dashboardStateMu.Unlock() + + // Record the change on the activity feed so ide_dashboard_activity + // reflects state transitions alongside build/session events. + s.recordActivity("dashboard_state", "dashboard state updated") + + // Push the update over the Laravel bridge when available so web clients + // stay in sync with desktop tooling. + if s.bridge != nil { + _ = s.bridge.Send(BridgeMessage{ + Type: "dashboard_update", + Data: snapshot, + }) + } + + // Surface the change on the shared MCP notifier so connected sessions + // receive a JSON-RPC notification alongside the tool response. + if s.notifier != nil { + s.notifier.ChannelSend(ctx, "dashboard.state.updated", map[string]any{ + "state": snapshot, + "updatedAt": now, + }) + } + + return nil, DashboardUpdateOutput{ + State: snapshot, + UpdatedAt: now, + }, nil +} + +// resetDashboardState clears the shared dashboard state. Intended for tests. +func resetDashboardState() { + dashboardStateMu.Lock() + defer dashboardStateMu.Unlock() + dashboardStateStore = map[string]any{} + dashboardStateUpdated = time.Time{} +} diff --git a/pkg/mcp/ide/tools_test.go b/pkg/mcp/ide/tools_test.go index 42b9edc..c462e9f 100644 --- a/pkg/mcp/ide/tools_test.go +++ b/pkg/mcp/ide/tools_test.go @@ -949,3 +949,76 @@ func TestChatSend_Good_BridgeMessageType(t *testing.T) { t.Fatal("timed out waiting for bridge message") } } + +// TestToolsDashboard_DashboardState_Good returns an empty state when the +// store has not been touched. +func TestToolsDashboard_DashboardState_Good(t *testing.T) { + t.Cleanup(resetDashboardState) + + sub := newNilBridgeSubsystem() + _, out, err := sub.dashboardState(context.Background(), nil, DashboardStateInput{}) + if err != nil { + t.Fatalf("dashboardState failed: %v", err) + } + if len(out.State) != 0 { + t.Fatalf("expected empty state, got %v", out.State) + } +} + +// TestToolsDashboard_DashboardUpdate_Good merges the supplied state into the +// shared store and reflects it back on a subsequent dashboardState call. +func TestToolsDashboard_DashboardUpdate_Good(t *testing.T) { + t.Cleanup(resetDashboardState) + + sub := newNilBridgeSubsystem() + + _, updateOut, err := sub.dashboardUpdate(context.Background(), nil, DashboardUpdateInput{ + State: map[string]any{"theme": "dark"}, + }) + if err != nil { + t.Fatalf("dashboardUpdate failed: %v", err) + } + if updateOut.State["theme"] != "dark" { + t.Fatalf("expected theme 'dark', got %v", updateOut.State["theme"]) + } + + _, readOut, err := sub.dashboardState(context.Background(), nil, DashboardStateInput{}) + if err != nil { + t.Fatalf("dashboardState failed: %v", err) + } + if readOut.State["theme"] != "dark" { + t.Fatalf("expected persisted theme 'dark', got %v", readOut.State["theme"]) + } + if readOut.UpdatedAt.IsZero() { + t.Fatal("expected non-zero UpdatedAt after update") + } +} + +// TestToolsDashboard_DashboardUpdate_Ugly replaces (not merges) prior state +// when Replace=true. +func TestToolsDashboard_DashboardUpdate_Ugly(t *testing.T) { + t.Cleanup(resetDashboardState) + + sub := newNilBridgeSubsystem() + + _, _, err := sub.dashboardUpdate(context.Background(), nil, DashboardUpdateInput{ + State: map[string]any{"theme": "dark", "sidebar": true}, + }) + if err != nil { + t.Fatalf("seed dashboardUpdate failed: %v", err) + } + + _, out, err := sub.dashboardUpdate(context.Background(), nil, DashboardUpdateInput{ + State: map[string]any{"theme": "light"}, + Replace: true, + }) + if err != nil { + t.Fatalf("replace dashboardUpdate failed: %v", err) + } + if _, ok := out.State["sidebar"]; ok { + t.Fatal("expected sidebar to be removed after replace") + } + if out.State["theme"] != "light" { + t.Fatalf("expected theme 'light', got %v", out.State["theme"]) + } +} diff --git a/pkg/mcp/mcp.go b/pkg/mcp/mcp.go index dafed1d..7864111 100644 --- a/pkg/mcp/mcp.go +++ b/pkg/mcp/mcp.go @@ -316,6 +316,7 @@ func (s *Service) registerTools(server *mcp.Server) { s.registerProcessTools(server) s.registerWebviewTools(server) s.registerWSTools(server) + s.registerWSClientTools(server) } // Tool input/output types for MCP file operations. diff --git a/pkg/mcp/registry_test.go b/pkg/mcp/registry_test.go index 57fee51..c2019d8 100644 --- a/pkg/mcp/registry_test.go +++ b/pkg/mcp/registry_test.go @@ -71,13 +71,19 @@ func TestToolRegistry_Good_ToolCount(t *testing.T) { } tools := svc.Tools() - // Built-in tools: file_read, file_write, file_delete, file_rename, - // file_exists, file_edit, dir_list, dir_create, lang_detect, lang_list, - // metrics_record, metrics_query, rag_query, rag_ingest, rag_collections, - // webview_connect, webview_disconnect, webview_navigate, webview_click, - // webview_type, webview_query, webview_console, webview_eval, - // webview_screenshot, webview_wait - const expectedCount = 25 + // Built-in tools (no ProcessService / WSHub / Subsystems): + // files (8): file_read, file_write, file_delete, file_rename, + // file_exists, file_edit, dir_list, dir_create + // language (2): lang_detect, lang_list + // metrics (2): metrics_record, metrics_query + // rag (6): rag_query, rag_search, rag_ingest, rag_index, + // rag_retrieve, rag_collections + // webview (12): webview_connect, webview_disconnect, webview_navigate, + // webview_click, webview_type, webview_query, + // webview_console, webview_eval, webview_screenshot, + // webview_wait, webview_render, webview_update + // ws (3): ws_connect, ws_send, ws_close + const expectedCount = 33 if len(tools) != expectedCount { t.Errorf("expected %d tools, got %d", expectedCount, len(tools)) for _, tr := range tools { @@ -95,8 +101,8 @@ func TestToolRegistry_Good_GroupAssignment(t *testing.T) { fileTools := []string{"file_read", "file_write", "file_delete", "file_rename", "file_exists", "file_edit", "dir_list", "dir_create"} langTools := []string{"lang_detect", "lang_list"} metricsTools := []string{"metrics_record", "metrics_query"} - ragTools := []string{"rag_query", "rag_ingest", "rag_collections"} - webviewTools := []string{"webview_connect", "webview_disconnect", "webview_navigate", "webview_click", "webview_type", "webview_query", "webview_console", "webview_eval", "webview_screenshot", "webview_wait"} + ragTools := []string{"rag_query", "rag_search", "rag_ingest", "rag_index", "rag_retrieve", "rag_collections"} + webviewTools := []string{"webview_connect", "webview_disconnect", "webview_navigate", "webview_click", "webview_type", "webview_query", "webview_console", "webview_eval", "webview_screenshot", "webview_wait", "webview_render", "webview_update"} byName := make(map[string]ToolRecord) for _, tr := range svc.Tools() { @@ -157,6 +163,18 @@ func TestToolRegistry_Good_GroupAssignment(t *testing.T) { t.Errorf("tool %s: expected group 'webview', got %q", name, tr.Group) } } + + wsClientTools := []string{"ws_connect", "ws_send", "ws_close"} + for _, name := range wsClientTools { + tr, ok := byName[name] + if !ok { + t.Errorf("tool %s not found in registry", name) + continue + } + if tr.Group != "ws" { + t.Errorf("tool %s: expected group 'ws', got %q", name, tr.Group) + } + } } func TestToolRegistry_Good_ToolRecordFields(t *testing.T) { diff --git a/pkg/mcp/tools_process.go b/pkg/mcp/tools_process.go index 90ec2e7..a319aec 100644 --- a/pkg/mcp/tools_process.go +++ b/pkg/mcp/tools_process.go @@ -29,6 +29,32 @@ type ProcessStartInput struct { Env []string `json:"env,omitempty"` // e.g. ["CGO_ENABLED=0"] } +// ProcessRunInput contains parameters for running a command to completion +// and returning its captured output. +// +// input := ProcessRunInput{ +// Command: "go", +// Args: []string{"test", "./..."}, +// Dir: "/home/user/project", +// Env: []string{"CGO_ENABLED=0"}, +// } +type ProcessRunInput struct { + Command string `json:"command"` // e.g. "go" + Args []string `json:"args,omitempty"` // e.g. ["test", "./..."] + Dir string `json:"dir,omitempty"` // e.g. "/home/user/project" + Env []string `json:"env,omitempty"` // e.g. ["CGO_ENABLED=0"] +} + +// ProcessRunOutput contains the result of running a process to completion. +// +// // out.ID == "proc-abc123", out.ExitCode == 0, out.Output == "PASS\n..." +type ProcessRunOutput struct { + ID string `json:"id"` // e.g. "proc-abc123" + ExitCode int `json:"exitCode"` // 0 on success + Output string `json:"output"` // combined stdout/stderr + Command string `json:"command"` // e.g. "go" +} + // ProcessStartOutput contains the result of starting a process. // // // out.ID == "proc-abc123", out.PID == 54321, out.Command == "go" @@ -146,6 +172,11 @@ func (s *Service) registerProcessTools(server *mcp.Server) bool { Description: "Start a new external process. Returns process ID for tracking.", }, s.processStart) + addToolRecorded(s, server, "process", &mcp.Tool{ + Name: "process_run", + Description: "Run a command to completion and return the captured output. Blocks until the process exits.", + }, s.processRun) + addToolRecorded(s, server, "process", &mcp.Tool{ Name: "process_stop", Description: "Gracefully stop a running process by ID.", @@ -224,6 +255,63 @@ func (s *Service) processStart(ctx context.Context, req *mcp.CallToolRequest, in return nil, output, nil } +// processRun handles the process_run tool call. +// Executes the command to completion and returns the captured output. +func (s *Service) processRun(ctx context.Context, req *mcp.CallToolRequest, input ProcessRunInput) (*mcp.CallToolResult, ProcessRunOutput, error) { + if s.processService == nil { + return nil, ProcessRunOutput{}, log.E("processRun", "process service unavailable", nil) + } + + s.logger.Security("MCP tool execution", "tool", "process_run", "command", input.Command, "args", input.Args, "dir", input.Dir, "user", log.Username()) + + if input.Command == "" { + return nil, ProcessRunOutput{}, log.E("processRun", "command cannot be empty", nil) + } + + opts := process.RunOptions{ + Command: input.Command, + Args: input.Args, + Dir: s.resolveWorkspacePath(input.Dir), + Env: input.Env, + } + + proc, err := s.processService.StartWithOptions(ctx, opts) + if err != nil { + log.Error("mcp: process run start failed", "command", input.Command, "err", err) + return nil, ProcessRunOutput{}, log.E("processRun", "failed to start process", err) + } + + info := proc.Info() + s.recordProcessRuntime(proc.ID, processRuntime{ + Command: proc.Command, + Args: proc.Args, + Dir: info.Dir, + StartedAt: proc.StartedAt, + }) + s.ChannelSend(ctx, ChannelProcessStart, map[string]any{ + "id": proc.ID, + "pid": info.PID, + "command": proc.Command, + "args": proc.Args, + "dir": info.Dir, + "startedAt": proc.StartedAt, + }) + + // Wait for completion (context-aware). + select { + case <-ctx.Done(): + return nil, ProcessRunOutput{}, log.E("processRun", "cancelled", ctx.Err()) + case <-proc.Done(): + } + + return nil, ProcessRunOutput{ + ID: proc.ID, + ExitCode: proc.ExitCode, + Output: proc.Output(), + Command: proc.Command, + }, nil +} + // processStop handles the process_stop tool call. func (s *Service) processStop(ctx context.Context, req *mcp.CallToolRequest, input ProcessStopInput) (*mcp.CallToolResult, ProcessStopOutput, error) { if s.processService == nil { diff --git a/pkg/mcp/tools_process_test.go b/pkg/mcp/tools_process_test.go index ee7c2d7..05aa435 100644 --- a/pkg/mcp/tools_process_test.go +++ b/pkg/mcp/tools_process_test.go @@ -301,3 +301,57 @@ func TestRegisterProcessTools_Bad_NilService(t *testing.T) { t.Error("Expected registerProcessTools to return false when processService is nil") } } + +// TestToolsProcess_ProcessRunInput_Good exercises the process_run input DTO shape. +func TestToolsProcess_ProcessRunInput_Good(t *testing.T) { + input := ProcessRunInput{ + Command: "echo", + Args: []string{"hello"}, + Dir: "/tmp", + Env: []string{"FOO=bar"}, + } + if input.Command != "echo" { + t.Errorf("expected command 'echo', got %q", input.Command) + } + if len(input.Args) != 1 || input.Args[0] != "hello" { + t.Errorf("expected args [hello], got %v", input.Args) + } + if input.Dir != "/tmp" { + t.Errorf("expected dir '/tmp', got %q", input.Dir) + } + if len(input.Env) != 1 { + t.Errorf("expected 1 env, got %d", len(input.Env)) + } +} + +// TestToolsProcess_ProcessRunOutput_Good exercises the process_run output DTO shape. +func TestToolsProcess_ProcessRunOutput_Good(t *testing.T) { + output := ProcessRunOutput{ + ID: "proc-1", + ExitCode: 0, + Output: "hello\n", + Command: "echo", + } + if output.ID != "proc-1" { + t.Errorf("expected id 'proc-1', got %q", output.ID) + } + if output.ExitCode != 0 { + t.Errorf("expected exit code 0, got %d", output.ExitCode) + } + if output.Output != "hello\n" { + t.Errorf("expected output 'hello\\n', got %q", output.Output) + } +} + +// TestToolsProcess_ProcessRun_Bad rejects calls without a process service. +func TestToolsProcess_ProcessRun_Bad(t *testing.T) { + svc, err := New(Options{}) + if err != nil { + t.Fatal(err) + } + + _, _, err = svc.processRun(t.Context(), nil, ProcessRunInput{Command: "echo", Args: []string{"hi"}}) + if err == nil { + t.Fatal("expected error when process service is unavailable") + } +} diff --git a/pkg/mcp/tools_rag.go b/pkg/mcp/tools_rag.go index 3b68140..ab9b981 100644 --- a/pkg/mcp/tools_rag.go +++ b/pkg/mcp/tools_rag.go @@ -83,6 +83,30 @@ type RAGCollectionsInput struct { ShowStats bool `json:"show_stats,omitempty"` // true to include point counts and status } +// RAGRetrieveInput contains parameters for retrieving chunks from a specific +// document source (rather than running a semantic query). +// +// input := RAGRetrieveInput{ +// Source: "docs/services.md", +// Collection: "core-docs", +// Limit: 20, +// } +type RAGRetrieveInput struct { + Source string `json:"source"` // e.g. "docs/services.md" + Collection string `json:"collection,omitempty"` // e.g. "core-docs" (default: "hostuk-docs") + Limit int `json:"limit,omitempty"` // e.g. 20 (default: 50) +} + +// RAGRetrieveOutput contains document chunks for a specific source. +// +// // len(out.Chunks) == 12, out.Source == "docs/services.md" +type RAGRetrieveOutput struct { + Source string `json:"source"` // e.g. "docs/services.md" + Collection string `json:"collection"` // collection searched + Chunks []RAGQueryResult `json:"chunks"` // chunks for the source, ordered by chunkIndex + Count int `json:"count"` // number of chunks returned +} + // CollectionInfo contains information about a Qdrant collection. // // // ci.Name == "core-docs", ci.PointsCount == 1500, ci.Status == "green" @@ -106,11 +130,28 @@ func (s *Service) registerRAGTools(server *mcp.Server) { Description: "Query the RAG vector database for relevant documentation. Returns semantically similar content based on the query.", }, s.ragQuery) + // rag_search is the spec-aligned alias for rag_query. + addToolRecorded(s, server, "rag", &mcp.Tool{ + Name: "rag_search", + Description: "Semantic search across documents in the RAG vector database. Returns chunks ranked by similarity.", + }, s.ragQuery) + addToolRecorded(s, server, "rag", &mcp.Tool{ Name: "rag_ingest", Description: "Ingest documents into the RAG vector database. Supports both single files and directories.", }, s.ragIngest) + // rag_index is the spec-aligned alias for rag_ingest. + addToolRecorded(s, server, "rag", &mcp.Tool{ + Name: "rag_index", + Description: "Index a document or directory into the RAG vector database.", + }, s.ragIngest) + + addToolRecorded(s, server, "rag", &mcp.Tool{ + Name: "rag_retrieve", + Description: "Retrieve chunks for a specific document source from the RAG vector database.", + }, s.ragRetrieve) + addToolRecorded(s, server, "rag", &mcp.Tool{ Name: "rag_collections", Description: "List all available collections in the RAG vector database.", @@ -216,6 +257,86 @@ func (s *Service) ragIngest(ctx context.Context, req *mcp.CallToolRequest, input }, nil } +// ragRetrieve handles the rag_retrieve tool call. +// Returns chunks for a specific source path by querying the collection with +// the source path as the query text and then filtering results down to the +// matching source. This preserves the transport abstraction that the rest of +// the RAG tools use while producing the document-scoped view callers expect. +func (s *Service) ragRetrieve(ctx context.Context, req *mcp.CallToolRequest, input RAGRetrieveInput) (*mcp.CallToolResult, RAGRetrieveOutput, error) { + collection := input.Collection + if collection == "" { + collection = DefaultRAGCollection + } + limit := input.Limit + if limit <= 0 { + limit = 50 + } + + s.logger.Info("MCP tool execution", "tool", "rag_retrieve", "source", input.Source, "collection", collection, "limit", limit, "user", log.Username()) + + if input.Source == "" { + return nil, RAGRetrieveOutput{}, log.E("ragRetrieve", "source cannot be empty", nil) + } + + // Use the source path as the query text — semantically related chunks + // will rank highly, and we then keep only chunks whose Source matches. + // Over-fetch by an order of magnitude so document-level limits are met + // even when the source appears beyond the top-K of the raw query. + overfetch := limit * 10 + if overfetch < 100 { + overfetch = 100 + } + + results, err := rag.QueryDocs(ctx, input.Source, collection, overfetch) + if err != nil { + log.Error("mcp: rag retrieve query failed", "source", input.Source, "collection", collection, "err", err) + return nil, RAGRetrieveOutput{}, log.E("ragRetrieve", "failed to retrieve chunks", err) + } + + chunks := make([]RAGQueryResult, 0, limit) + for _, r := range results { + if r.Source != input.Source { + continue + } + chunks = append(chunks, RAGQueryResult{ + Content: r.Text, + Source: r.Source, + Section: r.Section, + Category: r.Category, + ChunkIndex: r.ChunkIndex, + Score: r.Score, + }) + if len(chunks) >= limit { + break + } + } + sortChunksByIndex(chunks) + + return nil, RAGRetrieveOutput{ + Source: input.Source, + Collection: collection, + Chunks: chunks, + Count: len(chunks), + }, nil +} + +// sortChunksByIndex sorts chunks in ascending order of chunk index. +// Stable ordering keeps ties by their original position. +func sortChunksByIndex(chunks []RAGQueryResult) { + if len(chunks) <= 1 { + return + } + // Insertion sort keeps the code dependency-free and is fast enough + // for the small result sets rag_retrieve is designed for. + for i := 1; i < len(chunks); i++ { + j := i + for j > 0 && chunks[j-1].ChunkIndex > chunks[j].ChunkIndex { + chunks[j-1], chunks[j] = chunks[j], chunks[j-1] + j-- + } + } +} + // ragCollections handles the rag_collections tool call. func (s *Service) ragCollections(ctx context.Context, req *mcp.CallToolRequest, input RAGCollectionsInput) (*mcp.CallToolResult, RAGCollectionsOutput, error) { s.logger.Info("MCP tool execution", "tool", "rag_collections", "show_stats", input.ShowStats, "user", log.Username()) diff --git a/pkg/mcp/tools_rag_test.go b/pkg/mcp/tools_rag_test.go index 281dbf0..57e6b05 100644 --- a/pkg/mcp/tools_rag_test.go +++ b/pkg/mcp/tools_rag_test.go @@ -171,3 +171,66 @@ func TestRAGCollectionsInput_ShowStats(t *testing.T) { t.Error("Expected ShowStats to be true") } } + +// TestToolsRag_RAGRetrieveInput_Good exercises the rag_retrieve DTO defaults. +func TestToolsRag_RAGRetrieveInput_Good(t *testing.T) { + input := RAGRetrieveInput{ + Source: "docs/index.md", + Collection: "core-docs", + Limit: 20, + } + + if input.Source != "docs/index.md" { + t.Errorf("expected source docs/index.md, got %q", input.Source) + } + if input.Limit != 20 { + t.Errorf("expected limit 20, got %d", input.Limit) + } +} + +// TestToolsRag_RAGRetrieveOutput_Good exercises the rag_retrieve output shape. +func TestToolsRag_RAGRetrieveOutput_Good(t *testing.T) { + output := RAGRetrieveOutput{ + Source: "docs/index.md", + Collection: "core-docs", + Chunks: []RAGQueryResult{ + {Content: "first", ChunkIndex: 0}, + {Content: "second", ChunkIndex: 1}, + }, + Count: 2, + } + if output.Count != 2 { + t.Fatalf("expected count 2, got %d", output.Count) + } + if output.Chunks[1].ChunkIndex != 1 { + t.Fatalf("expected chunk 1, got %d", output.Chunks[1].ChunkIndex) + } +} + +// TestToolsRag_SortChunksByIndex_Good verifies sort orders by chunk index ascending. +func TestToolsRag_SortChunksByIndex_Good(t *testing.T) { + chunks := []RAGQueryResult{ + {ChunkIndex: 3}, + {ChunkIndex: 1}, + {ChunkIndex: 2}, + } + sortChunksByIndex(chunks) + for i, want := range []int{1, 2, 3} { + if chunks[i].ChunkIndex != want { + t.Fatalf("index %d: expected chunk %d, got %d", i, want, chunks[i].ChunkIndex) + } + } +} + +// TestToolsRag_RagRetrieve_Bad rejects empty source paths. +func TestToolsRag_RagRetrieve_Bad(t *testing.T) { + svc, err := New(Options{WorkspaceRoot: t.TempDir()}) + if err != nil { + t.Fatal(err) + } + + _, _, err = svc.ragRetrieve(t.Context(), nil, RAGRetrieveInput{}) + if err == nil { + t.Fatal("expected error for empty source") + } +} diff --git a/pkg/mcp/tools_webview.go b/pkg/mcp/tools_webview.go index a0d6a8b..734c223 100644 --- a/pkg/mcp/tools_webview.go +++ b/pkg/mcp/tools_webview.go @@ -270,6 +270,18 @@ func (s *Service) registerWebviewTools(server *mcp.Server) { Name: "webview_wait", Description: "Wait for an element to appear by CSS selector.", }, s.webviewWait) + + // Embedded UI rendering — for pushing HTML/state to connected clients + // without requiring a Chrome DevTools connection. + addToolRecorded(s, server, "webview", &mcp.Tool{ + Name: "webview_render", + Description: "Render HTML in an embedded webview by ID. Broadcasts to connected clients via the webview.render channel.", + }, s.webviewRender) + + addToolRecorded(s, server, "webview", &mcp.Tool{ + Name: "webview_update", + Description: "Update the HTML, title, or state of an embedded webview by ID. Broadcasts to connected clients via the webview.update channel.", + }, s.webviewUpdate) } // webviewConnect handles the webview_connect tool call. diff --git a/pkg/mcp/tools_webview_embed.go b/pkg/mcp/tools_webview_embed.go new file mode 100644 index 0000000..ff6d336 --- /dev/null +++ b/pkg/mcp/tools_webview_embed.go @@ -0,0 +1,233 @@ +// SPDX-License-Identifier: EUPL-1.2 + +package mcp + +import ( + "context" + "sync" + "time" + + core "dappco.re/go/core" + "dappco.re/go/core/log" + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +// WebviewRenderInput contains parameters for rendering an embedded +// HTML view. The named view is stored and broadcast so connected clients +// (Claude Code sessions, CoreGUI windows, HTTP/SSE subscribers) can +// display the content. +// +// input := WebviewRenderInput{ +// ViewID: "dashboard", +// HTML: "
Loading...
", +// Title: "Agent Dashboard", +// Width: 1024, +// Height: 768, +// State: map[string]any{"theme": "dark"}, +// } +type WebviewRenderInput struct { + ViewID string `json:"view_id"` // e.g. "dashboard" + HTML string `json:"html"` // rendered markup + Title string `json:"title,omitempty"` // e.g. "Agent Dashboard" + Width int `json:"width,omitempty"` // preferred width in pixels + Height int `json:"height,omitempty"` // preferred height in pixels + State map[string]any `json:"state,omitempty"` // initial view state +} + +// WebviewRenderOutput reports the result of rendering an embedded view. +// +// // out.Success == true, out.ViewID == "dashboard" +type WebviewRenderOutput struct { + Success bool `json:"success"` // true when the view was stored and broadcast + ViewID string `json:"view_id"` // echoed view identifier + UpdatedAt time.Time `json:"updatedAt"` // when the view was rendered +} + +// WebviewUpdateInput contains parameters for updating the state of an +// existing embedded view. Callers may provide HTML to replace the markup, +// patch fields in the view state, or do both. +// +// input := WebviewUpdateInput{ +// ViewID: "dashboard", +// HTML: "
Ready
", +// State: map[string]any{"count": 42}, +// Merge: true, +// } +type WebviewUpdateInput struct { + ViewID string `json:"view_id"` // e.g. "dashboard" + HTML string `json:"html,omitempty"` // replacement markup (optional) + Title string `json:"title,omitempty"` // e.g. "Agent Dashboard" + State map[string]any `json:"state,omitempty"` // partial state update + Merge bool `json:"merge,omitempty"` // merge state (default) or replace when false +} + +// WebviewUpdateOutput reports the result of updating an embedded view. +// +// // out.Success == true, out.ViewID == "dashboard" +type WebviewUpdateOutput struct { + Success bool `json:"success"` // true when the view was updated and broadcast + ViewID string `json:"view_id"` // echoed view identifier + UpdatedAt time.Time `json:"updatedAt"` // when the view was last updated +} + +// embeddedView captures the live state of a rendered UI view. Instances +// are kept per ViewID inside embeddedViewRegistry. +type embeddedView struct { + ViewID string + Title string + HTML string + Width int + Height int + State map[string]any + UpdatedAt time.Time +} + +// embeddedViewRegistry stores the most recent render/update state for each +// view so new subscribers can pick up the current UI on connection. +// Operations are guarded by embeddedViewMu. +var ( + embeddedViewMu sync.RWMutex + embeddedViewRegistry = map[string]*embeddedView{} +) + +// ChannelWebviewRender is the channel used to broadcast webview_render events. +const ChannelWebviewRender = "webview.render" + +// ChannelWebviewUpdate is the channel used to broadcast webview_update events. +const ChannelWebviewUpdate = "webview.update" + +// webviewRender handles the webview_render tool call. +func (s *Service) webviewRender(ctx context.Context, req *mcp.CallToolRequest, input WebviewRenderInput) (*mcp.CallToolResult, WebviewRenderOutput, error) { + s.logger.Info("MCP tool execution", "tool", "webview_render", "view", input.ViewID, "user", log.Username()) + + if core.Trim(input.ViewID) == "" { + return nil, WebviewRenderOutput{}, log.E("webviewRender", "view_id is required", nil) + } + + now := time.Now() + view := &embeddedView{ + ViewID: input.ViewID, + Title: input.Title, + HTML: input.HTML, + Width: input.Width, + Height: input.Height, + State: cloneStateMap(input.State), + UpdatedAt: now, + } + + embeddedViewMu.Lock() + embeddedViewRegistry[input.ViewID] = view + embeddedViewMu.Unlock() + + s.ChannelSend(ctx, ChannelWebviewRender, map[string]any{ + "view_id": view.ViewID, + "title": view.Title, + "html": view.HTML, + "width": view.Width, + "height": view.Height, + "state": cloneStateMap(view.State), + "updatedAt": view.UpdatedAt, + }) + + return nil, WebviewRenderOutput{ + Success: true, + ViewID: view.ViewID, + UpdatedAt: view.UpdatedAt, + }, nil +} + +// webviewUpdate handles the webview_update tool call. +func (s *Service) webviewUpdate(ctx context.Context, req *mcp.CallToolRequest, input WebviewUpdateInput) (*mcp.CallToolResult, WebviewUpdateOutput, error) { + s.logger.Info("MCP tool execution", "tool", "webview_update", "view", input.ViewID, "user", log.Username()) + + if core.Trim(input.ViewID) == "" { + return nil, WebviewUpdateOutput{}, log.E("webviewUpdate", "view_id is required", nil) + } + + now := time.Now() + + embeddedViewMu.Lock() + view, ok := embeddedViewRegistry[input.ViewID] + if !ok { + // Updating a view that was never rendered creates one lazily so + // clients that reconnect mid-session get a consistent snapshot. + view = &embeddedView{ViewID: input.ViewID, State: map[string]any{}} + embeddedViewRegistry[input.ViewID] = view + } + + if input.HTML != "" { + view.HTML = input.HTML + } + if input.Title != "" { + view.Title = input.Title + } + if input.State != nil { + merge := input.Merge || len(view.State) == 0 + if merge { + if view.State == nil { + view.State = map[string]any{} + } + for k, v := range input.State { + view.State[k] = v + } + } else { + view.State = cloneStateMap(input.State) + } + } + view.UpdatedAt = now + snapshot := *view + snapshot.State = cloneStateMap(view.State) + embeddedViewMu.Unlock() + + s.ChannelSend(ctx, ChannelWebviewUpdate, map[string]any{ + "view_id": snapshot.ViewID, + "title": snapshot.Title, + "html": snapshot.HTML, + "width": snapshot.Width, + "height": snapshot.Height, + "state": snapshot.State, + "updatedAt": snapshot.UpdatedAt, + }) + + return nil, WebviewUpdateOutput{ + Success: true, + ViewID: snapshot.ViewID, + UpdatedAt: snapshot.UpdatedAt, + }, nil +} + +// cloneStateMap returns a shallow copy of a state map. +// +// cloned := cloneStateMap(map[string]any{"a": 1}) // cloned["a"] == 1 +func cloneStateMap(in map[string]any) map[string]any { + if in == nil { + return nil + } + out := make(map[string]any, len(in)) + for k, v := range in { + out[k] = v + } + return out +} + +// lookupEmbeddedView returns the current snapshot of an embedded view, if any. +// +// view, ok := lookupEmbeddedView("dashboard") +func lookupEmbeddedView(id string) (*embeddedView, bool) { + embeddedViewMu.RLock() + defer embeddedViewMu.RUnlock() + view, ok := embeddedViewRegistry[id] + if !ok { + return nil, false + } + snapshot := *view + snapshot.State = cloneStateMap(view.State) + return &snapshot, true +} + +// resetEmbeddedViews clears the registry. Intended for tests. +func resetEmbeddedViews() { + embeddedViewMu.Lock() + defer embeddedViewMu.Unlock() + embeddedViewRegistry = map[string]*embeddedView{} +} diff --git a/pkg/mcp/tools_webview_embed_test.go b/pkg/mcp/tools_webview_embed_test.go new file mode 100644 index 0000000..79266b7 --- /dev/null +++ b/pkg/mcp/tools_webview_embed_test.go @@ -0,0 +1,137 @@ +// SPDX-License-Identifier: EUPL-1.2 + +package mcp + +import ( + "context" + "testing" +) + +// TestToolsWebviewEmbed_WebviewRender_Good registers a view and verifies the +// registry keeps the rendered HTML and state. +func TestToolsWebviewEmbed_WebviewRender_Good(t *testing.T) { + t.Cleanup(resetEmbeddedViews) + + svc, err := New(Options{WorkspaceRoot: t.TempDir()}) + if err != nil { + t.Fatal(err) + } + + _, out, err := svc.webviewRender(context.Background(), nil, WebviewRenderInput{ + ViewID: "dashboard", + HTML: "

hello

", + Title: "Demo", + State: map[string]any{"count": 1}, + }) + if err != nil { + t.Fatalf("webviewRender returned error: %v", err) + } + if !out.Success { + t.Fatal("expected Success=true") + } + if out.ViewID != "dashboard" { + t.Fatalf("expected view id 'dashboard', got %q", out.ViewID) + } + if out.UpdatedAt.IsZero() { + t.Fatal("expected non-zero UpdatedAt") + } + + view, ok := lookupEmbeddedView("dashboard") + if !ok { + t.Fatal("expected view to be stored in registry") + } + if view.HTML != "

hello

" { + t.Fatalf("expected HTML '

hello

', got %q", view.HTML) + } + if view.State["count"] != 1 { + t.Fatalf("expected state.count=1, got %v", view.State["count"]) + } +} + +// TestToolsWebviewEmbed_WebviewRender_Bad ensures empty view IDs are rejected. +func TestToolsWebviewEmbed_WebviewRender_Bad(t *testing.T) { + t.Cleanup(resetEmbeddedViews) + + svc, err := New(Options{WorkspaceRoot: t.TempDir()}) + if err != nil { + t.Fatal(err) + } + + _, _, err = svc.webviewRender(context.Background(), nil, WebviewRenderInput{}) + if err == nil { + t.Fatal("expected error for empty view_id") + } +} + +// TestToolsWebviewEmbed_WebviewUpdate_Good merges a state patch into the +// previously rendered view. +func TestToolsWebviewEmbed_WebviewUpdate_Good(t *testing.T) { + t.Cleanup(resetEmbeddedViews) + + svc, err := New(Options{WorkspaceRoot: t.TempDir()}) + if err != nil { + t.Fatal(err) + } + + _, _, err = svc.webviewRender(context.Background(), nil, WebviewRenderInput{ + ViewID: "dashboard", + HTML: "

hello

", + State: map[string]any{"count": 1}, + }) + if err != nil { + t.Fatalf("seed render failed: %v", err) + } + + _, out, err := svc.webviewUpdate(context.Background(), nil, WebviewUpdateInput{ + ViewID: "dashboard", + State: map[string]any{"theme": "dark"}, + Merge: true, + }) + if err != nil { + t.Fatalf("webviewUpdate returned error: %v", err) + } + if !out.Success { + t.Fatal("expected Success=true") + } + + view, ok := lookupEmbeddedView("dashboard") + if !ok { + t.Fatal("expected view to exist after update") + } + if view.State["count"] != 1 { + t.Fatalf("expected count to persist after merge, got %v", view.State["count"]) + } + if view.State["theme"] != "dark" { + t.Fatalf("expected theme 'dark' after merge, got %v", view.State["theme"]) + } +} + +// TestToolsWebviewEmbed_WebviewUpdate_Ugly updates a view that was never +// rendered and verifies a fresh registry entry is created. +func TestToolsWebviewEmbed_WebviewUpdate_Ugly(t *testing.T) { + t.Cleanup(resetEmbeddedViews) + + svc, err := New(Options{WorkspaceRoot: t.TempDir()}) + if err != nil { + t.Fatal(err) + } + + _, out, err := svc.webviewUpdate(context.Background(), nil, WebviewUpdateInput{ + ViewID: "ghost", + HTML: "

new

", + }) + if err != nil { + t.Fatalf("webviewUpdate returned error: %v", err) + } + if !out.Success { + t.Fatal("expected Success=true for lazy-create update") + } + + view, ok := lookupEmbeddedView("ghost") + if !ok { + t.Fatal("expected ghost view to be created lazily") + } + if view.HTML != "

new

" { + t.Fatalf("expected HTML '

new

', got %q", view.HTML) + } +} diff --git a/pkg/mcp/tools_ws_client.go b/pkg/mcp/tools_ws_client.go new file mode 100644 index 0000000..1895d1a --- /dev/null +++ b/pkg/mcp/tools_ws_client.go @@ -0,0 +1,264 @@ +// SPDX-License-Identifier: EUPL-1.2 + +package mcp + +import ( + "context" + "crypto/rand" + "encoding/hex" + "net/http" + "sync" + "time" + + core "dappco.re/go/core" + "dappco.re/go/core/log" + "github.com/gorilla/websocket" + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +// WSConnectInput contains parameters for opening an outbound WebSocket +// connection from the MCP server. Each connection is given a stable ID that +// subsequent ws_send and ws_close calls use to address it. +// +// input := WSConnectInput{URL: "wss://example.com/ws", Timeout: 10} +type WSConnectInput struct { + URL string `json:"url"` // e.g. "wss://example.com/ws" + Headers map[string]string `json:"headers,omitempty"` // custom request headers + Timeout int `json:"timeout,omitempty"` // handshake timeout in seconds (default: 30) +} + +// WSConnectOutput contains the result of opening a WebSocket connection. +// +// // out.Success == true, out.ID == "ws-0af3…" +type WSConnectOutput struct { + Success bool `json:"success"` // true when the handshake completed + ID string `json:"id"` // e.g. "ws-0af3…" + URL string `json:"url"` // the URL that was dialled +} + +// WSSendInput contains parameters for sending a message on an open +// WebSocket connection. +// +// input := WSSendInput{ID: "ws-0af3…", Message: "ping"} +type WSSendInput struct { + ID string `json:"id"` // e.g. "ws-0af3…" + Message string `json:"message"` // payload to send + Binary bool `json:"binary,omitempty"` // true to send a binary frame (payload is base64 text) +} + +// WSSendOutput contains the result of sending a message. +// +// // out.Success == true, out.ID == "ws-0af3…" +type WSSendOutput struct { + Success bool `json:"success"` // true when the message was written + ID string `json:"id"` // e.g. "ws-0af3…" + Bytes int `json:"bytes"` // number of bytes written +} + +// WSCloseInput contains parameters for closing a WebSocket connection. +// +// input := WSCloseInput{ID: "ws-0af3…", Reason: "done"} +type WSCloseInput struct { + ID string `json:"id"` // e.g. "ws-0af3…" + Code int `json:"code,omitempty"` // close code (default: 1000 - normal closure) + Reason string `json:"reason,omitempty"` // human-readable reason +} + +// WSCloseOutput contains the result of closing a WebSocket connection. +// +// // out.Success == true, out.ID == "ws-0af3…" +type WSCloseOutput struct { + Success bool `json:"success"` // true when the connection was closed + ID string `json:"id"` // e.g. "ws-0af3…" + Message string `json:"message,omitempty"` // e.g. "connection closed" +} + +// wsClientConn tracks an outbound WebSocket connection tied to a stable ID. +type wsClientConn struct { + ID string + URL string + conn *websocket.Conn + writeMu sync.Mutex + CreatedAt time.Time +} + +// wsClientRegistry holds all live outbound WebSocket connections keyed by ID. +// Access is guarded by wsClientMu. +var ( + wsClientMu sync.Mutex + wsClientRegistry = map[string]*wsClientConn{} +) + +// registerWSClientTools registers the outbound WebSocket client tools. +func (s *Service) registerWSClientTools(server *mcp.Server) { + addToolRecorded(s, server, "ws", &mcp.Tool{ + Name: "ws_connect", + Description: "Open an outbound WebSocket connection. Returns a connection ID for subsequent ws_send and ws_close calls.", + }, s.wsConnect) + + addToolRecorded(s, server, "ws", &mcp.Tool{ + Name: "ws_send", + Description: "Send a text or binary message on an open WebSocket connection identified by ID.", + }, s.wsSend) + + addToolRecorded(s, server, "ws", &mcp.Tool{ + Name: "ws_close", + Description: "Close an open WebSocket connection identified by ID.", + }, s.wsClose) +} + +// wsConnect handles the ws_connect tool call. +func (s *Service) wsConnect(ctx context.Context, req *mcp.CallToolRequest, input WSConnectInput) (*mcp.CallToolResult, WSConnectOutput, error) { + s.logger.Security("MCP tool execution", "tool", "ws_connect", "url", input.URL, "user", log.Username()) + + if core.Trim(input.URL) == "" { + return nil, WSConnectOutput{}, log.E("wsConnect", "url is required", nil) + } + + timeout := time.Duration(input.Timeout) * time.Second + if timeout <= 0 { + timeout = 30 * time.Second + } + + dialer := websocket.Dialer{ + HandshakeTimeout: timeout, + } + + headers := http.Header{} + for k, v := range input.Headers { + headers.Set(k, v) + } + + dialCtx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + + conn, _, err := dialer.DialContext(dialCtx, input.URL, headers) + if err != nil { + log.Error("mcp: ws connect failed", "url", input.URL, "err", err) + return nil, WSConnectOutput{}, log.E("wsConnect", "failed to connect", err) + } + + id := newWSClientID() + client := &wsClientConn{ + ID: id, + URL: input.URL, + conn: conn, + CreatedAt: time.Now(), + } + + wsClientMu.Lock() + wsClientRegistry[id] = client + wsClientMu.Unlock() + + return nil, WSConnectOutput{ + Success: true, + ID: id, + URL: input.URL, + }, nil +} + +// wsSend handles the ws_send tool call. +func (s *Service) wsSend(ctx context.Context, req *mcp.CallToolRequest, input WSSendInput) (*mcp.CallToolResult, WSSendOutput, error) { + s.logger.Info("MCP tool execution", "tool", "ws_send", "id", input.ID, "binary", input.Binary, "user", log.Username()) + + if core.Trim(input.ID) == "" { + return nil, WSSendOutput{}, log.E("wsSend", "id is required", nil) + } + + client, ok := getWSClient(input.ID) + if !ok { + return nil, WSSendOutput{}, log.E("wsSend", "connection not found", nil) + } + + messageType := websocket.TextMessage + if input.Binary { + messageType = websocket.BinaryMessage + } + + client.writeMu.Lock() + err := client.conn.WriteMessage(messageType, []byte(input.Message)) + client.writeMu.Unlock() + if err != nil { + log.Error("mcp: ws send failed", "id", input.ID, "err", err) + return nil, WSSendOutput{}, log.E("wsSend", "failed to send message", err) + } + + return nil, WSSendOutput{ + Success: true, + ID: input.ID, + Bytes: len(input.Message), + }, nil +} + +// wsClose handles the ws_close tool call. +func (s *Service) wsClose(ctx context.Context, req *mcp.CallToolRequest, input WSCloseInput) (*mcp.CallToolResult, WSCloseOutput, error) { + s.logger.Info("MCP tool execution", "tool", "ws_close", "id", input.ID, "user", log.Username()) + + if core.Trim(input.ID) == "" { + return nil, WSCloseOutput{}, log.E("wsClose", "id is required", nil) + } + + wsClientMu.Lock() + client, ok := wsClientRegistry[input.ID] + if ok { + delete(wsClientRegistry, input.ID) + } + wsClientMu.Unlock() + + if !ok { + return nil, WSCloseOutput{}, log.E("wsClose", "connection not found", nil) + } + + code := input.Code + if code == 0 { + code = websocket.CloseNormalClosure + } + reason := input.Reason + if reason == "" { + reason = "closed" + } + + client.writeMu.Lock() + _ = client.conn.WriteControl( + websocket.CloseMessage, + websocket.FormatCloseMessage(code, reason), + time.Now().Add(5*time.Second), + ) + client.writeMu.Unlock() + _ = client.conn.Close() + + return nil, WSCloseOutput{ + Success: true, + ID: input.ID, + Message: "connection closed", + }, nil +} + +// newWSClientID returns a fresh identifier for an outbound WebSocket client. +// +// id := newWSClientID() // "ws-0af3…" +func newWSClientID() string { + var buf [8]byte + _, _ = rand.Read(buf[:]) + return "ws-" + hex.EncodeToString(buf[:]) +} + +// getWSClient returns a tracked outbound WebSocket client by ID, if any. +// +// client, ok := getWSClient("ws-0af3…") +func getWSClient(id string) (*wsClientConn, bool) { + wsClientMu.Lock() + defer wsClientMu.Unlock() + client, ok := wsClientRegistry[id] + return client, ok +} + +// resetWSClients drops all tracked outbound WebSocket clients. Intended for tests. +func resetWSClients() { + wsClientMu.Lock() + defer wsClientMu.Unlock() + for id, client := range wsClientRegistry { + _ = client.conn.Close() + delete(wsClientRegistry, id) + } +} diff --git a/pkg/mcp/tools_ws_client_test.go b/pkg/mcp/tools_ws_client_test.go new file mode 100644 index 0000000..3c3d178 --- /dev/null +++ b/pkg/mcp/tools_ws_client_test.go @@ -0,0 +1,169 @@ +// SPDX-License-Identifier: EUPL-1.2 + +package mcp + +import ( + "context" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/gorilla/websocket" +) + +// TestToolsWSClient_WSConnect_Good dials a test WebSocket server and verifies +// the handshake completes and a client ID is assigned. +func TestToolsWSClient_WSConnect_Good(t *testing.T) { + t.Cleanup(resetWSClients) + + server := startTestWSServer(t) + defer server.Close() + + svc, err := New(Options{WorkspaceRoot: t.TempDir()}) + if err != nil { + t.Fatal(err) + } + + _, out, err := svc.wsConnect(context.Background(), nil, WSConnectInput{ + URL: "ws" + strings.TrimPrefix(server.URL, "http") + "/ws", + Timeout: 5, + }) + if err != nil { + t.Fatalf("wsConnect failed: %v", err) + } + if !out.Success { + t.Fatal("expected Success=true") + } + if !strings.HasPrefix(out.ID, "ws-") { + t.Fatalf("expected ID prefix 'ws-', got %q", out.ID) + } + + _, _, err = svc.wsClose(context.Background(), nil, WSCloseInput{ID: out.ID}) + if err != nil { + t.Fatalf("wsClose failed: %v", err) + } +} + +// TestToolsWSClient_WSConnect_Bad rejects empty URLs. +func TestToolsWSClient_WSConnect_Bad(t *testing.T) { + t.Cleanup(resetWSClients) + + svc, err := New(Options{WorkspaceRoot: t.TempDir()}) + if err != nil { + t.Fatal(err) + } + + _, _, err = svc.wsConnect(context.Background(), nil, WSConnectInput{}) + if err == nil { + t.Fatal("expected error for empty URL") + } +} + +// TestToolsWSClient_WSSendClose_Good sends a message on an open connection +// and then closes it. +func TestToolsWSClient_WSSendClose_Good(t *testing.T) { + t.Cleanup(resetWSClients) + + server := startTestWSServer(t) + defer server.Close() + + svc, err := New(Options{WorkspaceRoot: t.TempDir()}) + if err != nil { + t.Fatal(err) + } + + _, conn, err := svc.wsConnect(context.Background(), nil, WSConnectInput{ + URL: "ws" + strings.TrimPrefix(server.URL, "http") + "/ws", + Timeout: 5, + }) + if err != nil { + t.Fatalf("wsConnect failed: %v", err) + } + + _, sendOut, err := svc.wsSend(context.Background(), nil, WSSendInput{ + ID: conn.ID, + Message: "ping", + }) + if err != nil { + t.Fatalf("wsSend failed: %v", err) + } + if !sendOut.Success { + t.Fatal("expected Success=true for wsSend") + } + if sendOut.Bytes != 4 { + t.Fatalf("expected 4 bytes written, got %d", sendOut.Bytes) + } + + _, closeOut, err := svc.wsClose(context.Background(), nil, WSCloseInput{ID: conn.ID}) + if err != nil { + t.Fatalf("wsClose failed: %v", err) + } + if !closeOut.Success { + t.Fatal("expected Success=true for wsClose") + } + + if _, ok := getWSClient(conn.ID); ok { + t.Fatal("expected connection to be removed after close") + } +} + +// TestToolsWSClient_WSSend_Bad rejects unknown connection IDs. +func TestToolsWSClient_WSSend_Bad(t *testing.T) { + t.Cleanup(resetWSClients) + + svc, err := New(Options{WorkspaceRoot: t.TempDir()}) + if err != nil { + t.Fatal(err) + } + + _, _, err = svc.wsSend(context.Background(), nil, WSSendInput{ID: "ws-missing", Message: "x"}) + if err == nil { + t.Fatal("expected error for unknown connection ID") + } +} + +// TestToolsWSClient_WSClose_Bad rejects closes for unknown connection IDs. +func TestToolsWSClient_WSClose_Bad(t *testing.T) { + t.Cleanup(resetWSClients) + + svc, err := New(Options{WorkspaceRoot: t.TempDir()}) + if err != nil { + t.Fatal(err) + } + + _, _, err = svc.wsClose(context.Background(), nil, WSCloseInput{ID: "ws-missing"}) + if err == nil { + t.Fatal("expected error for unknown connection ID") + } +} + +// startTestWSServer returns an httptest.Server running a minimal echo WebSocket +// handler used by the ws_connect/ws_send tests. +func startTestWSServer(t *testing.T) *httptest.Server { + t.Helper() + + upgrader := websocket.Upgrader{ + CheckOrigin: func(*http.Request) bool { return true }, + } + mux := http.NewServeMux() + mux.HandleFunc("/ws", func(w http.ResponseWriter, r *http.Request) { + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + return + } + defer conn.Close() + conn.SetReadDeadline(time.Now().Add(5 * time.Second)) + for { + _, msg, err := conn.ReadMessage() + if err != nil { + return + } + if err := conn.WriteMessage(websocket.TextMessage, msg); err != nil { + return + } + } + }) + return httptest.NewServer(mux) +} diff --git a/pkg/mcp/transport_http.go b/pkg/mcp/transport_http.go index 2fa617d..c82b703 100644 --- a/pkg/mcp/transport_http.go +++ b/pkg/mcp/transport_http.go @@ -249,10 +249,31 @@ func handleMCPDiscovery(w http.ResponseWriter, r *http.Request) { "command": "core-agent", "args": []string{"mcp"}, }, - Capabilities: []string{"tools", "resources"}, - UseWhen: []string{"Need to dispatch work to Codex/Claude/Gemini", "Need workspace status", "Need semantic search"}, + Capabilities: []string{"tools", "resources"}, + UseWhen: []string{ + "Need to dispatch work to Codex/Claude/Gemini", + "Need workspace status", + "Need semantic search", + }, RelatedServers: []string{"core-mcp"}, }, + { + ID: "core-mcp", + Name: "Core MCP", + Description: "File ops, process and build tools, RAG search, webview, dashboards — the agent-facing MCP framework.", + Connection: map[string]any{ + "type": "stdio", + "command": "core-mcp", + }, + Capabilities: []string{"tools", "resources", "logging"}, + UseWhen: []string{ + "Need to read/write files inside a workspace", + "Need to start or monitor processes", + "Need to run RAG queries or index documents", + "Need to render or update an embedded dashboard view", + }, + RelatedServers: []string{"core-agent"}, + }, }, }