diff --git a/codex-rs/app-server/src/bespoke_event_handling.rs b/codex-rs/app-server/src/bespoke_event_handling.rs index cfa286b45..e0fc1cfa4 100644 --- a/codex-rs/app-server/src/bespoke_event_handling.rs +++ b/codex-rs/app-server/src/bespoke_event_handling.rs @@ -1,14 +1,12 @@ use crate::codex_message_processor::ApiVersion; -use crate::codex_message_processor::PendingInterrupts; -use crate::codex_message_processor::PendingRollbacks; -use crate::codex_message_processor::TurnSummary; -use crate::codex_message_processor::TurnSummaryStore; use crate::codex_message_processor::read_rollout_items_from_rollout; use crate::codex_message_processor::read_summary_from_rollout; use crate::codex_message_processor::summary_to_thread; use crate::error_code::INTERNAL_ERROR_CODE; use crate::error_code::INVALID_REQUEST_ERROR_CODE; use crate::outgoing_message::OutgoingMessageSender; +use crate::thread_state::ThreadState; +use crate::thread_state::TurnSummary; use codex_app_server_protocol::AccountRateLimitsUpdatedNotification; use codex_app_server_protocol::AgentMessageDeltaNotification; use codex_app_server_protocol::ApplyPatchApprovalParams; @@ -98,6 +96,7 @@ use std::collections::HashMap; use std::convert::TryFrom; use std::path::PathBuf; use std::sync::Arc; +use tokio::sync::Mutex; use tokio::sync::oneshot; use tracing::error; @@ -109,9 +108,7 @@ pub(crate) async fn apply_bespoke_event_handling( conversation_id: ThreadId, conversation: Arc, outgoing: Arc, - pending_interrupts: PendingInterrupts, - pending_rollbacks: PendingRollbacks, - turn_summary_store: TurnSummaryStore, + thread_state: Arc>, api_version: ApiVersion, fallback_model_provider: String, ) { @@ -122,13 +119,7 @@ pub(crate) async fn apply_bespoke_event_handling( match msg { EventMsg::TurnStarted(_) => {} EventMsg::TurnComplete(_ev) => { - handle_turn_complete( - conversation_id, - event_turn_id, - &outgoing, - &turn_summary_store, - ) - .await; + handle_turn_complete(conversation_id, event_turn_id, &outgoing, &thread_state).await; } EventMsg::ApplyPatchApprovalRequest(ApplyPatchApprovalRequestEvent { call_id, @@ -159,9 +150,11 @@ pub(crate) async fn apply_bespoke_event_handling( let patch_changes = convert_patch_changes(&changes); let first_start = { - let mut map = turn_summary_store.lock().await; - let summary = map.entry(conversation_id).or_default(); - summary.file_change_started.insert(item_id.clone()) + let mut state = thread_state.lock().await; + state + .turn_summary + .file_change_started + .insert(item_id.clone()) }; if first_start { let item = ThreadItem::FileChange { @@ -198,7 +191,7 @@ pub(crate) async fn apply_bespoke_event_handling( rx, conversation, outgoing, - turn_summary_store, + thread_state.clone(), ) .await; }); @@ -718,7 +711,7 @@ pub(crate) async fn apply_bespoke_event_handling( return handle_thread_rollback_failed( conversation_id, message, - &pending_rollbacks, + &thread_state, &outgoing, ) .await; @@ -729,7 +722,7 @@ pub(crate) async fn apply_bespoke_event_handling( codex_error_info: ev.codex_error_info.map(V2CodexErrorInfo::from), additional_details: None, }; - handle_error(conversation_id, turn_error.clone(), &turn_summary_store).await; + handle_error(conversation_id, turn_error.clone(), &thread_state).await; outgoing .send_server_notification(ServerNotification::Error(ErrorNotification { error: turn_error.clone(), @@ -867,9 +860,11 @@ pub(crate) async fn apply_bespoke_event_handling( let item_id = patch_begin_event.call_id.clone(); let first_start = { - let mut map = turn_summary_store.lock().await; - let summary = map.entry(conversation_id).or_default(); - summary.file_change_started.insert(item_id.clone()) + let mut state = thread_state.lock().await; + state + .turn_summary + .file_change_started + .insert(item_id.clone()) }; if first_start { let item = ThreadItem::FileChange { @@ -905,7 +900,7 @@ pub(crate) async fn apply_bespoke_event_handling( status, event_turn_id.clone(), outgoing.as_ref(), - &turn_summary_store, + &thread_state, ) .await; } @@ -950,9 +945,8 @@ pub(crate) async fn apply_bespoke_event_handling( // We need to detect which item type it is so we can emit the right notification. // We already have state tracking FileChange items on item/started, so let's use that. let is_file_change = { - let map = turn_summary_store.lock().await; - map.get(&conversation_id) - .is_some_and(|summary| summary.file_change_started.contains(&item_id)) + let state = thread_state.lock().await; + state.turn_summary.file_change_started.contains(&item_id) }; if is_file_change { let notification = FileChangeOutputDeltaNotification { @@ -1049,8 +1043,8 @@ pub(crate) async fn apply_bespoke_event_handling( // If this is a TurnAborted, reply to any pending interrupt requests. EventMsg::TurnAborted(turn_aborted_event) => { let pending = { - let mut map = pending_interrupts.lock().await; - map.remove(&conversation_id).unwrap_or_default() + let mut state = thread_state.lock().await; + std::mem::take(&mut state.pending_interrupts) }; if !pending.is_empty() { for (rid, ver) in pending { @@ -1069,18 +1063,12 @@ pub(crate) async fn apply_bespoke_event_handling( } } - handle_turn_interrupted( - conversation_id, - event_turn_id, - &outgoing, - &turn_summary_store, - ) - .await; + handle_turn_interrupted(conversation_id, event_turn_id, &outgoing, &thread_state).await; } EventMsg::ThreadRolledBack(_rollback_event) => { let pending = { - let mut map = pending_rollbacks.lock().await; - map.remove(&conversation_id) + let mut state = thread_state.lock().await; + state.pending_rollbacks.take() }; if let Some(request_id) = pending { @@ -1245,14 +1233,11 @@ async fn complete_file_change_item( status: PatchApplyStatus, turn_id: String, outgoing: &OutgoingMessageSender, - turn_summary_store: &TurnSummaryStore, + thread_state: &Arc>, ) { - { - let mut map = turn_summary_store.lock().await; - if let Some(summary) = map.get_mut(&conversation_id) { - summary.file_change_started.remove(&item_id); - } - } + let mut state = thread_state.lock().await; + state.turn_summary.file_change_started.remove(&item_id); + drop(state); let item = ThreadItem::FileChange { id: item_id, @@ -1324,20 +1309,20 @@ async fn maybe_emit_raw_response_item_completed( } async fn find_and_remove_turn_summary( - conversation_id: ThreadId, - turn_summary_store: &TurnSummaryStore, + _conversation_id: ThreadId, + thread_state: &Arc>, ) -> TurnSummary { - let mut map = turn_summary_store.lock().await; - map.remove(&conversation_id).unwrap_or_default() + let mut state = thread_state.lock().await; + std::mem::take(&mut state.turn_summary) } async fn handle_turn_complete( conversation_id: ThreadId, event_turn_id: String, outgoing: &OutgoingMessageSender, - turn_summary_store: &TurnSummaryStore, + thread_state: &Arc>, ) { - let turn_summary = find_and_remove_turn_summary(conversation_id, turn_summary_store).await; + let turn_summary = find_and_remove_turn_summary(conversation_id, thread_state).await; let (status, error) = match turn_summary.last_error { Some(error) => (TurnStatus::Failed, Some(error)), @@ -1351,9 +1336,9 @@ async fn handle_turn_interrupted( conversation_id: ThreadId, event_turn_id: String, outgoing: &OutgoingMessageSender, - turn_summary_store: &TurnSummaryStore, + thread_state: &Arc>, ) { - find_and_remove_turn_summary(conversation_id, turn_summary_store).await; + find_and_remove_turn_summary(conversation_id, thread_state).await; emit_turn_completed_with_status( conversation_id, @@ -1366,15 +1351,12 @@ async fn handle_turn_interrupted( } async fn handle_thread_rollback_failed( - conversation_id: ThreadId, + _conversation_id: ThreadId, message: String, - pending_rollbacks: &PendingRollbacks, + thread_state: &Arc>, outgoing: &OutgoingMessageSender, ) { - let pending_rollback = { - let mut map = pending_rollbacks.lock().await; - map.remove(&conversation_id) - }; + let pending_rollback = thread_state.lock().await.pending_rollbacks.take(); if let Some(request_id) = pending_rollback { outgoing @@ -1419,12 +1401,12 @@ async fn handle_token_count_event( } async fn handle_error( - conversation_id: ThreadId, + _conversation_id: ThreadId, error: TurnError, - turn_summary_store: &TurnSummaryStore, + thread_state: &Arc>, ) { - let mut map = turn_summary_store.lock().await; - map.entry(conversation_id).or_default().last_error = Some(error); + let mut state = thread_state.lock().await; + state.turn_summary.last_error = Some(error); } async fn on_patch_approval_response( @@ -1652,7 +1634,7 @@ async fn on_file_change_request_approval_response( receiver: oneshot::Receiver, codex: Arc, outgoing: Arc, - turn_summary_store: TurnSummaryStore, + thread_state: Arc>, ) { let response = receiver.await; let (decision, completion_status) = match response { @@ -1685,7 +1667,7 @@ async fn on_file_change_request_approval_response( status, event_turn_id.clone(), outgoing.as_ref(), - &turn_summary_store, + &thread_state, ) .await; } @@ -1915,13 +1897,12 @@ mod tests { use pretty_assertions::assert_eq; use rmcp::model::Content; use serde_json::Value as JsonValue; - use std::collections::HashMap; use std::time::Duration; use tokio::sync::Mutex; use tokio::sync::mpsc; - fn new_turn_summary_store() -> TurnSummaryStore { - Arc::new(Mutex::new(HashMap::new())) + fn new_thread_state() -> Arc> { + Arc::new(Mutex::new(ThreadState::default())) } async fn recv_broadcast_message( @@ -1999,7 +1980,7 @@ mod tests { #[tokio::test] async fn test_handle_error_records_message() -> Result<()> { let conversation_id = ThreadId::new(); - let turn_summary_store = new_turn_summary_store(); + let thread_state = new_thread_state(); handle_error( conversation_id, @@ -2008,11 +1989,11 @@ mod tests { codex_error_info: Some(V2CodexErrorInfo::InternalServerError), additional_details: None, }, - &turn_summary_store, + &thread_state, ) .await; - let turn_summary = find_and_remove_turn_summary(conversation_id, &turn_summary_store).await; + let turn_summary = find_and_remove_turn_summary(conversation_id, &thread_state).await; assert_eq!( turn_summary.last_error, Some(TurnError { @@ -2030,13 +2011,13 @@ mod tests { let event_turn_id = "complete1".to_string(); let (tx, mut rx) = mpsc::channel(CHANNEL_CAPACITY); let outgoing = Arc::new(OutgoingMessageSender::new(tx)); - let turn_summary_store = new_turn_summary_store(); + let thread_state = new_thread_state(); handle_turn_complete( conversation_id, event_turn_id.clone(), &outgoing, - &turn_summary_store, + &thread_state, ) .await; @@ -2057,7 +2038,7 @@ mod tests { async fn test_handle_turn_interrupted_emits_interrupted_with_error() -> Result<()> { let conversation_id = ThreadId::new(); let event_turn_id = "interrupt1".to_string(); - let turn_summary_store = new_turn_summary_store(); + let thread_state = new_thread_state(); handle_error( conversation_id, TurnError { @@ -2065,7 +2046,7 @@ mod tests { codex_error_info: None, additional_details: None, }, - &turn_summary_store, + &thread_state, ) .await; let (tx, mut rx) = mpsc::channel(CHANNEL_CAPACITY); @@ -2075,7 +2056,7 @@ mod tests { conversation_id, event_turn_id.clone(), &outgoing, - &turn_summary_store, + &thread_state, ) .await; @@ -2096,7 +2077,7 @@ mod tests { async fn test_handle_turn_complete_emits_failed_with_error() -> Result<()> { let conversation_id = ThreadId::new(); let event_turn_id = "complete_err1".to_string(); - let turn_summary_store = new_turn_summary_store(); + let thread_state = new_thread_state(); handle_error( conversation_id, TurnError { @@ -2104,7 +2085,7 @@ mod tests { codex_error_info: Some(V2CodexErrorInfo::Other), additional_details: None, }, - &turn_summary_store, + &thread_state, ) .await; let (tx, mut rx) = mpsc::channel(CHANNEL_CAPACITY); @@ -2114,7 +2095,7 @@ mod tests { conversation_id, event_turn_id.clone(), &outgoing, - &turn_summary_store, + &thread_state, ) .await; @@ -2336,7 +2317,7 @@ mod tests { // Conversation A will have two turns; Conversation B will have one turn. let conversation_a = ThreadId::new(); let conversation_b = ThreadId::new(); - let turn_summary_store = new_turn_summary_store(); + let thread_state = new_thread_state(); let (tx, mut rx) = mpsc::channel(CHANNEL_CAPACITY); let outgoing = Arc::new(OutgoingMessageSender::new(tx)); @@ -2350,16 +2331,10 @@ mod tests { codex_error_info: Some(V2CodexErrorInfo::BadRequest), additional_details: None, }, - &turn_summary_store, - ) - .await; - handle_turn_complete( - conversation_a, - a_turn1.clone(), - &outgoing, - &turn_summary_store, + &thread_state, ) .await; + handle_turn_complete(conversation_a, a_turn1.clone(), &outgoing, &thread_state).await; // Turn 1 on conversation B let b_turn1 = "b_turn1".to_string(); @@ -2370,26 +2345,14 @@ mod tests { codex_error_info: None, additional_details: None, }, - &turn_summary_store, - ) - .await; - handle_turn_complete( - conversation_b, - b_turn1.clone(), - &outgoing, - &turn_summary_store, + &thread_state, ) .await; + handle_turn_complete(conversation_b, b_turn1.clone(), &outgoing, &thread_state).await; // Turn 2 on conversation A let a_turn2 = "a_turn2".to_string(); - handle_turn_complete( - conversation_a, - a_turn2.clone(), - &outgoing, - &turn_summary_store, - ) - .await; + handle_turn_complete(conversation_a, a_turn2.clone(), &outgoing, &thread_state).await; // Verify: A turn 1 let msg = recv_broadcast_message(&mut rx).await?; diff --git a/codex-rs/app-server/src/codex_message_processor.rs b/codex-rs/app-server/src/codex_message_processor.rs index bb6062200..5e0e32dad 100644 --- a/codex-rs/app-server/src/codex_message_processor.rs +++ b/codex-rs/app-server/src/codex_message_processor.rs @@ -137,7 +137,6 @@ use codex_app_server_protocol::ThreadStartedNotification; use codex_app_server_protocol::ThreadUnarchiveParams; use codex_app_server_protocol::ThreadUnarchiveResponse; use codex_app_server_protocol::Turn; -use codex_app_server_protocol::TurnError; use codex_app_server_protocol::TurnInterruptParams; use codex_app_server_protocol::TurnStartParams; use codex_app_server_protocol::TurnStartResponse; @@ -252,20 +251,7 @@ use uuid::Uuid; use crate::filters::compute_source_filters; use crate::filters::source_kind_matches; - -type PendingInterruptQueue = Vec<(ConnectionRequestId, ApiVersion)>; -pub(crate) type PendingInterrupts = Arc>>; - -pub(crate) type PendingRollbacks = Arc>>; - -/// Per-conversation accumulation of the latest states e.g. error message while a turn runs. -#[derive(Default, Clone)] -pub(crate) struct TurnSummary { - pub(crate) file_change_started: HashSet, - pub(crate) last_error: Option, -} - -pub(crate) type TurnSummaryStore = Arc>>; +use crate::thread_state::ThreadStateManager; const THREAD_LIST_DEFAULT_LIMIT: usize = 25; const THREAD_LIST_MAX_LIMIT: usize = 100; @@ -303,21 +289,16 @@ pub(crate) struct CodexMessageProcessor { config: Arc, cli_overrides: Vec<(String, TomlValue)>, cloud_requirements: Arc>, - conversation_listeners: HashMap>, - listener_thread_ids_by_subscription: HashMap, active_login: Arc>>, - // Queue of pending interrupt requests per conversation. We reply when TurnAborted arrives. - pending_interrupts: PendingInterrupts, - // Queue of pending rollback requests per conversation. We reply when ThreadRollback arrives. - pending_rollbacks: PendingRollbacks, - turn_summary_store: TurnSummaryStore, + thread_state_manager: ThreadStateManager, pending_fuzzy_searches: Arc>>>, feedback: CodexFeedback, } -#[derive(Clone, Copy, Debug)] +#[derive(Clone, Copy, Debug, Default)] pub(crate) enum ApiVersion { V1, + #[default] V2, } @@ -375,12 +356,8 @@ impl CodexMessageProcessor { config, cli_overrides, cloud_requirements, - conversation_listeners: HashMap::new(), - listener_thread_ids_by_subscription: HashMap::new(), active_login: Arc::new(Mutex::new(None)), - pending_interrupts: Arc::new(Mutex::new(HashMap::new())), - pending_rollbacks: Arc::new(Mutex::new(HashMap::new())), - turn_summary_store: Arc::new(Mutex::new(HashMap::new())), + thread_state_manager: ThreadStateManager::new(), pending_fuzzy_searches: Arc::new(Mutex::new(HashMap::new())), feedback, } @@ -1012,7 +989,7 @@ impl CodexMessageProcessor { cloud_requirements.as_ref(), auth_manager.clone(), chatgpt_base_url, - codex_home.clone(), + codex_home, ); sync_default_client_residency_requirement( &cli_overrides, @@ -1120,7 +1097,7 @@ impl CodexMessageProcessor { cloud_requirements.as_ref(), auth_manager.clone(), chatgpt_base_url, - codex_home.clone(), + codex_home, ); sync_default_client_residency_requirement( &cli_overrides, @@ -2330,25 +2307,32 @@ impl CodexMessageProcessor { let request = request_id.clone(); - { - let mut map = self.pending_rollbacks.lock().await; - if map.contains_key(&thread_id) { - self.send_invalid_request_error( - request.clone(), - "rollback already in progress for this thread".to_string(), - ) - .await; - return; + let rollback_already_in_progress = { + let thread_state = self.thread_state_manager.thread_state(thread_id); + let mut thread_state = thread_state.lock().await; + if thread_state.pending_rollbacks.is_some() { + true + } else { + thread_state.pending_rollbacks = Some(request.clone()); + false } - - map.insert(thread_id, request.clone()); + }; + if rollback_already_in_progress { + self.send_invalid_request_error( + request.clone(), + "rollback already in progress for this thread".to_string(), + ) + .await; + return; } if let Err(err) = thread.submit(Op::ThreadRollback { num_turns }).await { // No ThreadRollback event will arrive if an error occurs. // Clean up and reply immediately. - let mut map = self.pending_rollbacks.lock().await; - map.remove(&thread_id); + let thread_state = self.thread_state_manager.thread_state(thread_id); + let mut thread_state = thread_state.lock().await; + thread_state.pending_rollbacks = None; + drop(thread_state); self.send_internal_error(request, format!("failed to start rollback: {err}")) .await; @@ -2646,11 +2630,7 @@ impl CodexMessageProcessor { /// Best-effort: attach a listener for thread_id if missing. pub(crate) async fn try_attach_thread_listener(&mut self, thread_id: ThreadId) { - if self - .listener_thread_ids_by_subscription - .values() - .any(|entry| *entry == thread_id) - { + if self.thread_state_manager.has_listener_for_thread(thread_id) { return; } @@ -4313,7 +4293,8 @@ impl CodexMessageProcessor { let mut state_db_ctx = None; // If the thread is active, request shutdown and wait briefly. - if let Some(conversation) = self.thread_manager.remove_thread(&thread_id).await { + let removed_conversation = self.thread_manager.remove_thread(&thread_id).await; + if let Some(conversation) = removed_conversation { if let Some(ctx) = conversation.state_db() { state_db_ctx = Some(ctx); } @@ -4341,6 +4322,9 @@ impl CodexMessageProcessor { error!("failed to submit Shutdown to thread {thread_id}: {err}"); } } + self.thread_state_manager + .remove_thread_state(thread_id) + .await; } if state_db_ctx.is_none() { @@ -4880,9 +4864,10 @@ impl CodexMessageProcessor { // Record the pending interrupt so we can reply when TurnAborted arrives. { - let mut map = self.pending_interrupts.lock().await; - map.entry(conversation_id) - .or_default() + let pending_interrupts = self.thread_state_manager.thread_state(conversation_id); + let mut thread_state = pending_interrupts.lock().await; + thread_state + .pending_interrupts .push((request, ApiVersion::V1)); } @@ -5276,9 +5261,10 @@ impl CodexMessageProcessor { // Record the pending interrupt so we can reply when TurnAborted arrives. { - let mut map = self.pending_interrupts.lock().await; - map.entry(thread_uuid) - .or_default() + let thread_state = self.thread_state_manager.thread_state(thread_uuid); + let mut thread_state = thread_state.lock().await; + thread_state + .pending_interrupts .push((request, ApiVersion::V2)); } @@ -5315,16 +5301,13 @@ impl CodexMessageProcessor { params: RemoveConversationListenerParams, ) { let RemoveConversationListenerParams { subscription_id } = params; - match self.conversation_listeners.remove(&subscription_id) { - Some(sender) => { - // Signal the spawned task to exit and acknowledge. - let _ = sender.send(()); - if let Some(thread_id) = self - .listener_thread_ids_by_subscription - .remove(&subscription_id) - { - info!("removed listener for thread {thread_id}"); - } + match self + .thread_state_manager + .remove_listener(subscription_id) + .await + { + Some(thread_id) => { + info!("removed listener for thread {thread_id}"); let response = RemoveConversationSubscriptionResponse {}; self.outgoing.send_response(request_id, response).await; } @@ -5342,7 +5325,7 @@ impl CodexMessageProcessor { async fn attach_conversation_listener( &mut self, conversation_id: ThreadId, - experimental_raw_events: bool, + raw_events_enabled: bool, api_version: ApiVersion, ) -> Result { let conversation = match self.thread_manager.get_thread(conversation_id).await { @@ -5358,16 +5341,11 @@ impl CodexMessageProcessor { let subscription_id = Uuid::new_v4(); let (cancel_tx, mut cancel_rx) = oneshot::channel(); - self.conversation_listeners - .insert(subscription_id, cancel_tx); - self.listener_thread_ids_by_subscription - .insert(subscription_id, conversation_id); - + let thread_state = self + .thread_state_manager + .set_listener(subscription_id, conversation_id, cancel_tx) + .await; let outgoing_for_task = self.outgoing.clone(); - let pending_interrupts = self.pending_interrupts.clone(); - let pending_rollbacks = self.pending_rollbacks.clone(); - let turn_summary_store = self.turn_summary_store.clone(); - let api_version_for_task = api_version; let fallback_model_provider = self.config.model_provider_id.clone(); tokio::spawn(async move { loop { @@ -5385,10 +5363,9 @@ impl CodexMessageProcessor { } }; - if let EventMsg::RawResponseItem(_) = &event.msg - && !experimental_raw_events { - continue; - } + if let EventMsg::RawResponseItem(_) = &event.msg && !raw_events_enabled { + continue; + } // For now, we send a notification for every event, // JSON-serializing the `Event` as-is, but these should @@ -5427,10 +5404,8 @@ impl CodexMessageProcessor { conversation_id, conversation.clone(), outgoing_for_task.clone(), - pending_interrupts.clone(), - pending_rollbacks.clone(), - turn_summary_store.clone(), - api_version_for_task, + thread_state.clone(), + api_version, fallback_model_provider.clone(), ) .await; @@ -5699,7 +5674,7 @@ fn replace_cloud_requirements_loader( cloud_requirements: &RwLock, auth_manager: Arc, chatgpt_base_url: String, - codex_home: std::path::PathBuf, + codex_home: PathBuf, ) { let loader = cloud_requirements_loader(auth_manager, chatgpt_base_url, codex_home); if let Ok(mut guard) = cloud_requirements.write() { @@ -6316,4 +6291,30 @@ mod tests { assert_eq!(summary, expected); Ok(()) } + + #[tokio::test] + async fn removing_one_listener_does_not_cancel_other_subscriptions_for_same_thread() + -> Result<()> { + let mut manager = ThreadStateManager::new(); + let thread_id = ThreadId::from_string("ad7f0408-99b8-4f6e-a46f-bd0eec433370")?; + let listener_a = Uuid::new_v4(); + let listener_b = Uuid::new_v4(); + let (cancel_a, cancel_rx_a) = oneshot::channel(); + let (cancel_b, mut cancel_rx_b) = oneshot::channel(); + + manager.set_listener(listener_a, thread_id, cancel_a).await; + manager.set_listener(listener_b, thread_id, cancel_b).await; + + assert_eq!(manager.remove_listener(listener_a).await, Some(thread_id)); + assert_eq!(cancel_rx_a.await, Ok(())); + assert!( + tokio::time::timeout(Duration::from_millis(20), &mut cancel_rx_b) + .await + .is_err() + ); + + assert_eq!(manager.remove_listener(listener_b).await, Some(thread_id)); + assert_eq!(cancel_rx_b.await, Ok(())); + Ok(()) + } } diff --git a/codex-rs/app-server/src/lib.rs b/codex-rs/app-server/src/lib.rs index 2a31b2053..06595d5b3 100644 --- a/codex-rs/app-server/src/lib.rs +++ b/codex-rs/app-server/src/lib.rs @@ -62,6 +62,7 @@ mod fuzzy_file_search; mod message_processor; mod models; mod outgoing_message; +mod thread_state; mod transport; pub use crate::transport::AppServerTransport; diff --git a/codex-rs/app-server/src/thread_state.rs b/codex-rs/app-server/src/thread_state.rs new file mode 100644 index 000000000..eb263c5e4 --- /dev/null +++ b/codex-rs/app-server/src/thread_state.rs @@ -0,0 +1,106 @@ +use crate::outgoing_message::ConnectionRequestId; +use codex_app_server_protocol::TurnError; +use codex_protocol::ThreadId; +use std::collections::HashMap; +use std::collections::HashSet; +use std::sync::Arc; +use tokio::sync::Mutex; +use tokio::sync::oneshot; +use uuid::Uuid; + +type PendingInterruptQueue = Vec<( + ConnectionRequestId, + crate::codex_message_processor::ApiVersion, +)>; + +/// Per-conversation accumulation of the latest states e.g. error message while a turn runs. +#[derive(Default, Clone)] +pub(crate) struct TurnSummary { + pub(crate) file_change_started: HashSet, + pub(crate) last_error: Option, +} + +#[derive(Default)] +pub(crate) struct ThreadState { + pub(crate) pending_interrupts: PendingInterruptQueue, + pub(crate) pending_rollbacks: Option, + pub(crate) turn_summary: TurnSummary, + pub(crate) listener_cancel_txs: HashMap>, +} + +impl ThreadState { + fn set_listener(&mut self, subscription_id: Uuid, cancel_tx: oneshot::Sender<()>) { + if let Some(previous) = self.listener_cancel_txs.insert(subscription_id, cancel_tx) { + let _ = previous.send(()); + } + } + + fn clear_listener(&mut self, subscription_id: Uuid) { + if let Some(cancel_tx) = self.listener_cancel_txs.remove(&subscription_id) { + let _ = cancel_tx.send(()); + } + } + + fn clear_listeners(&mut self) { + for (_, cancel_tx) in self.listener_cancel_txs.drain() { + let _ = cancel_tx.send(()); + } + } +} + +#[derive(Default)] +pub(crate) struct ThreadStateManager { + thread_states: HashMap>>, + thread_id_by_subscription: HashMap, +} + +impl ThreadStateManager { + pub(crate) fn new() -> Self { + Self::default() + } + + pub(crate) fn has_listener_for_thread(&self, thread_id: ThreadId) -> bool { + self.thread_id_by_subscription + .values() + .any(|existing| *existing == thread_id) + } + + pub(crate) fn thread_state(&mut self, thread_id: ThreadId) -> Arc> { + self.thread_states + .entry(thread_id) + .or_insert_with(|| Arc::new(Mutex::new(ThreadState::default()))) + .clone() + } + + pub(crate) async fn remove_listener(&mut self, subscription_id: Uuid) -> Option { + let thread_id = self.thread_id_by_subscription.remove(&subscription_id)?; + if let Some(thread_state) = self.thread_states.get(&thread_id) { + thread_state.lock().await.clear_listener(subscription_id); + } + Some(thread_id) + } + + pub(crate) async fn remove_thread_state(&mut self, thread_id: ThreadId) { + if let Some(thread_state) = self.thread_states.remove(&thread_id) { + thread_state.lock().await.clear_listeners(); + } + self.thread_id_by_subscription + .retain(|_, existing_thread_id| *existing_thread_id != thread_id); + } + + pub(crate) async fn set_listener( + &mut self, + subscription_id: Uuid, + thread_id: ThreadId, + cancel_tx: oneshot::Sender<()>, + ) -> Arc> { + self.thread_id_by_subscription + .insert(subscription_id, thread_id); + let thread_state = self.thread_state(thread_id); + thread_state + .lock() + .await + .set_listener(subscription_id, cancel_tx); + thread_state + } +}