diff --git a/codex-rs/Cargo.lock b/codex-rs/Cargo.lock index 31f12774a..a7ba2a4c6 100644 --- a/codex-rs/Cargo.lock +++ b/codex-rs/Cargo.lock @@ -1182,6 +1182,26 @@ dependencies = [ "wiremock", ] +[[package]] +name = "codex-exec-server" +version = "0.0.0" +dependencies = [ + "anyhow", + "clap", + "codex-core", + "libc", + "path-absolutize", + "pretty_assertions", + "rmcp", + "serde", + "serde_json", + "socket2 0.6.0", + "tempfile", + "tokio", + "tracing", + "tracing-subscriber", +] + [[package]] name = "codex-execpolicy" version = "0.0.0" diff --git a/codex-rs/Cargo.toml b/codex-rs/Cargo.toml index b19bf7660..2e88aba71 100644 --- a/codex-rs/Cargo.toml +++ b/codex-rs/Cargo.toml @@ -16,6 +16,7 @@ members = [ "common", "core", "exec", + "exec-server", "execpolicy", "execpolicy2", "keyring-store", @@ -176,6 +177,7 @@ sha1 = "0.10.6" sha2 = "0.10" shlex = "1.3.0" similar = "2.7.0" +socket2 = "0.6.0" starlark = "0.13.0" strum = "0.27.2" strum_macros = "0.27.2" diff --git a/codex-rs/exec-server/Cargo.toml b/codex-rs/exec-server/Cargo.toml new file mode 100644 index 000000000..24ee460f5 --- /dev/null +++ b/codex-rs/exec-server/Cargo.toml @@ -0,0 +1,47 @@ +[package] +edition = "2024" +name = "codex-exec-server" +version = { workspace = true } + +[[bin]] +name = "codex-exec-server" +path = "src/main.rs" + +[lints] +workspace = true + +[dependencies] +anyhow = { workspace = true } +clap = { workspace = true, features = ["derive"] } +codex-core = { workspace = true } +libc = { workspace = true } +path-absolutize = { workspace = true } +rmcp = { workspace = true, default-features = false, features = [ + "auth", + "elicitation", + "base64", + "client", + "macros", + "schemars", + "server", + "transport-child-process", + "transport-streamable-http-client-reqwest", + "transport-streamable-http-server", + "transport-io", +] } +serde = { workspace = true, features = ["derive"] } +serde_json = { workspace = true } +socket2 = { workspace = true } +tokio = { workspace = true, features = [ + "io-std", + "macros", + "process", + "rt-multi-thread", + "signal", +] } +tracing = { workspace = true } +tracing-subscriber = { workspace = true, features = ["env-filter", "fmt"] } + +[dev-dependencies] +pretty_assertions = { workspace = true } +tempfile = { workspace = true } diff --git a/codex-rs/exec-server/src/main.rs b/codex-rs/exec-server/src/main.rs new file mode 100644 index 000000000..23a18b252 --- /dev/null +++ b/codex-rs/exec-server/src/main.rs @@ -0,0 +1,11 @@ +#[cfg(target_os = "windows")] +fn main() { + eprintln!("codex-exec-server is not implemented on Windows targets"); + std::process::exit(1); +} + +#[cfg(not(target_os = "windows"))] +mod posix; + +#[cfg(not(target_os = "windows"))] +pub use posix::main; diff --git a/codex-rs/exec-server/src/posix.rs b/codex-rs/exec-server/src/posix.rs new file mode 100644 index 000000000..179fc7a6f --- /dev/null +++ b/codex-rs/exec-server/src/posix.rs @@ -0,0 +1,164 @@ +//! This is an MCP that implements an alternative `shell` tool with fine-grained privilege +//! escalation based on a per-exec() policy. +//! +//! We spawn Bash process inside a sandbox. The Bash we spawn is patched to allow us to intercept +//! every exec() call it makes by invoking a wrapper program and passing in the arguments it would +//! have passed to exec(). The Bash process (and its descendants) inherit a communication socket +//! from us, and we give its fd number in the CODEX_ESCALATE_SOCKET environment variable. +//! +//! When we intercept an exec() call, we send a message over the socket back to the main +//! MCP process. The MCP process can then decide whether to allow the exec() call to proceed +//! or to escalate privileges and run the requested command with elevated permissions. In the +//! latter case, we send a message back to the child requesting that it forward its open FDs to us. +//! We then execute the requested command on its behalf, patching in the forwarded FDs. +//! +//! +//! ### The privilege escalation flow +//! +//! Child MCP Bash Escalate Helper +//! | +//! o----->o +//! | | +//! | o--(exec)-->o +//! | | | +//! |o<-(EscalateReq)--o +//! || | | +//! |o--(Escalate)---->o +//! || | | +//! |o<---------(fds)--o +//! || | | +//! o<-----o | | +//! | || | | +//! x----->o | | +//! || | | +//! |x--(exit code)--->o +//! | | | +//! | o<--(exit)--x +//! | | +//! o<-----x +//! +//! ### The non-escalation flow +//! +//! MCP Bash Escalate Helper Child +//! | +//! o----->o +//! | | +//! | o--(exec)-->o +//! | | | +//! |o<-(EscalateReq)--o +//! || | | +//! |o-(Run)---------->o +//! | | | +//! | | x--(exec)-->o +//! | | | +//! | o<--------------(exit)--x +//! | | +//! o<-----x +//! +use std::path::Path; + +use clap::Parser; +use clap::Subcommand; +use tracing_subscriber::EnvFilter; +use tracing_subscriber::{self}; + +use crate::posix::escalate_protocol::EscalateAction; +use crate::posix::escalate_server::EscalateServer; + +mod escalate_client; +mod escalate_protocol; +mod escalate_server; +mod mcp; +mod socket; + +fn dummy_exec_policy(file: &Path, argv: &[String], _workdir: &Path) -> EscalateAction { + // TODO: execpolicy + if file == Path::new("/opt/homebrew/bin/gh") + && let [_, arg1, arg2, ..] = argv + && arg1 == "issue" + && arg2 == "list" + { + return EscalateAction::Escalate; + } + EscalateAction::Run +} + +#[derive(Parser)] +#[command(version)] +pub struct Cli { + #[command(subcommand)] + subcommand: Option, +} + +#[derive(Subcommand)] +enum Commands { + Escalate(EscalateArgs), + ShellExec(ShellExecArgs), +} + +/// Invoked from within the sandbox to (potentially) escalate permissions. +#[derive(Parser, Debug)] +struct EscalateArgs { + file: String, + + #[arg(trailing_var_arg = true)] + argv: Vec, +} + +impl EscalateArgs { + /// This is the escalate client. It talks to the escalate server to determine whether to exec() + /// the command directly or to proxy to the escalate server. + async fn run(self) -> anyhow::Result { + let EscalateArgs { file, argv } = self; + escalate_client::run(file, argv).await + } +} + +/// Debugging command to emulate an MCP "shell" tool call. +#[derive(Parser, Debug)] +struct ShellExecArgs { + command: String, +} + +#[tokio::main] +pub async fn main() -> anyhow::Result<()> { + let cli = Cli::parse(); + tracing_subscriber::fmt() + .with_env_filter(EnvFilter::from_default_env()) + .with_writer(std::io::stderr) + .with_ansi(false) + .init(); + + match cli.subcommand { + Some(Commands::Escalate(args)) => { + std::process::exit(args.run().await?); + } + Some(Commands::ShellExec(args)) => { + let bash_path = mcp::get_bash_path()?; + let escalate_server = EscalateServer::new(bash_path, dummy_exec_policy); + let result = escalate_server + .exec( + args.command.clone(), + std::env::vars().collect(), + std::env::current_dir()?, + None, + ) + .await?; + println!("{result:?}"); + std::process::exit(result.exit_code); + } + None => { + let bash_path = mcp::get_bash_path()?; + + tracing::info!("Starting MCP server"); + let service = mcp::serve(bash_path, dummy_exec_policy) + .await + .inspect_err(|e| { + tracing::error!("serving error: {:?}", e); + })?; + + service.waiting().await?; + Ok(()) + } + } +} diff --git a/codex-rs/exec-server/src/posix/escalate_client.rs b/codex-rs/exec-server/src/posix/escalate_client.rs new file mode 100644 index 000000000..2add1dade --- /dev/null +++ b/codex-rs/exec-server/src/posix/escalate_client.rs @@ -0,0 +1,102 @@ +use std::io; +use std::os::fd::AsRawFd; +use std::os::fd::FromRawFd as _; +use std::os::fd::OwnedFd; + +use anyhow::Context as _; + +use crate::posix::escalate_protocol::BASH_EXEC_WRAPPER_ENV_VAR; +use crate::posix::escalate_protocol::ESCALATE_SOCKET_ENV_VAR; +use crate::posix::escalate_protocol::EscalateAction; +use crate::posix::escalate_protocol::EscalateRequest; +use crate::posix::escalate_protocol::EscalateResponse; +use crate::posix::escalate_protocol::SuperExecMessage; +use crate::posix::escalate_protocol::SuperExecResult; +use crate::posix::socket::AsyncDatagramSocket; +use crate::posix::socket::AsyncSocket; + +fn get_escalate_client() -> anyhow::Result { + // TODO: we should defensively require only calling this once, since AsyncSocket will take ownership of the fd. + let client_fd = std::env::var(ESCALATE_SOCKET_ENV_VAR)?.parse::()?; + if client_fd < 0 { + return Err(anyhow::anyhow!( + "{ESCALATE_SOCKET_ENV_VAR} is not a valid file descriptor: {client_fd}" + )); + } + Ok(unsafe { AsyncDatagramSocket::from_raw_fd(client_fd) }?) +} + +pub(crate) async fn run(file: String, argv: Vec) -> anyhow::Result { + let handshake_client = get_escalate_client()?; + let (server, client) = AsyncSocket::pair()?; + const HANDSHAKE_MESSAGE: [u8; 1] = [0]; + handshake_client + .send_with_fds(&HANDSHAKE_MESSAGE, &[server.into_inner().into()]) + .await + .context("failed to send handshake datagram")?; + let env = std::env::vars() + .filter(|(k, _)| { + !matches!( + k.as_str(), + ESCALATE_SOCKET_ENV_VAR | BASH_EXEC_WRAPPER_ENV_VAR + ) + }) + .collect(); + client + .send(EscalateRequest { + file: file.clone().into(), + argv: argv.clone(), + workdir: std::env::current_dir()?, + env, + }) + .await + .context("failed to send EscalateRequest")?; + let message = client.receive::().await?; + match message.action { + EscalateAction::Escalate => { + // TODO: maybe we should send ALL open FDs (except the escalate client)? + let fds_to_send = [ + unsafe { OwnedFd::from_raw_fd(io::stdin().as_raw_fd()) }, + unsafe { OwnedFd::from_raw_fd(io::stdout().as_raw_fd()) }, + unsafe { OwnedFd::from_raw_fd(io::stderr().as_raw_fd()) }, + ]; + + // TODO: also forward signals over the super-exec socket + + client + .send_with_fds( + SuperExecMessage { + fds: fds_to_send.iter().map(AsRawFd::as_raw_fd).collect(), + }, + &fds_to_send, + ) + .await + .context("failed to send SuperExecMessage")?; + let SuperExecResult { exit_code } = client.receive::().await?; + Ok(exit_code) + } + EscalateAction::Run => { + // We avoid std::process::Command here because we want to be as transparent as + // possible. std::os::unix::process::CommandExt has .exec() but it does some funky + // stuff with signal masks and dup2() on its standard FDs, which we don't want. + use std::ffi::CString; + let file = CString::new(file).context("NUL in file")?; + + let argv_cstrs: Vec = argv + .iter() + .map(|s| CString::new(s.as_str()).context("NUL in argv")) + .collect::, _>>()?; + + let mut argv: Vec<*const libc::c_char> = + argv_cstrs.iter().map(|s| s.as_ptr()).collect(); + argv.push(std::ptr::null()); + + let err = unsafe { + libc::execv(file.as_ptr(), argv.as_ptr()); + std::io::Error::last_os_error() + }; + + Err(err.into()) + } + } +} diff --git a/codex-rs/exec-server/src/posix/escalate_protocol.rs b/codex-rs/exec-server/src/posix/escalate_protocol.rs new file mode 100644 index 000000000..09d33fc98 --- /dev/null +++ b/codex-rs/exec-server/src/posix/escalate_protocol.rs @@ -0,0 +1,49 @@ +use std::collections::HashMap; +use std::os::fd::RawFd; +use std::path::PathBuf; + +use serde::Deserialize; +use serde::Serialize; + +/// 'exec-server escalate' reads this to find the inherited FD for the escalate socket. +pub(super) const ESCALATE_SOCKET_ENV_VAR: &str = "CODEX_ESCALATE_SOCKET"; + +/// The patched bash uses this to wrap exec() calls. +pub(super) const BASH_EXEC_WRAPPER_ENV_VAR: &str = "BASH_EXEC_WRAPPER"; + +/// The client sends this to the server to request an exec() call. +#[derive(Clone, Serialize, Deserialize, Debug, PartialEq, Eq)] +pub(super) struct EscalateRequest { + /// The absolute path to the executable to run, i.e. the first arg to exec. + pub(super) file: PathBuf, + /// The argv, including the program name (argv[0]). + pub(super) argv: Vec, + pub(super) workdir: PathBuf, + pub(super) env: HashMap, +} + +/// The server sends this to the client to respond to an exec() request. +#[derive(Clone, Serialize, Deserialize, Debug, PartialEq, Eq)] +pub(super) struct EscalateResponse { + pub(super) action: EscalateAction, +} + +#[derive(Clone, Serialize, Deserialize, Debug, PartialEq, Eq)] +pub(super) enum EscalateAction { + /// The command should be run directly by the client. + Run, + /// The command should be escalated to the server for execution. + Escalate, +} + +/// The client sends this to the server to forward its open FDs. +#[derive(Clone, Serialize, Deserialize, Debug)] +pub(super) struct SuperExecMessage { + pub(super) fds: Vec, +} + +/// The server responds when the exec()'d command has exited. +#[derive(Clone, Serialize, Deserialize, Debug)] +pub(super) struct SuperExecResult { + pub(super) exit_code: i32, +} diff --git a/codex-rs/exec-server/src/posix/escalate_server.rs b/codex-rs/exec-server/src/posix/escalate_server.rs new file mode 100644 index 000000000..6d058ff0f --- /dev/null +++ b/codex-rs/exec-server/src/posix/escalate_server.rs @@ -0,0 +1,274 @@ +use std::collections::HashMap; +use std::os::fd::AsRawFd; +use std::path::Path; +use std::path::PathBuf; +use std::process::Stdio; +use std::time::Duration; + +use anyhow::Context as _; +use path_absolutize::Absolutize as _; + +use codex_core::exec::SandboxType; +use codex_core::exec::process_exec_tool_call; +use codex_core::get_platform_sandbox; +use codex_core::protocol::SandboxPolicy; +use tokio::process::Command; + +use crate::posix::escalate_protocol::BASH_EXEC_WRAPPER_ENV_VAR; +use crate::posix::escalate_protocol::ESCALATE_SOCKET_ENV_VAR; +use crate::posix::escalate_protocol::EscalateAction; +use crate::posix::escalate_protocol::EscalateRequest; +use crate::posix::escalate_protocol::EscalateResponse; +use crate::posix::escalate_protocol::SuperExecMessage; +use crate::posix::escalate_protocol::SuperExecResult; +use crate::posix::socket::AsyncDatagramSocket; +use crate::posix::socket::AsyncSocket; + +/// This is the policy which decides how to handle an exec() call. +/// +/// `file` is the absolute, canonical path to the executable to run, i.e. the first arg to exec. +/// `argv` is the argv, including the program name (`argv[0]`). +/// `workdir` is the absolute, canonical path to the working directory in which to execute the +/// command. +pub(crate) type ExecPolicy = fn(file: &Path, argv: &[String], workdir: &Path) -> EscalateAction; + +pub(crate) struct EscalateServer { + bash_path: PathBuf, + policy: ExecPolicy, +} + +impl EscalateServer { + pub fn new(bash_path: PathBuf, policy: ExecPolicy) -> Self { + Self { bash_path, policy } + } + + pub async fn exec( + &self, + command: String, + env: HashMap, + workdir: PathBuf, + timeout_ms: Option, + ) -> anyhow::Result { + let (escalate_server, escalate_client) = AsyncDatagramSocket::pair()?; + let client_socket = escalate_client.into_inner(); + client_socket.set_cloexec(false)?; + + let escalate_task = tokio::spawn(escalate_task(escalate_server, self.policy)); + let mut env = env.clone(); + env.insert( + ESCALATE_SOCKET_ENV_VAR.to_string(), + client_socket.as_raw_fd().to_string(), + ); + env.insert( + BASH_EXEC_WRAPPER_ENV_VAR.to_string(), + format!("{} escalate", std::env::current_exe()?.to_string_lossy()), + ); + let result = process_exec_tool_call( + codex_core::exec::ExecParams { + command: vec![ + self.bash_path.to_string_lossy().to_string(), + "-c".to_string(), + command, + ], + cwd: PathBuf::from(&workdir), + timeout_ms, + env, + with_escalated_permissions: None, + justification: None, + arg0: None, + }, + get_platform_sandbox().unwrap_or(SandboxType::None), + // TODO: use the sandbox policy and cwd from the calling client + &SandboxPolicy::ReadOnly, + &PathBuf::from("/__NONEXISTENT__"), // This is ignored for ReadOnly + &None, + None, + ) + .await?; + escalate_task.abort(); + let result = ExecResult { + exit_code: result.exit_code, + output: result.aggregated_output.text, + duration: result.duration, + timed_out: result.timed_out, + }; + Ok(result) + } +} + +async fn escalate_task(socket: AsyncDatagramSocket, policy: ExecPolicy) -> anyhow::Result<()> { + loop { + let (_, mut fds) = socket.receive_with_fds().await?; + if fds.len() != 1 { + tracing::error!("expected 1 fd in datagram handshake, got {}", fds.len()); + continue; + } + let stream_socket = AsyncSocket::from_fd(fds.remove(0))?; + tokio::spawn(async move { + if let Err(err) = handle_escalate_session_with_policy(stream_socket, policy).await { + tracing::error!("escalate session failed: {err:?}"); + } + }); + } +} + +#[derive(Debug)] +pub(crate) struct ExecResult { + pub(crate) exit_code: i32, + pub(crate) output: String, + pub(crate) duration: Duration, + pub(crate) timed_out: bool, +} + +async fn handle_escalate_session_with_policy( + socket: AsyncSocket, + policy: ExecPolicy, +) -> anyhow::Result<()> { + let EscalateRequest { + file, + argv, + workdir, + env, + } = socket.receive::().await?; + let file = PathBuf::from(&file).absolutize()?.into_owned(); + let workdir = PathBuf::from(&workdir).absolutize()?.into_owned(); + let action = policy(file.as_path(), &argv, &workdir); + tracing::debug!("decided {action:?} for {file:?} {argv:?} {workdir:?}"); + match action { + EscalateAction::Run => { + socket + .send(EscalateResponse { + action: EscalateAction::Run, + }) + .await?; + } + EscalateAction::Escalate => { + socket + .send(EscalateResponse { + action: EscalateAction::Escalate, + }) + .await?; + let (msg, fds) = socket + .receive_with_fds::() + .await + .context("failed to receive SuperExecMessage")?; + if fds.len() != msg.fds.len() { + return Err(anyhow::anyhow!( + "mismatched number of fds in SuperExecMessage: {} in the message, {} from the control message", + msg.fds.len(), + fds.len() + )); + } + + if msg + .fds + .iter() + .any(|src_fd| fds.iter().any(|dst_fd| dst_fd.as_raw_fd() == *src_fd)) + { + return Err(anyhow::anyhow!( + "overlapping fds not yet supported in SuperExecMessage" + )); + } + + let mut command = Command::new(file); + command + .args(&argv[1..]) + .arg0(argv[0].clone()) + .envs(&env) + .current_dir(&workdir) + .stdin(Stdio::null()) + .stdout(Stdio::null()) + .stderr(Stdio::null()); + unsafe { + command.pre_exec(move || { + for (dst_fd, src_fd) in msg.fds.iter().zip(&fds) { + libc::dup2(src_fd.as_raw_fd(), *dst_fd); + } + Ok(()) + }); + } + let mut child = command.spawn()?; + let exit_status = child.wait().await?; + socket + .send(SuperExecResult { + exit_code: exit_status.code().unwrap_or(127), + }) + .await?; + } + } + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + use pretty_assertions::assert_eq; + use std::collections::HashMap; + use std::path::PathBuf; + + #[tokio::test] + async fn handle_escalate_session_respects_run_in_sandbox_decision() -> anyhow::Result<()> { + let (server, client) = AsyncSocket::pair()?; + let server_task = tokio::spawn(handle_escalate_session_with_policy( + server, + |_file, _argv, _workdir| EscalateAction::Run, + )); + + client + .send(EscalateRequest { + file: PathBuf::from("/bin/echo"), + argv: vec!["echo".to_string()], + workdir: PathBuf::from("/tmp"), + env: HashMap::new(), + }) + .await?; + + let response = client.receive::().await?; + assert_eq!( + EscalateResponse { + action: EscalateAction::Run, + }, + response + ); + server_task.await? + } + + #[tokio::test] + async fn handle_escalate_session_executes_escalated_command() -> anyhow::Result<()> { + let (server, client) = AsyncSocket::pair()?; + let server_task = tokio::spawn(handle_escalate_session_with_policy( + server, + |_file, _argv, _workdir| EscalateAction::Escalate, + )); + + client + .send(EscalateRequest { + file: PathBuf::from("/bin/sh"), + argv: vec![ + "sh".to_string(), + "-c".to_string(), + r#"if [ "$KEY" = VALUE ]; then exit 42; else exit 1; fi"#.to_string(), + ], + workdir: std::env::current_dir()?, + env: HashMap::from([("KEY".to_string(), "VALUE".to_string())]), + }) + .await?; + + let response = client.receive::().await?; + assert_eq!( + EscalateResponse { + action: EscalateAction::Escalate, + }, + response + ); + + client + .send_with_fds(SuperExecMessage { fds: Vec::new() }, &[]) + .await?; + + let result = client.receive::().await?; + assert_eq!(42, result.exit_code); + + server_task.await? + } +} diff --git a/codex-rs/exec-server/src/posix/mcp.rs b/codex-rs/exec-server/src/posix/mcp.rs new file mode 100644 index 000000000..f4e2b19d4 --- /dev/null +++ b/codex-rs/exec-server/src/posix/mcp.rs @@ -0,0 +1,154 @@ +use std::path::PathBuf; +use std::time::Duration; + +use anyhow::Context as _; +use anyhow::Result; +use rmcp::ErrorData as McpError; +use rmcp::RoleServer; +use rmcp::ServerHandler; +use rmcp::ServiceExt; +use rmcp::handler::server::router::tool::ToolRouter; +use rmcp::handler::server::wrapper::Parameters; +use rmcp::model::*; +use rmcp::schemars; +use rmcp::service::RequestContext; +use rmcp::service::RunningService; +use rmcp::tool; +use rmcp::tool_handler; +use rmcp::tool_router; +use rmcp::transport::stdio; + +use crate::posix::escalate_server; +use crate::posix::escalate_server::EscalateServer; +use crate::posix::escalate_server::ExecPolicy; + +/// Path to our patched bash. +const CODEX_BASH_PATH_ENV_VAR: &str = "CODEX_BASH_PATH"; + +pub(crate) fn get_bash_path() -> Result { + std::env::var(CODEX_BASH_PATH_ENV_VAR) + .map(PathBuf::from) + .context(format!("{CODEX_BASH_PATH_ENV_VAR} must be set")) +} + +#[derive(Debug, serde::Deserialize, schemars::JsonSchema)] +pub struct ExecParams { + /// The bash string to execute. + pub command: String, + /// The working directory to execute the command in. Must be an absolute path. + pub workdir: String, + /// The timeout for the command in milliseconds. + pub timeout_ms: Option, +} + +#[derive(Debug, serde::Serialize, schemars::JsonSchema)] +pub struct ExecResult { + pub exit_code: i32, + pub output: String, + pub duration: Duration, + pub timed_out: bool, +} + +impl From for ExecResult { + fn from(result: escalate_server::ExecResult) -> Self { + Self { + exit_code: result.exit_code, + output: result.output, + duration: result.duration, + timed_out: result.timed_out, + } + } +} + +#[derive(Clone)] +pub struct ExecTool { + tool_router: ToolRouter, + bash_path: PathBuf, + policy: ExecPolicy, +} + +#[tool_router] +impl ExecTool { + pub fn new(bash_path: PathBuf, policy: ExecPolicy) -> Self { + Self { + tool_router: Self::tool_router(), + bash_path, + policy, + } + } + + /// Runs a shell command and returns its output. You MUST provide the workdir as an absolute path. + #[tool] + async fn shell( + &self, + _context: RequestContext, + Parameters(params): Parameters, + ) -> Result { + let escalate_server = EscalateServer::new(self.bash_path.clone(), self.policy); + let result = escalate_server + .exec( + params.command, + // TODO: use ShellEnvironmentPolicy + std::env::vars().collect(), + PathBuf::from(¶ms.workdir), + params.timeout_ms, + ) + .await + .map_err(|e| McpError::internal_error(e.to_string(), None))?; + Ok(CallToolResult::success(vec![Content::json( + ExecResult::from(result), + )?])) + } + + #[allow(dead_code)] + async fn prompt( + &self, + command: String, + workdir: String, + context: RequestContext, + ) -> Result { + context + .peer + .create_elicitation(CreateElicitationRequestParam { + message: format!("Allow Codex to run `{command:?}` in `{workdir:?}`?"), + #[allow(clippy::expect_used)] + requested_schema: ElicitationSchema::builder() + .property("dummy", PrimitiveSchema::String(StringSchema::new())) + .build() + .expect("failed to build elicitation schema"), + }) + .await + .map_err(|e| McpError::internal_error(e.to_string(), None)) + } +} + +#[tool_handler] +impl ServerHandler for ExecTool { + fn get_info(&self) -> ServerInfo { + ServerInfo { + protocol_version: ProtocolVersion::V_2025_06_18, + capabilities: ServerCapabilities::builder().enable_tools().build(), + server_info: Implementation::from_build_env(), + instructions: Some( + "This server provides a tool to execute shell commands and return their output." + .to_string(), + ), + } + } + + async fn initialize( + &self, + _request: InitializeRequestParam, + _context: RequestContext, + ) -> Result { + Ok(self.get_info()) + } +} + +pub(crate) async fn serve( + bash_path: PathBuf, + policy: ExecPolicy, +) -> Result, rmcp::service::ServerInitializeError> { + let tool = ExecTool::new(bash_path, policy); + tool.serve(stdio()).await +} diff --git a/codex-rs/exec-server/src/posix/socket.rs b/codex-rs/exec-server/src/posix/socket.rs new file mode 100644 index 000000000..92c93dcc7 --- /dev/null +++ b/codex-rs/exec-server/src/posix/socket.rs @@ -0,0 +1,486 @@ +use libc::c_uint; +use serde::Deserialize; +use serde::Serialize; +use socket2::Domain; +use socket2::MaybeUninitSlice; +use socket2::MsgHdr; +use socket2::MsgHdrMut; +use socket2::Socket; +use socket2::Type; +use std::io::IoSlice; +use std::mem::MaybeUninit; +use std::os::fd::AsRawFd; +use std::os::fd::FromRawFd; +use std::os::fd::OwnedFd; +use std::os::fd::RawFd; +use tokio::io::Interest; +use tokio::io::unix::AsyncFd; + +const MAX_FDS_PER_MESSAGE: usize = 16; +const LENGTH_PREFIX_SIZE: usize = size_of::(); +const MAX_DATAGRAM_SIZE: usize = 8192; + +/// Converts a slice of MaybeUninit to a slice of T. +/// +/// The caller guarantees that every element of `buf` is initialized. +fn assume_init(buf: &[MaybeUninit]) -> &[T] { + unsafe { std::slice::from_raw_parts(buf.as_ptr().cast(), buf.len()) } +} + +fn assume_init_slice(buf: &[MaybeUninit; N]) -> &[T; N] { + unsafe { &*(buf as *const [MaybeUninit; N] as *const [T; N]) } +} + +fn assume_init_vec(mut buf: Vec>) -> Vec { + unsafe { + let ptr = buf.as_mut_ptr() as *mut T; + let len = buf.len(); + let cap = buf.capacity(); + std::mem::forget(buf); + Vec::from_raw_parts(ptr, len, cap) + } +} + +fn control_space_for_fds(count: usize) -> usize { + unsafe { libc::CMSG_SPACE((count * size_of::()) as _) as usize } +} + +/// Extracts the FDs from a SCM_RIGHTS control message. +fn extract_fds(control: &[u8]) -> Vec { + let mut fds = Vec::new(); + let mut hdr: libc::msghdr = unsafe { std::mem::zeroed() }; + hdr.msg_control = control.as_ptr() as *mut libc::c_void; + hdr.msg_controllen = control.len() as _; + let hdr = hdr; // drop mut + + let mut cmsg = unsafe { libc::CMSG_FIRSTHDR(&hdr) as *const libc::cmsghdr }; + while !cmsg.is_null() { + let level = unsafe { (*cmsg).cmsg_level }; + let ty = unsafe { (*cmsg).cmsg_type }; + if level == libc::SOL_SOCKET && ty == libc::SCM_RIGHTS { + let data_ptr = unsafe { libc::CMSG_DATA(cmsg).cast::() }; + let fd_count: usize = { + let cmsg_data_len = + unsafe { (*cmsg).cmsg_len as usize } - unsafe { libc::CMSG_LEN(0) as usize }; + cmsg_data_len / size_of::() + }; + for i in 0..fd_count { + let fd = unsafe { data_ptr.add(i).read() }; + fds.push(unsafe { OwnedFd::from_raw_fd(fd) }); + } + } + cmsg = unsafe { libc::CMSG_NXTHDR(&hdr, cmsg) }; + } + fds +} + +/// Read a frame from a SOCK_STREAM socket. +/// +/// A frame is a message length prefix followed by a payload. FDs may be included in the control +/// message when receiving the frame header. +async fn read_frame(async_socket: &AsyncFd) -> std::io::Result<(Vec, Vec)> { + let (message_len, fds) = read_frame_header(async_socket).await?; + let payload = read_frame_payload(async_socket, message_len).await?; + Ok((payload, fds)) +} + +/// Read the frame header (i.e. length) and any FDs from a SOCK_STREAM socket. +async fn read_frame_header( + async_socket: &AsyncFd, +) -> std::io::Result<(usize, Vec)> { + let mut header = [MaybeUninit::::uninit(); LENGTH_PREFIX_SIZE]; + let mut filled = 0; + let mut control = vec![MaybeUninit::::uninit(); control_space_for_fds(MAX_FDS_PER_MESSAGE)]; + let mut captured_control = false; + + while filled < LENGTH_PREFIX_SIZE { + let mut guard = async_socket.readable().await?; + // The first read should come with a control message containing any FDs. + let result = if !captured_control { + guard.try_io(|inner| { + let mut bufs = [MaybeUninitSlice::new(&mut header[filled..])]; + let (read, control_len) = { + let mut msg = MsgHdrMut::new() + .with_buffers(&mut bufs) + .with_control(&mut control); + let read = inner.get_ref().recvmsg(&mut msg, 0)?; + (read, msg.control_len()) + }; + control.truncate(control_len); + captured_control = true; + Ok(read) + }) + } else { + guard.try_io(|inner| inner.get_ref().recv(&mut header[filled..])) + }; + let Ok(result) = result else { + // Would block, try again. + continue; + }; + + let read = result?; + if read == 0 { + return Err(std::io::Error::new( + std::io::ErrorKind::UnexpectedEof, + "socket closed while receiving frame header", + )); + } + + filled += read; + assert!(filled <= LENGTH_PREFIX_SIZE); + if filled == LENGTH_PREFIX_SIZE { + let len_bytes = assume_init_slice(&header); + let payload_len = u32::from_le_bytes(*len_bytes) as usize; + let fds = extract_fds(assume_init(&control)); + return Ok((payload_len, fds)); + } + } + unreachable!("header loop always returns") +} + +/// Read `message_len` bytes from a SOCK_STREAM socket. +async fn read_frame_payload( + async_socket: &AsyncFd, + message_len: usize, +) -> std::io::Result> { + if message_len == 0 { + return Ok(Vec::new()); + } + let mut payload = vec![MaybeUninit::::uninit(); message_len]; + let mut filled = 0; + while filled < message_len { + let mut guard = async_socket.readable().await?; + let result = guard.try_io(|inner| inner.get_ref().recv(&mut payload[filled..])); + let Ok(result) = result else { + // Would block, try again. + continue; + }; + let read = result?; + if read == 0 { + return Err(std::io::Error::new( + std::io::ErrorKind::UnexpectedEof, + "socket closed while receiving frame payload", + )); + } + filled += read; + assert!(filled <= message_len); + if filled == message_len { + return Ok(assume_init_vec(payload)); + } + } + unreachable!("loop exits only after returning payload") +} + +fn send_message_bytes(socket: &Socket, data: &[u8], fds: &[OwnedFd]) -> std::io::Result<()> { + if fds.len() > MAX_FDS_PER_MESSAGE { + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidInput, + format!("too many fds: {}", fds.len()), + )); + } + let mut frame = Vec::with_capacity(LENGTH_PREFIX_SIZE + data.len()); + frame.extend_from_slice(&encode_length(data.len())?); + frame.extend_from_slice(data); + + let mut control = vec![0u8; control_space_for_fds(fds.len())]; + unsafe { + let cmsg = control.as_mut_ptr().cast::(); + (*cmsg).cmsg_len = libc::CMSG_LEN(size_of::() as c_uint * fds.len() as c_uint) as _; + (*cmsg).cmsg_level = libc::SOL_SOCKET; + (*cmsg).cmsg_type = libc::SCM_RIGHTS; + let data_ptr = libc::CMSG_DATA(cmsg).cast::(); + for (i, fd) in fds.iter().enumerate() { + data_ptr.add(i).write(fd.as_raw_fd()); + } + } + + let payload = [IoSlice::new(&frame)]; + let msg = MsgHdr::new().with_buffers(&payload).with_control(&control); + let mut sent = socket.sendmsg(&msg, 0)?; + while sent < frame.len() { + let bytes = socket.send(&frame[sent..])?; + if bytes == 0 { + return Err(std::io::Error::new( + std::io::ErrorKind::WriteZero, + "socket closed while sending frame payload", + )); + } + sent += bytes; + } + Ok(()) +} + +fn encode_length(len: usize) -> std::io::Result<[u8; LENGTH_PREFIX_SIZE]> { + let len_u32 = u32::try_from(len).map_err(|_| { + std::io::Error::new( + std::io::ErrorKind::InvalidInput, + format!("message too large: {len}"), + ) + })?; + Ok(len_u32.to_le_bytes()) +} + +pub(crate) fn send_json_message( + socket: &Socket, + msg: T, + fds: &[OwnedFd], +) -> std::io::Result<()> { + let data = serde_json::to_vec(&msg)?; + send_message_bytes(socket, &data, fds) +} + +fn send_datagram_bytes(socket: &Socket, data: &[u8], fds: &[OwnedFd]) -> std::io::Result<()> { + if fds.len() > MAX_FDS_PER_MESSAGE { + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidInput, + format!("too many fds: {}", fds.len()), + )); + } + let mut control = vec![0u8; control_space_for_fds(fds.len())]; + if !fds.is_empty() { + unsafe { + let cmsg = control.as_mut_ptr().cast::(); + (*cmsg).cmsg_len = + libc::CMSG_LEN(size_of::() as c_uint * fds.len() as c_uint) as _; + (*cmsg).cmsg_level = libc::SOL_SOCKET; + (*cmsg).cmsg_type = libc::SCM_RIGHTS; + let data_ptr = libc::CMSG_DATA(cmsg).cast::(); + for (i, fd) in fds.iter().enumerate() { + data_ptr.add(i).write(fd.as_raw_fd()); + } + } + } + let payload = [IoSlice::new(data)]; + let msg = MsgHdr::new().with_buffers(&payload).with_control(&control); + let written = socket.sendmsg(&msg, 0)?; + if written != data.len() { + return Err(std::io::Error::new( + std::io::ErrorKind::WriteZero, + format!( + "short datagram write: wrote {written} bytes out of {}", + data.len() + ), + )); + } + Ok(()) +} + +fn receive_datagram_bytes(socket: &Socket) -> std::io::Result<(Vec, Vec)> { + let mut buffer = vec![MaybeUninit::::uninit(); MAX_DATAGRAM_SIZE]; + let mut control = vec![MaybeUninit::::uninit(); control_space_for_fds(MAX_FDS_PER_MESSAGE)]; + let (read, control_len) = { + let mut bufs = [MaybeUninitSlice::new(&mut buffer)]; + let mut msg = MsgHdrMut::new() + .with_buffers(&mut bufs) + .with_control(&mut control); + let read = socket.recvmsg(&mut msg, 0)?; + (read, msg.control_len()) + }; + let data = assume_init(&buffer[..read]).to_vec(); + let fds = extract_fds(assume_init(&control[..control_len])); + Ok((data, fds)) +} + +pub(crate) struct AsyncSocket { + inner: AsyncFd, +} + +impl AsyncSocket { + fn new(socket: Socket) -> std::io::Result { + socket.set_nonblocking(true)?; + let async_socket = AsyncFd::new(socket)?; + Ok(AsyncSocket { + inner: async_socket, + }) + } + + pub fn from_fd(fd: OwnedFd) -> std::io::Result { + AsyncSocket::new(Socket::from(fd)) + } + + pub fn pair() -> std::io::Result<(AsyncSocket, AsyncSocket)> { + let (server, client) = Socket::pair(Domain::UNIX, Type::STREAM, None)?; + Ok((AsyncSocket::new(server)?, AsyncSocket::new(client)?)) + } + + pub async fn send_with_fds( + &self, + msg: T, + fds: &[OwnedFd], + ) -> std::io::Result<()> { + self.inner + .async_io(Interest::WRITABLE, |socket| { + send_json_message(socket, &msg, fds) + }) + .await + } + + pub async fn receive_with_fds Deserialize<'de>>( + &self, + ) -> std::io::Result<(T, Vec)> { + let (payload, fds) = read_frame(&self.inner).await?; + let message: T = serde_json::from_slice(&payload)?; + Ok((message, fds)) + } + + pub async fn send(&self, msg: T) -> std::io::Result<()> + where + T: Serialize, + { + self.send_with_fds(&msg, &[]).await + } + + pub async fn receive Deserialize<'de>>(&self) -> std::io::Result { + let (msg, fds) = self.receive_with_fds().await?; + if !fds.is_empty() { + tracing::warn!("unexpected fds in receive: {}", fds.len()); + } + Ok(msg) + } + + pub fn into_inner(self) -> Socket { + self.inner.into_inner() + } +} + +pub(crate) struct AsyncDatagramSocket { + inner: AsyncFd, +} + +impl AsyncDatagramSocket { + fn new(socket: Socket) -> std::io::Result { + socket.set_nonblocking(true)?; + Ok(Self { + inner: AsyncFd::new(socket)?, + }) + } + + pub unsafe fn from_raw_fd(fd: RawFd) -> std::io::Result { + Self::new(unsafe { Socket::from_raw_fd(fd) }) + } + + pub fn pair() -> std::io::Result<(Self, Self)> { + let (server, client) = Socket::pair(Domain::UNIX, Type::DGRAM, None)?; + Ok((Self::new(server)?, Self::new(client)?)) + } + + pub async fn send_with_fds(&self, data: &[u8], fds: &[OwnedFd]) -> std::io::Result<()> { + self.inner + .async_io(Interest::WRITABLE, |socket| { + send_datagram_bytes(socket, data, fds) + }) + .await + } + + pub async fn receive_with_fds(&self) -> std::io::Result<(Vec, Vec)> { + self.inner + .async_io(Interest::READABLE, receive_datagram_bytes) + .await + } + + pub fn into_inner(self) -> Socket { + self.inner.into_inner() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use pretty_assertions::assert_eq; + use serde::Deserialize; + use serde::Serialize; + use std::os::fd::AsFd; + use std::os::fd::AsRawFd; + use tempfile::NamedTempFile; + + #[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)] + struct TestPayload { + id: i32, + label: String, + } + + fn fd_list(count: usize) -> std::io::Result> { + let file = NamedTempFile::new()?; + let mut fds = Vec::new(); + for _ in 0..count { + fds.push(file.as_fd().try_clone_to_owned()?); + } + Ok(fds) + } + + #[tokio::test] + async fn async_socket_round_trips_payload_and_fds() -> std::io::Result<()> { + let (server, client) = AsyncSocket::pair()?; + let payload = TestPayload { + id: 7, + label: "round-trip".to_string(), + }; + let send_fds = fd_list(1)?; + + let receive_task = + tokio::spawn(async move { server.receive_with_fds::().await }); + client.send_with_fds(payload.clone(), &send_fds).await?; + drop(send_fds); + + let (received_payload, received_fds) = receive_task.await.unwrap()?; + assert_eq!(payload, received_payload); + assert_eq!(1, received_fds.len()); + let fd_status = unsafe { libc::fcntl(received_fds[0].as_raw_fd(), libc::F_GETFD) }; + assert!( + fd_status >= 0, + "expected received file descriptor to be valid, but got {fd_status}", + ); + Ok(()) + } + + #[tokio::test] + async fn async_datagram_sockets_round_trip_messages() -> std::io::Result<()> { + let (server, client) = AsyncDatagramSocket::pair()?; + let data = b"datagram payload".to_vec(); + let send_fds = fd_list(1)?; + let receive_task = tokio::spawn(async move { server.receive_with_fds().await }); + + client.send_with_fds(&data, &send_fds).await?; + drop(send_fds); + + let (received_bytes, received_fds) = receive_task.await.unwrap()?; + assert_eq!(data, received_bytes); + assert_eq!(1, received_fds.len()); + Ok(()) + } + + #[test] + fn send_message_bytes_rejects_excessive_fd_counts() -> std::io::Result<()> { + let (socket, _peer) = Socket::pair(Domain::UNIX, Type::STREAM, None)?; + let fds = fd_list(MAX_FDS_PER_MESSAGE + 1)?; + let err = send_message_bytes(&socket, b"hello", &fds).unwrap_err(); + assert_eq!(std::io::ErrorKind::InvalidInput, err.kind()); + Ok(()) + } + + #[test] + fn send_datagram_bytes_rejects_excessive_fd_counts() -> std::io::Result<()> { + let (socket, _peer) = Socket::pair(Domain::UNIX, Type::DGRAM, None)?; + let fds = fd_list(MAX_FDS_PER_MESSAGE + 1)?; + let err = send_datagram_bytes(&socket, b"hi", &fds).unwrap_err(); + assert_eq!(std::io::ErrorKind::InvalidInput, err.kind()); + Ok(()) + } + + #[test] + fn encode_length_errors_for_oversized_messages() { + let err = encode_length(usize::MAX).unwrap_err(); + assert_eq!(std::io::ErrorKind::InvalidInput, err.kind()); + } + + #[tokio::test] + async fn receive_fails_when_peer_closes_before_header() { + let (server, client) = AsyncSocket::pair().expect("failed to create socket pair"); + drop(client); + let err = server + .receive::() + .await + .expect_err("expected read failure"); + assert_eq!(std::io::ErrorKind::UnexpectedEof, err.kind()); + } +}