From 81996fcde605a452ca94662eb7028e8c8b6f9ebb Mon Sep 17 00:00:00 2001 From: starr-openai Date: Wed, 18 Mar 2026 17:30:05 -0700 Subject: [PATCH] Add exec-server stub server and protocol docs (#15089) Stacked PR 1/3. This is the initialize-only exec-server stub slice: binary/client scaffolding and protocol docs, without exec/filesystem implementation. --------- Co-authored-by: Codex --- codex-rs/Cargo.lock | 20 + codex-rs/Cargo.toml | 1 + codex-rs/exec-server/BUILD.bazel | 7 + codex-rs/exec-server/Cargo.toml | 40 ++ codex-rs/exec-server/README.md | 282 ++++++++++++++ .../exec-server/src/bin/codex-exec-server.rs | 20 + codex-rs/exec-server/src/client.rs | 267 ++++++++++++++ .../exec-server/src/client/local_backend.rs | 38 ++ codex-rs/exec-server/src/client_api.rs | 17 + codex-rs/exec-server/src/connection.rs | 275 ++++++++++++++ codex-rs/exec-server/src/lib.rs | 21 ++ codex-rs/exec-server/src/local.rs | 71 ++++ codex-rs/exec-server/src/protocol.rs | 15 + codex-rs/exec-server/src/rpc.rs | 347 ++++++++++++++++++ codex-rs/exec-server/src/server.rs | 18 + codex-rs/exec-server/src/server/handler.rs | 40 ++ codex-rs/exec-server/src/server/jsonrpc.rs | 53 +++ codex-rs/exec-server/src/server/processor.rs | 121 ++++++ codex-rs/exec-server/src/server/transport.rs | 118 ++++++ .../exec-server/src/server/transport_tests.rs | 54 +++ codex-rs/exec-server/tests/stdio_smoke.rs | 129 +++++++ codex-rs/exec-server/tests/websocket_smoke.rs | 229 ++++++++++++ 22 files changed, 2183 insertions(+) create mode 100644 codex-rs/exec-server/BUILD.bazel create mode 100644 codex-rs/exec-server/Cargo.toml create mode 100644 codex-rs/exec-server/README.md create mode 100644 codex-rs/exec-server/src/bin/codex-exec-server.rs create mode 100644 codex-rs/exec-server/src/client.rs create mode 100644 codex-rs/exec-server/src/client/local_backend.rs create mode 100644 codex-rs/exec-server/src/client_api.rs create mode 100644 codex-rs/exec-server/src/connection.rs create mode 100644 codex-rs/exec-server/src/lib.rs create mode 100644 codex-rs/exec-server/src/local.rs create mode 100644 codex-rs/exec-server/src/protocol.rs create mode 100644 codex-rs/exec-server/src/rpc.rs create mode 100644 codex-rs/exec-server/src/server.rs create mode 100644 codex-rs/exec-server/src/server/handler.rs create mode 100644 codex-rs/exec-server/src/server/jsonrpc.rs create mode 100644 codex-rs/exec-server/src/server/processor.rs create mode 100644 codex-rs/exec-server/src/server/transport.rs create mode 100644 codex-rs/exec-server/src/server/transport_tests.rs create mode 100644 codex-rs/exec-server/tests/stdio_smoke.rs create mode 100644 codex-rs/exec-server/tests/websocket_smoke.rs diff --git a/codex-rs/Cargo.lock b/codex-rs/Cargo.lock index 771b714e0..5965204ce 100644 --- a/codex-rs/Cargo.lock +++ b/codex-rs/Cargo.lock @@ -2003,6 +2003,26 @@ dependencies = [ "wiremock", ] +[[package]] +name = "codex-exec-server" +version = "0.0.0" +dependencies = [ + "anyhow", + "base64 0.22.1", + "clap", + "codex-app-server-protocol", + "codex-utils-cargo-bin", + "codex-utils-pty", + "futures", + "pretty_assertions", + "serde", + "serde_json", + "thiserror 2.0.18", + "tokio", + "tokio-tungstenite", + "tracing", +] + [[package]] name = "codex-execpolicy" version = "0.0.0" diff --git a/codex-rs/Cargo.toml b/codex-rs/Cargo.toml index 35ff64195..7d4b8792b 100644 --- a/codex-rs/Cargo.toml +++ b/codex-rs/Cargo.toml @@ -26,6 +26,7 @@ members = [ "hooks", "secrets", "exec", + "exec-server", "execpolicy", "execpolicy-legacy", "keyring-store", diff --git a/codex-rs/exec-server/BUILD.bazel b/codex-rs/exec-server/BUILD.bazel new file mode 100644 index 000000000..5d62c68ca --- /dev/null +++ b/codex-rs/exec-server/BUILD.bazel @@ -0,0 +1,7 @@ +load("//:defs.bzl", "codex_rust_crate") + +codex_rust_crate( + name = "exec-server", + crate_name = "codex_exec_server", + test_tags = ["no-sandbox"], +) diff --git a/codex-rs/exec-server/Cargo.toml b/codex-rs/exec-server/Cargo.toml new file mode 100644 index 000000000..7eeada396 --- /dev/null +++ b/codex-rs/exec-server/Cargo.toml @@ -0,0 +1,40 @@ +[package] +name = "codex-exec-server" +version.workspace = true +edition.workspace = true +license.workspace = true + +[lib] +doctest = false + +[[bin]] +name = "codex-exec-server" +path = "src/bin/codex-exec-server.rs" + +[lints] +workspace = true + +[dependencies] +clap = { workspace = true, features = ["derive"] } +codex-app-server-protocol = { workspace = true } +futures = { workspace = true } +serde = { workspace = true, features = ["derive"] } +serde_json = { workspace = true } +thiserror = { workspace = true } +tokio = { workspace = true, features = [ + "io-std", + "io-util", + "macros", + "net", + "process", + "rt-multi-thread", + "sync", + "time", +] } +tokio-tungstenite = { workspace = true } +tracing = { workspace = true } + +[dev-dependencies] +anyhow = { workspace = true } +codex-utils-cargo-bin = { workspace = true } +pretty_assertions = { workspace = true } diff --git a/codex-rs/exec-server/README.md b/codex-rs/exec-server/README.md new file mode 100644 index 000000000..c4194fda4 --- /dev/null +++ b/codex-rs/exec-server/README.md @@ -0,0 +1,282 @@ +# codex-exec-server + +`codex-exec-server` is a small standalone JSON-RPC server for spawning +and controlling subprocesses through `codex-utils-pty`. + +This PR intentionally lands only the standalone binary, client, wire protocol, +and docs. Exec and filesystem methods are stubbed server-side here and are +implemented in follow-up PRs. + +It currently provides: + +- a standalone binary: `codex-exec-server` +- a Rust client: `ExecServerClient` +- a small protocol module with shared request/response types + +This crate is intentionally narrow. It is not wired into the main Codex CLI or +unified-exec in this PR; it is only the standalone transport layer. + +## Transport + +The server speaks the shared `codex-app-server-protocol` message envelope on +the wire. + +The standalone binary supports: + +- `ws://IP:PORT` (default) +- `stdio://` + +Wire framing: + +- websocket: one JSON-RPC message per websocket text frame +- stdio: one newline-delimited JSON-RPC message per line on stdin/stdout + +## Lifecycle + +Each connection follows this sequence: + +1. Send `initialize`. +2. Wait for the `initialize` response. +3. Send `initialized`. +4. Call exec or filesystem RPCs once the follow-up implementation PRs land. + +If the server receives any notification other than `initialized`, it replies +with an error using request id `-1`. + +If the stdio connection closes, the server terminates any remaining managed +processes before exiting. + +## API + +### `initialize` + +Initial handshake request. + +Request params: + +```json +{ + "clientName": "my-client" +} +``` + +Response: + +```json +{} +``` + +### `initialized` + +Handshake acknowledgement notification sent by the client after a successful +`initialize` response. + +Params are currently ignored. Sending any other notification method is treated +as an invalid request. + +### `command/exec` + +Starts a new managed process. + +Request params: + +```json +{ + "processId": "proc-1", + "argv": ["bash", "-lc", "printf 'hello\\n'"], + "cwd": "/absolute/working/directory", + "env": { + "PATH": "/usr/bin:/bin" + }, + "tty": true, + "outputBytesCap": 16384, + "arg0": null +} +``` + +Field definitions: + +- `processId`: caller-chosen stable id for this process within the connection. +- `argv`: command vector. It must be non-empty. +- `cwd`: absolute working directory used for the child process. +- `env`: environment variables passed to the child process. +- `tty`: when `true`, spawn a PTY-backed interactive process; when `false`, + spawn a pipe-backed process with closed stdin. +- `outputBytesCap`: maximum retained stdout/stderr bytes per stream for the + in-memory buffer. Defaults to `codex_utils_pty::DEFAULT_OUTPUT_BYTES_CAP`. +- `arg0`: optional argv0 override forwarded to `codex-utils-pty`. + +Response: + +```json +{ + "processId": "proc-1", + "running": true, + "exitCode": null, + "stdout": null, + "stderr": null +} +``` + +Behavior notes: + +- Reusing an existing `processId` is rejected. +- PTY-backed processes accept later writes through `command/exec/write`. +- Pipe-backed processes are launched with stdin closed and reject writes. +- Output is streamed asynchronously via `command/exec/outputDelta`. +- Exit is reported asynchronously via `command/exec/exited`. + +### `command/exec/write` + +Writes raw bytes to a running PTY-backed process stdin. + +Request params: + +```json +{ + "processId": "proc-1", + "chunk": "aGVsbG8K" +} +``` + +`chunk` is base64-encoded raw bytes. In the example above it is `hello\n`. + +Response: + +```json +{ + "accepted": true +} +``` + +Behavior notes: + +- Writes to an unknown `processId` are rejected. +- Writes to a non-PTY process are rejected because stdin is already closed. + +### `command/exec/terminate` + +Terminates a running managed process. + +Request params: + +```json +{ + "processId": "proc-1" +} +``` + +Response: + +```json +{ + "running": true +} +``` + +If the process is already unknown or already removed, the server responds with: + +```json +{ + "running": false +} +``` + +## Notifications + +### `command/exec/outputDelta` + +Streaming output chunk from a running process. + +Params: + +```json +{ + "processId": "proc-1", + "stream": "stdout", + "chunk": "aGVsbG8K" +} +``` + +Fields: + +- `processId`: process identifier +- `stream`: `"stdout"` or `"stderr"` +- `chunk`: base64-encoded output bytes + +### `command/exec/exited` + +Final process exit notification. + +Params: + +```json +{ + "processId": "proc-1", + "exitCode": 0 +} +``` + +## Errors + +The server returns JSON-RPC errors with these codes: + +- `-32600`: invalid request +- `-32602`: invalid params +- `-32603`: internal error + +Typical error cases: + +- unknown method +- malformed params +- empty `argv` +- duplicate `processId` +- writes to unknown processes +- writes to non-PTY processes + +## Rust surface + +The crate exports: + +- `ExecServerClient` +- `ExecServerLaunchCommand` +- `ExecServerProcess` +- `ExecServerError` +- protocol structs such as `ExecParams`, `ExecResponse`, + `WriteParams`, `TerminateParams`, `ExecOutputDeltaNotification`, and + `ExecExitedNotification` +- `run_main()` for embedding the stdio server in a binary + +## Example session + +Initialize: + +```json +{"id":1,"method":"initialize","params":{"clientName":"example-client"}} +{"id":1,"result":{}} +{"method":"initialized","params":{}} +``` + +Start a process: + +```json +{"id":2,"method":"command/exec","params":{"processId":"proc-1","argv":["bash","-lc","printf 'ready\\n'; while IFS= read -r line; do printf 'echo:%s\\n' \"$line\"; done"],"cwd":"/tmp","env":{"PATH":"/usr/bin:/bin"},"tty":true,"outputBytesCap":4096,"arg0":null}} +{"id":2,"result":{"processId":"proc-1","running":true,"exitCode":null,"stdout":null,"stderr":null}} +{"method":"command/exec/outputDelta","params":{"processId":"proc-1","stream":"stdout","chunk":"cmVhZHkK"}} +``` + +Write to the process: + +```json +{"id":3,"method":"command/exec/write","params":{"processId":"proc-1","chunk":"aGVsbG8K"}} +{"id":3,"result":{"accepted":true}} +{"method":"command/exec/outputDelta","params":{"processId":"proc-1","stream":"stdout","chunk":"ZWNobzpoZWxsbwo="}} +``` + +Terminate it: + +```json +{"id":4,"method":"command/exec/terminate","params":{"processId":"proc-1"}} +{"id":4,"result":{"running":true}} +{"method":"command/exec/exited","params":{"processId":"proc-1","exitCode":0}} +``` diff --git a/codex-rs/exec-server/src/bin/codex-exec-server.rs b/codex-rs/exec-server/src/bin/codex-exec-server.rs new file mode 100644 index 000000000..7bcb14190 --- /dev/null +++ b/codex-rs/exec-server/src/bin/codex-exec-server.rs @@ -0,0 +1,20 @@ +use clap::Parser; +use codex_exec_server::ExecServerTransport; + +#[derive(Debug, Parser)] +struct ExecServerArgs { + /// Transport endpoint URL. Supported values: `ws://IP:PORT` (default), + /// `stdio://`. + #[arg( + long = "listen", + value_name = "URL", + default_value = ExecServerTransport::DEFAULT_LISTEN_URL + )] + listen: ExecServerTransport, +} + +#[tokio::main] +async fn main() -> Result<(), Box> { + let args = ExecServerArgs::parse(); + codex_exec_server::run_main_with_transport(args.listen).await +} diff --git a/codex-rs/exec-server/src/client.rs b/codex-rs/exec-server/src/client.rs new file mode 100644 index 000000000..9830771a0 --- /dev/null +++ b/codex-rs/exec-server/src/client.rs @@ -0,0 +1,267 @@ +use std::sync::Arc; +use std::time::Duration; + +use tokio::io::AsyncRead; +use tokio::io::AsyncWrite; +use tokio::time::timeout; +use tokio_tungstenite::connect_async; +use tracing::warn; + +use crate::client_api::ExecServerClientConnectOptions; +use crate::client_api::RemoteExecServerConnectArgs; +use crate::connection::JsonRpcConnection; +use crate::protocol::INITIALIZE_METHOD; +use crate::protocol::INITIALIZED_METHOD; +use crate::protocol::InitializeParams; +use crate::protocol::InitializeResponse; +use crate::rpc::RpcCallError; +use crate::rpc::RpcClient; +use crate::rpc::RpcClientEvent; + +mod local_backend; +use local_backend::LocalBackend; + +const CONNECT_TIMEOUT: Duration = Duration::from_secs(10); +const INITIALIZE_TIMEOUT: Duration = Duration::from_secs(10); + +impl Default for ExecServerClientConnectOptions { + fn default() -> Self { + Self { + client_name: "codex-core".to_string(), + initialize_timeout: INITIALIZE_TIMEOUT, + } + } +} + +impl From for ExecServerClientConnectOptions { + fn from(value: RemoteExecServerConnectArgs) -> Self { + Self { + client_name: value.client_name, + initialize_timeout: value.initialize_timeout, + } + } +} + +impl RemoteExecServerConnectArgs { + pub fn new(websocket_url: String, client_name: String) -> Self { + Self { + websocket_url, + client_name, + connect_timeout: CONNECT_TIMEOUT, + initialize_timeout: INITIALIZE_TIMEOUT, + } + } +} + +enum ClientBackend { + Remote(RpcClient), + InProcess(LocalBackend), +} + +impl ClientBackend { + fn as_local(&self) -> Option<&LocalBackend> { + match self { + ClientBackend::Remote(_) => None, + ClientBackend::InProcess(backend) => Some(backend), + } + } + + fn as_remote(&self) -> Option<&RpcClient> { + match self { + ClientBackend::Remote(client) => Some(client), + ClientBackend::InProcess(_) => None, + } + } +} + +struct Inner { + backend: ClientBackend, + reader_task: tokio::task::JoinHandle<()>, +} + +impl Drop for Inner { + fn drop(&mut self) { + if let Some(backend) = self.backend.as_local() + && let Ok(handle) = tokio::runtime::Handle::try_current() + { + let backend = backend.clone(); + handle.spawn(async move { + backend.shutdown().await; + }); + } + self.reader_task.abort(); + } +} + +#[derive(Clone)] +pub struct ExecServerClient { + inner: Arc, +} + +#[derive(Debug, thiserror::Error)] +pub enum ExecServerError { + #[error("failed to spawn exec-server: {0}")] + Spawn(#[source] std::io::Error), + #[error("timed out connecting to exec-server websocket `{url}` after {timeout:?}")] + WebSocketConnectTimeout { url: String, timeout: Duration }, + #[error("failed to connect to exec-server websocket `{url}`: {source}")] + WebSocketConnect { + url: String, + #[source] + source: tokio_tungstenite::tungstenite::Error, + }, + #[error("timed out waiting for exec-server initialize handshake after {timeout:?}")] + InitializeTimedOut { timeout: Duration }, + #[error("exec-server transport closed")] + Closed, + #[error("failed to serialize or deserialize exec-server JSON: {0}")] + Json(#[from] serde_json::Error), + #[error("exec-server protocol error: {0}")] + Protocol(String), + #[error("exec-server rejected request ({code}): {message}")] + Server { code: i64, message: String }, +} + +impl ExecServerClient { + pub async fn connect_in_process( + options: ExecServerClientConnectOptions, + ) -> Result { + let backend = LocalBackend::new(crate::server::ExecServerHandler::new()); + let inner = Arc::new(Inner { + backend: ClientBackend::InProcess(backend), + reader_task: tokio::spawn(async {}), + }); + let client = Self { inner }; + client.initialize(options).await?; + Ok(client) + } + + pub async fn connect_stdio( + stdin: W, + stdout: R, + options: ExecServerClientConnectOptions, + ) -> Result + where + R: AsyncRead + Unpin + Send + 'static, + W: AsyncWrite + Unpin + Send + 'static, + { + Self::connect( + JsonRpcConnection::from_stdio(stdout, stdin, "exec-server stdio".to_string()), + options, + ) + .await + } + + pub async fn connect_websocket( + args: RemoteExecServerConnectArgs, + ) -> Result { + let websocket_url = args.websocket_url.clone(); + let connect_timeout = args.connect_timeout; + let (stream, _) = timeout(connect_timeout, connect_async(websocket_url.as_str())) + .await + .map_err(|_| ExecServerError::WebSocketConnectTimeout { + url: websocket_url.clone(), + timeout: connect_timeout, + })? + .map_err(|source| ExecServerError::WebSocketConnect { + url: websocket_url.clone(), + source, + })?; + + Self::connect( + JsonRpcConnection::from_websocket( + stream, + format!("exec-server websocket {websocket_url}"), + ), + args.into(), + ) + .await + } + + pub async fn initialize( + &self, + options: ExecServerClientConnectOptions, + ) -> Result { + let ExecServerClientConnectOptions { + client_name, + initialize_timeout, + } = options; + + timeout(initialize_timeout, async { + let response = if let Some(backend) = self.inner.backend.as_local() { + backend.initialize().await? + } else { + let params = InitializeParams { client_name }; + let Some(remote) = self.inner.backend.as_remote() else { + return Err(ExecServerError::Protocol( + "remote backend missing during initialize".to_string(), + )); + }; + remote.call(INITIALIZE_METHOD, ¶ms).await? + }; + self.notify_initialized().await?; + Ok(response) + }) + .await + .map_err(|_| ExecServerError::InitializeTimedOut { + timeout: initialize_timeout, + })? + } + + async fn connect( + connection: JsonRpcConnection, + options: ExecServerClientConnectOptions, + ) -> Result { + let (rpc_client, mut events_rx) = RpcClient::new(connection); + let reader_task = tokio::spawn(async move { + while let Some(event) = events_rx.recv().await { + match event { + RpcClientEvent::Notification(notification) => { + warn!( + "ignoring unexpected exec-server notification during stub phase: {}", + notification.method + ); + } + RpcClientEvent::Disconnected { reason } => { + if let Some(reason) = reason { + warn!("exec-server client transport disconnected: {reason}"); + } + return; + } + } + } + }); + + let client = Self { + inner: Arc::new(Inner { + backend: ClientBackend::Remote(rpc_client), + reader_task, + }), + }; + client.initialize(options).await?; + Ok(client) + } + + async fn notify_initialized(&self) -> Result<(), ExecServerError> { + match &self.inner.backend { + ClientBackend::Remote(client) => client + .notify(INITIALIZED_METHOD, &serde_json::json!({})) + .await + .map_err(ExecServerError::Json), + ClientBackend::InProcess(backend) => backend.initialized().await, + } + } +} + +impl From for ExecServerError { + fn from(value: RpcCallError) -> Self { + match value { + RpcCallError::Closed => Self::Closed, + RpcCallError::Json(err) => Self::Json(err), + RpcCallError::Server(error) => Self::Server { + code: error.code, + message: error.message, + }, + } + } +} diff --git a/codex-rs/exec-server/src/client/local_backend.rs b/codex-rs/exec-server/src/client/local_backend.rs new file mode 100644 index 000000000..8f9a2481f --- /dev/null +++ b/codex-rs/exec-server/src/client/local_backend.rs @@ -0,0 +1,38 @@ +use std::sync::Arc; + +use crate::protocol::InitializeResponse; +use crate::server::ExecServerHandler; + +use super::ExecServerError; + +#[derive(Clone)] +pub(super) struct LocalBackend { + handler: Arc, +} + +impl LocalBackend { + pub(super) fn new(handler: ExecServerHandler) -> Self { + Self { + handler: Arc::new(handler), + } + } + + pub(super) async fn shutdown(&self) { + self.handler.shutdown().await; + } + + pub(super) async fn initialize(&self) -> Result { + self.handler + .initialize() + .map_err(|error| ExecServerError::Server { + code: error.code, + message: error.message, + }) + } + + pub(super) async fn initialized(&self) -> Result<(), ExecServerError> { + self.handler + .initialized() + .map_err(ExecServerError::Protocol) + } +} diff --git a/codex-rs/exec-server/src/client_api.rs b/codex-rs/exec-server/src/client_api.rs new file mode 100644 index 000000000..6e8976341 --- /dev/null +++ b/codex-rs/exec-server/src/client_api.rs @@ -0,0 +1,17 @@ +use std::time::Duration; + +/// Connection options for any exec-server client transport. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ExecServerClientConnectOptions { + pub client_name: String, + pub initialize_timeout: Duration, +} + +/// WebSocket connection arguments for a remote exec-server. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct RemoteExecServerConnectArgs { + pub websocket_url: String, + pub client_name: String, + pub connect_timeout: Duration, + pub initialize_timeout: Duration, +} diff --git a/codex-rs/exec-server/src/connection.rs b/codex-rs/exec-server/src/connection.rs new file mode 100644 index 000000000..af03fc068 --- /dev/null +++ b/codex-rs/exec-server/src/connection.rs @@ -0,0 +1,275 @@ +use codex_app_server_protocol::JSONRPCMessage; +use futures::SinkExt; +use futures::StreamExt; +use tokio::io::AsyncBufReadExt; +use tokio::io::AsyncRead; +use tokio::io::AsyncWrite; +use tokio::io::AsyncWriteExt; +use tokio::io::BufReader; +use tokio::io::BufWriter; +use tokio::sync::mpsc; +use tokio_tungstenite::WebSocketStream; +use tokio_tungstenite::tungstenite::Message; + +pub(crate) const CHANNEL_CAPACITY: usize = 128; + +#[derive(Debug)] +pub(crate) enum JsonRpcConnectionEvent { + Message(JSONRPCMessage), + MalformedMessage { reason: String }, + Disconnected { reason: Option }, +} + +pub(crate) struct JsonRpcConnection { + outgoing_tx: mpsc::Sender, + incoming_rx: mpsc::Receiver, + task_handles: Vec>, +} + +impl JsonRpcConnection { + pub(crate) fn from_stdio(reader: R, writer: W, connection_label: String) -> Self + where + R: AsyncRead + Unpin + Send + 'static, + W: AsyncWrite + Unpin + Send + 'static, + { + let (outgoing_tx, mut outgoing_rx) = mpsc::channel(CHANNEL_CAPACITY); + let (incoming_tx, incoming_rx) = mpsc::channel(CHANNEL_CAPACITY); + + let reader_label = connection_label.clone(); + let incoming_tx_for_reader = incoming_tx.clone(); + let reader_task = tokio::spawn(async move { + let mut lines = BufReader::new(reader).lines(); + loop { + match lines.next_line().await { + Ok(Some(line)) => { + if line.trim().is_empty() { + continue; + } + match serde_json::from_str::(&line) { + Ok(message) => { + if incoming_tx_for_reader + .send(JsonRpcConnectionEvent::Message(message)) + .await + .is_err() + { + break; + } + } + Err(err) => { + send_malformed_message( + &incoming_tx_for_reader, + Some(format!( + "failed to parse JSON-RPC message from {reader_label}: {err}" + )), + ) + .await; + } + } + } + Ok(None) => { + send_disconnected(&incoming_tx_for_reader, /*reason*/ None).await; + break; + } + Err(err) => { + send_disconnected( + &incoming_tx_for_reader, + Some(format!( + "failed to read JSON-RPC message from {reader_label}: {err}" + )), + ) + .await; + break; + } + } + } + }); + + let writer_task = tokio::spawn(async move { + let mut writer = BufWriter::new(writer); + while let Some(message) = outgoing_rx.recv().await { + if let Err(err) = write_jsonrpc_line_message(&mut writer, &message).await { + send_disconnected( + &incoming_tx, + Some(format!( + "failed to write JSON-RPC message to {connection_label}: {err}" + )), + ) + .await; + break; + } + } + }); + + Self { + outgoing_tx, + incoming_rx, + task_handles: vec![reader_task, writer_task], + } + } + + pub(crate) fn from_websocket(stream: WebSocketStream, connection_label: String) -> Self + where + S: AsyncRead + AsyncWrite + Unpin + Send + 'static, + { + let (outgoing_tx, mut outgoing_rx) = mpsc::channel(CHANNEL_CAPACITY); + let (incoming_tx, incoming_rx) = mpsc::channel(CHANNEL_CAPACITY); + let (mut websocket_writer, mut websocket_reader) = stream.split(); + + let reader_label = connection_label.clone(); + let incoming_tx_for_reader = incoming_tx.clone(); + let reader_task = tokio::spawn(async move { + loop { + match websocket_reader.next().await { + Some(Ok(Message::Text(text))) => { + match serde_json::from_str::(text.as_ref()) { + Ok(message) => { + if incoming_tx_for_reader + .send(JsonRpcConnectionEvent::Message(message)) + .await + .is_err() + { + break; + } + } + Err(err) => { + send_malformed_message( + &incoming_tx_for_reader, + Some(format!( + "failed to parse websocket JSON-RPC message from {reader_label}: {err}" + )), + ) + .await; + } + } + } + Some(Ok(Message::Binary(bytes))) => { + match serde_json::from_slice::(bytes.as_ref()) { + Ok(message) => { + if incoming_tx_for_reader + .send(JsonRpcConnectionEvent::Message(message)) + .await + .is_err() + { + break; + } + } + Err(err) => { + send_malformed_message( + &incoming_tx_for_reader, + Some(format!( + "failed to parse websocket JSON-RPC message from {reader_label}: {err}" + )), + ) + .await; + } + } + } + Some(Ok(Message::Close(_))) => { + send_disconnected(&incoming_tx_for_reader, /*reason*/ None).await; + break; + } + Some(Ok(Message::Ping(_))) | Some(Ok(Message::Pong(_))) => {} + Some(Ok(_)) => {} + Some(Err(err)) => { + send_disconnected( + &incoming_tx_for_reader, + Some(format!( + "failed to read websocket JSON-RPC message from {reader_label}: {err}" + )), + ) + .await; + break; + } + None => { + send_disconnected(&incoming_tx_for_reader, /*reason*/ None).await; + break; + } + } + } + }); + + let writer_task = tokio::spawn(async move { + while let Some(message) = outgoing_rx.recv().await { + match serialize_jsonrpc_message(&message) { + Ok(encoded) => { + if let Err(err) = websocket_writer.send(Message::Text(encoded.into())).await + { + send_disconnected( + &incoming_tx, + Some(format!( + "failed to write websocket JSON-RPC message to {connection_label}: {err}" + )), + ) + .await; + break; + } + } + Err(err) => { + send_disconnected( + &incoming_tx, + Some(format!( + "failed to serialize JSON-RPC message for {connection_label}: {err}" + )), + ) + .await; + break; + } + } + } + }); + + Self { + outgoing_tx, + incoming_rx, + task_handles: vec![reader_task, writer_task], + } + } + + pub(crate) fn into_parts( + self, + ) -> ( + mpsc::Sender, + mpsc::Receiver, + Vec>, + ) { + (self.outgoing_tx, self.incoming_rx, self.task_handles) + } +} + +async fn send_disconnected( + incoming_tx: &mpsc::Sender, + reason: Option, +) { + let _ = incoming_tx + .send(JsonRpcConnectionEvent::Disconnected { reason }) + .await; +} + +async fn send_malformed_message( + incoming_tx: &mpsc::Sender, + reason: Option, +) { + let _ = incoming_tx + .send(JsonRpcConnectionEvent::MalformedMessage { + reason: reason.unwrap_or_else(|| "malformed JSON-RPC message".to_string()), + }) + .await; +} + +async fn write_jsonrpc_line_message( + writer: &mut BufWriter, + message: &JSONRPCMessage, +) -> std::io::Result<()> +where + W: AsyncWrite + Unpin, +{ + let encoded = + serialize_jsonrpc_message(message).map_err(|err| std::io::Error::other(err.to_string()))?; + writer.write_all(encoded.as_bytes()).await?; + writer.write_all(b"\n").await?; + writer.flush().await +} + +fn serialize_jsonrpc_message(message: &JSONRPCMessage) -> Result { + serde_json::to_string(message) +} diff --git a/codex-rs/exec-server/src/lib.rs b/codex-rs/exec-server/src/lib.rs new file mode 100644 index 000000000..e204d6e08 --- /dev/null +++ b/codex-rs/exec-server/src/lib.rs @@ -0,0 +1,21 @@ +mod client; +mod client_api; +mod connection; +mod local; +mod protocol; +mod rpc; +mod server; + +pub use client::ExecServerClient; +pub use client::ExecServerError; +pub use client_api::ExecServerClientConnectOptions; +pub use client_api::RemoteExecServerConnectArgs; +pub use local::ExecServerLaunchCommand; +pub use local::SpawnedExecServer; +pub use local::spawn_local_exec_server; +pub use protocol::InitializeParams; +pub use protocol::InitializeResponse; +pub use server::ExecServerTransport; +pub use server::ExecServerTransportParseError; +pub use server::run_main; +pub use server::run_main_with_transport; diff --git a/codex-rs/exec-server/src/local.rs b/codex-rs/exec-server/src/local.rs new file mode 100644 index 000000000..e51c94394 --- /dev/null +++ b/codex-rs/exec-server/src/local.rs @@ -0,0 +1,71 @@ +use std::path::PathBuf; +use std::process::Stdio; +use std::sync::Mutex as StdMutex; + +use tokio::process::Child; +use tokio::process::Command; + +use crate::client::ExecServerClient; +use crate::client::ExecServerError; +use crate::client_api::ExecServerClientConnectOptions; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ExecServerLaunchCommand { + pub program: PathBuf, + pub args: Vec, +} + +pub struct SpawnedExecServer { + client: ExecServerClient, + child: StdMutex>, +} + +impl SpawnedExecServer { + pub fn client(&self) -> &ExecServerClient { + &self.client + } +} + +impl Drop for SpawnedExecServer { + fn drop(&mut self) { + if let Ok(mut child_guard) = self.child.lock() + && let Some(child) = child_guard.as_mut() + { + let _ = child.start_kill(); + } + } +} + +pub async fn spawn_local_exec_server( + command: ExecServerLaunchCommand, + options: ExecServerClientConnectOptions, +) -> Result { + let mut child = Command::new(&command.program); + child.args(&command.args); + child.args(["--listen", "stdio://"]); + child.stdin(Stdio::piped()); + child.stdout(Stdio::piped()); + child.stderr(Stdio::inherit()); + child.kill_on_drop(true); + + let mut child = child.spawn().map_err(ExecServerError::Spawn)?; + let stdin = child.stdin.take().ok_or_else(|| { + ExecServerError::Protocol("exec-server stdin was not captured".to_string()) + })?; + let stdout = child.stdout.take().ok_or_else(|| { + ExecServerError::Protocol("exec-server stdout was not captured".to_string()) + })?; + + let client = match ExecServerClient::connect_stdio(stdin, stdout, options).await { + Ok(client) => client, + Err(err) => { + let _ = child.start_kill(); + return Err(err); + } + }; + + Ok(SpawnedExecServer { + client, + child: StdMutex::new(Some(child)), + }) +} diff --git a/codex-rs/exec-server/src/protocol.rs b/codex-rs/exec-server/src/protocol.rs new file mode 100644 index 000000000..165378fb5 --- /dev/null +++ b/codex-rs/exec-server/src/protocol.rs @@ -0,0 +1,15 @@ +use serde::Deserialize; +use serde::Serialize; + +pub const INITIALIZE_METHOD: &str = "initialize"; +pub const INITIALIZED_METHOD: &str = "initialized"; + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct InitializeParams { + pub client_name: String, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct InitializeResponse {} diff --git a/codex-rs/exec-server/src/rpc.rs b/codex-rs/exec-server/src/rpc.rs new file mode 100644 index 000000000..0c8b5cdf3 --- /dev/null +++ b/codex-rs/exec-server/src/rpc.rs @@ -0,0 +1,347 @@ +use std::collections::HashMap; +use std::sync::Arc; +use std::sync::atomic::AtomicI64; +use std::sync::atomic::Ordering; + +use codex_app_server_protocol::JSONRPCError; +use codex_app_server_protocol::JSONRPCErrorError; +use codex_app_server_protocol::JSONRPCMessage; +use codex_app_server_protocol::JSONRPCNotification; +use codex_app_server_protocol::JSONRPCRequest; +use codex_app_server_protocol::JSONRPCResponse; +use codex_app_server_protocol::RequestId; +use serde::Serialize; +use serde::de::DeserializeOwned; +use serde_json::Value; +use tokio::sync::Mutex; +use tokio::sync::mpsc; +use tokio::sync::oneshot; +use tokio::task::JoinHandle; +use tracing::warn; + +use crate::connection::JsonRpcConnection; +use crate::connection::JsonRpcConnectionEvent; + +type PendingRequest = oneshot::Sender>; + +#[derive(Debug)] +pub(crate) enum RpcClientEvent { + Notification(JSONRPCNotification), + Disconnected { reason: Option }, +} + +pub(crate) struct RpcClient { + write_tx: mpsc::Sender, + pending: Arc>>, + next_request_id: AtomicI64, + transport_tasks: Vec>, + reader_task: JoinHandle<()>, +} + +impl RpcClient { + pub(crate) fn new(connection: JsonRpcConnection) -> (Self, mpsc::Receiver) { + let (write_tx, mut incoming_rx, transport_tasks) = connection.into_parts(); + let pending = Arc::new(Mutex::new(HashMap::::new())); + let (event_tx, event_rx) = mpsc::channel(128); + + let pending_for_reader = Arc::clone(&pending); + let reader_task = tokio::spawn(async move { + while let Some(event) = incoming_rx.recv().await { + match event { + JsonRpcConnectionEvent::Message(message) => { + if let Err(err) = + handle_server_message(&pending_for_reader, &event_tx, message).await + { + warn!("JSON-RPC client closing after protocol error: {err}"); + break; + } + } + JsonRpcConnectionEvent::MalformedMessage { reason } => { + warn!("JSON-RPC client closing after malformed server message: {reason}"); + let _ = event_tx + .send(RpcClientEvent::Disconnected { + reason: Some(reason), + }) + .await; + drain_pending(&pending_for_reader).await; + return; + } + JsonRpcConnectionEvent::Disconnected { reason } => { + let _ = event_tx.send(RpcClientEvent::Disconnected { reason }).await; + drain_pending(&pending_for_reader).await; + return; + } + } + } + + let _ = event_tx + .send(RpcClientEvent::Disconnected { reason: None }) + .await; + drain_pending(&pending_for_reader).await; + }); + + ( + Self { + write_tx, + pending, + next_request_id: AtomicI64::new(1), + transport_tasks, + reader_task, + }, + event_rx, + ) + } + + pub(crate) async fn notify( + &self, + method: &str, + params: &P, + ) -> Result<(), serde_json::Error> { + let params = serde_json::to_value(params)?; + self.write_tx + .send(JSONRPCMessage::Notification(JSONRPCNotification { + method: method.to_string(), + params: Some(params), + })) + .await + .map_err(|_| { + serde_json::Error::io(std::io::Error::new( + std::io::ErrorKind::BrokenPipe, + "JSON-RPC transport closed", + )) + }) + } + + pub(crate) async fn call(&self, method: &str, params: &P) -> Result + where + P: Serialize, + T: DeserializeOwned, + { + let request_id = RequestId::Integer(self.next_request_id.fetch_add(1, Ordering::SeqCst)); + let (response_tx, response_rx) = oneshot::channel(); + self.pending + .lock() + .await + .insert(request_id.clone(), response_tx); + + let params = match serde_json::to_value(params) { + Ok(params) => params, + Err(err) => { + self.pending.lock().await.remove(&request_id); + return Err(RpcCallError::Json(err)); + } + }; + if self + .write_tx + .send(JSONRPCMessage::Request(JSONRPCRequest { + id: request_id.clone(), + method: method.to_string(), + params: Some(params), + trace: None, + })) + .await + .is_err() + { + self.pending.lock().await.remove(&request_id); + return Err(RpcCallError::Closed); + } + + let result = response_rx.await.map_err(|_| RpcCallError::Closed)?; + let response = match result { + Ok(response) => response, + Err(error) => return Err(RpcCallError::Server(error)), + }; + serde_json::from_value(response).map_err(RpcCallError::Json) + } + + #[cfg(test)] + #[allow(dead_code)] + pub(crate) async fn pending_request_count(&self) -> usize { + self.pending.lock().await.len() + } +} + +impl Drop for RpcClient { + fn drop(&mut self) { + for task in &self.transport_tasks { + task.abort(); + } + self.reader_task.abort(); + } +} + +#[derive(Debug)] +pub(crate) enum RpcCallError { + Closed, + Json(serde_json::Error), + Server(JSONRPCErrorError), +} + +async fn handle_server_message( + pending: &Mutex>, + event_tx: &mpsc::Sender, + message: JSONRPCMessage, +) -> Result<(), String> { + match message { + JSONRPCMessage::Response(JSONRPCResponse { id, result }) => { + if let Some(pending) = pending.lock().await.remove(&id) { + let _ = pending.send(Ok(result)); + } + } + JSONRPCMessage::Error(JSONRPCError { id, error }) => { + if let Some(pending) = pending.lock().await.remove(&id) { + let _ = pending.send(Err(error)); + } + } + JSONRPCMessage::Notification(notification) => { + let _ = event_tx + .send(RpcClientEvent::Notification(notification)) + .await; + } + JSONRPCMessage::Request(request) => { + return Err(format!( + "unexpected JSON-RPC request from remote server: {}", + request.method + )); + } + } + + Ok(()) +} + +async fn drain_pending(pending: &Mutex>) { + let pending = { + let mut pending = pending.lock().await; + pending + .drain() + .map(|(_, pending)| pending) + .collect::>() + }; + for pending in pending { + let _ = pending.send(Err(JSONRPCErrorError { + code: -32000, + data: None, + message: "JSON-RPC transport closed".to_string(), + })); + } +} + +#[cfg(test)] +mod tests { + use std::time::Duration; + + use codex_app_server_protocol::JSONRPCMessage; + use codex_app_server_protocol::JSONRPCResponse; + use pretty_assertions::assert_eq; + use tokio::io::AsyncBufReadExt; + use tokio::io::AsyncWriteExt; + use tokio::io::BufReader; + use tokio::time::timeout; + + use super::RpcClient; + use crate::connection::JsonRpcConnection; + + async fn read_jsonrpc_line(lines: &mut tokio::io::Lines>) -> JSONRPCMessage + where + R: tokio::io::AsyncRead + Unpin, + { + let next_line = timeout(Duration::from_secs(1), lines.next_line()).await; + let line_result = match next_line { + Ok(line_result) => line_result, + Err(err) => panic!("timed out waiting for JSON-RPC line: {err}"), + }; + let maybe_line = match line_result { + Ok(maybe_line) => maybe_line, + Err(err) => panic!("failed to read JSON-RPC line: {err}"), + }; + let line = match maybe_line { + Some(line) => line, + None => panic!("server connection closed before JSON-RPC line arrived"), + }; + match serde_json::from_str::(&line) { + Ok(message) => message, + Err(err) => panic!("failed to parse JSON-RPC line: {err}"), + } + } + + async fn write_jsonrpc_line(writer: &mut W, message: JSONRPCMessage) + where + W: tokio::io::AsyncWrite + Unpin, + { + let encoded = match serde_json::to_string(&message) { + Ok(encoded) => encoded, + Err(err) => panic!("failed to encode JSON-RPC message: {err}"), + }; + if let Err(err) = writer.write_all(format!("{encoded}\n").as_bytes()).await { + panic!("failed to write JSON-RPC line: {err}"); + } + } + + #[tokio::test] + async fn rpc_client_matches_out_of_order_responses_by_request_id() { + let (client_stdin, server_reader) = tokio::io::duplex(4096); + let (mut server_writer, client_stdout) = tokio::io::duplex(4096); + let (client, _events_rx) = RpcClient::new(JsonRpcConnection::from_stdio( + client_stdout, + client_stdin, + "test-rpc".to_string(), + )); + + let server = tokio::spawn(async move { + let mut lines = BufReader::new(server_reader).lines(); + + let first = read_jsonrpc_line(&mut lines).await; + let second = read_jsonrpc_line(&mut lines).await; + let (slow_request, fast_request) = match (first, second) { + ( + JSONRPCMessage::Request(first_request), + JSONRPCMessage::Request(second_request), + ) if first_request.method == "slow" && second_request.method == "fast" => { + (first_request, second_request) + } + ( + JSONRPCMessage::Request(first_request), + JSONRPCMessage::Request(second_request), + ) if first_request.method == "fast" && second_request.method == "slow" => { + (second_request, first_request) + } + _ => panic!("expected slow and fast requests"), + }; + + write_jsonrpc_line( + &mut server_writer, + JSONRPCMessage::Response(JSONRPCResponse { + id: fast_request.id, + result: serde_json::json!({ "value": "fast" }), + }), + ) + .await; + write_jsonrpc_line( + &mut server_writer, + JSONRPCMessage::Response(JSONRPCResponse { + id: slow_request.id, + result: serde_json::json!({ "value": "slow" }), + }), + ) + .await; + }); + + let slow_params = serde_json::json!({ "n": 1 }); + let fast_params = serde_json::json!({ "n": 2 }); + let (slow, fast) = tokio::join!( + client.call::<_, serde_json::Value>("slow", &slow_params), + client.call::<_, serde_json::Value>("fast", &fast_params), + ); + + let slow = slow.unwrap_or_else(|err| panic!("slow request failed: {err:?}")); + let fast = fast.unwrap_or_else(|err| panic!("fast request failed: {err:?}")); + assert_eq!(slow, serde_json::json!({ "value": "slow" })); + assert_eq!(fast, serde_json::json!({ "value": "fast" })); + + assert_eq!(client.pending_request_count().await, 0); + + if let Err(err) = server.await { + panic!("server task failed: {err}"); + } + } +} diff --git a/codex-rs/exec-server/src/server.rs b/codex-rs/exec-server/src/server.rs new file mode 100644 index 000000000..15ce8650f --- /dev/null +++ b/codex-rs/exec-server/src/server.rs @@ -0,0 +1,18 @@ +mod handler; +mod jsonrpc; +mod processor; +mod transport; + +pub(crate) use handler::ExecServerHandler; +pub use transport::ExecServerTransport; +pub use transport::ExecServerTransportParseError; + +pub async fn run_main() -> Result<(), Box> { + run_main_with_transport(ExecServerTransport::Stdio).await +} + +pub async fn run_main_with_transport( + transport: ExecServerTransport, +) -> Result<(), Box> { + transport::run_transport(transport).await +} diff --git a/codex-rs/exec-server/src/server/handler.rs b/codex-rs/exec-server/src/server/handler.rs new file mode 100644 index 000000000..838e58240 --- /dev/null +++ b/codex-rs/exec-server/src/server/handler.rs @@ -0,0 +1,40 @@ +use std::sync::atomic::AtomicBool; +use std::sync::atomic::Ordering; + +use codex_app_server_protocol::JSONRPCErrorError; + +use crate::protocol::InitializeResponse; +use crate::server::jsonrpc::invalid_request; + +pub(crate) struct ExecServerHandler { + initialize_requested: AtomicBool, + initialized: AtomicBool, +} + +impl ExecServerHandler { + pub(crate) fn new() -> Self { + Self { + initialize_requested: AtomicBool::new(false), + initialized: AtomicBool::new(false), + } + } + + pub(crate) async fn shutdown(&self) {} + + pub(crate) fn initialize(&self) -> Result { + if self.initialize_requested.swap(true, Ordering::SeqCst) { + return Err(invalid_request( + "initialize may only be sent once per connection".to_string(), + )); + } + Ok(InitializeResponse {}) + } + + pub(crate) fn initialized(&self) -> Result<(), String> { + if !self.initialize_requested.load(Ordering::SeqCst) { + return Err("received `initialized` notification before `initialize`".into()); + } + self.initialized.store(true, Ordering::SeqCst); + Ok(()) + } +} diff --git a/codex-rs/exec-server/src/server/jsonrpc.rs b/codex-rs/exec-server/src/server/jsonrpc.rs new file mode 100644 index 000000000..f81abd06e --- /dev/null +++ b/codex-rs/exec-server/src/server/jsonrpc.rs @@ -0,0 +1,53 @@ +use codex_app_server_protocol::JSONRPCError; +use codex_app_server_protocol::JSONRPCErrorError; +use codex_app_server_protocol::JSONRPCMessage; +use codex_app_server_protocol::JSONRPCResponse; +use codex_app_server_protocol::RequestId; +use serde_json::Value; + +pub(crate) fn invalid_request(message: String) -> JSONRPCErrorError { + JSONRPCErrorError { + code: -32600, + data: None, + message, + } +} + +pub(crate) fn invalid_params(message: String) -> JSONRPCErrorError { + JSONRPCErrorError { + code: -32602, + data: None, + message, + } +} + +pub(crate) fn method_not_found(message: String) -> JSONRPCErrorError { + JSONRPCErrorError { + code: -32601, + data: None, + message, + } +} + +pub(crate) fn response_message( + request_id: RequestId, + result: Result, +) -> JSONRPCMessage { + match result { + Ok(result) => JSONRPCMessage::Response(JSONRPCResponse { + id: request_id, + result, + }), + Err(error) => JSONRPCMessage::Error(JSONRPCError { + id: request_id, + error, + }), + } +} + +pub(crate) fn invalid_request_message(reason: String) -> JSONRPCMessage { + JSONRPCMessage::Error(JSONRPCError { + id: RequestId::Integer(-1), + error: invalid_request(reason), + }) +} diff --git a/codex-rs/exec-server/src/server/processor.rs b/codex-rs/exec-server/src/server/processor.rs new file mode 100644 index 000000000..7a8ca40f0 --- /dev/null +++ b/codex-rs/exec-server/src/server/processor.rs @@ -0,0 +1,121 @@ +use codex_app_server_protocol::JSONRPCMessage; +use codex_app_server_protocol::JSONRPCNotification; +use codex_app_server_protocol::JSONRPCRequest; +use tracing::debug; + +use crate::connection::JsonRpcConnection; +use crate::connection::JsonRpcConnectionEvent; +use crate::protocol::INITIALIZE_METHOD; +use crate::protocol::INITIALIZED_METHOD; +use crate::protocol::InitializeParams; +use crate::server::ExecServerHandler; +use crate::server::jsonrpc::invalid_params; +use crate::server::jsonrpc::invalid_request_message; +use crate::server::jsonrpc::method_not_found; +use crate::server::jsonrpc::response_message; +use tracing::warn; + +pub(crate) async fn run_connection(connection: JsonRpcConnection) { + let (json_outgoing_tx, mut incoming_rx, _connection_tasks) = connection.into_parts(); + let handler = ExecServerHandler::new(); + + while let Some(event) = incoming_rx.recv().await { + match event { + JsonRpcConnectionEvent::Message(message) => { + let response = match handle_connection_message(&handler, message).await { + Ok(response) => response, + Err(err) => { + tracing::warn!( + "closing exec-server connection after protocol error: {err}" + ); + break; + } + }; + let Some(response) = response else { + continue; + }; + if json_outgoing_tx.send(response).await.is_err() { + break; + } + } + JsonRpcConnectionEvent::MalformedMessage { reason } => { + warn!("ignoring malformed exec-server message: {reason}"); + if json_outgoing_tx + .send(invalid_request_message(reason)) + .await + .is_err() + { + break; + } + } + JsonRpcConnectionEvent::Disconnected { reason } => { + if let Some(reason) = reason { + debug!("exec-server connection disconnected: {reason}"); + } + break; + } + } + } + + handler.shutdown().await; +} + +pub(crate) async fn handle_connection_message( + handler: &ExecServerHandler, + message: JSONRPCMessage, +) -> Result, String> { + match message { + JSONRPCMessage::Request(request) => Ok(Some(dispatch_request(handler, request))), + JSONRPCMessage::Notification(notification) => { + handle_notification(handler, notification)?; + Ok(None) + } + JSONRPCMessage::Response(response) => Err(format!( + "unexpected client response for request id {:?}", + response.id + )), + JSONRPCMessage::Error(error) => Err(format!( + "unexpected client error for request id {:?}", + error.id + )), + } +} + +fn dispatch_request(handler: &ExecServerHandler, request: JSONRPCRequest) -> JSONRPCMessage { + let JSONRPCRequest { + id, + method, + params, + trace: _, + } = request; + + match method.as_str() { + INITIALIZE_METHOD => { + let result = serde_json::from_value::( + params.unwrap_or(serde_json::Value::Null), + ) + .map_err(|err| invalid_params(err.to_string())) + .and_then(|_params| handler.initialize()) + .and_then(|response| { + serde_json::to_value(response).map_err(|err| invalid_params(err.to_string())) + }); + response_message(id, result) + } + other => response_message( + id, + Err(method_not_found(format!( + "exec-server stub does not implement `{other}` yet" + ))), + ), + } +} + +fn handle_notification( + handler: &ExecServerHandler, + notification: JSONRPCNotification, +) -> Result<(), String> { + match notification.method.as_str() { + INITIALIZED_METHOD => handler.initialized(), + other => Err(format!("unexpected notification method: {other}")), + } +} diff --git a/codex-rs/exec-server/src/server/transport.rs b/codex-rs/exec-server/src/server/transport.rs new file mode 100644 index 000000000..edbec7fa9 --- /dev/null +++ b/codex-rs/exec-server/src/server/transport.rs @@ -0,0 +1,118 @@ +use std::net::SocketAddr; +use std::str::FromStr; + +use tokio::net::TcpListener; +use tokio_tungstenite::accept_async; +use tracing::warn; + +use crate::connection::JsonRpcConnection; +use crate::server::processor::run_connection; + +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub enum ExecServerTransport { + Stdio, + WebSocket { bind_address: SocketAddr }, +} + +#[derive(Debug, Clone, Eq, PartialEq)] +pub enum ExecServerTransportParseError { + UnsupportedListenUrl(String), + InvalidWebSocketListenUrl(String), +} + +impl std::fmt::Display for ExecServerTransportParseError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + ExecServerTransportParseError::UnsupportedListenUrl(listen_url) => write!( + f, + "unsupported --listen URL `{listen_url}`; expected `stdio://` or `ws://IP:PORT`" + ), + ExecServerTransportParseError::InvalidWebSocketListenUrl(listen_url) => write!( + f, + "invalid websocket --listen URL `{listen_url}`; expected `ws://IP:PORT`" + ), + } + } +} + +impl std::error::Error for ExecServerTransportParseError {} + +impl ExecServerTransport { + pub const DEFAULT_LISTEN_URL: &str = "ws://127.0.0.1:0"; + + pub fn from_listen_url(listen_url: &str) -> Result { + if listen_url == "stdio://" { + return Ok(Self::Stdio); + } + + if let Some(socket_addr) = listen_url.strip_prefix("ws://") { + let bind_address = socket_addr.parse::().map_err(|_| { + ExecServerTransportParseError::InvalidWebSocketListenUrl(listen_url.to_string()) + })?; + return Ok(Self::WebSocket { bind_address }); + } + + Err(ExecServerTransportParseError::UnsupportedListenUrl( + listen_url.to_string(), + )) + } +} + +impl FromStr for ExecServerTransport { + type Err = ExecServerTransportParseError; + + fn from_str(s: &str) -> Result { + Self::from_listen_url(s) + } +} + +pub(crate) async fn run_transport( + transport: ExecServerTransport, +) -> Result<(), Box> { + match transport { + ExecServerTransport::Stdio => { + run_connection(JsonRpcConnection::from_stdio( + tokio::io::stdin(), + tokio::io::stdout(), + "exec-server stdio".to_string(), + )) + .await; + Ok(()) + } + ExecServerTransport::WebSocket { bind_address } => { + run_websocket_listener(bind_address).await + } + } +} + +async fn run_websocket_listener( + bind_address: SocketAddr, +) -> Result<(), Box> { + let listener = TcpListener::bind(bind_address).await?; + let local_addr = listener.local_addr()?; + tracing::info!("codex-exec-server listening on ws://{local_addr}"); + + loop { + let (stream, peer_addr) = listener.accept().await?; + tokio::spawn(async move { + match accept_async(stream).await { + Ok(websocket) => { + run_connection(JsonRpcConnection::from_websocket( + websocket, + format!("exec-server websocket {peer_addr}"), + )) + .await; + } + Err(err) => { + warn!( + "failed to accept exec-server websocket connection from {peer_addr}: {err}" + ); + } + } + }); + } +} + +#[cfg(test)] +#[path = "transport_tests.rs"] +mod transport_tests; diff --git a/codex-rs/exec-server/src/server/transport_tests.rs b/codex-rs/exec-server/src/server/transport_tests.rs new file mode 100644 index 000000000..bc440e2aa --- /dev/null +++ b/codex-rs/exec-server/src/server/transport_tests.rs @@ -0,0 +1,54 @@ +use pretty_assertions::assert_eq; + +use super::ExecServerTransport; + +#[test] +fn exec_server_transport_parses_default_websocket_listen_url() { + let transport = ExecServerTransport::from_listen_url(ExecServerTransport::DEFAULT_LISTEN_URL) + .expect("default listen URL should parse"); + assert_eq!( + transport, + ExecServerTransport::WebSocket { + bind_address: "127.0.0.1:0".parse().expect("valid socket address"), + } + ); +} + +#[test] +fn exec_server_transport_parses_stdio_listen_url() { + let transport = + ExecServerTransport::from_listen_url("stdio://").expect("stdio listen URL should parse"); + assert_eq!(transport, ExecServerTransport::Stdio); +} + +#[test] +fn exec_server_transport_parses_websocket_listen_url() { + let transport = ExecServerTransport::from_listen_url("ws://127.0.0.1:1234") + .expect("websocket listen URL should parse"); + assert_eq!( + transport, + ExecServerTransport::WebSocket { + bind_address: "127.0.0.1:1234".parse().expect("valid socket address"), + } + ); +} + +#[test] +fn exec_server_transport_rejects_invalid_websocket_listen_url() { + let err = ExecServerTransport::from_listen_url("ws://localhost:1234") + .expect_err("hostname bind address should be rejected"); + assert_eq!( + err.to_string(), + "invalid websocket --listen URL `ws://localhost:1234`; expected `ws://IP:PORT`" + ); +} + +#[test] +fn exec_server_transport_rejects_unsupported_listen_url() { + let err = ExecServerTransport::from_listen_url("http://127.0.0.1:1234") + .expect_err("unsupported scheme should fail"); + assert_eq!( + err.to_string(), + "unsupported --listen URL `http://127.0.0.1:1234`; expected `stdio://` or `ws://IP:PORT`" + ); +} diff --git a/codex-rs/exec-server/tests/stdio_smoke.rs b/codex-rs/exec-server/tests/stdio_smoke.rs new file mode 100644 index 000000000..240180efd --- /dev/null +++ b/codex-rs/exec-server/tests/stdio_smoke.rs @@ -0,0 +1,129 @@ +#![cfg(unix)] + +use std::process::Stdio; +use std::time::Duration; + +use codex_app_server_protocol::JSONRPCMessage; +use codex_app_server_protocol::JSONRPCNotification; +use codex_app_server_protocol::JSONRPCRequest; +use codex_app_server_protocol::JSONRPCResponse; +use codex_app_server_protocol::RequestId; +use codex_exec_server::InitializeParams; +use codex_exec_server::InitializeResponse; +use codex_utils_cargo_bin::cargo_bin; +use pretty_assertions::assert_eq; +use tokio::io::AsyncBufReadExt; +use tokio::io::AsyncWriteExt; +use tokio::io::BufReader; +use tokio::process::Command; +use tokio::time::timeout; + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn exec_server_accepts_initialize_over_stdio() -> anyhow::Result<()> { + let binary = cargo_bin("codex-exec-server")?; + let mut child = Command::new(binary); + child.args(["--listen", "stdio://"]); + child.stdin(Stdio::piped()); + child.stdout(Stdio::piped()); + child.stderr(Stdio::inherit()); + let mut child = child.spawn()?; + + let mut stdin = child.stdin.take().expect("stdin"); + let stdout = child.stdout.take().expect("stdout"); + let mut stdout = BufReader::new(stdout).lines(); + + let initialize = JSONRPCMessage::Request(JSONRPCRequest { + id: RequestId::Integer(1), + method: "initialize".to_string(), + params: Some(serde_json::to_value(InitializeParams { + client_name: "exec-server-test".to_string(), + })?), + trace: None, + }); + stdin + .write_all(format!("{}\n", serde_json::to_string(&initialize)?).as_bytes()) + .await?; + + let response_line = timeout(Duration::from_secs(5), stdout.next_line()).await??; + let response_line = response_line.expect("response line"); + let response: JSONRPCMessage = serde_json::from_str(&response_line)?; + let JSONRPCMessage::Response(JSONRPCResponse { id, result }) = response else { + panic!("expected initialize response"); + }; + assert_eq!(id, RequestId::Integer(1)); + let initialize_response: InitializeResponse = serde_json::from_value(result)?; + assert_eq!(initialize_response, InitializeResponse {}); + + let initialized = JSONRPCMessage::Notification(JSONRPCNotification { + method: "initialized".to_string(), + params: Some(serde_json::json!({})), + }); + stdin + .write_all(format!("{}\n", serde_json::to_string(&initialized)?).as_bytes()) + .await?; + + child.start_kill()?; + Ok(()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn exec_server_stubs_process_start_over_stdio() -> anyhow::Result<()> { + let binary = cargo_bin("codex-exec-server")?; + let mut child = Command::new(binary); + child.args(["--listen", "stdio://"]); + child.stdin(Stdio::piped()); + child.stdout(Stdio::piped()); + child.stderr(Stdio::inherit()); + let mut child = child.spawn()?; + + let mut stdin = child.stdin.take().expect("stdin"); + let stdout = child.stdout.take().expect("stdout"); + let mut stdout = BufReader::new(stdout).lines(); + + let initialize = JSONRPCMessage::Request(JSONRPCRequest { + id: RequestId::Integer(1), + method: "initialize".to_string(), + params: Some(serde_json::to_value(InitializeParams { + client_name: "exec-server-test".to_string(), + })?), + trace: None, + }); + stdin + .write_all(format!("{}\n", serde_json::to_string(&initialize)?).as_bytes()) + .await?; + let _ = timeout(Duration::from_secs(5), stdout.next_line()).await??; + + let exec = JSONRPCMessage::Request(JSONRPCRequest { + id: RequestId::Integer(2), + method: "process/start".to_string(), + params: Some(serde_json::json!({ + "processId": "proc-1", + "argv": ["true"], + "cwd": std::env::current_dir()?, + "env": {}, + "tty": false, + "arg0": null + })), + trace: None, + }); + stdin + .write_all(format!("{}\n", serde_json::to_string(&exec)?).as_bytes()) + .await?; + + let response_line = timeout(Duration::from_secs(5), stdout.next_line()).await??; + let response_line = response_line.expect("exec response line"); + let response: JSONRPCMessage = serde_json::from_str(&response_line)?; + let JSONRPCMessage::Error(codex_app_server_protocol::JSONRPCError { id, error }) = response + else { + panic!("expected process/start stub error"); + }; + assert_eq!(id, RequestId::Integer(2)); + assert_eq!(error.code, -32601); + assert_eq!( + error.message, + "exec-server stub does not implement `process/start` yet" + ); + + child.start_kill()?; + Ok(()) +} diff --git a/codex-rs/exec-server/tests/websocket_smoke.rs b/codex-rs/exec-server/tests/websocket_smoke.rs new file mode 100644 index 000000000..2a51a4d3a --- /dev/null +++ b/codex-rs/exec-server/tests/websocket_smoke.rs @@ -0,0 +1,229 @@ +#![cfg(unix)] + +use std::process::Stdio; +use std::time::Duration; + +use codex_app_server_protocol::JSONRPCError; +use codex_app_server_protocol::JSONRPCMessage; +use codex_app_server_protocol::JSONRPCNotification; +use codex_app_server_protocol::JSONRPCRequest; +use codex_app_server_protocol::JSONRPCResponse; +use codex_app_server_protocol::RequestId; +use codex_exec_server::InitializeParams; +use codex_exec_server::InitializeResponse; +use codex_utils_cargo_bin::cargo_bin; +use pretty_assertions::assert_eq; +use tokio::process::Command; +use tokio_tungstenite::connect_async; +use tokio_tungstenite::tungstenite::Message; + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn exec_server_accepts_initialize_over_websocket() -> anyhow::Result<()> { + let binary = cargo_bin("codex-exec-server")?; + let websocket_url = reserve_websocket_url()?; + let mut child = Command::new(binary); + child.args(["--listen", &websocket_url]); + child.stdin(Stdio::null()); + child.stdout(Stdio::null()); + child.stderr(Stdio::inherit()); + let mut child = child.spawn()?; + + let (mut websocket, _) = connect_websocket_when_ready(&websocket_url).await?; + let initialize = JSONRPCMessage::Request(JSONRPCRequest { + id: RequestId::Integer(1), + method: "initialize".to_string(), + params: Some(serde_json::to_value(InitializeParams { + client_name: "exec-server-test".to_string(), + })?), + trace: None, + }); + futures::SinkExt::send( + &mut websocket, + Message::Text(serde_json::to_string(&initialize)?.into()), + ) + .await?; + + let Some(Ok(Message::Text(response_text))) = futures::StreamExt::next(&mut websocket).await + else { + panic!("expected initialize response"); + }; + let response: JSONRPCMessage = serde_json::from_str(response_text.as_ref())?; + let JSONRPCMessage::Response(JSONRPCResponse { id, result }) = response else { + panic!("expected initialize response"); + }; + assert_eq!(id, RequestId::Integer(1)); + let initialize_response: InitializeResponse = serde_json::from_value(result)?; + assert_eq!(initialize_response, InitializeResponse {}); + + let initialized = JSONRPCMessage::Notification(JSONRPCNotification { + method: "initialized".to_string(), + params: Some(serde_json::json!({})), + }); + futures::SinkExt::send( + &mut websocket, + Message::Text(serde_json::to_string(&initialized)?.into()), + ) + .await?; + + child.start_kill()?; + Ok(()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn exec_server_reports_malformed_websocket_json_and_keeps_running() -> anyhow::Result<()> { + let binary = cargo_bin("codex-exec-server")?; + let websocket_url = reserve_websocket_url()?; + let mut child = Command::new(binary); + child.args(["--listen", &websocket_url]); + child.stdin(Stdio::null()); + child.stdout(Stdio::null()); + child.stderr(Stdio::inherit()); + let mut child = child.spawn()?; + + let (mut websocket, _) = connect_websocket_when_ready(&websocket_url).await?; + futures::SinkExt::send(&mut websocket, Message::Text("not-json".to_string().into())).await?; + + let Some(Ok(Message::Text(response_text))) = futures::StreamExt::next(&mut websocket).await + else { + panic!("expected malformed-message error response"); + }; + let response: JSONRPCMessage = serde_json::from_str(response_text.as_ref())?; + let JSONRPCMessage::Error(JSONRPCError { id, error }) = response else { + panic!("expected malformed-message error response"); + }; + assert_eq!(id, RequestId::Integer(-1)); + assert_eq!(error.code, -32600); + assert!( + error + .message + .starts_with("failed to parse websocket JSON-RPC message from exec-server websocket"), + "unexpected malformed-message error: {}", + error.message + ); + + let initialize = JSONRPCMessage::Request(JSONRPCRequest { + id: RequestId::Integer(1), + method: "initialize".to_string(), + params: Some(serde_json::to_value(InitializeParams { + client_name: "exec-server-test".to_string(), + })?), + trace: None, + }); + futures::SinkExt::send( + &mut websocket, + Message::Text(serde_json::to_string(&initialize)?.into()), + ) + .await?; + + let Some(Ok(Message::Text(response_text))) = futures::StreamExt::next(&mut websocket).await + else { + panic!("expected initialize response after malformed input"); + }; + let response: JSONRPCMessage = serde_json::from_str(response_text.as_ref())?; + let JSONRPCMessage::Response(JSONRPCResponse { id, result }) = response else { + panic!("expected initialize response after malformed input"); + }; + assert_eq!(id, RequestId::Integer(1)); + let initialize_response: InitializeResponse = serde_json::from_value(result)?; + assert_eq!(initialize_response, InitializeResponse {}); + + child.start_kill()?; + Ok(()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn exec_server_stubs_process_start_over_websocket() -> anyhow::Result<()> { + let binary = cargo_bin("codex-exec-server")?; + let websocket_url = reserve_websocket_url()?; + let mut child = Command::new(binary); + child.args(["--listen", &websocket_url]); + child.stdin(Stdio::null()); + child.stdout(Stdio::null()); + child.stderr(Stdio::inherit()); + let mut child = child.spawn()?; + + let (mut websocket, _) = connect_websocket_when_ready(&websocket_url).await?; + let initialize = JSONRPCMessage::Request(JSONRPCRequest { + id: RequestId::Integer(1), + method: "initialize".to_string(), + params: Some(serde_json::to_value(InitializeParams { + client_name: "exec-server-test".to_string(), + })?), + trace: None, + }); + futures::SinkExt::send( + &mut websocket, + Message::Text(serde_json::to_string(&initialize)?.into()), + ) + .await?; + let _ = futures::StreamExt::next(&mut websocket).await; + + let exec = JSONRPCMessage::Request(JSONRPCRequest { + id: RequestId::Integer(2), + method: "process/start".to_string(), + params: Some(serde_json::json!({ + "processId": "proc-1", + "argv": ["true"], + "cwd": std::env::current_dir()?, + "env": {}, + "tty": false, + "arg0": null + })), + trace: None, + }); + futures::SinkExt::send( + &mut websocket, + Message::Text(serde_json::to_string(&exec)?.into()), + ) + .await?; + + let Some(Ok(Message::Text(response_text))) = futures::StreamExt::next(&mut websocket).await + else { + panic!("expected process/start error"); + }; + let response: JSONRPCMessage = serde_json::from_str(response_text.as_ref())?; + let JSONRPCMessage::Error(JSONRPCError { id, error }) = response else { + panic!("expected process/start stub error"); + }; + assert_eq!(id, RequestId::Integer(2)); + assert_eq!(error.code, -32601); + assert_eq!( + error.message, + "exec-server stub does not implement `process/start` yet" + ); + + child.start_kill()?; + Ok(()) +} + +fn reserve_websocket_url() -> anyhow::Result { + let listener = std::net::TcpListener::bind("127.0.0.1:0")?; + let addr = listener.local_addr()?; + drop(listener); + Ok(format!("ws://{addr}")) +} + +async fn connect_websocket_when_ready( + websocket_url: &str, +) -> anyhow::Result<( + tokio_tungstenite::WebSocketStream>, + tokio_tungstenite::tungstenite::handshake::client::Response, +)> { + let deadline = tokio::time::Instant::now() + Duration::from_secs(5); + loop { + match connect_async(websocket_url).await { + Ok(websocket) => return Ok(websocket), + Err(err) + if tokio::time::Instant::now() < deadline + && matches!( + err, + tokio_tungstenite::tungstenite::Error::Io(ref io_err) + if io_err.kind() == std::io::ErrorKind::ConnectionRefused + ) => + { + tokio::time::sleep(Duration::from_millis(25)).await; + } + Err(err) => return Err(err.into()), + } + } +}