Send warmup request (#11258)

Send a request with `generate: falls` but a full set of tools and
instructions to pre-warm inference.

---------

Co-authored-by: Codex <noreply@openai.com>
This commit is contained in:
pakrym-oai 2026-02-24 08:15:47 -08:00 committed by GitHub
parent 0679e70bfc
commit 97d0068658
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 516 additions and 173 deletions

View file

@ -176,6 +176,7 @@ impl From<&ResponsesApiRequest> for ResponseCreateWsRequest {
include: request.include.clone(),
prompt_cache_key: request.prompt_cache_key.clone(),
text: request.text.clone(),
generate: None,
client_metadata: None,
}
}
@ -200,6 +201,8 @@ pub struct ResponseCreateWsRequest {
#[serde(skip_serializing_if = "Option::is_none")]
pub text: Option<TextControls>,
#[serde(skip_serializing_if = "Option::is_none")]
pub generate: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub client_metadata: Option<HashMap<String, String>>,
}

View file

@ -12,19 +12,17 @@
//! requests during that turn. It caches a Responses WebSocket connection (opened lazily) and stores
//! per-turn state such as the `x-codex-turn-state` token used for sticky routing.
//!
//! Prewarm is intentionally handshake-only: it may warm a socket and capture sticky-routing
//! state, but the first `response.create` payload is still sent only when a turn starts.
//! WebSocket prewarm is a v2-only `response.create` with `generate=false`; it waits for completion
//! so the next request can reuse the same connection and `previous_response_id`.
//!
//! Startup prewarm is owned by turn-scoped callers (for example, a pre-created regular task). When
//! a warmed [`ModelClientSession`] is available, turn execution can reuse it; otherwise the turn
//! lazily opens a websocket on first stream call.
//! Turn execution performs prewarm as a best-effort step before the first stream request so the
//! subsequent request can reuse the same connection.
//!
//! ## Retry-Budget Tradeoff
//!
//! Startup prewarm is treated as the first websocket connection attempt for the first turn. If
//! it fails, the stream attempt fails and the retry/fallback loop decides whether to retry or fall
//! back. This avoids duplicate handshakes but means a failed prewarm can consume one retry
//! budget slot before any turn payload is sent.
//! V2 request prewarm is treated as the first websocket connection attempt for a turn. If it
//! fails, normal stream retry/fallback logic handles recovery on the same turn. V1 prewarm
//! remains connection-only.
use std::collections::HashMap;
use std::sync::Arc;
@ -146,7 +144,7 @@ struct ModelClientState {
include_timing_metrics: bool,
beta_features_header: Option<String>,
disable_websockets: AtomicBool,
cached_websocket_connection: StdMutex<Option<ApiWebSocketConnection>>,
cached_websocket_session: StdMutex<WebsocketSession>,
}
/// Resolved API client setup for a single request attempt.
@ -191,9 +189,7 @@ pub struct ModelClient {
/// contract and can cause routing bugs.
pub struct ModelClientSession {
client: ModelClient,
connection: Option<ApiWebSocketConnection>,
websocket_last_request: Option<ResponsesApiRequest>,
websocket_last_response_rx: Option<oneshot::Receiver<LastResponse>>,
websocket_session: WebsocketSession,
/// Turn state for sticky routing.
///
/// This is an `OnceLock` that stores the turn state value received from the server
@ -214,6 +210,13 @@ struct LastResponse {
can_append: bool,
}
#[derive(Debug, Default)]
struct WebsocketSession {
connection: Option<ApiWebSocketConnection>,
last_request: Option<ResponsesApiRequest>,
last_response_rx: Option<oneshot::Receiver<LastResponse>>,
}
enum WebsocketStreamOutcome {
Stream(ResponseStream),
FallbackToHttp,
@ -248,7 +251,7 @@ impl ModelClient {
include_timing_metrics,
beta_features_header,
disable_websockets: AtomicBool::new(false),
cached_websocket_connection: StdMutex::new(None),
cached_websocket_session: StdMutex::new(WebsocketSession::default()),
}),
}
}
@ -260,27 +263,26 @@ impl ModelClient {
pub fn new_session(&self) -> ModelClientSession {
ModelClientSession {
client: self.clone(),
connection: self.take_cached_websocket_connection(),
websocket_last_request: None,
websocket_last_response_rx: None,
websocket_session: self.take_cached_websocket_session(),
turn_state: Arc::new(OnceLock::new()),
}
}
fn take_cached_websocket_connection(&self) -> Option<ApiWebSocketConnection> {
self.state
.cached_websocket_connection
fn take_cached_websocket_session(&self) -> WebsocketSession {
let mut cached_websocket_session = self
.state
.cached_websocket_session
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
.take()
.unwrap_or_else(std::sync::PoisonError::into_inner);
std::mem::take(&mut *cached_websocket_session)
}
fn store_cached_websocket_connection(&self, connection: ApiWebSocketConnection) {
fn store_cached_websocket_session(&self, websocket_session: WebsocketSession) {
*self
.state
.cached_websocket_connection
.cached_websocket_session
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner) = Some(connection);
.unwrap_or_else(std::sync::PoisonError::into_inner) = websocket_session;
}
/// Compacts the current conversation history using the Compact endpoint.
@ -492,9 +494,9 @@ impl ModelClient {
impl Drop for ModelClientSession {
fn drop(&mut self) {
if let Some(connection) = self.connection.take() {
self.client.store_cached_websocket_connection(connection);
}
let websocket_session = std::mem::take(&mut self.websocket_session);
self.client
.store_cached_websocket_session(websocket_session);
}
}
@ -600,12 +602,13 @@ impl ModelClientSession {
&self,
request: &ResponsesApiRequest,
last_response: Option<&LastResponse>,
allow_empty_delta: bool,
) -> Option<Vec<ResponseItem>> {
// Checks whether the current request is an incremental append to the previous request.
// We only append when non-input request fields are unchanged and `input` is a strict
// extension of the previous known input. Server-returned output items are treated as part
// of the baseline so we do not resend them.
let previous_request = self.websocket_last_request.as_ref()?;
let previous_request = self.websocket_session.last_request.as_ref()?;
let mut previous_without_input = previous_request.clone();
previous_without_input.input.clear();
let mut request_without_input = request.clone();
@ -623,9 +626,8 @@ impl ModelClientSession {
}
let baseline_len = baseline.len();
if baseline_len > 0
&& request.input.starts_with(&baseline)
&& baseline_len < request.input.len()
if request.input.starts_with(&baseline)
&& (allow_empty_delta || baseline_len < request.input.len())
{
Some(request.input[baseline_len..].to_vec())
} else {
@ -635,7 +637,8 @@ impl ModelClientSession {
}
fn get_last_response(&mut self) -> Option<LastResponse> {
self.websocket_last_response_rx
self.websocket_session
.last_response_rx
.take()
.and_then(|mut receiver| match receiver.try_recv() {
Ok(last_response) => Some(last_response),
@ -652,7 +655,10 @@ impl ModelClientSession {
let Some(last_response) = self.get_last_response() else {
return ResponsesWsRequest::ResponseCreate(payload);
};
let Some(append_items) = self.get_incremental_items(request, Some(&last_response)) else {
let allow_empty_delta = matches!(ws_version, ResponsesWebsocketVersion::V2);
let Some(append_items) =
self.get_incremental_items(request, Some(&last_response), allow_empty_delta)
else {
return ResponsesWsRequest::ResponseCreate(payload);
};
@ -682,10 +688,10 @@ impl ModelClientSession {
}
}
/// Opportunistically warms a websocket for this turn-scoped client session.
/// Opportunistically preconnects a websocket for this turn-scoped client session.
///
/// This performs only connection setup; it never sends prompt payloads.
pub async fn prewarm_websocket(
pub async fn preconnect_websocket(
&mut self,
otel_manager: &OtelManager,
model_info: &ModelInfo,
@ -693,7 +699,7 @@ impl ModelClientSession {
let Some(ws_version) = self.client.active_ws_version(model_info) else {
return Ok(());
};
if self.connection.is_some() {
if self.websocket_session.connection.is_some() {
return Ok(());
}
@ -714,10 +720,9 @@ impl ModelClientSession {
None,
)
.await?;
self.connection = Some(connection);
self.websocket_session.connection = Some(connection);
Ok(())
}
/// Returns a websocket connection for this turn.
async fn websocket_connection(
&mut self,
@ -728,14 +733,14 @@ impl ModelClientSession {
turn_metadata_header: Option<&str>,
options: &ApiResponsesOptions,
) -> std::result::Result<&ApiWebSocketConnection, ApiError> {
let needs_new = match self.connection.as_ref() {
let needs_new = match self.websocket_session.connection.as_ref() {
Some(conn) => conn.is_closed().await,
None => true,
};
if needs_new {
self.websocket_last_request = None;
self.websocket_last_response_rx = None;
self.websocket_session.last_request = None;
self.websocket_session.last_response_rx = None;
let turn_state = options
.turn_state
.clone()
@ -751,12 +756,15 @@ impl ModelClientSession {
turn_metadata_header,
)
.await?;
self.connection = Some(new_conn);
self.websocket_session.connection = Some(new_conn);
}
self.connection.as_ref().ok_or(ApiError::Stream(
"websocket connection is unavailable".to_string(),
))
self.websocket_session
.connection
.as_ref()
.ok_or(ApiError::Stream(
"websocket connection is unavailable".to_string(),
))
}
fn responses_request_compression(&self, auth: Option<&crate::auth::CodexAuth>) -> Compression {
@ -848,6 +856,7 @@ impl ModelClientSession {
effort: Option<ReasoningEffortConfig>,
summary: ReasoningSummaryConfig,
turn_metadata_header: Option<&str>,
warmup: bool,
) -> Result<WebsocketStreamOutcome> {
let auth_manager = self.client.state.auth_manager.clone();
@ -866,10 +875,13 @@ impl ModelClientSession {
effort,
summary,
)?;
let ws_payload = ResponseCreateWsRequest {
let mut ws_payload = ResponseCreateWsRequest {
client_metadata: build_ws_client_metadata(turn_metadata_header),
..ResponseCreateWsRequest::from(&request)
};
if warmup {
ws_payload.generate = Some(false);
}
match self
.websocket_connection(
@ -898,8 +910,9 @@ impl ModelClientSession {
}
let ws_request = self.prepare_websocket_request(ws_payload, &request, ws_version);
self.websocket_session.last_request = Some(request);
let stream_result = self
.websocket_session
.connection
.as_ref()
.ok_or_else(|| {
@ -910,11 +923,9 @@ impl ModelClientSession {
.stream_request(ws_request)
.await
.map_err(map_api_error)?;
self.websocket_last_request = Some(request);
let (stream, last_request_rx) =
map_response_stream(stream_result, otel_manager.clone());
self.websocket_last_response_rx = Some(last_request_rx);
self.websocket_session.last_response_rx = Some(last_request_rx);
return Ok(WebsocketStreamOutcome::Stream(stream));
}
}
@ -936,6 +947,62 @@ impl ModelClientSession {
websocket_telemetry
}
#[allow(clippy::too_many_arguments)]
pub async fn prewarm_websocket(
&mut self,
prompt: &Prompt,
model_info: &ModelInfo,
otel_manager: &OtelManager,
effort: Option<ReasoningEffortConfig>,
summary: ReasoningSummaryConfig,
turn_metadata_header: Option<&str>,
) -> Result<()> {
let Some(ws_version) = self.client.active_ws_version(model_info) else {
return Ok(());
};
if self.websocket_session.last_request.is_some() {
return Ok(());
}
if matches!(ws_version, ResponsesWebsocketVersion::V1) {
self.preconnect_websocket(otel_manager, model_info)
.await
.map_err(map_api_error)?;
return Ok(());
}
match self
.stream_responses_websocket(
prompt,
model_info,
ws_version,
otel_manager,
effort,
summary,
turn_metadata_header,
true,
)
.await
{
Ok(WebsocketStreamOutcome::Stream(mut stream)) => {
// Wait for the v2 warmup request to complete before sending the first turn request.
while let Some(event) = stream.next().await {
match event {
Ok(ResponseEvent::Completed { .. }) => break,
Err(err) => return Err(err),
_ => {}
}
}
Ok(())
}
Ok(WebsocketStreamOutcome::FallbackToHttp) => {
self.try_switch_fallback_transport(otel_manager, model_info);
Ok(())
}
Err(err) => Err(err),
}
}
#[allow(clippy::too_many_arguments)]
/// Streams a single model request within the current turn.
///
@ -965,6 +1032,7 @@ impl ModelClientSession {
effort,
summary,
turn_metadata_header,
false,
)
.await?
{
@ -1009,9 +1077,9 @@ impl ModelClientSession {
&[("from_wire_api", "responses_websocket")],
);
self.connection = None;
self.websocket_last_request = None;
self.websocket_last_response_rx = None;
self.websocket_session.connection = None;
self.websocket_session.last_request = None;
self.websocket_session.last_response_rx = None;
}
activated
}

View file

@ -106,6 +106,7 @@ use tokio::sync::Mutex;
use tokio::sync::RwLock;
use tokio::sync::oneshot;
use tokio::sync::watch;
use tokio::task::JoinHandle;
use tokio_util::sync::CancellationToken;
use tracing::Instrument;
use tracing::debug;
@ -1265,7 +1266,7 @@ impl Session {
}
};
session_configuration.thread_name = thread_name.clone();
let mut state = SessionState::new(session_configuration.clone());
let state = SessionState::new(session_configuration.clone());
let managed_network_requirements_enabled = config.managed_network_requirements_enabled();
let network_approval = Arc::new(NetworkApprovalService::default());
// The managed proxy can call back into core for allowlist-miss decisions.
@ -1372,16 +1373,6 @@ impl Session {
config.js_repl_node_module_dirs.clone(),
));
let prewarm_model_info = models_manager
.get_model_info(session_configuration.collaboration_mode.model(), &config)
.await;
let startup_regular_task = RegularTask::with_startup_prewarm(
services.model_client.clone(),
services.otel_manager.clone(),
prewarm_model_info,
);
state.set_startup_regular_task(startup_regular_task);
let sess = Arc::new(Session {
conversation_id,
tx_event: tx_event.clone(),
@ -1399,7 +1390,6 @@ impl Session {
let mut guard = network_policy_decider_session.write().await;
*guard = Arc::downgrade(&sess);
}
// Dispatch the SessionConfiguredEvent first and then report any errors.
// If resuming, include converted initial messages in the payload so UIs can render them immediately.
let initial_messages = initial_history.get_event_msgs();
@ -1429,7 +1419,6 @@ impl Session {
// Start the watcher after SessionConfigured so it cannot emit earlier events.
sess.start_file_watcher_listener();
// Construct sandbox_state before MCP startup so it can be sent to each
// MCP server immediately after it becomes ready (avoiding blocking).
let sandbox_state = SandboxState {
@ -1490,6 +1479,8 @@ impl Session {
));
}
}
sess.schedule_startup_prewarm(session_configuration.base_instructions.clone())
.await;
// record_initial_history can emit events. We record only after the SessionConfiguredEvent is emitted.
sess.record_initial_history(initial_history).await;
@ -2155,8 +2146,69 @@ impl Session {
}
pub(crate) async fn take_startup_regular_task(&self) -> Option<RegularTask> {
let startup_regular_task = {
let mut state = self.state.lock().await;
state.take_startup_regular_task()
};
let startup_regular_task = startup_regular_task?;
match startup_regular_task.await {
Ok(Ok(regular_task)) => Some(regular_task),
Ok(Err(err)) => {
warn!("startup websocket prewarm setup failed: {err:#}");
None
}
Err(err) => {
warn!("startup websocket prewarm setup join failed: {err}");
None
}
}
}
async fn schedule_startup_prewarm(self: &Arc<Self>, base_instructions: String) {
let sess = Arc::clone(self);
let startup_regular_task: JoinHandle<CodexResult<RegularTask>> =
tokio::spawn(
async move { sess.schedule_startup_prewarm_inner(base_instructions).await },
);
let mut state = self.state.lock().await;
state.take_startup_regular_task()
state.set_startup_regular_task(startup_regular_task);
}
async fn schedule_startup_prewarm_inner(
self: &Arc<Self>,
base_instructions: String,
) -> CodexResult<RegularTask> {
let startup_turn_context = self
.new_default_turn_with_sub_id(INITIAL_SUBMIT_ID.to_owned())
.await;
let startup_cancellation_token = CancellationToken::new();
let startup_router = built_tools(
self,
startup_turn_context.as_ref(),
&[],
&HashSet::new(),
None,
&startup_cancellation_token,
)
.await?;
let startup_prompt = build_prompt(
Vec::new(),
startup_router.as_ref(),
startup_turn_context.as_ref(),
BaseInstructions {
text: base_instructions,
},
);
let startup_turn_metadata_header = startup_turn_context
.turn_metadata_state
.current_header_value();
RegularTask::with_startup_prewarm(
self.services.model_client.clone(),
startup_prompt,
startup_turn_context,
startup_turn_metadata_header,
)
.await
}
pub(crate) async fn get_config(&self) -> std::sync::Arc<Config> {
@ -5331,6 +5383,21 @@ fn codex_apps_connector_id(tool: &crate::mcp_connection_manager::ToolInfo) -> Op
tool.connector_id.as_deref()
}
fn build_prompt(
input: Vec<ResponseItem>,
router: &ToolRouter,
turn_context: &TurnContext,
base_instructions: BaseInstructions,
) -> Prompt {
Prompt {
input,
tools: router.specs(),
parallel_tool_calls: turn_context.model_info.supports_parallel_tool_calls,
base_instructions,
personality: turn_context.personality,
output_schema: turn_context.final_output_json_schema.clone(),
}
}
#[allow(clippy::too_many_arguments)]
#[instrument(level = "trace",
skip_all,
@ -5362,19 +5429,14 @@ async fn run_sampling_request(
)
.await?;
let model_supports_parallel = turn_context.model_info.supports_parallel_tool_calls;
let tools = router.specs();
let base_instructions = sess.get_base_instructions().await;
let prompt = Prompt {
let prompt = build_prompt(
input,
tools,
parallel_tool_calls: model_supports_parallel,
router.as_ref(),
turn_context.as_ref(),
base_instructions,
personality: turn_context.personality,
output_schema: turn_context.final_output_json_schema.clone(),
};
);
let mut retries = 0;
loop {
let err = match try_run_sampling_request(

View file

@ -3,9 +3,11 @@
use codex_protocol::models::ResponseItem;
use std::collections::HashMap;
use std::collections::HashSet;
use tokio::task::JoinHandle;
use crate::codex::SessionConfiguration;
use crate::context_manager::ContextManager;
use crate::error::Result as CodexResult;
use crate::protocol::RateLimitSnapshot;
use crate::protocol::TokenUsage;
use crate::protocol::TokenUsageInfo;
@ -26,7 +28,7 @@ pub(crate) struct SessionState {
/// resume or `/compact`).
previous_model: Option<String>,
/// Startup regular task pre-created during session initialization.
pub(crate) startup_regular_task: Option<RegularTask>,
pub(crate) startup_regular_task: Option<JoinHandle<CodexResult<RegularTask>>>,
pub(crate) active_mcp_tool_selection: Option<Vec<String>>,
pub(crate) active_connector_selection: HashSet<String>,
}
@ -155,11 +157,13 @@ impl SessionState {
self.dependency_env.clone()
}
pub(crate) fn set_startup_regular_task(&mut self, task: RegularTask) {
pub(crate) fn set_startup_regular_task(&mut self, task: JoinHandle<CodexResult<RegularTask>>) {
self.startup_regular_task = Some(task);
}
pub(crate) fn take_startup_regular_task(&mut self) -> Option<RegularTask> {
pub(crate) fn take_startup_regular_task(
&mut self,
) -> Option<JoinHandle<CodexResult<RegularTask>>> {
self.startup_regular_task.take()
}

View file

@ -3,77 +3,61 @@ use std::sync::Mutex;
use crate::client::ModelClient;
use crate::client::ModelClientSession;
use crate::client_common::Prompt;
use crate::codex::TurnContext;
use crate::codex::run_turn;
use crate::error::Result as CodexResult;
use crate::state::TaskKind;
use async_trait::async_trait;
use codex_otel::OtelManager;
use codex_protocol::openai_models::ModelInfo;
use codex_protocol::user_input::UserInput;
use tokio::task::JoinHandle;
use tokio_util::sync::CancellationToken;
use tracing::Instrument;
use tracing::trace_span;
use tracing::warn;
use super::SessionTask;
use super::SessionTaskContext;
type PrewarmedSessionTask = JoinHandle<Option<ModelClientSession>>;
pub(crate) struct RegularTask {
prewarmed_session_task: Mutex<Option<PrewarmedSessionTask>>,
prewarmed_session: Mutex<Option<ModelClientSession>>,
}
impl Default for RegularTask {
fn default() -> Self {
Self {
prewarmed_session_task: Mutex::new(None),
prewarmed_session: Mutex::new(None),
}
}
}
impl RegularTask {
pub(crate) fn with_startup_prewarm(
pub(crate) async fn with_startup_prewarm(
model_client: ModelClient,
otel_manager: OtelManager,
model_info: ModelInfo,
) -> Self {
let prewarmed_session_task = tokio::spawn(async move {
let mut client_session = model_client.new_session();
match client_session
.prewarm_websocket(&otel_manager, &model_info)
.await
{
Ok(()) => Some(client_session),
Err(err) => {
warn!("startup websocket prewarm task failed: {err}");
None
}
}
});
prompt: Prompt,
turn_context: Arc<TurnContext>,
turn_metadata_header: Option<String>,
) -> CodexResult<Self> {
let mut client_session = model_client.new_session();
client_session
.prewarm_websocket(
&prompt,
&turn_context.model_info,
&turn_context.otel_manager,
turn_context.reasoning_effort,
turn_context.reasoning_summary,
turn_metadata_header.as_deref(),
)
.await?;
Self {
prewarmed_session_task: Mutex::new(Some(prewarmed_session_task)),
}
Ok(Self {
prewarmed_session: Mutex::new(Some(client_session)),
})
}
async fn take_prewarmed_session(&self) -> Option<ModelClientSession> {
let prewarmed_session_task = self
.prewarmed_session_task
self.prewarmed_session
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
.take();
match prewarmed_session_task {
Some(task) => match task.await {
Ok(client_session) => client_session,
Err(err) => {
warn!("startup websocket prewarm task join failed: {err}");
None
}
},
None => None,
}
.take()
}
}

View file

@ -300,8 +300,8 @@ pub struct WebSocketConnectionConfig {
pub response_headers: Vec<(String, String)>,
/// Optional delay inserted before accepting the websocket handshake.
///
/// Tests use this to force startup preconnect into an in-flight state so first-turn adoption
/// paths can be exercised deterministically.
/// Tests use this to force websocket setup into an in-flight state so first-turn warmup paths
/// can be exercised deterministically.
pub accept_delay: Option<Duration>,
}
@ -337,7 +337,7 @@ impl WebSocketTestServer {
/// Waits until at least `expected` websocket handshakes have been observed or timeout elapses.
///
/// Uses a short bounded polling interval so tests can deterministically wait for background
/// preconnect activity without busy-spinning.
/// websocket activity without busy-spinning.
pub async fn wait_for_handshakes(&self, expected: usize, timeout: Duration) -> bool {
if self.handshakes.lock().unwrap().len() >= expected {
return true;

View file

@ -4,6 +4,7 @@ use core_test_support::responses::WebSocketConnectionConfig;
use core_test_support::responses::ev_assistant_message;
use core_test_support::responses::ev_completed;
use core_test_support::responses::ev_done;
use core_test_support::responses::ev_done_with_id;
use core_test_support::responses::ev_response_created;
use core_test_support::responses::ev_shell_command_call;
use core_test_support::responses::start_websocket_server;
@ -38,24 +39,28 @@ async fn websocket_test_codex_shell_chain() -> Result<()> {
let mut builder = test_codex();
let test = builder.build_with_websocket_server(&server).await?;
test.submit_turn("run the echo command").await?;
test.submit_turn_with_policy(
"run the echo command",
test.config.permissions.sandbox_policy.get().clone(),
)
.await?;
let connection = server.single_connection();
assert_eq!(connection.len(), 2);
let first = connection
let first_turn = connection
.first()
.expect("missing first request")
.expect("missing first turn request")
.body_json();
let second = connection
let second_turn = connection
.get(1)
.expect("missing second request")
.expect("missing second turn request")
.body_json();
assert_eq!(first["type"].as_str(), Some("response.create"));
assert_eq!(second["type"].as_str(), Some("response.append"));
assert_eq!(first_turn["type"].as_str(), Some("response.create"));
assert_eq!(second_turn["type"].as_str(), Some("response.append"));
let append_items = second
let append_items = second_turn
.get("input")
.and_then(Value::as_array)
.expect("response.append input array");
@ -75,50 +80,81 @@ async fn websocket_test_codex_shell_chain() -> Result<()> {
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn websocket_preconnect_happens_on_session_start() -> Result<()> {
async fn websocket_first_turn_uses_preconnect_and_create() -> Result<()> {
skip_if_no_network!(Ok(()));
let server = start_websocket_server(vec![vec![vec![
ev_response_created("resp-1"),
ev_assistant_message("msg-1", "hello"),
ev_completed("resp-1"),
]]])
.await;
let mut builder = test_codex();
let test = builder.build_with_websocket_server(&server).await?;
assert!(
server.wait_for_handshakes(1, Duration::from_secs(2)).await,
"expected websocket preconnect handshake during session startup"
);
test.submit_turn("hello").await?;
test.submit_turn_with_policy(
"hello",
test.config.permissions.sandbox_policy.get().clone(),
)
.await?;
assert_eq!(server.handshakes().len(), 1);
assert_eq!(server.single_connection().len(), 1);
let connection = server.single_connection();
assert_eq!(connection.len(), 1);
let turn = connection
.first()
.expect("missing turn request")
.body_json();
assert!(
turn["tools"]
.as_array()
.is_some_and(|tools| !tools.is_empty()),
"expected request tools to be populated"
);
assert_eq!(turn["type"].as_str(), Some("response.create"));
server.shutdown().await;
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn websocket_first_turn_waits_for_inflight_preconnect() -> Result<()> {
async fn websocket_first_turn_handles_handshake_delay_with_preconnect() -> Result<()> {
skip_if_no_network!(Ok(()));
let server = start_websocket_server_with_headers(vec![WebSocketConnectionConfig {
requests: vec![vec![ev_response_created("resp-1"), ev_completed("resp-1")]],
requests: vec![vec![
ev_response_created("resp-1"),
ev_assistant_message("msg-1", "hello"),
ev_completed("resp-1"),
]],
response_headers: Vec::new(),
// Delay handshake so submit_turn() observes startup preconnect as in-flight.
// Delay handshake so turn processing must tolerate websocket startup latency.
accept_delay: Some(Duration::from_millis(150)),
}])
.await;
let mut builder = test_codex();
let test = builder.build_with_websocket_server(&server).await?;
test.submit_turn("hello").await?;
test.submit_turn_with_policy(
"hello",
test.config.permissions.sandbox_policy.get().clone(),
)
.await?;
assert_eq!(server.handshakes().len(), 1);
assert_eq!(server.single_connection().len(), 1);
let connection = server.single_connection();
assert_eq!(connection.len(), 1);
let turn = connection
.first()
.expect("missing turn request")
.body_json();
assert!(
turn["tools"]
.as_array()
.is_some_and(|tools| !tools.is_empty()),
"expected request tools to be populated"
);
assert_eq!(turn["type"].as_str(), Some("response.create"));
server.shutdown().await;
Ok(())
@ -130,6 +166,7 @@ async fn websocket_v2_test_codex_shell_chain() -> Result<()> {
let call_id = "shell-command-call";
let server = start_websocket_server(vec![vec![
vec![ev_response_created("warm-1"), ev_done_with_id("warm-1")],
vec![
ev_response_created("resp-1"),
ev_shell_command_call(call_id, "echo websocket"),
@ -148,25 +185,42 @@ async fn websocket_v2_test_codex_shell_chain() -> Result<()> {
});
let test = builder.build_with_websocket_server(&server).await?;
test.submit_turn("run the echo command").await?;
test.submit_turn_with_policy(
"run the echo command",
test.config.permissions.sandbox_policy.get().clone(),
)
.await?;
let connection = server.single_connection();
assert_eq!(connection.len(), 2);
assert_eq!(connection.len(), 3);
let first = connection
let warmup = connection
.first()
.expect("missing first request")
.expect("missing warmup request")
.body_json();
let second = connection
let first_turn = connection
.get(1)
.expect("missing second request")
.expect("missing first turn request")
.body_json();
let second_turn = connection
.get(2)
.expect("missing second turn request")
.body_json();
assert_eq!(first["type"].as_str(), Some("response.create"));
assert_eq!(second["type"].as_str(), Some("response.create"));
assert_eq!(second["previous_response_id"].as_str(), Some("resp-1"));
assert_eq!(warmup["type"].as_str(), Some("response.create"));
assert_eq!(warmup["generate"].as_bool(), Some(false));
assert_eq!(first_turn["type"].as_str(), Some("response.create"));
assert_eq!(first_turn["previous_response_id"].as_str(), Some("warm-1"));
assert!(
first_turn
.get("input")
.and_then(Value::as_array)
.is_some_and(|items| !items.is_empty())
);
assert_eq!(second_turn["type"].as_str(), Some("response.create"));
assert_eq!(second_turn["previous_response_id"].as_str(), Some("resp-1"));
let create_items = second
let create_items = second_turn
.get("input")
.and_then(Value::as_array)
.expect("response.create input array");

View file

@ -107,9 +107,9 @@ async fn responses_websocket_preconnect_reuses_connection() {
let harness = websocket_harness(&server).await;
let mut client_session = harness.client.new_session();
client_session
.prewarm_websocket(&harness.otel_manager, &harness.model_info)
.preconnect_websocket(&harness.otel_manager, &harness.model_info)
.await
.expect("websocket prewarm failed");
.expect("websocket preconnect failed");
let prompt = prompt_with_input(vec![message_item("hello")]);
stream_until_complete(&mut client_session, &harness, &prompt).await;
@ -119,6 +119,54 @@ async fn responses_websocket_preconnect_reuses_connection() {
server.shutdown().await;
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn responses_websocket_request_prewarm_reuses_connection() {
skip_if_no_network!();
let server = start_websocket_server(vec![vec![
vec![ev_response_created("warm-1"), ev_done_with_id("warm-1")],
vec![ev_response_created("resp-1"), ev_completed("resp-1")],
]])
.await;
let harness = websocket_harness_with_options(&server, false, false, true, true).await;
let mut client_session = harness.client.new_session();
let prompt = prompt_with_input(vec![message_item("hello")]);
client_session
.prewarm_websocket(
&prompt,
&harness.model_info,
&harness.otel_manager,
harness.effort,
harness.summary,
None,
)
.await
.expect("websocket prewarm failed");
stream_until_complete(&mut client_session, &harness, &prompt).await;
assert_eq!(server.handshakes().len(), 1);
let connection = server.single_connection();
assert_eq!(connection.len(), 2);
let warmup = connection
.first()
.expect("missing warmup request")
.body_json();
let follow_up = connection
.get(1)
.expect("missing follow-up request")
.body_json();
assert_eq!(warmup["type"].as_str(), Some("response.create"));
assert_eq!(warmup["generate"].as_bool(), Some(false));
assert_eq!(warmup["tools"], serde_json::json!([]));
assert_eq!(follow_up["type"].as_str(), Some("response.create"));
assert_eq!(follow_up["previous_response_id"].as_str(), Some("warm-1"));
assert_eq!(follow_up["input"], serde_json::json!([]));
server.shutdown().await;
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn responses_websocket_reuses_connection_after_session_drop() {
skip_if_no_network!();
@ -160,9 +208,9 @@ async fn responses_websocket_preconnect_is_reused_even_with_header_changes() {
let harness = websocket_harness(&server).await;
let mut client_session = harness.client.new_session();
client_session
.prewarm_websocket(&harness.otel_manager, &harness.model_info)
.preconnect_websocket(&harness.otel_manager, &harness.model_info)
.await
.expect("websocket prewarm failed");
.expect("websocket preconnect failed");
let prompt = prompt_with_input(vec![message_item("hello")]);
let mut stream = client_session
.stream(
@ -188,6 +236,69 @@ async fn responses_websocket_preconnect_is_reused_even_with_header_changes() {
server.shutdown().await;
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn responses_websocket_request_prewarm_is_reused_even_with_header_changes() {
skip_if_no_network!();
let server = start_websocket_server(vec![vec![
vec![ev_response_created("warm-1"), ev_done_with_id("warm-1")],
vec![ev_response_created("resp-1"), ev_completed("resp-1")],
]])
.await;
let harness = websocket_harness_with_options(&server, false, false, true, true).await;
let mut client_session = harness.client.new_session();
let prompt = prompt_with_input(vec![message_item("hello")]);
client_session
.prewarm_websocket(
&prompt,
&harness.model_info,
&harness.otel_manager,
harness.effort,
harness.summary,
None,
)
.await
.expect("websocket prewarm failed");
let mut stream = client_session
.stream(
&prompt,
&harness.model_info,
&harness.otel_manager,
harness.effort,
harness.summary,
None,
)
.await
.expect("websocket stream failed");
while let Some(event) = stream.next().await {
if matches!(event, Ok(ResponseEvent::Completed { .. })) {
break;
}
}
assert_eq!(server.handshakes().len(), 1);
let connection = server.single_connection();
assert_eq!(connection.len(), 2);
let warmup = connection
.first()
.expect("missing warmup request")
.body_json();
let follow_up = connection
.get(1)
.expect("missing follow-up request")
.body_json();
assert_eq!(warmup["type"].as_str(), Some("response.create"));
assert_eq!(warmup["generate"].as_bool(), Some(false));
assert_eq!(warmup["tools"], serde_json::json!([]));
assert_eq!(follow_up["type"].as_str(), Some("response.create"));
assert_eq!(follow_up["previous_response_id"].as_str(), Some("warm-1"));
assert_eq!(follow_up["input"], serde_json::json!([]));
server.shutdown().await;
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn responses_websocket_prewarm_uses_model_preference_when_feature_disabled() {
skip_if_no_network!();
@ -200,26 +311,39 @@ async fn responses_websocket_prewarm_uses_model_preference_when_feature_disabled
let harness = websocket_harness_with_options(&server, false, false, false, true).await;
let mut client_session = harness.client.new_session();
let prompt = prompt_with_input(vec![message_item("hello")]);
client_session
.prewarm_websocket(&harness.otel_manager, &harness.model_info)
.prewarm_websocket(
&prompt,
&harness.model_info,
&harness.otel_manager,
harness.effort,
harness.summary,
None,
)
.await
.expect("websocket prewarm failed");
// Prewarm should only perform the handshake, not send response.create.
// V1 prewarm only preconnects and should not issue a request.
assert_eq!(server.handshakes().len(), 1);
assert_eq!(server.single_connection().len(), 0);
let prompt = prompt_with_input(vec![message_item("hello")]);
stream_until_complete(&mut client_session, &harness, &prompt).await;
assert_eq!(server.handshakes().len(), 1);
assert_eq!(server.single_connection().len(), 1);
let connection = server.single_connection();
assert_eq!(connection.len(), 1);
let turn = connection
.first()
.expect("missing turn request")
.body_json();
assert_eq!(turn["type"].as_str(), Some("response.create"));
assert_eq!(turn["input"], serde_json::to_value(&prompt.input).unwrap());
server.shutdown().await;
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn responses_websocket_v2_prewarm_runs_when_only_v2_feature_enabled() {
async fn responses_websocket_preconnect_runs_when_only_v2_feature_enabled() {
skip_if_no_network!();
let server = start_websocket_server(vec![vec![vec![
@ -231,9 +355,9 @@ async fn responses_websocket_v2_prewarm_runs_when_only_v2_feature_enabled() {
let harness = websocket_harness_with_options(&server, false, false, true, false).await;
let mut client_session = harness.client.new_session();
client_session
.prewarm_websocket(&harness.otel_manager, &harness.model_info)
.preconnect_websocket(&harness.otel_manager, &harness.model_info)
.await
.expect("websocket prewarm failed");
.expect("websocket preconnect failed");
assert_eq!(server.handshakes().len(), 1);
assert_eq!(server.single_connection().len(), 0);
@ -320,6 +444,50 @@ async fn responses_websocket_v2_requests_use_v2_when_model_prefers_websockets()
server.shutdown().await;
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn responses_websocket_v2_incremental_requests_are_reused_across_turns() {
skip_if_no_network!();
let server = start_websocket_server(vec![vec![
vec![
ev_response_created("resp-1"),
ev_assistant_message("msg-1", "assistant output"),
ev_done_with_id("resp-1"),
],
vec![ev_response_created("resp-2"), ev_completed("resp-2")],
]])
.await;
let harness = websocket_harness_with_options(&server, false, false, true, true).await;
let prompt_one = prompt_with_input(vec![message_item("hello")]);
let prompt_two = prompt_with_input(vec![
message_item("hello"),
assistant_message_item("msg-1", "assistant output"),
message_item("second"),
]);
{
let mut client_session = harness.client.new_session();
stream_until_complete(&mut client_session, &harness, &prompt_one).await;
}
let mut client_session = harness.client.new_session();
stream_until_complete(&mut client_session, &harness, &prompt_two).await;
assert_eq!(server.handshakes().len(), 1);
let connection = server.single_connection();
assert_eq!(connection.len(), 2);
let second = connection.get(1).expect("missing request").body_json();
assert_eq!(second["type"].as_str(), Some("response.create"));
assert_eq!(second["previous_response_id"].as_str(), Some("resp-1"));
assert_eq!(
second["input"],
serde_json::to_value(&prompt_two.input[2..]).unwrap()
);
server.shutdown().await;
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn responses_websocket_v2_wins_when_both_features_enabled() {
skip_if_no_network!();

View file

@ -67,9 +67,8 @@ async fn websocket_fallback_switches_to_http_on_upgrade_required_connect() -> Re
.filter(|req| req.method == Method::POST && req.url.path().ends_with("/responses"))
.count();
// One websocket attempt comes from startup preconnect and one from the first turn's stream
// attempt before fallback activates; after fallback, transport is HTTP. This matches the
// retry-budget tradeoff documented in [`codex_core::client`] module docs.
// Startup prewarm now only preconnects for v1 (one websocket GET with no request body).
// The first turn then attempts websocket once, sees 426, and falls back to HTTP.
assert_eq!(websocket_attempts, 2);
assert_eq!(http_attempts, 1);
assert_eq!(response_mock.requests().len(), 1);
@ -112,7 +111,7 @@ async fn websocket_fallback_switches_to_http_after_retries_exhausted() -> Result
.filter(|req| req.method == Method::POST && req.url.path().ends_with("/responses"))
.count();
// One websocket attempt comes from startup preconnect.
// Deferred request prewarm is attempted at startup.
// The first turn then makes 3 websocket stream attempts (initial try + 2 retries),
// after which fallback activates and the request is replayed over HTTP.
assert_eq!(websocket_attempts, 4);
@ -233,7 +232,8 @@ async fn websocket_fallback_is_sticky_across_turns() -> Result<()> {
.count();
// WebSocket attempts all happen on the first turn:
// 1 startup preconnect + 3 stream attempts (initial try + 2 retries) before fallback.
// 1 deferred request prewarm attempt (startup) + 3 stream attempts
// (initial try + 2 retries) before fallback.
// Fallback is sticky, so the second turn stays on HTTP and adds no websocket attempts.
assert_eq!(websocket_attempts, 4);
assert_eq!(http_attempts, 2);