diff --git a/codex-rs/core/src/codex.rs b/codex-rs/core/src/codex.rs index bc37e4d5e..035a8f4be 100644 --- a/codex-rs/core/src/codex.rs +++ b/codex-rs/core/src/codex.rs @@ -2065,17 +2065,9 @@ impl Session { } InitialHistory::Resumed(resumed_history) => { let rollout_items = resumed_history.history; - - let reconstructed_rollout = self - .reconstruct_history_from_rollout(&turn_context, &rollout_items) + let previous_turn_settings = self + .apply_rollout_reconstruction(&turn_context, &rollout_items) .await; - let previous_turn_settings = reconstructed_rollout.previous_turn_settings.clone(); - self.set_previous_turn_settings(previous_turn_settings.clone()) - .await; - { - let mut state = self.state.lock().await; - state.set_reference_context_item(reconstructed_rollout.reference_context_item); - } // If resuming, warn when the last recorded model differs from the current one. let curr: &str = turn_context.model_info.slug.as_str(); @@ -2097,13 +2089,6 @@ impl Session { .await; } - // Always add response items to conversation history - let reconstructed_history = reconstructed_rollout.history; - if !reconstructed_history.is_empty() { - self.record_into_history(&reconstructed_history, &turn_context) - .await; - } - // Seed usage info from the recorded rollout so UIs can show token counts // immediately on resume/fork. if let Some(info) = Self::last_token_info_from_rollout(&rollout_items) { @@ -2118,26 +2103,8 @@ impl Session { } } InitialHistory::Forked(rollout_items) => { - let reconstructed_rollout = self - .reconstruct_history_from_rollout(&turn_context, &rollout_items) + self.apply_rollout_reconstruction(&turn_context, &rollout_items) .await; - self.set_previous_turn_settings( - reconstructed_rollout.previous_turn_settings.clone(), - ) - .await; - { - let mut state = self.state.lock().await; - state.set_reference_context_item( - reconstructed_rollout.reference_context_item.clone(), - ); - } - - // Always add response items to conversation history - let reconstructed_history = reconstructed_rollout.history; - if !reconstructed_history.is_empty() { - self.record_into_history(&reconstructed_history, &turn_context) - .await; - } // Seed usage info from the recorded rollout so UIs can show token counts // immediately on resume/fork. @@ -2171,6 +2138,25 @@ impl Session { } } + async fn apply_rollout_reconstruction( + &self, + turn_context: &TurnContext, + rollout_items: &[RolloutItem], + ) -> Option { + let reconstructed_rollout = self + .reconstruct_history_from_rollout(turn_context, rollout_items) + .await; + let previous_turn_settings = reconstructed_rollout.previous_turn_settings.clone(); + self.replace_history( + reconstructed_rollout.history, + reconstructed_rollout.reference_context_item, + ) + .await; + self.set_previous_turn_settings(previous_turn_settings.clone()) + .await; + previous_turn_settings + } + fn last_token_info_from_rollout(rollout_items: &[RolloutItem]) -> Option { rollout_items.iter().rev().find_map(|item| match item { RolloutItem::EventMsg(EventMsg::TokenCount(ev)) => ev.info.clone(), @@ -2613,31 +2599,17 @@ impl Session { } pub(crate) async fn send_event_raw(&self, event: Event) { - // Record the last known agent status. - if let Some(status) = agent_status_from_event(&event.msg) { - self.agent_status.send_replace(status); - } // Persist the event into rollout (recorder filters as needed) let rollout_items = vec![RolloutItem::EventMsg(event.msg.clone())]; self.persist_rollout_items(&rollout_items).await; - if let Err(e) = self.tx_event.send(event).await { - debug!("dropping event because channel is closed: {e}"); - } + self.deliver_event_raw(event).await; } - /// Persist the event to the rollout file, flush it, and only then deliver it to clients. - /// - /// Most events can be delivered immediately after queueing the rollout write, but some - /// clients (e.g. app-server thread/rollback) re-read the rollout file synchronously on - /// receipt of the event and depend on the marker already being visible on disk. - pub(crate) async fn send_event_raw_flushed(&self, event: Event) { + async fn deliver_event_raw(&self, event: Event) { // Record the last known agent status. if let Some(status) = agent_status_from_event(&event.msg) { self.agent_status.send_replace(status); } - self.persist_rollout_items(&[RolloutItem::EventMsg(event.msg.clone())]) - .await; - self.flush_rollout().await; if let Err(e) = self.tx_event.send(event).await { debug!("dropping event because channel is closed: {e}"); } @@ -5070,29 +5042,22 @@ mod handlers { }; let rollback_event = ThreadRolledBackEvent { num_turns }; + let rollback_msg = EventMsg::ThreadRolledBack(rollback_event.clone()); let replay_items = initial_history .get_rollout_items() .into_iter() - .chain(std::iter::once(RolloutItem::EventMsg( - EventMsg::ThreadRolledBack(rollback_event.clone()), - ))) + .chain(std::iter::once(RolloutItem::EventMsg(rollback_msg.clone()))) .collect::>(); - - let reconstructed = sess - .reconstruct_history_from_rollout(turn_context.as_ref(), replay_items.as_slice()) + sess.persist_rollout_items(&[RolloutItem::EventMsg(rollback_msg.clone())]) .await; - sess.replace_history( - reconstructed.history, - reconstructed.reference_context_item.clone(), - ) - .await; - sess.set_previous_turn_settings(reconstructed.previous_turn_settings) + sess.flush_rollout().await; + sess.apply_rollout_reconstruction(turn_context.as_ref(), replay_items.as_slice()) .await; sess.recompute_token_usage(turn_context.as_ref()).await; - sess.send_event_raw_flushed(Event { + sess.deliver_event_raw(Event { id: turn_context.sub_id.clone(), - msg: EventMsg::ThreadRolledBack(rollback_event), + msg: rollback_msg, }) .await; } diff --git a/codex-rs/core/src/codex_tests.rs b/codex-rs/core/src/codex_tests.rs index ed5d5790b..a06f6a94e 100644 --- a/codex-rs/core/src/codex_tests.rs +++ b/codex-rs/core/src/codex_tests.rs @@ -1204,6 +1204,99 @@ async fn thread_rollback_recomputes_previous_turn_settings_and_reference_context ); } +#[tokio::test] +async fn thread_rollback_restores_cleared_reference_context_item_after_compaction() { + let (sess, tc, rx) = make_session_and_context_with_rx().await; + attach_rollout_recorder(&sess).await; + + let first_context_item = tc.to_turn_context_item(); + let first_turn_id = first_context_item + .turn_id + .clone() + .expect("turn context should have turn_id"); + let compact_turn_id = "compact-turn".to_string(); + let rolled_back_turn_id = "rolled-back-turn".to_string(); + let compacted_history = vec![ + user_message("turn 1 user"), + user_message("summary after compaction"), + ]; + + sess.persist_rollout_items(&[ + RolloutItem::EventMsg(EventMsg::TurnStarted( + codex_protocol::protocol::TurnStartedEvent { + turn_id: first_turn_id.clone(), + model_context_window: Some(128_000), + collaboration_mode_kind: ModeKind::Default, + }, + )), + RolloutItem::EventMsg(EventMsg::UserMessage(UserMessageEvent { + message: "turn 1 user".to_string(), + images: None, + local_images: Vec::new(), + text_elements: Vec::new(), + })), + RolloutItem::TurnContext(first_context_item.clone()), + RolloutItem::ResponseItem(user_message("turn 1 user")), + RolloutItem::ResponseItem(assistant_message("turn 1 assistant")), + RolloutItem::EventMsg(EventMsg::TurnComplete(TurnCompleteEvent { + turn_id: first_turn_id, + last_agent_message: None, + })), + RolloutItem::EventMsg(EventMsg::TurnStarted( + codex_protocol::protocol::TurnStartedEvent { + turn_id: compact_turn_id.clone(), + model_context_window: Some(128_000), + collaboration_mode_kind: ModeKind::Default, + }, + )), + RolloutItem::Compacted(CompactedItem { + message: "summary after compaction".to_string(), + replacement_history: Some(compacted_history.clone()), + }), + RolloutItem::EventMsg(EventMsg::TurnComplete(TurnCompleteEvent { + turn_id: compact_turn_id, + last_agent_message: None, + })), + RolloutItem::EventMsg(EventMsg::TurnStarted( + codex_protocol::protocol::TurnStartedEvent { + turn_id: rolled_back_turn_id.clone(), + model_context_window: Some(128_000), + collaboration_mode_kind: ModeKind::Default, + }, + )), + RolloutItem::EventMsg(EventMsg::UserMessage(UserMessageEvent { + message: "turn 2 user".to_string(), + images: None, + local_images: Vec::new(), + text_elements: Vec::new(), + })), + RolloutItem::TurnContext(TurnContextItem { + turn_id: Some(rolled_back_turn_id.clone()), + model: "rolled-back-model".to_string(), + ..first_context_item.clone() + }), + RolloutItem::ResponseItem(user_message("turn 2 user")), + RolloutItem::ResponseItem(assistant_message("turn 2 assistant")), + RolloutItem::EventMsg(EventMsg::TurnComplete(TurnCompleteEvent { + turn_id: rolled_back_turn_id, + last_agent_message: None, + })), + ]) + .await; + sess.replace_history( + vec![assistant_message("stale history")], + Some(first_context_item), + ) + .await; + + handlers::thread_rollback(&sess, "sub-1".to_string(), 1).await; + let rollback_event = wait_for_thread_rolled_back(&rx).await; + assert_eq!(rollback_event.num_turns, 1); + + assert_eq!(sess.clone_history().await.raw_items(), compacted_history); + assert!(sess.reference_context_item().await.is_none()); +} + #[tokio::test] async fn thread_rollback_persists_marker_and_replays_cumulatively() { let (sess, tc, rx) = make_session_and_context_with_rx().await;