exec-server (#6630)

This commit is contained in:
Jeremy Rose 2025-11-18 16:20:19 -08:00 committed by GitHub
parent 9275e93364
commit c1391b9f94
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 1309 additions and 0 deletions

20
codex-rs/Cargo.lock generated
View file

@ -1182,6 +1182,26 @@ dependencies = [
"wiremock",
]
[[package]]
name = "codex-exec-server"
version = "0.0.0"
dependencies = [
"anyhow",
"clap",
"codex-core",
"libc",
"path-absolutize",
"pretty_assertions",
"rmcp",
"serde",
"serde_json",
"socket2 0.6.0",
"tempfile",
"tokio",
"tracing",
"tracing-subscriber",
]
[[package]]
name = "codex-execpolicy"
version = "0.0.0"

View file

@ -16,6 +16,7 @@ members = [
"common",
"core",
"exec",
"exec-server",
"execpolicy",
"execpolicy2",
"keyring-store",
@ -176,6 +177,7 @@ sha1 = "0.10.6"
sha2 = "0.10"
shlex = "1.3.0"
similar = "2.7.0"
socket2 = "0.6.0"
starlark = "0.13.0"
strum = "0.27.2"
strum_macros = "0.27.2"

View file

@ -0,0 +1,47 @@
[package]
edition = "2024"
name = "codex-exec-server"
version = { workspace = true }
[[bin]]
name = "codex-exec-server"
path = "src/main.rs"
[lints]
workspace = true
[dependencies]
anyhow = { workspace = true }
clap = { workspace = true, features = ["derive"] }
codex-core = { workspace = true }
libc = { workspace = true }
path-absolutize = { workspace = true }
rmcp = { workspace = true, default-features = false, features = [
"auth",
"elicitation",
"base64",
"client",
"macros",
"schemars",
"server",
"transport-child-process",
"transport-streamable-http-client-reqwest",
"transport-streamable-http-server",
"transport-io",
] }
serde = { workspace = true, features = ["derive"] }
serde_json = { workspace = true }
socket2 = { workspace = true }
tokio = { workspace = true, features = [
"io-std",
"macros",
"process",
"rt-multi-thread",
"signal",
] }
tracing = { workspace = true }
tracing-subscriber = { workspace = true, features = ["env-filter", "fmt"] }
[dev-dependencies]
pretty_assertions = { workspace = true }
tempfile = { workspace = true }

View file

@ -0,0 +1,11 @@
#[cfg(target_os = "windows")]
fn main() {
eprintln!("codex-exec-server is not implemented on Windows targets");
std::process::exit(1);
}
#[cfg(not(target_os = "windows"))]
mod posix;
#[cfg(not(target_os = "windows"))]
pub use posix::main;

View file

@ -0,0 +1,164 @@
//! This is an MCP that implements an alternative `shell` tool with fine-grained privilege
//! escalation based on a per-exec() policy.
//!
//! We spawn Bash process inside a sandbox. The Bash we spawn is patched to allow us to intercept
//! every exec() call it makes by invoking a wrapper program and passing in the arguments it would
//! have passed to exec(). The Bash process (and its descendants) inherit a communication socket
//! from us, and we give its fd number in the CODEX_ESCALATE_SOCKET environment variable.
//!
//! When we intercept an exec() call, we send a message over the socket back to the main
//! MCP process. The MCP process can then decide whether to allow the exec() call to proceed
//! or to escalate privileges and run the requested command with elevated permissions. In the
//! latter case, we send a message back to the child requesting that it forward its open FDs to us.
//! We then execute the requested command on its behalf, patching in the forwarded FDs.
//!
//!
//! ### The privilege escalation flow
//!
//! Child MCP Bash Escalate Helper
//! |
//! o----->o
//! | |
//! | o--(exec)-->o
//! | | |
//! |o<-(EscalateReq)--o
//! || | |
//! |o--(Escalate)---->o
//! || | |
//! |o<---------(fds)--o
//! || | |
//! o<-----o | |
//! | || | |
//! x----->o | |
//! || | |
//! |x--(exit code)--->o
//! | | |
//! | o<--(exit)--x
//! | |
//! o<-----x
//!
//! ### The non-escalation flow
//!
//! MCP Bash Escalate Helper Child
//! |
//! o----->o
//! | |
//! | o--(exec)-->o
//! | | |
//! |o<-(EscalateReq)--o
//! || | |
//! |o-(Run)---------->o
//! | | |
//! | | x--(exec)-->o
//! | | |
//! | o<--------------(exit)--x
//! | |
//! o<-----x
//!
use std::path::Path;
use clap::Parser;
use clap::Subcommand;
use tracing_subscriber::EnvFilter;
use tracing_subscriber::{self};
use crate::posix::escalate_protocol::EscalateAction;
use crate::posix::escalate_server::EscalateServer;
mod escalate_client;
mod escalate_protocol;
mod escalate_server;
mod mcp;
mod socket;
fn dummy_exec_policy(file: &Path, argv: &[String], _workdir: &Path) -> EscalateAction {
// TODO: execpolicy
if file == Path::new("/opt/homebrew/bin/gh")
&& let [_, arg1, arg2, ..] = argv
&& arg1 == "issue"
&& arg2 == "list"
{
return EscalateAction::Escalate;
}
EscalateAction::Run
}
#[derive(Parser)]
#[command(version)]
pub struct Cli {
#[command(subcommand)]
subcommand: Option<Commands>,
}
#[derive(Subcommand)]
enum Commands {
Escalate(EscalateArgs),
ShellExec(ShellExecArgs),
}
/// Invoked from within the sandbox to (potentially) escalate permissions.
#[derive(Parser, Debug)]
struct EscalateArgs {
file: String,
#[arg(trailing_var_arg = true)]
argv: Vec<String>,
}
impl EscalateArgs {
/// This is the escalate client. It talks to the escalate server to determine whether to exec()
/// the command directly or to proxy to the escalate server.
async fn run(self) -> anyhow::Result<i32> {
let EscalateArgs { file, argv } = self;
escalate_client::run(file, argv).await
}
}
/// Debugging command to emulate an MCP "shell" tool call.
#[derive(Parser, Debug)]
struct ShellExecArgs {
command: String,
}
#[tokio::main]
pub async fn main() -> anyhow::Result<()> {
let cli = Cli::parse();
tracing_subscriber::fmt()
.with_env_filter(EnvFilter::from_default_env())
.with_writer(std::io::stderr)
.with_ansi(false)
.init();
match cli.subcommand {
Some(Commands::Escalate(args)) => {
std::process::exit(args.run().await?);
}
Some(Commands::ShellExec(args)) => {
let bash_path = mcp::get_bash_path()?;
let escalate_server = EscalateServer::new(bash_path, dummy_exec_policy);
let result = escalate_server
.exec(
args.command.clone(),
std::env::vars().collect(),
std::env::current_dir()?,
None,
)
.await?;
println!("{result:?}");
std::process::exit(result.exit_code);
}
None => {
let bash_path = mcp::get_bash_path()?;
tracing::info!("Starting MCP server");
let service = mcp::serve(bash_path, dummy_exec_policy)
.await
.inspect_err(|e| {
tracing::error!("serving error: {:?}", e);
})?;
service.waiting().await?;
Ok(())
}
}
}

View file

@ -0,0 +1,102 @@
use std::io;
use std::os::fd::AsRawFd;
use std::os::fd::FromRawFd as _;
use std::os::fd::OwnedFd;
use anyhow::Context as _;
use crate::posix::escalate_protocol::BASH_EXEC_WRAPPER_ENV_VAR;
use crate::posix::escalate_protocol::ESCALATE_SOCKET_ENV_VAR;
use crate::posix::escalate_protocol::EscalateAction;
use crate::posix::escalate_protocol::EscalateRequest;
use crate::posix::escalate_protocol::EscalateResponse;
use crate::posix::escalate_protocol::SuperExecMessage;
use crate::posix::escalate_protocol::SuperExecResult;
use crate::posix::socket::AsyncDatagramSocket;
use crate::posix::socket::AsyncSocket;
fn get_escalate_client() -> anyhow::Result<AsyncDatagramSocket> {
// TODO: we should defensively require only calling this once, since AsyncSocket will take ownership of the fd.
let client_fd = std::env::var(ESCALATE_SOCKET_ENV_VAR)?.parse::<i32>()?;
if client_fd < 0 {
return Err(anyhow::anyhow!(
"{ESCALATE_SOCKET_ENV_VAR} is not a valid file descriptor: {client_fd}"
));
}
Ok(unsafe { AsyncDatagramSocket::from_raw_fd(client_fd) }?)
}
pub(crate) async fn run(file: String, argv: Vec<String>) -> anyhow::Result<i32> {
let handshake_client = get_escalate_client()?;
let (server, client) = AsyncSocket::pair()?;
const HANDSHAKE_MESSAGE: [u8; 1] = [0];
handshake_client
.send_with_fds(&HANDSHAKE_MESSAGE, &[server.into_inner().into()])
.await
.context("failed to send handshake datagram")?;
let env = std::env::vars()
.filter(|(k, _)| {
!matches!(
k.as_str(),
ESCALATE_SOCKET_ENV_VAR | BASH_EXEC_WRAPPER_ENV_VAR
)
})
.collect();
client
.send(EscalateRequest {
file: file.clone().into(),
argv: argv.clone(),
workdir: std::env::current_dir()?,
env,
})
.await
.context("failed to send EscalateRequest")?;
let message = client.receive::<EscalateResponse>().await?;
match message.action {
EscalateAction::Escalate => {
// TODO: maybe we should send ALL open FDs (except the escalate client)?
let fds_to_send = [
unsafe { OwnedFd::from_raw_fd(io::stdin().as_raw_fd()) },
unsafe { OwnedFd::from_raw_fd(io::stdout().as_raw_fd()) },
unsafe { OwnedFd::from_raw_fd(io::stderr().as_raw_fd()) },
];
// TODO: also forward signals over the super-exec socket
client
.send_with_fds(
SuperExecMessage {
fds: fds_to_send.iter().map(AsRawFd::as_raw_fd).collect(),
},
&fds_to_send,
)
.await
.context("failed to send SuperExecMessage")?;
let SuperExecResult { exit_code } = client.receive::<SuperExecResult>().await?;
Ok(exit_code)
}
EscalateAction::Run => {
// We avoid std::process::Command here because we want to be as transparent as
// possible. std::os::unix::process::CommandExt has .exec() but it does some funky
// stuff with signal masks and dup2() on its standard FDs, which we don't want.
use std::ffi::CString;
let file = CString::new(file).context("NUL in file")?;
let argv_cstrs: Vec<CString> = argv
.iter()
.map(|s| CString::new(s.as_str()).context("NUL in argv"))
.collect::<Result<Vec<_>, _>>()?;
let mut argv: Vec<*const libc::c_char> =
argv_cstrs.iter().map(|s| s.as_ptr()).collect();
argv.push(std::ptr::null());
let err = unsafe {
libc::execv(file.as_ptr(), argv.as_ptr());
std::io::Error::last_os_error()
};
Err(err.into())
}
}
}

View file

@ -0,0 +1,49 @@
use std::collections::HashMap;
use std::os::fd::RawFd;
use std::path::PathBuf;
use serde::Deserialize;
use serde::Serialize;
/// 'exec-server escalate' reads this to find the inherited FD for the escalate socket.
pub(super) const ESCALATE_SOCKET_ENV_VAR: &str = "CODEX_ESCALATE_SOCKET";
/// The patched bash uses this to wrap exec() calls.
pub(super) const BASH_EXEC_WRAPPER_ENV_VAR: &str = "BASH_EXEC_WRAPPER";
/// The client sends this to the server to request an exec() call.
#[derive(Clone, Serialize, Deserialize, Debug, PartialEq, Eq)]
pub(super) struct EscalateRequest {
/// The absolute path to the executable to run, i.e. the first arg to exec.
pub(super) file: PathBuf,
/// The argv, including the program name (argv[0]).
pub(super) argv: Vec<String>,
pub(super) workdir: PathBuf,
pub(super) env: HashMap<String, String>,
}
/// The server sends this to the client to respond to an exec() request.
#[derive(Clone, Serialize, Deserialize, Debug, PartialEq, Eq)]
pub(super) struct EscalateResponse {
pub(super) action: EscalateAction,
}
#[derive(Clone, Serialize, Deserialize, Debug, PartialEq, Eq)]
pub(super) enum EscalateAction {
/// The command should be run directly by the client.
Run,
/// The command should be escalated to the server for execution.
Escalate,
}
/// The client sends this to the server to forward its open FDs.
#[derive(Clone, Serialize, Deserialize, Debug)]
pub(super) struct SuperExecMessage {
pub(super) fds: Vec<RawFd>,
}
/// The server responds when the exec()'d command has exited.
#[derive(Clone, Serialize, Deserialize, Debug)]
pub(super) struct SuperExecResult {
pub(super) exit_code: i32,
}

View file

@ -0,0 +1,274 @@
use std::collections::HashMap;
use std::os::fd::AsRawFd;
use std::path::Path;
use std::path::PathBuf;
use std::process::Stdio;
use std::time::Duration;
use anyhow::Context as _;
use path_absolutize::Absolutize as _;
use codex_core::exec::SandboxType;
use codex_core::exec::process_exec_tool_call;
use codex_core::get_platform_sandbox;
use codex_core::protocol::SandboxPolicy;
use tokio::process::Command;
use crate::posix::escalate_protocol::BASH_EXEC_WRAPPER_ENV_VAR;
use crate::posix::escalate_protocol::ESCALATE_SOCKET_ENV_VAR;
use crate::posix::escalate_protocol::EscalateAction;
use crate::posix::escalate_protocol::EscalateRequest;
use crate::posix::escalate_protocol::EscalateResponse;
use crate::posix::escalate_protocol::SuperExecMessage;
use crate::posix::escalate_protocol::SuperExecResult;
use crate::posix::socket::AsyncDatagramSocket;
use crate::posix::socket::AsyncSocket;
/// This is the policy which decides how to handle an exec() call.
///
/// `file` is the absolute, canonical path to the executable to run, i.e. the first arg to exec.
/// `argv` is the argv, including the program name (`argv[0]`).
/// `workdir` is the absolute, canonical path to the working directory in which to execute the
/// command.
pub(crate) type ExecPolicy = fn(file: &Path, argv: &[String], workdir: &Path) -> EscalateAction;
pub(crate) struct EscalateServer {
bash_path: PathBuf,
policy: ExecPolicy,
}
impl EscalateServer {
pub fn new(bash_path: PathBuf, policy: ExecPolicy) -> Self {
Self { bash_path, policy }
}
pub async fn exec(
&self,
command: String,
env: HashMap<String, String>,
workdir: PathBuf,
timeout_ms: Option<u64>,
) -> anyhow::Result<ExecResult> {
let (escalate_server, escalate_client) = AsyncDatagramSocket::pair()?;
let client_socket = escalate_client.into_inner();
client_socket.set_cloexec(false)?;
let escalate_task = tokio::spawn(escalate_task(escalate_server, self.policy));
let mut env = env.clone();
env.insert(
ESCALATE_SOCKET_ENV_VAR.to_string(),
client_socket.as_raw_fd().to_string(),
);
env.insert(
BASH_EXEC_WRAPPER_ENV_VAR.to_string(),
format!("{} escalate", std::env::current_exe()?.to_string_lossy()),
);
let result = process_exec_tool_call(
codex_core::exec::ExecParams {
command: vec![
self.bash_path.to_string_lossy().to_string(),
"-c".to_string(),
command,
],
cwd: PathBuf::from(&workdir),
timeout_ms,
env,
with_escalated_permissions: None,
justification: None,
arg0: None,
},
get_platform_sandbox().unwrap_or(SandboxType::None),
// TODO: use the sandbox policy and cwd from the calling client
&SandboxPolicy::ReadOnly,
&PathBuf::from("/__NONEXISTENT__"), // This is ignored for ReadOnly
&None,
None,
)
.await?;
escalate_task.abort();
let result = ExecResult {
exit_code: result.exit_code,
output: result.aggregated_output.text,
duration: result.duration,
timed_out: result.timed_out,
};
Ok(result)
}
}
async fn escalate_task(socket: AsyncDatagramSocket, policy: ExecPolicy) -> anyhow::Result<()> {
loop {
let (_, mut fds) = socket.receive_with_fds().await?;
if fds.len() != 1 {
tracing::error!("expected 1 fd in datagram handshake, got {}", fds.len());
continue;
}
let stream_socket = AsyncSocket::from_fd(fds.remove(0))?;
tokio::spawn(async move {
if let Err(err) = handle_escalate_session_with_policy(stream_socket, policy).await {
tracing::error!("escalate session failed: {err:?}");
}
});
}
}
#[derive(Debug)]
pub(crate) struct ExecResult {
pub(crate) exit_code: i32,
pub(crate) output: String,
pub(crate) duration: Duration,
pub(crate) timed_out: bool,
}
async fn handle_escalate_session_with_policy(
socket: AsyncSocket,
policy: ExecPolicy,
) -> anyhow::Result<()> {
let EscalateRequest {
file,
argv,
workdir,
env,
} = socket.receive::<EscalateRequest>().await?;
let file = PathBuf::from(&file).absolutize()?.into_owned();
let workdir = PathBuf::from(&workdir).absolutize()?.into_owned();
let action = policy(file.as_path(), &argv, &workdir);
tracing::debug!("decided {action:?} for {file:?} {argv:?} {workdir:?}");
match action {
EscalateAction::Run => {
socket
.send(EscalateResponse {
action: EscalateAction::Run,
})
.await?;
}
EscalateAction::Escalate => {
socket
.send(EscalateResponse {
action: EscalateAction::Escalate,
})
.await?;
let (msg, fds) = socket
.receive_with_fds::<SuperExecMessage>()
.await
.context("failed to receive SuperExecMessage")?;
if fds.len() != msg.fds.len() {
return Err(anyhow::anyhow!(
"mismatched number of fds in SuperExecMessage: {} in the message, {} from the control message",
msg.fds.len(),
fds.len()
));
}
if msg
.fds
.iter()
.any(|src_fd| fds.iter().any(|dst_fd| dst_fd.as_raw_fd() == *src_fd))
{
return Err(anyhow::anyhow!(
"overlapping fds not yet supported in SuperExecMessage"
));
}
let mut command = Command::new(file);
command
.args(&argv[1..])
.arg0(argv[0].clone())
.envs(&env)
.current_dir(&workdir)
.stdin(Stdio::null())
.stdout(Stdio::null())
.stderr(Stdio::null());
unsafe {
command.pre_exec(move || {
for (dst_fd, src_fd) in msg.fds.iter().zip(&fds) {
libc::dup2(src_fd.as_raw_fd(), *dst_fd);
}
Ok(())
});
}
let mut child = command.spawn()?;
let exit_status = child.wait().await?;
socket
.send(SuperExecResult {
exit_code: exit_status.code().unwrap_or(127),
})
.await?;
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use pretty_assertions::assert_eq;
use std::collections::HashMap;
use std::path::PathBuf;
#[tokio::test]
async fn handle_escalate_session_respects_run_in_sandbox_decision() -> anyhow::Result<()> {
let (server, client) = AsyncSocket::pair()?;
let server_task = tokio::spawn(handle_escalate_session_with_policy(
server,
|_file, _argv, _workdir| EscalateAction::Run,
));
client
.send(EscalateRequest {
file: PathBuf::from("/bin/echo"),
argv: vec!["echo".to_string()],
workdir: PathBuf::from("/tmp"),
env: HashMap::new(),
})
.await?;
let response = client.receive::<EscalateResponse>().await?;
assert_eq!(
EscalateResponse {
action: EscalateAction::Run,
},
response
);
server_task.await?
}
#[tokio::test]
async fn handle_escalate_session_executes_escalated_command() -> anyhow::Result<()> {
let (server, client) = AsyncSocket::pair()?;
let server_task = tokio::spawn(handle_escalate_session_with_policy(
server,
|_file, _argv, _workdir| EscalateAction::Escalate,
));
client
.send(EscalateRequest {
file: PathBuf::from("/bin/sh"),
argv: vec![
"sh".to_string(),
"-c".to_string(),
r#"if [ "$KEY" = VALUE ]; then exit 42; else exit 1; fi"#.to_string(),
],
workdir: std::env::current_dir()?,
env: HashMap::from([("KEY".to_string(), "VALUE".to_string())]),
})
.await?;
let response = client.receive::<EscalateResponse>().await?;
assert_eq!(
EscalateResponse {
action: EscalateAction::Escalate,
},
response
);
client
.send_with_fds(SuperExecMessage { fds: Vec::new() }, &[])
.await?;
let result = client.receive::<SuperExecResult>().await?;
assert_eq!(42, result.exit_code);
server_task.await?
}
}

View file

@ -0,0 +1,154 @@
use std::path::PathBuf;
use std::time::Duration;
use anyhow::Context as _;
use anyhow::Result;
use rmcp::ErrorData as McpError;
use rmcp::RoleServer;
use rmcp::ServerHandler;
use rmcp::ServiceExt;
use rmcp::handler::server::router::tool::ToolRouter;
use rmcp::handler::server::wrapper::Parameters;
use rmcp::model::*;
use rmcp::schemars;
use rmcp::service::RequestContext;
use rmcp::service::RunningService;
use rmcp::tool;
use rmcp::tool_handler;
use rmcp::tool_router;
use rmcp::transport::stdio;
use crate::posix::escalate_server;
use crate::posix::escalate_server::EscalateServer;
use crate::posix::escalate_server::ExecPolicy;
/// Path to our patched bash.
const CODEX_BASH_PATH_ENV_VAR: &str = "CODEX_BASH_PATH";
pub(crate) fn get_bash_path() -> Result<PathBuf> {
std::env::var(CODEX_BASH_PATH_ENV_VAR)
.map(PathBuf::from)
.context(format!("{CODEX_BASH_PATH_ENV_VAR} must be set"))
}
#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
pub struct ExecParams {
/// The bash string to execute.
pub command: String,
/// The working directory to execute the command in. Must be an absolute path.
pub workdir: String,
/// The timeout for the command in milliseconds.
pub timeout_ms: Option<u64>,
}
#[derive(Debug, serde::Serialize, schemars::JsonSchema)]
pub struct ExecResult {
pub exit_code: i32,
pub output: String,
pub duration: Duration,
pub timed_out: bool,
}
impl From<escalate_server::ExecResult> for ExecResult {
fn from(result: escalate_server::ExecResult) -> Self {
Self {
exit_code: result.exit_code,
output: result.output,
duration: result.duration,
timed_out: result.timed_out,
}
}
}
#[derive(Clone)]
pub struct ExecTool {
tool_router: ToolRouter<ExecTool>,
bash_path: PathBuf,
policy: ExecPolicy,
}
#[tool_router]
impl ExecTool {
pub fn new(bash_path: PathBuf, policy: ExecPolicy) -> Self {
Self {
tool_router: Self::tool_router(),
bash_path,
policy,
}
}
/// Runs a shell command and returns its output. You MUST provide the workdir as an absolute path.
#[tool]
async fn shell(
&self,
_context: RequestContext<RoleServer>,
Parameters(params): Parameters<ExecParams>,
) -> Result<CallToolResult, McpError> {
let escalate_server = EscalateServer::new(self.bash_path.clone(), self.policy);
let result = escalate_server
.exec(
params.command,
// TODO: use ShellEnvironmentPolicy
std::env::vars().collect(),
PathBuf::from(&params.workdir),
params.timeout_ms,
)
.await
.map_err(|e| McpError::internal_error(e.to_string(), None))?;
Ok(CallToolResult::success(vec![Content::json(
ExecResult::from(result),
)?]))
}
#[allow(dead_code)]
async fn prompt(
&self,
command: String,
workdir: String,
context: RequestContext<RoleServer>,
) -> Result<CreateElicitationResult, McpError> {
context
.peer
.create_elicitation(CreateElicitationRequestParam {
message: format!("Allow Codex to run `{command:?}` in `{workdir:?}`?"),
#[allow(clippy::expect_used)]
requested_schema: ElicitationSchema::builder()
.property("dummy", PrimitiveSchema::String(StringSchema::new()))
.build()
.expect("failed to build elicitation schema"),
})
.await
.map_err(|e| McpError::internal_error(e.to_string(), None))
}
}
#[tool_handler]
impl ServerHandler for ExecTool {
fn get_info(&self) -> ServerInfo {
ServerInfo {
protocol_version: ProtocolVersion::V_2025_06_18,
capabilities: ServerCapabilities::builder().enable_tools().build(),
server_info: Implementation::from_build_env(),
instructions: Some(
"This server provides a tool to execute shell commands and return their output."
.to_string(),
),
}
}
async fn initialize(
&self,
_request: InitializeRequestParam,
_context: RequestContext<RoleServer>,
) -> Result<InitializeResult, McpError> {
Ok(self.get_info())
}
}
pub(crate) async fn serve(
bash_path: PathBuf,
policy: ExecPolicy,
) -> Result<RunningService<RoleServer, ExecTool>, rmcp::service::ServerInitializeError> {
let tool = ExecTool::new(bash_path, policy);
tool.serve(stdio()).await
}

View file

@ -0,0 +1,486 @@
use libc::c_uint;
use serde::Deserialize;
use serde::Serialize;
use socket2::Domain;
use socket2::MaybeUninitSlice;
use socket2::MsgHdr;
use socket2::MsgHdrMut;
use socket2::Socket;
use socket2::Type;
use std::io::IoSlice;
use std::mem::MaybeUninit;
use std::os::fd::AsRawFd;
use std::os::fd::FromRawFd;
use std::os::fd::OwnedFd;
use std::os::fd::RawFd;
use tokio::io::Interest;
use tokio::io::unix::AsyncFd;
const MAX_FDS_PER_MESSAGE: usize = 16;
const LENGTH_PREFIX_SIZE: usize = size_of::<u32>();
const MAX_DATAGRAM_SIZE: usize = 8192;
/// Converts a slice of MaybeUninit<T> to a slice of T.
///
/// The caller guarantees that every element of `buf` is initialized.
fn assume_init<T>(buf: &[MaybeUninit<T>]) -> &[T] {
unsafe { std::slice::from_raw_parts(buf.as_ptr().cast(), buf.len()) }
}
fn assume_init_slice<T, const N: usize>(buf: &[MaybeUninit<T>; N]) -> &[T; N] {
unsafe { &*(buf as *const [MaybeUninit<T>; N] as *const [T; N]) }
}
fn assume_init_vec<T>(mut buf: Vec<MaybeUninit<T>>) -> Vec<T> {
unsafe {
let ptr = buf.as_mut_ptr() as *mut T;
let len = buf.len();
let cap = buf.capacity();
std::mem::forget(buf);
Vec::from_raw_parts(ptr, len, cap)
}
}
fn control_space_for_fds(count: usize) -> usize {
unsafe { libc::CMSG_SPACE((count * size_of::<RawFd>()) as _) as usize }
}
/// Extracts the FDs from a SCM_RIGHTS control message.
fn extract_fds(control: &[u8]) -> Vec<OwnedFd> {
let mut fds = Vec::new();
let mut hdr: libc::msghdr = unsafe { std::mem::zeroed() };
hdr.msg_control = control.as_ptr() as *mut libc::c_void;
hdr.msg_controllen = control.len() as _;
let hdr = hdr; // drop mut
let mut cmsg = unsafe { libc::CMSG_FIRSTHDR(&hdr) as *const libc::cmsghdr };
while !cmsg.is_null() {
let level = unsafe { (*cmsg).cmsg_level };
let ty = unsafe { (*cmsg).cmsg_type };
if level == libc::SOL_SOCKET && ty == libc::SCM_RIGHTS {
let data_ptr = unsafe { libc::CMSG_DATA(cmsg).cast::<RawFd>() };
let fd_count: usize = {
let cmsg_data_len =
unsafe { (*cmsg).cmsg_len as usize } - unsafe { libc::CMSG_LEN(0) as usize };
cmsg_data_len / size_of::<RawFd>()
};
for i in 0..fd_count {
let fd = unsafe { data_ptr.add(i).read() };
fds.push(unsafe { OwnedFd::from_raw_fd(fd) });
}
}
cmsg = unsafe { libc::CMSG_NXTHDR(&hdr, cmsg) };
}
fds
}
/// Read a frame from a SOCK_STREAM socket.
///
/// A frame is a message length prefix followed by a payload. FDs may be included in the control
/// message when receiving the frame header.
async fn read_frame(async_socket: &AsyncFd<Socket>) -> std::io::Result<(Vec<u8>, Vec<OwnedFd>)> {
let (message_len, fds) = read_frame_header(async_socket).await?;
let payload = read_frame_payload(async_socket, message_len).await?;
Ok((payload, fds))
}
/// Read the frame header (i.e. length) and any FDs from a SOCK_STREAM socket.
async fn read_frame_header(
async_socket: &AsyncFd<Socket>,
) -> std::io::Result<(usize, Vec<OwnedFd>)> {
let mut header = [MaybeUninit::<u8>::uninit(); LENGTH_PREFIX_SIZE];
let mut filled = 0;
let mut control = vec![MaybeUninit::<u8>::uninit(); control_space_for_fds(MAX_FDS_PER_MESSAGE)];
let mut captured_control = false;
while filled < LENGTH_PREFIX_SIZE {
let mut guard = async_socket.readable().await?;
// The first read should come with a control message containing any FDs.
let result = if !captured_control {
guard.try_io(|inner| {
let mut bufs = [MaybeUninitSlice::new(&mut header[filled..])];
let (read, control_len) = {
let mut msg = MsgHdrMut::new()
.with_buffers(&mut bufs)
.with_control(&mut control);
let read = inner.get_ref().recvmsg(&mut msg, 0)?;
(read, msg.control_len())
};
control.truncate(control_len);
captured_control = true;
Ok(read)
})
} else {
guard.try_io(|inner| inner.get_ref().recv(&mut header[filled..]))
};
let Ok(result) = result else {
// Would block, try again.
continue;
};
let read = result?;
if read == 0 {
return Err(std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
"socket closed while receiving frame header",
));
}
filled += read;
assert!(filled <= LENGTH_PREFIX_SIZE);
if filled == LENGTH_PREFIX_SIZE {
let len_bytes = assume_init_slice(&header);
let payload_len = u32::from_le_bytes(*len_bytes) as usize;
let fds = extract_fds(assume_init(&control));
return Ok((payload_len, fds));
}
}
unreachable!("header loop always returns")
}
/// Read `message_len` bytes from a SOCK_STREAM socket.
async fn read_frame_payload(
async_socket: &AsyncFd<Socket>,
message_len: usize,
) -> std::io::Result<Vec<u8>> {
if message_len == 0 {
return Ok(Vec::new());
}
let mut payload = vec![MaybeUninit::<u8>::uninit(); message_len];
let mut filled = 0;
while filled < message_len {
let mut guard = async_socket.readable().await?;
let result = guard.try_io(|inner| inner.get_ref().recv(&mut payload[filled..]));
let Ok(result) = result else {
// Would block, try again.
continue;
};
let read = result?;
if read == 0 {
return Err(std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
"socket closed while receiving frame payload",
));
}
filled += read;
assert!(filled <= message_len);
if filled == message_len {
return Ok(assume_init_vec(payload));
}
}
unreachable!("loop exits only after returning payload")
}
fn send_message_bytes(socket: &Socket, data: &[u8], fds: &[OwnedFd]) -> std::io::Result<()> {
if fds.len() > MAX_FDS_PER_MESSAGE {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
format!("too many fds: {}", fds.len()),
));
}
let mut frame = Vec::with_capacity(LENGTH_PREFIX_SIZE + data.len());
frame.extend_from_slice(&encode_length(data.len())?);
frame.extend_from_slice(data);
let mut control = vec![0u8; control_space_for_fds(fds.len())];
unsafe {
let cmsg = control.as_mut_ptr().cast::<libc::cmsghdr>();
(*cmsg).cmsg_len = libc::CMSG_LEN(size_of::<RawFd>() as c_uint * fds.len() as c_uint) as _;
(*cmsg).cmsg_level = libc::SOL_SOCKET;
(*cmsg).cmsg_type = libc::SCM_RIGHTS;
let data_ptr = libc::CMSG_DATA(cmsg).cast::<RawFd>();
for (i, fd) in fds.iter().enumerate() {
data_ptr.add(i).write(fd.as_raw_fd());
}
}
let payload = [IoSlice::new(&frame)];
let msg = MsgHdr::new().with_buffers(&payload).with_control(&control);
let mut sent = socket.sendmsg(&msg, 0)?;
while sent < frame.len() {
let bytes = socket.send(&frame[sent..])?;
if bytes == 0 {
return Err(std::io::Error::new(
std::io::ErrorKind::WriteZero,
"socket closed while sending frame payload",
));
}
sent += bytes;
}
Ok(())
}
fn encode_length(len: usize) -> std::io::Result<[u8; LENGTH_PREFIX_SIZE]> {
let len_u32 = u32::try_from(len).map_err(|_| {
std::io::Error::new(
std::io::ErrorKind::InvalidInput,
format!("message too large: {len}"),
)
})?;
Ok(len_u32.to_le_bytes())
}
pub(crate) fn send_json_message<T: Serialize>(
socket: &Socket,
msg: T,
fds: &[OwnedFd],
) -> std::io::Result<()> {
let data = serde_json::to_vec(&msg)?;
send_message_bytes(socket, &data, fds)
}
fn send_datagram_bytes(socket: &Socket, data: &[u8], fds: &[OwnedFd]) -> std::io::Result<()> {
if fds.len() > MAX_FDS_PER_MESSAGE {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
format!("too many fds: {}", fds.len()),
));
}
let mut control = vec![0u8; control_space_for_fds(fds.len())];
if !fds.is_empty() {
unsafe {
let cmsg = control.as_mut_ptr().cast::<libc::cmsghdr>();
(*cmsg).cmsg_len =
libc::CMSG_LEN(size_of::<RawFd>() as c_uint * fds.len() as c_uint) as _;
(*cmsg).cmsg_level = libc::SOL_SOCKET;
(*cmsg).cmsg_type = libc::SCM_RIGHTS;
let data_ptr = libc::CMSG_DATA(cmsg).cast::<RawFd>();
for (i, fd) in fds.iter().enumerate() {
data_ptr.add(i).write(fd.as_raw_fd());
}
}
}
let payload = [IoSlice::new(data)];
let msg = MsgHdr::new().with_buffers(&payload).with_control(&control);
let written = socket.sendmsg(&msg, 0)?;
if written != data.len() {
return Err(std::io::Error::new(
std::io::ErrorKind::WriteZero,
format!(
"short datagram write: wrote {written} bytes out of {}",
data.len()
),
));
}
Ok(())
}
fn receive_datagram_bytes(socket: &Socket) -> std::io::Result<(Vec<u8>, Vec<OwnedFd>)> {
let mut buffer = vec![MaybeUninit::<u8>::uninit(); MAX_DATAGRAM_SIZE];
let mut control = vec![MaybeUninit::<u8>::uninit(); control_space_for_fds(MAX_FDS_PER_MESSAGE)];
let (read, control_len) = {
let mut bufs = [MaybeUninitSlice::new(&mut buffer)];
let mut msg = MsgHdrMut::new()
.with_buffers(&mut bufs)
.with_control(&mut control);
let read = socket.recvmsg(&mut msg, 0)?;
(read, msg.control_len())
};
let data = assume_init(&buffer[..read]).to_vec();
let fds = extract_fds(assume_init(&control[..control_len]));
Ok((data, fds))
}
pub(crate) struct AsyncSocket {
inner: AsyncFd<Socket>,
}
impl AsyncSocket {
fn new(socket: Socket) -> std::io::Result<AsyncSocket> {
socket.set_nonblocking(true)?;
let async_socket = AsyncFd::new(socket)?;
Ok(AsyncSocket {
inner: async_socket,
})
}
pub fn from_fd(fd: OwnedFd) -> std::io::Result<AsyncSocket> {
AsyncSocket::new(Socket::from(fd))
}
pub fn pair() -> std::io::Result<(AsyncSocket, AsyncSocket)> {
let (server, client) = Socket::pair(Domain::UNIX, Type::STREAM, None)?;
Ok((AsyncSocket::new(server)?, AsyncSocket::new(client)?))
}
pub async fn send_with_fds<T: Serialize>(
&self,
msg: T,
fds: &[OwnedFd],
) -> std::io::Result<()> {
self.inner
.async_io(Interest::WRITABLE, |socket| {
send_json_message(socket, &msg, fds)
})
.await
}
pub async fn receive_with_fds<T: for<'de> Deserialize<'de>>(
&self,
) -> std::io::Result<(T, Vec<OwnedFd>)> {
let (payload, fds) = read_frame(&self.inner).await?;
let message: T = serde_json::from_slice(&payload)?;
Ok((message, fds))
}
pub async fn send<T>(&self, msg: T) -> std::io::Result<()>
where
T: Serialize,
{
self.send_with_fds(&msg, &[]).await
}
pub async fn receive<T: for<'de> Deserialize<'de>>(&self) -> std::io::Result<T> {
let (msg, fds) = self.receive_with_fds().await?;
if !fds.is_empty() {
tracing::warn!("unexpected fds in receive: {}", fds.len());
}
Ok(msg)
}
pub fn into_inner(self) -> Socket {
self.inner.into_inner()
}
}
pub(crate) struct AsyncDatagramSocket {
inner: AsyncFd<Socket>,
}
impl AsyncDatagramSocket {
fn new(socket: Socket) -> std::io::Result<Self> {
socket.set_nonblocking(true)?;
Ok(Self {
inner: AsyncFd::new(socket)?,
})
}
pub unsafe fn from_raw_fd(fd: RawFd) -> std::io::Result<Self> {
Self::new(unsafe { Socket::from_raw_fd(fd) })
}
pub fn pair() -> std::io::Result<(Self, Self)> {
let (server, client) = Socket::pair(Domain::UNIX, Type::DGRAM, None)?;
Ok((Self::new(server)?, Self::new(client)?))
}
pub async fn send_with_fds(&self, data: &[u8], fds: &[OwnedFd]) -> std::io::Result<()> {
self.inner
.async_io(Interest::WRITABLE, |socket| {
send_datagram_bytes(socket, data, fds)
})
.await
}
pub async fn receive_with_fds(&self) -> std::io::Result<(Vec<u8>, Vec<OwnedFd>)> {
self.inner
.async_io(Interest::READABLE, receive_datagram_bytes)
.await
}
pub fn into_inner(self) -> Socket {
self.inner.into_inner()
}
}
#[cfg(test)]
mod tests {
use super::*;
use pretty_assertions::assert_eq;
use serde::Deserialize;
use serde::Serialize;
use std::os::fd::AsFd;
use std::os::fd::AsRawFd;
use tempfile::NamedTempFile;
#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)]
struct TestPayload {
id: i32,
label: String,
}
fn fd_list(count: usize) -> std::io::Result<Vec<OwnedFd>> {
let file = NamedTempFile::new()?;
let mut fds = Vec::new();
for _ in 0..count {
fds.push(file.as_fd().try_clone_to_owned()?);
}
Ok(fds)
}
#[tokio::test]
async fn async_socket_round_trips_payload_and_fds() -> std::io::Result<()> {
let (server, client) = AsyncSocket::pair()?;
let payload = TestPayload {
id: 7,
label: "round-trip".to_string(),
};
let send_fds = fd_list(1)?;
let receive_task =
tokio::spawn(async move { server.receive_with_fds::<TestPayload>().await });
client.send_with_fds(payload.clone(), &send_fds).await?;
drop(send_fds);
let (received_payload, received_fds) = receive_task.await.unwrap()?;
assert_eq!(payload, received_payload);
assert_eq!(1, received_fds.len());
let fd_status = unsafe { libc::fcntl(received_fds[0].as_raw_fd(), libc::F_GETFD) };
assert!(
fd_status >= 0,
"expected received file descriptor to be valid, but got {fd_status}",
);
Ok(())
}
#[tokio::test]
async fn async_datagram_sockets_round_trip_messages() -> std::io::Result<()> {
let (server, client) = AsyncDatagramSocket::pair()?;
let data = b"datagram payload".to_vec();
let send_fds = fd_list(1)?;
let receive_task = tokio::spawn(async move { server.receive_with_fds().await });
client.send_with_fds(&data, &send_fds).await?;
drop(send_fds);
let (received_bytes, received_fds) = receive_task.await.unwrap()?;
assert_eq!(data, received_bytes);
assert_eq!(1, received_fds.len());
Ok(())
}
#[test]
fn send_message_bytes_rejects_excessive_fd_counts() -> std::io::Result<()> {
let (socket, _peer) = Socket::pair(Domain::UNIX, Type::STREAM, None)?;
let fds = fd_list(MAX_FDS_PER_MESSAGE + 1)?;
let err = send_message_bytes(&socket, b"hello", &fds).unwrap_err();
assert_eq!(std::io::ErrorKind::InvalidInput, err.kind());
Ok(())
}
#[test]
fn send_datagram_bytes_rejects_excessive_fd_counts() -> std::io::Result<()> {
let (socket, _peer) = Socket::pair(Domain::UNIX, Type::DGRAM, None)?;
let fds = fd_list(MAX_FDS_PER_MESSAGE + 1)?;
let err = send_datagram_bytes(&socket, b"hi", &fds).unwrap_err();
assert_eq!(std::io::ErrorKind::InvalidInput, err.kind());
Ok(())
}
#[test]
fn encode_length_errors_for_oversized_messages() {
let err = encode_length(usize::MAX).unwrap_err();
assert_eq!(std::io::ErrorKind::InvalidInput, err.kind());
}
#[tokio::test]
async fn receive_fails_when_peer_closes_before_header() {
let (server, client) = AsyncSocket::pair().expect("failed to create socket pair");
drop(client);
let err = server
.receive::<serde_json::Value>()
.await
.expect_err("expected read failure");
assert_eq!(std::io::ErrorKind::UnexpectedEof, err.kind());
}
}