mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-05-02 08:47:26 +08:00
feat(environments): add pwncollege RL environment with per-task SSH overrides
This commit is contained in:
98
tests/tools/test_ssh_overrides.py
Normal file
98
tests/tools/test_ssh_overrides.py
Normal file
@@ -0,0 +1,98 @@
|
||||
"""Tests for per-task SSH environment overrides."""
|
||||
|
||||
from tools.terminal_tool import (
|
||||
register_task_env_overrides,
|
||||
clear_task_env_overrides,
|
||||
_task_env_overrides,
|
||||
)
|
||||
|
||||
|
||||
class TestSSHOverridesInConfig:
|
||||
"""Verify SSH config assembly respects per-task overrides."""
|
||||
|
||||
def setup_method(self):
|
||||
self._saved = dict(_task_env_overrides)
|
||||
_task_env_overrides.clear()
|
||||
|
||||
def teardown_method(self):
|
||||
_task_env_overrides.clear()
|
||||
_task_env_overrides.update(self._saved)
|
||||
|
||||
def _build_ssh_config(self, task_id: str, global_config: dict) -> dict:
|
||||
"""Replicate the SSH config assembly logic from terminal_tool.py."""
|
||||
overrides = _task_env_overrides.get(task_id, {})
|
||||
return {
|
||||
"host": overrides.get("ssh_host") or global_config.get("ssh_host", ""),
|
||||
"user": overrides.get("ssh_user") or global_config.get("ssh_user", ""),
|
||||
"port": overrides.get("ssh_port") or global_config.get("ssh_port", 22),
|
||||
"key": overrides.get("ssh_key") or global_config.get("ssh_key", ""),
|
||||
"persistent": overrides.get("ssh_persistent", global_config.get("ssh_persistent", False)),
|
||||
}
|
||||
|
||||
def test_no_overrides_uses_global(self):
|
||||
"""Without per-task overrides, global config is used."""
|
||||
global_config = {
|
||||
"ssh_host": "global.example.com",
|
||||
"ssh_user": "root",
|
||||
"ssh_port": 22,
|
||||
"ssh_key": "/root/.ssh/id_rsa",
|
||||
"ssh_persistent": True,
|
||||
}
|
||||
result = self._build_ssh_config("task-1", global_config)
|
||||
assert result["host"] == "global.example.com"
|
||||
assert result["user"] == "root"
|
||||
assert result["port"] == 22
|
||||
assert result["key"] == "/root/.ssh/id_rsa"
|
||||
assert result["persistent"] is True
|
||||
|
||||
def test_override_port_and_key(self):
|
||||
"""Per-task overrides for port and key take precedence."""
|
||||
global_config = {
|
||||
"ssh_host": "dojo.pwncollege.com",
|
||||
"ssh_user": "hacker",
|
||||
"ssh_port": 22,
|
||||
"ssh_key": "/default/key",
|
||||
}
|
||||
register_task_env_overrides("task-42", {
|
||||
"ssh_port": 2264,
|
||||
"ssh_key": "/tmp/keys/episode_42",
|
||||
})
|
||||
result = self._build_ssh_config("task-42", global_config)
|
||||
assert result["port"] == 2264
|
||||
assert result["key"] == "/tmp/keys/episode_42"
|
||||
# Non-overridden fields fall through to global
|
||||
assert result["host"] == "dojo.pwncollege.com"
|
||||
assert result["user"] == "hacker"
|
||||
|
||||
def test_different_tasks_get_different_ports(self):
|
||||
"""128 parallel rollouts each get their own SSH port."""
|
||||
global_config = {
|
||||
"ssh_host": "dojo.pwncollege.com",
|
||||
"ssh_user": "hacker",
|
||||
"ssh_port": 22,
|
||||
"ssh_key": "",
|
||||
}
|
||||
for i in range(128):
|
||||
tid = f"task-{i}"
|
||||
register_task_env_overrides(tid, {"ssh_port": 2222 + i})
|
||||
|
||||
for i in range(128):
|
||||
tid = f"task-{i}"
|
||||
result = self._build_ssh_config(tid, global_config)
|
||||
assert result["port"] == 2222 + i
|
||||
|
||||
def test_clear_overrides_reverts_to_global(self):
|
||||
"""After clearing, config falls back to global."""
|
||||
global_config = {"ssh_port": 22}
|
||||
register_task_env_overrides("task-99", {"ssh_port": 9999})
|
||||
assert self._build_ssh_config("task-99", global_config)["port"] == 9999
|
||||
|
||||
clear_task_env_overrides("task-99")
|
||||
assert self._build_ssh_config("task-99", global_config)["port"] == 22
|
||||
|
||||
def test_persistent_false_not_clobbered_by_or(self):
|
||||
"""ssh_persistent=False override must not be skipped due to falsy `or`."""
|
||||
global_config = {"ssh_persistent": True}
|
||||
register_task_env_overrides("task-x", {"ssh_persistent": False})
|
||||
result = self._build_ssh_config("task-x", global_config)
|
||||
assert result["persistent"] is False
|
||||
Reference in New Issue
Block a user