Hydrate previous model across resume/fork/rollback/task start (#11497)

- Replace pending resume model state with persistent previous_model and
hydrate it on resume, fork, rollback, and task end in spawn_task
This commit is contained in:
Ahmed Ibrahim 2026-02-11 16:45:18 -08:00 committed by GitHub
parent 23444a063b
commit bb5dfd037a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 139 additions and 20 deletions

View file

@ -1348,20 +1348,25 @@ impl Session {
let mut state = self.state.lock().await;
state.initial_context_seeded = true;
}
self.set_previous_model(None).await;
// Ensure initial items are visible to immediate readers (e.g., tests, forks).
self.flush_rollout().await;
}
InitialHistory::Resumed(resumed_history) => {
let rollout_items = resumed_history.history;
let previous_model = Self::last_rollout_model_name(&rollout_items)
.map(std::string::ToString::to_string);
{
let mut state = self.state.lock().await;
state.initial_context_seeded = false;
state.pending_resume_previous_model = None;
}
self.set_previous_model(previous_model).await;
// If resuming, warn when the last recorded model differs from the current one.
let curr = turn_context.model_info.slug.as_str();
if let Some(prev) = Self::last_model_name(&rollout_items, curr) {
if let Some(prev) =
Self::last_rollout_model_name(&rollout_items).filter(|p| *p != curr)
{
warn!("resuming session with different model: previous={prev}, current={curr}");
self.send_event(
&turn_context,
@ -1373,9 +1378,6 @@ impl Session {
}),
)
.await;
let mut state = self.state.lock().await;
state.pending_resume_previous_model = Some(prev.to_string());
}
// Always add response items to conversation history
@ -1399,6 +1401,10 @@ impl Session {
self.flush_rollout().await;
}
InitialHistory::Forked(rollout_items) => {
let previous_model = Self::last_rollout_model_name(&rollout_items)
.map(std::string::ToString::to_string);
self.set_previous_model(previous_model).await;
// Always add response items to conversation history
let reconstructed_history = self
.reconstruct_history_from_rollout(&turn_context, &rollout_items)
@ -1438,19 +1444,14 @@ impl Session {
}
}
fn last_model_name<'a>(rollout_items: &'a [RolloutItem], current: &str) -> Option<&'a str> {
let previous = rollout_items.iter().rev().find_map(|it| {
fn last_rollout_model_name(rollout_items: &[RolloutItem]) -> Option<&str> {
rollout_items.iter().rev().find_map(|it| {
if let RolloutItem::TurnContext(ctx) = it {
Some(ctx.model.as_str())
} else {
None
}
})?;
if previous == current {
None
} else {
Some(previous)
}
})
}
fn last_token_info_from_rollout(rollout_items: &[RolloutItem]) -> Option<TokenUsageInfo> {
@ -1460,9 +1461,14 @@ impl Session {
})
}
async fn take_pending_resume_previous_model(&self) -> Option<String> {
async fn previous_model(&self) -> Option<String> {
let state = self.state.lock().await;
state.previous_model()
}
pub(crate) async fn set_previous_model(&self, previous_model: Option<String>) {
let mut state = self.state.lock().await;
state.pending_resume_previous_model.take()
state.set_previous_model(previous_model);
}
fn maybe_refresh_shell_snapshot_for_cwd(
@ -3134,10 +3140,10 @@ mod handlers {
// Attempt to inject input into current task.
if let Err(SteerInputError::NoActiveTurn(items)) = sess.steer_input(items, None).await {
sess.seed_initial_context_if_needed(&current_context).await;
let resumed_model = sess.take_pending_resume_previous_model().await;
let previous_model = sess.previous_model().await;
let update_items = sess.build_settings_update_items(
previous_context.as_ref(),
resumed_model.as_deref(),
previous_model.as_deref(),
&current_context,
);
if !update_items.is_empty() {
@ -3529,6 +3535,8 @@ mod handlers {
}
let turn_context = sess.new_default_turn_with_sub_id(sub_id).await;
sess.set_previous_model(Some(turn_context.model_info.slug.clone()))
.await;
let mut history = sess.clone_history().await;
history.drop_last_n_user_turns(num_turns);
@ -5511,6 +5519,40 @@ mod tests {
assert_eq!(expected, history.raw_items());
}
#[tokio::test]
async fn record_initial_history_resumed_hydrates_previous_model() {
let (session, turn_context) = make_session_and_context().await;
let previous_model = "previous-rollout-model";
let rollout_items = vec![RolloutItem::TurnContext(TurnContextItem {
turn_id: Some(turn_context.sub_id.clone()),
cwd: turn_context.cwd.clone(),
approval_policy: turn_context.approval_policy,
sandbox_policy: turn_context.sandbox_policy.clone(),
model: previous_model.to_string(),
personality: turn_context.personality,
collaboration_mode: Some(turn_context.collaboration_mode.clone()),
effort: turn_context.reasoning_effort,
summary: turn_context.reasoning_summary,
user_instructions: None,
developer_instructions: None,
final_output_json_schema: None,
truncation_policy: Some(turn_context.truncation_policy.into()),
})];
session
.record_initial_history(InitialHistory::Resumed(ResumedHistory {
conversation_id: ThreadId::default(),
history: rollout_items,
rollout_path: PathBuf::from("/tmp/resume.jsonl"),
}))
.await;
assert_eq!(
session.previous_model().await,
Some(previous_model.to_string())
);
}
#[tokio::test]
async fn resumed_history_seeds_initial_context_on_first_turn_only() {
let (session, turn_context) = make_session_and_context().await;
@ -5673,6 +5715,36 @@ mod tests {
assert_eq!(expected, history.raw_items());
}
#[tokio::test]
async fn record_initial_history_forked_hydrates_previous_model() {
let (session, turn_context) = make_session_and_context().await;
let previous_model = "forked-rollout-model";
let rollout_items = vec![RolloutItem::TurnContext(TurnContextItem {
turn_id: Some(turn_context.sub_id.clone()),
cwd: turn_context.cwd.clone(),
approval_policy: turn_context.approval_policy,
sandbox_policy: turn_context.sandbox_policy.clone(),
model: previous_model.to_string(),
personality: turn_context.personality,
collaboration_mode: Some(turn_context.collaboration_mode.clone()),
effort: turn_context.reasoning_effort,
summary: turn_context.reasoning_summary,
user_instructions: None,
developer_instructions: None,
final_output_json_schema: None,
truncation_policy: Some(turn_context.truncation_policy.into()),
})];
session
.record_initial_history(InitialHistory::Forked(rollout_items))
.await;
assert_eq!(
session.previous_model().await,
Some(previous_model.to_string())
);
}
#[tokio::test]
async fn thread_rollback_drops_last_turn_from_history() {
let (sess, tc, rx) = make_session_and_context_with_rx().await;
@ -5736,6 +5808,10 @@ mod tests {
let history = sess.clone_history().await;
assert_eq!(expected, history.raw_items());
assert_eq!(
sess.previous_model().await,
Some(tc.model_info.slug.clone())
);
}
#[tokio::test]
@ -6569,6 +6645,32 @@ mod tests {
}
}
#[tokio::test]
async fn spawn_task_hydrates_previous_model() {
let (sess, tc, _rx) = make_session_and_context_with_rx().await;
sess.set_previous_model(None).await;
let input = vec![UserInput::Text {
text: "hello".to_string(),
text_elements: Vec::new(),
}];
sess.spawn_task(
Arc::clone(&tc),
input,
NeverEndingTask {
kind: TaskKind::Regular,
listen_to_cancellation_token: true,
},
)
.await;
sess.abort_all_tasks(TurnAbortReason::Interrupted).await;
assert_eq!(
sess.previous_model().await,
Some(tc.model_info.slug.clone())
);
}
#[derive(Clone, Copy)]
struct NeverEndingTask {
kind: TaskKind,

View file

@ -25,8 +25,8 @@ pub(crate) struct SessionState {
/// TODO(owen): This is a temporary solution to avoid updating a thread's updated_at
/// timestamp when resuming a session. Remove this once SQLite is in place.
pub(crate) initial_context_seeded: bool,
/// Previous rollout model for one-shot model-switch handling on first turn after resume.
pub(crate) pending_resume_previous_model: Option<String>,
/// Previous model seen by the session, used for model-switch handling on task start.
previous_model: Option<String>,
/// Startup regular task pre-created during session initialization.
pub(crate) startup_regular_task: Option<RegularTask>,
pub(crate) active_mcp_tool_selection: Option<Vec<String>>,
@ -44,7 +44,7 @@ impl SessionState {
dependency_env: HashMap::new(),
mcp_dependency_prompted: HashSet::new(),
initial_context_seeded: false,
pending_resume_previous_model: None,
previous_model: None,
startup_regular_task: None,
active_mcp_tool_selection: None,
}
@ -59,6 +59,13 @@ impl SessionState {
self.history.record_items(items, policy);
}
pub(crate) fn previous_model(&self) -> Option<String> {
self.previous_model.clone()
}
pub(crate) fn set_previous_model(&mut self, previous_model: Option<String>) {
self.previous_model = previous_model;
}
pub(crate) fn clone_history(&self) -> ContextManager {
self.history.clone()
}

View file

@ -140,6 +140,7 @@ impl Session {
tokio::spawn(
async move {
let ctx_for_finish = Arc::clone(&ctx);
let model_slug = ctx_for_finish.model_info.slug.clone();
let last_agent_message = task_for_run
.run(
Arc::clone(&session_ctx),
@ -155,6 +156,11 @@ impl Session {
sess.on_task_finished(ctx_for_finish, last_agent_message)
.await;
}
// Set previous model regardless of completion or interruption for model-switch handling.
session_ctx
.clone_session()
.set_previous_model(Some(model_slug))
.await;
done_clone.notify_waiters();
}
.instrument(session_span),
@ -267,6 +273,10 @@ impl Session {
task.handle.abort();
// Set previous model even when interrupted so model-switch handling stays correct.
self.set_previous_model(Some(task.turn_context.model_info.slug.clone()))
.await;
let session_ctx = Arc::new(SessionTaskContext::new(Arc::clone(self)));
session_task
.abort(session_ctx, Arc::clone(&task.turn_context))