diff --git a/codex-rs/Cargo.lock b/codex-rs/Cargo.lock index f134c3cbb..c47be4d5a 100644 --- a/codex-rs/Cargo.lock +++ b/codex-rs/Cargo.lock @@ -2238,9 +2238,7 @@ dependencies = [ "anyhow", "async-trait", "clap", - "codex-core", "codex-execpolicy", - "codex-protocol", "libc", "path-absolutize", "pretty_assertions", diff --git a/codex-rs/shell-escalation/Cargo.toml b/codex-rs/shell-escalation/Cargo.toml index e8bc81042..ec88bf82f 100644 --- a/codex-rs/shell-escalation/Cargo.toml +++ b/codex-rs/shell-escalation/Cargo.toml @@ -12,20 +12,20 @@ path = "src/bin/main_execve_wrapper.rs" anyhow = { workspace = true } async-trait = { workspace = true } clap = { workspace = true, features = ["derive"] } -codex-core = { workspace = true } codex-execpolicy = { workspace = true } -codex-protocol = { workspace = true } libc = { workspace = true } serde_json = { workspace = true } path-absolutize = { workspace = true } serde = { workspace = true, features = ["derive"] } -socket2 = { workspace = true } +socket2 = { workspace = true, features = ["all"] } tokio = { workspace = true, features = [ "io-std", + "net", "macros", "process", "rt-multi-thread", "signal", + "time", ] } tokio-util = { workspace = true } tracing = { workspace = true } diff --git a/codex-rs/shell-escalation/src/lib.rs b/codex-rs/shell-escalation/src/lib.rs index f697545f6..45ed7c799 100644 --- a/codex-rs/shell-escalation/src/lib.rs +++ b/codex-rs/shell-escalation/src/lib.rs @@ -12,6 +12,12 @@ pub use unix::ExecParams; #[cfg(unix)] pub use unix::ExecResult; #[cfg(unix)] +pub use unix::ShellActionProvider; +#[cfg(unix)] +pub use unix::ShellCommandExecutor; +#[cfg(unix)] +pub use unix::ShellPolicyFactory; +#[cfg(unix)] pub use unix::Stopwatch; #[cfg(unix)] pub use unix::main_execve_wrapper; diff --git a/codex-rs/shell-escalation/src/unix/core_shell_escalation.rs b/codex-rs/shell-escalation/src/unix/core_shell_escalation.rs index 5a67062dc..a56624cd9 100644 --- a/codex-rs/shell-escalation/src/unix/core_shell_escalation.rs +++ b/codex-rs/shell-escalation/src/unix/core_shell_escalation.rs @@ -3,11 +3,11 @@ use std::path::Path; use std::sync::Arc; use tokio::sync::RwLock; -use codex_execpolicy::Policy; use super::escalate_protocol::EscalateAction; use super::escalate_server::EscalationPolicyFactory; use super::escalation_policy::EscalationPolicy; use super::stopwatch::Stopwatch; +use codex_execpolicy::Policy; #[async_trait] pub trait ShellActionProvider: Send + Sync { @@ -40,7 +40,10 @@ impl ShellPolicyFactory { } } -struct ShellEscalationPolicy { +/// Public only because it is the associated `Policy` type in the public +/// `EscalationPolicyFactory` impl for `ShellPolicyFactory`. +#[doc(hidden)] +pub struct ShellEscalationPolicy { provider: Arc, stopwatch: Stopwatch, } diff --git a/codex-rs/shell-escalation/src/unix/escalate_server.rs b/codex-rs/shell-escalation/src/unix/escalate_server.rs index 0ee5fc27c..dbdba645b 100644 --- a/codex-rs/shell-escalation/src/unix/escalate_server.rs +++ b/codex-rs/shell-escalation/src/unix/escalate_server.rs @@ -7,7 +7,6 @@ use std::sync::Arc; use std::time::Duration; use anyhow::Context as _; -use codex_core::SandboxState; use codex_execpolicy::Policy; use path_absolutize::Absolutize as _; use tokio::process::Command; @@ -27,6 +26,23 @@ use crate::unix::socket::AsyncDatagramSocket; use crate::unix::socket::AsyncSocket; use crate::unix::stopwatch::Stopwatch; +/// Adapter for running the shell command after the escalation server has been set up. +/// +/// This lets `shell-escalation` own the Unix escalation protocol while the caller +/// keeps control over process spawning, output capture, and sandbox integration. +/// Implementations can capture any sandbox state they need. +#[async_trait::async_trait] +pub trait ShellCommandExecutor: Send + Sync { + /// Runs the requested shell command and returns the captured result. + async fn run( + &self, + command: Vec, + cwd: PathBuf, + env: HashMap, + cancel_rx: CancellationToken, + ) -> anyhow::Result; +} + #[derive(Debug, serde::Deserialize, serde::Serialize)] pub struct ExecParams { /// The bash string to execute. @@ -70,12 +86,12 @@ impl EscalateServer { &self, params: ExecParams, cancel_rx: CancellationToken, - sandbox_state: &SandboxState, + command_executor: &dyn ShellCommandExecutor, ) -> anyhow::Result { let (escalate_server, escalate_client) = AsyncDatagramSocket::pair()?; let client_socket = escalate_client.into_inner(); + // Only the client endpoint should cross exec into the wrapper process. client_socket.set_cloexec(false)?; - let escalate_task = tokio::spawn(escalate_task(escalate_server, self.policy.clone())); let mut env = std::env::vars().collect::>(); env.insert( @@ -91,47 +107,20 @@ impl EscalateServer { self.execve_wrapper.to_string_lossy().to_string(), ); - let ExecParams { - command, - workdir, - timeout_ms: _, - login, - } = params; - let result = codex_core::exec::process_exec_tool_call( - codex_core::exec::ExecParams { - command: vec![ - self.bash_path.to_string_lossy().to_string(), - if login == Some(false) { - "-c".to_string() - } else { - "-lc".to_string() - }, - command, - ], - cwd: PathBuf::from(&workdir), - expiration: codex_core::exec::ExecExpiration::Cancellation(cancel_rx), - env, - network: None, - sandbox_permissions: codex_core::sandboxing::SandboxPermissions::UseDefault, - windows_sandbox_level: codex_protocol::config_types::WindowsSandboxLevel::Disabled, - justification: None, - arg0: None, + let command = vec![ + self.bash_path.to_string_lossy().to_string(), + if params.login == Some(false) { + "-c".to_string() + } else { + "-lc".to_string() }, - &sandbox_state.sandbox_policy, - &sandbox_state.sandbox_cwd, - &sandbox_state.codex_linux_sandbox_exe, - sandbox_state.use_linux_sandbox_bwrap, - None, - ) - .await?; + params.command, + ]; + let result = command_executor + .run(command, PathBuf::from(¶ms.workdir), env, cancel_rx) + .await?; escalate_task.abort(); - - Ok(ExecResult { - exit_code: result.exit_code, - output: result.aggregated_output.text, - duration: result.duration, - timed_out: result.timed_out, - }) + Ok(result) } } @@ -144,12 +133,12 @@ pub trait EscalationPolicyFactory { pub async fn run_escalate_server( exec_params: ExecParams, - sandbox_state: &SandboxState, shell_program: impl AsRef, execve_wrapper: impl AsRef, policy: Arc>, escalation_policy_factory: impl EscalationPolicyFactory, effective_timeout: Duration, + command_executor: &dyn ShellCommandExecutor, ) -> anyhow::Result { let stopwatch = Stopwatch::new(effective_timeout); let cancel_token = stopwatch.cancellation_token(); @@ -160,7 +149,7 @@ pub async fn run_escalate_server( ); escalate_server - .exec(exec_params, cancel_token, sandbox_state) + .exec(exec_params, cancel_token, command_executor) .await } diff --git a/codex-rs/shell-escalation/src/unix/mod.rs b/codex-rs/shell-escalation/src/unix/mod.rs index 5bdd233d5..42091791f 100644 --- a/codex-rs/shell-escalation/src/unix/mod.rs +++ b/codex-rs/shell-escalation/src/unix/mod.rs @@ -53,6 +53,7 @@ //! | | //! o<-----x //! +pub mod core_shell_escalation; pub mod escalate_client; pub mod escalate_protocol; pub mod escalate_server; @@ -61,11 +62,14 @@ pub mod execve_wrapper; pub mod socket; pub mod stopwatch; +pub use self::core_shell_escalation::ShellActionProvider; +pub use self::core_shell_escalation::ShellPolicyFactory; pub use self::escalate_client::run; pub use self::escalate_protocol::EscalateAction; pub use self::escalate_server::EscalationPolicyFactory; pub use self::escalate_server::ExecParams; pub use self::escalate_server::ExecResult; +pub use self::escalate_server::ShellCommandExecutor; pub use self::escalate_server::run_escalate_server; pub use self::escalation_policy::EscalationPolicy; pub use self::execve_wrapper::main_execve_wrapper; diff --git a/codex-rs/shell-escalation/src/unix/socket.rs b/codex-rs/shell-escalation/src/unix/socket.rs index 35292367a..8325e940f 100644 --- a/codex-rs/shell-escalation/src/unix/socket.rs +++ b/codex-rs/shell-escalation/src/unix/socket.rs @@ -96,8 +96,8 @@ async fn read_frame_header( while filled < LENGTH_PREFIX_SIZE { let mut guard = async_socket.readable().await?; // The first read should come with a control message containing any FDs. - let result = if !captured_control { - guard.try_io(|inner| { + let read = if !captured_control { + match guard.try_io(|inner| { let mut bufs = [MaybeUninitSlice::new(&mut header[filled..])]; let (read, control_len) = { let mut msg = MsgHdrMut::new() @@ -109,16 +109,18 @@ async fn read_frame_header( control.truncate(control_len); captured_control = true; Ok(read) - }) + }) { + Ok(Ok(read)) => read, + Ok(Err(err)) => return Err(err), + Err(_would_block) => continue, + } } else { - guard.try_io(|inner| inner.get_ref().recv(&mut header[filled..])) + match guard.try_io(|inner| inner.get_ref().recv(&mut header[filled..])) { + Ok(Ok(read)) => read, + Ok(Err(err)) => return Err(err), + Err(_would_block) => continue, + } }; - let Ok(result) = result else { - // Would block, try again. - continue; - }; - - let read = result?; if read == 0 { return Err(std::io::Error::new( std::io::ErrorKind::UnexpectedEof, @@ -150,12 +152,11 @@ async fn read_frame_payload( let mut filled = 0; while filled < message_len { let mut guard = async_socket.readable().await?; - let result = guard.try_io(|inner| inner.get_ref().recv(&mut payload[filled..])); - let Ok(result) = result else { - // Would block, try again. - continue; + let read = match guard.try_io(|inner| inner.get_ref().recv(&mut payload[filled..])) { + Ok(Ok(read)) => read, + Ok(Err(err)) => return Err(err), + Err(_would_block) => continue, }; - let read = result?; if read == 0 { return Err(std::io::Error::new( std::io::ErrorKind::UnexpectedEof, @@ -261,7 +262,13 @@ impl AsyncSocket { } pub fn pair() -> std::io::Result<(AsyncSocket, AsyncSocket)> { - let (server, client) = Socket::pair(Domain::UNIX, Type::STREAM, None)?; + // `socket2::Socket::pair()` also applies "common flags" (including + // `SO_NOSIGPIPE` on Apple platforms), which can fail for AF_UNIX sockets. + // Use `pair_raw()` to avoid those side effects, then restore `CLOEXEC` + // explicitly on both endpoints. + let (server, client) = Socket::pair_raw(Domain::UNIX, Type::STREAM, None)?; + server.set_cloexec(true)?; + client.set_cloexec(true)?; Ok((AsyncSocket::new(server)?, AsyncSocket::new(client)?)) } @@ -314,11 +321,10 @@ async fn send_stream_frame( let mut include_fds = !fds.is_empty(); while written < frame.len() { let mut guard = socket.writable().await?; - let result = guard.try_io(|inner| { - send_stream_chunk(inner.get_ref(), &frame[written..], fds, include_fds) - }); - let bytes_written = match result { - Ok(bytes_written) => bytes_written?, + let bytes_written = match guard + .try_io(|inner| send_stream_chunk(inner.get_ref(), &frame[written..], fds, include_fds)) + { + Ok(result) => result?, Err(_would_block) => continue, }; if bytes_written == 0 { @@ -370,7 +376,13 @@ impl AsyncDatagramSocket { } pub fn pair() -> std::io::Result<(Self, Self)> { - let (server, client) = Socket::pair(Domain::UNIX, Type::DGRAM, None)?; + // `socket2::Socket::pair()` also applies "common flags" (including + // `SO_NOSIGPIPE` on Apple platforms), which can fail for AF_UNIX sockets. + // Use `pair_raw()` to avoid those side effects, then restore `CLOEXEC` + // explicitly on both endpoints. + let (server, client) = Socket::pair_raw(Domain::UNIX, Type::DGRAM, None)?; + server.set_cloexec(true)?; + client.set_cloexec(true)?; Ok((Self::new(server)?, Self::new(client)?)) } @@ -472,7 +484,7 @@ mod tests { #[test] fn send_datagram_bytes_rejects_excessive_fd_counts() -> std::io::Result<()> { - let (socket, _peer) = Socket::pair(Domain::UNIX, Type::DGRAM, None)?; + let (socket, _peer) = Socket::pair_raw(Domain::UNIX, Type::DGRAM, None)?; let fds = fd_list(MAX_FDS_PER_MESSAGE + 1)?; let err = send_datagram_bytes(&socket, b"hi", &fds).unwrap_err(); assert_eq!(std::io::ErrorKind::InvalidInput, err.kind()); @@ -481,7 +493,7 @@ mod tests { #[test] fn send_stream_chunk_rejects_excessive_fd_counts() -> std::io::Result<()> { - let (socket, _peer) = Socket::pair(Domain::UNIX, Type::STREAM, None)?; + let (socket, _peer) = Socket::pair_raw(Domain::UNIX, Type::STREAM, None)?; let fds = fd_list(MAX_FDS_PER_MESSAGE + 1)?; let err = send_stream_chunk(&socket, b"hello", &fds, true).unwrap_err(); assert_eq!(std::io::ErrorKind::InvalidInput, err.kind());