diff --git a/hermes_cli/main.py b/hermes_cli/main.py index b835efb0f5..2a4aa7b2a1 100644 --- a/hermes_cli/main.py +++ b/hermes_cli/main.py @@ -2123,6 +2123,43 @@ def _restore_stashed_changes( return True +def _get_update_target(git_cmd: list[str]) -> tuple[str, str, str, str]: + """Resolve the current branch and the remote/branch it should update from. + + Returns ``(branch, remote, remote_branch, upstream_ref)`` where + *upstream_ref* is ``remote/remote_branch``. When the branch has a + configured upstream (e.g. tracking a fork remote), that upstream is + used; otherwise falls back to ``origin/``. + """ + result = subprocess.run( + git_cmd + ["rev-parse", "--abbrev-ref", "HEAD"], + cwd=PROJECT_ROOT, + capture_output=True, + text=True, + check=True, + ) + branch = result.stdout.strip() + if branch == "HEAD": + raise RuntimeError("Cannot update from a detached HEAD; check out a branch first.") + + remote = "origin" + remote_branch = branch + upstream_ref = f"{remote}/{remote_branch}" + + upstream = subprocess.run( + git_cmd + ["rev-parse", "--abbrev-ref", "--symbolic-full-name", "@{u}"], + cwd=PROJECT_ROOT, + capture_output=True, + text=True, + check=False, + ) + upstream_name = upstream.stdout.strip() + if upstream.returncode == 0 and upstream_name and "/" in upstream_name: + remote, remote_branch = upstream_name.split("/", 1) + upstream_ref = upstream_name + + return branch, remote, remote_branch, upstream_ref + def cmd_update(args): """Update Hermes Agent to the latest version.""" @@ -2164,29 +2201,21 @@ def cmd_update(args): if sys.platform == "win32": git_cmd = ["git", "-c", "windows.appendAtomically=false"] - subprocess.run(git_cmd + ["fetch", "origin"], cwd=PROJECT_ROOT, check=True) - - # Get current branch - result = subprocess.run( - git_cmd + ["rev-parse", "--abbrev-ref", "HEAD"], - cwd=PROJECT_ROOT, - capture_output=True, - text=True, - check=True - ) - branch = result.stdout.strip() + branch, remote, remote_branch, upstream_ref = _get_update_target(git_cmd) + subprocess.run(git_cmd + ["fetch", remote], cwd=PROJECT_ROOT, check=True) - # Fall back to main if the current branch doesn't exist on the remote + # Fall back to main if the resolved branch doesn't exist on the remote verify = subprocess.run( - git_cmd + ["rev-parse", "--verify", f"origin/{branch}"], + git_cmd + ["rev-parse", "--verify", upstream_ref], cwd=PROJECT_ROOT, capture_output=True, text=True, ) if verify.returncode != 0: - branch = "main" + remote, remote_branch, upstream_ref = "origin", "main", "origin/main" + subprocess.run(git_cmd + ["fetch", remote], cwd=PROJECT_ROOT, check=True) # Check if there are updates result = subprocess.run( - git_cmd + ["rev-list", f"HEAD..origin/{branch}", "--count"], + git_cmd + ["rev-list", f"HEAD..{upstream_ref}", "--count"], cwd=PROJECT_ROOT, capture_output=True, text=True, @@ -2205,7 +2234,7 @@ def cmd_update(args): print("→ Pulling updates...") try: - subprocess.run(git_cmd + ["pull", "origin", branch], cwd=PROJECT_ROOT, check=True) + subprocess.run(git_cmd + ["pull", remote, remote_branch], cwd=PROJECT_ROOT, check=True) finally: if auto_stash_ref is not None: _restore_stashed_changes( diff --git a/tests/hermes_cli/test_update_upstream.py b/tests/hermes_cli/test_update_upstream.py new file mode 100644 index 0000000000..69e3e298e0 --- /dev/null +++ b/tests/hermes_cli/test_update_upstream.py @@ -0,0 +1,87 @@ +"""Tests for _get_update_target — upstream branch resolution in hermes update.""" + +import subprocess +from unittest.mock import patch + +import pytest + + +# --------------------------------------------------------------------------- +# _get_update_target unit tests +# --------------------------------------------------------------------------- + +class TestGetUpdateTarget: + """Test upstream resolution for hermes update.""" + + def _make_run(self, head_output, upstream_output=None, upstream_rc=0): + """Build a subprocess.run side-effect that fakes git rev-parse calls.""" + def fake_run(cmd, **kwargs): + if "rev-parse" in cmd and "--abbrev-ref" in cmd and "HEAD" in cmd: + return subprocess.CompletedProcess(cmd, 0, stdout=head_output, stderr="") + if "rev-parse" in cmd and "@{u}" in cmd: + return subprocess.CompletedProcess( + cmd, upstream_rc, + stdout=upstream_output or "", stderr="" if upstream_rc == 0 else "fatal: no upstream\n", + ) + raise AssertionError(f"Unexpected command: {cmd}") + return fake_run + + def test_uses_tracking_upstream(self): + """Branch tracking fork/feature should resolve to that remote.""" + from hermes_cli.main import _get_update_target + + fake = self._make_run("fix/discord\n", "fork/fix/discord\n") + with patch("subprocess.run", side_effect=fake): + branch, remote, remote_branch, upstream_ref = _get_update_target(["git"]) + + assert branch == "fix/discord" + assert remote == "fork" + assert remote_branch == "fix/discord" + assert upstream_ref == "fork/fix/discord" + + def test_falls_back_to_origin_when_no_upstream(self): + """Branch with no upstream should default to origin/.""" + from hermes_cli.main import _get_update_target + + fake = self._make_run("feature/local\n", upstream_rc=128) + with patch("subprocess.run", side_effect=fake): + branch, remote, remote_branch, upstream_ref = _get_update_target(["git"]) + + assert branch == "feature/local" + assert remote == "origin" + assert remote_branch == "feature/local" + assert upstream_ref == "origin/feature/local" + + def test_main_branch_with_origin_upstream(self): + """Standard main branch tracking origin/main.""" + from hermes_cli.main import _get_update_target + + fake = self._make_run("main\n", "origin/main\n") + with patch("subprocess.run", side_effect=fake): + branch, remote, remote_branch, upstream_ref = _get_update_target(["git"]) + + assert branch == "main" + assert remote == "origin" + assert remote_branch == "main" + assert upstream_ref == "origin/main" + + def test_detached_head_raises(self): + """Detached HEAD should raise RuntimeError.""" + from hermes_cli.main import _get_update_target + + fake = self._make_run("HEAD\n") + with patch("subprocess.run", side_effect=fake): + with pytest.raises(RuntimeError, match="detached HEAD"): + _get_update_target(["git"]) + + def test_upstream_with_nested_slashes(self): + """Upstream like origin/fix/deep/path should split on first slash only.""" + from hermes_cli.main import _get_update_target + + fake = self._make_run("fix/deep/path\n", "origin/fix/deep/path\n") + with patch("subprocess.run", side_effect=fake): + branch, remote, remote_branch, upstream_ref = _get_update_target(["git"]) + + assert remote == "origin" + assert remote_branch == "fix/deep/path" + assert upstream_ref == "origin/fix/deep/path"