fix(core,app-server) resume with different model (#10719)
## Summary When resuming with a different model, we should also append a developer message with the model instructions ## Testing - [x] Added unit tests
This commit is contained in:
parent
1e1146cd29
commit
fe8b474acd
4 changed files with 347 additions and 42 deletions
|
|
@ -1,5 +1,7 @@
|
|||
use anyhow::Result;
|
||||
use app_test_support::McpProcess;
|
||||
use app_test_support::create_fake_rollout;
|
||||
use app_test_support::rollout_path;
|
||||
use app_test_support::to_response;
|
||||
use codex_app_server_protocol::AddConversationListenerParams;
|
||||
use codex_app_server_protocol::AddConversationSubscriptionResponse;
|
||||
|
|
@ -9,18 +11,25 @@ use codex_app_server_protocol::JSONRPCResponse;
|
|||
use codex_app_server_protocol::NewConversationParams;
|
||||
use codex_app_server_protocol::NewConversationResponse;
|
||||
use codex_app_server_protocol::RequestId;
|
||||
use codex_app_server_protocol::ResumeConversationParams;
|
||||
use codex_app_server_protocol::ResumeConversationResponse;
|
||||
use codex_app_server_protocol::SendUserMessageParams;
|
||||
use codex_app_server_protocol::SendUserMessageResponse;
|
||||
use codex_execpolicy::Policy;
|
||||
use codex_protocol::ThreadId;
|
||||
use codex_protocol::config_types::ReasoningSummary;
|
||||
use codex_protocol::models::ContentItem;
|
||||
use codex_protocol::models::DeveloperInstructions;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use codex_protocol::protocol::AskForApproval;
|
||||
use codex_protocol::protocol::RawResponseItemEvent;
|
||||
use codex_protocol::protocol::RolloutItem;
|
||||
use codex_protocol::protocol::RolloutLine;
|
||||
use codex_protocol::protocol::SandboxPolicy;
|
||||
use codex_protocol::protocol::TurnContextItem;
|
||||
use core_test_support::responses;
|
||||
use pretty_assertions::assert_eq;
|
||||
use std::io::Write;
|
||||
use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
use tempfile::TempDir;
|
||||
|
|
@ -263,6 +272,114 @@ async fn test_send_message_session_not_found() -> Result<()> {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn resume_with_model_mismatch_appends_model_switch_once() -> Result<()> {
|
||||
let server = responses::start_mock_server().await;
|
||||
let response_mock = responses::mount_sse_sequence(
|
||||
&server,
|
||||
vec![
|
||||
responses::sse(vec![
|
||||
responses::ev_response_created("resp-1"),
|
||||
responses::ev_assistant_message("msg-1", "Done"),
|
||||
responses::ev_completed("resp-1"),
|
||||
]),
|
||||
responses::sse(vec![
|
||||
responses::ev_response_created("resp-2"),
|
||||
responses::ev_assistant_message("msg-2", "Done again"),
|
||||
responses::ev_completed("resp-2"),
|
||||
]),
|
||||
],
|
||||
)
|
||||
.await;
|
||||
|
||||
let codex_home = TempDir::new()?;
|
||||
create_config_toml(codex_home.path(), &server.uri())?;
|
||||
|
||||
let filename_ts = "2025-01-02T12-00-00";
|
||||
let meta_rfc3339 = "2025-01-02T12:00:00Z";
|
||||
let preview = "Resume me";
|
||||
let conversation_id = create_fake_rollout(
|
||||
codex_home.path(),
|
||||
filename_ts,
|
||||
meta_rfc3339,
|
||||
preview,
|
||||
Some("mock_provider"),
|
||||
None,
|
||||
)?;
|
||||
let rollout_path = rollout_path(codex_home.path(), filename_ts, &conversation_id);
|
||||
append_rollout_turn_context(&rollout_path, meta_rfc3339, "previous-model")?;
|
||||
|
||||
let mut mcp = McpProcess::new(codex_home.path()).await?;
|
||||
timeout(DEFAULT_READ_TIMEOUT, mcp.initialize()).await??;
|
||||
|
||||
let resume_id = mcp
|
||||
.send_resume_conversation_request(ResumeConversationParams {
|
||||
path: Some(rollout_path.clone()),
|
||||
conversation_id: None,
|
||||
history: None,
|
||||
overrides: Some(NewConversationParams {
|
||||
model: Some("gpt-5.2-codex".to_string()),
|
||||
..Default::default()
|
||||
}),
|
||||
})
|
||||
.await?;
|
||||
timeout(
|
||||
DEFAULT_READ_TIMEOUT,
|
||||
mcp.read_stream_until_notification_message("sessionConfigured"),
|
||||
)
|
||||
.await??;
|
||||
let resume_resp: JSONRPCResponse = timeout(
|
||||
DEFAULT_READ_TIMEOUT,
|
||||
mcp.read_stream_until_response_message(RequestId::Integer(resume_id)),
|
||||
)
|
||||
.await??;
|
||||
let ResumeConversationResponse {
|
||||
conversation_id, ..
|
||||
} = to_response::<ResumeConversationResponse>(resume_resp)?;
|
||||
|
||||
let add_listener_id = mcp
|
||||
.send_add_conversation_listener_request(AddConversationListenerParams {
|
||||
conversation_id,
|
||||
experimental_raw_events: false,
|
||||
})
|
||||
.await?;
|
||||
let add_listener_resp: JSONRPCResponse = timeout(
|
||||
DEFAULT_READ_TIMEOUT,
|
||||
mcp.read_stream_until_response_message(RequestId::Integer(add_listener_id)),
|
||||
)
|
||||
.await??;
|
||||
let AddConversationSubscriptionResponse { subscription_id: _ } =
|
||||
to_response::<_>(add_listener_resp)?;
|
||||
|
||||
send_message("hello after resume", conversation_id, &mut mcp).await?;
|
||||
send_message("second turn", conversation_id, &mut mcp).await?;
|
||||
|
||||
let requests = response_mock.requests();
|
||||
assert_eq!(requests.len(), 2, "expected two model requests");
|
||||
|
||||
let first_developer_texts = requests[0].message_input_texts("developer");
|
||||
let first_model_switch_count = first_developer_texts
|
||||
.iter()
|
||||
.filter(|text| text.contains("<model_switch>"))
|
||||
.count();
|
||||
assert!(
|
||||
first_model_switch_count >= 1,
|
||||
"expected model switch message on first post-resume turn, got {first_developer_texts:?}"
|
||||
);
|
||||
|
||||
let second_developer_texts = requests[1].message_input_texts("developer");
|
||||
let second_model_switch_count = second_developer_texts
|
||||
.iter()
|
||||
.filter(|text| text.contains("<model_switch>"))
|
||||
.count();
|
||||
assert_eq!(
|
||||
second_model_switch_count, 1,
|
||||
"did not expect duplicate model switch message on second post-resume turn, got {second_developer_texts:?}"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
|
@ -438,3 +555,28 @@ fn content_texts(content: &[ContentItem]) -> Vec<&str> {
|
|||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn append_rollout_turn_context(path: &Path, timestamp: &str, model: &str) -> std::io::Result<()> {
|
||||
let line = RolloutLine {
|
||||
timestamp: timestamp.to_string(),
|
||||
item: RolloutItem::TurnContext(TurnContextItem {
|
||||
cwd: PathBuf::from("/"),
|
||||
approval_policy: AskForApproval::Never,
|
||||
sandbox_policy: SandboxPolicy::DangerFullAccess,
|
||||
model: model.to_string(),
|
||||
personality: None,
|
||||
collaboration_mode: None,
|
||||
effort: None,
|
||||
summary: ReasoningSummary::Auto,
|
||||
user_instructions: None,
|
||||
developer_instructions: None,
|
||||
final_output_json_schema: None,
|
||||
truncation_policy: None,
|
||||
}),
|
||||
};
|
||||
let serialized = serde_json::to_string(&line).map_err(std::io::Error::other)?;
|
||||
std::fs::OpenOptions::new()
|
||||
.append(true)
|
||||
.open(path)?
|
||||
.write_all(format!("{serialized}\n").as_bytes())
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1176,32 +1176,26 @@ impl Session {
|
|||
{
|
||||
let mut state = self.state.lock().await;
|
||||
state.initial_context_seeded = false;
|
||||
state.pending_resume_previous_model = None;
|
||||
}
|
||||
|
||||
// If resuming, warn when the last recorded model differs from the current one.
|
||||
if let Some(prev) = rollout_items.iter().rev().find_map(|it| {
|
||||
if let RolloutItem::TurnContext(ctx) = it {
|
||||
Some(ctx.model.as_str())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}) {
|
||||
let curr = turn_context.model_info.slug.as_str();
|
||||
if prev != curr {
|
||||
warn!(
|
||||
"resuming session with different model: previous={prev}, current={curr}"
|
||||
);
|
||||
self.send_event(
|
||||
&turn_context,
|
||||
EventMsg::Warning(WarningEvent {
|
||||
message: format!(
|
||||
"This session was recorded with model `{prev}` but is resuming with `{curr}`. \
|
||||
let curr = turn_context.model_info.slug.as_str();
|
||||
if let Some(prev) = Self::last_model_name(&rollout_items, curr) {
|
||||
warn!("resuming session with different model: previous={prev}, current={curr}");
|
||||
self.send_event(
|
||||
&turn_context,
|
||||
EventMsg::Warning(WarningEvent {
|
||||
message: format!(
|
||||
"This session was recorded with model `{prev}` but is resuming with `{curr}`. \
|
||||
Consider switching back to `{prev}` as it may affect Codex performance."
|
||||
),
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
),
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
|
||||
let mut state = self.state.lock().await;
|
||||
state.pending_resume_previous_model = Some(prev.to_string());
|
||||
}
|
||||
|
||||
// Always add response items to conversation history
|
||||
|
|
@ -1260,6 +1254,21 @@ impl Session {
|
|||
}
|
||||
}
|
||||
|
||||
fn last_model_name<'a>(rollout_items: &'a [RolloutItem], current: &str) -> Option<&'a str> {
|
||||
let previous = rollout_items.iter().rev().find_map(|it| {
|
||||
if let RolloutItem::TurnContext(ctx) = it {
|
||||
Some(ctx.model.as_str())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})?;
|
||||
if previous == current {
|
||||
None
|
||||
} else {
|
||||
Some(previous)
|
||||
}
|
||||
}
|
||||
|
||||
fn last_token_info_from_rollout(rollout_items: &[RolloutItem]) -> Option<TokenUsageInfo> {
|
||||
rollout_items.iter().rev().find_map(|item| match item {
|
||||
RolloutItem::EventMsg(EventMsg::TokenCount(ev)) => ev.info.clone(),
|
||||
|
|
@ -1267,6 +1276,11 @@ impl Session {
|
|||
})
|
||||
}
|
||||
|
||||
async fn take_pending_resume_previous_model(&self) -> Option<String> {
|
||||
let mut state = self.state.lock().await;
|
||||
state.pending_resume_previous_model.take()
|
||||
}
|
||||
|
||||
pub(crate) async fn update_settings(
|
||||
&self,
|
||||
updates: SessionSettingsUpdate,
|
||||
|
|
@ -1504,10 +1518,12 @@ impl Session {
|
|||
fn build_model_instructions_update_item(
|
||||
&self,
|
||||
previous: Option<&Arc<TurnContext>>,
|
||||
resumed_model: Option<&str>,
|
||||
next: &TurnContext,
|
||||
) -> Option<ResponseItem> {
|
||||
let prev = previous?;
|
||||
if prev.model_info.slug == next.model_info.slug {
|
||||
let previous_model =
|
||||
resumed_model.or_else(|| previous.map(|prev| prev.model_info.slug.as_str()))?;
|
||||
if previous_model == next.model_info.slug {
|
||||
return None;
|
||||
}
|
||||
|
||||
|
|
@ -1522,6 +1538,7 @@ impl Session {
|
|||
fn build_settings_update_items(
|
||||
&self,
|
||||
previous_context: Option<&Arc<TurnContext>>,
|
||||
resumed_model: Option<&str>,
|
||||
current_context: &TurnContext,
|
||||
) -> Vec<ResponseItem> {
|
||||
let mut update_items = Vec::new();
|
||||
|
|
@ -1540,9 +1557,11 @@ impl Session {
|
|||
{
|
||||
update_items.push(collaboration_mode_item);
|
||||
}
|
||||
if let Some(model_instructions_item) =
|
||||
self.build_model_instructions_update_item(previous_context, current_context)
|
||||
{
|
||||
if let Some(model_instructions_item) = self.build_model_instructions_update_item(
|
||||
previous_context,
|
||||
resumed_model,
|
||||
current_context,
|
||||
) {
|
||||
update_items.push(model_instructions_item);
|
||||
}
|
||||
if let Some(personality_item) =
|
||||
|
|
@ -2819,8 +2838,12 @@ mod handlers {
|
|||
// Attempt to inject input into current task
|
||||
if let Err(items) = sess.inject_input(items).await {
|
||||
sess.seed_initial_context_if_needed(¤t_context).await;
|
||||
let update_items =
|
||||
sess.build_settings_update_items(previous_context.as_ref(), ¤t_context);
|
||||
let resumed_model = sess.take_pending_resume_previous_model().await;
|
||||
let update_items = sess.build_settings_update_items(
|
||||
previous_context.as_ref(),
|
||||
resumed_model.as_deref(),
|
||||
¤t_context,
|
||||
);
|
||||
if !update_items.is_empty() {
|
||||
sess.record_conversation_items(¤t_context, &update_items)
|
||||
.await;
|
||||
|
|
|
|||
|
|
@ -24,6 +24,8 @@ pub(crate) struct SessionState {
|
|||
/// TODO(owen): This is a temporary solution to avoid updating a thread's updated_at
|
||||
/// timestamp when resuming a session. Remove this once SQLite is in place.
|
||||
pub(crate) initial_context_seeded: bool,
|
||||
/// Previous rollout model for one-shot model-switch handling on first turn after resume.
|
||||
pub(crate) pending_resume_previous_model: Option<String>,
|
||||
}
|
||||
|
||||
impl SessionState {
|
||||
|
|
@ -38,6 +40,7 @@ impl SessionState {
|
|||
dependency_env: HashMap::new(),
|
||||
mcp_dependency_prompted: HashSet::new(),
|
||||
initial_context_seeded: false,
|
||||
pending_resume_previous_model: None,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ use core_test_support::responses::ev_completed;
|
|||
use core_test_support::responses::ev_reasoning_item;
|
||||
use core_test_support::responses::ev_response_created;
|
||||
use core_test_support::responses::mount_sse_once;
|
||||
use core_test_support::responses::mount_sse_sequence;
|
||||
use core_test_support::responses::sse;
|
||||
use core_test_support::responses::start_mock_server;
|
||||
use core_test_support::skip_if_no_network;
|
||||
|
|
@ -182,12 +183,22 @@ async fn resume_switches_models_preserves_base_instructions() -> Result<()> {
|
|||
.unwrap_or_default()
|
||||
.to_string();
|
||||
|
||||
let resumed_sse = sse(vec![
|
||||
ev_response_created("resp-resume"),
|
||||
ev_assistant_message("msg-2", "Resumed turn"),
|
||||
ev_completed("resp-resume"),
|
||||
]);
|
||||
let resumed_mock = mount_sse_once(&server, resumed_sse).await;
|
||||
let resumed_mock = mount_sse_sequence(
|
||||
&server,
|
||||
vec![
|
||||
sse(vec![
|
||||
ev_response_created("resp-resume-1"),
|
||||
ev_assistant_message("msg-2", "Resumed turn"),
|
||||
ev_completed("resp-resume-1"),
|
||||
]),
|
||||
sse(vec![
|
||||
ev_response_created("resp-resume-2"),
|
||||
ev_assistant_message("msg-3", "Second resumed turn"),
|
||||
ev_completed("resp-resume-2"),
|
||||
]),
|
||||
],
|
||||
)
|
||||
.await;
|
||||
|
||||
let mut resume_builder = test_codex().with_config(|config| {
|
||||
config.model = Some("gpt-5.2-codex".to_string());
|
||||
|
|
@ -208,13 +219,139 @@ async fn resume_switches_models_preserves_base_instructions() -> Result<()> {
|
|||
})
|
||||
.await;
|
||||
|
||||
let resumed_body = resumed_mock.single_request().body_json();
|
||||
let resumed_instructions = resumed_body
|
||||
.get("instructions")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or_default()
|
||||
.to_string();
|
||||
assert_eq!(resumed_instructions, initial_instructions);
|
||||
resumed
|
||||
.codex
|
||||
.submit(Op::UserInput {
|
||||
items: vec![UserInput::Text {
|
||||
text: "Second turn after resume".into(),
|
||||
text_elements: Vec::new(),
|
||||
}],
|
||||
final_output_json_schema: None,
|
||||
})
|
||||
.await?;
|
||||
wait_for_event(&resumed.codex, |event| {
|
||||
matches!(event, EventMsg::TurnComplete(_))
|
||||
})
|
||||
.await;
|
||||
|
||||
let requests = resumed_mock.requests();
|
||||
assert_eq!(requests.len(), 2, "expected two resumed requests");
|
||||
|
||||
let first_resumed = &requests[0];
|
||||
assert_eq!(first_resumed.instructions_text(), initial_instructions);
|
||||
let first_developer_texts = first_resumed.message_input_texts("developer");
|
||||
let first_model_switch_count = first_developer_texts
|
||||
.iter()
|
||||
.filter(|text| text.contains("<model_switch>"))
|
||||
.count();
|
||||
assert!(
|
||||
first_model_switch_count >= 1,
|
||||
"expected model switch message on first post-resume turn"
|
||||
);
|
||||
|
||||
let second_resumed = &requests[1];
|
||||
assert_eq!(second_resumed.instructions_text(), initial_instructions);
|
||||
let second_developer_texts = second_resumed.message_input_texts("developer");
|
||||
let second_model_switch_count = second_developer_texts
|
||||
.iter()
|
||||
.filter(|text| text.contains("<model_switch>"))
|
||||
.count();
|
||||
assert_eq!(
|
||||
second_model_switch_count, 1,
|
||||
"did not expect duplicate model switch message after first post-resume turn"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn resume_model_switch_is_not_duplicated_after_pre_turn_override() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let server = start_mock_server().await;
|
||||
let mut builder = test_codex().with_config(|config| {
|
||||
config.model = Some("gpt-5.2".to_string());
|
||||
});
|
||||
let initial = builder.build(&server).await?;
|
||||
let codex = Arc::clone(&initial.codex);
|
||||
let home = initial.home.clone();
|
||||
let rollout_path = initial
|
||||
.session_configured
|
||||
.rollout_path
|
||||
.clone()
|
||||
.expect("rollout path");
|
||||
|
||||
let initial_mock = mount_sse_once(
|
||||
&server,
|
||||
sse(vec![
|
||||
ev_response_created("resp-initial"),
|
||||
ev_assistant_message("msg-1", "Completed first turn"),
|
||||
ev_completed("resp-initial"),
|
||||
]),
|
||||
)
|
||||
.await;
|
||||
codex
|
||||
.submit(Op::UserInput {
|
||||
items: vec![UserInput::Text {
|
||||
text: "Record initial instructions".into(),
|
||||
text_elements: Vec::new(),
|
||||
}],
|
||||
final_output_json_schema: None,
|
||||
})
|
||||
.await?;
|
||||
wait_for_event(&codex, |event| matches!(event, EventMsg::TurnComplete(_))).await;
|
||||
let _ = initial_mock.single_request();
|
||||
|
||||
let resumed_mock = mount_sse_once(
|
||||
&server,
|
||||
sse(vec![
|
||||
ev_response_created("resp-resume"),
|
||||
ev_assistant_message("msg-2", "Resumed turn"),
|
||||
ev_completed("resp-resume"),
|
||||
]),
|
||||
)
|
||||
.await;
|
||||
|
||||
let mut resume_builder = test_codex().with_config(|config| {
|
||||
config.model = Some("gpt-5.2-codex".to_string());
|
||||
});
|
||||
let resumed = resume_builder.resume(&server, home, rollout_path).await?;
|
||||
resumed
|
||||
.codex
|
||||
.submit(Op::OverrideTurnContext {
|
||||
cwd: None,
|
||||
approval_policy: None,
|
||||
sandbox_policy: None,
|
||||
windows_sandbox_level: None,
|
||||
model: Some("gpt-5.1-codex-max".to_string()),
|
||||
effort: None,
|
||||
summary: None,
|
||||
collaboration_mode: None,
|
||||
personality: None,
|
||||
})
|
||||
.await?;
|
||||
resumed
|
||||
.codex
|
||||
.submit(Op::UserInput {
|
||||
items: vec![UserInput::Text {
|
||||
text: "first turn after override".into(),
|
||||
text_elements: Vec::new(),
|
||||
}],
|
||||
final_output_json_schema: None,
|
||||
})
|
||||
.await?;
|
||||
wait_for_event(&resumed.codex, |event| {
|
||||
matches!(event, EventMsg::TurnComplete(_))
|
||||
})
|
||||
.await;
|
||||
|
||||
let request = resumed_mock.single_request();
|
||||
let developer_texts = request.message_input_texts("developer");
|
||||
let model_switch_count = developer_texts
|
||||
.iter()
|
||||
.filter(|text| text.contains("<model_switch>"))
|
||||
.count();
|
||||
assert_eq!(model_switch_count, 1);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue