Add a background job to refresh the requirements local cache (#12936)

- Update the cloud requirements cache TTL to 30 minutes.
- Add a background job to refresh the cache every 5 minutes.
- Ensure there is only one refresh job per process.
This commit is contained in:
alexsong-oai 2026-02-26 20:16:19 -08:00 committed by GitHub
parent cee009d117
commit f53612d3b2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -28,17 +28,21 @@ use serde::Serialize;
use sha2::Sha256;
use std::path::PathBuf;
use std::sync::Arc;
use std::sync::Mutex;
use std::sync::OnceLock;
use std::time::Duration;
use std::time::Instant;
use thiserror::Error;
use tokio::fs;
use tokio::task::JoinHandle;
use tokio::time::sleep;
use tokio::time::timeout;
const CLOUD_REQUIREMENTS_TIMEOUT: Duration = Duration::from_secs(15);
const CLOUD_REQUIREMENTS_MAX_ATTEMPTS: usize = 5;
const CLOUD_REQUIREMENTS_CACHE_FILENAME: &str = "cloud-requirements-cache.json";
const CLOUD_REQUIREMENTS_CACHE_TTL: Duration = Duration::from_secs(60 * 60);
const CLOUD_REQUIREMENTS_CACHE_REFRESH_INTERVAL: Duration = Duration::from_secs(5 * 60);
const CLOUD_REQUIREMENTS_CACHE_TTL: Duration = Duration::from_secs(30 * 60);
const CLOUD_REQUIREMENTS_CACHE_WRITE_HMAC_KEY: &[u8] =
b"codex-cloud-requirements-cache-v3-064f8542-75b4-494c-a294-97d3ce597271";
const CLOUD_REQUIREMENTS_CACHE_READ_HMAC_KEYS: &[&[u8]] =
@ -46,6 +50,11 @@ const CLOUD_REQUIREMENTS_CACHE_READ_HMAC_KEYS: &[&[u8]] =
type HmacSha256 = Hmac<Sha256>;
fn refresher_task_slot() -> &'static Mutex<Option<JoinHandle<()>>> {
static REFRESHER_TASK: OnceLock<Mutex<Option<JoinHandle<()>>>> = OnceLock::new();
REFRESHER_TASK.get_or_init(|| Mutex::new(None))
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
enum FetchCloudRequirementsStatus {
BackendClientInit,
@ -188,6 +197,7 @@ impl RequirementsFetcher for BackendRequirementsFetcher {
}
}
#[derive(Clone)]
struct CloudRequirementsService {
auth_manager: Arc<AuthManager>,
fetcher: Arc<dyn RequirementsFetcher>,
@ -325,6 +335,54 @@ impl CloudRequirementsService {
None
}
async fn refresh_cache_in_background(&self) {
loop {
sleep(CLOUD_REQUIREMENTS_CACHE_REFRESH_INTERVAL).await;
match timeout(self.timeout, self.refresh_cache()).await {
Ok(true) => {}
Ok(false) => break,
Err(_) => {
tracing::warn!(
"Timed out refreshing cloud requirements cache from remote; keeping existing cache"
);
}
}
}
}
async fn refresh_cache(&self) -> bool {
let Some(auth) = self.auth_manager.auth().await else {
return false;
};
if !auth.is_chatgpt_auth()
|| !matches!(
auth.account_plan_type(),
Some(PlanType::Business | PlanType::Enterprise)
)
{
return false;
}
let token_data = auth.get_token_data().ok();
let chatgpt_user_id = token_data
.as_ref()
.and_then(|token_data| token_data.id_token.chatgpt_user_id.as_deref());
let account_id = auth.get_account_id();
let account_id = account_id.as_deref();
if self
.fetch_with_retries(&auth, chatgpt_user_id, account_id)
.await
.is_none()
{
tracing::warn!(
path = %self.cache_path.display(),
"Failed to refresh cloud requirements cache from remote"
);
}
true
}
async fn load_cache(
&self,
chatgpt_user_id: Option<&str>,
@ -452,7 +510,17 @@ pub fn cloud_requirements_loader(
codex_home,
CLOUD_REQUIREMENTS_TIMEOUT,
);
let refresh_service = service.clone();
let task = tokio::spawn(async move { service.fetch_with_timeout().await });
let refresh_task =
tokio::spawn(async move { refresh_service.refresh_cache_in_background().await });
let mut refresher_guard = refresher_task_slot().lock().unwrap_or_else(|err| {
tracing::warn!("cloud requirements refresher task slot was poisoned");
err.into_inner()
});
if let Some(existing_task) = refresher_guard.replace(refresh_task) {
existing_task.abort();
}
CloudRequirementsLoader::new(async move {
task.await
.inspect_err(|err| tracing::warn!(error = %err, "Cloud requirements task failed"))
@ -1052,7 +1120,11 @@ mod tests {
let cache_file: CloudRequirementsCacheFile =
serde_json::from_str(&std::fs::read_to_string(path).expect("read cache"))
.expect("parse cache");
assert!(cache_file.signed_payload.expires_at > Utc::now());
assert!(
cache_file.signed_payload.expires_at
<= cache_file.signed_payload.cached_at + ChronoDuration::minutes(30)
);
assert!(cache_file.signed_payload.expires_at > cache_file.signed_payload.cached_at);
assert!(cache_file.signed_payload.cached_at <= Utc::now());
assert_eq!(
cache_file.signed_payload.chatgpt_user_id,
@ -1130,4 +1202,57 @@ mod tests {
CLOUD_REQUIREMENTS_MAX_ATTEMPTS
);
}
#[tokio::test]
async fn refresh_from_remote_updates_cached_cloud_requirements() {
let codex_home = tempdir().expect("tempdir");
let fetcher = Arc::new(SequenceFetcher::new(vec![
Ok(Some("allowed_approval_policies = [\"never\"]".to_string())),
Ok(Some(
"allowed_approval_policies = [\"on-request\"]".to_string(),
)),
]));
let service = CloudRequirementsService::new(
auth_manager_with_plan("business"),
fetcher,
codex_home.path().to_path_buf(),
CLOUD_REQUIREMENTS_TIMEOUT,
);
assert_eq!(
service.fetch().await,
Some(ConfigRequirementsToml {
allowed_approval_policies: Some(vec![AskForApproval::Never]),
allowed_sandbox_modes: None,
allowed_web_search_modes: None,
mcp_servers: None,
rules: None,
enforce_residency: None,
network: None,
})
);
service.refresh_cache().await;
let path = codex_home.path().join(CLOUD_REQUIREMENTS_CACHE_FILENAME);
let cache_file: CloudRequirementsCacheFile =
serde_json::from_str(&std::fs::read_to_string(path).expect("read cache"))
.expect("parse cache");
assert_eq!(
cache_file
.signed_payload
.contents
.as_deref()
.and_then(|contents| parse_cloud_requirements(contents).ok().flatten()),
Some(ConfigRequirementsToml {
allowed_approval_policies: Some(vec![AskForApproval::OnRequest]),
allowed_sandbox_modes: None,
allowed_web_search_modes: None,
mcp_servers: None,
rules: None,
enforce_residency: None,
network: None,
})
);
}
}