🐛 fix(rmcp-client): refresh OAuth tokens using expires_at (#6574)
## Summary - persist OAuth credential expiry timestamps and rehydrate `expires_in` - proactively refresh rmcp OAuth tokens when `expires_at` is near, then persist ## Testing - just fmt - just fix -p codex-rmcp-client - cargo test -p codex-rmcp-client Fixes #6572
This commit is contained in:
parent
28ebe1c97a
commit
f3d4e210d8
3 changed files with 138 additions and 15 deletions
|
|
@ -50,6 +50,7 @@ use tokio::sync::Mutex;
|
|||
use crate::find_codex_home::find_codex_home;
|
||||
|
||||
const KEYRING_SERVICE: &str = "Codex MCP Credentials";
|
||||
const REFRESH_SKEW_MILLIS: u64 = 30_000;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub struct StoredOAuthTokens {
|
||||
|
|
@ -57,6 +58,8 @@ pub struct StoredOAuthTokens {
|
|||
pub url: String,
|
||||
pub client_id: String,
|
||||
pub token_response: WrappedOAuthTokenResponse,
|
||||
#[serde(default)]
|
||||
pub expires_at: Option<u64>,
|
||||
}
|
||||
|
||||
/// Determine where Codex should store and read MCP credentials.
|
||||
|
|
@ -113,6 +116,22 @@ pub(crate) fn has_oauth_tokens(
|
|||
Ok(load_oauth_tokens(server_name, url, store_mode)?.is_some())
|
||||
}
|
||||
|
||||
fn refresh_expires_in_from_timestamp(tokens: &mut StoredOAuthTokens) {
|
||||
let Some(expires_at) = tokens.expires_at else {
|
||||
return;
|
||||
};
|
||||
|
||||
match expires_in_from_timestamp(expires_at) {
|
||||
Some(seconds) => {
|
||||
let duration = Duration::from_secs(seconds);
|
||||
tokens.token_response.0.set_expires_in(Some(&duration));
|
||||
}
|
||||
None => {
|
||||
tokens.token_response.0.set_expires_in(None);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn load_oauth_tokens_from_keyring_with_fallback_to_file<K: KeyringStore>(
|
||||
keyring_store: &K,
|
||||
server_name: &str,
|
||||
|
|
@ -137,8 +156,9 @@ fn load_oauth_tokens_from_keyring<K: KeyringStore>(
|
|||
let key = compute_store_key(server_name, url)?;
|
||||
match keyring_store.load(KEYRING_SERVICE, &key) {
|
||||
Ok(Some(serialized)) => {
|
||||
let tokens: StoredOAuthTokens = serde_json::from_str(&serialized)
|
||||
let mut tokens: StoredOAuthTokens = serde_json::from_str(&serialized)
|
||||
.context("failed to deserialize OAuth tokens from keyring")?;
|
||||
refresh_expires_in_from_timestamp(&mut tokens);
|
||||
Ok(Some(tokens))
|
||||
}
|
||||
Ok(None) => Ok(None),
|
||||
|
|
@ -286,13 +306,24 @@ impl OAuthPersistor {
|
|||
|
||||
match maybe_credentials {
|
||||
Some(credentials) => {
|
||||
let mut last_credentials = self.inner.last_credentials.lock().await;
|
||||
let new_token_response = WrappedOAuthTokenResponse(credentials.clone());
|
||||
let same_token = last_credentials
|
||||
.as_ref()
|
||||
.map(|prev| prev.token_response == new_token_response)
|
||||
.unwrap_or(false);
|
||||
let expires_at = if same_token {
|
||||
last_credentials.as_ref().and_then(|prev| prev.expires_at)
|
||||
} else {
|
||||
compute_expires_at_millis(&credentials)
|
||||
};
|
||||
let stored = StoredOAuthTokens {
|
||||
server_name: self.inner.server_name.clone(),
|
||||
url: self.inner.url.clone(),
|
||||
client_id,
|
||||
token_response: WrappedOAuthTokenResponse(credentials.clone()),
|
||||
token_response: new_token_response,
|
||||
expires_at,
|
||||
};
|
||||
let mut last_credentials = self.inner.last_credentials.lock().await;
|
||||
if last_credentials.as_ref() != Some(&stored) {
|
||||
save_oauth_tokens(&self.inner.server_name, &stored, self.inner.store_mode)?;
|
||||
*last_credentials = Some(stored);
|
||||
|
|
@ -317,6 +348,30 @@ impl OAuthPersistor {
|
|||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) async fn refresh_if_needed(&self) -> Result<()> {
|
||||
let expires_at = {
|
||||
let guard = self.inner.last_credentials.lock().await;
|
||||
guard.as_ref().and_then(|tokens| tokens.expires_at)
|
||||
};
|
||||
|
||||
if !token_needs_refresh(expires_at) {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
{
|
||||
let manager = self.inner.authorization_manager.clone();
|
||||
let guard = manager.lock().await;
|
||||
guard.refresh_token().await.with_context(|| {
|
||||
format!(
|
||||
"failed to refresh OAuth tokens for server {}",
|
||||
self.inner.server_name
|
||||
)
|
||||
})?;
|
||||
}
|
||||
|
||||
self.persist_if_needed().await
|
||||
}
|
||||
}
|
||||
|
||||
const FALLBACK_FILENAME: &str = ".credentials.json";
|
||||
|
|
@ -366,19 +421,14 @@ fn load_oauth_tokens_from_file(server_name: &str, url: &str) -> Result<Option<St
|
|||
token_response.set_scopes(Some(scopes.into_iter().map(Scope::new).collect()));
|
||||
}
|
||||
|
||||
if let Some(expires_at) = entry.expires_at
|
||||
&& let Some(seconds) = expires_in_from_timestamp(expires_at)
|
||||
{
|
||||
let duration = Duration::from_secs(seconds);
|
||||
token_response.set_expires_in(Some(&duration));
|
||||
}
|
||||
|
||||
let stored = StoredOAuthTokens {
|
||||
let mut stored = StoredOAuthTokens {
|
||||
server_name: entry.server_name.clone(),
|
||||
url: entry.server_url.clone(),
|
||||
client_id: entry.client_id.clone(),
|
||||
token_response: WrappedOAuthTokenResponse(token_response),
|
||||
expires_at: entry.expires_at,
|
||||
};
|
||||
refresh_expires_in_from_timestamp(&mut stored);
|
||||
|
||||
return Ok(Some(stored));
|
||||
}
|
||||
|
|
@ -391,6 +441,9 @@ fn save_oauth_tokens_to_file(tokens: &StoredOAuthTokens) -> Result<()> {
|
|||
let mut store = read_fallback_file()?.unwrap_or_default();
|
||||
|
||||
let token_response = &tokens.token_response.0;
|
||||
let expires_at = tokens
|
||||
.expires_at
|
||||
.or_else(|| compute_expires_at_millis(token_response));
|
||||
let refresh_token = token_response
|
||||
.refresh_token()
|
||||
.map(|token| token.secret().to_string());
|
||||
|
|
@ -403,7 +456,7 @@ fn save_oauth_tokens_to_file(tokens: &StoredOAuthTokens) -> Result<()> {
|
|||
server_url: tokens.url.clone(),
|
||||
client_id: tokens.client_id.clone(),
|
||||
access_token: token_response.access_token().secret().to_string(),
|
||||
expires_at: compute_expires_at_millis(token_response),
|
||||
expires_at,
|
||||
refresh_token,
|
||||
scopes,
|
||||
};
|
||||
|
|
@ -427,7 +480,7 @@ fn delete_oauth_tokens_from_file(key: &str) -> Result<bool> {
|
|||
Ok(removed)
|
||||
}
|
||||
|
||||
fn compute_expires_at_millis(response: &OAuthTokenResponse) -> Option<u64> {
|
||||
pub(crate) fn compute_expires_at_millis(response: &OAuthTokenResponse) -> Option<u64> {
|
||||
let expires_in = response.expires_in()?;
|
||||
let now = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
|
|
@ -454,6 +507,19 @@ fn expires_in_from_timestamp(expires_at: u64) -> Option<u64> {
|
|||
}
|
||||
}
|
||||
|
||||
fn token_needs_refresh(expires_at: Option<u64>) -> bool {
|
||||
let Some(expires_at) = expires_at else {
|
||||
return false;
|
||||
};
|
||||
|
||||
let now = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap_or_else(|_| Duration::from_secs(0))
|
||||
.as_millis() as u64;
|
||||
|
||||
now.saturating_add(REFRESH_SKEW_MILLIS) >= expires_at
|
||||
}
|
||||
|
||||
fn compute_store_key(server_name: &str, server_url: &str) -> Result<String> {
|
||||
let mut payload = JsonMap::new();
|
||||
payload.insert(
|
||||
|
|
@ -589,8 +655,9 @@ mod tests {
|
|||
store.save(KEYRING_SERVICE, &key, &serialized)?;
|
||||
|
||||
let loaded =
|
||||
super::load_oauth_tokens_from_keyring(&store, &tokens.server_name, &tokens.url)?;
|
||||
assert_eq!(loaded, Some(expected));
|
||||
super::load_oauth_tokens_from_keyring(&store, &tokens.server_name, &tokens.url)?
|
||||
.expect("tokens should load from keyring");
|
||||
assert_tokens_match_without_expiry(&loaded, &expected);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
|
@ -750,6 +817,43 @@ mod tests {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn refresh_expires_in_from_timestamp_restores_future_durations() {
|
||||
let mut tokens = sample_tokens();
|
||||
let expires_at = tokens.expires_at.expect("expires_at should be set");
|
||||
|
||||
tokens.token_response.0.set_expires_in(None);
|
||||
super::refresh_expires_in_from_timestamp(&mut tokens);
|
||||
|
||||
let actual = tokens
|
||||
.token_response
|
||||
.0
|
||||
.expires_in()
|
||||
.expect("expires_in should be restored")
|
||||
.as_secs();
|
||||
let expected = super::expires_in_from_timestamp(expires_at)
|
||||
.expect("expires_at should still be in the future");
|
||||
let diff = actual.abs_diff(expected);
|
||||
assert!(diff <= 1, "expires_in drift too large: diff={diff}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn refresh_expires_in_from_timestamp_clears_expired_tokens() {
|
||||
let mut tokens = sample_tokens();
|
||||
let now = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap_or_else(|_| Duration::from_secs(0));
|
||||
let expired_at = now.as_millis() as u64;
|
||||
tokens.expires_at = Some(expired_at.saturating_sub(1000));
|
||||
|
||||
let duration = Duration::from_secs(600);
|
||||
tokens.token_response.0.set_expires_in(Some(&duration));
|
||||
|
||||
super::refresh_expires_in_from_timestamp(&mut tokens);
|
||||
|
||||
assert!(tokens.token_response.0.expires_in().is_none());
|
||||
}
|
||||
|
||||
fn assert_tokens_match_without_expiry(
|
||||
actual: &StoredOAuthTokens,
|
||||
expected: &StoredOAuthTokens,
|
||||
|
|
@ -757,6 +861,7 @@ mod tests {
|
|||
assert_eq!(actual.server_name, expected.server_name);
|
||||
assert_eq!(actual.url, expected.url);
|
||||
assert_eq!(actual.client_id, expected.client_id);
|
||||
assert_eq!(actual.expires_at, expected.expires_at);
|
||||
assert_token_response_match_without_expiry(
|
||||
&actual.token_response,
|
||||
&expected.token_response,
|
||||
|
|
@ -803,12 +908,14 @@ mod tests {
|
|||
]));
|
||||
let expires_in = Duration::from_secs(3600);
|
||||
response.set_expires_in(Some(&expires_in));
|
||||
let expires_at = super::compute_expires_at_millis(&response);
|
||||
|
||||
StoredOAuthTokens {
|
||||
server_name: "test-server".to_string(),
|
||||
url: "https://example.test".to_string(),
|
||||
client_id: "client-id".to_string(),
|
||||
token_response: WrappedOAuthTokenResponse(response),
|
||||
expires_at,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ use urlencoding::decode;
|
|||
use crate::OAuthCredentialsStoreMode;
|
||||
use crate::StoredOAuthTokens;
|
||||
use crate::WrappedOAuthTokenResponse;
|
||||
use crate::oauth::compute_expires_at_millis;
|
||||
use crate::save_oauth_tokens;
|
||||
use crate::utils::apply_default_headers;
|
||||
use crate::utils::build_default_headers;
|
||||
|
|
@ -91,11 +92,13 @@ pub async fn perform_oauth_login(
|
|||
let credentials =
|
||||
credentials_opt.ok_or_else(|| anyhow!("OAuth provider did not return credentials"))?;
|
||||
|
||||
let expires_at = compute_expires_at_millis(&credentials);
|
||||
let stored = StoredOAuthTokens {
|
||||
server_name: server_name.to_string(),
|
||||
url: server_url.to_string(),
|
||||
client_id,
|
||||
token_response: WrappedOAuthTokenResponse(credentials),
|
||||
expires_at,
|
||||
};
|
||||
save_oauth_tokens(server_name, &stored, store_mode)?;
|
||||
|
||||
|
|
|
|||
|
|
@ -267,6 +267,7 @@ impl RmcpClient {
|
|||
params: Option<ListToolsRequestParams>,
|
||||
timeout: Option<Duration>,
|
||||
) -> Result<ListToolsResult> {
|
||||
self.refresh_oauth_if_needed().await;
|
||||
let service = self.service().await?;
|
||||
let rmcp_params = params
|
||||
.map(convert_to_rmcp::<_, PaginatedRequestParam>)
|
||||
|
|
@ -284,6 +285,7 @@ impl RmcpClient {
|
|||
params: Option<ListResourcesRequestParams>,
|
||||
timeout: Option<Duration>,
|
||||
) -> Result<ListResourcesResult> {
|
||||
self.refresh_oauth_if_needed().await;
|
||||
let service = self.service().await?;
|
||||
let rmcp_params = params
|
||||
.map(convert_to_rmcp::<_, PaginatedRequestParam>)
|
||||
|
|
@ -301,6 +303,7 @@ impl RmcpClient {
|
|||
params: Option<ListResourceTemplatesRequestParams>,
|
||||
timeout: Option<Duration>,
|
||||
) -> Result<ListResourceTemplatesResult> {
|
||||
self.refresh_oauth_if_needed().await;
|
||||
let service = self.service().await?;
|
||||
let rmcp_params = params
|
||||
.map(convert_to_rmcp::<_, PaginatedRequestParam>)
|
||||
|
|
@ -318,6 +321,7 @@ impl RmcpClient {
|
|||
params: ReadResourceRequestParams,
|
||||
timeout: Option<Duration>,
|
||||
) -> Result<ReadResourceResult> {
|
||||
self.refresh_oauth_if_needed().await;
|
||||
let service = self.service().await?;
|
||||
let rmcp_params: ReadResourceRequestParam = convert_to_rmcp(params)?;
|
||||
let fut = service.read_resource(rmcp_params);
|
||||
|
|
@ -333,6 +337,7 @@ impl RmcpClient {
|
|||
arguments: Option<serde_json::Value>,
|
||||
timeout: Option<Duration>,
|
||||
) -> Result<CallToolResult> {
|
||||
self.refresh_oauth_if_needed().await;
|
||||
let service = self.service().await?;
|
||||
let params = CallToolRequestParams { arguments, name };
|
||||
let rmcp_params: CallToolRequestParam = convert_to_rmcp(params)?;
|
||||
|
|
@ -371,6 +376,14 @@ impl RmcpClient {
|
|||
warn!("failed to persist OAuth tokens: {error}");
|
||||
}
|
||||
}
|
||||
|
||||
async fn refresh_oauth_if_needed(&self) {
|
||||
if let Some(runtime) = self.oauth_persistor().await
|
||||
&& let Err(error) = runtime.refresh_if_needed().await
|
||||
{
|
||||
warn!("failed to refresh OAuth tokens: {error}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn create_oauth_transport_and_runtime(
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue