diff --git a/codex-rs/core/src/client.rs b/codex-rs/core/src/client.rs index 99f8464ac..09ae0eb0d 100644 --- a/codex-rs/core/src/client.rs +++ b/codex-rs/core/src/client.rs @@ -212,6 +212,8 @@ impl ModelClient { include_timing_metrics: bool, beta_features_header: Option, ) -> Self { + let enable_responses_websockets = + enable_responses_websockets || enable_responses_websockets_v2; Self { state: Arc::new(ModelClientState { auth_manager, @@ -351,7 +353,9 @@ impl ModelClient { /// to be eligible. pub fn responses_websocket_enabled(&self, model_info: &ModelInfo) -> bool { self.state.provider.supports_websockets - && (self.state.enable_responses_websockets || model_info.prefer_websockets) + && (self.state.enable_responses_websockets + || self.state.enable_responses_websockets_v2 + || model_info.prefer_websockets) } fn responses_websockets_v2_enabled(&self) -> bool { diff --git a/codex-rs/core/tests/suite/client_websockets.rs b/codex-rs/core/tests/suite/client_websockets.rs index 1db81ecf8..626d0695b 100755 --- a/codex-rs/core/tests/suite/client_websockets.rs +++ b/codex-rs/core/tests/suite/client_websockets.rs @@ -189,6 +189,108 @@ async fn responses_websocket_prewarm_uses_model_preference_when_feature_disabled server.shutdown().await; } +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn responses_websocket_v2_prewarm_runs_when_only_v2_feature_enabled() { + skip_if_no_network!(); + + let server = start_websocket_server(vec![vec![vec![ + ev_response_created("resp-1"), + ev_completed("resp-1"), + ]]]) + .await; + + let harness = websocket_harness_with_options(&server, false, false, true, false).await; + let mut client_session = harness.client.new_session(); + client_session + .prewarm_websocket(&harness.otel_manager, &harness.model_info) + .await + .expect("websocket prewarm failed"); + + assert_eq!(server.handshakes().len(), 1); + assert_eq!(server.single_connection().len(), 0); + + let prompt = prompt_with_input(vec![message_item("hello")]); + stream_until_complete(&mut client_session, &harness, &prompt).await; + + assert_eq!(server.handshakes().len(), 1); + assert_eq!(server.single_connection().len(), 1); + + let handshake = server.single_handshake(); + let openai_beta_header = handshake + .header(OPENAI_BETA_HEADER) + .expect("missing OpenAI-Beta header"); + assert!( + openai_beta_header + .split(',') + .map(str::trim) + .any(|value| value == WS_V2_BETA_HEADER_VALUE) + ); + assert!( + !openai_beta_header + .split(',') + .map(str::trim) + .any(|value| value == OPENAI_BETA_RESPONSES_WEBSOCKETS) + ); + + server.shutdown().await; +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn responses_websocket_v2_requests_use_v2_when_model_prefers_websockets() { + skip_if_no_network!(); + + let server = start_websocket_server(vec![vec![ + vec![ + ev_response_created("resp-1"), + ev_assistant_message("msg-1", "assistant output"), + ev_done_with_id("resp-1"), + ], + vec![ev_response_created("resp-2"), ev_completed("resp-2")], + ]]) + .await; + + let harness = websocket_harness_with_options(&server, false, false, true, true).await; + let mut client_session = harness.client.new_session(); + let prompt_one = prompt_with_input(vec![message_item("hello")]); + let prompt_two = prompt_with_input(vec![ + message_item("hello"), + assistant_message_item("msg-1", "assistant output"), + message_item("second"), + ]); + + stream_until_complete(&mut client_session, &harness, &prompt_one).await; + stream_until_complete(&mut client_session, &harness, &prompt_two).await; + + let connection = server.single_connection(); + assert_eq!(connection.len(), 2); + let second = connection.get(1).expect("missing request").body_json(); + assert_eq!(second["type"].as_str(), Some("response.create")); + assert_eq!(second["previous_response_id"].as_str(), Some("resp-1")); + assert_eq!( + second["input"], + serde_json::to_value(&prompt_two.input[2..]).unwrap() + ); + + let handshake = server.single_handshake(); + let openai_beta_header = handshake + .header(OPENAI_BETA_HEADER) + .expect("missing OpenAI-Beta header"); + assert!( + openai_beta_header + .split(',') + .map(str::trim) + .any(|value| value == WS_V2_BETA_HEADER_VALUE) + ); + assert!( + !openai_beta_header + .split(',') + .map(str::trim) + .any(|value| value == OPENAI_BETA_RESPONSES_WEBSOCKETS) + ); + + server.shutdown().await; +} + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] #[traced_test] async fn responses_websocket_emits_websocket_telemetry_events() {