feat: support mcp in-session login (#7751)
### Summary * Added `mcpServer/oauthLogin` in app server for supporting in session MCP server login * Added `McpServerOauthLoginParams` and `McpServerOauthLoginResponse` to support above method with response returning the auth URL for consumer to open browser or display accordingly. * Added `McpServerOauthLoginCompletedNotification` which the app server would emit on MCP server login success or failure (i.e. timeout). * Refactored rmcp-client oath_login to have the ability on starting a auth server which the codex_message_processor uses for in-session auth.
This commit is contained in:
parent
fa4cac1e6b
commit
893f5261eb
8 changed files with 392 additions and 59 deletions
1
codex-rs/Cargo.lock
generated
1
codex-rs/Cargo.lock
generated
|
|
@ -887,6 +887,7 @@ dependencies = [
|
|||
"codex-file-search",
|
||||
"codex-login",
|
||||
"codex-protocol",
|
||||
"codex-rmcp-client",
|
||||
"codex-utils-json-to-toml",
|
||||
"core_test_support",
|
||||
"mcp-types",
|
||||
|
|
|
|||
|
|
@ -139,6 +139,11 @@ client_request_definitions! {
|
|||
response: v2::ModelListResponse,
|
||||
},
|
||||
|
||||
McpServerOauthLogin => "mcpServer/oauth/login" {
|
||||
params: v2::McpServerOauthLoginParams,
|
||||
response: v2::McpServerOauthLoginResponse,
|
||||
},
|
||||
|
||||
McpServersList => "mcpServers/list" {
|
||||
params: v2::ListMcpServersParams,
|
||||
response: v2::ListMcpServersResponse,
|
||||
|
|
@ -524,6 +529,7 @@ server_notification_definitions! {
|
|||
CommandExecutionOutputDelta => "item/commandExecution/outputDelta" (v2::CommandExecutionOutputDeltaNotification),
|
||||
FileChangeOutputDelta => "item/fileChange/outputDelta" (v2::FileChangeOutputDeltaNotification),
|
||||
McpToolCallProgress => "item/mcpToolCall/progress" (v2::McpToolCallProgressNotification),
|
||||
McpServerOauthLoginCompleted => "mcpServer/oauthLogin/completed" (v2::McpServerOauthLoginCompletedNotification),
|
||||
AccountUpdated => "account/updated" (v2::AccountUpdatedNotification),
|
||||
AccountRateLimitsUpdated => "account/rateLimits/updated" (v2::AccountRateLimitsUpdatedNotification),
|
||||
ReasoningSummaryTextDelta => "item/reasoning/summaryTextDelta" (v2::ReasoningSummaryTextDeltaNotification),
|
||||
|
|
|
|||
|
|
@ -688,6 +688,26 @@ pub struct ListMcpServersResponse {
|
|||
pub next_cursor: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, JsonSchema, TS)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
#[ts(export_to = "v2/")]
|
||||
pub struct McpServerOauthLoginParams {
|
||||
pub name: String,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
#[ts(optional)]
|
||||
pub scopes: Option<Vec<String>>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
#[ts(optional)]
|
||||
pub timeout_secs: Option<i64>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, JsonSchema, TS)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
#[ts(export_to = "v2/")]
|
||||
pub struct McpServerOauthLoginResponse {
|
||||
pub authorization_url: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, JsonSchema, TS)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
#[ts(export_to = "v2/")]
|
||||
|
|
@ -1467,6 +1487,17 @@ pub struct McpToolCallProgressNotification {
|
|||
pub message: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, JsonSchema, TS)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
#[ts(export_to = "v2/")]
|
||||
pub struct McpServerOauthLoginCompletedNotification {
|
||||
pub name: String,
|
||||
pub success: bool,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
#[ts(optional)]
|
||||
pub error: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, JsonSchema, TS)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
#[ts(export_to = "v2/")]
|
||||
|
|
|
|||
|
|
@ -26,6 +26,7 @@ codex-login = { workspace = true }
|
|||
codex-protocol = { workspace = true }
|
||||
codex-app-server-protocol = { workspace = true }
|
||||
codex-feedback = { workspace = true }
|
||||
codex-rmcp-client = { workspace = true }
|
||||
codex-utils-json-to-toml = { workspace = true }
|
||||
chrono = { workspace = true }
|
||||
serde = { workspace = true, features = ["derive"] }
|
||||
|
|
|
|||
|
|
@ -55,6 +55,9 @@ use codex_app_server_protocol::LoginChatGptResponse;
|
|||
use codex_app_server_protocol::LogoutAccountResponse;
|
||||
use codex_app_server_protocol::LogoutChatGptResponse;
|
||||
use codex_app_server_protocol::McpServer;
|
||||
use codex_app_server_protocol::McpServerOauthLoginCompletedNotification;
|
||||
use codex_app_server_protocol::McpServerOauthLoginParams;
|
||||
use codex_app_server_protocol::McpServerOauthLoginResponse;
|
||||
use codex_app_server_protocol::ModelListParams;
|
||||
use codex_app_server_protocol::ModelListResponse;
|
||||
use codex_app_server_protocol::NewConversationParams;
|
||||
|
|
@ -115,6 +118,7 @@ use codex_core::config::Config;
|
|||
use codex_core::config::ConfigOverrides;
|
||||
use codex_core::config::ConfigToml;
|
||||
use codex_core::config::edit::ConfigEditsBuilder;
|
||||
use codex_core::config::types::McpServerTransportConfig;
|
||||
use codex_core::config_loader::load_config_as_toml;
|
||||
use codex_core::default_client::get_codex_user_agent;
|
||||
use codex_core::exec::ExecParams;
|
||||
|
|
@ -147,6 +151,7 @@ use codex_protocol::protocol::RolloutItem;
|
|||
use codex_protocol::protocol::SessionMetaLine;
|
||||
use codex_protocol::protocol::USER_MESSAGE_BEGIN;
|
||||
use codex_protocol::user_input::UserInput as CoreInputItem;
|
||||
use codex_rmcp_client::perform_oauth_login_return_url;
|
||||
use codex_utils_json_to_toml::json_to_toml;
|
||||
use std::collections::HashMap;
|
||||
use std::collections::HashSet;
|
||||
|
|
@ -161,6 +166,7 @@ use std::time::Duration;
|
|||
use tokio::select;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::sync::oneshot;
|
||||
use toml::Value as TomlValue;
|
||||
use tracing::error;
|
||||
use tracing::info;
|
||||
use tracing::warn;
|
||||
|
|
@ -198,6 +204,7 @@ pub(crate) struct CodexMessageProcessor {
|
|||
outgoing: Arc<OutgoingMessageSender>,
|
||||
codex_linux_sandbox_exe: Option<PathBuf>,
|
||||
config: Arc<Config>,
|
||||
cli_overrides: Vec<(String, TomlValue)>,
|
||||
conversation_listeners: HashMap<Uuid, oneshot::Sender<()>>,
|
||||
active_login: Arc<Mutex<Option<ActiveLogin>>>,
|
||||
// Queue of pending interrupt requests per conversation. We reply when TurnAborted arrives.
|
||||
|
|
@ -244,6 +251,7 @@ impl CodexMessageProcessor {
|
|||
outgoing: Arc<OutgoingMessageSender>,
|
||||
codex_linux_sandbox_exe: Option<PathBuf>,
|
||||
config: Arc<Config>,
|
||||
cli_overrides: Vec<(String, TomlValue)>,
|
||||
feedback: CodexFeedback,
|
||||
) -> Self {
|
||||
Self {
|
||||
|
|
@ -252,6 +260,7 @@ impl CodexMessageProcessor {
|
|||
outgoing,
|
||||
codex_linux_sandbox_exe,
|
||||
config,
|
||||
cli_overrides,
|
||||
conversation_listeners: HashMap::new(),
|
||||
active_login: Arc::new(Mutex::new(None)),
|
||||
pending_interrupts: Arc::new(Mutex::new(HashMap::new())),
|
||||
|
|
@ -261,6 +270,16 @@ impl CodexMessageProcessor {
|
|||
}
|
||||
}
|
||||
|
||||
async fn load_latest_config(&self) -> Result<Config, JSONRPCErrorError> {
|
||||
Config::load_with_cli_overrides(self.cli_overrides.clone(), ConfigOverrides::default())
|
||||
.await
|
||||
.map_err(|err| JSONRPCErrorError {
|
||||
code: INTERNAL_ERROR_CODE,
|
||||
message: format!("failed to reload config: {err}"),
|
||||
data: None,
|
||||
})
|
||||
}
|
||||
|
||||
fn review_request_from_target(
|
||||
target: ApiReviewTarget,
|
||||
) -> Result<(ReviewRequest, String), JSONRPCErrorError> {
|
||||
|
|
@ -369,6 +388,9 @@ impl CodexMessageProcessor {
|
|||
ClientRequest::ModelList { request_id, params } => {
|
||||
self.list_models(request_id, params).await;
|
||||
}
|
||||
ClientRequest::McpServerOauthLogin { request_id, params } => {
|
||||
self.mcp_server_oauth_login(request_id, params).await;
|
||||
}
|
||||
ClientRequest::McpServersList { request_id, params } => {
|
||||
self.list_mcp_servers(request_id, params).await;
|
||||
}
|
||||
|
|
@ -1916,6 +1938,110 @@ impl CodexMessageProcessor {
|
|||
self.outgoing.send_response(request_id, response).await;
|
||||
}
|
||||
|
||||
async fn mcp_server_oauth_login(
|
||||
&self,
|
||||
request_id: RequestId,
|
||||
params: McpServerOauthLoginParams,
|
||||
) {
|
||||
let config = match self.load_latest_config().await {
|
||||
Ok(config) => config,
|
||||
Err(error) => {
|
||||
self.outgoing.send_error(request_id, error).await;
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
if !config.features.enabled(Feature::RmcpClient) {
|
||||
let error = JSONRPCErrorError {
|
||||
code: INVALID_REQUEST_ERROR_CODE,
|
||||
message: "OAuth login is only supported when [features].rmcp_client is true in config.toml".to_string(),
|
||||
data: None,
|
||||
};
|
||||
self.outgoing.send_error(request_id, error).await;
|
||||
return;
|
||||
}
|
||||
|
||||
let McpServerOauthLoginParams {
|
||||
name,
|
||||
scopes,
|
||||
timeout_secs,
|
||||
} = params;
|
||||
|
||||
let Some(server) = config.mcp_servers.get(&name) else {
|
||||
let error = JSONRPCErrorError {
|
||||
code: INVALID_REQUEST_ERROR_CODE,
|
||||
message: format!("No MCP server named '{name}' found."),
|
||||
data: None,
|
||||
};
|
||||
self.outgoing.send_error(request_id, error).await;
|
||||
return;
|
||||
};
|
||||
|
||||
let (url, http_headers, env_http_headers) = match &server.transport {
|
||||
McpServerTransportConfig::StreamableHttp {
|
||||
url,
|
||||
http_headers,
|
||||
env_http_headers,
|
||||
..
|
||||
} => (url.clone(), http_headers.clone(), env_http_headers.clone()),
|
||||
_ => {
|
||||
let error = JSONRPCErrorError {
|
||||
code: INVALID_REQUEST_ERROR_CODE,
|
||||
message: "OAuth login is only supported for streamable HTTP servers."
|
||||
.to_string(),
|
||||
data: None,
|
||||
};
|
||||
self.outgoing.send_error(request_id, error).await;
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
match perform_oauth_login_return_url(
|
||||
&name,
|
||||
&url,
|
||||
config.mcp_oauth_credentials_store_mode,
|
||||
http_headers,
|
||||
env_http_headers,
|
||||
scopes.as_deref().unwrap_or_default(),
|
||||
timeout_secs,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(handle) => {
|
||||
let authorization_url = handle.authorization_url().to_string();
|
||||
let notification_name = name.clone();
|
||||
let outgoing = Arc::clone(&self.outgoing);
|
||||
|
||||
tokio::spawn(async move {
|
||||
let (success, error) = match handle.wait().await {
|
||||
Ok(()) => (true, None),
|
||||
Err(err) => (false, Some(err.to_string())),
|
||||
};
|
||||
|
||||
let notification = ServerNotification::McpServerOauthLoginCompleted(
|
||||
McpServerOauthLoginCompletedNotification {
|
||||
name: notification_name,
|
||||
success,
|
||||
error,
|
||||
},
|
||||
);
|
||||
outgoing.send_server_notification(notification).await;
|
||||
});
|
||||
|
||||
let response = McpServerOauthLoginResponse { authorization_url };
|
||||
self.outgoing.send_response(request_id, response).await;
|
||||
}
|
||||
Err(err) => {
|
||||
let error = JSONRPCErrorError {
|
||||
code: INTERNAL_ERROR_CODE,
|
||||
message: format!("failed to login to MCP server '{name}': {err}"),
|
||||
data: None,
|
||||
};
|
||||
self.outgoing.send_error(request_id, error).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn list_mcp_servers(&self, request_id: RequestId, params: ListMcpServersParams) {
|
||||
let snapshot = collect_mcp_snapshot(self.config.as_ref()).await;
|
||||
|
||||
|
|
|
|||
|
|
@ -59,6 +59,7 @@ impl MessageProcessor {
|
|||
outgoing.clone(),
|
||||
codex_linux_sandbox_exe,
|
||||
Arc::clone(&config),
|
||||
cli_overrides.clone(),
|
||||
feedback,
|
||||
);
|
||||
let config_api = ConfigApi::new(config.codex_home.clone(), cli_overrides);
|
||||
|
|
|
|||
|
|
@ -16,7 +16,9 @@ pub use oauth::WrappedOAuthTokenResponse;
|
|||
pub use oauth::delete_oauth_tokens;
|
||||
pub(crate) use oauth::load_oauth_tokens;
|
||||
pub use oauth::save_oauth_tokens;
|
||||
pub use perform_oauth_login::OauthLoginHandle;
|
||||
pub use perform_oauth_login::perform_oauth_login;
|
||||
pub use perform_oauth_login::perform_oauth_login_return_url;
|
||||
pub use rmcp::model::ElicitationAction;
|
||||
pub use rmcp_client::Elicitation;
|
||||
pub use rmcp_client::ElicitationResponse;
|
||||
|
|
|
|||
|
|
@ -22,6 +22,11 @@ use crate::save_oauth_tokens;
|
|||
use crate::utils::apply_default_headers;
|
||||
use crate::utils::build_default_headers;
|
||||
|
||||
struct OauthHeaders {
|
||||
http_headers: Option<HashMap<String, String>>,
|
||||
env_http_headers: Option<HashMap<String, String>>,
|
||||
}
|
||||
|
||||
struct CallbackServerGuard {
|
||||
server: Arc<Server>,
|
||||
}
|
||||
|
|
@ -40,70 +45,52 @@ pub async fn perform_oauth_login(
|
|||
env_http_headers: Option<HashMap<String, String>>,
|
||||
scopes: &[String],
|
||||
) -> Result<()> {
|
||||
let server = Arc::new(Server::http("127.0.0.1:0").map_err(|err| anyhow!(err))?);
|
||||
let guard = CallbackServerGuard {
|
||||
server: Arc::clone(&server),
|
||||
let headers = OauthHeaders {
|
||||
http_headers,
|
||||
env_http_headers,
|
||||
};
|
||||
OauthLoginFlow::new(
|
||||
server_name,
|
||||
server_url,
|
||||
store_mode,
|
||||
headers,
|
||||
scopes,
|
||||
true,
|
||||
None,
|
||||
)
|
||||
.await?
|
||||
.finish()
|
||||
.await
|
||||
}
|
||||
|
||||
let redirect_uri = match server.server_addr() {
|
||||
tiny_http::ListenAddr::IP(std::net::SocketAddr::V4(addr)) => {
|
||||
format!("http://{}:{}/callback", addr.ip(), addr.port())
|
||||
}
|
||||
tiny_http::ListenAddr::IP(std::net::SocketAddr::V6(addr)) => {
|
||||
format!("http://[{}]:{}/callback", addr.ip(), addr.port())
|
||||
}
|
||||
#[cfg(not(target_os = "windows"))]
|
||||
_ => return Err(anyhow!("unable to determine callback address")),
|
||||
pub async fn perform_oauth_login_return_url(
|
||||
server_name: &str,
|
||||
server_url: &str,
|
||||
store_mode: OAuthCredentialsStoreMode,
|
||||
http_headers: Option<HashMap<String, String>>,
|
||||
env_http_headers: Option<HashMap<String, String>>,
|
||||
scopes: &[String],
|
||||
timeout_secs: Option<i64>,
|
||||
) -> Result<OauthLoginHandle> {
|
||||
let headers = OauthHeaders {
|
||||
http_headers,
|
||||
env_http_headers,
|
||||
};
|
||||
let flow = OauthLoginFlow::new(
|
||||
server_name,
|
||||
server_url,
|
||||
store_mode,
|
||||
headers,
|
||||
scopes,
|
||||
false,
|
||||
timeout_secs,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let (tx, rx) = oneshot::channel();
|
||||
spawn_callback_server(server, tx);
|
||||
let authorization_url = flow.authorization_url();
|
||||
let completion = flow.spawn();
|
||||
|
||||
let default_headers = build_default_headers(http_headers, env_http_headers)?;
|
||||
let http_client = apply_default_headers(ClientBuilder::new(), &default_headers).build()?;
|
||||
|
||||
let mut oauth_state = OAuthState::new(server_url, Some(http_client)).await?;
|
||||
let scope_refs: Vec<&str> = scopes.iter().map(String::as_str).collect();
|
||||
oauth_state
|
||||
.start_authorization(&scope_refs, &redirect_uri, Some("Codex"))
|
||||
.await?;
|
||||
let auth_url = oauth_state.get_authorization_url().await?;
|
||||
|
||||
println!("Authorize `{server_name}` by opening this URL in your browser:\n{auth_url}\n");
|
||||
|
||||
if webbrowser::open(&auth_url).is_err() {
|
||||
println!("(Browser launch failed; please copy the URL above manually.)");
|
||||
}
|
||||
|
||||
let (code, csrf_state) = timeout(Duration::from_secs(300), rx)
|
||||
.await
|
||||
.context("timed out waiting for OAuth callback")?
|
||||
.context("OAuth callback was cancelled")?;
|
||||
|
||||
oauth_state
|
||||
.handle_callback(&code, &csrf_state)
|
||||
.await
|
||||
.context("failed to handle OAuth callback")?;
|
||||
|
||||
let (client_id, credentials_opt) = oauth_state
|
||||
.get_credentials()
|
||||
.await
|
||||
.context("failed to retrieve OAuth credentials")?;
|
||||
let credentials =
|
||||
credentials_opt.ok_or_else(|| anyhow!("OAuth provider did not return credentials"))?;
|
||||
|
||||
let expires_at = compute_expires_at_millis(&credentials);
|
||||
let stored = StoredOAuthTokens {
|
||||
server_name: server_name.to_string(),
|
||||
url: server_url.to_string(),
|
||||
client_id,
|
||||
token_response: WrappedOAuthTokenResponse(credentials),
|
||||
expires_at,
|
||||
};
|
||||
save_oauth_tokens(server_name, &stored, store_mode)?;
|
||||
|
||||
drop(guard);
|
||||
Ok(())
|
||||
Ok(OauthLoginHandle::new(authorization_url, completion))
|
||||
}
|
||||
|
||||
fn spawn_callback_server(server: Arc<Server>, tx: oneshot::Sender<(String, String)>) {
|
||||
|
|
@ -160,3 +147,181 @@ fn parse_oauth_callback(path: &str) -> Option<OauthCallbackResult> {
|
|||
state: state?,
|
||||
})
|
||||
}
|
||||
|
||||
pub struct OauthLoginHandle {
|
||||
authorization_url: String,
|
||||
completion: oneshot::Receiver<Result<()>>,
|
||||
}
|
||||
|
||||
impl OauthLoginHandle {
|
||||
fn new(authorization_url: String, completion: oneshot::Receiver<Result<()>>) -> Self {
|
||||
Self {
|
||||
authorization_url,
|
||||
completion,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn authorization_url(&self) -> &str {
|
||||
&self.authorization_url
|
||||
}
|
||||
|
||||
pub fn into_parts(self) -> (String, oneshot::Receiver<Result<()>>) {
|
||||
(self.authorization_url, self.completion)
|
||||
}
|
||||
|
||||
pub async fn wait(self) -> Result<()> {
|
||||
self.completion
|
||||
.await
|
||||
.map_err(|err| anyhow!("OAuth login task was cancelled: {err}"))?
|
||||
}
|
||||
}
|
||||
|
||||
struct OauthLoginFlow {
|
||||
auth_url: String,
|
||||
oauth_state: OAuthState,
|
||||
rx: oneshot::Receiver<(String, String)>,
|
||||
guard: CallbackServerGuard,
|
||||
server_name: String,
|
||||
server_url: String,
|
||||
store_mode: OAuthCredentialsStoreMode,
|
||||
launch_browser: bool,
|
||||
timeout: Duration,
|
||||
}
|
||||
|
||||
impl OauthLoginFlow {
|
||||
async fn new(
|
||||
server_name: &str,
|
||||
server_url: &str,
|
||||
store_mode: OAuthCredentialsStoreMode,
|
||||
headers: OauthHeaders,
|
||||
scopes: &[String],
|
||||
launch_browser: bool,
|
||||
timeout_secs: Option<i64>,
|
||||
) -> Result<Self> {
|
||||
const DEFAULT_OAUTH_TIMEOUT_SECS: i64 = 300;
|
||||
|
||||
let server = Arc::new(Server::http("127.0.0.1:0").map_err(|err| anyhow!(err))?);
|
||||
let guard = CallbackServerGuard {
|
||||
server: Arc::clone(&server),
|
||||
};
|
||||
|
||||
let redirect_uri = match server.server_addr() {
|
||||
tiny_http::ListenAddr::IP(std::net::SocketAddr::V4(addr)) => {
|
||||
let ip = addr.ip();
|
||||
let port = addr.port();
|
||||
format!("http://{ip}:{port}/callback")
|
||||
}
|
||||
tiny_http::ListenAddr::IP(std::net::SocketAddr::V6(addr)) => {
|
||||
let ip = addr.ip();
|
||||
let port = addr.port();
|
||||
format!("http://[{ip}]:{port}/callback")
|
||||
}
|
||||
#[cfg(not(target_os = "windows"))]
|
||||
_ => return Err(anyhow!("unable to determine callback address")),
|
||||
};
|
||||
|
||||
let (tx, rx) = oneshot::channel();
|
||||
spawn_callback_server(server, tx);
|
||||
|
||||
let OauthHeaders {
|
||||
http_headers,
|
||||
env_http_headers,
|
||||
} = headers;
|
||||
let default_headers = build_default_headers(http_headers, env_http_headers)?;
|
||||
let http_client = apply_default_headers(ClientBuilder::new(), &default_headers).build()?;
|
||||
|
||||
let mut oauth_state = OAuthState::new(server_url, Some(http_client)).await?;
|
||||
let scope_refs: Vec<&str> = scopes.iter().map(String::as_str).collect();
|
||||
oauth_state
|
||||
.start_authorization(&scope_refs, &redirect_uri, Some("Codex"))
|
||||
.await?;
|
||||
let auth_url = oauth_state.get_authorization_url().await?;
|
||||
let timeout_secs = timeout_secs.unwrap_or(DEFAULT_OAUTH_TIMEOUT_SECS).max(1);
|
||||
let timeout = Duration::from_secs(timeout_secs as u64);
|
||||
|
||||
Ok(Self {
|
||||
auth_url,
|
||||
oauth_state,
|
||||
rx,
|
||||
guard,
|
||||
server_name: server_name.to_string(),
|
||||
server_url: server_url.to_string(),
|
||||
store_mode,
|
||||
launch_browser,
|
||||
timeout,
|
||||
})
|
||||
}
|
||||
|
||||
fn authorization_url(&self) -> String {
|
||||
self.auth_url.clone()
|
||||
}
|
||||
|
||||
async fn finish(mut self) -> Result<()> {
|
||||
if self.launch_browser {
|
||||
let server_name = &self.server_name;
|
||||
let auth_url = &self.auth_url;
|
||||
println!(
|
||||
"Authorize `{server_name}` by opening this URL in your browser:\n{auth_url}\n"
|
||||
);
|
||||
|
||||
if webbrowser::open(auth_url).is_err() {
|
||||
println!("(Browser launch failed; please copy the URL above manually.)");
|
||||
}
|
||||
}
|
||||
|
||||
let result = async {
|
||||
let (code, csrf_state) = timeout(self.timeout, &mut self.rx)
|
||||
.await
|
||||
.context("timed out waiting for OAuth callback")?
|
||||
.context("OAuth callback was cancelled")?;
|
||||
|
||||
self.oauth_state
|
||||
.handle_callback(&code, &csrf_state)
|
||||
.await
|
||||
.context("failed to handle OAuth callback")?;
|
||||
|
||||
let (client_id, credentials_opt) = self
|
||||
.oauth_state
|
||||
.get_credentials()
|
||||
.await
|
||||
.context("failed to retrieve OAuth credentials")?;
|
||||
let credentials = credentials_opt
|
||||
.ok_or_else(|| anyhow!("OAuth provider did not return credentials"))?;
|
||||
|
||||
let expires_at = compute_expires_at_millis(&credentials);
|
||||
let stored = StoredOAuthTokens {
|
||||
server_name: self.server_name.clone(),
|
||||
url: self.server_url.clone(),
|
||||
client_id,
|
||||
token_response: WrappedOAuthTokenResponse(credentials),
|
||||
expires_at,
|
||||
};
|
||||
save_oauth_tokens(&self.server_name, &stored, self.store_mode)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
.await;
|
||||
|
||||
drop(self.guard);
|
||||
result
|
||||
}
|
||||
|
||||
fn spawn(self) -> oneshot::Receiver<Result<()>> {
|
||||
let server_name_for_logging = self.server_name.clone();
|
||||
let (tx, rx) = oneshot::channel();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let result = self.finish().await;
|
||||
|
||||
if let Err(err) = &result {
|
||||
eprintln!(
|
||||
"Failed to complete OAuth login for '{server_name_for_logging}': {err:#}"
|
||||
);
|
||||
}
|
||||
|
||||
let _ = tx.send(result);
|
||||
});
|
||||
|
||||
rx
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue