diff --git a/codex-rs/utils/git/src/branch.rs b/codex-rs/utils/git/src/branch.rs index 543dffd9c..f65de0c6b 100644 --- a/codex-rs/utils/git/src/branch.rs +++ b/codex-rs/utils/git/src/branch.rs @@ -7,7 +7,8 @@ use crate::operations::resolve_head; use crate::operations::resolve_repository_root; use crate::operations::run_git_for_stdout; -/// Returns the merge-base commit between `HEAD` and the provided branch, if both exist. +/// Returns the merge-base commit between `HEAD` and the latest version between local +/// and remote of the provided branch, if both exist. /// /// The function mirrors `git merge-base HEAD ` but returns `Ok(None)` when /// the repository has no `HEAD` yet or when the branch cannot be resolved. @@ -22,26 +23,23 @@ pub fn merge_base_with_head( None => return Ok(None), }; - let branch_ref = match run_git_for_stdout( - repo_root.as_path(), - vec![ - OsString::from("rev-parse"), - OsString::from("--verify"), - OsString::from(branch), - ], - None, - ) { - Ok(rev) => rev, - Err(GitToolingError::GitCommand { .. }) => return Ok(None), - Err(other) => return Err(other), + let Some(branch_ref) = resolve_branch_ref(repo_root.as_path(), branch)? else { + return Ok(None); }; + let preferred_ref = + if let Some(upstream) = resolve_upstream_if_remote_ahead(repo_root.as_path(), branch)? { + resolve_branch_ref(repo_root.as_path(), &upstream)?.unwrap_or(branch_ref) + } else { + branch_ref + }; + let merge_base = run_git_for_stdout( repo_root.as_path(), vec![ OsString::from("merge-base"), OsString::from(head), - OsString::from(branch_ref), + OsString::from(preferred_ref), ], None, )?; @@ -49,6 +47,75 @@ pub fn merge_base_with_head( Ok(Some(merge_base)) } +fn resolve_branch_ref(repo_root: &Path, branch: &str) -> Result, GitToolingError> { + let rev = run_git_for_stdout( + repo_root, + vec![ + OsString::from("rev-parse"), + OsString::from("--verify"), + OsString::from(branch), + ], + None, + ); + + match rev { + Ok(rev) => Ok(Some(rev)), + Err(GitToolingError::GitCommand { .. }) => Ok(None), + Err(other) => Err(other), + } +} + +fn resolve_upstream_if_remote_ahead( + repo_root: &Path, + branch: &str, +) -> Result, GitToolingError> { + let upstream = match run_git_for_stdout( + repo_root, + vec![ + OsString::from("rev-parse"), + OsString::from("--abbrev-ref"), + OsString::from("--symbolic-full-name"), + OsString::from(format!("{branch}@{{upstream}}")), + ], + None, + ) { + Ok(name) => { + let trimmed = name.trim(); + if trimmed.is_empty() { + return Ok(None); + } + trimmed.to_string() + } + Err(GitToolingError::GitCommand { .. }) => return Ok(None), + Err(other) => return Err(other), + }; + + let counts = match run_git_for_stdout( + repo_root, + vec![ + OsString::from("rev-list"), + OsString::from("--left-right"), + OsString::from("--count"), + OsString::from(format!("{branch}...{upstream}")), + ], + None, + ) { + Ok(counts) => counts, + Err(GitToolingError::GitCommand { .. }) => return Ok(None), + Err(other) => return Err(other), + }; + + let mut parts = counts.split_whitespace(); + let _left: i64 = parts.next().unwrap_or("0").parse().unwrap_or(0); + let right: i64 = parts.next().unwrap_or("0").parse().unwrap_or(0); + + if right > 0 { + Ok(Some(upstream)) + } else { + Ok(None) + } +} + #[cfg(test)] mod tests { use super::merge_base_with_head; @@ -126,6 +193,51 @@ mod tests { Ok(()) } + #[test] + fn merge_base_prefers_upstream_when_remote_ahead() -> Result<(), GitToolingError> { + let temp = tempdir()?; + let repo = temp.path().join("repo"); + let remote = temp.path().join("remote.git"); + std::fs::create_dir_all(&repo)?; + std::fs::create_dir_all(&remote)?; + + run_git_in(&remote, &["init", "--bare"]); + run_git_in(&repo, &["init", "--initial-branch=main"]); + run_git_in(&repo, &["config", "core.autocrlf", "false"]); + + std::fs::write(repo.join("base.txt"), "base\n")?; + run_git_in(&repo, &["add", "base.txt"]); + commit(&repo, "base commit"); + + run_git_in( + &repo, + &["remote", "add", "origin", remote.to_str().unwrap()], + ); + run_git_in(&repo, &["push", "-u", "origin", "main"]); + + run_git_in(&repo, &["checkout", "-b", "feature"]); + std::fs::write(repo.join("feature.txt"), "feature change\n")?; + run_git_in(&repo, &["add", "feature.txt"]); + commit(&repo, "feature commit"); + + run_git_in(&repo, &["checkout", "--orphan", "rewrite"]); + run_git_in(&repo, &["rm", "-rf", "."]); + std::fs::write(repo.join("new-main.txt"), "rewritten main\n")?; + run_git_in(&repo, &["add", "new-main.txt"]); + commit(&repo, "rewrite main"); + run_git_in(&repo, &["branch", "-M", "rewrite", "main"]); + run_git_in(&repo, &["branch", "--set-upstream-to=origin/main", "main"]); + + run_git_in(&repo, &["checkout", "feature"]); + run_git_in(&repo, &["fetch", "origin"]); + + let expected = run_git_stdout(&repo, &["merge-base", "HEAD", "origin/main"]); + let merge_base = merge_base_with_head(&repo, "main")?; + assert_eq!(merge_base, Some(expected)); + + Ok(()) + } + #[test] fn merge_base_returns_none_when_branch_missing() -> Result<(), GitToolingError> { let temp = tempdir()?;