feat: sub-agent injection (#12152)
This PR adds parent-thread sub-agent completion notifications and change the prompt of the model to prevent if from being confused
This commit is contained in:
parent
f298c48cc6
commit
2daa3fd44f
8 changed files with 420 additions and 22 deletions
|
|
@ -1,11 +1,14 @@
|
|||
use crate::agent::AgentStatus;
|
||||
use crate::agent::guards::Guards;
|
||||
use crate::agent::status::is_final;
|
||||
use crate::error::CodexErr;
|
||||
use crate::error::Result as CodexResult;
|
||||
use crate::session_prefix::format_subagent_notification_message;
|
||||
use crate::thread_manager::ThreadManagerState;
|
||||
use codex_protocol::ThreadId;
|
||||
use codex_protocol::protocol::Op;
|
||||
use codex_protocol::protocol::SessionSource;
|
||||
use codex_protocol::protocol::SubAgentSource;
|
||||
use codex_protocol::protocol::TokenUsage;
|
||||
use codex_protocol::user_input::UserInput;
|
||||
use std::path::PathBuf;
|
||||
|
|
@ -46,6 +49,7 @@ impl AgentControl {
|
|||
) -> CodexResult<ThreadId> {
|
||||
let state = self.upgrade()?;
|
||||
let reservation = self.state.reserve_spawn_slot(config.agent_max_threads)?;
|
||||
let notification_source = session_source.clone();
|
||||
|
||||
// The same `AgentControl` is sent to spawn the thread.
|
||||
let new_thread = match session_source {
|
||||
|
|
@ -64,6 +68,7 @@ impl AgentControl {
|
|||
state.notify_thread_created(new_thread.thread_id);
|
||||
|
||||
self.send_input(new_thread.thread_id, items).await?;
|
||||
self.maybe_start_completion_watcher(new_thread.thread_id, notification_source);
|
||||
|
||||
Ok(new_thread.thread_id)
|
||||
}
|
||||
|
|
@ -77,6 +82,7 @@ impl AgentControl {
|
|||
) -> CodexResult<ThreadId> {
|
||||
let state = self.upgrade()?;
|
||||
let reservation = self.state.reserve_spawn_slot(config.agent_max_threads)?;
|
||||
let notification_source = session_source.clone();
|
||||
|
||||
let resumed_thread = state
|
||||
.resume_thread_from_rollout_with_source(
|
||||
|
|
@ -90,6 +96,7 @@ impl AgentControl {
|
|||
// Resumed threads are re-registered in-memory and need the same listener
|
||||
// attachment path as freshly spawned threads.
|
||||
state.notify_thread_created(resumed_thread.thread_id);
|
||||
self.maybe_start_completion_watcher(resumed_thread.thread_id, Some(notification_source));
|
||||
|
||||
Ok(resumed_thread.thread_id)
|
||||
}
|
||||
|
|
@ -164,13 +171,60 @@ impl AgentControl {
|
|||
thread.total_token_usage().await
|
||||
}
|
||||
|
||||
/// Starts a detached watcher for sub-agents spawned from another thread.
|
||||
///
|
||||
/// This is only enabled for `SubAgentSource::ThreadSpawn`, where a parent thread exists and
|
||||
/// can receive completion notifications.
|
||||
fn maybe_start_completion_watcher(
|
||||
&self,
|
||||
child_thread_id: ThreadId,
|
||||
session_source: Option<SessionSource>,
|
||||
) {
|
||||
let Some(SessionSource::SubAgent(SubAgentSource::ThreadSpawn {
|
||||
parent_thread_id, ..
|
||||
})) = session_source
|
||||
else {
|
||||
return;
|
||||
};
|
||||
let control = self.clone();
|
||||
tokio::spawn(async move {
|
||||
let mut status_rx = match control.subscribe_status(child_thread_id).await {
|
||||
Ok(rx) => rx,
|
||||
Err(_) => return,
|
||||
};
|
||||
let mut status = status_rx.borrow().clone();
|
||||
while !is_final(&status) {
|
||||
if status_rx.changed().await.is_err() {
|
||||
status = control.get_status(child_thread_id).await;
|
||||
break;
|
||||
}
|
||||
status = status_rx.borrow().clone();
|
||||
}
|
||||
if !is_final(&status) {
|
||||
return;
|
||||
}
|
||||
|
||||
let Ok(state) = control.upgrade() else {
|
||||
return;
|
||||
};
|
||||
let Ok(parent_thread) = state.get_thread(parent_thread_id).await else {
|
||||
return;
|
||||
};
|
||||
parent_thread
|
||||
.inject_user_message_without_turn(format_subagent_notification_message(
|
||||
&child_thread_id.to_string(),
|
||||
&status,
|
||||
))
|
||||
.await;
|
||||
});
|
||||
}
|
||||
|
||||
fn upgrade(&self) -> CodexResult<Arc<ThreadManagerState>> {
|
||||
self.manager
|
||||
.upgrade()
|
||||
.ok_or_else(|| CodexErr::UnsupportedOperation("thread manager dropped".to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
|
@ -180,16 +234,24 @@ mod tests {
|
|||
use crate::agent::agent_status_from_event;
|
||||
use crate::config::Config;
|
||||
use crate::config::ConfigBuilder;
|
||||
use crate::session_prefix::SUBAGENT_NOTIFICATION_OPEN_TAG;
|
||||
use assert_matches::assert_matches;
|
||||
use codex_protocol::config_types::ModeKind;
|
||||
use codex_protocol::models::ContentItem;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use codex_protocol::protocol::ErrorEvent;
|
||||
use codex_protocol::protocol::EventMsg;
|
||||
use codex_protocol::protocol::SessionSource;
|
||||
use codex_protocol::protocol::SubAgentSource;
|
||||
use codex_protocol::protocol::TurnAbortReason;
|
||||
use codex_protocol::protocol::TurnAbortedEvent;
|
||||
use codex_protocol::protocol::TurnCompleteEvent;
|
||||
use codex_protocol::protocol::TurnStartedEvent;
|
||||
use pretty_assertions::assert_eq;
|
||||
use tempfile::TempDir;
|
||||
use tokio::time::Duration;
|
||||
use tokio::time::sleep;
|
||||
use tokio::time::timeout;
|
||||
use toml::Value as TomlValue;
|
||||
|
||||
async fn test_config_with_cli_overrides(
|
||||
|
|
@ -250,6 +312,42 @@ mod tests {
|
|||
}
|
||||
}
|
||||
|
||||
fn has_subagent_notification(history_items: &[ResponseItem]) -> bool {
|
||||
history_items.iter().any(|item| {
|
||||
let ResponseItem::Message { role, content, .. } = item else {
|
||||
return false;
|
||||
};
|
||||
if role != "user" {
|
||||
return false;
|
||||
}
|
||||
content.iter().any(|content_item| match content_item {
|
||||
ContentItem::InputText { text } | ContentItem::OutputText { text } => {
|
||||
text.contains(SUBAGENT_NOTIFICATION_OPEN_TAG)
|
||||
}
|
||||
ContentItem::InputImage { .. } => false,
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
async fn wait_for_subagent_notification(parent_thread: &Arc<CodexThread>) -> bool {
|
||||
let wait = async {
|
||||
loop {
|
||||
let history_items = parent_thread
|
||||
.codex
|
||||
.session
|
||||
.clone_history()
|
||||
.await
|
||||
.raw_items()
|
||||
.to_vec();
|
||||
if has_subagent_notification(&history_items) {
|
||||
return true;
|
||||
}
|
||||
sleep(Duration::from_millis(25)).await;
|
||||
}
|
||||
};
|
||||
timeout(Duration::from_secs(2), wait).await.is_ok()
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn send_input_errors_when_manager_dropped() {
|
||||
let control = AgentControl::default();
|
||||
|
|
@ -683,4 +781,35 @@ mod tests {
|
|||
.await
|
||||
.expect("shutdown resumed thread");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn spawn_child_completion_notifies_parent_history() {
|
||||
let harness = AgentControlHarness::new().await;
|
||||
let (parent_thread_id, parent_thread) = harness.start_thread().await;
|
||||
|
||||
let child_thread_id = harness
|
||||
.control
|
||||
.spawn_agent(
|
||||
harness.config.clone(),
|
||||
text_input("hello child"),
|
||||
Some(SessionSource::SubAgent(SubAgentSource::ThreadSpawn {
|
||||
parent_thread_id,
|
||||
depth: 1,
|
||||
})),
|
||||
)
|
||||
.await
|
||||
.expect("child spawn should succeed");
|
||||
|
||||
let child_thread = harness
|
||||
.manager
|
||||
.get_thread(child_thread_id)
|
||||
.await
|
||||
.expect("child thread should exist");
|
||||
let _ = child_thread
|
||||
.submit(Op::Shutdown {})
|
||||
.await
|
||||
.expect("child shutdown should submit");
|
||||
|
||||
assert_eq!(wait_for_subagent_notification(&parent_thread).await, true);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -8,6 +8,9 @@ use crate::protocol::Event;
|
|||
use crate::protocol::Op;
|
||||
use crate::protocol::Submission;
|
||||
use codex_protocol::config_types::Personality;
|
||||
use codex_protocol::models::ContentItem;
|
||||
use codex_protocol::models::ResponseInputItem;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use codex_protocol::openai_models::ReasoningEffort;
|
||||
use codex_protocol::protocol::AskForApproval;
|
||||
use codex_protocol::protocol::SandboxPolicy;
|
||||
|
|
@ -32,7 +35,7 @@ pub struct ThreadConfigSnapshot {
|
|||
}
|
||||
|
||||
pub struct CodexThread {
|
||||
codex: Codex,
|
||||
pub(crate) codex: Codex,
|
||||
rollout_path: Option<PathBuf>,
|
||||
_watch_registration: WatchRegistration,
|
||||
}
|
||||
|
|
@ -85,6 +88,33 @@ impl CodexThread {
|
|||
self.codex.session.total_token_usage().await
|
||||
}
|
||||
|
||||
/// Records a user-role session-prefix message without creating a new user turn boundary.
|
||||
pub(crate) async fn inject_user_message_without_turn(&self, message: String) {
|
||||
let pending_item = ResponseInputItem::Message {
|
||||
role: "user".to_string(),
|
||||
content: vec![ContentItem::InputText { text: message }],
|
||||
};
|
||||
let pending_items = vec![pending_item];
|
||||
let Err(items_without_active_turn) = self
|
||||
.codex
|
||||
.session
|
||||
.inject_response_items(pending_items)
|
||||
.await
|
||||
else {
|
||||
return;
|
||||
};
|
||||
|
||||
let turn_context = self.codex.session.new_default_turn().await;
|
||||
let items: Vec<ResponseItem> = items_without_active_turn
|
||||
.into_iter()
|
||||
.map(ResponseItem::from)
|
||||
.collect();
|
||||
self.codex
|
||||
.session
|
||||
.record_conversation_items(turn_context.as_ref(), &items)
|
||||
.await;
|
||||
}
|
||||
|
||||
pub fn rollout_path(&self) -> Option<PathBuf> {
|
||||
self.rollout_path.clone()
|
||||
}
|
||||
|
|
|
|||
|
|
@ -571,6 +571,9 @@ fn drop_last_n_user_turns_ignores_session_prefix_user_messages() {
|
|||
"<skill>\n<name>demo</name>\n<path>skills/demo/SKILL.md</path>\nbody\n</skill>",
|
||||
),
|
||||
user_input_text_msg("<user_shell_command>echo 42</user_shell_command>"),
|
||||
user_input_text_msg(
|
||||
"<subagent_notification>{\"agent_id\":\"a\",\"status\":\"completed\"}</subagent_notification>",
|
||||
),
|
||||
user_input_text_msg("turn 1 user"),
|
||||
assistant_msg("turn 1 assistant"),
|
||||
user_input_text_msg("turn 2 user"),
|
||||
|
|
@ -591,6 +594,9 @@ fn drop_last_n_user_turns_ignores_session_prefix_user_messages() {
|
|||
"<skill>\n<name>demo</name>\n<path>skills/demo/SKILL.md</path>\nbody\n</skill>",
|
||||
),
|
||||
user_input_text_msg("<user_shell_command>echo 42</user_shell_command>"),
|
||||
user_input_text_msg(
|
||||
"<subagent_notification>{\"agent_id\":\"a\",\"status\":\"completed\"}</subagent_notification>",
|
||||
),
|
||||
user_input_text_msg("turn 1 user"),
|
||||
assistant_msg("turn 1 assistant"),
|
||||
];
|
||||
|
|
@ -610,6 +616,9 @@ fn drop_last_n_user_turns_ignores_session_prefix_user_messages() {
|
|||
"<skill>\n<name>demo</name>\n<path>skills/demo/SKILL.md</path>\nbody\n</skill>",
|
||||
),
|
||||
user_input_text_msg("<user_shell_command>echo 42</user_shell_command>"),
|
||||
user_input_text_msg(
|
||||
"<subagent_notification>{\"agent_id\":\"a\",\"status\":\"completed\"}</subagent_notification>",
|
||||
),
|
||||
];
|
||||
|
||||
let mut history = create_history_with_items(vec![
|
||||
|
|
@ -622,6 +631,9 @@ fn drop_last_n_user_turns_ignores_session_prefix_user_messages() {
|
|||
"<skill>\n<name>demo</name>\n<path>skills/demo/SKILL.md</path>\nbody\n</skill>",
|
||||
),
|
||||
user_input_text_msg("<user_shell_command>echo 42</user_shell_command>"),
|
||||
user_input_text_msg(
|
||||
"<subagent_notification>{\"agent_id\":\"a\",\"status\":\"completed\"}</subagent_notification>",
|
||||
),
|
||||
user_input_text_msg("turn 1 user"),
|
||||
assistant_msg("turn 1 assistant"),
|
||||
user_input_text_msg("turn 2 user"),
|
||||
|
|
@ -640,6 +652,9 @@ fn drop_last_n_user_turns_ignores_session_prefix_user_messages() {
|
|||
"<skill>\n<name>demo</name>\n<path>skills/demo/SKILL.md</path>\nbody\n</skill>",
|
||||
),
|
||||
user_input_text_msg("<user_shell_command>echo 42</user_shell_command>"),
|
||||
user_input_text_msg(
|
||||
"<subagent_notification>{\"agent_id\":\"a\",\"status\":\"completed\"}</subagent_notification>",
|
||||
),
|
||||
user_input_text_msg("turn 1 user"),
|
||||
assistant_msg("turn 1 assistant"),
|
||||
user_input_text_msg("turn 2 user"),
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
use codex_protocol::protocol::AgentStatus;
|
||||
|
||||
/// Helpers for identifying model-visible "session prefix" messages.
|
||||
///
|
||||
/// A session prefix is a user-role message that carries configuration or state needed by
|
||||
|
|
@ -6,10 +8,41 @@
|
|||
/// boundaries.
|
||||
pub(crate) const ENVIRONMENT_CONTEXT_OPEN_TAG: &str = "<environment_context>";
|
||||
pub(crate) const TURN_ABORTED_OPEN_TAG: &str = "<turn_aborted>";
|
||||
pub(crate) const SUBAGENT_NOTIFICATION_OPEN_TAG: &str = "<subagent_notification>";
|
||||
pub(crate) const SUBAGENT_NOTIFICATION_CLOSE_TAG: &str = "</subagent_notification>";
|
||||
|
||||
fn starts_with_ascii_case_insensitive(text: &str, prefix: &str) -> bool {
|
||||
text.get(..prefix.len())
|
||||
.is_some_and(|candidate| candidate.eq_ignore_ascii_case(prefix))
|
||||
}
|
||||
|
||||
/// Returns true if `text` starts with a session prefix marker (case-insensitive).
|
||||
pub(crate) fn is_session_prefix(text: &str) -> bool {
|
||||
let trimmed = text.trim_start();
|
||||
let lowered = trimmed.to_ascii_lowercase();
|
||||
lowered.starts_with(ENVIRONMENT_CONTEXT_OPEN_TAG) || lowered.starts_with(TURN_ABORTED_OPEN_TAG)
|
||||
starts_with_ascii_case_insensitive(trimmed, ENVIRONMENT_CONTEXT_OPEN_TAG)
|
||||
|| starts_with_ascii_case_insensitive(trimmed, TURN_ABORTED_OPEN_TAG)
|
||||
|| starts_with_ascii_case_insensitive(trimmed, SUBAGENT_NOTIFICATION_OPEN_TAG)
|
||||
}
|
||||
|
||||
pub(crate) fn format_subagent_notification_message(agent_id: &str, status: &AgentStatus) -> String {
|
||||
let payload_json = serde_json::json!({
|
||||
"agent_id": agent_id,
|
||||
"status": status,
|
||||
})
|
||||
.to_string();
|
||||
format!("{SUBAGENT_NOTIFICATION_OPEN_TAG}\n{payload_json}\n{SUBAGENT_NOTIFICATION_CLOSE_TAG}")
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
#[test]
|
||||
fn is_session_prefix_is_case_insensitive() {
|
||||
assert_eq!(
|
||||
is_session_prefix("<SUBAGENT_NOTIFICATION>{}</subagent_notification>"),
|
||||
true
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -427,7 +427,7 @@ mod resume_agent {
|
|||
}
|
||||
}
|
||||
|
||||
mod wait {
|
||||
pub(crate) mod wait {
|
||||
use super::*;
|
||||
use crate::agent::status::is_final;
|
||||
use futures::FutureExt;
|
||||
|
|
@ -447,10 +447,10 @@ mod wait {
|
|||
timeout_ms: Option<i64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct WaitResult {
|
||||
status: HashMap<ThreadId, AgentStatus>,
|
||||
timed_out: bool,
|
||||
#[derive(Debug, Deserialize, Serialize, PartialEq, Eq)]
|
||||
pub(crate) struct WaitResult {
|
||||
pub(crate) status: HashMap<ThreadId, AgentStatus>,
|
||||
pub(crate) timed_out: bool,
|
||||
}
|
||||
|
||||
pub async fn handle(
|
||||
|
|
@ -1462,12 +1462,6 @@ mod tests {
|
|||
);
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, PartialEq, Eq)]
|
||||
struct WaitResult {
|
||||
status: HashMap<ThreadId, AgentStatus>,
|
||||
timed_out: bool,
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn wait_rejects_non_positive_timeout() {
|
||||
let (session, turn) = make_session_and_context().await;
|
||||
|
|
@ -1553,11 +1547,11 @@ mod tests {
|
|||
else {
|
||||
panic!("expected function output");
|
||||
};
|
||||
let result: WaitResult =
|
||||
let result: wait::WaitResult =
|
||||
serde_json::from_str(&content).expect("wait result should be json");
|
||||
assert_eq!(
|
||||
result,
|
||||
WaitResult {
|
||||
wait::WaitResult {
|
||||
status: HashMap::from([
|
||||
(id_a, AgentStatus::NotFound),
|
||||
(id_b, AgentStatus::NotFound),
|
||||
|
|
@ -1597,11 +1591,11 @@ mod tests {
|
|||
else {
|
||||
panic!("expected function output");
|
||||
};
|
||||
let result: WaitResult =
|
||||
let result: wait::WaitResult =
|
||||
serde_json::from_str(&content).expect("wait result should be json");
|
||||
assert_eq!(
|
||||
result,
|
||||
WaitResult {
|
||||
wait::WaitResult {
|
||||
status: HashMap::new(),
|
||||
timed_out: true
|
||||
}
|
||||
|
|
@ -1694,11 +1688,11 @@ mod tests {
|
|||
else {
|
||||
panic!("expected function output");
|
||||
};
|
||||
let result: WaitResult =
|
||||
let result: wait::WaitResult =
|
||||
serde_json::from_str(&content).expect("wait result should be json");
|
||||
assert_eq!(
|
||||
result,
|
||||
WaitResult {
|
||||
wait::WaitResult {
|
||||
status: HashMap::from([(agent_id, AgentStatus::Shutdown)]),
|
||||
timed_out: false
|
||||
}
|
||||
|
|
|
|||
|
|
@ -646,7 +646,7 @@ fn create_wait_tool() -> ToolSpec {
|
|||
|
||||
ToolSpec::Function(ResponsesApiTool {
|
||||
name: "wait".to_string(),
|
||||
description: "Wait for agents to reach a final status. Completed statuses may include the agent's final message. Returns empty status when timed out."
|
||||
description: "Wait for agents to reach a final status. Completed statuses may include the agent's final message. Returns empty status when timed out. Once the agent reaches his final status, a notification message will be received containing the same completed status."
|
||||
.to_string(),
|
||||
strict: false,
|
||||
parameters: JsonSchema::Object {
|
||||
|
|
|
|||
|
|
@ -114,6 +114,7 @@ mod skills;
|
|||
mod sqlite_state;
|
||||
mod stream_error_allows_next_turn;
|
||||
mod stream_no_completed;
|
||||
mod subagent_notifications;
|
||||
mod text_encoding_fix;
|
||||
mod tool_harness;
|
||||
mod tool_parallelism;
|
||||
|
|
|
|||
196
codex-rs/core/tests/suite/subagent_notifications.rs
Normal file
196
codex-rs/core/tests/suite/subagent_notifications.rs
Normal file
|
|
@ -0,0 +1,196 @@
|
|||
use anyhow::Result;
|
||||
use codex_core::features::Feature;
|
||||
use core_test_support::responses::ResponsesRequest;
|
||||
use core_test_support::responses::ev_assistant_message;
|
||||
use core_test_support::responses::ev_completed;
|
||||
use core_test_support::responses::ev_function_call;
|
||||
use core_test_support::responses::ev_response_created;
|
||||
use core_test_support::responses::mount_response_once_match;
|
||||
use core_test_support::responses::mount_sse_once_match;
|
||||
use core_test_support::responses::sse;
|
||||
use core_test_support::responses::sse_response;
|
||||
use core_test_support::responses::start_mock_server;
|
||||
use core_test_support::skip_if_no_network;
|
||||
use core_test_support::test_codex::TestCodex;
|
||||
use core_test_support::test_codex::test_codex;
|
||||
use serde_json::json;
|
||||
use std::time::Duration;
|
||||
use tokio::time::Instant;
|
||||
use tokio::time::sleep;
|
||||
use wiremock::MockServer;
|
||||
|
||||
const SPAWN_CALL_ID: &str = "spawn-call-1";
|
||||
const TURN_1_PROMPT: &str = "spawn a child and continue";
|
||||
const TURN_2_NO_WAIT_PROMPT: &str = "follow up without wait";
|
||||
const CHILD_PROMPT: &str = "child: do work";
|
||||
|
||||
fn body_contains(req: &wiremock::Request, text: &str) -> bool {
|
||||
let is_zstd = req
|
||||
.headers
|
||||
.get("content-encoding")
|
||||
.and_then(|value| value.to_str().ok())
|
||||
.is_some_and(|value| {
|
||||
value
|
||||
.split(',')
|
||||
.any(|entry| entry.trim().eq_ignore_ascii_case("zstd"))
|
||||
});
|
||||
let bytes = if is_zstd {
|
||||
zstd::stream::decode_all(std::io::Cursor::new(&req.body)).ok()
|
||||
} else {
|
||||
Some(req.body.clone())
|
||||
};
|
||||
bytes
|
||||
.and_then(|body| String::from_utf8(body).ok())
|
||||
.is_some_and(|body| body.contains(text))
|
||||
}
|
||||
|
||||
fn has_subagent_notification(req: &ResponsesRequest) -> bool {
|
||||
req.message_input_texts("user")
|
||||
.iter()
|
||||
.any(|text| text.contains("<subagent_notification>"))
|
||||
}
|
||||
|
||||
async fn wait_for_spawned_thread_id(test: &TestCodex) -> Result<String> {
|
||||
let deadline = Instant::now() + Duration::from_secs(2);
|
||||
loop {
|
||||
let ids = test.thread_manager.list_thread_ids().await;
|
||||
if let Some(spawned_id) = ids
|
||||
.iter()
|
||||
.find(|id| **id != test.session_configured.session_id)
|
||||
{
|
||||
return Ok(spawned_id.to_string());
|
||||
}
|
||||
if Instant::now() >= deadline {
|
||||
anyhow::bail!("timed out waiting for spawned thread id");
|
||||
}
|
||||
sleep(Duration::from_millis(10)).await;
|
||||
}
|
||||
}
|
||||
|
||||
async fn wait_for_requests(
|
||||
mock: &core_test_support::responses::ResponseMock,
|
||||
) -> Result<Vec<ResponsesRequest>> {
|
||||
let deadline = Instant::now() + Duration::from_secs(2);
|
||||
loop {
|
||||
let requests = mock.requests();
|
||||
if !requests.is_empty() {
|
||||
return Ok(requests);
|
||||
}
|
||||
if Instant::now() >= deadline {
|
||||
anyhow::bail!("expected at least 1 request, got {}", requests.len());
|
||||
}
|
||||
sleep(Duration::from_millis(10)).await;
|
||||
}
|
||||
}
|
||||
|
||||
async fn setup_turn_one_with_spawned_child(
|
||||
server: &MockServer,
|
||||
child_response_delay: Option<Duration>,
|
||||
) -> Result<(TestCodex, String)> {
|
||||
let spawn_args = serde_json::to_string(&json!({
|
||||
"message": CHILD_PROMPT,
|
||||
}))?;
|
||||
|
||||
mount_sse_once_match(
|
||||
server,
|
||||
|req: &wiremock::Request| body_contains(req, TURN_1_PROMPT),
|
||||
sse(vec![
|
||||
ev_response_created("resp-turn1-1"),
|
||||
ev_function_call(SPAWN_CALL_ID, "spawn_agent", &spawn_args),
|
||||
ev_completed("resp-turn1-1"),
|
||||
]),
|
||||
)
|
||||
.await;
|
||||
|
||||
let child_sse = sse(vec![
|
||||
ev_response_created("resp-child-1"),
|
||||
ev_assistant_message("msg-child-1", "child done"),
|
||||
ev_completed("resp-child-1"),
|
||||
]);
|
||||
let child_request_log = if let Some(delay) = child_response_delay {
|
||||
mount_response_once_match(
|
||||
server,
|
||||
|req: &wiremock::Request| {
|
||||
body_contains(req, CHILD_PROMPT) && !body_contains(req, SPAWN_CALL_ID)
|
||||
},
|
||||
sse_response(child_sse).set_delay(delay),
|
||||
)
|
||||
.await
|
||||
} else {
|
||||
mount_sse_once_match(
|
||||
server,
|
||||
|req: &wiremock::Request| {
|
||||
body_contains(req, CHILD_PROMPT) && !body_contains(req, SPAWN_CALL_ID)
|
||||
},
|
||||
child_sse,
|
||||
)
|
||||
.await
|
||||
};
|
||||
|
||||
let _turn1_followup = mount_sse_once_match(
|
||||
server,
|
||||
|req: &wiremock::Request| body_contains(req, SPAWN_CALL_ID),
|
||||
sse(vec![
|
||||
ev_response_created("resp-turn1-2"),
|
||||
ev_assistant_message("msg-turn1-2", "parent done"),
|
||||
ev_completed("resp-turn1-2"),
|
||||
]),
|
||||
)
|
||||
.await;
|
||||
|
||||
let mut builder = test_codex().with_config(|config| {
|
||||
config.features.enable(Feature::Collab);
|
||||
});
|
||||
let test = builder.build(server).await?;
|
||||
test.submit_turn(TURN_1_PROMPT).await?;
|
||||
if child_response_delay.is_none() {
|
||||
let _ = wait_for_requests(&child_request_log).await?;
|
||||
let rollout_path = test
|
||||
.codex
|
||||
.rollout_path()
|
||||
.ok_or_else(|| anyhow::anyhow!("expected parent rollout path"))?;
|
||||
let deadline = Instant::now() + Duration::from_secs(6);
|
||||
loop {
|
||||
let has_notification = tokio::fs::read_to_string(&rollout_path)
|
||||
.await
|
||||
.is_ok_and(|rollout| rollout.contains("<subagent_notification>"));
|
||||
if has_notification {
|
||||
break;
|
||||
}
|
||||
if Instant::now() >= deadline {
|
||||
anyhow::bail!(
|
||||
"timed out waiting for parent rollout to include subagent notification"
|
||||
);
|
||||
}
|
||||
sleep(Duration::from_millis(10)).await;
|
||||
}
|
||||
}
|
||||
let spawned_id = wait_for_spawned_thread_id(&test).await?;
|
||||
|
||||
Ok((test, spawned_id))
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn subagent_notification_is_included_without_wait() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let server = start_mock_server().await;
|
||||
let (test, _spawned_id) = setup_turn_one_with_spawned_child(&server, None).await?;
|
||||
|
||||
let turn2 = mount_sse_once_match(
|
||||
&server,
|
||||
|req: &wiremock::Request| body_contains(req, TURN_2_NO_WAIT_PROMPT),
|
||||
sse(vec![
|
||||
ev_response_created("resp-turn2-1"),
|
||||
ev_assistant_message("msg-turn2-1", "no wait path"),
|
||||
ev_completed("resp-turn2-1"),
|
||||
]),
|
||||
)
|
||||
.await;
|
||||
test.submit_turn(TURN_2_NO_WAIT_PROMPT).await?;
|
||||
|
||||
let turn2_requests = wait_for_requests(&turn2).await?;
|
||||
assert!(turn2_requests.iter().any(has_subagent_notification));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Loading…
Add table
Reference in a new issue