mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-05-03 09:17:09 +08:00
tar xf - -C / extracts the staging directory tree to the remote root. GNU tar default behavior overwrites metadata (including mode) of existing directories. When the local umask is 002 (Ubuntu default), the staging dirs are 0775, and tar chmod's /home/<user> to 0775 — breaking sshd StrictModes which requires 0755 or stricter for home dirs. Add --no-overwrite-dir to the remote tar command so existing directory metadata is preserved. Fixes #17767
292 lines
12 KiB
Python
292 lines
12 KiB
Python
"""SSH remote execution environment with ControlMaster connection persistence."""
|
|
|
|
import hashlib
|
|
import logging
|
|
import os
|
|
import shlex
|
|
import shutil
|
|
import subprocess
|
|
import tempfile
|
|
from pathlib import Path
|
|
|
|
from tools.environments.base import BaseEnvironment, _popen_bash
|
|
from tools.environments.file_sync import (
|
|
FileSyncManager,
|
|
iter_sync_files,
|
|
quoted_mkdir_command,
|
|
quoted_rm_command,
|
|
unique_parent_dirs,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def _ensure_ssh_available() -> None:
|
|
"""Fail fast with a clear error when the SSH client is unavailable."""
|
|
if not shutil.which("ssh"):
|
|
raise RuntimeError(
|
|
"SSH is not installed or not in PATH. Install OpenSSH client: apt install openssh-client"
|
|
)
|
|
|
|
|
|
class SSHEnvironment(BaseEnvironment):
|
|
"""Run commands on a remote machine over SSH.
|
|
|
|
Spawn-per-call: every execute() spawns a fresh ``ssh ... bash -c`` process.
|
|
Session snapshot preserves env vars across calls.
|
|
CWD persists via in-band stdout markers.
|
|
Uses SSH ControlMaster for connection reuse.
|
|
"""
|
|
|
|
def __init__(self, host: str, user: str, cwd: str = "~",
|
|
timeout: int = 60, port: int = 22, key_path: str = ""):
|
|
super().__init__(cwd=cwd, timeout=timeout)
|
|
self.host = host
|
|
self.user = user
|
|
self.port = port
|
|
self.key_path = key_path
|
|
|
|
self.control_dir = Path(tempfile.gettempdir()) / "hermes-ssh"
|
|
self.control_dir.mkdir(parents=True, exist_ok=True)
|
|
# Keep the socket filename short and deterministic so the full path
|
|
# stays under the 104-byte sun_path limit that macOS enforces on
|
|
# Unix domain sockets. A raw ``user@host:port`` — especially with an
|
|
# IPv6 host — plus the 16-byte random suffix SSH appends in
|
|
# ControlMaster mode easily exceeds the limit under macOS's
|
|
# deeply-nested $TMPDIR (e.g. /var/folders/xx/yy/T/). Hashing the
|
|
# triple keeps the path stable across reconnects so ControlMaster
|
|
# reuse still works.
|
|
_socket_id = hashlib.sha256(
|
|
f"{user}@{host}:{port}".encode()
|
|
).hexdigest()[:16]
|
|
self.control_socket = self.control_dir / f"{_socket_id}.sock"
|
|
_ensure_ssh_available()
|
|
self._establish_connection()
|
|
self._remote_home = self._detect_remote_home()
|
|
|
|
self._ensure_remote_dirs()
|
|
self._sync_manager = FileSyncManager(
|
|
get_files_fn=lambda: iter_sync_files(f"{self._remote_home}/.hermes"),
|
|
upload_fn=self._scp_upload,
|
|
delete_fn=self._ssh_delete,
|
|
bulk_upload_fn=self._ssh_bulk_upload,
|
|
bulk_download_fn=self._ssh_bulk_download,
|
|
)
|
|
self._sync_manager.sync(force=True)
|
|
|
|
self.init_session()
|
|
|
|
def _build_ssh_command(self, extra_args: list | None = None) -> list:
|
|
cmd = ["ssh"]
|
|
cmd.extend(["-o", f"ControlPath={self.control_socket}"])
|
|
cmd.extend(["-o", "ControlMaster=auto"])
|
|
cmd.extend(["-o", "ControlPersist=300"])
|
|
cmd.extend(["-o", "BatchMode=yes"])
|
|
cmd.extend(["-o", "StrictHostKeyChecking=accept-new"])
|
|
cmd.extend(["-o", "ConnectTimeout=10"])
|
|
if self.port != 22:
|
|
cmd.extend(["-p", str(self.port)])
|
|
if self.key_path:
|
|
cmd.extend(["-i", self.key_path])
|
|
if extra_args:
|
|
cmd.extend(extra_args)
|
|
cmd.append(f"{self.user}@{self.host}")
|
|
return cmd
|
|
|
|
def _establish_connection(self):
|
|
cmd = self._build_ssh_command()
|
|
cmd.append("echo 'SSH connection established'")
|
|
try:
|
|
result = subprocess.run(cmd, capture_output=True, text=True, timeout=15)
|
|
if result.returncode != 0:
|
|
error_msg = result.stderr.strip() or result.stdout.strip()
|
|
raise RuntimeError(f"SSH connection failed: {error_msg}")
|
|
except subprocess.TimeoutExpired:
|
|
raise RuntimeError(f"SSH connection to {self.user}@{self.host} timed out")
|
|
|
|
def _detect_remote_home(self) -> str:
|
|
"""Detect the remote user's home directory."""
|
|
try:
|
|
cmd = self._build_ssh_command()
|
|
cmd.append("echo $HOME")
|
|
result = subprocess.run(cmd, capture_output=True, text=True, timeout=10)
|
|
home = result.stdout.strip()
|
|
if home and result.returncode == 0:
|
|
logger.debug("SSH: remote home = %s", home)
|
|
return home
|
|
except Exception:
|
|
pass
|
|
if self.user == "root":
|
|
return "/root"
|
|
return f"/home/{self.user}"
|
|
|
|
# ------------------------------------------------------------------
|
|
# File sync (via FileSyncManager)
|
|
# ------------------------------------------------------------------
|
|
|
|
def _ensure_remote_dirs(self) -> None:
|
|
"""Create base ~/.hermes directory tree on remote in one SSH call."""
|
|
base = f"{self._remote_home}/.hermes"
|
|
dirs = [base, f"{base}/skills", f"{base}/credentials", f"{base}/cache"]
|
|
cmd = self._build_ssh_command()
|
|
cmd.append(quoted_mkdir_command(dirs))
|
|
subprocess.run(cmd, capture_output=True, text=True, timeout=10)
|
|
|
|
# _get_sync_files provided via iter_sync_files in FileSyncManager init
|
|
|
|
def _scp_upload(self, host_path: str, remote_path: str) -> None:
|
|
"""Upload a single file via scp over ControlMaster."""
|
|
parent = str(Path(remote_path).parent)
|
|
mkdir_cmd = self._build_ssh_command()
|
|
mkdir_cmd.append(f"mkdir -p {shlex.quote(parent)}")
|
|
subprocess.run(mkdir_cmd, capture_output=True, text=True, timeout=10)
|
|
|
|
scp_cmd = ["scp", "-o", f"ControlPath={self.control_socket}"]
|
|
if self.port != 22:
|
|
scp_cmd.extend(["-P", str(self.port)])
|
|
if self.key_path:
|
|
scp_cmd.extend(["-i", self.key_path])
|
|
scp_cmd.extend([host_path, f"{self.user}@{self.host}:{remote_path}"])
|
|
result = subprocess.run(scp_cmd, capture_output=True, text=True, timeout=30)
|
|
if result.returncode != 0:
|
|
raise RuntimeError(f"scp failed: {result.stderr.strip()}")
|
|
|
|
def _ssh_bulk_upload(self, files: list[tuple[str, str]]) -> None:
|
|
"""Upload many files in a single tar-over-SSH stream.
|
|
|
|
Pipes ``tar c`` on the local side through an SSH connection to
|
|
``tar x`` on the remote, transferring all files in one TCP stream
|
|
instead of spawning a subprocess per file. Directory creation is
|
|
batched into a single ``mkdir -p`` call beforehand.
|
|
|
|
Typical improvement: ~580 files goes from O(N) scp round-trips
|
|
to a single streaming transfer.
|
|
"""
|
|
if not files:
|
|
return
|
|
|
|
parents = unique_parent_dirs(files)
|
|
if parents:
|
|
cmd = self._build_ssh_command()
|
|
cmd.append(quoted_mkdir_command(parents))
|
|
result = subprocess.run(cmd, capture_output=True, text=True, timeout=30)
|
|
if result.returncode != 0:
|
|
raise RuntimeError(f"remote mkdir failed: {result.stderr.strip()}")
|
|
|
|
# Symlink staging avoids fragile GNU tar --transform rules.
|
|
with tempfile.TemporaryDirectory(prefix="hermes-ssh-bulk-") as staging:
|
|
for host_path, remote_path in files:
|
|
staged = os.path.join(staging, remote_path.lstrip("/"))
|
|
os.makedirs(os.path.dirname(staged), exist_ok=True)
|
|
os.symlink(os.path.abspath(host_path), staged)
|
|
|
|
tar_cmd = ["tar", "-chf", "-", "-C", staging, "."]
|
|
ssh_cmd = self._build_ssh_command()
|
|
# --no-overwrite-dir prevents tar from overwriting the mode of
|
|
# existing directories (e.g. /home/<user>) with the staging
|
|
# directory's mode. Without this, a umask 002 produces 0775
|
|
# dirs which breaks sshd StrictModes (refuses authorized_keys).
|
|
ssh_cmd.append("tar xf - --no-overwrite-dir -C /")
|
|
|
|
tar_proc = subprocess.Popen(
|
|
tar_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE
|
|
)
|
|
try:
|
|
ssh_proc = subprocess.Popen(
|
|
ssh_cmd, stdin=tar_proc.stdout, stdout=subprocess.PIPE,
|
|
stderr=subprocess.PIPE,
|
|
)
|
|
except Exception:
|
|
tar_proc.kill()
|
|
tar_proc.wait()
|
|
raise
|
|
|
|
# Allow tar_proc to receive SIGPIPE if ssh_proc exits early
|
|
tar_proc.stdout.close()
|
|
|
|
try:
|
|
_, ssh_stderr = ssh_proc.communicate(timeout=120)
|
|
# Use communicate() instead of wait() to drain stderr and
|
|
# avoid deadlock if tar produces more than PIPE_BUF of errors.
|
|
tar_stderr_raw = b""
|
|
if tar_proc.poll() is None:
|
|
_, tar_stderr_raw = tar_proc.communicate(timeout=10)
|
|
else:
|
|
tar_stderr_raw = tar_proc.stderr.read() if tar_proc.stderr else b""
|
|
except subprocess.TimeoutExpired:
|
|
tar_proc.kill()
|
|
ssh_proc.kill()
|
|
tar_proc.wait()
|
|
ssh_proc.wait()
|
|
raise RuntimeError("SSH bulk upload timed out")
|
|
|
|
if tar_proc.returncode != 0:
|
|
raise RuntimeError(
|
|
f"tar create failed (rc={tar_proc.returncode}): "
|
|
f"{tar_stderr_raw.decode(errors='replace').strip()}"
|
|
)
|
|
if ssh_proc.returncode != 0:
|
|
raise RuntimeError(
|
|
f"tar extract over SSH failed (rc={ssh_proc.returncode}): "
|
|
f"{ssh_stderr.decode(errors='replace').strip()}"
|
|
)
|
|
|
|
logger.debug("SSH: bulk-uploaded %d file(s) via tar pipe", len(files))
|
|
|
|
def _ssh_bulk_download(self, dest: Path) -> None:
|
|
"""Download remote .hermes/ as a tar archive."""
|
|
# Tar from / with the full path so archive entries preserve absolute
|
|
# paths (e.g. home/user/.hermes/skills/f.py), matching _pushed_hashes keys.
|
|
rel_base = f"{self._remote_home}/.hermes".lstrip("/")
|
|
ssh_cmd = self._build_ssh_command()
|
|
ssh_cmd.append(f"tar cf - -C / {shlex.quote(rel_base)}")
|
|
with open(dest, "wb") as f:
|
|
result = subprocess.run(ssh_cmd, stdout=f, stderr=subprocess.PIPE, timeout=120)
|
|
if result.returncode != 0:
|
|
raise RuntimeError(f"SSH bulk download failed: {result.stderr.decode(errors='replace').strip()}")
|
|
|
|
def _ssh_delete(self, remote_paths: list[str]) -> None:
|
|
"""Batch-delete remote files in one SSH call."""
|
|
cmd = self._build_ssh_command()
|
|
cmd.append(quoted_rm_command(remote_paths))
|
|
result = subprocess.run(cmd, capture_output=True, text=True, timeout=10)
|
|
if result.returncode != 0:
|
|
raise RuntimeError(f"remote rm failed: {result.stderr.strip()}")
|
|
|
|
def _before_execute(self) -> None:
|
|
"""Sync files to remote via FileSyncManager (rate-limited internally)."""
|
|
self._sync_manager.sync()
|
|
|
|
# ------------------------------------------------------------------
|
|
# Execution
|
|
# ------------------------------------------------------------------
|
|
|
|
def _run_bash(self, cmd_string: str, *, login: bool = False,
|
|
timeout: int = 120,
|
|
stdin_data: str | None = None) -> subprocess.Popen:
|
|
"""Spawn an SSH process that runs bash on the remote host."""
|
|
cmd = self._build_ssh_command()
|
|
if login:
|
|
cmd.extend(["bash", "-l", "-c", shlex.quote(cmd_string)])
|
|
else:
|
|
cmd.extend(["bash", "-c", shlex.quote(cmd_string)])
|
|
|
|
return _popen_bash(cmd, stdin_data)
|
|
|
|
def cleanup(self):
|
|
if self._sync_manager:
|
|
logger.info("SSH: syncing files from sandbox...")
|
|
self._sync_manager.sync_back()
|
|
|
|
if self.control_socket.exists():
|
|
try:
|
|
cmd = ["ssh", "-o", f"ControlPath={self.control_socket}",
|
|
"-O", "exit", f"{self.user}@{self.host}"]
|
|
subprocess.run(cmd, capture_output=True, timeout=5)
|
|
except (OSError, subprocess.SubprocessError):
|
|
pass
|
|
try:
|
|
self.control_socket.unlink()
|
|
except OSError:
|
|
pass
|