diff --git a/codex-rs/core/src/auth.rs b/codex-rs/core/src/auth.rs index 2cb2ce51f..217334b2a 100644 --- a/codex-rs/core/src/auth.rs +++ b/codex-rs/core/src/auth.rs @@ -100,6 +100,7 @@ const REFRESH_TOKEN_REUSED_MESSAGE: &str = "Your access token could not be refre const REFRESH_TOKEN_INVALIDATED_MESSAGE: &str = "Your access token could not be refreshed because your refresh token was revoked. Please log out and sign in again."; const REFRESH_TOKEN_UNKNOWN_MESSAGE: &str = "Your access token could not be refreshed. Please log out and sign in again."; +const REFRESH_TOKEN_ACCOUNT_MISMATCH_MESSAGE: &str = "Your access token could not be refreshed because you have since logged out or signed in to another account. Please sign in again."; const REFRESH_TOKEN_URL: &str = "https://auth.openai.com/oauth/token"; pub const REFRESH_TOKEN_URL_OVERRIDE_ENV_VAR: &str = "CODEX_REFRESH_TOKEN_URL_OVERRIDE"; @@ -584,7 +585,8 @@ fn load_auth( Ok(Some(auth)) } -fn update_tokens( +// Persist refreshed tokens into auth storage and update last_refresh. +fn persist_tokens( storage: &Arc, id_token: Option, access_token: Option, @@ -609,7 +611,9 @@ fn update_tokens( Ok(auth_dot_json) } -async fn try_refresh_token( +// Requests refreshed ChatGPT OAuth tokens from the auth service using a refresh token. +// The caller is responsible for persisting any returned tokens. +async fn request_chatgpt_token_refresh( refresh_token: String, client: &CodexHttpClient, ) -> Result { @@ -823,7 +827,11 @@ enum UnauthorizedRecoveryStep { } enum ReloadOutcome { - Reloaded, + /// Reload was performed and the cached auth changed + ReloadedChanged, + /// Reload was performed and the cached auth remained the same + ReloadedNoChange, + /// Reload was skipped (missing or mismatched account id) Skipped, } @@ -910,17 +918,20 @@ impl UnauthorizedRecovery { .manager .reload_if_account_id_matches(self.expected_account_id.as_deref()) { - ReloadOutcome::Reloaded => { + ReloadOutcome::ReloadedChanged | ReloadOutcome::ReloadedNoChange => { self.step = UnauthorizedRecoveryStep::RefreshToken; } ReloadOutcome::Skipped => { - self.manager.refresh_token().await?; self.step = UnauthorizedRecoveryStep::Done; + return Err(RefreshTokenError::Permanent(RefreshTokenFailedError::new( + RefreshTokenFailedReason::Other, + REFRESH_TOKEN_ACCOUNT_MISMATCH_MESSAGE.to_string(), + ))); } } } UnauthorizedRecoveryStep::RefreshToken => { - self.manager.refresh_token().await?; + self.manager.refresh_token_from_authority().await?; self.step = UnauthorizedRecoveryStep::Done; } UnauthorizedRecoveryStep::ExternalRefresh => { @@ -1060,8 +1071,30 @@ impl AuthManager { } tracing::info!("Reloading auth for account {expected_account_id}"); + let cached_before_reload = self.auth_cached(); + let auth_changed = + !Self::auths_equal_for_refresh(cached_before_reload.as_ref(), new_auth.as_ref()); self.set_cached_auth(new_auth); - ReloadOutcome::Reloaded + if auth_changed { + ReloadOutcome::ReloadedChanged + } else { + ReloadOutcome::ReloadedNoChange + } + } + + fn auths_equal_for_refresh(a: Option<&CodexAuth>, b: Option<&CodexAuth>) -> bool { + match (a, b) { + (None, None) => true, + (Some(a), Some(b)) => match (a.api_auth_mode(), b.api_auth_mode()) { + (ApiAuthMode::ApiKey, ApiAuthMode::ApiKey) => a.api_key() == b.api_key(), + (ApiAuthMode::Chatgpt, ApiAuthMode::Chatgpt) + | (ApiAuthMode::ChatgptAuthTokens, ApiAuthMode::ChatgptAuthTokens) => { + a.get_current_auth_json() == b.get_current_auth_json() + } + _ => false, + }, + _ => false, + } } fn auths_equal(a: Option<&CodexAuth>, b: Option<&CodexAuth>) -> bool { @@ -1144,10 +1177,37 @@ impl AuthManager { UnauthorizedRecovery::new(Arc::clone(self)) } - /// Attempt to refresh the current auth token (if any). On success, reload - /// the auth state from disk so other components observe refreshed token. - /// If the token refresh fails, returns the error to the caller. + /// Attempt to refresh the token by first performing a guarded reload. Auth + /// is reloaded from storage only when the account id matches the currently + /// cached account id. If the persisted token differs from the cached token, we + /// can assume that some other instance already refreshed it. If the persisted + /// token is the same as the cached, then ask the token authority to refresh. pub async fn refresh_token(&self) -> Result<(), RefreshTokenError> { + let auth_before_reload = self.auth_cached(); + let expected_account_id = auth_before_reload + .as_ref() + .and_then(CodexAuth::get_account_id); + + match self.reload_if_account_id_matches(expected_account_id.as_deref()) { + ReloadOutcome::ReloadedChanged => { + tracing::info!("Skipping token refresh because auth changed after guarded reload."); + Ok(()) + } + ReloadOutcome::ReloadedNoChange => self.refresh_token_from_authority().await, + ReloadOutcome::Skipped => { + Err(RefreshTokenError::Permanent(RefreshTokenFailedError::new( + RefreshTokenFailedReason::Other, + REFRESH_TOKEN_ACCOUNT_MISMATCH_MESSAGE.to_string(), + ))) + } + } + } + + /// Attempt to refresh the current auth token from the authority that issued + /// the token. On success, reloads the auth state from disk so other components + /// observe refreshed token. If the token refresh fails, returns the error to + /// the caller. + pub async fn refresh_token_from_authority(&self) -> Result<(), RefreshTokenError> { tracing::info!("Refreshing token"); let auth = match self.auth_cached() { @@ -1165,10 +1225,8 @@ impl AuthManager { "Token data is not available.", )) })?; - self.refresh_tokens(&chatgpt_auth, token_data.refresh_token) + self.refresh_and_persist_chatgpt_token(&chatgpt_auth, token_data.refresh_token) .await?; - // Reload to pick up persisted changes. - self.reload(); Ok(()) } CodexAuth::ApiKey(_) => Ok(()), @@ -1215,9 +1273,8 @@ impl AuthManager { if last_refresh >= Utc::now() - chrono::Duration::days(TOKEN_REFRESH_INTERVAL) { return Ok(false); } - self.refresh_tokens(chatgpt_auth, tokens.refresh_token) + self.refresh_and_persist_chatgpt_token(chatgpt_auth, tokens.refresh_token) .await?; - self.reload(); Ok(true) } @@ -1273,20 +1330,23 @@ impl AuthManager { Ok(()) } - async fn refresh_tokens( + // Refreshes ChatGPT OAuth tokens, persists the updated auth state, and + // reloads the in-memory cache so callers immediately observe new tokens. + async fn refresh_and_persist_chatgpt_token( &self, auth: &ChatgptAuth, refresh_token: String, ) -> Result<(), RefreshTokenError> { - let refresh_response = try_refresh_token(refresh_token, auth.client()).await?; + let refresh_response = request_chatgpt_token_refresh(refresh_token, auth.client()).await?; - update_tokens( + persist_tokens( auth.storage(), refresh_response.id_token, refresh_response.access_token, refresh_response.refresh_token, ) .map_err(RefreshTokenError::from)?; + self.reload(); Ok(()) } @@ -1328,7 +1388,7 @@ mod tests { codex_home.path().to_path_buf(), AuthCredentialsStoreMode::File, ); - let updated = super::update_tokens( + let updated = super::persist_tokens( &storage, None, Some("new-access-token".to_string()), diff --git a/codex-rs/core/tests/suite/auth_refresh.rs b/codex-rs/core/tests/suite/auth_refresh.rs index a6be08f23..f5b13f091 100644 --- a/codex-rs/core/tests/suite/auth_refresh.rs +++ b/codex-rs/core/tests/suite/auth_refresh.rs @@ -17,7 +17,6 @@ use codex_core::token_data::TokenData; use core_test_support::skip_if_no_network; use pretty_assertions::assert_eq; use serde::Serialize; -use serde_json::Value; use serde_json::json; use std::ffi::OsString; use std::sync::Arc; @@ -58,6 +57,69 @@ async fn refresh_token_succeeds_updates_storage() -> Result<()> { }; ctx.write_auth(&initial_auth)?; + ctx.auth_manager + .refresh_token_from_authority() + .await + .context("refresh should succeed")?; + + let refreshed_tokens = TokenData { + access_token: "new-access-token".to_string(), + refresh_token: "new-refresh-token".to_string(), + ..initial_tokens.clone() + }; + let stored = ctx.load_auth()?; + let tokens = stored.tokens.as_ref().context("tokens should exist")?; + assert_eq!(tokens, &refreshed_tokens); + let refreshed_at = stored + .last_refresh + .as_ref() + .context("last_refresh should be recorded")?; + assert!( + *refreshed_at >= initial_last_refresh, + "last_refresh should advance" + ); + + let cached_auth = ctx + .auth_manager + .auth() + .await + .context("auth should be cached")?; + let cached = cached_auth + .get_token_data() + .context("token data should be cached")?; + assert_eq!(cached, refreshed_tokens); + + server.verify().await; + Ok(()) +} + +#[serial_test::serial(auth_refresh)] +#[tokio::test] +async fn refresh_token_refreshes_when_auth_is_unchanged() -> Result<()> { + skip_if_no_network!(Ok(())); + + let server = MockServer::start().await; + Mock::given(method("POST")) + .and(path("/oauth/token")) + .respond_with(ResponseTemplate::new(200).set_body_json(json!({ + "access_token": "new-access-token", + "refresh_token": "new-refresh-token" + }))) + .expect(1) + .mount(&server) + .await; + + let ctx = RefreshTokenTestContext::new(&server)?; + let initial_last_refresh = Utc::now() - Duration::days(1); + let initial_tokens = build_tokens(INITIAL_ACCESS_TOKEN, INITIAL_REFRESH_TOKEN); + let initial_auth = AuthDotJson { + auth_mode: Some(AuthMode::Chatgpt), + openai_api_key: None, + tokens: Some(initial_tokens.clone()), + last_refresh: Some(initial_last_refresh), + }; + ctx.write_auth(&initial_auth)?; + ctx.auth_manager .refresh_token() .await @@ -94,6 +156,128 @@ async fn refresh_token_succeeds_updates_storage() -> Result<()> { Ok(()) } +#[serial_test::serial(auth_refresh)] +#[tokio::test] +async fn refresh_token_skips_refresh_when_auth_changed() -> Result<()> { + skip_if_no_network!(Ok(())); + + let server = MockServer::start().await; + let ctx = RefreshTokenTestContext::new(&server)?; + + let initial_last_refresh = Utc::now() - Duration::days(1); + let initial_tokens = build_tokens(INITIAL_ACCESS_TOKEN, INITIAL_REFRESH_TOKEN); + let initial_auth = AuthDotJson { + auth_mode: Some(AuthMode::Chatgpt), + openai_api_key: None, + tokens: Some(initial_tokens), + last_refresh: Some(initial_last_refresh), + }; + ctx.write_auth(&initial_auth)?; + + let disk_tokens = build_tokens("disk-access-token", "disk-refresh-token"); + let disk_auth = AuthDotJson { + auth_mode: Some(AuthMode::Chatgpt), + openai_api_key: None, + tokens: Some(disk_tokens.clone()), + last_refresh: Some(initial_last_refresh), + }; + save_auth( + ctx.codex_home.path(), + &disk_auth, + AuthCredentialsStoreMode::File, + )?; + + ctx.auth_manager + .refresh_token() + .await + .context("refresh should be skipped")?; + + let stored = ctx.load_auth()?; + assert_eq!(stored, disk_auth); + + let cached_auth = ctx + .auth_manager + .auth_cached() + .context("auth should be cached")?; + let cached_tokens = cached_auth + .get_token_data() + .context("token data should be cached")?; + assert_eq!(cached_tokens, disk_tokens); + + let requests = server.received_requests().await.unwrap_or_default(); + assert!(requests.is_empty(), "expected no refresh token requests"); + + Ok(()) +} + +#[serial_test::serial(auth_refresh)] +#[tokio::test] +async fn refresh_token_errors_on_account_mismatch() -> Result<()> { + skip_if_no_network!(Ok(())); + + let server = MockServer::start().await; + Mock::given(method("POST")) + .and(path("/oauth/token")) + .respond_with(ResponseTemplate::new(200).set_body_json(json!({ + "access_token": "recovered-access-token", + "refresh_token": "recovered-refresh-token" + }))) + .expect(0) + .mount(&server) + .await; + + let ctx = RefreshTokenTestContext::new(&server)?; + let initial_last_refresh = Utc::now() - Duration::days(1); + let initial_tokens = build_tokens(INITIAL_ACCESS_TOKEN, INITIAL_REFRESH_TOKEN); + let initial_auth = AuthDotJson { + auth_mode: Some(AuthMode::Chatgpt), + openai_api_key: None, + tokens: Some(initial_tokens.clone()), + last_refresh: Some(initial_last_refresh), + }; + ctx.write_auth(&initial_auth)?; + + let mut disk_tokens = build_tokens("disk-access-token", "disk-refresh-token"); + disk_tokens.account_id = Some("other-account".to_string()); + let disk_auth = AuthDotJson { + auth_mode: Some(AuthMode::Chatgpt), + openai_api_key: None, + tokens: Some(disk_tokens), + last_refresh: Some(initial_last_refresh), + }; + save_auth( + ctx.codex_home.path(), + &disk_auth, + AuthCredentialsStoreMode::File, + )?; + + let err = ctx + .auth_manager + .refresh_token() + .await + .err() + .context("refresh should fail due to account mismatch")?; + assert_eq!(err.failed_reason(), Some(RefreshTokenFailedReason::Other)); + + let stored = ctx.load_auth()?; + assert_eq!(stored, disk_auth); + + let requests = server.received_requests().await.unwrap_or_default(); + assert!(requests.is_empty(), "expected no refresh token requests"); + + let cached_after = ctx + .auth_manager + .auth_cached() + .context("auth should be cached after refresh")?; + let cached_after_tokens = cached_after + .get_token_data() + .context("token data should remain cached")?; + assert_eq!(cached_after_tokens, initial_tokens); + + server.verify().await; + Ok(()) +} + #[serial_test::serial(auth_refresh)] #[tokio::test] async fn returns_fresh_tokens_as_is() -> Result<()> { @@ -227,7 +411,7 @@ async fn refresh_token_returns_permanent_error_for_expired_refresh_token() -> Re let err = ctx .auth_manager - .refresh_token() + .refresh_token_from_authority() .await .err() .context("refresh should fail")?; @@ -277,7 +461,7 @@ async fn refresh_token_returns_transient_error_on_server_failure() -> Result<()> let err = ctx .auth_manager - .refresh_token() + .refresh_token_from_authority() .await .err() .context("refresh should fail")?; @@ -394,7 +578,7 @@ async fn unauthorized_recovery_reloads_then_refreshes_tokens() -> Result<()> { #[serial_test::serial(auth_refresh)] #[tokio::test] -async fn unauthorized_recovery_skips_reload_on_account_mismatch() -> Result<()> { +async fn unauthorized_recovery_errors_on_account_mismatch() -> Result<()> { skip_if_no_network!(Ok(())); let server = MockServer::start().await; @@ -404,7 +588,7 @@ async fn unauthorized_recovery_skips_reload_on_account_mismatch() -> Result<()> "access_token": "recovered-access-token", "refresh_token": "recovered-refresh-token" }))) - .expect(1) + .expect(0) .mount(&server) .await; @@ -421,11 +605,6 @@ async fn unauthorized_recovery_skips_reload_on_account_mismatch() -> Result<()> let mut disk_tokens = build_tokens("disk-access-token", "disk-refresh-token"); disk_tokens.account_id = Some("other-account".to_string()); - let expected_tokens = TokenData { - access_token: "recovered-access-token".to_string(), - refresh_token: "recovered-refresh-token".to_string(), - ..disk_tokens.clone() - }; let disk_auth = AuthDotJson { auth_mode: Some(AuthMode::Chatgpt), openai_api_key: None, @@ -450,34 +629,27 @@ async fn unauthorized_recovery_skips_reload_on_account_mismatch() -> Result<()> let mut recovery = ctx.auth_manager.unauthorized_recovery(); assert!(recovery.has_next()); - recovery.next().await?; + let err = recovery + .next() + .await + .err() + .context("recovery should fail due to account mismatch")?; + assert_eq!(err.failed_reason(), Some(RefreshTokenFailedReason::Other)); let stored = ctx.load_auth()?; - let tokens = stored.tokens.as_ref().context("tokens should exist")?; - assert_eq!(tokens, &expected_tokens); + assert_eq!(stored, disk_auth); let requests = server.received_requests().await.unwrap_or_default(); - let request = requests - .first() - .context("expected a refresh token request")?; - let body: Value = - serde_json::from_slice(&request.body).context("refresh request body should be json")?; - let refresh_token = body - .get("refresh_token") - .and_then(Value::as_str) - .context("refresh_token should be set")?; - assert_eq!(refresh_token, INITIAL_REFRESH_TOKEN); + assert!(requests.is_empty(), "expected no refresh token requests"); let cached_after = ctx .auth_manager - .auth() - .await + .auth_cached() .context("auth should remain cached after refresh")?; let cached_after_tokens = cached_after .get_token_data() - .context("token data should reflect refreshed tokens")?; - assert_eq!(cached_after_tokens, expected_tokens); - assert!(!recovery.has_next()); + .context("token data should remain cached")?; + assert_eq!(cached_after_tokens, initial_tokens); server.verify().await; Ok(())