Hydrate previous model across resume/fork/rollback/task start (#11497)
- Replace pending resume model state with persistent previous_model and hydrate it on resume, fork, rollback, and task end in spawn_task
This commit is contained in:
parent
23444a063b
commit
bb5dfd037a
3 changed files with 139 additions and 20 deletions
|
|
@ -1348,20 +1348,25 @@ impl Session {
|
|||
let mut state = self.state.lock().await;
|
||||
state.initial_context_seeded = true;
|
||||
}
|
||||
self.set_previous_model(None).await;
|
||||
// Ensure initial items are visible to immediate readers (e.g., tests, forks).
|
||||
self.flush_rollout().await;
|
||||
}
|
||||
InitialHistory::Resumed(resumed_history) => {
|
||||
let rollout_items = resumed_history.history;
|
||||
let previous_model = Self::last_rollout_model_name(&rollout_items)
|
||||
.map(std::string::ToString::to_string);
|
||||
{
|
||||
let mut state = self.state.lock().await;
|
||||
state.initial_context_seeded = false;
|
||||
state.pending_resume_previous_model = None;
|
||||
}
|
||||
self.set_previous_model(previous_model).await;
|
||||
|
||||
// If resuming, warn when the last recorded model differs from the current one.
|
||||
let curr = turn_context.model_info.slug.as_str();
|
||||
if let Some(prev) = Self::last_model_name(&rollout_items, curr) {
|
||||
if let Some(prev) =
|
||||
Self::last_rollout_model_name(&rollout_items).filter(|p| *p != curr)
|
||||
{
|
||||
warn!("resuming session with different model: previous={prev}, current={curr}");
|
||||
self.send_event(
|
||||
&turn_context,
|
||||
|
|
@ -1373,9 +1378,6 @@ impl Session {
|
|||
}),
|
||||
)
|
||||
.await;
|
||||
|
||||
let mut state = self.state.lock().await;
|
||||
state.pending_resume_previous_model = Some(prev.to_string());
|
||||
}
|
||||
|
||||
// Always add response items to conversation history
|
||||
|
|
@ -1399,6 +1401,10 @@ impl Session {
|
|||
self.flush_rollout().await;
|
||||
}
|
||||
InitialHistory::Forked(rollout_items) => {
|
||||
let previous_model = Self::last_rollout_model_name(&rollout_items)
|
||||
.map(std::string::ToString::to_string);
|
||||
self.set_previous_model(previous_model).await;
|
||||
|
||||
// Always add response items to conversation history
|
||||
let reconstructed_history = self
|
||||
.reconstruct_history_from_rollout(&turn_context, &rollout_items)
|
||||
|
|
@ -1438,19 +1444,14 @@ impl Session {
|
|||
}
|
||||
}
|
||||
|
||||
fn last_model_name<'a>(rollout_items: &'a [RolloutItem], current: &str) -> Option<&'a str> {
|
||||
let previous = rollout_items.iter().rev().find_map(|it| {
|
||||
fn last_rollout_model_name(rollout_items: &[RolloutItem]) -> Option<&str> {
|
||||
rollout_items.iter().rev().find_map(|it| {
|
||||
if let RolloutItem::TurnContext(ctx) = it {
|
||||
Some(ctx.model.as_str())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})?;
|
||||
if previous == current {
|
||||
None
|
||||
} else {
|
||||
Some(previous)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn last_token_info_from_rollout(rollout_items: &[RolloutItem]) -> Option<TokenUsageInfo> {
|
||||
|
|
@ -1460,9 +1461,14 @@ impl Session {
|
|||
})
|
||||
}
|
||||
|
||||
async fn take_pending_resume_previous_model(&self) -> Option<String> {
|
||||
async fn previous_model(&self) -> Option<String> {
|
||||
let state = self.state.lock().await;
|
||||
state.previous_model()
|
||||
}
|
||||
|
||||
pub(crate) async fn set_previous_model(&self, previous_model: Option<String>) {
|
||||
let mut state = self.state.lock().await;
|
||||
state.pending_resume_previous_model.take()
|
||||
state.set_previous_model(previous_model);
|
||||
}
|
||||
|
||||
fn maybe_refresh_shell_snapshot_for_cwd(
|
||||
|
|
@ -3134,10 +3140,10 @@ mod handlers {
|
|||
// Attempt to inject input into current task.
|
||||
if let Err(SteerInputError::NoActiveTurn(items)) = sess.steer_input(items, None).await {
|
||||
sess.seed_initial_context_if_needed(¤t_context).await;
|
||||
let resumed_model = sess.take_pending_resume_previous_model().await;
|
||||
let previous_model = sess.previous_model().await;
|
||||
let update_items = sess.build_settings_update_items(
|
||||
previous_context.as_ref(),
|
||||
resumed_model.as_deref(),
|
||||
previous_model.as_deref(),
|
||||
¤t_context,
|
||||
);
|
||||
if !update_items.is_empty() {
|
||||
|
|
@ -3529,6 +3535,8 @@ mod handlers {
|
|||
}
|
||||
|
||||
let turn_context = sess.new_default_turn_with_sub_id(sub_id).await;
|
||||
sess.set_previous_model(Some(turn_context.model_info.slug.clone()))
|
||||
.await;
|
||||
|
||||
let mut history = sess.clone_history().await;
|
||||
history.drop_last_n_user_turns(num_turns);
|
||||
|
|
@ -5511,6 +5519,40 @@ mod tests {
|
|||
assert_eq!(expected, history.raw_items());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn record_initial_history_resumed_hydrates_previous_model() {
|
||||
let (session, turn_context) = make_session_and_context().await;
|
||||
let previous_model = "previous-rollout-model";
|
||||
let rollout_items = vec![RolloutItem::TurnContext(TurnContextItem {
|
||||
turn_id: Some(turn_context.sub_id.clone()),
|
||||
cwd: turn_context.cwd.clone(),
|
||||
approval_policy: turn_context.approval_policy,
|
||||
sandbox_policy: turn_context.sandbox_policy.clone(),
|
||||
model: previous_model.to_string(),
|
||||
personality: turn_context.personality,
|
||||
collaboration_mode: Some(turn_context.collaboration_mode.clone()),
|
||||
effort: turn_context.reasoning_effort,
|
||||
summary: turn_context.reasoning_summary,
|
||||
user_instructions: None,
|
||||
developer_instructions: None,
|
||||
final_output_json_schema: None,
|
||||
truncation_policy: Some(turn_context.truncation_policy.into()),
|
||||
})];
|
||||
|
||||
session
|
||||
.record_initial_history(InitialHistory::Resumed(ResumedHistory {
|
||||
conversation_id: ThreadId::default(),
|
||||
history: rollout_items,
|
||||
rollout_path: PathBuf::from("/tmp/resume.jsonl"),
|
||||
}))
|
||||
.await;
|
||||
|
||||
assert_eq!(
|
||||
session.previous_model().await,
|
||||
Some(previous_model.to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn resumed_history_seeds_initial_context_on_first_turn_only() {
|
||||
let (session, turn_context) = make_session_and_context().await;
|
||||
|
|
@ -5673,6 +5715,36 @@ mod tests {
|
|||
assert_eq!(expected, history.raw_items());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn record_initial_history_forked_hydrates_previous_model() {
|
||||
let (session, turn_context) = make_session_and_context().await;
|
||||
let previous_model = "forked-rollout-model";
|
||||
let rollout_items = vec![RolloutItem::TurnContext(TurnContextItem {
|
||||
turn_id: Some(turn_context.sub_id.clone()),
|
||||
cwd: turn_context.cwd.clone(),
|
||||
approval_policy: turn_context.approval_policy,
|
||||
sandbox_policy: turn_context.sandbox_policy.clone(),
|
||||
model: previous_model.to_string(),
|
||||
personality: turn_context.personality,
|
||||
collaboration_mode: Some(turn_context.collaboration_mode.clone()),
|
||||
effort: turn_context.reasoning_effort,
|
||||
summary: turn_context.reasoning_summary,
|
||||
user_instructions: None,
|
||||
developer_instructions: None,
|
||||
final_output_json_schema: None,
|
||||
truncation_policy: Some(turn_context.truncation_policy.into()),
|
||||
})];
|
||||
|
||||
session
|
||||
.record_initial_history(InitialHistory::Forked(rollout_items))
|
||||
.await;
|
||||
|
||||
assert_eq!(
|
||||
session.previous_model().await,
|
||||
Some(previous_model.to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn thread_rollback_drops_last_turn_from_history() {
|
||||
let (sess, tc, rx) = make_session_and_context_with_rx().await;
|
||||
|
|
@ -5736,6 +5808,10 @@ mod tests {
|
|||
|
||||
let history = sess.clone_history().await;
|
||||
assert_eq!(expected, history.raw_items());
|
||||
assert_eq!(
|
||||
sess.previous_model().await,
|
||||
Some(tc.model_info.slug.clone())
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
|
|
@ -6569,6 +6645,32 @@ mod tests {
|
|||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn spawn_task_hydrates_previous_model() {
|
||||
let (sess, tc, _rx) = make_session_and_context_with_rx().await;
|
||||
sess.set_previous_model(None).await;
|
||||
let input = vec![UserInput::Text {
|
||||
text: "hello".to_string(),
|
||||
text_elements: Vec::new(),
|
||||
}];
|
||||
|
||||
sess.spawn_task(
|
||||
Arc::clone(&tc),
|
||||
input,
|
||||
NeverEndingTask {
|
||||
kind: TaskKind::Regular,
|
||||
listen_to_cancellation_token: true,
|
||||
},
|
||||
)
|
||||
.await;
|
||||
|
||||
sess.abort_all_tasks(TurnAbortReason::Interrupted).await;
|
||||
assert_eq!(
|
||||
sess.previous_model().await,
|
||||
Some(tc.model_info.slug.clone())
|
||||
);
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
struct NeverEndingTask {
|
||||
kind: TaskKind,
|
||||
|
|
|
|||
|
|
@ -25,8 +25,8 @@ pub(crate) struct SessionState {
|
|||
/// TODO(owen): This is a temporary solution to avoid updating a thread's updated_at
|
||||
/// timestamp when resuming a session. Remove this once SQLite is in place.
|
||||
pub(crate) initial_context_seeded: bool,
|
||||
/// Previous rollout model for one-shot model-switch handling on first turn after resume.
|
||||
pub(crate) pending_resume_previous_model: Option<String>,
|
||||
/// Previous model seen by the session, used for model-switch handling on task start.
|
||||
previous_model: Option<String>,
|
||||
/// Startup regular task pre-created during session initialization.
|
||||
pub(crate) startup_regular_task: Option<RegularTask>,
|
||||
pub(crate) active_mcp_tool_selection: Option<Vec<String>>,
|
||||
|
|
@ -44,7 +44,7 @@ impl SessionState {
|
|||
dependency_env: HashMap::new(),
|
||||
mcp_dependency_prompted: HashSet::new(),
|
||||
initial_context_seeded: false,
|
||||
pending_resume_previous_model: None,
|
||||
previous_model: None,
|
||||
startup_regular_task: None,
|
||||
active_mcp_tool_selection: None,
|
||||
}
|
||||
|
|
@ -59,6 +59,13 @@ impl SessionState {
|
|||
self.history.record_items(items, policy);
|
||||
}
|
||||
|
||||
pub(crate) fn previous_model(&self) -> Option<String> {
|
||||
self.previous_model.clone()
|
||||
}
|
||||
pub(crate) fn set_previous_model(&mut self, previous_model: Option<String>) {
|
||||
self.previous_model = previous_model;
|
||||
}
|
||||
|
||||
pub(crate) fn clone_history(&self) -> ContextManager {
|
||||
self.history.clone()
|
||||
}
|
||||
|
|
|
|||
|
|
@ -140,6 +140,7 @@ impl Session {
|
|||
tokio::spawn(
|
||||
async move {
|
||||
let ctx_for_finish = Arc::clone(&ctx);
|
||||
let model_slug = ctx_for_finish.model_info.slug.clone();
|
||||
let last_agent_message = task_for_run
|
||||
.run(
|
||||
Arc::clone(&session_ctx),
|
||||
|
|
@ -155,6 +156,11 @@ impl Session {
|
|||
sess.on_task_finished(ctx_for_finish, last_agent_message)
|
||||
.await;
|
||||
}
|
||||
// Set previous model regardless of completion or interruption for model-switch handling.
|
||||
session_ctx
|
||||
.clone_session()
|
||||
.set_previous_model(Some(model_slug))
|
||||
.await;
|
||||
done_clone.notify_waiters();
|
||||
}
|
||||
.instrument(session_span),
|
||||
|
|
@ -267,6 +273,10 @@ impl Session {
|
|||
|
||||
task.handle.abort();
|
||||
|
||||
// Set previous model even when interrupted so model-switch handling stays correct.
|
||||
self.set_previous_model(Some(task.turn_context.model_info.slug.clone()))
|
||||
.await;
|
||||
|
||||
let session_ctx = Arc::new(SessionTaskContext::new(Arc::clone(self)));
|
||||
session_task
|
||||
.abort(session_ctx, Arc::clone(&task.turn_context))
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue