Fix turn context reconstruction after backtracking (#14616)
## Summary - reuse rollout reconstruction when applying a backtrack rollback so `reference_context_item` is restored from persisted rollout state - build rollback replay from the flushed rollout items plus the rollback marker, avoiding the extra reread/fallback path - add regression coverage for rollback after compaction so turn-context diffing stays aligned after backtracking Co-authored-by: Codex <noreply@openai.com>
This commit is contained in:
parent
69c8a1ef9e
commit
bbd329a812
2 changed files with 124 additions and 66 deletions
|
|
@ -2065,17 +2065,9 @@ impl Session {
|
|||
}
|
||||
InitialHistory::Resumed(resumed_history) => {
|
||||
let rollout_items = resumed_history.history;
|
||||
|
||||
let reconstructed_rollout = self
|
||||
.reconstruct_history_from_rollout(&turn_context, &rollout_items)
|
||||
let previous_turn_settings = self
|
||||
.apply_rollout_reconstruction(&turn_context, &rollout_items)
|
||||
.await;
|
||||
let previous_turn_settings = reconstructed_rollout.previous_turn_settings.clone();
|
||||
self.set_previous_turn_settings(previous_turn_settings.clone())
|
||||
.await;
|
||||
{
|
||||
let mut state = self.state.lock().await;
|
||||
state.set_reference_context_item(reconstructed_rollout.reference_context_item);
|
||||
}
|
||||
|
||||
// If resuming, warn when the last recorded model differs from the current one.
|
||||
let curr: &str = turn_context.model_info.slug.as_str();
|
||||
|
|
@ -2097,13 +2089,6 @@ impl Session {
|
|||
.await;
|
||||
}
|
||||
|
||||
// Always add response items to conversation history
|
||||
let reconstructed_history = reconstructed_rollout.history;
|
||||
if !reconstructed_history.is_empty() {
|
||||
self.record_into_history(&reconstructed_history, &turn_context)
|
||||
.await;
|
||||
}
|
||||
|
||||
// Seed usage info from the recorded rollout so UIs can show token counts
|
||||
// immediately on resume/fork.
|
||||
if let Some(info) = Self::last_token_info_from_rollout(&rollout_items) {
|
||||
|
|
@ -2118,26 +2103,8 @@ impl Session {
|
|||
}
|
||||
}
|
||||
InitialHistory::Forked(rollout_items) => {
|
||||
let reconstructed_rollout = self
|
||||
.reconstruct_history_from_rollout(&turn_context, &rollout_items)
|
||||
self.apply_rollout_reconstruction(&turn_context, &rollout_items)
|
||||
.await;
|
||||
self.set_previous_turn_settings(
|
||||
reconstructed_rollout.previous_turn_settings.clone(),
|
||||
)
|
||||
.await;
|
||||
{
|
||||
let mut state = self.state.lock().await;
|
||||
state.set_reference_context_item(
|
||||
reconstructed_rollout.reference_context_item.clone(),
|
||||
);
|
||||
}
|
||||
|
||||
// Always add response items to conversation history
|
||||
let reconstructed_history = reconstructed_rollout.history;
|
||||
if !reconstructed_history.is_empty() {
|
||||
self.record_into_history(&reconstructed_history, &turn_context)
|
||||
.await;
|
||||
}
|
||||
|
||||
// Seed usage info from the recorded rollout so UIs can show token counts
|
||||
// immediately on resume/fork.
|
||||
|
|
@ -2171,6 +2138,25 @@ impl Session {
|
|||
}
|
||||
}
|
||||
|
||||
async fn apply_rollout_reconstruction(
|
||||
&self,
|
||||
turn_context: &TurnContext,
|
||||
rollout_items: &[RolloutItem],
|
||||
) -> Option<PreviousTurnSettings> {
|
||||
let reconstructed_rollout = self
|
||||
.reconstruct_history_from_rollout(turn_context, rollout_items)
|
||||
.await;
|
||||
let previous_turn_settings = reconstructed_rollout.previous_turn_settings.clone();
|
||||
self.replace_history(
|
||||
reconstructed_rollout.history,
|
||||
reconstructed_rollout.reference_context_item,
|
||||
)
|
||||
.await;
|
||||
self.set_previous_turn_settings(previous_turn_settings.clone())
|
||||
.await;
|
||||
previous_turn_settings
|
||||
}
|
||||
|
||||
fn last_token_info_from_rollout(rollout_items: &[RolloutItem]) -> Option<TokenUsageInfo> {
|
||||
rollout_items.iter().rev().find_map(|item| match item {
|
||||
RolloutItem::EventMsg(EventMsg::TokenCount(ev)) => ev.info.clone(),
|
||||
|
|
@ -2613,31 +2599,17 @@ impl Session {
|
|||
}
|
||||
|
||||
pub(crate) async fn send_event_raw(&self, event: Event) {
|
||||
// Record the last known agent status.
|
||||
if let Some(status) = agent_status_from_event(&event.msg) {
|
||||
self.agent_status.send_replace(status);
|
||||
}
|
||||
// Persist the event into rollout (recorder filters as needed)
|
||||
let rollout_items = vec![RolloutItem::EventMsg(event.msg.clone())];
|
||||
self.persist_rollout_items(&rollout_items).await;
|
||||
if let Err(e) = self.tx_event.send(event).await {
|
||||
debug!("dropping event because channel is closed: {e}");
|
||||
}
|
||||
self.deliver_event_raw(event).await;
|
||||
}
|
||||
|
||||
/// Persist the event to the rollout file, flush it, and only then deliver it to clients.
|
||||
///
|
||||
/// Most events can be delivered immediately after queueing the rollout write, but some
|
||||
/// clients (e.g. app-server thread/rollback) re-read the rollout file synchronously on
|
||||
/// receipt of the event and depend on the marker already being visible on disk.
|
||||
pub(crate) async fn send_event_raw_flushed(&self, event: Event) {
|
||||
async fn deliver_event_raw(&self, event: Event) {
|
||||
// Record the last known agent status.
|
||||
if let Some(status) = agent_status_from_event(&event.msg) {
|
||||
self.agent_status.send_replace(status);
|
||||
}
|
||||
self.persist_rollout_items(&[RolloutItem::EventMsg(event.msg.clone())])
|
||||
.await;
|
||||
self.flush_rollout().await;
|
||||
if let Err(e) = self.tx_event.send(event).await {
|
||||
debug!("dropping event because channel is closed: {e}");
|
||||
}
|
||||
|
|
@ -5070,29 +5042,22 @@ mod handlers {
|
|||
};
|
||||
|
||||
let rollback_event = ThreadRolledBackEvent { num_turns };
|
||||
let rollback_msg = EventMsg::ThreadRolledBack(rollback_event.clone());
|
||||
let replay_items = initial_history
|
||||
.get_rollout_items()
|
||||
.into_iter()
|
||||
.chain(std::iter::once(RolloutItem::EventMsg(
|
||||
EventMsg::ThreadRolledBack(rollback_event.clone()),
|
||||
)))
|
||||
.chain(std::iter::once(RolloutItem::EventMsg(rollback_msg.clone())))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let reconstructed = sess
|
||||
.reconstruct_history_from_rollout(turn_context.as_ref(), replay_items.as_slice())
|
||||
sess.persist_rollout_items(&[RolloutItem::EventMsg(rollback_msg.clone())])
|
||||
.await;
|
||||
sess.replace_history(
|
||||
reconstructed.history,
|
||||
reconstructed.reference_context_item.clone(),
|
||||
)
|
||||
.await;
|
||||
sess.set_previous_turn_settings(reconstructed.previous_turn_settings)
|
||||
sess.flush_rollout().await;
|
||||
sess.apply_rollout_reconstruction(turn_context.as_ref(), replay_items.as_slice())
|
||||
.await;
|
||||
sess.recompute_token_usage(turn_context.as_ref()).await;
|
||||
|
||||
sess.send_event_raw_flushed(Event {
|
||||
sess.deliver_event_raw(Event {
|
||||
id: turn_context.sub_id.clone(),
|
||||
msg: EventMsg::ThreadRolledBack(rollback_event),
|
||||
msg: rollback_msg,
|
||||
})
|
||||
.await;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1204,6 +1204,99 @@ async fn thread_rollback_recomputes_previous_turn_settings_and_reference_context
|
|||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn thread_rollback_restores_cleared_reference_context_item_after_compaction() {
|
||||
let (sess, tc, rx) = make_session_and_context_with_rx().await;
|
||||
attach_rollout_recorder(&sess).await;
|
||||
|
||||
let first_context_item = tc.to_turn_context_item();
|
||||
let first_turn_id = first_context_item
|
||||
.turn_id
|
||||
.clone()
|
||||
.expect("turn context should have turn_id");
|
||||
let compact_turn_id = "compact-turn".to_string();
|
||||
let rolled_back_turn_id = "rolled-back-turn".to_string();
|
||||
let compacted_history = vec![
|
||||
user_message("turn 1 user"),
|
||||
user_message("summary after compaction"),
|
||||
];
|
||||
|
||||
sess.persist_rollout_items(&[
|
||||
RolloutItem::EventMsg(EventMsg::TurnStarted(
|
||||
codex_protocol::protocol::TurnStartedEvent {
|
||||
turn_id: first_turn_id.clone(),
|
||||
model_context_window: Some(128_000),
|
||||
collaboration_mode_kind: ModeKind::Default,
|
||||
},
|
||||
)),
|
||||
RolloutItem::EventMsg(EventMsg::UserMessage(UserMessageEvent {
|
||||
message: "turn 1 user".to_string(),
|
||||
images: None,
|
||||
local_images: Vec::new(),
|
||||
text_elements: Vec::new(),
|
||||
})),
|
||||
RolloutItem::TurnContext(first_context_item.clone()),
|
||||
RolloutItem::ResponseItem(user_message("turn 1 user")),
|
||||
RolloutItem::ResponseItem(assistant_message("turn 1 assistant")),
|
||||
RolloutItem::EventMsg(EventMsg::TurnComplete(TurnCompleteEvent {
|
||||
turn_id: first_turn_id,
|
||||
last_agent_message: None,
|
||||
})),
|
||||
RolloutItem::EventMsg(EventMsg::TurnStarted(
|
||||
codex_protocol::protocol::TurnStartedEvent {
|
||||
turn_id: compact_turn_id.clone(),
|
||||
model_context_window: Some(128_000),
|
||||
collaboration_mode_kind: ModeKind::Default,
|
||||
},
|
||||
)),
|
||||
RolloutItem::Compacted(CompactedItem {
|
||||
message: "summary after compaction".to_string(),
|
||||
replacement_history: Some(compacted_history.clone()),
|
||||
}),
|
||||
RolloutItem::EventMsg(EventMsg::TurnComplete(TurnCompleteEvent {
|
||||
turn_id: compact_turn_id,
|
||||
last_agent_message: None,
|
||||
})),
|
||||
RolloutItem::EventMsg(EventMsg::TurnStarted(
|
||||
codex_protocol::protocol::TurnStartedEvent {
|
||||
turn_id: rolled_back_turn_id.clone(),
|
||||
model_context_window: Some(128_000),
|
||||
collaboration_mode_kind: ModeKind::Default,
|
||||
},
|
||||
)),
|
||||
RolloutItem::EventMsg(EventMsg::UserMessage(UserMessageEvent {
|
||||
message: "turn 2 user".to_string(),
|
||||
images: None,
|
||||
local_images: Vec::new(),
|
||||
text_elements: Vec::new(),
|
||||
})),
|
||||
RolloutItem::TurnContext(TurnContextItem {
|
||||
turn_id: Some(rolled_back_turn_id.clone()),
|
||||
model: "rolled-back-model".to_string(),
|
||||
..first_context_item.clone()
|
||||
}),
|
||||
RolloutItem::ResponseItem(user_message("turn 2 user")),
|
||||
RolloutItem::ResponseItem(assistant_message("turn 2 assistant")),
|
||||
RolloutItem::EventMsg(EventMsg::TurnComplete(TurnCompleteEvent {
|
||||
turn_id: rolled_back_turn_id,
|
||||
last_agent_message: None,
|
||||
})),
|
||||
])
|
||||
.await;
|
||||
sess.replace_history(
|
||||
vec![assistant_message("stale history")],
|
||||
Some(first_context_item),
|
||||
)
|
||||
.await;
|
||||
|
||||
handlers::thread_rollback(&sess, "sub-1".to_string(), 1).await;
|
||||
let rollback_event = wait_for_thread_rolled_back(&rx).await;
|
||||
assert_eq!(rollback_event.num_turns, 1);
|
||||
|
||||
assert_eq!(sess.clone_history().await.raw_items(), compacted_history);
|
||||
assert!(sess.reference_context_item().await.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn thread_rollback_persists_marker_and_replays_cumulatively() {
|
||||
let (sess, tc, rx) = make_session_and_context_with_rx().await;
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue