Move metadata calculation out of client (#10589)
Model client shouldn't be responsible for this.
This commit is contained in:
parent
38a47700b5
commit
56ebfff1a8
6 changed files with 70 additions and 155 deletions
|
|
@ -1,13 +1,10 @@
|
|||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use std::sync::OnceLock;
|
||||
use std::sync::RwLock;
|
||||
|
||||
use crate::api_bridge::CoreAuthProvider;
|
||||
use crate::api_bridge::auth_provider_from_auth;
|
||||
use crate::api_bridge::map_api_error;
|
||||
use crate::auth::UnauthorizedRecovery;
|
||||
use crate::turn_metadata::build_turn_metadata_header;
|
||||
use codex_api::CompactClient as ApiCompactClient;
|
||||
use codex_api::CompactionInput as ApiCompactionInput;
|
||||
use codex_api::Prompt as ApiPrompt;
|
||||
|
|
@ -75,12 +72,6 @@ pub const X_CODEX_TURN_METADATA_HEADER: &str = "x-codex-turn-metadata";
|
|||
pub const X_RESPONSESAPI_INCLUDE_TIMING_METRICS_HEADER: &str =
|
||||
"x-responsesapi-include-timing-metrics";
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
struct TurnMetadataCache {
|
||||
cwd: Option<PathBuf>,
|
||||
header: Option<HeaderValue>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct ModelClientState {
|
||||
config: Arc<Config>,
|
||||
|
|
@ -93,7 +84,6 @@ struct ModelClientState {
|
|||
summary: ReasoningSummaryConfig,
|
||||
session_source: SessionSource,
|
||||
transport_manager: TransportManager,
|
||||
turn_metadata_cache: Arc<RwLock<TurnMetadataCache>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
|
|
@ -106,6 +96,7 @@ pub struct ModelClientSession {
|
|||
connection: Option<ApiWebSocketConnection>,
|
||||
websocket_last_items: Vec<ResponseItem>,
|
||||
transport_manager: TransportManager,
|
||||
turn_metadata_header: Option<String>,
|
||||
/// Turn state for sticky routing.
|
||||
///
|
||||
/// This is an `OnceLock` that stores the turn state value received from the server
|
||||
|
|
@ -145,53 +136,20 @@ impl ModelClient {
|
|||
summary,
|
||||
session_source,
|
||||
transport_manager,
|
||||
turn_metadata_cache: Arc::new(RwLock::new(TurnMetadataCache::default())),
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new_session(&self, turn_metadata_cwd: Option<PathBuf>) -> ModelClientSession {
|
||||
self.prewarm_turn_metadata_header(turn_metadata_cwd);
|
||||
pub fn new_session(&self, turn_metadata_header: Option<String>) -> ModelClientSession {
|
||||
ModelClientSession {
|
||||
state: Arc::clone(&self.state),
|
||||
connection: None,
|
||||
websocket_last_items: Vec::new(),
|
||||
transport_manager: self.state.transport_manager.clone(),
|
||||
turn_metadata_header,
|
||||
turn_state: Arc::new(OnceLock::new()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Refresh turn metadata in the background and update a cached header that request
|
||||
/// builders can read without blocking.
|
||||
fn prewarm_turn_metadata_header(&self, turn_metadata_cwd: Option<PathBuf>) {
|
||||
let turn_metadata_cwd =
|
||||
turn_metadata_cwd.map(|cwd| std::fs::canonicalize(&cwd).unwrap_or(cwd));
|
||||
|
||||
if let Ok(mut cache) = self.state.turn_metadata_cache.write()
|
||||
&& cache.cwd != turn_metadata_cwd
|
||||
{
|
||||
cache.cwd = turn_metadata_cwd.clone();
|
||||
cache.header = None;
|
||||
}
|
||||
|
||||
let Some(cwd) = turn_metadata_cwd else {
|
||||
return;
|
||||
};
|
||||
let turn_metadata_cache = Arc::clone(&self.state.turn_metadata_cache);
|
||||
if let Ok(handle) = tokio::runtime::Handle::try_current() {
|
||||
let _task = handle.spawn(async move {
|
||||
let header = build_turn_metadata_header(cwd.as_path())
|
||||
.await
|
||||
.and_then(|value| HeaderValue::from_str(value.as_str()).ok());
|
||||
|
||||
if let Ok(mut cache) = turn_metadata_cache.write()
|
||||
&& cache.cwd.as_ref() == Some(&cwd)
|
||||
{
|
||||
cache.header = header;
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ModelClient {
|
||||
|
|
@ -300,14 +258,6 @@ impl ModelClient {
|
|||
}
|
||||
|
||||
impl ModelClientSession {
|
||||
fn turn_metadata_header(&self) -> Option<HeaderValue> {
|
||||
self.state
|
||||
.turn_metadata_cache
|
||||
.try_read()
|
||||
.ok()
|
||||
.and_then(|cache| cache.header.clone())
|
||||
}
|
||||
|
||||
/// Streams a single model turn using the configured Responses transport.
|
||||
pub async fn stream(&mut self, prompt: &Prompt) -> Result<ResponseStream> {
|
||||
let wire_api = self.state.provider.wire_api;
|
||||
|
|
@ -364,7 +314,10 @@ impl ModelClientSession {
|
|||
prompt: &Prompt,
|
||||
compression: Compression,
|
||||
) -> ApiResponsesOptions {
|
||||
let turn_metadata_header = self.turn_metadata_header();
|
||||
let turn_metadata_header = self
|
||||
.turn_metadata_header
|
||||
.as_deref()
|
||||
.and_then(|value| HeaderValue::from_str(value).ok());
|
||||
let model_info = &self.state.model_info;
|
||||
|
||||
let default_reasoning_effort = model_info.default_reasoning_level;
|
||||
|
|
|
|||
|
|
@ -34,6 +34,7 @@ use crate::stream_events_utils::last_assistant_message_from_item;
|
|||
use crate::terminal;
|
||||
use crate::transport_manager::TransportManager;
|
||||
use crate::truncate::TruncationPolicy;
|
||||
use crate::turn_metadata::build_turn_metadata_header;
|
||||
use crate::user_notification::UserNotifier;
|
||||
use crate::util::error_or_panic;
|
||||
use async_channel::Receiver;
|
||||
|
|
@ -80,6 +81,7 @@ use rmcp::model::RequestId;
|
|||
use serde_json;
|
||||
use serde_json::Value;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::sync::OnceCell;
|
||||
use tokio::sync::RwLock;
|
||||
use tokio::sync::oneshot;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
|
|
@ -90,6 +92,7 @@ use tracing::field;
|
|||
use tracing::info;
|
||||
use tracing::info_span;
|
||||
use tracing::instrument;
|
||||
use tracing::trace;
|
||||
use tracing::trace_span;
|
||||
use tracing::warn;
|
||||
|
||||
|
|
@ -501,6 +504,7 @@ pub(crate) struct TurnContext {
|
|||
pub(crate) tool_call_gate: Arc<ReadinessFlag>,
|
||||
pub(crate) truncation_policy: TruncationPolicy,
|
||||
pub(crate) dynamic_tools: Vec<DynamicToolSpec>,
|
||||
turn_metadata_header: OnceCell<Option<String>>,
|
||||
}
|
||||
impl TurnContext {
|
||||
pub(crate) fn resolve_path(&self, path: Option<String>) -> PathBuf {
|
||||
|
|
@ -514,6 +518,38 @@ impl TurnContext {
|
|||
.as_deref()
|
||||
.unwrap_or(compact::SUMMARIZATION_PROMPT)
|
||||
}
|
||||
|
||||
async fn build_turn_metadata_header(&self) -> Option<String> {
|
||||
self.turn_metadata_header
|
||||
.get_or_init(|| async { build_turn_metadata_header(self.cwd.as_path()).await })
|
||||
.await
|
||||
.clone()
|
||||
}
|
||||
|
||||
pub async fn resolve_turn_metadata_header(&self) -> Option<String> {
|
||||
const TURN_METADATA_HEADER_TIMEOUT_MS: u64 = 250;
|
||||
match tokio::time::timeout(
|
||||
std::time::Duration::from_millis(TURN_METADATA_HEADER_TIMEOUT_MS),
|
||||
self.build_turn_metadata_header(),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(header) => header,
|
||||
Err(_) => {
|
||||
warn!("timed out after 250ms while building turn metadata header");
|
||||
self.turn_metadata_header.get().cloned().flatten()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn spawn_turn_metadata_header_task(self: &Arc<Self>) {
|
||||
let context = Arc::clone(self);
|
||||
tokio::spawn(async move {
|
||||
trace!("Spawning turn metadata calculation task");
|
||||
context.build_turn_metadata_header().await;
|
||||
trace!("Turn metadata calculation task completed");
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
|
|
@ -682,10 +718,11 @@ impl Session {
|
|||
web_search_mode: per_turn_config.web_search_mode,
|
||||
});
|
||||
|
||||
let cwd = session_configuration.cwd.clone();
|
||||
TurnContext {
|
||||
sub_id,
|
||||
client,
|
||||
cwd: session_configuration.cwd.clone(),
|
||||
cwd,
|
||||
developer_instructions: session_configuration.developer_instructions.clone(),
|
||||
compact_prompt: session_configuration.compact_prompt.clone(),
|
||||
user_instructions: session_configuration.user_instructions.clone(),
|
||||
|
|
@ -702,6 +739,7 @@ impl Session {
|
|||
tool_call_gate: Arc::new(ReadinessFlag::new()),
|
||||
truncation_policy: model_info.truncation_policy.into(),
|
||||
dynamic_tools: session_configuration.dynamic_tools.clone(),
|
||||
turn_metadata_header: OnceCell::new(),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -1246,10 +1284,13 @@ impl Session {
|
|||
sub_id,
|
||||
self.services.transport_manager.clone(),
|
||||
);
|
||||
|
||||
if let Some(final_schema) = final_output_json_schema {
|
||||
turn_context.final_output_json_schema = final_schema;
|
||||
}
|
||||
Arc::new(turn_context)
|
||||
let turn_context = Arc::new(turn_context);
|
||||
turn_context.spawn_turn_metadata_header_task();
|
||||
turn_context
|
||||
}
|
||||
|
||||
pub(crate) async fn new_default_turn(&self) -> Arc<TurnContext> {
|
||||
|
|
@ -3274,6 +3315,7 @@ async fn spawn_review_thread(
|
|||
tool_call_gate: Arc::new(ReadinessFlag::new()),
|
||||
dynamic_tools: parent_turn_context.dynamic_tools.clone(),
|
||||
truncation_policy: model_info.truncation_policy.into(),
|
||||
turn_metadata_header: parent_turn_context.turn_metadata_header.clone(),
|
||||
};
|
||||
|
||||
// Seed the child task with the review prompt as the initial user message.
|
||||
|
|
@ -3478,9 +3520,8 @@ pub(crate) async fn run_turn(
|
|||
// many turns, from the perspective of the user, it is a single turn.
|
||||
let turn_diff_tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::new()));
|
||||
|
||||
let mut client_session = turn_context
|
||||
.client
|
||||
.new_session(Some(turn_context.cwd.clone()));
|
||||
let turn_metadata_header = turn_context.resolve_turn_metadata_header().await;
|
||||
let mut client_session = turn_context.client.new_session(turn_metadata_header);
|
||||
|
||||
loop {
|
||||
// Note that pending_input would be something like a message the user
|
||||
|
|
|
|||
|
|
@ -337,9 +337,8 @@ async fn drain_to_completed(
|
|||
turn_context: &TurnContext,
|
||||
prompt: &Prompt,
|
||||
) -> CodexResult<()> {
|
||||
let mut client_session = turn_context
|
||||
.client
|
||||
.new_session(Some(turn_context.cwd.clone()));
|
||||
let turn_metadata_header = turn_context.resolve_turn_metadata_header().await;
|
||||
let mut client_session = turn_context.client.new_session(turn_metadata_header);
|
||||
let mut stream = client_session.stream(prompt).await?;
|
||||
loop {
|
||||
let maybe_event = stream.next().await;
|
||||
|
|
|
|||
|
|
@ -139,6 +139,7 @@ pub use exec_policy::check_execpolicy_for_warnings;
|
|||
pub use exec_policy::load_exec_policy;
|
||||
pub use safety::get_platform_sandbox;
|
||||
pub use tools::spec::parse_tool_input_schema;
|
||||
pub use turn_metadata::build_turn_metadata_header;
|
||||
// Re-export the protocol types from the standalone `codex-protocol` crate so existing
|
||||
// `codex_core::protocol::...` references continue to work across the workspace.
|
||||
pub use codex_protocol::protocol;
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ struct TurnMetadata {
|
|||
workspaces: BTreeMap<String, TurnMetadataWorkspace>,
|
||||
}
|
||||
|
||||
pub(crate) async fn build_turn_metadata_header(cwd: &Path) -> Option<String> {
|
||||
pub async fn build_turn_metadata_header(cwd: &Path) -> Option<String> {
|
||||
let repo_root = get_git_repo_root(cwd)?;
|
||||
|
||||
let (latest_git_commit_hash, associated_remote_urls) = tokio::join!(
|
||||
|
|
|
|||
|
|
@ -408,88 +408,14 @@ async fn responses_stream_includes_turn_metadata_header_for_git_workspace_e2e()
|
|||
responses::ev_response_created("resp-1"),
|
||||
responses::ev_completed("resp-1"),
|
||||
]);
|
||||
let provider = ModelProviderInfo {
|
||||
name: "mock".into(),
|
||||
base_url: Some(format!("{}/v1", server.uri())),
|
||||
env_key: None,
|
||||
env_key_instructions: None,
|
||||
experimental_bearer_token: None,
|
||||
wire_api: WireApi::Responses,
|
||||
query_params: None,
|
||||
http_headers: None,
|
||||
env_http_headers: None,
|
||||
request_max_retries: Some(0),
|
||||
stream_max_retries: Some(0),
|
||||
stream_idle_timeout_ms: Some(5_000),
|
||||
requires_openai_auth: false,
|
||||
supports_websockets: false,
|
||||
};
|
||||
|
||||
let codex_home = TempDir::new().expect("failed to create TempDir");
|
||||
let mut config = load_default_config_for_test(&codex_home).await;
|
||||
config.model_provider_id = provider.name.clone();
|
||||
config.model_provider = provider.clone();
|
||||
let effort = config.model_reasoning_effort;
|
||||
let summary = config.model_reasoning_summary;
|
||||
let model = ModelsManager::get_model_offline(config.model.as_deref());
|
||||
config.model = Some(model.clone());
|
||||
let config = Arc::new(config);
|
||||
|
||||
let conversation_id = ThreadId::new();
|
||||
let auth_mode = AuthMode::Chatgpt;
|
||||
let session_source =
|
||||
SessionSource::SubAgent(SubAgentSource::Other("turn-metadata-e2e".to_string()));
|
||||
let model_info = ModelsManager::construct_model_info_offline(model.as_str(), &config);
|
||||
let otel_manager = OtelManager::new(
|
||||
conversation_id,
|
||||
model.as_str(),
|
||||
model_info.slug.as_str(),
|
||||
None,
|
||||
Some("test@test.com".to_string()),
|
||||
Some(auth_mode),
|
||||
false,
|
||||
"test".to_string(),
|
||||
session_source.clone(),
|
||||
);
|
||||
|
||||
let client = ModelClient::new(
|
||||
Arc::clone(&config),
|
||||
None,
|
||||
model_info,
|
||||
otel_manager,
|
||||
provider,
|
||||
effort,
|
||||
summary,
|
||||
conversation_id,
|
||||
session_source,
|
||||
TransportManager::new(),
|
||||
);
|
||||
|
||||
let workspace = TempDir::new().expect("workspace tempdir");
|
||||
let cwd = workspace.path();
|
||||
|
||||
let mut prompt = Prompt::default();
|
||||
prompt.input = vec![ResponseItem::Message {
|
||||
id: None,
|
||||
role: "user".into(),
|
||||
content: vec![ContentItem::InputText {
|
||||
text: "hello".into(),
|
||||
}],
|
||||
end_turn: None,
|
||||
phase: None,
|
||||
}];
|
||||
let test = test_codex().build(&server).await.expect("build test codex");
|
||||
let cwd = test.cwd_path();
|
||||
|
||||
let first_request = responses::mount_sse_once(&server, response_body.clone()).await;
|
||||
let mut first_session = client.new_session(Some(cwd.to_path_buf()));
|
||||
let mut first_stream = first_session
|
||||
.stream(&prompt)
|
||||
test.submit_turn("hello")
|
||||
.await
|
||||
.expect("stream first turn");
|
||||
while let Some(event) = first_stream.next().await {
|
||||
if matches!(event, Ok(ResponseEvent::Completed { .. })) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
.expect("submit first turn prompt");
|
||||
assert_eq!(
|
||||
first_request
|
||||
.single_request()
|
||||
|
|
@ -539,21 +465,13 @@ async fn responses_stream_includes_turn_metadata_header_for_git_workspace_e2e()
|
|||
.trim()
|
||||
.to_string();
|
||||
|
||||
let repo_root = std::fs::canonicalize(cwd)
|
||||
.unwrap_or_else(|_| cwd.to_path_buf())
|
||||
.to_string_lossy()
|
||||
.into_owned();
|
||||
let deadline = tokio::time::Instant::now() + std::time::Duration::from_secs(5);
|
||||
loop {
|
||||
let request_recorder = responses::mount_sse_once(&server, response_body.clone()).await;
|
||||
let mut session = client.new_session(Some(cwd.to_path_buf()));
|
||||
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
|
||||
let mut stream = session.stream(&prompt).await.expect("stream post-git turn");
|
||||
while let Some(event) = stream.next().await {
|
||||
if matches!(event, Ok(ResponseEvent::Completed { .. })) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
test.submit_turn("hello")
|
||||
.await
|
||||
.expect("submit post-git turn prompt");
|
||||
|
||||
let maybe_header = request_recorder
|
||||
.single_request()
|
||||
|
|
@ -561,11 +479,14 @@ async fn responses_stream_includes_turn_metadata_header_for_git_workspace_e2e()
|
|||
if let Some(header_value) = maybe_header {
|
||||
let parsed: serde_json::Value = serde_json::from_str(&header_value)
|
||||
.expect("x-codex-turn-metadata should be valid JSON");
|
||||
let workspace = parsed
|
||||
let workspaces = parsed
|
||||
.get("workspaces")
|
||||
.and_then(serde_json::Value::as_object)
|
||||
.and_then(|workspaces| workspaces.get(&repo_root))
|
||||
.expect("metadata should include cwd repo root workspace entry");
|
||||
.expect("metadata should include workspaces");
|
||||
let workspace = workspaces
|
||||
.values()
|
||||
.next()
|
||||
.expect("metadata should include at least one workspace entry");
|
||||
|
||||
assert_eq!(
|
||||
workspace
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue