fix(app-server): make thread/start non-blocking (#13033)
Stop `thread/start` from blocking other app-server requests. Before this change, `thread/start ran` inline on the request loop, so slow startup paths like MCP auth checks could hold up unrelated requests on the same connection, including `thread/loaded/list`. This moves `thread/start` into a background task. While doing so, it revealed an issue where we were doing nested locking (and there were some race conditions possible that could introduce a "phantom listener"). This PR also refactors the listener/subscription bookkeeping - listener/subscription state is now centralized in `ThreadStateManager` instead of being split across multiple lock domains. That makes late auto-attach on `thread/start` race-safe and avoids reintroducing disconnected clients as phantom subscribers.
This commit is contained in:
parent
6604608bad
commit
8fa792868c
3 changed files with 592 additions and 268 deletions
|
|
@ -393,6 +393,23 @@ pub(crate) enum ApiVersion {
|
|||
V2,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct ListenerTaskContext {
|
||||
thread_manager: Arc<ThreadManager>,
|
||||
thread_state_manager: ThreadStateManager,
|
||||
outgoing: Arc<OutgoingMessageSender>,
|
||||
thread_watch_manager: ThreadWatchManager,
|
||||
fallback_model_provider: String,
|
||||
codex_home: PathBuf,
|
||||
single_client_mode: bool,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
|
||||
enum EnsureConversationListenerResult {
|
||||
Attached,
|
||||
ConnectionClosed,
|
||||
}
|
||||
|
||||
pub(crate) struct CodexMessageProcessorArgs {
|
||||
pub(crate) auth_manager: Arc<AuthManager>,
|
||||
pub(crate) thread_manager: Arc<ThreadManager>,
|
||||
|
|
@ -2025,7 +2042,7 @@ impl CodexMessageProcessor {
|
|||
}
|
||||
}
|
||||
|
||||
async fn thread_start(&mut self, request_id: ConnectionRequestId, params: ThreadStartParams) {
|
||||
async fn thread_start(&self, request_id: ConnectionRequestId, params: ThreadStartParams) {
|
||||
let ThreadStartParams {
|
||||
model,
|
||||
model_provider,
|
||||
|
|
@ -2054,11 +2071,51 @@ impl CodexMessageProcessor {
|
|||
personality,
|
||||
);
|
||||
typesafe_overrides.ephemeral = ephemeral;
|
||||
|
||||
let cli_overrides = self.cli_overrides.clone();
|
||||
let cloud_requirements = self.current_cloud_requirements();
|
||||
let listener_task_context = ListenerTaskContext {
|
||||
thread_manager: Arc::clone(&self.thread_manager),
|
||||
thread_state_manager: self.thread_state_manager.clone(),
|
||||
outgoing: Arc::clone(&self.outgoing),
|
||||
thread_watch_manager: self.thread_watch_manager.clone(),
|
||||
fallback_model_provider: self.config.model_provider_id.clone(),
|
||||
codex_home: self.config.codex_home.clone(),
|
||||
single_client_mode: self.single_client_mode,
|
||||
};
|
||||
|
||||
tokio::spawn(async move {
|
||||
Self::thread_start_task(
|
||||
listener_task_context,
|
||||
cli_overrides,
|
||||
cloud_requirements,
|
||||
request_id,
|
||||
config,
|
||||
typesafe_overrides,
|
||||
dynamic_tools,
|
||||
persist_extended_history,
|
||||
service_name,
|
||||
experimental_raw_events,
|
||||
)
|
||||
.await;
|
||||
});
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
async fn thread_start_task(
|
||||
listener_task_context: ListenerTaskContext,
|
||||
cli_overrides: Vec<(String, TomlValue)>,
|
||||
cloud_requirements: CloudRequirementsLoader,
|
||||
request_id: ConnectionRequestId,
|
||||
config_overrides: Option<HashMap<String, serde_json::Value>>,
|
||||
typesafe_overrides: ConfigOverrides,
|
||||
dynamic_tools: Option<Vec<ApiDynamicToolSpec>>,
|
||||
persist_extended_history: bool,
|
||||
service_name: Option<String>,
|
||||
experimental_raw_events: bool,
|
||||
) {
|
||||
let config = match derive_config_from_params(
|
||||
&self.cli_overrides,
|
||||
config,
|
||||
&cli_overrides,
|
||||
config_overrides,
|
||||
typesafe_overrides,
|
||||
&cloud_requirements,
|
||||
)
|
||||
|
|
@ -2071,7 +2128,10 @@ impl CodexMessageProcessor {
|
|||
message: format!("error deriving config: {err}"),
|
||||
data: None,
|
||||
};
|
||||
self.outgoing.send_error(request_id, error).await;
|
||||
listener_task_context
|
||||
.outgoing
|
||||
.send_error(request_id, error)
|
||||
.await;
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
|
@ -2086,7 +2146,10 @@ impl CodexMessageProcessor {
|
|||
message,
|
||||
data: None,
|
||||
};
|
||||
self.outgoing.send_error(request_id, error).await;
|
||||
listener_task_context
|
||||
.outgoing
|
||||
.send_error(request_id, error)
|
||||
.await;
|
||||
return;
|
||||
}
|
||||
dynamic_tools
|
||||
|
|
@ -2099,7 +2162,7 @@ impl CodexMessageProcessor {
|
|||
.collect()
|
||||
};
|
||||
|
||||
match self
|
||||
match listener_task_context
|
||||
.thread_manager
|
||||
.start_thread_with_tools_and_service_name(
|
||||
config,
|
||||
|
|
@ -2124,29 +2187,28 @@ impl CodexMessageProcessor {
|
|||
);
|
||||
|
||||
// Auto-attach a thread listener when starting a thread.
|
||||
// Use the same behavior as the v1 API, with opt-in support for raw item events.
|
||||
if let Err(err) = self
|
||||
.ensure_conversation_listener(
|
||||
Self::log_listener_attach_result(
|
||||
Self::ensure_conversation_listener_task(
|
||||
listener_task_context.clone(),
|
||||
thread_id,
|
||||
request_id.connection_id,
|
||||
experimental_raw_events,
|
||||
ApiVersion::V2,
|
||||
)
|
||||
.await
|
||||
{
|
||||
tracing::warn!(
|
||||
"failed to attach listener for thread {}: {}",
|
||||
thread_id,
|
||||
err.message
|
||||
);
|
||||
}
|
||||
.await,
|
||||
thread_id,
|
||||
request_id.connection_id,
|
||||
"thread",
|
||||
);
|
||||
|
||||
self.thread_watch_manager
|
||||
listener_task_context
|
||||
.thread_watch_manager
|
||||
.upsert_thread(thread.clone())
|
||||
.await;
|
||||
|
||||
thread.status = resolve_thread_status(
|
||||
self.thread_watch_manager
|
||||
listener_task_context
|
||||
.thread_watch_manager
|
||||
.loaded_status_for_thread(&thread.id)
|
||||
.await,
|
||||
false,
|
||||
|
|
@ -2162,10 +2224,14 @@ impl CodexMessageProcessor {
|
|||
reasoning_effort: config_snapshot.reasoning_effort,
|
||||
};
|
||||
|
||||
self.outgoing.send_response(request_id, response).await;
|
||||
listener_task_context
|
||||
.outgoing
|
||||
.send_response(request_id, response)
|
||||
.await;
|
||||
|
||||
let notif = ThreadStartedNotification { thread };
|
||||
self.outgoing
|
||||
listener_task_context
|
||||
.outgoing
|
||||
.send_server_notification(ServerNotification::ThreadStarted(notif))
|
||||
.await;
|
||||
}
|
||||
|
|
@ -2175,7 +2241,10 @@ impl CodexMessageProcessor {
|
|||
message: format!("error creating thread: {err}"),
|
||||
data: None,
|
||||
};
|
||||
self.outgoing.send_error(request_id, error).await;
|
||||
listener_task_context
|
||||
.outgoing
|
||||
.send_error(request_id, error)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -2515,7 +2584,7 @@ impl CodexMessageProcessor {
|
|||
let request = request_id.clone();
|
||||
|
||||
let rollback_already_in_progress = {
|
||||
let thread_state = self.thread_state_manager.thread_state(thread_id);
|
||||
let thread_state = self.thread_state_manager.thread_state(thread_id).await;
|
||||
let mut thread_state = thread_state.lock().await;
|
||||
if thread_state.pending_rollbacks.is_some() {
|
||||
true
|
||||
|
|
@ -2536,7 +2605,7 @@ impl CodexMessageProcessor {
|
|||
if let Err(err) = thread.submit(Op::ThreadRollback { num_turns }).await {
|
||||
// No ThreadRollback event will arrive if an error occurs.
|
||||
// Clean up and reply immediately.
|
||||
let thread_state = self.thread_state_manager.thread_state(thread_id);
|
||||
let thread_state = self.thread_state_manager.thread_state(thread_id).await;
|
||||
let mut thread_state = thread_state.lock().await;
|
||||
thread_state.pending_rollbacks = None;
|
||||
drop(thread_state);
|
||||
|
|
@ -2882,6 +2951,12 @@ impl CodexMessageProcessor {
|
|||
self.thread_manager.subscribe_thread_created()
|
||||
}
|
||||
|
||||
pub(crate) async fn connection_initialized(&self, connection_id: ConnectionId) {
|
||||
self.thread_state_manager
|
||||
.connection_initialized(connection_id)
|
||||
.await;
|
||||
}
|
||||
|
||||
pub(crate) async fn connection_closed(&mut self, connection_id: ConnectionId) {
|
||||
self.thread_state_manager
|
||||
.remove_connection(connection_id)
|
||||
|
|
@ -2906,15 +2981,13 @@ impl CodexMessageProcessor {
|
|||
}
|
||||
|
||||
for connection_id in connection_ids {
|
||||
if let Err(err) = self
|
||||
.ensure_conversation_listener(thread_id, connection_id, false, ApiVersion::V2)
|
||||
.await
|
||||
{
|
||||
warn!(
|
||||
"failed to auto-attach listener for thread {thread_id}: {message}",
|
||||
message = err.message
|
||||
);
|
||||
}
|
||||
Self::log_listener_attach_result(
|
||||
self.ensure_conversation_listener(thread_id, connection_id, false, ApiVersion::V2)
|
||||
.await,
|
||||
thread_id,
|
||||
connection_id,
|
||||
"thread",
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -3039,21 +3112,18 @@ impl CodexMessageProcessor {
|
|||
return;
|
||||
};
|
||||
// Auto-attach a thread listener when resuming a thread.
|
||||
if let Err(err) = self
|
||||
.ensure_conversation_listener(
|
||||
Self::log_listener_attach_result(
|
||||
self.ensure_conversation_listener(
|
||||
thread_id,
|
||||
request_id.connection_id,
|
||||
false,
|
||||
ApiVersion::V2,
|
||||
)
|
||||
.await
|
||||
{
|
||||
tracing::warn!(
|
||||
"failed to attach listener for thread {}: {}",
|
||||
thread_id,
|
||||
err.message
|
||||
);
|
||||
}
|
||||
.await,
|
||||
thread_id,
|
||||
request_id.connection_id,
|
||||
"thread",
|
||||
);
|
||||
|
||||
let Some(mut thread) = self
|
||||
.load_thread_from_rollout_or_send_internal(
|
||||
|
|
@ -3191,7 +3261,10 @@ impl CodexMessageProcessor {
|
|||
return true;
|
||||
}
|
||||
|
||||
let thread_state = self.thread_state_manager.thread_state(existing_thread_id);
|
||||
let thread_state = self
|
||||
.thread_state_manager
|
||||
.thread_state(existing_thread_id)
|
||||
.await;
|
||||
self.ensure_listener_task_running(
|
||||
existing_thread_id,
|
||||
existing_thread.clone(),
|
||||
|
|
@ -3542,21 +3615,18 @@ impl CodexMessageProcessor {
|
|||
return;
|
||||
};
|
||||
// Auto-attach a conversation listener when forking a thread.
|
||||
if let Err(err) = self
|
||||
.ensure_conversation_listener(
|
||||
Self::log_listener_attach_result(
|
||||
self.ensure_conversation_listener(
|
||||
thread_id,
|
||||
request_id.connection_id,
|
||||
false,
|
||||
ApiVersion::V2,
|
||||
)
|
||||
.await
|
||||
{
|
||||
tracing::warn!(
|
||||
"failed to attach listener for thread {}: {}",
|
||||
thread_id,
|
||||
err.message
|
||||
);
|
||||
}
|
||||
.await,
|
||||
thread_id,
|
||||
request_id.connection_id,
|
||||
"thread",
|
||||
);
|
||||
|
||||
let mut thread = match read_summary_from_rollout(
|
||||
rollout_path.as_path(),
|
||||
|
|
@ -5692,7 +5762,10 @@ impl CodexMessageProcessor {
|
|||
|
||||
// Record the pending interrupt so we can reply when TurnAborted arrives.
|
||||
{
|
||||
let pending_interrupts = self.thread_state_manager.thread_state(conversation_id);
|
||||
let pending_interrupts = self
|
||||
.thread_state_manager
|
||||
.thread_state(conversation_id)
|
||||
.await;
|
||||
let mut thread_state = pending_interrupts.lock().await;
|
||||
thread_state
|
||||
.pending_interrupts
|
||||
|
|
@ -5895,7 +5968,7 @@ impl CodexMessageProcessor {
|
|||
}
|
||||
};
|
||||
|
||||
if let Err(error) = self
|
||||
match self
|
||||
.ensure_conversation_listener(
|
||||
thread_id,
|
||||
request_id.connection_id,
|
||||
|
|
@ -5904,8 +5977,14 @@ impl CodexMessageProcessor {
|
|||
)
|
||||
.await
|
||||
{
|
||||
self.outgoing.send_error(request_id, error).await;
|
||||
return None;
|
||||
Ok(EnsureConversationListenerResult::Attached) => {}
|
||||
Ok(EnsureConversationListenerResult::ConnectionClosed) => {
|
||||
return None;
|
||||
}
|
||||
Err(error) => {
|
||||
self.outgoing.send_error(request_id, error).await;
|
||||
return None;
|
||||
}
|
||||
}
|
||||
|
||||
if !thread.enabled(Feature::RealtimeConversation) {
|
||||
|
|
@ -6174,21 +6253,18 @@ impl CodexMessageProcessor {
|
|||
data: None,
|
||||
})?;
|
||||
|
||||
if let Err(err) = self
|
||||
.ensure_conversation_listener(
|
||||
Self::log_listener_attach_result(
|
||||
self.ensure_conversation_listener(
|
||||
thread_id,
|
||||
request_id.connection_id,
|
||||
false,
|
||||
ApiVersion::V2,
|
||||
)
|
||||
.await
|
||||
{
|
||||
tracing::warn!(
|
||||
"failed to attach listener for review thread {}: {}",
|
||||
thread_id,
|
||||
err.message
|
||||
);
|
||||
}
|
||||
.await,
|
||||
thread_id,
|
||||
request_id.connection_id,
|
||||
"review thread",
|
||||
);
|
||||
|
||||
let fallback_provider = self.config.model_provider_id.as_str();
|
||||
if let Some(rollout_path) = review_thread.rollout_path() {
|
||||
|
|
@ -6315,7 +6391,7 @@ impl CodexMessageProcessor {
|
|||
|
||||
// Record the pending interrupt so we can reply when TurnAborted arrives.
|
||||
{
|
||||
let thread_state = self.thread_state_manager.thread_state(thread_uuid);
|
||||
let thread_state = self.thread_state_manager.thread_state(thread_uuid).await;
|
||||
let mut thread_state = thread_state.lock().await;
|
||||
thread_state
|
||||
.pending_interrupts
|
||||
|
|
@ -6397,13 +6473,42 @@ impl CodexMessageProcessor {
|
|||
}
|
||||
|
||||
async fn ensure_conversation_listener(
|
||||
&mut self,
|
||||
&self,
|
||||
conversation_id: ThreadId,
|
||||
connection_id: ConnectionId,
|
||||
raw_events_enabled: bool,
|
||||
api_version: ApiVersion,
|
||||
) -> Result<(), JSONRPCErrorError> {
|
||||
let conversation = match self.thread_manager.get_thread(conversation_id).await {
|
||||
) -> Result<EnsureConversationListenerResult, JSONRPCErrorError> {
|
||||
Self::ensure_conversation_listener_task(
|
||||
ListenerTaskContext {
|
||||
thread_manager: Arc::clone(&self.thread_manager),
|
||||
thread_state_manager: self.thread_state_manager.clone(),
|
||||
outgoing: Arc::clone(&self.outgoing),
|
||||
thread_watch_manager: self.thread_watch_manager.clone(),
|
||||
fallback_model_provider: self.config.model_provider_id.clone(),
|
||||
codex_home: self.config.codex_home.clone(),
|
||||
single_client_mode: self.single_client_mode,
|
||||
},
|
||||
conversation_id,
|
||||
connection_id,
|
||||
raw_events_enabled,
|
||||
api_version,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn ensure_conversation_listener_task(
|
||||
listener_task_context: ListenerTaskContext,
|
||||
conversation_id: ThreadId,
|
||||
connection_id: ConnectionId,
|
||||
raw_events_enabled: bool,
|
||||
api_version: ApiVersion,
|
||||
) -> Result<EnsureConversationListenerResult, JSONRPCErrorError> {
|
||||
let conversation = match listener_task_context
|
||||
.thread_manager
|
||||
.get_thread(conversation_id)
|
||||
.await
|
||||
{
|
||||
Ok(conv) => conv,
|
||||
Err(_) => {
|
||||
return Err(JSONRPCErrorError {
|
||||
|
|
@ -6413,13 +6518,46 @@ impl CodexMessageProcessor {
|
|||
});
|
||||
}
|
||||
};
|
||||
let thread_state = self
|
||||
let Some(thread_state) = listener_task_context
|
||||
.thread_state_manager
|
||||
.ensure_connection_subscribed(conversation_id, connection_id, raw_events_enabled)
|
||||
.await;
|
||||
self.ensure_listener_task_running(conversation_id, conversation, thread_state, api_version)
|
||||
.await;
|
||||
Ok(())
|
||||
.try_ensure_connection_subscribed(conversation_id, connection_id, raw_events_enabled)
|
||||
.await
|
||||
else {
|
||||
return Ok(EnsureConversationListenerResult::ConnectionClosed);
|
||||
};
|
||||
Self::ensure_listener_task_running_task(
|
||||
listener_task_context,
|
||||
conversation_id,
|
||||
conversation,
|
||||
thread_state,
|
||||
api_version,
|
||||
)
|
||||
.await;
|
||||
Ok(EnsureConversationListenerResult::Attached)
|
||||
}
|
||||
|
||||
fn log_listener_attach_result(
|
||||
result: Result<EnsureConversationListenerResult, JSONRPCErrorError>,
|
||||
thread_id: ThreadId,
|
||||
connection_id: ConnectionId,
|
||||
thread_kind: &'static str,
|
||||
) {
|
||||
match result {
|
||||
Ok(EnsureConversationListenerResult::Attached) => {}
|
||||
Ok(EnsureConversationListenerResult::ConnectionClosed) => {
|
||||
tracing::debug!(
|
||||
thread_id = %thread_id,
|
||||
connection_id = ?connection_id,
|
||||
"skipping auto-attach for closed connection"
|
||||
);
|
||||
}
|
||||
Err(err) => {
|
||||
tracing::warn!(
|
||||
"failed to attach listener for {thread_kind} {thread_id}: {message}",
|
||||
message = err.message
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn ensure_listener_task_running(
|
||||
|
|
@ -6428,6 +6566,31 @@ impl CodexMessageProcessor {
|
|||
conversation: Arc<CodexThread>,
|
||||
thread_state: Arc<Mutex<ThreadState>>,
|
||||
api_version: ApiVersion,
|
||||
) {
|
||||
Self::ensure_listener_task_running_task(
|
||||
ListenerTaskContext {
|
||||
thread_manager: Arc::clone(&self.thread_manager),
|
||||
thread_state_manager: self.thread_state_manager.clone(),
|
||||
outgoing: Arc::clone(&self.outgoing),
|
||||
thread_watch_manager: self.thread_watch_manager.clone(),
|
||||
fallback_model_provider: self.config.model_provider_id.clone(),
|
||||
codex_home: self.config.codex_home.clone(),
|
||||
single_client_mode: self.single_client_mode,
|
||||
},
|
||||
conversation_id,
|
||||
conversation,
|
||||
thread_state,
|
||||
api_version,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
async fn ensure_listener_task_running_task(
|
||||
listener_task_context: ListenerTaskContext,
|
||||
conversation_id: ThreadId,
|
||||
conversation: Arc<CodexThread>,
|
||||
thread_state: Arc<Mutex<ThreadState>>,
|
||||
api_version: ApiVersion,
|
||||
) {
|
||||
let (cancel_tx, mut cancel_rx) = oneshot::channel();
|
||||
let (mut listener_command_rx, listener_generation) = {
|
||||
|
|
@ -6437,12 +6600,16 @@ impl CodexMessageProcessor {
|
|||
}
|
||||
thread_state.set_listener(cancel_tx, &conversation)
|
||||
};
|
||||
let outgoing_for_task = self.outgoing.clone();
|
||||
let thread_manager = self.thread_manager.clone();
|
||||
let thread_watch_manager = self.thread_watch_manager.clone();
|
||||
let fallback_model_provider = self.config.model_provider_id.clone();
|
||||
let single_client_mode = self.single_client_mode;
|
||||
let codex_home = self.config.codex_home.clone();
|
||||
let ListenerTaskContext {
|
||||
outgoing,
|
||||
thread_manager,
|
||||
thread_state_manager,
|
||||
thread_watch_manager,
|
||||
fallback_model_provider,
|
||||
codex_home,
|
||||
single_client_mode,
|
||||
} = listener_task_context;
|
||||
let outgoing_for_task = Arc::clone(&outgoing);
|
||||
tokio::spawn(async move {
|
||||
loop {
|
||||
tokio::select! {
|
||||
|
|
@ -6488,16 +6655,16 @@ impl CodexMessageProcessor {
|
|||
"conversationId".to_string(),
|
||||
conversation_id.to_string().into(),
|
||||
);
|
||||
let (subscribed_connection_ids, raw_events_enabled) = {
|
||||
let raw_events_enabled = {
|
||||
let mut thread_state = thread_state.lock().await;
|
||||
if !single_client_mode {
|
||||
thread_state.track_current_turn_event(&event.msg);
|
||||
}
|
||||
(
|
||||
thread_state.subscribed_connection_ids(),
|
||||
thread_state.experimental_raw_events,
|
||||
)
|
||||
thread_state.experimental_raw_events
|
||||
};
|
||||
let subscribed_connection_ids = thread_state_manager
|
||||
.subscribed_connection_ids(conversation_id)
|
||||
.await;
|
||||
if let EventMsg::RawResponseItem(_) = &event.msg && !raw_events_enabled {
|
||||
continue;
|
||||
}
|
||||
|
|
@ -6540,6 +6707,7 @@ impl CodexMessageProcessor {
|
|||
handle_thread_listener_command(
|
||||
conversation_id,
|
||||
codex_home.as_path(),
|
||||
&thread_state_manager,
|
||||
&thread_state,
|
||||
&thread_watch_manager,
|
||||
&outgoing_for_task,
|
||||
|
|
@ -6857,6 +7025,7 @@ impl CodexMessageProcessor {
|
|||
async fn handle_thread_listener_command(
|
||||
conversation_id: ThreadId,
|
||||
codex_home: &Path,
|
||||
thread_state_manager: &ThreadStateManager,
|
||||
thread_state: &Arc<Mutex<ThreadState>>,
|
||||
thread_watch_manager: &ThreadWatchManager,
|
||||
outgoing: &Arc<OutgoingMessageSender>,
|
||||
|
|
@ -6867,6 +7036,7 @@ async fn handle_thread_listener_command(
|
|||
handle_pending_thread_resume_request(
|
||||
conversation_id,
|
||||
codex_home,
|
||||
thread_state_manager,
|
||||
thread_state,
|
||||
thread_watch_manager,
|
||||
outgoing,
|
||||
|
|
@ -6878,8 +7048,13 @@ async fn handle_thread_listener_command(
|
|||
request_id,
|
||||
completion_tx,
|
||||
} => {
|
||||
resolve_pending_server_request(conversation_id, thread_state, outgoing, request_id)
|
||||
.await;
|
||||
resolve_pending_server_request(
|
||||
conversation_id,
|
||||
thread_state_manager,
|
||||
outgoing,
|
||||
request_id,
|
||||
)
|
||||
.await;
|
||||
let _ = completion_tx.send(());
|
||||
}
|
||||
}
|
||||
|
|
@ -6888,6 +7063,7 @@ async fn handle_thread_listener_command(
|
|||
async fn handle_pending_thread_resume_request(
|
||||
conversation_id: ThreadId,
|
||||
codex_home: &Path,
|
||||
thread_state_manager: &ThreadStateManager,
|
||||
thread_state: &Arc<Mutex<ThreadState>>,
|
||||
thread_watch_manager: &ThreadWatchManager,
|
||||
outgoing: &Arc<OutgoingMessageSender>,
|
||||
|
|
@ -6976,18 +7152,21 @@ async fn handle_pending_thread_resume_request(
|
|||
outgoing
|
||||
.replay_requests_to_connection_for_thread(connection_id, conversation_id)
|
||||
.await;
|
||||
|
||||
thread_state.lock().await.add_connection(connection_id);
|
||||
let _attached = thread_state_manager
|
||||
.try_add_connection_to_thread(conversation_id, connection_id)
|
||||
.await;
|
||||
}
|
||||
|
||||
async fn resolve_pending_server_request(
|
||||
conversation_id: ThreadId,
|
||||
thread_state: &Arc<Mutex<ThreadState>>,
|
||||
thread_state_manager: &ThreadStateManager,
|
||||
outgoing: &Arc<OutgoingMessageSender>,
|
||||
request_id: RequestId,
|
||||
) {
|
||||
let thread_id = conversation_id.to_string();
|
||||
let subscribed_connection_ids = thread_state.lock().await.subscribed_connection_ids();
|
||||
let subscribed_connection_ids = thread_state_manager
|
||||
.subscribed_connection_ids(conversation_id)
|
||||
.await;
|
||||
let outgoing = ThreadScopedOutgoingMessageSender::new(
|
||||
outgoing.clone(),
|
||||
subscribed_connection_ids,
|
||||
|
|
@ -7951,9 +8130,7 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn aborting_pending_request_clears_pending_state() -> Result<()> {
|
||||
let thread_id = ThreadId::from_string("bfd12a78-5900-467b-9bc5-d3d35df08191")?;
|
||||
let thread_state = Arc::new(Mutex::new(ThreadState::default()));
|
||||
let connection_id = ConnectionId(7);
|
||||
thread_state.lock().await.add_connection(connection_id);
|
||||
|
||||
let (outgoing_tx, mut outgoing_rx) = tokio::sync::mpsc::channel(8);
|
||||
let outgoing = Arc::new(OutgoingMessageSender::new(outgoing_tx));
|
||||
|
|
@ -8047,7 +8224,7 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn removing_listeners_retains_thread_listener_when_last_subscriber_leaves() -> Result<()>
|
||||
{
|
||||
let mut manager = ThreadStateManager::new();
|
||||
let manager = ThreadStateManager::new();
|
||||
let thread_id = ThreadId::from_string("ad7f0408-99b8-4f6e-a46f-bd0eec433370")?;
|
||||
let listener_a = Uuid::new_v4();
|
||||
let listener_b = Uuid::new_v4();
|
||||
|
|
@ -8062,7 +8239,7 @@ mod tests {
|
|||
.set_listener(listener_b, thread_id, connection_b, false)
|
||||
.await;
|
||||
{
|
||||
let state = manager.thread_state(thread_id);
|
||||
let state = manager.thread_state(thread_id).await;
|
||||
state.lock().await.cancel_tx = Some(cancel_tx);
|
||||
}
|
||||
|
||||
|
|
@ -8078,14 +8255,18 @@ mod tests {
|
|||
.await
|
||||
.is_err()
|
||||
);
|
||||
let state = manager.thread_state(thread_id);
|
||||
assert!(state.lock().await.subscribed_connection_ids().is_empty());
|
||||
assert!(
|
||||
manager
|
||||
.subscribed_connection_ids(thread_id)
|
||||
.await
|
||||
.is_empty()
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn removing_listener_unsubscribes_its_connection() -> Result<()> {
|
||||
let mut manager = ThreadStateManager::new();
|
||||
let manager = ThreadStateManager::new();
|
||||
let thread_id = ThreadId::from_string("ad7f0408-99b8-4f6e-a46f-bd0eec433370")?;
|
||||
let listener_a = Uuid::new_v4();
|
||||
let listener_b = Uuid::new_v4();
|
||||
|
|
@ -8100,15 +8281,14 @@ mod tests {
|
|||
.await;
|
||||
|
||||
assert_eq!(manager.remove_listener(listener_a).await, Some(thread_id));
|
||||
let state = manager.thread_state(thread_id);
|
||||
let subscribed_connection_ids = state.lock().await.subscribed_connection_ids();
|
||||
let subscribed_connection_ids = manager.subscribed_connection_ids(thread_id).await;
|
||||
assert_eq!(subscribed_connection_ids, vec![connection_b]);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn set_listener_uses_last_write_for_raw_events() -> Result<()> {
|
||||
let mut manager = ThreadStateManager::new();
|
||||
let manager = ThreadStateManager::new();
|
||||
let thread_id = ThreadId::from_string("ad7f0408-99b8-4f6e-a46f-bd0eec433370")?;
|
||||
let listener_a = Uuid::new_v4();
|
||||
let listener_b = Uuid::new_v4();
|
||||
|
|
@ -8119,13 +8299,13 @@ mod tests {
|
|||
.set_listener(listener_a, thread_id, connection_a, true)
|
||||
.await;
|
||||
{
|
||||
let state = manager.thread_state(thread_id);
|
||||
let state = manager.thread_state(thread_id).await;
|
||||
assert!(state.lock().await.experimental_raw_events);
|
||||
}
|
||||
manager
|
||||
.set_listener(listener_b, thread_id, connection_b, false)
|
||||
.await;
|
||||
let state = manager.thread_state(thread_id);
|
||||
let state = manager.thread_state(thread_id).await;
|
||||
assert!(!state.lock().await.experimental_raw_events);
|
||||
Ok(())
|
||||
}
|
||||
|
|
@ -8133,7 +8313,7 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn removing_connection_retains_listener_and_active_turn_when_last_subscriber_disconnects()
|
||||
-> Result<()> {
|
||||
let mut manager = ThreadStateManager::new();
|
||||
let manager = ThreadStateManager::new();
|
||||
let thread_id = ThreadId::from_string("ad7f0408-99b8-4f6e-a46f-bd0eec433370")?;
|
||||
let listener = Uuid::new_v4();
|
||||
let connection = ConnectionId(1);
|
||||
|
|
@ -8143,7 +8323,7 @@ mod tests {
|
|||
.set_listener(listener, thread_id, connection, false)
|
||||
.await;
|
||||
{
|
||||
let state = manager.thread_state(thread_id);
|
||||
let state = manager.thread_state(thread_id).await;
|
||||
let mut state = state.lock().await;
|
||||
state.cancel_tx = Some(cancel_tx);
|
||||
state.track_current_turn_event(&EventMsg::TurnStarted(
|
||||
|
|
@ -8163,9 +8343,14 @@ mod tests {
|
|||
);
|
||||
assert_eq!(manager.remove_listener(listener).await, None);
|
||||
|
||||
let state = manager.thread_state(thread_id);
|
||||
let state = manager.thread_state(thread_id).await;
|
||||
let state = state.lock().await;
|
||||
assert!(state.subscribed_connection_ids().is_empty());
|
||||
assert!(
|
||||
manager
|
||||
.subscribed_connection_ids(thread_id)
|
||||
.await
|
||||
.is_empty()
|
||||
);
|
||||
assert!(state.cancel_tx.is_some());
|
||||
let active_turn = state.active_turn_snapshot().expect("active turn snapshot");
|
||||
assert_eq!(active_turn.id, "turn-1");
|
||||
|
|
@ -8175,16 +8360,18 @@ mod tests {
|
|||
|
||||
#[tokio::test]
|
||||
async fn removing_thread_state_clears_listener_and_active_turn_history() -> Result<()> {
|
||||
let mut manager = ThreadStateManager::new();
|
||||
let manager = ThreadStateManager::new();
|
||||
let thread_id = ThreadId::from_string("ad7f0408-99b8-4f6e-a46f-bd0eec433370")?;
|
||||
let connection = ConnectionId(1);
|
||||
let (cancel_tx, cancel_rx) = oneshot::channel();
|
||||
|
||||
manager.connection_initialized(connection).await;
|
||||
manager
|
||||
.ensure_connection_subscribed(thread_id, connection, false)
|
||||
.await;
|
||||
.try_ensure_connection_subscribed(thread_id, connection, false)
|
||||
.await
|
||||
.expect("connection should be live");
|
||||
{
|
||||
let state = manager.thread_state(thread_id);
|
||||
let state = manager.thread_state(thread_id).await;
|
||||
let mut state = state.lock().await;
|
||||
state.cancel_tx = Some(cancel_tx);
|
||||
state.track_current_turn_event(&EventMsg::TurnStarted(
|
||||
|
|
@ -8199,9 +8386,14 @@ mod tests {
|
|||
manager.remove_thread_state(thread_id).await;
|
||||
assert_eq!(cancel_rx.await, Ok(()));
|
||||
|
||||
let state = manager.thread_state(thread_id);
|
||||
let state = manager.thread_state(thread_id).await;
|
||||
let state = state.lock().await;
|
||||
assert!(state.subscribed_connection_ids().is_empty());
|
||||
assert!(
|
||||
manager
|
||||
.subscribed_connection_ids(thread_id)
|
||||
.await
|
||||
.is_empty()
|
||||
);
|
||||
assert!(state.cancel_tx.is_none());
|
||||
assert!(state.active_turn_snapshot().is_none());
|
||||
Ok(())
|
||||
|
|
@ -8210,20 +8402,24 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn removing_auto_attached_connection_preserves_listener_for_other_connections()
|
||||
-> Result<()> {
|
||||
let mut manager = ThreadStateManager::new();
|
||||
let manager = ThreadStateManager::new();
|
||||
let thread_id = ThreadId::from_string("ad7f0408-99b8-4f6e-a46f-bd0eec433370")?;
|
||||
let connection_a = ConnectionId(1);
|
||||
let connection_b = ConnectionId(2);
|
||||
let (cancel_tx, mut cancel_rx) = oneshot::channel();
|
||||
|
||||
manager.connection_initialized(connection_a).await;
|
||||
manager.connection_initialized(connection_b).await;
|
||||
manager
|
||||
.ensure_connection_subscribed(thread_id, connection_a, false)
|
||||
.await;
|
||||
.try_ensure_connection_subscribed(thread_id, connection_a, false)
|
||||
.await
|
||||
.expect("connection_a should be live");
|
||||
manager
|
||||
.ensure_connection_subscribed(thread_id, connection_b, false)
|
||||
.await;
|
||||
.try_ensure_connection_subscribed(thread_id, connection_b, false)
|
||||
.await
|
||||
.expect("connection_b should be live");
|
||||
{
|
||||
let state = manager.thread_state(thread_id);
|
||||
let state = manager.thread_state(thread_id).await;
|
||||
state.lock().await.cancel_tx = Some(cancel_tx);
|
||||
}
|
||||
|
||||
|
|
@ -8234,11 +8430,29 @@ mod tests {
|
|||
.is_err()
|
||||
);
|
||||
|
||||
let state = manager.thread_state(thread_id);
|
||||
assert_eq!(
|
||||
state.lock().await.subscribed_connection_ids(),
|
||||
manager.subscribed_connection_ids(thread_id).await,
|
||||
vec![connection_b]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn closed_connection_cannot_be_reintroduced_by_auto_subscribe() -> Result<()> {
|
||||
let manager = ThreadStateManager::new();
|
||||
let thread_id = ThreadId::from_string("ad7f0408-99b8-4f6e-a46f-bd0eec433370")?;
|
||||
let connection = ConnectionId(1);
|
||||
|
||||
manager.connection_initialized(connection).await;
|
||||
manager.remove_connection(connection).await;
|
||||
|
||||
assert!(
|
||||
manager
|
||||
.try_ensure_connection_subscribed(thread_id, connection, false)
|
||||
.await
|
||||
.is_none()
|
||||
);
|
||||
assert!(!manager.has_subscribers(thread_id).await);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -338,6 +338,9 @@ impl MessageProcessor {
|
|||
|
||||
session.initialized = true;
|
||||
outbound_initialized.store(true, Ordering::Release);
|
||||
self.codex_message_processor
|
||||
.connection_initialized(connection_id)
|
||||
.await;
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -60,7 +60,6 @@ pub(crate) struct ThreadState {
|
|||
listener_command_tx: Option<mpsc::UnboundedSender<ThreadListenerCommand>>,
|
||||
current_turn_history: ThreadHistoryBuilder,
|
||||
listener_thread: Option<Weak<CodexThread>>,
|
||||
subscribed_connections: HashSet<ConnectionId>,
|
||||
}
|
||||
|
||||
impl ThreadState {
|
||||
|
|
@ -95,18 +94,6 @@ impl ThreadState {
|
|||
self.listener_thread = None;
|
||||
}
|
||||
|
||||
pub(crate) fn add_connection(&mut self, connection_id: ConnectionId) {
|
||||
self.subscribed_connections.insert(connection_id);
|
||||
}
|
||||
|
||||
pub(crate) fn remove_connection(&mut self, connection_id: ConnectionId) {
|
||||
self.subscribed_connections.remove(&connection_id);
|
||||
}
|
||||
|
||||
pub(crate) fn subscribed_connection_ids(&self) -> Vec<ConnectionId> {
|
||||
self.subscribed_connections.iter().copied().collect()
|
||||
}
|
||||
|
||||
pub(crate) fn set_experimental_raw_events(&mut self, enabled: bool) {
|
||||
self.experimental_raw_events = enabled;
|
||||
}
|
||||
|
|
@ -135,55 +122,112 @@ struct SubscriptionState {
|
|||
connection_id: ConnectionId,
|
||||
}
|
||||
|
||||
struct ThreadEntry {
|
||||
state: Arc<Mutex<ThreadState>>,
|
||||
connection_ids: HashSet<ConnectionId>,
|
||||
}
|
||||
|
||||
impl Default for ThreadEntry {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
state: Arc::new(Mutex::new(ThreadState::default())),
|
||||
connection_ids: HashSet::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub(crate) struct ThreadStateManager {
|
||||
thread_states: HashMap<ThreadId, Arc<Mutex<ThreadState>>>,
|
||||
struct ThreadStateManagerInner {
|
||||
live_connections: HashSet<ConnectionId>,
|
||||
threads: HashMap<ThreadId, ThreadEntry>,
|
||||
subscription_state_by_id: HashMap<Uuid, SubscriptionState>,
|
||||
thread_ids_by_connection: HashMap<ConnectionId, HashSet<ThreadId>>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Default)]
|
||||
pub(crate) struct ThreadStateManager {
|
||||
state: Arc<Mutex<ThreadStateManagerInner>>,
|
||||
}
|
||||
|
||||
impl ThreadStateManager {
|
||||
pub(crate) fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
pub(crate) fn thread_state(&mut self, thread_id: ThreadId) -> Arc<Mutex<ThreadState>> {
|
||||
self.thread_states
|
||||
.entry(thread_id)
|
||||
.or_insert_with(|| Arc::new(Mutex::new(ThreadState::default())))
|
||||
.clone()
|
||||
pub(crate) async fn connection_initialized(&self, connection_id: ConnectionId) {
|
||||
self.state
|
||||
.lock()
|
||||
.await
|
||||
.live_connections
|
||||
.insert(connection_id);
|
||||
}
|
||||
|
||||
pub(crate) async fn remove_listener(&mut self, subscription_id: Uuid) -> Option<ThreadId> {
|
||||
let subscription_state = self.subscription_state_by_id.remove(&subscription_id)?;
|
||||
pub(crate) async fn subscribed_connection_ids(&self, thread_id: ThreadId) -> Vec<ConnectionId> {
|
||||
let state = self.state.lock().await;
|
||||
state
|
||||
.threads
|
||||
.get(&thread_id)
|
||||
.map(|thread_entry| thread_entry.connection_ids.iter().copied().collect())
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
pub(crate) async fn thread_state(&self, thread_id: ThreadId) -> Arc<Mutex<ThreadState>> {
|
||||
let mut state = self.state.lock().await;
|
||||
state.threads.entry(thread_id).or_default().state.clone()
|
||||
}
|
||||
|
||||
pub(crate) async fn remove_listener(&self, subscription_id: Uuid) -> Option<ThreadId> {
|
||||
let (subscription_state, connection_still_subscribed_to_thread, thread_state) = {
|
||||
let mut state = self.state.lock().await;
|
||||
let subscription_state = state.subscription_state_by_id.remove(&subscription_id)?;
|
||||
let thread_id = subscription_state.thread_id;
|
||||
|
||||
let connection_still_subscribed_to_thread = state
|
||||
.subscription_state_by_id
|
||||
.values()
|
||||
.any(|subscription_state_entry| {
|
||||
subscription_state_entry.thread_id == thread_id
|
||||
&& subscription_state_entry.connection_id
|
||||
== subscription_state.connection_id
|
||||
});
|
||||
if !connection_still_subscribed_to_thread {
|
||||
let mut remove_connection_entry = false;
|
||||
if let Some(thread_ids) = state
|
||||
.thread_ids_by_connection
|
||||
.get_mut(&subscription_state.connection_id)
|
||||
{
|
||||
thread_ids.remove(&thread_id);
|
||||
remove_connection_entry = thread_ids.is_empty();
|
||||
}
|
||||
if remove_connection_entry {
|
||||
state
|
||||
.thread_ids_by_connection
|
||||
.remove(&subscription_state.connection_id);
|
||||
}
|
||||
if let Some(thread_entry) = state.threads.get_mut(&thread_id) {
|
||||
thread_entry
|
||||
.connection_ids
|
||||
.remove(&subscription_state.connection_id);
|
||||
}
|
||||
}
|
||||
|
||||
let thread_state = state.threads.get(&thread_id).map(|thread_entry| {
|
||||
(
|
||||
thread_entry.connection_ids.is_empty(),
|
||||
thread_entry.state.clone(),
|
||||
)
|
||||
});
|
||||
(
|
||||
subscription_state,
|
||||
connection_still_subscribed_to_thread,
|
||||
thread_state,
|
||||
)
|
||||
};
|
||||
let thread_id = subscription_state.thread_id;
|
||||
|
||||
let connection_still_subscribed_to_thread =
|
||||
self.subscription_state_by_id.values().any(|state| {
|
||||
state.thread_id == thread_id
|
||||
&& state.connection_id == subscription_state.connection_id
|
||||
});
|
||||
if !connection_still_subscribed_to_thread {
|
||||
let mut remove_connection_entry = false;
|
||||
if let Some(thread_ids) = self
|
||||
.thread_ids_by_connection
|
||||
.get_mut(&subscription_state.connection_id)
|
||||
{
|
||||
thread_ids.remove(&thread_id);
|
||||
remove_connection_entry = thread_ids.is_empty();
|
||||
}
|
||||
if remove_connection_entry {
|
||||
self.thread_ids_by_connection
|
||||
.remove(&subscription_state.connection_id);
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(thread_state) = self.thread_states.get(&thread_id) {
|
||||
let mut thread_state = thread_state.lock().await;
|
||||
if !connection_still_subscribed_to_thread {
|
||||
thread_state.remove_connection(subscription_state.connection_id);
|
||||
}
|
||||
if thread_state.subscribed_connection_ids().is_empty() {
|
||||
if let Some((no_subscribers, thread_state)) = thread_state {
|
||||
let thread_state = thread_state.lock().await;
|
||||
if !connection_still_subscribed_to_thread && no_subscribers {
|
||||
tracing::debug!(
|
||||
thread_id = %thread_id,
|
||||
subscription_id = %subscription_id,
|
||||
|
|
@ -196,8 +240,24 @@ impl ThreadStateManager {
|
|||
Some(thread_id)
|
||||
}
|
||||
|
||||
pub(crate) async fn remove_thread_state(&mut self, thread_id: ThreadId) {
|
||||
if let Some(thread_state) = self.thread_states.remove(&thread_id) {
|
||||
pub(crate) async fn remove_thread_state(&self, thread_id: ThreadId) {
|
||||
let thread_state = {
|
||||
let mut state = self.state.lock().await;
|
||||
let thread_state = state
|
||||
.threads
|
||||
.remove(&thread_id)
|
||||
.map(|thread_entry| thread_entry.state);
|
||||
state
|
||||
.subscription_state_by_id
|
||||
.retain(|_, state| state.thread_id != thread_id);
|
||||
state.thread_ids_by_connection.retain(|_, thread_ids| {
|
||||
thread_ids.remove(&thread_id);
|
||||
!thread_ids.is_empty()
|
||||
});
|
||||
thread_state
|
||||
};
|
||||
|
||||
if let Some(thread_state) = thread_state {
|
||||
let mut thread_state = thread_state.lock().await;
|
||||
tracing::debug!(
|
||||
thread_id = %thread_id,
|
||||
|
|
@ -208,142 +268,189 @@ impl ThreadStateManager {
|
|||
);
|
||||
thread_state.clear_listener();
|
||||
}
|
||||
self.subscription_state_by_id
|
||||
.retain(|_, state| state.thread_id != thread_id);
|
||||
self.thread_ids_by_connection.retain(|_, thread_ids| {
|
||||
thread_ids.remove(&thread_id);
|
||||
!thread_ids.is_empty()
|
||||
});
|
||||
}
|
||||
|
||||
pub(crate) async fn unsubscribe_connection_from_thread(
|
||||
&mut self,
|
||||
&self,
|
||||
thread_id: ThreadId,
|
||||
connection_id: ConnectionId,
|
||||
) -> bool {
|
||||
let Some(thread_state) = self.thread_states.get(&thread_id) else {
|
||||
return false;
|
||||
{
|
||||
let mut state = self.state.lock().await;
|
||||
if !state.threads.contains_key(&thread_id) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if !state
|
||||
.thread_ids_by_connection
|
||||
.get(&connection_id)
|
||||
.is_some_and(|thread_ids| thread_ids.contains(&thread_id))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if let Some(thread_ids) = state.thread_ids_by_connection.get_mut(&connection_id) {
|
||||
thread_ids.remove(&thread_id);
|
||||
if thread_ids.is_empty() {
|
||||
state.thread_ids_by_connection.remove(&connection_id);
|
||||
}
|
||||
}
|
||||
if let Some(thread_entry) = state.threads.get_mut(&thread_id) {
|
||||
thread_entry.connection_ids.remove(&connection_id);
|
||||
}
|
||||
|
||||
state
|
||||
.subscription_state_by_id
|
||||
.retain(|_, subscription_state| {
|
||||
!(subscription_state.thread_id == thread_id
|
||||
&& subscription_state.connection_id == connection_id)
|
||||
});
|
||||
};
|
||||
|
||||
if !self
|
||||
.thread_ids_by_connection
|
||||
.get(&connection_id)
|
||||
.is_some_and(|thread_ids| thread_ids.contains(&thread_id))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if let Some(thread_ids) = self.thread_ids_by_connection.get_mut(&connection_id) {
|
||||
thread_ids.remove(&thread_id);
|
||||
if thread_ids.is_empty() {
|
||||
self.thread_ids_by_connection.remove(&connection_id);
|
||||
}
|
||||
}
|
||||
|
||||
self.subscription_state_by_id.retain(|_, state| {
|
||||
!(state.thread_id == thread_id && state.connection_id == connection_id)
|
||||
});
|
||||
|
||||
let mut thread_state = thread_state.lock().await;
|
||||
thread_state.remove_connection(connection_id);
|
||||
true
|
||||
}
|
||||
|
||||
pub(crate) async fn has_subscribers(&self, thread_id: ThreadId) -> bool {
|
||||
let Some(thread_state) = self.thread_states.get(&thread_id) else {
|
||||
return false;
|
||||
};
|
||||
!thread_state
|
||||
self.state
|
||||
.lock()
|
||||
.await
|
||||
.subscribed_connection_ids()
|
||||
.is_empty()
|
||||
.threads
|
||||
.get(&thread_id)
|
||||
.is_some_and(|thread_entry| !thread_entry.connection_ids.is_empty())
|
||||
}
|
||||
|
||||
pub(crate) async fn set_listener(
|
||||
&mut self,
|
||||
&self,
|
||||
subscription_id: Uuid,
|
||||
thread_id: ThreadId,
|
||||
connection_id: ConnectionId,
|
||||
experimental_raw_events: bool,
|
||||
) -> Arc<Mutex<ThreadState>> {
|
||||
self.subscription_state_by_id.insert(
|
||||
subscription_id,
|
||||
SubscriptionState {
|
||||
thread_id,
|
||||
connection_id,
|
||||
},
|
||||
);
|
||||
self.thread_ids_by_connection
|
||||
.entry(connection_id)
|
||||
.or_default()
|
||||
.insert(thread_id);
|
||||
let thread_state = self.thread_state(thread_id);
|
||||
let thread_state = {
|
||||
let mut state = self.state.lock().await;
|
||||
state.subscription_state_by_id.insert(
|
||||
subscription_id,
|
||||
SubscriptionState {
|
||||
thread_id,
|
||||
connection_id,
|
||||
},
|
||||
);
|
||||
state
|
||||
.thread_ids_by_connection
|
||||
.entry(connection_id)
|
||||
.or_default()
|
||||
.insert(thread_id);
|
||||
let thread_entry = state.threads.entry(thread_id).or_default();
|
||||
thread_entry.connection_ids.insert(connection_id);
|
||||
thread_entry.state.clone()
|
||||
};
|
||||
{
|
||||
let mut thread_state_guard = thread_state.lock().await;
|
||||
thread_state_guard.add_connection(connection_id);
|
||||
thread_state_guard.set_experimental_raw_events(experimental_raw_events);
|
||||
}
|
||||
thread_state
|
||||
}
|
||||
|
||||
pub(crate) async fn ensure_connection_subscribed(
|
||||
&mut self,
|
||||
pub(crate) async fn try_ensure_connection_subscribed(
|
||||
&self,
|
||||
thread_id: ThreadId,
|
||||
connection_id: ConnectionId,
|
||||
experimental_raw_events: bool,
|
||||
) -> Arc<Mutex<ThreadState>> {
|
||||
self.thread_ids_by_connection
|
||||
.entry(connection_id)
|
||||
.or_default()
|
||||
.insert(thread_id);
|
||||
let thread_state = self.thread_state(thread_id);
|
||||
) -> Option<Arc<Mutex<ThreadState>>> {
|
||||
let thread_state = {
|
||||
let mut state = self.state.lock().await;
|
||||
if !state.live_connections.contains(&connection_id) {
|
||||
return None;
|
||||
}
|
||||
state
|
||||
.thread_ids_by_connection
|
||||
.entry(connection_id)
|
||||
.or_default()
|
||||
.insert(thread_id);
|
||||
let thread_entry = state.threads.entry(thread_id).or_default();
|
||||
thread_entry.connection_ids.insert(connection_id);
|
||||
thread_entry.state.clone()
|
||||
};
|
||||
{
|
||||
let mut thread_state_guard = thread_state.lock().await;
|
||||
thread_state_guard.add_connection(connection_id);
|
||||
if experimental_raw_events {
|
||||
thread_state_guard.set_experimental_raw_events(true);
|
||||
}
|
||||
}
|
||||
thread_state
|
||||
Some(thread_state)
|
||||
}
|
||||
|
||||
pub(crate) async fn remove_connection(&mut self, connection_id: ConnectionId) {
|
||||
let thread_ids = self
|
||||
.thread_ids_by_connection
|
||||
.remove(&connection_id)
|
||||
.unwrap_or_default();
|
||||
self.subscription_state_by_id
|
||||
.retain(|_, state| state.connection_id != connection_id);
|
||||
|
||||
if thread_ids.is_empty() {
|
||||
for thread_state in self.thread_states.values() {
|
||||
let mut thread_state = thread_state.lock().await;
|
||||
thread_state.remove_connection(connection_id);
|
||||
if thread_state.subscribed_connection_ids().is_empty() {
|
||||
tracing::debug!(
|
||||
connection_id = ?connection_id,
|
||||
listener_generation = thread_state.listener_generation,
|
||||
"retaining thread listener after connection disconnect left zero subscribers"
|
||||
);
|
||||
}
|
||||
}
|
||||
return;
|
||||
pub(crate) async fn try_add_connection_to_thread(
|
||||
&self,
|
||||
thread_id: ThreadId,
|
||||
connection_id: ConnectionId,
|
||||
) -> bool {
|
||||
let mut state = self.state.lock().await;
|
||||
if !state.live_connections.contains(&connection_id) {
|
||||
return false;
|
||||
}
|
||||
state
|
||||
.thread_ids_by_connection
|
||||
.entry(connection_id)
|
||||
.or_default()
|
||||
.insert(thread_id);
|
||||
state
|
||||
.threads
|
||||
.entry(thread_id)
|
||||
.or_default()
|
||||
.connection_ids
|
||||
.insert(connection_id);
|
||||
true
|
||||
}
|
||||
|
||||
for thread_id in thread_ids {
|
||||
if let Some(thread_state) = self.thread_states.get(&thread_id) {
|
||||
let mut thread_state = thread_state.lock().await;
|
||||
thread_state.remove_connection(connection_id);
|
||||
if thread_state.subscribed_connection_ids().is_empty() {
|
||||
tracing::debug!(
|
||||
thread_id = %thread_id,
|
||||
connection_id = ?connection_id,
|
||||
listener_generation = thread_state.listener_generation,
|
||||
"retaining thread listener after connection disconnect left zero subscribers"
|
||||
);
|
||||
pub(crate) async fn remove_connection(&self, connection_id: ConnectionId) {
|
||||
let thread_states = {
|
||||
let mut state = self.state.lock().await;
|
||||
state.live_connections.remove(&connection_id);
|
||||
let thread_ids = state
|
||||
.thread_ids_by_connection
|
||||
.remove(&connection_id)
|
||||
.unwrap_or_default();
|
||||
state
|
||||
.subscription_state_by_id
|
||||
.retain(|_, state| state.connection_id != connection_id);
|
||||
for thread_id in &thread_ids {
|
||||
if let Some(thread_entry) = state.threads.get_mut(thread_id) {
|
||||
thread_entry.connection_ids.remove(&connection_id);
|
||||
}
|
||||
}
|
||||
thread_ids
|
||||
.into_iter()
|
||||
.map(|thread_id| {
|
||||
(
|
||||
thread_id,
|
||||
state
|
||||
.threads
|
||||
.get(&thread_id)
|
||||
.is_none_or(|thread_entry| thread_entry.connection_ids.is_empty()),
|
||||
state
|
||||
.threads
|
||||
.get(&thread_id)
|
||||
.map(|thread_entry| thread_entry.state.clone()),
|
||||
)
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
};
|
||||
|
||||
for (thread_id, no_subscribers, thread_state) in thread_states {
|
||||
if !no_subscribers {
|
||||
continue;
|
||||
}
|
||||
let Some(thread_state) = thread_state else {
|
||||
continue;
|
||||
};
|
||||
let listener_generation = thread_state.lock().await.listener_generation;
|
||||
tracing::debug!(
|
||||
thread_id = %thread_id,
|
||||
connection_id = ?connection_id,
|
||||
listener_generation,
|
||||
"retaining thread listener after connection disconnect left zero subscribers"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue