feat: cap number of agents (#9855)

Adding more guards to agent:
* Max depth or 1 (i.e. a sub-agent can't spawn another one)
* Max 12 sub-agents in total
This commit is contained in:
jif-oai 2026-01-25 15:57:22 +01:00 committed by GitHub
parent a748600c42
commit 73b5274443
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 143 additions and 12 deletions

View file

@ -1,6 +1,8 @@
use crate::error::CodexErr;
use crate::error::Result;
use codex_protocol::ThreadId;
use codex_protocol::protocol::SessionSource;
use codex_protocol::protocol::SubAgentSource;
use std::collections::HashSet;
use std::sync::Arc;
use std::sync::Mutex;
@ -19,6 +21,25 @@ pub(crate) struct Guards {
total_count: AtomicUsize,
}
/// Initial agent is depth 0.
pub(crate) const MAX_THREAD_SPAWN_DEPTH: i32 = 1;
fn session_depth(session_source: &SessionSource) -> i32 {
match session_source {
SessionSource::SubAgent(SubAgentSource::ThreadSpawn { depth, .. }) => *depth,
SessionSource::SubAgent(_) => 0,
_ => 0,
}
}
pub(crate) fn next_thread_spawn_depth(session_source: &SessionSource) -> i32 {
session_depth(session_source).saturating_add(1)
}
pub(crate) fn exceeds_thread_spawn_depth_limit(depth: i32) -> bool {
depth > MAX_THREAD_SPAWN_DEPTH
}
impl Guards {
pub(crate) fn reserve_spawn_slot(
self: &Arc<Self>,
@ -102,6 +123,30 @@ mod tests {
use super::*;
use pretty_assertions::assert_eq;
#[test]
fn session_depth_defaults_to_zero_for_root_sources() {
assert_eq!(session_depth(&SessionSource::Cli), 0);
}
#[test]
fn thread_spawn_depth_increments_and_enforces_limit() {
let session_source = SessionSource::SubAgent(SubAgentSource::ThreadSpawn {
parent_thread_id: ThreadId::new(),
depth: 1,
});
let child_depth = next_thread_spawn_depth(&session_source);
assert_eq!(child_depth, 2);
assert!(exceeds_thread_spawn_depth_limit(child_depth));
}
#[test]
fn non_thread_spawn_subagents_default_to_depth_zero() {
let session_source = SessionSource::SubAgent(SubAgentSource::Review);
assert_eq!(session_depth(&session_source), 0);
assert_eq!(next_thread_spawn_depth(&session_source), 1);
assert!(!exceeds_thread_spawn_depth_limit(1));
}
#[test]
fn reservation_drop_releases_slot() {
let guards = Arc::new(Guards::default());

View file

@ -1,10 +1,12 @@
pub(crate) mod control;
// Do not put in `pub` or `pub(crate)`. This code should not be used somewhere else.
mod guards;
pub(crate) mod role;
pub(crate) mod status;
pub(crate) use codex_protocol::protocol::AgentStatus;
pub(crate) use control::AgentControl;
pub(crate) use guards::MAX_THREAD_SPAWN_DEPTH;
pub(crate) use guards::exceeds_thread_spawn_depth_limit;
pub(crate) use guards::next_thread_spawn_depth;
pub(crate) use role::AgentRole;
pub(crate) use status::agent_status_from_event;

View file

@ -89,7 +89,7 @@ pub use codex_git::GhostSnapshotConfig;
/// files are *silently truncated* to this size so we do not take up too much of
/// the context window.
pub(crate) const PROJECT_DOC_MAX_BYTES: usize = 32 * 1024; // 32 KiB
pub(crate) const DEFAULT_AGENT_MAX_THREADS: Option<usize> = None;
pub(crate) const DEFAULT_AGENT_MAX_THREADS: Option<usize> = Some(12);
pub const CONFIG_TOML_FILE: &str = "config.toml";
@ -3693,7 +3693,7 @@ model_verbosity = "high"
project_doc_max_bytes: PROJECT_DOC_MAX_BYTES,
project_doc_fallback_filenames: Vec::new(),
tool_output_token_limit: None,
agent_max_threads: None,
agent_max_threads: DEFAULT_AGENT_MAX_THREADS,
codex_home: fixture.codex_home(),
config_layer_stack: Default::default(),
history: History::default(),
@ -3775,7 +3775,7 @@ model_verbosity = "high"
project_doc_max_bytes: PROJECT_DOC_MAX_BYTES,
project_doc_fallback_filenames: Vec::new(),
tool_output_token_limit: None,
agent_max_threads: None,
agent_max_threads: DEFAULT_AGENT_MAX_THREADS,
codex_home: fixture.codex_home(),
config_layer_stack: Default::default(),
history: History::default(),
@ -3872,7 +3872,7 @@ model_verbosity = "high"
project_doc_max_bytes: PROJECT_DOC_MAX_BYTES,
project_doc_fallback_filenames: Vec::new(),
tool_output_token_limit: None,
agent_max_threads: None,
agent_max_threads: DEFAULT_AGENT_MAX_THREADS,
codex_home: fixture.codex_home(),
config_layer_stack: Default::default(),
history: History::default(),
@ -3955,7 +3955,7 @@ model_verbosity = "high"
project_doc_max_bytes: PROJECT_DOC_MAX_BYTES,
project_doc_fallback_filenames: Vec::new(),
tool_output_token_limit: None,
agent_max_threads: None,
agent_max_threads: DEFAULT_AGENT_MAX_THREADS,
codex_home: fixture.codex_home(),
config_layer_stack: Default::default(),
history: History::default(),

View file

@ -78,6 +78,9 @@ impl ToolHandler for CollabHandler {
mod spawn {
use super::*;
use crate::agent::AgentRole;
use crate::agent::MAX_THREAD_SPAWN_DEPTH;
use crate::agent::exceeds_thread_spawn_depth_limit;
use crate::agent::next_thread_spawn_depth;
use codex_protocol::protocol::SessionSource;
use codex_protocol::protocol::SubAgentSource;
use std::sync::Arc;
@ -107,6 +110,13 @@ mod spawn {
"Empty message can't be sent to an agent".to_string(),
));
}
let session_source = turn.client.get_session_source();
let child_depth = next_thread_spawn_depth(&session_source);
if exceeds_thread_spawn_depth_limit(child_depth) {
return Err(FunctionCallError::RespondToModel(format!(
"agent depth limit reached: max depth is {MAX_THREAD_SPAWN_DEPTH}"
)));
}
session
.send_event(
&turn,
@ -132,6 +142,7 @@ mod spawn {
prompt.clone(),
Some(SessionSource::SubAgent(SubAgentSource::ThreadSpawn {
parent_thread_id: session.conversation_id,
depth: child_depth,
})),
)
.await
@ -581,7 +592,6 @@ fn build_agent_spawn_config(
config.model_reasoning_summary = turn.client.get_reasoning_summary();
config.developer_instructions = turn.developer_instructions.clone();
config.compact_prompt = turn.compact_prompt.clone();
config.user_instructions = turn.user_instructions.clone();
config.shell_environment_policy = turn.shell_environment_policy.clone();
config.codex_linux_sandbox_exe = turn.codex_linux_sandbox_exe.clone();
config.cwd = turn.cwd.clone();
@ -605,13 +615,17 @@ mod tests {
use super::*;
use crate::CodexAuth;
use crate::ThreadManager;
use crate::agent::MAX_THREAD_SPAWN_DEPTH;
use crate::built_in_model_providers;
use crate::client::ModelClient;
use crate::codex::make_session_and_context;
use crate::config::types::ShellEnvironmentPolicy;
use crate::function_tool::FunctionCallError;
use crate::protocol::AskForApproval;
use crate::protocol::Op;
use crate::protocol::SandboxPolicy;
use crate::protocol::SessionSource;
use crate::protocol::SubAgentSource;
use crate::turn_diff_tracker::TurnDiffTracker;
use codex_protocol::ThreadId;
use pretty_assertions::assert_eq;
@ -731,6 +745,45 @@ mod tests {
);
}
#[tokio::test]
async fn spawn_agent_rejects_when_depth_limit_exceeded() {
let (mut session, mut turn) = make_session_and_context().await;
let manager = thread_manager();
session.services.agent_control = manager.agent_control();
let session_source = SessionSource::SubAgent(SubAgentSource::ThreadSpawn {
parent_thread_id: session.conversation_id,
depth: MAX_THREAD_SPAWN_DEPTH,
});
turn.client = ModelClient::new(
turn.client.config(),
Some(session.services.auth_manager.clone()),
turn.client.get_model_info(),
turn.client.get_otel_manager(),
turn.client.get_provider(),
turn.client.get_reasoning_effort(),
turn.client.get_reasoning_summary(),
session.conversation_id,
session_source,
);
let invocation = invocation(
Arc::new(session),
Arc::new(turn),
"spawn_agent",
function_payload(json!({"message": "hello"})),
);
let Err(err) = CollabHandler.handle(invocation).await else {
panic!("spawn should fail when depth limit exceeded");
};
assert_eq!(
err,
FunctionCallError::RespondToModel(format!(
"agent depth limit reached: max depth is {MAX_THREAD_SPAWN_DEPTH}"
))
);
}
#[tokio::test]
async fn send_input_rejects_empty_message() {
let (session, turn) = make_session_and_context().await;
@ -1081,7 +1134,6 @@ mod tests {
};
turn.developer_instructions = Some("dev".to_string());
turn.compact_prompt = Some("compact".to_string());
turn.user_instructions = Some("user".to_string());
turn.shell_environment_policy = ShellEnvironmentPolicy {
use_profile: true,
..ShellEnvironmentPolicy::default()
@ -1101,7 +1153,6 @@ mod tests {
expected.model_reasoning_summary = turn.client.get_reasoning_summary();
expected.developer_instructions = turn.developer_instructions.clone();
expected.compact_prompt = turn.compact_prompt.clone();
expected.user_instructions = turn.user_instructions.clone();
expected.shell_environment_policy = turn.shell_environment_policy.clone();
expected.codex_linux_sandbox_exe = turn.codex_linux_sandbox_exe.clone();
expected.cwd = turn.cwd.clone();
@ -1115,4 +1166,31 @@ mod tests {
.expect("sandbox policy set");
assert_eq!(config, expected);
}
#[tokio::test]
async fn build_agent_spawn_config_preserves_base_user_instructions() {
let (session, mut turn) = make_session_and_context().await;
let session_source = turn.client.get_session_source();
let mut base_config = (*turn.client.config()).clone();
base_config.user_instructions = Some("base-user".to_string());
turn.user_instructions = Some("resolved-user".to_string());
turn.client = ModelClient::new(
Arc::new(base_config.clone()),
Some(session.services.auth_manager.clone()),
turn.client.get_model_info(),
turn.client.get_otel_manager(),
turn.client.get_provider(),
turn.client.get_reasoning_effort(),
turn.client.get_reasoning_summary(),
session.conversation_id,
session_source,
);
let base_instructions = BaseInstructions {
text: "base".to_string(),
};
let config = build_agent_spawn_config(&base_instructions, &turn).expect("spawn config");
assert_eq!(config.user_instructions, base_config.user_instructions);
}
}

View file

@ -1518,7 +1518,10 @@ pub enum SessionSource {
pub enum SubAgentSource {
Review,
Compact,
ThreadSpawn { parent_thread_id: ThreadId },
ThreadSpawn {
parent_thread_id: ThreadId,
depth: i32,
},
Other(String),
}
@ -1540,8 +1543,11 @@ impl fmt::Display for SubAgentSource {
match self {
SubAgentSource::Review => f.write_str("review"),
SubAgentSource::Compact => f.write_str("compact"),
SubAgentSource::ThreadSpawn { parent_thread_id } => {
write!(f, "thread_spawn_{parent_thread_id}")
SubAgentSource::ThreadSpawn {
parent_thread_id,
depth,
} => {
write!(f, "thread_spawn_{parent_thread_id}_d{depth}")
}
SubAgentSource::Other(other) => f.write_str(other),
}