diff --git a/codex-rs/core/src/shell_snapshot.rs b/codex-rs/core/src/shell_snapshot.rs index 4df54997b..2c4c423f5 100644 --- a/codex-rs/core/src/shell_snapshot.rs +++ b/codex-rs/core/src/shell_snapshot.rs @@ -367,6 +367,10 @@ mod tests { #[tokio::test] async fn timed_out_snapshot_shell_is_terminated() -> Result<()> { use std::process::Stdio; + use tokio::time::Duration as TokioDuration; + use tokio::time::Instant; + use tokio::time::sleep; + let dir = tempdir()?; let shell_path = dir.path().join("hanging-shell.sh"); let pid_path = dir.path().join("pid"); @@ -402,16 +406,22 @@ mod tests { .trim() .parse::()?; - let kill_status = StdCommand::new("kill") - .arg("-0") - .arg(pid.to_string()) - .stderr(Stdio::null()) - .stdout(Stdio::null()) - .status()?; - assert!( - !kill_status.success(), - "timed out snapshot shell should be terminated" - ); + let deadline = Instant::now() + TokioDuration::from_secs(1); + loop { + let kill_status = StdCommand::new("kill") + .arg("-0") + .arg(pid.to_string()) + .stderr(Stdio::null()) + .stdout(Stdio::null()) + .status()?; + if !kill_status.success() { + break; + } + if Instant::now() >= deadline { + panic!("timed out snapshot shell is still alive after grace period"); + } + sleep(TokioDuration::from_millis(50)).await; + } Ok(()) } diff --git a/codex-rs/core/src/unified_exec/async_watcher.rs b/codex-rs/core/src/unified_exec/async_watcher.rs index 7412d2972..19d91dbcc 100644 --- a/codex-rs/core/src/unified_exec/async_watcher.rs +++ b/codex-rs/core/src/unified_exec/async_watcher.rs @@ -1,9 +1,11 @@ use std::path::PathBuf; +use std::pin::Pin; use std::sync::Arc; use tokio::sync::Mutex; use tokio::time::Duration; use tokio::time::Instant; +use tokio::time::Sleep; use crate::codex::Session; use crate::codex::TurnContext; @@ -21,6 +23,8 @@ use super::CommandTranscript; use super::UnifiedExecContext; use super::session::UnifiedExecSession; +pub(crate) const TRAILING_OUTPUT_GRACE: Duration = Duration::from_millis(100); + /// Spawn a background task that continuously reads from the PTY, appends to the /// shared transcript, and emits ExecCommandOutputDelta events on UTF‑8 /// boundaries. @@ -30,39 +34,58 @@ pub(crate) fn start_streaming_output( transcript: Arc>, ) { let mut receiver = session.output_receiver(); + let output_drained = session.output_drained_notify(); + let exit_token = session.cancellation_token(); + let session_ref = Arc::clone(&context.session); let turn_ref = Arc::clone(&context.turn); let call_id = context.call_id.clone(); - let cancellation_token = session.cancellation_token(); tokio::spawn(async move { - let mut pending: Vec = Vec::new(); + use tokio::sync::broadcast::error::RecvError; + + let mut pending = Vec::::new(); + + let mut grace_sleep: Option>> = None; + loop { tokio::select! { - _ = cancellation_token.cancelled() => break, - result = receiver.recv() => match result { - Ok(chunk) => { - pending.extend_from_slice(&chunk); - while let Some(prefix) = split_valid_utf8_prefix(&mut pending) { - { - let mut guard = transcript.lock().await; - guard.append(&prefix); - } - - let event = ExecCommandOutputDeltaEvent { - call_id: call_id.clone(), - stream: ExecOutputStream::Stdout, - chunk: prefix, - }; - session_ref - .send_event(turn_ref.as_ref(), EventMsg::ExecCommandOutputDelta(event)) - .await; - } - } - Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => continue, - Err(tokio::sync::broadcast::error::RecvError::Closed) => break, + _ = exit_token.cancelled(), if grace_sleep.is_none() => { + let deadline = Instant::now() + TRAILING_OUTPUT_GRACE; + grace_sleep.replace(Box::pin(tokio::time::sleep_until(deadline))); } - }; + + _ = async { + if let Some(sleep) = grace_sleep.as_mut() { + sleep.as_mut().await; + } + }, if grace_sleep.is_some() => { + output_drained.notify_one(); + break; + } + + received = receiver.recv() => { + let chunk = match received { + Ok(chunk) => chunk, + Err(RecvError::Lagged(_)) => { + continue; + }, + Err(RecvError::Closed) => { + output_drained.notify_one(); + break; + } + }; + + process_chunk( + &mut pending, + &transcript, + &call_id, + &session_ref, + &turn_ref, + chunk, + ).await; + } + } } }); } @@ -82,9 +105,11 @@ pub(crate) fn spawn_exit_watcher( started_at: Instant, ) { let exit_token = session.cancellation_token(); + let output_drained = session.output_drained_notify(); tokio::spawn(async move { exit_token.cancelled().await; + output_drained.notified().await; let exit_code = session.exit_code().unwrap_or(-1); let duration = Instant::now().saturating_duration_since(started_at); @@ -104,6 +129,32 @@ pub(crate) fn spawn_exit_watcher( }); } +async fn process_chunk( + pending: &mut Vec, + transcript: &Arc>, + call_id: &str, + session_ref: &Arc, + turn_ref: &Arc, + chunk: Vec, +) { + pending.extend_from_slice(&chunk); + while let Some(prefix) = split_valid_utf8_prefix(pending) { + { + let mut guard = transcript.lock().await; + guard.append(&prefix); + } + + let event = ExecCommandOutputDeltaEvent { + call_id: call_id.to_string(), + stream: ExecOutputStream::Stdout, + chunk: prefix, + }; + session_ref + .send_event(turn_ref.as_ref(), EventMsg::ExecCommandOutputDelta(event)) + .await; + } +} + /// Emit an ExecCommandEnd event for a unified exec session, using the transcript /// as the primary source of aggregated_output and falling back to the provided /// text when the transcript is empty. diff --git a/codex-rs/core/src/unified_exec/session.rs b/codex-rs/core/src/unified_exec/session.rs index 51ebbd356..4973a1a64 100644 --- a/codex-rs/core/src/unified_exec/session.rs +++ b/codex-rs/core/src/unified_exec/session.rs @@ -79,6 +79,7 @@ pub(crate) struct UnifiedExecSession { output_buffer: OutputBuffer, output_notify: Arc, cancellation_token: CancellationToken, + output_drained: Arc, output_task: JoinHandle<()>, sandbox_type: SandboxType, } @@ -92,27 +93,21 @@ impl UnifiedExecSession { let output_buffer = Arc::new(Mutex::new(OutputBufferState::default())); let output_notify = Arc::new(Notify::new()); let cancellation_token = CancellationToken::new(); + let output_drained = Arc::new(Notify::new()); let mut receiver = initial_output_rx; let buffer_clone = Arc::clone(&output_buffer); let notify_clone = Arc::clone(&output_notify); - let cancellation_token_clone = cancellation_token.clone(); let output_task = tokio::spawn(async move { loop { - tokio::select! { - _ = cancellation_token_clone.cancelled() => break, - result = receiver.recv() => match result { - Ok(chunk) => { - let mut guard = buffer_clone.lock().await; - guard.push_chunk(chunk); - drop(guard); - notify_clone.notify_waiters(); - } - Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => continue, - Err(tokio::sync::broadcast::error::RecvError::Closed) => { - cancellation_token_clone.cancel(); - break; - } + match receiver.recv().await { + Ok(chunk) => { + let mut guard = buffer_clone.lock().await; + guard.push_chunk(chunk); + drop(guard); + notify_clone.notify_waiters(); } + Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => continue, + Err(tokio::sync::broadcast::error::RecvError::Closed) => break, }; } }); @@ -122,6 +117,7 @@ impl UnifiedExecSession { output_buffer, output_notify, cancellation_token, + output_drained, output_task, sandbox_type, } @@ -147,6 +143,10 @@ impl UnifiedExecSession { self.cancellation_token.clone() } + pub(super) fn output_drained_notify(&self) -> Arc { + Arc::clone(&self.output_drained) + } + pub(super) fn has_exited(&self) -> bool { self.session.has_exited() } diff --git a/codex-rs/core/tests/suite/shell_snapshot.rs b/codex-rs/core/tests/suite/shell_snapshot.rs index f50e153dd..cc9d4ee77 100644 --- a/codex-rs/core/tests/suite/shell_snapshot.rs +++ b/codex-rs/core/tests/suite/shell_snapshot.rs @@ -132,6 +132,7 @@ fn assert_posix_snapshot_sections(snapshot: &str) { async fn linux_unified_exec_uses_shell_snapshot() -> Result<()> { let command = "echo snapshot-linux"; let run = run_snapshot_command(command).await?; + let stdout = normalize_newlines(&run.end.stdout); let shell_path = run .begin @@ -150,8 +151,11 @@ async fn linux_unified_exec_uses_shell_snapshot() -> Result<()> { assert!(run.snapshot_path.starts_with(&run.codex_home)); assert_posix_snapshot_sections(&run.snapshot_content); - assert_eq!(normalize_newlines(&run.end.stdout).trim(), "snapshot-linux"); assert_eq!(run.end.exit_code, 0); + assert!( + stdout.contains("snapshot-linux"), + "stdout should contain snapshot marker; stdout={stdout:?}" + ); Ok(()) } diff --git a/codex-rs/core/tests/suite/unified_exec.rs b/codex-rs/core/tests/suite/unified_exec.rs index 5def7aadb..e2dcb0c56 100644 --- a/codex-rs/core/tests/suite/unified_exec.rs +++ b/codex-rs/core/tests/suite/unified_exec.rs @@ -228,6 +228,7 @@ async fn unified_exec_intercepts_apply_patch_exec_command() -> Result<()> { false } EventMsg::ExecCommandBegin(event) if event.call_id == call_id => { + println!("Saw it"); saw_exec_begin = true; false } @@ -893,7 +894,7 @@ async fn unified_exec_terminal_interaction_captures_delayed_output() -> Result<( let open_call_id = "uexec-delayed-open"; let open_args = json!({ - "cmd": "sleep 5 && echo MARKER1 && sleep 5 && echo MARKER2", + "cmd": "sleep 3 && echo MARKER1 && sleep 3 && echo MARKER2", "yield_time_ms": 10, }); @@ -910,14 +911,14 @@ async fn unified_exec_terminal_interaction_captures_delayed_output() -> Result<( let second_poll_args = json!({ "chars": "", "session_id": 1000, - "yield_time_ms": 6000, + "yield_time_ms": 4000, }); let third_poll_call_id = "uexec-delayed-poll-3"; let third_poll_args = json!({ "chars": "", "session_id": 1000, - "yield_time_ms": 10000, + "yield_time_ms": 6000, }); let responses = vec![ @@ -984,6 +985,7 @@ async fn unified_exec_terminal_interaction_captures_delayed_output() -> Result<( let mut begin_event = None; let mut end_event = None; + let mut task_completed = false; let mut terminal_events = Vec::new(); let mut delta_text = String::new(); @@ -1003,8 +1005,13 @@ async fn unified_exec_terminal_interaction_captures_delayed_output() -> Result<( EventMsg::ExecCommandEnd(ev) if ev.call_id == open_call_id => { end_event = Some(ev); } - EventMsg::TaskComplete(_) => break, + EventMsg::TaskComplete(_) => { + task_completed = true; + } _ => {} + }; + if task_completed && end_event.is_some() { + break; } }