diff --git a/codex-rs/codex-api/src/sse/responses.rs b/codex-rs/codex-api/src/sse/responses.rs index f279ba5ed..7f0981c5c 100644 --- a/codex-rs/codex-api/src/sse/responses.rs +++ b/codex-rs/codex-api/src/sse/responses.rs @@ -88,6 +88,14 @@ struct ResponseCompleted { usage: Option, } +#[derive(Debug, Deserialize)] +struct ResponseDone { + #[serde(default)] + id: Option, + #[serde(default)] + usage: Option, +} + #[derive(Debug, Deserialize)] struct ResponseCompletedUsage { input_tokens: i64, @@ -229,6 +237,29 @@ pub fn process_responses_event( } } } + "response.done" => { + if let Some(resp_val) = event.response { + match serde_json::from_value::(resp_val) { + Ok(resp) => { + return Ok(Some(ResponseEvent::Completed { + response_id: resp.id.unwrap_or_default(), + token_usage: resp.usage.map(Into::into), + })); + } + Err(err) => { + let error = format!("failed to parse ResponseCompleted: {err}"); + debug!("{error}"); + return Err(ResponsesEventError::Api(ApiError::Stream(error))); + } + } + } + + debug!("response.done missing response payload"); + return Ok(Some(ResponseEvent::Completed { + response_id: String::new(), + token_usage: None, + })); + } "response.output_item.added" => { if let Some(item_val) = event.item { if let Ok(item) = serde_json::from_value::(item_val) { @@ -517,6 +548,65 @@ mod tests { } } + #[tokio::test] + async fn response_done_emits_completed() { + let done = json!({ + "type": "response.done", + "response": { + "usage": { + "input_tokens": 1, + "input_tokens_details": null, + "output_tokens": 2, + "output_tokens_details": null, + "total_tokens": 3 + } + } + }) + .to_string(); + + let sse1 = format!("event: response.done\ndata: {done}\n\n"); + + let events = collect_events(&[sse1.as_bytes()]).await; + + assert_eq!(events.len(), 1); + + match &events[0] { + Ok(ResponseEvent::Completed { + response_id, + token_usage, + }) => { + assert_eq!(response_id, ""); + assert!(token_usage.is_some()); + } + other => panic!("unexpected event: {other:?}"), + } + } + + #[tokio::test] + async fn response_done_without_payload_emits_completed() { + let done = json!({ + "type": "response.done" + }) + .to_string(); + + let sse1 = format!("event: response.done\ndata: {done}\n\n"); + + let events = collect_events(&[sse1.as_bytes()]).await; + + assert_eq!(events.len(), 1); + + match &events[0] { + Ok(ResponseEvent::Completed { + response_id, + token_usage, + }) => { + assert_eq!(response_id, ""); + assert!(token_usage.is_none()); + } + other => panic!("unexpected event: {other:?}"), + } + } + #[tokio::test] async fn error_when_error_event() { let raw_error = r#"{"type":"response.failed","sequence_number":3,"response":{"id":"resp_689bcf18d7f08194bf3440ba62fe05d803fee0cdac429894","object":"response","created_at":1755041560,"status":"failed","background":false,"error":{"code":"rate_limit_exceeded","message":"Rate limit reached for gpt-5.1 in organization org-AAA on tokens per min (TPM): Limit 30000, Used 22999, Requested 12528. Please try again in 11.054s. Visit https://platform.openai.com/account/rate-limits to learn more."}, "usage":null,"user":null,"metadata":{}}}"#; diff --git a/codex-rs/core/src/codex.rs b/codex-rs/core/src/codex.rs index 019e57740..be362e287 100644 --- a/codex-rs/core/src/codex.rs +++ b/codex-rs/core/src/codex.rs @@ -2543,6 +2543,8 @@ pub(crate) async fn run_turn( // many turns, from the perspective of the user, it is a single turn. let turn_diff_tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::new())); + let mut client_session = turn_context.client.new_session(); + loop { // Note that pending_input would be something like a message the user // submitted through the UI while the model was running. Though the UI @@ -2573,6 +2575,7 @@ pub(crate) async fn run_turn( Arc::clone(&sess), Arc::clone(&turn_context), Arc::clone(&turn_diff_tracker), + &mut client_session, turn_input, cancellation_token.child_token(), ) @@ -2650,6 +2653,7 @@ async fn run_model_turn( sess: Arc, turn_context: Arc, turn_diff_tracker: SharedTurnDiffTracker, + client_session: &mut ModelClientSession, input: Vec, cancellation_token: CancellationToken, ) -> CodexResult { @@ -2684,15 +2688,13 @@ async fn run_model_turn( output_schema: turn_context.final_output_json_schema.clone(), }; - let mut client_session = turn_context.client.new_session(); - let mut retries = 0; loop { let err = match try_run_turn( Arc::clone(&router), Arc::clone(&sess), Arc::clone(&turn_context), - &mut client_session, + client_session, Arc::clone(&turn_diff_tracker), &prompt, cancellation_token.child_token(), diff --git a/codex-rs/core/tests/common/responses.rs b/codex-rs/core/tests/common/responses.rs index 552966e79..8f698ed28 100644 --- a/codex-rs/core/tests/common/responses.rs +++ b/codex-rs/core/tests/common/responses.rs @@ -319,6 +319,15 @@ pub fn ev_completed(id: &str) -> Value { }) } +pub fn ev_done() -> Value { + serde_json::json!({ + "type": "response.done", + "response": { + "usage": {"input_tokens":0,"input_tokens_details":null,"output_tokens":0,"output_tokens_details":null,"total_tokens":0} + } + }) +} + /// Convenience: SSE event for a created response with a specific id. pub fn ev_response_created(id: &str) -> Value { serde_json::json!({ diff --git a/codex-rs/core/tests/common/test_codex.rs b/codex-rs/core/tests/common/test_codex.rs index 7aaa096c3..6d59cd4df 100644 --- a/codex-rs/core/tests/common/test_codex.rs +++ b/codex-rs/core/tests/common/test_codex.rs @@ -8,6 +8,7 @@ use codex_core::CodexAuth; use codex_core::CodexThread; use codex_core::ModelProviderInfo; use codex_core::ThreadManager; +use codex_core::WireApi; use codex_core::built_in_model_providers; use codex_core::config::Config; use codex_core::features::Feature; @@ -23,6 +24,7 @@ use tempfile::TempDir; use wiremock::MockServer; use crate::load_default_config_for_test; +use crate::responses::WebSocketTestServer; use crate::responses::start_mock_server; use crate::streaming_sse::StreamingSseServer; use crate::wait_for_event; @@ -101,6 +103,21 @@ impl TestCodexBuilder { .await } + pub async fn build_with_websocket_server( + &mut self, + server: &WebSocketTestServer, + ) -> anyhow::Result { + let base_url = format!("{}/v1", server.uri()); + let home = Arc::new(TempDir::new()?); + let base_url_clone = base_url.clone(); + self.config_mutators.push(Box::new(move |config| { + config.model_provider.base_url = Some(base_url_clone); + config.model_provider.wire_api = WireApi::ResponsesWebsocket; + })); + self.build_with_home_and_base_url(base_url, home, None) + .await + } + pub async fn resume( &mut self, server: &wiremock::MockServer, diff --git a/codex-rs/core/tests/suite/agent_websocket.rs b/codex-rs/core/tests/suite/agent_websocket.rs new file mode 100644 index 000000000..940995fb8 --- /dev/null +++ b/codex-rs/core/tests/suite/agent_websocket.rs @@ -0,0 +1,69 @@ +use anyhow::Result; +use core_test_support::responses::ev_assistant_message; +use core_test_support::responses::ev_completed; +use core_test_support::responses::ev_done; +use core_test_support::responses::ev_response_created; +use core_test_support::responses::ev_shell_command_call; +use core_test_support::responses::start_websocket_server; +use core_test_support::skip_if_no_network; +use core_test_support::test_codex::test_codex; +use pretty_assertions::assert_eq; +use serde_json::Value; + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn websocket_test_codex_shell_chain() -> Result<()> { + skip_if_no_network!(Ok(())); + + let call_id = "shell-command-call"; + let server = start_websocket_server(vec![vec![ + vec![ + ev_response_created("resp-1"), + ev_shell_command_call(call_id, "echo websocket"), + ev_done(), + ], + vec![ + ev_response_created("resp-2"), + ev_assistant_message("msg-1", "done"), + ev_completed("resp-2"), + ], + ]]) + .await; + + let mut builder = test_codex(); + + let test = builder.build_with_websocket_server(&server).await?; + test.submit_turn("run the echo command").await?; + + let connection = server.single_connection(); + assert_eq!(connection.len(), 2); + + let first = connection + .first() + .expect("missing first request") + .body_json(); + let second = connection + .get(1) + .expect("missing second request") + .body_json(); + + assert_eq!(first["type"].as_str(), Some("response.create")); + assert_eq!(second["type"].as_str(), Some("response.append")); + + let append_items = second + .get("input") + .and_then(Value::as_array) + .expect("response.append input array"); + assert!(!append_items.is_empty()); + + let output_item = append_items + .iter() + .find(|item| item.get("type").and_then(Value::as_str) == Some("function_call_output")) + .expect("function_call_output in append"); + assert_eq!( + output_item.get("call_id").and_then(Value::as_str), + Some(call_id) + ); + + server.shutdown().await; + Ok(()) +} diff --git a/codex-rs/core/tests/suite/websocket.rs b/codex-rs/core/tests/suite/client_websockets.rs similarity index 100% rename from codex-rs/core/tests/suite/websocket.rs rename to codex-rs/core/tests/suite/client_websockets.rs diff --git a/codex-rs/core/tests/suite/mod.rs b/codex-rs/core/tests/suite/mod.rs index c75cb5407..66fcb5cdb 100644 --- a/codex-rs/core/tests/suite/mod.rs +++ b/codex-rs/core/tests/suite/mod.rs @@ -15,12 +15,14 @@ pub static CODEX_ALIASES_TEMP_DIR: TempDir = unsafe { #[cfg(not(target_os = "windows"))] mod abort_tasks; +mod agent_websocket; mod apply_patch_cli; #[cfg(not(target_os = "windows"))] mod approvals; mod auth_refresh; mod cli_stream; mod client; +mod client_websockets; mod codex_delegate; mod compact; mod compact_remote; @@ -72,4 +74,3 @@ mod user_notification; mod user_shell_cmd; mod view_image; mod web_search_cached; -mod websocket;