feat: cap number of agents (#9855)
Adding more guards to agent: * Max depth or 1 (i.e. a sub-agent can't spawn another one) * Max 12 sub-agents in total
This commit is contained in:
parent
a748600c42
commit
73b5274443
5 changed files with 143 additions and 12 deletions
|
|
@ -1,6 +1,8 @@
|
|||
use crate::error::CodexErr;
|
||||
use crate::error::Result;
|
||||
use codex_protocol::ThreadId;
|
||||
use codex_protocol::protocol::SessionSource;
|
||||
use codex_protocol::protocol::SubAgentSource;
|
||||
use std::collections::HashSet;
|
||||
use std::sync::Arc;
|
||||
use std::sync::Mutex;
|
||||
|
|
@ -19,6 +21,25 @@ pub(crate) struct Guards {
|
|||
total_count: AtomicUsize,
|
||||
}
|
||||
|
||||
/// Initial agent is depth 0.
|
||||
pub(crate) const MAX_THREAD_SPAWN_DEPTH: i32 = 1;
|
||||
|
||||
fn session_depth(session_source: &SessionSource) -> i32 {
|
||||
match session_source {
|
||||
SessionSource::SubAgent(SubAgentSource::ThreadSpawn { depth, .. }) => *depth,
|
||||
SessionSource::SubAgent(_) => 0,
|
||||
_ => 0,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn next_thread_spawn_depth(session_source: &SessionSource) -> i32 {
|
||||
session_depth(session_source).saturating_add(1)
|
||||
}
|
||||
|
||||
pub(crate) fn exceeds_thread_spawn_depth_limit(depth: i32) -> bool {
|
||||
depth > MAX_THREAD_SPAWN_DEPTH
|
||||
}
|
||||
|
||||
impl Guards {
|
||||
pub(crate) fn reserve_spawn_slot(
|
||||
self: &Arc<Self>,
|
||||
|
|
@ -102,6 +123,30 @@ mod tests {
|
|||
use super::*;
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
#[test]
|
||||
fn session_depth_defaults_to_zero_for_root_sources() {
|
||||
assert_eq!(session_depth(&SessionSource::Cli), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn thread_spawn_depth_increments_and_enforces_limit() {
|
||||
let session_source = SessionSource::SubAgent(SubAgentSource::ThreadSpawn {
|
||||
parent_thread_id: ThreadId::new(),
|
||||
depth: 1,
|
||||
});
|
||||
let child_depth = next_thread_spawn_depth(&session_source);
|
||||
assert_eq!(child_depth, 2);
|
||||
assert!(exceeds_thread_spawn_depth_limit(child_depth));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn non_thread_spawn_subagents_default_to_depth_zero() {
|
||||
let session_source = SessionSource::SubAgent(SubAgentSource::Review);
|
||||
assert_eq!(session_depth(&session_source), 0);
|
||||
assert_eq!(next_thread_spawn_depth(&session_source), 1);
|
||||
assert!(!exceeds_thread_spawn_depth_limit(1));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reservation_drop_releases_slot() {
|
||||
let guards = Arc::new(Guards::default());
|
||||
|
|
|
|||
|
|
@ -1,10 +1,12 @@
|
|||
pub(crate) mod control;
|
||||
// Do not put in `pub` or `pub(crate)`. This code should not be used somewhere else.
|
||||
mod guards;
|
||||
pub(crate) mod role;
|
||||
pub(crate) mod status;
|
||||
|
||||
pub(crate) use codex_protocol::protocol::AgentStatus;
|
||||
pub(crate) use control::AgentControl;
|
||||
pub(crate) use guards::MAX_THREAD_SPAWN_DEPTH;
|
||||
pub(crate) use guards::exceeds_thread_spawn_depth_limit;
|
||||
pub(crate) use guards::next_thread_spawn_depth;
|
||||
pub(crate) use role::AgentRole;
|
||||
pub(crate) use status::agent_status_from_event;
|
||||
|
|
|
|||
|
|
@ -89,7 +89,7 @@ pub use codex_git::GhostSnapshotConfig;
|
|||
/// files are *silently truncated* to this size so we do not take up too much of
|
||||
/// the context window.
|
||||
pub(crate) const PROJECT_DOC_MAX_BYTES: usize = 32 * 1024; // 32 KiB
|
||||
pub(crate) const DEFAULT_AGENT_MAX_THREADS: Option<usize> = None;
|
||||
pub(crate) const DEFAULT_AGENT_MAX_THREADS: Option<usize> = Some(12);
|
||||
|
||||
pub const CONFIG_TOML_FILE: &str = "config.toml";
|
||||
|
||||
|
|
@ -3693,7 +3693,7 @@ model_verbosity = "high"
|
|||
project_doc_max_bytes: PROJECT_DOC_MAX_BYTES,
|
||||
project_doc_fallback_filenames: Vec::new(),
|
||||
tool_output_token_limit: None,
|
||||
agent_max_threads: None,
|
||||
agent_max_threads: DEFAULT_AGENT_MAX_THREADS,
|
||||
codex_home: fixture.codex_home(),
|
||||
config_layer_stack: Default::default(),
|
||||
history: History::default(),
|
||||
|
|
@ -3775,7 +3775,7 @@ model_verbosity = "high"
|
|||
project_doc_max_bytes: PROJECT_DOC_MAX_BYTES,
|
||||
project_doc_fallback_filenames: Vec::new(),
|
||||
tool_output_token_limit: None,
|
||||
agent_max_threads: None,
|
||||
agent_max_threads: DEFAULT_AGENT_MAX_THREADS,
|
||||
codex_home: fixture.codex_home(),
|
||||
config_layer_stack: Default::default(),
|
||||
history: History::default(),
|
||||
|
|
@ -3872,7 +3872,7 @@ model_verbosity = "high"
|
|||
project_doc_max_bytes: PROJECT_DOC_MAX_BYTES,
|
||||
project_doc_fallback_filenames: Vec::new(),
|
||||
tool_output_token_limit: None,
|
||||
agent_max_threads: None,
|
||||
agent_max_threads: DEFAULT_AGENT_MAX_THREADS,
|
||||
codex_home: fixture.codex_home(),
|
||||
config_layer_stack: Default::default(),
|
||||
history: History::default(),
|
||||
|
|
@ -3955,7 +3955,7 @@ model_verbosity = "high"
|
|||
project_doc_max_bytes: PROJECT_DOC_MAX_BYTES,
|
||||
project_doc_fallback_filenames: Vec::new(),
|
||||
tool_output_token_limit: None,
|
||||
agent_max_threads: None,
|
||||
agent_max_threads: DEFAULT_AGENT_MAX_THREADS,
|
||||
codex_home: fixture.codex_home(),
|
||||
config_layer_stack: Default::default(),
|
||||
history: History::default(),
|
||||
|
|
|
|||
|
|
@ -78,6 +78,9 @@ impl ToolHandler for CollabHandler {
|
|||
mod spawn {
|
||||
use super::*;
|
||||
use crate::agent::AgentRole;
|
||||
use crate::agent::MAX_THREAD_SPAWN_DEPTH;
|
||||
use crate::agent::exceeds_thread_spawn_depth_limit;
|
||||
use crate::agent::next_thread_spawn_depth;
|
||||
use codex_protocol::protocol::SessionSource;
|
||||
use codex_protocol::protocol::SubAgentSource;
|
||||
use std::sync::Arc;
|
||||
|
|
@ -107,6 +110,13 @@ mod spawn {
|
|||
"Empty message can't be sent to an agent".to_string(),
|
||||
));
|
||||
}
|
||||
let session_source = turn.client.get_session_source();
|
||||
let child_depth = next_thread_spawn_depth(&session_source);
|
||||
if exceeds_thread_spawn_depth_limit(child_depth) {
|
||||
return Err(FunctionCallError::RespondToModel(format!(
|
||||
"agent depth limit reached: max depth is {MAX_THREAD_SPAWN_DEPTH}"
|
||||
)));
|
||||
}
|
||||
session
|
||||
.send_event(
|
||||
&turn,
|
||||
|
|
@ -132,6 +142,7 @@ mod spawn {
|
|||
prompt.clone(),
|
||||
Some(SessionSource::SubAgent(SubAgentSource::ThreadSpawn {
|
||||
parent_thread_id: session.conversation_id,
|
||||
depth: child_depth,
|
||||
})),
|
||||
)
|
||||
.await
|
||||
|
|
@ -581,7 +592,6 @@ fn build_agent_spawn_config(
|
|||
config.model_reasoning_summary = turn.client.get_reasoning_summary();
|
||||
config.developer_instructions = turn.developer_instructions.clone();
|
||||
config.compact_prompt = turn.compact_prompt.clone();
|
||||
config.user_instructions = turn.user_instructions.clone();
|
||||
config.shell_environment_policy = turn.shell_environment_policy.clone();
|
||||
config.codex_linux_sandbox_exe = turn.codex_linux_sandbox_exe.clone();
|
||||
config.cwd = turn.cwd.clone();
|
||||
|
|
@ -605,13 +615,17 @@ mod tests {
|
|||
use super::*;
|
||||
use crate::CodexAuth;
|
||||
use crate::ThreadManager;
|
||||
use crate::agent::MAX_THREAD_SPAWN_DEPTH;
|
||||
use crate::built_in_model_providers;
|
||||
use crate::client::ModelClient;
|
||||
use crate::codex::make_session_and_context;
|
||||
use crate::config::types::ShellEnvironmentPolicy;
|
||||
use crate::function_tool::FunctionCallError;
|
||||
use crate::protocol::AskForApproval;
|
||||
use crate::protocol::Op;
|
||||
use crate::protocol::SandboxPolicy;
|
||||
use crate::protocol::SessionSource;
|
||||
use crate::protocol::SubAgentSource;
|
||||
use crate::turn_diff_tracker::TurnDiffTracker;
|
||||
use codex_protocol::ThreadId;
|
||||
use pretty_assertions::assert_eq;
|
||||
|
|
@ -731,6 +745,45 @@ mod tests {
|
|||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn spawn_agent_rejects_when_depth_limit_exceeded() {
|
||||
let (mut session, mut turn) = make_session_and_context().await;
|
||||
let manager = thread_manager();
|
||||
session.services.agent_control = manager.agent_control();
|
||||
|
||||
let session_source = SessionSource::SubAgent(SubAgentSource::ThreadSpawn {
|
||||
parent_thread_id: session.conversation_id,
|
||||
depth: MAX_THREAD_SPAWN_DEPTH,
|
||||
});
|
||||
turn.client = ModelClient::new(
|
||||
turn.client.config(),
|
||||
Some(session.services.auth_manager.clone()),
|
||||
turn.client.get_model_info(),
|
||||
turn.client.get_otel_manager(),
|
||||
turn.client.get_provider(),
|
||||
turn.client.get_reasoning_effort(),
|
||||
turn.client.get_reasoning_summary(),
|
||||
session.conversation_id,
|
||||
session_source,
|
||||
);
|
||||
|
||||
let invocation = invocation(
|
||||
Arc::new(session),
|
||||
Arc::new(turn),
|
||||
"spawn_agent",
|
||||
function_payload(json!({"message": "hello"})),
|
||||
);
|
||||
let Err(err) = CollabHandler.handle(invocation).await else {
|
||||
panic!("spawn should fail when depth limit exceeded");
|
||||
};
|
||||
assert_eq!(
|
||||
err,
|
||||
FunctionCallError::RespondToModel(format!(
|
||||
"agent depth limit reached: max depth is {MAX_THREAD_SPAWN_DEPTH}"
|
||||
))
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn send_input_rejects_empty_message() {
|
||||
let (session, turn) = make_session_and_context().await;
|
||||
|
|
@ -1081,7 +1134,6 @@ mod tests {
|
|||
};
|
||||
turn.developer_instructions = Some("dev".to_string());
|
||||
turn.compact_prompt = Some("compact".to_string());
|
||||
turn.user_instructions = Some("user".to_string());
|
||||
turn.shell_environment_policy = ShellEnvironmentPolicy {
|
||||
use_profile: true,
|
||||
..ShellEnvironmentPolicy::default()
|
||||
|
|
@ -1101,7 +1153,6 @@ mod tests {
|
|||
expected.model_reasoning_summary = turn.client.get_reasoning_summary();
|
||||
expected.developer_instructions = turn.developer_instructions.clone();
|
||||
expected.compact_prompt = turn.compact_prompt.clone();
|
||||
expected.user_instructions = turn.user_instructions.clone();
|
||||
expected.shell_environment_policy = turn.shell_environment_policy.clone();
|
||||
expected.codex_linux_sandbox_exe = turn.codex_linux_sandbox_exe.clone();
|
||||
expected.cwd = turn.cwd.clone();
|
||||
|
|
@ -1115,4 +1166,31 @@ mod tests {
|
|||
.expect("sandbox policy set");
|
||||
assert_eq!(config, expected);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn build_agent_spawn_config_preserves_base_user_instructions() {
|
||||
let (session, mut turn) = make_session_and_context().await;
|
||||
let session_source = turn.client.get_session_source();
|
||||
let mut base_config = (*turn.client.config()).clone();
|
||||
base_config.user_instructions = Some("base-user".to_string());
|
||||
turn.user_instructions = Some("resolved-user".to_string());
|
||||
turn.client = ModelClient::new(
|
||||
Arc::new(base_config.clone()),
|
||||
Some(session.services.auth_manager.clone()),
|
||||
turn.client.get_model_info(),
|
||||
turn.client.get_otel_manager(),
|
||||
turn.client.get_provider(),
|
||||
turn.client.get_reasoning_effort(),
|
||||
turn.client.get_reasoning_summary(),
|
||||
session.conversation_id,
|
||||
session_source,
|
||||
);
|
||||
let base_instructions = BaseInstructions {
|
||||
text: "base".to_string(),
|
||||
};
|
||||
|
||||
let config = build_agent_spawn_config(&base_instructions, &turn).expect("spawn config");
|
||||
|
||||
assert_eq!(config.user_instructions, base_config.user_instructions);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1518,7 +1518,10 @@ pub enum SessionSource {
|
|||
pub enum SubAgentSource {
|
||||
Review,
|
||||
Compact,
|
||||
ThreadSpawn { parent_thread_id: ThreadId },
|
||||
ThreadSpawn {
|
||||
parent_thread_id: ThreadId,
|
||||
depth: i32,
|
||||
},
|
||||
Other(String),
|
||||
}
|
||||
|
||||
|
|
@ -1540,8 +1543,11 @@ impl fmt::Display for SubAgentSource {
|
|||
match self {
|
||||
SubAgentSource::Review => f.write_str("review"),
|
||||
SubAgentSource::Compact => f.write_str("compact"),
|
||||
SubAgentSource::ThreadSpawn { parent_thread_id } => {
|
||||
write!(f, "thread_spawn_{parent_thread_id}")
|
||||
SubAgentSource::ThreadSpawn {
|
||||
parent_thread_id,
|
||||
depth,
|
||||
} => {
|
||||
write!(f, "thread_spawn_{parent_thread_id}_d{depth}")
|
||||
}
|
||||
SubAgentSource::Other(other) => f.write_str(other),
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue