diff --git a/codex-rs/rmcp-client/src/oauth.rs b/codex-rs/rmcp-client/src/oauth.rs index bd6833fca..f8eafaf23 100644 --- a/codex-rs/rmcp-client/src/oauth.rs +++ b/codex-rs/rmcp-client/src/oauth.rs @@ -50,6 +50,7 @@ use tokio::sync::Mutex; use crate::find_codex_home::find_codex_home; const KEYRING_SERVICE: &str = "Codex MCP Credentials"; +const REFRESH_SKEW_MILLIS: u64 = 30_000; #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] pub struct StoredOAuthTokens { @@ -57,6 +58,8 @@ pub struct StoredOAuthTokens { pub url: String, pub client_id: String, pub token_response: WrappedOAuthTokenResponse, + #[serde(default)] + pub expires_at: Option, } /// Determine where Codex should store and read MCP credentials. @@ -113,6 +116,22 @@ pub(crate) fn has_oauth_tokens( Ok(load_oauth_tokens(server_name, url, store_mode)?.is_some()) } +fn refresh_expires_in_from_timestamp(tokens: &mut StoredOAuthTokens) { + let Some(expires_at) = tokens.expires_at else { + return; + }; + + match expires_in_from_timestamp(expires_at) { + Some(seconds) => { + let duration = Duration::from_secs(seconds); + tokens.token_response.0.set_expires_in(Some(&duration)); + } + None => { + tokens.token_response.0.set_expires_in(None); + } + } +} + fn load_oauth_tokens_from_keyring_with_fallback_to_file( keyring_store: &K, server_name: &str, @@ -137,8 +156,9 @@ fn load_oauth_tokens_from_keyring( let key = compute_store_key(server_name, url)?; match keyring_store.load(KEYRING_SERVICE, &key) { Ok(Some(serialized)) => { - let tokens: StoredOAuthTokens = serde_json::from_str(&serialized) + let mut tokens: StoredOAuthTokens = serde_json::from_str(&serialized) .context("failed to deserialize OAuth tokens from keyring")?; + refresh_expires_in_from_timestamp(&mut tokens); Ok(Some(tokens)) } Ok(None) => Ok(None), @@ -286,13 +306,24 @@ impl OAuthPersistor { match maybe_credentials { Some(credentials) => { + let mut last_credentials = self.inner.last_credentials.lock().await; + let new_token_response = WrappedOAuthTokenResponse(credentials.clone()); + let same_token = last_credentials + .as_ref() + .map(|prev| prev.token_response == new_token_response) + .unwrap_or(false); + let expires_at = if same_token { + last_credentials.as_ref().and_then(|prev| prev.expires_at) + } else { + compute_expires_at_millis(&credentials) + }; let stored = StoredOAuthTokens { server_name: self.inner.server_name.clone(), url: self.inner.url.clone(), client_id, - token_response: WrappedOAuthTokenResponse(credentials.clone()), + token_response: new_token_response, + expires_at, }; - let mut last_credentials = self.inner.last_credentials.lock().await; if last_credentials.as_ref() != Some(&stored) { save_oauth_tokens(&self.inner.server_name, &stored, self.inner.store_mode)?; *last_credentials = Some(stored); @@ -317,6 +348,30 @@ impl OAuthPersistor { Ok(()) } + + pub(crate) async fn refresh_if_needed(&self) -> Result<()> { + let expires_at = { + let guard = self.inner.last_credentials.lock().await; + guard.as_ref().and_then(|tokens| tokens.expires_at) + }; + + if !token_needs_refresh(expires_at) { + return Ok(()); + } + + { + let manager = self.inner.authorization_manager.clone(); + let guard = manager.lock().await; + guard.refresh_token().await.with_context(|| { + format!( + "failed to refresh OAuth tokens for server {}", + self.inner.server_name + ) + })?; + } + + self.persist_if_needed().await + } } const FALLBACK_FILENAME: &str = ".credentials.json"; @@ -366,19 +421,14 @@ fn load_oauth_tokens_from_file(server_name: &str, url: &str) -> Result Result<()> { let mut store = read_fallback_file()?.unwrap_or_default(); let token_response = &tokens.token_response.0; + let expires_at = tokens + .expires_at + .or_else(|| compute_expires_at_millis(token_response)); let refresh_token = token_response .refresh_token() .map(|token| token.secret().to_string()); @@ -403,7 +456,7 @@ fn save_oauth_tokens_to_file(tokens: &StoredOAuthTokens) -> Result<()> { server_url: tokens.url.clone(), client_id: tokens.client_id.clone(), access_token: token_response.access_token().secret().to_string(), - expires_at: compute_expires_at_millis(token_response), + expires_at, refresh_token, scopes, }; @@ -427,7 +480,7 @@ fn delete_oauth_tokens_from_file(key: &str) -> Result { Ok(removed) } -fn compute_expires_at_millis(response: &OAuthTokenResponse) -> Option { +pub(crate) fn compute_expires_at_millis(response: &OAuthTokenResponse) -> Option { let expires_in = response.expires_in()?; let now = SystemTime::now() .duration_since(UNIX_EPOCH) @@ -454,6 +507,19 @@ fn expires_in_from_timestamp(expires_at: u64) -> Option { } } +fn token_needs_refresh(expires_at: Option) -> bool { + let Some(expires_at) = expires_at else { + return false; + }; + + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_else(|_| Duration::from_secs(0)) + .as_millis() as u64; + + now.saturating_add(REFRESH_SKEW_MILLIS) >= expires_at +} + fn compute_store_key(server_name: &str, server_url: &str) -> Result { let mut payload = JsonMap::new(); payload.insert( @@ -589,8 +655,9 @@ mod tests { store.save(KEYRING_SERVICE, &key, &serialized)?; let loaded = - super::load_oauth_tokens_from_keyring(&store, &tokens.server_name, &tokens.url)?; - assert_eq!(loaded, Some(expected)); + super::load_oauth_tokens_from_keyring(&store, &tokens.server_name, &tokens.url)? + .expect("tokens should load from keyring"); + assert_tokens_match_without_expiry(&loaded, &expected); Ok(()) } @@ -750,6 +817,43 @@ mod tests { Ok(()) } + #[test] + fn refresh_expires_in_from_timestamp_restores_future_durations() { + let mut tokens = sample_tokens(); + let expires_at = tokens.expires_at.expect("expires_at should be set"); + + tokens.token_response.0.set_expires_in(None); + super::refresh_expires_in_from_timestamp(&mut tokens); + + let actual = tokens + .token_response + .0 + .expires_in() + .expect("expires_in should be restored") + .as_secs(); + let expected = super::expires_in_from_timestamp(expires_at) + .expect("expires_at should still be in the future"); + let diff = actual.abs_diff(expected); + assert!(diff <= 1, "expires_in drift too large: diff={diff}"); + } + + #[test] + fn refresh_expires_in_from_timestamp_clears_expired_tokens() { + let mut tokens = sample_tokens(); + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_else(|_| Duration::from_secs(0)); + let expired_at = now.as_millis() as u64; + tokens.expires_at = Some(expired_at.saturating_sub(1000)); + + let duration = Duration::from_secs(600); + tokens.token_response.0.set_expires_in(Some(&duration)); + + super::refresh_expires_in_from_timestamp(&mut tokens); + + assert!(tokens.token_response.0.expires_in().is_none()); + } + fn assert_tokens_match_without_expiry( actual: &StoredOAuthTokens, expected: &StoredOAuthTokens, @@ -757,6 +861,7 @@ mod tests { assert_eq!(actual.server_name, expected.server_name); assert_eq!(actual.url, expected.url); assert_eq!(actual.client_id, expected.client_id); + assert_eq!(actual.expires_at, expected.expires_at); assert_token_response_match_without_expiry( &actual.token_response, &expected.token_response, @@ -803,12 +908,14 @@ mod tests { ])); let expires_in = Duration::from_secs(3600); response.set_expires_in(Some(&expires_in)); + let expires_at = super::compute_expires_at_millis(&response); StoredOAuthTokens { server_name: "test-server".to_string(), url: "https://example.test".to_string(), client_id: "client-id".to_string(), token_response: WrappedOAuthTokenResponse(response), + expires_at, } } } diff --git a/codex-rs/rmcp-client/src/perform_oauth_login.rs b/codex-rs/rmcp-client/src/perform_oauth_login.rs index 425e124d7..d8ffdd394 100644 --- a/codex-rs/rmcp-client/src/perform_oauth_login.rs +++ b/codex-rs/rmcp-client/src/perform_oauth_login.rs @@ -17,6 +17,7 @@ use urlencoding::decode; use crate::OAuthCredentialsStoreMode; use crate::StoredOAuthTokens; use crate::WrappedOAuthTokenResponse; +use crate::oauth::compute_expires_at_millis; use crate::save_oauth_tokens; use crate::utils::apply_default_headers; use crate::utils::build_default_headers; @@ -91,11 +92,13 @@ pub async fn perform_oauth_login( let credentials = credentials_opt.ok_or_else(|| anyhow!("OAuth provider did not return credentials"))?; + let expires_at = compute_expires_at_millis(&credentials); let stored = StoredOAuthTokens { server_name: server_name.to_string(), url: server_url.to_string(), client_id, token_response: WrappedOAuthTokenResponse(credentials), + expires_at, }; save_oauth_tokens(server_name, &stored, store_mode)?; diff --git a/codex-rs/rmcp-client/src/rmcp_client.rs b/codex-rs/rmcp-client/src/rmcp_client.rs index f21859dc9..d7d3477b0 100644 --- a/codex-rs/rmcp-client/src/rmcp_client.rs +++ b/codex-rs/rmcp-client/src/rmcp_client.rs @@ -267,6 +267,7 @@ impl RmcpClient { params: Option, timeout: Option, ) -> Result { + self.refresh_oauth_if_needed().await; let service = self.service().await?; let rmcp_params = params .map(convert_to_rmcp::<_, PaginatedRequestParam>) @@ -284,6 +285,7 @@ impl RmcpClient { params: Option, timeout: Option, ) -> Result { + self.refresh_oauth_if_needed().await; let service = self.service().await?; let rmcp_params = params .map(convert_to_rmcp::<_, PaginatedRequestParam>) @@ -301,6 +303,7 @@ impl RmcpClient { params: Option, timeout: Option, ) -> Result { + self.refresh_oauth_if_needed().await; let service = self.service().await?; let rmcp_params = params .map(convert_to_rmcp::<_, PaginatedRequestParam>) @@ -318,6 +321,7 @@ impl RmcpClient { params: ReadResourceRequestParams, timeout: Option, ) -> Result { + self.refresh_oauth_if_needed().await; let service = self.service().await?; let rmcp_params: ReadResourceRequestParam = convert_to_rmcp(params)?; let fut = service.read_resource(rmcp_params); @@ -333,6 +337,7 @@ impl RmcpClient { arguments: Option, timeout: Option, ) -> Result { + self.refresh_oauth_if_needed().await; let service = self.service().await?; let params = CallToolRequestParams { arguments, name }; let rmcp_params: CallToolRequestParam = convert_to_rmcp(params)?; @@ -371,6 +376,14 @@ impl RmcpClient { warn!("failed to persist OAuth tokens: {error}"); } } + + async fn refresh_oauth_if_needed(&self) { + if let Some(runtime) = self.oauth_persistor().await + && let Err(error) = runtime.refresh_if_needed().await + { + warn!("failed to refresh OAuth tokens: {error}"); + } + } } async fn create_oauth_transport_and_runtime(