refactor: decouple shell-escalation from codex-core (#12638)

## Why

After removing `exec-server`, the next step is to wire a new shell tool
to `codex-rs/shell-escalation` directly.

That is blocked while `codex-shell-escalation` depends on `codex-core`,
because the new integration would require `codex-core` to depend on
`codex-shell-escalation` and create a dependency cycle.

This change ports the reusable pieces from the earlier prep work, but
drops the old compatibility shim because `exec-server`/MCP support is
already gone.

## What Changed

### Decouple `shell-escalation` from `codex-core`

- Introduce a crate-local `SandboxState` in `shell-escalation`
- Introduce a `ShellCommandExecutor` trait so callers provide process
execution/sandbox integration
- Update `EscalateServer::exec(...)` and `run_escalate_server(...)` to
use the injected executor
- Remove the direct `codex_core::exec::process_exec_tool_call(...)` call
from `shell-escalation`
- Remove the `codex-core` dependency from `codex-shell-escalation`

### Restore reusable policy adapter exports

- Re-enable `unix::core_shell_escalation`
- Export `ShellActionProvider` and `ShellPolicyFactory` from
`shell-escalation`
- Keep the crate root API simple (no `legacy_api` compatibility layer)

### Port socket fixes from the earlier prep commit

- Use `socket2::Socket::pair_raw(...)` for AF_UNIX socketpairs and
restore `CLOEXEC` explicitly on both endpoints
- Keep `CLOEXEC` cleared only on the single datagram client FD that is
intentionally passed across `exec`
- Clean up `tokio::AsyncFd::try_io(...)` error handling in the socket
helpers

## Verification

- `cargo shear`
- `cargo clippy -p codex-shell-escalation --tests`
- `cargo test -p codex-shell-escalation`
This commit is contained in:
Michael Bolin 2026-02-23 20:58:24 -08:00 committed by GitHub
parent 38f84b6b29
commit af215eb390
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 87 additions and 75 deletions

2
codex-rs/Cargo.lock generated
View file

@ -2238,9 +2238,7 @@ dependencies = [
"anyhow",
"async-trait",
"clap",
"codex-core",
"codex-execpolicy",
"codex-protocol",
"libc",
"path-absolutize",
"pretty_assertions",

View file

@ -12,20 +12,20 @@ path = "src/bin/main_execve_wrapper.rs"
anyhow = { workspace = true }
async-trait = { workspace = true }
clap = { workspace = true, features = ["derive"] }
codex-core = { workspace = true }
codex-execpolicy = { workspace = true }
codex-protocol = { workspace = true }
libc = { workspace = true }
serde_json = { workspace = true }
path-absolutize = { workspace = true }
serde = { workspace = true, features = ["derive"] }
socket2 = { workspace = true }
socket2 = { workspace = true, features = ["all"] }
tokio = { workspace = true, features = [
"io-std",
"net",
"macros",
"process",
"rt-multi-thread",
"signal",
"time",
] }
tokio-util = { workspace = true }
tracing = { workspace = true }

View file

@ -12,6 +12,12 @@ pub use unix::ExecParams;
#[cfg(unix)]
pub use unix::ExecResult;
#[cfg(unix)]
pub use unix::ShellActionProvider;
#[cfg(unix)]
pub use unix::ShellCommandExecutor;
#[cfg(unix)]
pub use unix::ShellPolicyFactory;
#[cfg(unix)]
pub use unix::Stopwatch;
#[cfg(unix)]
pub use unix::main_execve_wrapper;

View file

@ -3,11 +3,11 @@ use std::path::Path;
use std::sync::Arc;
use tokio::sync::RwLock;
use codex_execpolicy::Policy;
use super::escalate_protocol::EscalateAction;
use super::escalate_server::EscalationPolicyFactory;
use super::escalation_policy::EscalationPolicy;
use super::stopwatch::Stopwatch;
use codex_execpolicy::Policy;
#[async_trait]
pub trait ShellActionProvider: Send + Sync {
@ -40,7 +40,10 @@ impl ShellPolicyFactory {
}
}
struct ShellEscalationPolicy {
/// Public only because it is the associated `Policy` type in the public
/// `EscalationPolicyFactory` impl for `ShellPolicyFactory`.
#[doc(hidden)]
pub struct ShellEscalationPolicy {
provider: Arc<dyn ShellActionProvider>,
stopwatch: Stopwatch,
}

View file

@ -7,7 +7,6 @@ use std::sync::Arc;
use std::time::Duration;
use anyhow::Context as _;
use codex_core::SandboxState;
use codex_execpolicy::Policy;
use path_absolutize::Absolutize as _;
use tokio::process::Command;
@ -27,6 +26,23 @@ use crate::unix::socket::AsyncDatagramSocket;
use crate::unix::socket::AsyncSocket;
use crate::unix::stopwatch::Stopwatch;
/// Adapter for running the shell command after the escalation server has been set up.
///
/// This lets `shell-escalation` own the Unix escalation protocol while the caller
/// keeps control over process spawning, output capture, and sandbox integration.
/// Implementations can capture any sandbox state they need.
#[async_trait::async_trait]
pub trait ShellCommandExecutor: Send + Sync {
/// Runs the requested shell command and returns the captured result.
async fn run(
&self,
command: Vec<String>,
cwd: PathBuf,
env: HashMap<String, String>,
cancel_rx: CancellationToken,
) -> anyhow::Result<ExecResult>;
}
#[derive(Debug, serde::Deserialize, serde::Serialize)]
pub struct ExecParams {
/// The bash string to execute.
@ -70,12 +86,12 @@ impl EscalateServer {
&self,
params: ExecParams,
cancel_rx: CancellationToken,
sandbox_state: &SandboxState,
command_executor: &dyn ShellCommandExecutor,
) -> anyhow::Result<ExecResult> {
let (escalate_server, escalate_client) = AsyncDatagramSocket::pair()?;
let client_socket = escalate_client.into_inner();
// Only the client endpoint should cross exec into the wrapper process.
client_socket.set_cloexec(false)?;
let escalate_task = tokio::spawn(escalate_task(escalate_server, self.policy.clone()));
let mut env = std::env::vars().collect::<HashMap<String, String>>();
env.insert(
@ -91,47 +107,20 @@ impl EscalateServer {
self.execve_wrapper.to_string_lossy().to_string(),
);
let ExecParams {
command,
workdir,
timeout_ms: _,
login,
} = params;
let result = codex_core::exec::process_exec_tool_call(
codex_core::exec::ExecParams {
command: vec![
self.bash_path.to_string_lossy().to_string(),
if login == Some(false) {
"-c".to_string()
} else {
"-lc".to_string()
},
command,
],
cwd: PathBuf::from(&workdir),
expiration: codex_core::exec::ExecExpiration::Cancellation(cancel_rx),
env,
network: None,
sandbox_permissions: codex_core::sandboxing::SandboxPermissions::UseDefault,
windows_sandbox_level: codex_protocol::config_types::WindowsSandboxLevel::Disabled,
justification: None,
arg0: None,
let command = vec![
self.bash_path.to_string_lossy().to_string(),
if params.login == Some(false) {
"-c".to_string()
} else {
"-lc".to_string()
},
&sandbox_state.sandbox_policy,
&sandbox_state.sandbox_cwd,
&sandbox_state.codex_linux_sandbox_exe,
sandbox_state.use_linux_sandbox_bwrap,
None,
)
.await?;
params.command,
];
let result = command_executor
.run(command, PathBuf::from(&params.workdir), env, cancel_rx)
.await?;
escalate_task.abort();
Ok(ExecResult {
exit_code: result.exit_code,
output: result.aggregated_output.text,
duration: result.duration,
timed_out: result.timed_out,
})
Ok(result)
}
}
@ -144,12 +133,12 @@ pub trait EscalationPolicyFactory {
pub async fn run_escalate_server(
exec_params: ExecParams,
sandbox_state: &SandboxState,
shell_program: impl AsRef<Path>,
execve_wrapper: impl AsRef<Path>,
policy: Arc<RwLock<Policy>>,
escalation_policy_factory: impl EscalationPolicyFactory,
effective_timeout: Duration,
command_executor: &dyn ShellCommandExecutor,
) -> anyhow::Result<ExecResult> {
let stopwatch = Stopwatch::new(effective_timeout);
let cancel_token = stopwatch.cancellation_token();
@ -160,7 +149,7 @@ pub async fn run_escalate_server(
);
escalate_server
.exec(exec_params, cancel_token, sandbox_state)
.exec(exec_params, cancel_token, command_executor)
.await
}

View file

@ -53,6 +53,7 @@
//! | |
//! o<-----x
//!
pub mod core_shell_escalation;
pub mod escalate_client;
pub mod escalate_protocol;
pub mod escalate_server;
@ -61,11 +62,14 @@ pub mod execve_wrapper;
pub mod socket;
pub mod stopwatch;
pub use self::core_shell_escalation::ShellActionProvider;
pub use self::core_shell_escalation::ShellPolicyFactory;
pub use self::escalate_client::run;
pub use self::escalate_protocol::EscalateAction;
pub use self::escalate_server::EscalationPolicyFactory;
pub use self::escalate_server::ExecParams;
pub use self::escalate_server::ExecResult;
pub use self::escalate_server::ShellCommandExecutor;
pub use self::escalate_server::run_escalate_server;
pub use self::escalation_policy::EscalationPolicy;
pub use self::execve_wrapper::main_execve_wrapper;

View file

@ -96,8 +96,8 @@ async fn read_frame_header(
while filled < LENGTH_PREFIX_SIZE {
let mut guard = async_socket.readable().await?;
// The first read should come with a control message containing any FDs.
let result = if !captured_control {
guard.try_io(|inner| {
let read = if !captured_control {
match guard.try_io(|inner| {
let mut bufs = [MaybeUninitSlice::new(&mut header[filled..])];
let (read, control_len) = {
let mut msg = MsgHdrMut::new()
@ -109,16 +109,18 @@ async fn read_frame_header(
control.truncate(control_len);
captured_control = true;
Ok(read)
})
}) {
Ok(Ok(read)) => read,
Ok(Err(err)) => return Err(err),
Err(_would_block) => continue,
}
} else {
guard.try_io(|inner| inner.get_ref().recv(&mut header[filled..]))
match guard.try_io(|inner| inner.get_ref().recv(&mut header[filled..])) {
Ok(Ok(read)) => read,
Ok(Err(err)) => return Err(err),
Err(_would_block) => continue,
}
};
let Ok(result) = result else {
// Would block, try again.
continue;
};
let read = result?;
if read == 0 {
return Err(std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
@ -150,12 +152,11 @@ async fn read_frame_payload(
let mut filled = 0;
while filled < message_len {
let mut guard = async_socket.readable().await?;
let result = guard.try_io(|inner| inner.get_ref().recv(&mut payload[filled..]));
let Ok(result) = result else {
// Would block, try again.
continue;
let read = match guard.try_io(|inner| inner.get_ref().recv(&mut payload[filled..])) {
Ok(Ok(read)) => read,
Ok(Err(err)) => return Err(err),
Err(_would_block) => continue,
};
let read = result?;
if read == 0 {
return Err(std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
@ -261,7 +262,13 @@ impl AsyncSocket {
}
pub fn pair() -> std::io::Result<(AsyncSocket, AsyncSocket)> {
let (server, client) = Socket::pair(Domain::UNIX, Type::STREAM, None)?;
// `socket2::Socket::pair()` also applies "common flags" (including
// `SO_NOSIGPIPE` on Apple platforms), which can fail for AF_UNIX sockets.
// Use `pair_raw()` to avoid those side effects, then restore `CLOEXEC`
// explicitly on both endpoints.
let (server, client) = Socket::pair_raw(Domain::UNIX, Type::STREAM, None)?;
server.set_cloexec(true)?;
client.set_cloexec(true)?;
Ok((AsyncSocket::new(server)?, AsyncSocket::new(client)?))
}
@ -314,11 +321,10 @@ async fn send_stream_frame(
let mut include_fds = !fds.is_empty();
while written < frame.len() {
let mut guard = socket.writable().await?;
let result = guard.try_io(|inner| {
send_stream_chunk(inner.get_ref(), &frame[written..], fds, include_fds)
});
let bytes_written = match result {
Ok(bytes_written) => bytes_written?,
let bytes_written = match guard
.try_io(|inner| send_stream_chunk(inner.get_ref(), &frame[written..], fds, include_fds))
{
Ok(result) => result?,
Err(_would_block) => continue,
};
if bytes_written == 0 {
@ -370,7 +376,13 @@ impl AsyncDatagramSocket {
}
pub fn pair() -> std::io::Result<(Self, Self)> {
let (server, client) = Socket::pair(Domain::UNIX, Type::DGRAM, None)?;
// `socket2::Socket::pair()` also applies "common flags" (including
// `SO_NOSIGPIPE` on Apple platforms), which can fail for AF_UNIX sockets.
// Use `pair_raw()` to avoid those side effects, then restore `CLOEXEC`
// explicitly on both endpoints.
let (server, client) = Socket::pair_raw(Domain::UNIX, Type::DGRAM, None)?;
server.set_cloexec(true)?;
client.set_cloexec(true)?;
Ok((Self::new(server)?, Self::new(client)?))
}
@ -472,7 +484,7 @@ mod tests {
#[test]
fn send_datagram_bytes_rejects_excessive_fd_counts() -> std::io::Result<()> {
let (socket, _peer) = Socket::pair(Domain::UNIX, Type::DGRAM, None)?;
let (socket, _peer) = Socket::pair_raw(Domain::UNIX, Type::DGRAM, None)?;
let fds = fd_list(MAX_FDS_PER_MESSAGE + 1)?;
let err = send_datagram_bytes(&socket, b"hi", &fds).unwrap_err();
assert_eq!(std::io::ErrorKind::InvalidInput, err.kind());
@ -481,7 +493,7 @@ mod tests {
#[test]
fn send_stream_chunk_rejects_excessive_fd_counts() -> std::io::Result<()> {
let (socket, _peer) = Socket::pair(Domain::UNIX, Type::STREAM, None)?;
let (socket, _peer) = Socket::pair_raw(Domain::UNIX, Type::STREAM, None)?;
let fds = fd_list(MAX_FDS_PER_MESSAGE + 1)?;
let err = send_stream_chunk(&socket, b"hello", &fds, true).unwrap_err();
assert_eq!(std::io::ErrorKind::InvalidInput, err.kind());