diff --git a/codex-rs/core/src/agent/control.rs b/codex-rs/core/src/agent/control.rs index 2a95ca6eb..847bb7f76 100644 --- a/codex-rs/core/src/agent/control.rs +++ b/codex-rs/core/src/agent/control.rs @@ -1,11 +1,14 @@ use crate::agent::AgentStatus; use crate::agent::guards::Guards; +use crate::agent::status::is_final; use crate::error::CodexErr; use crate::error::Result as CodexResult; +use crate::session_prefix::format_subagent_notification_message; use crate::thread_manager::ThreadManagerState; use codex_protocol::ThreadId; use codex_protocol::protocol::Op; use codex_protocol::protocol::SessionSource; +use codex_protocol::protocol::SubAgentSource; use codex_protocol::protocol::TokenUsage; use codex_protocol::user_input::UserInput; use std::path::PathBuf; @@ -46,6 +49,7 @@ impl AgentControl { ) -> CodexResult { let state = self.upgrade()?; let reservation = self.state.reserve_spawn_slot(config.agent_max_threads)?; + let notification_source = session_source.clone(); // The same `AgentControl` is sent to spawn the thread. let new_thread = match session_source { @@ -64,6 +68,7 @@ impl AgentControl { state.notify_thread_created(new_thread.thread_id); self.send_input(new_thread.thread_id, items).await?; + self.maybe_start_completion_watcher(new_thread.thread_id, notification_source); Ok(new_thread.thread_id) } @@ -77,6 +82,7 @@ impl AgentControl { ) -> CodexResult { let state = self.upgrade()?; let reservation = self.state.reserve_spawn_slot(config.agent_max_threads)?; + let notification_source = session_source.clone(); let resumed_thread = state .resume_thread_from_rollout_with_source( @@ -90,6 +96,7 @@ impl AgentControl { // Resumed threads are re-registered in-memory and need the same listener // attachment path as freshly spawned threads. state.notify_thread_created(resumed_thread.thread_id); + self.maybe_start_completion_watcher(resumed_thread.thread_id, Some(notification_source)); Ok(resumed_thread.thread_id) } @@ -164,13 +171,60 @@ impl AgentControl { thread.total_token_usage().await } + /// Starts a detached watcher for sub-agents spawned from another thread. + /// + /// This is only enabled for `SubAgentSource::ThreadSpawn`, where a parent thread exists and + /// can receive completion notifications. + fn maybe_start_completion_watcher( + &self, + child_thread_id: ThreadId, + session_source: Option, + ) { + let Some(SessionSource::SubAgent(SubAgentSource::ThreadSpawn { + parent_thread_id, .. + })) = session_source + else { + return; + }; + let control = self.clone(); + tokio::spawn(async move { + let mut status_rx = match control.subscribe_status(child_thread_id).await { + Ok(rx) => rx, + Err(_) => return, + }; + let mut status = status_rx.borrow().clone(); + while !is_final(&status) { + if status_rx.changed().await.is_err() { + status = control.get_status(child_thread_id).await; + break; + } + status = status_rx.borrow().clone(); + } + if !is_final(&status) { + return; + } + + let Ok(state) = control.upgrade() else { + return; + }; + let Ok(parent_thread) = state.get_thread(parent_thread_id).await else { + return; + }; + parent_thread + .inject_user_message_without_turn(format_subagent_notification_message( + &child_thread_id.to_string(), + &status, + )) + .await; + }); + } + fn upgrade(&self) -> CodexResult> { self.manager .upgrade() .ok_or_else(|| CodexErr::UnsupportedOperation("thread manager dropped".to_string())) } } - #[cfg(test)] mod tests { use super::*; @@ -180,16 +234,24 @@ mod tests { use crate::agent::agent_status_from_event; use crate::config::Config; use crate::config::ConfigBuilder; + use crate::session_prefix::SUBAGENT_NOTIFICATION_OPEN_TAG; use assert_matches::assert_matches; use codex_protocol::config_types::ModeKind; + use codex_protocol::models::ContentItem; + use codex_protocol::models::ResponseItem; use codex_protocol::protocol::ErrorEvent; use codex_protocol::protocol::EventMsg; + use codex_protocol::protocol::SessionSource; + use codex_protocol::protocol::SubAgentSource; use codex_protocol::protocol::TurnAbortReason; use codex_protocol::protocol::TurnAbortedEvent; use codex_protocol::protocol::TurnCompleteEvent; use codex_protocol::protocol::TurnStartedEvent; use pretty_assertions::assert_eq; use tempfile::TempDir; + use tokio::time::Duration; + use tokio::time::sleep; + use tokio::time::timeout; use toml::Value as TomlValue; async fn test_config_with_cli_overrides( @@ -250,6 +312,42 @@ mod tests { } } + fn has_subagent_notification(history_items: &[ResponseItem]) -> bool { + history_items.iter().any(|item| { + let ResponseItem::Message { role, content, .. } = item else { + return false; + }; + if role != "user" { + return false; + } + content.iter().any(|content_item| match content_item { + ContentItem::InputText { text } | ContentItem::OutputText { text } => { + text.contains(SUBAGENT_NOTIFICATION_OPEN_TAG) + } + ContentItem::InputImage { .. } => false, + }) + }) + } + + async fn wait_for_subagent_notification(parent_thread: &Arc) -> bool { + let wait = async { + loop { + let history_items = parent_thread + .codex + .session + .clone_history() + .await + .raw_items() + .to_vec(); + if has_subagent_notification(&history_items) { + return true; + } + sleep(Duration::from_millis(25)).await; + } + }; + timeout(Duration::from_secs(2), wait).await.is_ok() + } + #[tokio::test] async fn send_input_errors_when_manager_dropped() { let control = AgentControl::default(); @@ -683,4 +781,35 @@ mod tests { .await .expect("shutdown resumed thread"); } + + #[tokio::test] + async fn spawn_child_completion_notifies_parent_history() { + let harness = AgentControlHarness::new().await; + let (parent_thread_id, parent_thread) = harness.start_thread().await; + + let child_thread_id = harness + .control + .spawn_agent( + harness.config.clone(), + text_input("hello child"), + Some(SessionSource::SubAgent(SubAgentSource::ThreadSpawn { + parent_thread_id, + depth: 1, + })), + ) + .await + .expect("child spawn should succeed"); + + let child_thread = harness + .manager + .get_thread(child_thread_id) + .await + .expect("child thread should exist"); + let _ = child_thread + .submit(Op::Shutdown {}) + .await + .expect("child shutdown should submit"); + + assert_eq!(wait_for_subagent_notification(&parent_thread).await, true); + } } diff --git a/codex-rs/core/src/codex_thread.rs b/codex-rs/core/src/codex_thread.rs index c98dd6977..b493075d4 100644 --- a/codex-rs/core/src/codex_thread.rs +++ b/codex-rs/core/src/codex_thread.rs @@ -8,6 +8,9 @@ use crate::protocol::Event; use crate::protocol::Op; use crate::protocol::Submission; use codex_protocol::config_types::Personality; +use codex_protocol::models::ContentItem; +use codex_protocol::models::ResponseInputItem; +use codex_protocol::models::ResponseItem; use codex_protocol::openai_models::ReasoningEffort; use codex_protocol::protocol::AskForApproval; use codex_protocol::protocol::SandboxPolicy; @@ -32,7 +35,7 @@ pub struct ThreadConfigSnapshot { } pub struct CodexThread { - codex: Codex, + pub(crate) codex: Codex, rollout_path: Option, _watch_registration: WatchRegistration, } @@ -85,6 +88,33 @@ impl CodexThread { self.codex.session.total_token_usage().await } + /// Records a user-role session-prefix message without creating a new user turn boundary. + pub(crate) async fn inject_user_message_without_turn(&self, message: String) { + let pending_item = ResponseInputItem::Message { + role: "user".to_string(), + content: vec![ContentItem::InputText { text: message }], + }; + let pending_items = vec![pending_item]; + let Err(items_without_active_turn) = self + .codex + .session + .inject_response_items(pending_items) + .await + else { + return; + }; + + let turn_context = self.codex.session.new_default_turn().await; + let items: Vec = items_without_active_turn + .into_iter() + .map(ResponseItem::from) + .collect(); + self.codex + .session + .record_conversation_items(turn_context.as_ref(), &items) + .await; + } + pub fn rollout_path(&self) -> Option { self.rollout_path.clone() } diff --git a/codex-rs/core/src/context_manager/history_tests.rs b/codex-rs/core/src/context_manager/history_tests.rs index 28c157b1e..dbecaa9ef 100644 --- a/codex-rs/core/src/context_manager/history_tests.rs +++ b/codex-rs/core/src/context_manager/history_tests.rs @@ -571,6 +571,9 @@ fn drop_last_n_user_turns_ignores_session_prefix_user_messages() { "\ndemo\nskills/demo/SKILL.md\nbody\n", ), user_input_text_msg("echo 42"), + user_input_text_msg( + "{\"agent_id\":\"a\",\"status\":\"completed\"}", + ), user_input_text_msg("turn 1 user"), assistant_msg("turn 1 assistant"), user_input_text_msg("turn 2 user"), @@ -591,6 +594,9 @@ fn drop_last_n_user_turns_ignores_session_prefix_user_messages() { "\ndemo\nskills/demo/SKILL.md\nbody\n", ), user_input_text_msg("echo 42"), + user_input_text_msg( + "{\"agent_id\":\"a\",\"status\":\"completed\"}", + ), user_input_text_msg("turn 1 user"), assistant_msg("turn 1 assistant"), ]; @@ -610,6 +616,9 @@ fn drop_last_n_user_turns_ignores_session_prefix_user_messages() { "\ndemo\nskills/demo/SKILL.md\nbody\n", ), user_input_text_msg("echo 42"), + user_input_text_msg( + "{\"agent_id\":\"a\",\"status\":\"completed\"}", + ), ]; let mut history = create_history_with_items(vec![ @@ -622,6 +631,9 @@ fn drop_last_n_user_turns_ignores_session_prefix_user_messages() { "\ndemo\nskills/demo/SKILL.md\nbody\n", ), user_input_text_msg("echo 42"), + user_input_text_msg( + "{\"agent_id\":\"a\",\"status\":\"completed\"}", + ), user_input_text_msg("turn 1 user"), assistant_msg("turn 1 assistant"), user_input_text_msg("turn 2 user"), @@ -640,6 +652,9 @@ fn drop_last_n_user_turns_ignores_session_prefix_user_messages() { "\ndemo\nskills/demo/SKILL.md\nbody\n", ), user_input_text_msg("echo 42"), + user_input_text_msg( + "{\"agent_id\":\"a\",\"status\":\"completed\"}", + ), user_input_text_msg("turn 1 user"), assistant_msg("turn 1 assistant"), user_input_text_msg("turn 2 user"), diff --git a/codex-rs/core/src/session_prefix.rs b/codex-rs/core/src/session_prefix.rs index 99283082b..5f8516ba9 100644 --- a/codex-rs/core/src/session_prefix.rs +++ b/codex-rs/core/src/session_prefix.rs @@ -1,3 +1,5 @@ +use codex_protocol::protocol::AgentStatus; + /// Helpers for identifying model-visible "session prefix" messages. /// /// A session prefix is a user-role message that carries configuration or state needed by @@ -6,10 +8,41 @@ /// boundaries. pub(crate) const ENVIRONMENT_CONTEXT_OPEN_TAG: &str = ""; pub(crate) const TURN_ABORTED_OPEN_TAG: &str = ""; +pub(crate) const SUBAGENT_NOTIFICATION_OPEN_TAG: &str = ""; +pub(crate) const SUBAGENT_NOTIFICATION_CLOSE_TAG: &str = ""; + +fn starts_with_ascii_case_insensitive(text: &str, prefix: &str) -> bool { + text.get(..prefix.len()) + .is_some_and(|candidate| candidate.eq_ignore_ascii_case(prefix)) +} /// Returns true if `text` starts with a session prefix marker (case-insensitive). pub(crate) fn is_session_prefix(text: &str) -> bool { let trimmed = text.trim_start(); - let lowered = trimmed.to_ascii_lowercase(); - lowered.starts_with(ENVIRONMENT_CONTEXT_OPEN_TAG) || lowered.starts_with(TURN_ABORTED_OPEN_TAG) + starts_with_ascii_case_insensitive(trimmed, ENVIRONMENT_CONTEXT_OPEN_TAG) + || starts_with_ascii_case_insensitive(trimmed, TURN_ABORTED_OPEN_TAG) + || starts_with_ascii_case_insensitive(trimmed, SUBAGENT_NOTIFICATION_OPEN_TAG) +} + +pub(crate) fn format_subagent_notification_message(agent_id: &str, status: &AgentStatus) -> String { + let payload_json = serde_json::json!({ + "agent_id": agent_id, + "status": status, + }) + .to_string(); + format!("{SUBAGENT_NOTIFICATION_OPEN_TAG}\n{payload_json}\n{SUBAGENT_NOTIFICATION_CLOSE_TAG}") +} + +#[cfg(test)] +mod tests { + use super::*; + use pretty_assertions::assert_eq; + + #[test] + fn is_session_prefix_is_case_insensitive() { + assert_eq!( + is_session_prefix("{}"), + true + ); + } } diff --git a/codex-rs/core/src/tools/handlers/multi_agents.rs b/codex-rs/core/src/tools/handlers/multi_agents.rs index f005c3f3c..7a2b62d56 100644 --- a/codex-rs/core/src/tools/handlers/multi_agents.rs +++ b/codex-rs/core/src/tools/handlers/multi_agents.rs @@ -427,7 +427,7 @@ mod resume_agent { } } -mod wait { +pub(crate) mod wait { use super::*; use crate::agent::status::is_final; use futures::FutureExt; @@ -447,10 +447,10 @@ mod wait { timeout_ms: Option, } - #[derive(Debug, Serialize)] - struct WaitResult { - status: HashMap, - timed_out: bool, + #[derive(Debug, Deserialize, Serialize, PartialEq, Eq)] + pub(crate) struct WaitResult { + pub(crate) status: HashMap, + pub(crate) timed_out: bool, } pub async fn handle( @@ -1462,12 +1462,6 @@ mod tests { ); } - #[derive(Debug, Deserialize, PartialEq, Eq)] - struct WaitResult { - status: HashMap, - timed_out: bool, - } - #[tokio::test] async fn wait_rejects_non_positive_timeout() { let (session, turn) = make_session_and_context().await; @@ -1553,11 +1547,11 @@ mod tests { else { panic!("expected function output"); }; - let result: WaitResult = + let result: wait::WaitResult = serde_json::from_str(&content).expect("wait result should be json"); assert_eq!( result, - WaitResult { + wait::WaitResult { status: HashMap::from([ (id_a, AgentStatus::NotFound), (id_b, AgentStatus::NotFound), @@ -1597,11 +1591,11 @@ mod tests { else { panic!("expected function output"); }; - let result: WaitResult = + let result: wait::WaitResult = serde_json::from_str(&content).expect("wait result should be json"); assert_eq!( result, - WaitResult { + wait::WaitResult { status: HashMap::new(), timed_out: true } @@ -1694,11 +1688,11 @@ mod tests { else { panic!("expected function output"); }; - let result: WaitResult = + let result: wait::WaitResult = serde_json::from_str(&content).expect("wait result should be json"); assert_eq!( result, - WaitResult { + wait::WaitResult { status: HashMap::from([(agent_id, AgentStatus::Shutdown)]), timed_out: false } diff --git a/codex-rs/core/src/tools/spec.rs b/codex-rs/core/src/tools/spec.rs index d0e372be2..b17a3a285 100644 --- a/codex-rs/core/src/tools/spec.rs +++ b/codex-rs/core/src/tools/spec.rs @@ -646,7 +646,7 @@ fn create_wait_tool() -> ToolSpec { ToolSpec::Function(ResponsesApiTool { name: "wait".to_string(), - description: "Wait for agents to reach a final status. Completed statuses may include the agent's final message. Returns empty status when timed out." + description: "Wait for agents to reach a final status. Completed statuses may include the agent's final message. Returns empty status when timed out. Once the agent reaches his final status, a notification message will be received containing the same completed status." .to_string(), strict: false, parameters: JsonSchema::Object { diff --git a/codex-rs/core/tests/suite/mod.rs b/codex-rs/core/tests/suite/mod.rs index e420d12f9..5e3d65799 100644 --- a/codex-rs/core/tests/suite/mod.rs +++ b/codex-rs/core/tests/suite/mod.rs @@ -114,6 +114,7 @@ mod skills; mod sqlite_state; mod stream_error_allows_next_turn; mod stream_no_completed; +mod subagent_notifications; mod text_encoding_fix; mod tool_harness; mod tool_parallelism; diff --git a/codex-rs/core/tests/suite/subagent_notifications.rs b/codex-rs/core/tests/suite/subagent_notifications.rs new file mode 100644 index 000000000..422510f32 --- /dev/null +++ b/codex-rs/core/tests/suite/subagent_notifications.rs @@ -0,0 +1,196 @@ +use anyhow::Result; +use codex_core::features::Feature; +use core_test_support::responses::ResponsesRequest; +use core_test_support::responses::ev_assistant_message; +use core_test_support::responses::ev_completed; +use core_test_support::responses::ev_function_call; +use core_test_support::responses::ev_response_created; +use core_test_support::responses::mount_response_once_match; +use core_test_support::responses::mount_sse_once_match; +use core_test_support::responses::sse; +use core_test_support::responses::sse_response; +use core_test_support::responses::start_mock_server; +use core_test_support::skip_if_no_network; +use core_test_support::test_codex::TestCodex; +use core_test_support::test_codex::test_codex; +use serde_json::json; +use std::time::Duration; +use tokio::time::Instant; +use tokio::time::sleep; +use wiremock::MockServer; + +const SPAWN_CALL_ID: &str = "spawn-call-1"; +const TURN_1_PROMPT: &str = "spawn a child and continue"; +const TURN_2_NO_WAIT_PROMPT: &str = "follow up without wait"; +const CHILD_PROMPT: &str = "child: do work"; + +fn body_contains(req: &wiremock::Request, text: &str) -> bool { + let is_zstd = req + .headers + .get("content-encoding") + .and_then(|value| value.to_str().ok()) + .is_some_and(|value| { + value + .split(',') + .any(|entry| entry.trim().eq_ignore_ascii_case("zstd")) + }); + let bytes = if is_zstd { + zstd::stream::decode_all(std::io::Cursor::new(&req.body)).ok() + } else { + Some(req.body.clone()) + }; + bytes + .and_then(|body| String::from_utf8(body).ok()) + .is_some_and(|body| body.contains(text)) +} + +fn has_subagent_notification(req: &ResponsesRequest) -> bool { + req.message_input_texts("user") + .iter() + .any(|text| text.contains("")) +} + +async fn wait_for_spawned_thread_id(test: &TestCodex) -> Result { + let deadline = Instant::now() + Duration::from_secs(2); + loop { + let ids = test.thread_manager.list_thread_ids().await; + if let Some(spawned_id) = ids + .iter() + .find(|id| **id != test.session_configured.session_id) + { + return Ok(spawned_id.to_string()); + } + if Instant::now() >= deadline { + anyhow::bail!("timed out waiting for spawned thread id"); + } + sleep(Duration::from_millis(10)).await; + } +} + +async fn wait_for_requests( + mock: &core_test_support::responses::ResponseMock, +) -> Result> { + let deadline = Instant::now() + Duration::from_secs(2); + loop { + let requests = mock.requests(); + if !requests.is_empty() { + return Ok(requests); + } + if Instant::now() >= deadline { + anyhow::bail!("expected at least 1 request, got {}", requests.len()); + } + sleep(Duration::from_millis(10)).await; + } +} + +async fn setup_turn_one_with_spawned_child( + server: &MockServer, + child_response_delay: Option, +) -> Result<(TestCodex, String)> { + let spawn_args = serde_json::to_string(&json!({ + "message": CHILD_PROMPT, + }))?; + + mount_sse_once_match( + server, + |req: &wiremock::Request| body_contains(req, TURN_1_PROMPT), + sse(vec![ + ev_response_created("resp-turn1-1"), + ev_function_call(SPAWN_CALL_ID, "spawn_agent", &spawn_args), + ev_completed("resp-turn1-1"), + ]), + ) + .await; + + let child_sse = sse(vec![ + ev_response_created("resp-child-1"), + ev_assistant_message("msg-child-1", "child done"), + ev_completed("resp-child-1"), + ]); + let child_request_log = if let Some(delay) = child_response_delay { + mount_response_once_match( + server, + |req: &wiremock::Request| { + body_contains(req, CHILD_PROMPT) && !body_contains(req, SPAWN_CALL_ID) + }, + sse_response(child_sse).set_delay(delay), + ) + .await + } else { + mount_sse_once_match( + server, + |req: &wiremock::Request| { + body_contains(req, CHILD_PROMPT) && !body_contains(req, SPAWN_CALL_ID) + }, + child_sse, + ) + .await + }; + + let _turn1_followup = mount_sse_once_match( + server, + |req: &wiremock::Request| body_contains(req, SPAWN_CALL_ID), + sse(vec![ + ev_response_created("resp-turn1-2"), + ev_assistant_message("msg-turn1-2", "parent done"), + ev_completed("resp-turn1-2"), + ]), + ) + .await; + + let mut builder = test_codex().with_config(|config| { + config.features.enable(Feature::Collab); + }); + let test = builder.build(server).await?; + test.submit_turn(TURN_1_PROMPT).await?; + if child_response_delay.is_none() { + let _ = wait_for_requests(&child_request_log).await?; + let rollout_path = test + .codex + .rollout_path() + .ok_or_else(|| anyhow::anyhow!("expected parent rollout path"))?; + let deadline = Instant::now() + Duration::from_secs(6); + loop { + let has_notification = tokio::fs::read_to_string(&rollout_path) + .await + .is_ok_and(|rollout| rollout.contains("")); + if has_notification { + break; + } + if Instant::now() >= deadline { + anyhow::bail!( + "timed out waiting for parent rollout to include subagent notification" + ); + } + sleep(Duration::from_millis(10)).await; + } + } + let spawned_id = wait_for_spawned_thread_id(&test).await?; + + Ok((test, spawned_id)) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn subagent_notification_is_included_without_wait() -> Result<()> { + skip_if_no_network!(Ok(())); + + let server = start_mock_server().await; + let (test, _spawned_id) = setup_turn_one_with_spawned_child(&server, None).await?; + + let turn2 = mount_sse_once_match( + &server, + |req: &wiremock::Request| body_contains(req, TURN_2_NO_WAIT_PROMPT), + sse(vec![ + ev_response_created("resp-turn2-1"), + ev_assistant_message("msg-turn2-1", "no wait path"), + ev_completed("resp-turn2-1"), + ]), + ) + .await; + test.submit_turn(TURN_2_NO_WAIT_PROMPT).await?; + + let turn2_requests = wait_for_requests(&turn2).await?; + assert!(turn2_requests.iter().any(has_subagent_notification)); + + Ok(()) +}