exec-server (#6630)
This commit is contained in:
parent
9275e93364
commit
c1391b9f94
10 changed files with 1309 additions and 0 deletions
20
codex-rs/Cargo.lock
generated
20
codex-rs/Cargo.lock
generated
|
|
@ -1182,6 +1182,26 @@ dependencies = [
|
|||
"wiremock",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "codex-exec-server"
|
||||
version = "0.0.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"clap",
|
||||
"codex-core",
|
||||
"libc",
|
||||
"path-absolutize",
|
||||
"pretty_assertions",
|
||||
"rmcp",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"socket2 0.6.0",
|
||||
"tempfile",
|
||||
"tokio",
|
||||
"tracing",
|
||||
"tracing-subscriber",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "codex-execpolicy"
|
||||
version = "0.0.0"
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ members = [
|
|||
"common",
|
||||
"core",
|
||||
"exec",
|
||||
"exec-server",
|
||||
"execpolicy",
|
||||
"execpolicy2",
|
||||
"keyring-store",
|
||||
|
|
@ -176,6 +177,7 @@ sha1 = "0.10.6"
|
|||
sha2 = "0.10"
|
||||
shlex = "1.3.0"
|
||||
similar = "2.7.0"
|
||||
socket2 = "0.6.0"
|
||||
starlark = "0.13.0"
|
||||
strum = "0.27.2"
|
||||
strum_macros = "0.27.2"
|
||||
|
|
|
|||
47
codex-rs/exec-server/Cargo.toml
Normal file
47
codex-rs/exec-server/Cargo.toml
Normal file
|
|
@ -0,0 +1,47 @@
|
|||
[package]
|
||||
edition = "2024"
|
||||
name = "codex-exec-server"
|
||||
version = { workspace = true }
|
||||
|
||||
[[bin]]
|
||||
name = "codex-exec-server"
|
||||
path = "src/main.rs"
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[dependencies]
|
||||
anyhow = { workspace = true }
|
||||
clap = { workspace = true, features = ["derive"] }
|
||||
codex-core = { workspace = true }
|
||||
libc = { workspace = true }
|
||||
path-absolutize = { workspace = true }
|
||||
rmcp = { workspace = true, default-features = false, features = [
|
||||
"auth",
|
||||
"elicitation",
|
||||
"base64",
|
||||
"client",
|
||||
"macros",
|
||||
"schemars",
|
||||
"server",
|
||||
"transport-child-process",
|
||||
"transport-streamable-http-client-reqwest",
|
||||
"transport-streamable-http-server",
|
||||
"transport-io",
|
||||
] }
|
||||
serde = { workspace = true, features = ["derive"] }
|
||||
serde_json = { workspace = true }
|
||||
socket2 = { workspace = true }
|
||||
tokio = { workspace = true, features = [
|
||||
"io-std",
|
||||
"macros",
|
||||
"process",
|
||||
"rt-multi-thread",
|
||||
"signal",
|
||||
] }
|
||||
tracing = { workspace = true }
|
||||
tracing-subscriber = { workspace = true, features = ["env-filter", "fmt"] }
|
||||
|
||||
[dev-dependencies]
|
||||
pretty_assertions = { workspace = true }
|
||||
tempfile = { workspace = true }
|
||||
11
codex-rs/exec-server/src/main.rs
Normal file
11
codex-rs/exec-server/src/main.rs
Normal file
|
|
@ -0,0 +1,11 @@
|
|||
#[cfg(target_os = "windows")]
|
||||
fn main() {
|
||||
eprintln!("codex-exec-server is not implemented on Windows targets");
|
||||
std::process::exit(1);
|
||||
}
|
||||
|
||||
#[cfg(not(target_os = "windows"))]
|
||||
mod posix;
|
||||
|
||||
#[cfg(not(target_os = "windows"))]
|
||||
pub use posix::main;
|
||||
164
codex-rs/exec-server/src/posix.rs
Normal file
164
codex-rs/exec-server/src/posix.rs
Normal file
|
|
@ -0,0 +1,164 @@
|
|||
//! This is an MCP that implements an alternative `shell` tool with fine-grained privilege
|
||||
//! escalation based on a per-exec() policy.
|
||||
//!
|
||||
//! We spawn Bash process inside a sandbox. The Bash we spawn is patched to allow us to intercept
|
||||
//! every exec() call it makes by invoking a wrapper program and passing in the arguments it would
|
||||
//! have passed to exec(). The Bash process (and its descendants) inherit a communication socket
|
||||
//! from us, and we give its fd number in the CODEX_ESCALATE_SOCKET environment variable.
|
||||
//!
|
||||
//! When we intercept an exec() call, we send a message over the socket back to the main
|
||||
//! MCP process. The MCP process can then decide whether to allow the exec() call to proceed
|
||||
//! or to escalate privileges and run the requested command with elevated permissions. In the
|
||||
//! latter case, we send a message back to the child requesting that it forward its open FDs to us.
|
||||
//! We then execute the requested command on its behalf, patching in the forwarded FDs.
|
||||
//!
|
||||
//!
|
||||
//! ### The privilege escalation flow
|
||||
//!
|
||||
//! Child MCP Bash Escalate Helper
|
||||
//! |
|
||||
//! o----->o
|
||||
//! | |
|
||||
//! | o--(exec)-->o
|
||||
//! | | |
|
||||
//! |o<-(EscalateReq)--o
|
||||
//! || | |
|
||||
//! |o--(Escalate)---->o
|
||||
//! || | |
|
||||
//! |o<---------(fds)--o
|
||||
//! || | |
|
||||
//! o<-----o | |
|
||||
//! | || | |
|
||||
//! x----->o | |
|
||||
//! || | |
|
||||
//! |x--(exit code)--->o
|
||||
//! | | |
|
||||
//! | o<--(exit)--x
|
||||
//! | |
|
||||
//! o<-----x
|
||||
//!
|
||||
//! ### The non-escalation flow
|
||||
//!
|
||||
//! MCP Bash Escalate Helper Child
|
||||
//! |
|
||||
//! o----->o
|
||||
//! | |
|
||||
//! | o--(exec)-->o
|
||||
//! | | |
|
||||
//! |o<-(EscalateReq)--o
|
||||
//! || | |
|
||||
//! |o-(Run)---------->o
|
||||
//! | | |
|
||||
//! | | x--(exec)-->o
|
||||
//! | | |
|
||||
//! | o<--------------(exit)--x
|
||||
//! | |
|
||||
//! o<-----x
|
||||
//!
|
||||
use std::path::Path;
|
||||
|
||||
use clap::Parser;
|
||||
use clap::Subcommand;
|
||||
use tracing_subscriber::EnvFilter;
|
||||
use tracing_subscriber::{self};
|
||||
|
||||
use crate::posix::escalate_protocol::EscalateAction;
|
||||
use crate::posix::escalate_server::EscalateServer;
|
||||
|
||||
mod escalate_client;
|
||||
mod escalate_protocol;
|
||||
mod escalate_server;
|
||||
mod mcp;
|
||||
mod socket;
|
||||
|
||||
fn dummy_exec_policy(file: &Path, argv: &[String], _workdir: &Path) -> EscalateAction {
|
||||
// TODO: execpolicy
|
||||
if file == Path::new("/opt/homebrew/bin/gh")
|
||||
&& let [_, arg1, arg2, ..] = argv
|
||||
&& arg1 == "issue"
|
||||
&& arg2 == "list"
|
||||
{
|
||||
return EscalateAction::Escalate;
|
||||
}
|
||||
EscalateAction::Run
|
||||
}
|
||||
|
||||
#[derive(Parser)]
|
||||
#[command(version)]
|
||||
pub struct Cli {
|
||||
#[command(subcommand)]
|
||||
subcommand: Option<Commands>,
|
||||
}
|
||||
|
||||
#[derive(Subcommand)]
|
||||
enum Commands {
|
||||
Escalate(EscalateArgs),
|
||||
ShellExec(ShellExecArgs),
|
||||
}
|
||||
|
||||
/// Invoked from within the sandbox to (potentially) escalate permissions.
|
||||
#[derive(Parser, Debug)]
|
||||
struct EscalateArgs {
|
||||
file: String,
|
||||
|
||||
#[arg(trailing_var_arg = true)]
|
||||
argv: Vec<String>,
|
||||
}
|
||||
|
||||
impl EscalateArgs {
|
||||
/// This is the escalate client. It talks to the escalate server to determine whether to exec()
|
||||
/// the command directly or to proxy to the escalate server.
|
||||
async fn run(self) -> anyhow::Result<i32> {
|
||||
let EscalateArgs { file, argv } = self;
|
||||
escalate_client::run(file, argv).await
|
||||
}
|
||||
}
|
||||
|
||||
/// Debugging command to emulate an MCP "shell" tool call.
|
||||
#[derive(Parser, Debug)]
|
||||
struct ShellExecArgs {
|
||||
command: String,
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
pub async fn main() -> anyhow::Result<()> {
|
||||
let cli = Cli::parse();
|
||||
tracing_subscriber::fmt()
|
||||
.with_env_filter(EnvFilter::from_default_env())
|
||||
.with_writer(std::io::stderr)
|
||||
.with_ansi(false)
|
||||
.init();
|
||||
|
||||
match cli.subcommand {
|
||||
Some(Commands::Escalate(args)) => {
|
||||
std::process::exit(args.run().await?);
|
||||
}
|
||||
Some(Commands::ShellExec(args)) => {
|
||||
let bash_path = mcp::get_bash_path()?;
|
||||
let escalate_server = EscalateServer::new(bash_path, dummy_exec_policy);
|
||||
let result = escalate_server
|
||||
.exec(
|
||||
args.command.clone(),
|
||||
std::env::vars().collect(),
|
||||
std::env::current_dir()?,
|
||||
None,
|
||||
)
|
||||
.await?;
|
||||
println!("{result:?}");
|
||||
std::process::exit(result.exit_code);
|
||||
}
|
||||
None => {
|
||||
let bash_path = mcp::get_bash_path()?;
|
||||
|
||||
tracing::info!("Starting MCP server");
|
||||
let service = mcp::serve(bash_path, dummy_exec_policy)
|
||||
.await
|
||||
.inspect_err(|e| {
|
||||
tracing::error!("serving error: {:?}", e);
|
||||
})?;
|
||||
|
||||
service.waiting().await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
}
|
||||
102
codex-rs/exec-server/src/posix/escalate_client.rs
Normal file
102
codex-rs/exec-server/src/posix/escalate_client.rs
Normal file
|
|
@ -0,0 +1,102 @@
|
|||
use std::io;
|
||||
use std::os::fd::AsRawFd;
|
||||
use std::os::fd::FromRawFd as _;
|
||||
use std::os::fd::OwnedFd;
|
||||
|
||||
use anyhow::Context as _;
|
||||
|
||||
use crate::posix::escalate_protocol::BASH_EXEC_WRAPPER_ENV_VAR;
|
||||
use crate::posix::escalate_protocol::ESCALATE_SOCKET_ENV_VAR;
|
||||
use crate::posix::escalate_protocol::EscalateAction;
|
||||
use crate::posix::escalate_protocol::EscalateRequest;
|
||||
use crate::posix::escalate_protocol::EscalateResponse;
|
||||
use crate::posix::escalate_protocol::SuperExecMessage;
|
||||
use crate::posix::escalate_protocol::SuperExecResult;
|
||||
use crate::posix::socket::AsyncDatagramSocket;
|
||||
use crate::posix::socket::AsyncSocket;
|
||||
|
||||
fn get_escalate_client() -> anyhow::Result<AsyncDatagramSocket> {
|
||||
// TODO: we should defensively require only calling this once, since AsyncSocket will take ownership of the fd.
|
||||
let client_fd = std::env::var(ESCALATE_SOCKET_ENV_VAR)?.parse::<i32>()?;
|
||||
if client_fd < 0 {
|
||||
return Err(anyhow::anyhow!(
|
||||
"{ESCALATE_SOCKET_ENV_VAR} is not a valid file descriptor: {client_fd}"
|
||||
));
|
||||
}
|
||||
Ok(unsafe { AsyncDatagramSocket::from_raw_fd(client_fd) }?)
|
||||
}
|
||||
|
||||
pub(crate) async fn run(file: String, argv: Vec<String>) -> anyhow::Result<i32> {
|
||||
let handshake_client = get_escalate_client()?;
|
||||
let (server, client) = AsyncSocket::pair()?;
|
||||
const HANDSHAKE_MESSAGE: [u8; 1] = [0];
|
||||
handshake_client
|
||||
.send_with_fds(&HANDSHAKE_MESSAGE, &[server.into_inner().into()])
|
||||
.await
|
||||
.context("failed to send handshake datagram")?;
|
||||
let env = std::env::vars()
|
||||
.filter(|(k, _)| {
|
||||
!matches!(
|
||||
k.as_str(),
|
||||
ESCALATE_SOCKET_ENV_VAR | BASH_EXEC_WRAPPER_ENV_VAR
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
client
|
||||
.send(EscalateRequest {
|
||||
file: file.clone().into(),
|
||||
argv: argv.clone(),
|
||||
workdir: std::env::current_dir()?,
|
||||
env,
|
||||
})
|
||||
.await
|
||||
.context("failed to send EscalateRequest")?;
|
||||
let message = client.receive::<EscalateResponse>().await?;
|
||||
match message.action {
|
||||
EscalateAction::Escalate => {
|
||||
// TODO: maybe we should send ALL open FDs (except the escalate client)?
|
||||
let fds_to_send = [
|
||||
unsafe { OwnedFd::from_raw_fd(io::stdin().as_raw_fd()) },
|
||||
unsafe { OwnedFd::from_raw_fd(io::stdout().as_raw_fd()) },
|
||||
unsafe { OwnedFd::from_raw_fd(io::stderr().as_raw_fd()) },
|
||||
];
|
||||
|
||||
// TODO: also forward signals over the super-exec socket
|
||||
|
||||
client
|
||||
.send_with_fds(
|
||||
SuperExecMessage {
|
||||
fds: fds_to_send.iter().map(AsRawFd::as_raw_fd).collect(),
|
||||
},
|
||||
&fds_to_send,
|
||||
)
|
||||
.await
|
||||
.context("failed to send SuperExecMessage")?;
|
||||
let SuperExecResult { exit_code } = client.receive::<SuperExecResult>().await?;
|
||||
Ok(exit_code)
|
||||
}
|
||||
EscalateAction::Run => {
|
||||
// We avoid std::process::Command here because we want to be as transparent as
|
||||
// possible. std::os::unix::process::CommandExt has .exec() but it does some funky
|
||||
// stuff with signal masks and dup2() on its standard FDs, which we don't want.
|
||||
use std::ffi::CString;
|
||||
let file = CString::new(file).context("NUL in file")?;
|
||||
|
||||
let argv_cstrs: Vec<CString> = argv
|
||||
.iter()
|
||||
.map(|s| CString::new(s.as_str()).context("NUL in argv"))
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
|
||||
let mut argv: Vec<*const libc::c_char> =
|
||||
argv_cstrs.iter().map(|s| s.as_ptr()).collect();
|
||||
argv.push(std::ptr::null());
|
||||
|
||||
let err = unsafe {
|
||||
libc::execv(file.as_ptr(), argv.as_ptr());
|
||||
std::io::Error::last_os_error()
|
||||
};
|
||||
|
||||
Err(err.into())
|
||||
}
|
||||
}
|
||||
}
|
||||
49
codex-rs/exec-server/src/posix/escalate_protocol.rs
Normal file
49
codex-rs/exec-server/src/posix/escalate_protocol.rs
Normal file
|
|
@ -0,0 +1,49 @@
|
|||
use std::collections::HashMap;
|
||||
use std::os::fd::RawFd;
|
||||
use std::path::PathBuf;
|
||||
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
|
||||
/// 'exec-server escalate' reads this to find the inherited FD for the escalate socket.
|
||||
pub(super) const ESCALATE_SOCKET_ENV_VAR: &str = "CODEX_ESCALATE_SOCKET";
|
||||
|
||||
/// The patched bash uses this to wrap exec() calls.
|
||||
pub(super) const BASH_EXEC_WRAPPER_ENV_VAR: &str = "BASH_EXEC_WRAPPER";
|
||||
|
||||
/// The client sends this to the server to request an exec() call.
|
||||
#[derive(Clone, Serialize, Deserialize, Debug, PartialEq, Eq)]
|
||||
pub(super) struct EscalateRequest {
|
||||
/// The absolute path to the executable to run, i.e. the first arg to exec.
|
||||
pub(super) file: PathBuf,
|
||||
/// The argv, including the program name (argv[0]).
|
||||
pub(super) argv: Vec<String>,
|
||||
pub(super) workdir: PathBuf,
|
||||
pub(super) env: HashMap<String, String>,
|
||||
}
|
||||
|
||||
/// The server sends this to the client to respond to an exec() request.
|
||||
#[derive(Clone, Serialize, Deserialize, Debug, PartialEq, Eq)]
|
||||
pub(super) struct EscalateResponse {
|
||||
pub(super) action: EscalateAction,
|
||||
}
|
||||
|
||||
#[derive(Clone, Serialize, Deserialize, Debug, PartialEq, Eq)]
|
||||
pub(super) enum EscalateAction {
|
||||
/// The command should be run directly by the client.
|
||||
Run,
|
||||
/// The command should be escalated to the server for execution.
|
||||
Escalate,
|
||||
}
|
||||
|
||||
/// The client sends this to the server to forward its open FDs.
|
||||
#[derive(Clone, Serialize, Deserialize, Debug)]
|
||||
pub(super) struct SuperExecMessage {
|
||||
pub(super) fds: Vec<RawFd>,
|
||||
}
|
||||
|
||||
/// The server responds when the exec()'d command has exited.
|
||||
#[derive(Clone, Serialize, Deserialize, Debug)]
|
||||
pub(super) struct SuperExecResult {
|
||||
pub(super) exit_code: i32,
|
||||
}
|
||||
274
codex-rs/exec-server/src/posix/escalate_server.rs
Normal file
274
codex-rs/exec-server/src/posix/escalate_server.rs
Normal file
|
|
@ -0,0 +1,274 @@
|
|||
use std::collections::HashMap;
|
||||
use std::os::fd::AsRawFd;
|
||||
use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
use std::process::Stdio;
|
||||
use std::time::Duration;
|
||||
|
||||
use anyhow::Context as _;
|
||||
use path_absolutize::Absolutize as _;
|
||||
|
||||
use codex_core::exec::SandboxType;
|
||||
use codex_core::exec::process_exec_tool_call;
|
||||
use codex_core::get_platform_sandbox;
|
||||
use codex_core::protocol::SandboxPolicy;
|
||||
use tokio::process::Command;
|
||||
|
||||
use crate::posix::escalate_protocol::BASH_EXEC_WRAPPER_ENV_VAR;
|
||||
use crate::posix::escalate_protocol::ESCALATE_SOCKET_ENV_VAR;
|
||||
use crate::posix::escalate_protocol::EscalateAction;
|
||||
use crate::posix::escalate_protocol::EscalateRequest;
|
||||
use crate::posix::escalate_protocol::EscalateResponse;
|
||||
use crate::posix::escalate_protocol::SuperExecMessage;
|
||||
use crate::posix::escalate_protocol::SuperExecResult;
|
||||
use crate::posix::socket::AsyncDatagramSocket;
|
||||
use crate::posix::socket::AsyncSocket;
|
||||
|
||||
/// This is the policy which decides how to handle an exec() call.
|
||||
///
|
||||
/// `file` is the absolute, canonical path to the executable to run, i.e. the first arg to exec.
|
||||
/// `argv` is the argv, including the program name (`argv[0]`).
|
||||
/// `workdir` is the absolute, canonical path to the working directory in which to execute the
|
||||
/// command.
|
||||
pub(crate) type ExecPolicy = fn(file: &Path, argv: &[String], workdir: &Path) -> EscalateAction;
|
||||
|
||||
pub(crate) struct EscalateServer {
|
||||
bash_path: PathBuf,
|
||||
policy: ExecPolicy,
|
||||
}
|
||||
|
||||
impl EscalateServer {
|
||||
pub fn new(bash_path: PathBuf, policy: ExecPolicy) -> Self {
|
||||
Self { bash_path, policy }
|
||||
}
|
||||
|
||||
pub async fn exec(
|
||||
&self,
|
||||
command: String,
|
||||
env: HashMap<String, String>,
|
||||
workdir: PathBuf,
|
||||
timeout_ms: Option<u64>,
|
||||
) -> anyhow::Result<ExecResult> {
|
||||
let (escalate_server, escalate_client) = AsyncDatagramSocket::pair()?;
|
||||
let client_socket = escalate_client.into_inner();
|
||||
client_socket.set_cloexec(false)?;
|
||||
|
||||
let escalate_task = tokio::spawn(escalate_task(escalate_server, self.policy));
|
||||
let mut env = env.clone();
|
||||
env.insert(
|
||||
ESCALATE_SOCKET_ENV_VAR.to_string(),
|
||||
client_socket.as_raw_fd().to_string(),
|
||||
);
|
||||
env.insert(
|
||||
BASH_EXEC_WRAPPER_ENV_VAR.to_string(),
|
||||
format!("{} escalate", std::env::current_exe()?.to_string_lossy()),
|
||||
);
|
||||
let result = process_exec_tool_call(
|
||||
codex_core::exec::ExecParams {
|
||||
command: vec![
|
||||
self.bash_path.to_string_lossy().to_string(),
|
||||
"-c".to_string(),
|
||||
command,
|
||||
],
|
||||
cwd: PathBuf::from(&workdir),
|
||||
timeout_ms,
|
||||
env,
|
||||
with_escalated_permissions: None,
|
||||
justification: None,
|
||||
arg0: None,
|
||||
},
|
||||
get_platform_sandbox().unwrap_or(SandboxType::None),
|
||||
// TODO: use the sandbox policy and cwd from the calling client
|
||||
&SandboxPolicy::ReadOnly,
|
||||
&PathBuf::from("/__NONEXISTENT__"), // This is ignored for ReadOnly
|
||||
&None,
|
||||
None,
|
||||
)
|
||||
.await?;
|
||||
escalate_task.abort();
|
||||
let result = ExecResult {
|
||||
exit_code: result.exit_code,
|
||||
output: result.aggregated_output.text,
|
||||
duration: result.duration,
|
||||
timed_out: result.timed_out,
|
||||
};
|
||||
Ok(result)
|
||||
}
|
||||
}
|
||||
|
||||
async fn escalate_task(socket: AsyncDatagramSocket, policy: ExecPolicy) -> anyhow::Result<()> {
|
||||
loop {
|
||||
let (_, mut fds) = socket.receive_with_fds().await?;
|
||||
if fds.len() != 1 {
|
||||
tracing::error!("expected 1 fd in datagram handshake, got {}", fds.len());
|
||||
continue;
|
||||
}
|
||||
let stream_socket = AsyncSocket::from_fd(fds.remove(0))?;
|
||||
tokio::spawn(async move {
|
||||
if let Err(err) = handle_escalate_session_with_policy(stream_socket, policy).await {
|
||||
tracing::error!("escalate session failed: {err:?}");
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct ExecResult {
|
||||
pub(crate) exit_code: i32,
|
||||
pub(crate) output: String,
|
||||
pub(crate) duration: Duration,
|
||||
pub(crate) timed_out: bool,
|
||||
}
|
||||
|
||||
async fn handle_escalate_session_with_policy(
|
||||
socket: AsyncSocket,
|
||||
policy: ExecPolicy,
|
||||
) -> anyhow::Result<()> {
|
||||
let EscalateRequest {
|
||||
file,
|
||||
argv,
|
||||
workdir,
|
||||
env,
|
||||
} = socket.receive::<EscalateRequest>().await?;
|
||||
let file = PathBuf::from(&file).absolutize()?.into_owned();
|
||||
let workdir = PathBuf::from(&workdir).absolutize()?.into_owned();
|
||||
let action = policy(file.as_path(), &argv, &workdir);
|
||||
tracing::debug!("decided {action:?} for {file:?} {argv:?} {workdir:?}");
|
||||
match action {
|
||||
EscalateAction::Run => {
|
||||
socket
|
||||
.send(EscalateResponse {
|
||||
action: EscalateAction::Run,
|
||||
})
|
||||
.await?;
|
||||
}
|
||||
EscalateAction::Escalate => {
|
||||
socket
|
||||
.send(EscalateResponse {
|
||||
action: EscalateAction::Escalate,
|
||||
})
|
||||
.await?;
|
||||
let (msg, fds) = socket
|
||||
.receive_with_fds::<SuperExecMessage>()
|
||||
.await
|
||||
.context("failed to receive SuperExecMessage")?;
|
||||
if fds.len() != msg.fds.len() {
|
||||
return Err(anyhow::anyhow!(
|
||||
"mismatched number of fds in SuperExecMessage: {} in the message, {} from the control message",
|
||||
msg.fds.len(),
|
||||
fds.len()
|
||||
));
|
||||
}
|
||||
|
||||
if msg
|
||||
.fds
|
||||
.iter()
|
||||
.any(|src_fd| fds.iter().any(|dst_fd| dst_fd.as_raw_fd() == *src_fd))
|
||||
{
|
||||
return Err(anyhow::anyhow!(
|
||||
"overlapping fds not yet supported in SuperExecMessage"
|
||||
));
|
||||
}
|
||||
|
||||
let mut command = Command::new(file);
|
||||
command
|
||||
.args(&argv[1..])
|
||||
.arg0(argv[0].clone())
|
||||
.envs(&env)
|
||||
.current_dir(&workdir)
|
||||
.stdin(Stdio::null())
|
||||
.stdout(Stdio::null())
|
||||
.stderr(Stdio::null());
|
||||
unsafe {
|
||||
command.pre_exec(move || {
|
||||
for (dst_fd, src_fd) in msg.fds.iter().zip(&fds) {
|
||||
libc::dup2(src_fd.as_raw_fd(), *dst_fd);
|
||||
}
|
||||
Ok(())
|
||||
});
|
||||
}
|
||||
let mut child = command.spawn()?;
|
||||
let exit_status = child.wait().await?;
|
||||
socket
|
||||
.send(SuperExecResult {
|
||||
exit_code: exit_status.code().unwrap_or(127),
|
||||
})
|
||||
.await?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use pretty_assertions::assert_eq;
|
||||
use std::collections::HashMap;
|
||||
use std::path::PathBuf;
|
||||
|
||||
#[tokio::test]
|
||||
async fn handle_escalate_session_respects_run_in_sandbox_decision() -> anyhow::Result<()> {
|
||||
let (server, client) = AsyncSocket::pair()?;
|
||||
let server_task = tokio::spawn(handle_escalate_session_with_policy(
|
||||
server,
|
||||
|_file, _argv, _workdir| EscalateAction::Run,
|
||||
));
|
||||
|
||||
client
|
||||
.send(EscalateRequest {
|
||||
file: PathBuf::from("/bin/echo"),
|
||||
argv: vec!["echo".to_string()],
|
||||
workdir: PathBuf::from("/tmp"),
|
||||
env: HashMap::new(),
|
||||
})
|
||||
.await?;
|
||||
|
||||
let response = client.receive::<EscalateResponse>().await?;
|
||||
assert_eq!(
|
||||
EscalateResponse {
|
||||
action: EscalateAction::Run,
|
||||
},
|
||||
response
|
||||
);
|
||||
server_task.await?
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn handle_escalate_session_executes_escalated_command() -> anyhow::Result<()> {
|
||||
let (server, client) = AsyncSocket::pair()?;
|
||||
let server_task = tokio::spawn(handle_escalate_session_with_policy(
|
||||
server,
|
||||
|_file, _argv, _workdir| EscalateAction::Escalate,
|
||||
));
|
||||
|
||||
client
|
||||
.send(EscalateRequest {
|
||||
file: PathBuf::from("/bin/sh"),
|
||||
argv: vec![
|
||||
"sh".to_string(),
|
||||
"-c".to_string(),
|
||||
r#"if [ "$KEY" = VALUE ]; then exit 42; else exit 1; fi"#.to_string(),
|
||||
],
|
||||
workdir: std::env::current_dir()?,
|
||||
env: HashMap::from([("KEY".to_string(), "VALUE".to_string())]),
|
||||
})
|
||||
.await?;
|
||||
|
||||
let response = client.receive::<EscalateResponse>().await?;
|
||||
assert_eq!(
|
||||
EscalateResponse {
|
||||
action: EscalateAction::Escalate,
|
||||
},
|
||||
response
|
||||
);
|
||||
|
||||
client
|
||||
.send_with_fds(SuperExecMessage { fds: Vec::new() }, &[])
|
||||
.await?;
|
||||
|
||||
let result = client.receive::<SuperExecResult>().await?;
|
||||
assert_eq!(42, result.exit_code);
|
||||
|
||||
server_task.await?
|
||||
}
|
||||
}
|
||||
154
codex-rs/exec-server/src/posix/mcp.rs
Normal file
154
codex-rs/exec-server/src/posix/mcp.rs
Normal file
|
|
@ -0,0 +1,154 @@
|
|||
use std::path::PathBuf;
|
||||
use std::time::Duration;
|
||||
|
||||
use anyhow::Context as _;
|
||||
use anyhow::Result;
|
||||
use rmcp::ErrorData as McpError;
|
||||
use rmcp::RoleServer;
|
||||
use rmcp::ServerHandler;
|
||||
use rmcp::ServiceExt;
|
||||
use rmcp::handler::server::router::tool::ToolRouter;
|
||||
use rmcp::handler::server::wrapper::Parameters;
|
||||
use rmcp::model::*;
|
||||
use rmcp::schemars;
|
||||
use rmcp::service::RequestContext;
|
||||
use rmcp::service::RunningService;
|
||||
use rmcp::tool;
|
||||
use rmcp::tool_handler;
|
||||
use rmcp::tool_router;
|
||||
use rmcp::transport::stdio;
|
||||
|
||||
use crate::posix::escalate_server;
|
||||
use crate::posix::escalate_server::EscalateServer;
|
||||
use crate::posix::escalate_server::ExecPolicy;
|
||||
|
||||
/// Path to our patched bash.
|
||||
const CODEX_BASH_PATH_ENV_VAR: &str = "CODEX_BASH_PATH";
|
||||
|
||||
pub(crate) fn get_bash_path() -> Result<PathBuf> {
|
||||
std::env::var(CODEX_BASH_PATH_ENV_VAR)
|
||||
.map(PathBuf::from)
|
||||
.context(format!("{CODEX_BASH_PATH_ENV_VAR} must be set"))
|
||||
}
|
||||
|
||||
#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
|
||||
pub struct ExecParams {
|
||||
/// The bash string to execute.
|
||||
pub command: String,
|
||||
/// The working directory to execute the command in. Must be an absolute path.
|
||||
pub workdir: String,
|
||||
/// The timeout for the command in milliseconds.
|
||||
pub timeout_ms: Option<u64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, serde::Serialize, schemars::JsonSchema)]
|
||||
pub struct ExecResult {
|
||||
pub exit_code: i32,
|
||||
pub output: String,
|
||||
pub duration: Duration,
|
||||
pub timed_out: bool,
|
||||
}
|
||||
|
||||
impl From<escalate_server::ExecResult> for ExecResult {
|
||||
fn from(result: escalate_server::ExecResult) -> Self {
|
||||
Self {
|
||||
exit_code: result.exit_code,
|
||||
output: result.output,
|
||||
duration: result.duration,
|
||||
timed_out: result.timed_out,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ExecTool {
|
||||
tool_router: ToolRouter<ExecTool>,
|
||||
bash_path: PathBuf,
|
||||
policy: ExecPolicy,
|
||||
}
|
||||
|
||||
#[tool_router]
|
||||
impl ExecTool {
|
||||
pub fn new(bash_path: PathBuf, policy: ExecPolicy) -> Self {
|
||||
Self {
|
||||
tool_router: Self::tool_router(),
|
||||
bash_path,
|
||||
policy,
|
||||
}
|
||||
}
|
||||
|
||||
/// Runs a shell command and returns its output. You MUST provide the workdir as an absolute path.
|
||||
#[tool]
|
||||
async fn shell(
|
||||
&self,
|
||||
_context: RequestContext<RoleServer>,
|
||||
Parameters(params): Parameters<ExecParams>,
|
||||
) -> Result<CallToolResult, McpError> {
|
||||
let escalate_server = EscalateServer::new(self.bash_path.clone(), self.policy);
|
||||
let result = escalate_server
|
||||
.exec(
|
||||
params.command,
|
||||
// TODO: use ShellEnvironmentPolicy
|
||||
std::env::vars().collect(),
|
||||
PathBuf::from(¶ms.workdir),
|
||||
params.timeout_ms,
|
||||
)
|
||||
.await
|
||||
.map_err(|e| McpError::internal_error(e.to_string(), None))?;
|
||||
Ok(CallToolResult::success(vec![Content::json(
|
||||
ExecResult::from(result),
|
||||
)?]))
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
async fn prompt(
|
||||
&self,
|
||||
command: String,
|
||||
workdir: String,
|
||||
context: RequestContext<RoleServer>,
|
||||
) -> Result<CreateElicitationResult, McpError> {
|
||||
context
|
||||
.peer
|
||||
.create_elicitation(CreateElicitationRequestParam {
|
||||
message: format!("Allow Codex to run `{command:?}` in `{workdir:?}`?"),
|
||||
#[allow(clippy::expect_used)]
|
||||
requested_schema: ElicitationSchema::builder()
|
||||
.property("dummy", PrimitiveSchema::String(StringSchema::new()))
|
||||
.build()
|
||||
.expect("failed to build elicitation schema"),
|
||||
})
|
||||
.await
|
||||
.map_err(|e| McpError::internal_error(e.to_string(), None))
|
||||
}
|
||||
}
|
||||
|
||||
#[tool_handler]
|
||||
impl ServerHandler for ExecTool {
|
||||
fn get_info(&self) -> ServerInfo {
|
||||
ServerInfo {
|
||||
protocol_version: ProtocolVersion::V_2025_06_18,
|
||||
capabilities: ServerCapabilities::builder().enable_tools().build(),
|
||||
server_info: Implementation::from_build_env(),
|
||||
instructions: Some(
|
||||
"This server provides a tool to execute shell commands and return their output."
|
||||
.to_string(),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
async fn initialize(
|
||||
&self,
|
||||
_request: InitializeRequestParam,
|
||||
_context: RequestContext<RoleServer>,
|
||||
) -> Result<InitializeResult, McpError> {
|
||||
Ok(self.get_info())
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn serve(
|
||||
bash_path: PathBuf,
|
||||
policy: ExecPolicy,
|
||||
) -> Result<RunningService<RoleServer, ExecTool>, rmcp::service::ServerInitializeError> {
|
||||
let tool = ExecTool::new(bash_path, policy);
|
||||
tool.serve(stdio()).await
|
||||
}
|
||||
486
codex-rs/exec-server/src/posix/socket.rs
Normal file
486
codex-rs/exec-server/src/posix/socket.rs
Normal file
|
|
@ -0,0 +1,486 @@
|
|||
use libc::c_uint;
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use socket2::Domain;
|
||||
use socket2::MaybeUninitSlice;
|
||||
use socket2::MsgHdr;
|
||||
use socket2::MsgHdrMut;
|
||||
use socket2::Socket;
|
||||
use socket2::Type;
|
||||
use std::io::IoSlice;
|
||||
use std::mem::MaybeUninit;
|
||||
use std::os::fd::AsRawFd;
|
||||
use std::os::fd::FromRawFd;
|
||||
use std::os::fd::OwnedFd;
|
||||
use std::os::fd::RawFd;
|
||||
use tokio::io::Interest;
|
||||
use tokio::io::unix::AsyncFd;
|
||||
|
||||
const MAX_FDS_PER_MESSAGE: usize = 16;
|
||||
const LENGTH_PREFIX_SIZE: usize = size_of::<u32>();
|
||||
const MAX_DATAGRAM_SIZE: usize = 8192;
|
||||
|
||||
/// Converts a slice of MaybeUninit<T> to a slice of T.
|
||||
///
|
||||
/// The caller guarantees that every element of `buf` is initialized.
|
||||
fn assume_init<T>(buf: &[MaybeUninit<T>]) -> &[T] {
|
||||
unsafe { std::slice::from_raw_parts(buf.as_ptr().cast(), buf.len()) }
|
||||
}
|
||||
|
||||
fn assume_init_slice<T, const N: usize>(buf: &[MaybeUninit<T>; N]) -> &[T; N] {
|
||||
unsafe { &*(buf as *const [MaybeUninit<T>; N] as *const [T; N]) }
|
||||
}
|
||||
|
||||
fn assume_init_vec<T>(mut buf: Vec<MaybeUninit<T>>) -> Vec<T> {
|
||||
unsafe {
|
||||
let ptr = buf.as_mut_ptr() as *mut T;
|
||||
let len = buf.len();
|
||||
let cap = buf.capacity();
|
||||
std::mem::forget(buf);
|
||||
Vec::from_raw_parts(ptr, len, cap)
|
||||
}
|
||||
}
|
||||
|
||||
fn control_space_for_fds(count: usize) -> usize {
|
||||
unsafe { libc::CMSG_SPACE((count * size_of::<RawFd>()) as _) as usize }
|
||||
}
|
||||
|
||||
/// Extracts the FDs from a SCM_RIGHTS control message.
|
||||
fn extract_fds(control: &[u8]) -> Vec<OwnedFd> {
|
||||
let mut fds = Vec::new();
|
||||
let mut hdr: libc::msghdr = unsafe { std::mem::zeroed() };
|
||||
hdr.msg_control = control.as_ptr() as *mut libc::c_void;
|
||||
hdr.msg_controllen = control.len() as _;
|
||||
let hdr = hdr; // drop mut
|
||||
|
||||
let mut cmsg = unsafe { libc::CMSG_FIRSTHDR(&hdr) as *const libc::cmsghdr };
|
||||
while !cmsg.is_null() {
|
||||
let level = unsafe { (*cmsg).cmsg_level };
|
||||
let ty = unsafe { (*cmsg).cmsg_type };
|
||||
if level == libc::SOL_SOCKET && ty == libc::SCM_RIGHTS {
|
||||
let data_ptr = unsafe { libc::CMSG_DATA(cmsg).cast::<RawFd>() };
|
||||
let fd_count: usize = {
|
||||
let cmsg_data_len =
|
||||
unsafe { (*cmsg).cmsg_len as usize } - unsafe { libc::CMSG_LEN(0) as usize };
|
||||
cmsg_data_len / size_of::<RawFd>()
|
||||
};
|
||||
for i in 0..fd_count {
|
||||
let fd = unsafe { data_ptr.add(i).read() };
|
||||
fds.push(unsafe { OwnedFd::from_raw_fd(fd) });
|
||||
}
|
||||
}
|
||||
cmsg = unsafe { libc::CMSG_NXTHDR(&hdr, cmsg) };
|
||||
}
|
||||
fds
|
||||
}
|
||||
|
||||
/// Read a frame from a SOCK_STREAM socket.
|
||||
///
|
||||
/// A frame is a message length prefix followed by a payload. FDs may be included in the control
|
||||
/// message when receiving the frame header.
|
||||
async fn read_frame(async_socket: &AsyncFd<Socket>) -> std::io::Result<(Vec<u8>, Vec<OwnedFd>)> {
|
||||
let (message_len, fds) = read_frame_header(async_socket).await?;
|
||||
let payload = read_frame_payload(async_socket, message_len).await?;
|
||||
Ok((payload, fds))
|
||||
}
|
||||
|
||||
/// Read the frame header (i.e. length) and any FDs from a SOCK_STREAM socket.
|
||||
async fn read_frame_header(
|
||||
async_socket: &AsyncFd<Socket>,
|
||||
) -> std::io::Result<(usize, Vec<OwnedFd>)> {
|
||||
let mut header = [MaybeUninit::<u8>::uninit(); LENGTH_PREFIX_SIZE];
|
||||
let mut filled = 0;
|
||||
let mut control = vec![MaybeUninit::<u8>::uninit(); control_space_for_fds(MAX_FDS_PER_MESSAGE)];
|
||||
let mut captured_control = false;
|
||||
|
||||
while filled < LENGTH_PREFIX_SIZE {
|
||||
let mut guard = async_socket.readable().await?;
|
||||
// The first read should come with a control message containing any FDs.
|
||||
let result = if !captured_control {
|
||||
guard.try_io(|inner| {
|
||||
let mut bufs = [MaybeUninitSlice::new(&mut header[filled..])];
|
||||
let (read, control_len) = {
|
||||
let mut msg = MsgHdrMut::new()
|
||||
.with_buffers(&mut bufs)
|
||||
.with_control(&mut control);
|
||||
let read = inner.get_ref().recvmsg(&mut msg, 0)?;
|
||||
(read, msg.control_len())
|
||||
};
|
||||
control.truncate(control_len);
|
||||
captured_control = true;
|
||||
Ok(read)
|
||||
})
|
||||
} else {
|
||||
guard.try_io(|inner| inner.get_ref().recv(&mut header[filled..]))
|
||||
};
|
||||
let Ok(result) = result else {
|
||||
// Would block, try again.
|
||||
continue;
|
||||
};
|
||||
|
||||
let read = result?;
|
||||
if read == 0 {
|
||||
return Err(std::io::Error::new(
|
||||
std::io::ErrorKind::UnexpectedEof,
|
||||
"socket closed while receiving frame header",
|
||||
));
|
||||
}
|
||||
|
||||
filled += read;
|
||||
assert!(filled <= LENGTH_PREFIX_SIZE);
|
||||
if filled == LENGTH_PREFIX_SIZE {
|
||||
let len_bytes = assume_init_slice(&header);
|
||||
let payload_len = u32::from_le_bytes(*len_bytes) as usize;
|
||||
let fds = extract_fds(assume_init(&control));
|
||||
return Ok((payload_len, fds));
|
||||
}
|
||||
}
|
||||
unreachable!("header loop always returns")
|
||||
}
|
||||
|
||||
/// Read `message_len` bytes from a SOCK_STREAM socket.
|
||||
async fn read_frame_payload(
|
||||
async_socket: &AsyncFd<Socket>,
|
||||
message_len: usize,
|
||||
) -> std::io::Result<Vec<u8>> {
|
||||
if message_len == 0 {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
let mut payload = vec![MaybeUninit::<u8>::uninit(); message_len];
|
||||
let mut filled = 0;
|
||||
while filled < message_len {
|
||||
let mut guard = async_socket.readable().await?;
|
||||
let result = guard.try_io(|inner| inner.get_ref().recv(&mut payload[filled..]));
|
||||
let Ok(result) = result else {
|
||||
// Would block, try again.
|
||||
continue;
|
||||
};
|
||||
let read = result?;
|
||||
if read == 0 {
|
||||
return Err(std::io::Error::new(
|
||||
std::io::ErrorKind::UnexpectedEof,
|
||||
"socket closed while receiving frame payload",
|
||||
));
|
||||
}
|
||||
filled += read;
|
||||
assert!(filled <= message_len);
|
||||
if filled == message_len {
|
||||
return Ok(assume_init_vec(payload));
|
||||
}
|
||||
}
|
||||
unreachable!("loop exits only after returning payload")
|
||||
}
|
||||
|
||||
fn send_message_bytes(socket: &Socket, data: &[u8], fds: &[OwnedFd]) -> std::io::Result<()> {
|
||||
if fds.len() > MAX_FDS_PER_MESSAGE {
|
||||
return Err(std::io::Error::new(
|
||||
std::io::ErrorKind::InvalidInput,
|
||||
format!("too many fds: {}", fds.len()),
|
||||
));
|
||||
}
|
||||
let mut frame = Vec::with_capacity(LENGTH_PREFIX_SIZE + data.len());
|
||||
frame.extend_from_slice(&encode_length(data.len())?);
|
||||
frame.extend_from_slice(data);
|
||||
|
||||
let mut control = vec![0u8; control_space_for_fds(fds.len())];
|
||||
unsafe {
|
||||
let cmsg = control.as_mut_ptr().cast::<libc::cmsghdr>();
|
||||
(*cmsg).cmsg_len = libc::CMSG_LEN(size_of::<RawFd>() as c_uint * fds.len() as c_uint) as _;
|
||||
(*cmsg).cmsg_level = libc::SOL_SOCKET;
|
||||
(*cmsg).cmsg_type = libc::SCM_RIGHTS;
|
||||
let data_ptr = libc::CMSG_DATA(cmsg).cast::<RawFd>();
|
||||
for (i, fd) in fds.iter().enumerate() {
|
||||
data_ptr.add(i).write(fd.as_raw_fd());
|
||||
}
|
||||
}
|
||||
|
||||
let payload = [IoSlice::new(&frame)];
|
||||
let msg = MsgHdr::new().with_buffers(&payload).with_control(&control);
|
||||
let mut sent = socket.sendmsg(&msg, 0)?;
|
||||
while sent < frame.len() {
|
||||
let bytes = socket.send(&frame[sent..])?;
|
||||
if bytes == 0 {
|
||||
return Err(std::io::Error::new(
|
||||
std::io::ErrorKind::WriteZero,
|
||||
"socket closed while sending frame payload",
|
||||
));
|
||||
}
|
||||
sent += bytes;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn encode_length(len: usize) -> std::io::Result<[u8; LENGTH_PREFIX_SIZE]> {
|
||||
let len_u32 = u32::try_from(len).map_err(|_| {
|
||||
std::io::Error::new(
|
||||
std::io::ErrorKind::InvalidInput,
|
||||
format!("message too large: {len}"),
|
||||
)
|
||||
})?;
|
||||
Ok(len_u32.to_le_bytes())
|
||||
}
|
||||
|
||||
pub(crate) fn send_json_message<T: Serialize>(
|
||||
socket: &Socket,
|
||||
msg: T,
|
||||
fds: &[OwnedFd],
|
||||
) -> std::io::Result<()> {
|
||||
let data = serde_json::to_vec(&msg)?;
|
||||
send_message_bytes(socket, &data, fds)
|
||||
}
|
||||
|
||||
fn send_datagram_bytes(socket: &Socket, data: &[u8], fds: &[OwnedFd]) -> std::io::Result<()> {
|
||||
if fds.len() > MAX_FDS_PER_MESSAGE {
|
||||
return Err(std::io::Error::new(
|
||||
std::io::ErrorKind::InvalidInput,
|
||||
format!("too many fds: {}", fds.len()),
|
||||
));
|
||||
}
|
||||
let mut control = vec![0u8; control_space_for_fds(fds.len())];
|
||||
if !fds.is_empty() {
|
||||
unsafe {
|
||||
let cmsg = control.as_mut_ptr().cast::<libc::cmsghdr>();
|
||||
(*cmsg).cmsg_len =
|
||||
libc::CMSG_LEN(size_of::<RawFd>() as c_uint * fds.len() as c_uint) as _;
|
||||
(*cmsg).cmsg_level = libc::SOL_SOCKET;
|
||||
(*cmsg).cmsg_type = libc::SCM_RIGHTS;
|
||||
let data_ptr = libc::CMSG_DATA(cmsg).cast::<RawFd>();
|
||||
for (i, fd) in fds.iter().enumerate() {
|
||||
data_ptr.add(i).write(fd.as_raw_fd());
|
||||
}
|
||||
}
|
||||
}
|
||||
let payload = [IoSlice::new(data)];
|
||||
let msg = MsgHdr::new().with_buffers(&payload).with_control(&control);
|
||||
let written = socket.sendmsg(&msg, 0)?;
|
||||
if written != data.len() {
|
||||
return Err(std::io::Error::new(
|
||||
std::io::ErrorKind::WriteZero,
|
||||
format!(
|
||||
"short datagram write: wrote {written} bytes out of {}",
|
||||
data.len()
|
||||
),
|
||||
));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn receive_datagram_bytes(socket: &Socket) -> std::io::Result<(Vec<u8>, Vec<OwnedFd>)> {
|
||||
let mut buffer = vec![MaybeUninit::<u8>::uninit(); MAX_DATAGRAM_SIZE];
|
||||
let mut control = vec![MaybeUninit::<u8>::uninit(); control_space_for_fds(MAX_FDS_PER_MESSAGE)];
|
||||
let (read, control_len) = {
|
||||
let mut bufs = [MaybeUninitSlice::new(&mut buffer)];
|
||||
let mut msg = MsgHdrMut::new()
|
||||
.with_buffers(&mut bufs)
|
||||
.with_control(&mut control);
|
||||
let read = socket.recvmsg(&mut msg, 0)?;
|
||||
(read, msg.control_len())
|
||||
};
|
||||
let data = assume_init(&buffer[..read]).to_vec();
|
||||
let fds = extract_fds(assume_init(&control[..control_len]));
|
||||
Ok((data, fds))
|
||||
}
|
||||
|
||||
pub(crate) struct AsyncSocket {
|
||||
inner: AsyncFd<Socket>,
|
||||
}
|
||||
|
||||
impl AsyncSocket {
|
||||
fn new(socket: Socket) -> std::io::Result<AsyncSocket> {
|
||||
socket.set_nonblocking(true)?;
|
||||
let async_socket = AsyncFd::new(socket)?;
|
||||
Ok(AsyncSocket {
|
||||
inner: async_socket,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn from_fd(fd: OwnedFd) -> std::io::Result<AsyncSocket> {
|
||||
AsyncSocket::new(Socket::from(fd))
|
||||
}
|
||||
|
||||
pub fn pair() -> std::io::Result<(AsyncSocket, AsyncSocket)> {
|
||||
let (server, client) = Socket::pair(Domain::UNIX, Type::STREAM, None)?;
|
||||
Ok((AsyncSocket::new(server)?, AsyncSocket::new(client)?))
|
||||
}
|
||||
|
||||
pub async fn send_with_fds<T: Serialize>(
|
||||
&self,
|
||||
msg: T,
|
||||
fds: &[OwnedFd],
|
||||
) -> std::io::Result<()> {
|
||||
self.inner
|
||||
.async_io(Interest::WRITABLE, |socket| {
|
||||
send_json_message(socket, &msg, fds)
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn receive_with_fds<T: for<'de> Deserialize<'de>>(
|
||||
&self,
|
||||
) -> std::io::Result<(T, Vec<OwnedFd>)> {
|
||||
let (payload, fds) = read_frame(&self.inner).await?;
|
||||
let message: T = serde_json::from_slice(&payload)?;
|
||||
Ok((message, fds))
|
||||
}
|
||||
|
||||
pub async fn send<T>(&self, msg: T) -> std::io::Result<()>
|
||||
where
|
||||
T: Serialize,
|
||||
{
|
||||
self.send_with_fds(&msg, &[]).await
|
||||
}
|
||||
|
||||
pub async fn receive<T: for<'de> Deserialize<'de>>(&self) -> std::io::Result<T> {
|
||||
let (msg, fds) = self.receive_with_fds().await?;
|
||||
if !fds.is_empty() {
|
||||
tracing::warn!("unexpected fds in receive: {}", fds.len());
|
||||
}
|
||||
Ok(msg)
|
||||
}
|
||||
|
||||
pub fn into_inner(self) -> Socket {
|
||||
self.inner.into_inner()
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct AsyncDatagramSocket {
|
||||
inner: AsyncFd<Socket>,
|
||||
}
|
||||
|
||||
impl AsyncDatagramSocket {
|
||||
fn new(socket: Socket) -> std::io::Result<Self> {
|
||||
socket.set_nonblocking(true)?;
|
||||
Ok(Self {
|
||||
inner: AsyncFd::new(socket)?,
|
||||
})
|
||||
}
|
||||
|
||||
pub unsafe fn from_raw_fd(fd: RawFd) -> std::io::Result<Self> {
|
||||
Self::new(unsafe { Socket::from_raw_fd(fd) })
|
||||
}
|
||||
|
||||
pub fn pair() -> std::io::Result<(Self, Self)> {
|
||||
let (server, client) = Socket::pair(Domain::UNIX, Type::DGRAM, None)?;
|
||||
Ok((Self::new(server)?, Self::new(client)?))
|
||||
}
|
||||
|
||||
pub async fn send_with_fds(&self, data: &[u8], fds: &[OwnedFd]) -> std::io::Result<()> {
|
||||
self.inner
|
||||
.async_io(Interest::WRITABLE, |socket| {
|
||||
send_datagram_bytes(socket, data, fds)
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn receive_with_fds(&self) -> std::io::Result<(Vec<u8>, Vec<OwnedFd>)> {
|
||||
self.inner
|
||||
.async_io(Interest::READABLE, receive_datagram_bytes)
|
||||
.await
|
||||
}
|
||||
|
||||
pub fn into_inner(self) -> Socket {
|
||||
self.inner.into_inner()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use pretty_assertions::assert_eq;
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use std::os::fd::AsFd;
|
||||
use std::os::fd::AsRawFd;
|
||||
use tempfile::NamedTempFile;
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)]
|
||||
struct TestPayload {
|
||||
id: i32,
|
||||
label: String,
|
||||
}
|
||||
|
||||
fn fd_list(count: usize) -> std::io::Result<Vec<OwnedFd>> {
|
||||
let file = NamedTempFile::new()?;
|
||||
let mut fds = Vec::new();
|
||||
for _ in 0..count {
|
||||
fds.push(file.as_fd().try_clone_to_owned()?);
|
||||
}
|
||||
Ok(fds)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn async_socket_round_trips_payload_and_fds() -> std::io::Result<()> {
|
||||
let (server, client) = AsyncSocket::pair()?;
|
||||
let payload = TestPayload {
|
||||
id: 7,
|
||||
label: "round-trip".to_string(),
|
||||
};
|
||||
let send_fds = fd_list(1)?;
|
||||
|
||||
let receive_task =
|
||||
tokio::spawn(async move { server.receive_with_fds::<TestPayload>().await });
|
||||
client.send_with_fds(payload.clone(), &send_fds).await?;
|
||||
drop(send_fds);
|
||||
|
||||
let (received_payload, received_fds) = receive_task.await.unwrap()?;
|
||||
assert_eq!(payload, received_payload);
|
||||
assert_eq!(1, received_fds.len());
|
||||
let fd_status = unsafe { libc::fcntl(received_fds[0].as_raw_fd(), libc::F_GETFD) };
|
||||
assert!(
|
||||
fd_status >= 0,
|
||||
"expected received file descriptor to be valid, but got {fd_status}",
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn async_datagram_sockets_round_trip_messages() -> std::io::Result<()> {
|
||||
let (server, client) = AsyncDatagramSocket::pair()?;
|
||||
let data = b"datagram payload".to_vec();
|
||||
let send_fds = fd_list(1)?;
|
||||
let receive_task = tokio::spawn(async move { server.receive_with_fds().await });
|
||||
|
||||
client.send_with_fds(&data, &send_fds).await?;
|
||||
drop(send_fds);
|
||||
|
||||
let (received_bytes, received_fds) = receive_task.await.unwrap()?;
|
||||
assert_eq!(data, received_bytes);
|
||||
assert_eq!(1, received_fds.len());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn send_message_bytes_rejects_excessive_fd_counts() -> std::io::Result<()> {
|
||||
let (socket, _peer) = Socket::pair(Domain::UNIX, Type::STREAM, None)?;
|
||||
let fds = fd_list(MAX_FDS_PER_MESSAGE + 1)?;
|
||||
let err = send_message_bytes(&socket, b"hello", &fds).unwrap_err();
|
||||
assert_eq!(std::io::ErrorKind::InvalidInput, err.kind());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn send_datagram_bytes_rejects_excessive_fd_counts() -> std::io::Result<()> {
|
||||
let (socket, _peer) = Socket::pair(Domain::UNIX, Type::DGRAM, None)?;
|
||||
let fds = fd_list(MAX_FDS_PER_MESSAGE + 1)?;
|
||||
let err = send_datagram_bytes(&socket, b"hi", &fds).unwrap_err();
|
||||
assert_eq!(std::io::ErrorKind::InvalidInput, err.kind());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn encode_length_errors_for_oversized_messages() {
|
||||
let err = encode_length(usize::MAX).unwrap_err();
|
||||
assert_eq!(std::io::ErrorKind::InvalidInput, err.kind());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn receive_fails_when_peer_closes_before_header() {
|
||||
let (server, client) = AsyncSocket::pair().expect("failed to create socket pair");
|
||||
drop(client);
|
||||
let err = server
|
||||
.receive::<serde_json::Value>()
|
||||
.await
|
||||
.expect_err("expected read failure");
|
||||
assert_eq!(std::io::ErrorKind::UnexpectedEof, err.kind());
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Reference in a new issue