diff --git a/codex-rs/core/src/codex.rs b/codex-rs/core/src/codex.rs index dbde7a4e2..8bafc13aa 100644 --- a/codex-rs/core/src/codex.rs +++ b/codex-rs/core/src/codex.rs @@ -9,8 +9,6 @@ use crate::client_common::REVIEW_PROMPT; use crate::compact; use crate::features::Feature; use crate::function_tool::FunctionCallError; -use crate::mcp::auth::McpAuthStatusEntry; -use crate::mcp_connection_manager::DEFAULT_STARTUP_TIMEOUT; use crate::parse_command::parse_command; use crate::parse_turn_item; use crate::response_processing::process_items; @@ -45,6 +43,7 @@ use mcp_types::ReadResourceResult; use serde_json; use serde_json::Value; use tokio::sync::Mutex; +use tokio::sync::RwLock; use tokio::sync::oneshot; use tokio_util::sync::CancellationToken; use tracing::debug; @@ -57,7 +56,6 @@ use crate::client::ModelClient; use crate::client_common::Prompt; use crate::client_common::ResponseEvent; use crate::config::Config; -use crate::config::types::McpServerTransportConfig; use crate::config::types::ShellEnvironmentPolicy; use crate::context_manager::ContextManager; use crate::environment_context::EnvironmentContext; @@ -476,21 +474,13 @@ impl Session { ), }; - // Error messages to dispatch after SessionConfigured is sent. - let mut post_session_configured_events = Vec::::new(); - // Kick off independent async setup tasks in parallel to reduce startup latency. // // - initialize RolloutRecorder with new or resumed session info - // - spin up MCP connection manager // - perform default shell discovery // - load history metadata let rollout_fut = RolloutRecorder::new(&config, rollout_params); - let mcp_fut = McpConnectionManager::new( - config.mcp_servers.clone(), - config.mcp_oauth_credentials_store_mode, - ); let default_shell_fut = shell::default_user_shell(); let history_meta_fut = crate::message_history::history_metadata(&config); let auth_statuses_fut = compute_auth_statuses( @@ -499,15 +489,8 @@ impl Session { ); // Join all independent futures. - let ( - rollout_recorder, - mcp_res, - default_shell, - (history_log_id, history_entry_count), - auth_statuses, - ) = tokio::join!( + let (rollout_recorder, default_shell, (history_log_id, history_entry_count), auth_statuses) = tokio::join!( rollout_fut, - mcp_fut, default_shell_fut, history_meta_fut, auth_statuses_fut @@ -519,34 +502,7 @@ impl Session { })?; let rollout_path = rollout_recorder.rollout_path.clone(); - // Handle MCP manager result and record any startup failures. - let (mcp_connection_manager, failed_clients) = match mcp_res { - Ok((mgr, failures)) => (mgr, failures), - Err(e) => { - let message = format!("Failed to create MCP connection manager: {e:#}"); - error!("{message}"); - post_session_configured_events.push(Event { - id: INITIAL_SUBMIT_ID.to_owned(), - msg: EventMsg::Error(ErrorEvent { message }), - }); - (McpConnectionManager::default(), Default::default()) - } - }; - - // Surface individual client start-up failures to the user. - if !failed_clients.is_empty() { - for (server_name, err) in failed_clients { - let auth_entry = auth_statuses.get(&server_name); - let display_message = mcp_init_error_display(&server_name, auth_entry, &err); - warn!("MCP client for `{server_name}` failed to start: {err:#}"); - post_session_configured_events.push(Event { - id: INITIAL_SUBMIT_ID.to_owned(), - msg: EventMsg::Error(ErrorEvent { - message: display_message, - }), - }); - } - } + let mut post_session_configured_events = Vec::::new(); for (alias, feature) in session_configuration.features.legacy_feature_usages() { let canonical = feature.key(); @@ -595,7 +551,8 @@ impl Session { warm_model_cache(&session_configuration.model); let services = SessionServices { - mcp_connection_manager, + mcp_connection_manager: Arc::new(RwLock::new(McpConnectionManager::default())), + mcp_startup_cancellation_token: CancellationToken::new(), unified_exec_manager: UnifiedExecSessionManager::default(), notifier: UserNotifier::new(config.notify.clone()), rollout: Mutex::new(Some(rollout_recorder)), @@ -635,6 +592,18 @@ impl Session { for event in events { sess.send_event_raw(event).await; } + sess.services + .mcp_connection_manager + .write() + .await + .initialize( + config.mcp_servers.clone(), + config.mcp_oauth_credentials_store_mode, + auth_statuses.clone(), + tx_event.clone(), + sess.services.mcp_startup_cancellation_token.clone(), + ) + .await; // record_initial_history can emit events. We record only after the SessionConfiguredEvent is emitted. sess.record_initial_history(initial_history).await; @@ -1258,6 +1227,8 @@ impl Session { ) -> anyhow::Result { self.services .mcp_connection_manager + .read() + .await .list_resources(server, params) .await } @@ -1269,6 +1240,8 @@ impl Session { ) -> anyhow::Result { self.services .mcp_connection_manager + .read() + .await .list_resource_templates(server, params) .await } @@ -1280,6 +1253,8 @@ impl Session { ) -> anyhow::Result { self.services .mcp_connection_manager + .read() + .await .read_resource(server, params) .await } @@ -1292,19 +1267,29 @@ impl Session { ) -> anyhow::Result { self.services .mcp_connection_manager + .read() + .await .call_tool(server, tool, arguments) .await } - pub(crate) fn parse_mcp_tool_name(&self, tool_name: &str) -> Option<(String, String)> { + pub(crate) async fn parse_mcp_tool_name(&self, tool_name: &str) -> Option<(String, String)> { self.services .mcp_connection_manager + .read() + .await .parse_tool_name(tool_name) + .await } pub async fn interrupt_task(self: &Arc) { info!("interrupt received: abort current task, if any"); - self.abort_all_tasks(TurnAbortReason::Interrupted).await; + let has_active_turn = { self.active_turn.lock().await.is_some() }; + if has_active_turn { + self.abort_all_tasks(TurnAbortReason::Interrupted).await; + } else { + self.cancel_mcp_startup().await; + } } pub(crate) fn notifier(&self) -> &UserNotifier { @@ -1318,6 +1303,10 @@ impl Session { fn show_raw_agent_reasoning(&self) -> bool { self.services.show_raw_agent_reasoning } + + async fn cancel_mcp_startup(&self) { + self.services.mcp_startup_cancellation_token.cancel(); + } } async fn submission_loop(sess: Arc, config: Arc, rx_sub: Receiver) { @@ -1575,17 +1564,15 @@ mod handlers { } pub async fn list_mcp_tools(sess: &Session, config: &Arc, sub_id: String) { - // This is a cheap lookup from the connection manager's cache. - let tools = sess.services.mcp_connection_manager.list_all_tools(); - let (auth_status_entries, resources, resource_templates) = tokio::join!( + let mcp_connection_manager = sess.services.mcp_connection_manager.read().await; + let (tools, auth_status_entries, resources, resource_templates) = tokio::join!( + mcp_connection_manager.list_all_tools(), compute_auth_statuses( config.mcp_servers.iter(), config.mcp_oauth_credentials_store_mode, ), - sess.services.mcp_connection_manager.list_all_resources(), - sess.services - .mcp_connection_manager - .list_all_resource_templates() + mcp_connection_manager.list_all_resources(), + mcp_connection_manager.list_all_resource_templates(), ); let auth_statuses = auth_status_entries .iter() @@ -1594,7 +1581,10 @@ mod handlers { let event = Event { id: sub_id, msg: EventMsg::McpListToolsResponse(crate::protocol::McpListToolsResponseEvent { - tools, + tools: tools + .into_iter() + .map(|(name, tool)| (name, tool.tool)) + .collect(), resources, resource_templates, auth_statuses, @@ -1924,10 +1914,22 @@ async fn run_turn( input: Vec, cancellation_token: CancellationToken, ) -> CodexResult { - let mcp_tools = sess.services.mcp_connection_manager.list_all_tools(); + let mcp_tools = sess + .services + .mcp_connection_manager + .read() + .await + .list_all_tools() + .or_cancel(&cancellation_token) + .await?; let router = Arc::new(ToolRouter::from_config( &turn_context.tools_config, - Some(mcp_tools), + Some( + mcp_tools + .into_iter() + .map(|(name, tool)| (name, tool.tool)) + .collect(), + ), )); let model_supports_parallel = turn_context @@ -2096,7 +2098,7 @@ async fn try_run_turn( ResponseEvent::Created => {} ResponseEvent::OutputItemDone(item) => { let previously_active_item = active_item.take(); - match ToolRouter::build_tool_call(sess.as_ref(), item.clone()) { + match ToolRouter::build_tool_call(sess.as_ref(), item.clone()).await { Ok(Some(call)) => { let payload_preview = call.payload.log_payload().into_owned(); tracing::info!("ToolCall: {} {}", call.tool_name, payload_preview); @@ -2307,59 +2309,6 @@ pub(super) fn get_last_assistant_message_from_turn(responses: &[ResponseItem]) - }) } -fn mcp_init_error_display( - server_name: &str, - entry: Option<&McpAuthStatusEntry>, - err: &anyhow::Error, -) -> String { - if let Some(McpServerTransportConfig::StreamableHttp { - url, - bearer_token_env_var, - http_headers, - .. - }) = &entry.map(|entry| &entry.config.transport) - && url == "https://api.githubcopilot.com/mcp/" - && bearer_token_env_var.is_none() - && http_headers.as_ref().map(HashMap::is_empty).unwrap_or(true) - { - // GitHub only supports OAUth for first party MCP clients. - // That means that the user has to specify a personal access token either via bearer_token_env_var or http_headers. - // https://github.com/github/github-mcp-server/issues/921#issuecomment-3221026448 - format!( - "GitHub MCP does not support OAuth. Log in by adding a personal access token (https://github.com/settings/personal-access-tokens) to your environment and config.toml:\n[mcp_servers.{server_name}]\nbearer_token_env_var = CODEX_GITHUB_PERSONAL_ACCESS_TOKEN" - ) - } else if is_mcp_client_auth_required_error(err) { - format!( - "The {server_name} MCP server is not logged in. Run `codex mcp login {server_name}`." - ) - } else if is_mcp_client_startup_timeout_error(err) { - let startup_timeout_secs = match entry { - Some(entry) => match entry.config.startup_timeout_sec { - Some(timeout) => timeout, - None => DEFAULT_STARTUP_TIMEOUT, - }, - None => DEFAULT_STARTUP_TIMEOUT, - } - .as_secs(); - format!( - "MCP client for `{server_name}` timed out after {startup_timeout_secs} seconds. Add or adjust `startup_timeout_sec` in your config.toml:\n[mcp_servers.{server_name}]\nstartup_timeout_sec = XX" - ) - } else { - format!("MCP client for `{server_name}` failed to start: {err:#}") - } -} - -fn is_mcp_client_auth_required_error(error: &anyhow::Error) -> bool { - // StreamableHttpError::AuthRequired from the MCP SDK. - error.to_string().contains("Auth required") -} - -fn is_mcp_client_startup_timeout_error(error: &anyhow::Error) -> bool { - let error_message = error.to_string(); - error_message.contains("request timed out") - || error_message.contains("timed out handshaking with MCP server") -} - use crate::features::Features; #[cfg(test)] pub(crate) use tests::make_session_and_context; @@ -2369,10 +2318,7 @@ mod tests { use super::*; use crate::config::ConfigOverrides; use crate::config::ConfigToml; - use crate::config::types::McpServerConfig; - use crate::config::types::McpServerTransportConfig; use crate::exec::ExecToolCallOutput; - use crate::mcp::auth::McpAuthStatusEntry; use crate::tools::format_exec_output_str; use crate::protocol::CompactedItem; @@ -2392,7 +2338,6 @@ mod tests { use codex_app_server_protocol::AuthMode; use codex_protocol::models::ContentItem; use codex_protocol::models::ResponseItem; - use codex_protocol::protocol::McpAuthStatus; use std::time::Duration; use tokio::time::sleep; @@ -2606,7 +2551,8 @@ mod tests { let state = SessionState::new(session_configuration.clone()); let services = SessionServices { - mcp_connection_manager: McpConnectionManager::default(), + mcp_connection_manager: Arc::new(RwLock::new(McpConnectionManager::default())), + mcp_startup_cancellation_token: CancellationToken::new(), unified_exec_manager: UnifiedExecSessionManager::default(), notifier: UserNotifier::new(None), rollout: Mutex::new(None), @@ -2682,7 +2628,8 @@ mod tests { let state = SessionState::new(session_configuration.clone()); let services = SessionServices { - mcp_connection_manager: McpConnectionManager::default(), + mcp_connection_manager: Arc::new(RwLock::new(McpConnectionManager::default())), + mcp_startup_cancellation_token: CancellationToken::new(), unified_exec_manager: UnifiedExecSessionManager::default(), notifier: UserNotifier::new(None), rollout: Mutex::new(None), @@ -2863,9 +2810,23 @@ mod tests { #[tokio::test] async fn fatal_tool_error_stops_turn_and_reports_error() { let (session, turn_context, _rx) = make_session_and_context_with_rx(); + let tools = { + session + .services + .mcp_connection_manager + .read() + .await + .list_all_tools() + .await + }; let router = ToolRouter::from_config( &turn_context.tools_config, - Some(session.services.mcp_connection_manager.list_all_tools()), + Some( + tools + .into_iter() + .map(|(name, tool)| (name, tool.tool)) + .collect(), + ), ); let item = ResponseItem::CustomToolCall { id: None, @@ -2876,6 +2837,7 @@ mod tests { }; let call = ToolRouter::build_tool_call(session.as_ref(), item.clone()) + .await .expect("build tool call") .expect("tool call present"); let tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::new())); @@ -3125,7 +3087,6 @@ mod tests { pretty_assertions::assert_eq!(exec_output.metadata, ResponseExecMetadata { exit_code: 0 }); assert!(exec_output.output.contains("hi")); } - #[tokio::test] async fn unified_exec_rejects_escalated_permissions_when_policy_not_on_request() { use crate::protocol::AskForApproval; @@ -3167,89 +3128,4 @@ mod tests { pretty_assertions::assert_eq!(output, expected); } - - #[test] - fn mcp_init_error_display_prompts_for_github_pat() { - let server_name = "github"; - let entry = McpAuthStatusEntry { - config: McpServerConfig { - transport: McpServerTransportConfig::StreamableHttp { - url: "https://api.githubcopilot.com/mcp/".to_string(), - bearer_token_env_var: None, - http_headers: None, - env_http_headers: None, - }, - enabled: true, - startup_timeout_sec: None, - tool_timeout_sec: None, - enabled_tools: None, - disabled_tools: None, - }, - auth_status: McpAuthStatus::Unsupported, - }; - let err = anyhow::anyhow!("OAuth is unsupported"); - - let display = mcp_init_error_display(server_name, Some(&entry), &err); - - let expected = format!( - "GitHub MCP does not support OAuth. Log in by adding a personal access token (https://github.com/settings/personal-access-tokens) to your environment and config.toml:\n[mcp_servers.{server_name}]\nbearer_token_env_var = CODEX_GITHUB_PERSONAL_ACCESS_TOKEN" - ); - - assert_eq!(expected, display); - } - - #[test] - fn mcp_init_error_display_prompts_for_login_when_auth_required() { - let server_name = "example"; - let err = anyhow::anyhow!("Auth required for server"); - - let display = mcp_init_error_display(server_name, None, &err); - - let expected = format!( - "The {server_name} MCP server is not logged in. Run `codex mcp login {server_name}`." - ); - - assert_eq!(expected, display); - } - - #[test] - fn mcp_init_error_display_reports_generic_errors() { - let server_name = "custom"; - let entry = McpAuthStatusEntry { - config: McpServerConfig { - transport: McpServerTransportConfig::StreamableHttp { - url: "https://example.com".to_string(), - bearer_token_env_var: Some("TOKEN".to_string()), - http_headers: None, - env_http_headers: None, - }, - enabled: true, - startup_timeout_sec: None, - tool_timeout_sec: None, - enabled_tools: None, - disabled_tools: None, - }, - auth_status: McpAuthStatus::Unsupported, - }; - let err = anyhow::anyhow!("boom"); - - let display = mcp_init_error_display(server_name, Some(&entry), &err); - - let expected = format!("MCP client for `{server_name}` failed to start: {err:#}"); - - assert_eq!(expected, display); - } - - #[test] - fn mcp_init_error_display_includes_startup_timeout_hint() { - let server_name = "slow"; - let err = anyhow::anyhow!("request timed out"); - - let display = mcp_init_error_display(server_name, None, &err); - - assert_eq!( - "MCP client for `slow` timed out after 10 seconds. Add or adjust `startup_timeout_sec` in your config.toml:\n[mcp_servers.slow]\nstartup_timeout_sec = XX", - display - ); - } } diff --git a/codex-rs/core/src/mcp_connection_manager.rs b/codex-rs/core/src/mcp_connection_manager.rs index 11a73c0b0..d8869e5e9 100644 --- a/codex-rs/core/src/mcp_connection_manager.rs +++ b/codex-rs/core/src/mcp_connection_manager.rs @@ -13,11 +13,24 @@ use std::ffi::OsString; use std::sync::Arc; use std::time::Duration; +use crate::mcp::auth::McpAuthStatusEntry; use anyhow::Context; use anyhow::Result; use anyhow::anyhow; +use async_channel::Sender; +use codex_async_utils::CancelErr; +use codex_async_utils::OrCancelExt; +use codex_protocol::protocol::Event; +use codex_protocol::protocol::EventMsg; +use codex_protocol::protocol::McpStartupCompleteEvent; +use codex_protocol::protocol::McpStartupFailure; +use codex_protocol::protocol::McpStartupStatus; +use codex_protocol::protocol::McpStartupUpdateEvent; use codex_rmcp_client::OAuthCredentialsStoreMode; use codex_rmcp_client::RmcpClient; +use futures::future::BoxFuture; +use futures::future::FutureExt; +use futures::future::Shared; use mcp_types::ClientCapabilities; use mcp_types::Implementation; use mcp_types::ListResourceTemplatesRequestParams; @@ -34,9 +47,10 @@ use serde_json::json; use sha1::Digest; use sha1::Sha1; use tokio::task::JoinSet; -use tracing::info; +use tokio_util::sync::CancellationToken; use tracing::warn; +use crate::codex::INITIAL_SUBMIT_ID; use crate::config::types::McpServerConfig; use crate::config::types::McpServerTransportConfig; @@ -54,11 +68,10 @@ pub const DEFAULT_STARTUP_TIMEOUT: Duration = Duration::from_secs(10); /// Default timeout for individual tool calls. const DEFAULT_TOOL_TIMEOUT: Duration = Duration::from_secs(60); -/// Map that holds a startup error for every MCP server that could **not** be -/// spawned successfully. -pub type ClientStartErrors = HashMap; - -fn qualify_tools(tools: Vec) -> HashMap { +fn qualify_tools(tools: I) -> HashMap +where + I: IntoIterator, +{ let mut used_names = HashSet::new(); let mut qualified_tools = HashMap::new(); for tool in tools { @@ -90,222 +103,166 @@ fn qualify_tools(tools: Vec) -> HashMap { qualified_tools } -struct ToolInfo { - server_name: String, - tool_name: String, - tool: Tool, +#[derive(Clone)] +pub(crate) struct ToolInfo { + pub(crate) server_name: String, + pub(crate) tool_name: String, + pub(crate) tool: Tool, } +#[derive(Clone)] struct ManagedClient { client: Arc, - startup_timeout: Duration, + tools: Vec, + tool_filter: ToolFilter, tool_timeout: Option, } +#[derive(Clone)] +struct AsyncManagedClient { + client: Shared>>, +} + +impl AsyncManagedClient { + fn new( + server_name: String, + config: McpServerConfig, + store_mode: OAuthCredentialsStoreMode, + cancel_token: CancellationToken, + ) -> Self { + let tool_filter = ToolFilter::from_config(&config); + let fut = start_server_task( + server_name, + config.transport, + store_mode, + config + .startup_timeout_sec + .unwrap_or(DEFAULT_STARTUP_TIMEOUT), + config.tool_timeout_sec.unwrap_or(DEFAULT_TOOL_TIMEOUT), + tool_filter, + cancel_token, + ); + Self { + client: fut.boxed().shared(), + } + } + + async fn client(&self) -> Result { + self.client.clone().await + } +} + /// A thin wrapper around a set of running [`RmcpClient`] instances. #[derive(Default)] pub(crate) struct McpConnectionManager { - /// Server-name -> client instance. - /// - /// The server name originates from the keys of the `mcp_servers` map in - /// the user configuration. - clients: HashMap, - - /// Fully qualified tool name -> tool instance. - tools: HashMap, - - /// Server-name -> configured tool filters. - tool_filters: HashMap, + clients: HashMap, } impl McpConnectionManager { - /// Spawn a [`RmcpClient`] for each configured server. - /// - /// * `mcp_servers` – Map loaded from the user configuration where *keys* - /// are human-readable server identifiers and *values* are the spawn - /// instructions. - /// - /// Servers that fail to start are reported in `ClientStartErrors`: the - /// user should be informed about these errors. - pub async fn new( + pub async fn initialize( + &mut self, mcp_servers: HashMap, store_mode: OAuthCredentialsStoreMode, - ) -> Result<(Self, ClientStartErrors)> { - // Early exit if no servers are configured. - if mcp_servers.is_empty() { - return Ok((Self::default(), ClientStartErrors::default())); + auth_entries: HashMap, + tx_event: Sender, + cancel_token: CancellationToken, + ) { + if cancel_token.is_cancelled() { + return; } - - // Launch all configured servers concurrently. + let mut clients = HashMap::new(); let mut join_set = JoinSet::new(); - let mut errors = ClientStartErrors::new(); - let mut tool_filters: HashMap = HashMap::new(); - - for (server_name, cfg) in mcp_servers { - // Validate server name before spawning - if !is_valid_mcp_server_name(&server_name) { - let error = anyhow::anyhow!( - "invalid server name '{server_name}': must match pattern ^[a-zA-Z0-9_-]+$" - ); - errors.insert(server_name, error); - continue; - } - - if !cfg.enabled { - tool_filters.insert(server_name, ToolFilter::from_config(&cfg)); - continue; - } - - let startup_timeout = cfg.startup_timeout_sec.unwrap_or(DEFAULT_STARTUP_TIMEOUT); - let tool_timeout = cfg.tool_timeout_sec.unwrap_or(DEFAULT_TOOL_TIMEOUT); - tool_filters.insert(server_name.clone(), ToolFilter::from_config(&cfg)); - - let resolved_bearer_token = match &cfg.transport { - McpServerTransportConfig::StreamableHttp { - bearer_token_env_var, - .. - } => resolve_bearer_token(&server_name, bearer_token_env_var.as_deref()), - _ => Ok(None), - }; - + for (server_name, cfg) in mcp_servers.into_iter().filter(|(_, cfg)| cfg.enabled) { + let cancel_token = cancel_token.child_token(); + let _ = emit_update( + &tx_event, + McpStartupUpdateEvent { + server: server_name.clone(), + status: McpStartupStatus::Starting, + }, + ) + .await; + let async_managed_client = + AsyncManagedClient::new(server_name.clone(), cfg, store_mode, cancel_token.clone()); + clients.insert(server_name.clone(), async_managed_client.clone()); + let tx_event = tx_event.clone(); + let auth_entry = auth_entries.get(&server_name).cloned(); join_set.spawn(async move { - let McpServerConfig { transport, .. } = cfg; - let params = mcp_types::InitializeRequestParams { - capabilities: ClientCapabilities { - experimental: None, - roots: None, - sampling: None, - // https://modelcontextprotocol.io/specification/2025-06-18/client/elicitation#capabilities - // indicates this should be an empty object. - elicitation: Some(json!({})), - }, - client_info: Implementation { - name: "codex-mcp-client".to_owned(), - version: env!("CARGO_PKG_VERSION").to_owned(), - title: Some("Codex".into()), - // This field is used by Codex when it is an MCP - // server: it should not be used when Codex is - // an MCP client. - user_agent: None, - }, - protocol_version: mcp_types::MCP_SCHEMA_VERSION.to_owned(), - }; - - let resolved_bearer_token = resolved_bearer_token.unwrap_or_default(); - let client_result = match transport { - McpServerTransportConfig::Stdio { - command, - args, - env, - env_vars, - cwd, - } => { - let command_os: OsString = command.into(); - let args_os: Vec = args.into_iter().map(Into::into).collect(); - match RmcpClient::new_stdio_client(command_os, args_os, env, &env_vars, cwd) - .await - { - Ok(client) => { - let client = Arc::new(client); - client - .initialize(params.clone(), Some(startup_timeout)) - .await - .map(|_| client) - } - Err(err) => Err(err.into()), - } - } - McpServerTransportConfig::StreamableHttp { - url, - http_headers, - env_http_headers, - .. - } => { - match RmcpClient::new_streamable_http_client( - &server_name, - &url, - resolved_bearer_token.clone(), - http_headers, - env_http_headers, - store_mode, - ) - .await - { - Ok(client) => { - let client = Arc::new(client); - client - .initialize(params.clone(), Some(startup_timeout)) - .await - .map(|_| client) - } - Err(err) => Err(err), - } + let outcome = async_managed_client.client().await; + if cancel_token.is_cancelled() { + return (server_name, Err(StartupOutcomeError::Cancelled)); + } + let status = match &outcome { + Ok(_) => McpStartupStatus::Ready, + Err(error) => { + let error_str = mcp_init_error_display( + server_name.as_str(), + auth_entry.as_ref(), + error, + ); + McpStartupStatus::Failed { error: error_str } } }; - ( - (server_name, tool_timeout), - client_result.map(|client| (client, startup_timeout)), + let _ = emit_update( + &tx_event, + McpStartupUpdateEvent { + server: server_name.clone(), + status, + }, ) + .await; + + (server_name, outcome) }); } - - let mut clients: HashMap = HashMap::with_capacity(join_set.len()); - - while let Some(res) = join_set.join_next().await { - let ((server_name, tool_timeout), client_res) = match res { - Ok(result) => result, - Err(e) => { - warn!("Task panic when starting MCP server: {e:#}"); - continue; - } - }; - - match client_res { - Ok((client, startup_timeout)) => { - clients.insert( - server_name, - ManagedClient { - client, - startup_timeout, - tool_timeout: Some(tool_timeout), - }, - ); - } - Err(e) => { - errors.insert(server_name, e); + self.clients = clients; + tokio::spawn(async move { + let outcomes = join_set.join_all().await; + let mut summary = McpStartupCompleteEvent::default(); + for (server_name, outcome) in outcomes { + match outcome { + Ok(_) => summary.ready.push(server_name), + Err(StartupOutcomeError::Cancelled) => summary.cancelled.push(server_name), + Err(StartupOutcomeError::Failed { error }) => { + summary.failed.push(McpStartupFailure { + server: server_name, + error, + }) + } } } - } + let _ = tx_event + .send(Event { + id: INITIAL_SUBMIT_ID.to_owned(), + msg: EventMsg::McpStartupComplete(summary), + }) + .await; + }); + } - let all_tools = match list_all_tools(&clients).await { - Ok(tools) => tools, - Err(e) => { - warn!("Failed to list tools from some MCP servers: {e:#}"); - Vec::new() - } - }; - - let filtered_tools = filter_tools(all_tools, &tool_filters); - let tools = qualify_tools(filtered_tools); - - Ok(( - Self { - clients, - tools, - tool_filters, - }, - errors, - )) + async fn client_by_name(&self, name: &str) -> Result { + self.clients + .get(name) + .ok_or_else(|| anyhow!("unknown MCP server '{name}'"))? + .client() + .await + .context("failed to get client") } /// Returns a single map that contains all tools. Each key is the /// fully-qualified name for the tool. - pub fn list_all_tools(&self) -> HashMap { - self.tools - .iter() - .map(|(name, tool)| (name.clone(), tool.tool.clone())) - .collect() + pub async fn list_all_tools(&self) -> HashMap { + let mut tools = HashMap::new(); + for managed_client in self.clients.values() { + if let Ok(client) = managed_client.client().await { + tools.extend(qualify_tools(filter_tools( + client.tools, + client.tool_filter, + ))); + } + } + tools } /// Returns a single map that contains all resources. Each key is the @@ -313,10 +270,15 @@ impl McpConnectionManager { pub async fn list_all_resources(&self) -> HashMap> { let mut join_set = JoinSet::new(); - for (server_name, managed_client) in &self.clients { - let server_name_cloned = server_name.clone(); - let client_clone = managed_client.client.clone(); + let clients_snapshot = &self.clients; + + for (server_name, async_managed_client) in clients_snapshot { + let server_name = server_name.clone(); + let Ok(managed_client) = async_managed_client.client().await else { + continue; + }; let timeout = managed_client.tool_timeout; + let client = managed_client.client.clone(); join_set.spawn(async move { let mut collected: Vec = Vec::new(); @@ -326,9 +288,9 @@ impl McpConnectionManager { let params = cursor.as_ref().map(|next| ListResourcesRequestParams { cursor: Some(next.clone()), }); - let response = match client_clone.list_resources(params, timeout).await { + let response = match client.list_resources(params, timeout).await { Ok(result) => result, - Err(err) => return (server_name_cloned, Err(err)), + Err(err) => return (server_name, Err(err)), }; collected.extend(response.resources); @@ -337,13 +299,13 @@ impl McpConnectionManager { Some(next) => { if cursor.as_ref() == Some(&next) { return ( - server_name_cloned, + server_name, Err(anyhow!("resources/list returned duplicate cursor")), ); } cursor = Some(next); } - None => return (server_name_cloned, Ok(collected)), + None => return (server_name, Ok(collected)), } } }); @@ -373,9 +335,14 @@ impl McpConnectionManager { pub async fn list_all_resource_templates(&self) -> HashMap> { let mut join_set = JoinSet::new(); - for (server_name, managed_client) in &self.clients { + let clients_snapshot = &self.clients; + + for (server_name, async_managed_client) in clients_snapshot { let server_name_cloned = server_name.clone(); - let client_clone = managed_client.client.clone(); + let Ok(managed_client) = async_managed_client.client().await else { + continue; + }; + let client = managed_client.client.clone(); let timeout = managed_client.tool_timeout; join_set.spawn(async move { @@ -388,8 +355,7 @@ impl McpConnectionManager { .map(|next| ListResourceTemplatesRequestParams { cursor: Some(next.clone()), }); - let response = match client_clone.list_resource_templates(params, timeout).await - { + let response = match client.list_resource_templates(params, timeout).await { Ok(result) => result, Err(err) => return (server_name_cloned, Err(err)), }; @@ -442,22 +408,16 @@ impl McpConnectionManager { tool: &str, arguments: Option, ) -> Result { - if let Some(filter) = self.tool_filters.get(server) - && !filter.allows(tool) - { + let client = self.client_by_name(server).await?; + if !client.tool_filter.allows(tool) { return Err(anyhow!( "tool '{tool}' is disabled for MCP server '{server}'" )); } - let managed = self - .clients - .get(server) - .ok_or_else(|| anyhow!("unknown MCP server '{server}'"))?; - let client = &managed.client; - let timeout = managed.tool_timeout; client - .call_tool(tool.to_string(), arguments, timeout) + .client + .call_tool(tool.to_string(), arguments, client.tool_timeout) .await .with_context(|| format!("tool call failed for `{server}/{tool}`")) } @@ -468,14 +428,11 @@ impl McpConnectionManager { server: &str, params: Option, ) -> Result { - let managed = self - .clients - .get(server) - .ok_or_else(|| anyhow!("unknown MCP server '{server}'"))?; - let client = managed.client.clone(); + let managed = self.client_by_name(server).await?; let timeout = managed.tool_timeout; - client + managed + .client .list_resources(params, timeout) .await .with_context(|| format!("resources/list failed for `{server}`")) @@ -487,10 +444,7 @@ impl McpConnectionManager { server: &str, params: Option, ) -> Result { - let managed = self - .clients - .get(server) - .ok_or_else(|| anyhow!("unknown MCP server '{server}'"))?; + let managed = self.client_by_name(server).await?; let client = managed.client.clone(); let timeout = managed.tool_timeout; @@ -506,10 +460,7 @@ impl McpConnectionManager { server: &str, params: ReadResourceRequestParams, ) -> Result { - let managed = self - .clients - .get(server) - .ok_or_else(|| anyhow!("unknown MCP server '{server}'"))?; + let managed = self.client_by_name(server).await?; let client = managed.client.clone(); let timeout = managed.tool_timeout; let uri = params.uri.clone(); @@ -520,18 +471,31 @@ impl McpConnectionManager { .with_context(|| format!("resources/read failed for `{server}` ({uri})")) } - pub fn parse_tool_name(&self, tool_name: &str) -> Option<(String, String)> { - self.tools + pub async fn parse_tool_name(&self, tool_name: &str) -> Option<(String, String)> { + self.list_all_tools() + .await .get(tool_name) .map(|tool| (tool.server_name.clone(), tool.tool_name.clone())) } } +async fn emit_update( + tx_event: &Sender, + update: McpStartupUpdateEvent, +) -> Result<(), async_channel::SendError> { + tx_event + .send(Event { + id: INITIAL_SUBMIT_ID.to_owned(), + msg: EventMsg::McpStartupUpdate(update), + }) + .await +} + /// A tool is allowed to be used if both are true: /// 1. enabled is None (no allowlist is set) or the tool is explicitly enabled. /// 2. The tool is not explicitly disabled. #[derive(Default, Clone)] -struct ToolFilter { +pub(crate) struct ToolFilter { enabled: Option>, disabled: HashSet, } @@ -562,14 +526,10 @@ impl ToolFilter { } } -fn filter_tools(tools: Vec, filters: &HashMap) -> Vec { +fn filter_tools(tools: Vec, filter: ToolFilter) -> Vec { tools .into_iter() - .filter(|tool| { - filters - .get(&tool.server_name) - .is_none_or(|filter| filter.allows(&tool.tool_name)) - }) + .filter(|tool| filter.allows(&tool.tool_name)) .collect() } @@ -600,70 +560,254 @@ fn resolve_bearer_token( } } -/// Query every server for its available tools and return a single map that -/// contains all tools. Each key is the fully-qualified name for the tool. -async fn list_all_tools(clients: &HashMap) -> Result> { - let mut join_set = JoinSet::new(); +#[derive(Debug, Clone, thiserror::Error)] +enum StartupOutcomeError { + #[error("MCP startup cancelled")] + Cancelled, + // We can't store the original error here because anyhow::Error doesn't implement + // `Clone`. + #[error("MCP startup failed: {error}")] + Failed { error: String }, +} - // Spawn one task per server so we can query them concurrently. This - // keeps the overall latency roughly at the slowest server instead of - // the cumulative latency. - for (server_name, managed_client) in clients { - let server_name_cloned = server_name.clone(); - let client_clone = managed_client.client.clone(); - let startup_timeout = managed_client.startup_timeout; - join_set.spawn(async move { - let res = client_clone.list_tools(None, Some(startup_timeout)).await; - (server_name_cloned, res) - }); - } - - let mut aggregated: Vec = Vec::with_capacity(join_set.len()); - - while let Some(join_res) = join_set.join_next().await { - let (server_name, list_result) = if let Ok(result) = join_res { - result - } else { - warn!("Task panic when listing tools for MCP server: {join_res:#?}"); - continue; - }; - - let list_result = if let Ok(result) = list_result { - result - } else { - warn!("Failed to list tools for MCP server '{server_name}': {list_result:#?}"); - continue; - }; - - for tool in list_result.tools { - let tool_info = ToolInfo { - server_name: server_name.clone(), - tool_name: tool.name.clone(), - tool, - }; - aggregated.push(tool_info); +impl From for StartupOutcomeError { + fn from(error: anyhow::Error) -> Self { + Self::Failed { + error: error.to_string(), } } - - info!( - "aggregated {} tools from {} servers", - aggregated.len(), - clients.len() - ); - - Ok(aggregated) } -fn is_valid_mcp_server_name(server_name: &str) -> bool { - !server_name.is_empty() - && server_name - .chars() - .all(|c| c.is_ascii_alphanumeric() || c == '_' || c == '-') +async fn start_server_task( + server_name: String, + transport: McpServerTransportConfig, + store_mode: OAuthCredentialsStoreMode, + startup_timeout: Duration, // TODO: cancel_token should handle this. + tool_timeout: Duration, + tool_filter: ToolFilter, + cancel_token: CancellationToken, +) -> Result { + if cancel_token.is_cancelled() { + return Err(StartupOutcomeError::Cancelled); + } + if let Err(error) = validate_mcp_server_name(&server_name) { + return Err(error.into()); + } + + match start_server_work( + server_name, + transport, + store_mode, + startup_timeout, + tool_timeout, + tool_filter, + ) + .or_cancel(&cancel_token) + .await + { + Ok(result) => result, + Err(CancelErr::Cancelled) => Err(StartupOutcomeError::Cancelled), + } } +async fn start_server_work( + server_name: String, + transport: McpServerTransportConfig, + store_mode: OAuthCredentialsStoreMode, + startup_timeout: Duration, + tool_timeout: Duration, + tool_filter: ToolFilter, +) -> Result { + let params = mcp_types::InitializeRequestParams { + capabilities: ClientCapabilities { + experimental: None, + roots: None, + sampling: None, + // https://modelcontextprotocol.io/specification/2025-06-18/client/elicitation#capabilities + // indicates this should be an empty object. + elicitation: Some(json!({})), + }, + client_info: Implementation { + name: "codex-mcp-client".to_owned(), + version: env!("CARGO_PKG_VERSION").to_owned(), + title: Some("Codex".into()), + // This field is used by Codex when it is an MCP + // server: it should not be used when Codex is + // an MCP client. + user_agent: None, + }, + protocol_version: mcp_types::MCP_SCHEMA_VERSION.to_owned(), + }; + + let client_result = match transport { + McpServerTransportConfig::Stdio { + command, + args, + env, + env_vars, + cwd, + } => { + let command_os: OsString = command.into(); + let args_os: Vec = args.into_iter().map(Into::into).collect(); + match RmcpClient::new_stdio_client(command_os, args_os, env, &env_vars, cwd).await { + Ok(client) => { + let client = Arc::new(client); + client + .initialize(params.clone(), Some(startup_timeout)) + .await + .map(|_| client) + } + Err(err) => Err(err.into()), + } + } + McpServerTransportConfig::StreamableHttp { + url, + http_headers, + env_http_headers, + bearer_token_env_var, + } => { + let resolved_bearer_token = + match resolve_bearer_token(&server_name, bearer_token_env_var.as_deref()) { + Ok(token) => token, + Err(error) => return Err(error.into()), + }; + match RmcpClient::new_streamable_http_client( + &server_name, + &url, + resolved_bearer_token, + http_headers, + env_http_headers, + store_mode, + ) + .await + { + Ok(client) => { + let client = Arc::new(client); + client + .initialize(params.clone(), Some(startup_timeout)) + .await + .map(|_| client) + } + Err(err) => Err(err), + } + } + }; + + let client = match client_result { + Ok(client) => client, + Err(error) => { + return Err(error.into()); + } + }; + + let tools = match list_tools_for_client(&server_name, &client, startup_timeout).await { + Ok(tools) => tools, + Err(error) => { + return Err(error.into()); + } + }; + + let managed = ManagedClient { + client: Arc::clone(&client), + tools, + tool_timeout: Some(tool_timeout), + tool_filter, + }; + + Ok(managed) +} + +async fn list_tools_for_client( + server_name: &str, + client: &Arc, + timeout: Duration, +) -> Result> { + let resp = client.list_tools(None, Some(timeout)).await?; + Ok(resp + .tools + .into_iter() + .map(|tool| ToolInfo { + server_name: server_name.to_owned(), + tool_name: tool.name.clone(), + tool, + }) + .collect()) +} + +fn validate_mcp_server_name(server_name: &str) -> Result<()> { + let re = regex_lite::Regex::new(r"^[a-zA-Z0-9_-]+$")?; + if !re.is_match(server_name) { + return Err(anyhow!( + "Invalid MCP server name '{server_name}': must match pattern {pattern}", + pattern = re.as_str() + )); + } + Ok(()) +} + +fn mcp_init_error_display( + server_name: &str, + entry: Option<&McpAuthStatusEntry>, + err: &StartupOutcomeError, +) -> String { + if let Some(McpServerTransportConfig::StreamableHttp { + url, + bearer_token_env_var, + http_headers, + .. + }) = &entry.map(|entry| &entry.config.transport) + && url == "https://api.githubcopilot.com/mcp/" + && bearer_token_env_var.is_none() + && http_headers.as_ref().map(HashMap::is_empty).unwrap_or(true) + { + format!( + "GitHub MCP does not support OAuth. Log in by adding a personal access token (https://github.com/settings/personal-access-tokens) to your environment and config.toml:\n[mcp_servers.{server_name}]\nbearer_token_env_var = CODEX_GITHUB_PERSONAL_ACCESS_TOKEN" + ) + } else if is_mcp_client_auth_required_error(err) { + format!( + "The {server_name} MCP server is not logged in. Run `codex mcp login {server_name}`." + ) + } else if is_mcp_client_startup_timeout_error(err) { + let startup_timeout_secs = match entry { + Some(entry) => match entry.config.startup_timeout_sec { + Some(timeout) => timeout, + None => DEFAULT_STARTUP_TIMEOUT, + }, + None => DEFAULT_STARTUP_TIMEOUT, + } + .as_secs(); + format!( + "MCP client for `{server_name}` timed out after {startup_timeout_secs} seconds. Add or adjust `startup_timeout_sec` in your config.toml:\n[mcp_servers.{server_name}]\nstartup_timeout_sec = XX" + ) + } else { + format!("MCP client for `{server_name}` failed to start: {err:#}") + } +} + +fn is_mcp_client_auth_required_error(error: &StartupOutcomeError) -> bool { + match error { + StartupOutcomeError::Failed { error } => error.contains("Auth required"), + _ => false, + } +} + +fn is_mcp_client_startup_timeout_error(error: &StartupOutcomeError) -> bool { + match error { + StartupOutcomeError::Failed { error } => { + error.contains("request timed out") + || error.contains("timed out handshaking with MCP server") + } + _ => false, + } +} + +#[cfg(test)] +mod mcp_init_error_display_tests {} + #[cfg(test)] mod tests { use super::*; + use codex_protocol::protocol::McpAuthStatus; use mcp_types::ToolInputSchema; use std::collections::HashSet; @@ -792,31 +936,112 @@ mod tests { #[test] fn filter_tools_applies_per_server_filters() { - let tools = vec![ + let server1_tools = vec![ create_test_tool("server1", "tool_a"), create_test_tool("server1", "tool_b"), - create_test_tool("server2", "tool_a"), ]; - let mut filters = HashMap::new(); - filters.insert( - "server1".to_string(), - ToolFilter { - enabled: Some(HashSet::from(["tool_a".to_string(), "tool_b".to_string()])), - disabled: HashSet::from(["tool_b".to_string()]), - }, - ); - filters.insert( - "server2".to_string(), - ToolFilter { - enabled: None, - disabled: HashSet::from(["tool_a".to_string()]), - }, - ); + let server2_tools = vec![create_test_tool("server2", "tool_a")]; + let server1_filter = ToolFilter { + enabled: Some(HashSet::from(["tool_a".to_string(), "tool_b".to_string()])), + disabled: HashSet::from(["tool_b".to_string()]), + }; + let server2_filter = ToolFilter { + enabled: None, + disabled: HashSet::from(["tool_a".to_string()]), + }; - let filtered = filter_tools(tools, &filters); + let filtered: Vec<_> = filter_tools(server1_tools, server1_filter) + .into_iter() + .chain(filter_tools(server2_tools, server2_filter)) + .collect(); assert_eq!(filtered.len(), 1); assert_eq!(filtered[0].server_name, "server1"); assert_eq!(filtered[0].tool_name, "tool_a"); } + + #[test] + fn mcp_init_error_display_prompts_for_github_pat() { + let server_name = "github"; + let entry = McpAuthStatusEntry { + config: McpServerConfig { + transport: McpServerTransportConfig::StreamableHttp { + url: "https://api.githubcopilot.com/mcp/".to_string(), + bearer_token_env_var: None, + http_headers: None, + env_http_headers: None, + }, + enabled: true, + startup_timeout_sec: None, + tool_timeout_sec: None, + enabled_tools: None, + disabled_tools: None, + }, + auth_status: McpAuthStatus::Unsupported, + }; + let err: StartupOutcomeError = anyhow::anyhow!("OAuth is unsupported").into(); + + let display = mcp_init_error_display(server_name, Some(&entry), &err); + + let expected = format!( + "GitHub MCP does not support OAuth. Log in by adding a personal access token (https://github.com/settings/personal-access-tokens) to your environment and config.toml:\n[mcp_servers.{server_name}]\nbearer_token_env_var = CODEX_GITHUB_PERSONAL_ACCESS_TOKEN" + ); + + assert_eq!(expected, display); + } + + #[test] + fn mcp_init_error_display_prompts_for_login_when_auth_required() { + let server_name = "example"; + let err: StartupOutcomeError = anyhow::anyhow!("Auth required for server").into(); + + let display = mcp_init_error_display(server_name, None, &err); + + let expected = format!( + "The {server_name} MCP server is not logged in. Run `codex mcp login {server_name}`." + ); + + assert_eq!(expected, display); + } + + #[test] + fn mcp_init_error_display_reports_generic_errors() { + let server_name = "custom"; + let entry = McpAuthStatusEntry { + config: McpServerConfig { + transport: McpServerTransportConfig::StreamableHttp { + url: "https://example.com".to_string(), + bearer_token_env_var: Some("TOKEN".to_string()), + http_headers: None, + env_http_headers: None, + }, + enabled: true, + startup_timeout_sec: None, + tool_timeout_sec: None, + enabled_tools: None, + disabled_tools: None, + }, + auth_status: McpAuthStatus::Unsupported, + }; + let err: StartupOutcomeError = anyhow::anyhow!("boom").into(); + + let display = mcp_init_error_display(server_name, Some(&entry), &err); + + let expected = format!("MCP client for `{server_name}` failed to start: {err:#}"); + + assert_eq!(expected, display); + } + + #[test] + fn mcp_init_error_display_includes_startup_timeout_hint() { + let server_name = "slow"; + let err: StartupOutcomeError = anyhow::anyhow!("request timed out").into(); + + let display = mcp_init_error_display(server_name, None, &err); + + assert_eq!( + "MCP client for `slow` timed out after 10 seconds. Add or adjust `startup_timeout_sec` in your config.toml:\n[mcp_servers.slow]\nstartup_timeout_sec = XX", + display + ); + } } diff --git a/codex-rs/core/src/rollout/policy.rs b/codex-rs/core/src/rollout/policy.rs index e00883264..75fed0988 100644 --- a/codex-rs/core/src/rollout/policy.rs +++ b/codex-rs/core/src/rollout/policy.rs @@ -72,6 +72,8 @@ pub(crate) fn should_persist_event_msg(ev: &EventMsg) -> bool { | EventMsg::GetHistoryEntryResponse(_) | EventMsg::UndoStarted(_) | EventMsg::McpListToolsResponse(_) + | EventMsg::McpStartupUpdate(_) + | EventMsg::McpStartupComplete(_) | EventMsg::ListCustomPromptsResponse(_) | EventMsg::PlanUpdate(_) | EventMsg::ShutdownComplete diff --git a/codex-rs/core/src/state/service.rs b/codex-rs/core/src/state/service.rs index ad6f5f90e..287fb73d2 100644 --- a/codex-rs/core/src/state/service.rs +++ b/codex-rs/core/src/state/service.rs @@ -8,9 +8,12 @@ use crate::unified_exec::UnifiedExecSessionManager; use crate::user_notification::UserNotifier; use codex_otel::otel_event_manager::OtelEventManager; use tokio::sync::Mutex; +use tokio::sync::RwLock; +use tokio_util::sync::CancellationToken; pub(crate) struct SessionServices { - pub(crate) mcp_connection_manager: McpConnectionManager, + pub(crate) mcp_connection_manager: Arc>, + pub(crate) mcp_startup_cancellation_token: CancellationToken, pub(crate) unified_exec_manager: UnifiedExecSessionManager, pub(crate) notifier: UserNotifier, pub(crate) rollout: Mutex>, diff --git a/codex-rs/core/src/tools/handlers/mcp_resource.rs b/codex-rs/core/src/tools/handlers/mcp_resource.rs index b601591ac..4dac72fbc 100644 --- a/codex-rs/core/src/tools/handlers/mcp_resource.rs +++ b/codex-rs/core/src/tools/handlers/mcp_resource.rs @@ -287,6 +287,8 @@ async fn handle_list_resources( let resources = session .services .mcp_connection_manager + .read() + .await .list_all_resources() .await; Ok(ListResourcesPayload::from_all_servers(resources)) @@ -396,6 +398,8 @@ async fn handle_list_resource_templates( let templates = session .services .mcp_connection_manager + .read() + .await .list_all_resource_templates() .await; Ok(ListResourceTemplatesPayload::from_all_servers(templates)) diff --git a/codex-rs/core/src/tools/router.rs b/codex-rs/core/src/tools/router.rs index 19098aa80..7152d3c1e 100644 --- a/codex-rs/core/src/tools/router.rs +++ b/codex-rs/core/src/tools/router.rs @@ -54,7 +54,7 @@ impl ToolRouter { .any(|config| config.spec.name() == tool_name) } - pub fn build_tool_call( + pub async fn build_tool_call( session: &Session, item: ResponseItem, ) -> Result, FunctionCallError> { @@ -65,7 +65,7 @@ impl ToolRouter { call_id, .. } => { - if let Some((server, tool)) = session.parse_mcp_tool_name(&name) { + if let Some((server, tool)) = session.parse_mcp_tool_name(&name).await { Ok(Some(ToolCall { tool_name: name, call_id, diff --git a/codex-rs/exec/src/event_processor_with_human_output.rs b/codex-rs/exec/src/event_processor_with_human_output.rs index 93e0e493b..8c7bb6881 100644 --- a/codex-rs/exec/src/event_processor_with_human_output.rs +++ b/codex-rs/exec/src/event_processor_with_human_output.rs @@ -182,6 +182,42 @@ impl EventProcessor for EventProcessorWithHumanOutput { ts_msg!(self, " {}", details.style(self.dimmed)); } } + EventMsg::McpStartupUpdate(update) => { + let status_text = match update.status { + codex_core::protocol::McpStartupStatus::Starting => "starting".to_string(), + codex_core::protocol::McpStartupStatus::Ready => "ready".to_string(), + codex_core::protocol::McpStartupStatus::Cancelled => "cancelled".to_string(), + codex_core::protocol::McpStartupStatus::Failed { ref error } => { + format!("failed: {error}") + } + }; + ts_msg!( + self, + "{} {} {}", + "mcp:".style(self.cyan), + update.server, + status_text + ); + } + EventMsg::McpStartupComplete(summary) => { + let mut parts = Vec::new(); + if !summary.ready.is_empty() { + parts.push(format!("ready: {}", summary.ready.join(", "))); + } + if !summary.failed.is_empty() { + let servers: Vec<_> = summary.failed.iter().map(|f| f.server.clone()).collect(); + parts.push(format!("failed: {}", servers.join(", "))); + } + if !summary.cancelled.is_empty() { + parts.push(format!("cancelled: {}", summary.cancelled.join(", "))); + } + let joined = if parts.is_empty() { + "no servers".to_string() + } else { + parts.join("; ") + }; + ts_msg!(self, "{} {}", "mcp startup:".style(self.cyan), joined); + } EventMsg::BackgroundEvent(BackgroundEventEvent { message }) => { ts_msg!(self, "{}", message.style(self.dimmed)); } diff --git a/codex-rs/mcp-server/src/codex_tool_runner.rs b/codex-rs/mcp-server/src/codex_tool_runner.rs index 96e875153..be9cbaaf7 100644 --- a/codex-rs/mcp-server/src/codex_tool_runner.rs +++ b/codex-rs/mcp-server/src/codex_tool_runner.rs @@ -258,6 +258,9 @@ async fn run_codex_tool_session_inner( EventMsg::AgentReasoningDelta(_) => { // TODO: think how we want to support this in the MCP } + EventMsg::McpStartupUpdate(_) | EventMsg::McpStartupComplete(_) => { + // Ignored in MCP tool runner. + } EventMsg::AgentMessage(AgentMessageEvent { .. }) => { // TODO: think how we want to support this in the MCP } diff --git a/codex-rs/protocol/src/protocol.rs b/codex-rs/protocol/src/protocol.rs index defdd9385..4a4a61044 100644 --- a/codex-rs/protocol/src/protocol.rs +++ b/codex-rs/protocol/src/protocol.rs @@ -478,6 +478,12 @@ pub enum EventMsg { /// Ack the client's configure message. SessionConfigured(SessionConfiguredEvent), + /// Incremental MCP startup progress updates. + McpStartupUpdate(McpStartupUpdateEvent), + + /// Aggregate MCP startup completion summary. + McpStartupComplete(McpStartupCompleteEvent), + McpToolCallBegin(McpToolCallBeginEvent), McpToolCallEnd(McpToolCallEndEvent), @@ -1383,6 +1389,37 @@ pub struct McpListToolsResponseEvent { pub auth_statuses: std::collections::HashMap, } +#[derive(Debug, Clone, Deserialize, Serialize, JsonSchema, TS)] +pub struct McpStartupUpdateEvent { + /// Server name being started. + pub server: String, + /// Current startup status. + pub status: McpStartupStatus, +} + +#[derive(Debug, Clone, Deserialize, Serialize, JsonSchema, TS)] +#[serde(rename_all = "snake_case", tag = "state")] +#[ts(rename_all = "snake_case", tag = "state")] +pub enum McpStartupStatus { + Starting, + Ready, + Failed { error: String }, + Cancelled, +} + +#[derive(Debug, Clone, Deserialize, Serialize, JsonSchema, TS, Default)] +pub struct McpStartupCompleteEvent { + pub ready: Vec, + pub failed: Vec, + pub cancelled: Vec, +} + +#[derive(Debug, Clone, Deserialize, Serialize, JsonSchema, TS)] +pub struct McpStartupFailure { + pub server: String, + pub error: String, +} + #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, JsonSchema, TS)] #[serde(rename_all = "snake_case")] #[ts(rename_all = "snake_case")] @@ -1589,4 +1626,47 @@ mod tests { assert_eq!(deserialized, event); Ok(()) } + + #[test] + fn serialize_mcp_startup_update_event() -> Result<()> { + let event = Event { + id: "init".to_string(), + msg: EventMsg::McpStartupUpdate(McpStartupUpdateEvent { + server: "srv".to_string(), + status: McpStartupStatus::Failed { + error: "boom".to_string(), + }, + }), + }; + + let value = serde_json::to_value(&event)?; + assert_eq!(value["msg"]["type"], "mcp_startup_update"); + assert_eq!(value["msg"]["server"], "srv"); + assert_eq!(value["msg"]["status"]["state"], "failed"); + assert_eq!(value["msg"]["status"]["error"], "boom"); + Ok(()) + } + + #[test] + fn serialize_mcp_startup_complete_event() -> Result<()> { + let event = Event { + id: "init".to_string(), + msg: EventMsg::McpStartupComplete(McpStartupCompleteEvent { + ready: vec!["a".to_string()], + failed: vec![McpStartupFailure { + server: "b".to_string(), + error: "bad".to_string(), + }], + cancelled: vec!["c".to_string()], + }), + }; + + let value = serde_json::to_value(&event)?; + assert_eq!(value["msg"]["type"], "mcp_startup_complete"); + assert_eq!(value["msg"]["ready"][0], "a"); + assert_eq!(value["msg"]["failed"][0]["server"], "b"); + assert_eq!(value["msg"]["failed"][0]["error"], "bad"); + assert_eq!(value["msg"]["cancelled"][0], "c"); + Ok(()) + } } diff --git a/codex-rs/tui/src/bottom_pane/mod.rs b/codex-rs/tui/src/bottom_pane/mod.rs index 685c71c87..da2efb63c 100644 --- a/codex-rs/tui/src/bottom_pane/mod.rs +++ b/codex-rs/tui/src/bottom_pane/mod.rs @@ -279,20 +279,23 @@ impl BottomPane { // esc_backtrack_hint_visible removed; hints are controlled internally. pub fn set_task_running(&mut self, running: bool) { + let was_running = self.is_task_running; self.is_task_running = running; self.composer.set_task_running(running); if running { - if self.status.is_none() { - self.status = Some(StatusIndicatorWidget::new( - self.app_event_tx.clone(), - self.frame_requester.clone(), - )); + if !was_running { + if self.status.is_none() { + self.status = Some(StatusIndicatorWidget::new( + self.app_event_tx.clone(), + self.frame_requester.clone(), + )); + } + if let Some(status) = self.status.as_mut() { + status.set_interrupt_hint_visible(true); + } + self.request_redraw(); } - if let Some(status) = self.status.as_mut() { - status.set_interrupt_hint_visible(true); - } - self.request_redraw(); } else { // Hide the status indicator when a task completes, but keep other modal views. self.hide_status_indicator(); diff --git a/codex-rs/tui/src/chatwidget.rs b/codex-rs/tui/src/chatwidget.rs index 79781f335..265b5f5fc 100644 --- a/codex-rs/tui/src/chatwidget.rs +++ b/codex-rs/tui/src/chatwidget.rs @@ -27,6 +27,9 @@ use codex_core::protocol::ExecCommandSource; use codex_core::protocol::ExitedReviewModeEvent; use codex_core::protocol::ListCustomPromptsResponseEvent; use codex_core::protocol::McpListToolsResponseEvent; +use codex_core::protocol::McpStartupCompleteEvent; +use codex_core::protocol::McpStartupStatus; +use codex_core::protocol::McpStartupUpdateEvent; use codex_core::protocol::McpToolCallBeginEvent; use codex_core::protocol::McpToolCallEndEvent; use codex_core::protocol::Op; @@ -259,6 +262,7 @@ pub(crate) struct ChatWidget { stream_controller: Option, running_commands: HashMap, task_complete_pending: bool, + mcp_startup_status: Option>, // Queue of interruptive UI events deferred during an active write cycle interrupts: InterruptManager, // Accumulates the current reasoning block text to extract a header @@ -567,8 +571,76 @@ impl ChatWidget { self.maybe_send_next_queued_input(); } - fn on_warning(&mut self, message: String) { - self.add_to_history(history_cell::new_warning_event(message)); + fn on_warning(&mut self, message: impl Into) { + self.add_to_history(history_cell::new_warning_event(message.into())); + self.request_redraw(); + } + + fn on_mcp_startup_update(&mut self, ev: McpStartupUpdateEvent) { + let mut status = self.mcp_startup_status.take().unwrap_or_default(); + if let McpStartupStatus::Failed { error } = &ev.status { + self.on_warning(error); + } + status.insert(ev.server, ev.status); + self.mcp_startup_status = Some(status); + self.bottom_pane.set_task_running(true); + if let Some(current) = &self.mcp_startup_status { + let total = current.len(); + let mut starting: Vec<_> = current + .iter() + .filter_map(|(name, state)| { + if matches!(state, McpStartupStatus::Starting) { + Some(name) + } else { + None + } + }) + .collect(); + starting.sort(); + if let Some(first) = starting.first() { + let completed = total.saturating_sub(starting.len()); + let max_to_show = 3; + let mut to_show: Vec = starting + .iter() + .take(max_to_show) + .map(ToString::to_string) + .collect(); + if starting.len() > max_to_show { + to_show.push("…".to_string()); + } + let header = if total > 1 { + format!( + "Starting MCP servers ({completed}/{total}): {}", + to_show.join(", ") + ) + } else { + format!("Booting MCP server: {first}") + }; + self.set_status_header(header); + } + } + self.request_redraw(); + } + + fn on_mcp_startup_complete(&mut self, ev: McpStartupCompleteEvent) { + let mut parts = Vec::new(); + if !ev.failed.is_empty() { + let failed_servers: Vec<_> = ev.failed.iter().map(|f| f.server.clone()).collect(); + parts.push(format!("failed: {}", failed_servers.join(", "))); + } + if !ev.cancelled.is_empty() { + self.on_warning(format!( + "MCP startup interrupted. The following servers were not initialized: {}", + ev.cancelled.join(", ") + )); + } + if !parts.is_empty() { + self.on_warning(format!("MCP startup incomplete ({})", parts.join("; "))); + } + + self.mcp_startup_status = None; + self.bottom_pane.set_task_running(false); + self.maybe_send_next_queued_input(); self.request_redraw(); } @@ -1061,6 +1133,7 @@ impl ChatWidget { stream_controller: None, running_commands: HashMap::new(), task_complete_pending: false, + mcp_startup_status: None, interrupts: InterruptManager::new(), reasoning_buffer: String::new(), full_reasoning_buffer: String::new(), @@ -1128,6 +1201,7 @@ impl ChatWidget { stream_controller: None, running_commands: HashMap::new(), task_complete_pending: false, + mcp_startup_status: None, interrupts: InterruptManager::new(), reasoning_buffer: String::new(), full_reasoning_buffer: String::new(), @@ -1540,6 +1614,8 @@ impl ChatWidget { } EventMsg::Warning(WarningEvent { message }) => self.on_warning(message), EventMsg::Error(ErrorEvent { message }) => self.on_error(message), + EventMsg::McpStartupUpdate(ev) => self.on_mcp_startup_update(ev), + EventMsg::McpStartupComplete(ev) => self.on_mcp_startup_complete(ev), EventMsg::TurnAborted(ev) => match ev.reason { TurnAbortReason::Interrupted => { self.on_interrupted_turn(ev.reason); diff --git a/codex-rs/tui/src/chatwidget/tests.rs b/codex-rs/tui/src/chatwidget/tests.rs index 6d7193bf7..554acf2f4 100644 --- a/codex-rs/tui/src/chatwidget/tests.rs +++ b/codex-rs/tui/src/chatwidget/tests.rs @@ -58,8 +58,6 @@ use tempfile::tempdir; use tokio::sync::mpsc::error::TryRecvError; use tokio::sync::mpsc::unbounded_channel; -const TEST_WARNING_MESSAGE: &str = "Heads up: Long conversations and multiple compactions can cause the model to be less accurate. Start a new conversation when possible to keep conversations small and targeted."; - fn test_config() -> Config { // Use base defaults to avoid depending on host state. Config::load_from_base_config_with_overrides( @@ -268,6 +266,7 @@ fn make_chatwidget_manual() -> ( stream_controller: None, running_commands: HashMap::new(), task_complete_pending: false, + mcp_startup_status: None, interrupts: InterruptManager::new(), reasoning_buffer: String::new(), full_reasoning_buffer: String::new(), @@ -2439,7 +2438,7 @@ fn warning_event_adds_warning_history_cell() { chat.handle_codex_event(Event { id: "sub-1".into(), msg: EventMsg::Warning(WarningEvent { - message: TEST_WARNING_MESSAGE.to_string(), + message: "test warning message".to_string(), }), }); @@ -2447,7 +2446,7 @@ fn warning_event_adds_warning_history_cell() { assert_eq!(cells.len(), 1, "expected one warning history cell"); let rendered = lines_to_single_string(&cells[0]); assert!( - rendered.contains(TEST_WARNING_MESSAGE), + rendered.contains("test warning message"), "warning cell missing content: {rendered}" ); } diff --git a/codex-rs/tui/src/history_cell.rs b/codex-rs/tui/src/history_cell.rs index ffefdc6e4..b026170f7 100644 --- a/codex-rs/tui/src/history_cell.rs +++ b/codex-rs/tui/src/history_cell.rs @@ -1018,10 +1018,8 @@ fn try_new_completed_mcp_tool_call_with_image_output( } #[allow(clippy::disallowed_methods)] -pub(crate) fn new_warning_event(message: String) -> PlainHistoryCell { - PlainHistoryCell { - lines: vec![vec![format!("⚠ {message}").yellow()].into()], - } +pub(crate) fn new_warning_event(message: String) -> PrefixedWrappedHistoryCell { + PrefixedWrappedHistoryCell::new(message.yellow(), "⚠ ".yellow(), " ") } #[derive(Debug)] diff --git a/codex-rs/tui/src/status_indicator_widget.rs b/codex-rs/tui/src/status_indicator_widget.rs index ea7627a4f..54979e6d6 100644 --- a/codex-rs/tui/src/status_indicator_widget.rs +++ b/codex-rs/tui/src/status_indicator_widget.rs @@ -145,7 +145,6 @@ impl Renderable for StatusIndicatorWidget { let elapsed_duration = self.elapsed_duration_at(now); let pretty_elapsed = fmt_elapsed_compact(elapsed_duration.as_secs()); - // Plain rendering: no borders or padding so the live cell is visually indistinguishable from terminal scrollback. let mut spans = Vec::with_capacity(5); spans.push(spinner(Some(self.last_resume_at))); spans.push(" ".into());