diff --git a/codex-rs/app-server/tests/suite/send_message.rs b/codex-rs/app-server/tests/suite/send_message.rs index 814352a00..ecb742aff 100644 --- a/codex-rs/app-server/tests/suite/send_message.rs +++ b/codex-rs/app-server/tests/suite/send_message.rs @@ -1,5 +1,7 @@ use anyhow::Result; use app_test_support::McpProcess; +use app_test_support::create_fake_rollout; +use app_test_support::rollout_path; use app_test_support::to_response; use codex_app_server_protocol::AddConversationListenerParams; use codex_app_server_protocol::AddConversationSubscriptionResponse; @@ -9,18 +11,25 @@ use codex_app_server_protocol::JSONRPCResponse; use codex_app_server_protocol::NewConversationParams; use codex_app_server_protocol::NewConversationResponse; use codex_app_server_protocol::RequestId; +use codex_app_server_protocol::ResumeConversationParams; +use codex_app_server_protocol::ResumeConversationResponse; use codex_app_server_protocol::SendUserMessageParams; use codex_app_server_protocol::SendUserMessageResponse; use codex_execpolicy::Policy; use codex_protocol::ThreadId; +use codex_protocol::config_types::ReasoningSummary; use codex_protocol::models::ContentItem; use codex_protocol::models::DeveloperInstructions; use codex_protocol::models::ResponseItem; use codex_protocol::protocol::AskForApproval; use codex_protocol::protocol::RawResponseItemEvent; +use codex_protocol::protocol::RolloutItem; +use codex_protocol::protocol::RolloutLine; use codex_protocol::protocol::SandboxPolicy; +use codex_protocol::protocol::TurnContextItem; use core_test_support::responses; use pretty_assertions::assert_eq; +use std::io::Write; use std::path::Path; use std::path::PathBuf; use tempfile::TempDir; @@ -263,6 +272,114 @@ async fn test_send_message_session_not_found() -> Result<()> { Ok(()) } +#[tokio::test] +async fn resume_with_model_mismatch_appends_model_switch_once() -> Result<()> { + let server = responses::start_mock_server().await; + let response_mock = responses::mount_sse_sequence( + &server, + vec![ + responses::sse(vec![ + responses::ev_response_created("resp-1"), + responses::ev_assistant_message("msg-1", "Done"), + responses::ev_completed("resp-1"), + ]), + responses::sse(vec![ + responses::ev_response_created("resp-2"), + responses::ev_assistant_message("msg-2", "Done again"), + responses::ev_completed("resp-2"), + ]), + ], + ) + .await; + + let codex_home = TempDir::new()?; + create_config_toml(codex_home.path(), &server.uri())?; + + let filename_ts = "2025-01-02T12-00-00"; + let meta_rfc3339 = "2025-01-02T12:00:00Z"; + let preview = "Resume me"; + let conversation_id = create_fake_rollout( + codex_home.path(), + filename_ts, + meta_rfc3339, + preview, + Some("mock_provider"), + None, + )?; + let rollout_path = rollout_path(codex_home.path(), filename_ts, &conversation_id); + append_rollout_turn_context(&rollout_path, meta_rfc3339, "previous-model")?; + + let mut mcp = McpProcess::new(codex_home.path()).await?; + timeout(DEFAULT_READ_TIMEOUT, mcp.initialize()).await??; + + let resume_id = mcp + .send_resume_conversation_request(ResumeConversationParams { + path: Some(rollout_path.clone()), + conversation_id: None, + history: None, + overrides: Some(NewConversationParams { + model: Some("gpt-5.2-codex".to_string()), + ..Default::default() + }), + }) + .await?; + timeout( + DEFAULT_READ_TIMEOUT, + mcp.read_stream_until_notification_message("sessionConfigured"), + ) + .await??; + let resume_resp: JSONRPCResponse = timeout( + DEFAULT_READ_TIMEOUT, + mcp.read_stream_until_response_message(RequestId::Integer(resume_id)), + ) + .await??; + let ResumeConversationResponse { + conversation_id, .. + } = to_response::(resume_resp)?; + + let add_listener_id = mcp + .send_add_conversation_listener_request(AddConversationListenerParams { + conversation_id, + experimental_raw_events: false, + }) + .await?; + let add_listener_resp: JSONRPCResponse = timeout( + DEFAULT_READ_TIMEOUT, + mcp.read_stream_until_response_message(RequestId::Integer(add_listener_id)), + ) + .await??; + let AddConversationSubscriptionResponse { subscription_id: _ } = + to_response::<_>(add_listener_resp)?; + + send_message("hello after resume", conversation_id, &mut mcp).await?; + send_message("second turn", conversation_id, &mut mcp).await?; + + let requests = response_mock.requests(); + assert_eq!(requests.len(), 2, "expected two model requests"); + + let first_developer_texts = requests[0].message_input_texts("developer"); + let first_model_switch_count = first_developer_texts + .iter() + .filter(|text| text.contains("")) + .count(); + assert!( + first_model_switch_count >= 1, + "expected model switch message on first post-resume turn, got {first_developer_texts:?}" + ); + + let second_developer_texts = requests[1].message_input_texts("developer"); + let second_model_switch_count = second_developer_texts + .iter() + .filter(|text| text.contains("")) + .count(); + assert_eq!( + second_model_switch_count, 1, + "did not expect duplicate model switch message on second post-resume turn, got {second_developer_texts:?}" + ); + + Ok(()) +} + // --------------------------------------------------------------------------- // Helpers // --------------------------------------------------------------------------- @@ -438,3 +555,28 @@ fn content_texts(content: &[ContentItem]) -> Vec<&str> { }) .collect() } + +fn append_rollout_turn_context(path: &Path, timestamp: &str, model: &str) -> std::io::Result<()> { + let line = RolloutLine { + timestamp: timestamp.to_string(), + item: RolloutItem::TurnContext(TurnContextItem { + cwd: PathBuf::from("/"), + approval_policy: AskForApproval::Never, + sandbox_policy: SandboxPolicy::DangerFullAccess, + model: model.to_string(), + personality: None, + collaboration_mode: None, + effort: None, + summary: ReasoningSummary::Auto, + user_instructions: None, + developer_instructions: None, + final_output_json_schema: None, + truncation_policy: None, + }), + }; + let serialized = serde_json::to_string(&line).map_err(std::io::Error::other)?; + std::fs::OpenOptions::new() + .append(true) + .open(path)? + .write_all(format!("{serialized}\n").as_bytes()) +} diff --git a/codex-rs/core/src/codex.rs b/codex-rs/core/src/codex.rs index 45b0a34b7..8567c9286 100644 --- a/codex-rs/core/src/codex.rs +++ b/codex-rs/core/src/codex.rs @@ -1176,32 +1176,26 @@ impl Session { { let mut state = self.state.lock().await; state.initial_context_seeded = false; + state.pending_resume_previous_model = None; } // If resuming, warn when the last recorded model differs from the current one. - if let Some(prev) = rollout_items.iter().rev().find_map(|it| { - if let RolloutItem::TurnContext(ctx) = it { - Some(ctx.model.as_str()) - } else { - None - } - }) { - let curr = turn_context.model_info.slug.as_str(); - if prev != curr { - warn!( - "resuming session with different model: previous={prev}, current={curr}" - ); - self.send_event( - &turn_context, - EventMsg::Warning(WarningEvent { - message: format!( - "This session was recorded with model `{prev}` but is resuming with `{curr}`. \ + let curr = turn_context.model_info.slug.as_str(); + if let Some(prev) = Self::last_model_name(&rollout_items, curr) { + warn!("resuming session with different model: previous={prev}, current={curr}"); + self.send_event( + &turn_context, + EventMsg::Warning(WarningEvent { + message: format!( + "This session was recorded with model `{prev}` but is resuming with `{curr}`. \ Consider switching back to `{prev}` as it may affect Codex performance." - ), - }), - ) - .await; - } + ), + }), + ) + .await; + + let mut state = self.state.lock().await; + state.pending_resume_previous_model = Some(prev.to_string()); } // Always add response items to conversation history @@ -1260,6 +1254,21 @@ impl Session { } } + fn last_model_name<'a>(rollout_items: &'a [RolloutItem], current: &str) -> Option<&'a str> { + let previous = rollout_items.iter().rev().find_map(|it| { + if let RolloutItem::TurnContext(ctx) = it { + Some(ctx.model.as_str()) + } else { + None + } + })?; + if previous == current { + None + } else { + Some(previous) + } + } + fn last_token_info_from_rollout(rollout_items: &[RolloutItem]) -> Option { rollout_items.iter().rev().find_map(|item| match item { RolloutItem::EventMsg(EventMsg::TokenCount(ev)) => ev.info.clone(), @@ -1267,6 +1276,11 @@ impl Session { }) } + async fn take_pending_resume_previous_model(&self) -> Option { + let mut state = self.state.lock().await; + state.pending_resume_previous_model.take() + } + pub(crate) async fn update_settings( &self, updates: SessionSettingsUpdate, @@ -1504,10 +1518,12 @@ impl Session { fn build_model_instructions_update_item( &self, previous: Option<&Arc>, + resumed_model: Option<&str>, next: &TurnContext, ) -> Option { - let prev = previous?; - if prev.model_info.slug == next.model_info.slug { + let previous_model = + resumed_model.or_else(|| previous.map(|prev| prev.model_info.slug.as_str()))?; + if previous_model == next.model_info.slug { return None; } @@ -1522,6 +1538,7 @@ impl Session { fn build_settings_update_items( &self, previous_context: Option<&Arc>, + resumed_model: Option<&str>, current_context: &TurnContext, ) -> Vec { let mut update_items = Vec::new(); @@ -1540,9 +1557,11 @@ impl Session { { update_items.push(collaboration_mode_item); } - if let Some(model_instructions_item) = - self.build_model_instructions_update_item(previous_context, current_context) - { + if let Some(model_instructions_item) = self.build_model_instructions_update_item( + previous_context, + resumed_model, + current_context, + ) { update_items.push(model_instructions_item); } if let Some(personality_item) = @@ -2819,8 +2838,12 @@ mod handlers { // Attempt to inject input into current task if let Err(items) = sess.inject_input(items).await { sess.seed_initial_context_if_needed(¤t_context).await; - let update_items = - sess.build_settings_update_items(previous_context.as_ref(), ¤t_context); + let resumed_model = sess.take_pending_resume_previous_model().await; + let update_items = sess.build_settings_update_items( + previous_context.as_ref(), + resumed_model.as_deref(), + ¤t_context, + ); if !update_items.is_empty() { sess.record_conversation_items(¤t_context, &update_items) .await; diff --git a/codex-rs/core/src/state/session.rs b/codex-rs/core/src/state/session.rs index deee0d0c7..29bc14d1c 100644 --- a/codex-rs/core/src/state/session.rs +++ b/codex-rs/core/src/state/session.rs @@ -24,6 +24,8 @@ pub(crate) struct SessionState { /// TODO(owen): This is a temporary solution to avoid updating a thread's updated_at /// timestamp when resuming a session. Remove this once SQLite is in place. pub(crate) initial_context_seeded: bool, + /// Previous rollout model for one-shot model-switch handling on first turn after resume. + pub(crate) pending_resume_previous_model: Option, } impl SessionState { @@ -38,6 +40,7 @@ impl SessionState { dependency_env: HashMap::new(), mcp_dependency_prompted: HashSet::new(), initial_context_seeded: false, + pending_resume_previous_model: None, } } diff --git a/codex-rs/core/tests/suite/resume.rs b/codex-rs/core/tests/suite/resume.rs index a7be42f43..c0bdcd4fe 100644 --- a/codex-rs/core/tests/suite/resume.rs +++ b/codex-rs/core/tests/suite/resume.rs @@ -9,6 +9,7 @@ use core_test_support::responses::ev_completed; use core_test_support::responses::ev_reasoning_item; use core_test_support::responses::ev_response_created; use core_test_support::responses::mount_sse_once; +use core_test_support::responses::mount_sse_sequence; use core_test_support::responses::sse; use core_test_support::responses::start_mock_server; use core_test_support::skip_if_no_network; @@ -182,12 +183,22 @@ async fn resume_switches_models_preserves_base_instructions() -> Result<()> { .unwrap_or_default() .to_string(); - let resumed_sse = sse(vec![ - ev_response_created("resp-resume"), - ev_assistant_message("msg-2", "Resumed turn"), - ev_completed("resp-resume"), - ]); - let resumed_mock = mount_sse_once(&server, resumed_sse).await; + let resumed_mock = mount_sse_sequence( + &server, + vec![ + sse(vec![ + ev_response_created("resp-resume-1"), + ev_assistant_message("msg-2", "Resumed turn"), + ev_completed("resp-resume-1"), + ]), + sse(vec![ + ev_response_created("resp-resume-2"), + ev_assistant_message("msg-3", "Second resumed turn"), + ev_completed("resp-resume-2"), + ]), + ], + ) + .await; let mut resume_builder = test_codex().with_config(|config| { config.model = Some("gpt-5.2-codex".to_string()); @@ -208,13 +219,139 @@ async fn resume_switches_models_preserves_base_instructions() -> Result<()> { }) .await; - let resumed_body = resumed_mock.single_request().body_json(); - let resumed_instructions = resumed_body - .get("instructions") - .and_then(|v| v.as_str()) - .unwrap_or_default() - .to_string(); - assert_eq!(resumed_instructions, initial_instructions); + resumed + .codex + .submit(Op::UserInput { + items: vec![UserInput::Text { + text: "Second turn after resume".into(), + text_elements: Vec::new(), + }], + final_output_json_schema: None, + }) + .await?; + wait_for_event(&resumed.codex, |event| { + matches!(event, EventMsg::TurnComplete(_)) + }) + .await; + + let requests = resumed_mock.requests(); + assert_eq!(requests.len(), 2, "expected two resumed requests"); + + let first_resumed = &requests[0]; + assert_eq!(first_resumed.instructions_text(), initial_instructions); + let first_developer_texts = first_resumed.message_input_texts("developer"); + let first_model_switch_count = first_developer_texts + .iter() + .filter(|text| text.contains("")) + .count(); + assert!( + first_model_switch_count >= 1, + "expected model switch message on first post-resume turn" + ); + + let second_resumed = &requests[1]; + assert_eq!(second_resumed.instructions_text(), initial_instructions); + let second_developer_texts = second_resumed.message_input_texts("developer"); + let second_model_switch_count = second_developer_texts + .iter() + .filter(|text| text.contains("")) + .count(); + assert_eq!( + second_model_switch_count, 1, + "did not expect duplicate model switch message after first post-resume turn" + ); + + Ok(()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn resume_model_switch_is_not_duplicated_after_pre_turn_override() -> Result<()> { + skip_if_no_network!(Ok(())); + + let server = start_mock_server().await; + let mut builder = test_codex().with_config(|config| { + config.model = Some("gpt-5.2".to_string()); + }); + let initial = builder.build(&server).await?; + let codex = Arc::clone(&initial.codex); + let home = initial.home.clone(); + let rollout_path = initial + .session_configured + .rollout_path + .clone() + .expect("rollout path"); + + let initial_mock = mount_sse_once( + &server, + sse(vec![ + ev_response_created("resp-initial"), + ev_assistant_message("msg-1", "Completed first turn"), + ev_completed("resp-initial"), + ]), + ) + .await; + codex + .submit(Op::UserInput { + items: vec![UserInput::Text { + text: "Record initial instructions".into(), + text_elements: Vec::new(), + }], + final_output_json_schema: None, + }) + .await?; + wait_for_event(&codex, |event| matches!(event, EventMsg::TurnComplete(_))).await; + let _ = initial_mock.single_request(); + + let resumed_mock = mount_sse_once( + &server, + sse(vec![ + ev_response_created("resp-resume"), + ev_assistant_message("msg-2", "Resumed turn"), + ev_completed("resp-resume"), + ]), + ) + .await; + + let mut resume_builder = test_codex().with_config(|config| { + config.model = Some("gpt-5.2-codex".to_string()); + }); + let resumed = resume_builder.resume(&server, home, rollout_path).await?; + resumed + .codex + .submit(Op::OverrideTurnContext { + cwd: None, + approval_policy: None, + sandbox_policy: None, + windows_sandbox_level: None, + model: Some("gpt-5.1-codex-max".to_string()), + effort: None, + summary: None, + collaboration_mode: None, + personality: None, + }) + .await?; + resumed + .codex + .submit(Op::UserInput { + items: vec![UserInput::Text { + text: "first turn after override".into(), + text_elements: Vec::new(), + }], + final_output_json_schema: None, + }) + .await?; + wait_for_event(&resumed.codex, |event| { + matches!(event, EventMsg::TurnComplete(_)) + }) + .await; + + let request = resumed_mock.single_request(); + let developer_texts = request.message_input_texts("developer"); + let model_switch_count = developer_texts + .iter() + .filter(|text| text.contains("")) + .count(); + assert_eq!(model_switch_count, 1); Ok(()) }