mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-28 23:11:37 +08:00
99 lines
3.8 KiB
Python
99 lines
3.8 KiB
Python
"""Tests for per-task SSH environment overrides."""
|
|
|
|
from tools.terminal_tool import (
|
|
register_task_env_overrides,
|
|
clear_task_env_overrides,
|
|
_task_env_overrides,
|
|
)
|
|
|
|
|
|
class TestSSHOverridesInConfig:
|
|
"""Verify SSH config assembly respects per-task overrides."""
|
|
|
|
def setup_method(self):
|
|
self._saved = dict(_task_env_overrides)
|
|
_task_env_overrides.clear()
|
|
|
|
def teardown_method(self):
|
|
_task_env_overrides.clear()
|
|
_task_env_overrides.update(self._saved)
|
|
|
|
def _build_ssh_config(self, task_id: str, global_config: dict) -> dict:
|
|
"""Replicate the SSH config assembly logic from terminal_tool.py."""
|
|
overrides = _task_env_overrides.get(task_id, {})
|
|
return {
|
|
"host": overrides.get("ssh_host") or global_config.get("ssh_host", ""),
|
|
"user": overrides.get("ssh_user") or global_config.get("ssh_user", ""),
|
|
"port": overrides.get("ssh_port") or global_config.get("ssh_port", 22),
|
|
"key": overrides.get("ssh_key") or global_config.get("ssh_key", ""),
|
|
"persistent": overrides.get("ssh_persistent", global_config.get("ssh_persistent", False)),
|
|
}
|
|
|
|
def test_no_overrides_uses_global(self):
|
|
"""Without per-task overrides, global config is used."""
|
|
global_config = {
|
|
"ssh_host": "global.example.com",
|
|
"ssh_user": "root",
|
|
"ssh_port": 22,
|
|
"ssh_key": "/root/.ssh/id_rsa",
|
|
"ssh_persistent": True,
|
|
}
|
|
result = self._build_ssh_config("task-1", global_config)
|
|
assert result["host"] == "global.example.com"
|
|
assert result["user"] == "root"
|
|
assert result["port"] == 22
|
|
assert result["key"] == "/root/.ssh/id_rsa"
|
|
assert result["persistent"] is True
|
|
|
|
def test_override_port_and_key(self):
|
|
"""Per-task overrides for port and key take precedence."""
|
|
global_config = {
|
|
"ssh_host": "dojo.pwncollege.com",
|
|
"ssh_user": "hacker",
|
|
"ssh_port": 22,
|
|
"ssh_key": "/default/key",
|
|
}
|
|
register_task_env_overrides("task-42", {
|
|
"ssh_port": 2264,
|
|
"ssh_key": "/tmp/keys/episode_42",
|
|
})
|
|
result = self._build_ssh_config("task-42", global_config)
|
|
assert result["port"] == 2264
|
|
assert result["key"] == "/tmp/keys/episode_42"
|
|
# Non-overridden fields fall through to global
|
|
assert result["host"] == "dojo.pwncollege.com"
|
|
assert result["user"] == "hacker"
|
|
|
|
def test_different_tasks_get_different_ports(self):
|
|
"""128 parallel rollouts each get their own SSH port."""
|
|
global_config = {
|
|
"ssh_host": "dojo.pwncollege.com",
|
|
"ssh_user": "hacker",
|
|
"ssh_port": 22,
|
|
"ssh_key": "",
|
|
}
|
|
for i in range(128):
|
|
tid = f"task-{i}"
|
|
register_task_env_overrides(tid, {"ssh_port": 2222 + i})
|
|
|
|
for i in range(128):
|
|
tid = f"task-{i}"
|
|
result = self._build_ssh_config(tid, global_config)
|
|
assert result["port"] == 2222 + i
|
|
|
|
def test_clear_overrides_reverts_to_global(self):
|
|
"""After clearing, config falls back to global."""
|
|
global_config = {"ssh_port": 22}
|
|
register_task_env_overrides("task-99", {"ssh_port": 9999})
|
|
assert self._build_ssh_config("task-99", global_config)["port"] == 9999
|
|
|
|
clear_task_env_overrides("task-99")
|
|
assert self._build_ssh_config("task-99", global_config)["port"] == 22
|
|
|
|
def test_persistent_false_not_clobbered_by_or(self):
|
|
"""ssh_persistent=False override must not be skipped due to falsy `or`."""
|
|
global_config = {"ssh_persistent": True}
|
|
register_task_env_overrides("task-x", {"ssh_persistent": False})
|
|
result = self._build_ssh_config("task-x", global_config)
|
|
assert result["persistent"] is False
|