feat: sub-agent injection (#12152)

This PR adds parent-thread sub-agent completion notifications and change
the prompt of the model to prevent if from being confused
This commit is contained in:
jif-oai 2026-02-19 11:32:10 +00:00 committed by GitHub
parent f298c48cc6
commit 2daa3fd44f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 420 additions and 22 deletions

View file

@ -1,11 +1,14 @@
use crate::agent::AgentStatus;
use crate::agent::guards::Guards;
use crate::agent::status::is_final;
use crate::error::CodexErr;
use crate::error::Result as CodexResult;
use crate::session_prefix::format_subagent_notification_message;
use crate::thread_manager::ThreadManagerState;
use codex_protocol::ThreadId;
use codex_protocol::protocol::Op;
use codex_protocol::protocol::SessionSource;
use codex_protocol::protocol::SubAgentSource;
use codex_protocol::protocol::TokenUsage;
use codex_protocol::user_input::UserInput;
use std::path::PathBuf;
@ -46,6 +49,7 @@ impl AgentControl {
) -> CodexResult<ThreadId> {
let state = self.upgrade()?;
let reservation = self.state.reserve_spawn_slot(config.agent_max_threads)?;
let notification_source = session_source.clone();
// The same `AgentControl` is sent to spawn the thread.
let new_thread = match session_source {
@ -64,6 +68,7 @@ impl AgentControl {
state.notify_thread_created(new_thread.thread_id);
self.send_input(new_thread.thread_id, items).await?;
self.maybe_start_completion_watcher(new_thread.thread_id, notification_source);
Ok(new_thread.thread_id)
}
@ -77,6 +82,7 @@ impl AgentControl {
) -> CodexResult<ThreadId> {
let state = self.upgrade()?;
let reservation = self.state.reserve_spawn_slot(config.agent_max_threads)?;
let notification_source = session_source.clone();
let resumed_thread = state
.resume_thread_from_rollout_with_source(
@ -90,6 +96,7 @@ impl AgentControl {
// Resumed threads are re-registered in-memory and need the same listener
// attachment path as freshly spawned threads.
state.notify_thread_created(resumed_thread.thread_id);
self.maybe_start_completion_watcher(resumed_thread.thread_id, Some(notification_source));
Ok(resumed_thread.thread_id)
}
@ -164,13 +171,60 @@ impl AgentControl {
thread.total_token_usage().await
}
/// Starts a detached watcher for sub-agents spawned from another thread.
///
/// This is only enabled for `SubAgentSource::ThreadSpawn`, where a parent thread exists and
/// can receive completion notifications.
fn maybe_start_completion_watcher(
&self,
child_thread_id: ThreadId,
session_source: Option<SessionSource>,
) {
let Some(SessionSource::SubAgent(SubAgentSource::ThreadSpawn {
parent_thread_id, ..
})) = session_source
else {
return;
};
let control = self.clone();
tokio::spawn(async move {
let mut status_rx = match control.subscribe_status(child_thread_id).await {
Ok(rx) => rx,
Err(_) => return,
};
let mut status = status_rx.borrow().clone();
while !is_final(&status) {
if status_rx.changed().await.is_err() {
status = control.get_status(child_thread_id).await;
break;
}
status = status_rx.borrow().clone();
}
if !is_final(&status) {
return;
}
let Ok(state) = control.upgrade() else {
return;
};
let Ok(parent_thread) = state.get_thread(parent_thread_id).await else {
return;
};
parent_thread
.inject_user_message_without_turn(format_subagent_notification_message(
&child_thread_id.to_string(),
&status,
))
.await;
});
}
fn upgrade(&self) -> CodexResult<Arc<ThreadManagerState>> {
self.manager
.upgrade()
.ok_or_else(|| CodexErr::UnsupportedOperation("thread manager dropped".to_string()))
}
}
#[cfg(test)]
mod tests {
use super::*;
@ -180,16 +234,24 @@ mod tests {
use crate::agent::agent_status_from_event;
use crate::config::Config;
use crate::config::ConfigBuilder;
use crate::session_prefix::SUBAGENT_NOTIFICATION_OPEN_TAG;
use assert_matches::assert_matches;
use codex_protocol::config_types::ModeKind;
use codex_protocol::models::ContentItem;
use codex_protocol::models::ResponseItem;
use codex_protocol::protocol::ErrorEvent;
use codex_protocol::protocol::EventMsg;
use codex_protocol::protocol::SessionSource;
use codex_protocol::protocol::SubAgentSource;
use codex_protocol::protocol::TurnAbortReason;
use codex_protocol::protocol::TurnAbortedEvent;
use codex_protocol::protocol::TurnCompleteEvent;
use codex_protocol::protocol::TurnStartedEvent;
use pretty_assertions::assert_eq;
use tempfile::TempDir;
use tokio::time::Duration;
use tokio::time::sleep;
use tokio::time::timeout;
use toml::Value as TomlValue;
async fn test_config_with_cli_overrides(
@ -250,6 +312,42 @@ mod tests {
}
}
fn has_subagent_notification(history_items: &[ResponseItem]) -> bool {
history_items.iter().any(|item| {
let ResponseItem::Message { role, content, .. } = item else {
return false;
};
if role != "user" {
return false;
}
content.iter().any(|content_item| match content_item {
ContentItem::InputText { text } | ContentItem::OutputText { text } => {
text.contains(SUBAGENT_NOTIFICATION_OPEN_TAG)
}
ContentItem::InputImage { .. } => false,
})
})
}
async fn wait_for_subagent_notification(parent_thread: &Arc<CodexThread>) -> bool {
let wait = async {
loop {
let history_items = parent_thread
.codex
.session
.clone_history()
.await
.raw_items()
.to_vec();
if has_subagent_notification(&history_items) {
return true;
}
sleep(Duration::from_millis(25)).await;
}
};
timeout(Duration::from_secs(2), wait).await.is_ok()
}
#[tokio::test]
async fn send_input_errors_when_manager_dropped() {
let control = AgentControl::default();
@ -683,4 +781,35 @@ mod tests {
.await
.expect("shutdown resumed thread");
}
#[tokio::test]
async fn spawn_child_completion_notifies_parent_history() {
let harness = AgentControlHarness::new().await;
let (parent_thread_id, parent_thread) = harness.start_thread().await;
let child_thread_id = harness
.control
.spawn_agent(
harness.config.clone(),
text_input("hello child"),
Some(SessionSource::SubAgent(SubAgentSource::ThreadSpawn {
parent_thread_id,
depth: 1,
})),
)
.await
.expect("child spawn should succeed");
let child_thread = harness
.manager
.get_thread(child_thread_id)
.await
.expect("child thread should exist");
let _ = child_thread
.submit(Op::Shutdown {})
.await
.expect("child shutdown should submit");
assert_eq!(wait_for_subagent_notification(&parent_thread).await, true);
}
}

View file

@ -8,6 +8,9 @@ use crate::protocol::Event;
use crate::protocol::Op;
use crate::protocol::Submission;
use codex_protocol::config_types::Personality;
use codex_protocol::models::ContentItem;
use codex_protocol::models::ResponseInputItem;
use codex_protocol::models::ResponseItem;
use codex_protocol::openai_models::ReasoningEffort;
use codex_protocol::protocol::AskForApproval;
use codex_protocol::protocol::SandboxPolicy;
@ -32,7 +35,7 @@ pub struct ThreadConfigSnapshot {
}
pub struct CodexThread {
codex: Codex,
pub(crate) codex: Codex,
rollout_path: Option<PathBuf>,
_watch_registration: WatchRegistration,
}
@ -85,6 +88,33 @@ impl CodexThread {
self.codex.session.total_token_usage().await
}
/// Records a user-role session-prefix message without creating a new user turn boundary.
pub(crate) async fn inject_user_message_without_turn(&self, message: String) {
let pending_item = ResponseInputItem::Message {
role: "user".to_string(),
content: vec![ContentItem::InputText { text: message }],
};
let pending_items = vec![pending_item];
let Err(items_without_active_turn) = self
.codex
.session
.inject_response_items(pending_items)
.await
else {
return;
};
let turn_context = self.codex.session.new_default_turn().await;
let items: Vec<ResponseItem> = items_without_active_turn
.into_iter()
.map(ResponseItem::from)
.collect();
self.codex
.session
.record_conversation_items(turn_context.as_ref(), &items)
.await;
}
pub fn rollout_path(&self) -> Option<PathBuf> {
self.rollout_path.clone()
}

View file

@ -571,6 +571,9 @@ fn drop_last_n_user_turns_ignores_session_prefix_user_messages() {
"<skill>\n<name>demo</name>\n<path>skills/demo/SKILL.md</path>\nbody\n</skill>",
),
user_input_text_msg("<user_shell_command>echo 42</user_shell_command>"),
user_input_text_msg(
"<subagent_notification>{\"agent_id\":\"a\",\"status\":\"completed\"}</subagent_notification>",
),
user_input_text_msg("turn 1 user"),
assistant_msg("turn 1 assistant"),
user_input_text_msg("turn 2 user"),
@ -591,6 +594,9 @@ fn drop_last_n_user_turns_ignores_session_prefix_user_messages() {
"<skill>\n<name>demo</name>\n<path>skills/demo/SKILL.md</path>\nbody\n</skill>",
),
user_input_text_msg("<user_shell_command>echo 42</user_shell_command>"),
user_input_text_msg(
"<subagent_notification>{\"agent_id\":\"a\",\"status\":\"completed\"}</subagent_notification>",
),
user_input_text_msg("turn 1 user"),
assistant_msg("turn 1 assistant"),
];
@ -610,6 +616,9 @@ fn drop_last_n_user_turns_ignores_session_prefix_user_messages() {
"<skill>\n<name>demo</name>\n<path>skills/demo/SKILL.md</path>\nbody\n</skill>",
),
user_input_text_msg("<user_shell_command>echo 42</user_shell_command>"),
user_input_text_msg(
"<subagent_notification>{\"agent_id\":\"a\",\"status\":\"completed\"}</subagent_notification>",
),
];
let mut history = create_history_with_items(vec![
@ -622,6 +631,9 @@ fn drop_last_n_user_turns_ignores_session_prefix_user_messages() {
"<skill>\n<name>demo</name>\n<path>skills/demo/SKILL.md</path>\nbody\n</skill>",
),
user_input_text_msg("<user_shell_command>echo 42</user_shell_command>"),
user_input_text_msg(
"<subagent_notification>{\"agent_id\":\"a\",\"status\":\"completed\"}</subagent_notification>",
),
user_input_text_msg("turn 1 user"),
assistant_msg("turn 1 assistant"),
user_input_text_msg("turn 2 user"),
@ -640,6 +652,9 @@ fn drop_last_n_user_turns_ignores_session_prefix_user_messages() {
"<skill>\n<name>demo</name>\n<path>skills/demo/SKILL.md</path>\nbody\n</skill>",
),
user_input_text_msg("<user_shell_command>echo 42</user_shell_command>"),
user_input_text_msg(
"<subagent_notification>{\"agent_id\":\"a\",\"status\":\"completed\"}</subagent_notification>",
),
user_input_text_msg("turn 1 user"),
assistant_msg("turn 1 assistant"),
user_input_text_msg("turn 2 user"),

View file

@ -1,3 +1,5 @@
use codex_protocol::protocol::AgentStatus;
/// Helpers for identifying model-visible "session prefix" messages.
///
/// A session prefix is a user-role message that carries configuration or state needed by
@ -6,10 +8,41 @@
/// boundaries.
pub(crate) const ENVIRONMENT_CONTEXT_OPEN_TAG: &str = "<environment_context>";
pub(crate) const TURN_ABORTED_OPEN_TAG: &str = "<turn_aborted>";
pub(crate) const SUBAGENT_NOTIFICATION_OPEN_TAG: &str = "<subagent_notification>";
pub(crate) const SUBAGENT_NOTIFICATION_CLOSE_TAG: &str = "</subagent_notification>";
fn starts_with_ascii_case_insensitive(text: &str, prefix: &str) -> bool {
text.get(..prefix.len())
.is_some_and(|candidate| candidate.eq_ignore_ascii_case(prefix))
}
/// Returns true if `text` starts with a session prefix marker (case-insensitive).
pub(crate) fn is_session_prefix(text: &str) -> bool {
let trimmed = text.trim_start();
let lowered = trimmed.to_ascii_lowercase();
lowered.starts_with(ENVIRONMENT_CONTEXT_OPEN_TAG) || lowered.starts_with(TURN_ABORTED_OPEN_TAG)
starts_with_ascii_case_insensitive(trimmed, ENVIRONMENT_CONTEXT_OPEN_TAG)
|| starts_with_ascii_case_insensitive(trimmed, TURN_ABORTED_OPEN_TAG)
|| starts_with_ascii_case_insensitive(trimmed, SUBAGENT_NOTIFICATION_OPEN_TAG)
}
pub(crate) fn format_subagent_notification_message(agent_id: &str, status: &AgentStatus) -> String {
let payload_json = serde_json::json!({
"agent_id": agent_id,
"status": status,
})
.to_string();
format!("{SUBAGENT_NOTIFICATION_OPEN_TAG}\n{payload_json}\n{SUBAGENT_NOTIFICATION_CLOSE_TAG}")
}
#[cfg(test)]
mod tests {
use super::*;
use pretty_assertions::assert_eq;
#[test]
fn is_session_prefix_is_case_insensitive() {
assert_eq!(
is_session_prefix("<SUBAGENT_NOTIFICATION>{}</subagent_notification>"),
true
);
}
}

View file

@ -427,7 +427,7 @@ mod resume_agent {
}
}
mod wait {
pub(crate) mod wait {
use super::*;
use crate::agent::status::is_final;
use futures::FutureExt;
@ -447,10 +447,10 @@ mod wait {
timeout_ms: Option<i64>,
}
#[derive(Debug, Serialize)]
struct WaitResult {
status: HashMap<ThreadId, AgentStatus>,
timed_out: bool,
#[derive(Debug, Deserialize, Serialize, PartialEq, Eq)]
pub(crate) struct WaitResult {
pub(crate) status: HashMap<ThreadId, AgentStatus>,
pub(crate) timed_out: bool,
}
pub async fn handle(
@ -1462,12 +1462,6 @@ mod tests {
);
}
#[derive(Debug, Deserialize, PartialEq, Eq)]
struct WaitResult {
status: HashMap<ThreadId, AgentStatus>,
timed_out: bool,
}
#[tokio::test]
async fn wait_rejects_non_positive_timeout() {
let (session, turn) = make_session_and_context().await;
@ -1553,11 +1547,11 @@ mod tests {
else {
panic!("expected function output");
};
let result: WaitResult =
let result: wait::WaitResult =
serde_json::from_str(&content).expect("wait result should be json");
assert_eq!(
result,
WaitResult {
wait::WaitResult {
status: HashMap::from([
(id_a, AgentStatus::NotFound),
(id_b, AgentStatus::NotFound),
@ -1597,11 +1591,11 @@ mod tests {
else {
panic!("expected function output");
};
let result: WaitResult =
let result: wait::WaitResult =
serde_json::from_str(&content).expect("wait result should be json");
assert_eq!(
result,
WaitResult {
wait::WaitResult {
status: HashMap::new(),
timed_out: true
}
@ -1694,11 +1688,11 @@ mod tests {
else {
panic!("expected function output");
};
let result: WaitResult =
let result: wait::WaitResult =
serde_json::from_str(&content).expect("wait result should be json");
assert_eq!(
result,
WaitResult {
wait::WaitResult {
status: HashMap::from([(agent_id, AgentStatus::Shutdown)]),
timed_out: false
}

View file

@ -646,7 +646,7 @@ fn create_wait_tool() -> ToolSpec {
ToolSpec::Function(ResponsesApiTool {
name: "wait".to_string(),
description: "Wait for agents to reach a final status. Completed statuses may include the agent's final message. Returns empty status when timed out."
description: "Wait for agents to reach a final status. Completed statuses may include the agent's final message. Returns empty status when timed out. Once the agent reaches his final status, a notification message will be received containing the same completed status."
.to_string(),
strict: false,
parameters: JsonSchema::Object {

View file

@ -114,6 +114,7 @@ mod skills;
mod sqlite_state;
mod stream_error_allows_next_turn;
mod stream_no_completed;
mod subagent_notifications;
mod text_encoding_fix;
mod tool_harness;
mod tool_parallelism;

View file

@ -0,0 +1,196 @@
use anyhow::Result;
use codex_core::features::Feature;
use core_test_support::responses::ResponsesRequest;
use core_test_support::responses::ev_assistant_message;
use core_test_support::responses::ev_completed;
use core_test_support::responses::ev_function_call;
use core_test_support::responses::ev_response_created;
use core_test_support::responses::mount_response_once_match;
use core_test_support::responses::mount_sse_once_match;
use core_test_support::responses::sse;
use core_test_support::responses::sse_response;
use core_test_support::responses::start_mock_server;
use core_test_support::skip_if_no_network;
use core_test_support::test_codex::TestCodex;
use core_test_support::test_codex::test_codex;
use serde_json::json;
use std::time::Duration;
use tokio::time::Instant;
use tokio::time::sleep;
use wiremock::MockServer;
const SPAWN_CALL_ID: &str = "spawn-call-1";
const TURN_1_PROMPT: &str = "spawn a child and continue";
const TURN_2_NO_WAIT_PROMPT: &str = "follow up without wait";
const CHILD_PROMPT: &str = "child: do work";
fn body_contains(req: &wiremock::Request, text: &str) -> bool {
let is_zstd = req
.headers
.get("content-encoding")
.and_then(|value| value.to_str().ok())
.is_some_and(|value| {
value
.split(',')
.any(|entry| entry.trim().eq_ignore_ascii_case("zstd"))
});
let bytes = if is_zstd {
zstd::stream::decode_all(std::io::Cursor::new(&req.body)).ok()
} else {
Some(req.body.clone())
};
bytes
.and_then(|body| String::from_utf8(body).ok())
.is_some_and(|body| body.contains(text))
}
fn has_subagent_notification(req: &ResponsesRequest) -> bool {
req.message_input_texts("user")
.iter()
.any(|text| text.contains("<subagent_notification>"))
}
async fn wait_for_spawned_thread_id(test: &TestCodex) -> Result<String> {
let deadline = Instant::now() + Duration::from_secs(2);
loop {
let ids = test.thread_manager.list_thread_ids().await;
if let Some(spawned_id) = ids
.iter()
.find(|id| **id != test.session_configured.session_id)
{
return Ok(spawned_id.to_string());
}
if Instant::now() >= deadline {
anyhow::bail!("timed out waiting for spawned thread id");
}
sleep(Duration::from_millis(10)).await;
}
}
async fn wait_for_requests(
mock: &core_test_support::responses::ResponseMock,
) -> Result<Vec<ResponsesRequest>> {
let deadline = Instant::now() + Duration::from_secs(2);
loop {
let requests = mock.requests();
if !requests.is_empty() {
return Ok(requests);
}
if Instant::now() >= deadline {
anyhow::bail!("expected at least 1 request, got {}", requests.len());
}
sleep(Duration::from_millis(10)).await;
}
}
async fn setup_turn_one_with_spawned_child(
server: &MockServer,
child_response_delay: Option<Duration>,
) -> Result<(TestCodex, String)> {
let spawn_args = serde_json::to_string(&json!({
"message": CHILD_PROMPT,
}))?;
mount_sse_once_match(
server,
|req: &wiremock::Request| body_contains(req, TURN_1_PROMPT),
sse(vec![
ev_response_created("resp-turn1-1"),
ev_function_call(SPAWN_CALL_ID, "spawn_agent", &spawn_args),
ev_completed("resp-turn1-1"),
]),
)
.await;
let child_sse = sse(vec![
ev_response_created("resp-child-1"),
ev_assistant_message("msg-child-1", "child done"),
ev_completed("resp-child-1"),
]);
let child_request_log = if let Some(delay) = child_response_delay {
mount_response_once_match(
server,
|req: &wiremock::Request| {
body_contains(req, CHILD_PROMPT) && !body_contains(req, SPAWN_CALL_ID)
},
sse_response(child_sse).set_delay(delay),
)
.await
} else {
mount_sse_once_match(
server,
|req: &wiremock::Request| {
body_contains(req, CHILD_PROMPT) && !body_contains(req, SPAWN_CALL_ID)
},
child_sse,
)
.await
};
let _turn1_followup = mount_sse_once_match(
server,
|req: &wiremock::Request| body_contains(req, SPAWN_CALL_ID),
sse(vec![
ev_response_created("resp-turn1-2"),
ev_assistant_message("msg-turn1-2", "parent done"),
ev_completed("resp-turn1-2"),
]),
)
.await;
let mut builder = test_codex().with_config(|config| {
config.features.enable(Feature::Collab);
});
let test = builder.build(server).await?;
test.submit_turn(TURN_1_PROMPT).await?;
if child_response_delay.is_none() {
let _ = wait_for_requests(&child_request_log).await?;
let rollout_path = test
.codex
.rollout_path()
.ok_or_else(|| anyhow::anyhow!("expected parent rollout path"))?;
let deadline = Instant::now() + Duration::from_secs(6);
loop {
let has_notification = tokio::fs::read_to_string(&rollout_path)
.await
.is_ok_and(|rollout| rollout.contains("<subagent_notification>"));
if has_notification {
break;
}
if Instant::now() >= deadline {
anyhow::bail!(
"timed out waiting for parent rollout to include subagent notification"
);
}
sleep(Duration::from_millis(10)).await;
}
}
let spawned_id = wait_for_spawned_thread_id(&test).await?;
Ok((test, spawned_id))
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn subagent_notification_is_included_without_wait() -> Result<()> {
skip_if_no_network!(Ok(()));
let server = start_mock_server().await;
let (test, _spawned_id) = setup_turn_one_with_spawned_child(&server, None).await?;
let turn2 = mount_sse_once_match(
&server,
|req: &wiremock::Request| body_contains(req, TURN_2_NO_WAIT_PROMPT),
sse(vec![
ev_response_created("resp-turn2-1"),
ev_assistant_message("msg-turn2-1", "no wait path"),
ev_completed("resp-turn2-1"),
]),
)
.await;
test.submit_turn(TURN_2_NO_WAIT_PROMPT).await?;
let turn2_requests = wait_for_requests(&turn2).await?;
assert!(turn2_requests.iter().any(has_subagent_notification));
Ok(())
}