From fbd7f9b9864bef4ee074974d649f0939f3bc91e9 Mon Sep 17 00:00:00 2001 From: Ahmed Ibrahim Date: Mon, 16 Mar 2026 21:38:07 -0700 Subject: [PATCH] [stack 2/4] Align main realtime v2 wire and runtime flow (#14830) ## Stack Position 2/4. Built on top of #14828. ## Base - #14828 ## Unblocks - #14829 - #14827 ## Scope - Port the realtime v2 wire parsing, session, app-server, and conversation runtime behavior onto the split websocket-method base. - Branch runtime behavior directly on the current realtime session kind instead of parser-derived flow flags. - Keep regression coverage in the existing e2e suites. --------- Co-authored-by: Codex --- .../schema/json/ClientRequest.json | 6 + .../schema/json/ServerNotification.json | 6 + .../codex_app_server_protocol.schemas.json | 6 + .../codex_app_server_protocol.v2.schemas.json | 6 + ...dRealtimeOutputAudioDeltaNotification.json | 6 + .../typescript/v2/ThreadRealtimeAudioChunk.ts | 2 +- .../src/protocol/common.rs | 5 +- .../app-server-protocol/src/protocol/v2.rs | 5 + .../app-server/src/bespoke_event_handling.rs | 28 ++ .../tests/suite/v2/realtime_conversation.rs | 9 +- .../endpoint/realtime_websocket/methods.rs | 166 +++++++- .../realtime_websocket/methods_common.rs | 1 - .../endpoint/realtime_websocket/methods_v1.rs | 23 +- .../endpoint/realtime_websocket/methods_v2.rs | 57 ++- .../endpoint/realtime_websocket/protocol.rs | 103 ++++- .../realtime_websocket/protocol_v1.rs | 1 + .../realtime_websocket/protocol_v2.rs | 60 ++- .../codex-api/tests/realtime_websocket_e2e.rs | 3 + codex-rs/core/src/codex.rs | 3 + codex-rs/core/src/codex_tests.rs | 1 + codex-rs/core/src/realtime_conversation.rs | 366 +++++++++++++++--- .../core/src/realtime_conversation_tests.rs | 48 +-- .../core/tests/suite/realtime_conversation.rs | 5 + codex-rs/protocol/src/protocol.rs | 15 + codex-rs/tui/src/chatwidget/realtime.rs | 2 + codex-rs/tui/src/voice.rs | 1 + .../tui_app_server/src/chatwidget/realtime.rs | 12 + codex-rs/tui_app_server/src/voice.rs | 1 + 28 files changed, 807 insertions(+), 140 deletions(-) diff --git a/codex-rs/app-server-protocol/schema/json/ClientRequest.json b/codex-rs/app-server-protocol/schema/json/ClientRequest.json index 6138c86d2..dd5c955cf 100644 --- a/codex-rs/app-server-protocol/schema/json/ClientRequest.json +++ b/codex-rs/app-server-protocol/schema/json/ClientRequest.json @@ -2779,6 +2779,12 @@ "data": { "type": "string" }, + "itemId": { + "type": [ + "string", + "null" + ] + }, "numChannels": { "format": "uint16", "minimum": 0.0, diff --git a/codex-rs/app-server-protocol/schema/json/ServerNotification.json b/codex-rs/app-server-protocol/schema/json/ServerNotification.json index 7303fa1ca..14908dbb1 100644 --- a/codex-rs/app-server-protocol/schema/json/ServerNotification.json +++ b/codex-rs/app-server-protocol/schema/json/ServerNotification.json @@ -2750,6 +2750,12 @@ "data": { "type": "string" }, + "itemId": { + "type": [ + "string", + "null" + ] + }, "numChannels": { "format": "uint16", "minimum": 0.0, diff --git a/codex-rs/app-server-protocol/schema/json/codex_app_server_protocol.schemas.json b/codex-rs/app-server-protocol/schema/json/codex_app_server_protocol.schemas.json index 11bdd8938..207c6c92f 100644 --- a/codex-rs/app-server-protocol/schema/json/codex_app_server_protocol.schemas.json +++ b/codex-rs/app-server-protocol/schema/json/codex_app_server_protocol.schemas.json @@ -12817,6 +12817,12 @@ "data": { "type": "string" }, + "itemId": { + "type": [ + "string", + "null" + ] + }, "numChannels": { "format": "uint16", "minimum": 0.0, diff --git a/codex-rs/app-server-protocol/schema/json/codex_app_server_protocol.v2.schemas.json b/codex-rs/app-server-protocol/schema/json/codex_app_server_protocol.v2.schemas.json index 3a1af8abc..62acdafee 100644 --- a/codex-rs/app-server-protocol/schema/json/codex_app_server_protocol.v2.schemas.json +++ b/codex-rs/app-server-protocol/schema/json/codex_app_server_protocol.v2.schemas.json @@ -10577,6 +10577,12 @@ "data": { "type": "string" }, + "itemId": { + "type": [ + "string", + "null" + ] + }, "numChannels": { "format": "uint16", "minimum": 0.0, diff --git a/codex-rs/app-server-protocol/schema/json/v2/ThreadRealtimeOutputAudioDeltaNotification.json b/codex-rs/app-server-protocol/schema/json/v2/ThreadRealtimeOutputAudioDeltaNotification.json index d4df6194f..6c75f6755 100644 --- a/codex-rs/app-server-protocol/schema/json/v2/ThreadRealtimeOutputAudioDeltaNotification.json +++ b/codex-rs/app-server-protocol/schema/json/v2/ThreadRealtimeOutputAudioDeltaNotification.json @@ -7,6 +7,12 @@ "data": { "type": "string" }, + "itemId": { + "type": [ + "string", + "null" + ] + }, "numChannels": { "format": "uint16", "minimum": 0.0, diff --git a/codex-rs/app-server-protocol/schema/typescript/v2/ThreadRealtimeAudioChunk.ts b/codex-rs/app-server-protocol/schema/typescript/v2/ThreadRealtimeAudioChunk.ts index 078f64224..eefb79dd6 100644 --- a/codex-rs/app-server-protocol/schema/typescript/v2/ThreadRealtimeAudioChunk.ts +++ b/codex-rs/app-server-protocol/schema/typescript/v2/ThreadRealtimeAudioChunk.ts @@ -5,4 +5,4 @@ /** * EXPERIMENTAL - thread realtime audio chunk. */ -export type ThreadRealtimeAudioChunk = { data: string, sampleRate: number, numChannels: number, samplesPerChannel: number | null, }; +export type ThreadRealtimeAudioChunk = { data: string, sampleRate: number, numChannels: number, samplesPerChannel: number | null, itemId: string | null, }; diff --git a/codex-rs/app-server-protocol/src/protocol/common.rs b/codex-rs/app-server-protocol/src/protocol/common.rs index 75aa7768d..73139a2e0 100644 --- a/codex-rs/app-server-protocol/src/protocol/common.rs +++ b/codex-rs/app-server-protocol/src/protocol/common.rs @@ -1577,6 +1577,7 @@ mod tests { sample_rate: 24_000, num_channels: 1, samples_per_channel: Some(512), + item_id: None, }, }, ); @@ -1589,7 +1590,8 @@ mod tests { "data": "AQID", "sampleRate": 24000, "numChannels": 1, - "samplesPerChannel": 512 + "samplesPerChannel": 512, + "itemId": null } } }), @@ -1641,6 +1643,7 @@ mod tests { sample_rate: 24_000, num_channels: 1, samples_per_channel: Some(512), + item_id: None, }, }, ); diff --git a/codex-rs/app-server-protocol/src/protocol/v2.rs b/codex-rs/app-server-protocol/src/protocol/v2.rs index 0f17889a4..3b481d563 100644 --- a/codex-rs/app-server-protocol/src/protocol/v2.rs +++ b/codex-rs/app-server-protocol/src/protocol/v2.rs @@ -3659,6 +3659,7 @@ pub struct ThreadRealtimeAudioChunk { pub sample_rate: u32, pub num_channels: u16, pub samples_per_channel: Option, + pub item_id: Option, } impl From for ThreadRealtimeAudioChunk { @@ -3668,12 +3669,14 @@ impl From for ThreadRealtimeAudioChunk { sample_rate, num_channels, samples_per_channel, + item_id, } = value; Self { data, sample_rate, num_channels, samples_per_channel, + item_id, } } } @@ -3685,12 +3688,14 @@ impl From for CoreRealtimeAudioFrame { sample_rate, num_channels, samples_per_channel, + item_id, } = value; Self { data, sample_rate, num_channels, samples_per_channel, + item_id, } } } diff --git a/codex-rs/app-server/src/bespoke_event_handling.rs b/codex-rs/app-server/src/bespoke_event_handling.rs index 04bef6655..4f4f995e2 100644 --- a/codex-rs/app-server/src/bespoke_event_handling.rs +++ b/codex-rs/app-server/src/bespoke_event_handling.rs @@ -350,6 +350,20 @@ pub(crate) async fn apply_bespoke_event_handling( if let ApiVersion::V2 = api_version { match event.payload { RealtimeEvent::SessionUpdated { .. } => {} + RealtimeEvent::InputAudioSpeechStarted(event) => { + let notification = ThreadRealtimeItemAddedNotification { + thread_id: conversation_id.to_string(), + item: serde_json::json!({ + "type": "input_audio_buffer.speech_started", + "item_id": event.item_id, + }), + }; + outgoing + .send_server_notification(ServerNotification::ThreadRealtimeItemAdded( + notification, + )) + .await; + } RealtimeEvent::InputTranscriptDelta(_) => {} RealtimeEvent::OutputTranscriptDelta(_) => {} RealtimeEvent::AudioOut(audio) => { @@ -363,6 +377,20 @@ pub(crate) async fn apply_bespoke_event_handling( ) .await; } + RealtimeEvent::ResponseCancelled(event) => { + let notification = ThreadRealtimeItemAddedNotification { + thread_id: conversation_id.to_string(), + item: serde_json::json!({ + "type": "response.cancelled", + "response_id": event.response_id, + }), + }; + outgoing + .send_server_notification(ServerNotification::ThreadRealtimeItemAdded( + notification, + )) + .await; + } RealtimeEvent::ConversationItemAdded(item) => { let notification = ThreadRealtimeItemAddedNotification { thread_id: conversation_id.to_string(), diff --git a/codex-rs/app-server/tests/suite/v2/realtime_conversation.rs b/codex-rs/app-server/tests/suite/v2/realtime_conversation.rs index a771fb874..71b6d6dcf 100644 --- a/codex-rs/app-server/tests/suite/v2/realtime_conversation.rs +++ b/codex-rs/app-server/tests/suite/v2/realtime_conversation.rs @@ -70,6 +70,7 @@ async fn realtime_conversation_streams_v2_notifications() -> Result<()> { "message": "upstream boom" }), ], + vec![], ]]) .await; @@ -135,6 +136,7 @@ async fn realtime_conversation_streams_v2_notifications() -> Result<()> { sample_rate: 24_000, num_channels: 1, samples_per_channel: Some(480), + item_id: None, }, }) .await?; @@ -191,7 +193,7 @@ async fn realtime_conversation_streams_v2_notifications() -> Result<()> { let connections = realtime_server.connections(); assert_eq!(connections.len(), 1); let connection = &connections[0]; - assert_eq!(connection.len(), 3); + assert_eq!(connection.len(), 4); assert_eq!( connection[0].body_json()["type"].as_str(), Some("session.update") @@ -211,6 +213,10 @@ async fn realtime_conversation_streams_v2_notifications() -> Result<()> { .as_str() .context("expected websocket request type")? .to_string(), + connection[3].body_json()["type"] + .as_str() + .context("expected websocket request type")? + .to_string(), ]; request_types.sort(); assert_eq!( @@ -218,6 +224,7 @@ async fn realtime_conversation_streams_v2_notifications() -> Result<()> { [ "conversation.item.create".to_string(), "input_audio_buffer.append".to_string(), + "response.create".to_string(), ] ); diff --git a/codex-rs/codex-api/src/endpoint/realtime_websocket/methods.rs b/codex-rs/codex-api/src/endpoint/realtime_websocket/methods.rs index 5082f6314..fe83c751a 100644 --- a/codex-rs/codex-api/src/endpoint/realtime_websocket/methods.rs +++ b/codex-rs/codex-api/src/endpoint/realtime_websocket/methods.rs @@ -272,12 +272,12 @@ impl RealtimeWebsocketConnection { impl RealtimeWebsocketWriter { pub async fn send_audio_frame(&self, frame: RealtimeAudioFrame) -> Result<(), ApiError> { - self.send_json(RealtimeOutboundMessage::InputAudioBufferAppend { audio: frame.data }) + self.send_json(&RealtimeOutboundMessage::InputAudioBufferAppend { audio: frame.data }) .await } pub async fn send_conversation_item_create(&self, text: String) -> Result<(), ApiError> { - self.send_json(conversation_item_create_message(self.event_parser, text)) + self.send_json(&conversation_item_create_message(self.event_parser, text)) .await } @@ -286,7 +286,7 @@ impl RealtimeWebsocketWriter { handoff_id: String, output_text: String, ) -> Result<(), ApiError> { - self.send_json(conversation_handoff_append_message( + self.send_json(&conversation_handoff_append_message( self.event_parser, handoff_id, output_text, @@ -294,6 +294,11 @@ impl RealtimeWebsocketWriter { .await } + pub async fn send_response_create(&self) -> Result<(), ApiError> { + self.send_json(&RealtimeOutboundMessage::ResponseCreate) + .await + } + pub async fn send_session_update( &self, instructions: String, @@ -301,7 +306,7 @@ impl RealtimeWebsocketWriter { ) -> Result<(), ApiError> { let session_mode = normalized_session_mode(self.event_parser, session_mode); let session = session_update_session(self.event_parser, instructions, session_mode); - self.send_json(RealtimeOutboundMessage::SessionUpdate { session }) + self.send_json(&RealtimeOutboundMessage::SessionUpdate { session }) .await } @@ -319,11 +324,14 @@ impl RealtimeWebsocketWriter { Ok(()) } - async fn send_json(&self, message: RealtimeOutboundMessage) -> Result<(), ApiError> { - let payload = serde_json::to_string(&message) + async fn send_json(&self, message: &RealtimeOutboundMessage) -> Result<(), ApiError> { + let payload = serde_json::to_string(message) .map_err(|err| ApiError::Stream(format!("failed to encode realtime request: {err}")))?; debug!(?message, "realtime websocket request"); + self.send_payload(payload).await + } + pub async fn send_payload(&self, payload: String) -> Result<(), ApiError> { if self.is_closed.load(Ordering::SeqCst) { return Err(ApiError::Stream( "realtime websocket connection is closed".to_string(), @@ -392,6 +400,7 @@ impl RealtimeWebsocketEvents { async fn update_active_transcript(&self, event: &mut RealtimeEvent) { let mut active_transcript = self.active_transcript.lock().await; match event { + RealtimeEvent::InputAudioSpeechStarted(_) => {} RealtimeEvent::InputTranscriptDelta(RealtimeTranscriptDelta { delta }) => { append_transcript_delta(&mut active_transcript.entries, "user", delta); } @@ -403,6 +412,7 @@ impl RealtimeWebsocketEvents { } RealtimeEvent::SessionUpdated { .. } | RealtimeEvent::AudioOut(_) + | RealtimeEvent::ResponseCancelled(_) | RealtimeEvent::ConversationItemAdded(_) | RealtimeEvent::ConversationItemDone { .. } | RealtimeEvent::Error(_) => {} @@ -616,6 +626,8 @@ mod tests { use crate::endpoint::realtime_websocket::protocol::RealtimeHandoffRequested; use crate::endpoint::realtime_websocket::protocol::RealtimeTranscriptDelta; use crate::endpoint::realtime_websocket::protocol::RealtimeTranscriptEntry; + use codex_protocol::protocol::RealtimeInputAudioSpeechStarted; + use codex_protocol::protocol::RealtimeResponseCancelled; use http::HeaderValue; use pretty_assertions::assert_eq; use serde_json::Value; @@ -660,6 +672,7 @@ mod tests { sample_rate: 48000, num_channels: 1, samples_per_channel: Some(960), + item_id: None, })) ); } @@ -809,10 +822,112 @@ mod tests { sample_rate: 24_000, num_channels: 1, samples_per_channel: None, + item_id: None, })) ); } + #[test] + fn parse_realtime_v2_response_audio_delta_with_item_id() { + let payload = json!({ + "type": "response.audio.delta", + "delta": "AQID", + "item_id": "item_audio_1" + }) + .to_string(); + + assert_eq!( + parse_realtime_event(payload.as_str(), RealtimeEventParser::RealtimeV2), + Some(RealtimeEvent::AudioOut(RealtimeAudioFrame { + data: "AQID".to_string(), + sample_rate: 24_000, + num_channels: 1, + samples_per_channel: None, + item_id: Some("item_audio_1".to_string()), + })) + ); + } + + #[test] + fn parse_realtime_v2_speech_started_event() { + let payload = json!({ + "type": "input_audio_buffer.speech_started", + "item_id": "item_input_1" + }) + .to_string(); + + assert_eq!( + parse_realtime_event(payload.as_str(), RealtimeEventParser::RealtimeV2), + Some(RealtimeEvent::InputAudioSpeechStarted( + RealtimeInputAudioSpeechStarted { + item_id: Some("item_input_1".to_string()), + } + )) + ); + } + + #[test] + fn parse_realtime_v2_response_cancelled_event() { + let payload = json!({ + "type": "response.cancelled", + "response": {"id": "resp_cancelled_1"} + }) + .to_string(); + + assert_eq!( + parse_realtime_event(payload.as_str(), RealtimeEventParser::RealtimeV2), + Some(RealtimeEvent::ResponseCancelled( + RealtimeResponseCancelled { + response_id: Some("resp_cancelled_1".to_string()), + } + )) + ); + } + + #[test] + fn parse_realtime_v2_response_done_handoff_event() { + let payload = json!({ + "type": "response.done", + "response": { + "output": [{ + "id": "item_123", + "type": "function_call", + "name": "codex", + "call_id": "call_123", + "arguments": "{\"prompt\":\"delegate from done\"}" + }] + } + }) + .to_string(); + + assert_eq!( + parse_realtime_event(payload.as_str(), RealtimeEventParser::RealtimeV2), + Some(RealtimeEvent::HandoffRequested(RealtimeHandoffRequested { + handoff_id: "call_123".to_string(), + item_id: "item_123".to_string(), + input_transcript: "delegate from done".to_string(), + active_transcript: Vec::new(), + })) + ); + } + + #[test] + fn parse_realtime_v2_response_created_event() { + let payload = json!({ + "type": "response.created", + "response": {"id": "resp_created_1"} + }) + .to_string(); + + assert_eq!( + parse_realtime_event(payload.as_str(), RealtimeEventParser::RealtimeV2), + Some(RealtimeEvent::ConversationItemAdded(json!({ + "type": "response.created", + "response": {"id": "resp_created_1"} + }))) + ); + } + #[test] fn merge_request_headers_matches_http_precedence() { let mut provider_headers = HeaderMap::new(); @@ -1169,6 +1284,7 @@ mod tests { sample_rate: 48000, num_channels: 1, samples_per_channel: Some(960), + item_id: None, }) .await .expect("send audio"); @@ -1196,6 +1312,7 @@ mod tests { sample_rate: 48000, num_channels: 1, samples_per_channel: None, + item_id: None, }) ); @@ -1285,9 +1402,38 @@ mod tests { first_json["session"]["type"], Value::String("realtime".to_string()) ); + assert_eq!(first_json["session"]["output_modalities"], json!(["audio"])); + assert_eq!( + first_json["session"]["audio"]["input"]["format"], + json!({ + "type": "audio/pcm", + "rate": 24_000, + }) + ); + assert_eq!( + first_json["session"]["audio"]["input"]["noise_reduction"], + json!({ + "type": "near_field", + }) + ); + assert_eq!( + first_json["session"]["audio"]["input"]["turn_detection"], + json!({ + "type": "server_vad", + "interrupt_response": true, + "create_response": true, + }) + ); + assert_eq!( + first_json["session"]["audio"]["output"]["format"], + json!({ + "type": "audio/pcm", + "rate": 24_000, + }) + ); assert_eq!( first_json["session"]["audio"]["output"]["voice"], - Value::String("alloy".to_string()) + Value::String("marin".to_string()) ); assert_eq!( first_json["session"]["tools"][0]["type"], @@ -1301,6 +1447,10 @@ mod tests { first_json["session"]["tools"][0]["parameters"]["required"], json!(["prompt"]) ); + assert_eq!( + first_json["session"]["tool_choice"], + Value::String("auto".to_string()) + ); ws.send(Message::Text( json!({ @@ -1511,6 +1661,7 @@ mod tests { sample_rate: 24_000, num_channels: 1, samples_per_channel: Some(480), + item_id: None, }) .await .expect("send audio"); @@ -1690,6 +1841,7 @@ mod tests { sample_rate: 48000, num_channels: 1, samples_per_channel: Some(960), + item_id: None, }), ) .await diff --git a/codex-rs/codex-api/src/endpoint/realtime_websocket/methods_common.rs b/codex-rs/codex-api/src/endpoint/realtime_websocket/methods_common.rs index 4a5013c65..48f21964a 100644 --- a/codex-rs/codex-api/src/endpoint/realtime_websocket/methods_common.rs +++ b/codex-rs/codex-api/src/endpoint/realtime_websocket/methods_common.rs @@ -12,7 +12,6 @@ use crate::endpoint::realtime_websocket::protocol::RealtimeSessionMode; use crate::endpoint::realtime_websocket::protocol::SessionUpdateSession; pub(super) const REALTIME_AUDIO_SAMPLE_RATE: u32 = 24_000; -pub(super) const REALTIME_AUDIO_FORMAT: &str = "audio/pcm"; pub(super) fn normalized_session_mode( event_parser: RealtimeEventParser, diff --git a/codex-rs/codex-api/src/endpoint/realtime_websocket/methods_v1.rs b/codex-rs/codex-api/src/endpoint/realtime_websocket/methods_v1.rs index 8280c4d9a..b31899ff8 100644 --- a/codex-rs/codex-api/src/endpoint/realtime_websocket/methods_v1.rs +++ b/codex-rs/codex-api/src/endpoint/realtime_websocket/methods_v1.rs @@ -1,25 +1,27 @@ -use crate::endpoint::realtime_websocket::methods_common::REALTIME_AUDIO_FORMAT; use crate::endpoint::realtime_websocket::methods_common::REALTIME_AUDIO_SAMPLE_RATE; +use crate::endpoint::realtime_websocket::protocol::AudioFormatType; +use crate::endpoint::realtime_websocket::protocol::ConversationContentType; use crate::endpoint::realtime_websocket::protocol::ConversationItemContent; use crate::endpoint::realtime_websocket::protocol::ConversationItemPayload; +use crate::endpoint::realtime_websocket::protocol::ConversationItemType; use crate::endpoint::realtime_websocket::protocol::ConversationMessageItem; +use crate::endpoint::realtime_websocket::protocol::ConversationRole; use crate::endpoint::realtime_websocket::protocol::RealtimeOutboundMessage; use crate::endpoint::realtime_websocket::protocol::SessionAudio; use crate::endpoint::realtime_websocket::protocol::SessionAudioFormat; use crate::endpoint::realtime_websocket::protocol::SessionAudioInput; use crate::endpoint::realtime_websocket::protocol::SessionAudioOutput; use crate::endpoint::realtime_websocket::protocol::SessionAudioVoice; +use crate::endpoint::realtime_websocket::protocol::SessionType; use crate::endpoint::realtime_websocket::protocol::SessionUpdateSession; -const REALTIME_V1_SESSION_TYPE: &str = "quicksilver"; - pub(super) fn conversation_item_create_message(text: String) -> RealtimeOutboundMessage { RealtimeOutboundMessage::ConversationItemCreate { item: ConversationItemPayload::Message(ConversationMessageItem { - kind: "message".to_string(), - role: "user".to_string(), + r#type: ConversationItemType::Message, + role: ConversationRole::User, content: vec![ConversationItemContent { - kind: "text".to_string(), + r#type: ConversationContentType::Text, text, }], }), @@ -38,20 +40,25 @@ pub(super) fn conversation_handoff_append_message( pub(super) fn session_update_session(instructions: String) -> SessionUpdateSession { SessionUpdateSession { - kind: REALTIME_V1_SESSION_TYPE.to_string(), + r#type: SessionType::Quicksilver, instructions: Some(instructions), + output_modalities: None, audio: SessionAudio { input: SessionAudioInput { format: SessionAudioFormat { - kind: REALTIME_AUDIO_FORMAT.to_string(), + r#type: AudioFormatType::AudioPcm, rate: REALTIME_AUDIO_SAMPLE_RATE, }, + noise_reduction: None, + turn_detection: None, }, output: Some(SessionAudioOutput { + format: None, voice: SessionAudioVoice::Fathom, }), }, tools: None, + tool_choice: None, } } diff --git a/codex-rs/codex-api/src/endpoint/realtime_websocket/methods_v2.rs b/codex-rs/codex-api/src/endpoint/realtime_websocket/methods_v2.rs index 59a8f1284..afff680c1 100644 --- a/codex-rs/codex-api/src/endpoint/realtime_websocket/methods_v2.rs +++ b/codex-rs/codex-api/src/endpoint/realtime_websocket/methods_v2.rs @@ -1,31 +1,42 @@ -use crate::endpoint::realtime_websocket::methods_common::REALTIME_AUDIO_FORMAT; use crate::endpoint::realtime_websocket::methods_common::REALTIME_AUDIO_SAMPLE_RATE; +use crate::endpoint::realtime_websocket::protocol::AudioFormatType; +use crate::endpoint::realtime_websocket::protocol::ConversationContentType; use crate::endpoint::realtime_websocket::protocol::ConversationFunctionCallOutputItem; use crate::endpoint::realtime_websocket::protocol::ConversationItemContent; use crate::endpoint::realtime_websocket::protocol::ConversationItemPayload; +use crate::endpoint::realtime_websocket::protocol::ConversationItemType; use crate::endpoint::realtime_websocket::protocol::ConversationMessageItem; +use crate::endpoint::realtime_websocket::protocol::ConversationRole; +use crate::endpoint::realtime_websocket::protocol::NoiseReductionType; use crate::endpoint::realtime_websocket::protocol::RealtimeOutboundMessage; use crate::endpoint::realtime_websocket::protocol::RealtimeSessionMode; use crate::endpoint::realtime_websocket::protocol::SessionAudio; use crate::endpoint::realtime_websocket::protocol::SessionAudioFormat; use crate::endpoint::realtime_websocket::protocol::SessionAudioInput; use crate::endpoint::realtime_websocket::protocol::SessionAudioOutput; +use crate::endpoint::realtime_websocket::protocol::SessionAudioOutputFormat; use crate::endpoint::realtime_websocket::protocol::SessionAudioVoice; use crate::endpoint::realtime_websocket::protocol::SessionFunctionTool; +use crate::endpoint::realtime_websocket::protocol::SessionNoiseReduction; +use crate::endpoint::realtime_websocket::protocol::SessionToolType; +use crate::endpoint::realtime_websocket::protocol::SessionTurnDetection; +use crate::endpoint::realtime_websocket::protocol::SessionType; use crate::endpoint::realtime_websocket::protocol::SessionUpdateSession; +use crate::endpoint::realtime_websocket::protocol::TurnDetectionType; use serde_json::json; -const REALTIME_V2_SESSION_TYPE: &str = "realtime"; +const REALTIME_V2_OUTPUT_MODALITY_AUDIO: &str = "audio"; +const REALTIME_V2_TOOL_CHOICE: &str = "auto"; const REALTIME_V2_CODEX_TOOL_NAME: &str = "codex"; -const REALTIME_V2_CODEX_TOOL_DESCRIPTION: &str = "Delegate work to Codex and return the result."; +const REALTIME_V2_CODEX_TOOL_DESCRIPTION: &str = "Delegate a request to Codex and return the final result to the user. Use this as the default action. If the user asks to do something next, later, after this, or once current work finishes, call this tool so the work is actually queued instead of merely promising to do it later."; pub(super) fn conversation_item_create_message(text: String) -> RealtimeOutboundMessage { RealtimeOutboundMessage::ConversationItemCreate { item: ConversationItemPayload::Message(ConversationMessageItem { - kind: "message".to_string(), - role: "user".to_string(), + r#type: ConversationItemType::Message, + role: ConversationRole::User, content: vec![ConversationItemContent { - kind: "input_text".to_string(), + r#type: ConversationContentType::InputText, text, }], }), @@ -38,7 +49,7 @@ pub(super) fn conversation_handoff_append_message( ) -> RealtimeOutboundMessage { RealtimeOutboundMessage::ConversationItemCreate { item: ConversationItemPayload::FunctionCallOutput(ConversationFunctionCallOutputItem { - kind: "function_call_output".to_string(), + r#type: ConversationItemType::FunctionCallOutput, call_id: handoff_id, output: output_text, }), @@ -51,21 +62,34 @@ pub(super) fn session_update_session( ) -> SessionUpdateSession { match session_mode { RealtimeSessionMode::Conversational => SessionUpdateSession { - kind: REALTIME_V2_SESSION_TYPE.to_string(), + r#type: SessionType::Realtime, instructions: Some(instructions), + output_modalities: Some(vec![REALTIME_V2_OUTPUT_MODALITY_AUDIO.to_string()]), audio: SessionAudio { input: SessionAudioInput { format: SessionAudioFormat { - kind: REALTIME_AUDIO_FORMAT.to_string(), + r#type: AudioFormatType::AudioPcm, rate: REALTIME_AUDIO_SAMPLE_RATE, }, + noise_reduction: Some(SessionNoiseReduction { + r#type: NoiseReductionType::NearField, + }), + turn_detection: Some(SessionTurnDetection { + r#type: TurnDetectionType::ServerVad, + interrupt_response: true, + create_response: true, + }), }, output: Some(SessionAudioOutput { - voice: SessionAudioVoice::Alloy, + format: Some(SessionAudioOutputFormat { + r#type: AudioFormatType::AudioPcm, + rate: REALTIME_AUDIO_SAMPLE_RATE, + }), + voice: SessionAudioVoice::Marin, }), }, tools: Some(vec![SessionFunctionTool { - kind: "function".to_string(), + r#type: SessionToolType::Function, name: REALTIME_V2_CODEX_TOOL_NAME.to_string(), description: REALTIME_V2_CODEX_TOOL_DESCRIPTION.to_string(), parameters: json!({ @@ -73,27 +97,32 @@ pub(super) fn session_update_session( "properties": { "prompt": { "type": "string", - "description": "Prompt text for the delegated Codex task." + "description": "The user request to delegate to Codex." } }, "required": ["prompt"], "additionalProperties": false }), }]), + tool_choice: Some(REALTIME_V2_TOOL_CHOICE.to_string()), }, RealtimeSessionMode::Transcription => SessionUpdateSession { - kind: "transcription".to_string(), + r#type: SessionType::Transcription, instructions: None, + output_modalities: None, audio: SessionAudio { input: SessionAudioInput { format: SessionAudioFormat { - kind: REALTIME_AUDIO_FORMAT.to_string(), + r#type: AudioFormatType::AudioPcm, rate: REALTIME_AUDIO_SAMPLE_RATE, }, + noise_reduction: None, + turn_detection: None, }, output: None, }, tools: None, + tool_choice: None, }, } } diff --git a/codex-rs/codex-api/src/endpoint/realtime_websocket/protocol.rs b/codex-rs/codex-api/src/endpoint/realtime_websocket/protocol.rs index 73c2c1052..2c629249f 100644 --- a/codex-rs/codex-api/src/endpoint/realtime_websocket/protocol.rs +++ b/codex-rs/codex-api/src/endpoint/realtime_websocket/protocol.rs @@ -39,6 +39,8 @@ pub(super) enum RealtimeOutboundMessage { handoff_id: String, output_text: String, }, + #[serde(rename = "response.create")] + ResponseCreate, #[serde(rename = "session.update")] SessionUpdate { session: SessionUpdateSession }, #[serde(rename = "conversation.item.create")] @@ -48,12 +50,24 @@ pub(super) enum RealtimeOutboundMessage { #[derive(Debug, Clone, Serialize)] pub(super) struct SessionUpdateSession { #[serde(rename = "type")] - pub(super) kind: String, + pub(super) r#type: SessionType, #[serde(skip_serializing_if = "Option::is_none")] pub(super) instructions: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub(super) output_modalities: Option>, pub(super) audio: SessionAudio, #[serde(skip_serializing_if = "Option::is_none")] pub(super) tools: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub(super) tool_choice: Option, +} + +#[derive(Debug, Clone, Copy, Serialize)] +#[serde(rename_all = "snake_case")] +pub(super) enum SessionType { + Quicksilver, + Realtime, + Transcription, } #[derive(Debug, Clone, Serialize)] @@ -66,17 +80,29 @@ pub(super) struct SessionAudio { #[derive(Debug, Clone, Serialize)] pub(super) struct SessionAudioInput { pub(super) format: SessionAudioFormat, + #[serde(skip_serializing_if = "Option::is_none")] + pub(super) noise_reduction: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub(super) turn_detection: Option, } #[derive(Debug, Clone, Serialize)] pub(super) struct SessionAudioFormat { #[serde(rename = "type")] - pub(super) kind: String, + pub(super) r#type: AudioFormatType, pub(super) rate: u32, } +#[derive(Debug, Clone, Copy, Serialize)] +pub(super) enum AudioFormatType { + #[serde(rename = "audio/pcm")] + AudioPcm, +} + #[derive(Debug, Clone, Serialize)] pub(super) struct SessionAudioOutput { + #[serde(skip_serializing_if = "Option::is_none")] + pub(super) format: Option, pub(super) voice: SessionAudioVoice, } @@ -84,18 +110,64 @@ pub(super) struct SessionAudioOutput { pub(super) enum SessionAudioVoice { #[serde(rename = "fathom")] Fathom, - #[serde(rename = "alloy")] - Alloy, + #[serde(rename = "marin")] + Marin, +} + +#[derive(Debug, Clone, Serialize)] +pub(super) struct SessionNoiseReduction { + #[serde(rename = "type")] + pub(super) r#type: NoiseReductionType, +} + +#[derive(Debug, Clone, Copy, Serialize)] +#[serde(rename_all = "snake_case")] +pub(super) enum NoiseReductionType { + NearField, +} + +#[derive(Debug, Clone, Serialize)] +pub(super) struct SessionTurnDetection { + #[serde(rename = "type")] + pub(super) r#type: TurnDetectionType, + pub(super) interrupt_response: bool, + pub(super) create_response: bool, +} + +#[derive(Debug, Clone, Copy, Serialize)] +#[serde(rename_all = "snake_case")] +pub(super) enum TurnDetectionType { + ServerVad, +} + +#[derive(Debug, Clone, Serialize)] +pub(super) struct SessionAudioOutputFormat { + #[serde(rename = "type")] + pub(super) r#type: AudioFormatType, + pub(super) rate: u32, } #[derive(Debug, Clone, Serialize)] pub(super) struct ConversationMessageItem { #[serde(rename = "type")] - pub(super) kind: String, - pub(super) role: String, + pub(super) r#type: ConversationItemType, + pub(super) role: ConversationRole, pub(super) content: Vec, } +#[derive(Debug, Clone, Copy, Serialize)] +#[serde(rename_all = "snake_case")] +pub(super) enum ConversationItemType { + Message, + FunctionCallOutput, +} + +#[derive(Debug, Clone, Copy, Serialize)] +#[serde(rename_all = "snake_case")] +pub(super) enum ConversationRole { + User, +} + #[derive(Debug, Clone, Serialize)] #[serde(untagged)] pub(super) enum ConversationItemPayload { @@ -106,7 +178,7 @@ pub(super) enum ConversationItemPayload { #[derive(Debug, Clone, Serialize)] pub(super) struct ConversationFunctionCallOutputItem { #[serde(rename = "type")] - pub(super) kind: String, + pub(super) r#type: ConversationItemType, pub(super) call_id: String, pub(super) output: String, } @@ -114,19 +186,32 @@ pub(super) struct ConversationFunctionCallOutputItem { #[derive(Debug, Clone, Serialize)] pub(super) struct ConversationItemContent { #[serde(rename = "type")] - pub(super) kind: String, + pub(super) r#type: ConversationContentType, pub(super) text: String, } +#[derive(Debug, Clone, Copy, Serialize)] +#[serde(rename_all = "snake_case")] +pub(super) enum ConversationContentType { + Text, + InputText, +} + #[derive(Debug, Clone, Serialize)] pub(super) struct SessionFunctionTool { #[serde(rename = "type")] - pub(super) kind: String, + pub(super) r#type: SessionToolType, pub(super) name: String, pub(super) description: String, pub(super) parameters: Value, } +#[derive(Debug, Clone, Copy, Serialize)] +#[serde(rename_all = "snake_case")] +pub(super) enum SessionToolType { + Function, +} + pub(super) fn parse_realtime_event( payload: &str, event_parser: RealtimeEventParser, diff --git a/codex-rs/codex-api/src/endpoint/realtime_websocket/protocol_v1.rs b/codex-rs/codex-api/src/endpoint/realtime_websocket/protocol_v1.rs index 04e76fb44..b66cf2b24 100644 --- a/codex-rs/codex-api/src/endpoint/realtime_websocket/protocol_v1.rs +++ b/codex-rs/codex-api/src/endpoint/realtime_websocket/protocol_v1.rs @@ -35,6 +35,7 @@ pub(super) fn parse_realtime_event_v1(payload: &str) -> Option { .get("samples_per_channel") .and_then(Value::as_u64) .and_then(|value| u32::try_from(value).ok()), + item_id: None, })) } "conversation.input_transcript.delta" => { diff --git a/codex-rs/codex-api/src/endpoint/realtime_websocket/protocol_v2.rs b/codex-rs/codex-api/src/endpoint/realtime_websocket/protocol_v2.rs index 7ef318d3f..b33007519 100644 --- a/codex-rs/codex-api/src/endpoint/realtime_websocket/protocol_v2.rs +++ b/codex-rs/codex-api/src/endpoint/realtime_websocket/protocol_v2.rs @@ -5,6 +5,8 @@ use crate::endpoint::realtime_websocket::protocol_common::parse_transcript_delta use codex_protocol::protocol::RealtimeAudioFrame; use codex_protocol::protocol::RealtimeEvent; use codex_protocol::protocol::RealtimeHandoffRequested; +use codex_protocol::protocol::RealtimeInputAudioSpeechStarted; +use codex_protocol::protocol::RealtimeResponseCancelled; use serde_json::Map as JsonMap; use serde_json::Value; use tracing::debug; @@ -19,7 +21,9 @@ pub(super) fn parse_realtime_event_v2(payload: &str) -> Option { match message_type.as_str() { "session.updated" => parse_session_updated_event(&parsed), - "response.output_audio.delta" => parse_output_audio_delta_event(&parsed), + "response.output_audio.delta" | "response.audio.delta" => { + parse_output_audio_delta_event(&parsed) + } "conversation.item.input_audio_transcription.delta" => { parse_transcript_delta_event(&parsed, "delta").map(RealtimeEvent::InputTranscriptDelta) } @@ -30,11 +34,37 @@ pub(super) fn parse_realtime_event_v2(payload: &str) -> Option { "response.output_text.delta" | "response.output_audio_transcript.delta" => { parse_transcript_delta_event(&parsed, "delta").map(RealtimeEvent::OutputTranscriptDelta) } + "input_audio_buffer.speech_started" => Some(RealtimeEvent::InputAudioSpeechStarted( + RealtimeInputAudioSpeechStarted { + item_id: parsed + .get("item_id") + .and_then(Value::as_str) + .map(str::to_string), + }, + )), "conversation.item.added" => parsed .get("item") .cloned() .map(RealtimeEvent::ConversationItemAdded), "conversation.item.done" => parse_conversation_item_done_event(&parsed), + "response.created" => Some(RealtimeEvent::ConversationItemAdded(parsed)), + "response.done" => parse_response_done_event(parsed), + "response.cancelled" => Some(RealtimeEvent::ResponseCancelled( + RealtimeResponseCancelled { + response_id: parsed + .get("response") + .and_then(Value::as_object) + .and_then(|response| response.get("id")) + .and_then(Value::as_str) + .map(str::to_string) + .or_else(|| { + parsed + .get("response_id") + .and_then(Value::as_str) + .map(str::to_string) + }), + }, + )), "error" => parse_error_event(&parsed), _ => { debug!("received unsupported realtime v2 event type: {message_type}, data: {payload}"); @@ -67,6 +97,10 @@ fn parse_output_audio_delta_event(parsed: &Value) -> Option { .get("samples_per_channel") .and_then(Value::as_u64) .and_then(|value| u32::try_from(value).ok()), + item_id: parsed + .get("item_id") + .and_then(Value::as_str) + .map(str::to_string), })) } @@ -82,6 +116,30 @@ fn parse_conversation_item_done_event(parsed: &Value) -> Option { .map(|item_id| RealtimeEvent::ConversationItemDone { item_id }) } +fn parse_response_done_event(parsed: Value) -> Option { + if let Some(handoff) = parse_response_done_handoff_requested_event(&parsed) { + return Some(handoff); + } + + Some(RealtimeEvent::ConversationItemAdded(parsed)) +} + +fn parse_response_done_handoff_requested_event(parsed: &Value) -> Option { + let item = parsed + .get("response") + .and_then(Value::as_object) + .and_then(|response| response.get("output")) + .and_then(Value::as_array)? + .iter() + .find(|item| { + item.get("type").and_then(Value::as_str) == Some("function_call") + && item.get("name").and_then(Value::as_str) == Some(CODEX_TOOL_NAME) + })? + .as_object()?; + + parse_handoff_requested_event(item) +} + fn parse_handoff_requested_event(item: &JsonMap) -> Option { let item_type = item.get("type").and_then(Value::as_str); let item_name = item.get("name").and_then(Value::as_str); diff --git a/codex-rs/codex-api/tests/realtime_websocket_e2e.rs b/codex-rs/codex-api/tests/realtime_websocket_e2e.rs index 30786ad92..130ab6fd3 100644 --- a/codex-rs/codex-api/tests/realtime_websocket_e2e.rs +++ b/codex-rs/codex-api/tests/realtime_websocket_e2e.rs @@ -170,6 +170,7 @@ async fn realtime_ws_e2e_session_create_and_event_flow() { sample_rate: 48000, num_channels: 1, samples_per_channel: Some(960), + item_id: None, }) .await .expect("send audio"); @@ -186,6 +187,7 @@ async fn realtime_ws_e2e_session_create_and_event_flow() { sample_rate: 48000, num_channels: 1, samples_per_channel: None, + item_id: None, }) ); @@ -254,6 +256,7 @@ async fn realtime_ws_e2e_send_while_next_event_waits() { sample_rate: 48000, num_channels: 1, samples_per_channel: Some(960), + item_id: None, }), ) .await diff --git a/codex-rs/core/src/codex.rs b/codex-rs/core/src/codex.rs index 9787e10e9..735fe7ec4 100644 --- a/codex-rs/core/src/codex.rs +++ b/codex-rs/core/src/codex.rs @@ -2614,6 +2614,9 @@ impl Session { if !matches!(msg, EventMsg::TurnComplete(_)) { return; } + if let Err(err) = self.conversation.handoff_complete().await { + debug!("failed to finalize realtime handoff output: {err}"); + } self.conversation.clear_active_handoff().await; } diff --git a/codex-rs/core/src/codex_tests.rs b/codex-rs/core/src/codex_tests.rs index fd1bb576b..34ed7bcd6 100644 --- a/codex-rs/core/src/codex_tests.rs +++ b/codex-rs/core/src/codex_tests.rs @@ -2735,6 +2735,7 @@ fn submission_dispatch_span_uses_debug_for_realtime_audio() { sample_rate: 16_000, num_channels: 1, samples_per_channel: Some(160), + item_id: None, }, }), trace: None, diff --git a/codex-rs/core/src/realtime_conversation.rs b/codex-rs/core/src/realtime_conversation.rs index 243f4d8f2..938f922f8 100644 --- a/codex-rs/core/src/realtime_conversation.rs +++ b/codex-rs/core/src/realtime_conversation.rs @@ -11,6 +11,8 @@ use crate::realtime_context::build_realtime_startup_context; use async_channel::Receiver; use async_channel::Sender; use async_channel::TrySendError; +use base64::Engine; +use base64::engine::general_purpose::STANDARD as BASE64_STANDARD; use codex_api::Provider as ApiProvider; use codex_api::RealtimeAudioFrame; use codex_api::RealtimeEvent; @@ -34,6 +36,8 @@ use codex_protocol::protocol::RealtimeHandoffRequested; use http::HeaderMap; use http::HeaderValue; use http::header::AUTHORIZATION; +use serde_json::Value; +use serde_json::json; use std::sync::Arc; use std::sync::atomic::AtomicBool; use std::sync::atomic::Ordering; @@ -49,51 +53,72 @@ const USER_TEXT_IN_QUEUE_CAPACITY: usize = 64; const HANDOFF_OUT_QUEUE_CAPACITY: usize = 64; const OUTPUT_EVENTS_QUEUE_CAPACITY: usize = 256; const REALTIME_STARTUP_CONTEXT_TOKEN_BUDGET: usize = 5_000; +const ACTIVE_RESPONSE_CONFLICT_ERROR_PREFIX: &str = + "Conversation already has an active response in progress:"; pub(crate) struct RealtimeConversationManager { state: Mutex>, } +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +enum RealtimeSessionKind { + V1, + V2, +} + #[derive(Clone, Debug)] struct RealtimeHandoffState { output_tx: Sender, active_handoff: Arc>>, + last_output_text: Arc>>, + session_kind: RealtimeSessionKind, } #[derive(Debug, PartialEq, Eq)] -struct HandoffOutput { - handoff_id: String, - output_text: String, +enum HandoffOutput { + ImmediateAppend { + handoff_id: String, + output_text: String, + }, + FinalToolCall { + handoff_id: String, + output_text: String, + }, +} + +#[derive(Debug, PartialEq, Eq)] +struct OutputAudioState { + item_id: String, + audio_end_ms: u32, +} + +struct RealtimeInputTask { + writer: RealtimeWebsocketWriter, + events: RealtimeWebsocketEvents, + user_text_rx: Receiver, + handoff_output_rx: Receiver, + audio_rx: Receiver, + events_tx: Sender, + handoff_state: RealtimeHandoffState, + session_kind: RealtimeSessionKind, } impl RealtimeHandoffState { - fn new(output_tx: Sender) -> Self { + fn new(output_tx: Sender, session_kind: RealtimeSessionKind) -> Self { Self { output_tx, active_handoff: Arc::new(Mutex::new(None)), + last_output_text: Arc::new(Mutex::new(None)), + session_kind, } } - - async fn send_output(&self, output_text: String) -> CodexResult<()> { - let Some(handoff_id) = self.active_handoff.lock().await.clone() else { - return Ok(()); - }; - - self.output_tx - .send(HandoffOutput { - handoff_id, - output_text, - }) - .await - .map_err(|_| CodexErr::InvalidRequest("conversation is not running".to_string()))?; - Ok(()) - } } #[allow(dead_code)] struct ConversationState { audio_tx: Sender, user_text_tx: Sender, + writer: RealtimeWebsocketWriter, handoff: RealtimeHandoffState, task: JoinHandle<()>, realtime_active: Arc, @@ -129,6 +154,10 @@ impl RealtimeConversationManager { state.task.abort(); let _ = state.task.await; } + let session_kind = match session_config.event_parser { + RealtimeEventParser::V1 => RealtimeSessionKind::V1, + RealtimeEventParser::RealtimeV2 => RealtimeSessionKind::V2, + }; let client = RealtimeWebsocketClient::new(api_provider); let connection = client @@ -152,21 +181,23 @@ impl RealtimeConversationManager { async_channel::bounded::(OUTPUT_EVENTS_QUEUE_CAPACITY); let realtime_active = Arc::new(AtomicBool::new(true)); - let handoff = RealtimeHandoffState::new(handoff_output_tx); - let task = spawn_realtime_input_task( - writer, + let handoff = RealtimeHandoffState::new(handoff_output_tx, session_kind); + let task = spawn_realtime_input_task(RealtimeInputTask { + writer: writer.clone(), events, user_text_rx, handoff_output_rx, audio_rx, events_tx, - handoff.clone(), - ); + handoff_state: handoff.clone(), + session_kind, + }); let mut guard = self.state.lock().await; *guard = Some(ConversationState { audio_tx, user_text_tx, + writer, handoff, task, realtime_active: Arc::clone(&realtime_active), @@ -228,7 +259,51 @@ impl RealtimeConversationManager { state.handoff.clone() }; - handoff.send_output(output_text).await + let Some(handoff_id) = handoff.active_handoff.lock().await.clone() else { + return Ok(()); + }; + + *handoff.last_output_text.lock().await = Some(output_text.clone()); + if matches!(handoff.session_kind, RealtimeSessionKind::V1) { + handoff + .output_tx + .send(HandoffOutput::ImmediateAppend { + handoff_id, + output_text, + }) + .await + .map_err(|_| CodexErr::InvalidRequest("conversation is not running".to_string()))?; + } + Ok(()) + } + + pub(crate) async fn handoff_complete(&self) -> CodexResult<()> { + let handoff = { + let guard = self.state.lock().await; + guard.as_ref().map(|state| state.handoff.clone()) + }; + let Some(handoff) = handoff else { + return Ok(()); + }; + if matches!(handoff.session_kind, RealtimeSessionKind::V1) { + return Ok(()); + } + + let Some(handoff_id) = handoff.active_handoff.lock().await.clone() else { + return Ok(()); + }; + let Some(output_text) = handoff.last_output_text.lock().await.clone() else { + return Ok(()); + }; + + handoff + .output_tx + .send(HandoffOutput::FinalToolCall { + handoff_id, + output_text, + }) + .await + .map_err(|_| CodexErr::InvalidRequest("conversation is not running".to_string())) } pub(crate) async fn active_handoff_id(&self) -> Option { @@ -246,6 +321,7 @@ impl RealtimeConversationManager { }; if let Some(handoff) = handoff { *handoff.active_handoff.lock().await = None; + *handoff.last_output_text.lock().await = None; } } @@ -467,7 +543,6 @@ pub(crate) async fn handle_text( params: ConversationTextParams, ) { debug!(text = %params.text, "[realtime-text] appending realtime conversation text input"); - if let Err(err) = sess.conversation.text_in(params.text).await { error!("failed to append realtime text: {err}"); send_conversation_error(sess, sub_id, err.to_string(), CodexErrorInfo::BadRequest).await; @@ -491,16 +566,23 @@ pub(crate) async fn handle_close(sess: &Arc, sub_id: String) { } } -fn spawn_realtime_input_task( - writer: RealtimeWebsocketWriter, - events: RealtimeWebsocketEvents, - user_text_rx: Receiver, - handoff_output_rx: Receiver, - audio_rx: Receiver, - events_tx: Sender, - handoff_state: RealtimeHandoffState, -) -> JoinHandle<()> { +fn spawn_realtime_input_task(input: RealtimeInputTask) -> JoinHandle<()> { + let RealtimeInputTask { + writer, + events, + user_text_rx, + handoff_output_rx, + audio_rx, + events_tx, + handoff_state, + session_kind, + } = input; + tokio::spawn(async move { + let mut pending_response_create = false; + let mut response_in_progress = false; + let mut output_audio_state: Option = None; + loop { tokio::select! { text = user_text_rx.recv() => { @@ -511,23 +593,66 @@ fn spawn_realtime_input_task( warn!("failed to send input text: {mapped_error}"); break; } + if matches!(session_kind, RealtimeSessionKind::V2) { + if response_in_progress { + pending_response_create = true; + } else if let Err(err) = writer.send_response_create().await { + let mapped_error = map_api_error(err); + warn!("failed to send text response.create: {mapped_error}"); + break; + } else { + pending_response_create = false; + response_in_progress = true; + } + } } Err(_) => break, } } handoff_output = handoff_output_rx.recv() => { match handoff_output { - Ok(HandoffOutput { - handoff_id, - output_text, - }) => { - if let Err(err) = writer - .send_conversation_handoff_append(handoff_id, output_text) - .await - { - let mapped_error = map_api_error(err); - warn!("failed to send handoff output: {mapped_error}"); - break; + Ok(handoff_output) => { + match handoff_output { + HandoffOutput::ImmediateAppend { + handoff_id, + output_text, + } => { + if let Err(err) = writer + .send_conversation_handoff_append(handoff_id, output_text) + .await + { + let mapped_error = map_api_error(err); + warn!("failed to send handoff output: {mapped_error}"); + break; + } + } + HandoffOutput::FinalToolCall { + handoff_id, + output_text, + } => { + if let Err(err) = writer + .send_conversation_handoff_append(handoff_id, output_text) + .await + { + let mapped_error = map_api_error(err); + warn!("failed to send handoff output: {mapped_error}"); + break; + } + if matches!(session_kind, RealtimeSessionKind::V2) { + if response_in_progress { + pending_response_create = true; + } else if let Err(err) = writer.send_response_create().await { + let mapped_error = map_api_error(err); + warn!( + "failed to send handoff response.create: {mapped_error}" + ); + break; + } else { + pending_response_create = false; + response_in_progress = true; + } + } + } } } Err(_) => break, @@ -536,12 +661,108 @@ fn spawn_realtime_input_task( event = events.next_event() => { match event { Ok(Some(event)) => { - if let RealtimeEvent::HandoffRequested(handoff) = &event { - *handoff_state.active_handoff.lock().await = - Some(handoff.handoff_id.clone()); + let mut should_stop = false; + let mut forward_event = true; + + match &event { + RealtimeEvent::ConversationItemAdded(item) => { + match item.get("type").and_then(Value::as_str) { + Some("response.created") + if matches!(session_kind, RealtimeSessionKind::V2) => + { + response_in_progress = true; + } + Some("response.done") + if matches!(session_kind, RealtimeSessionKind::V2) => + { + response_in_progress = false; + output_audio_state = None; + if pending_response_create { + if let Err(err) = writer.send_response_create().await { + let mapped_error = map_api_error(err); + warn!( + "failed to send deferred response.create: {mapped_error}" + ); + break; + } + pending_response_create = false; + response_in_progress = true; + } + } + _ => {} + } + } + RealtimeEvent::AudioOut(frame) => { + if matches!(session_kind, RealtimeSessionKind::V2) { + update_output_audio_state(&mut output_audio_state, frame); + } + } + RealtimeEvent::InputAudioSpeechStarted(event) => { + if matches!(session_kind, RealtimeSessionKind::V2) + && let Some(output_audio_state) = + output_audio_state.take() + && event + .item_id + .as_deref() + .is_none_or(|item_id| item_id == output_audio_state.item_id) + && let Err(err) = writer + .send_payload(json!({ + "type": "conversation.item.truncate", + "item_id": output_audio_state.item_id, + "content_index": 0, + "audio_end_ms": output_audio_state.audio_end_ms, + }) + .to_string()) + .await + { + let mapped_error = map_api_error(err); + warn!("failed to truncate realtime audio: {mapped_error}"); + } + } + RealtimeEvent::ResponseCancelled(_) => { + response_in_progress = false; + output_audio_state = None; + if matches!(session_kind, RealtimeSessionKind::V2) + && pending_response_create + { + if let Err(err) = writer.send_response_create().await { + let mapped_error = map_api_error(err); + warn!( + "failed to send deferred response.create after cancellation: {mapped_error}" + ); + break; + } + pending_response_create = false; + response_in_progress = true; + } + } + RealtimeEvent::HandoffRequested(handoff) => { + *handoff_state.active_handoff.lock().await = + Some(handoff.handoff_id.clone()); + *handoff_state.last_output_text.lock().await = None; + response_in_progress = false; + output_audio_state = None; + } + RealtimeEvent::Error(message) + if matches!(session_kind, RealtimeSessionKind::V2) + && message.starts_with(ACTIVE_RESPONSE_CONFLICT_ERROR_PREFIX) => + { + warn!( + "realtime rejected response.create because a response is already in progress; deferring follow-up response.create" + ); + pending_response_create = true; + response_in_progress = true; + forward_event = false; + } + RealtimeEvent::Error(_) => { + should_stop = true; + } + RealtimeEvent::SessionUpdated { .. } + | RealtimeEvent::InputTranscriptDelta(_) + | RealtimeEvent::OutputTranscriptDelta(_) + | RealtimeEvent::ConversationItemDone { .. } => {} } - let should_stop = matches!(&event, RealtimeEvent::Error(_)); - if events_tx.send(event).await.is_err() { + if forward_event && events_tx.send(event).await.is_err() { break; } if should_stop { @@ -588,6 +809,49 @@ fn spawn_realtime_input_task( }) } +fn update_output_audio_state( + output_audio_state: &mut Option, + frame: &RealtimeAudioFrame, +) { + let Some(item_id) = frame.item_id.clone() else { + return; + }; + let audio_end_ms = audio_duration_ms(frame); + if audio_end_ms == 0 { + return; + } + + if let Some(current) = output_audio_state.as_mut() + && current.item_id == item_id + { + current.audio_end_ms = current.audio_end_ms.saturating_add(audio_end_ms); + return; + } + + *output_audio_state = Some(OutputAudioState { + item_id, + audio_end_ms, + }); +} + +fn audio_duration_ms(frame: &RealtimeAudioFrame) -> u32 { + let Some(samples_per_channel) = frame + .samples_per_channel + .or_else(|| decoded_samples_per_channel(frame)) + else { + return 0; + }; + let sample_rate = u64::from(frame.sample_rate.max(1)); + ((u64::from(samples_per_channel) * 1_000) / sample_rate) as u32 +} + +fn decoded_samples_per_channel(frame: &RealtimeAudioFrame) -> Option { + let bytes = BASE64_STANDARD.decode(&frame.data).ok()?; + let channels = usize::from(frame.num_channels.max(1)); + let samples = bytes.len().checked_div(2)?.checked_div(channels)?; + u32::try_from(samples).ok() +} + async fn send_conversation_error( sess: &Arc, sub_id: String, diff --git a/codex-rs/core/src/realtime_conversation_tests.rs b/codex-rs/core/src/realtime_conversation_tests.rs index d6b85a92d..0a32d063c 100644 --- a/codex-rs/core/src/realtime_conversation_tests.rs +++ b/codex-rs/core/src/realtime_conversation_tests.rs @@ -1,5 +1,5 @@ -use super::HandoffOutput; use super::RealtimeHandoffState; +use super::RealtimeSessionKind; use super::realtime_text_from_handoff_request; use async_channel::bounded; use codex_protocol::protocol::RealtimeHandoffRequested; @@ -57,7 +57,7 @@ fn ignores_empty_handoff_request_input_transcript() { #[tokio::test] async fn clears_active_handoff_explicitly() { let (tx, _rx) = bounded(1); - let state = RealtimeHandoffState::new(tx); + let state = RealtimeHandoffState::new(tx, RealtimeSessionKind::V1); *state.active_handoff.lock().await = Some("handoff_1".to_string()); assert_eq!( @@ -68,47 +68,3 @@ async fn clears_active_handoff_explicitly() { *state.active_handoff.lock().await = None; assert_eq!(state.active_handoff.lock().await.clone(), None); } - -#[tokio::test] -async fn sends_multiple_handoff_outputs_until_cleared() { - let (tx, rx) = bounded(4); - let state = RealtimeHandoffState::new(tx); - - state - .send_output("ignored".to_string()) - .await - .expect("send"); - assert!(rx.is_empty()); - - *state.active_handoff.lock().await = Some("handoff_1".to_string()); - state.send_output("result".to_string()).await.expect("send"); - state - .send_output("result 2".to_string()) - .await - .expect("send"); - - let output_1 = rx.recv().await.expect("recv"); - assert_eq!( - output_1, - HandoffOutput { - handoff_id: "handoff_1".to_string(), - output_text: "result".to_string(), - } - ); - - let output_2 = rx.recv().await.expect("recv"); - assert_eq!( - output_2, - HandoffOutput { - handoff_id: "handoff_1".to_string(), - output_text: "result 2".to_string(), - } - ); - - *state.active_handoff.lock().await = None; - state - .send_output("ignored after clear".to_string()) - .await - .expect("send"); - assert!(rx.is_empty()); -} diff --git a/codex-rs/core/tests/suite/realtime_conversation.rs b/codex-rs/core/tests/suite/realtime_conversation.rs index 0d49f8c8d..4ab987121 100644 --- a/codex-rs/core/tests/suite/realtime_conversation.rs +++ b/codex-rs/core/tests/suite/realtime_conversation.rs @@ -176,6 +176,7 @@ async fn conversation_start_audio_text_close_round_trip() -> Result<()> { sample_rate: 24000, num_channels: 1, samples_per_channel: Some(480), + item_id: None, }, })) .await?; @@ -409,6 +410,7 @@ async fn conversation_audio_before_start_emits_error() -> Result<()> { sample_rate: 24000, num_channels: 1, samples_per_channel: Some(480), + item_id: None, }, })) .await?; @@ -518,6 +520,7 @@ async fn conversation_second_start_replaces_runtime() -> Result<()> { sample_rate: 24000, num_channels: 1, samples_per_channel: Some(480), + item_id: None, }, })) .await?; @@ -1469,6 +1472,7 @@ async fn inbound_handoff_request_clears_active_transcript_after_each_handoff() - sample_rate: 24000, num_channels: 1, samples_per_channel: Some(480), + item_id: None, }, })) .await?; @@ -1954,6 +1958,7 @@ async fn inbound_handoff_request_steers_active_turn() -> Result<()> { sample_rate: 24000, num_channels: 1, samples_per_channel: Some(480), + item_id: None, }, })) .await?; diff --git a/codex-rs/protocol/src/protocol.rs b/codex-rs/protocol/src/protocol.rs index f1f60e163..152743b3e 100644 --- a/codex-rs/protocol/src/protocol.rs +++ b/codex-rs/protocol/src/protocol.rs @@ -139,6 +139,8 @@ pub struct RealtimeAudioFrame { pub num_channels: u16, #[serde(skip_serializing_if = "Option::is_none")] pub samples_per_channel: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub item_id: Option, } #[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq, JsonSchema, TS)] @@ -160,15 +162,27 @@ pub struct RealtimeHandoffRequested { pub active_transcript: Vec, } +#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq, JsonSchema, TS)] +pub struct RealtimeInputAudioSpeechStarted { + pub item_id: Option, +} + +#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq, JsonSchema, TS)] +pub struct RealtimeResponseCancelled { + pub response_id: Option, +} + #[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq, JsonSchema, TS)] pub enum RealtimeEvent { SessionUpdated { session_id: String, instructions: Option, }, + InputAudioSpeechStarted(RealtimeInputAudioSpeechStarted), InputTranscriptDelta(RealtimeTranscriptDelta), OutputTranscriptDelta(RealtimeTranscriptDelta), AudioOut(RealtimeAudioFrame), + ResponseCancelled(RealtimeResponseCancelled), ConversationItemAdded(Value), ConversationItemDone { item_id: String, @@ -4078,6 +4092,7 @@ mod tests { sample_rate: 24_000, num_channels: 1, samples_per_channel: Some(480), + item_id: None, }, }); let start = Op::RealtimeConversationStart(ConversationStartParams { diff --git a/codex-rs/tui/src/chatwidget/realtime.rs b/codex-rs/tui/src/chatwidget/realtime.rs index 4991aa568..6b5042307 100644 --- a/codex-rs/tui/src/chatwidget/realtime.rs +++ b/codex-rs/tui/src/chatwidget/realtime.rs @@ -264,9 +264,11 @@ impl ChatWidget { RealtimeEvent::SessionUpdated { session_id, .. } => { self.realtime_conversation.session_id = Some(session_id); } + RealtimeEvent::InputAudioSpeechStarted(_) => {} RealtimeEvent::InputTranscriptDelta(_) => {} RealtimeEvent::OutputTranscriptDelta(_) => {} RealtimeEvent::AudioOut(frame) => self.enqueue_realtime_audio_out(&frame), + RealtimeEvent::ResponseCancelled(_) => {} RealtimeEvent::ConversationItemAdded(_item) => {} RealtimeEvent::ConversationItemDone { .. } => {} RealtimeEvent::HandoffRequested(_) => {} diff --git a/codex-rs/tui/src/voice.rs b/codex-rs/tui/src/voice.rs index 7e4d8a85e..07adcfd0a 100644 --- a/codex-rs/tui/src/voice.rs +++ b/codex-rs/tui/src/voice.rs @@ -428,6 +428,7 @@ fn send_realtime_audio_chunk( sample_rate: MODEL_AUDIO_SAMPLE_RATE, num_channels: MODEL_AUDIO_CHANNELS, samples_per_channel: Some(samples_per_channel), + item_id: None, }, }, ))); diff --git a/codex-rs/tui_app_server/src/chatwidget/realtime.rs b/codex-rs/tui_app_server/src/chatwidget/realtime.rs index b954a373a..8a11cb405 100644 --- a/codex-rs/tui_app_server/src/chatwidget/realtime.rs +++ b/codex-rs/tui_app_server/src/chatwidget/realtime.rs @@ -268,9 +268,11 @@ impl ChatWidget { RealtimeEvent::SessionUpdated { session_id, .. } => { self.realtime_conversation.session_id = Some(session_id); } + RealtimeEvent::InputAudioSpeechStarted(_) => self.interrupt_realtime_audio_playback(), RealtimeEvent::InputTranscriptDelta(_) => {} RealtimeEvent::OutputTranscriptDelta(_) => {} RealtimeEvent::AudioOut(frame) => self.enqueue_realtime_audio_out(&frame), + RealtimeEvent::ResponseCancelled(_) => self.interrupt_realtime_audio_playback(), RealtimeEvent::ConversationItemAdded(_item) => {} RealtimeEvent::ConversationItemDone { .. } => {} RealtimeEvent::HandoffRequested(_) => {} @@ -313,6 +315,16 @@ impl ChatWidget { } } + #[cfg(not(target_os = "linux"))] + fn interrupt_realtime_audio_playback(&mut self) { + if let Some(player) = &self.realtime_conversation.audio_player { + player.clear(); + } + } + + #[cfg(target_os = "linux")] + fn interrupt_realtime_audio_playback(&mut self) {} + #[cfg(not(target_os = "linux"))] fn start_realtime_local_audio(&mut self) { if self.realtime_conversation.capture_stop_flag.is_some() { diff --git a/codex-rs/tui_app_server/src/voice.rs b/codex-rs/tui_app_server/src/voice.rs index f448c4573..6758eff4d 100644 --- a/codex-rs/tui_app_server/src/voice.rs +++ b/codex-rs/tui_app_server/src/voice.rs @@ -426,6 +426,7 @@ fn send_realtime_audio_chunk( sample_rate: MODEL_AUDIO_SAMPLE_RATE, num_channels: MODEL_AUDIO_CHANNELS, samples_per_channel: Some(samples_per_channel), + item_id: None, }, }); }