Allow unified_exec to early exit (if the process terminates before yield_time_ms) (#6867)
Thread through an `exit_notify` tokio `Notify` through to the `UnifiedExecSession` so that we can return early if the command terminates before `yield_time_ms`. As Codex review correctly pointed out below 🙌 we also need a `exit_signaled` flag so that commands which finish before we start waiting can also exit early. Since the default `yield_time_ms` is now 10s, this means that we don't have to wait 10s for trivial commands like ls, sed, etc (which are the majority of agent commands 😅) --------- Co-authored-by: jif-oai <jif@openai.com>
This commit is contained in:
parent
54e6e4ac32
commit
b5dd189067
3 changed files with 198 additions and 52 deletions
|
|
@ -2,13 +2,13 @@
|
|||
|
||||
use std::collections::VecDeque;
|
||||
use std::sync::Arc;
|
||||
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::sync::Notify;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::sync::oneshot::error::TryRecvError;
|
||||
use tokio::task::JoinHandle;
|
||||
use tokio::time::Duration;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
|
||||
use crate::exec::ExecToolCallOutput;
|
||||
use crate::exec::SandboxType;
|
||||
|
|
@ -67,13 +67,18 @@ impl OutputBufferState {
|
|||
}
|
||||
|
||||
pub(crate) type OutputBuffer = Arc<Mutex<OutputBufferState>>;
|
||||
pub(crate) type OutputHandles = (OutputBuffer, Arc<Notify>);
|
||||
pub(crate) struct OutputHandles {
|
||||
pub(crate) output_buffer: OutputBuffer,
|
||||
pub(crate) output_notify: Arc<Notify>,
|
||||
pub(crate) cancellation_token: CancellationToken,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct UnifiedExecSession {
|
||||
session: ExecCommandSession,
|
||||
output_buffer: OutputBuffer,
|
||||
output_notify: Arc<Notify>,
|
||||
cancellation_token: CancellationToken,
|
||||
output_task: JoinHandle<()>,
|
||||
sandbox_type: SandboxType,
|
||||
}
|
||||
|
|
@ -86,9 +91,11 @@ impl UnifiedExecSession {
|
|||
) -> Self {
|
||||
let output_buffer = Arc::new(Mutex::new(OutputBufferState::default()));
|
||||
let output_notify = Arc::new(Notify::new());
|
||||
let cancellation_token = CancellationToken::new();
|
||||
let mut receiver = initial_output_rx;
|
||||
let buffer_clone = Arc::clone(&output_buffer);
|
||||
let notify_clone = Arc::clone(&output_notify);
|
||||
let cancellation_token_clone = cancellation_token.clone();
|
||||
let output_task = tokio::spawn(async move {
|
||||
loop {
|
||||
match receiver.recv().await {
|
||||
|
|
@ -99,7 +106,10 @@ impl UnifiedExecSession {
|
|||
notify_clone.notify_waiters();
|
||||
}
|
||||
Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => continue,
|
||||
Err(tokio::sync::broadcast::error::RecvError::Closed) => break,
|
||||
Err(tokio::sync::broadcast::error::RecvError::Closed) => {
|
||||
cancellation_token_clone.cancel();
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
|
@ -108,6 +118,7 @@ impl UnifiedExecSession {
|
|||
session,
|
||||
output_buffer,
|
||||
output_notify,
|
||||
cancellation_token,
|
||||
output_task,
|
||||
sandbox_type,
|
||||
}
|
||||
|
|
@ -118,10 +129,11 @@ impl UnifiedExecSession {
|
|||
}
|
||||
|
||||
pub(super) fn output_handles(&self) -> OutputHandles {
|
||||
(
|
||||
Arc::clone(&self.output_buffer),
|
||||
Arc::clone(&self.output_notify),
|
||||
)
|
||||
OutputHandles {
|
||||
output_buffer: Arc::clone(&self.output_buffer),
|
||||
output_notify: Arc::clone(&self.output_notify),
|
||||
cancellation_token: self.cancellation_token.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn has_exited(&self) -> bool {
|
||||
|
|
@ -199,20 +211,34 @@ impl UnifiedExecSession {
|
|||
};
|
||||
|
||||
if exit_ready {
|
||||
managed.signal_exit();
|
||||
managed.check_for_sandbox_denial().await?;
|
||||
return Ok(managed);
|
||||
}
|
||||
|
||||
tokio::pin!(exit_rx);
|
||||
if tokio::time::timeout(Duration::from_millis(50), &mut exit_rx)
|
||||
.await
|
||||
.is_ok()
|
||||
{
|
||||
managed.signal_exit();
|
||||
managed.check_for_sandbox_denial().await?;
|
||||
return Ok(managed);
|
||||
}
|
||||
|
||||
tokio::spawn({
|
||||
let cancellation_token = managed.cancellation_token.clone();
|
||||
async move {
|
||||
let _ = exit_rx.await;
|
||||
cancellation_token.cancel();
|
||||
}
|
||||
});
|
||||
|
||||
Ok(managed)
|
||||
}
|
||||
|
||||
fn signal_exit(&self) {
|
||||
self.cancellation_token.cancel();
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for UnifiedExecSession {
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ use tokio::sync::Notify;
|
|||
use tokio::sync::mpsc;
|
||||
use tokio::time::Duration;
|
||||
use tokio::time::Instant;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
|
||||
use crate::codex::Session;
|
||||
use crate::codex::TurnContext;
|
||||
|
|
@ -40,8 +41,20 @@ use super::clamp_yield_time;
|
|||
use super::generate_chunk_id;
|
||||
use super::resolve_max_tokens;
|
||||
use super::session::OutputBuffer;
|
||||
use super::session::OutputHandles;
|
||||
use super::session::UnifiedExecSession;
|
||||
|
||||
struct PreparedSessionHandles {
|
||||
writer_tx: mpsc::Sender<Vec<u8>>,
|
||||
output_buffer: OutputBuffer,
|
||||
output_notify: Arc<Notify>,
|
||||
cancellation_token: CancellationToken,
|
||||
session_ref: Arc<Session>,
|
||||
turn_ref: Arc<TurnContext>,
|
||||
command: Vec<String>,
|
||||
cwd: PathBuf,
|
||||
}
|
||||
|
||||
impl UnifiedExecSessionManager {
|
||||
pub(crate) async fn exec_command(
|
||||
&self,
|
||||
|
|
@ -67,10 +80,19 @@ impl UnifiedExecSessionManager {
|
|||
let yield_time_ms = clamp_yield_time(request.yield_time_ms);
|
||||
|
||||
let start = Instant::now();
|
||||
let (output_buffer, output_notify) = session.output_handles();
|
||||
let OutputHandles {
|
||||
output_buffer,
|
||||
output_notify,
|
||||
cancellation_token,
|
||||
} = session.output_handles();
|
||||
let deadline = start + Duration::from_millis(yield_time_ms);
|
||||
let collected =
|
||||
Self::collect_output_until_deadline(&output_buffer, &output_notify, deadline).await;
|
||||
let collected = Self::collect_output_until_deadline(
|
||||
&output_buffer,
|
||||
&output_notify,
|
||||
&cancellation_token,
|
||||
deadline,
|
||||
)
|
||||
.await;
|
||||
let wall_time = Instant::now().saturating_duration_since(start);
|
||||
|
||||
let text = String::from_utf8_lossy(&collected).to_string();
|
||||
|
|
@ -129,15 +151,16 @@ impl UnifiedExecSessionManager {
|
|||
) -> Result<UnifiedExecResponse, UnifiedExecError> {
|
||||
let session_id = request.session_id;
|
||||
|
||||
let (
|
||||
let PreparedSessionHandles {
|
||||
writer_tx,
|
||||
output_buffer,
|
||||
output_notify,
|
||||
cancellation_token,
|
||||
session_ref,
|
||||
turn_ref,
|
||||
session_command,
|
||||
session_cwd,
|
||||
) = self.prepare_session_handles(session_id).await?;
|
||||
command: session_command,
|
||||
cwd: session_cwd,
|
||||
} = self.prepare_session_handles(session_id).await?;
|
||||
|
||||
let interaction_emitter = ToolEmitter::unified_exec(
|
||||
&session_command,
|
||||
|
|
@ -176,8 +199,13 @@ impl UnifiedExecSessionManager {
|
|||
let yield_time_ms = clamp_yield_time(request.yield_time_ms);
|
||||
let start = Instant::now();
|
||||
let deadline = start + Duration::from_millis(yield_time_ms);
|
||||
let collected =
|
||||
Self::collect_output_until_deadline(&output_buffer, &output_notify, deadline).await;
|
||||
let collected = Self::collect_output_until_deadline(
|
||||
&output_buffer,
|
||||
&output_notify,
|
||||
&cancellation_token,
|
||||
deadline,
|
||||
)
|
||||
.await;
|
||||
let wall_time = Instant::now().saturating_duration_since(start);
|
||||
|
||||
let text = String::from_utf8_lossy(&collected).to_string();
|
||||
|
|
@ -265,44 +293,27 @@ impl UnifiedExecSessionManager {
|
|||
async fn prepare_session_handles(
|
||||
&self,
|
||||
session_id: i32,
|
||||
) -> Result<
|
||||
(
|
||||
mpsc::Sender<Vec<u8>>,
|
||||
OutputBuffer,
|
||||
Arc<Notify>,
|
||||
Arc<Session>,
|
||||
Arc<TurnContext>,
|
||||
Vec<String>,
|
||||
PathBuf,
|
||||
),
|
||||
UnifiedExecError,
|
||||
> {
|
||||
) -> Result<PreparedSessionHandles, UnifiedExecError> {
|
||||
let sessions = self.sessions.lock().await;
|
||||
let (output_buffer, output_notify, writer_tx, session, turn, command, cwd) =
|
||||
if let Some(entry) = sessions.get(&session_id) {
|
||||
let (buffer, notify) = entry.session.output_handles();
|
||||
(
|
||||
buffer,
|
||||
notify,
|
||||
entry.session.writer_sender(),
|
||||
Arc::clone(&entry.session_ref),
|
||||
Arc::clone(&entry.turn_ref),
|
||||
entry.command.clone(),
|
||||
entry.cwd.clone(),
|
||||
)
|
||||
} else {
|
||||
return Err(UnifiedExecError::UnknownSessionId { session_id });
|
||||
};
|
||||
|
||||
Ok((
|
||||
writer_tx,
|
||||
let entry = sessions
|
||||
.get(&session_id)
|
||||
.ok_or(UnifiedExecError::UnknownSessionId { session_id })?;
|
||||
let OutputHandles {
|
||||
output_buffer,
|
||||
output_notify,
|
||||
session,
|
||||
turn,
|
||||
command,
|
||||
cwd,
|
||||
))
|
||||
cancellation_token,
|
||||
} = entry.session.output_handles();
|
||||
|
||||
Ok(PreparedSessionHandles {
|
||||
writer_tx: entry.session.writer_sender(),
|
||||
output_buffer,
|
||||
output_notify,
|
||||
cancellation_token,
|
||||
session_ref: Arc::clone(&entry.session_ref),
|
||||
turn_ref: Arc::clone(&entry.turn_ref),
|
||||
command: entry.command.clone(),
|
||||
cwd: entry.cwd.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
async fn send_input(
|
||||
|
|
@ -480,9 +491,13 @@ impl UnifiedExecSessionManager {
|
|||
pub(super) async fn collect_output_until_deadline(
|
||||
output_buffer: &OutputBuffer,
|
||||
output_notify: &Arc<Notify>,
|
||||
cancellation_token: &CancellationToken,
|
||||
deadline: Instant,
|
||||
) -> Vec<u8> {
|
||||
const POST_EXIT_OUTPUT_GRACE: Duration = Duration::from_millis(25);
|
||||
|
||||
let mut collected: Vec<u8> = Vec::with_capacity(4096);
|
||||
let mut exit_signal_received = cancellation_token.is_cancelled();
|
||||
loop {
|
||||
let drained_chunks;
|
||||
let mut wait_for_output = None;
|
||||
|
|
@ -495,15 +510,27 @@ impl UnifiedExecSessionManager {
|
|||
}
|
||||
|
||||
if drained_chunks.is_empty() {
|
||||
exit_signal_received |= cancellation_token.is_cancelled();
|
||||
let remaining = deadline.saturating_duration_since(Instant::now());
|
||||
if remaining == Duration::ZERO {
|
||||
break;
|
||||
}
|
||||
|
||||
let notified = wait_for_output.unwrap_or_else(|| output_notify.notified());
|
||||
if exit_signal_received {
|
||||
let grace = remaining.min(POST_EXIT_OUTPUT_GRACE);
|
||||
if tokio::time::timeout(grace, notified).await.is_err() {
|
||||
break;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
tokio::pin!(notified);
|
||||
let exit_notified = cancellation_token.cancelled();
|
||||
tokio::pin!(exit_notified);
|
||||
tokio::select! {
|
||||
_ = &mut notified => {}
|
||||
_ = &mut exit_notified => exit_signal_received = true,
|
||||
_ = tokio::time::sleep(remaining) => break,
|
||||
}
|
||||
continue;
|
||||
|
|
@ -513,6 +540,7 @@ impl UnifiedExecSessionManager {
|
|||
collected.extend_from_slice(&chunk);
|
||||
}
|
||||
|
||||
exit_signal_received |= cancellation_token.is_cancelled();
|
||||
if Instant::now() >= deadline {
|
||||
break;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -904,6 +904,98 @@ async fn exec_command_reports_chunk_and_exit_metadata() -> Result<()> {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn unified_exec_respects_early_exit_notifications() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
skip_if_sandbox!(Ok(()));
|
||||
|
||||
let server = start_mock_server().await;
|
||||
|
||||
let mut builder = test_codex().with_config(|config| {
|
||||
config.features.enable(Feature::UnifiedExec);
|
||||
});
|
||||
let TestCodex {
|
||||
codex,
|
||||
cwd,
|
||||
session_configured,
|
||||
..
|
||||
} = builder.build(&server).await?;
|
||||
|
||||
let call_id = "uexec-early-exit";
|
||||
let args = serde_json::json!({
|
||||
"cmd": "sleep 0.05",
|
||||
"yield_time_ms": 31415,
|
||||
});
|
||||
|
||||
let responses = vec![
|
||||
sse(vec![
|
||||
ev_response_created("resp-1"),
|
||||
ev_function_call(call_id, "exec_command", &serde_json::to_string(&args)?),
|
||||
ev_completed("resp-1"),
|
||||
]),
|
||||
sse(vec![
|
||||
ev_assistant_message("msg-1", "done"),
|
||||
ev_completed("resp-2"),
|
||||
]),
|
||||
];
|
||||
mount_sse_sequence(&server, responses).await;
|
||||
|
||||
let session_model = session_configured.model.clone();
|
||||
|
||||
codex
|
||||
.submit(Op::UserTurn {
|
||||
items: vec![UserInput::Text {
|
||||
text: "watch early exit timing".into(),
|
||||
}],
|
||||
final_output_json_schema: None,
|
||||
cwd: cwd.path().to_path_buf(),
|
||||
approval_policy: AskForApproval::Never,
|
||||
sandbox_policy: SandboxPolicy::DangerFullAccess,
|
||||
model: session_model,
|
||||
effort: None,
|
||||
summary: ReasoningSummary::Auto,
|
||||
})
|
||||
.await?;
|
||||
|
||||
wait_for_event(&codex, |event| matches!(event, EventMsg::TaskComplete(_))).await;
|
||||
|
||||
let requests = server.received_requests().await.expect("recorded requests");
|
||||
assert!(!requests.is_empty(), "expected at least one POST request");
|
||||
|
||||
let bodies = requests
|
||||
.iter()
|
||||
.map(|req| req.body_json::<Value>().expect("request json"))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let outputs = collect_tool_outputs(&bodies)?;
|
||||
let output = outputs
|
||||
.get(call_id)
|
||||
.expect("missing early exit unified_exec output");
|
||||
|
||||
assert!(
|
||||
output.session_id.is_none(),
|
||||
"short-lived process should not keep a session alive"
|
||||
);
|
||||
assert_eq!(
|
||||
output.exit_code,
|
||||
Some(0),
|
||||
"short-lived process should exit successfully"
|
||||
);
|
||||
|
||||
let wall_time = output.wall_time_seconds;
|
||||
assert!(
|
||||
wall_time < 0.75,
|
||||
"wall_time should reflect early exit rather than the full yield time; got {wall_time}"
|
||||
);
|
||||
assert!(
|
||||
output.output.is_empty(),
|
||||
"sleep command should not emit output, got {:?}",
|
||||
output.output
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn write_stdin_returns_exit_metadata_and_clears_session() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue