From bb5dfd037a50fd5a590973b221bf2924ded654b3 Mon Sep 17 00:00:00 2001 From: Ahmed Ibrahim Date: Wed, 11 Feb 2026 16:45:18 -0800 Subject: [PATCH] Hydrate previous model across resume/fork/rollback/task start (#11497) - Replace pending resume model state with persistent previous_model and hydrate it on resume, fork, rollback, and task end in spawn_task --- codex-rs/core/src/codex.rs | 136 +++++++++++++++++++++++++---- codex-rs/core/src/state/session.rs | 13 ++- codex-rs/core/src/tasks/mod.rs | 10 +++ 3 files changed, 139 insertions(+), 20 deletions(-) diff --git a/codex-rs/core/src/codex.rs b/codex-rs/core/src/codex.rs index d1f4305b3..4696d9d15 100644 --- a/codex-rs/core/src/codex.rs +++ b/codex-rs/core/src/codex.rs @@ -1348,20 +1348,25 @@ impl Session { let mut state = self.state.lock().await; state.initial_context_seeded = true; } + self.set_previous_model(None).await; // Ensure initial items are visible to immediate readers (e.g., tests, forks). self.flush_rollout().await; } InitialHistory::Resumed(resumed_history) => { let rollout_items = resumed_history.history; + let previous_model = Self::last_rollout_model_name(&rollout_items) + .map(std::string::ToString::to_string); { let mut state = self.state.lock().await; state.initial_context_seeded = false; - state.pending_resume_previous_model = None; } + self.set_previous_model(previous_model).await; // If resuming, warn when the last recorded model differs from the current one. let curr = turn_context.model_info.slug.as_str(); - if let Some(prev) = Self::last_model_name(&rollout_items, curr) { + if let Some(prev) = + Self::last_rollout_model_name(&rollout_items).filter(|p| *p != curr) + { warn!("resuming session with different model: previous={prev}, current={curr}"); self.send_event( &turn_context, @@ -1373,9 +1378,6 @@ impl Session { }), ) .await; - - let mut state = self.state.lock().await; - state.pending_resume_previous_model = Some(prev.to_string()); } // Always add response items to conversation history @@ -1399,6 +1401,10 @@ impl Session { self.flush_rollout().await; } InitialHistory::Forked(rollout_items) => { + let previous_model = Self::last_rollout_model_name(&rollout_items) + .map(std::string::ToString::to_string); + self.set_previous_model(previous_model).await; + // Always add response items to conversation history let reconstructed_history = self .reconstruct_history_from_rollout(&turn_context, &rollout_items) @@ -1438,19 +1444,14 @@ impl Session { } } - fn last_model_name<'a>(rollout_items: &'a [RolloutItem], current: &str) -> Option<&'a str> { - let previous = rollout_items.iter().rev().find_map(|it| { + fn last_rollout_model_name(rollout_items: &[RolloutItem]) -> Option<&str> { + rollout_items.iter().rev().find_map(|it| { if let RolloutItem::TurnContext(ctx) = it { Some(ctx.model.as_str()) } else { None } - })?; - if previous == current { - None - } else { - Some(previous) - } + }) } fn last_token_info_from_rollout(rollout_items: &[RolloutItem]) -> Option { @@ -1460,9 +1461,14 @@ impl Session { }) } - async fn take_pending_resume_previous_model(&self) -> Option { + async fn previous_model(&self) -> Option { + let state = self.state.lock().await; + state.previous_model() + } + + pub(crate) async fn set_previous_model(&self, previous_model: Option) { let mut state = self.state.lock().await; - state.pending_resume_previous_model.take() + state.set_previous_model(previous_model); } fn maybe_refresh_shell_snapshot_for_cwd( @@ -3134,10 +3140,10 @@ mod handlers { // Attempt to inject input into current task. if let Err(SteerInputError::NoActiveTurn(items)) = sess.steer_input(items, None).await { sess.seed_initial_context_if_needed(¤t_context).await; - let resumed_model = sess.take_pending_resume_previous_model().await; + let previous_model = sess.previous_model().await; let update_items = sess.build_settings_update_items( previous_context.as_ref(), - resumed_model.as_deref(), + previous_model.as_deref(), ¤t_context, ); if !update_items.is_empty() { @@ -3529,6 +3535,8 @@ mod handlers { } let turn_context = sess.new_default_turn_with_sub_id(sub_id).await; + sess.set_previous_model(Some(turn_context.model_info.slug.clone())) + .await; let mut history = sess.clone_history().await; history.drop_last_n_user_turns(num_turns); @@ -5511,6 +5519,40 @@ mod tests { assert_eq!(expected, history.raw_items()); } + #[tokio::test] + async fn record_initial_history_resumed_hydrates_previous_model() { + let (session, turn_context) = make_session_and_context().await; + let previous_model = "previous-rollout-model"; + let rollout_items = vec![RolloutItem::TurnContext(TurnContextItem { + turn_id: Some(turn_context.sub_id.clone()), + cwd: turn_context.cwd.clone(), + approval_policy: turn_context.approval_policy, + sandbox_policy: turn_context.sandbox_policy.clone(), + model: previous_model.to_string(), + personality: turn_context.personality, + collaboration_mode: Some(turn_context.collaboration_mode.clone()), + effort: turn_context.reasoning_effort, + summary: turn_context.reasoning_summary, + user_instructions: None, + developer_instructions: None, + final_output_json_schema: None, + truncation_policy: Some(turn_context.truncation_policy.into()), + })]; + + session + .record_initial_history(InitialHistory::Resumed(ResumedHistory { + conversation_id: ThreadId::default(), + history: rollout_items, + rollout_path: PathBuf::from("/tmp/resume.jsonl"), + })) + .await; + + assert_eq!( + session.previous_model().await, + Some(previous_model.to_string()) + ); + } + #[tokio::test] async fn resumed_history_seeds_initial_context_on_first_turn_only() { let (session, turn_context) = make_session_and_context().await; @@ -5673,6 +5715,36 @@ mod tests { assert_eq!(expected, history.raw_items()); } + #[tokio::test] + async fn record_initial_history_forked_hydrates_previous_model() { + let (session, turn_context) = make_session_and_context().await; + let previous_model = "forked-rollout-model"; + let rollout_items = vec![RolloutItem::TurnContext(TurnContextItem { + turn_id: Some(turn_context.sub_id.clone()), + cwd: turn_context.cwd.clone(), + approval_policy: turn_context.approval_policy, + sandbox_policy: turn_context.sandbox_policy.clone(), + model: previous_model.to_string(), + personality: turn_context.personality, + collaboration_mode: Some(turn_context.collaboration_mode.clone()), + effort: turn_context.reasoning_effort, + summary: turn_context.reasoning_summary, + user_instructions: None, + developer_instructions: None, + final_output_json_schema: None, + truncation_policy: Some(turn_context.truncation_policy.into()), + })]; + + session + .record_initial_history(InitialHistory::Forked(rollout_items)) + .await; + + assert_eq!( + session.previous_model().await, + Some(previous_model.to_string()) + ); + } + #[tokio::test] async fn thread_rollback_drops_last_turn_from_history() { let (sess, tc, rx) = make_session_and_context_with_rx().await; @@ -5736,6 +5808,10 @@ mod tests { let history = sess.clone_history().await; assert_eq!(expected, history.raw_items()); + assert_eq!( + sess.previous_model().await, + Some(tc.model_info.slug.clone()) + ); } #[tokio::test] @@ -6569,6 +6645,32 @@ mod tests { } } + #[tokio::test] + async fn spawn_task_hydrates_previous_model() { + let (sess, tc, _rx) = make_session_and_context_with_rx().await; + sess.set_previous_model(None).await; + let input = vec![UserInput::Text { + text: "hello".to_string(), + text_elements: Vec::new(), + }]; + + sess.spawn_task( + Arc::clone(&tc), + input, + NeverEndingTask { + kind: TaskKind::Regular, + listen_to_cancellation_token: true, + }, + ) + .await; + + sess.abort_all_tasks(TurnAbortReason::Interrupted).await; + assert_eq!( + sess.previous_model().await, + Some(tc.model_info.slug.clone()) + ); + } + #[derive(Clone, Copy)] struct NeverEndingTask { kind: TaskKind, diff --git a/codex-rs/core/src/state/session.rs b/codex-rs/core/src/state/session.rs index 16614f3a6..2ceb684d0 100644 --- a/codex-rs/core/src/state/session.rs +++ b/codex-rs/core/src/state/session.rs @@ -25,8 +25,8 @@ pub(crate) struct SessionState { /// TODO(owen): This is a temporary solution to avoid updating a thread's updated_at /// timestamp when resuming a session. Remove this once SQLite is in place. pub(crate) initial_context_seeded: bool, - /// Previous rollout model for one-shot model-switch handling on first turn after resume. - pub(crate) pending_resume_previous_model: Option, + /// Previous model seen by the session, used for model-switch handling on task start. + previous_model: Option, /// Startup regular task pre-created during session initialization. pub(crate) startup_regular_task: Option, pub(crate) active_mcp_tool_selection: Option>, @@ -44,7 +44,7 @@ impl SessionState { dependency_env: HashMap::new(), mcp_dependency_prompted: HashSet::new(), initial_context_seeded: false, - pending_resume_previous_model: None, + previous_model: None, startup_regular_task: None, active_mcp_tool_selection: None, } @@ -59,6 +59,13 @@ impl SessionState { self.history.record_items(items, policy); } + pub(crate) fn previous_model(&self) -> Option { + self.previous_model.clone() + } + pub(crate) fn set_previous_model(&mut self, previous_model: Option) { + self.previous_model = previous_model; + } + pub(crate) fn clone_history(&self) -> ContextManager { self.history.clone() } diff --git a/codex-rs/core/src/tasks/mod.rs b/codex-rs/core/src/tasks/mod.rs index 9a3a4756a..106640251 100644 --- a/codex-rs/core/src/tasks/mod.rs +++ b/codex-rs/core/src/tasks/mod.rs @@ -140,6 +140,7 @@ impl Session { tokio::spawn( async move { let ctx_for_finish = Arc::clone(&ctx); + let model_slug = ctx_for_finish.model_info.slug.clone(); let last_agent_message = task_for_run .run( Arc::clone(&session_ctx), @@ -155,6 +156,11 @@ impl Session { sess.on_task_finished(ctx_for_finish, last_agent_message) .await; } + // Set previous model regardless of completion or interruption for model-switch handling. + session_ctx + .clone_session() + .set_previous_model(Some(model_slug)) + .await; done_clone.notify_waiters(); } .instrument(session_span), @@ -267,6 +273,10 @@ impl Session { task.handle.abort(); + // Set previous model even when interrupted so model-switch handling stays correct. + self.set_previous_model(Some(task.turn_context.model_info.slug.clone())) + .await; + let session_ctx = Arc::new(SessionTaskContext::new(Arc::clone(self))); session_task .abort(session_ctx, Arc::clone(&task.turn_context))