feat: close unified_exec at end of turn (#8052)
This commit is contained in:
parent
cf44511e77
commit
ae57e18947
4 changed files with 173 additions and 2 deletions
|
|
@ -159,6 +159,7 @@ impl Session {
|
|||
for task in self.take_all_running_tasks().await {
|
||||
self.handle_task_abort(task, reason.clone()).await;
|
||||
}
|
||||
self.close_unified_exec_sessions().await;
|
||||
}
|
||||
|
||||
pub async fn on_task_finished(
|
||||
|
|
@ -167,12 +168,18 @@ impl Session {
|
|||
last_agent_message: Option<String>,
|
||||
) {
|
||||
let mut active = self.active_turn.lock().await;
|
||||
if let Some(at) = active.as_mut()
|
||||
let should_close_sessions = if let Some(at) = active.as_mut()
|
||||
&& at.remove_task(&turn_context.sub_id)
|
||||
{
|
||||
*active = None;
|
||||
}
|
||||
true
|
||||
} else {
|
||||
false
|
||||
};
|
||||
drop(active);
|
||||
if should_close_sessions {
|
||||
self.close_unified_exec_sessions().await;
|
||||
}
|
||||
let event = EventMsg::TaskComplete(TaskCompleteEvent { last_agent_message });
|
||||
self.send_event(turn_context.as_ref(), event).await;
|
||||
}
|
||||
|
|
@ -196,6 +203,13 @@ impl Session {
|
|||
}
|
||||
}
|
||||
|
||||
async fn close_unified_exec_sessions(&self) {
|
||||
self.services
|
||||
.unified_exec_manager
|
||||
.terminate_all_sessions()
|
||||
.await;
|
||||
}
|
||||
|
||||
async fn handle_task_abort(self: &Arc<Self>, task: RunningTask, reason: TurnAbortReason) {
|
||||
let sub_id = task.turn_context.sub_id.clone();
|
||||
if task.cancellation_token.is_cancelled() {
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ use std::path::PathBuf;
|
|||
#[cfg(target_os = "linux")]
|
||||
use assert_cmd::cargo::cargo_bin;
|
||||
|
||||
pub mod process;
|
||||
pub mod responses;
|
||||
pub mod streaming_sse;
|
||||
pub mod test_codex;
|
||||
|
|
|
|||
48
codex-rs/core/tests/common/process.rs
Normal file
48
codex-rs/core/tests/common/process.rs
Normal file
|
|
@ -0,0 +1,48 @@
|
|||
use anyhow::Context;
|
||||
use std::fs;
|
||||
use std::path::Path;
|
||||
use std::time::Duration;
|
||||
|
||||
pub async fn wait_for_pid_file(path: &Path) -> anyhow::Result<String> {
|
||||
let pid = tokio::time::timeout(Duration::from_secs(2), async {
|
||||
loop {
|
||||
if let Ok(contents) = fs::read_to_string(path) {
|
||||
let trimmed = contents.trim();
|
||||
if !trimmed.is_empty() {
|
||||
return trimmed.to_string();
|
||||
}
|
||||
}
|
||||
tokio::time::sleep(Duration::from_millis(25)).await;
|
||||
}
|
||||
})
|
||||
.await
|
||||
.context("timed out waiting for pid file")?;
|
||||
|
||||
Ok(pid)
|
||||
}
|
||||
|
||||
pub fn process_is_alive(pid: &str) -> anyhow::Result<bool> {
|
||||
let status = std::process::Command::new("kill")
|
||||
.args(["-0", pid])
|
||||
.status()
|
||||
.context("failed to probe process liveness with kill -0")?;
|
||||
Ok(status.success())
|
||||
}
|
||||
|
||||
async fn wait_for_process_exit_inner(pid: String) -> anyhow::Result<()> {
|
||||
loop {
|
||||
if !process_is_alive(&pid)? {
|
||||
return Ok(());
|
||||
}
|
||||
tokio::time::sleep(Duration::from_millis(25)).await;
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn wait_for_process_exit(pid: &str) -> anyhow::Result<()> {
|
||||
let pid = pid.to_string();
|
||||
tokio::time::timeout(Duration::from_secs(2), wait_for_process_exit_inner(pid))
|
||||
.await
|
||||
.context("timed out waiting for process to exit")??;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
|
@ -14,6 +14,8 @@ use codex_core::protocol::SandboxPolicy;
|
|||
use codex_protocol::config_types::ReasoningSummary;
|
||||
use codex_protocol::user_input::UserInput;
|
||||
use core_test_support::assert_regex_match;
|
||||
use core_test_support::process::wait_for_pid_file;
|
||||
use core_test_support::process::wait_for_process_exit;
|
||||
use core_test_support::responses::ev_assistant_message;
|
||||
use core_test_support::responses::ev_completed;
|
||||
use core_test_support::responses::ev_function_call;
|
||||
|
|
@ -31,6 +33,7 @@ use core_test_support::test_codex::test_codex;
|
|||
use core_test_support::wait_for_event;
|
||||
use core_test_support::wait_for_event_match;
|
||||
use core_test_support::wait_for_event_with_timeout;
|
||||
use pretty_assertions::assert_eq;
|
||||
use regex_lite::Regex;
|
||||
use serde_json::Value;
|
||||
use serde_json::json;
|
||||
|
|
@ -1640,6 +1643,111 @@ async fn unified_exec_emits_end_event_when_session_dies_via_stdin() -> Result<()
|
|||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn unified_exec_closes_long_running_session_at_turn_end() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
skip_if_sandbox!(Ok(()));
|
||||
skip_if_windows!(Ok(()));
|
||||
|
||||
let server = start_mock_server().await;
|
||||
|
||||
let mut builder = test_codex().with_config(|config| {
|
||||
config.use_experimental_unified_exec_tool = true;
|
||||
config.features.enable(Feature::UnifiedExec);
|
||||
});
|
||||
let TestCodex {
|
||||
codex,
|
||||
cwd,
|
||||
session_configured,
|
||||
..
|
||||
} = builder.build(&server).await?;
|
||||
|
||||
let temp_dir = tempfile::tempdir()?;
|
||||
let pid_path = temp_dir.path().join("uexec_pid");
|
||||
let pid_path_str = pid_path.to_string_lossy();
|
||||
|
||||
let call_id = "uexec-long-running";
|
||||
let command = format!("printf '%s' $$ > '{pid_path_str}' && exec sleep 3000");
|
||||
let args = json!({
|
||||
"cmd": command,
|
||||
"yield_time_ms": 250,
|
||||
});
|
||||
|
||||
let responses = vec![
|
||||
sse(vec![
|
||||
ev_response_created("resp-1"),
|
||||
ev_function_call(call_id, "exec_command", &serde_json::to_string(&args)?),
|
||||
ev_completed("resp-1"),
|
||||
]),
|
||||
sse(vec![
|
||||
ev_response_created("resp-2"),
|
||||
ev_assistant_message("msg-1", "done"),
|
||||
ev_completed("resp-2"),
|
||||
]),
|
||||
];
|
||||
mount_sse_sequence(&server, responses).await;
|
||||
|
||||
let session_model = session_configured.model.clone();
|
||||
|
||||
codex
|
||||
.submit(Op::UserTurn {
|
||||
items: vec![UserInput::Text {
|
||||
text: "close unified exec sessions on turn end".into(),
|
||||
}],
|
||||
final_output_json_schema: None,
|
||||
cwd: cwd.path().to_path_buf(),
|
||||
approval_policy: AskForApproval::Never,
|
||||
sandbox_policy: SandboxPolicy::DangerFullAccess,
|
||||
model: session_model,
|
||||
effort: None,
|
||||
summary: ReasoningSummary::Auto,
|
||||
})
|
||||
.await?;
|
||||
|
||||
let begin_event = wait_for_event_match(&codex, |msg| match msg {
|
||||
EventMsg::ExecCommandBegin(ev) if ev.call_id == call_id => Some(ev.clone()),
|
||||
_ => None,
|
||||
})
|
||||
.await;
|
||||
|
||||
let begin_process_id = begin_event
|
||||
.process_id
|
||||
.clone()
|
||||
.expect("expected process_id for long-running unified exec session");
|
||||
|
||||
let pid = wait_for_pid_file(&pid_path).await?;
|
||||
assert!(
|
||||
pid.chars().all(|ch| ch.is_ascii_digit()),
|
||||
"expected numeric pid, got {pid:?}"
|
||||
);
|
||||
|
||||
let mut end_event = None;
|
||||
let mut task_complete = false;
|
||||
loop {
|
||||
let msg = wait_for_event(&codex, |_| true).await;
|
||||
match msg {
|
||||
EventMsg::ExecCommandEnd(ev) if ev.call_id == call_id => end_event = Some(ev),
|
||||
EventMsg::TaskComplete(_) => task_complete = true,
|
||||
_ => {}
|
||||
}
|
||||
if task_complete && end_event.is_some() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
let end_event = end_event.expect("expected ExecCommandEnd event for unified exec session");
|
||||
assert_eq!(end_event.call_id, call_id);
|
||||
let end_process_id = end_event
|
||||
.process_id
|
||||
.clone()
|
||||
.expect("expected process_id in unified exec end event");
|
||||
assert_eq!(end_process_id, begin_process_id);
|
||||
|
||||
wait_for_process_exit(&pid).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn unified_exec_reuses_session_via_stdin() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue