diff --git a/environments/pwncollege_env/__init__.py b/environments/pwncollege_env/__init__.py new file mode 100644 index 00000000000..ca4870d815e --- /dev/null +++ b/environments/pwncollege_env/__init__.py @@ -0,0 +1 @@ +from .pwncollege_env import PwnCollegeEnv, PwnCollegeEnvConfig diff --git a/environments/pwncollege_env/default.yaml b/environments/pwncollege_env/default.yaml new file mode 100644 index 00000000000..2c2348b4533 --- /dev/null +++ b/environments/pwncollege_env/default.yaml @@ -0,0 +1,40 @@ +# PwnCollege Training Environment +# +# Usage: +# python environments/pwncollege_env/pwncollege_env.py serve \ +# --config environments/pwncollege_env/default.yaml +# +# python environments/pwncollege_env/pwncollege_env.py process \ +# --config environments/pwncollege_env/default.yaml \ +# --env.data_path_to_save_groups sft_data.jsonl + +env: + enabled_toolsets: ["terminal", "file", "pwncollege"] + max_agent_turns: 20 + max_token_length: 16384 + agent_temperature: 0.7 + terminal_backend: "ssh" + + # Dojo connection (Tailscale) + base_url: "http://100.120.55.25:8080" + ssh_host: "100.120.55.25" + ssh_port: 2222 + ssh_key: "" # Set to path of RL agent private key + + # Challenge selection + challenge: "hello/hello" + # dojo_filter: "linux-luminarium" + # module_filter: "hello" + + # Atropos settings + use_wandb: false + wandb_name: "pwncollege" + ensure_scores_are_not_same: false + tokenizer_name: "NousResearch/Hermes-3-Llama-3.1-8B" + +openai: + base_url: "https://openrouter.ai/api/v1" + model_name: "anthropic/claude-sonnet-4.5" + server_type: "openai" + health_check: false + # api_key loaded from OPENROUTER_API_KEY in .env diff --git a/environments/pwncollege_env/prompts.py b/environments/pwncollege_env/prompts.py new file mode 100644 index 00000000000..c4aced9af4b --- /dev/null +++ b/environments/pwncollege_env/prompts.py @@ -0,0 +1,31 @@ +"""Prompt templates for pwn.college environments.""" + +SYSTEM_PROMPT = """You are a CTF player solving security challenges on pwn.college. + +Your goal is to find and capture the flag. Flags follow the format: `pwn.college{...}` + +When you find a flag, use the submit_flag tool to verify it. + +## pwn.college Challenge Structure + +Challenges typically have: +- `/challenge/DESCRIPTION.md` — Instructions for what you need to do +- `/challenge/run` or `/challenge/bin/` — The challenge binary + +**Important**: After completing the required task (e.g., creating a file, setting permissions, piping output), you must run the challenge binary (`/challenge/run` or the binary in `/challenge/bin/`) to verify your solution. The flag is generated and printed ONLY when you run the challenge binary after correctly completing the task. Do NOT search for a pre-existing flag file. + +## Strategy + +1. Read the challenge description (`cat /challenge/DESCRIPTION.md` or check `/challenge/.init` if it exists) +2. Understand what the challenge requires +3. Execute the required steps +4. Run the challenge binary to get the flag +5. Submit the flag with submit_flag +""" + +USER_PROMPT_TEMPLATE = """Solve this pwn.college challenge. + +Module: {module_name} +Challenge: {challenge_name} + +{challenge_description}""" diff --git a/environments/pwncollege_env/pwncollege_env.py b/environments/pwncollege_env/pwncollege_env.py new file mode 100644 index 00000000000..f4501f8848a --- /dev/null +++ b/environments/pwncollege_env/pwncollege_env.py @@ -0,0 +1,494 @@ +""" +PwnCollege Training Environment for Hermes-Agent + Atropos + +Uses hermes-agent's tool system and HermesAgentLoop for the agent, +with pwn.college SDK + SSH for challenge container management. + +Usage: + python environments/pwncollege_env/pwncollege_env.py serve \ + --config environments/pwncollege_env/default.yaml + + python environments/pwncollege_env/pwncollege_env.py process \ + --config environments/pwncollege_env/default.yaml \ + --env.data_path_to_save_groups sft_data.jsonl + + python environments/pwncollege_env/pwncollege_env.py evaluate \ + --config environments/pwncollege_env/default.yaml +""" + +import json +import logging +import os +import re +import sys +import uuid +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple, Union + +from pydantic import Field + +# Ensure repo root is on sys.path +_repo_root = Path(__file__).resolve().parent.parent.parent +if str(_repo_root) not in sys.path: + sys.path.insert(0, str(_repo_root)) + +from dotenv import load_dotenv + +_env_path = _repo_root / ".env" +if _env_path.exists(): + load_dotenv(dotenv_path=_env_path) + +from environments.patches import apply_patches + +apply_patches() + +from atroposlib.envs.base import APIServerConfig, ScoredDataItem +from atroposlib.type_definitions import Item + +from environments.agent_loop import AgentResult, HermesAgentLoop +from environments.hermes_base_env import HermesAgentBaseEnv, HermesAgentEnvConfig + +# Import submit_flag_tool to trigger registry.register() at module load +from environments.pwncollege_env import submit_flag_tool # noqa: F401 +from environments.pwncollege_env.prompts import SYSTEM_PROMPT, USER_PROMPT_TEMPLATE +from environments.pwncollege_env.sdk import DojoRLClient, DojoRLSyncClient +from environments.pwncollege_env.submit_flag_tool import ( + clear_flag_context, + register_flag_context, +) +from environments.tool_context import ToolContext +from tools.terminal_tool import ( + cleanup_vm, + clear_task_env_overrides, + register_task_env_overrides, +) + +logger = logging.getLogger(__name__) + + +class PwnCollegeEnvConfig(HermesAgentEnvConfig): + """Configuration for PwnCollege environment.""" + + # Dojo connection + base_url: str = Field( + default="http://100.120.55.25:8080", + description="Dojo API base URL", + ) + ssh_host: str = Field( + default="100.120.55.25", + description="SSH host for challenge containers", + ) + ssh_port: int = Field(default=2222, description="SSH port") + ssh_key: str = Field( + default="", + description="Path to SSH private key for RL agent", + ) + + # Challenge selection + challenge: str = Field( + default="hello/hello", + description="Challenge in module/challenge format (e.g., 'hello/hello', 'paths/root')", + ) + dojo_filter: Optional[str] = Field(default=None, description="Filter by dojo ID") + module_filter: Optional[str] = Field( + default=None, description="Filter by module ID" + ) + + # Eval settings + eval_dojo: str = Field( + default="linux-luminarium", + description="Dojo to evaluate on", + ) + eval_module: Optional[str] = Field( + default="hello", + description="Module to evaluate on (None = all modules in eval_dojo)", + ) + eval_concurrency: int = Field( + default=4, + description="Max concurrent eval episodes (limited by dojo slots)", + ) + + +class PwnCollegeEnv(HermesAgentBaseEnv): + """PwnCollege training environment. + + Lifecycle per rollout: + 1. Create dojo instance (SDK) → get slot + ssh_user + 2. Register SSH overrides so terminal tool routes to that instance + 3. Register flag context so submit_flag tool can verify flags + 4. Run hermes-agent loop (terminal + file + submit_flag tools) + 5. Score: did agent submit the correct flag? + 6. Cleanup: destroy instance, clear overrides + """ + + name = "pwncollege" + env_config_cls = PwnCollegeEnvConfig + + def __init__( + self, + config: PwnCollegeEnvConfig, + server_configs: List[APIServerConfig], + slurm: bool = False, + testing: bool = False, + ): + super().__init__(config, server_configs, slurm, testing) + self.config: PwnCollegeEnvConfig = config + + self.train: list[dict] = [] + self.iter = 0 + self.solve_rate_buffer: list[float] = [] + + # SDK clients — async for setup/lifecycle, sync for submit_flag handler + self.client: Optional[DojoRLClient] = None + self.sync_client: Optional[DojoRLSyncClient] = None + + @classmethod + def config_init(cls) -> Tuple[PwnCollegeEnvConfig, List[APIServerConfig]]: + env_config = PwnCollegeEnvConfig( + enabled_toolsets=["terminal", "file", "pwncollege"], + max_agent_turns=20, + max_token_length=16384, + agent_temperature=0.7, + terminal_backend="ssh", + system_prompt=SYSTEM_PROMPT, + use_wandb=True, + wandb_name="pwncollege", + ensure_scores_are_not_same=False, + ) + server_configs = [ + APIServerConfig( + base_url="https://openrouter.ai/api/v1", + model_name="anthropic/claude-sonnet-4.5", + server_type="openai", + api_key=os.getenv("OPENROUTER_API_KEY", ""), + health_check=False, + ), + ] + return env_config, server_configs + + async def setup(self): + """Load challenges from dojo and initialize SDK clients.""" + self.client = DojoRLClient(self.config.base_url) + self.sync_client = DojoRLSyncClient(self.config.base_url) + + # Fetch challenges + challenges = await self.client.list_challenges() + logger.info("Fetched %d challenges from dojo", len(challenges)) + + # Apply filters + for c in challenges: + if self.config.dojo_filter and c.get("dojo_id") != self.config.dojo_filter: + continue + if ( + self.config.module_filter + and c.get("module_id") != self.config.module_filter + ): + continue + self.train.append(c) + + # If a specific challenge is set and no filters matched, use it directly + if not self.train and self.config.challenge: + self.train.append( + { + "id": self.config.challenge.split("/")[-1], + "module_id": self.config.challenge.split("/")[0], + "dojo_id": "unknown", + "name": self.config.challenge, + "description": "", + "challenge_key": self.config.challenge, + } + ) + + logger.info("Training on %d challenges", len(self.train)) + + async def get_next_item(self) -> Item: + """Return next challenge item (round-robin).""" + item = self.train[self.iter % len(self.train)] + self.iter += 1 + return item + + def format_prompt(self, item: Item) -> str: + """Build user prompt from challenge metadata.""" + challenge_key = item.get( + "challenge_key", f"{item.get('module_id', '')}/{item.get('id', '')}" + ) + return USER_PROMPT_TEMPLATE.format( + module_name=item.get("module_id", "unknown"), + challenge_name=item.get("name", item.get("id", "unknown")), + challenge_description=item.get( + "description", f"Solve the challenge: {challenge_key}" + ), + ) + + async def collect_trajectory( + self, item: Item + ) -> Tuple[Optional[Union[ScoredDataItem, Any]], List[Item]]: + """Run a single rollout with dojo instance lifecycle management. + + Overrides the base class to wrap the agent loop with: + 1. Dojo instance creation (SSH-accessible challenge container) + 2. SSH override registration (routes terminal tool to the instance) + 3. Flag context registration (enables submit_flag tool) + 4. Cleanup on completion + """ + task_id = str(uuid.uuid4()) + challenge_key = item.get( + "challenge_key", + f"{item.get('module_id', '')}/{item.get('id', '')}", + ) + + # Create dojo instance + try: + inst = await self.client.create_instance(challenge_key) + except Exception as e: + logger.error("Failed to create instance for %s: %s", challenge_key, e) + return None, [] + + slot = inst.slot + ssh_user = inst.ssh_user + + # Register per-task SSH overrides + register_task_env_overrides( + task_id, + { + "ssh_user": ssh_user, + "ssh_host": self.config.ssh_host, + "ssh_port": self.config.ssh_port, + "ssh_key": self.config.ssh_key, + }, + ) + + # Register flag context for submit_flag tool + register_flag_context(task_id, self.sync_client, slot) + + try: + # Get group-level tools (includes submit_flag via "pwncollege" toolset) + if self._current_group_tools is None: + tools, valid_names = self._resolve_tools_for_group() + else: + tools, valid_names = self._current_group_tools + + # Build initial messages + messages: List[Dict[str, Any]] = [] + if self.config.system_prompt: + messages.append( + {"role": "system", "content": self.config.system_prompt} + ) + messages.append({"role": "user", "content": self.format_prompt(item)}) + + # Run the agent loop (Phase 1: OpenAI server) + agent = HermesAgentLoop( + server=self.server, + tool_schemas=tools, + valid_tool_names=valid_names, + max_turns=self.config.max_agent_turns, + task_id=task_id, + temperature=self.config.agent_temperature, + max_tokens=self.config.max_token_length, + extra_body=self.config.extra_body, + ) + result = await agent.run(messages) + + # Compute reward + ctx = ToolContext(task_id) + try: + reward = await self.compute_reward(item, result, ctx) + except Exception as e: + logger.error("compute_reward failed: %s", e) + reward = 0.0 + finally: + ctx.cleanup() + + # Track tool errors + if result.tool_errors: + for err in result.tool_errors: + self._tool_error_buffer.append( + { + "turn": err.turn, + "tool": err.tool_name, + "args": err.arguments[:150], + "error": err.error[:300], + "result": err.tool_result[:300], + } + ) + + # Build scored item (Phase 1: placeholder tokens) + full_text = "\n".join( + msg.get("content", "") for msg in result.messages if msg.get("content") + ) + if self.tokenizer: + tokens = self.tokenizer.encode(full_text, add_special_tokens=True) + else: + tokens = list(range(min(len(full_text) // 4, 128))) + + scored_item = { + "tokens": tokens, + "masks": [-100] + tokens[1:], + "scores": reward, + "messages": result.messages, + } + + return scored_item, [] + + finally: + # Always cleanup + clear_flag_context(task_id) + clear_task_env_overrides(task_id) + cleanup_vm(task_id) + try: + await self.client.destroy_instance(slot) + except Exception as e: + logger.warning("Failed to destroy instance slot %d: %s", slot, e) + + async def compute_reward( + self, item: Item, result: AgentResult, ctx: ToolContext + ) -> float: + """Score the rollout: 1.0 if flag was correctly submitted, 0.0 otherwise. + + Checks two signals: + 1. Did submit_flag return {"success": true}? + 2. Fallback: extract pwn.college{...} from terminal output and verify via SDK. + """ + # Check submit_flag tool results in the conversation + for msg in result.messages: + if msg.get("role") == "tool": + try: + data = json.loads(msg.get("content", "")) + if isinstance(data, dict) and data.get("success") is True: + self.solve_rate_buffer.append(1.0) + return 1.0 + except (json.JSONDecodeError, TypeError): + pass + + # Fallback: scan for flag pattern in all messages + for msg in result.messages: + content = msg.get("content", "") + if not content: + continue + flag_match = re.search(r"pwn\.college\{[^}]+\}", content) + if flag_match: + # We can't verify here since instance is being torn down, + # but the flag pattern presence suggests partial progress + self.solve_rate_buffer.append(0.0) + return 0.0 + + self.solve_rate_buffer.append(0.0) + return 0.0 + + async def evaluate(self, *args, **kwargs): + """Run evaluation on a dojo/module and report solve rate. + + Fetches challenges matching eval_dojo/eval_module, runs each through + the agent loop with concurrency control, and logs results. + """ + import asyncio + import time + + if not self.client: + logger.error("SDK client not initialized. Call setup() first.") + return + + start_time = time.time() + + # Fetch and filter eval challenges + all_challenges = await self.client.list_challenges() + eval_challenges = [ + c for c in all_challenges + if c.get("dojo_id") == self.config.eval_dojo + and (self.config.eval_module is None or c.get("module_id") == self.config.eval_module) + ] + + if not eval_challenges: + logger.warning( + "No challenges found for eval_dojo=%s eval_module=%s", + self.config.eval_dojo, self.config.eval_module, + ) + return + + logger.info( + "Evaluating %d challenges from %s/%s (concurrency=%d)", + len(eval_challenges), self.config.eval_dojo, + self.config.eval_module or "*", self.config.eval_concurrency, + ) + + semaphore = asyncio.Semaphore(self.config.eval_concurrency) + results: list[dict] = [] + + async def eval_one(challenge: dict) -> dict: + challenge_key = f"{challenge.get('module_id', '')}/{challenge.get('id', '')}" + async with semaphore: + try: + scored, _ = await self.collect_trajectory(challenge) + solved = scored is not None and scored.get("scores", 0.0) >= 1.0 + return { + "challenge": challenge_key, + "name": challenge.get("name", ""), + "solved": solved, + "reward": scored.get("scores", 0.0) if scored else 0.0, + } + except Exception as e: + logger.error("Eval failed for %s: %s", challenge_key, e) + return { + "challenge": challenge_key, + "name": challenge.get("name", ""), + "solved": False, + "reward": 0.0, + "error": str(e), + } + + tasks = [eval_one(c) for c in eval_challenges] + results = await asyncio.gather(*tasks) + + end_time = time.time() + + # Aggregate + n = len(results) + solved = sum(1 for r in results if r["solved"]) + solve_rate = solved / n if n else 0.0 + + logger.info("=" * 60) + for r in results: + status = "PASS" if r["solved"] else "FAIL" + logger.info(" [%s] %s (%s)", status, r["challenge"], r["name"]) + logger.info("=" * 60) + logger.info( + "Eval: %d/%d solved (%.1f%%) in %.1fs", + solved, n, solve_rate * 100, end_time - start_time, + ) + + eval_metrics = { + "eval/solve_rate": solve_rate, + "eval/solved": solved, + "eval/total": n, + } + + samples = [ + { + "prompt": r["challenge"], + "response": "SOLVED" if r["solved"] else "FAILED", + "expected": "SOLVED", + "reward": r["reward"], + } + for r in results + ] + + await self.evaluate_log( + metrics=eval_metrics, + samples=samples, + start_time=start_time, + end_time=end_time, + ) + + async def wandb_log(self, wandb_metrics: Optional[Dict] = None): + """Log solve rate metrics to wandb.""" + if wandb_metrics is None: + wandb_metrics = {} + if self.solve_rate_buffer: + n = len(self.solve_rate_buffer) + wandb_metrics["train/solve_rate"] = sum(self.solve_rate_buffer) / n + wandb_metrics["train/num_rollouts"] = n + await super().wandb_log(wandb_metrics) + + +if __name__ == "__main__": + PwnCollegeEnv.cli() diff --git a/environments/pwncollege_env/sdk.py b/environments/pwncollege_env/sdk.py new file mode 100644 index 00000000000..eda0f48675f --- /dev/null +++ b/environments/pwncollege_env/sdk.py @@ -0,0 +1,331 @@ +"""SDK for pwncollege dojo""" + +import asyncio +import re +from contextlib import asynccontextmanager +from dataclasses import dataclass, field +from typing import Any + +import httpx + + +def _extract_csrf_nonce(html: str) -> str | None: + match = re.search(r"'csrfNonce': \"([^\"]+)\"", html) + return match.group(1) if match else None + + +@dataclass +class RLInstance: + slot: int + ssh_user: str + challenge_id: str + module_id: str + dojo_id: str + flag: str | None = None + created_at: str | None = None + + +class DojoRLClient: + """Client for the dojo RL API. No auth required.""" + + def __init__(self, base_url: str, timeout: float = 120.0): + self.base_url = base_url.rstrip("/") + self.client = httpx.AsyncClient( + base_url=self.base_url, + timeout=timeout, + follow_redirects=True, + ) + + async def __aenter__(self): + return self + + async def __aexit__(self, *args): + await self.close() + + async def close(self): + await self.client.aclose() + + def _rl_url(self, path: str) -> str: + return f"/pwncollege_api/v1/rl{path}" + + async def _get(self, path: str) -> dict[str, Any]: + resp = await self.client.get(self._rl_url(path)) + resp.raise_for_status() + return resp.json() + + async def _post(self, path: str, json: dict | None = None) -> dict[str, Any]: + resp = await self.client.post(self._rl_url(path), json=json or {}) + resp.raise_for_status() + return resp.json() + + async def _delete(self, path: str) -> dict[str, Any]: + resp = await self.client.delete(self._rl_url(path)) + resp.raise_for_status() + return resp.json() + + # ── RL Instance Lifecycle ───────────────────────────────────────────────── + + async def status(self) -> dict[str, Any]: + return await self._get("/status") + + async def create_instance( + self, challenge: str, *, variant: int | None = None + ) -> RLInstance: + data: dict[str, Any] = {"challenge": challenge} + if variant is not None: + data["variant"] = variant + result = await self._post("/instances", json=data) + if not result.get("success"): + raise RuntimeError(f"Failed to create instance: {result.get('error')}") + return RLInstance( + slot=result["slot"], + ssh_user=result["ssh_user"], + challenge_id=result["challenge"], + module_id=result["module"], + dojo_id=result["dojo"], + ) + + async def get_instance(self, slot: int) -> RLInstance: + result = await self._get(f"/instances/{slot}") + if not result.get("success"): + raise KeyError(f"No instance at slot {slot}") + return RLInstance( + slot=result["slot"], + ssh_user=result.get("ssh_user", f"rl_{slot}"), + challenge_id=result["challenge_id"], + module_id=result["module_id"], + dojo_id=result["dojo_id"], + flag=result.get("flag"), + created_at=result.get("created_at"), + ) + + async def list_instances(self) -> list[dict[str, Any]]: + result = await self._get("/instances") + return result.get("instances", []) + + async def destroy_instance(self, slot: int) -> None: + result = await self._delete(f"/instances/{slot}") + if not result.get("success"): + raise RuntimeError(f"Failed to destroy instance: {result.get('error')}") + + async def reset_instance( + self, slot: int, *, challenge: str | None = None + ) -> RLInstance: + data: dict[str, Any] = {} + if challenge is not None: + data["challenge"] = challenge + result = await self._post(f"/instances/{slot}/reset", json=data) + if not result.get("success"): + raise RuntimeError(f"Failed to reset instance: {result.get('error')}") + return RLInstance( + slot=result["slot"], + ssh_user=result["ssh_user"], + challenge_id=result["challenge"], + module_id=result["module"], + dojo_id=result["dojo"], + ) + + async def check_flag(self, slot: int, flag: str) -> bool: + result = await self._post(f"/instances/{slot}/check", json={"flag": flag}) + return result.get("correct", False) + + async def get_flag(self, slot: int) -> str: + instance = await self.get_instance(slot) + if instance.flag is None: + raise RuntimeError(f"No flag available for slot {slot}") + return instance.flag + + # ── Challenge Discovery ─────────────────────────────────────────────────── + + async def list_challenges(self) -> list[dict[str, Any]]: + result = await self._get("/challenges") + return result.get("challenges", []) + + # ── Admin (requires auth) ───────────────────────────────────────────────── + + async def admin_login( + self, username: str = "admin", password: str = "admin" + ) -> None: + resp = await self.client.get("/login") + nonce = _extract_csrf_nonce(resp.text) + if not nonce: + raise RuntimeError("Could not extract CSRF nonce") + self._admin_csrf = nonce + resp = await self.client.post( + "/login", + data={"name": username, "password": password, "nonce": nonce}, + ) + if resp.status_code not in (200, 302): + raise RuntimeError(f"Login failed: {resp.status_code}") + resp = await self.client.get("/") + self._admin_csrf = _extract_csrf_nonce(resp.text) or self._admin_csrf + + async def load_dojo(self, repository: str) -> str: + if not hasattr(self, "_admin_csrf"): + raise RuntimeError("Must call admin_login() first") + resp = await self.client.post( + "/pwncollege_api/v1/dojos/create", + json={ + "repository": repository, + "public_key": f"public/{repository}", + "private_key": f"private/{repository}", + }, + headers={"CSRF-Token": self._admin_csrf}, + ) + resp.raise_for_status() + data = resp.json() + if not data.get("success", True): + raise RuntimeError(f"Failed to load dojo: {data.get('error', data)}") + return data.get("dojo", repository) + + async def promote_dojo(self, dojo_id: str) -> None: + if not hasattr(self, "_admin_csrf"): + raise RuntimeError("Must call admin_login() first") + resp = await self.client.post( + f"/pwncollege_api/v1/dojos/{dojo_id}/promote", + json={}, + headers={"CSRF-Token": self._admin_csrf}, + ) + resp.raise_for_status() + + # ── Bulk Operations ─────────────────────────────────────────────────────── + + async def create_batch(self, challenge: str, count: int) -> list[RLInstance]: + tasks = [self.create_instance(challenge) for _ in range(count)] + return await asyncio.gather(*tasks) + + async def destroy_all(self) -> int: + instances = await self.list_instances() + for inst in instances: + await self.destroy_instance(inst["slot"]) + return len(instances) + + +class DojoRLSyncClient: + """Sync wrapper for DojoRLClient. + + Runs all async operations on a dedicated background thread with its own + event loop, so it's safe to call from any context — including from inside + another running event loop (e.g., Atropos's loop or tool dispatch threads). + """ + + def __init__(self, base_url: str, timeout: float = 120.0): + import threading + self._async = DojoRLClient(base_url, timeout) + self._loop = asyncio.new_event_loop() + self._thread = threading.Thread( + target=self._loop.run_forever, daemon=True, + ) + self._thread.start() + + def _run(self, coro): + return asyncio.run_coroutine_threadsafe(coro, self._loop).result() + + def __enter__(self): + return self + + def __exit__(self, *args): + self.close() + + def close(self): + self._run(self._async.close()) + self._loop.call_soon_threadsafe(self._loop.stop) + self._thread.join(timeout=5) + + def status(self) -> dict[str, Any]: + return self._run(self._async.status()) + + def create_instance( + self, challenge: str, *, variant: int | None = None + ) -> RLInstance: + return self._run(self._async.create_instance(challenge, variant=variant)) + + def get_instance(self, slot: int) -> RLInstance: + return self._run(self._async.get_instance(slot)) + + def list_instances(self) -> list[dict[str, Any]]: + return self._run(self._async.list_instances()) + + def destroy_instance(self, slot: int) -> None: + return self._run(self._async.destroy_instance(slot)) + + def reset_instance(self, slot: int, *, challenge: str | None = None) -> RLInstance: + return self._run(self._async.reset_instance(slot, challenge=challenge)) + + def check_flag(self, slot: int, flag: str) -> bool: + return self._run(self._async.check_flag(slot, flag)) + + def get_flag(self, slot: int) -> str: + return self._run(self._async.get_flag(slot)) + + def list_challenges(self) -> list[dict[str, Any]]: + return self._run(self._async.list_challenges()) + + def admin_login(self, username: str = "admin", password: str = "admin") -> None: + return self._run(self._async.admin_login(username, password)) + + def load_dojo(self, repository: str) -> str: + return self._run(self._async.load_dojo(repository)) + + def promote_dojo(self, dojo_id: str) -> None: + return self._run(self._async.promote_dojo(dojo_id)) + + def destroy_all(self) -> int: + return self._run(self._async.destroy_all()) + + +@dataclass +class EpisodePool: + """Manages a pool of RL instances for parallel episode collection.""" + + client: DojoRLClient + challenge: str + pool_size: int = 32 + acquisition_timeout: float = 300.0 + + _available: asyncio.Queue[RLInstance] = field( + default_factory=asyncio.Queue, init=False + ) + _all_instances: dict[int, RLInstance] = field(default_factory=dict, init=False) + _initialized: bool = field(default=False, init=False) + + async def initialize(self) -> None: + if self._initialized: + return + for _ in range(self.pool_size): + instance = await self.client.create_instance(self.challenge) + full = await self.client.get_instance(instance.slot) + self._all_instances[instance.slot] = full + await self._available.put(full) + self._initialized = True + + @asynccontextmanager + async def acquire(self): + if not self._initialized: + raise RuntimeError("EpisodePool not initialized") + try: + instance = await asyncio.wait_for( + self._available.get(), timeout=self.acquisition_timeout + ) + except asyncio.TimeoutError: + raise RuntimeError( + f"No instance available within {self.acquisition_timeout}s" + ) + try: + yield instance + finally: + reset = await self.client.reset_instance( + instance.slot, challenge=self.challenge + ) + full = await self.client.get_instance(reset.slot) + self._all_instances[reset.slot] = full + await self._available.put(full) + + async def shutdown(self) -> None: + for slot in list(self._all_instances.keys()): + try: + await self.client.destroy_instance(slot) + except Exception: + pass + self._all_instances.clear() + self._initialized = False diff --git a/environments/pwncollege_env/submit_flag_tool.py b/environments/pwncollege_env/submit_flag_tool.py new file mode 100644 index 00000000000..36c4b375505 --- /dev/null +++ b/environments/pwncollege_env/submit_flag_tool.py @@ -0,0 +1,102 @@ +"""submit_flag tool for pwn.college RL environments. + +Registers a `submit_flag` tool in the hermes-agent tool registry under the +"pwncollege" toolset. The handler checks flags against the dojo RL API using +per-task context (SDK client + slot) stored in a module-level dict. + +Usage in an environment: + from environments.pwncollege_env.submit_flag_tool import ( + register_flag_context, clear_flag_context, + ) + + # Before agent loop + register_flag_context(task_id, sync_client, slot) + + # After agent loop + clear_flag_context(task_id) +""" + +import json +import logging +from typing import Any, Dict + +logger = logging.getLogger(__name__) + +# Per-task context: task_id → {"client": DojoRLSyncClient, "slot": int} +_task_flag_context: Dict[str, Dict[str, Any]] = {} + + +def register_flag_context(task_id: str, sync_client: Any, slot: int) -> None: + """Register dojo client + slot for a rollout so submit_flag can verify flags.""" + _task_flag_context[task_id] = {"client": sync_client, "slot": slot} + + +def clear_flag_context(task_id: str) -> None: + """Remove flag context after rollout completes.""" + _task_flag_context.pop(task_id, None) + + +def _submit_flag_handler(args: dict, **kw) -> str: + """Handle submit_flag tool calls by checking the flag against the dojo API.""" + task_id = kw.get("task_id", "default") + flag = args.get("flag", "") + + if not flag: + return json.dumps({"success": False, "message": "No flag provided."}) + + ctx = _task_flag_context.get(task_id) + if not ctx: + return json.dumps({ + "success": False, + "message": "No active challenge instance for this task.", + }) + + try: + correct = ctx["client"].check_flag(ctx["slot"], flag) + except Exception as e: + logger.error("Flag check failed for task %s: %s", task_id, e) + return json.dumps({"success": False, "message": f"Flag check error: {e}"}) + + if correct: + return json.dumps({"success": True, "message": "Flag accepted! Challenge solved."}) + return json.dumps({"success": False, "message": "Incorrect flag."}) + + +# --------------------------------------------------------------------------- +# Register in hermes-agent tool registry +# --------------------------------------------------------------------------- + +SUBMIT_FLAG_SCHEMA = { + "name": "submit_flag", + "description": ( + "Submit a flag for verification. Use this when you find a flag " + "(format: pwn.college{...}) to check if it is correct." + ), + "parameters": { + "type": "object", + "properties": { + "flag": { + "type": "string", + "description": "The flag string (format: pwn.college{...}).", + } + }, + "required": ["flag"], + }, +} + +from tools.registry import registry +from toolsets import create_custom_toolset + +registry.register( + name="submit_flag", + toolset="pwncollege", + schema=SUBMIT_FLAG_SCHEMA, + handler=_submit_flag_handler, + emoji="🚩", +) + +create_custom_toolset( + name="pwncollege", + description="PwnCollege CTF tools", + tools=["submit_flag"], +) diff --git a/tests/tools/test_ssh_overrides.py b/tests/tools/test_ssh_overrides.py new file mode 100644 index 00000000000..6230369d70b --- /dev/null +++ b/tests/tools/test_ssh_overrides.py @@ -0,0 +1,98 @@ +"""Tests for per-task SSH environment overrides.""" + +from tools.terminal_tool import ( + register_task_env_overrides, + clear_task_env_overrides, + _task_env_overrides, +) + + +class TestSSHOverridesInConfig: + """Verify SSH config assembly respects per-task overrides.""" + + def setup_method(self): + self._saved = dict(_task_env_overrides) + _task_env_overrides.clear() + + def teardown_method(self): + _task_env_overrides.clear() + _task_env_overrides.update(self._saved) + + def _build_ssh_config(self, task_id: str, global_config: dict) -> dict: + """Replicate the SSH config assembly logic from terminal_tool.py.""" + overrides = _task_env_overrides.get(task_id, {}) + return { + "host": overrides.get("ssh_host") or global_config.get("ssh_host", ""), + "user": overrides.get("ssh_user") or global_config.get("ssh_user", ""), + "port": overrides.get("ssh_port") or global_config.get("ssh_port", 22), + "key": overrides.get("ssh_key") or global_config.get("ssh_key", ""), + "persistent": overrides.get("ssh_persistent", global_config.get("ssh_persistent", False)), + } + + def test_no_overrides_uses_global(self): + """Without per-task overrides, global config is used.""" + global_config = { + "ssh_host": "global.example.com", + "ssh_user": "root", + "ssh_port": 22, + "ssh_key": "/root/.ssh/id_rsa", + "ssh_persistent": True, + } + result = self._build_ssh_config("task-1", global_config) + assert result["host"] == "global.example.com" + assert result["user"] == "root" + assert result["port"] == 22 + assert result["key"] == "/root/.ssh/id_rsa" + assert result["persistent"] is True + + def test_override_port_and_key(self): + """Per-task overrides for port and key take precedence.""" + global_config = { + "ssh_host": "dojo.pwncollege.com", + "ssh_user": "hacker", + "ssh_port": 22, + "ssh_key": "/default/key", + } + register_task_env_overrides("task-42", { + "ssh_port": 2264, + "ssh_key": "/tmp/keys/episode_42", + }) + result = self._build_ssh_config("task-42", global_config) + assert result["port"] == 2264 + assert result["key"] == "/tmp/keys/episode_42" + # Non-overridden fields fall through to global + assert result["host"] == "dojo.pwncollege.com" + assert result["user"] == "hacker" + + def test_different_tasks_get_different_ports(self): + """128 parallel rollouts each get their own SSH port.""" + global_config = { + "ssh_host": "dojo.pwncollege.com", + "ssh_user": "hacker", + "ssh_port": 22, + "ssh_key": "", + } + for i in range(128): + tid = f"task-{i}" + register_task_env_overrides(tid, {"ssh_port": 2222 + i}) + + for i in range(128): + tid = f"task-{i}" + result = self._build_ssh_config(tid, global_config) + assert result["port"] == 2222 + i + + def test_clear_overrides_reverts_to_global(self): + """After clearing, config falls back to global.""" + global_config = {"ssh_port": 22} + register_task_env_overrides("task-99", {"ssh_port": 9999}) + assert self._build_ssh_config("task-99", global_config)["port"] == 9999 + + clear_task_env_overrides("task-99") + assert self._build_ssh_config("task-99", global_config)["port"] == 22 + + def test_persistent_false_not_clobbered_by_or(self): + """ssh_persistent=False override must not be skipped due to falsy `or`.""" + global_config = {"ssh_persistent": True} + register_task_env_overrides("task-x", {"ssh_persistent": False}) + result = self._build_ssh_config("task-x", global_config) + assert result["persistent"] is False diff --git a/tools/environments/ssh.py b/tools/environments/ssh.py index fa3781a9900..83891fe2006 100644 --- a/tools/environments/ssh.py +++ b/tools/environments/ssh.py @@ -3,7 +3,6 @@ import logging import shutil import subprocess -import tempfile import threading import time from pathlib import Path @@ -50,7 +49,11 @@ class SSHEnvironment(PersistentShellMixin, BaseEnvironment): self.key_path = key_path self.persistent = persistent - self.control_dir = Path(tempfile.gettempdir()) / "hermes-ssh" + # Use /tmp directly instead of platform tempdir — macOS's + # /var/folders/XX/.../T/ path is ~60 chars, and Unix domain sockets + # have a 104-char limit. A socket name like "rl_1@10.0.0.5:2222.sock" + # would exceed it. /tmp/hermes-ssh/ keeps paths short. + self.control_dir = Path("/tmp/hermes-ssh") self.control_dir.mkdir(parents=True, exist_ok=True) self.control_socket = self.control_dir / f"{user}@{host}:{port}.sock" _ensure_ssh_available() diff --git a/tools/file_tools.py b/tools/file_tools.py index 519178c006e..6a0da2845e6 100644 --- a/tools/file_tools.py +++ b/tools/file_tools.py @@ -117,11 +117,11 @@ def _get_file_ops(task_id: str = "default") -> ShellFileOperations: ssh_config = None if env_type == "ssh": ssh_config = { - "host": config.get("ssh_host", ""), - "user": config.get("ssh_user", ""), - "port": config.get("ssh_port", 22), - "key": config.get("ssh_key", ""), - "persistent": config.get("ssh_persistent", False), + "host": overrides.get("ssh_host") or config.get("ssh_host", ""), + "user": overrides.get("ssh_user") or config.get("ssh_user", ""), + "port": overrides.get("ssh_port") or config.get("ssh_port", 22), + "key": overrides.get("ssh_key") or config.get("ssh_key", ""), + "persistent": overrides.get("ssh_persistent", config.get("ssh_persistent", False)), } local_config = None diff --git a/tools/terminal_tool.py b/tools/terminal_tool.py index aa917ab1ab7..cd7225b7e47 100644 --- a/tools/terminal_tool.py +++ b/tools/terminal_tool.py @@ -413,6 +413,11 @@ def register_task_env_overrides(task_id: str, overrides: Dict[str, Any]): - modal_image: str -- Path to Dockerfile or Docker Hub image name - docker_image: str -- Docker image name - cwd: str -- Working directory inside the sandbox + - ssh_host: str -- SSH hostname (overrides TERMINAL_SSH_HOST) + - ssh_user: str -- SSH username (overrides TERMINAL_SSH_USER) + - ssh_port: int -- SSH port (overrides TERMINAL_SSH_PORT) + - ssh_key: str -- Path to SSH private key (overrides TERMINAL_SSH_KEY) + - ssh_persistent: bool -- Persistent shell mode (overrides TERMINAL_SSH_PERSISTENT) Args: task_id: The rollout's unique task identifier @@ -942,11 +947,11 @@ def terminal_tool( ssh_config = None if env_type == "ssh": ssh_config = { - "host": config.get("ssh_host", ""), - "user": config.get("ssh_user", ""), - "port": config.get("ssh_port", 22), - "key": config.get("ssh_key", ""), - "persistent": config.get("ssh_persistent", False), + "host": overrides.get("ssh_host") or config.get("ssh_host", ""), + "user": overrides.get("ssh_user") or config.get("ssh_user", ""), + "port": overrides.get("ssh_port") or config.get("ssh_port", 22), + "key": overrides.get("ssh_key") or config.get("ssh_key", ""), + "persistent": overrides.get("ssh_persistent", config.get("ssh_persistent", False)), } container_config = None