Fall back to http when websockets fail (#10139)
I expect not all proxies work with websockets, fall back to http if websockets fail.
This commit is contained in:
parent
798c4b3260
commit
3b1cddf001
13 changed files with 205 additions and 2 deletions
|
|
@ -65,6 +65,7 @@ use crate::model_provider_info::ModelProviderInfo;
|
|||
use crate::model_provider_info::WireApi;
|
||||
use crate::tools::spec::create_tools_json_for_chat_completions_api;
|
||||
use crate::tools::spec::create_tools_json_for_responses_api;
|
||||
use crate::transport_manager::TransportManager;
|
||||
|
||||
pub const WEB_SEARCH_ELIGIBLE_HEADER: &str = "x-oai-web-search-eligible";
|
||||
pub const X_CODEX_TURN_STATE_HEADER: &str = "x-codex-turn-state";
|
||||
|
|
@ -80,6 +81,7 @@ struct ModelClientState {
|
|||
effort: Option<ReasoningEffortConfig>,
|
||||
summary: ReasoningSummaryConfig,
|
||||
session_source: SessionSource,
|
||||
transport_manager: TransportManager,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
|
|
@ -91,6 +93,7 @@ pub struct ModelClientSession {
|
|||
state: Arc<ModelClientState>,
|
||||
connection: Option<ApiWebSocketConnection>,
|
||||
websocket_last_items: Vec<ResponseItem>,
|
||||
transport_manager: TransportManager,
|
||||
/// Turn state for sticky routing.
|
||||
///
|
||||
/// This is an `OnceLock` that stores the turn state value received from the server
|
||||
|
|
@ -116,6 +119,7 @@ impl ModelClient {
|
|||
summary: ReasoningSummaryConfig,
|
||||
conversation_id: ThreadId,
|
||||
session_source: SessionSource,
|
||||
transport_manager: TransportManager,
|
||||
) -> Self {
|
||||
Self {
|
||||
state: Arc::new(ModelClientState {
|
||||
|
|
@ -128,6 +132,7 @@ impl ModelClient {
|
|||
effort,
|
||||
summary,
|
||||
session_source,
|
||||
transport_manager,
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
|
@ -137,6 +142,7 @@ impl ModelClient {
|
|||
state: Arc::clone(&self.state),
|
||||
connection: None,
|
||||
websocket_last_items: Vec::new(),
|
||||
transport_manager: self.state.transport_manager.clone(),
|
||||
turn_state: Arc::new(OnceLock::new()),
|
||||
}
|
||||
}
|
||||
|
|
@ -171,6 +177,10 @@ impl ModelClient {
|
|||
self.state.session_source.clone()
|
||||
}
|
||||
|
||||
pub(crate) fn transport_manager(&self) -> TransportManager {
|
||||
self.state.transport_manager.clone()
|
||||
}
|
||||
|
||||
/// Returns the currently configured model slug.
|
||||
pub fn get_model(&self) -> String {
|
||||
self.state.model_info.slug.clone()
|
||||
|
|
@ -250,7 +260,10 @@ impl ModelClientSession {
|
|||
/// For Chat providers, the underlying stream is optionally aggregated
|
||||
/// based on the `show_raw_agent_reasoning` flag in the config.
|
||||
pub async fn stream(&mut self, prompt: &Prompt) -> Result<ResponseStream> {
|
||||
match self.state.provider.wire_api {
|
||||
let wire_api = self
|
||||
.transport_manager
|
||||
.effective_wire_api(self.state.provider.wire_api);
|
||||
match wire_api {
|
||||
WireApi::Responses => self.stream_responses_api(prompt).await,
|
||||
WireApi::ResponsesWebsocket => self.stream_responses_websocket(prompt).await,
|
||||
WireApi::Chat => {
|
||||
|
|
@ -271,6 +284,24 @@ impl ModelClientSession {
|
|||
}
|
||||
}
|
||||
|
||||
pub(crate) fn try_switch_fallback_transport(&mut self) -> bool {
|
||||
let activated = self
|
||||
.transport_manager
|
||||
.activate_http_fallback(self.state.provider.wire_api);
|
||||
if activated {
|
||||
warn!("falling back to HTTP");
|
||||
self.state.otel_manager.counter(
|
||||
"codex.transport.fallback_to_http",
|
||||
1,
|
||||
&[("from_wire_api", "responses_websocket")],
|
||||
);
|
||||
|
||||
self.connection = None;
|
||||
self.websocket_last_items.clear();
|
||||
}
|
||||
activated
|
||||
}
|
||||
|
||||
fn build_responses_request(&self, prompt: &Prompt) -> Result<ApiPrompt> {
|
||||
let instructions = prompt.base_instructions.text.clone();
|
||||
let tools_json: Vec<Value> = create_tools_json_for_responses_api(&prompt.tools)?;
|
||||
|
|
|
|||
|
|
@ -30,6 +30,7 @@ use crate::stream_events_utils::HandleOutputCtx;
|
|||
use crate::stream_events_utils::handle_non_tool_response_item;
|
||||
use crate::stream_events_utils::handle_output_item_done;
|
||||
use crate::terminal;
|
||||
use crate::transport_manager::TransportManager;
|
||||
use crate::truncate::TruncationPolicy;
|
||||
use crate::user_notification::UserNotifier;
|
||||
use crate::util::error_or_panic;
|
||||
|
|
@ -624,6 +625,7 @@ impl Session {
|
|||
model_info: ModelInfo,
|
||||
conversation_id: ThreadId,
|
||||
sub_id: String,
|
||||
transport_manager: TransportManager,
|
||||
) -> TurnContext {
|
||||
let otel_manager = otel_manager.clone().with_model(
|
||||
session_configuration.collaboration_mode.model(),
|
||||
|
|
@ -640,6 +642,7 @@ impl Session {
|
|||
session_configuration.model_reasoning_summary,
|
||||
conversation_id,
|
||||
session_configuration.session_source.clone(),
|
||||
transport_manager,
|
||||
);
|
||||
|
||||
let tools_config = ToolsConfig::new(&ToolsConfigParams {
|
||||
|
|
@ -869,6 +872,7 @@ impl Session {
|
|||
skills_manager,
|
||||
agent_control,
|
||||
state_db: state_db_ctx.clone(),
|
||||
transport_manager: TransportManager::new(),
|
||||
};
|
||||
|
||||
let sess = Arc::new(Session {
|
||||
|
|
@ -1188,6 +1192,7 @@ impl Session {
|
|||
model_info,
|
||||
self.conversation_id,
|
||||
sub_id,
|
||||
self.services.transport_manager.clone(),
|
||||
);
|
||||
if let Some(final_schema) = final_output_json_schema {
|
||||
turn_context.final_output_json_schema = final_schema;
|
||||
|
|
@ -3042,6 +3047,7 @@ async fn spawn_review_thread(
|
|||
per_turn_config.model_reasoning_summary,
|
||||
sess.conversation_id,
|
||||
parent_turn_context.client.get_session_source(),
|
||||
parent_turn_context.client.transport_manager(),
|
||||
);
|
||||
|
||||
let review_turn_context = TurnContext {
|
||||
|
|
@ -3537,7 +3543,9 @@ async fn run_sampling_request(
|
|||
)
|
||||
.await
|
||||
{
|
||||
Ok(output) => return Ok(output),
|
||||
Ok(output) => {
|
||||
return Ok(output);
|
||||
}
|
||||
Err(CodexErr::ContextWindowExceeded) => {
|
||||
sess.set_total_tokens_full(&turn_context).await;
|
||||
return Err(CodexErr::ContextWindowExceeded);
|
||||
|
|
@ -3558,6 +3566,17 @@ async fn run_sampling_request(
|
|||
|
||||
// Use the configured provider-specific stream retry budget.
|
||||
let max_retries = turn_context.client.get_provider().stream_max_retries();
|
||||
if retries >= max_retries && client_session.try_switch_fallback_transport() {
|
||||
sess.send_event(
|
||||
&turn_context,
|
||||
EventMsg::Warning(WarningEvent {
|
||||
message: format!("Falling back from WebSockets to HTTPS transport. {err:#}"),
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
retries = 0;
|
||||
continue;
|
||||
}
|
||||
if retries < max_retries {
|
||||
retries += 1;
|
||||
let delay = match &err {
|
||||
|
|
@ -4753,6 +4772,7 @@ mod tests {
|
|||
skills_manager,
|
||||
agent_control,
|
||||
state_db: None,
|
||||
transport_manager: TransportManager::new(),
|
||||
};
|
||||
|
||||
let turn_context = Session::make_turn_context(
|
||||
|
|
@ -4764,6 +4784,7 @@ mod tests {
|
|||
model_info,
|
||||
conversation_id,
|
||||
"turn_id".to_string(),
|
||||
services.transport_manager.clone(),
|
||||
);
|
||||
|
||||
let session = Session {
|
||||
|
|
@ -4865,6 +4886,7 @@ mod tests {
|
|||
skills_manager,
|
||||
agent_control,
|
||||
state_db: None,
|
||||
transport_manager: TransportManager::new(),
|
||||
};
|
||||
|
||||
let turn_context = Arc::new(Session::make_turn_context(
|
||||
|
|
@ -4876,6 +4898,7 @@ mod tests {
|
|||
model_info,
|
||||
conversation_id,
|
||||
"turn_id".to_string(),
|
||||
services.transport_manager.clone(),
|
||||
));
|
||||
|
||||
let session = Arc::new(Session {
|
||||
|
|
|
|||
|
|
@ -38,6 +38,7 @@ pub mod landlock;
|
|||
pub mod mcp;
|
||||
mod mcp_connection_manager;
|
||||
pub mod models_manager;
|
||||
mod transport_manager;
|
||||
pub use mcp_connection_manager::MCP_SANDBOX_STATE_CAPABILITY;
|
||||
pub use mcp_connection_manager::MCP_SANDBOX_STATE_METHOD;
|
||||
pub use mcp_connection_manager::SandboxState;
|
||||
|
|
@ -113,6 +114,7 @@ pub use rollout::list::parse_cursor;
|
|||
pub use rollout::list::read_head_for_summary;
|
||||
pub use rollout::list::read_session_meta_line;
|
||||
pub use rollout::rollout_date_parts;
|
||||
pub use transport_manager::TransportManager;
|
||||
mod function_tool;
|
||||
mod state;
|
||||
mod tasks;
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ use crate::models_manager::manager::ModelsManager;
|
|||
use crate::skills::SkillsManager;
|
||||
use crate::state_db::StateDbHandle;
|
||||
use crate::tools::sandboxing::ApprovalStore;
|
||||
use crate::transport_manager::TransportManager;
|
||||
use crate::unified_exec::UnifiedExecProcessManager;
|
||||
use crate::user_notification::UserNotifier;
|
||||
use codex_otel::OtelManager;
|
||||
|
|
@ -32,4 +33,5 @@ pub(crate) struct SessionServices {
|
|||
pub(crate) skills_manager: Arc<SkillsManager>,
|
||||
pub(crate) agent_control: AgentControl,
|
||||
pub(crate) state_db: Option<StateDbHandle>,
|
||||
pub(crate) transport_manager: TransportManager,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -781,6 +781,7 @@ mod tests {
|
|||
turn.client.get_reasoning_summary(),
|
||||
session.conversation_id,
|
||||
session_source,
|
||||
session.services.transport_manager.clone(),
|
||||
);
|
||||
|
||||
let invocation = invocation(
|
||||
|
|
@ -1221,6 +1222,7 @@ mod tests {
|
|||
let mut base_config = (*turn.client.config()).clone();
|
||||
base_config.user_instructions = Some("base-user".to_string());
|
||||
turn.user_instructions = Some("resolved-user".to_string());
|
||||
let transport_manager = turn.client.transport_manager();
|
||||
turn.client = ModelClient::new(
|
||||
Arc::new(base_config.clone()),
|
||||
Some(session.services.auth_manager.clone()),
|
||||
|
|
@ -1231,6 +1233,7 @@ mod tests {
|
|||
turn.client.get_reasoning_summary(),
|
||||
session.conversation_id,
|
||||
session_source,
|
||||
transport_manager,
|
||||
);
|
||||
let base_instructions = BaseInstructions {
|
||||
text: "base".to_string(),
|
||||
|
|
|
|||
31
codex-rs/core/src/transport_manager.rs
Normal file
31
codex-rs/core/src/transport_manager.rs
Normal file
|
|
@ -0,0 +1,31 @@
|
|||
use std::sync::Arc;
|
||||
use std::sync::atomic::AtomicBool;
|
||||
use std::sync::atomic::Ordering;
|
||||
|
||||
use crate::model_provider_info::WireApi;
|
||||
|
||||
#[derive(Clone, Debug, Default)]
|
||||
pub struct TransportManager {
|
||||
fallback_to_http: Arc<AtomicBool>,
|
||||
}
|
||||
|
||||
impl TransportManager {
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
pub fn effective_wire_api(&self, provider_wire_api: WireApi) -> WireApi {
|
||||
if self.fallback_to_http.load(Ordering::Relaxed)
|
||||
&& provider_wire_api == WireApi::ResponsesWebsocket
|
||||
{
|
||||
WireApi::Responses
|
||||
} else {
|
||||
provider_wire_api
|
||||
}
|
||||
}
|
||||
|
||||
pub fn activate_http_fallback(&self, provider_wire_api: WireApi) -> bool {
|
||||
provider_wire_api == WireApi::ResponsesWebsocket
|
||||
&& !self.fallback_to_http.swap(true, Ordering::Relaxed)
|
||||
}
|
||||
}
|
||||
|
|
@ -11,6 +11,7 @@ use codex_core::ModelClient;
|
|||
use codex_core::ModelProviderInfo;
|
||||
use codex_core::Prompt;
|
||||
use codex_core::ResponseItem;
|
||||
use codex_core::TransportManager;
|
||||
use codex_core::WireApi;
|
||||
use codex_core::models_manager::manager::ModelsManager;
|
||||
use codex_otel::OtelManager;
|
||||
|
|
@ -98,6 +99,7 @@ async fn run_request(input: Vec<ResponseItem>) -> Value {
|
|||
summary,
|
||||
conversation_id,
|
||||
SessionSource::Exec,
|
||||
TransportManager::new(),
|
||||
)
|
||||
.new_session();
|
||||
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ use codex_core::ModelProviderInfo;
|
|||
use codex_core::Prompt;
|
||||
use codex_core::ResponseEvent;
|
||||
use codex_core::ResponseItem;
|
||||
use codex_core::TransportManager;
|
||||
use codex_core::WireApi;
|
||||
use codex_core::models_manager::manager::ModelsManager;
|
||||
use codex_otel::OtelManager;
|
||||
|
|
@ -99,6 +100,7 @@ async fn run_stream_with_bytes(sse_body: &[u8]) -> Vec<ResponseEvent> {
|
|||
summary,
|
||||
conversation_id,
|
||||
SessionSource::Exec,
|
||||
TransportManager::new(),
|
||||
)
|
||||
.new_session();
|
||||
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ use codex_core::ModelProviderInfo;
|
|||
use codex_core::Prompt;
|
||||
use codex_core::ResponseEvent;
|
||||
use codex_core::ResponseItem;
|
||||
use codex_core::TransportManager;
|
||||
use codex_core::WEB_SEARCH_ELIGIBLE_HEADER;
|
||||
use codex_core::WireApi;
|
||||
use codex_core::models_manager::manager::ModelsManager;
|
||||
|
|
@ -94,6 +95,7 @@ async fn responses_stream_includes_subagent_header_on_review() {
|
|||
summary,
|
||||
conversation_id,
|
||||
session_source,
|
||||
TransportManager::new(),
|
||||
)
|
||||
.new_session();
|
||||
|
||||
|
|
@ -191,6 +193,7 @@ async fn responses_stream_includes_subagent_header_on_other() {
|
|||
summary,
|
||||
conversation_id,
|
||||
session_source,
|
||||
TransportManager::new(),
|
||||
)
|
||||
.new_session();
|
||||
|
||||
|
|
@ -346,6 +349,7 @@ async fn responses_respects_model_info_overrides_from_config() {
|
|||
summary,
|
||||
conversation_id,
|
||||
session_source,
|
||||
TransportManager::new(),
|
||||
)
|
||||
.new_session();
|
||||
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ use codex_core::Prompt;
|
|||
use codex_core::ResponseEvent;
|
||||
use codex_core::ResponseItem;
|
||||
use codex_core::ThreadManager;
|
||||
use codex_core::TransportManager;
|
||||
use codex_core::WireApi;
|
||||
use codex_core::auth::AuthCredentialsStoreMode;
|
||||
use codex_core::built_in_model_providers;
|
||||
|
|
@ -1186,6 +1187,7 @@ async fn azure_responses_request_includes_store_and_reasoning_ids() {
|
|||
summary,
|
||||
conversation_id,
|
||||
SessionSource::Exec,
|
||||
TransportManager::new(),
|
||||
)
|
||||
.new_session();
|
||||
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ use codex_core::ModelProviderInfo;
|
|||
use codex_core::Prompt;
|
||||
use codex_core::ResponseEvent;
|
||||
use codex_core::ResponseItem;
|
||||
use codex_core::TransportManager;
|
||||
use codex_core::WireApi;
|
||||
use codex_core::models_manager::manager::ModelsManager;
|
||||
use codex_core::protocol::SessionSource;
|
||||
|
|
@ -228,6 +229,7 @@ async fn websocket_harness(server: &WebSocketTestServer) -> WebsocketTestHarness
|
|||
ReasoningSummary::Auto,
|
||||
conversation_id,
|
||||
SessionSource::Exec,
|
||||
TransportManager::new(),
|
||||
);
|
||||
|
||||
WebsocketTestHarness {
|
||||
|
|
|
|||
|
|
@ -80,3 +80,4 @@ mod user_notification;
|
|||
mod user_shell_cmd;
|
||||
mod view_image;
|
||||
mod web_search_cached;
|
||||
mod websocket_fallback;
|
||||
|
|
|
|||
98
codex-rs/core/tests/suite/websocket_fallback.rs
Normal file
98
codex-rs/core/tests/suite/websocket_fallback.rs
Normal file
|
|
@ -0,0 +1,98 @@
|
|||
use anyhow::Result;
|
||||
use codex_core::WireApi;
|
||||
use core_test_support::responses;
|
||||
use core_test_support::responses::ev_completed;
|
||||
use core_test_support::responses::ev_response_created;
|
||||
use core_test_support::responses::mount_sse_once;
|
||||
use core_test_support::responses::mount_sse_sequence;
|
||||
use core_test_support::responses::sse;
|
||||
use core_test_support::skip_if_no_network;
|
||||
use core_test_support::test_codex::test_codex;
|
||||
use pretty_assertions::assert_eq;
|
||||
use wiremock::http::Method;
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn websocket_fallback_switches_to_http_after_retries_exhausted() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let server = responses::start_mock_server().await;
|
||||
let response_mock = mount_sse_once(
|
||||
&server,
|
||||
sse(vec![ev_response_created("resp-1"), ev_completed("resp-1")]),
|
||||
)
|
||||
.await;
|
||||
|
||||
let mut builder = test_codex().with_config({
|
||||
let base_url = format!("{}/v1", server.uri());
|
||||
move |config| {
|
||||
config.model_provider.base_url = Some(base_url);
|
||||
config.model_provider.wire_api = WireApi::ResponsesWebsocket;
|
||||
config.model_provider.stream_max_retries = Some(0);
|
||||
config.model_provider.request_max_retries = Some(0);
|
||||
}
|
||||
});
|
||||
let test = builder.build(&server).await?;
|
||||
|
||||
test.submit_turn("hello").await?;
|
||||
|
||||
let requests = server.received_requests().await.unwrap_or_default();
|
||||
let websocket_attempts = requests
|
||||
.iter()
|
||||
.filter(|req| req.method == Method::GET && req.url.path().ends_with("/responses"))
|
||||
.count();
|
||||
let http_attempts = requests
|
||||
.iter()
|
||||
.filter(|req| req.method == Method::POST && req.url.path().ends_with("/responses"))
|
||||
.count();
|
||||
|
||||
assert_eq!(websocket_attempts, 1);
|
||||
assert_eq!(http_attempts, 1);
|
||||
assert_eq!(response_mock.requests().len(), 1);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn websocket_fallback_is_sticky_across_turns() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let server = responses::start_mock_server().await;
|
||||
let response_mock = mount_sse_sequence(
|
||||
&server,
|
||||
vec![
|
||||
sse(vec![ev_response_created("resp-1"), ev_completed("resp-1")]),
|
||||
sse(vec![ev_response_created("resp-2"), ev_completed("resp-2")]),
|
||||
],
|
||||
)
|
||||
.await;
|
||||
|
||||
let mut builder = test_codex().with_config({
|
||||
let base_url = format!("{}/v1", server.uri());
|
||||
move |config| {
|
||||
config.model_provider.base_url = Some(base_url);
|
||||
config.model_provider.wire_api = WireApi::ResponsesWebsocket;
|
||||
config.model_provider.stream_max_retries = Some(0);
|
||||
config.model_provider.request_max_retries = Some(0);
|
||||
}
|
||||
});
|
||||
let test = builder.build(&server).await?;
|
||||
|
||||
test.submit_turn("first").await?;
|
||||
test.submit_turn("second").await?;
|
||||
|
||||
let requests = server.received_requests().await.unwrap_or_default();
|
||||
let websocket_attempts = requests
|
||||
.iter()
|
||||
.filter(|req| req.method == Method::GET && req.url.path().ends_with("/responses"))
|
||||
.count();
|
||||
let http_attempts = requests
|
||||
.iter()
|
||||
.filter(|req| req.method == Method::POST && req.url.path().ends_with("/responses"))
|
||||
.count();
|
||||
|
||||
assert_eq!(websocket_attempts, 1);
|
||||
assert_eq!(http_attempts, 2);
|
||||
assert_eq!(response_mock.requests().len(), 2);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Loading…
Add table
Reference in a new issue