Fall back to http when websockets fail (#10139)

I expect not all proxies work with websockets, fall back to http if
websockets fail.
This commit is contained in:
pakrym-oai 2026-01-29 10:36:21 -08:00 committed by GitHub
parent 798c4b3260
commit 3b1cddf001
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 205 additions and 2 deletions

View file

@ -65,6 +65,7 @@ use crate::model_provider_info::ModelProviderInfo;
use crate::model_provider_info::WireApi;
use crate::tools::spec::create_tools_json_for_chat_completions_api;
use crate::tools::spec::create_tools_json_for_responses_api;
use crate::transport_manager::TransportManager;
pub const WEB_SEARCH_ELIGIBLE_HEADER: &str = "x-oai-web-search-eligible";
pub const X_CODEX_TURN_STATE_HEADER: &str = "x-codex-turn-state";
@ -80,6 +81,7 @@ struct ModelClientState {
effort: Option<ReasoningEffortConfig>,
summary: ReasoningSummaryConfig,
session_source: SessionSource,
transport_manager: TransportManager,
}
#[derive(Debug, Clone)]
@ -91,6 +93,7 @@ pub struct ModelClientSession {
state: Arc<ModelClientState>,
connection: Option<ApiWebSocketConnection>,
websocket_last_items: Vec<ResponseItem>,
transport_manager: TransportManager,
/// Turn state for sticky routing.
///
/// This is an `OnceLock` that stores the turn state value received from the server
@ -116,6 +119,7 @@ impl ModelClient {
summary: ReasoningSummaryConfig,
conversation_id: ThreadId,
session_source: SessionSource,
transport_manager: TransportManager,
) -> Self {
Self {
state: Arc::new(ModelClientState {
@ -128,6 +132,7 @@ impl ModelClient {
effort,
summary,
session_source,
transport_manager,
}),
}
}
@ -137,6 +142,7 @@ impl ModelClient {
state: Arc::clone(&self.state),
connection: None,
websocket_last_items: Vec::new(),
transport_manager: self.state.transport_manager.clone(),
turn_state: Arc::new(OnceLock::new()),
}
}
@ -171,6 +177,10 @@ impl ModelClient {
self.state.session_source.clone()
}
pub(crate) fn transport_manager(&self) -> TransportManager {
self.state.transport_manager.clone()
}
/// Returns the currently configured model slug.
pub fn get_model(&self) -> String {
self.state.model_info.slug.clone()
@ -250,7 +260,10 @@ impl ModelClientSession {
/// For Chat providers, the underlying stream is optionally aggregated
/// based on the `show_raw_agent_reasoning` flag in the config.
pub async fn stream(&mut self, prompt: &Prompt) -> Result<ResponseStream> {
match self.state.provider.wire_api {
let wire_api = self
.transport_manager
.effective_wire_api(self.state.provider.wire_api);
match wire_api {
WireApi::Responses => self.stream_responses_api(prompt).await,
WireApi::ResponsesWebsocket => self.stream_responses_websocket(prompt).await,
WireApi::Chat => {
@ -271,6 +284,24 @@ impl ModelClientSession {
}
}
pub(crate) fn try_switch_fallback_transport(&mut self) -> bool {
let activated = self
.transport_manager
.activate_http_fallback(self.state.provider.wire_api);
if activated {
warn!("falling back to HTTP");
self.state.otel_manager.counter(
"codex.transport.fallback_to_http",
1,
&[("from_wire_api", "responses_websocket")],
);
self.connection = None;
self.websocket_last_items.clear();
}
activated
}
fn build_responses_request(&self, prompt: &Prompt) -> Result<ApiPrompt> {
let instructions = prompt.base_instructions.text.clone();
let tools_json: Vec<Value> = create_tools_json_for_responses_api(&prompt.tools)?;

View file

@ -30,6 +30,7 @@ use crate::stream_events_utils::HandleOutputCtx;
use crate::stream_events_utils::handle_non_tool_response_item;
use crate::stream_events_utils::handle_output_item_done;
use crate::terminal;
use crate::transport_manager::TransportManager;
use crate::truncate::TruncationPolicy;
use crate::user_notification::UserNotifier;
use crate::util::error_or_panic;
@ -624,6 +625,7 @@ impl Session {
model_info: ModelInfo,
conversation_id: ThreadId,
sub_id: String,
transport_manager: TransportManager,
) -> TurnContext {
let otel_manager = otel_manager.clone().with_model(
session_configuration.collaboration_mode.model(),
@ -640,6 +642,7 @@ impl Session {
session_configuration.model_reasoning_summary,
conversation_id,
session_configuration.session_source.clone(),
transport_manager,
);
let tools_config = ToolsConfig::new(&ToolsConfigParams {
@ -869,6 +872,7 @@ impl Session {
skills_manager,
agent_control,
state_db: state_db_ctx.clone(),
transport_manager: TransportManager::new(),
};
let sess = Arc::new(Session {
@ -1188,6 +1192,7 @@ impl Session {
model_info,
self.conversation_id,
sub_id,
self.services.transport_manager.clone(),
);
if let Some(final_schema) = final_output_json_schema {
turn_context.final_output_json_schema = final_schema;
@ -3042,6 +3047,7 @@ async fn spawn_review_thread(
per_turn_config.model_reasoning_summary,
sess.conversation_id,
parent_turn_context.client.get_session_source(),
parent_turn_context.client.transport_manager(),
);
let review_turn_context = TurnContext {
@ -3537,7 +3543,9 @@ async fn run_sampling_request(
)
.await
{
Ok(output) => return Ok(output),
Ok(output) => {
return Ok(output);
}
Err(CodexErr::ContextWindowExceeded) => {
sess.set_total_tokens_full(&turn_context).await;
return Err(CodexErr::ContextWindowExceeded);
@ -3558,6 +3566,17 @@ async fn run_sampling_request(
// Use the configured provider-specific stream retry budget.
let max_retries = turn_context.client.get_provider().stream_max_retries();
if retries >= max_retries && client_session.try_switch_fallback_transport() {
sess.send_event(
&turn_context,
EventMsg::Warning(WarningEvent {
message: format!("Falling back from WebSockets to HTTPS transport. {err:#}"),
}),
)
.await;
retries = 0;
continue;
}
if retries < max_retries {
retries += 1;
let delay = match &err {
@ -4753,6 +4772,7 @@ mod tests {
skills_manager,
agent_control,
state_db: None,
transport_manager: TransportManager::new(),
};
let turn_context = Session::make_turn_context(
@ -4764,6 +4784,7 @@ mod tests {
model_info,
conversation_id,
"turn_id".to_string(),
services.transport_manager.clone(),
);
let session = Session {
@ -4865,6 +4886,7 @@ mod tests {
skills_manager,
agent_control,
state_db: None,
transport_manager: TransportManager::new(),
};
let turn_context = Arc::new(Session::make_turn_context(
@ -4876,6 +4898,7 @@ mod tests {
model_info,
conversation_id,
"turn_id".to_string(),
services.transport_manager.clone(),
));
let session = Arc::new(Session {

View file

@ -38,6 +38,7 @@ pub mod landlock;
pub mod mcp;
mod mcp_connection_manager;
pub mod models_manager;
mod transport_manager;
pub use mcp_connection_manager::MCP_SANDBOX_STATE_CAPABILITY;
pub use mcp_connection_manager::MCP_SANDBOX_STATE_METHOD;
pub use mcp_connection_manager::SandboxState;
@ -113,6 +114,7 @@ pub use rollout::list::parse_cursor;
pub use rollout::list::read_head_for_summary;
pub use rollout::list::read_session_meta_line;
pub use rollout::rollout_date_parts;
pub use transport_manager::TransportManager;
mod function_tool;
mod state;
mod tasks;

View file

@ -9,6 +9,7 @@ use crate::models_manager::manager::ModelsManager;
use crate::skills::SkillsManager;
use crate::state_db::StateDbHandle;
use crate::tools::sandboxing::ApprovalStore;
use crate::transport_manager::TransportManager;
use crate::unified_exec::UnifiedExecProcessManager;
use crate::user_notification::UserNotifier;
use codex_otel::OtelManager;
@ -32,4 +33,5 @@ pub(crate) struct SessionServices {
pub(crate) skills_manager: Arc<SkillsManager>,
pub(crate) agent_control: AgentControl,
pub(crate) state_db: Option<StateDbHandle>,
pub(crate) transport_manager: TransportManager,
}

View file

@ -781,6 +781,7 @@ mod tests {
turn.client.get_reasoning_summary(),
session.conversation_id,
session_source,
session.services.transport_manager.clone(),
);
let invocation = invocation(
@ -1221,6 +1222,7 @@ mod tests {
let mut base_config = (*turn.client.config()).clone();
base_config.user_instructions = Some("base-user".to_string());
turn.user_instructions = Some("resolved-user".to_string());
let transport_manager = turn.client.transport_manager();
turn.client = ModelClient::new(
Arc::new(base_config.clone()),
Some(session.services.auth_manager.clone()),
@ -1231,6 +1233,7 @@ mod tests {
turn.client.get_reasoning_summary(),
session.conversation_id,
session_source,
transport_manager,
);
let base_instructions = BaseInstructions {
text: "base".to_string(),

View file

@ -0,0 +1,31 @@
use std::sync::Arc;
use std::sync::atomic::AtomicBool;
use std::sync::atomic::Ordering;
use crate::model_provider_info::WireApi;
#[derive(Clone, Debug, Default)]
pub struct TransportManager {
fallback_to_http: Arc<AtomicBool>,
}
impl TransportManager {
pub fn new() -> Self {
Self::default()
}
pub fn effective_wire_api(&self, provider_wire_api: WireApi) -> WireApi {
if self.fallback_to_http.load(Ordering::Relaxed)
&& provider_wire_api == WireApi::ResponsesWebsocket
{
WireApi::Responses
} else {
provider_wire_api
}
}
pub fn activate_http_fallback(&self, provider_wire_api: WireApi) -> bool {
provider_wire_api == WireApi::ResponsesWebsocket
&& !self.fallback_to_http.swap(true, Ordering::Relaxed)
}
}

View file

@ -11,6 +11,7 @@ use codex_core::ModelClient;
use codex_core::ModelProviderInfo;
use codex_core::Prompt;
use codex_core::ResponseItem;
use codex_core::TransportManager;
use codex_core::WireApi;
use codex_core::models_manager::manager::ModelsManager;
use codex_otel::OtelManager;
@ -98,6 +99,7 @@ async fn run_request(input: Vec<ResponseItem>) -> Value {
summary,
conversation_id,
SessionSource::Exec,
TransportManager::new(),
)
.new_session();

View file

@ -10,6 +10,7 @@ use codex_core::ModelProviderInfo;
use codex_core::Prompt;
use codex_core::ResponseEvent;
use codex_core::ResponseItem;
use codex_core::TransportManager;
use codex_core::WireApi;
use codex_core::models_manager::manager::ModelsManager;
use codex_otel::OtelManager;
@ -99,6 +100,7 @@ async fn run_stream_with_bytes(sse_body: &[u8]) -> Vec<ResponseEvent> {
summary,
conversation_id,
SessionSource::Exec,
TransportManager::new(),
)
.new_session();

View file

@ -9,6 +9,7 @@ use codex_core::ModelProviderInfo;
use codex_core::Prompt;
use codex_core::ResponseEvent;
use codex_core::ResponseItem;
use codex_core::TransportManager;
use codex_core::WEB_SEARCH_ELIGIBLE_HEADER;
use codex_core::WireApi;
use codex_core::models_manager::manager::ModelsManager;
@ -94,6 +95,7 @@ async fn responses_stream_includes_subagent_header_on_review() {
summary,
conversation_id,
session_source,
TransportManager::new(),
)
.new_session();
@ -191,6 +193,7 @@ async fn responses_stream_includes_subagent_header_on_other() {
summary,
conversation_id,
session_source,
TransportManager::new(),
)
.new_session();
@ -346,6 +349,7 @@ async fn responses_respects_model_info_overrides_from_config() {
summary,
conversation_id,
session_source,
TransportManager::new(),
)
.new_session();

View file

@ -11,6 +11,7 @@ use codex_core::Prompt;
use codex_core::ResponseEvent;
use codex_core::ResponseItem;
use codex_core::ThreadManager;
use codex_core::TransportManager;
use codex_core::WireApi;
use codex_core::auth::AuthCredentialsStoreMode;
use codex_core::built_in_model_providers;
@ -1186,6 +1187,7 @@ async fn azure_responses_request_includes_store_and_reasoning_ids() {
summary,
conversation_id,
SessionSource::Exec,
TransportManager::new(),
)
.new_session();

View file

@ -8,6 +8,7 @@ use codex_core::ModelProviderInfo;
use codex_core::Prompt;
use codex_core::ResponseEvent;
use codex_core::ResponseItem;
use codex_core::TransportManager;
use codex_core::WireApi;
use codex_core::models_manager::manager::ModelsManager;
use codex_core::protocol::SessionSource;
@ -228,6 +229,7 @@ async fn websocket_harness(server: &WebSocketTestServer) -> WebsocketTestHarness
ReasoningSummary::Auto,
conversation_id,
SessionSource::Exec,
TransportManager::new(),
);
WebsocketTestHarness {

View file

@ -80,3 +80,4 @@ mod user_notification;
mod user_shell_cmd;
mod view_image;
mod web_search_cached;
mod websocket_fallback;

View file

@ -0,0 +1,98 @@
use anyhow::Result;
use codex_core::WireApi;
use core_test_support::responses;
use core_test_support::responses::ev_completed;
use core_test_support::responses::ev_response_created;
use core_test_support::responses::mount_sse_once;
use core_test_support::responses::mount_sse_sequence;
use core_test_support::responses::sse;
use core_test_support::skip_if_no_network;
use core_test_support::test_codex::test_codex;
use pretty_assertions::assert_eq;
use wiremock::http::Method;
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn websocket_fallback_switches_to_http_after_retries_exhausted() -> Result<()> {
skip_if_no_network!(Ok(()));
let server = responses::start_mock_server().await;
let response_mock = mount_sse_once(
&server,
sse(vec![ev_response_created("resp-1"), ev_completed("resp-1")]),
)
.await;
let mut builder = test_codex().with_config({
let base_url = format!("{}/v1", server.uri());
move |config| {
config.model_provider.base_url = Some(base_url);
config.model_provider.wire_api = WireApi::ResponsesWebsocket;
config.model_provider.stream_max_retries = Some(0);
config.model_provider.request_max_retries = Some(0);
}
});
let test = builder.build(&server).await?;
test.submit_turn("hello").await?;
let requests = server.received_requests().await.unwrap_or_default();
let websocket_attempts = requests
.iter()
.filter(|req| req.method == Method::GET && req.url.path().ends_with("/responses"))
.count();
let http_attempts = requests
.iter()
.filter(|req| req.method == Method::POST && req.url.path().ends_with("/responses"))
.count();
assert_eq!(websocket_attempts, 1);
assert_eq!(http_attempts, 1);
assert_eq!(response_mock.requests().len(), 1);
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn websocket_fallback_is_sticky_across_turns() -> Result<()> {
skip_if_no_network!(Ok(()));
let server = responses::start_mock_server().await;
let response_mock = mount_sse_sequence(
&server,
vec![
sse(vec![ev_response_created("resp-1"), ev_completed("resp-1")]),
sse(vec![ev_response_created("resp-2"), ev_completed("resp-2")]),
],
)
.await;
let mut builder = test_codex().with_config({
let base_url = format!("{}/v1", server.uri());
move |config| {
config.model_provider.base_url = Some(base_url);
config.model_provider.wire_api = WireApi::ResponsesWebsocket;
config.model_provider.stream_max_retries = Some(0);
config.model_provider.request_max_retries = Some(0);
}
});
let test = builder.build(&server).await?;
test.submit_turn("first").await?;
test.submit_turn("second").await?;
let requests = server.received_requests().await.unwrap_or_default();
let websocket_attempts = requests
.iter()
.filter(|req| req.method == Method::GET && req.url.path().ends_with("/responses"))
.count();
let http_attempts = requests
.iter()
.filter(|req| req.method == Method::POST && req.url.path().ends_with("/responses"))
.count();
assert_eq!(websocket_attempts, 1);
assert_eq!(http_attempts, 2);
assert_eq!(response_mock.requests().len(), 2);
Ok(())
}