Turn-state sticky routing per turn (#9332)
- capture the header from SSE/WS handshakes, store it per ModelClientSession using `Oncelock`, echo it on turn-scoped requests, and add SSE+WS integration tests for within-turn persistence + cross-turn reset. - keep `x-codex-turn-state` sticky within a user turn to maintain routing continuity for retries/tool follow-ups.
This commit is contained in:
parent
4125c825f9
commit
ebdd8795e9
11 changed files with 343 additions and 24 deletions
|
|
@ -87,6 +87,7 @@ impl<T: HttpTransport, A: AuthProvider> ChatClient<T, A> {
|
|||
extra_headers,
|
||||
RequestCompression::None,
|
||||
spawn_chat_stream,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ use codex_protocol::protocol::SessionSource;
|
|||
use http::HeaderMap;
|
||||
use serde_json::Value;
|
||||
use std::sync::Arc;
|
||||
use std::sync::OnceLock;
|
||||
use tracing::instrument;
|
||||
|
||||
pub struct ResponsesClient<T: HttpTransport, A: AuthProvider> {
|
||||
|
|
@ -36,6 +37,7 @@ pub struct ResponsesOptions {
|
|||
pub session_source: Option<SessionSource>,
|
||||
pub extra_headers: HeaderMap,
|
||||
pub compression: Compression,
|
||||
pub turn_state: Option<Arc<OnceLock<String>>>,
|
||||
}
|
||||
|
||||
impl<T: HttpTransport, A: AuthProvider> ResponsesClient<T, A> {
|
||||
|
|
@ -58,9 +60,15 @@ impl<T: HttpTransport, A: AuthProvider> ResponsesClient<T, A> {
|
|||
pub async fn stream_request(
|
||||
&self,
|
||||
request: ResponsesRequest,
|
||||
turn_state: Option<Arc<OnceLock<String>>>,
|
||||
) -> Result<ResponseStream, ApiError> {
|
||||
self.stream(request.body, request.headers, request.compression)
|
||||
.await
|
||||
self.stream(
|
||||
request.body,
|
||||
request.headers,
|
||||
request.compression,
|
||||
turn_state,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
#[instrument(level = "trace", skip_all, err)]
|
||||
|
|
@ -80,6 +88,7 @@ impl<T: HttpTransport, A: AuthProvider> ResponsesClient<T, A> {
|
|||
session_source,
|
||||
extra_headers,
|
||||
compression,
|
||||
turn_state,
|
||||
} = options;
|
||||
|
||||
let request = ResponsesRequestBuilder::new(model, &prompt.instructions, &prompt.input)
|
||||
|
|
@ -96,7 +105,7 @@ impl<T: HttpTransport, A: AuthProvider> ResponsesClient<T, A> {
|
|||
.compression(compression)
|
||||
.build(self.streaming.provider())?;
|
||||
|
||||
self.stream_request(request).await
|
||||
self.stream_request(request, turn_state).await
|
||||
}
|
||||
|
||||
fn path(&self) -> &'static str {
|
||||
|
|
@ -111,6 +120,7 @@ impl<T: HttpTransport, A: AuthProvider> ResponsesClient<T, A> {
|
|||
body: Value,
|
||||
extra_headers: HeaderMap,
|
||||
compression: Compression,
|
||||
turn_state: Option<Arc<OnceLock<String>>>,
|
||||
) -> Result<ResponseStream, ApiError> {
|
||||
let compression = match compression {
|
||||
Compression::None => RequestCompression::None,
|
||||
|
|
@ -124,6 +134,7 @@ impl<T: HttpTransport, A: AuthProvider> ResponsesClient<T, A> {
|
|||
extra_headers,
|
||||
compression,
|
||||
spawn_response_stream,
|
||||
turn_state,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ use http::HeaderMap;
|
|||
use http::HeaderValue;
|
||||
use serde_json::Value;
|
||||
use std::sync::Arc;
|
||||
use std::sync::OnceLock;
|
||||
use std::time::Duration;
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::sync::Mutex;
|
||||
|
|
@ -27,6 +28,7 @@ use tracing::trace;
|
|||
use url::Url;
|
||||
|
||||
type WsStream = WebSocketStream<MaybeTlsStream<TcpStream>>;
|
||||
const X_CODEX_TURN_STATE_HEADER: &str = "x-codex-turn-state";
|
||||
|
||||
pub struct ResponsesWebsocketConnection {
|
||||
stream: Arc<Mutex<Option<WsStream>>>,
|
||||
|
|
@ -100,6 +102,7 @@ impl<A: AuthProvider> ResponsesWebsocketClient<A> {
|
|||
pub async fn connect(
|
||||
&self,
|
||||
extra_headers: HeaderMap,
|
||||
turn_state: Option<Arc<OnceLock<String>>>,
|
||||
) -> Result<ResponsesWebsocketConnection, ApiError> {
|
||||
let ws_url = Url::parse(&self.provider.url_for_path("responses"))
|
||||
.map_err(|err| ApiError::Stream(format!("failed to build websocket URL: {err}")))?;
|
||||
|
|
@ -108,7 +111,7 @@ impl<A: AuthProvider> ResponsesWebsocketClient<A> {
|
|||
headers.extend(extra_headers);
|
||||
apply_auth_headers(&mut headers, &self.auth);
|
||||
|
||||
let stream = connect_websocket(ws_url, headers).await?;
|
||||
let stream = connect_websocket(ws_url, headers, turn_state).await?;
|
||||
Ok(ResponsesWebsocketConnection::new(
|
||||
stream,
|
||||
self.provider.stream_idle_timeout,
|
||||
|
|
@ -130,16 +133,28 @@ fn apply_auth_headers(headers: &mut HeaderMap, auth: &impl AuthProvider) {
|
|||
}
|
||||
}
|
||||
|
||||
async fn connect_websocket(url: Url, headers: HeaderMap) -> Result<WsStream, ApiError> {
|
||||
async fn connect_websocket(
|
||||
url: Url,
|
||||
headers: HeaderMap,
|
||||
turn_state: Option<Arc<OnceLock<String>>>,
|
||||
) -> Result<WsStream, ApiError> {
|
||||
let mut request = url
|
||||
.clone()
|
||||
.into_client_request()
|
||||
.map_err(|err| ApiError::Stream(format!("failed to build websocket request: {err}")))?;
|
||||
request.headers_mut().extend(headers);
|
||||
|
||||
let (stream, _) = tokio_tungstenite::connect_async(request)
|
||||
let (stream, response) = tokio_tungstenite::connect_async(request)
|
||||
.await
|
||||
.map_err(|err| map_ws_error(err, &url))?;
|
||||
if let Some(turn_state) = turn_state
|
||||
&& let Some(header_value) = response
|
||||
.headers()
|
||||
.get(X_CODEX_TURN_STATE_HEADER)
|
||||
.and_then(|value| value.to_str().ok())
|
||||
{
|
||||
let _ = turn_state.set(header_value.to_string());
|
||||
}
|
||||
Ok(stream)
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ use http::HeaderMap;
|
|||
use http::Method;
|
||||
use serde_json::Value;
|
||||
use std::sync::Arc;
|
||||
use std::sync::OnceLock;
|
||||
use std::time::Duration;
|
||||
|
||||
pub(crate) struct StreamingClient<T: HttpTransport, A: AuthProvider> {
|
||||
|
|
@ -23,6 +24,13 @@ pub(crate) struct StreamingClient<T: HttpTransport, A: AuthProvider> {
|
|||
sse_telemetry: Option<Arc<dyn SseTelemetry>>,
|
||||
}
|
||||
|
||||
type StreamSpawner = fn(
|
||||
StreamResponse,
|
||||
Duration,
|
||||
Option<Arc<dyn SseTelemetry>>,
|
||||
Option<Arc<OnceLock<String>>>,
|
||||
) -> ResponseStream;
|
||||
|
||||
impl<T: HttpTransport, A: AuthProvider> StreamingClient<T, A> {
|
||||
pub(crate) fn new(transport: T, provider: Provider, auth: A) -> Self {
|
||||
Self {
|
||||
|
|
@ -54,7 +62,8 @@ impl<T: HttpTransport, A: AuthProvider> StreamingClient<T, A> {
|
|||
body: Value,
|
||||
extra_headers: HeaderMap,
|
||||
compression: RequestCompression,
|
||||
spawner: fn(StreamResponse, Duration, Option<Arc<dyn SseTelemetry>>) -> ResponseStream,
|
||||
spawner: StreamSpawner,
|
||||
turn_state: Option<Arc<OnceLock<String>>>,
|
||||
) -> Result<ResponseStream, ApiError> {
|
||||
let builder = || {
|
||||
let mut req = self.provider.build_request(Method::POST, path);
|
||||
|
|
@ -80,6 +89,7 @@ impl<T: HttpTransport, A: AuthProvider> StreamingClient<T, A> {
|
|||
stream_response,
|
||||
self.provider.stream_idle_timeout,
|
||||
self.sse_telemetry.clone(),
|
||||
turn_state,
|
||||
))
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -11,6 +11,8 @@ use futures::Stream;
|
|||
use futures::StreamExt;
|
||||
use std::collections::HashMap;
|
||||
use std::collections::HashSet;
|
||||
use std::sync::Arc;
|
||||
use std::sync::OnceLock;
|
||||
use std::time::Duration;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::time::Instant;
|
||||
|
|
@ -21,7 +23,8 @@ use tracing::trace;
|
|||
pub(crate) fn spawn_chat_stream(
|
||||
stream_response: StreamResponse,
|
||||
idle_timeout: Duration,
|
||||
telemetry: Option<std::sync::Arc<dyn SseTelemetry>>,
|
||||
telemetry: Option<Arc<dyn SseTelemetry>>,
|
||||
_turn_state: Option<Arc<OnceLock<String>>>,
|
||||
) -> ResponseStream {
|
||||
let (tx_event, rx_event) = mpsc::channel::<Result<ResponseEvent, ApiError>>(1600);
|
||||
tokio::spawn(async move {
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ use serde_json::Value;
|
|||
use std::io::BufRead;
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
use std::sync::OnceLock;
|
||||
use std::time::Duration;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::time::Instant;
|
||||
|
|
@ -49,6 +50,7 @@ pub fn spawn_response_stream(
|
|||
stream_response: StreamResponse,
|
||||
idle_timeout: Duration,
|
||||
telemetry: Option<Arc<dyn SseTelemetry>>,
|
||||
turn_state: Option<Arc<OnceLock<String>>>,
|
||||
) -> ResponseStream {
|
||||
let rate_limits = parse_rate_limit(&stream_response.headers);
|
||||
let models_etag = stream_response
|
||||
|
|
@ -56,6 +58,14 @@ pub fn spawn_response_stream(
|
|||
.get("X-Models-Etag")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(ToString::to_string);
|
||||
if let Some(turn_state) = turn_state.as_ref()
|
||||
&& let Some(header_value) = stream_response
|
||||
.headers
|
||||
.get("x-codex-turn-state")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
{
|
||||
let _ = turn_state.set(header_value.to_string());
|
||||
}
|
||||
let (tx_event, rx_event) = mpsc::channel::<Result<ResponseEvent, ApiError>>(1600);
|
||||
tokio::spawn(async move {
|
||||
if let Some(snapshot) = rate_limits {
|
||||
|
|
|
|||
|
|
@ -231,7 +231,7 @@ async fn responses_client_uses_responses_path_for_responses_wire() -> Result<()>
|
|||
|
||||
let body = serde_json::json!({ "echo": true });
|
||||
let _stream = client
|
||||
.stream(body, HeaderMap::new(), Compression::None)
|
||||
.stream(body, HeaderMap::new(), Compression::None, None)
|
||||
.await?;
|
||||
|
||||
let requests = state.take_stream_requests();
|
||||
|
|
@ -247,7 +247,7 @@ async fn responses_client_uses_chat_path_for_chat_wire() -> Result<()> {
|
|||
|
||||
let body = serde_json::json!({ "echo": true });
|
||||
let _stream = client
|
||||
.stream(body, HeaderMap::new(), Compression::None)
|
||||
.stream(body, HeaderMap::new(), Compression::None, None)
|
||||
.await?;
|
||||
|
||||
let requests = state.take_stream_requests();
|
||||
|
|
@ -264,7 +264,7 @@ async fn streaming_client_adds_auth_headers() -> Result<()> {
|
|||
|
||||
let body = serde_json::json!({ "model": "gpt-test" });
|
||||
let _stream = client
|
||||
.stream(body, HeaderMap::new(), Compression::None)
|
||||
.stream(body, HeaderMap::new(), Compression::None, None)
|
||||
.await?;
|
||||
|
||||
let requests = state.take_stream_requests();
|
||||
|
|
|
|||
|
|
@ -129,6 +129,7 @@ async fn responses_stream_parses_items_and_completed_end_to_end() -> Result<()>
|
|||
serde_json::json!({"echo": true}),
|
||||
HeaderMap::new(),
|
||||
Compression::None,
|
||||
None,
|
||||
)
|
||||
.await?;
|
||||
|
||||
|
|
@ -198,6 +199,7 @@ async fn responses_stream_aggregates_output_text_deltas() -> Result<()> {
|
|||
serde_json::json!({"echo": true}),
|
||||
HeaderMap::new(),
|
||||
Compression::None,
|
||||
None,
|
||||
)
|
||||
.await?;
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
use std::sync::Arc;
|
||||
use std::sync::OnceLock;
|
||||
|
||||
use crate::api_bridge::CoreAuthProvider;
|
||||
use crate::api_bridge::auth_provider_from_auth;
|
||||
|
|
@ -66,6 +67,7 @@ use crate::tools::spec::create_tools_json_for_chat_completions_api;
|
|||
use crate::tools::spec::create_tools_json_for_responses_api;
|
||||
|
||||
pub const WEB_SEARCH_ELIGIBLE_HEADER: &str = "x-oai-web-search-eligible";
|
||||
pub const X_CODEX_TURN_STATE_HEADER: &str = "x-codex-turn-state";
|
||||
|
||||
#[derive(Debug)]
|
||||
struct ModelClientState {
|
||||
|
|
@ -89,6 +91,17 @@ pub struct ModelClientSession {
|
|||
state: Arc<ModelClientState>,
|
||||
connection: Option<ApiWebSocketConnection>,
|
||||
websocket_last_items: Vec<ResponseItem>,
|
||||
/// Turn state for sticky routing.
|
||||
///
|
||||
/// This is an `OnceLock` that stores the turn state value received from the server
|
||||
/// on turn start via the `x-codex-turn-state` response header. Once set, this value
|
||||
/// should be sent back to the server in the `x-codex-turn-state` request header for
|
||||
/// all subsequent requests within the same turn to maintain sticky routing.
|
||||
///
|
||||
/// This is a contract between the client and server: we receive it at turn start,
|
||||
/// keep sending it unchanged between turn requests (e.g., for retries, incremental
|
||||
/// appends, or continuation requests), and must not send it between different turns.
|
||||
turn_state: Arc<OnceLock<String>>,
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
|
|
@ -124,6 +137,7 @@ impl ModelClient {
|
|||
state: Arc::clone(&self.state),
|
||||
connection: None,
|
||||
websocket_last_items: Vec::new(),
|
||||
turn_state: Arc::new(OnceLock::new()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -226,7 +240,6 @@ impl ModelClient {
|
|||
extra_headers.insert("x-openai-subagent", val);
|
||||
}
|
||||
}
|
||||
|
||||
client
|
||||
.compact_input(&payload, extra_headers)
|
||||
.await
|
||||
|
|
@ -322,8 +335,9 @@ impl ModelClientSession {
|
|||
store_override: None,
|
||||
conversation_id: Some(conversation_id),
|
||||
session_source: Some(self.state.session_source.clone()),
|
||||
extra_headers: build_responses_headers(&self.state.config),
|
||||
extra_headers: build_responses_headers(&self.state.config, Some(&self.turn_state)),
|
||||
compression,
|
||||
turn_state: Some(Arc::clone(&self.turn_state)),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -397,7 +411,7 @@ impl ModelClientSession {
|
|||
headers.extend(build_conversation_headers(options.conversation_id.clone()));
|
||||
let new_conn: ApiWebSocketConnection =
|
||||
ApiWebSocketResponsesClient::new(api_provider, api_auth)
|
||||
.connect(headers)
|
||||
.connect(headers, options.turn_state.clone())
|
||||
.await?;
|
||||
self.connection = Some(new_conn);
|
||||
}
|
||||
|
|
@ -638,7 +652,10 @@ fn beta_feature_headers(config: &Config) -> ApiHeaderMap {
|
|||
headers
|
||||
}
|
||||
|
||||
fn build_responses_headers(config: &Config) -> ApiHeaderMap {
|
||||
fn build_responses_headers(
|
||||
config: &Config,
|
||||
turn_state: Option<&Arc<OnceLock<String>>>,
|
||||
) -> ApiHeaderMap {
|
||||
let mut headers = beta_feature_headers(config);
|
||||
headers.insert(
|
||||
WEB_SEARCH_ELIGIBLE_HEADER,
|
||||
|
|
@ -650,6 +667,12 @@ fn build_responses_headers(config: &Config) -> ApiHeaderMap {
|
|||
},
|
||||
),
|
||||
);
|
||||
if let Some(turn_state) = turn_state
|
||||
&& let Some(state) = turn_state.get()
|
||||
&& let Ok(header_value) = HeaderValue::from_str(state)
|
||||
{
|
||||
headers.insert(X_CODEX_TURN_STATE_HEADER, header_value);
|
||||
}
|
||||
headers
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -12,6 +12,8 @@ use serde_json::Value;
|
|||
use tokio::net::TcpListener;
|
||||
use tokio::sync::oneshot;
|
||||
use tokio_tungstenite::tungstenite::Message;
|
||||
use tokio_tungstenite::tungstenite::handshake::server::Request;
|
||||
use tokio_tungstenite::tungstenite::handshake::server::Response;
|
||||
use wiremock::BodyPrintLimit;
|
||||
use wiremock::Match;
|
||||
use wiremock::Mock;
|
||||
|
|
@ -19,6 +21,8 @@ use wiremock::MockBuilder;
|
|||
use wiremock::MockServer;
|
||||
use wiremock::Respond;
|
||||
use wiremock::ResponseTemplate;
|
||||
use wiremock::http::HeaderName;
|
||||
use wiremock::http::HeaderValue;
|
||||
use wiremock::matchers::method;
|
||||
use wiremock::matchers::path_regex;
|
||||
|
||||
|
|
@ -216,9 +220,30 @@ impl WebSocketRequest {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct WebSocketHandshake {
|
||||
headers: Vec<(String, String)>,
|
||||
}
|
||||
|
||||
impl WebSocketHandshake {
|
||||
pub fn header(&self, name: &str) -> Option<String> {
|
||||
self.headers
|
||||
.iter()
|
||||
.find(|(header, _)| header.eq_ignore_ascii_case(name))
|
||||
.map(|(_, value)| value.clone())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct WebSocketConnectionConfig {
|
||||
pub requests: Vec<Vec<Value>>,
|
||||
pub response_headers: Vec<(String, String)>,
|
||||
}
|
||||
|
||||
pub struct WebSocketTestServer {
|
||||
uri: String,
|
||||
connections: Arc<Mutex<Vec<Vec<WebSocketRequest>>>>,
|
||||
handshakes: Arc<Mutex<Vec<WebSocketHandshake>>>,
|
||||
shutdown: oneshot::Sender<()>,
|
||||
task: tokio::task::JoinHandle<()>,
|
||||
}
|
||||
|
|
@ -240,6 +265,18 @@ impl WebSocketTestServer {
|
|||
connections.first().cloned().unwrap_or_default()
|
||||
}
|
||||
|
||||
pub fn handshakes(&self) -> Vec<WebSocketHandshake> {
|
||||
self.handshakes.lock().unwrap().clone()
|
||||
}
|
||||
|
||||
pub fn single_handshake(&self) -> WebSocketHandshake {
|
||||
let handshakes = self.handshakes.lock().unwrap();
|
||||
if handshakes.len() != 1 {
|
||||
panic!("expected 1 handshake, got {}", handshakes.len());
|
||||
}
|
||||
handshakes.first().cloned().unwrap()
|
||||
}
|
||||
|
||||
pub async fn shutdown(self) {
|
||||
let _ = self.shutdown.send(());
|
||||
let _ = self.task.await;
|
||||
|
|
@ -786,13 +823,28 @@ pub async fn start_mock_server() -> MockServer {
|
|||
/// request message, the server records the payload and streams the matching
|
||||
/// events as WebSocket text frames before moving to the next request.
|
||||
pub async fn start_websocket_server(connections: Vec<Vec<Vec<Value>>>) -> WebSocketTestServer {
|
||||
let connections = connections
|
||||
.into_iter()
|
||||
.map(|requests| WebSocketConnectionConfig {
|
||||
requests,
|
||||
response_headers: Vec::new(),
|
||||
})
|
||||
.collect();
|
||||
start_websocket_server_with_headers(connections).await
|
||||
}
|
||||
|
||||
pub async fn start_websocket_server_with_headers(
|
||||
connections: Vec<WebSocketConnectionConfig>,
|
||||
) -> WebSocketTestServer {
|
||||
let listener = TcpListener::bind("127.0.0.1:0")
|
||||
.await
|
||||
.expect("bind websocket server");
|
||||
let addr = listener.local_addr().expect("websocket server address");
|
||||
let uri = format!("ws://{addr}");
|
||||
let connections_log = Arc::new(Mutex::new(Vec::new()));
|
||||
let handshakes_log = Arc::new(Mutex::new(Vec::new()));
|
||||
let requests = Arc::clone(&connections_log);
|
||||
let handshakes = Arc::clone(&handshakes_log);
|
||||
let connections = Arc::new(Mutex::new(VecDeque::from(connections)));
|
||||
let (shutdown_tx, mut shutdown_rx) = oneshot::channel();
|
||||
|
||||
|
|
@ -806,27 +858,57 @@ pub async fn start_websocket_server(connections: Vec<Vec<Vec<Value>>>) -> WebSoc
|
|||
Ok(value) => value,
|
||||
Err(_) => return,
|
||||
};
|
||||
let mut ws_stream = match tokio_tungstenite::accept_async(stream).await {
|
||||
Ok(ws) => ws,
|
||||
Err(_) => continue,
|
||||
};
|
||||
|
||||
let connection_requests = {
|
||||
let connection = {
|
||||
let mut pending = connections.lock().unwrap();
|
||||
pending.pop_front()
|
||||
};
|
||||
|
||||
let Some(connection_requests) = connection_requests else {
|
||||
let _ = ws_stream.close(None).await;
|
||||
let Some(connection) = connection else {
|
||||
continue;
|
||||
};
|
||||
|
||||
let response_headers = connection.response_headers.clone();
|
||||
let handshake_log = Arc::clone(&handshakes);
|
||||
let callback = move |req: &Request, mut response: Response| {
|
||||
let headers = req
|
||||
.headers()
|
||||
.iter()
|
||||
.filter_map(|(name, value)| {
|
||||
value
|
||||
.to_str()
|
||||
.ok()
|
||||
.map(|value| (name.as_str().to_string(), value.to_string()))
|
||||
})
|
||||
.collect();
|
||||
handshake_log
|
||||
.lock()
|
||||
.unwrap()
|
||||
.push(WebSocketHandshake { headers });
|
||||
|
||||
let headers_mut = response.headers_mut();
|
||||
for (name, value) in &response_headers {
|
||||
if let (Ok(name), Ok(value)) = (
|
||||
HeaderName::from_bytes(name.as_bytes()),
|
||||
HeaderValue::from_str(value),
|
||||
) {
|
||||
headers_mut.insert(name, value);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(response)
|
||||
};
|
||||
|
||||
let mut ws_stream = match tokio_tungstenite::accept_hdr_async(stream, callback).await {
|
||||
Ok(ws) => ws,
|
||||
Err(_) => continue,
|
||||
};
|
||||
|
||||
let connection_index = {
|
||||
let mut log = requests.lock().unwrap();
|
||||
log.push(Vec::new());
|
||||
log.len() - 1
|
||||
};
|
||||
for request_events in connection_requests {
|
||||
for request_events in connection.requests {
|
||||
let Some(Ok(message)) = ws_stream.next().await else {
|
||||
break;
|
||||
};
|
||||
|
|
@ -858,6 +940,7 @@ pub async fn start_websocket_server(connections: Vec<Vec<Vec<Value>>>) -> WebSoc
|
|||
WebSocketTestServer {
|
||||
uri,
|
||||
connections: connections_log,
|
||||
handshakes: handshakes_log,
|
||||
shutdown: shutdown_tx,
|
||||
task,
|
||||
}
|
||||
|
|
@ -942,6 +1025,45 @@ pub async fn mount_sse_sequence(server: &MockServer, bodies: Vec<String>) -> Res
|
|||
response_mock
|
||||
}
|
||||
|
||||
/// Mounts a sequence of responses for each POST to `/v1/responses`.
|
||||
/// Panics if more requests are received than responses provided.
|
||||
pub async fn mount_response_sequence(
|
||||
server: &MockServer,
|
||||
responses: Vec<ResponseTemplate>,
|
||||
) -> ResponseMock {
|
||||
use std::sync::atomic::AtomicUsize;
|
||||
use std::sync::atomic::Ordering;
|
||||
|
||||
struct SeqResponder {
|
||||
num_calls: AtomicUsize,
|
||||
responses: Vec<ResponseTemplate>,
|
||||
}
|
||||
|
||||
impl Respond for SeqResponder {
|
||||
fn respond(&self, _: &wiremock::Request) -> ResponseTemplate {
|
||||
let call_num = self.num_calls.fetch_add(1, Ordering::SeqCst);
|
||||
self.responses
|
||||
.get(call_num)
|
||||
.unwrap_or_else(|| panic!("no response for {call_num}"))
|
||||
.clone()
|
||||
}
|
||||
}
|
||||
|
||||
let num_calls = responses.len();
|
||||
let responder = SeqResponder {
|
||||
num_calls: AtomicUsize::new(0),
|
||||
responses,
|
||||
};
|
||||
|
||||
let (mock, response_mock) = base_mock();
|
||||
mock.respond_with(responder)
|
||||
.up_to_n_times(num_calls as u64)
|
||||
.expect(num_calls as u64)
|
||||
.mount(server)
|
||||
.await;
|
||||
response_mock
|
||||
}
|
||||
|
||||
/// Validate invariants on the request body sent to `/v1/responses`.
|
||||
///
|
||||
/// - No `function_call_output`/`custom_tool_call_output` with missing/empty `call_id`.
|
||||
|
|
|
|||
122
codex-rs/core/tests/suite/turn_state.rs
Normal file
122
codex-rs/core/tests/suite/turn_state.rs
Normal file
|
|
@ -0,0 +1,122 @@
|
|||
#![allow(clippy::expect_used, clippy::unwrap_used)]
|
||||
|
||||
use anyhow::Result;
|
||||
use core_test_support::responses::WebSocketConnectionConfig;
|
||||
use core_test_support::responses::ev_assistant_message;
|
||||
use core_test_support::responses::ev_completed;
|
||||
use core_test_support::responses::ev_done;
|
||||
use core_test_support::responses::ev_reasoning_item;
|
||||
use core_test_support::responses::ev_response_created;
|
||||
use core_test_support::responses::ev_shell_command_call;
|
||||
use core_test_support::responses::mount_response_sequence;
|
||||
use core_test_support::responses::sse;
|
||||
use core_test_support::responses::sse_response;
|
||||
use core_test_support::responses::start_mock_server;
|
||||
use core_test_support::responses::start_websocket_server_with_headers;
|
||||
use core_test_support::skip_if_no_network;
|
||||
use core_test_support::test_codex::test_codex;
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
const TURN_STATE_HEADER: &str = "x-codex-turn-state";
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn responses_turn_state_persists_within_turn_and_resets_after() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let server = start_mock_server().await;
|
||||
let call_id = "shell-turn-state";
|
||||
|
||||
let first_response = sse(vec![
|
||||
ev_response_created("resp-1"),
|
||||
ev_reasoning_item("rsn-1", &["thinking"], &[]),
|
||||
ev_shell_command_call(call_id, "echo turn-state"),
|
||||
ev_completed("resp-1"),
|
||||
]);
|
||||
let second_response = sse(vec![
|
||||
ev_response_created("resp-2"),
|
||||
ev_assistant_message("msg-1", "done"),
|
||||
ev_completed("resp-2"),
|
||||
]);
|
||||
let third_response = sse(vec![
|
||||
ev_response_created("resp-3"),
|
||||
ev_assistant_message("msg-2", "done"),
|
||||
ev_completed("resp-3"),
|
||||
]);
|
||||
|
||||
// First response sets turn_state; follow-up request in the same turn should echo it.
|
||||
let responses = vec![
|
||||
sse_response(first_response).insert_header(TURN_STATE_HEADER, "ts-1"),
|
||||
sse_response(second_response),
|
||||
sse_response(third_response),
|
||||
];
|
||||
let request_log = mount_response_sequence(&server, responses).await;
|
||||
|
||||
let test = test_codex().build(&server).await?;
|
||||
test.submit_turn("run a shell command").await?;
|
||||
test.submit_turn("second turn").await?;
|
||||
|
||||
let requests = request_log.requests();
|
||||
assert_eq!(requests.len(), 3);
|
||||
// Initial turn request has no header; follow-up has it; next turn clears it.
|
||||
assert_eq!(requests[0].header(TURN_STATE_HEADER), None);
|
||||
assert_eq!(
|
||||
requests[1].header(TURN_STATE_HEADER),
|
||||
Some("ts-1".to_string())
|
||||
);
|
||||
assert_eq!(requests[2].header(TURN_STATE_HEADER), None);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn websocket_turn_state_persists_within_turn_and_resets_after() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let call_id = "ws-shell-turn-state";
|
||||
// First connection delivers turn_state; second (same turn) must send it; third (new turn) must not.
|
||||
let server = start_websocket_server_with_headers(vec![
|
||||
WebSocketConnectionConfig {
|
||||
requests: vec![vec![
|
||||
ev_response_created("resp-1"),
|
||||
ev_reasoning_item("rsn-1", &["thinking"], &[]),
|
||||
ev_shell_command_call(call_id, "echo websocket"),
|
||||
ev_done(),
|
||||
]],
|
||||
response_headers: vec![(TURN_STATE_HEADER.to_string(), "ts-1".to_string())],
|
||||
},
|
||||
WebSocketConnectionConfig {
|
||||
requests: vec![vec![
|
||||
ev_response_created("resp-2"),
|
||||
ev_assistant_message("msg-1", "done"),
|
||||
ev_completed("resp-2"),
|
||||
]],
|
||||
response_headers: Vec::new(),
|
||||
},
|
||||
WebSocketConnectionConfig {
|
||||
requests: vec![vec![
|
||||
ev_response_created("resp-3"),
|
||||
ev_assistant_message("msg-2", "done"),
|
||||
ev_completed("resp-3"),
|
||||
]],
|
||||
response_headers: Vec::new(),
|
||||
},
|
||||
])
|
||||
.await;
|
||||
|
||||
let mut builder = test_codex();
|
||||
let test = builder.build_with_websocket_server(&server).await?;
|
||||
test.submit_turn("run the echo command").await?;
|
||||
test.submit_turn("second turn").await?;
|
||||
|
||||
let handshakes = server.handshakes();
|
||||
assert_eq!(handshakes.len(), 3);
|
||||
assert_eq!(handshakes[0].header(TURN_STATE_HEADER), None);
|
||||
assert_eq!(
|
||||
handshakes[1].header(TURN_STATE_HEADER),
|
||||
Some("ts-1".to_string())
|
||||
);
|
||||
assert_eq!(handshakes[2].header(TURN_STATE_HEADER), None);
|
||||
|
||||
server.shutdown().await;
|
||||
Ok(())
|
||||
}
|
||||
Loading…
Add table
Reference in a new issue