🐛 fix(rmcp-client): refresh OAuth tokens using expires_at (#6574)

## Summary
- persist OAuth credential expiry timestamps and rehydrate `expires_in`
- proactively refresh rmcp OAuth tokens when `expires_at` is near, then
persist

## Testing
- just fmt
- just fix -p codex-rmcp-client
- cargo test -p codex-rmcp-client

Fixes #6572
This commit is contained in:
Lael 2025-11-18 15:16:58 +08:00 committed by GitHub
parent 28ebe1c97a
commit f3d4e210d8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 138 additions and 15 deletions

View file

@ -50,6 +50,7 @@ use tokio::sync::Mutex;
use crate::find_codex_home::find_codex_home;
const KEYRING_SERVICE: &str = "Codex MCP Credentials";
const REFRESH_SKEW_MILLIS: u64 = 30_000;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct StoredOAuthTokens {
@ -57,6 +58,8 @@ pub struct StoredOAuthTokens {
pub url: String,
pub client_id: String,
pub token_response: WrappedOAuthTokenResponse,
#[serde(default)]
pub expires_at: Option<u64>,
}
/// Determine where Codex should store and read MCP credentials.
@ -113,6 +116,22 @@ pub(crate) fn has_oauth_tokens(
Ok(load_oauth_tokens(server_name, url, store_mode)?.is_some())
}
fn refresh_expires_in_from_timestamp(tokens: &mut StoredOAuthTokens) {
let Some(expires_at) = tokens.expires_at else {
return;
};
match expires_in_from_timestamp(expires_at) {
Some(seconds) => {
let duration = Duration::from_secs(seconds);
tokens.token_response.0.set_expires_in(Some(&duration));
}
None => {
tokens.token_response.0.set_expires_in(None);
}
}
}
fn load_oauth_tokens_from_keyring_with_fallback_to_file<K: KeyringStore>(
keyring_store: &K,
server_name: &str,
@ -137,8 +156,9 @@ fn load_oauth_tokens_from_keyring<K: KeyringStore>(
let key = compute_store_key(server_name, url)?;
match keyring_store.load(KEYRING_SERVICE, &key) {
Ok(Some(serialized)) => {
let tokens: StoredOAuthTokens = serde_json::from_str(&serialized)
let mut tokens: StoredOAuthTokens = serde_json::from_str(&serialized)
.context("failed to deserialize OAuth tokens from keyring")?;
refresh_expires_in_from_timestamp(&mut tokens);
Ok(Some(tokens))
}
Ok(None) => Ok(None),
@ -286,13 +306,24 @@ impl OAuthPersistor {
match maybe_credentials {
Some(credentials) => {
let mut last_credentials = self.inner.last_credentials.lock().await;
let new_token_response = WrappedOAuthTokenResponse(credentials.clone());
let same_token = last_credentials
.as_ref()
.map(|prev| prev.token_response == new_token_response)
.unwrap_or(false);
let expires_at = if same_token {
last_credentials.as_ref().and_then(|prev| prev.expires_at)
} else {
compute_expires_at_millis(&credentials)
};
let stored = StoredOAuthTokens {
server_name: self.inner.server_name.clone(),
url: self.inner.url.clone(),
client_id,
token_response: WrappedOAuthTokenResponse(credentials.clone()),
token_response: new_token_response,
expires_at,
};
let mut last_credentials = self.inner.last_credentials.lock().await;
if last_credentials.as_ref() != Some(&stored) {
save_oauth_tokens(&self.inner.server_name, &stored, self.inner.store_mode)?;
*last_credentials = Some(stored);
@ -317,6 +348,30 @@ impl OAuthPersistor {
Ok(())
}
pub(crate) async fn refresh_if_needed(&self) -> Result<()> {
let expires_at = {
let guard = self.inner.last_credentials.lock().await;
guard.as_ref().and_then(|tokens| tokens.expires_at)
};
if !token_needs_refresh(expires_at) {
return Ok(());
}
{
let manager = self.inner.authorization_manager.clone();
let guard = manager.lock().await;
guard.refresh_token().await.with_context(|| {
format!(
"failed to refresh OAuth tokens for server {}",
self.inner.server_name
)
})?;
}
self.persist_if_needed().await
}
}
const FALLBACK_FILENAME: &str = ".credentials.json";
@ -366,19 +421,14 @@ fn load_oauth_tokens_from_file(server_name: &str, url: &str) -> Result<Option<St
token_response.set_scopes(Some(scopes.into_iter().map(Scope::new).collect()));
}
if let Some(expires_at) = entry.expires_at
&& let Some(seconds) = expires_in_from_timestamp(expires_at)
{
let duration = Duration::from_secs(seconds);
token_response.set_expires_in(Some(&duration));
}
let stored = StoredOAuthTokens {
let mut stored = StoredOAuthTokens {
server_name: entry.server_name.clone(),
url: entry.server_url.clone(),
client_id: entry.client_id.clone(),
token_response: WrappedOAuthTokenResponse(token_response),
expires_at: entry.expires_at,
};
refresh_expires_in_from_timestamp(&mut stored);
return Ok(Some(stored));
}
@ -391,6 +441,9 @@ fn save_oauth_tokens_to_file(tokens: &StoredOAuthTokens) -> Result<()> {
let mut store = read_fallback_file()?.unwrap_or_default();
let token_response = &tokens.token_response.0;
let expires_at = tokens
.expires_at
.or_else(|| compute_expires_at_millis(token_response));
let refresh_token = token_response
.refresh_token()
.map(|token| token.secret().to_string());
@ -403,7 +456,7 @@ fn save_oauth_tokens_to_file(tokens: &StoredOAuthTokens) -> Result<()> {
server_url: tokens.url.clone(),
client_id: tokens.client_id.clone(),
access_token: token_response.access_token().secret().to_string(),
expires_at: compute_expires_at_millis(token_response),
expires_at,
refresh_token,
scopes,
};
@ -427,7 +480,7 @@ fn delete_oauth_tokens_from_file(key: &str) -> Result<bool> {
Ok(removed)
}
fn compute_expires_at_millis(response: &OAuthTokenResponse) -> Option<u64> {
pub(crate) fn compute_expires_at_millis(response: &OAuthTokenResponse) -> Option<u64> {
let expires_in = response.expires_in()?;
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
@ -454,6 +507,19 @@ fn expires_in_from_timestamp(expires_at: u64) -> Option<u64> {
}
}
fn token_needs_refresh(expires_at: Option<u64>) -> bool {
let Some(expires_at) = expires_at else {
return false;
};
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_else(|_| Duration::from_secs(0))
.as_millis() as u64;
now.saturating_add(REFRESH_SKEW_MILLIS) >= expires_at
}
fn compute_store_key(server_name: &str, server_url: &str) -> Result<String> {
let mut payload = JsonMap::new();
payload.insert(
@ -589,8 +655,9 @@ mod tests {
store.save(KEYRING_SERVICE, &key, &serialized)?;
let loaded =
super::load_oauth_tokens_from_keyring(&store, &tokens.server_name, &tokens.url)?;
assert_eq!(loaded, Some(expected));
super::load_oauth_tokens_from_keyring(&store, &tokens.server_name, &tokens.url)?
.expect("tokens should load from keyring");
assert_tokens_match_without_expiry(&loaded, &expected);
Ok(())
}
@ -750,6 +817,43 @@ mod tests {
Ok(())
}
#[test]
fn refresh_expires_in_from_timestamp_restores_future_durations() {
let mut tokens = sample_tokens();
let expires_at = tokens.expires_at.expect("expires_at should be set");
tokens.token_response.0.set_expires_in(None);
super::refresh_expires_in_from_timestamp(&mut tokens);
let actual = tokens
.token_response
.0
.expires_in()
.expect("expires_in should be restored")
.as_secs();
let expected = super::expires_in_from_timestamp(expires_at)
.expect("expires_at should still be in the future");
let diff = actual.abs_diff(expected);
assert!(diff <= 1, "expires_in drift too large: diff={diff}");
}
#[test]
fn refresh_expires_in_from_timestamp_clears_expired_tokens() {
let mut tokens = sample_tokens();
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_else(|_| Duration::from_secs(0));
let expired_at = now.as_millis() as u64;
tokens.expires_at = Some(expired_at.saturating_sub(1000));
let duration = Duration::from_secs(600);
tokens.token_response.0.set_expires_in(Some(&duration));
super::refresh_expires_in_from_timestamp(&mut tokens);
assert!(tokens.token_response.0.expires_in().is_none());
}
fn assert_tokens_match_without_expiry(
actual: &StoredOAuthTokens,
expected: &StoredOAuthTokens,
@ -757,6 +861,7 @@ mod tests {
assert_eq!(actual.server_name, expected.server_name);
assert_eq!(actual.url, expected.url);
assert_eq!(actual.client_id, expected.client_id);
assert_eq!(actual.expires_at, expected.expires_at);
assert_token_response_match_without_expiry(
&actual.token_response,
&expected.token_response,
@ -803,12 +908,14 @@ mod tests {
]));
let expires_in = Duration::from_secs(3600);
response.set_expires_in(Some(&expires_in));
let expires_at = super::compute_expires_at_millis(&response);
StoredOAuthTokens {
server_name: "test-server".to_string(),
url: "https://example.test".to_string(),
client_id: "client-id".to_string(),
token_response: WrappedOAuthTokenResponse(response),
expires_at,
}
}
}

View file

@ -17,6 +17,7 @@ use urlencoding::decode;
use crate::OAuthCredentialsStoreMode;
use crate::StoredOAuthTokens;
use crate::WrappedOAuthTokenResponse;
use crate::oauth::compute_expires_at_millis;
use crate::save_oauth_tokens;
use crate::utils::apply_default_headers;
use crate::utils::build_default_headers;
@ -91,11 +92,13 @@ pub async fn perform_oauth_login(
let credentials =
credentials_opt.ok_or_else(|| anyhow!("OAuth provider did not return credentials"))?;
let expires_at = compute_expires_at_millis(&credentials);
let stored = StoredOAuthTokens {
server_name: server_name.to_string(),
url: server_url.to_string(),
client_id,
token_response: WrappedOAuthTokenResponse(credentials),
expires_at,
};
save_oauth_tokens(server_name, &stored, store_mode)?;

View file

@ -267,6 +267,7 @@ impl RmcpClient {
params: Option<ListToolsRequestParams>,
timeout: Option<Duration>,
) -> Result<ListToolsResult> {
self.refresh_oauth_if_needed().await;
let service = self.service().await?;
let rmcp_params = params
.map(convert_to_rmcp::<_, PaginatedRequestParam>)
@ -284,6 +285,7 @@ impl RmcpClient {
params: Option<ListResourcesRequestParams>,
timeout: Option<Duration>,
) -> Result<ListResourcesResult> {
self.refresh_oauth_if_needed().await;
let service = self.service().await?;
let rmcp_params = params
.map(convert_to_rmcp::<_, PaginatedRequestParam>)
@ -301,6 +303,7 @@ impl RmcpClient {
params: Option<ListResourceTemplatesRequestParams>,
timeout: Option<Duration>,
) -> Result<ListResourceTemplatesResult> {
self.refresh_oauth_if_needed().await;
let service = self.service().await?;
let rmcp_params = params
.map(convert_to_rmcp::<_, PaginatedRequestParam>)
@ -318,6 +321,7 @@ impl RmcpClient {
params: ReadResourceRequestParams,
timeout: Option<Duration>,
) -> Result<ReadResourceResult> {
self.refresh_oauth_if_needed().await;
let service = self.service().await?;
let rmcp_params: ReadResourceRequestParam = convert_to_rmcp(params)?;
let fut = service.read_resource(rmcp_params);
@ -333,6 +337,7 @@ impl RmcpClient {
arguments: Option<serde_json::Value>,
timeout: Option<Duration>,
) -> Result<CallToolResult> {
self.refresh_oauth_if_needed().await;
let service = self.service().await?;
let params = CallToolRequestParams { arguments, name };
let rmcp_params: CallToolRequestParam = convert_to_rmcp(params)?;
@ -371,6 +376,14 @@ impl RmcpClient {
warn!("failed to persist OAuth tokens: {error}");
}
}
async fn refresh_oauth_if_needed(&self) {
if let Some(runtime) = self.oauth_persistor().await
&& let Err(error) = runtime.refresh_if_needed().await
{
warn!("failed to refresh OAuth tokens: {error}");
}
}
}
async fn create_oauth_transport_and_runtime(