diff --git a/codex-rs/app-server/src/codex_message_processor.rs b/codex-rs/app-server/src/codex_message_processor.rs index c5b4e6277..c4532cafd 100644 --- a/codex-rs/app-server/src/codex_message_processor.rs +++ b/codex-rs/app-server/src/codex_message_processor.rs @@ -393,6 +393,23 @@ pub(crate) enum ApiVersion { V2, } +#[derive(Clone)] +struct ListenerTaskContext { + thread_manager: Arc, + thread_state_manager: ThreadStateManager, + outgoing: Arc, + thread_watch_manager: ThreadWatchManager, + fallback_model_provider: String, + codex_home: PathBuf, + single_client_mode: bool, +} + +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +enum EnsureConversationListenerResult { + Attached, + ConnectionClosed, +} + pub(crate) struct CodexMessageProcessorArgs { pub(crate) auth_manager: Arc, pub(crate) thread_manager: Arc, @@ -2025,7 +2042,7 @@ impl CodexMessageProcessor { } } - async fn thread_start(&mut self, request_id: ConnectionRequestId, params: ThreadStartParams) { + async fn thread_start(&self, request_id: ConnectionRequestId, params: ThreadStartParams) { let ThreadStartParams { model, model_provider, @@ -2054,11 +2071,51 @@ impl CodexMessageProcessor { personality, ); typesafe_overrides.ephemeral = ephemeral; - + let cli_overrides = self.cli_overrides.clone(); let cloud_requirements = self.current_cloud_requirements(); + let listener_task_context = ListenerTaskContext { + thread_manager: Arc::clone(&self.thread_manager), + thread_state_manager: self.thread_state_manager.clone(), + outgoing: Arc::clone(&self.outgoing), + thread_watch_manager: self.thread_watch_manager.clone(), + fallback_model_provider: self.config.model_provider_id.clone(), + codex_home: self.config.codex_home.clone(), + single_client_mode: self.single_client_mode, + }; + + tokio::spawn(async move { + Self::thread_start_task( + listener_task_context, + cli_overrides, + cloud_requirements, + request_id, + config, + typesafe_overrides, + dynamic_tools, + persist_extended_history, + service_name, + experimental_raw_events, + ) + .await; + }); + } + + #[allow(clippy::too_many_arguments)] + async fn thread_start_task( + listener_task_context: ListenerTaskContext, + cli_overrides: Vec<(String, TomlValue)>, + cloud_requirements: CloudRequirementsLoader, + request_id: ConnectionRequestId, + config_overrides: Option>, + typesafe_overrides: ConfigOverrides, + dynamic_tools: Option>, + persist_extended_history: bool, + service_name: Option, + experimental_raw_events: bool, + ) { let config = match derive_config_from_params( - &self.cli_overrides, - config, + &cli_overrides, + config_overrides, typesafe_overrides, &cloud_requirements, ) @@ -2071,7 +2128,10 @@ impl CodexMessageProcessor { message: format!("error deriving config: {err}"), data: None, }; - self.outgoing.send_error(request_id, error).await; + listener_task_context + .outgoing + .send_error(request_id, error) + .await; return; } }; @@ -2086,7 +2146,10 @@ impl CodexMessageProcessor { message, data: None, }; - self.outgoing.send_error(request_id, error).await; + listener_task_context + .outgoing + .send_error(request_id, error) + .await; return; } dynamic_tools @@ -2099,7 +2162,7 @@ impl CodexMessageProcessor { .collect() }; - match self + match listener_task_context .thread_manager .start_thread_with_tools_and_service_name( config, @@ -2124,29 +2187,28 @@ impl CodexMessageProcessor { ); // Auto-attach a thread listener when starting a thread. - // Use the same behavior as the v1 API, with opt-in support for raw item events. - if let Err(err) = self - .ensure_conversation_listener( + Self::log_listener_attach_result( + Self::ensure_conversation_listener_task( + listener_task_context.clone(), thread_id, request_id.connection_id, experimental_raw_events, ApiVersion::V2, ) - .await - { - tracing::warn!( - "failed to attach listener for thread {}: {}", - thread_id, - err.message - ); - } + .await, + thread_id, + request_id.connection_id, + "thread", + ); - self.thread_watch_manager + listener_task_context + .thread_watch_manager .upsert_thread(thread.clone()) .await; thread.status = resolve_thread_status( - self.thread_watch_manager + listener_task_context + .thread_watch_manager .loaded_status_for_thread(&thread.id) .await, false, @@ -2162,10 +2224,14 @@ impl CodexMessageProcessor { reasoning_effort: config_snapshot.reasoning_effort, }; - self.outgoing.send_response(request_id, response).await; + listener_task_context + .outgoing + .send_response(request_id, response) + .await; let notif = ThreadStartedNotification { thread }; - self.outgoing + listener_task_context + .outgoing .send_server_notification(ServerNotification::ThreadStarted(notif)) .await; } @@ -2175,7 +2241,10 @@ impl CodexMessageProcessor { message: format!("error creating thread: {err}"), data: None, }; - self.outgoing.send_error(request_id, error).await; + listener_task_context + .outgoing + .send_error(request_id, error) + .await; } } } @@ -2515,7 +2584,7 @@ impl CodexMessageProcessor { let request = request_id.clone(); let rollback_already_in_progress = { - let thread_state = self.thread_state_manager.thread_state(thread_id); + let thread_state = self.thread_state_manager.thread_state(thread_id).await; let mut thread_state = thread_state.lock().await; if thread_state.pending_rollbacks.is_some() { true @@ -2536,7 +2605,7 @@ impl CodexMessageProcessor { if let Err(err) = thread.submit(Op::ThreadRollback { num_turns }).await { // No ThreadRollback event will arrive if an error occurs. // Clean up and reply immediately. - let thread_state = self.thread_state_manager.thread_state(thread_id); + let thread_state = self.thread_state_manager.thread_state(thread_id).await; let mut thread_state = thread_state.lock().await; thread_state.pending_rollbacks = None; drop(thread_state); @@ -2882,6 +2951,12 @@ impl CodexMessageProcessor { self.thread_manager.subscribe_thread_created() } + pub(crate) async fn connection_initialized(&self, connection_id: ConnectionId) { + self.thread_state_manager + .connection_initialized(connection_id) + .await; + } + pub(crate) async fn connection_closed(&mut self, connection_id: ConnectionId) { self.thread_state_manager .remove_connection(connection_id) @@ -2906,15 +2981,13 @@ impl CodexMessageProcessor { } for connection_id in connection_ids { - if let Err(err) = self - .ensure_conversation_listener(thread_id, connection_id, false, ApiVersion::V2) - .await - { - warn!( - "failed to auto-attach listener for thread {thread_id}: {message}", - message = err.message - ); - } + Self::log_listener_attach_result( + self.ensure_conversation_listener(thread_id, connection_id, false, ApiVersion::V2) + .await, + thread_id, + connection_id, + "thread", + ); } } @@ -3039,21 +3112,18 @@ impl CodexMessageProcessor { return; }; // Auto-attach a thread listener when resuming a thread. - if let Err(err) = self - .ensure_conversation_listener( + Self::log_listener_attach_result( + self.ensure_conversation_listener( thread_id, request_id.connection_id, false, ApiVersion::V2, ) - .await - { - tracing::warn!( - "failed to attach listener for thread {}: {}", - thread_id, - err.message - ); - } + .await, + thread_id, + request_id.connection_id, + "thread", + ); let Some(mut thread) = self .load_thread_from_rollout_or_send_internal( @@ -3191,7 +3261,10 @@ impl CodexMessageProcessor { return true; } - let thread_state = self.thread_state_manager.thread_state(existing_thread_id); + let thread_state = self + .thread_state_manager + .thread_state(existing_thread_id) + .await; self.ensure_listener_task_running( existing_thread_id, existing_thread.clone(), @@ -3542,21 +3615,18 @@ impl CodexMessageProcessor { return; }; // Auto-attach a conversation listener when forking a thread. - if let Err(err) = self - .ensure_conversation_listener( + Self::log_listener_attach_result( + self.ensure_conversation_listener( thread_id, request_id.connection_id, false, ApiVersion::V2, ) - .await - { - tracing::warn!( - "failed to attach listener for thread {}: {}", - thread_id, - err.message - ); - } + .await, + thread_id, + request_id.connection_id, + "thread", + ); let mut thread = match read_summary_from_rollout( rollout_path.as_path(), @@ -5692,7 +5762,10 @@ impl CodexMessageProcessor { // Record the pending interrupt so we can reply when TurnAborted arrives. { - let pending_interrupts = self.thread_state_manager.thread_state(conversation_id); + let pending_interrupts = self + .thread_state_manager + .thread_state(conversation_id) + .await; let mut thread_state = pending_interrupts.lock().await; thread_state .pending_interrupts @@ -5895,7 +5968,7 @@ impl CodexMessageProcessor { } }; - if let Err(error) = self + match self .ensure_conversation_listener( thread_id, request_id.connection_id, @@ -5904,8 +5977,14 @@ impl CodexMessageProcessor { ) .await { - self.outgoing.send_error(request_id, error).await; - return None; + Ok(EnsureConversationListenerResult::Attached) => {} + Ok(EnsureConversationListenerResult::ConnectionClosed) => { + return None; + } + Err(error) => { + self.outgoing.send_error(request_id, error).await; + return None; + } } if !thread.enabled(Feature::RealtimeConversation) { @@ -6174,21 +6253,18 @@ impl CodexMessageProcessor { data: None, })?; - if let Err(err) = self - .ensure_conversation_listener( + Self::log_listener_attach_result( + self.ensure_conversation_listener( thread_id, request_id.connection_id, false, ApiVersion::V2, ) - .await - { - tracing::warn!( - "failed to attach listener for review thread {}: {}", - thread_id, - err.message - ); - } + .await, + thread_id, + request_id.connection_id, + "review thread", + ); let fallback_provider = self.config.model_provider_id.as_str(); if let Some(rollout_path) = review_thread.rollout_path() { @@ -6315,7 +6391,7 @@ impl CodexMessageProcessor { // Record the pending interrupt so we can reply when TurnAborted arrives. { - let thread_state = self.thread_state_manager.thread_state(thread_uuid); + let thread_state = self.thread_state_manager.thread_state(thread_uuid).await; let mut thread_state = thread_state.lock().await; thread_state .pending_interrupts @@ -6397,13 +6473,42 @@ impl CodexMessageProcessor { } async fn ensure_conversation_listener( - &mut self, + &self, conversation_id: ThreadId, connection_id: ConnectionId, raw_events_enabled: bool, api_version: ApiVersion, - ) -> Result<(), JSONRPCErrorError> { - let conversation = match self.thread_manager.get_thread(conversation_id).await { + ) -> Result { + Self::ensure_conversation_listener_task( + ListenerTaskContext { + thread_manager: Arc::clone(&self.thread_manager), + thread_state_manager: self.thread_state_manager.clone(), + outgoing: Arc::clone(&self.outgoing), + thread_watch_manager: self.thread_watch_manager.clone(), + fallback_model_provider: self.config.model_provider_id.clone(), + codex_home: self.config.codex_home.clone(), + single_client_mode: self.single_client_mode, + }, + conversation_id, + connection_id, + raw_events_enabled, + api_version, + ) + .await + } + + async fn ensure_conversation_listener_task( + listener_task_context: ListenerTaskContext, + conversation_id: ThreadId, + connection_id: ConnectionId, + raw_events_enabled: bool, + api_version: ApiVersion, + ) -> Result { + let conversation = match listener_task_context + .thread_manager + .get_thread(conversation_id) + .await + { Ok(conv) => conv, Err(_) => { return Err(JSONRPCErrorError { @@ -6413,13 +6518,46 @@ impl CodexMessageProcessor { }); } }; - let thread_state = self + let Some(thread_state) = listener_task_context .thread_state_manager - .ensure_connection_subscribed(conversation_id, connection_id, raw_events_enabled) - .await; - self.ensure_listener_task_running(conversation_id, conversation, thread_state, api_version) - .await; - Ok(()) + .try_ensure_connection_subscribed(conversation_id, connection_id, raw_events_enabled) + .await + else { + return Ok(EnsureConversationListenerResult::ConnectionClosed); + }; + Self::ensure_listener_task_running_task( + listener_task_context, + conversation_id, + conversation, + thread_state, + api_version, + ) + .await; + Ok(EnsureConversationListenerResult::Attached) + } + + fn log_listener_attach_result( + result: Result, + thread_id: ThreadId, + connection_id: ConnectionId, + thread_kind: &'static str, + ) { + match result { + Ok(EnsureConversationListenerResult::Attached) => {} + Ok(EnsureConversationListenerResult::ConnectionClosed) => { + tracing::debug!( + thread_id = %thread_id, + connection_id = ?connection_id, + "skipping auto-attach for closed connection" + ); + } + Err(err) => { + tracing::warn!( + "failed to attach listener for {thread_kind} {thread_id}: {message}", + message = err.message + ); + } + } } async fn ensure_listener_task_running( @@ -6428,6 +6566,31 @@ impl CodexMessageProcessor { conversation: Arc, thread_state: Arc>, api_version: ApiVersion, + ) { + Self::ensure_listener_task_running_task( + ListenerTaskContext { + thread_manager: Arc::clone(&self.thread_manager), + thread_state_manager: self.thread_state_manager.clone(), + outgoing: Arc::clone(&self.outgoing), + thread_watch_manager: self.thread_watch_manager.clone(), + fallback_model_provider: self.config.model_provider_id.clone(), + codex_home: self.config.codex_home.clone(), + single_client_mode: self.single_client_mode, + }, + conversation_id, + conversation, + thread_state, + api_version, + ) + .await; + } + + async fn ensure_listener_task_running_task( + listener_task_context: ListenerTaskContext, + conversation_id: ThreadId, + conversation: Arc, + thread_state: Arc>, + api_version: ApiVersion, ) { let (cancel_tx, mut cancel_rx) = oneshot::channel(); let (mut listener_command_rx, listener_generation) = { @@ -6437,12 +6600,16 @@ impl CodexMessageProcessor { } thread_state.set_listener(cancel_tx, &conversation) }; - let outgoing_for_task = self.outgoing.clone(); - let thread_manager = self.thread_manager.clone(); - let thread_watch_manager = self.thread_watch_manager.clone(); - let fallback_model_provider = self.config.model_provider_id.clone(); - let single_client_mode = self.single_client_mode; - let codex_home = self.config.codex_home.clone(); + let ListenerTaskContext { + outgoing, + thread_manager, + thread_state_manager, + thread_watch_manager, + fallback_model_provider, + codex_home, + single_client_mode, + } = listener_task_context; + let outgoing_for_task = Arc::clone(&outgoing); tokio::spawn(async move { loop { tokio::select! { @@ -6488,16 +6655,16 @@ impl CodexMessageProcessor { "conversationId".to_string(), conversation_id.to_string().into(), ); - let (subscribed_connection_ids, raw_events_enabled) = { + let raw_events_enabled = { let mut thread_state = thread_state.lock().await; if !single_client_mode { thread_state.track_current_turn_event(&event.msg); } - ( - thread_state.subscribed_connection_ids(), - thread_state.experimental_raw_events, - ) + thread_state.experimental_raw_events }; + let subscribed_connection_ids = thread_state_manager + .subscribed_connection_ids(conversation_id) + .await; if let EventMsg::RawResponseItem(_) = &event.msg && !raw_events_enabled { continue; } @@ -6540,6 +6707,7 @@ impl CodexMessageProcessor { handle_thread_listener_command( conversation_id, codex_home.as_path(), + &thread_state_manager, &thread_state, &thread_watch_manager, &outgoing_for_task, @@ -6857,6 +7025,7 @@ impl CodexMessageProcessor { async fn handle_thread_listener_command( conversation_id: ThreadId, codex_home: &Path, + thread_state_manager: &ThreadStateManager, thread_state: &Arc>, thread_watch_manager: &ThreadWatchManager, outgoing: &Arc, @@ -6867,6 +7036,7 @@ async fn handle_thread_listener_command( handle_pending_thread_resume_request( conversation_id, codex_home, + thread_state_manager, thread_state, thread_watch_manager, outgoing, @@ -6878,8 +7048,13 @@ async fn handle_thread_listener_command( request_id, completion_tx, } => { - resolve_pending_server_request(conversation_id, thread_state, outgoing, request_id) - .await; + resolve_pending_server_request( + conversation_id, + thread_state_manager, + outgoing, + request_id, + ) + .await; let _ = completion_tx.send(()); } } @@ -6888,6 +7063,7 @@ async fn handle_thread_listener_command( async fn handle_pending_thread_resume_request( conversation_id: ThreadId, codex_home: &Path, + thread_state_manager: &ThreadStateManager, thread_state: &Arc>, thread_watch_manager: &ThreadWatchManager, outgoing: &Arc, @@ -6976,18 +7152,21 @@ async fn handle_pending_thread_resume_request( outgoing .replay_requests_to_connection_for_thread(connection_id, conversation_id) .await; - - thread_state.lock().await.add_connection(connection_id); + let _attached = thread_state_manager + .try_add_connection_to_thread(conversation_id, connection_id) + .await; } async fn resolve_pending_server_request( conversation_id: ThreadId, - thread_state: &Arc>, + thread_state_manager: &ThreadStateManager, outgoing: &Arc, request_id: RequestId, ) { let thread_id = conversation_id.to_string(); - let subscribed_connection_ids = thread_state.lock().await.subscribed_connection_ids(); + let subscribed_connection_ids = thread_state_manager + .subscribed_connection_ids(conversation_id) + .await; let outgoing = ThreadScopedOutgoingMessageSender::new( outgoing.clone(), subscribed_connection_ids, @@ -7951,9 +8130,7 @@ mod tests { #[tokio::test] async fn aborting_pending_request_clears_pending_state() -> Result<()> { let thread_id = ThreadId::from_string("bfd12a78-5900-467b-9bc5-d3d35df08191")?; - let thread_state = Arc::new(Mutex::new(ThreadState::default())); let connection_id = ConnectionId(7); - thread_state.lock().await.add_connection(connection_id); let (outgoing_tx, mut outgoing_rx) = tokio::sync::mpsc::channel(8); let outgoing = Arc::new(OutgoingMessageSender::new(outgoing_tx)); @@ -8047,7 +8224,7 @@ mod tests { #[tokio::test] async fn removing_listeners_retains_thread_listener_when_last_subscriber_leaves() -> Result<()> { - let mut manager = ThreadStateManager::new(); + let manager = ThreadStateManager::new(); let thread_id = ThreadId::from_string("ad7f0408-99b8-4f6e-a46f-bd0eec433370")?; let listener_a = Uuid::new_v4(); let listener_b = Uuid::new_v4(); @@ -8062,7 +8239,7 @@ mod tests { .set_listener(listener_b, thread_id, connection_b, false) .await; { - let state = manager.thread_state(thread_id); + let state = manager.thread_state(thread_id).await; state.lock().await.cancel_tx = Some(cancel_tx); } @@ -8078,14 +8255,18 @@ mod tests { .await .is_err() ); - let state = manager.thread_state(thread_id); - assert!(state.lock().await.subscribed_connection_ids().is_empty()); + assert!( + manager + .subscribed_connection_ids(thread_id) + .await + .is_empty() + ); Ok(()) } #[tokio::test] async fn removing_listener_unsubscribes_its_connection() -> Result<()> { - let mut manager = ThreadStateManager::new(); + let manager = ThreadStateManager::new(); let thread_id = ThreadId::from_string("ad7f0408-99b8-4f6e-a46f-bd0eec433370")?; let listener_a = Uuid::new_v4(); let listener_b = Uuid::new_v4(); @@ -8100,15 +8281,14 @@ mod tests { .await; assert_eq!(manager.remove_listener(listener_a).await, Some(thread_id)); - let state = manager.thread_state(thread_id); - let subscribed_connection_ids = state.lock().await.subscribed_connection_ids(); + let subscribed_connection_ids = manager.subscribed_connection_ids(thread_id).await; assert_eq!(subscribed_connection_ids, vec![connection_b]); Ok(()) } #[tokio::test] async fn set_listener_uses_last_write_for_raw_events() -> Result<()> { - let mut manager = ThreadStateManager::new(); + let manager = ThreadStateManager::new(); let thread_id = ThreadId::from_string("ad7f0408-99b8-4f6e-a46f-bd0eec433370")?; let listener_a = Uuid::new_v4(); let listener_b = Uuid::new_v4(); @@ -8119,13 +8299,13 @@ mod tests { .set_listener(listener_a, thread_id, connection_a, true) .await; { - let state = manager.thread_state(thread_id); + let state = manager.thread_state(thread_id).await; assert!(state.lock().await.experimental_raw_events); } manager .set_listener(listener_b, thread_id, connection_b, false) .await; - let state = manager.thread_state(thread_id); + let state = manager.thread_state(thread_id).await; assert!(!state.lock().await.experimental_raw_events); Ok(()) } @@ -8133,7 +8313,7 @@ mod tests { #[tokio::test] async fn removing_connection_retains_listener_and_active_turn_when_last_subscriber_disconnects() -> Result<()> { - let mut manager = ThreadStateManager::new(); + let manager = ThreadStateManager::new(); let thread_id = ThreadId::from_string("ad7f0408-99b8-4f6e-a46f-bd0eec433370")?; let listener = Uuid::new_v4(); let connection = ConnectionId(1); @@ -8143,7 +8323,7 @@ mod tests { .set_listener(listener, thread_id, connection, false) .await; { - let state = manager.thread_state(thread_id); + let state = manager.thread_state(thread_id).await; let mut state = state.lock().await; state.cancel_tx = Some(cancel_tx); state.track_current_turn_event(&EventMsg::TurnStarted( @@ -8163,9 +8343,14 @@ mod tests { ); assert_eq!(manager.remove_listener(listener).await, None); - let state = manager.thread_state(thread_id); + let state = manager.thread_state(thread_id).await; let state = state.lock().await; - assert!(state.subscribed_connection_ids().is_empty()); + assert!( + manager + .subscribed_connection_ids(thread_id) + .await + .is_empty() + ); assert!(state.cancel_tx.is_some()); let active_turn = state.active_turn_snapshot().expect("active turn snapshot"); assert_eq!(active_turn.id, "turn-1"); @@ -8175,16 +8360,18 @@ mod tests { #[tokio::test] async fn removing_thread_state_clears_listener_and_active_turn_history() -> Result<()> { - let mut manager = ThreadStateManager::new(); + let manager = ThreadStateManager::new(); let thread_id = ThreadId::from_string("ad7f0408-99b8-4f6e-a46f-bd0eec433370")?; let connection = ConnectionId(1); let (cancel_tx, cancel_rx) = oneshot::channel(); + manager.connection_initialized(connection).await; manager - .ensure_connection_subscribed(thread_id, connection, false) - .await; + .try_ensure_connection_subscribed(thread_id, connection, false) + .await + .expect("connection should be live"); { - let state = manager.thread_state(thread_id); + let state = manager.thread_state(thread_id).await; let mut state = state.lock().await; state.cancel_tx = Some(cancel_tx); state.track_current_turn_event(&EventMsg::TurnStarted( @@ -8199,9 +8386,14 @@ mod tests { manager.remove_thread_state(thread_id).await; assert_eq!(cancel_rx.await, Ok(())); - let state = manager.thread_state(thread_id); + let state = manager.thread_state(thread_id).await; let state = state.lock().await; - assert!(state.subscribed_connection_ids().is_empty()); + assert!( + manager + .subscribed_connection_ids(thread_id) + .await + .is_empty() + ); assert!(state.cancel_tx.is_none()); assert!(state.active_turn_snapshot().is_none()); Ok(()) @@ -8210,20 +8402,24 @@ mod tests { #[tokio::test] async fn removing_auto_attached_connection_preserves_listener_for_other_connections() -> Result<()> { - let mut manager = ThreadStateManager::new(); + let manager = ThreadStateManager::new(); let thread_id = ThreadId::from_string("ad7f0408-99b8-4f6e-a46f-bd0eec433370")?; let connection_a = ConnectionId(1); let connection_b = ConnectionId(2); let (cancel_tx, mut cancel_rx) = oneshot::channel(); + manager.connection_initialized(connection_a).await; + manager.connection_initialized(connection_b).await; manager - .ensure_connection_subscribed(thread_id, connection_a, false) - .await; + .try_ensure_connection_subscribed(thread_id, connection_a, false) + .await + .expect("connection_a should be live"); manager - .ensure_connection_subscribed(thread_id, connection_b, false) - .await; + .try_ensure_connection_subscribed(thread_id, connection_b, false) + .await + .expect("connection_b should be live"); { - let state = manager.thread_state(thread_id); + let state = manager.thread_state(thread_id).await; state.lock().await.cancel_tx = Some(cancel_tx); } @@ -8234,11 +8430,29 @@ mod tests { .is_err() ); - let state = manager.thread_state(thread_id); assert_eq!( - state.lock().await.subscribed_connection_ids(), + manager.subscribed_connection_ids(thread_id).await, vec![connection_b] ); Ok(()) } + + #[tokio::test] + async fn closed_connection_cannot_be_reintroduced_by_auto_subscribe() -> Result<()> { + let manager = ThreadStateManager::new(); + let thread_id = ThreadId::from_string("ad7f0408-99b8-4f6e-a46f-bd0eec433370")?; + let connection = ConnectionId(1); + + manager.connection_initialized(connection).await; + manager.remove_connection(connection).await; + + assert!( + manager + .try_ensure_connection_subscribed(thread_id, connection, false) + .await + .is_none() + ); + assert!(!manager.has_subscribers(thread_id).await); + Ok(()) + } } diff --git a/codex-rs/app-server/src/message_processor.rs b/codex-rs/app-server/src/message_processor.rs index 79f845ad4..2c3034aa1 100644 --- a/codex-rs/app-server/src/message_processor.rs +++ b/codex-rs/app-server/src/message_processor.rs @@ -338,6 +338,9 @@ impl MessageProcessor { session.initialized = true; outbound_initialized.store(true, Ordering::Release); + self.codex_message_processor + .connection_initialized(connection_id) + .await; return; } } diff --git a/codex-rs/app-server/src/thread_state.rs b/codex-rs/app-server/src/thread_state.rs index dbacb0ea7..a60dc07b5 100644 --- a/codex-rs/app-server/src/thread_state.rs +++ b/codex-rs/app-server/src/thread_state.rs @@ -60,7 +60,6 @@ pub(crate) struct ThreadState { listener_command_tx: Option>, current_turn_history: ThreadHistoryBuilder, listener_thread: Option>, - subscribed_connections: HashSet, } impl ThreadState { @@ -95,18 +94,6 @@ impl ThreadState { self.listener_thread = None; } - pub(crate) fn add_connection(&mut self, connection_id: ConnectionId) { - self.subscribed_connections.insert(connection_id); - } - - pub(crate) fn remove_connection(&mut self, connection_id: ConnectionId) { - self.subscribed_connections.remove(&connection_id); - } - - pub(crate) fn subscribed_connection_ids(&self) -> Vec { - self.subscribed_connections.iter().copied().collect() - } - pub(crate) fn set_experimental_raw_events(&mut self, enabled: bool) { self.experimental_raw_events = enabled; } @@ -135,55 +122,112 @@ struct SubscriptionState { connection_id: ConnectionId, } +struct ThreadEntry { + state: Arc>, + connection_ids: HashSet, +} + +impl Default for ThreadEntry { + fn default() -> Self { + Self { + state: Arc::new(Mutex::new(ThreadState::default())), + connection_ids: HashSet::new(), + } + } +} + #[derive(Default)] -pub(crate) struct ThreadStateManager { - thread_states: HashMap>>, +struct ThreadStateManagerInner { + live_connections: HashSet, + threads: HashMap, subscription_state_by_id: HashMap, thread_ids_by_connection: HashMap>, } +#[derive(Clone, Default)] +pub(crate) struct ThreadStateManager { + state: Arc>, +} + impl ThreadStateManager { pub(crate) fn new() -> Self { Self::default() } - pub(crate) fn thread_state(&mut self, thread_id: ThreadId) -> Arc> { - self.thread_states - .entry(thread_id) - .or_insert_with(|| Arc::new(Mutex::new(ThreadState::default()))) - .clone() + pub(crate) async fn connection_initialized(&self, connection_id: ConnectionId) { + self.state + .lock() + .await + .live_connections + .insert(connection_id); } - pub(crate) async fn remove_listener(&mut self, subscription_id: Uuid) -> Option { - let subscription_state = self.subscription_state_by_id.remove(&subscription_id)?; + pub(crate) async fn subscribed_connection_ids(&self, thread_id: ThreadId) -> Vec { + let state = self.state.lock().await; + state + .threads + .get(&thread_id) + .map(|thread_entry| thread_entry.connection_ids.iter().copied().collect()) + .unwrap_or_default() + } + + pub(crate) async fn thread_state(&self, thread_id: ThreadId) -> Arc> { + let mut state = self.state.lock().await; + state.threads.entry(thread_id).or_default().state.clone() + } + + pub(crate) async fn remove_listener(&self, subscription_id: Uuid) -> Option { + let (subscription_state, connection_still_subscribed_to_thread, thread_state) = { + let mut state = self.state.lock().await; + let subscription_state = state.subscription_state_by_id.remove(&subscription_id)?; + let thread_id = subscription_state.thread_id; + + let connection_still_subscribed_to_thread = state + .subscription_state_by_id + .values() + .any(|subscription_state_entry| { + subscription_state_entry.thread_id == thread_id + && subscription_state_entry.connection_id + == subscription_state.connection_id + }); + if !connection_still_subscribed_to_thread { + let mut remove_connection_entry = false; + if let Some(thread_ids) = state + .thread_ids_by_connection + .get_mut(&subscription_state.connection_id) + { + thread_ids.remove(&thread_id); + remove_connection_entry = thread_ids.is_empty(); + } + if remove_connection_entry { + state + .thread_ids_by_connection + .remove(&subscription_state.connection_id); + } + if let Some(thread_entry) = state.threads.get_mut(&thread_id) { + thread_entry + .connection_ids + .remove(&subscription_state.connection_id); + } + } + + let thread_state = state.threads.get(&thread_id).map(|thread_entry| { + ( + thread_entry.connection_ids.is_empty(), + thread_entry.state.clone(), + ) + }); + ( + subscription_state, + connection_still_subscribed_to_thread, + thread_state, + ) + }; let thread_id = subscription_state.thread_id; - let connection_still_subscribed_to_thread = - self.subscription_state_by_id.values().any(|state| { - state.thread_id == thread_id - && state.connection_id == subscription_state.connection_id - }); - if !connection_still_subscribed_to_thread { - let mut remove_connection_entry = false; - if let Some(thread_ids) = self - .thread_ids_by_connection - .get_mut(&subscription_state.connection_id) - { - thread_ids.remove(&thread_id); - remove_connection_entry = thread_ids.is_empty(); - } - if remove_connection_entry { - self.thread_ids_by_connection - .remove(&subscription_state.connection_id); - } - } - - if let Some(thread_state) = self.thread_states.get(&thread_id) { - let mut thread_state = thread_state.lock().await; - if !connection_still_subscribed_to_thread { - thread_state.remove_connection(subscription_state.connection_id); - } - if thread_state.subscribed_connection_ids().is_empty() { + if let Some((no_subscribers, thread_state)) = thread_state { + let thread_state = thread_state.lock().await; + if !connection_still_subscribed_to_thread && no_subscribers { tracing::debug!( thread_id = %thread_id, subscription_id = %subscription_id, @@ -196,8 +240,24 @@ impl ThreadStateManager { Some(thread_id) } - pub(crate) async fn remove_thread_state(&mut self, thread_id: ThreadId) { - if let Some(thread_state) = self.thread_states.remove(&thread_id) { + pub(crate) async fn remove_thread_state(&self, thread_id: ThreadId) { + let thread_state = { + let mut state = self.state.lock().await; + let thread_state = state + .threads + .remove(&thread_id) + .map(|thread_entry| thread_entry.state); + state + .subscription_state_by_id + .retain(|_, state| state.thread_id != thread_id); + state.thread_ids_by_connection.retain(|_, thread_ids| { + thread_ids.remove(&thread_id); + !thread_ids.is_empty() + }); + thread_state + }; + + if let Some(thread_state) = thread_state { let mut thread_state = thread_state.lock().await; tracing::debug!( thread_id = %thread_id, @@ -208,142 +268,189 @@ impl ThreadStateManager { ); thread_state.clear_listener(); } - self.subscription_state_by_id - .retain(|_, state| state.thread_id != thread_id); - self.thread_ids_by_connection.retain(|_, thread_ids| { - thread_ids.remove(&thread_id); - !thread_ids.is_empty() - }); } pub(crate) async fn unsubscribe_connection_from_thread( - &mut self, + &self, thread_id: ThreadId, connection_id: ConnectionId, ) -> bool { - let Some(thread_state) = self.thread_states.get(&thread_id) else { - return false; + { + let mut state = self.state.lock().await; + if !state.threads.contains_key(&thread_id) { + return false; + } + + if !state + .thread_ids_by_connection + .get(&connection_id) + .is_some_and(|thread_ids| thread_ids.contains(&thread_id)) + { + return false; + } + + if let Some(thread_ids) = state.thread_ids_by_connection.get_mut(&connection_id) { + thread_ids.remove(&thread_id); + if thread_ids.is_empty() { + state.thread_ids_by_connection.remove(&connection_id); + } + } + if let Some(thread_entry) = state.threads.get_mut(&thread_id) { + thread_entry.connection_ids.remove(&connection_id); + } + + state + .subscription_state_by_id + .retain(|_, subscription_state| { + !(subscription_state.thread_id == thread_id + && subscription_state.connection_id == connection_id) + }); }; - if !self - .thread_ids_by_connection - .get(&connection_id) - .is_some_and(|thread_ids| thread_ids.contains(&thread_id)) - { - return false; - } - - if let Some(thread_ids) = self.thread_ids_by_connection.get_mut(&connection_id) { - thread_ids.remove(&thread_id); - if thread_ids.is_empty() { - self.thread_ids_by_connection.remove(&connection_id); - } - } - - self.subscription_state_by_id.retain(|_, state| { - !(state.thread_id == thread_id && state.connection_id == connection_id) - }); - - let mut thread_state = thread_state.lock().await; - thread_state.remove_connection(connection_id); true } pub(crate) async fn has_subscribers(&self, thread_id: ThreadId) -> bool { - let Some(thread_state) = self.thread_states.get(&thread_id) else { - return false; - }; - !thread_state + self.state .lock() .await - .subscribed_connection_ids() - .is_empty() + .threads + .get(&thread_id) + .is_some_and(|thread_entry| !thread_entry.connection_ids.is_empty()) } pub(crate) async fn set_listener( - &mut self, + &self, subscription_id: Uuid, thread_id: ThreadId, connection_id: ConnectionId, experimental_raw_events: bool, ) -> Arc> { - self.subscription_state_by_id.insert( - subscription_id, - SubscriptionState { - thread_id, - connection_id, - }, - ); - self.thread_ids_by_connection - .entry(connection_id) - .or_default() - .insert(thread_id); - let thread_state = self.thread_state(thread_id); + let thread_state = { + let mut state = self.state.lock().await; + state.subscription_state_by_id.insert( + subscription_id, + SubscriptionState { + thread_id, + connection_id, + }, + ); + state + .thread_ids_by_connection + .entry(connection_id) + .or_default() + .insert(thread_id); + let thread_entry = state.threads.entry(thread_id).or_default(); + thread_entry.connection_ids.insert(connection_id); + thread_entry.state.clone() + }; { let mut thread_state_guard = thread_state.lock().await; - thread_state_guard.add_connection(connection_id); thread_state_guard.set_experimental_raw_events(experimental_raw_events); } thread_state } - pub(crate) async fn ensure_connection_subscribed( - &mut self, + pub(crate) async fn try_ensure_connection_subscribed( + &self, thread_id: ThreadId, connection_id: ConnectionId, experimental_raw_events: bool, - ) -> Arc> { - self.thread_ids_by_connection - .entry(connection_id) - .or_default() - .insert(thread_id); - let thread_state = self.thread_state(thread_id); + ) -> Option>> { + let thread_state = { + let mut state = self.state.lock().await; + if !state.live_connections.contains(&connection_id) { + return None; + } + state + .thread_ids_by_connection + .entry(connection_id) + .or_default() + .insert(thread_id); + let thread_entry = state.threads.entry(thread_id).or_default(); + thread_entry.connection_ids.insert(connection_id); + thread_entry.state.clone() + }; { let mut thread_state_guard = thread_state.lock().await; - thread_state_guard.add_connection(connection_id); if experimental_raw_events { thread_state_guard.set_experimental_raw_events(true); } } - thread_state + Some(thread_state) } - pub(crate) async fn remove_connection(&mut self, connection_id: ConnectionId) { - let thread_ids = self - .thread_ids_by_connection - .remove(&connection_id) - .unwrap_or_default(); - self.subscription_state_by_id - .retain(|_, state| state.connection_id != connection_id); - - if thread_ids.is_empty() { - for thread_state in self.thread_states.values() { - let mut thread_state = thread_state.lock().await; - thread_state.remove_connection(connection_id); - if thread_state.subscribed_connection_ids().is_empty() { - tracing::debug!( - connection_id = ?connection_id, - listener_generation = thread_state.listener_generation, - "retaining thread listener after connection disconnect left zero subscribers" - ); - } - } - return; + pub(crate) async fn try_add_connection_to_thread( + &self, + thread_id: ThreadId, + connection_id: ConnectionId, + ) -> bool { + let mut state = self.state.lock().await; + if !state.live_connections.contains(&connection_id) { + return false; } + state + .thread_ids_by_connection + .entry(connection_id) + .or_default() + .insert(thread_id); + state + .threads + .entry(thread_id) + .or_default() + .connection_ids + .insert(connection_id); + true + } - for thread_id in thread_ids { - if let Some(thread_state) = self.thread_states.get(&thread_id) { - let mut thread_state = thread_state.lock().await; - thread_state.remove_connection(connection_id); - if thread_state.subscribed_connection_ids().is_empty() { - tracing::debug!( - thread_id = %thread_id, - connection_id = ?connection_id, - listener_generation = thread_state.listener_generation, - "retaining thread listener after connection disconnect left zero subscribers" - ); + pub(crate) async fn remove_connection(&self, connection_id: ConnectionId) { + let thread_states = { + let mut state = self.state.lock().await; + state.live_connections.remove(&connection_id); + let thread_ids = state + .thread_ids_by_connection + .remove(&connection_id) + .unwrap_or_default(); + state + .subscription_state_by_id + .retain(|_, state| state.connection_id != connection_id); + for thread_id in &thread_ids { + if let Some(thread_entry) = state.threads.get_mut(thread_id) { + thread_entry.connection_ids.remove(&connection_id); } } + thread_ids + .into_iter() + .map(|thread_id| { + ( + thread_id, + state + .threads + .get(&thread_id) + .is_none_or(|thread_entry| thread_entry.connection_ids.is_empty()), + state + .threads + .get(&thread_id) + .map(|thread_entry| thread_entry.state.clone()), + ) + }) + .collect::>() + }; + + for (thread_id, no_subscribers, thread_state) in thread_states { + if !no_subscribers { + continue; + } + let Some(thread_state) = thread_state else { + continue; + }; + let listener_generation = thread_state.lock().await.listener_generation; + tracing::debug!( + thread_id = %thread_id, + connection_id = ?connection_id, + listener_generation, + "retaining thread listener after connection disconnect left zero subscribers" + ); } } }