Reuse websocket connection (#9127)
Reuses the connection but still sends full requests.
This commit is contained in:
parent
12779c7c07
commit
d75626ad99
10 changed files with 293 additions and 120 deletions
|
|
@ -16,8 +16,10 @@ use futures::StreamExt;
|
|||
use http::HeaderMap;
|
||||
use http::HeaderValue;
|
||||
use serde_json::Value;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio_tungstenite::MaybeTlsStream;
|
||||
use tokio_tungstenite::WebSocketStream;
|
||||
|
|
@ -31,6 +33,69 @@ use url::Url;
|
|||
|
||||
type WsStream = WebSocketStream<MaybeTlsStream<TcpStream>>;
|
||||
|
||||
pub struct ResponsesWebsocketConnection {
|
||||
stream: Arc<Mutex<Option<WsStream>>>,
|
||||
// TODO (pakrym): is this the right place for timeout?
|
||||
idle_timeout: Duration,
|
||||
}
|
||||
|
||||
impl ResponsesWebsocketConnection {
|
||||
fn new(stream: WsStream, idle_timeout: Duration) -> Self {
|
||||
Self {
|
||||
stream: Arc::new(Mutex::new(Some(stream))),
|
||||
idle_timeout,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn is_closed(&self) -> bool {
|
||||
self.stream.lock().await.is_none()
|
||||
}
|
||||
|
||||
pub async fn stream_request(
|
||||
&self,
|
||||
request: ResponsesRequest,
|
||||
) -> Result<ResponseStream, ApiError> {
|
||||
if request.compression == Compression::Zstd {
|
||||
warn!(
|
||||
"request compression is not supported for websocket streaming; sending uncompressed payload"
|
||||
);
|
||||
}
|
||||
|
||||
let (tx_event, rx_event) =
|
||||
mpsc::channel::<std::result::Result<ResponseEvent, ApiError>>(1600);
|
||||
let stream = Arc::clone(&self.stream);
|
||||
let idle_timeout = self.idle_timeout;
|
||||
let request_body = request.body;
|
||||
|
||||
tokio::spawn(async move {
|
||||
let mut guard = stream.lock().await;
|
||||
let Some(ws_stream) = guard.as_mut() else {
|
||||
let _ = tx_event
|
||||
.send(Err(ApiError::Stream(
|
||||
"websocket connection is closed".to_string(),
|
||||
)))
|
||||
.await;
|
||||
return;
|
||||
};
|
||||
|
||||
if let Err(err) = run_websocket_response_stream(
|
||||
ws_stream,
|
||||
tx_event.clone(),
|
||||
request_body,
|
||||
idle_timeout,
|
||||
)
|
||||
.await
|
||||
{
|
||||
let _ = ws_stream.close(None).await;
|
||||
*guard = None;
|
||||
let _ = tx_event.send(Err(err)).await;
|
||||
}
|
||||
});
|
||||
|
||||
Ok(ResponseStream { rx_event })
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ResponsesWebsocketClient<A: AuthProvider> {
|
||||
provider: Provider,
|
||||
auth: A,
|
||||
|
|
@ -41,12 +106,22 @@ impl<A: AuthProvider> ResponsesWebsocketClient<A> {
|
|||
Self { provider, auth }
|
||||
}
|
||||
|
||||
pub async fn stream_request(
|
||||
pub async fn connect(
|
||||
&self,
|
||||
request: ResponsesRequest,
|
||||
) -> Result<ResponseStream, ApiError> {
|
||||
self.stream(request.body, request.headers, request.compression)
|
||||
.await
|
||||
extra_headers: HeaderMap,
|
||||
) -> Result<ResponsesWebsocketConnection, ApiError> {
|
||||
let ws_url = Url::parse(&self.provider.url_for_path("responses"))
|
||||
.map_err(|err| ApiError::Stream(format!("failed to build websocket URL: {err}")))?;
|
||||
|
||||
let mut headers = self.provider.headers.clone();
|
||||
headers.extend(extra_headers);
|
||||
apply_auth_headers(&mut headers, &self.auth);
|
||||
|
||||
let stream = connect_websocket(ws_url, headers).await?;
|
||||
Ok(ResponsesWebsocketConnection::new(
|
||||
stream,
|
||||
self.provider.stream_idle_timeout,
|
||||
))
|
||||
}
|
||||
|
||||
pub async fn stream_prompt(
|
||||
|
|
@ -82,7 +157,8 @@ impl<A: AuthProvider> ResponsesWebsocketClient<A> {
|
|||
.compression(compression)
|
||||
.build(&self.provider)?;
|
||||
|
||||
self.stream_request(request).await
|
||||
let connection = self.connect(request.headers.clone()).await?;
|
||||
connection.stream_request(request).await
|
||||
}
|
||||
|
||||
pub async fn stream(
|
||||
|
|
@ -91,41 +167,13 @@ impl<A: AuthProvider> ResponsesWebsocketClient<A> {
|
|||
extra_headers: HeaderMap,
|
||||
compression: Compression,
|
||||
) -> Result<ResponseStream, ApiError> {
|
||||
if compression == Compression::Zstd {
|
||||
warn!(
|
||||
"request compression is not supported for websocket streaming; sending uncompressed payload"
|
||||
);
|
||||
}
|
||||
|
||||
let ws_url = Url::parse(&self.provider.url_for_path("responses"))
|
||||
.map_err(|err| ApiError::Stream(format!("failed to build websocket URL: {err}")))?;
|
||||
let mut headers = self.provider.headers.clone();
|
||||
headers.extend(extra_headers);
|
||||
apply_auth_headers(&mut headers, &self.auth);
|
||||
|
||||
let connection = connect_websocket(ws_url, headers).await?;
|
||||
|
||||
let (tx_event, rx_event) =
|
||||
mpsc::channel::<std::result::Result<ResponseEvent, ApiError>>(1600);
|
||||
let idle_timeout = self.provider.stream_idle_timeout;
|
||||
|
||||
// TODO (pakrym): surface rate limits
|
||||
// TODO (pakrym): check models etags
|
||||
|
||||
tokio::spawn(async move {
|
||||
if let Err(err) = run_websocket_response_stream(
|
||||
connection.stream,
|
||||
tx_event.clone(),
|
||||
body,
|
||||
idle_timeout,
|
||||
)
|
||||
.await
|
||||
{
|
||||
let _ = tx_event.send(Err(err)).await;
|
||||
}
|
||||
});
|
||||
|
||||
Ok(ResponseStream { rx_event })
|
||||
let request = ResponsesRequest {
|
||||
body,
|
||||
headers: extra_headers,
|
||||
compression,
|
||||
};
|
||||
let connection = self.connect(request.headers.clone()).await?;
|
||||
connection.stream_request(request).await
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -143,11 +191,7 @@ fn apply_auth_headers(headers: &mut HeaderMap, auth: &impl AuthProvider) {
|
|||
}
|
||||
}
|
||||
|
||||
struct WebSocketConnection {
|
||||
stream: WsStream,
|
||||
}
|
||||
|
||||
async fn connect_websocket(url: Url, headers: HeaderMap) -> Result<WebSocketConnection, ApiError> {
|
||||
async fn connect_websocket(url: Url, headers: HeaderMap) -> Result<WsStream, ApiError> {
|
||||
let mut request = url
|
||||
.clone()
|
||||
.into_client_request()
|
||||
|
|
@ -157,7 +201,7 @@ async fn connect_websocket(url: Url, headers: HeaderMap) -> Result<WebSocketConn
|
|||
let (stream, _) = tokio_tungstenite::connect_async(request)
|
||||
.await
|
||||
.map_err(|err| map_ws_error(err, &url))?;
|
||||
Ok(WebSocketConnection { stream })
|
||||
Ok(stream)
|
||||
}
|
||||
|
||||
fn map_ws_error(err: WsError, url: &Url) -> ApiError {
|
||||
|
|
@ -185,7 +229,7 @@ fn map_ws_error(err: WsError, url: &Url) -> ApiError {
|
|||
}
|
||||
|
||||
async fn run_websocket_response_stream(
|
||||
mut ws_stream: WsStream,
|
||||
ws_stream: &mut WsStream,
|
||||
tx_event: mpsc::Sender<std::result::Result<ResponseEvent, ApiError>>,
|
||||
request_body: Value,
|
||||
idle_timeout: Duration,
|
||||
|
|
@ -193,7 +237,6 @@ async fn run_websocket_response_stream(
|
|||
let request_text = match serde_json::to_string(&request_body) {
|
||||
Ok(text) => text,
|
||||
Err(err) => {
|
||||
let _ = ws_stream.close(None).await;
|
||||
return Err(ApiError::Stream(format!(
|
||||
"failed to encode websocket request: {err}"
|
||||
)));
|
||||
|
|
@ -201,7 +244,6 @@ async fn run_websocket_response_stream(
|
|||
};
|
||||
|
||||
if let Err(err) = ws_stream.send(Message::Text(request_text)).await {
|
||||
let _ = ws_stream.close(None).await;
|
||||
return Err(ApiError::Stream(format!(
|
||||
"failed to send websocket request: {err}"
|
||||
)));
|
||||
|
|
@ -214,17 +256,14 @@ async fn run_websocket_response_stream(
|
|||
let message = match response {
|
||||
Ok(Some(Ok(msg))) => msg,
|
||||
Ok(Some(Err(err))) => {
|
||||
let _ = ws_stream.close(None).await;
|
||||
return Err(ApiError::Stream(err.to_string()));
|
||||
}
|
||||
Ok(None) => {
|
||||
let _ = ws_stream.close(None).await;
|
||||
return Err(ApiError::Stream(
|
||||
"stream closed before response.completed".into(),
|
||||
));
|
||||
}
|
||||
Err(err) => {
|
||||
let _ = ws_stream.close(None).await;
|
||||
return Err(err);
|
||||
}
|
||||
};
|
||||
|
|
@ -249,24 +288,20 @@ async fn run_websocket_response_stream(
|
|||
}
|
||||
Ok(None) => {}
|
||||
Err(error) => {
|
||||
let _ = ws_stream.close(None).await;
|
||||
return Err(error.into_api_error());
|
||||
}
|
||||
}
|
||||
}
|
||||
Message::Binary(_) => {
|
||||
let _ = ws_stream.close(None).await;
|
||||
return Err(ApiError::Stream("unexpected binary websocket event".into()));
|
||||
}
|
||||
Message::Ping(payload) => {
|
||||
if ws_stream.send(Message::Pong(payload)).await.is_err() {
|
||||
let _ = ws_stream.close(None).await;
|
||||
return Err(ApiError::Stream("websocket ping failed".into()));
|
||||
}
|
||||
}
|
||||
Message::Pong(_) => {}
|
||||
Message::Close(_) => {
|
||||
let _ = ws_stream.close(None).await;
|
||||
return Err(ApiError::Stream(
|
||||
"websocket closed before response.completed".into(),
|
||||
));
|
||||
|
|
@ -275,6 +310,5 @@ async fn run_websocket_response_stream(
|
|||
}
|
||||
}
|
||||
|
||||
let _ = ws_stream.close(None).await;
|
||||
Ok(())
|
||||
}
|
||||
|
|
|
|||
|
|
@ -26,6 +26,7 @@ pub use crate::endpoint::models::ModelsClient;
|
|||
pub use crate::endpoint::responses::ResponsesClient;
|
||||
pub use crate::endpoint::responses::ResponsesOptions;
|
||||
pub use crate::endpoint::responses_websocket::ResponsesWebsocketClient;
|
||||
pub use crate::endpoint::responses_websocket::ResponsesWebsocketConnection;
|
||||
pub use crate::error::ApiError;
|
||||
pub use crate::provider::Provider;
|
||||
pub use crate::provider::WireApi;
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use crate::api_bridge::CoreAuthProvider;
|
||||
use crate::api_bridge::auth_provider_from_auth;
|
||||
use crate::api_bridge::map_api_error;
|
||||
use crate::auth::UnauthorizedRecovery;
|
||||
|
|
@ -13,7 +14,10 @@ use codex_api::ReqwestTransport;
|
|||
use codex_api::ResponseStream as ApiResponseStream;
|
||||
use codex_api::ResponsesClient as ApiResponsesClient;
|
||||
use codex_api::ResponsesOptions as ApiResponsesOptions;
|
||||
use codex_api::ResponsesRequest;
|
||||
use codex_api::ResponsesRequestBuilder;
|
||||
use codex_api::ResponsesWebsocketClient as ApiWebSocketResponsesClient;
|
||||
use codex_api::ResponsesWebsocketConnection as ApiWebSocketConnection;
|
||||
use codex_api::SseTelemetry;
|
||||
use codex_api::TransportError;
|
||||
use codex_api::common::Reasoning;
|
||||
|
|
@ -76,9 +80,9 @@ pub struct ModelClient {
|
|||
state: Arc<ModelClientState>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ModelClientSession {
|
||||
state: Arc<ModelClientState>,
|
||||
connection: Option<ApiWebSocketConnection>,
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
|
|
@ -112,6 +116,7 @@ impl ModelClient {
|
|||
pub fn new_session(&self) -> ModelClientSession {
|
||||
ModelClientSession {
|
||||
state: Arc::clone(&self.state),
|
||||
connection: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -228,7 +233,7 @@ impl ModelClientSession {
|
|||
///
|
||||
/// For Chat providers, the underlying stream is optionally aggregated
|
||||
/// based on the `show_raw_agent_reasoning` flag in the config.
|
||||
pub async fn stream(&self, prompt: &Prompt) -> Result<ResponseStream> {
|
||||
pub async fn stream(&mut self, prompt: &Prompt) -> Result<ResponseStream> {
|
||||
match self.state.provider.wire_api {
|
||||
WireApi::Responses => self.stream_responses_api(prompt).await,
|
||||
WireApi::ResponsesWebsocket => self.stream_responses_websocket(prompt).await,
|
||||
|
|
@ -315,6 +320,67 @@ impl ModelClientSession {
|
|||
}
|
||||
}
|
||||
|
||||
fn build_responses_websocket_request(
|
||||
&self,
|
||||
api_provider: &codex_api::Provider,
|
||||
api_prompt: &ApiPrompt,
|
||||
options: ApiResponsesOptions,
|
||||
) -> Result<ResponsesRequest> {
|
||||
let ApiResponsesOptions {
|
||||
reasoning,
|
||||
include,
|
||||
prompt_cache_key,
|
||||
text,
|
||||
store_override,
|
||||
conversation_id,
|
||||
session_source,
|
||||
extra_headers,
|
||||
compression,
|
||||
} = options;
|
||||
|
||||
ResponsesRequestBuilder::new(
|
||||
&self.state.model_info.slug,
|
||||
&api_prompt.instructions,
|
||||
&api_prompt.input,
|
||||
)
|
||||
.tools(&api_prompt.tools)
|
||||
.parallel_tool_calls(api_prompt.parallel_tool_calls)
|
||||
.reasoning(reasoning)
|
||||
.include(include)
|
||||
.prompt_cache_key(prompt_cache_key)
|
||||
.text(text)
|
||||
.conversation(conversation_id)
|
||||
.session_source(session_source)
|
||||
.store_override(store_override)
|
||||
.extra_headers(extra_headers)
|
||||
.compression(compression)
|
||||
.build(api_provider)
|
||||
.map_err(map_api_error)
|
||||
}
|
||||
|
||||
async fn websocket_connection(
|
||||
&mut self,
|
||||
api_provider: codex_api::Provider,
|
||||
api_auth: CoreAuthProvider,
|
||||
headers: ApiHeaderMap,
|
||||
) -> std::result::Result<&ApiWebSocketConnection, ApiError> {
|
||||
let needs_new = match self.connection.as_ref() {
|
||||
Some(conn) => conn.is_closed().await,
|
||||
None => true,
|
||||
};
|
||||
|
||||
if needs_new {
|
||||
let new_conn = ApiWebSocketResponsesClient::new(api_provider, api_auth)
|
||||
.connect(headers)
|
||||
.await?;
|
||||
self.connection = Some(new_conn);
|
||||
}
|
||||
|
||||
self.connection.as_ref().ok_or(ApiError::Stream(
|
||||
"websocket connection is unavailable".to_string(),
|
||||
))
|
||||
}
|
||||
|
||||
fn responses_request_compression(&self, auth: Option<&crate::auth::CodexAuth>) -> Compression {
|
||||
if self
|
||||
.state
|
||||
|
|
@ -447,7 +513,7 @@ impl ModelClientSession {
|
|||
}
|
||||
|
||||
/// Streams a turn via the Responses API over WebSocket transport.
|
||||
async fn stream_responses_websocket(&self, prompt: &Prompt) -> Result<ResponseStream> {
|
||||
async fn stream_responses_websocket(&mut self, prompt: &Prompt) -> Result<ResponseStream> {
|
||||
let auth_manager = self.state.auth_manager.clone();
|
||||
let api_prompt = self.build_responses_request(prompt)?;
|
||||
|
||||
|
|
@ -467,16 +533,18 @@ impl ModelClientSession {
|
|||
let compression = self.responses_request_compression(auth.as_ref());
|
||||
|
||||
let options = self.build_responses_options(prompt, compression);
|
||||
let client = ApiWebSocketResponsesClient::new(api_provider, api_auth);
|
||||
let request =
|
||||
self.build_responses_websocket_request(&api_provider, &api_prompt, options)?;
|
||||
|
||||
let stream_result = client
|
||||
.stream_prompt(&self.state.model_info.slug, &api_prompt, options)
|
||||
.await;
|
||||
|
||||
match stream_result {
|
||||
Ok(stream) => {
|
||||
return Ok(map_response_stream(stream, self.state.otel_manager.clone()));
|
||||
}
|
||||
let connection = match self
|
||||
.websocket_connection(
|
||||
api_provider.clone(),
|
||||
api_auth.clone(),
|
||||
request.headers.clone(),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(connection) => connection,
|
||||
Err(ApiError::Transport(TransportError::Http { status, .. }))
|
||||
if status == StatusCode::UNAUTHORIZED =>
|
||||
{
|
||||
|
|
@ -484,7 +552,17 @@ impl ModelClientSession {
|
|||
continue;
|
||||
}
|
||||
Err(err) => return Err(map_api_error(err)),
|
||||
}
|
||||
};
|
||||
|
||||
let stream_result = connection
|
||||
.stream_request(request)
|
||||
.await
|
||||
.map_err(map_api_error)?;
|
||||
|
||||
return Ok(map_response_stream(
|
||||
stream_result,
|
||||
self.state.otel_manager.clone(),
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -2673,7 +2673,7 @@ async fn run_model_turn(
|
|||
output_schema: turn_context.final_output_json_schema.clone(),
|
||||
};
|
||||
|
||||
let client_session = turn_context.client.new_session();
|
||||
let mut client_session = turn_context.client.new_session();
|
||||
|
||||
let mut retries = 0;
|
||||
loop {
|
||||
|
|
@ -2681,7 +2681,7 @@ async fn run_model_turn(
|
|||
Arc::clone(&router),
|
||||
Arc::clone(&sess),
|
||||
Arc::clone(&turn_context),
|
||||
&client_session,
|
||||
&mut client_session,
|
||||
Arc::clone(&turn_diff_tracker),
|
||||
&prompt,
|
||||
cancellation_token.child_token(),
|
||||
|
|
@ -2773,7 +2773,7 @@ async fn try_run_turn(
|
|||
router: Arc<ToolRouter>,
|
||||
sess: Arc<Session>,
|
||||
turn_context: Arc<TurnContext>,
|
||||
client_session: &ModelClientSession,
|
||||
client_session: &mut ModelClientSession,
|
||||
turn_diff_tracker: SharedTurnDiffTracker,
|
||||
prompt: &Prompt,
|
||||
cancellation_token: CancellationToken,
|
||||
|
|
|
|||
|
|
@ -297,7 +297,7 @@ async fn drain_to_completed(
|
|||
turn_context: &TurnContext,
|
||||
prompt: &Prompt,
|
||||
) -> CodexResult<()> {
|
||||
let client_session = turn_context.client.new_session();
|
||||
let mut client_session = turn_context.client.new_session();
|
||||
let mut stream = client_session.stream(prompt).await?;
|
||||
loop {
|
||||
let maybe_event = stream.next().await;
|
||||
|
|
|
|||
|
|
@ -88,7 +88,7 @@ async fn run_request(input: Vec<ResponseItem>) -> Value {
|
|||
SessionSource::Exec,
|
||||
);
|
||||
|
||||
let client = ModelClient::new(
|
||||
let mut client_session = ModelClient::new(
|
||||
Arc::clone(&config),
|
||||
None,
|
||||
model_info,
|
||||
|
|
@ -104,7 +104,7 @@ async fn run_request(input: Vec<ResponseItem>) -> Value {
|
|||
let mut prompt = Prompt::default();
|
||||
prompt.input = input;
|
||||
|
||||
let mut stream = match client.stream(&prompt).await {
|
||||
let mut stream = match client_session.stream(&prompt).await {
|
||||
Ok(s) => s,
|
||||
Err(e) => panic!("stream chat failed: {e}"),
|
||||
};
|
||||
|
|
|
|||
|
|
@ -89,7 +89,7 @@ async fn run_stream_with_bytes(sse_body: &[u8]) -> Vec<ResponseEvent> {
|
|||
SessionSource::Exec,
|
||||
);
|
||||
|
||||
let client = ModelClient::new(
|
||||
let mut client = ModelClient::new(
|
||||
Arc::clone(&config),
|
||||
None,
|
||||
model_info,
|
||||
|
|
|
|||
|
|
@ -81,7 +81,7 @@ async fn responses_stream_includes_subagent_header_on_review() {
|
|||
session_source.clone(),
|
||||
);
|
||||
|
||||
let client = ModelClient::new(
|
||||
let mut client_session = ModelClient::new(
|
||||
Arc::clone(&config),
|
||||
None,
|
||||
model_info,
|
||||
|
|
@ -103,7 +103,7 @@ async fn responses_stream_includes_subagent_header_on_review() {
|
|||
}],
|
||||
}];
|
||||
|
||||
let mut stream = client.stream(&prompt).await.expect("stream failed");
|
||||
let mut stream = client_session.stream(&prompt).await.expect("stream failed");
|
||||
while let Some(event) = stream.next().await {
|
||||
if matches!(event, Ok(ResponseEvent::Completed { .. })) {
|
||||
break;
|
||||
|
|
@ -177,7 +177,7 @@ async fn responses_stream_includes_subagent_header_on_other() {
|
|||
session_source.clone(),
|
||||
);
|
||||
|
||||
let client = ModelClient::new(
|
||||
let mut client_session = ModelClient::new(
|
||||
Arc::clone(&config),
|
||||
None,
|
||||
model_info,
|
||||
|
|
@ -199,7 +199,7 @@ async fn responses_stream_includes_subagent_header_on_other() {
|
|||
}],
|
||||
}];
|
||||
|
||||
let mut stream = client.stream(&prompt).await.expect("stream failed");
|
||||
let mut stream = client_session.stream(&prompt).await.expect("stream failed");
|
||||
while let Some(event) = stream.next().await {
|
||||
if matches!(event, Ok(ResponseEvent::Completed { .. })) {
|
||||
break;
|
||||
|
|
@ -271,7 +271,7 @@ async fn responses_respects_model_info_overrides_from_config() {
|
|||
session_source.clone(),
|
||||
);
|
||||
|
||||
let client = ModelClient::new(
|
||||
let mut client = ModelClient::new(
|
||||
Arc::clone(&config),
|
||||
None,
|
||||
model_info,
|
||||
|
|
|
|||
|
|
@ -1171,7 +1171,7 @@ async fn azure_responses_request_includes_store_and_reasoning_ids() {
|
|||
SessionSource::Exec,
|
||||
);
|
||||
|
||||
let client = ModelClient::new(
|
||||
let mut client = ModelClient::new(
|
||||
Arc::clone(&config),
|
||||
None,
|
||||
model_info,
|
||||
|
|
|
|||
|
|
@ -1,7 +1,9 @@
|
|||
#![allow(clippy::expect_used, clippy::unwrap_used)]
|
||||
use codex_core::AuthManager;
|
||||
use codex_core::CodexAuth;
|
||||
use codex_core::ContentItem;
|
||||
use codex_core::ModelClient;
|
||||
use codex_core::ModelClientSession;
|
||||
use codex_core::ModelProviderInfo;
|
||||
use codex_core::Prompt;
|
||||
use codex_core::ResponseEvent;
|
||||
|
|
@ -11,23 +13,97 @@ use codex_core::models_manager::manager::ModelsManager;
|
|||
use codex_core::protocol::SessionSource;
|
||||
use codex_otel::OtelManager;
|
||||
use codex_protocol::ThreadId;
|
||||
use codex_protocol::config_types::ReasoningSummary;
|
||||
use core_test_support::load_default_config_for_test;
|
||||
use core_test_support::responses::WebSocketTestServer;
|
||||
use core_test_support::responses::ev_completed;
|
||||
use core_test_support::responses::ev_response_created;
|
||||
use core_test_support::responses::start_websocket_server;
|
||||
use core_test_support::skip_if_no_network;
|
||||
use futures::StreamExt;
|
||||
use pretty_assertions::assert_eq;
|
||||
use std::sync::Arc;
|
||||
use tempfile::TempDir;
|
||||
|
||||
const MODEL: &str = "gpt-5.2-codex";
|
||||
|
||||
struct WebsocketTestHarness {
|
||||
_codex_home: TempDir,
|
||||
client: ModelClient,
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn responses_websocket_streams_request() {
|
||||
skip_if_no_network!();
|
||||
|
||||
let server = start_websocket_server(vec![vec![vec![
|
||||
ev_response_created("resp-1"),
|
||||
ev_completed("resp-1"),
|
||||
]]])
|
||||
.await;
|
||||
|
||||
let provider = ModelProviderInfo {
|
||||
let harness = websocket_harness(&server).await;
|
||||
let mut session = harness.client.new_session();
|
||||
let mut prompt = Prompt::default();
|
||||
prompt.input = vec![ResponseItem::Message {
|
||||
id: None,
|
||||
role: "user".into(),
|
||||
content: vec![ContentItem::InputText {
|
||||
text: "hello".into(),
|
||||
}],
|
||||
}];
|
||||
|
||||
stream_until_complete(&mut session, &prompt).await;
|
||||
|
||||
let connection = server.single_connection();
|
||||
assert_eq!(connection.len(), 1);
|
||||
let body = connection.first().expect("missing request").body_json();
|
||||
|
||||
assert_eq!(body["model"].as_str(), Some(MODEL));
|
||||
assert_eq!(body["stream"], serde_json::Value::Bool(true));
|
||||
assert_eq!(body["input"].as_array().map(Vec::len), Some(1));
|
||||
|
||||
server.shutdown().await;
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn responses_websocket_reuses_connection() {
|
||||
skip_if_no_network!();
|
||||
|
||||
let server = start_websocket_server(vec![vec![
|
||||
vec![ev_response_created("resp-1"), ev_completed("resp-1")],
|
||||
vec![ev_response_created("resp-2"), ev_completed("resp-2")],
|
||||
]])
|
||||
.await;
|
||||
|
||||
let harness = websocket_harness(&server).await;
|
||||
let mut session = harness.client.new_session();
|
||||
let mut prompt = Prompt::default();
|
||||
prompt.input = vec![ResponseItem::Message {
|
||||
id: None,
|
||||
role: "user".into(),
|
||||
content: vec![ContentItem::InputText {
|
||||
text: "hello".into(),
|
||||
}],
|
||||
}];
|
||||
|
||||
for _ in 0..2 {
|
||||
stream_until_complete(&mut session, &prompt).await;
|
||||
}
|
||||
|
||||
let connection = server.single_connection();
|
||||
assert_eq!(connection.len(), 2);
|
||||
let body = connection.first().expect("missing request").body_json();
|
||||
|
||||
assert_eq!(body["model"].as_str(), Some(MODEL));
|
||||
assert_eq!(body["stream"], serde_json::Value::Bool(true));
|
||||
assert_eq!(body["input"].as_array().map(Vec::len), Some(1));
|
||||
|
||||
server.shutdown().await;
|
||||
}
|
||||
|
||||
fn websocket_provider(server: &WebSocketTestServer) -> ModelProviderInfo {
|
||||
ModelProviderInfo {
|
||||
name: "mock-ws".into(),
|
||||
base_url: Some(format!("{}/v1", server.uri())),
|
||||
env_key: None,
|
||||
|
|
@ -41,23 +117,21 @@ async fn responses_websocket_streams_request() {
|
|||
stream_max_retries: Some(0),
|
||||
stream_idle_timeout_ms: Some(5_000),
|
||||
requires_openai_auth: false,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
async fn websocket_harness(server: &WebSocketTestServer) -> WebsocketTestHarness {
|
||||
let provider = websocket_provider(server);
|
||||
let codex_home = TempDir::new().unwrap();
|
||||
let mut config = load_default_config_for_test(&codex_home).await;
|
||||
config.model_provider_id = provider.name.clone();
|
||||
config.model_provider = provider.clone();
|
||||
let effort = config.model_reasoning_effort;
|
||||
let summary = config.model_reasoning_summary;
|
||||
let model = ModelsManager::get_model_offline(config.model.as_deref());
|
||||
config.model = Some(model.clone());
|
||||
config.model = Some(MODEL.to_string());
|
||||
let config = Arc::new(config);
|
||||
let model_info = ModelsManager::construct_model_info_offline(model.as_str(), &config);
|
||||
let model_info = ModelsManager::construct_model_info_offline(MODEL, &config);
|
||||
let conversation_id = ThreadId::new();
|
||||
let auth_manager = AuthManager::from_auth_for_testing(CodexAuth::from_api_key("Test API Key"));
|
||||
let otel_manager = OtelManager::new(
|
||||
conversation_id,
|
||||
model.as_str(),
|
||||
MODEL,
|
||||
model_info.slug.as_str(),
|
||||
None,
|
||||
Some("test@test.com".to_string()),
|
||||
|
|
@ -66,31 +140,27 @@ async fn responses_websocket_streams_request() {
|
|||
"test".to_string(),
|
||||
SessionSource::Exec,
|
||||
);
|
||||
|
||||
let client = ModelClient::new(
|
||||
Arc::clone(&config),
|
||||
None,
|
||||
model_info,
|
||||
otel_manager,
|
||||
provider,
|
||||
effort,
|
||||
summary,
|
||||
provider.clone(),
|
||||
None,
|
||||
ReasoningSummary::Auto,
|
||||
conversation_id,
|
||||
SessionSource::Exec,
|
||||
)
|
||||
.new_session();
|
||||
);
|
||||
|
||||
let mut prompt = Prompt::default();
|
||||
prompt.input = vec![ResponseItem::Message {
|
||||
id: None,
|
||||
role: "user".into(),
|
||||
content: vec![ContentItem::InputText {
|
||||
text: "hello".into(),
|
||||
}],
|
||||
}];
|
||||
WebsocketTestHarness {
|
||||
_codex_home: codex_home,
|
||||
client,
|
||||
}
|
||||
}
|
||||
|
||||
let mut stream = client
|
||||
.stream(&prompt)
|
||||
async fn stream_until_complete(session: &mut ModelClientSession, prompt: &Prompt) {
|
||||
let mut stream = session
|
||||
.stream(prompt)
|
||||
.await
|
||||
.expect("websocket stream failed");
|
||||
|
||||
|
|
@ -99,14 +169,4 @@ async fn responses_websocket_streams_request() {
|
|||
break;
|
||||
}
|
||||
}
|
||||
|
||||
let connection = server.single_connection();
|
||||
assert_eq!(connection.len(), 1);
|
||||
let request = connection.first().cloned().unwrap();
|
||||
let body = request.body_json();
|
||||
assert_eq!(body["model"].as_str(), Some(model.as_str()));
|
||||
assert_eq!(body["stream"], serde_json::Value::Bool(true));
|
||||
assert_eq!(body["input"].as_array().map(Vec::len), Some(1));
|
||||
|
||||
server.shutdown().await;
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue