From b3765a07e8b57fdd21919ae5dde781d70230e7ca Mon Sep 17 00:00:00 2001 From: Casey Chow Date: Fri, 6 Mar 2026 10:02:42 -0500 Subject: [PATCH] [rmcp-client] Recover from streamable HTTP 404 sessions (#13514) ## Summary - add one-time session recovery in `RmcpClient` for streamable HTTP MCP `404` session expiry - rebuild the transport and retry the failed operation once after reinitializing the client state - extend the test server and integration coverage for `404`, `401`, single-retry, and non-session failure scenarios ## Testing - just fmt - cargo test -p codex-rmcp-client (the post-rebase run lost its final summary in the terminal; the suite had passed earlier before the rebase) - just fix -p codex-rmcp-client --- codex-rs/Cargo.lock | 2 + codex-rs/rmcp-client/Cargo.toml | 2 + .../src/bin/test_streamable_http_server.rs | 84 +- codex-rs/rmcp-client/src/rmcp_client.rs | 875 ++++++++++++++---- codex-rs/rmcp-client/src/utils.rs | 28 +- .../tests/streamable_http_recovery.rs | 268 ++++++ 6 files changed, 1046 insertions(+), 213 deletions(-) create mode 100644 codex-rs/rmcp-client/tests/streamable_http_recovery.rs diff --git a/codex-rs/Cargo.lock b/codex-rs/Cargo.lock index ac94100af..d8d000533 100644 --- a/codex-rs/Cargo.lock +++ b/codex-rs/Cargo.lock @@ -2302,7 +2302,9 @@ dependencies = [ "serde_json", "serial_test", "sha2", + "sse-stream", "tempfile", + "thiserror 2.0.18", "tiny_http", "tokio", "tracing", diff --git a/codex-rs/rmcp-client/Cargo.toml b/codex-rs/rmcp-client/Cargo.toml index 951bb0dd1..7393368b3 100644 --- a/codex-rs/rmcp-client/Cargo.toml +++ b/codex-rs/rmcp-client/Cargo.toml @@ -40,6 +40,8 @@ schemars = { workspace = true } serde = { workspace = true, features = ["derive"] } serde_json = { workspace = true } sha2 = { workspace = true } +sse-stream = "0.2.1" +thiserror = { workspace = true } tiny_http = { workspace = true } tokio = { workspace = true, features = [ "io-util", diff --git a/codex-rs/rmcp-client/src/bin/test_streamable_http_server.rs b/codex-rs/rmcp-client/src/bin/test_streamable_http_server.rs index 821850d2a..05ba6089d 100644 --- a/codex-rs/rmcp-client/src/bin/test_streamable_http_server.rs +++ b/codex-rs/rmcp-client/src/bin/test_streamable_http_server.rs @@ -6,7 +6,9 @@ use std::sync::Arc; use axum::Router; use axum::body::Body; +use axum::extract::Json; use axum::extract::State; +use axum::http::Method; use axum::http::Request; use axum::http::StatusCode; use axum::http::header::AUTHORIZATION; @@ -15,6 +17,7 @@ use axum::middleware; use axum::middleware::Next; use axum::response::Response; use axum::routing::get; +use axum::routing::post; use rmcp::ErrorData as McpError; use rmcp::handler::server::ServerHandler; use rmcp::model::CallToolRequestParams; @@ -39,6 +42,7 @@ use rmcp::transport::StreamableHttpService; use rmcp::transport::streamable_http_server::session::local::LocalSessionManager; use serde::Deserialize; use serde_json::json; +use tokio::sync::Mutex; use tokio::task; #[derive(Clone)] @@ -50,6 +54,8 @@ struct TestToolServer { const MEMO_URI: &str = "memo://codex/example-note"; const MEMO_CONTENT: &str = "This is a sample MCP resource served by the rmcp test server."; +const MCP_SESSION_ID_HEADER: &str = "mcp-session-id"; +const SESSION_POST_FAILURE_CONTROL_PATH: &str = "/test/control/session-post-failure"; impl TestToolServer { fn new() -> Self { @@ -116,6 +122,23 @@ impl TestToolServer { } } +#[derive(Clone, Default)] +struct SessionFailureState { + armed_failure: Arc>>, +} + +#[derive(Clone, Debug)] +struct ArmedFailure { + status: StatusCode, + remaining: usize, +} + +#[derive(Debug, Deserialize)] +struct ArmSessionPostFailureRequest { + status: u16, + remaining: usize, +} + #[derive(Deserialize)] struct EchoArgs { message: String, @@ -251,6 +274,7 @@ fn parse_bind_addr() -> Result> { #[tokio::main] async fn main() -> Result<(), Box> { let bind_addr = parse_bind_addr()?; + let session_failure_state = SessionFailureState::default(); let listener = match tokio::net::TcpListener::bind(&bind_addr).await { Ok(listener) => listener, Err(err) if err.kind() == ErrorKind::PermissionDenied => { @@ -264,6 +288,10 @@ async fn main() -> Result<(), Box> { eprintln!("starting rmcp streamable http test server on http://{bind_addr}/mcp"); let router = Router::new() + .route( + SESSION_POST_FAILURE_CONTROL_PATH, + post(arm_session_post_failure), + ) .route( "/.well-known/oauth-authorization-server/mcp", get({ @@ -291,7 +319,12 @@ async fn main() -> Result<(), Box> { Arc::new(LocalSessionManager::default()), StreamableHttpServerConfig::default(), ), - ); + ) + .layer(middleware::from_fn_with_state( + session_failure_state.clone(), + fail_session_post_when_armed, + )) + .with_state(session_failure_state); let router = if let Ok(token) = std::env::var("MCP_EXPECT_BEARER") { let expected = Arc::new(format!("Bearer {token}")); @@ -323,3 +356,52 @@ async fn require_bearer( Err(StatusCode::UNAUTHORIZED) } } + +async fn arm_session_post_failure( + State(state): State, + Json(request): Json, +) -> Result { + let status = StatusCode::from_u16(request.status).map_err(|_| StatusCode::BAD_REQUEST)?; + let armed_failure = if request.remaining == 0 { + None + } else { + Some(ArmedFailure { + status, + remaining: request.remaining, + }) + }; + *state.armed_failure.lock().await = armed_failure; + Ok(StatusCode::NO_CONTENT) +} + +async fn fail_session_post_when_armed( + State(state): State, + request: Request, + next: Next, +) -> Response { + if request.uri().path() != "/mcp" + || request.method() != Method::POST + || !request.headers().contains_key(MCP_SESSION_ID_HEADER) + { + return next.run(request).await; + } + + let mut armed_failure = state.armed_failure.lock().await; + if let Some(failure) = armed_failure.as_mut() + && failure.remaining > 0 + { + failure.remaining -= 1; + let status = failure.status; + if failure.remaining == 0 { + *armed_failure = None; + } + let mut response = Response::new(Body::from(format!( + "forced session failure with status {status}" + ))); + *response.status_mut() = status; + return response; + } + + drop(armed_failure); + next.run(request).await +} diff --git a/codex-rs/rmcp-client/src/rmcp_client.rs b/codex-rs/rmcp-client/src/rmcp_client.rs index 6dbe46dfd..c3bb6be0e 100644 --- a/codex-rs/rmcp-client/src/rmcp_client.rs +++ b/codex-rs/rmcp-client/src/rmcp_client.rs @@ -1,3 +1,4 @@ +use std::borrow::Cow; use std::collections::HashMap; use std::ffi::OsString; use std::io; @@ -9,10 +10,15 @@ use std::time::Duration; use anyhow::Result; use anyhow::anyhow; use futures::FutureExt; +use futures::StreamExt; use futures::future::BoxFuture; +use futures::stream::BoxStream; use oauth2::TokenResponse; +use reqwest::header::ACCEPT; use reqwest::header::AUTHORIZATION; +use reqwest::header::CONTENT_TYPE; use reqwest::header::HeaderMap; +use reqwest::header::WWW_AUTHENTICATE; use rmcp::model::CallToolRequestParams; use rmcp::model::CallToolResult; use rmcp::model::ClientNotification; @@ -42,10 +48,16 @@ use rmcp::transport::auth::AuthClient; use rmcp::transport::auth::AuthError; use rmcp::transport::auth::OAuthState; use rmcp::transport::child_process::TokioChildProcess; +use rmcp::transport::streamable_http_client::AuthRequiredError; +use rmcp::transport::streamable_http_client::StreamableHttpClient; use rmcp::transport::streamable_http_client::StreamableHttpClientTransportConfig; +use rmcp::transport::streamable_http_client::StreamableHttpError; +use rmcp::transport::streamable_http_client::StreamableHttpPostResponse; use serde::Deserialize; use serde::Serialize; use serde_json::Value; +use sse_stream::Sse; +use sse_stream::SseStream; use tokio::io::AsyncBufReadExt; use tokio::io::BufReader; use tokio::process::Command; @@ -63,7 +75,225 @@ use crate::program_resolver; use crate::utils::apply_default_headers; use crate::utils::build_default_headers; use crate::utils::create_env_for_mcp_server; -use crate::utils::run_with_timeout; + +const EVENT_STREAM_MIME_TYPE: &str = "text/event-stream"; +const JSON_MIME_TYPE: &str = "application/json"; +const HEADER_LAST_EVENT_ID: &str = "Last-Event-Id"; +const HEADER_SESSION_ID: &str = "Mcp-Session-Id"; +const NON_JSON_RESPONSE_BODY_PREVIEW_BYTES: usize = 8_192; + +#[derive(Clone)] +struct StreamableHttpResponseClient { + inner: reqwest::Client, +} + +impl StreamableHttpResponseClient { + fn new(inner: reqwest::Client) -> Self { + Self { inner } + } + + fn reqwest_error( + error: reqwest::Error, + ) -> StreamableHttpError { + StreamableHttpError::Client(StreamableHttpResponseClientError::from(error)) + } +} + +#[derive(Debug, thiserror::Error)] +enum StreamableHttpResponseClientError { + #[error("streamable HTTP session expired with 404 Not Found")] + SessionExpired404, + #[error(transparent)] + Reqwest(#[from] reqwest::Error), +} + +impl StreamableHttpClient for StreamableHttpResponseClient { + type Error = StreamableHttpResponseClientError; + + async fn post_message( + &self, + uri: Arc, + message: rmcp::model::ClientJsonRpcMessage, + session_id: Option>, + auth_token: Option, + ) -> std::result::Result> { + let mut request = self + .inner + .post(uri.as_ref()) + .header(ACCEPT, [EVENT_STREAM_MIME_TYPE, JSON_MIME_TYPE].join(", ")); + if let Some(auth_header) = auth_token { + request = request.bearer_auth(auth_header); + } + if let Some(session_id_value) = session_id.as_ref() { + request = request.header(HEADER_SESSION_ID, session_id_value.as_ref()); + } + + let response = request + .json(&message) + .send() + .await + .map_err(StreamableHttpResponseClient::reqwest_error)?; + if response.status() == reqwest::StatusCode::NOT_FOUND && session_id.is_some() { + return Err(StreamableHttpError::Client( + StreamableHttpResponseClientError::SessionExpired404, + )); + } + if response.status() == reqwest::StatusCode::UNAUTHORIZED + && let Some(header) = response.headers().get(WWW_AUTHENTICATE) + { + let header = header + .to_str() + .map_err(|_| { + StreamableHttpError::UnexpectedServerResponse(Cow::Borrowed( + "invalid www-authenticate header value", + )) + })? + .to_string(); + return Err(StreamableHttpError::AuthRequired(AuthRequiredError { + www_authenticate_header: header, + })); + } + + let status = response.status(); + if matches!( + status, + reqwest::StatusCode::ACCEPTED | reqwest::StatusCode::NO_CONTENT + ) { + return Ok(StreamableHttpPostResponse::Accepted); + } + + let content_type = response + .headers() + .get(CONTENT_TYPE) + .and_then(|value| value.to_str().ok()) + .map(str::to_string); + let session_id = response + .headers() + .get(HEADER_SESSION_ID) + .and_then(|value| value.to_str().ok()) + .map(str::to_string); + + match content_type.as_deref() { + Some(ct) if ct.as_bytes().starts_with(EVENT_STREAM_MIME_TYPE.as_bytes()) => { + let event_stream = SseStream::from_byte_stream(response.bytes_stream()).boxed(); + Ok(StreamableHttpPostResponse::Sse(event_stream, session_id)) + } + Some(ct) if ct.as_bytes().starts_with(JSON_MIME_TYPE.as_bytes()) => { + let message = response + .json() + .await + .map_err(StreamableHttpResponseClient::reqwest_error)?; + Ok(StreamableHttpPostResponse::Json(message, session_id)) + } + _ => { + let body = response + .text() + .await + .map_err(StreamableHttpResponseClient::reqwest_error)?; + let mut body_preview = body; + let body_len = body_preview.len(); + if body_len > NON_JSON_RESPONSE_BODY_PREVIEW_BYTES { + let mut boundary = NON_JSON_RESPONSE_BODY_PREVIEW_BYTES; + while !body_preview.is_char_boundary(boundary) { + boundary = boundary.saturating_sub(1); + } + body_preview.truncate(boundary); + body_preview.push_str(&format!( + "... (truncated {} bytes)", + body_len.saturating_sub(boundary) + )); + } + + let content_type = content_type.unwrap_or_else(|| "missing-content-type".into()); + Err(StreamableHttpError::UnexpectedContentType(Some(format!( + "{content_type}; body: {body_preview}" + )))) + } + } + } + + async fn delete_session( + &self, + uri: Arc, + session: Arc, + auth_token: Option, + ) -> std::result::Result<(), StreamableHttpError> { + let mut request_builder = self.inner.delete(uri.as_ref()); + if let Some(auth_header) = auth_token { + request_builder = request_builder.bearer_auth(auth_header); + } + let response = request_builder + .header(HEADER_SESSION_ID, session.as_ref()) + .send() + .await + .map_err(StreamableHttpResponseClient::reqwest_error)?; + + if response.status() == reqwest::StatusCode::METHOD_NOT_ALLOWED { + return Ok(()); + } + + response + .error_for_status() + .map_err(StreamableHttpResponseClient::reqwest_error)?; + Ok(()) + } + + async fn get_stream( + &self, + uri: Arc, + session_id: Arc, + last_event_id: Option, + auth_token: Option, + ) -> std::result::Result< + BoxStream<'static, std::result::Result>, + StreamableHttpError, + > { + let mut request_builder = self + .inner + .get(uri.as_ref()) + .header(ACCEPT, [EVENT_STREAM_MIME_TYPE, JSON_MIME_TYPE].join(", ")) + .header(HEADER_SESSION_ID, session_id.as_ref()); + if let Some(last_event_id) = last_event_id { + request_builder = request_builder.header(HEADER_LAST_EVENT_ID, last_event_id); + } + if let Some(auth_header) = auth_token { + request_builder = request_builder.bearer_auth(auth_header); + } + + let response = request_builder + .send() + .await + .map_err(StreamableHttpResponseClient::reqwest_error)?; + if response.status() == reqwest::StatusCode::METHOD_NOT_ALLOWED { + return Err(StreamableHttpError::ServerDoesNotSupportSse); + } + if response.status() == reqwest::StatusCode::NOT_FOUND { + return Err(StreamableHttpError::Client( + StreamableHttpResponseClientError::SessionExpired404, + )); + } + + let response = response + .error_for_status() + .map_err(StreamableHttpResponseClient::reqwest_error)?; + match response.headers().get(CONTENT_TYPE) { + Some(ct) + if ct.as_bytes().starts_with(EVENT_STREAM_MIME_TYPE.as_bytes()) + || ct.as_bytes().starts_with(JSON_MIME_TYPE.as_bytes()) => {} + Some(ct) => { + return Err(StreamableHttpError::UnexpectedContentType(Some( + String::from_utf8_lossy(ct.as_bytes()).to_string(), + ))); + } + None => { + return Err(StreamableHttpError::UnexpectedContentType(None)); + } + } + + let event_stream = SseStream::from_byte_stream(response.bytes_stream()).boxed(); + Ok(event_stream) + } +} enum PendingTransport { ChildProcess { @@ -71,10 +301,10 @@ enum PendingTransport { process_group_guard: Option, }, StreamableHttp { - transport: StreamableHttpClientTransport, + transport: StreamableHttpClientTransport, }, StreamableHttpWithOAuth { - transport: StreamableHttpClientTransport>, + transport: StreamableHttpClientTransport>, oauth_persistor: OAuthPersistor, }, } @@ -149,6 +379,39 @@ impl Drop for ProcessGroupGuard { } } +#[derive(Clone)] +enum TransportRecipe { + Stdio { + program: OsString, + args: Vec, + env: Option>, + env_vars: Vec, + cwd: Option, + }, + StreamableHttp { + server_name: String, + url: String, + bearer_token: Option, + http_headers: Option>, + env_http_headers: Option>, + store_mode: OAuthCredentialsStoreMode, + }, +} + +#[derive(Clone)] +struct InitializeContext { + timeout: Option, + handler: LoggingClientHandler, +} + +#[derive(Debug, thiserror::Error)] +enum ClientOperationError { + #[error(transparent)] + Service(#[from] rmcp::service::ServiceError), + #[error("timed out awaiting {label} after {duration:?}")] + Timeout { label: String, duration: Duration }, +} + pub type Elicitation = CreateElicitationRequestParams; #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] @@ -199,6 +462,9 @@ pub struct ListToolsWithConnectorIdResult { /// https://github.com/modelcontextprotocol/rust-sdk pub struct RmcpClient { state: Mutex, + transport_recipe: TransportRecipe, + initialize_context: Mutex>, + session_recovery_lock: Mutex<()>, } impl RmcpClient { @@ -209,58 +475,24 @@ impl RmcpClient { env_vars: &[String], cwd: Option, ) -> io::Result { - let program_name = program.to_string_lossy().into_owned(); - - // Build environment for program resolution and subprocess - let envs = create_env_for_mcp_server(env, env_vars); - - // Resolve program to executable path (platform-specific) - let resolved_program = program_resolver::resolve(program, &envs)?; - - let mut command = Command::new(resolved_program); - command - .kill_on_drop(true) - .stdin(Stdio::piped()) - .stdout(Stdio::piped()) - .env_clear() - .envs(envs) - .args(&args); - #[cfg(unix)] - command.process_group(0); - if let Some(cwd) = cwd { - command.current_dir(cwd); - } - - let (transport, stderr) = TokioChildProcess::builder(command) - .stderr(Stdio::piped()) - .spawn()?; - let process_group_guard = transport.id().map(ProcessGroupGuard::new); - - if let Some(stderr) = stderr { - tokio::spawn(async move { - let mut reader = BufReader::new(stderr).lines(); - loop { - match reader.next_line().await { - Ok(Some(line)) => { - info!("MCP server stderr ({program_name}): {line}"); - } - Ok(None) => break, - Err(error) => { - warn!("Failed to read MCP server stderr ({program_name}): {error}"); - break; - } - } - } - }); - } + let transport_recipe = TransportRecipe::Stdio { + program, + args, + env, + env_vars: env_vars.to_vec(), + cwd, + }; + let transport = Self::create_pending_transport(&transport_recipe) + .await + .map_err(io::Error::other)?; Ok(Self { state: Mutex::new(ClientState::Connecting { - transport: Some(PendingTransport::ChildProcess { - transport, - process_group_guard, - }), + transport: Some(transport), }), + transport_recipe, + initialize_context: Mutex::new(None), + session_recovery_lock: Mutex::new(()), }) } @@ -273,77 +505,22 @@ impl RmcpClient { env_http_headers: Option>, store_mode: OAuthCredentialsStoreMode, ) -> Result { - let default_headers = build_default_headers(http_headers, env_http_headers)?; - - let initial_oauth_tokens = - if bearer_token.is_none() && !default_headers.contains_key(AUTHORIZATION) { - match load_oauth_tokens(server_name, url, store_mode) { - Ok(tokens) => tokens, - Err(err) => { - warn!("failed to read tokens for server `{server_name}`: {err}"); - None - } - } - } else { - None - }; - - let transport = if let Some(initial_tokens) = initial_oauth_tokens.clone() { - match create_oauth_transport_and_runtime( - server_name, - url, - initial_tokens.clone(), - store_mode, - default_headers.clone(), - ) - .await - { - Ok((transport, oauth_persistor)) => PendingTransport::StreamableHttpWithOAuth { - transport, - oauth_persistor, - }, - Err(err) - if err.downcast_ref::().is_some_and(|auth_err| { - matches!(auth_err, AuthError::NoAuthorizationSupport) - }) => - { - let access_token = initial_tokens - .token_response - .0 - .access_token() - .secret() - .to_string(); - warn!( - "OAuth metadata discovery is unavailable for MCP server `{server_name}`; falling back to stored bearer token authentication" - ); - let http_config = - StreamableHttpClientTransportConfig::with_uri(url.to_string()) - .auth_header(access_token); - let http_client = - apply_default_headers(reqwest::Client::builder(), &default_headers) - .build()?; - let transport = - StreamableHttpClientTransport::with_client(http_client, http_config); - PendingTransport::StreamableHttp { transport } - } - Err(err) => return Err(err), - } - } else { - let mut http_config = StreamableHttpClientTransportConfig::with_uri(url.to_string()); - if let Some(bearer_token) = bearer_token.clone() { - http_config = http_config.auth_header(bearer_token); - } - - let http_client = - apply_default_headers(reqwest::Client::builder(), &default_headers).build()?; - - let transport = StreamableHttpClientTransport::with_client(http_client, http_config); - PendingTransport::StreamableHttp { transport } + let transport_recipe = TransportRecipe::StreamableHttp { + server_name: server_name.to_string(), + url: url.to_string(), + bearer_token, + http_headers, + env_http_headers, + store_mode, }; + let transport = Self::create_pending_transport(&transport_recipe).await?; Ok(Self { state: Mutex::new(ClientState::Connecting { transport: Some(transport), }), + transport_recipe, + initialize_context: Mutex::new(None), + session_recovery_lock: Mutex::new(()), }) } @@ -356,47 +533,20 @@ impl RmcpClient { send_elicitation: SendElicitation, ) -> Result { let client_handler = LoggingClientHandler::new(params.clone(), send_elicitation); - - let (transport, oauth_persistor, process_group_guard) = { + let pending_transport = { let mut guard = self.state.lock().await; match &mut *guard { ClientState::Connecting { transport } => match transport.take() { - Some(PendingTransport::ChildProcess { - transport, - process_group_guard, - }) => ( - service::serve_client(client_handler.clone(), transport).boxed(), - None, - process_group_guard, - ), - Some(PendingTransport::StreamableHttp { transport }) => ( - service::serve_client(client_handler.clone(), transport).boxed(), - None, - None, - ), - Some(PendingTransport::StreamableHttpWithOAuth { - transport, - oauth_persistor, - }) => ( - service::serve_client(client_handler.clone(), transport).boxed(), - Some(oauth_persistor), - None, - ), + Some(transport) => transport, None => return Err(anyhow!("client already initializing")), }, ClientState::Ready { .. } => return Err(anyhow!("client already initialized")), } }; - let service = match timeout { - Some(duration) => time::timeout(duration, transport) - .await - .map_err(|_| anyhow!("timed out handshaking with MCP server after {duration:?}"))? - .map_err(|err| anyhow!("handshaking with MCP server failed: {err}"))?, - None => transport - .await - .map_err(|err| anyhow!("handshaking with MCP server failed: {err}"))?, - }; + let (service, oauth_persistor, process_group_guard) = + Self::connect_pending_transport(pending_transport, client_handler.clone(), timeout) + .await?; let initialize_result_rmcp = service .peer() @@ -404,11 +554,19 @@ impl RmcpClient { .ok_or_else(|| anyhow!("handshake succeeded but server info was missing"))?; let initialize_result = initialize_result_rmcp.clone(); + { + let mut initialize_context = self.initialize_context.lock().await; + *initialize_context = Some(InitializeContext { + timeout, + handler: client_handler, + }); + } + { let mut guard = self.state.lock().await; *guard = ClientState::Ready { _process_group_guard: process_group_guard, - service: Arc::new(service), + service, oauth: oauth_persistor.clone(), }; } @@ -428,9 +586,12 @@ impl RmcpClient { timeout: Option, ) -> Result { self.refresh_oauth_if_needed().await; - let service = self.service().await?; - let fut = service.list_tools(params); - let result = run_with_timeout(fut, timeout, "tools/list").await?; + let result = self + .run_service_operation("tools/list", timeout, move |service| { + let params = params.clone(); + async move { service.list_tools(params).await }.boxed() + }) + .await?; self.persist_oauth_tokens().await; Ok(result) } @@ -441,10 +602,12 @@ impl RmcpClient { timeout: Option, ) -> Result { self.refresh_oauth_if_needed().await; - let service = self.service().await?; - - let fut = service.list_tools(params); - let result = run_with_timeout(fut, timeout, "tools/list").await?; + let result = self + .run_service_operation("tools/list", timeout, move |service| { + let params = params.clone(); + async move { service.list_tools(params).await }.boxed() + }) + .await?; let tools = result .tools .into_iter() @@ -481,10 +644,12 @@ impl RmcpClient { timeout: Option, ) -> Result { self.refresh_oauth_if_needed().await; - let service = self.service().await?; - - let fut = service.list_resources(params); - let result = run_with_timeout(fut, timeout, "resources/list").await?; + let result = self + .run_service_operation("resources/list", timeout, move |service| { + let params = params.clone(); + async move { service.list_resources(params).await }.boxed() + }) + .await?; self.persist_oauth_tokens().await; Ok(result) } @@ -495,10 +660,12 @@ impl RmcpClient { timeout: Option, ) -> Result { self.refresh_oauth_if_needed().await; - let service = self.service().await?; - - let fut = service.list_resource_templates(params); - let result = run_with_timeout(fut, timeout, "resources/templates/list").await?; + let result = self + .run_service_operation("resources/templates/list", timeout, move |service| { + let params = params.clone(); + async move { service.list_resource_templates(params).await }.boxed() + }) + .await?; self.persist_oauth_tokens().await; Ok(result) } @@ -509,9 +676,12 @@ impl RmcpClient { timeout: Option, ) -> Result { self.refresh_oauth_if_needed().await; - let service = self.service().await?; - let fut = service.read_resource(params); - let result = run_with_timeout(fut, timeout, "resources/read").await?; + let result = self + .run_service_operation("resources/read", timeout, move |service| { + let params = params.clone(); + async move { service.read_resource(params).await }.boxed() + }) + .await?; self.persist_oauth_tokens().await; Ok(result) } @@ -523,7 +693,6 @@ impl RmcpClient { timeout: Option, ) -> Result { self.refresh_oauth_if_needed().await; - let service = self.service().await?; let arguments = match arguments { Some(Value::Object(map)) => Some(map), Some(other) => { @@ -539,8 +708,12 @@ impl RmcpClient { arguments, task: None, }; - let fut = service.call_tool(rmcp_params); - let result = run_with_timeout(fut, timeout, "tools/call").await?; + let result = self + .run_service_operation("tools/call", timeout, move |service| { + let rmcp_params = rmcp_params.clone(); + async move { service.call_tool(rmcp_params).await }.boxed() + }) + .await?; self.persist_oauth_tokens().await; Ok(result) } @@ -550,14 +723,22 @@ impl RmcpClient { method: &str, params: Option, ) -> Result<()> { - let service: Arc> = self.service().await?; - service - .send_notification(ClientNotification::CustomNotification(CustomNotification { - method: method.to_string(), - params, - extensions: Extensions::new(), - })) - .await?; + self.refresh_oauth_if_needed().await; + self.run_service_operation("notifications/custom", None, move |service| { + let params = params.clone(); + async move { + service + .send_notification(ClientNotification::CustomNotification(CustomNotification { + method: method.to_string(), + params, + extensions: Extensions::new(), + })) + .await + } + .boxed() + }) + .await?; + self.persist_oauth_tokens().await; Ok(()) } @@ -566,12 +747,21 @@ impl RmcpClient { method: &str, params: Option, ) -> Result { - let service: Arc> = self.service().await?; - let response = service - .send_request(ClientRequest::CustomRequest(CustomRequest::new( - method, params, - ))) + self.refresh_oauth_if_needed().await; + let response = self + .run_service_operation("requests/custom", None, move |service| { + let params = params.clone(); + async move { + service + .send_request(ClientRequest::CustomRequest(CustomRequest::new( + method, params, + ))) + .await + } + .boxed() + }) .await?; + self.persist_oauth_tokens().await; Ok(response) } @@ -611,6 +801,319 @@ impl RmcpClient { warn!("failed to refresh OAuth tokens: {error}"); } } + + async fn create_pending_transport( + transport_recipe: &TransportRecipe, + ) -> Result { + match transport_recipe { + TransportRecipe::Stdio { + program, + args, + env, + env_vars, + cwd, + } => { + let program_name = program.to_string_lossy().into_owned(); + let envs = create_env_for_mcp_server(env.clone(), env_vars); + let resolved_program = program_resolver::resolve(program.clone(), &envs)?; + + let mut command = Command::new(resolved_program); + command + .kill_on_drop(true) + .stdin(Stdio::piped()) + .stdout(Stdio::piped()) + .env_clear() + .envs(envs) + .args(args); + #[cfg(unix)] + command.process_group(0); + if let Some(cwd) = cwd { + command.current_dir(cwd); + } + + let (transport, stderr) = TokioChildProcess::builder(command) + .stderr(Stdio::piped()) + .spawn()?; + let process_group_guard = transport.id().map(ProcessGroupGuard::new); + + if let Some(stderr) = stderr { + tokio::spawn(async move { + let mut reader = BufReader::new(stderr).lines(); + loop { + match reader.next_line().await { + Ok(Some(line)) => { + info!("MCP server stderr ({program_name}): {line}"); + } + Ok(None) => break, + Err(error) => { + warn!( + "Failed to read MCP server stderr ({program_name}): {error}" + ); + break; + } + } + } + }); + } + + Ok(PendingTransport::ChildProcess { + transport, + process_group_guard, + }) + } + TransportRecipe::StreamableHttp { + server_name, + url, + bearer_token, + http_headers, + env_http_headers, + store_mode, + } => { + let default_headers = + build_default_headers(http_headers.clone(), env_http_headers.clone())?; + + let initial_oauth_tokens = + if bearer_token.is_none() && !default_headers.contains_key(AUTHORIZATION) { + match load_oauth_tokens(server_name, url, *store_mode) { + Ok(tokens) => tokens, + Err(err) => { + warn!("failed to read tokens for server `{server_name}`: {err}"); + None + } + } + } else { + None + }; + + if let Some(initial_tokens) = initial_oauth_tokens.clone() { + match create_oauth_transport_and_runtime( + server_name, + url, + initial_tokens.clone(), + *store_mode, + default_headers.clone(), + ) + .await + { + Ok((transport, oauth_persistor)) => { + Ok(PendingTransport::StreamableHttpWithOAuth { + transport, + oauth_persistor, + }) + } + Err(err) + if err.downcast_ref::().is_some_and(|auth_err| { + matches!(auth_err, AuthError::NoAuthorizationSupport) + }) => + { + let access_token = initial_tokens + .token_response + .0 + .access_token() + .secret() + .to_string(); + warn!( + "OAuth metadata discovery is unavailable for MCP server `{server_name}`; falling back to stored bearer token authentication" + ); + let http_config = + StreamableHttpClientTransportConfig::with_uri(url.clone()) + .auth_header(access_token); + let http_client = + apply_default_headers(reqwest::Client::builder(), &default_headers) + .build()?; + let transport = StreamableHttpClientTransport::with_client( + StreamableHttpResponseClient::new(http_client), + http_config, + ); + Ok(PendingTransport::StreamableHttp { transport }) + } + Err(err) => Err(err), + } + } else { + let mut http_config = + StreamableHttpClientTransportConfig::with_uri(url.clone()); + if let Some(bearer_token) = bearer_token.clone() { + http_config = http_config.auth_header(bearer_token); + } + + let http_client = + apply_default_headers(reqwest::Client::builder(), &default_headers) + .build()?; + + let transport = StreamableHttpClientTransport::with_client( + StreamableHttpResponseClient::new(http_client), + http_config, + ); + Ok(PendingTransport::StreamableHttp { transport }) + } + } + } + } + + async fn connect_pending_transport( + pending_transport: PendingTransport, + client_handler: LoggingClientHandler, + timeout: Option, + ) -> Result<( + Arc>, + Option, + Option, + )> { + let (transport, oauth_persistor, process_group_guard) = match pending_transport { + PendingTransport::ChildProcess { + transport, + process_group_guard, + } => ( + service::serve_client(client_handler, transport).boxed(), + None, + process_group_guard, + ), + PendingTransport::StreamableHttp { transport } => ( + service::serve_client(client_handler, transport).boxed(), + None, + None, + ), + PendingTransport::StreamableHttpWithOAuth { + transport, + oauth_persistor, + } => ( + service::serve_client(client_handler, transport).boxed(), + Some(oauth_persistor), + None, + ), + }; + + let service = match timeout { + Some(duration) => time::timeout(duration, transport) + .await + .map_err(|_| anyhow!("timed out handshaking with MCP server after {duration:?}"))? + .map_err(|err| anyhow!("handshaking with MCP server failed: {err}"))?, + None => transport + .await + .map_err(|err| anyhow!("handshaking with MCP server failed: {err}"))?, + }; + + Ok((Arc::new(service), oauth_persistor, process_group_guard)) + } + + async fn run_service_operation( + &self, + label: &str, + timeout: Option, + operation: F, + ) -> Result + where + F: Fn(Arc>) -> Fut, + Fut: std::future::Future>, + { + let service = self.service().await?; + match Self::run_service_operation_once(Arc::clone(&service), label, timeout, &operation) + .await + { + Ok(result) => Ok(result), + Err(error) if Self::is_session_expired_404(&error) => { + self.reinitialize_after_session_expiry(&service).await?; + let recovered_service = self.service().await?; + Self::run_service_operation_once(recovered_service, label, timeout, &operation) + .await + .map_err(Into::into) + } + Err(error) => Err(error.into()), + } + } + + async fn run_service_operation_once( + service: Arc>, + label: &str, + timeout: Option, + operation: &F, + ) -> std::result::Result + where + F: Fn(Arc>) -> Fut, + Fut: std::future::Future>, + { + match timeout { + Some(duration) => time::timeout(duration, operation(service)) + .await + .map_err(|_| ClientOperationError::Timeout { + label: label.to_string(), + duration, + })? + .map_err(ClientOperationError::from), + None => operation(service).await.map_err(ClientOperationError::from), + } + } + + fn is_session_expired_404(error: &ClientOperationError) -> bool { + let ClientOperationError::Service(rmcp::service::ServiceError::TransportSend(error)) = + error + else { + return false; + }; + + error + .error + .downcast_ref::>() + .is_some_and(|error| { + matches!( + error, + StreamableHttpError::Client( + StreamableHttpResponseClientError::SessionExpired404 + ) + ) + }) + } + + async fn reinitialize_after_session_expiry( + &self, + failed_service: &Arc>, + ) -> Result<()> { + let _recovery_guard = self.session_recovery_lock.lock().await; + + { + let guard = self.state.lock().await; + match &*guard { + ClientState::Ready { service, .. } if !Arc::ptr_eq(service, failed_service) => { + return Ok(()); + } + ClientState::Ready { .. } => {} + ClientState::Connecting { .. } => { + return Err(anyhow!("MCP client not initialized")); + } + } + } + + let initialize_context = self + .initialize_context + .lock() + .await + .clone() + .ok_or_else(|| anyhow!("MCP client cannot recover before initialize succeeds"))?; + let pending_transport = Self::create_pending_transport(&self.transport_recipe).await?; + let (service, oauth_persistor, process_group_guard) = Self::connect_pending_transport( + pending_transport, + initialize_context.handler, + initialize_context.timeout, + ) + .await?; + + { + let mut guard = self.state.lock().await; + *guard = ClientState::Ready { + _process_group_guard: process_group_guard, + service, + oauth: oauth_persistor.clone(), + }; + } + + if let Some(runtime) = oauth_persistor + && let Err(error) = runtime.persist_if_needed().await + { + warn!("failed to persist OAuth tokens after session recovery: {error}"); + } + + Ok(()) + } } async fn create_oauth_transport_and_runtime( @@ -620,7 +1123,7 @@ async fn create_oauth_transport_and_runtime( credentials_store: OAuthCredentialsStoreMode, default_headers: HeaderMap, ) -> Result<( - StreamableHttpClientTransport>, + StreamableHttpClientTransport>, OAuthPersistor, )> { let http_client = @@ -642,7 +1145,7 @@ async fn create_oauth_transport_and_runtime( } }; - let auth_client = AuthClient::new(http_client, manager); + let auth_client = AuthClient::new(StreamableHttpResponseClient::new(http_client), manager); let auth_manager = auth_client.auth_manager.clone(); let transport = StreamableHttpClientTransport::with_client( diff --git a/codex-rs/rmcp-client/src/utils.rs b/codex-rs/rmcp-client/src/utils.rs index e47c1d14b..df09b0f8b 100644 --- a/codex-rs/rmcp-client/src/utils.rs +++ b/codex-rs/rmcp-client/src/utils.rs @@ -1,34 +1,10 @@ -use std::collections::HashMap; -use std::env; -use std::time::Duration; - -use anyhow::Context; use anyhow::Result; -use anyhow::anyhow; use reqwest::ClientBuilder; use reqwest::header::HeaderMap; use reqwest::header::HeaderName; use reqwest::header::HeaderValue; -use rmcp::service::ServiceError; -use tokio::time; - -pub(crate) async fn run_with_timeout( - fut: F, - timeout: Option, - label: &str, -) -> Result -where - F: std::future::Future>, -{ - if let Some(duration) = timeout { - let result = time::timeout(duration, fut) - .await - .with_context(|| anyhow!("timed out awaiting {label} after {duration:?}"))?; - result.map_err(|err| anyhow!("{label} failed: {err}")) - } else { - fut.await.map_err(|err| anyhow!("{label} failed: {err}")) - } -} +use std::collections::HashMap; +use std::env; pub(crate) fn create_env_for_mcp_server( extra_env: Option>, diff --git a/codex-rs/rmcp-client/tests/streamable_http_recovery.rs b/codex-rs/rmcp-client/tests/streamable_http_recovery.rs new file mode 100644 index 000000000..4710fdf78 --- /dev/null +++ b/codex-rs/rmcp-client/tests/streamable_http_recovery.rs @@ -0,0 +1,268 @@ +use std::net::TcpListener; +use std::path::PathBuf; +use std::time::Duration; +use std::time::Instant; + +use codex_rmcp_client::ElicitationAction; +use codex_rmcp_client::ElicitationResponse; +use codex_rmcp_client::OAuthCredentialsStoreMode; +use codex_rmcp_client::RmcpClient; +use codex_utils_cargo_bin::CargoBinError; +use futures::FutureExt as _; +use pretty_assertions::assert_eq; +use rmcp::model::CallToolResult; +use rmcp::model::ClientCapabilities; +use rmcp::model::ElicitationCapability; +use rmcp::model::FormElicitationCapability; +use rmcp::model::Implementation; +use rmcp::model::InitializeRequestParams; +use rmcp::model::ProtocolVersion; +use serde_json::json; +use tokio::net::TcpStream; +use tokio::process::Child; +use tokio::process::Command; +use tokio::time::sleep; + +const SESSION_POST_FAILURE_CONTROL_PATH: &str = "/test/control/session-post-failure"; + +fn streamable_http_server_bin() -> Result { + codex_utils_cargo_bin::cargo_bin("test_streamable_http_server") +} + +fn init_params() -> InitializeRequestParams { + InitializeRequestParams { + meta: None, + capabilities: ClientCapabilities { + experimental: None, + extensions: None, + roots: None, + sampling: None, + elicitation: Some(ElicitationCapability { + form: Some(FormElicitationCapability { + schema_validation: None, + }), + url: None, + }), + tasks: None, + }, + client_info: Implementation { + name: "codex-test".into(), + version: "0.0.0-test".into(), + title: Some("Codex rmcp recovery test".into()), + description: None, + icons: None, + website_url: None, + }, + protocol_version: ProtocolVersion::V_2025_06_18, + } +} + +fn expected_echo_result(message: &str) -> CallToolResult { + CallToolResult { + content: Vec::new(), + structured_content: Some(json!({ + "echo": format!("ECHOING: {message}"), + "env": null, + })), + is_error: Some(false), + meta: None, + } +} + +async fn create_client(base_url: &str) -> anyhow::Result { + let client = RmcpClient::new_streamable_http_client( + "test-streamable-http", + &format!("{base_url}/mcp"), + Some("test-bearer".to_string()), + None, + None, + OAuthCredentialsStoreMode::File, + ) + .await?; + + client + .initialize( + init_params(), + Some(Duration::from_secs(5)), + Box::new(|_, _| { + async { + Ok(ElicitationResponse { + action: ElicitationAction::Accept, + content: Some(json!({})), + meta: None, + }) + } + .boxed() + }), + ) + .await?; + + Ok(client) +} + +async fn call_echo_tool(client: &RmcpClient, message: &str) -> anyhow::Result { + client + .call_tool( + "echo".to_string(), + Some(json!({ "message": message })), + Some(Duration::from_secs(5)), + ) + .await +} + +async fn arm_session_post_failure( + base_url: &str, + status: u16, + remaining: usize, +) -> anyhow::Result<()> { + let response = reqwest::Client::new() + .post(format!("{base_url}{SESSION_POST_FAILURE_CONTROL_PATH}")) + .json(&json!({ + "status": status, + "remaining": remaining, + })) + .send() + .await?; + + assert_eq!(response.status(), reqwest::StatusCode::NO_CONTENT); + Ok(()) +} + +async fn spawn_streamable_http_server() -> anyhow::Result<(Child, String)> { + let listener = TcpListener::bind("127.0.0.1:0")?; + let port = listener.local_addr()?.port(); + drop(listener); + + let bind_addr = format!("127.0.0.1:{port}"); + let base_url = format!("http://{bind_addr}"); + let mut child = Command::new(streamable_http_server_bin()?) + .kill_on_drop(true) + .env("MCP_STREAMABLE_HTTP_BIND_ADDR", &bind_addr) + .spawn()?; + + wait_for_streamable_http_server(&mut child, &bind_addr, Duration::from_secs(5)).await?; + Ok((child, base_url)) +} + +async fn wait_for_streamable_http_server( + server_child: &mut Child, + address: &str, + timeout: Duration, +) -> anyhow::Result<()> { + let deadline = Instant::now() + timeout; + + loop { + if let Some(status) = server_child.try_wait()? { + return Err(anyhow::anyhow!( + "streamable HTTP server exited early with status {status}" + )); + } + + let remaining = deadline.saturating_duration_since(Instant::now()); + if remaining.is_zero() { + return Err(anyhow::anyhow!( + "timed out waiting for streamable HTTP server at {address}: deadline reached" + )); + } + + match tokio::time::timeout(remaining, TcpStream::connect(address)).await { + Ok(Ok(_)) => return Ok(()), + Ok(Err(error)) => { + if Instant::now() >= deadline { + return Err(anyhow::anyhow!( + "timed out waiting for streamable HTTP server at {address}: {error}" + )); + } + } + Err(_) => { + return Err(anyhow::anyhow!( + "timed out waiting for streamable HTTP server at {address}: connect call timed out" + )); + } + } + + sleep(Duration::from_millis(50)).await; + } +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 1)] +async fn streamable_http_404_session_expiry_recovers_and_retries_once() -> anyhow::Result<()> { + let (_server, base_url) = spawn_streamable_http_server().await?; + let client = create_client(&base_url).await?; + + let warmup = call_echo_tool(&client, "warmup").await?; + assert_eq!(warmup, expected_echo_result("warmup")); + + arm_session_post_failure(&base_url, 404, 1).await?; + + let recovered = call_echo_tool(&client, "recovered").await?; + assert_eq!(recovered, expected_echo_result("recovered")); + + Ok(()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 1)] +async fn streamable_http_401_does_not_trigger_recovery() -> anyhow::Result<()> { + let (_server, base_url) = spawn_streamable_http_server().await?; + let client = create_client(&base_url).await?; + + let warmup = call_echo_tool(&client, "warmup").await?; + assert_eq!(warmup, expected_echo_result("warmup")); + + arm_session_post_failure(&base_url, 401, 2).await?; + + let first_error = call_echo_tool(&client, "unauthorized").await.unwrap_err(); + assert!(first_error.to_string().contains("401")); + + let second_error = call_echo_tool(&client, "still-unauthorized") + .await + .unwrap_err(); + assert!(second_error.to_string().contains("401")); + + Ok(()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 1)] +async fn streamable_http_404_recovery_only_retries_once() -> anyhow::Result<()> { + let (_server, base_url) = spawn_streamable_http_server().await?; + let client = create_client(&base_url).await?; + + let warmup = call_echo_tool(&client, "warmup").await?; + assert_eq!(warmup, expected_echo_result("warmup")); + + arm_session_post_failure(&base_url, 404, 2).await?; + + let error = call_echo_tool(&client, "double-404").await.unwrap_err(); + assert!( + error + .to_string() + .contains("handshaking with MCP server failed") + || error.to_string().contains("Transport channel closed") + ); + + let recovered = call_echo_tool(&client, "after-double-404").await?; + assert_eq!(recovered, expected_echo_result("after-double-404")); + + Ok(()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 1)] +async fn streamable_http_non_session_failure_does_not_trigger_recovery() -> anyhow::Result<()> { + let (_server, base_url) = spawn_streamable_http_server().await?; + let client = create_client(&base_url).await?; + + let warmup = call_echo_tool(&client, "warmup").await?; + assert_eq!(warmup, expected_echo_result("warmup")); + + arm_session_post_failure(&base_url, 500, 2).await?; + + let first_error = call_echo_tool(&client, "server-error").await.unwrap_err(); + assert!(first_error.to_string().contains("500")); + + let second_error = call_echo_tool(&client, "still-server-error") + .await + .unwrap_err(); + assert!(second_error.to_string().contains("500")); + + Ok(()) +}