diff --git a/codex-rs/core/src/codex.rs b/codex-rs/core/src/codex.rs index b90d702b2..8197c8cb6 100644 --- a/codex-rs/core/src/codex.rs +++ b/codex-rs/core/src/codex.rs @@ -125,8 +125,6 @@ use futures::future::BoxFuture; use futures::future::Shared; use futures::prelude::*; use futures::stream::FuturesOrdered; -use reqwest::header::HeaderMap; -use reqwest::header::HeaderValue; use rmcp::model::ListResourceTemplatesResult; use rmcp::model::ListResourcesResult; use rmcp::model::PaginatedRequestParams; @@ -3952,45 +3950,6 @@ impl Session { .await } - pub(crate) async fn sync_mcp_request_headers_for_turn(&self, turn_context: &TurnContext) { - let mut request_headers = HeaderMap::new(); - let session_id = self.conversation_id.to_string(); - if let Ok(value) = HeaderValue::from_str(&session_id) { - request_headers.insert("session_id", value.clone()); - request_headers.insert("x-client-request-id", value); - } - if let Some(turn_metadata) = turn_context.turn_metadata_state.current_header_value() - && let Ok(value) = HeaderValue::from_str(&turn_metadata) - { - request_headers.insert(crate::X_CODEX_TURN_METADATA_HEADER, value); - } - - let request_headers = if request_headers.is_empty() { - None - } else { - Some(request_headers) - }; - self.services - .mcp_connection_manager - .read() - .await - .set_request_headers_for_server( - crate::mcp::CODEX_APPS_MCP_SERVER_NAME, - request_headers, - ); - } - - pub(crate) async fn clear_mcp_request_headers(&self) { - self.services - .mcp_connection_manager - .read() - .await - .set_request_headers_for_server( - crate::mcp::CODEX_APPS_MCP_SERVER_NAME, - /*request_headers*/ None, - ); - } - pub(crate) async fn parse_mcp_tool_name( &self, name: &str, diff --git a/codex-rs/core/src/mcp_connection_manager.rs b/codex-rs/core/src/mcp_connection_manager.rs index 7c8a34307..938d6d0b2 100644 --- a/codex-rs/core/src/mcp_connection_manager.rs +++ b/codex-rs/core/src/mcp_connection_manager.rs @@ -423,7 +423,6 @@ impl ManagedClient { #[derive(Clone)] struct AsyncManagedClient { client: Shared>>, - request_headers: Arc>>, startup_snapshot: Option>, startup_complete: Arc, tool_plugin_provenance: Arc, @@ -449,26 +448,17 @@ impl AsyncManagedClient { codex_apps_tools_cache_context.as_ref(), ) .map(|tools| filter_tools(tools, &tool_filter)); - let request_headers = Arc::new(StdMutex::new(None)); let startup_tool_filter = tool_filter; let startup_complete = Arc::new(AtomicBool::new(false)); let startup_complete_for_fut = Arc::clone(&startup_complete); - let request_headers_for_client = Arc::clone(&request_headers); let fut = async move { let outcome = async { if let Err(error) = validate_mcp_server_name(&server_name) { return Err(error.into()); } - let client = Arc::new( - make_rmcp_client( - &server_name, - config.transport, - store_mode, - request_headers_for_client, - ) - .await?, - ); + let client = + Arc::new(make_rmcp_client(&server_name, config.transport, store_mode).await?); match start_server_task( server_name, client, @@ -505,7 +495,6 @@ impl AsyncManagedClient { Self { client, - request_headers, startup_snapshot, startup_complete, tool_plugin_provenance, @@ -587,14 +576,6 @@ impl AsyncManagedClient { let managed = self.client().await?; managed.notify_sandbox_state_change(sandbox_state).await } - - fn set_request_headers(&self, request_headers: Option) { - let mut guard = self - .request_headers - .lock() - .unwrap_or_else(std::sync::PoisonError::into_inner); - *guard = request_headers; - } } pub const MCP_SANDBOX_STATE_CAPABILITY: &str = "codex/sandbox-state"; @@ -1065,16 +1046,6 @@ impl McpConnectionManager { }) } - pub(crate) fn set_request_headers_for_server( - &self, - server_name: &str, - request_headers: Option, - ) { - if let Some(client) = self.clients.get(server_name) { - client.set_request_headers(request_headers); - } - } - /// List resources from the specified server. pub async fn list_resources( &self, @@ -1458,7 +1429,6 @@ async fn make_rmcp_client( server_name: &str, transport: McpServerTransportConfig, store_mode: OAuthCredentialsStoreMode, - request_headers: Arc>>, ) -> Result { match transport { McpServerTransportConfig::Stdio { @@ -1492,7 +1462,6 @@ async fn make_rmcp_client( http_headers, env_http_headers, store_mode, - request_headers, ) .await .map_err(StartupOutcomeError::from) diff --git a/codex-rs/core/src/mcp_connection_manager_tests.rs b/codex-rs/core/src/mcp_connection_manager_tests.rs index 9401b379b..c5f7fc4a4 100644 --- a/codex-rs/core/src/mcp_connection_manager_tests.rs +++ b/codex-rs/core/src/mcp_connection_manager_tests.rs @@ -4,7 +4,6 @@ use codex_protocol::protocol::McpAuthStatus; use rmcp::model::JsonObject; use std::collections::HashSet; use std::sync::Arc; -use std::sync::Mutex as StdMutex; use tempfile::tempdir; fn create_test_tool(server_name: &str, tool_name: &str) -> ToolInfo { @@ -414,7 +413,6 @@ async fn list_all_tools_uses_startup_snapshot_while_client_is_pending() { CODEX_APPS_MCP_SERVER_NAME.to_string(), AsyncManagedClient { client: pending_client, - request_headers: Arc::new(StdMutex::new(None)), startup_snapshot: Some(startup_tools), startup_complete: Arc::new(std::sync::atomic::AtomicBool::new(false)), tool_plugin_provenance: Arc::new(ToolPluginProvenance::default()), @@ -440,7 +438,6 @@ async fn list_all_tools_blocks_while_client_is_pending_without_startup_snapshot( CODEX_APPS_MCP_SERVER_NAME.to_string(), AsyncManagedClient { client: pending_client, - request_headers: Arc::new(StdMutex::new(None)), startup_snapshot: None, startup_complete: Arc::new(std::sync::atomic::AtomicBool::new(false)), tool_plugin_provenance: Arc::new(ToolPluginProvenance::default()), @@ -463,7 +460,6 @@ async fn list_all_tools_does_not_block_when_startup_snapshot_cache_hit_is_empty( CODEX_APPS_MCP_SERVER_NAME.to_string(), AsyncManagedClient { client: pending_client, - request_headers: Arc::new(StdMutex::new(None)), startup_snapshot: Some(Vec::new()), startup_complete: Arc::new(std::sync::atomic::AtomicBool::new(false)), tool_plugin_provenance: Arc::new(ToolPluginProvenance::default()), @@ -496,7 +492,6 @@ async fn list_all_tools_uses_startup_snapshot_when_client_startup_fails() { CODEX_APPS_MCP_SERVER_NAME.to_string(), AsyncManagedClient { client: failed_client, - request_headers: Arc::new(StdMutex::new(None)), startup_snapshot: Some(startup_tools), startup_complete, tool_plugin_provenance: Arc::new(ToolPluginProvenance::default()), diff --git a/codex-rs/core/src/tasks/mod.rs b/codex-rs/core/src/tasks/mod.rs index 049ed56d4..c52e4f917 100644 --- a/codex-rs/core/src/tasks/mod.rs +++ b/codex-rs/core/src/tasks/mod.rs @@ -153,8 +153,6 @@ impl Session { ) { self.abort_all_tasks(TurnAbortReason::Replaced).await; self.clear_connector_selection().await; - self.sync_mcp_request_headers_for_turn(turn_context.as_ref()) - .await; let task: Arc = Arc::new(task); let task_kind = task.kind(); @@ -235,7 +233,6 @@ impl Session { // in-flight approval wait can surface as a model-visible rejection before TurnAborted. active_turn.clear_pending().await; } - self.clear_mcp_request_headers().await; } pub async fn on_task_finished( @@ -265,9 +262,6 @@ impl Session { *active = None; } drop(active); - if should_clear_active_turn { - self.clear_mcp_request_headers().await; - } if !pending_input.is_empty() { for pending_input_item in pending_input { match inspect_pending_input(self, &turn_context, pending_input_item).await { diff --git a/codex-rs/rmcp-client/src/rmcp_client.rs b/codex-rs/rmcp-client/src/rmcp_client.rs index cf4f90ad3..b898403b2 100644 --- a/codex-rs/rmcp-client/src/rmcp_client.rs +++ b/codex-rs/rmcp-client/src/rmcp_client.rs @@ -5,7 +5,6 @@ use std::io; use std::path::PathBuf; use std::process::Stdio; use std::sync::Arc; -use std::sync::Mutex as StdMutex; use std::time::Duration; use anyhow::Result; @@ -23,7 +22,6 @@ use reqwest::header::HeaderMap; use reqwest::header::WWW_AUTHENTICATE; use rmcp::model::CallToolRequestParams; use rmcp::model::CallToolResult; -use rmcp::model::ClientJsonRpcMessage; use rmcp::model::ClientNotification; use rmcp::model::ClientRequest; use rmcp::model::CreateElicitationRequestParams; @@ -85,45 +83,14 @@ const HEADER_LAST_EVENT_ID: &str = "Last-Event-Id"; const HEADER_SESSION_ID: &str = "Mcp-Session-Id"; const NON_JSON_RESPONSE_BODY_PREVIEW_BYTES: usize = 8_192; -fn message_uses_request_scoped_headers(message: &ClientJsonRpcMessage) -> bool { - matches!( - message, - ClientJsonRpcMessage::Request(request) - if request.request.method() == "tools/call" - ) -} - -fn apply_request_scoped_headers( - mut request: reqwest::RequestBuilder, - request_headers_state: &Arc>>, -) -> reqwest::RequestBuilder { - let extra_headers = request_headers_state - .lock() - .unwrap_or_else(std::sync::PoisonError::into_inner) - .clone(); - if let Some(extra_headers) = extra_headers { - for (name, value) in &extra_headers { - request = request.header(name, value.clone()); - } - } - request -} - #[derive(Clone)] struct StreamableHttpResponseClient { inner: reqwest::Client, - request_headers_state: Arc>>, } impl StreamableHttpResponseClient { - fn new( - inner: reqwest::Client, - request_headers_state: Arc>>, - ) -> Self { - Self { - inner, - request_headers_state, - } + fn new(inner: reqwest::Client) -> Self { + Self { inner } } fn reqwest_error( @@ -166,9 +133,6 @@ impl StreamableHttpClient for StreamableHttpResponseClient { if let Some(session_id_value) = session_id.as_ref() { request = request.header(HEADER_SESSION_ID, session_id_value.as_ref()); } - if message_uses_request_scoped_headers(&message) { - request = apply_request_scoped_headers(request, &self.request_headers_state); - } let response = request .json(&message) @@ -508,7 +472,6 @@ pub struct RmcpClient { transport_recipe: TransportRecipe, initialize_context: Mutex>, session_recovery_lock: Mutex<()>, - request_headers: Option>>>, } impl RmcpClient { @@ -526,10 +489,9 @@ impl RmcpClient { env_vars: env_vars.to_vec(), cwd, }; - let transport = - Self::create_pending_transport(&transport_recipe, /*request_headers*/ None) - .await - .map_err(io::Error::other)?; + let transport = Self::create_pending_transport(&transport_recipe) + .await + .map_err(io::Error::other)?; Ok(Self { state: Mutex::new(ClientState::Connecting { @@ -538,7 +500,6 @@ impl RmcpClient { transport_recipe, initialize_context: Mutex::new(None), session_recovery_lock: Mutex::new(()), - request_headers: None, }) } @@ -550,7 +511,6 @@ impl RmcpClient { http_headers: Option>, env_http_headers: Option>, store_mode: OAuthCredentialsStoreMode, - request_headers: Arc>>, ) -> Result { let transport_recipe = TransportRecipe::StreamableHttp { server_name: server_name.to_string(), @@ -560,9 +520,7 @@ impl RmcpClient { env_http_headers, store_mode, }; - let transport = - Self::create_pending_transport(&transport_recipe, Some(Arc::clone(&request_headers))) - .await?; + let transport = Self::create_pending_transport(&transport_recipe).await?; Ok(Self { state: Mutex::new(ClientState::Connecting { transport: Some(transport), @@ -570,7 +528,6 @@ impl RmcpClient { transport_recipe, initialize_context: Mutex::new(None), session_recovery_lock: Mutex::new(()), - request_headers: Some(request_headers), }) } @@ -873,7 +830,6 @@ impl RmcpClient { async fn create_pending_transport( transport_recipe: &TransportRecipe, - request_headers: Option>>>, ) -> Result { match transport_recipe { TransportRecipe::Stdio { @@ -990,12 +946,7 @@ impl RmcpClient { .auth_header(access_token); let http_client = build_http_client(&default_headers)?; let transport = StreamableHttpClientTransport::with_client( - StreamableHttpResponseClient::new( - http_client, - request_headers - .clone() - .unwrap_or_else(|| Arc::new(StdMutex::new(None))), - ), + StreamableHttpResponseClient::new(http_client), http_config, ); Ok(PendingTransport::StreamableHttp { transport }) @@ -1012,12 +963,7 @@ impl RmcpClient { let http_client = build_http_client(&default_headers)?; let transport = StreamableHttpClientTransport::with_client( - StreamableHttpResponseClient::new( - http_client, - request_headers - .clone() - .unwrap_or_else(|| Arc::new(StdMutex::new(None))), - ), + StreamableHttpResponseClient::new(http_client), http_config, ); Ok(PendingTransport::StreamableHttp { transport }) @@ -1165,9 +1111,7 @@ impl RmcpClient { .await .clone() .ok_or_else(|| anyhow!("MCP client cannot recover before initialize succeeds"))?; - let pending_transport = - Self::create_pending_transport(&self.transport_recipe, self.request_headers.clone()) - .await?; + let pending_transport = Self::create_pending_transport(&self.transport_recipe).await?; let (service, oauth_persistor, process_group_guard) = Self::connect_pending_transport( pending_transport, initialize_context.handler, @@ -1222,10 +1166,7 @@ async fn create_oauth_transport_and_runtime( } }; - let auth_client = AuthClient::new( - StreamableHttpResponseClient::new(http_client, Arc::new(StdMutex::new(None))), - manager, - ); + let auth_client = AuthClient::new(StreamableHttpResponseClient::new(http_client), manager); let auth_manager = auth_client.auth_manager.clone(); let transport = StreamableHttpClientTransport::with_client( diff --git a/codex-rs/rmcp-client/tests/streamable_http_recovery.rs b/codex-rs/rmcp-client/tests/streamable_http_recovery.rs index 8b03da8f1..fb2fc96d2 100644 --- a/codex-rs/rmcp-client/tests/streamable_http_recovery.rs +++ b/codex-rs/rmcp-client/tests/streamable_http_recovery.rs @@ -1,7 +1,5 @@ use std::net::TcpListener; use std::path::PathBuf; -use std::sync::Arc; -use std::sync::Mutex as StdMutex; use std::time::Duration; use std::time::Instant; @@ -79,7 +77,6 @@ async fn create_client(base_url: &str) -> anyhow::Result { None, None, OAuthCredentialsStoreMode::File, - Arc::new(StdMutex::new(None)), ) .await?;