diff --git a/.github/workflows/rust-ci.yml b/.github/workflows/rust-ci.yml index 0be65403d..287e7e540 100644 --- a/.github/workflows/rust-ci.yml +++ b/.github/workflows/rust-ci.yml @@ -141,8 +141,10 @@ jobs: run: working-directory: codex-rs env: - # Speed up repeated builds across CI runs by caching compiled objects (non-Windows). - USE_SCCACHE: ${{ startsWith(matrix.runner, 'windows') && 'false' || 'true' }} + # Speed up repeated builds across CI runs by caching compiled objects, except on + # arm64 macOS runners cross-targeting x86_64 where ring/cc-rs can produce + # mixed-architecture archives under sccache. + USE_SCCACHE: ${{ (startsWith(matrix.runner, 'windows') || (matrix.runner == 'macos-15-xlarge' && matrix.target == 'x86_64-apple-darwin')) && 'false' || 'true' }} CARGO_INCREMENTAL: "0" SCCACHE_CACHE_SIZE: 10G # In rust-ci, representative release-profile checks use thin LTO for faster feedback. @@ -506,8 +508,10 @@ jobs: run: working-directory: codex-rs env: - # Speed up repeated builds across CI runs by caching compiled objects (non-Windows). - USE_SCCACHE: ${{ startsWith(matrix.runner, 'windows') && 'false' || 'true' }} + # Speed up repeated builds across CI runs by caching compiled objects, except on + # arm64 macOS runners cross-targeting x86_64 where ring/cc-rs can produce + # mixed-architecture archives under sccache. + USE_SCCACHE: ${{ (startsWith(matrix.runner, 'windows') || (matrix.runner == 'macos-15-xlarge' && matrix.target == 'x86_64-apple-darwin')) && 'false' || 'true' }} CARGO_INCREMENTAL: "0" SCCACHE_CACHE_SIZE: 10G diff --git a/codex-rs/Cargo.lock b/codex-rs/Cargo.lock index b488f94cd..32b41c022 100644 --- a/codex-rs/Cargo.lock +++ b/codex-rs/Cargo.lock @@ -1998,10 +1998,12 @@ version = "0.0.0" dependencies = [ "anyhow", "async-trait", + "base64 0.22.1", "clap", "codex-app-server-protocol", "codex-utils-absolute-path", "codex-utils-cargo-bin", + "codex-utils-pty", "futures", "pretty_assertions", "serde", diff --git a/codex-rs/exec-server/Cargo.toml b/codex-rs/exec-server/Cargo.toml index 91af099eb..fac7649e4 100644 --- a/codex-rs/exec-server/Cargo.toml +++ b/codex-rs/exec-server/Cargo.toml @@ -16,9 +16,11 @@ workspace = true [dependencies] async-trait = { workspace = true } +base64 = { workspace = true } clap = { workspace = true, features = ["derive"] } codex-app-server-protocol = { workspace = true } codex-utils-absolute-path = { workspace = true } +codex-utils-pty = { workspace = true } futures = { workspace = true } serde = { workspace = true, features = ["derive"] } serde_json = { workspace = true } diff --git a/codex-rs/exec-server/src/client.rs b/codex-rs/exec-server/src/client.rs index 4b4e69f24..a7680e73e 100644 --- a/codex-rs/exec-server/src/client.rs +++ b/codex-rs/exec-server/src/client.rs @@ -1,20 +1,65 @@ use std::sync::Arc; use std::time::Duration; +use codex_app_server_protocol::FsCopyParams; +use codex_app_server_protocol::FsCopyResponse; +use codex_app_server_protocol::FsCreateDirectoryParams; +use codex_app_server_protocol::FsCreateDirectoryResponse; +use codex_app_server_protocol::FsGetMetadataParams; +use codex_app_server_protocol::FsGetMetadataResponse; +use codex_app_server_protocol::FsReadDirectoryParams; +use codex_app_server_protocol::FsReadDirectoryResponse; +use codex_app_server_protocol::FsReadFileParams; +use codex_app_server_protocol::FsReadFileResponse; +use codex_app_server_protocol::FsRemoveParams; +use codex_app_server_protocol::FsRemoveResponse; +use codex_app_server_protocol::FsWriteFileParams; +use codex_app_server_protocol::FsWriteFileResponse; +use codex_app_server_protocol::JSONRPCNotification; +use serde_json::Value; +use tokio::sync::broadcast; +use tokio::sync::mpsc; use tokio::time::timeout; use tokio_tungstenite::connect_async; +use tracing::debug; use tracing::warn; use crate::client_api::ExecServerClientConnectOptions; +use crate::client_api::ExecServerEvent; use crate::client_api::RemoteExecServerConnectArgs; use crate::connection::JsonRpcConnection; +use crate::protocol::EXEC_EXITED_METHOD; +use crate::protocol::EXEC_METHOD; +use crate::protocol::EXEC_OUTPUT_DELTA_METHOD; +use crate::protocol::EXEC_READ_METHOD; +use crate::protocol::EXEC_TERMINATE_METHOD; +use crate::protocol::EXEC_WRITE_METHOD; +use crate::protocol::ExecExitedNotification; +use crate::protocol::ExecOutputDeltaNotification; +use crate::protocol::ExecParams; +use crate::protocol::ExecResponse; +use crate::protocol::FS_COPY_METHOD; +use crate::protocol::FS_CREATE_DIRECTORY_METHOD; +use crate::protocol::FS_GET_METADATA_METHOD; +use crate::protocol::FS_READ_DIRECTORY_METHOD; +use crate::protocol::FS_READ_FILE_METHOD; +use crate::protocol::FS_REMOVE_METHOD; +use crate::protocol::FS_WRITE_FILE_METHOD; use crate::protocol::INITIALIZE_METHOD; use crate::protocol::INITIALIZED_METHOD; use crate::protocol::InitializeParams; use crate::protocol::InitializeResponse; +use crate::protocol::ReadParams; +use crate::protocol::ReadResponse; +use crate::protocol::TerminateParams; +use crate::protocol::TerminateResponse; +use crate::protocol::WriteParams; +use crate::protocol::WriteResponse; use crate::rpc::RpcCallError; use crate::rpc::RpcClient; use crate::rpc::RpcClientEvent; +use crate::rpc::RpcNotificationSender; +use crate::rpc::RpcServerOutboundMessage; mod local_backend; use local_backend::LocalBackend; @@ -74,6 +119,7 @@ impl ClientBackend { struct Inner { backend: ClientBackend, + events_tx: broadcast::Sender, reader_task: tokio::task::JoinHandle<()>, } @@ -124,11 +170,32 @@ impl ExecServerClient { pub async fn connect_in_process( options: ExecServerClientConnectOptions, ) -> Result { - let backend = LocalBackend::new(crate::server::ExecServerHandler::new()); - let inner = Arc::new(Inner { - backend: ClientBackend::InProcess(backend), - reader_task: tokio::spawn(async {}), + let (outgoing_tx, mut outgoing_rx) = mpsc::channel::(256); + let backend = LocalBackend::new(crate::server::ExecServerHandler::new( + RpcNotificationSender::new(outgoing_tx), + )); + let inner = Arc::new_cyclic(|weak| { + let weak = weak.clone(); + let reader_task = tokio::spawn(async move { + while let Some(message) = outgoing_rx.recv().await { + if let Some(inner) = weak.upgrade() + && let Err(err) = handle_in_process_outbound_message(&inner, message).await + { + warn!( + "in-process exec-server client closing after unexpected response: {err}" + ); + return; + } + } + }); + + Inner { + backend: ClientBackend::InProcess(backend), + events_tx: broadcast::channel(256).0, + reader_task, + } }); + let client = Self { inner }; client.initialize(options).await?; Ok(client) @@ -160,6 +227,10 @@ impl ExecServerClient { .await } + pub fn event_receiver(&self) -> broadcast::Receiver { + self.inner.events_tx.subscribe() + } + pub async fn initialize( &self, options: ExecServerClientConnectOptions, @@ -190,36 +261,234 @@ impl ExecServerClient { })? } + pub async fn exec(&self, params: ExecParams) -> Result { + if let Some(backend) = self.inner.backend.as_local() { + return backend.exec(params).await; + } + let Some(remote) = self.inner.backend.as_remote() else { + return Err(ExecServerError::Protocol( + "remote backend missing during exec".to_string(), + )); + }; + remote.call(EXEC_METHOD, ¶ms).await.map_err(Into::into) + } + + pub async fn read(&self, params: ReadParams) -> Result { + if let Some(backend) = self.inner.backend.as_local() { + return backend.exec_read(params).await; + } + let Some(remote) = self.inner.backend.as_remote() else { + return Err(ExecServerError::Protocol( + "remote backend missing during read".to_string(), + )); + }; + remote + .call(EXEC_READ_METHOD, ¶ms) + .await + .map_err(Into::into) + } + + pub async fn write( + &self, + process_id: &str, + chunk: Vec, + ) -> Result { + let params = WriteParams { + process_id: process_id.to_string(), + chunk: chunk.into(), + }; + if let Some(backend) = self.inner.backend.as_local() { + return backend.exec_write(params).await; + } + let Some(remote) = self.inner.backend.as_remote() else { + return Err(ExecServerError::Protocol( + "remote backend missing during write".to_string(), + )); + }; + remote + .call(EXEC_WRITE_METHOD, ¶ms) + .await + .map_err(Into::into) + } + + pub async fn terminate(&self, process_id: &str) -> Result { + let params = TerminateParams { + process_id: process_id.to_string(), + }; + if let Some(backend) = self.inner.backend.as_local() { + return backend.terminate(params).await; + } + let Some(remote) = self.inner.backend.as_remote() else { + return Err(ExecServerError::Protocol( + "remote backend missing during terminate".to_string(), + )); + }; + remote + .call(EXEC_TERMINATE_METHOD, ¶ms) + .await + .map_err(Into::into) + } + + pub async fn fs_read_file( + &self, + params: FsReadFileParams, + ) -> Result { + if let Some(backend) = self.inner.backend.as_local() { + return backend.fs_read_file(params).await; + } + let Some(remote) = self.inner.backend.as_remote() else { + return Err(ExecServerError::Protocol( + "remote backend missing during fs/readFile".to_string(), + )); + }; + remote + .call(FS_READ_FILE_METHOD, ¶ms) + .await + .map_err(Into::into) + } + + pub async fn fs_write_file( + &self, + params: FsWriteFileParams, + ) -> Result { + if let Some(backend) = self.inner.backend.as_local() { + return backend.fs_write_file(params).await; + } + let Some(remote) = self.inner.backend.as_remote() else { + return Err(ExecServerError::Protocol( + "remote backend missing during fs/writeFile".to_string(), + )); + }; + remote + .call(FS_WRITE_FILE_METHOD, ¶ms) + .await + .map_err(Into::into) + } + + pub async fn fs_create_directory( + &self, + params: FsCreateDirectoryParams, + ) -> Result { + if let Some(backend) = self.inner.backend.as_local() { + return backend.fs_create_directory(params).await; + } + let Some(remote) = self.inner.backend.as_remote() else { + return Err(ExecServerError::Protocol( + "remote backend missing during fs/createDirectory".to_string(), + )); + }; + remote + .call(FS_CREATE_DIRECTORY_METHOD, ¶ms) + .await + .map_err(Into::into) + } + + pub async fn fs_get_metadata( + &self, + params: FsGetMetadataParams, + ) -> Result { + if let Some(backend) = self.inner.backend.as_local() { + return backend.fs_get_metadata(params).await; + } + let Some(remote) = self.inner.backend.as_remote() else { + return Err(ExecServerError::Protocol( + "remote backend missing during fs/getMetadata".to_string(), + )); + }; + remote + .call(FS_GET_METADATA_METHOD, ¶ms) + .await + .map_err(Into::into) + } + + pub async fn fs_read_directory( + &self, + params: FsReadDirectoryParams, + ) -> Result { + if let Some(backend) = self.inner.backend.as_local() { + return backend.fs_read_directory(params).await; + } + let Some(remote) = self.inner.backend.as_remote() else { + return Err(ExecServerError::Protocol( + "remote backend missing during fs/readDirectory".to_string(), + )); + }; + remote + .call(FS_READ_DIRECTORY_METHOD, ¶ms) + .await + .map_err(Into::into) + } + + pub async fn fs_remove( + &self, + params: FsRemoveParams, + ) -> Result { + if let Some(backend) = self.inner.backend.as_local() { + return backend.fs_remove(params).await; + } + let Some(remote) = self.inner.backend.as_remote() else { + return Err(ExecServerError::Protocol( + "remote backend missing during fs/remove".to_string(), + )); + }; + remote + .call(FS_REMOVE_METHOD, ¶ms) + .await + .map_err(Into::into) + } + + pub async fn fs_copy(&self, params: FsCopyParams) -> Result { + if let Some(backend) = self.inner.backend.as_local() { + return backend.fs_copy(params).await; + } + let Some(remote) = self.inner.backend.as_remote() else { + return Err(ExecServerError::Protocol( + "remote backend missing during fs/copy".to_string(), + )); + }; + remote + .call(FS_COPY_METHOD, ¶ms) + .await + .map_err(Into::into) + } + async fn connect( connection: JsonRpcConnection, options: ExecServerClientConnectOptions, ) -> Result { let (rpc_client, mut events_rx) = RpcClient::new(connection); - let reader_task = tokio::spawn(async move { - while let Some(event) = events_rx.recv().await { - match event { - RpcClientEvent::Notification(notification) => { - warn!( - "ignoring unexpected exec-server notification during stub phase: {}", - notification.method - ); - } - RpcClientEvent::Disconnected { reason } => { - if let Some(reason) = reason { - warn!("exec-server client transport disconnected: {reason}"); + let inner = Arc::new_cyclic(|weak| { + let weak = weak.clone(); + let reader_task = tokio::spawn(async move { + while let Some(event) = events_rx.recv().await { + match event { + RpcClientEvent::Notification(notification) => { + if let Some(inner) = weak.upgrade() + && let Err(err) = + handle_server_notification(&inner, notification).await + { + warn!("exec-server client closing after protocol error: {err}"); + return; + } + } + RpcClientEvent::Disconnected { reason } => { + if let Some(reason) = reason { + warn!("exec-server client transport disconnected: {reason}"); + } + return; } - return; } } + }); + + Inner { + backend: ClientBackend::Remote(rpc_client), + events_tx: broadcast::channel(256).0, + reader_task, } }); - let client = Self { - inner: Arc::new(Inner { - backend: ClientBackend::Remote(rpc_client), - reader_task, - }), - }; + let client = Self { inner }; client.initialize(options).await?; Ok(client) } @@ -247,3 +516,39 @@ impl From for ExecServerError { } } } + +async fn handle_in_process_outbound_message( + inner: &Arc, + message: RpcServerOutboundMessage, +) -> Result<(), ExecServerError> { + match message { + RpcServerOutboundMessage::Response { .. } | RpcServerOutboundMessage::Error { .. } => Err( + ExecServerError::Protocol("unexpected in-process RPC response".to_string()), + ), + RpcServerOutboundMessage::Notification(notification) => { + handle_server_notification(inner, notification).await + } + } +} + +async fn handle_server_notification( + inner: &Arc, + notification: JSONRPCNotification, +) -> Result<(), ExecServerError> { + match notification.method.as_str() { + EXEC_OUTPUT_DELTA_METHOD => { + let params: ExecOutputDeltaNotification = + serde_json::from_value(notification.params.unwrap_or(Value::Null))?; + let _ = inner.events_tx.send(ExecServerEvent::OutputDelta(params)); + } + EXEC_EXITED_METHOD => { + let params: ExecExitedNotification = + serde_json::from_value(notification.params.unwrap_or(Value::Null))?; + let _ = inner.events_tx.send(ExecServerEvent::Exited(params)); + } + other => { + debug!("ignoring unknown exec-server notification: {other}"); + } + } + Ok(()) +} diff --git a/codex-rs/exec-server/src/client/local_backend.rs b/codex-rs/exec-server/src/client/local_backend.rs index 8f9a2481f..e23a5361d 100644 --- a/codex-rs/exec-server/src/client/local_backend.rs +++ b/codex-rs/exec-server/src/client/local_backend.rs @@ -1,7 +1,29 @@ use std::sync::Arc; +use crate::protocol::ExecParams; +use crate::protocol::ExecResponse; use crate::protocol::InitializeResponse; +use crate::protocol::ReadParams; +use crate::protocol::ReadResponse; +use crate::protocol::TerminateParams; +use crate::protocol::TerminateResponse; +use crate::protocol::WriteParams; +use crate::protocol::WriteResponse; use crate::server::ExecServerHandler; +use codex_app_server_protocol::FsCopyParams; +use codex_app_server_protocol::FsCopyResponse; +use codex_app_server_protocol::FsCreateDirectoryParams; +use codex_app_server_protocol::FsCreateDirectoryResponse; +use codex_app_server_protocol::FsGetMetadataParams; +use codex_app_server_protocol::FsGetMetadataResponse; +use codex_app_server_protocol::FsReadDirectoryParams; +use codex_app_server_protocol::FsReadDirectoryResponse; +use codex_app_server_protocol::FsReadFileParams; +use codex_app_server_protocol::FsReadFileResponse; +use codex_app_server_protocol::FsRemoveParams; +use codex_app_server_protocol::FsRemoveResponse; +use codex_app_server_protocol::FsWriteFileParams; +use codex_app_server_protocol::FsWriteFileResponse; use super::ExecServerError; @@ -35,4 +57,144 @@ impl LocalBackend { .initialized() .map_err(ExecServerError::Protocol) } + + pub(super) async fn exec(&self, params: ExecParams) -> Result { + self.handler + .exec(params) + .await + .map_err(|error| ExecServerError::Server { + code: error.code, + message: error.message, + }) + } + + pub(super) async fn exec_read( + &self, + params: ReadParams, + ) -> Result { + self.handler + .exec_read(params) + .await + .map_err(|error| ExecServerError::Server { + code: error.code, + message: error.message, + }) + } + + pub(super) async fn exec_write( + &self, + params: WriteParams, + ) -> Result { + self.handler + .exec_write(params) + .await + .map_err(|error| ExecServerError::Server { + code: error.code, + message: error.message, + }) + } + + pub(super) async fn terminate( + &self, + params: TerminateParams, + ) -> Result { + self.handler + .terminate(params) + .await + .map_err(|error| ExecServerError::Server { + code: error.code, + message: error.message, + }) + } + + pub(super) async fn fs_read_file( + &self, + params: FsReadFileParams, + ) -> Result { + self.handler + .fs_read_file(params) + .await + .map_err(|error| ExecServerError::Server { + code: error.code, + message: error.message, + }) + } + + pub(super) async fn fs_write_file( + &self, + params: FsWriteFileParams, + ) -> Result { + self.handler + .fs_write_file(params) + .await + .map_err(|error| ExecServerError::Server { + code: error.code, + message: error.message, + }) + } + + pub(super) async fn fs_create_directory( + &self, + params: FsCreateDirectoryParams, + ) -> Result { + self.handler + .fs_create_directory(params) + .await + .map_err(|error| ExecServerError::Server { + code: error.code, + message: error.message, + }) + } + + pub(super) async fn fs_get_metadata( + &self, + params: FsGetMetadataParams, + ) -> Result { + self.handler + .fs_get_metadata(params) + .await + .map_err(|error| ExecServerError::Server { + code: error.code, + message: error.message, + }) + } + + pub(super) async fn fs_read_directory( + &self, + params: FsReadDirectoryParams, + ) -> Result { + self.handler + .fs_read_directory(params) + .await + .map_err(|error| ExecServerError::Server { + code: error.code, + message: error.message, + }) + } + + pub(super) async fn fs_remove( + &self, + params: FsRemoveParams, + ) -> Result { + self.handler + .fs_remove(params) + .await + .map_err(|error| ExecServerError::Server { + code: error.code, + message: error.message, + }) + } + + pub(super) async fn fs_copy( + &self, + params: FsCopyParams, + ) -> Result { + self.handler + .fs_copy(params) + .await + .map_err(|error| ExecServerError::Server { + code: error.code, + message: error.message, + }) + } } diff --git a/codex-rs/exec-server/src/client_api.rs b/codex-rs/exec-server/src/client_api.rs index 6e8976341..962d3ba36 100644 --- a/codex-rs/exec-server/src/client_api.rs +++ b/codex-rs/exec-server/src/client_api.rs @@ -1,5 +1,8 @@ use std::time::Duration; +use crate::protocol::ExecExitedNotification; +use crate::protocol::ExecOutputDeltaNotification; + /// Connection options for any exec-server client transport. #[derive(Debug, Clone, PartialEq, Eq)] pub struct ExecServerClientConnectOptions { @@ -15,3 +18,10 @@ pub struct RemoteExecServerConnectArgs { pub connect_timeout: Duration, pub initialize_timeout: Duration, } + +/// Connection-level server events. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum ExecServerEvent { + OutputDelta(ExecOutputDeltaNotification), + Exited(ExecExitedNotification), +} diff --git a/codex-rs/exec-server/src/lib.rs b/codex-rs/exec-server/src/lib.rs index fdd22e163..3c50d0ec5 100644 --- a/codex-rs/exec-server/src/lib.rs +++ b/codex-rs/exec-server/src/lib.rs @@ -10,7 +10,23 @@ mod server; pub use client::ExecServerClient; pub use client::ExecServerError; pub use client_api::ExecServerClientConnectOptions; +pub use client_api::ExecServerEvent; pub use client_api::RemoteExecServerConnectArgs; +pub use codex_app_server_protocol::FsCopyParams; +pub use codex_app_server_protocol::FsCopyResponse; +pub use codex_app_server_protocol::FsCreateDirectoryParams; +pub use codex_app_server_protocol::FsCreateDirectoryResponse; +pub use codex_app_server_protocol::FsGetMetadataParams; +pub use codex_app_server_protocol::FsGetMetadataResponse; +pub use codex_app_server_protocol::FsReadDirectoryEntry; +pub use codex_app_server_protocol::FsReadDirectoryParams; +pub use codex_app_server_protocol::FsReadDirectoryResponse; +pub use codex_app_server_protocol::FsReadFileParams; +pub use codex_app_server_protocol::FsReadFileResponse; +pub use codex_app_server_protocol::FsRemoveParams; +pub use codex_app_server_protocol::FsRemoveResponse; +pub use codex_app_server_protocol::FsWriteFileParams; +pub use codex_app_server_protocol::FsWriteFileResponse; pub use environment::Environment; pub use fs::CopyOptions; pub use fs::CreateDirectoryOptions; @@ -19,8 +35,19 @@ pub use fs::FileMetadata; pub use fs::FileSystemResult; pub use fs::ReadDirectoryEntry; pub use fs::RemoveOptions; +pub use protocol::ExecExitedNotification; +pub use protocol::ExecOutputDeltaNotification; +pub use protocol::ExecOutputStream; +pub use protocol::ExecParams; +pub use protocol::ExecResponse; pub use protocol::InitializeParams; pub use protocol::InitializeResponse; +pub use protocol::ReadParams; +pub use protocol::ReadResponse; +pub use protocol::TerminateParams; +pub use protocol::TerminateResponse; +pub use protocol::WriteParams; +pub use protocol::WriteResponse; pub use server::DEFAULT_LISTEN_URL; pub use server::ExecServerListenUrlParseError; pub use server::run_main; diff --git a/codex-rs/exec-server/src/protocol.rs b/codex-rs/exec-server/src/protocol.rs index 165378fb5..4429b4ca7 100644 --- a/codex-rs/exec-server/src/protocol.rs +++ b/codex-rs/exec-server/src/protocol.rs @@ -1,8 +1,41 @@ +use std::collections::HashMap; +use std::path::PathBuf; + +use base64::engine::general_purpose::STANDARD as BASE64_STANDARD; use serde::Deserialize; use serde::Serialize; pub const INITIALIZE_METHOD: &str = "initialize"; pub const INITIALIZED_METHOD: &str = "initialized"; +pub const EXEC_METHOD: &str = "process/start"; +pub const EXEC_READ_METHOD: &str = "process/read"; +pub const EXEC_WRITE_METHOD: &str = "process/write"; +pub const EXEC_TERMINATE_METHOD: &str = "process/terminate"; +pub const EXEC_OUTPUT_DELTA_METHOD: &str = "process/output"; +pub const EXEC_EXITED_METHOD: &str = "process/exited"; +pub const FS_READ_FILE_METHOD: &str = "fs/readFile"; +pub const FS_WRITE_FILE_METHOD: &str = "fs/writeFile"; +pub const FS_CREATE_DIRECTORY_METHOD: &str = "fs/createDirectory"; +pub const FS_GET_METADATA_METHOD: &str = "fs/getMetadata"; +pub const FS_READ_DIRECTORY_METHOD: &str = "fs/readDirectory"; +pub const FS_REMOVE_METHOD: &str = "fs/remove"; +pub const FS_COPY_METHOD: &str = "fs/copy"; + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(transparent)] +pub struct ByteChunk(#[serde(with = "base64_bytes")] pub Vec); + +impl ByteChunk { + pub fn into_inner(self) -> Vec { + self.0 + } +} + +impl From> for ByteChunk { + fn from(value: Vec) -> Self { + Self(value) + } +} #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] @@ -13,3 +46,121 @@ pub struct InitializeParams { #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct InitializeResponse {} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ExecParams { + /// Client-chosen logical process handle scoped to this connection/session. + /// This is a protocol key, not an OS pid. + pub process_id: String, + pub argv: Vec, + pub cwd: PathBuf, + pub env: HashMap, + pub tty: bool, + pub arg0: Option, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ExecResponse { + pub process_id: String, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ReadParams { + pub process_id: String, + pub after_seq: Option, + pub max_bytes: Option, + pub wait_ms: Option, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ProcessOutputChunk { + pub seq: u64, + pub stream: ExecOutputStream, + pub chunk: ByteChunk, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ReadResponse { + pub chunks: Vec, + pub next_seq: u64, + pub exited: bool, + pub exit_code: Option, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct WriteParams { + pub process_id: String, + pub chunk: ByteChunk, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct WriteResponse { + pub accepted: bool, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct TerminateParams { + pub process_id: String, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct TerminateResponse { + pub running: bool, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub enum ExecOutputStream { + Stdout, + Stderr, + Pty, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ExecOutputDeltaNotification { + pub process_id: String, + pub stream: ExecOutputStream, + pub chunk: ByteChunk, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ExecExitedNotification { + pub process_id: String, + pub exit_code: i32, +} + +mod base64_bytes { + use super::BASE64_STANDARD; + use base64::Engine as _; + use serde::Deserialize; + use serde::Deserializer; + use serde::Serializer; + + pub fn serialize(bytes: &[u8], serializer: S) -> Result + where + S: Serializer, + { + serializer.serialize_str(&BASE64_STANDARD.encode(bytes)) + } + + pub fn deserialize<'de, D>(deserializer: D) -> Result, D::Error> + where + D: Deserializer<'de>, + { + let encoded = String::deserialize(deserializer)?; + BASE64_STANDARD + .decode(encoded) + .map_err(serde::de::Error::custom) + } +} diff --git a/codex-rs/exec-server/src/rpc.rs b/codex-rs/exec-server/src/rpc.rs index 0c8b5cdf3..8d79883c5 100644 --- a/codex-rs/exec-server/src/rpc.rs +++ b/codex-rs/exec-server/src/rpc.rs @@ -1,4 +1,6 @@ use std::collections::HashMap; +use std::future::Future; +use std::pin::Pin; use std::sync::Arc; use std::sync::atomic::AtomicI64; use std::sync::atomic::Ordering; @@ -23,6 +25,11 @@ use crate::connection::JsonRpcConnection; use crate::connection::JsonRpcConnectionEvent; type PendingRequest = oneshot::Sender>; +type BoxFuture = Pin + Send + 'static>>; +type RequestRoute = + Box, JSONRPCRequest) -> BoxFuture + Send + Sync>; +type NotificationRoute = + Box, JSONRPCNotification) -> BoxFuture> + Send + Sync>; #[derive(Debug)] pub(crate) enum RpcClientEvent { @@ -30,6 +37,139 @@ pub(crate) enum RpcClientEvent { Disconnected { reason: Option }, } +#[derive(Debug, Clone, PartialEq)] +pub(crate) enum RpcServerOutboundMessage { + Response { + request_id: RequestId, + result: Value, + }, + Error { + request_id: RequestId, + error: JSONRPCErrorError, + }, + #[allow(dead_code)] + Notification(JSONRPCNotification), +} + +#[allow(dead_code)] +#[derive(Clone)] +pub(crate) struct RpcNotificationSender { + outgoing_tx: mpsc::Sender, +} + +impl RpcNotificationSender { + pub(crate) fn new(outgoing_tx: mpsc::Sender) -> Self { + Self { outgoing_tx } + } + + #[allow(dead_code)] + pub(crate) async fn notify( + &self, + method: &str, + params: &P, + ) -> Result<(), JSONRPCErrorError> { + let params = serde_json::to_value(params).map_err(|err| internal_error(err.to_string()))?; + self.outgoing_tx + .send(RpcServerOutboundMessage::Notification( + JSONRPCNotification { + method: method.to_string(), + params: Some(params), + }, + )) + .await + .map_err(|_| internal_error("RPC connection closed while sending notification".into())) + } +} + +pub(crate) struct RpcRouter { + request_routes: HashMap<&'static str, RequestRoute>, + notification_routes: HashMap<&'static str, NotificationRoute>, +} + +impl Default for RpcRouter { + fn default() -> Self { + Self { + request_routes: HashMap::new(), + notification_routes: HashMap::new(), + } + } +} + +impl RpcRouter +where + S: Send + Sync + 'static, +{ + pub(crate) fn new() -> Self { + Self::default() + } + + pub(crate) fn request(&mut self, method: &'static str, handler: F) + where + P: DeserializeOwned + Send + 'static, + R: Serialize + Send + 'static, + F: Fn(Arc, P) -> Fut + Send + Sync + 'static, + Fut: Future> + Send + 'static, + { + self.request_routes.insert( + method, + Box::new(move |state, request| { + let request_id = request.id; + let params = request.params; + let response = + decode_request_params::

