From 73b5274443cd3ef70ee8d30d707f8fdf805b7ad2 Mon Sep 17 00:00:00 2001 From: jif-oai Date: Sun, 25 Jan 2026 15:57:22 +0100 Subject: [PATCH] feat: cap number of agents (#9855) Adding more guards to agent: * Max depth or 1 (i.e. a sub-agent can't spawn another one) * Max 12 sub-agents in total --- codex-rs/core/src/agent/guards.rs | 45 ++++++++++++ codex-rs/core/src/agent/mod.rs | 4 +- codex-rs/core/src/config/mod.rs | 10 +-- codex-rs/core/src/tools/handlers/collab.rs | 84 +++++++++++++++++++++- codex-rs/protocol/src/protocol.rs | 12 +++- 5 files changed, 143 insertions(+), 12 deletions(-) diff --git a/codex-rs/core/src/agent/guards.rs b/codex-rs/core/src/agent/guards.rs index c384ed7cd..2f146f2f8 100644 --- a/codex-rs/core/src/agent/guards.rs +++ b/codex-rs/core/src/agent/guards.rs @@ -1,6 +1,8 @@ use crate::error::CodexErr; use crate::error::Result; use codex_protocol::ThreadId; +use codex_protocol::protocol::SessionSource; +use codex_protocol::protocol::SubAgentSource; use std::collections::HashSet; use std::sync::Arc; use std::sync::Mutex; @@ -19,6 +21,25 @@ pub(crate) struct Guards { total_count: AtomicUsize, } +/// Initial agent is depth 0. +pub(crate) const MAX_THREAD_SPAWN_DEPTH: i32 = 1; + +fn session_depth(session_source: &SessionSource) -> i32 { + match session_source { + SessionSource::SubAgent(SubAgentSource::ThreadSpawn { depth, .. }) => *depth, + SessionSource::SubAgent(_) => 0, + _ => 0, + } +} + +pub(crate) fn next_thread_spawn_depth(session_source: &SessionSource) -> i32 { + session_depth(session_source).saturating_add(1) +} + +pub(crate) fn exceeds_thread_spawn_depth_limit(depth: i32) -> bool { + depth > MAX_THREAD_SPAWN_DEPTH +} + impl Guards { pub(crate) fn reserve_spawn_slot( self: &Arc, @@ -102,6 +123,30 @@ mod tests { use super::*; use pretty_assertions::assert_eq; + #[test] + fn session_depth_defaults_to_zero_for_root_sources() { + assert_eq!(session_depth(&SessionSource::Cli), 0); + } + + #[test] + fn thread_spawn_depth_increments_and_enforces_limit() { + let session_source = SessionSource::SubAgent(SubAgentSource::ThreadSpawn { + parent_thread_id: ThreadId::new(), + depth: 1, + }); + let child_depth = next_thread_spawn_depth(&session_source); + assert_eq!(child_depth, 2); + assert!(exceeds_thread_spawn_depth_limit(child_depth)); + } + + #[test] + fn non_thread_spawn_subagents_default_to_depth_zero() { + let session_source = SessionSource::SubAgent(SubAgentSource::Review); + assert_eq!(session_depth(&session_source), 0); + assert_eq!(next_thread_spawn_depth(&session_source), 1); + assert!(!exceeds_thread_spawn_depth_limit(1)); + } + #[test] fn reservation_drop_releases_slot() { let guards = Arc::new(Guards::default()); diff --git a/codex-rs/core/src/agent/mod.rs b/codex-rs/core/src/agent/mod.rs index 180f70dbe..03652e43e 100644 --- a/codex-rs/core/src/agent/mod.rs +++ b/codex-rs/core/src/agent/mod.rs @@ -1,10 +1,12 @@ pub(crate) mod control; -// Do not put in `pub` or `pub(crate)`. This code should not be used somewhere else. mod guards; pub(crate) mod role; pub(crate) mod status; pub(crate) use codex_protocol::protocol::AgentStatus; pub(crate) use control::AgentControl; +pub(crate) use guards::MAX_THREAD_SPAWN_DEPTH; +pub(crate) use guards::exceeds_thread_spawn_depth_limit; +pub(crate) use guards::next_thread_spawn_depth; pub(crate) use role::AgentRole; pub(crate) use status::agent_status_from_event; diff --git a/codex-rs/core/src/config/mod.rs b/codex-rs/core/src/config/mod.rs index b3fb84c5f..02fc0121f 100644 --- a/codex-rs/core/src/config/mod.rs +++ b/codex-rs/core/src/config/mod.rs @@ -89,7 +89,7 @@ pub use codex_git::GhostSnapshotConfig; /// files are *silently truncated* to this size so we do not take up too much of /// the context window. pub(crate) const PROJECT_DOC_MAX_BYTES: usize = 32 * 1024; // 32 KiB -pub(crate) const DEFAULT_AGENT_MAX_THREADS: Option = None; +pub(crate) const DEFAULT_AGENT_MAX_THREADS: Option = Some(12); pub const CONFIG_TOML_FILE: &str = "config.toml"; @@ -3693,7 +3693,7 @@ model_verbosity = "high" project_doc_max_bytes: PROJECT_DOC_MAX_BYTES, project_doc_fallback_filenames: Vec::new(), tool_output_token_limit: None, - agent_max_threads: None, + agent_max_threads: DEFAULT_AGENT_MAX_THREADS, codex_home: fixture.codex_home(), config_layer_stack: Default::default(), history: History::default(), @@ -3775,7 +3775,7 @@ model_verbosity = "high" project_doc_max_bytes: PROJECT_DOC_MAX_BYTES, project_doc_fallback_filenames: Vec::new(), tool_output_token_limit: None, - agent_max_threads: None, + agent_max_threads: DEFAULT_AGENT_MAX_THREADS, codex_home: fixture.codex_home(), config_layer_stack: Default::default(), history: History::default(), @@ -3872,7 +3872,7 @@ model_verbosity = "high" project_doc_max_bytes: PROJECT_DOC_MAX_BYTES, project_doc_fallback_filenames: Vec::new(), tool_output_token_limit: None, - agent_max_threads: None, + agent_max_threads: DEFAULT_AGENT_MAX_THREADS, codex_home: fixture.codex_home(), config_layer_stack: Default::default(), history: History::default(), @@ -3955,7 +3955,7 @@ model_verbosity = "high" project_doc_max_bytes: PROJECT_DOC_MAX_BYTES, project_doc_fallback_filenames: Vec::new(), tool_output_token_limit: None, - agent_max_threads: None, + agent_max_threads: DEFAULT_AGENT_MAX_THREADS, codex_home: fixture.codex_home(), config_layer_stack: Default::default(), history: History::default(), diff --git a/codex-rs/core/src/tools/handlers/collab.rs b/codex-rs/core/src/tools/handlers/collab.rs index 6bdebf6a2..83ee53905 100644 --- a/codex-rs/core/src/tools/handlers/collab.rs +++ b/codex-rs/core/src/tools/handlers/collab.rs @@ -78,6 +78,9 @@ impl ToolHandler for CollabHandler { mod spawn { use super::*; use crate::agent::AgentRole; + use crate::agent::MAX_THREAD_SPAWN_DEPTH; + use crate::agent::exceeds_thread_spawn_depth_limit; + use crate::agent::next_thread_spawn_depth; use codex_protocol::protocol::SessionSource; use codex_protocol::protocol::SubAgentSource; use std::sync::Arc; @@ -107,6 +110,13 @@ mod spawn { "Empty message can't be sent to an agent".to_string(), )); } + let session_source = turn.client.get_session_source(); + let child_depth = next_thread_spawn_depth(&session_source); + if exceeds_thread_spawn_depth_limit(child_depth) { + return Err(FunctionCallError::RespondToModel(format!( + "agent depth limit reached: max depth is {MAX_THREAD_SPAWN_DEPTH}" + ))); + } session .send_event( &turn, @@ -132,6 +142,7 @@ mod spawn { prompt.clone(), Some(SessionSource::SubAgent(SubAgentSource::ThreadSpawn { parent_thread_id: session.conversation_id, + depth: child_depth, })), ) .await @@ -581,7 +592,6 @@ fn build_agent_spawn_config( config.model_reasoning_summary = turn.client.get_reasoning_summary(); config.developer_instructions = turn.developer_instructions.clone(); config.compact_prompt = turn.compact_prompt.clone(); - config.user_instructions = turn.user_instructions.clone(); config.shell_environment_policy = turn.shell_environment_policy.clone(); config.codex_linux_sandbox_exe = turn.codex_linux_sandbox_exe.clone(); config.cwd = turn.cwd.clone(); @@ -605,13 +615,17 @@ mod tests { use super::*; use crate::CodexAuth; use crate::ThreadManager; + use crate::agent::MAX_THREAD_SPAWN_DEPTH; use crate::built_in_model_providers; + use crate::client::ModelClient; use crate::codex::make_session_and_context; use crate::config::types::ShellEnvironmentPolicy; use crate::function_tool::FunctionCallError; use crate::protocol::AskForApproval; use crate::protocol::Op; use crate::protocol::SandboxPolicy; + use crate::protocol::SessionSource; + use crate::protocol::SubAgentSource; use crate::turn_diff_tracker::TurnDiffTracker; use codex_protocol::ThreadId; use pretty_assertions::assert_eq; @@ -731,6 +745,45 @@ mod tests { ); } + #[tokio::test] + async fn spawn_agent_rejects_when_depth_limit_exceeded() { + let (mut session, mut turn) = make_session_and_context().await; + let manager = thread_manager(); + session.services.agent_control = manager.agent_control(); + + let session_source = SessionSource::SubAgent(SubAgentSource::ThreadSpawn { + parent_thread_id: session.conversation_id, + depth: MAX_THREAD_SPAWN_DEPTH, + }); + turn.client = ModelClient::new( + turn.client.config(), + Some(session.services.auth_manager.clone()), + turn.client.get_model_info(), + turn.client.get_otel_manager(), + turn.client.get_provider(), + turn.client.get_reasoning_effort(), + turn.client.get_reasoning_summary(), + session.conversation_id, + session_source, + ); + + let invocation = invocation( + Arc::new(session), + Arc::new(turn), + "spawn_agent", + function_payload(json!({"message": "hello"})), + ); + let Err(err) = CollabHandler.handle(invocation).await else { + panic!("spawn should fail when depth limit exceeded"); + }; + assert_eq!( + err, + FunctionCallError::RespondToModel(format!( + "agent depth limit reached: max depth is {MAX_THREAD_SPAWN_DEPTH}" + )) + ); + } + #[tokio::test] async fn send_input_rejects_empty_message() { let (session, turn) = make_session_and_context().await; @@ -1081,7 +1134,6 @@ mod tests { }; turn.developer_instructions = Some("dev".to_string()); turn.compact_prompt = Some("compact".to_string()); - turn.user_instructions = Some("user".to_string()); turn.shell_environment_policy = ShellEnvironmentPolicy { use_profile: true, ..ShellEnvironmentPolicy::default() @@ -1101,7 +1153,6 @@ mod tests { expected.model_reasoning_summary = turn.client.get_reasoning_summary(); expected.developer_instructions = turn.developer_instructions.clone(); expected.compact_prompt = turn.compact_prompt.clone(); - expected.user_instructions = turn.user_instructions.clone(); expected.shell_environment_policy = turn.shell_environment_policy.clone(); expected.codex_linux_sandbox_exe = turn.codex_linux_sandbox_exe.clone(); expected.cwd = turn.cwd.clone(); @@ -1115,4 +1166,31 @@ mod tests { .expect("sandbox policy set"); assert_eq!(config, expected); } + + #[tokio::test] + async fn build_agent_spawn_config_preserves_base_user_instructions() { + let (session, mut turn) = make_session_and_context().await; + let session_source = turn.client.get_session_source(); + let mut base_config = (*turn.client.config()).clone(); + base_config.user_instructions = Some("base-user".to_string()); + turn.user_instructions = Some("resolved-user".to_string()); + turn.client = ModelClient::new( + Arc::new(base_config.clone()), + Some(session.services.auth_manager.clone()), + turn.client.get_model_info(), + turn.client.get_otel_manager(), + turn.client.get_provider(), + turn.client.get_reasoning_effort(), + turn.client.get_reasoning_summary(), + session.conversation_id, + session_source, + ); + let base_instructions = BaseInstructions { + text: "base".to_string(), + }; + + let config = build_agent_spawn_config(&base_instructions, &turn).expect("spawn config"); + + assert_eq!(config.user_instructions, base_config.user_instructions); + } } diff --git a/codex-rs/protocol/src/protocol.rs b/codex-rs/protocol/src/protocol.rs index 7b1c39d06..7bf3c90ea 100644 --- a/codex-rs/protocol/src/protocol.rs +++ b/codex-rs/protocol/src/protocol.rs @@ -1518,7 +1518,10 @@ pub enum SessionSource { pub enum SubAgentSource { Review, Compact, - ThreadSpawn { parent_thread_id: ThreadId }, + ThreadSpawn { + parent_thread_id: ThreadId, + depth: i32, + }, Other(String), } @@ -1540,8 +1543,11 @@ impl fmt::Display for SubAgentSource { match self { SubAgentSource::Review => f.write_str("review"), SubAgentSource::Compact => f.write_str("compact"), - SubAgentSource::ThreadSpawn { parent_thread_id } => { - write!(f, "thread_spawn_{parent_thread_id}") + SubAgentSource::ThreadSpawn { + parent_thread_id, + depth, + } => { + write!(f, "thread_spawn_{parent_thread_id}_d{depth}") } SubAgentSource::Other(other) => f.write_str(other), }