diff --git a/codex-rs/Cargo.lock b/codex-rs/Cargo.lock index 4e3ec84e8..20099b02e 100644 --- a/codex-rs/Cargo.lock +++ b/codex-rs/Cargo.lock @@ -1354,6 +1354,7 @@ dependencies = [ "strum_macros 0.27.2", "tempfile", "thiserror 2.0.18", + "tracing", "ts-rs", "uuid", ] diff --git a/codex-rs/app-server-protocol/Cargo.toml b/codex-rs/app-server-protocol/Cargo.toml index 2ab6f291f..f5df7af03 100644 --- a/codex-rs/app-server-protocol/Cargo.toml +++ b/codex-rs/app-server-protocol/Cargo.toml @@ -25,6 +25,7 @@ strum_macros = { workspace = true } thiserror = { workspace = true } ts-rs = { workspace = true } inventory = { workspace = true } +tracing = { workspace = true } uuid = { workspace = true, features = ["serde", "v7"] } [dev-dependencies] diff --git a/codex-rs/app-server-protocol/src/protocol/thread_history.rs b/codex-rs/app-server-protocol/src/protocol/thread_history.rs index e5005077e..52850a67e 100644 --- a/codex-rs/app-server-protocol/src/protocol/thread_history.rs +++ b/codex-rs/app-server-protocol/src/protocol/thread_history.rs @@ -24,9 +24,13 @@ use codex_protocol::protocol::CompactedItem; use codex_protocol::protocol::ContextCompactedEvent; use codex_protocol::protocol::ErrorEvent; use codex_protocol::protocol::EventMsg; +use codex_protocol::protocol::ExecCommandBeginEvent; use codex_protocol::protocol::ExecCommandEndEvent; use codex_protocol::protocol::ItemCompletedEvent; +use codex_protocol::protocol::ItemStartedEvent; +use codex_protocol::protocol::McpToolCallBeginEvent; use codex_protocol::protocol::McpToolCallEndEvent; +use codex_protocol::protocol::PatchApplyBeginEvent; use codex_protocol::protocol::PatchApplyEndEvent; use codex_protocol::protocol::ReviewOutputEvent; use codex_protocol::protocol::RolloutItem; @@ -36,8 +40,10 @@ use codex_protocol::protocol::TurnCompleteEvent; use codex_protocol::protocol::TurnStartedEvent; use codex_protocol::protocol::UserMessageEvent; use codex_protocol::protocol::ViewImageToolCallEvent; +use codex_protocol::protocol::WebSearchBeginEvent; use codex_protocol::protocol::WebSearchEndEvent; use std::collections::HashMap; +use tracing::warn; use uuid::Uuid; #[cfg(test)] @@ -57,14 +63,20 @@ pub fn build_turns_from_rollout_items(items: &[RolloutItem]) -> Vec { builder.finish() } -struct ThreadHistoryBuilder { +pub struct ThreadHistoryBuilder { turns: Vec, current_turn: Option, next_item_index: i64, } +impl Default for ThreadHistoryBuilder { + fn default() -> Self { + Self::new() + } +} + impl ThreadHistoryBuilder { - fn new() -> Self { + pub fn new() -> Self { Self { turns: Vec::new(), current_turn: None, @@ -72,14 +84,32 @@ impl ThreadHistoryBuilder { } } - fn finish(mut self) -> Vec { + pub fn reset(&mut self) { + *self = Self::new(); + } + + pub fn finish(mut self) -> Vec { self.finish_current_turn(); self.turns } + pub fn active_turn_snapshot(&self) -> Option { + self.current_turn + .as_ref() + .map(Turn::from) + .or_else(|| self.turns.last().cloned()) + } + + pub fn has_active_turn(&self) -> bool { + self.current_turn.is_some() + } + + /// Shared reducer for persisted rollout replay and in-memory current-turn + /// tracking used by running thread resume/rejoin. + /// /// This function should handle all EventMsg variants that can be persisted in a rollout file. /// See `should_persist_event_msg` in `codex-rs/core/rollout/policy.rs`. - fn handle_event(&mut self, event: &EventMsg) { + pub fn handle_event(&mut self, event: &EventMsg) { match event { EventMsg::UserMessage(payload) => self.handle_user_message(payload), EventMsg::AgentMessage(payload) => { @@ -89,21 +119,35 @@ impl ThreadHistoryBuilder { EventMsg::AgentReasoningRawContent(payload) => { self.handle_agent_reasoning_raw_content(payload) } + EventMsg::WebSearchBegin(payload) => self.handle_web_search_begin(payload), EventMsg::WebSearchEnd(payload) => self.handle_web_search_end(payload), + EventMsg::ExecCommandBegin(payload) => self.handle_exec_command_begin(payload), EventMsg::ExecCommandEnd(payload) => self.handle_exec_command_end(payload), + EventMsg::PatchApplyBegin(payload) => self.handle_patch_apply_begin(payload), EventMsg::PatchApplyEnd(payload) => self.handle_patch_apply_end(payload), + EventMsg::McpToolCallBegin(payload) => self.handle_mcp_tool_call_begin(payload), EventMsg::McpToolCallEnd(payload) => self.handle_mcp_tool_call_end(payload), EventMsg::ViewImageToolCall(payload) => self.handle_view_image_tool_call(payload), + EventMsg::CollabAgentSpawnBegin(payload) => { + self.handle_collab_agent_spawn_begin(payload) + } EventMsg::CollabAgentSpawnEnd(payload) => self.handle_collab_agent_spawn_end(payload), + EventMsg::CollabAgentInteractionBegin(payload) => { + self.handle_collab_agent_interaction_begin(payload) + } EventMsg::CollabAgentInteractionEnd(payload) => { self.handle_collab_agent_interaction_end(payload) } + EventMsg::CollabWaitingBegin(payload) => self.handle_collab_waiting_begin(payload), EventMsg::CollabWaitingEnd(payload) => self.handle_collab_waiting_end(payload), + EventMsg::CollabCloseBegin(payload) => self.handle_collab_close_begin(payload), EventMsg::CollabCloseEnd(payload) => self.handle_collab_close_end(payload), + EventMsg::CollabResumeBegin(payload) => self.handle_collab_resume_begin(payload), EventMsg::CollabResumeEnd(payload) => self.handle_collab_resume_end(payload), EventMsg::ContextCompacted(payload) => self.handle_context_compacted(payload), EventMsg::EnteredReviewMode(payload) => self.handle_entered_review_mode(payload), EventMsg::ExitedReviewMode(payload) => self.handle_exited_review_mode(payload), + EventMsg::ItemStarted(payload) => self.handle_item_started(payload), EventMsg::ItemCompleted(payload) => self.handle_item_completed(payload), EventMsg::Error(payload) => self.handle_error(payload), EventMsg::TokenCount(_) => {} @@ -116,7 +160,7 @@ impl ThreadHistoryBuilder { } } - fn handle_rollout_item(&mut self, item: &RolloutItem) { + pub fn handle_rollout_item(&mut self, item: &RolloutItem) { match item { RolloutItem::EventMsg(event) => self.handle_event(event), RolloutItem::Compacted(payload) => self.handle_compacted(payload), @@ -199,15 +243,51 @@ impl ThreadHistoryBuilder { }); } - fn handle_item_completed(&mut self, payload: &ItemCompletedEvent) { - if let codex_protocol::items::TurnItem::Plan(plan) = &payload.item - && plan.text.is_empty() - { - return; + fn handle_item_started(&mut self, payload: &ItemStartedEvent) { + match &payload.item { + codex_protocol::items::TurnItem::Plan(plan) => { + if plan.text.is_empty() { + return; + } + self.upsert_item_in_turn_id( + &payload.turn_id, + ThreadItem::from(payload.item.clone()), + ); + } + codex_protocol::items::TurnItem::UserMessage(_) + | codex_protocol::items::TurnItem::AgentMessage(_) + | codex_protocol::items::TurnItem::Reasoning(_) + | codex_protocol::items::TurnItem::WebSearch(_) + | codex_protocol::items::TurnItem::ContextCompaction(_) => {} } + } - let item = ThreadItem::from(payload.item.clone()); - self.ensure_turn().items.push(item); + fn handle_item_completed(&mut self, payload: &ItemCompletedEvent) { + match &payload.item { + codex_protocol::items::TurnItem::Plan(plan) => { + if plan.text.is_empty() { + return; + } + self.upsert_item_in_turn_id( + &payload.turn_id, + ThreadItem::from(payload.item.clone()), + ); + } + codex_protocol::items::TurnItem::UserMessage(_) + | codex_protocol::items::TurnItem::AgentMessage(_) + | codex_protocol::items::TurnItem::Reasoning(_) + | codex_protocol::items::TurnItem::WebSearch(_) + | codex_protocol::items::TurnItem::ContextCompaction(_) => {} + } + } + + fn handle_web_search_begin(&mut self, payload: &WebSearchBeginEvent) { + let item = ThreadItem::WebSearch { + id: payload.call_id.clone(), + query: String::new(), + action: None, + }; + self.upsert_item_in_current_turn(item); } fn handle_web_search_end(&mut self, payload: &WebSearchEndEvent) { @@ -216,7 +296,30 @@ impl ThreadHistoryBuilder { query: payload.query.clone(), action: Some(WebSearchAction::from(payload.action.clone())), }; - self.ensure_turn().items.push(item); + self.upsert_item_in_current_turn(item); + } + + fn handle_exec_command_begin(&mut self, payload: &ExecCommandBeginEvent) { + let command = shlex::try_join(payload.command.iter().map(String::as_str)) + .unwrap_or_else(|_| payload.command.join(" ")); + let command_actions = payload + .parsed_cmd + .iter() + .cloned() + .map(CommandAction::from) + .collect(); + let item = ThreadItem::CommandExecution { + id: payload.call_id.clone(), + command, + cwd: payload.cwd.clone(), + process_id: payload.process_id.clone(), + status: CommandExecutionStatus::InProgress, + command_actions, + aggregated_output: None, + exit_code: None, + duration_ms: None, + }; + self.upsert_item_in_turn_id(&payload.turn_id, item); } fn handle_exec_command_end(&mut self, payload: &ExecCommandEndEvent) { @@ -246,33 +349,25 @@ impl ThreadHistoryBuilder { exit_code: Some(payload.exit_code), duration_ms: Some(duration_ms), }; - // Command completions can arrive out of order. Unified exec may return // while a PTY is still running, then emit ExecCommandEnd later from a // background exit watcher when that process finally exits. By then, a // newer user turn may already have started. Route by event turn_id so // replay preserves the original turn association. - if let Some(turn) = self.current_turn.as_mut() - && turn.id == payload.turn_id - { - turn.items.push(item); - return; - } + self.upsert_item_in_turn_id(&payload.turn_id, item); + } - // If the originating turn is already finalized, append there instead - // of attaching to whichever turn is currently active during replay. - if let Some(turn) = self - .turns - .iter_mut() - .find(|turn| turn.id == payload.turn_id) - { - turn.items.push(item); - return; + fn handle_patch_apply_begin(&mut self, payload: &PatchApplyBeginEvent) { + let item = ThreadItem::FileChange { + id: payload.call_id.clone(), + changes: convert_patch_changes(&payload.changes), + status: PatchApplyStatus::InProgress, + }; + if payload.turn_id.is_empty() { + self.upsert_item_in_current_turn(item); + } else { + self.upsert_item_in_turn_id(&payload.turn_id, item); } - - // Backward-compatibility fallback for partial/legacy streams where the - // event turn_id does not match any known replay turn. - self.ensure_turn().items.push(item); } fn handle_patch_apply_end(&mut self, payload: &PatchApplyEndEvent) { @@ -282,7 +377,29 @@ impl ThreadHistoryBuilder { changes: convert_patch_changes(&payload.changes), status, }; - self.ensure_turn().items.push(item); + if payload.turn_id.is_empty() { + self.upsert_item_in_current_turn(item); + } else { + self.upsert_item_in_turn_id(&payload.turn_id, item); + } + } + + fn handle_mcp_tool_call_begin(&mut self, payload: &McpToolCallBeginEvent) { + let item = ThreadItem::McpToolCall { + id: payload.call_id.clone(), + server: payload.invocation.server.clone(), + tool: payload.invocation.tool.clone(), + status: McpToolCallStatus::InProgress, + arguments: payload + .invocation + .arguments + .clone() + .unwrap_or(serde_json::Value::Null), + result: None, + error: None, + duration_ms: None, + }; + self.upsert_item_in_current_turn(item); } fn handle_mcp_tool_call_end(&mut self, payload: &McpToolCallEndEvent) { @@ -321,7 +438,7 @@ impl ThreadHistoryBuilder { error, duration_ms, }; - self.ensure_turn().items.push(item); + self.upsert_item_in_current_turn(item); } fn handle_view_image_tool_call(&mut self, payload: &ViewImageToolCallEvent) { @@ -329,7 +446,23 @@ impl ThreadHistoryBuilder { id: payload.call_id.clone(), path: payload.path.to_string_lossy().into_owned(), }; - self.ensure_turn().items.push(item); + self.upsert_item_in_current_turn(item); + } + + fn handle_collab_agent_spawn_begin( + &mut self, + payload: &codex_protocol::protocol::CollabAgentSpawnBeginEvent, + ) { + let item = ThreadItem::CollabAgentToolCall { + id: payload.call_id.clone(), + tool: CollabAgentTool::SpawnAgent, + status: CollabAgentToolCallStatus::InProgress, + sender_thread_id: payload.sender_thread_id.to_string(), + receiver_thread_ids: Vec::new(), + prompt: Some(payload.prompt.clone()), + agents_states: HashMap::new(), + }; + self.upsert_item_in_current_turn(item); } fn handle_collab_agent_spawn_end( @@ -353,17 +486,31 @@ impl ThreadHistoryBuilder { } None => (Vec::new(), HashMap::new()), }; - self.ensure_turn() - .items - .push(ThreadItem::CollabAgentToolCall { - id: payload.call_id.clone(), - tool: CollabAgentTool::SpawnAgent, - status, - sender_thread_id: payload.sender_thread_id.to_string(), - receiver_thread_ids, - prompt: Some(payload.prompt.clone()), - agents_states, - }); + self.upsert_item_in_current_turn(ThreadItem::CollabAgentToolCall { + id: payload.call_id.clone(), + tool: CollabAgentTool::SpawnAgent, + status, + sender_thread_id: payload.sender_thread_id.to_string(), + receiver_thread_ids, + prompt: Some(payload.prompt.clone()), + agents_states, + }); + } + + fn handle_collab_agent_interaction_begin( + &mut self, + payload: &codex_protocol::protocol::CollabAgentInteractionBeginEvent, + ) { + let item = ThreadItem::CollabAgentToolCall { + id: payload.call_id.clone(), + tool: CollabAgentTool::SendInput, + status: CollabAgentToolCallStatus::InProgress, + sender_thread_id: payload.sender_thread_id.to_string(), + receiver_thread_ids: vec![payload.receiver_thread_id.to_string()], + prompt: Some(payload.prompt.clone()), + agents_states: HashMap::new(), + }; + self.upsert_item_in_current_turn(item); } fn handle_collab_agent_interaction_end( @@ -376,17 +523,35 @@ impl ThreadHistoryBuilder { }; let receiver_id = payload.receiver_thread_id.to_string(); let received_status = CollabAgentState::from(payload.status.clone()); - self.ensure_turn() - .items - .push(ThreadItem::CollabAgentToolCall { - id: payload.call_id.clone(), - tool: CollabAgentTool::SendInput, - status, - sender_thread_id: payload.sender_thread_id.to_string(), - receiver_thread_ids: vec![receiver_id.clone()], - prompt: Some(payload.prompt.clone()), - agents_states: [(receiver_id, received_status)].into_iter().collect(), - }); + self.upsert_item_in_current_turn(ThreadItem::CollabAgentToolCall { + id: payload.call_id.clone(), + tool: CollabAgentTool::SendInput, + status, + sender_thread_id: payload.sender_thread_id.to_string(), + receiver_thread_ids: vec![receiver_id.clone()], + prompt: Some(payload.prompt.clone()), + agents_states: [(receiver_id, received_status)].into_iter().collect(), + }); + } + + fn handle_collab_waiting_begin( + &mut self, + payload: &codex_protocol::protocol::CollabWaitingBeginEvent, + ) { + let item = ThreadItem::CollabAgentToolCall { + id: payload.call_id.clone(), + tool: CollabAgentTool::Wait, + status: CollabAgentToolCallStatus::InProgress, + sender_thread_id: payload.sender_thread_id.to_string(), + receiver_thread_ids: payload + .receiver_thread_ids + .iter() + .map(ToString::to_string) + .collect(), + prompt: None, + agents_states: HashMap::new(), + }; + self.upsert_item_in_current_turn(item); } fn handle_collab_waiting_end( @@ -410,17 +575,31 @@ impl ThreadHistoryBuilder { .iter() .map(|(id, status)| (id.to_string(), CollabAgentState::from(status.clone()))) .collect(); - self.ensure_turn() - .items - .push(ThreadItem::CollabAgentToolCall { - id: payload.call_id.clone(), - tool: CollabAgentTool::Wait, - status, - sender_thread_id: payload.sender_thread_id.to_string(), - receiver_thread_ids, - prompt: None, - agents_states, - }); + self.upsert_item_in_current_turn(ThreadItem::CollabAgentToolCall { + id: payload.call_id.clone(), + tool: CollabAgentTool::Wait, + status, + sender_thread_id: payload.sender_thread_id.to_string(), + receiver_thread_ids, + prompt: None, + agents_states, + }); + } + + fn handle_collab_close_begin( + &mut self, + payload: &codex_protocol::protocol::CollabCloseBeginEvent, + ) { + let item = ThreadItem::CollabAgentToolCall { + id: payload.call_id.clone(), + tool: CollabAgentTool::CloseAgent, + status: CollabAgentToolCallStatus::InProgress, + sender_thread_id: payload.sender_thread_id.to_string(), + receiver_thread_ids: vec![payload.receiver_thread_id.to_string()], + prompt: None, + agents_states: HashMap::new(), + }; + self.upsert_item_in_current_turn(item); } fn handle_collab_close_end(&mut self, payload: &codex_protocol::protocol::CollabCloseEndEvent) { @@ -435,17 +614,31 @@ impl ThreadHistoryBuilder { )] .into_iter() .collect(); - self.ensure_turn() - .items - .push(ThreadItem::CollabAgentToolCall { - id: payload.call_id.clone(), - tool: CollabAgentTool::CloseAgent, - status, - sender_thread_id: payload.sender_thread_id.to_string(), - receiver_thread_ids: vec![receiver_id], - prompt: None, - agents_states, - }); + self.upsert_item_in_current_turn(ThreadItem::CollabAgentToolCall { + id: payload.call_id.clone(), + tool: CollabAgentTool::CloseAgent, + status, + sender_thread_id: payload.sender_thread_id.to_string(), + receiver_thread_ids: vec![receiver_id], + prompt: None, + agents_states, + }); + } + + fn handle_collab_resume_begin( + &mut self, + payload: &codex_protocol::protocol::CollabResumeBeginEvent, + ) { + let item = ThreadItem::CollabAgentToolCall { + id: payload.call_id.clone(), + tool: CollabAgentTool::ResumeAgent, + status: CollabAgentToolCallStatus::InProgress, + sender_thread_id: payload.sender_thread_id.to_string(), + receiver_thread_ids: vec![payload.receiver_thread_id.to_string()], + prompt: None, + agents_states: HashMap::new(), + }; + self.upsert_item_in_current_turn(item); } fn handle_collab_resume_end( @@ -463,17 +656,15 @@ impl ThreadHistoryBuilder { )] .into_iter() .collect(); - self.ensure_turn() - .items - .push(ThreadItem::CollabAgentToolCall { - id: payload.call_id.clone(), - tool: CollabAgentTool::ResumeAgent, - status, - sender_thread_id: payload.sender_thread_id.to_string(), - receiver_thread_ids: vec![receiver_id], - prompt: None, - agents_states, - }); + self.upsert_item_in_current_turn(ThreadItem::CollabAgentToolCall { + id: payload.call_id.clone(), + tool: CollabAgentTool::ResumeAgent, + status, + sender_thread_id: payload.sender_thread_id.to_string(), + receiver_thread_ids: vec![receiver_id], + prompt: None, + agents_states, + }); } fn handle_context_compacted(&mut self, _payload: &ContextCompactedEvent) { @@ -548,6 +739,7 @@ impl ThreadHistoryBuilder { self.finish_current_turn(); self.current_turn = Some( self.new_turn(Some(payload.turn_id.clone())) + .with_status(TurnStatus::InProgress) .opened_explicitly(), ); } @@ -642,6 +834,30 @@ impl ThreadHistoryBuilder { unreachable!("current turn must exist after initialization"); } + fn upsert_item_in_turn_id(&mut self, turn_id: &str, item: ThreadItem) { + if let Some(turn) = self.current_turn.as_mut() + && turn.id == turn_id + { + upsert_turn_item(&mut turn.items, item); + return; + } + + if let Some(turn) = self.turns.iter_mut().find(|turn| turn.id == turn_id) { + upsert_turn_item(&mut turn.items, item); + return; + } + + warn!( + item_id = item.id(), + "dropping turn-scoped item for unknown turn id `{turn_id}`" + ); + } + + fn upsert_item_in_current_turn(&mut self, item: ThreadItem) { + let turn = self.ensure_turn(); + upsert_turn_item(&mut turn.items, item); + } + fn next_item_id(&mut self) -> String { let id = format!("item-{}", self.next_item_index); self.next_item_index += 1; @@ -684,7 +900,7 @@ fn render_review_output_text(output: &ReviewOutputEvent) -> String { } } -fn convert_patch_changes( +pub fn convert_patch_changes( changes: &HashMap, ) -> Vec { let mut converted: Vec = changes @@ -726,6 +942,17 @@ fn format_file_change_diff(change: &codex_protocol::protocol::FileChange) -> Str } } +fn upsert_turn_item(items: &mut Vec, item: ThreadItem) { + if let Some(existing_item) = items + .iter_mut() + .find(|existing_item| existing_item.id() == item.id()) + { + *existing_item = item; + return; + } + items.push(item); +} + struct PendingTurn { id: String, items: Vec, @@ -744,6 +971,11 @@ impl PendingTurn { self.opened_explicitly = true; self } + + fn with_status(mut self, status: TurnStatus) -> Self { + self.status = status; + self + } } impl From for Turn { @@ -757,10 +989,23 @@ impl From for Turn { } } +impl From<&PendingTurn> for Turn { + fn from(value: &PendingTurn) -> Self { + Self { + id: value.id.clone(), + items: value.items.clone(), + error: value.error.clone(), + status: value.status.clone(), + } + } +} + #[cfg(test)] mod tests { use super::*; use codex_protocol::ThreadId; + use codex_protocol::items::TurnItem as CoreTurnItem; + use codex_protocol::items::UserMessageItem as CoreUserMessageItem; use codex_protocol::models::MessagePhase as CoreMessagePhase; use codex_protocol::models::WebSearchAction as CoreWebSearchAction; use codex_protocol::parse_command::ParsedCommand; @@ -771,6 +1016,7 @@ mod tests { use codex_protocol::protocol::CompactedItem; use codex_protocol::protocol::ExecCommandEndEvent; use codex_protocol::protocol::ExecCommandSource; + use codex_protocol::protocol::ItemStartedEvent; use codex_protocol::protocol::McpInvocation; use codex_protocol::protocol::McpToolCallEndEvent; use codex_protocol::protocol::ThreadRolledBackEvent; @@ -816,11 +1062,11 @@ mod tests { }), ]; - let items = events - .into_iter() - .map(RolloutItem::EventMsg) - .collect::>(); - let turns = build_turns_from_rollout_items(&items); + let mut builder = ThreadHistoryBuilder::new(); + for event in &events { + builder.handle_event(event); + } + let turns = builder.finish(); assert_eq!(turns.len(), 2); let first = &turns[0]; @@ -883,6 +1129,55 @@ mod tests { ); } + #[test] + fn ignores_non_plan_item_lifecycle_events() { + let turn_id = "turn-1"; + let thread_id = ThreadId::new(); + let events = vec![ + EventMsg::TurnStarted(TurnStartedEvent { + turn_id: turn_id.to_string(), + model_context_window: None, + collaboration_mode_kind: Default::default(), + }), + EventMsg::UserMessage(UserMessageEvent { + message: "hello".into(), + images: None, + text_elements: Vec::new(), + local_images: Vec::new(), + }), + EventMsg::ItemStarted(ItemStartedEvent { + thread_id, + turn_id: turn_id.to_string(), + item: CoreTurnItem::UserMessage(CoreUserMessageItem { + id: "user-item-id".to_string(), + content: Vec::new(), + }), + }), + EventMsg::TurnComplete(TurnCompleteEvent { + turn_id: turn_id.to_string(), + last_agent_message: None, + }), + ]; + + let items = events + .into_iter() + .map(RolloutItem::EventMsg) + .collect::>(); + let turns = build_turns_from_rollout_items(&items); + assert_eq!(turns.len(), 1); + assert_eq!(turns[0].items.len(), 1); + assert_eq!( + turns[0].items[0], + ThreadItem::UserMessage { + id: "item-1".into(), + content: vec![UserInput::Text { + text: "hello".into(), + text_elements: Vec::new(), + }], + } + ); + } + #[test] fn preserves_agent_message_phase_in_history() { let events = vec![EventMsg::AgentMessage(AgentMessageEvent { @@ -1212,6 +1507,11 @@ mod tests { #[test] fn reconstructs_tool_items_from_persisted_completion_events() { let events = vec![ + EventMsg::TurnStarted(TurnStartedEvent { + turn_id: "turn-1".into(), + model_context_window: None, + collaboration_mode_kind: Default::default(), + }), EventMsg::UserMessage(UserMessageEvent { message: "run tools".into(), images: None, @@ -1311,6 +1611,11 @@ mod tests { #[test] fn reconstructs_declined_exec_and_patch_items() { let events = vec![ + EventMsg::TurnStarted(TurnStartedEvent { + turn_id: "turn-1".into(), + model_context_window: None, + collaboration_mode_kind: Default::default(), + }), EventMsg::UserMessage(UserMessageEvent { message: "run tools".into(), images: None, @@ -1471,6 +1776,82 @@ mod tests { ); } + #[test] + fn drops_late_turn_scoped_item_for_unknown_turn_id() { + let events = vec![ + EventMsg::TurnStarted(TurnStartedEvent { + turn_id: "turn-a".into(), + model_context_window: None, + collaboration_mode_kind: Default::default(), + }), + EventMsg::UserMessage(UserMessageEvent { + message: "first".into(), + images: None, + text_elements: Vec::new(), + local_images: Vec::new(), + }), + EventMsg::TurnComplete(TurnCompleteEvent { + turn_id: "turn-a".into(), + last_agent_message: None, + }), + EventMsg::TurnStarted(TurnStartedEvent { + turn_id: "turn-b".into(), + model_context_window: None, + collaboration_mode_kind: Default::default(), + }), + EventMsg::UserMessage(UserMessageEvent { + message: "second".into(), + images: None, + text_elements: Vec::new(), + local_images: Vec::new(), + }), + EventMsg::ExecCommandEnd(ExecCommandEndEvent { + call_id: "exec-unknown-turn".into(), + process_id: Some("pid-42".into()), + turn_id: "turn-missing".into(), + command: vec!["echo".into(), "done".into()], + cwd: PathBuf::from("/tmp"), + parsed_cmd: vec![ParsedCommand::Unknown { + cmd: "echo done".into(), + }], + source: ExecCommandSource::Agent, + interaction_input: None, + stdout: "done\n".into(), + stderr: String::new(), + aggregated_output: "done\n".into(), + exit_code: 0, + duration: Duration::from_millis(5), + formatted_output: "done\n".into(), + status: CoreExecCommandStatus::Completed, + }), + EventMsg::TurnComplete(TurnCompleteEvent { + turn_id: "turn-b".into(), + last_agent_message: None, + }), + ]; + + let mut builder = ThreadHistoryBuilder::new(); + for event in &events { + builder.handle_event(event); + } + let turns = builder.finish(); + assert_eq!(turns.len(), 2); + assert_eq!(turns[0].id, "turn-a"); + assert_eq!(turns[1].id, "turn-b"); + assert_eq!(turns[0].items.len(), 1); + assert_eq!(turns[1].items.len(), 1); + assert_eq!( + turns[1].items[0], + ThreadItem::UserMessage { + id: "item-2".into(), + content: vec![UserInput::Text { + text: "second".into(), + text_elements: Vec::new(), + }], + } + ); + } + #[test] fn late_turn_complete_does_not_close_active_turn() { let events = vec![ @@ -1572,7 +1953,7 @@ mod tests { assert_eq!(turns.len(), 2); assert_eq!(turns[0].id, "turn-a"); assert_eq!(turns[1].id, "turn-b"); - assert_eq!(turns[1].status, TurnStatus::Completed); + assert_eq!(turns[1].status, TurnStatus::InProgress); assert_eq!(turns[1].items.len(), 2); } diff --git a/codex-rs/app-server-protocol/src/protocol/v2.rs b/codex-rs/app-server-protocol/src/protocol/v2.rs index 328fdc091..6a1959bb2 100644 --- a/codex-rs/app-server-protocol/src/protocol/v2.rs +++ b/codex-rs/app-server-protocol/src/protocol/v2.rs @@ -2727,6 +2727,26 @@ pub enum ThreadItem { ContextCompaction { id: String }, } +impl ThreadItem { + pub fn id(&self) -> &str { + match self { + ThreadItem::UserMessage { id, .. } + | ThreadItem::AgentMessage { id, .. } + | ThreadItem::Plan { id, .. } + | ThreadItem::Reasoning { id, .. } + | ThreadItem::CommandExecution { id, .. } + | ThreadItem::FileChange { id, .. } + | ThreadItem::McpToolCall { id, .. } + | ThreadItem::CollabAgentToolCall { id, .. } + | ThreadItem::WebSearch { id, .. } + | ThreadItem::ImageView { id, .. } + | ThreadItem::EnteredReviewMode { id, .. } + | ThreadItem::ExitedReviewMode { id, .. } + | ThreadItem::ContextCompaction { id, .. } => id, + } + } +} + #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, JsonSchema, TS)] #[serde(tag = "type", rename_all = "camelCase")] #[ts(tag = "type", rename_all = "camelCase")] diff --git a/codex-rs/app-server/src/bespoke_event_handling.rs b/codex-rs/app-server/src/bespoke_event_handling.rs index b206eac35..2e87a1684 100644 --- a/codex-rs/app-server/src/bespoke_event_handling.rs +++ b/codex-rs/app-server/src/bespoke_event_handling.rs @@ -45,7 +45,6 @@ use codex_app_server_protocol::McpToolCallResult; use codex_app_server_protocol::McpToolCallStatus; use codex_app_server_protocol::ModelReroutedNotification; use codex_app_server_protocol::PatchApplyStatus; -use codex_app_server_protocol::PatchChangeKind as V2PatchChangeKind; use codex_app_server_protocol::PlanDeltaNotification; use codex_app_server_protocol::RawResponseItemCompletedNotification; use codex_app_server_protocol::ReasoningSummaryPartAddedNotification; @@ -72,6 +71,7 @@ use codex_app_server_protocol::TurnPlanStep; use codex_app_server_protocol::TurnPlanUpdatedNotification; use codex_app_server_protocol::TurnStatus; use codex_app_server_protocol::build_turns_from_rollout_items; +use codex_app_server_protocol::convert_patch_changes; use codex_core::CodexThread; use codex_core::ThreadManager; use codex_core::parse_command::shlex_join; @@ -81,7 +81,6 @@ use codex_core::protocol::Event; use codex_core::protocol::EventMsg; use codex_core::protocol::ExecApprovalRequestEvent; use codex_core::protocol::ExecCommandEndEvent; -use codex_core::protocol::FileChange as CoreFileChange; use codex_core::protocol::McpToolCallBeginEvent; use codex_core::protocol::McpToolCallEndEvent; use codex_core::protocol::Op; @@ -1706,46 +1705,6 @@ fn render_review_output_text(output: &ReviewOutputEvent) -> String { } } -fn convert_patch_changes(changes: &HashMap) -> Vec { - let mut converted: Vec = changes - .iter() - .map(|(path, change)| FileUpdateChange { - path: path.to_string_lossy().into_owned(), - kind: map_patch_change_kind(change), - diff: format_file_change_diff(change), - }) - .collect(); - converted.sort_by(|a, b| a.path.cmp(&b.path)); - converted -} - -fn map_patch_change_kind(change: &CoreFileChange) -> V2PatchChangeKind { - match change { - CoreFileChange::Add { .. } => V2PatchChangeKind::Add, - CoreFileChange::Delete { .. } => V2PatchChangeKind::Delete, - CoreFileChange::Update { move_path, .. } => V2PatchChangeKind::Update { - move_path: move_path.clone(), - }, - } -} - -fn format_file_change_diff(change: &CoreFileChange) -> String { - match change { - CoreFileChange::Add { content } => content.clone(), - CoreFileChange::Delete { content } => content.clone(), - CoreFileChange::Update { - unified_diff, - move_path, - } => { - if let Some(path) = move_path { - format!("{unified_diff}\n\nMoved to: {}", path.display()) - } else { - unified_diff.clone() - } - } - } -} - fn map_file_change_approval_decision( decision: FileChangeApprovalDecision, ) -> (ReviewDecision, Option) { diff --git a/codex-rs/app-server/src/codex_message_processor.rs b/codex-rs/app-server/src/codex_message_processor.rs index 58ca63306..ea2701d63 100644 --- a/codex-rs/app-server/src/codex_message_processor.rs +++ b/codex-rs/app-server/src/codex_message_processor.rs @@ -336,6 +336,7 @@ pub(crate) struct CodexMessageProcessor { outgoing: Arc, codex_linux_sandbox_exe: Option, config: Arc, + single_client_mode: bool, cli_overrides: Vec<(String, TomlValue)>, cloud_requirements: Arc>, active_login: Arc>>, @@ -361,6 +362,7 @@ pub(crate) struct CodexMessageProcessorArgs { pub(crate) config: Arc, pub(crate) cli_overrides: Vec<(String, TomlValue)>, pub(crate) cloud_requirements: Arc>, + pub(crate) single_client_mode: bool, pub(crate) feedback: CodexFeedback, } @@ -397,6 +399,7 @@ impl CodexMessageProcessor { config, cli_overrides, cloud_requirements, + single_client_mode, feedback, } = args; Self { @@ -405,6 +408,7 @@ impl CodexMessageProcessor { outgoing: outgoing.clone(), codex_linux_sandbox_exe, config, + single_client_mode, cli_overrides, cloud_requirements, active_login: Arc::new(Mutex::new(None)), @@ -3042,21 +3046,14 @@ impl CodexMessageProcessor { return true; } - if let Err(err) = self - .ensure_conversation_listener( - existing_thread_id, - request_id.connection_id, - false, - ApiVersion::V2, - ) - .await - { - tracing::warn!( - "failed to attach listener for thread {}: {}", - existing_thread_id, - err.message - ); - } + let thread_state = self.thread_state_manager.thread_state(existing_thread_id); + self.ensure_listener_task_running( + existing_thread_id, + existing_thread.clone(), + thread_state.clone(), + ApiVersion::V2, + ) + .await; let config_snapshot = existing_thread.config_snapshot().await; let mismatch_details = collect_resume_override_mismatches(params, &config_snapshot); @@ -3068,41 +3065,39 @@ impl CodexMessageProcessor { ); } - let Some(mut thread) = self - .load_thread_from_rollout_or_send_internal( - request_id.clone(), - existing_thread_id, - rollout_path.as_path(), - config_snapshot.model_provider_id.as_str(), - ) - .await - else { + let listener_command_tx = { + let thread_state = thread_state.lock().await; + thread_state.listener_command_tx() + }; + let Some(listener_command_tx) = listener_command_tx else { + let err = JSONRPCErrorError { + code: INTERNAL_ERROR_CODE, + message: format!( + "failed to enqueue running thread resume for thread {existing_thread_id}: thread listener is not running" + ), + data: None, + }; + self.outgoing.send_error(request_id, err).await; return true; }; - let ThreadConfigSnapshot { - model, - model_provider_id, - approval_policy, - sandbox_policy, - cwd, - reasoning_effort, - .. - } = config_snapshot; - thread.status = self - .thread_watch_manager - .loaded_status_for_thread(&thread.id) - .await; - let response = ThreadResumeResponse { - thread, - model, - model_provider: model_provider_id, - cwd, - approval_policy: approval_policy.into(), - sandbox: sandbox_policy.into(), - reasoning_effort, - }; - self.outgoing.send_response(request_id, response).await; + let command = crate::thread_state::ThreadListenerCommand::SendThreadResumeResponse( + crate::thread_state::PendingThreadResumeRequest { + request_id: request_id.clone(), + rollout_path, + config_snapshot, + }, + ); + if listener_command_tx.send(command).is_err() { + let err = JSONRPCErrorError { + code: INTERNAL_ERROR_CODE, + message: format!( + "failed to enqueue running thread resume for thread {existing_thread_id}: thread listener command channel is closed" + ), + data: None, + }; + self.outgoing.send_error(request_id, err).await; + } return true; } false @@ -5817,17 +5812,18 @@ impl CodexMessageProcessor { api_version: ApiVersion, ) { let (cancel_tx, mut cancel_rx) = oneshot::channel(); - { + let mut listener_command_rx = { let mut thread_state = thread_state.lock().await; if thread_state.listener_matches(&conversation) { return; } - thread_state.set_listener(cancel_tx, &conversation); - } + thread_state.set_listener(cancel_tx, &conversation) + }; let outgoing_for_task = self.outgoing.clone(); let thread_manager = self.thread_manager.clone(); let thread_watch_manager = self.thread_watch_manager.clone(); let fallback_model_provider = self.config.model_provider_id.clone(); + let single_client_mode = self.single_client_mode; tokio::spawn(async move { loop { tokio::select! { @@ -5869,7 +5865,10 @@ impl CodexMessageProcessor { conversation_id.to_string().into(), ); let (subscribed_connection_ids, raw_events_enabled) = { - let thread_state = thread_state.lock().await; + let mut thread_state = thread_state.lock().await; + if !single_client_mode { + thread_state.track_current_turn_event(&event.msg); + } ( thread_state.subscribed_connection_ids(), thread_state.experimental_raw_events, @@ -5908,6 +5907,25 @@ impl CodexMessageProcessor { ) .await; } + listener_command = listener_command_rx.recv() => { + let Some(listener_command) = listener_command else { + break; + }; + match listener_command { + crate::thread_state::ThreadListenerCommand::SendThreadResumeResponse( + resume_request, + ) => { + handle_pending_thread_resume_request( + conversation_id, + &thread_state, + &thread_watch_manager, + &outgoing_for_task, + resume_request, + ) + .await; + } + } + } } } }); @@ -6206,6 +6224,106 @@ impl CodexMessageProcessor { } } +async fn handle_pending_thread_resume_request( + conversation_id: ThreadId, + thread_state: &Arc>, + thread_watch_manager: &ThreadWatchManager, + outgoing: &Arc, + pending: crate::thread_state::PendingThreadResumeRequest, +) { + let active_turn = { + let state = thread_state.lock().await; + state.active_turn_snapshot() + }; + + let request_id = pending.request_id; + let connection_id = request_id.connection_id; + let mut thread = match load_thread_for_running_resume_response( + conversation_id, + pending.rollout_path.as_path(), + pending.config_snapshot.model_provider_id.as_str(), + active_turn.as_ref(), + ) + .await + { + Ok(thread) => thread, + Err(message) => { + outgoing + .send_error( + request_id, + JSONRPCErrorError { + code: INTERNAL_ERROR_CODE, + message, + data: None, + }, + ) + .await; + return; + } + }; + thread.status = thread_watch_manager + .loaded_status_for_thread(&thread.id) + .await; + + let ThreadConfigSnapshot { + model, + model_provider_id, + approval_policy, + sandbox_policy, + cwd, + reasoning_effort, + .. + } = pending.config_snapshot; + let response = ThreadResumeResponse { + thread, + model, + model_provider: model_provider_id, + cwd, + approval_policy: approval_policy.into(), + sandbox: sandbox_policy.into(), + reasoning_effort, + }; + outgoing.send_response(request_id, response).await; + thread_state.lock().await.add_connection(connection_id); +} + +async fn load_thread_for_running_resume_response( + conversation_id: ThreadId, + rollout_path: &Path, + fallback_provider: &str, + active_turn: Option<&Turn>, +) -> std::result::Result { + let mut thread = read_summary_from_rollout(rollout_path, fallback_provider) + .await + .map(summary_to_thread) + .map_err(|err| { + format!( + "failed to load rollout `{}` for thread {conversation_id}: {err}", + rollout_path.display() + ) + })?; + + let mut turns = read_rollout_items_from_rollout(rollout_path) + .await + .map(|items| build_turns_from_rollout_items(&items)) + .map_err(|err| { + format!( + "failed to load rollout `{}` for thread {conversation_id}: {err}", + rollout_path.display() + ) + })?; + if let Some(active_turn) = active_turn { + merge_turn_history_with_active_turn(&mut turns, active_turn.clone()); + } + thread.turns = turns; + Ok(thread) +} + +fn merge_turn_history_with_active_turn(turns: &mut Vec, active_turn: Turn) { + turns.retain(|turn| turn.id != active_turn.id); + turns.push(active_turn); +} + fn collect_resume_override_mismatches( request: &ThreadResumeParams, config_snapshot: &ThreadConfigSnapshot, diff --git a/codex-rs/app-server/src/lib.rs b/codex-rs/app-server/src/lib.rs index 7752ee92c..8e7e1455c 100644 --- a/codex-rs/app-server/src/lib.rs +++ b/codex-rs/app-server/src/lib.rs @@ -238,7 +238,8 @@ pub async fn run_main_with_transport( Some(start_websocket_acceptor(bind_address, transport_event_tx.clone()).await?); } } - let shutdown_when_no_connections = matches!(transport, AppServerTransport::Stdio); + let single_client_mode = matches!(transport, AppServerTransport::Stdio); + let shutdown_when_no_connections = single_client_mode; // Parse CLI overrides once and derive the base Config eagerly so later // components do not need to work with raw TOML values. @@ -439,6 +440,7 @@ pub async fn run_main_with_transport( outgoing: outgoing_message_sender, codex_linux_sandbox_exe, config: Arc::new(config), + single_client_mode, cli_overrides, loader_overrides, cloud_requirements: cloud_requirements.clone(), diff --git a/codex-rs/app-server/src/message_processor.rs b/codex-rs/app-server/src/message_processor.rs index f1c6029a2..e80e1790f 100644 --- a/codex-rs/app-server/src/message_processor.rs +++ b/codex-rs/app-server/src/message_processor.rs @@ -139,6 +139,7 @@ pub(crate) struct MessageProcessorArgs { pub(crate) outgoing: Arc, pub(crate) codex_linux_sandbox_exe: Option, pub(crate) config: Arc, + pub(crate) single_client_mode: bool, pub(crate) cli_overrides: Vec<(String, TomlValue)>, pub(crate) loader_overrides: LoaderOverrides, pub(crate) cloud_requirements: CloudRequirementsLoader, @@ -154,6 +155,7 @@ impl MessageProcessor { outgoing, codex_linux_sandbox_exe, config, + single_client_mode, cli_overrides, loader_overrides, cloud_requirements, @@ -184,6 +186,7 @@ impl MessageProcessor { config: Arc::clone(&config), cli_overrides: cli_overrides.clone(), cloud_requirements: cloud_requirements.clone(), + single_client_mode, feedback, }); let config_api = ConfigApi::new( diff --git a/codex-rs/app-server/src/thread_state.rs b/codex-rs/app-server/src/thread_state.rs index fd0f58abd..264c9a39d 100644 --- a/codex-rs/app-server/src/thread_state.rs +++ b/codex-rs/app-server/src/thread_state.rs @@ -1,13 +1,19 @@ use crate::outgoing_message::ConnectionId; use crate::outgoing_message::ConnectionRequestId; +use codex_app_server_protocol::ThreadHistoryBuilder; +use codex_app_server_protocol::Turn; use codex_app_server_protocol::TurnError; use codex_core::CodexThread; +use codex_core::ThreadConfigSnapshot; use codex_protocol::ThreadId; +use codex_protocol::protocol::EventMsg; use std::collections::HashMap; use std::collections::HashSet; +use std::path::PathBuf; use std::sync::Arc; use std::sync::Weak; use tokio::sync::Mutex; +use tokio::sync::mpsc; use tokio::sync::oneshot; use uuid::Uuid; @@ -16,6 +22,16 @@ type PendingInterruptQueue = Vec<( crate::codex_message_processor::ApiVersion, )>; +pub(crate) struct PendingThreadResumeRequest { + pub(crate) request_id: ConnectionRequestId, + pub(crate) rollout_path: PathBuf, + pub(crate) config_snapshot: ThreadConfigSnapshot, +} + +pub(crate) enum ThreadListenerCommand { + SendThreadResumeResponse(PendingThreadResumeRequest), +} + /// Per-conversation accumulation of the latest states e.g. error message while a turn runs. #[derive(Default, Clone)] pub(crate) struct TurnSummary { @@ -31,6 +47,8 @@ pub(crate) struct ThreadState { pub(crate) turn_summary: TurnSummary, pub(crate) cancel_tx: Option>, pub(crate) experimental_raw_events: bool, + listener_command_tx: Option>, + current_turn_history: ThreadHistoryBuilder, listener_thread: Option>, subscribed_connections: HashSet, } @@ -47,17 +65,22 @@ impl ThreadState { &mut self, cancel_tx: oneshot::Sender<()>, conversation: &Arc, - ) { + ) -> mpsc::UnboundedReceiver { if let Some(previous) = self.cancel_tx.replace(cancel_tx) { let _ = previous.send(()); } + let (listener_command_tx, listener_command_rx) = mpsc::unbounded_channel(); + self.listener_command_tx = Some(listener_command_tx); self.listener_thread = Some(Arc::downgrade(conversation)); + listener_command_rx } pub(crate) fn clear_listener(&mut self) { if let Some(cancel_tx) = self.cancel_tx.take() { let _ = cancel_tx.send(()); } + self.listener_command_tx = None; + self.current_turn_history.reset(); self.listener_thread = None; } @@ -76,6 +99,23 @@ impl ThreadState { pub(crate) fn set_experimental_raw_events(&mut self, enabled: bool) { self.experimental_raw_events = enabled; } + + pub(crate) fn listener_command_tx( + &self, + ) -> Option> { + self.listener_command_tx.clone() + } + + pub(crate) fn active_turn_snapshot(&self) -> Option { + self.current_turn_history.active_turn_snapshot() + } + + pub(crate) fn track_current_turn_event(&mut self, event: &EventMsg) { + self.current_turn_history.handle_event(event); + if !self.current_turn_history.has_active_turn() { + self.current_turn_history.reset(); + } + } } #[derive(Clone, Copy)] @@ -200,12 +240,24 @@ impl ThreadStateManager { } pub(crate) async fn remove_connection(&mut self, connection_id: ConnectionId) { - let Some(thread_ids) = self.thread_ids_by_connection.remove(&connection_id) else { - return; - }; + let thread_ids = self + .thread_ids_by_connection + .remove(&connection_id) + .unwrap_or_default(); self.subscription_state_by_id .retain(|_, state| state.connection_id != connection_id); + if thread_ids.is_empty() { + for thread_state in self.thread_states.values() { + let mut thread_state = thread_state.lock().await; + thread_state.remove_connection(connection_id); + if thread_state.subscribed_connection_ids().is_empty() { + thread_state.clear_listener(); + } + } + return; + } + for thread_id in thread_ids { if let Some(thread_state) = self.thread_states.get(&thread_id) { let mut thread_state = thread_state.lock().await;