diff --git a/environments/agent_loop.py b/environments/agent_loop.py index 11a8a01f3a9..41da13a2cac 100644 --- a/environments/agent_loop.py +++ b/environments/agent_loop.py @@ -18,7 +18,7 @@ import logging import os import uuid from dataclasses import dataclass, field -from typing import Any, Dict, List, Optional, Set +from typing import Any, Callable, Dict, List, Optional, Set from model_tools import handle_function_call @@ -138,6 +138,7 @@ class HermesAgentLoop: temperature: float = 1.0, max_tokens: Optional[int] = None, extra_body: Optional[Dict[str, Any]] = None, + early_stop_check: Optional[Callable[[List[Dict[str, Any]]], bool]] = None, ): """ Initialize the agent loop. @@ -154,6 +155,9 @@ class HermesAgentLoop: extra_body: Extra parameters passed to the OpenAI client's create() call. Used for OpenRouter provider preferences, transforms, etc. e.g. {"provider": {"ignore": ["DeepInfra"]}} + early_stop_check: Optional callback that inspects messages after each tool + turn. If it returns True, the loop ends with finished_naturally=True. + Used for environment-level completion signals (e.g., flag accepted). """ self.server = server self.tool_schemas = tool_schemas @@ -163,6 +167,7 @@ class HermesAgentLoop: self.temperature = temperature self.max_tokens = max_tokens self.extra_body = extra_body + self.early_stop_check = early_stop_check async def run(self, messages: List[Dict[str, Any]]) -> AgentResult: """ @@ -456,6 +461,23 @@ class HermesAgentLoop: } ) + # Check if environment signals early stop (e.g., flag accepted) + if self.early_stop_check and self.early_stop_check(messages): + turn_elapsed = _time.monotonic() - turn_start + logger.info( + "[%s] turn %d: early stop triggered after %d tools (%.1fs)", + self.task_id[:8], turn + 1, + len(assistant_msg.tool_calls), turn_elapsed, + ) + return AgentResult( + messages=messages, + managed_state=self._get_managed_state(), + turns_used=turn + 1, + finished_naturally=True, + reasoning_per_turn=reasoning_per_turn, + tool_errors=tool_errors, + ) + turn_elapsed = _time.monotonic() - turn_start logger.info( "[%s] turn %d: api=%.1fs, %d tools, turn_total=%.1fs", diff --git a/environments/pwncollege_env/evaluate_config.yaml b/environments/pwncollege_env/evaluate_config.yaml index 94e5128a479..a323cd85470 100644 --- a/environments/pwncollege_env/evaluate_config.yaml +++ b/environments/pwncollege_env/evaluate_config.yaml @@ -65,7 +65,7 @@ openai: rolling_buffer_length: 1000 server_type: openai tokenizer_name: none - api_key: ${OPENROUTER_API_KEY} + api_key: "" base_url: https://openrouter.ai/api/v1 n_kwarg_is_ignored: false health_check: false diff --git a/environments/pwncollege_env/sdk.py b/environments/pwncollege_env/sdk.py index 7811c38258c..4f0959619ae 100644 --- a/environments/pwncollege_env/sdk.py +++ b/environments/pwncollege_env/sdk.py @@ -236,6 +236,15 @@ class DojoRLClient: raise RuntimeError(f"No flag available for slot {slot}") return instance.flag + # ── SSH Key Management ──────────────────────────────────────────────────── + + async def register_ssh_key(self, public_key: str) -> bool: + result = await self._post("/ssh_key", json={"public_key": public_key}) + return result.get("success", False) + + async def get_ssh_key(self) -> dict[str, Any]: + return await self._get("/ssh_key") + # ── Challenge Discovery ─────────────────────────────────────────────────── async def list_challenges(self) -> list[RLChallenge]: @@ -369,6 +378,12 @@ class DojoRLSyncClient: def list_challenges(self) -> list[RLChallenge]: return self._run(self._async.list_challenges()) + def register_ssh_key(self, public_key: str) -> bool: + return self._run(self._async.register_ssh_key(public_key)) + + def get_ssh_key(self) -> dict[str, Any]: + return self._run(self._async.get_ssh_key()) + def admin_login(self, username: str = "admin", password: str = "admin") -> None: return self._run(self._async.admin_login(username, password)) diff --git a/environments/pwncollege_env/smoke_hello.yaml b/environments/pwncollege_env/smoke_hello.yaml index 627856c30f5..a028f149afa 100644 --- a/environments/pwncollege_env/smoke_hello.yaml +++ b/environments/pwncollege_env/smoke_hello.yaml @@ -57,7 +57,7 @@ env: eval_exclude_dojos: - archive eval_module: hello - eval_concurrency: 2 + eval_concurrency: 3 openai: - timeout: 1200 num_max_requests_at_once: 512 diff --git a/pyproject.toml b/pyproject.toml index d670d044f23..e5a71a4739b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,7 +58,7 @@ sms = ["aiohttp>=3.9.0,<4"] acp = ["agent-client-protocol>=0.8.1,<0.9"] dingtalk = ["dingtalk-stream>=0.1.0,<1"] rl = [ - "atroposlib @ git+https://github.com/NousResearch/atropos.git@sid/traj-saving-eval-mode", + "atroposlib @ git+https://github.com/NousResearch/atropos.git@main", "tinker @ git+https://github.com/thinking-machines-lab/tinker.git", "fastapi>=0.104.0,<1", "uvicorn[standard]>=0.24.0,<1", diff --git a/tools/environments/persistent_shell.py b/tools/environments/persistent_shell.py index b43da71081d..e0332f0685d 100644 --- a/tools/environments/persistent_shell.py +++ b/tools/environments/persistent_shell.py @@ -15,6 +15,8 @@ logger = logging.getLogger(__name__) _SENTINEL_PREFIX = "__HERMES_DONE_" _SENTINEL_SUFFIX = "__" +from tools.ansi_strip import strip_ansi + class PersistentShellMixin: """Mixin that adds persistent shell capability to any BaseEnvironment. @@ -80,7 +82,11 @@ class PersistentShellMixin: init_script = ( # Disable echo so sentinel markers aren't duplicated in stdout f"stty -echo 2>/dev/null\n" - f"export TERM=${{TERM:-dumb}}\n" + # Disable history recording — IPC scaffolding pollutes it. + # Agent commands are added explicitly via `history -s` below. + f"set +o history\n" + f"export HISTFILE=/dev/null\n" + f"export TERM=dumb\n" f"touch {self._pshell_stdout} {self._pshell_stderr} " f"{self._pshell_status} {self._pshell_cwd} {self._pshell_pid_file}\n" f"echo $$ > {self._pshell_pid_file}\n" @@ -99,21 +105,28 @@ class PersistentShellMixin: else: logger.warning("Persistent shell init sentinel not received") - pid_str, reported_cwd = self._read_temp_files( - self._pshell_pid_file, self._pshell_cwd, - ) - pid_str = pid_str.strip() - if pid_str.isdigit(): - self._shell_pid = int(pid_str) + # Retry reading PID file — temp files may not be flushed yet + self._shell_pid = None + reported_cwd = "" + for _ in range(5): + time.sleep(0.2) + pid_str, reported_cwd = self._read_temp_files( + self._pshell_pid_file, self._pshell_cwd, + ) + pid_str = pid_str.strip() + if pid_str.isdigit(): + self._shell_pid = int(pid_str) + break + + if self._shell_pid: logger.info( "Persistent shell started (session=%s, pid=%d)", self._session_id, self._shell_pid, ) else: logger.warning("Could not read persistent shell PID") - self._shell_pid = None - reported_cwd = reported_cwd.strip() + reported_cwd = (reported_cwd or "").strip() if reported_cwd: self.cwd = reported_cwd @@ -166,14 +179,10 @@ class PersistentShellMixin: def _drain_shell_output(self): try: for line in self._shell_proc.stdout: - stripped = line.rstrip('\r\n') - if ( - stripped.startswith(_SENTINEL_PREFIX) - and stripped.endswith(_SENTINEL_SUFFIX) - ): - inner = stripped[len(_SENTINEL_PREFIX):-len(_SENTINEL_SUFFIX)] - if inner == self._sentinel_cmd_id: - self._sentinel_event.set() + clean = strip_ansi(line).strip('\r\n\x00') + expected = f"{_SENTINEL_PREFIX}{self._sentinel_cmd_id}{_SENTINEL_SUFFIX}" + if clean.endswith(expected): + self._sentinel_event.set() except Exception: pass self._shell_alive = False @@ -240,6 +249,7 @@ class PersistentShellMixin: ipc_script = ( f"cd {shlex.quote(work_dir)}\n" + f"history -s {shlex.quote(command)}\n" f"eval '{escaped}' < /dev/null > {self._pshell_stdout} 2> {self._pshell_stderr}\n" f"__EC=$?\n" f"pwd > {self._pshell_cwd}\n" diff --git a/tools/environments/ssh.py b/tools/environments/ssh.py index d21f83fc16f..102085f3f8e 100644 --- a/tools/environments/ssh.py +++ b/tools/environments/ssh.py @@ -68,12 +68,15 @@ class SSHEnvironment(PersistentShellMixin, BaseEnvironment): cmd.extend(["-o", "ControlMaster=auto"]) cmd.extend(["-o", "ControlPersist=300"]) cmd.extend(["-o", "BatchMode=yes"]) - cmd.extend(["-o", "StrictHostKeyChecking=accept-new"]) + cmd.extend(["-o", "StrictHostKeyChecking=no"]) + cmd.extend(["-o", "UserKnownHostsFile=/dev/null"]) + cmd.extend(["-o", "LogLevel=ERROR"]) cmd.extend(["-o", "ConnectTimeout=10"]) if self.port != 22: cmd.extend(["-p", str(self.port)]) if self.key_path: cmd.extend(["-i", self.key_path]) + cmd.extend(["-o", "IdentitiesOnly=yes"]) if extra_args: cmd.extend(extra_args) cmd.append(f"{self.user}@{self.host}") diff --git a/uv.lock b/uv.lock index 64a81ba0e67..7202246faf3 100644 --- a/uv.lock +++ b/uv.lock @@ -262,7 +262,7 @@ sdist = { url = "https://files.pythonhosted.org/packages/87/c6/53da25344e3e3a9c0 [[package]] name = "atroposlib" version = "0.4.0" -source = { git = "https://github.com/NousResearch/atropos.git?rev=sid%2Ftraj-saving-eval-mode#0ab46d65b015363148fccf49b2dc0e2ebd47e9e9" } +source = { git = "https://github.com/NousResearch/atropos.git?rev=main#c20c85256e5a45ad31edf8b7276e9c5ee1995a30" } dependencies = [ { name = "aiofiles" }, { name = "aiohttp" }, @@ -1763,7 +1763,7 @@ requires-dist = [ { name = "aiohttp", marker = "extra == 'messaging'", specifier = ">=3.13.3,<4" }, { name = "aiohttp", marker = "extra == 'sms'", specifier = ">=3.9.0,<4" }, { name = "anthropic", specifier = ">=0.39.0,<1" }, - { name = "atroposlib", marker = "extra == 'rl'", git = "https://github.com/NousResearch/atropos.git?rev=sid%2Ftraj-saving-eval-mode" }, + { name = "atroposlib", marker = "extra == 'rl'", git = "https://github.com/NousResearch/atropos.git?rev=main" }, { name = "croniter", marker = "extra == 'cron'", specifier = ">=6.0.0,<7" }, { name = "daytona", marker = "extra == 'daytona'", specifier = ">=0.148.0,<1" }, { name = "dingtalk-stream", marker = "extra == 'dingtalk'", specifier = ">=0.1.0,<1" },