Files
hermes-agent/tests/tools/test_ssh_overrides.py

99 lines
3.8 KiB
Python

"""Tests for per-task SSH environment overrides."""
from tools.terminal_tool import (
register_task_env_overrides,
clear_task_env_overrides,
_task_env_overrides,
)
class TestSSHOverridesInConfig:
"""Verify SSH config assembly respects per-task overrides."""
def setup_method(self):
self._saved = dict(_task_env_overrides)
_task_env_overrides.clear()
def teardown_method(self):
_task_env_overrides.clear()
_task_env_overrides.update(self._saved)
def _build_ssh_config(self, task_id: str, global_config: dict) -> dict:
"""Replicate the SSH config assembly logic from terminal_tool.py."""
overrides = _task_env_overrides.get(task_id, {})
return {
"host": overrides.get("ssh_host") or global_config.get("ssh_host", ""),
"user": overrides.get("ssh_user") or global_config.get("ssh_user", ""),
"port": overrides.get("ssh_port") or global_config.get("ssh_port", 22),
"key": overrides.get("ssh_key") or global_config.get("ssh_key", ""),
"persistent": overrides.get("ssh_persistent", global_config.get("ssh_persistent", False)),
}
def test_no_overrides_uses_global(self):
"""Without per-task overrides, global config is used."""
global_config = {
"ssh_host": "global.example.com",
"ssh_user": "root",
"ssh_port": 22,
"ssh_key": "/root/.ssh/id_rsa",
"ssh_persistent": True,
}
result = self._build_ssh_config("task-1", global_config)
assert result["host"] == "global.example.com"
assert result["user"] == "root"
assert result["port"] == 22
assert result["key"] == "/root/.ssh/id_rsa"
assert result["persistent"] is True
def test_override_port_and_key(self):
"""Per-task overrides for port and key take precedence."""
global_config = {
"ssh_host": "dojo.pwncollege.com",
"ssh_user": "hacker",
"ssh_port": 22,
"ssh_key": "/default/key",
}
register_task_env_overrides("task-42", {
"ssh_port": 2264,
"ssh_key": "/tmp/keys/episode_42",
})
result = self._build_ssh_config("task-42", global_config)
assert result["port"] == 2264
assert result["key"] == "/tmp/keys/episode_42"
# Non-overridden fields fall through to global
assert result["host"] == "dojo.pwncollege.com"
assert result["user"] == "hacker"
def test_different_tasks_get_different_ports(self):
"""128 parallel rollouts each get their own SSH port."""
global_config = {
"ssh_host": "dojo.pwncollege.com",
"ssh_user": "hacker",
"ssh_port": 22,
"ssh_key": "",
}
for i in range(128):
tid = f"task-{i}"
register_task_env_overrides(tid, {"ssh_port": 2222 + i})
for i in range(128):
tid = f"task-{i}"
result = self._build_ssh_config(tid, global_config)
assert result["port"] == 2222 + i
def test_clear_overrides_reverts_to_global(self):
"""After clearing, config falls back to global."""
global_config = {"ssh_port": 22}
register_task_env_overrides("task-99", {"ssh_port": 9999})
assert self._build_ssh_config("task-99", global_config)["port"] == 9999
clear_task_env_overrides("task-99")
assert self._build_ssh_config("task-99", global_config)["port"] == 22
def test_persistent_false_not_clobbered_by_or(self):
"""ssh_persistent=False override must not be skipped due to falsy `or`."""
global_config = {"ssh_persistent": True}
register_task_env_overrides("task-x", {"ssh_persistent": False})
result = self._build_ssh_config("task-x", global_config)
assert result["persistent"] is False