app-server: thread resume subscriptions (#11474)

This stack layer makes app-server thread event delivery connection-aware
so resumed/attached threads only emit notifications and approval prompts
to subscribed connections.

- Added per-thread subscription tracking in `ThreadState`
(`subscribed_connections`) and mapped subscription ids to `(thread_id,
connection_id)`.
- Updated listener lifecycle so removing a subscription or closing a
connection only removes that connection from the thread’s subscriber
set; listener shutdown now happens when the last subscriber is gone.
- Added `connection_closed(connection_id)` plumbing (`lib.rs` ->
`message_processor.rs` -> `codex_message_processor.rs`) so disconnect
cleanup happens immediately.
- Scoped bespoke event handling outputs through `TargetedOutgoing` to
send requests/notifications only to subscribed connections.
- Kept existing threadresume behavior while aligning with the latest
split-loop transport structure.
This commit is contained in:
Max Johnson 2026-02-11 16:21:13 -08:00 committed by GitHub
parent 703fb38d2a
commit c0ecc2e1e1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 648 additions and 150 deletions

View file

@ -4,7 +4,7 @@ use crate::codex_message_processor::read_summary_from_rollout;
use crate::codex_message_processor::summary_to_thread;
use crate::error_code::INTERNAL_ERROR_CODE;
use crate::error_code::INVALID_REQUEST_ERROR_CODE;
use crate::outgoing_message::OutgoingMessageSender;
use crate::outgoing_message::ThreadScopedOutgoingMessageSender;
use crate::thread_state::ThreadState;
use crate::thread_state::TurnSummary;
use codex_app_server_protocol::AccountRateLimitsUpdatedNotification;
@ -107,7 +107,7 @@ pub(crate) async fn apply_bespoke_event_handling(
event: Event,
conversation_id: ThreadId,
conversation: Arc<CodexThread>,
outgoing: Arc<OutgoingMessageSender>,
outgoing: ThreadScopedOutgoingMessageSender,
thread_state: Arc<tokio::sync::Mutex<ThreadState>>,
api_version: ApiVersion,
fallback_model_provider: String,
@ -850,7 +850,7 @@ pub(crate) async fn apply_bespoke_event_handling(
conversation_id,
&event_turn_id,
raw_response_item_event.item,
outgoing.as_ref(),
&outgoing,
)
.await;
}
@ -899,7 +899,7 @@ pub(crate) async fn apply_bespoke_event_handling(
changes,
status,
event_turn_id.clone(),
outgoing.as_ref(),
&outgoing,
&thread_state,
)
.await;
@ -1142,7 +1142,7 @@ pub(crate) async fn apply_bespoke_event_handling(
&event_turn_id,
turn_diff_event,
api_version,
outgoing.as_ref(),
&outgoing,
)
.await;
}
@ -1152,7 +1152,7 @@ pub(crate) async fn apply_bespoke_event_handling(
&event_turn_id,
plan_update_event,
api_version,
outgoing.as_ref(),
&outgoing,
)
.await;
}
@ -1166,7 +1166,7 @@ async fn handle_turn_diff(
event_turn_id: &str,
turn_diff_event: TurnDiffEvent,
api_version: ApiVersion,
outgoing: &OutgoingMessageSender,
outgoing: &ThreadScopedOutgoingMessageSender,
) {
if let ApiVersion::V2 = api_version {
let notification = TurnDiffUpdatedNotification {
@ -1185,7 +1185,7 @@ async fn handle_turn_plan_update(
event_turn_id: &str,
plan_update_event: UpdatePlanArgs,
api_version: ApiVersion,
outgoing: &OutgoingMessageSender,
outgoing: &ThreadScopedOutgoingMessageSender,
) {
// `update_plan` is a todo/checklist tool; it is not related to plan-mode updates
if let ApiVersion::V2 = api_version {
@ -1210,7 +1210,7 @@ async fn emit_turn_completed_with_status(
event_turn_id: String,
status: TurnStatus,
error: Option<TurnError>,
outgoing: &OutgoingMessageSender,
outgoing: &ThreadScopedOutgoingMessageSender,
) {
let notification = TurnCompletedNotification {
thread_id: conversation_id.to_string(),
@ -1232,7 +1232,7 @@ async fn complete_file_change_item(
changes: Vec<FileUpdateChange>,
status: PatchApplyStatus,
turn_id: String,
outgoing: &OutgoingMessageSender,
outgoing: &ThreadScopedOutgoingMessageSender,
thread_state: &Arc<Mutex<ThreadState>>,
) {
let mut state = thread_state.lock().await;
@ -1264,7 +1264,7 @@ async fn complete_command_execution_item(
process_id: Option<String>,
command_actions: Vec<V2ParsedCommand>,
status: CommandExecutionStatus,
outgoing: &OutgoingMessageSender,
outgoing: &ThreadScopedOutgoingMessageSender,
) {
let item = ThreadItem::CommandExecution {
id: item_id,
@ -1292,7 +1292,7 @@ async fn maybe_emit_raw_response_item_completed(
conversation_id: ThreadId,
turn_id: &str,
item: codex_protocol::models::ResponseItem,
outgoing: &OutgoingMessageSender,
outgoing: &ThreadScopedOutgoingMessageSender,
) {
let ApiVersion::V2 = api_version else {
return;
@ -1319,7 +1319,7 @@ async fn find_and_remove_turn_summary(
async fn handle_turn_complete(
conversation_id: ThreadId,
event_turn_id: String,
outgoing: &OutgoingMessageSender,
outgoing: &ThreadScopedOutgoingMessageSender,
thread_state: &Arc<Mutex<ThreadState>>,
) {
let turn_summary = find_and_remove_turn_summary(conversation_id, thread_state).await;
@ -1335,7 +1335,7 @@ async fn handle_turn_complete(
async fn handle_turn_interrupted(
conversation_id: ThreadId,
event_turn_id: String,
outgoing: &OutgoingMessageSender,
outgoing: &ThreadScopedOutgoingMessageSender,
thread_state: &Arc<Mutex<ThreadState>>,
) {
find_and_remove_turn_summary(conversation_id, thread_state).await;
@ -1354,7 +1354,7 @@ async fn handle_thread_rollback_failed(
_conversation_id: ThreadId,
message: String,
thread_state: &Arc<Mutex<ThreadState>>,
outgoing: &OutgoingMessageSender,
outgoing: &ThreadScopedOutgoingMessageSender,
) {
let pending_rollback = thread_state.lock().await.pending_rollbacks.take();
@ -1376,7 +1376,7 @@ async fn handle_token_count_event(
conversation_id: ThreadId,
turn_id: String,
token_count_event: TokenCountEvent,
outgoing: &OutgoingMessageSender,
outgoing: &ThreadScopedOutgoingMessageSender,
) {
let TokenCountEvent { info, rate_limits } = token_count_event;
if let Some(token_usage) = info.map(ThreadTokenUsage::from) {
@ -1633,7 +1633,7 @@ async fn on_file_change_request_approval_response(
changes: Vec<FileUpdateChange>,
receiver: oneshot::Receiver<JsonValue>,
codex: Arc<CodexThread>,
outgoing: Arc<OutgoingMessageSender>,
outgoing: ThreadScopedOutgoingMessageSender,
thread_state: Arc<Mutex<ThreadState>>,
) {
let response = receiver.await;
@ -1666,7 +1666,7 @@ async fn on_file_change_request_approval_response(
changes,
status,
event_turn_id.clone(),
outgoing.as_ref(),
&outgoing,
&thread_state,
)
.await;
@ -1693,7 +1693,7 @@ async fn on_command_execution_request_approval_response(
command_actions: Vec<V2ParsedCommand>,
receiver: oneshot::Receiver<JsonValue>,
conversation: Arc<CodexThread>,
outgoing: Arc<OutgoingMessageSender>,
outgoing: ThreadScopedOutgoingMessageSender,
) {
let response = receiver.await;
let (decision, completion_status) = match response {
@ -1748,7 +1748,7 @@ async fn on_command_execution_request_approval_response(
None,
command_actions.clone(),
status,
outgoing.as_ref(),
&outgoing,
)
.await;
}
@ -1876,6 +1876,7 @@ async fn construct_mcp_tool_call_end_notification(
mod tests {
use super::*;
use crate::CHANNEL_CAPACITY;
use crate::outgoing_message::ConnectionId;
use crate::outgoing_message::OutgoingEnvelope;
use crate::outgoing_message::OutgoingMessage;
use crate::outgoing_message::OutgoingMessageSender;
@ -1914,9 +1915,7 @@ mod tests {
.ok_or_else(|| anyhow!("should send one message"))?;
match envelope {
OutgoingEnvelope::Broadcast { message } => Ok(message),
OutgoingEnvelope::ToConnection { connection_id, .. } => {
bail!("unexpected targeted message for connection {connection_id:?}")
}
OutgoingEnvelope::ToConnection { message, .. } => Ok(message),
}
}
@ -2011,6 +2010,7 @@ mod tests {
let event_turn_id = "complete1".to_string();
let (tx, mut rx) = mpsc::channel(CHANNEL_CAPACITY);
let outgoing = Arc::new(OutgoingMessageSender::new(tx));
let outgoing = ThreadScopedOutgoingMessageSender::new(outgoing, vec![ConnectionId(1)]);
let thread_state = new_thread_state();
handle_turn_complete(
@ -2051,6 +2051,7 @@ mod tests {
.await;
let (tx, mut rx) = mpsc::channel(CHANNEL_CAPACITY);
let outgoing = Arc::new(OutgoingMessageSender::new(tx));
let outgoing = ThreadScopedOutgoingMessageSender::new(outgoing, vec![ConnectionId(1)]);
handle_turn_interrupted(
conversation_id,
@ -2090,6 +2091,7 @@ mod tests {
.await;
let (tx, mut rx) = mpsc::channel(CHANNEL_CAPACITY);
let outgoing = Arc::new(OutgoingMessageSender::new(tx));
let outgoing = ThreadScopedOutgoingMessageSender::new(outgoing, vec![ConnectionId(1)]);
handle_turn_complete(
conversation_id,
@ -2122,7 +2124,8 @@ mod tests {
#[tokio::test]
async fn test_handle_turn_plan_update_emits_notification_for_v2() -> Result<()> {
let (tx, mut rx) = mpsc::channel(CHANNEL_CAPACITY);
let outgoing = OutgoingMessageSender::new(tx);
let outgoing = Arc::new(OutgoingMessageSender::new(tx));
let outgoing = ThreadScopedOutgoingMessageSender::new(outgoing, vec![ConnectionId(1)]);
let update = UpdatePlanArgs {
explanation: Some("need plan".to_string()),
plan: vec![
@ -2172,6 +2175,7 @@ mod tests {
let turn_id = "turn-123".to_string();
let (tx, mut rx) = mpsc::channel(CHANNEL_CAPACITY);
let outgoing = Arc::new(OutgoingMessageSender::new(tx));
let outgoing = ThreadScopedOutgoingMessageSender::new(outgoing, vec![ConnectionId(1)]);
let info = TokenUsageInfo {
total_token_usage: TokenUsage {
@ -2255,6 +2259,7 @@ mod tests {
let turn_id = "turn-456".to_string();
let (tx, mut rx) = mpsc::channel(CHANNEL_CAPACITY);
let outgoing = Arc::new(OutgoingMessageSender::new(tx));
let outgoing = ThreadScopedOutgoingMessageSender::new(outgoing, vec![ConnectionId(1)]);
handle_token_count_event(
conversation_id,
@ -2321,6 +2326,7 @@ mod tests {
let (tx, mut rx) = mpsc::channel(CHANNEL_CAPACITY);
let outgoing = Arc::new(OutgoingMessageSender::new(tx));
let outgoing = ThreadScopedOutgoingMessageSender::new(outgoing, vec![ConnectionId(1)]);
// Turn 1 on conversation A
let a_turn1 = "a_turn1".to_string();
@ -2542,7 +2548,8 @@ mod tests {
#[tokio::test]
async fn test_handle_turn_diff_emits_v2_notification() -> Result<()> {
let (tx, mut rx) = mpsc::channel(CHANNEL_CAPACITY);
let outgoing = OutgoingMessageSender::new(tx);
let outgoing = Arc::new(OutgoingMessageSender::new(tx));
let outgoing = ThreadScopedOutgoingMessageSender::new(outgoing, vec![ConnectionId(1)]);
let unified_diff = "--- a\n+++ b\n".to_string();
let conversation_id = ThreadId::new();
@ -2575,7 +2582,8 @@ mod tests {
#[tokio::test]
async fn test_handle_turn_diff_is_noop_for_v1() -> Result<()> {
let (tx, mut rx) = mpsc::channel(CHANNEL_CAPACITY);
let outgoing = OutgoingMessageSender::new(tx);
let outgoing = Arc::new(OutgoingMessageSender::new(tx));
let outgoing = ThreadScopedOutgoingMessageSender::new(outgoing, vec![ConnectionId(1)]);
let conversation_id = ThreadId::new();
handle_turn_diff(

View file

@ -7,6 +7,7 @@ use crate::outgoing_message::ConnectionId;
use crate::outgoing_message::ConnectionRequestId;
use crate::outgoing_message::OutgoingMessageSender;
use crate::outgoing_message::OutgoingNotification;
use crate::outgoing_message::ThreadScopedOutgoingMessageSender;
use chrono::DateTime;
use chrono::SecondsFormat;
use chrono::Utc;
@ -251,6 +252,7 @@ use uuid::Uuid;
use crate::filters::compute_source_filters;
use crate::filters::source_kind_matches;
use crate::thread_state::ThreadState;
use crate::thread_state::ThreadStateManager;
const THREAD_LIST_DEFAULT_LIMIT: usize = 25;
@ -1961,8 +1963,9 @@ impl CodexMessageProcessor {
// Auto-attach a thread listener when starting a thread.
// Use the same behavior as the v1 API, with opt-in support for raw item events.
if let Err(err) = self
.attach_conversation_listener(
.ensure_conversation_listener(
thread_id,
request_id.connection_id,
experimental_raw_events,
ApiVersion::V2,
)
@ -2628,20 +2631,28 @@ impl CodexMessageProcessor {
self.thread_manager.subscribe_thread_created()
}
/// Best-effort: attach a listener for thread_id if missing.
pub(crate) async fn try_attach_thread_listener(&mut self, thread_id: ThreadId) {
if self.thread_state_manager.has_listener_for_thread(thread_id) {
return;
}
pub(crate) async fn connection_closed(&mut self, connection_id: ConnectionId) {
self.thread_state_manager
.remove_connection(connection_id)
.await;
}
if let Err(err) = self
.attach_conversation_listener(thread_id, false, ApiVersion::V2)
.await
{
warn!(
"failed to attach listener for thread {thread_id}: {message}",
message = err.message
);
/// Best-effort: ensure initialized connections are subscribed to this thread.
pub(crate) async fn try_attach_thread_listener(
&mut self,
thread_id: ThreadId,
connection_ids: Vec<ConnectionId>,
) {
for connection_id in connection_ids {
if let Err(err) = self
.ensure_conversation_listener(thread_id, connection_id, false, ApiVersion::V2)
.await
{
warn!(
"failed to auto-attach listener for thread {thread_id}: {message}",
message = err.message
);
}
}
}
@ -2793,7 +2804,12 @@ impl CodexMessageProcessor {
};
// Auto-attach a thread listener when resuming a thread.
if let Err(err) = self
.attach_conversation_listener(thread_id, false, ApiVersion::V2)
.ensure_conversation_listener(
thread_id,
request_id.connection_id,
false,
ApiVersion::V2,
)
.await
{
tracing::warn!(
@ -3019,7 +3035,12 @@ impl CodexMessageProcessor {
};
// Auto-attach a conversation listener when forking a thread.
if let Err(err) = self
.attach_conversation_listener(thread_id, false, ApiVersion::V2)
.ensure_conversation_listener(
thread_id,
request_id.connection_id,
false,
ApiVersion::V2,
)
.await
{
tracing::warn!(
@ -5136,7 +5157,12 @@ impl CodexMessageProcessor {
})?;
if let Err(err) = self
.attach_conversation_listener(thread_id, false, ApiVersion::V2)
.ensure_conversation_listener(
thread_id,
request_id.connection_id,
false,
ApiVersion::V2,
)
.await
{
tracing::warn!(
@ -5281,18 +5307,38 @@ impl CodexMessageProcessor {
conversation_id,
experimental_raw_events,
} = params;
match self
.attach_conversation_listener(conversation_id, experimental_raw_events, ApiVersion::V1)
.await
{
Ok(subscription_id) => {
let response = AddConversationSubscriptionResponse { subscription_id };
self.outgoing.send_response(request_id, response).await;
let conversation = match self.thread_manager.get_thread(conversation_id).await {
Ok(conv) => conv,
Err(_) => {
let error = JSONRPCErrorError {
code: INVALID_REQUEST_ERROR_CODE,
message: format!("thread not found: {conversation_id}"),
data: None,
};
self.outgoing.send_error(request_id, error).await;
return;
}
Err(err) => {
self.outgoing.send_error(request_id, err).await;
}
}
};
let subscription_id = Uuid::new_v4();
let thread_state = self
.thread_state_manager
.set_listener(
subscription_id,
conversation_id,
request_id.connection_id,
experimental_raw_events,
)
.await;
self.ensure_listener_task_running(
conversation_id,
conversation,
thread_state,
ApiVersion::V1,
)
.await;
let response = AddConversationSubscriptionResponse { subscription_id };
self.outgoing.send_response(request_id, response).await;
}
async fn remove_thread_listener(
@ -5322,12 +5368,13 @@ impl CodexMessageProcessor {
}
}
async fn attach_conversation_listener(
async fn ensure_conversation_listener(
&mut self,
conversation_id: ThreadId,
connection_id: ConnectionId,
raw_events_enabled: bool,
api_version: ApiVersion,
) -> Result<Uuid, JSONRPCErrorError> {
) -> Result<(), JSONRPCErrorError> {
let conversation = match self.thread_manager.get_thread(conversation_id).await {
Ok(conv) => conv,
Err(_) => {
@ -5338,13 +5385,30 @@ impl CodexMessageProcessor {
});
}
};
let subscription_id = Uuid::new_v4();
let (cancel_tx, mut cancel_rx) = oneshot::channel();
let thread_state = self
.thread_state_manager
.set_listener(subscription_id, conversation_id, cancel_tx)
.ensure_connection_subscribed(conversation_id, connection_id, raw_events_enabled)
.await;
self.ensure_listener_task_running(conversation_id, conversation, thread_state, api_version)
.await;
Ok(())
}
async fn ensure_listener_task_running(
&self,
conversation_id: ThreadId,
conversation: Arc<CodexThread>,
thread_state: Arc<Mutex<ThreadState>>,
api_version: ApiVersion,
) {
let (cancel_tx, mut cancel_rx) = oneshot::channel();
{
let mut thread_state = thread_state.lock().await;
if thread_state.listener_matches(&conversation) {
return;
}
thread_state.set_listener(cancel_tx, &conversation);
}
let outgoing_for_task = self.outgoing.clone();
let fallback_model_provider = self.config.model_provider_id.clone();
tokio::spawn(async move {
@ -5363,10 +5427,6 @@ impl CodexMessageProcessor {
}
};
if let EventMsg::RawResponseItem(_) = &event.msg && !raw_events_enabled {
continue;
}
// For now, we send a notification for every event,
// JSON-serializing the `Event` as-is, but these should
// be migrated to be variants of `ServerNotification`
@ -5391,19 +5451,38 @@ impl CodexMessageProcessor {
"conversationId".to_string(),
conversation_id.to_string().into(),
);
let (subscribed_connection_ids, raw_events_enabled) = {
let thread_state = thread_state.lock().await;
(
thread_state.subscribed_connection_ids(),
thread_state.experimental_raw_events,
)
};
if let EventMsg::RawResponseItem(_) = &event.msg && !raw_events_enabled {
continue;
}
outgoing_for_task
.send_notification(OutgoingNotification {
method: format!("codex/event/{event_formatted}"),
params: Some(params.into()),
})
.await;
if !subscribed_connection_ids.is_empty() {
outgoing_for_task
.send_notification_to_connections(
&subscribed_connection_ids,
OutgoingNotification {
method: format!("codex/event/{event_formatted}"),
params: Some(params.into()),
},
)
.await;
}
let thread_outgoing = ThreadScopedOutgoingMessageSender::new(
outgoing_for_task.clone(),
subscribed_connection_ids,
);
apply_bespoke_event_handling(
event.clone(),
conversation_id,
conversation.clone(),
outgoing_for_task.clone(),
thread_outgoing,
thread_state.clone(),
api_version,
fallback_model_provider.clone(),
@ -5413,9 +5492,7 @@ impl CodexMessageProcessor {
}
}
});
Ok(subscription_id)
}
async fn git_diff_to_origin(&self, request_id: ConnectionRequestId, cwd: PathBuf) {
let diff = git_diff_to_remote(&cwd).await;
match diff {
@ -6299,22 +6376,137 @@ mod tests {
let thread_id = ThreadId::from_string("ad7f0408-99b8-4f6e-a46f-bd0eec433370")?;
let listener_a = Uuid::new_v4();
let listener_b = Uuid::new_v4();
let (cancel_a, cancel_rx_a) = oneshot::channel();
let (cancel_b, mut cancel_rx_b) = oneshot::channel();
let connection_a = ConnectionId(1);
let connection_b = ConnectionId(2);
let (cancel_tx, mut cancel_rx) = oneshot::channel();
manager.set_listener(listener_a, thread_id, cancel_a).await;
manager.set_listener(listener_b, thread_id, cancel_b).await;
manager
.set_listener(listener_a, thread_id, connection_a, false)
.await;
manager
.set_listener(listener_b, thread_id, connection_b, false)
.await;
{
let state = manager.thread_state(thread_id);
state.lock().await.cancel_tx = Some(cancel_tx);
}
assert_eq!(manager.remove_listener(listener_a).await, Some(thread_id));
assert_eq!(cancel_rx_a.await, Ok(()));
assert!(
tokio::time::timeout(Duration::from_millis(20), &mut cancel_rx_b)
tokio::time::timeout(Duration::from_millis(20), &mut cancel_rx)
.await
.is_err()
);
assert_eq!(manager.remove_listener(listener_b).await, Some(thread_id));
assert_eq!(cancel_rx.await, Ok(()));
Ok(())
}
#[tokio::test]
async fn removing_listener_unsubscribes_its_connection() -> Result<()> {
let mut manager = ThreadStateManager::new();
let thread_id = ThreadId::from_string("ad7f0408-99b8-4f6e-a46f-bd0eec433370")?;
let listener_a = Uuid::new_v4();
let listener_b = Uuid::new_v4();
let connection_a = ConnectionId(1);
let connection_b = ConnectionId(2);
manager
.set_listener(listener_a, thread_id, connection_a, false)
.await;
manager
.set_listener(listener_b, thread_id, connection_b, false)
.await;
assert_eq!(manager.remove_listener(listener_a).await, Some(thread_id));
let state = manager.thread_state(thread_id);
let subscribed_connection_ids = state.lock().await.subscribed_connection_ids();
assert_eq!(subscribed_connection_ids, vec![connection_b]);
Ok(())
}
#[tokio::test]
async fn set_listener_uses_last_write_for_raw_events() -> Result<()> {
let mut manager = ThreadStateManager::new();
let thread_id = ThreadId::from_string("ad7f0408-99b8-4f6e-a46f-bd0eec433370")?;
let listener_a = Uuid::new_v4();
let listener_b = Uuid::new_v4();
let connection_a = ConnectionId(1);
let connection_b = ConnectionId(2);
manager
.set_listener(listener_a, thread_id, connection_a, true)
.await;
{
let state = manager.thread_state(thread_id);
assert!(state.lock().await.experimental_raw_events);
}
manager
.set_listener(listener_b, thread_id, connection_b, false)
.await;
let state = manager.thread_state(thread_id);
assert!(!state.lock().await.experimental_raw_events);
Ok(())
}
#[tokio::test]
async fn removing_connection_clears_subscription_and_listener_when_last_subscriber()
-> Result<()> {
let mut manager = ThreadStateManager::new();
let thread_id = ThreadId::from_string("ad7f0408-99b8-4f6e-a46f-bd0eec433370")?;
let listener = Uuid::new_v4();
let connection = ConnectionId(1);
let (cancel_tx, cancel_rx) = oneshot::channel();
manager
.set_listener(listener, thread_id, connection, false)
.await;
{
let state = manager.thread_state(thread_id);
state.lock().await.cancel_tx = Some(cancel_tx);
}
manager.remove_connection(connection).await;
assert_eq!(cancel_rx.await, Ok(()));
assert_eq!(manager.remove_listener(listener).await, None);
let state = manager.thread_state(thread_id);
assert!(state.lock().await.subscribed_connection_ids().is_empty());
Ok(())
}
#[tokio::test]
async fn removing_auto_attached_connection_preserves_listener_for_other_connections()
-> Result<()> {
let mut manager = ThreadStateManager::new();
let thread_id = ThreadId::from_string("ad7f0408-99b8-4f6e-a46f-bd0eec433370")?;
let connection_a = ConnectionId(1);
let connection_b = ConnectionId(2);
let (cancel_tx, mut cancel_rx) = oneshot::channel();
manager
.ensure_connection_subscribed(thread_id, connection_a, false)
.await;
manager
.ensure_connection_subscribed(thread_id, connection_b, false)
.await;
{
let state = manager.thread_state(thread_id);
state.lock().await.cancel_tx = Some(cancel_tx);
}
manager.remove_connection(connection_a).await;
assert!(
tokio::time::timeout(Duration::from_millis(20), &mut cancel_rx)
.await
.is_err()
);
assert_eq!(manager.remove_listener(listener_b).await, Some(thread_id));
assert_eq!(cancel_rx_b.await, Ok(()));
let state = manager.thread_state(thread_id);
assert_eq!(
state.lock().await.subscribed_connection_ids(),
vec![connection_b]
);
Ok(())
}
}

View file

@ -27,7 +27,6 @@ use crate::transport::CHANNEL_CAPACITY;
use crate::transport::ConnectionState;
use crate::transport::OutboundConnectionState;
use crate::transport::TransportEvent;
use crate::transport::has_initialized_connections;
use crate::transport::route_outgoing_envelope;
use crate::transport::start_stdio_connection;
use crate::transport::start_websocket_acceptor;
@ -490,6 +489,7 @@ pub async fn run_main_with_transport(
{
break;
}
processor.connection_closed(connection_id).await;
connections.remove(&connection_id);
if shutdown_when_no_connections && connections.is_empty() {
break;
@ -544,8 +544,19 @@ pub async fn run_main_with_transport(
created = thread_created_rx.recv(), if listen_for_threads => {
match created {
Ok(thread_id) => {
if has_initialized_connections(&connections) {
processor.try_attach_thread_listener(thread_id).await;
let initialized_connection_ids: Vec<ConnectionId> = connections
.iter()
.filter_map(|(connection_id, connection_state)| {
connection_state.session.initialized.then_some(*connection_id)
})
.collect();
if !initialized_connection_ids.is_empty() {
processor
.try_attach_thread_listener(
thread_id,
initialized_connection_ids,
)
.await;
}
}
Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => {

View file

@ -396,9 +396,19 @@ impl MessageProcessor {
}
}
pub(crate) async fn try_attach_thread_listener(&mut self, thread_id: ThreadId) {
pub(crate) async fn try_attach_thread_listener(
&mut self,
thread_id: ThreadId,
connection_ids: Vec<ConnectionId>,
) {
self.codex_message_processor
.try_attach_thread_listener(thread_id)
.try_attach_thread_listener(thread_id, connection_ids)
.await;
}
pub(crate) async fn connection_closed(&mut self, connection_id: ConnectionId) {
self.codex_message_processor
.connection_closed(connection_id)
.await;
}

View file

@ -1,4 +1,5 @@
use std::collections::HashMap;
use std::sync::Arc;
use std::sync::atomic::AtomicI64;
use std::sync::atomic::Ordering;
@ -48,6 +49,62 @@ pub(crate) struct OutgoingMessageSender {
request_id_to_callback: Mutex<HashMap<RequestId, oneshot::Sender<Result>>>,
}
#[derive(Clone)]
pub(crate) struct ThreadScopedOutgoingMessageSender {
outgoing: Arc<OutgoingMessageSender>,
connection_ids: Arc<Vec<ConnectionId>>,
}
impl ThreadScopedOutgoingMessageSender {
pub(crate) fn new(
outgoing: Arc<OutgoingMessageSender>,
connection_ids: Vec<ConnectionId>,
) -> Self {
Self {
outgoing,
connection_ids: Arc::new(connection_ids),
}
}
pub(crate) async fn send_request(
&self,
payload: ServerRequestPayload,
) -> oneshot::Receiver<Result> {
if self.connection_ids.is_empty() {
let (_tx, rx) = oneshot::channel();
return rx;
}
self.outgoing
.send_request_to_connections(self.connection_ids.as_slice(), payload)
.await
}
pub(crate) async fn send_server_notification(&self, notification: ServerNotification) {
if self.connection_ids.is_empty() {
return;
}
self.outgoing
.send_server_notification_to_connections(self.connection_ids.as_slice(), notification)
.await;
}
pub(crate) async fn send_response<T: Serialize>(
&self,
request_id: ConnectionRequestId,
response: T,
) {
self.outgoing.send_response(request_id, response).await;
}
pub(crate) async fn send_error(
&self,
request_id: ConnectionRequestId,
error: JSONRPCErrorError,
) {
self.outgoing.send_error(request_id, error).await;
}
}
impl OutgoingMessageSender {
pub(crate) fn new(sender: mpsc::Sender<OutgoingEnvelope>) -> Self {
Self {
@ -57,17 +114,28 @@ impl OutgoingMessageSender {
}
}
pub(crate) async fn send_request(
pub(crate) async fn send_request_to_connections(
&self,
connection_ids: &[ConnectionId],
request: ServerRequestPayload,
) -> oneshot::Receiver<Result> {
let (_id, rx) = self.send_request_with_id(request).await;
let (_id, rx) = self
.send_request_with_id_to_connections(connection_ids, request)
.await;
rx
}
pub(crate) async fn send_request_with_id(
&self,
request: ServerRequestPayload,
) -> (RequestId, oneshot::Receiver<Result>) {
self.send_request_with_id_to_connections(&[], request).await
}
async fn send_request_with_id_to_connections(
&self,
connection_ids: &[ConnectionId],
request: ServerRequestPayload,
) -> (RequestId, oneshot::Receiver<Result>) {
let id = RequestId::Integer(self.next_server_request_id.fetch_add(1, Ordering::Relaxed));
let outgoing_message_id = id.clone();
@ -79,13 +147,34 @@ impl OutgoingMessageSender {
let outgoing_message =
OutgoingMessage::Request(request.request_with_id(outgoing_message_id.clone()));
if let Err(err) = self
.sender
.send(OutgoingEnvelope::Broadcast {
message: outgoing_message,
})
.await
{
let send_result = if connection_ids.is_empty() {
self.sender
.send(OutgoingEnvelope::Broadcast {
message: outgoing_message,
})
.await
} else {
let mut send_error = None;
for connection_id in connection_ids {
if let Err(err) = self
.sender
.send(OutgoingEnvelope::ToConnection {
connection_id: *connection_id,
message: outgoing_message.clone(),
})
.await
{
send_error = Some(err);
break;
}
}
match send_error {
Some(err) => Err(err),
None => Ok(()),
}
};
if let Err(err) = send_result {
warn!("failed to send request {outgoing_message_id:?} to client: {err:?}");
let mut request_id_to_callback = self.request_id_to_callback.lock().await;
request_id_to_callback.remove(&outgoing_message_id);
@ -172,29 +261,71 @@ impl OutgoingMessageSender {
}
pub(crate) async fn send_server_notification(&self, notification: ServerNotification) {
if let Err(err) = self
.sender
.send(OutgoingEnvelope::Broadcast {
message: OutgoingMessage::AppServerNotification(notification),
})
.await
{
warn!("failed to send server notification to client: {err:?}");
self.send_server_notification_to_connections(&[], notification)
.await;
}
pub(crate) async fn send_server_notification_to_connections(
&self,
connection_ids: &[ConnectionId],
notification: ServerNotification,
) {
let outgoing_message = OutgoingMessage::AppServerNotification(notification);
if connection_ids.is_empty() {
if let Err(err) = self
.sender
.send(OutgoingEnvelope::Broadcast {
message: outgoing_message,
})
.await
{
warn!("failed to send server notification to client: {err:?}");
}
return;
}
for connection_id in connection_ids {
if let Err(err) = self
.sender
.send(OutgoingEnvelope::ToConnection {
connection_id: *connection_id,
message: outgoing_message.clone(),
})
.await
{
warn!("failed to send server notification to client: {err:?}");
}
}
}
/// All notifications should be migrated to [`ServerNotification`] and
/// [`OutgoingMessage::Notification`] should be removed.
pub(crate) async fn send_notification(&self, notification: OutgoingNotification) {
pub(crate) async fn send_notification_to_connections(
&self,
connection_ids: &[ConnectionId],
notification: OutgoingNotification,
) {
let outgoing_message = OutgoingMessage::Notification(notification);
if let Err(err) = self
.sender
.send(OutgoingEnvelope::Broadcast {
message: outgoing_message,
})
.await
{
warn!("failed to send notification to client: {err:?}");
if connection_ids.is_empty() {
if let Err(err) = self
.sender
.send(OutgoingEnvelope::Broadcast {
message: outgoing_message,
})
.await
{
warn!("failed to send notification to client: {err:?}");
}
return;
}
for connection_id in connection_ids {
if let Err(err) = self
.sender
.send(OutgoingEnvelope::ToConnection {
connection_id: *connection_id,
message: outgoing_message.clone(),
})
.await
{
warn!("failed to send notification to client: {err:?}");
}
}
}

View file

@ -1,9 +1,12 @@
use crate::outgoing_message::ConnectionId;
use crate::outgoing_message::ConnectionRequestId;
use codex_app_server_protocol::TurnError;
use codex_core::CodexThread;
use codex_protocol::ThreadId;
use std::collections::HashMap;
use std::collections::HashSet;
use std::sync::Arc;
use std::sync::Weak;
use tokio::sync::Mutex;
use tokio::sync::oneshot;
use uuid::Uuid;
@ -25,33 +28,66 @@ pub(crate) struct ThreadState {
pub(crate) pending_interrupts: PendingInterruptQueue,
pub(crate) pending_rollbacks: Option<ConnectionRequestId>,
pub(crate) turn_summary: TurnSummary,
pub(crate) listener_cancel_txs: HashMap<Uuid, oneshot::Sender<()>>,
pub(crate) cancel_tx: Option<oneshot::Sender<()>>,
pub(crate) experimental_raw_events: bool,
listener_thread: Option<Weak<CodexThread>>,
subscribed_connections: HashSet<ConnectionId>,
}
impl ThreadState {
fn set_listener(&mut self, subscription_id: Uuid, cancel_tx: oneshot::Sender<()>) {
if let Some(previous) = self.listener_cancel_txs.insert(subscription_id, cancel_tx) {
pub(crate) fn listener_matches(&self, conversation: &Arc<CodexThread>) -> bool {
self.listener_thread
.as_ref()
.and_then(Weak::upgrade)
.is_some_and(|existing| Arc::ptr_eq(&existing, conversation))
}
pub(crate) fn set_listener(
&mut self,
cancel_tx: oneshot::Sender<()>,
conversation: &Arc<CodexThread>,
) {
if let Some(previous) = self.cancel_tx.replace(cancel_tx) {
let _ = previous.send(());
}
self.listener_thread = Some(Arc::downgrade(conversation));
}
fn clear_listener(&mut self, subscription_id: Uuid) {
if let Some(cancel_tx) = self.listener_cancel_txs.remove(&subscription_id) {
pub(crate) fn clear_listener(&mut self) {
if let Some(cancel_tx) = self.cancel_tx.take() {
let _ = cancel_tx.send(());
}
self.listener_thread = None;
}
fn clear_listeners(&mut self) {
for (_, cancel_tx) in self.listener_cancel_txs.drain() {
let _ = cancel_tx.send(());
}
pub(crate) fn add_connection(&mut self, connection_id: ConnectionId) {
self.subscribed_connections.insert(connection_id);
}
pub(crate) fn remove_connection(&mut self, connection_id: ConnectionId) {
self.subscribed_connections.remove(&connection_id);
}
pub(crate) fn subscribed_connection_ids(&self) -> Vec<ConnectionId> {
self.subscribed_connections.iter().copied().collect()
}
pub(crate) fn set_experimental_raw_events(&mut self, enabled: bool) {
self.experimental_raw_events = enabled;
}
}
#[derive(Clone, Copy)]
struct SubscriptionState {
thread_id: ThreadId,
connection_id: ConnectionId,
}
#[derive(Default)]
pub(crate) struct ThreadStateManager {
thread_states: HashMap<ThreadId, Arc<Mutex<ThreadState>>>,
thread_id_by_subscription: HashMap<Uuid, ThreadId>,
subscription_state_by_id: HashMap<Uuid, SubscriptionState>,
thread_ids_by_connection: HashMap<ConnectionId, HashSet<ThreadId>>,
}
impl ThreadStateManager {
@ -59,12 +95,6 @@ impl ThreadStateManager {
Self::default()
}
pub(crate) fn has_listener_for_thread(&self, thread_id: ThreadId) -> bool {
self.thread_id_by_subscription
.values()
.any(|existing| *existing == thread_id)
}
pub(crate) fn thread_state(&mut self, thread_id: ThreadId) -> Arc<Mutex<ThreadState>> {
self.thread_states
.entry(thread_id)
@ -73,34 +103,119 @@ impl ThreadStateManager {
}
pub(crate) async fn remove_listener(&mut self, subscription_id: Uuid) -> Option<ThreadId> {
let thread_id = self.thread_id_by_subscription.remove(&subscription_id)?;
let subscription_state = self.subscription_state_by_id.remove(&subscription_id)?;
let thread_id = subscription_state.thread_id;
let connection_still_subscribed_to_thread =
self.subscription_state_by_id.values().any(|state| {
state.thread_id == thread_id
&& state.connection_id == subscription_state.connection_id
});
if !connection_still_subscribed_to_thread {
let mut remove_connection_entry = false;
if let Some(thread_ids) = self
.thread_ids_by_connection
.get_mut(&subscription_state.connection_id)
{
thread_ids.remove(&thread_id);
remove_connection_entry = thread_ids.is_empty();
}
if remove_connection_entry {
self.thread_ids_by_connection
.remove(&subscription_state.connection_id);
}
if let Some(thread_state) = self.thread_states.get(&thread_id) {
thread_state
.lock()
.await
.remove_connection(subscription_state.connection_id);
}
}
if let Some(thread_state) = self.thread_states.get(&thread_id) {
thread_state.lock().await.clear_listener(subscription_id);
let mut thread_state = thread_state.lock().await;
if thread_state.subscribed_connection_ids().is_empty() {
thread_state.clear_listener();
}
}
Some(thread_id)
}
pub(crate) async fn remove_thread_state(&mut self, thread_id: ThreadId) {
if let Some(thread_state) = self.thread_states.remove(&thread_id) {
thread_state.lock().await.clear_listeners();
thread_state.lock().await.clear_listener();
}
self.thread_id_by_subscription
.retain(|_, existing_thread_id| *existing_thread_id != thread_id);
self.subscription_state_by_id
.retain(|_, state| state.thread_id != thread_id);
self.thread_ids_by_connection.retain(|_, thread_ids| {
thread_ids.remove(&thread_id);
!thread_ids.is_empty()
});
}
pub(crate) async fn set_listener(
&mut self,
subscription_id: Uuid,
thread_id: ThreadId,
cancel_tx: oneshot::Sender<()>,
connection_id: ConnectionId,
experimental_raw_events: bool,
) -> Arc<Mutex<ThreadState>> {
self.thread_id_by_subscription
.insert(subscription_id, thread_id);
self.subscription_state_by_id.insert(
subscription_id,
SubscriptionState {
thread_id,
connection_id,
},
);
self.thread_ids_by_connection
.entry(connection_id)
.or_default()
.insert(thread_id);
let thread_state = self.thread_state(thread_id);
thread_state
.lock()
.await
.set_listener(subscription_id, cancel_tx);
{
let mut thread_state_guard = thread_state.lock().await;
thread_state_guard.add_connection(connection_id);
thread_state_guard.set_experimental_raw_events(experimental_raw_events);
}
thread_state
}
pub(crate) async fn ensure_connection_subscribed(
&mut self,
thread_id: ThreadId,
connection_id: ConnectionId,
experimental_raw_events: bool,
) -> Arc<Mutex<ThreadState>> {
self.thread_ids_by_connection
.entry(connection_id)
.or_default()
.insert(thread_id);
let thread_state = self.thread_state(thread_id);
{
let mut thread_state_guard = thread_state.lock().await;
thread_state_guard.add_connection(connection_id);
if experimental_raw_events {
thread_state_guard.set_experimental_raw_events(true);
}
}
thread_state
}
pub(crate) async fn remove_connection(&mut self, connection_id: ConnectionId) {
let Some(thread_ids) = self.thread_ids_by_connection.remove(&connection_id) else {
return;
};
self.subscription_state_by_id
.retain(|_, state| state.connection_id != connection_id);
for thread_id in thread_ids {
if let Some(thread_state) = self.thread_states.get(&thread_id) {
let mut thread_state = thread_state.lock().await;
thread_state.remove_connection(connection_id);
if thread_state.subscribed_connection_ids().is_empty() {
thread_state.clear_listener();
}
}
}
}
}

View file

@ -478,6 +478,9 @@ pub(crate) async fn route_outgoing_envelope(
);
return disconnected;
};
if should_skip_notification_for_connection(connection_state, &message) {
return disconnected;
}
if connection_state.writer.send(message).await.is_err() {
connections.remove(&connection_id);
disconnected.push(connection_id);
@ -511,14 +514,6 @@ pub(crate) async fn route_outgoing_envelope(
disconnected
}
pub(crate) fn has_initialized_connections(
connections: &HashMap<ConnectionId, ConnectionState>,
) -> bool {
connections
.values()
.any(|connection| connection.session.initialized)
}
#[cfg(test)]
mod tests {
use super::*;
@ -746,4 +741,40 @@ mod tests {
let queued_json = serde_json::to_value(queued_outgoing).expect("serialize queued message");
assert_eq!(queued_json, json!({ "method": "queued" }));
}
#[tokio::test]
async fn to_connection_notification_respects_opt_out_filters() {
let connection_id = ConnectionId(7);
let (writer_tx, mut writer_rx) = mpsc::channel(1);
let initialized = Arc::new(AtomicBool::new(true));
let opted_out_notification_methods = Arc::new(RwLock::new(HashSet::from([
"codex/event/task_started".to_string(),
])));
let mut connections = HashMap::new();
connections.insert(
connection_id,
OutboundConnectionState::new(writer_tx, initialized, opted_out_notification_methods),
);
let disconnected = route_outgoing_envelope(
&mut connections,
OutgoingEnvelope::ToConnection {
connection_id,
message: OutgoingMessage::Notification(
crate::outgoing_message::OutgoingNotification {
method: "codex/event/task_started".to_string(),
params: None,
},
),
},
)
.await;
assert_eq!(disconnected, Vec::<ConnectionId>::new());
assert!(
writer_rx.try_recv().is_err(),
"opted-out notification should be dropped"
);
}
}