diff --git a/codex-rs/core/tests/suite/resume.rs b/codex-rs/core/tests/suite/resume.rs index 98a570a1d..b5889c9aa 100644 --- a/codex-rs/core/tests/suite/resume.rs +++ b/codex-rs/core/tests/suite/resume.rs @@ -13,10 +13,49 @@ use core_test_support::responses::mount_sse_sequence; use core_test_support::responses::sse; use core_test_support::responses::start_mock_server; use core_test_support::skip_if_no_network; +use core_test_support::test_codex::TestCodex; +use core_test_support::test_codex::TestCodexBuilder; use core_test_support::test_codex::test_codex; use core_test_support::wait_for_event; use pretty_assertions::assert_eq; +use std::path::PathBuf; use std::sync::Arc; +use std::time::Duration; +use tempfile::TempDir; +use wiremock::MockServer; + +async fn resume_until_initial_messages( + builder: &mut TestCodexBuilder, + server: &MockServer, + home: Arc, + rollout_path: PathBuf, + predicate: impl Fn(&[EventMsg]) -> bool, +) -> Result { + let deadline = tokio::time::Instant::now() + Duration::from_secs(2); + let poll_interval = Duration::from_millis(10); + let mut last_initial_messages = "".to_string(); + + loop { + let resumed = builder + .resume(server, Arc::clone(&home), rollout_path.clone()) + .await?; + if let Some(initial_messages) = resumed.session_configured.initial_messages.as_ref() { + if predicate(initial_messages) { + return Ok(resumed); + } + last_initial_messages = format!("{initial_messages:#?}"); + } + + if tokio::time::Instant::now() >= deadline { + panic!( + "timed out waiting for rollout resume messages to stabilize: {last_initial_messages}" + ); + } + + drop(resumed); + tokio::time::sleep(poll_interval).await; + } +} #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn resume_includes_initial_messages_from_rollout_events() -> Result<()> { @@ -57,7 +96,26 @@ async fn resume_includes_initial_messages_from_rollout_events() -> Result<()> { wait_for_event(&codex, |event| matches!(event, EventMsg::TurnComplete(_))).await; - let resumed = builder.resume(&server, home, rollout_path).await?; + let resumed = resume_until_initial_messages( + &mut builder, + &server, + home, + rollout_path, + |initial_messages| { + matches!( + initial_messages, + [ + EventMsg::TurnStarted(_), + EventMsg::UserMessage(_), + EventMsg::TokenCount(_), + EventMsg::AgentMessage(_), + EventMsg::TokenCount(_), + EventMsg::TurnComplete(_), + ] + ) + }, + ) + .await?; let initial_messages = resumed .session_configured .initial_messages @@ -123,7 +181,28 @@ async fn resume_includes_initial_messages_from_reasoning_events() -> Result<()> wait_for_event(&codex, |event| matches!(event, EventMsg::TurnComplete(_))).await; - let resumed = builder.resume(&server, home, rollout_path).await?; + let resumed = resume_until_initial_messages( + &mut builder, + &server, + home, + rollout_path, + |initial_messages| { + matches!( + initial_messages, + [ + EventMsg::TurnStarted(_), + EventMsg::UserMessage(_), + EventMsg::TokenCount(_), + EventMsg::AgentReasoning(_), + EventMsg::AgentReasoningRawContent(_), + EventMsg::AgentMessage(_), + EventMsg::TokenCount(_), + EventMsg::TurnComplete(_), + ] + ) + }, + ) + .await?; let initial_messages = resumed .session_configured .initial_messages