Move metadata calculation out of client (#10589)

Model client shouldn't be responsible for this.
This commit is contained in:
pakrym-oai 2026-02-03 21:59:13 -08:00 committed by GitHub
parent 38a47700b5
commit 56ebfff1a8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 70 additions and 155 deletions

View file

@ -1,13 +1,10 @@
use std::path::PathBuf;
use std::sync::Arc;
use std::sync::OnceLock;
use std::sync::RwLock;
use crate::api_bridge::CoreAuthProvider;
use crate::api_bridge::auth_provider_from_auth;
use crate::api_bridge::map_api_error;
use crate::auth::UnauthorizedRecovery;
use crate::turn_metadata::build_turn_metadata_header;
use codex_api::CompactClient as ApiCompactClient;
use codex_api::CompactionInput as ApiCompactionInput;
use codex_api::Prompt as ApiPrompt;
@ -75,12 +72,6 @@ pub const X_CODEX_TURN_METADATA_HEADER: &str = "x-codex-turn-metadata";
pub const X_RESPONSESAPI_INCLUDE_TIMING_METRICS_HEADER: &str =
"x-responsesapi-include-timing-metrics";
#[derive(Debug, Default)]
struct TurnMetadataCache {
cwd: Option<PathBuf>,
header: Option<HeaderValue>,
}
#[derive(Debug)]
struct ModelClientState {
config: Arc<Config>,
@ -93,7 +84,6 @@ struct ModelClientState {
summary: ReasoningSummaryConfig,
session_source: SessionSource,
transport_manager: TransportManager,
turn_metadata_cache: Arc<RwLock<TurnMetadataCache>>,
}
#[derive(Debug, Clone)]
@ -106,6 +96,7 @@ pub struct ModelClientSession {
connection: Option<ApiWebSocketConnection>,
websocket_last_items: Vec<ResponseItem>,
transport_manager: TransportManager,
turn_metadata_header: Option<String>,
/// Turn state for sticky routing.
///
/// This is an `OnceLock` that stores the turn state value received from the server
@ -145,53 +136,20 @@ impl ModelClient {
summary,
session_source,
transport_manager,
turn_metadata_cache: Arc::new(RwLock::new(TurnMetadataCache::default())),
}),
}
}
pub fn new_session(&self, turn_metadata_cwd: Option<PathBuf>) -> ModelClientSession {
self.prewarm_turn_metadata_header(turn_metadata_cwd);
pub fn new_session(&self, turn_metadata_header: Option<String>) -> ModelClientSession {
ModelClientSession {
state: Arc::clone(&self.state),
connection: None,
websocket_last_items: Vec::new(),
transport_manager: self.state.transport_manager.clone(),
turn_metadata_header,
turn_state: Arc::new(OnceLock::new()),
}
}
/// Refresh turn metadata in the background and update a cached header that request
/// builders can read without blocking.
fn prewarm_turn_metadata_header(&self, turn_metadata_cwd: Option<PathBuf>) {
let turn_metadata_cwd =
turn_metadata_cwd.map(|cwd| std::fs::canonicalize(&cwd).unwrap_or(cwd));
if let Ok(mut cache) = self.state.turn_metadata_cache.write()
&& cache.cwd != turn_metadata_cwd
{
cache.cwd = turn_metadata_cwd.clone();
cache.header = None;
}
let Some(cwd) = turn_metadata_cwd else {
return;
};
let turn_metadata_cache = Arc::clone(&self.state.turn_metadata_cache);
if let Ok(handle) = tokio::runtime::Handle::try_current() {
let _task = handle.spawn(async move {
let header = build_turn_metadata_header(cwd.as_path())
.await
.and_then(|value| HeaderValue::from_str(value.as_str()).ok());
if let Ok(mut cache) = turn_metadata_cache.write()
&& cache.cwd.as_ref() == Some(&cwd)
{
cache.header = header;
}
});
}
}
}
impl ModelClient {
@ -300,14 +258,6 @@ impl ModelClient {
}
impl ModelClientSession {
fn turn_metadata_header(&self) -> Option<HeaderValue> {
self.state
.turn_metadata_cache
.try_read()
.ok()
.and_then(|cache| cache.header.clone())
}
/// Streams a single model turn using the configured Responses transport.
pub async fn stream(&mut self, prompt: &Prompt) -> Result<ResponseStream> {
let wire_api = self.state.provider.wire_api;
@ -364,7 +314,10 @@ impl ModelClientSession {
prompt: &Prompt,
compression: Compression,
) -> ApiResponsesOptions {
let turn_metadata_header = self.turn_metadata_header();
let turn_metadata_header = self
.turn_metadata_header
.as_deref()
.and_then(|value| HeaderValue::from_str(value).ok());
let model_info = &self.state.model_info;
let default_reasoning_effort = model_info.default_reasoning_level;

View file

@ -34,6 +34,7 @@ use crate::stream_events_utils::last_assistant_message_from_item;
use crate::terminal;
use crate::transport_manager::TransportManager;
use crate::truncate::TruncationPolicy;
use crate::turn_metadata::build_turn_metadata_header;
use crate::user_notification::UserNotifier;
use crate::util::error_or_panic;
use async_channel::Receiver;
@ -80,6 +81,7 @@ use rmcp::model::RequestId;
use serde_json;
use serde_json::Value;
use tokio::sync::Mutex;
use tokio::sync::OnceCell;
use tokio::sync::RwLock;
use tokio::sync::oneshot;
use tokio_util::sync::CancellationToken;
@ -90,6 +92,7 @@ use tracing::field;
use tracing::info;
use tracing::info_span;
use tracing::instrument;
use tracing::trace;
use tracing::trace_span;
use tracing::warn;
@ -501,6 +504,7 @@ pub(crate) struct TurnContext {
pub(crate) tool_call_gate: Arc<ReadinessFlag>,
pub(crate) truncation_policy: TruncationPolicy,
pub(crate) dynamic_tools: Vec<DynamicToolSpec>,
turn_metadata_header: OnceCell<Option<String>>,
}
impl TurnContext {
pub(crate) fn resolve_path(&self, path: Option<String>) -> PathBuf {
@ -514,6 +518,38 @@ impl TurnContext {
.as_deref()
.unwrap_or(compact::SUMMARIZATION_PROMPT)
}
async fn build_turn_metadata_header(&self) -> Option<String> {
self.turn_metadata_header
.get_or_init(|| async { build_turn_metadata_header(self.cwd.as_path()).await })
.await
.clone()
}
pub async fn resolve_turn_metadata_header(&self) -> Option<String> {
const TURN_METADATA_HEADER_TIMEOUT_MS: u64 = 250;
match tokio::time::timeout(
std::time::Duration::from_millis(TURN_METADATA_HEADER_TIMEOUT_MS),
self.build_turn_metadata_header(),
)
.await
{
Ok(header) => header,
Err(_) => {
warn!("timed out after 250ms while building turn metadata header");
self.turn_metadata_header.get().cloned().flatten()
}
}
}
pub fn spawn_turn_metadata_header_task(self: &Arc<Self>) {
let context = Arc::clone(self);
tokio::spawn(async move {
trace!("Spawning turn metadata calculation task");
context.build_turn_metadata_header().await;
trace!("Turn metadata calculation task completed");
});
}
}
#[derive(Clone)]
@ -682,10 +718,11 @@ impl Session {
web_search_mode: per_turn_config.web_search_mode,
});
let cwd = session_configuration.cwd.clone();
TurnContext {
sub_id,
client,
cwd: session_configuration.cwd.clone(),
cwd,
developer_instructions: session_configuration.developer_instructions.clone(),
compact_prompt: session_configuration.compact_prompt.clone(),
user_instructions: session_configuration.user_instructions.clone(),
@ -702,6 +739,7 @@ impl Session {
tool_call_gate: Arc::new(ReadinessFlag::new()),
truncation_policy: model_info.truncation_policy.into(),
dynamic_tools: session_configuration.dynamic_tools.clone(),
turn_metadata_header: OnceCell::new(),
}
}
@ -1246,10 +1284,13 @@ impl Session {
sub_id,
self.services.transport_manager.clone(),
);
if let Some(final_schema) = final_output_json_schema {
turn_context.final_output_json_schema = final_schema;
}
Arc::new(turn_context)
let turn_context = Arc::new(turn_context);
turn_context.spawn_turn_metadata_header_task();
turn_context
}
pub(crate) async fn new_default_turn(&self) -> Arc<TurnContext> {
@ -3274,6 +3315,7 @@ async fn spawn_review_thread(
tool_call_gate: Arc::new(ReadinessFlag::new()),
dynamic_tools: parent_turn_context.dynamic_tools.clone(),
truncation_policy: model_info.truncation_policy.into(),
turn_metadata_header: parent_turn_context.turn_metadata_header.clone(),
};
// Seed the child task with the review prompt as the initial user message.
@ -3478,9 +3520,8 @@ pub(crate) async fn run_turn(
// many turns, from the perspective of the user, it is a single turn.
let turn_diff_tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::new()));
let mut client_session = turn_context
.client
.new_session(Some(turn_context.cwd.clone()));
let turn_metadata_header = turn_context.resolve_turn_metadata_header().await;
let mut client_session = turn_context.client.new_session(turn_metadata_header);
loop {
// Note that pending_input would be something like a message the user

View file

@ -337,9 +337,8 @@ async fn drain_to_completed(
turn_context: &TurnContext,
prompt: &Prompt,
) -> CodexResult<()> {
let mut client_session = turn_context
.client
.new_session(Some(turn_context.cwd.clone()));
let turn_metadata_header = turn_context.resolve_turn_metadata_header().await;
let mut client_session = turn_context.client.new_session(turn_metadata_header);
let mut stream = client_session.stream(prompt).await?;
loop {
let maybe_event = stream.next().await;

View file

@ -139,6 +139,7 @@ pub use exec_policy::check_execpolicy_for_warnings;
pub use exec_policy::load_exec_policy;
pub use safety::get_platform_sandbox;
pub use tools::spec::parse_tool_input_schema;
pub use turn_metadata::build_turn_metadata_header;
// Re-export the protocol types from the standalone `codex-protocol` crate so existing
// `codex_core::protocol::...` references continue to work across the workspace.
pub use codex_protocol::protocol;

View file

@ -20,7 +20,7 @@ struct TurnMetadata {
workspaces: BTreeMap<String, TurnMetadataWorkspace>,
}
pub(crate) async fn build_turn_metadata_header(cwd: &Path) -> Option<String> {
pub async fn build_turn_metadata_header(cwd: &Path) -> Option<String> {
let repo_root = get_git_repo_root(cwd)?;
let (latest_git_commit_hash, associated_remote_urls) = tokio::join!(

View file

@ -408,88 +408,14 @@ async fn responses_stream_includes_turn_metadata_header_for_git_workspace_e2e()
responses::ev_response_created("resp-1"),
responses::ev_completed("resp-1"),
]);
let provider = ModelProviderInfo {
name: "mock".into(),
base_url: Some(format!("{}/v1", server.uri())),
env_key: None,
env_key_instructions: None,
experimental_bearer_token: None,
wire_api: WireApi::Responses,
query_params: None,
http_headers: None,
env_http_headers: None,
request_max_retries: Some(0),
stream_max_retries: Some(0),
stream_idle_timeout_ms: Some(5_000),
requires_openai_auth: false,
supports_websockets: false,
};
let codex_home = TempDir::new().expect("failed to create TempDir");
let mut config = load_default_config_for_test(&codex_home).await;
config.model_provider_id = provider.name.clone();
config.model_provider = provider.clone();
let effort = config.model_reasoning_effort;
let summary = config.model_reasoning_summary;
let model = ModelsManager::get_model_offline(config.model.as_deref());
config.model = Some(model.clone());
let config = Arc::new(config);
let conversation_id = ThreadId::new();
let auth_mode = AuthMode::Chatgpt;
let session_source =
SessionSource::SubAgent(SubAgentSource::Other("turn-metadata-e2e".to_string()));
let model_info = ModelsManager::construct_model_info_offline(model.as_str(), &config);
let otel_manager = OtelManager::new(
conversation_id,
model.as_str(),
model_info.slug.as_str(),
None,
Some("test@test.com".to_string()),
Some(auth_mode),
false,
"test".to_string(),
session_source.clone(),
);
let client = ModelClient::new(
Arc::clone(&config),
None,
model_info,
otel_manager,
provider,
effort,
summary,
conversation_id,
session_source,
TransportManager::new(),
);
let workspace = TempDir::new().expect("workspace tempdir");
let cwd = workspace.path();
let mut prompt = Prompt::default();
prompt.input = vec![ResponseItem::Message {
id: None,
role: "user".into(),
content: vec![ContentItem::InputText {
text: "hello".into(),
}],
end_turn: None,
phase: None,
}];
let test = test_codex().build(&server).await.expect("build test codex");
let cwd = test.cwd_path();
let first_request = responses::mount_sse_once(&server, response_body.clone()).await;
let mut first_session = client.new_session(Some(cwd.to_path_buf()));
let mut first_stream = first_session
.stream(&prompt)
test.submit_turn("hello")
.await
.expect("stream first turn");
while let Some(event) = first_stream.next().await {
if matches!(event, Ok(ResponseEvent::Completed { .. })) {
break;
}
}
.expect("submit first turn prompt");
assert_eq!(
first_request
.single_request()
@ -539,21 +465,13 @@ async fn responses_stream_includes_turn_metadata_header_for_git_workspace_e2e()
.trim()
.to_string();
let repo_root = std::fs::canonicalize(cwd)
.unwrap_or_else(|_| cwd.to_path_buf())
.to_string_lossy()
.into_owned();
let deadline = tokio::time::Instant::now() + std::time::Duration::from_secs(5);
loop {
let request_recorder = responses::mount_sse_once(&server, response_body.clone()).await;
let mut session = client.new_session(Some(cwd.to_path_buf()));
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
let mut stream = session.stream(&prompt).await.expect("stream post-git turn");
while let Some(event) = stream.next().await {
if matches!(event, Ok(ResponseEvent::Completed { .. })) {
break;
}
}
test.submit_turn("hello")
.await
.expect("submit post-git turn prompt");
let maybe_header = request_recorder
.single_request()
@ -561,11 +479,14 @@ async fn responses_stream_includes_turn_metadata_header_for_git_workspace_e2e()
if let Some(header_value) = maybe_header {
let parsed: serde_json::Value = serde_json::from_str(&header_value)
.expect("x-codex-turn-metadata should be valid JSON");
let workspace = parsed
let workspaces = parsed
.get("workspaces")
.and_then(serde_json::Value::as_object)
.and_then(|workspaces| workspaces.get(&repo_root))
.expect("metadata should include cwd repo root workspace entry");
.expect("metadata should include workspaces");
let workspace = workspaces
.values()
.next()
.expect("metadata should include at least one workspace entry");
assert_eq!(
workspace