diff --git a/codex-rs/app-server/README.md b/codex-rs/app-server/README.md index aece5fa3c..27bb78e49 100644 --- a/codex-rs/app-server/README.md +++ b/codex-rs/app-server/README.md @@ -557,7 +557,7 @@ Today both notifications carry an empty `items` array even when item events were `ThreadItem` is the tagged union carried in turn responses and `item/*` notifications. Currently we support events for the following items: -- `userMessage` — `{id, content}` where `content` is a list of user inputs (`text`, `image`, or `localImage`). +- `userMessage` — `{id, content}` where `content` is a list of user inputs (`text`, `image`, or `localImage`). Cyber model-routing warnings are surfaced as synthetic `userMessage` items with `text` prefixed by `Warning:`. - `agentMessage` — `{id, text}` containing the accumulated agent reply. - `plan` — `{id, text}` emitted for plan-mode turns; plan text can stream via `item/plan/delta` (experimental). - `reasoning` — `{id, summary, content}` where `summary` holds streamed reasoning summaries (applicable for most OpenAI models) and `content` holds raw reasoning blocks (applicable for e.g. open source models). diff --git a/codex-rs/app-server/src/bespoke_event_handling.rs b/codex-rs/app-server/src/bespoke_event_handling.rs index da7abe059..87ddf7e18 100644 --- a/codex-rs/app-server/src/bespoke_event_handling.rs +++ b/codex-rs/app-server/src/bespoke_event_handling.rs @@ -68,6 +68,7 @@ use codex_app_server_protocol::TurnInterruptResponse; use codex_app_server_protocol::TurnPlanStep; use codex_app_server_protocol::TurnPlanUpdatedNotification; use codex_app_server_protocol::TurnStatus; +use codex_app_server_protocol::UserInput as V2UserInput; use codex_app_server_protocol::build_turns_from_rollout_items; use codex_core::CodexThread; use codex_core::parse_command::shlex_join; @@ -95,6 +96,8 @@ use codex_protocol::request_user_input::RequestUserInputAnswer as CoreRequestUse use codex_protocol::request_user_input::RequestUserInputResponse as CoreRequestUserInputResponse; use std::collections::HashMap; use std::convert::TryFrom; +use std::hash::Hash; +use std::hash::Hasher; use std::path::PathBuf; use std::sync::Arc; use tokio::sync::Mutex; @@ -122,6 +125,35 @@ pub(crate) async fn apply_bespoke_event_handling( EventMsg::TurnComplete(_ev) => { handle_turn_complete(conversation_id, event_turn_id, &outgoing, &thread_state).await; } + EventMsg::Warning(warning_event) => { + if matches!(api_version, ApiVersion::V2) + && is_safety_check_downgrade_warning(&warning_event.message) + { + let item = ThreadItem::UserMessage { + id: warning_item_id(&event_turn_id, &warning_event.message), + content: vec![V2UserInput::Text { + text: format!("Warning: {}", warning_event.message), + text_elements: Vec::new(), + }], + }; + let started = ItemStartedNotification { + thread_id: conversation_id.to_string(), + turn_id: event_turn_id.clone(), + item: item.clone(), + }; + outgoing + .send_server_notification(ServerNotification::ItemStarted(started)) + .await; + let completed = ItemCompletedNotification { + thread_id: conversation_id.to_string(), + turn_id: event_turn_id.clone(), + item, + }; + outgoing + .send_server_notification(ServerNotification::ItemCompleted(completed)) + .await; + } + } EventMsg::ApplyPatchApprovalRequest(ApplyPatchApprovalRequestEvent { call_id, turn_id, @@ -1286,6 +1318,18 @@ async fn complete_command_execution_item( .await; } +fn is_safety_check_downgrade_warning(message: &str) -> bool { + message.contains("Your account was flagged for potentially high-risk cyber activity") + && message.contains("apply for trusted access: https://chatgpt.com/cyber") +} + +fn warning_item_id(turn_id: &str, message: &str) -> String { + let mut hasher = std::collections::hash_map::DefaultHasher::new(); + message.hash(&mut hasher); + let digest = hasher.finish(); + format!("{turn_id}-warning-{digest:x}") +} + async fn maybe_emit_raw_response_item_completed( api_version: ApiVersion, conversation_id: ThreadId, @@ -2016,6 +2060,18 @@ mod tests { assert_eq!(item, expected); } + #[test] + fn safety_check_downgrade_warning_detection_matches_expected_message() { + let warning = "Your account was flagged for potentially high-risk cyber activity and this request was routed to gpt-5.2 as a fallback. To regain access to gpt-5.3-codex, apply for trusted access: https://chatgpt.com/cyber\nLearn more: https://developers.openai.com/codex/concepts/cyber-safety"; + assert!(is_safety_check_downgrade_warning(warning)); + } + + #[test] + fn safety_check_downgrade_warning_detection_ignores_other_warnings() { + let warning = "Model metadata for `mock-model` not found. Defaulting to fallback metadata; this can degrade performance and cause issues."; + assert!(!is_safety_check_downgrade_warning(warning)); + } + #[tokio::test] async fn test_handle_error_records_message() -> Result<()> { let conversation_id = ThreadId::new(); diff --git a/codex-rs/app-server/tests/suite/v2/mod.rs b/codex-rs/app-server/tests/suite/v2/mod.rs index 48622acdd..1eacc2a84 100644 --- a/codex-rs/app-server/tests/suite/v2/mod.rs +++ b/codex-rs/app-server/tests/suite/v2/mod.rs @@ -15,6 +15,7 @@ mod plan_item; mod rate_limits; mod request_user_input; mod review; +mod safety_check_downgrade; mod skills_list; mod thread_archive; mod thread_fork; diff --git a/codex-rs/app-server/tests/suite/v2/safety_check_downgrade.rs b/codex-rs/app-server/tests/suite/v2/safety_check_downgrade.rs new file mode 100644 index 000000000..20a7c6023 --- /dev/null +++ b/codex-rs/app-server/tests/suite/v2/safety_check_downgrade.rs @@ -0,0 +1,266 @@ +use anyhow::Result; +use app_test_support::McpProcess; +use app_test_support::to_response; +use codex_app_server_protocol::ItemCompletedNotification; +use codex_app_server_protocol::ItemStartedNotification; +use codex_app_server_protocol::JSONRPCNotification; +use codex_app_server_protocol::JSONRPCResponse; +use codex_app_server_protocol::RequestId; +use codex_app_server_protocol::ThreadItem; +use codex_app_server_protocol::ThreadStartParams; +use codex_app_server_protocol::ThreadStartResponse; +use codex_app_server_protocol::TurnStartParams; +use codex_app_server_protocol::TurnStartResponse; +use codex_app_server_protocol::UserInput; +use core_test_support::responses; +use core_test_support::skip_if_no_network; +use pretty_assertions::assert_eq; +use tempfile::TempDir; +use tokio::time::timeout; + +const DEFAULT_READ_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(10); +const REQUESTED_MODEL: &str = "gpt-5.1-codex-max"; +const SERVER_MODEL: &str = "gpt-5.2-codex"; + +#[tokio::test] +async fn openai_model_header_mismatch_emits_warning_item_v2() -> Result<()> { + skip_if_no_network!(Ok(())); + + let server = responses::start_mock_server().await; + let body = responses::sse(vec![ + responses::ev_response_created("resp-1"), + responses::ev_assistant_message("msg-1", "Done"), + responses::ev_completed("resp-1"), + ]); + let response = responses::sse_response(body).insert_header("OpenAI-Model", SERVER_MODEL); + let _response_mock = responses::mount_response_once(&server, response).await; + + let codex_home = TempDir::new()?; + create_config_toml(codex_home.path(), &server.uri())?; + + let mut mcp = McpProcess::new(codex_home.path()).await?; + timeout(DEFAULT_READ_TIMEOUT, mcp.initialize()).await??; + + let thread_req = mcp + .send_thread_start_request(ThreadStartParams { + model: Some(REQUESTED_MODEL.to_string()), + ..Default::default() + }) + .await?; + let thread_resp: JSONRPCResponse = timeout( + DEFAULT_READ_TIMEOUT, + mcp.read_stream_until_response_message(RequestId::Integer(thread_req)), + ) + .await??; + let ThreadStartResponse { thread, .. } = to_response::(thread_resp)?; + + let turn_req = mcp + .send_turn_start_request(TurnStartParams { + thread_id: thread.id.clone(), + input: vec![UserInput::Text { + text: "trigger safeguard".to_string(), + text_elements: Vec::new(), + }], + ..Default::default() + }) + .await?; + let _turn_resp: JSONRPCResponse = timeout( + DEFAULT_READ_TIMEOUT, + mcp.read_stream_until_response_message(RequestId::Integer(turn_req)), + ) + .await??; + let _turn_start: TurnStartResponse = to_response(_turn_resp)?; + + let warning_started = timeout(DEFAULT_READ_TIMEOUT, async { + loop { + let notification: JSONRPCNotification = mcp + .read_stream_until_notification_message("item/started") + .await?; + let params = notification.params.expect("item/started params"); + let started: ItemStartedNotification = + serde_json::from_value(params).expect("deserialize item/started"); + if warning_text_from_item(&started.item).is_some_and(is_cyber_model_warning_text) { + return Ok::(started); + } + } + }) + .await??; + + let warning_text = + warning_text_from_item(&warning_started.item).expect("expected warning user message item"); + assert!(warning_text.contains("Warning:")); + assert!(warning_text.contains("gpt-5.2 as a fallback")); + assert!(warning_text.contains("regain access to gpt-5.3-codex")); + + let warning_completed = timeout(DEFAULT_READ_TIMEOUT, async { + loop { + let notification: JSONRPCNotification = mcp + .read_stream_until_notification_message("item/completed") + .await?; + let params = notification.params.expect("item/completed params"); + let completed: ItemCompletedNotification = + serde_json::from_value(params).expect("deserialize item/completed"); + if warning_text_from_item(&completed.item).is_some_and(is_cyber_model_warning_text) { + return Ok::(completed); + } + } + }) + .await??; + assert_eq!( + warning_text_from_item(&warning_completed.item), + warning_text_from_item(&warning_started.item) + ); + + timeout( + DEFAULT_READ_TIMEOUT, + mcp.read_stream_until_notification_message("turn/completed"), + ) + .await??; + + Ok(()) +} + +#[tokio::test] +async fn response_model_field_mismatch_emits_warning_item_v2_when_header_matches_requested() +-> Result<()> { + skip_if_no_network!(Ok(())); + + let server = responses::start_mock_server().await; + let body = responses::sse(vec![ + serde_json::json!({ + "type": "response.created", + "response": { + "id": "resp-1", + "model": SERVER_MODEL, + } + }), + responses::ev_assistant_message("msg-1", "Done"), + responses::ev_completed("resp-1"), + ]); + let response = responses::sse_response(body).insert_header("OpenAI-Model", REQUESTED_MODEL); + let _response_mock = responses::mount_response_once(&server, response).await; + + let codex_home = TempDir::new()?; + create_config_toml(codex_home.path(), &server.uri())?; + + let mut mcp = McpProcess::new(codex_home.path()).await?; + timeout(DEFAULT_READ_TIMEOUT, mcp.initialize()).await??; + + let thread_req = mcp + .send_thread_start_request(ThreadStartParams { + model: Some(REQUESTED_MODEL.to_string()), + ..Default::default() + }) + .await?; + let thread_resp: JSONRPCResponse = timeout( + DEFAULT_READ_TIMEOUT, + mcp.read_stream_until_response_message(RequestId::Integer(thread_req)), + ) + .await??; + let ThreadStartResponse { thread, .. } = to_response::(thread_resp)?; + + let turn_req = mcp + .send_turn_start_request(TurnStartParams { + thread_id: thread.id.clone(), + input: vec![UserInput::Text { + text: "trigger response model check".to_string(), + text_elements: Vec::new(), + }], + ..Default::default() + }) + .await?; + let turn_resp: JSONRPCResponse = timeout( + DEFAULT_READ_TIMEOUT, + mcp.read_stream_until_response_message(RequestId::Integer(turn_req)), + ) + .await??; + let _turn_start: TurnStartResponse = to_response(turn_resp)?; + + let warning_started = timeout(DEFAULT_READ_TIMEOUT, async { + loop { + let notification: JSONRPCNotification = mcp + .read_stream_until_notification_message("item/started") + .await?; + let params = notification.params.expect("item/started params"); + let started: ItemStartedNotification = + serde_json::from_value(params).expect("deserialize item/started"); + if warning_text_from_item(&started.item).is_some_and(is_cyber_model_warning_text) { + return Ok::(started); + } + } + }) + .await??; + let warning_text = + warning_text_from_item(&warning_started.item).expect("expected warning user message item"); + assert!(warning_text.contains("gpt-5.2 as a fallback")); + + let warning_completed = timeout(DEFAULT_READ_TIMEOUT, async { + loop { + let notification: JSONRPCNotification = mcp + .read_stream_until_notification_message("item/completed") + .await?; + let params = notification.params.expect("item/completed params"); + let completed: ItemCompletedNotification = + serde_json::from_value(params).expect("deserialize item/completed"); + if warning_text_from_item(&completed.item).is_some_and(is_cyber_model_warning_text) { + return Ok::(completed); + } + } + }) + .await??; + assert_eq!( + warning_text_from_item(&warning_completed.item), + warning_text_from_item(&warning_started.item) + ); + + timeout( + DEFAULT_READ_TIMEOUT, + mcp.read_stream_until_notification_message("turn/completed"), + ) + .await??; + + Ok(()) +} + +fn warning_text_from_item(item: &ThreadItem) -> Option<&str> { + let ThreadItem::UserMessage { content, .. } = item else { + return None; + }; + + content.iter().find_map(|input| match input { + UserInput::Text { text, .. } if text.starts_with("Warning: ") => Some(text.as_str()), + _ => None, + }) +} + +fn is_cyber_model_warning_text(text: &str) -> bool { + text.contains("flagged for potentially high-risk cyber activity") + && text.contains("apply for trusted access: https://chatgpt.com/cyber") +} + +fn create_config_toml(codex_home: &std::path::Path, server_uri: &str) -> std::io::Result<()> { + let config_toml = codex_home.join("config.toml"); + std::fs::write( + config_toml, + format!( + r#" +model = "{REQUESTED_MODEL}" +approval_policy = "never" +sandbox_mode = "read-only" + +model_provider = "mock_provider" + +[features] +remote_models = false +personality = true + +[model_providers.mock_provider] +name = "Mock provider for test" +base_url = "{server_uri}/v1" +wire_api = "responses" +request_max_retries = 0 +stream_max_retries = 0 +"# + ), + ) +} diff --git a/codex-rs/codex-api/src/common.rs b/codex-rs/codex-api/src/common.rs index af08b6c97..3ef96c185 100644 --- a/codex-rs/codex-api/src/common.rs +++ b/codex-rs/codex-api/src/common.rs @@ -56,6 +56,9 @@ pub enum ResponseEvent { Created, OutputItemDone(ResponseItem), OutputItemAdded(ResponseItem), + /// Emitted when the server includes `OpenAI-Model` on the stream response. + /// This can differ from the requested model when backend safety routing applies. + ServerModel(String), /// Emitted when `X-Reasoning-Included: true` is present on the response, /// meaning the server already accounted for past reasoning tokens and the /// client should not re-estimate them. diff --git a/codex-rs/codex-api/src/endpoint/aggregate.rs b/codex-rs/codex-api/src/endpoint/aggregate.rs index a91eec90a..6d8e785c8 100644 --- a/codex-rs/codex-api/src/endpoint/aggregate.rs +++ b/codex-rs/codex-api/src/endpoint/aggregate.rs @@ -63,6 +63,9 @@ impl Stream for AggregatedStream { Poll::Ready(Some(Ok(ResponseEvent::ModelsEtag(etag)))) => { return Poll::Ready(Some(Ok(ResponseEvent::ModelsEtag(etag)))); } + Poll::Ready(Some(Ok(ResponseEvent::ServerModel(model)))) => { + return Poll::Ready(Some(Ok(ResponseEvent::ServerModel(model)))); + } Poll::Ready(Some(Ok(ResponseEvent::Completed { response_id, token_usage, diff --git a/codex-rs/codex-api/src/endpoint/responses_websocket.rs b/codex-rs/codex-api/src/endpoint/responses_websocket.rs index 6ebf5ab65..aa559e983 100644 --- a/codex-rs/codex-api/src/endpoint/responses_websocket.rs +++ b/codex-rs/codex-api/src/endpoint/responses_websocket.rs @@ -163,6 +163,7 @@ impl Drop for WsStream { const X_CODEX_TURN_STATE_HEADER: &str = "x-codex-turn-state"; const X_MODELS_ETAG_HEADER: &str = "x-models-etag"; const X_REASONING_INCLUDED_HEADER: &str = "x-reasoning-included"; +const OPENAI_MODEL_HEADER: &str = "openai-model"; pub struct ResponsesWebsocketConnection { stream: Arc>>, @@ -170,6 +171,7 @@ pub struct ResponsesWebsocketConnection { idle_timeout: Duration, server_reasoning_included: bool, models_etag: Option, + server_model: Option, telemetry: Option>, } @@ -179,6 +181,7 @@ impl ResponsesWebsocketConnection { idle_timeout: Duration, server_reasoning_included: bool, models_etag: Option, + server_model: Option, telemetry: Option>, ) -> Self { Self { @@ -186,6 +189,7 @@ impl ResponsesWebsocketConnection { idle_timeout, server_reasoning_included, models_etag, + server_model, telemetry, } } @@ -204,12 +208,16 @@ impl ResponsesWebsocketConnection { let idle_timeout = self.idle_timeout; let server_reasoning_included = self.server_reasoning_included; let models_etag = self.models_etag.clone(); + let server_model = self.server_model.clone(); let telemetry = self.telemetry.clone(); let request_body = serde_json::to_value(&request).map_err(|err| { ApiError::Stream(format!("failed to encode websocket request: {err}")) })?; tokio::spawn(async move { + if let Some(model) = server_model { + let _ = tx_event.send(Ok(ResponseEvent::ServerModel(model))).await; + } if let Some(etag) = models_etag { let _ = tx_event.send(Ok(ResponseEvent::ModelsEtag(etag))).await; } @@ -273,13 +281,14 @@ impl ResponsesWebsocketClient { merge_request_headers(&self.provider.headers, extra_headers, default_headers); add_auth_headers_to_header_map(&self.auth, &mut headers); - let (stream, server_reasoning_included, models_etag) = + let (stream, server_reasoning_included, models_etag, server_model) = connect_websocket(ws_url, headers, turn_state.clone()).await?; Ok(ResponsesWebsocketConnection::new( stream, self.provider.stream_idle_timeout, server_reasoning_included, models_etag, + server_model, telemetry, )) } @@ -304,7 +313,7 @@ async fn connect_websocket( url: Url, headers: HeaderMap, turn_state: Option>>, -) -> Result<(WsStream, bool, Option), ApiError> { +) -> Result<(WsStream, bool, Option, Option), ApiError> { ensure_rustls_crypto_provider(); info!("connecting to websocket: {url}"); @@ -341,6 +350,11 @@ async fn connect_websocket( .get(X_MODELS_ETAG_HEADER) .and_then(|value| value.to_str().ok()) .map(ToString::to_string); + let server_model = response + .headers() + .get(OPENAI_MODEL_HEADER) + .and_then(|value| value.to_str().ok()) + .map(ToString::to_string); if let Some(turn_state) = turn_state && let Some(header_value) = response .headers() @@ -349,7 +363,12 @@ async fn connect_websocket( { let _ = turn_state.set(header_value.to_string()); } - Ok((WsStream::new(stream), reasoning_included, models_etag)) + Ok(( + WsStream::new(stream), + reasoning_included, + models_etag, + server_model, + )) } fn websocket_config() -> WebSocketConfig { @@ -469,6 +488,7 @@ async fn run_websocket_response_stream( idle_timeout: Duration, telemetry: Option>, ) -> Result<(), ApiError> { + let mut last_server_model: Option = None; let request_text = match serde_json::to_string(&request_body) { Ok(text) => text, Err(err) => { @@ -536,6 +556,14 @@ async fn run_websocket_response_stream( } continue; } + if let Some(model) = event.response_model() + && last_server_model.as_deref() != Some(model.as_str()) + { + let _ = tx_event + .send(Ok(ResponseEvent::ServerModel(model.clone()))) + .await; + last_server_model = Some(model); + } match process_responses_event(event) { Ok(Some(event)) => { let is_completed = matches!(event, ResponseEvent::Completed { .. }); diff --git a/codex-rs/codex-api/src/sse/responses.rs b/codex-rs/codex-api/src/sse/responses.rs index 11e4f9de3..75c79ba5c 100644 --- a/codex-rs/codex-api/src/sse/responses.rs +++ b/codex-rs/codex-api/src/sse/responses.rs @@ -26,6 +26,7 @@ use tracing::debug; use tracing::trace; const X_REASONING_INCLUDED_HEADER: &str = "x-reasoning-included"; +const OPENAI_MODEL_HEADER: &str = "openai-model"; /// Streams SSE events from an on-disk fixture for tests. pub fn stream_from_fixture( @@ -60,6 +61,11 @@ pub fn spawn_response_stream( .get("X-Models-Etag") .and_then(|v| v.to_str().ok()) .map(ToString::to_string); + let server_model = stream_response + .headers + .get(OPENAI_MODEL_HEADER) + .and_then(|v| v.to_str().ok()) + .map(ToString::to_string); let reasoning_included = stream_response .headers .get(X_REASONING_INCLUDED_HEADER) @@ -74,6 +80,9 @@ pub fn spawn_response_stream( } let (tx_event, rx_event) = mpsc::channel::>(1600); tokio::spawn(async move { + if let Some(model) = server_model { + let _ = tx_event.send(Ok(ResponseEvent::ServerModel(model))).await; + } for snapshot in rate_limit_snapshots { let _ = tx_event.send(Ok(ResponseEvent::RateLimits(snapshot))).await; } @@ -169,6 +178,41 @@ impl ResponsesStreamEvent { pub fn kind(&self) -> &str { &self.kind } + + pub fn response_model(&self) -> Option { + self.response.as_ref().and_then(extract_server_model) + } +} + +fn extract_server_model(value: &Value) -> Option { + value + .get("model") + .and_then(json_value_as_string) + .or_else(|| { + value + .get("headers") + .and_then(header_openai_model_value_from_json) + }) +} + +fn header_openai_model_value_from_json(value: &Value) -> Option { + let headers = value.as_object()?; + headers.iter().find_map(|(name, value)| { + if name.eq_ignore_ascii_case("openai-model") || name.eq_ignore_ascii_case("x-openai-model") + { + json_value_as_string(value) + } else { + None + } + }) +} + +fn json_value_as_string(value: &Value) -> Option { + match value { + Value::String(value) => Some(value.clone()), + Value::Array(items) => items.first().and_then(json_value_as_string), + _ => None, + } } #[derive(Debug)] @@ -339,6 +383,7 @@ pub async fn process_sse( ) { let mut stream = stream.eventsource(); let mut response_error: Option = None; + let mut last_server_model: Option = None; loop { let start = Instant::now(); @@ -378,6 +423,19 @@ pub async fn process_sse( } }; + if let Some(model) = event.response_model() + && last_server_model.as_deref() != Some(model.as_str()) + { + if tx_event + .send(Ok(ResponseEvent::ServerModel(model.clone()))) + .await + .is_err() + { + return; + } + last_server_model = Some(model); + } + match process_responses_event(event) { Ok(Some(event)) => { let is_completed = matches!(event, ResponseEvent::Completed { .. }); @@ -456,9 +514,13 @@ mod tests { use super::*; use assert_matches::assert_matches; use bytes::Bytes; + use codex_client::StreamResponse; use codex_protocol::models::MessagePhase; use codex_protocol::models::ResponseItem; use futures::stream; + use http::HeaderMap; + use http::HeaderValue; + use http::StatusCode; use pretty_assertions::assert_eq; use serde_json::json; use tokio::sync::mpsc; @@ -870,6 +932,149 @@ mod tests { } } + #[tokio::test] + async fn spawn_response_stream_emits_server_model_header() { + let mut headers = HeaderMap::new(); + headers.insert( + OPENAI_MODEL_HEADER, + HeaderValue::from_static(CYBER_RESTRICTED_MODEL_FOR_TESTS), + ); + let bytes = stream::iter(Vec::>::new()); + let stream_response = StreamResponse { + status: StatusCode::OK, + headers, + bytes: Box::pin(bytes), + }; + + let mut stream = spawn_response_stream(stream_response, idle_timeout(), None, None); + let event = stream + .rx_event + .recv() + .await + .expect("expected server model event") + .expect("expected ok event"); + + match event { + ResponseEvent::ServerModel(model) => { + assert_eq!(model, CYBER_RESTRICTED_MODEL_FOR_TESTS); + } + other => panic!("expected server model event, got {other:?}"), + } + } + + #[tokio::test] + async fn process_sse_emits_server_model_from_response_payload() { + let events = run_sse(vec![ + json!({ + "type": "response.created", + "response": { + "id": "resp-1", + "model": CYBER_RESTRICTED_MODEL_FOR_TESTS + } + }), + json!({ + "type": "response.completed", + "response": { + "id": "resp-1", + "model": CYBER_RESTRICTED_MODEL_FOR_TESTS + } + }), + ]) + .await; + + assert_eq!(events.len(), 3); + assert_matches!( + &events[0], + ResponseEvent::ServerModel(model) if model == CYBER_RESTRICTED_MODEL_FOR_TESTS + ); + assert_matches!(&events[1], ResponseEvent::Created); + assert_matches!( + &events[2], + ResponseEvent::Completed { + response_id, + token_usage: None, + can_append: false + } if response_id == "resp-1" + ); + } + + #[tokio::test] + async fn process_sse_emits_server_model_from_response_headers_payload() { + let events = run_sse(vec![ + json!({ + "type": "response.created", + "response": { + "id": "resp-1", + "headers": { + "OpenAI-Model": CYBER_RESTRICTED_MODEL_FOR_TESTS + } + } + }), + json!({ + "type": "response.completed", + "response": { + "id": "resp-1" + } + }), + ]) + .await; + + assert_eq!(events.len(), 3); + assert_matches!( + &events[0], + ResponseEvent::ServerModel(model) if model == CYBER_RESTRICTED_MODEL_FOR_TESTS + ); + assert_matches!(&events[1], ResponseEvent::Created); + assert_matches!( + &events[2], + ResponseEvent::Completed { + response_id, + token_usage: None, + can_append: false + } if response_id == "resp-1" + ); + } + + #[tokio::test] + async fn process_sse_emits_server_model_again_when_response_model_changes() { + let events = run_sse(vec![ + json!({ + "type": "response.created", + "response": { + "id": "resp-1", + "model": "gpt-5.2-codex" + } + }), + json!({ + "type": "response.completed", + "response": { + "id": "resp-1", + "model": "gpt-5.3-codex" + } + }), + ]) + .await; + + assert_eq!(events.len(), 4); + assert_matches!( + &events[0], + ResponseEvent::ServerModel(model) if model == "gpt-5.2-codex" + ); + assert_matches!(&events[1], ResponseEvent::Created); + assert_matches!( + &events[2], + ResponseEvent::ServerModel(model) if model == "gpt-5.3-codex" + ); + assert_matches!( + &events[3], + ResponseEvent::Completed { + response_id, + token_usage: None, + can_append: false + } if response_id == "resp-1" + ); + } + #[test] fn test_try_parse_retry_after() { let err = Error { @@ -909,4 +1114,6 @@ mod tests { let delay = try_parse_retry_after(&err); assert_eq!(delay, Some(Duration::from_secs(35))); } + + const CYBER_RESTRICTED_MODEL_FOR_TESTS: &str = "gpt-5.3-codex"; } diff --git a/codex-rs/core/src/codex.rs b/codex-rs/core/src/codex.rs index 1341e6341..2b0a76809 100644 --- a/codex-rs/core/src/codex.rs +++ b/codex-rs/core/src/codex.rs @@ -279,6 +279,8 @@ pub struct CodexSpawnOk { pub(crate) const INITIAL_SUBMIT_ID: &str = ""; pub(crate) const SUBMISSION_CHANNEL_CAPACITY: usize = 64; +const CYBER_VERIFY_URL: &str = "https://chatgpt.com/cyber"; +const CYBER_SAFETY_URL: &str = "https://developers.openai.com/codex/concepts/cyber-safety"; impl Codex { /// Spawn a new [`Codex`] and initialize the session. @@ -2560,6 +2562,35 @@ impl Session { self.record_conversation_items(ctx, &[item]).await; } + async fn maybe_warn_on_server_model_mismatch( + self: &Arc, + turn_context: &Arc, + server_model: String, + ) -> bool { + let requested_model = turn_context.model_info.slug.as_str(); + if server_model == requested_model { + info!("server reported model {server_model} (matches requested model)"); + return false; + } + + warn!("server reported model {server_model} while requested model was {requested_model}"); + + let warning_message = format!( + "Your account was flagged for potentially high-risk cyber activity and this request was routed to gpt-5.2 as a fallback. To regain access to gpt-5.3-codex, apply for trusted access: {CYBER_VERIFY_URL} or learn more: {CYBER_SAFETY_URL}" + ); + + self.send_event( + turn_context, + EventMsg::Warning(WarningEvent { + message: warning_message.clone(), + }), + ) + .await; + self.record_model_warning(warning_message, turn_context) + .await; + true + } + pub(crate) async fn replace_history(&self, items: Vec) { let mut state = self.state.lock().await; state.replace_history(items); @@ -4435,6 +4466,7 @@ pub(crate) async fn run_turn( // Although from the perspective of codex.rs, TurnDiffTracker has the lifecycle of a Task which contains // many turns, from the perspective of the user, it is a single turn. let turn_diff_tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::new())); + let mut server_model_warning_emitted_for_turn = false; // `ModelClientSession` is turn-scoped and caches WebSocket + sticky routing state, so we reuse // one instance across retries within this turn. @@ -4497,6 +4529,7 @@ pub(crate) async fn run_turn( sampling_request_input, &explicitly_enabled_connectors, skills_outcome.as_ref(), + &mut server_model_warning_emitted_for_turn, cancellation_token.child_token(), ) .await @@ -4870,6 +4903,7 @@ async fn run_sampling_request( input: Vec, explicitly_enabled_connectors: &HashSet, skills_outcome: Option<&SkillLoadOutcome>, + server_model_warning_emitted_for_turn: &mut bool, cancellation_token: CancellationToken, ) -> CodexResult { let router = built_tools( @@ -4906,6 +4940,7 @@ async fn run_sampling_request( client_session, turn_metadata_header, Arc::clone(&turn_diff_tracker), + server_model_warning_emitted_for_turn, &prompt, cancellation_token.child_token(), ) @@ -5474,6 +5509,7 @@ async fn try_run_sampling_request( client_session: &mut ModelClientSession, turn_metadata_header: Option<&str>, turn_diff_tracker: SharedTurnDiffTracker, + server_model_warning_emitted_for_turn: &mut bool, prompt: &Prompt, cancellation_token: CancellationToken, ) -> CodexResult { @@ -5616,6 +5652,15 @@ async fn try_run_sampling_request( active_item = Some(turn_item); } } + ResponseEvent::ServerModel(server_model) => { + if !*server_model_warning_emitted_for_turn + && sess + .maybe_warn_on_server_model_mismatch(&turn_context, server_model) + .await + { + *server_model_warning_emitted_for_turn = true; + } + } ResponseEvent::ServerReasoningIncluded(included) => { sess.set_server_reasoning_included(included).await; } diff --git a/codex-rs/core/tests/suite/mod.rs b/codex-rs/core/tests/suite/mod.rs index 113cf9f67..03eec5bde 100644 --- a/codex-rs/core/tests/suite/mod.rs +++ b/codex-rs/core/tests/suite/mod.rs @@ -103,6 +103,7 @@ mod resume_warning; mod review; mod rmcp_client; mod rollout_list_find; +mod safety_check_downgrade; mod search_tool; mod seatbelt; mod shell_command; diff --git a/codex-rs/core/tests/suite/safety_check_downgrade.rs b/codex-rs/core/tests/suite/safety_check_downgrade.rs new file mode 100644 index 000000000..a22fead90 --- /dev/null +++ b/codex-rs/core/tests/suite/safety_check_downgrade.rs @@ -0,0 +1,228 @@ +use anyhow::Result; +use codex_core::protocol::AskForApproval; +use codex_core::protocol::EventMsg; +use codex_core::protocol::Op; +use codex_core::protocol::SandboxPolicy; +use codex_protocol::config_types::ReasoningSummary; +use codex_protocol::models::ContentItem; +use codex_protocol::models::ResponseItem; +use codex_protocol::user_input::UserInput; +use core_test_support::responses::ev_assistant_message; +use core_test_support::responses::ev_function_call; +use core_test_support::responses::ev_response_created; +use core_test_support::responses::mount_response_once; +use core_test_support::responses::mount_response_sequence; +use core_test_support::responses::sse; +use core_test_support::responses::sse_completed; +use core_test_support::responses::sse_response; +use core_test_support::responses::start_mock_server; +use core_test_support::skip_if_no_network; +use core_test_support::test_codex::test_codex; +use core_test_support::wait_for_event; +use pretty_assertions::assert_eq; + +const SERVER_MODEL: &str = "gpt-5.2"; +const REQUESTED_MODEL: &str = "gpt-5.3-codex"; + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn openai_model_header_mismatch_emits_warning_event_and_warning_item() -> Result<()> { + skip_if_no_network!(Ok(())); + + let server = start_mock_server().await; + let response = + sse_response(sse_completed("resp-1")).insert_header("OpenAI-Model", SERVER_MODEL); + let _mock = mount_response_once(&server, response).await; + + let mut builder = test_codex().with_model(REQUESTED_MODEL); + let test = builder.build(&server).await?; + + test.codex + .submit(Op::UserTurn { + items: vec![UserInput::Text { + text: "trigger safety check".to_string(), + text_elements: Vec::new(), + }], + final_output_json_schema: None, + cwd: test.cwd_path().to_path_buf(), + approval_policy: AskForApproval::Never, + sandbox_policy: SandboxPolicy::DangerFullAccess, + model: REQUESTED_MODEL.to_string(), + effort: test.config.model_reasoning_effort, + summary: ReasoningSummary::Auto, + collaboration_mode: None, + personality: None, + }) + .await?; + + let warning = wait_for_event(&test.codex, |event| matches!(event, EventMsg::Warning(_))).await; + let EventMsg::Warning(warning) = warning else { + panic!("expected warning event"); + }; + assert!(warning.message.contains(REQUESTED_MODEL)); + assert!(warning.message.contains(SERVER_MODEL)); + + let warning_item = wait_for_event(&test.codex, |event| { + matches!( + event, + EventMsg::RawResponseItem(raw) + if matches!( + &raw.item, + ResponseItem::Message { content, .. } + if content.iter().any(|item| matches!( + item, + ContentItem::InputText { text } if text.starts_with("Warning: ") + )) + ) + ) + }) + .await; + let EventMsg::RawResponseItem(raw) = warning_item else { + panic!("expected raw response item event"); + }; + let ResponseItem::Message { role, content, .. } = raw.item else { + panic!("expected warning to be recorded as a message item"); + }; + assert_eq!(role, "user"); + let warning_text = content.iter().find_map(|item| match item { + ContentItem::InputText { text } => Some(text.as_str()), + _ => None, + }); + let warning_text = warning_text.expect("warning message should include input_text content"); + assert!(warning_text.contains(REQUESTED_MODEL)); + assert!(warning_text.contains(SERVER_MODEL)); + + let _ = wait_for_event(&test.codex, |event| { + matches!(event, EventMsg::TurnComplete(_)) + }) + .await; + + Ok(()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn response_model_field_mismatch_emits_warning_when_header_matches_requested() -> Result<()> { + skip_if_no_network!(Ok(())); + + let server = start_mock_server().await; + let response = sse_response(sse(vec![ + serde_json::json!({ + "type": "response.created", + "response": { + "id": "resp-1", + "model": SERVER_MODEL, + } + }), + core_test_support::responses::ev_completed("resp-1"), + ])) + .insert_header("OpenAI-Model", REQUESTED_MODEL); + let _mock = mount_response_once(&server, response).await; + + let mut builder = test_codex().with_model(REQUESTED_MODEL); + let test = builder.build(&server).await?; + + test.codex + .submit(Op::UserTurn { + items: vec![UserInput::Text { + text: "trigger response model check".to_string(), + text_elements: Vec::new(), + }], + final_output_json_schema: None, + cwd: test.cwd_path().to_path_buf(), + approval_policy: AskForApproval::Never, + sandbox_policy: SandboxPolicy::DangerFullAccess, + model: REQUESTED_MODEL.to_string(), + effort: test.config.model_reasoning_effort, + summary: ReasoningSummary::Auto, + collaboration_mode: None, + personality: None, + }) + .await?; + + let warning = wait_for_event(&test.codex, |event| { + matches!( + event, + EventMsg::Warning(warning) + if warning + .message + .contains("flagged for potentially high-risk cyber activity") + ) + }) + .await; + let EventMsg::Warning(warning) = warning else { + panic!("expected warning event"); + }; + assert!(warning.message.contains("gpt-5.2 as a fallback")); + + let _ = wait_for_event(&test.codex, |event| { + matches!(event, EventMsg::TurnComplete(_)) + }) + .await; + + Ok(()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn openai_model_header_mismatch_only_emits_one_warning_per_turn() -> Result<()> { + skip_if_no_network!(Ok(())); + + let server = start_mock_server().await; + let tool_args = serde_json::json!({ + "command": "echo hello", + "timeout_ms": 1_000 + }); + + let first_response = sse_response(sse(vec![ + ev_response_created("resp-1"), + ev_function_call( + "call-1", + "shell_command", + &serde_json::to_string(&tool_args)?, + ), + core_test_support::responses::ev_completed("resp-1"), + ])) + .insert_header("OpenAI-Model", SERVER_MODEL); + let second_response = sse_response(sse(vec![ + ev_response_created("resp-2"), + ev_assistant_message("msg-1", "done"), + core_test_support::responses::ev_completed("resp-2"), + ])) + .insert_header("OpenAI-Model", SERVER_MODEL); + let _mock = mount_response_sequence(&server, vec![first_response, second_response]).await; + + let mut builder = test_codex().with_model(REQUESTED_MODEL); + let test = builder.build(&server).await?; + + test.codex + .submit(Op::UserTurn { + items: vec![UserInput::Text { + text: "trigger follow-up turn".to_string(), + text_elements: Vec::new(), + }], + final_output_json_schema: None, + cwd: test.cwd_path().to_path_buf(), + approval_policy: AskForApproval::Never, + sandbox_policy: SandboxPolicy::DangerFullAccess, + model: REQUESTED_MODEL.to_string(), + effort: test.config.model_reasoning_effort, + summary: ReasoningSummary::Auto, + collaboration_mode: None, + personality: None, + }) + .await?; + + let mut warning_count = 0; + loop { + let event = wait_for_event(&test.codex, |_| true).await; + match event { + EventMsg::Warning(warning) if warning.message.contains(REQUESTED_MODEL) => { + warning_count += 1; + } + EventMsg::TurnComplete(_) => break, + _ => {} + } + } + + assert_eq!(warning_count, 1); + + Ok(()) +} diff --git a/codex-rs/otel/src/traces/otel_manager.rs b/codex-rs/otel/src/traces/otel_manager.rs index f50eb5dfa..e3ba85eab 100644 --- a/codex-rs/otel/src/traces/otel_manager.rs +++ b/codex-rs/otel/src/traces/otel_manager.rs @@ -748,6 +748,7 @@ impl OtelManager { ResponseEvent::ReasoningSummaryPartAdded { .. } => { "reasoning_summary_part_added".into() } + ResponseEvent::ServerModel(_) => "server_model".into(), ResponseEvent::ServerReasoningIncluded(_) => "server_reasoning_included".into(), ResponseEvent::RateLimits(_) => "rate_limits".into(), ResponseEvent::ModelsEtag(_) => "models_etag".into(),