diff --git a/pkg/mcp/mcp.go b/pkg/mcp/mcp.go index 2c75f90..a7d7a1a 100644 --- a/pkg/mcp/mcp.go +++ b/pkg/mcp/mcp.go @@ -174,14 +174,40 @@ func (s *Service) ToolsSeq() iter.Seq[ToolRecord] { // defer cancel() // if err := svc.Shutdown(ctx); err != nil { log.Fatal(err) } func (s *Service) Shutdown(ctx context.Context) error { + var shutdownErr error + for _, sub := range s.subsystems { if sh, ok := sub.(SubsystemWithShutdown); ok { if err := sh.Shutdown(ctx); err != nil { - return log.E("mcp.Shutdown", "shutdown "+sub.Name(), err) + if shutdownErr == nil { + shutdownErr = log.E("mcp.Shutdown", "shutdown "+sub.Name(), err) + } } } } - return nil + + if s.wsServer != nil { + s.wsMu.Lock() + server := s.wsServer + s.wsMu.Unlock() + + if err := server.Shutdown(ctx); err != nil && shutdownErr == nil { + shutdownErr = log.E("mcp.Shutdown", "shutdown websocket server", err) + } + + s.wsMu.Lock() + if s.wsServer == server { + s.wsServer = nil + s.wsAddr = "" + } + s.wsMu.Unlock() + } + + if err := closeWebviewConnection(); err != nil && shutdownErr == nil { + shutdownErr = log.E("mcp.Shutdown", "close webview connection", err) + } + + return shutdownErr } // WSHub returns the WebSocket hub, or nil if not configured. diff --git a/pkg/mcp/tools_process.go b/pkg/mcp/tools_process.go index ccc948e..1c9abb8 100644 --- a/pkg/mcp/tools_process.go +++ b/pkg/mcp/tools_process.go @@ -221,10 +221,10 @@ func (s *Service) processStop(ctx context.Context, req *mcp.CallToolRequest, inp return nil, ProcessStopOutput{}, log.E("processStop", "process not found", err) } - // For graceful stop, we use Kill() which sends SIGKILL - // A more sophisticated implementation could use SIGTERM first - if err := proc.Kill(); err != nil { - log.Error("mcp: process stop kill failed", "id", input.ID, "err", err) + // Use the process service's graceful shutdown path first so callers get + // a real stop signal before we fall back to a hard kill internally. + if err := proc.Shutdown(); err != nil { + log.Error("mcp: process stop failed", "id", input.ID, "err", err) return nil, ProcessStopOutput{}, log.E("processStop", "failed to stop process", err) } diff --git a/pkg/mcp/tools_webview.go b/pkg/mcp/tools_webview.go index b4aee9c..e52eae1 100644 --- a/pkg/mcp/tools_webview.go +++ b/pkg/mcp/tools_webview.go @@ -3,6 +3,7 @@ package mcp import ( "context" "encoding/base64" + "strings" "sync" "time" @@ -25,6 +26,20 @@ var ( errSelectorRequired = log.E("webview", "selector is required", nil) ) +// closeWebviewConnection closes and clears the shared browser connection. +func closeWebviewConnection() error { + webviewMu.Lock() + defer webviewMu.Unlock() + + if webviewInstance == nil { + return nil + } + + err := webviewInstance.Close() + webviewInstance = nil + return err +} + // WebviewConnectInput contains parameters for connecting to Chrome DevTools. // // input := WebviewConnectInput{DebugURL: "http://localhost:9222", Timeout: 10} @@ -562,7 +577,15 @@ func (s *Service) webviewWait(ctx context.Context, req *mcp.CallToolRequest, inp return nil, WebviewWaitOutput{}, errSelectorRequired } - if err := webviewInstance.WaitForSelector(input.Selector); err != nil { + timeout := time.Duration(input.Timeout) * time.Second + if timeout <= 0 { + timeout = 30 * time.Second + } + + if err := waitForSelector(ctx, timeout, input.Selector, func(selector string) error { + _, err := webviewInstance.QuerySelector(selector) + return err + }); err != nil { log.Error("mcp: webview wait failed", "selector", input.Selector, "err", err) return nil, WebviewWaitOutput{}, log.E("webviewWait", "failed to wait for selector", err) } @@ -572,3 +595,34 @@ func (s *Service) webviewWait(ctx context.Context, req *mcp.CallToolRequest, inp Message: core.Sprintf("Element found: %s", input.Selector), }, nil } + +// waitForSelector polls until the selector exists or the timeout elapses. +// Query helpers in go-webview report "element not found" as an error, so we +// keep retrying until we see the element or hit the deadline. +func waitForSelector(ctx context.Context, timeout time.Duration, selector string, query func(string) error) error { + if timeout <= 0 { + timeout = 30 * time.Second + } + + waitCtx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + + ticker := time.NewTicker(10 * time.Millisecond) + defer ticker.Stop() + + for { + err := query(selector) + if err == nil { + return nil + } + if !strings.Contains(err.Error(), "element not found") { + return err + } + + select { + case <-waitCtx.Done(): + return log.E("webviewWait", "timed out waiting for selector", waitCtx.Err()) + case <-ticker.C: + } + } +} diff --git a/pkg/mcp/tools_webview_test.go b/pkg/mcp/tools_webview_test.go index 1849430..0428ed7 100644 --- a/pkg/mcp/tools_webview_test.go +++ b/pkg/mcp/tools_webview_test.go @@ -1,6 +1,8 @@ package mcp import ( + "context" + "errors" "testing" "time" @@ -215,6 +217,41 @@ func TestWebviewWaitInput_Good(t *testing.T) { } } +func TestWaitForSelector_Good(t *testing.T) { + ctx := context.Background() + + attempts := 0 + err := waitForSelector(ctx, 200*time.Millisecond, "#ready", func(selector string) error { + attempts++ + if attempts < 3 { + return errors.New("element not found: " + selector) + } + return nil + }) + + if err != nil { + t.Fatalf("waitForSelector failed: %v", err) + } + if attempts != 3 { + t.Fatalf("expected 3 attempts, got %d", attempts) + } +} + +func TestWaitForSelector_Bad_Timeout(t *testing.T) { + ctx := context.Background() + + start := time.Now() + err := waitForSelector(ctx, 50*time.Millisecond, "#missing", func(selector string) error { + return errors.New("element not found: " + selector) + }) + if err == nil { + t.Fatal("expected waitForSelector to time out") + } + if time.Since(start) < 50*time.Millisecond { + t.Fatal("expected waitForSelector to honor timeout") + } +} + // TestWebviewConnectOutput_Good verifies the WebviewConnectOutput struct has expected fields. func TestWebviewConnectOutput_Good(t *testing.T) { output := WebviewConnectOutput{