core-agent-ide/codex-rs/exec-server/src/posix/socket.rs
Michael Bolin 82090803d9
fix: exec-server stream was erroring for large requests (#7654)
Previous to this change, large `EscalateRequest` payloads exceeded the
kernel send buffer, causing our single `sendmsg(2)` call (with attached
FDs) to be split and retried without proper control handling; this led
to `EINVAL`/broken pipe in the
`handle_escalate_session_respects_run_in_sandbox_decision()` test when
using an `env` with large contents.

**Before:** `AsyncSocket::send_with_fds()` called `send_json_message()`,
which called `send_message_bytes()`, which made one `socket.sendmsg()`
call followed by additional `socket.send()` calls, as necessary:


2e4a402521/codex-rs/exec-server/src/posix/socket.rs (L198-L209)

**After:** `AsyncSocket::send_with_fds()` now calls
`send_stream_frame()`, which calls `send_stream_chunk()` one or more
times. Each call to `send_stream_chunk()` calls `socket.sendmsg()`.

In the previous implementation, the subsequent `socket.send()` writes
had no control information associated with them, whereas in the new
`send_stream_chunk()` implementation, a fresh `MsgHdr` (using
`with_control()`, as appropriate) is created for `socket.sendmsg()` each
time.

Additionally, with this PR, stream sending attaches `SCM_RIGHTS` only on
the first chunk, and omits control data when there are no FDs, allowing
oversized payloads to deliver correctly while preserving FD limits and
error checks.
2025-12-06 10:16:47 -08:00

507 lines
17 KiB
Rust

use libc::c_uint;
use serde::Deserialize;
use serde::Serialize;
use socket2::Domain;
use socket2::MaybeUninitSlice;
use socket2::MsgHdr;
use socket2::MsgHdrMut;
use socket2::Socket;
use socket2::Type;
use std::io::IoSlice;
use std::mem::MaybeUninit;
use std::os::fd::AsRawFd;
use std::os::fd::FromRawFd;
use std::os::fd::OwnedFd;
use std::os::fd::RawFd;
use tokio::io::Interest;
use tokio::io::unix::AsyncFd;
const MAX_FDS_PER_MESSAGE: usize = 16;
const LENGTH_PREFIX_SIZE: usize = size_of::<u32>();
const MAX_DATAGRAM_SIZE: usize = 8192;
/// Converts a slice of MaybeUninit<T> to a slice of T.
///
/// The caller guarantees that every element of `buf` is initialized.
fn assume_init<T>(buf: &[MaybeUninit<T>]) -> &[T] {
unsafe { std::slice::from_raw_parts(buf.as_ptr().cast(), buf.len()) }
}
fn assume_init_slice<T, const N: usize>(buf: &[MaybeUninit<T>; N]) -> &[T; N] {
unsafe { &*(buf as *const [MaybeUninit<T>; N] as *const [T; N]) }
}
fn assume_init_vec<T>(mut buf: Vec<MaybeUninit<T>>) -> Vec<T> {
unsafe {
let ptr = buf.as_mut_ptr() as *mut T;
let len = buf.len();
let cap = buf.capacity();
std::mem::forget(buf);
Vec::from_raw_parts(ptr, len, cap)
}
}
fn control_space_for_fds(count: usize) -> usize {
unsafe { libc::CMSG_SPACE((count * size_of::<RawFd>()) as _) as usize }
}
/// Extracts the FDs from a SCM_RIGHTS control message.
fn extract_fds(control: &[u8]) -> Vec<OwnedFd> {
let mut fds = Vec::new();
let mut hdr: libc::msghdr = unsafe { std::mem::zeroed() };
hdr.msg_control = control.as_ptr() as *mut libc::c_void;
hdr.msg_controllen = control.len() as _;
let hdr = hdr; // drop mut
let mut cmsg = unsafe { libc::CMSG_FIRSTHDR(&hdr) as *const libc::cmsghdr };
while !cmsg.is_null() {
let level = unsafe { (*cmsg).cmsg_level };
let ty = unsafe { (*cmsg).cmsg_type };
if level == libc::SOL_SOCKET && ty == libc::SCM_RIGHTS {
let data_ptr = unsafe { libc::CMSG_DATA(cmsg).cast::<RawFd>() };
let fd_count: usize = {
let cmsg_data_len =
unsafe { (*cmsg).cmsg_len as usize } - unsafe { libc::CMSG_LEN(0) as usize };
cmsg_data_len / size_of::<RawFd>()
};
for i in 0..fd_count {
let fd = unsafe { data_ptr.add(i).read() };
fds.push(unsafe { OwnedFd::from_raw_fd(fd) });
}
}
cmsg = unsafe { libc::CMSG_NXTHDR(&hdr, cmsg) };
}
fds
}
/// Read a frame from a SOCK_STREAM socket.
///
/// A frame is a message length prefix followed by a payload. FDs may be included in the control
/// message when receiving the frame header.
async fn read_frame(async_socket: &AsyncFd<Socket>) -> std::io::Result<(Vec<u8>, Vec<OwnedFd>)> {
let (message_len, fds) = read_frame_header(async_socket).await?;
let payload = read_frame_payload(async_socket, message_len).await?;
Ok((payload, fds))
}
/// Read the frame header (i.e. length) and any FDs from a SOCK_STREAM socket.
async fn read_frame_header(
async_socket: &AsyncFd<Socket>,
) -> std::io::Result<(usize, Vec<OwnedFd>)> {
let mut header = [MaybeUninit::<u8>::uninit(); LENGTH_PREFIX_SIZE];
let mut filled = 0;
let mut control = vec![MaybeUninit::<u8>::uninit(); control_space_for_fds(MAX_FDS_PER_MESSAGE)];
let mut captured_control = false;
while filled < LENGTH_PREFIX_SIZE {
let mut guard = async_socket.readable().await?;
// The first read should come with a control message containing any FDs.
let result = if !captured_control {
guard.try_io(|inner| {
let mut bufs = [MaybeUninitSlice::new(&mut header[filled..])];
let (read, control_len) = {
let mut msg = MsgHdrMut::new()
.with_buffers(&mut bufs)
.with_control(&mut control);
let read = inner.get_ref().recvmsg(&mut msg, 0)?;
(read, msg.control_len())
};
control.truncate(control_len);
captured_control = true;
Ok(read)
})
} else {
guard.try_io(|inner| inner.get_ref().recv(&mut header[filled..]))
};
let Ok(result) = result else {
// Would block, try again.
continue;
};
let read = result?;
if read == 0 {
return Err(std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
"socket closed while receiving frame header",
));
}
filled += read;
assert!(filled <= LENGTH_PREFIX_SIZE);
if filled == LENGTH_PREFIX_SIZE {
let len_bytes = assume_init_slice(&header);
let payload_len = u32::from_le_bytes(*len_bytes) as usize;
let fds = extract_fds(assume_init(&control));
return Ok((payload_len, fds));
}
}
unreachable!("header loop always returns")
}
/// Read `message_len` bytes from a SOCK_STREAM socket.
async fn read_frame_payload(
async_socket: &AsyncFd<Socket>,
message_len: usize,
) -> std::io::Result<Vec<u8>> {
if message_len == 0 {
return Ok(Vec::new());
}
let mut payload = vec![MaybeUninit::<u8>::uninit(); message_len];
let mut filled = 0;
while filled < message_len {
let mut guard = async_socket.readable().await?;
let result = guard.try_io(|inner| inner.get_ref().recv(&mut payload[filled..]));
let Ok(result) = result else {
// Would block, try again.
continue;
};
let read = result?;
if read == 0 {
return Err(std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
"socket closed while receiving frame payload",
));
}
filled += read;
assert!(filled <= message_len);
if filled == message_len {
return Ok(assume_init_vec(payload));
}
}
unreachable!("loop exits only after returning payload")
}
fn send_datagram_bytes(socket: &Socket, data: &[u8], fds: &[OwnedFd]) -> std::io::Result<()> {
let control = make_control_message(fds)?;
let payload = [IoSlice::new(data)];
let msg = if control.is_empty() {
MsgHdr::new().with_buffers(&payload)
} else {
MsgHdr::new().with_buffers(&payload).with_control(&control)
};
let written = socket.sendmsg(&msg, 0)?;
if written != data.len() {
return Err(std::io::Error::new(
std::io::ErrorKind::WriteZero,
format!(
"short datagram write: wrote {written} bytes out of {}",
data.len()
),
));
}
Ok(())
}
fn encode_length(len: usize) -> std::io::Result<[u8; LENGTH_PREFIX_SIZE]> {
let len_u32 = u32::try_from(len).map_err(|_| {
std::io::Error::new(
std::io::ErrorKind::InvalidInput,
format!("message too large: {len}"),
)
})?;
Ok(len_u32.to_le_bytes())
}
fn make_control_message(fds: &[OwnedFd]) -> std::io::Result<Vec<u8>> {
if fds.len() > MAX_FDS_PER_MESSAGE {
Err(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
format!("too many fds: {}", fds.len()),
))
} else if fds.is_empty() {
Ok(Vec::new())
} else {
let mut control = vec![0u8; control_space_for_fds(fds.len())];
unsafe {
let cmsg = control.as_mut_ptr().cast::<libc::cmsghdr>();
(*cmsg).cmsg_len =
libc::CMSG_LEN(size_of::<RawFd>() as c_uint * fds.len() as c_uint) as _;
(*cmsg).cmsg_level = libc::SOL_SOCKET;
(*cmsg).cmsg_type = libc::SCM_RIGHTS;
let data_ptr = libc::CMSG_DATA(cmsg).cast::<RawFd>();
for (i, fd) in fds.iter().enumerate() {
data_ptr.add(i).write(fd.as_raw_fd());
}
}
Ok(control)
}
}
fn receive_datagram_bytes(socket: &Socket) -> std::io::Result<(Vec<u8>, Vec<OwnedFd>)> {
let mut buffer = vec![MaybeUninit::<u8>::uninit(); MAX_DATAGRAM_SIZE];
let mut control = vec![MaybeUninit::<u8>::uninit(); control_space_for_fds(MAX_FDS_PER_MESSAGE)];
let (read, control_len) = {
let mut bufs = [MaybeUninitSlice::new(&mut buffer)];
let mut msg = MsgHdrMut::new()
.with_buffers(&mut bufs)
.with_control(&mut control);
let read = socket.recvmsg(&mut msg, 0)?;
(read, msg.control_len())
};
let data = assume_init(&buffer[..read]).to_vec();
let fds = extract_fds(assume_init(&control[..control_len]));
Ok((data, fds))
}
pub(crate) struct AsyncSocket {
inner: AsyncFd<Socket>,
}
impl AsyncSocket {
fn new(socket: Socket) -> std::io::Result<AsyncSocket> {
socket.set_nonblocking(true)?;
let async_socket = AsyncFd::new(socket)?;
Ok(AsyncSocket {
inner: async_socket,
})
}
pub fn from_fd(fd: OwnedFd) -> std::io::Result<AsyncSocket> {
AsyncSocket::new(Socket::from(fd))
}
pub fn pair() -> std::io::Result<(AsyncSocket, AsyncSocket)> {
let (server, client) = Socket::pair(Domain::UNIX, Type::STREAM, None)?;
Ok((AsyncSocket::new(server)?, AsyncSocket::new(client)?))
}
pub async fn send_with_fds<T: Serialize>(
&self,
msg: T,
fds: &[OwnedFd],
) -> std::io::Result<()> {
let payload = serde_json::to_vec(&msg)?;
let mut frame = Vec::with_capacity(LENGTH_PREFIX_SIZE + payload.len());
frame.extend_from_slice(&encode_length(payload.len())?);
frame.extend_from_slice(&payload);
send_stream_frame(&self.inner, &frame, fds).await
}
pub async fn receive_with_fds<T: for<'de> Deserialize<'de>>(
&self,
) -> std::io::Result<(T, Vec<OwnedFd>)> {
let (payload, fds) = read_frame(&self.inner).await?;
let message: T = serde_json::from_slice(&payload)?;
Ok((message, fds))
}
pub async fn send<T>(&self, msg: T) -> std::io::Result<()>
where
T: Serialize,
{
self.send_with_fds(&msg, &[]).await
}
pub async fn receive<T: for<'de> Deserialize<'de>>(&self) -> std::io::Result<T> {
let (msg, fds) = self.receive_with_fds().await?;
if !fds.is_empty() {
tracing::warn!("unexpected fds in receive: {}", fds.len());
}
Ok(msg)
}
pub fn into_inner(self) -> Socket {
self.inner.into_inner()
}
}
async fn send_stream_frame(
socket: &AsyncFd<Socket>,
frame: &[u8],
fds: &[OwnedFd],
) -> std::io::Result<()> {
let mut written = 0;
let mut include_fds = !fds.is_empty();
while written < frame.len() {
let mut guard = socket.writable().await?;
let result = guard.try_io(|inner| {
send_stream_chunk(inner.get_ref(), &frame[written..], fds, include_fds)
});
let bytes_written = match result {
Ok(bytes_written) => bytes_written?,
Err(_would_block) => continue,
};
if bytes_written == 0 {
return Err(std::io::Error::new(
std::io::ErrorKind::WriteZero,
"socket closed while sending frame payload",
));
}
written += bytes_written;
include_fds = false;
}
Ok(())
}
fn send_stream_chunk(
socket: &Socket,
frame: &[u8],
fds: &[OwnedFd],
include_fds: bool,
) -> std::io::Result<usize> {
let control = if include_fds {
make_control_message(fds)?
} else {
Vec::new()
};
let payload = [IoSlice::new(frame)];
let msg = if control.is_empty() {
MsgHdr::new().with_buffers(&payload)
} else {
MsgHdr::new().with_buffers(&payload).with_control(&control)
};
socket.sendmsg(&msg, 0)
}
pub(crate) struct AsyncDatagramSocket {
inner: AsyncFd<Socket>,
}
impl AsyncDatagramSocket {
fn new(socket: Socket) -> std::io::Result<Self> {
socket.set_nonblocking(true)?;
Ok(Self {
inner: AsyncFd::new(socket)?,
})
}
pub unsafe fn from_raw_fd(fd: RawFd) -> std::io::Result<Self> {
Self::new(unsafe { Socket::from_raw_fd(fd) })
}
pub fn pair() -> std::io::Result<(Self, Self)> {
let (server, client) = Socket::pair(Domain::UNIX, Type::DGRAM, None)?;
Ok((Self::new(server)?, Self::new(client)?))
}
pub async fn send_with_fds(&self, data: &[u8], fds: &[OwnedFd]) -> std::io::Result<()> {
self.inner
.async_io(Interest::WRITABLE, |socket| {
send_datagram_bytes(socket, data, fds)
})
.await
}
pub async fn receive_with_fds(&self) -> std::io::Result<(Vec<u8>, Vec<OwnedFd>)> {
self.inner
.async_io(Interest::READABLE, receive_datagram_bytes)
.await
}
pub fn into_inner(self) -> Socket {
self.inner.into_inner()
}
}
#[cfg(test)]
mod tests {
use super::*;
use pretty_assertions::assert_eq;
use serde::Deserialize;
use serde::Serialize;
use std::os::fd::AsFd;
use std::os::fd::AsRawFd;
use tempfile::NamedTempFile;
#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)]
struct TestPayload {
id: i32,
label: String,
}
fn fd_list(count: usize) -> std::io::Result<Vec<OwnedFd>> {
let file = NamedTempFile::new()?;
let mut fds = Vec::new();
for _ in 0..count {
fds.push(file.as_fd().try_clone_to_owned()?);
}
Ok(fds)
}
#[tokio::test]
async fn async_socket_round_trips_payload_and_fds() -> std::io::Result<()> {
let (server, client) = AsyncSocket::pair()?;
let payload = TestPayload {
id: 7,
label: "round-trip".to_string(),
};
let send_fds = fd_list(1)?;
let receive_task =
tokio::spawn(async move { server.receive_with_fds::<TestPayload>().await });
client.send_with_fds(payload.clone(), &send_fds).await?;
drop(send_fds);
let (received_payload, received_fds) = receive_task.await.unwrap()?;
assert_eq!(payload, received_payload);
assert_eq!(1, received_fds.len());
let fd_status = unsafe { libc::fcntl(received_fds[0].as_raw_fd(), libc::F_GETFD) };
assert!(
fd_status >= 0,
"expected received file descriptor to be valid, but got {fd_status}",
);
Ok(())
}
#[tokio::test]
async fn async_socket_handles_large_payload() -> std::io::Result<()> {
let (server, client) = AsyncSocket::pair()?;
let payload = vec![b'A'; 10_000];
let receive_task = tokio::spawn(async move { server.receive::<Vec<u8>>().await });
client.send(payload.clone()).await?;
let received_payload = receive_task.await.unwrap()?;
assert_eq!(payload, received_payload);
Ok(())
}
#[tokio::test]
async fn async_datagram_sockets_round_trip_messages() -> std::io::Result<()> {
let (server, client) = AsyncDatagramSocket::pair()?;
let data = b"datagram payload".to_vec();
let send_fds = fd_list(1)?;
let receive_task = tokio::spawn(async move { server.receive_with_fds().await });
client.send_with_fds(&data, &send_fds).await?;
drop(send_fds);
let (received_bytes, received_fds) = receive_task.await.unwrap()?;
assert_eq!(data, received_bytes);
assert_eq!(1, received_fds.len());
Ok(())
}
#[test]
fn send_datagram_bytes_rejects_excessive_fd_counts() -> std::io::Result<()> {
let (socket, _peer) = Socket::pair(Domain::UNIX, Type::DGRAM, None)?;
let fds = fd_list(MAX_FDS_PER_MESSAGE + 1)?;
let err = send_datagram_bytes(&socket, b"hi", &fds).unwrap_err();
assert_eq!(std::io::ErrorKind::InvalidInput, err.kind());
Ok(())
}
#[test]
fn send_stream_chunk_rejects_excessive_fd_counts() -> std::io::Result<()> {
let (socket, _peer) = Socket::pair(Domain::UNIX, Type::STREAM, None)?;
let fds = fd_list(MAX_FDS_PER_MESSAGE + 1)?;
let err = send_stream_chunk(&socket, b"hello", &fds, true).unwrap_err();
assert_eq!(std::io::ErrorKind::InvalidInput, err.kind());
Ok(())
}
#[test]
fn encode_length_errors_for_oversized_messages() {
let err = encode_length(usize::MAX).unwrap_err();
assert_eq!(std::io::ErrorKind::InvalidInput, err.kind());
}
#[tokio::test]
async fn receive_fails_when_peer_closes_before_header() {
let (server, client) = AsyncSocket::pair().expect("failed to create socket pair");
drop(client);
let err = server
.receive::<serde_json::Value>()
.await
.expect_err("expected read failure");
assert_eq!(std::io::ErrorKind::UnexpectedEof, err.kind());
}
}