From e726a82c8aac6b0f0843bb914ca83f6e28305724 Mon Sep 17 00:00:00 2001 From: pakrym-oai Date: Mon, 12 Jan 2026 22:07:13 -0800 Subject: [PATCH] Websocket append support (#9128) Support an incremental append request in websocket transport. --- codex-rs/codex-api/src/common.rs | 32 ++++++ .../src/endpoint/responses_websocket.rs | 71 +------------- codex-rs/codex-api/src/lib.rs | 3 + codex-rs/codex-api/src/requests/headers.rs | 2 +- codex-rs/core/src/client.rs | 97 +++++++++++-------- codex-rs/core/tests/suite/websocket.rs | 89 ++++++++++++----- 6 files changed, 164 insertions(+), 130 deletions(-) diff --git a/codex-rs/codex-api/src/common.rs b/codex-rs/codex-api/src/common.rs index db1524d27..2118cf66e 100644 --- a/codex-rs/codex-api/src/common.rs +++ b/codex-rs/codex-api/src/common.rs @@ -136,6 +136,38 @@ pub struct ResponsesApiRequest<'a> { pub text: Option, } +#[derive(Debug, Serialize)] +pub struct ResponseCreateWsRequest { + pub model: String, + pub instructions: String, + pub input: Vec, + pub tools: Vec, + pub tool_choice: String, + pub parallel_tool_calls: bool, + pub reasoning: Option, + pub store: bool, + pub stream: bool, + pub include: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub prompt_cache_key: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub text: Option, +} + +#[derive(Debug, Serialize)] +pub struct ResponseAppendWsRequest { + pub input: Vec, +} +#[derive(Debug, Serialize)] +#[serde(tag = "type")] +#[allow(clippy::large_enum_variant)] +pub enum ResponsesWsRequest { + #[serde(rename = "response.create")] + ResponseCreate(ResponseCreateWsRequest), + #[serde(rename = "response.append")] + ResponseAppend(ResponseAppendWsRequest), +} + pub fn create_text_param_for_request( verbosity: Option, output_schema: &Option, diff --git a/codex-rs/codex-api/src/endpoint/responses_websocket.rs b/codex-rs/codex-api/src/endpoint/responses_websocket.rs index 9cd264505..af3335ce7 100644 --- a/codex-rs/codex-api/src/endpoint/responses_websocket.rs +++ b/codex-rs/codex-api/src/endpoint/responses_websocket.rs @@ -1,13 +1,9 @@ use crate::auth::AuthProvider; -use crate::common::Prompt as ApiPrompt; use crate::common::ResponseEvent; use crate::common::ResponseStream; -use crate::endpoint::responses::ResponsesOptions; +use crate::common::ResponsesWsRequest; use crate::error::ApiError; use crate::provider::Provider; -use crate::requests::ResponsesRequest; -use crate::requests::ResponsesRequestBuilder; -use crate::requests::responses::Compression; use crate::sse::responses::ResponsesStreamEvent; use crate::sse::responses::process_responses_event; use codex_client::TransportError; @@ -28,7 +24,6 @@ use tokio_tungstenite::tungstenite::Message; use tokio_tungstenite::tungstenite::client::IntoClientRequest; use tracing::debug; use tracing::trace; -use tracing::warn; use url::Url; type WsStream = WebSocketStream>; @@ -53,19 +48,15 @@ impl ResponsesWebsocketConnection { pub async fn stream_request( &self, - request: ResponsesRequest, + request: ResponsesWsRequest, ) -> Result { - if request.compression == Compression::Zstd { - warn!( - "request compression is not supported for websocket streaming; sending uncompressed payload" - ); - } - let (tx_event, rx_event) = mpsc::channel::>(1600); let stream = Arc::clone(&self.stream); let idle_timeout = self.idle_timeout; - let request_body = request.body; + let request_body = serde_json::to_value(&request).map_err(|err| { + ApiError::Stream(format!("failed to encode websocket request: {err}")) + })?; tokio::spawn(async move { let mut guard = stream.lock().await; @@ -123,58 +114,6 @@ impl ResponsesWebsocketClient { self.provider.stream_idle_timeout, )) } - - pub async fn stream_prompt( - &self, - model: &str, - prompt: &ApiPrompt, - options: ResponsesOptions, - ) -> Result { - let ResponsesOptions { - reasoning, - include, - prompt_cache_key, - text, - store_override, - conversation_id, - session_source, - extra_headers, - compression, - } = options; - - // TODO (pakrym): share with HTTP based Responses API client - let request = ResponsesRequestBuilder::new(model, &prompt.instructions, &prompt.input) - .tools(&prompt.tools) - .parallel_tool_calls(prompt.parallel_tool_calls) - .reasoning(reasoning) - .include(include) - .prompt_cache_key(prompt_cache_key) - .text(text) - .conversation(conversation_id) - .session_source(session_source) - .store_override(store_override) - .extra_headers(extra_headers) - .compression(compression) - .build(&self.provider)?; - - let connection = self.connect(request.headers.clone()).await?; - connection.stream_request(request).await - } - - pub async fn stream( - &self, - body: Value, - extra_headers: HeaderMap, - compression: Compression, - ) -> Result { - let request = ResponsesRequest { - body, - headers: extra_headers, - compression, - }; - let connection = self.connect(request.headers.clone()).await?; - connection.stream_request(request).await - } } // TODO (pakrym): share with /auth diff --git a/codex-rs/codex-api/src/lib.rs b/codex-rs/codex-api/src/lib.rs index 0128efc35..0f608fd23 100644 --- a/codex-rs/codex-api/src/lib.rs +++ b/codex-rs/codex-api/src/lib.rs @@ -8,6 +8,7 @@ pub mod requests; pub mod sse; pub mod telemetry; +pub use crate::requests::headers::build_conversation_headers; pub use codex_client::RequestTelemetry; pub use codex_client::ReqwestTransport; pub use codex_client::TransportError; @@ -15,6 +16,8 @@ pub use codex_client::TransportError; pub use crate::auth::AuthProvider; pub use crate::common::CompactionInput; pub use crate::common::Prompt; +pub use crate::common::ResponseAppendWsRequest; +pub use crate::common::ResponseCreateWsRequest; pub use crate::common::ResponseEvent; pub use crate::common::ResponseStream; pub use crate::common::ResponsesApiRequest; diff --git a/codex-rs/codex-api/src/requests/headers.rs b/codex-rs/codex-api/src/requests/headers.rs index bdc7bba4f..02f08724f 100644 --- a/codex-rs/codex-api/src/requests/headers.rs +++ b/codex-rs/codex-api/src/requests/headers.rs @@ -2,7 +2,7 @@ use codex_protocol::protocol::SessionSource; use http::HeaderMap; use http::HeaderValue; -pub(crate) fn build_conversation_headers(conversation_id: Option) -> HeaderMap { +pub fn build_conversation_headers(conversation_id: Option) -> HeaderMap { let mut headers = HeaderMap::new(); if let Some(id) = conversation_id { insert_header(&mut headers, "session_id", &id); diff --git a/codex-rs/core/src/client.rs b/codex-rs/core/src/client.rs index e6dc69b15..438c7c5ee 100644 --- a/codex-rs/core/src/client.rs +++ b/codex-rs/core/src/client.rs @@ -11,16 +11,18 @@ use codex_api::CompactionInput as ApiCompactionInput; use codex_api::Prompt as ApiPrompt; use codex_api::RequestTelemetry; use codex_api::ReqwestTransport; +use codex_api::ResponseAppendWsRequest; +use codex_api::ResponseCreateWsRequest; use codex_api::ResponseStream as ApiResponseStream; use codex_api::ResponsesClient as ApiResponsesClient; use codex_api::ResponsesOptions as ApiResponsesOptions; -use codex_api::ResponsesRequest; -use codex_api::ResponsesRequestBuilder; use codex_api::ResponsesWebsocketClient as ApiWebSocketResponsesClient; use codex_api::ResponsesWebsocketConnection as ApiWebSocketConnection; use codex_api::SseTelemetry; use codex_api::TransportError; +use codex_api::build_conversation_headers; use codex_api::common::Reasoning; +use codex_api::common::ResponsesWsRequest; use codex_api::create_text_param_for_request; use codex_api::error::ApiError; use codex_api::requests::responses::Compression; @@ -83,6 +85,7 @@ pub struct ModelClient { pub struct ModelClientSession { state: Arc, connection: Option, + websocket_last_items: Vec, } #[allow(clippy::too_many_arguments)] @@ -117,6 +120,7 @@ impl ModelClient { ModelClientSession { state: Arc::clone(&self.state), connection: None, + websocket_last_items: Vec::new(), } } } @@ -320,49 +324,65 @@ impl ModelClientSession { } } - fn build_responses_websocket_request( + fn get_incremental_items(&self, input_items: &[ResponseItem]) -> Option> { + // Checks whether the current request input is an incremental append to the previous request. + // If items in the new request contain all the items from the previous request we build + // a response.append request otherwise we start with a fresh response.create request. + let previous_len = self.websocket_last_items.len(); + let can_append = previous_len > 0 + && input_items.starts_with(&self.websocket_last_items) + && previous_len < input_items.len(); + if can_append { + Some(input_items[previous_len..].to_vec()) + } else { + None + } + } + + fn prepare_websocket_request( &self, - api_provider: &codex_api::Provider, api_prompt: &ApiPrompt, - options: ApiResponsesOptions, - ) -> Result { + options: &ApiResponsesOptions, + ) -> ResponsesWsRequest { + if let Some(append_items) = self.get_incremental_items(&api_prompt.input) { + return ResponsesWsRequest::ResponseAppend(ResponseAppendWsRequest { + input: append_items, + }); + } + let ApiResponsesOptions { reasoning, include, prompt_cache_key, text, store_override, - conversation_id, - session_source, - extra_headers, - compression, + .. } = options; - ResponsesRequestBuilder::new( - &self.state.model_info.slug, - &api_prompt.instructions, - &api_prompt.input, - ) - .tools(&api_prompt.tools) - .parallel_tool_calls(api_prompt.parallel_tool_calls) - .reasoning(reasoning) - .include(include) - .prompt_cache_key(prompt_cache_key) - .text(text) - .conversation(conversation_id) - .session_source(session_source) - .store_override(store_override) - .extra_headers(extra_headers) - .compression(compression) - .build(api_provider) - .map_err(map_api_error) + let store = store_override.unwrap_or(false); + let payload = ResponseCreateWsRequest { + model: self.state.model_info.slug.clone(), + instructions: api_prompt.instructions.clone(), + input: api_prompt.input.clone(), + tools: api_prompt.tools.clone(), + tool_choice: "auto".to_string(), + parallel_tool_calls: api_prompt.parallel_tool_calls, + reasoning: reasoning.clone(), + store, + stream: true, + include: include.clone(), + prompt_cache_key: prompt_cache_key.clone(), + text: text.clone(), + }; + + ResponsesWsRequest::ResponseCreate(payload) } async fn websocket_connection( &mut self, api_provider: codex_api::Provider, api_auth: CoreAuthProvider, - headers: ApiHeaderMap, + options: &ApiResponsesOptions, ) -> std::result::Result<&ApiWebSocketConnection, ApiError> { let needs_new = match self.connection.as_ref() { Some(conn) => conn.is_closed().await, @@ -370,9 +390,12 @@ impl ModelClientSession { }; if needs_new { - let new_conn = ApiWebSocketResponsesClient::new(api_provider, api_auth) - .connect(headers) - .await?; + let mut headers = options.extra_headers.clone(); + headers.extend(build_conversation_headers(options.conversation_id.clone())); + let new_conn: ApiWebSocketConnection = + ApiWebSocketResponsesClient::new(api_provider, api_auth) + .connect(headers) + .await?; self.connection = Some(new_conn); } @@ -533,15 +556,10 @@ impl ModelClientSession { let compression = self.responses_request_compression(auth.as_ref()); let options = self.build_responses_options(prompt, compression); - let request = - self.build_responses_websocket_request(&api_provider, &api_prompt, options)?; + let request = self.prepare_websocket_request(&api_prompt, &options); let connection = match self - .websocket_connection( - api_provider.clone(), - api_auth.clone(), - request.headers.clone(), - ) + .websocket_connection(api_provider.clone(), api_auth.clone(), &options) .await { Ok(connection) => connection, @@ -558,6 +576,7 @@ impl ModelClientSession { .stream_request(request) .await .map_err(map_api_error)?; + self.websocket_last_items = api_prompt.input.clone(); return Ok(map_response_stream( stream_result, diff --git a/codex-rs/core/tests/suite/websocket.rs b/codex-rs/core/tests/suite/websocket.rs index 6a5f99acb..9d39fb240 100644 --- a/codex-rs/core/tests/suite/websocket.rs +++ b/codex-rs/core/tests/suite/websocket.rs @@ -44,14 +44,7 @@ async fn responses_websocket_streams_request() { let harness = websocket_harness(&server).await; let mut session = harness.client.new_session(); - let mut prompt = Prompt::default(); - prompt.input = vec![ResponseItem::Message { - id: None, - role: "user".into(), - content: vec![ContentItem::InputText { - text: "hello".into(), - }], - }]; + let prompt = prompt_with_input(vec![message_item("hello")]); stream_until_complete(&mut session, &prompt).await; @@ -59,6 +52,7 @@ async fn responses_websocket_streams_request() { assert_eq!(connection.len(), 1); let body = connection.first().expect("missing request").body_json(); + assert_eq!(body["type"].as_str(), Some("response.create")); assert_eq!(body["model"].as_str(), Some(MODEL)); assert_eq!(body["stream"], serde_json::Value::Bool(true)); assert_eq!(body["input"].as_array().map(Vec::len), Some(1)); @@ -67,7 +61,7 @@ async fn responses_websocket_streams_request() { } #[tokio::test(flavor = "multi_thread", worker_threads = 2)] -async fn responses_websocket_reuses_connection() { +async fn responses_websocket_appends_on_prefix() { skip_if_no_network!(); let server = start_websocket_server(vec![vec![ @@ -78,30 +72,77 @@ async fn responses_websocket_reuses_connection() { let harness = websocket_harness(&server).await; let mut session = harness.client.new_session(); - let mut prompt = Prompt::default(); - prompt.input = vec![ResponseItem::Message { - id: None, - role: "user".into(), - content: vec![ContentItem::InputText { - text: "hello".into(), - }], - }]; + let prompt_one = prompt_with_input(vec![message_item("hello")]); + let prompt_two = prompt_with_input(vec![message_item("hello"), message_item("second")]); - for _ in 0..2 { - stream_until_complete(&mut session, &prompt).await; - } + stream_until_complete(&mut session, &prompt_one).await; + stream_until_complete(&mut session, &prompt_two).await; let connection = server.single_connection(); assert_eq!(connection.len(), 2); - let body = connection.first().expect("missing request").body_json(); + let first = connection.first().expect("missing request").body_json(); + let second = connection.get(1).expect("missing request").body_json(); - assert_eq!(body["model"].as_str(), Some(MODEL)); - assert_eq!(body["stream"], serde_json::Value::Bool(true)); - assert_eq!(body["input"].as_array().map(Vec::len), Some(1)); + assert_eq!(first["type"].as_str(), Some("response.create")); + assert_eq!(first["model"].as_str(), Some(MODEL)); + assert_eq!(first["stream"], serde_json::Value::Bool(true)); + assert_eq!(first["input"].as_array().map(Vec::len), Some(1)); + let expected_append = serde_json::json!({ + "type": "response.append", + "input": serde_json::to_value(&prompt_two.input[1..]).expect("serialize append items"), + }); + assert_eq!(second, expected_append); server.shutdown().await; } +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn responses_websocket_creates_on_non_prefix() { + skip_if_no_network!(); + + let server = start_websocket_server(vec![vec![ + vec![ev_response_created("resp-1"), ev_completed("resp-1")], + vec![ev_response_created("resp-2"), ev_completed("resp-2")], + ]]) + .await; + + let harness = websocket_harness(&server).await; + let mut session = harness.client.new_session(); + let prompt_one = prompt_with_input(vec![message_item("hello")]); + let prompt_two = prompt_with_input(vec![message_item("different")]); + + stream_until_complete(&mut session, &prompt_one).await; + stream_until_complete(&mut session, &prompt_two).await; + + let connection = server.single_connection(); + assert_eq!(connection.len(), 2); + let second = connection.get(1).expect("missing request").body_json(); + + assert_eq!(second["type"].as_str(), Some("response.create")); + assert_eq!(second["model"].as_str(), Some(MODEL)); + assert_eq!(second["stream"], serde_json::Value::Bool(true)); + assert_eq!( + second["input"], + serde_json::to_value(&prompt_two.input).unwrap() + ); + + server.shutdown().await; +} + +fn message_item(text: &str) -> ResponseItem { + ResponseItem::Message { + id: None, + role: "user".into(), + content: vec![ContentItem::InputText { text: text.into() }], + } +} + +fn prompt_with_input(input: Vec) -> Prompt { + let mut prompt = Prompt::default(); + prompt.input = input; + prompt +} + fn websocket_provider(server: &WebSocketTestServer) -> ModelProviderInfo { ModelProviderInfo { name: "mock-ws".into(),