diff --git a/codex-rs/core/src/client.rs b/codex-rs/core/src/client.rs index 314bfe7b4..1d0dc7fb6 100644 --- a/codex-rs/core/src/client.rs +++ b/codex-rs/core/src/client.rs @@ -1,13 +1,10 @@ -use std::path::PathBuf; use std::sync::Arc; use std::sync::OnceLock; -use std::sync::RwLock; use crate::api_bridge::CoreAuthProvider; use crate::api_bridge::auth_provider_from_auth; use crate::api_bridge::map_api_error; use crate::auth::UnauthorizedRecovery; -use crate::turn_metadata::build_turn_metadata_header; use codex_api::CompactClient as ApiCompactClient; use codex_api::CompactionInput as ApiCompactionInput; use codex_api::Prompt as ApiPrompt; @@ -75,12 +72,6 @@ pub const X_CODEX_TURN_METADATA_HEADER: &str = "x-codex-turn-metadata"; pub const X_RESPONSESAPI_INCLUDE_TIMING_METRICS_HEADER: &str = "x-responsesapi-include-timing-metrics"; -#[derive(Debug, Default)] -struct TurnMetadataCache { - cwd: Option, - header: Option, -} - #[derive(Debug)] struct ModelClientState { config: Arc, @@ -93,7 +84,6 @@ struct ModelClientState { summary: ReasoningSummaryConfig, session_source: SessionSource, transport_manager: TransportManager, - turn_metadata_cache: Arc>, } #[derive(Debug, Clone)] @@ -106,6 +96,7 @@ pub struct ModelClientSession { connection: Option, websocket_last_items: Vec, transport_manager: TransportManager, + turn_metadata_header: Option, /// Turn state for sticky routing. /// /// This is an `OnceLock` that stores the turn state value received from the server @@ -145,53 +136,20 @@ impl ModelClient { summary, session_source, transport_manager, - turn_metadata_cache: Arc::new(RwLock::new(TurnMetadataCache::default())), }), } } - pub fn new_session(&self, turn_metadata_cwd: Option) -> ModelClientSession { - self.prewarm_turn_metadata_header(turn_metadata_cwd); + pub fn new_session(&self, turn_metadata_header: Option) -> ModelClientSession { ModelClientSession { state: Arc::clone(&self.state), connection: None, websocket_last_items: Vec::new(), transport_manager: self.state.transport_manager.clone(), + turn_metadata_header, turn_state: Arc::new(OnceLock::new()), } } - - /// Refresh turn metadata in the background and update a cached header that request - /// builders can read without blocking. - fn prewarm_turn_metadata_header(&self, turn_metadata_cwd: Option) { - let turn_metadata_cwd = - turn_metadata_cwd.map(|cwd| std::fs::canonicalize(&cwd).unwrap_or(cwd)); - - if let Ok(mut cache) = self.state.turn_metadata_cache.write() - && cache.cwd != turn_metadata_cwd - { - cache.cwd = turn_metadata_cwd.clone(); - cache.header = None; - } - - let Some(cwd) = turn_metadata_cwd else { - return; - }; - let turn_metadata_cache = Arc::clone(&self.state.turn_metadata_cache); - if let Ok(handle) = tokio::runtime::Handle::try_current() { - let _task = handle.spawn(async move { - let header = build_turn_metadata_header(cwd.as_path()) - .await - .and_then(|value| HeaderValue::from_str(value.as_str()).ok()); - - if let Ok(mut cache) = turn_metadata_cache.write() - && cache.cwd.as_ref() == Some(&cwd) - { - cache.header = header; - } - }); - } - } } impl ModelClient { @@ -300,14 +258,6 @@ impl ModelClient { } impl ModelClientSession { - fn turn_metadata_header(&self) -> Option { - self.state - .turn_metadata_cache - .try_read() - .ok() - .and_then(|cache| cache.header.clone()) - } - /// Streams a single model turn using the configured Responses transport. pub async fn stream(&mut self, prompt: &Prompt) -> Result { let wire_api = self.state.provider.wire_api; @@ -364,7 +314,10 @@ impl ModelClientSession { prompt: &Prompt, compression: Compression, ) -> ApiResponsesOptions { - let turn_metadata_header = self.turn_metadata_header(); + let turn_metadata_header = self + .turn_metadata_header + .as_deref() + .and_then(|value| HeaderValue::from_str(value).ok()); let model_info = &self.state.model_info; let default_reasoning_effort = model_info.default_reasoning_level; diff --git a/codex-rs/core/src/codex.rs b/codex-rs/core/src/codex.rs index b7c040405..7fe3b8c81 100644 --- a/codex-rs/core/src/codex.rs +++ b/codex-rs/core/src/codex.rs @@ -34,6 +34,7 @@ use crate::stream_events_utils::last_assistant_message_from_item; use crate::terminal; use crate::transport_manager::TransportManager; use crate::truncate::TruncationPolicy; +use crate::turn_metadata::build_turn_metadata_header; use crate::user_notification::UserNotifier; use crate::util::error_or_panic; use async_channel::Receiver; @@ -80,6 +81,7 @@ use rmcp::model::RequestId; use serde_json; use serde_json::Value; use tokio::sync::Mutex; +use tokio::sync::OnceCell; use tokio::sync::RwLock; use tokio::sync::oneshot; use tokio_util::sync::CancellationToken; @@ -90,6 +92,7 @@ use tracing::field; use tracing::info; use tracing::info_span; use tracing::instrument; +use tracing::trace; use tracing::trace_span; use tracing::warn; @@ -501,6 +504,7 @@ pub(crate) struct TurnContext { pub(crate) tool_call_gate: Arc, pub(crate) truncation_policy: TruncationPolicy, pub(crate) dynamic_tools: Vec, + turn_metadata_header: OnceCell>, } impl TurnContext { pub(crate) fn resolve_path(&self, path: Option) -> PathBuf { @@ -514,6 +518,38 @@ impl TurnContext { .as_deref() .unwrap_or(compact::SUMMARIZATION_PROMPT) } + + async fn build_turn_metadata_header(&self) -> Option { + self.turn_metadata_header + .get_or_init(|| async { build_turn_metadata_header(self.cwd.as_path()).await }) + .await + .clone() + } + + pub async fn resolve_turn_metadata_header(&self) -> Option { + const TURN_METADATA_HEADER_TIMEOUT_MS: u64 = 250; + match tokio::time::timeout( + std::time::Duration::from_millis(TURN_METADATA_HEADER_TIMEOUT_MS), + self.build_turn_metadata_header(), + ) + .await + { + Ok(header) => header, + Err(_) => { + warn!("timed out after 250ms while building turn metadata header"); + self.turn_metadata_header.get().cloned().flatten() + } + } + } + + pub fn spawn_turn_metadata_header_task(self: &Arc) { + let context = Arc::clone(self); + tokio::spawn(async move { + trace!("Spawning turn metadata calculation task"); + context.build_turn_metadata_header().await; + trace!("Turn metadata calculation task completed"); + }); + } } #[derive(Clone)] @@ -682,10 +718,11 @@ impl Session { web_search_mode: per_turn_config.web_search_mode, }); + let cwd = session_configuration.cwd.clone(); TurnContext { sub_id, client, - cwd: session_configuration.cwd.clone(), + cwd, developer_instructions: session_configuration.developer_instructions.clone(), compact_prompt: session_configuration.compact_prompt.clone(), user_instructions: session_configuration.user_instructions.clone(), @@ -702,6 +739,7 @@ impl Session { tool_call_gate: Arc::new(ReadinessFlag::new()), truncation_policy: model_info.truncation_policy.into(), dynamic_tools: session_configuration.dynamic_tools.clone(), + turn_metadata_header: OnceCell::new(), } } @@ -1246,10 +1284,13 @@ impl Session { sub_id, self.services.transport_manager.clone(), ); + if let Some(final_schema) = final_output_json_schema { turn_context.final_output_json_schema = final_schema; } - Arc::new(turn_context) + let turn_context = Arc::new(turn_context); + turn_context.spawn_turn_metadata_header_task(); + turn_context } pub(crate) async fn new_default_turn(&self) -> Arc { @@ -3274,6 +3315,7 @@ async fn spawn_review_thread( tool_call_gate: Arc::new(ReadinessFlag::new()), dynamic_tools: parent_turn_context.dynamic_tools.clone(), truncation_policy: model_info.truncation_policy.into(), + turn_metadata_header: parent_turn_context.turn_metadata_header.clone(), }; // Seed the child task with the review prompt as the initial user message. @@ -3478,9 +3520,8 @@ pub(crate) async fn run_turn( // many turns, from the perspective of the user, it is a single turn. let turn_diff_tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::new())); - let mut client_session = turn_context - .client - .new_session(Some(turn_context.cwd.clone())); + let turn_metadata_header = turn_context.resolve_turn_metadata_header().await; + let mut client_session = turn_context.client.new_session(turn_metadata_header); loop { // Note that pending_input would be something like a message the user diff --git a/codex-rs/core/src/compact.rs b/codex-rs/core/src/compact.rs index ee94e4994..7ae773884 100644 --- a/codex-rs/core/src/compact.rs +++ b/codex-rs/core/src/compact.rs @@ -337,9 +337,8 @@ async fn drain_to_completed( turn_context: &TurnContext, prompt: &Prompt, ) -> CodexResult<()> { - let mut client_session = turn_context - .client - .new_session(Some(turn_context.cwd.clone())); + let turn_metadata_header = turn_context.resolve_turn_metadata_header().await; + let mut client_session = turn_context.client.new_session(turn_metadata_header); let mut stream = client_session.stream(prompt).await?; loop { let maybe_event = stream.next().await; diff --git a/codex-rs/core/src/lib.rs b/codex-rs/core/src/lib.rs index c5ac0d8f8..9bd4c8725 100644 --- a/codex-rs/core/src/lib.rs +++ b/codex-rs/core/src/lib.rs @@ -139,6 +139,7 @@ pub use exec_policy::check_execpolicy_for_warnings; pub use exec_policy::load_exec_policy; pub use safety::get_platform_sandbox; pub use tools::spec::parse_tool_input_schema; +pub use turn_metadata::build_turn_metadata_header; // Re-export the protocol types from the standalone `codex-protocol` crate so existing // `codex_core::protocol::...` references continue to work across the workspace. pub use codex_protocol::protocol; diff --git a/codex-rs/core/src/turn_metadata.rs b/codex-rs/core/src/turn_metadata.rs index 58b394a2f..6d4878c7a 100644 --- a/codex-rs/core/src/turn_metadata.rs +++ b/codex-rs/core/src/turn_metadata.rs @@ -20,7 +20,7 @@ struct TurnMetadata { workspaces: BTreeMap, } -pub(crate) async fn build_turn_metadata_header(cwd: &Path) -> Option { +pub async fn build_turn_metadata_header(cwd: &Path) -> Option { let repo_root = get_git_repo_root(cwd)?; let (latest_git_commit_hash, associated_remote_urls) = tokio::join!( diff --git a/codex-rs/core/tests/responses_headers.rs b/codex-rs/core/tests/responses_headers.rs index 86e511c47..fd65bcf32 100644 --- a/codex-rs/core/tests/responses_headers.rs +++ b/codex-rs/core/tests/responses_headers.rs @@ -408,88 +408,14 @@ async fn responses_stream_includes_turn_metadata_header_for_git_workspace_e2e() responses::ev_response_created("resp-1"), responses::ev_completed("resp-1"), ]); - let provider = ModelProviderInfo { - name: "mock".into(), - base_url: Some(format!("{}/v1", server.uri())), - env_key: None, - env_key_instructions: None, - experimental_bearer_token: None, - wire_api: WireApi::Responses, - query_params: None, - http_headers: None, - env_http_headers: None, - request_max_retries: Some(0), - stream_max_retries: Some(0), - stream_idle_timeout_ms: Some(5_000), - requires_openai_auth: false, - supports_websockets: false, - }; - let codex_home = TempDir::new().expect("failed to create TempDir"); - let mut config = load_default_config_for_test(&codex_home).await; - config.model_provider_id = provider.name.clone(); - config.model_provider = provider.clone(); - let effort = config.model_reasoning_effort; - let summary = config.model_reasoning_summary; - let model = ModelsManager::get_model_offline(config.model.as_deref()); - config.model = Some(model.clone()); - let config = Arc::new(config); - - let conversation_id = ThreadId::new(); - let auth_mode = AuthMode::Chatgpt; - let session_source = - SessionSource::SubAgent(SubAgentSource::Other("turn-metadata-e2e".to_string())); - let model_info = ModelsManager::construct_model_info_offline(model.as_str(), &config); - let otel_manager = OtelManager::new( - conversation_id, - model.as_str(), - model_info.slug.as_str(), - None, - Some("test@test.com".to_string()), - Some(auth_mode), - false, - "test".to_string(), - session_source.clone(), - ); - - let client = ModelClient::new( - Arc::clone(&config), - None, - model_info, - otel_manager, - provider, - effort, - summary, - conversation_id, - session_source, - TransportManager::new(), - ); - - let workspace = TempDir::new().expect("workspace tempdir"); - let cwd = workspace.path(); - - let mut prompt = Prompt::default(); - prompt.input = vec![ResponseItem::Message { - id: None, - role: "user".into(), - content: vec![ContentItem::InputText { - text: "hello".into(), - }], - end_turn: None, - phase: None, - }]; + let test = test_codex().build(&server).await.expect("build test codex"); + let cwd = test.cwd_path(); let first_request = responses::mount_sse_once(&server, response_body.clone()).await; - let mut first_session = client.new_session(Some(cwd.to_path_buf())); - let mut first_stream = first_session - .stream(&prompt) + test.submit_turn("hello") .await - .expect("stream first turn"); - while let Some(event) = first_stream.next().await { - if matches!(event, Ok(ResponseEvent::Completed { .. })) { - break; - } - } + .expect("submit first turn prompt"); assert_eq!( first_request .single_request() @@ -539,21 +465,13 @@ async fn responses_stream_includes_turn_metadata_header_for_git_workspace_e2e() .trim() .to_string(); - let repo_root = std::fs::canonicalize(cwd) - .unwrap_or_else(|_| cwd.to_path_buf()) - .to_string_lossy() - .into_owned(); let deadline = tokio::time::Instant::now() + std::time::Duration::from_secs(5); loop { let request_recorder = responses::mount_sse_once(&server, response_body.clone()).await; - let mut session = client.new_session(Some(cwd.to_path_buf())); tokio::time::sleep(std::time::Duration::from_millis(50)).await; - let mut stream = session.stream(&prompt).await.expect("stream post-git turn"); - while let Some(event) = stream.next().await { - if matches!(event, Ok(ResponseEvent::Completed { .. })) { - break; - } - } + test.submit_turn("hello") + .await + .expect("submit post-git turn prompt"); let maybe_header = request_recorder .single_request() @@ -561,11 +479,14 @@ async fn responses_stream_includes_turn_metadata_header_for_git_workspace_e2e() if let Some(header_value) = maybe_header { let parsed: serde_json::Value = serde_json::from_str(&header_value) .expect("x-codex-turn-metadata should be valid JSON"); - let workspace = parsed + let workspaces = parsed .get("workspaces") .and_then(serde_json::Value::as_object) - .and_then(|workspaces| workspaces.get(&repo_root)) - .expect("metadata should include cwd repo root workspace entry"); + .expect("metadata should include workspaces"); + let workspace = workspaces + .values() + .next() + .expect("metadata should include at least one workspace entry"); assert_eq!( workspace