Send warmup request (#11258)
Send a request with `generate: falls` but a full set of tools and instructions to pre-warm inference. --------- Co-authored-by: Codex <noreply@openai.com>
This commit is contained in:
parent
0679e70bfc
commit
97d0068658
9 changed files with 516 additions and 173 deletions
|
|
@ -176,6 +176,7 @@ impl From<&ResponsesApiRequest> for ResponseCreateWsRequest {
|
|||
include: request.include.clone(),
|
||||
prompt_cache_key: request.prompt_cache_key.clone(),
|
||||
text: request.text.clone(),
|
||||
generate: None,
|
||||
client_metadata: None,
|
||||
}
|
||||
}
|
||||
|
|
@ -200,6 +201,8 @@ pub struct ResponseCreateWsRequest {
|
|||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub text: Option<TextControls>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub generate: Option<bool>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub client_metadata: Option<HashMap<String, String>>,
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -12,19 +12,17 @@
|
|||
//! requests during that turn. It caches a Responses WebSocket connection (opened lazily) and stores
|
||||
//! per-turn state such as the `x-codex-turn-state` token used for sticky routing.
|
||||
//!
|
||||
//! Prewarm is intentionally handshake-only: it may warm a socket and capture sticky-routing
|
||||
//! state, but the first `response.create` payload is still sent only when a turn starts.
|
||||
//! WebSocket prewarm is a v2-only `response.create` with `generate=false`; it waits for completion
|
||||
//! so the next request can reuse the same connection and `previous_response_id`.
|
||||
//!
|
||||
//! Startup prewarm is owned by turn-scoped callers (for example, a pre-created regular task). When
|
||||
//! a warmed [`ModelClientSession`] is available, turn execution can reuse it; otherwise the turn
|
||||
//! lazily opens a websocket on first stream call.
|
||||
//! Turn execution performs prewarm as a best-effort step before the first stream request so the
|
||||
//! subsequent request can reuse the same connection.
|
||||
//!
|
||||
//! ## Retry-Budget Tradeoff
|
||||
//!
|
||||
//! Startup prewarm is treated as the first websocket connection attempt for the first turn. If
|
||||
//! it fails, the stream attempt fails and the retry/fallback loop decides whether to retry or fall
|
||||
//! back. This avoids duplicate handshakes but means a failed prewarm can consume one retry
|
||||
//! budget slot before any turn payload is sent.
|
||||
//! V2 request prewarm is treated as the first websocket connection attempt for a turn. If it
|
||||
//! fails, normal stream retry/fallback logic handles recovery on the same turn. V1 prewarm
|
||||
//! remains connection-only.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
|
@ -146,7 +144,7 @@ struct ModelClientState {
|
|||
include_timing_metrics: bool,
|
||||
beta_features_header: Option<String>,
|
||||
disable_websockets: AtomicBool,
|
||||
cached_websocket_connection: StdMutex<Option<ApiWebSocketConnection>>,
|
||||
cached_websocket_session: StdMutex<WebsocketSession>,
|
||||
}
|
||||
|
||||
/// Resolved API client setup for a single request attempt.
|
||||
|
|
@ -191,9 +189,7 @@ pub struct ModelClient {
|
|||
/// contract and can cause routing bugs.
|
||||
pub struct ModelClientSession {
|
||||
client: ModelClient,
|
||||
connection: Option<ApiWebSocketConnection>,
|
||||
websocket_last_request: Option<ResponsesApiRequest>,
|
||||
websocket_last_response_rx: Option<oneshot::Receiver<LastResponse>>,
|
||||
websocket_session: WebsocketSession,
|
||||
/// Turn state for sticky routing.
|
||||
///
|
||||
/// This is an `OnceLock` that stores the turn state value received from the server
|
||||
|
|
@ -214,6 +210,13 @@ struct LastResponse {
|
|||
can_append: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
struct WebsocketSession {
|
||||
connection: Option<ApiWebSocketConnection>,
|
||||
last_request: Option<ResponsesApiRequest>,
|
||||
last_response_rx: Option<oneshot::Receiver<LastResponse>>,
|
||||
}
|
||||
|
||||
enum WebsocketStreamOutcome {
|
||||
Stream(ResponseStream),
|
||||
FallbackToHttp,
|
||||
|
|
@ -248,7 +251,7 @@ impl ModelClient {
|
|||
include_timing_metrics,
|
||||
beta_features_header,
|
||||
disable_websockets: AtomicBool::new(false),
|
||||
cached_websocket_connection: StdMutex::new(None),
|
||||
cached_websocket_session: StdMutex::new(WebsocketSession::default()),
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
|
@ -260,27 +263,26 @@ impl ModelClient {
|
|||
pub fn new_session(&self) -> ModelClientSession {
|
||||
ModelClientSession {
|
||||
client: self.clone(),
|
||||
connection: self.take_cached_websocket_connection(),
|
||||
websocket_last_request: None,
|
||||
websocket_last_response_rx: None,
|
||||
websocket_session: self.take_cached_websocket_session(),
|
||||
turn_state: Arc::new(OnceLock::new()),
|
||||
}
|
||||
}
|
||||
|
||||
fn take_cached_websocket_connection(&self) -> Option<ApiWebSocketConnection> {
|
||||
self.state
|
||||
.cached_websocket_connection
|
||||
fn take_cached_websocket_session(&self) -> WebsocketSession {
|
||||
let mut cached_websocket_session = self
|
||||
.state
|
||||
.cached_websocket_session
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner)
|
||||
.take()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner);
|
||||
std::mem::take(&mut *cached_websocket_session)
|
||||
}
|
||||
|
||||
fn store_cached_websocket_connection(&self, connection: ApiWebSocketConnection) {
|
||||
fn store_cached_websocket_session(&self, websocket_session: WebsocketSession) {
|
||||
*self
|
||||
.state
|
||||
.cached_websocket_connection
|
||||
.cached_websocket_session
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner) = Some(connection);
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner) = websocket_session;
|
||||
}
|
||||
|
||||
/// Compacts the current conversation history using the Compact endpoint.
|
||||
|
|
@ -492,9 +494,9 @@ impl ModelClient {
|
|||
|
||||
impl Drop for ModelClientSession {
|
||||
fn drop(&mut self) {
|
||||
if let Some(connection) = self.connection.take() {
|
||||
self.client.store_cached_websocket_connection(connection);
|
||||
}
|
||||
let websocket_session = std::mem::take(&mut self.websocket_session);
|
||||
self.client
|
||||
.store_cached_websocket_session(websocket_session);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -600,12 +602,13 @@ impl ModelClientSession {
|
|||
&self,
|
||||
request: &ResponsesApiRequest,
|
||||
last_response: Option<&LastResponse>,
|
||||
allow_empty_delta: bool,
|
||||
) -> Option<Vec<ResponseItem>> {
|
||||
// Checks whether the current request is an incremental append to the previous request.
|
||||
// We only append when non-input request fields are unchanged and `input` is a strict
|
||||
// extension of the previous known input. Server-returned output items are treated as part
|
||||
// of the baseline so we do not resend them.
|
||||
let previous_request = self.websocket_last_request.as_ref()?;
|
||||
let previous_request = self.websocket_session.last_request.as_ref()?;
|
||||
let mut previous_without_input = previous_request.clone();
|
||||
previous_without_input.input.clear();
|
||||
let mut request_without_input = request.clone();
|
||||
|
|
@ -623,9 +626,8 @@ impl ModelClientSession {
|
|||
}
|
||||
|
||||
let baseline_len = baseline.len();
|
||||
if baseline_len > 0
|
||||
&& request.input.starts_with(&baseline)
|
||||
&& baseline_len < request.input.len()
|
||||
if request.input.starts_with(&baseline)
|
||||
&& (allow_empty_delta || baseline_len < request.input.len())
|
||||
{
|
||||
Some(request.input[baseline_len..].to_vec())
|
||||
} else {
|
||||
|
|
@ -635,7 +637,8 @@ impl ModelClientSession {
|
|||
}
|
||||
|
||||
fn get_last_response(&mut self) -> Option<LastResponse> {
|
||||
self.websocket_last_response_rx
|
||||
self.websocket_session
|
||||
.last_response_rx
|
||||
.take()
|
||||
.and_then(|mut receiver| match receiver.try_recv() {
|
||||
Ok(last_response) => Some(last_response),
|
||||
|
|
@ -652,7 +655,10 @@ impl ModelClientSession {
|
|||
let Some(last_response) = self.get_last_response() else {
|
||||
return ResponsesWsRequest::ResponseCreate(payload);
|
||||
};
|
||||
let Some(append_items) = self.get_incremental_items(request, Some(&last_response)) else {
|
||||
let allow_empty_delta = matches!(ws_version, ResponsesWebsocketVersion::V2);
|
||||
let Some(append_items) =
|
||||
self.get_incremental_items(request, Some(&last_response), allow_empty_delta)
|
||||
else {
|
||||
return ResponsesWsRequest::ResponseCreate(payload);
|
||||
};
|
||||
|
||||
|
|
@ -682,10 +688,10 @@ impl ModelClientSession {
|
|||
}
|
||||
}
|
||||
|
||||
/// Opportunistically warms a websocket for this turn-scoped client session.
|
||||
/// Opportunistically preconnects a websocket for this turn-scoped client session.
|
||||
///
|
||||
/// This performs only connection setup; it never sends prompt payloads.
|
||||
pub async fn prewarm_websocket(
|
||||
pub async fn preconnect_websocket(
|
||||
&mut self,
|
||||
otel_manager: &OtelManager,
|
||||
model_info: &ModelInfo,
|
||||
|
|
@ -693,7 +699,7 @@ impl ModelClientSession {
|
|||
let Some(ws_version) = self.client.active_ws_version(model_info) else {
|
||||
return Ok(());
|
||||
};
|
||||
if self.connection.is_some() {
|
||||
if self.websocket_session.connection.is_some() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
|
|
@ -714,10 +720,9 @@ impl ModelClientSession {
|
|||
None,
|
||||
)
|
||||
.await?;
|
||||
self.connection = Some(connection);
|
||||
self.websocket_session.connection = Some(connection);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Returns a websocket connection for this turn.
|
||||
async fn websocket_connection(
|
||||
&mut self,
|
||||
|
|
@ -728,14 +733,14 @@ impl ModelClientSession {
|
|||
turn_metadata_header: Option<&str>,
|
||||
options: &ApiResponsesOptions,
|
||||
) -> std::result::Result<&ApiWebSocketConnection, ApiError> {
|
||||
let needs_new = match self.connection.as_ref() {
|
||||
let needs_new = match self.websocket_session.connection.as_ref() {
|
||||
Some(conn) => conn.is_closed().await,
|
||||
None => true,
|
||||
};
|
||||
|
||||
if needs_new {
|
||||
self.websocket_last_request = None;
|
||||
self.websocket_last_response_rx = None;
|
||||
self.websocket_session.last_request = None;
|
||||
self.websocket_session.last_response_rx = None;
|
||||
let turn_state = options
|
||||
.turn_state
|
||||
.clone()
|
||||
|
|
@ -751,12 +756,15 @@ impl ModelClientSession {
|
|||
turn_metadata_header,
|
||||
)
|
||||
.await?;
|
||||
self.connection = Some(new_conn);
|
||||
self.websocket_session.connection = Some(new_conn);
|
||||
}
|
||||
|
||||
self.connection.as_ref().ok_or(ApiError::Stream(
|
||||
"websocket connection is unavailable".to_string(),
|
||||
))
|
||||
self.websocket_session
|
||||
.connection
|
||||
.as_ref()
|
||||
.ok_or(ApiError::Stream(
|
||||
"websocket connection is unavailable".to_string(),
|
||||
))
|
||||
}
|
||||
|
||||
fn responses_request_compression(&self, auth: Option<&crate::auth::CodexAuth>) -> Compression {
|
||||
|
|
@ -848,6 +856,7 @@ impl ModelClientSession {
|
|||
effort: Option<ReasoningEffortConfig>,
|
||||
summary: ReasoningSummaryConfig,
|
||||
turn_metadata_header: Option<&str>,
|
||||
warmup: bool,
|
||||
) -> Result<WebsocketStreamOutcome> {
|
||||
let auth_manager = self.client.state.auth_manager.clone();
|
||||
|
||||
|
|
@ -866,10 +875,13 @@ impl ModelClientSession {
|
|||
effort,
|
||||
summary,
|
||||
)?;
|
||||
let ws_payload = ResponseCreateWsRequest {
|
||||
let mut ws_payload = ResponseCreateWsRequest {
|
||||
client_metadata: build_ws_client_metadata(turn_metadata_header),
|
||||
..ResponseCreateWsRequest::from(&request)
|
||||
};
|
||||
if warmup {
|
||||
ws_payload.generate = Some(false);
|
||||
}
|
||||
|
||||
match self
|
||||
.websocket_connection(
|
||||
|
|
@ -898,8 +910,9 @@ impl ModelClientSession {
|
|||
}
|
||||
|
||||
let ws_request = self.prepare_websocket_request(ws_payload, &request, ws_version);
|
||||
|
||||
self.websocket_session.last_request = Some(request);
|
||||
let stream_result = self
|
||||
.websocket_session
|
||||
.connection
|
||||
.as_ref()
|
||||
.ok_or_else(|| {
|
||||
|
|
@ -910,11 +923,9 @@ impl ModelClientSession {
|
|||
.stream_request(ws_request)
|
||||
.await
|
||||
.map_err(map_api_error)?;
|
||||
self.websocket_last_request = Some(request);
|
||||
let (stream, last_request_rx) =
|
||||
map_response_stream(stream_result, otel_manager.clone());
|
||||
self.websocket_last_response_rx = Some(last_request_rx);
|
||||
|
||||
self.websocket_session.last_response_rx = Some(last_request_rx);
|
||||
return Ok(WebsocketStreamOutcome::Stream(stream));
|
||||
}
|
||||
}
|
||||
|
|
@ -936,6 +947,62 @@ impl ModelClientSession {
|
|||
websocket_telemetry
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub async fn prewarm_websocket(
|
||||
&mut self,
|
||||
prompt: &Prompt,
|
||||
model_info: &ModelInfo,
|
||||
otel_manager: &OtelManager,
|
||||
effort: Option<ReasoningEffortConfig>,
|
||||
summary: ReasoningSummaryConfig,
|
||||
turn_metadata_header: Option<&str>,
|
||||
) -> Result<()> {
|
||||
let Some(ws_version) = self.client.active_ws_version(model_info) else {
|
||||
return Ok(());
|
||||
};
|
||||
if self.websocket_session.last_request.is_some() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
if matches!(ws_version, ResponsesWebsocketVersion::V1) {
|
||||
self.preconnect_websocket(otel_manager, model_info)
|
||||
.await
|
||||
.map_err(map_api_error)?;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
match self
|
||||
.stream_responses_websocket(
|
||||
prompt,
|
||||
model_info,
|
||||
ws_version,
|
||||
otel_manager,
|
||||
effort,
|
||||
summary,
|
||||
turn_metadata_header,
|
||||
true,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(WebsocketStreamOutcome::Stream(mut stream)) => {
|
||||
// Wait for the v2 warmup request to complete before sending the first turn request.
|
||||
while let Some(event) = stream.next().await {
|
||||
match event {
|
||||
Ok(ResponseEvent::Completed { .. }) => break,
|
||||
Err(err) => return Err(err),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
Ok(WebsocketStreamOutcome::FallbackToHttp) => {
|
||||
self.try_switch_fallback_transport(otel_manager, model_info);
|
||||
Ok(())
|
||||
}
|
||||
Err(err) => Err(err),
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
/// Streams a single model request within the current turn.
|
||||
///
|
||||
|
|
@ -965,6 +1032,7 @@ impl ModelClientSession {
|
|||
effort,
|
||||
summary,
|
||||
turn_metadata_header,
|
||||
false,
|
||||
)
|
||||
.await?
|
||||
{
|
||||
|
|
@ -1009,9 +1077,9 @@ impl ModelClientSession {
|
|||
&[("from_wire_api", "responses_websocket")],
|
||||
);
|
||||
|
||||
self.connection = None;
|
||||
self.websocket_last_request = None;
|
||||
self.websocket_last_response_rx = None;
|
||||
self.websocket_session.connection = None;
|
||||
self.websocket_session.last_request = None;
|
||||
self.websocket_session.last_response_rx = None;
|
||||
}
|
||||
activated
|
||||
}
|
||||
|
|
|
|||
|
|
@ -106,6 +106,7 @@ use tokio::sync::Mutex;
|
|||
use tokio::sync::RwLock;
|
||||
use tokio::sync::oneshot;
|
||||
use tokio::sync::watch;
|
||||
use tokio::task::JoinHandle;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::Instrument;
|
||||
use tracing::debug;
|
||||
|
|
@ -1265,7 +1266,7 @@ impl Session {
|
|||
}
|
||||
};
|
||||
session_configuration.thread_name = thread_name.clone();
|
||||
let mut state = SessionState::new(session_configuration.clone());
|
||||
let state = SessionState::new(session_configuration.clone());
|
||||
let managed_network_requirements_enabled = config.managed_network_requirements_enabled();
|
||||
let network_approval = Arc::new(NetworkApprovalService::default());
|
||||
// The managed proxy can call back into core for allowlist-miss decisions.
|
||||
|
|
@ -1372,16 +1373,6 @@ impl Session {
|
|||
config.js_repl_node_module_dirs.clone(),
|
||||
));
|
||||
|
||||
let prewarm_model_info = models_manager
|
||||
.get_model_info(session_configuration.collaboration_mode.model(), &config)
|
||||
.await;
|
||||
let startup_regular_task = RegularTask::with_startup_prewarm(
|
||||
services.model_client.clone(),
|
||||
services.otel_manager.clone(),
|
||||
prewarm_model_info,
|
||||
);
|
||||
state.set_startup_regular_task(startup_regular_task);
|
||||
|
||||
let sess = Arc::new(Session {
|
||||
conversation_id,
|
||||
tx_event: tx_event.clone(),
|
||||
|
|
@ -1399,7 +1390,6 @@ impl Session {
|
|||
let mut guard = network_policy_decider_session.write().await;
|
||||
*guard = Arc::downgrade(&sess);
|
||||
}
|
||||
|
||||
// Dispatch the SessionConfiguredEvent first and then report any errors.
|
||||
// If resuming, include converted initial messages in the payload so UIs can render them immediately.
|
||||
let initial_messages = initial_history.get_event_msgs();
|
||||
|
|
@ -1429,7 +1419,6 @@ impl Session {
|
|||
|
||||
// Start the watcher after SessionConfigured so it cannot emit earlier events.
|
||||
sess.start_file_watcher_listener();
|
||||
|
||||
// Construct sandbox_state before MCP startup so it can be sent to each
|
||||
// MCP server immediately after it becomes ready (avoiding blocking).
|
||||
let sandbox_state = SandboxState {
|
||||
|
|
@ -1490,6 +1479,8 @@ impl Session {
|
|||
));
|
||||
}
|
||||
}
|
||||
sess.schedule_startup_prewarm(session_configuration.base_instructions.clone())
|
||||
.await;
|
||||
|
||||
// record_initial_history can emit events. We record only after the SessionConfiguredEvent is emitted.
|
||||
sess.record_initial_history(initial_history).await;
|
||||
|
|
@ -2155,8 +2146,69 @@ impl Session {
|
|||
}
|
||||
|
||||
pub(crate) async fn take_startup_regular_task(&self) -> Option<RegularTask> {
|
||||
let startup_regular_task = {
|
||||
let mut state = self.state.lock().await;
|
||||
state.take_startup_regular_task()
|
||||
};
|
||||
let startup_regular_task = startup_regular_task?;
|
||||
match startup_regular_task.await {
|
||||
Ok(Ok(regular_task)) => Some(regular_task),
|
||||
Ok(Err(err)) => {
|
||||
warn!("startup websocket prewarm setup failed: {err:#}");
|
||||
None
|
||||
}
|
||||
Err(err) => {
|
||||
warn!("startup websocket prewarm setup join failed: {err}");
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn schedule_startup_prewarm(self: &Arc<Self>, base_instructions: String) {
|
||||
let sess = Arc::clone(self);
|
||||
let startup_regular_task: JoinHandle<CodexResult<RegularTask>> =
|
||||
tokio::spawn(
|
||||
async move { sess.schedule_startup_prewarm_inner(base_instructions).await },
|
||||
);
|
||||
let mut state = self.state.lock().await;
|
||||
state.take_startup_regular_task()
|
||||
state.set_startup_regular_task(startup_regular_task);
|
||||
}
|
||||
|
||||
async fn schedule_startup_prewarm_inner(
|
||||
self: &Arc<Self>,
|
||||
base_instructions: String,
|
||||
) -> CodexResult<RegularTask> {
|
||||
let startup_turn_context = self
|
||||
.new_default_turn_with_sub_id(INITIAL_SUBMIT_ID.to_owned())
|
||||
.await;
|
||||
let startup_cancellation_token = CancellationToken::new();
|
||||
let startup_router = built_tools(
|
||||
self,
|
||||
startup_turn_context.as_ref(),
|
||||
&[],
|
||||
&HashSet::new(),
|
||||
None,
|
||||
&startup_cancellation_token,
|
||||
)
|
||||
.await?;
|
||||
let startup_prompt = build_prompt(
|
||||
Vec::new(),
|
||||
startup_router.as_ref(),
|
||||
startup_turn_context.as_ref(),
|
||||
BaseInstructions {
|
||||
text: base_instructions,
|
||||
},
|
||||
);
|
||||
let startup_turn_metadata_header = startup_turn_context
|
||||
.turn_metadata_state
|
||||
.current_header_value();
|
||||
RegularTask::with_startup_prewarm(
|
||||
self.services.model_client.clone(),
|
||||
startup_prompt,
|
||||
startup_turn_context,
|
||||
startup_turn_metadata_header,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
pub(crate) async fn get_config(&self) -> std::sync::Arc<Config> {
|
||||
|
|
@ -5331,6 +5383,21 @@ fn codex_apps_connector_id(tool: &crate::mcp_connection_manager::ToolInfo) -> Op
|
|||
tool.connector_id.as_deref()
|
||||
}
|
||||
|
||||
fn build_prompt(
|
||||
input: Vec<ResponseItem>,
|
||||
router: &ToolRouter,
|
||||
turn_context: &TurnContext,
|
||||
base_instructions: BaseInstructions,
|
||||
) -> Prompt {
|
||||
Prompt {
|
||||
input,
|
||||
tools: router.specs(),
|
||||
parallel_tool_calls: turn_context.model_info.supports_parallel_tool_calls,
|
||||
base_instructions,
|
||||
personality: turn_context.personality,
|
||||
output_schema: turn_context.final_output_json_schema.clone(),
|
||||
}
|
||||
}
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
#[instrument(level = "trace",
|
||||
skip_all,
|
||||
|
|
@ -5362,19 +5429,14 @@ async fn run_sampling_request(
|
|||
)
|
||||
.await?;
|
||||
|
||||
let model_supports_parallel = turn_context.model_info.supports_parallel_tool_calls;
|
||||
|
||||
let tools = router.specs();
|
||||
let base_instructions = sess.get_base_instructions().await;
|
||||
|
||||
let prompt = Prompt {
|
||||
let prompt = build_prompt(
|
||||
input,
|
||||
tools,
|
||||
parallel_tool_calls: model_supports_parallel,
|
||||
router.as_ref(),
|
||||
turn_context.as_ref(),
|
||||
base_instructions,
|
||||
personality: turn_context.personality,
|
||||
output_schema: turn_context.final_output_json_schema.clone(),
|
||||
};
|
||||
);
|
||||
let mut retries = 0;
|
||||
loop {
|
||||
let err = match try_run_sampling_request(
|
||||
|
|
|
|||
|
|
@ -3,9 +3,11 @@
|
|||
use codex_protocol::models::ResponseItem;
|
||||
use std::collections::HashMap;
|
||||
use std::collections::HashSet;
|
||||
use tokio::task::JoinHandle;
|
||||
|
||||
use crate::codex::SessionConfiguration;
|
||||
use crate::context_manager::ContextManager;
|
||||
use crate::error::Result as CodexResult;
|
||||
use crate::protocol::RateLimitSnapshot;
|
||||
use crate::protocol::TokenUsage;
|
||||
use crate::protocol::TokenUsageInfo;
|
||||
|
|
@ -26,7 +28,7 @@ pub(crate) struct SessionState {
|
|||
/// resume or `/compact`).
|
||||
previous_model: Option<String>,
|
||||
/// Startup regular task pre-created during session initialization.
|
||||
pub(crate) startup_regular_task: Option<RegularTask>,
|
||||
pub(crate) startup_regular_task: Option<JoinHandle<CodexResult<RegularTask>>>,
|
||||
pub(crate) active_mcp_tool_selection: Option<Vec<String>>,
|
||||
pub(crate) active_connector_selection: HashSet<String>,
|
||||
}
|
||||
|
|
@ -155,11 +157,13 @@ impl SessionState {
|
|||
self.dependency_env.clone()
|
||||
}
|
||||
|
||||
pub(crate) fn set_startup_regular_task(&mut self, task: RegularTask) {
|
||||
pub(crate) fn set_startup_regular_task(&mut self, task: JoinHandle<CodexResult<RegularTask>>) {
|
||||
self.startup_regular_task = Some(task);
|
||||
}
|
||||
|
||||
pub(crate) fn take_startup_regular_task(&mut self) -> Option<RegularTask> {
|
||||
pub(crate) fn take_startup_regular_task(
|
||||
&mut self,
|
||||
) -> Option<JoinHandle<CodexResult<RegularTask>>> {
|
||||
self.startup_regular_task.take()
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -3,77 +3,61 @@ use std::sync::Mutex;
|
|||
|
||||
use crate::client::ModelClient;
|
||||
use crate::client::ModelClientSession;
|
||||
use crate::client_common::Prompt;
|
||||
use crate::codex::TurnContext;
|
||||
use crate::codex::run_turn;
|
||||
use crate::error::Result as CodexResult;
|
||||
use crate::state::TaskKind;
|
||||
use async_trait::async_trait;
|
||||
use codex_otel::OtelManager;
|
||||
use codex_protocol::openai_models::ModelInfo;
|
||||
use codex_protocol::user_input::UserInput;
|
||||
use tokio::task::JoinHandle;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::Instrument;
|
||||
use tracing::trace_span;
|
||||
use tracing::warn;
|
||||
|
||||
use super::SessionTask;
|
||||
use super::SessionTaskContext;
|
||||
|
||||
type PrewarmedSessionTask = JoinHandle<Option<ModelClientSession>>;
|
||||
|
||||
pub(crate) struct RegularTask {
|
||||
prewarmed_session_task: Mutex<Option<PrewarmedSessionTask>>,
|
||||
prewarmed_session: Mutex<Option<ModelClientSession>>,
|
||||
}
|
||||
|
||||
impl Default for RegularTask {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
prewarmed_session_task: Mutex::new(None),
|
||||
prewarmed_session: Mutex::new(None),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl RegularTask {
|
||||
pub(crate) fn with_startup_prewarm(
|
||||
pub(crate) async fn with_startup_prewarm(
|
||||
model_client: ModelClient,
|
||||
otel_manager: OtelManager,
|
||||
model_info: ModelInfo,
|
||||
) -> Self {
|
||||
let prewarmed_session_task = tokio::spawn(async move {
|
||||
let mut client_session = model_client.new_session();
|
||||
match client_session
|
||||
.prewarm_websocket(&otel_manager, &model_info)
|
||||
.await
|
||||
{
|
||||
Ok(()) => Some(client_session),
|
||||
Err(err) => {
|
||||
warn!("startup websocket prewarm task failed: {err}");
|
||||
None
|
||||
}
|
||||
}
|
||||
});
|
||||
prompt: Prompt,
|
||||
turn_context: Arc<TurnContext>,
|
||||
turn_metadata_header: Option<String>,
|
||||
) -> CodexResult<Self> {
|
||||
let mut client_session = model_client.new_session();
|
||||
client_session
|
||||
.prewarm_websocket(
|
||||
&prompt,
|
||||
&turn_context.model_info,
|
||||
&turn_context.otel_manager,
|
||||
turn_context.reasoning_effort,
|
||||
turn_context.reasoning_summary,
|
||||
turn_metadata_header.as_deref(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
Self {
|
||||
prewarmed_session_task: Mutex::new(Some(prewarmed_session_task)),
|
||||
}
|
||||
Ok(Self {
|
||||
prewarmed_session: Mutex::new(Some(client_session)),
|
||||
})
|
||||
}
|
||||
|
||||
async fn take_prewarmed_session(&self) -> Option<ModelClientSession> {
|
||||
let prewarmed_session_task = self
|
||||
.prewarmed_session_task
|
||||
self.prewarmed_session
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner)
|
||||
.take();
|
||||
match prewarmed_session_task {
|
||||
Some(task) => match task.await {
|
||||
Ok(client_session) => client_session,
|
||||
Err(err) => {
|
||||
warn!("startup websocket prewarm task join failed: {err}");
|
||||
None
|
||||
}
|
||||
},
|
||||
None => None,
|
||||
}
|
||||
.take()
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -300,8 +300,8 @@ pub struct WebSocketConnectionConfig {
|
|||
pub response_headers: Vec<(String, String)>,
|
||||
/// Optional delay inserted before accepting the websocket handshake.
|
||||
///
|
||||
/// Tests use this to force startup preconnect into an in-flight state so first-turn adoption
|
||||
/// paths can be exercised deterministically.
|
||||
/// Tests use this to force websocket setup into an in-flight state so first-turn warmup paths
|
||||
/// can be exercised deterministically.
|
||||
pub accept_delay: Option<Duration>,
|
||||
}
|
||||
|
||||
|
|
@ -337,7 +337,7 @@ impl WebSocketTestServer {
|
|||
/// Waits until at least `expected` websocket handshakes have been observed or timeout elapses.
|
||||
///
|
||||
/// Uses a short bounded polling interval so tests can deterministically wait for background
|
||||
/// preconnect activity without busy-spinning.
|
||||
/// websocket activity without busy-spinning.
|
||||
pub async fn wait_for_handshakes(&self, expected: usize, timeout: Duration) -> bool {
|
||||
if self.handshakes.lock().unwrap().len() >= expected {
|
||||
return true;
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ use core_test_support::responses::WebSocketConnectionConfig;
|
|||
use core_test_support::responses::ev_assistant_message;
|
||||
use core_test_support::responses::ev_completed;
|
||||
use core_test_support::responses::ev_done;
|
||||
use core_test_support::responses::ev_done_with_id;
|
||||
use core_test_support::responses::ev_response_created;
|
||||
use core_test_support::responses::ev_shell_command_call;
|
||||
use core_test_support::responses::start_websocket_server;
|
||||
|
|
@ -38,24 +39,28 @@ async fn websocket_test_codex_shell_chain() -> Result<()> {
|
|||
let mut builder = test_codex();
|
||||
|
||||
let test = builder.build_with_websocket_server(&server).await?;
|
||||
test.submit_turn("run the echo command").await?;
|
||||
test.submit_turn_with_policy(
|
||||
"run the echo command",
|
||||
test.config.permissions.sandbox_policy.get().clone(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
let connection = server.single_connection();
|
||||
assert_eq!(connection.len(), 2);
|
||||
|
||||
let first = connection
|
||||
let first_turn = connection
|
||||
.first()
|
||||
.expect("missing first request")
|
||||
.expect("missing first turn request")
|
||||
.body_json();
|
||||
let second = connection
|
||||
let second_turn = connection
|
||||
.get(1)
|
||||
.expect("missing second request")
|
||||
.expect("missing second turn request")
|
||||
.body_json();
|
||||
|
||||
assert_eq!(first["type"].as_str(), Some("response.create"));
|
||||
assert_eq!(second["type"].as_str(), Some("response.append"));
|
||||
assert_eq!(first_turn["type"].as_str(), Some("response.create"));
|
||||
assert_eq!(second_turn["type"].as_str(), Some("response.append"));
|
||||
|
||||
let append_items = second
|
||||
let append_items = second_turn
|
||||
.get("input")
|
||||
.and_then(Value::as_array)
|
||||
.expect("response.append input array");
|
||||
|
|
@ -75,50 +80,81 @@ async fn websocket_test_codex_shell_chain() -> Result<()> {
|
|||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn websocket_preconnect_happens_on_session_start() -> Result<()> {
|
||||
async fn websocket_first_turn_uses_preconnect_and_create() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let server = start_websocket_server(vec![vec![vec![
|
||||
ev_response_created("resp-1"),
|
||||
ev_assistant_message("msg-1", "hello"),
|
||||
ev_completed("resp-1"),
|
||||
]]])
|
||||
.await;
|
||||
|
||||
let mut builder = test_codex();
|
||||
let test = builder.build_with_websocket_server(&server).await?;
|
||||
|
||||
assert!(
|
||||
server.wait_for_handshakes(1, Duration::from_secs(2)).await,
|
||||
"expected websocket preconnect handshake during session startup"
|
||||
);
|
||||
|
||||
test.submit_turn("hello").await?;
|
||||
test.submit_turn_with_policy(
|
||||
"hello",
|
||||
test.config.permissions.sandbox_policy.get().clone(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
assert_eq!(server.handshakes().len(), 1);
|
||||
assert_eq!(server.single_connection().len(), 1);
|
||||
let connection = server.single_connection();
|
||||
assert_eq!(connection.len(), 1);
|
||||
let turn = connection
|
||||
.first()
|
||||
.expect("missing turn request")
|
||||
.body_json();
|
||||
assert!(
|
||||
turn["tools"]
|
||||
.as_array()
|
||||
.is_some_and(|tools| !tools.is_empty()),
|
||||
"expected request tools to be populated"
|
||||
);
|
||||
assert_eq!(turn["type"].as_str(), Some("response.create"));
|
||||
|
||||
server.shutdown().await;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn websocket_first_turn_waits_for_inflight_preconnect() -> Result<()> {
|
||||
async fn websocket_first_turn_handles_handshake_delay_with_preconnect() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let server = start_websocket_server_with_headers(vec![WebSocketConnectionConfig {
|
||||
requests: vec![vec![ev_response_created("resp-1"), ev_completed("resp-1")]],
|
||||
requests: vec![vec![
|
||||
ev_response_created("resp-1"),
|
||||
ev_assistant_message("msg-1", "hello"),
|
||||
ev_completed("resp-1"),
|
||||
]],
|
||||
response_headers: Vec::new(),
|
||||
// Delay handshake so submit_turn() observes startup preconnect as in-flight.
|
||||
// Delay handshake so turn processing must tolerate websocket startup latency.
|
||||
accept_delay: Some(Duration::from_millis(150)),
|
||||
}])
|
||||
.await;
|
||||
|
||||
let mut builder = test_codex();
|
||||
let test = builder.build_with_websocket_server(&server).await?;
|
||||
test.submit_turn("hello").await?;
|
||||
test.submit_turn_with_policy(
|
||||
"hello",
|
||||
test.config.permissions.sandbox_policy.get().clone(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
assert_eq!(server.handshakes().len(), 1);
|
||||
assert_eq!(server.single_connection().len(), 1);
|
||||
let connection = server.single_connection();
|
||||
assert_eq!(connection.len(), 1);
|
||||
let turn = connection
|
||||
.first()
|
||||
.expect("missing turn request")
|
||||
.body_json();
|
||||
assert!(
|
||||
turn["tools"]
|
||||
.as_array()
|
||||
.is_some_and(|tools| !tools.is_empty()),
|
||||
"expected request tools to be populated"
|
||||
);
|
||||
assert_eq!(turn["type"].as_str(), Some("response.create"));
|
||||
|
||||
server.shutdown().await;
|
||||
Ok(())
|
||||
|
|
@ -130,6 +166,7 @@ async fn websocket_v2_test_codex_shell_chain() -> Result<()> {
|
|||
|
||||
let call_id = "shell-command-call";
|
||||
let server = start_websocket_server(vec![vec![
|
||||
vec![ev_response_created("warm-1"), ev_done_with_id("warm-1")],
|
||||
vec![
|
||||
ev_response_created("resp-1"),
|
||||
ev_shell_command_call(call_id, "echo websocket"),
|
||||
|
|
@ -148,25 +185,42 @@ async fn websocket_v2_test_codex_shell_chain() -> Result<()> {
|
|||
});
|
||||
|
||||
let test = builder.build_with_websocket_server(&server).await?;
|
||||
test.submit_turn("run the echo command").await?;
|
||||
test.submit_turn_with_policy(
|
||||
"run the echo command",
|
||||
test.config.permissions.sandbox_policy.get().clone(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
let connection = server.single_connection();
|
||||
assert_eq!(connection.len(), 2);
|
||||
assert_eq!(connection.len(), 3);
|
||||
|
||||
let first = connection
|
||||
let warmup = connection
|
||||
.first()
|
||||
.expect("missing first request")
|
||||
.expect("missing warmup request")
|
||||
.body_json();
|
||||
let second = connection
|
||||
let first_turn = connection
|
||||
.get(1)
|
||||
.expect("missing second request")
|
||||
.expect("missing first turn request")
|
||||
.body_json();
|
||||
let second_turn = connection
|
||||
.get(2)
|
||||
.expect("missing second turn request")
|
||||
.body_json();
|
||||
|
||||
assert_eq!(first["type"].as_str(), Some("response.create"));
|
||||
assert_eq!(second["type"].as_str(), Some("response.create"));
|
||||
assert_eq!(second["previous_response_id"].as_str(), Some("resp-1"));
|
||||
assert_eq!(warmup["type"].as_str(), Some("response.create"));
|
||||
assert_eq!(warmup["generate"].as_bool(), Some(false));
|
||||
assert_eq!(first_turn["type"].as_str(), Some("response.create"));
|
||||
assert_eq!(first_turn["previous_response_id"].as_str(), Some("warm-1"));
|
||||
assert!(
|
||||
first_turn
|
||||
.get("input")
|
||||
.and_then(Value::as_array)
|
||||
.is_some_and(|items| !items.is_empty())
|
||||
);
|
||||
assert_eq!(second_turn["type"].as_str(), Some("response.create"));
|
||||
assert_eq!(second_turn["previous_response_id"].as_str(), Some("resp-1"));
|
||||
|
||||
let create_items = second
|
||||
let create_items = second_turn
|
||||
.get("input")
|
||||
.and_then(Value::as_array)
|
||||
.expect("response.create input array");
|
||||
|
|
|
|||
|
|
@ -107,9 +107,9 @@ async fn responses_websocket_preconnect_reuses_connection() {
|
|||
let harness = websocket_harness(&server).await;
|
||||
let mut client_session = harness.client.new_session();
|
||||
client_session
|
||||
.prewarm_websocket(&harness.otel_manager, &harness.model_info)
|
||||
.preconnect_websocket(&harness.otel_manager, &harness.model_info)
|
||||
.await
|
||||
.expect("websocket prewarm failed");
|
||||
.expect("websocket preconnect failed");
|
||||
let prompt = prompt_with_input(vec![message_item("hello")]);
|
||||
stream_until_complete(&mut client_session, &harness, &prompt).await;
|
||||
|
||||
|
|
@ -119,6 +119,54 @@ async fn responses_websocket_preconnect_reuses_connection() {
|
|||
server.shutdown().await;
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn responses_websocket_request_prewarm_reuses_connection() {
|
||||
skip_if_no_network!();
|
||||
|
||||
let server = start_websocket_server(vec![vec![
|
||||
vec![ev_response_created("warm-1"), ev_done_with_id("warm-1")],
|
||||
vec![ev_response_created("resp-1"), ev_completed("resp-1")],
|
||||
]])
|
||||
.await;
|
||||
|
||||
let harness = websocket_harness_with_options(&server, false, false, true, true).await;
|
||||
let mut client_session = harness.client.new_session();
|
||||
let prompt = prompt_with_input(vec![message_item("hello")]);
|
||||
client_session
|
||||
.prewarm_websocket(
|
||||
&prompt,
|
||||
&harness.model_info,
|
||||
&harness.otel_manager,
|
||||
harness.effort,
|
||||
harness.summary,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.expect("websocket prewarm failed");
|
||||
stream_until_complete(&mut client_session, &harness, &prompt).await;
|
||||
|
||||
assert_eq!(server.handshakes().len(), 1);
|
||||
let connection = server.single_connection();
|
||||
assert_eq!(connection.len(), 2);
|
||||
let warmup = connection
|
||||
.first()
|
||||
.expect("missing warmup request")
|
||||
.body_json();
|
||||
let follow_up = connection
|
||||
.get(1)
|
||||
.expect("missing follow-up request")
|
||||
.body_json();
|
||||
|
||||
assert_eq!(warmup["type"].as_str(), Some("response.create"));
|
||||
assert_eq!(warmup["generate"].as_bool(), Some(false));
|
||||
assert_eq!(warmup["tools"], serde_json::json!([]));
|
||||
assert_eq!(follow_up["type"].as_str(), Some("response.create"));
|
||||
assert_eq!(follow_up["previous_response_id"].as_str(), Some("warm-1"));
|
||||
assert_eq!(follow_up["input"], serde_json::json!([]));
|
||||
|
||||
server.shutdown().await;
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn responses_websocket_reuses_connection_after_session_drop() {
|
||||
skip_if_no_network!();
|
||||
|
|
@ -160,9 +208,9 @@ async fn responses_websocket_preconnect_is_reused_even_with_header_changes() {
|
|||
let harness = websocket_harness(&server).await;
|
||||
let mut client_session = harness.client.new_session();
|
||||
client_session
|
||||
.prewarm_websocket(&harness.otel_manager, &harness.model_info)
|
||||
.preconnect_websocket(&harness.otel_manager, &harness.model_info)
|
||||
.await
|
||||
.expect("websocket prewarm failed");
|
||||
.expect("websocket preconnect failed");
|
||||
let prompt = prompt_with_input(vec![message_item("hello")]);
|
||||
let mut stream = client_session
|
||||
.stream(
|
||||
|
|
@ -188,6 +236,69 @@ async fn responses_websocket_preconnect_is_reused_even_with_header_changes() {
|
|||
server.shutdown().await;
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn responses_websocket_request_prewarm_is_reused_even_with_header_changes() {
|
||||
skip_if_no_network!();
|
||||
|
||||
let server = start_websocket_server(vec![vec![
|
||||
vec![ev_response_created("warm-1"), ev_done_with_id("warm-1")],
|
||||
vec![ev_response_created("resp-1"), ev_completed("resp-1")],
|
||||
]])
|
||||
.await;
|
||||
|
||||
let harness = websocket_harness_with_options(&server, false, false, true, true).await;
|
||||
let mut client_session = harness.client.new_session();
|
||||
let prompt = prompt_with_input(vec![message_item("hello")]);
|
||||
client_session
|
||||
.prewarm_websocket(
|
||||
&prompt,
|
||||
&harness.model_info,
|
||||
&harness.otel_manager,
|
||||
harness.effort,
|
||||
harness.summary,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.expect("websocket prewarm failed");
|
||||
let mut stream = client_session
|
||||
.stream(
|
||||
&prompt,
|
||||
&harness.model_info,
|
||||
&harness.otel_manager,
|
||||
harness.effort,
|
||||
harness.summary,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.expect("websocket stream failed");
|
||||
|
||||
while let Some(event) = stream.next().await {
|
||||
if matches!(event, Ok(ResponseEvent::Completed { .. })) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
assert_eq!(server.handshakes().len(), 1);
|
||||
let connection = server.single_connection();
|
||||
assert_eq!(connection.len(), 2);
|
||||
let warmup = connection
|
||||
.first()
|
||||
.expect("missing warmup request")
|
||||
.body_json();
|
||||
let follow_up = connection
|
||||
.get(1)
|
||||
.expect("missing follow-up request")
|
||||
.body_json();
|
||||
assert_eq!(warmup["type"].as_str(), Some("response.create"));
|
||||
assert_eq!(warmup["generate"].as_bool(), Some(false));
|
||||
assert_eq!(warmup["tools"], serde_json::json!([]));
|
||||
assert_eq!(follow_up["type"].as_str(), Some("response.create"));
|
||||
assert_eq!(follow_up["previous_response_id"].as_str(), Some("warm-1"));
|
||||
assert_eq!(follow_up["input"], serde_json::json!([]));
|
||||
|
||||
server.shutdown().await;
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn responses_websocket_prewarm_uses_model_preference_when_feature_disabled() {
|
||||
skip_if_no_network!();
|
||||
|
|
@ -200,26 +311,39 @@ async fn responses_websocket_prewarm_uses_model_preference_when_feature_disabled
|
|||
|
||||
let harness = websocket_harness_with_options(&server, false, false, false, true).await;
|
||||
let mut client_session = harness.client.new_session();
|
||||
let prompt = prompt_with_input(vec![message_item("hello")]);
|
||||
client_session
|
||||
.prewarm_websocket(&harness.otel_manager, &harness.model_info)
|
||||
.prewarm_websocket(
|
||||
&prompt,
|
||||
&harness.model_info,
|
||||
&harness.otel_manager,
|
||||
harness.effort,
|
||||
harness.summary,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.expect("websocket prewarm failed");
|
||||
|
||||
// Prewarm should only perform the handshake, not send response.create.
|
||||
// V1 prewarm only preconnects and should not issue a request.
|
||||
assert_eq!(server.handshakes().len(), 1);
|
||||
assert_eq!(server.single_connection().len(), 0);
|
||||
|
||||
let prompt = prompt_with_input(vec![message_item("hello")]);
|
||||
stream_until_complete(&mut client_session, &harness, &prompt).await;
|
||||
|
||||
assert_eq!(server.handshakes().len(), 1);
|
||||
assert_eq!(server.single_connection().len(), 1);
|
||||
let connection = server.single_connection();
|
||||
assert_eq!(connection.len(), 1);
|
||||
let turn = connection
|
||||
.first()
|
||||
.expect("missing turn request")
|
||||
.body_json();
|
||||
assert_eq!(turn["type"].as_str(), Some("response.create"));
|
||||
assert_eq!(turn["input"], serde_json::to_value(&prompt.input).unwrap());
|
||||
|
||||
server.shutdown().await;
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn responses_websocket_v2_prewarm_runs_when_only_v2_feature_enabled() {
|
||||
async fn responses_websocket_preconnect_runs_when_only_v2_feature_enabled() {
|
||||
skip_if_no_network!();
|
||||
|
||||
let server = start_websocket_server(vec![vec![vec![
|
||||
|
|
@ -231,9 +355,9 @@ async fn responses_websocket_v2_prewarm_runs_when_only_v2_feature_enabled() {
|
|||
let harness = websocket_harness_with_options(&server, false, false, true, false).await;
|
||||
let mut client_session = harness.client.new_session();
|
||||
client_session
|
||||
.prewarm_websocket(&harness.otel_manager, &harness.model_info)
|
||||
.preconnect_websocket(&harness.otel_manager, &harness.model_info)
|
||||
.await
|
||||
.expect("websocket prewarm failed");
|
||||
.expect("websocket preconnect failed");
|
||||
|
||||
assert_eq!(server.handshakes().len(), 1);
|
||||
assert_eq!(server.single_connection().len(), 0);
|
||||
|
|
@ -320,6 +444,50 @@ async fn responses_websocket_v2_requests_use_v2_when_model_prefers_websockets()
|
|||
server.shutdown().await;
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn responses_websocket_v2_incremental_requests_are_reused_across_turns() {
|
||||
skip_if_no_network!();
|
||||
|
||||
let server = start_websocket_server(vec![vec![
|
||||
vec![
|
||||
ev_response_created("resp-1"),
|
||||
ev_assistant_message("msg-1", "assistant output"),
|
||||
ev_done_with_id("resp-1"),
|
||||
],
|
||||
vec![ev_response_created("resp-2"), ev_completed("resp-2")],
|
||||
]])
|
||||
.await;
|
||||
|
||||
let harness = websocket_harness_with_options(&server, false, false, true, true).await;
|
||||
let prompt_one = prompt_with_input(vec![message_item("hello")]);
|
||||
let prompt_two = prompt_with_input(vec![
|
||||
message_item("hello"),
|
||||
assistant_message_item("msg-1", "assistant output"),
|
||||
message_item("second"),
|
||||
]);
|
||||
|
||||
{
|
||||
let mut client_session = harness.client.new_session();
|
||||
stream_until_complete(&mut client_session, &harness, &prompt_one).await;
|
||||
}
|
||||
|
||||
let mut client_session = harness.client.new_session();
|
||||
stream_until_complete(&mut client_session, &harness, &prompt_two).await;
|
||||
|
||||
assert_eq!(server.handshakes().len(), 1);
|
||||
let connection = server.single_connection();
|
||||
assert_eq!(connection.len(), 2);
|
||||
let second = connection.get(1).expect("missing request").body_json();
|
||||
assert_eq!(second["type"].as_str(), Some("response.create"));
|
||||
assert_eq!(second["previous_response_id"].as_str(), Some("resp-1"));
|
||||
assert_eq!(
|
||||
second["input"],
|
||||
serde_json::to_value(&prompt_two.input[2..]).unwrap()
|
||||
);
|
||||
|
||||
server.shutdown().await;
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn responses_websocket_v2_wins_when_both_features_enabled() {
|
||||
skip_if_no_network!();
|
||||
|
|
|
|||
|
|
@ -67,9 +67,8 @@ async fn websocket_fallback_switches_to_http_on_upgrade_required_connect() -> Re
|
|||
.filter(|req| req.method == Method::POST && req.url.path().ends_with("/responses"))
|
||||
.count();
|
||||
|
||||
// One websocket attempt comes from startup preconnect and one from the first turn's stream
|
||||
// attempt before fallback activates; after fallback, transport is HTTP. This matches the
|
||||
// retry-budget tradeoff documented in [`codex_core::client`] module docs.
|
||||
// Startup prewarm now only preconnects for v1 (one websocket GET with no request body).
|
||||
// The first turn then attempts websocket once, sees 426, and falls back to HTTP.
|
||||
assert_eq!(websocket_attempts, 2);
|
||||
assert_eq!(http_attempts, 1);
|
||||
assert_eq!(response_mock.requests().len(), 1);
|
||||
|
|
@ -112,7 +111,7 @@ async fn websocket_fallback_switches_to_http_after_retries_exhausted() -> Result
|
|||
.filter(|req| req.method == Method::POST && req.url.path().ends_with("/responses"))
|
||||
.count();
|
||||
|
||||
// One websocket attempt comes from startup preconnect.
|
||||
// Deferred request prewarm is attempted at startup.
|
||||
// The first turn then makes 3 websocket stream attempts (initial try + 2 retries),
|
||||
// after which fallback activates and the request is replayed over HTTP.
|
||||
assert_eq!(websocket_attempts, 4);
|
||||
|
|
@ -233,7 +232,8 @@ async fn websocket_fallback_is_sticky_across_turns() -> Result<()> {
|
|||
.count();
|
||||
|
||||
// WebSocket attempts all happen on the first turn:
|
||||
// 1 startup preconnect + 3 stream attempts (initial try + 2 retries) before fallback.
|
||||
// 1 deferred request prewarm attempt (startup) + 3 stream attempts
|
||||
// (initial try + 2 retries) before fallback.
|
||||
// Fallback is sticky, so the second turn stays on HTTP and adds no websocket attempts.
|
||||
assert_eq!(websocket_attempts, 4);
|
||||
assert_eq!(http_attempts, 2);
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue