Route inbound realtime text into turn start or steer (#12469)
- Route inbound realtime websocket text into normal user input handling so it steers an active turn or starts a new one
This commit is contained in:
parent
2ba2c57af4
commit
031d701705
3 changed files with 350 additions and 17 deletions
|
|
@ -1518,6 +1518,21 @@ impl Session {
|
|||
format!("auto-compact-{id}")
|
||||
}
|
||||
|
||||
pub(crate) async fn route_realtime_text_input(self: &Arc<Self>, text: String) {
|
||||
handlers::user_input_or_turn(
|
||||
self,
|
||||
self.next_internal_sub_id(),
|
||||
Op::UserInput {
|
||||
items: vec![UserInput::Text {
|
||||
text,
|
||||
text_elements: Vec::new(),
|
||||
}],
|
||||
final_output_json_schema: None,
|
||||
},
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
pub(crate) async fn get_total_token_usage(&self) -> i64 {
|
||||
let state = self.state.lock().await;
|
||||
state.get_total_token_usage(state.server_reasoning_included())
|
||||
|
|
|
|||
|
|
@ -25,6 +25,7 @@ use codex_protocol::protocol::RealtimeConversationClosedEvent;
|
|||
use codex_protocol::protocol::RealtimeConversationRealtimeEvent;
|
||||
use codex_protocol::protocol::RealtimeConversationStartedEvent;
|
||||
use http::HeaderMap;
|
||||
use serde_json::Value;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::task::JoinHandle;
|
||||
|
|
@ -209,11 +210,22 @@ pub(crate) async fn handle_start(
|
|||
msg,
|
||||
};
|
||||
while let Ok(event) = events_rx.recv().await {
|
||||
let maybe_routed_text = match &event {
|
||||
RealtimeEvent::ConversationItemAdded(item) => {
|
||||
realtime_text_from_conversation_item(item)
|
||||
}
|
||||
_ => None,
|
||||
};
|
||||
sess_clone
|
||||
.send_event_raw(ev(EventMsg::RealtimeConversationRealtime(
|
||||
RealtimeConversationRealtimeEvent { payload: event },
|
||||
RealtimeConversationRealtimeEvent {
|
||||
payload: event.clone(),
|
||||
},
|
||||
)))
|
||||
.await;
|
||||
if let Some(text) = maybe_routed_text {
|
||||
sess_clone.route_realtime_text_input(text).await;
|
||||
}
|
||||
}
|
||||
if let Some(()) = sess_clone.conversation.running_state().await {
|
||||
sess_clone
|
||||
|
|
@ -239,6 +251,19 @@ pub(crate) async fn handle_audio(
|
|||
}
|
||||
}
|
||||
|
||||
fn realtime_text_from_conversation_item(item: &Value) -> Option<String> {
|
||||
if item.get("type").and_then(Value::as_str) != Some("message") {
|
||||
return None;
|
||||
}
|
||||
let content = item.get("content")?.as_array()?;
|
||||
let text = content
|
||||
.iter()
|
||||
.filter(|entry| entry.get("type").and_then(Value::as_str) == Some("text"))
|
||||
.filter_map(|entry| entry.get("text").and_then(Value::as_str))
|
||||
.collect::<String>();
|
||||
if text.is_empty() { None } else { Some(text) }
|
||||
}
|
||||
|
||||
pub(crate) async fn handle_text(
|
||||
sess: &Arc<Session>,
|
||||
sub_id: String,
|
||||
|
|
@ -355,3 +380,64 @@ async fn send_conversation_error(
|
|||
})
|
||||
.await;
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::realtime_text_from_conversation_item;
|
||||
use pretty_assertions::assert_eq;
|
||||
use serde_json::json;
|
||||
|
||||
#[test]
|
||||
fn extracts_text_from_message_items_ignoring_role() {
|
||||
let assistant = json!({
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": "hello"}],
|
||||
});
|
||||
assert_eq!(
|
||||
realtime_text_from_conversation_item(&assistant),
|
||||
Some("hello".to_string())
|
||||
);
|
||||
|
||||
let user = json!({
|
||||
"type": "message",
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": "world"}],
|
||||
});
|
||||
assert_eq!(
|
||||
realtime_text_from_conversation_item(&user),
|
||||
Some("world".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extracts_and_concatenates_text_entries_only() {
|
||||
let item = json!({
|
||||
"type": "message",
|
||||
"content": [
|
||||
{"type": "text", "text": "a"},
|
||||
{"type": "ignored", "text": "x"},
|
||||
{"type": "text", "text": "b"}
|
||||
],
|
||||
});
|
||||
assert_eq!(
|
||||
realtime_text_from_conversation_item(&item),
|
||||
Some("ab".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ignores_non_message_or_missing_text() {
|
||||
let non_message = json!({
|
||||
"type": "tool_call",
|
||||
"content": [{"type": "text", "text": "nope"}],
|
||||
});
|
||||
assert_eq!(realtime_text_from_conversation_item(&non_message), None);
|
||||
|
||||
let no_text = json!({
|
||||
"type": "message",
|
||||
"content": [{"type": "other", "value": 1}],
|
||||
});
|
||||
assert_eq!(realtime_text_from_conversation_item(&no_text), None);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -9,13 +9,21 @@ use codex_protocol::protocol::Op;
|
|||
use codex_protocol::protocol::RealtimeAudioFrame;
|
||||
use codex_protocol::protocol::RealtimeConversationRealtimeEvent;
|
||||
use codex_protocol::protocol::RealtimeEvent;
|
||||
use codex_protocol::user_input::UserInput;
|
||||
use core_test_support::responses;
|
||||
use core_test_support::responses::start_mock_server;
|
||||
use core_test_support::responses::start_websocket_server;
|
||||
use core_test_support::skip_if_no_network;
|
||||
use core_test_support::streaming_sse::StreamingSseChunk;
|
||||
use core_test_support::streaming_sse::start_streaming_sse_server;
|
||||
use core_test_support::test_codex::test_codex;
|
||||
use core_test_support::wait_for_event;
|
||||
use core_test_support::wait_for_event_match;
|
||||
use pretty_assertions::assert_eq;
|
||||
use serde_json::Value;
|
||||
use serde_json::json;
|
||||
use std::time::Duration;
|
||||
use tokio::sync::oneshot;
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn conversation_start_audio_text_close_round_trip() -> Result<()> {
|
||||
|
|
@ -29,22 +37,12 @@ async fn conversation_start_audio_text_close_round_trip() -> Result<()> {
|
|||
"session": { "id": "sess_1" }
|
||||
})],
|
||||
vec![],
|
||||
vec![
|
||||
json!({
|
||||
"type": "response.output_audio.delta",
|
||||
"delta": "AQID",
|
||||
"sample_rate": 24000,
|
||||
"num_channels": 1
|
||||
}),
|
||||
json!({
|
||||
"type": "conversation.item.added",
|
||||
"item": {
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": "hi"}]
|
||||
}
|
||||
}),
|
||||
],
|
||||
vec![json!({
|
||||
"type": "response.output_audio.delta",
|
||||
"delta": "AQID",
|
||||
"sample_rate": 24000,
|
||||
"num_channels": 1
|
||||
})],
|
||||
],
|
||||
])
|
||||
.await;
|
||||
|
|
@ -458,3 +456,237 @@ async fn conversation_uses_experimental_realtime_ws_backend_prompt_override() ->
|
|||
server.shutdown().await;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn sse_event(event: Value) -> String {
|
||||
responses::sse(vec![event])
|
||||
}
|
||||
|
||||
fn message_input_texts(body: &Value, role: &str) -> Vec<String> {
|
||||
body.get("input")
|
||||
.and_then(Value::as_array)
|
||||
.into_iter()
|
||||
.flatten()
|
||||
.filter(|item| item.get("type").and_then(Value::as_str) == Some("message"))
|
||||
.filter(|item| item.get("role").and_then(Value::as_str) == Some(role))
|
||||
.filter_map(|item| item.get("content").and_then(Value::as_array))
|
||||
.flatten()
|
||||
.filter(|span| span.get("type").and_then(Value::as_str) == Some("input_text"))
|
||||
.filter_map(|span| span.get("text").and_then(Value::as_str).map(str::to_owned))
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn inbound_realtime_text_starts_turn_and_ignores_role() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let api_server = start_mock_server().await;
|
||||
let response_mock = responses::mount_sse_once(
|
||||
&api_server,
|
||||
responses::sse(vec![
|
||||
responses::ev_response_created("resp-1"),
|
||||
responses::ev_assistant_message("msg-1", "ok"),
|
||||
responses::ev_completed("resp-1"),
|
||||
]),
|
||||
)
|
||||
.await;
|
||||
|
||||
let realtime_server = start_websocket_server(vec![vec![vec![
|
||||
json!({
|
||||
"type": "session.created",
|
||||
"session": { "id": "sess_inbound" }
|
||||
}),
|
||||
json!({
|
||||
"type": "conversation.item.added",
|
||||
"item": {
|
||||
"type": "message",
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": "text from realtime"}]
|
||||
}
|
||||
}),
|
||||
]]])
|
||||
.await;
|
||||
|
||||
let mut builder = test_codex().with_config({
|
||||
let realtime_base_url = realtime_server.uri().to_string();
|
||||
move |config| {
|
||||
config.experimental_realtime_ws_base_url = Some(realtime_base_url);
|
||||
}
|
||||
});
|
||||
let test = builder.build(&api_server).await?;
|
||||
|
||||
test.codex
|
||||
.submit(Op::RealtimeConversationStart(ConversationStartParams {
|
||||
prompt: "backend prompt".to_string(),
|
||||
session_id: None,
|
||||
}))
|
||||
.await?;
|
||||
|
||||
let session_created = wait_for_event_match(&test.codex, |msg| match msg {
|
||||
EventMsg::RealtimeConversationRealtime(RealtimeConversationRealtimeEvent {
|
||||
payload: RealtimeEvent::SessionCreated { session_id },
|
||||
}) => Some(session_id.clone()),
|
||||
_ => None,
|
||||
})
|
||||
.await;
|
||||
assert_eq!(session_created, "sess_inbound");
|
||||
|
||||
wait_for_event(&test.codex, |event| {
|
||||
matches!(event, EventMsg::TurnComplete(_))
|
||||
})
|
||||
.await;
|
||||
|
||||
let request = response_mock.single_request();
|
||||
let user_texts = request.message_input_texts("user");
|
||||
assert!(user_texts.iter().any(|text| text == "text from realtime"));
|
||||
|
||||
realtime_server.shutdown().await;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn inbound_realtime_text_steers_active_turn() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let (gate_completed_tx, gate_completed_rx) = oneshot::channel();
|
||||
let first_chunks = vec![
|
||||
StreamingSseChunk {
|
||||
gate: None,
|
||||
body: sse_event(responses::ev_response_created("resp-1")),
|
||||
},
|
||||
StreamingSseChunk {
|
||||
gate: None,
|
||||
body: sse_event(responses::ev_message_item_added("msg-1", "")),
|
||||
},
|
||||
StreamingSseChunk {
|
||||
gate: None,
|
||||
body: sse_event(responses::ev_output_text_delta("first ")),
|
||||
},
|
||||
StreamingSseChunk {
|
||||
gate: None,
|
||||
body: sse_event(responses::ev_output_text_delta("turn")),
|
||||
},
|
||||
StreamingSseChunk {
|
||||
gate: None,
|
||||
body: sse_event(responses::ev_assistant_message("msg-1", "first turn")),
|
||||
},
|
||||
StreamingSseChunk {
|
||||
gate: Some(gate_completed_rx),
|
||||
body: sse_event(responses::ev_completed("resp-1")),
|
||||
},
|
||||
];
|
||||
let second_chunks = vec![
|
||||
StreamingSseChunk {
|
||||
gate: None,
|
||||
body: sse_event(responses::ev_response_created("resp-2")),
|
||||
},
|
||||
StreamingSseChunk {
|
||||
gate: None,
|
||||
body: sse_event(responses::ev_completed("resp-2")),
|
||||
},
|
||||
];
|
||||
let (api_server, _completions) =
|
||||
start_streaming_sse_server(vec![first_chunks, second_chunks]).await;
|
||||
|
||||
let realtime_server = start_websocket_server(vec![vec![
|
||||
vec![json!({
|
||||
"type": "session.created",
|
||||
"session": { "id": "sess_steer" }
|
||||
})],
|
||||
vec![json!({
|
||||
"type": "conversation.item.added",
|
||||
"item": {
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": "steer via realtime"}]
|
||||
}
|
||||
})],
|
||||
]])
|
||||
.await;
|
||||
|
||||
let mut builder = test_codex().with_model("gpt-5.1").with_config({
|
||||
let realtime_base_url = realtime_server.uri().to_string();
|
||||
move |config| {
|
||||
config.experimental_realtime_ws_base_url = Some(realtime_base_url);
|
||||
}
|
||||
});
|
||||
let test = builder.build_with_streaming_server(&api_server).await?;
|
||||
|
||||
test.codex
|
||||
.submit(Op::RealtimeConversationStart(ConversationStartParams {
|
||||
prompt: "backend prompt".to_string(),
|
||||
session_id: None,
|
||||
}))
|
||||
.await?;
|
||||
let _ = wait_for_event_match(&test.codex, |msg| match msg {
|
||||
EventMsg::RealtimeConversationRealtime(RealtimeConversationRealtimeEvent {
|
||||
payload: RealtimeEvent::SessionCreated { session_id },
|
||||
}) if session_id == "sess_steer" => Some(()),
|
||||
_ => None,
|
||||
})
|
||||
.await;
|
||||
|
||||
test.codex
|
||||
.submit(Op::UserInput {
|
||||
items: vec![UserInput::Text {
|
||||
text: "first prompt".to_string(),
|
||||
text_elements: Vec::new(),
|
||||
}],
|
||||
final_output_json_schema: None,
|
||||
})
|
||||
.await?;
|
||||
|
||||
wait_for_event(&test.codex, |event| {
|
||||
matches!(event, EventMsg::AgentMessageContentDelta(_))
|
||||
})
|
||||
.await;
|
||||
|
||||
test.codex
|
||||
.submit(Op::RealtimeConversationAudio(ConversationAudioParams {
|
||||
frame: RealtimeAudioFrame {
|
||||
data: "AQID".to_string(),
|
||||
sample_rate: 24000,
|
||||
num_channels: 1,
|
||||
samples_per_channel: Some(480),
|
||||
},
|
||||
}))
|
||||
.await?;
|
||||
|
||||
let _ = wait_for_event_match(&test.codex, |msg| match msg {
|
||||
EventMsg::RealtimeConversationRealtime(RealtimeConversationRealtimeEvent {
|
||||
payload: RealtimeEvent::ConversationItemAdded(item),
|
||||
}) => item
|
||||
.get("content")
|
||||
.and_then(Value::as_array)
|
||||
.into_iter()
|
||||
.flatten()
|
||||
.any(|content| {
|
||||
content.get("text").and_then(Value::as_str) == Some("steer via realtime")
|
||||
})
|
||||
.then_some(()),
|
||||
_ => None,
|
||||
})
|
||||
.await;
|
||||
|
||||
let _ = gate_completed_tx.send(());
|
||||
wait_for_event(&test.codex, |event| {
|
||||
matches!(event, EventMsg::TurnComplete(_))
|
||||
})
|
||||
.await;
|
||||
|
||||
let requests = api_server.requests().await;
|
||||
assert_eq!(requests.len(), 2);
|
||||
|
||||
let first_body: Value = serde_json::from_slice(&requests[0]).expect("parse first request");
|
||||
let second_body: Value = serde_json::from_slice(&requests[1]).expect("parse second request");
|
||||
let first_texts = message_input_texts(&first_body, "user");
|
||||
let second_texts = message_input_texts(&second_body, "user");
|
||||
|
||||
assert!(first_texts.iter().any(|text| text == "first prompt"));
|
||||
assert!(!first_texts.iter().any(|text| text == "steer via realtime"));
|
||||
assert!(second_texts.iter().any(|text| text == "first prompt"));
|
||||
assert!(second_texts.iter().any(|text| text == "steer via realtime"));
|
||||
|
||||
realtime_server.shutdown().await;
|
||||
api_server.shutdown().await;
|
||||
Ok(())
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue