From 3d356723c409c2681bb18ff8dc280fe8df68618d Mon Sep 17 00:00:00 2001 From: Michael Bolin Date: Tue, 24 Feb 2026 16:20:08 -0800 Subject: [PATCH] fix: make EscalateServer public and remove shell escalation wrappers (#12724) ## Why `codex-shell-escalation` exposed a `codex-core`-specific adapter layer (`ShellActionProvider`, `ShellPolicyFactory`, and `run_escalate_server`) that existed only to bridge `codex-core` to `EscalateServer`. That indirection increased API surface and obscured crate ownership without adding behavior. This change moves orchestration into `codex-core` so boundaries are clearer: `codex-shell-escalation` provides reusable escalation primitives, and `codex-core` provides shell-tool policy decisions. Admittedly, @pakrym rightfully requested this sort of cleanup as part of https://github.com/openai/codex/pull/12649, though this avoids moving all of `codex-shell-escalation` into `codex-core`. ## What changed - Made `EscalateServer` public and exported it from `shell-escalation`. - Removed the adapter layer from `shell-escalation`: - deleted `shell-escalation/src/unix/core_shell_escalation.rs` - removed exports for `ShellActionProvider`, `ShellPolicyFactory`, `EscalationPolicyFactory`, and `run_escalate_server` - Updated `core/src/tools/runtimes/shell/unix_escalation.rs` to: - create `Stopwatch`/cancellation in `codex-core` - instantiate `EscalateServer` directly - implement `EscalationPolicy` directly on `CoreShellActionProvider` Net effect: same escalation flow with fewer wrappers and a smaller public API. ## Verification - Manually reviewed the old vs. new escalation call flow to confirm timeout/cancellation behavior and approval policy decisions are preserved while removing wrapper types. --- codex-rs/Cargo.lock | 1 - .../tools/runtimes/shell/unix_escalation.rs | 68 +++++++++-------- codex-rs/shell-escalation/Cargo.toml | 7 +- codex-rs/shell-escalation/src/lib.rs | 10 +-- .../src/unix/core_shell_escalation.rs | 73 ------------------- .../src/unix/escalate_server.rs | 39 +--------- codex-rs/shell-escalation/src/unix/mod.rs | 6 +- 7 files changed, 47 insertions(+), 157 deletions(-) delete mode 100644 codex-rs/shell-escalation/src/unix/core_shell_escalation.rs diff --git a/codex-rs/Cargo.lock b/codex-rs/Cargo.lock index 056eae957..91932b30f 100644 --- a/codex-rs/Cargo.lock +++ b/codex-rs/Cargo.lock @@ -2241,7 +2241,6 @@ dependencies = [ "anyhow", "async-trait", "clap", - "codex-execpolicy", "libc", "path-absolutize", "pretty_assertions", diff --git a/codex-rs/core/src/tools/runtimes/shell/unix_escalation.rs b/codex-rs/core/src/tools/runtimes/shell/unix_escalation.rs index 2fd7f49e2..811b1fdfe 100644 --- a/codex-rs/core/src/tools/runtimes/shell/unix_escalation.rs +++ b/codex-rs/core/src/tools/runtimes/shell/unix_escalation.rs @@ -24,13 +24,12 @@ use codex_protocol::protocol::SandboxPolicy; use codex_shell_command::bash::parse_shell_lc_plain_commands; use codex_shell_command::bash::parse_shell_lc_single_command_prefix; use codex_shell_escalation::EscalateAction; +use codex_shell_escalation::EscalateServer; +use codex_shell_escalation::EscalationPolicy; use codex_shell_escalation::ExecParams; use codex_shell_escalation::ExecResult; -use codex_shell_escalation::ShellActionProvider; use codex_shell_escalation::ShellCommandExecutor; -use codex_shell_escalation::ShellPolicyFactory; use codex_shell_escalation::Stopwatch; -use codex_shell_escalation::run_escalate_server; use std::collections::HashMap; use std::path::Path; use std::path::PathBuf; @@ -107,30 +106,39 @@ pub(super) async fn try_run_zsh_fork( justification, arg0, }; - let exec_result = run_escalate_server( - ExecParams { - command: script, - workdir: req.cwd.to_string_lossy().to_string(), - timeout_ms: Some(effective_timeout.as_millis() as u64), - login: Some(login), - }, - shell_zsh_path.clone(), - shell_execve_wrapper().map_err(|err| ToolError::Rejected(format!("{err}")))?, - exec_policy.clone(), - ShellPolicyFactory::new(CoreShellActionProvider { - policy: Arc::clone(&exec_policy), - session: Arc::clone(&ctx.session), - turn: Arc::clone(&ctx.turn), - call_id: ctx.call_id.clone(), - approval_policy: ctx.turn.approval_policy.value(), - sandbox_policy: attempt.policy.clone(), - sandbox_permissions: req.sandbox_permissions, - }), - effective_timeout, - &command_executor, - ) - .await - .map_err(|err| ToolError::Rejected(err.to_string()))?; + + let exec_params = ExecParams { + command: script, + workdir: req.cwd.to_string_lossy().to_string(), + timeout_ms: Some(effective_timeout.as_millis() as u64), + login: Some(login), + }; + let execve_wrapper = + shell_execve_wrapper().map_err(|err| ToolError::Rejected(format!("{err}")))?; + + // Note that Stopwatch starts immediately upon creation, so currently we try + // to minimize the time between creating the Stopwatch and starting the + // escalation server. + let stopwatch = Stopwatch::new(effective_timeout); + let cancel_token = stopwatch.cancellation_token(); + let escalation_policy = CoreShellActionProvider { + policy: Arc::clone(&exec_policy), + session: Arc::clone(&ctx.session), + turn: Arc::clone(&ctx.turn), + call_id: ctx.call_id.clone(), + approval_policy: ctx.turn.approval_policy.value(), + sandbox_policy: attempt.policy.clone(), + sandbox_permissions: req.sandbox_permissions, + stopwatch: stopwatch.clone(), + }; + + let escalate_server = + EscalateServer::new(shell_zsh_path.clone(), execve_wrapper, escalation_policy); + + let exec_result = escalate_server + .exec(exec_params, cancel_token, &command_executor) + .await + .map_err(|err| ToolError::Rejected(err.to_string()))?; map_exec_result(attempt.sandbox, exec_result).map(Some) } @@ -143,6 +151,7 @@ struct CoreShellActionProvider { approval_policy: AskForApproval, sandbox_policy: SandboxPolicy, sandbox_permissions: SandboxPermissions, + stopwatch: Stopwatch, } impl CoreShellActionProvider { @@ -186,13 +195,12 @@ impl CoreShellActionProvider { } #[async_trait::async_trait] -impl ShellActionProvider for CoreShellActionProvider { +impl EscalationPolicy for CoreShellActionProvider { async fn determine_action( &self, file: &Path, argv: &[String], workdir: &Path, - stopwatch: &Stopwatch, ) -> anyhow::Result { let command = std::iter::once(file.to_string_lossy().to_string()) .chain(argv.iter().cloned()) @@ -240,7 +248,7 @@ impl ShellActionProvider for CoreShellActionProvider { reason: Some("Execution forbidden by policy".to_string()), } } else { - match self.prompt(&command, workdir, stopwatch).await? { + match self.prompt(&command, workdir, &self.stopwatch).await? { ReviewDecision::Approved | ReviewDecision::ApprovedExecpolicyAmendment { .. } | ReviewDecision::ApprovedForSession => { diff --git a/codex-rs/shell-escalation/Cargo.toml b/codex-rs/shell-escalation/Cargo.toml index ec88bf82f..f608b9608 100644 --- a/codex-rs/shell-escalation/Cargo.toml +++ b/codex-rs/shell-escalation/Cargo.toml @@ -1,8 +1,8 @@ [package] -name = "codex-shell-escalation" -version.workspace = true edition.workspace = true license.workspace = true +name = "codex-shell-escalation" +version.workspace = true [[bin]] name = "codex-execve-wrapper" @@ -12,11 +12,10 @@ path = "src/bin/main_execve_wrapper.rs" anyhow = { workspace = true } async-trait = { workspace = true } clap = { workspace = true, features = ["derive"] } -codex-execpolicy = { workspace = true } libc = { workspace = true } -serde_json = { workspace = true } path-absolutize = { workspace = true } serde = { workspace = true, features = ["derive"] } +serde_json = { workspace = true } socket2 = { workspace = true, features = ["all"] } tokio = { workspace = true, features = [ "io-std", diff --git a/codex-rs/shell-escalation/src/lib.rs b/codex-rs/shell-escalation/src/lib.rs index 98d33fc59..48bc11658 100644 --- a/codex-rs/shell-escalation/src/lib.rs +++ b/codex-rs/shell-escalation/src/lib.rs @@ -4,24 +4,18 @@ mod unix; #[cfg(unix)] pub use unix::EscalateAction; #[cfg(unix)] -pub use unix::EscalationPolicy; +pub use unix::EscalateServer; #[cfg(unix)] -pub use unix::EscalationPolicyFactory; +pub use unix::EscalationPolicy; #[cfg(unix)] pub use unix::ExecParams; #[cfg(unix)] pub use unix::ExecResult; #[cfg(unix)] -pub use unix::ShellActionProvider; -#[cfg(unix)] pub use unix::ShellCommandExecutor; #[cfg(unix)] -pub use unix::ShellPolicyFactory; -#[cfg(unix)] pub use unix::Stopwatch; #[cfg(unix)] pub use unix::main_execve_wrapper; #[cfg(unix)] -pub use unix::run_escalate_server; -#[cfg(unix)] pub use unix::run_shell_escalation_execve_wrapper; diff --git a/codex-rs/shell-escalation/src/unix/core_shell_escalation.rs b/codex-rs/shell-escalation/src/unix/core_shell_escalation.rs deleted file mode 100644 index 3f2d5ad23..000000000 --- a/codex-rs/shell-escalation/src/unix/core_shell_escalation.rs +++ /dev/null @@ -1,73 +0,0 @@ -use async_trait::async_trait; -use std::path::Path; -use std::sync::Arc; -use tokio::sync::RwLock; - -use super::escalate_protocol::EscalateAction; -use super::escalate_server::EscalationPolicyFactory; -use super::escalation_policy::EscalationPolicy; -use super::stopwatch::Stopwatch; -use codex_execpolicy::Policy; - -#[async_trait] -pub trait ShellActionProvider: Send + Sync { - async fn determine_action( - &self, - file: &Path, - argv: &[String], - workdir: &Path, - stopwatch: &Stopwatch, - ) -> anyhow::Result; -} - -#[derive(Clone)] -pub struct ShellPolicyFactory { - provider: Arc, -} - -impl ShellPolicyFactory { - pub fn new

(provider: P) -> Self - where - P: ShellActionProvider + 'static, - { - Self { - provider: Arc::new(provider), - } - } - - pub fn with_provider(provider: Arc) -> Self { - Self { provider } - } -} - -/// Public only because it is the associated `Policy` type in the public -/// `EscalationPolicyFactory` impl for `ShellPolicyFactory`. -pub struct ShellEscalationPolicy { - provider: Arc, - stopwatch: Stopwatch, -} - -#[async_trait] -impl EscalationPolicy for ShellEscalationPolicy { - async fn determine_action( - &self, - file: &Path, - argv: &[String], - workdir: &Path, - ) -> anyhow::Result { - self.provider - .determine_action(file, argv, workdir, &self.stopwatch) - .await - } -} - -impl EscalationPolicyFactory for ShellPolicyFactory { - type Policy = ShellEscalationPolicy; - - fn create_policy(&self, _policy: Arc>, stopwatch: Stopwatch) -> Self::Policy { - ShellEscalationPolicy { - provider: Arc::clone(&self.provider), - stopwatch, - } - } -} diff --git a/codex-rs/shell-escalation/src/unix/escalate_server.rs b/codex-rs/shell-escalation/src/unix/escalate_server.rs index 7e373f4fb..56db822eb 100644 --- a/codex-rs/shell-escalation/src/unix/escalate_server.rs +++ b/codex-rs/shell-escalation/src/unix/escalate_server.rs @@ -1,16 +1,13 @@ use std::collections::HashMap; use std::os::fd::AsRawFd; -use std::path::Path; use std::path::PathBuf; use std::process::Stdio; use std::sync::Arc; use std::time::Duration; use anyhow::Context as _; -use codex_execpolicy::Policy; use path_absolutize::Absolutize as _; use tokio::process::Command; -use tokio::sync::RwLock; use tokio_util::sync::CancellationToken; use crate::unix::escalate_protocol::ESCALATE_SOCKET_ENV_VAR; @@ -24,7 +21,6 @@ use crate::unix::escalate_protocol::SuperExecResult; use crate::unix::escalation_policy::EscalationPolicy; use crate::unix::socket::AsyncDatagramSocket; use crate::unix::socket::AsyncSocket; -use crate::unix::stopwatch::Stopwatch; /// Adapter for running the shell command after the escalation server has been set up. /// @@ -66,14 +62,14 @@ pub struct ExecResult { pub timed_out: bool, } -struct EscalateServer { +pub struct EscalateServer { bash_path: PathBuf, execve_wrapper: PathBuf, policy: Arc, } impl EscalateServer { - fn new

(bash_path: PathBuf, execve_wrapper: PathBuf, policy: P) -> Self + pub fn new

(bash_path: PathBuf, execve_wrapper: PathBuf, policy: P) -> Self where P: EscalationPolicy + Send + Sync + 'static, { @@ -84,7 +80,7 @@ impl EscalateServer { } } - async fn exec( + pub async fn exec( &self, params: ExecParams, cancel_rx: CancellationToken, @@ -126,35 +122,6 @@ impl EscalateServer { } } -/// Factory for creating escalation policy instances for a single shell run. -pub trait EscalationPolicyFactory { - type Policy: EscalationPolicy + Send + Sync + 'static; - - fn create_policy(&self, policy: Arc>, stopwatch: Stopwatch) -> Self::Policy; -} - -pub async fn run_escalate_server( - exec_params: ExecParams, - shell_program: impl AsRef, - execve_wrapper: impl AsRef, - policy: Arc>, - escalation_policy_factory: impl EscalationPolicyFactory, - effective_timeout: Duration, - command_executor: &dyn ShellCommandExecutor, -) -> anyhow::Result { - let stopwatch = Stopwatch::new(effective_timeout); - let cancel_token = stopwatch.cancellation_token(); - let escalate_server = EscalateServer::new( - shell_program.as_ref().to_path_buf(), - execve_wrapper.as_ref().to_path_buf(), - escalation_policy_factory.create_policy(policy, stopwatch), - ); - - escalate_server - .exec(exec_params, cancel_token, command_executor) - .await -} - async fn escalate_task( socket: AsyncDatagramSocket, policy: Arc, diff --git a/codex-rs/shell-escalation/src/unix/mod.rs b/codex-rs/shell-escalation/src/unix/mod.rs index c69c8b942..37e29e877 100644 --- a/codex-rs/shell-escalation/src/unix/mod.rs +++ b/codex-rs/shell-escalation/src/unix/mod.rs @@ -53,7 +53,6 @@ //! | | //! o<-----x //! -pub mod core_shell_escalation; pub mod escalate_client; pub mod escalate_protocol; pub mod escalate_server; @@ -62,15 +61,12 @@ pub mod execve_wrapper; pub mod socket; pub mod stopwatch; -pub use self::core_shell_escalation::ShellActionProvider; -pub use self::core_shell_escalation::ShellPolicyFactory; pub use self::escalate_client::run_shell_escalation_execve_wrapper; pub use self::escalate_protocol::EscalateAction; -pub use self::escalate_server::EscalationPolicyFactory; +pub use self::escalate_server::EscalateServer; pub use self::escalate_server::ExecParams; pub use self::escalate_server::ExecResult; pub use self::escalate_server::ShellCommandExecutor; -pub use self::escalate_server::run_escalate_server; pub use self::escalation_policy::EscalationPolicy; pub use self::execve_wrapper::main_execve_wrapper; pub use self::stopwatch::Stopwatch;