From b5dd1890671d58dd92c29dc4ffa6432ecb48d54c Mon Sep 17 00:00:00 2001 From: hanson-openai Date: Thu, 20 Nov 2025 04:34:41 -0800 Subject: [PATCH] Allow unified_exec to early exit (if the process terminates before yield_time_ms) (#6867) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Thread through an `exit_notify` tokio `Notify` through to the `UnifiedExecSession` so that we can return early if the command terminates before `yield_time_ms`. As Codex review correctly pointed out below 🙌 we also need a `exit_signaled` flag so that commands which finish before we start waiting can also exit early. Since the default `yield_time_ms` is now 10s, this means that we don't have to wait 10s for trivial commands like ls, sed, etc (which are the majority of agent commands 😅) --------- Co-authored-by: jif-oai --- codex-rs/core/src/unified_exec/session.rs | 42 +++++-- .../core/src/unified_exec/session_manager.rs | 116 +++++++++++------- codex-rs/core/tests/suite/unified_exec.rs | 92 ++++++++++++++ 3 files changed, 198 insertions(+), 52 deletions(-) diff --git a/codex-rs/core/src/unified_exec/session.rs b/codex-rs/core/src/unified_exec/session.rs index b37a9cdb5..710334c80 100644 --- a/codex-rs/core/src/unified_exec/session.rs +++ b/codex-rs/core/src/unified_exec/session.rs @@ -2,13 +2,13 @@ use std::collections::VecDeque; use std::sync::Arc; - use tokio::sync::Mutex; use tokio::sync::Notify; use tokio::sync::mpsc; use tokio::sync::oneshot::error::TryRecvError; use tokio::task::JoinHandle; use tokio::time::Duration; +use tokio_util::sync::CancellationToken; use crate::exec::ExecToolCallOutput; use crate::exec::SandboxType; @@ -67,13 +67,18 @@ impl OutputBufferState { } pub(crate) type OutputBuffer = Arc>; -pub(crate) type OutputHandles = (OutputBuffer, Arc); +pub(crate) struct OutputHandles { + pub(crate) output_buffer: OutputBuffer, + pub(crate) output_notify: Arc, + pub(crate) cancellation_token: CancellationToken, +} #[derive(Debug)] pub(crate) struct UnifiedExecSession { session: ExecCommandSession, output_buffer: OutputBuffer, output_notify: Arc, + cancellation_token: CancellationToken, output_task: JoinHandle<()>, sandbox_type: SandboxType, } @@ -86,9 +91,11 @@ impl UnifiedExecSession { ) -> Self { let output_buffer = Arc::new(Mutex::new(OutputBufferState::default())); let output_notify = Arc::new(Notify::new()); + let cancellation_token = CancellationToken::new(); let mut receiver = initial_output_rx; let buffer_clone = Arc::clone(&output_buffer); let notify_clone = Arc::clone(&output_notify); + let cancellation_token_clone = cancellation_token.clone(); let output_task = tokio::spawn(async move { loop { match receiver.recv().await { @@ -99,7 +106,10 @@ impl UnifiedExecSession { notify_clone.notify_waiters(); } Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => continue, - Err(tokio::sync::broadcast::error::RecvError::Closed) => break, + Err(tokio::sync::broadcast::error::RecvError::Closed) => { + cancellation_token_clone.cancel(); + break; + } } } }); @@ -108,6 +118,7 @@ impl UnifiedExecSession { session, output_buffer, output_notify, + cancellation_token, output_task, sandbox_type, } @@ -118,10 +129,11 @@ impl UnifiedExecSession { } pub(super) fn output_handles(&self) -> OutputHandles { - ( - Arc::clone(&self.output_buffer), - Arc::clone(&self.output_notify), - ) + OutputHandles { + output_buffer: Arc::clone(&self.output_buffer), + output_notify: Arc::clone(&self.output_notify), + cancellation_token: self.cancellation_token.clone(), + } } pub(super) fn has_exited(&self) -> bool { @@ -199,20 +211,34 @@ impl UnifiedExecSession { }; if exit_ready { + managed.signal_exit(); managed.check_for_sandbox_denial().await?; return Ok(managed); } - tokio::pin!(exit_rx); if tokio::time::timeout(Duration::from_millis(50), &mut exit_rx) .await .is_ok() { + managed.signal_exit(); managed.check_for_sandbox_denial().await?; + return Ok(managed); } + tokio::spawn({ + let cancellation_token = managed.cancellation_token.clone(); + async move { + let _ = exit_rx.await; + cancellation_token.cancel(); + } + }); + Ok(managed) } + + fn signal_exit(&self) { + self.cancellation_token.cancel(); + } } impl Drop for UnifiedExecSession { diff --git a/codex-rs/core/src/unified_exec/session_manager.rs b/codex-rs/core/src/unified_exec/session_manager.rs index 93340bb2d..d9f99b9ea 100644 --- a/codex-rs/core/src/unified_exec/session_manager.rs +++ b/codex-rs/core/src/unified_exec/session_manager.rs @@ -5,6 +5,7 @@ use tokio::sync::Notify; use tokio::sync::mpsc; use tokio::time::Duration; use tokio::time::Instant; +use tokio_util::sync::CancellationToken; use crate::codex::Session; use crate::codex::TurnContext; @@ -40,8 +41,20 @@ use super::clamp_yield_time; use super::generate_chunk_id; use super::resolve_max_tokens; use super::session::OutputBuffer; +use super::session::OutputHandles; use super::session::UnifiedExecSession; +struct PreparedSessionHandles { + writer_tx: mpsc::Sender>, + output_buffer: OutputBuffer, + output_notify: Arc, + cancellation_token: CancellationToken, + session_ref: Arc, + turn_ref: Arc, + command: Vec, + cwd: PathBuf, +} + impl UnifiedExecSessionManager { pub(crate) async fn exec_command( &self, @@ -67,10 +80,19 @@ impl UnifiedExecSessionManager { let yield_time_ms = clamp_yield_time(request.yield_time_ms); let start = Instant::now(); - let (output_buffer, output_notify) = session.output_handles(); + let OutputHandles { + output_buffer, + output_notify, + cancellation_token, + } = session.output_handles(); let deadline = start + Duration::from_millis(yield_time_ms); - let collected = - Self::collect_output_until_deadline(&output_buffer, &output_notify, deadline).await; + let collected = Self::collect_output_until_deadline( + &output_buffer, + &output_notify, + &cancellation_token, + deadline, + ) + .await; let wall_time = Instant::now().saturating_duration_since(start); let text = String::from_utf8_lossy(&collected).to_string(); @@ -129,15 +151,16 @@ impl UnifiedExecSessionManager { ) -> Result { let session_id = request.session_id; - let ( + let PreparedSessionHandles { writer_tx, output_buffer, output_notify, + cancellation_token, session_ref, turn_ref, - session_command, - session_cwd, - ) = self.prepare_session_handles(session_id).await?; + command: session_command, + cwd: session_cwd, + } = self.prepare_session_handles(session_id).await?; let interaction_emitter = ToolEmitter::unified_exec( &session_command, @@ -176,8 +199,13 @@ impl UnifiedExecSessionManager { let yield_time_ms = clamp_yield_time(request.yield_time_ms); let start = Instant::now(); let deadline = start + Duration::from_millis(yield_time_ms); - let collected = - Self::collect_output_until_deadline(&output_buffer, &output_notify, deadline).await; + let collected = Self::collect_output_until_deadline( + &output_buffer, + &output_notify, + &cancellation_token, + deadline, + ) + .await; let wall_time = Instant::now().saturating_duration_since(start); let text = String::from_utf8_lossy(&collected).to_string(); @@ -265,44 +293,27 @@ impl UnifiedExecSessionManager { async fn prepare_session_handles( &self, session_id: i32, - ) -> Result< - ( - mpsc::Sender>, - OutputBuffer, - Arc, - Arc, - Arc, - Vec, - PathBuf, - ), - UnifiedExecError, - > { + ) -> Result { let sessions = self.sessions.lock().await; - let (output_buffer, output_notify, writer_tx, session, turn, command, cwd) = - if let Some(entry) = sessions.get(&session_id) { - let (buffer, notify) = entry.session.output_handles(); - ( - buffer, - notify, - entry.session.writer_sender(), - Arc::clone(&entry.session_ref), - Arc::clone(&entry.turn_ref), - entry.command.clone(), - entry.cwd.clone(), - ) - } else { - return Err(UnifiedExecError::UnknownSessionId { session_id }); - }; - - Ok(( - writer_tx, + let entry = sessions + .get(&session_id) + .ok_or(UnifiedExecError::UnknownSessionId { session_id })?; + let OutputHandles { output_buffer, output_notify, - session, - turn, - command, - cwd, - )) + cancellation_token, + } = entry.session.output_handles(); + + Ok(PreparedSessionHandles { + writer_tx: entry.session.writer_sender(), + output_buffer, + output_notify, + cancellation_token, + session_ref: Arc::clone(&entry.session_ref), + turn_ref: Arc::clone(&entry.turn_ref), + command: entry.command.clone(), + cwd: entry.cwd.clone(), + }) } async fn send_input( @@ -480,9 +491,13 @@ impl UnifiedExecSessionManager { pub(super) async fn collect_output_until_deadline( output_buffer: &OutputBuffer, output_notify: &Arc, + cancellation_token: &CancellationToken, deadline: Instant, ) -> Vec { + const POST_EXIT_OUTPUT_GRACE: Duration = Duration::from_millis(25); + let mut collected: Vec = Vec::with_capacity(4096); + let mut exit_signal_received = cancellation_token.is_cancelled(); loop { let drained_chunks; let mut wait_for_output = None; @@ -495,15 +510,27 @@ impl UnifiedExecSessionManager { } if drained_chunks.is_empty() { + exit_signal_received |= cancellation_token.is_cancelled(); let remaining = deadline.saturating_duration_since(Instant::now()); if remaining == Duration::ZERO { break; } let notified = wait_for_output.unwrap_or_else(|| output_notify.notified()); + if exit_signal_received { + let grace = remaining.min(POST_EXIT_OUTPUT_GRACE); + if tokio::time::timeout(grace, notified).await.is_err() { + break; + } + continue; + } + tokio::pin!(notified); + let exit_notified = cancellation_token.cancelled(); + tokio::pin!(exit_notified); tokio::select! { _ = &mut notified => {} + _ = &mut exit_notified => exit_signal_received = true, _ = tokio::time::sleep(remaining) => break, } continue; @@ -513,6 +540,7 @@ impl UnifiedExecSessionManager { collected.extend_from_slice(&chunk); } + exit_signal_received |= cancellation_token.is_cancelled(); if Instant::now() >= deadline { break; } diff --git a/codex-rs/core/tests/suite/unified_exec.rs b/codex-rs/core/tests/suite/unified_exec.rs index 07a0b21a7..d3059aba4 100644 --- a/codex-rs/core/tests/suite/unified_exec.rs +++ b/codex-rs/core/tests/suite/unified_exec.rs @@ -904,6 +904,98 @@ async fn exec_command_reports_chunk_and_exit_metadata() -> Result<()> { Ok(()) } +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn unified_exec_respects_early_exit_notifications() -> Result<()> { + skip_if_no_network!(Ok(())); + skip_if_sandbox!(Ok(())); + + let server = start_mock_server().await; + + let mut builder = test_codex().with_config(|config| { + config.features.enable(Feature::UnifiedExec); + }); + let TestCodex { + codex, + cwd, + session_configured, + .. + } = builder.build(&server).await?; + + let call_id = "uexec-early-exit"; + let args = serde_json::json!({ + "cmd": "sleep 0.05", + "yield_time_ms": 31415, + }); + + let responses = vec![ + sse(vec![ + ev_response_created("resp-1"), + ev_function_call(call_id, "exec_command", &serde_json::to_string(&args)?), + ev_completed("resp-1"), + ]), + sse(vec![ + ev_assistant_message("msg-1", "done"), + ev_completed("resp-2"), + ]), + ]; + mount_sse_sequence(&server, responses).await; + + let session_model = session_configured.model.clone(); + + codex + .submit(Op::UserTurn { + items: vec![UserInput::Text { + text: "watch early exit timing".into(), + }], + final_output_json_schema: None, + cwd: cwd.path().to_path_buf(), + approval_policy: AskForApproval::Never, + sandbox_policy: SandboxPolicy::DangerFullAccess, + model: session_model, + effort: None, + summary: ReasoningSummary::Auto, + }) + .await?; + + wait_for_event(&codex, |event| matches!(event, EventMsg::TaskComplete(_))).await; + + let requests = server.received_requests().await.expect("recorded requests"); + assert!(!requests.is_empty(), "expected at least one POST request"); + + let bodies = requests + .iter() + .map(|req| req.body_json::().expect("request json")) + .collect::>(); + + let outputs = collect_tool_outputs(&bodies)?; + let output = outputs + .get(call_id) + .expect("missing early exit unified_exec output"); + + assert!( + output.session_id.is_none(), + "short-lived process should not keep a session alive" + ); + assert_eq!( + output.exit_code, + Some(0), + "short-lived process should exit successfully" + ); + + let wall_time = output.wall_time_seconds; + assert!( + wall_time < 0.75, + "wall_time should reflect early exit rather than the full yield time; got {wall_time}" + ); + assert!( + output.output.is_empty(), + "sleep command should not emit output, got {:?}", + output.output + ); + + Ok(()) +} + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn write_stdin_returns_exit_metadata_and_clears_session() -> Result<()> { skip_if_no_network!(Ok(()));