diff --git a/codex-rs/Cargo.lock b/codex-rs/Cargo.lock index ab7df33fa..d883e55ba 100644 --- a/codex-rs/Cargo.lock +++ b/codex-rs/Cargo.lock @@ -984,8 +984,10 @@ dependencies = [ "thiserror 2.0.17", "tokio", "tokio-test", + "tokio-tungstenite", "tokio-util", "tracing", + "url", "wiremock", ] @@ -2126,6 +2128,7 @@ dependencies = [ "codex-protocol", "codex-utils-absolute-path", "codex-utils-cargo-bin", + "futures", "notify", "pretty_assertions", "regex-lite", @@ -2134,6 +2137,7 @@ dependencies = [ "shlex", "tempfile", "tokio", + "tokio-tungstenite", "walkdir", "wiremock", ] @@ -2361,6 +2365,12 @@ dependencies = [ "syn 2.0.104", ] +[[package]] +name = "data-encoding" +version = "2.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d7a1e2f27636f116493b8b860f5546edb47c8d8f8ea73e1d2a20be88e28d1fea" + [[package]] name = "dbus" version = "0.9.9" @@ -7117,6 +7127,18 @@ dependencies = [ "tokio-stream", ] +[[package]] +name = "tokio-tungstenite" +version = "0.21.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c83b561d025642014097b66e6c1bb422783339e0909e4429cde4749d1990bc38" +dependencies = [ + "futures-util", + "log", + "tokio", + "tungstenite", +] + [[package]] name = "tokio-util" version = "0.7.18" @@ -7511,6 +7533,25 @@ dependencies = [ "ratatui-core", ] +[[package]] +name = "tungstenite" +version = "0.21.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ef1a641ea34f399a848dea702823bbecfb4c486f911735368f1f137cb8257e1" +dependencies = [ + "byteorder", + "bytes", + "data-encoding", + "http 1.3.1", + "httparse", + "log", + "rand 0.8.5", + "sha1", + "thiserror 1.0.69", + "url", + "utf-8", +] + [[package]] name = "typenum" version = "1.18.0" diff --git a/codex-rs/Cargo.toml b/codex-rs/Cargo.toml index a2d8b09b6..8809c0aca 100644 --- a/codex-rs/Cargo.toml +++ b/codex-rs/Cargo.toml @@ -209,6 +209,7 @@ tiny_http = "0.12" tokio = "1" tokio-stream = "0.1.18" tokio-test = "0.4" +tokio-tungstenite = "0.21.0" tokio-util = "0.7.18" toml = "0.9.5" toml_edit = "0.24.0" diff --git a/codex-rs/codex-api/Cargo.toml b/codex-rs/codex-api/Cargo.toml index e9fc78878..761e57236 100644 --- a/codex-rs/codex-api/Cargo.toml +++ b/codex-rs/codex-api/Cargo.toml @@ -14,11 +14,13 @@ http = { workspace = true } serde = { workspace = true, features = ["derive"] } serde_json = { workspace = true } thiserror = { workspace = true } -tokio = { workspace = true, features = ["macros", "rt", "sync", "time"] } +tokio = { workspace = true, features = ["macros", "net", "rt", "sync", "time"] } +tokio-tungstenite = { workspace = true } tracing = { workspace = true } eventsource-stream = { workspace = true } regex-lite = { workspace = true } tokio-util = { workspace = true, features = ["codec"] } +url = { workspace = true } [dev-dependencies] anyhow = { workspace = true } diff --git a/codex-rs/codex-api/src/endpoint/mod.rs b/codex-rs/codex-api/src/endpoint/mod.rs index cb0eeb9f2..2fa116c08 100644 --- a/codex-rs/codex-api/src/endpoint/mod.rs +++ b/codex-rs/codex-api/src/endpoint/mod.rs @@ -2,4 +2,5 @@ pub mod chat; pub mod compact; pub mod models; pub mod responses; +pub mod responses_websocket; mod streaming; diff --git a/codex-rs/codex-api/src/endpoint/responses_websocket.rs b/codex-rs/codex-api/src/endpoint/responses_websocket.rs new file mode 100644 index 000000000..bc64f3bfb --- /dev/null +++ b/codex-rs/codex-api/src/endpoint/responses_websocket.rs @@ -0,0 +1,280 @@ +use crate::auth::AuthProvider; +use crate::common::Prompt as ApiPrompt; +use crate::common::ResponseEvent; +use crate::common::ResponseStream; +use crate::endpoint::responses::ResponsesOptions; +use crate::error::ApiError; +use crate::provider::Provider; +use crate::requests::ResponsesRequest; +use crate::requests::ResponsesRequestBuilder; +use crate::requests::responses::Compression; +use crate::sse::responses::ResponsesStreamEvent; +use crate::sse::responses::process_responses_event; +use codex_client::TransportError; +use futures::SinkExt; +use futures::StreamExt; +use http::HeaderMap; +use http::HeaderValue; +use serde_json::Value; +use std::time::Duration; +use tokio::net::TcpStream; +use tokio::sync::mpsc; +use tokio_tungstenite::MaybeTlsStream; +use tokio_tungstenite::WebSocketStream; +use tokio_tungstenite::tungstenite::Error as WsError; +use tokio_tungstenite::tungstenite::Message; +use tokio_tungstenite::tungstenite::client::IntoClientRequest; +use tracing::debug; +use tracing::trace; +use tracing::warn; +use url::Url; + +type WsStream = WebSocketStream>; + +pub struct ResponsesWebsocketClient { + provider: Provider, + auth: A, +} + +impl ResponsesWebsocketClient { + pub fn new(provider: Provider, auth: A) -> Self { + Self { provider, auth } + } + + pub async fn stream_request( + &self, + request: ResponsesRequest, + ) -> Result { + self.stream(request.body, request.headers, request.compression) + .await + } + + pub async fn stream_prompt( + &self, + model: &str, + prompt: &ApiPrompt, + options: ResponsesOptions, + ) -> Result { + let ResponsesOptions { + reasoning, + include, + prompt_cache_key, + text, + store_override, + conversation_id, + session_source, + extra_headers, + compression, + } = options; + + // TODO (pakrym): share with HTTP based Responses API client + let request = ResponsesRequestBuilder::new(model, &prompt.instructions, &prompt.input) + .tools(&prompt.tools) + .parallel_tool_calls(prompt.parallel_tool_calls) + .reasoning(reasoning) + .include(include) + .prompt_cache_key(prompt_cache_key) + .text(text) + .conversation(conversation_id) + .session_source(session_source) + .store_override(store_override) + .extra_headers(extra_headers) + .compression(compression) + .build(&self.provider)?; + + self.stream_request(request).await + } + + pub async fn stream( + &self, + body: Value, + extra_headers: HeaderMap, + compression: Compression, + ) -> Result { + if compression == Compression::Zstd { + warn!( + "request compression is not supported for websocket streaming; sending uncompressed payload" + ); + } + + let ws_url = Url::parse(&self.provider.url_for_path("responses")) + .map_err(|err| ApiError::Stream(format!("failed to build websocket URL: {err}")))?; + let mut headers = self.provider.headers.clone(); + headers.extend(extra_headers); + apply_auth_headers(&mut headers, &self.auth); + + let connection = connect_websocket(ws_url, headers).await?; + + let (tx_event, rx_event) = + mpsc::channel::>(1600); + let idle_timeout = self.provider.stream_idle_timeout; + + // TODO (pakrym): surface rate limits + // TODO (pakrym): check models etags + + tokio::spawn(async move { + if let Err(err) = run_websocket_response_stream( + connection.stream, + tx_event.clone(), + body, + idle_timeout, + ) + .await + { + let _ = tx_event.send(Err(err)).await; + } + }); + + Ok(ResponseStream { rx_event }) + } +} + +// TODO (pakrym): share with /auth +fn apply_auth_headers(headers: &mut HeaderMap, auth: &impl AuthProvider) { + if let Some(token) = auth.bearer_token() + && let Ok(header) = HeaderValue::from_str(&format!("Bearer {token}")) + { + let _ = headers.insert(http::header::AUTHORIZATION, header); + } + if let Some(account_id) = auth.account_id() + && let Ok(header) = HeaderValue::from_str(&account_id) + { + let _ = headers.insert("ChatGPT-Account-ID", header); + } +} + +struct WebSocketConnection { + stream: WsStream, +} + +async fn connect_websocket(url: Url, headers: HeaderMap) -> Result { + let mut request = url + .clone() + .into_client_request() + .map_err(|err| ApiError::Stream(format!("failed to build websocket request: {err}")))?; + request.headers_mut().extend(headers); + + let (stream, _) = tokio_tungstenite::connect_async(request) + .await + .map_err(|err| map_ws_error(err, &url))?; + Ok(WebSocketConnection { stream }) +} + +fn map_ws_error(err: WsError, url: &Url) -> ApiError { + match err { + WsError::Http(response) => { + let status = response.status(); + let headers = response.headers().clone(); + let body = response + .body() + .as_ref() + .and_then(|bytes| String::from_utf8(bytes.clone()).ok()); + ApiError::Transport(TransportError::Http { + status, + url: Some(url.to_string()), + headers: Some(headers), + body, + }) + } + WsError::ConnectionClosed | WsError::AlreadyClosed => { + ApiError::Stream("websocket closed".to_string()) + } + WsError::Io(err) => ApiError::Transport(TransportError::Network(err.to_string())), + other => ApiError::Transport(TransportError::Network(other.to_string())), + } +} + +async fn run_websocket_response_stream( + mut ws_stream: WsStream, + tx_event: mpsc::Sender>, + request_body: Value, + idle_timeout: Duration, +) -> Result<(), ApiError> { + let request_text = match serde_json::to_string(&request_body) { + Ok(text) => text, + Err(err) => { + let _ = ws_stream.close(None).await; + return Err(ApiError::Stream(format!( + "failed to encode websocket request: {err}" + ))); + } + }; + + if let Err(err) = ws_stream.send(Message::Text(request_text)).await { + let _ = ws_stream.close(None).await; + return Err(ApiError::Stream(format!( + "failed to send websocket request: {err}" + ))); + } + + loop { + let response = tokio::time::timeout(idle_timeout, ws_stream.next()) + .await + .map_err(|_| ApiError::Stream("idle timeout waiting for websocket".into())); + let message = match response { + Ok(Some(Ok(msg))) => msg, + Ok(Some(Err(err))) => { + let _ = ws_stream.close(None).await; + return Err(ApiError::Stream(err.to_string())); + } + Ok(None) => { + let _ = ws_stream.close(None).await; + return Err(ApiError::Stream( + "stream closed before response.completed".into(), + )); + } + Err(err) => { + let _ = ws_stream.close(None).await; + return Err(err); + } + }; + + match message { + Message::Text(text) => { + trace!("websocket event: {text}"); + let event = match serde_json::from_str::(&text) { + Ok(event) => event, + Err(err) => { + debug!("failed to parse websocket event: {err}, data: {text}"); + continue; + } + }; + match process_responses_event(event) { + Ok(Some(event)) => { + let is_completed = matches!(event, ResponseEvent::Completed { .. }); + let _ = tx_event.send(Ok(event)).await; + if is_completed { + break; + } + } + Ok(None) => {} + Err(error) => { + let _ = ws_stream.close(None).await; + return Err(error.into_api_error()); + } + } + } + Message::Binary(_) => { + let _ = ws_stream.close(None).await; + return Err(ApiError::Stream("unexpected binary websocket event".into())); + } + Message::Ping(payload) => { + if ws_stream.send(Message::Pong(payload)).await.is_err() { + let _ = ws_stream.close(None).await; + return Err(ApiError::Stream("websocket ping failed".into())); + } + } + Message::Pong(_) => {} + Message::Close(_) => { + let _ = ws_stream.close(None).await; + return Err(ApiError::Stream( + "websocket closed before response.completed".into(), + )); + } + _ => {} + } + } + + let _ = ws_stream.close(None).await; + Ok(()) +} diff --git a/codex-rs/codex-api/src/lib.rs b/codex-rs/codex-api/src/lib.rs index d0c382ac8..4e82b874b 100644 --- a/codex-rs/codex-api/src/lib.rs +++ b/codex-rs/codex-api/src/lib.rs @@ -25,6 +25,7 @@ pub use crate::endpoint::compact::CompactClient; pub use crate::endpoint::models::ModelsClient; pub use crate::endpoint::responses::ResponsesClient; pub use crate::endpoint::responses::ResponsesOptions; +pub use crate::endpoint::responses_websocket::ResponsesWebsocketClient; pub use crate::error::ApiError; pub use crate::provider::Provider; pub use crate::provider::WireApi; diff --git a/codex-rs/codex-api/src/sse/responses.rs b/codex-rs/codex-api/src/sse/responses.rs index 5a1ab832e..f279ba5ed 100644 --- a/codex-rs/codex-api/src/sse/responses.rs +++ b/codex-rs/codex-api/src/sse/responses.rs @@ -126,7 +126,7 @@ struct ResponseCompletedOutputTokensDetails { } #[derive(Deserialize, Debug)] -struct ResponsesStreamEvent { +pub struct ResponsesStreamEvent { #[serde(rename = "type")] kind: String, response: Option, @@ -149,7 +149,7 @@ impl ResponsesEventError { } } -fn process_responses_event( +pub fn process_responses_event( event: ResponsesStreamEvent, ) -> std::result::Result, ResponsesEventError> { match event.kind.as_str() { diff --git a/codex-rs/core/src/client.rs b/codex-rs/core/src/client.rs index bec015b4c..eb866527c 100644 --- a/codex-rs/core/src/client.rs +++ b/codex-rs/core/src/client.rs @@ -13,6 +13,7 @@ use codex_api::ReqwestTransport; use codex_api::ResponseStream as ApiResponseStream; use codex_api::ResponsesClient as ApiResponsesClient; use codex_api::ResponsesOptions as ApiResponsesOptions; +use codex_api::ResponsesWebsocketClient as ApiWebSocketResponsesClient; use codex_api::SseTelemetry; use codex_api::TransportError; use codex_api::common::Reasoning; @@ -57,8 +58,8 @@ use crate::model_provider_info::WireApi; use crate::tools::spec::create_tools_json_for_chat_completions_api; use crate::tools::spec::create_tools_json_for_responses_api; -#[derive(Debug, Clone)] -pub struct ModelClient { +#[derive(Debug)] +struct ModelClientState { config: Arc, auth_manager: Option>, model_info: ModelInfo, @@ -70,6 +71,16 @@ pub struct ModelClient { session_source: SessionSource, } +#[derive(Debug, Clone)] +pub struct ModelClient { + state: Arc, +} + +#[derive(Debug, Clone)] +pub struct ModelClientSession { + state: Arc, +} + #[allow(clippy::too_many_arguments)] impl ModelClient { pub fn new( @@ -84,20 +95,30 @@ impl ModelClient { session_source: SessionSource, ) -> Self { Self { - config, - auth_manager, - model_info, - otel_manager, - provider, - conversation_id, - effort, - summary, - session_source, + state: Arc::new(ModelClientState { + config, + auth_manager, + model_info, + otel_manager, + provider, + conversation_id, + effort, + summary, + session_source, + }), } } + pub fn new_session(&self) -> ModelClientSession { + ModelClientSession { + state: Arc::clone(&self.state), + } + } +} + +impl ModelClient { pub fn get_model_context_window(&self) -> Option { - let model_info = self.get_model_info(); + let model_info = &self.state.model_info; let effective_context_window_percent = model_info.effective_context_window_percent; model_info.context_window.map(|context_window| { context_window.saturating_mul(effective_context_window_percent) / 100 @@ -105,39 +126,210 @@ impl ModelClient { } pub fn config(&self) -> Arc { - Arc::clone(&self.config) + Arc::clone(&self.state.config) } pub fn provider(&self) -> &ModelProviderInfo { - &self.provider + &self.state.provider } + pub fn get_provider(&self) -> ModelProviderInfo { + self.state.provider.clone() + } + + pub fn get_otel_manager(&self) -> OtelManager { + self.state.otel_manager.clone() + } + + pub fn get_session_source(&self) -> SessionSource { + self.state.session_source.clone() + } + + /// Returns the currently configured model slug. + pub fn get_model(&self) -> String { + self.state.model_info.slug.clone() + } + + pub fn get_model_info(&self) -> ModelInfo { + self.state.model_info.clone() + } + + /// Returns the current reasoning effort setting. + pub fn get_reasoning_effort(&self) -> Option { + self.state.effort + } + + /// Returns the current reasoning summary setting. + pub fn get_reasoning_summary(&self) -> ReasoningSummaryConfig { + self.state.summary + } + + pub fn get_auth_manager(&self) -> Option> { + self.state.auth_manager.clone() + } + + /// Compacts the current conversation history using the Compact endpoint. + /// + /// This is a unary call (no streaming) that returns a new list of + /// `ResponseItem`s representing the compacted transcript. + pub async fn compact_conversation_history(&self, prompt: &Prompt) -> Result> { + if prompt.input.is_empty() { + return Ok(Vec::new()); + } + let auth_manager = self.state.auth_manager.clone(); + let auth = match auth_manager.as_ref() { + Some(manager) => manager.auth().await, + None => None, + }; + let api_provider = self + .state + .provider + .to_api_provider(auth.as_ref().map(|a| a.mode))?; + let api_auth = auth_provider_from_auth(auth.clone(), &self.state.provider)?; + let transport = ReqwestTransport::new(build_reqwest_client()); + let request_telemetry = self.build_request_telemetry(); + let client = ApiCompactClient::new(transport, api_provider, api_auth) + .with_telemetry(Some(request_telemetry)); + + let instructions = prompt + .get_full_instructions(&self.state.model_info) + .into_owned(); + let payload = ApiCompactionInput { + model: &self.state.model_info.slug, + input: &prompt.input, + instructions: &instructions, + }; + + let mut extra_headers = ApiHeaderMap::new(); + if let SessionSource::SubAgent(sub) = &self.state.session_source { + let subagent = if let crate::protocol::SubAgentSource::Other(label) = sub { + label.clone() + } else { + serde_json::to_value(sub) + .ok() + .and_then(|v| v.as_str().map(std::string::ToString::to_string)) + .unwrap_or_else(|| "other".to_string()) + }; + if let Ok(val) = HeaderValue::from_str(&subagent) { + extra_headers.insert("x-openai-subagent", val); + } + } + + client + .compact_input(&payload, extra_headers) + .await + .map_err(map_api_error) + } +} + +impl ModelClientSession { /// Streams a single model turn using either the Responses or Chat /// Completions wire API, depending on the configured provider. /// /// For Chat providers, the underlying stream is optionally aggregated /// based on the `show_raw_agent_reasoning` flag in the config. pub async fn stream(&self, prompt: &Prompt) -> Result { - match self.provider.wire_api { + match self.state.provider.wire_api { WireApi::Responses => self.stream_responses_api(prompt).await, + WireApi::ResponsesWebsocket => self.stream_responses_websocket(prompt).await, WireApi::Chat => { let api_stream = self.stream_chat_completions(prompt).await?; - if self.config.show_raw_agent_reasoning { + if self.state.config.show_raw_agent_reasoning { Ok(map_response_stream( api_stream.streaming_mode(), - self.otel_manager.clone(), + self.state.otel_manager.clone(), )) } else { Ok(map_response_stream( api_stream.aggregate(), - self.otel_manager.clone(), + self.state.otel_manager.clone(), )) } } } } + fn build_responses_request(&self, prompt: &Prompt) -> Result { + let model_info = self.state.model_info.clone(); + let instructions = prompt.get_full_instructions(&model_info).into_owned(); + let tools_json: Vec = create_tools_json_for_responses_api(&prompt.tools)?; + Ok(build_api_prompt(prompt, instructions, tools_json)) + } + + fn build_responses_options( + &self, + prompt: &Prompt, + compression: Compression, + ) -> ApiResponsesOptions { + let model_info = &self.state.model_info; + + let default_reasoning_effort = model_info.default_reasoning_level; + let reasoning = if model_info.supports_reasoning_summaries { + Some(Reasoning { + effort: self.state.effort.or(default_reasoning_effort), + summary: if self.state.summary == ReasoningSummaryConfig::None { + None + } else { + Some(self.state.summary) + }, + }) + } else { + None + }; + + let include = if reasoning.is_some() { + vec!["reasoning.encrypted_content".to_string()] + } else { + Vec::new() + }; + + let verbosity = if model_info.support_verbosity { + self.state + .config + .model_verbosity + .or(model_info.default_verbosity) + } else { + if self.state.config.model_verbosity.is_some() { + warn!( + "model_verbosity is set but ignored as the model does not support verbosity: {}", + model_info.slug + ); + } + None + }; + + let text = create_text_param_for_request(verbosity, &prompt.output_schema); + let conversation_id = self.state.conversation_id.to_string(); + + ApiResponsesOptions { + reasoning, + include, + prompt_cache_key: Some(conversation_id.clone()), + text, + store_override: None, + conversation_id: Some(conversation_id), + session_source: Some(self.state.session_source.clone()), + extra_headers: beta_feature_headers(&self.state.config), + compression, + } + } + + fn responses_request_compression(&self, auth: Option<&crate::auth::CodexAuth>) -> Compression { + if self + .state + .config + .features + .enabled(Feature::EnableRequestCompression) + && auth.is_some_and(|auth| auth.mode == AuthMode::ChatGPT) + && self.state.provider.is_openai() + { + Compression::Zstd + } else { + Compression::None + } + } + /// Streams a turn via the OpenAI Chat Completions API. /// /// This path is only used when the provider is configured with @@ -149,13 +341,13 @@ impl ModelClient { )); } - let auth_manager = self.auth_manager.clone(); - let model_info = self.get_model_info(); + let auth_manager = self.state.auth_manager.clone(); + let model_info = self.state.model_info.clone(); let instructions = prompt.get_full_instructions(&model_info).into_owned(); let tools_json = create_tools_json_for_chat_completions_api(&prompt.tools)?; let api_prompt = build_api_prompt(prompt, instructions, tools_json); - let conversation_id = self.conversation_id.to_string(); - let session_source = self.session_source.clone(); + let conversation_id = self.state.conversation_id.to_string(); + let session_source = self.state.session_source.clone(); let mut auth_recovery = auth_manager .as_ref() @@ -166,9 +358,10 @@ impl ModelClient { None => None, }; let api_provider = self + .state .provider .to_api_provider(auth.as_ref().map(|a| a.mode))?; - let api_auth = auth_provider_from_auth(auth.clone(), &self.provider)?; + let api_auth = auth_provider_from_auth(auth.clone(), &self.state.provider)?; let transport = ReqwestTransport::new(build_reqwest_client()); let (request_telemetry, sse_telemetry) = self.build_streaming_telemetry(); let client = ApiChatClient::new(transport, api_provider, api_auth) @@ -176,7 +369,7 @@ impl ModelClient { let stream_result = client .stream_prompt( - &self.get_model(), + &self.state.model_info.slug, &api_prompt, Some(conversation_id.clone()), Some(session_source.clone()), @@ -203,52 +396,14 @@ impl ModelClient { async fn stream_responses_api(&self, prompt: &Prompt) -> Result { if let Some(path) = &*CODEX_RS_SSE_FIXTURE { warn!(path, "Streaming from fixture"); - let stream = codex_api::stream_from_fixture(path, self.provider.stream_idle_timeout()) - .map_err(map_api_error)?; - return Ok(map_response_stream(stream, self.otel_manager.clone())); + let stream = + codex_api::stream_from_fixture(path, self.state.provider.stream_idle_timeout()) + .map_err(map_api_error)?; + return Ok(map_response_stream(stream, self.state.otel_manager.clone())); } - let auth_manager = self.auth_manager.clone(); - let model_info = self.get_model_info(); - let instructions = prompt.get_full_instructions(&model_info).into_owned(); - let tools_json: Vec = create_tools_json_for_responses_api(&prompt.tools)?; - - let default_reasoning_effort = model_info.default_reasoning_level; - let reasoning = if model_info.supports_reasoning_summaries { - Some(Reasoning { - effort: self.effort.or(default_reasoning_effort), - summary: if self.summary == ReasoningSummaryConfig::None { - None - } else { - Some(self.summary) - }, - }) - } else { - None - }; - - let include: Vec = if reasoning.is_some() { - vec!["reasoning.encrypted_content".to_string()] - } else { - vec![] - }; - - let verbosity = if model_info.support_verbosity { - self.config.model_verbosity.or(model_info.default_verbosity) - } else { - if self.config.model_verbosity.is_some() { - warn!( - "model_verbosity is set but ignored as the model does not support verbosity: {}", - model_info.slug - ); - } - None - }; - - let text = create_text_param_for_request(verbosity, &prompt.output_schema); - let api_prompt = build_api_prompt(prompt, instructions.clone(), tools_json); - let conversation_id = self.conversation_id.to_string(); - let session_source = self.session_source.clone(); + let auth_manager = self.state.auth_manager.clone(); + let api_prompt = self.build_responses_request(prompt)?; let mut auth_recovery = auth_manager .as_ref() @@ -259,47 +414,26 @@ impl ModelClient { None => None, }; let api_provider = self + .state .provider .to_api_provider(auth.as_ref().map(|a| a.mode))?; - let api_auth = auth_provider_from_auth(auth.clone(), &self.provider)?; + let api_auth = auth_provider_from_auth(auth.clone(), &self.state.provider)?; let transport = ReqwestTransport::new(build_reqwest_client()); let (request_telemetry, sse_telemetry) = self.build_streaming_telemetry(); - let compression = if self - .config - .features - .enabled(Feature::EnableRequestCompression) - && auth - .as_ref() - .is_some_and(|auth| auth.mode == AuthMode::ChatGPT) - && self.provider.is_openai() - { - Compression::Zstd - } else { - Compression::None - }; + let compression = self.responses_request_compression(auth.as_ref()); let client = ApiResponsesClient::new(transport, api_provider, api_auth) .with_telemetry(Some(request_telemetry), Some(sse_telemetry)); - let options = ApiResponsesOptions { - reasoning: reasoning.clone(), - include: include.clone(), - prompt_cache_key: Some(conversation_id.clone()), - text: text.clone(), - store_override: None, - conversation_id: Some(conversation_id.clone()), - session_source: Some(session_source.clone()), - extra_headers: beta_feature_headers(&self.config), - compression, - }; + let options = self.build_responses_options(prompt, compression); let stream_result = client - .stream_prompt(&self.get_model(), &api_prompt, options) + .stream_prompt(&self.state.model_info.slug, &api_prompt, options) .await; match stream_result { Ok(stream) => { - return Ok(map_response_stream(stream, self.otel_manager.clone())); + return Ok(map_response_stream(stream, self.state.otel_manager.clone())); } Err(ApiError::Transport(TransportError::Http { status, .. })) if status == StatusCode::UNAUTHORIZED => @@ -312,106 +446,61 @@ impl ModelClient { } } - pub fn get_provider(&self) -> ModelProviderInfo { - self.provider.clone() - } + /// Streams a turn via the Responses API over WebSocket transport. + async fn stream_responses_websocket(&self, prompt: &Prompt) -> Result { + let auth_manager = self.state.auth_manager.clone(); + let api_prompt = self.build_responses_request(prompt)?; - pub fn get_otel_manager(&self) -> OtelManager { - self.otel_manager.clone() - } - - pub fn get_session_source(&self) -> SessionSource { - self.session_source.clone() - } - - /// Returns the currently configured model slug. - pub fn get_model(&self) -> String { - self.model_info.slug.clone() - } - - pub fn get_model_info(&self) -> ModelInfo { - self.model_info.clone() - } - - /// Returns the current reasoning effort setting. - pub fn get_reasoning_effort(&self) -> Option { - self.effort - } - - /// Returns the current reasoning summary setting. - pub fn get_reasoning_summary(&self) -> ReasoningSummaryConfig { - self.summary - } - - pub fn get_auth_manager(&self) -> Option> { - self.auth_manager.clone() - } - - /// Compacts the current conversation history using the Compact endpoint. - /// - /// This is a unary call (no streaming) that returns a new list of - /// `ResponseItem`s representing the compacted transcript. - pub async fn compact_conversation_history(&self, prompt: &Prompt) -> Result> { - if prompt.input.is_empty() { - return Ok(Vec::new()); - } - let auth_manager = self.auth_manager.clone(); - let auth = match auth_manager.as_ref() { - Some(manager) => manager.auth().await, - None => None, - }; - let api_provider = self - .provider - .to_api_provider(auth.as_ref().map(|a| a.mode))?; - let api_auth = auth_provider_from_auth(auth.clone(), &self.provider)?; - let transport = ReqwestTransport::new(build_reqwest_client()); - let request_telemetry = self.build_request_telemetry(); - let client = ApiCompactClient::new(transport, api_provider, api_auth) - .with_telemetry(Some(request_telemetry)); - - let instructions = prompt - .get_full_instructions(&self.get_model_info()) - .into_owned(); - let payload = ApiCompactionInput { - model: &self.get_model(), - input: &prompt.input, - instructions: &instructions, - }; - - let mut extra_headers = ApiHeaderMap::new(); - if let SessionSource::SubAgent(sub) = &self.session_source { - let subagent = if let crate::protocol::SubAgentSource::Other(label) = sub { - label.clone() - } else { - serde_json::to_value(sub) - .ok() - .and_then(|v| v.as_str().map(std::string::ToString::to_string)) - .unwrap_or_else(|| "other".to_string()) + let mut auth_recovery = auth_manager + .as_ref() + .map(super::auth::AuthManager::unauthorized_recovery); + loop { + let auth = match auth_manager.as_ref() { + Some(manager) => manager.auth().await, + None => None, }; - if let Ok(val) = HeaderValue::from_str(&subagent) { - extra_headers.insert("x-openai-subagent", val); + let api_provider = self + .state + .provider + .to_api_provider(auth.as_ref().map(|a| a.mode))?; + let api_auth = auth_provider_from_auth(auth.clone(), &self.state.provider)?; + let compression = self.responses_request_compression(auth.as_ref()); + + let options = self.build_responses_options(prompt, compression); + let client = ApiWebSocketResponsesClient::new(api_provider, api_auth); + + let stream_result = client + .stream_prompt(&self.state.model_info.slug, &api_prompt, options) + .await; + + match stream_result { + Ok(stream) => { + return Ok(map_response_stream(stream, self.state.otel_manager.clone())); + } + Err(ApiError::Transport(TransportError::Http { status, .. })) + if status == StatusCode::UNAUTHORIZED => + { + handle_unauthorized(status, &mut auth_recovery).await?; + continue; + } + Err(err) => return Err(map_api_error(err)), } } - - client - .compact_input(&payload, extra_headers) - .await - .map_err(map_api_error) } -} -impl ModelClient { /// Builds request and SSE telemetry for streaming API calls (Chat/Responses). fn build_streaming_telemetry(&self) -> (Arc, Arc) { - let telemetry = Arc::new(ApiTelemetry::new(self.otel_manager.clone())); + let telemetry = Arc::new(ApiTelemetry::new(self.state.otel_manager.clone())); let request_telemetry: Arc = telemetry.clone(); let sse_telemetry: Arc = telemetry; (request_telemetry, sse_telemetry) } +} +impl ModelClient { /// Builds request telemetry for unary API calls (e.g., Compact endpoint). fn build_request_telemetry(&self) -> Arc { - let telemetry = Arc::new(ApiTelemetry::new(self.otel_manager.clone())); + let telemetry = Arc::new(ApiTelemetry::new(self.state.otel_manager.clone())); let request_telemetry: Arc = telemetry; request_telemetry } diff --git a/codex-rs/core/src/codex.rs b/codex-rs/core/src/codex.rs index 2fd7e84f4..38687cc48 100644 --- a/codex-rs/core/src/codex.rs +++ b/codex-rs/core/src/codex.rs @@ -78,6 +78,7 @@ use tracing::warn; use crate::ModelProviderInfo; use crate::WireApi; use crate::client::ModelClient; +use crate::client::ModelClientSession; use crate::client_common::Prompt; use crate::client_common::ResponseEvent; use crate::compact::collect_user_messages; @@ -2672,12 +2673,15 @@ async fn run_model_turn( output_schema: turn_context.final_output_json_schema.clone(), }; + let client_session = turn_context.client.new_session(); + let mut retries = 0; loop { let err = match try_run_turn( Arc::clone(&router), Arc::clone(&sess), Arc::clone(&turn_context), + &client_session, Arc::clone(&turn_diff_tracker), &prompt, cancellation_token.child_token(), @@ -2769,6 +2773,7 @@ async fn try_run_turn( router: Arc, sess: Arc, turn_context: Arc, + client_session: &ModelClientSession, turn_diff_tracker: SharedTurnDiffTracker, prompt: &Prompt, cancellation_token: CancellationToken, @@ -2797,9 +2802,7 @@ async fn try_run_turn( ); sess.persist_rollout_items(&[rollout_item]).await; - let mut stream = turn_context - .client - .clone() + let mut stream = client_session .stream(prompt) .instrument(trace_span!("stream_request")) .or_cancel(&cancellation_token) diff --git a/codex-rs/core/src/compact.rs b/codex-rs/core/src/compact.rs index c8509cc5c..2a518dfeb 100644 --- a/codex-rs/core/src/compact.rs +++ b/codex-rs/core/src/compact.rs @@ -297,7 +297,8 @@ async fn drain_to_completed( turn_context: &TurnContext, prompt: &Prompt, ) -> CodexResult<()> { - let mut stream = turn_context.client.clone().stream(prompt).await?; + let client_session = turn_context.client.new_session(); + let mut stream = client_session.stream(prompt).await?; loop { let maybe_event = stream.next().await; let Some(event) = maybe_event else { diff --git a/codex-rs/core/src/lib.rs b/codex-rs/core/src/lib.rs index 1fb25ebc1..c4a5ef92b 100644 --- a/codex-rs/core/src/lib.rs +++ b/codex-rs/core/src/lib.rs @@ -126,6 +126,7 @@ pub use codex_protocol::protocol; pub use codex_protocol::config_types as protocol_config_types; pub use client::ModelClient; +pub use client::ModelClientSession; pub use client_common::Prompt; pub use client_common::REVIEW_PROMPT; pub use client_common::ResponseEvent; diff --git a/codex-rs/core/src/model_provider_info.rs b/codex-rs/core/src/model_provider_info.rs index 961739223..c5be0fc8d 100644 --- a/codex-rs/core/src/model_provider_info.rs +++ b/codex-rs/core/src/model_provider_info.rs @@ -42,6 +42,10 @@ pub enum WireApi { /// The Responses API exposed by OpenAI at `/v1/responses`. Responses, + /// Experimental: Responses API over WebSocket transport. + #[serde(rename = "responses_websocket")] + ResponsesWebsocket, + /// Regular Chat Completions compatible with `/v1/chat/completions`. #[default] Chat, @@ -156,6 +160,7 @@ impl ModelProviderInfo { query_params: self.query_params.clone(), wire: match self.wire_api { WireApi::Responses => ApiWireApi::Responses, + WireApi::ResponsesWebsocket => ApiWireApi::Responses, WireApi::Chat => ApiWireApi::Chat, }, headers, diff --git a/codex-rs/core/tests/chat_completions_payload.rs b/codex-rs/core/tests/chat_completions_payload.rs index 54d13367a..c8fef336e 100644 --- a/codex-rs/core/tests/chat_completions_payload.rs +++ b/codex-rs/core/tests/chat_completions_payload.rs @@ -98,7 +98,8 @@ async fn run_request(input: Vec) -> Value { summary, conversation_id, SessionSource::Exec, - ); + ) + .new_session(); let mut prompt = Prompt::default(); prompt.input = input; diff --git a/codex-rs/core/tests/chat_completions_sse.rs b/codex-rs/core/tests/chat_completions_sse.rs index 65b1f229b..157475580 100644 --- a/codex-rs/core/tests/chat_completions_sse.rs +++ b/codex-rs/core/tests/chat_completions_sse.rs @@ -99,7 +99,8 @@ async fn run_stream_with_bytes(sse_body: &[u8]) -> Vec { summary, conversation_id, SessionSource::Exec, - ); + ) + .new_session(); let mut prompt = Prompt::default(); prompt.input = vec![ResponseItem::Message { diff --git a/codex-rs/core/tests/common/Cargo.toml b/codex-rs/core/tests/common/Cargo.toml index c61a09568..8e9f53943 100644 --- a/codex-rs/core/tests/common/Cargo.toml +++ b/codex-rs/core/tests/common/Cargo.toml @@ -15,11 +15,13 @@ codex-core = { workspace = true, features = ["test-support"] } codex-protocol = { workspace = true } codex-utils-absolute-path = { workspace = true } codex-utils-cargo-bin = { workspace = true } +futures = { workspace = true } notify = { workspace = true } regex-lite = { workspace = true } serde_json = { workspace = true } tempfile = { workspace = true } -tokio = { workspace = true, features = ["time"] } +tokio = { workspace = true, features = ["net", "time"] } +tokio-tungstenite = { workspace = true } walkdir = { workspace = true } wiremock = { workspace = true } shlex = { workspace = true } diff --git a/codex-rs/core/tests/common/responses.rs b/codex-rs/core/tests/common/responses.rs index 710d03fc7..552966e79 100644 --- a/codex-rs/core/tests/common/responses.rs +++ b/codex-rs/core/tests/common/responses.rs @@ -1,3 +1,4 @@ +use std::collections::VecDeque; use std::sync::Arc; use std::sync::Mutex; use std::time::Duration; @@ -5,7 +6,12 @@ use std::time::Duration; use anyhow::Result; use base64::Engine; use codex_protocol::openai_models::ModelsResponse; +use futures::SinkExt; +use futures::StreamExt; use serde_json::Value; +use tokio::net::TcpListener; +use tokio::sync::oneshot; +use tokio_tungstenite::tungstenite::Message; use wiremock::BodyPrintLimit; use wiremock::Match; use wiremock::Mock; @@ -199,6 +205,47 @@ impl ResponsesRequest { } } +#[derive(Debug, Clone)] +pub struct WebSocketRequest { + body: Value, +} + +impl WebSocketRequest { + pub fn body_json(&self) -> Value { + self.body.clone() + } +} + +pub struct WebSocketTestServer { + uri: String, + connections: Arc>>>, + shutdown: oneshot::Sender<()>, + task: tokio::task::JoinHandle<()>, +} + +impl WebSocketTestServer { + pub fn uri(&self) -> &str { + &self.uri + } + + pub fn connections(&self) -> Vec> { + self.connections.lock().unwrap().clone() + } + + pub fn single_connection(&self) -> Vec { + let connections = self.connections.lock().unwrap(); + if connections.len() != 1 { + panic!("expected 1 connection, got {}", connections.len()); + } + connections.first().cloned().unwrap_or_default() + } + + pub async fn shutdown(self) { + let _ = self.shutdown.send(()); + let _ = self.task.await; + } +} + #[derive(Debug, Clone)] pub struct ModelsMock { requests: Arc>>, @@ -724,6 +771,91 @@ pub async fn start_mock_server() -> MockServer { server } +/// Starts a lightweight WebSocket server for `/v1/responses` tests. +/// +/// Each connection consumes a queue of request/event sequences. For each +/// request message, the server records the payload and streams the matching +/// events as WebSocket text frames before moving to the next request. +pub async fn start_websocket_server(connections: Vec>>) -> WebSocketTestServer { + let listener = TcpListener::bind("127.0.0.1:0") + .await + .expect("bind websocket server"); + let addr = listener.local_addr().expect("websocket server address"); + let uri = format!("ws://{addr}"); + let connections_log = Arc::new(Mutex::new(Vec::new())); + let requests = Arc::clone(&connections_log); + let connections = Arc::new(Mutex::new(VecDeque::from(connections))); + let (shutdown_tx, mut shutdown_rx) = oneshot::channel(); + + let task = tokio::spawn(async move { + loop { + let accept_res = tokio::select! { + _ = &mut shutdown_rx => return, + accept_res = listener.accept() => accept_res, + }; + let (stream, _) = match accept_res { + Ok(value) => value, + Err(_) => return, + }; + let mut ws_stream = match tokio_tungstenite::accept_async(stream).await { + Ok(ws) => ws, + Err(_) => continue, + }; + + let connection_requests = { + let mut pending = connections.lock().unwrap(); + pending.pop_front() + }; + + let Some(connection_requests) = connection_requests else { + let _ = ws_stream.close(None).await; + continue; + }; + + let mut connection_log = Vec::new(); + for request_events in connection_requests { + let Some(Ok(message)) = ws_stream.next().await else { + break; + }; + if let Some(body) = parse_ws_request_body(message) { + connection_log.push(WebSocketRequest { body }); + } + + for event in &request_events { + let Ok(payload) = serde_json::to_string(event) else { + continue; + }; + if ws_stream.send(Message::Text(payload)).await.is_err() { + break; + } + } + } + + requests.lock().unwrap().push(connection_log); + let _ = ws_stream.close(None).await; + + if connections.lock().unwrap().is_empty() { + return; + } + } + }); + + WebSocketTestServer { + uri, + connections: connections_log, + shutdown: shutdown_tx, + task, + } +} + +fn parse_ws_request_body(message: Message) -> Option { + match message { + Message::Text(text) => serde_json::from_str(&text).ok(), + Message::Binary(bytes) => serde_json::from_slice(&bytes).ok(), + _ => None, + } +} + #[derive(Clone)] pub struct FunctionCallResponseMocks { pub function_call: ResponseMock, diff --git a/codex-rs/core/tests/responses_headers.rs b/codex-rs/core/tests/responses_headers.rs index dbbf0d57d..3efbb2b7e 100644 --- a/codex-rs/core/tests/responses_headers.rs +++ b/codex-rs/core/tests/responses_headers.rs @@ -91,7 +91,8 @@ async fn responses_stream_includes_subagent_header_on_review() { summary, conversation_id, session_source, - ); + ) + .new_session(); let mut prompt = Prompt::default(); prompt.input = vec![ResponseItem::Message { @@ -186,7 +187,8 @@ async fn responses_stream_includes_subagent_header_on_other() { summary, conversation_id, session_source, - ); + ) + .new_session(); let mut prompt = Prompt::default(); prompt.input = vec![ResponseItem::Message { @@ -279,7 +281,8 @@ async fn responses_respects_model_info_overrides_from_config() { summary, conversation_id, session_source, - ); + ) + .new_session(); let mut prompt = Prompt::default(); prompt.input = vec![ResponseItem::Message { diff --git a/codex-rs/core/tests/suite/client.rs b/codex-rs/core/tests/suite/client.rs index f376ad073..458b355f1 100644 --- a/codex-rs/core/tests/suite/client.rs +++ b/codex-rs/core/tests/suite/client.rs @@ -1181,7 +1181,8 @@ async fn azure_responses_request_includes_store_and_reasoning_ids() { summary, conversation_id, SessionSource::Exec, - ); + ) + .new_session(); let mut prompt = Prompt::default(); prompt.input.push(ResponseItem::Reasoning { diff --git a/codex-rs/core/tests/suite/mod.rs b/codex-rs/core/tests/suite/mod.rs index 2a8216aa4..1838df3ca 100644 --- a/codex-rs/core/tests/suite/mod.rs +++ b/codex-rs/core/tests/suite/mod.rs @@ -71,3 +71,4 @@ mod user_notification; mod user_shell_cmd; mod view_image; mod web_search_cached; +mod websocket; diff --git a/codex-rs/core/tests/suite/stream_no_completed.rs b/codex-rs/core/tests/suite/stream_no_completed.rs index f82aaceaf..3aa20c0c7 100644 --- a/codex-rs/core/tests/suite/stream_no_completed.rs +++ b/codex-rs/core/tests/suite/stream_no_completed.rs @@ -67,7 +67,7 @@ async fn retries_on_early_close() { name: "openai".into(), base_url: Some(format!("{}/v1", server.uri())), // Environment variable that should exist in the test environment. - // ModelClient will return an error if the environment variable for the + // ModelClientSession will return an error if the environment variable for the // provider is not set. env_key: Some("PATH".into()), env_key_instructions: None, diff --git a/codex-rs/core/tests/suite/websocket.rs b/codex-rs/core/tests/suite/websocket.rs new file mode 100644 index 000000000..fc15c8ae8 --- /dev/null +++ b/codex-rs/core/tests/suite/websocket.rs @@ -0,0 +1,112 @@ +use codex_core::AuthManager; +use codex_core::CodexAuth; +use codex_core::ContentItem; +use codex_core::ModelClient; +use codex_core::ModelProviderInfo; +use codex_core::Prompt; +use codex_core::ResponseEvent; +use codex_core::ResponseItem; +use codex_core::WireApi; +use codex_core::models_manager::manager::ModelsManager; +use codex_core::protocol::SessionSource; +use codex_otel::OtelManager; +use codex_protocol::ThreadId; +use core_test_support::load_default_config_for_test; +use core_test_support::responses::ev_completed; +use core_test_support::responses::ev_response_created; +use core_test_support::responses::start_websocket_server; +use futures::StreamExt; +use std::sync::Arc; +use tempfile::TempDir; + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn responses_websocket_streams_request() { + let server = start_websocket_server(vec![vec![vec![ + ev_response_created("resp-1"), + ev_completed("resp-1"), + ]]]) + .await; + + let provider = ModelProviderInfo { + name: "mock-ws".into(), + base_url: Some(format!("{}/v1", server.uri())), + env_key: None, + env_key_instructions: None, + experimental_bearer_token: None, + wire_api: WireApi::ResponsesWebsocket, + query_params: None, + http_headers: None, + env_http_headers: None, + request_max_retries: Some(0), + stream_max_retries: Some(0), + stream_idle_timeout_ms: Some(5_000), + requires_openai_auth: false, + }; + + let codex_home = TempDir::new().unwrap(); + let mut config = load_default_config_for_test(&codex_home).await; + config.model_provider_id = provider.name.clone(); + config.model_provider = provider.clone(); + let effort = config.model_reasoning_effort; + let summary = config.model_reasoning_summary; + let model = ModelsManager::get_model_offline(config.model.as_deref()); + config.model = Some(model.clone()); + let config = Arc::new(config); + let model_info = ModelsManager::construct_model_info_offline(model.as_str(), &config); + let conversation_id = ThreadId::new(); + let auth_manager = AuthManager::from_auth_for_testing(CodexAuth::from_api_key("Test API Key")); + let otel_manager = OtelManager::new( + conversation_id, + model.as_str(), + model_info.slug.as_str(), + None, + Some("test@test.com".to_string()), + auth_manager.get_auth_mode(), + false, + "test".to_string(), + SessionSource::Exec, + ); + + let client = ModelClient::new( + Arc::clone(&config), + None, + model_info, + otel_manager, + provider, + effort, + summary, + conversation_id, + SessionSource::Exec, + ) + .new_session(); + + let mut prompt = Prompt::default(); + prompt.input = vec![ResponseItem::Message { + id: None, + role: "user".into(), + content: vec![ContentItem::InputText { + text: "hello".into(), + }], + }]; + + let mut stream = client + .stream(&prompt) + .await + .expect("websocket stream failed"); + + while let Some(event) = stream.next().await { + if matches!(event, Ok(ResponseEvent::Completed { .. })) { + break; + } + } + + let connection = server.single_connection(); + assert_eq!(connection.len(), 1); + let request = connection.first().cloned().unwrap(); + let body = request.body_json(); + assert_eq!(body["model"].as_str(), Some(model.as_str())); + assert_eq!(body["stream"], serde_json::Value::Bool(true)); + assert_eq!(body["input"].as_array().map(Vec::len), Some(1)); + + server.shutdown().await; +} diff --git a/codex-rs/protocol/src/protocol.rs b/codex-rs/protocol/src/protocol.rs index 4efef6bc1..54cfba8eb 100644 --- a/codex-rs/protocol/src/protocol.rs +++ b/codex-rs/protocol/src/protocol.rs @@ -102,7 +102,7 @@ pub enum Op { /// Policy to use for tool calls such as `local_shell`. sandbox_policy: SandboxPolicy, - /// Must be a valid model slug for the [`crate::client::ModelClient`] + /// Must be a valid model slug for the configured client session /// associated with this conversation. model: String,