Reuse websocket connection (#9127)

Reuses the connection but still sends full requests.
This commit is contained in:
pakrym-oai 2026-01-12 19:30:09 -08:00 committed by GitHub
parent 12779c7c07
commit d75626ad99
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 293 additions and 120 deletions

View file

@ -16,8 +16,10 @@ use futures::StreamExt;
use http::HeaderMap;
use http::HeaderValue;
use serde_json::Value;
use std::sync::Arc;
use std::time::Duration;
use tokio::net::TcpStream;
use tokio::sync::Mutex;
use tokio::sync::mpsc;
use tokio_tungstenite::MaybeTlsStream;
use tokio_tungstenite::WebSocketStream;
@ -31,6 +33,69 @@ use url::Url;
type WsStream = WebSocketStream<MaybeTlsStream<TcpStream>>;
pub struct ResponsesWebsocketConnection {
stream: Arc<Mutex<Option<WsStream>>>,
// TODO (pakrym): is this the right place for timeout?
idle_timeout: Duration,
}
impl ResponsesWebsocketConnection {
fn new(stream: WsStream, idle_timeout: Duration) -> Self {
Self {
stream: Arc::new(Mutex::new(Some(stream))),
idle_timeout,
}
}
pub async fn is_closed(&self) -> bool {
self.stream.lock().await.is_none()
}
pub async fn stream_request(
&self,
request: ResponsesRequest,
) -> Result<ResponseStream, ApiError> {
if request.compression == Compression::Zstd {
warn!(
"request compression is not supported for websocket streaming; sending uncompressed payload"
);
}
let (tx_event, rx_event) =
mpsc::channel::<std::result::Result<ResponseEvent, ApiError>>(1600);
let stream = Arc::clone(&self.stream);
let idle_timeout = self.idle_timeout;
let request_body = request.body;
tokio::spawn(async move {
let mut guard = stream.lock().await;
let Some(ws_stream) = guard.as_mut() else {
let _ = tx_event
.send(Err(ApiError::Stream(
"websocket connection is closed".to_string(),
)))
.await;
return;
};
if let Err(err) = run_websocket_response_stream(
ws_stream,
tx_event.clone(),
request_body,
idle_timeout,
)
.await
{
let _ = ws_stream.close(None).await;
*guard = None;
let _ = tx_event.send(Err(err)).await;
}
});
Ok(ResponseStream { rx_event })
}
}
pub struct ResponsesWebsocketClient<A: AuthProvider> {
provider: Provider,
auth: A,
@ -41,12 +106,22 @@ impl<A: AuthProvider> ResponsesWebsocketClient<A> {
Self { provider, auth }
}
pub async fn stream_request(
pub async fn connect(
&self,
request: ResponsesRequest,
) -> Result<ResponseStream, ApiError> {
self.stream(request.body, request.headers, request.compression)
.await
extra_headers: HeaderMap,
) -> Result<ResponsesWebsocketConnection, ApiError> {
let ws_url = Url::parse(&self.provider.url_for_path("responses"))
.map_err(|err| ApiError::Stream(format!("failed to build websocket URL: {err}")))?;
let mut headers = self.provider.headers.clone();
headers.extend(extra_headers);
apply_auth_headers(&mut headers, &self.auth);
let stream = connect_websocket(ws_url, headers).await?;
Ok(ResponsesWebsocketConnection::new(
stream,
self.provider.stream_idle_timeout,
))
}
pub async fn stream_prompt(
@ -82,7 +157,8 @@ impl<A: AuthProvider> ResponsesWebsocketClient<A> {
.compression(compression)
.build(&self.provider)?;
self.stream_request(request).await
let connection = self.connect(request.headers.clone()).await?;
connection.stream_request(request).await
}
pub async fn stream(
@ -91,41 +167,13 @@ impl<A: AuthProvider> ResponsesWebsocketClient<A> {
extra_headers: HeaderMap,
compression: Compression,
) -> Result<ResponseStream, ApiError> {
if compression == Compression::Zstd {
warn!(
"request compression is not supported for websocket streaming; sending uncompressed payload"
);
}
let ws_url = Url::parse(&self.provider.url_for_path("responses"))
.map_err(|err| ApiError::Stream(format!("failed to build websocket URL: {err}")))?;
let mut headers = self.provider.headers.clone();
headers.extend(extra_headers);
apply_auth_headers(&mut headers, &self.auth);
let connection = connect_websocket(ws_url, headers).await?;
let (tx_event, rx_event) =
mpsc::channel::<std::result::Result<ResponseEvent, ApiError>>(1600);
let idle_timeout = self.provider.stream_idle_timeout;
// TODO (pakrym): surface rate limits
// TODO (pakrym): check models etags
tokio::spawn(async move {
if let Err(err) = run_websocket_response_stream(
connection.stream,
tx_event.clone(),
body,
idle_timeout,
)
.await
{
let _ = tx_event.send(Err(err)).await;
}
});
Ok(ResponseStream { rx_event })
let request = ResponsesRequest {
body,
headers: extra_headers,
compression,
};
let connection = self.connect(request.headers.clone()).await?;
connection.stream_request(request).await
}
}
@ -143,11 +191,7 @@ fn apply_auth_headers(headers: &mut HeaderMap, auth: &impl AuthProvider) {
}
}
struct WebSocketConnection {
stream: WsStream,
}
async fn connect_websocket(url: Url, headers: HeaderMap) -> Result<WebSocketConnection, ApiError> {
async fn connect_websocket(url: Url, headers: HeaderMap) -> Result<WsStream, ApiError> {
let mut request = url
.clone()
.into_client_request()
@ -157,7 +201,7 @@ async fn connect_websocket(url: Url, headers: HeaderMap) -> Result<WebSocketConn
let (stream, _) = tokio_tungstenite::connect_async(request)
.await
.map_err(|err| map_ws_error(err, &url))?;
Ok(WebSocketConnection { stream })
Ok(stream)
}
fn map_ws_error(err: WsError, url: &Url) -> ApiError {
@ -185,7 +229,7 @@ fn map_ws_error(err: WsError, url: &Url) -> ApiError {
}
async fn run_websocket_response_stream(
mut ws_stream: WsStream,
ws_stream: &mut WsStream,
tx_event: mpsc::Sender<std::result::Result<ResponseEvent, ApiError>>,
request_body: Value,
idle_timeout: Duration,
@ -193,7 +237,6 @@ async fn run_websocket_response_stream(
let request_text = match serde_json::to_string(&request_body) {
Ok(text) => text,
Err(err) => {
let _ = ws_stream.close(None).await;
return Err(ApiError::Stream(format!(
"failed to encode websocket request: {err}"
)));
@ -201,7 +244,6 @@ async fn run_websocket_response_stream(
};
if let Err(err) = ws_stream.send(Message::Text(request_text)).await {
let _ = ws_stream.close(None).await;
return Err(ApiError::Stream(format!(
"failed to send websocket request: {err}"
)));
@ -214,17 +256,14 @@ async fn run_websocket_response_stream(
let message = match response {
Ok(Some(Ok(msg))) => msg,
Ok(Some(Err(err))) => {
let _ = ws_stream.close(None).await;
return Err(ApiError::Stream(err.to_string()));
}
Ok(None) => {
let _ = ws_stream.close(None).await;
return Err(ApiError::Stream(
"stream closed before response.completed".into(),
));
}
Err(err) => {
let _ = ws_stream.close(None).await;
return Err(err);
}
};
@ -249,24 +288,20 @@ async fn run_websocket_response_stream(
}
Ok(None) => {}
Err(error) => {
let _ = ws_stream.close(None).await;
return Err(error.into_api_error());
}
}
}
Message::Binary(_) => {
let _ = ws_stream.close(None).await;
return Err(ApiError::Stream("unexpected binary websocket event".into()));
}
Message::Ping(payload) => {
if ws_stream.send(Message::Pong(payload)).await.is_err() {
let _ = ws_stream.close(None).await;
return Err(ApiError::Stream("websocket ping failed".into()));
}
}
Message::Pong(_) => {}
Message::Close(_) => {
let _ = ws_stream.close(None).await;
return Err(ApiError::Stream(
"websocket closed before response.completed".into(),
));
@ -275,6 +310,5 @@ async fn run_websocket_response_stream(
}
}
let _ = ws_stream.close(None).await;
Ok(())
}

View file

@ -26,6 +26,7 @@ pub use crate::endpoint::models::ModelsClient;
pub use crate::endpoint::responses::ResponsesClient;
pub use crate::endpoint::responses::ResponsesOptions;
pub use crate::endpoint::responses_websocket::ResponsesWebsocketClient;
pub use crate::endpoint::responses_websocket::ResponsesWebsocketConnection;
pub use crate::error::ApiError;
pub use crate::provider::Provider;
pub use crate::provider::WireApi;

View file

@ -1,5 +1,6 @@
use std::sync::Arc;
use crate::api_bridge::CoreAuthProvider;
use crate::api_bridge::auth_provider_from_auth;
use crate::api_bridge::map_api_error;
use crate::auth::UnauthorizedRecovery;
@ -13,7 +14,10 @@ use codex_api::ReqwestTransport;
use codex_api::ResponseStream as ApiResponseStream;
use codex_api::ResponsesClient as ApiResponsesClient;
use codex_api::ResponsesOptions as ApiResponsesOptions;
use codex_api::ResponsesRequest;
use codex_api::ResponsesRequestBuilder;
use codex_api::ResponsesWebsocketClient as ApiWebSocketResponsesClient;
use codex_api::ResponsesWebsocketConnection as ApiWebSocketConnection;
use codex_api::SseTelemetry;
use codex_api::TransportError;
use codex_api::common::Reasoning;
@ -76,9 +80,9 @@ pub struct ModelClient {
state: Arc<ModelClientState>,
}
#[derive(Debug, Clone)]
pub struct ModelClientSession {
state: Arc<ModelClientState>,
connection: Option<ApiWebSocketConnection>,
}
#[allow(clippy::too_many_arguments)]
@ -112,6 +116,7 @@ impl ModelClient {
pub fn new_session(&self) -> ModelClientSession {
ModelClientSession {
state: Arc::clone(&self.state),
connection: None,
}
}
}
@ -228,7 +233,7 @@ impl ModelClientSession {
///
/// For Chat providers, the underlying stream is optionally aggregated
/// based on the `show_raw_agent_reasoning` flag in the config.
pub async fn stream(&self, prompt: &Prompt) -> Result<ResponseStream> {
pub async fn stream(&mut self, prompt: &Prompt) -> Result<ResponseStream> {
match self.state.provider.wire_api {
WireApi::Responses => self.stream_responses_api(prompt).await,
WireApi::ResponsesWebsocket => self.stream_responses_websocket(prompt).await,
@ -315,6 +320,67 @@ impl ModelClientSession {
}
}
fn build_responses_websocket_request(
&self,
api_provider: &codex_api::Provider,
api_prompt: &ApiPrompt,
options: ApiResponsesOptions,
) -> Result<ResponsesRequest> {
let ApiResponsesOptions {
reasoning,
include,
prompt_cache_key,
text,
store_override,
conversation_id,
session_source,
extra_headers,
compression,
} = options;
ResponsesRequestBuilder::new(
&self.state.model_info.slug,
&api_prompt.instructions,
&api_prompt.input,
)
.tools(&api_prompt.tools)
.parallel_tool_calls(api_prompt.parallel_tool_calls)
.reasoning(reasoning)
.include(include)
.prompt_cache_key(prompt_cache_key)
.text(text)
.conversation(conversation_id)
.session_source(session_source)
.store_override(store_override)
.extra_headers(extra_headers)
.compression(compression)
.build(api_provider)
.map_err(map_api_error)
}
async fn websocket_connection(
&mut self,
api_provider: codex_api::Provider,
api_auth: CoreAuthProvider,
headers: ApiHeaderMap,
) -> std::result::Result<&ApiWebSocketConnection, ApiError> {
let needs_new = match self.connection.as_ref() {
Some(conn) => conn.is_closed().await,
None => true,
};
if needs_new {
let new_conn = ApiWebSocketResponsesClient::new(api_provider, api_auth)
.connect(headers)
.await?;
self.connection = Some(new_conn);
}
self.connection.as_ref().ok_or(ApiError::Stream(
"websocket connection is unavailable".to_string(),
))
}
fn responses_request_compression(&self, auth: Option<&crate::auth::CodexAuth>) -> Compression {
if self
.state
@ -447,7 +513,7 @@ impl ModelClientSession {
}
/// Streams a turn via the Responses API over WebSocket transport.
async fn stream_responses_websocket(&self, prompt: &Prompt) -> Result<ResponseStream> {
async fn stream_responses_websocket(&mut self, prompt: &Prompt) -> Result<ResponseStream> {
let auth_manager = self.state.auth_manager.clone();
let api_prompt = self.build_responses_request(prompt)?;
@ -467,16 +533,18 @@ impl ModelClientSession {
let compression = self.responses_request_compression(auth.as_ref());
let options = self.build_responses_options(prompt, compression);
let client = ApiWebSocketResponsesClient::new(api_provider, api_auth);
let request =
self.build_responses_websocket_request(&api_provider, &api_prompt, options)?;
let stream_result = client
.stream_prompt(&self.state.model_info.slug, &api_prompt, options)
.await;
match stream_result {
Ok(stream) => {
return Ok(map_response_stream(stream, self.state.otel_manager.clone()));
}
let connection = match self
.websocket_connection(
api_provider.clone(),
api_auth.clone(),
request.headers.clone(),
)
.await
{
Ok(connection) => connection,
Err(ApiError::Transport(TransportError::Http { status, .. }))
if status == StatusCode::UNAUTHORIZED =>
{
@ -484,7 +552,17 @@ impl ModelClientSession {
continue;
}
Err(err) => return Err(map_api_error(err)),
}
};
let stream_result = connection
.stream_request(request)
.await
.map_err(map_api_error)?;
return Ok(map_response_stream(
stream_result,
self.state.otel_manager.clone(),
));
}
}

View file

@ -2673,7 +2673,7 @@ async fn run_model_turn(
output_schema: turn_context.final_output_json_schema.clone(),
};
let client_session = turn_context.client.new_session();
let mut client_session = turn_context.client.new_session();
let mut retries = 0;
loop {
@ -2681,7 +2681,7 @@ async fn run_model_turn(
Arc::clone(&router),
Arc::clone(&sess),
Arc::clone(&turn_context),
&client_session,
&mut client_session,
Arc::clone(&turn_diff_tracker),
&prompt,
cancellation_token.child_token(),
@ -2773,7 +2773,7 @@ async fn try_run_turn(
router: Arc<ToolRouter>,
sess: Arc<Session>,
turn_context: Arc<TurnContext>,
client_session: &ModelClientSession,
client_session: &mut ModelClientSession,
turn_diff_tracker: SharedTurnDiffTracker,
prompt: &Prompt,
cancellation_token: CancellationToken,

View file

@ -297,7 +297,7 @@ async fn drain_to_completed(
turn_context: &TurnContext,
prompt: &Prompt,
) -> CodexResult<()> {
let client_session = turn_context.client.new_session();
let mut client_session = turn_context.client.new_session();
let mut stream = client_session.stream(prompt).await?;
loop {
let maybe_event = stream.next().await;

View file

@ -88,7 +88,7 @@ async fn run_request(input: Vec<ResponseItem>) -> Value {
SessionSource::Exec,
);
let client = ModelClient::new(
let mut client_session = ModelClient::new(
Arc::clone(&config),
None,
model_info,
@ -104,7 +104,7 @@ async fn run_request(input: Vec<ResponseItem>) -> Value {
let mut prompt = Prompt::default();
prompt.input = input;
let mut stream = match client.stream(&prompt).await {
let mut stream = match client_session.stream(&prompt).await {
Ok(s) => s,
Err(e) => panic!("stream chat failed: {e}"),
};

View file

@ -89,7 +89,7 @@ async fn run_stream_with_bytes(sse_body: &[u8]) -> Vec<ResponseEvent> {
SessionSource::Exec,
);
let client = ModelClient::new(
let mut client = ModelClient::new(
Arc::clone(&config),
None,
model_info,

View file

@ -81,7 +81,7 @@ async fn responses_stream_includes_subagent_header_on_review() {
session_source.clone(),
);
let client = ModelClient::new(
let mut client_session = ModelClient::new(
Arc::clone(&config),
None,
model_info,
@ -103,7 +103,7 @@ async fn responses_stream_includes_subagent_header_on_review() {
}],
}];
let mut stream = client.stream(&prompt).await.expect("stream failed");
let mut stream = client_session.stream(&prompt).await.expect("stream failed");
while let Some(event) = stream.next().await {
if matches!(event, Ok(ResponseEvent::Completed { .. })) {
break;
@ -177,7 +177,7 @@ async fn responses_stream_includes_subagent_header_on_other() {
session_source.clone(),
);
let client = ModelClient::new(
let mut client_session = ModelClient::new(
Arc::clone(&config),
None,
model_info,
@ -199,7 +199,7 @@ async fn responses_stream_includes_subagent_header_on_other() {
}],
}];
let mut stream = client.stream(&prompt).await.expect("stream failed");
let mut stream = client_session.stream(&prompt).await.expect("stream failed");
while let Some(event) = stream.next().await {
if matches!(event, Ok(ResponseEvent::Completed { .. })) {
break;
@ -271,7 +271,7 @@ async fn responses_respects_model_info_overrides_from_config() {
session_source.clone(),
);
let client = ModelClient::new(
let mut client = ModelClient::new(
Arc::clone(&config),
None,
model_info,

View file

@ -1171,7 +1171,7 @@ async fn azure_responses_request_includes_store_and_reasoning_ids() {
SessionSource::Exec,
);
let client = ModelClient::new(
let mut client = ModelClient::new(
Arc::clone(&config),
None,
model_info,

View file

@ -1,7 +1,9 @@
#![allow(clippy::expect_used, clippy::unwrap_used)]
use codex_core::AuthManager;
use codex_core::CodexAuth;
use codex_core::ContentItem;
use codex_core::ModelClient;
use codex_core::ModelClientSession;
use codex_core::ModelProviderInfo;
use codex_core::Prompt;
use codex_core::ResponseEvent;
@ -11,23 +13,97 @@ use codex_core::models_manager::manager::ModelsManager;
use codex_core::protocol::SessionSource;
use codex_otel::OtelManager;
use codex_protocol::ThreadId;
use codex_protocol::config_types::ReasoningSummary;
use core_test_support::load_default_config_for_test;
use core_test_support::responses::WebSocketTestServer;
use core_test_support::responses::ev_completed;
use core_test_support::responses::ev_response_created;
use core_test_support::responses::start_websocket_server;
use core_test_support::skip_if_no_network;
use futures::StreamExt;
use pretty_assertions::assert_eq;
use std::sync::Arc;
use tempfile::TempDir;
const MODEL: &str = "gpt-5.2-codex";
struct WebsocketTestHarness {
_codex_home: TempDir,
client: ModelClient,
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn responses_websocket_streams_request() {
skip_if_no_network!();
let server = start_websocket_server(vec![vec![vec![
ev_response_created("resp-1"),
ev_completed("resp-1"),
]]])
.await;
let provider = ModelProviderInfo {
let harness = websocket_harness(&server).await;
let mut session = harness.client.new_session();
let mut prompt = Prompt::default();
prompt.input = vec![ResponseItem::Message {
id: None,
role: "user".into(),
content: vec![ContentItem::InputText {
text: "hello".into(),
}],
}];
stream_until_complete(&mut session, &prompt).await;
let connection = server.single_connection();
assert_eq!(connection.len(), 1);
let body = connection.first().expect("missing request").body_json();
assert_eq!(body["model"].as_str(), Some(MODEL));
assert_eq!(body["stream"], serde_json::Value::Bool(true));
assert_eq!(body["input"].as_array().map(Vec::len), Some(1));
server.shutdown().await;
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn responses_websocket_reuses_connection() {
skip_if_no_network!();
let server = start_websocket_server(vec![vec![
vec![ev_response_created("resp-1"), ev_completed("resp-1")],
vec![ev_response_created("resp-2"), ev_completed("resp-2")],
]])
.await;
let harness = websocket_harness(&server).await;
let mut session = harness.client.new_session();
let mut prompt = Prompt::default();
prompt.input = vec![ResponseItem::Message {
id: None,
role: "user".into(),
content: vec![ContentItem::InputText {
text: "hello".into(),
}],
}];
for _ in 0..2 {
stream_until_complete(&mut session, &prompt).await;
}
let connection = server.single_connection();
assert_eq!(connection.len(), 2);
let body = connection.first().expect("missing request").body_json();
assert_eq!(body["model"].as_str(), Some(MODEL));
assert_eq!(body["stream"], serde_json::Value::Bool(true));
assert_eq!(body["input"].as_array().map(Vec::len), Some(1));
server.shutdown().await;
}
fn websocket_provider(server: &WebSocketTestServer) -> ModelProviderInfo {
ModelProviderInfo {
name: "mock-ws".into(),
base_url: Some(format!("{}/v1", server.uri())),
env_key: None,
@ -41,23 +117,21 @@ async fn responses_websocket_streams_request() {
stream_max_retries: Some(0),
stream_idle_timeout_ms: Some(5_000),
requires_openai_auth: false,
};
}
}
async fn websocket_harness(server: &WebSocketTestServer) -> WebsocketTestHarness {
let provider = websocket_provider(server);
let codex_home = TempDir::new().unwrap();
let mut config = load_default_config_for_test(&codex_home).await;
config.model_provider_id = provider.name.clone();
config.model_provider = provider.clone();
let effort = config.model_reasoning_effort;
let summary = config.model_reasoning_summary;
let model = ModelsManager::get_model_offline(config.model.as_deref());
config.model = Some(model.clone());
config.model = Some(MODEL.to_string());
let config = Arc::new(config);
let model_info = ModelsManager::construct_model_info_offline(model.as_str(), &config);
let model_info = ModelsManager::construct_model_info_offline(MODEL, &config);
let conversation_id = ThreadId::new();
let auth_manager = AuthManager::from_auth_for_testing(CodexAuth::from_api_key("Test API Key"));
let otel_manager = OtelManager::new(
conversation_id,
model.as_str(),
MODEL,
model_info.slug.as_str(),
None,
Some("test@test.com".to_string()),
@ -66,31 +140,27 @@ async fn responses_websocket_streams_request() {
"test".to_string(),
SessionSource::Exec,
);
let client = ModelClient::new(
Arc::clone(&config),
None,
model_info,
otel_manager,
provider,
effort,
summary,
provider.clone(),
None,
ReasoningSummary::Auto,
conversation_id,
SessionSource::Exec,
)
.new_session();
);
let mut prompt = Prompt::default();
prompt.input = vec![ResponseItem::Message {
id: None,
role: "user".into(),
content: vec![ContentItem::InputText {
text: "hello".into(),
}],
}];
WebsocketTestHarness {
_codex_home: codex_home,
client,
}
}
let mut stream = client
.stream(&prompt)
async fn stream_until_complete(session: &mut ModelClientSession, prompt: &Prompt) {
let mut stream = session
.stream(prompt)
.await
.expect("websocket stream failed");
@ -99,14 +169,4 @@ async fn responses_websocket_streams_request() {
break;
}
}
let connection = server.single_connection();
assert_eq!(connection.len(), 1);
let request = connection.first().cloned().unwrap();
let body = request.body_json();
assert_eq!(body["model"].as_str(), Some(model.as_str()));
assert_eq!(body["stream"], serde_json::Value::Bool(true));
assert_eq!(body["input"].as_array().map(Vec::len), Some(1));
server.shutdown().await;
}