Add a background job to refresh the requirements local cache (#12936)
- Update the cloud requirements cache TTL to 30 minutes. - Add a background job to refresh the cache every 5 minutes. - Ensure there is only one refresh job per process.
This commit is contained in:
parent
cee009d117
commit
f53612d3b2
1 changed files with 127 additions and 2 deletions
|
|
@ -28,17 +28,21 @@ use serde::Serialize;
|
|||
use sha2::Sha256;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use std::sync::Mutex;
|
||||
use std::sync::OnceLock;
|
||||
use std::time::Duration;
|
||||
use std::time::Instant;
|
||||
use thiserror::Error;
|
||||
use tokio::fs;
|
||||
use tokio::task::JoinHandle;
|
||||
use tokio::time::sleep;
|
||||
use tokio::time::timeout;
|
||||
|
||||
const CLOUD_REQUIREMENTS_TIMEOUT: Duration = Duration::from_secs(15);
|
||||
const CLOUD_REQUIREMENTS_MAX_ATTEMPTS: usize = 5;
|
||||
const CLOUD_REQUIREMENTS_CACHE_FILENAME: &str = "cloud-requirements-cache.json";
|
||||
const CLOUD_REQUIREMENTS_CACHE_TTL: Duration = Duration::from_secs(60 * 60);
|
||||
const CLOUD_REQUIREMENTS_CACHE_REFRESH_INTERVAL: Duration = Duration::from_secs(5 * 60);
|
||||
const CLOUD_REQUIREMENTS_CACHE_TTL: Duration = Duration::from_secs(30 * 60);
|
||||
const CLOUD_REQUIREMENTS_CACHE_WRITE_HMAC_KEY: &[u8] =
|
||||
b"codex-cloud-requirements-cache-v3-064f8542-75b4-494c-a294-97d3ce597271";
|
||||
const CLOUD_REQUIREMENTS_CACHE_READ_HMAC_KEYS: &[&[u8]] =
|
||||
|
|
@ -46,6 +50,11 @@ const CLOUD_REQUIREMENTS_CACHE_READ_HMAC_KEYS: &[&[u8]] =
|
|||
|
||||
type HmacSha256 = Hmac<Sha256>;
|
||||
|
||||
fn refresher_task_slot() -> &'static Mutex<Option<JoinHandle<()>>> {
|
||||
static REFRESHER_TASK: OnceLock<Mutex<Option<JoinHandle<()>>>> = OnceLock::new();
|
||||
REFRESHER_TASK.get_or_init(|| Mutex::new(None))
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
|
||||
enum FetchCloudRequirementsStatus {
|
||||
BackendClientInit,
|
||||
|
|
@ -188,6 +197,7 @@ impl RequirementsFetcher for BackendRequirementsFetcher {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct CloudRequirementsService {
|
||||
auth_manager: Arc<AuthManager>,
|
||||
fetcher: Arc<dyn RequirementsFetcher>,
|
||||
|
|
@ -325,6 +335,54 @@ impl CloudRequirementsService {
|
|||
None
|
||||
}
|
||||
|
||||
async fn refresh_cache_in_background(&self) {
|
||||
loop {
|
||||
sleep(CLOUD_REQUIREMENTS_CACHE_REFRESH_INTERVAL).await;
|
||||
match timeout(self.timeout, self.refresh_cache()).await {
|
||||
Ok(true) => {}
|
||||
Ok(false) => break,
|
||||
Err(_) => {
|
||||
tracing::warn!(
|
||||
"Timed out refreshing cloud requirements cache from remote; keeping existing cache"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn refresh_cache(&self) -> bool {
|
||||
let Some(auth) = self.auth_manager.auth().await else {
|
||||
return false;
|
||||
};
|
||||
if !auth.is_chatgpt_auth()
|
||||
|| !matches!(
|
||||
auth.account_plan_type(),
|
||||
Some(PlanType::Business | PlanType::Enterprise)
|
||||
)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
let token_data = auth.get_token_data().ok();
|
||||
let chatgpt_user_id = token_data
|
||||
.as_ref()
|
||||
.and_then(|token_data| token_data.id_token.chatgpt_user_id.as_deref());
|
||||
let account_id = auth.get_account_id();
|
||||
let account_id = account_id.as_deref();
|
||||
|
||||
if self
|
||||
.fetch_with_retries(&auth, chatgpt_user_id, account_id)
|
||||
.await
|
||||
.is_none()
|
||||
{
|
||||
tracing::warn!(
|
||||
path = %self.cache_path.display(),
|
||||
"Failed to refresh cloud requirements cache from remote"
|
||||
);
|
||||
}
|
||||
true
|
||||
}
|
||||
|
||||
async fn load_cache(
|
||||
&self,
|
||||
chatgpt_user_id: Option<&str>,
|
||||
|
|
@ -452,7 +510,17 @@ pub fn cloud_requirements_loader(
|
|||
codex_home,
|
||||
CLOUD_REQUIREMENTS_TIMEOUT,
|
||||
);
|
||||
let refresh_service = service.clone();
|
||||
let task = tokio::spawn(async move { service.fetch_with_timeout().await });
|
||||
let refresh_task =
|
||||
tokio::spawn(async move { refresh_service.refresh_cache_in_background().await });
|
||||
let mut refresher_guard = refresher_task_slot().lock().unwrap_or_else(|err| {
|
||||
tracing::warn!("cloud requirements refresher task slot was poisoned");
|
||||
err.into_inner()
|
||||
});
|
||||
if let Some(existing_task) = refresher_guard.replace(refresh_task) {
|
||||
existing_task.abort();
|
||||
}
|
||||
CloudRequirementsLoader::new(async move {
|
||||
task.await
|
||||
.inspect_err(|err| tracing::warn!(error = %err, "Cloud requirements task failed"))
|
||||
|
|
@ -1052,7 +1120,11 @@ mod tests {
|
|||
let cache_file: CloudRequirementsCacheFile =
|
||||
serde_json::from_str(&std::fs::read_to_string(path).expect("read cache"))
|
||||
.expect("parse cache");
|
||||
assert!(cache_file.signed_payload.expires_at > Utc::now());
|
||||
assert!(
|
||||
cache_file.signed_payload.expires_at
|
||||
<= cache_file.signed_payload.cached_at + ChronoDuration::minutes(30)
|
||||
);
|
||||
assert!(cache_file.signed_payload.expires_at > cache_file.signed_payload.cached_at);
|
||||
assert!(cache_file.signed_payload.cached_at <= Utc::now());
|
||||
assert_eq!(
|
||||
cache_file.signed_payload.chatgpt_user_id,
|
||||
|
|
@ -1130,4 +1202,57 @@ mod tests {
|
|||
CLOUD_REQUIREMENTS_MAX_ATTEMPTS
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn refresh_from_remote_updates_cached_cloud_requirements() {
|
||||
let codex_home = tempdir().expect("tempdir");
|
||||
let fetcher = Arc::new(SequenceFetcher::new(vec![
|
||||
Ok(Some("allowed_approval_policies = [\"never\"]".to_string())),
|
||||
Ok(Some(
|
||||
"allowed_approval_policies = [\"on-request\"]".to_string(),
|
||||
)),
|
||||
]));
|
||||
let service = CloudRequirementsService::new(
|
||||
auth_manager_with_plan("business"),
|
||||
fetcher,
|
||||
codex_home.path().to_path_buf(),
|
||||
CLOUD_REQUIREMENTS_TIMEOUT,
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
service.fetch().await,
|
||||
Some(ConfigRequirementsToml {
|
||||
allowed_approval_policies: Some(vec![AskForApproval::Never]),
|
||||
allowed_sandbox_modes: None,
|
||||
allowed_web_search_modes: None,
|
||||
mcp_servers: None,
|
||||
rules: None,
|
||||
enforce_residency: None,
|
||||
network: None,
|
||||
})
|
||||
);
|
||||
|
||||
service.refresh_cache().await;
|
||||
|
||||
let path = codex_home.path().join(CLOUD_REQUIREMENTS_CACHE_FILENAME);
|
||||
let cache_file: CloudRequirementsCacheFile =
|
||||
serde_json::from_str(&std::fs::read_to_string(path).expect("read cache"))
|
||||
.expect("parse cache");
|
||||
assert_eq!(
|
||||
cache_file
|
||||
.signed_payload
|
||||
.contents
|
||||
.as_deref()
|
||||
.and_then(|contents| parse_cloud_requirements(contents).ok().flatten()),
|
||||
Some(ConfigRequirementsToml {
|
||||
allowed_approval_policies: Some(vec![AskForApproval::OnRequest]),
|
||||
allowed_sandbox_modes: None,
|
||||
allowed_web_search_modes: None,
|
||||
mcp_servers: None,
|
||||
rules: None,
|
||||
enforce_residency: None,
|
||||
network: None,
|
||||
})
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue