[rmcp-client] Recover from streamable HTTP 404 sessions (#13514)

## Summary
- add one-time session recovery in `RmcpClient` for streamable HTTP MCP
`404` session expiry
- rebuild the transport and retry the failed operation once after
reinitializing the client state
- extend the test server and integration coverage for `404`, `401`,
single-retry, and non-session failure scenarios

## Testing
- just fmt
- cargo test -p codex-rmcp-client (the post-rebase run lost its final
summary in the terminal; the suite had passed earlier before the rebase)
- just fix -p codex-rmcp-client
This commit is contained in:
Casey Chow 2026-03-06 10:02:42 -05:00 committed by GitHub
parent 5d4303510c
commit b3765a07e8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 1046 additions and 213 deletions

2
codex-rs/Cargo.lock generated
View file

@ -2302,7 +2302,9 @@ dependencies = [
"serde_json",
"serial_test",
"sha2",
"sse-stream",
"tempfile",
"thiserror 2.0.18",
"tiny_http",
"tokio",
"tracing",

View file

@ -40,6 +40,8 @@ schemars = { workspace = true }
serde = { workspace = true, features = ["derive"] }
serde_json = { workspace = true }
sha2 = { workspace = true }
sse-stream = "0.2.1"
thiserror = { workspace = true }
tiny_http = { workspace = true }
tokio = { workspace = true, features = [
"io-util",

View file

@ -6,7 +6,9 @@ use std::sync::Arc;
use axum::Router;
use axum::body::Body;
use axum::extract::Json;
use axum::extract::State;
use axum::http::Method;
use axum::http::Request;
use axum::http::StatusCode;
use axum::http::header::AUTHORIZATION;
@ -15,6 +17,7 @@ use axum::middleware;
use axum::middleware::Next;
use axum::response::Response;
use axum::routing::get;
use axum::routing::post;
use rmcp::ErrorData as McpError;
use rmcp::handler::server::ServerHandler;
use rmcp::model::CallToolRequestParams;
@ -39,6 +42,7 @@ use rmcp::transport::StreamableHttpService;
use rmcp::transport::streamable_http_server::session::local::LocalSessionManager;
use serde::Deserialize;
use serde_json::json;
use tokio::sync::Mutex;
use tokio::task;
#[derive(Clone)]
@ -50,6 +54,8 @@ struct TestToolServer {
const MEMO_URI: &str = "memo://codex/example-note";
const MEMO_CONTENT: &str = "This is a sample MCP resource served by the rmcp test server.";
const MCP_SESSION_ID_HEADER: &str = "mcp-session-id";
const SESSION_POST_FAILURE_CONTROL_PATH: &str = "/test/control/session-post-failure";
impl TestToolServer {
fn new() -> Self {
@ -116,6 +122,23 @@ impl TestToolServer {
}
}
#[derive(Clone, Default)]
struct SessionFailureState {
armed_failure: Arc<Mutex<Option<ArmedFailure>>>,
}
#[derive(Clone, Debug)]
struct ArmedFailure {
status: StatusCode,
remaining: usize,
}
#[derive(Debug, Deserialize)]
struct ArmSessionPostFailureRequest {
status: u16,
remaining: usize,
}
#[derive(Deserialize)]
struct EchoArgs {
message: String,
@ -251,6 +274,7 @@ fn parse_bind_addr() -> Result<SocketAddr, Box<dyn std::error::Error>> {
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let bind_addr = parse_bind_addr()?;
let session_failure_state = SessionFailureState::default();
let listener = match tokio::net::TcpListener::bind(&bind_addr).await {
Ok(listener) => listener,
Err(err) if err.kind() == ErrorKind::PermissionDenied => {
@ -264,6 +288,10 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
eprintln!("starting rmcp streamable http test server on http://{bind_addr}/mcp");
let router = Router::new()
.route(
SESSION_POST_FAILURE_CONTROL_PATH,
post(arm_session_post_failure),
)
.route(
"/.well-known/oauth-authorization-server/mcp",
get({
@ -291,7 +319,12 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
Arc::new(LocalSessionManager::default()),
StreamableHttpServerConfig::default(),
),
);
)
.layer(middleware::from_fn_with_state(
session_failure_state.clone(),
fail_session_post_when_armed,
))
.with_state(session_failure_state);
let router = if let Ok(token) = std::env::var("MCP_EXPECT_BEARER") {
let expected = Arc::new(format!("Bearer {token}"));
@ -323,3 +356,52 @@ async fn require_bearer(
Err(StatusCode::UNAUTHORIZED)
}
}
async fn arm_session_post_failure(
State(state): State<SessionFailureState>,
Json(request): Json<ArmSessionPostFailureRequest>,
) -> Result<StatusCode, StatusCode> {
let status = StatusCode::from_u16(request.status).map_err(|_| StatusCode::BAD_REQUEST)?;
let armed_failure = if request.remaining == 0 {
None
} else {
Some(ArmedFailure {
status,
remaining: request.remaining,
})
};
*state.armed_failure.lock().await = armed_failure;
Ok(StatusCode::NO_CONTENT)
}
async fn fail_session_post_when_armed(
State(state): State<SessionFailureState>,
request: Request<Body>,
next: Next,
) -> Response {
if request.uri().path() != "/mcp"
|| request.method() != Method::POST
|| !request.headers().contains_key(MCP_SESSION_ID_HEADER)
{
return next.run(request).await;
}
let mut armed_failure = state.armed_failure.lock().await;
if let Some(failure) = armed_failure.as_mut()
&& failure.remaining > 0
{
failure.remaining -= 1;
let status = failure.status;
if failure.remaining == 0 {
*armed_failure = None;
}
let mut response = Response::new(Body::from(format!(
"forced session failure with status {status}"
)));
*response.status_mut() = status;
return response;
}
drop(armed_failure);
next.run(request).await
}

File diff suppressed because it is too large Load diff

View file

@ -1,34 +1,10 @@
use std::collections::HashMap;
use std::env;
use std::time::Duration;
use anyhow::Context;
use anyhow::Result;
use anyhow::anyhow;
use reqwest::ClientBuilder;
use reqwest::header::HeaderMap;
use reqwest::header::HeaderName;
use reqwest::header::HeaderValue;
use rmcp::service::ServiceError;
use tokio::time;
pub(crate) async fn run_with_timeout<F, T>(
fut: F,
timeout: Option<Duration>,
label: &str,
) -> Result<T>
where
F: std::future::Future<Output = Result<T, ServiceError>>,
{
if let Some(duration) = timeout {
let result = time::timeout(duration, fut)
.await
.with_context(|| anyhow!("timed out awaiting {label} after {duration:?}"))?;
result.map_err(|err| anyhow!("{label} failed: {err}"))
} else {
fut.await.map_err(|err| anyhow!("{label} failed: {err}"))
}
}
use std::collections::HashMap;
use std::env;
pub(crate) fn create_env_for_mcp_server(
extra_env: Option<HashMap<String, String>>,

View file

@ -0,0 +1,268 @@
use std::net::TcpListener;
use std::path::PathBuf;
use std::time::Duration;
use std::time::Instant;
use codex_rmcp_client::ElicitationAction;
use codex_rmcp_client::ElicitationResponse;
use codex_rmcp_client::OAuthCredentialsStoreMode;
use codex_rmcp_client::RmcpClient;
use codex_utils_cargo_bin::CargoBinError;
use futures::FutureExt as _;
use pretty_assertions::assert_eq;
use rmcp::model::CallToolResult;
use rmcp::model::ClientCapabilities;
use rmcp::model::ElicitationCapability;
use rmcp::model::FormElicitationCapability;
use rmcp::model::Implementation;
use rmcp::model::InitializeRequestParams;
use rmcp::model::ProtocolVersion;
use serde_json::json;
use tokio::net::TcpStream;
use tokio::process::Child;
use tokio::process::Command;
use tokio::time::sleep;
const SESSION_POST_FAILURE_CONTROL_PATH: &str = "/test/control/session-post-failure";
fn streamable_http_server_bin() -> Result<PathBuf, CargoBinError> {
codex_utils_cargo_bin::cargo_bin("test_streamable_http_server")
}
fn init_params() -> InitializeRequestParams {
InitializeRequestParams {
meta: None,
capabilities: ClientCapabilities {
experimental: None,
extensions: None,
roots: None,
sampling: None,
elicitation: Some(ElicitationCapability {
form: Some(FormElicitationCapability {
schema_validation: None,
}),
url: None,
}),
tasks: None,
},
client_info: Implementation {
name: "codex-test".into(),
version: "0.0.0-test".into(),
title: Some("Codex rmcp recovery test".into()),
description: None,
icons: None,
website_url: None,
},
protocol_version: ProtocolVersion::V_2025_06_18,
}
}
fn expected_echo_result(message: &str) -> CallToolResult {
CallToolResult {
content: Vec::new(),
structured_content: Some(json!({
"echo": format!("ECHOING: {message}"),
"env": null,
})),
is_error: Some(false),
meta: None,
}
}
async fn create_client(base_url: &str) -> anyhow::Result<RmcpClient> {
let client = RmcpClient::new_streamable_http_client(
"test-streamable-http",
&format!("{base_url}/mcp"),
Some("test-bearer".to_string()),
None,
None,
OAuthCredentialsStoreMode::File,
)
.await?;
client
.initialize(
init_params(),
Some(Duration::from_secs(5)),
Box::new(|_, _| {
async {
Ok(ElicitationResponse {
action: ElicitationAction::Accept,
content: Some(json!({})),
meta: None,
})
}
.boxed()
}),
)
.await?;
Ok(client)
}
async fn call_echo_tool(client: &RmcpClient, message: &str) -> anyhow::Result<CallToolResult> {
client
.call_tool(
"echo".to_string(),
Some(json!({ "message": message })),
Some(Duration::from_secs(5)),
)
.await
}
async fn arm_session_post_failure(
base_url: &str,
status: u16,
remaining: usize,
) -> anyhow::Result<()> {
let response = reqwest::Client::new()
.post(format!("{base_url}{SESSION_POST_FAILURE_CONTROL_PATH}"))
.json(&json!({
"status": status,
"remaining": remaining,
}))
.send()
.await?;
assert_eq!(response.status(), reqwest::StatusCode::NO_CONTENT);
Ok(())
}
async fn spawn_streamable_http_server() -> anyhow::Result<(Child, String)> {
let listener = TcpListener::bind("127.0.0.1:0")?;
let port = listener.local_addr()?.port();
drop(listener);
let bind_addr = format!("127.0.0.1:{port}");
let base_url = format!("http://{bind_addr}");
let mut child = Command::new(streamable_http_server_bin()?)
.kill_on_drop(true)
.env("MCP_STREAMABLE_HTTP_BIND_ADDR", &bind_addr)
.spawn()?;
wait_for_streamable_http_server(&mut child, &bind_addr, Duration::from_secs(5)).await?;
Ok((child, base_url))
}
async fn wait_for_streamable_http_server(
server_child: &mut Child,
address: &str,
timeout: Duration,
) -> anyhow::Result<()> {
let deadline = Instant::now() + timeout;
loop {
if let Some(status) = server_child.try_wait()? {
return Err(anyhow::anyhow!(
"streamable HTTP server exited early with status {status}"
));
}
let remaining = deadline.saturating_duration_since(Instant::now());
if remaining.is_zero() {
return Err(anyhow::anyhow!(
"timed out waiting for streamable HTTP server at {address}: deadline reached"
));
}
match tokio::time::timeout(remaining, TcpStream::connect(address)).await {
Ok(Ok(_)) => return Ok(()),
Ok(Err(error)) => {
if Instant::now() >= deadline {
return Err(anyhow::anyhow!(
"timed out waiting for streamable HTTP server at {address}: {error}"
));
}
}
Err(_) => {
return Err(anyhow::anyhow!(
"timed out waiting for streamable HTTP server at {address}: connect call timed out"
));
}
}
sleep(Duration::from_millis(50)).await;
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn streamable_http_404_session_expiry_recovers_and_retries_once() -> anyhow::Result<()> {
let (_server, base_url) = spawn_streamable_http_server().await?;
let client = create_client(&base_url).await?;
let warmup = call_echo_tool(&client, "warmup").await?;
assert_eq!(warmup, expected_echo_result("warmup"));
arm_session_post_failure(&base_url, 404, 1).await?;
let recovered = call_echo_tool(&client, "recovered").await?;
assert_eq!(recovered, expected_echo_result("recovered"));
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn streamable_http_401_does_not_trigger_recovery() -> anyhow::Result<()> {
let (_server, base_url) = spawn_streamable_http_server().await?;
let client = create_client(&base_url).await?;
let warmup = call_echo_tool(&client, "warmup").await?;
assert_eq!(warmup, expected_echo_result("warmup"));
arm_session_post_failure(&base_url, 401, 2).await?;
let first_error = call_echo_tool(&client, "unauthorized").await.unwrap_err();
assert!(first_error.to_string().contains("401"));
let second_error = call_echo_tool(&client, "still-unauthorized")
.await
.unwrap_err();
assert!(second_error.to_string().contains("401"));
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn streamable_http_404_recovery_only_retries_once() -> anyhow::Result<()> {
let (_server, base_url) = spawn_streamable_http_server().await?;
let client = create_client(&base_url).await?;
let warmup = call_echo_tool(&client, "warmup").await?;
assert_eq!(warmup, expected_echo_result("warmup"));
arm_session_post_failure(&base_url, 404, 2).await?;
let error = call_echo_tool(&client, "double-404").await.unwrap_err();
assert!(
error
.to_string()
.contains("handshaking with MCP server failed")
|| error.to_string().contains("Transport channel closed")
);
let recovered = call_echo_tool(&client, "after-double-404").await?;
assert_eq!(recovered, expected_echo_result("after-double-404"));
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn streamable_http_non_session_failure_does_not_trigger_recovery() -> anyhow::Result<()> {
let (_server, base_url) = spawn_streamable_http_server().await?;
let client = create_client(&base_url).await?;
let warmup = call_echo_tool(&client, "warmup").await?;
assert_eq!(warmup, expected_echo_result("warmup"));
arm_session_post_failure(&base_url, 500, 2).await?;
let first_error = call_echo_tool(&client, "server-error").await.unwrap_err();
assert!(first_error.to_string().contains("500"));
let second_error = call_echo_tool(&client, "still-server-error")
.await
.unwrap_err();
assert!(second_error.to_string().contains("500"));
Ok(())
}