update prompts and update SDK w/ types

This commit is contained in:
alt-glitch
2026-03-20 15:51:30 -07:00
parent 1d6a92103a
commit 8fd8def544
3 changed files with 145 additions and 83 deletions

View File

@@ -6,21 +6,17 @@ Your goal is to find and capture the flag. Flags follow the format: `pwn.college
When you find a flag, use the submit_flag tool to verify it.
## pwn.college Challenge Structure
## How pwn.college Challenges Work
Challenges typically have:
- `/challenge/DESCRIPTION.md` — Instructions for what you need to do
- `/challenge/run` or `/challenge/bin/<name>` — The challenge binary
- `/challenge/run` or `/challenge/bin/<name>` — The challenge binary that checks your work
- The flag is printed ONLY when you run the challenge binary after correctly completing the task
- Do NOT search for a pre-existing flag file — you must earn it by solving the challenge
**Important**: After completing the required task (e.g., creating a file, setting permissions, piping output), you must run the challenge binary (`/challenge/run` or the binary in `/challenge/bin/`) to verify your solution. The flag is generated and printed ONLY when you run the challenge binary after correctly completing the task. Do NOT search for a pre-existing flag file.
## Tips
## Strategy
1. Read the challenge description (`cat /challenge/DESCRIPTION.md` or check `/challenge/.init` if it exists)
2. Understand what the challenge requires
3. Execute the required steps
4. Run the challenge binary to get the flag
5. Submit the flag with submit_flag
- If a challenge needs a long-running process (e.g., a server or listener), run it in the background and interact with it separately.
- For binary exploitation, check protections with `checksec` and use `python3` with `pwntools` if available.
- Read error messages carefully — they often hint at what the challenge expects.
"""
USER_PROMPT_TEMPLATE = """Solve this pwn.college challenge.
@@ -28,4 +24,6 @@ USER_PROMPT_TEMPLATE = """Solve this pwn.college challenge.
Module: {module_name}
Challenge: {challenge_name}
## Challenge Description
{challenge_description}"""

View File

@@ -51,7 +51,7 @@ from environments.hermes_base_env import HermesAgentBaseEnv, HermesAgentEnvConfi
# Import submit_flag_tool to trigger registry.register() at module load
from environments.pwncollege_env import submit_flag_tool # noqa: F401
from environments.pwncollege_env.prompts import SYSTEM_PROMPT, USER_PROMPT_TEMPLATE
from environments.pwncollege_env.sdk import DojoRLClient, DojoRLSyncClient
from environments.pwncollege_env.sdk import DojoRLClient, DojoRLSyncClient, RLChallenge
from environments.pwncollege_env.submit_flag_tool import (
clear_flag_context,
register_flag_context,
@@ -134,7 +134,7 @@ class PwnCollegeEnv(HermesAgentBaseEnv):
super().__init__(config, server_configs, slurm, testing)
self.config: PwnCollegeEnvConfig = config
self.train: list[dict] = []
self.train: list[RLChallenge] = []
self.iter = 0
self.solve_rate_buffer: list[float] = []
@@ -177,26 +177,23 @@ class PwnCollegeEnv(HermesAgentBaseEnv):
# Apply filters
for c in challenges:
if self.config.dojo_filter and c.get("dojo_id") != self.config.dojo_filter:
if self.config.dojo_filter and c.dojo_id != self.config.dojo_filter:
continue
if (
self.config.module_filter
and c.get("module_id") != self.config.module_filter
):
if self.config.module_filter and c.module_id != self.config.module_filter:
continue
self.train.append(c)
# If a specific challenge is set and no filters matched, use it directly
if not self.train and self.config.challenge:
parts = self.config.challenge.split("/")
self.train.append(
{
"id": self.config.challenge.split("/")[-1],
"module_id": self.config.challenge.split("/")[0],
"dojo_id": "unknown",
"name": self.config.challenge,
"description": "",
"challenge_key": self.config.challenge,
}
RLChallenge(
id=parts[-1],
module_id=parts[0],
dojo_id="unknown",
name=self.config.challenge,
description="",
)
)
if not self.train:
@@ -208,28 +205,23 @@ class PwnCollegeEnv(HermesAgentBaseEnv):
logger.info("Training on %d challenges", len(self.train))
async def get_next_item(self) -> Item:
async def get_next_item(self) -> RLChallenge:
"""Return next challenge item (round-robin)."""
item = self.train[self.iter % len(self.train)]
self.iter += 1
return item
def _get_challenge_key(self, item: Item) -> str:
"""Extract the challenge key from a dataset item."""
return item.get(
"challenge_key",
f"{item.get('module_id', '')}/{item.get('id', '')}",
)
def _get_challenge_key(self, item: RLChallenge) -> str:
"""Extract the challenge key from a challenge."""
return item.challenge_key or f"{item.module_id or ''}/{item.id}"
def format_prompt(self, item: Item) -> str:
def format_prompt(self, item: RLChallenge) -> str:
"""Build user prompt from challenge metadata."""
challenge_key = self._get_challenge_key(item)
return USER_PROMPT_TEMPLATE.format(
module_name=item.get("module_id", "unknown"),
challenge_name=item.get("name", item.get("id", "unknown")),
challenge_description=item.get(
"description", f"Solve the challenge: {challenge_key}"
),
module_name=item.module_id or "unknown",
challenge_name=item.name or item.id,
challenge_description=item.description or f"Solve the challenge: {challenge_key}",
)
async def collect_trajectory(
@@ -393,8 +385,8 @@ class PwnCollegeEnv(HermesAgentBaseEnv):
all_challenges = await self.client.list_challenges()
eval_challenges = [
c for c in all_challenges
if c.get("dojo_id") == self.config.eval_dojo
and (self.config.eval_module is None or c.get("module_id") == self.config.eval_module)
if c.dojo_id == self.config.eval_dojo
and (self.config.eval_module is None or c.module_id == self.config.eval_module)
]
if not eval_challenges:
@@ -413,15 +405,15 @@ class PwnCollegeEnv(HermesAgentBaseEnv):
semaphore = asyncio.Semaphore(self.config.eval_concurrency)
results: list[dict] = []
async def eval_one(challenge: dict) -> dict:
challenge_key = f"{challenge.get('module_id', '')}/{challenge.get('id', '')}"
async def eval_one(challenge: RLChallenge) -> dict:
challenge_key = self._get_challenge_key(challenge)
async with semaphore:
try:
scored, _ = await self.collect_trajectory(challenge)
solved = scored is not None and scored.get("scores", 0.0) >= 1.0
return {
"challenge": challenge_key,
"name": challenge.get("name", ""),
"name": challenge.name,
"solved": solved,
"reward": scored.get("scores", 0.0) if scored else 0.0,
}
@@ -429,7 +421,7 @@ class PwnCollegeEnv(HermesAgentBaseEnv):
logger.error("Eval failed for %s: %s", challenge_key, e)
return {
"challenge": challenge_key,
"name": challenge.get("name", ""),
"name": challenge.name,
"solved": False,
"reward": 0.0,
"error": str(e),

View File

@@ -25,7 +25,35 @@ class RLInstance:
module_id: str
dojo_id: str
flag: str | None = None
created_at: str | None = None
created_at: float | None = None
status: str | None = None
@property
def challenge_key(self) -> str:
return f"{self.module_id}/{self.challenge_id}"
@dataclass
class RLChallenge:
id: str
name: str
description: str
module_id: str | None = None
dojo_id: str | None = None
@property
def challenge_key(self) -> str | None:
if self.module_id:
return f"{self.module_id}/{self.id}"
return None
@dataclass
class RLStatus:
enabled: bool
max_instances: int
running: int
instances: list[RLInstance]
class DojoRLClient:
@@ -66,10 +94,69 @@ class DojoRLClient:
resp.raise_for_status()
return resp.json()
# ── Response Parsing ──────────────────────────────────────────────────────
# The API uses different field names in create/reset vs get/list responses.
# These parsers normalize everything into RLInstance.
@staticmethod
def _parse_create_response(data: dict[str, Any]) -> RLInstance:
return RLInstance(
slot=data["slot"],
ssh_user=data["ssh_user"],
challenge_id=data["challenge"],
module_id=data["module"],
dojo_id=data["dojo"],
)
@staticmethod
def _parse_instance_detail(data: dict[str, Any]) -> RLInstance:
created_at = data.get("created_at")
return RLInstance(
slot=data["slot"],
ssh_user=data.get("ssh_user", f"rl_{data['slot']}"),
challenge_id=data["challenge_id"],
module_id=data["module_id"],
dojo_id=data["dojo_id"],
flag=data.get("flag"),
created_at=float(created_at) if created_at else None,
)
@staticmethod
def _parse_instance_listing(data: dict[str, Any]) -> RLInstance:
created_at = data.get("created_at")
return RLInstance(
slot=data["slot"],
ssh_user=f"rl_{data['slot']}",
challenge_id=data["challenge_id"],
module_id=data["module_id"],
dojo_id=data["dojo_id"],
created_at=float(created_at) if created_at else None,
status=data.get("status"),
)
@staticmethod
def _parse_challenge(data: dict[str, Any]) -> RLChallenge:
return RLChallenge(
id=data["id"],
name=data["name"],
description=data["description"],
module_id=data.get("module_id"),
dojo_id=data.get("dojo_id"),
)
# ── RL Instance Lifecycle ─────────────────────────────────────────────────
async def status(self) -> dict[str, Any]:
return await self._get("/status")
async def status(self) -> RLStatus:
result = await self._get("/status")
instances = [
self._parse_instance_listing(inst) for inst in result.get("instances", [])
]
return RLStatus(
enabled=result["enabled"],
max_instances=result["max_instances"],
running=result["running"],
instances=instances,
)
async def create_instance(
self, challenge: str, *, variant: int | None = None
@@ -80,31 +167,19 @@ class DojoRLClient:
result = await self._post("/instances", json=data)
if not result.get("success"):
raise RuntimeError(f"Failed to create instance: {result.get('error')}")
return RLInstance(
slot=result["slot"],
ssh_user=result["ssh_user"],
challenge_id=result["challenge"],
module_id=result["module"],
dojo_id=result["dojo"],
)
return self._parse_create_response(result)
async def get_instance(self, slot: int) -> RLInstance:
result = await self._get(f"/instances/{slot}")
if not result.get("success"):
raise KeyError(f"No instance at slot {slot}")
return RLInstance(
slot=result["slot"],
ssh_user=result.get("ssh_user", f"rl_{slot}"),
challenge_id=result["challenge_id"],
module_id=result["module_id"],
dojo_id=result["dojo_id"],
flag=result.get("flag"),
created_at=result.get("created_at"),
)
return self._parse_instance_detail(result)
async def list_instances(self) -> list[dict[str, Any]]:
async def list_instances(self) -> list[RLInstance]:
result = await self._get("/instances")
return result.get("instances", [])
return [
self._parse_instance_listing(inst) for inst in result.get("instances", [])
]
async def destroy_instance(self, slot: int) -> None:
result = await self._delete(f"/instances/{slot}")
@@ -120,13 +195,7 @@ class DojoRLClient:
result = await self._post(f"/instances/{slot}/reset", json=data)
if not result.get("success"):
raise RuntimeError(f"Failed to reset instance: {result.get('error')}")
return RLInstance(
slot=result["slot"],
ssh_user=result["ssh_user"],
challenge_id=result["challenge"],
module_id=result["module"],
dojo_id=result["dojo"],
)
return self._parse_create_response(result)
async def check_flag(self, slot: int, flag: str) -> bool:
result = await self._post(f"/instances/{slot}/check", json={"flag": flag})
@@ -140,9 +209,9 @@ class DojoRLClient:
# ── Challenge Discovery ───────────────────────────────────────────────────
async def list_challenges(self) -> list[dict[str, Any]]:
async def list_challenges(self) -> list[RLChallenge]:
result = await self._get("/challenges")
return result.get("challenges", [])
return [self._parse_challenge(ch) for ch in result.get("challenges", [])]
# ── Admin (requires auth) ─────────────────────────────────────────────────
@@ -200,7 +269,7 @@ class DojoRLClient:
async def destroy_all(self) -> int:
instances = await self.list_instances()
for inst in instances:
await self.destroy_instance(inst["slot"])
await self.destroy_instance(inst.slot)
return len(instances)
@@ -214,10 +283,12 @@ class DojoRLSyncClient:
def __init__(self, base_url: str, timeout: float = 120.0):
import threading
self._async = DojoRLClient(base_url, timeout)
self._loop = asyncio.new_event_loop()
self._thread = threading.Thread(
target=self._loop.run_forever, daemon=True,
target=self._loop.run_forever,
daemon=True,
)
self._thread.start()
@@ -236,11 +307,11 @@ class DojoRLSyncClient:
try:
self._run(self._async.close())
except Exception:
pass # Best-effort: httpx client may already be closed
pass
self._loop.call_soon_threadsafe(self._loop.stop)
self._thread.join(timeout=5)
def status(self) -> dict[str, Any]:
def status(self) -> RLStatus:
return self._run(self._async.status())
def create_instance(
@@ -251,7 +322,7 @@ class DojoRLSyncClient:
def get_instance(self, slot: int) -> RLInstance:
return self._run(self._async.get_instance(slot))
def list_instances(self) -> list[dict[str, Any]]:
def list_instances(self) -> list[RLInstance]:
return self._run(self._async.list_instances())
def destroy_instance(self, slot: int) -> None:
@@ -266,7 +337,7 @@ class DojoRLSyncClient:
def get_flag(self, slot: int) -> str:
return self._run(self._async.get_flag(slot))
def list_challenges(self) -> list[dict[str, Any]]:
def list_challenges(self) -> list[RLChallenge]:
return self._run(self._async.list_challenges())
def admin_login(self, username: str = "admin", password: str = "admin") -> None:
@@ -332,7 +403,8 @@ class EpisodePool:
except Exception as e:
logger.error(
"Failed to reset instance slot %d, returning stale instance: %s",
instance.slot, e,
instance.slot,
e,
)
await self._available.put(instance)