refactor: codex app-server ThreadState (#11419)

this is a no-op functionality wise. consolidates thread-specific message
processor / event handling state in ThreadState
This commit is contained in:
Max Johnson 2026-02-11 12:20:54 -08:00 committed by GitHub
parent 42e22f3bde
commit b5339a591d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 258 additions and 187 deletions

View file

@ -1,14 +1,12 @@
use crate::codex_message_processor::ApiVersion;
use crate::codex_message_processor::PendingInterrupts;
use crate::codex_message_processor::PendingRollbacks;
use crate::codex_message_processor::TurnSummary;
use crate::codex_message_processor::TurnSummaryStore;
use crate::codex_message_processor::read_rollout_items_from_rollout;
use crate::codex_message_processor::read_summary_from_rollout;
use crate::codex_message_processor::summary_to_thread;
use crate::error_code::INTERNAL_ERROR_CODE;
use crate::error_code::INVALID_REQUEST_ERROR_CODE;
use crate::outgoing_message::OutgoingMessageSender;
use crate::thread_state::ThreadState;
use crate::thread_state::TurnSummary;
use codex_app_server_protocol::AccountRateLimitsUpdatedNotification;
use codex_app_server_protocol::AgentMessageDeltaNotification;
use codex_app_server_protocol::ApplyPatchApprovalParams;
@ -98,6 +96,7 @@ use std::collections::HashMap;
use std::convert::TryFrom;
use std::path::PathBuf;
use std::sync::Arc;
use tokio::sync::Mutex;
use tokio::sync::oneshot;
use tracing::error;
@ -109,9 +108,7 @@ pub(crate) async fn apply_bespoke_event_handling(
conversation_id: ThreadId,
conversation: Arc<CodexThread>,
outgoing: Arc<OutgoingMessageSender>,
pending_interrupts: PendingInterrupts,
pending_rollbacks: PendingRollbacks,
turn_summary_store: TurnSummaryStore,
thread_state: Arc<tokio::sync::Mutex<ThreadState>>,
api_version: ApiVersion,
fallback_model_provider: String,
) {
@ -122,13 +119,7 @@ pub(crate) async fn apply_bespoke_event_handling(
match msg {
EventMsg::TurnStarted(_) => {}
EventMsg::TurnComplete(_ev) => {
handle_turn_complete(
conversation_id,
event_turn_id,
&outgoing,
&turn_summary_store,
)
.await;
handle_turn_complete(conversation_id, event_turn_id, &outgoing, &thread_state).await;
}
EventMsg::ApplyPatchApprovalRequest(ApplyPatchApprovalRequestEvent {
call_id,
@ -159,9 +150,11 @@ pub(crate) async fn apply_bespoke_event_handling(
let patch_changes = convert_patch_changes(&changes);
let first_start = {
let mut map = turn_summary_store.lock().await;
let summary = map.entry(conversation_id).or_default();
summary.file_change_started.insert(item_id.clone())
let mut state = thread_state.lock().await;
state
.turn_summary
.file_change_started
.insert(item_id.clone())
};
if first_start {
let item = ThreadItem::FileChange {
@ -198,7 +191,7 @@ pub(crate) async fn apply_bespoke_event_handling(
rx,
conversation,
outgoing,
turn_summary_store,
thread_state.clone(),
)
.await;
});
@ -718,7 +711,7 @@ pub(crate) async fn apply_bespoke_event_handling(
return handle_thread_rollback_failed(
conversation_id,
message,
&pending_rollbacks,
&thread_state,
&outgoing,
)
.await;
@ -729,7 +722,7 @@ pub(crate) async fn apply_bespoke_event_handling(
codex_error_info: ev.codex_error_info.map(V2CodexErrorInfo::from),
additional_details: None,
};
handle_error(conversation_id, turn_error.clone(), &turn_summary_store).await;
handle_error(conversation_id, turn_error.clone(), &thread_state).await;
outgoing
.send_server_notification(ServerNotification::Error(ErrorNotification {
error: turn_error.clone(),
@ -867,9 +860,11 @@ pub(crate) async fn apply_bespoke_event_handling(
let item_id = patch_begin_event.call_id.clone();
let first_start = {
let mut map = turn_summary_store.lock().await;
let summary = map.entry(conversation_id).or_default();
summary.file_change_started.insert(item_id.clone())
let mut state = thread_state.lock().await;
state
.turn_summary
.file_change_started
.insert(item_id.clone())
};
if first_start {
let item = ThreadItem::FileChange {
@ -905,7 +900,7 @@ pub(crate) async fn apply_bespoke_event_handling(
status,
event_turn_id.clone(),
outgoing.as_ref(),
&turn_summary_store,
&thread_state,
)
.await;
}
@ -950,9 +945,8 @@ pub(crate) async fn apply_bespoke_event_handling(
// We need to detect which item type it is so we can emit the right notification.
// We already have state tracking FileChange items on item/started, so let's use that.
let is_file_change = {
let map = turn_summary_store.lock().await;
map.get(&conversation_id)
.is_some_and(|summary| summary.file_change_started.contains(&item_id))
let state = thread_state.lock().await;
state.turn_summary.file_change_started.contains(&item_id)
};
if is_file_change {
let notification = FileChangeOutputDeltaNotification {
@ -1049,8 +1043,8 @@ pub(crate) async fn apply_bespoke_event_handling(
// If this is a TurnAborted, reply to any pending interrupt requests.
EventMsg::TurnAborted(turn_aborted_event) => {
let pending = {
let mut map = pending_interrupts.lock().await;
map.remove(&conversation_id).unwrap_or_default()
let mut state = thread_state.lock().await;
std::mem::take(&mut state.pending_interrupts)
};
if !pending.is_empty() {
for (rid, ver) in pending {
@ -1069,18 +1063,12 @@ pub(crate) async fn apply_bespoke_event_handling(
}
}
handle_turn_interrupted(
conversation_id,
event_turn_id,
&outgoing,
&turn_summary_store,
)
.await;
handle_turn_interrupted(conversation_id, event_turn_id, &outgoing, &thread_state).await;
}
EventMsg::ThreadRolledBack(_rollback_event) => {
let pending = {
let mut map = pending_rollbacks.lock().await;
map.remove(&conversation_id)
let mut state = thread_state.lock().await;
state.pending_rollbacks.take()
};
if let Some(request_id) = pending {
@ -1245,14 +1233,11 @@ async fn complete_file_change_item(
status: PatchApplyStatus,
turn_id: String,
outgoing: &OutgoingMessageSender,
turn_summary_store: &TurnSummaryStore,
thread_state: &Arc<Mutex<ThreadState>>,
) {
{
let mut map = turn_summary_store.lock().await;
if let Some(summary) = map.get_mut(&conversation_id) {
summary.file_change_started.remove(&item_id);
}
}
let mut state = thread_state.lock().await;
state.turn_summary.file_change_started.remove(&item_id);
drop(state);
let item = ThreadItem::FileChange {
id: item_id,
@ -1324,20 +1309,20 @@ async fn maybe_emit_raw_response_item_completed(
}
async fn find_and_remove_turn_summary(
conversation_id: ThreadId,
turn_summary_store: &TurnSummaryStore,
_conversation_id: ThreadId,
thread_state: &Arc<Mutex<ThreadState>>,
) -> TurnSummary {
let mut map = turn_summary_store.lock().await;
map.remove(&conversation_id).unwrap_or_default()
let mut state = thread_state.lock().await;
std::mem::take(&mut state.turn_summary)
}
async fn handle_turn_complete(
conversation_id: ThreadId,
event_turn_id: String,
outgoing: &OutgoingMessageSender,
turn_summary_store: &TurnSummaryStore,
thread_state: &Arc<Mutex<ThreadState>>,
) {
let turn_summary = find_and_remove_turn_summary(conversation_id, turn_summary_store).await;
let turn_summary = find_and_remove_turn_summary(conversation_id, thread_state).await;
let (status, error) = match turn_summary.last_error {
Some(error) => (TurnStatus::Failed, Some(error)),
@ -1351,9 +1336,9 @@ async fn handle_turn_interrupted(
conversation_id: ThreadId,
event_turn_id: String,
outgoing: &OutgoingMessageSender,
turn_summary_store: &TurnSummaryStore,
thread_state: &Arc<Mutex<ThreadState>>,
) {
find_and_remove_turn_summary(conversation_id, turn_summary_store).await;
find_and_remove_turn_summary(conversation_id, thread_state).await;
emit_turn_completed_with_status(
conversation_id,
@ -1366,15 +1351,12 @@ async fn handle_turn_interrupted(
}
async fn handle_thread_rollback_failed(
conversation_id: ThreadId,
_conversation_id: ThreadId,
message: String,
pending_rollbacks: &PendingRollbacks,
thread_state: &Arc<Mutex<ThreadState>>,
outgoing: &OutgoingMessageSender,
) {
let pending_rollback = {
let mut map = pending_rollbacks.lock().await;
map.remove(&conversation_id)
};
let pending_rollback = thread_state.lock().await.pending_rollbacks.take();
if let Some(request_id) = pending_rollback {
outgoing
@ -1419,12 +1401,12 @@ async fn handle_token_count_event(
}
async fn handle_error(
conversation_id: ThreadId,
_conversation_id: ThreadId,
error: TurnError,
turn_summary_store: &TurnSummaryStore,
thread_state: &Arc<Mutex<ThreadState>>,
) {
let mut map = turn_summary_store.lock().await;
map.entry(conversation_id).or_default().last_error = Some(error);
let mut state = thread_state.lock().await;
state.turn_summary.last_error = Some(error);
}
async fn on_patch_approval_response(
@ -1652,7 +1634,7 @@ async fn on_file_change_request_approval_response(
receiver: oneshot::Receiver<JsonValue>,
codex: Arc<CodexThread>,
outgoing: Arc<OutgoingMessageSender>,
turn_summary_store: TurnSummaryStore,
thread_state: Arc<Mutex<ThreadState>>,
) {
let response = receiver.await;
let (decision, completion_status) = match response {
@ -1685,7 +1667,7 @@ async fn on_file_change_request_approval_response(
status,
event_turn_id.clone(),
outgoing.as_ref(),
&turn_summary_store,
&thread_state,
)
.await;
}
@ -1915,13 +1897,12 @@ mod tests {
use pretty_assertions::assert_eq;
use rmcp::model::Content;
use serde_json::Value as JsonValue;
use std::collections::HashMap;
use std::time::Duration;
use tokio::sync::Mutex;
use tokio::sync::mpsc;
fn new_turn_summary_store() -> TurnSummaryStore {
Arc::new(Mutex::new(HashMap::new()))
fn new_thread_state() -> Arc<Mutex<ThreadState>> {
Arc::new(Mutex::new(ThreadState::default()))
}
async fn recv_broadcast_message(
@ -1999,7 +1980,7 @@ mod tests {
#[tokio::test]
async fn test_handle_error_records_message() -> Result<()> {
let conversation_id = ThreadId::new();
let turn_summary_store = new_turn_summary_store();
let thread_state = new_thread_state();
handle_error(
conversation_id,
@ -2008,11 +1989,11 @@ mod tests {
codex_error_info: Some(V2CodexErrorInfo::InternalServerError),
additional_details: None,
},
&turn_summary_store,
&thread_state,
)
.await;
let turn_summary = find_and_remove_turn_summary(conversation_id, &turn_summary_store).await;
let turn_summary = find_and_remove_turn_summary(conversation_id, &thread_state).await;
assert_eq!(
turn_summary.last_error,
Some(TurnError {
@ -2030,13 +2011,13 @@ mod tests {
let event_turn_id = "complete1".to_string();
let (tx, mut rx) = mpsc::channel(CHANNEL_CAPACITY);
let outgoing = Arc::new(OutgoingMessageSender::new(tx));
let turn_summary_store = new_turn_summary_store();
let thread_state = new_thread_state();
handle_turn_complete(
conversation_id,
event_turn_id.clone(),
&outgoing,
&turn_summary_store,
&thread_state,
)
.await;
@ -2057,7 +2038,7 @@ mod tests {
async fn test_handle_turn_interrupted_emits_interrupted_with_error() -> Result<()> {
let conversation_id = ThreadId::new();
let event_turn_id = "interrupt1".to_string();
let turn_summary_store = new_turn_summary_store();
let thread_state = new_thread_state();
handle_error(
conversation_id,
TurnError {
@ -2065,7 +2046,7 @@ mod tests {
codex_error_info: None,
additional_details: None,
},
&turn_summary_store,
&thread_state,
)
.await;
let (tx, mut rx) = mpsc::channel(CHANNEL_CAPACITY);
@ -2075,7 +2056,7 @@ mod tests {
conversation_id,
event_turn_id.clone(),
&outgoing,
&turn_summary_store,
&thread_state,
)
.await;
@ -2096,7 +2077,7 @@ mod tests {
async fn test_handle_turn_complete_emits_failed_with_error() -> Result<()> {
let conversation_id = ThreadId::new();
let event_turn_id = "complete_err1".to_string();
let turn_summary_store = new_turn_summary_store();
let thread_state = new_thread_state();
handle_error(
conversation_id,
TurnError {
@ -2104,7 +2085,7 @@ mod tests {
codex_error_info: Some(V2CodexErrorInfo::Other),
additional_details: None,
},
&turn_summary_store,
&thread_state,
)
.await;
let (tx, mut rx) = mpsc::channel(CHANNEL_CAPACITY);
@ -2114,7 +2095,7 @@ mod tests {
conversation_id,
event_turn_id.clone(),
&outgoing,
&turn_summary_store,
&thread_state,
)
.await;
@ -2336,7 +2317,7 @@ mod tests {
// Conversation A will have two turns; Conversation B will have one turn.
let conversation_a = ThreadId::new();
let conversation_b = ThreadId::new();
let turn_summary_store = new_turn_summary_store();
let thread_state = new_thread_state();
let (tx, mut rx) = mpsc::channel(CHANNEL_CAPACITY);
let outgoing = Arc::new(OutgoingMessageSender::new(tx));
@ -2350,16 +2331,10 @@ mod tests {
codex_error_info: Some(V2CodexErrorInfo::BadRequest),
additional_details: None,
},
&turn_summary_store,
)
.await;
handle_turn_complete(
conversation_a,
a_turn1.clone(),
&outgoing,
&turn_summary_store,
&thread_state,
)
.await;
handle_turn_complete(conversation_a, a_turn1.clone(), &outgoing, &thread_state).await;
// Turn 1 on conversation B
let b_turn1 = "b_turn1".to_string();
@ -2370,26 +2345,14 @@ mod tests {
codex_error_info: None,
additional_details: None,
},
&turn_summary_store,
)
.await;
handle_turn_complete(
conversation_b,
b_turn1.clone(),
&outgoing,
&turn_summary_store,
&thread_state,
)
.await;
handle_turn_complete(conversation_b, b_turn1.clone(), &outgoing, &thread_state).await;
// Turn 2 on conversation A
let a_turn2 = "a_turn2".to_string();
handle_turn_complete(
conversation_a,
a_turn2.clone(),
&outgoing,
&turn_summary_store,
)
.await;
handle_turn_complete(conversation_a, a_turn2.clone(), &outgoing, &thread_state).await;
// Verify: A turn 1
let msg = recv_broadcast_message(&mut rx).await?;

View file

@ -137,7 +137,6 @@ use codex_app_server_protocol::ThreadStartedNotification;
use codex_app_server_protocol::ThreadUnarchiveParams;
use codex_app_server_protocol::ThreadUnarchiveResponse;
use codex_app_server_protocol::Turn;
use codex_app_server_protocol::TurnError;
use codex_app_server_protocol::TurnInterruptParams;
use codex_app_server_protocol::TurnStartParams;
use codex_app_server_protocol::TurnStartResponse;
@ -252,20 +251,7 @@ use uuid::Uuid;
use crate::filters::compute_source_filters;
use crate::filters::source_kind_matches;
type PendingInterruptQueue = Vec<(ConnectionRequestId, ApiVersion)>;
pub(crate) type PendingInterrupts = Arc<Mutex<HashMap<ThreadId, PendingInterruptQueue>>>;
pub(crate) type PendingRollbacks = Arc<Mutex<HashMap<ThreadId, ConnectionRequestId>>>;
/// Per-conversation accumulation of the latest states e.g. error message while a turn runs.
#[derive(Default, Clone)]
pub(crate) struct TurnSummary {
pub(crate) file_change_started: HashSet<String>,
pub(crate) last_error: Option<TurnError>,
}
pub(crate) type TurnSummaryStore = Arc<Mutex<HashMap<ThreadId, TurnSummary>>>;
use crate::thread_state::ThreadStateManager;
const THREAD_LIST_DEFAULT_LIMIT: usize = 25;
const THREAD_LIST_MAX_LIMIT: usize = 100;
@ -303,21 +289,16 @@ pub(crate) struct CodexMessageProcessor {
config: Arc<Config>,
cli_overrides: Vec<(String, TomlValue)>,
cloud_requirements: Arc<RwLock<CloudRequirementsLoader>>,
conversation_listeners: HashMap<Uuid, oneshot::Sender<()>>,
listener_thread_ids_by_subscription: HashMap<Uuid, ThreadId>,
active_login: Arc<Mutex<Option<ActiveLogin>>>,
// Queue of pending interrupt requests per conversation. We reply when TurnAborted arrives.
pending_interrupts: PendingInterrupts,
// Queue of pending rollback requests per conversation. We reply when ThreadRollback arrives.
pending_rollbacks: PendingRollbacks,
turn_summary_store: TurnSummaryStore,
thread_state_manager: ThreadStateManager,
pending_fuzzy_searches: Arc<Mutex<HashMap<String, Arc<AtomicBool>>>>,
feedback: CodexFeedback,
}
#[derive(Clone, Copy, Debug)]
#[derive(Clone, Copy, Debug, Default)]
pub(crate) enum ApiVersion {
V1,
#[default]
V2,
}
@ -375,12 +356,8 @@ impl CodexMessageProcessor {
config,
cli_overrides,
cloud_requirements,
conversation_listeners: HashMap::new(),
listener_thread_ids_by_subscription: HashMap::new(),
active_login: Arc::new(Mutex::new(None)),
pending_interrupts: Arc::new(Mutex::new(HashMap::new())),
pending_rollbacks: Arc::new(Mutex::new(HashMap::new())),
turn_summary_store: Arc::new(Mutex::new(HashMap::new())),
thread_state_manager: ThreadStateManager::new(),
pending_fuzzy_searches: Arc::new(Mutex::new(HashMap::new())),
feedback,
}
@ -1012,7 +989,7 @@ impl CodexMessageProcessor {
cloud_requirements.as_ref(),
auth_manager.clone(),
chatgpt_base_url,
codex_home.clone(),
codex_home,
);
sync_default_client_residency_requirement(
&cli_overrides,
@ -1120,7 +1097,7 @@ impl CodexMessageProcessor {
cloud_requirements.as_ref(),
auth_manager.clone(),
chatgpt_base_url,
codex_home.clone(),
codex_home,
);
sync_default_client_residency_requirement(
&cli_overrides,
@ -2330,25 +2307,32 @@ impl CodexMessageProcessor {
let request = request_id.clone();
{
let mut map = self.pending_rollbacks.lock().await;
if map.contains_key(&thread_id) {
self.send_invalid_request_error(
request.clone(),
"rollback already in progress for this thread".to_string(),
)
.await;
return;
let rollback_already_in_progress = {
let thread_state = self.thread_state_manager.thread_state(thread_id);
let mut thread_state = thread_state.lock().await;
if thread_state.pending_rollbacks.is_some() {
true
} else {
thread_state.pending_rollbacks = Some(request.clone());
false
}
map.insert(thread_id, request.clone());
};
if rollback_already_in_progress {
self.send_invalid_request_error(
request.clone(),
"rollback already in progress for this thread".to_string(),
)
.await;
return;
}
if let Err(err) = thread.submit(Op::ThreadRollback { num_turns }).await {
// No ThreadRollback event will arrive if an error occurs.
// Clean up and reply immediately.
let mut map = self.pending_rollbacks.lock().await;
map.remove(&thread_id);
let thread_state = self.thread_state_manager.thread_state(thread_id);
let mut thread_state = thread_state.lock().await;
thread_state.pending_rollbacks = None;
drop(thread_state);
self.send_internal_error(request, format!("failed to start rollback: {err}"))
.await;
@ -2646,11 +2630,7 @@ impl CodexMessageProcessor {
/// Best-effort: attach a listener for thread_id if missing.
pub(crate) async fn try_attach_thread_listener(&mut self, thread_id: ThreadId) {
if self
.listener_thread_ids_by_subscription
.values()
.any(|entry| *entry == thread_id)
{
if self.thread_state_manager.has_listener_for_thread(thread_id) {
return;
}
@ -4313,7 +4293,8 @@ impl CodexMessageProcessor {
let mut state_db_ctx = None;
// If the thread is active, request shutdown and wait briefly.
if let Some(conversation) = self.thread_manager.remove_thread(&thread_id).await {
let removed_conversation = self.thread_manager.remove_thread(&thread_id).await;
if let Some(conversation) = removed_conversation {
if let Some(ctx) = conversation.state_db() {
state_db_ctx = Some(ctx);
}
@ -4341,6 +4322,9 @@ impl CodexMessageProcessor {
error!("failed to submit Shutdown to thread {thread_id}: {err}");
}
}
self.thread_state_manager
.remove_thread_state(thread_id)
.await;
}
if state_db_ctx.is_none() {
@ -4880,9 +4864,10 @@ impl CodexMessageProcessor {
// Record the pending interrupt so we can reply when TurnAborted arrives.
{
let mut map = self.pending_interrupts.lock().await;
map.entry(conversation_id)
.or_default()
let pending_interrupts = self.thread_state_manager.thread_state(conversation_id);
let mut thread_state = pending_interrupts.lock().await;
thread_state
.pending_interrupts
.push((request, ApiVersion::V1));
}
@ -5276,9 +5261,10 @@ impl CodexMessageProcessor {
// Record the pending interrupt so we can reply when TurnAborted arrives.
{
let mut map = self.pending_interrupts.lock().await;
map.entry(thread_uuid)
.or_default()
let thread_state = self.thread_state_manager.thread_state(thread_uuid);
let mut thread_state = thread_state.lock().await;
thread_state
.pending_interrupts
.push((request, ApiVersion::V2));
}
@ -5315,16 +5301,13 @@ impl CodexMessageProcessor {
params: RemoveConversationListenerParams,
) {
let RemoveConversationListenerParams { subscription_id } = params;
match self.conversation_listeners.remove(&subscription_id) {
Some(sender) => {
// Signal the spawned task to exit and acknowledge.
let _ = sender.send(());
if let Some(thread_id) = self
.listener_thread_ids_by_subscription
.remove(&subscription_id)
{
info!("removed listener for thread {thread_id}");
}
match self
.thread_state_manager
.remove_listener(subscription_id)
.await
{
Some(thread_id) => {
info!("removed listener for thread {thread_id}");
let response = RemoveConversationSubscriptionResponse {};
self.outgoing.send_response(request_id, response).await;
}
@ -5342,7 +5325,7 @@ impl CodexMessageProcessor {
async fn attach_conversation_listener(
&mut self,
conversation_id: ThreadId,
experimental_raw_events: bool,
raw_events_enabled: bool,
api_version: ApiVersion,
) -> Result<Uuid, JSONRPCErrorError> {
let conversation = match self.thread_manager.get_thread(conversation_id).await {
@ -5358,16 +5341,11 @@ impl CodexMessageProcessor {
let subscription_id = Uuid::new_v4();
let (cancel_tx, mut cancel_rx) = oneshot::channel();
self.conversation_listeners
.insert(subscription_id, cancel_tx);
self.listener_thread_ids_by_subscription
.insert(subscription_id, conversation_id);
let thread_state = self
.thread_state_manager
.set_listener(subscription_id, conversation_id, cancel_tx)
.await;
let outgoing_for_task = self.outgoing.clone();
let pending_interrupts = self.pending_interrupts.clone();
let pending_rollbacks = self.pending_rollbacks.clone();
let turn_summary_store = self.turn_summary_store.clone();
let api_version_for_task = api_version;
let fallback_model_provider = self.config.model_provider_id.clone();
tokio::spawn(async move {
loop {
@ -5385,10 +5363,9 @@ impl CodexMessageProcessor {
}
};
if let EventMsg::RawResponseItem(_) = &event.msg
&& !experimental_raw_events {
continue;
}
if let EventMsg::RawResponseItem(_) = &event.msg && !raw_events_enabled {
continue;
}
// For now, we send a notification for every event,
// JSON-serializing the `Event` as-is, but these should
@ -5427,10 +5404,8 @@ impl CodexMessageProcessor {
conversation_id,
conversation.clone(),
outgoing_for_task.clone(),
pending_interrupts.clone(),
pending_rollbacks.clone(),
turn_summary_store.clone(),
api_version_for_task,
thread_state.clone(),
api_version,
fallback_model_provider.clone(),
)
.await;
@ -5699,7 +5674,7 @@ fn replace_cloud_requirements_loader(
cloud_requirements: &RwLock<CloudRequirementsLoader>,
auth_manager: Arc<AuthManager>,
chatgpt_base_url: String,
codex_home: std::path::PathBuf,
codex_home: PathBuf,
) {
let loader = cloud_requirements_loader(auth_manager, chatgpt_base_url, codex_home);
if let Ok(mut guard) = cloud_requirements.write() {
@ -6316,4 +6291,30 @@ mod tests {
assert_eq!(summary, expected);
Ok(())
}
#[tokio::test]
async fn removing_one_listener_does_not_cancel_other_subscriptions_for_same_thread()
-> Result<()> {
let mut manager = ThreadStateManager::new();
let thread_id = ThreadId::from_string("ad7f0408-99b8-4f6e-a46f-bd0eec433370")?;
let listener_a = Uuid::new_v4();
let listener_b = Uuid::new_v4();
let (cancel_a, cancel_rx_a) = oneshot::channel();
let (cancel_b, mut cancel_rx_b) = oneshot::channel();
manager.set_listener(listener_a, thread_id, cancel_a).await;
manager.set_listener(listener_b, thread_id, cancel_b).await;
assert_eq!(manager.remove_listener(listener_a).await, Some(thread_id));
assert_eq!(cancel_rx_a.await, Ok(()));
assert!(
tokio::time::timeout(Duration::from_millis(20), &mut cancel_rx_b)
.await
.is_err()
);
assert_eq!(manager.remove_listener(listener_b).await, Some(thread_id));
assert_eq!(cancel_rx_b.await, Ok(()));
Ok(())
}
}

View file

@ -62,6 +62,7 @@ mod fuzzy_file_search;
mod message_processor;
mod models;
mod outgoing_message;
mod thread_state;
mod transport;
pub use crate::transport::AppServerTransport;

View file

@ -0,0 +1,106 @@
use crate::outgoing_message::ConnectionRequestId;
use codex_app_server_protocol::TurnError;
use codex_protocol::ThreadId;
use std::collections::HashMap;
use std::collections::HashSet;
use std::sync::Arc;
use tokio::sync::Mutex;
use tokio::sync::oneshot;
use uuid::Uuid;
type PendingInterruptQueue = Vec<(
ConnectionRequestId,
crate::codex_message_processor::ApiVersion,
)>;
/// Per-conversation accumulation of the latest states e.g. error message while a turn runs.
#[derive(Default, Clone)]
pub(crate) struct TurnSummary {
pub(crate) file_change_started: HashSet<String>,
pub(crate) last_error: Option<TurnError>,
}
#[derive(Default)]
pub(crate) struct ThreadState {
pub(crate) pending_interrupts: PendingInterruptQueue,
pub(crate) pending_rollbacks: Option<ConnectionRequestId>,
pub(crate) turn_summary: TurnSummary,
pub(crate) listener_cancel_txs: HashMap<Uuid, oneshot::Sender<()>>,
}
impl ThreadState {
fn set_listener(&mut self, subscription_id: Uuid, cancel_tx: oneshot::Sender<()>) {
if let Some(previous) = self.listener_cancel_txs.insert(subscription_id, cancel_tx) {
let _ = previous.send(());
}
}
fn clear_listener(&mut self, subscription_id: Uuid) {
if let Some(cancel_tx) = self.listener_cancel_txs.remove(&subscription_id) {
let _ = cancel_tx.send(());
}
}
fn clear_listeners(&mut self) {
for (_, cancel_tx) in self.listener_cancel_txs.drain() {
let _ = cancel_tx.send(());
}
}
}
#[derive(Default)]
pub(crate) struct ThreadStateManager {
thread_states: HashMap<ThreadId, Arc<Mutex<ThreadState>>>,
thread_id_by_subscription: HashMap<Uuid, ThreadId>,
}
impl ThreadStateManager {
pub(crate) fn new() -> Self {
Self::default()
}
pub(crate) fn has_listener_for_thread(&self, thread_id: ThreadId) -> bool {
self.thread_id_by_subscription
.values()
.any(|existing| *existing == thread_id)
}
pub(crate) fn thread_state(&mut self, thread_id: ThreadId) -> Arc<Mutex<ThreadState>> {
self.thread_states
.entry(thread_id)
.or_insert_with(|| Arc::new(Mutex::new(ThreadState::default())))
.clone()
}
pub(crate) async fn remove_listener(&mut self, subscription_id: Uuid) -> Option<ThreadId> {
let thread_id = self.thread_id_by_subscription.remove(&subscription_id)?;
if let Some(thread_state) = self.thread_states.get(&thread_id) {
thread_state.lock().await.clear_listener(subscription_id);
}
Some(thread_id)
}
pub(crate) async fn remove_thread_state(&mut self, thread_id: ThreadId) {
if let Some(thread_state) = self.thread_states.remove(&thread_id) {
thread_state.lock().await.clear_listeners();
}
self.thread_id_by_subscription
.retain(|_, existing_thread_id| *existing_thread_id != thread_id);
}
pub(crate) async fn set_listener(
&mut self,
subscription_id: Uuid,
thread_id: ThreadId,
cancel_tx: oneshot::Sender<()>,
) -> Arc<Mutex<ThreadState>> {
self.thread_id_by_subscription
.insert(subscription_id, thread_id);
let thread_state = self.thread_state(thread_id);
thread_state
.lock()
.await
.set_listener(subscription_id, cancel_tx);
thread_state
}
}