diff --git a/codex-rs/core/src/codex.rs b/codex-rs/core/src/codex.rs index 832592d40..45dd4f2b6 100644 --- a/codex-rs/core/src/codex.rs +++ b/codex-rs/core/src/codex.rs @@ -1518,6 +1518,21 @@ impl Session { format!("auto-compact-{id}") } + pub(crate) async fn route_realtime_text_input(self: &Arc, text: String) { + handlers::user_input_or_turn( + self, + self.next_internal_sub_id(), + Op::UserInput { + items: vec![UserInput::Text { + text, + text_elements: Vec::new(), + }], + final_output_json_schema: None, + }, + ) + .await; + } + pub(crate) async fn get_total_token_usage(&self) -> i64 { let state = self.state.lock().await; state.get_total_token_usage(state.server_reasoning_included()) diff --git a/codex-rs/core/src/realtime_conversation.rs b/codex-rs/core/src/realtime_conversation.rs index b67874593..b1540ca94 100644 --- a/codex-rs/core/src/realtime_conversation.rs +++ b/codex-rs/core/src/realtime_conversation.rs @@ -25,6 +25,7 @@ use codex_protocol::protocol::RealtimeConversationClosedEvent; use codex_protocol::protocol::RealtimeConversationRealtimeEvent; use codex_protocol::protocol::RealtimeConversationStartedEvent; use http::HeaderMap; +use serde_json::Value; use std::sync::Arc; use tokio::sync::Mutex; use tokio::task::JoinHandle; @@ -209,11 +210,22 @@ pub(crate) async fn handle_start( msg, }; while let Ok(event) = events_rx.recv().await { + let maybe_routed_text = match &event { + RealtimeEvent::ConversationItemAdded(item) => { + realtime_text_from_conversation_item(item) + } + _ => None, + }; sess_clone .send_event_raw(ev(EventMsg::RealtimeConversationRealtime( - RealtimeConversationRealtimeEvent { payload: event }, + RealtimeConversationRealtimeEvent { + payload: event.clone(), + }, ))) .await; + if let Some(text) = maybe_routed_text { + sess_clone.route_realtime_text_input(text).await; + } } if let Some(()) = sess_clone.conversation.running_state().await { sess_clone @@ -239,6 +251,19 @@ pub(crate) async fn handle_audio( } } +fn realtime_text_from_conversation_item(item: &Value) -> Option { + if item.get("type").and_then(Value::as_str) != Some("message") { + return None; + } + let content = item.get("content")?.as_array()?; + let text = content + .iter() + .filter(|entry| entry.get("type").and_then(Value::as_str) == Some("text")) + .filter_map(|entry| entry.get("text").and_then(Value::as_str)) + .collect::(); + if text.is_empty() { None } else { Some(text) } +} + pub(crate) async fn handle_text( sess: &Arc, sub_id: String, @@ -355,3 +380,64 @@ async fn send_conversation_error( }) .await; } + +#[cfg(test)] +mod tests { + use super::realtime_text_from_conversation_item; + use pretty_assertions::assert_eq; + use serde_json::json; + + #[test] + fn extracts_text_from_message_items_ignoring_role() { + let assistant = json!({ + "type": "message", + "role": "assistant", + "content": [{"type": "text", "text": "hello"}], + }); + assert_eq!( + realtime_text_from_conversation_item(&assistant), + Some("hello".to_string()) + ); + + let user = json!({ + "type": "message", + "role": "user", + "content": [{"type": "text", "text": "world"}], + }); + assert_eq!( + realtime_text_from_conversation_item(&user), + Some("world".to_string()) + ); + } + + #[test] + fn extracts_and_concatenates_text_entries_only() { + let item = json!({ + "type": "message", + "content": [ + {"type": "text", "text": "a"}, + {"type": "ignored", "text": "x"}, + {"type": "text", "text": "b"} + ], + }); + assert_eq!( + realtime_text_from_conversation_item(&item), + Some("ab".to_string()) + ); + } + + #[test] + fn ignores_non_message_or_missing_text() { + let non_message = json!({ + "type": "tool_call", + "content": [{"type": "text", "text": "nope"}], + }); + assert_eq!(realtime_text_from_conversation_item(&non_message), None); + + let no_text = json!({ + "type": "message", + "content": [{"type": "other", "value": 1}], + }); + assert_eq!(realtime_text_from_conversation_item(&no_text), None); + } +} diff --git a/codex-rs/core/tests/suite/realtime_conversation.rs b/codex-rs/core/tests/suite/realtime_conversation.rs index 466686788..ee4f139f6 100644 --- a/codex-rs/core/tests/suite/realtime_conversation.rs +++ b/codex-rs/core/tests/suite/realtime_conversation.rs @@ -9,13 +9,21 @@ use codex_protocol::protocol::Op; use codex_protocol::protocol::RealtimeAudioFrame; use codex_protocol::protocol::RealtimeConversationRealtimeEvent; use codex_protocol::protocol::RealtimeEvent; +use codex_protocol::user_input::UserInput; +use core_test_support::responses; +use core_test_support::responses::start_mock_server; use core_test_support::responses::start_websocket_server; use core_test_support::skip_if_no_network; +use core_test_support::streaming_sse::StreamingSseChunk; +use core_test_support::streaming_sse::start_streaming_sse_server; use core_test_support::test_codex::test_codex; +use core_test_support::wait_for_event; use core_test_support::wait_for_event_match; use pretty_assertions::assert_eq; +use serde_json::Value; use serde_json::json; use std::time::Duration; +use tokio::sync::oneshot; #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn conversation_start_audio_text_close_round_trip() -> Result<()> { @@ -29,22 +37,12 @@ async fn conversation_start_audio_text_close_round_trip() -> Result<()> { "session": { "id": "sess_1" } })], vec![], - vec![ - json!({ - "type": "response.output_audio.delta", - "delta": "AQID", - "sample_rate": 24000, - "num_channels": 1 - }), - json!({ - "type": "conversation.item.added", - "item": { - "type": "message", - "role": "assistant", - "content": [{"type": "text", "text": "hi"}] - } - }), - ], + vec![json!({ + "type": "response.output_audio.delta", + "delta": "AQID", + "sample_rate": 24000, + "num_channels": 1 + })], ], ]) .await; @@ -458,3 +456,237 @@ async fn conversation_uses_experimental_realtime_ws_backend_prompt_override() -> server.shutdown().await; Ok(()) } + +fn sse_event(event: Value) -> String { + responses::sse(vec![event]) +} + +fn message_input_texts(body: &Value, role: &str) -> Vec { + body.get("input") + .and_then(Value::as_array) + .into_iter() + .flatten() + .filter(|item| item.get("type").and_then(Value::as_str) == Some("message")) + .filter(|item| item.get("role").and_then(Value::as_str) == Some(role)) + .filter_map(|item| item.get("content").and_then(Value::as_array)) + .flatten() + .filter(|span| span.get("type").and_then(Value::as_str) == Some("input_text")) + .filter_map(|span| span.get("text").and_then(Value::as_str).map(str::to_owned)) + .collect() +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn inbound_realtime_text_starts_turn_and_ignores_role() -> Result<()> { + skip_if_no_network!(Ok(())); + + let api_server = start_mock_server().await; + let response_mock = responses::mount_sse_once( + &api_server, + responses::sse(vec![ + responses::ev_response_created("resp-1"), + responses::ev_assistant_message("msg-1", "ok"), + responses::ev_completed("resp-1"), + ]), + ) + .await; + + let realtime_server = start_websocket_server(vec![vec![vec![ + json!({ + "type": "session.created", + "session": { "id": "sess_inbound" } + }), + json!({ + "type": "conversation.item.added", + "item": { + "type": "message", + "role": "user", + "content": [{"type": "text", "text": "text from realtime"}] + } + }), + ]]]) + .await; + + let mut builder = test_codex().with_config({ + let realtime_base_url = realtime_server.uri().to_string(); + move |config| { + config.experimental_realtime_ws_base_url = Some(realtime_base_url); + } + }); + let test = builder.build(&api_server).await?; + + test.codex + .submit(Op::RealtimeConversationStart(ConversationStartParams { + prompt: "backend prompt".to_string(), + session_id: None, + })) + .await?; + + let session_created = wait_for_event_match(&test.codex, |msg| match msg { + EventMsg::RealtimeConversationRealtime(RealtimeConversationRealtimeEvent { + payload: RealtimeEvent::SessionCreated { session_id }, + }) => Some(session_id.clone()), + _ => None, + }) + .await; + assert_eq!(session_created, "sess_inbound"); + + wait_for_event(&test.codex, |event| { + matches!(event, EventMsg::TurnComplete(_)) + }) + .await; + + let request = response_mock.single_request(); + let user_texts = request.message_input_texts("user"); + assert!(user_texts.iter().any(|text| text == "text from realtime")); + + realtime_server.shutdown().await; + Ok(()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn inbound_realtime_text_steers_active_turn() -> Result<()> { + skip_if_no_network!(Ok(())); + + let (gate_completed_tx, gate_completed_rx) = oneshot::channel(); + let first_chunks = vec![ + StreamingSseChunk { + gate: None, + body: sse_event(responses::ev_response_created("resp-1")), + }, + StreamingSseChunk { + gate: None, + body: sse_event(responses::ev_message_item_added("msg-1", "")), + }, + StreamingSseChunk { + gate: None, + body: sse_event(responses::ev_output_text_delta("first ")), + }, + StreamingSseChunk { + gate: None, + body: sse_event(responses::ev_output_text_delta("turn")), + }, + StreamingSseChunk { + gate: None, + body: sse_event(responses::ev_assistant_message("msg-1", "first turn")), + }, + StreamingSseChunk { + gate: Some(gate_completed_rx), + body: sse_event(responses::ev_completed("resp-1")), + }, + ]; + let second_chunks = vec![ + StreamingSseChunk { + gate: None, + body: sse_event(responses::ev_response_created("resp-2")), + }, + StreamingSseChunk { + gate: None, + body: sse_event(responses::ev_completed("resp-2")), + }, + ]; + let (api_server, _completions) = + start_streaming_sse_server(vec![first_chunks, second_chunks]).await; + + let realtime_server = start_websocket_server(vec![vec![ + vec![json!({ + "type": "session.created", + "session": { "id": "sess_steer" } + })], + vec![json!({ + "type": "conversation.item.added", + "item": { + "type": "message", + "role": "assistant", + "content": [{"type": "text", "text": "steer via realtime"}] + } + })], + ]]) + .await; + + let mut builder = test_codex().with_model("gpt-5.1").with_config({ + let realtime_base_url = realtime_server.uri().to_string(); + move |config| { + config.experimental_realtime_ws_base_url = Some(realtime_base_url); + } + }); + let test = builder.build_with_streaming_server(&api_server).await?; + + test.codex + .submit(Op::RealtimeConversationStart(ConversationStartParams { + prompt: "backend prompt".to_string(), + session_id: None, + })) + .await?; + let _ = wait_for_event_match(&test.codex, |msg| match msg { + EventMsg::RealtimeConversationRealtime(RealtimeConversationRealtimeEvent { + payload: RealtimeEvent::SessionCreated { session_id }, + }) if session_id == "sess_steer" => Some(()), + _ => None, + }) + .await; + + test.codex + .submit(Op::UserInput { + items: vec![UserInput::Text { + text: "first prompt".to_string(), + text_elements: Vec::new(), + }], + final_output_json_schema: None, + }) + .await?; + + wait_for_event(&test.codex, |event| { + matches!(event, EventMsg::AgentMessageContentDelta(_)) + }) + .await; + + test.codex + .submit(Op::RealtimeConversationAudio(ConversationAudioParams { + frame: RealtimeAudioFrame { + data: "AQID".to_string(), + sample_rate: 24000, + num_channels: 1, + samples_per_channel: Some(480), + }, + })) + .await?; + + let _ = wait_for_event_match(&test.codex, |msg| match msg { + EventMsg::RealtimeConversationRealtime(RealtimeConversationRealtimeEvent { + payload: RealtimeEvent::ConversationItemAdded(item), + }) => item + .get("content") + .and_then(Value::as_array) + .into_iter() + .flatten() + .any(|content| { + content.get("text").and_then(Value::as_str) == Some("steer via realtime") + }) + .then_some(()), + _ => None, + }) + .await; + + let _ = gate_completed_tx.send(()); + wait_for_event(&test.codex, |event| { + matches!(event, EventMsg::TurnComplete(_)) + }) + .await; + + let requests = api_server.requests().await; + assert_eq!(requests.len(), 2); + + let first_body: Value = serde_json::from_slice(&requests[0]).expect("parse first request"); + let second_body: Value = serde_json::from_slice(&requests[1]).expect("parse second request"); + let first_texts = message_input_texts(&first_body, "user"); + let second_texts = message_input_texts(&second_body, "user"); + + assert!(first_texts.iter().any(|text| text == "first prompt")); + assert!(!first_texts.iter().any(|text| text == "steer via realtime")); + assert!(second_texts.iter().any(|text| text == "first prompt")); + assert!(second_texts.iter().any(|text| text == "steer via realtime")); + + realtime_server.shutdown().await; + api_server.shutdown().await; + Ok(()) +}