fix: make EscalateServer public and remove shell escalation wrappers (#12724)
## Why `codex-shell-escalation` exposed a `codex-core`-specific adapter layer (`ShellActionProvider`, `ShellPolicyFactory`, and `run_escalate_server`) that existed only to bridge `codex-core` to `EscalateServer`. That indirection increased API surface and obscured crate ownership without adding behavior. This change moves orchestration into `codex-core` so boundaries are clearer: `codex-shell-escalation` provides reusable escalation primitives, and `codex-core` provides shell-tool policy decisions. Admittedly, @pakrym rightfully requested this sort of cleanup as part of https://github.com/openai/codex/pull/12649, though this avoids moving all of `codex-shell-escalation` into `codex-core`. ## What changed - Made `EscalateServer` public and exported it from `shell-escalation`. - Removed the adapter layer from `shell-escalation`: - deleted `shell-escalation/src/unix/core_shell_escalation.rs` - removed exports for `ShellActionProvider`, `ShellPolicyFactory`, `EscalationPolicyFactory`, and `run_escalate_server` - Updated `core/src/tools/runtimes/shell/unix_escalation.rs` to: - create `Stopwatch`/cancellation in `codex-core` - instantiate `EscalateServer` directly - implement `EscalationPolicy` directly on `CoreShellActionProvider` Net effect: same escalation flow with fewer wrappers and a smaller public API. ## Verification - Manually reviewed the old vs. new escalation call flow to confirm timeout/cancellation behavior and approval policy decisions are preserved while removing wrapper types.
This commit is contained in:
parent
8da40c9251
commit
3d356723c4
7 changed files with 47 additions and 157 deletions
1
codex-rs/Cargo.lock
generated
1
codex-rs/Cargo.lock
generated
|
|
@ -2241,7 +2241,6 @@ dependencies = [
|
|||
"anyhow",
|
||||
"async-trait",
|
||||
"clap",
|
||||
"codex-execpolicy",
|
||||
"libc",
|
||||
"path-absolutize",
|
||||
"pretty_assertions",
|
||||
|
|
|
|||
|
|
@ -24,13 +24,12 @@ use codex_protocol::protocol::SandboxPolicy;
|
|||
use codex_shell_command::bash::parse_shell_lc_plain_commands;
|
||||
use codex_shell_command::bash::parse_shell_lc_single_command_prefix;
|
||||
use codex_shell_escalation::EscalateAction;
|
||||
use codex_shell_escalation::EscalateServer;
|
||||
use codex_shell_escalation::EscalationPolicy;
|
||||
use codex_shell_escalation::ExecParams;
|
||||
use codex_shell_escalation::ExecResult;
|
||||
use codex_shell_escalation::ShellActionProvider;
|
||||
use codex_shell_escalation::ShellCommandExecutor;
|
||||
use codex_shell_escalation::ShellPolicyFactory;
|
||||
use codex_shell_escalation::Stopwatch;
|
||||
use codex_shell_escalation::run_escalate_server;
|
||||
use std::collections::HashMap;
|
||||
use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
|
|
@ -107,30 +106,39 @@ pub(super) async fn try_run_zsh_fork(
|
|||
justification,
|
||||
arg0,
|
||||
};
|
||||
let exec_result = run_escalate_server(
|
||||
ExecParams {
|
||||
command: script,
|
||||
workdir: req.cwd.to_string_lossy().to_string(),
|
||||
timeout_ms: Some(effective_timeout.as_millis() as u64),
|
||||
login: Some(login),
|
||||
},
|
||||
shell_zsh_path.clone(),
|
||||
shell_execve_wrapper().map_err(|err| ToolError::Rejected(format!("{err}")))?,
|
||||
exec_policy.clone(),
|
||||
ShellPolicyFactory::new(CoreShellActionProvider {
|
||||
policy: Arc::clone(&exec_policy),
|
||||
session: Arc::clone(&ctx.session),
|
||||
turn: Arc::clone(&ctx.turn),
|
||||
call_id: ctx.call_id.clone(),
|
||||
approval_policy: ctx.turn.approval_policy.value(),
|
||||
sandbox_policy: attempt.policy.clone(),
|
||||
sandbox_permissions: req.sandbox_permissions,
|
||||
}),
|
||||
effective_timeout,
|
||||
&command_executor,
|
||||
)
|
||||
.await
|
||||
.map_err(|err| ToolError::Rejected(err.to_string()))?;
|
||||
|
||||
let exec_params = ExecParams {
|
||||
command: script,
|
||||
workdir: req.cwd.to_string_lossy().to_string(),
|
||||
timeout_ms: Some(effective_timeout.as_millis() as u64),
|
||||
login: Some(login),
|
||||
};
|
||||
let execve_wrapper =
|
||||
shell_execve_wrapper().map_err(|err| ToolError::Rejected(format!("{err}")))?;
|
||||
|
||||
// Note that Stopwatch starts immediately upon creation, so currently we try
|
||||
// to minimize the time between creating the Stopwatch and starting the
|
||||
// escalation server.
|
||||
let stopwatch = Stopwatch::new(effective_timeout);
|
||||
let cancel_token = stopwatch.cancellation_token();
|
||||
let escalation_policy = CoreShellActionProvider {
|
||||
policy: Arc::clone(&exec_policy),
|
||||
session: Arc::clone(&ctx.session),
|
||||
turn: Arc::clone(&ctx.turn),
|
||||
call_id: ctx.call_id.clone(),
|
||||
approval_policy: ctx.turn.approval_policy.value(),
|
||||
sandbox_policy: attempt.policy.clone(),
|
||||
sandbox_permissions: req.sandbox_permissions,
|
||||
stopwatch: stopwatch.clone(),
|
||||
};
|
||||
|
||||
let escalate_server =
|
||||
EscalateServer::new(shell_zsh_path.clone(), execve_wrapper, escalation_policy);
|
||||
|
||||
let exec_result = escalate_server
|
||||
.exec(exec_params, cancel_token, &command_executor)
|
||||
.await
|
||||
.map_err(|err| ToolError::Rejected(err.to_string()))?;
|
||||
|
||||
map_exec_result(attempt.sandbox, exec_result).map(Some)
|
||||
}
|
||||
|
|
@ -143,6 +151,7 @@ struct CoreShellActionProvider {
|
|||
approval_policy: AskForApproval,
|
||||
sandbox_policy: SandboxPolicy,
|
||||
sandbox_permissions: SandboxPermissions,
|
||||
stopwatch: Stopwatch,
|
||||
}
|
||||
|
||||
impl CoreShellActionProvider {
|
||||
|
|
@ -186,13 +195,12 @@ impl CoreShellActionProvider {
|
|||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl ShellActionProvider for CoreShellActionProvider {
|
||||
impl EscalationPolicy for CoreShellActionProvider {
|
||||
async fn determine_action(
|
||||
&self,
|
||||
file: &Path,
|
||||
argv: &[String],
|
||||
workdir: &Path,
|
||||
stopwatch: &Stopwatch,
|
||||
) -> anyhow::Result<EscalateAction> {
|
||||
let command = std::iter::once(file.to_string_lossy().to_string())
|
||||
.chain(argv.iter().cloned())
|
||||
|
|
@ -240,7 +248,7 @@ impl ShellActionProvider for CoreShellActionProvider {
|
|||
reason: Some("Execution forbidden by policy".to_string()),
|
||||
}
|
||||
} else {
|
||||
match self.prompt(&command, workdir, stopwatch).await? {
|
||||
match self.prompt(&command, workdir, &self.stopwatch).await? {
|
||||
ReviewDecision::Approved
|
||||
| ReviewDecision::ApprovedExecpolicyAmendment { .. }
|
||||
| ReviewDecision::ApprovedForSession => {
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
[package]
|
||||
name = "codex-shell-escalation"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
name = "codex-shell-escalation"
|
||||
version.workspace = true
|
||||
|
||||
[[bin]]
|
||||
name = "codex-execve-wrapper"
|
||||
|
|
@ -12,11 +12,10 @@ path = "src/bin/main_execve_wrapper.rs"
|
|||
anyhow = { workspace = true }
|
||||
async-trait = { workspace = true }
|
||||
clap = { workspace = true, features = ["derive"] }
|
||||
codex-execpolicy = { workspace = true }
|
||||
libc = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
path-absolutize = { workspace = true }
|
||||
serde = { workspace = true, features = ["derive"] }
|
||||
serde_json = { workspace = true }
|
||||
socket2 = { workspace = true, features = ["all"] }
|
||||
tokio = { workspace = true, features = [
|
||||
"io-std",
|
||||
|
|
|
|||
|
|
@ -4,24 +4,18 @@ mod unix;
|
|||
#[cfg(unix)]
|
||||
pub use unix::EscalateAction;
|
||||
#[cfg(unix)]
|
||||
pub use unix::EscalationPolicy;
|
||||
pub use unix::EscalateServer;
|
||||
#[cfg(unix)]
|
||||
pub use unix::EscalationPolicyFactory;
|
||||
pub use unix::EscalationPolicy;
|
||||
#[cfg(unix)]
|
||||
pub use unix::ExecParams;
|
||||
#[cfg(unix)]
|
||||
pub use unix::ExecResult;
|
||||
#[cfg(unix)]
|
||||
pub use unix::ShellActionProvider;
|
||||
#[cfg(unix)]
|
||||
pub use unix::ShellCommandExecutor;
|
||||
#[cfg(unix)]
|
||||
pub use unix::ShellPolicyFactory;
|
||||
#[cfg(unix)]
|
||||
pub use unix::Stopwatch;
|
||||
#[cfg(unix)]
|
||||
pub use unix::main_execve_wrapper;
|
||||
#[cfg(unix)]
|
||||
pub use unix::run_escalate_server;
|
||||
#[cfg(unix)]
|
||||
pub use unix::run_shell_escalation_execve_wrapper;
|
||||
|
|
|
|||
|
|
@ -1,73 +0,0 @@
|
|||
use async_trait::async_trait;
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
use super::escalate_protocol::EscalateAction;
|
||||
use super::escalate_server::EscalationPolicyFactory;
|
||||
use super::escalation_policy::EscalationPolicy;
|
||||
use super::stopwatch::Stopwatch;
|
||||
use codex_execpolicy::Policy;
|
||||
|
||||
#[async_trait]
|
||||
pub trait ShellActionProvider: Send + Sync {
|
||||
async fn determine_action(
|
||||
&self,
|
||||
file: &Path,
|
||||
argv: &[String],
|
||||
workdir: &Path,
|
||||
stopwatch: &Stopwatch,
|
||||
) -> anyhow::Result<EscalateAction>;
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ShellPolicyFactory {
|
||||
provider: Arc<dyn ShellActionProvider>,
|
||||
}
|
||||
|
||||
impl ShellPolicyFactory {
|
||||
pub fn new<P>(provider: P) -> Self
|
||||
where
|
||||
P: ShellActionProvider + 'static,
|
||||
{
|
||||
Self {
|
||||
provider: Arc::new(provider),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_provider(provider: Arc<dyn ShellActionProvider>) -> Self {
|
||||
Self { provider }
|
||||
}
|
||||
}
|
||||
|
||||
/// Public only because it is the associated `Policy` type in the public
|
||||
/// `EscalationPolicyFactory` impl for `ShellPolicyFactory`.
|
||||
pub struct ShellEscalationPolicy {
|
||||
provider: Arc<dyn ShellActionProvider>,
|
||||
stopwatch: Stopwatch,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl EscalationPolicy for ShellEscalationPolicy {
|
||||
async fn determine_action(
|
||||
&self,
|
||||
file: &Path,
|
||||
argv: &[String],
|
||||
workdir: &Path,
|
||||
) -> anyhow::Result<EscalateAction> {
|
||||
self.provider
|
||||
.determine_action(file, argv, workdir, &self.stopwatch)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
impl EscalationPolicyFactory for ShellPolicyFactory {
|
||||
type Policy = ShellEscalationPolicy;
|
||||
|
||||
fn create_policy(&self, _policy: Arc<RwLock<Policy>>, stopwatch: Stopwatch) -> Self::Policy {
|
||||
ShellEscalationPolicy {
|
||||
provider: Arc::clone(&self.provider),
|
||||
stopwatch,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -1,16 +1,13 @@
|
|||
use std::collections::HashMap;
|
||||
use std::os::fd::AsRawFd;
|
||||
use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
use std::process::Stdio;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use anyhow::Context as _;
|
||||
use codex_execpolicy::Policy;
|
||||
use path_absolutize::Absolutize as _;
|
||||
use tokio::process::Command;
|
||||
use tokio::sync::RwLock;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
|
||||
use crate::unix::escalate_protocol::ESCALATE_SOCKET_ENV_VAR;
|
||||
|
|
@ -24,7 +21,6 @@ use crate::unix::escalate_protocol::SuperExecResult;
|
|||
use crate::unix::escalation_policy::EscalationPolicy;
|
||||
use crate::unix::socket::AsyncDatagramSocket;
|
||||
use crate::unix::socket::AsyncSocket;
|
||||
use crate::unix::stopwatch::Stopwatch;
|
||||
|
||||
/// Adapter for running the shell command after the escalation server has been set up.
|
||||
///
|
||||
|
|
@ -66,14 +62,14 @@ pub struct ExecResult {
|
|||
pub timed_out: bool,
|
||||
}
|
||||
|
||||
struct EscalateServer {
|
||||
pub struct EscalateServer {
|
||||
bash_path: PathBuf,
|
||||
execve_wrapper: PathBuf,
|
||||
policy: Arc<dyn EscalationPolicy>,
|
||||
}
|
||||
|
||||
impl EscalateServer {
|
||||
fn new<P>(bash_path: PathBuf, execve_wrapper: PathBuf, policy: P) -> Self
|
||||
pub fn new<P>(bash_path: PathBuf, execve_wrapper: PathBuf, policy: P) -> Self
|
||||
where
|
||||
P: EscalationPolicy + Send + Sync + 'static,
|
||||
{
|
||||
|
|
@ -84,7 +80,7 @@ impl EscalateServer {
|
|||
}
|
||||
}
|
||||
|
||||
async fn exec(
|
||||
pub async fn exec(
|
||||
&self,
|
||||
params: ExecParams,
|
||||
cancel_rx: CancellationToken,
|
||||
|
|
@ -126,35 +122,6 @@ impl EscalateServer {
|
|||
}
|
||||
}
|
||||
|
||||
/// Factory for creating escalation policy instances for a single shell run.
|
||||
pub trait EscalationPolicyFactory {
|
||||
type Policy: EscalationPolicy + Send + Sync + 'static;
|
||||
|
||||
fn create_policy(&self, policy: Arc<RwLock<Policy>>, stopwatch: Stopwatch) -> Self::Policy;
|
||||
}
|
||||
|
||||
pub async fn run_escalate_server(
|
||||
exec_params: ExecParams,
|
||||
shell_program: impl AsRef<Path>,
|
||||
execve_wrapper: impl AsRef<Path>,
|
||||
policy: Arc<RwLock<Policy>>,
|
||||
escalation_policy_factory: impl EscalationPolicyFactory,
|
||||
effective_timeout: Duration,
|
||||
command_executor: &dyn ShellCommandExecutor,
|
||||
) -> anyhow::Result<ExecResult> {
|
||||
let stopwatch = Stopwatch::new(effective_timeout);
|
||||
let cancel_token = stopwatch.cancellation_token();
|
||||
let escalate_server = EscalateServer::new(
|
||||
shell_program.as_ref().to_path_buf(),
|
||||
execve_wrapper.as_ref().to_path_buf(),
|
||||
escalation_policy_factory.create_policy(policy, stopwatch),
|
||||
);
|
||||
|
||||
escalate_server
|
||||
.exec(exec_params, cancel_token, command_executor)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn escalate_task(
|
||||
socket: AsyncDatagramSocket,
|
||||
policy: Arc<dyn EscalationPolicy>,
|
||||
|
|
|
|||
|
|
@ -53,7 +53,6 @@
|
|||
//! | |
|
||||
//! o<-----x
|
||||
//!
|
||||
pub mod core_shell_escalation;
|
||||
pub mod escalate_client;
|
||||
pub mod escalate_protocol;
|
||||
pub mod escalate_server;
|
||||
|
|
@ -62,15 +61,12 @@ pub mod execve_wrapper;
|
|||
pub mod socket;
|
||||
pub mod stopwatch;
|
||||
|
||||
pub use self::core_shell_escalation::ShellActionProvider;
|
||||
pub use self::core_shell_escalation::ShellPolicyFactory;
|
||||
pub use self::escalate_client::run_shell_escalation_execve_wrapper;
|
||||
pub use self::escalate_protocol::EscalateAction;
|
||||
pub use self::escalate_server::EscalationPolicyFactory;
|
||||
pub use self::escalate_server::EscalateServer;
|
||||
pub use self::escalate_server::ExecParams;
|
||||
pub use self::escalate_server::ExecResult;
|
||||
pub use self::escalate_server::ShellCommandExecutor;
|
||||
pub use self::escalate_server::run_escalate_server;
|
||||
pub use self::escalation_policy::EscalationPolicy;
|
||||
pub use self::execve_wrapper::main_execve_wrapper;
|
||||
pub use self::stopwatch::Stopwatch;
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue