diff --git a/codex-rs/exec-server/src/client.rs b/codex-rs/exec-server/src/client.rs index a7680e73e..4fa75abe1 100644 --- a/codex-rs/exec-server/src/client.rs +++ b/codex-rs/exec-server/src/client.rs @@ -18,16 +18,15 @@ use codex_app_server_protocol::FsWriteFileResponse; use codex_app_server_protocol::JSONRPCNotification; use serde_json::Value; use tokio::sync::broadcast; -use tokio::sync::mpsc; use tokio::time::timeout; use tokio_tungstenite::connect_async; use tracing::debug; use tracing::warn; use crate::client_api::ExecServerClientConnectOptions; -use crate::client_api::ExecServerEvent; use crate::client_api::RemoteExecServerConnectArgs; use crate::connection::JsonRpcConnection; +use crate::process::ExecServerEvent; use crate::protocol::EXEC_EXITED_METHOD; use crate::protocol::EXEC_METHOD; use crate::protocol::EXEC_OUTPUT_DELTA_METHOD; @@ -58,11 +57,6 @@ use crate::protocol::WriteResponse; use crate::rpc::RpcCallError; use crate::rpc::RpcClient; use crate::rpc::RpcClientEvent; -use crate::rpc::RpcNotificationSender; -use crate::rpc::RpcServerOutboundMessage; - -mod local_backend; -use local_backend::LocalBackend; const CONNECT_TIMEOUT: Duration = Duration::from_secs(10); const INITIALIZE_TIMEOUT: Duration = Duration::from_secs(10); @@ -96,43 +90,14 @@ impl RemoteExecServerConnectArgs { } } -enum ClientBackend { - Remote(RpcClient), - InProcess(LocalBackend), -} - -impl ClientBackend { - fn as_local(&self) -> Option<&LocalBackend> { - match self { - ClientBackend::Remote(_) => None, - ClientBackend::InProcess(backend) => Some(backend), - } - } - - fn as_remote(&self) -> Option<&RpcClient> { - match self { - ClientBackend::Remote(client) => Some(client), - ClientBackend::InProcess(_) => None, - } - } -} - struct Inner { - backend: ClientBackend, + client: RpcClient, events_tx: broadcast::Sender, reader_task: tokio::task::JoinHandle<()>, } impl Drop for Inner { fn drop(&mut self) { - if let Some(backend) = self.backend.as_local() - && let Ok(handle) = tokio::runtime::Handle::try_current() - { - let backend = backend.clone(); - handle.spawn(async move { - backend.shutdown().await; - }); - } self.reader_task.abort(); } } @@ -167,40 +132,6 @@ pub enum ExecServerError { } impl ExecServerClient { - pub async fn connect_in_process( - options: ExecServerClientConnectOptions, - ) -> Result { - let (outgoing_tx, mut outgoing_rx) = mpsc::channel::(256); - let backend = LocalBackend::new(crate::server::ExecServerHandler::new( - RpcNotificationSender::new(outgoing_tx), - )); - let inner = Arc::new_cyclic(|weak| { - let weak = weak.clone(); - let reader_task = tokio::spawn(async move { - while let Some(message) = outgoing_rx.recv().await { - if let Some(inner) = weak.upgrade() - && let Err(err) = handle_in_process_outbound_message(&inner, message).await - { - warn!( - "in-process exec-server client closing after unexpected response: {err}" - ); - return; - } - } - }); - - Inner { - backend: ClientBackend::InProcess(backend), - events_tx: broadcast::channel(256).0, - reader_task, - } - }); - - let client = Self { inner }; - client.initialize(options).await?; - Ok(client) - } - pub async fn connect_websocket( args: RemoteExecServerConnectArgs, ) -> Result { @@ -241,17 +172,11 @@ impl ExecServerClient { } = options; timeout(initialize_timeout, async { - let response = if let Some(backend) = self.inner.backend.as_local() { - backend.initialize().await? - } else { - let params = InitializeParams { client_name }; - let Some(remote) = self.inner.backend.as_remote() else { - return Err(ExecServerError::Protocol( - "remote backend missing during initialize".to_string(), - )); - }; - remote.call(INITIALIZE_METHOD, ¶ms).await? - }; + let response = self + .inner + .client + .call(INITIALIZE_METHOD, &InitializeParams { client_name }) + .await?; self.notify_initialized().await?; Ok(response) }) @@ -262,27 +187,16 @@ impl ExecServerClient { } pub async fn exec(&self, params: ExecParams) -> Result { - if let Some(backend) = self.inner.backend.as_local() { - return backend.exec(params).await; - } - let Some(remote) = self.inner.backend.as_remote() else { - return Err(ExecServerError::Protocol( - "remote backend missing during exec".to_string(), - )); - }; - remote.call(EXEC_METHOD, ¶ms).await.map_err(Into::into) + self.inner + .client + .call(EXEC_METHOD, ¶ms) + .await + .map_err(Into::into) } pub async fn read(&self, params: ReadParams) -> Result { - if let Some(backend) = self.inner.backend.as_local() { - return backend.exec_read(params).await; - } - let Some(remote) = self.inner.backend.as_remote() else { - return Err(ExecServerError::Protocol( - "remote backend missing during read".to_string(), - )); - }; - remote + self.inner + .client .call(EXEC_READ_METHOD, ¶ms) .await .map_err(Into::into) @@ -293,38 +207,28 @@ impl ExecServerClient { process_id: &str, chunk: Vec, ) -> Result { - let params = WriteParams { - process_id: process_id.to_string(), - chunk: chunk.into(), - }; - if let Some(backend) = self.inner.backend.as_local() { - return backend.exec_write(params).await; - } - let Some(remote) = self.inner.backend.as_remote() else { - return Err(ExecServerError::Protocol( - "remote backend missing during write".to_string(), - )); - }; - remote - .call(EXEC_WRITE_METHOD, ¶ms) + self.inner + .client + .call( + EXEC_WRITE_METHOD, + &WriteParams { + process_id: process_id.to_string(), + chunk: chunk.into(), + }, + ) .await .map_err(Into::into) } pub async fn terminate(&self, process_id: &str) -> Result { - let params = TerminateParams { - process_id: process_id.to_string(), - }; - if let Some(backend) = self.inner.backend.as_local() { - return backend.terminate(params).await; - } - let Some(remote) = self.inner.backend.as_remote() else { - return Err(ExecServerError::Protocol( - "remote backend missing during terminate".to_string(), - )); - }; - remote - .call(EXEC_TERMINATE_METHOD, ¶ms) + self.inner + .client + .call( + EXEC_TERMINATE_METHOD, + &TerminateParams { + process_id: process_id.to_string(), + }, + ) .await .map_err(Into::into) } @@ -333,15 +237,8 @@ impl ExecServerClient { &self, params: FsReadFileParams, ) -> Result { - if let Some(backend) = self.inner.backend.as_local() { - return backend.fs_read_file(params).await; - } - let Some(remote) = self.inner.backend.as_remote() else { - return Err(ExecServerError::Protocol( - "remote backend missing during fs/readFile".to_string(), - )); - }; - remote + self.inner + .client .call(FS_READ_FILE_METHOD, ¶ms) .await .map_err(Into::into) @@ -351,15 +248,8 @@ impl ExecServerClient { &self, params: FsWriteFileParams, ) -> Result { - if let Some(backend) = self.inner.backend.as_local() { - return backend.fs_write_file(params).await; - } - let Some(remote) = self.inner.backend.as_remote() else { - return Err(ExecServerError::Protocol( - "remote backend missing during fs/writeFile".to_string(), - )); - }; - remote + self.inner + .client .call(FS_WRITE_FILE_METHOD, ¶ms) .await .map_err(Into::into) @@ -369,15 +259,8 @@ impl ExecServerClient { &self, params: FsCreateDirectoryParams, ) -> Result { - if let Some(backend) = self.inner.backend.as_local() { - return backend.fs_create_directory(params).await; - } - let Some(remote) = self.inner.backend.as_remote() else { - return Err(ExecServerError::Protocol( - "remote backend missing during fs/createDirectory".to_string(), - )); - }; - remote + self.inner + .client .call(FS_CREATE_DIRECTORY_METHOD, ¶ms) .await .map_err(Into::into) @@ -387,15 +270,8 @@ impl ExecServerClient { &self, params: FsGetMetadataParams, ) -> Result { - if let Some(backend) = self.inner.backend.as_local() { - return backend.fs_get_metadata(params).await; - } - let Some(remote) = self.inner.backend.as_remote() else { - return Err(ExecServerError::Protocol( - "remote backend missing during fs/getMetadata".to_string(), - )); - }; - remote + self.inner + .client .call(FS_GET_METADATA_METHOD, ¶ms) .await .map_err(Into::into) @@ -405,15 +281,8 @@ impl ExecServerClient { &self, params: FsReadDirectoryParams, ) -> Result { - if let Some(backend) = self.inner.backend.as_local() { - return backend.fs_read_directory(params).await; - } - let Some(remote) = self.inner.backend.as_remote() else { - return Err(ExecServerError::Protocol( - "remote backend missing during fs/readDirectory".to_string(), - )); - }; - remote + self.inner + .client .call(FS_READ_DIRECTORY_METHOD, ¶ms) .await .map_err(Into::into) @@ -423,30 +292,16 @@ impl ExecServerClient { &self, params: FsRemoveParams, ) -> Result { - if let Some(backend) = self.inner.backend.as_local() { - return backend.fs_remove(params).await; - } - let Some(remote) = self.inner.backend.as_remote() else { - return Err(ExecServerError::Protocol( - "remote backend missing during fs/remove".to_string(), - )); - }; - remote + self.inner + .client .call(FS_REMOVE_METHOD, ¶ms) .await .map_err(Into::into) } pub async fn fs_copy(&self, params: FsCopyParams) -> Result { - if let Some(backend) = self.inner.backend.as_local() { - return backend.fs_copy(params).await; - } - let Some(remote) = self.inner.backend.as_remote() else { - return Err(ExecServerError::Protocol( - "remote backend missing during fs/copy".to_string(), - )); - }; - remote + self.inner + .client .call(FS_COPY_METHOD, ¶ms) .await .map_err(Into::into) @@ -482,7 +337,7 @@ impl ExecServerClient { }); Inner { - backend: ClientBackend::Remote(rpc_client), + client: rpc_client, events_tx: broadcast::channel(256).0, reader_task, } @@ -494,13 +349,11 @@ impl ExecServerClient { } async fn notify_initialized(&self) -> Result<(), ExecServerError> { - match &self.inner.backend { - ClientBackend::Remote(client) => client - .notify(INITIALIZED_METHOD, &serde_json::json!({})) - .await - .map_err(ExecServerError::Json), - ClientBackend::InProcess(backend) => backend.initialized().await, - } + self.inner + .client + .notify(INITIALIZED_METHOD, &serde_json::json!({})) + .await + .map_err(ExecServerError::Json) } } @@ -517,20 +370,6 @@ impl From for ExecServerError { } } -async fn handle_in_process_outbound_message( - inner: &Arc, - message: RpcServerOutboundMessage, -) -> Result<(), ExecServerError> { - match message { - RpcServerOutboundMessage::Response { .. } | RpcServerOutboundMessage::Error { .. } => Err( - ExecServerError::Protocol("unexpected in-process RPC response".to_string()), - ), - RpcServerOutboundMessage::Notification(notification) => { - handle_server_notification(inner, notification).await - } - } -} - async fn handle_server_notification( inner: &Arc, notification: JSONRPCNotification, diff --git a/codex-rs/exec-server/src/client/local_backend.rs b/codex-rs/exec-server/src/client/local_backend.rs deleted file mode 100644 index e23a5361d..000000000 --- a/codex-rs/exec-server/src/client/local_backend.rs +++ /dev/null @@ -1,200 +0,0 @@ -use std::sync::Arc; - -use crate::protocol::ExecParams; -use crate::protocol::ExecResponse; -use crate::protocol::InitializeResponse; -use crate::protocol::ReadParams; -use crate::protocol::ReadResponse; -use crate::protocol::TerminateParams; -use crate::protocol::TerminateResponse; -use crate::protocol::WriteParams; -use crate::protocol::WriteResponse; -use crate::server::ExecServerHandler; -use codex_app_server_protocol::FsCopyParams; -use codex_app_server_protocol::FsCopyResponse; -use codex_app_server_protocol::FsCreateDirectoryParams; -use codex_app_server_protocol::FsCreateDirectoryResponse; -use codex_app_server_protocol::FsGetMetadataParams; -use codex_app_server_protocol::FsGetMetadataResponse; -use codex_app_server_protocol::FsReadDirectoryParams; -use codex_app_server_protocol::FsReadDirectoryResponse; -use codex_app_server_protocol::FsReadFileParams; -use codex_app_server_protocol::FsReadFileResponse; -use codex_app_server_protocol::FsRemoveParams; -use codex_app_server_protocol::FsRemoveResponse; -use codex_app_server_protocol::FsWriteFileParams; -use codex_app_server_protocol::FsWriteFileResponse; - -use super::ExecServerError; - -#[derive(Clone)] -pub(super) struct LocalBackend { - handler: Arc, -} - -impl LocalBackend { - pub(super) fn new(handler: ExecServerHandler) -> Self { - Self { - handler: Arc::new(handler), - } - } - - pub(super) async fn shutdown(&self) { - self.handler.shutdown().await; - } - - pub(super) async fn initialize(&self) -> Result { - self.handler - .initialize() - .map_err(|error| ExecServerError::Server { - code: error.code, - message: error.message, - }) - } - - pub(super) async fn initialized(&self) -> Result<(), ExecServerError> { - self.handler - .initialized() - .map_err(ExecServerError::Protocol) - } - - pub(super) async fn exec(&self, params: ExecParams) -> Result { - self.handler - .exec(params) - .await - .map_err(|error| ExecServerError::Server { - code: error.code, - message: error.message, - }) - } - - pub(super) async fn exec_read( - &self, - params: ReadParams, - ) -> Result { - self.handler - .exec_read(params) - .await - .map_err(|error| ExecServerError::Server { - code: error.code, - message: error.message, - }) - } - - pub(super) async fn exec_write( - &self, - params: WriteParams, - ) -> Result { - self.handler - .exec_write(params) - .await - .map_err(|error| ExecServerError::Server { - code: error.code, - message: error.message, - }) - } - - pub(super) async fn terminate( - &self, - params: TerminateParams, - ) -> Result { - self.handler - .terminate(params) - .await - .map_err(|error| ExecServerError::Server { - code: error.code, - message: error.message, - }) - } - - pub(super) async fn fs_read_file( - &self, - params: FsReadFileParams, - ) -> Result { - self.handler - .fs_read_file(params) - .await - .map_err(|error| ExecServerError::Server { - code: error.code, - message: error.message, - }) - } - - pub(super) async fn fs_write_file( - &self, - params: FsWriteFileParams, - ) -> Result { - self.handler - .fs_write_file(params) - .await - .map_err(|error| ExecServerError::Server { - code: error.code, - message: error.message, - }) - } - - pub(super) async fn fs_create_directory( - &self, - params: FsCreateDirectoryParams, - ) -> Result { - self.handler - .fs_create_directory(params) - .await - .map_err(|error| ExecServerError::Server { - code: error.code, - message: error.message, - }) - } - - pub(super) async fn fs_get_metadata( - &self, - params: FsGetMetadataParams, - ) -> Result { - self.handler - .fs_get_metadata(params) - .await - .map_err(|error| ExecServerError::Server { - code: error.code, - message: error.message, - }) - } - - pub(super) async fn fs_read_directory( - &self, - params: FsReadDirectoryParams, - ) -> Result { - self.handler - .fs_read_directory(params) - .await - .map_err(|error| ExecServerError::Server { - code: error.code, - message: error.message, - }) - } - - pub(super) async fn fs_remove( - &self, - params: FsRemoveParams, - ) -> Result { - self.handler - .fs_remove(params) - .await - .map_err(|error| ExecServerError::Server { - code: error.code, - message: error.message, - }) - } - - pub(super) async fn fs_copy( - &self, - params: FsCopyParams, - ) -> Result { - self.handler - .fs_copy(params) - .await - .map_err(|error| ExecServerError::Server { - code: error.code, - message: error.message, - }) - } -} diff --git a/codex-rs/exec-server/src/client_api.rs b/codex-rs/exec-server/src/client_api.rs index 962d3ba36..6e8976341 100644 --- a/codex-rs/exec-server/src/client_api.rs +++ b/codex-rs/exec-server/src/client_api.rs @@ -1,8 +1,5 @@ use std::time::Duration; -use crate::protocol::ExecExitedNotification; -use crate::protocol::ExecOutputDeltaNotification; - /// Connection options for any exec-server client transport. #[derive(Debug, Clone, PartialEq, Eq)] pub struct ExecServerClientConnectOptions { @@ -18,10 +15,3 @@ pub struct RemoteExecServerConnectArgs { pub connect_timeout: Duration, pub initialize_timeout: Duration, } - -/// Connection-level server events. -#[derive(Debug, Clone, PartialEq, Eq)] -pub enum ExecServerEvent { - OutputDelta(ExecOutputDeltaNotification), - Exited(ExecExitedNotification), -} diff --git a/codex-rs/exec-server/src/environment.rs b/codex-rs/exec-server/src/environment.rs index 3ca1cfe90..7cc3f7840 100644 --- a/codex-rs/exec-server/src/environment.rs +++ b/codex-rs/exec-server/src/environment.rs @@ -1,15 +1,42 @@ +use std::sync::Arc; + use crate::ExecServerClient; use crate::ExecServerError; use crate::RemoteExecServerConnectArgs; use crate::file_system::ExecutorFileSystem; use crate::local_file_system::LocalFileSystem; +use crate::local_process::LocalProcess; +use crate::process::ExecProcess; use crate::remote_file_system::RemoteFileSystem; -use std::sync::Arc; +use crate::remote_process::RemoteProcess; -#[derive(Clone, Default)] +pub trait ExecutorEnvironment: Send + Sync { + fn get_executor(&self) -> Arc; +} + +#[derive(Clone)] pub struct Environment { experimental_exec_server_url: Option, remote_exec_server_client: Option, + executor: Arc, +} + +impl Default for Environment { + fn default() -> Self { + let local_process = LocalProcess::default(); + if let Err(err) = local_process.initialize() { + panic!("default local process initialization should succeed: {err:?}"); + } + if let Err(err) = local_process.initialized() { + panic!("default local process should accept initialized notification: {err}"); + } + + Self { + experimental_exec_server_url: None, + remote_exec_server_client: None, + executor: Arc::new(local_process), + } + } } impl std::fmt::Debug for Environment { @@ -19,11 +46,7 @@ impl std::fmt::Debug for Environment { "experimental_exec_server_url", &self.experimental_exec_server_url, ) - .field( - "has_remote_exec_server_client", - &self.remote_exec_server_client.is_some(), - ) - .finish() + .finish_non_exhaustive() } } @@ -31,22 +54,38 @@ impl Environment { pub async fn create( experimental_exec_server_url: Option, ) -> Result { - let remote_exec_server_client = - if let Some(websocket_url) = experimental_exec_server_url.as_deref() { - Some( - ExecServerClient::connect_websocket(RemoteExecServerConnectArgs::new( - websocket_url.to_string(), - "codex-core".to_string(), - )) - .await?, - ) - } else { - None - }; + let remote_exec_server_client = if let Some(url) = &experimental_exec_server_url { + Some( + ExecServerClient::connect_websocket(RemoteExecServerConnectArgs { + websocket_url: url.clone(), + client_name: "codex-environment".to_string(), + connect_timeout: std::time::Duration::from_secs(5), + initialize_timeout: std::time::Duration::from_secs(5), + }) + .await?, + ) + } else { + None + }; + + let executor: Arc = if let Some(client) = remote_exec_server_client.clone() + { + Arc::new(RemoteProcess::new(client)) + } else { + let local_process = LocalProcess::default(); + local_process + .initialize() + .map_err(|err| ExecServerError::Protocol(err.message))?; + local_process + .initialized() + .map_err(ExecServerError::Protocol)?; + Arc::new(local_process) + }; Ok(Self { experimental_exec_server_url, remote_exec_server_client, + executor, }) } @@ -54,8 +93,8 @@ impl Environment { self.experimental_exec_server_url.as_deref() } - pub fn remote_exec_server_client(&self) -> Option<&ExecServerClient> { - self.remote_exec_server_client.as_ref() + pub fn get_executor(&self) -> Arc { + Arc::clone(&self.executor) } pub fn get_filesystem(&self) -> Arc { @@ -67,6 +106,12 @@ impl Environment { } } +impl ExecutorEnvironment for Environment { + fn get_executor(&self) -> Arc { + Arc::clone(&self.executor) + } +} + #[cfg(test)] mod tests { use super::Environment; @@ -77,6 +122,31 @@ mod tests { let environment = Environment::create(None).await.expect("create environment"); assert_eq!(environment.experimental_exec_server_url(), None); - assert!(environment.remote_exec_server_client().is_none()); + assert!(environment.remote_exec_server_client.is_none()); + } + + #[tokio::test] + async fn default_environment_has_ready_local_executor() { + let environment = Environment::default(); + + let response = environment + .get_executor() + .start(crate::ExecParams { + process_id: "default-env-proc".to_string(), + argv: vec!["true".to_string()], + cwd: std::env::current_dir().expect("read current dir"), + env: Default::default(), + tty: false, + arg0: None, + }) + .await + .expect("start process"); + + assert_eq!( + response, + crate::ExecResponse { + process_id: "default-env-proc".to_string(), + } + ); } } diff --git a/codex-rs/exec-server/src/lib.rs b/codex-rs/exec-server/src/lib.rs index 55c42ebb9..68ff9f654 100644 --- a/codex-rs/exec-server/src/lib.rs +++ b/codex-rs/exec-server/src/lib.rs @@ -4,15 +4,17 @@ mod connection; mod environment; mod file_system; mod local_file_system; +mod local_process; +mod process; mod protocol; mod remote_file_system; +mod remote_process; mod rpc; mod server; pub use client::ExecServerClient; pub use client::ExecServerError; pub use client_api::ExecServerClientConnectOptions; -pub use client_api::ExecServerEvent; pub use client_api::RemoteExecServerConnectArgs; pub use codex_app_server_protocol::FsCopyParams; pub use codex_app_server_protocol::FsCopyResponse; @@ -20,7 +22,6 @@ pub use codex_app_server_protocol::FsCreateDirectoryParams; pub use codex_app_server_protocol::FsCreateDirectoryResponse; pub use codex_app_server_protocol::FsGetMetadataParams; pub use codex_app_server_protocol::FsGetMetadataResponse; -pub use codex_app_server_protocol::FsReadDirectoryEntry; pub use codex_app_server_protocol::FsReadDirectoryParams; pub use codex_app_server_protocol::FsReadDirectoryResponse; pub use codex_app_server_protocol::FsReadFileParams; @@ -30,6 +31,7 @@ pub use codex_app_server_protocol::FsRemoveResponse; pub use codex_app_server_protocol::FsWriteFileParams; pub use codex_app_server_protocol::FsWriteFileResponse; pub use environment::Environment; +pub use environment::ExecutorEnvironment; pub use file_system::CopyOptions; pub use file_system::CreateDirectoryOptions; pub use file_system::ExecutorFileSystem; @@ -37,6 +39,8 @@ pub use file_system::FileMetadata; pub use file_system::FileSystemResult; pub use file_system::ReadDirectoryEntry; pub use file_system::RemoveOptions; +pub use process::ExecProcess; +pub use process::ExecServerEvent; pub use protocol::ExecExitedNotification; pub use protocol::ExecOutputDeltaNotification; pub use protocol::ExecOutputStream; diff --git a/codex-rs/exec-server/src/local_process.rs b/codex-rs/exec-server/src/local_process.rs new file mode 100644 index 000000000..c233da3d7 --- /dev/null +++ b/codex-rs/exec-server/src/local_process.rs @@ -0,0 +1,515 @@ +use std::collections::HashMap; +use std::collections::VecDeque; +use std::sync::Arc; +use std::sync::atomic::AtomicBool; +use std::sync::atomic::Ordering; +use std::time::Duration; + +use async_trait::async_trait; +use codex_app_server_protocol::JSONRPCErrorError; +use codex_utils_pty::ExecCommandSession; +use codex_utils_pty::TerminalSize; +use tokio::sync::Mutex; +use tokio::sync::Notify; +use tokio::sync::broadcast; +use tokio::sync::mpsc; +use tracing::warn; + +use crate::ExecProcess; +use crate::ExecServerError; +use crate::ExecServerEvent; +use crate::protocol::ExecExitedNotification; +use crate::protocol::ExecOutputDeltaNotification; +use crate::protocol::ExecOutputStream; +use crate::protocol::ExecParams; +use crate::protocol::ExecResponse; +use crate::protocol::InitializeResponse; +use crate::protocol::ProcessOutputChunk; +use crate::protocol::ReadParams; +use crate::protocol::ReadResponse; +use crate::protocol::TerminateParams; +use crate::protocol::TerminateResponse; +use crate::protocol::WriteParams; +use crate::protocol::WriteResponse; +use crate::rpc::RpcNotificationSender; +use crate::rpc::RpcServerOutboundMessage; +use crate::rpc::internal_error; +use crate::rpc::invalid_params; +use crate::rpc::invalid_request; + +const RETAINED_OUTPUT_BYTES_PER_PROCESS: usize = 1024 * 1024; +const EVENT_CHANNEL_CAPACITY: usize = 256; +const NOTIFICATION_CHANNEL_CAPACITY: usize = 256; +#[cfg(test)] +const EXITED_PROCESS_RETENTION: Duration = Duration::from_millis(25); +#[cfg(not(test))] +const EXITED_PROCESS_RETENTION: Duration = Duration::from_secs(30); + +#[derive(Clone)] +struct RetainedOutputChunk { + seq: u64, + stream: ExecOutputStream, + chunk: Vec, +} + +struct RunningProcess { + session: ExecCommandSession, + tty: bool, + output: VecDeque, + retained_bytes: usize, + next_seq: u64, + exit_code: Option, + output_notify: Arc, +} + +enum ProcessEntry { + Starting, + Running(Box), +} + +struct Inner { + notifications: RpcNotificationSender, + events_tx: broadcast::Sender, + processes: Mutex>, + initialize_requested: AtomicBool, + initialized: AtomicBool, +} + +#[derive(Clone)] +pub(crate) struct LocalProcess { + inner: Arc, +} + +impl Default for LocalProcess { + fn default() -> Self { + let (outgoing_tx, mut outgoing_rx) = + mpsc::channel::(NOTIFICATION_CHANNEL_CAPACITY); + tokio::spawn(async move { while outgoing_rx.recv().await.is_some() {} }); + Self::new(RpcNotificationSender::new(outgoing_tx)) + } +} + +impl LocalProcess { + pub(crate) fn new(notifications: RpcNotificationSender) -> Self { + Self { + inner: Arc::new(Inner { + notifications, + events_tx: broadcast::channel(EVENT_CHANNEL_CAPACITY).0, + processes: Mutex::new(HashMap::new()), + initialize_requested: AtomicBool::new(false), + initialized: AtomicBool::new(false), + }), + } + } + + pub(crate) async fn shutdown(&self) { + let remaining = { + let mut processes = self.inner.processes.lock().await; + processes + .drain() + .filter_map(|(_, process)| match process { + ProcessEntry::Starting => None, + ProcessEntry::Running(process) => Some(process), + }) + .collect::>() + }; + for process in remaining { + process.session.terminate(); + } + } + + pub(crate) fn initialize(&self) -> Result { + if self.inner.initialize_requested.swap(true, Ordering::SeqCst) { + return Err(invalid_request( + "initialize may only be sent once per connection".to_string(), + )); + } + Ok(InitializeResponse {}) + } + + pub(crate) fn initialized(&self) -> Result<(), String> { + if !self.inner.initialize_requested.load(Ordering::SeqCst) { + return Err("received `initialized` notification before `initialize`".into()); + } + self.inner.initialized.store(true, Ordering::SeqCst); + Ok(()) + } + + pub(crate) fn require_initialized_for( + &self, + method_family: &str, + ) -> Result<(), JSONRPCErrorError> { + if !self.inner.initialize_requested.load(Ordering::SeqCst) { + return Err(invalid_request(format!( + "client must call initialize before using {method_family} methods" + ))); + } + if !self.inner.initialized.load(Ordering::SeqCst) { + return Err(invalid_request(format!( + "client must send initialized before using {method_family} methods" + ))); + } + Ok(()) + } + + pub(crate) async fn exec(&self, params: ExecParams) -> Result { + self.require_initialized_for("exec")?; + let process_id = params.process_id.clone(); + + let (program, args) = params + .argv + .split_first() + .ok_or_else(|| invalid_params("argv must not be empty".to_string()))?; + + { + let mut process_map = self.inner.processes.lock().await; + if process_map.contains_key(&process_id) { + return Err(invalid_request(format!( + "process {process_id} already exists" + ))); + } + process_map.insert(process_id.clone(), ProcessEntry::Starting); + } + + let spawned_result = if params.tty { + codex_utils_pty::spawn_pty_process( + program, + args, + params.cwd.as_path(), + ¶ms.env, + ¶ms.arg0, + TerminalSize::default(), + ) + .await + } else { + codex_utils_pty::spawn_pipe_process_no_stdin( + program, + args, + params.cwd.as_path(), + ¶ms.env, + ¶ms.arg0, + ) + .await + }; + let spawned = match spawned_result { + Ok(spawned) => spawned, + Err(err) => { + let mut process_map = self.inner.processes.lock().await; + if matches!(process_map.get(&process_id), Some(ProcessEntry::Starting)) { + process_map.remove(&process_id); + } + return Err(internal_error(err.to_string())); + } + }; + + let output_notify = Arc::new(Notify::new()); + { + let mut process_map = self.inner.processes.lock().await; + process_map.insert( + process_id.clone(), + ProcessEntry::Running(Box::new(RunningProcess { + session: spawned.session, + tty: params.tty, + output: VecDeque::new(), + retained_bytes: 0, + next_seq: 1, + exit_code: None, + output_notify: Arc::clone(&output_notify), + })), + ); + } + + tokio::spawn(stream_output( + process_id.clone(), + if params.tty { + ExecOutputStream::Pty + } else { + ExecOutputStream::Stdout + }, + spawned.stdout_rx, + Arc::clone(&self.inner), + Arc::clone(&output_notify), + )); + tokio::spawn(stream_output( + process_id.clone(), + if params.tty { + ExecOutputStream::Pty + } else { + ExecOutputStream::Stderr + }, + spawned.stderr_rx, + Arc::clone(&self.inner), + Arc::clone(&output_notify), + )); + tokio::spawn(watch_exit( + process_id.clone(), + spawned.exit_rx, + Arc::clone(&self.inner), + output_notify, + )); + + Ok(ExecResponse { process_id }) + } + + pub(crate) async fn exec_read( + &self, + params: ReadParams, + ) -> Result { + self.require_initialized_for("exec")?; + let after_seq = params.after_seq.unwrap_or(0); + let max_bytes = params.max_bytes.unwrap_or(usize::MAX); + let wait = Duration::from_millis(params.wait_ms.unwrap_or(0)); + let deadline = tokio::time::Instant::now() + wait; + + loop { + let (response, output_notify) = { + let process_map = self.inner.processes.lock().await; + let process = process_map.get(¶ms.process_id).ok_or_else(|| { + invalid_request(format!("unknown process id {}", params.process_id)) + })?; + let ProcessEntry::Running(process) = process else { + return Err(invalid_request(format!( + "process id {} is starting", + params.process_id + ))); + }; + + let mut chunks = Vec::new(); + let mut total_bytes = 0; + let mut next_seq = process.next_seq; + for retained in process.output.iter().filter(|chunk| chunk.seq > after_seq) { + let chunk_len = retained.chunk.len(); + if !chunks.is_empty() && total_bytes + chunk_len > max_bytes { + break; + } + total_bytes += chunk_len; + chunks.push(ProcessOutputChunk { + seq: retained.seq, + stream: retained.stream, + chunk: retained.chunk.clone().into(), + }); + next_seq = retained.seq + 1; + if total_bytes >= max_bytes { + break; + } + } + + ( + ReadResponse { + chunks, + next_seq, + exited: process.exit_code.is_some(), + exit_code: process.exit_code, + }, + Arc::clone(&process.output_notify), + ) + }; + + if !response.chunks.is_empty() + || response.exited + || tokio::time::Instant::now() >= deadline + { + return Ok(response); + } + + let remaining = deadline.saturating_duration_since(tokio::time::Instant::now()); + if remaining.is_zero() { + return Ok(response); + } + let _ = tokio::time::timeout(remaining, output_notify.notified()).await; + } + } + + pub(crate) async fn exec_write( + &self, + params: WriteParams, + ) -> Result { + self.require_initialized_for("exec")?; + let writer_tx = { + let process_map = self.inner.processes.lock().await; + let process = process_map.get(¶ms.process_id).ok_or_else(|| { + invalid_request(format!("unknown process id {}", params.process_id)) + })?; + let ProcessEntry::Running(process) = process else { + return Err(invalid_request(format!( + "process id {} is starting", + params.process_id + ))); + }; + if !process.tty { + return Err(invalid_request(format!( + "stdin is closed for process {}", + params.process_id + ))); + } + process.session.writer_sender() + }; + + writer_tx + .send(params.chunk.into_inner()) + .await + .map_err(|_| internal_error("failed to write to process stdin".to_string()))?; + + Ok(WriteResponse { accepted: true }) + } + + pub(crate) async fn terminate_process( + &self, + params: TerminateParams, + ) -> Result { + self.require_initialized_for("exec")?; + let running = { + let process_map = self.inner.processes.lock().await; + match process_map.get(¶ms.process_id) { + Some(ProcessEntry::Running(process)) => { + if process.exit_code.is_some() { + return Ok(TerminateResponse { running: false }); + } + process.session.terminate(); + true + } + Some(ProcessEntry::Starting) | None => false, + } + }; + + Ok(TerminateResponse { running }) + } +} + +#[async_trait] +impl ExecProcess for LocalProcess { + async fn start(&self, params: ExecParams) -> Result { + self.exec(params).await.map_err(map_handler_error) + } + + async fn read(&self, params: ReadParams) -> Result { + self.exec_read(params).await.map_err(map_handler_error) + } + + async fn write( + &self, + process_id: &str, + chunk: Vec, + ) -> Result { + self.exec_write(WriteParams { + process_id: process_id.to_string(), + chunk: chunk.into(), + }) + .await + .map_err(map_handler_error) + } + + async fn terminate(&self, process_id: &str) -> Result { + self.terminate_process(TerminateParams { + process_id: process_id.to_string(), + }) + .await + .map_err(map_handler_error) + } + + fn subscribe_events(&self) -> broadcast::Receiver { + self.inner.events_tx.subscribe() + } +} + +fn map_handler_error(error: JSONRPCErrorError) -> ExecServerError { + ExecServerError::Server { + code: error.code, + message: error.message, + } +} + +async fn stream_output( + process_id: String, + stream: ExecOutputStream, + mut receiver: tokio::sync::mpsc::Receiver>, + inner: Arc, + output_notify: Arc, +) { + while let Some(chunk) = receiver.recv().await { + let notification = { + let mut processes = inner.processes.lock().await; + let Some(entry) = processes.get_mut(&process_id) else { + break; + }; + let ProcessEntry::Running(process) = entry else { + break; + }; + let seq = process.next_seq; + process.next_seq += 1; + process.retained_bytes += chunk.len(); + process.output.push_back(RetainedOutputChunk { + seq, + stream, + chunk: chunk.clone(), + }); + while process.retained_bytes > RETAINED_OUTPUT_BYTES_PER_PROCESS { + let Some(evicted) = process.output.pop_front() else { + break; + }; + process.retained_bytes = process.retained_bytes.saturating_sub(evicted.chunk.len()); + warn!( + "retained output cap exceeded for process {process_id}; dropping oldest output" + ); + } + ExecOutputDeltaNotification { + process_id: process_id.clone(), + stream, + chunk: chunk.into(), + } + }; + output_notify.notify_waiters(); + let _ = inner + .events_tx + .send(ExecServerEvent::OutputDelta(notification.clone())); + + if inner + .notifications + .notify(crate::protocol::EXEC_OUTPUT_DELTA_METHOD, ¬ification) + .await + .is_err() + { + break; + } + } +} + +async fn watch_exit( + process_id: String, + exit_rx: tokio::sync::oneshot::Receiver, + inner: Arc, + output_notify: Arc, +) { + let exit_code = exit_rx.await.unwrap_or(-1); + { + let mut processes = inner.processes.lock().await; + if let Some(ProcessEntry::Running(process)) = processes.get_mut(&process_id) { + process.exit_code = Some(exit_code); + } + } + output_notify.notify_waiters(); + let notification = ExecExitedNotification { + process_id: process_id.clone(), + exit_code, + }; + let _ = inner + .events_tx + .send(ExecServerEvent::Exited(notification.clone())); + if inner + .notifications + .notify(crate::protocol::EXEC_EXITED_METHOD, ¬ification) + .await + .is_err() + { + return; + } + + tokio::time::sleep(EXITED_PROCESS_RETENTION).await; + let mut processes = inner.processes.lock().await; + if matches!( + processes.get(&process_id), + Some(ProcessEntry::Running(process)) if process.exit_code == Some(exit_code) + ) { + processes.remove(&process_id); + } +} diff --git a/codex-rs/exec-server/src/process.rs b/codex-rs/exec-server/src/process.rs new file mode 100644 index 000000000..b2d743c32 --- /dev/null +++ b/codex-rs/exec-server/src/process.rs @@ -0,0 +1,35 @@ +use async_trait::async_trait; +use tokio::sync::broadcast; + +use crate::ExecServerError; +use crate::protocol::ExecExitedNotification; +use crate::protocol::ExecOutputDeltaNotification; +use crate::protocol::ExecParams; +use crate::protocol::ExecResponse; +use crate::protocol::ReadParams; +use crate::protocol::ReadResponse; +use crate::protocol::TerminateResponse; +use crate::protocol::WriteResponse; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum ExecServerEvent { + OutputDelta(ExecOutputDeltaNotification), + Exited(ExecExitedNotification), +} + +#[async_trait] +pub trait ExecProcess: Send + Sync { + async fn start(&self, params: ExecParams) -> Result; + + async fn read(&self, params: ReadParams) -> Result; + + async fn write( + &self, + process_id: &str, + chunk: Vec, + ) -> Result; + + async fn terminate(&self, process_id: &str) -> Result; + + fn subscribe_events(&self) -> broadcast::Receiver; +} diff --git a/codex-rs/exec-server/src/remote_process.rs b/codex-rs/exec-server/src/remote_process.rs new file mode 100644 index 000000000..c34c1fe6a --- /dev/null +++ b/codex-rs/exec-server/src/remote_process.rs @@ -0,0 +1,51 @@ +use async_trait::async_trait; +use tokio::sync::broadcast; + +use crate::ExecProcess; +use crate::ExecServerClient; +use crate::ExecServerError; +use crate::ExecServerEvent; +use crate::protocol::ExecParams; +use crate::protocol::ExecResponse; +use crate::protocol::ReadParams; +use crate::protocol::ReadResponse; +use crate::protocol::TerminateResponse; +use crate::protocol::WriteResponse; + +#[derive(Clone)] +pub(crate) struct RemoteProcess { + client: ExecServerClient, +} + +impl RemoteProcess { + pub(crate) fn new(client: ExecServerClient) -> Self { + Self { client } + } +} + +#[async_trait] +impl ExecProcess for RemoteProcess { + async fn start(&self, params: ExecParams) -> Result { + self.client.exec(params).await + } + + async fn read(&self, params: ReadParams) -> Result { + self.client.read(params).await + } + + async fn write( + &self, + process_id: &str, + chunk: Vec, + ) -> Result { + self.client.write(process_id, chunk).await + } + + async fn terminate(&self, process_id: &str) -> Result { + self.client.terminate(process_id).await + } + + fn subscribe_events(&self) -> broadcast::Receiver { + self.client.event_receiver() + } +} diff --git a/codex-rs/exec-server/src/server.rs b/codex-rs/exec-server/src/server.rs index 4bd90dd9a..46de5aa49 100644 --- a/codex-rs/exec-server/src/server.rs +++ b/codex-rs/exec-server/src/server.rs @@ -1,5 +1,6 @@ mod file_system_handler; mod handler; +mod process_handler; mod processor; mod registry; mod transport; diff --git a/codex-rs/exec-server/src/server/handler.rs b/codex-rs/exec-server/src/server/handler.rs index 0ddd7ee50..0fe2588d0 100644 --- a/codex-rs/exec-server/src/server/handler.rs +++ b/codex-rs/exec-server/src/server/handler.rs @@ -1,10 +1,3 @@ -use std::collections::HashMap; -use std::collections::VecDeque; -use std::sync::Arc; -use std::sync::atomic::AtomicBool; -use std::sync::atomic::Ordering; -use std::time::Duration; - use codex_app_server_protocol::FsCopyParams; use codex_app_server_protocol::FsCopyResponse; use codex_app_server_protocol::FsCreateDirectoryParams; @@ -20,19 +13,10 @@ use codex_app_server_protocol::FsRemoveResponse; use codex_app_server_protocol::FsWriteFileParams; use codex_app_server_protocol::FsWriteFileResponse; use codex_app_server_protocol::JSONRPCErrorError; -use codex_utils_pty::ExecCommandSession; -use codex_utils_pty::TerminalSize; -use tokio::sync::Mutex; -use tokio::sync::Notify; -use tracing::warn; -use crate::protocol::ExecExitedNotification; -use crate::protocol::ExecOutputDeltaNotification; -use crate::protocol::ExecOutputStream; use crate::protocol::ExecParams; use crate::protocol::ExecResponse; use crate::protocol::InitializeResponse; -use crate::protocol::ProcessOutputChunk; use crate::protocol::ReadParams; use crate::protocol::ReadResponse; use crate::protocol::TerminateParams; @@ -40,336 +24,65 @@ use crate::protocol::TerminateResponse; use crate::protocol::WriteParams; use crate::protocol::WriteResponse; use crate::rpc::RpcNotificationSender; -use crate::rpc::internal_error; -use crate::rpc::invalid_params; -use crate::rpc::invalid_request; use crate::server::file_system_handler::FileSystemHandler; - -const RETAINED_OUTPUT_BYTES_PER_PROCESS: usize = 1024 * 1024; -#[cfg(test)] -const EXITED_PROCESS_RETENTION: Duration = Duration::from_millis(25); -#[cfg(not(test))] -const EXITED_PROCESS_RETENTION: Duration = Duration::from_secs(30); +use crate::server::process_handler::ProcessHandler; #[derive(Clone)] -struct RetainedOutputChunk { - seq: u64, - stream: ExecOutputStream, - chunk: Vec, -} - -struct RunningProcess { - session: ExecCommandSession, - tty: bool, - output: VecDeque, - retained_bytes: usize, - next_seq: u64, - exit_code: Option, - output_notify: Arc, -} - -enum ProcessEntry { - Starting, - Running(Box), -} - pub(crate) struct ExecServerHandler { - notifications: RpcNotificationSender, + process: ProcessHandler, file_system: FileSystemHandler, - processes: Arc>>, - initialize_requested: AtomicBool, - initialized: AtomicBool, } impl ExecServerHandler { pub(crate) fn new(notifications: RpcNotificationSender) -> Self { Self { - notifications, + process: ProcessHandler::new(notifications), file_system: FileSystemHandler::default(), - processes: Arc::new(Mutex::new(HashMap::new())), - initialize_requested: AtomicBool::new(false), - initialized: AtomicBool::new(false), } } pub(crate) async fn shutdown(&self) { - let remaining = { - let mut processes = self.processes.lock().await; - processes - .drain() - .filter_map(|(_, process)| match process { - ProcessEntry::Starting => None, - ProcessEntry::Running(process) => Some(process), - }) - .collect::>() - }; - for process in remaining { - process.session.terminate(); - } + self.process.shutdown().await; } pub(crate) fn initialize(&self) -> Result { - if self.initialize_requested.swap(true, Ordering::SeqCst) { - return Err(invalid_request( - "initialize may only be sent once per connection".to_string(), - )); - } - Ok(InitializeResponse {}) + self.process.initialize() } pub(crate) fn initialized(&self) -> Result<(), String> { - if !self.initialize_requested.load(Ordering::SeqCst) { - return Err("received `initialized` notification before `initialize`".into()); - } - self.initialized.store(true, Ordering::SeqCst); - Ok(()) - } - - fn require_initialized_for(&self, method_family: &str) -> Result<(), JSONRPCErrorError> { - if !self.initialize_requested.load(Ordering::SeqCst) { - return Err(invalid_request(format!( - "client must call initialize before using {method_family} methods" - ))); - } - if !self.initialized.load(Ordering::SeqCst) { - return Err(invalid_request(format!( - "client must send initialized before using {method_family} methods" - ))); - } - Ok(()) + self.process.initialized() } pub(crate) async fn exec(&self, params: ExecParams) -> Result { - self.require_initialized_for("exec")?; - let process_id = params.process_id.clone(); - - let (program, args) = params - .argv - .split_first() - .ok_or_else(|| invalid_params("argv must not be empty".to_string()))?; - - { - let mut process_map = self.processes.lock().await; - if process_map.contains_key(&process_id) { - return Err(invalid_request(format!( - "process {process_id} already exists" - ))); - } - process_map.insert(process_id.clone(), ProcessEntry::Starting); - } - - let spawned_result = if params.tty { - codex_utils_pty::spawn_pty_process( - program, - args, - params.cwd.as_path(), - ¶ms.env, - ¶ms.arg0, - TerminalSize::default(), - ) - .await - } else { - codex_utils_pty::spawn_pipe_process_no_stdin( - program, - args, - params.cwd.as_path(), - ¶ms.env, - ¶ms.arg0, - ) - .await - }; - let spawned = match spawned_result { - Ok(spawned) => spawned, - Err(err) => { - let mut process_map = self.processes.lock().await; - if matches!(process_map.get(&process_id), Some(ProcessEntry::Starting)) { - process_map.remove(&process_id); - } - return Err(internal_error(err.to_string())); - } - }; - - let output_notify = Arc::new(Notify::new()); - { - let mut process_map = self.processes.lock().await; - process_map.insert( - process_id.clone(), - ProcessEntry::Running(Box::new(RunningProcess { - session: spawned.session, - tty: params.tty, - output: VecDeque::new(), - retained_bytes: 0, - next_seq: 1, - exit_code: None, - output_notify: Arc::clone(&output_notify), - })), - ); - } - - tokio::spawn(stream_output( - process_id.clone(), - if params.tty { - ExecOutputStream::Pty - } else { - ExecOutputStream::Stdout - }, - spawned.stdout_rx, - self.notifications.clone(), - Arc::clone(&self.processes), - Arc::clone(&output_notify), - )); - tokio::spawn(stream_output( - process_id.clone(), - if params.tty { - ExecOutputStream::Pty - } else { - ExecOutputStream::Stderr - }, - spawned.stderr_rx, - self.notifications.clone(), - Arc::clone(&self.processes), - Arc::clone(&output_notify), - )); - tokio::spawn(watch_exit( - process_id.clone(), - spawned.exit_rx, - self.notifications.clone(), - Arc::clone(&self.processes), - output_notify, - )); - - Ok(ExecResponse { process_id }) + self.process.exec(params).await } pub(crate) async fn exec_read( &self, params: ReadParams, ) -> Result { - self.require_initialized_for("exec")?; - let after_seq = params.after_seq.unwrap_or(0); - let max_bytes = params.max_bytes.unwrap_or(usize::MAX); - let wait = Duration::from_millis(params.wait_ms.unwrap_or(0)); - let deadline = tokio::time::Instant::now() + wait; - - loop { - let (response, output_notify) = { - let process_map = self.processes.lock().await; - let process = process_map.get(¶ms.process_id).ok_or_else(|| { - invalid_request(format!("unknown process id {}", params.process_id)) - })?; - let ProcessEntry::Running(process) = process else { - return Err(invalid_request(format!( - "process id {} is starting", - params.process_id - ))); - }; - - let mut chunks = Vec::new(); - let mut total_bytes = 0; - let mut next_seq = process.next_seq; - for retained in process.output.iter().filter(|chunk| chunk.seq > after_seq) { - let chunk_len = retained.chunk.len(); - if !chunks.is_empty() && total_bytes + chunk_len > max_bytes { - break; - } - total_bytes += chunk_len; - chunks.push(ProcessOutputChunk { - seq: retained.seq, - stream: retained.stream, - chunk: retained.chunk.clone().into(), - }); - next_seq = retained.seq + 1; - if total_bytes >= max_bytes { - break; - } - } - - ( - ReadResponse { - chunks, - next_seq, - exited: process.exit_code.is_some(), - exit_code: process.exit_code, - }, - Arc::clone(&process.output_notify), - ) - }; - - if !response.chunks.is_empty() - || response.exited - || tokio::time::Instant::now() >= deadline - { - return Ok(response); - } - - let remaining = deadline.saturating_duration_since(tokio::time::Instant::now()); - if remaining.is_zero() { - return Ok(response); - } - let _ = tokio::time::timeout(remaining, output_notify.notified()).await; - } + self.process.exec_read(params).await } pub(crate) async fn exec_write( &self, params: WriteParams, ) -> Result { - self.require_initialized_for("exec")?; - let writer_tx = { - let process_map = self.processes.lock().await; - let process = process_map.get(¶ms.process_id).ok_or_else(|| { - invalid_request(format!("unknown process id {}", params.process_id)) - })?; - let ProcessEntry::Running(process) = process else { - return Err(invalid_request(format!( - "process id {} is starting", - params.process_id - ))); - }; - if !process.tty { - return Err(invalid_request(format!( - "stdin is closed for process {}", - params.process_id - ))); - } - process.session.writer_sender() - }; - - writer_tx - .send(params.chunk.into_inner()) - .await - .map_err(|_| internal_error("failed to write to process stdin".to_string()))?; - - Ok(WriteResponse { accepted: true }) + self.process.exec_write(params).await } pub(crate) async fn terminate( &self, params: TerminateParams, ) -> Result { - self.require_initialized_for("exec")?; - let running = { - let process_map = self.processes.lock().await; - match process_map.get(¶ms.process_id) { - Some(ProcessEntry::Running(process)) => { - if process.exit_code.is_some() { - return Ok(TerminateResponse { running: false }); - } - process.session.terminate(); - true - } - Some(ProcessEntry::Starting) | None => false, - } - }; - - Ok(TerminateResponse { running }) + self.process.terminate(params).await } pub(crate) async fn fs_read_file( &self, params: FsReadFileParams, ) -> Result { - self.require_initialized_for("filesystem")?; + self.process.require_initialized_for("filesystem")?; self.file_system.read_file(params).await } @@ -377,7 +90,7 @@ impl ExecServerHandler { &self, params: FsWriteFileParams, ) -> Result { - self.require_initialized_for("filesystem")?; + self.process.require_initialized_for("filesystem")?; self.file_system.write_file(params).await } @@ -385,7 +98,7 @@ impl ExecServerHandler { &self, params: FsCreateDirectoryParams, ) -> Result { - self.require_initialized_for("filesystem")?; + self.process.require_initialized_for("filesystem")?; self.file_system.create_directory(params).await } @@ -393,7 +106,7 @@ impl ExecServerHandler { &self, params: FsGetMetadataParams, ) -> Result { - self.require_initialized_for("filesystem")?; + self.process.require_initialized_for("filesystem")?; self.file_system.get_metadata(params).await } @@ -401,7 +114,7 @@ impl ExecServerHandler { &self, params: FsReadDirectoryParams, ) -> Result { - self.require_initialized_for("filesystem")?; + self.process.require_initialized_for("filesystem")?; self.file_system.read_directory(params).await } @@ -409,7 +122,7 @@ impl ExecServerHandler { &self, params: FsRemoveParams, ) -> Result { - self.require_initialized_for("filesystem")?; + self.process.require_initialized_for("filesystem")?; self.file_system.remove(params).await } @@ -417,101 +130,10 @@ impl ExecServerHandler { &self, params: FsCopyParams, ) -> Result { - self.require_initialized_for("filesystem")?; + self.process.require_initialized_for("filesystem")?; self.file_system.copy(params).await } } -async fn stream_output( - process_id: String, - stream: ExecOutputStream, - mut receiver: tokio::sync::mpsc::Receiver>, - notifications: RpcNotificationSender, - processes: Arc>>, - output_notify: Arc, -) { - while let Some(chunk) = receiver.recv().await { - let notification = { - let mut processes = processes.lock().await; - let Some(entry) = processes.get_mut(&process_id) else { - break; - }; - let ProcessEntry::Running(process) = entry else { - break; - }; - let seq = process.next_seq; - process.next_seq += 1; - process.retained_bytes += chunk.len(); - process.output.push_back(RetainedOutputChunk { - seq, - stream, - chunk: chunk.clone(), - }); - while process.retained_bytes > RETAINED_OUTPUT_BYTES_PER_PROCESS { - let Some(evicted) = process.output.pop_front() else { - break; - }; - process.retained_bytes = process.retained_bytes.saturating_sub(evicted.chunk.len()); - warn!( - "retained output cap exceeded for process {process_id}; dropping oldest output" - ); - } - ExecOutputDeltaNotification { - process_id: process_id.clone(), - stream, - chunk: chunk.into(), - } - }; - output_notify.notify_waiters(); - - if notifications - .notify(crate::protocol::EXEC_OUTPUT_DELTA_METHOD, ¬ification) - .await - .is_err() - { - break; - } - } -} - -async fn watch_exit( - process_id: String, - exit_rx: tokio::sync::oneshot::Receiver, - notifications: RpcNotificationSender, - processes: Arc>>, - output_notify: Arc, -) { - let exit_code = exit_rx.await.unwrap_or(-1); - { - let mut processes = processes.lock().await; - if let Some(ProcessEntry::Running(process)) = processes.get_mut(&process_id) { - process.exit_code = Some(exit_code); - } - } - output_notify.notify_waiters(); - if notifications - .notify( - crate::protocol::EXEC_EXITED_METHOD, - &ExecExitedNotification { - process_id: process_id.clone(), - exit_code, - }, - ) - .await - .is_err() - { - return; - } - - tokio::time::sleep(EXITED_PROCESS_RETENTION).await; - let mut processes = processes.lock().await; - if matches!( - processes.get(&process_id), - Some(ProcessEntry::Running(process)) if process.exit_code == Some(exit_code) - ) { - processes.remove(&process_id); - } -} - #[cfg(test)] mod tests; diff --git a/codex-rs/exec-server/src/server/process_handler.rs b/codex-rs/exec-server/src/server/process_handler.rs new file mode 100644 index 000000000..6f22890d3 --- /dev/null +++ b/codex-rs/exec-server/src/server/process_handler.rs @@ -0,0 +1,70 @@ +use codex_app_server_protocol::JSONRPCErrorError; + +use crate::local_process::LocalProcess; +use crate::protocol::ExecParams; +use crate::protocol::ExecResponse; +use crate::protocol::InitializeResponse; +use crate::protocol::ReadParams; +use crate::protocol::ReadResponse; +use crate::protocol::TerminateParams; +use crate::protocol::TerminateResponse; +use crate::protocol::WriteParams; +use crate::protocol::WriteResponse; +use crate::rpc::RpcNotificationSender; + +#[derive(Clone)] +pub(crate) struct ProcessHandler { + process: LocalProcess, +} + +impl ProcessHandler { + pub(crate) fn new(notifications: RpcNotificationSender) -> Self { + Self { + process: LocalProcess::new(notifications), + } + } + + pub(crate) async fn shutdown(&self) { + self.process.shutdown().await; + } + + pub(crate) fn initialize(&self) -> Result { + self.process.initialize() + } + + pub(crate) fn initialized(&self) -> Result<(), String> { + self.process.initialized() + } + + pub(crate) fn require_initialized_for( + &self, + method_family: &str, + ) -> Result<(), JSONRPCErrorError> { + self.process.require_initialized_for(method_family) + } + + pub(crate) async fn exec(&self, params: ExecParams) -> Result { + self.process.exec(params).await + } + + pub(crate) async fn exec_read( + &self, + params: ReadParams, + ) -> Result { + self.process.exec_read(params).await + } + + pub(crate) async fn exec_write( + &self, + params: WriteParams, + ) -> Result { + self.process.exec_write(params).await + } + + pub(crate) async fn terminate( + &self, + params: TerminateParams, + ) -> Result { + self.process.terminate_process(params).await + } +} diff --git a/codex-rs/exec-server/tests/exec_process.rs b/codex-rs/exec-server/tests/exec_process.rs new file mode 100644 index 000000000..d72f83b95 --- /dev/null +++ b/codex-rs/exec-server/tests/exec_process.rs @@ -0,0 +1,87 @@ +#![cfg(unix)] + +mod common; + +use std::sync::Arc; + +use anyhow::Result; +use codex_exec_server::Environment; +use codex_exec_server::ExecParams; +use codex_exec_server::ExecProcess; +use codex_exec_server::ExecResponse; +use codex_exec_server::ReadParams; +use pretty_assertions::assert_eq; +use test_case::test_case; + +use common::exec_server::ExecServerHarness; +use common::exec_server::exec_server; + +struct ProcessContext { + process: Arc, + _server: Option, +} + +async fn create_process_context(use_remote: bool) -> Result { + if use_remote { + let server = exec_server().await?; + let environment = Environment::create(Some(server.websocket_url().to_string())).await?; + Ok(ProcessContext { + process: environment.get_executor(), + _server: Some(server), + }) + } else { + let environment = Environment::create(None).await?; + Ok(ProcessContext { + process: environment.get_executor(), + _server: None, + }) + } +} + +async fn assert_exec_process_starts_and_exits(use_remote: bool) -> Result<()> { + let context = create_process_context(use_remote).await?; + let response = context + .process + .start(ExecParams { + process_id: "proc-1".to_string(), + argv: vec!["true".to_string()], + cwd: std::env::current_dir()?, + env: Default::default(), + tty: false, + arg0: None, + }) + .await?; + assert_eq!( + response, + ExecResponse { + process_id: "proc-1".to_string(), + } + ); + + let mut next_seq = 0; + loop { + let read = context + .process + .read(ReadParams { + process_id: "proc-1".to_string(), + after_seq: Some(next_seq), + max_bytes: None, + wait_ms: Some(100), + }) + .await?; + next_seq = read.next_seq; + if read.exited { + assert_eq!(read.exit_code, Some(0)); + break; + } + } + + Ok(()) +} + +#[test_case(false ; "local")] +#[test_case(true ; "remote")] +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn exec_process_starts_and_exits(use_remote: bool) -> Result<()> { + assert_exec_process_starts_and_exits(use_remote).await +}