Websocket append support (#9128)

Support an incremental append request in websocket transport.
This commit is contained in:
pakrym-oai 2026-01-12 22:07:13 -08:00 committed by GitHub
parent ddae70bd62
commit e726a82c8a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 164 additions and 130 deletions

View file

@ -136,6 +136,38 @@ pub struct ResponsesApiRequest<'a> {
pub text: Option<TextControls>,
}
#[derive(Debug, Serialize)]
pub struct ResponseCreateWsRequest {
pub model: String,
pub instructions: String,
pub input: Vec<ResponseItem>,
pub tools: Vec<Value>,
pub tool_choice: String,
pub parallel_tool_calls: bool,
pub reasoning: Option<Reasoning>,
pub store: bool,
pub stream: bool,
pub include: Vec<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub prompt_cache_key: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub text: Option<TextControls>,
}
#[derive(Debug, Serialize)]
pub struct ResponseAppendWsRequest {
pub input: Vec<ResponseItem>,
}
#[derive(Debug, Serialize)]
#[serde(tag = "type")]
#[allow(clippy::large_enum_variant)]
pub enum ResponsesWsRequest {
#[serde(rename = "response.create")]
ResponseCreate(ResponseCreateWsRequest),
#[serde(rename = "response.append")]
ResponseAppend(ResponseAppendWsRequest),
}
pub fn create_text_param_for_request(
verbosity: Option<VerbosityConfig>,
output_schema: &Option<Value>,

View file

@ -1,13 +1,9 @@
use crate::auth::AuthProvider;
use crate::common::Prompt as ApiPrompt;
use crate::common::ResponseEvent;
use crate::common::ResponseStream;
use crate::endpoint::responses::ResponsesOptions;
use crate::common::ResponsesWsRequest;
use crate::error::ApiError;
use crate::provider::Provider;
use crate::requests::ResponsesRequest;
use crate::requests::ResponsesRequestBuilder;
use crate::requests::responses::Compression;
use crate::sse::responses::ResponsesStreamEvent;
use crate::sse::responses::process_responses_event;
use codex_client::TransportError;
@ -28,7 +24,6 @@ use tokio_tungstenite::tungstenite::Message;
use tokio_tungstenite::tungstenite::client::IntoClientRequest;
use tracing::debug;
use tracing::trace;
use tracing::warn;
use url::Url;
type WsStream = WebSocketStream<MaybeTlsStream<TcpStream>>;
@ -53,19 +48,15 @@ impl ResponsesWebsocketConnection {
pub async fn stream_request(
&self,
request: ResponsesRequest,
request: ResponsesWsRequest,
) -> Result<ResponseStream, ApiError> {
if request.compression == Compression::Zstd {
warn!(
"request compression is not supported for websocket streaming; sending uncompressed payload"
);
}
let (tx_event, rx_event) =
mpsc::channel::<std::result::Result<ResponseEvent, ApiError>>(1600);
let stream = Arc::clone(&self.stream);
let idle_timeout = self.idle_timeout;
let request_body = request.body;
let request_body = serde_json::to_value(&request).map_err(|err| {
ApiError::Stream(format!("failed to encode websocket request: {err}"))
})?;
tokio::spawn(async move {
let mut guard = stream.lock().await;
@ -123,58 +114,6 @@ impl<A: AuthProvider> ResponsesWebsocketClient<A> {
self.provider.stream_idle_timeout,
))
}
pub async fn stream_prompt(
&self,
model: &str,
prompt: &ApiPrompt,
options: ResponsesOptions,
) -> Result<ResponseStream, ApiError> {
let ResponsesOptions {
reasoning,
include,
prompt_cache_key,
text,
store_override,
conversation_id,
session_source,
extra_headers,
compression,
} = options;
// TODO (pakrym): share with HTTP based Responses API client
let request = ResponsesRequestBuilder::new(model, &prompt.instructions, &prompt.input)
.tools(&prompt.tools)
.parallel_tool_calls(prompt.parallel_tool_calls)
.reasoning(reasoning)
.include(include)
.prompt_cache_key(prompt_cache_key)
.text(text)
.conversation(conversation_id)
.session_source(session_source)
.store_override(store_override)
.extra_headers(extra_headers)
.compression(compression)
.build(&self.provider)?;
let connection = self.connect(request.headers.clone()).await?;
connection.stream_request(request).await
}
pub async fn stream(
&self,
body: Value,
extra_headers: HeaderMap,
compression: Compression,
) -> Result<ResponseStream, ApiError> {
let request = ResponsesRequest {
body,
headers: extra_headers,
compression,
};
let connection = self.connect(request.headers.clone()).await?;
connection.stream_request(request).await
}
}
// TODO (pakrym): share with /auth

View file

@ -8,6 +8,7 @@ pub mod requests;
pub mod sse;
pub mod telemetry;
pub use crate::requests::headers::build_conversation_headers;
pub use codex_client::RequestTelemetry;
pub use codex_client::ReqwestTransport;
pub use codex_client::TransportError;
@ -15,6 +16,8 @@ pub use codex_client::TransportError;
pub use crate::auth::AuthProvider;
pub use crate::common::CompactionInput;
pub use crate::common::Prompt;
pub use crate::common::ResponseAppendWsRequest;
pub use crate::common::ResponseCreateWsRequest;
pub use crate::common::ResponseEvent;
pub use crate::common::ResponseStream;
pub use crate::common::ResponsesApiRequest;

View file

@ -2,7 +2,7 @@ use codex_protocol::protocol::SessionSource;
use http::HeaderMap;
use http::HeaderValue;
pub(crate) fn build_conversation_headers(conversation_id: Option<String>) -> HeaderMap {
pub fn build_conversation_headers(conversation_id: Option<String>) -> HeaderMap {
let mut headers = HeaderMap::new();
if let Some(id) = conversation_id {
insert_header(&mut headers, "session_id", &id);

View file

@ -11,16 +11,18 @@ use codex_api::CompactionInput as ApiCompactionInput;
use codex_api::Prompt as ApiPrompt;
use codex_api::RequestTelemetry;
use codex_api::ReqwestTransport;
use codex_api::ResponseAppendWsRequest;
use codex_api::ResponseCreateWsRequest;
use codex_api::ResponseStream as ApiResponseStream;
use codex_api::ResponsesClient as ApiResponsesClient;
use codex_api::ResponsesOptions as ApiResponsesOptions;
use codex_api::ResponsesRequest;
use codex_api::ResponsesRequestBuilder;
use codex_api::ResponsesWebsocketClient as ApiWebSocketResponsesClient;
use codex_api::ResponsesWebsocketConnection as ApiWebSocketConnection;
use codex_api::SseTelemetry;
use codex_api::TransportError;
use codex_api::build_conversation_headers;
use codex_api::common::Reasoning;
use codex_api::common::ResponsesWsRequest;
use codex_api::create_text_param_for_request;
use codex_api::error::ApiError;
use codex_api::requests::responses::Compression;
@ -83,6 +85,7 @@ pub struct ModelClient {
pub struct ModelClientSession {
state: Arc<ModelClientState>,
connection: Option<ApiWebSocketConnection>,
websocket_last_items: Vec<ResponseItem>,
}
#[allow(clippy::too_many_arguments)]
@ -117,6 +120,7 @@ impl ModelClient {
ModelClientSession {
state: Arc::clone(&self.state),
connection: None,
websocket_last_items: Vec::new(),
}
}
}
@ -320,49 +324,65 @@ impl ModelClientSession {
}
}
fn build_responses_websocket_request(
fn get_incremental_items(&self, input_items: &[ResponseItem]) -> Option<Vec<ResponseItem>> {
// Checks whether the current request input is an incremental append to the previous request.
// If items in the new request contain all the items from the previous request we build
// a response.append request otherwise we start with a fresh response.create request.
let previous_len = self.websocket_last_items.len();
let can_append = previous_len > 0
&& input_items.starts_with(&self.websocket_last_items)
&& previous_len < input_items.len();
if can_append {
Some(input_items[previous_len..].to_vec())
} else {
None
}
}
fn prepare_websocket_request(
&self,
api_provider: &codex_api::Provider,
api_prompt: &ApiPrompt,
options: ApiResponsesOptions,
) -> Result<ResponsesRequest> {
options: &ApiResponsesOptions,
) -> ResponsesWsRequest {
if let Some(append_items) = self.get_incremental_items(&api_prompt.input) {
return ResponsesWsRequest::ResponseAppend(ResponseAppendWsRequest {
input: append_items,
});
}
let ApiResponsesOptions {
reasoning,
include,
prompt_cache_key,
text,
store_override,
conversation_id,
session_source,
extra_headers,
compression,
..
} = options;
ResponsesRequestBuilder::new(
&self.state.model_info.slug,
&api_prompt.instructions,
&api_prompt.input,
)
.tools(&api_prompt.tools)
.parallel_tool_calls(api_prompt.parallel_tool_calls)
.reasoning(reasoning)
.include(include)
.prompt_cache_key(prompt_cache_key)
.text(text)
.conversation(conversation_id)
.session_source(session_source)
.store_override(store_override)
.extra_headers(extra_headers)
.compression(compression)
.build(api_provider)
.map_err(map_api_error)
let store = store_override.unwrap_or(false);
let payload = ResponseCreateWsRequest {
model: self.state.model_info.slug.clone(),
instructions: api_prompt.instructions.clone(),
input: api_prompt.input.clone(),
tools: api_prompt.tools.clone(),
tool_choice: "auto".to_string(),
parallel_tool_calls: api_prompt.parallel_tool_calls,
reasoning: reasoning.clone(),
store,
stream: true,
include: include.clone(),
prompt_cache_key: prompt_cache_key.clone(),
text: text.clone(),
};
ResponsesWsRequest::ResponseCreate(payload)
}
async fn websocket_connection(
&mut self,
api_provider: codex_api::Provider,
api_auth: CoreAuthProvider,
headers: ApiHeaderMap,
options: &ApiResponsesOptions,
) -> std::result::Result<&ApiWebSocketConnection, ApiError> {
let needs_new = match self.connection.as_ref() {
Some(conn) => conn.is_closed().await,
@ -370,9 +390,12 @@ impl ModelClientSession {
};
if needs_new {
let new_conn = ApiWebSocketResponsesClient::new(api_provider, api_auth)
.connect(headers)
.await?;
let mut headers = options.extra_headers.clone();
headers.extend(build_conversation_headers(options.conversation_id.clone()));
let new_conn: ApiWebSocketConnection =
ApiWebSocketResponsesClient::new(api_provider, api_auth)
.connect(headers)
.await?;
self.connection = Some(new_conn);
}
@ -533,15 +556,10 @@ impl ModelClientSession {
let compression = self.responses_request_compression(auth.as_ref());
let options = self.build_responses_options(prompt, compression);
let request =
self.build_responses_websocket_request(&api_provider, &api_prompt, options)?;
let request = self.prepare_websocket_request(&api_prompt, &options);
let connection = match self
.websocket_connection(
api_provider.clone(),
api_auth.clone(),
request.headers.clone(),
)
.websocket_connection(api_provider.clone(), api_auth.clone(), &options)
.await
{
Ok(connection) => connection,
@ -558,6 +576,7 @@ impl ModelClientSession {
.stream_request(request)
.await
.map_err(map_api_error)?;
self.websocket_last_items = api_prompt.input.clone();
return Ok(map_response_stream(
stream_result,

View file

@ -44,14 +44,7 @@ async fn responses_websocket_streams_request() {
let harness = websocket_harness(&server).await;
let mut session = harness.client.new_session();
let mut prompt = Prompt::default();
prompt.input = vec![ResponseItem::Message {
id: None,
role: "user".into(),
content: vec![ContentItem::InputText {
text: "hello".into(),
}],
}];
let prompt = prompt_with_input(vec![message_item("hello")]);
stream_until_complete(&mut session, &prompt).await;
@ -59,6 +52,7 @@ async fn responses_websocket_streams_request() {
assert_eq!(connection.len(), 1);
let body = connection.first().expect("missing request").body_json();
assert_eq!(body["type"].as_str(), Some("response.create"));
assert_eq!(body["model"].as_str(), Some(MODEL));
assert_eq!(body["stream"], serde_json::Value::Bool(true));
assert_eq!(body["input"].as_array().map(Vec::len), Some(1));
@ -67,7 +61,7 @@ async fn responses_websocket_streams_request() {
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn responses_websocket_reuses_connection() {
async fn responses_websocket_appends_on_prefix() {
skip_if_no_network!();
let server = start_websocket_server(vec![vec![
@ -78,30 +72,77 @@ async fn responses_websocket_reuses_connection() {
let harness = websocket_harness(&server).await;
let mut session = harness.client.new_session();
let mut prompt = Prompt::default();
prompt.input = vec![ResponseItem::Message {
id: None,
role: "user".into(),
content: vec![ContentItem::InputText {
text: "hello".into(),
}],
}];
let prompt_one = prompt_with_input(vec![message_item("hello")]);
let prompt_two = prompt_with_input(vec![message_item("hello"), message_item("second")]);
for _ in 0..2 {
stream_until_complete(&mut session, &prompt).await;
}
stream_until_complete(&mut session, &prompt_one).await;
stream_until_complete(&mut session, &prompt_two).await;
let connection = server.single_connection();
assert_eq!(connection.len(), 2);
let body = connection.first().expect("missing request").body_json();
let first = connection.first().expect("missing request").body_json();
let second = connection.get(1).expect("missing request").body_json();
assert_eq!(body["model"].as_str(), Some(MODEL));
assert_eq!(body["stream"], serde_json::Value::Bool(true));
assert_eq!(body["input"].as_array().map(Vec::len), Some(1));
assert_eq!(first["type"].as_str(), Some("response.create"));
assert_eq!(first["model"].as_str(), Some(MODEL));
assert_eq!(first["stream"], serde_json::Value::Bool(true));
assert_eq!(first["input"].as_array().map(Vec::len), Some(1));
let expected_append = serde_json::json!({
"type": "response.append",
"input": serde_json::to_value(&prompt_two.input[1..]).expect("serialize append items"),
});
assert_eq!(second, expected_append);
server.shutdown().await;
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn responses_websocket_creates_on_non_prefix() {
skip_if_no_network!();
let server = start_websocket_server(vec![vec![
vec![ev_response_created("resp-1"), ev_completed("resp-1")],
vec![ev_response_created("resp-2"), ev_completed("resp-2")],
]])
.await;
let harness = websocket_harness(&server).await;
let mut session = harness.client.new_session();
let prompt_one = prompt_with_input(vec![message_item("hello")]);
let prompt_two = prompt_with_input(vec![message_item("different")]);
stream_until_complete(&mut session, &prompt_one).await;
stream_until_complete(&mut session, &prompt_two).await;
let connection = server.single_connection();
assert_eq!(connection.len(), 2);
let second = connection.get(1).expect("missing request").body_json();
assert_eq!(second["type"].as_str(), Some("response.create"));
assert_eq!(second["model"].as_str(), Some(MODEL));
assert_eq!(second["stream"], serde_json::Value::Bool(true));
assert_eq!(
second["input"],
serde_json::to_value(&prompt_two.input).unwrap()
);
server.shutdown().await;
}
fn message_item(text: &str) -> ResponseItem {
ResponseItem::Message {
id: None,
role: "user".into(),
content: vec![ContentItem::InputText { text: text.into() }],
}
}
fn prompt_with_input(input: Vec<ResponseItem>) -> Prompt {
let mut prompt = Prompt::default();
prompt.input = input;
prompt
}
fn websocket_provider(server: &WebSocketTestServer) -> ModelProviderInfo {
ModelProviderInfo {
name: "mock-ws".into(),