From f53612d3b24e5f61ed4bd08a2b6fb4bd70a5b8fd Mon Sep 17 00:00:00 2001 From: alexsong-oai Date: Thu, 26 Feb 2026 20:16:19 -0800 Subject: [PATCH] Add a background job to refresh the requirements local cache (#12936) - Update the cloud requirements cache TTL to 30 minutes. - Add a background job to refresh the cache every 5 minutes. - Ensure there is only one refresh job per process. --- codex-rs/cloud-requirements/src/lib.rs | 129 ++++++++++++++++++++++++- 1 file changed, 127 insertions(+), 2 deletions(-) diff --git a/codex-rs/cloud-requirements/src/lib.rs b/codex-rs/cloud-requirements/src/lib.rs index 6f6bf3b6d..d89ff5352 100644 --- a/codex-rs/cloud-requirements/src/lib.rs +++ b/codex-rs/cloud-requirements/src/lib.rs @@ -28,17 +28,21 @@ use serde::Serialize; use sha2::Sha256; use std::path::PathBuf; use std::sync::Arc; +use std::sync::Mutex; +use std::sync::OnceLock; use std::time::Duration; use std::time::Instant; use thiserror::Error; use tokio::fs; +use tokio::task::JoinHandle; use tokio::time::sleep; use tokio::time::timeout; const CLOUD_REQUIREMENTS_TIMEOUT: Duration = Duration::from_secs(15); const CLOUD_REQUIREMENTS_MAX_ATTEMPTS: usize = 5; const CLOUD_REQUIREMENTS_CACHE_FILENAME: &str = "cloud-requirements-cache.json"; -const CLOUD_REQUIREMENTS_CACHE_TTL: Duration = Duration::from_secs(60 * 60); +const CLOUD_REQUIREMENTS_CACHE_REFRESH_INTERVAL: Duration = Duration::from_secs(5 * 60); +const CLOUD_REQUIREMENTS_CACHE_TTL: Duration = Duration::from_secs(30 * 60); const CLOUD_REQUIREMENTS_CACHE_WRITE_HMAC_KEY: &[u8] = b"codex-cloud-requirements-cache-v3-064f8542-75b4-494c-a294-97d3ce597271"; const CLOUD_REQUIREMENTS_CACHE_READ_HMAC_KEYS: &[&[u8]] = @@ -46,6 +50,11 @@ const CLOUD_REQUIREMENTS_CACHE_READ_HMAC_KEYS: &[&[u8]] = type HmacSha256 = Hmac; +fn refresher_task_slot() -> &'static Mutex>> { + static REFRESHER_TASK: OnceLock>>> = OnceLock::new(); + REFRESHER_TASK.get_or_init(|| Mutex::new(None)) +} + #[derive(Clone, Copy, Debug, Eq, PartialEq)] enum FetchCloudRequirementsStatus { BackendClientInit, @@ -188,6 +197,7 @@ impl RequirementsFetcher for BackendRequirementsFetcher { } } +#[derive(Clone)] struct CloudRequirementsService { auth_manager: Arc, fetcher: Arc, @@ -325,6 +335,54 @@ impl CloudRequirementsService { None } + async fn refresh_cache_in_background(&self) { + loop { + sleep(CLOUD_REQUIREMENTS_CACHE_REFRESH_INTERVAL).await; + match timeout(self.timeout, self.refresh_cache()).await { + Ok(true) => {} + Ok(false) => break, + Err(_) => { + tracing::warn!( + "Timed out refreshing cloud requirements cache from remote; keeping existing cache" + ); + } + } + } + } + + async fn refresh_cache(&self) -> bool { + let Some(auth) = self.auth_manager.auth().await else { + return false; + }; + if !auth.is_chatgpt_auth() + || !matches!( + auth.account_plan_type(), + Some(PlanType::Business | PlanType::Enterprise) + ) + { + return false; + } + + let token_data = auth.get_token_data().ok(); + let chatgpt_user_id = token_data + .as_ref() + .and_then(|token_data| token_data.id_token.chatgpt_user_id.as_deref()); + let account_id = auth.get_account_id(); + let account_id = account_id.as_deref(); + + if self + .fetch_with_retries(&auth, chatgpt_user_id, account_id) + .await + .is_none() + { + tracing::warn!( + path = %self.cache_path.display(), + "Failed to refresh cloud requirements cache from remote" + ); + } + true + } + async fn load_cache( &self, chatgpt_user_id: Option<&str>, @@ -452,7 +510,17 @@ pub fn cloud_requirements_loader( codex_home, CLOUD_REQUIREMENTS_TIMEOUT, ); + let refresh_service = service.clone(); let task = tokio::spawn(async move { service.fetch_with_timeout().await }); + let refresh_task = + tokio::spawn(async move { refresh_service.refresh_cache_in_background().await }); + let mut refresher_guard = refresher_task_slot().lock().unwrap_or_else(|err| { + tracing::warn!("cloud requirements refresher task slot was poisoned"); + err.into_inner() + }); + if let Some(existing_task) = refresher_guard.replace(refresh_task) { + existing_task.abort(); + } CloudRequirementsLoader::new(async move { task.await .inspect_err(|err| tracing::warn!(error = %err, "Cloud requirements task failed")) @@ -1052,7 +1120,11 @@ mod tests { let cache_file: CloudRequirementsCacheFile = serde_json::from_str(&std::fs::read_to_string(path).expect("read cache")) .expect("parse cache"); - assert!(cache_file.signed_payload.expires_at > Utc::now()); + assert!( + cache_file.signed_payload.expires_at + <= cache_file.signed_payload.cached_at + ChronoDuration::minutes(30) + ); + assert!(cache_file.signed_payload.expires_at > cache_file.signed_payload.cached_at); assert!(cache_file.signed_payload.cached_at <= Utc::now()); assert_eq!( cache_file.signed_payload.chatgpt_user_id, @@ -1130,4 +1202,57 @@ mod tests { CLOUD_REQUIREMENTS_MAX_ATTEMPTS ); } + + #[tokio::test] + async fn refresh_from_remote_updates_cached_cloud_requirements() { + let codex_home = tempdir().expect("tempdir"); + let fetcher = Arc::new(SequenceFetcher::new(vec![ + Ok(Some("allowed_approval_policies = [\"never\"]".to_string())), + Ok(Some( + "allowed_approval_policies = [\"on-request\"]".to_string(), + )), + ])); + let service = CloudRequirementsService::new( + auth_manager_with_plan("business"), + fetcher, + codex_home.path().to_path_buf(), + CLOUD_REQUIREMENTS_TIMEOUT, + ); + + assert_eq!( + service.fetch().await, + Some(ConfigRequirementsToml { + allowed_approval_policies: Some(vec![AskForApproval::Never]), + allowed_sandbox_modes: None, + allowed_web_search_modes: None, + mcp_servers: None, + rules: None, + enforce_residency: None, + network: None, + }) + ); + + service.refresh_cache().await; + + let path = codex_home.path().join(CLOUD_REQUIREMENTS_CACHE_FILENAME); + let cache_file: CloudRequirementsCacheFile = + serde_json::from_str(&std::fs::read_to_string(path).expect("read cache")) + .expect("parse cache"); + assert_eq!( + cache_file + .signed_payload + .contents + .as_deref() + .and_then(|contents| parse_cloud_requirements(contents).ok().flatten()), + Some(ConfigRequirementsToml { + allowed_approval_policies: Some(vec![AskForApproval::OnRequest]), + allowed_sandbox_modes: None, + allowed_web_search_modes: None, + mcp_servers: None, + rules: None, + enforce_residency: None, + network: None, + }) + ); + } }