diff --git a/codex-rs/codex-api/src/endpoint/responses_websocket.rs b/codex-rs/codex-api/src/endpoint/responses_websocket.rs index bc64f3bfb..9cd264505 100644 --- a/codex-rs/codex-api/src/endpoint/responses_websocket.rs +++ b/codex-rs/codex-api/src/endpoint/responses_websocket.rs @@ -16,8 +16,10 @@ use futures::StreamExt; use http::HeaderMap; use http::HeaderValue; use serde_json::Value; +use std::sync::Arc; use std::time::Duration; use tokio::net::TcpStream; +use tokio::sync::Mutex; use tokio::sync::mpsc; use tokio_tungstenite::MaybeTlsStream; use tokio_tungstenite::WebSocketStream; @@ -31,6 +33,69 @@ use url::Url; type WsStream = WebSocketStream>; +pub struct ResponsesWebsocketConnection { + stream: Arc>>, + // TODO (pakrym): is this the right place for timeout? + idle_timeout: Duration, +} + +impl ResponsesWebsocketConnection { + fn new(stream: WsStream, idle_timeout: Duration) -> Self { + Self { + stream: Arc::new(Mutex::new(Some(stream))), + idle_timeout, + } + } + + pub async fn is_closed(&self) -> bool { + self.stream.lock().await.is_none() + } + + pub async fn stream_request( + &self, + request: ResponsesRequest, + ) -> Result { + if request.compression == Compression::Zstd { + warn!( + "request compression is not supported for websocket streaming; sending uncompressed payload" + ); + } + + let (tx_event, rx_event) = + mpsc::channel::>(1600); + let stream = Arc::clone(&self.stream); + let idle_timeout = self.idle_timeout; + let request_body = request.body; + + tokio::spawn(async move { + let mut guard = stream.lock().await; + let Some(ws_stream) = guard.as_mut() else { + let _ = tx_event + .send(Err(ApiError::Stream( + "websocket connection is closed".to_string(), + ))) + .await; + return; + }; + + if let Err(err) = run_websocket_response_stream( + ws_stream, + tx_event.clone(), + request_body, + idle_timeout, + ) + .await + { + let _ = ws_stream.close(None).await; + *guard = None; + let _ = tx_event.send(Err(err)).await; + } + }); + + Ok(ResponseStream { rx_event }) + } +} + pub struct ResponsesWebsocketClient { provider: Provider, auth: A, @@ -41,12 +106,22 @@ impl ResponsesWebsocketClient { Self { provider, auth } } - pub async fn stream_request( + pub async fn connect( &self, - request: ResponsesRequest, - ) -> Result { - self.stream(request.body, request.headers, request.compression) - .await + extra_headers: HeaderMap, + ) -> Result { + let ws_url = Url::parse(&self.provider.url_for_path("responses")) + .map_err(|err| ApiError::Stream(format!("failed to build websocket URL: {err}")))?; + + let mut headers = self.provider.headers.clone(); + headers.extend(extra_headers); + apply_auth_headers(&mut headers, &self.auth); + + let stream = connect_websocket(ws_url, headers).await?; + Ok(ResponsesWebsocketConnection::new( + stream, + self.provider.stream_idle_timeout, + )) } pub async fn stream_prompt( @@ -82,7 +157,8 @@ impl ResponsesWebsocketClient { .compression(compression) .build(&self.provider)?; - self.stream_request(request).await + let connection = self.connect(request.headers.clone()).await?; + connection.stream_request(request).await } pub async fn stream( @@ -91,41 +167,13 @@ impl ResponsesWebsocketClient { extra_headers: HeaderMap, compression: Compression, ) -> Result { - if compression == Compression::Zstd { - warn!( - "request compression is not supported for websocket streaming; sending uncompressed payload" - ); - } - - let ws_url = Url::parse(&self.provider.url_for_path("responses")) - .map_err(|err| ApiError::Stream(format!("failed to build websocket URL: {err}")))?; - let mut headers = self.provider.headers.clone(); - headers.extend(extra_headers); - apply_auth_headers(&mut headers, &self.auth); - - let connection = connect_websocket(ws_url, headers).await?; - - let (tx_event, rx_event) = - mpsc::channel::>(1600); - let idle_timeout = self.provider.stream_idle_timeout; - - // TODO (pakrym): surface rate limits - // TODO (pakrym): check models etags - - tokio::spawn(async move { - if let Err(err) = run_websocket_response_stream( - connection.stream, - tx_event.clone(), - body, - idle_timeout, - ) - .await - { - let _ = tx_event.send(Err(err)).await; - } - }); - - Ok(ResponseStream { rx_event }) + let request = ResponsesRequest { + body, + headers: extra_headers, + compression, + }; + let connection = self.connect(request.headers.clone()).await?; + connection.stream_request(request).await } } @@ -143,11 +191,7 @@ fn apply_auth_headers(headers: &mut HeaderMap, auth: &impl AuthProvider) { } } -struct WebSocketConnection { - stream: WsStream, -} - -async fn connect_websocket(url: Url, headers: HeaderMap) -> Result { +async fn connect_websocket(url: Url, headers: HeaderMap) -> Result { let mut request = url .clone() .into_client_request() @@ -157,7 +201,7 @@ async fn connect_websocket(url: Url, headers: HeaderMap) -> Result ApiError { @@ -185,7 +229,7 @@ fn map_ws_error(err: WsError, url: &Url) -> ApiError { } async fn run_websocket_response_stream( - mut ws_stream: WsStream, + ws_stream: &mut WsStream, tx_event: mpsc::Sender>, request_body: Value, idle_timeout: Duration, @@ -193,7 +237,6 @@ async fn run_websocket_response_stream( let request_text = match serde_json::to_string(&request_body) { Ok(text) => text, Err(err) => { - let _ = ws_stream.close(None).await; return Err(ApiError::Stream(format!( "failed to encode websocket request: {err}" ))); @@ -201,7 +244,6 @@ async fn run_websocket_response_stream( }; if let Err(err) = ws_stream.send(Message::Text(request_text)).await { - let _ = ws_stream.close(None).await; return Err(ApiError::Stream(format!( "failed to send websocket request: {err}" ))); @@ -214,17 +256,14 @@ async fn run_websocket_response_stream( let message = match response { Ok(Some(Ok(msg))) => msg, Ok(Some(Err(err))) => { - let _ = ws_stream.close(None).await; return Err(ApiError::Stream(err.to_string())); } Ok(None) => { - let _ = ws_stream.close(None).await; return Err(ApiError::Stream( "stream closed before response.completed".into(), )); } Err(err) => { - let _ = ws_stream.close(None).await; return Err(err); } }; @@ -249,24 +288,20 @@ async fn run_websocket_response_stream( } Ok(None) => {} Err(error) => { - let _ = ws_stream.close(None).await; return Err(error.into_api_error()); } } } Message::Binary(_) => { - let _ = ws_stream.close(None).await; return Err(ApiError::Stream("unexpected binary websocket event".into())); } Message::Ping(payload) => { if ws_stream.send(Message::Pong(payload)).await.is_err() { - let _ = ws_stream.close(None).await; return Err(ApiError::Stream("websocket ping failed".into())); } } Message::Pong(_) => {} Message::Close(_) => { - let _ = ws_stream.close(None).await; return Err(ApiError::Stream( "websocket closed before response.completed".into(), )); @@ -275,6 +310,5 @@ async fn run_websocket_response_stream( } } - let _ = ws_stream.close(None).await; Ok(()) } diff --git a/codex-rs/codex-api/src/lib.rs b/codex-rs/codex-api/src/lib.rs index 4e82b874b..0128efc35 100644 --- a/codex-rs/codex-api/src/lib.rs +++ b/codex-rs/codex-api/src/lib.rs @@ -26,6 +26,7 @@ pub use crate::endpoint::models::ModelsClient; pub use crate::endpoint::responses::ResponsesClient; pub use crate::endpoint::responses::ResponsesOptions; pub use crate::endpoint::responses_websocket::ResponsesWebsocketClient; +pub use crate::endpoint::responses_websocket::ResponsesWebsocketConnection; pub use crate::error::ApiError; pub use crate::provider::Provider; pub use crate::provider::WireApi; diff --git a/codex-rs/core/src/client.rs b/codex-rs/core/src/client.rs index eb866527c..e6dc69b15 100644 --- a/codex-rs/core/src/client.rs +++ b/codex-rs/core/src/client.rs @@ -1,5 +1,6 @@ use std::sync::Arc; +use crate::api_bridge::CoreAuthProvider; use crate::api_bridge::auth_provider_from_auth; use crate::api_bridge::map_api_error; use crate::auth::UnauthorizedRecovery; @@ -13,7 +14,10 @@ use codex_api::ReqwestTransport; use codex_api::ResponseStream as ApiResponseStream; use codex_api::ResponsesClient as ApiResponsesClient; use codex_api::ResponsesOptions as ApiResponsesOptions; +use codex_api::ResponsesRequest; +use codex_api::ResponsesRequestBuilder; use codex_api::ResponsesWebsocketClient as ApiWebSocketResponsesClient; +use codex_api::ResponsesWebsocketConnection as ApiWebSocketConnection; use codex_api::SseTelemetry; use codex_api::TransportError; use codex_api::common::Reasoning; @@ -76,9 +80,9 @@ pub struct ModelClient { state: Arc, } -#[derive(Debug, Clone)] pub struct ModelClientSession { state: Arc, + connection: Option, } #[allow(clippy::too_many_arguments)] @@ -112,6 +116,7 @@ impl ModelClient { pub fn new_session(&self) -> ModelClientSession { ModelClientSession { state: Arc::clone(&self.state), + connection: None, } } } @@ -228,7 +233,7 @@ impl ModelClientSession { /// /// For Chat providers, the underlying stream is optionally aggregated /// based on the `show_raw_agent_reasoning` flag in the config. - pub async fn stream(&self, prompt: &Prompt) -> Result { + pub async fn stream(&mut self, prompt: &Prompt) -> Result { match self.state.provider.wire_api { WireApi::Responses => self.stream_responses_api(prompt).await, WireApi::ResponsesWebsocket => self.stream_responses_websocket(prompt).await, @@ -315,6 +320,67 @@ impl ModelClientSession { } } + fn build_responses_websocket_request( + &self, + api_provider: &codex_api::Provider, + api_prompt: &ApiPrompt, + options: ApiResponsesOptions, + ) -> Result { + let ApiResponsesOptions { + reasoning, + include, + prompt_cache_key, + text, + store_override, + conversation_id, + session_source, + extra_headers, + compression, + } = options; + + ResponsesRequestBuilder::new( + &self.state.model_info.slug, + &api_prompt.instructions, + &api_prompt.input, + ) + .tools(&api_prompt.tools) + .parallel_tool_calls(api_prompt.parallel_tool_calls) + .reasoning(reasoning) + .include(include) + .prompt_cache_key(prompt_cache_key) + .text(text) + .conversation(conversation_id) + .session_source(session_source) + .store_override(store_override) + .extra_headers(extra_headers) + .compression(compression) + .build(api_provider) + .map_err(map_api_error) + } + + async fn websocket_connection( + &mut self, + api_provider: codex_api::Provider, + api_auth: CoreAuthProvider, + headers: ApiHeaderMap, + ) -> std::result::Result<&ApiWebSocketConnection, ApiError> { + let needs_new = match self.connection.as_ref() { + Some(conn) => conn.is_closed().await, + None => true, + }; + + if needs_new { + let new_conn = ApiWebSocketResponsesClient::new(api_provider, api_auth) + .connect(headers) + .await?; + self.connection = Some(new_conn); + } + + self.connection.as_ref().ok_or(ApiError::Stream( + "websocket connection is unavailable".to_string(), + )) + } + fn responses_request_compression(&self, auth: Option<&crate::auth::CodexAuth>) -> Compression { if self .state @@ -447,7 +513,7 @@ impl ModelClientSession { } /// Streams a turn via the Responses API over WebSocket transport. - async fn stream_responses_websocket(&self, prompt: &Prompt) -> Result { + async fn stream_responses_websocket(&mut self, prompt: &Prompt) -> Result { let auth_manager = self.state.auth_manager.clone(); let api_prompt = self.build_responses_request(prompt)?; @@ -467,16 +533,18 @@ impl ModelClientSession { let compression = self.responses_request_compression(auth.as_ref()); let options = self.build_responses_options(prompt, compression); - let client = ApiWebSocketResponsesClient::new(api_provider, api_auth); + let request = + self.build_responses_websocket_request(&api_provider, &api_prompt, options)?; - let stream_result = client - .stream_prompt(&self.state.model_info.slug, &api_prompt, options) - .await; - - match stream_result { - Ok(stream) => { - return Ok(map_response_stream(stream, self.state.otel_manager.clone())); - } + let connection = match self + .websocket_connection( + api_provider.clone(), + api_auth.clone(), + request.headers.clone(), + ) + .await + { + Ok(connection) => connection, Err(ApiError::Transport(TransportError::Http { status, .. })) if status == StatusCode::UNAUTHORIZED => { @@ -484,7 +552,17 @@ impl ModelClientSession { continue; } Err(err) => return Err(map_api_error(err)), - } + }; + + let stream_result = connection + .stream_request(request) + .await + .map_err(map_api_error)?; + + return Ok(map_response_stream( + stream_result, + self.state.otel_manager.clone(), + )); } } diff --git a/codex-rs/core/src/codex.rs b/codex-rs/core/src/codex.rs index 38687cc48..3ecb82ec5 100644 --- a/codex-rs/core/src/codex.rs +++ b/codex-rs/core/src/codex.rs @@ -2673,7 +2673,7 @@ async fn run_model_turn( output_schema: turn_context.final_output_json_schema.clone(), }; - let client_session = turn_context.client.new_session(); + let mut client_session = turn_context.client.new_session(); let mut retries = 0; loop { @@ -2681,7 +2681,7 @@ async fn run_model_turn( Arc::clone(&router), Arc::clone(&sess), Arc::clone(&turn_context), - &client_session, + &mut client_session, Arc::clone(&turn_diff_tracker), &prompt, cancellation_token.child_token(), @@ -2773,7 +2773,7 @@ async fn try_run_turn( router: Arc, sess: Arc, turn_context: Arc, - client_session: &ModelClientSession, + client_session: &mut ModelClientSession, turn_diff_tracker: SharedTurnDiffTracker, prompt: &Prompt, cancellation_token: CancellationToken, diff --git a/codex-rs/core/src/compact.rs b/codex-rs/core/src/compact.rs index 2a518dfeb..120e701bb 100644 --- a/codex-rs/core/src/compact.rs +++ b/codex-rs/core/src/compact.rs @@ -297,7 +297,7 @@ async fn drain_to_completed( turn_context: &TurnContext, prompt: &Prompt, ) -> CodexResult<()> { - let client_session = turn_context.client.new_session(); + let mut client_session = turn_context.client.new_session(); let mut stream = client_session.stream(prompt).await?; loop { let maybe_event = stream.next().await; diff --git a/codex-rs/core/tests/chat_completions_payload.rs b/codex-rs/core/tests/chat_completions_payload.rs index c8fef336e..23b50823f 100644 --- a/codex-rs/core/tests/chat_completions_payload.rs +++ b/codex-rs/core/tests/chat_completions_payload.rs @@ -88,7 +88,7 @@ async fn run_request(input: Vec) -> Value { SessionSource::Exec, ); - let client = ModelClient::new( + let mut client_session = ModelClient::new( Arc::clone(&config), None, model_info, @@ -104,7 +104,7 @@ async fn run_request(input: Vec) -> Value { let mut prompt = Prompt::default(); prompt.input = input; - let mut stream = match client.stream(&prompt).await { + let mut stream = match client_session.stream(&prompt).await { Ok(s) => s, Err(e) => panic!("stream chat failed: {e}"), }; diff --git a/codex-rs/core/tests/chat_completions_sse.rs b/codex-rs/core/tests/chat_completions_sse.rs index 157475580..f6d7eb24f 100644 --- a/codex-rs/core/tests/chat_completions_sse.rs +++ b/codex-rs/core/tests/chat_completions_sse.rs @@ -89,7 +89,7 @@ async fn run_stream_with_bytes(sse_body: &[u8]) -> Vec { SessionSource::Exec, ); - let client = ModelClient::new( + let mut client = ModelClient::new( Arc::clone(&config), None, model_info, diff --git a/codex-rs/core/tests/responses_headers.rs b/codex-rs/core/tests/responses_headers.rs index 3efbb2b7e..8be6e3634 100644 --- a/codex-rs/core/tests/responses_headers.rs +++ b/codex-rs/core/tests/responses_headers.rs @@ -81,7 +81,7 @@ async fn responses_stream_includes_subagent_header_on_review() { session_source.clone(), ); - let client = ModelClient::new( + let mut client_session = ModelClient::new( Arc::clone(&config), None, model_info, @@ -103,7 +103,7 @@ async fn responses_stream_includes_subagent_header_on_review() { }], }]; - let mut stream = client.stream(&prompt).await.expect("stream failed"); + let mut stream = client_session.stream(&prompt).await.expect("stream failed"); while let Some(event) = stream.next().await { if matches!(event, Ok(ResponseEvent::Completed { .. })) { break; @@ -177,7 +177,7 @@ async fn responses_stream_includes_subagent_header_on_other() { session_source.clone(), ); - let client = ModelClient::new( + let mut client_session = ModelClient::new( Arc::clone(&config), None, model_info, @@ -199,7 +199,7 @@ async fn responses_stream_includes_subagent_header_on_other() { }], }]; - let mut stream = client.stream(&prompt).await.expect("stream failed"); + let mut stream = client_session.stream(&prompt).await.expect("stream failed"); while let Some(event) = stream.next().await { if matches!(event, Ok(ResponseEvent::Completed { .. })) { break; @@ -271,7 +271,7 @@ async fn responses_respects_model_info_overrides_from_config() { session_source.clone(), ); - let client = ModelClient::new( + let mut client = ModelClient::new( Arc::clone(&config), None, model_info, diff --git a/codex-rs/core/tests/suite/client.rs b/codex-rs/core/tests/suite/client.rs index 458b355f1..ecb8dcbbf 100644 --- a/codex-rs/core/tests/suite/client.rs +++ b/codex-rs/core/tests/suite/client.rs @@ -1171,7 +1171,7 @@ async fn azure_responses_request_includes_store_and_reasoning_ids() { SessionSource::Exec, ); - let client = ModelClient::new( + let mut client = ModelClient::new( Arc::clone(&config), None, model_info, diff --git a/codex-rs/core/tests/suite/websocket.rs b/codex-rs/core/tests/suite/websocket.rs index fc15c8ae8..6a5f99acb 100644 --- a/codex-rs/core/tests/suite/websocket.rs +++ b/codex-rs/core/tests/suite/websocket.rs @@ -1,7 +1,9 @@ +#![allow(clippy::expect_used, clippy::unwrap_used)] use codex_core::AuthManager; use codex_core::CodexAuth; use codex_core::ContentItem; use codex_core::ModelClient; +use codex_core::ModelClientSession; use codex_core::ModelProviderInfo; use codex_core::Prompt; use codex_core::ResponseEvent; @@ -11,23 +13,97 @@ use codex_core::models_manager::manager::ModelsManager; use codex_core::protocol::SessionSource; use codex_otel::OtelManager; use codex_protocol::ThreadId; +use codex_protocol::config_types::ReasoningSummary; use core_test_support::load_default_config_for_test; +use core_test_support::responses::WebSocketTestServer; use core_test_support::responses::ev_completed; use core_test_support::responses::ev_response_created; use core_test_support::responses::start_websocket_server; +use core_test_support::skip_if_no_network; use futures::StreamExt; +use pretty_assertions::assert_eq; use std::sync::Arc; use tempfile::TempDir; +const MODEL: &str = "gpt-5.2-codex"; + +struct WebsocketTestHarness { + _codex_home: TempDir, + client: ModelClient, +} + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn responses_websocket_streams_request() { + skip_if_no_network!(); + let server = start_websocket_server(vec![vec![vec![ ev_response_created("resp-1"), ev_completed("resp-1"), ]]]) .await; - let provider = ModelProviderInfo { + let harness = websocket_harness(&server).await; + let mut session = harness.client.new_session(); + let mut prompt = Prompt::default(); + prompt.input = vec![ResponseItem::Message { + id: None, + role: "user".into(), + content: vec![ContentItem::InputText { + text: "hello".into(), + }], + }]; + + stream_until_complete(&mut session, &prompt).await; + + let connection = server.single_connection(); + assert_eq!(connection.len(), 1); + let body = connection.first().expect("missing request").body_json(); + + assert_eq!(body["model"].as_str(), Some(MODEL)); + assert_eq!(body["stream"], serde_json::Value::Bool(true)); + assert_eq!(body["input"].as_array().map(Vec::len), Some(1)); + + server.shutdown().await; +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn responses_websocket_reuses_connection() { + skip_if_no_network!(); + + let server = start_websocket_server(vec![vec![ + vec![ev_response_created("resp-1"), ev_completed("resp-1")], + vec![ev_response_created("resp-2"), ev_completed("resp-2")], + ]]) + .await; + + let harness = websocket_harness(&server).await; + let mut session = harness.client.new_session(); + let mut prompt = Prompt::default(); + prompt.input = vec![ResponseItem::Message { + id: None, + role: "user".into(), + content: vec![ContentItem::InputText { + text: "hello".into(), + }], + }]; + + for _ in 0..2 { + stream_until_complete(&mut session, &prompt).await; + } + + let connection = server.single_connection(); + assert_eq!(connection.len(), 2); + let body = connection.first().expect("missing request").body_json(); + + assert_eq!(body["model"].as_str(), Some(MODEL)); + assert_eq!(body["stream"], serde_json::Value::Bool(true)); + assert_eq!(body["input"].as_array().map(Vec::len), Some(1)); + + server.shutdown().await; +} + +fn websocket_provider(server: &WebSocketTestServer) -> ModelProviderInfo { + ModelProviderInfo { name: "mock-ws".into(), base_url: Some(format!("{}/v1", server.uri())), env_key: None, @@ -41,23 +117,21 @@ async fn responses_websocket_streams_request() { stream_max_retries: Some(0), stream_idle_timeout_ms: Some(5_000), requires_openai_auth: false, - }; + } +} +async fn websocket_harness(server: &WebSocketTestServer) -> WebsocketTestHarness { + let provider = websocket_provider(server); let codex_home = TempDir::new().unwrap(); let mut config = load_default_config_for_test(&codex_home).await; - config.model_provider_id = provider.name.clone(); - config.model_provider = provider.clone(); - let effort = config.model_reasoning_effort; - let summary = config.model_reasoning_summary; - let model = ModelsManager::get_model_offline(config.model.as_deref()); - config.model = Some(model.clone()); + config.model = Some(MODEL.to_string()); let config = Arc::new(config); - let model_info = ModelsManager::construct_model_info_offline(model.as_str(), &config); + let model_info = ModelsManager::construct_model_info_offline(MODEL, &config); let conversation_id = ThreadId::new(); let auth_manager = AuthManager::from_auth_for_testing(CodexAuth::from_api_key("Test API Key")); let otel_manager = OtelManager::new( conversation_id, - model.as_str(), + MODEL, model_info.slug.as_str(), None, Some("test@test.com".to_string()), @@ -66,31 +140,27 @@ async fn responses_websocket_streams_request() { "test".to_string(), SessionSource::Exec, ); - let client = ModelClient::new( Arc::clone(&config), None, model_info, otel_manager, - provider, - effort, - summary, + provider.clone(), + None, + ReasoningSummary::Auto, conversation_id, SessionSource::Exec, - ) - .new_session(); + ); - let mut prompt = Prompt::default(); - prompt.input = vec![ResponseItem::Message { - id: None, - role: "user".into(), - content: vec![ContentItem::InputText { - text: "hello".into(), - }], - }]; + WebsocketTestHarness { + _codex_home: codex_home, + client, + } +} - let mut stream = client - .stream(&prompt) +async fn stream_until_complete(session: &mut ModelClientSession, prompt: &Prompt) { + let mut stream = session + .stream(prompt) .await .expect("websocket stream failed"); @@ -99,14 +169,4 @@ async fn responses_websocket_streams_request() { break; } } - - let connection = server.single_connection(); - assert_eq!(connection.len(), 1); - let request = connection.first().cloned().unwrap(); - let body = request.body_json(); - assert_eq!(body["model"].as_str(), Some(model.as_str())); - assert_eq!(body["stream"], serde_json::Value::Bool(true)); - assert_eq!(body["input"].as_array().map(Vec::len), Some(1)); - - server.shutdown().await; }