diff --git a/environments/pwncollege_env/default.yaml b/environments/pwncollege_env/default.yaml index 7ccdc77bfe6..72b0f526e9f 100644 --- a/environments/pwncollege_env/default.yaml +++ b/environments/pwncollege_env/default.yaml @@ -19,7 +19,7 @@ env: base_url: "http://100.120.55.25:8080" ssh_host: "100.120.55.25" ssh_port: 2222 - ssh_key: "environments/pwncollege_env/keys/rl_agent_00" + ssh_key: "environments/pwncollege_env/keys/rl_test_key" # Training: challenge selection # challenge: "hello/hello" # Single challenge (training fallback) diff --git a/environments/pwncollege_env/evaluate_config.yaml b/environments/pwncollege_env/evaluate_config.yaml new file mode 100644 index 00000000000..94e5128a479 --- /dev/null +++ b/environments/pwncollege_env/evaluate_config.yaml @@ -0,0 +1,73 @@ +env: + group_size: 4 + max_num_workers: -1 + max_eval_workers: 16 + max_num_workers_per_node: 8 + steps_per_eval: 100 + max_token_length: 16384 + eval_handling: STOP_TRAIN + eval_limit_ratio: 0.5 + inference_weight: 1.0 + batch_size: -1 + max_batches_offpolicy: 3 + tokenizer_name: NousResearch/Hermes-3-Llama-3.1-8B + use_wandb: false + rollout_server_url: http://localhost:8000 + total_steps: 1000 + wandb_name: pwncollege + num_rollouts_to_keep: 32 + num_rollouts_per_group_for_logging: 1 + ensure_scores_are_not_same: false + data_path_to_save_groups: null + data_dir_to_save_evals: eval_output/pwncollege + min_items_sent_before_logging: 2 + include_messages: false + min_batch_allocation: null + worker_timeout: 600.0 + thinking_mode: false + reasoning_effort: null + max_reasoning_tokens: null + custom_thinking_prompt: null + enabled_toolsets: + - terminal + - file + - pwncollege + disabled_toolsets: null + distribution: null + max_agent_turns: 50 + agent_temperature: 0.7 + terminal_backend: ssh + terminal_timeout: 120 + terminal_lifetime: 3600 + dataset_name: null + dataset_split: train + prompt_field: prompt + tool_pool_size: 128 + tool_call_parser: hermes + extra_body: null + base_url: http://100.120.55.25:8080 + ssh_host: 100.120.55.25 + ssh_port: 2222 + ssh_key: environments/pwncollege_env/keys/rl_test_key + challenge: hello/hello + dojo_filter: null + module_filter: null + eval_dojo: linux-luminarium + eval_exclude_dojos: + - archive + eval_module: hello + eval_concurrency: 16 +openai: + - timeout: 1200 + num_max_requests_at_once: 512 + num_requests_for_eval: 64 + model_name: xiaomi/mimo-v2-flash + rolling_buffer_length: 1000 + server_type: openai + tokenizer_name: none + api_key: ${OPENROUTER_API_KEY} + base_url: https://openrouter.ai/api/v1 + n_kwarg_is_ignored: false + health_check: false +slurm: false +testing: false diff --git a/environments/pwncollege_env/prompts.py b/environments/pwncollege_env/prompts.py index 112ccb38e3f..ff1c5d7b362 100644 --- a/environments/pwncollege_env/prompts.py +++ b/environments/pwncollege_env/prompts.py @@ -17,6 +17,12 @@ When you find a flag, use the submit_flag tool to verify it. - If a challenge needs a long-running process (e.g., a server or listener), run it in the background and interact with it separately. - For binary exploitation, check protections with `checksec` and use `python3` with `pwntools` if available. - Read error messages carefully — they often hint at what the challenge expects. +- Sometimes `/challenge/run` drops you into an interactive shell inside a network namespace. To interact with it, use `tmux`: + ``` + tmux new-session -d -s challenge '/challenge/run' + tmux send-keys -t challenge 'your-command-here' Enter + tmux capture-pane -t challenge -p + ``` """ USER_PROMPT_TEMPLATE = """Solve this pwn.college challenge. diff --git a/environments/pwncollege_env/pwncollege_env.py b/environments/pwncollege_env/pwncollege_env.py index 4c882bafe64..e9666ad3c94 100644 --- a/environments/pwncollege_env/pwncollege_env.py +++ b/environments/pwncollege_env/pwncollege_env.py @@ -16,6 +16,7 @@ Usage: --config environments/pwncollege_env/default.yaml """ +import asyncio import atexit import json import logging @@ -24,6 +25,8 @@ import re import signal import sys import uuid + +import httpx from pathlib import Path from typing import Any, Dict, List, Optional, Tuple, Union @@ -109,6 +112,14 @@ class PwnCollegeEnvConfig(HermesAgentEnvConfig): default=None, description="Module to evaluate on (None = all modules)", ) + eval_exclude_modules: List[str] = Field( + default_factory=list, + description="Modules to exclude from evaluation", + ) + eval_challenges: Optional[List[str]] = Field( + default=None, + description="Specific challenges to evaluate (format: module_id/challenge_id). Overrides dojo/module filters.", + ) eval_concurrency: int = Field( default=4, description="Max concurrent eval episodes (limited by dojo slots)", @@ -280,11 +291,32 @@ class PwnCollegeEnv(HermesAgentBaseEnv): task_id = str(uuid.uuid4()) challenge_key = self._get_challenge_key(item) - try: - inst = await self.client.create_instance(challenge_key) - except Exception as e: - logger.error("Failed to create instance for %s: %s", challenge_key, e) - return None, [] + max_retries = 5 + inst = None + for attempt in range(max_retries): + try: + inst = await self.client.create_instance(challenge_key) + break + except Exception as e: + err_str = str(e) + is_transient = ( + isinstance(e, httpx.HTTPStatusError) and e.response.status_code >= 500 + or isinstance(e, (httpx.ReadTimeout, httpx.ConnectTimeout, httpx.ConnectError)) + or "No available slots" in err_str + ) + if is_transient and attempt < max_retries - 1: + wait = min(2 ** (attempt + 1), 30) + logger.warning( + "Transient error creating instance for %s (attempt %d/%d): %s, retrying in %ds", + challenge_key, attempt + 1, max_retries, err_str[:100], wait, + ) + await asyncio.sleep(wait) + else: + logger.error( + "Failed to create instance for %s after %d attempts: %s", + challenge_key, attempt + 1, e, + ) + return None, [] slot = inst.slot self._active_slots.add(slot) @@ -416,7 +448,6 @@ class PwnCollegeEnv(HermesAgentBaseEnv): Fetches challenges matching eval_dojo/eval_module, runs each through the agent loop with concurrency control, and logs results. """ - import asyncio import time if not self.client: @@ -427,12 +458,17 @@ class PwnCollegeEnv(HermesAgentBaseEnv): # Fetch and filter eval challenges all_challenges = await self.client.list_challenges() - eval_challenges = [ - c for c in all_challenges - if (self.config.eval_dojo is None or c.dojo_id == self.config.eval_dojo) - and (self.config.eval_module is None or c.module_id == self.config.eval_module) - and c.dojo_id not in self.config.eval_exclude_dojos - ] + if self.config.eval_challenges: + challenge_set = set(self.config.eval_challenges) + eval_challenges = [c for c in all_challenges if c.challenge_key in challenge_set] + else: + eval_challenges = [ + c for c in all_challenges + if (self.config.eval_dojo is None or c.dojo_id == self.config.eval_dojo) + and (self.config.eval_module is None or c.module_id == self.config.eval_module) + and c.dojo_id not in self.config.eval_exclude_dojos + and c.module_id not in self.config.eval_exclude_modules + ] if not eval_challenges: logger.warning( @@ -467,18 +503,32 @@ class PwnCollegeEnv(HermesAgentBaseEnv): f"(reward={reward:.1f})", flush=True, ) - return { + result = { "challenge": challenge_key, "name": challenge.name, "solved": solved, "reward": reward, } + # Stream-write sample with full conversation for HTML viewer + self.log_eval_sample({ + "score": reward, + "challenge": challenge_key, + "solved": solved, + "messages": scored.get("messages", []) if scored else [], + }) + return result except Exception as e: completed += 1 print( f" [{completed}/{total}] [ERR ] {challenge_key}: {e}", flush=True, ) + self.log_eval_sample({ + "score": 0.0, + "challenge": challenge_key, + "solved": False, + "messages": [{"role": "system", "content": f"Error: {e}"}], + }) return { "challenge": challenge_key, "name": challenge.name, @@ -511,19 +561,8 @@ class PwnCollegeEnv(HermesAgentBaseEnv): "eval/total": n, } - samples = [ - { - "prompt": r["challenge"], - "response": "SOLVED" if r["solved"] else "FAILED", - "expected": "SOLVED", - "reward": r["reward"], - } - for r in results - ] - await self.evaluate_log( metrics=eval_metrics, - samples=samples, start_time=start_time, end_time=end_time, ) diff --git a/environments/pwncollege_env/stress_test.py b/environments/pwncollege_env/stress_test.py new file mode 100644 index 00000000000..4ecd15684ff --- /dev/null +++ b/environments/pwncollege_env/stress_test.py @@ -0,0 +1,513 @@ +""" +Capability verification test for pwn-dojo RL infrastructure. + +Verifies that RL containers are provisioned with the correct Linux capabilities, +resource limits, and host configuration for each challenge type. + +Usage: + python environments/pwncollege_env/stress_test.py -y + python environments/pwncollege_env/stress_test.py -y -o report.json --verbose +""" + +import argparse +import asyncio +import json +import sys +import time +from dataclasses import asdict, dataclass, field +from pathlib import Path + +_repo_root = Path(__file__).resolve().parent.parent.parent +if str(_repo_root) not in sys.path: + sys.path.insert(0, str(_repo_root)) + +from environments.pwncollege_env.sdk import DojoRLClient + + +@dataclass +class SSHConfig: + host: str + port: int + key: str + + +@dataclass +class CheckResult: + name: str + passed: bool + message: str + duration: float = 0.0 + + +@dataclass +class TestResult: + name: str + challenge: str + checks: list[CheckResult] = field(default_factory=list) + passed: bool = False + skipped: bool = False + error: str | None = None + duration: float = 0.0 + + +@dataclass +class TestCase: + name: str + challenge: str + checks: list + + +async def ssh_run( + cfg: SSHConfig, user: str, command: str, timeout: float = 30.0 +) -> tuple[int, str]: + """Run a command over SSH via subprocess. Returns (returncode, output).""" + cmd = [ + "ssh", + "-o", + "BatchMode=yes", + "-o", + "StrictHostKeyChecking=accept-new", + "-o", + "UserKnownHostsFile=/dev/null", + "-o", + "ConnectTimeout=10", + "-o", + "LogLevel=ERROR", + "-p", + str(cfg.port), + "-i", + cfg.key, + f"{user}@{cfg.host}", + command, + ] + proc = await asyncio.create_subprocess_exec( + *cmd, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.STDOUT, + ) + try: + stdout, _ = await asyncio.wait_for(proc.communicate(), timeout=timeout) + return proc.returncode, stdout.decode(errors="replace") + except asyncio.TimeoutError: + proc.kill() + await proc.wait() + return -1, f"[SSH timeout after {timeout}s]" + + +async def wait_ssh_ready(cfg: SSHConfig, user: str, retries: int = 10) -> bool: + for i in range(retries): + rc, out = await ssh_run(cfg, user, "echo ready", timeout=10) + if rc == 0 and "ready" in out: + return True + await asyncio.sleep(1) + return False + + +# ── Check functions ────────────────────────────────────────────────────────── + + +async def check_ssh_echo(cfg: SSHConfig, user: str) -> CheckResult: + t0 = time.monotonic() + rc, out = await ssh_run(cfg, user, "echo ok") + dur = time.monotonic() - t0 + if rc == 0 and "ok" in out: + return CheckResult("ssh_echo", True, "connected", dur) + return CheckResult("ssh_echo", False, f"rc={rc}: {out.strip()[:100]}", dur) + + +async def check_unshare_net(cfg: SSHConfig, user: str) -> CheckResult: + t0 = time.monotonic() + rc, out = await ssh_run(cfg, user, "unshare --net echo ok") + dur = time.monotonic() - t0 + if rc == 0 and "ok" in out: + return CheckResult("unshare_net", True, "namespace creation works", dur) + return CheckResult("unshare_net", False, f"rc={rc}: {out.strip()[:120]}", dur) + + +async def check_unshare_user(cfg: SSHConfig, user: str) -> CheckResult: + t0 = time.monotonic() + rc, out = await ssh_run(cfg, user, "unshare --user --map-root-user bash -c 'id'") + dur = time.monotonic() - t0 + if rc == 0 and "uid=0" in out: + return CheckResult("unshare_user", True, "user namespace works", dur) + return CheckResult("unshare_user", False, f"rc={rc}: {out.strip()[:120]}", dur) + + +async def check_capeff(cfg: SSHConfig, user: str) -> CheckResult: + """Check that the container init (PID 1) has SYS_ADMIN capability.""" + t0 = time.monotonic() + rc, out = await ssh_run(cfg, user, "cat /proc/1/status") + dur = time.monotonic() - t0 + if rc != 0: + return CheckResult( + "capeff", False, f"Cannot read /proc/1/status: {out.strip()[:80]}", dur + ) + for line in out.splitlines(): + if line.startswith("CapEff:") or line.startswith("CapBnd:"): + hex_val = line.split(":")[1].strip() + try: + val = int(hex_val, 16) + has_sysadmin = bool(val & (1 << 21)) + if has_sysadmin: + label = line.split(":")[0] + return CheckResult( + "capeff", True, f"{label}={hex_val} has SYS_ADMIN", dur + ) + except ValueError: + pass + return CheckResult( + "capeff", False, "SYS_ADMIN (bit 21) not found in capabilities", dur + ) + + +async def check_hosts_resolution(cfg: SSHConfig, user: str) -> CheckResult: + t0 = time.monotonic() + rc, out = await ssh_run(cfg, user, "getent hosts challenge.localhost") + dur = time.monotonic() - t0 + if rc == 0 and out.strip(): + return CheckResult( + "hosts_resolution", True, f"resolves to {out.strip()[:40]}", dur + ) + rc2, out2 = await ssh_run(cfg, user, "grep challenge.localhost /etc/hosts") + dur = time.monotonic() - t0 + if rc2 == 0 and "challenge.localhost" in out2: + return CheckResult( + "hosts_resolution", True, "/etc/hosts has entry", dur + ) + return CheckResult( + "hosts_resolution", False, "challenge.localhost not resolvable", dur + ) + + +async def check_pids_limit(cfg: SSHConfig, user: str) -> CheckResult: + t0 = time.monotonic() + rc, out = await ssh_run( + cfg, + user, + "cat /sys/fs/cgroup/pids.max 2>/dev/null || cat /sys/fs/cgroup/pids/pids.max 2>/dev/null", + ) + dur = time.monotonic() - t0 + val = out.strip() + if val == "max": + return CheckResult("pids_limit", True, "unlimited", dur) + try: + limit = int(val) + if limit >= 1024: + return CheckResult("pids_limit", True, f"pids_limit={limit}", dur) + return CheckResult( + "pids_limit", False, f"pids_limit={limit} (need >= 1024)", dur + ) + except ValueError: + return CheckResult("pids_limit", False, f"Cannot parse: {val[:60]}", dur) + + +async def check_mem_limit(cfg: SSHConfig, user: str) -> CheckResult: + t0 = time.monotonic() + rc, out = await ssh_run( + cfg, + user, + "cat /sys/fs/cgroup/memory.max 2>/dev/null || cat /sys/fs/cgroup/memory/memory.limit_in_bytes 2>/dev/null", + ) + dur = time.monotonic() - t0 + val = out.strip() + if val == "max": + return CheckResult("mem_limit", True, "unlimited", dur) + try: + limit = int(val) + limit_gb = limit / (1024**3) + if ( + limit_gb >= 1.8 + ): # 2GB for privileged RL containers (not 4GB to manage memory pressure) + return CheckResult("mem_limit", True, f"mem={limit_gb:.1f}GB", dur) + return CheckResult( + "mem_limit", False, f"mem={limit_gb:.1f}GB (need >= 2GB)", dur + ) + except ValueError: + return CheckResult("mem_limit", False, f"Cannot parse: {val[:60]}", dur) + + +async def check_challenge_run(cfg: SSHConfig, user: str) -> CheckResult: + """Run /challenge/run and verify no PermissionError.""" + t0 = time.monotonic() + rc, out = await ssh_run(cfg, user, "/challenge/run < /dev/null", timeout=15) + dur = time.monotonic() - t0 + if "PermissionError" in out or "Operation not permitted" in out: + snippet = [l for l in out.splitlines() if "Permission" in l or "Operation" in l] + return CheckResult( + "challenge_run", + False, + snippet[0][:120] if snippet else "PermissionError", + dur, + ) + return CheckResult("challenge_run", True, f"No permission errors (rc={rc})", dur) + + +# ── Test cases ─────────────────────────────────────────────────────────────── + +TEST_CASES = [ + TestCase("unprivileged_basic", "hello/hello", [check_ssh_echo]), + TestCase( + "privileged_caps", + "intercepting-communication/udp-1", + [check_ssh_echo, check_capeff], + ), + TestCase( + "privileged_challenge_run", + "intercepting-communication/udp-1", + [check_challenge_run], + ), + TestCase( + "web_challenge_hosts", + "web-security/path-traversal-1", + [check_ssh_echo, check_hosts_resolution], + ), + TestCase( + "resource_limits", + "intercepting-communication/udp-1", + [check_pids_limit, check_mem_limit], + ), +] + + +# ── Runner ─────────────────────────────────────────────────────────────────── + + +async def run_tests(args) -> dict: + cfg = SSHConfig(host=args.ssh_host, port=args.ssh_port, key=args.ssh_key) + client = DojoRLClient(args.base_url) + + status = await client.status() + print( + f"Server: {args.base_url} (RL={'enabled' if status.enabled else 'DISABLED'}, " + f"{status.max_instances} max, {status.running} running)" + ) + if status.running > 0: + n = await client.destroy_all() + print(f"Cleaned up {n} instance(s)") + print() + + results: list[TestResult] = [] + test_num = 0 + total = len(TEST_CASES) + (0 if args.skip_concurrent else 1) + start_time = time.monotonic() + + for tc in TEST_CASES: + test_num += 1 + t0 = time.monotonic() + tr = TestResult(name=tc.name, challenge=tc.challenge) + print(f"[{test_num}/{total}] {tc.name} ({tc.challenge})") + + try: + inst = await client.create_instance(tc.challenge) + except Exception as e: + err = str(e) + if "404" in err or "not found" in err.lower() or "Invalid" in err: + tr.skipped = True + tr.error = f"Challenge not available: {err[:80]}" + print(f" SKIP {tr.error}") + else: + tr.error = f"create_instance failed: {err[:100]}" + print(f" ERR {tr.error}") + tr.duration = time.monotonic() - t0 + results.append(tr) + print(f" --- {'SKIP' if tr.skipped else 'FAIL'} ({tr.duration:.1f}s)\n") + continue + + try: + ready = await wait_ssh_ready(cfg, inst.ssh_user) + if not ready: + tr.error = "SSH not ready after 10 retries" + tr.checks.append( + CheckResult("ssh_ready", False, tr.error, time.monotonic() - t0) + ) + print(f" FAIL ssh_ready: {tr.error}") + else: + for check_fn in tc.checks: + cr = await check_fn(cfg, inst.ssh_user) + tr.checks.append(cr) + tag = "PASS" if cr.passed else "FAIL" + extra = f" ({cr.message})" if args.verbose or not cr.passed else "" + print(f" {tag} {cr.name:30s} {cr.duration:.1f}s{extra}") + if not cr.passed: + break + finally: + try: + await client.destroy_instance(inst.slot) + except Exception as e: + print(f" WARN destroy failed: {e}") + + tr.passed = all(c.passed for c in tr.checks) and not tr.error + tr.duration = time.monotonic() - t0 + results.append(tr) + print(f" --- {'PASS' if tr.passed else 'FAIL'} ({tr.duration:.1f}s)\n") + + if not args.skip_concurrent: + test_num += 1 + t0 = time.monotonic() + tr = TestResult(name="concurrent_lifecycle", challenge="8x hello/hello") + n_concurrent = min(8, status.max_instances) + print( + f"[{test_num}/{total}] concurrent_lifecycle ({n_concurrent}x hello/hello)" + ) + + try: + ct0 = time.monotonic() + tasks = [client.create_instance("hello/hello") for _ in range(n_concurrent)] + instances = await asyncio.gather(*tasks, return_exceptions=True) + create_dur = time.monotonic() - ct0 + + created = [i for i in instances if not isinstance(i, Exception)] + errors = [i for i in instances if isinstance(i, Exception)] + if errors: + tr.checks.append( + CheckResult( + "create_all", + False, + f"{len(errors)}/{n_concurrent} failed: {errors[0]}", + create_dur, + ) + ) + else: + tr.checks.append( + CheckResult( + "create_all", True, f"{n_concurrent} created", create_dur + ) + ) + + if created: + await asyncio.sleep(3) + et0 = time.monotonic() + echo_tasks = [ + ssh_run(cfg, i.ssh_user, "echo ok", timeout=15) for i in created + ] + echo_results = await asyncio.gather(*echo_tasks, return_exceptions=True) + echo_ok = sum( + 1 + for r in echo_results + if not isinstance(r, Exception) and r[0] == 0 + ) + tr.checks.append( + CheckResult( + "ssh_echo_all", + echo_ok == len(created), + f"{echo_ok}/{len(created)} connected", + time.monotonic() - et0, + ) + ) + + dt0 = time.monotonic() + destroyed = await client.destroy_all() + tr.checks.append( + CheckResult( + "destroy_all", + True, + f"destroyed {destroyed}", + time.monotonic() - dt0, + ) + ) + + st = await client.status() + live = sum(1 for i in st.instances if i.status == "running") + tr.checks.append( + CheckResult( + "slot_cleanup", + live == 0, + f"running={live} (total listed={st.running})", + 0.0, + ) + ) + except Exception as e: + tr.error = str(e)[:200] + tr.checks.append(CheckResult("concurrent", False, str(e)[:100], 0.0)) + + tr.passed = all(c.passed for c in tr.checks) and not tr.error + tr.duration = time.monotonic() - t0 + results.append(tr) + for cr in tr.checks: + tag = "PASS" if cr.passed else "FAIL" + extra = f" ({cr.message})" if args.verbose or not cr.passed else "" + print(f" {tag} {cr.name:30s} {cr.duration:.1f}s{extra}") + print(f" --- {'PASS' if tr.passed else 'FAIL'} ({tr.duration:.1f}s)\n") + + total_dur = time.monotonic() - start_time + passed = sum(1 for r in results if r.passed) + failed = sum(1 for r in results if not r.passed and not r.skipped) + skipped = sum(1 for r in results if r.skipped) + + print("=" * 50) + parts = [f"{passed}/{len(results)} passed"] + if failed: + parts.append(f"{failed} failed") + if skipped: + parts.append(f"{skipped} skipped") + print(f"RESULTS: {', '.join(parts)} in {total_dur:.0f}s") + print("=" * 50) + + return { + "test": "capability_verification", + "timestamp": time.strftime("%Y-%m-%dT%H:%M:%S%z"), + "server": args.base_url, + "summary": { + "total": len(results), + "passed": passed, + "failed": failed, + "skipped": skipped, + "duration_seconds": round(total_dur, 1), + }, + "tests": [ + { + "name": r.name, + "challenge": r.challenge, + "passed": r.passed, + "skipped": r.skipped, + "error": r.error, + "duration": round(r.duration, 1), + "checks": [asdict(c) for c in r.checks], + } + for r in results + ], + } + + +def main(): + parser = argparse.ArgumentParser( + description="Capability verification test for pwn-dojo RL infrastructure", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument("--base-url", default="http://100.120.55.25:8080") + parser.add_argument("--ssh-host", default="100.120.55.25") + parser.add_argument("--ssh-port", type=int, default=2222) + parser.add_argument( + "--ssh-key", default="environments/pwncollege_env/keys/rl_test_key" + ) + parser.add_argument("--output", "-o", help="Write JSON report") + parser.add_argument("--skip-concurrent", action="store_true") + parser.add_argument("--verbose", "-v", action="store_true") + parser.add_argument("--yes", "-y", action="store_true", help="Skip confirmation") + args = parser.parse_args() + + key = Path(args.ssh_key) + if not key.exists(): + key = _repo_root / args.ssh_key + if not key.exists(): + print(f"SSH key not found: {args.ssh_key}") + sys.exit(1) + args.ssh_key = str(key) + + if not args.yes: + print(f"Will test against {args.base_url}") + if input("Continue? [y/N] ").lower() != "y": + sys.exit(0) + + report = asyncio.run(run_tests(args)) + + if args.output: + with open(args.output, "w") as f: + json.dump(report, f, indent=2) + print(f"\nJSON report: {args.output}") + + sys.exit(0 if report["summary"]["failed"] == 0 else 1) + + +if __name__ == "__main__": + main() diff --git a/pyproject.toml b/pyproject.toml index c0a7078ee38..d670d044f23 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,7 +58,7 @@ sms = ["aiohttp>=3.9.0,<4"] acp = ["agent-client-protocol>=0.8.1,<0.9"] dingtalk = ["dingtalk-stream>=0.1.0,<1"] rl = [ - "atroposlib @ git+https://github.com/NousResearch/atropos.git", + "atroposlib @ git+https://github.com/NousResearch/atropos.git@sid/traj-saving-eval-mode", "tinker @ git+https://github.com/thinking-machines-lab/tinker.git", "fastapi>=0.104.0,<1", "uvicorn[standard]>=0.24.0,<1", diff --git a/tools/environments/persistent_shell.py b/tools/environments/persistent_shell.py index b1280bf4e0d..b43da71081d 100644 --- a/tools/environments/persistent_shell.py +++ b/tools/environments/persistent_shell.py @@ -12,6 +12,9 @@ from tools.interrupt import is_interrupted logger = logging.getLogger(__name__) +_SENTINEL_PREFIX = "__HERMES_DONE_" +_SENTINEL_SUFFIX = "__" + class PersistentShellMixin: """Mixin that adds persistent shell capability to any BaseEnvironment. @@ -40,8 +43,6 @@ class PersistentShellMixin: def _cleanup_temp_files(self): ... _session_id: str = "" - _poll_interval_start: float = 0.01 # initial poll interval (10ms) - _poll_interval_max: float = 0.25 # max poll interval (250ms) — reduces I/O for long commands @property def _temp_prefix(self) -> str: @@ -56,6 +57,8 @@ class PersistentShellMixin: self._shell_proc: subprocess.Popen | None = None self._shell_alive: bool = False self._shell_pid: int | None = None + self._sentinel_event = threading.Event() + self._sentinel_cmd_id: str | None = None self._session_id = uuid.uuid4().hex[:12] p = self._temp_prefix @@ -73,33 +76,44 @@ class PersistentShellMixin: ) self._drain_thread.start() + init_cmd_id = "init" init_script = ( + # Disable echo so sentinel markers aren't duplicated in stdout + f"stty -echo 2>/dev/null\n" f"export TERM=${{TERM:-dumb}}\n" f"touch {self._pshell_stdout} {self._pshell_stderr} " f"{self._pshell_status} {self._pshell_cwd} {self._pshell_pid_file}\n" f"echo $$ > {self._pshell_pid_file}\n" f"pwd > {self._pshell_cwd}\n" + f"echo '{_SENTINEL_PREFIX}{init_cmd_id}{_SENTINEL_SUFFIX}'\n" ) + self._sentinel_event.clear() + self._sentinel_cmd_id = init_cmd_id self._send_to_shell(init_script) - deadline = time.monotonic() + 3.0 + deadline = time.monotonic() + 10.0 while time.monotonic() < deadline: - pid_str = self._read_temp_files(self._pshell_pid_file)[0].strip() - if pid_str.isdigit(): - self._shell_pid = int(pid_str) + remaining = deadline - time.monotonic() + if self._sentinel_event.wait(timeout=min(remaining, 0.5)): break - time.sleep(0.05) else: - logger.warning("Could not read persistent shell PID") - self._shell_pid = None + logger.warning("Persistent shell init sentinel not received") - if self._shell_pid: + pid_str, reported_cwd = self._read_temp_files( + self._pshell_pid_file, self._pshell_cwd, + ) + pid_str = pid_str.strip() + if pid_str.isdigit(): + self._shell_pid = int(pid_str) logger.info( "Persistent shell started (session=%s, pid=%d)", self._session_id, self._shell_pid, ) + else: + logger.warning("Could not read persistent shell PID") + self._shell_pid = None - reported_cwd = self._read_temp_files(self._pshell_cwd)[0].strip() + reported_cwd = reported_cwd.strip() if reported_cwd: self.cwd = reported_cwd @@ -151,11 +165,19 @@ class PersistentShellMixin: def _drain_shell_output(self): try: - for _ in self._shell_proc.stdout: - pass + for line in self._shell_proc.stdout: + stripped = line.rstrip('\r\n') + if ( + stripped.startswith(_SENTINEL_PREFIX) + and stripped.endswith(_SENTINEL_SUFFIX) + ): + inner = stripped[len(_SENTINEL_PREFIX):-len(_SENTINEL_SUFFIX)] + if inner == self._sentinel_cmd_id: + self._sentinel_event.set() except Exception: pass self._shell_alive = False + self._sentinel_event.set() def _send_to_shell(self, text: str): if not self._shell_alive or self._shell_proc is None: @@ -222,21 +244,16 @@ class PersistentShellMixin: f"__EC=$?\n" f"pwd > {self._pshell_cwd}\n" f"echo {cmd_id}:$__EC > {self._pshell_status}\n" + f"echo '{_SENTINEL_PREFIX}{cmd_id}{_SENTINEL_SUFFIX}'\n" ) + self._sentinel_event.clear() + self._sentinel_cmd_id = cmd_id self._send_to_shell(ipc_script) deadline = time.monotonic() + timeout - poll_interval = self._poll_interval_start # starts at 10ms, backs off to 250ms while True: - if is_interrupted(): - self._kill_shell_children() - output, _, _ = self._read_persistent_output() - return { - "output": output + "\n[Command interrupted]", - "returncode": 130, - } - - if time.monotonic() > deadline: + remaining = deadline - time.monotonic() + if remaining <= 0: self._kill_shell_children() output, _, _ = self._read_persistent_output() if output: @@ -246,22 +263,23 @@ class PersistentShellMixin: } return self._timeout_result(timeout) + if is_interrupted(): + self._kill_shell_children() + output, _, _ = self._read_persistent_output() + return { + "output": output + "\n[Command interrupted]", + "returncode": 130, + } + if not self._shell_alive: return { "output": "Persistent shell died during execution", "returncode": 1, } - status_content = self._read_temp_files(self._pshell_status)[0].strip() - if status_content.startswith(cmd_id + ":"): + if self._sentinel_event.wait(timeout=min(remaining, 0.5)): break - time.sleep(poll_interval) - # Exponential backoff: fast start (10ms) for quick commands, - # ramps up to 250ms for long-running commands — reduces I/O by 10-25x - # on WSL2 where polling keeps the VM hot and memory pressure high. - poll_interval = min(poll_interval * 1.5, self._poll_interval_max) - output, exit_code, new_cwd = self._read_persistent_output() if new_cwd: self.cwd = new_cwd diff --git a/tools/environments/ssh.py b/tools/environments/ssh.py index 83891fe2006..d21f83fc16f 100644 --- a/tools/environments/ssh.py +++ b/tools/environments/ssh.py @@ -83,21 +83,19 @@ class SSHEnvironment(PersistentShellMixin, BaseEnvironment): cmd = self._build_ssh_command() cmd.append("echo 'SSH connection established'") try: - result = subprocess.run(cmd, capture_output=True, text=True, timeout=15) + result = subprocess.run(cmd, capture_output=True, text=True, errors="replace", timeout=15) if result.returncode != 0: error_msg = result.stderr.strip() or result.stdout.strip() raise RuntimeError(f"SSH connection failed: {error_msg}") except subprocess.TimeoutExpired: raise RuntimeError(f"SSH connection to {self.user}@{self.host} timed out") - _poll_interval_start: float = 0.15 # SSH: higher initial interval (150ms) for network latency - @property def _temp_prefix(self) -> str: return f"/tmp/hermes-ssh-{self._session_id}" def _spawn_shell_process(self) -> subprocess.Popen: - cmd = self._build_ssh_command() + cmd = self._build_ssh_command(extra_args=["-tt"]) cmd.append("bash -l") return subprocess.Popen( cmd, @@ -105,6 +103,7 @@ class SSHEnvironment(PersistentShellMixin, BaseEnvironment): stdout=subprocess.PIPE, stderr=subprocess.DEVNULL, text=True, + errors="replace", ) def _read_temp_files(self, *paths: str) -> list[str]: @@ -113,7 +112,7 @@ class SSHEnvironment(PersistentShellMixin, BaseEnvironment): cmd.append(f"cat {paths[0]} 2>/dev/null") try: result = subprocess.run( - cmd, capture_output=True, text=True, timeout=10, + cmd, capture_output=True, text=True, errors="replace", timeout=10, ) return [result.stdout] except (subprocess.TimeoutExpired, OSError): @@ -127,7 +126,7 @@ class SSHEnvironment(PersistentShellMixin, BaseEnvironment): cmd.append(script) try: result = subprocess.run( - cmd, capture_output=True, text=True, timeout=10, + cmd, capture_output=True, text=True, errors="replace", timeout=10, ) parts = result.stdout.split(delim + "\n") return [parts[i] if i < len(parts) else "" for i in range(len(paths))] @@ -179,6 +178,7 @@ class SSHEnvironment(PersistentShellMixin, BaseEnvironment): stderr=subprocess.STDOUT, stdin=subprocess.PIPE if effective_stdin else subprocess.DEVNULL, text=True, + errors="replace", ) if effective_stdin: diff --git a/uv.lock b/uv.lock index 48720c67fba..64a81ba0e67 100644 --- a/uv.lock +++ b/uv.lock @@ -262,7 +262,7 @@ sdist = { url = "https://files.pythonhosted.org/packages/87/c6/53da25344e3e3a9c0 [[package]] name = "atroposlib" version = "0.4.0" -source = { git = "https://github.com/NousResearch/atropos.git#c421582b6f7ce8a32f751aab3117d3824ac8f709" } +source = { git = "https://github.com/NousResearch/atropos.git?rev=sid%2Ftraj-saving-eval-mode#0ab46d65b015363148fccf49b2dc0e2ebd47e9e9" } dependencies = [ { name = "aiofiles" }, { name = "aiohttp" }, @@ -1758,12 +1758,12 @@ yc-bench = [ [package.metadata] requires-dist = [ - { name = "agent-client-protocol", marker = "extra == 'acp'", specifier = ">=0.8.1,<1.0" }, + { name = "agent-client-protocol", marker = "extra == 'acp'", specifier = ">=0.8.1,<0.9" }, { name = "aiohttp", marker = "extra == 'homeassistant'", specifier = ">=3.9.0,<4" }, { name = "aiohttp", marker = "extra == 'messaging'", specifier = ">=3.13.3,<4" }, { name = "aiohttp", marker = "extra == 'sms'", specifier = ">=3.9.0,<4" }, { name = "anthropic", specifier = ">=0.39.0,<1" }, - { name = "atroposlib", marker = "extra == 'rl'", git = "https://github.com/NousResearch/atropos.git" }, + { name = "atroposlib", marker = "extra == 'rl'", git = "https://github.com/NousResearch/atropos.git?rev=sid%2Ftraj-saving-eval-mode" }, { name = "croniter", marker = "extra == 'cron'", specifier = ">=6.0.0,<7" }, { name = "daytona", marker = "extra == 'daytona'", specifier = ">=0.148.0,<1" }, { name = "dingtalk-stream", marker = "extra == 'dingtalk'", specifier = ">=0.1.0,<1" },