Websocket append support (#9128)
Support an incremental append request in websocket transport.
This commit is contained in:
parent
ddae70bd62
commit
e726a82c8a
6 changed files with 164 additions and 130 deletions
|
|
@ -136,6 +136,38 @@ pub struct ResponsesApiRequest<'a> {
|
|||
pub text: Option<TextControls>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct ResponseCreateWsRequest {
|
||||
pub model: String,
|
||||
pub instructions: String,
|
||||
pub input: Vec<ResponseItem>,
|
||||
pub tools: Vec<Value>,
|
||||
pub tool_choice: String,
|
||||
pub parallel_tool_calls: bool,
|
||||
pub reasoning: Option<Reasoning>,
|
||||
pub store: bool,
|
||||
pub stream: bool,
|
||||
pub include: Vec<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub prompt_cache_key: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub text: Option<TextControls>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct ResponseAppendWsRequest {
|
||||
pub input: Vec<ResponseItem>,
|
||||
}
|
||||
#[derive(Debug, Serialize)]
|
||||
#[serde(tag = "type")]
|
||||
#[allow(clippy::large_enum_variant)]
|
||||
pub enum ResponsesWsRequest {
|
||||
#[serde(rename = "response.create")]
|
||||
ResponseCreate(ResponseCreateWsRequest),
|
||||
#[serde(rename = "response.append")]
|
||||
ResponseAppend(ResponseAppendWsRequest),
|
||||
}
|
||||
|
||||
pub fn create_text_param_for_request(
|
||||
verbosity: Option<VerbosityConfig>,
|
||||
output_schema: &Option<Value>,
|
||||
|
|
|
|||
|
|
@ -1,13 +1,9 @@
|
|||
use crate::auth::AuthProvider;
|
||||
use crate::common::Prompt as ApiPrompt;
|
||||
use crate::common::ResponseEvent;
|
||||
use crate::common::ResponseStream;
|
||||
use crate::endpoint::responses::ResponsesOptions;
|
||||
use crate::common::ResponsesWsRequest;
|
||||
use crate::error::ApiError;
|
||||
use crate::provider::Provider;
|
||||
use crate::requests::ResponsesRequest;
|
||||
use crate::requests::ResponsesRequestBuilder;
|
||||
use crate::requests::responses::Compression;
|
||||
use crate::sse::responses::ResponsesStreamEvent;
|
||||
use crate::sse::responses::process_responses_event;
|
||||
use codex_client::TransportError;
|
||||
|
|
@ -28,7 +24,6 @@ use tokio_tungstenite::tungstenite::Message;
|
|||
use tokio_tungstenite::tungstenite::client::IntoClientRequest;
|
||||
use tracing::debug;
|
||||
use tracing::trace;
|
||||
use tracing::warn;
|
||||
use url::Url;
|
||||
|
||||
type WsStream = WebSocketStream<MaybeTlsStream<TcpStream>>;
|
||||
|
|
@ -53,19 +48,15 @@ impl ResponsesWebsocketConnection {
|
|||
|
||||
pub async fn stream_request(
|
||||
&self,
|
||||
request: ResponsesRequest,
|
||||
request: ResponsesWsRequest,
|
||||
) -> Result<ResponseStream, ApiError> {
|
||||
if request.compression == Compression::Zstd {
|
||||
warn!(
|
||||
"request compression is not supported for websocket streaming; sending uncompressed payload"
|
||||
);
|
||||
}
|
||||
|
||||
let (tx_event, rx_event) =
|
||||
mpsc::channel::<std::result::Result<ResponseEvent, ApiError>>(1600);
|
||||
let stream = Arc::clone(&self.stream);
|
||||
let idle_timeout = self.idle_timeout;
|
||||
let request_body = request.body;
|
||||
let request_body = serde_json::to_value(&request).map_err(|err| {
|
||||
ApiError::Stream(format!("failed to encode websocket request: {err}"))
|
||||
})?;
|
||||
|
||||
tokio::spawn(async move {
|
||||
let mut guard = stream.lock().await;
|
||||
|
|
@ -123,58 +114,6 @@ impl<A: AuthProvider> ResponsesWebsocketClient<A> {
|
|||
self.provider.stream_idle_timeout,
|
||||
))
|
||||
}
|
||||
|
||||
pub async fn stream_prompt(
|
||||
&self,
|
||||
model: &str,
|
||||
prompt: &ApiPrompt,
|
||||
options: ResponsesOptions,
|
||||
) -> Result<ResponseStream, ApiError> {
|
||||
let ResponsesOptions {
|
||||
reasoning,
|
||||
include,
|
||||
prompt_cache_key,
|
||||
text,
|
||||
store_override,
|
||||
conversation_id,
|
||||
session_source,
|
||||
extra_headers,
|
||||
compression,
|
||||
} = options;
|
||||
|
||||
// TODO (pakrym): share with HTTP based Responses API client
|
||||
let request = ResponsesRequestBuilder::new(model, &prompt.instructions, &prompt.input)
|
||||
.tools(&prompt.tools)
|
||||
.parallel_tool_calls(prompt.parallel_tool_calls)
|
||||
.reasoning(reasoning)
|
||||
.include(include)
|
||||
.prompt_cache_key(prompt_cache_key)
|
||||
.text(text)
|
||||
.conversation(conversation_id)
|
||||
.session_source(session_source)
|
||||
.store_override(store_override)
|
||||
.extra_headers(extra_headers)
|
||||
.compression(compression)
|
||||
.build(&self.provider)?;
|
||||
|
||||
let connection = self.connect(request.headers.clone()).await?;
|
||||
connection.stream_request(request).await
|
||||
}
|
||||
|
||||
pub async fn stream(
|
||||
&self,
|
||||
body: Value,
|
||||
extra_headers: HeaderMap,
|
||||
compression: Compression,
|
||||
) -> Result<ResponseStream, ApiError> {
|
||||
let request = ResponsesRequest {
|
||||
body,
|
||||
headers: extra_headers,
|
||||
compression,
|
||||
};
|
||||
let connection = self.connect(request.headers.clone()).await?;
|
||||
connection.stream_request(request).await
|
||||
}
|
||||
}
|
||||
|
||||
// TODO (pakrym): share with /auth
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ pub mod requests;
|
|||
pub mod sse;
|
||||
pub mod telemetry;
|
||||
|
||||
pub use crate::requests::headers::build_conversation_headers;
|
||||
pub use codex_client::RequestTelemetry;
|
||||
pub use codex_client::ReqwestTransport;
|
||||
pub use codex_client::TransportError;
|
||||
|
|
@ -15,6 +16,8 @@ pub use codex_client::TransportError;
|
|||
pub use crate::auth::AuthProvider;
|
||||
pub use crate::common::CompactionInput;
|
||||
pub use crate::common::Prompt;
|
||||
pub use crate::common::ResponseAppendWsRequest;
|
||||
pub use crate::common::ResponseCreateWsRequest;
|
||||
pub use crate::common::ResponseEvent;
|
||||
pub use crate::common::ResponseStream;
|
||||
pub use crate::common::ResponsesApiRequest;
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ use codex_protocol::protocol::SessionSource;
|
|||
use http::HeaderMap;
|
||||
use http::HeaderValue;
|
||||
|
||||
pub(crate) fn build_conversation_headers(conversation_id: Option<String>) -> HeaderMap {
|
||||
pub fn build_conversation_headers(conversation_id: Option<String>) -> HeaderMap {
|
||||
let mut headers = HeaderMap::new();
|
||||
if let Some(id) = conversation_id {
|
||||
insert_header(&mut headers, "session_id", &id);
|
||||
|
|
|
|||
|
|
@ -11,16 +11,18 @@ use codex_api::CompactionInput as ApiCompactionInput;
|
|||
use codex_api::Prompt as ApiPrompt;
|
||||
use codex_api::RequestTelemetry;
|
||||
use codex_api::ReqwestTransport;
|
||||
use codex_api::ResponseAppendWsRequest;
|
||||
use codex_api::ResponseCreateWsRequest;
|
||||
use codex_api::ResponseStream as ApiResponseStream;
|
||||
use codex_api::ResponsesClient as ApiResponsesClient;
|
||||
use codex_api::ResponsesOptions as ApiResponsesOptions;
|
||||
use codex_api::ResponsesRequest;
|
||||
use codex_api::ResponsesRequestBuilder;
|
||||
use codex_api::ResponsesWebsocketClient as ApiWebSocketResponsesClient;
|
||||
use codex_api::ResponsesWebsocketConnection as ApiWebSocketConnection;
|
||||
use codex_api::SseTelemetry;
|
||||
use codex_api::TransportError;
|
||||
use codex_api::build_conversation_headers;
|
||||
use codex_api::common::Reasoning;
|
||||
use codex_api::common::ResponsesWsRequest;
|
||||
use codex_api::create_text_param_for_request;
|
||||
use codex_api::error::ApiError;
|
||||
use codex_api::requests::responses::Compression;
|
||||
|
|
@ -83,6 +85,7 @@ pub struct ModelClient {
|
|||
pub struct ModelClientSession {
|
||||
state: Arc<ModelClientState>,
|
||||
connection: Option<ApiWebSocketConnection>,
|
||||
websocket_last_items: Vec<ResponseItem>,
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
|
|
@ -117,6 +120,7 @@ impl ModelClient {
|
|||
ModelClientSession {
|
||||
state: Arc::clone(&self.state),
|
||||
connection: None,
|
||||
websocket_last_items: Vec::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -320,49 +324,65 @@ impl ModelClientSession {
|
|||
}
|
||||
}
|
||||
|
||||
fn build_responses_websocket_request(
|
||||
fn get_incremental_items(&self, input_items: &[ResponseItem]) -> Option<Vec<ResponseItem>> {
|
||||
// Checks whether the current request input is an incremental append to the previous request.
|
||||
// If items in the new request contain all the items from the previous request we build
|
||||
// a response.append request otherwise we start with a fresh response.create request.
|
||||
let previous_len = self.websocket_last_items.len();
|
||||
let can_append = previous_len > 0
|
||||
&& input_items.starts_with(&self.websocket_last_items)
|
||||
&& previous_len < input_items.len();
|
||||
if can_append {
|
||||
Some(input_items[previous_len..].to_vec())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
fn prepare_websocket_request(
|
||||
&self,
|
||||
api_provider: &codex_api::Provider,
|
||||
api_prompt: &ApiPrompt,
|
||||
options: ApiResponsesOptions,
|
||||
) -> Result<ResponsesRequest> {
|
||||
options: &ApiResponsesOptions,
|
||||
) -> ResponsesWsRequest {
|
||||
if let Some(append_items) = self.get_incremental_items(&api_prompt.input) {
|
||||
return ResponsesWsRequest::ResponseAppend(ResponseAppendWsRequest {
|
||||
input: append_items,
|
||||
});
|
||||
}
|
||||
|
||||
let ApiResponsesOptions {
|
||||
reasoning,
|
||||
include,
|
||||
prompt_cache_key,
|
||||
text,
|
||||
store_override,
|
||||
conversation_id,
|
||||
session_source,
|
||||
extra_headers,
|
||||
compression,
|
||||
..
|
||||
} = options;
|
||||
|
||||
ResponsesRequestBuilder::new(
|
||||
&self.state.model_info.slug,
|
||||
&api_prompt.instructions,
|
||||
&api_prompt.input,
|
||||
)
|
||||
.tools(&api_prompt.tools)
|
||||
.parallel_tool_calls(api_prompt.parallel_tool_calls)
|
||||
.reasoning(reasoning)
|
||||
.include(include)
|
||||
.prompt_cache_key(prompt_cache_key)
|
||||
.text(text)
|
||||
.conversation(conversation_id)
|
||||
.session_source(session_source)
|
||||
.store_override(store_override)
|
||||
.extra_headers(extra_headers)
|
||||
.compression(compression)
|
||||
.build(api_provider)
|
||||
.map_err(map_api_error)
|
||||
let store = store_override.unwrap_or(false);
|
||||
let payload = ResponseCreateWsRequest {
|
||||
model: self.state.model_info.slug.clone(),
|
||||
instructions: api_prompt.instructions.clone(),
|
||||
input: api_prompt.input.clone(),
|
||||
tools: api_prompt.tools.clone(),
|
||||
tool_choice: "auto".to_string(),
|
||||
parallel_tool_calls: api_prompt.parallel_tool_calls,
|
||||
reasoning: reasoning.clone(),
|
||||
store,
|
||||
stream: true,
|
||||
include: include.clone(),
|
||||
prompt_cache_key: prompt_cache_key.clone(),
|
||||
text: text.clone(),
|
||||
};
|
||||
|
||||
ResponsesWsRequest::ResponseCreate(payload)
|
||||
}
|
||||
|
||||
async fn websocket_connection(
|
||||
&mut self,
|
||||
api_provider: codex_api::Provider,
|
||||
api_auth: CoreAuthProvider,
|
||||
headers: ApiHeaderMap,
|
||||
options: &ApiResponsesOptions,
|
||||
) -> std::result::Result<&ApiWebSocketConnection, ApiError> {
|
||||
let needs_new = match self.connection.as_ref() {
|
||||
Some(conn) => conn.is_closed().await,
|
||||
|
|
@ -370,9 +390,12 @@ impl ModelClientSession {
|
|||
};
|
||||
|
||||
if needs_new {
|
||||
let new_conn = ApiWebSocketResponsesClient::new(api_provider, api_auth)
|
||||
.connect(headers)
|
||||
.await?;
|
||||
let mut headers = options.extra_headers.clone();
|
||||
headers.extend(build_conversation_headers(options.conversation_id.clone()));
|
||||
let new_conn: ApiWebSocketConnection =
|
||||
ApiWebSocketResponsesClient::new(api_provider, api_auth)
|
||||
.connect(headers)
|
||||
.await?;
|
||||
self.connection = Some(new_conn);
|
||||
}
|
||||
|
||||
|
|
@ -533,15 +556,10 @@ impl ModelClientSession {
|
|||
let compression = self.responses_request_compression(auth.as_ref());
|
||||
|
||||
let options = self.build_responses_options(prompt, compression);
|
||||
let request =
|
||||
self.build_responses_websocket_request(&api_provider, &api_prompt, options)?;
|
||||
let request = self.prepare_websocket_request(&api_prompt, &options);
|
||||
|
||||
let connection = match self
|
||||
.websocket_connection(
|
||||
api_provider.clone(),
|
||||
api_auth.clone(),
|
||||
request.headers.clone(),
|
||||
)
|
||||
.websocket_connection(api_provider.clone(), api_auth.clone(), &options)
|
||||
.await
|
||||
{
|
||||
Ok(connection) => connection,
|
||||
|
|
@ -558,6 +576,7 @@ impl ModelClientSession {
|
|||
.stream_request(request)
|
||||
.await
|
||||
.map_err(map_api_error)?;
|
||||
self.websocket_last_items = api_prompt.input.clone();
|
||||
|
||||
return Ok(map_response_stream(
|
||||
stream_result,
|
||||
|
|
|
|||
|
|
@ -44,14 +44,7 @@ async fn responses_websocket_streams_request() {
|
|||
|
||||
let harness = websocket_harness(&server).await;
|
||||
let mut session = harness.client.new_session();
|
||||
let mut prompt = Prompt::default();
|
||||
prompt.input = vec![ResponseItem::Message {
|
||||
id: None,
|
||||
role: "user".into(),
|
||||
content: vec![ContentItem::InputText {
|
||||
text: "hello".into(),
|
||||
}],
|
||||
}];
|
||||
let prompt = prompt_with_input(vec![message_item("hello")]);
|
||||
|
||||
stream_until_complete(&mut session, &prompt).await;
|
||||
|
||||
|
|
@ -59,6 +52,7 @@ async fn responses_websocket_streams_request() {
|
|||
assert_eq!(connection.len(), 1);
|
||||
let body = connection.first().expect("missing request").body_json();
|
||||
|
||||
assert_eq!(body["type"].as_str(), Some("response.create"));
|
||||
assert_eq!(body["model"].as_str(), Some(MODEL));
|
||||
assert_eq!(body["stream"], serde_json::Value::Bool(true));
|
||||
assert_eq!(body["input"].as_array().map(Vec::len), Some(1));
|
||||
|
|
@ -67,7 +61,7 @@ async fn responses_websocket_streams_request() {
|
|||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn responses_websocket_reuses_connection() {
|
||||
async fn responses_websocket_appends_on_prefix() {
|
||||
skip_if_no_network!();
|
||||
|
||||
let server = start_websocket_server(vec![vec![
|
||||
|
|
@ -78,30 +72,77 @@ async fn responses_websocket_reuses_connection() {
|
|||
|
||||
let harness = websocket_harness(&server).await;
|
||||
let mut session = harness.client.new_session();
|
||||
let mut prompt = Prompt::default();
|
||||
prompt.input = vec![ResponseItem::Message {
|
||||
id: None,
|
||||
role: "user".into(),
|
||||
content: vec![ContentItem::InputText {
|
||||
text: "hello".into(),
|
||||
}],
|
||||
}];
|
||||
let prompt_one = prompt_with_input(vec![message_item("hello")]);
|
||||
let prompt_two = prompt_with_input(vec![message_item("hello"), message_item("second")]);
|
||||
|
||||
for _ in 0..2 {
|
||||
stream_until_complete(&mut session, &prompt).await;
|
||||
}
|
||||
stream_until_complete(&mut session, &prompt_one).await;
|
||||
stream_until_complete(&mut session, &prompt_two).await;
|
||||
|
||||
let connection = server.single_connection();
|
||||
assert_eq!(connection.len(), 2);
|
||||
let body = connection.first().expect("missing request").body_json();
|
||||
let first = connection.first().expect("missing request").body_json();
|
||||
let second = connection.get(1).expect("missing request").body_json();
|
||||
|
||||
assert_eq!(body["model"].as_str(), Some(MODEL));
|
||||
assert_eq!(body["stream"], serde_json::Value::Bool(true));
|
||||
assert_eq!(body["input"].as_array().map(Vec::len), Some(1));
|
||||
assert_eq!(first["type"].as_str(), Some("response.create"));
|
||||
assert_eq!(first["model"].as_str(), Some(MODEL));
|
||||
assert_eq!(first["stream"], serde_json::Value::Bool(true));
|
||||
assert_eq!(first["input"].as_array().map(Vec::len), Some(1));
|
||||
let expected_append = serde_json::json!({
|
||||
"type": "response.append",
|
||||
"input": serde_json::to_value(&prompt_two.input[1..]).expect("serialize append items"),
|
||||
});
|
||||
assert_eq!(second, expected_append);
|
||||
|
||||
server.shutdown().await;
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn responses_websocket_creates_on_non_prefix() {
|
||||
skip_if_no_network!();
|
||||
|
||||
let server = start_websocket_server(vec![vec![
|
||||
vec![ev_response_created("resp-1"), ev_completed("resp-1")],
|
||||
vec![ev_response_created("resp-2"), ev_completed("resp-2")],
|
||||
]])
|
||||
.await;
|
||||
|
||||
let harness = websocket_harness(&server).await;
|
||||
let mut session = harness.client.new_session();
|
||||
let prompt_one = prompt_with_input(vec![message_item("hello")]);
|
||||
let prompt_two = prompt_with_input(vec![message_item("different")]);
|
||||
|
||||
stream_until_complete(&mut session, &prompt_one).await;
|
||||
stream_until_complete(&mut session, &prompt_two).await;
|
||||
|
||||
let connection = server.single_connection();
|
||||
assert_eq!(connection.len(), 2);
|
||||
let second = connection.get(1).expect("missing request").body_json();
|
||||
|
||||
assert_eq!(second["type"].as_str(), Some("response.create"));
|
||||
assert_eq!(second["model"].as_str(), Some(MODEL));
|
||||
assert_eq!(second["stream"], serde_json::Value::Bool(true));
|
||||
assert_eq!(
|
||||
second["input"],
|
||||
serde_json::to_value(&prompt_two.input).unwrap()
|
||||
);
|
||||
|
||||
server.shutdown().await;
|
||||
}
|
||||
|
||||
fn message_item(text: &str) -> ResponseItem {
|
||||
ResponseItem::Message {
|
||||
id: None,
|
||||
role: "user".into(),
|
||||
content: vec![ContentItem::InputText { text: text.into() }],
|
||||
}
|
||||
}
|
||||
|
||||
fn prompt_with_input(input: Vec<ResponseItem>) -> Prompt {
|
||||
let mut prompt = Prompt::default();
|
||||
prompt.input = input;
|
||||
prompt
|
||||
}
|
||||
|
||||
fn websocket_provider(server: &WebSocketTestServer) -> ModelProviderInfo {
|
||||
ModelProviderInfo {
|
||||
name: "mock-ws".into(),
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue