diff --git a/codex-rs/core/src/unified_exec/mod.rs b/codex-rs/core/src/unified_exec/mod.rs index 390401d78..1c9194baa 100644 --- a/codex-rs/core/src/unified_exec/mod.rs +++ b/codex-rs/core/src/unified_exec/mod.rs @@ -46,6 +46,7 @@ pub(crate) const MAX_YIELD_TIME_MS: u64 = 30_000; pub(crate) const DEFAULT_MAX_OUTPUT_TOKENS: usize = 10_000; pub(crate) const UNIFIED_EXEC_OUTPUT_MAX_BYTES: usize = 1024 * 1024; // 1 MiB pub(crate) const UNIFIED_EXEC_OUTPUT_MAX_TOKENS: usize = UNIFIED_EXEC_OUTPUT_MAX_BYTES / 4; +pub(crate) const MAX_UNIFIED_EXEC_SESSIONS: usize = 64; pub(crate) struct UnifiedExecContext { pub session: Arc, @@ -108,6 +109,7 @@ struct SessionEntry { command: Vec, cwd: PathBuf, started_at: tokio::time::Instant, + last_used: tokio::time::Instant, } pub(crate) fn clamp_yield_time(yield_time_ms: u64) -> u64 { diff --git a/codex-rs/core/src/unified_exec/session_manager.rs b/codex-rs/core/src/unified_exec/session_manager.rs index d9f99b9ea..a5cc17477 100644 --- a/codex-rs/core/src/unified_exec/session_manager.rs +++ b/codex-rs/core/src/unified_exec/session_manager.rs @@ -1,3 +1,6 @@ +use std::cmp::Reverse; +use std::collections::HashMap; +use std::collections::HashSet; use std::path::PathBuf; use std::sync::Arc; @@ -31,6 +34,7 @@ use crate::truncate::approx_token_count; use crate::truncate::formatted_truncate_text; use super::ExecCommandRequest; +use super::MAX_UNIFIED_EXEC_SESSIONS; use super::SessionEntry; use super::UnifiedExecContext; use super::UnifiedExecError; @@ -294,10 +298,11 @@ impl UnifiedExecSessionManager { &self, session_id: i32, ) -> Result { - let sessions = self.sessions.lock().await; + let mut sessions = self.sessions.lock().await; let entry = sessions - .get(&session_id) + .get_mut(&session_id) .ok_or(UnifiedExecError::UnknownSessionId { session_id })?; + entry.last_used = Instant::now(); let OutputHandles { output_buffer, output_notify, @@ -345,8 +350,11 @@ impl UnifiedExecSessionManager { command: command.to_vec(), cwd, started_at, + last_used: started_at, }; - self.sessions.lock().await.insert(session_id, entry); + let mut sessions = self.sessions.lock().await; + Self::prune_sessions_if_needed(&mut sessions); + sessions.insert(session_id, entry); session_id } @@ -548,6 +556,50 @@ impl UnifiedExecSessionManager { collected } + + fn prune_sessions_if_needed(sessions: &mut HashMap) { + if sessions.len() < MAX_UNIFIED_EXEC_SESSIONS { + return; + } + + let meta: Vec<(i32, Instant, bool)> = sessions + .iter() + .map(|(id, entry)| (*id, entry.last_used, entry.session.has_exited())) + .collect(); + + if let Some(session_id) = Self::session_id_to_prune_from_meta(&meta) { + sessions.remove(&session_id); + } + } + + // Centralized pruning policy so we can easily swap strategies later. + fn session_id_to_prune_from_meta(meta: &[(i32, Instant, bool)]) -> Option { + if meta.is_empty() { + return None; + } + + let mut by_recency = meta.to_vec(); + by_recency.sort_by_key(|(_, last_used, _)| Reverse(*last_used)); + let protected: HashSet = by_recency + .iter() + .take(8) + .map(|(session_id, _, _)| *session_id) + .collect(); + + let mut lru = meta.to_vec(); + lru.sort_by_key(|(_, last_used, _)| *last_used); + + if let Some((session_id, _, _)) = lru + .iter() + .find(|(session_id, _, exited)| !protected.contains(session_id) && *exited) + { + return Some(*session_id); + } + + lru.into_iter() + .find(|(session_id, _, _)| !protected.contains(session_id)) + .map(|(session_id, _, _)| session_id) + } } enum SessionStatus { @@ -561,3 +613,75 @@ enum SessionStatus { }, Unknown, } + +#[cfg(test)] +mod tests { + use super::*; + use pretty_assertions::assert_eq; + use tokio::time::Duration; + use tokio::time::Instant; + + #[test] + fn pruning_prefers_exited_sessions_outside_recently_used() { + let now = Instant::now(); + let meta = vec![ + (1, now - Duration::from_secs(40), false), + (2, now - Duration::from_secs(30), true), + (3, now - Duration::from_secs(20), false), + (4, now - Duration::from_secs(19), false), + (5, now - Duration::from_secs(18), false), + (6, now - Duration::from_secs(17), false), + (7, now - Duration::from_secs(16), false), + (8, now - Duration::from_secs(15), false), + (9, now - Duration::from_secs(14), false), + (10, now - Duration::from_secs(13), false), + ]; + + let candidate = UnifiedExecSessionManager::session_id_to_prune_from_meta(&meta); + + assert_eq!(candidate, Some(2)); + } + + #[test] + fn pruning_falls_back_to_lru_when_no_exited() { + let now = Instant::now(); + let meta = vec![ + (1, now - Duration::from_secs(40), false), + (2, now - Duration::from_secs(30), false), + (3, now - Duration::from_secs(20), false), + (4, now - Duration::from_secs(19), false), + (5, now - Duration::from_secs(18), false), + (6, now - Duration::from_secs(17), false), + (7, now - Duration::from_secs(16), false), + (8, now - Duration::from_secs(15), false), + (9, now - Duration::from_secs(14), false), + (10, now - Duration::from_secs(13), false), + ]; + + let candidate = UnifiedExecSessionManager::session_id_to_prune_from_meta(&meta); + + assert_eq!(candidate, Some(1)); + } + + #[test] + fn pruning_protects_recent_sessions_even_if_exited() { + let now = Instant::now(); + let meta = vec![ + (1, now - Duration::from_secs(40), false), + (2, now - Duration::from_secs(30), false), + (3, now - Duration::from_secs(20), true), + (4, now - Duration::from_secs(19), false), + (5, now - Duration::from_secs(18), false), + (6, now - Duration::from_secs(17), false), + (7, now - Duration::from_secs(16), false), + (8, now - Duration::from_secs(15), false), + (9, now - Duration::from_secs(14), false), + (10, now - Duration::from_secs(13), true), + ]; + + let candidate = UnifiedExecSessionManager::session_id_to_prune_from_meta(&meta); + + // (10) is exited but among the last 8; we should drop the LRU outside that set. + assert_eq!(candidate, Some(1)); + } +} diff --git a/codex-rs/core/tests/suite/unified_exec.rs b/codex-rs/core/tests/suite/unified_exec.rs index aed4cecef..3019e6e1e 100644 --- a/codex-rs/core/tests/suite/unified_exec.rs +++ b/codex-rs/core/tests/suite/unified_exec.rs @@ -1760,3 +1760,160 @@ async fn unified_exec_runs_under_sandbox() -> Result<()> { Ok(()) } + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn unified_exec_prunes_exited_sessions_first() -> Result<()> { + skip_if_no_network!(Ok(())); + skip_if_sandbox!(Ok(())); + + let server = start_mock_server().await; + + let mut builder = test_codex().with_config(|config| { + config.use_experimental_unified_exec_tool = true; + config.features.enable(Feature::UnifiedExec); + }); + let TestCodex { + codex, + cwd, + session_configured, + .. + } = builder.build(&server).await?; + + const MAX_SESSIONS_FOR_TEST: i32 = 64; + const FILLER_SESSIONS: i32 = MAX_SESSIONS_FOR_TEST - 1; + + let keep_call_id = "uexec-prune-keep"; + let keep_args = serde_json::json!({ + "cmd": "/bin/cat", + "yield_time_ms": 250, + }); + + let prune_call_id = "uexec-prune-target"; + let prune_args = serde_json::json!({ + "cmd": "sleep 1", + "yield_time_ms": 250, + }); + + let mut events = vec![ev_response_created("resp-prune-1")]; + events.push(ev_function_call( + keep_call_id, + "exec_command", + &serde_json::to_string(&keep_args)?, + )); + events.push(ev_function_call( + prune_call_id, + "exec_command", + &serde_json::to_string(&prune_args)?, + )); + + for idx in 0..FILLER_SESSIONS { + let filler_args = serde_json::json!({ + "cmd": format!("echo filler {idx}"), + "yield_time_ms": 250, + }); + let call_id = format!("uexec-prune-fill-{idx}"); + events.push(ev_function_call( + &call_id, + "exec_command", + &serde_json::to_string(&filler_args)?, + )); + } + + let keep_write_call_id = "uexec-prune-keep-write"; + let keep_write_args = serde_json::json!({ + "chars": "still alive\n", + "session_id": 0, + "yield_time_ms": 500, + }); + events.push(ev_function_call( + keep_write_call_id, + "write_stdin", + &serde_json::to_string(&keep_write_args)?, + )); + + let probe_call_id = "uexec-prune-probe"; + let probe_args = serde_json::json!({ + "chars": "should fail\n", + "session_id": 1, + "yield_time_ms": 500, + }); + events.push(ev_function_call( + probe_call_id, + "write_stdin", + &serde_json::to_string(&probe_args)?, + )); + + events.push(ev_completed("resp-prune-1")); + let first_response = sse(events); + let completion_response = sse(vec![ + ev_response_created("resp-prune-2"), + ev_assistant_message("msg-prune", "done"), + ev_completed("resp-prune-2"), + ]); + let response_mock = + mount_sse_sequence(&server, vec![first_response, completion_response]).await; + + let session_model = session_configured.model.clone(); + + codex + .submit(Op::UserTurn { + items: vec![UserInput::Text { + text: "fill session cache".into(), + }], + final_output_json_schema: None, + cwd: cwd.path().to_path_buf(), + approval_policy: AskForApproval::Never, + sandbox_policy: SandboxPolicy::DangerFullAccess, + model: session_model, + effort: None, + summary: ReasoningSummary::Auto, + }) + .await?; + + wait_for_event(&codex, |event| matches!(event, EventMsg::TaskComplete(_))).await; + + let requests = response_mock.requests(); + assert!( + !requests.is_empty(), + "expected at least one response request" + ); + + let keep_start = requests + .iter() + .find_map(|req| req.function_call_output_text(keep_call_id)) + .expect("missing initial keep session output"); + let keep_start_output = parse_unified_exec_output(&keep_start)?; + pretty_assertions::assert_eq!(keep_start_output.session_id, Some(0)); + assert!(keep_start_output.exit_code.is_none()); + + let prune_start = requests + .iter() + .find_map(|req| req.function_call_output_text(prune_call_id)) + .expect("missing initial prune session output"); + let prune_start_output = parse_unified_exec_output(&prune_start)?; + pretty_assertions::assert_eq!(prune_start_output.session_id, Some(1)); + assert!(prune_start_output.exit_code.is_none()); + + let keep_write = requests + .iter() + .find_map(|req| req.function_call_output_text(keep_write_call_id)) + .expect("missing keep write output"); + let keep_write_output = parse_unified_exec_output(&keep_write)?; + pretty_assertions::assert_eq!(keep_write_output.session_id, Some(0)); + assert!( + keep_write_output.output.contains("still alive"), + "expected cat session to echo input, got {:?}", + keep_write_output.output + ); + + let pruned_probe = requests + .iter() + .find_map(|req| req.function_call_output_text(probe_call_id)) + .expect("missing probe output"); + assert!( + pruned_probe.contains("UnknownSessionId") || pruned_probe.contains("Unknown session id"), + "expected probe to fail after pruning, got {pruned_probe:?}" + ); + + Ok(()) +}