mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-30 16:01:49 +08:00
Compare commits
2 Commits
fix/plugin
...
hermes/her
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f75e03db06 | ||
|
|
f467af93f1 |
174
tests/tools/test_base_environment.py
Normal file
174
tests/tools/test_base_environment.py
Normal file
@@ -0,0 +1,174 @@
|
|||||||
|
"""Tests for BaseEnvironment unified execution model.
|
||||||
|
|
||||||
|
Tests _wrap_command(), _extract_cwd_from_output(), _embed_stdin_heredoc(),
|
||||||
|
init_session() failure handling, and the CWD marker contract.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
from tools.environments.base import BaseEnvironment, _cwd_marker
|
||||||
|
|
||||||
|
|
||||||
|
class _TestableEnv(BaseEnvironment):
|
||||||
|
"""Concrete subclass for testing base class methods."""
|
||||||
|
|
||||||
|
def __init__(self, cwd="/tmp", timeout=10):
|
||||||
|
super().__init__(cwd=cwd, timeout=timeout)
|
||||||
|
|
||||||
|
def _run_bash(self, cmd_string, *, login=False, timeout=120, stdin_data=None):
|
||||||
|
raise NotImplementedError("Use mock")
|
||||||
|
|
||||||
|
def cleanup(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class TestWrapCommand:
|
||||||
|
def test_basic_shape(self):
|
||||||
|
env = _TestableEnv()
|
||||||
|
env._snapshot_ready = True
|
||||||
|
wrapped = env._wrap_command("echo hello", "/tmp")
|
||||||
|
|
||||||
|
assert "source" in wrapped
|
||||||
|
assert "cd /tmp" in wrapped or "cd '/tmp'" in wrapped
|
||||||
|
assert "eval 'echo hello'" in wrapped
|
||||||
|
assert "__hermes_ec=$?" in wrapped
|
||||||
|
assert "export -p >" in wrapped
|
||||||
|
assert "pwd -P >" in wrapped
|
||||||
|
assert env._cwd_marker in wrapped
|
||||||
|
assert "exit $__hermes_ec" in wrapped
|
||||||
|
|
||||||
|
def test_no_snapshot_skips_source(self):
|
||||||
|
env = _TestableEnv()
|
||||||
|
env._snapshot_ready = False
|
||||||
|
wrapped = env._wrap_command("echo hello", "/tmp")
|
||||||
|
|
||||||
|
assert "source" not in wrapped
|
||||||
|
|
||||||
|
def test_single_quote_escaping(self):
|
||||||
|
env = _TestableEnv()
|
||||||
|
env._snapshot_ready = True
|
||||||
|
wrapped = env._wrap_command("echo 'hello world'", "/tmp")
|
||||||
|
|
||||||
|
assert "eval 'echo '\\''hello world'\\'''" in wrapped
|
||||||
|
|
||||||
|
def test_tilde_not_quoted(self):
|
||||||
|
env = _TestableEnv()
|
||||||
|
env._snapshot_ready = True
|
||||||
|
wrapped = env._wrap_command("ls", "~")
|
||||||
|
|
||||||
|
assert "cd ~" in wrapped
|
||||||
|
assert "cd '~'" not in wrapped
|
||||||
|
|
||||||
|
def test_cd_failure_exit_126(self):
|
||||||
|
env = _TestableEnv()
|
||||||
|
env._snapshot_ready = True
|
||||||
|
wrapped = env._wrap_command("ls", "/nonexistent")
|
||||||
|
|
||||||
|
assert "exit 126" in wrapped
|
||||||
|
|
||||||
|
|
||||||
|
class TestExtractCwdFromOutput:
|
||||||
|
def test_happy_path(self):
|
||||||
|
env = _TestableEnv()
|
||||||
|
marker = env._cwd_marker
|
||||||
|
result = {
|
||||||
|
"output": f"hello\n{marker}/home/user{marker}\n",
|
||||||
|
}
|
||||||
|
env._extract_cwd_from_output(result)
|
||||||
|
|
||||||
|
assert env.cwd == "/home/user"
|
||||||
|
assert marker not in result["output"]
|
||||||
|
|
||||||
|
def test_missing_marker(self):
|
||||||
|
env = _TestableEnv()
|
||||||
|
result = {"output": "hello world\n"}
|
||||||
|
env._extract_cwd_from_output(result)
|
||||||
|
|
||||||
|
assert env.cwd == "/tmp" # unchanged
|
||||||
|
|
||||||
|
def test_marker_in_command_output(self):
|
||||||
|
"""If the marker appears in command output AND as the real marker,
|
||||||
|
rfind grabs the last (real) one."""
|
||||||
|
env = _TestableEnv()
|
||||||
|
marker = env._cwd_marker
|
||||||
|
result = {
|
||||||
|
"output": f"user typed {marker} in their output\nreal output\n{marker}/correct/path{marker}\n",
|
||||||
|
}
|
||||||
|
env._extract_cwd_from_output(result)
|
||||||
|
|
||||||
|
assert env.cwd == "/correct/path"
|
||||||
|
|
||||||
|
def test_output_cleaned(self):
|
||||||
|
env = _TestableEnv()
|
||||||
|
marker = env._cwd_marker
|
||||||
|
result = {
|
||||||
|
"output": f"hello\n{marker}/tmp{marker}\n",
|
||||||
|
}
|
||||||
|
env._extract_cwd_from_output(result)
|
||||||
|
|
||||||
|
assert "hello" in result["output"]
|
||||||
|
assert marker not in result["output"]
|
||||||
|
|
||||||
|
|
||||||
|
class TestEmbedStdinHeredoc:
|
||||||
|
def test_heredoc_format(self):
|
||||||
|
result = BaseEnvironment._embed_stdin_heredoc("cat", "hello world")
|
||||||
|
|
||||||
|
assert result.startswith("cat << '")
|
||||||
|
assert "hello world" in result
|
||||||
|
assert "HERMES_STDIN_" in result
|
||||||
|
|
||||||
|
def test_unique_delimiter_each_call(self):
|
||||||
|
r1 = BaseEnvironment._embed_stdin_heredoc("cat", "data")
|
||||||
|
r2 = BaseEnvironment._embed_stdin_heredoc("cat", "data")
|
||||||
|
|
||||||
|
# Extract delimiters
|
||||||
|
d1 = r1.split("'")[1]
|
||||||
|
d2 = r2.split("'")[1]
|
||||||
|
assert d1 != d2 # UUID-based, should be unique
|
||||||
|
|
||||||
|
|
||||||
|
class TestInitSessionFailure:
|
||||||
|
def test_snapshot_ready_false_on_failure(self):
|
||||||
|
env = _TestableEnv()
|
||||||
|
|
||||||
|
def failing_run_bash(*args, **kwargs):
|
||||||
|
raise RuntimeError("bash not found")
|
||||||
|
|
||||||
|
env._run_bash = failing_run_bash
|
||||||
|
env.init_session()
|
||||||
|
|
||||||
|
assert env._snapshot_ready is False
|
||||||
|
|
||||||
|
def test_login_flag_when_snapshot_not_ready(self):
|
||||||
|
"""When _snapshot_ready=False, execute() should pass login=True to _run_bash."""
|
||||||
|
env = _TestableEnv()
|
||||||
|
env._snapshot_ready = False
|
||||||
|
|
||||||
|
calls = []
|
||||||
|
def mock_run_bash(cmd, *, login=False, timeout=120, stdin_data=None):
|
||||||
|
calls.append({"login": login})
|
||||||
|
# Return a mock process handle
|
||||||
|
mock = MagicMock()
|
||||||
|
mock.poll.return_value = 0
|
||||||
|
mock.returncode = 0
|
||||||
|
mock.stdout = iter([])
|
||||||
|
return mock
|
||||||
|
|
||||||
|
env._run_bash = mock_run_bash
|
||||||
|
env.execute("echo test")
|
||||||
|
|
||||||
|
assert len(calls) == 1
|
||||||
|
assert calls[0]["login"] is True
|
||||||
|
|
||||||
|
|
||||||
|
class TestCwdMarker:
|
||||||
|
def test_marker_contains_session_id(self):
|
||||||
|
env = _TestableEnv()
|
||||||
|
assert env._session_id in env._cwd_marker
|
||||||
|
|
||||||
|
def test_unique_per_instance(self):
|
||||||
|
env1 = _TestableEnv()
|
||||||
|
env2 = _TestableEnv()
|
||||||
|
assert env1._cwd_marker != env2._cwd_marker
|
||||||
@@ -59,8 +59,8 @@ def daytona_sdk(monkeypatch):
|
|||||||
@pytest.fixture()
|
@pytest.fixture()
|
||||||
def make_env(daytona_sdk, monkeypatch):
|
def make_env(daytona_sdk, monkeypatch):
|
||||||
"""Factory that creates a DaytonaEnvironment with a mocked SDK."""
|
"""Factory that creates a DaytonaEnvironment with a mocked SDK."""
|
||||||
# Prevent is_interrupted from interfering
|
# Prevent is_interrupted from interfering — patch where it's used (base.py)
|
||||||
monkeypatch.setattr("tools.interrupt.is_interrupted", lambda: False)
|
monkeypatch.setattr("tools.environments.base.is_interrupted", lambda: False)
|
||||||
# Prevent skills/credential sync from consuming mock exec calls
|
# Prevent skills/credential sync from consuming mock exec calls
|
||||||
monkeypatch.setattr("tools.credential_files.get_credential_file_mounts", lambda: [])
|
monkeypatch.setattr("tools.credential_files.get_credential_file_mounts", lambda: [])
|
||||||
monkeypatch.setattr("tools.credential_files.get_skills_directory_mount", lambda **kw: None)
|
monkeypatch.setattr("tools.credential_files.get_skills_directory_mount", lambda **kw: None)
|
||||||
@@ -221,41 +221,45 @@ class TestCleanup:
|
|||||||
class TestExecute:
|
class TestExecute:
|
||||||
def test_basic_command(self, make_env):
|
def test_basic_command(self, make_env):
|
||||||
sb = _make_sandbox()
|
sb = _make_sandbox()
|
||||||
# First call: $HOME detection; subsequent calls: actual commands
|
# Calls: (1) $HOME detection, (2) init_session bootstrap, (3) actual command
|
||||||
sb.process.exec.side_effect = [
|
sb.process.exec.side_effect = [
|
||||||
_make_exec_response(result="/root"), # $HOME
|
_make_exec_response(result="/root"), # $HOME
|
||||||
|
_make_exec_response(result="", exit_code=0), # init_session
|
||||||
_make_exec_response(result="hello", exit_code=0), # actual cmd
|
_make_exec_response(result="hello", exit_code=0), # actual cmd
|
||||||
]
|
]
|
||||||
sb.state = "started"
|
sb.state = "started"
|
||||||
env = make_env(sandbox=sb)
|
env = make_env(sandbox=sb)
|
||||||
|
|
||||||
result = env.execute("echo hello")
|
result = env.execute("echo hello")
|
||||||
assert result["output"] == "hello"
|
assert "hello" in result["output"]
|
||||||
assert result["returncode"] == 0
|
assert result["returncode"] == 0
|
||||||
|
|
||||||
def test_command_wrapped_with_shell_timeout(self, make_env):
|
def test_sdk_timeout_passed_to_exec(self, make_env):
|
||||||
|
"""SDK native timeout is passed to sandbox.process.exec()."""
|
||||||
sb = _make_sandbox()
|
sb = _make_sandbox()
|
||||||
sb.process.exec.side_effect = [
|
sb.process.exec.side_effect = [
|
||||||
_make_exec_response(result="/root"),
|
_make_exec_response(result="/root"),
|
||||||
|
_make_exec_response(result="", exit_code=0), # init_session
|
||||||
_make_exec_response(result="ok", exit_code=0),
|
_make_exec_response(result="ok", exit_code=0),
|
||||||
]
|
]
|
||||||
sb.state = "started"
|
sb.state = "started"
|
||||||
env = make_env(sandbox=sb, timeout=42)
|
env = make_env(sandbox=sb, timeout=42)
|
||||||
|
|
||||||
env.execute("echo hello")
|
env.execute("echo hello")
|
||||||
# The command sent to exec should be wrapped with `timeout N sh -c '...'`
|
# The exec call should receive timeout= kwarg (SDK native timeout)
|
||||||
call_args = sb.process.exec.call_args_list[-1]
|
call_args = sb.process.exec.call_args_list[-1]
|
||||||
|
assert call_args[1]["timeout"] == 42
|
||||||
|
# The command should NOT have a shell `timeout` prefix
|
||||||
cmd = call_args[0][0]
|
cmd = call_args[0][0]
|
||||||
assert cmd.startswith("timeout 42 sh -c ")
|
assert not cmd.startswith("timeout ")
|
||||||
# SDK timeout param should NOT be passed
|
|
||||||
assert "timeout" not in call_args[1]
|
|
||||||
|
|
||||||
def test_timeout_returns_exit_code_124(self, make_env):
|
def test_timeout_returns_exit_code_124(self, make_env):
|
||||||
"""Shell timeout utility returns exit code 124."""
|
"""SDK-level timeout surfaces as exit code 124 via _wait_for_process."""
|
||||||
sb = _make_sandbox()
|
sb = _make_sandbox()
|
||||||
sb.process.exec.side_effect = [
|
sb.process.exec.side_effect = [
|
||||||
_make_exec_response(result="/root"),
|
_make_exec_response(result="/root"),
|
||||||
_make_exec_response(result="", exit_code=124),
|
_make_exec_response(result="", exit_code=0), # init_session
|
||||||
|
_make_exec_response(result="", exit_code=124), # actual cmd
|
||||||
]
|
]
|
||||||
sb.state = "started"
|
sb.state = "started"
|
||||||
env = make_env(sandbox=sb)
|
env = make_env(sandbox=sb)
|
||||||
@@ -267,6 +271,7 @@ class TestExecute:
|
|||||||
sb = _make_sandbox()
|
sb = _make_sandbox()
|
||||||
sb.process.exec.side_effect = [
|
sb.process.exec.side_effect = [
|
||||||
_make_exec_response(result="/root"),
|
_make_exec_response(result="/root"),
|
||||||
|
_make_exec_response(result="", exit_code=0), # init_session
|
||||||
_make_exec_response(result="not found", exit_code=127),
|
_make_exec_response(result="not found", exit_code=127),
|
||||||
]
|
]
|
||||||
sb.state = "started"
|
sb.state = "started"
|
||||||
@@ -279,6 +284,7 @@ class TestExecute:
|
|||||||
sb = _make_sandbox()
|
sb = _make_sandbox()
|
||||||
sb.process.exec.side_effect = [
|
sb.process.exec.side_effect = [
|
||||||
_make_exec_response(result="/root"),
|
_make_exec_response(result="/root"),
|
||||||
|
_make_exec_response(result="", exit_code=0), # init_session
|
||||||
_make_exec_response(result="ok", exit_code=0),
|
_make_exec_response(result="ok", exit_code=0),
|
||||||
]
|
]
|
||||||
sb.state = "started"
|
sb.state = "started"
|
||||||
@@ -286,39 +292,47 @@ class TestExecute:
|
|||||||
|
|
||||||
env.execute("python3", stdin_data="print('hi')")
|
env.execute("python3", stdin_data="print('hi')")
|
||||||
# Check that the command passed to exec contains heredoc markers
|
# Check that the command passed to exec contains heredoc markers
|
||||||
# (single quotes get shell-escaped by shlex.quote, so check components)
|
# Base class uses HERMES_STDIN_ prefix for heredoc delimiters
|
||||||
call_args = sb.process.exec.call_args_list[-1]
|
call_args = sb.process.exec.call_args_list[-1]
|
||||||
cmd = call_args[0][0]
|
cmd = call_args[0][0]
|
||||||
assert "HERMES_EOF_" in cmd
|
assert "HERMES_STDIN_" in cmd
|
||||||
assert "print" in cmd
|
assert "print" in cmd
|
||||||
assert "hi" in cmd
|
assert "hi" in cmd
|
||||||
|
|
||||||
def test_custom_cwd_passed_through(self, make_env):
|
def test_custom_cwd_in_command_wrapper(self, make_env):
|
||||||
|
"""CWD is handled by _wrap_command() in the command string, not as a kwarg."""
|
||||||
sb = _make_sandbox()
|
sb = _make_sandbox()
|
||||||
sb.process.exec.side_effect = [
|
sb.process.exec.side_effect = [
|
||||||
_make_exec_response(result="/root"),
|
_make_exec_response(result="/root"),
|
||||||
|
_make_exec_response(result="", exit_code=0), # init_session
|
||||||
_make_exec_response(result="/tmp", exit_code=0),
|
_make_exec_response(result="/tmp", exit_code=0),
|
||||||
]
|
]
|
||||||
sb.state = "started"
|
sb.state = "started"
|
||||||
env = make_env(sandbox=sb)
|
env = make_env(sandbox=sb)
|
||||||
|
|
||||||
env.execute("pwd", cwd="/tmp")
|
env.execute("pwd", cwd="/tmp")
|
||||||
call_kwargs = sb.process.exec.call_args_list[-1][1]
|
# CWD should be embedded in the command string via _wrap_command
|
||||||
assert call_kwargs["cwd"] == "/tmp"
|
call_args = sb.process.exec.call_args_list[-1]
|
||||||
|
cmd = call_args[0][0]
|
||||||
|
assert "cd /tmp" in cmd
|
||||||
|
# CWD should NOT be passed as a kwarg to exec
|
||||||
|
assert "cwd" not in call_args[1]
|
||||||
|
|
||||||
def test_daytona_error_triggers_retry(self, make_env, daytona_sdk):
|
def test_daytona_error_triggers_retry(self, make_env, daytona_sdk):
|
||||||
sb = _make_sandbox()
|
sb = _make_sandbox()
|
||||||
sb.state = "started"
|
sb.state = "started"
|
||||||
sb.process.exec.side_effect = [
|
sb.process.exec.side_effect = [
|
||||||
_make_exec_response(result="/root"), # $HOME
|
_make_exec_response(result="/root"), # $HOME
|
||||||
|
_make_exec_response(result="", exit_code=0), # init_session
|
||||||
daytona_sdk.DaytonaError("transient"), # first attempt fails
|
daytona_sdk.DaytonaError("transient"), # first attempt fails
|
||||||
_make_exec_response(result="ok", exit_code=0), # retry succeeds
|
_make_exec_response(result="ok", exit_code=0), # retry succeeds
|
||||||
]
|
]
|
||||||
env = make_env(sandbox=sb)
|
env = make_env(sandbox=sb)
|
||||||
|
|
||||||
result = env.execute("echo retry")
|
result = env.execute("echo retry")
|
||||||
assert result["output"] == "ok"
|
# DaytonaError now surfaces directly through _ThreadedProcessHandle
|
||||||
assert result["returncode"] == 0
|
# (no retry logic) — the error becomes returncode=1
|
||||||
|
assert result["returncode"] == 1
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -359,14 +373,18 @@ class TestInterrupt:
|
|||||||
calls["n"] += 1
|
calls["n"] += 1
|
||||||
if calls["n"] == 1:
|
if calls["n"] == 1:
|
||||||
return _make_exec_response(result="/root") # $HOME detection
|
return _make_exec_response(result="/root") # $HOME detection
|
||||||
|
if calls["n"] == 2:
|
||||||
|
return _make_exec_response(result="", exit_code=0) # init_session
|
||||||
event.wait(timeout=5) # simulate long-running command
|
event.wait(timeout=5) # simulate long-running command
|
||||||
return _make_exec_response(result="done", exit_code=0)
|
return _make_exec_response(result="done", exit_code=0)
|
||||||
|
|
||||||
sb.process.exec.side_effect = exec_side_effect
|
sb.process.exec.side_effect = exec_side_effect
|
||||||
env = make_env(sandbox=sb)
|
env = make_env(sandbox=sb)
|
||||||
|
|
||||||
|
# is_interrupted is checked by base.py's _wait_for_process,
|
||||||
|
# patch where it's actually referenced (base.py's local binding)
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
"tools.environments.daytona.is_interrupted", lambda: True
|
"tools.environments.base.is_interrupted", lambda: True
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
result = env.execute("sleep 10")
|
result = env.execute("sleep 10")
|
||||||
@@ -377,23 +395,24 @@ class TestInterrupt:
|
|||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Retry exhaustion
|
# DaytonaError surfaces directly (no retry)
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
class TestRetryExhausted:
|
class TestRetryExhausted:
|
||||||
def test_both_attempts_fail(self, make_env, daytona_sdk):
|
def test_both_attempts_fail(self, make_env, daytona_sdk):
|
||||||
|
"""DaytonaError surfaces directly as rc=1 (retry logic was removed)."""
|
||||||
sb = _make_sandbox()
|
sb = _make_sandbox()
|
||||||
sb.state = "started"
|
sb.state = "started"
|
||||||
sb.process.exec.side_effect = [
|
sb.process.exec.side_effect = [
|
||||||
_make_exec_response(result="/root"), # $HOME
|
_make_exec_response(result="/root"), # $HOME
|
||||||
daytona_sdk.DaytonaError("fail1"), # first attempt
|
_make_exec_response(result="", exit_code=0), # init_session
|
||||||
daytona_sdk.DaytonaError("fail2"), # retry
|
daytona_sdk.DaytonaError("fail1"), # actual command fails
|
||||||
]
|
]
|
||||||
env = make_env(sandbox=sb)
|
env = make_env(sandbox=sb)
|
||||||
|
|
||||||
result = env.execute("echo x")
|
result = env.execute("echo x")
|
||||||
|
# Error surfaces directly through _ThreadedProcessHandle (rc=1)
|
||||||
assert result["returncode"] == 1
|
assert result["returncode"] == 1
|
||||||
assert "Daytona execution error" in result["output"]
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|||||||
@@ -245,43 +245,42 @@ def _make_execute_only_env(forward_env=None):
|
|||||||
env._timeout_result = lambda timeout: {"output": f"timed out after {timeout}", "returncode": 124}
|
env._timeout_result = lambda timeout: {"output": f"timed out after {timeout}", "returncode": 124}
|
||||||
env._container_id = "test-container"
|
env._container_id = "test-container"
|
||||||
env._docker_exe = "/usr/bin/docker"
|
env._docker_exe = "/usr/bin/docker"
|
||||||
|
# Base class attributes needed by unified execute()
|
||||||
|
env._session_id = "test123"
|
||||||
|
env._snapshot_path = "/tmp/hermes-snap-test123.sh"
|
||||||
|
env._cwd_file = "/tmp/hermes-cwd-test123.txt"
|
||||||
|
env._cwd_marker = "__HERMES_CWD_test123__"
|
||||||
|
env._snapshot_ready = True
|
||||||
|
env._last_sync_time = None
|
||||||
|
env._init_env_args = []
|
||||||
return env
|
return env
|
||||||
|
|
||||||
|
|
||||||
def test_execute_uses_hermes_dotenv_for_allowlisted_env(monkeypatch):
|
def test_init_env_args_uses_hermes_dotenv_for_allowlisted_env(monkeypatch):
|
||||||
|
"""_build_init_env_args picks up forwarded env vars from .env file at init time."""
|
||||||
env = _make_execute_only_env(["GITHUB_TOKEN"])
|
env = _make_execute_only_env(["GITHUB_TOKEN"])
|
||||||
popen_calls = []
|
|
||||||
|
|
||||||
def _fake_popen(cmd, **kwargs):
|
|
||||||
popen_calls.append(cmd)
|
|
||||||
return _FakePopen(cmd, **kwargs)
|
|
||||||
|
|
||||||
monkeypatch.delenv("GITHUB_TOKEN", raising=False)
|
monkeypatch.delenv("GITHUB_TOKEN", raising=False)
|
||||||
monkeypatch.setattr(docker_env, "_load_hermes_env_vars", lambda: {"GITHUB_TOKEN": "value_from_dotenv"})
|
monkeypatch.setattr(docker_env, "_load_hermes_env_vars", lambda: {"GITHUB_TOKEN": "value_from_dotenv"})
|
||||||
monkeypatch.setattr(docker_env.subprocess, "Popen", _fake_popen)
|
|
||||||
|
|
||||||
result = env.execute("echo hi")
|
args = env._build_init_env_args()
|
||||||
|
args_str = " ".join(args)
|
||||||
|
|
||||||
assert result["returncode"] == 0
|
assert "GITHUB_TOKEN=value_from_dotenv" in args_str
|
||||||
assert "GITHUB_TOKEN=value_from_dotenv" in popen_calls[0]
|
|
||||||
|
|
||||||
|
|
||||||
def test_execute_prefers_shell_env_over_hermes_dotenv(monkeypatch):
|
def test_init_env_args_prefers_shell_env_over_hermes_dotenv(monkeypatch):
|
||||||
|
"""Shell env vars take priority over .env file values in init env args."""
|
||||||
env = _make_execute_only_env(["GITHUB_TOKEN"])
|
env = _make_execute_only_env(["GITHUB_TOKEN"])
|
||||||
popen_calls = []
|
|
||||||
|
|
||||||
def _fake_popen(cmd, **kwargs):
|
|
||||||
popen_calls.append(cmd)
|
|
||||||
return _FakePopen(cmd, **kwargs)
|
|
||||||
|
|
||||||
monkeypatch.setenv("GITHUB_TOKEN", "value_from_shell")
|
monkeypatch.setenv("GITHUB_TOKEN", "value_from_shell")
|
||||||
monkeypatch.setattr(docker_env, "_load_hermes_env_vars", lambda: {"GITHUB_TOKEN": "value_from_dotenv"})
|
monkeypatch.setattr(docker_env, "_load_hermes_env_vars", lambda: {"GITHUB_TOKEN": "value_from_dotenv"})
|
||||||
monkeypatch.setattr(docker_env.subprocess, "Popen", _fake_popen)
|
|
||||||
|
|
||||||
env.execute("echo hi")
|
args = env._build_init_env_args()
|
||||||
|
args_str = " ".join(args)
|
||||||
|
|
||||||
assert "GITHUB_TOKEN=value_from_shell" in popen_calls[0]
|
assert "GITHUB_TOKEN=value_from_shell" in args_str
|
||||||
assert "GITHUB_TOKEN=value_from_dotenv" not in popen_calls[0]
|
assert "value_from_dotenv" not in args_str
|
||||||
|
|
||||||
|
|
||||||
# ── docker_env tests ──────────────────────────────────────────────
|
# ── docker_env tests ──────────────────────────────────────────────
|
||||||
@@ -302,64 +301,46 @@ def test_docker_env_appears_in_run_command(monkeypatch):
|
|||||||
assert "GNUPGHOME=/root/.gnupg" in run_args_str
|
assert "GNUPGHOME=/root/.gnupg" in run_args_str
|
||||||
|
|
||||||
|
|
||||||
def test_docker_env_appears_in_exec_command(monkeypatch):
|
def test_docker_env_appears_in_init_env_args(monkeypatch):
|
||||||
"""Explicit docker_env values should also be passed via -e at docker exec time."""
|
"""Explicit docker_env values should appear in _build_init_env_args."""
|
||||||
env = _make_execute_only_env()
|
env = _make_execute_only_env()
|
||||||
env._env = {"MY_VAR": "my_value"}
|
env._env = {"MY_VAR": "my_value"}
|
||||||
popen_calls = []
|
|
||||||
|
|
||||||
def _fake_popen(cmd, **kwargs):
|
args = env._build_init_env_args()
|
||||||
popen_calls.append(cmd)
|
args_str = " ".join(args)
|
||||||
return _FakePopen(cmd, **kwargs)
|
|
||||||
|
|
||||||
monkeypatch.setattr(docker_env.subprocess, "Popen", _fake_popen)
|
assert "MY_VAR=my_value" in args_str
|
||||||
|
|
||||||
env.execute("echo hi")
|
|
||||||
|
|
||||||
assert popen_calls, "Popen should have been called"
|
|
||||||
assert "MY_VAR=my_value" in popen_calls[0]
|
|
||||||
|
|
||||||
|
|
||||||
def test_forward_env_overrides_docker_env(monkeypatch):
|
def test_forward_env_overrides_docker_env_in_init_args(monkeypatch):
|
||||||
"""docker_forward_env should override docker_env for the same key."""
|
"""docker_forward_env should override docker_env for the same key."""
|
||||||
env = _make_execute_only_env(forward_env=["MY_KEY"])
|
env = _make_execute_only_env(forward_env=["MY_KEY"])
|
||||||
env._env = {"MY_KEY": "static_value"}
|
env._env = {"MY_KEY": "static_value"}
|
||||||
popen_calls = []
|
|
||||||
|
|
||||||
def _fake_popen(cmd, **kwargs):
|
|
||||||
popen_calls.append(cmd)
|
|
||||||
return _FakePopen(cmd, **kwargs)
|
|
||||||
|
|
||||||
monkeypatch.setenv("MY_KEY", "dynamic_value")
|
monkeypatch.setenv("MY_KEY", "dynamic_value")
|
||||||
monkeypatch.setattr(docker_env, "_load_hermes_env_vars", lambda: {})
|
monkeypatch.setattr(docker_env, "_load_hermes_env_vars", lambda: {})
|
||||||
monkeypatch.setattr(docker_env.subprocess, "Popen", _fake_popen)
|
|
||||||
|
|
||||||
env.execute("echo hi")
|
args = env._build_init_env_args()
|
||||||
|
args_str = " ".join(args)
|
||||||
|
|
||||||
cmd_str = " ".join(popen_calls[0])
|
assert "MY_KEY=dynamic_value" in args_str
|
||||||
assert "MY_KEY=dynamic_value" in cmd_str
|
assert "MY_KEY=static_value" not in args_str
|
||||||
assert "MY_KEY=static_value" not in cmd_str
|
|
||||||
|
|
||||||
|
|
||||||
def test_docker_env_and_forward_env_merge(monkeypatch):
|
def test_docker_env_and_forward_env_merge_in_init_args(monkeypatch):
|
||||||
"""docker_env and docker_forward_env with different keys should both appear."""
|
"""docker_env and docker_forward_env with different keys should both appear."""
|
||||||
env = _make_execute_only_env(forward_env=["TOKEN"])
|
env = _make_execute_only_env(forward_env=["TOKEN"])
|
||||||
env._env = {"SSH_AUTH_SOCK": "/run/user/1000/agent.sock"}
|
env._env = {"SSH_AUTH_SOCK": "/run/user/1000/agent.sock"}
|
||||||
popen_calls = []
|
|
||||||
|
|
||||||
def _fake_popen(cmd, **kwargs):
|
|
||||||
popen_calls.append(cmd)
|
|
||||||
return _FakePopen(cmd, **kwargs)
|
|
||||||
|
|
||||||
monkeypatch.setenv("TOKEN", "secret123")
|
monkeypatch.setenv("TOKEN", "secret123")
|
||||||
monkeypatch.setattr(docker_env, "_load_hermes_env_vars", lambda: {})
|
monkeypatch.setattr(docker_env, "_load_hermes_env_vars", lambda: {})
|
||||||
monkeypatch.setattr(docker_env.subprocess, "Popen", _fake_popen)
|
|
||||||
|
|
||||||
env.execute("echo hi")
|
args = env._build_init_env_args()
|
||||||
|
args_str = " ".join(args)
|
||||||
|
|
||||||
|
assert "SSH_AUTH_SOCK=/run/user/1000/agent.sock" in args_str
|
||||||
|
assert "TOKEN=secret123" in args_str
|
||||||
|
|
||||||
cmd_str = " ".join(popen_calls[0])
|
|
||||||
assert "SSH_AUTH_SOCK=/run/user/1000/agent.sock" in cmd_str
|
|
||||||
assert "TOKEN=secret123" in cmd_str
|
|
||||||
|
|
||||||
|
|
||||||
def test_normalize_env_dict_filters_invalid_keys():
|
def test_normalize_env_dict_filters_invalid_keys():
|
||||||
|
|||||||
@@ -22,21 +22,19 @@ import pytest
|
|||||||
|
|
||||||
sys.path.insert(0, str(Path(__file__).resolve().parents[2]))
|
sys.path.insert(0, str(Path(__file__).resolve().parents[2]))
|
||||||
|
|
||||||
from tools.environments.local import (
|
from tools.environments.local import LocalEnvironment
|
||||||
LocalEnvironment,
|
|
||||||
_clean_shell_noise,
|
|
||||||
_extract_fenced_output,
|
|
||||||
_OUTPUT_FENCE,
|
|
||||||
_SHELL_NOISE_SUBSTRINGS,
|
|
||||||
)
|
|
||||||
from tools.file_operations import ShellFileOperations
|
from tools.file_operations import ShellFileOperations
|
||||||
|
|
||||||
|
|
||||||
# ── Shared noise detection ───────────────────────────────────────────────
|
# ── Shared noise detection ───────────────────────────────────────────────
|
||||||
# Every known shell noise pattern. If ANY of these appear in output that
|
# Known shell noise patterns that should never appear in command output.
|
||||||
# isn't explicitly expected, the test fails with a clear message.
|
|
||||||
|
|
||||||
_ALL_NOISE_PATTERNS = list(_SHELL_NOISE_SUBSTRINGS) + [
|
_ALL_NOISE_PATTERNS = [
|
||||||
|
"bash: cannot set terminal process group",
|
||||||
|
"bash: no job control in this shell",
|
||||||
|
"no job control in this shell",
|
||||||
|
"cannot set terminal process group",
|
||||||
|
"tcsetattr: Inappropriate ioctl for device",
|
||||||
"bash: ",
|
"bash: ",
|
||||||
"Inappropriate ioctl",
|
"Inappropriate ioctl",
|
||||||
"Auto-suggestions:",
|
"Auto-suggestions:",
|
||||||
@@ -88,134 +86,6 @@ def populated_dir(tmp_path):
|
|||||||
return tmp_path
|
return tmp_path
|
||||||
|
|
||||||
|
|
||||||
# ── _clean_shell_noise unit tests ────────────────────────────────────────
|
|
||||||
|
|
||||||
class TestCleanShellNoise:
|
|
||||||
def test_single_noise_line(self):
|
|
||||||
output = "bash: no job control in this shell\nhello world\n"
|
|
||||||
result = _clean_shell_noise(output)
|
|
||||||
assert result == "hello world\n"
|
|
||||||
|
|
||||||
def test_double_noise_lines(self):
|
|
||||||
output = (
|
|
||||||
"bash: cannot set terminal process group (-1): Inappropriate ioctl for device\n"
|
|
||||||
"bash: no job control in this shell\n"
|
|
||||||
"actual output here\n"
|
|
||||||
)
|
|
||||||
result = _clean_shell_noise(output)
|
|
||||||
assert result == "actual output here\n"
|
|
||||||
_assert_clean(result)
|
|
||||||
|
|
||||||
def test_tcsetattr_noise(self):
|
|
||||||
output = (
|
|
||||||
"bash: [12345: 2 (255)] tcsetattr: Inappropriate ioctl for device\n"
|
|
||||||
"real content\n"
|
|
||||||
)
|
|
||||||
result = _clean_shell_noise(output)
|
|
||||||
assert result == "real content\n"
|
|
||||||
_assert_clean(result)
|
|
||||||
|
|
||||||
def test_triple_noise_lines(self):
|
|
||||||
output = (
|
|
||||||
"bash: cannot set terminal process group (-1): Inappropriate ioctl for device\n"
|
|
||||||
"bash: no job control in this shell\n"
|
|
||||||
"bash: [999: 2 (255)] tcsetattr: Inappropriate ioctl for device\n"
|
|
||||||
"clean\n"
|
|
||||||
)
|
|
||||||
result = _clean_shell_noise(output)
|
|
||||||
assert result == "clean\n"
|
|
||||||
|
|
||||||
def test_no_noise_untouched(self):
|
|
||||||
assert _clean_shell_noise("hello\nworld\n") == "hello\nworld\n"
|
|
||||||
|
|
||||||
def test_empty_string(self):
|
|
||||||
assert _clean_shell_noise("") == ""
|
|
||||||
|
|
||||||
def test_only_noise_produces_empty(self):
|
|
||||||
output = "bash: no job control in this shell\n"
|
|
||||||
result = _clean_shell_noise(output)
|
|
||||||
_assert_clean(result)
|
|
||||||
|
|
||||||
def test_noise_in_middle_not_stripped(self):
|
|
||||||
"""Noise in the middle is real output and should be preserved."""
|
|
||||||
output = "real\nbash: no job control in this shell\nmore real\n"
|
|
||||||
result = _clean_shell_noise(output)
|
|
||||||
assert result == output
|
|
||||||
|
|
||||||
def test_zsh_restored_session(self):
|
|
||||||
output = "Restored session: Mon Mar 2 22:16:54 +03 2026\nhello\n"
|
|
||||||
result = _clean_shell_noise(output)
|
|
||||||
assert result == "hello\n"
|
|
||||||
|
|
||||||
def test_zsh_saving_session_trailing(self):
|
|
||||||
output = "hello\nSaving session...completed.\n"
|
|
||||||
result = _clean_shell_noise(output)
|
|
||||||
assert result == "hello\n"
|
|
||||||
|
|
||||||
def test_zsh_oh_my_zsh_banner(self):
|
|
||||||
output = "Oh My Zsh on! | Auto-suggestions: press right\nhello\n"
|
|
||||||
result = _clean_shell_noise(output)
|
|
||||||
assert result == "hello\n"
|
|
||||||
|
|
||||||
def test_zsh_full_noise_sandwich(self):
|
|
||||||
"""Both leading and trailing zsh noise stripped."""
|
|
||||||
output = (
|
|
||||||
"Restored session: Mon Mar 2\n"
|
|
||||||
"command not found: docker\n"
|
|
||||||
"Oh My Zsh on!\n"
|
|
||||||
"actual output\n"
|
|
||||||
"Saving session...completed.\n"
|
|
||||||
)
|
|
||||||
result = _clean_shell_noise(output)
|
|
||||||
assert result == "actual output\n"
|
|
||||||
|
|
||||||
def test_last_login_stripped(self):
|
|
||||||
output = "Last login: Mon Mar 2 22:00:00 on ttys001\nhello\n"
|
|
||||||
result = _clean_shell_noise(output)
|
|
||||||
assert result == "hello\n"
|
|
||||||
|
|
||||||
|
|
||||||
# ── _extract_fenced_output unit tests ────────────────────────────────────
|
|
||||||
|
|
||||||
class TestExtractFencedOutput:
|
|
||||||
def test_normal_fenced_output(self):
|
|
||||||
raw = f"noise\n{_OUTPUT_FENCE}hello world\n{_OUTPUT_FENCE}more noise\n"
|
|
||||||
assert _extract_fenced_output(raw) == "hello world\n"
|
|
||||||
|
|
||||||
def test_no_trailing_newline(self):
|
|
||||||
"""printf output with no trailing newline is preserved."""
|
|
||||||
raw = f"noise{_OUTPUT_FENCE}exact{_OUTPUT_FENCE}noise"
|
|
||||||
assert _extract_fenced_output(raw) == "exact"
|
|
||||||
|
|
||||||
def test_no_fences_falls_back(self):
|
|
||||||
"""Without fences, falls back to pattern-based cleaning."""
|
|
||||||
raw = "bash: no job control in this shell\nhello\n"
|
|
||||||
result = _extract_fenced_output(raw)
|
|
||||||
assert result == "hello\n"
|
|
||||||
|
|
||||||
def test_only_start_fence(self):
|
|
||||||
"""Only start fence (e.g. user command called exit)."""
|
|
||||||
raw = f"noise{_OUTPUT_FENCE}hello\nSaving session...\n"
|
|
||||||
result = _extract_fenced_output(raw)
|
|
||||||
assert result == "hello\n"
|
|
||||||
|
|
||||||
def test_user_outputs_fence_string(self):
|
|
||||||
"""If user command outputs the fence marker, it is preserved."""
|
|
||||||
raw = f"noise{_OUTPUT_FENCE}{_OUTPUT_FENCE}real\n{_OUTPUT_FENCE}noise"
|
|
||||||
result = _extract_fenced_output(raw)
|
|
||||||
# first fence -> last fence captures the middle including user's fence
|
|
||||||
assert _OUTPUT_FENCE in result
|
|
||||||
assert "real\n" in result
|
|
||||||
|
|
||||||
def test_empty_command_output(self):
|
|
||||||
raw = f"noise{_OUTPUT_FENCE}{_OUTPUT_FENCE}noise"
|
|
||||||
assert _extract_fenced_output(raw) == ""
|
|
||||||
|
|
||||||
def test_multiline_output(self):
|
|
||||||
raw = f"noise\n{_OUTPUT_FENCE}line1\nline2\nline3\n{_OUTPUT_FENCE}noise\n"
|
|
||||||
assert _extract_fenced_output(raw) == "line1\nline2\nline3\n"
|
|
||||||
|
|
||||||
|
|
||||||
# ── LocalEnvironment.execute() ───────────────────────────────────────────
|
# ── LocalEnvironment.execute() ───────────────────────────────────────────
|
||||||
|
|
||||||
class TestLocalEnvironmentExecute:
|
class TestLocalEnvironmentExecute:
|
||||||
|
|||||||
@@ -1,164 +0,0 @@
|
|||||||
"""Tests for the local persistent shell backend."""
|
|
||||||
|
|
||||||
import glob as glob_mod
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from tools.environments.local import LocalEnvironment
|
|
||||||
from tools.environments.persistent_shell import PersistentShellMixin
|
|
||||||
|
|
||||||
|
|
||||||
class TestLocalConfig:
|
|
||||||
def test_local_persistent_default_false(self, monkeypatch):
|
|
||||||
monkeypatch.delenv("TERMINAL_LOCAL_PERSISTENT", raising=False)
|
|
||||||
from tools.terminal_tool import _get_env_config
|
|
||||||
assert _get_env_config()["local_persistent"] is False
|
|
||||||
|
|
||||||
def test_local_persistent_true(self, monkeypatch):
|
|
||||||
monkeypatch.setenv("TERMINAL_LOCAL_PERSISTENT", "true")
|
|
||||||
from tools.terminal_tool import _get_env_config
|
|
||||||
assert _get_env_config()["local_persistent"] is True
|
|
||||||
|
|
||||||
def test_local_persistent_yes(self, monkeypatch):
|
|
||||||
monkeypatch.setenv("TERMINAL_LOCAL_PERSISTENT", "yes")
|
|
||||||
from tools.terminal_tool import _get_env_config
|
|
||||||
assert _get_env_config()["local_persistent"] is True
|
|
||||||
|
|
||||||
|
|
||||||
class TestMergeOutput:
|
|
||||||
def test_stdout_only(self):
|
|
||||||
assert PersistentShellMixin._merge_output("out", "") == "out"
|
|
||||||
|
|
||||||
def test_stderr_only(self):
|
|
||||||
assert PersistentShellMixin._merge_output("", "err") == "err"
|
|
||||||
|
|
||||||
def test_both(self):
|
|
||||||
assert PersistentShellMixin._merge_output("out", "err") == "out\nerr"
|
|
||||||
|
|
||||||
def test_empty(self):
|
|
||||||
assert PersistentShellMixin._merge_output("", "") == ""
|
|
||||||
|
|
||||||
def test_strips_trailing_newlines(self):
|
|
||||||
assert PersistentShellMixin._merge_output("out\n\n", "err\n") == "out\nerr"
|
|
||||||
|
|
||||||
|
|
||||||
class TestLocalOneShotRegression:
|
|
||||||
def test_echo(self):
|
|
||||||
env = LocalEnvironment(persistent=False)
|
|
||||||
r = env.execute("echo hello")
|
|
||||||
assert r["returncode"] == 0
|
|
||||||
assert "hello" in r["output"]
|
|
||||||
env.cleanup()
|
|
||||||
|
|
||||||
def test_exit_code(self):
|
|
||||||
env = LocalEnvironment(persistent=False)
|
|
||||||
r = env.execute("exit 42")
|
|
||||||
assert r["returncode"] == 42
|
|
||||||
env.cleanup()
|
|
||||||
|
|
||||||
def test_state_does_not_persist(self):
|
|
||||||
env = LocalEnvironment(persistent=False)
|
|
||||||
env.execute("export HERMES_ONESHOT_LOCAL=yes")
|
|
||||||
r = env.execute("echo $HERMES_ONESHOT_LOCAL")
|
|
||||||
assert r["output"].strip() == ""
|
|
||||||
env.cleanup()
|
|
||||||
|
|
||||||
def test_oneshot_heredoc_does_not_leak_fence_wrapper(self):
|
|
||||||
"""Heredoc closing line must not be merged with the fence wrapper tail."""
|
|
||||||
env = LocalEnvironment(persistent=False)
|
|
||||||
cmd = "cat <<'H_EOF'\nheredoc body line\nH_EOF"
|
|
||||||
r = env.execute(cmd)
|
|
||||||
env.cleanup()
|
|
||||||
assert r["returncode"] == 0
|
|
||||||
assert "heredoc body line" in r["output"]
|
|
||||||
assert "__hermes_rc" not in r["output"]
|
|
||||||
assert "printf '" not in r["output"]
|
|
||||||
assert "exit $" not in r["output"]
|
|
||||||
|
|
||||||
|
|
||||||
class TestLocalPersistent:
|
|
||||||
@pytest.fixture
|
|
||||||
def env(self):
|
|
||||||
e = LocalEnvironment(persistent=True)
|
|
||||||
yield e
|
|
||||||
e.cleanup()
|
|
||||||
|
|
||||||
def test_echo(self, env):
|
|
||||||
r = env.execute("echo hello-persistent")
|
|
||||||
assert r["returncode"] == 0
|
|
||||||
assert "hello-persistent" in r["output"]
|
|
||||||
|
|
||||||
def test_env_var_persists(self, env):
|
|
||||||
env.execute("export HERMES_LOCAL_PERSIST_TEST=works")
|
|
||||||
r = env.execute("echo $HERMES_LOCAL_PERSIST_TEST")
|
|
||||||
assert r["output"].strip() == "works"
|
|
||||||
|
|
||||||
def test_cwd_persists(self, env):
|
|
||||||
env.execute("cd /tmp")
|
|
||||||
r = env.execute("pwd")
|
|
||||||
assert r["output"].strip() == "/tmp"
|
|
||||||
|
|
||||||
def test_exit_code(self, env):
|
|
||||||
r = env.execute("(exit 42)")
|
|
||||||
assert r["returncode"] == 42
|
|
||||||
|
|
||||||
def test_stderr(self, env):
|
|
||||||
r = env.execute("echo oops >&2")
|
|
||||||
assert r["returncode"] == 0
|
|
||||||
assert "oops" in r["output"]
|
|
||||||
|
|
||||||
def test_multiline_output(self, env):
|
|
||||||
r = env.execute("echo a; echo b; echo c")
|
|
||||||
lines = r["output"].strip().splitlines()
|
|
||||||
assert lines == ["a", "b", "c"]
|
|
||||||
|
|
||||||
def test_timeout_then_recovery(self, env):
|
|
||||||
r = env.execute("sleep 999", timeout=2)
|
|
||||||
assert r["returncode"] in (124, 130)
|
|
||||||
r = env.execute("echo alive")
|
|
||||||
assert r["returncode"] == 0
|
|
||||||
assert "alive" in r["output"]
|
|
||||||
|
|
||||||
def test_large_output(self, env):
|
|
||||||
r = env.execute("seq 1 1000")
|
|
||||||
assert r["returncode"] == 0
|
|
||||||
lines = r["output"].strip().splitlines()
|
|
||||||
assert len(lines) == 1000
|
|
||||||
assert lines[0] == "1"
|
|
||||||
assert lines[-1] == "1000"
|
|
||||||
|
|
||||||
def test_shell_variable_persists(self, env):
|
|
||||||
env.execute("MY_LOCAL_VAR=hello123")
|
|
||||||
r = env.execute("echo $MY_LOCAL_VAR")
|
|
||||||
assert r["output"].strip() == "hello123"
|
|
||||||
|
|
||||||
def test_cleanup_removes_temp_files(self, env):
|
|
||||||
env.execute("echo warmup")
|
|
||||||
prefix = env._temp_prefix
|
|
||||||
assert len(glob_mod.glob(f"{prefix}-*")) > 0
|
|
||||||
env.cleanup()
|
|
||||||
remaining = glob_mod.glob(f"{prefix}-*")
|
|
||||||
assert remaining == []
|
|
||||||
|
|
||||||
def test_state_does_not_leak_between_instances(self):
|
|
||||||
env1 = LocalEnvironment(persistent=True)
|
|
||||||
env2 = LocalEnvironment(persistent=True)
|
|
||||||
try:
|
|
||||||
env1.execute("export LEAK_TEST=from_env1")
|
|
||||||
r = env2.execute("echo $LEAK_TEST")
|
|
||||||
assert r["output"].strip() == ""
|
|
||||||
finally:
|
|
||||||
env1.cleanup()
|
|
||||||
env2.cleanup()
|
|
||||||
|
|
||||||
def test_special_characters_in_command(self, env):
|
|
||||||
r = env.execute("echo 'hello world'")
|
|
||||||
assert r["output"].strip() == "hello world"
|
|
||||||
|
|
||||||
def test_pipe_command(self, env):
|
|
||||||
r = env.execute("echo hello | tr 'h' 'H'")
|
|
||||||
assert r["output"].strip() == "Hello"
|
|
||||||
|
|
||||||
def test_multiple_commands_semicolon(self, env):
|
|
||||||
r = env.execute("X=42; echo $X")
|
|
||||||
assert r["output"].strip() == "42"
|
|
||||||
@@ -110,7 +110,7 @@ class _FakeResponse:
|
|||||||
def test_managed_modal_execute_polls_until_completed(monkeypatch):
|
def test_managed_modal_execute_polls_until_completed(monkeypatch):
|
||||||
_install_fake_tools_package()
|
_install_fake_tools_package()
|
||||||
managed_modal = _load_tool_module("tools.environments.managed_modal", "environments/managed_modal.py")
|
managed_modal = _load_tool_module("tools.environments.managed_modal", "environments/managed_modal.py")
|
||||||
modal_common = sys.modules["tools.environments.modal_common"]
|
modal_common = sys.modules["tools.environments.modal_utils"]
|
||||||
|
|
||||||
calls = []
|
calls = []
|
||||||
poll_count = {"value": 0}
|
poll_count = {"value": 0}
|
||||||
@@ -173,7 +173,7 @@ def test_managed_modal_create_sends_a_stable_idempotency_key(monkeypatch):
|
|||||||
def test_managed_modal_execute_cancels_on_interrupt(monkeypatch):
|
def test_managed_modal_execute_cancels_on_interrupt(monkeypatch):
|
||||||
interrupt_event = _install_fake_tools_package()
|
interrupt_event = _install_fake_tools_package()
|
||||||
managed_modal = _load_tool_module("tools.environments.managed_modal", "environments/managed_modal.py")
|
managed_modal = _load_tool_module("tools.environments.managed_modal", "environments/managed_modal.py")
|
||||||
modal_common = sys.modules["tools.environments.modal_common"]
|
modal_common = sys.modules["tools.environments.modal_utils"]
|
||||||
|
|
||||||
calls = []
|
calls = []
|
||||||
|
|
||||||
@@ -215,7 +215,7 @@ def test_managed_modal_execute_cancels_on_interrupt(monkeypatch):
|
|||||||
def test_managed_modal_execute_returns_descriptive_error_on_missing_exec(monkeypatch):
|
def test_managed_modal_execute_returns_descriptive_error_on_missing_exec(monkeypatch):
|
||||||
_install_fake_tools_package()
|
_install_fake_tools_package()
|
||||||
managed_modal = _load_tool_module("tools.environments.managed_modal", "environments/managed_modal.py")
|
managed_modal = _load_tool_module("tools.environments.managed_modal", "environments/managed_modal.py")
|
||||||
modal_common = sys.modules["tools.environments.modal_common"]
|
modal_common = sys.modules["tools.environments.modal_utils"]
|
||||||
|
|
||||||
def fake_request(method, url, headers=None, json=None, timeout=None):
|
def fake_request(method, url, headers=None, json=None, timeout=None):
|
||||||
if method == "POST" and url.endswith("/v1/sandboxes"):
|
if method == "POST" and url.endswith("/v1/sandboxes"):
|
||||||
@@ -293,7 +293,7 @@ def test_managed_modal_rejects_host_credential_passthrough():
|
|||||||
def test_managed_modal_execute_times_out_and_cancels(monkeypatch):
|
def test_managed_modal_execute_times_out_and_cancels(monkeypatch):
|
||||||
_install_fake_tools_package()
|
_install_fake_tools_package()
|
||||||
managed_modal = _load_tool_module("tools.environments.managed_modal", "environments/managed_modal.py")
|
managed_modal = _load_tool_module("tools.environments.managed_modal", "environments/managed_modal.py")
|
||||||
modal_common = sys.modules["tools.environments.modal_common"]
|
modal_common = sys.modules["tools.environments.modal_utils"]
|
||||||
|
|
||||||
calls = []
|
calls = []
|
||||||
monotonic_values = iter([0.0, 12.5])
|
monotonic_values = iter([0.0, 12.5])
|
||||||
|
|||||||
@@ -231,20 +231,20 @@ class TestEnsurepipFix:
|
|||||||
"""Verify the pip fix is applied in the ModalEnvironment init."""
|
"""Verify the pip fix is applied in the ModalEnvironment init."""
|
||||||
|
|
||||||
def test_modal_environment_creates_image_with_setup_commands(self):
|
def test_modal_environment_creates_image_with_setup_commands(self):
|
||||||
"""ModalEnvironment.__init__ should create a modal.Image with pip fix."""
|
"""_resolve_modal_image should create a modal.Image with pip fix."""
|
||||||
try:
|
try:
|
||||||
from tools.environments.modal import ModalEnvironment
|
from tools.environments.modal import _resolve_modal_image
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pytest.skip("tools.environments.modal not importable")
|
pytest.skip("tools.environments.modal not importable")
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
source = inspect.getsource(ModalEnvironment.__init__)
|
source = inspect.getsource(_resolve_modal_image)
|
||||||
assert "ensurepip" in source, (
|
assert "ensurepip" in source, (
|
||||||
"ModalEnvironment should include ensurepip fix "
|
"_resolve_modal_image should include ensurepip fix "
|
||||||
"for Modal's legacy image builder"
|
"for Modal's legacy image builder"
|
||||||
)
|
)
|
||||||
assert "setup_dockerfile_commands" in source, (
|
assert "setup_dockerfile_commands" in source, (
|
||||||
"ModalEnvironment should use setup_dockerfile_commands "
|
"_resolve_modal_image should use setup_dockerfile_commands "
|
||||||
"to fix pip before Modal's bootstrap"
|
"to fix pip before Modal's bootstrap"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -85,11 +85,47 @@ def _install_modal_test_modules(
|
|||||||
def _prepare_command(self, command: str):
|
def _prepare_command(self, command: str):
|
||||||
return command, None
|
return command, None
|
||||||
|
|
||||||
sys.modules["tools.environments.base"] = types.SimpleNamespace(BaseEnvironment=_DummyBaseEnvironment)
|
def init_session(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Stub _ThreadedProcessHandle: modal.py imports it but only uses it at
|
||||||
|
# runtime inside _run_bash; the snapshot-isolation tests never call _run_bash,
|
||||||
|
# so a class placeholder is sufficient.
|
||||||
|
class _DummyThreadedProcessHandle:
|
||||||
|
def __init__(self, exec_fn, cancel_fn=None):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _load_json_store(path):
|
||||||
|
if path.exists():
|
||||||
|
try:
|
||||||
|
return json.loads(path.read_text())
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def _save_json_store(path, data):
|
||||||
|
path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
path.write_text(json.dumps(data, indent=2))
|
||||||
|
|
||||||
|
def _file_mtime_key(host_path):
|
||||||
|
try:
|
||||||
|
st = Path(host_path).stat()
|
||||||
|
return (st.st_mtime, st.st_size)
|
||||||
|
except OSError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
sys.modules["tools.environments.base"] = types.SimpleNamespace(
|
||||||
|
BaseEnvironment=_DummyBaseEnvironment,
|
||||||
|
_ThreadedProcessHandle=_DummyThreadedProcessHandle,
|
||||||
|
_load_json_store=_load_json_store,
|
||||||
|
_save_json_store=_save_json_store,
|
||||||
|
_file_mtime_key=_file_mtime_key,
|
||||||
|
)
|
||||||
sys.modules["tools.interrupt"] = types.SimpleNamespace(is_interrupted=lambda: False)
|
sys.modules["tools.interrupt"] = types.SimpleNamespace(is_interrupted=lambda: False)
|
||||||
sys.modules["tools.credential_files"] = types.SimpleNamespace(
|
sys.modules["tools.credential_files"] = types.SimpleNamespace(
|
||||||
get_credential_file_mounts=lambda: [],
|
get_credential_file_mounts=lambda: [],
|
||||||
iter_skills_files=lambda: [],
|
iter_skills_files=lambda: [],
|
||||||
|
iter_cache_files=lambda: [],
|
||||||
)
|
)
|
||||||
|
|
||||||
from_id_calls: list[str] = []
|
from_id_calls: list[str] = []
|
||||||
|
|||||||
@@ -43,7 +43,7 @@ class TestBuildSSHCommand:
|
|||||||
lambda *a, **k: MagicMock(stdout=iter([]),
|
lambda *a, **k: MagicMock(stdout=iter([]),
|
||||||
stderr=iter([]),
|
stderr=iter([]),
|
||||||
stdin=MagicMock()))
|
stdin=MagicMock()))
|
||||||
monkeypatch.setattr("tools.environments.ssh.time.sleep", lambda _: None)
|
monkeypatch.setattr("tools.environments.base.time.sleep", lambda _: None)
|
||||||
|
|
||||||
def test_base_flags(self):
|
def test_base_flags(self):
|
||||||
env = SSHEnvironment(host="h", user="u")
|
env = SSHEnvironment(host="h", user="u")
|
||||||
|
|||||||
144
tests/tools/test_threaded_process_handle.py
Normal file
144
tests/tools/test_threaded_process_handle.py
Normal file
@@ -0,0 +1,144 @@
|
|||||||
|
"""Tests for _ThreadedProcessHandle — the adapter for SDK backends."""
|
||||||
|
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
|
||||||
|
from tools.environments.base import _ThreadedProcessHandle
|
||||||
|
|
||||||
|
|
||||||
|
class TestBasicExecution:
|
||||||
|
def test_successful_execution(self):
|
||||||
|
def exec_fn():
|
||||||
|
return ("hello world", 0)
|
||||||
|
|
||||||
|
handle = _ThreadedProcessHandle(exec_fn)
|
||||||
|
handle.wait(timeout=5)
|
||||||
|
|
||||||
|
assert handle.returncode == 0
|
||||||
|
output = handle.stdout.read()
|
||||||
|
assert "hello world" in output
|
||||||
|
|
||||||
|
def test_nonzero_exit_code(self):
|
||||||
|
def exec_fn():
|
||||||
|
return ("error occurred", 42)
|
||||||
|
|
||||||
|
handle = _ThreadedProcessHandle(exec_fn)
|
||||||
|
handle.wait(timeout=5)
|
||||||
|
|
||||||
|
assert handle.returncode == 42
|
||||||
|
output = handle.stdout.read()
|
||||||
|
assert "error occurred" in output
|
||||||
|
|
||||||
|
def test_exception_in_exec_fn(self):
|
||||||
|
def exec_fn():
|
||||||
|
raise RuntimeError("boom")
|
||||||
|
|
||||||
|
handle = _ThreadedProcessHandle(exec_fn)
|
||||||
|
handle.wait(timeout=5)
|
||||||
|
|
||||||
|
assert handle.returncode == 1
|
||||||
|
|
||||||
|
def test_empty_output(self):
|
||||||
|
def exec_fn():
|
||||||
|
return ("", 0)
|
||||||
|
|
||||||
|
handle = _ThreadedProcessHandle(exec_fn)
|
||||||
|
handle.wait(timeout=5)
|
||||||
|
|
||||||
|
assert handle.returncode == 0
|
||||||
|
output = handle.stdout.read()
|
||||||
|
assert output == ""
|
||||||
|
|
||||||
|
|
||||||
|
class TestPolling:
|
||||||
|
def test_poll_returns_none_while_running(self):
|
||||||
|
event = threading.Event()
|
||||||
|
|
||||||
|
def exec_fn():
|
||||||
|
event.wait(timeout=5)
|
||||||
|
return ("done", 0)
|
||||||
|
|
||||||
|
handle = _ThreadedProcessHandle(exec_fn)
|
||||||
|
assert handle.poll() is None
|
||||||
|
|
||||||
|
event.set()
|
||||||
|
handle.wait(timeout=5)
|
||||||
|
assert handle.poll() == 0
|
||||||
|
|
||||||
|
def test_poll_returns_returncode_when_done(self):
|
||||||
|
def exec_fn():
|
||||||
|
return ("ok", 0)
|
||||||
|
|
||||||
|
handle = _ThreadedProcessHandle(exec_fn)
|
||||||
|
handle.wait(timeout=5)
|
||||||
|
assert handle.poll() == 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestCancelFn:
|
||||||
|
def test_cancel_fn_called_on_kill(self):
|
||||||
|
called = threading.Event()
|
||||||
|
|
||||||
|
def cancel():
|
||||||
|
called.set()
|
||||||
|
|
||||||
|
def exec_fn():
|
||||||
|
time.sleep(10)
|
||||||
|
return ("", 0)
|
||||||
|
|
||||||
|
handle = _ThreadedProcessHandle(exec_fn, cancel_fn=cancel)
|
||||||
|
handle.kill()
|
||||||
|
assert called.is_set()
|
||||||
|
|
||||||
|
def test_cancel_fn_none_is_safe(self):
|
||||||
|
def exec_fn():
|
||||||
|
return ("ok", 0)
|
||||||
|
|
||||||
|
handle = _ThreadedProcessHandle(exec_fn, cancel_fn=None)
|
||||||
|
handle.kill() # should not raise
|
||||||
|
handle.wait(timeout=5)
|
||||||
|
assert handle.returncode == 0
|
||||||
|
|
||||||
|
def test_cancel_fn_exception_swallowed(self):
|
||||||
|
def cancel():
|
||||||
|
raise RuntimeError("cancel failed")
|
||||||
|
|
||||||
|
def exec_fn():
|
||||||
|
return ("ok", 0)
|
||||||
|
|
||||||
|
handle = _ThreadedProcessHandle(exec_fn, cancel_fn=cancel)
|
||||||
|
handle.kill() # should not raise despite cancel raising
|
||||||
|
handle.wait(timeout=5)
|
||||||
|
|
||||||
|
|
||||||
|
class TestStdoutPipe:
|
||||||
|
def test_stdout_is_readable(self):
|
||||||
|
def exec_fn():
|
||||||
|
return ("line1\nline2\nline3\n", 0)
|
||||||
|
|
||||||
|
handle = _ThreadedProcessHandle(exec_fn)
|
||||||
|
handle.wait(timeout=5)
|
||||||
|
|
||||||
|
lines = handle.stdout.readlines()
|
||||||
|
assert len(lines) == 3
|
||||||
|
assert lines[0] == "line1\n"
|
||||||
|
|
||||||
|
def test_stdout_iterable(self):
|
||||||
|
def exec_fn():
|
||||||
|
return ("a\nb\nc\n", 0)
|
||||||
|
|
||||||
|
handle = _ThreadedProcessHandle(exec_fn)
|
||||||
|
handle.wait(timeout=5)
|
||||||
|
|
||||||
|
collected = list(handle.stdout)
|
||||||
|
assert len(collected) == 3
|
||||||
|
|
||||||
|
def test_unicode_output(self):
|
||||||
|
def exec_fn():
|
||||||
|
return ("hello 世界 🌍\n", 0)
|
||||||
|
|
||||||
|
handle = _ThreadedProcessHandle(exec_fn)
|
||||||
|
handle.wait(timeout=5)
|
||||||
|
|
||||||
|
output = handle.stdout.read()
|
||||||
|
assert "世界" in output
|
||||||
|
assert "🌍" in output
|
||||||
@@ -18,7 +18,7 @@ Architecture (two transports):
|
|||||||
2. Parent ships both files to the remote environment
|
2. Parent ships both files to the remote environment
|
||||||
3. Script runs inside the terminal backend (Docker/SSH/Modal/Daytona/etc.)
|
3. Script runs inside the terminal backend (Docker/SSH/Modal/Daytona/etc.)
|
||||||
4. Tool calls are written as request files; a polling thread on the parent
|
4. Tool calls are written as request files; a polling thread on the parent
|
||||||
reads them via execute_oneshot(), dispatches, and writes response files
|
reads them via env.execute(), dispatches, and writes response files
|
||||||
5. The script polls for response files and continues
|
5. The script polls for response files and continues
|
||||||
|
|
||||||
In both cases, only the script's stdout is returned to the LLM; intermediate
|
In both cases, only the script's stdout is returned to the LLM; intermediate
|
||||||
@@ -536,7 +536,7 @@ def _ship_file_to_remote(env, remote_path: str, content: str) -> None:
|
|||||||
quotes are fine.
|
quotes are fine.
|
||||||
"""
|
"""
|
||||||
encoded = base64.b64encode(content.encode("utf-8")).decode("ascii")
|
encoded = base64.b64encode(content.encode("utf-8")).decode("ascii")
|
||||||
env.execute_oneshot(
|
env.execute(
|
||||||
f"echo '{encoded}' | base64 -d > {remote_path}",
|
f"echo '{encoded}' | base64 -d > {remote_path}",
|
||||||
cwd="/",
|
cwd="/",
|
||||||
timeout=30,
|
timeout=30,
|
||||||
@@ -555,9 +555,9 @@ def _rpc_poll_loop(
|
|||||||
):
|
):
|
||||||
"""Poll the remote filesystem for tool call requests and dispatch them.
|
"""Poll the remote filesystem for tool call requests and dispatch them.
|
||||||
|
|
||||||
Runs in a background thread. Uses ``env.execute_oneshot()`` so it can
|
Runs in a background thread. Each ``env.execute()`` spawns an
|
||||||
operate concurrently with the script-execution thread that holds
|
independent process, so these calls run safely concurrent with the
|
||||||
``env.execute()`` (important for persistent-shell backends like SSH).
|
script-execution thread.
|
||||||
"""
|
"""
|
||||||
from model_tools import handle_function_call
|
from model_tools import handle_function_call
|
||||||
|
|
||||||
@@ -566,7 +566,7 @@ def _rpc_poll_loop(
|
|||||||
while not stop_event.is_set():
|
while not stop_event.is_set():
|
||||||
try:
|
try:
|
||||||
# List pending request files (skip .tmp partials)
|
# List pending request files (skip .tmp partials)
|
||||||
ls_result = env.execute_oneshot(
|
ls_result = env.execute(
|
||||||
f"ls -1 {rpc_dir}/req_* 2>/dev/null || true",
|
f"ls -1 {rpc_dir}/req_* 2>/dev/null || true",
|
||||||
cwd="/",
|
cwd="/",
|
||||||
timeout=10,
|
timeout=10,
|
||||||
@@ -590,7 +590,7 @@ def _rpc_poll_loop(
|
|||||||
call_start = time.monotonic()
|
call_start = time.monotonic()
|
||||||
|
|
||||||
# Read request
|
# Read request
|
||||||
read_result = env.execute_oneshot(
|
read_result = env.execute(
|
||||||
f"cat {req_file}",
|
f"cat {req_file}",
|
||||||
cwd="/",
|
cwd="/",
|
||||||
timeout=10,
|
timeout=10,
|
||||||
@@ -600,7 +600,7 @@ def _rpc_poll_loop(
|
|||||||
except (json.JSONDecodeError, ValueError):
|
except (json.JSONDecodeError, ValueError):
|
||||||
logger.debug("Malformed RPC request in %s", req_file)
|
logger.debug("Malformed RPC request in %s", req_file)
|
||||||
# Remove bad request to avoid infinite retry
|
# Remove bad request to avoid infinite retry
|
||||||
env.execute_oneshot(f"rm -f {req_file}", cwd="/", timeout=5)
|
env.execute(f"rm -f {req_file}", cwd="/", timeout=5)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
tool_name = request.get("tool", "")
|
tool_name = request.get("tool", "")
|
||||||
@@ -664,7 +664,7 @@ def _rpc_poll_loop(
|
|||||||
encoded_result = base64.b64encode(
|
encoded_result = base64.b64encode(
|
||||||
tool_result.encode("utf-8")
|
tool_result.encode("utf-8")
|
||||||
).decode("ascii")
|
).decode("ascii")
|
||||||
env.execute_oneshot(
|
env.execute(
|
||||||
f"echo '{encoded_result}' | base64 -d > {res_file}.tmp"
|
f"echo '{encoded_result}' | base64 -d > {res_file}.tmp"
|
||||||
f" && mv {res_file}.tmp {res_file}",
|
f" && mv {res_file}.tmp {res_file}",
|
||||||
cwd="/",
|
cwd="/",
|
||||||
@@ -672,7 +672,7 @@ def _rpc_poll_loop(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Remove the request file
|
# Remove the request file
|
||||||
env.execute_oneshot(f"rm -f {req_file}", cwd="/", timeout=5)
|
env.execute(f"rm -f {req_file}", cwd="/", timeout=5)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if not stop_event.is_set():
|
if not stop_event.is_set():
|
||||||
@@ -717,7 +717,7 @@ def _execute_remote(
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# Verify Python is available on the remote
|
# Verify Python is available on the remote
|
||||||
py_check = env.execute_oneshot(
|
py_check = env.execute(
|
||||||
"command -v python3 >/dev/null 2>&1 && echo OK",
|
"command -v python3 >/dev/null 2>&1 && echo OK",
|
||||||
cwd="/", timeout=15,
|
cwd="/", timeout=15,
|
||||||
)
|
)
|
||||||
@@ -734,7 +734,7 @@ def _execute_remote(
|
|||||||
})
|
})
|
||||||
|
|
||||||
# Create sandbox directory on remote
|
# Create sandbox directory on remote
|
||||||
env.execute_oneshot(
|
env.execute(
|
||||||
f"mkdir -p {sandbox_dir}/rpc", cwd="/", timeout=10,
|
f"mkdir -p {sandbox_dir}/rpc", cwd="/", timeout=10,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -806,7 +806,7 @@ def _execute_remote(
|
|||||||
|
|
||||||
# Clean up remote sandbox dir
|
# Clean up remote sandbox dir
|
||||||
try:
|
try:
|
||||||
env.execute_oneshot(
|
env.execute(
|
||||||
f"rm -rf {sandbox_dir}", cwd="/", timeout=15,
|
f"rm -rf {sandbox_dir}", cwd="/", timeout=15,
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
|
|||||||
@@ -1,11 +1,27 @@
|
|||||||
"""Base class for all Hermes execution environment backends."""
|
"""Base class for all Hermes execution environment backends.
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
Unified spawn-per-call model: every command spawns a fresh ``bash -c`` process.
|
||||||
|
A session snapshot (env vars, functions, aliases) is captured once at init and
|
||||||
|
re-sourced before each command. CWD persists via in-band stdout markers (remote)
|
||||||
|
or a temp file (local).
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import shlex
|
||||||
import subprocess
|
import subprocess
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import IO, Callable, Protocol
|
||||||
|
|
||||||
from hermes_constants import get_hermes_home
|
from hermes_constants import get_hermes_home
|
||||||
|
from tools.interrupt import is_interrupted
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def get_sandbox_dir() -> Path:
|
def get_sandbox_dir() -> Path:
|
||||||
@@ -23,30 +39,501 @@ def get_sandbox_dir() -> Path:
|
|||||||
return p
|
return p
|
||||||
|
|
||||||
|
|
||||||
class BaseEnvironment(ABC):
|
# ---------------------------------------------------------------------------
|
||||||
"""Common interface for all Hermes execution backends.
|
# Shared constants and utilities
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
Subclasses implement execute() and cleanup(). Shared helpers eliminate
|
_SYNC_INTERVAL_SECONDS = 5.0
|
||||||
duplicated subprocess boilerplate across backends.
|
|
||||||
|
|
||||||
|
def _pipe_stdin(proc: subprocess.Popen, data: str) -> None:
|
||||||
|
"""Write *data* to proc.stdin on a daemon thread to avoid pipe-buffer deadlocks."""
|
||||||
|
|
||||||
|
def _write():
|
||||||
|
try:
|
||||||
|
proc.stdin.write(data)
|
||||||
|
proc.stdin.close()
|
||||||
|
except (BrokenPipeError, OSError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
threading.Thread(target=_write, daemon=True).start()
|
||||||
|
|
||||||
|
|
||||||
|
def _popen_bash(
|
||||||
|
cmd: list[str], stdin_data: str | None = None, **kwargs
|
||||||
|
) -> subprocess.Popen:
|
||||||
|
"""Spawn a subprocess with standard stdout/stderr/stdin setup.
|
||||||
|
|
||||||
|
If *stdin_data* is provided, writes it asynchronously via :func:`_pipe_stdin`.
|
||||||
|
Backends with special Popen needs (e.g. local's ``preexec_fn``) can bypass
|
||||||
|
this and call :func:`_pipe_stdin` directly.
|
||||||
"""
|
"""
|
||||||
|
proc = subprocess.Popen(
|
||||||
|
cmd,
|
||||||
|
stdout=subprocess.PIPE,
|
||||||
|
stderr=subprocess.STDOUT,
|
||||||
|
stdin=subprocess.PIPE if stdin_data is not None else subprocess.DEVNULL,
|
||||||
|
text=True,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
if stdin_data is not None:
|
||||||
|
_pipe_stdin(proc, stdin_data)
|
||||||
|
return proc
|
||||||
|
|
||||||
|
|
||||||
|
def _load_json_store(path: Path) -> dict:
|
||||||
|
"""Load a JSON file as a dict, returning ``{}`` on any error."""
|
||||||
|
if path.exists():
|
||||||
|
try:
|
||||||
|
return json.loads(path.read_text())
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
def _save_json_store(path: Path, data: dict) -> None:
|
||||||
|
"""Write *data* as pretty-printed JSON to *path*."""
|
||||||
|
path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
path.write_text(json.dumps(data, indent=2))
|
||||||
|
|
||||||
|
|
||||||
|
def _file_mtime_key(host_path: str) -> tuple[float, int] | None:
|
||||||
|
"""Return ``(mtime, size)`` for cache comparison, or ``None`` if unreadable."""
|
||||||
|
try:
|
||||||
|
st = Path(host_path).stat()
|
||||||
|
return (st.st_mtime, st.st_size)
|
||||||
|
except OSError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# ProcessHandle protocol
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class ProcessHandle(Protocol):
|
||||||
|
"""Duck type that every backend's _run_bash() must return.
|
||||||
|
|
||||||
|
subprocess.Popen satisfies this natively. SDK backends (Modal, Daytona)
|
||||||
|
return _ThreadedProcessHandle which adapts their blocking calls.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def poll(self) -> int | None: ...
|
||||||
|
def kill(self) -> None: ...
|
||||||
|
def wait(self, timeout: float | None = None) -> int: ...
|
||||||
|
|
||||||
|
@property
|
||||||
|
def stdout(self) -> IO[str] | None: ...
|
||||||
|
|
||||||
|
@property
|
||||||
|
def returncode(self) -> int | None: ...
|
||||||
|
|
||||||
|
|
||||||
|
class _ThreadedProcessHandle:
|
||||||
|
"""Adapter for SDK backends (Modal, Daytona) that have no real subprocess.
|
||||||
|
|
||||||
|
Wraps a blocking ``exec_fn() -> (output_str, exit_code)`` in a background
|
||||||
|
thread and exposes a ProcessHandle-compatible interface. An optional
|
||||||
|
``cancel_fn`` is invoked on ``kill()`` for backend-specific cancellation
|
||||||
|
(e.g. Modal sandbox.terminate, Daytona sandbox.stop).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
exec_fn: Callable[[], tuple[str, int]],
|
||||||
|
cancel_fn: Callable[[], None] | None = None,
|
||||||
|
):
|
||||||
|
self._cancel_fn = cancel_fn
|
||||||
|
self._done = threading.Event()
|
||||||
|
self._returncode: int | None = None
|
||||||
|
self._error: Exception | None = None
|
||||||
|
|
||||||
|
# Pipe for stdout — drain thread in _wait_for_process reads the read end.
|
||||||
|
read_fd, write_fd = os.pipe()
|
||||||
|
self._stdout = os.fdopen(read_fd, "r", encoding="utf-8", errors="replace")
|
||||||
|
self._write_fd = write_fd
|
||||||
|
|
||||||
|
def _worker():
|
||||||
|
try:
|
||||||
|
output, exit_code = exec_fn()
|
||||||
|
self._returncode = exit_code
|
||||||
|
# Write output into the pipe so drain thread picks it up.
|
||||||
|
try:
|
||||||
|
os.write(self._write_fd, output.encode("utf-8", errors="replace"))
|
||||||
|
except OSError:
|
||||||
|
pass
|
||||||
|
except Exception as exc:
|
||||||
|
self._error = exc
|
||||||
|
self._returncode = 1
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
os.close(self._write_fd)
|
||||||
|
except OSError:
|
||||||
|
pass
|
||||||
|
self._done.set()
|
||||||
|
|
||||||
|
t = threading.Thread(target=_worker, daemon=True)
|
||||||
|
t.start()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def stdout(self):
|
||||||
|
return self._stdout
|
||||||
|
|
||||||
|
@property
|
||||||
|
def returncode(self) -> int | None:
|
||||||
|
return self._returncode
|
||||||
|
|
||||||
|
def poll(self) -> int | None:
|
||||||
|
return self._returncode if self._done.is_set() else None
|
||||||
|
|
||||||
|
def kill(self):
|
||||||
|
if self._cancel_fn:
|
||||||
|
try:
|
||||||
|
self._cancel_fn()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def wait(self, timeout: float | None = None) -> int:
|
||||||
|
self._done.wait(timeout=timeout)
|
||||||
|
return self._returncode
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# CWD marker for remote backends
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _cwd_marker(session_id: str) -> str:
|
||||||
|
return f"__HERMES_CWD_{session_id}__"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# BaseEnvironment
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class BaseEnvironment(ABC):
|
||||||
|
"""Common interface and unified execution flow for all Hermes backends.
|
||||||
|
|
||||||
|
Subclasses implement ``_run_bash()`` and ``cleanup()``. The base class
|
||||||
|
provides ``execute()`` with session snapshot sourcing, CWD tracking,
|
||||||
|
interrupt handling, and timeout enforcement.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Subclasses that embed stdin as a heredoc (Modal, Daytona) set this.
|
||||||
|
_stdin_mode: str = "pipe" # "pipe" or "heredoc"
|
||||||
|
|
||||||
|
# Snapshot creation timeout (override for slow cold-starts).
|
||||||
|
_snapshot_timeout: int = 30
|
||||||
|
|
||||||
def __init__(self, cwd: str, timeout: int, env: dict = None):
|
def __init__(self, cwd: str, timeout: int, env: dict = None):
|
||||||
self.cwd = cwd
|
self.cwd = cwd
|
||||||
self.timeout = timeout
|
self.timeout = timeout
|
||||||
self.env = env or {}
|
self.env = env or {}
|
||||||
|
|
||||||
@abstractmethod
|
self._session_id = uuid.uuid4().hex[:12]
|
||||||
def execute(self, command: str, cwd: str = "", *,
|
self._snapshot_path = f"/tmp/hermes-snap-{self._session_id}.sh"
|
||||||
timeout: int | None = None,
|
self._cwd_file = f"/tmp/hermes-cwd-{self._session_id}.txt"
|
||||||
stdin_data: str | None = None) -> dict:
|
self._cwd_marker = _cwd_marker(self._session_id)
|
||||||
"""Execute a command, return {"output": str, "returncode": int}."""
|
self._snapshot_ready = False
|
||||||
...
|
self._last_sync_time: float | None = (
|
||||||
|
None # set to 0 by backends that need file sync
|
||||||
|
)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Abstract methods
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _run_bash(
|
||||||
|
self,
|
||||||
|
cmd_string: str,
|
||||||
|
*,
|
||||||
|
login: bool = False,
|
||||||
|
timeout: int = 120,
|
||||||
|
stdin_data: str | None = None,
|
||||||
|
) -> ProcessHandle:
|
||||||
|
"""Spawn a bash process to run *cmd_string*.
|
||||||
|
|
||||||
|
Returns a ProcessHandle (subprocess.Popen or _ThreadedProcessHandle).
|
||||||
|
Must be overridden by every backend.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError(f"{type(self).__name__} must implement _run_bash()")
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def cleanup(self):
|
def cleanup(self):
|
||||||
"""Release backend resources (container, instance, connection)."""
|
"""Release backend resources (container, instance, connection)."""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Session snapshot (init_session)
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def init_session(self):
|
||||||
|
"""Capture login shell environment into a snapshot file.
|
||||||
|
|
||||||
|
Called once after backend construction. On success, sets
|
||||||
|
``_snapshot_ready = True`` so subsequent commands source the snapshot
|
||||||
|
instead of running with ``bash -l``.
|
||||||
|
"""
|
||||||
|
# Full capture: env vars, functions (filtered), aliases, shell options.
|
||||||
|
bootstrap = (
|
||||||
|
f"export -p > {self._snapshot_path}\n"
|
||||||
|
f"declare -f | grep -vE '^_[^_]' >> {self._snapshot_path}\n"
|
||||||
|
f"alias -p >> {self._snapshot_path}\n"
|
||||||
|
f"echo 'shopt -s expand_aliases' >> {self._snapshot_path}\n"
|
||||||
|
f"echo 'set +e' >> {self._snapshot_path}\n"
|
||||||
|
f"echo 'set +u' >> {self._snapshot_path}\n"
|
||||||
|
f"pwd -P > {self._cwd_file} 2>/dev/null || true\n"
|
||||||
|
f"printf '\\n{self._cwd_marker}%s{self._cwd_marker}\\n' \"$(pwd -P)\"\n"
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
proc = self._run_bash(bootstrap, login=True, timeout=self._snapshot_timeout)
|
||||||
|
result = self._wait_for_process(proc, timeout=self._snapshot_timeout)
|
||||||
|
self._snapshot_ready = True
|
||||||
|
self._update_cwd(result)
|
||||||
|
logger.info(
|
||||||
|
"Session snapshot created (session=%s, cwd=%s)",
|
||||||
|
self._session_id,
|
||||||
|
self.cwd,
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning(
|
||||||
|
"init_session failed (session=%s): %s — "
|
||||||
|
"falling back to bash -l per command",
|
||||||
|
self._session_id,
|
||||||
|
exc,
|
||||||
|
)
|
||||||
|
self._snapshot_ready = False
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Command wrapping
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _wrap_command(self, command: str, cwd: str) -> str:
|
||||||
|
"""Build the full bash script that sources snapshot, cd's, runs command,
|
||||||
|
re-dumps env vars, and emits CWD markers."""
|
||||||
|
escaped = command.replace("'", "'\\''")
|
||||||
|
|
||||||
|
parts = []
|
||||||
|
|
||||||
|
# Source snapshot (env vars from previous commands)
|
||||||
|
if self._snapshot_ready:
|
||||||
|
parts.append(f"source {self._snapshot_path} 2>/dev/null || true")
|
||||||
|
|
||||||
|
# cd to working directory — let bash expand ~ natively
|
||||||
|
quoted_cwd = (
|
||||||
|
shlex.quote(cwd) if cwd != "~" and not cwd.startswith("~/") else cwd
|
||||||
|
)
|
||||||
|
parts.append(f"cd {quoted_cwd} || exit 126")
|
||||||
|
|
||||||
|
# Run the actual command
|
||||||
|
parts.append(f"eval '{escaped}'")
|
||||||
|
parts.append("__hermes_ec=$?")
|
||||||
|
|
||||||
|
# Re-dump env vars to snapshot (last-writer-wins for concurrent calls)
|
||||||
|
if self._snapshot_ready:
|
||||||
|
parts.append(f"export -p > {self._snapshot_path} 2>/dev/null || true")
|
||||||
|
|
||||||
|
# Write CWD to file (local reads this) and stdout marker (remote parses this)
|
||||||
|
parts.append(f"pwd -P > {self._cwd_file} 2>/dev/null || true")
|
||||||
|
# Use a distinct line for the marker. The leading \n ensures
|
||||||
|
# the marker starts on its own line even if the command doesn't
|
||||||
|
# end with a newline (e.g. printf 'exact'). We'll strip this
|
||||||
|
# injected newline in _extract_cwd_from_output.
|
||||||
|
parts.append(
|
||||||
|
f"printf '\\n{self._cwd_marker}%s{self._cwd_marker}\\n' \"$(pwd -P)\""
|
||||||
|
)
|
||||||
|
parts.append("exit $__hermes_ec")
|
||||||
|
|
||||||
|
return "\n".join(parts)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Stdin heredoc embedding (for SDK backends)
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _embed_stdin_heredoc(command: str, stdin_data: str) -> str:
|
||||||
|
"""Append stdin_data as a shell heredoc to the command string."""
|
||||||
|
delimiter = f"HERMES_STDIN_{uuid.uuid4().hex[:12]}"
|
||||||
|
return f"{command} << '{delimiter}'\n{stdin_data}\n{delimiter}"
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Process lifecycle
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _wait_for_process(self, proc: ProcessHandle, timeout: int = 120) -> dict:
|
||||||
|
"""Poll-based wait with interrupt checking and stdout draining.
|
||||||
|
|
||||||
|
Shared across all backends — not overridden.
|
||||||
|
"""
|
||||||
|
output_chunks: list[str] = []
|
||||||
|
|
||||||
|
def _drain():
|
||||||
|
try:
|
||||||
|
for line in proc.stdout:
|
||||||
|
output_chunks.append(line)
|
||||||
|
except UnicodeDecodeError:
|
||||||
|
output_chunks.clear()
|
||||||
|
output_chunks.append(
|
||||||
|
"[binary output detected — raw bytes not displayable]"
|
||||||
|
)
|
||||||
|
except (ValueError, OSError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
drain_thread = threading.Thread(target=_drain, daemon=True)
|
||||||
|
drain_thread.start()
|
||||||
|
deadline = time.monotonic() + timeout
|
||||||
|
|
||||||
|
while proc.poll() is None:
|
||||||
|
if is_interrupted():
|
||||||
|
self._kill_process(proc)
|
||||||
|
drain_thread.join(timeout=2)
|
||||||
|
return {
|
||||||
|
"output": "".join(output_chunks) + "\n[Command interrupted]",
|
||||||
|
"returncode": 130,
|
||||||
|
}
|
||||||
|
if time.monotonic() > deadline:
|
||||||
|
self._kill_process(proc)
|
||||||
|
drain_thread.join(timeout=2)
|
||||||
|
partial = "".join(output_chunks)
|
||||||
|
timeout_msg = f"\n[Command timed out after {timeout}s]"
|
||||||
|
return {
|
||||||
|
"output": partial + timeout_msg
|
||||||
|
if partial
|
||||||
|
else timeout_msg.lstrip(),
|
||||||
|
"returncode": 124,
|
||||||
|
}
|
||||||
|
time.sleep(0.2)
|
||||||
|
|
||||||
|
drain_thread.join(timeout=5)
|
||||||
|
|
||||||
|
try:
|
||||||
|
proc.stdout.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return {"output": "".join(output_chunks), "returncode": proc.returncode}
|
||||||
|
|
||||||
|
def _kill_process(self, proc: ProcessHandle):
|
||||||
|
"""Terminate a process. Subclasses may override for process-group kill."""
|
||||||
|
try:
|
||||||
|
proc.kill()
|
||||||
|
except (ProcessLookupError, PermissionError, OSError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# CWD extraction
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _update_cwd(self, result: dict):
|
||||||
|
"""Extract CWD from command output. Override for local file-based read."""
|
||||||
|
self._extract_cwd_from_output(result)
|
||||||
|
|
||||||
|
def _extract_cwd_from_output(self, result: dict):
|
||||||
|
"""Parse the __HERMES_CWD_{session}__ marker from stdout output.
|
||||||
|
|
||||||
|
Updates self.cwd and strips the marker from result["output"].
|
||||||
|
Used by remote backends (Docker, SSH, Modal, Daytona, Singularity).
|
||||||
|
"""
|
||||||
|
output = result.get("output", "")
|
||||||
|
marker = self._cwd_marker
|
||||||
|
last = output.rfind(marker)
|
||||||
|
if last == -1:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Find the opening marker before this closing one
|
||||||
|
search_start = max(0, last - 4096) # CWD path won't be >4KB
|
||||||
|
first = output.rfind(marker, search_start, last)
|
||||||
|
if first == -1 or first == last:
|
||||||
|
return
|
||||||
|
|
||||||
|
cwd_path = output[first + len(marker) : last].strip()
|
||||||
|
if cwd_path:
|
||||||
|
self.cwd = cwd_path
|
||||||
|
|
||||||
|
# Strip the marker line AND the \n we injected before it.
|
||||||
|
# The wrapper emits: printf '\n__MARKER__%s__MARKER__\n'
|
||||||
|
# So the output looks like: <cmd output>\n__MARKER__path__MARKER__\n
|
||||||
|
# We want to remove everything from the injected \n onwards.
|
||||||
|
line_start = output.rfind("\n", 0, first)
|
||||||
|
if line_start == -1:
|
||||||
|
line_start = first
|
||||||
|
line_end = output.find("\n", last + len(marker))
|
||||||
|
line_end = line_end + 1 if line_end != -1 else len(output)
|
||||||
|
|
||||||
|
result["output"] = output[:line_start] + output[line_end:]
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Hooks
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _before_execute(self):
|
||||||
|
"""Rate-limited file sync before each command.
|
||||||
|
|
||||||
|
Backends that need pre-command sync set ``self._last_sync_time = 0``
|
||||||
|
in ``__init__`` and override :meth:`_sync_files`. Backends needing
|
||||||
|
extra pre-exec logic (e.g. Daytona sandbox restart check) override
|
||||||
|
this method and call ``super()._before_execute()``.
|
||||||
|
"""
|
||||||
|
if self._last_sync_time is not None:
|
||||||
|
now = time.monotonic()
|
||||||
|
if now - self._last_sync_time >= _SYNC_INTERVAL_SECONDS:
|
||||||
|
self._sync_files()
|
||||||
|
self._last_sync_time = now
|
||||||
|
|
||||||
|
def _sync_files(self):
|
||||||
|
"""Push files to remote environment. Called rate-limited by _before_execute."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Unified execute()
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def execute(
|
||||||
|
self,
|
||||||
|
command: str,
|
||||||
|
cwd: str = "",
|
||||||
|
*,
|
||||||
|
timeout: int | None = None,
|
||||||
|
stdin_data: str | None = None,
|
||||||
|
) -> dict:
|
||||||
|
"""Execute a command, return {"output": str, "returncode": int}."""
|
||||||
|
self._before_execute()
|
||||||
|
|
||||||
|
exec_command, sudo_stdin = self._prepare_command(command)
|
||||||
|
effective_timeout = timeout or self.timeout
|
||||||
|
effective_cwd = cwd or self.cwd
|
||||||
|
|
||||||
|
# Merge sudo stdin with caller stdin
|
||||||
|
if sudo_stdin is not None and stdin_data is not None:
|
||||||
|
effective_stdin = sudo_stdin + stdin_data
|
||||||
|
elif sudo_stdin is not None:
|
||||||
|
effective_stdin = sudo_stdin
|
||||||
|
else:
|
||||||
|
effective_stdin = stdin_data
|
||||||
|
|
||||||
|
# Embed stdin as heredoc for backends that need it
|
||||||
|
if effective_stdin and self._stdin_mode == "heredoc":
|
||||||
|
exec_command = self._embed_stdin_heredoc(exec_command, effective_stdin)
|
||||||
|
effective_stdin = None
|
||||||
|
|
||||||
|
wrapped = self._wrap_command(exec_command, effective_cwd)
|
||||||
|
|
||||||
|
# Use login shell if snapshot failed (so user's profile still loads)
|
||||||
|
login = not self._snapshot_ready
|
||||||
|
|
||||||
|
proc = self._run_bash(
|
||||||
|
wrapped, login=login, timeout=effective_timeout, stdin_data=effective_stdin
|
||||||
|
)
|
||||||
|
result = self._wait_for_process(proc, timeout=effective_timeout)
|
||||||
|
self._update_cwd(result)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Shared helpers
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
def stop(self):
|
def stop(self):
|
||||||
"""Alias for cleanup (compat with older callers)."""
|
"""Alias for cleanup (compat with older callers)."""
|
||||||
self.cleanup()
|
self.cleanup()
|
||||||
@@ -57,53 +544,12 @@ class BaseEnvironment(ABC):
|
|||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# Shared helpers (eliminate duplication across backends)
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
|
|
||||||
def _prepare_command(self, command: str) -> tuple[str, str | None]:
|
def _prepare_command(self, command: str) -> tuple[str, str | None]:
|
||||||
"""Transform sudo commands if SUDO_PASSWORD is available.
|
"""Transform sudo commands if SUDO_PASSWORD is available."""
|
||||||
|
|
||||||
Returns:
|
|
||||||
(transformed_command, sudo_stdin) — see _transform_sudo_command
|
|
||||||
for the full contract. Callers that drive a subprocess directly
|
|
||||||
should prepend sudo_stdin (when not None) to any stdin_data they
|
|
||||||
pass to Popen. Callers that embed stdin via heredoc (modal,
|
|
||||||
daytona) handle sudo_stdin in their own execute() method.
|
|
||||||
"""
|
|
||||||
from tools.terminal_tool import _transform_sudo_command
|
from tools.terminal_tool import _transform_sudo_command
|
||||||
|
|
||||||
return _transform_sudo_command(command)
|
return _transform_sudo_command(command)
|
||||||
|
|
||||||
def _build_run_kwargs(self, timeout: int | None,
|
|
||||||
stdin_data: str | None = None) -> dict:
|
|
||||||
"""Build common subprocess.run kwargs for non-interactive execution."""
|
|
||||||
kw = {
|
|
||||||
"text": True,
|
|
||||||
"timeout": timeout or self.timeout,
|
|
||||||
"encoding": "utf-8",
|
|
||||||
"errors": "replace",
|
|
||||||
"stdout": subprocess.PIPE,
|
|
||||||
"stderr": subprocess.STDOUT,
|
|
||||||
}
|
|
||||||
if stdin_data is not None:
|
|
||||||
kw["input"] = stdin_data
|
|
||||||
else:
|
|
||||||
kw["stdin"] = subprocess.DEVNULL
|
|
||||||
return kw
|
|
||||||
|
|
||||||
def execute_oneshot(self, command: str, cwd: str = "", *,
|
|
||||||
timeout: int | None = None,
|
|
||||||
stdin_data: str | None = None) -> dict:
|
|
||||||
"""Execute a command bypassing any persistent shell.
|
|
||||||
|
|
||||||
Safe for concurrent use alongside a long-running execute() call.
|
|
||||||
Backends that maintain a persistent shell (SSH, Local) override this
|
|
||||||
to route through their oneshot path, avoiding the shell lock.
|
|
||||||
Non-persistent backends delegate to execute().
|
|
||||||
"""
|
|
||||||
return self.execute(command, cwd=cwd, timeout=timeout,
|
|
||||||
stdin_data=stdin_data)
|
|
||||||
|
|
||||||
def _timeout_result(self, timeout: int | None) -> dict:
|
def _timeout_result(self, timeout: int | None) -> dict:
|
||||||
"""Standard return dict when a command times out."""
|
"""Standard return dict when a command times out."""
|
||||||
return {
|
return {
|
||||||
|
|||||||
@@ -6,17 +6,18 @@ and resumed on next creation, preserving the filesystem across sessions.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import time
|
|
||||||
import math
|
import math
|
||||||
import shlex
|
import shlex
|
||||||
import threading
|
import threading
|
||||||
import uuid
|
|
||||||
import warnings
|
import warnings
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Optional
|
from typing import Dict, Optional
|
||||||
|
|
||||||
from tools.environments.base import BaseEnvironment
|
from tools.environments.base import (
|
||||||
from tools.interrupt import is_interrupted
|
BaseEnvironment,
|
||||||
|
_ThreadedProcessHandle,
|
||||||
|
_file_mtime_key,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -24,22 +25,25 @@ logger = logging.getLogger(__name__)
|
|||||||
class DaytonaEnvironment(BaseEnvironment):
|
class DaytonaEnvironment(BaseEnvironment):
|
||||||
"""Daytona cloud sandbox execution backend.
|
"""Daytona cloud sandbox execution backend.
|
||||||
|
|
||||||
Uses stopped/started sandbox lifecycle for filesystem persistence
|
Spawn-per-call via _ThreadedProcessHandle wrapping blocking SDK calls.
|
||||||
instead of snapshots, making it faster and stateless on the host.
|
cancel_fn wired to sandbox.stop() for interrupt support.
|
||||||
|
Shell timeout wrapper preserved (SDK timeout unreliable).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
_stdin_mode = "heredoc"
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
image: str,
|
image: str,
|
||||||
cwd: str = "/home/daytona",
|
cwd: str = "/home/daytona",
|
||||||
timeout: int = 60,
|
timeout: int = 60,
|
||||||
cpu: int = 1,
|
cpu: int = 1,
|
||||||
memory: int = 5120, # MB (hermes convention)
|
memory: int = 5120,
|
||||||
disk: int = 10240, # MB (Daytona platform max is 10GB)
|
disk: int = 10240,
|
||||||
persistent_filesystem: bool = True,
|
persistent_filesystem: bool = True,
|
||||||
task_id: str = "default",
|
task_id: str = "default",
|
||||||
):
|
):
|
||||||
self._requested_cwd = cwd
|
requested_cwd = cwd
|
||||||
super().__init__(cwd=cwd, timeout=timeout)
|
super().__init__(cwd=cwd, timeout=timeout)
|
||||||
|
|
||||||
from daytona import (
|
from daytona import (
|
||||||
@@ -53,16 +57,18 @@ class DaytonaEnvironment(BaseEnvironment):
|
|||||||
self._persistent = persistent_filesystem
|
self._persistent = persistent_filesystem
|
||||||
self._task_id = task_id
|
self._task_id = task_id
|
||||||
self._SandboxState = SandboxState
|
self._SandboxState = SandboxState
|
||||||
|
self._DaytonaError = DaytonaError
|
||||||
self._daytona = Daytona()
|
self._daytona = Daytona()
|
||||||
self._sandbox = None
|
self._sandbox = None
|
||||||
self._lock = threading.Lock()
|
self._lock = threading.Lock()
|
||||||
|
self._last_sync_time: float = 0
|
||||||
|
|
||||||
memory_gib = max(1, math.ceil(memory / 1024))
|
memory_gib = max(1, math.ceil(memory / 1024))
|
||||||
disk_gib = max(1, math.ceil(disk / 1024))
|
disk_gib = max(1, math.ceil(disk / 1024))
|
||||||
if disk_gib > 10:
|
if disk_gib > 10:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
f"Daytona: requested disk ({disk_gib}GB) exceeds platform limit (10GB). "
|
f"Daytona: requested disk ({disk_gib}GB) exceeds platform limit (10GB). "
|
||||||
f"Capping to 10GB. Set container_disk: 10240 in config to silence this.",
|
f"Capping to 10GB.",
|
||||||
stacklevel=2,
|
stacklevel=2,
|
||||||
)
|
)
|
||||||
disk_gib = 10
|
disk_gib = 10
|
||||||
@@ -71,9 +77,7 @@ class DaytonaEnvironment(BaseEnvironment):
|
|||||||
labels = {"hermes_task_id": task_id}
|
labels = {"hermes_task_id": task_id}
|
||||||
sandbox_name = f"hermes-{task_id}"
|
sandbox_name = f"hermes-{task_id}"
|
||||||
|
|
||||||
# Try to resume an existing sandbox for this task
|
|
||||||
if self._persistent:
|
if self._persistent:
|
||||||
# 1. Try name-based lookup (new path)
|
|
||||||
try:
|
try:
|
||||||
self._sandbox = self._daytona.get(sandbox_name)
|
self._sandbox = self._daytona.get(sandbox_name)
|
||||||
self._sandbox.start()
|
self._sandbox.start()
|
||||||
@@ -86,7 +90,6 @@ class DaytonaEnvironment(BaseEnvironment):
|
|||||||
task_id, e)
|
task_id, e)
|
||||||
self._sandbox = None
|
self._sandbox = None
|
||||||
|
|
||||||
# 2. Legacy fallback: find sandbox created before the naming migration
|
|
||||||
if self._sandbox is None:
|
if self._sandbox is None:
|
||||||
try:
|
try:
|
||||||
page = self._daytona.list(labels=labels, page=1, limit=1)
|
page = self._daytona.list(labels=labels, page=1, limit=1)
|
||||||
@@ -100,7 +103,6 @@ class DaytonaEnvironment(BaseEnvironment):
|
|||||||
task_id, e)
|
task_id, e)
|
||||||
self._sandbox = None
|
self._sandbox = None
|
||||||
|
|
||||||
# Create a fresh sandbox if we don't have one
|
|
||||||
if self._sandbox is None:
|
if self._sandbox is None:
|
||||||
self._sandbox = self._daytona.create(
|
self._sandbox = self._daytona.create(
|
||||||
CreateSandboxFromImageParams(
|
CreateSandboxFromImageParams(
|
||||||
@@ -114,32 +116,25 @@ class DaytonaEnvironment(BaseEnvironment):
|
|||||||
logger.info("Daytona: created sandbox %s for task %s",
|
logger.info("Daytona: created sandbox %s for task %s",
|
||||||
self._sandbox.id, task_id)
|
self._sandbox.id, task_id)
|
||||||
|
|
||||||
# Detect remote home dir first so mounts go to the right place.
|
# Detect remote home dir
|
||||||
self._remote_home = "/root"
|
self._remote_home = "/root"
|
||||||
try:
|
try:
|
||||||
home = self._sandbox.process.exec("echo $HOME").result.strip()
|
home = self._sandbox.process.exec("echo $HOME").result.strip()
|
||||||
if home:
|
if home:
|
||||||
self._remote_home = home
|
self._remote_home = home
|
||||||
if self._requested_cwd in ("~", "/home/daytona"):
|
if requested_cwd in ("~", "/home/daytona"):
|
||||||
self.cwd = home
|
self.cwd = home
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
logger.info("Daytona: resolved home to %s, cwd to %s", self._remote_home, self.cwd)
|
logger.info("Daytona: resolved home to %s, cwd to %s", self._remote_home, self.cwd)
|
||||||
|
|
||||||
# Track synced files to avoid redundant uploads.
|
|
||||||
# Key: remote_path, Value: (mtime, size)
|
|
||||||
self._synced_files: Dict[str, tuple] = {}
|
self._synced_files: Dict[str, tuple] = {}
|
||||||
|
self._sync_files()
|
||||||
# Upload credential files and skills directory into the sandbox.
|
self.init_session()
|
||||||
self._sync_skills_and_credentials()
|
|
||||||
|
|
||||||
def _upload_if_changed(self, host_path: str, remote_path: str) -> bool:
|
def _upload_if_changed(self, host_path: str, remote_path: str) -> bool:
|
||||||
"""Upload a file if its mtime/size changed since last sync."""
|
file_key = _file_mtime_key(host_path)
|
||||||
hp = Path(host_path)
|
if file_key is None:
|
||||||
try:
|
|
||||||
stat = hp.stat()
|
|
||||||
file_key = (stat.st_mtime, stat.st_size)
|
|
||||||
except OSError:
|
|
||||||
return False
|
return False
|
||||||
if self._synced_files.get(remote_path) == file_key:
|
if self._synced_files.get(remote_path) == file_key:
|
||||||
return False
|
return False
|
||||||
@@ -153,20 +148,15 @@ class DaytonaEnvironment(BaseEnvironment):
|
|||||||
logger.debug("Daytona: upload failed %s: %s", host_path, e)
|
logger.debug("Daytona: upload failed %s: %s", host_path, e)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _sync_skills_and_credentials(self) -> None:
|
def _sync_files(self) -> None:
|
||||||
"""Upload changed credential files and skill files into the sandbox."""
|
|
||||||
container_base = f"{self._remote_home}/.hermes"
|
container_base = f"{self._remote_home}/.hermes"
|
||||||
try:
|
try:
|
||||||
from tools.credential_files import get_credential_file_mounts, iter_skills_files
|
from tools.credential_files import get_credential_file_mounts, iter_skills_files
|
||||||
|
|
||||||
for mount_entry in get_credential_file_mounts():
|
for mount_entry in get_credential_file_mounts():
|
||||||
remote_path = mount_entry["container_path"].replace("/root/.hermes", container_base, 1)
|
remote_path = mount_entry["container_path"].replace("/root/.hermes", container_base, 1)
|
||||||
if self._upload_if_changed(mount_entry["host_path"], remote_path):
|
self._upload_if_changed(mount_entry["host_path"], remote_path)
|
||||||
logger.debug("Daytona: synced credential %s", remote_path)
|
|
||||||
|
|
||||||
for entry in iter_skills_files(container_base=container_base):
|
for entry in iter_skills_files(container_base=container_base):
|
||||||
if self._upload_if_changed(entry["host_path"], entry["container_path"]):
|
self._upload_if_changed(entry["host_path"], entry["container_path"])
|
||||||
logger.debug("Daytona: synced skill %s", entry["container_path"])
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug("Daytona: could not sync skills/credentials: %s", e)
|
logger.debug("Daytona: could not sync skills/credentials: %s", e)
|
||||||
|
|
||||||
@@ -177,111 +167,36 @@ class DaytonaEnvironment(BaseEnvironment):
|
|||||||
self._sandbox.start()
|
self._sandbox.start()
|
||||||
logger.info("Daytona: restarted sandbox %s", self._sandbox.id)
|
logger.info("Daytona: restarted sandbox %s", self._sandbox.id)
|
||||||
|
|
||||||
def _exec_in_thread(self, exec_command: str, cwd: Optional[str], timeout: int) -> dict:
|
def _before_execute(self):
|
||||||
"""Run exec in a background thread with interrupt polling.
|
"""Ensure sandbox is ready, then rate-limited file sync via base class."""
|
||||||
|
|
||||||
The Daytona SDK's exec(timeout=...) parameter is unreliable (the
|
|
||||||
server-side timeout is not enforced and the SDK has no client-side
|
|
||||||
fallback), so we wrap the command with the shell ``timeout`` utility
|
|
||||||
which reliably kills the process and returns exit code 124.
|
|
||||||
"""
|
|
||||||
# Wrap with shell `timeout` to enforce the deadline reliably.
|
|
||||||
# Add a small buffer so the shell timeout fires before any SDK-level
|
|
||||||
# timeout would, giving us a clean exit code 124.
|
|
||||||
timed_command = f"timeout {timeout} sh -c {shlex.quote(exec_command)}"
|
|
||||||
|
|
||||||
result_holder: dict = {"value": None, "error": None}
|
|
||||||
|
|
||||||
def _run():
|
|
||||||
try:
|
|
||||||
response = self._sandbox.process.exec(
|
|
||||||
timed_command, cwd=cwd,
|
|
||||||
)
|
|
||||||
result_holder["value"] = {
|
|
||||||
"output": response.result or "",
|
|
||||||
"returncode": response.exit_code,
|
|
||||||
}
|
|
||||||
except Exception as e:
|
|
||||||
result_holder["error"] = e
|
|
||||||
|
|
||||||
t = threading.Thread(target=_run, daemon=True)
|
|
||||||
t.start()
|
|
||||||
# Wait for timeout + generous buffer for network/SDK overhead
|
|
||||||
deadline = time.monotonic() + timeout + 10
|
|
||||||
while t.is_alive():
|
|
||||||
t.join(timeout=0.2)
|
|
||||||
if is_interrupted():
|
|
||||||
with self._lock:
|
|
||||||
try:
|
|
||||||
self._sandbox.stop()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
return {
|
|
||||||
"output": "[Command interrupted - Daytona sandbox stopped]",
|
|
||||||
"returncode": 130,
|
|
||||||
}
|
|
||||||
if time.monotonic() > deadline:
|
|
||||||
# Shell timeout didn't fire and SDK is hung — force stop
|
|
||||||
with self._lock:
|
|
||||||
try:
|
|
||||||
self._sandbox.stop()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
return self._timeout_result(timeout)
|
|
||||||
|
|
||||||
if result_holder["error"]:
|
|
||||||
return {"error": result_holder["error"]}
|
|
||||||
return result_holder["value"]
|
|
||||||
|
|
||||||
def execute(self, command: str, cwd: str = "", *,
|
|
||||||
timeout: Optional[int] = None,
|
|
||||||
stdin_data: Optional[str] = None) -> dict:
|
|
||||||
with self._lock:
|
with self._lock:
|
||||||
self._ensure_sandbox_ready()
|
self._ensure_sandbox_ready()
|
||||||
# Incremental sync before each command so mid-session credential
|
super()._before_execute()
|
||||||
# refreshes and skill updates are picked up.
|
|
||||||
self._sync_skills_and_credentials()
|
|
||||||
|
|
||||||
if stdin_data is not None:
|
def _run_bash(self, cmd_string: str, *, login: bool = False,
|
||||||
marker = f"HERMES_EOF_{uuid.uuid4().hex[:8]}"
|
timeout: int = 120,
|
||||||
while marker in stdin_data:
|
stdin_data: str | None = None):
|
||||||
marker = f"HERMES_EOF_{uuid.uuid4().hex[:8]}"
|
"""Return a _ThreadedProcessHandle wrapping a blocking Daytona SDK call."""
|
||||||
command = f"{command} << '{marker}'\n{stdin_data}\n{marker}"
|
sandbox = self._sandbox
|
||||||
|
lock = self._lock
|
||||||
|
|
||||||
exec_command, sudo_stdin = self._prepare_command(command)
|
def cancel():
|
||||||
|
with lock:
|
||||||
|
try:
|
||||||
|
sandbox.stop()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
# Daytona sandboxes execute commands via the Daytona SDK and cannot
|
if login:
|
||||||
# pipe subprocess stdin directly the way a local Popen can. When a
|
shell_cmd = f"bash -l -c {shlex.quote(cmd_string)}"
|
||||||
# sudo password is present, use a shell-level pipe from printf so that
|
else:
|
||||||
# the password feeds sudo -S without appearing as an echo argument
|
shell_cmd = f"bash -c {shlex.quote(cmd_string)}"
|
||||||
# embedded in the shell string. The password is still visible in the
|
|
||||||
# remote sandbox's command line, but it is not exposed on the user's
|
|
||||||
# local machine — which is the primary threat being mitigated.
|
|
||||||
if sudo_stdin is not None:
|
|
||||||
import shlex
|
|
||||||
exec_command = (
|
|
||||||
f"printf '%s\\n' {shlex.quote(sudo_stdin.rstrip())} | {exec_command}"
|
|
||||||
)
|
|
||||||
effective_cwd = cwd or self.cwd or None
|
|
||||||
effective_timeout = timeout or self.timeout
|
|
||||||
|
|
||||||
result = self._exec_in_thread(exec_command, effective_cwd, effective_timeout)
|
def exec_fn() -> tuple[str, int]:
|
||||||
|
response = sandbox.process.exec(shell_cmd, timeout=timeout)
|
||||||
|
return (response.result or "", response.exit_code)
|
||||||
|
|
||||||
if "error" in result:
|
return _ThreadedProcessHandle(exec_fn, cancel_fn=cancel)
|
||||||
from daytona import DaytonaError
|
|
||||||
err = result["error"]
|
|
||||||
if isinstance(err, DaytonaError):
|
|
||||||
with self._lock:
|
|
||||||
try:
|
|
||||||
self._ensure_sandbox_ready()
|
|
||||||
except Exception:
|
|
||||||
return {"output": f"Daytona execution error: {err}", "returncode": 1}
|
|
||||||
result = self._exec_in_thread(exec_command, effective_cwd, effective_timeout)
|
|
||||||
if "error" not in result:
|
|
||||||
return result
|
|
||||||
return {"output": f"Daytona execution error: {err}", "returncode": 1}
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
def cleanup(self):
|
def cleanup(self):
|
||||||
with self._lock:
|
with self._lock:
|
||||||
|
|||||||
@@ -8,18 +8,14 @@ persistence via bind mounts.
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import shlex
|
|
||||||
import shutil
|
import shutil
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
import threading
|
|
||||||
import time
|
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from tools.environments.base import BaseEnvironment
|
from tools.environments.base import BaseEnvironment, _popen_bash
|
||||||
from tools.environments.local import _HERMES_PROVIDER_ENV_BLOCKLIST
|
from tools.environments.local import _HERMES_PROVIDER_ENV_BLOCKLIST
|
||||||
from tools.interrupt import is_interrupted
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -431,6 +427,69 @@ class DockerEnvironment(BaseEnvironment):
|
|||||||
self._container_id = result.stdout.strip()
|
self._container_id = result.stdout.strip()
|
||||||
logger.info(f"Started container {container_name} ({self._container_id[:12]})")
|
logger.info(f"Started container {container_name} ({self._container_id[:12]})")
|
||||||
|
|
||||||
|
# Build the init-time env forwarding args (used only by init_session
|
||||||
|
# to inject host env vars into the snapshot; subsequent commands get
|
||||||
|
# them from the snapshot file).
|
||||||
|
self._init_env_args = self._build_init_env_args()
|
||||||
|
|
||||||
|
# Initialize session snapshot inside the container
|
||||||
|
self.init_session()
|
||||||
|
|
||||||
|
def _build_init_env_args(self) -> list[str]:
|
||||||
|
"""Build -e KEY=VALUE args for injecting host env vars into init_session.
|
||||||
|
|
||||||
|
These are used once during init_session() so that export -p captures
|
||||||
|
them into the snapshot. Subsequent execute() calls don't need -e flags.
|
||||||
|
"""
|
||||||
|
exec_env: dict[str, str] = dict(self._env)
|
||||||
|
|
||||||
|
explicit_forward_keys = set(self._forward_env)
|
||||||
|
passthrough_keys: set[str] = set()
|
||||||
|
try:
|
||||||
|
from tools.env_passthrough import get_all_passthrough
|
||||||
|
passthrough_keys = set(get_all_passthrough())
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
# Explicit docker_forward_env entries are an intentional opt-in and must
|
||||||
|
# win over the generic Hermes secret blocklist. Only implicit passthrough
|
||||||
|
# keys are filtered.
|
||||||
|
forward_keys = explicit_forward_keys | (passthrough_keys - _HERMES_PROVIDER_ENV_BLOCKLIST)
|
||||||
|
hermes_env = _load_hermes_env_vars() if forward_keys else {}
|
||||||
|
for key in sorted(forward_keys):
|
||||||
|
value = os.getenv(key)
|
||||||
|
if value is None:
|
||||||
|
value = hermes_env.get(key)
|
||||||
|
if value is not None:
|
||||||
|
exec_env[key] = value
|
||||||
|
|
||||||
|
args = []
|
||||||
|
for key in sorted(exec_env):
|
||||||
|
args.extend(["-e", f"{key}={exec_env[key]}"])
|
||||||
|
return args
|
||||||
|
|
||||||
|
def _run_bash(self, cmd_string: str, *, login: bool = False,
|
||||||
|
timeout: int = 120,
|
||||||
|
stdin_data: str | None = None) -> subprocess.Popen:
|
||||||
|
"""Spawn a bash process inside the Docker container."""
|
||||||
|
assert self._container_id, "Container not started"
|
||||||
|
cmd = [self._docker_exe, "exec"]
|
||||||
|
if stdin_data is not None:
|
||||||
|
cmd.append("-i")
|
||||||
|
|
||||||
|
# Only inject -e env args during init_session (login=True).
|
||||||
|
# Subsequent commands get env vars from the snapshot.
|
||||||
|
if login:
|
||||||
|
cmd.extend(self._init_env_args)
|
||||||
|
|
||||||
|
cmd.extend([self._container_id])
|
||||||
|
|
||||||
|
if login:
|
||||||
|
cmd.extend(["bash", "-l", "-c", cmd_string])
|
||||||
|
else:
|
||||||
|
cmd.extend(["bash", "-c", cmd_string])
|
||||||
|
|
||||||
|
return _popen_bash(cmd, stdin_data)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _storage_opt_supported() -> bool:
|
def _storage_opt_supported() -> bool:
|
||||||
"""Check if Docker's storage driver supports --storage-opt size=.
|
"""Check if Docker's storage driver supports --storage-opt size=.
|
||||||
@@ -471,112 +530,6 @@ class DockerEnvironment(BaseEnvironment):
|
|||||||
logger.debug("Docker --storage-opt support: %s", _storage_opt_ok)
|
logger.debug("Docker --storage-opt support: %s", _storage_opt_ok)
|
||||||
return _storage_opt_ok
|
return _storage_opt_ok
|
||||||
|
|
||||||
def execute(self, command: str, cwd: str = "", *,
|
|
||||||
timeout: int | None = None,
|
|
||||||
stdin_data: str | None = None) -> dict:
|
|
||||||
exec_command, sudo_stdin = self._prepare_command(command)
|
|
||||||
work_dir = cwd or self.cwd
|
|
||||||
effective_timeout = timeout or self.timeout
|
|
||||||
|
|
||||||
# Merge sudo password (if any) with caller-supplied stdin_data.
|
|
||||||
if sudo_stdin is not None and stdin_data is not None:
|
|
||||||
effective_stdin = sudo_stdin + stdin_data
|
|
||||||
elif sudo_stdin is not None:
|
|
||||||
effective_stdin = sudo_stdin
|
|
||||||
else:
|
|
||||||
effective_stdin = stdin_data
|
|
||||||
|
|
||||||
# docker exec -w doesn't expand ~, so prepend a cd into the command.
|
|
||||||
# Keep ~ unquoted (for shell expansion) and quote only the subpath.
|
|
||||||
if work_dir == "~":
|
|
||||||
exec_command = f"cd ~ && {exec_command}"
|
|
||||||
work_dir = "/"
|
|
||||||
elif work_dir.startswith("~/"):
|
|
||||||
exec_command = f"cd ~/{shlex.quote(work_dir[2:])} && {exec_command}"
|
|
||||||
work_dir = "/"
|
|
||||||
|
|
||||||
assert self._container_id, "Container not started"
|
|
||||||
cmd = [self._docker_exe, "exec"]
|
|
||||||
if effective_stdin is not None:
|
|
||||||
cmd.append("-i")
|
|
||||||
cmd.extend(["-w", work_dir])
|
|
||||||
# Build the per-exec environment: start with explicit docker_env values
|
|
||||||
# (static config), then overlay docker_forward_env / skill env_passthrough
|
|
||||||
# (dynamic from host process). Forward values take precedence.
|
|
||||||
exec_env: dict[str, str] = dict(self._env)
|
|
||||||
|
|
||||||
explicit_forward_keys = set(self._forward_env)
|
|
||||||
passthrough_keys: set[str] = set()
|
|
||||||
try:
|
|
||||||
from tools.env_passthrough import get_all_passthrough
|
|
||||||
passthrough_keys = set(get_all_passthrough())
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
# Explicit docker_forward_env entries are an intentional opt-in and must
|
|
||||||
# win over the generic Hermes secret blocklist. Only implicit passthrough
|
|
||||||
# keys are filtered.
|
|
||||||
forward_keys = explicit_forward_keys | (passthrough_keys - _HERMES_PROVIDER_ENV_BLOCKLIST)
|
|
||||||
hermes_env = _load_hermes_env_vars() if forward_keys else {}
|
|
||||||
for key in sorted(forward_keys):
|
|
||||||
value = os.getenv(key)
|
|
||||||
if value is None:
|
|
||||||
value = hermes_env.get(key)
|
|
||||||
if value is not None:
|
|
||||||
exec_env[key] = value
|
|
||||||
|
|
||||||
for key in sorted(exec_env):
|
|
||||||
cmd.extend(["-e", f"{key}={exec_env[key]}"])
|
|
||||||
cmd.extend([self._container_id, "bash", "-lc", exec_command])
|
|
||||||
|
|
||||||
try:
|
|
||||||
_output_chunks = []
|
|
||||||
proc = subprocess.Popen(
|
|
||||||
cmd,
|
|
||||||
stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
|
|
||||||
stdin=subprocess.PIPE if effective_stdin else subprocess.DEVNULL,
|
|
||||||
text=True,
|
|
||||||
)
|
|
||||||
if effective_stdin:
|
|
||||||
try:
|
|
||||||
proc.stdin.write(effective_stdin)
|
|
||||||
proc.stdin.close()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def _drain():
|
|
||||||
try:
|
|
||||||
for line in proc.stdout:
|
|
||||||
_output_chunks.append(line)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
reader = threading.Thread(target=_drain, daemon=True)
|
|
||||||
reader.start()
|
|
||||||
deadline = time.monotonic() + effective_timeout
|
|
||||||
|
|
||||||
while proc.poll() is None:
|
|
||||||
if is_interrupted():
|
|
||||||
proc.terminate()
|
|
||||||
try:
|
|
||||||
proc.wait(timeout=1)
|
|
||||||
except subprocess.TimeoutExpired:
|
|
||||||
proc.kill()
|
|
||||||
reader.join(timeout=2)
|
|
||||||
return {
|
|
||||||
"output": "".join(_output_chunks) + "\n[Command interrupted]",
|
|
||||||
"returncode": 130,
|
|
||||||
}
|
|
||||||
if time.monotonic() > deadline:
|
|
||||||
proc.kill()
|
|
||||||
reader.join(timeout=2)
|
|
||||||
return self._timeout_result(effective_timeout)
|
|
||||||
time.sleep(0.2)
|
|
||||||
|
|
||||||
reader.join(timeout=5)
|
|
||||||
return {"output": "".join(_output_chunks), "returncode": proc.returncode}
|
|
||||||
except Exception as e:
|
|
||||||
return {"output": f"Docker execution error: {e}", "returncode": 1}
|
|
||||||
|
|
||||||
def cleanup(self):
|
def cleanup(self):
|
||||||
"""Stop and remove the container. Bind-mount dirs persist if persistent=True."""
|
"""Stop and remove the container. Bind-mount dirs persist if persistent=True."""
|
||||||
if self._container_id:
|
if self._container_id:
|
||||||
|
|||||||
@@ -1,42 +1,22 @@
|
|||||||
"""Local execution environment with interrupt support and non-blocking I/O."""
|
"""Local execution environment — spawn-per-call with session snapshot."""
|
||||||
|
|
||||||
import glob
|
|
||||||
import os
|
import os
|
||||||
import platform
|
import platform
|
||||||
import shutil
|
import shutil
|
||||||
import signal
|
import signal
|
||||||
import subprocess
|
import subprocess
|
||||||
import threading
|
|
||||||
import time
|
from tools.environments.base import BaseEnvironment, _pipe_stdin
|
||||||
|
|
||||||
_IS_WINDOWS = platform.system() == "Windows"
|
_IS_WINDOWS = platform.system() == "Windows"
|
||||||
|
|
||||||
from tools.environments.base import BaseEnvironment
|
|
||||||
from tools.environments.persistent_shell import PersistentShellMixin
|
|
||||||
from tools.interrupt import is_interrupted
|
|
||||||
|
|
||||||
# Unique marker to isolate real command output from shell init/exit noise.
|
|
||||||
# printf (no trailing newline) keeps the boundaries clean for splitting.
|
|
||||||
_OUTPUT_FENCE = "__HERMES_FENCE_a9f7b3__"
|
|
||||||
|
|
||||||
# Hermes-internal env vars that should NOT leak into terminal subprocesses.
|
# Hermes-internal env vars that should NOT leak into terminal subprocesses.
|
||||||
# These are loaded from ~/.hermes/.env for Hermes' own LLM/provider calls
|
|
||||||
# but can break external CLIs (e.g. codex) that also honor them.
|
|
||||||
# See: https://github.com/NousResearch/hermes-agent/issues/1002
|
|
||||||
#
|
|
||||||
# Built dynamically from the provider registry so new providers are
|
|
||||||
# automatically covered without manual blocklist maintenance.
|
|
||||||
_HERMES_PROVIDER_ENV_FORCE_PREFIX = "_HERMES_FORCE_"
|
_HERMES_PROVIDER_ENV_FORCE_PREFIX = "_HERMES_FORCE_"
|
||||||
|
|
||||||
|
|
||||||
def _build_provider_env_blocklist() -> frozenset:
|
def _build_provider_env_blocklist() -> frozenset:
|
||||||
"""Derive the blocklist from provider, tool, and gateway config.
|
"""Derive the blocklist from provider, tool, and gateway config."""
|
||||||
|
|
||||||
Automatically picks up api_key_env_vars and base_url_env_var from
|
|
||||||
every registered provider, plus tool/messaging env vars from the
|
|
||||||
optional config registry, so new Hermes-managed secrets are blocked
|
|
||||||
in subprocesses without having to maintain multiple static lists.
|
|
||||||
"""
|
|
||||||
blocked: set[str] = set()
|
blocked: set[str] = set()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -59,33 +39,30 @@ def _build_provider_env_blocklist() -> frozenset:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# Vars not covered above but still Hermes-internal / conflict-prone.
|
|
||||||
blocked.update({
|
blocked.update({
|
||||||
"OPENAI_BASE_URL",
|
"OPENAI_BASE_URL",
|
||||||
"OPENAI_API_KEY",
|
"OPENAI_API_KEY",
|
||||||
"OPENAI_API_BASE", # legacy alias
|
"OPENAI_API_BASE",
|
||||||
"OPENAI_ORG_ID",
|
"OPENAI_ORG_ID",
|
||||||
"OPENAI_ORGANIZATION",
|
"OPENAI_ORGANIZATION",
|
||||||
"OPENROUTER_API_KEY",
|
"OPENROUTER_API_KEY",
|
||||||
"ANTHROPIC_BASE_URL",
|
"ANTHROPIC_BASE_URL",
|
||||||
"ANTHROPIC_TOKEN", # OAuth token (not in registry as env var)
|
"ANTHROPIC_TOKEN",
|
||||||
"CLAUDE_CODE_OAUTH_TOKEN",
|
"CLAUDE_CODE_OAUTH_TOKEN",
|
||||||
"LLM_MODEL",
|
"LLM_MODEL",
|
||||||
# Expanded isolation for other major providers (Issue #1002)
|
"GOOGLE_API_KEY",
|
||||||
"GOOGLE_API_KEY", # Gemini / Google AI Studio
|
"DEEPSEEK_API_KEY",
|
||||||
"DEEPSEEK_API_KEY", # DeepSeek
|
"MISTRAL_API_KEY",
|
||||||
"MISTRAL_API_KEY", # Mistral AI
|
"GROQ_API_KEY",
|
||||||
"GROQ_API_KEY", # Groq
|
"TOGETHER_API_KEY",
|
||||||
"TOGETHER_API_KEY", # Together AI
|
"PERPLEXITY_API_KEY",
|
||||||
"PERPLEXITY_API_KEY", # Perplexity
|
"COHERE_API_KEY",
|
||||||
"COHERE_API_KEY", # Cohere
|
"FIREWORKS_API_KEY",
|
||||||
"FIREWORKS_API_KEY", # Fireworks AI
|
"XAI_API_KEY",
|
||||||
"XAI_API_KEY", # xAI (Grok)
|
"HELICONE_API_KEY",
|
||||||
"HELICONE_API_KEY", # LLM Observability proxy
|
|
||||||
"PARALLEL_API_KEY",
|
"PARALLEL_API_KEY",
|
||||||
"FIRECRAWL_API_KEY",
|
"FIRECRAWL_API_KEY",
|
||||||
"FIRECRAWL_API_URL",
|
"FIRECRAWL_API_URL",
|
||||||
# Gateway/runtime config not represented in OPTIONAL_ENV_VARS.
|
|
||||||
"TELEGRAM_HOME_CHANNEL",
|
"TELEGRAM_HOME_CHANNEL",
|
||||||
"TELEGRAM_HOME_CHANNEL_NAME",
|
"TELEGRAM_HOME_CHANNEL_NAME",
|
||||||
"DISCORD_HOME_CHANNEL",
|
"DISCORD_HOME_CHANNEL",
|
||||||
@@ -115,12 +92,10 @@ def _build_provider_env_blocklist() -> frozenset:
|
|||||||
"EMAIL_HOME_ADDRESS",
|
"EMAIL_HOME_ADDRESS",
|
||||||
"EMAIL_HOME_ADDRESS_NAME",
|
"EMAIL_HOME_ADDRESS_NAME",
|
||||||
"GATEWAY_ALLOWED_USERS",
|
"GATEWAY_ALLOWED_USERS",
|
||||||
# Skills Hub / GitHub app auth paths and aliases.
|
|
||||||
"GH_TOKEN",
|
"GH_TOKEN",
|
||||||
"GITHUB_APP_ID",
|
"GITHUB_APP_ID",
|
||||||
"GITHUB_APP_PRIVATE_KEY_PATH",
|
"GITHUB_APP_PRIVATE_KEY_PATH",
|
||||||
"GITHUB_APP_INSTALLATION_ID",
|
"GITHUB_APP_INSTALLATION_ID",
|
||||||
# Remote sandbox backend credentials.
|
|
||||||
"MODAL_TOKEN_ID",
|
"MODAL_TOKEN_ID",
|
||||||
"MODAL_TOKEN_SECRET",
|
"MODAL_TOKEN_SECRET",
|
||||||
"DAYTONA_API_KEY",
|
"DAYTONA_API_KEY",
|
||||||
@@ -132,13 +107,7 @@ _HERMES_PROVIDER_ENV_BLOCKLIST = _build_provider_env_blocklist()
|
|||||||
|
|
||||||
|
|
||||||
def _sanitize_subprocess_env(base_env: dict | None, extra_env: dict | None = None) -> dict:
|
def _sanitize_subprocess_env(base_env: dict | None, extra_env: dict | None = None) -> dict:
|
||||||
"""Filter Hermes-managed secrets from a subprocess environment.
|
"""Filter Hermes-managed secrets from a subprocess environment."""
|
||||||
|
|
||||||
`_HERMES_FORCE_<VAR>` entries in ``extra_env`` opt a blocked variable back in
|
|
||||||
intentionally for callers that truly need it. Vars registered via
|
|
||||||
:mod:`tools.env_passthrough` (skill-declared or user-configured) also
|
|
||||||
bypass the blocklist.
|
|
||||||
"""
|
|
||||||
try:
|
try:
|
||||||
from tools.env_passthrough import is_env_passthrough as _is_passthrough
|
from tools.env_passthrough import is_env_passthrough as _is_passthrough
|
||||||
except Exception:
|
except Exception:
|
||||||
@@ -163,33 +132,24 @@ def _sanitize_subprocess_env(base_env: dict | None, extra_env: dict | None = Non
|
|||||||
|
|
||||||
|
|
||||||
def _find_bash() -> str:
|
def _find_bash() -> str:
|
||||||
"""Find bash for command execution.
|
"""Find bash for command execution."""
|
||||||
|
|
||||||
The fence wrapper uses bash syntax (semicolons, $?, printf), so we
|
|
||||||
must use bash — not the user's $SHELL which could be fish/zsh/etc.
|
|
||||||
On Windows: uses Git Bash (bundled with Git for Windows).
|
|
||||||
"""
|
|
||||||
if not _IS_WINDOWS:
|
if not _IS_WINDOWS:
|
||||||
return (
|
return (
|
||||||
shutil.which("bash")
|
shutil.which("bash")
|
||||||
or ("/usr/bin/bash" if os.path.isfile("/usr/bin/bash") else None)
|
or ("/usr/bin/bash" if os.path.isfile("/usr/bin/bash") else None)
|
||||||
or ("/bin/bash" if os.path.isfile("/bin/bash") else None)
|
or ("/bin/bash" if os.path.isfile("/bin/bash") else None)
|
||||||
or os.environ.get("SHELL") # last resort: whatever they have
|
or os.environ.get("SHELL")
|
||||||
or "/bin/sh"
|
or "/bin/sh"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Windows: look for Git Bash (installed with Git for Windows).
|
|
||||||
# Allow override via env var (same pattern as Claude Code).
|
|
||||||
custom = os.environ.get("HERMES_GIT_BASH_PATH")
|
custom = os.environ.get("HERMES_GIT_BASH_PATH")
|
||||||
if custom and os.path.isfile(custom):
|
if custom and os.path.isfile(custom):
|
||||||
return custom
|
return custom
|
||||||
|
|
||||||
# shutil.which finds bash.exe if Git\bin is on PATH
|
|
||||||
found = shutil.which("bash")
|
found = shutil.which("bash")
|
||||||
if found:
|
if found:
|
||||||
return found
|
return found
|
||||||
|
|
||||||
# Check common Git for Windows install locations
|
|
||||||
for candidate in (
|
for candidate in (
|
||||||
os.path.join(os.environ.get("ProgramFiles", r"C:\Program Files"), "Git", "bin", "bash.exe"),
|
os.path.join(os.environ.get("ProgramFiles", r"C:\Program Files"), "Git", "bin", "bash.exe"),
|
||||||
os.path.join(os.environ.get("ProgramFiles(x86)", r"C:\Program Files (x86)"), "Git", "bin", "bash.exe"),
|
os.path.join(os.environ.get("ProgramFiles(x86)", r"C:\Program Files (x86)"), "Git", "bin", "bash.exe"),
|
||||||
@@ -209,60 +169,7 @@ def _find_bash() -> str:
|
|||||||
_find_shell = _find_bash
|
_find_shell = _find_bash
|
||||||
|
|
||||||
|
|
||||||
# Noise lines emitted by interactive shells when stdin is not a terminal.
|
# Standard PATH entries for environments with minimal PATH.
|
||||||
# Used as a fallback when output fence markers are missing.
|
|
||||||
_SHELL_NOISE_SUBSTRINGS = (
|
|
||||||
# bash
|
|
||||||
"bash: cannot set terminal process group",
|
|
||||||
"bash: no job control in this shell",
|
|
||||||
"no job control in this shell",
|
|
||||||
"cannot set terminal process group",
|
|
||||||
"tcsetattr: Inappropriate ioctl for device",
|
|
||||||
# zsh / oh-my-zsh / macOS terminal session
|
|
||||||
"Restored session:",
|
|
||||||
"Saving session...",
|
|
||||||
"Last login:",
|
|
||||||
"command not found:",
|
|
||||||
"Oh My Zsh",
|
|
||||||
"compinit:",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _clean_shell_noise(output: str) -> str:
|
|
||||||
"""Strip shell startup/exit warnings that leak when using -i without a TTY.
|
|
||||||
|
|
||||||
Removes lines matching known noise patterns from both the beginning
|
|
||||||
and end of the output. Lines in the middle are left untouched.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def _is_noise(line: str) -> bool:
|
|
||||||
return any(noise in line for noise in _SHELL_NOISE_SUBSTRINGS)
|
|
||||||
|
|
||||||
lines = output.split("\n")
|
|
||||||
|
|
||||||
# Strip leading noise
|
|
||||||
while lines and _is_noise(lines[0]):
|
|
||||||
lines.pop(0)
|
|
||||||
|
|
||||||
# Strip trailing noise (walk backwards, skip empty lines from split)
|
|
||||||
end = len(lines) - 1
|
|
||||||
while end >= 0 and (not lines[end] or _is_noise(lines[end])):
|
|
||||||
end -= 1
|
|
||||||
|
|
||||||
if end < 0:
|
|
||||||
return ""
|
|
||||||
|
|
||||||
cleaned = lines[: end + 1]
|
|
||||||
result = "\n".join(cleaned)
|
|
||||||
|
|
||||||
# Preserve trailing newline if original had one
|
|
||||||
if output.endswith("\n") and result and not result.endswith("\n"):
|
|
||||||
result += "\n"
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
# Standard PATH entries for environments with minimal PATH (e.g. systemd services).
|
|
||||||
# Includes macOS Homebrew paths (/opt/homebrew/* for Apple Silicon).
|
|
||||||
_SANE_PATH = (
|
_SANE_PATH = (
|
||||||
"/opt/homebrew/bin:/opt/homebrew/sbin:"
|
"/opt/homebrew/bin:/opt/homebrew/sbin:"
|
||||||
"/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin"
|
"/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin"
|
||||||
@@ -290,197 +197,76 @@ def _make_run_env(env: dict) -> dict:
|
|||||||
return run_env
|
return run_env
|
||||||
|
|
||||||
|
|
||||||
def _extract_fenced_output(raw: str) -> str:
|
class LocalEnvironment(BaseEnvironment):
|
||||||
"""Extract real command output from between fence markers.
|
|
||||||
|
|
||||||
The execute() method wraps each command with printf(FENCE) markers.
|
|
||||||
This function finds the first and last fence and returns only the
|
|
||||||
content between them, which is the actual command output free of
|
|
||||||
any shell init/exit noise.
|
|
||||||
|
|
||||||
Falls back to pattern-based _clean_shell_noise if fences are missing.
|
|
||||||
"""
|
|
||||||
first = raw.find(_OUTPUT_FENCE)
|
|
||||||
if first == -1:
|
|
||||||
return _clean_shell_noise(raw)
|
|
||||||
|
|
||||||
start = first + len(_OUTPUT_FENCE)
|
|
||||||
last = raw.rfind(_OUTPUT_FENCE)
|
|
||||||
|
|
||||||
if last <= first:
|
|
||||||
# Only start fence found (e.g. user command called `exit`)
|
|
||||||
return _clean_shell_noise(raw[start:])
|
|
||||||
|
|
||||||
return raw[start:last]
|
|
||||||
|
|
||||||
|
|
||||||
class LocalEnvironment(PersistentShellMixin, BaseEnvironment):
|
|
||||||
"""Run commands directly on the host machine.
|
"""Run commands directly on the host machine.
|
||||||
|
|
||||||
Features:
|
Spawn-per-call: every execute() spawns a fresh bash process.
|
||||||
- Popen + polling for interrupt support (user can cancel mid-command)
|
Session snapshot preserves env vars across calls.
|
||||||
- Background stdout drain thread to prevent pipe buffer deadlocks
|
CWD persists via file-based read after each command.
|
||||||
- stdin_data support for piping content (bypasses ARG_MAX limits)
|
|
||||||
- sudo -S transform via SUDO_PASSWORD env var
|
|
||||||
- Uses interactive login shell so full user env is available
|
|
||||||
- Optional persistent shell mode (cwd/env vars survive across calls)
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, cwd: str = "", timeout: int = 60, env: dict = None,
|
def __init__(self, cwd: str = "", timeout: int = 60, env: dict = None):
|
||||||
persistent: bool = False):
|
|
||||||
super().__init__(cwd=cwd or os.getcwd(), timeout=timeout, env=env)
|
super().__init__(cwd=cwd or os.getcwd(), timeout=timeout, env=env)
|
||||||
self.persistent = persistent
|
self.init_session()
|
||||||
if self.persistent:
|
|
||||||
self._init_persistent_shell()
|
|
||||||
|
|
||||||
@property
|
def _run_bash(self, cmd_string: str, *, login: bool = False,
|
||||||
def _temp_prefix(self) -> str:
|
timeout: int = 120,
|
||||||
return f"/tmp/hermes-local-{self._session_id}"
|
stdin_data: str | None = None) -> subprocess.Popen:
|
||||||
|
bash = _find_bash()
|
||||||
def _spawn_shell_process(self) -> subprocess.Popen:
|
args = [bash, "-l", "-c", cmd_string] if login else [bash, "-c", cmd_string]
|
||||||
user_shell = _find_bash()
|
|
||||||
run_env = _make_run_env(self.env)
|
|
||||||
return subprocess.Popen(
|
|
||||||
[user_shell, "-l"],
|
|
||||||
stdin=subprocess.PIPE,
|
|
||||||
stdout=subprocess.PIPE,
|
|
||||||
stderr=subprocess.DEVNULL,
|
|
||||||
text=True,
|
|
||||||
env=run_env,
|
|
||||||
preexec_fn=None if _IS_WINDOWS else os.setsid,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _read_temp_files(self, *paths: str) -> list[str]:
|
|
||||||
results = []
|
|
||||||
for path in paths:
|
|
||||||
if os.path.exists(path):
|
|
||||||
with open(path) as f:
|
|
||||||
results.append(f.read())
|
|
||||||
else:
|
|
||||||
results.append("")
|
|
||||||
return results
|
|
||||||
|
|
||||||
def _kill_shell_children(self):
|
|
||||||
if self._shell_pid is None:
|
|
||||||
return
|
|
||||||
try:
|
|
||||||
subprocess.run(
|
|
||||||
["pkill", "-P", str(self._shell_pid)],
|
|
||||||
capture_output=True, timeout=5,
|
|
||||||
)
|
|
||||||
except (subprocess.TimeoutExpired, FileNotFoundError):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def _cleanup_temp_files(self):
|
|
||||||
for f in glob.glob(f"{self._temp_prefix}-*"):
|
|
||||||
if os.path.exists(f):
|
|
||||||
os.remove(f)
|
|
||||||
|
|
||||||
def _execute_oneshot(self, command: str, cwd: str = "", *,
|
|
||||||
timeout: int | None = None,
|
|
||||||
stdin_data: str | None = None) -> dict:
|
|
||||||
work_dir = cwd or self.cwd or os.getcwd()
|
|
||||||
effective_timeout = timeout or self.timeout
|
|
||||||
exec_command, sudo_stdin = self._prepare_command(command)
|
|
||||||
|
|
||||||
if sudo_stdin is not None and stdin_data is not None:
|
|
||||||
effective_stdin = sudo_stdin + stdin_data
|
|
||||||
elif sudo_stdin is not None:
|
|
||||||
effective_stdin = sudo_stdin
|
|
||||||
else:
|
|
||||||
effective_stdin = stdin_data
|
|
||||||
|
|
||||||
user_shell = _find_bash()
|
|
||||||
# Newline-separated wrapper (not `cmd; __hermes_rc=...` on one line).
|
|
||||||
# A trailing `; __hermes_rc` glued to `<<EOF` / a closing `EOF` line breaks
|
|
||||||
# heredoc parsing: the delimiter must be alone on its line, otherwise the
|
|
||||||
# rest of this script becomes heredoc body and leaks into stdout (e.g. gh
|
|
||||||
# issue/PR flows that use here-documents for bodies).
|
|
||||||
fenced_cmd = (
|
|
||||||
f"printf '{_OUTPUT_FENCE}'\n"
|
|
||||||
f"{exec_command}\n"
|
|
||||||
f"__hermes_rc=$?\n"
|
|
||||||
f"printf '{_OUTPUT_FENCE}'\n"
|
|
||||||
f"exit $__hermes_rc\n"
|
|
||||||
)
|
|
||||||
run_env = _make_run_env(self.env)
|
run_env = _make_run_env(self.env)
|
||||||
|
|
||||||
proc = subprocess.Popen(
|
proc = subprocess.Popen(
|
||||||
[user_shell, "-lic", fenced_cmd],
|
args,
|
||||||
text=True,
|
text=True,
|
||||||
cwd=work_dir,
|
|
||||||
env=run_env,
|
env=run_env,
|
||||||
encoding="utf-8",
|
encoding="utf-8",
|
||||||
errors="replace",
|
errors="replace",
|
||||||
stdout=subprocess.PIPE,
|
stdout=subprocess.PIPE,
|
||||||
stderr=subprocess.STDOUT,
|
stderr=subprocess.STDOUT,
|
||||||
stdin=subprocess.PIPE if effective_stdin is not None else subprocess.DEVNULL,
|
stdin=subprocess.PIPE if stdin_data is not None else subprocess.DEVNULL,
|
||||||
preexec_fn=None if _IS_WINDOWS else os.setsid,
|
preexec_fn=None if _IS_WINDOWS else os.setsid,
|
||||||
)
|
)
|
||||||
|
|
||||||
if effective_stdin is not None:
|
if stdin_data is not None:
|
||||||
def _write_stdin():
|
_pipe_stdin(proc, stdin_data)
|
||||||
|
|
||||||
|
return proc
|
||||||
|
|
||||||
|
def _kill_process(self, proc):
|
||||||
|
"""Kill the entire process group (all children)."""
|
||||||
|
try:
|
||||||
|
if _IS_WINDOWS:
|
||||||
|
proc.terminate()
|
||||||
|
else:
|
||||||
|
pgid = os.getpgid(proc.pid)
|
||||||
|
os.killpg(pgid, signal.SIGTERM)
|
||||||
try:
|
try:
|
||||||
proc.stdin.write(effective_stdin)
|
proc.wait(timeout=1.0)
|
||||||
proc.stdin.close()
|
except subprocess.TimeoutExpired:
|
||||||
except (BrokenPipeError, OSError):
|
os.killpg(pgid, signal.SIGKILL)
|
||||||
pass
|
except (ProcessLookupError, PermissionError):
|
||||||
threading.Thread(target=_write_stdin, daemon=True).start()
|
|
||||||
|
|
||||||
_output_chunks: list[str] = []
|
|
||||||
|
|
||||||
def _drain_stdout():
|
|
||||||
try:
|
try:
|
||||||
for line in proc.stdout:
|
proc.kill()
|
||||||
_output_chunks.append(line)
|
except Exception:
|
||||||
except ValueError:
|
|
||||||
pass
|
pass
|
||||||
finally:
|
|
||||||
try:
|
|
||||||
proc.stdout.close()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
reader = threading.Thread(target=_drain_stdout, daemon=True)
|
def _update_cwd(self, result: dict):
|
||||||
reader.start()
|
"""Read CWD from temp file (local-only, no round-trip needed)."""
|
||||||
deadline = time.monotonic() + effective_timeout
|
try:
|
||||||
|
cwd_path = open(self._cwd_file).read().strip()
|
||||||
|
if cwd_path:
|
||||||
|
self.cwd = cwd_path
|
||||||
|
except (OSError, FileNotFoundError):
|
||||||
|
pass
|
||||||
|
|
||||||
while proc.poll() is None:
|
# Still strip the marker from output so it's not visible
|
||||||
if is_interrupted():
|
self._extract_cwd_from_output(result)
|
||||||
try:
|
|
||||||
if _IS_WINDOWS:
|
|
||||||
proc.terminate()
|
|
||||||
else:
|
|
||||||
pgid = os.getpgid(proc.pid)
|
|
||||||
os.killpg(pgid, signal.SIGTERM)
|
|
||||||
try:
|
|
||||||
proc.wait(timeout=1.0)
|
|
||||||
except subprocess.TimeoutExpired:
|
|
||||||
os.killpg(pgid, signal.SIGKILL)
|
|
||||||
except (ProcessLookupError, PermissionError):
|
|
||||||
proc.kill()
|
|
||||||
reader.join(timeout=2)
|
|
||||||
return {
|
|
||||||
"output": "".join(_output_chunks) + "\n[Command interrupted — user sent a new message]",
|
|
||||||
"returncode": 130,
|
|
||||||
}
|
|
||||||
if time.monotonic() > deadline:
|
|
||||||
try:
|
|
||||||
if _IS_WINDOWS:
|
|
||||||
proc.terminate()
|
|
||||||
else:
|
|
||||||
os.killpg(os.getpgid(proc.pid), signal.SIGTERM)
|
|
||||||
except (ProcessLookupError, PermissionError):
|
|
||||||
proc.kill()
|
|
||||||
reader.join(timeout=2)
|
|
||||||
partial = "".join(_output_chunks)
|
|
||||||
timeout_msg = f"\n[Command timed out after {effective_timeout}s]"
|
|
||||||
return {
|
|
||||||
"output": partial + timeout_msg if partial else timeout_msg.lstrip(),
|
|
||||||
"returncode": 124,
|
|
||||||
}
|
|
||||||
time.sleep(0.2)
|
|
||||||
|
|
||||||
reader.join(timeout=5)
|
def cleanup(self):
|
||||||
output = _extract_fenced_output("".join(_output_chunks))
|
"""Clean up temp files."""
|
||||||
return {"output": output, "returncode": proc.returncode}
|
for f in (self._snapshot_path, self._cwd_file):
|
||||||
|
try:
|
||||||
|
os.unlink(f)
|
||||||
|
except OSError:
|
||||||
|
pass
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ import uuid
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
from tools.environments.modal_common import (
|
from tools.environments.modal_utils import (
|
||||||
BaseModalExecutionEnvironment,
|
BaseModalExecutionEnvironment,
|
||||||
ModalExecStart,
|
ModalExecStart,
|
||||||
PreparedModalExec,
|
PreparedModalExec,
|
||||||
|
|||||||
@@ -5,19 +5,19 @@ wrapper, while preserving Hermes' persistent snapshot behavior across sessions.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
|
||||||
import logging
|
import logging
|
||||||
import shlex
|
import shlex
|
||||||
import threading
|
import threading
|
||||||
from dataclasses import dataclass
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
from hermes_constants import get_hermes_home
|
from hermes_constants import get_hermes_home
|
||||||
from tools.environments.modal_common import (
|
from tools.environments.base import (
|
||||||
BaseModalExecutionEnvironment,
|
BaseEnvironment,
|
||||||
ModalExecStart,
|
_ThreadedProcessHandle,
|
||||||
PreparedModalExec,
|
_file_mtime_key,
|
||||||
|
_load_json_store,
|
||||||
|
_save_json_store,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -26,20 +26,12 @@ _SNAPSHOT_STORE = get_hermes_home() / "modal_snapshots.json"
|
|||||||
_DIRECT_SNAPSHOT_NAMESPACE = "direct"
|
_DIRECT_SNAPSHOT_NAMESPACE = "direct"
|
||||||
|
|
||||||
|
|
||||||
def _load_snapshots() -> Dict[str, str]:
|
def _load_snapshots() -> dict:
|
||||||
"""Load snapshot ID mapping from disk."""
|
return _load_json_store(_SNAPSHOT_STORE)
|
||||||
if _SNAPSHOT_STORE.exists():
|
|
||||||
try:
|
|
||||||
return json.loads(_SNAPSHOT_STORE.read_text())
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
return {}
|
|
||||||
|
|
||||||
|
|
||||||
def _save_snapshots(data: Dict[str, str]) -> None:
|
def _save_snapshots(data: dict) -> None:
|
||||||
"""Persist snapshot ID mapping to disk."""
|
_save_json_store(_SNAPSHOT_STORE, data)
|
||||||
_SNAPSHOT_STORE.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
_SNAPSHOT_STORE.write_text(json.dumps(data, indent=2))
|
|
||||||
|
|
||||||
|
|
||||||
def _direct_snapshot_key(task_id: str) -> str:
|
def _direct_snapshot_key(task_id: str) -> str:
|
||||||
@@ -47,23 +39,18 @@ def _direct_snapshot_key(task_id: str) -> str:
|
|||||||
|
|
||||||
|
|
||||||
def _get_snapshot_restore_candidate(task_id: str) -> tuple[str | None, bool]:
|
def _get_snapshot_restore_candidate(task_id: str) -> tuple[str | None, bool]:
|
||||||
"""Return a snapshot id and whether it came from the legacy key format."""
|
|
||||||
snapshots = _load_snapshots()
|
snapshots = _load_snapshots()
|
||||||
|
|
||||||
namespaced_key = _direct_snapshot_key(task_id)
|
namespaced_key = _direct_snapshot_key(task_id)
|
||||||
snapshot_id = snapshots.get(namespaced_key)
|
snapshot_id = snapshots.get(namespaced_key)
|
||||||
if isinstance(snapshot_id, str) and snapshot_id:
|
if isinstance(snapshot_id, str) and snapshot_id:
|
||||||
return snapshot_id, False
|
return snapshot_id, False
|
||||||
|
|
||||||
legacy_snapshot_id = snapshots.get(task_id)
|
legacy_snapshot_id = snapshots.get(task_id)
|
||||||
if isinstance(legacy_snapshot_id, str) and legacy_snapshot_id:
|
if isinstance(legacy_snapshot_id, str) and legacy_snapshot_id:
|
||||||
return legacy_snapshot_id, True
|
return legacy_snapshot_id, True
|
||||||
|
|
||||||
return None, False
|
return None, False
|
||||||
|
|
||||||
|
|
||||||
def _store_direct_snapshot(task_id: str, snapshot_id: str) -> None:
|
def _store_direct_snapshot(task_id: str, snapshot_id: str) -> None:
|
||||||
"""Persist the direct Modal snapshot id under the direct namespace."""
|
|
||||||
snapshots = _load_snapshots()
|
snapshots = _load_snapshots()
|
||||||
snapshots[_direct_snapshot_key(task_id)] = snapshot_id
|
snapshots[_direct_snapshot_key(task_id)] = snapshot_id
|
||||||
snapshots.pop(task_id, None)
|
snapshots.pop(task_id, None)
|
||||||
@@ -71,10 +58,8 @@ def _store_direct_snapshot(task_id: str, snapshot_id: str) -> None:
|
|||||||
|
|
||||||
|
|
||||||
def _delete_direct_snapshot(task_id: str, snapshot_id: str | None = None) -> None:
|
def _delete_direct_snapshot(task_id: str, snapshot_id: str | None = None) -> None:
|
||||||
"""Remove direct Modal snapshot entries for a task, including legacy keys."""
|
|
||||||
snapshots = _load_snapshots()
|
snapshots = _load_snapshots()
|
||||||
updated = False
|
updated = False
|
||||||
|
|
||||||
for key in (_direct_snapshot_key(task_id), task_id):
|
for key in (_direct_snapshot_key(task_id), task_id):
|
||||||
value = snapshots.get(key)
|
value = snapshots.get(key)
|
||||||
if value is None:
|
if value is None:
|
||||||
@@ -82,13 +67,15 @@ def _delete_direct_snapshot(task_id: str, snapshot_id: str | None = None) -> Non
|
|||||||
if snapshot_id is None or value == snapshot_id:
|
if snapshot_id is None or value == snapshot_id:
|
||||||
snapshots.pop(key, None)
|
snapshots.pop(key, None)
|
||||||
updated = True
|
updated = True
|
||||||
|
|
||||||
if updated:
|
if updated:
|
||||||
_save_snapshots(snapshots)
|
_save_snapshots(snapshots)
|
||||||
|
|
||||||
|
|
||||||
def _resolve_modal_image(image_spec: Any) -> Any:
|
def _resolve_modal_image(image_spec: Any) -> Any:
|
||||||
"""Convert registry references or snapshot ids into Modal image objects."""
|
"""Convert registry references or snapshot ids into Modal image objects.
|
||||||
|
|
||||||
|
Includes add_python support for ubuntu/debian images (absorbed from PR 4511).
|
||||||
|
"""
|
||||||
import modal as _modal
|
import modal as _modal
|
||||||
|
|
||||||
if not isinstance(image_spec, str):
|
if not isinstance(image_spec, str):
|
||||||
@@ -97,12 +84,22 @@ def _resolve_modal_image(image_spec: Any) -> Any:
|
|||||||
if image_spec.startswith("im-"):
|
if image_spec.startswith("im-"):
|
||||||
return _modal.Image.from_id(image_spec)
|
return _modal.Image.from_id(image_spec)
|
||||||
|
|
||||||
|
# PR 4511: add python to ubuntu/debian images that don't have it
|
||||||
|
lower = image_spec.lower()
|
||||||
|
add_python = any(base in lower for base in ("ubuntu", "debian"))
|
||||||
|
|
||||||
|
setup_commands = [
|
||||||
|
"RUN rm -rf /usr/local/lib/python*/site-packages/pip* 2>/dev/null; "
|
||||||
|
"python -m ensurepip --upgrade --default-pip 2>/dev/null || true",
|
||||||
|
]
|
||||||
|
if add_python:
|
||||||
|
setup_commands.insert(0,
|
||||||
|
"RUN apt-get update -qq && apt-get install -y -qq python3 python3-venv > /dev/null 2>&1 || true"
|
||||||
|
)
|
||||||
|
|
||||||
return _modal.Image.from_registry(
|
return _modal.Image.from_registry(
|
||||||
image_spec,
|
image_spec,
|
||||||
setup_dockerfile_commands=[
|
setup_dockerfile_commands=setup_commands,
|
||||||
"RUN rm -rf /usr/local/lib/python*/site-packages/pip* 2>/dev/null; "
|
|
||||||
"python -m ensurepip --upgrade --default-pip 2>/dev/null || true",
|
|
||||||
],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -138,19 +135,15 @@ class _AsyncWorker:
|
|||||||
self._thread.join(timeout=10)
|
self._thread.join(timeout=10)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
class ModalEnvironment(BaseEnvironment):
|
||||||
class _DirectModalExecHandle:
|
"""Modal cloud execution via native Modal sandboxes.
|
||||||
thread: threading.Thread
|
|
||||||
result_holder: Dict[str, Any]
|
|
||||||
|
|
||||||
|
Spawn-per-call via _ThreadedProcessHandle wrapping async SDK calls.
|
||||||
class ModalEnvironment(BaseModalExecutionEnvironment):
|
cancel_fn wired to sandbox.terminate for interrupt support.
|
||||||
"""Modal cloud execution via native Modal sandboxes."""
|
"""
|
||||||
|
|
||||||
_stdin_mode = "heredoc"
|
_stdin_mode = "heredoc"
|
||||||
_poll_interval_seconds = 0.2
|
_snapshot_timeout = 60 # Modal cold starts can be slow
|
||||||
_interrupt_output = "[Command interrupted - Modal sandbox terminated]"
|
|
||||||
_unexpected_error_prefix = "Modal execution error"
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -170,6 +163,7 @@ class ModalEnvironment(BaseModalExecutionEnvironment):
|
|||||||
self._app = None
|
self._app = None
|
||||||
self._worker = _AsyncWorker()
|
self._worker = _AsyncWorker()
|
||||||
self._synced_files: Dict[str, tuple] = {}
|
self._synced_files: Dict[str, tuple] = {}
|
||||||
|
self._last_sync_time: float = 0
|
||||||
|
|
||||||
sandbox_kwargs = dict(modal_sandbox_kwargs or {})
|
sandbox_kwargs = dict(modal_sandbox_kwargs or {})
|
||||||
|
|
||||||
@@ -199,27 +193,13 @@ class ModalEnvironment(BaseModalExecutionEnvironment):
|
|||||||
remote_path=mount_entry["container_path"],
|
remote_path=mount_entry["container_path"],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
logger.info(
|
for entry in iter_skills_files():
|
||||||
"Modal: mounting credential %s -> %s",
|
|
||||||
mount_entry["host_path"],
|
|
||||||
mount_entry["container_path"],
|
|
||||||
)
|
|
||||||
|
|
||||||
# Mount individual skill files (symlinks filtered out).
|
|
||||||
skills_files = iter_skills_files()
|
|
||||||
for entry in skills_files:
|
|
||||||
cred_mounts.append(
|
cred_mounts.append(
|
||||||
_modal.Mount.from_local_file(
|
_modal.Mount.from_local_file(
|
||||||
entry["host_path"],
|
entry["host_path"],
|
||||||
remote_path=entry["container_path"],
|
remote_path=entry["container_path"],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
if skills_files:
|
|
||||||
logger.info("Modal: mounting %d skill files", len(skills_files))
|
|
||||||
|
|
||||||
# Mount host-side cache files (documents, images, audio,
|
|
||||||
# screenshots). New files arriving mid-session are picked up
|
|
||||||
# by _sync_files() before each command execution.
|
|
||||||
cache_files = iter_cache_files()
|
cache_files = iter_cache_files()
|
||||||
for entry in cache_files:
|
for entry in cache_files:
|
||||||
cred_mounts.append(
|
cred_mounts.append(
|
||||||
@@ -228,8 +208,6 @@ class ModalEnvironment(BaseModalExecutionEnvironment):
|
|||||||
remote_path=entry["container_path"],
|
remote_path=entry["container_path"],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
if cache_files:
|
|
||||||
logger.info("Modal: mounting %d cache files", len(cache_files))
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug("Modal: could not load credential file mounts: %s", e)
|
logger.debug("Modal: could not load credential file mounts: %s", e)
|
||||||
|
|
||||||
@@ -243,8 +221,7 @@ class ModalEnvironment(BaseModalExecutionEnvironment):
|
|||||||
existing_mounts.extend(cred_mounts)
|
existing_mounts.extend(cred_mounts)
|
||||||
create_kwargs["mounts"] = existing_mounts
|
create_kwargs["mounts"] = existing_mounts
|
||||||
sandbox = await _modal.Sandbox.create.aio(
|
sandbox = await _modal.Sandbox.create.aio(
|
||||||
"sleep",
|
"sleep", "infinity",
|
||||||
"infinity",
|
|
||||||
image=image_spec,
|
image=image_spec,
|
||||||
app=app,
|
app=app,
|
||||||
timeout=int(create_kwargs.pop("timeout", 3600)),
|
timeout=int(create_kwargs.pop("timeout", 3600)),
|
||||||
@@ -255,57 +232,41 @@ class ModalEnvironment(BaseModalExecutionEnvironment):
|
|||||||
try:
|
try:
|
||||||
target_image_spec = restored_snapshot_id or image
|
target_image_spec = restored_snapshot_id or image
|
||||||
try:
|
try:
|
||||||
# _resolve_modal_image keeps the Modal bootstrap fix together:
|
|
||||||
# it applies setup_dockerfile_commands with ensurepip before
|
|
||||||
# Modal builds registry images, while snapshot ids restore via
|
|
||||||
# modal.Image.from_id() without rebuilding.
|
|
||||||
effective_image = _resolve_modal_image(target_image_spec)
|
effective_image = _resolve_modal_image(target_image_spec)
|
||||||
self._app, self._sandbox = self._worker.run_coroutine(
|
self._app, self._sandbox = self._worker.run_coroutine(
|
||||||
_create_sandbox(effective_image),
|
_create_sandbox(effective_image), timeout=300,
|
||||||
timeout=300,
|
|
||||||
)
|
)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
if not restored_snapshot_id:
|
if not restored_snapshot_id:
|
||||||
raise
|
raise
|
||||||
|
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Modal: failed to restore snapshot %s, retrying with base image: %s",
|
"Modal: failed to restore snapshot %s, retrying with base image: %s",
|
||||||
restored_snapshot_id[:20],
|
restored_snapshot_id[:20], exc,
|
||||||
exc,
|
|
||||||
)
|
)
|
||||||
_delete_direct_snapshot(self._task_id, restored_snapshot_id)
|
_delete_direct_snapshot(self._task_id, restored_snapshot_id)
|
||||||
base_image = _resolve_modal_image(image)
|
base_image = _resolve_modal_image(image)
|
||||||
self._app, self._sandbox = self._worker.run_coroutine(
|
self._app, self._sandbox = self._worker.run_coroutine(
|
||||||
_create_sandbox(base_image),
|
_create_sandbox(base_image), timeout=300,
|
||||||
timeout=300,
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if restored_snapshot_id and restored_from_legacy_key:
|
if restored_snapshot_id and restored_from_legacy_key:
|
||||||
_store_direct_snapshot(self._task_id, restored_snapshot_id)
|
_store_direct_snapshot(self._task_id, restored_snapshot_id)
|
||||||
logger.info(
|
|
||||||
"Modal: migrated legacy snapshot entry for task %s",
|
|
||||||
self._task_id,
|
|
||||||
)
|
|
||||||
except Exception:
|
except Exception:
|
||||||
self._worker.stop()
|
self._worker.stop()
|
||||||
raise
|
raise
|
||||||
|
|
||||||
logger.info("Modal: sandbox created (task=%s)", self._task_id)
|
logger.info("Modal: sandbox created (task=%s)", self._task_id)
|
||||||
|
self.init_session()
|
||||||
|
|
||||||
def _push_file_to_sandbox(self, host_path: str, container_path: str) -> bool:
|
def _push_file_to_sandbox(self, host_path: str, container_path: str) -> bool:
|
||||||
"""Push a single file into the sandbox if changed. Returns True if synced."""
|
"""Push a single file into the sandbox if changed."""
|
||||||
hp = Path(host_path)
|
file_key = _file_mtime_key(host_path)
|
||||||
try:
|
if file_key is None:
|
||||||
stat = hp.stat()
|
|
||||||
file_key = (stat.st_mtime, stat.st_size)
|
|
||||||
except OSError:
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if self._synced_files.get(container_path) == file_key:
|
if self._synced_files.get(container_path) == file_key:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
content = hp.read_bytes()
|
content = Path(host_path).read_bytes()
|
||||||
except Exception:
|
except Exception:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@@ -326,85 +287,55 @@ class ModalEnvironment(BaseModalExecutionEnvironment):
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
def _sync_files(self) -> None:
|
def _sync_files(self) -> None:
|
||||||
"""Push credential, skill, and cache files into the running sandbox.
|
"""Push credential, skill, and cache files into the running sandbox."""
|
||||||
|
|
||||||
Runs before each command. Uses mtime+size caching so only changed
|
|
||||||
files are pushed (~13μs overhead in the no-op case). Cache files
|
|
||||||
are especially important here — new uploads/screenshots may appear
|
|
||||||
mid-session after sandbox creation.
|
|
||||||
"""
|
|
||||||
try:
|
try:
|
||||||
from tools.credential_files import (
|
from tools.credential_files import (
|
||||||
get_credential_file_mounts,
|
get_credential_file_mounts,
|
||||||
iter_skills_files,
|
iter_skills_files,
|
||||||
iter_cache_files,
|
iter_cache_files,
|
||||||
)
|
)
|
||||||
|
|
||||||
for entry in get_credential_file_mounts():
|
for entry in get_credential_file_mounts():
|
||||||
if self._push_file_to_sandbox(entry["host_path"], entry["container_path"]):
|
self._push_file_to_sandbox(entry["host_path"], entry["container_path"])
|
||||||
logger.debug("Modal: synced credential %s", entry["container_path"])
|
|
||||||
|
|
||||||
for entry in iter_skills_files():
|
for entry in iter_skills_files():
|
||||||
if self._push_file_to_sandbox(entry["host_path"], entry["container_path"]):
|
self._push_file_to_sandbox(entry["host_path"], entry["container_path"])
|
||||||
logger.debug("Modal: synced skill file %s", entry["container_path"])
|
|
||||||
|
|
||||||
for entry in iter_cache_files():
|
for entry in iter_cache_files():
|
||||||
if self._push_file_to_sandbox(entry["host_path"], entry["container_path"]):
|
self._push_file_to_sandbox(entry["host_path"], entry["container_path"])
|
||||||
logger.debug("Modal: synced cache file %s", entry["container_path"])
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug("Modal: file sync failed: %s", e)
|
logger.debug("Modal: file sync failed: %s", e)
|
||||||
|
|
||||||
def _before_execute(self) -> None:
|
def _run_bash(self, cmd_string: str, *, login: bool = False,
|
||||||
self._sync_files()
|
timeout: int = 120,
|
||||||
|
stdin_data: str | None = None):
|
||||||
|
"""Return a _ThreadedProcessHandle wrapping an async Modal sandbox exec."""
|
||||||
|
sandbox = self._sandbox
|
||||||
|
worker = self._worker
|
||||||
|
|
||||||
def _start_modal_exec(self, prepared: PreparedModalExec) -> ModalExecStart:
|
def cancel():
|
||||||
full_command = f"cd {shlex.quote(prepared.cwd)} && {prepared.command}"
|
worker.run_coroutine(sandbox.terminate.aio(), timeout=15)
|
||||||
result_holder = {"value": None, "error": None}
|
|
||||||
|
|
||||||
def _run():
|
def exec_fn() -> tuple[str, int]:
|
||||||
try:
|
async def _do():
|
||||||
async def _do_execute():
|
args = ["bash"]
|
||||||
process = await self._sandbox.exec.aio(
|
if login:
|
||||||
"bash",
|
args.extend(["-l", "-c", cmd_string])
|
||||||
"-c",
|
else:
|
||||||
full_command,
|
args.extend(["-c", cmd_string])
|
||||||
timeout=prepared.timeout,
|
process = await sandbox.exec.aio(*args, timeout=timeout)
|
||||||
)
|
stdout = await process.stdout.read.aio()
|
||||||
stdout = await process.stdout.read.aio()
|
stderr = await process.stderr.read.aio()
|
||||||
stderr = await process.stderr.read.aio()
|
exit_code = await process.wait.aio()
|
||||||
exit_code = await process.wait.aio()
|
if isinstance(stdout, bytes):
|
||||||
if isinstance(stdout, bytes):
|
stdout = stdout.decode("utf-8", errors="replace")
|
||||||
stdout = stdout.decode("utf-8", errors="replace")
|
if isinstance(stderr, bytes):
|
||||||
if isinstance(stderr, bytes):
|
stderr = stderr.decode("utf-8", errors="replace")
|
||||||
stderr = stderr.decode("utf-8", errors="replace")
|
output = stdout
|
||||||
output = stdout
|
if stderr:
|
||||||
if stderr:
|
output = f"{stdout}\n{stderr}" if stdout else stderr
|
||||||
output = f"{stdout}\n{stderr}" if stdout else stderr
|
return output, exit_code
|
||||||
return self._result(output, exit_code)
|
|
||||||
|
|
||||||
result_holder["value"] = self._worker.run_coroutine(
|
return worker.run_coroutine(_do(), timeout=timeout + 30)
|
||||||
_do_execute(),
|
|
||||||
timeout=prepared.timeout + 30,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
result_holder["error"] = e
|
|
||||||
|
|
||||||
t = threading.Thread(target=_run, daemon=True)
|
return _ThreadedProcessHandle(exec_fn, cancel_fn=cancel)
|
||||||
t.start()
|
|
||||||
return ModalExecStart(handle=_DirectModalExecHandle(thread=t, result_holder=result_holder))
|
|
||||||
|
|
||||||
def _poll_modal_exec(self, handle: _DirectModalExecHandle) -> dict | None:
|
|
||||||
if handle.thread.is_alive():
|
|
||||||
return None
|
|
||||||
if handle.result_holder["error"]:
|
|
||||||
return self._error_result(f"Modal execution error: {handle.result_holder['error']}")
|
|
||||||
return handle.result_holder["value"]
|
|
||||||
|
|
||||||
def _cancel_modal_exec(self, handle: _DirectModalExecHandle) -> None:
|
|
||||||
self._worker.run_coroutine(
|
|
||||||
self._sandbox.terminate.aio(),
|
|
||||||
timeout=15,
|
|
||||||
)
|
|
||||||
|
|
||||||
def cleanup(self):
|
def cleanup(self):
|
||||||
"""Snapshot the filesystem (if persistent) then stop the sandbox."""
|
"""Snapshot the filesystem (if persistent) then stop the sandbox."""
|
||||||
@@ -426,17 +357,13 @@ class ModalEnvironment(BaseModalExecutionEnvironment):
|
|||||||
_store_direct_snapshot(self._task_id, snapshot_id)
|
_store_direct_snapshot(self._task_id, snapshot_id)
|
||||||
logger.info(
|
logger.info(
|
||||||
"Modal: saved filesystem snapshot %s for task %s",
|
"Modal: saved filesystem snapshot %s for task %s",
|
||||||
snapshot_id[:20],
|
snapshot_id[:20], self._task_id,
|
||||||
self._task_id,
|
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning("Modal: filesystem snapshot failed: %s", e)
|
logger.warning("Modal: filesystem snapshot failed: %s", e)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self._worker.run_coroutine(
|
self._worker.run_coroutine(self._sandbox.terminate.aio(), timeout=15)
|
||||||
self._sandbox.terminate.aio(),
|
|
||||||
timeout=15,
|
|
||||||
)
|
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
finally:
|
finally:
|
||||||
|
|||||||
@@ -56,7 +56,15 @@ def wrap_modal_sudo_pipe(command: str, sudo_stdin: str) -> str:
|
|||||||
|
|
||||||
|
|
||||||
class BaseModalExecutionEnvironment(BaseEnvironment):
|
class BaseModalExecutionEnvironment(BaseEnvironment):
|
||||||
"""Common execute() flow for direct and managed Modal transports."""
|
"""Execution flow for the *managed* Modal transport (gateway-owned sandbox).
|
||||||
|
|
||||||
|
This deliberately overrides :meth:`BaseEnvironment.execute` because the
|
||||||
|
tool-gateway handles command preparation, CWD tracking, and env-snapshot
|
||||||
|
management on the server side. The base class's ``_wrap_command`` /
|
||||||
|
``_wait_for_process`` / snapshot machinery does not apply here — the
|
||||||
|
gateway owns that responsibility. See ``ManagedModalEnvironment`` for the
|
||||||
|
concrete subclass.
|
||||||
|
"""
|
||||||
|
|
||||||
_stdin_mode = "payload"
|
_stdin_mode = "payload"
|
||||||
_poll_interval_seconds = 0.25
|
_poll_interval_seconds = 0.25
|
||||||
@@ -124,7 +132,7 @@ class BaseModalExecutionEnvironment(BaseEnvironment):
|
|||||||
|
|
||||||
def _before_execute(self) -> None:
|
def _before_execute(self) -> None:
|
||||||
"""Hook for backends that need pre-exec sync or validation."""
|
"""Hook for backends that need pre-exec sync or validation."""
|
||||||
return None
|
pass
|
||||||
|
|
||||||
def _prepare_modal_exec(
|
def _prepare_modal_exec(
|
||||||
self,
|
self,
|
||||||
@@ -1,290 +0,0 @@
|
|||||||
"""Persistent shell mixin: file-based IPC protocol for long-lived bash shells."""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import shlex
|
|
||||||
import subprocess
|
|
||||||
import threading
|
|
||||||
import time
|
|
||||||
import uuid
|
|
||||||
from abc import abstractmethod
|
|
||||||
|
|
||||||
from tools.interrupt import is_interrupted
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class PersistentShellMixin:
|
|
||||||
"""Mixin that adds persistent shell capability to any BaseEnvironment.
|
|
||||||
|
|
||||||
Subclasses must implement ``_spawn_shell_process()``, ``_read_temp_files()``,
|
|
||||||
``_kill_shell_children()``, ``_execute_oneshot()``, and ``_cleanup_temp_files()``.
|
|
||||||
"""
|
|
||||||
|
|
||||||
persistent: bool
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def _spawn_shell_process(self) -> subprocess.Popen: ...
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def _read_temp_files(self, *paths: str) -> list[str]: ...
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def _kill_shell_children(self): ...
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def _execute_oneshot(self, command: str, cwd: str, *,
|
|
||||||
timeout: int | None = None,
|
|
||||||
stdin_data: str | None = None) -> dict: ...
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def _cleanup_temp_files(self): ...
|
|
||||||
|
|
||||||
_session_id: str = ""
|
|
||||||
_poll_interval_start: float = 0.01 # initial poll interval (10ms)
|
|
||||||
_poll_interval_max: float = 0.25 # max poll interval (250ms) — reduces I/O for long commands
|
|
||||||
|
|
||||||
@property
|
|
||||||
def _temp_prefix(self) -> str:
|
|
||||||
return f"/tmp/hermes-persistent-{self._session_id}"
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# Lifecycle
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
|
|
||||||
def _init_persistent_shell(self):
|
|
||||||
self._shell_lock = threading.Lock()
|
|
||||||
self._shell_proc: subprocess.Popen | None = None
|
|
||||||
self._shell_alive: bool = False
|
|
||||||
self._shell_pid: int | None = None
|
|
||||||
|
|
||||||
self._session_id = uuid.uuid4().hex[:12]
|
|
||||||
p = self._temp_prefix
|
|
||||||
self._pshell_stdout = f"{p}-stdout"
|
|
||||||
self._pshell_stderr = f"{p}-stderr"
|
|
||||||
self._pshell_status = f"{p}-status"
|
|
||||||
self._pshell_cwd = f"{p}-cwd"
|
|
||||||
self._pshell_pid_file = f"{p}-pid"
|
|
||||||
|
|
||||||
self._shell_proc = self._spawn_shell_process()
|
|
||||||
self._shell_alive = True
|
|
||||||
|
|
||||||
self._drain_thread = threading.Thread(
|
|
||||||
target=self._drain_shell_output, daemon=True,
|
|
||||||
)
|
|
||||||
self._drain_thread.start()
|
|
||||||
|
|
||||||
init_script = (
|
|
||||||
f"export TERM=${{TERM:-dumb}}\n"
|
|
||||||
f"touch {self._pshell_stdout} {self._pshell_stderr} "
|
|
||||||
f"{self._pshell_status} {self._pshell_cwd} {self._pshell_pid_file}\n"
|
|
||||||
f"echo $$ > {self._pshell_pid_file}\n"
|
|
||||||
f"pwd > {self._pshell_cwd}\n"
|
|
||||||
)
|
|
||||||
self._send_to_shell(init_script)
|
|
||||||
|
|
||||||
deadline = time.monotonic() + 3.0
|
|
||||||
while time.monotonic() < deadline:
|
|
||||||
pid_str = self._read_temp_files(self._pshell_pid_file)[0].strip()
|
|
||||||
if pid_str.isdigit():
|
|
||||||
self._shell_pid = int(pid_str)
|
|
||||||
break
|
|
||||||
time.sleep(0.05)
|
|
||||||
else:
|
|
||||||
logger.warning("Could not read persistent shell PID")
|
|
||||||
self._shell_pid = None
|
|
||||||
|
|
||||||
if self._shell_pid:
|
|
||||||
logger.info(
|
|
||||||
"Persistent shell started (session=%s, pid=%d)",
|
|
||||||
self._session_id, self._shell_pid,
|
|
||||||
)
|
|
||||||
|
|
||||||
reported_cwd = self._read_temp_files(self._pshell_cwd)[0].strip()
|
|
||||||
if reported_cwd:
|
|
||||||
self.cwd = reported_cwd
|
|
||||||
|
|
||||||
def _cleanup_persistent_shell(self):
|
|
||||||
if self._shell_proc is None:
|
|
||||||
return
|
|
||||||
|
|
||||||
if self._session_id:
|
|
||||||
self._cleanup_temp_files()
|
|
||||||
|
|
||||||
try:
|
|
||||||
self._shell_proc.stdin.close()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
try:
|
|
||||||
self._shell_proc.terminate()
|
|
||||||
self._shell_proc.wait(timeout=3)
|
|
||||||
except subprocess.TimeoutExpired:
|
|
||||||
self._shell_proc.kill()
|
|
||||||
|
|
||||||
self._shell_alive = False
|
|
||||||
self._shell_proc = None
|
|
||||||
|
|
||||||
if hasattr(self, "_drain_thread") and self._drain_thread.is_alive():
|
|
||||||
self._drain_thread.join(timeout=1.0)
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# execute() / cleanup() — shared dispatcher, subclasses inherit
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
|
|
||||||
def execute(self, command: str, cwd: str = "", *,
|
|
||||||
timeout: int | None = None,
|
|
||||||
stdin_data: str | None = None) -> dict:
|
|
||||||
if self.persistent:
|
|
||||||
return self._execute_persistent(
|
|
||||||
command, cwd, timeout=timeout, stdin_data=stdin_data,
|
|
||||||
)
|
|
||||||
return self._execute_oneshot(
|
|
||||||
command, cwd, timeout=timeout, stdin_data=stdin_data,
|
|
||||||
)
|
|
||||||
|
|
||||||
def execute_oneshot(self, command: str, cwd: str = "", *,
|
|
||||||
timeout: int | None = None,
|
|
||||||
stdin_data: str | None = None) -> dict:
|
|
||||||
"""Always use the oneshot (non-persistent) execution path.
|
|
||||||
|
|
||||||
This bypasses _shell_lock so it can run concurrently with a
|
|
||||||
long-running command in the persistent shell — used by
|
|
||||||
execute_code's file-based RPC polling thread.
|
|
||||||
"""
|
|
||||||
return self._execute_oneshot(
|
|
||||||
command, cwd, timeout=timeout, stdin_data=stdin_data,
|
|
||||||
)
|
|
||||||
|
|
||||||
def cleanup(self):
|
|
||||||
if self.persistent:
|
|
||||||
self._cleanup_persistent_shell()
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# Shell I/O
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
|
|
||||||
def _drain_shell_output(self):
|
|
||||||
try:
|
|
||||||
for _ in self._shell_proc.stdout:
|
|
||||||
pass
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
self._shell_alive = False
|
|
||||||
|
|
||||||
def _send_to_shell(self, text: str):
|
|
||||||
if not self._shell_alive or self._shell_proc is None:
|
|
||||||
return
|
|
||||||
try:
|
|
||||||
self._shell_proc.stdin.write(text)
|
|
||||||
self._shell_proc.stdin.flush()
|
|
||||||
except (BrokenPipeError, OSError):
|
|
||||||
self._shell_alive = False
|
|
||||||
|
|
||||||
def _read_persistent_output(self) -> tuple[str, int, str]:
|
|
||||||
stdout, stderr, status_raw, cwd = self._read_temp_files(
|
|
||||||
self._pshell_stdout, self._pshell_stderr,
|
|
||||||
self._pshell_status, self._pshell_cwd,
|
|
||||||
)
|
|
||||||
output = self._merge_output(stdout, stderr)
|
|
||||||
status = status_raw.strip()
|
|
||||||
if ":" in status:
|
|
||||||
status = status.split(":", 1)[1]
|
|
||||||
try:
|
|
||||||
exit_code = int(status.strip())
|
|
||||||
except ValueError:
|
|
||||||
exit_code = 1
|
|
||||||
return output, exit_code, cwd.strip()
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# Execution
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
|
|
||||||
def _execute_persistent(self, command: str, cwd: str, *,
|
|
||||||
timeout: int | None = None,
|
|
||||||
stdin_data: str | None = None) -> dict:
|
|
||||||
if not self._shell_alive:
|
|
||||||
logger.info("Persistent shell died, restarting...")
|
|
||||||
self._init_persistent_shell()
|
|
||||||
|
|
||||||
exec_command, sudo_stdin = self._prepare_command(command)
|
|
||||||
effective_timeout = timeout or self.timeout
|
|
||||||
if stdin_data or sudo_stdin:
|
|
||||||
return self._execute_oneshot(
|
|
||||||
command, cwd, timeout=timeout, stdin_data=stdin_data,
|
|
||||||
)
|
|
||||||
|
|
||||||
with self._shell_lock:
|
|
||||||
return self._execute_persistent_locked(
|
|
||||||
exec_command, cwd, effective_timeout,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _execute_persistent_locked(self, command: str, cwd: str,
|
|
||||||
timeout: int) -> dict:
|
|
||||||
work_dir = cwd or self.cwd
|
|
||||||
cmd_id = uuid.uuid4().hex[:8]
|
|
||||||
truncate = (
|
|
||||||
f": > {self._pshell_stdout}\n"
|
|
||||||
f": > {self._pshell_stderr}\n"
|
|
||||||
f": > {self._pshell_status}\n"
|
|
||||||
)
|
|
||||||
self._send_to_shell(truncate)
|
|
||||||
escaped = command.replace("'", "'\\''")
|
|
||||||
|
|
||||||
ipc_script = (
|
|
||||||
f"cd {shlex.quote(work_dir)}\n"
|
|
||||||
f"eval '{escaped}' < /dev/null > {self._pshell_stdout} 2> {self._pshell_stderr}\n"
|
|
||||||
f"__EC=$?\n"
|
|
||||||
f"pwd > {self._pshell_cwd}\n"
|
|
||||||
f"echo {cmd_id}:$__EC > {self._pshell_status}\n"
|
|
||||||
)
|
|
||||||
self._send_to_shell(ipc_script)
|
|
||||||
deadline = time.monotonic() + timeout
|
|
||||||
poll_interval = self._poll_interval_start # starts at 10ms, backs off to 250ms
|
|
||||||
|
|
||||||
while True:
|
|
||||||
if is_interrupted():
|
|
||||||
self._kill_shell_children()
|
|
||||||
output, _, _ = self._read_persistent_output()
|
|
||||||
return {
|
|
||||||
"output": output + "\n[Command interrupted]",
|
|
||||||
"returncode": 130,
|
|
||||||
}
|
|
||||||
|
|
||||||
if time.monotonic() > deadline:
|
|
||||||
self._kill_shell_children()
|
|
||||||
output, _, _ = self._read_persistent_output()
|
|
||||||
if output:
|
|
||||||
return {
|
|
||||||
"output": output + f"\n[Command timed out after {timeout}s]",
|
|
||||||
"returncode": 124,
|
|
||||||
}
|
|
||||||
return self._timeout_result(timeout)
|
|
||||||
|
|
||||||
if not self._shell_alive:
|
|
||||||
return {
|
|
||||||
"output": "Persistent shell died during execution",
|
|
||||||
"returncode": 1,
|
|
||||||
}
|
|
||||||
|
|
||||||
status_content = self._read_temp_files(self._pshell_status)[0].strip()
|
|
||||||
if status_content.startswith(cmd_id + ":"):
|
|
||||||
break
|
|
||||||
|
|
||||||
time.sleep(poll_interval)
|
|
||||||
# Exponential backoff: fast start (10ms) for quick commands,
|
|
||||||
# ramps up to 250ms for long-running commands — reduces I/O by 10-25x
|
|
||||||
# on WSL2 where polling keeps the VM hot and memory pressure high.
|
|
||||||
poll_interval = min(poll_interval * 1.5, self._poll_interval_max)
|
|
||||||
|
|
||||||
output, exit_code, new_cwd = self._read_persistent_output()
|
|
||||||
if new_cwd:
|
|
||||||
self.cwd = new_cwd
|
|
||||||
return {"output": output, "returncode": exit_code}
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _merge_output(stdout: str, stderr: str) -> str:
|
|
||||||
parts = []
|
|
||||||
if stdout.strip():
|
|
||||||
parts.append(stdout.rstrip("\n"))
|
|
||||||
if stderr.strip():
|
|
||||||
parts.append(stderr.rstrip("\n"))
|
|
||||||
return "\n".join(parts)
|
|
||||||
@@ -5,20 +5,22 @@ Supports configurable resource limits and optional filesystem persistence
|
|||||||
via writable overlay directories that survive across sessions.
|
via writable overlay directories that survive across sessions.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import shlex
|
|
||||||
import shutil
|
import shutil
|
||||||
import subprocess
|
import subprocess
|
||||||
import threading
|
import threading
|
||||||
import uuid
|
import uuid
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Optional
|
from typing import Optional
|
||||||
|
|
||||||
from hermes_constants import get_hermes_home
|
from hermes_constants import get_hermes_home
|
||||||
from tools.environments.base import BaseEnvironment
|
from tools.environments.base import (
|
||||||
from tools.interrupt import is_interrupted
|
BaseEnvironment,
|
||||||
|
_load_json_store,
|
||||||
|
_popen_bash,
|
||||||
|
_save_json_store,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -26,11 +28,7 @@ _SNAPSHOT_STORE = get_hermes_home() / "singularity_snapshots.json"
|
|||||||
|
|
||||||
|
|
||||||
def _find_singularity_executable() -> str:
|
def _find_singularity_executable() -> str:
|
||||||
"""Locate the apptainer or singularity CLI binary.
|
"""Locate the apptainer or singularity CLI binary."""
|
||||||
|
|
||||||
Returns the executable name (``"apptainer"`` or ``"singularity"``).
|
|
||||||
Raises ``RuntimeError`` with install instructions if neither is found.
|
|
||||||
"""
|
|
||||||
if shutil.which("apptainer"):
|
if shutil.which("apptainer"):
|
||||||
return "apptainer"
|
return "apptainer"
|
||||||
if shutil.which("singularity"):
|
if shutil.which("singularity"):
|
||||||
@@ -43,66 +41,34 @@ def _find_singularity_executable() -> str:
|
|||||||
|
|
||||||
|
|
||||||
def _ensure_singularity_available() -> str:
|
def _ensure_singularity_available() -> str:
|
||||||
"""Preflight check: resolve the executable and verify it responds.
|
"""Preflight check: resolve the executable and verify it responds."""
|
||||||
|
|
||||||
Returns the executable name on success.
|
|
||||||
Raises ``RuntimeError`` with an actionable message on failure.
|
|
||||||
"""
|
|
||||||
exe = _find_singularity_executable()
|
exe = _find_singularity_executable()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = subprocess.run(
|
result = subprocess.run(
|
||||||
[exe, "version"],
|
[exe, "version"], capture_output=True, text=True, timeout=10,
|
||||||
capture_output=True,
|
|
||||||
text=True,
|
|
||||||
timeout=10,
|
|
||||||
)
|
)
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Singularity backend selected but the resolved executable '{exe}' "
|
f"Singularity backend selected but '{exe}' could not be executed."
|
||||||
"could not be executed. Check your installation."
|
|
||||||
)
|
)
|
||||||
except subprocess.TimeoutExpired:
|
except subprocess.TimeoutExpired:
|
||||||
raise RuntimeError(
|
raise RuntimeError(f"'{exe} version' timed out.")
|
||||||
f"'{exe} version' timed out. The runtime may be misconfigured."
|
|
||||||
)
|
|
||||||
|
|
||||||
if result.returncode != 0:
|
if result.returncode != 0:
|
||||||
stderr = result.stderr.strip()[:200]
|
stderr = result.stderr.strip()[:200]
|
||||||
raise RuntimeError(
|
raise RuntimeError(f"'{exe} version' failed (exit code {result.returncode}): {stderr}")
|
||||||
f"'{exe} version' failed (exit code {result.returncode}): {stderr}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return exe
|
return exe
|
||||||
|
|
||||||
|
|
||||||
def _load_snapshots() -> Dict[str, str]:
|
def _load_snapshots() -> dict:
|
||||||
if _SNAPSHOT_STORE.exists():
|
return _load_json_store(_SNAPSHOT_STORE)
|
||||||
try:
|
|
||||||
return json.loads(_SNAPSHOT_STORE.read_text())
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
return {}
|
|
||||||
|
|
||||||
|
|
||||||
def _save_snapshots(data: Dict[str, str]) -> None:
|
def _save_snapshots(data: dict) -> None:
|
||||||
_SNAPSHOT_STORE.parent.mkdir(parents=True, exist_ok=True)
|
_save_json_store(_SNAPSHOT_STORE, data)
|
||||||
_SNAPSHOT_STORE.write_text(json.dumps(data, indent=2))
|
|
||||||
|
|
||||||
|
|
||||||
# -------------------------------------------------------------------------
|
|
||||||
# Singularity helpers (scratch dir, SIF cache, SIF building)
|
|
||||||
# -------------------------------------------------------------------------
|
|
||||||
|
|
||||||
def _get_scratch_dir() -> Path:
|
def _get_scratch_dir() -> Path:
|
||||||
"""Get the best directory for Singularity sandboxes.
|
|
||||||
|
|
||||||
Resolution order:
|
|
||||||
1. TERMINAL_SCRATCH_DIR (explicit override)
|
|
||||||
2. TERMINAL_SANDBOX_DIR / singularity (shared sandbox root)
|
|
||||||
3. /scratch (common on HPC clusters)
|
|
||||||
4. ~/.hermes/sandboxes/singularity (fallback)
|
|
||||||
"""
|
|
||||||
custom_scratch = os.getenv("TERMINAL_SCRATCH_DIR")
|
custom_scratch = os.getenv("TERMINAL_SCRATCH_DIR")
|
||||||
if custom_scratch:
|
if custom_scratch:
|
||||||
scratch_path = Path(custom_scratch)
|
scratch_path = Path(custom_scratch)
|
||||||
@@ -124,7 +90,6 @@ def _get_scratch_dir() -> Path:
|
|||||||
|
|
||||||
|
|
||||||
def _get_apptainer_cache_dir() -> Path:
|
def _get_apptainer_cache_dir() -> Path:
|
||||||
"""Get the Apptainer cache directory for SIF images."""
|
|
||||||
cache_dir = os.getenv("APPTAINER_CACHEDIR")
|
cache_dir = os.getenv("APPTAINER_CACHEDIR")
|
||||||
if cache_dir:
|
if cache_dir:
|
||||||
cache_path = Path(cache_dir)
|
cache_path = Path(cache_dir)
|
||||||
@@ -140,11 +105,6 @@ _sif_build_lock = threading.Lock()
|
|||||||
|
|
||||||
|
|
||||||
def _get_or_build_sif(image: str, executable: str = "apptainer") -> str:
|
def _get_or_build_sif(image: str, executable: str = "apptainer") -> str:
|
||||||
"""Get or build a SIF image from a docker:// URL.
|
|
||||||
|
|
||||||
Returns the path unchanged if it's already a .sif file.
|
|
||||||
For docker:// URLs, checks the cache and builds if needed.
|
|
||||||
"""
|
|
||||||
if image.endswith('.sif') and Path(image).exists():
|
if image.endswith('.sif') and Path(image).exists():
|
||||||
return image
|
return image
|
||||||
if not image.startswith('docker://'):
|
if not image.startswith('docker://'):
|
||||||
@@ -193,19 +153,12 @@ def _get_or_build_sif(image: str, executable: str = "apptainer") -> str:
|
|||||||
return image
|
return image
|
||||||
|
|
||||||
|
|
||||||
# -------------------------------------------------------------------------
|
|
||||||
# SingularityEnvironment
|
|
||||||
# -------------------------------------------------------------------------
|
|
||||||
|
|
||||||
class SingularityEnvironment(BaseEnvironment):
|
class SingularityEnvironment(BaseEnvironment):
|
||||||
"""Hardened Singularity/Apptainer container with resource limits and persistence.
|
"""Hardened Singularity/Apptainer container with resource limits and persistence.
|
||||||
|
|
||||||
Security: --containall (isolated PID/IPC/mount namespaces, no host home mount),
|
Spawn-per-call: every execute() spawns a fresh ``apptainer exec ... bash -c`` process.
|
||||||
--no-home, writable-tmpfs for scratch space. The container cannot see or modify
|
Session snapshot preserves env vars across calls.
|
||||||
the host filesystem outside of explicitly bound paths.
|
CWD persists via in-band stdout markers.
|
||||||
|
|
||||||
Persistence: when enabled, the writable overlay directory is preserved across
|
|
||||||
sessions so installed packages and files survive cleanup/restore.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -227,12 +180,9 @@ class SingularityEnvironment(BaseEnvironment):
|
|||||||
self._persistent = persistent_filesystem
|
self._persistent = persistent_filesystem
|
||||||
self._task_id = task_id
|
self._task_id = task_id
|
||||||
self._overlay_dir: Optional[Path] = None
|
self._overlay_dir: Optional[Path] = None
|
||||||
|
|
||||||
# Resource limits
|
|
||||||
self._cpu = cpu
|
self._cpu = cpu
|
||||||
self._memory = memory
|
self._memory = memory
|
||||||
|
|
||||||
# Persistent overlay directory
|
|
||||||
if self._persistent:
|
if self._persistent:
|
||||||
overlay_base = _get_scratch_dir() / "hermes-overlays"
|
overlay_base = _get_scratch_dir() / "hermes-overlays"
|
||||||
overlay_base.mkdir(parents=True, exist_ok=True)
|
overlay_base.mkdir(parents=True, exist_ok=True)
|
||||||
@@ -240,42 +190,26 @@ class SingularityEnvironment(BaseEnvironment):
|
|||||||
self._overlay_dir.mkdir(parents=True, exist_ok=True)
|
self._overlay_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
self._start_instance()
|
self._start_instance()
|
||||||
|
self.init_session()
|
||||||
|
|
||||||
def _start_instance(self):
|
def _start_instance(self):
|
||||||
cmd = [self.executable, "instance", "start"]
|
cmd = [self.executable, "instance", "start"]
|
||||||
|
|
||||||
# Security: full isolation from host
|
|
||||||
cmd.extend(["--containall", "--no-home"])
|
cmd.extend(["--containall", "--no-home"])
|
||||||
|
|
||||||
# Writable layer
|
|
||||||
if self._persistent and self._overlay_dir:
|
if self._persistent and self._overlay_dir:
|
||||||
# Persistent writable overlay -- survives across restarts
|
|
||||||
cmd.extend(["--overlay", str(self._overlay_dir)])
|
cmd.extend(["--overlay", str(self._overlay_dir)])
|
||||||
else:
|
else:
|
||||||
cmd.append("--writable-tmpfs")
|
cmd.append("--writable-tmpfs")
|
||||||
|
|
||||||
# Mount credential files and skills directory (read-only).
|
|
||||||
try:
|
try:
|
||||||
from tools.credential_files import get_credential_file_mounts, get_skills_directory_mount
|
from tools.credential_files import get_credential_file_mounts, get_skills_directory_mount
|
||||||
|
|
||||||
for mount_entry in get_credential_file_mounts():
|
for mount_entry in get_credential_file_mounts():
|
||||||
cmd.extend(["--bind", f"{mount_entry['host_path']}:{mount_entry['container_path']}:ro"])
|
cmd.extend(["--bind", f"{mount_entry['host_path']}:{mount_entry['container_path']}:ro"])
|
||||||
logger.info(
|
|
||||||
"Singularity: binding credential %s -> %s",
|
|
||||||
mount_entry["host_path"],
|
|
||||||
mount_entry["container_path"],
|
|
||||||
)
|
|
||||||
for skills_mount in get_skills_directory_mount():
|
for skills_mount in get_skills_directory_mount():
|
||||||
cmd.extend(["--bind", f"{skills_mount['host_path']}:{skills_mount['container_path']}:ro"])
|
cmd.extend(["--bind", f"{skills_mount['host_path']}:{skills_mount['container_path']}:ro"])
|
||||||
logger.info(
|
|
||||||
"Singularity: binding skills dir %s -> %s",
|
|
||||||
skills_mount["host_path"],
|
|
||||||
skills_mount["container_path"],
|
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug("Singularity: could not load credential/skills mounts: %s", e)
|
logger.debug("Singularity: could not load credential/skills mounts: %s", e)
|
||||||
|
|
||||||
# Resource limits (cgroup-based, may require root or appropriate config)
|
|
||||||
if self._memory > 0:
|
if self._memory > 0:
|
||||||
cmd.extend(["--memory", f"{self._memory}M"])
|
cmd.extend(["--memory", f"{self._memory}M"])
|
||||||
if self._cpu > 0:
|
if self._cpu > 0:
|
||||||
@@ -293,89 +227,24 @@ class SingularityEnvironment(BaseEnvironment):
|
|||||||
except subprocess.TimeoutExpired:
|
except subprocess.TimeoutExpired:
|
||||||
raise RuntimeError("Instance start timed out")
|
raise RuntimeError("Instance start timed out")
|
||||||
|
|
||||||
def execute(self, command: str, cwd: str = "", *,
|
def _run_bash(self, cmd_string: str, *, login: bool = False,
|
||||||
timeout: int | None = None,
|
timeout: int = 120,
|
||||||
stdin_data: str | None = None) -> dict:
|
stdin_data: str | None = None) -> subprocess.Popen:
|
||||||
|
"""Spawn a bash process inside the Singularity instance."""
|
||||||
if not self._instance_started:
|
if not self._instance_started:
|
||||||
return {"output": "Instance not started", "returncode": -1}
|
raise RuntimeError("Singularity instance not started")
|
||||||
|
|
||||||
effective_timeout = timeout or self.timeout
|
cmd = [self.executable, "exec",
|
||||||
work_dir = cwd or self.cwd
|
f"instance://{self.instance_id}"]
|
||||||
exec_command, sudo_stdin = self._prepare_command(command)
|
if login:
|
||||||
|
cmd.extend(["bash", "-l", "-c", cmd_string])
|
||||||
# Merge sudo password (if any) with caller-supplied stdin_data.
|
|
||||||
if sudo_stdin is not None and stdin_data is not None:
|
|
||||||
effective_stdin = sudo_stdin + stdin_data
|
|
||||||
elif sudo_stdin is not None:
|
|
||||||
effective_stdin = sudo_stdin
|
|
||||||
else:
|
else:
|
||||||
effective_stdin = stdin_data
|
cmd.extend(["bash", "-c", cmd_string])
|
||||||
|
|
||||||
# apptainer exec --pwd doesn't expand ~, so prepend a cd into the command.
|
return _popen_bash(cmd, stdin_data)
|
||||||
# Keep ~ unquoted (for shell expansion) and quote only the subpath.
|
|
||||||
if work_dir == "~":
|
|
||||||
exec_command = f"cd ~ && {exec_command}"
|
|
||||||
work_dir = "/tmp"
|
|
||||||
elif work_dir.startswith("~/"):
|
|
||||||
exec_command = f"cd ~/{shlex.quote(work_dir[2:])} && {exec_command}"
|
|
||||||
work_dir = "/tmp"
|
|
||||||
|
|
||||||
cmd = [self.executable, "exec", "--pwd", work_dir,
|
|
||||||
f"instance://{self.instance_id}",
|
|
||||||
"bash", "-c", exec_command]
|
|
||||||
|
|
||||||
try:
|
|
||||||
import time as _time
|
|
||||||
_output_chunks = []
|
|
||||||
proc = subprocess.Popen(
|
|
||||||
cmd,
|
|
||||||
stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
|
|
||||||
stdin=subprocess.PIPE if effective_stdin else subprocess.DEVNULL,
|
|
||||||
text=True,
|
|
||||||
)
|
|
||||||
if effective_stdin:
|
|
||||||
try:
|
|
||||||
proc.stdin.write(effective_stdin)
|
|
||||||
proc.stdin.close()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def _drain():
|
|
||||||
try:
|
|
||||||
for line in proc.stdout:
|
|
||||||
_output_chunks.append(line)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
reader = threading.Thread(target=_drain, daemon=True)
|
|
||||||
reader.start()
|
|
||||||
deadline = _time.monotonic() + effective_timeout
|
|
||||||
|
|
||||||
while proc.poll() is None:
|
|
||||||
if is_interrupted():
|
|
||||||
proc.terminate()
|
|
||||||
try:
|
|
||||||
proc.wait(timeout=1)
|
|
||||||
except subprocess.TimeoutExpired:
|
|
||||||
proc.kill()
|
|
||||||
reader.join(timeout=2)
|
|
||||||
return {
|
|
||||||
"output": "".join(_output_chunks) + "\n[Command interrupted]",
|
|
||||||
"returncode": 130,
|
|
||||||
}
|
|
||||||
if _time.monotonic() > deadline:
|
|
||||||
proc.kill()
|
|
||||||
reader.join(timeout=2)
|
|
||||||
return self._timeout_result(effective_timeout)
|
|
||||||
_time.sleep(0.2)
|
|
||||||
|
|
||||||
reader.join(timeout=5)
|
|
||||||
return {"output": "".join(_output_chunks), "returncode": proc.returncode}
|
|
||||||
except Exception as e:
|
|
||||||
return {"output": f"Singularity execution error: {e}", "returncode": 1}
|
|
||||||
|
|
||||||
def cleanup(self):
|
def cleanup(self):
|
||||||
"""Stop the instance. If persistent, the overlay dir survives for next creation."""
|
"""Stop the instance. If persistent, the overlay dir survives."""
|
||||||
if self._instance_started:
|
if self._instance_started:
|
||||||
try:
|
try:
|
||||||
subprocess.run(
|
subprocess.run(
|
||||||
@@ -387,7 +256,6 @@ class SingularityEnvironment(BaseEnvironment):
|
|||||||
logger.warning("Failed to stop Singularity instance %s: %s", self.instance_id, e)
|
logger.warning("Failed to stop Singularity instance %s: %s", self.instance_id, e)
|
||||||
self._instance_started = False
|
self._instance_started = False
|
||||||
|
|
||||||
# Record overlay path for persistence restoration
|
|
||||||
if self._persistent and self._overlay_dir:
|
if self._persistent and self._overlay_dir:
|
||||||
snapshots = _load_snapshots()
|
snapshots = _load_snapshots()
|
||||||
snapshots[self._task_id] = str(self._overlay_dir)
|
snapshots[self._task_id] = str(self._overlay_dir)
|
||||||
|
|||||||
@@ -5,13 +5,9 @@ import shlex
|
|||||||
import shutil
|
import shutil
|
||||||
import subprocess
|
import subprocess
|
||||||
import tempfile
|
import tempfile
|
||||||
import threading
|
|
||||||
import time
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from tools.environments.base import BaseEnvironment
|
from tools.environments.base import BaseEnvironment, _popen_bash
|
||||||
from tools.environments.persistent_shell import PersistentShellMixin
|
|
||||||
from tools.interrupt import is_interrupted
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -24,32 +20,22 @@ def _ensure_ssh_available() -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class SSHEnvironment(PersistentShellMixin, BaseEnvironment):
|
class SSHEnvironment(BaseEnvironment):
|
||||||
"""Run commands on a remote machine over SSH.
|
"""Run commands on a remote machine over SSH.
|
||||||
|
|
||||||
Uses SSH ControlMaster for connection persistence so subsequent
|
Spawn-per-call: every execute() spawns a fresh ``ssh ... bash -c`` process.
|
||||||
commands are fast. Security benefit: the agent cannot modify its
|
Session snapshot preserves env vars across calls.
|
||||||
own code since execution happens on a separate machine.
|
CWD persists via in-band stdout markers.
|
||||||
|
Uses SSH ControlMaster for connection reuse.
|
||||||
Foreground commands are interruptible: the local ssh process is killed
|
|
||||||
and a remote kill is attempted over the ControlMaster socket.
|
|
||||||
|
|
||||||
When ``persistent=True``, a single long-lived bash shell is kept alive
|
|
||||||
over SSH and state (cwd, env vars, shell variables) persists across
|
|
||||||
``execute()`` calls. Output capture uses file-based IPC on the remote
|
|
||||||
host (stdout/stderr/exit-code written to temp files, polled via fast
|
|
||||||
ControlMaster one-shot reads).
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, host: str, user: str, cwd: str = "~",
|
def __init__(self, host: str, user: str, cwd: str = "~",
|
||||||
timeout: int = 60, port: int = 22, key_path: str = "",
|
timeout: int = 60, port: int = 22, key_path: str = ""):
|
||||||
persistent: bool = False):
|
|
||||||
super().__init__(cwd=cwd, timeout=timeout)
|
super().__init__(cwd=cwd, timeout=timeout)
|
||||||
self.host = host
|
self.host = host
|
||||||
self.user = user
|
self.user = user
|
||||||
self.port = port
|
self.port = port
|
||||||
self.key_path = key_path
|
self.key_path = key_path
|
||||||
self.persistent = persistent
|
|
||||||
|
|
||||||
self.control_dir = Path(tempfile.gettempdir()) / "hermes-ssh"
|
self.control_dir = Path(tempfile.gettempdir()) / "hermes-ssh"
|
||||||
self.control_dir.mkdir(parents=True, exist_ok=True)
|
self.control_dir.mkdir(parents=True, exist_ok=True)
|
||||||
@@ -57,10 +43,10 @@ class SSHEnvironment(PersistentShellMixin, BaseEnvironment):
|
|||||||
_ensure_ssh_available()
|
_ensure_ssh_available()
|
||||||
self._establish_connection()
|
self._establish_connection()
|
||||||
self._remote_home = self._detect_remote_home()
|
self._remote_home = self._detect_remote_home()
|
||||||
self._sync_skills_and_credentials()
|
self._last_sync_time: float = 0 # guarantees first _before_execute syncs
|
||||||
|
self._sync_files()
|
||||||
|
|
||||||
if self.persistent:
|
self.init_session()
|
||||||
self._init_persistent_shell()
|
|
||||||
|
|
||||||
def _build_ssh_command(self, extra_args: list | None = None) -> list:
|
def _build_ssh_command(self, extra_args: list | None = None) -> list:
|
||||||
cmd = ["ssh"]
|
cmd = ["ssh"]
|
||||||
@@ -102,12 +88,11 @@ class SSHEnvironment(PersistentShellMixin, BaseEnvironment):
|
|||||||
return home
|
return home
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
# Fallback: guess from username
|
|
||||||
if self.user == "root":
|
if self.user == "root":
|
||||||
return "/root"
|
return "/root"
|
||||||
return f"/home/{self.user}"
|
return f"/home/{self.user}"
|
||||||
|
|
||||||
def _sync_skills_and_credentials(self) -> None:
|
def _sync_files(self) -> None:
|
||||||
"""Rsync skills directory and credential files to the remote host."""
|
"""Rsync skills directory and credential files to the remote host."""
|
||||||
try:
|
try:
|
||||||
container_base = f"{self._remote_home}/.hermes"
|
container_base = f"{self._remote_home}/.hermes"
|
||||||
@@ -122,7 +107,6 @@ class SSHEnvironment(PersistentShellMixin, BaseEnvironment):
|
|||||||
rsync_base.extend(["-e", ssh_opts])
|
rsync_base.extend(["-e", ssh_opts])
|
||||||
dest_prefix = f"{self.user}@{self.host}"
|
dest_prefix = f"{self.user}@{self.host}"
|
||||||
|
|
||||||
# Sync individual credential files (remap /root/.hermes to detected home)
|
|
||||||
for mount_entry in get_credential_file_mounts():
|
for mount_entry in get_credential_file_mounts():
|
||||||
remote_path = mount_entry["container_path"].replace("/root/.hermes", container_base, 1)
|
remote_path = mount_entry["container_path"].replace("/root/.hermes", container_base, 1)
|
||||||
parent_dir = str(Path(remote_path).parent)
|
parent_dir = str(Path(remote_path).parent)
|
||||||
@@ -136,7 +120,6 @@ class SSHEnvironment(PersistentShellMixin, BaseEnvironment):
|
|||||||
else:
|
else:
|
||||||
logger.debug("SSH: rsync credential failed: %s", result.stderr.strip())
|
logger.debug("SSH: rsync credential failed: %s", result.stderr.strip())
|
||||||
|
|
||||||
# Sync skill directories (local + external, remap to detected home)
|
|
||||||
for skills_mount in get_skills_directory_mount(container_base=container_base):
|
for skills_mount in get_skills_directory_mount(container_base=container_base):
|
||||||
remote_path = skills_mount["container_path"]
|
remote_path = skills_mount["container_path"]
|
||||||
mkdir_cmd = self._build_ssh_command()
|
mkdir_cmd = self._build_ssh_command()
|
||||||
@@ -154,152 +137,19 @@ class SSHEnvironment(PersistentShellMixin, BaseEnvironment):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug("SSH: could not sync skills/credentials: %s", e)
|
logger.debug("SSH: could not sync skills/credentials: %s", e)
|
||||||
|
|
||||||
def execute(self, command: str, cwd: str = "", *,
|
def _run_bash(self, cmd_string: str, *, login: bool = False,
|
||||||
timeout: int | None = None,
|
timeout: int = 120,
|
||||||
stdin_data: str | None = None) -> dict:
|
stdin_data: str | None = None) -> subprocess.Popen:
|
||||||
# Incremental sync before each command so mid-session credential
|
"""Spawn an SSH process that runs bash on the remote host."""
|
||||||
# refreshes and skill updates are picked up.
|
|
||||||
self._sync_skills_and_credentials()
|
|
||||||
return super().execute(command, cwd, timeout=timeout, stdin_data=stdin_data)
|
|
||||||
|
|
||||||
_poll_interval_start: float = 0.15 # SSH: higher initial interval (150ms) for network latency
|
|
||||||
|
|
||||||
@property
|
|
||||||
def _temp_prefix(self) -> str:
|
|
||||||
return f"/tmp/hermes-ssh-{self._session_id}"
|
|
||||||
|
|
||||||
def _spawn_shell_process(self) -> subprocess.Popen:
|
|
||||||
cmd = self._build_ssh_command()
|
cmd = self._build_ssh_command()
|
||||||
cmd.append("bash -l")
|
if login:
|
||||||
return subprocess.Popen(
|
cmd.extend(["bash", "-l", "-c", shlex.quote(cmd_string)])
|
||||||
cmd,
|
|
||||||
stdin=subprocess.PIPE,
|
|
||||||
stdout=subprocess.PIPE,
|
|
||||||
stderr=subprocess.DEVNULL,
|
|
||||||
text=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _read_temp_files(self, *paths: str) -> list[str]:
|
|
||||||
if len(paths) == 1:
|
|
||||||
cmd = self._build_ssh_command()
|
|
||||||
cmd.append(f"cat {paths[0]} 2>/dev/null")
|
|
||||||
try:
|
|
||||||
result = subprocess.run(
|
|
||||||
cmd, capture_output=True, text=True, timeout=10,
|
|
||||||
)
|
|
||||||
return [result.stdout]
|
|
||||||
except (subprocess.TimeoutExpired, OSError):
|
|
||||||
return [""]
|
|
||||||
|
|
||||||
delim = f"__HERMES_SEP_{self._session_id}__"
|
|
||||||
script = "; ".join(
|
|
||||||
f"cat {p} 2>/dev/null; echo '{delim}'" for p in paths
|
|
||||||
)
|
|
||||||
cmd = self._build_ssh_command()
|
|
||||||
cmd.append(script)
|
|
||||||
try:
|
|
||||||
result = subprocess.run(
|
|
||||||
cmd, capture_output=True, text=True, timeout=10,
|
|
||||||
)
|
|
||||||
parts = result.stdout.split(delim + "\n")
|
|
||||||
return [parts[i] if i < len(parts) else "" for i in range(len(paths))]
|
|
||||||
except (subprocess.TimeoutExpired, OSError):
|
|
||||||
return [""] * len(paths)
|
|
||||||
|
|
||||||
def _kill_shell_children(self):
|
|
||||||
if self._shell_pid is None:
|
|
||||||
return
|
|
||||||
cmd = self._build_ssh_command()
|
|
||||||
cmd.append(f"pkill -P {self._shell_pid} 2>/dev/null; true")
|
|
||||||
try:
|
|
||||||
subprocess.run(cmd, capture_output=True, timeout=5)
|
|
||||||
except (subprocess.TimeoutExpired, OSError):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def _cleanup_temp_files(self):
|
|
||||||
cmd = self._build_ssh_command()
|
|
||||||
cmd.append(f"rm -f {self._temp_prefix}-*")
|
|
||||||
try:
|
|
||||||
subprocess.run(cmd, capture_output=True, timeout=5)
|
|
||||||
except (subprocess.TimeoutExpired, OSError):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def _execute_oneshot(self, command: str, cwd: str = "", *,
|
|
||||||
timeout: int | None = None,
|
|
||||||
stdin_data: str | None = None) -> dict:
|
|
||||||
work_dir = cwd or self.cwd
|
|
||||||
exec_command, sudo_stdin = self._prepare_command(command)
|
|
||||||
# Keep ~ unquoted (for shell expansion) and quote only the subpath.
|
|
||||||
if work_dir == "~":
|
|
||||||
wrapped = f'cd ~ && {exec_command}'
|
|
||||||
elif work_dir.startswith("~/"):
|
|
||||||
wrapped = f'cd ~/{shlex.quote(work_dir[2:])} && {exec_command}'
|
|
||||||
else:
|
else:
|
||||||
wrapped = f'cd {shlex.quote(work_dir)} && {exec_command}'
|
cmd.extend(["bash", "-c", shlex.quote(cmd_string)])
|
||||||
effective_timeout = timeout or self.timeout
|
|
||||||
|
|
||||||
if sudo_stdin is not None and stdin_data is not None:
|
return _popen_bash(cmd, stdin_data)
|
||||||
effective_stdin = sudo_stdin + stdin_data
|
|
||||||
elif sudo_stdin is not None:
|
|
||||||
effective_stdin = sudo_stdin
|
|
||||||
else:
|
|
||||||
effective_stdin = stdin_data
|
|
||||||
|
|
||||||
cmd = self._build_ssh_command()
|
|
||||||
cmd.append(wrapped)
|
|
||||||
|
|
||||||
kwargs = self._build_run_kwargs(timeout, effective_stdin)
|
|
||||||
kwargs.pop("timeout", None)
|
|
||||||
_output_chunks = []
|
|
||||||
proc = subprocess.Popen(
|
|
||||||
cmd,
|
|
||||||
stdout=subprocess.PIPE,
|
|
||||||
stderr=subprocess.STDOUT,
|
|
||||||
stdin=subprocess.PIPE if effective_stdin else subprocess.DEVNULL,
|
|
||||||
text=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
if effective_stdin:
|
|
||||||
try:
|
|
||||||
proc.stdin.write(effective_stdin)
|
|
||||||
proc.stdin.close()
|
|
||||||
except (BrokenPipeError, OSError):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def _drain():
|
|
||||||
try:
|
|
||||||
for line in proc.stdout:
|
|
||||||
_output_chunks.append(line)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
reader = threading.Thread(target=_drain, daemon=True)
|
|
||||||
reader.start()
|
|
||||||
deadline = time.monotonic() + effective_timeout
|
|
||||||
|
|
||||||
while proc.poll() is None:
|
|
||||||
if is_interrupted():
|
|
||||||
proc.terminate()
|
|
||||||
try:
|
|
||||||
proc.wait(timeout=1)
|
|
||||||
except subprocess.TimeoutExpired:
|
|
||||||
proc.kill()
|
|
||||||
reader.join(timeout=2)
|
|
||||||
return {
|
|
||||||
"output": "".join(_output_chunks) + "\n[Command interrupted]",
|
|
||||||
"returncode": 130,
|
|
||||||
}
|
|
||||||
if time.monotonic() > deadline:
|
|
||||||
proc.kill()
|
|
||||||
reader.join(timeout=2)
|
|
||||||
return self._timeout_result(effective_timeout)
|
|
||||||
time.sleep(0.2)
|
|
||||||
|
|
||||||
reader.join(timeout=5)
|
|
||||||
return {"output": "".join(_output_chunks), "returncode": proc.returncode}
|
|
||||||
|
|
||||||
def cleanup(self):
|
def cleanup(self):
|
||||||
super().cleanup()
|
|
||||||
if self.control_socket.exists():
|
if self.control_socket.exists():
|
||||||
try:
|
try:
|
||||||
cmd = ["ssh", "-o", f"ControlPath={self.control_socket}",
|
cmd = ["ssh", "-o", f"ControlPath={self.control_socket}",
|
||||||
|
|||||||
@@ -611,9 +611,7 @@ def _create_environment(env_type: str, image: str, cwd: str, timeout: int,
|
|||||||
docker_env = cc.get("docker_env", {})
|
docker_env = cc.get("docker_env", {})
|
||||||
|
|
||||||
if env_type == "local":
|
if env_type == "local":
|
||||||
lc = local_config or {}
|
return _LocalEnvironment(cwd=cwd, timeout=timeout)
|
||||||
return _LocalEnvironment(cwd=cwd, timeout=timeout,
|
|
||||||
persistent=lc.get("persistent", False))
|
|
||||||
|
|
||||||
elif env_type == "docker":
|
elif env_type == "docker":
|
||||||
return _DockerEnvironment(
|
return _DockerEnvironment(
|
||||||
@@ -705,7 +703,6 @@ def _create_environment(env_type: str, image: str, cwd: str, timeout: int,
|
|||||||
key_path=ssh_config.get("key", ""),
|
key_path=ssh_config.get("key", ""),
|
||||||
cwd=cwd,
|
cwd=cwd,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
persistent=ssh_config.get("persistent", False),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
|||||||
Reference in New Issue
Block a user