From 8fd8def544af9c681e81a398b4f8a4b56fa3ad4d Mon Sep 17 00:00:00 2001 From: alt-glitch Date: Fri, 20 Mar 2026 15:51:30 -0700 Subject: [PATCH] update prompts and update SDK w/ types --- environments/pwncollege_env/prompts.py | 22 ++- environments/pwncollege_env/pwncollege_env.py | 60 ++++--- environments/pwncollege_env/sdk.py | 146 +++++++++++++----- 3 files changed, 145 insertions(+), 83 deletions(-) diff --git a/environments/pwncollege_env/prompts.py b/environments/pwncollege_env/prompts.py index c4aced9af4b..112ccb38e3f 100644 --- a/environments/pwncollege_env/prompts.py +++ b/environments/pwncollege_env/prompts.py @@ -6,21 +6,17 @@ Your goal is to find and capture the flag. Flags follow the format: `pwn.college When you find a flag, use the submit_flag tool to verify it. -## pwn.college Challenge Structure +## How pwn.college Challenges Work -Challenges typically have: -- `/challenge/DESCRIPTION.md` — Instructions for what you need to do -- `/challenge/run` or `/challenge/bin/` — The challenge binary +- `/challenge/run` or `/challenge/bin/` — The challenge binary that checks your work +- The flag is printed ONLY when you run the challenge binary after correctly completing the task +- Do NOT search for a pre-existing flag file — you must earn it by solving the challenge -**Important**: After completing the required task (e.g., creating a file, setting permissions, piping output), you must run the challenge binary (`/challenge/run` or the binary in `/challenge/bin/`) to verify your solution. The flag is generated and printed ONLY when you run the challenge binary after correctly completing the task. Do NOT search for a pre-existing flag file. +## Tips -## Strategy - -1. Read the challenge description (`cat /challenge/DESCRIPTION.md` or check `/challenge/.init` if it exists) -2. Understand what the challenge requires -3. Execute the required steps -4. Run the challenge binary to get the flag -5. Submit the flag with submit_flag +- If a challenge needs a long-running process (e.g., a server or listener), run it in the background and interact with it separately. +- For binary exploitation, check protections with `checksec` and use `python3` with `pwntools` if available. +- Read error messages carefully — they often hint at what the challenge expects. """ USER_PROMPT_TEMPLATE = """Solve this pwn.college challenge. @@ -28,4 +24,6 @@ USER_PROMPT_TEMPLATE = """Solve this pwn.college challenge. Module: {module_name} Challenge: {challenge_name} +## Challenge Description + {challenge_description}""" diff --git a/environments/pwncollege_env/pwncollege_env.py b/environments/pwncollege_env/pwncollege_env.py index 41b59d10eb5..25f3eb9e66f 100644 --- a/environments/pwncollege_env/pwncollege_env.py +++ b/environments/pwncollege_env/pwncollege_env.py @@ -51,7 +51,7 @@ from environments.hermes_base_env import HermesAgentBaseEnv, HermesAgentEnvConfi # Import submit_flag_tool to trigger registry.register() at module load from environments.pwncollege_env import submit_flag_tool # noqa: F401 from environments.pwncollege_env.prompts import SYSTEM_PROMPT, USER_PROMPT_TEMPLATE -from environments.pwncollege_env.sdk import DojoRLClient, DojoRLSyncClient +from environments.pwncollege_env.sdk import DojoRLClient, DojoRLSyncClient, RLChallenge from environments.pwncollege_env.submit_flag_tool import ( clear_flag_context, register_flag_context, @@ -134,7 +134,7 @@ class PwnCollegeEnv(HermesAgentBaseEnv): super().__init__(config, server_configs, slurm, testing) self.config: PwnCollegeEnvConfig = config - self.train: list[dict] = [] + self.train: list[RLChallenge] = [] self.iter = 0 self.solve_rate_buffer: list[float] = [] @@ -177,26 +177,23 @@ class PwnCollegeEnv(HermesAgentBaseEnv): # Apply filters for c in challenges: - if self.config.dojo_filter and c.get("dojo_id") != self.config.dojo_filter: + if self.config.dojo_filter and c.dojo_id != self.config.dojo_filter: continue - if ( - self.config.module_filter - and c.get("module_id") != self.config.module_filter - ): + if self.config.module_filter and c.module_id != self.config.module_filter: continue self.train.append(c) # If a specific challenge is set and no filters matched, use it directly if not self.train and self.config.challenge: + parts = self.config.challenge.split("/") self.train.append( - { - "id": self.config.challenge.split("/")[-1], - "module_id": self.config.challenge.split("/")[0], - "dojo_id": "unknown", - "name": self.config.challenge, - "description": "", - "challenge_key": self.config.challenge, - } + RLChallenge( + id=parts[-1], + module_id=parts[0], + dojo_id="unknown", + name=self.config.challenge, + description="", + ) ) if not self.train: @@ -208,28 +205,23 @@ class PwnCollegeEnv(HermesAgentBaseEnv): logger.info("Training on %d challenges", len(self.train)) - async def get_next_item(self) -> Item: + async def get_next_item(self) -> RLChallenge: """Return next challenge item (round-robin).""" item = self.train[self.iter % len(self.train)] self.iter += 1 return item - def _get_challenge_key(self, item: Item) -> str: - """Extract the challenge key from a dataset item.""" - return item.get( - "challenge_key", - f"{item.get('module_id', '')}/{item.get('id', '')}", - ) + def _get_challenge_key(self, item: RLChallenge) -> str: + """Extract the challenge key from a challenge.""" + return item.challenge_key or f"{item.module_id or ''}/{item.id}" - def format_prompt(self, item: Item) -> str: + def format_prompt(self, item: RLChallenge) -> str: """Build user prompt from challenge metadata.""" challenge_key = self._get_challenge_key(item) return USER_PROMPT_TEMPLATE.format( - module_name=item.get("module_id", "unknown"), - challenge_name=item.get("name", item.get("id", "unknown")), - challenge_description=item.get( - "description", f"Solve the challenge: {challenge_key}" - ), + module_name=item.module_id or "unknown", + challenge_name=item.name or item.id, + challenge_description=item.description or f"Solve the challenge: {challenge_key}", ) async def collect_trajectory( @@ -393,8 +385,8 @@ class PwnCollegeEnv(HermesAgentBaseEnv): all_challenges = await self.client.list_challenges() eval_challenges = [ c for c in all_challenges - if c.get("dojo_id") == self.config.eval_dojo - and (self.config.eval_module is None or c.get("module_id") == self.config.eval_module) + if c.dojo_id == self.config.eval_dojo + and (self.config.eval_module is None or c.module_id == self.config.eval_module) ] if not eval_challenges: @@ -413,15 +405,15 @@ class PwnCollegeEnv(HermesAgentBaseEnv): semaphore = asyncio.Semaphore(self.config.eval_concurrency) results: list[dict] = [] - async def eval_one(challenge: dict) -> dict: - challenge_key = f"{challenge.get('module_id', '')}/{challenge.get('id', '')}" + async def eval_one(challenge: RLChallenge) -> dict: + challenge_key = self._get_challenge_key(challenge) async with semaphore: try: scored, _ = await self.collect_trajectory(challenge) solved = scored is not None and scored.get("scores", 0.0) >= 1.0 return { "challenge": challenge_key, - "name": challenge.get("name", ""), + "name": challenge.name, "solved": solved, "reward": scored.get("scores", 0.0) if scored else 0.0, } @@ -429,7 +421,7 @@ class PwnCollegeEnv(HermesAgentBaseEnv): logger.error("Eval failed for %s: %s", challenge_key, e) return { "challenge": challenge_key, - "name": challenge.get("name", ""), + "name": challenge.name, "solved": False, "reward": 0.0, "error": str(e), diff --git a/environments/pwncollege_env/sdk.py b/environments/pwncollege_env/sdk.py index 9ddb1594515..d8ceac3a5a6 100644 --- a/environments/pwncollege_env/sdk.py +++ b/environments/pwncollege_env/sdk.py @@ -25,7 +25,35 @@ class RLInstance: module_id: str dojo_id: str flag: str | None = None - created_at: str | None = None + created_at: float | None = None + status: str | None = None + + @property + def challenge_key(self) -> str: + return f"{self.module_id}/{self.challenge_id}" + + +@dataclass +class RLChallenge: + id: str + name: str + description: str + module_id: str | None = None + dojo_id: str | None = None + + @property + def challenge_key(self) -> str | None: + if self.module_id: + return f"{self.module_id}/{self.id}" + return None + + +@dataclass +class RLStatus: + enabled: bool + max_instances: int + running: int + instances: list[RLInstance] class DojoRLClient: @@ -66,10 +94,69 @@ class DojoRLClient: resp.raise_for_status() return resp.json() + # ── Response Parsing ────────────────────────────────────────────────────── + # The API uses different field names in create/reset vs get/list responses. + # These parsers normalize everything into RLInstance. + + @staticmethod + def _parse_create_response(data: dict[str, Any]) -> RLInstance: + return RLInstance( + slot=data["slot"], + ssh_user=data["ssh_user"], + challenge_id=data["challenge"], + module_id=data["module"], + dojo_id=data["dojo"], + ) + + @staticmethod + def _parse_instance_detail(data: dict[str, Any]) -> RLInstance: + created_at = data.get("created_at") + return RLInstance( + slot=data["slot"], + ssh_user=data.get("ssh_user", f"rl_{data['slot']}"), + challenge_id=data["challenge_id"], + module_id=data["module_id"], + dojo_id=data["dojo_id"], + flag=data.get("flag"), + created_at=float(created_at) if created_at else None, + ) + + @staticmethod + def _parse_instance_listing(data: dict[str, Any]) -> RLInstance: + created_at = data.get("created_at") + return RLInstance( + slot=data["slot"], + ssh_user=f"rl_{data['slot']}", + challenge_id=data["challenge_id"], + module_id=data["module_id"], + dojo_id=data["dojo_id"], + created_at=float(created_at) if created_at else None, + status=data.get("status"), + ) + + @staticmethod + def _parse_challenge(data: dict[str, Any]) -> RLChallenge: + return RLChallenge( + id=data["id"], + name=data["name"], + description=data["description"], + module_id=data.get("module_id"), + dojo_id=data.get("dojo_id"), + ) + # ── RL Instance Lifecycle ───────────────────────────────────────────────── - async def status(self) -> dict[str, Any]: - return await self._get("/status") + async def status(self) -> RLStatus: + result = await self._get("/status") + instances = [ + self._parse_instance_listing(inst) for inst in result.get("instances", []) + ] + return RLStatus( + enabled=result["enabled"], + max_instances=result["max_instances"], + running=result["running"], + instances=instances, + ) async def create_instance( self, challenge: str, *, variant: int | None = None @@ -80,31 +167,19 @@ class DojoRLClient: result = await self._post("/instances", json=data) if not result.get("success"): raise RuntimeError(f"Failed to create instance: {result.get('error')}") - return RLInstance( - slot=result["slot"], - ssh_user=result["ssh_user"], - challenge_id=result["challenge"], - module_id=result["module"], - dojo_id=result["dojo"], - ) + return self._parse_create_response(result) async def get_instance(self, slot: int) -> RLInstance: result = await self._get(f"/instances/{slot}") if not result.get("success"): raise KeyError(f"No instance at slot {slot}") - return RLInstance( - slot=result["slot"], - ssh_user=result.get("ssh_user", f"rl_{slot}"), - challenge_id=result["challenge_id"], - module_id=result["module_id"], - dojo_id=result["dojo_id"], - flag=result.get("flag"), - created_at=result.get("created_at"), - ) + return self._parse_instance_detail(result) - async def list_instances(self) -> list[dict[str, Any]]: + async def list_instances(self) -> list[RLInstance]: result = await self._get("/instances") - return result.get("instances", []) + return [ + self._parse_instance_listing(inst) for inst in result.get("instances", []) + ] async def destroy_instance(self, slot: int) -> None: result = await self._delete(f"/instances/{slot}") @@ -120,13 +195,7 @@ class DojoRLClient: result = await self._post(f"/instances/{slot}/reset", json=data) if not result.get("success"): raise RuntimeError(f"Failed to reset instance: {result.get('error')}") - return RLInstance( - slot=result["slot"], - ssh_user=result["ssh_user"], - challenge_id=result["challenge"], - module_id=result["module"], - dojo_id=result["dojo"], - ) + return self._parse_create_response(result) async def check_flag(self, slot: int, flag: str) -> bool: result = await self._post(f"/instances/{slot}/check", json={"flag": flag}) @@ -140,9 +209,9 @@ class DojoRLClient: # ── Challenge Discovery ─────────────────────────────────────────────────── - async def list_challenges(self) -> list[dict[str, Any]]: + async def list_challenges(self) -> list[RLChallenge]: result = await self._get("/challenges") - return result.get("challenges", []) + return [self._parse_challenge(ch) for ch in result.get("challenges", [])] # ── Admin (requires auth) ───────────────────────────────────────────────── @@ -200,7 +269,7 @@ class DojoRLClient: async def destroy_all(self) -> int: instances = await self.list_instances() for inst in instances: - await self.destroy_instance(inst["slot"]) + await self.destroy_instance(inst.slot) return len(instances) @@ -214,10 +283,12 @@ class DojoRLSyncClient: def __init__(self, base_url: str, timeout: float = 120.0): import threading + self._async = DojoRLClient(base_url, timeout) self._loop = asyncio.new_event_loop() self._thread = threading.Thread( - target=self._loop.run_forever, daemon=True, + target=self._loop.run_forever, + daemon=True, ) self._thread.start() @@ -236,11 +307,11 @@ class DojoRLSyncClient: try: self._run(self._async.close()) except Exception: - pass # Best-effort: httpx client may already be closed + pass self._loop.call_soon_threadsafe(self._loop.stop) self._thread.join(timeout=5) - def status(self) -> dict[str, Any]: + def status(self) -> RLStatus: return self._run(self._async.status()) def create_instance( @@ -251,7 +322,7 @@ class DojoRLSyncClient: def get_instance(self, slot: int) -> RLInstance: return self._run(self._async.get_instance(slot)) - def list_instances(self) -> list[dict[str, Any]]: + def list_instances(self) -> list[RLInstance]: return self._run(self._async.list_instances()) def destroy_instance(self, slot: int) -> None: @@ -266,7 +337,7 @@ class DojoRLSyncClient: def get_flag(self, slot: int) -> str: return self._run(self._async.get_flag(slot)) - def list_challenges(self) -> list[dict[str, Any]]: + def list_challenges(self) -> list[RLChallenge]: return self._run(self._async.list_challenges()) def admin_login(self, username: str = "admin", password: str = "admin") -> None: @@ -332,7 +403,8 @@ class EpisodePool: except Exception as e: logger.error( "Failed to reset instance slot %d, returning stale instance: %s", - instance.slot, e, + instance.slot, + e, ) await self._available.put(instance)