fix(app-server): make thread/start non-blocking (#13033)

Stop `thread/start` from blocking other app-server requests.

Before this change, `thread/start ran` inline on the request loop, so
slow startup paths like MCP auth checks could hold up unrelated requests
on the same connection, including `thread/loaded/list`. This moves
`thread/start` into a background task.

While doing so, it revealed an issue where we were doing nested locking
(and there were some race conditions possible that could introduce a
"phantom listener"). This PR also refactors the listener/subscription
bookkeeping - listener/subscription state is now centralized in
`ThreadStateManager` instead of being split across multiple lock
domains. That makes late auto-attach on `thread/start` race-safe and
avoids reintroducing disconnected clients as phantom subscribers.
This commit is contained in:
Owen Lin 2026-02-27 17:40:08 -08:00 committed by GitHub
parent 6604608bad
commit 8fa792868c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 592 additions and 268 deletions

View file

@ -393,6 +393,23 @@ pub(crate) enum ApiVersion {
V2,
}
#[derive(Clone)]
struct ListenerTaskContext {
thread_manager: Arc<ThreadManager>,
thread_state_manager: ThreadStateManager,
outgoing: Arc<OutgoingMessageSender>,
thread_watch_manager: ThreadWatchManager,
fallback_model_provider: String,
codex_home: PathBuf,
single_client_mode: bool,
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
enum EnsureConversationListenerResult {
Attached,
ConnectionClosed,
}
pub(crate) struct CodexMessageProcessorArgs {
pub(crate) auth_manager: Arc<AuthManager>,
pub(crate) thread_manager: Arc<ThreadManager>,
@ -2025,7 +2042,7 @@ impl CodexMessageProcessor {
}
}
async fn thread_start(&mut self, request_id: ConnectionRequestId, params: ThreadStartParams) {
async fn thread_start(&self, request_id: ConnectionRequestId, params: ThreadStartParams) {
let ThreadStartParams {
model,
model_provider,
@ -2054,11 +2071,51 @@ impl CodexMessageProcessor {
personality,
);
typesafe_overrides.ephemeral = ephemeral;
let cli_overrides = self.cli_overrides.clone();
let cloud_requirements = self.current_cloud_requirements();
let listener_task_context = ListenerTaskContext {
thread_manager: Arc::clone(&self.thread_manager),
thread_state_manager: self.thread_state_manager.clone(),
outgoing: Arc::clone(&self.outgoing),
thread_watch_manager: self.thread_watch_manager.clone(),
fallback_model_provider: self.config.model_provider_id.clone(),
codex_home: self.config.codex_home.clone(),
single_client_mode: self.single_client_mode,
};
tokio::spawn(async move {
Self::thread_start_task(
listener_task_context,
cli_overrides,
cloud_requirements,
request_id,
config,
typesafe_overrides,
dynamic_tools,
persist_extended_history,
service_name,
experimental_raw_events,
)
.await;
});
}
#[allow(clippy::too_many_arguments)]
async fn thread_start_task(
listener_task_context: ListenerTaskContext,
cli_overrides: Vec<(String, TomlValue)>,
cloud_requirements: CloudRequirementsLoader,
request_id: ConnectionRequestId,
config_overrides: Option<HashMap<String, serde_json::Value>>,
typesafe_overrides: ConfigOverrides,
dynamic_tools: Option<Vec<ApiDynamicToolSpec>>,
persist_extended_history: bool,
service_name: Option<String>,
experimental_raw_events: bool,
) {
let config = match derive_config_from_params(
&self.cli_overrides,
config,
&cli_overrides,
config_overrides,
typesafe_overrides,
&cloud_requirements,
)
@ -2071,7 +2128,10 @@ impl CodexMessageProcessor {
message: format!("error deriving config: {err}"),
data: None,
};
self.outgoing.send_error(request_id, error).await;
listener_task_context
.outgoing
.send_error(request_id, error)
.await;
return;
}
};
@ -2086,7 +2146,10 @@ impl CodexMessageProcessor {
message,
data: None,
};
self.outgoing.send_error(request_id, error).await;
listener_task_context
.outgoing
.send_error(request_id, error)
.await;
return;
}
dynamic_tools
@ -2099,7 +2162,7 @@ impl CodexMessageProcessor {
.collect()
};
match self
match listener_task_context
.thread_manager
.start_thread_with_tools_and_service_name(
config,
@ -2124,29 +2187,28 @@ impl CodexMessageProcessor {
);
// Auto-attach a thread listener when starting a thread.
// Use the same behavior as the v1 API, with opt-in support for raw item events.
if let Err(err) = self
.ensure_conversation_listener(
Self::log_listener_attach_result(
Self::ensure_conversation_listener_task(
listener_task_context.clone(),
thread_id,
request_id.connection_id,
experimental_raw_events,
ApiVersion::V2,
)
.await
{
tracing::warn!(
"failed to attach listener for thread {}: {}",
thread_id,
err.message
);
}
.await,
thread_id,
request_id.connection_id,
"thread",
);
self.thread_watch_manager
listener_task_context
.thread_watch_manager
.upsert_thread(thread.clone())
.await;
thread.status = resolve_thread_status(
self.thread_watch_manager
listener_task_context
.thread_watch_manager
.loaded_status_for_thread(&thread.id)
.await,
false,
@ -2162,10 +2224,14 @@ impl CodexMessageProcessor {
reasoning_effort: config_snapshot.reasoning_effort,
};
self.outgoing.send_response(request_id, response).await;
listener_task_context
.outgoing
.send_response(request_id, response)
.await;
let notif = ThreadStartedNotification { thread };
self.outgoing
listener_task_context
.outgoing
.send_server_notification(ServerNotification::ThreadStarted(notif))
.await;
}
@ -2175,7 +2241,10 @@ impl CodexMessageProcessor {
message: format!("error creating thread: {err}"),
data: None,
};
self.outgoing.send_error(request_id, error).await;
listener_task_context
.outgoing
.send_error(request_id, error)
.await;
}
}
}
@ -2515,7 +2584,7 @@ impl CodexMessageProcessor {
let request = request_id.clone();
let rollback_already_in_progress = {
let thread_state = self.thread_state_manager.thread_state(thread_id);
let thread_state = self.thread_state_manager.thread_state(thread_id).await;
let mut thread_state = thread_state.lock().await;
if thread_state.pending_rollbacks.is_some() {
true
@ -2536,7 +2605,7 @@ impl CodexMessageProcessor {
if let Err(err) = thread.submit(Op::ThreadRollback { num_turns }).await {
// No ThreadRollback event will arrive if an error occurs.
// Clean up and reply immediately.
let thread_state = self.thread_state_manager.thread_state(thread_id);
let thread_state = self.thread_state_manager.thread_state(thread_id).await;
let mut thread_state = thread_state.lock().await;
thread_state.pending_rollbacks = None;
drop(thread_state);
@ -2882,6 +2951,12 @@ impl CodexMessageProcessor {
self.thread_manager.subscribe_thread_created()
}
pub(crate) async fn connection_initialized(&self, connection_id: ConnectionId) {
self.thread_state_manager
.connection_initialized(connection_id)
.await;
}
pub(crate) async fn connection_closed(&mut self, connection_id: ConnectionId) {
self.thread_state_manager
.remove_connection(connection_id)
@ -2906,15 +2981,13 @@ impl CodexMessageProcessor {
}
for connection_id in connection_ids {
if let Err(err) = self
.ensure_conversation_listener(thread_id, connection_id, false, ApiVersion::V2)
.await
{
warn!(
"failed to auto-attach listener for thread {thread_id}: {message}",
message = err.message
);
}
Self::log_listener_attach_result(
self.ensure_conversation_listener(thread_id, connection_id, false, ApiVersion::V2)
.await,
thread_id,
connection_id,
"thread",
);
}
}
@ -3039,21 +3112,18 @@ impl CodexMessageProcessor {
return;
};
// Auto-attach a thread listener when resuming a thread.
if let Err(err) = self
.ensure_conversation_listener(
Self::log_listener_attach_result(
self.ensure_conversation_listener(
thread_id,
request_id.connection_id,
false,
ApiVersion::V2,
)
.await
{
tracing::warn!(
"failed to attach listener for thread {}: {}",
thread_id,
err.message
);
}
.await,
thread_id,
request_id.connection_id,
"thread",
);
let Some(mut thread) = self
.load_thread_from_rollout_or_send_internal(
@ -3191,7 +3261,10 @@ impl CodexMessageProcessor {
return true;
}
let thread_state = self.thread_state_manager.thread_state(existing_thread_id);
let thread_state = self
.thread_state_manager
.thread_state(existing_thread_id)
.await;
self.ensure_listener_task_running(
existing_thread_id,
existing_thread.clone(),
@ -3542,21 +3615,18 @@ impl CodexMessageProcessor {
return;
};
// Auto-attach a conversation listener when forking a thread.
if let Err(err) = self
.ensure_conversation_listener(
Self::log_listener_attach_result(
self.ensure_conversation_listener(
thread_id,
request_id.connection_id,
false,
ApiVersion::V2,
)
.await
{
tracing::warn!(
"failed to attach listener for thread {}: {}",
thread_id,
err.message
);
}
.await,
thread_id,
request_id.connection_id,
"thread",
);
let mut thread = match read_summary_from_rollout(
rollout_path.as_path(),
@ -5692,7 +5762,10 @@ impl CodexMessageProcessor {
// Record the pending interrupt so we can reply when TurnAborted arrives.
{
let pending_interrupts = self.thread_state_manager.thread_state(conversation_id);
let pending_interrupts = self
.thread_state_manager
.thread_state(conversation_id)
.await;
let mut thread_state = pending_interrupts.lock().await;
thread_state
.pending_interrupts
@ -5895,7 +5968,7 @@ impl CodexMessageProcessor {
}
};
if let Err(error) = self
match self
.ensure_conversation_listener(
thread_id,
request_id.connection_id,
@ -5904,8 +5977,14 @@ impl CodexMessageProcessor {
)
.await
{
self.outgoing.send_error(request_id, error).await;
return None;
Ok(EnsureConversationListenerResult::Attached) => {}
Ok(EnsureConversationListenerResult::ConnectionClosed) => {
return None;
}
Err(error) => {
self.outgoing.send_error(request_id, error).await;
return None;
}
}
if !thread.enabled(Feature::RealtimeConversation) {
@ -6174,21 +6253,18 @@ impl CodexMessageProcessor {
data: None,
})?;
if let Err(err) = self
.ensure_conversation_listener(
Self::log_listener_attach_result(
self.ensure_conversation_listener(
thread_id,
request_id.connection_id,
false,
ApiVersion::V2,
)
.await
{
tracing::warn!(
"failed to attach listener for review thread {}: {}",
thread_id,
err.message
);
}
.await,
thread_id,
request_id.connection_id,
"review thread",
);
let fallback_provider = self.config.model_provider_id.as_str();
if let Some(rollout_path) = review_thread.rollout_path() {
@ -6315,7 +6391,7 @@ impl CodexMessageProcessor {
// Record the pending interrupt so we can reply when TurnAborted arrives.
{
let thread_state = self.thread_state_manager.thread_state(thread_uuid);
let thread_state = self.thread_state_manager.thread_state(thread_uuid).await;
let mut thread_state = thread_state.lock().await;
thread_state
.pending_interrupts
@ -6397,13 +6473,42 @@ impl CodexMessageProcessor {
}
async fn ensure_conversation_listener(
&mut self,
&self,
conversation_id: ThreadId,
connection_id: ConnectionId,
raw_events_enabled: bool,
api_version: ApiVersion,
) -> Result<(), JSONRPCErrorError> {
let conversation = match self.thread_manager.get_thread(conversation_id).await {
) -> Result<EnsureConversationListenerResult, JSONRPCErrorError> {
Self::ensure_conversation_listener_task(
ListenerTaskContext {
thread_manager: Arc::clone(&self.thread_manager),
thread_state_manager: self.thread_state_manager.clone(),
outgoing: Arc::clone(&self.outgoing),
thread_watch_manager: self.thread_watch_manager.clone(),
fallback_model_provider: self.config.model_provider_id.clone(),
codex_home: self.config.codex_home.clone(),
single_client_mode: self.single_client_mode,
},
conversation_id,
connection_id,
raw_events_enabled,
api_version,
)
.await
}
async fn ensure_conversation_listener_task(
listener_task_context: ListenerTaskContext,
conversation_id: ThreadId,
connection_id: ConnectionId,
raw_events_enabled: bool,
api_version: ApiVersion,
) -> Result<EnsureConversationListenerResult, JSONRPCErrorError> {
let conversation = match listener_task_context
.thread_manager
.get_thread(conversation_id)
.await
{
Ok(conv) => conv,
Err(_) => {
return Err(JSONRPCErrorError {
@ -6413,13 +6518,46 @@ impl CodexMessageProcessor {
});
}
};
let thread_state = self
let Some(thread_state) = listener_task_context
.thread_state_manager
.ensure_connection_subscribed(conversation_id, connection_id, raw_events_enabled)
.await;
self.ensure_listener_task_running(conversation_id, conversation, thread_state, api_version)
.await;
Ok(())
.try_ensure_connection_subscribed(conversation_id, connection_id, raw_events_enabled)
.await
else {
return Ok(EnsureConversationListenerResult::ConnectionClosed);
};
Self::ensure_listener_task_running_task(
listener_task_context,
conversation_id,
conversation,
thread_state,
api_version,
)
.await;
Ok(EnsureConversationListenerResult::Attached)
}
fn log_listener_attach_result(
result: Result<EnsureConversationListenerResult, JSONRPCErrorError>,
thread_id: ThreadId,
connection_id: ConnectionId,
thread_kind: &'static str,
) {
match result {
Ok(EnsureConversationListenerResult::Attached) => {}
Ok(EnsureConversationListenerResult::ConnectionClosed) => {
tracing::debug!(
thread_id = %thread_id,
connection_id = ?connection_id,
"skipping auto-attach for closed connection"
);
}
Err(err) => {
tracing::warn!(
"failed to attach listener for {thread_kind} {thread_id}: {message}",
message = err.message
);
}
}
}
async fn ensure_listener_task_running(
@ -6428,6 +6566,31 @@ impl CodexMessageProcessor {
conversation: Arc<CodexThread>,
thread_state: Arc<Mutex<ThreadState>>,
api_version: ApiVersion,
) {
Self::ensure_listener_task_running_task(
ListenerTaskContext {
thread_manager: Arc::clone(&self.thread_manager),
thread_state_manager: self.thread_state_manager.clone(),
outgoing: Arc::clone(&self.outgoing),
thread_watch_manager: self.thread_watch_manager.clone(),
fallback_model_provider: self.config.model_provider_id.clone(),
codex_home: self.config.codex_home.clone(),
single_client_mode: self.single_client_mode,
},
conversation_id,
conversation,
thread_state,
api_version,
)
.await;
}
async fn ensure_listener_task_running_task(
listener_task_context: ListenerTaskContext,
conversation_id: ThreadId,
conversation: Arc<CodexThread>,
thread_state: Arc<Mutex<ThreadState>>,
api_version: ApiVersion,
) {
let (cancel_tx, mut cancel_rx) = oneshot::channel();
let (mut listener_command_rx, listener_generation) = {
@ -6437,12 +6600,16 @@ impl CodexMessageProcessor {
}
thread_state.set_listener(cancel_tx, &conversation)
};
let outgoing_for_task = self.outgoing.clone();
let thread_manager = self.thread_manager.clone();
let thread_watch_manager = self.thread_watch_manager.clone();
let fallback_model_provider = self.config.model_provider_id.clone();
let single_client_mode = self.single_client_mode;
let codex_home = self.config.codex_home.clone();
let ListenerTaskContext {
outgoing,
thread_manager,
thread_state_manager,
thread_watch_manager,
fallback_model_provider,
codex_home,
single_client_mode,
} = listener_task_context;
let outgoing_for_task = Arc::clone(&outgoing);
tokio::spawn(async move {
loop {
tokio::select! {
@ -6488,16 +6655,16 @@ impl CodexMessageProcessor {
"conversationId".to_string(),
conversation_id.to_string().into(),
);
let (subscribed_connection_ids, raw_events_enabled) = {
let raw_events_enabled = {
let mut thread_state = thread_state.lock().await;
if !single_client_mode {
thread_state.track_current_turn_event(&event.msg);
}
(
thread_state.subscribed_connection_ids(),
thread_state.experimental_raw_events,
)
thread_state.experimental_raw_events
};
let subscribed_connection_ids = thread_state_manager
.subscribed_connection_ids(conversation_id)
.await;
if let EventMsg::RawResponseItem(_) = &event.msg && !raw_events_enabled {
continue;
}
@ -6540,6 +6707,7 @@ impl CodexMessageProcessor {
handle_thread_listener_command(
conversation_id,
codex_home.as_path(),
&thread_state_manager,
&thread_state,
&thread_watch_manager,
&outgoing_for_task,
@ -6857,6 +7025,7 @@ impl CodexMessageProcessor {
async fn handle_thread_listener_command(
conversation_id: ThreadId,
codex_home: &Path,
thread_state_manager: &ThreadStateManager,
thread_state: &Arc<Mutex<ThreadState>>,
thread_watch_manager: &ThreadWatchManager,
outgoing: &Arc<OutgoingMessageSender>,
@ -6867,6 +7036,7 @@ async fn handle_thread_listener_command(
handle_pending_thread_resume_request(
conversation_id,
codex_home,
thread_state_manager,
thread_state,
thread_watch_manager,
outgoing,
@ -6878,8 +7048,13 @@ async fn handle_thread_listener_command(
request_id,
completion_tx,
} => {
resolve_pending_server_request(conversation_id, thread_state, outgoing, request_id)
.await;
resolve_pending_server_request(
conversation_id,
thread_state_manager,
outgoing,
request_id,
)
.await;
let _ = completion_tx.send(());
}
}
@ -6888,6 +7063,7 @@ async fn handle_thread_listener_command(
async fn handle_pending_thread_resume_request(
conversation_id: ThreadId,
codex_home: &Path,
thread_state_manager: &ThreadStateManager,
thread_state: &Arc<Mutex<ThreadState>>,
thread_watch_manager: &ThreadWatchManager,
outgoing: &Arc<OutgoingMessageSender>,
@ -6976,18 +7152,21 @@ async fn handle_pending_thread_resume_request(
outgoing
.replay_requests_to_connection_for_thread(connection_id, conversation_id)
.await;
thread_state.lock().await.add_connection(connection_id);
let _attached = thread_state_manager
.try_add_connection_to_thread(conversation_id, connection_id)
.await;
}
async fn resolve_pending_server_request(
conversation_id: ThreadId,
thread_state: &Arc<Mutex<ThreadState>>,
thread_state_manager: &ThreadStateManager,
outgoing: &Arc<OutgoingMessageSender>,
request_id: RequestId,
) {
let thread_id = conversation_id.to_string();
let subscribed_connection_ids = thread_state.lock().await.subscribed_connection_ids();
let subscribed_connection_ids = thread_state_manager
.subscribed_connection_ids(conversation_id)
.await;
let outgoing = ThreadScopedOutgoingMessageSender::new(
outgoing.clone(),
subscribed_connection_ids,
@ -7951,9 +8130,7 @@ mod tests {
#[tokio::test]
async fn aborting_pending_request_clears_pending_state() -> Result<()> {
let thread_id = ThreadId::from_string("bfd12a78-5900-467b-9bc5-d3d35df08191")?;
let thread_state = Arc::new(Mutex::new(ThreadState::default()));
let connection_id = ConnectionId(7);
thread_state.lock().await.add_connection(connection_id);
let (outgoing_tx, mut outgoing_rx) = tokio::sync::mpsc::channel(8);
let outgoing = Arc::new(OutgoingMessageSender::new(outgoing_tx));
@ -8047,7 +8224,7 @@ mod tests {
#[tokio::test]
async fn removing_listeners_retains_thread_listener_when_last_subscriber_leaves() -> Result<()>
{
let mut manager = ThreadStateManager::new();
let manager = ThreadStateManager::new();
let thread_id = ThreadId::from_string("ad7f0408-99b8-4f6e-a46f-bd0eec433370")?;
let listener_a = Uuid::new_v4();
let listener_b = Uuid::new_v4();
@ -8062,7 +8239,7 @@ mod tests {
.set_listener(listener_b, thread_id, connection_b, false)
.await;
{
let state = manager.thread_state(thread_id);
let state = manager.thread_state(thread_id).await;
state.lock().await.cancel_tx = Some(cancel_tx);
}
@ -8078,14 +8255,18 @@ mod tests {
.await
.is_err()
);
let state = manager.thread_state(thread_id);
assert!(state.lock().await.subscribed_connection_ids().is_empty());
assert!(
manager
.subscribed_connection_ids(thread_id)
.await
.is_empty()
);
Ok(())
}
#[tokio::test]
async fn removing_listener_unsubscribes_its_connection() -> Result<()> {
let mut manager = ThreadStateManager::new();
let manager = ThreadStateManager::new();
let thread_id = ThreadId::from_string("ad7f0408-99b8-4f6e-a46f-bd0eec433370")?;
let listener_a = Uuid::new_v4();
let listener_b = Uuid::new_v4();
@ -8100,15 +8281,14 @@ mod tests {
.await;
assert_eq!(manager.remove_listener(listener_a).await, Some(thread_id));
let state = manager.thread_state(thread_id);
let subscribed_connection_ids = state.lock().await.subscribed_connection_ids();
let subscribed_connection_ids = manager.subscribed_connection_ids(thread_id).await;
assert_eq!(subscribed_connection_ids, vec![connection_b]);
Ok(())
}
#[tokio::test]
async fn set_listener_uses_last_write_for_raw_events() -> Result<()> {
let mut manager = ThreadStateManager::new();
let manager = ThreadStateManager::new();
let thread_id = ThreadId::from_string("ad7f0408-99b8-4f6e-a46f-bd0eec433370")?;
let listener_a = Uuid::new_v4();
let listener_b = Uuid::new_v4();
@ -8119,13 +8299,13 @@ mod tests {
.set_listener(listener_a, thread_id, connection_a, true)
.await;
{
let state = manager.thread_state(thread_id);
let state = manager.thread_state(thread_id).await;
assert!(state.lock().await.experimental_raw_events);
}
manager
.set_listener(listener_b, thread_id, connection_b, false)
.await;
let state = manager.thread_state(thread_id);
let state = manager.thread_state(thread_id).await;
assert!(!state.lock().await.experimental_raw_events);
Ok(())
}
@ -8133,7 +8313,7 @@ mod tests {
#[tokio::test]
async fn removing_connection_retains_listener_and_active_turn_when_last_subscriber_disconnects()
-> Result<()> {
let mut manager = ThreadStateManager::new();
let manager = ThreadStateManager::new();
let thread_id = ThreadId::from_string("ad7f0408-99b8-4f6e-a46f-bd0eec433370")?;
let listener = Uuid::new_v4();
let connection = ConnectionId(1);
@ -8143,7 +8323,7 @@ mod tests {
.set_listener(listener, thread_id, connection, false)
.await;
{
let state = manager.thread_state(thread_id);
let state = manager.thread_state(thread_id).await;
let mut state = state.lock().await;
state.cancel_tx = Some(cancel_tx);
state.track_current_turn_event(&EventMsg::TurnStarted(
@ -8163,9 +8343,14 @@ mod tests {
);
assert_eq!(manager.remove_listener(listener).await, None);
let state = manager.thread_state(thread_id);
let state = manager.thread_state(thread_id).await;
let state = state.lock().await;
assert!(state.subscribed_connection_ids().is_empty());
assert!(
manager
.subscribed_connection_ids(thread_id)
.await
.is_empty()
);
assert!(state.cancel_tx.is_some());
let active_turn = state.active_turn_snapshot().expect("active turn snapshot");
assert_eq!(active_turn.id, "turn-1");
@ -8175,16 +8360,18 @@ mod tests {
#[tokio::test]
async fn removing_thread_state_clears_listener_and_active_turn_history() -> Result<()> {
let mut manager = ThreadStateManager::new();
let manager = ThreadStateManager::new();
let thread_id = ThreadId::from_string("ad7f0408-99b8-4f6e-a46f-bd0eec433370")?;
let connection = ConnectionId(1);
let (cancel_tx, cancel_rx) = oneshot::channel();
manager.connection_initialized(connection).await;
manager
.ensure_connection_subscribed(thread_id, connection, false)
.await;
.try_ensure_connection_subscribed(thread_id, connection, false)
.await
.expect("connection should be live");
{
let state = manager.thread_state(thread_id);
let state = manager.thread_state(thread_id).await;
let mut state = state.lock().await;
state.cancel_tx = Some(cancel_tx);
state.track_current_turn_event(&EventMsg::TurnStarted(
@ -8199,9 +8386,14 @@ mod tests {
manager.remove_thread_state(thread_id).await;
assert_eq!(cancel_rx.await, Ok(()));
let state = manager.thread_state(thread_id);
let state = manager.thread_state(thread_id).await;
let state = state.lock().await;
assert!(state.subscribed_connection_ids().is_empty());
assert!(
manager
.subscribed_connection_ids(thread_id)
.await
.is_empty()
);
assert!(state.cancel_tx.is_none());
assert!(state.active_turn_snapshot().is_none());
Ok(())
@ -8210,20 +8402,24 @@ mod tests {
#[tokio::test]
async fn removing_auto_attached_connection_preserves_listener_for_other_connections()
-> Result<()> {
let mut manager = ThreadStateManager::new();
let manager = ThreadStateManager::new();
let thread_id = ThreadId::from_string("ad7f0408-99b8-4f6e-a46f-bd0eec433370")?;
let connection_a = ConnectionId(1);
let connection_b = ConnectionId(2);
let (cancel_tx, mut cancel_rx) = oneshot::channel();
manager.connection_initialized(connection_a).await;
manager.connection_initialized(connection_b).await;
manager
.ensure_connection_subscribed(thread_id, connection_a, false)
.await;
.try_ensure_connection_subscribed(thread_id, connection_a, false)
.await
.expect("connection_a should be live");
manager
.ensure_connection_subscribed(thread_id, connection_b, false)
.await;
.try_ensure_connection_subscribed(thread_id, connection_b, false)
.await
.expect("connection_b should be live");
{
let state = manager.thread_state(thread_id);
let state = manager.thread_state(thread_id).await;
state.lock().await.cancel_tx = Some(cancel_tx);
}
@ -8234,11 +8430,29 @@ mod tests {
.is_err()
);
let state = manager.thread_state(thread_id);
assert_eq!(
state.lock().await.subscribed_connection_ids(),
manager.subscribed_connection_ids(thread_id).await,
vec![connection_b]
);
Ok(())
}
#[tokio::test]
async fn closed_connection_cannot_be_reintroduced_by_auto_subscribe() -> Result<()> {
let manager = ThreadStateManager::new();
let thread_id = ThreadId::from_string("ad7f0408-99b8-4f6e-a46f-bd0eec433370")?;
let connection = ConnectionId(1);
manager.connection_initialized(connection).await;
manager.remove_connection(connection).await;
assert!(
manager
.try_ensure_connection_subscribed(thread_id, connection, false)
.await
.is_none()
);
assert!(!manager.has_subscribers(thread_id).await);
Ok(())
}
}

View file

@ -338,6 +338,9 @@ impl MessageProcessor {
session.initialized = true;
outbound_initialized.store(true, Ordering::Release);
self.codex_message_processor
.connection_initialized(connection_id)
.await;
return;
}
}

View file

@ -60,7 +60,6 @@ pub(crate) struct ThreadState {
listener_command_tx: Option<mpsc::UnboundedSender<ThreadListenerCommand>>,
current_turn_history: ThreadHistoryBuilder,
listener_thread: Option<Weak<CodexThread>>,
subscribed_connections: HashSet<ConnectionId>,
}
impl ThreadState {
@ -95,18 +94,6 @@ impl ThreadState {
self.listener_thread = None;
}
pub(crate) fn add_connection(&mut self, connection_id: ConnectionId) {
self.subscribed_connections.insert(connection_id);
}
pub(crate) fn remove_connection(&mut self, connection_id: ConnectionId) {
self.subscribed_connections.remove(&connection_id);
}
pub(crate) fn subscribed_connection_ids(&self) -> Vec<ConnectionId> {
self.subscribed_connections.iter().copied().collect()
}
pub(crate) fn set_experimental_raw_events(&mut self, enabled: bool) {
self.experimental_raw_events = enabled;
}
@ -135,55 +122,112 @@ struct SubscriptionState {
connection_id: ConnectionId,
}
struct ThreadEntry {
state: Arc<Mutex<ThreadState>>,
connection_ids: HashSet<ConnectionId>,
}
impl Default for ThreadEntry {
fn default() -> Self {
Self {
state: Arc::new(Mutex::new(ThreadState::default())),
connection_ids: HashSet::new(),
}
}
}
#[derive(Default)]
pub(crate) struct ThreadStateManager {
thread_states: HashMap<ThreadId, Arc<Mutex<ThreadState>>>,
struct ThreadStateManagerInner {
live_connections: HashSet<ConnectionId>,
threads: HashMap<ThreadId, ThreadEntry>,
subscription_state_by_id: HashMap<Uuid, SubscriptionState>,
thread_ids_by_connection: HashMap<ConnectionId, HashSet<ThreadId>>,
}
#[derive(Clone, Default)]
pub(crate) struct ThreadStateManager {
state: Arc<Mutex<ThreadStateManagerInner>>,
}
impl ThreadStateManager {
pub(crate) fn new() -> Self {
Self::default()
}
pub(crate) fn thread_state(&mut self, thread_id: ThreadId) -> Arc<Mutex<ThreadState>> {
self.thread_states
.entry(thread_id)
.or_insert_with(|| Arc::new(Mutex::new(ThreadState::default())))
.clone()
pub(crate) async fn connection_initialized(&self, connection_id: ConnectionId) {
self.state
.lock()
.await
.live_connections
.insert(connection_id);
}
pub(crate) async fn remove_listener(&mut self, subscription_id: Uuid) -> Option<ThreadId> {
let subscription_state = self.subscription_state_by_id.remove(&subscription_id)?;
pub(crate) async fn subscribed_connection_ids(&self, thread_id: ThreadId) -> Vec<ConnectionId> {
let state = self.state.lock().await;
state
.threads
.get(&thread_id)
.map(|thread_entry| thread_entry.connection_ids.iter().copied().collect())
.unwrap_or_default()
}
pub(crate) async fn thread_state(&self, thread_id: ThreadId) -> Arc<Mutex<ThreadState>> {
let mut state = self.state.lock().await;
state.threads.entry(thread_id).or_default().state.clone()
}
pub(crate) async fn remove_listener(&self, subscription_id: Uuid) -> Option<ThreadId> {
let (subscription_state, connection_still_subscribed_to_thread, thread_state) = {
let mut state = self.state.lock().await;
let subscription_state = state.subscription_state_by_id.remove(&subscription_id)?;
let thread_id = subscription_state.thread_id;
let connection_still_subscribed_to_thread = state
.subscription_state_by_id
.values()
.any(|subscription_state_entry| {
subscription_state_entry.thread_id == thread_id
&& subscription_state_entry.connection_id
== subscription_state.connection_id
});
if !connection_still_subscribed_to_thread {
let mut remove_connection_entry = false;
if let Some(thread_ids) = state
.thread_ids_by_connection
.get_mut(&subscription_state.connection_id)
{
thread_ids.remove(&thread_id);
remove_connection_entry = thread_ids.is_empty();
}
if remove_connection_entry {
state
.thread_ids_by_connection
.remove(&subscription_state.connection_id);
}
if let Some(thread_entry) = state.threads.get_mut(&thread_id) {
thread_entry
.connection_ids
.remove(&subscription_state.connection_id);
}
}
let thread_state = state.threads.get(&thread_id).map(|thread_entry| {
(
thread_entry.connection_ids.is_empty(),
thread_entry.state.clone(),
)
});
(
subscription_state,
connection_still_subscribed_to_thread,
thread_state,
)
};
let thread_id = subscription_state.thread_id;
let connection_still_subscribed_to_thread =
self.subscription_state_by_id.values().any(|state| {
state.thread_id == thread_id
&& state.connection_id == subscription_state.connection_id
});
if !connection_still_subscribed_to_thread {
let mut remove_connection_entry = false;
if let Some(thread_ids) = self
.thread_ids_by_connection
.get_mut(&subscription_state.connection_id)
{
thread_ids.remove(&thread_id);
remove_connection_entry = thread_ids.is_empty();
}
if remove_connection_entry {
self.thread_ids_by_connection
.remove(&subscription_state.connection_id);
}
}
if let Some(thread_state) = self.thread_states.get(&thread_id) {
let mut thread_state = thread_state.lock().await;
if !connection_still_subscribed_to_thread {
thread_state.remove_connection(subscription_state.connection_id);
}
if thread_state.subscribed_connection_ids().is_empty() {
if let Some((no_subscribers, thread_state)) = thread_state {
let thread_state = thread_state.lock().await;
if !connection_still_subscribed_to_thread && no_subscribers {
tracing::debug!(
thread_id = %thread_id,
subscription_id = %subscription_id,
@ -196,8 +240,24 @@ impl ThreadStateManager {
Some(thread_id)
}
pub(crate) async fn remove_thread_state(&mut self, thread_id: ThreadId) {
if let Some(thread_state) = self.thread_states.remove(&thread_id) {
pub(crate) async fn remove_thread_state(&self, thread_id: ThreadId) {
let thread_state = {
let mut state = self.state.lock().await;
let thread_state = state
.threads
.remove(&thread_id)
.map(|thread_entry| thread_entry.state);
state
.subscription_state_by_id
.retain(|_, state| state.thread_id != thread_id);
state.thread_ids_by_connection.retain(|_, thread_ids| {
thread_ids.remove(&thread_id);
!thread_ids.is_empty()
});
thread_state
};
if let Some(thread_state) = thread_state {
let mut thread_state = thread_state.lock().await;
tracing::debug!(
thread_id = %thread_id,
@ -208,142 +268,189 @@ impl ThreadStateManager {
);
thread_state.clear_listener();
}
self.subscription_state_by_id
.retain(|_, state| state.thread_id != thread_id);
self.thread_ids_by_connection.retain(|_, thread_ids| {
thread_ids.remove(&thread_id);
!thread_ids.is_empty()
});
}
pub(crate) async fn unsubscribe_connection_from_thread(
&mut self,
&self,
thread_id: ThreadId,
connection_id: ConnectionId,
) -> bool {
let Some(thread_state) = self.thread_states.get(&thread_id) else {
return false;
{
let mut state = self.state.lock().await;
if !state.threads.contains_key(&thread_id) {
return false;
}
if !state
.thread_ids_by_connection
.get(&connection_id)
.is_some_and(|thread_ids| thread_ids.contains(&thread_id))
{
return false;
}
if let Some(thread_ids) = state.thread_ids_by_connection.get_mut(&connection_id) {
thread_ids.remove(&thread_id);
if thread_ids.is_empty() {
state.thread_ids_by_connection.remove(&connection_id);
}
}
if let Some(thread_entry) = state.threads.get_mut(&thread_id) {
thread_entry.connection_ids.remove(&connection_id);
}
state
.subscription_state_by_id
.retain(|_, subscription_state| {
!(subscription_state.thread_id == thread_id
&& subscription_state.connection_id == connection_id)
});
};
if !self
.thread_ids_by_connection
.get(&connection_id)
.is_some_and(|thread_ids| thread_ids.contains(&thread_id))
{
return false;
}
if let Some(thread_ids) = self.thread_ids_by_connection.get_mut(&connection_id) {
thread_ids.remove(&thread_id);
if thread_ids.is_empty() {
self.thread_ids_by_connection.remove(&connection_id);
}
}
self.subscription_state_by_id.retain(|_, state| {
!(state.thread_id == thread_id && state.connection_id == connection_id)
});
let mut thread_state = thread_state.lock().await;
thread_state.remove_connection(connection_id);
true
}
pub(crate) async fn has_subscribers(&self, thread_id: ThreadId) -> bool {
let Some(thread_state) = self.thread_states.get(&thread_id) else {
return false;
};
!thread_state
self.state
.lock()
.await
.subscribed_connection_ids()
.is_empty()
.threads
.get(&thread_id)
.is_some_and(|thread_entry| !thread_entry.connection_ids.is_empty())
}
pub(crate) async fn set_listener(
&mut self,
&self,
subscription_id: Uuid,
thread_id: ThreadId,
connection_id: ConnectionId,
experimental_raw_events: bool,
) -> Arc<Mutex<ThreadState>> {
self.subscription_state_by_id.insert(
subscription_id,
SubscriptionState {
thread_id,
connection_id,
},
);
self.thread_ids_by_connection
.entry(connection_id)
.or_default()
.insert(thread_id);
let thread_state = self.thread_state(thread_id);
let thread_state = {
let mut state = self.state.lock().await;
state.subscription_state_by_id.insert(
subscription_id,
SubscriptionState {
thread_id,
connection_id,
},
);
state
.thread_ids_by_connection
.entry(connection_id)
.or_default()
.insert(thread_id);
let thread_entry = state.threads.entry(thread_id).or_default();
thread_entry.connection_ids.insert(connection_id);
thread_entry.state.clone()
};
{
let mut thread_state_guard = thread_state.lock().await;
thread_state_guard.add_connection(connection_id);
thread_state_guard.set_experimental_raw_events(experimental_raw_events);
}
thread_state
}
pub(crate) async fn ensure_connection_subscribed(
&mut self,
pub(crate) async fn try_ensure_connection_subscribed(
&self,
thread_id: ThreadId,
connection_id: ConnectionId,
experimental_raw_events: bool,
) -> Arc<Mutex<ThreadState>> {
self.thread_ids_by_connection
.entry(connection_id)
.or_default()
.insert(thread_id);
let thread_state = self.thread_state(thread_id);
) -> Option<Arc<Mutex<ThreadState>>> {
let thread_state = {
let mut state = self.state.lock().await;
if !state.live_connections.contains(&connection_id) {
return None;
}
state
.thread_ids_by_connection
.entry(connection_id)
.or_default()
.insert(thread_id);
let thread_entry = state.threads.entry(thread_id).or_default();
thread_entry.connection_ids.insert(connection_id);
thread_entry.state.clone()
};
{
let mut thread_state_guard = thread_state.lock().await;
thread_state_guard.add_connection(connection_id);
if experimental_raw_events {
thread_state_guard.set_experimental_raw_events(true);
}
}
thread_state
Some(thread_state)
}
pub(crate) async fn remove_connection(&mut self, connection_id: ConnectionId) {
let thread_ids = self
.thread_ids_by_connection
.remove(&connection_id)
.unwrap_or_default();
self.subscription_state_by_id
.retain(|_, state| state.connection_id != connection_id);
if thread_ids.is_empty() {
for thread_state in self.thread_states.values() {
let mut thread_state = thread_state.lock().await;
thread_state.remove_connection(connection_id);
if thread_state.subscribed_connection_ids().is_empty() {
tracing::debug!(
connection_id = ?connection_id,
listener_generation = thread_state.listener_generation,
"retaining thread listener after connection disconnect left zero subscribers"
);
}
}
return;
pub(crate) async fn try_add_connection_to_thread(
&self,
thread_id: ThreadId,
connection_id: ConnectionId,
) -> bool {
let mut state = self.state.lock().await;
if !state.live_connections.contains(&connection_id) {
return false;
}
state
.thread_ids_by_connection
.entry(connection_id)
.or_default()
.insert(thread_id);
state
.threads
.entry(thread_id)
.or_default()
.connection_ids
.insert(connection_id);
true
}
for thread_id in thread_ids {
if let Some(thread_state) = self.thread_states.get(&thread_id) {
let mut thread_state = thread_state.lock().await;
thread_state.remove_connection(connection_id);
if thread_state.subscribed_connection_ids().is_empty() {
tracing::debug!(
thread_id = %thread_id,
connection_id = ?connection_id,
listener_generation = thread_state.listener_generation,
"retaining thread listener after connection disconnect left zero subscribers"
);
pub(crate) async fn remove_connection(&self, connection_id: ConnectionId) {
let thread_states = {
let mut state = self.state.lock().await;
state.live_connections.remove(&connection_id);
let thread_ids = state
.thread_ids_by_connection
.remove(&connection_id)
.unwrap_or_default();
state
.subscription_state_by_id
.retain(|_, state| state.connection_id != connection_id);
for thread_id in &thread_ids {
if let Some(thread_entry) = state.threads.get_mut(thread_id) {
thread_entry.connection_ids.remove(&connection_id);
}
}
thread_ids
.into_iter()
.map(|thread_id| {
(
thread_id,
state
.threads
.get(&thread_id)
.is_none_or(|thread_entry| thread_entry.connection_ids.is_empty()),
state
.threads
.get(&thread_id)
.map(|thread_entry| thread_entry.state.clone()),
)
})
.collect::<Vec<_>>()
};
for (thread_id, no_subscribers, thread_state) in thread_states {
if !no_subscribers {
continue;
}
let Some(thread_state) = thread_state else {
continue;
};
let listener_generation = thread_state.lock().await.listener_generation;
tracing::debug!(
thread_id = %thread_id,
connection_id = ?connection_id,
listener_generation,
"retaining thread listener after connection disconnect left zero subscribers"
);
}
}
}