diff --git a/codex-rs/core/src/codex.rs b/codex-rs/core/src/codex.rs index a916f3311..45227d813 100644 --- a/codex-rs/core/src/codex.rs +++ b/codex-rs/core/src/codex.rs @@ -125,6 +125,8 @@ use futures::future::BoxFuture; use futures::future::Shared; use futures::prelude::*; use futures::stream::FuturesOrdered; +use reqwest::header::HeaderMap; +use reqwest::header::HeaderValue; use rmcp::model::ListResourceTemplatesResult; use rmcp::model::ListResourcesResult; use rmcp::model::PaginatedRequestParams; @@ -3950,6 +3952,45 @@ impl Session { .await } + pub(crate) async fn sync_mcp_request_headers_for_turn(&self, turn_context: &TurnContext) { + let mut request_headers = HeaderMap::new(); + let session_id = self.conversation_id.to_string(); + if let Ok(value) = HeaderValue::from_str(&session_id) { + request_headers.insert("session_id", value.clone()); + request_headers.insert("x-client-request-id", value); + } + if let Some(turn_metadata) = turn_context.turn_metadata_state.current_header_value() + && let Ok(value) = HeaderValue::from_str(&turn_metadata) + { + request_headers.insert(crate::X_CODEX_TURN_METADATA_HEADER, value); + } + + let request_headers = if request_headers.is_empty() { + None + } else { + Some(request_headers) + }; + self.services + .mcp_connection_manager + .read() + .await + .set_request_headers_for_server( + crate::mcp::CODEX_APPS_MCP_SERVER_NAME, + request_headers, + ); + } + + pub(crate) async fn clear_mcp_request_headers(&self) { + self.services + .mcp_connection_manager + .read() + .await + .set_request_headers_for_server( + crate::mcp::CODEX_APPS_MCP_SERVER_NAME, + /*request_headers*/ None, + ); + } + pub(crate) async fn parse_mcp_tool_name( &self, name: &str, diff --git a/codex-rs/core/src/mcp_connection_manager.rs b/codex-rs/core/src/mcp_connection_manager.rs index 938d6d0b2..7c8a34307 100644 --- a/codex-rs/core/src/mcp_connection_manager.rs +++ b/codex-rs/core/src/mcp_connection_manager.rs @@ -423,6 +423,7 @@ impl ManagedClient { #[derive(Clone)] struct AsyncManagedClient { client: Shared>>, + request_headers: Arc>>, startup_snapshot: Option>, startup_complete: Arc, tool_plugin_provenance: Arc, @@ -448,17 +449,26 @@ impl AsyncManagedClient { codex_apps_tools_cache_context.as_ref(), ) .map(|tools| filter_tools(tools, &tool_filter)); + let request_headers = Arc::new(StdMutex::new(None)); let startup_tool_filter = tool_filter; let startup_complete = Arc::new(AtomicBool::new(false)); let startup_complete_for_fut = Arc::clone(&startup_complete); + let request_headers_for_client = Arc::clone(&request_headers); let fut = async move { let outcome = async { if let Err(error) = validate_mcp_server_name(&server_name) { return Err(error.into()); } - let client = - Arc::new(make_rmcp_client(&server_name, config.transport, store_mode).await?); + let client = Arc::new( + make_rmcp_client( + &server_name, + config.transport, + store_mode, + request_headers_for_client, + ) + .await?, + ); match start_server_task( server_name, client, @@ -495,6 +505,7 @@ impl AsyncManagedClient { Self { client, + request_headers, startup_snapshot, startup_complete, tool_plugin_provenance, @@ -576,6 +587,14 @@ impl AsyncManagedClient { let managed = self.client().await?; managed.notify_sandbox_state_change(sandbox_state).await } + + fn set_request_headers(&self, request_headers: Option) { + let mut guard = self + .request_headers + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + *guard = request_headers; + } } pub const MCP_SANDBOX_STATE_CAPABILITY: &str = "codex/sandbox-state"; @@ -1046,6 +1065,16 @@ impl McpConnectionManager { }) } + pub(crate) fn set_request_headers_for_server( + &self, + server_name: &str, + request_headers: Option, + ) { + if let Some(client) = self.clients.get(server_name) { + client.set_request_headers(request_headers); + } + } + /// List resources from the specified server. pub async fn list_resources( &self, @@ -1429,6 +1458,7 @@ async fn make_rmcp_client( server_name: &str, transport: McpServerTransportConfig, store_mode: OAuthCredentialsStoreMode, + request_headers: Arc>>, ) -> Result { match transport { McpServerTransportConfig::Stdio { @@ -1462,6 +1492,7 @@ async fn make_rmcp_client( http_headers, env_http_headers, store_mode, + request_headers, ) .await .map_err(StartupOutcomeError::from) diff --git a/codex-rs/core/src/mcp_connection_manager_tests.rs b/codex-rs/core/src/mcp_connection_manager_tests.rs index c5f7fc4a4..9401b379b 100644 --- a/codex-rs/core/src/mcp_connection_manager_tests.rs +++ b/codex-rs/core/src/mcp_connection_manager_tests.rs @@ -4,6 +4,7 @@ use codex_protocol::protocol::McpAuthStatus; use rmcp::model::JsonObject; use std::collections::HashSet; use std::sync::Arc; +use std::sync::Mutex as StdMutex; use tempfile::tempdir; fn create_test_tool(server_name: &str, tool_name: &str) -> ToolInfo { @@ -413,6 +414,7 @@ async fn list_all_tools_uses_startup_snapshot_while_client_is_pending() { CODEX_APPS_MCP_SERVER_NAME.to_string(), AsyncManagedClient { client: pending_client, + request_headers: Arc::new(StdMutex::new(None)), startup_snapshot: Some(startup_tools), startup_complete: Arc::new(std::sync::atomic::AtomicBool::new(false)), tool_plugin_provenance: Arc::new(ToolPluginProvenance::default()), @@ -438,6 +440,7 @@ async fn list_all_tools_blocks_while_client_is_pending_without_startup_snapshot( CODEX_APPS_MCP_SERVER_NAME.to_string(), AsyncManagedClient { client: pending_client, + request_headers: Arc::new(StdMutex::new(None)), startup_snapshot: None, startup_complete: Arc::new(std::sync::atomic::AtomicBool::new(false)), tool_plugin_provenance: Arc::new(ToolPluginProvenance::default()), @@ -460,6 +463,7 @@ async fn list_all_tools_does_not_block_when_startup_snapshot_cache_hit_is_empty( CODEX_APPS_MCP_SERVER_NAME.to_string(), AsyncManagedClient { client: pending_client, + request_headers: Arc::new(StdMutex::new(None)), startup_snapshot: Some(Vec::new()), startup_complete: Arc::new(std::sync::atomic::AtomicBool::new(false)), tool_plugin_provenance: Arc::new(ToolPluginProvenance::default()), @@ -492,6 +496,7 @@ async fn list_all_tools_uses_startup_snapshot_when_client_startup_fails() { CODEX_APPS_MCP_SERVER_NAME.to_string(), AsyncManagedClient { client: failed_client, + request_headers: Arc::new(StdMutex::new(None)), startup_snapshot: Some(startup_tools), startup_complete, tool_plugin_provenance: Arc::new(ToolPluginProvenance::default()), diff --git a/codex-rs/core/src/tasks/mod.rs b/codex-rs/core/src/tasks/mod.rs index c52e4f917..049ed56d4 100644 --- a/codex-rs/core/src/tasks/mod.rs +++ b/codex-rs/core/src/tasks/mod.rs @@ -153,6 +153,8 @@ impl Session { ) { self.abort_all_tasks(TurnAbortReason::Replaced).await; self.clear_connector_selection().await; + self.sync_mcp_request_headers_for_turn(turn_context.as_ref()) + .await; let task: Arc = Arc::new(task); let task_kind = task.kind(); @@ -233,6 +235,7 @@ impl Session { // in-flight approval wait can surface as a model-visible rejection before TurnAborted. active_turn.clear_pending().await; } + self.clear_mcp_request_headers().await; } pub async fn on_task_finished( @@ -262,6 +265,9 @@ impl Session { *active = None; } drop(active); + if should_clear_active_turn { + self.clear_mcp_request_headers().await; + } if !pending_input.is_empty() { for pending_input_item in pending_input { match inspect_pending_input(self, &turn_context, pending_input_item).await { diff --git a/codex-rs/rmcp-client/src/rmcp_client.rs b/codex-rs/rmcp-client/src/rmcp_client.rs index b898403b2..cf4f90ad3 100644 --- a/codex-rs/rmcp-client/src/rmcp_client.rs +++ b/codex-rs/rmcp-client/src/rmcp_client.rs @@ -5,6 +5,7 @@ use std::io; use std::path::PathBuf; use std::process::Stdio; use std::sync::Arc; +use std::sync::Mutex as StdMutex; use std::time::Duration; use anyhow::Result; @@ -22,6 +23,7 @@ use reqwest::header::HeaderMap; use reqwest::header::WWW_AUTHENTICATE; use rmcp::model::CallToolRequestParams; use rmcp::model::CallToolResult; +use rmcp::model::ClientJsonRpcMessage; use rmcp::model::ClientNotification; use rmcp::model::ClientRequest; use rmcp::model::CreateElicitationRequestParams; @@ -83,14 +85,45 @@ const HEADER_LAST_EVENT_ID: &str = "Last-Event-Id"; const HEADER_SESSION_ID: &str = "Mcp-Session-Id"; const NON_JSON_RESPONSE_BODY_PREVIEW_BYTES: usize = 8_192; +fn message_uses_request_scoped_headers(message: &ClientJsonRpcMessage) -> bool { + matches!( + message, + ClientJsonRpcMessage::Request(request) + if request.request.method() == "tools/call" + ) +} + +fn apply_request_scoped_headers( + mut request: reqwest::RequestBuilder, + request_headers_state: &Arc>>, +) -> reqwest::RequestBuilder { + let extra_headers = request_headers_state + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner) + .clone(); + if let Some(extra_headers) = extra_headers { + for (name, value) in &extra_headers { + request = request.header(name, value.clone()); + } + } + request +} + #[derive(Clone)] struct StreamableHttpResponseClient { inner: reqwest::Client, + request_headers_state: Arc>>, } impl StreamableHttpResponseClient { - fn new(inner: reqwest::Client) -> Self { - Self { inner } + fn new( + inner: reqwest::Client, + request_headers_state: Arc>>, + ) -> Self { + Self { + inner, + request_headers_state, + } } fn reqwest_error( @@ -133,6 +166,9 @@ impl StreamableHttpClient for StreamableHttpResponseClient { if let Some(session_id_value) = session_id.as_ref() { request = request.header(HEADER_SESSION_ID, session_id_value.as_ref()); } + if message_uses_request_scoped_headers(&message) { + request = apply_request_scoped_headers(request, &self.request_headers_state); + } let response = request .json(&message) @@ -472,6 +508,7 @@ pub struct RmcpClient { transport_recipe: TransportRecipe, initialize_context: Mutex>, session_recovery_lock: Mutex<()>, + request_headers: Option>>>, } impl RmcpClient { @@ -489,9 +526,10 @@ impl RmcpClient { env_vars: env_vars.to_vec(), cwd, }; - let transport = Self::create_pending_transport(&transport_recipe) - .await - .map_err(io::Error::other)?; + let transport = + Self::create_pending_transport(&transport_recipe, /*request_headers*/ None) + .await + .map_err(io::Error::other)?; Ok(Self { state: Mutex::new(ClientState::Connecting { @@ -500,6 +538,7 @@ impl RmcpClient { transport_recipe, initialize_context: Mutex::new(None), session_recovery_lock: Mutex::new(()), + request_headers: None, }) } @@ -511,6 +550,7 @@ impl RmcpClient { http_headers: Option>, env_http_headers: Option>, store_mode: OAuthCredentialsStoreMode, + request_headers: Arc>>, ) -> Result { let transport_recipe = TransportRecipe::StreamableHttp { server_name: server_name.to_string(), @@ -520,7 +560,9 @@ impl RmcpClient { env_http_headers, store_mode, }; - let transport = Self::create_pending_transport(&transport_recipe).await?; + let transport = + Self::create_pending_transport(&transport_recipe, Some(Arc::clone(&request_headers))) + .await?; Ok(Self { state: Mutex::new(ClientState::Connecting { transport: Some(transport), @@ -528,6 +570,7 @@ impl RmcpClient { transport_recipe, initialize_context: Mutex::new(None), session_recovery_lock: Mutex::new(()), + request_headers: Some(request_headers), }) } @@ -830,6 +873,7 @@ impl RmcpClient { async fn create_pending_transport( transport_recipe: &TransportRecipe, + request_headers: Option>>>, ) -> Result { match transport_recipe { TransportRecipe::Stdio { @@ -946,7 +990,12 @@ impl RmcpClient { .auth_header(access_token); let http_client = build_http_client(&default_headers)?; let transport = StreamableHttpClientTransport::with_client( - StreamableHttpResponseClient::new(http_client), + StreamableHttpResponseClient::new( + http_client, + request_headers + .clone() + .unwrap_or_else(|| Arc::new(StdMutex::new(None))), + ), http_config, ); Ok(PendingTransport::StreamableHttp { transport }) @@ -963,7 +1012,12 @@ impl RmcpClient { let http_client = build_http_client(&default_headers)?; let transport = StreamableHttpClientTransport::with_client( - StreamableHttpResponseClient::new(http_client), + StreamableHttpResponseClient::new( + http_client, + request_headers + .clone() + .unwrap_or_else(|| Arc::new(StdMutex::new(None))), + ), http_config, ); Ok(PendingTransport::StreamableHttp { transport }) @@ -1111,7 +1165,9 @@ impl RmcpClient { .await .clone() .ok_or_else(|| anyhow!("MCP client cannot recover before initialize succeeds"))?; - let pending_transport = Self::create_pending_transport(&self.transport_recipe).await?; + let pending_transport = + Self::create_pending_transport(&self.transport_recipe, self.request_headers.clone()) + .await?; let (service, oauth_persistor, process_group_guard) = Self::connect_pending_transport( pending_transport, initialize_context.handler, @@ -1166,7 +1222,10 @@ async fn create_oauth_transport_and_runtime( } }; - let auth_client = AuthClient::new(StreamableHttpResponseClient::new(http_client), manager); + let auth_client = AuthClient::new( + StreamableHttpResponseClient::new(http_client, Arc::new(StdMutex::new(None))), + manager, + ); let auth_manager = auth_client.auth_manager.clone(); let transport = StreamableHttpClientTransport::with_client( diff --git a/codex-rs/rmcp-client/tests/streamable_http_recovery.rs b/codex-rs/rmcp-client/tests/streamable_http_recovery.rs index fb2fc96d2..8b03da8f1 100644 --- a/codex-rs/rmcp-client/tests/streamable_http_recovery.rs +++ b/codex-rs/rmcp-client/tests/streamable_http_recovery.rs @@ -1,5 +1,7 @@ use std::net::TcpListener; use std::path::PathBuf; +use std::sync::Arc; +use std::sync::Mutex as StdMutex; use std::time::Duration; use std::time::Instant; @@ -77,6 +79,7 @@ async fn create_client(base_url: &str) -> anyhow::Result { None, None, OAuthCredentialsStoreMode::File, + Arc::new(StdMutex::new(None)), ) .await?;