diff --git a/codex-rs/codex-api/src/endpoint/aggregate.rs b/codex-rs/codex-api/src/endpoint/aggregate.rs deleted file mode 100644 index 6d8e785c8..000000000 --- a/codex-rs/codex-api/src/endpoint/aggregate.rs +++ /dev/null @@ -1,163 +0,0 @@ -use crate::common::ResponseEvent; -use crate::common::ResponseStream; -use crate::error::ApiError; -use codex_protocol::models::ContentItem; -use codex_protocol::models::ReasoningItemContent; -use codex_protocol::models::ResponseItem; -use futures::Stream; -use std::collections::VecDeque; -use std::pin::Pin; -use std::task::Context; -use std::task::Poll; - -/// Stream adapter that merges token deltas into a single assistant message per turn. -pub struct AggregatedStream { - inner: ResponseStream, - cumulative: String, - cumulative_reasoning: String, - pending: VecDeque, -} - -impl Stream for AggregatedStream { - type Item = Result; - - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let this = self.get_mut(); - - if let Some(ev) = this.pending.pop_front() { - return Poll::Ready(Some(Ok(ev))); - } - - loop { - match Pin::new(&mut this.inner).poll_next(cx) { - Poll::Pending => return Poll::Pending, - Poll::Ready(None) => return Poll::Ready(None), - Poll::Ready(Some(Err(err))) => return Poll::Ready(Some(Err(err))), - Poll::Ready(Some(Ok(ResponseEvent::OutputItemDone(item)))) => { - let is_assistant_message = matches!( - &item, - ResponseItem::Message { role, .. } if role == "assistant" - ); - - if is_assistant_message { - if this.cumulative.is_empty() - && let ResponseItem::Message { content, .. } = &item - && let Some(text) = content.iter().find_map(|c| match c { - ContentItem::OutputText { text } => Some(text), - _ => None, - }) - { - this.cumulative.push_str(text); - } - continue; - } - - return Poll::Ready(Some(Ok(ResponseEvent::OutputItemDone(item)))); - } - Poll::Ready(Some(Ok(ResponseEvent::ServerReasoningIncluded(included)))) => { - return Poll::Ready(Some(Ok(ResponseEvent::ServerReasoningIncluded(included)))); - } - Poll::Ready(Some(Ok(ResponseEvent::RateLimits(snapshot)))) => { - return Poll::Ready(Some(Ok(ResponseEvent::RateLimits(snapshot)))); - } - Poll::Ready(Some(Ok(ResponseEvent::ModelsEtag(etag)))) => { - return Poll::Ready(Some(Ok(ResponseEvent::ModelsEtag(etag)))); - } - Poll::Ready(Some(Ok(ResponseEvent::ServerModel(model)))) => { - return Poll::Ready(Some(Ok(ResponseEvent::ServerModel(model)))); - } - Poll::Ready(Some(Ok(ResponseEvent::Completed { - response_id, - token_usage, - can_append: _can_append, - }))) => { - let mut emitted_any = false; - - if !this.cumulative_reasoning.is_empty() { - let aggregated_reasoning = ResponseItem::Reasoning { - id: String::new(), - summary: Vec::new(), - content: Some(vec![ReasoningItemContent::ReasoningText { - text: std::mem::take(&mut this.cumulative_reasoning), - }]), - encrypted_content: None, - }; - this.pending - .push_back(ResponseEvent::OutputItemDone(aggregated_reasoning)); - emitted_any = true; - } - - if !this.cumulative.is_empty() { - let aggregated_message = ResponseItem::Message { - id: None, - role: "assistant".to_string(), - content: vec![ContentItem::OutputText { - text: std::mem::take(&mut this.cumulative), - }], - end_turn: None, - phase: None, - }; - this.pending - .push_back(ResponseEvent::OutputItemDone(aggregated_message)); - emitted_any = true; - } - - if emitted_any { - this.pending.push_back(ResponseEvent::Completed { - response_id: response_id.clone(), - token_usage: token_usage.clone(), - can_append: false, - }); - if let Some(ev) = this.pending.pop_front() { - return Poll::Ready(Some(Ok(ev))); - } - } - - return Poll::Ready(Some(Ok(ResponseEvent::Completed { - response_id, - token_usage, - can_append: false, - }))); - } - Poll::Ready(Some(Ok(ResponseEvent::Created))) => continue, - Poll::Ready(Some(Ok(ResponseEvent::OutputTextDelta(delta)))) => { - this.cumulative.push_str(&delta); - continue; - } - Poll::Ready(Some(Ok(ResponseEvent::ReasoningContentDelta { - delta, - content_index: _, - }))) => { - this.cumulative_reasoning.push_str(&delta); - continue; - } - Poll::Ready(Some(Ok(ResponseEvent::ReasoningSummaryDelta { .. }))) => continue, - Poll::Ready(Some(Ok(ResponseEvent::ReasoningSummaryPartAdded { .. }))) => continue, - Poll::Ready(Some(Ok(ResponseEvent::OutputItemAdded(item)))) => { - return Poll::Ready(Some(Ok(ResponseEvent::OutputItemAdded(item)))); - } - } - } - } -} - -pub trait AggregateStreamExt { - fn aggregate(self) -> AggregatedStream; -} - -impl AggregateStreamExt for ResponseStream { - fn aggregate(self) -> AggregatedStream { - AggregatedStream::new(self) - } -} - -impl AggregatedStream { - fn new(inner: ResponseStream) -> Self { - AggregatedStream { - inner, - cumulative: String::new(), - cumulative_reasoning: String::new(), - pending: VecDeque::new(), - } - } -} diff --git a/codex-rs/codex-api/src/endpoint/mod.rs b/codex-rs/codex-api/src/endpoint/mod.rs index 981643904..6a748e533 100644 --- a/codex-rs/codex-api/src/endpoint/mod.rs +++ b/codex-rs/codex-api/src/endpoint/mod.rs @@ -1,4 +1,3 @@ -pub mod aggregate; pub mod compact; pub mod memories; pub mod models; diff --git a/codex-rs/codex-api/src/lib.rs b/codex-rs/codex-api/src/lib.rs index 23c08d7e2..b9152ec57 100644 --- a/codex-rs/codex-api/src/lib.rs +++ b/codex-rs/codex-api/src/lib.rs @@ -25,7 +25,6 @@ pub use crate::common::ResponseEvent; pub use crate::common::ResponseStream; pub use crate::common::ResponsesApiRequest; pub use crate::common::create_text_param_for_request; -pub use crate::endpoint::aggregate::AggregateStreamExt; pub use crate::endpoint::compact::CompactClient; pub use crate::endpoint::memories::MemoriesClient; pub use crate::endpoint::models::ModelsClient; diff --git a/codex-rs/codex-api/tests/sse_end_to_end.rs b/codex-rs/codex-api/tests/sse_end_to_end.rs index 9f2903378..ca840d1b9 100644 --- a/codex-rs/codex-api/tests/sse_end_to_end.rs +++ b/codex-rs/codex-api/tests/sse_end_to_end.rs @@ -3,7 +3,6 @@ use std::time::Duration; use anyhow::Result; use async_trait::async_trait; use bytes::Bytes; -use codex_api::AggregateStreamExt; use codex_api::AuthProvider; use codex_api::Provider; use codex_api::ResponseEvent; @@ -14,7 +13,6 @@ use codex_client::Request; use codex_client::Response; use codex_client::StreamResponse; use codex_client::TransportError; -use codex_protocol::models::ContentItem; use codex_protocol::models::ResponseItem; use futures::StreamExt; use http::HeaderMap; @@ -172,69 +170,3 @@ async fn responses_stream_parses_items_and_completed_end_to_end() -> Result<()> Ok(()) } - -#[tokio::test] -async fn responses_stream_aggregates_output_text_deltas() -> Result<()> { - let delta1 = serde_json::json!({ - "type": "response.output_text.delta", - "delta": "Hello, " - }); - - let delta2 = serde_json::json!({ - "type": "response.output_text.delta", - "delta": "world" - }); - - let completed = serde_json::json!({ - "type": "response.completed", - "response": { "id": "resp-agg" } - }); - - let body = build_responses_body(vec![delta1, delta2, completed]); - let transport = FixtureSseTransport::new(body); - let client = ResponsesClient::new(transport, provider("openai"), NoAuth); - - let stream = client - .stream( - serde_json::json!({"echo": true}), - HeaderMap::new(), - Compression::None, - None, - ) - .await?; - - let mut stream = stream.aggregate(); - let mut events = Vec::new(); - while let Some(ev) = stream.next().await { - events.push(ev?); - } - - let events: Vec = events - .into_iter() - .filter(|ev| !matches!(ev, ResponseEvent::RateLimits(_))) - .collect(); - - assert_eq!(events.len(), 2); - - match &events[0] { - ResponseEvent::OutputItemDone(ResponseItem::Message { content, .. }) => { - let mut aggregated = String::new(); - for item in content { - if let ContentItem::OutputText { text } = item { - aggregated.push_str(text); - } - } - assert_eq!(aggregated, "Hello, world"); - } - other => panic!("unexpected first event: {other:?}"), - } - - match &events[1] { - ResponseEvent::Completed { response_id, .. } => { - assert_eq!(response_id, "resp-agg"); - } - other => panic!("unexpected second event: {other:?}"), - } - - Ok(()) -}