diff --git a/codex-rs/core/tests/suite/rmcp_client.rs b/codex-rs/core/tests/suite/rmcp_client.rs index 4139a582a..3c6948354 100644 --- a/codex-rs/core/tests/suite/rmcp_client.rs +++ b/codex-rs/core/tests/suite/rmcp_client.rs @@ -4,6 +4,7 @@ use std::ffi::OsString; use std::fs; use std::net::TcpListener; use std::path::Path; +use std::sync::Arc; use std::time::Duration; use std::time::SystemTime; use std::time::UNIX_EPOCH; @@ -36,11 +37,13 @@ use core_test_support::skip_if_no_network; use core_test_support::stdio_server_bin; use core_test_support::test_codex::test_codex; use core_test_support::wait_for_event; +use core_test_support::wait_for_event_with_timeout; +use reqwest::Client; +use reqwest::StatusCode; use serde_json::Value; use serde_json::json; use serial_test::serial; use tempfile::tempdir; -use tokio::net::TcpStream; use tokio::process::Child; use tokio::process::Command; use tokio::time::Instant; @@ -263,7 +266,7 @@ async fn stdio_image_responses_round_trip() -> anyhow::Result<()> { let tools_ready_deadline = Instant::now() + Duration::from_secs(30); loop { fixture.codex.submit(Op::ListMcpTools).await?; - let list_event = core_test_support::wait_for_event_with_timeout( + let list_event = wait_for_event_with_timeout( &fixture.codex, |ev| matches!(ev, EventMsg::McpListToolsResponse(_)), Duration::from_secs(10), @@ -853,8 +856,8 @@ async fn streamable_http_tool_call_round_trip() -> anyhow::Result<()> { /// This test writes to a fallback credentials file in CODEX_HOME. /// Ideally, we wouldn't need to serialize the test but it's much more cumbersome to wire CODEX_HOME through the code. -#[serial(codex_home)] #[test] +#[serial(codex_home)] fn streamable_http_with_oauth_round_trip() -> anyhow::Result<()> { const TEST_STACK_SIZE_BYTES: usize = 8 * 1024 * 1024; @@ -936,8 +939,8 @@ async fn streamable_http_with_oauth_round_trip_impl() -> anyhow::Result<()> { wait_for_streamable_http_server(&mut http_server_child, &bind_addr, Duration::from_secs(5)) .await?; - let temp_home = tempdir()?; - let _guard = EnvVarGuard::set("CODEX_HOME", temp_home.path().as_os_str()); + let temp_home = Arc::new(tempdir()?); + let _codex_home_guard = EnvVarGuard::set("CODEX_HOME", temp_home.path().as_os_str()); write_fallback_oauth_tokens( temp_home.path(), server_name, @@ -948,10 +951,10 @@ async fn streamable_http_with_oauth_round_trip_impl() -> anyhow::Result<()> { )?; let fixture = test_codex() + .with_home(temp_home.clone()) .with_config(move |config| { - // This test seeds OAuth tokens in CODEX_HOME/.credentials.json and - // validates file-backed OAuth loading. Force file mode so Linux - // keyring backend quirks do not affect this test. + // Keep OAuth credentials isolated to this test home because Bazel + // runs the full core suite in one process. config.mcp_oauth_credentials_store_mode = serde_json::from_value(json!("file")) .expect("`file` should deserialize as OAuthCredentialsStoreMode"); let mut servers = config.mcp_servers.get().clone(); @@ -984,6 +987,31 @@ async fn streamable_http_with_oauth_round_trip_impl() -> anyhow::Result<()> { .await?; let session_model = fixture.session_configured.model.clone(); + let tools_ready_deadline = Instant::now() + Duration::from_secs(30); + loop { + fixture.codex.submit(Op::ListMcpTools).await?; + let list_event = wait_for_event_with_timeout( + &fixture.codex, + |ev| matches!(ev, EventMsg::McpListToolsResponse(_)), + Duration::from_secs(10), + ) + .await; + let EventMsg::McpListToolsResponse(tool_list) = list_event else { + unreachable!("event guard guarantees McpListToolsResponse"); + }; + if tool_list.tools.contains_key(&tool_name) { + break; + } + + let available_tools: Vec<&str> = tool_list.tools.keys().map(String::as_str).collect(); + if Instant::now() >= tools_ready_deadline { + panic!( + "timed out waiting for MCP tool {tool_name} to become available; discovered tools: {available_tools:?}" + ); + } + sleep(Duration::from_millis(200)).await; + } + fixture .codex .submit(Op::UserTurn { @@ -1078,7 +1106,8 @@ async fn wait_for_streamable_http_server( timeout: Duration, ) -> anyhow::Result<()> { let deadline = Instant::now() + timeout; - + let metadata_url = format!("http://{address}/.well-known/oauth-authorization-server/mcp"); + let client = Client::builder().no_proxy().build()?; loop { if let Some(status) = server_child.try_wait()? { return Err(anyhow::anyhow!( @@ -1090,22 +1119,30 @@ async fn wait_for_streamable_http_server( if remaining.is_zero() { return Err(anyhow::anyhow!( - "timed out waiting for streamable HTTP server at {address}: deadline reached" + "timed out waiting for streamable HTTP server metadata at {metadata_url}: deadline reached" )); } - match tokio::time::timeout(remaining, TcpStream::connect(address)).await { - Ok(Ok(_)) => return Ok(()), + match tokio::time::timeout(remaining, client.get(&metadata_url).send()).await { + Ok(Ok(response)) if response.status() == StatusCode::OK => return Ok(()), + Ok(Ok(response)) => { + if Instant::now() >= deadline { + return Err(anyhow::anyhow!( + "timed out waiting for streamable HTTP server metadata at {metadata_url}: HTTP {}", + response.status() + )); + } + } Ok(Err(error)) => { if Instant::now() >= deadline { return Err(anyhow::anyhow!( - "timed out waiting for streamable HTTP server at {address}: {error}" + "timed out waiting for streamable HTTP server metadata at {metadata_url}: {error}" )); } } Err(_) => { return Err(anyhow::anyhow!( - "timed out waiting for streamable HTTP server at {address}: connect call timed out" + "timed out waiting for streamable HTTP server metadata at {metadata_url}: request timed out" )); } } diff --git a/codex-rs/rmcp-client/src/bin/test_streamable_http_server.rs b/codex-rs/rmcp-client/src/bin/test_streamable_http_server.rs index 05ba6089d..284e1194c 100644 --- a/codex-rs/rmcp-client/src/bin/test_streamable_http_server.rs +++ b/codex-rs/rmcp-client/src/bin/test_streamable_http_server.rs @@ -3,6 +3,7 @@ use std::collections::HashMap; use std::io::ErrorKind; use std::net::SocketAddr; use std::sync::Arc; +use std::time::Duration; use axum::Router; use axum::body::Body; @@ -44,6 +45,7 @@ use serde::Deserialize; use serde_json::json; use tokio::sync::Mutex; use tokio::task; +use tokio::time::sleep; #[derive(Clone)] struct TestToolServer { @@ -275,15 +277,25 @@ fn parse_bind_addr() -> Result> { async fn main() -> Result<(), Box> { let bind_addr = parse_bind_addr()?; let session_failure_state = SessionFailureState::default(); - let listener = match tokio::net::TcpListener::bind(&bind_addr).await { - Ok(listener) => listener, - Err(err) if err.kind() == ErrorKind::PermissionDenied => { - eprintln!( - "failed to bind to {bind_addr}: {err}. make sure the process has network access" - ); - return Ok(()); + const MAX_BIND_RETRIES: u32 = 20; + const BIND_RETRY_DELAY: Duration = Duration::from_millis(50); + + let mut bind_retries = 0; + let listener = loop { + match tokio::net::TcpListener::bind(&bind_addr).await { + Ok(listener) => break listener, + Err(err) if err.kind() == ErrorKind::PermissionDenied => { + eprintln!( + "failed to bind to {bind_addr}: {err}. make sure the process has network access" + ); + return Ok(()); + } + Err(err) if err.kind() == ErrorKind::AddrInUse && bind_retries < MAX_BIND_RETRIES => { + bind_retries += 1; + sleep(BIND_RETRY_DELAY).await; + } + Err(err) => return Err(err.into()), } - Err(err) => return Err(err.into()), }; eprintln!("starting rmcp streamable http test server on http://{bind_addr}/mcp");