From 6d08298f4e54f8a4ac2cda4171bcb2c954605fcd Mon Sep 17 00:00:00 2001 From: pakrym-oai Date: Sat, 7 Feb 2026 21:06:33 -0800 Subject: [PATCH] Fallback to HTTP on UPGRADE_REQUIRED (#10824) Allow the server to trigger a connection downgrade in case the protocol changes in incompatible ways. --- codex-rs/core/src/client.rs | 62 ++++++++++------ codex-rs/core/src/codex.rs | 19 +++-- .../core/tests/suite/websocket_fallback.rs | 71 +++++++++++++++++-- 3 files changed, 114 insertions(+), 38 deletions(-) diff --git a/codex-rs/core/src/client.rs b/codex-rs/core/src/client.rs index 3cecec337..04ab5748f 100644 --- a/codex-rs/core/src/client.rs +++ b/codex-rs/core/src/client.rs @@ -225,6 +225,11 @@ pub struct ModelClientSession { turn_state: Arc>, } +enum WebsocketStreamOutcome { + Stream(ResponseStream), + FallbackToHttp, +} + impl ModelClient { #[allow(clippy::too_many_arguments)] /// Creates a new session-scoped `ModelClient`. @@ -926,7 +931,7 @@ impl ModelClientSession { effort: Option, summary: ReasoningSummaryConfig, turn_metadata_header: Option<&str>, - ) -> Result { + ) -> Result { let auth_manager = self.client.state.auth_manager.clone(); let api_prompt = Self::build_responses_request(prompt)?; @@ -957,6 +962,11 @@ impl ModelClientSession { .await { Ok(_) => {} + Err(ApiError::Transport(TransportError::Http { status, .. })) + if status == StatusCode::UPGRADE_REQUIRED => + { + return Ok(WebsocketStreamOutcome::FallbackToHttp); + } Err(ApiError::Transport( unauthorized_transport @ TransportError::Http { status, .. }, )) if status == StatusCode::UNAUTHORIZED => { @@ -992,7 +1002,10 @@ impl ModelClientSession { } }); - return Ok(map_response_stream(stream_result, otel_manager.clone())); + return Ok(WebsocketStreamOutcome::Stream(map_response_stream( + stream_result, + otel_manager.clone(), + ))); } } @@ -1036,26 +1049,33 @@ impl ModelClientSession { self.client.responses_websocket_enabled() && !self.client.disable_websockets(); if websocket_enabled { - self.stream_responses_websocket( - prompt, - model_info, - otel_manager, - effort, - summary, - turn_metadata_header, - ) - .await - } else { - self.stream_responses_api( - prompt, - model_info, - otel_manager, - effort, - summary, - turn_metadata_header, - ) - .await + match self + .stream_responses_websocket( + prompt, + model_info, + otel_manager, + effort, + summary, + turn_metadata_header, + ) + .await? + { + WebsocketStreamOutcome::Stream(stream) => return Ok(stream), + WebsocketStreamOutcome::FallbackToHttp => { + self.try_switch_fallback_transport(otel_manager); + } + } } + + self.stream_responses_api( + prompt, + model_info, + otel_manager, + effort, + summary, + turn_metadata_header, + ) + .await } } } diff --git a/codex-rs/core/src/codex.rs b/codex-rs/core/src/codex.rs index bac87686a..3840aa708 100644 --- a/codex-rs/core/src/codex.rs +++ b/codex-rs/core/src/codex.rs @@ -558,9 +558,8 @@ impl TurnContext { } async fn build_turn_metadata_header(&self) -> Option { - let cwd = self.cwd.clone(); self.turn_metadata_header - .get_or_init(|| async { build_turn_metadata_header(cwd).await }) + .get_or_init(|| async { build_turn_metadata_header(self.cwd.clone()).await }) .await .clone() } @@ -1098,14 +1097,14 @@ impl Session { // Warm a websocket in the background so the first turn can reuse it. // This performs only connection setup; user input is still sent later via response.create // when submit_turn() runs. - sess.services.model_client.pre_establish_connection( - sess.services.otel_manager.clone(), - resolve_turn_metadata_header_with_timeout( - build_turn_metadata_header(session_configuration.cwd.clone()), - None, - ) - .boxed(), - ); + let turn_metadata_header = resolve_turn_metadata_header_with_timeout( + build_turn_metadata_header(session_configuration.cwd.clone()), + None, + ) + .boxed(); + sess.services + .model_client + .pre_establish_connection(sess.services.otel_manager.clone(), turn_metadata_header); // Dispatch the SessionConfiguredEvent first and then report any errors. // If resuming, include converted initial messages in the payload so UIs can render them immediately. diff --git a/codex-rs/core/tests/suite/websocket_fallback.rs b/codex-rs/core/tests/suite/websocket_fallback.rs index b5ac66d1a..843f3168f 100644 --- a/codex-rs/core/tests/suite/websocket_fallback.rs +++ b/codex-rs/core/tests/suite/websocket_fallback.rs @@ -9,13 +9,23 @@ use core_test_support::responses::sse; use core_test_support::skip_if_no_network; use core_test_support::test_codex::test_codex; use pretty_assertions::assert_eq; +use wiremock::Mock; +use wiremock::ResponseTemplate; use wiremock::http::Method; +use wiremock::matchers::method; +use wiremock::matchers::path_regex; #[tokio::test(flavor = "multi_thread", worker_threads = 2)] -async fn websocket_fallback_switches_to_http_after_retries_exhausted() -> Result<()> { +async fn websocket_fallback_switches_to_http_on_upgrade_required_connect() -> Result<()> { skip_if_no_network!(Ok(())); let server = responses::start_mock_server().await; + Mock::given(method("GET")) + .and(path_regex(".*/responses$")) + .respond_with(ResponseTemplate::new(426)) + .mount(&server) + .await; + let response_mock = mount_sse_once( &server, sse(vec![ev_response_created("resp-1"), ev_completed("resp-1")]), @@ -28,7 +38,9 @@ async fn websocket_fallback_switches_to_http_after_retries_exhausted() -> Result config.model_provider.base_url = Some(base_url); config.model_provider.wire_api = codex_core::WireApi::Responses; config.features.enable(Feature::ResponsesWebsockets); - config.model_provider.stream_max_retries = Some(0); + // If we don't treat 426 specially, the sampling loop would retry the WebSocket + // handshake before switching to the HTTP transport. + config.model_provider.stream_max_retries = Some(2); config.model_provider.request_max_retries = Some(0); } }); @@ -56,6 +68,51 @@ async fn websocket_fallback_switches_to_http_after_retries_exhausted() -> Result Ok(()) } +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn websocket_fallback_switches_to_http_after_retries_exhausted() -> Result<()> { + skip_if_no_network!(Ok(())); + + let server = responses::start_mock_server().await; + let response_mock = mount_sse_once( + &server, + sse(vec![ev_response_created("resp-1"), ev_completed("resp-1")]), + ) + .await; + + let mut builder = test_codex().with_config({ + let base_url = format!("{}/v1", server.uri()); + move |config| { + config.model_provider.base_url = Some(base_url); + config.model_provider.wire_api = codex_core::WireApi::Responses; + config.features.enable(Feature::ResponsesWebsockets); + config.model_provider.stream_max_retries = Some(2); + config.model_provider.request_max_retries = Some(0); + } + }); + let test = builder.build(&server).await?; + + test.submit_turn("hello").await?; + + let requests = server.received_requests().await.unwrap_or_default(); + let websocket_attempts = requests + .iter() + .filter(|req| req.method == Method::GET && req.url.path().ends_with("/responses")) + .count(); + let http_attempts = requests + .iter() + .filter(|req| req.method == Method::POST && req.url.path().ends_with("/responses")) + .count(); + + // One websocket attempt comes from startup preconnect. + // The first turn then makes 3 websocket stream attempts (initial try + 2 retries), + // after which fallback activates and the request is replayed over HTTP. + assert_eq!(websocket_attempts, 4); + assert_eq!(http_attempts, 1); + assert_eq!(response_mock.requests().len(), 1); + + Ok(()) +} + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn websocket_fallback_is_sticky_across_turns() -> Result<()> { skip_if_no_network!(Ok(())); @@ -76,7 +133,7 @@ async fn websocket_fallback_is_sticky_across_turns() -> Result<()> { config.model_provider.base_url = Some(base_url); config.model_provider.wire_api = codex_core::WireApi::Responses; config.features.enable(Feature::ResponsesWebsockets); - config.model_provider.stream_max_retries = Some(0); + config.model_provider.stream_max_retries = Some(2); config.model_provider.request_max_retries = Some(0); } }); @@ -95,10 +152,10 @@ async fn websocket_fallback_is_sticky_across_turns() -> Result<()> { .filter(|req| req.method == Method::POST && req.url.path().ends_with("/responses")) .count(); - // The first turn issues exactly two websocket attempts (startup preconnect + first stream - // attempt). After fallback becomes sticky, subsequent turns stay on HTTP. This mirrors the - // retry-budget tradeoff documented in [`codex_core::client`] module docs. - assert_eq!(websocket_attempts, 2); + // WebSocket attempts all happen on the first turn: + // 1 startup preconnect + 3 stream attempts (initial try + 2 retries) before fallback. + // Fallback is sticky, so the second turn stays on HTTP and adds no websocket attempts. + assert_eq!(websocket_attempts, 4); assert_eq!(http_attempts, 2); assert_eq!(response_mock.requests().len(), 2);