Forward session and turn headers to MCP HTTP requests (#15011)
## Summary - forward request-scoped task headers through MCP tool metadata lookups and tool calls - apply those headers to streamable HTTP initialize, tools/list, and tools/call requests - update affected rmcp/core tests for the new request_headers plumbing ## Testing - cargo test -p codex-rmcp-client - cargo test -p codex-core (fails on pre-existing unrelated error in core/src/auth_env_telemetry.rs: missing websocket_connect_timeout_ms in ModelProviderInfo initializer) - just fix -p codex-rmcp-client - just fix -p codex-core (hits the same unrelated auth_env_telemetry.rs error) - just fmt --------- Co-authored-by: Codex <noreply@openai.com>
This commit is contained in:
parent
20f2a216df
commit
b14689df3b
6 changed files with 157 additions and 12 deletions
|
|
@ -125,6 +125,8 @@ use futures::future::BoxFuture;
|
|||
use futures::future::Shared;
|
||||
use futures::prelude::*;
|
||||
use futures::stream::FuturesOrdered;
|
||||
use reqwest::header::HeaderMap;
|
||||
use reqwest::header::HeaderValue;
|
||||
use rmcp::model::ListResourceTemplatesResult;
|
||||
use rmcp::model::ListResourcesResult;
|
||||
use rmcp::model::PaginatedRequestParams;
|
||||
|
|
@ -3950,6 +3952,45 @@ impl Session {
|
|||
.await
|
||||
}
|
||||
|
||||
pub(crate) async fn sync_mcp_request_headers_for_turn(&self, turn_context: &TurnContext) {
|
||||
let mut request_headers = HeaderMap::new();
|
||||
let session_id = self.conversation_id.to_string();
|
||||
if let Ok(value) = HeaderValue::from_str(&session_id) {
|
||||
request_headers.insert("session_id", value.clone());
|
||||
request_headers.insert("x-client-request-id", value);
|
||||
}
|
||||
if let Some(turn_metadata) = turn_context.turn_metadata_state.current_header_value()
|
||||
&& let Ok(value) = HeaderValue::from_str(&turn_metadata)
|
||||
{
|
||||
request_headers.insert(crate::X_CODEX_TURN_METADATA_HEADER, value);
|
||||
}
|
||||
|
||||
let request_headers = if request_headers.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(request_headers)
|
||||
};
|
||||
self.services
|
||||
.mcp_connection_manager
|
||||
.read()
|
||||
.await
|
||||
.set_request_headers_for_server(
|
||||
crate::mcp::CODEX_APPS_MCP_SERVER_NAME,
|
||||
request_headers,
|
||||
);
|
||||
}
|
||||
|
||||
pub(crate) async fn clear_mcp_request_headers(&self) {
|
||||
self.services
|
||||
.mcp_connection_manager
|
||||
.read()
|
||||
.await
|
||||
.set_request_headers_for_server(
|
||||
crate::mcp::CODEX_APPS_MCP_SERVER_NAME,
|
||||
/*request_headers*/ None,
|
||||
);
|
||||
}
|
||||
|
||||
pub(crate) async fn parse_mcp_tool_name(
|
||||
&self,
|
||||
name: &str,
|
||||
|
|
|
|||
|
|
@ -423,6 +423,7 @@ impl ManagedClient {
|
|||
#[derive(Clone)]
|
||||
struct AsyncManagedClient {
|
||||
client: Shared<BoxFuture<'static, Result<ManagedClient, StartupOutcomeError>>>,
|
||||
request_headers: Arc<StdMutex<Option<reqwest::header::HeaderMap>>>,
|
||||
startup_snapshot: Option<Vec<ToolInfo>>,
|
||||
startup_complete: Arc<AtomicBool>,
|
||||
tool_plugin_provenance: Arc<ToolPluginProvenance>,
|
||||
|
|
@ -448,17 +449,26 @@ impl AsyncManagedClient {
|
|||
codex_apps_tools_cache_context.as_ref(),
|
||||
)
|
||||
.map(|tools| filter_tools(tools, &tool_filter));
|
||||
let request_headers = Arc::new(StdMutex::new(None));
|
||||
let startup_tool_filter = tool_filter;
|
||||
let startup_complete = Arc::new(AtomicBool::new(false));
|
||||
let startup_complete_for_fut = Arc::clone(&startup_complete);
|
||||
let request_headers_for_client = Arc::clone(&request_headers);
|
||||
let fut = async move {
|
||||
let outcome = async {
|
||||
if let Err(error) = validate_mcp_server_name(&server_name) {
|
||||
return Err(error.into());
|
||||
}
|
||||
|
||||
let client =
|
||||
Arc::new(make_rmcp_client(&server_name, config.transport, store_mode).await?);
|
||||
let client = Arc::new(
|
||||
make_rmcp_client(
|
||||
&server_name,
|
||||
config.transport,
|
||||
store_mode,
|
||||
request_headers_for_client,
|
||||
)
|
||||
.await?,
|
||||
);
|
||||
match start_server_task(
|
||||
server_name,
|
||||
client,
|
||||
|
|
@ -495,6 +505,7 @@ impl AsyncManagedClient {
|
|||
|
||||
Self {
|
||||
client,
|
||||
request_headers,
|
||||
startup_snapshot,
|
||||
startup_complete,
|
||||
tool_plugin_provenance,
|
||||
|
|
@ -576,6 +587,14 @@ impl AsyncManagedClient {
|
|||
let managed = self.client().await?;
|
||||
managed.notify_sandbox_state_change(sandbox_state).await
|
||||
}
|
||||
|
||||
fn set_request_headers(&self, request_headers: Option<reqwest::header::HeaderMap>) {
|
||||
let mut guard = self
|
||||
.request_headers
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner);
|
||||
*guard = request_headers;
|
||||
}
|
||||
}
|
||||
|
||||
pub const MCP_SANDBOX_STATE_CAPABILITY: &str = "codex/sandbox-state";
|
||||
|
|
@ -1046,6 +1065,16 @@ impl McpConnectionManager {
|
|||
})
|
||||
}
|
||||
|
||||
pub(crate) fn set_request_headers_for_server(
|
||||
&self,
|
||||
server_name: &str,
|
||||
request_headers: Option<reqwest::header::HeaderMap>,
|
||||
) {
|
||||
if let Some(client) = self.clients.get(server_name) {
|
||||
client.set_request_headers(request_headers);
|
||||
}
|
||||
}
|
||||
|
||||
/// List resources from the specified server.
|
||||
pub async fn list_resources(
|
||||
&self,
|
||||
|
|
@ -1429,6 +1458,7 @@ async fn make_rmcp_client(
|
|||
server_name: &str,
|
||||
transport: McpServerTransportConfig,
|
||||
store_mode: OAuthCredentialsStoreMode,
|
||||
request_headers: Arc<StdMutex<Option<reqwest::header::HeaderMap>>>,
|
||||
) -> Result<RmcpClient, StartupOutcomeError> {
|
||||
match transport {
|
||||
McpServerTransportConfig::Stdio {
|
||||
|
|
@ -1462,6 +1492,7 @@ async fn make_rmcp_client(
|
|||
http_headers,
|
||||
env_http_headers,
|
||||
store_mode,
|
||||
request_headers,
|
||||
)
|
||||
.await
|
||||
.map_err(StartupOutcomeError::from)
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ use codex_protocol::protocol::McpAuthStatus;
|
|||
use rmcp::model::JsonObject;
|
||||
use std::collections::HashSet;
|
||||
use std::sync::Arc;
|
||||
use std::sync::Mutex as StdMutex;
|
||||
use tempfile::tempdir;
|
||||
|
||||
fn create_test_tool(server_name: &str, tool_name: &str) -> ToolInfo {
|
||||
|
|
@ -413,6 +414,7 @@ async fn list_all_tools_uses_startup_snapshot_while_client_is_pending() {
|
|||
CODEX_APPS_MCP_SERVER_NAME.to_string(),
|
||||
AsyncManagedClient {
|
||||
client: pending_client,
|
||||
request_headers: Arc::new(StdMutex::new(None)),
|
||||
startup_snapshot: Some(startup_tools),
|
||||
startup_complete: Arc::new(std::sync::atomic::AtomicBool::new(false)),
|
||||
tool_plugin_provenance: Arc::new(ToolPluginProvenance::default()),
|
||||
|
|
@ -438,6 +440,7 @@ async fn list_all_tools_blocks_while_client_is_pending_without_startup_snapshot(
|
|||
CODEX_APPS_MCP_SERVER_NAME.to_string(),
|
||||
AsyncManagedClient {
|
||||
client: pending_client,
|
||||
request_headers: Arc::new(StdMutex::new(None)),
|
||||
startup_snapshot: None,
|
||||
startup_complete: Arc::new(std::sync::atomic::AtomicBool::new(false)),
|
||||
tool_plugin_provenance: Arc::new(ToolPluginProvenance::default()),
|
||||
|
|
@ -460,6 +463,7 @@ async fn list_all_tools_does_not_block_when_startup_snapshot_cache_hit_is_empty(
|
|||
CODEX_APPS_MCP_SERVER_NAME.to_string(),
|
||||
AsyncManagedClient {
|
||||
client: pending_client,
|
||||
request_headers: Arc::new(StdMutex::new(None)),
|
||||
startup_snapshot: Some(Vec::new()),
|
||||
startup_complete: Arc::new(std::sync::atomic::AtomicBool::new(false)),
|
||||
tool_plugin_provenance: Arc::new(ToolPluginProvenance::default()),
|
||||
|
|
@ -492,6 +496,7 @@ async fn list_all_tools_uses_startup_snapshot_when_client_startup_fails() {
|
|||
CODEX_APPS_MCP_SERVER_NAME.to_string(),
|
||||
AsyncManagedClient {
|
||||
client: failed_client,
|
||||
request_headers: Arc::new(StdMutex::new(None)),
|
||||
startup_snapshot: Some(startup_tools),
|
||||
startup_complete,
|
||||
tool_plugin_provenance: Arc::new(ToolPluginProvenance::default()),
|
||||
|
|
|
|||
|
|
@ -153,6 +153,8 @@ impl Session {
|
|||
) {
|
||||
self.abort_all_tasks(TurnAbortReason::Replaced).await;
|
||||
self.clear_connector_selection().await;
|
||||
self.sync_mcp_request_headers_for_turn(turn_context.as_ref())
|
||||
.await;
|
||||
|
||||
let task: Arc<dyn SessionTask> = Arc::new(task);
|
||||
let task_kind = task.kind();
|
||||
|
|
@ -233,6 +235,7 @@ impl Session {
|
|||
// in-flight approval wait can surface as a model-visible rejection before TurnAborted.
|
||||
active_turn.clear_pending().await;
|
||||
}
|
||||
self.clear_mcp_request_headers().await;
|
||||
}
|
||||
|
||||
pub async fn on_task_finished(
|
||||
|
|
@ -262,6 +265,9 @@ impl Session {
|
|||
*active = None;
|
||||
}
|
||||
drop(active);
|
||||
if should_clear_active_turn {
|
||||
self.clear_mcp_request_headers().await;
|
||||
}
|
||||
if !pending_input.is_empty() {
|
||||
for pending_input_item in pending_input {
|
||||
match inspect_pending_input(self, &turn_context, pending_input_item).await {
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ use std::io;
|
|||
use std::path::PathBuf;
|
||||
use std::process::Stdio;
|
||||
use std::sync::Arc;
|
||||
use std::sync::Mutex as StdMutex;
|
||||
use std::time::Duration;
|
||||
|
||||
use anyhow::Result;
|
||||
|
|
@ -22,6 +23,7 @@ use reqwest::header::HeaderMap;
|
|||
use reqwest::header::WWW_AUTHENTICATE;
|
||||
use rmcp::model::CallToolRequestParams;
|
||||
use rmcp::model::CallToolResult;
|
||||
use rmcp::model::ClientJsonRpcMessage;
|
||||
use rmcp::model::ClientNotification;
|
||||
use rmcp::model::ClientRequest;
|
||||
use rmcp::model::CreateElicitationRequestParams;
|
||||
|
|
@ -83,14 +85,45 @@ const HEADER_LAST_EVENT_ID: &str = "Last-Event-Id";
|
|||
const HEADER_SESSION_ID: &str = "Mcp-Session-Id";
|
||||
const NON_JSON_RESPONSE_BODY_PREVIEW_BYTES: usize = 8_192;
|
||||
|
||||
fn message_uses_request_scoped_headers(message: &ClientJsonRpcMessage) -> bool {
|
||||
matches!(
|
||||
message,
|
||||
ClientJsonRpcMessage::Request(request)
|
||||
if request.request.method() == "tools/call"
|
||||
)
|
||||
}
|
||||
|
||||
fn apply_request_scoped_headers(
|
||||
mut request: reqwest::RequestBuilder,
|
||||
request_headers_state: &Arc<StdMutex<Option<HeaderMap>>>,
|
||||
) -> reqwest::RequestBuilder {
|
||||
let extra_headers = request_headers_state
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner)
|
||||
.clone();
|
||||
if let Some(extra_headers) = extra_headers {
|
||||
for (name, value) in &extra_headers {
|
||||
request = request.header(name, value.clone());
|
||||
}
|
||||
}
|
||||
request
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct StreamableHttpResponseClient {
|
||||
inner: reqwest::Client,
|
||||
request_headers_state: Arc<StdMutex<Option<HeaderMap>>>,
|
||||
}
|
||||
|
||||
impl StreamableHttpResponseClient {
|
||||
fn new(inner: reqwest::Client) -> Self {
|
||||
Self { inner }
|
||||
fn new(
|
||||
inner: reqwest::Client,
|
||||
request_headers_state: Arc<StdMutex<Option<HeaderMap>>>,
|
||||
) -> Self {
|
||||
Self {
|
||||
inner,
|
||||
request_headers_state,
|
||||
}
|
||||
}
|
||||
|
||||
fn reqwest_error(
|
||||
|
|
@ -133,6 +166,9 @@ impl StreamableHttpClient for StreamableHttpResponseClient {
|
|||
if let Some(session_id_value) = session_id.as_ref() {
|
||||
request = request.header(HEADER_SESSION_ID, session_id_value.as_ref());
|
||||
}
|
||||
if message_uses_request_scoped_headers(&message) {
|
||||
request = apply_request_scoped_headers(request, &self.request_headers_state);
|
||||
}
|
||||
|
||||
let response = request
|
||||
.json(&message)
|
||||
|
|
@ -472,6 +508,7 @@ pub struct RmcpClient {
|
|||
transport_recipe: TransportRecipe,
|
||||
initialize_context: Mutex<Option<InitializeContext>>,
|
||||
session_recovery_lock: Mutex<()>,
|
||||
request_headers: Option<Arc<StdMutex<Option<HeaderMap>>>>,
|
||||
}
|
||||
|
||||
impl RmcpClient {
|
||||
|
|
@ -489,9 +526,10 @@ impl RmcpClient {
|
|||
env_vars: env_vars.to_vec(),
|
||||
cwd,
|
||||
};
|
||||
let transport = Self::create_pending_transport(&transport_recipe)
|
||||
.await
|
||||
.map_err(io::Error::other)?;
|
||||
let transport =
|
||||
Self::create_pending_transport(&transport_recipe, /*request_headers*/ None)
|
||||
.await
|
||||
.map_err(io::Error::other)?;
|
||||
|
||||
Ok(Self {
|
||||
state: Mutex::new(ClientState::Connecting {
|
||||
|
|
@ -500,6 +538,7 @@ impl RmcpClient {
|
|||
transport_recipe,
|
||||
initialize_context: Mutex::new(None),
|
||||
session_recovery_lock: Mutex::new(()),
|
||||
request_headers: None,
|
||||
})
|
||||
}
|
||||
|
||||
|
|
@ -511,6 +550,7 @@ impl RmcpClient {
|
|||
http_headers: Option<HashMap<String, String>>,
|
||||
env_http_headers: Option<HashMap<String, String>>,
|
||||
store_mode: OAuthCredentialsStoreMode,
|
||||
request_headers: Arc<StdMutex<Option<HeaderMap>>>,
|
||||
) -> Result<Self> {
|
||||
let transport_recipe = TransportRecipe::StreamableHttp {
|
||||
server_name: server_name.to_string(),
|
||||
|
|
@ -520,7 +560,9 @@ impl RmcpClient {
|
|||
env_http_headers,
|
||||
store_mode,
|
||||
};
|
||||
let transport = Self::create_pending_transport(&transport_recipe).await?;
|
||||
let transport =
|
||||
Self::create_pending_transport(&transport_recipe, Some(Arc::clone(&request_headers)))
|
||||
.await?;
|
||||
Ok(Self {
|
||||
state: Mutex::new(ClientState::Connecting {
|
||||
transport: Some(transport),
|
||||
|
|
@ -528,6 +570,7 @@ impl RmcpClient {
|
|||
transport_recipe,
|
||||
initialize_context: Mutex::new(None),
|
||||
session_recovery_lock: Mutex::new(()),
|
||||
request_headers: Some(request_headers),
|
||||
})
|
||||
}
|
||||
|
||||
|
|
@ -830,6 +873,7 @@ impl RmcpClient {
|
|||
|
||||
async fn create_pending_transport(
|
||||
transport_recipe: &TransportRecipe,
|
||||
request_headers: Option<Arc<StdMutex<Option<HeaderMap>>>>,
|
||||
) -> Result<PendingTransport> {
|
||||
match transport_recipe {
|
||||
TransportRecipe::Stdio {
|
||||
|
|
@ -946,7 +990,12 @@ impl RmcpClient {
|
|||
.auth_header(access_token);
|
||||
let http_client = build_http_client(&default_headers)?;
|
||||
let transport = StreamableHttpClientTransport::with_client(
|
||||
StreamableHttpResponseClient::new(http_client),
|
||||
StreamableHttpResponseClient::new(
|
||||
http_client,
|
||||
request_headers
|
||||
.clone()
|
||||
.unwrap_or_else(|| Arc::new(StdMutex::new(None))),
|
||||
),
|
||||
http_config,
|
||||
);
|
||||
Ok(PendingTransport::StreamableHttp { transport })
|
||||
|
|
@ -963,7 +1012,12 @@ impl RmcpClient {
|
|||
let http_client = build_http_client(&default_headers)?;
|
||||
|
||||
let transport = StreamableHttpClientTransport::with_client(
|
||||
StreamableHttpResponseClient::new(http_client),
|
||||
StreamableHttpResponseClient::new(
|
||||
http_client,
|
||||
request_headers
|
||||
.clone()
|
||||
.unwrap_or_else(|| Arc::new(StdMutex::new(None))),
|
||||
),
|
||||
http_config,
|
||||
);
|
||||
Ok(PendingTransport::StreamableHttp { transport })
|
||||
|
|
@ -1111,7 +1165,9 @@ impl RmcpClient {
|
|||
.await
|
||||
.clone()
|
||||
.ok_or_else(|| anyhow!("MCP client cannot recover before initialize succeeds"))?;
|
||||
let pending_transport = Self::create_pending_transport(&self.transport_recipe).await?;
|
||||
let pending_transport =
|
||||
Self::create_pending_transport(&self.transport_recipe, self.request_headers.clone())
|
||||
.await?;
|
||||
let (service, oauth_persistor, process_group_guard) = Self::connect_pending_transport(
|
||||
pending_transport,
|
||||
initialize_context.handler,
|
||||
|
|
@ -1166,7 +1222,10 @@ async fn create_oauth_transport_and_runtime(
|
|||
}
|
||||
};
|
||||
|
||||
let auth_client = AuthClient::new(StreamableHttpResponseClient::new(http_client), manager);
|
||||
let auth_client = AuthClient::new(
|
||||
StreamableHttpResponseClient::new(http_client, Arc::new(StdMutex::new(None))),
|
||||
manager,
|
||||
);
|
||||
let auth_manager = auth_client.auth_manager.clone();
|
||||
|
||||
let transport = StreamableHttpClientTransport::with_client(
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
use std::net::TcpListener;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use std::sync::Mutex as StdMutex;
|
||||
use std::time::Duration;
|
||||
use std::time::Instant;
|
||||
|
||||
|
|
@ -77,6 +79,7 @@ async fn create_client(base_url: &str) -> anyhow::Result<RmcpClient> {
|
|||
None,
|
||||
None,
|
||||
OAuthCredentialsStoreMode::File,
|
||||
Arc::new(StdMutex::new(None)),
|
||||
)
|
||||
.await?;
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue