mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-28 23:11:37 +08:00
Compare commits
3 Commits
sid/fix-to
...
fix/modal-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b075d08064 | ||
|
|
9000d3163a | ||
|
|
04d4f41e77 |
295
tests/tools/test_modal_bulk_upload.py
Normal file
295
tests/tools/test_modal_bulk_upload.py
Normal file
@@ -0,0 +1,295 @@
|
|||||||
|
"""Tests for Modal bulk upload via tar/base64 archive."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import base64
|
||||||
|
import io
|
||||||
|
import tarfile
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from tools.environments import modal as modal_env
|
||||||
|
|
||||||
|
|
||||||
|
def _make_mock_modal_env(monkeypatch, tmp_path):
|
||||||
|
"""Create a minimal mock ModalEnvironment for testing upload methods.
|
||||||
|
|
||||||
|
Returns a ModalEnvironment-like object with _sandbox and _worker mocked.
|
||||||
|
We don't call __init__ because it requires the Modal SDK.
|
||||||
|
"""
|
||||||
|
env = object.__new__(modal_env.ModalEnvironment)
|
||||||
|
env._sandbox = MagicMock()
|
||||||
|
env._worker = MagicMock()
|
||||||
|
env._persistent = False
|
||||||
|
env._task_id = "test"
|
||||||
|
env._sync_manager = None
|
||||||
|
return env
|
||||||
|
|
||||||
|
|
||||||
|
def _make_mock_stdin():
|
||||||
|
"""Create a mock stdin that captures written data."""
|
||||||
|
stdin = MagicMock()
|
||||||
|
written_chunks = []
|
||||||
|
|
||||||
|
def mock_write(data):
|
||||||
|
written_chunks.append(data)
|
||||||
|
|
||||||
|
stdin.write = mock_write
|
||||||
|
stdin.write_eof = MagicMock()
|
||||||
|
stdin.drain = MagicMock()
|
||||||
|
stdin.drain.aio = AsyncMock()
|
||||||
|
stdin._written_chunks = written_chunks
|
||||||
|
return stdin
|
||||||
|
|
||||||
|
|
||||||
|
def _wire_async_exec(env, exec_calls=None):
|
||||||
|
"""Wire mock sandbox.exec.aio and a real run_coroutine on the env.
|
||||||
|
|
||||||
|
Optionally captures exec call args into *exec_calls* list.
|
||||||
|
Returns (exec_calls, run_kwargs, stdin_mock).
|
||||||
|
"""
|
||||||
|
if exec_calls is None:
|
||||||
|
exec_calls = []
|
||||||
|
run_kwargs: dict = {}
|
||||||
|
stdin_mock = _make_mock_stdin()
|
||||||
|
|
||||||
|
async def mock_exec_fn(*args, **kwargs):
|
||||||
|
exec_calls.append(args)
|
||||||
|
proc = MagicMock()
|
||||||
|
proc.wait = MagicMock()
|
||||||
|
proc.wait.aio = AsyncMock(return_value=0)
|
||||||
|
proc.stdin = stdin_mock
|
||||||
|
proc.stderr = MagicMock()
|
||||||
|
proc.stderr.read = MagicMock()
|
||||||
|
proc.stderr.read.aio = AsyncMock(return_value="")
|
||||||
|
return proc
|
||||||
|
|
||||||
|
env._sandbox.exec = MagicMock()
|
||||||
|
env._sandbox.exec.aio = mock_exec_fn
|
||||||
|
|
||||||
|
def real_run_coroutine(coro, **kwargs):
|
||||||
|
run_kwargs.update(kwargs)
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
try:
|
||||||
|
return loop.run_until_complete(coro)
|
||||||
|
finally:
|
||||||
|
loop.close()
|
||||||
|
|
||||||
|
env._worker.run_coroutine = real_run_coroutine
|
||||||
|
return exec_calls, run_kwargs, stdin_mock
|
||||||
|
|
||||||
|
|
||||||
|
class TestModalBulkUpload:
|
||||||
|
"""Test _modal_bulk_upload method."""
|
||||||
|
|
||||||
|
def test_empty_files_is_noop(self, monkeypatch, tmp_path):
|
||||||
|
"""Empty file list should not call worker.run_coroutine."""
|
||||||
|
env = _make_mock_modal_env(monkeypatch, tmp_path)
|
||||||
|
env._modal_bulk_upload([])
|
||||||
|
env._worker.run_coroutine.assert_not_called()
|
||||||
|
|
||||||
|
def test_tar_archive_contains_all_files(self, monkeypatch, tmp_path):
|
||||||
|
"""The tar archive sent via stdin should contain all files."""
|
||||||
|
env = _make_mock_modal_env(monkeypatch, tmp_path)
|
||||||
|
|
||||||
|
src_a = tmp_path / "a.json"
|
||||||
|
src_b = tmp_path / "b.py"
|
||||||
|
src_a.write_text("cred_content")
|
||||||
|
src_b.write_text("skill_content")
|
||||||
|
|
||||||
|
files = [
|
||||||
|
(str(src_a), "/root/.hermes/credentials/a.json"),
|
||||||
|
(str(src_b), "/root/.hermes/skills/b.py"),
|
||||||
|
]
|
||||||
|
|
||||||
|
exec_calls, _, stdin_mock = _wire_async_exec(env)
|
||||||
|
env._modal_bulk_upload(files)
|
||||||
|
|
||||||
|
# Verify the command reads from stdin (no echo with embedded payload)
|
||||||
|
assert len(exec_calls) == 1
|
||||||
|
args = exec_calls[0]
|
||||||
|
assert args[0] == "bash"
|
||||||
|
assert args[1] == "-c"
|
||||||
|
cmd = args[2]
|
||||||
|
assert "mkdir -p" in cmd
|
||||||
|
assert "base64 -d" in cmd
|
||||||
|
assert "tar xzf" in cmd
|
||||||
|
assert "-C /" in cmd
|
||||||
|
|
||||||
|
# Reassemble the base64 payload from stdin chunks and verify tar contents
|
||||||
|
payload = "".join(stdin_mock._written_chunks)
|
||||||
|
tar_data = base64.b64decode(payload)
|
||||||
|
buf = io.BytesIO(tar_data)
|
||||||
|
with tarfile.open(fileobj=buf, mode="r:gz") as tar:
|
||||||
|
names = sorted(tar.getnames())
|
||||||
|
assert "root/.hermes/credentials/a.json" in names
|
||||||
|
assert "root/.hermes/skills/b.py" in names
|
||||||
|
|
||||||
|
# Verify content
|
||||||
|
a_content = tar.extractfile("root/.hermes/credentials/a.json").read()
|
||||||
|
assert a_content == b"cred_content"
|
||||||
|
b_content = tar.extractfile("root/.hermes/skills/b.py").read()
|
||||||
|
assert b_content == b"skill_content"
|
||||||
|
|
||||||
|
# Verify stdin was closed
|
||||||
|
stdin_mock.write_eof.assert_called_once()
|
||||||
|
|
||||||
|
def test_mkdir_includes_all_parents(self, monkeypatch, tmp_path):
|
||||||
|
"""Remote parent directories should be pre-created in the command."""
|
||||||
|
env = _make_mock_modal_env(monkeypatch, tmp_path)
|
||||||
|
|
||||||
|
src = tmp_path / "f.txt"
|
||||||
|
src.write_text("data")
|
||||||
|
|
||||||
|
files = [
|
||||||
|
(str(src), "/root/.hermes/credentials/f.txt"),
|
||||||
|
(str(src), "/root/.hermes/skills/deep/nested/f.txt"),
|
||||||
|
]
|
||||||
|
|
||||||
|
exec_calls, _, _ = _wire_async_exec(env)
|
||||||
|
env._modal_bulk_upload(files)
|
||||||
|
|
||||||
|
cmd = exec_calls[0][2]
|
||||||
|
assert "/root/.hermes/credentials" in cmd
|
||||||
|
assert "/root/.hermes/skills/deep/nested" in cmd
|
||||||
|
|
||||||
|
def test_single_exec_call(self, monkeypatch, tmp_path):
|
||||||
|
"""Bulk upload should use exactly one exec call regardless of file count."""
|
||||||
|
env = _make_mock_modal_env(monkeypatch, tmp_path)
|
||||||
|
|
||||||
|
files = []
|
||||||
|
for i in range(20):
|
||||||
|
src = tmp_path / f"file_{i}.txt"
|
||||||
|
src.write_text(f"content_{i}")
|
||||||
|
files.append((str(src), f"/root/.hermes/cache/file_{i}.txt"))
|
||||||
|
|
||||||
|
exec_calls, _, _ = _wire_async_exec(env)
|
||||||
|
env._modal_bulk_upload(files)
|
||||||
|
|
||||||
|
# Should be exactly 1 exec call, not 20
|
||||||
|
assert len(exec_calls) == 1
|
||||||
|
|
||||||
|
def test_bulk_upload_wired_in_filesyncmanager(self, monkeypatch):
|
||||||
|
"""Verify ModalEnvironment passes bulk_upload_fn to FileSyncManager."""
|
||||||
|
captured_kwargs = {}
|
||||||
|
|
||||||
|
def capture_fsm(**kwargs):
|
||||||
|
captured_kwargs.update(kwargs)
|
||||||
|
return type("M", (), {"sync": lambda self, **k: None})()
|
||||||
|
|
||||||
|
monkeypatch.setattr(modal_env, "FileSyncManager", capture_fsm)
|
||||||
|
|
||||||
|
# Create a minimal env without full __init__
|
||||||
|
env = object.__new__(modal_env.ModalEnvironment)
|
||||||
|
env._sandbox = MagicMock()
|
||||||
|
env._worker = MagicMock()
|
||||||
|
env._persistent = False
|
||||||
|
env._task_id = "test"
|
||||||
|
|
||||||
|
# Manually call the part of __init__ that wires FileSyncManager
|
||||||
|
from tools.environments.file_sync import iter_sync_files
|
||||||
|
env._sync_manager = modal_env.FileSyncManager(
|
||||||
|
get_files_fn=lambda: iter_sync_files("/root/.hermes"),
|
||||||
|
upload_fn=env._modal_upload,
|
||||||
|
delete_fn=env._modal_delete,
|
||||||
|
bulk_upload_fn=env._modal_bulk_upload,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "bulk_upload_fn" in captured_kwargs
|
||||||
|
assert captured_kwargs["bulk_upload_fn"] is not None
|
||||||
|
assert callable(captured_kwargs["bulk_upload_fn"])
|
||||||
|
|
||||||
|
def test_timeout_set_to_120(self, monkeypatch, tmp_path):
|
||||||
|
"""Bulk upload uses a 120s timeout (not the per-file 15s)."""
|
||||||
|
env = _make_mock_modal_env(monkeypatch, tmp_path)
|
||||||
|
|
||||||
|
src = tmp_path / "f.txt"
|
||||||
|
src.write_text("data")
|
||||||
|
files = [(str(src), "/root/.hermes/f.txt")]
|
||||||
|
|
||||||
|
_, run_kwargs, _ = _wire_async_exec(env)
|
||||||
|
env._modal_bulk_upload(files)
|
||||||
|
|
||||||
|
assert run_kwargs.get("timeout") == 120
|
||||||
|
|
||||||
|
def test_nonzero_exit_raises(self, monkeypatch, tmp_path):
|
||||||
|
"""Non-zero exit code from remote exec should raise RuntimeError."""
|
||||||
|
env = _make_mock_modal_env(monkeypatch, tmp_path)
|
||||||
|
|
||||||
|
src = tmp_path / "f.txt"
|
||||||
|
src.write_text("data")
|
||||||
|
files = [(str(src), "/root/.hermes/f.txt")]
|
||||||
|
|
||||||
|
stdin_mock = _make_mock_stdin()
|
||||||
|
|
||||||
|
async def mock_exec_fn(*args, **kwargs):
|
||||||
|
proc = MagicMock()
|
||||||
|
proc.wait = MagicMock()
|
||||||
|
proc.wait.aio = AsyncMock(return_value=1) # non-zero exit
|
||||||
|
proc.stdin = stdin_mock
|
||||||
|
proc.stderr = MagicMock()
|
||||||
|
proc.stderr.read = MagicMock()
|
||||||
|
proc.stderr.read.aio = AsyncMock(return_value="tar: error")
|
||||||
|
return proc
|
||||||
|
|
||||||
|
env._sandbox.exec = MagicMock()
|
||||||
|
env._sandbox.exec.aio = mock_exec_fn
|
||||||
|
|
||||||
|
def real_run_coroutine(coro, **kwargs):
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
try:
|
||||||
|
return loop.run_until_complete(coro)
|
||||||
|
finally:
|
||||||
|
loop.close()
|
||||||
|
|
||||||
|
env._worker.run_coroutine = real_run_coroutine
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError, match="Modal bulk upload failed"):
|
||||||
|
env._modal_bulk_upload(files)
|
||||||
|
|
||||||
|
def test_payload_not_in_command_string(self, monkeypatch, tmp_path):
|
||||||
|
"""The base64 payload must NOT appear in the bash -c argument.
|
||||||
|
|
||||||
|
This is the core ARG_MAX fix: the payload goes through stdin,
|
||||||
|
not embedded in the command string.
|
||||||
|
"""
|
||||||
|
env = _make_mock_modal_env(monkeypatch, tmp_path)
|
||||||
|
|
||||||
|
src = tmp_path / "f.txt"
|
||||||
|
src.write_text("some data to upload")
|
||||||
|
files = [(str(src), "/root/.hermes/f.txt")]
|
||||||
|
|
||||||
|
exec_calls, _, stdin_mock = _wire_async_exec(env)
|
||||||
|
env._modal_bulk_upload(files)
|
||||||
|
|
||||||
|
# The command should NOT contain an echo with the payload
|
||||||
|
cmd = exec_calls[0][2]
|
||||||
|
assert "echo" not in cmd
|
||||||
|
# The payload should go through stdin
|
||||||
|
assert len(stdin_mock._written_chunks) > 0
|
||||||
|
|
||||||
|
def test_stdin_chunked_for_large_payloads(self, monkeypatch, tmp_path):
|
||||||
|
"""Payloads larger than _STDIN_CHUNK_SIZE should be split into multiple writes."""
|
||||||
|
env = _make_mock_modal_env(monkeypatch, tmp_path)
|
||||||
|
|
||||||
|
# Use random bytes so gzip cannot compress them -- ensures the
|
||||||
|
# base64 payload exceeds one 1 MB chunk.
|
||||||
|
import os as _os
|
||||||
|
src = tmp_path / "large.bin"
|
||||||
|
src.write_bytes(_os.urandom(1024 * 1024 + 512 * 1024))
|
||||||
|
files = [(str(src), "/root/.hermes/large.bin")]
|
||||||
|
|
||||||
|
exec_calls, _, stdin_mock = _wire_async_exec(env)
|
||||||
|
env._modal_bulk_upload(files)
|
||||||
|
|
||||||
|
# Should have multiple stdin write chunks
|
||||||
|
assert len(stdin_mock._written_chunks) >= 2
|
||||||
|
|
||||||
|
# Reassembled payload should still decode to valid tar
|
||||||
|
payload = "".join(stdin_mock._written_chunks)
|
||||||
|
tar_data = base64.b64decode(payload)
|
||||||
|
buf = io.BytesIO(tar_data)
|
||||||
|
with tarfile.open(fileobj=buf, mode="r:gz") as tar:
|
||||||
|
names = tar.getnames()
|
||||||
|
assert "root/.hermes/large.bin" in names
|
||||||
517
tests/tools/test_ssh_bulk_upload.py
Normal file
517
tests/tools/test_ssh_bulk_upload.py
Normal file
@@ -0,0 +1,517 @@
|
|||||||
|
"""Tests for SSH bulk upload via tar pipe."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from tools.environments import ssh as ssh_env
|
||||||
|
from tools.environments.file_sync import quoted_mkdir_command, unique_parent_dirs
|
||||||
|
from tools.environments.ssh import SSHEnvironment
|
||||||
|
|
||||||
|
|
||||||
|
def _mock_proc(*, returncode=0, poll_return=0, communicate_return=(b"", b""),
|
||||||
|
stderr_read=b""):
|
||||||
|
"""Create a MagicMock mimicking subprocess.Popen for tar/ssh pipes."""
|
||||||
|
m = MagicMock()
|
||||||
|
m.stdout = MagicMock()
|
||||||
|
m.returncode = returncode
|
||||||
|
m.poll.return_value = poll_return
|
||||||
|
m.communicate.return_value = communicate_return
|
||||||
|
m.stderr = MagicMock()
|
||||||
|
m.stderr.read.return_value = stderr_read
|
||||||
|
return m
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_env(monkeypatch):
|
||||||
|
"""Create an SSHEnvironment with mocked connection/sync."""
|
||||||
|
monkeypatch.setattr(ssh_env.shutil, "which", lambda _name: "/usr/bin/ssh")
|
||||||
|
monkeypatch.setattr(ssh_env.SSHEnvironment, "_establish_connection", lambda self: None)
|
||||||
|
monkeypatch.setattr(ssh_env.SSHEnvironment, "_detect_remote_home", lambda self: "/home/testuser")
|
||||||
|
monkeypatch.setattr(ssh_env.SSHEnvironment, "_ensure_remote_dirs", lambda self: None)
|
||||||
|
monkeypatch.setattr(ssh_env.SSHEnvironment, "init_session", lambda self: None)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
ssh_env, "FileSyncManager",
|
||||||
|
lambda **kw: type("M", (), {"sync": lambda self, **k: None})(),
|
||||||
|
)
|
||||||
|
return SSHEnvironment(host="example.com", user="testuser")
|
||||||
|
|
||||||
|
|
||||||
|
class TestSSHBulkUpload:
|
||||||
|
"""Unit tests for _ssh_bulk_upload — tar pipe mechanics."""
|
||||||
|
|
||||||
|
def test_empty_files_is_noop(self, mock_env):
|
||||||
|
"""Empty file list should not spawn any subprocesses."""
|
||||||
|
with patch.object(subprocess, "run") as mock_run, \
|
||||||
|
patch.object(subprocess, "Popen") as mock_popen:
|
||||||
|
mock_env._ssh_bulk_upload([])
|
||||||
|
mock_run.assert_not_called()
|
||||||
|
mock_popen.assert_not_called()
|
||||||
|
|
||||||
|
def test_mkdir_batched_into_single_call(self, mock_env, tmp_path):
|
||||||
|
"""All parent directories should be created in one SSH call."""
|
||||||
|
# Create test files
|
||||||
|
f1 = tmp_path / "a.txt"
|
||||||
|
f1.write_text("aaa")
|
||||||
|
f2 = tmp_path / "b.txt"
|
||||||
|
f2.write_text("bbb")
|
||||||
|
|
||||||
|
files = [
|
||||||
|
(str(f1), "/home/testuser/.hermes/skills/a.txt"),
|
||||||
|
(str(f2), "/home/testuser/.hermes/credentials/b.txt"),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Mock subprocess.run for mkdir and Popen for tar pipe
|
||||||
|
mock_run = MagicMock(return_value=subprocess.CompletedProcess([], 0))
|
||||||
|
|
||||||
|
def make_proc(cmd, **kwargs):
|
||||||
|
m = MagicMock()
|
||||||
|
m.stdout = MagicMock()
|
||||||
|
m.returncode = 0
|
||||||
|
m.poll.return_value = 0
|
||||||
|
m.communicate.return_value = (b"", b"")
|
||||||
|
m.stderr = MagicMock()
|
||||||
|
m.stderr.read.return_value = b""
|
||||||
|
return m
|
||||||
|
|
||||||
|
with patch.object(subprocess, "run", mock_run), \
|
||||||
|
patch.object(subprocess, "Popen", side_effect=make_proc):
|
||||||
|
mock_env._ssh_bulk_upload(files)
|
||||||
|
|
||||||
|
# Exactly one subprocess.run call for mkdir
|
||||||
|
assert mock_run.call_count == 1
|
||||||
|
mkdir_cmd = mock_run.call_args[0][0]
|
||||||
|
# Should contain mkdir -p with both parent dirs
|
||||||
|
mkdir_str = " ".join(mkdir_cmd)
|
||||||
|
assert "mkdir -p" in mkdir_str
|
||||||
|
assert "/home/testuser/.hermes/skills" in mkdir_str
|
||||||
|
assert "/home/testuser/.hermes/credentials" in mkdir_str
|
||||||
|
|
||||||
|
def test_staging_symlinks_mirror_remote_layout(self, mock_env, tmp_path):
|
||||||
|
"""Symlinks in staging dir should mirror the remote path structure."""
|
||||||
|
f1 = tmp_path / "local_a.txt"
|
||||||
|
f1.write_text("content a")
|
||||||
|
|
||||||
|
files = [
|
||||||
|
(str(f1), "/home/testuser/.hermes/skills/my_skill.md"),
|
||||||
|
]
|
||||||
|
|
||||||
|
staging_paths = []
|
||||||
|
|
||||||
|
def capture_tar_cmd(cmd, **kwargs):
|
||||||
|
if cmd[0] == "tar":
|
||||||
|
# Capture the staging dir from -C argument
|
||||||
|
c_idx = cmd.index("-C")
|
||||||
|
staging_dir = cmd[c_idx + 1]
|
||||||
|
# Check the symlink exists
|
||||||
|
expected = os.path.join(
|
||||||
|
staging_dir, "home/testuser/.hermes/skills/my_skill.md"
|
||||||
|
)
|
||||||
|
staging_paths.append(expected)
|
||||||
|
assert os.path.islink(expected), f"Expected symlink at {expected}"
|
||||||
|
assert os.readlink(expected) == os.path.abspath(str(f1))
|
||||||
|
|
||||||
|
mock = MagicMock()
|
||||||
|
mock.stdout = MagicMock()
|
||||||
|
mock.returncode = 0
|
||||||
|
mock.poll.return_value = 0
|
||||||
|
mock.communicate.return_value = (b"", b"")
|
||||||
|
mock.stderr = MagicMock()
|
||||||
|
mock.stderr.read.return_value = b""
|
||||||
|
return mock
|
||||||
|
|
||||||
|
with patch.object(subprocess, "run",
|
||||||
|
return_value=subprocess.CompletedProcess([], 0)), \
|
||||||
|
patch.object(subprocess, "Popen", side_effect=capture_tar_cmd):
|
||||||
|
mock_env._ssh_bulk_upload(files)
|
||||||
|
|
||||||
|
assert len(staging_paths) == 1, "tar command should have been called"
|
||||||
|
|
||||||
|
def test_tar_pipe_commands(self, mock_env, tmp_path):
|
||||||
|
"""Verify tar and SSH commands are wired correctly."""
|
||||||
|
f1 = tmp_path / "x.txt"
|
||||||
|
f1.write_text("x")
|
||||||
|
|
||||||
|
files = [(str(f1), "/home/testuser/.hermes/cache/x.txt")]
|
||||||
|
|
||||||
|
popen_cmds = []
|
||||||
|
|
||||||
|
def capture_popen(cmd, **kwargs):
|
||||||
|
popen_cmds.append(cmd)
|
||||||
|
mock = MagicMock()
|
||||||
|
mock.stdout = MagicMock()
|
||||||
|
mock.returncode = 0
|
||||||
|
mock.poll.return_value = 0
|
||||||
|
mock.communicate.return_value = (b"", b"")
|
||||||
|
mock.stderr = MagicMock()
|
||||||
|
mock.stderr.read.return_value = b""
|
||||||
|
return mock
|
||||||
|
|
||||||
|
with patch.object(subprocess, "run",
|
||||||
|
return_value=subprocess.CompletedProcess([], 0)), \
|
||||||
|
patch.object(subprocess, "Popen", side_effect=capture_popen):
|
||||||
|
mock_env._ssh_bulk_upload(files)
|
||||||
|
|
||||||
|
assert len(popen_cmds) == 2, "Should spawn tar + ssh processes"
|
||||||
|
|
||||||
|
tar_cmd = popen_cmds[0]
|
||||||
|
ssh_cmd = popen_cmds[1]
|
||||||
|
|
||||||
|
# tar: create, dereference symlinks, to stdout
|
||||||
|
assert tar_cmd[0] == "tar"
|
||||||
|
assert "-chf" in tar_cmd
|
||||||
|
assert "-" in tar_cmd # stdout
|
||||||
|
assert "-C" in tar_cmd
|
||||||
|
|
||||||
|
# ssh: extract from stdin at /
|
||||||
|
ssh_str = " ".join(ssh_cmd)
|
||||||
|
assert "ssh" in ssh_str
|
||||||
|
assert "tar xf - -C /" in ssh_str
|
||||||
|
assert "testuser@example.com" in ssh_str
|
||||||
|
|
||||||
|
def test_mkdir_failure_raises(self, mock_env, tmp_path):
|
||||||
|
"""mkdir failure should raise RuntimeError before tar pipe."""
|
||||||
|
f1 = tmp_path / "y.txt"
|
||||||
|
f1.write_text("y")
|
||||||
|
files = [(str(f1), "/home/testuser/.hermes/skills/y.txt")]
|
||||||
|
|
||||||
|
failed_run = subprocess.CompletedProcess([], 1, stderr="Permission denied")
|
||||||
|
with patch.object(subprocess, "run", return_value=failed_run):
|
||||||
|
with pytest.raises(RuntimeError, match="remote mkdir failed"):
|
||||||
|
mock_env._ssh_bulk_upload(files)
|
||||||
|
|
||||||
|
def test_tar_create_failure_raises(self, mock_env, tmp_path):
|
||||||
|
"""tar create failure should raise RuntimeError."""
|
||||||
|
f1 = tmp_path / "z.txt"
|
||||||
|
f1.write_text("z")
|
||||||
|
files = [(str(f1), "/home/testuser/.hermes/skills/z.txt")]
|
||||||
|
|
||||||
|
mock_tar = MagicMock()
|
||||||
|
mock_tar.stdout = MagicMock()
|
||||||
|
mock_tar.returncode = 1
|
||||||
|
mock_tar.poll.return_value = 1
|
||||||
|
mock_tar.communicate.return_value = (b"tar: error", b"")
|
||||||
|
mock_tar.stderr = MagicMock()
|
||||||
|
mock_tar.stderr.read.return_value = b"tar: error"
|
||||||
|
|
||||||
|
mock_ssh = MagicMock()
|
||||||
|
mock_ssh.communicate.return_value = (b"", b"")
|
||||||
|
mock_ssh.returncode = 0
|
||||||
|
|
||||||
|
def popen_side_effect(cmd, **kwargs):
|
||||||
|
if cmd[0] == "tar":
|
||||||
|
return mock_tar
|
||||||
|
return mock_ssh
|
||||||
|
|
||||||
|
with patch.object(subprocess, "run",
|
||||||
|
return_value=subprocess.CompletedProcess([], 0)), \
|
||||||
|
patch.object(subprocess, "Popen", side_effect=popen_side_effect):
|
||||||
|
with pytest.raises(RuntimeError, match="tar create failed"):
|
||||||
|
mock_env._ssh_bulk_upload(files)
|
||||||
|
|
||||||
|
def test_ssh_extract_failure_raises(self, mock_env, tmp_path):
|
||||||
|
"""SSH tar extract failure should raise RuntimeError."""
|
||||||
|
f1 = tmp_path / "w.txt"
|
||||||
|
f1.write_text("w")
|
||||||
|
files = [(str(f1), "/home/testuser/.hermes/skills/w.txt")]
|
||||||
|
|
||||||
|
mock_tar = MagicMock()
|
||||||
|
mock_tar.stdout = MagicMock()
|
||||||
|
mock_tar.returncode = 0
|
||||||
|
mock_tar.poll.return_value = 0
|
||||||
|
mock_tar.communicate.return_value = (b"", b"")
|
||||||
|
mock_tar.stderr = MagicMock()
|
||||||
|
mock_tar.stderr.read.return_value = b""
|
||||||
|
|
||||||
|
mock_ssh = MagicMock()
|
||||||
|
mock_ssh.communicate.return_value = (b"", b"Permission denied")
|
||||||
|
mock_ssh.returncode = 1
|
||||||
|
|
||||||
|
def popen_side_effect(cmd, **kwargs):
|
||||||
|
if cmd[0] == "tar":
|
||||||
|
return mock_tar
|
||||||
|
return mock_ssh
|
||||||
|
|
||||||
|
with patch.object(subprocess, "run",
|
||||||
|
return_value=subprocess.CompletedProcess([], 0)), \
|
||||||
|
patch.object(subprocess, "Popen", side_effect=popen_side_effect):
|
||||||
|
with pytest.raises(RuntimeError, match="tar extract over SSH failed"):
|
||||||
|
mock_env._ssh_bulk_upload(files)
|
||||||
|
|
||||||
|
def test_ssh_command_uses_control_socket(self, mock_env, tmp_path):
|
||||||
|
"""SSH command for tar extract should reuse ControlMaster socket."""
|
||||||
|
f1 = tmp_path / "c.txt"
|
||||||
|
f1.write_text("c")
|
||||||
|
files = [(str(f1), "/home/testuser/.hermes/cache/c.txt")]
|
||||||
|
|
||||||
|
popen_cmds = []
|
||||||
|
|
||||||
|
def capture_popen(cmd, **kwargs):
|
||||||
|
popen_cmds.append(cmd)
|
||||||
|
mock = MagicMock()
|
||||||
|
mock.stdout = MagicMock()
|
||||||
|
mock.returncode = 0
|
||||||
|
mock.poll.return_value = 0
|
||||||
|
mock.communicate.return_value = (b"", b"")
|
||||||
|
mock.stderr = MagicMock()
|
||||||
|
mock.stderr.read.return_value = b""
|
||||||
|
return mock
|
||||||
|
|
||||||
|
with patch.object(subprocess, "run",
|
||||||
|
return_value=subprocess.CompletedProcess([], 0)), \
|
||||||
|
patch.object(subprocess, "Popen", side_effect=capture_popen):
|
||||||
|
mock_env._ssh_bulk_upload(files)
|
||||||
|
|
||||||
|
# The SSH command (second Popen call) should include ControlPath
|
||||||
|
ssh_cmd = popen_cmds[1]
|
||||||
|
assert f"ControlPath={mock_env.control_socket}" in " ".join(ssh_cmd)
|
||||||
|
|
||||||
|
def test_custom_port_and_key_in_ssh_command(self, monkeypatch, tmp_path):
|
||||||
|
"""Bulk upload SSH command should include custom port and key."""
|
||||||
|
monkeypatch.setattr(ssh_env.shutil, "which", lambda _name: "/usr/bin/ssh")
|
||||||
|
monkeypatch.setattr(ssh_env.SSHEnvironment, "_establish_connection", lambda self: None)
|
||||||
|
monkeypatch.setattr(ssh_env.SSHEnvironment, "_detect_remote_home", lambda self: "/home/u")
|
||||||
|
monkeypatch.setattr(ssh_env.SSHEnvironment, "_ensure_remote_dirs", lambda self: None)
|
||||||
|
monkeypatch.setattr(ssh_env.SSHEnvironment, "init_session", lambda self: None)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
ssh_env, "FileSyncManager",
|
||||||
|
lambda **kw: type("M", (), {"sync": lambda self, **k: None})(),
|
||||||
|
)
|
||||||
|
env = SSHEnvironment(host="h", user="u", port=2222, key_path="/my/key")
|
||||||
|
|
||||||
|
f1 = tmp_path / "d.txt"
|
||||||
|
f1.write_text("d")
|
||||||
|
files = [(str(f1), "/home/u/.hermes/skills/d.txt")]
|
||||||
|
|
||||||
|
run_cmds = []
|
||||||
|
popen_cmds = []
|
||||||
|
|
||||||
|
def capture_run(cmd, **kwargs):
|
||||||
|
run_cmds.append(cmd)
|
||||||
|
return subprocess.CompletedProcess([], 0)
|
||||||
|
|
||||||
|
def capture_popen(cmd, **kwargs):
|
||||||
|
popen_cmds.append(cmd)
|
||||||
|
mock = MagicMock()
|
||||||
|
mock.stdout = MagicMock()
|
||||||
|
mock.returncode = 0
|
||||||
|
mock.poll.return_value = 0
|
||||||
|
mock.communicate.return_value = (b"", b"")
|
||||||
|
mock.stderr = MagicMock()
|
||||||
|
mock.stderr.read.return_value = b""
|
||||||
|
return mock
|
||||||
|
|
||||||
|
with patch.object(subprocess, "run", side_effect=capture_run), \
|
||||||
|
patch.object(subprocess, "Popen", side_effect=capture_popen):
|
||||||
|
env._ssh_bulk_upload(files)
|
||||||
|
|
||||||
|
# Check mkdir SSH call includes port and key
|
||||||
|
assert len(run_cmds) == 1
|
||||||
|
mkdir_cmd = run_cmds[0]
|
||||||
|
assert "-p" in mkdir_cmd and "2222" in mkdir_cmd
|
||||||
|
assert "-i" in mkdir_cmd and "/my/key" in mkdir_cmd
|
||||||
|
|
||||||
|
# Check tar extract SSH call includes port and key
|
||||||
|
ssh_cmd = popen_cmds[1]
|
||||||
|
assert "-p" in ssh_cmd and "2222" in ssh_cmd
|
||||||
|
assert "-i" in ssh_cmd and "/my/key" in ssh_cmd
|
||||||
|
|
||||||
|
def test_parent_dirs_deduplicated(self, mock_env, tmp_path):
|
||||||
|
"""Multiple files in the same dir should produce one mkdir entry."""
|
||||||
|
f1 = tmp_path / "a.txt"
|
||||||
|
f1.write_text("a")
|
||||||
|
f2 = tmp_path / "b.txt"
|
||||||
|
f2.write_text("b")
|
||||||
|
f3 = tmp_path / "c.txt"
|
||||||
|
f3.write_text("c")
|
||||||
|
|
||||||
|
files = [
|
||||||
|
(str(f1), "/home/testuser/.hermes/skills/a.txt"),
|
||||||
|
(str(f2), "/home/testuser/.hermes/skills/b.txt"),
|
||||||
|
(str(f3), "/home/testuser/.hermes/credentials/c.txt"),
|
||||||
|
]
|
||||||
|
|
||||||
|
run_cmds = []
|
||||||
|
|
||||||
|
def capture_run(cmd, **kwargs):
|
||||||
|
run_cmds.append(cmd)
|
||||||
|
return subprocess.CompletedProcess([], 0)
|
||||||
|
|
||||||
|
def make_mock_proc(cmd, **kwargs):
|
||||||
|
mock = MagicMock()
|
||||||
|
mock.stdout = MagicMock()
|
||||||
|
mock.returncode = 0
|
||||||
|
mock.poll.return_value = 0
|
||||||
|
mock.communicate.return_value = (b"", b"")
|
||||||
|
mock.stderr = MagicMock()
|
||||||
|
mock.stderr.read.return_value = b""
|
||||||
|
return mock
|
||||||
|
|
||||||
|
with patch.object(subprocess, "run", side_effect=capture_run), \
|
||||||
|
patch.object(subprocess, "Popen", side_effect=make_mock_proc):
|
||||||
|
mock_env._ssh_bulk_upload(files)
|
||||||
|
|
||||||
|
# Only one mkdir call
|
||||||
|
assert len(run_cmds) == 1
|
||||||
|
mkdir_str = " ".join(run_cmds[0])
|
||||||
|
# skills dir should appear exactly once despite two files
|
||||||
|
assert mkdir_str.count("/home/testuser/.hermes/skills") == 1
|
||||||
|
assert "/home/testuser/.hermes/credentials" in mkdir_str
|
||||||
|
|
||||||
|
def test_tar_stdout_closed_for_sigpipe(self, mock_env, tmp_path):
|
||||||
|
"""tar_proc.stdout must be closed so SIGPIPE propagates correctly."""
|
||||||
|
f1 = tmp_path / "s.txt"
|
||||||
|
f1.write_text("s")
|
||||||
|
files = [(str(f1), "/home/testuser/.hermes/skills/s.txt")]
|
||||||
|
|
||||||
|
mock_tar_stdout = MagicMock()
|
||||||
|
|
||||||
|
def make_proc(cmd, **kwargs):
|
||||||
|
mock = MagicMock()
|
||||||
|
if cmd[0] == "tar":
|
||||||
|
mock.stdout = mock_tar_stdout
|
||||||
|
else:
|
||||||
|
mock.stdout = MagicMock()
|
||||||
|
mock.returncode = 0
|
||||||
|
mock.poll.return_value = 0
|
||||||
|
mock.communicate.return_value = (b"", b"")
|
||||||
|
mock.stderr = MagicMock()
|
||||||
|
mock.stderr.read.return_value = b""
|
||||||
|
return mock
|
||||||
|
|
||||||
|
with patch.object(subprocess, "run",
|
||||||
|
return_value=subprocess.CompletedProcess([], 0)), \
|
||||||
|
patch.object(subprocess, "Popen", side_effect=make_proc):
|
||||||
|
mock_env._ssh_bulk_upload(files)
|
||||||
|
|
||||||
|
mock_tar_stdout.close.assert_called_once()
|
||||||
|
|
||||||
|
def test_timeout_kills_both_processes(self, mock_env, tmp_path):
|
||||||
|
"""TimeoutExpired during communicate should kill both processes."""
|
||||||
|
f1 = tmp_path / "t.txt"
|
||||||
|
f1.write_text("t")
|
||||||
|
files = [(str(f1), "/home/testuser/.hermes/skills/t.txt")]
|
||||||
|
|
||||||
|
mock_tar = MagicMock()
|
||||||
|
mock_tar.stdout = MagicMock()
|
||||||
|
mock_tar.returncode = None
|
||||||
|
mock_tar.poll.return_value = None
|
||||||
|
|
||||||
|
mock_ssh = MagicMock()
|
||||||
|
mock_ssh.communicate.side_effect = subprocess.TimeoutExpired("ssh", 120)
|
||||||
|
mock_ssh.returncode = None
|
||||||
|
|
||||||
|
def make_proc(cmd, **kwargs):
|
||||||
|
if cmd[0] == "tar":
|
||||||
|
return mock_tar
|
||||||
|
return mock_ssh
|
||||||
|
|
||||||
|
with patch.object(subprocess, "run",
|
||||||
|
return_value=subprocess.CompletedProcess([], 0)), \
|
||||||
|
patch.object(subprocess, "Popen", side_effect=make_proc):
|
||||||
|
with pytest.raises(RuntimeError, match="SSH bulk upload timed out"):
|
||||||
|
mock_env._ssh_bulk_upload(files)
|
||||||
|
|
||||||
|
mock_tar.kill.assert_called_once()
|
||||||
|
mock_ssh.kill.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
class TestSSHBulkUploadWiring:
|
||||||
|
"""Verify bulk_upload_fn is wired into FileSyncManager."""
|
||||||
|
|
||||||
|
def test_filesyncmanager_receives_bulk_upload_fn(self, monkeypatch):
|
||||||
|
"""SSHEnvironment should pass _ssh_bulk_upload to FileSyncManager."""
|
||||||
|
monkeypatch.setattr(ssh_env.shutil, "which", lambda _name: "/usr/bin/ssh")
|
||||||
|
monkeypatch.setattr(ssh_env.SSHEnvironment, "_establish_connection", lambda self: None)
|
||||||
|
monkeypatch.setattr(ssh_env.SSHEnvironment, "_detect_remote_home", lambda self: "/root")
|
||||||
|
monkeypatch.setattr(ssh_env.SSHEnvironment, "_ensure_remote_dirs", lambda self: None)
|
||||||
|
monkeypatch.setattr(ssh_env.SSHEnvironment, "init_session", lambda self: None)
|
||||||
|
|
||||||
|
captured_kwargs = {}
|
||||||
|
|
||||||
|
class FakeSyncManager:
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
captured_kwargs.update(kwargs)
|
||||||
|
|
||||||
|
def sync(self, **kw):
|
||||||
|
pass
|
||||||
|
|
||||||
|
monkeypatch.setattr(ssh_env, "FileSyncManager", FakeSyncManager)
|
||||||
|
|
||||||
|
env = SSHEnvironment(host="h", user="u")
|
||||||
|
|
||||||
|
assert "bulk_upload_fn" in captured_kwargs
|
||||||
|
assert captured_kwargs["bulk_upload_fn"] is not None
|
||||||
|
# Should be the bound method
|
||||||
|
assert callable(captured_kwargs["bulk_upload_fn"])
|
||||||
|
|
||||||
|
|
||||||
|
class TestSharedHelpers:
|
||||||
|
"""Direct unit tests for file_sync.py helpers."""
|
||||||
|
|
||||||
|
def test_quoted_mkdir_command_basic(self):
|
||||||
|
result = quoted_mkdir_command(["/a", "/b/c"])
|
||||||
|
assert result == "mkdir -p /a /b/c"
|
||||||
|
|
||||||
|
def test_quoted_mkdir_command_quotes_special_chars(self):
|
||||||
|
result = quoted_mkdir_command(["/path/with spaces", "/path/'quotes'"])
|
||||||
|
assert "mkdir -p" in result
|
||||||
|
# shlex.quote wraps in single quotes
|
||||||
|
assert "'/path/with spaces'" in result
|
||||||
|
|
||||||
|
def test_quoted_mkdir_command_empty(self):
|
||||||
|
result = quoted_mkdir_command([])
|
||||||
|
assert result == "mkdir -p "
|
||||||
|
|
||||||
|
def test_unique_parent_dirs_deduplicates(self):
|
||||||
|
files = [
|
||||||
|
("/local/a.txt", "/remote/dir/a.txt"),
|
||||||
|
("/local/b.txt", "/remote/dir/b.txt"),
|
||||||
|
("/local/c.txt", "/remote/other/c.txt"),
|
||||||
|
]
|
||||||
|
result = unique_parent_dirs(files)
|
||||||
|
assert result == ["/remote/dir", "/remote/other"]
|
||||||
|
|
||||||
|
def test_unique_parent_dirs_sorted(self):
|
||||||
|
files = [
|
||||||
|
("/local/z.txt", "/z/file.txt"),
|
||||||
|
("/local/a.txt", "/a/file.txt"),
|
||||||
|
]
|
||||||
|
result = unique_parent_dirs(files)
|
||||||
|
assert result == ["/a", "/z"]
|
||||||
|
|
||||||
|
def test_unique_parent_dirs_empty(self):
|
||||||
|
assert unique_parent_dirs([]) == []
|
||||||
|
|
||||||
|
|
||||||
|
class TestSSHBulkUploadEdgeCases:
|
||||||
|
"""Edge cases for _ssh_bulk_upload."""
|
||||||
|
|
||||||
|
def test_ssh_popen_failure_kills_tar(self, mock_env, tmp_path):
|
||||||
|
"""If SSH Popen raises, tar process must be killed and cleaned up."""
|
||||||
|
f1 = tmp_path / "e.txt"
|
||||||
|
f1.write_text("e")
|
||||||
|
files = [(str(f1), "/home/testuser/.hermes/skills/e.txt")]
|
||||||
|
|
||||||
|
mock_tar = _mock_proc()
|
||||||
|
|
||||||
|
call_count = 0
|
||||||
|
|
||||||
|
def failing_ssh_popen(cmd, **kwargs):
|
||||||
|
nonlocal call_count
|
||||||
|
call_count += 1
|
||||||
|
if call_count == 1:
|
||||||
|
return mock_tar # tar Popen succeeds
|
||||||
|
raise OSError("SSH binary not found")
|
||||||
|
|
||||||
|
with patch.object(subprocess, "run",
|
||||||
|
return_value=subprocess.CompletedProcess([], 0)), \
|
||||||
|
patch.object(subprocess, "Popen", side_effect=failing_ssh_popen):
|
||||||
|
with pytest.raises(OSError, match="SSH binary not found"):
|
||||||
|
mock_env._ssh_bulk_upload(files)
|
||||||
|
|
||||||
|
mock_tar.kill.assert_called_once()
|
||||||
|
mock_tar.wait.assert_called_once()
|
||||||
@@ -15,7 +15,13 @@ from tools.environments.base import (
|
|||||||
BaseEnvironment,
|
BaseEnvironment,
|
||||||
_ThreadedProcessHandle,
|
_ThreadedProcessHandle,
|
||||||
)
|
)
|
||||||
from tools.environments.file_sync import FileSyncManager, iter_sync_files, quoted_rm_command
|
from tools.environments.file_sync import (
|
||||||
|
FileSyncManager,
|
||||||
|
iter_sync_files,
|
||||||
|
quoted_mkdir_command,
|
||||||
|
quoted_rm_command,
|
||||||
|
unique_parent_dirs,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -150,11 +156,9 @@ class DaytonaEnvironment(BaseEnvironment):
|
|||||||
if not files:
|
if not files:
|
||||||
return
|
return
|
||||||
|
|
||||||
# Pre-create all unique parent directories in one shell call
|
parents = unique_parent_dirs(files)
|
||||||
parents = sorted({str(Path(remote).parent) for _, remote in files})
|
|
||||||
if parents:
|
if parents:
|
||||||
mkdir_cmd = "mkdir -p " + " ".join(shlex.quote(p) for p in parents)
|
self._sandbox.process.exec(quoted_mkdir_command(parents))
|
||||||
self._sandbox.process.exec(mkdir_cmd)
|
|
||||||
|
|
||||||
uploads = [
|
uploads = [
|
||||||
FileUpload(source=host_path, destination=remote_path)
|
FileUpload(source=host_path, destination=remote_path)
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import logging
|
|||||||
import os
|
import os
|
||||||
import shlex
|
import shlex
|
||||||
import time
|
import time
|
||||||
|
from pathlib import Path
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
|
|
||||||
from tools.environments.base import _file_mtime_key
|
from tools.environments.base import _file_mtime_key
|
||||||
@@ -60,6 +61,16 @@ def quoted_rm_command(remote_paths: list[str]) -> str:
|
|||||||
return "rm -f " + " ".join(shlex.quote(p) for p in remote_paths)
|
return "rm -f " + " ".join(shlex.quote(p) for p in remote_paths)
|
||||||
|
|
||||||
|
|
||||||
|
def quoted_mkdir_command(dirs: list[str]) -> str:
|
||||||
|
"""Build a shell ``mkdir -p`` command for a batch of directories."""
|
||||||
|
return "mkdir -p " + " ".join(shlex.quote(d) for d in dirs)
|
||||||
|
|
||||||
|
|
||||||
|
def unique_parent_dirs(files: list[tuple[str, str]]) -> list[str]:
|
||||||
|
"""Extract sorted unique parent directories from (host, remote) pairs."""
|
||||||
|
return sorted({str(Path(remote).parent) for _, remote in files})
|
||||||
|
|
||||||
|
|
||||||
class FileSyncManager:
|
class FileSyncManager:
|
||||||
"""Tracks local file changes and syncs to a remote environment.
|
"""Tracks local file changes and syncs to a remote environment.
|
||||||
|
|
||||||
|
|||||||
@@ -5,8 +5,11 @@ wrapper, while preserving Hermes' persistent snapshot behavior across sessions.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import base64
|
||||||
|
import io
|
||||||
import logging
|
import logging
|
||||||
import shlex
|
import shlex
|
||||||
|
import tarfile
|
||||||
import threading
|
import threading
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
@@ -18,7 +21,13 @@ from tools.environments.base import (
|
|||||||
_load_json_store,
|
_load_json_store,
|
||||||
_save_json_store,
|
_save_json_store,
|
||||||
)
|
)
|
||||||
from tools.environments.file_sync import FileSyncManager, iter_sync_files, quoted_rm_command
|
from tools.environments.file_sync import (
|
||||||
|
FileSyncManager,
|
||||||
|
iter_sync_files,
|
||||||
|
quoted_mkdir_command,
|
||||||
|
quoted_rm_command,
|
||||||
|
unique_parent_dirs,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -259,26 +268,86 @@ class ModalEnvironment(BaseEnvironment):
|
|||||||
get_files_fn=lambda: iter_sync_files("/root/.hermes"),
|
get_files_fn=lambda: iter_sync_files("/root/.hermes"),
|
||||||
upload_fn=self._modal_upload,
|
upload_fn=self._modal_upload,
|
||||||
delete_fn=self._modal_delete,
|
delete_fn=self._modal_delete,
|
||||||
|
bulk_upload_fn=self._modal_bulk_upload,
|
||||||
)
|
)
|
||||||
self._sync_manager.sync(force=True)
|
self._sync_manager.sync(force=True)
|
||||||
self.init_session()
|
self.init_session()
|
||||||
|
|
||||||
def _modal_upload(self, host_path: str, remote_path: str) -> None:
|
def _modal_upload(self, host_path: str, remote_path: str) -> None:
|
||||||
"""Upload a single file via base64-over-exec."""
|
"""Upload a single file via base64 piped through stdin."""
|
||||||
import base64
|
|
||||||
content = Path(host_path).read_bytes()
|
content = Path(host_path).read_bytes()
|
||||||
b64 = base64.b64encode(content).decode("ascii")
|
b64 = base64.b64encode(content).decode("ascii")
|
||||||
container_dir = str(Path(remote_path).parent)
|
container_dir = str(Path(remote_path).parent)
|
||||||
cmd = (
|
cmd = (
|
||||||
f"mkdir -p {shlex.quote(container_dir)} && "
|
f"mkdir -p {shlex.quote(container_dir)} && "
|
||||||
f"echo {shlex.quote(b64)} | base64 -d > {shlex.quote(remote_path)}"
|
f"base64 -d > {shlex.quote(remote_path)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _write():
|
async def _write():
|
||||||
proc = await self._sandbox.exec.aio("bash", "-c", cmd)
|
proc = await self._sandbox.exec.aio("bash", "-c", cmd)
|
||||||
await proc.wait.aio()
|
offset = 0
|
||||||
|
chunk_size = self._STDIN_CHUNK_SIZE
|
||||||
|
while offset < len(b64):
|
||||||
|
proc.stdin.write(b64[offset:offset + chunk_size])
|
||||||
|
await proc.stdin.drain.aio()
|
||||||
|
offset += chunk_size
|
||||||
|
proc.stdin.write_eof()
|
||||||
|
await proc.stdin.drain.aio()
|
||||||
|
exit_code = await proc.wait.aio()
|
||||||
|
if exit_code != 0:
|
||||||
|
raise RuntimeError(f"Modal upload failed (exit {exit_code})")
|
||||||
|
|
||||||
self._worker.run_coroutine(_write(), timeout=15)
|
self._worker.run_coroutine(_write(), timeout=30)
|
||||||
|
|
||||||
|
# Modal SDK stdin buffer limit (legacy server path). The command-router
|
||||||
|
# path allows 16 MB, but we must stay under the smaller 2 MB cap for
|
||||||
|
# compatibility. Chunks are written below this threshold and flushed
|
||||||
|
# individually via drain().
|
||||||
|
_STDIN_CHUNK_SIZE = 1 * 1024 * 1024 # 1 MB — safe for both transport paths
|
||||||
|
|
||||||
|
def _modal_bulk_upload(self, files: list[tuple[str, str]]) -> None:
|
||||||
|
"""Upload many files via tar archive piped through stdin.
|
||||||
|
|
||||||
|
Builds a gzipped tar archive in memory and streams it into a
|
||||||
|
``base64 -d | tar xzf -`` pipeline via the process's stdin,
|
||||||
|
avoiding the Modal SDK's 64 KB ``ARG_MAX_BYTES`` exec-arg limit.
|
||||||
|
"""
|
||||||
|
if not files:
|
||||||
|
return
|
||||||
|
|
||||||
|
buf = io.BytesIO()
|
||||||
|
with tarfile.open(fileobj=buf, mode="w:gz") as tar:
|
||||||
|
for host_path, remote_path in files:
|
||||||
|
tar.add(host_path, arcname=remote_path.lstrip("/"))
|
||||||
|
payload = base64.b64encode(buf.getvalue()).decode("ascii")
|
||||||
|
|
||||||
|
parents = unique_parent_dirs(files)
|
||||||
|
mkdir_part = quoted_mkdir_command(parents)
|
||||||
|
cmd = f"{mkdir_part} && base64 -d | tar xzf - -C /"
|
||||||
|
|
||||||
|
async def _bulk():
|
||||||
|
proc = await self._sandbox.exec.aio("bash", "-c", cmd)
|
||||||
|
|
||||||
|
# Stream payload through stdin in chunks to stay under the
|
||||||
|
# SDK's per-write buffer limit (2 MB legacy / 16 MB router).
|
||||||
|
offset = 0
|
||||||
|
chunk_size = self._STDIN_CHUNK_SIZE
|
||||||
|
while offset < len(payload):
|
||||||
|
proc.stdin.write(payload[offset:offset + chunk_size])
|
||||||
|
await proc.stdin.drain.aio()
|
||||||
|
offset += chunk_size
|
||||||
|
|
||||||
|
proc.stdin.write_eof()
|
||||||
|
await proc.stdin.drain.aio()
|
||||||
|
|
||||||
|
exit_code = await proc.wait.aio()
|
||||||
|
if exit_code != 0:
|
||||||
|
stderr_text = await proc.stderr.read.aio()
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Modal bulk upload failed (exit {exit_code}): {stderr_text}"
|
||||||
|
)
|
||||||
|
|
||||||
|
self._worker.run_coroutine(_bulk(), timeout=120)
|
||||||
|
|
||||||
def _modal_delete(self, remote_paths: list[str]) -> None:
|
def _modal_delete(self, remote_paths: list[str]) -> None:
|
||||||
"""Batch-delete remote files via exec."""
|
"""Batch-delete remote files via exec."""
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""SSH remote execution environment with ControlMaster connection persistence."""
|
"""SSH remote execution environment with ControlMaster connection persistence."""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
import shlex
|
import shlex
|
||||||
import shutil
|
import shutil
|
||||||
import subprocess
|
import subprocess
|
||||||
@@ -8,7 +9,13 @@ import tempfile
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from tools.environments.base import BaseEnvironment, _popen_bash
|
from tools.environments.base import BaseEnvironment, _popen_bash
|
||||||
from tools.environments.file_sync import FileSyncManager, iter_sync_files, quoted_rm_command
|
from tools.environments.file_sync import (
|
||||||
|
FileSyncManager,
|
||||||
|
iter_sync_files,
|
||||||
|
quoted_mkdir_command,
|
||||||
|
quoted_rm_command,
|
||||||
|
unique_parent_dirs,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -50,6 +57,7 @@ class SSHEnvironment(BaseEnvironment):
|
|||||||
get_files_fn=lambda: iter_sync_files(f"{self._remote_home}/.hermes"),
|
get_files_fn=lambda: iter_sync_files(f"{self._remote_home}/.hermes"),
|
||||||
upload_fn=self._scp_upload,
|
upload_fn=self._scp_upload,
|
||||||
delete_fn=self._ssh_delete,
|
delete_fn=self._ssh_delete,
|
||||||
|
bulk_upload_fn=self._ssh_bulk_upload,
|
||||||
)
|
)
|
||||||
self._sync_manager.sync(force=True)
|
self._sync_manager.sync(force=True)
|
||||||
|
|
||||||
@@ -107,9 +115,8 @@ class SSHEnvironment(BaseEnvironment):
|
|||||||
"""Create base ~/.hermes directory tree on remote in one SSH call."""
|
"""Create base ~/.hermes directory tree on remote in one SSH call."""
|
||||||
base = f"{self._remote_home}/.hermes"
|
base = f"{self._remote_home}/.hermes"
|
||||||
dirs = [base, f"{base}/skills", f"{base}/credentials", f"{base}/cache"]
|
dirs = [base, f"{base}/skills", f"{base}/credentials", f"{base}/cache"]
|
||||||
mkdir_cmd = "mkdir -p " + " ".join(shlex.quote(d) for d in dirs)
|
|
||||||
cmd = self._build_ssh_command()
|
cmd = self._build_ssh_command()
|
||||||
cmd.append(mkdir_cmd)
|
cmd.append(quoted_mkdir_command(dirs))
|
||||||
subprocess.run(cmd, capture_output=True, text=True, timeout=10)
|
subprocess.run(cmd, capture_output=True, text=True, timeout=10)
|
||||||
|
|
||||||
# _get_sync_files provided via iter_sync_files in FileSyncManager init
|
# _get_sync_files provided via iter_sync_files in FileSyncManager init
|
||||||
@@ -131,6 +138,86 @@ class SSHEnvironment(BaseEnvironment):
|
|||||||
if result.returncode != 0:
|
if result.returncode != 0:
|
||||||
raise RuntimeError(f"scp failed: {result.stderr.strip()}")
|
raise RuntimeError(f"scp failed: {result.stderr.strip()}")
|
||||||
|
|
||||||
|
def _ssh_bulk_upload(self, files: list[tuple[str, str]]) -> None:
|
||||||
|
"""Upload many files in a single tar-over-SSH stream.
|
||||||
|
|
||||||
|
Pipes ``tar c`` on the local side through an SSH connection to
|
||||||
|
``tar x`` on the remote, transferring all files in one TCP stream
|
||||||
|
instead of spawning a subprocess per file. Directory creation is
|
||||||
|
batched into a single ``mkdir -p`` call beforehand.
|
||||||
|
|
||||||
|
Typical improvement: ~580 files goes from O(N) scp round-trips
|
||||||
|
to a single streaming transfer.
|
||||||
|
"""
|
||||||
|
if not files:
|
||||||
|
return
|
||||||
|
|
||||||
|
parents = unique_parent_dirs(files)
|
||||||
|
if parents:
|
||||||
|
cmd = self._build_ssh_command()
|
||||||
|
cmd.append(quoted_mkdir_command(parents))
|
||||||
|
result = subprocess.run(cmd, capture_output=True, text=True, timeout=30)
|
||||||
|
if result.returncode != 0:
|
||||||
|
raise RuntimeError(f"remote mkdir failed: {result.stderr.strip()}")
|
||||||
|
|
||||||
|
# Symlink staging avoids fragile GNU tar --transform rules.
|
||||||
|
with tempfile.TemporaryDirectory(prefix="hermes-ssh-bulk-") as staging:
|
||||||
|
for host_path, remote_path in files:
|
||||||
|
staged = os.path.join(staging, remote_path.lstrip("/"))
|
||||||
|
os.makedirs(os.path.dirname(staged), exist_ok=True)
|
||||||
|
os.symlink(os.path.abspath(host_path), staged)
|
||||||
|
|
||||||
|
tar_cmd = ["tar", "-chf", "-", "-C", staging, "."]
|
||||||
|
ssh_cmd = self._build_ssh_command()
|
||||||
|
ssh_cmd.append("tar xf - -C /")
|
||||||
|
|
||||||
|
tar_proc = subprocess.Popen(
|
||||||
|
tar_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
ssh_proc = subprocess.Popen(
|
||||||
|
ssh_cmd, stdin=tar_proc.stdout, stdout=subprocess.PIPE,
|
||||||
|
stderr=subprocess.PIPE,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
tar_proc.kill()
|
||||||
|
tar_proc.wait()
|
||||||
|
raise
|
||||||
|
|
||||||
|
# Allow tar_proc to receive SIGPIPE if ssh_proc exits early
|
||||||
|
tar_proc.stdout.close()
|
||||||
|
|
||||||
|
try:
|
||||||
|
_, ssh_stderr = ssh_proc.communicate(timeout=120)
|
||||||
|
# Use communicate() instead of wait() to drain stderr and
|
||||||
|
# avoid deadlock if tar produces more than PIPE_BUF of errors.
|
||||||
|
# stdout is already closed (for SIGPIPE); only drain stderr.
|
||||||
|
# Cannot use communicate() here — it would call fileno()
|
||||||
|
# on the closed stdout fd and raise ValueError.
|
||||||
|
tar_stderr_raw = b""
|
||||||
|
if tar_proc.poll() is None:
|
||||||
|
tar_proc.wait(timeout=10)
|
||||||
|
tar_stderr_raw = tar_proc.stderr.read() if tar_proc.stderr else b""
|
||||||
|
except subprocess.TimeoutExpired:
|
||||||
|
tar_proc.kill()
|
||||||
|
ssh_proc.kill()
|
||||||
|
tar_proc.wait()
|
||||||
|
ssh_proc.wait()
|
||||||
|
raise RuntimeError("SSH bulk upload timed out")
|
||||||
|
|
||||||
|
if tar_proc.returncode != 0:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"tar create failed (rc={tar_proc.returncode}): "
|
||||||
|
f"{tar_stderr_raw.decode(errors='replace').strip()}"
|
||||||
|
)
|
||||||
|
if ssh_proc.returncode != 0:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"tar extract over SSH failed (rc={ssh_proc.returncode}): "
|
||||||
|
f"{ssh_stderr.decode(errors='replace').strip()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug("SSH: bulk-uploaded %d file(s) via tar pipe", len(files))
|
||||||
|
|
||||||
def _ssh_delete(self, remote_paths: list[str]) -> None:
|
def _ssh_delete(self, remote_paths: list[str]) -> None:
|
||||||
"""Batch-delete remote files in one SSH call."""
|
"""Batch-delete remote files in one SSH call."""
|
||||||
cmd = self._build_ssh_command()
|
cmd = self._build_ssh_command()
|
||||||
|
|||||||
Reference in New Issue
Block a user