diff --git a/codex-rs/codex-api/src/endpoint/mod.rs b/codex-rs/codex-api/src/endpoint/mod.rs index 0dede138e..981643904 100644 --- a/codex-rs/codex-api/src/endpoint/mod.rs +++ b/codex-rs/codex-api/src/endpoint/mod.rs @@ -2,6 +2,7 @@ pub mod aggregate; pub mod compact; pub mod memories; pub mod models; +pub mod realtime_websocket; pub mod responses; pub mod responses_websocket; mod session; diff --git a/codex-rs/codex-api/src/endpoint/realtime_websocket/methods.rs b/codex-rs/codex-api/src/endpoint/realtime_websocket/methods.rs new file mode 100644 index 000000000..e9a297de0 --- /dev/null +++ b/codex-rs/codex-api/src/endpoint/realtime_websocket/methods.rs @@ -0,0 +1,824 @@ +use crate::endpoint::realtime_websocket::protocol::ConversationItem; +use crate::endpoint::realtime_websocket::protocol::ConversationItemContent; +use crate::endpoint::realtime_websocket::protocol::RealtimeAudioFrame; +use crate::endpoint::realtime_websocket::protocol::RealtimeEvent; +use crate::endpoint::realtime_websocket::protocol::RealtimeOutboundMessage; +use crate::endpoint::realtime_websocket::protocol::RealtimeSessionConfig; +use crate::endpoint::realtime_websocket::protocol::SessionCreateSession; +use crate::endpoint::realtime_websocket::protocol::SessionUpdateSession; +use crate::endpoint::realtime_websocket::protocol::parse_realtime_event; +use crate::error::ApiError; +use crate::provider::Provider; +use codex_utils_rustls_provider::ensure_rustls_crypto_provider; +use futures::SinkExt; +use futures::StreamExt; +use http::HeaderMap; +use std::sync::Arc; +use std::sync::atomic::AtomicBool; +use std::sync::atomic::Ordering; +use tokio::net::TcpStream; +use tokio::sync::Mutex; +use tokio::sync::mpsc; +use tokio::sync::oneshot; +use tokio_tungstenite::MaybeTlsStream; +use tokio_tungstenite::WebSocketStream; +use tokio_tungstenite::tungstenite::Error as WsError; +use tokio_tungstenite::tungstenite::Message; +use tokio_tungstenite::tungstenite::client::IntoClientRequest; +use tracing::info; +use tracing::trace; +use tungstenite::protocol::WebSocketConfig; +use url::Url; + +struct WsStream { + tx_command: mpsc::Sender, + pump_task: tokio::task::JoinHandle<()>, +} + +enum WsCommand { + Send { + message: Message, + tx_result: oneshot::Sender>, + }, + Close { + tx_result: oneshot::Sender>, + }, +} + +impl WsStream { + fn new( + inner: WebSocketStream>, + ) -> (Self, mpsc::UnboundedReceiver>) { + let (tx_command, mut rx_command) = mpsc::channel::(32); + let (tx_message, rx_message) = mpsc::unbounded_channel::>(); + + let pump_task = tokio::spawn(async move { + let mut inner = inner; + loop { + tokio::select! { + command = rx_command.recv() => { + let Some(command) = command else { + break; + }; + match command { + WsCommand::Send { message, tx_result } => { + let result = inner.send(message).await; + let should_break = result.is_err(); + let _ = tx_result.send(result); + if should_break { + break; + } + } + WsCommand::Close { tx_result } => { + let result = inner.close(None).await; + let _ = tx_result.send(result); + break; + } + } + } + message = inner.next() => { + let Some(message) = message else { + break; + }; + match message { + Ok(Message::Ping(payload)) => { + if let Err(err) = inner.send(Message::Pong(payload)).await { + let _ = tx_message.send(Err(err)); + break; + } + } + Ok(Message::Pong(_)) => {} + Ok(message @ (Message::Text(_) + | Message::Binary(_) + | Message::Close(_) + | Message::Frame(_))) => { + let is_close = matches!(message, Message::Close(_)); + if tx_message.send(Ok(message)).is_err() { + break; + } + if is_close { + break; + } + } + Err(err) => { + let _ = tx_message.send(Err(err)); + break; + } + } + } + } + } + }); + + ( + Self { + tx_command, + pump_task, + }, + rx_message, + ) + } + + async fn request( + &self, + make_command: impl FnOnce(oneshot::Sender>) -> WsCommand, + ) -> Result<(), WsError> { + let (tx_result, rx_result) = oneshot::channel(); + if self.tx_command.send(make_command(tx_result)).await.is_err() { + return Err(WsError::ConnectionClosed); + } + rx_result.await.unwrap_or(Err(WsError::ConnectionClosed)) + } + + async fn send(&self, message: Message) -> Result<(), WsError> { + self.request(|tx_result| WsCommand::Send { message, tx_result }) + .await + } + + async fn close(&self) -> Result<(), WsError> { + self.request(|tx_result| WsCommand::Close { tx_result }) + .await + } +} + +impl Drop for WsStream { + fn drop(&mut self) { + self.pump_task.abort(); + } +} + +pub struct RealtimeWebsocketConnection { + writer: RealtimeWebsocketWriter, + events: RealtimeWebsocketEvents, +} + +#[derive(Clone)] +pub struct RealtimeWebsocketWriter { + stream: Arc, + is_closed: Arc, +} + +#[derive(Clone)] +pub struct RealtimeWebsocketEvents { + rx_message: Arc>>>, + is_closed: Arc, +} + +impl RealtimeWebsocketConnection { + pub async fn send_audio_frame(&self, frame: RealtimeAudioFrame) -> Result<(), ApiError> { + self.writer.send_audio_frame(frame).await + } + + pub async fn send_conversation_item_create(&self, text: String) -> Result<(), ApiError> { + self.writer.send_conversation_item_create(text).await + } + + pub async fn send_session_update( + &self, + backend_prompt: String, + conversation_id: Option, + ) -> Result<(), ApiError> { + self.writer + .send_session_update(backend_prompt, conversation_id) + .await + } + + pub async fn send_session_create( + &self, + backend_prompt: String, + conversation_id: Option, + ) -> Result<(), ApiError> { + self.writer + .send_session_create(backend_prompt, conversation_id) + .await + } + + pub async fn close(&self) -> Result<(), ApiError> { + self.writer.close().await + } + + pub async fn next_event(&self) -> Result, ApiError> { + self.events.next_event().await + } + + pub fn writer(&self) -> RealtimeWebsocketWriter { + self.writer.clone() + } + + pub fn events(&self) -> RealtimeWebsocketEvents { + self.events.clone() + } + + fn new( + stream: WsStream, + rx_message: mpsc::UnboundedReceiver>, + ) -> Self { + let stream = Arc::new(stream); + let is_closed = Arc::new(AtomicBool::new(false)); + Self { + writer: RealtimeWebsocketWriter { + stream: Arc::clone(&stream), + is_closed: Arc::clone(&is_closed), + }, + events: RealtimeWebsocketEvents { + rx_message: Arc::new(Mutex::new(rx_message)), + is_closed, + }, + } + } +} + +impl RealtimeWebsocketWriter { + pub async fn send_audio_frame(&self, frame: RealtimeAudioFrame) -> Result<(), ApiError> { + self.send_json(RealtimeOutboundMessage::InputAudioDelta { + delta: frame.data, + sample_rate: frame.sample_rate, + num_channels: frame.num_channels, + samples_per_channel: frame.samples_per_channel, + }) + .await + } + + pub async fn send_conversation_item_create(&self, text: String) -> Result<(), ApiError> { + self.send_json(RealtimeOutboundMessage::ConversationItemCreate { + item: ConversationItem { + kind: "message".to_string(), + role: "user".to_string(), + content: vec![ConversationItemContent { + kind: "text".to_string(), + text, + }], + }, + }) + .await + } + + pub async fn send_session_update( + &self, + backend_prompt: String, + conversation_id: Option, + ) -> Result<(), ApiError> { + self.send_json(RealtimeOutboundMessage::SessionUpdate { + session: Some(SessionUpdateSession { + backend_prompt, + conversation_id, + }), + }) + .await + } + + pub async fn send_session_create( + &self, + backend_prompt: String, + conversation_id: Option, + ) -> Result<(), ApiError> { + self.send_json(RealtimeOutboundMessage::SessionCreate { + session: SessionCreateSession { + backend_prompt, + conversation_id, + }, + }) + .await + } + + pub async fn close(&self) -> Result<(), ApiError> { + if self.is_closed.swap(true, Ordering::SeqCst) { + return Ok(()); + } + if let Err(err) = self.stream.close().await + && !matches!(err, WsError::ConnectionClosed | WsError::AlreadyClosed) + { + return Err(ApiError::Stream(format!( + "failed to close websocket: {err}" + ))); + } + Ok(()) + } + + async fn send_json(&self, message: RealtimeOutboundMessage) -> Result<(), ApiError> { + let payload = serde_json::to_string(&message) + .map_err(|err| ApiError::Stream(format!("failed to encode realtime request: {err}")))?; + trace!("realtime websocket request: {payload}"); + + if self.is_closed.load(Ordering::SeqCst) { + return Err(ApiError::Stream( + "realtime websocket connection is closed".to_string(), + )); + } + + self.stream + .send(Message::Text(payload.into())) + .await + .map_err(|err| ApiError::Stream(format!("failed to send realtime request: {err}")))?; + Ok(()) + } +} + +impl RealtimeWebsocketEvents { + pub async fn next_event(&self) -> Result, ApiError> { + if self.is_closed.load(Ordering::SeqCst) { + return Ok(None); + } + + loop { + let msg = match self.rx_message.lock().await.recv().await { + Some(Ok(msg)) => msg, + Some(Err(err)) => { + self.is_closed.store(true, Ordering::SeqCst); + return Err(ApiError::Stream(format!( + "failed to read websocket message: {err}" + ))); + } + None => { + self.is_closed.store(true, Ordering::SeqCst); + return Ok(None); + } + }; + + match msg { + Message::Text(text) => { + if let Some(event) = parse_realtime_event(&text) { + return Ok(Some(event)); + } + } + Message::Close(_) => { + self.is_closed.store(true, Ordering::SeqCst); + return Ok(None); + } + Message::Binary(_) => { + return Ok(Some(RealtimeEvent::Error( + "unexpected binary realtime websocket event".to_string(), + ))); + } + Message::Frame(_) | Message::Ping(_) | Message::Pong(_) => {} + } + } + } +} + +pub struct RealtimeWebsocketClient { + provider: Provider, +} + +impl RealtimeWebsocketClient { + pub fn new(provider: Provider) -> Self { + Self { provider } + } + + pub async fn connect( + &self, + config: RealtimeSessionConfig, + extra_headers: HeaderMap, + default_headers: HeaderMap, + ) -> Result { + ensure_rustls_crypto_provider(); + let ws_url = websocket_url_from_api_url(config.api_url.as_str())?; + + let mut request = ws_url + .as_str() + .into_client_request() + .map_err(|err| ApiError::Stream(format!("failed to build websocket request: {err}")))?; + let headers = merge_request_headers(&self.provider.headers, extra_headers, default_headers); + request.headers_mut().extend(headers); + + info!("connecting realtime websocket: {ws_url}"); + let (stream, _) = + tokio_tungstenite::connect_async_with_config(request, Some(websocket_config()), false) + .await + .map_err(|err| { + ApiError::Stream(format!("failed to connect realtime websocket: {err}")) + })?; + + let (stream, rx_message) = WsStream::new(stream); + let connection = RealtimeWebsocketConnection::new(stream, rx_message); + connection + .send_session_create(config.prompt, config.session_id) + .await?; + Ok(connection) + } +} + +fn merge_request_headers( + provider_headers: &HeaderMap, + extra_headers: HeaderMap, + default_headers: HeaderMap, +) -> HeaderMap { + let mut headers = provider_headers.clone(); + headers.extend(extra_headers); + for (name, value) in &default_headers { + if let http::header::Entry::Vacant(entry) = headers.entry(name) { + entry.insert(value.clone()); + } + } + headers +} + +fn websocket_config() -> WebSocketConfig { + WebSocketConfig::default() +} + +fn websocket_url_from_api_url(api_url: &str) -> Result { + let mut url = Url::parse(api_url) + .map_err(|err| ApiError::Stream(format!("failed to parse realtime api_url: {err}")))?; + + match url.scheme() { + "ws" | "wss" => { + if url.path().is_empty() || url.path() == "/" { + url.set_path("/ws"); + } + Ok(url) + } + "http" | "https" => { + if url.path().is_empty() || url.path() == "/" { + url.set_path("/ws"); + } + let scheme = if url.scheme() == "http" { "ws" } else { "wss" }; + let _ = url.set_scheme(scheme); + Ok(url) + } + scheme => Err(ApiError::Stream(format!( + "unsupported realtime api_url scheme: {scheme}" + ))), + } +} + +#[cfg(test)] +mod tests { + use super::*; + use http::HeaderValue; + use pretty_assertions::assert_eq; + use serde_json::Value; + use serde_json::json; + use std::collections::HashMap; + use std::time::Duration; + use tokio::net::TcpListener; + use tokio_tungstenite::accept_async; + use tokio_tungstenite::tungstenite::Message; + + #[test] + fn parse_session_created_event() { + let payload = json!({ + "type": "session.created", + "session": {"id": "sess_123"} + }) + .to_string(); + + assert_eq!( + parse_realtime_event(payload.as_str()), + Some(RealtimeEvent::SessionCreated { + session_id: "sess_123".to_string() + }) + ); + } + + #[test] + fn parse_audio_delta_event() { + let payload = json!({ + "type": "response.output_audio.delta", + "delta": "AAA=", + "sample_rate": 48000, + "num_channels": 1, + "samples_per_channel": 960 + }) + .to_string(); + assert_eq!( + parse_realtime_event(payload.as_str()), + Some(RealtimeEvent::AudioOut(RealtimeAudioFrame { + data: "AAA=".to_string(), + sample_rate: 48000, + num_channels: 1, + samples_per_channel: Some(960), + })) + ); + } + + #[test] + fn parse_conversation_item_added_event() { + let payload = json!({ + "type": "conversation.item.added", + "item": {"type": "spawn_transcript", "seq": 7} + }) + .to_string(); + assert_eq!( + parse_realtime_event(payload.as_str()), + Some(RealtimeEvent::ConversationItemAdded( + json!({"type": "spawn_transcript", "seq": 7}) + )) + ); + } + + #[test] + fn merge_request_headers_matches_http_precedence() { + let mut provider_headers = HeaderMap::new(); + provider_headers.insert( + "originator", + HeaderValue::from_static("provider-originator"), + ); + provider_headers.insert("x-priority", HeaderValue::from_static("provider")); + + let mut extra_headers = HeaderMap::new(); + extra_headers.insert("x-priority", HeaderValue::from_static("extra")); + + let mut default_headers = HeaderMap::new(); + default_headers.insert("originator", HeaderValue::from_static("default-originator")); + default_headers.insert("x-priority", HeaderValue::from_static("default")); + default_headers.insert("x-default-only", HeaderValue::from_static("default-only")); + + let merged = merge_request_headers(&provider_headers, extra_headers, default_headers); + + assert_eq!( + merged.get("originator"), + Some(&HeaderValue::from_static("provider-originator")) + ); + assert_eq!( + merged.get("x-priority"), + Some(&HeaderValue::from_static("extra")) + ); + assert_eq!( + merged.get("x-default-only"), + Some(&HeaderValue::from_static("default-only")) + ); + } + + #[test] + fn websocket_url_from_http_base_defaults_to_ws_path() { + let url = websocket_url_from_api_url("http://127.0.0.1:8011").expect("build ws url"); + assert_eq!(url.as_str(), "ws://127.0.0.1:8011/ws"); + } + + #[test] + fn websocket_url_from_ws_base_defaults_to_ws_path() { + let url = websocket_url_from_api_url("wss://example.com").expect("build ws url"); + assert_eq!(url.as_str(), "wss://example.com/ws"); + } + + #[tokio::test] + async fn e2e_connect_and_exchange_events_against_mock_ws_server() { + let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind"); + let addr = listener.local_addr().expect("local addr"); + + let server = tokio::spawn(async move { + let (stream, _) = listener.accept().await.expect("accept"); + let mut ws = accept_async(stream).await.expect("accept ws"); + + let first = ws + .next() + .await + .expect("first msg") + .expect("first msg ok") + .into_text() + .expect("text"); + let first_json: Value = serde_json::from_str(&first).expect("json"); + assert_eq!(first_json["type"], "session.create"); + assert_eq!( + first_json["session"]["backend_prompt"], + Value::String("backend prompt".to_string()) + ); + assert_eq!( + first_json["session"]["conversation_id"], + Value::String("conv_1".to_string()) + ); + + ws.send(Message::Text( + json!({ + "type": "session.created", + "session": {"id": "sess_mock"} + }) + .to_string() + .into(), + )) + .await + .expect("send session.created"); + + let second = ws + .next() + .await + .expect("second msg") + .expect("second msg ok") + .into_text() + .expect("text"); + let second_json: Value = serde_json::from_str(&second).expect("json"); + assert_eq!(second_json["type"], "response.input_audio.delta"); + + let third = ws + .next() + .await + .expect("third msg") + .expect("third msg ok") + .into_text() + .expect("text"); + let third_json: Value = serde_json::from_str(&third).expect("json"); + assert_eq!(third_json["type"], "conversation.item.create"); + assert_eq!(third_json["item"]["content"][0]["text"], "hello agent"); + + ws.send(Message::Text( + json!({ + "type": "response.output_audio.delta", + "delta": "AQID", + "sample_rate": 48000, + "num_channels": 1 + }) + .to_string() + .into(), + )) + .await + .expect("send audio"); + + ws.send(Message::Text( + json!({ + "type": "conversation.item.added", + "item": {"type": "spawn_transcript", "seq": 2} + }) + .to_string() + .into(), + )) + .await + .expect("send item added"); + }); + + let provider = Provider { + name: "test".to_string(), + base_url: "http://localhost".to_string(), + query_params: Some(HashMap::new()), + headers: HeaderMap::new(), + retry: crate::provider::RetryConfig { + max_attempts: 1, + base_delay: Duration::from_millis(1), + retry_429: false, + retry_5xx: false, + retry_transport: false, + }, + stream_idle_timeout: Duration::from_secs(5), + }; + let client = RealtimeWebsocketClient::new(provider); + let connection = client + .connect( + RealtimeSessionConfig { + api_url: format!("ws://{addr}"), + prompt: "backend prompt".to_string(), + session_id: Some("conv_1".to_string()), + }, + HeaderMap::new(), + HeaderMap::new(), + ) + .await + .expect("connect"); + + let created = connection + .next_event() + .await + .expect("next event") + .expect("event"); + assert_eq!( + created, + RealtimeEvent::SessionCreated { + session_id: "sess_mock".to_string() + } + ); + + connection + .send_audio_frame(RealtimeAudioFrame { + data: "AQID".to_string(), + sample_rate: 48000, + num_channels: 1, + samples_per_channel: Some(960), + }) + .await + .expect("send audio"); + connection + .send_conversation_item_create("hello agent".to_string()) + .await + .expect("send item"); + + let audio_event = connection + .next_event() + .await + .expect("next event") + .expect("event"); + assert_eq!( + audio_event, + RealtimeEvent::AudioOut(RealtimeAudioFrame { + data: "AQID".to_string(), + sample_rate: 48000, + num_channels: 1, + samples_per_channel: None, + }) + ); + + let added_event = connection + .next_event() + .await + .expect("next event") + .expect("event"); + assert_eq!( + added_event, + RealtimeEvent::ConversationItemAdded(json!({ + "type": "spawn_transcript", + "seq": 2 + })) + ); + + connection.close().await.expect("close"); + server.await.expect("server task"); + } + + #[tokio::test] + async fn send_does_not_block_while_next_event_waits_for_inbound_data() { + let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind"); + let addr = listener.local_addr().expect("local addr"); + + let server = tokio::spawn(async move { + let (stream, _) = listener.accept().await.expect("accept"); + let mut ws = accept_async(stream).await.expect("accept ws"); + + let first = ws + .next() + .await + .expect("first msg") + .expect("first msg ok") + .into_text() + .expect("text"); + let first_json: Value = serde_json::from_str(&first).expect("json"); + assert_eq!(first_json["type"], "session.create"); + + let second = ws + .next() + .await + .expect("second msg") + .expect("second msg ok") + .into_text() + .expect("text"); + let second_json: Value = serde_json::from_str(&second).expect("json"); + assert_eq!(second_json["type"], "response.input_audio.delta"); + + ws.send(Message::Text( + json!({ + "type": "session.created", + "session": {"id": "sess_after_send"} + }) + .to_string() + .into(), + )) + .await + .expect("send session.created"); + }); + + let provider = Provider { + name: "test".to_string(), + base_url: "http://localhost".to_string(), + query_params: Some(HashMap::new()), + headers: HeaderMap::new(), + retry: crate::provider::RetryConfig { + max_attempts: 1, + base_delay: Duration::from_millis(1), + retry_429: false, + retry_5xx: false, + retry_transport: false, + }, + stream_idle_timeout: Duration::from_secs(5), + }; + let client = RealtimeWebsocketClient::new(provider); + let connection = client + .connect( + RealtimeSessionConfig { + api_url: format!("ws://{addr}"), + prompt: "backend prompt".to_string(), + session_id: Some("conv_1".to_string()), + }, + HeaderMap::new(), + HeaderMap::new(), + ) + .await + .expect("connect"); + + let (send_result, next_result) = tokio::join!( + async { + tokio::time::timeout( + Duration::from_millis(200), + connection.send_audio_frame(RealtimeAudioFrame { + data: "AQID".to_string(), + sample_rate: 48000, + num_channels: 1, + samples_per_channel: Some(960), + }), + ) + .await + }, + connection.next_event() + ); + + send_result + .expect("send should not block on next_event") + .expect("send audio"); + let next_event = next_result.expect("next event").expect("event"); + assert_eq!( + next_event, + RealtimeEvent::SessionCreated { + session_id: "sess_after_send".to_string() + } + ); + + connection.close().await.expect("close"); + server.await.expect("server task"); + } +} diff --git a/codex-rs/codex-api/src/endpoint/realtime_websocket/mod.rs b/codex-rs/codex-api/src/endpoint/realtime_websocket/mod.rs new file mode 100644 index 000000000..469fea8dc --- /dev/null +++ b/codex-rs/codex-api/src/endpoint/realtime_websocket/mod.rs @@ -0,0 +1,10 @@ +pub mod methods; +pub mod protocol; + +pub use methods::RealtimeWebsocketClient; +pub use methods::RealtimeWebsocketConnection; +pub use methods::RealtimeWebsocketEvents; +pub use methods::RealtimeWebsocketWriter; +pub use protocol::RealtimeAudioFrame; +pub use protocol::RealtimeEvent; +pub use protocol::RealtimeSessionConfig; diff --git a/codex-rs/codex-api/src/endpoint/realtime_websocket/protocol.rs b/codex-rs/codex-api/src/endpoint/realtime_websocket/protocol.rs new file mode 100644 index 000000000..db63f5179 --- /dev/null +++ b/codex-rs/codex-api/src/endpoint/realtime_websocket/protocol.rs @@ -0,0 +1,161 @@ +use serde::Deserialize; +use serde::Serialize; +use serde_json::Value; +use tracing::debug; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct RealtimeSessionConfig { + pub api_url: String, + pub prompt: String, + pub session_id: Option, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct RealtimeAudioFrame { + pub data: String, + pub sample_rate: u32, + pub num_channels: u16, + #[serde(skip_serializing_if = "Option::is_none")] + pub samples_per_channel: Option, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum RealtimeEvent { + SessionCreated { session_id: String }, + SessionUpdated { backend_prompt: Option }, + AudioOut(RealtimeAudioFrame), + ConversationItemAdded(Value), + Error(String), +} + +#[derive(Debug, Clone, Serialize)] +#[serde(tag = "type")] +pub(super) enum RealtimeOutboundMessage { + #[serde(rename = "response.input_audio.delta")] + InputAudioDelta { + delta: String, + sample_rate: u32, + num_channels: u16, + #[serde(skip_serializing_if = "Option::is_none")] + samples_per_channel: Option, + }, + #[serde(rename = "session.create")] + SessionCreate { session: SessionCreateSession }, + #[serde(rename = "session.update")] + SessionUpdate { + #[serde(skip_serializing_if = "Option::is_none")] + session: Option, + }, + #[serde(rename = "conversation.item.create")] + ConversationItemCreate { item: ConversationItem }, +} + +#[derive(Debug, Clone, Serialize)] +pub(super) struct SessionUpdateSession { + pub(super) backend_prompt: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub(super) conversation_id: Option, +} + +#[derive(Debug, Clone, Serialize)] +pub(super) struct SessionCreateSession { + pub(super) backend_prompt: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub(super) conversation_id: Option, +} + +#[derive(Debug, Clone, Serialize)] +pub(super) struct ConversationItem { + #[serde(rename = "type")] + pub(super) kind: String, + pub(super) role: String, + pub(super) content: Vec, +} + +#[derive(Debug, Clone, Serialize)] +pub(super) struct ConversationItemContent { + #[serde(rename = "type")] + pub(super) kind: String, + pub(super) text: String, +} + +pub(super) fn parse_realtime_event(payload: &str) -> Option { + let parsed: Value = match serde_json::from_str(payload) { + Ok(msg) => msg, + Err(err) => { + debug!("failed to parse realtime event: {err}, data: {payload}"); + return None; + } + }; + + let message_type = match parsed.get("type").and_then(Value::as_str) { + Some(message_type) => message_type, + None => { + debug!("received realtime event without type field: {payload}"); + return None; + } + }; + match message_type { + "session.created" => { + let session = parsed.get("session").and_then(Value::as_object); + let session_id = session + .and_then(|session| session.get("id")) + .and_then(Value::as_str) + .map(str::to_string) + .or_else(|| { + parsed + .get("session_id") + .and_then(Value::as_str) + .map(str::to_string) + }); + session_id.map(|id| RealtimeEvent::SessionCreated { session_id: id }) + } + "session.updated" => { + let backend_prompt = parsed + .get("session") + .and_then(Value::as_object) + .and_then(|session| session.get("backend_prompt")) + .and_then(Value::as_str) + .map(str::to_string); + Some(RealtimeEvent::SessionUpdated { backend_prompt }) + } + "response.output_audio.delta" => { + let data = parsed + .get("delta") + .and_then(Value::as_str) + .or_else(|| parsed.get("data").and_then(Value::as_str)) + .map(str::to_string)?; + let sample_rate = parsed + .get("sample_rate") + .and_then(Value::as_u64) + .and_then(|v| u32::try_from(v).ok())?; + let num_channels = parsed + .get("num_channels") + .and_then(Value::as_u64) + .and_then(|v| u16::try_from(v).ok())?; + Some(RealtimeEvent::AudioOut(RealtimeAudioFrame { + data, + sample_rate, + num_channels, + samples_per_channel: parsed + .get("samples_per_channel") + .and_then(Value::as_u64) + .and_then(|v| u32::try_from(v).ok()), + })) + } + "conversation.item.added" => parsed + .get("item") + .cloned() + .map(RealtimeEvent::ConversationItemAdded), + "error" => parsed + .get("message") + .and_then(Value::as_str) + .map(str::to_string) + .or_else(|| parsed.get("error").map(std::string::ToString::to_string)) + .map(RealtimeEvent::Error), + _ => { + debug!("received unsupported realtime event type: {message_type}, data: {payload}"); + None + } + } +} diff --git a/codex-rs/codex-api/src/lib.rs b/codex-rs/codex-api/src/lib.rs index fba35442c..ff8953c03 100644 --- a/codex-rs/codex-api/src/lib.rs +++ b/codex-rs/codex-api/src/lib.rs @@ -29,6 +29,11 @@ pub use crate::endpoint::aggregate::AggregateStreamExt; pub use crate::endpoint::compact::CompactClient; pub use crate::endpoint::memories::MemoriesClient; pub use crate::endpoint::models::ModelsClient; +pub use crate::endpoint::realtime_websocket::RealtimeAudioFrame; +pub use crate::endpoint::realtime_websocket::RealtimeEvent; +pub use crate::endpoint::realtime_websocket::RealtimeSessionConfig; +pub use crate::endpoint::realtime_websocket::RealtimeWebsocketClient; +pub use crate::endpoint::realtime_websocket::RealtimeWebsocketConnection; pub use crate::endpoint::responses::ResponsesClient; pub use crate::endpoint::responses::ResponsesOptions; pub use crate::endpoint::responses_websocket::ResponsesWebsocketClient; diff --git a/codex-rs/codex-api/tests/realtime_websocket_e2e.rs b/codex-rs/codex-api/tests/realtime_websocket_e2e.rs new file mode 100644 index 000000000..08d93b914 --- /dev/null +++ b/codex-rs/codex-api/tests/realtime_websocket_e2e.rs @@ -0,0 +1,368 @@ +use std::collections::HashMap; +use std::future::Future; +use std::time::Duration; + +use codex_api::RealtimeAudioFrame; +use codex_api::RealtimeEvent; +use codex_api::RealtimeSessionConfig; +use codex_api::RealtimeWebsocketClient; +use codex_api::provider::Provider; +use codex_api::provider::RetryConfig; +use futures::SinkExt; +use futures::StreamExt; +use http::HeaderMap; +use serde_json::Value; +use serde_json::json; +use tokio::net::TcpListener; +use tokio_tungstenite::accept_async; +use tokio_tungstenite::tungstenite::Message; + +type RealtimeWsStream = tokio_tungstenite::WebSocketStream; + +async fn spawn_realtime_ws_server( + handler: Handler, +) -> (String, tokio::task::JoinHandle<()>) +where + Handler: FnOnce(RealtimeWsStream) -> Fut + Send + 'static, + Fut: Future + Send + 'static, +{ + let listener = match TcpListener::bind("127.0.0.1:0").await { + Ok(listener) => listener, + Err(err) => panic!("failed to bind test websocket listener: {err}"), + }; + let addr = match listener.local_addr() { + Ok(addr) => addr.to_string(), + Err(err) => panic!("failed to read local websocket listener address: {err}"), + }; + + let server = tokio::spawn(async move { + let (stream, _) = match listener.accept().await { + Ok(stream) => stream, + Err(err) => panic!("failed to accept test websocket connection: {err}"), + }; + let ws = match accept_async(stream).await { + Ok(ws) => ws, + Err(err) => panic!("failed to complete websocket handshake: {err}"), + }; + handler(ws).await; + }); + + (addr, server) +} + +fn test_provider() -> Provider { + Provider { + name: "test".to_string(), + base_url: "http://localhost".to_string(), + query_params: Some(HashMap::new()), + headers: HeaderMap::new(), + retry: RetryConfig { + max_attempts: 1, + base_delay: Duration::from_millis(1), + retry_429: false, + retry_5xx: false, + retry_transport: false, + }, + stream_idle_timeout: Duration::from_secs(5), + } +} + +#[tokio::test] +async fn realtime_ws_e2e_session_create_and_event_flow() { + let (addr, server) = spawn_realtime_ws_server(|mut ws: RealtimeWsStream| async move { + let first = ws + .next() + .await + .expect("first msg") + .expect("first msg ok") + .into_text() + .expect("text"); + let first_json: Value = serde_json::from_str(&first).expect("json"); + assert_eq!(first_json["type"], "session.create"); + assert_eq!( + first_json["session"]["backend_prompt"], + Value::String("backend prompt".to_string()) + ); + assert_eq!( + first_json["session"]["conversation_id"], + Value::String("conv_123".to_string()) + ); + + ws.send(Message::Text( + json!({ + "type": "session.created", + "session": {"id": "sess_mock"} + }) + .to_string() + .into(), + )) + .await + .expect("send session.created"); + + let second = ws + .next() + .await + .expect("second msg") + .expect("second msg ok") + .into_text() + .expect("text"); + let second_json: Value = serde_json::from_str(&second).expect("json"); + assert_eq!(second_json["type"], "response.input_audio.delta"); + + ws.send(Message::Text( + json!({ + "type": "response.output_audio.delta", + "delta": "AQID", + "sample_rate": 48000, + "num_channels": 1 + }) + .to_string() + .into(), + )) + .await + .expect("send audio out"); + }) + .await; + + let client = RealtimeWebsocketClient::new(test_provider()); + let connection = client + .connect( + RealtimeSessionConfig { + api_url: format!("ws://{addr}"), + prompt: "backend prompt".to_string(), + session_id: Some("conv_123".to_string()), + }, + HeaderMap::new(), + HeaderMap::new(), + ) + .await + .expect("connect"); + + let created = connection + .next_event() + .await + .expect("next event") + .expect("event"); + assert_eq!( + created, + RealtimeEvent::SessionCreated { + session_id: "sess_mock".to_string() + } + ); + + connection + .send_audio_frame(RealtimeAudioFrame { + data: "AQID".to_string(), + sample_rate: 48000, + num_channels: 1, + samples_per_channel: Some(960), + }) + .await + .expect("send audio"); + + let audio_event = connection + .next_event() + .await + .expect("next event") + .expect("event"); + assert_eq!( + audio_event, + RealtimeEvent::AudioOut(RealtimeAudioFrame { + data: "AQID".to_string(), + sample_rate: 48000, + num_channels: 1, + samples_per_channel: None, + }) + ); + + connection.close().await.expect("close"); + server.await.expect("server task"); +} + +#[tokio::test] +async fn realtime_ws_e2e_send_while_next_event_waits() { + let (addr, server) = spawn_realtime_ws_server(|mut ws: RealtimeWsStream| async move { + let first = ws + .next() + .await + .expect("first msg") + .expect("first msg ok") + .into_text() + .expect("text"); + let first_json: Value = serde_json::from_str(&first).expect("json"); + assert_eq!(first_json["type"], "session.create"); + + let second = ws + .next() + .await + .expect("second msg") + .expect("second msg ok") + .into_text() + .expect("text"); + let second_json: Value = serde_json::from_str(&second).expect("json"); + assert_eq!(second_json["type"], "response.input_audio.delta"); + + ws.send(Message::Text( + json!({ + "type": "session.created", + "session": {"id": "sess_after_send"} + }) + .to_string() + .into(), + )) + .await + .expect("send session.created"); + }) + .await; + + let client = RealtimeWebsocketClient::new(test_provider()); + let connection = client + .connect( + RealtimeSessionConfig { + api_url: format!("ws://{addr}"), + prompt: "backend prompt".to_string(), + session_id: Some("conv_123".to_string()), + }, + HeaderMap::new(), + HeaderMap::new(), + ) + .await + .expect("connect"); + + let (send_result, next_result) = tokio::join!( + async { + tokio::time::timeout( + Duration::from_millis(200), + connection.send_audio_frame(RealtimeAudioFrame { + data: "AQID".to_string(), + sample_rate: 48000, + num_channels: 1, + samples_per_channel: Some(960), + }), + ) + .await + }, + connection.next_event() + ); + + send_result + .expect("send should not block on next_event") + .expect("send audio"); + let next_event = next_result.expect("next event").expect("event"); + assert_eq!( + next_event, + RealtimeEvent::SessionCreated { + session_id: "sess_after_send".to_string() + } + ); + + connection.close().await.expect("close"); + server.await.expect("server task"); +} + +#[tokio::test] +async fn realtime_ws_e2e_disconnected_emitted_once() { + let (addr, server) = spawn_realtime_ws_server(|mut ws: RealtimeWsStream| async move { + let first = ws + .next() + .await + .expect("first msg") + .expect("first msg ok") + .into_text() + .expect("text"); + let first_json: Value = serde_json::from_str(&first).expect("json"); + assert_eq!(first_json["type"], "session.create"); + + ws.send(Message::Close(None)).await.expect("send close"); + }) + .await; + + let client = RealtimeWebsocketClient::new(test_provider()); + let connection = client + .connect( + RealtimeSessionConfig { + api_url: format!("ws://{addr}"), + prompt: "backend prompt".to_string(), + session_id: Some("conv_123".to_string()), + }, + HeaderMap::new(), + HeaderMap::new(), + ) + .await + .expect("connect"); + + let first = connection.next_event().await.expect("next event"); + assert_eq!(first, None); + + let second = connection.next_event().await.expect("next event"); + assert_eq!(second, None); + + server.await.expect("server task"); +} + +#[tokio::test] +async fn realtime_ws_e2e_ignores_unknown_text_events() { + let (addr, server) = spawn_realtime_ws_server(|mut ws: RealtimeWsStream| async move { + let first = ws + .next() + .await + .expect("first msg") + .expect("first msg ok") + .into_text() + .expect("text"); + let first_json: Value = serde_json::from_str(&first).expect("json"); + assert_eq!(first_json["type"], "session.create"); + + ws.send(Message::Text( + json!({ + "type": "response.created", + "response": {"id": "resp_unknown"} + }) + .to_string() + .into(), + )) + .await + .expect("send unknown event"); + + ws.send(Message::Text( + json!({ + "type": "session.created", + "session": {"id": "sess_after_unknown"} + }) + .to_string() + .into(), + )) + .await + .expect("send session.created"); + }) + .await; + + let client = RealtimeWebsocketClient::new(test_provider()); + let connection = client + .connect( + RealtimeSessionConfig { + api_url: format!("ws://{addr}"), + prompt: "backend prompt".to_string(), + session_id: Some("conv_123".to_string()), + }, + HeaderMap::new(), + HeaderMap::new(), + ) + .await + .expect("connect"); + + let event = connection + .next_event() + .await + .expect("next event") + .expect("event"); + assert_eq!( + event, + RealtimeEvent::SessionCreated { + session_id: "sess_after_unknown".to_string() + } + ); + + connection.close().await.expect("close"); + server.await.expect("server task"); +}