app-server: thread resume subscriptions (#11474)
This stack layer makes app-server thread event delivery connection-aware so resumed/attached threads only emit notifications and approval prompts to subscribed connections. - Added per-thread subscription tracking in `ThreadState` (`subscribed_connections`) and mapped subscription ids to `(thread_id, connection_id)`. - Updated listener lifecycle so removing a subscription or closing a connection only removes that connection from the thread’s subscriber set; listener shutdown now happens when the last subscriber is gone. - Added `connection_closed(connection_id)` plumbing (`lib.rs` -> `message_processor.rs` -> `codex_message_processor.rs`) so disconnect cleanup happens immediately. - Scoped bespoke event handling outputs through `TargetedOutgoing` to send requests/notifications only to subscribed connections. - Kept existing threadresume behavior while aligning with the latest split-loop transport structure.
This commit is contained in:
parent
703fb38d2a
commit
c0ecc2e1e1
7 changed files with 648 additions and 150 deletions
|
|
@ -4,7 +4,7 @@ use crate::codex_message_processor::read_summary_from_rollout;
|
|||
use crate::codex_message_processor::summary_to_thread;
|
||||
use crate::error_code::INTERNAL_ERROR_CODE;
|
||||
use crate::error_code::INVALID_REQUEST_ERROR_CODE;
|
||||
use crate::outgoing_message::OutgoingMessageSender;
|
||||
use crate::outgoing_message::ThreadScopedOutgoingMessageSender;
|
||||
use crate::thread_state::ThreadState;
|
||||
use crate::thread_state::TurnSummary;
|
||||
use codex_app_server_protocol::AccountRateLimitsUpdatedNotification;
|
||||
|
|
@ -107,7 +107,7 @@ pub(crate) async fn apply_bespoke_event_handling(
|
|||
event: Event,
|
||||
conversation_id: ThreadId,
|
||||
conversation: Arc<CodexThread>,
|
||||
outgoing: Arc<OutgoingMessageSender>,
|
||||
outgoing: ThreadScopedOutgoingMessageSender,
|
||||
thread_state: Arc<tokio::sync::Mutex<ThreadState>>,
|
||||
api_version: ApiVersion,
|
||||
fallback_model_provider: String,
|
||||
|
|
@ -850,7 +850,7 @@ pub(crate) async fn apply_bespoke_event_handling(
|
|||
conversation_id,
|
||||
&event_turn_id,
|
||||
raw_response_item_event.item,
|
||||
outgoing.as_ref(),
|
||||
&outgoing,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
|
@ -899,7 +899,7 @@ pub(crate) async fn apply_bespoke_event_handling(
|
|||
changes,
|
||||
status,
|
||||
event_turn_id.clone(),
|
||||
outgoing.as_ref(),
|
||||
&outgoing,
|
||||
&thread_state,
|
||||
)
|
||||
.await;
|
||||
|
|
@ -1142,7 +1142,7 @@ pub(crate) async fn apply_bespoke_event_handling(
|
|||
&event_turn_id,
|
||||
turn_diff_event,
|
||||
api_version,
|
||||
outgoing.as_ref(),
|
||||
&outgoing,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
|
@ -1152,7 +1152,7 @@ pub(crate) async fn apply_bespoke_event_handling(
|
|||
&event_turn_id,
|
||||
plan_update_event,
|
||||
api_version,
|
||||
outgoing.as_ref(),
|
||||
&outgoing,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
|
@ -1166,7 +1166,7 @@ async fn handle_turn_diff(
|
|||
event_turn_id: &str,
|
||||
turn_diff_event: TurnDiffEvent,
|
||||
api_version: ApiVersion,
|
||||
outgoing: &OutgoingMessageSender,
|
||||
outgoing: &ThreadScopedOutgoingMessageSender,
|
||||
) {
|
||||
if let ApiVersion::V2 = api_version {
|
||||
let notification = TurnDiffUpdatedNotification {
|
||||
|
|
@ -1185,7 +1185,7 @@ async fn handle_turn_plan_update(
|
|||
event_turn_id: &str,
|
||||
plan_update_event: UpdatePlanArgs,
|
||||
api_version: ApiVersion,
|
||||
outgoing: &OutgoingMessageSender,
|
||||
outgoing: &ThreadScopedOutgoingMessageSender,
|
||||
) {
|
||||
// `update_plan` is a todo/checklist tool; it is not related to plan-mode updates
|
||||
if let ApiVersion::V2 = api_version {
|
||||
|
|
@ -1210,7 +1210,7 @@ async fn emit_turn_completed_with_status(
|
|||
event_turn_id: String,
|
||||
status: TurnStatus,
|
||||
error: Option<TurnError>,
|
||||
outgoing: &OutgoingMessageSender,
|
||||
outgoing: &ThreadScopedOutgoingMessageSender,
|
||||
) {
|
||||
let notification = TurnCompletedNotification {
|
||||
thread_id: conversation_id.to_string(),
|
||||
|
|
@ -1232,7 +1232,7 @@ async fn complete_file_change_item(
|
|||
changes: Vec<FileUpdateChange>,
|
||||
status: PatchApplyStatus,
|
||||
turn_id: String,
|
||||
outgoing: &OutgoingMessageSender,
|
||||
outgoing: &ThreadScopedOutgoingMessageSender,
|
||||
thread_state: &Arc<Mutex<ThreadState>>,
|
||||
) {
|
||||
let mut state = thread_state.lock().await;
|
||||
|
|
@ -1264,7 +1264,7 @@ async fn complete_command_execution_item(
|
|||
process_id: Option<String>,
|
||||
command_actions: Vec<V2ParsedCommand>,
|
||||
status: CommandExecutionStatus,
|
||||
outgoing: &OutgoingMessageSender,
|
||||
outgoing: &ThreadScopedOutgoingMessageSender,
|
||||
) {
|
||||
let item = ThreadItem::CommandExecution {
|
||||
id: item_id,
|
||||
|
|
@ -1292,7 +1292,7 @@ async fn maybe_emit_raw_response_item_completed(
|
|||
conversation_id: ThreadId,
|
||||
turn_id: &str,
|
||||
item: codex_protocol::models::ResponseItem,
|
||||
outgoing: &OutgoingMessageSender,
|
||||
outgoing: &ThreadScopedOutgoingMessageSender,
|
||||
) {
|
||||
let ApiVersion::V2 = api_version else {
|
||||
return;
|
||||
|
|
@ -1319,7 +1319,7 @@ async fn find_and_remove_turn_summary(
|
|||
async fn handle_turn_complete(
|
||||
conversation_id: ThreadId,
|
||||
event_turn_id: String,
|
||||
outgoing: &OutgoingMessageSender,
|
||||
outgoing: &ThreadScopedOutgoingMessageSender,
|
||||
thread_state: &Arc<Mutex<ThreadState>>,
|
||||
) {
|
||||
let turn_summary = find_and_remove_turn_summary(conversation_id, thread_state).await;
|
||||
|
|
@ -1335,7 +1335,7 @@ async fn handle_turn_complete(
|
|||
async fn handle_turn_interrupted(
|
||||
conversation_id: ThreadId,
|
||||
event_turn_id: String,
|
||||
outgoing: &OutgoingMessageSender,
|
||||
outgoing: &ThreadScopedOutgoingMessageSender,
|
||||
thread_state: &Arc<Mutex<ThreadState>>,
|
||||
) {
|
||||
find_and_remove_turn_summary(conversation_id, thread_state).await;
|
||||
|
|
@ -1354,7 +1354,7 @@ async fn handle_thread_rollback_failed(
|
|||
_conversation_id: ThreadId,
|
||||
message: String,
|
||||
thread_state: &Arc<Mutex<ThreadState>>,
|
||||
outgoing: &OutgoingMessageSender,
|
||||
outgoing: &ThreadScopedOutgoingMessageSender,
|
||||
) {
|
||||
let pending_rollback = thread_state.lock().await.pending_rollbacks.take();
|
||||
|
||||
|
|
@ -1376,7 +1376,7 @@ async fn handle_token_count_event(
|
|||
conversation_id: ThreadId,
|
||||
turn_id: String,
|
||||
token_count_event: TokenCountEvent,
|
||||
outgoing: &OutgoingMessageSender,
|
||||
outgoing: &ThreadScopedOutgoingMessageSender,
|
||||
) {
|
||||
let TokenCountEvent { info, rate_limits } = token_count_event;
|
||||
if let Some(token_usage) = info.map(ThreadTokenUsage::from) {
|
||||
|
|
@ -1633,7 +1633,7 @@ async fn on_file_change_request_approval_response(
|
|||
changes: Vec<FileUpdateChange>,
|
||||
receiver: oneshot::Receiver<JsonValue>,
|
||||
codex: Arc<CodexThread>,
|
||||
outgoing: Arc<OutgoingMessageSender>,
|
||||
outgoing: ThreadScopedOutgoingMessageSender,
|
||||
thread_state: Arc<Mutex<ThreadState>>,
|
||||
) {
|
||||
let response = receiver.await;
|
||||
|
|
@ -1666,7 +1666,7 @@ async fn on_file_change_request_approval_response(
|
|||
changes,
|
||||
status,
|
||||
event_turn_id.clone(),
|
||||
outgoing.as_ref(),
|
||||
&outgoing,
|
||||
&thread_state,
|
||||
)
|
||||
.await;
|
||||
|
|
@ -1693,7 +1693,7 @@ async fn on_command_execution_request_approval_response(
|
|||
command_actions: Vec<V2ParsedCommand>,
|
||||
receiver: oneshot::Receiver<JsonValue>,
|
||||
conversation: Arc<CodexThread>,
|
||||
outgoing: Arc<OutgoingMessageSender>,
|
||||
outgoing: ThreadScopedOutgoingMessageSender,
|
||||
) {
|
||||
let response = receiver.await;
|
||||
let (decision, completion_status) = match response {
|
||||
|
|
@ -1748,7 +1748,7 @@ async fn on_command_execution_request_approval_response(
|
|||
None,
|
||||
command_actions.clone(),
|
||||
status,
|
||||
outgoing.as_ref(),
|
||||
&outgoing,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
|
@ -1876,6 +1876,7 @@ async fn construct_mcp_tool_call_end_notification(
|
|||
mod tests {
|
||||
use super::*;
|
||||
use crate::CHANNEL_CAPACITY;
|
||||
use crate::outgoing_message::ConnectionId;
|
||||
use crate::outgoing_message::OutgoingEnvelope;
|
||||
use crate::outgoing_message::OutgoingMessage;
|
||||
use crate::outgoing_message::OutgoingMessageSender;
|
||||
|
|
@ -1914,9 +1915,7 @@ mod tests {
|
|||
.ok_or_else(|| anyhow!("should send one message"))?;
|
||||
match envelope {
|
||||
OutgoingEnvelope::Broadcast { message } => Ok(message),
|
||||
OutgoingEnvelope::ToConnection { connection_id, .. } => {
|
||||
bail!("unexpected targeted message for connection {connection_id:?}")
|
||||
}
|
||||
OutgoingEnvelope::ToConnection { message, .. } => Ok(message),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -2011,6 +2010,7 @@ mod tests {
|
|||
let event_turn_id = "complete1".to_string();
|
||||
let (tx, mut rx) = mpsc::channel(CHANNEL_CAPACITY);
|
||||
let outgoing = Arc::new(OutgoingMessageSender::new(tx));
|
||||
let outgoing = ThreadScopedOutgoingMessageSender::new(outgoing, vec![ConnectionId(1)]);
|
||||
let thread_state = new_thread_state();
|
||||
|
||||
handle_turn_complete(
|
||||
|
|
@ -2051,6 +2051,7 @@ mod tests {
|
|||
.await;
|
||||
let (tx, mut rx) = mpsc::channel(CHANNEL_CAPACITY);
|
||||
let outgoing = Arc::new(OutgoingMessageSender::new(tx));
|
||||
let outgoing = ThreadScopedOutgoingMessageSender::new(outgoing, vec![ConnectionId(1)]);
|
||||
|
||||
handle_turn_interrupted(
|
||||
conversation_id,
|
||||
|
|
@ -2090,6 +2091,7 @@ mod tests {
|
|||
.await;
|
||||
let (tx, mut rx) = mpsc::channel(CHANNEL_CAPACITY);
|
||||
let outgoing = Arc::new(OutgoingMessageSender::new(tx));
|
||||
let outgoing = ThreadScopedOutgoingMessageSender::new(outgoing, vec![ConnectionId(1)]);
|
||||
|
||||
handle_turn_complete(
|
||||
conversation_id,
|
||||
|
|
@ -2122,7 +2124,8 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn test_handle_turn_plan_update_emits_notification_for_v2() -> Result<()> {
|
||||
let (tx, mut rx) = mpsc::channel(CHANNEL_CAPACITY);
|
||||
let outgoing = OutgoingMessageSender::new(tx);
|
||||
let outgoing = Arc::new(OutgoingMessageSender::new(tx));
|
||||
let outgoing = ThreadScopedOutgoingMessageSender::new(outgoing, vec![ConnectionId(1)]);
|
||||
let update = UpdatePlanArgs {
|
||||
explanation: Some("need plan".to_string()),
|
||||
plan: vec![
|
||||
|
|
@ -2172,6 +2175,7 @@ mod tests {
|
|||
let turn_id = "turn-123".to_string();
|
||||
let (tx, mut rx) = mpsc::channel(CHANNEL_CAPACITY);
|
||||
let outgoing = Arc::new(OutgoingMessageSender::new(tx));
|
||||
let outgoing = ThreadScopedOutgoingMessageSender::new(outgoing, vec![ConnectionId(1)]);
|
||||
|
||||
let info = TokenUsageInfo {
|
||||
total_token_usage: TokenUsage {
|
||||
|
|
@ -2255,6 +2259,7 @@ mod tests {
|
|||
let turn_id = "turn-456".to_string();
|
||||
let (tx, mut rx) = mpsc::channel(CHANNEL_CAPACITY);
|
||||
let outgoing = Arc::new(OutgoingMessageSender::new(tx));
|
||||
let outgoing = ThreadScopedOutgoingMessageSender::new(outgoing, vec![ConnectionId(1)]);
|
||||
|
||||
handle_token_count_event(
|
||||
conversation_id,
|
||||
|
|
@ -2321,6 +2326,7 @@ mod tests {
|
|||
|
||||
let (tx, mut rx) = mpsc::channel(CHANNEL_CAPACITY);
|
||||
let outgoing = Arc::new(OutgoingMessageSender::new(tx));
|
||||
let outgoing = ThreadScopedOutgoingMessageSender::new(outgoing, vec![ConnectionId(1)]);
|
||||
|
||||
// Turn 1 on conversation A
|
||||
let a_turn1 = "a_turn1".to_string();
|
||||
|
|
@ -2542,7 +2548,8 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn test_handle_turn_diff_emits_v2_notification() -> Result<()> {
|
||||
let (tx, mut rx) = mpsc::channel(CHANNEL_CAPACITY);
|
||||
let outgoing = OutgoingMessageSender::new(tx);
|
||||
let outgoing = Arc::new(OutgoingMessageSender::new(tx));
|
||||
let outgoing = ThreadScopedOutgoingMessageSender::new(outgoing, vec![ConnectionId(1)]);
|
||||
let unified_diff = "--- a\n+++ b\n".to_string();
|
||||
let conversation_id = ThreadId::new();
|
||||
|
||||
|
|
@ -2575,7 +2582,8 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn test_handle_turn_diff_is_noop_for_v1() -> Result<()> {
|
||||
let (tx, mut rx) = mpsc::channel(CHANNEL_CAPACITY);
|
||||
let outgoing = OutgoingMessageSender::new(tx);
|
||||
let outgoing = Arc::new(OutgoingMessageSender::new(tx));
|
||||
let outgoing = ThreadScopedOutgoingMessageSender::new(outgoing, vec![ConnectionId(1)]);
|
||||
let conversation_id = ThreadId::new();
|
||||
|
||||
handle_turn_diff(
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ use crate::outgoing_message::ConnectionId;
|
|||
use crate::outgoing_message::ConnectionRequestId;
|
||||
use crate::outgoing_message::OutgoingMessageSender;
|
||||
use crate::outgoing_message::OutgoingNotification;
|
||||
use crate::outgoing_message::ThreadScopedOutgoingMessageSender;
|
||||
use chrono::DateTime;
|
||||
use chrono::SecondsFormat;
|
||||
use chrono::Utc;
|
||||
|
|
@ -251,6 +252,7 @@ use uuid::Uuid;
|
|||
|
||||
use crate::filters::compute_source_filters;
|
||||
use crate::filters::source_kind_matches;
|
||||
use crate::thread_state::ThreadState;
|
||||
use crate::thread_state::ThreadStateManager;
|
||||
|
||||
const THREAD_LIST_DEFAULT_LIMIT: usize = 25;
|
||||
|
|
@ -1961,8 +1963,9 @@ impl CodexMessageProcessor {
|
|||
// Auto-attach a thread listener when starting a thread.
|
||||
// Use the same behavior as the v1 API, with opt-in support for raw item events.
|
||||
if let Err(err) = self
|
||||
.attach_conversation_listener(
|
||||
.ensure_conversation_listener(
|
||||
thread_id,
|
||||
request_id.connection_id,
|
||||
experimental_raw_events,
|
||||
ApiVersion::V2,
|
||||
)
|
||||
|
|
@ -2628,20 +2631,28 @@ impl CodexMessageProcessor {
|
|||
self.thread_manager.subscribe_thread_created()
|
||||
}
|
||||
|
||||
/// Best-effort: attach a listener for thread_id if missing.
|
||||
pub(crate) async fn try_attach_thread_listener(&mut self, thread_id: ThreadId) {
|
||||
if self.thread_state_manager.has_listener_for_thread(thread_id) {
|
||||
return;
|
||||
}
|
||||
pub(crate) async fn connection_closed(&mut self, connection_id: ConnectionId) {
|
||||
self.thread_state_manager
|
||||
.remove_connection(connection_id)
|
||||
.await;
|
||||
}
|
||||
|
||||
if let Err(err) = self
|
||||
.attach_conversation_listener(thread_id, false, ApiVersion::V2)
|
||||
.await
|
||||
{
|
||||
warn!(
|
||||
"failed to attach listener for thread {thread_id}: {message}",
|
||||
message = err.message
|
||||
);
|
||||
/// Best-effort: ensure initialized connections are subscribed to this thread.
|
||||
pub(crate) async fn try_attach_thread_listener(
|
||||
&mut self,
|
||||
thread_id: ThreadId,
|
||||
connection_ids: Vec<ConnectionId>,
|
||||
) {
|
||||
for connection_id in connection_ids {
|
||||
if let Err(err) = self
|
||||
.ensure_conversation_listener(thread_id, connection_id, false, ApiVersion::V2)
|
||||
.await
|
||||
{
|
||||
warn!(
|
||||
"failed to auto-attach listener for thread {thread_id}: {message}",
|
||||
message = err.message
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -2793,7 +2804,12 @@ impl CodexMessageProcessor {
|
|||
};
|
||||
// Auto-attach a thread listener when resuming a thread.
|
||||
if let Err(err) = self
|
||||
.attach_conversation_listener(thread_id, false, ApiVersion::V2)
|
||||
.ensure_conversation_listener(
|
||||
thread_id,
|
||||
request_id.connection_id,
|
||||
false,
|
||||
ApiVersion::V2,
|
||||
)
|
||||
.await
|
||||
{
|
||||
tracing::warn!(
|
||||
|
|
@ -3019,7 +3035,12 @@ impl CodexMessageProcessor {
|
|||
};
|
||||
// Auto-attach a conversation listener when forking a thread.
|
||||
if let Err(err) = self
|
||||
.attach_conversation_listener(thread_id, false, ApiVersion::V2)
|
||||
.ensure_conversation_listener(
|
||||
thread_id,
|
||||
request_id.connection_id,
|
||||
false,
|
||||
ApiVersion::V2,
|
||||
)
|
||||
.await
|
||||
{
|
||||
tracing::warn!(
|
||||
|
|
@ -5136,7 +5157,12 @@ impl CodexMessageProcessor {
|
|||
})?;
|
||||
|
||||
if let Err(err) = self
|
||||
.attach_conversation_listener(thread_id, false, ApiVersion::V2)
|
||||
.ensure_conversation_listener(
|
||||
thread_id,
|
||||
request_id.connection_id,
|
||||
false,
|
||||
ApiVersion::V2,
|
||||
)
|
||||
.await
|
||||
{
|
||||
tracing::warn!(
|
||||
|
|
@ -5281,18 +5307,38 @@ impl CodexMessageProcessor {
|
|||
conversation_id,
|
||||
experimental_raw_events,
|
||||
} = params;
|
||||
match self
|
||||
.attach_conversation_listener(conversation_id, experimental_raw_events, ApiVersion::V1)
|
||||
.await
|
||||
{
|
||||
Ok(subscription_id) => {
|
||||
let response = AddConversationSubscriptionResponse { subscription_id };
|
||||
self.outgoing.send_response(request_id, response).await;
|
||||
let conversation = match self.thread_manager.get_thread(conversation_id).await {
|
||||
Ok(conv) => conv,
|
||||
Err(_) => {
|
||||
let error = JSONRPCErrorError {
|
||||
code: INVALID_REQUEST_ERROR_CODE,
|
||||
message: format!("thread not found: {conversation_id}"),
|
||||
data: None,
|
||||
};
|
||||
self.outgoing.send_error(request_id, error).await;
|
||||
return;
|
||||
}
|
||||
Err(err) => {
|
||||
self.outgoing.send_error(request_id, err).await;
|
||||
}
|
||||
}
|
||||
};
|
||||
let subscription_id = Uuid::new_v4();
|
||||
let thread_state = self
|
||||
.thread_state_manager
|
||||
.set_listener(
|
||||
subscription_id,
|
||||
conversation_id,
|
||||
request_id.connection_id,
|
||||
experimental_raw_events,
|
||||
)
|
||||
.await;
|
||||
self.ensure_listener_task_running(
|
||||
conversation_id,
|
||||
conversation,
|
||||
thread_state,
|
||||
ApiVersion::V1,
|
||||
)
|
||||
.await;
|
||||
|
||||
let response = AddConversationSubscriptionResponse { subscription_id };
|
||||
self.outgoing.send_response(request_id, response).await;
|
||||
}
|
||||
|
||||
async fn remove_thread_listener(
|
||||
|
|
@ -5322,12 +5368,13 @@ impl CodexMessageProcessor {
|
|||
}
|
||||
}
|
||||
|
||||
async fn attach_conversation_listener(
|
||||
async fn ensure_conversation_listener(
|
||||
&mut self,
|
||||
conversation_id: ThreadId,
|
||||
connection_id: ConnectionId,
|
||||
raw_events_enabled: bool,
|
||||
api_version: ApiVersion,
|
||||
) -> Result<Uuid, JSONRPCErrorError> {
|
||||
) -> Result<(), JSONRPCErrorError> {
|
||||
let conversation = match self.thread_manager.get_thread(conversation_id).await {
|
||||
Ok(conv) => conv,
|
||||
Err(_) => {
|
||||
|
|
@ -5338,13 +5385,30 @@ impl CodexMessageProcessor {
|
|||
});
|
||||
}
|
||||
};
|
||||
|
||||
let subscription_id = Uuid::new_v4();
|
||||
let (cancel_tx, mut cancel_rx) = oneshot::channel();
|
||||
let thread_state = self
|
||||
.thread_state_manager
|
||||
.set_listener(subscription_id, conversation_id, cancel_tx)
|
||||
.ensure_connection_subscribed(conversation_id, connection_id, raw_events_enabled)
|
||||
.await;
|
||||
self.ensure_listener_task_running(conversation_id, conversation, thread_state, api_version)
|
||||
.await;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn ensure_listener_task_running(
|
||||
&self,
|
||||
conversation_id: ThreadId,
|
||||
conversation: Arc<CodexThread>,
|
||||
thread_state: Arc<Mutex<ThreadState>>,
|
||||
api_version: ApiVersion,
|
||||
) {
|
||||
let (cancel_tx, mut cancel_rx) = oneshot::channel();
|
||||
{
|
||||
let mut thread_state = thread_state.lock().await;
|
||||
if thread_state.listener_matches(&conversation) {
|
||||
return;
|
||||
}
|
||||
thread_state.set_listener(cancel_tx, &conversation);
|
||||
}
|
||||
let outgoing_for_task = self.outgoing.clone();
|
||||
let fallback_model_provider = self.config.model_provider_id.clone();
|
||||
tokio::spawn(async move {
|
||||
|
|
@ -5363,10 +5427,6 @@ impl CodexMessageProcessor {
|
|||
}
|
||||
};
|
||||
|
||||
if let EventMsg::RawResponseItem(_) = &event.msg && !raw_events_enabled {
|
||||
continue;
|
||||
}
|
||||
|
||||
// For now, we send a notification for every event,
|
||||
// JSON-serializing the `Event` as-is, but these should
|
||||
// be migrated to be variants of `ServerNotification`
|
||||
|
|
@ -5391,19 +5451,38 @@ impl CodexMessageProcessor {
|
|||
"conversationId".to_string(),
|
||||
conversation_id.to_string().into(),
|
||||
);
|
||||
let (subscribed_connection_ids, raw_events_enabled) = {
|
||||
let thread_state = thread_state.lock().await;
|
||||
(
|
||||
thread_state.subscribed_connection_ids(),
|
||||
thread_state.experimental_raw_events,
|
||||
)
|
||||
};
|
||||
if let EventMsg::RawResponseItem(_) = &event.msg && !raw_events_enabled {
|
||||
continue;
|
||||
}
|
||||
|
||||
outgoing_for_task
|
||||
.send_notification(OutgoingNotification {
|
||||
method: format!("codex/event/{event_formatted}"),
|
||||
params: Some(params.into()),
|
||||
})
|
||||
.await;
|
||||
if !subscribed_connection_ids.is_empty() {
|
||||
outgoing_for_task
|
||||
.send_notification_to_connections(
|
||||
&subscribed_connection_ids,
|
||||
OutgoingNotification {
|
||||
method: format!("codex/event/{event_formatted}"),
|
||||
params: Some(params.into()),
|
||||
},
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
let thread_outgoing = ThreadScopedOutgoingMessageSender::new(
|
||||
outgoing_for_task.clone(),
|
||||
subscribed_connection_ids,
|
||||
);
|
||||
apply_bespoke_event_handling(
|
||||
event.clone(),
|
||||
conversation_id,
|
||||
conversation.clone(),
|
||||
outgoing_for_task.clone(),
|
||||
thread_outgoing,
|
||||
thread_state.clone(),
|
||||
api_version,
|
||||
fallback_model_provider.clone(),
|
||||
|
|
@ -5413,9 +5492,7 @@ impl CodexMessageProcessor {
|
|||
}
|
||||
}
|
||||
});
|
||||
Ok(subscription_id)
|
||||
}
|
||||
|
||||
async fn git_diff_to_origin(&self, request_id: ConnectionRequestId, cwd: PathBuf) {
|
||||
let diff = git_diff_to_remote(&cwd).await;
|
||||
match diff {
|
||||
|
|
@ -6299,22 +6376,137 @@ mod tests {
|
|||
let thread_id = ThreadId::from_string("ad7f0408-99b8-4f6e-a46f-bd0eec433370")?;
|
||||
let listener_a = Uuid::new_v4();
|
||||
let listener_b = Uuid::new_v4();
|
||||
let (cancel_a, cancel_rx_a) = oneshot::channel();
|
||||
let (cancel_b, mut cancel_rx_b) = oneshot::channel();
|
||||
let connection_a = ConnectionId(1);
|
||||
let connection_b = ConnectionId(2);
|
||||
let (cancel_tx, mut cancel_rx) = oneshot::channel();
|
||||
|
||||
manager.set_listener(listener_a, thread_id, cancel_a).await;
|
||||
manager.set_listener(listener_b, thread_id, cancel_b).await;
|
||||
manager
|
||||
.set_listener(listener_a, thread_id, connection_a, false)
|
||||
.await;
|
||||
manager
|
||||
.set_listener(listener_b, thread_id, connection_b, false)
|
||||
.await;
|
||||
{
|
||||
let state = manager.thread_state(thread_id);
|
||||
state.lock().await.cancel_tx = Some(cancel_tx);
|
||||
}
|
||||
|
||||
assert_eq!(manager.remove_listener(listener_a).await, Some(thread_id));
|
||||
assert_eq!(cancel_rx_a.await, Ok(()));
|
||||
assert!(
|
||||
tokio::time::timeout(Duration::from_millis(20), &mut cancel_rx_b)
|
||||
tokio::time::timeout(Duration::from_millis(20), &mut cancel_rx)
|
||||
.await
|
||||
.is_err()
|
||||
);
|
||||
assert_eq!(manager.remove_listener(listener_b).await, Some(thread_id));
|
||||
assert_eq!(cancel_rx.await, Ok(()));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn removing_listener_unsubscribes_its_connection() -> Result<()> {
|
||||
let mut manager = ThreadStateManager::new();
|
||||
let thread_id = ThreadId::from_string("ad7f0408-99b8-4f6e-a46f-bd0eec433370")?;
|
||||
let listener_a = Uuid::new_v4();
|
||||
let listener_b = Uuid::new_v4();
|
||||
let connection_a = ConnectionId(1);
|
||||
let connection_b = ConnectionId(2);
|
||||
|
||||
manager
|
||||
.set_listener(listener_a, thread_id, connection_a, false)
|
||||
.await;
|
||||
manager
|
||||
.set_listener(listener_b, thread_id, connection_b, false)
|
||||
.await;
|
||||
|
||||
assert_eq!(manager.remove_listener(listener_a).await, Some(thread_id));
|
||||
let state = manager.thread_state(thread_id);
|
||||
let subscribed_connection_ids = state.lock().await.subscribed_connection_ids();
|
||||
assert_eq!(subscribed_connection_ids, vec![connection_b]);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn set_listener_uses_last_write_for_raw_events() -> Result<()> {
|
||||
let mut manager = ThreadStateManager::new();
|
||||
let thread_id = ThreadId::from_string("ad7f0408-99b8-4f6e-a46f-bd0eec433370")?;
|
||||
let listener_a = Uuid::new_v4();
|
||||
let listener_b = Uuid::new_v4();
|
||||
let connection_a = ConnectionId(1);
|
||||
let connection_b = ConnectionId(2);
|
||||
|
||||
manager
|
||||
.set_listener(listener_a, thread_id, connection_a, true)
|
||||
.await;
|
||||
{
|
||||
let state = manager.thread_state(thread_id);
|
||||
assert!(state.lock().await.experimental_raw_events);
|
||||
}
|
||||
manager
|
||||
.set_listener(listener_b, thread_id, connection_b, false)
|
||||
.await;
|
||||
let state = manager.thread_state(thread_id);
|
||||
assert!(!state.lock().await.experimental_raw_events);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn removing_connection_clears_subscription_and_listener_when_last_subscriber()
|
||||
-> Result<()> {
|
||||
let mut manager = ThreadStateManager::new();
|
||||
let thread_id = ThreadId::from_string("ad7f0408-99b8-4f6e-a46f-bd0eec433370")?;
|
||||
let listener = Uuid::new_v4();
|
||||
let connection = ConnectionId(1);
|
||||
let (cancel_tx, cancel_rx) = oneshot::channel();
|
||||
|
||||
manager
|
||||
.set_listener(listener, thread_id, connection, false)
|
||||
.await;
|
||||
{
|
||||
let state = manager.thread_state(thread_id);
|
||||
state.lock().await.cancel_tx = Some(cancel_tx);
|
||||
}
|
||||
|
||||
manager.remove_connection(connection).await;
|
||||
assert_eq!(cancel_rx.await, Ok(()));
|
||||
assert_eq!(manager.remove_listener(listener).await, None);
|
||||
|
||||
let state = manager.thread_state(thread_id);
|
||||
assert!(state.lock().await.subscribed_connection_ids().is_empty());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn removing_auto_attached_connection_preserves_listener_for_other_connections()
|
||||
-> Result<()> {
|
||||
let mut manager = ThreadStateManager::new();
|
||||
let thread_id = ThreadId::from_string("ad7f0408-99b8-4f6e-a46f-bd0eec433370")?;
|
||||
let connection_a = ConnectionId(1);
|
||||
let connection_b = ConnectionId(2);
|
||||
let (cancel_tx, mut cancel_rx) = oneshot::channel();
|
||||
|
||||
manager
|
||||
.ensure_connection_subscribed(thread_id, connection_a, false)
|
||||
.await;
|
||||
manager
|
||||
.ensure_connection_subscribed(thread_id, connection_b, false)
|
||||
.await;
|
||||
{
|
||||
let state = manager.thread_state(thread_id);
|
||||
state.lock().await.cancel_tx = Some(cancel_tx);
|
||||
}
|
||||
|
||||
manager.remove_connection(connection_a).await;
|
||||
assert!(
|
||||
tokio::time::timeout(Duration::from_millis(20), &mut cancel_rx)
|
||||
.await
|
||||
.is_err()
|
||||
);
|
||||
|
||||
assert_eq!(manager.remove_listener(listener_b).await, Some(thread_id));
|
||||
assert_eq!(cancel_rx_b.await, Ok(()));
|
||||
let state = manager.thread_state(thread_id);
|
||||
assert_eq!(
|
||||
state.lock().await.subscribed_connection_ids(),
|
||||
vec![connection_b]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -27,7 +27,6 @@ use crate::transport::CHANNEL_CAPACITY;
|
|||
use crate::transport::ConnectionState;
|
||||
use crate::transport::OutboundConnectionState;
|
||||
use crate::transport::TransportEvent;
|
||||
use crate::transport::has_initialized_connections;
|
||||
use crate::transport::route_outgoing_envelope;
|
||||
use crate::transport::start_stdio_connection;
|
||||
use crate::transport::start_websocket_acceptor;
|
||||
|
|
@ -490,6 +489,7 @@ pub async fn run_main_with_transport(
|
|||
{
|
||||
break;
|
||||
}
|
||||
processor.connection_closed(connection_id).await;
|
||||
connections.remove(&connection_id);
|
||||
if shutdown_when_no_connections && connections.is_empty() {
|
||||
break;
|
||||
|
|
@ -544,8 +544,19 @@ pub async fn run_main_with_transport(
|
|||
created = thread_created_rx.recv(), if listen_for_threads => {
|
||||
match created {
|
||||
Ok(thread_id) => {
|
||||
if has_initialized_connections(&connections) {
|
||||
processor.try_attach_thread_listener(thread_id).await;
|
||||
let initialized_connection_ids: Vec<ConnectionId> = connections
|
||||
.iter()
|
||||
.filter_map(|(connection_id, connection_state)| {
|
||||
connection_state.session.initialized.then_some(*connection_id)
|
||||
})
|
||||
.collect();
|
||||
if !initialized_connection_ids.is_empty() {
|
||||
processor
|
||||
.try_attach_thread_listener(
|
||||
thread_id,
|
||||
initialized_connection_ids,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => {
|
||||
|
|
|
|||
|
|
@ -396,9 +396,19 @@ impl MessageProcessor {
|
|||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn try_attach_thread_listener(&mut self, thread_id: ThreadId) {
|
||||
pub(crate) async fn try_attach_thread_listener(
|
||||
&mut self,
|
||||
thread_id: ThreadId,
|
||||
connection_ids: Vec<ConnectionId>,
|
||||
) {
|
||||
self.codex_message_processor
|
||||
.try_attach_thread_listener(thread_id)
|
||||
.try_attach_thread_listener(thread_id, connection_ids)
|
||||
.await;
|
||||
}
|
||||
|
||||
pub(crate) async fn connection_closed(&mut self, connection_id: ConnectionId) {
|
||||
self.codex_message_processor
|
||||
.connection_closed(connection_id)
|
||||
.await;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::AtomicI64;
|
||||
use std::sync::atomic::Ordering;
|
||||
|
||||
|
|
@ -48,6 +49,62 @@ pub(crate) struct OutgoingMessageSender {
|
|||
request_id_to_callback: Mutex<HashMap<RequestId, oneshot::Sender<Result>>>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct ThreadScopedOutgoingMessageSender {
|
||||
outgoing: Arc<OutgoingMessageSender>,
|
||||
connection_ids: Arc<Vec<ConnectionId>>,
|
||||
}
|
||||
|
||||
impl ThreadScopedOutgoingMessageSender {
|
||||
pub(crate) fn new(
|
||||
outgoing: Arc<OutgoingMessageSender>,
|
||||
connection_ids: Vec<ConnectionId>,
|
||||
) -> Self {
|
||||
Self {
|
||||
outgoing,
|
||||
connection_ids: Arc::new(connection_ids),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn send_request(
|
||||
&self,
|
||||
payload: ServerRequestPayload,
|
||||
) -> oneshot::Receiver<Result> {
|
||||
if self.connection_ids.is_empty() {
|
||||
let (_tx, rx) = oneshot::channel();
|
||||
return rx;
|
||||
}
|
||||
self.outgoing
|
||||
.send_request_to_connections(self.connection_ids.as_slice(), payload)
|
||||
.await
|
||||
}
|
||||
|
||||
pub(crate) async fn send_server_notification(&self, notification: ServerNotification) {
|
||||
if self.connection_ids.is_empty() {
|
||||
return;
|
||||
}
|
||||
self.outgoing
|
||||
.send_server_notification_to_connections(self.connection_ids.as_slice(), notification)
|
||||
.await;
|
||||
}
|
||||
|
||||
pub(crate) async fn send_response<T: Serialize>(
|
||||
&self,
|
||||
request_id: ConnectionRequestId,
|
||||
response: T,
|
||||
) {
|
||||
self.outgoing.send_response(request_id, response).await;
|
||||
}
|
||||
|
||||
pub(crate) async fn send_error(
|
||||
&self,
|
||||
request_id: ConnectionRequestId,
|
||||
error: JSONRPCErrorError,
|
||||
) {
|
||||
self.outgoing.send_error(request_id, error).await;
|
||||
}
|
||||
}
|
||||
|
||||
impl OutgoingMessageSender {
|
||||
pub(crate) fn new(sender: mpsc::Sender<OutgoingEnvelope>) -> Self {
|
||||
Self {
|
||||
|
|
@ -57,17 +114,28 @@ impl OutgoingMessageSender {
|
|||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn send_request(
|
||||
pub(crate) async fn send_request_to_connections(
|
||||
&self,
|
||||
connection_ids: &[ConnectionId],
|
||||
request: ServerRequestPayload,
|
||||
) -> oneshot::Receiver<Result> {
|
||||
let (_id, rx) = self.send_request_with_id(request).await;
|
||||
let (_id, rx) = self
|
||||
.send_request_with_id_to_connections(connection_ids, request)
|
||||
.await;
|
||||
rx
|
||||
}
|
||||
|
||||
pub(crate) async fn send_request_with_id(
|
||||
&self,
|
||||
request: ServerRequestPayload,
|
||||
) -> (RequestId, oneshot::Receiver<Result>) {
|
||||
self.send_request_with_id_to_connections(&[], request).await
|
||||
}
|
||||
|
||||
async fn send_request_with_id_to_connections(
|
||||
&self,
|
||||
connection_ids: &[ConnectionId],
|
||||
request: ServerRequestPayload,
|
||||
) -> (RequestId, oneshot::Receiver<Result>) {
|
||||
let id = RequestId::Integer(self.next_server_request_id.fetch_add(1, Ordering::Relaxed));
|
||||
let outgoing_message_id = id.clone();
|
||||
|
|
@ -79,13 +147,34 @@ impl OutgoingMessageSender {
|
|||
|
||||
let outgoing_message =
|
||||
OutgoingMessage::Request(request.request_with_id(outgoing_message_id.clone()));
|
||||
if let Err(err) = self
|
||||
.sender
|
||||
.send(OutgoingEnvelope::Broadcast {
|
||||
message: outgoing_message,
|
||||
})
|
||||
.await
|
||||
{
|
||||
let send_result = if connection_ids.is_empty() {
|
||||
self.sender
|
||||
.send(OutgoingEnvelope::Broadcast {
|
||||
message: outgoing_message,
|
||||
})
|
||||
.await
|
||||
} else {
|
||||
let mut send_error = None;
|
||||
for connection_id in connection_ids {
|
||||
if let Err(err) = self
|
||||
.sender
|
||||
.send(OutgoingEnvelope::ToConnection {
|
||||
connection_id: *connection_id,
|
||||
message: outgoing_message.clone(),
|
||||
})
|
||||
.await
|
||||
{
|
||||
send_error = Some(err);
|
||||
break;
|
||||
}
|
||||
}
|
||||
match send_error {
|
||||
Some(err) => Err(err),
|
||||
None => Ok(()),
|
||||
}
|
||||
};
|
||||
|
||||
if let Err(err) = send_result {
|
||||
warn!("failed to send request {outgoing_message_id:?} to client: {err:?}");
|
||||
let mut request_id_to_callback = self.request_id_to_callback.lock().await;
|
||||
request_id_to_callback.remove(&outgoing_message_id);
|
||||
|
|
@ -172,29 +261,71 @@ impl OutgoingMessageSender {
|
|||
}
|
||||
|
||||
pub(crate) async fn send_server_notification(&self, notification: ServerNotification) {
|
||||
if let Err(err) = self
|
||||
.sender
|
||||
.send(OutgoingEnvelope::Broadcast {
|
||||
message: OutgoingMessage::AppServerNotification(notification),
|
||||
})
|
||||
.await
|
||||
{
|
||||
warn!("failed to send server notification to client: {err:?}");
|
||||
self.send_server_notification_to_connections(&[], notification)
|
||||
.await;
|
||||
}
|
||||
|
||||
pub(crate) async fn send_server_notification_to_connections(
|
||||
&self,
|
||||
connection_ids: &[ConnectionId],
|
||||
notification: ServerNotification,
|
||||
) {
|
||||
let outgoing_message = OutgoingMessage::AppServerNotification(notification);
|
||||
if connection_ids.is_empty() {
|
||||
if let Err(err) = self
|
||||
.sender
|
||||
.send(OutgoingEnvelope::Broadcast {
|
||||
message: outgoing_message,
|
||||
})
|
||||
.await
|
||||
{
|
||||
warn!("failed to send server notification to client: {err:?}");
|
||||
}
|
||||
return;
|
||||
}
|
||||
for connection_id in connection_ids {
|
||||
if let Err(err) = self
|
||||
.sender
|
||||
.send(OutgoingEnvelope::ToConnection {
|
||||
connection_id: *connection_id,
|
||||
message: outgoing_message.clone(),
|
||||
})
|
||||
.await
|
||||
{
|
||||
warn!("failed to send server notification to client: {err:?}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// All notifications should be migrated to [`ServerNotification`] and
|
||||
/// [`OutgoingMessage::Notification`] should be removed.
|
||||
pub(crate) async fn send_notification(&self, notification: OutgoingNotification) {
|
||||
pub(crate) async fn send_notification_to_connections(
|
||||
&self,
|
||||
connection_ids: &[ConnectionId],
|
||||
notification: OutgoingNotification,
|
||||
) {
|
||||
let outgoing_message = OutgoingMessage::Notification(notification);
|
||||
if let Err(err) = self
|
||||
.sender
|
||||
.send(OutgoingEnvelope::Broadcast {
|
||||
message: outgoing_message,
|
||||
})
|
||||
.await
|
||||
{
|
||||
warn!("failed to send notification to client: {err:?}");
|
||||
if connection_ids.is_empty() {
|
||||
if let Err(err) = self
|
||||
.sender
|
||||
.send(OutgoingEnvelope::Broadcast {
|
||||
message: outgoing_message,
|
||||
})
|
||||
.await
|
||||
{
|
||||
warn!("failed to send notification to client: {err:?}");
|
||||
}
|
||||
return;
|
||||
}
|
||||
for connection_id in connection_ids {
|
||||
if let Err(err) = self
|
||||
.sender
|
||||
.send(OutgoingEnvelope::ToConnection {
|
||||
connection_id: *connection_id,
|
||||
message: outgoing_message.clone(),
|
||||
})
|
||||
.await
|
||||
{
|
||||
warn!("failed to send notification to client: {err:?}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,9 +1,12 @@
|
|||
use crate::outgoing_message::ConnectionId;
|
||||
use crate::outgoing_message::ConnectionRequestId;
|
||||
use codex_app_server_protocol::TurnError;
|
||||
use codex_core::CodexThread;
|
||||
use codex_protocol::ThreadId;
|
||||
use std::collections::HashMap;
|
||||
use std::collections::HashSet;
|
||||
use std::sync::Arc;
|
||||
use std::sync::Weak;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::sync::oneshot;
|
||||
use uuid::Uuid;
|
||||
|
|
@ -25,33 +28,66 @@ pub(crate) struct ThreadState {
|
|||
pub(crate) pending_interrupts: PendingInterruptQueue,
|
||||
pub(crate) pending_rollbacks: Option<ConnectionRequestId>,
|
||||
pub(crate) turn_summary: TurnSummary,
|
||||
pub(crate) listener_cancel_txs: HashMap<Uuid, oneshot::Sender<()>>,
|
||||
pub(crate) cancel_tx: Option<oneshot::Sender<()>>,
|
||||
pub(crate) experimental_raw_events: bool,
|
||||
listener_thread: Option<Weak<CodexThread>>,
|
||||
subscribed_connections: HashSet<ConnectionId>,
|
||||
}
|
||||
|
||||
impl ThreadState {
|
||||
fn set_listener(&mut self, subscription_id: Uuid, cancel_tx: oneshot::Sender<()>) {
|
||||
if let Some(previous) = self.listener_cancel_txs.insert(subscription_id, cancel_tx) {
|
||||
pub(crate) fn listener_matches(&self, conversation: &Arc<CodexThread>) -> bool {
|
||||
self.listener_thread
|
||||
.as_ref()
|
||||
.and_then(Weak::upgrade)
|
||||
.is_some_and(|existing| Arc::ptr_eq(&existing, conversation))
|
||||
}
|
||||
|
||||
pub(crate) fn set_listener(
|
||||
&mut self,
|
||||
cancel_tx: oneshot::Sender<()>,
|
||||
conversation: &Arc<CodexThread>,
|
||||
) {
|
||||
if let Some(previous) = self.cancel_tx.replace(cancel_tx) {
|
||||
let _ = previous.send(());
|
||||
}
|
||||
self.listener_thread = Some(Arc::downgrade(conversation));
|
||||
}
|
||||
|
||||
fn clear_listener(&mut self, subscription_id: Uuid) {
|
||||
if let Some(cancel_tx) = self.listener_cancel_txs.remove(&subscription_id) {
|
||||
pub(crate) fn clear_listener(&mut self) {
|
||||
if let Some(cancel_tx) = self.cancel_tx.take() {
|
||||
let _ = cancel_tx.send(());
|
||||
}
|
||||
self.listener_thread = None;
|
||||
}
|
||||
|
||||
fn clear_listeners(&mut self) {
|
||||
for (_, cancel_tx) in self.listener_cancel_txs.drain() {
|
||||
let _ = cancel_tx.send(());
|
||||
}
|
||||
pub(crate) fn add_connection(&mut self, connection_id: ConnectionId) {
|
||||
self.subscribed_connections.insert(connection_id);
|
||||
}
|
||||
|
||||
pub(crate) fn remove_connection(&mut self, connection_id: ConnectionId) {
|
||||
self.subscribed_connections.remove(&connection_id);
|
||||
}
|
||||
|
||||
pub(crate) fn subscribed_connection_ids(&self) -> Vec<ConnectionId> {
|
||||
self.subscribed_connections.iter().copied().collect()
|
||||
}
|
||||
|
||||
pub(crate) fn set_experimental_raw_events(&mut self, enabled: bool) {
|
||||
self.experimental_raw_events = enabled;
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
struct SubscriptionState {
|
||||
thread_id: ThreadId,
|
||||
connection_id: ConnectionId,
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub(crate) struct ThreadStateManager {
|
||||
thread_states: HashMap<ThreadId, Arc<Mutex<ThreadState>>>,
|
||||
thread_id_by_subscription: HashMap<Uuid, ThreadId>,
|
||||
subscription_state_by_id: HashMap<Uuid, SubscriptionState>,
|
||||
thread_ids_by_connection: HashMap<ConnectionId, HashSet<ThreadId>>,
|
||||
}
|
||||
|
||||
impl ThreadStateManager {
|
||||
|
|
@ -59,12 +95,6 @@ impl ThreadStateManager {
|
|||
Self::default()
|
||||
}
|
||||
|
||||
pub(crate) fn has_listener_for_thread(&self, thread_id: ThreadId) -> bool {
|
||||
self.thread_id_by_subscription
|
||||
.values()
|
||||
.any(|existing| *existing == thread_id)
|
||||
}
|
||||
|
||||
pub(crate) fn thread_state(&mut self, thread_id: ThreadId) -> Arc<Mutex<ThreadState>> {
|
||||
self.thread_states
|
||||
.entry(thread_id)
|
||||
|
|
@ -73,34 +103,119 @@ impl ThreadStateManager {
|
|||
}
|
||||
|
||||
pub(crate) async fn remove_listener(&mut self, subscription_id: Uuid) -> Option<ThreadId> {
|
||||
let thread_id = self.thread_id_by_subscription.remove(&subscription_id)?;
|
||||
let subscription_state = self.subscription_state_by_id.remove(&subscription_id)?;
|
||||
let thread_id = subscription_state.thread_id;
|
||||
|
||||
let connection_still_subscribed_to_thread =
|
||||
self.subscription_state_by_id.values().any(|state| {
|
||||
state.thread_id == thread_id
|
||||
&& state.connection_id == subscription_state.connection_id
|
||||
});
|
||||
if !connection_still_subscribed_to_thread {
|
||||
let mut remove_connection_entry = false;
|
||||
if let Some(thread_ids) = self
|
||||
.thread_ids_by_connection
|
||||
.get_mut(&subscription_state.connection_id)
|
||||
{
|
||||
thread_ids.remove(&thread_id);
|
||||
remove_connection_entry = thread_ids.is_empty();
|
||||
}
|
||||
if remove_connection_entry {
|
||||
self.thread_ids_by_connection
|
||||
.remove(&subscription_state.connection_id);
|
||||
}
|
||||
if let Some(thread_state) = self.thread_states.get(&thread_id) {
|
||||
thread_state
|
||||
.lock()
|
||||
.await
|
||||
.remove_connection(subscription_state.connection_id);
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(thread_state) = self.thread_states.get(&thread_id) {
|
||||
thread_state.lock().await.clear_listener(subscription_id);
|
||||
let mut thread_state = thread_state.lock().await;
|
||||
if thread_state.subscribed_connection_ids().is_empty() {
|
||||
thread_state.clear_listener();
|
||||
}
|
||||
}
|
||||
Some(thread_id)
|
||||
}
|
||||
|
||||
pub(crate) async fn remove_thread_state(&mut self, thread_id: ThreadId) {
|
||||
if let Some(thread_state) = self.thread_states.remove(&thread_id) {
|
||||
thread_state.lock().await.clear_listeners();
|
||||
thread_state.lock().await.clear_listener();
|
||||
}
|
||||
self.thread_id_by_subscription
|
||||
.retain(|_, existing_thread_id| *existing_thread_id != thread_id);
|
||||
self.subscription_state_by_id
|
||||
.retain(|_, state| state.thread_id != thread_id);
|
||||
self.thread_ids_by_connection.retain(|_, thread_ids| {
|
||||
thread_ids.remove(&thread_id);
|
||||
!thread_ids.is_empty()
|
||||
});
|
||||
}
|
||||
|
||||
pub(crate) async fn set_listener(
|
||||
&mut self,
|
||||
subscription_id: Uuid,
|
||||
thread_id: ThreadId,
|
||||
cancel_tx: oneshot::Sender<()>,
|
||||
connection_id: ConnectionId,
|
||||
experimental_raw_events: bool,
|
||||
) -> Arc<Mutex<ThreadState>> {
|
||||
self.thread_id_by_subscription
|
||||
.insert(subscription_id, thread_id);
|
||||
self.subscription_state_by_id.insert(
|
||||
subscription_id,
|
||||
SubscriptionState {
|
||||
thread_id,
|
||||
connection_id,
|
||||
},
|
||||
);
|
||||
self.thread_ids_by_connection
|
||||
.entry(connection_id)
|
||||
.or_default()
|
||||
.insert(thread_id);
|
||||
let thread_state = self.thread_state(thread_id);
|
||||
thread_state
|
||||
.lock()
|
||||
.await
|
||||
.set_listener(subscription_id, cancel_tx);
|
||||
{
|
||||
let mut thread_state_guard = thread_state.lock().await;
|
||||
thread_state_guard.add_connection(connection_id);
|
||||
thread_state_guard.set_experimental_raw_events(experimental_raw_events);
|
||||
}
|
||||
thread_state
|
||||
}
|
||||
|
||||
pub(crate) async fn ensure_connection_subscribed(
|
||||
&mut self,
|
||||
thread_id: ThreadId,
|
||||
connection_id: ConnectionId,
|
||||
experimental_raw_events: bool,
|
||||
) -> Arc<Mutex<ThreadState>> {
|
||||
self.thread_ids_by_connection
|
||||
.entry(connection_id)
|
||||
.or_default()
|
||||
.insert(thread_id);
|
||||
let thread_state = self.thread_state(thread_id);
|
||||
{
|
||||
let mut thread_state_guard = thread_state.lock().await;
|
||||
thread_state_guard.add_connection(connection_id);
|
||||
if experimental_raw_events {
|
||||
thread_state_guard.set_experimental_raw_events(true);
|
||||
}
|
||||
}
|
||||
thread_state
|
||||
}
|
||||
|
||||
pub(crate) async fn remove_connection(&mut self, connection_id: ConnectionId) {
|
||||
let Some(thread_ids) = self.thread_ids_by_connection.remove(&connection_id) else {
|
||||
return;
|
||||
};
|
||||
self.subscription_state_by_id
|
||||
.retain(|_, state| state.connection_id != connection_id);
|
||||
|
||||
for thread_id in thread_ids {
|
||||
if let Some(thread_state) = self.thread_states.get(&thread_id) {
|
||||
let mut thread_state = thread_state.lock().await;
|
||||
thread_state.remove_connection(connection_id);
|
||||
if thread_state.subscribed_connection_ids().is_empty() {
|
||||
thread_state.clear_listener();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -478,6 +478,9 @@ pub(crate) async fn route_outgoing_envelope(
|
|||
);
|
||||
return disconnected;
|
||||
};
|
||||
if should_skip_notification_for_connection(connection_state, &message) {
|
||||
return disconnected;
|
||||
}
|
||||
if connection_state.writer.send(message).await.is_err() {
|
||||
connections.remove(&connection_id);
|
||||
disconnected.push(connection_id);
|
||||
|
|
@ -511,14 +514,6 @@ pub(crate) async fn route_outgoing_envelope(
|
|||
disconnected
|
||||
}
|
||||
|
||||
pub(crate) fn has_initialized_connections(
|
||||
connections: &HashMap<ConnectionId, ConnectionState>,
|
||||
) -> bool {
|
||||
connections
|
||||
.values()
|
||||
.any(|connection| connection.session.initialized)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
|
@ -746,4 +741,40 @@ mod tests {
|
|||
let queued_json = serde_json::to_value(queued_outgoing).expect("serialize queued message");
|
||||
assert_eq!(queued_json, json!({ "method": "queued" }));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn to_connection_notification_respects_opt_out_filters() {
|
||||
let connection_id = ConnectionId(7);
|
||||
let (writer_tx, mut writer_rx) = mpsc::channel(1);
|
||||
let initialized = Arc::new(AtomicBool::new(true));
|
||||
let opted_out_notification_methods = Arc::new(RwLock::new(HashSet::from([
|
||||
"codex/event/task_started".to_string(),
|
||||
])));
|
||||
|
||||
let mut connections = HashMap::new();
|
||||
connections.insert(
|
||||
connection_id,
|
||||
OutboundConnectionState::new(writer_tx, initialized, opted_out_notification_methods),
|
||||
);
|
||||
|
||||
let disconnected = route_outgoing_envelope(
|
||||
&mut connections,
|
||||
OutgoingEnvelope::ToConnection {
|
||||
connection_id,
|
||||
message: OutgoingMessage::Notification(
|
||||
crate::outgoing_message::OutgoingNotification {
|
||||
method: "codex/event/task_started".to_string(),
|
||||
params: None,
|
||||
},
|
||||
),
|
||||
},
|
||||
)
|
||||
.await;
|
||||
|
||||
assert_eq!(disconnected, Vec::<ConnectionId>::new());
|
||||
assert!(
|
||||
writer_rx.try_recv().is_err(),
|
||||
"opted-out notification should be dropped"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue