diff --git a/codex-rs/codex-api/src/sse/responses.rs b/codex-rs/codex-api/src/sse/responses.rs index 2ec8271c2..5a1ab832e 100644 --- a/codex-rs/codex-api/src/sse/responses.rs +++ b/codex-rs/codex-api/src/sse/responses.rs @@ -126,7 +126,7 @@ struct ResponseCompletedOutputTokensDetails { } #[derive(Deserialize, Debug)] -struct SseEvent { +struct ResponsesStreamEvent { #[serde(rename = "type")] kind: String, response: Option, @@ -136,6 +136,122 @@ struct SseEvent { content_index: Option, } +#[derive(Debug)] +pub enum ResponsesEventError { + Api(ApiError), +} + +impl ResponsesEventError { + pub fn into_api_error(self) -> ApiError { + match self { + Self::Api(error) => error, + } + } +} + +fn process_responses_event( + event: ResponsesStreamEvent, +) -> std::result::Result, ResponsesEventError> { + match event.kind.as_str() { + "response.output_item.done" => { + if let Some(item_val) = event.item { + if let Ok(item) = serde_json::from_value::(item_val) { + return Ok(Some(ResponseEvent::OutputItemDone(item))); + } + debug!("failed to parse ResponseItem from output_item.done"); + } + } + "response.output_text.delta" => { + if let Some(delta) = event.delta { + return Ok(Some(ResponseEvent::OutputTextDelta(delta))); + } + } + "response.reasoning_summary_text.delta" => { + if let (Some(delta), Some(summary_index)) = (event.delta, event.summary_index) { + return Ok(Some(ResponseEvent::ReasoningSummaryDelta { + delta, + summary_index, + })); + } + } + "response.reasoning_text.delta" => { + if let (Some(delta), Some(content_index)) = (event.delta, event.content_index) { + return Ok(Some(ResponseEvent::ReasoningContentDelta { + delta, + content_index, + })); + } + } + "response.created" => { + if event.response.is_some() { + return Ok(Some(ResponseEvent::Created {})); + } + } + "response.failed" => { + if let Some(resp_val) = event.response { + let mut response_error = ApiError::Stream("response.failed event received".into()); + if let Some(error) = resp_val.get("error") + && let Ok(error) = serde_json::from_value::(error.clone()) + { + if is_context_window_error(&error) { + response_error = ApiError::ContextWindowExceeded; + } else if is_quota_exceeded_error(&error) { + response_error = ApiError::QuotaExceeded; + } else if is_usage_not_included(&error) { + response_error = ApiError::UsageNotIncluded; + } else { + let delay = try_parse_retry_after(&error); + let message = error.message.unwrap_or_default(); + response_error = ApiError::Retryable { message, delay }; + } + } + return Err(ResponsesEventError::Api(response_error)); + } + + return Err(ResponsesEventError::Api(ApiError::Stream( + "response.failed event received".into(), + ))); + } + "response.completed" => { + if let Some(resp_val) = event.response { + match serde_json::from_value::(resp_val) { + Ok(resp) => { + return Ok(Some(ResponseEvent::Completed { + response_id: resp.id, + token_usage: resp.usage.map(Into::into), + })); + } + Err(err) => { + let error = format!("failed to parse ResponseCompleted: {err}"); + debug!("{error}"); + return Err(ResponsesEventError::Api(ApiError::Stream(error))); + } + } + } + } + "response.output_item.added" => { + if let Some(item_val) = event.item { + if let Ok(item) = serde_json::from_value::(item_val) { + return Ok(Some(ResponseEvent::OutputItemAdded(item))); + } + debug!("failed to parse ResponseItem from output_item.done"); + } + } + "response.reasoning_summary_part.added" => { + if let Some(summary_index) = event.summary_index { + return Ok(Some(ResponseEvent::ReasoningSummaryPartAdded { + summary_index, + })); + } + } + _ => { + trace!("unhandled responses event: {}", event.kind); + } + } + + Ok(None) +} + pub async fn process_sse( stream: ByteStream, tx_event: mpsc::Sender>, @@ -143,7 +259,7 @@ pub async fn process_sse( telemetry: Option>, ) { let mut stream = stream.eventsource(); - let mut response_completed: Option = None; + let mut response_completed: Option = None; let mut response_error: Option = None; loop { @@ -161,11 +277,7 @@ pub async fn process_sse( } Ok(None) => { match response_completed.take() { - Some(ResponseCompleted { id, usage }) => { - let event = ResponseEvent::Completed { - response_id: id, - token_usage: usage.map(Into::into), - }; + Some(event) => { let _ = tx_event.send(Ok(event)).await; } None => { @@ -188,7 +300,7 @@ pub async fn process_sse( let raw = sse.data.clone(); trace!("SSE event: {raw}"); - let event: SseEvent = match serde_json::from_str(&sse.data) { + let event: ResponsesStreamEvent = match serde_json::from_str(&sse.data) { Ok(event) => event, Err(e) => { debug!("Failed to parse SSE event: {e}, data: {}", &sse.data); @@ -196,115 +308,19 @@ pub async fn process_sse( } }; - match event.kind.as_str() { - "response.output_item.done" => { - let Some(item_val) = event.item else { continue }; - let Ok(item) = serde_json::from_value::(item_val) else { - debug!("failed to parse ResponseItem from output_item.done"); - continue; - }; - - let event = ResponseEvent::OutputItemDone(item); - if tx_event.send(Ok(event)).await.is_err() { + match process_responses_event(event) { + Ok(Some(event)) => { + if matches!(event, ResponseEvent::Completed { .. }) { + response_completed = Some(event); + } else if tx_event.send(Ok(event)).await.is_err() { return; } } - "response.output_text.delta" => { - if let Some(delta) = event.delta { - let event = ResponseEvent::OutputTextDelta(delta); - if tx_event.send(Ok(event)).await.is_err() { - return; - } - } + Ok(None) => {} + Err(error) => { + response_error = Some(error.into_api_error()); } - "response.reasoning_summary_text.delta" => { - if let (Some(delta), Some(summary_index)) = (event.delta, event.summary_index) { - let event = ResponseEvent::ReasoningSummaryDelta { - delta, - summary_index, - }; - if tx_event.send(Ok(event)).await.is_err() { - return; - } - } - } - "response.reasoning_text.delta" => { - if let (Some(delta), Some(content_index)) = (event.delta, event.content_index) { - let event = ResponseEvent::ReasoningContentDelta { - delta, - content_index, - }; - if tx_event.send(Ok(event)).await.is_err() { - return; - } - } - } - "response.created" => { - if event.response.is_some() { - let _ = tx_event.send(Ok(ResponseEvent::Created {})).await; - } - } - "response.failed" => { - if let Some(resp_val) = event.response { - response_error = - Some(ApiError::Stream("response.failed event received".into())); - - if let Some(error) = resp_val.get("error") - && let Ok(error) = serde_json::from_value::(error.clone()) - { - if is_context_window_error(&error) { - response_error = Some(ApiError::ContextWindowExceeded); - } else if is_quota_exceeded_error(&error) { - response_error = Some(ApiError::QuotaExceeded); - } else if is_usage_not_included(&error) { - response_error = Some(ApiError::UsageNotIncluded); - } else { - let delay = try_parse_retry_after(&error); - let message = error.message.clone().unwrap_or_default(); - response_error = Some(ApiError::Retryable { message, delay }); - } - } - } - } - "response.completed" => { - if let Some(resp_val) = event.response { - match serde_json::from_value::(resp_val) { - Ok(r) => { - response_completed = Some(r); - } - Err(e) => { - let error = format!("failed to parse ResponseCompleted: {e}"); - debug!(error); - response_error = Some(ApiError::Stream(error)); - continue; - } - }; - }; - } - "response.output_item.added" => { - let Some(item_val) = event.item else { continue }; - let Ok(item) = serde_json::from_value::(item_val) else { - debug!("failed to parse ResponseItem from output_item.done"); - continue; - }; - - let event = ResponseEvent::OutputItemAdded(item); - if tx_event.send(Ok(event)).await.is_err() { - return; - } - } - "response.reasoning_summary_part.added" => { - if let Some(summary_index) = event.summary_index { - let event = ResponseEvent::ReasoningSummaryPartAdded { summary_index }; - if tx_event.send(Ok(event)).await.is_err() { - return; - } - } - } - _ => { - trace!("unhandled SSE event: {:#?}", event.kind); - } - } + }; } }