(params).map(|params| handler(state, params)); + Box::pin(async move { + let response = match response { + Ok(response) => response.await, + Err(error) => { + return RpcServerOutboundMessage::Error { request_id, error }; + } + }; + match response { + Ok(result) => match serde_json::to_value(result) { + Ok(result) => RpcServerOutboundMessage::Response { request_id, result }, + Err(err) => RpcServerOutboundMessage::Error { + request_id, + error: internal_error(err.to_string()), + }, + }, + Err(error) => RpcServerOutboundMessage::Error { request_id, error }, + } + }) + }), + ); + } + + pub(crate) fn notification(&mut self, method: &'static str, handler: F) + where + P: DeserializeOwned + Send + 'static, + F: Fn(Arc, P) -> Fut + Send + Sync + 'static, + Fut: Future> + Send + 'static, + { + self.notification_routes.insert( + method, + Box::new(move |state, notification| { + let params = decode_notification_params::

(notification.params) + .map(|params| handler(state, params)); + Box::pin(async move { + let handler = match params { + Ok(handler) => handler, + Err(err) => return Err(err), + }; + handler.await + }) + }), + ); + } + + pub(crate) fn request_route(&self, method: &str) -> Option<&RequestRoute> { + self.request_routes.get(method) + } + + pub(crate) fn notification_route(&self, method: &str) -> Option<&NotificationRoute> { + self.notification_routes.get(method) + } +} + pub(crate) struct RpcClient { write_tx: mpsc::Sender, pending: Arc>>, @@ -57,14 +197,8 @@ impl RpcClient { } } JsonRpcConnectionEvent::MalformedMessage { reason } => { - warn!("JSON-RPC client closing after malformed server message: {reason}"); - let _ = event_tx - .send(RpcClientEvent::Disconnected { - reason: Some(reason), - }) - .await; - drain_pending(&pending_for_reader).await; - return; + warn!("JSON-RPC client closing after malformed message: {reason}"); + break; } JsonRpcConnectionEvent::Disconnected { reason } => { let _ = event_tx.send(RpcClientEvent::Disconnected { reason }).await; @@ -177,6 +311,91 @@ pub(crate) enum RpcCallError { Server(JSONRPCErrorError), } +pub(crate) fn encode_server_message( + message: RpcServerOutboundMessage, +) -> Result { + match message { + RpcServerOutboundMessage::Response { request_id, result } => { + Ok(JSONRPCMessage::Response(JSONRPCResponse { + id: request_id, + result, + })) + } + RpcServerOutboundMessage::Error { request_id, error } => { + Ok(JSONRPCMessage::Error(JSONRPCError { + id: request_id, + error, + })) + } + RpcServerOutboundMessage::Notification(notification) => { + Ok(JSONRPCMessage::Notification(notification)) + } + } +} + +pub(crate) fn invalid_request(message: String) -> JSONRPCErrorError { + JSONRPCErrorError { + code: -32600, + data: None, + message, + } +} + +pub(crate) fn method_not_found(message: String) -> JSONRPCErrorError { + JSONRPCErrorError { + code: -32601, + data: None, + message, + } +} + +pub(crate) fn invalid_params(message: String) -> JSONRPCErrorError { + JSONRPCErrorError { + code: -32602, + data: None, + message, + } +} + +pub(crate) fn internal_error(message: String) -> JSONRPCErrorError { + JSONRPCErrorError { + code: -32603, + data: None, + message, + } +} + +fn decode_request_params

(params: Option) -> Result +where + P: DeserializeOwned, +{ + decode_params(params).map_err(|err| invalid_params(err.to_string())) +} + +fn decode_notification_params

(params: Option) -> Result +where + P: DeserializeOwned, +{ + decode_params(params).map_err(|err| err.to_string()) +} + +fn decode_params

(params: Option) -> Result +where + P: DeserializeOwned, +{ + let params = params.unwrap_or(Value::Null); + match serde_json::from_value(params.clone()) { + Ok(params) => Ok(params), + Err(err) => { + if matches!(params, Value::Object(ref map) if map.is_empty()) { + serde_json::from_value(Value::Null).map_err(|_| err) + } else { + Err(err) + } + } + } +} + async fn handle_server_message( pending: &Mutex>, event_tx: &mpsc::Sender, diff --git a/codex-rs/exec-server/src/server.rs b/codex-rs/exec-server/src/server.rs index af1e929cf..c403b029d 100644 --- a/codex-rs/exec-server/src/server.rs +++ b/codex-rs/exec-server/src/server.rs @@ -1,6 +1,7 @@ +mod filesystem; mod handler; -mod jsonrpc; mod processor; +mod registry; mod transport; pub(crate) use handler::ExecServerHandler; diff --git a/codex-rs/exec-server/src/server/filesystem.rs b/codex-rs/exec-server/src/server/filesystem.rs new file mode 100644 index 000000000..bc3d22a4d --- /dev/null +++ b/codex-rs/exec-server/src/server/filesystem.rs @@ -0,0 +1,170 @@ +use std::io; +use std::sync::Arc; + +use base64::Engine as _; +use base64::engine::general_purpose::STANDARD; +use codex_app_server_protocol::FsCopyParams; +use codex_app_server_protocol::FsCopyResponse; +use codex_app_server_protocol::FsCreateDirectoryParams; +use codex_app_server_protocol::FsCreateDirectoryResponse; +use codex_app_server_protocol::FsGetMetadataParams; +use codex_app_server_protocol::FsGetMetadataResponse; +use codex_app_server_protocol::FsReadDirectoryEntry; +use codex_app_server_protocol::FsReadDirectoryParams; +use codex_app_server_protocol::FsReadDirectoryResponse; +use codex_app_server_protocol::FsReadFileParams; +use codex_app_server_protocol::FsReadFileResponse; +use codex_app_server_protocol::FsRemoveParams; +use codex_app_server_protocol::FsRemoveResponse; +use codex_app_server_protocol::FsWriteFileParams; +use codex_app_server_protocol::FsWriteFileResponse; +use codex_app_server_protocol::JSONRPCErrorError; + +use crate::CopyOptions; +use crate::CreateDirectoryOptions; +use crate::Environment; +use crate::ExecutorFileSystem; +use crate::RemoveOptions; +use crate::rpc::internal_error; +use crate::rpc::invalid_request; + +#[derive(Clone)] +pub(crate) struct ExecServerFileSystem { + file_system: Arc, +} + +impl Default for ExecServerFileSystem { + fn default() -> Self { + Self { + file_system: Arc::new(Environment.get_filesystem()), + } + } +} + +impl ExecServerFileSystem { + pub(crate) async fn read_file( + &self, + params: FsReadFileParams, + ) -> Result { + let bytes = self + .file_system + .read_file(¶ms.path) + .await + .map_err(map_fs_error)?; + Ok(FsReadFileResponse { + data_base64: STANDARD.encode(bytes), + }) + } + + pub(crate) async fn write_file( + &self, + params: FsWriteFileParams, + ) -> Result { + let bytes = STANDARD.decode(params.data_base64).map_err(|err| { + invalid_request(format!( + "fs/writeFile requires valid base64 dataBase64: {err}" + )) + })?; + self.file_system + .write_file(¶ms.path, bytes) + .await + .map_err(map_fs_error)?; + Ok(FsWriteFileResponse {}) + } + + pub(crate) async fn create_directory( + &self, + params: FsCreateDirectoryParams, + ) -> Result { + self.file_system + .create_directory( + ¶ms.path, + CreateDirectoryOptions { + recursive: params.recursive.unwrap_or(true), + }, + ) + .await + .map_err(map_fs_error)?; + Ok(FsCreateDirectoryResponse {}) + } + + pub(crate) async fn get_metadata( + &self, + params: FsGetMetadataParams, + ) -> Result { + let metadata = self + .file_system + .get_metadata(¶ms.path) + .await + .map_err(map_fs_error)?; + Ok(FsGetMetadataResponse { + is_directory: metadata.is_directory, + is_file: metadata.is_file, + created_at_ms: metadata.created_at_ms, + modified_at_ms: metadata.modified_at_ms, + }) + } + + pub(crate) async fn read_directory( + &self, + params: FsReadDirectoryParams, + ) -> Result { + let entries = self + .file_system + .read_directory(¶ms.path) + .await + .map_err(map_fs_error)?; + Ok(FsReadDirectoryResponse { + entries: entries + .into_iter() + .map(|entry| FsReadDirectoryEntry { + file_name: entry.file_name, + is_directory: entry.is_directory, + is_file: entry.is_file, + }) + .collect(), + }) + } + + pub(crate) async fn remove( + &self, + params: FsRemoveParams, + ) -> Result { + self.file_system + .remove( + ¶ms.path, + RemoveOptions { + recursive: params.recursive.unwrap_or(true), + force: params.force.unwrap_or(true), + }, + ) + .await + .map_err(map_fs_error)?; + Ok(FsRemoveResponse {}) + } + + pub(crate) async fn copy( + &self, + params: FsCopyParams, + ) -> Result { + self.file_system + .copy( + ¶ms.source_path, + ¶ms.destination_path, + CopyOptions { + recursive: params.recursive, + }, + ) + .await + .map_err(map_fs_error)?; + Ok(FsCopyResponse {}) + } +} + +fn map_fs_error(err: io::Error) -> JSONRPCErrorError { + if err.kind() == io::ErrorKind::InvalidInput { + invalid_request(err.to_string()) + } else { + internal_error(err.to_string()) + } +} diff --git a/codex-rs/exec-server/src/server/handler.rs b/codex-rs/exec-server/src/server/handler.rs index 838e58240..c21aeecb5 100644 --- a/codex-rs/exec-server/src/server/handler.rs +++ b/codex-rs/exec-server/src/server/handler.rs @@ -1,25 +1,112 @@ +use std::collections::HashMap; +use std::collections::VecDeque; +use std::sync::Arc; use std::sync::atomic::AtomicBool; use std::sync::atomic::Ordering; +use std::time::Duration; +use codex_app_server_protocol::FsCopyParams; +use codex_app_server_protocol::FsCopyResponse; +use codex_app_server_protocol::FsCreateDirectoryParams; +use codex_app_server_protocol::FsCreateDirectoryResponse; +use codex_app_server_protocol::FsGetMetadataParams; +use codex_app_server_protocol::FsGetMetadataResponse; +use codex_app_server_protocol::FsReadDirectoryParams; +use codex_app_server_protocol::FsReadDirectoryResponse; +use codex_app_server_protocol::FsReadFileParams; +use codex_app_server_protocol::FsReadFileResponse; +use codex_app_server_protocol::FsRemoveParams; +use codex_app_server_protocol::FsRemoveResponse; +use codex_app_server_protocol::FsWriteFileParams; +use codex_app_server_protocol::FsWriteFileResponse; use codex_app_server_protocol::JSONRPCErrorError; +use codex_utils_pty::ExecCommandSession; +use codex_utils_pty::TerminalSize; +use tokio::sync::Mutex; +use tokio::sync::Notify; +use tracing::warn; +use crate::protocol::ExecExitedNotification; +use crate::protocol::ExecOutputDeltaNotification; +use crate::protocol::ExecOutputStream; +use crate::protocol::ExecParams; +use crate::protocol::ExecResponse; use crate::protocol::InitializeResponse; -use crate::server::jsonrpc::invalid_request; +use crate::protocol::ProcessOutputChunk; +use crate::protocol::ReadParams; +use crate::protocol::ReadResponse; +use crate::protocol::TerminateParams; +use crate::protocol::TerminateResponse; +use crate::protocol::WriteParams; +use crate::protocol::WriteResponse; +use crate::rpc::RpcNotificationSender; +use crate::rpc::internal_error; +use crate::rpc::invalid_params; +use crate::rpc::invalid_request; +use crate::server::filesystem::ExecServerFileSystem; + +const RETAINED_OUTPUT_BYTES_PER_PROCESS: usize = 1024 * 1024; +#[cfg(test)] +const EXITED_PROCESS_RETENTION: Duration = Duration::from_millis(25); +#[cfg(not(test))] +const EXITED_PROCESS_RETENTION: Duration = Duration::from_secs(30); + +#[derive(Clone)] +struct RetainedOutputChunk { + seq: u64, + stream: ExecOutputStream, + chunk: Vec, +} + +struct RunningProcess { + session: ExecCommandSession, + tty: bool, + output: VecDeque, + retained_bytes: usize, + next_seq: u64, + exit_code: Option, + output_notify: Arc, +} + +enum ProcessEntry { + Starting, + Running(Box), +} pub(crate) struct ExecServerHandler { + notifications: RpcNotificationSender, + file_system: ExecServerFileSystem, + processes: Arc>>, initialize_requested: AtomicBool, initialized: AtomicBool, } impl ExecServerHandler { - pub(crate) fn new() -> Self { + pub(crate) fn new(notifications: RpcNotificationSender) -> Self { Self { + notifications, + file_system: ExecServerFileSystem::default(), + processes: Arc::new(Mutex::new(HashMap::new())), initialize_requested: AtomicBool::new(false), initialized: AtomicBool::new(false), } } - pub(crate) async fn shutdown(&self) {} + pub(crate) async fn shutdown(&self) { + let remaining = { + let mut processes = self.processes.lock().await; + processes + .drain() + .filter_map(|(_, process)| match process { + ProcessEntry::Starting => None, + ProcessEntry::Running(process) => Some(process), + }) + .collect::>() + }; + for process in remaining { + process.session.terminate(); + } + } pub(crate) fn initialize(&self) -> Result { if self.initialize_requested.swap(true, Ordering::SeqCst) { @@ -37,4 +124,394 @@ impl ExecServerHandler { self.initialized.store(true, Ordering::SeqCst); Ok(()) } + + fn require_initialized_for(&self, method_family: &str) -> Result<(), JSONRPCErrorError> { + if !self.initialize_requested.load(Ordering::SeqCst) { + return Err(invalid_request(format!( + "client must call initialize before using {method_family} methods" + ))); + } + if !self.initialized.load(Ordering::SeqCst) { + return Err(invalid_request(format!( + "client must send initialized before using {method_family} methods" + ))); + } + Ok(()) + } + + pub(crate) async fn exec(&self, params: ExecParams) -> Result { + self.require_initialized_for("exec")?; + let process_id = params.process_id.clone(); + + let (program, args) = params + .argv + .split_first() + .ok_or_else(|| invalid_params("argv must not be empty".to_string()))?; + + { + let mut process_map = self.processes.lock().await; + if process_map.contains_key(&process_id) { + return Err(invalid_request(format!( + "process {process_id} already exists" + ))); + } + process_map.insert(process_id.clone(), ProcessEntry::Starting); + } + + let spawned_result = if params.tty { + codex_utils_pty::spawn_pty_process( + program, + args, + params.cwd.as_path(), + ¶ms.env, + ¶ms.arg0, + TerminalSize::default(), + ) + .await + } else { + codex_utils_pty::spawn_pipe_process_no_stdin( + program, + args, + params.cwd.as_path(), + ¶ms.env, + ¶ms.arg0, + ) + .await + }; + let spawned = match spawned_result { + Ok(spawned) => spawned, + Err(err) => { + let mut process_map = self.processes.lock().await; + if matches!(process_map.get(&process_id), Some(ProcessEntry::Starting)) { + process_map.remove(&process_id); + } + return Err(internal_error(err.to_string())); + } + }; + + let output_notify = Arc::new(Notify::new()); + { + let mut process_map = self.processes.lock().await; + process_map.insert( + process_id.clone(), + ProcessEntry::Running(Box::new(RunningProcess { + session: spawned.session, + tty: params.tty, + output: VecDeque::new(), + retained_bytes: 0, + next_seq: 1, + exit_code: None, + output_notify: Arc::clone(&output_notify), + })), + ); + } + + tokio::spawn(stream_output( + process_id.clone(), + if params.tty { + ExecOutputStream::Pty + } else { + ExecOutputStream::Stdout + }, + spawned.stdout_rx, + self.notifications.clone(), + Arc::clone(&self.processes), + Arc::clone(&output_notify), + )); + tokio::spawn(stream_output( + process_id.clone(), + if params.tty { + ExecOutputStream::Pty + } else { + ExecOutputStream::Stderr + }, + spawned.stderr_rx, + self.notifications.clone(), + Arc::clone(&self.processes), + Arc::clone(&output_notify), + )); + tokio::spawn(watch_exit( + process_id.clone(), + spawned.exit_rx, + self.notifications.clone(), + Arc::clone(&self.processes), + output_notify, + )); + + Ok(ExecResponse { process_id }) + } + + pub(crate) async fn exec_read( + &self, + params: ReadParams, + ) -> Result { + self.require_initialized_for("exec")?; + let after_seq = params.after_seq.unwrap_or(0); + let max_bytes = params.max_bytes.unwrap_or(usize::MAX); + let wait = Duration::from_millis(params.wait_ms.unwrap_or(0)); + let deadline = tokio::time::Instant::now() + wait; + + loop { + let (response, output_notify) = { + let process_map = self.processes.lock().await; + let process = process_map.get(¶ms.process_id).ok_or_else(|| { + invalid_request(format!("unknown process id {}", params.process_id)) + })?; + let ProcessEntry::Running(process) = process else { + return Err(invalid_request(format!( + "process id {} is starting", + params.process_id + ))); + }; + + let mut chunks = Vec::new(); + let mut total_bytes = 0; + let mut next_seq = process.next_seq; + for retained in process.output.iter().filter(|chunk| chunk.seq > after_seq) { + let chunk_len = retained.chunk.len(); + if !chunks.is_empty() && total_bytes + chunk_len > max_bytes { + break; + } + total_bytes += chunk_len; + chunks.push(ProcessOutputChunk { + seq: retained.seq, + stream: retained.stream, + chunk: retained.chunk.clone().into(), + }); + next_seq = retained.seq + 1; + if total_bytes >= max_bytes { + break; + } + } + + ( + ReadResponse { + chunks, + next_seq, + exited: process.exit_code.is_some(), + exit_code: process.exit_code, + }, + Arc::clone(&process.output_notify), + ) + }; + + if !response.chunks.is_empty() + || response.exited + || tokio::time::Instant::now() >= deadline + { + return Ok(response); + } + + let remaining = deadline.saturating_duration_since(tokio::time::Instant::now()); + if remaining.is_zero() { + return Ok(response); + } + let _ = tokio::time::timeout(remaining, output_notify.notified()).await; + } + } + + pub(crate) async fn exec_write( + &self, + params: WriteParams, + ) -> Result { + self.require_initialized_for("exec")?; + let writer_tx = { + let process_map = self.processes.lock().await; + let process = process_map.get(¶ms.process_id).ok_or_else(|| { + invalid_request(format!("unknown process id {}", params.process_id)) + })?; + let ProcessEntry::Running(process) = process else { + return Err(invalid_request(format!( + "process id {} is starting", + params.process_id + ))); + }; + if !process.tty { + return Err(invalid_request(format!( + "stdin is closed for process {}", + params.process_id + ))); + } + process.session.writer_sender() + }; + + writer_tx + .send(params.chunk.into_inner()) + .await + .map_err(|_| internal_error("failed to write to process stdin".to_string()))?; + + Ok(WriteResponse { accepted: true }) + } + + pub(crate) async fn terminate( + &self, + params: TerminateParams, + ) -> Result { + self.require_initialized_for("exec")?; + let running = { + let process_map = self.processes.lock().await; + match process_map.get(¶ms.process_id) { + Some(ProcessEntry::Running(process)) => { + if process.exit_code.is_some() { + return Ok(TerminateResponse { running: false }); + } + process.session.terminate(); + true + } + Some(ProcessEntry::Starting) | None => false, + } + }; + + Ok(TerminateResponse { running }) + } + + pub(crate) async fn fs_read_file( + &self, + params: FsReadFileParams, + ) -> Result { + self.require_initialized_for("filesystem")?; + self.file_system.read_file(params).await + } + + pub(crate) async fn fs_write_file( + &self, + params: FsWriteFileParams, + ) -> Result { + self.require_initialized_for("filesystem")?; + self.file_system.write_file(params).await + } + + pub(crate) async fn fs_create_directory( + &self, + params: FsCreateDirectoryParams, + ) -> Result { + self.require_initialized_for("filesystem")?; + self.file_system.create_directory(params).await + } + + pub(crate) async fn fs_get_metadata( + &self, + params: FsGetMetadataParams, + ) -> Result { + self.require_initialized_for("filesystem")?; + self.file_system.get_metadata(params).await + } + + pub(crate) async fn fs_read_directory( + &self, + params: FsReadDirectoryParams, + ) -> Result { + self.require_initialized_for("filesystem")?; + self.file_system.read_directory(params).await + } + + pub(crate) async fn fs_remove( + &self, + params: FsRemoveParams, + ) -> Result { + self.require_initialized_for("filesystem")?; + self.file_system.remove(params).await + } + + pub(crate) async fn fs_copy( + &self, + params: FsCopyParams, + ) -> Result { + self.require_initialized_for("filesystem")?; + self.file_system.copy(params).await + } } + +async fn stream_output( + process_id: String, + stream: ExecOutputStream, + mut receiver: tokio::sync::mpsc::Receiver>, + notifications: RpcNotificationSender, + processes: Arc>>, + output_notify: Arc, +) { + while let Some(chunk) = receiver.recv().await { + let notification = { + let mut processes = processes.lock().await; + let Some(entry) = processes.get_mut(&process_id) else { + break; + }; + let ProcessEntry::Running(process) = entry else { + break; + }; + let seq = process.next_seq; + process.next_seq += 1; + process.retained_bytes += chunk.len(); + process.output.push_back(RetainedOutputChunk { + seq, + stream, + chunk: chunk.clone(), + }); + while process.retained_bytes > RETAINED_OUTPUT_BYTES_PER_PROCESS { + let Some(evicted) = process.output.pop_front() else { + break; + }; + process.retained_bytes = process.retained_bytes.saturating_sub(evicted.chunk.len()); + warn!( + "retained output cap exceeded for process {process_id}; dropping oldest output" + ); + } + ExecOutputDeltaNotification { + process_id: process_id.clone(), + stream, + chunk: chunk.into(), + } + }; + output_notify.notify_waiters(); + + if notifications + .notify(crate::protocol::EXEC_OUTPUT_DELTA_METHOD, ¬ification) + .await + .is_err() + { + break; + } + } +} + +async fn watch_exit( + process_id: String, + exit_rx: tokio::sync::oneshot::Receiver, + notifications: RpcNotificationSender, + processes: Arc>>, + output_notify: Arc, +) { + let exit_code = exit_rx.await.unwrap_or(-1); + { + let mut processes = processes.lock().await; + if let Some(ProcessEntry::Running(process)) = processes.get_mut(&process_id) { + process.exit_code = Some(exit_code); + } + } + output_notify.notify_waiters(); + if notifications + .notify( + crate::protocol::EXEC_EXITED_METHOD, + &ExecExitedNotification { + process_id: process_id.clone(), + exit_code, + }, + ) + .await + .is_err() + { + return; + } + + tokio::time::sleep(EXITED_PROCESS_RETENTION).await; + let mut processes = processes.lock().await; + if matches!( + processes.get(&process_id), + Some(ProcessEntry::Running(process)) if process.exit_code == Some(exit_code) + ) { + processes.remove(&process_id); + } +} + +#[cfg(test)] +mod tests; diff --git a/codex-rs/exec-server/src/server/handler/tests.rs b/codex-rs/exec-server/src/server/handler/tests.rs new file mode 100644 index 000000000..5b6c9074f --- /dev/null +++ b/codex-rs/exec-server/src/server/handler/tests.rs @@ -0,0 +1,102 @@ +use std::collections::HashMap; +use std::sync::Arc; +use std::time::Duration; + +use pretty_assertions::assert_eq; +use tokio::sync::mpsc; + +use super::ExecServerHandler; +use crate::protocol::ExecParams; +use crate::protocol::InitializeResponse; +use crate::protocol::TerminateParams; +use crate::protocol::TerminateResponse; +use crate::rpc::RpcNotificationSender; + +fn exec_params(process_id: &str) -> ExecParams { + let mut env = HashMap::new(); + if let Some(path) = std::env::var_os("PATH") { + env.insert("PATH".to_string(), path.to_string_lossy().into_owned()); + } + ExecParams { + process_id: process_id.to_string(), + argv: vec![ + "bash".to_string(), + "-lc".to_string(), + "sleep 0.1".to_string(), + ], + cwd: std::env::current_dir().expect("cwd"), + env, + tty: false, + arg0: None, + } +} + +async fn initialized_handler() -> Arc { + let (outgoing_tx, _outgoing_rx) = mpsc::channel(16); + let handler = Arc::new(ExecServerHandler::new(RpcNotificationSender::new( + outgoing_tx, + ))); + assert_eq!( + handler.initialize().expect("initialize"), + InitializeResponse {} + ); + handler.initialized().expect("initialized"); + handler +} + +#[tokio::test] +async fn duplicate_process_ids_allow_only_one_successful_start() { + let handler = initialized_handler().await; + let first_handler = Arc::clone(&handler); + let second_handler = Arc::clone(&handler); + + let (first, second) = tokio::join!( + first_handler.exec(exec_params("proc-1")), + second_handler.exec(exec_params("proc-1")), + ); + + let (successes, failures): (Vec<_>, Vec<_>) = + [first, second].into_iter().partition(Result::is_ok); + assert_eq!(successes.len(), 1); + assert_eq!(failures.len(), 1); + + let error = failures + .into_iter() + .next() + .expect("one failed request") + .expect_err("expected duplicate process error"); + assert_eq!(error.code, -32600); + assert_eq!(error.message, "process proc-1 already exists"); + + tokio::time::sleep(Duration::from_millis(150)).await; + handler.shutdown().await; +} + +#[tokio::test] +async fn terminate_reports_false_after_process_exit() { + let handler = initialized_handler().await; + handler + .exec(exec_params("proc-1")) + .await + .expect("start process"); + + let deadline = tokio::time::Instant::now() + Duration::from_secs(1); + loop { + let response = handler + .terminate(TerminateParams { + process_id: "proc-1".to_string(), + }) + .await + .expect("terminate response"); + if response == (TerminateResponse { running: false }) { + break; + } + assert!( + tokio::time::Instant::now() < deadline, + "process should have exited within 1s" + ); + tokio::time::sleep(Duration::from_millis(25)).await; + } + + handler.shutdown().await; +} diff --git a/codex-rs/exec-server/src/server/processor.rs b/codex-rs/exec-server/src/server/processor.rs index 7a8ca40f0..518a1a78e 100644 --- a/codex-rs/exec-server/src/server/processor.rs +++ b/codex-rs/exec-server/src/server/processor.rs @@ -1,53 +1,109 @@ -use codex_app_server_protocol::JSONRPCMessage; -use codex_app_server_protocol::JSONRPCNotification; -use codex_app_server_protocol::JSONRPCRequest; -use tracing::debug; +use std::sync::Arc; -use crate::connection::JsonRpcConnection; -use crate::connection::JsonRpcConnectionEvent; -use crate::protocol::INITIALIZE_METHOD; -use crate::protocol::INITIALIZED_METHOD; -use crate::protocol::InitializeParams; -use crate::server::ExecServerHandler; -use crate::server::jsonrpc::invalid_params; -use crate::server::jsonrpc::invalid_request_message; -use crate::server::jsonrpc::method_not_found; -use crate::server::jsonrpc::response_message; +use tokio::sync::mpsc; +use tracing::debug; use tracing::warn; -pub(crate) async fn run_connection(connection: JsonRpcConnection) { - let (json_outgoing_tx, mut incoming_rx, _connection_tasks) = connection.into_parts(); - let handler = ExecServerHandler::new(); +use crate::connection::CHANNEL_CAPACITY; +use crate::connection::JsonRpcConnection; +use crate::connection::JsonRpcConnectionEvent; +use crate::rpc::RpcNotificationSender; +use crate::rpc::RpcServerOutboundMessage; +use crate::rpc::encode_server_message; +use crate::rpc::invalid_request; +use crate::rpc::method_not_found; +use crate::server::ExecServerHandler; +use crate::server::registry::build_router; - while let Some(event) = incoming_rx.recv().await { - match event { - JsonRpcConnectionEvent::Message(message) => { - let response = match handle_connection_message(&handler, message).await { - Ok(response) => response, - Err(err) => { - tracing::warn!( - "closing exec-server connection after protocol error: {err}" - ); - break; - } - }; - let Some(response) = response else { - continue; - }; - if json_outgoing_tx.send(response).await.is_err() { +pub(crate) async fn run_connection(connection: JsonRpcConnection) { + let router = Arc::new(build_router()); + let (json_outgoing_tx, mut incoming_rx, connection_tasks) = connection.into_parts(); + let (outgoing_tx, mut outgoing_rx) = + mpsc::channel::(CHANNEL_CAPACITY); + let notifications = RpcNotificationSender::new(outgoing_tx.clone()); + let handler = Arc::new(ExecServerHandler::new(notifications)); + + let outbound_task = tokio::spawn(async move { + while let Some(message) = outgoing_rx.recv().await { + let json_message = match encode_server_message(message) { + Ok(json_message) => json_message, + Err(err) => { + warn!("failed to serialize exec-server outbound message: {err}"); break; } + }; + if json_outgoing_tx.send(json_message).await.is_err() { + break; } + } + }); + + // Process inbound events sequentially to preserve initialize/initialized ordering. + while let Some(event) = incoming_rx.recv().await { + match event { JsonRpcConnectionEvent::MalformedMessage { reason } => { warn!("ignoring malformed exec-server message: {reason}"); - if json_outgoing_tx - .send(invalid_request_message(reason)) + if outgoing_tx + .send(RpcServerOutboundMessage::Error { + request_id: codex_app_server_protocol::RequestId::Integer(-1), + error: invalid_request(reason), + }) .await .is_err() { break; } } + JsonRpcConnectionEvent::Message(message) => match message { + codex_app_server_protocol::JSONRPCMessage::Request(request) => { + if let Some(route) = router.request_route(request.method.as_str()) { + let message = route(handler.clone(), request).await; + if outgoing_tx.send(message).await.is_err() { + break; + } + } else if outgoing_tx + .send(RpcServerOutboundMessage::Error { + request_id: request.id, + error: method_not_found(format!( + "exec-server stub does not implement `{}` yet", + request.method + )), + }) + .await + .is_err() + { + break; + } + } + codex_app_server_protocol::JSONRPCMessage::Notification(notification) => { + let Some(route) = router.notification_route(notification.method.as_str()) + else { + warn!( + "closing exec-server connection after unexpected notification: {}", + notification.method + ); + break; + }; + if let Err(err) = route(handler.clone(), notification).await { + warn!("closing exec-server connection after protocol error: {err}"); + break; + } + } + codex_app_server_protocol::JSONRPCMessage::Response(response) => { + warn!( + "closing exec-server connection after unexpected client response: {:?}", + response.id + ); + break; + } + codex_app_server_protocol::JSONRPCMessage::Error(error) => { + warn!( + "closing exec-server connection after unexpected client error: {:?}", + error.id + ); + break; + } + }, JsonRpcConnectionEvent::Disconnected { reason } => { if let Some(reason) = reason { debug!("exec-server connection disconnected: {reason}"); @@ -58,64 +114,10 @@ pub(crate) async fn run_connection(connection: JsonRpcConnection) { } handler.shutdown().await; -} - -pub(crate) async fn handle_connection_message( - handler: &ExecServerHandler, - message: JSONRPCMessage, -) -> Result, String> { - match message { - JSONRPCMessage::Request(request) => Ok(Some(dispatch_request(handler, request))), - JSONRPCMessage::Notification(notification) => { - handle_notification(handler, notification)?; - Ok(None) - } - JSONRPCMessage::Response(response) => Err(format!( - "unexpected client response for request id {:?}", - response.id - )), - JSONRPCMessage::Error(error) => Err(format!( - "unexpected client error for request id {:?}", - error.id - )), - } -} - -fn dispatch_request(handler: &ExecServerHandler, request: JSONRPCRequest) -> JSONRPCMessage { - let JSONRPCRequest { - id, - method, - params, - trace: _, - } = request; - - match method.as_str() { - INITIALIZE_METHOD => { - let result = serde_json::from_value::( - params.unwrap_or(serde_json::Value::Null), - ) - .map_err(|err| invalid_params(err.to_string())) - .and_then(|_params| handler.initialize()) - .and_then(|response| { - serde_json::to_value(response).map_err(|err| invalid_params(err.to_string())) - }); - response_message(id, result) - } - other => response_message( - id, - Err(method_not_found(format!( - "exec-server stub does not implement `{other}` yet" - ))), - ), - } -} - -fn handle_notification( - handler: &ExecServerHandler, - notification: JSONRPCNotification, -) -> Result<(), String> { - match notification.method.as_str() { - INITIALIZED_METHOD => handler.initialized(), - other => Err(format!("unexpected notification method: {other}")), + drop(outgoing_tx); + for task in connection_tasks { + task.abort(); + let _ = task.await; } + let _ = outbound_task.await; } diff --git a/codex-rs/exec-server/src/server/registry.rs b/codex-rs/exec-server/src/server/registry.rs new file mode 100644 index 000000000..482e5ab61 --- /dev/null +++ b/codex-rs/exec-server/src/server/registry.rs @@ -0,0 +1,110 @@ +use std::sync::Arc; + +use crate::protocol::EXEC_METHOD; +use crate::protocol::EXEC_READ_METHOD; +use crate::protocol::EXEC_TERMINATE_METHOD; +use crate::protocol::EXEC_WRITE_METHOD; +use crate::protocol::ExecParams; +use crate::protocol::FS_COPY_METHOD; +use crate::protocol::FS_CREATE_DIRECTORY_METHOD; +use crate::protocol::FS_GET_METADATA_METHOD; +use crate::protocol::FS_READ_DIRECTORY_METHOD; +use crate::protocol::FS_READ_FILE_METHOD; +use crate::protocol::FS_REMOVE_METHOD; +use crate::protocol::FS_WRITE_FILE_METHOD; +use crate::protocol::INITIALIZE_METHOD; +use crate::protocol::INITIALIZED_METHOD; +use crate::protocol::InitializeParams; +use crate::protocol::ReadParams; +use crate::protocol::TerminateParams; +use crate::protocol::WriteParams; +use crate::rpc::RpcRouter; +use crate::server::ExecServerHandler; +use codex_app_server_protocol::FsCopyParams; +use codex_app_server_protocol::FsCreateDirectoryParams; +use codex_app_server_protocol::FsGetMetadataParams; +use codex_app_server_protocol::FsReadDirectoryParams; +use codex_app_server_protocol::FsReadFileParams; +use codex_app_server_protocol::FsRemoveParams; +use codex_app_server_protocol::FsWriteFileParams; + +pub(crate) fn build_router() -> RpcRouter { + let mut router = RpcRouter::new(); + router.request( + INITIALIZE_METHOD, + |handler: Arc, _params: InitializeParams| async move { + handler.initialize() + }, + ); + router.notification( + INITIALIZED_METHOD, + |handler: Arc, _params: serde_json::Value| async move { + handler.initialized() + }, + ); + router.request( + EXEC_METHOD, + |handler: Arc, params: ExecParams| async move { handler.exec(params).await }, + ); + router.request( + EXEC_READ_METHOD, + |handler: Arc, params: ReadParams| async move { + handler.exec_read(params).await + }, + ); + router.request( + EXEC_WRITE_METHOD, + |handler: Arc, params: WriteParams| async move { + handler.exec_write(params).await + }, + ); + router.request( + EXEC_TERMINATE_METHOD, + |handler: Arc, params: TerminateParams| async move { + handler.terminate(params).await + }, + ); + router.request( + FS_READ_FILE_METHOD, + |handler: Arc, params: FsReadFileParams| async move { + handler.fs_read_file(params).await + }, + ); + router.request( + FS_WRITE_FILE_METHOD, + |handler: Arc, params: FsWriteFileParams| async move { + handler.fs_write_file(params).await + }, + ); + router.request( + FS_CREATE_DIRECTORY_METHOD, + |handler: Arc, params: FsCreateDirectoryParams| async move { + handler.fs_create_directory(params).await + }, + ); + router.request( + FS_GET_METADATA_METHOD, + |handler: Arc, params: FsGetMetadataParams| async move { + handler.fs_get_metadata(params).await + }, + ); + router.request( + FS_READ_DIRECTORY_METHOD, + |handler: Arc, params: FsReadDirectoryParams| async move { + handler.fs_read_directory(params).await + }, + ); + router.request( + FS_REMOVE_METHOD, + |handler: Arc, params: FsRemoveParams| async move { + handler.fs_remove(params).await + }, + ); + router.request( + FS_COPY_METHOD, + |handler: Arc, params: FsCopyParams| async move { + handler.fs_copy(params).await + }, + ); + router +} diff --git a/codex-rs/exec-server/tests/process.rs b/codex-rs/exec-server/tests/process.rs index a99a889ed..4926e6088 100644 --- a/codex-rs/exec-server/tests/process.rs +++ b/codex-rs/exec-server/tests/process.rs @@ -2,15 +2,15 @@ mod common; -use codex_app_server_protocol::JSONRPCError; use codex_app_server_protocol::JSONRPCMessage; use codex_app_server_protocol::JSONRPCResponse; +use codex_exec_server::ExecResponse; use codex_exec_server::InitializeParams; use common::exec_server::exec_server; use pretty_assertions::assert_eq; #[tokio::test(flavor = "multi_thread", worker_threads = 2)] -async fn exec_server_stubs_process_start_over_websocket() -> anyhow::Result<()> { +async fn exec_server_starts_process_over_websocket() -> anyhow::Result<()> { let mut server = exec_server().await?; let initialize_id = server .send_request( @@ -29,6 +29,10 @@ async fn exec_server_stubs_process_start_over_websocket() -> anyhow::Result<()> }) .await?; + server + .send_notification("initialized", serde_json::json!({})) + .await?; + let process_start_id = server .send_request( "process/start", @@ -46,18 +50,20 @@ async fn exec_server_stubs_process_start_over_websocket() -> anyhow::Result<()> .wait_for_event(|event| { matches!( event, - JSONRPCMessage::Error(JSONRPCError { id, .. }) if id == &process_start_id + JSONRPCMessage::Response(JSONRPCResponse { id, .. }) if id == &process_start_id ) }) .await?; - let JSONRPCMessage::Error(JSONRPCError { id, error }) = response else { - panic!("expected process/start stub error"); + let JSONRPCMessage::Response(JSONRPCResponse { id, result }) = response else { + panic!("expected process/start response"); }; assert_eq!(id, process_start_id); - assert_eq!(error.code, -32601); + let process_start_response: ExecResponse = serde_json::from_value(result)?; assert_eq!( - error.message, - "exec-server stub does not implement `process/start` yet" + process_start_response, + ExecResponse { + process_id: "proc-1".to_string() + } ); server.shutdown().await?;