diff --git a/codex-rs/core/src/client.rs b/codex-rs/core/src/client.rs index 6f33910ce..523984104 100644 --- a/codex-rs/core/src/client.rs +++ b/codex-rs/core/src/client.rs @@ -92,6 +92,7 @@ use crate::auth::RefreshTokenError; use crate::client_common::Prompt; use crate::client_common::ResponseEvent; use crate::client_common::ResponseStream; +use crate::config::Config; use crate::default_client::build_reqwest_client; use crate::error::CodexErr; use crate::error::Result; @@ -107,6 +108,28 @@ pub const X_CODEX_TURN_METADATA_HEADER: &str = "x-codex-turn-metadata"; pub const X_RESPONSESAPI_INCLUDE_TIMING_METRICS_HEADER: &str = "x-responsesapi-include-timing-metrics"; const RESPONSES_WEBSOCKETS_V2_BETA_HEADER_VALUE: &str = "responses_websockets=2026-02-06"; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ResponsesWebsocketVersion { + V1, + V2, +} + +pub fn ws_version_from_features(config: &Config) -> Option { + match ( + config + .features + .enabled(crate::features::Feature::ResponsesWebsockets), + config + .features + .enabled(crate::features::Feature::ResponsesWebsocketsV2), + ) { + (_, true) => Some(ResponsesWebsocketVersion::V2), + (true, false) => Some(ResponsesWebsocketVersion::V1), + (false, false) => None, + } +} + /// Session-scoped state shared by all [`ModelClient`] clones. /// /// This is intentionally kept minimal so `ModelClient` does not need to hold a full `Config`. Most @@ -118,8 +141,7 @@ struct ModelClientState { provider: ModelProviderInfo, session_source: SessionSource, model_verbosity: Option, - enable_responses_websockets: bool, - enable_responses_websockets_v2: bool, + responses_websocket_version: Option, enable_request_compression: bool, include_timing_metrics: bool, beta_features_header: Option, @@ -209,14 +231,11 @@ impl ModelClient { provider: ModelProviderInfo, session_source: SessionSource, model_verbosity: Option, - enable_responses_websockets: bool, - enable_responses_websockets_v2: bool, + responses_websocket_version: Option, enable_request_compression: bool, include_timing_metrics: bool, beta_features_header: Option, ) -> Self { - let enable_responses_websockets = - enable_responses_websockets || enable_responses_websockets_v2; Self { state: Arc::new(ModelClientState { auth_manager, @@ -224,8 +243,7 @@ impl ModelClient { provider, session_source, model_verbosity, - enable_responses_websockets, - enable_responses_websockets_v2, + responses_websocket_version, enable_request_compression, include_timing_metrics, beta_features_header, @@ -367,26 +385,25 @@ impl ModelClient { request_telemetry } - /// Returns whether this session is configured to use Responses-over-WebSocket. + /// Returns the active Responses-over-WebSocket version for this session. /// /// This combines provider capability and feature gating; both must be true for websocket paths /// to be eligible. - pub fn responses_websocket_enabled(&self, model_info: &ModelInfo) -> bool { - self.state.provider.supports_websockets - && (self.state.enable_responses_websockets - || self.state.enable_responses_websockets_v2 - || model_info.prefer_websockets) - } - - fn responses_websockets_v2_enabled(&self) -> bool { - self.state.enable_responses_websockets_v2 - } - - /// Returns whether websocket transport has been permanently disabled for this session. /// - /// Once set by fallback activation, subsequent turns must stay on HTTP transport. - fn websockets_disabled(&self) -> bool { - self.state.disable_websockets.load(Ordering::Relaxed) + /// If websockets are only enabled via model preference (no explicit feature flag), default to + /// v1 behavior. + pub fn active_ws_version(&self, model_info: &ModelInfo) -> Option { + if !self.state.provider.supports_websockets + || self.state.disable_websockets.load(Ordering::Relaxed) + { + return None; + } + + match self.state.responses_websocket_version { + Some(version) => Some(version), + None if model_info.prefer_websockets => Some(ResponsesWebsocketVersion::V1), + None => None, + } } /// Returns auth + provider configuration resolved from the current session auth state. @@ -419,10 +436,12 @@ impl ModelClient { otel_manager: &OtelManager, api_provider: codex_api::Provider, api_auth: CoreAuthProvider, + ws_version: ResponsesWebsocketVersion, turn_state: Option>>, turn_metadata_header: Option<&str>, ) -> std::result::Result { - let headers = self.build_websocket_headers(turn_state.as_ref(), turn_metadata_header); + let headers = + self.build_websocket_headers(ws_version, turn_state.as_ref(), turn_metadata_header); let websocket_telemetry = ModelClientSession::build_websocket_telemetry(otel_manager); ApiWebSocketResponsesClient::new(api_provider, api_auth) .connect( @@ -440,6 +459,7 @@ impl ModelClient { /// replayed on reconnect within the same turn. fn build_websocket_headers( &self, + ws_version: ResponsesWebsocketVersion, turn_state: Option<&Arc>>, turn_metadata_header: Option<&str>, ) -> ApiHeaderMap { @@ -452,10 +472,9 @@ impl ModelClient { headers.extend(build_conversation_headers(Some( self.state.conversation_id.to_string(), ))); - let responses_websockets_beta_header = if self.responses_websockets_v2_enabled() { - RESPONSES_WEBSOCKETS_V2_BETA_HEADER_VALUE - } else { - OPENAI_BETA_RESPONSES_WEBSOCKETS + let responses_websockets_beta_header = match ws_version { + ResponsesWebsocketVersion::V2 => RESPONSES_WEBSOCKETS_V2_BETA_HEADER_VALUE, + ResponsesWebsocketVersion::V1 => OPENAI_BETA_RESPONSES_WEBSOCKETS, }; headers.insert( OPENAI_BETA_HEADER, @@ -628,35 +647,39 @@ impl ModelClientSession { &mut self, payload: ResponseCreateWsRequest, request: &ResponsesApiRequest, + ws_version: ResponsesWebsocketVersion, ) -> ResponsesWsRequest { let Some(last_response) = self.get_last_response() else { return ResponsesWsRequest::ResponseCreate(payload); }; - let responses_websockets_v2_enabled = self.client.responses_websockets_v2_enabled(); - if !responses_websockets_v2_enabled && !last_response.can_append { - trace!("incremental request failed, can't append"); + let Some(append_items) = self.get_incremental_items(request, Some(&last_response)) else { return ResponsesWsRequest::ResponseCreate(payload); - } - let incremental_items = self.get_incremental_items(request, Some(&last_response)); - if let Some(append_items) = incremental_items { - if responses_websockets_v2_enabled && !last_response.response_id.is_empty() { - let payload = ResponseCreateWsRequest { + }; + + match ws_version { + ResponsesWebsocketVersion::V2 => { + if last_response.response_id.is_empty() { + trace!("incremental request failed, no previous response id"); + return ResponsesWsRequest::ResponseCreate(payload); + } + + ResponsesWsRequest::ResponseCreate(ResponseCreateWsRequest { previous_response_id: Some(last_response.response_id), input: append_items, ..payload - }; - return ResponsesWsRequest::ResponseCreate(payload); + }) } - - if !responses_websockets_v2_enabled { - return ResponsesWsRequest::ResponseAppend(ResponseAppendWsRequest { + ResponsesWebsocketVersion::V1 => { + if !last_response.can_append { + trace!("incremental request failed, can't append"); + return ResponsesWsRequest::ResponseCreate(payload); + } + ResponsesWsRequest::ResponseAppend(ResponseAppendWsRequest { input: append_items, client_metadata: payload.client_metadata, - }); + }) } } - - ResponsesWsRequest::ResponseCreate(payload) } /// Opportunistically warms a websocket for this turn-scoped client session. @@ -667,10 +690,9 @@ impl ModelClientSession { otel_manager: &OtelManager, model_info: &ModelInfo, ) -> std::result::Result<(), ApiError> { - if !self.client.responses_websocket_enabled(model_info) || self.client.websockets_disabled() - { + let Some(ws_version) = self.client.active_ws_version(model_info) else { return Ok(()); - } + }; if self.connection.is_some() { return Ok(()); } @@ -687,6 +709,7 @@ impl ModelClientSession { otel_manager, client_setup.api_provider, client_setup.api_auth, + ws_version, Some(Arc::clone(&self.turn_state)), None, ) @@ -701,6 +724,7 @@ impl ModelClientSession { otel_manager: &OtelManager, api_provider: codex_api::Provider, api_auth: CoreAuthProvider, + ws_version: ResponsesWebsocketVersion, turn_metadata_header: Option<&str>, options: &ApiResponsesOptions, ) -> std::result::Result<&ApiWebSocketConnection, ApiError> { @@ -722,6 +746,7 @@ impl ModelClientSession { otel_manager, api_provider, api_auth, + ws_version, Some(turn_state), turn_metadata_header, ) @@ -818,6 +843,7 @@ impl ModelClientSession { &mut self, prompt: &Prompt, model_info: &ModelInfo, + ws_version: ResponsesWebsocketVersion, otel_manager: &OtelManager, effort: Option, summary: ReasoningSummaryConfig, @@ -850,6 +876,7 @@ impl ModelClientSession { otel_manager, client_setup.api_provider, client_setup.api_auth, + ws_version, turn_metadata_header, &options, ) @@ -870,7 +897,7 @@ impl ModelClientSession { Err(err) => return Err(map_api_error(err)), } - let ws_request = self.prepare_websocket_request(ws_payload, &request); + let ws_request = self.prepare_websocket_request(ws_payload, &request, ws_version); let stream_result = self .connection @@ -928,14 +955,12 @@ impl ModelClientSession { let wire_api = self.client.state.provider.wire_api; match wire_api { WireApi::Responses => { - let websocket_enabled = self.client.responses_websocket_enabled(model_info) - && !self.client.websockets_disabled(); - - if websocket_enabled { + if let Some(ws_version) = self.client.active_ws_version(model_info) { match self .stream_responses_websocket( prompt, model_info, + ws_version, otel_manager, effort, summary, @@ -974,7 +999,7 @@ impl ModelClientSession { otel_manager: &OtelManager, model_info: &ModelInfo, ) -> bool { - let websocket_enabled = self.client.responses_websocket_enabled(model_info); + let websocket_enabled = self.client.active_ws_version(model_info).is_some(); let activated = self.activate_http_fallback(websocket_enabled); if activated { warn!("falling back to HTTP"); @@ -1224,8 +1249,7 @@ mod tests { provider, session_source, None, - false, - false, + None, false, false, None, diff --git a/codex-rs/core/src/codex.rs b/codex-rs/core/src/codex.rs index 6c9c9a678..55a806218 100644 --- a/codex-rs/core/src/codex.rs +++ b/codex-rs/core/src/codex.rs @@ -45,6 +45,7 @@ use crate::terminal; use crate::truncate::TruncationPolicy; use crate::turn_metadata::TurnMetadataState; use crate::util::error_or_panic; +use crate::ws_version_from_features; use async_channel::Receiver; use async_channel::Sender; use codex_hooks::HookEvent; @@ -1335,9 +1336,7 @@ impl Session { session_configuration.provider.clone(), session_configuration.session_source.clone(), config.model_verbosity, - config.features.enabled(Feature::ResponsesWebsockets) - || config.features.enabled(Feature::ResponsesWebsocketsV2), - config.features.enabled(Feature::ResponsesWebsocketsV2), + ws_version_from_features(config.as_ref()), config.features.enabled(Feature::EnableRequestCompression), config.features.enabled(Feature::RuntimeMetrics), Self::build_model_client_beta_features_header(config.as_ref()), @@ -5198,10 +5197,11 @@ async fn run_sampling_request( // transient reconnect messages. In debug builds, keep full visibility for diagnosis. let report_error = retries > 1 || cfg!(debug_assertions) - || !sess + || sess .services .model_client - .responses_websocket_enabled(&turn_context.model_info); + .active_ws_version(&turn_context.model_info) + .is_none(); if report_error { // Surface retry information to any UI/front‑end so the @@ -7844,10 +7844,7 @@ mod tests { session_configuration.provider.clone(), session_configuration.session_source.clone(), config.model_verbosity, - model_info.prefer_websockets - || config.features.enabled(Feature::ResponsesWebsockets) - || config.features.enabled(Feature::ResponsesWebsocketsV2), - config.features.enabled(Feature::ResponsesWebsocketsV2), + ws_version_from_features(config.as_ref()), config.features.enabled(Feature::EnableRequestCompression), config.features.enabled(Feature::RuntimeMetrics), Session::build_model_client_beta_features_header(config.as_ref()), @@ -8000,10 +7997,7 @@ mod tests { session_configuration.provider.clone(), session_configuration.session_source.clone(), config.model_verbosity, - model_info.prefer_websockets - || config.features.enabled(Feature::ResponsesWebsockets) - || config.features.enabled(Feature::ResponsesWebsocketsV2), - config.features.enabled(Feature::ResponsesWebsocketsV2), + ws_version_from_features(config.as_ref()), config.features.enabled(Feature::EnableRequestCompression), config.features.enabled(Feature::RuntimeMetrics), Session::build_model_client_beta_features_header(config.as_ref()), diff --git a/codex-rs/core/src/lib.rs b/codex-rs/core/src/lib.rs index a0ad85a42..e40905e31 100644 --- a/codex-rs/core/src/lib.rs +++ b/codex-rs/core/src/lib.rs @@ -157,6 +157,8 @@ pub use zsh_exec_bridge::maybe_run_zsh_exec_wrapper_mode; pub use client::ModelClient; pub use client::ModelClientSession; +pub use client::ResponsesWebsocketVersion; +pub use client::ws_version_from_features; pub use client_common::Prompt; pub use client_common::REVIEW_PROMPT; pub use client_common::ResponseEvent; diff --git a/codex-rs/core/tests/responses_headers.rs b/codex-rs/core/tests/responses_headers.rs index b392081ee..126dc2c28 100644 --- a/codex-rs/core/tests/responses_headers.rs +++ b/codex-rs/core/tests/responses_headers.rs @@ -6,6 +6,7 @@ use codex_core::ModelClient; use codex_core::ModelProviderInfo; use codex_core::Prompt; use codex_core::ResponseEvent; +use codex_core::ResponsesWebsocketVersion; use codex_core::WireApi; use codex_otel::OtelManager; use codex_otel::TelemetryAuthMode; @@ -91,8 +92,7 @@ async fn responses_stream_includes_subagent_header_on_review() { provider.clone(), session_source, config.model_verbosity, - false, - false, + None::, false, false, None, @@ -197,8 +197,7 @@ async fn responses_stream_includes_subagent_header_on_other() { provider.clone(), session_source, config.model_verbosity, - false, - false, + None::, false, false, None, @@ -302,8 +301,7 @@ async fn responses_respects_model_info_overrides_from_config() { provider.clone(), session_source, config.model_verbosity, - false, - false, + None::, false, false, None, diff --git a/codex-rs/core/tests/suite/client.rs b/codex-rs/core/tests/suite/client.rs index 77a2ba33e..5fc94a033 100644 --- a/codex-rs/core/tests/suite/client.rs +++ b/codex-rs/core/tests/suite/client.rs @@ -4,6 +4,7 @@ use codex_core::ModelProviderInfo; use codex_core::NewThread; use codex_core::Prompt; use codex_core::ResponseEvent; +use codex_core::ResponsesWebsocketVersion; use codex_core::ThreadManager; use codex_core::WireApi; use codex_core::auth::AuthCredentialsStoreMode; @@ -1353,8 +1354,7 @@ async fn azure_responses_request_includes_store_and_reasoning_ids() { provider.clone(), SessionSource::Exec, config.model_verbosity, - false, - false, + None::, false, false, None, diff --git a/codex-rs/core/tests/suite/client_websockets.rs b/codex-rs/core/tests/suite/client_websockets.rs index 54e197e56..d8b1940d1 100755 --- a/codex-rs/core/tests/suite/client_websockets.rs +++ b/codex-rs/core/tests/suite/client_websockets.rs @@ -8,6 +8,7 @@ use codex_core::ResponseEvent; use codex_core::WireApi; use codex_core::X_RESPONSESAPI_INCLUDE_TIMING_METRICS_HEADER; use codex_core::features::Feature; +use codex_core::ws_version_from_features; use codex_otel::OtelManager; use codex_otel::TelemetryAuthMode; use codex_otel::metrics::MetricsClient; @@ -319,6 +320,62 @@ async fn responses_websocket_v2_requests_use_v2_when_model_prefers_websockets() server.shutdown().await; } +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn responses_websocket_v2_wins_when_both_features_enabled() { + skip_if_no_network!(); + + let server = start_websocket_server(vec![vec![ + vec![ + ev_response_created("resp-1"), + ev_assistant_message("msg-1", "assistant output"), + ev_done_with_id("resp-1"), + ], + vec![ev_response_created("resp-2"), ev_completed("resp-2")], + ]]) + .await; + + let harness = websocket_harness_with_options(&server, false, true, true, false).await; + let mut client_session = harness.client.new_session(); + let prompt_one = prompt_with_input(vec![message_item("hello")]); + let prompt_two = prompt_with_input(vec![ + message_item("hello"), + assistant_message_item("msg-1", "assistant output"), + message_item("second"), + ]); + + stream_until_complete(&mut client_session, &harness, &prompt_one).await; + stream_until_complete(&mut client_session, &harness, &prompt_two).await; + + let connection = server.single_connection(); + assert_eq!(connection.len(), 2); + let second = connection.get(1).expect("missing request").body_json(); + assert_eq!(second["type"].as_str(), Some("response.create")); + assert_eq!(second["previous_response_id"].as_str(), Some("resp-1")); + assert_eq!( + second["input"], + serde_json::to_value(&prompt_two.input[2..]).unwrap() + ); + + let handshake = server.single_handshake(); + let openai_beta_header = handshake + .header(OPENAI_BETA_HEADER) + .expect("missing OpenAI-Beta header"); + assert!( + openai_beta_header + .split(',') + .map(str::trim) + .any(|value| value == WS_V2_BETA_HEADER_VALUE) + ); + assert!( + !openai_beta_header + .split(',') + .map(str::trim) + .any(|value| value == OPENAI_BETA_RESPONSES_WEBSOCKETS) + ); + + server.shutdown().await; +} + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] #[traced_test] async fn responses_websocket_emits_websocket_telemetry_events() { @@ -1251,8 +1308,7 @@ async fn websocket_harness_with_options( provider.clone(), SessionSource::Exec, config.model_verbosity, - websocket_enabled, - websocket_v2_enabled, + ws_version_from_features(&config), false, runtime_metrics_enabled, None,