Fallback to HTTP on UPGRADE_REQUIRED (#10824)
Allow the server to trigger a connection downgrade in case the protocol changes in incompatible ways.
This commit is contained in:
parent
d68e9c0f19
commit
6d08298f4e
3 changed files with 114 additions and 38 deletions
|
|
@ -225,6 +225,11 @@ pub struct ModelClientSession {
|
|||
turn_state: Arc<OnceLock<String>>,
|
||||
}
|
||||
|
||||
enum WebsocketStreamOutcome {
|
||||
Stream(ResponseStream),
|
||||
FallbackToHttp,
|
||||
}
|
||||
|
||||
impl ModelClient {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
/// Creates a new session-scoped `ModelClient`.
|
||||
|
|
@ -926,7 +931,7 @@ impl ModelClientSession {
|
|||
effort: Option<ReasoningEffortConfig>,
|
||||
summary: ReasoningSummaryConfig,
|
||||
turn_metadata_header: Option<&str>,
|
||||
) -> Result<ResponseStream> {
|
||||
) -> Result<WebsocketStreamOutcome> {
|
||||
let auth_manager = self.client.state.auth_manager.clone();
|
||||
let api_prompt = Self::build_responses_request(prompt)?;
|
||||
|
||||
|
|
@ -957,6 +962,11 @@ impl ModelClientSession {
|
|||
.await
|
||||
{
|
||||
Ok(_) => {}
|
||||
Err(ApiError::Transport(TransportError::Http { status, .. }))
|
||||
if status == StatusCode::UPGRADE_REQUIRED =>
|
||||
{
|
||||
return Ok(WebsocketStreamOutcome::FallbackToHttp);
|
||||
}
|
||||
Err(ApiError::Transport(
|
||||
unauthorized_transport @ TransportError::Http { status, .. },
|
||||
)) if status == StatusCode::UNAUTHORIZED => {
|
||||
|
|
@ -992,7 +1002,10 @@ impl ModelClientSession {
|
|||
}
|
||||
});
|
||||
|
||||
return Ok(map_response_stream(stream_result, otel_manager.clone()));
|
||||
return Ok(WebsocketStreamOutcome::Stream(map_response_stream(
|
||||
stream_result,
|
||||
otel_manager.clone(),
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -1036,26 +1049,33 @@ impl ModelClientSession {
|
|||
self.client.responses_websocket_enabled() && !self.client.disable_websockets();
|
||||
|
||||
if websocket_enabled {
|
||||
self.stream_responses_websocket(
|
||||
prompt,
|
||||
model_info,
|
||||
otel_manager,
|
||||
effort,
|
||||
summary,
|
||||
turn_metadata_header,
|
||||
)
|
||||
.await
|
||||
} else {
|
||||
self.stream_responses_api(
|
||||
prompt,
|
||||
model_info,
|
||||
otel_manager,
|
||||
effort,
|
||||
summary,
|
||||
turn_metadata_header,
|
||||
)
|
||||
.await
|
||||
match self
|
||||
.stream_responses_websocket(
|
||||
prompt,
|
||||
model_info,
|
||||
otel_manager,
|
||||
effort,
|
||||
summary,
|
||||
turn_metadata_header,
|
||||
)
|
||||
.await?
|
||||
{
|
||||
WebsocketStreamOutcome::Stream(stream) => return Ok(stream),
|
||||
WebsocketStreamOutcome::FallbackToHttp => {
|
||||
self.try_switch_fallback_transport(otel_manager);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
self.stream_responses_api(
|
||||
prompt,
|
||||
model_info,
|
||||
otel_manager,
|
||||
effort,
|
||||
summary,
|
||||
turn_metadata_header,
|
||||
)
|
||||
.await
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -558,9 +558,8 @@ impl TurnContext {
|
|||
}
|
||||
|
||||
async fn build_turn_metadata_header(&self) -> Option<String> {
|
||||
let cwd = self.cwd.clone();
|
||||
self.turn_metadata_header
|
||||
.get_or_init(|| async { build_turn_metadata_header(cwd).await })
|
||||
.get_or_init(|| async { build_turn_metadata_header(self.cwd.clone()).await })
|
||||
.await
|
||||
.clone()
|
||||
}
|
||||
|
|
@ -1098,14 +1097,14 @@ impl Session {
|
|||
// Warm a websocket in the background so the first turn can reuse it.
|
||||
// This performs only connection setup; user input is still sent later via response.create
|
||||
// when submit_turn() runs.
|
||||
sess.services.model_client.pre_establish_connection(
|
||||
sess.services.otel_manager.clone(),
|
||||
resolve_turn_metadata_header_with_timeout(
|
||||
build_turn_metadata_header(session_configuration.cwd.clone()),
|
||||
None,
|
||||
)
|
||||
.boxed(),
|
||||
);
|
||||
let turn_metadata_header = resolve_turn_metadata_header_with_timeout(
|
||||
build_turn_metadata_header(session_configuration.cwd.clone()),
|
||||
None,
|
||||
)
|
||||
.boxed();
|
||||
sess.services
|
||||
.model_client
|
||||
.pre_establish_connection(sess.services.otel_manager.clone(), turn_metadata_header);
|
||||
|
||||
// Dispatch the SessionConfiguredEvent first and then report any errors.
|
||||
// If resuming, include converted initial messages in the payload so UIs can render them immediately.
|
||||
|
|
|
|||
|
|
@ -9,13 +9,23 @@ use core_test_support::responses::sse;
|
|||
use core_test_support::skip_if_no_network;
|
||||
use core_test_support::test_codex::test_codex;
|
||||
use pretty_assertions::assert_eq;
|
||||
use wiremock::Mock;
|
||||
use wiremock::ResponseTemplate;
|
||||
use wiremock::http::Method;
|
||||
use wiremock::matchers::method;
|
||||
use wiremock::matchers::path_regex;
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn websocket_fallback_switches_to_http_after_retries_exhausted() -> Result<()> {
|
||||
async fn websocket_fallback_switches_to_http_on_upgrade_required_connect() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let server = responses::start_mock_server().await;
|
||||
Mock::given(method("GET"))
|
||||
.and(path_regex(".*/responses$"))
|
||||
.respond_with(ResponseTemplate::new(426))
|
||||
.mount(&server)
|
||||
.await;
|
||||
|
||||
let response_mock = mount_sse_once(
|
||||
&server,
|
||||
sse(vec![ev_response_created("resp-1"), ev_completed("resp-1")]),
|
||||
|
|
@ -28,7 +38,9 @@ async fn websocket_fallback_switches_to_http_after_retries_exhausted() -> Result
|
|||
config.model_provider.base_url = Some(base_url);
|
||||
config.model_provider.wire_api = codex_core::WireApi::Responses;
|
||||
config.features.enable(Feature::ResponsesWebsockets);
|
||||
config.model_provider.stream_max_retries = Some(0);
|
||||
// If we don't treat 426 specially, the sampling loop would retry the WebSocket
|
||||
// handshake before switching to the HTTP transport.
|
||||
config.model_provider.stream_max_retries = Some(2);
|
||||
config.model_provider.request_max_retries = Some(0);
|
||||
}
|
||||
});
|
||||
|
|
@ -56,6 +68,51 @@ async fn websocket_fallback_switches_to_http_after_retries_exhausted() -> Result
|
|||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn websocket_fallback_switches_to_http_after_retries_exhausted() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let server = responses::start_mock_server().await;
|
||||
let response_mock = mount_sse_once(
|
||||
&server,
|
||||
sse(vec![ev_response_created("resp-1"), ev_completed("resp-1")]),
|
||||
)
|
||||
.await;
|
||||
|
||||
let mut builder = test_codex().with_config({
|
||||
let base_url = format!("{}/v1", server.uri());
|
||||
move |config| {
|
||||
config.model_provider.base_url = Some(base_url);
|
||||
config.model_provider.wire_api = codex_core::WireApi::Responses;
|
||||
config.features.enable(Feature::ResponsesWebsockets);
|
||||
config.model_provider.stream_max_retries = Some(2);
|
||||
config.model_provider.request_max_retries = Some(0);
|
||||
}
|
||||
});
|
||||
let test = builder.build(&server).await?;
|
||||
|
||||
test.submit_turn("hello").await?;
|
||||
|
||||
let requests = server.received_requests().await.unwrap_or_default();
|
||||
let websocket_attempts = requests
|
||||
.iter()
|
||||
.filter(|req| req.method == Method::GET && req.url.path().ends_with("/responses"))
|
||||
.count();
|
||||
let http_attempts = requests
|
||||
.iter()
|
||||
.filter(|req| req.method == Method::POST && req.url.path().ends_with("/responses"))
|
||||
.count();
|
||||
|
||||
// One websocket attempt comes from startup preconnect.
|
||||
// The first turn then makes 3 websocket stream attempts (initial try + 2 retries),
|
||||
// after which fallback activates and the request is replayed over HTTP.
|
||||
assert_eq!(websocket_attempts, 4);
|
||||
assert_eq!(http_attempts, 1);
|
||||
assert_eq!(response_mock.requests().len(), 1);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn websocket_fallback_is_sticky_across_turns() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
|
@ -76,7 +133,7 @@ async fn websocket_fallback_is_sticky_across_turns() -> Result<()> {
|
|||
config.model_provider.base_url = Some(base_url);
|
||||
config.model_provider.wire_api = codex_core::WireApi::Responses;
|
||||
config.features.enable(Feature::ResponsesWebsockets);
|
||||
config.model_provider.stream_max_retries = Some(0);
|
||||
config.model_provider.stream_max_retries = Some(2);
|
||||
config.model_provider.request_max_retries = Some(0);
|
||||
}
|
||||
});
|
||||
|
|
@ -95,10 +152,10 @@ async fn websocket_fallback_is_sticky_across_turns() -> Result<()> {
|
|||
.filter(|req| req.method == Method::POST && req.url.path().ends_with("/responses"))
|
||||
.count();
|
||||
|
||||
// The first turn issues exactly two websocket attempts (startup preconnect + first stream
|
||||
// attempt). After fallback becomes sticky, subsequent turns stay on HTTP. This mirrors the
|
||||
// retry-budget tradeoff documented in [`codex_core::client`] module docs.
|
||||
assert_eq!(websocket_attempts, 2);
|
||||
// WebSocket attempts all happen on the first turn:
|
||||
// 1 startup preconnect + 3 stream attempts (initial try + 2 retries) before fallback.
|
||||
// Fallback is sticky, so the second turn stays on HTTP and adds no websocket attempts.
|
||||
assert_eq!(websocket_attempts, 4);
|
||||
assert_eq!(http_attempts, 2);
|
||||
assert_eq!(response_mock.requests().len(), 2);
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue