From 0639c3389263c4fea045ecbcf8b95f84e8197fa9 Mon Sep 17 00:00:00 2001 From: pakrym-oai Date: Tue, 10 Feb 2026 11:14:36 -0800 Subject: [PATCH] Compare full request for websockets incrementality (#11343) Tools can dynamically change mid-turn now. We need to be more thorough about reusing incremental connections. --- codex-rs/codex-api/src/common.rs | 12 +-- codex-rs/core/src/client.rs | 66 ++++++++------- .../core/tests/suite/client_websockets.rs | 82 +++++++++++++++++++ 3 files changed, 124 insertions(+), 36 deletions(-) diff --git a/codex-rs/codex-api/src/common.rs b/codex-rs/codex-api/src/common.rs index 69f697c06..bfd8bf666 100644 --- a/codex-rs/codex-api/src/common.rs +++ b/codex-rs/codex-api/src/common.rs @@ -80,7 +80,7 @@ pub enum ResponseEvent { ModelsEtag(String), } -#[derive(Debug, Serialize, Clone)] +#[derive(Debug, Serialize, Clone, PartialEq)] pub struct Reasoning { #[serde(skip_serializing_if = "Option::is_none")] pub effort: Option, @@ -88,14 +88,14 @@ pub struct Reasoning { pub summary: Option, } -#[derive(Debug, Serialize, Default, Clone)] +#[derive(Debug, Serialize, Default, Clone, PartialEq)] #[serde(rename_all = "snake_case")] pub enum TextFormatType { #[default] JsonSchema, } -#[derive(Debug, Serialize, Default, Clone)] +#[derive(Debug, Serialize, Default, Clone, PartialEq)] pub struct TextFormat { /// Format type used by the OpenAI text controls. pub r#type: TextFormatType, @@ -109,7 +109,7 @@ pub struct TextFormat { /// Controls the `text` field for the Responses API, combining verbosity and /// optional JSON schema output formatting. -#[derive(Debug, Serialize, Default, Clone)] +#[derive(Debug, Serialize, Default, Clone, PartialEq)] pub struct TextControls { #[serde(skip_serializing_if = "Option::is_none")] pub verbosity: Option, @@ -117,7 +117,7 @@ pub struct TextControls { pub format: Option, } -#[derive(Debug, Serialize, Default, Clone)] +#[derive(Debug, Serialize, Default, Clone, PartialEq)] #[serde(rename_all = "lowercase")] pub enum OpenAiVerbosity { Low, @@ -136,7 +136,7 @@ impl From for OpenAiVerbosity { } } -#[derive(Debug, Serialize, Clone)] +#[derive(Debug, Serialize, Clone, PartialEq)] pub struct ResponsesApiRequest { pub model: String, pub instructions: String, diff --git a/codex-rs/core/src/client.rs b/codex-rs/core/src/client.rs index 199ced5a8..1a6e2e0c4 100644 --- a/codex-rs/core/src/client.rs +++ b/codex-rs/core/src/client.rs @@ -155,8 +155,8 @@ pub struct ModelClient { /// The session establishes a Responses WebSocket connection lazily and reuses it across multiple /// requests within the turn. It also caches per-turn state: /// -/// - The last request's input items, so subsequent calls can use `response.append` when the input -/// is an incremental extension of the previous request. +/// - The last full request, so subsequent calls can use `response.append` only when the current +/// request is an incremental extension of the previous one. /// - The `x-codex-turn-state` sticky-routing token, which must be replayed for all requests within /// the same turn. /// @@ -166,7 +166,7 @@ pub struct ModelClient { pub struct ModelClientSession { client: ModelClient, connection: Option, - websocket_last_items: Vec, + websocket_last_request: Option, websocket_last_response_id: Option, websocket_last_response_id_rx: Option>, /// Turn state for sticky routing. @@ -230,7 +230,7 @@ impl ModelClient { ModelClientSession { client: self.clone(), connection: None, - websocket_last_items: Vec::new(), + websocket_last_request: None, websocket_last_response_id: None, websocket_last_response_id_rx: None, turn_state: Arc::new(OnceLock::new()), @@ -530,16 +530,25 @@ impl ModelClientSession { } } - fn get_incremental_items(&self, input_items: &[ResponseItem]) -> Option> { - // Checks whether the current request input is an incremental append to the previous request. - // If items in the new request contain all the items from the previous request we build - // a response.append request otherwise we start with a fresh response.create request. - let previous_len = self.websocket_last_items.len(); - let can_append = previous_len > 0 - && input_items.starts_with(&self.websocket_last_items) - && previous_len < input_items.len(); - if can_append { - Some(input_items[previous_len..].to_vec()) + fn get_incremental_items(&self, request: &ResponsesApiRequest) -> Option> { + // Checks whether the current request is an incremental append to the previous request. + // We only append when non-input request fields are unchanged and `input` is a strict + // extension of the previous input. + let previous_request = self.websocket_last_request.as_ref()?; + let mut previous_without_input = previous_request.clone(); + previous_without_input.input.clear(); + let mut request_without_input = request.clone(); + request_without_input.input.clear(); + if previous_without_input != request_without_input { + return None; + } + + let previous_len = previous_request.input.len(); + if previous_len > 0 + && request.input.starts_with(&previous_request.input) + && previous_len < request.input.len() + { + Some(request.input[previous_len..].to_vec()) } else { None } @@ -571,10 +580,10 @@ impl ModelClientSession { fn prepare_websocket_request( &mut self, payload: ResponseCreateWsRequest, - ) -> (ResponsesWsRequest, Vec) { - let full_input = payload.input.clone(); + request: &ResponsesApiRequest, + ) -> ResponsesWsRequest { let responses_websockets_v2_enabled = self.client.responses_websockets_v2_enabled(); - let incremental_items = self.get_incremental_items(&full_input); + let incremental_items = self.get_incremental_items(request); if let Some(append_items) = incremental_items { if responses_websockets_v2_enabled && let Some(previous_response_id) = self.websocket_previous_response_id() @@ -584,20 +593,17 @@ impl ModelClientSession { input: append_items, ..payload }; - return (ResponsesWsRequest::ResponseCreate(payload), full_input); + return ResponsesWsRequest::ResponseCreate(payload); } if !responses_websockets_v2_enabled { - return ( - ResponsesWsRequest::ResponseAppend(ResponseAppendWsRequest { - input: append_items, - }), - full_input, - ); + return ResponsesWsRequest::ResponseAppend(ResponseAppendWsRequest { + input: append_items, + }); } } - (ResponsesWsRequest::ResponseCreate(payload), full_input) + ResponsesWsRequest::ResponseCreate(payload) } /// Opportunistically warms a websocket for this turn-scoped client session. @@ -650,7 +656,7 @@ impl ModelClientSession { }; if needs_new { - self.websocket_last_items.clear(); + self.websocket_last_request = None; self.websocket_last_response_id = None; self.websocket_last_response_id_rx = None; let turn_state = options @@ -806,7 +812,7 @@ impl ModelClientSession { Err(err) => return Err(map_api_error(err)), } - let (request, request_input) = self.prepare_websocket_request(ws_payload); + let ws_request = self.prepare_websocket_request(ws_payload, &request); let stream_result = self .connection @@ -816,10 +822,10 @@ impl ModelClientSession { "websocket connection is unavailable".to_string(), )) })? - .stream_request(request) + .stream_request(ws_request) .await .map_err(map_api_error)?; - self.websocket_last_items = request_input; + self.websocket_last_request = Some(request); let (last_response_id_sender, last_response_id_receiver) = oneshot::channel(); self.websocket_last_response_id_rx = Some(last_response_id_receiver); let mut last_response_id_sender = Some(last_response_id_sender); @@ -928,7 +934,7 @@ impl ModelClientSession { ); self.connection = None; - self.websocket_last_items.clear(); + self.websocket_last_request = None; } activated } diff --git a/codex-rs/core/tests/suite/client_websockets.rs b/codex-rs/core/tests/suite/client_websockets.rs index 1991a6290..d1ebc6377 100755 --- a/codex-rs/core/tests/suite/client_websockets.rs +++ b/codex-rs/core/tests/suite/client_websockets.rs @@ -22,6 +22,7 @@ use codex_otel::metrics::MetricsConfig; use codex_protocol::ThreadId; use codex_protocol::account::PlanType; use codex_protocol::config_types::ReasoningSummary; +use codex_protocol::models::BaseInstructions; use codex_protocol::openai_models::ModelInfo; use codex_protocol::openai_models::ReasoningEffort as ReasoningEffortConfig; use codex_protocol::user_input::UserInput; @@ -603,6 +604,42 @@ async fn responses_websocket_creates_on_non_prefix() { server.shutdown().await; } +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn responses_websocket_creates_when_non_input_request_fields_change() { + skip_if_no_network!(); + + let server = start_websocket_server(vec![vec![ + vec![ev_response_created("resp-1"), ev_completed("resp-1")], + vec![ev_response_created("resp-2"), ev_completed("resp-2")], + ]]) + .await; + + let harness = websocket_harness(&server).await; + let mut client_session = harness.client.new_session(); + let prompt_one = + prompt_with_input_and_instructions(vec![message_item("hello")], "base instructions one"); + let prompt_two = prompt_with_input_and_instructions( + vec![message_item("hello"), message_item("second")], + "base instructions two", + ); + + stream_until_complete(&mut client_session, &harness, &prompt_one).await; + stream_until_complete(&mut client_session, &harness, &prompt_two).await; + + let connection = server.single_connection(); + assert_eq!(connection.len(), 2); + let second = connection.get(1).expect("missing request").body_json(); + + assert_eq!(second["type"].as_str(), Some("response.create")); + assert_eq!(second.get("previous_response_id"), None); + assert_eq!( + second["input"], + serde_json::to_value(&prompt_two.input).expect("serialize full input") + ); + + server.shutdown().await; +} + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn responses_websocket_v2_creates_with_previous_response_id_on_prefix() { skip_if_no_network!(); @@ -637,6 +674,43 @@ async fn responses_websocket_v2_creates_with_previous_response_id_on_prefix() { server.shutdown().await; } +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn responses_websocket_v2_creates_without_previous_response_id_when_non_input_fields_change() +{ + skip_if_no_network!(); + + let server = start_websocket_server(vec![vec![ + vec![ev_response_created("resp-1"), ev_completed("resp-1")], + vec![ev_response_created("resp-2"), ev_completed("resp-2")], + ]]) + .await; + + let harness = websocket_harness_with_v2(&server, true).await; + let mut session = harness.client.new_session(); + let prompt_one = + prompt_with_input_and_instructions(vec![message_item("hello")], "base instructions one"); + let prompt_two = prompt_with_input_and_instructions( + vec![message_item("hello"), message_item("second")], + "base instructions two", + ); + + stream_until_complete(&mut session, &harness, &prompt_one).await; + stream_until_complete(&mut session, &harness, &prompt_two).await; + + let connection = server.single_connection(); + assert_eq!(connection.len(), 2); + let second = connection.get(1).expect("missing request").body_json(); + + assert_eq!(second["type"].as_str(), Some("response.create")); + assert_eq!(second.get("previous_response_id"), None); + assert_eq!( + second["input"], + serde_json::to_value(&prompt_two.input).expect("serialize full input") + ); + + server.shutdown().await; +} + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn responses_websocket_v2_after_error_uses_full_create_without_previous_response_id() { skip_if_no_network!(); @@ -778,6 +852,14 @@ fn prompt_with_input(input: Vec) -> Prompt { prompt } +fn prompt_with_input_and_instructions(input: Vec, instructions: &str) -> Prompt { + let mut prompt = prompt_with_input(input); + prompt.base_instructions = BaseInstructions { + text: instructions.to_string(), + }; + prompt +} + fn websocket_provider(server: &WebSocketTestServer) -> ModelProviderInfo { ModelProviderInfo { name: "mock-ws".into(),