diff --git a/codex-rs/codex-api/src/common.rs b/codex-rs/codex-api/src/common.rs index f27f936b5..7176a1586 100644 --- a/codex-rs/codex-api/src/common.rs +++ b/codex-rs/codex-api/src/common.rs @@ -176,6 +176,7 @@ impl From<&ResponsesApiRequest> for ResponseCreateWsRequest { include: request.include.clone(), prompt_cache_key: request.prompt_cache_key.clone(), text: request.text.clone(), + generate: None, client_metadata: None, } } @@ -200,6 +201,8 @@ pub struct ResponseCreateWsRequest { #[serde(skip_serializing_if = "Option::is_none")] pub text: Option, #[serde(skip_serializing_if = "Option::is_none")] + pub generate: Option, + #[serde(skip_serializing_if = "Option::is_none")] pub client_metadata: Option>, } diff --git a/codex-rs/core/src/client.rs b/codex-rs/core/src/client.rs index 523984104..74a2f15c1 100644 --- a/codex-rs/core/src/client.rs +++ b/codex-rs/core/src/client.rs @@ -12,19 +12,17 @@ //! requests during that turn. It caches a Responses WebSocket connection (opened lazily) and stores //! per-turn state such as the `x-codex-turn-state` token used for sticky routing. //! -//! Prewarm is intentionally handshake-only: it may warm a socket and capture sticky-routing -//! state, but the first `response.create` payload is still sent only when a turn starts. +//! WebSocket prewarm is a v2-only `response.create` with `generate=false`; it waits for completion +//! so the next request can reuse the same connection and `previous_response_id`. //! -//! Startup prewarm is owned by turn-scoped callers (for example, a pre-created regular task). When -//! a warmed [`ModelClientSession`] is available, turn execution can reuse it; otherwise the turn -//! lazily opens a websocket on first stream call. +//! Turn execution performs prewarm as a best-effort step before the first stream request so the +//! subsequent request can reuse the same connection. //! //! ## Retry-Budget Tradeoff //! -//! Startup prewarm is treated as the first websocket connection attempt for the first turn. If -//! it fails, the stream attempt fails and the retry/fallback loop decides whether to retry or fall -//! back. This avoids duplicate handshakes but means a failed prewarm can consume one retry -//! budget slot before any turn payload is sent. +//! V2 request prewarm is treated as the first websocket connection attempt for a turn. If it +//! fails, normal stream retry/fallback logic handles recovery on the same turn. V1 prewarm +//! remains connection-only. use std::collections::HashMap; use std::sync::Arc; @@ -146,7 +144,7 @@ struct ModelClientState { include_timing_metrics: bool, beta_features_header: Option, disable_websockets: AtomicBool, - cached_websocket_connection: StdMutex>, + cached_websocket_session: StdMutex, } /// Resolved API client setup for a single request attempt. @@ -191,9 +189,7 @@ pub struct ModelClient { /// contract and can cause routing bugs. pub struct ModelClientSession { client: ModelClient, - connection: Option, - websocket_last_request: Option, - websocket_last_response_rx: Option>, + websocket_session: WebsocketSession, /// Turn state for sticky routing. /// /// This is an `OnceLock` that stores the turn state value received from the server @@ -214,6 +210,13 @@ struct LastResponse { can_append: bool, } +#[derive(Debug, Default)] +struct WebsocketSession { + connection: Option, + last_request: Option, + last_response_rx: Option>, +} + enum WebsocketStreamOutcome { Stream(ResponseStream), FallbackToHttp, @@ -248,7 +251,7 @@ impl ModelClient { include_timing_metrics, beta_features_header, disable_websockets: AtomicBool::new(false), - cached_websocket_connection: StdMutex::new(None), + cached_websocket_session: StdMutex::new(WebsocketSession::default()), }), } } @@ -260,27 +263,26 @@ impl ModelClient { pub fn new_session(&self) -> ModelClientSession { ModelClientSession { client: self.clone(), - connection: self.take_cached_websocket_connection(), - websocket_last_request: None, - websocket_last_response_rx: None, + websocket_session: self.take_cached_websocket_session(), turn_state: Arc::new(OnceLock::new()), } } - fn take_cached_websocket_connection(&self) -> Option { - self.state - .cached_websocket_connection + fn take_cached_websocket_session(&self) -> WebsocketSession { + let mut cached_websocket_session = self + .state + .cached_websocket_session .lock() - .unwrap_or_else(std::sync::PoisonError::into_inner) - .take() + .unwrap_or_else(std::sync::PoisonError::into_inner); + std::mem::take(&mut *cached_websocket_session) } - fn store_cached_websocket_connection(&self, connection: ApiWebSocketConnection) { + fn store_cached_websocket_session(&self, websocket_session: WebsocketSession) { *self .state - .cached_websocket_connection + .cached_websocket_session .lock() - .unwrap_or_else(std::sync::PoisonError::into_inner) = Some(connection); + .unwrap_or_else(std::sync::PoisonError::into_inner) = websocket_session; } /// Compacts the current conversation history using the Compact endpoint. @@ -492,9 +494,9 @@ impl ModelClient { impl Drop for ModelClientSession { fn drop(&mut self) { - if let Some(connection) = self.connection.take() { - self.client.store_cached_websocket_connection(connection); - } + let websocket_session = std::mem::take(&mut self.websocket_session); + self.client + .store_cached_websocket_session(websocket_session); } } @@ -600,12 +602,13 @@ impl ModelClientSession { &self, request: &ResponsesApiRequest, last_response: Option<&LastResponse>, + allow_empty_delta: bool, ) -> Option> { // Checks whether the current request is an incremental append to the previous request. // We only append when non-input request fields are unchanged and `input` is a strict // extension of the previous known input. Server-returned output items are treated as part // of the baseline so we do not resend them. - let previous_request = self.websocket_last_request.as_ref()?; + let previous_request = self.websocket_session.last_request.as_ref()?; let mut previous_without_input = previous_request.clone(); previous_without_input.input.clear(); let mut request_without_input = request.clone(); @@ -623,9 +626,8 @@ impl ModelClientSession { } let baseline_len = baseline.len(); - if baseline_len > 0 - && request.input.starts_with(&baseline) - && baseline_len < request.input.len() + if request.input.starts_with(&baseline) + && (allow_empty_delta || baseline_len < request.input.len()) { Some(request.input[baseline_len..].to_vec()) } else { @@ -635,7 +637,8 @@ impl ModelClientSession { } fn get_last_response(&mut self) -> Option { - self.websocket_last_response_rx + self.websocket_session + .last_response_rx .take() .and_then(|mut receiver| match receiver.try_recv() { Ok(last_response) => Some(last_response), @@ -652,7 +655,10 @@ impl ModelClientSession { let Some(last_response) = self.get_last_response() else { return ResponsesWsRequest::ResponseCreate(payload); }; - let Some(append_items) = self.get_incremental_items(request, Some(&last_response)) else { + let allow_empty_delta = matches!(ws_version, ResponsesWebsocketVersion::V2); + let Some(append_items) = + self.get_incremental_items(request, Some(&last_response), allow_empty_delta) + else { return ResponsesWsRequest::ResponseCreate(payload); }; @@ -682,10 +688,10 @@ impl ModelClientSession { } } - /// Opportunistically warms a websocket for this turn-scoped client session. + /// Opportunistically preconnects a websocket for this turn-scoped client session. /// /// This performs only connection setup; it never sends prompt payloads. - pub async fn prewarm_websocket( + pub async fn preconnect_websocket( &mut self, otel_manager: &OtelManager, model_info: &ModelInfo, @@ -693,7 +699,7 @@ impl ModelClientSession { let Some(ws_version) = self.client.active_ws_version(model_info) else { return Ok(()); }; - if self.connection.is_some() { + if self.websocket_session.connection.is_some() { return Ok(()); } @@ -714,10 +720,9 @@ impl ModelClientSession { None, ) .await?; - self.connection = Some(connection); + self.websocket_session.connection = Some(connection); Ok(()) } - /// Returns a websocket connection for this turn. async fn websocket_connection( &mut self, @@ -728,14 +733,14 @@ impl ModelClientSession { turn_metadata_header: Option<&str>, options: &ApiResponsesOptions, ) -> std::result::Result<&ApiWebSocketConnection, ApiError> { - let needs_new = match self.connection.as_ref() { + let needs_new = match self.websocket_session.connection.as_ref() { Some(conn) => conn.is_closed().await, None => true, }; if needs_new { - self.websocket_last_request = None; - self.websocket_last_response_rx = None; + self.websocket_session.last_request = None; + self.websocket_session.last_response_rx = None; let turn_state = options .turn_state .clone() @@ -751,12 +756,15 @@ impl ModelClientSession { turn_metadata_header, ) .await?; - self.connection = Some(new_conn); + self.websocket_session.connection = Some(new_conn); } - self.connection.as_ref().ok_or(ApiError::Stream( - "websocket connection is unavailable".to_string(), - )) + self.websocket_session + .connection + .as_ref() + .ok_or(ApiError::Stream( + "websocket connection is unavailable".to_string(), + )) } fn responses_request_compression(&self, auth: Option<&crate::auth::CodexAuth>) -> Compression { @@ -848,6 +856,7 @@ impl ModelClientSession { effort: Option, summary: ReasoningSummaryConfig, turn_metadata_header: Option<&str>, + warmup: bool, ) -> Result { let auth_manager = self.client.state.auth_manager.clone(); @@ -866,10 +875,13 @@ impl ModelClientSession { effort, summary, )?; - let ws_payload = ResponseCreateWsRequest { + let mut ws_payload = ResponseCreateWsRequest { client_metadata: build_ws_client_metadata(turn_metadata_header), ..ResponseCreateWsRequest::from(&request) }; + if warmup { + ws_payload.generate = Some(false); + } match self .websocket_connection( @@ -898,8 +910,9 @@ impl ModelClientSession { } let ws_request = self.prepare_websocket_request(ws_payload, &request, ws_version); - + self.websocket_session.last_request = Some(request); let stream_result = self + .websocket_session .connection .as_ref() .ok_or_else(|| { @@ -910,11 +923,9 @@ impl ModelClientSession { .stream_request(ws_request) .await .map_err(map_api_error)?; - self.websocket_last_request = Some(request); let (stream, last_request_rx) = map_response_stream(stream_result, otel_manager.clone()); - self.websocket_last_response_rx = Some(last_request_rx); - + self.websocket_session.last_response_rx = Some(last_request_rx); return Ok(WebsocketStreamOutcome::Stream(stream)); } } @@ -936,6 +947,62 @@ impl ModelClientSession { websocket_telemetry } + #[allow(clippy::too_many_arguments)] + pub async fn prewarm_websocket( + &mut self, + prompt: &Prompt, + model_info: &ModelInfo, + otel_manager: &OtelManager, + effort: Option, + summary: ReasoningSummaryConfig, + turn_metadata_header: Option<&str>, + ) -> Result<()> { + let Some(ws_version) = self.client.active_ws_version(model_info) else { + return Ok(()); + }; + if self.websocket_session.last_request.is_some() { + return Ok(()); + } + + if matches!(ws_version, ResponsesWebsocketVersion::V1) { + self.preconnect_websocket(otel_manager, model_info) + .await + .map_err(map_api_error)?; + return Ok(()); + } + + match self + .stream_responses_websocket( + prompt, + model_info, + ws_version, + otel_manager, + effort, + summary, + turn_metadata_header, + true, + ) + .await + { + Ok(WebsocketStreamOutcome::Stream(mut stream)) => { + // Wait for the v2 warmup request to complete before sending the first turn request. + while let Some(event) = stream.next().await { + match event { + Ok(ResponseEvent::Completed { .. }) => break, + Err(err) => return Err(err), + _ => {} + } + } + Ok(()) + } + Ok(WebsocketStreamOutcome::FallbackToHttp) => { + self.try_switch_fallback_transport(otel_manager, model_info); + Ok(()) + } + Err(err) => Err(err), + } + } + #[allow(clippy::too_many_arguments)] /// Streams a single model request within the current turn. /// @@ -965,6 +1032,7 @@ impl ModelClientSession { effort, summary, turn_metadata_header, + false, ) .await? { @@ -1009,9 +1077,9 @@ impl ModelClientSession { &[("from_wire_api", "responses_websocket")], ); - self.connection = None; - self.websocket_last_request = None; - self.websocket_last_response_rx = None; + self.websocket_session.connection = None; + self.websocket_session.last_request = None; + self.websocket_session.last_response_rx = None; } activated } diff --git a/codex-rs/core/src/codex.rs b/codex-rs/core/src/codex.rs index 94d2cde3c..90ed386b1 100644 --- a/codex-rs/core/src/codex.rs +++ b/codex-rs/core/src/codex.rs @@ -106,6 +106,7 @@ use tokio::sync::Mutex; use tokio::sync::RwLock; use tokio::sync::oneshot; use tokio::sync::watch; +use tokio::task::JoinHandle; use tokio_util::sync::CancellationToken; use tracing::Instrument; use tracing::debug; @@ -1265,7 +1266,7 @@ impl Session { } }; session_configuration.thread_name = thread_name.clone(); - let mut state = SessionState::new(session_configuration.clone()); + let state = SessionState::new(session_configuration.clone()); let managed_network_requirements_enabled = config.managed_network_requirements_enabled(); let network_approval = Arc::new(NetworkApprovalService::default()); // The managed proxy can call back into core for allowlist-miss decisions. @@ -1372,16 +1373,6 @@ impl Session { config.js_repl_node_module_dirs.clone(), )); - let prewarm_model_info = models_manager - .get_model_info(session_configuration.collaboration_mode.model(), &config) - .await; - let startup_regular_task = RegularTask::with_startup_prewarm( - services.model_client.clone(), - services.otel_manager.clone(), - prewarm_model_info, - ); - state.set_startup_regular_task(startup_regular_task); - let sess = Arc::new(Session { conversation_id, tx_event: tx_event.clone(), @@ -1399,7 +1390,6 @@ impl Session { let mut guard = network_policy_decider_session.write().await; *guard = Arc::downgrade(&sess); } - // Dispatch the SessionConfiguredEvent first and then report any errors. // If resuming, include converted initial messages in the payload so UIs can render them immediately. let initial_messages = initial_history.get_event_msgs(); @@ -1429,7 +1419,6 @@ impl Session { // Start the watcher after SessionConfigured so it cannot emit earlier events. sess.start_file_watcher_listener(); - // Construct sandbox_state before MCP startup so it can be sent to each // MCP server immediately after it becomes ready (avoiding blocking). let sandbox_state = SandboxState { @@ -1490,6 +1479,8 @@ impl Session { )); } } + sess.schedule_startup_prewarm(session_configuration.base_instructions.clone()) + .await; // record_initial_history can emit events. We record only after the SessionConfiguredEvent is emitted. sess.record_initial_history(initial_history).await; @@ -2155,8 +2146,69 @@ impl Session { } pub(crate) async fn take_startup_regular_task(&self) -> Option { + let startup_regular_task = { + let mut state = self.state.lock().await; + state.take_startup_regular_task() + }; + let startup_regular_task = startup_regular_task?; + match startup_regular_task.await { + Ok(Ok(regular_task)) => Some(regular_task), + Ok(Err(err)) => { + warn!("startup websocket prewarm setup failed: {err:#}"); + None + } + Err(err) => { + warn!("startup websocket prewarm setup join failed: {err}"); + None + } + } + } + + async fn schedule_startup_prewarm(self: &Arc, base_instructions: String) { + let sess = Arc::clone(self); + let startup_regular_task: JoinHandle> = + tokio::spawn( + async move { sess.schedule_startup_prewarm_inner(base_instructions).await }, + ); let mut state = self.state.lock().await; - state.take_startup_regular_task() + state.set_startup_regular_task(startup_regular_task); + } + + async fn schedule_startup_prewarm_inner( + self: &Arc, + base_instructions: String, + ) -> CodexResult { + let startup_turn_context = self + .new_default_turn_with_sub_id(INITIAL_SUBMIT_ID.to_owned()) + .await; + let startup_cancellation_token = CancellationToken::new(); + let startup_router = built_tools( + self, + startup_turn_context.as_ref(), + &[], + &HashSet::new(), + None, + &startup_cancellation_token, + ) + .await?; + let startup_prompt = build_prompt( + Vec::new(), + startup_router.as_ref(), + startup_turn_context.as_ref(), + BaseInstructions { + text: base_instructions, + }, + ); + let startup_turn_metadata_header = startup_turn_context + .turn_metadata_state + .current_header_value(); + RegularTask::with_startup_prewarm( + self.services.model_client.clone(), + startup_prompt, + startup_turn_context, + startup_turn_metadata_header, + ) + .await } pub(crate) async fn get_config(&self) -> std::sync::Arc { @@ -5331,6 +5383,21 @@ fn codex_apps_connector_id(tool: &crate::mcp_connection_manager::ToolInfo) -> Op tool.connector_id.as_deref() } +fn build_prompt( + input: Vec, + router: &ToolRouter, + turn_context: &TurnContext, + base_instructions: BaseInstructions, +) -> Prompt { + Prompt { + input, + tools: router.specs(), + parallel_tool_calls: turn_context.model_info.supports_parallel_tool_calls, + base_instructions, + personality: turn_context.personality, + output_schema: turn_context.final_output_json_schema.clone(), + } +} #[allow(clippy::too_many_arguments)] #[instrument(level = "trace", skip_all, @@ -5362,19 +5429,14 @@ async fn run_sampling_request( ) .await?; - let model_supports_parallel = turn_context.model_info.supports_parallel_tool_calls; - - let tools = router.specs(); let base_instructions = sess.get_base_instructions().await; - let prompt = Prompt { + let prompt = build_prompt( input, - tools, - parallel_tool_calls: model_supports_parallel, + router.as_ref(), + turn_context.as_ref(), base_instructions, - personality: turn_context.personality, - output_schema: turn_context.final_output_json_schema.clone(), - }; + ); let mut retries = 0; loop { let err = match try_run_sampling_request( diff --git a/codex-rs/core/src/state/session.rs b/codex-rs/core/src/state/session.rs index 160225174..4523b9356 100644 --- a/codex-rs/core/src/state/session.rs +++ b/codex-rs/core/src/state/session.rs @@ -3,9 +3,11 @@ use codex_protocol::models::ResponseItem; use std::collections::HashMap; use std::collections::HashSet; +use tokio::task::JoinHandle; use crate::codex::SessionConfiguration; use crate::context_manager::ContextManager; +use crate::error::Result as CodexResult; use crate::protocol::RateLimitSnapshot; use crate::protocol::TokenUsage; use crate::protocol::TokenUsageInfo; @@ -26,7 +28,7 @@ pub(crate) struct SessionState { /// resume or `/compact`). previous_model: Option, /// Startup regular task pre-created during session initialization. - pub(crate) startup_regular_task: Option, + pub(crate) startup_regular_task: Option>>, pub(crate) active_mcp_tool_selection: Option>, pub(crate) active_connector_selection: HashSet, } @@ -155,11 +157,13 @@ impl SessionState { self.dependency_env.clone() } - pub(crate) fn set_startup_regular_task(&mut self, task: RegularTask) { + pub(crate) fn set_startup_regular_task(&mut self, task: JoinHandle>) { self.startup_regular_task = Some(task); } - pub(crate) fn take_startup_regular_task(&mut self) -> Option { + pub(crate) fn take_startup_regular_task( + &mut self, + ) -> Option>> { self.startup_regular_task.take() } diff --git a/codex-rs/core/src/tasks/regular.rs b/codex-rs/core/src/tasks/regular.rs index 1428f7775..725039d67 100644 --- a/codex-rs/core/src/tasks/regular.rs +++ b/codex-rs/core/src/tasks/regular.rs @@ -3,77 +3,61 @@ use std::sync::Mutex; use crate::client::ModelClient; use crate::client::ModelClientSession; +use crate::client_common::Prompt; use crate::codex::TurnContext; use crate::codex::run_turn; +use crate::error::Result as CodexResult; use crate::state::TaskKind; use async_trait::async_trait; -use codex_otel::OtelManager; -use codex_protocol::openai_models::ModelInfo; use codex_protocol::user_input::UserInput; -use tokio::task::JoinHandle; use tokio_util::sync::CancellationToken; use tracing::Instrument; use tracing::trace_span; -use tracing::warn; use super::SessionTask; use super::SessionTaskContext; -type PrewarmedSessionTask = JoinHandle>; - pub(crate) struct RegularTask { - prewarmed_session_task: Mutex>, + prewarmed_session: Mutex>, } impl Default for RegularTask { fn default() -> Self { Self { - prewarmed_session_task: Mutex::new(None), + prewarmed_session: Mutex::new(None), } } } impl RegularTask { - pub(crate) fn with_startup_prewarm( + pub(crate) async fn with_startup_prewarm( model_client: ModelClient, - otel_manager: OtelManager, - model_info: ModelInfo, - ) -> Self { - let prewarmed_session_task = tokio::spawn(async move { - let mut client_session = model_client.new_session(); - match client_session - .prewarm_websocket(&otel_manager, &model_info) - .await - { - Ok(()) => Some(client_session), - Err(err) => { - warn!("startup websocket prewarm task failed: {err}"); - None - } - } - }); + prompt: Prompt, + turn_context: Arc, + turn_metadata_header: Option, + ) -> CodexResult { + let mut client_session = model_client.new_session(); + client_session + .prewarm_websocket( + &prompt, + &turn_context.model_info, + &turn_context.otel_manager, + turn_context.reasoning_effort, + turn_context.reasoning_summary, + turn_metadata_header.as_deref(), + ) + .await?; - Self { - prewarmed_session_task: Mutex::new(Some(prewarmed_session_task)), - } + Ok(Self { + prewarmed_session: Mutex::new(Some(client_session)), + }) } async fn take_prewarmed_session(&self) -> Option { - let prewarmed_session_task = self - .prewarmed_session_task + self.prewarmed_session .lock() .unwrap_or_else(std::sync::PoisonError::into_inner) - .take(); - match prewarmed_session_task { - Some(task) => match task.await { - Ok(client_session) => client_session, - Err(err) => { - warn!("startup websocket prewarm task join failed: {err}"); - None - } - }, - None => None, - } + .take() } } diff --git a/codex-rs/core/tests/common/responses.rs b/codex-rs/core/tests/common/responses.rs index 649d1cf0b..4c7f26ddb 100644 --- a/codex-rs/core/tests/common/responses.rs +++ b/codex-rs/core/tests/common/responses.rs @@ -300,8 +300,8 @@ pub struct WebSocketConnectionConfig { pub response_headers: Vec<(String, String)>, /// Optional delay inserted before accepting the websocket handshake. /// - /// Tests use this to force startup preconnect into an in-flight state so first-turn adoption - /// paths can be exercised deterministically. + /// Tests use this to force websocket setup into an in-flight state so first-turn warmup paths + /// can be exercised deterministically. pub accept_delay: Option, } @@ -337,7 +337,7 @@ impl WebSocketTestServer { /// Waits until at least `expected` websocket handshakes have been observed or timeout elapses. /// /// Uses a short bounded polling interval so tests can deterministically wait for background - /// preconnect activity without busy-spinning. + /// websocket activity without busy-spinning. pub async fn wait_for_handshakes(&self, expected: usize, timeout: Duration) -> bool { if self.handshakes.lock().unwrap().len() >= expected { return true; diff --git a/codex-rs/core/tests/suite/agent_websocket.rs b/codex-rs/core/tests/suite/agent_websocket.rs index f9f33bb56..26b07bbd5 100644 --- a/codex-rs/core/tests/suite/agent_websocket.rs +++ b/codex-rs/core/tests/suite/agent_websocket.rs @@ -4,6 +4,7 @@ use core_test_support::responses::WebSocketConnectionConfig; use core_test_support::responses::ev_assistant_message; use core_test_support::responses::ev_completed; use core_test_support::responses::ev_done; +use core_test_support::responses::ev_done_with_id; use core_test_support::responses::ev_response_created; use core_test_support::responses::ev_shell_command_call; use core_test_support::responses::start_websocket_server; @@ -38,24 +39,28 @@ async fn websocket_test_codex_shell_chain() -> Result<()> { let mut builder = test_codex(); let test = builder.build_with_websocket_server(&server).await?; - test.submit_turn("run the echo command").await?; + test.submit_turn_with_policy( + "run the echo command", + test.config.permissions.sandbox_policy.get().clone(), + ) + .await?; let connection = server.single_connection(); assert_eq!(connection.len(), 2); - let first = connection + let first_turn = connection .first() - .expect("missing first request") + .expect("missing first turn request") .body_json(); - let second = connection + let second_turn = connection .get(1) - .expect("missing second request") + .expect("missing second turn request") .body_json(); - assert_eq!(first["type"].as_str(), Some("response.create")); - assert_eq!(second["type"].as_str(), Some("response.append")); + assert_eq!(first_turn["type"].as_str(), Some("response.create")); + assert_eq!(second_turn["type"].as_str(), Some("response.append")); - let append_items = second + let append_items = second_turn .get("input") .and_then(Value::as_array) .expect("response.append input array"); @@ -75,50 +80,81 @@ async fn websocket_test_codex_shell_chain() -> Result<()> { } #[tokio::test(flavor = "multi_thread", worker_threads = 2)] -async fn websocket_preconnect_happens_on_session_start() -> Result<()> { +async fn websocket_first_turn_uses_preconnect_and_create() -> Result<()> { skip_if_no_network!(Ok(())); let server = start_websocket_server(vec![vec![vec![ ev_response_created("resp-1"), + ev_assistant_message("msg-1", "hello"), ev_completed("resp-1"), ]]]) .await; let mut builder = test_codex(); let test = builder.build_with_websocket_server(&server).await?; - - assert!( - server.wait_for_handshakes(1, Duration::from_secs(2)).await, - "expected websocket preconnect handshake during session startup" - ); - - test.submit_turn("hello").await?; + test.submit_turn_with_policy( + "hello", + test.config.permissions.sandbox_policy.get().clone(), + ) + .await?; assert_eq!(server.handshakes().len(), 1); - assert_eq!(server.single_connection().len(), 1); + let connection = server.single_connection(); + assert_eq!(connection.len(), 1); + let turn = connection + .first() + .expect("missing turn request") + .body_json(); + assert!( + turn["tools"] + .as_array() + .is_some_and(|tools| !tools.is_empty()), + "expected request tools to be populated" + ); + assert_eq!(turn["type"].as_str(), Some("response.create")); server.shutdown().await; Ok(()) } #[tokio::test(flavor = "multi_thread", worker_threads = 2)] -async fn websocket_first_turn_waits_for_inflight_preconnect() -> Result<()> { +async fn websocket_first_turn_handles_handshake_delay_with_preconnect() -> Result<()> { skip_if_no_network!(Ok(())); let server = start_websocket_server_with_headers(vec![WebSocketConnectionConfig { - requests: vec![vec![ev_response_created("resp-1"), ev_completed("resp-1")]], + requests: vec![vec![ + ev_response_created("resp-1"), + ev_assistant_message("msg-1", "hello"), + ev_completed("resp-1"), + ]], response_headers: Vec::new(), - // Delay handshake so submit_turn() observes startup preconnect as in-flight. + // Delay handshake so turn processing must tolerate websocket startup latency. accept_delay: Some(Duration::from_millis(150)), }]) .await; let mut builder = test_codex(); let test = builder.build_with_websocket_server(&server).await?; - test.submit_turn("hello").await?; + test.submit_turn_with_policy( + "hello", + test.config.permissions.sandbox_policy.get().clone(), + ) + .await?; assert_eq!(server.handshakes().len(), 1); - assert_eq!(server.single_connection().len(), 1); + let connection = server.single_connection(); + assert_eq!(connection.len(), 1); + let turn = connection + .first() + .expect("missing turn request") + .body_json(); + assert!( + turn["tools"] + .as_array() + .is_some_and(|tools| !tools.is_empty()), + "expected request tools to be populated" + ); + assert_eq!(turn["type"].as_str(), Some("response.create")); server.shutdown().await; Ok(()) @@ -130,6 +166,7 @@ async fn websocket_v2_test_codex_shell_chain() -> Result<()> { let call_id = "shell-command-call"; let server = start_websocket_server(vec![vec![ + vec![ev_response_created("warm-1"), ev_done_with_id("warm-1")], vec![ ev_response_created("resp-1"), ev_shell_command_call(call_id, "echo websocket"), @@ -148,25 +185,42 @@ async fn websocket_v2_test_codex_shell_chain() -> Result<()> { }); let test = builder.build_with_websocket_server(&server).await?; - test.submit_turn("run the echo command").await?; + test.submit_turn_with_policy( + "run the echo command", + test.config.permissions.sandbox_policy.get().clone(), + ) + .await?; let connection = server.single_connection(); - assert_eq!(connection.len(), 2); + assert_eq!(connection.len(), 3); - let first = connection + let warmup = connection .first() - .expect("missing first request") + .expect("missing warmup request") .body_json(); - let second = connection + let first_turn = connection .get(1) - .expect("missing second request") + .expect("missing first turn request") + .body_json(); + let second_turn = connection + .get(2) + .expect("missing second turn request") .body_json(); - assert_eq!(first["type"].as_str(), Some("response.create")); - assert_eq!(second["type"].as_str(), Some("response.create")); - assert_eq!(second["previous_response_id"].as_str(), Some("resp-1")); + assert_eq!(warmup["type"].as_str(), Some("response.create")); + assert_eq!(warmup["generate"].as_bool(), Some(false)); + assert_eq!(first_turn["type"].as_str(), Some("response.create")); + assert_eq!(first_turn["previous_response_id"].as_str(), Some("warm-1")); + assert!( + first_turn + .get("input") + .and_then(Value::as_array) + .is_some_and(|items| !items.is_empty()) + ); + assert_eq!(second_turn["type"].as_str(), Some("response.create")); + assert_eq!(second_turn["previous_response_id"].as_str(), Some("resp-1")); - let create_items = second + let create_items = second_turn .get("input") .and_then(Value::as_array) .expect("response.create input array"); diff --git a/codex-rs/core/tests/suite/client_websockets.rs b/codex-rs/core/tests/suite/client_websockets.rs index d8b1940d1..3dec0c440 100755 --- a/codex-rs/core/tests/suite/client_websockets.rs +++ b/codex-rs/core/tests/suite/client_websockets.rs @@ -107,9 +107,9 @@ async fn responses_websocket_preconnect_reuses_connection() { let harness = websocket_harness(&server).await; let mut client_session = harness.client.new_session(); client_session - .prewarm_websocket(&harness.otel_manager, &harness.model_info) + .preconnect_websocket(&harness.otel_manager, &harness.model_info) .await - .expect("websocket prewarm failed"); + .expect("websocket preconnect failed"); let prompt = prompt_with_input(vec![message_item("hello")]); stream_until_complete(&mut client_session, &harness, &prompt).await; @@ -119,6 +119,54 @@ async fn responses_websocket_preconnect_reuses_connection() { server.shutdown().await; } +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn responses_websocket_request_prewarm_reuses_connection() { + skip_if_no_network!(); + + let server = start_websocket_server(vec![vec![ + vec![ev_response_created("warm-1"), ev_done_with_id("warm-1")], + vec![ev_response_created("resp-1"), ev_completed("resp-1")], + ]]) + .await; + + let harness = websocket_harness_with_options(&server, false, false, true, true).await; + let mut client_session = harness.client.new_session(); + let prompt = prompt_with_input(vec![message_item("hello")]); + client_session + .prewarm_websocket( + &prompt, + &harness.model_info, + &harness.otel_manager, + harness.effort, + harness.summary, + None, + ) + .await + .expect("websocket prewarm failed"); + stream_until_complete(&mut client_session, &harness, &prompt).await; + + assert_eq!(server.handshakes().len(), 1); + let connection = server.single_connection(); + assert_eq!(connection.len(), 2); + let warmup = connection + .first() + .expect("missing warmup request") + .body_json(); + let follow_up = connection + .get(1) + .expect("missing follow-up request") + .body_json(); + + assert_eq!(warmup["type"].as_str(), Some("response.create")); + assert_eq!(warmup["generate"].as_bool(), Some(false)); + assert_eq!(warmup["tools"], serde_json::json!([])); + assert_eq!(follow_up["type"].as_str(), Some("response.create")); + assert_eq!(follow_up["previous_response_id"].as_str(), Some("warm-1")); + assert_eq!(follow_up["input"], serde_json::json!([])); + + server.shutdown().await; +} + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn responses_websocket_reuses_connection_after_session_drop() { skip_if_no_network!(); @@ -160,9 +208,9 @@ async fn responses_websocket_preconnect_is_reused_even_with_header_changes() { let harness = websocket_harness(&server).await; let mut client_session = harness.client.new_session(); client_session - .prewarm_websocket(&harness.otel_manager, &harness.model_info) + .preconnect_websocket(&harness.otel_manager, &harness.model_info) .await - .expect("websocket prewarm failed"); + .expect("websocket preconnect failed"); let prompt = prompt_with_input(vec![message_item("hello")]); let mut stream = client_session .stream( @@ -188,6 +236,69 @@ async fn responses_websocket_preconnect_is_reused_even_with_header_changes() { server.shutdown().await; } +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn responses_websocket_request_prewarm_is_reused_even_with_header_changes() { + skip_if_no_network!(); + + let server = start_websocket_server(vec![vec![ + vec![ev_response_created("warm-1"), ev_done_with_id("warm-1")], + vec![ev_response_created("resp-1"), ev_completed("resp-1")], + ]]) + .await; + + let harness = websocket_harness_with_options(&server, false, false, true, true).await; + let mut client_session = harness.client.new_session(); + let prompt = prompt_with_input(vec![message_item("hello")]); + client_session + .prewarm_websocket( + &prompt, + &harness.model_info, + &harness.otel_manager, + harness.effort, + harness.summary, + None, + ) + .await + .expect("websocket prewarm failed"); + let mut stream = client_session + .stream( + &prompt, + &harness.model_info, + &harness.otel_manager, + harness.effort, + harness.summary, + None, + ) + .await + .expect("websocket stream failed"); + + while let Some(event) = stream.next().await { + if matches!(event, Ok(ResponseEvent::Completed { .. })) { + break; + } + } + + assert_eq!(server.handshakes().len(), 1); + let connection = server.single_connection(); + assert_eq!(connection.len(), 2); + let warmup = connection + .first() + .expect("missing warmup request") + .body_json(); + let follow_up = connection + .get(1) + .expect("missing follow-up request") + .body_json(); + assert_eq!(warmup["type"].as_str(), Some("response.create")); + assert_eq!(warmup["generate"].as_bool(), Some(false)); + assert_eq!(warmup["tools"], serde_json::json!([])); + assert_eq!(follow_up["type"].as_str(), Some("response.create")); + assert_eq!(follow_up["previous_response_id"].as_str(), Some("warm-1")); + assert_eq!(follow_up["input"], serde_json::json!([])); + + server.shutdown().await; +} + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn responses_websocket_prewarm_uses_model_preference_when_feature_disabled() { skip_if_no_network!(); @@ -200,26 +311,39 @@ async fn responses_websocket_prewarm_uses_model_preference_when_feature_disabled let harness = websocket_harness_with_options(&server, false, false, false, true).await; let mut client_session = harness.client.new_session(); + let prompt = prompt_with_input(vec![message_item("hello")]); client_session - .prewarm_websocket(&harness.otel_manager, &harness.model_info) + .prewarm_websocket( + &prompt, + &harness.model_info, + &harness.otel_manager, + harness.effort, + harness.summary, + None, + ) .await .expect("websocket prewarm failed"); - // Prewarm should only perform the handshake, not send response.create. + // V1 prewarm only preconnects and should not issue a request. assert_eq!(server.handshakes().len(), 1); assert_eq!(server.single_connection().len(), 0); - let prompt = prompt_with_input(vec![message_item("hello")]); stream_until_complete(&mut client_session, &harness, &prompt).await; - assert_eq!(server.handshakes().len(), 1); - assert_eq!(server.single_connection().len(), 1); + let connection = server.single_connection(); + assert_eq!(connection.len(), 1); + let turn = connection + .first() + .expect("missing turn request") + .body_json(); + assert_eq!(turn["type"].as_str(), Some("response.create")); + assert_eq!(turn["input"], serde_json::to_value(&prompt.input).unwrap()); server.shutdown().await; } #[tokio::test(flavor = "multi_thread", worker_threads = 2)] -async fn responses_websocket_v2_prewarm_runs_when_only_v2_feature_enabled() { +async fn responses_websocket_preconnect_runs_when_only_v2_feature_enabled() { skip_if_no_network!(); let server = start_websocket_server(vec![vec![vec![ @@ -231,9 +355,9 @@ async fn responses_websocket_v2_prewarm_runs_when_only_v2_feature_enabled() { let harness = websocket_harness_with_options(&server, false, false, true, false).await; let mut client_session = harness.client.new_session(); client_session - .prewarm_websocket(&harness.otel_manager, &harness.model_info) + .preconnect_websocket(&harness.otel_manager, &harness.model_info) .await - .expect("websocket prewarm failed"); + .expect("websocket preconnect failed"); assert_eq!(server.handshakes().len(), 1); assert_eq!(server.single_connection().len(), 0); @@ -320,6 +444,50 @@ async fn responses_websocket_v2_requests_use_v2_when_model_prefers_websockets() server.shutdown().await; } +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn responses_websocket_v2_incremental_requests_are_reused_across_turns() { + skip_if_no_network!(); + + let server = start_websocket_server(vec![vec![ + vec![ + ev_response_created("resp-1"), + ev_assistant_message("msg-1", "assistant output"), + ev_done_with_id("resp-1"), + ], + vec![ev_response_created("resp-2"), ev_completed("resp-2")], + ]]) + .await; + + let harness = websocket_harness_with_options(&server, false, false, true, true).await; + let prompt_one = prompt_with_input(vec![message_item("hello")]); + let prompt_two = prompt_with_input(vec![ + message_item("hello"), + assistant_message_item("msg-1", "assistant output"), + message_item("second"), + ]); + + { + let mut client_session = harness.client.new_session(); + stream_until_complete(&mut client_session, &harness, &prompt_one).await; + } + + let mut client_session = harness.client.new_session(); + stream_until_complete(&mut client_session, &harness, &prompt_two).await; + + assert_eq!(server.handshakes().len(), 1); + let connection = server.single_connection(); + assert_eq!(connection.len(), 2); + let second = connection.get(1).expect("missing request").body_json(); + assert_eq!(second["type"].as_str(), Some("response.create")); + assert_eq!(second["previous_response_id"].as_str(), Some("resp-1")); + assert_eq!( + second["input"], + serde_json::to_value(&prompt_two.input[2..]).unwrap() + ); + + server.shutdown().await; +} + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn responses_websocket_v2_wins_when_both_features_enabled() { skip_if_no_network!(); diff --git a/codex-rs/core/tests/suite/websocket_fallback.rs b/codex-rs/core/tests/suite/websocket_fallback.rs index b5c27ae92..20c7b3b2f 100644 --- a/codex-rs/core/tests/suite/websocket_fallback.rs +++ b/codex-rs/core/tests/suite/websocket_fallback.rs @@ -67,9 +67,8 @@ async fn websocket_fallback_switches_to_http_on_upgrade_required_connect() -> Re .filter(|req| req.method == Method::POST && req.url.path().ends_with("/responses")) .count(); - // One websocket attempt comes from startup preconnect and one from the first turn's stream - // attempt before fallback activates; after fallback, transport is HTTP. This matches the - // retry-budget tradeoff documented in [`codex_core::client`] module docs. + // Startup prewarm now only preconnects for v1 (one websocket GET with no request body). + // The first turn then attempts websocket once, sees 426, and falls back to HTTP. assert_eq!(websocket_attempts, 2); assert_eq!(http_attempts, 1); assert_eq!(response_mock.requests().len(), 1); @@ -112,7 +111,7 @@ async fn websocket_fallback_switches_to_http_after_retries_exhausted() -> Result .filter(|req| req.method == Method::POST && req.url.path().ends_with("/responses")) .count(); - // One websocket attempt comes from startup preconnect. + // Deferred request prewarm is attempted at startup. // The first turn then makes 3 websocket stream attempts (initial try + 2 retries), // after which fallback activates and the request is replayed over HTTP. assert_eq!(websocket_attempts, 4); @@ -233,7 +232,8 @@ async fn websocket_fallback_is_sticky_across_turns() -> Result<()> { .count(); // WebSocket attempts all happen on the first turn: - // 1 startup preconnect + 3 stream attempts (initial try + 2 retries) before fallback. + // 1 deferred request prewarm attempt (startup) + 3 stream attempts + // (initial try + 2 retries) before fallback. // Fallback is sticky, so the second turn stays on HTTP and adds no websocket attempts. assert_eq!(websocket_attempts, 4); assert_eq!(http_attempts, 2);