diff --git a/codex-rs/codex-api/src/endpoint/responses_websocket.rs b/codex-rs/codex-api/src/endpoint/responses_websocket.rs index aa559e983..bdd32fbd5 100644 --- a/codex-rs/codex-api/src/endpoint/responses_websocket.rs +++ b/codex-rs/codex-api/src/endpoint/responses_websocket.rs @@ -175,6 +175,19 @@ pub struct ResponsesWebsocketConnection { telemetry: Option>, } +impl std::fmt::Debug for ResponsesWebsocketConnection { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ResponsesWebsocketConnection") + .field("stream", &"") + .field("idle_timeout", &self.idle_timeout) + .field("server_reasoning_included", &self.server_reasoning_included) + .field("models_etag", &self.models_etag) + .field("server_model", &self.server_model) + .field("telemetry", &self.telemetry.as_ref().map(|_| "")) + .finish() + } +} + impl ResponsesWebsocketConnection { fn new( stream: WsStream, diff --git a/codex-rs/core/src/client.rs b/codex-rs/core/src/client.rs index 658e212fa..6f33910ce 100644 --- a/codex-rs/core/src/client.rs +++ b/codex-rs/core/src/client.rs @@ -28,6 +28,7 @@ use std::collections::HashMap; use std::sync::Arc; +use std::sync::Mutex as StdMutex; use std::sync::OnceLock; use std::sync::atomic::AtomicBool; use std::sync::atomic::Ordering; @@ -123,6 +124,7 @@ struct ModelClientState { include_timing_metrics: bool, beta_features_header: Option, disable_websockets: AtomicBool, + cached_websocket_connection: StdMutex>, } /// Resolved API client setup for a single request attempt. @@ -228,6 +230,7 @@ impl ModelClient { include_timing_metrics, beta_features_header, disable_websockets: AtomicBool::new(false), + cached_websocket_connection: StdMutex::new(None), }), } } @@ -239,13 +242,29 @@ impl ModelClient { pub fn new_session(&self) -> ModelClientSession { ModelClientSession { client: self.clone(), - connection: None, + connection: self.take_cached_websocket_connection(), websocket_last_request: None, websocket_last_response_rx: None, turn_state: Arc::new(OnceLock::new()), } } + fn take_cached_websocket_connection(&self) -> Option { + self.state + .cached_websocket_connection + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner) + .take() + } + + fn store_cached_websocket_connection(&self, connection: ApiWebSocketConnection) { + *self + .state + .cached_websocket_connection + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner) = Some(connection); + } + /// Compacts the current conversation history using the Compact endpoint. /// /// This is a unary call (no streaming) that returns a new list of @@ -452,6 +471,14 @@ impl ModelClient { } } +impl Drop for ModelClientSession { + fn drop(&mut self) { + if let Some(connection) = self.connection.take() { + self.client.store_cached_websocket_connection(connection); + } + } +} + impl ModelClientSession { fn activate_http_fallback(&self, websocket_enabled: bool) -> bool { websocket_enabled diff --git a/codex-rs/core/tests/suite/client_websockets.rs b/codex-rs/core/tests/suite/client_websockets.rs index 1c7c578ca..736e5f50f 100755 --- a/codex-rs/core/tests/suite/client_websockets.rs +++ b/codex-rs/core/tests/suite/client_websockets.rs @@ -118,6 +118,34 @@ async fn responses_websocket_preconnect_reuses_connection() { server.shutdown().await; } +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn responses_websocket_reuses_connection_after_session_drop() { + skip_if_no_network!(); + + let server = start_websocket_server(vec![vec![ + vec![ev_response_created("resp-1"), ev_completed("resp-1")], + vec![ev_response_created("resp-2"), ev_completed("resp-2")], + ]]) + .await; + + let harness = websocket_harness(&server).await; + let prompt_one = prompt_with_input(vec![message_item("hello")]); + let prompt_two = prompt_with_input(vec![message_item("again")]); + + { + let mut client_session = harness.client.new_session(); + stream_until_complete(&mut client_session, &harness, &prompt_one).await; + } + + let mut client_session = harness.client.new_session(); + stream_until_complete(&mut client_session, &harness, &prompt_two).await; + + assert_eq!(server.handshakes().len(), 1); + assert_eq!(server.single_connection().len(), 2); + + server.shutdown().await; +} + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn responses_websocket_preconnect_is_reused_even_with_header_changes() { skip_if_no_network!();