mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-05-07 19:26:56 +08:00
Compare commits
3 Commits
atropos-in
...
endless-te
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
735723803f | ||
|
|
1472cc302d | ||
|
|
9c200abdb1 |
4
.gitignore
vendored
4
.gitignore
vendored
@@ -66,7 +66,3 @@ test_*.py
|
||||
|
||||
# Nomad data
|
||||
/tmp/NomadClient*/
|
||||
|
||||
*.egg-info*
|
||||
wandb
|
||||
logs
|
||||
@@ -31,11 +31,8 @@ def _require_atroposlib() -> None:
|
||||
_require_atroposlib()
|
||||
|
||||
# Re-export the most commonly used pieces for convenience.
|
||||
# Agent imports are eager (always available).
|
||||
from .agent import AgentConfig, AgentResult, AgentStep, AtroposAgent, SequenceData # noqa: E402
|
||||
|
||||
# Env imports are lazy to avoid pulling in deleted atropos.tools dependencies.
|
||||
# Use: from atropos.envs import AgentEnv, AgentEnvConfig (if needed)
|
||||
from .envs import AgentEnv, AgentEnvConfig # noqa: E402
|
||||
|
||||
__all__ = [
|
||||
"AtroposAgent",
|
||||
@@ -43,5 +40,7 @@ __all__ = [
|
||||
"AgentResult",
|
||||
"AgentStep",
|
||||
"SequenceData",
|
||||
"AgentEnv",
|
||||
"AgentEnvConfig",
|
||||
]
|
||||
|
||||
|
||||
@@ -1,18 +1,10 @@
|
||||
"""
|
||||
Environment implementations for atropos-agent.
|
||||
|
||||
NOTE: AgentEnv is the OLD environment system, replaced by
|
||||
environments/hermes_base_env.py (HermesAgentBaseEnv).
|
||||
Import is lazy to avoid pulling in deleted dependencies.
|
||||
"""
|
||||
|
||||
from .agent_env import AgentEnv, AgentEnvConfig
|
||||
|
||||
def __getattr__(name):
|
||||
"""Lazy import to avoid breaking when old dependencies are removed."""
|
||||
if name in ("AgentEnv", "AgentEnvConfig"):
|
||||
from .agent_env import AgentEnv, AgentEnvConfig
|
||||
return {"AgentEnv": AgentEnv, "AgentEnvConfig": AgentEnvConfig}[name]
|
||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||
|
||||
# NOTE: Additional example envs exist as modules (e.g. `test_env`, `swe_smith_oracle_env`),
|
||||
# but are intentionally not imported here to avoid pulling heavy optional deps at import time.
|
||||
|
||||
__all__ = ["AgentEnv", "AgentEnvConfig"]
|
||||
|
||||
873
atropos/envs/endless_terminals_env.py
Normal file
873
atropos/envs/endless_terminals_env.py
Normal file
@@ -0,0 +1,873 @@
|
||||
"""
|
||||
Endless Terminals Environment for Hermes-Agent + Atropos RL.
|
||||
|
||||
Runs terminal tasks from the Endless Terminals dataset.
|
||||
Supports three modes:
|
||||
1. Local directory: tasks from a local folder of task_* dirs (default)
|
||||
2. HuggingFace dataset: tasks from a HF dataset
|
||||
3. Procedural: generate tasks on-the-fly via LLM (requires vLLM)
|
||||
|
||||
Each task provides a Dockerfile that defines the initial environment.
|
||||
The agent solves the task using terminal commands inside a Docker container.
|
||||
Scoring is done by running pytest on `test_final_state.py` in the container.
|
||||
|
||||
Run (standalone process mode):
|
||||
python -m atropos.envs.endless_terminals_env process \
|
||||
--env.use_wandb false \
|
||||
--env.total_steps 100 \
|
||||
--env.group_size 4
|
||||
|
||||
Run (Tinker serve mode):
|
||||
# Terminal 1: run-api
|
||||
# Terminal 2: python launch_training.py --config configs/endless_terminals.yaml
|
||||
# Terminal 3:
|
||||
TINKER_CONFIG=configs/endless_terminals.yaml \
|
||||
ENDLESS_TERMINALS_DIR=/path/to/endless-terminals \
|
||||
python -m atropos.envs.endless_terminals_env serve
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from pydantic import Field
|
||||
|
||||
from atroposlib.envs.base import APIServerConfig, Item
|
||||
|
||||
from ..agent import AgentConfig
|
||||
from ..backends.docker_direct_backend import (
|
||||
DockerDirectBackend,
|
||||
build_docker_image,
|
||||
docker_image_exists,
|
||||
)
|
||||
from ..tools import ToolCall
|
||||
from .agent_env import AgentEnv, AgentEnvConfig
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tinker integration
|
||||
# ---------------------------------------------------------------------------
|
||||
# When TINKER_CONFIG is set, we load model/training params from the Tinker YAML.
|
||||
# Custom env fields (ENDLESS_TERMINALS_DIR, etc.) are always read from env vars.
|
||||
TINKER_CONFIG = os.getenv("TINKER_CONFIG", "")
|
||||
|
||||
|
||||
def _load_tinker_config():
|
||||
"""Load TinkerAtroposConfig if available, else return None."""
|
||||
if not TINKER_CONFIG:
|
||||
return None
|
||||
config_path = Path(TINKER_CONFIG)
|
||||
if not config_path.exists():
|
||||
print(f"[EndlessTerminalsEnv] TINKER_CONFIG={TINKER_CONFIG} not found, ignoring", flush=True)
|
||||
return None
|
||||
try:
|
||||
from tinker_atropos.config import TinkerAtroposConfig
|
||||
config = TinkerAtroposConfig.from_yaml(config_path)
|
||||
print(f"[EndlessTerminalsEnv] Loaded Tinker config from {config_path}", flush=True)
|
||||
return config
|
||||
except ImportError:
|
||||
print("[EndlessTerminalsEnv] tinker_atropos not installed, ignoring TINKER_CONFIG", flush=True)
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"[EndlessTerminalsEnv] Error loading Tinker config: {e}", flush=True)
|
||||
return None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Config
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class EndlessTerminalsEnvConfig(AgentEnvConfig):
|
||||
"""Configuration for Endless Terminals environment."""
|
||||
|
||||
# ---- Local directory mode (primary) ----
|
||||
use_local_dir: bool = Field(
|
||||
default=True,
|
||||
description="Load tasks from a local directory of task_* folders.",
|
||||
)
|
||||
local_tasks_dir: str = Field(
|
||||
default="",
|
||||
description="Path to directory containing task_* folders. Required if use_local_dir=True.",
|
||||
)
|
||||
prebuild_images: bool = Field(
|
||||
default=False,
|
||||
description="Pre-build ALL Docker images during setup (slow but avoids build-during-training).",
|
||||
)
|
||||
max_concurrent_builds: int = Field(
|
||||
default=4,
|
||||
description="Max parallel Docker image builds during pre-build.",
|
||||
)
|
||||
|
||||
# ---- HuggingFace dataset mode ----
|
||||
use_dataset: bool = Field(
|
||||
default=False,
|
||||
description="Load tasks from HuggingFace dataset.",
|
||||
)
|
||||
dataset_name: str = Field(
|
||||
default="obiwan96/endless-terminals-train",
|
||||
description="HuggingFace dataset name (if use_dataset=True)",
|
||||
)
|
||||
dataset_split: str = Field(default="train")
|
||||
dataset_cache_dir: str = Field(default="~/.cache/huggingface/datasets")
|
||||
tasks_base_dir: str = Field(
|
||||
default="",
|
||||
description="Base directory containing task_* folders (for dataset mode path resolution).",
|
||||
)
|
||||
|
||||
# ---- Procedural generation mode ----
|
||||
task_gen_model: str = Field(default="Qwen/Qwen3-32B")
|
||||
task_gen_temperature: float = Field(default=1.0)
|
||||
task_gen_max_tokens: int = Field(default=2048)
|
||||
|
||||
# ---- Container / scoring ----
|
||||
container_build_timeout_s: float = Field(default=600.0, description="Docker build timeout")
|
||||
test_timeout_s: int = Field(default=120, description="Test execution timeout (seconds)")
|
||||
keep_failed_tasks: bool = Field(default=False)
|
||||
|
||||
# ---- Agent defaults ----
|
||||
agent_max_steps: int = Field(default=32)
|
||||
agent_temperature: float = Field(default=0.7)
|
||||
|
||||
# ---- Docker image prefix ----
|
||||
docker_image_prefix: str = Field(
|
||||
default="endless-terminals",
|
||||
description="Docker image name prefix for built task images.",
|
||||
)
|
||||
|
||||
# ---- Server defaults ----
|
||||
server_base_url: str = Field(default="http://127.0.0.1:8080")
|
||||
server_model: str = Field(default="hermes-4-36b")
|
||||
tokenizer_name: str = Field(default="NousResearch/Hermes-4.3-36B")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Env
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class EndlessTerminalsEnv(AgentEnv[EndlessTerminalsEnvConfig]):
|
||||
"""
|
||||
Endless Terminals environment.
|
||||
|
||||
Each task:
|
||||
1. Has a Dockerfile defining the initial container state
|
||||
2. Has an instruction.md describing what the agent should do
|
||||
3. Has tests/test_final_state.py to verify completion
|
||||
|
||||
Flow per trajectory:
|
||||
1. get_next_item() → picks a task
|
||||
2. setup_trajectory_workspace() → builds Docker image, registers with backend
|
||||
3. Agent solves task via terminal commands (docker exec in the container)
|
||||
4. verify_and_score_trajectory() → runs pytest in container, returns binary reward
|
||||
"""
|
||||
|
||||
name = "endless_terminals_env"
|
||||
env_config_cls = EndlessTerminalsEnvConfig
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: EndlessTerminalsEnvConfig,
|
||||
server_configs: List[APIServerConfig],
|
||||
slurm: bool = False,
|
||||
testing: bool = False,
|
||||
):
|
||||
super().__init__(config, server_configs, slurm, testing)
|
||||
self._iteration = 0
|
||||
|
||||
# Local dir mode
|
||||
self._local_tasks: List[Dict[str, Any]] = []
|
||||
self._local_task_indices: List[int] = []
|
||||
self._local_current_index = 0
|
||||
|
||||
# Eval split (held-out tasks)
|
||||
self._eval_tasks: List[Dict[str, Any]] = []
|
||||
|
||||
# Training metrics
|
||||
self._train_scores_buffer: List[float] = []
|
||||
self._eval_metrics: List[tuple] = []
|
||||
|
||||
# HF dataset mode
|
||||
self._dataset = None
|
||||
self._dataset_indices: List[int] = []
|
||||
self._dataset_current_index = 0
|
||||
|
||||
# Docker image cache: task_name -> image_tag
|
||||
self._image_cache: Dict[str, str] = {}
|
||||
self._build_lock = asyncio.Lock()
|
||||
|
||||
# ---- Config init (CLI) ----
|
||||
|
||||
@classmethod
|
||||
def config_init(cls) -> Tuple[EndlessTerminalsEnvConfig, List[APIServerConfig]]:
|
||||
"""
|
||||
Initialize config.
|
||||
|
||||
Two modes:
|
||||
1. Tinker mode: TINKER_CONFIG env var points to a Tinker YAML.
|
||||
Model, training params, and server config come from the YAML.
|
||||
2. Standalone mode: Everything from env vars (ATROPOS_SERVER_*, etc.)
|
||||
|
||||
In both modes, Endless Terminals-specific fields (ENDLESS_TERMINALS_DIR,
|
||||
PREBUILD_IMAGES, etc.) are always read from env vars.
|
||||
"""
|
||||
tinker_cfg = _load_tinker_config()
|
||||
|
||||
# ── Endless Terminals-specific fields (always from env vars) ──
|
||||
local_tasks_dir = os.getenv("ENDLESS_TERMINALS_DIR", "")
|
||||
use_local_dir = bool(local_tasks_dir)
|
||||
|
||||
if tinker_cfg is not None:
|
||||
# ── Tinker mode ─────────────────────────────────────────
|
||||
print("[EndlessTerminalsEnv] Using Tinker config", flush=True)
|
||||
|
||||
env_config = EndlessTerminalsEnvConfig(
|
||||
# Standard Atropos fields from Tinker YAML
|
||||
tokenizer_name=tinker_cfg.base_model,
|
||||
group_size=tinker_cfg.group_size,
|
||||
use_wandb=tinker_cfg.use_wandb,
|
||||
rollout_server_url=tinker_cfg.atropos_api_url,
|
||||
total_steps=tinker_cfg.num_steps,
|
||||
batch_size=tinker_cfg.batch_size,
|
||||
steps_per_eval=tinker_cfg.steps_per_eval,
|
||||
max_token_length=tinker_cfg.max_token_env_length,
|
||||
max_num_workers=tinker_cfg.max_num_workers,
|
||||
max_batches_offpolicy=tinker_cfg.max_batches_offpolicy,
|
||||
ensure_scores_are_not_same=tinker_cfg.ensure_scores_are_not_same,
|
||||
wandb_name=f"{tinker_cfg.wandb_run_name}-env",
|
||||
include_messages=True,
|
||||
|
||||
# Tooling: terminal only
|
||||
enabled_toolsets=["terminal"],
|
||||
disabled_toolsets=[],
|
||||
|
||||
# Agent config
|
||||
agent_max_steps=int(os.getenv("AGENT_MAX_STEPS", "32")),
|
||||
agent_temperature=float(os.getenv("AGENT_TEMPERATURE", "0.7")),
|
||||
|
||||
# Docker-direct backend (no Nomad needed)
|
||||
tool_pool_mode="docker_direct",
|
||||
sandbox_image="ubuntu:22.04",
|
||||
purge_job_on_start=False,
|
||||
purge_job_on_shutdown=False,
|
||||
|
||||
# Endless Terminals fields
|
||||
use_local_dir=use_local_dir,
|
||||
local_tasks_dir=local_tasks_dir,
|
||||
prebuild_images=os.getenv("PREBUILD_IMAGES", "false").lower() == "true",
|
||||
use_dataset=os.getenv("USE_DATASET", "false").lower() == "true",
|
||||
dataset_name=os.getenv("ENDLESS_DATASET", "obiwan96/endless-terminals-train"),
|
||||
container_build_timeout_s=float(os.getenv("CONTAINER_BUILD_TIMEOUT", "600")),
|
||||
test_timeout_s=int(os.getenv("TEST_TIMEOUT", "120")),
|
||||
)
|
||||
|
||||
server_configs = [
|
||||
APIServerConfig(
|
||||
model_name=tinker_cfg.base_model,
|
||||
base_url=tinker_cfg.inference_api_url + "/v1",
|
||||
api_key="x",
|
||||
server_type="sglang",
|
||||
num_requests_for_eval=tinker_cfg.num_requests_for_eval,
|
||||
timeout=600, # Longer timeout for multi-step agent trajectories
|
||||
),
|
||||
]
|
||||
return env_config, server_configs
|
||||
|
||||
else:
|
||||
# ── Standalone mode (env vars) ──────────────────────────
|
||||
base_url = (
|
||||
os.getenv("ATROPOS_SERVER_BASE_URL")
|
||||
or os.getenv("OPENAI_BASE_URL")
|
||||
or os.getenv("LLM_BASE_URL")
|
||||
or "http://127.0.0.1:8080"
|
||||
)
|
||||
model = os.getenv("ATROPOS_SERVER_MODEL") or os.getenv("LLM_MODEL") or "hermes-4-36b"
|
||||
api_key = (
|
||||
os.getenv("ATROPOS_SERVER_API_KEY")
|
||||
or os.getenv("NOUS_API_KEY")
|
||||
or os.getenv("OPENAI_API_KEY")
|
||||
or "local"
|
||||
)
|
||||
|
||||
env_config = EndlessTerminalsEnvConfig(
|
||||
tokenizer_name=os.getenv("ATROPOS_TOKENIZER_NAME") or "NousResearch/Hermes-4.3-36B",
|
||||
group_size=int(os.getenv("ATROPOS_GROUP_SIZE", "4")),
|
||||
use_wandb=os.getenv("USE_WANDB", "false").lower() == "true",
|
||||
include_messages=True,
|
||||
total_steps=int(os.getenv("ATROPOS_TOTAL_STEPS", "1000")),
|
||||
batch_size=int(os.getenv("ATROPOS_BATCH_SIZE", "32")),
|
||||
server_base_url=base_url,
|
||||
server_model=model,
|
||||
|
||||
# Tooling
|
||||
enabled_toolsets=["terminal"],
|
||||
disabled_toolsets=[],
|
||||
|
||||
# Agent
|
||||
agent_max_steps=int(os.getenv("AGENT_MAX_STEPS", "32")),
|
||||
agent_temperature=float(os.getenv("AGENT_TEMPERATURE", "0.7")),
|
||||
|
||||
# Docker-direct backend
|
||||
tool_pool_mode="docker_direct",
|
||||
sandbox_image="ubuntu:22.04",
|
||||
purge_job_on_start=False,
|
||||
purge_job_on_shutdown=False,
|
||||
|
||||
# Endless Terminals fields
|
||||
use_local_dir=use_local_dir,
|
||||
local_tasks_dir=local_tasks_dir,
|
||||
prebuild_images=os.getenv("PREBUILD_IMAGES", "false").lower() == "true",
|
||||
use_dataset=os.getenv("USE_DATASET", "false").lower() == "true",
|
||||
dataset_name=os.getenv("ENDLESS_DATASET", "obiwan96/endless-terminals-train"),
|
||||
task_gen_model=os.getenv("TASK_GEN_MODEL", "Qwen/Qwen3-32B"),
|
||||
container_build_timeout_s=float(os.getenv("CONTAINER_BUILD_TIMEOUT", "600")),
|
||||
test_timeout_s=int(os.getenv("TEST_TIMEOUT", "120")),
|
||||
)
|
||||
|
||||
server_configs = [
|
||||
APIServerConfig(
|
||||
model_name=model,
|
||||
base_url=f"{base_url.rstrip('/')}/v1",
|
||||
api_key=api_key,
|
||||
num_max_requests_at_once=int(os.getenv("MAX_CONCURRENT_REQUESTS", "4")),
|
||||
num_requests_for_eval=int(os.getenv("MAX_EVAL_REQUESTS", "4")),
|
||||
timeout=300,
|
||||
)
|
||||
]
|
||||
return env_config, server_configs
|
||||
|
||||
# ---- Setup ----
|
||||
|
||||
async def setup_agent_env(self) -> None:
|
||||
"""Env-specific setup: scan tasks and optionally pre-build images."""
|
||||
if self.config.use_local_dir:
|
||||
await self._setup_local_dir()
|
||||
elif self.config.use_dataset:
|
||||
await self._setup_hf_dataset()
|
||||
else:
|
||||
print("[EndlessTerminalsEnv] Using procedural task generation", flush=True)
|
||||
|
||||
async def _setup_local_dir(self) -> None:
|
||||
"""Scan local directory for task_* folders."""
|
||||
tasks_dir = Path(self.config.local_tasks_dir).expanduser().resolve()
|
||||
if not tasks_dir.is_dir():
|
||||
raise RuntimeError(f"local_tasks_dir does not exist: {tasks_dir}")
|
||||
|
||||
print(f"[EndlessTerminalsEnv] Scanning {tasks_dir} for tasks...", flush=True)
|
||||
|
||||
tasks = []
|
||||
for entry in sorted(tasks_dir.iterdir()):
|
||||
if not entry.is_dir() or not entry.name.startswith("task_"):
|
||||
continue
|
||||
|
||||
# Validate required files
|
||||
dockerfile = entry / "environment" / "Dockerfile"
|
||||
instruction = entry / "instruction.md"
|
||||
test_final = entry / "tests" / "test_final_state.py"
|
||||
|
||||
if not dockerfile.exists():
|
||||
continue
|
||||
if not instruction.exists():
|
||||
continue
|
||||
if not test_final.exists():
|
||||
continue
|
||||
|
||||
# Read task metadata
|
||||
task_json_path = entry / "environment" / "task.json"
|
||||
description = instruction.read_text(encoding="utf-8").strip()
|
||||
|
||||
truth = ""
|
||||
if task_json_path.exists():
|
||||
try:
|
||||
task_json = json.loads(task_json_path.read_text(encoding="utf-8"))
|
||||
# task.json may have a richer description; prefer instruction.md
|
||||
truth = task_json.get("truth", "")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
tasks.append({
|
||||
"task_name": entry.name,
|
||||
"task_dir": str(entry),
|
||||
"dockerfile": str(dockerfile),
|
||||
"description": description,
|
||||
"truth": truth,
|
||||
"test_final": str(test_final),
|
||||
})
|
||||
|
||||
if not tasks:
|
||||
raise RuntimeError(f"No valid task_* directories found in {tasks_dir}")
|
||||
|
||||
# Split into train and eval (hold out ~5% for eval, min 10, max 50)
|
||||
random.shuffle(tasks)
|
||||
eval_count = max(10, min(50, len(tasks) // 20))
|
||||
eval_count = min(eval_count, len(tasks) // 2) # Never more than half
|
||||
|
||||
self._eval_tasks = tasks[:eval_count]
|
||||
self._local_tasks = tasks[eval_count:]
|
||||
self._local_task_indices = list(range(len(self._local_tasks)))
|
||||
random.shuffle(self._local_task_indices)
|
||||
self._local_current_index = 0
|
||||
|
||||
print(
|
||||
f"[EndlessTerminalsEnv] Found {len(tasks)} valid tasks "
|
||||
f"({len(self._local_tasks)} train, {len(self._eval_tasks)} eval)",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
# Optionally pre-build all Docker images
|
||||
if self.config.prebuild_images:
|
||||
await self._prebuild_images()
|
||||
|
||||
async def _prebuild_images(self) -> None:
|
||||
"""Pre-build Docker images for all tasks."""
|
||||
print(f"[EndlessTerminalsEnv] Pre-building Docker images...", flush=True)
|
||||
sem = asyncio.Semaphore(self.config.max_concurrent_builds)
|
||||
built = 0
|
||||
skipped = 0
|
||||
failed = 0
|
||||
|
||||
async def _build_one(task: Dict[str, Any]) -> None:
|
||||
nonlocal built, skipped, failed
|
||||
image_tag = self._image_tag_for_task(task["task_name"])
|
||||
|
||||
if docker_image_exists(image_tag):
|
||||
self._image_cache[task["task_name"]] = image_tag
|
||||
skipped += 1
|
||||
return
|
||||
|
||||
async with sem:
|
||||
ok = await build_docker_image(
|
||||
task["dockerfile"], image_tag,
|
||||
timeout_s=self.config.container_build_timeout_s,
|
||||
)
|
||||
if ok:
|
||||
self._image_cache[task["task_name"]] = image_tag
|
||||
built += 1
|
||||
else:
|
||||
failed += 1
|
||||
|
||||
await asyncio.gather(*[_build_one(t) for t in self._local_tasks])
|
||||
print(
|
||||
f"[EndlessTerminalsEnv] Pre-build: {built} built, {skipped} cached, {failed} failed",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
async def _setup_hf_dataset(self) -> None:
|
||||
"""Load HuggingFace dataset."""
|
||||
print(f"[EndlessTerminalsEnv] Loading dataset: {self.config.dataset_name}", flush=True)
|
||||
try:
|
||||
from datasets import load_dataset
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
self._dataset = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: load_dataset(
|
||||
self.config.dataset_name,
|
||||
split=self.config.dataset_split,
|
||||
cache_dir=os.path.expanduser(self.config.dataset_cache_dir),
|
||||
),
|
||||
)
|
||||
self._dataset_indices = list(range(len(self._dataset)))
|
||||
random.shuffle(self._dataset_indices)
|
||||
self._dataset_current_index = 0
|
||||
print(f"[EndlessTerminalsEnv] Loaded {len(self._dataset)} tasks from dataset", flush=True)
|
||||
except Exception as e:
|
||||
print(f"[EndlessTerminalsEnv] ERROR loading dataset: {e}", flush=True)
|
||||
raise
|
||||
|
||||
# ---- Image helpers ----
|
||||
|
||||
def _image_tag_for_task(self, task_name: str) -> str:
|
||||
return f"{self.config.docker_image_prefix}:{task_name}"
|
||||
|
||||
async def _ensure_image(self, task: Dict[str, Any]) -> str:
|
||||
"""Ensure the Docker image for a task is built. Returns image tag."""
|
||||
task_name = task["task_name"]
|
||||
image_tag = self._image_tag_for_task(task_name)
|
||||
|
||||
# Fast path: already cached
|
||||
if task_name in self._image_cache:
|
||||
return self._image_cache[task_name]
|
||||
|
||||
async with self._build_lock:
|
||||
# Double-check after acquiring lock
|
||||
if task_name in self._image_cache:
|
||||
return self._image_cache[task_name]
|
||||
|
||||
# Check if image exists in Docker
|
||||
if docker_image_exists(image_tag):
|
||||
self._image_cache[task_name] = image_tag
|
||||
return image_tag
|
||||
|
||||
# Build it
|
||||
print(f"[EndlessTerminalsEnv] Building image {image_tag}...", flush=True)
|
||||
ok = await build_docker_image(
|
||||
task["dockerfile"], image_tag,
|
||||
timeout_s=self.config.container_build_timeout_s,
|
||||
)
|
||||
if not ok:
|
||||
raise RuntimeError(f"Failed to build Docker image for {task_name}")
|
||||
|
||||
self._image_cache[task_name] = image_tag
|
||||
return image_tag
|
||||
|
||||
# ---- Item generation ----
|
||||
|
||||
async def get_next_item(self) -> Item:
|
||||
self._iteration += 1
|
||||
|
||||
if self.config.use_local_dir and self._local_tasks:
|
||||
return self._get_next_local_item()
|
||||
elif self.config.use_dataset and self._dataset is not None:
|
||||
return self._get_next_dataset_item()
|
||||
else:
|
||||
return self._get_fallback_item()
|
||||
|
||||
def _get_next_local_item(self) -> Item:
|
||||
"""Pick the next task from local directories."""
|
||||
idx = self._local_task_indices[self._local_current_index]
|
||||
task = self._local_tasks[idx]
|
||||
|
||||
self._local_current_index += 1
|
||||
if self._local_current_index >= len(self._local_task_indices):
|
||||
random.shuffle(self._local_task_indices)
|
||||
self._local_current_index = 0
|
||||
print("[EndlessTerminalsEnv] Reshuffled local tasks (epoch complete)", flush=True)
|
||||
|
||||
return {
|
||||
"task_id": f"local_{self._iteration:06d}_{task['task_name']}",
|
||||
"task_name": task["task_name"],
|
||||
"description": task["description"],
|
||||
"truth": task.get("truth", ""),
|
||||
"task_dir": task["task_dir"],
|
||||
"dockerfile": task["dockerfile"],
|
||||
"test_final": task["test_final"],
|
||||
"from_local_dir": True,
|
||||
}
|
||||
|
||||
def _get_next_dataset_item(self) -> Item:
|
||||
"""Pick the next task from HuggingFace dataset."""
|
||||
idx = self._dataset_indices[self._dataset_current_index]
|
||||
task = self._dataset[idx]
|
||||
|
||||
self._dataset_current_index += 1
|
||||
if self._dataset_current_index >= len(self._dataset_indices):
|
||||
random.shuffle(self._dataset_indices)
|
||||
self._dataset_current_index = 0
|
||||
print("[EndlessTerminalsEnv] Reshuffled dataset (epoch complete)", flush=True)
|
||||
|
||||
# Resolve task directory
|
||||
task_dir = task.get("extra_info", {}).get("task_dir") or task.get("reward_spec", {}).get("ground_truth", "")
|
||||
if self.config.tasks_base_dir:
|
||||
task_name = Path(task_dir).name
|
||||
task_dir = str(Path(self.config.tasks_base_dir) / task_name)
|
||||
|
||||
task_dir_path = Path(task_dir)
|
||||
return {
|
||||
"task_id": f"dataset_{self._iteration:06d}_{task_dir_path.name}",
|
||||
"task_name": task_dir_path.name,
|
||||
"description": task.get("description", ""),
|
||||
"task_dir": task_dir,
|
||||
"dockerfile": str(task_dir_path / "environment" / "Dockerfile"),
|
||||
"test_final": str(task_dir_path / "tests" / "test_final_state.py"),
|
||||
"from_dataset": True,
|
||||
}
|
||||
|
||||
def _get_fallback_item(self) -> Item:
|
||||
return {
|
||||
"task_id": f"fallback_{self._iteration:06d}",
|
||||
"task_name": "fallback",
|
||||
"description": (
|
||||
"Create a file named 'hello.txt' in /home/user/ containing "
|
||||
"the text 'Hello, World!' on a single line."
|
||||
),
|
||||
"task_dir": "",
|
||||
"dockerfile": "",
|
||||
"test_final": "",
|
||||
}
|
||||
|
||||
# ---- AgentEnv hooks ----
|
||||
|
||||
def build_task(self, item: Item) -> str:
|
||||
"""Return the task prompt for the agent."""
|
||||
return str(item.get("description", ""))
|
||||
|
||||
def build_agent_config(self, item: Item) -> AgentConfig:
|
||||
return AgentConfig(
|
||||
max_steps=self.config.agent_max_steps,
|
||||
temperature=self.config.agent_temperature,
|
||||
max_tokens=self.config.agent_max_tokens,
|
||||
tool_delay_s=self.config.agent_tool_delay_s,
|
||||
)
|
||||
|
||||
async def setup_trajectory_workspace(
|
||||
self,
|
||||
item: Item,
|
||||
*,
|
||||
trajectory_id: str,
|
||||
exec_tool,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Build the Docker image for this task and register it with the backend.
|
||||
|
||||
The DockerDirectBackend will start a container from this image when the
|
||||
agent makes its first tool call (lazy acquisition via ToolExecutor).
|
||||
"""
|
||||
task_name = item.get("task_name", "unknown")
|
||||
dockerfile = item.get("dockerfile", "")
|
||||
|
||||
if not dockerfile or not Path(dockerfile).exists():
|
||||
print(f"[EndlessTerminalsEnv] WARNING: No Dockerfile for {task_name}", flush=True)
|
||||
return {"image": "ubuntu:22.04"}
|
||||
|
||||
# Build/get Docker image
|
||||
image_tag = await self._ensure_image({
|
||||
"task_name": task_name,
|
||||
"dockerfile": dockerfile,
|
||||
})
|
||||
|
||||
# Register image with the DockerDirect backend
|
||||
if isinstance(self._backend, DockerDirectBackend):
|
||||
self._backend.register_image(trajectory_id, image_tag)
|
||||
|
||||
return {"image": image_tag, "task_name": task_name}
|
||||
|
||||
async def score_trajectory(self, item: Item, final_response: str) -> float:
|
||||
"""Not used — scoring happens in verify_and_score_trajectory."""
|
||||
return 0.0
|
||||
|
||||
async def verify_and_score_trajectory(
|
||||
self,
|
||||
item: Item,
|
||||
final_response: str,
|
||||
*,
|
||||
trajectory_id: str,
|
||||
exec_tool,
|
||||
agent_result=None,
|
||||
workspace_meta=None,
|
||||
) -> tuple[float, Dict[str, Any]]:
|
||||
"""
|
||||
Run test_final_state.py inside the container and return binary reward.
|
||||
"""
|
||||
task_id = item.get("task_id", "unknown")
|
||||
test_final = item.get("test_final", "")
|
||||
|
||||
if not test_final or not Path(test_final).exists():
|
||||
print(f"[EndlessTerminalsEnv] No test file for {task_id}", flush=True)
|
||||
return 0.0, {"error": "No test file"}
|
||||
|
||||
print(f"[EndlessTerminalsEnv] Scoring {task_id}...", flush=True)
|
||||
|
||||
try:
|
||||
# Read the test file and base64-encode it for safe transfer
|
||||
test_content = Path(test_final).read_text(encoding="utf-8")
|
||||
encoded = base64.b64encode(test_content.encode("utf-8")).decode("ascii")
|
||||
|
||||
# Write test file into the container and run pytest
|
||||
# We write to /tmp to avoid interfering with the agent's workspace
|
||||
# Use printf + heredoc to avoid quoting issues with single quotes in base64
|
||||
verify_cmd = (
|
||||
f"printf '%s' '{encoded}' | base64 -d > /tmp/_test_final_state.py && "
|
||||
f"cd /home/user && "
|
||||
f"python3 -m pytest /tmp/_test_final_state.py -v --tb=short 2>&1; "
|
||||
f"echo \"EXIT_CODE=$?\""
|
||||
)
|
||||
|
||||
result = await exec_tool(ToolCall(
|
||||
name="terminal",
|
||||
arguments={"command": verify_cmd},
|
||||
))
|
||||
|
||||
output = result.output if hasattr(result, "output") else str(result)
|
||||
|
||||
# Check if pytest passed
|
||||
# Look for EXIT_CODE=0 at the end (most reliable)
|
||||
success = "EXIT_CODE=0" in output
|
||||
|
||||
score = 1.0 if success else 0.0
|
||||
|
||||
metadata = {
|
||||
"task_id": task_id,
|
||||
"success": success,
|
||||
"test_output": output[-2000:] if len(output) > 2000 else output,
|
||||
"total_tool_calls": agent_result.total_tool_calls if agent_result else 0,
|
||||
}
|
||||
|
||||
self._train_scores_buffer.append(score)
|
||||
print(f"[EndlessTerminalsEnv] {task_id} → score={score}", flush=True)
|
||||
return score, metadata
|
||||
|
||||
except Exception as e:
|
||||
print(f"[EndlessTerminalsEnv] Error scoring {task_id}: {e}", flush=True)
|
||||
return 0.0, {"error": str(e)}
|
||||
|
||||
# ---- WandB logging ----
|
||||
|
||||
async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
|
||||
"""Log training metrics to wandb."""
|
||||
if wandb_metrics is None:
|
||||
wandb_metrics = {}
|
||||
|
||||
# Training pass rate since last log
|
||||
if self._train_scores_buffer:
|
||||
wandb_metrics["train/percent_correct"] = (
|
||||
sum(self._train_scores_buffer) / len(self._train_scores_buffer)
|
||||
)
|
||||
wandb_metrics["train/num_trajectories"] = len(self._train_scores_buffer)
|
||||
self._train_scores_buffer = []
|
||||
|
||||
# Eval metrics (populated by evaluate())
|
||||
for key, value in self._eval_metrics:
|
||||
wandb_metrics[key] = value
|
||||
self._eval_metrics = []
|
||||
|
||||
await super().wandb_log(wandb_metrics)
|
||||
|
||||
# ---- Evaluation ----
|
||||
|
||||
async def evaluate(self, *args, **kwargs):
|
||||
"""
|
||||
Run the agent on held-out eval tasks and report pass rate.
|
||||
|
||||
Each eval task: build Docker container → run agent (temp=0) → pytest → score.
|
||||
This is expensive (full agent trajectories), so we only eval a subset.
|
||||
"""
|
||||
import time as _time
|
||||
|
||||
if not self._eval_tasks:
|
||||
return {}
|
||||
|
||||
start_time = _time.time()
|
||||
eval_sample_size = min(len(self._eval_tasks), 20)
|
||||
eval_subset = random.sample(self._eval_tasks, eval_sample_size)
|
||||
|
||||
print(
|
||||
f"[EndlessTerminalsEnv] Running evaluation on {eval_sample_size} tasks...",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
scores = []
|
||||
samples = []
|
||||
|
||||
for task_info in eval_subset:
|
||||
task_name = task_info["task_name"]
|
||||
description = task_info["description"]
|
||||
|
||||
try:
|
||||
# Build Docker image
|
||||
image_tag = await self._ensure_image(task_info)
|
||||
|
||||
# Run agent with temp=0 for deterministic eval
|
||||
eval_tid = f"eval_{uuid.uuid4().hex[:8]}"
|
||||
|
||||
# Register image with backend
|
||||
if isinstance(self._backend, DockerDirectBackend):
|
||||
self._backend.register_image(eval_tid, image_tag)
|
||||
|
||||
async def _exec(call, _tid=eval_tid):
|
||||
return await self._tool_executor.execute(_tid, call)
|
||||
|
||||
from ..agent import AtroposAgent as _AtroposAgent
|
||||
|
||||
agent = _AtroposAgent(
|
||||
server=self.server,
|
||||
tokenizer=self.tokenizer,
|
||||
tools=self.tools,
|
||||
config=AgentConfig(
|
||||
max_steps=self.config.agent_max_steps,
|
||||
temperature=0.0, # Deterministic for eval
|
||||
max_tokens=self.config.agent_max_tokens,
|
||||
),
|
||||
execute_tool=_exec,
|
||||
)
|
||||
|
||||
result = await agent.run(description)
|
||||
|
||||
# Score: run pytest in the container
|
||||
score = 0.0
|
||||
test_final = task_info.get("test_final", "")
|
||||
if result.success and test_final and Path(test_final).exists():
|
||||
test_content = Path(test_final).read_text(encoding="utf-8")
|
||||
encoded = base64.b64encode(test_content.encode("utf-8")).decode("ascii")
|
||||
verify_cmd = (
|
||||
f"printf '%s' '{encoded}' | base64 -d > /tmp/_test_final_state.py && "
|
||||
f"cd /home/user && "
|
||||
f"python3 -m pytest /tmp/_test_final_state.py -v --tb=short 2>&1; "
|
||||
f'echo "EXIT_CODE=$?"'
|
||||
)
|
||||
test_result = await _exec(ToolCall(
|
||||
name="terminal",
|
||||
arguments={"command": verify_cmd},
|
||||
))
|
||||
test_output = test_result.output if hasattr(test_result, "output") else ""
|
||||
if "EXIT_CODE=0" in test_output:
|
||||
score = 1.0
|
||||
|
||||
scores.append(score)
|
||||
samples.append({
|
||||
"task": task_name,
|
||||
"score": score,
|
||||
"tool_calls": result.total_tool_calls,
|
||||
"success": result.success,
|
||||
})
|
||||
|
||||
# Cleanup
|
||||
await self._tool_executor.release_trajectory(eval_tid, reset_workspace=True)
|
||||
|
||||
print(f" [eval] {task_name} → {score}", flush=True)
|
||||
|
||||
except Exception as e:
|
||||
print(f" [eval] {task_name} → ERROR: {e}", flush=True)
|
||||
scores.append(0.0)
|
||||
samples.append({"task": task_name, "score": 0.0, "error": str(e)})
|
||||
|
||||
end_time = _time.time()
|
||||
|
||||
percent_correct = sum(scores) / len(scores) if scores else 0.0
|
||||
|
||||
print(
|
||||
f"[EndlessTerminalsEnv] Eval: {percent_correct:.1%} pass rate "
|
||||
f"({sum(scores):.0f}/{len(scores)}) in {end_time - start_time:.0f}s",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
# Store for wandb_log to pick up
|
||||
self._eval_metrics.append(("eval/percent_correct", percent_correct))
|
||||
self._eval_metrics.append(("eval/num_tasks", len(scores)))
|
||||
self._eval_metrics.append(("eval/duration_s", end_time - start_time))
|
||||
|
||||
# Log via atroposlib
|
||||
eval_metrics = {
|
||||
"eval/percent_correct": percent_correct,
|
||||
"eval/num_tasks": len(scores),
|
||||
}
|
||||
await self.evaluate_log(
|
||||
metrics=eval_metrics,
|
||||
samples=samples,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
generation_parameters={
|
||||
"temperature": 0.0,
|
||||
"max_tokens": self.config.agent_max_tokens,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
EndlessTerminalsEnv.cli()
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -1,22 +1,12 @@
|
||||
"""
|
||||
Tool abstractions for atropos-agent.
|
||||
|
||||
Provides base Tool class, ToolCall/ToolResult types, and specialized tools.
|
||||
|
||||
Kept modules:
|
||||
- base.py: ToolSchema, ToolCall, ToolResult, Tool ABC, ToolRegistry
|
||||
- tool_executor.py: Batched execution queue with slot routing
|
||||
- terminal_stateful_tool.py: Persistent terminal sessions
|
||||
- tmux_tool.py: Tmux-based streaming terminal
|
||||
|
||||
Removed (replaced by hermes-agent equivalents):
|
||||
- build_registry.py → model_tools.py + toolsets.py
|
||||
- sandbox_stubs.py → atropos/backends/ execute() methods
|
||||
- hermes_external_tools.py → environments/agent_loop.py handle_function_call()
|
||||
- toolset_resolver.py → toolsets.py
|
||||
Provides base Tool class and common tool implementations.
|
||||
"""
|
||||
|
||||
from .base import Tool, ToolCall, ToolRegistry, ToolResult, ToolSchema
|
||||
from .build_registry import build_tool_registry
|
||||
from .sandbox_stubs import BashTool, ReadFileTool, TerminalTool, WriteFileTool
|
||||
from .terminal_stateful_tool import TerminalStatefulTool
|
||||
from .tmux_tool import TmuxTool
|
||||
|
||||
@@ -26,6 +16,11 @@ __all__ = [
|
||||
"ToolRegistry",
|
||||
"ToolResult",
|
||||
"ToolSchema",
|
||||
"BashTool",
|
||||
"ReadFileTool",
|
||||
"WriteFileTool",
|
||||
"TerminalTool",
|
||||
"TerminalStatefulTool",
|
||||
"TmuxTool",
|
||||
"build_tool_registry",
|
||||
]
|
||||
|
||||
64
atropos/tools/build_registry.py
Normal file
64
atropos/tools/build_registry.py
Normal file
@@ -0,0 +1,64 @@
|
||||
"""
|
||||
Unified tool registry builder for Hermes-Agent Atropos integration.
|
||||
|
||||
This composes:
|
||||
- sandbox tool stubs (terminal/bash/read_file/write_file + stateful terminal/tmux)
|
||||
- Hermes external tools (web/vision/image/moa/skills/browser), executed via ToolServer
|
||||
|
||||
ToolExecutor only needs the schema + `external` routing bit; ToolServer executes
|
||||
the external tools via Hermes' existing implementations.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
from .base import ToolRegistry
|
||||
from .hermes_external_tools import build_external_tools
|
||||
from .sandbox_stubs import BashTool, ReadFileTool, TerminalTool, WriteFileTool
|
||||
from .terminal_stateful_tool import TerminalStatefulTool
|
||||
from .tmux_tool import TmuxTool
|
||||
from .toolset_resolver import resolve_multiple_toolsets
|
||||
|
||||
|
||||
def build_tool_registry(
|
||||
*,
|
||||
enabled_toolsets: Optional[List[str]] = None,
|
||||
disabled_toolsets: Optional[List[str]] = None,
|
||||
tool_server_url: Optional[str] = None,
|
||||
) -> ToolRegistry:
|
||||
"""
|
||||
Build a ToolRegistry for AgentEnv / ToolExecutor / ToolServer.
|
||||
|
||||
If `tool_server_url` is not provided, external tools will be omitted so we do
|
||||
not advertise tools that cannot execute.
|
||||
"""
|
||||
enabled_toolsets = enabled_toolsets or ["default"]
|
||||
|
||||
# Resolve tool names using Hermes toolsets plus Atropos additions.
|
||||
selected = set(resolve_multiple_toolsets(enabled_toolsets))
|
||||
if disabled_toolsets:
|
||||
selected -= set(resolve_multiple_toolsets(disabled_toolsets))
|
||||
|
||||
reg = ToolRegistry()
|
||||
|
||||
# Always register sandbox tools if selected.
|
||||
sandbox_by_name = {
|
||||
"terminal": TerminalTool(),
|
||||
"bash": BashTool(),
|
||||
"read_file": ReadFileTool(),
|
||||
"write_file": WriteFileTool(),
|
||||
"terminal_stateful": TerminalStatefulTool(),
|
||||
"tmux": TmuxTool(),
|
||||
}
|
||||
for name, tool in sandbox_by_name.items():
|
||||
if name in selected:
|
||||
reg.register(tool)
|
||||
|
||||
# External tools: only include when ToolServer is configured.
|
||||
if tool_server_url:
|
||||
for tool in build_external_tools(selected_tool_names=selected):
|
||||
if tool.name in selected:
|
||||
reg.register(tool)
|
||||
|
||||
return reg
|
||||
90
atropos/tools/hermes_external_tools.py
Normal file
90
atropos/tools/hermes_external_tools.py
Normal file
@@ -0,0 +1,90 @@
|
||||
"""
|
||||
Hermes external tool adapter for Atropos ToolServer.
|
||||
|
||||
These tools reuse Hermes-Agent's existing tool runner (`model_tools.handle_function_call`)
|
||||
so we don't duplicate external tool implementations.
|
||||
|
||||
Important:
|
||||
- These are marked `external=True` and should be executed ONLY by ToolServer.
|
||||
- We run `handle_function_call` in a worker thread because the Hermes implementation
|
||||
uses `asyncio.run()` internally for some async tools (web_extract, vision, MoA, etc).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import model_tools
|
||||
|
||||
from .base import Tool, ToolResult, ToolSchema
|
||||
|
||||
|
||||
def _schema_from_openai_tool_dict(tool: Dict[str, Any], *, external: bool) -> ToolSchema:
|
||||
fn = tool.get("function") or {}
|
||||
name = str(fn.get("name") or "")
|
||||
description = str(fn.get("description") or "")
|
||||
params = fn.get("parameters") or {}
|
||||
properties = params.get("properties") or {}
|
||||
required = params.get("required") or []
|
||||
if not isinstance(required, list):
|
||||
required = []
|
||||
return ToolSchema(
|
||||
name=name,
|
||||
description=description,
|
||||
parameters=dict(properties),
|
||||
required=[str(x) for x in required if isinstance(x, (str, int))],
|
||||
external=external,
|
||||
)
|
||||
|
||||
|
||||
class HermesExternalTool(Tool):
|
||||
def __init__(self, schema: ToolSchema):
|
||||
self._schema = schema
|
||||
|
||||
@property
|
||||
def schema(self) -> ToolSchema:
|
||||
return self._schema
|
||||
|
||||
async def execute(self, task_id: Optional[str] = None, **kwargs: Any) -> ToolResult:
|
||||
# `model_tools.handle_function_call` returns a JSON string (success or error).
|
||||
# Run in a thread because some Hermes tool handlers call `asyncio.run()`.
|
||||
raw = await asyncio.to_thread(model_tools.handle_function_call, self.name, kwargs, task_id)
|
||||
|
||||
try:
|
||||
parsed = json.loads(raw)
|
||||
except Exception:
|
||||
# Keep as plain string.
|
||||
return ToolResult(success=True, output=str(raw))
|
||||
|
||||
if isinstance(parsed, dict) and parsed.get("error"):
|
||||
return ToolResult(success=False, error=str(parsed.get("error")), output="")
|
||||
|
||||
return ToolResult(success=True, output=json.dumps(parsed, ensure_ascii=False))
|
||||
|
||||
|
||||
def build_external_tools(
|
||||
*,
|
||||
selected_tool_names: Optional[set[str]] = None,
|
||||
) -> List[HermesExternalTool]:
|
||||
"""
|
||||
Build external tool wrappers from Hermes tool declarations.
|
||||
|
||||
Filters out sandbox-oriented tools (e.g. `terminal`) since those should run
|
||||
inside the sandbox via ToolExecutor.
|
||||
"""
|
||||
# IMPORTANT: Hermes' `model_tools.get_tool_definitions()` only understands Hermes toolsets.
|
||||
# Atropos envs add extra toolsets (filesystem/sandbox/stateful). To avoid noisy "Unknown toolset"
|
||||
# prints and accidental filtering, we fetch ALL Hermes tool definitions here and filter by name.
|
||||
tools = model_tools.get_tool_definitions(enabled_toolsets=None, disabled_toolsets=None, quiet_mode=True)
|
||||
|
||||
wrappers: List[HermesExternalTool] = []
|
||||
for t in tools:
|
||||
schema = _schema_from_openai_tool_dict(t, external=True)
|
||||
if schema.name in {"terminal"}:
|
||||
continue
|
||||
if selected_tool_names is not None and schema.name not in selected_tool_names:
|
||||
continue
|
||||
wrappers.append(HermesExternalTool(schema))
|
||||
return wrappers
|
||||
99
atropos/tools/sandbox_stubs.py
Normal file
99
atropos/tools/sandbox_stubs.py
Normal file
@@ -0,0 +1,99 @@
|
||||
"""
|
||||
Sandbox tool stubs for Atropos ToolExecutor.
|
||||
|
||||
These tools are executed inside the sandbox containers via:
|
||||
ToolExecutor -> SlotPool -> sandbox_server.py
|
||||
|
||||
They intentionally do NOT execute anything on the host process. If they are
|
||||
called directly (outside ToolExecutor), they return a clear error.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from .base import Tool, ToolResult, ToolSchema
|
||||
|
||||
|
||||
class TerminalTool(Tool):
|
||||
@property
|
||||
def schema(self) -> ToolSchema:
|
||||
return ToolSchema(
|
||||
name="terminal",
|
||||
description=(
|
||||
"Execute a command inside the sandbox slot workspace and return stdout/stderr. "
|
||||
"Filesystem persists within a trajectory slot. Background processes are not supported "
|
||||
"in stateless mode. Commands run under POSIX /bin/sh and each tool call runs in a fresh "
|
||||
"shell (no persisted env vars). Avoid bash-only syntax like `source`; prefer `. .venv/bin/activate` "
|
||||
"or invoke `.venv/bin/python ...` directly."
|
||||
),
|
||||
parameters={
|
||||
"command": {"type": "string", "description": "The command to execute"},
|
||||
"timeout": {
|
||||
"type": "integer",
|
||||
"description": "Command timeout in seconds (optional).",
|
||||
"minimum": 1,
|
||||
},
|
||||
"background": {
|
||||
"type": "boolean",
|
||||
"description": "Not supported in sandbox terminal (always false).",
|
||||
"default": False,
|
||||
},
|
||||
},
|
||||
required=["command"],
|
||||
external=False,
|
||||
)
|
||||
|
||||
async def execute(self, **_kwargs) -> ToolResult:
|
||||
return ToolResult(
|
||||
success=False,
|
||||
error="terminal must be executed via ToolExecutor inside the sandbox",
|
||||
)
|
||||
|
||||
|
||||
class BashTool(Tool):
|
||||
@property
|
||||
def schema(self) -> ToolSchema:
|
||||
return ToolSchema(
|
||||
name="bash",
|
||||
description="Execute a bash command inside the sandbox slot workspace.",
|
||||
parameters={"command": {"type": "string", "description": "The bash command to execute"}},
|
||||
required=["command"],
|
||||
external=False,
|
||||
)
|
||||
|
||||
async def execute(self, **_kwargs) -> ToolResult:
|
||||
return ToolResult(success=False, error="bash must be executed via ToolExecutor inside the sandbox")
|
||||
|
||||
|
||||
class ReadFileTool(Tool):
|
||||
@property
|
||||
def schema(self) -> ToolSchema:
|
||||
return ToolSchema(
|
||||
name="read_file",
|
||||
description="Read a file from the sandbox slot workspace.",
|
||||
parameters={"path": {"type": "string", "description": "Path to the file"}},
|
||||
required=["path"],
|
||||
external=False,
|
||||
)
|
||||
|
||||
async def execute(self, **_kwargs) -> ToolResult:
|
||||
return ToolResult(success=False, error="read_file must be executed via ToolExecutor inside the sandbox")
|
||||
|
||||
|
||||
class WriteFileTool(Tool):
|
||||
@property
|
||||
def schema(self) -> ToolSchema:
|
||||
return ToolSchema(
|
||||
name="write_file",
|
||||
description="Write a file into the sandbox slot workspace.",
|
||||
parameters={
|
||||
"path": {"type": "string", "description": "Path to the file"},
|
||||
"content": {"type": "string", "description": "File content"},
|
||||
},
|
||||
required=["path", "content"],
|
||||
external=False,
|
||||
)
|
||||
|
||||
async def execute(self, **_kwargs) -> ToolResult:
|
||||
return ToolResult(success=False, error="write_file must be executed via ToolExecutor inside the sandbox")
|
||||
88
atropos/tools/toolset_resolver.py
Normal file
88
atropos/tools/toolset_resolver.py
Normal file
@@ -0,0 +1,88 @@
|
||||
"""
|
||||
Toolset resolution for Hermes-Agent Atropos integration.
|
||||
|
||||
We primarily reuse Hermes-Agent toolsets (`toolsets.py`), but Atropos training/envs
|
||||
need a few extra sandbox-oriented toolsets that Hermes doesn't expose by default
|
||||
(e.g. filesystem + stateful terminal).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
|
||||
import toolsets as hermes_toolsets
|
||||
|
||||
|
||||
ATROPOS_TOOLSETS: Dict[str, Dict[str, Any]] = {
|
||||
"filesystem": {
|
||||
"description": "Read/write files in the sandbox workspace.",
|
||||
"tools": ["read_file", "write_file"],
|
||||
"includes": [],
|
||||
},
|
||||
"terminal_stateful": {
|
||||
"description": "Stateful terminal execution (tmux/TUI support) inside the sandbox.",
|
||||
"tools": ["terminal_stateful", "tmux"],
|
||||
"includes": [],
|
||||
},
|
||||
"sandbox": {
|
||||
"description": "Sandbox tools (terminal + filesystem).",
|
||||
"tools": [],
|
||||
"includes": ["terminal", "filesystem"],
|
||||
},
|
||||
"default": {
|
||||
"description": "Default toolset for Atropos AgentEnv tasks.",
|
||||
"tools": [],
|
||||
"includes": ["sandbox"],
|
||||
},
|
||||
"full": {
|
||||
"description": "All Hermes tools plus Atropos sandbox additions.",
|
||||
"tools": [],
|
||||
"includes": ["all", "filesystem", "sandbox", "terminal_stateful"],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def validate_toolset(name: str) -> bool:
|
||||
if name in {"all", "*"}:
|
||||
return True
|
||||
return hermes_toolsets.validate_toolset(name) or name in ATROPOS_TOOLSETS
|
||||
|
||||
|
||||
def resolve_toolset(name: str, visited: Optional[Set[str]] = None) -> List[str]:
|
||||
if visited is None:
|
||||
visited = set()
|
||||
|
||||
if name in {"all", "*"}:
|
||||
# Union Hermes + Atropos toolsets.
|
||||
all_tools: Set[str] = set()
|
||||
for tname in hermes_toolsets.get_toolset_names():
|
||||
all_tools.update(resolve_toolset(tname, visited=set()))
|
||||
for tname, spec in ATROPOS_TOOLSETS.items():
|
||||
# Avoid recursion: some Atropos toolsets (e.g. "full") include "all".
|
||||
if tname == "full" or "all" in (spec.get("includes") or []):
|
||||
continue
|
||||
all_tools.update(resolve_toolset(tname, visited=set()))
|
||||
return sorted(all_tools)
|
||||
|
||||
if name in ATROPOS_TOOLSETS:
|
||||
if name in visited:
|
||||
return []
|
||||
visited.add(name)
|
||||
spec = ATROPOS_TOOLSETS[name]
|
||||
tools: Set[str] = set(spec.get("tools", []))
|
||||
for inc in spec.get("includes", []):
|
||||
tools.update(resolve_toolset(inc, visited=set(visited)))
|
||||
return sorted(tools)
|
||||
|
||||
# Fall back to Hermes toolsets.
|
||||
# IMPORTANT: do not pre-add `name` to `visited` here; Hermes' resolver uses
|
||||
# `visited` for its own cycle detection and will treat the presence of `name`
|
||||
# as a circular dependency.
|
||||
return sorted(hermes_toolsets.resolve_toolset(name, visited=set(visited)))
|
||||
|
||||
|
||||
def resolve_multiple_toolsets(names: List[str]) -> List[str]:
|
||||
tools: Set[str] = set()
|
||||
for name in names:
|
||||
tools.update(resolve_toolset(name, visited=set()))
|
||||
return sorted(tools)
|
||||
File diff suppressed because it is too large
Load Diff
83
configs/endless_terminals.yaml
Normal file
83
configs/endless_terminals.yaml
Normal file
@@ -0,0 +1,83 @@
|
||||
# Endless Terminals Environment Configuration
|
||||
#
|
||||
# Two modes:
|
||||
# 1. Dataset mode (default): Load pre-generated tasks from HuggingFace
|
||||
# 2. Procedural mode: Generate tasks on-demand via LLM
|
||||
#
|
||||
# Usage:
|
||||
# python -m atropos.envs.endless_terminals_env process \
|
||||
# --config configs/endless_terminals.yaml
|
||||
|
||||
# Environment settings
|
||||
env:
|
||||
# Dataset mode (primary - recommended)
|
||||
use_dataset: true # Load from HuggingFace (fast, no vLLM needed)
|
||||
dataset_name: "obiwan96/endless-terminals-train"
|
||||
dataset_split: "train"
|
||||
dataset_cache_dir: "~/.cache/huggingface/datasets"
|
||||
tasks_base_dir: "" # Set to dir containing task_* folders if not using default paths
|
||||
# Example: "/path/to/endless-terminals-train"
|
||||
|
||||
# Task generation (fallback if use_dataset=false)
|
||||
task_gen_model: "Qwen/Qwen3-32B" # Only needed if use_dataset=false
|
||||
task_gen_temperature: 1.0
|
||||
task_gen_max_tokens: 2048
|
||||
|
||||
# Container settings
|
||||
base_container_image: "ubuntu:22.04"
|
||||
container_timeout_s: 180
|
||||
test_timeout_s: 60
|
||||
|
||||
# Workspace
|
||||
workspace_dir: "/tmp/endless_terminals_workspace"
|
||||
keep_failed_tasks: false # Set true to debug failed tasks
|
||||
|
||||
# Agent config (increased for long traces)
|
||||
agent_max_steps: 32
|
||||
agent_temperature: 0.7
|
||||
agent_max_tokens: null # Let backend decide
|
||||
|
||||
# Tooling: terminal only
|
||||
enabled_toolsets: ["terminal"]
|
||||
disabled_toolsets: []
|
||||
|
||||
# Training settings
|
||||
group_size: 4 # Parallel trajectory collection
|
||||
batch_size: 32
|
||||
total_steps: 1000 # Total training episodes
|
||||
use_wandb: false # Enable for experiment tracking
|
||||
include_messages: true
|
||||
|
||||
# Tool execution backend (nomad or modal)
|
||||
tool_pool_mode: "nomad"
|
||||
|
||||
# Nomad settings (if using nomad)
|
||||
nomad_address: "http://localhost:4646"
|
||||
sandbox_job_id: "atropos-sandbox-endless"
|
||||
sandbox_image: "atropos-sandbox:local"
|
||||
slots_per_container: 10
|
||||
min_containers: 1
|
||||
max_containers: 10
|
||||
privileged: false
|
||||
acquire_timeout_s: 30.0
|
||||
purge_job_on_start: true
|
||||
purge_job_on_shutdown: true
|
||||
|
||||
# Modal settings (if using modal instead)
|
||||
# modal_app_name: "atropos-endless"
|
||||
# modal_image: "python:3.11"
|
||||
# modal_slots_per_sandbox: 10
|
||||
# modal_min_sandboxes: 1
|
||||
# modal_max_sandboxes: 5
|
||||
|
||||
# Server config
|
||||
server_base_url: "http://127.0.0.1:8080"
|
||||
server_model: "hermes-4-36b"
|
||||
tokenizer_name: "NousResearch/Hermes-4.3-36B"
|
||||
|
||||
# Server configs are auto-generated from env vars and env.server_* settings
|
||||
# Override via environment variables:
|
||||
# ATROPOS_SERVER_BASE_URL
|
||||
# ATROPOS_SERVER_MODEL
|
||||
# ATROPOS_SERVER_API_KEY
|
||||
# ATROPOS_TOKENIZER_NAME
|
||||
@@ -57,12 +57,6 @@ class AgentResult:
|
||||
# Tool errors encountered during the loop
|
||||
tool_errors: List[ToolError] = field(default_factory=list)
|
||||
|
||||
# Tool-call metrics (for reward shaping + debugging)
|
||||
tool_calls_attempted: int = 0 # Valid tool name + attempted dispatch
|
||||
tool_calls_schema_valid: int = 0 # Arguments matched schema (no coercion)
|
||||
tool_calls_executed_ok: int = 0 # Tool ran and returned no error
|
||||
tool_calls_exec_error: int = 0 # Unknown tool / exception / tool returned error
|
||||
|
||||
|
||||
def _extract_reasoning_from_message(message) -> Optional[str]:
|
||||
"""
|
||||
@@ -125,8 +119,6 @@ class HermesAgentLoop:
|
||||
task_id: Optional[str] = None,
|
||||
temperature: float = 1.0,
|
||||
max_tokens: Optional[int] = None,
|
||||
tool_handler=None,
|
||||
max_context_tokens: Optional[int] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the agent loop.
|
||||
@@ -140,13 +132,6 @@ class HermesAgentLoop:
|
||||
task_id: Unique ID for terminal/browser session isolation
|
||||
temperature: Sampling temperature for generation
|
||||
max_tokens: Max tokens per generation (None for server default)
|
||||
tool_handler: Optional async callable(tool_name, args, task_id) -> str.
|
||||
When provided, used INSTEAD of handle_function_call() for
|
||||
tool dispatch. This allows sandbox backends (Modal, Nomad)
|
||||
to route tool calls through their slot-based execution.
|
||||
max_context_tokens: Maximum prompt tokens before truncation.
|
||||
If None, no truncation is applied.
|
||||
Recommended: set to max_model_len - max_tokens - 512 (safety margin).
|
||||
"""
|
||||
self.server = server
|
||||
self.tool_schemas = tool_schemas
|
||||
@@ -155,139 +140,6 @@ class HermesAgentLoop:
|
||||
self.task_id = task_id or str(uuid.uuid4())
|
||||
self.temperature = temperature
|
||||
self.max_tokens = max_tokens
|
||||
self.tool_handler = tool_handler
|
||||
self.max_context_tokens = max_context_tokens
|
||||
|
||||
|
||||
def _truncate_context(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Truncate conversation history to fit within max_context_tokens.
|
||||
|
||||
Strategy:
|
||||
- Keep system message (index 0) and initial user message (index 1) always
|
||||
- Keep last 6 messages (recent context) always
|
||||
- For everything in between, progressively truncate tool result content
|
||||
- If still too long, drop oldest middle messages entirely
|
||||
|
||||
Uses rough char/4 token estimate (fast, no tokenizer needed).
|
||||
"""
|
||||
if self.max_context_tokens is None:
|
||||
return messages
|
||||
|
||||
def estimate_tokens(msgs):
|
||||
total = 0
|
||||
for m in msgs:
|
||||
content = m.get("content", "") or ""
|
||||
total += len(content) // 4 + 10 # ~4 chars per token + overhead
|
||||
if "tool_calls" in m:
|
||||
total += 50 * len(m["tool_calls"]) # tool call overhead
|
||||
return total
|
||||
|
||||
est = estimate_tokens(messages)
|
||||
if est <= self.max_context_tokens:
|
||||
return messages
|
||||
|
||||
# Phase 1: Truncate tool result content in middle messages
|
||||
# Keep first 2 and last 6 messages untouched
|
||||
protect_head = 2
|
||||
protect_tail = max(0, min(6, len(messages) - protect_head))
|
||||
middle_start = protect_head
|
||||
middle_end = len(messages) - protect_tail
|
||||
|
||||
if middle_start < middle_end:
|
||||
# Truncate tool results from oldest first
|
||||
for i in range(middle_start, middle_end):
|
||||
if messages[i].get("role") == "tool":
|
||||
content = messages[i].get("content", "") or ""
|
||||
if len(content) > 200:
|
||||
messages[i] = dict(messages[i]) # copy
|
||||
messages[i]["content"] = content[:100] + "\n...[truncated]...\n" + content[-50:]
|
||||
|
||||
est = estimate_tokens(messages)
|
||||
if est <= self.max_context_tokens:
|
||||
logger.debug("Context truncated (phase 1: tool results): %d tokens", est)
|
||||
return messages
|
||||
|
||||
# Phase 2: Drop oldest middle messages entirely
|
||||
while middle_start < middle_end and estimate_tokens(messages) > self.max_context_tokens:
|
||||
# Remove the oldest middle message
|
||||
# But keep assistant+tool pairs together
|
||||
msg = messages[middle_start]
|
||||
messages.pop(middle_start)
|
||||
middle_end -= 1
|
||||
# If we removed an assistant with tool_calls, also remove matching tool responses
|
||||
if msg.get("role") == "assistant" and msg.get("tool_calls"):
|
||||
tool_ids = {tc.get("id") or tc.get("tool_call_id", "") for tc in msg.get("tool_calls", []) if isinstance(tc, dict)}
|
||||
# Remove tool responses for those IDs
|
||||
i = middle_start
|
||||
while i < middle_end:
|
||||
if messages[i].get("role") == "tool" and messages[i].get("tool_call_id", "") in tool_ids:
|
||||
messages.pop(i)
|
||||
middle_end -= 1
|
||||
else:
|
||||
i += 1
|
||||
|
||||
est = estimate_tokens(messages)
|
||||
logger.info("Context truncated (phase 2: dropped messages): %d estimated tokens, %d messages remaining", est, len(messages))
|
||||
return messages
|
||||
|
||||
def _normalize_tool_args(self, tool_name: str, tool_args_raw: str) -> (Dict[str, Any], bool):
|
||||
"""Normalize tool arguments into a dict.
|
||||
|
||||
Returns:
|
||||
(args_dict, schema_valid)
|
||||
|
||||
schema_valid is True only when the arguments decode directly into a dict
|
||||
(i.e. no double-decoding and no coercion/wrapping was needed).
|
||||
|
||||
This lets us keep the environment robust (never crash due to args format)
|
||||
while still scoring down malformed tool-call argument formats.
|
||||
"""
|
||||
try:
|
||||
decoded = json.loads(tool_args_raw)
|
||||
except json.JSONDecodeError:
|
||||
# Not valid JSON at all. Be robust: treat it as a plain string.
|
||||
# (Some parsers/providers may pass through non-JSON strings.)
|
||||
if tool_name == "terminal":
|
||||
return {"command": tool_args_raw}, False
|
||||
return {"input": tool_args_raw}, False
|
||||
|
||||
# Canonical case: decoded is already a dict
|
||||
if isinstance(decoded, dict):
|
||||
# For terminal tool, require a command key
|
||||
if tool_name == "terminal":
|
||||
cmd = decoded.get("command")
|
||||
if isinstance(cmd, str) and cmd.strip():
|
||||
return decoded, True
|
||||
# Common alternate key
|
||||
if isinstance(decoded.get("input"), str):
|
||||
return {"command": decoded.get("input")}, False
|
||||
return decoded, False
|
||||
return decoded, True
|
||||
|
||||
# Common drift case: decoded is a JSON string of an object
|
||||
if isinstance(decoded, str):
|
||||
s = decoded.strip()
|
||||
if (s.startswith("{") and s.endswith("}")) or (s.startswith("[") and s.endswith("]")):
|
||||
try:
|
||||
decoded2 = json.loads(s)
|
||||
except json.JSONDecodeError:
|
||||
decoded2 = None
|
||||
if isinstance(decoded2, dict):
|
||||
# Terminal tool: ensure command
|
||||
if tool_name == "terminal" and isinstance(decoded2.get("command"), str):
|
||||
return decoded2, False
|
||||
return decoded2, False
|
||||
|
||||
# Plain string (not JSON) — coerce to expected shape
|
||||
if tool_name == "terminal":
|
||||
return {"command": decoded}, False
|
||||
return {"input": decoded}, False
|
||||
|
||||
# Other JSON types (list/number/etc.) — wrap
|
||||
if tool_name == "terminal":
|
||||
return {"command": str(decoded)}, False
|
||||
return {"input": decoded}, False
|
||||
|
||||
async def run(self, messages: List[Dict[str, Any]]) -> AgentResult:
|
||||
"""
|
||||
@@ -295,12 +147,7 @@ class HermesAgentLoop:
|
||||
|
||||
Args:
|
||||
messages: Initial conversation messages (system + user).
|
||||
This list is treated as the FULL trajectory and is
|
||||
appended to as the conversation progresses.
|
||||
|
||||
Prompt truncation (to avoid context overflow) is applied
|
||||
on a copy of this list per turn, so we do not lose
|
||||
earlier messages for reward computation/debugging.
|
||||
Modified in-place as the conversation progresses.
|
||||
|
||||
Returns:
|
||||
AgentResult with full conversation history, managed state, and metadata
|
||||
@@ -308,21 +155,10 @@ class HermesAgentLoop:
|
||||
reasoning_per_turn = []
|
||||
tool_errors: List[ToolError] = []
|
||||
|
||||
# Metrics to separate "attempted tool use" from "schema-valid tool use"
|
||||
tool_calls_attempted = 0
|
||||
tool_calls_schema_valid = 0
|
||||
tool_calls_executed_ok = 0
|
||||
tool_calls_exec_error = 0
|
||||
|
||||
for turn in range(self.max_turns):
|
||||
# Truncate context if approaching limit.
|
||||
# IMPORTANT: do this on a copy so we keep the full trajectory in `messages`
|
||||
# for reward computation + debugging, while only trimming the prompt view.
|
||||
prompt_messages = self._truncate_context(list(messages))
|
||||
|
||||
# Build the chat_completion kwargs
|
||||
chat_kwargs = {
|
||||
"messages": prompt_messages,
|
||||
"messages": messages,
|
||||
"n": 1,
|
||||
"temperature": self.temperature,
|
||||
}
|
||||
@@ -347,10 +183,6 @@ class HermesAgentLoop:
|
||||
finished_naturally=False,
|
||||
reasoning_per_turn=reasoning_per_turn,
|
||||
tool_errors=tool_errors,
|
||||
tool_calls_attempted=tool_calls_attempted,
|
||||
tool_calls_schema_valid=tool_calls_schema_valid,
|
||||
tool_calls_executed_ok=tool_calls_executed_ok,
|
||||
tool_calls_exec_error=tool_calls_exec_error,
|
||||
)
|
||||
|
||||
if not response or not response.choices:
|
||||
@@ -362,10 +194,6 @@ class HermesAgentLoop:
|
||||
finished_naturally=False,
|
||||
reasoning_per_turn=reasoning_per_turn,
|
||||
tool_errors=tool_errors,
|
||||
tool_calls_attempted=tool_calls_attempted,
|
||||
tool_calls_schema_valid=tool_calls_schema_valid,
|
||||
tool_calls_executed_ok=tool_calls_executed_ok,
|
||||
tool_calls_exec_error=tool_calls_exec_error,
|
||||
)
|
||||
|
||||
assistant_msg = response.choices[0].message
|
||||
@@ -424,45 +252,35 @@ class HermesAgentLoop:
|
||||
"Model called unknown tool '%s' on turn %d",
|
||||
tool_name, turn + 1,
|
||||
)
|
||||
tool_calls_exec_error += 1
|
||||
else:
|
||||
tool_calls_attempted += 1
|
||||
|
||||
# Normalize args into a dict so we never crash due to formatting.
|
||||
# Track schema_valid separately so reward shaping can penalize
|
||||
# non-canonical formats (e.g. stringified JSON).
|
||||
args, schema_valid = self._normalize_tool_args(tool_name, tool_args_raw)
|
||||
if schema_valid:
|
||||
tool_calls_schema_valid += 1
|
||||
# Parse arguments and dispatch
|
||||
try:
|
||||
args = json.loads(tool_args_raw)
|
||||
except json.JSONDecodeError:
|
||||
args = {}
|
||||
logger.warning(
|
||||
"Invalid JSON in tool call arguments for '%s': %s",
|
||||
tool_name, tool_args_raw[:200],
|
||||
)
|
||||
|
||||
try:
|
||||
if tool_name == "terminal":
|
||||
import os
|
||||
backend = os.getenv("TERMINAL_ENV", "local")
|
||||
if self.tool_handler:
|
||||
backend = "sandbox"
|
||||
cmd_preview = str(args.get("command", ""))[:80]
|
||||
cmd_preview = args.get("command", "")[:80]
|
||||
print(f" 🖥️ [{backend}] $ {cmd_preview}")
|
||||
|
||||
if self.tool_handler:
|
||||
# Use custom tool handler (sandbox backend routing)
|
||||
tool_result = await self.tool_handler(
|
||||
tool_name, args, self.task_id
|
||||
)
|
||||
else:
|
||||
# Default: run via hermes-agent's handle_function_call
|
||||
# in a thread pool so backends that use asyncio.run()
|
||||
# internally (modal, docker) get a clean event loop
|
||||
# instead of deadlocking inside Atropos's loop.
|
||||
loop = asyncio.get_event_loop()
|
||||
tool_result = await loop.run_in_executor(
|
||||
_tool_executor,
|
||||
lambda: handle_function_call(
|
||||
tool_name, args, task_id=self.task_id
|
||||
),
|
||||
)
|
||||
# Run tool calls in a thread pool so backends that use
|
||||
# asyncio.run() internally (modal, docker) get a clean
|
||||
# event loop instead of deadlocking inside Atropos's loop.
|
||||
loop = asyncio.get_event_loop()
|
||||
tool_result = await loop.run_in_executor(
|
||||
_tool_executor,
|
||||
lambda: handle_function_call(
|
||||
tool_name, args, task_id=self.task_id
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
tool_calls_exec_error += 1
|
||||
tool_result = json.dumps(
|
||||
{"error": f"Tool execution failed: {type(e).__name__}: {str(e)}"}
|
||||
)
|
||||
@@ -476,34 +294,22 @@ class HermesAgentLoop:
|
||||
"Tool '%s' execution failed on turn %d: %s",
|
||||
tool_name, turn + 1, e,
|
||||
)
|
||||
else:
|
||||
# Count tool result errors (if tool returns structured JSON error)
|
||||
tool_err = False
|
||||
try:
|
||||
result_data = json.loads(tool_result)
|
||||
if isinstance(result_data, dict):
|
||||
err = result_data.get("error")
|
||||
if err:
|
||||
tool_err = True
|
||||
|
||||
# Keep existing behavior: treat negative exit_code as tool error
|
||||
exit_code = result_data.get("exit_code")
|
||||
if exit_code is not None and isinstance(exit_code, int) and exit_code < 0:
|
||||
tool_err = True
|
||||
tool_errors.append(ToolError(
|
||||
turn=turn + 1, tool_name=tool_name,
|
||||
arguments=tool_args_raw[:200],
|
||||
error=str(err) if err else "nonzero exit_code",
|
||||
tool_result=tool_result[:500],
|
||||
))
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
# Non-JSON tool output — assume ok
|
||||
pass
|
||||
|
||||
if tool_err:
|
||||
tool_calls_exec_error += 1
|
||||
else:
|
||||
tool_calls_executed_ok += 1
|
||||
# Also check if the tool returned an error in its JSON result
|
||||
try:
|
||||
result_data = json.loads(tool_result)
|
||||
if isinstance(result_data, dict):
|
||||
err = result_data.get("error")
|
||||
exit_code = result_data.get("exit_code")
|
||||
if err and exit_code and exit_code < 0:
|
||||
tool_errors.append(ToolError(
|
||||
turn=turn + 1, tool_name=tool_name,
|
||||
arguments=tool_args_raw[:200],
|
||||
error=str(err),
|
||||
tool_result=tool_result[:500],
|
||||
))
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
|
||||
# Add tool response to conversation
|
||||
messages.append(
|
||||
@@ -541,10 +347,6 @@ class HermesAgentLoop:
|
||||
finished_naturally=True,
|
||||
reasoning_per_turn=reasoning_per_turn,
|
||||
tool_errors=tool_errors,
|
||||
tool_calls_attempted=tool_calls_attempted,
|
||||
tool_calls_schema_valid=tool_calls_schema_valid,
|
||||
tool_calls_executed_ok=tool_calls_executed_ok,
|
||||
tool_calls_exec_error=tool_calls_exec_error,
|
||||
)
|
||||
|
||||
# Hit max turns without the model stopping
|
||||
@@ -556,10 +358,6 @@ class HermesAgentLoop:
|
||||
finished_naturally=False,
|
||||
reasoning_per_turn=reasoning_per_turn,
|
||||
tool_errors=tool_errors,
|
||||
tool_calls_attempted=tool_calls_attempted,
|
||||
tool_calls_schema_valid=tool_calls_schema_valid,
|
||||
tool_calls_executed_ok=tool_calls_executed_ok,
|
||||
tool_calls_exec_error=tool_calls_exec_error,
|
||||
)
|
||||
|
||||
def _get_managed_state(self) -> Optional[Dict[str, Any]]:
|
||||
|
||||
@@ -1,350 +0,0 @@
|
||||
"""
|
||||
GSM8kAgentEnv -- Math Reasoning with Tool Use (Python REPL)
|
||||
|
||||
An agentic RL environment where models solve GSM8k math problems using
|
||||
a Python interpreter tool. Uses proper OpenAI-spec tool calling via
|
||||
HermesAgentBaseEnv (not ICL).
|
||||
|
||||
The model:
|
||||
1. Receives a math problem
|
||||
2. Can call the `terminal` tool to run Python code (`python3 -c "..."`)
|
||||
3. Provides a final answer in \\boxed{} format
|
||||
4. Gets reward: 1.0 if correct, 0.0 if wrong
|
||||
|
||||
Usage:
|
||||
# Phase 1 (OpenRouter, no training):
|
||||
python environments/gsm8k_agent_env.py process \\
|
||||
--env.data_path_to_save_groups gsm8k_agent_output.jsonl
|
||||
|
||||
# Phase 2 (VLLM + Tinker training):
|
||||
run-api
|
||||
python launch_training.py --config configs/gsm8k_agent.yaml
|
||||
python environments/gsm8k_agent_env.py serve
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
# Ensure repo root is on sys.path
|
||||
_repo_root = Path(__file__).resolve().parent.parent
|
||||
if str(_repo_root) not in sys.path:
|
||||
sys.path.insert(0, str(_repo_root))
|
||||
|
||||
from atroposlib.envs.base import ScoredDataGroup
|
||||
from atroposlib.envs.server_handling.server_manager import APIServerConfig
|
||||
from atroposlib.type_definitions import Item
|
||||
|
||||
from environments.agent_loop import AgentResult
|
||||
from environments.hermes_base_env import HermesAgentBaseEnv, HermesAgentEnvConfig
|
||||
from environments.tool_context import ToolContext
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Math verification helpers
|
||||
# =============================================================================
|
||||
|
||||
def _verify_math_answer(model_response: str, gold_answer: str) -> bool:
|
||||
"""
|
||||
Verify if the model's response contains the correct answer.
|
||||
Uses math_verify for robust LaTeX comparison, falls back to string matching.
|
||||
"""
|
||||
try:
|
||||
from latex2sympy2_extended import NormalizationConfig
|
||||
from math_verify import LatexExtractionConfig, parse, verify
|
||||
|
||||
gold_parsed = parse(
|
||||
f"\\boxed{{{gold_answer}}}",
|
||||
extraction_mode="first_match",
|
||||
extraction_config=[LatexExtractionConfig()],
|
||||
)
|
||||
|
||||
# Strip <think> blocks if present
|
||||
answer_text = model_response
|
||||
if "</think>" in answer_text:
|
||||
answer_text = answer_text.split("</think>")[-1]
|
||||
|
||||
answer_parsed = parse(
|
||||
answer_text,
|
||||
extraction_config=[
|
||||
LatexExtractionConfig(
|
||||
normalization_config=NormalizationConfig(
|
||||
nits=False,
|
||||
malformed_operators=False,
|
||||
basic_latex=True,
|
||||
boxed="all",
|
||||
units=True,
|
||||
),
|
||||
boxed_match_priority=0,
|
||||
try_extract_without_anchor=False,
|
||||
)
|
||||
],
|
||||
extraction_mode="first_match",
|
||||
)
|
||||
|
||||
return bool(verify(answer_parsed, gold_parsed))
|
||||
|
||||
except ImportError:
|
||||
# Fallback: simple string matching for \\boxed{answer}
|
||||
import re
|
||||
pattern = r'\\boxed\{([^}]+)\}'
|
||||
matches = re.findall(pattern, model_response)
|
||||
if matches:
|
||||
model_answer = matches[-1].strip().replace(",", "")
|
||||
gold_clean = gold_answer.strip().replace(",", "")
|
||||
return model_answer == gold_clean
|
||||
return False
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Environment Config
|
||||
# =============================================================================
|
||||
|
||||
class GSM8kAgentEnvConfig(HermesAgentEnvConfig):
|
||||
"""Config with defaults for GSM8k agent environment."""
|
||||
pass
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Environment
|
||||
# =============================================================================
|
||||
|
||||
class GSM8kAgentEnv(HermesAgentBaseEnv):
|
||||
"""
|
||||
GSM8k math environment with Python REPL tool calling.
|
||||
|
||||
Models solve grade-school math problems by reasoning step by step
|
||||
and using Python (via the terminal tool) for calculations.
|
||||
|
||||
Exercises the full agentic RL training loop:
|
||||
- Model receives math problem
|
||||
- Makes tool calls to compute (python3 -c "...")
|
||||
- Provides final answer in \\boxed{}
|
||||
- Reward: binary (1.0 correct, 0.0 wrong)
|
||||
"""
|
||||
|
||||
name = "gsm8k-agent"
|
||||
env_config_cls = GSM8kAgentEnvConfig
|
||||
|
||||
@classmethod
|
||||
def config_init(cls) -> Tuple[GSM8kAgentEnvConfig, List[APIServerConfig]]:
|
||||
"""
|
||||
Default config using terminal tool.
|
||||
|
||||
Reads from environment variables (set in .env):
|
||||
ATROPOS_SERVER_BASE_URL - Inference server URL
|
||||
ATROPOS_SERVER_MODEL - Model name on the server
|
||||
ATROPOS_TOKENIZER_NAME - HuggingFace tokenizer name
|
||||
ATROPOS_SERVER_API_KEY - API key for the server
|
||||
"""
|
||||
# Resolve inference server settings from env
|
||||
base_url = (
|
||||
os.getenv("ATROPOS_SERVER_BASE_URL")
|
||||
or os.getenv("OPENAI_BASE_URL")
|
||||
or os.getenv("LLM_BASE_URL")
|
||||
or "https://openrouter.ai/api/v1"
|
||||
)
|
||||
if not base_url.rstrip("/").endswith("/v1"):
|
||||
base_url = base_url.rstrip("/") + "/v1"
|
||||
|
||||
model = (
|
||||
os.getenv("ATROPOS_SERVER_MODEL")
|
||||
or os.getenv("LLM_MODEL")
|
||||
or "Hermes-4.3-36B"
|
||||
)
|
||||
|
||||
api_key = (
|
||||
os.getenv("ATROPOS_SERVER_API_KEY")
|
||||
or os.getenv("NOUS_API_KEY")
|
||||
or os.getenv("OPENROUTER_API_KEY")
|
||||
or os.getenv("OPENAI_API_KEY")
|
||||
or ""
|
||||
)
|
||||
|
||||
tokenizer = (
|
||||
os.getenv("ATROPOS_TOKENIZER_NAME")
|
||||
or os.getenv("ATROPOS_TOKENIZER")
|
||||
or "NousResearch/Hermes-4.3-36B"
|
||||
)
|
||||
|
||||
env_config = GSM8kAgentEnvConfig(
|
||||
# Terminal + file toolsets (same as terminal_test_env.py)
|
||||
enabled_toolsets=["terminal", "file"],
|
||||
disabled_toolsets=None,
|
||||
distribution=None,
|
||||
# Agent settings
|
||||
max_agent_turns=5, # Math problems don't need many turns
|
||||
max_token_length=2048, # Room for reasoning + code
|
||||
agent_temperature=1.0,
|
||||
system_prompt=(
|
||||
"You are a helpful math assistant. You have access to a terminal "
|
||||
"where you can run Python code to help solve problems.\n\n"
|
||||
"When you need to calculate something, use the terminal tool with "
|
||||
"a command like: python3 -c \"print(2 + 2)\"\n\n"
|
||||
"When you have the final answer, write it inside \\boxed{} like: \\boxed{42}\n\n"
|
||||
"Work step by step. Use Python to verify your reasoning."
|
||||
),
|
||||
# Terminal backend (local for testing, modal for production)
|
||||
terminal_backend=os.getenv("TERMINAL_ENV", "local"),
|
||||
# Parser -- hermes format for Hermes models
|
||||
tool_call_parser="hermes",
|
||||
# Atropos settings
|
||||
group_size=4,
|
||||
tokenizer_name=tokenizer,
|
||||
steps_per_eval=5,
|
||||
total_steps=10,
|
||||
use_wandb=bool(os.getenv("WANDB_API_KEY")),
|
||||
wandb_name="gsm8k-agent",
|
||||
ensure_scores_are_not_same=False,
|
||||
# No external dataset (we load GSM8k ourselves)
|
||||
dataset_name=None,
|
||||
)
|
||||
|
||||
server_configs = [
|
||||
APIServerConfig(
|
||||
base_url=base_url,
|
||||
model_name=model,
|
||||
server_type="openai",
|
||||
api_key=api_key,
|
||||
health_check=False,
|
||||
)
|
||||
]
|
||||
|
||||
return env_config, server_configs
|
||||
|
||||
async def setup(self):
|
||||
"""Load GSM8k dataset."""
|
||||
from datasets import load_dataset
|
||||
|
||||
self.train = load_dataset("gsm8k", "main", split="train").shuffle(seed=42)
|
||||
test_data = load_dataset("gsm8k", "main", split="test").shuffle(seed=42)
|
||||
self.test = [
|
||||
{
|
||||
"question": item["question"],
|
||||
"gold_answer": item["answer"].split("#")[-1].strip().replace(",", ""),
|
||||
}
|
||||
for item in test_data
|
||||
]
|
||||
self.iter = 0
|
||||
self.reward_buffer: List[float] = []
|
||||
self.tool_use_buffer: List[int] = []
|
||||
print(f"[GSM8kAgentEnv] Loaded {len(self.train)} train, {len(self.test)} test examples")
|
||||
|
||||
async def get_next_item(self) -> Dict[str, str]:
|
||||
"""Cycle through training problems."""
|
||||
item = self.train[self.iter % len(self.train)]
|
||||
self.iter += 1
|
||||
return {
|
||||
"question": item["question"],
|
||||
"gold_answer": item["answer"].split("#")[-1].strip().replace(",", ""),
|
||||
}
|
||||
|
||||
def format_prompt(self, item: Dict[str, str]) -> str:
|
||||
"""Format the math problem as a user message."""
|
||||
return item["question"]
|
||||
|
||||
async def compute_reward(
|
||||
self, item: Dict[str, str], result: AgentResult, ctx: ToolContext
|
||||
) -> float:
|
||||
"""
|
||||
Score: verify the model's \\boxed{} answer against the gold answer.
|
||||
|
||||
The agent has full access to terminal via ctx, but for GSM8k we just
|
||||
check the final answer from the conversation.
|
||||
"""
|
||||
# Get the last assistant message content
|
||||
final_text = ""
|
||||
for msg in reversed(result.messages):
|
||||
if msg.get("role") == "assistant" and msg.get("content"):
|
||||
final_text = msg["content"]
|
||||
break
|
||||
|
||||
correct = _verify_math_answer(final_text, item["gold_answer"])
|
||||
reward = 1.0 if correct else 0.0
|
||||
|
||||
self.reward_buffer.append(reward)
|
||||
# Count tool calls in this trajectory
|
||||
tool_call_count = sum(
|
||||
len(msg.get("tool_calls", []))
|
||||
for msg in result.messages
|
||||
if msg.get("role") == "assistant"
|
||||
)
|
||||
self.tool_use_buffer.append(tool_call_count)
|
||||
|
||||
return reward
|
||||
|
||||
async def evaluate(self, *args, **kwargs):
|
||||
"""Evaluate on a subset of the test set (greedy, no tools for speed)."""
|
||||
start_time = time.time()
|
||||
correct = 0
|
||||
total = 0
|
||||
samples = []
|
||||
|
||||
eval_subset = self.test[:30] # Small subset for quick eval
|
||||
|
||||
for item in eval_subset:
|
||||
try:
|
||||
completion = await self.server.chat_completion(
|
||||
messages=[
|
||||
{"role": "system", "content": self.config.system_prompt or ""},
|
||||
{"role": "user", "content": item["question"]},
|
||||
],
|
||||
n=1,
|
||||
max_tokens=self.config.max_token_length,
|
||||
temperature=0.0,
|
||||
split="eval",
|
||||
)
|
||||
|
||||
response = completion.choices[0].message.content or ""
|
||||
is_correct = _verify_math_answer(response, item["gold_answer"])
|
||||
|
||||
if is_correct:
|
||||
correct += 1
|
||||
total += 1
|
||||
|
||||
samples.append({
|
||||
"question": item["question"],
|
||||
"gold_answer": item["gold_answer"],
|
||||
"response": response[:500],
|
||||
"correct": is_correct,
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Eval failed: %s", e)
|
||||
total += 1
|
||||
|
||||
percent_correct = correct / total if total > 0 else 0
|
||||
end_time = time.time()
|
||||
|
||||
await self.evaluate_log(
|
||||
metrics={"eval/percent_correct": percent_correct, "eval/total": total},
|
||||
samples=samples,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
)
|
||||
|
||||
async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
|
||||
"""Log training metrics."""
|
||||
if wandb_metrics is None:
|
||||
wandb_metrics = {}
|
||||
|
||||
if self.reward_buffer:
|
||||
wandb_metrics["train/percent_correct"] = sum(self.reward_buffer) / len(self.reward_buffer)
|
||||
wandb_metrics["train/total_rollouts"] = len(self.reward_buffer)
|
||||
self.reward_buffer = []
|
||||
|
||||
if self.tool_use_buffer:
|
||||
wandb_metrics["train/avg_tool_calls"] = sum(self.tool_use_buffer) / len(self.tool_use_buffer)
|
||||
wandb_metrics["train/tool_use_rate"] = sum(1 for t in self.tool_use_buffer if t > 0) / len(self.tool_use_buffer)
|
||||
self.tool_use_buffer = []
|
||||
|
||||
await super().wandb_log(wandb_metrics)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
GSM8kAgentEnv.cli()
|
||||
@@ -45,7 +45,7 @@ if _env_path.exists():
|
||||
# This patches SwerexModalEnvironment to use a background thread instead of
|
||||
# asyncio.run(), which would deadlock inside Atropos. Safe for normal CLI too.
|
||||
from environments.patches import apply_patches
|
||||
# apply_patches() # DISABLED: sglang patch breaks native vLLM /generate
|
||||
apply_patches()
|
||||
|
||||
from atroposlib.envs.base import (
|
||||
BaseEnv,
|
||||
@@ -64,7 +64,7 @@ from environments.agent_loop import AgentResult, HermesAgentLoop
|
||||
from environments.tool_context import ToolContext
|
||||
|
||||
# Import hermes-agent toolset infrastructure
|
||||
from model_tools import get_tool_definitions, handle_function_call
|
||||
from model_tools import get_tool_definitions
|
||||
from toolset_distributions import sample_toolsets_from_distribution
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -140,48 +140,6 @@ class HermesAgentEnvConfig(BaseEnvConfig):
|
||||
"Options: hermes, mistral, llama3_json, qwen, deepseek_v3, etc.",
|
||||
)
|
||||
|
||||
# --- Sandbox pool mode (optional, for scaled environments) ---
|
||||
tool_pool_mode: str = Field(
|
||||
default="default",
|
||||
description="Tool execution mode: 'default' (terminal tool per task_id), "
|
||||
"'nomad' (slot pool via Nomad/Docker/Singularity), or 'modal' (Modal sandbox pool).",
|
||||
)
|
||||
|
||||
# Sandbox pool: shared settings
|
||||
allow_network: bool = Field(default=True, description="Whether sandbox bash commands may access the network.")
|
||||
require_sandbox: bool = Field(default=False, description="Fail closed if bubblewrap is unavailable.")
|
||||
purge_job_on_start: bool = Field(default=False, description="Purge existing sandbox job on startup.")
|
||||
purge_job_on_shutdown: bool = Field(default=True, description="Purge sandbox job on shutdown.")
|
||||
acquire_timeout_s: float = Field(default=30.0, description="Slot acquisition timeout (seconds).")
|
||||
|
||||
# Sandbox pool: Nomad settings
|
||||
nomad_address: str = Field(default="http://localhost:4646", description="Nomad API address.")
|
||||
sandbox_job_id: str = Field(default="atropos-sandbox", description="Nomad job id for sandbox containers.")
|
||||
sandbox_image: str = Field(default="atropos-sandbox:local", description="Docker image for sandbox containers.")
|
||||
slots_per_container: int = Field(default=10, description="Nomad: slots per container.")
|
||||
min_containers: int = Field(default=1, description="Nomad: minimum containers.")
|
||||
max_containers: int = Field(default=10, description="Nomad: maximum containers.")
|
||||
privileged: bool = Field(default=False, description="Nomad: run container privileged.")
|
||||
driver: str = Field(default="docker", description="Nomad task driver: 'docker' or 'singularity'.")
|
||||
singularity_image: Optional[str] = Field(default=None, description="Path to .sif file for Singularity driver.")
|
||||
|
||||
# Sandbox pool: Modal settings
|
||||
modal_app_name: str = Field(default="atropos-sandbox", description="Modal app name prefix.")
|
||||
modal_image: str = Field(default="python:3.11", description="Modal: container image.")
|
||||
modal_gpu: Optional[str] = Field(default=None, description="Modal: GPU type (None, 'T4', 'A10G', 'A100', 'H100').")
|
||||
modal_cpu: float = Field(default=1.0, description="Modal: CPU cores.")
|
||||
modal_memory: int = Field(default=2048, description="Modal: memory in MB.")
|
||||
modal_slots_per_sandbox: int = Field(default=10, description="Modal: slots per sandbox.")
|
||||
modal_min_sandboxes: int = Field(default=1, description="Modal: minimum sandboxes.")
|
||||
modal_max_sandboxes: int = Field(default=5, description="Modal: maximum sandboxes.")
|
||||
modal_idle_timeout: int = Field(default=120, description="Modal: idle timeout (seconds).")
|
||||
modal_max_lifetime: int = Field(default=3600, description="Modal: max sandbox lifetime (seconds).")
|
||||
modal_acquire_timeout: float = Field(default=60.0, description="Modal: slot acquisition timeout (seconds).")
|
||||
modal_execution_timeout: float = Field(default=30.0, description="Modal: command execution timeout (seconds).")
|
||||
modal_secrets: str = Field(default="", description="Modal: comma-separated Modal Secret names.")
|
||||
modal_env_vars: str = Field(default="", description="Modal: semicolon-separated KEY=VALUE pairs.")
|
||||
modal_workspace_base: str = Field(default="/data", description="Modal: workspace base directory.")
|
||||
|
||||
|
||||
class HermesAgentBaseEnv(BaseEnv):
|
||||
"""
|
||||
@@ -228,9 +186,6 @@ class HermesAgentBaseEnv(BaseEnv):
|
||||
# Tool error tracking for wandb logging
|
||||
self._tool_error_buffer: List[Dict[str, Any]] = []
|
||||
|
||||
# Sandbox pool backend (only used when tool_pool_mode != "default")
|
||||
self._sandbox_backend = None
|
||||
|
||||
# =========================================================================
|
||||
# Toolset resolution (per-group)
|
||||
# =========================================================================
|
||||
@@ -270,12 +225,6 @@ class HermesAgentBaseEnv(BaseEnv):
|
||||
# =========================================================================
|
||||
|
||||
def _use_managed_server(self) -> bool:
|
||||
import sys
|
||||
result = self._use_managed_server_inner()
|
||||
print(f"HERMES_DEBUG _use_managed_server={result}, servers={len(self.server.servers) if hasattr(self.server, 'servers') else 'N/A'}, type={type(self.server.servers[0]).__name__ if hasattr(self.server, 'servers') and self.server.servers else 'N/A'}", file=sys.stderr, flush=True)
|
||||
return result
|
||||
|
||||
def _use_managed_server_inner(self) -> bool:
|
||||
"""
|
||||
Determine if we should use ManagedServer (Phase 2) or direct server (Phase 1).
|
||||
|
||||
@@ -293,154 +242,6 @@ class HermesAgentBaseEnv(BaseEnv):
|
||||
from atroposlib.envs.server_handling.openai_server import OpenAIServer
|
||||
return not isinstance(server, OpenAIServer)
|
||||
|
||||
# =========================================================================
|
||||
# Sandbox pool backend (tool_pool_mode != "default")
|
||||
# =========================================================================
|
||||
|
||||
async def _start_sandbox_backend(self) -> None:
|
||||
"""
|
||||
Configure the slot pool backend if tool_pool_mode is not 'default'.
|
||||
|
||||
Sets TERMINAL_ENV=slot_pool and configures env vars so that ALL hermes
|
||||
tools (terminal, file, etc.) automatically route through the sandbox
|
||||
pool via _SlotPoolEnvironment in terminal_tool.py.
|
||||
"""
|
||||
if self.config.tool_pool_mode == "default":
|
||||
return
|
||||
|
||||
mode = self.config.tool_pool_mode
|
||||
logger.info("Configuring slot pool backend (mode=%s)", mode)
|
||||
|
||||
# Set TERMINAL_ENV=slot_pool so terminal_tool.py uses _SlotPoolEnvironment
|
||||
os.environ["TERMINAL_ENV"] = "slot_pool"
|
||||
|
||||
# Set the backend type (modal or nomad)
|
||||
if mode == "modal":
|
||||
os.environ["TERMINAL_SLOT_BACKEND"] = "modal"
|
||||
# Forward modal config from env config to slot pool env vars
|
||||
os.environ.setdefault("TERMINAL_MODAL_IMAGE", self.config.modal_image)
|
||||
os.environ.setdefault("TERMINAL_MODAL_SLOTS", str(self.config.modal_slots_per_sandbox))
|
||||
os.environ.setdefault("TERMINAL_MODAL_MIN", str(self.config.modal_min_sandboxes))
|
||||
os.environ.setdefault("TERMINAL_MODAL_MAX", str(self.config.modal_max_sandboxes))
|
||||
os.environ.setdefault("TERMINAL_MODAL_IDLE_TIMEOUT", str(self.config.modal_idle_timeout))
|
||||
os.environ.setdefault("TERMINAL_MODAL_MAX_LIFETIME", str(self.config.modal_max_lifetime))
|
||||
os.environ.setdefault("TERMINAL_MODAL_ACQUIRE_TIMEOUT", str(self.config.modal_acquire_timeout))
|
||||
os.environ.setdefault("TERMINAL_MODAL_EXEC_TIMEOUT", str(self.config.modal_execution_timeout))
|
||||
os.environ.setdefault("TERMINAL_MODAL_WORKSPACE", self.config.modal_workspace_base)
|
||||
if self.config.modal_gpu:
|
||||
os.environ.setdefault("TERMINAL_MODAL_GPU", self.config.modal_gpu)
|
||||
elif mode == "nomad":
|
||||
os.environ["TERMINAL_SLOT_BACKEND"] = "nomad"
|
||||
os.environ.setdefault("TERMINAL_NOMAD_ADDRESS", self.config.nomad_address)
|
||||
os.environ.setdefault("TERMINAL_NOMAD_IMAGE", self.config.sandbox_image)
|
||||
os.environ.setdefault("TERMINAL_NOMAD_DRIVER", self.config.driver)
|
||||
os.environ.setdefault("TERMINAL_NOMAD_SLOTS", str(self.config.slots_per_container))
|
||||
os.environ.setdefault("TERMINAL_NOMAD_MIN", str(self.config.min_containers))
|
||||
os.environ.setdefault("TERMINAL_NOMAD_MAX", str(self.config.max_containers))
|
||||
|
||||
# Eagerly start the _SlotPoolManager so the backend is ready
|
||||
# before any trajectories try to use it
|
||||
from tools.terminal_tool import _SlotPoolManager
|
||||
_SlotPoolManager.get_instance() # Triggers _start() which creates sandboxes
|
||||
|
||||
self._sandbox_backend = True # Flag that sandbox mode is active
|
||||
print(f"🔧 Slot pool started: TERMINAL_ENV=slot_pool, backend={mode}")
|
||||
|
||||
async def _stop_sandbox_backend(self) -> None:
|
||||
"""Stop the slot pool backend."""
|
||||
if self._sandbox_backend:
|
||||
logger.info("Stopping slot pool backend")
|
||||
try:
|
||||
from tools.terminal_tool import _SlotPoolManager
|
||||
_SlotPoolManager.reset_instance()
|
||||
except Exception as e:
|
||||
logger.warning("Slot pool shutdown: %s", e)
|
||||
self._sandbox_backend = None
|
||||
|
||||
# =========================================================================
|
||||
# Optional hooks for sandbox environments
|
||||
# =========================================================================
|
||||
|
||||
async def setup_trajectory_workspace(
|
||||
self,
|
||||
item: Item,
|
||||
*,
|
||||
trajectory_id: str,
|
||||
exec_tool,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Optional hook: prepare the sandbox workspace before the agent starts.
|
||||
|
||||
Override in subclasses for environments that need workspace setup
|
||||
(e.g., git clone, worktree creation, dependency installation).
|
||||
|
||||
Args:
|
||||
item: The dataset item being rolled out
|
||||
trajectory_id: Unique ID for this trajectory
|
||||
exec_tool: Callable to execute tool calls in the sandbox
|
||||
|
||||
Returns:
|
||||
Dict of workspace metadata (passed to verify_and_score_trajectory)
|
||||
"""
|
||||
return {}
|
||||
|
||||
async def verify_and_score_trajectory(
|
||||
self,
|
||||
item: Item,
|
||||
result: AgentResult,
|
||||
*,
|
||||
trajectory_id: str,
|
||||
exec_tool,
|
||||
workspace_meta: Optional[Dict[str, Any]] = None,
|
||||
) -> Tuple[float, Dict[str, Any]]:
|
||||
"""
|
||||
Optional hook: run in-sandbox verification before scoring.
|
||||
|
||||
Override in subclasses for environments that need to verify results
|
||||
inside the sandbox (e.g., run pytest, check file contents).
|
||||
|
||||
Default: calls compute_reward() with ToolContext.
|
||||
|
||||
Args:
|
||||
item: The dataset item
|
||||
result: The agent's rollout result
|
||||
trajectory_id: Unique ID for this trajectory
|
||||
exec_tool: Callable to execute tool calls in the sandbox
|
||||
workspace_meta: Metadata from setup_trajectory_workspace
|
||||
|
||||
Returns:
|
||||
Tuple of (reward, metadata_dict)
|
||||
"""
|
||||
ctx = ToolContext(trajectory_id)
|
||||
try:
|
||||
reward = await self.compute_reward(item, result, ctx)
|
||||
except Exception as e:
|
||||
logger.error("compute_reward failed: %s", e)
|
||||
reward = 0.0
|
||||
finally:
|
||||
ctx.cleanup()
|
||||
return reward, {}
|
||||
|
||||
# =========================================================================
|
||||
# Lifecycle hooks for env_manager/process_manager cleanup
|
||||
# =========================================================================
|
||||
|
||||
async def env_manager(self):
|
||||
"""Start sandbox backend, run env, then clean up."""
|
||||
await self._start_sandbox_backend()
|
||||
try:
|
||||
return await super().env_manager()
|
||||
finally:
|
||||
await self._stop_sandbox_backend()
|
||||
|
||||
async def process_manager(self):
|
||||
"""Start sandbox backend, run process, then clean up."""
|
||||
await self._start_sandbox_backend()
|
||||
try:
|
||||
return await super().process_manager()
|
||||
finally:
|
||||
await self._stop_sandbox_backend()
|
||||
|
||||
# =========================================================================
|
||||
# Core Atropos integration
|
||||
# =========================================================================
|
||||
@@ -584,13 +385,6 @@ class HermesAgentBaseEnv(BaseEnv):
|
||||
|
||||
await super().wandb_log(wandb_metrics)
|
||||
|
||||
def _use_sandbox_backend(self) -> bool:
|
||||
"""Check if we should route tool execution through a sandbox backend."""
|
||||
return (
|
||||
self.config.tool_pool_mode != "default"
|
||||
and self._sandbox_backend is not None
|
||||
)
|
||||
|
||||
async def collect_trajectory(
|
||||
self, item: Item
|
||||
) -> Tuple[Optional[Union[ScoredDataItem, Any]], List[Item]]:
|
||||
@@ -599,19 +393,12 @@ class HermesAgentBaseEnv(BaseEnv):
|
||||
|
||||
This is called group_size times in parallel by collect_trajectories().
|
||||
Each call gets its own task_id for terminal/browser session isolation.
|
||||
|
||||
When tool_pool_mode != "default", routes tool execution through the
|
||||
sandbox backend (Modal, Nomad) with slot-based multiplexing:
|
||||
1. Acquire a slot from the sandbox pool
|
||||
2. Setup workspace via subclass hook (e.g., git clone + worktree)
|
||||
3. Run agent loop with terminal calls routed through sandbox
|
||||
4. Verify and score in-sandbox via subclass hook (e.g., pytest)
|
||||
5. Release the slot
|
||||
"""
|
||||
task_id = str(uuid.uuid4())
|
||||
|
||||
# Get group-level tools (resolved once in collect_trajectories)
|
||||
if self._current_group_tools is None:
|
||||
# Fallback: resolve per-trajectory if called outside collect_trajectories
|
||||
tools, valid_names = self._resolve_tools_for_group()
|
||||
else:
|
||||
tools, valid_names = self._current_group_tools
|
||||
@@ -622,194 +409,11 @@ class HermesAgentBaseEnv(BaseEnv):
|
||||
messages.append({"role": "system", "content": self.config.system_prompt})
|
||||
messages.append({"role": "user", "content": self.format_prompt(item)})
|
||||
|
||||
# Dispatch to the appropriate path
|
||||
if self._use_sandbox_backend():
|
||||
return await self._collect_trajectory_sandbox(
|
||||
item, task_id, tools, valid_names, messages
|
||||
)
|
||||
else:
|
||||
return await self._collect_trajectory_local(
|
||||
item, task_id, tools, valid_names, messages
|
||||
)
|
||||
|
||||
async def _collect_trajectory_local(
|
||||
self,
|
||||
item: Item,
|
||||
task_id: str,
|
||||
tools: List[Dict[str, Any]],
|
||||
valid_names: Set[str],
|
||||
messages: List[Dict[str, Any]],
|
||||
) -> Tuple[Optional[Union[ScoredDataItem, Any]], List[Item]]:
|
||||
"""
|
||||
Default (local) trajectory collection path.
|
||||
|
||||
Uses hermes-agent's handle_function_call() for tool execution.
|
||||
Reward computed via compute_reward() with ToolContext.
|
||||
"""
|
||||
result = await self._run_agent_loop(
|
||||
task_id, tools, valid_names, messages, tool_handler=None
|
||||
)
|
||||
|
||||
# Skip reward if the agent loop produced no meaningful work
|
||||
only_system_and_user = all(
|
||||
msg.get("role") in ("system", "user") for msg in result.messages
|
||||
)
|
||||
if result.turns_used == 0 or only_system_and_user:
|
||||
logger.warning(
|
||||
"Agent loop produced no output (turns=%d, msgs=%d). Skipping reward.",
|
||||
result.turns_used, len(result.messages),
|
||||
)
|
||||
reward = 0.0
|
||||
else:
|
||||
ctx = ToolContext(task_id)
|
||||
try:
|
||||
reward = await self.compute_reward(item, result, ctx)
|
||||
except Exception as e:
|
||||
logger.error("compute_reward failed: %s", e)
|
||||
reward = 0.0
|
||||
finally:
|
||||
ctx.cleanup()
|
||||
|
||||
return self._build_scored_item(item, result, reward)
|
||||
|
||||
async def _collect_trajectory_sandbox(
|
||||
self,
|
||||
item: Item,
|
||||
task_id: str,
|
||||
tools: List[Dict[str, Any]],
|
||||
valid_names: Set[str],
|
||||
messages: List[Dict[str, Any]],
|
||||
) -> Tuple[Optional[Union[ScoredDataItem, Any]], List[Item]]:
|
||||
"""
|
||||
Sandbox trajectory collection path (Modal, Nomad).
|
||||
|
||||
Uses TERMINAL_ENV=slot_pool so ALL hermes tools (terminal, file, web)
|
||||
automatically route through the sandbox pool via _SlotPoolEnvironment.
|
||||
No per-tool routing needed — the slot pool is the terminal backend.
|
||||
|
||||
Flow:
|
||||
1. Pre-warm terminal env (acquires a slot in the pool)
|
||||
2. Setup workspace via subclass hook (e.g., git clone + worktree)
|
||||
3. Run agent loop with tool_handler=None (all tools use handle_function_call)
|
||||
4. Verify and score in-sandbox via subclass hook (e.g., pytest)
|
||||
5. Release the slot via cleanup_vm()
|
||||
"""
|
||||
from tools.terminal_tool import _SlotPoolManager, cleanup_vm
|
||||
from dataclasses import dataclass
|
||||
|
||||
@dataclass
|
||||
class _ExecResult:
|
||||
"""Lightweight result for exec_tool compatibility with env hooks."""
|
||||
success: bool
|
||||
output: str = ""
|
||||
error: str = ""
|
||||
metadata: Dict[str, Any] = None
|
||||
def __post_init__(self):
|
||||
if self.metadata is None:
|
||||
self.metadata = {}
|
||||
|
||||
try:
|
||||
# 1. Pre-warm: trigger terminal env creation → acquires slot
|
||||
logger.info("Pre-warming sandbox slot for task %s", task_id)
|
||||
loop = asyncio.get_event_loop()
|
||||
warmup = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: handle_function_call(
|
||||
"terminal", {"command": "echo slot_ready"}, task_id=task_id
|
||||
),
|
||||
)
|
||||
logger.info("Sandbox slot acquired for task %s", task_id)
|
||||
|
||||
# 2. Create exec_tool for setup/verify hooks
|
||||
# Routes through handle_function_call → terminal_tool → same _SlotPoolEnvironment
|
||||
async def exec_tool(tool_name: str, args: Dict[str, Any], timeout: float = 300) -> _ExecResult:
|
||||
command = args.get("command", "")
|
||||
result_json = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: handle_function_call(
|
||||
"terminal",
|
||||
{"command": command, "timeout": int(timeout)},
|
||||
task_id=task_id,
|
||||
),
|
||||
)
|
||||
try:
|
||||
result_dict = json.loads(result_json)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
result_dict = {"output": str(result_json), "exit_code": 1}
|
||||
returncode = result_dict.get("exit_code", result_dict.get("returncode", 1))
|
||||
output = result_dict.get("output", "")
|
||||
return _ExecResult(
|
||||
success=(returncode == 0),
|
||||
output=output,
|
||||
error=result_dict.get("error", "") if returncode != 0 else "",
|
||||
metadata={"returncode": returncode},
|
||||
)
|
||||
|
||||
# 3. Setup workspace (subclass hook: git clone, worktree, etc.)
|
||||
workspace_meta = await self.setup_trajectory_workspace(
|
||||
item, trajectory_id=task_id, exec_tool=exec_tool
|
||||
)
|
||||
|
||||
# 4. Run agent loop — tool_handler=None means ALL tools go through
|
||||
# handle_function_call() → terminal_tool() → _SlotPoolEnvironment
|
||||
# → same sandbox slot. File tools also route through same env.
|
||||
result = await self._run_agent_loop(
|
||||
task_id, tools, valid_names, messages,
|
||||
tool_handler=None,
|
||||
)
|
||||
|
||||
# 5. Skip verification if no meaningful work
|
||||
only_system_and_user = all(
|
||||
msg.get("role") in ("system", "user") for msg in result.messages
|
||||
)
|
||||
if result.turns_used == 0 or only_system_and_user:
|
||||
logger.warning(
|
||||
"Agent loop produced no output (turns=%d, msgs=%d). Skipping reward.",
|
||||
result.turns_used, len(result.messages),
|
||||
)
|
||||
reward = 0.0
|
||||
else:
|
||||
# 6. Verify and score in-sandbox (subclass hook: pytest, etc.)
|
||||
reward, score_meta = await self.verify_and_score_trajectory(
|
||||
item, result,
|
||||
trajectory_id=task_id,
|
||||
exec_tool=exec_tool,
|
||||
workspace_meta=workspace_meta,
|
||||
)
|
||||
logger.info("Sandbox reward for task %s: %.2f", task_id, reward)
|
||||
|
||||
return self._build_scored_item(item, result, reward)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Sandbox trajectory failed for task %s: %s", task_id, e, exc_info=True)
|
||||
dummy_result = AgentResult(
|
||||
messages=messages, turns_used=0, finished_naturally=False
|
||||
)
|
||||
return self._build_scored_item(item, dummy_result, 0.0)
|
||||
|
||||
finally:
|
||||
# Release the slot back to the pool
|
||||
try:
|
||||
cleanup_vm(task_id)
|
||||
logger.info("Released sandbox slot for task %s", task_id)
|
||||
except Exception as e:
|
||||
logger.error("Failed to release slot for task %s: %s", task_id, e)
|
||||
|
||||
async def _run_agent_loop(
|
||||
self,
|
||||
task_id: str,
|
||||
tools: List[Dict[str, Any]],
|
||||
valid_names: Set[str],
|
||||
messages: List[Dict[str, Any]],
|
||||
tool_handler=None,
|
||||
) -> AgentResult:
|
||||
"""
|
||||
Run the agent loop in either Phase 1 or Phase 2 mode.
|
||||
|
||||
Shared between local and sandbox paths -- the only difference is
|
||||
the tool_handler parameter (None for local, sandbox callable for sandbox).
|
||||
"""
|
||||
# Run the agent loop
|
||||
result: AgentResult
|
||||
if self._use_managed_server():
|
||||
# Phase 2: ManagedServer with parser -- exact tokens + logprobs
|
||||
# Load the tool call parser from registry based on config
|
||||
from environments.tool_call_parsers import get_parser
|
||||
try:
|
||||
tc_parser = get_parser(self.config.tool_call_parser)
|
||||
@@ -825,13 +429,6 @@ class HermesAgentBaseEnv(BaseEnv):
|
||||
tokenizer=self.tokenizer,
|
||||
tool_call_parser=tc_parser,
|
||||
) as managed:
|
||||
# Calculate max prompt tokens
|
||||
# Context budget = max_token_length (prompt can be as long as generation budget)
|
||||
# This ensures prompt + generation stays under typical model context limits
|
||||
# E.g., max_token_length=16384 → 16384 prompt + 16384 gen = 32K < 40960 model limit
|
||||
_max_ctx = None
|
||||
if self.config.max_token_length and self.config.max_token_length > 0:
|
||||
_max_ctx = self.config.max_token_length
|
||||
agent = HermesAgentLoop(
|
||||
server=managed,
|
||||
tool_schemas=tools,
|
||||
@@ -840,18 +437,14 @@ class HermesAgentBaseEnv(BaseEnv):
|
||||
task_id=task_id,
|
||||
temperature=self.config.agent_temperature,
|
||||
max_tokens=self.config.max_token_length,
|
||||
tool_handler=tool_handler,
|
||||
max_context_tokens=_max_ctx,
|
||||
)
|
||||
return await agent.run(messages)
|
||||
result = await agent.run(messages)
|
||||
except NotImplementedError:
|
||||
# DummyManagedServer not allowed -- fall back to Phase 1
|
||||
logger.warning(
|
||||
"ManagedServer not available (OpenAI server?). "
|
||||
"Falling back to direct server mode."
|
||||
)
|
||||
_max_ctx = None
|
||||
if self.config.max_token_length and self.config.max_token_length > 0:
|
||||
_max_ctx = self.config.max_token_length
|
||||
agent = HermesAgentLoop(
|
||||
server=self.server,
|
||||
tool_schemas=tools,
|
||||
@@ -860,14 +453,10 @@ class HermesAgentBaseEnv(BaseEnv):
|
||||
task_id=task_id,
|
||||
temperature=self.config.agent_temperature,
|
||||
max_tokens=self.config.max_token_length,
|
||||
tool_handler=tool_handler,
|
||||
max_context_tokens=_max_ctx,
|
||||
)
|
||||
return await agent.run(messages)
|
||||
result = await agent.run(messages)
|
||||
else:
|
||||
_max_ctx = None
|
||||
if self.config.max_token_length and self.config.max_token_length > 0:
|
||||
_max_ctx = self.config.max_token_length
|
||||
# Phase 1: OpenAI server -- native tool_calls, placeholder tokens
|
||||
agent = HermesAgentLoop(
|
||||
server=self.server,
|
||||
tool_schemas=tools,
|
||||
@@ -876,22 +465,32 @@ class HermesAgentBaseEnv(BaseEnv):
|
||||
task_id=task_id,
|
||||
temperature=self.config.agent_temperature,
|
||||
max_tokens=self.config.max_token_length,
|
||||
tool_handler=tool_handler,
|
||||
max_context_tokens=_max_ctx,
|
||||
)
|
||||
return await agent.run(messages)
|
||||
result = await agent.run(messages)
|
||||
|
||||
def _build_scored_item(
|
||||
self,
|
||||
item: Item,
|
||||
result: AgentResult,
|
||||
reward: float,
|
||||
) -> Tuple[Optional[Union[ScoredDataItem, Any]], List[Item]]:
|
||||
"""
|
||||
Build a ScoredDataItem from an AgentResult and reward.
|
||||
# Skip reward computation if the agent loop produced no meaningful work
|
||||
# (e.g., API call failed on turn 1). No point spinning up a Modal sandbox
|
||||
# just to verify files that were never created.
|
||||
only_system_and_user = all(
|
||||
msg.get("role") in ("system", "user") for msg in result.messages
|
||||
)
|
||||
if result.turns_used == 0 or only_system_and_user:
|
||||
logger.warning(
|
||||
"Agent loop produced no output (turns=%d, msgs=%d). Skipping reward.",
|
||||
result.turns_used, len(result.messages),
|
||||
)
|
||||
reward = 0.0
|
||||
else:
|
||||
# Compute reward using ToolContext (gives verifier full tool access)
|
||||
ctx = ToolContext(task_id)
|
||||
try:
|
||||
reward = await self.compute_reward(item, result, ctx)
|
||||
except Exception as e:
|
||||
logger.error("compute_reward failed: %s", e)
|
||||
reward = 0.0
|
||||
finally:
|
||||
ctx.cleanup()
|
||||
|
||||
Shared between local and sandbox paths.
|
||||
"""
|
||||
# Track tool errors for wandb logging
|
||||
if result.tool_errors:
|
||||
for err in result.tool_errors:
|
||||
@@ -904,19 +503,28 @@ class HermesAgentBaseEnv(BaseEnv):
|
||||
})
|
||||
|
||||
# Build ScoredDataItem from ManagedServer state
|
||||
# Phase 2: real tokens/masks/logprobs from SequenceNodes
|
||||
# Phase 1: placeholder tokens (still need a valid ScoredDataItem for the pipeline)
|
||||
nodes = (result.managed_state or {}).get("nodes", [])
|
||||
|
||||
if nodes:
|
||||
node = nodes[-1]
|
||||
# Phase 2 (or DummyManagedServer): use actual node data
|
||||
node = nodes[-1] # Final sequence node = full trajectory
|
||||
scored_item: Dict[str, Any] = {
|
||||
"tokens": node.tokens,
|
||||
"masks": node.masked_tokens,
|
||||
"scores": reward,
|
||||
}
|
||||
|
||||
# Include logprobs if available (Phase 2)
|
||||
if hasattr(node, "logprobs") and node.logprobs:
|
||||
scored_item["advantages"] = None
|
||||
scored_item["advantages"] = None # Computed by trainer
|
||||
scored_item["ref_logprobs"] = None
|
||||
else:
|
||||
# Phase 1 with no managed state: create placeholder tokens
|
||||
# so the data pipeline doesn't break. These are NOT suitable
|
||||
# for training but allow process mode (SFT data gen) to work.
|
||||
# Tokenize the full conversation to get approximate tokens.
|
||||
full_text = "\n".join(
|
||||
msg.get("content", "") for msg in result.messages if msg.get("content")
|
||||
)
|
||||
@@ -927,11 +535,13 @@ class HermesAgentBaseEnv(BaseEnv):
|
||||
|
||||
scored_item = {
|
||||
"tokens": tokens,
|
||||
"masks": [-100] + tokens[1:],
|
||||
"masks": [-100] + tokens[1:], # Mask first token as prompt
|
||||
"scores": reward,
|
||||
}
|
||||
|
||||
# Always include messages for wandb rollout display and data logging
|
||||
scored_item["messages"] = result.messages
|
||||
|
||||
return scored_item, []
|
||||
|
||||
# =========================================================================
|
||||
|
||||
@@ -171,126 +171,6 @@ def _patch_swerex_modal():
|
||||
logger.debug("Patched SwerexModalEnvironment for async-safe operation")
|
||||
|
||||
|
||||
def _patch_vllm_server_for_sglang():
|
||||
"""
|
||||
(Mainly for Runpod serverless compat)
|
||||
|
||||
Monkey patch VLLMServer._tokens_and_logprobs_completion_wrapper to handle
|
||||
SGLang's /generate response format.
|
||||
|
||||
VLLMServer expects:
|
||||
Request: {"prompt": {"prompt_token_ids": [...]}, "logprobs": 0}
|
||||
Response: {"logprobs": [[{token_id: logprob}]], "finish_reasons": [...]}
|
||||
|
||||
SGLang returns:
|
||||
Request: {"input_ids": [...], "sampling_params": {...}, "return_logprob": true}
|
||||
Response: {"text": "...", "meta_info": {"output_token_logprobs": [[logprob, token_id, text], ...]}}
|
||||
|
||||
This patch makes VLLMServer work with SGLang endpoints (e.g., RunPod SGLang workers).
|
||||
"""
|
||||
try:
|
||||
import aiohttp
|
||||
from atroposlib.envs.server_handling.vllm_server import VLLMServer
|
||||
except ImportError:
|
||||
logger.debug("atroposlib VLLMServer not available, skipping SGLang patch")
|
||||
return
|
||||
|
||||
# Save the original method
|
||||
_original_wrapper = VLLMServer._tokens_and_logprobs_completion_wrapper
|
||||
|
||||
async def _sglang_compatible_wrapper(self, **kwargs):
|
||||
"""
|
||||
Patched wrapper that tries the original VLLMServer format first,
|
||||
then falls back to SGLang format if that fails.
|
||||
"""
|
||||
assert kwargs.get("model") is not None, "Model is required!"
|
||||
assert kwargs.get("prompt") is not None or kwargs.get("input_ids") is not None, "Prompt or input_ids required!"
|
||||
|
||||
# Get prompt tokens
|
||||
if "input_ids" in kwargs:
|
||||
prompt_tokens = kwargs.pop("input_ids")
|
||||
kwargs.pop("prompt", None)
|
||||
else:
|
||||
prompt_tokens = self.tokenizer.encode(kwargs.pop("prompt"))
|
||||
|
||||
# Check for double BOS
|
||||
if (len(prompt_tokens) >= 2
|
||||
and prompt_tokens[0] == self.tokenizer.bos_token_id == prompt_tokens[1]):
|
||||
prompt_tokens = prompt_tokens[1:]
|
||||
|
||||
# Normalize kwargs
|
||||
max_tokens = kwargs.pop("max_new_tokens", kwargs.pop("max_completion_tokens", kwargs.pop("max_tokens", 2048)))
|
||||
n = kwargs.pop("n", 1)
|
||||
temperature = kwargs.pop("temperature", 1.0)
|
||||
kwargs.pop("model", None)
|
||||
|
||||
# Build SGLang-compatible request
|
||||
request_data = {
|
||||
"input_ids": prompt_tokens,
|
||||
"sampling_params": {
|
||||
"max_new_tokens": max_tokens,
|
||||
"temperature": temperature,
|
||||
"n": n,
|
||||
},
|
||||
"return_logprob": True,
|
||||
"top_logprobs_num": 0,
|
||||
}
|
||||
|
||||
generate_url = f"{self.config.base_url.replace('/v1', '')}/generate"
|
||||
|
||||
headers = {}
|
||||
if self.config.api_key:
|
||||
headers["Authorization"] = f"Bearer {self.config.api_key}"
|
||||
headers["Content-Type"] = "application/json"
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
generate_url,
|
||||
json=request_data,
|
||||
headers=headers,
|
||||
timeout=aiohttp.ClientTimeout(total=self.config.timeout),
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
raw_text = await response.text()
|
||||
|
||||
# RunPod wraps JSON responses in quotes — may need double-parse
|
||||
import json
|
||||
results = json.loads(raw_text)
|
||||
if isinstance(results, str):
|
||||
results = json.loads(results)
|
||||
|
||||
# Parse SGLang response format
|
||||
meta = results.get("meta_info", {})
|
||||
output_token_logprobs_raw = meta.get("output_token_logprobs", [])
|
||||
|
||||
# SGLang format: [[logprob, token_id, token_text], ...]
|
||||
output_tokens = []
|
||||
output_logprobs = []
|
||||
for entry in output_token_logprobs_raw:
|
||||
if isinstance(entry, (list, tuple)) and len(entry) >= 2:
|
||||
logprob, token_id = entry[0], entry[1]
|
||||
output_tokens.append(int(token_id))
|
||||
output_logprobs.append(float(logprob))
|
||||
|
||||
# Get finish reason
|
||||
finish_reason_raw = meta.get("finish_reason", "stop")
|
||||
if isinstance(finish_reason_raw, dict):
|
||||
finish_reason = finish_reason_raw.get("type", "stop")
|
||||
else:
|
||||
finish_reason = str(finish_reason_raw)
|
||||
|
||||
return (
|
||||
prompt_tokens,
|
||||
[output_tokens],
|
||||
[output_logprobs],
|
||||
[finish_reason],
|
||||
)
|
||||
|
||||
# Apply the patch
|
||||
VLLMServer._tokens_and_logprobs_completion_wrapper = _sglang_compatible_wrapper
|
||||
logger.info("Patched VLLMServer for SGLang /generate compatibility")
|
||||
|
||||
|
||||
def apply_patches():
|
||||
"""
|
||||
Apply all monkey patches needed for Atropos compatibility.
|
||||
@@ -304,6 +184,5 @@ def apply_patches():
|
||||
return
|
||||
|
||||
_patch_swerex_modal()
|
||||
# _patch_vllm_server_for_sglang()
|
||||
|
||||
_patches_applied = True
|
||||
|
||||
@@ -1,620 +0,0 @@
|
||||
"""
|
||||
SWE-smith-oracle environment (ported to HermesAgentBaseEnv).
|
||||
|
||||
Trains models to fix real GitHub repositories:
|
||||
- Clones a public GitHub repo at a specific commit
|
||||
- Runs an agent loop with terminal tool to apply a fix
|
||||
- Verifies by running pytest with nodeids from the dataset
|
||||
- Reward: 1.0 if all tests pass, 0.0 otherwise
|
||||
|
||||
Dataset: NousResearch/SWE-smith-oracle (train split; does NOT use SWE-bench eval set).
|
||||
|
||||
Usage:
|
||||
# Process mode (OpenAI server, no training):
|
||||
python environments/swe_smith_oracle_env.py process \\
|
||||
--env.data_path_to_save_groups data/swe_oracle_output.jsonl
|
||||
|
||||
# With Modal sandbox backend:
|
||||
python environments/swe_smith_oracle_env.py process \\
|
||||
--env.tool_pool_mode modal \\
|
||||
--env.modal_image python:3.11
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
_repo_root = Path(__file__).resolve().parent.parent
|
||||
if str(_repo_root) not in sys.path:
|
||||
sys.path.insert(0, str(_repo_root))
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from atroposlib.envs.base import ScoredDataGroup
|
||||
from atroposlib.envs.server_handling.server_manager import APIServerConfig
|
||||
from atroposlib.type_definitions import Item
|
||||
|
||||
from environments.agent_loop import AgentResult
|
||||
from environments.hermes_base_env import HermesAgentBaseEnv, HermesAgentEnvConfig
|
||||
from environments.tool_context import ToolContext
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Config
|
||||
# =============================================================================
|
||||
|
||||
class SweSmithOracleEnvConfig(HermesAgentEnvConfig):
|
||||
"""Config for SWE-smith-oracle environment."""
|
||||
|
||||
dataset_name: str = Field(default="NousResearch/SWE-smith-oracle")
|
||||
dataset_split: str = Field(default="train")
|
||||
max_items: int = Field(default=0, description="0 = no limit")
|
||||
shuffle: bool = Field(default=True)
|
||||
seed: int = Field(default=0)
|
||||
|
||||
python_only: bool = Field(default=True, description="Filter to Python-evaluable rows")
|
||||
score_include_fail_to_pass: bool = Field(
|
||||
default=True,
|
||||
description="Score tests on PASS_TO_PASS ∪ FAIL_TO_PASS. "
|
||||
"Disable to only run PASS_TO_PASS (faster but weaker signal).",
|
||||
)
|
||||
|
||||
prompt_mode: str = Field(
|
||||
default="problem_statement",
|
||||
description="'problem_statement' (fast) or 'problem_statement+text' (includes dataset 'text').",
|
||||
)
|
||||
|
||||
repo_base_url: str = Field(default="https://github.com", description="Base URL for repo cloning")
|
||||
install_timeout_s: float = Field(default=600.0)
|
||||
test_timeout_s: float = Field(default=600.0)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Environment
|
||||
# =============================================================================
|
||||
|
||||
class SweSmithOracleEnv(HermesAgentBaseEnv):
|
||||
"""
|
||||
SWE-smith-oracle environment for training models to fix real GitHub repos.
|
||||
|
||||
Uses proper OpenAI-spec tool calling via HermesAgentBaseEnv.
|
||||
The model gets terminal access to inspect, edit, and test the repository.
|
||||
"""
|
||||
|
||||
name = "swe-smith-oracle"
|
||||
env_config_cls = SweSmithOracleEnvConfig
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: SweSmithOracleEnvConfig,
|
||||
server_configs,
|
||||
slurm=False,
|
||||
testing=False,
|
||||
):
|
||||
super().__init__(config, server_configs, slurm, testing)
|
||||
self._dataset = None
|
||||
self._indices: List[int] = []
|
||||
self._cursor = 0
|
||||
|
||||
@classmethod
|
||||
def config_init(cls) -> Tuple[SweSmithOracleEnvConfig, List[APIServerConfig]]:
|
||||
"""Default config — reads from ATROPOS_SERVER_* env vars."""
|
||||
base_url = (
|
||||
os.getenv("ATROPOS_SERVER_BASE_URL")
|
||||
or os.getenv("OPENAI_BASE_URL")
|
||||
or os.getenv("LLM_BASE_URL")
|
||||
or "http://127.0.0.1:8080"
|
||||
)
|
||||
if not base_url.rstrip("/").endswith("/v1"):
|
||||
base_url = base_url.rstrip("/") + "/v1"
|
||||
|
||||
model = os.getenv("ATROPOS_SERVER_MODEL") or os.getenv("LLM_MODEL") or "Hermes-4.3-36B"
|
||||
api_key = (
|
||||
os.getenv("ATROPOS_SERVER_API_KEY")
|
||||
or os.getenv("NOUS_API_KEY")
|
||||
or os.getenv("OPENAI_API_KEY")
|
||||
or "local"
|
||||
)
|
||||
|
||||
env_config = SweSmithOracleEnvConfig(
|
||||
tokenizer_name=os.getenv("ATROPOS_TOKENIZER_NAME") or "NousResearch/Hermes-4.3-36B",
|
||||
group_size=1,
|
||||
use_wandb=False,
|
||||
rollout_server_url="http://localhost:8000",
|
||||
total_steps=1,
|
||||
batch_size=1,
|
||||
steps_per_eval=1,
|
||||
max_token_length=8192,
|
||||
wandb_name="swe_smith_oracle",
|
||||
enabled_toolsets=["terminal", "file"],
|
||||
terminal_backend=os.getenv("TERMINAL_ENV", "local"),
|
||||
# Longer agent turns for SWE tasks
|
||||
max_agent_turns=50,
|
||||
agent_temperature=0.7,
|
||||
system_prompt=(
|
||||
"You are a senior software engineer. You have access to a terminal "
|
||||
"to inspect and fix repositories. Use non-interactive commands only. "
|
||||
"Each terminal command runs in a fresh shell."
|
||||
),
|
||||
tool_call_parser="hermes",
|
||||
# Sandbox settings (used when tool_pool_mode != "default")
|
||||
sandbox_image=os.getenv("ATROPOS_SANDBOX_IMAGE") or "atropos-sandbox:local",
|
||||
purge_job_on_start=True,
|
||||
purge_job_on_shutdown=True,
|
||||
)
|
||||
|
||||
server_configs = [
|
||||
APIServerConfig(
|
||||
model_name=model,
|
||||
base_url=base_url,
|
||||
api_key=api_key,
|
||||
server_type="vllm",
|
||||
health_check=False,
|
||||
timeout=int(os.getenv("ATROPOS_SERVER_TIMEOUT_S") or "300"),
|
||||
),
|
||||
]
|
||||
|
||||
return env_config, server_configs
|
||||
|
||||
# =========================================================================
|
||||
# Dataset loading
|
||||
# =========================================================================
|
||||
|
||||
async def setup(self):
|
||||
"""Load SWE-smith-oracle dataset."""
|
||||
from datasets import load_dataset
|
||||
|
||||
t0 = time.perf_counter()
|
||||
print(
|
||||
f"[SweSmithOracleEnv] loading dataset {self.config.dataset_name}:{self.config.dataset_split} "
|
||||
f"(python_only={self.config.python_only}, max_items={self.config.max_items or 'all'})",
|
||||
flush=True,
|
||||
)
|
||||
ds = load_dataset(self.config.dataset_name, split=self.config.dataset_split)
|
||||
self._dataset = ds
|
||||
|
||||
indices: List[int] = []
|
||||
for idx in range(len(ds)):
|
||||
row = ds[idx]
|
||||
if self.config.python_only and not self._is_python_row(row):
|
||||
continue
|
||||
indices.append(idx)
|
||||
|
||||
if self.config.shuffle:
|
||||
rnd = random.Random(self.config.seed)
|
||||
rnd.shuffle(indices)
|
||||
|
||||
if self.config.max_items and self.config.max_items > 0:
|
||||
indices = indices[: self.config.max_items]
|
||||
|
||||
self._indices = indices
|
||||
self._cursor = 0
|
||||
|
||||
print(
|
||||
f"[SweSmithOracleEnv] loaded {len(self._indices)} items "
|
||||
f"in {time.perf_counter() - t0:.2f}s",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
def _is_python_row(self, row: Dict[str, Any]) -> bool:
|
||||
nodeids = row.get("PASS_TO_PASS")
|
||||
if not isinstance(nodeids, list) or not nodeids:
|
||||
return False
|
||||
return all(isinstance(nid, str) and ".py::" in nid for nid in nodeids)
|
||||
|
||||
async def get_next_item(self) -> Item:
|
||||
if not self._dataset or not self._indices:
|
||||
raise RuntimeError("Dataset not initialized")
|
||||
if self._cursor >= len(self._indices):
|
||||
self._cursor = 0
|
||||
idx = self._indices[self._cursor]
|
||||
self._cursor += 1
|
||||
return dict(self._dataset[idx])
|
||||
|
||||
# =========================================================================
|
||||
# Prompt formatting
|
||||
# =========================================================================
|
||||
|
||||
def _repo_name(self, item: Item) -> str:
|
||||
repo = item.get("repo") or ""
|
||||
if isinstance(repo, str) and "/" in repo:
|
||||
return repo.split("/")[-1]
|
||||
return "repo"
|
||||
|
||||
def format_prompt(self, item: Item) -> str:
|
||||
"""Build the SWE task prompt."""
|
||||
repo = item.get("repo") or ""
|
||||
base_commit = item.get("base_commit") or ""
|
||||
problem = str(item.get("problem_statement") or "")
|
||||
context = str(item.get("text") or "")
|
||||
repo_dir = self._repo_name(item)
|
||||
|
||||
nodeids = self._tests_for_item(item)
|
||||
tests_list = "\n".join(f"- {t}" for t in nodeids)
|
||||
|
||||
context_block = ""
|
||||
prompt_mode = (self.config.prompt_mode or "problem_statement").strip().lower()
|
||||
if prompt_mode == "problem_statement+text" and context:
|
||||
context_block = f"\nAdditional context:\n{context}\n"
|
||||
|
||||
return (
|
||||
f"Fix the repository so the specified tests pass.\n\n"
|
||||
f"Repository: {repo} (checked out at base_commit={base_commit})\n"
|
||||
f"Workspace path: ./{repo_dir}\n\n"
|
||||
"Constraints:\n"
|
||||
"- Use the terminal tool to inspect, edit, and verify the repository.\n"
|
||||
f"- Start by inspecting the repo (e.g. `ls`, `cd ./{repo_dir}`, `git status`).\n"
|
||||
"- Use a workspace-local virtualenv (.venv) to avoid cross-run contamination.\n"
|
||||
"- Use non-interactive commands only.\n"
|
||||
"- Prefer `. .venv/bin/activate` or `.venv/bin/python ...` (POSIX compatible).\n\n"
|
||||
f"Problem statement:\n{problem}\n\n"
|
||||
f"{context_block}"
|
||||
f"Run these tests to verify:\n{tests_list}\n\n"
|
||||
"When done, briefly describe what you changed and confirm tests pass."
|
||||
)
|
||||
|
||||
# =========================================================================
|
||||
# Test helpers
|
||||
# =========================================================================
|
||||
|
||||
def _tests_for_item(self, item: Item) -> List[str]:
|
||||
tests: List[str] = []
|
||||
if self.config.score_include_fail_to_pass:
|
||||
for key in ("PASS_TO_PASS", "FAIL_TO_PASS"):
|
||||
nodeids = item.get(key)
|
||||
if isinstance(nodeids, list):
|
||||
tests.extend([n for n in nodeids if isinstance(n, str)])
|
||||
else:
|
||||
nodeids = item.get("PASS_TO_PASS")
|
||||
if isinstance(nodeids, list):
|
||||
tests.extend([n for n in nodeids if isinstance(n, str)])
|
||||
return sorted(dict.fromkeys(tests))
|
||||
|
||||
def _chunk_nodeids(self, nodeids: List[str], max_per_chunk: int = 50) -> List[List[str]]:
|
||||
return [nodeids[i : i + max_per_chunk] for i in range(0, len(nodeids), max_per_chunk)]
|
||||
|
||||
# =========================================================================
|
||||
# Sandbox hooks: setup_trajectory_workspace + verify_and_score_trajectory
|
||||
# =========================================================================
|
||||
|
||||
async def setup_trajectory_workspace(
|
||||
self, item: Item, *, trajectory_id: str, exec_tool
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Prepare a sandbox workspace: bare repo cache + git worktree.
|
||||
|
||||
Uses flock-serialized bare repo cache under /data/repo_cache so
|
||||
multiple trajectories sharing a sandbox don't clone the same repo
|
||||
in parallel. Each trajectory gets an isolated worktree at the
|
||||
specified base_commit.
|
||||
|
||||
Args:
|
||||
item: Dataset row with repo, base_commit, etc.
|
||||
trajectory_id: Unique trajectory ID
|
||||
exec_tool: async callable(tool_name, args, timeout) -> ExecutionResult
|
||||
|
||||
Returns:
|
||||
Dict with repo_dir, base_commit metadata
|
||||
"""
|
||||
import time as _time
|
||||
|
||||
t0 = _time.perf_counter()
|
||||
repo = item.get("repo")
|
||||
base_commit = item.get("base_commit")
|
||||
instance_id = item.get("instance_id") or item.get("id") or item.get("problem_id")
|
||||
if not isinstance(repo, str) or not isinstance(base_commit, str):
|
||||
raise RuntimeError("Invalid dataset row: missing repo/base_commit")
|
||||
|
||||
repo_dir = self._repo_name(item)
|
||||
clone_url = f"{self.config.repo_base_url.rstrip('/')}/{repo}.git"
|
||||
print(
|
||||
f"[SweSmithOracleEnv] tid={trajectory_id} setup_trajectory_workspace(): "
|
||||
f"repo={repo} base_commit={base_commit} instance_id={instance_id} dir=./{repo_dir}",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
# Bare repo cache + worktree strategy (same as atropos/envs/swe_smith_oracle_env.py)
|
||||
repo_slug = repo.replace("/", "__")
|
||||
cache_root = "/data/repo_cache"
|
||||
bare_repo = f"{cache_root}/{repo_slug}.git"
|
||||
lock_file = f"{cache_root}/.locks/{repo_slug}.lock"
|
||||
|
||||
worktree_cmd = (
|
||||
"set -e; "
|
||||
f"rm -rf {repo_dir}; "
|
||||
f"mkdir -p {cache_root}/.locks; "
|
||||
f": > {lock_file}; "
|
||||
f"flock -x {lock_file} sh -lc '"
|
||||
f"set -e; "
|
||||
"export GIT_TERMINAL_PROMPT=0; "
|
||||
"export GIT_LFS_SKIP_SMUDGE=1; "
|
||||
f"if [ ! -d \"{bare_repo}\" ]; then "
|
||||
f" git init --bare \"{bare_repo}\"; "
|
||||
f" git -C \"{bare_repo}\" remote add origin \"{clone_url}\"; "
|
||||
"fi; "
|
||||
f"git -C \"{bare_repo}\" remote set-url origin \"{clone_url}\"; "
|
||||
f"git -C \"{bare_repo}\" worktree prune || true; "
|
||||
f"if ! git -C \"{bare_repo}\" cat-file -e \"{base_commit}^{{commit}}\" 2>/dev/null; then "
|
||||
f" git -C \"{bare_repo}\" fetch --depth 1 origin \"{base_commit}\" || true; "
|
||||
"fi; "
|
||||
f"if ! git -C \"{bare_repo}\" cat-file -e \"{base_commit}^{{commit}}\" 2>/dev/null; then "
|
||||
f" git -C \"{bare_repo}\" fetch --prune origin; "
|
||||
"fi; "
|
||||
f"git --git-dir=\"{bare_repo}\" worktree add --detach \"{repo_dir}\" \"{base_commit}\"; "
|
||||
"'"
|
||||
)
|
||||
|
||||
print(f"[SweSmithOracleEnv] tid={trajectory_id} preparing worktree from repo cache", flush=True)
|
||||
res = await exec_tool(
|
||||
"bash",
|
||||
{"command": worktree_cmd},
|
||||
timeout=self.config.install_timeout_s,
|
||||
)
|
||||
if not res.success:
|
||||
raise RuntimeError(
|
||||
f"git worktree setup failed "
|
||||
f"(repo={repo}, base_commit={base_commit}, instance_id={instance_id}): "
|
||||
f"{res.error}\n{res.output}"
|
||||
)
|
||||
|
||||
print(
|
||||
f"[SweSmithOracleEnv] tid={trajectory_id} setup_trajectory_workspace(): "
|
||||
f"worktree ready in {_time.perf_counter() - t0:.2f}s",
|
||||
flush=True,
|
||||
)
|
||||
return {"repo_dir": repo_dir, "base_commit": base_commit}
|
||||
|
||||
async def verify_and_score_trajectory(
|
||||
self,
|
||||
item: Item,
|
||||
result: AgentResult,
|
||||
*,
|
||||
trajectory_id: str,
|
||||
exec_tool,
|
||||
workspace_meta: Optional[Dict[str, Any]] = None,
|
||||
) -> Tuple[float, Dict[str, Any]]:
|
||||
"""
|
||||
In-sandbox verification: install deps + run pytest with dataset nodeids.
|
||||
|
||||
Args:
|
||||
item: Dataset row
|
||||
result: Agent's rollout result
|
||||
trajectory_id: Unique trajectory ID
|
||||
exec_tool: async callable(tool_name, args, timeout) -> ExecutionResult
|
||||
workspace_meta: From setup_trajectory_workspace (has repo_dir)
|
||||
|
||||
Returns:
|
||||
(reward, metadata) tuple
|
||||
"""
|
||||
repo_dir = (workspace_meta or {}).get("repo_dir") or self._repo_name(item)
|
||||
|
||||
# Don't reward trajectories that never used tools
|
||||
tool_call_count = sum(
|
||||
len(msg.get("tool_calls", []))
|
||||
for msg in result.messages
|
||||
if msg.get("role") == "assistant"
|
||||
)
|
||||
if tool_call_count == 0:
|
||||
print(
|
||||
f"[SweSmithOracleEnv] tid={trajectory_id} verify: no tool calls; score=0.0",
|
||||
flush=True,
|
||||
)
|
||||
return 0.0, {"error": "No tool calls were made by the agent"}
|
||||
|
||||
nodeids = self._tests_for_item(item)
|
||||
if not nodeids:
|
||||
return 0.0, {"error": "No tests provided"}
|
||||
|
||||
# Install dependencies
|
||||
print(
|
||||
f"[SweSmithOracleEnv] tid={trajectory_id} verify: installing deps + running tests",
|
||||
flush=True,
|
||||
)
|
||||
setup_cmd = (
|
||||
f"cd {repo_dir} && "
|
||||
"python -m venv .venv && "
|
||||
". .venv/bin/activate && "
|
||||
"python -m pip install -U pip setuptools wheel && "
|
||||
"python -m pip install -e . && "
|
||||
"python -m pip install pytest"
|
||||
)
|
||||
setup_res = await exec_tool(
|
||||
"bash", {"command": setup_cmd}, timeout=self.config.install_timeout_s
|
||||
)
|
||||
if not setup_res.success:
|
||||
print(
|
||||
f"[SweSmithOracleEnv] tid={trajectory_id} install failed; score=0.0",
|
||||
flush=True,
|
||||
)
|
||||
return 0.0, {
|
||||
"phase": "install",
|
||||
"error": setup_res.error,
|
||||
"output": setup_res.output,
|
||||
}
|
||||
|
||||
# Run test chunks
|
||||
chunks = self._chunk_nodeids(nodeids, max_per_chunk=50)
|
||||
for chunk_idx, chunk in enumerate(chunks):
|
||||
joined = " ".join(chunk)
|
||||
cmd = f"cd {repo_dir} && . .venv/bin/activate && python -m pytest -q {joined}"
|
||||
res = await exec_tool(
|
||||
"bash", {"command": cmd}, timeout=self.config.test_timeout_s
|
||||
)
|
||||
if not res.success:
|
||||
print(
|
||||
f"[SweSmithOracleEnv] tid={trajectory_id} tests failed (chunk {chunk_idx}); score=0.0",
|
||||
flush=True,
|
||||
)
|
||||
return 0.0, {
|
||||
"phase": "pytest",
|
||||
"failed_chunk": chunk_idx,
|
||||
"error": res.error,
|
||||
"output": res.output,
|
||||
}
|
||||
|
||||
print(
|
||||
f"[SweSmithOracleEnv] tid={trajectory_id} all tests passed; score=1.0",
|
||||
flush=True,
|
||||
)
|
||||
return 1.0, {"passed": True}
|
||||
|
||||
# =========================================================================
|
||||
# Reward: run pytest in the terminal (local / non-sandbox path)
|
||||
# =========================================================================
|
||||
|
||||
async def compute_reward(
|
||||
self, item: Item, result: AgentResult, ctx: ToolContext
|
||||
) -> float:
|
||||
"""
|
||||
Verify by running pytest with the dataset's nodeids.
|
||||
|
||||
Reward structure (shaped to give training signal even when model can't solve tasks):
|
||||
- 0.0: No tool calls at all
|
||||
- 0.05: Per valid tool call (up to 0.3 max for tool-call shaping)
|
||||
- 0.4: Successfully installed deps
|
||||
- 1.0: All tests pass
|
||||
|
||||
The partial rewards for tool calls help the model learn to USE tools
|
||||
before it can learn to use them CORRECTLY. This is critical for cold-start
|
||||
training where the base model barely makes any tool calls.
|
||||
"""
|
||||
repo_dir = self._repo_name(item)
|
||||
|
||||
# Count tool calls (assistant messages that have tool_calls).
|
||||
# NOTE: we keep scoring policy here intentionally simple and env-specific.
|
||||
# The agent loop exposes additional tool-call metrics (attempted/schema_valid/
|
||||
# executed_ok/exec_error) that other environments may choose to use for
|
||||
# reward shaping, but we don't hard-require any particular calling format here.
|
||||
tool_call_count = sum(
|
||||
len(msg.get("tool_calls", []))
|
||||
for msg in result.messages
|
||||
if msg.get("role") == "assistant"
|
||||
)
|
||||
|
||||
if tool_call_count == 0:
|
||||
print(f"[SweSmithOracleEnv] No tool calls made; score=0.0", flush=True)
|
||||
return 0.0
|
||||
|
||||
# Partial reward: 0.05 per tool call, capped at 0.3
|
||||
tool_call_reward = min(tool_call_count * 0.05, 0.3)
|
||||
|
||||
# Debug: log tool-call quality metrics if present
|
||||
attempted = getattr(result, "tool_calls_attempted", None)
|
||||
schema_valid = getattr(result, "tool_calls_schema_valid", None)
|
||||
executed_ok = getattr(result, "tool_calls_executed_ok", None)
|
||||
exec_error = getattr(result, "tool_calls_exec_error", None)
|
||||
if attempted is not None:
|
||||
print(
|
||||
f"[SweSmithOracleEnv] Tool calls: total={tool_call_count}, attempted={attempted}, schema_valid={schema_valid}, ok={executed_ok}, err={exec_error}",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
nodeids = self._tests_for_item(item)
|
||||
if not nodeids:
|
||||
# No tests defined — just reward tool usage
|
||||
print(f"[SweSmithOracleEnv] No tests defined; score={tool_call_reward:.2f} (tool calls)", flush=True)
|
||||
return tool_call_reward
|
||||
|
||||
# Install deps + run tests
|
||||
print(f"[SweSmithOracleEnv] Verifying: installing deps + running tests", flush=True)
|
||||
setup_result = ctx.terminal(
|
||||
f"cd {repo_dir} && "
|
||||
"python -m venv .venv && "
|
||||
". .venv/bin/activate && "
|
||||
"python -m pip install -U pip setuptools wheel && "
|
||||
"python -m pip install -e . && "
|
||||
"python -m pip install pytest",
|
||||
timeout=int(self.config.install_timeout_s),
|
||||
)
|
||||
if setup_result.get("exit_code", 1) != 0:
|
||||
print(f"[SweSmithOracleEnv] Install failed; score={tool_call_reward:.2f} (tool calls only)", flush=True)
|
||||
return tool_call_reward
|
||||
|
||||
# Partial reward for successful install
|
||||
install_reward = 0.4
|
||||
|
||||
# Run test chunks
|
||||
chunks = self._chunk_nodeids(nodeids, max_per_chunk=50)
|
||||
for chunk_idx, chunk in enumerate(chunks):
|
||||
joined = " ".join(chunk)
|
||||
test_result = ctx.terminal(
|
||||
f"cd {repo_dir} && . .venv/bin/activate && python -m pytest -q {joined}",
|
||||
timeout=int(self.config.test_timeout_s),
|
||||
)
|
||||
if test_result.get("exit_code", 1) != 0:
|
||||
print(f"[SweSmithOracleEnv] Tests failed (chunk {chunk_idx}); score={install_reward:.2f} (install ok)", flush=True)
|
||||
return install_reward
|
||||
|
||||
print(f"[SweSmithOracleEnv] All tests passed; score=1.0", flush=True)
|
||||
return 1.0
|
||||
|
||||
# =========================================================================
|
||||
# Token truncation — keep start of trajectory, truncate from end
|
||||
# =========================================================================
|
||||
|
||||
def _build_scored_item(self, item, result, reward):
|
||||
"""
|
||||
Override to truncate tokens/masks from the END to fit within max_token_len.
|
||||
|
||||
Intuition (from NeurIPS finding): the start of the trajectory is most important
|
||||
for shifting the model distribution. Truncating from the end only costs ~2-3%
|
||||
vs handling the full sequence, but avoids the "Token length is too long" discard
|
||||
that throws away entire groups including valid training signal.
|
||||
"""
|
||||
scored_item, remaining = super()._build_scored_item(item, result, reward)
|
||||
if scored_item is None:
|
||||
return scored_item, remaining
|
||||
|
||||
# Use config.max_token_length as the truncation limit.
|
||||
# self.max_token_len comes from the trainer via /info, but may be -1
|
||||
# if the trainer hasn't registered yet (race condition).
|
||||
max_len = self.max_token_len
|
||||
if max_len <= 0:
|
||||
# Fallback to config value
|
||||
max_len = getattr(self.config, 'max_token_length', 0)
|
||||
if max_len <= 0:
|
||||
return scored_item, remaining
|
||||
|
||||
# Leave some margin (64 tokens) to avoid edge cases with padding alignment
|
||||
truncate_to = max_len - 64
|
||||
|
||||
tokens = scored_item.get("tokens")
|
||||
masks = scored_item.get("masks")
|
||||
|
||||
if tokens is not None and len(tokens) >= max_len:
|
||||
orig_len = len(tokens)
|
||||
scored_item["tokens"] = tokens[:truncate_to]
|
||||
if masks is not None and len(masks) >= max_len:
|
||||
scored_item["masks"] = masks[:truncate_to]
|
||||
logger.info(
|
||||
"Truncated trajectory from %d to %d tokens (max_token_len=%d)",
|
||||
orig_len, truncate_to, max_len,
|
||||
)
|
||||
|
||||
return scored_item, remaining
|
||||
|
||||
# =========================================================================
|
||||
# Evaluation (minimal for now)
|
||||
# =========================================================================
|
||||
|
||||
async def evaluate(self, *args, **kwargs):
|
||||
"""Placeholder evaluation — SWE tasks are too expensive for frequent eval."""
|
||||
start_time = time.time()
|
||||
await self.evaluate_log(
|
||||
metrics={"eval/placeholder": 0.0},
|
||||
samples=[],
|
||||
start_time=start_time,
|
||||
end_time=time.time(),
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
SweSmithOracleEnv.cli()
|
||||
@@ -49,22 +49,15 @@ class HermesToolCallParser(ToolCallParser):
|
||||
continue
|
||||
|
||||
tc_data = json.loads(raw_json)
|
||||
# Handle arguments: could be dict or already a JSON string
|
||||
raw_args = tc_data.get("arguments", {})
|
||||
if isinstance(raw_args, str):
|
||||
# Already a string — pass through as-is.
|
||||
# It may be a JSON string ("{...}") or a plain string ("ls").
|
||||
args_str = raw_args
|
||||
else:
|
||||
# Dict — serialize to JSON
|
||||
args_str = json.dumps(raw_args, ensure_ascii=False)
|
||||
tool_calls.append(
|
||||
ChatCompletionMessageToolCall(
|
||||
id=f"call_{uuid.uuid4().hex[:8]}",
|
||||
type="function",
|
||||
function=Function(
|
||||
name=tc_data["name"],
|
||||
arguments=args_str,
|
||||
arguments=json.dumps(
|
||||
tc_data.get("arguments", {}), ensure_ascii=False
|
||||
),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -1,136 +1,61 @@
|
||||
# Active Context
|
||||
|
||||
## Current Task: SWE Smith Oracle Env with Modal Backend
|
||||
## Current Focus
|
||||
Tinker RL training integration - pipeline fully wired up, waiting on Tinker billing to test.
|
||||
|
||||
### Goal
|
||||
Run this command:
|
||||
```bash
|
||||
python environments/swe_smith_oracle_env.py process \
|
||||
--env.use_wandb false \
|
||||
--env.total_steps 2 \
|
||||
--env.group_size 1 \
|
||||
--env.max_items 2 \
|
||||
--env.tool_pool_mode modal \
|
||||
--env.modal_image python:3.11 \
|
||||
--env.modal_slots_per_sandbox 10 \
|
||||
--env.modal_min_sandboxes 1
|
||||
```
|
||||
## Recently Completed (Feb 9, 2026)
|
||||
|
||||
### What's Done
|
||||
1. ✅ **agent_loop.py** - Added `tool_handler` parameter
|
||||
- New param: `tool_handler=None` in `__init__`
|
||||
- When `self.tool_handler` is set, it's called INSTEAD of `handle_function_call()`
|
||||
- Signature: `async tool_handler(tool_name, args, task_id) -> str`
|
||||
- Shows `[sandbox]` instead of backend name in terminal preview
|
||||
### Tinker RL Training Integration
|
||||
Created a complete agent training pipeline using Tinker (Thinking Machines) + Atropos:
|
||||
|
||||
2. ✅ **Phase 2 ManagedServer + SGLang** - Fully working (previous session)
|
||||
**New Files Created:**
|
||||
1. `tinker-atropos/tinker_atropos/environments/gsm8k_agent.py` - Agent GSM8k environment with:
|
||||
- Python REPL tool calling (Hermes-style `<tool_call>` format)
|
||||
- Multi-step agent loop within `collect_trajectories()`
|
||||
- Math answer verification via `math_verify`
|
||||
- Subprocess-based Python execution
|
||||
- WandB metrics (percent_correct, tool_use_rate)
|
||||
2. `tinker-atropos/configs/gsm8k_agent.yaml` - Config for Qwen3-4B-Instruct training
|
||||
|
||||
3. ✅ **hermes_base_env.py** - Sandbox routing in collect_trajectory() (THIS SESSION)
|
||||
- Refactored `collect_trajectory()` into:
|
||||
- `_use_sandbox_backend()` - checks if sandbox should be used
|
||||
- `_collect_trajectory_local()` - existing path (ToolContext + handle_function_call)
|
||||
- `_collect_trajectory_sandbox()` - NEW sandbox path with slot lifecycle
|
||||
- `_run_agent_loop()` - shared agent loop for Phase 1/2, accepts tool_handler
|
||||
- `_build_scored_item()` - shared scored item construction
|
||||
- Sandbox path:
|
||||
1. `backend.acquire(task_id)` → Slot
|
||||
2. `exec_tool` callable wrapping `backend.execute_batch([(slot, tool_name, args)])`
|
||||
3. `setup_trajectory_workspace(item, exec_tool=exec_tool)` → workspace_meta
|
||||
4. `sandbox_tool_handler` routes terminal→sandbox, other→local
|
||||
5. `_run_agent_loop(tool_handler=sandbox_tool_handler)`
|
||||
6. `verify_and_score_trajectory(item, result, exec_tool=exec_tool)`
|
||||
7. `backend.release(slot, reset_workspace=True)` in finally
|
||||
- Added `handle_function_call` import for non-terminal tool fallback
|
||||
**Dependencies Updated:**
|
||||
- `pyproject.toml` `[atropos]` extra now includes: tinker SDK, torch, wandb, math-verify
|
||||
- Installed: tinker 0.12.0, tinker-atropos 0.1.0, torch (CPU)
|
||||
|
||||
4. ✅ **swe_smith_oracle_env.py** - Sandbox hooks (THIS SESSION)
|
||||
- `setup_trajectory_workspace()` - bare repo cache + git worktree (ported from atropos/envs/swe_smith_oracle_env.py)
|
||||
- `verify_and_score_trajectory()` - install deps + run pytest in sandbox
|
||||
- `compute_reward()` retained for local (non-sandbox) path
|
||||
- Uses `exec_tool("bash", {"command": cmd}, timeout=600)` → `ExecutionResult`
|
||||
**README Updated:**
|
||||
- Added comprehensive "RL Training with Tinker" section with architecture diagram, quick start, config docs
|
||||
- Added TINKER_API_KEY and WANDB_API_KEY to optional keys table
|
||||
|
||||
5. ✅ **All tests pass**:
|
||||
- Syntax checks (ast.parse) on both files
|
||||
- Import checks (both modules import cleanly)
|
||||
- Method existence checks (all new methods present)
|
||||
- Signature checks (exec_tool, trajectory_id, workspace_meta params)
|
||||
- Backend integration (ModalSandboxConfig.from_agent_env_config, create_tool_backend)
|
||||
- `_use_sandbox_backend()` logic (True when modal+backend set, False otherwise)
|
||||
**Verified Working:**
|
||||
- Tinker SDK connection ✅
|
||||
- All imports (tinker, tinker_atropos, trainer, environment) ✅
|
||||
- Python REPL execution + tool call parsing ✅
|
||||
- Math verification ✅
|
||||
- Atropos run-api (port 8000) ✅
|
||||
- Tinker trainer starts, loads config, creates inference server (port 8001) ✅
|
||||
|
||||
6. ✅ **End-to-end test with Qwen 3 8B + Modal sandbox** (THIS SESSION)
|
||||
- RunPod endpoint: `0tx0ruuuo4f10c` (Qwen/Qwen3-8B via SGLang)
|
||||
- 5 terminal tool calls executed IN sandbox: `ls`, `git status`, `git log`, `cat parse.py`, `cat tests/`
|
||||
- In-sandbox verification: install deps + pytest → score=0.0 (model inspected but didn't fix)
|
||||
- Full token tracking with logprobs via Phase 2 ManagedServer
|
||||
- Key finding: Llama-3-8B template silently drops `tools=` param, Qwen 3 has full Hermes format support
|
||||
**Blocked:** Tinker billing (402 error) - user's payment didn't process (possibly regional card issue)
|
||||
|
||||
### Current Task: Integrate Slot Pool Backend into tools/terminal_tool.py
|
||||
### Main Branch Merge (Feb 9, 2026)
|
||||
Merged `origin/main` into `atropos-integrations` - 22,560 lines, 79 files, 5 conflicts resolved.
|
||||
|
||||
#### Step 1: Add `_SlotPoolEnvironment` to `tools/terminal_tool.py`
|
||||
- New class alongside existing `_LocalEnvironment`, `_DockerEnvironment`, etc.
|
||||
- Routes through `atropos/backends/` (ModalToolBackend or NomadToolBackend)
|
||||
- N:M slot multiplexing: 5-10 sandboxes × 10 slots each = 50-100 concurrent
|
||||
- Singleton `_SlotPoolManager` (like `_ModalPoolManager`) manages backend lifecycle
|
||||
- `execute()` acquires slot → `backend.execute_batch([(slot, "bash", ...)])` → returns `{"output": ..., "returncode": ...}`
|
||||
- `cleanup()` releases slot back to pool
|
||||
### Modal Backend (Feb 8, 2026)
|
||||
Merged modal-integration branch, working with Modal Sandboxes.
|
||||
|
||||
#### Step 2: Wire into `_create_environment()`
|
||||
- `TERMINAL_ENV=slot_pool` → `_SlotPoolEnvironment(...)`
|
||||
- Sub-config: `TERMINAL_SLOT_BACKEND=modal` or `TERMINAL_SLOT_BACKEND=nomad`
|
||||
- Reuse existing `TERMINAL_MODAL_*` and Nomad env vars for configuration
|
||||
### Singularity/Apptainer (Feb 6, 2026)
|
||||
Completed and tested.
|
||||
|
||||
#### Step 3: Remove redundant `atropos/tools/` files
|
||||
- DELETE: `hermes_external_tools.py`, `build_registry.py`, `sandbox_stubs.py`, `toolset_resolver.py`
|
||||
- KEEP: `base.py` (ToolCall/ToolResult types), `tool_executor.py` (batched queue), `terminal_stateful_tool.py`, `tmux_tool.py`
|
||||
|
||||
#### Step 4: Clean up `atropos/envs/` and `atropos/agent/` (defer)
|
||||
- Remove `atropos/envs/agent_env.py` → replaced by `environments/hermes_base_env.py`
|
||||
- Remove `atropos/agent/atropos_agent.py` → replaced by `environments/agent_loop.py`
|
||||
|
||||
#### Later
|
||||
- Test with Tinker trainer (blocked on billing)
|
||||
- Add more environments (endless-terminals, terminalbench 2)
|
||||
|
||||
### Key Architecture Insight
|
||||
Two separate sandbox integration points:
|
||||
1. **`tools/terminal_tool.py` with `TERMINAL_ENV=slot_pool`** — for hermes CLI, batch_runner, any code using `handle_function_call("terminal", ...)`. Uses `_SlotPoolEnvironment` which wraps `atropos/backends/`.
|
||||
2. **`environments/hermes_base_env.py` with `tool_pool_mode=modal/nomad`** — for RL environments. Uses `_collect_trajectory_sandbox()` which directly acquires slots and creates `sandbox_tool_handler`.
|
||||
|
||||
Both use the same underlying `atropos/backends/` (ModalToolBackend, NomadToolBackend) with the same slot pool.
|
||||
|
||||
### Architecture Summary
|
||||
## Architecture: Training Pipeline
|
||||
|
||||
```
|
||||
environments/hermes_base_env.py (HermesAgentBaseEnv)
|
||||
│
|
||||
├── tool_pool_mode="default" (existing path)
|
||||
│ └── collect_trajectory() → HermesAgentLoop(tool_handler=None)
|
||||
│ → handle_function_call() → hermes terminal tool (local)
|
||||
│
|
||||
└── tool_pool_mode="modal" or "nomad" (new path)
|
||||
└── collect_trajectory():
|
||||
1. slot = backend.acquire(task_id)
|
||||
2. exec_tool = lambda routing through backend.execute_batch
|
||||
3. setup_trajectory_workspace(item, exec_tool=exec_tool) [subclass hook]
|
||||
4. HermesAgentLoop(tool_handler=sandbox_tool_handler)
|
||||
→ terminal calls → backend.execute_batch(slot, "bash", ...)
|
||||
5. verify_and_score_trajectory(item, result, exec_tool=exec_tool) [subclass hook]
|
||||
6. backend.release(slot, reset_workspace=True)
|
||||
|
||||
atropos/backends/modal_backend.py (ModalToolBackend)
|
||||
└── acquire(trajectory_id) → Slot
|
||||
└── execute_batch([(slot, "bash", {"command": "..."})]) → [ExecutionResult]
|
||||
└── release(slot, reset_workspace=True)
|
||||
Terminal 1: run-api (port 8000) - Atropos Rollout API
|
||||
Terminal 2: launch_training.py (port 8001) - Tinker Trainer + FastAPI inference
|
||||
Terminal 3: gsm8k_agent.py serve - Environment (generates trajectories)
|
||||
```
|
||||
|
||||
### Key Files to Modify
|
||||
1. `environments/hermes_base_env.py` - Add sandbox path in `collect_trajectory()`
|
||||
2. `environments/swe_smith_oracle_env.py` - Override `setup_trajectory_workspace()` and `verify_and_score_trajectory()` to use exec_tool
|
||||
The agent env gets math problems → model calls Python REPL tool → scores answer → sends to Atropos → Tinker does LoRA training → updates sampling weights → repeat.
|
||||
|
||||
### Important Notes
|
||||
- `exec_tool` returns `ExecutionResult` (from `atropos/slots/executor.py`) with `.success`, `.output`, `.error`, `.metadata`
|
||||
- `tool_handler` returns JSON string (for agent loop message format)
|
||||
- These are DIFFERENT interfaces for different purposes:
|
||||
- `exec_tool`: used by env hooks (setup/verify) - returns structured result
|
||||
- `tool_handler`: used by agent loop - returns JSON string like hermes tools do
|
||||
- The ModalToolBackend.execute_batch calls _ModalSandboxWithSlots.execute which runs `sandbox.exec("bash", "-c", command)` on Modal
|
||||
- For the SWE env, the worktree setup pattern from `atropos/envs/swe_smith_oracle_env.py` should be reused (bare repo cache + worktree add)
|
||||
## Next Steps
|
||||
- [ ] Resolve Tinker billing to test full training loop
|
||||
- [ ] Run GSM8k agent training for ~20 steps (proof of concept)
|
||||
- [ ] Monitor WandB for reward improvement
|
||||
- [ ] Graduate to more complex agent envs (SWE tasks with Modal backend)
|
||||
|
||||
@@ -1,134 +1,96 @@
|
||||
# Progress
|
||||
|
||||
## Current Sprint: Phase 2 ManagedServer + SGLang Working (Feb 10, 2026)
|
||||
|
||||
### ✅ Phase 2 End-to-End Pipeline VERIFIED
|
||||
Full pipeline working: GSM8k env → collect_trajectory → ManagedServer → VLLMServer (SGLang patched) → tokens + logprobs + masks.
|
||||
|
||||
Test results:
|
||||
- 212 tokens with logprobs and masks from single trajectory
|
||||
- Reward: 1.0 (correct answer)
|
||||
- ScoredDataItem has all required fields: tokens, masks, scores, advantages, ref_logprobs, messages
|
||||
- RunPod SGLang endpoint (b9zmuyn1carwya) with Llama-3-8B-Instruct
|
||||
|
||||
### Consolidation Checklist
|
||||
- [x] Install atropos `tool_call_support` branch (PR #366)
|
||||
- [x] Create `environments/gsm8k_agent_env.py` using `HermesAgentBaseEnv`
|
||||
- [x] Create `environments/agent_loop.py` with proper OpenAI-spec tool calling
|
||||
- [x] Create `environments/tool_call_parsers/` with 13 parsers
|
||||
- [x] Create `environments/patches.py` for SGLang compatibility
|
||||
- [x] Add sandbox pool support to `HermesAgentBaseEnv`
|
||||
- [x] Test Phase 1 (OpenAI server type) with Nous API — WORKS
|
||||
- [x] Test Phase 2 (ManagedServer) with RunPod SGLang — WORKS
|
||||
- [x] Port SWE env to `HermesAgentBaseEnv` with multiplexed sandboxing
|
||||
- [x] End-to-end test: Qwen 3 8B + Modal sandbox + tool calls in sandbox + pytest verification
|
||||
- [x] Add `_SlotPoolEnvironment` to `tools/terminal_tool.py` (TERMINAL_ENV=slot_pool)
|
||||
- [x] Remove redundant `atropos/tools/` files (4 of 8)
|
||||
- [ ] Remove redundant `atropos/agent/` and `atropos/envs/agent_env.py` (deferred)
|
||||
- [ ] Test end-to-end with Tinker trainer (blocked on billing)
|
||||
|
||||
### ✅ End-to-End SWE + Modal Sandbox Verified (Feb 10, 2026)
|
||||
- Qwen 3 8B on RunPod SGLang (endpoint `0tx0ruuuo4f10c`)
|
||||
- Phase 2 ManagedServer with hermes tool call parser
|
||||
- 5 terminal commands executed in Modal sandbox: ls, git status, git log, cat parse.py, cat tests/
|
||||
- In-sandbox verification: install deps + pytest → score 0.0 (model inspected but didn't fix)
|
||||
- Full token tracking with logprobs via /generate endpoint
|
||||
- Key finding: Llama-3-8B template drops tools= silently; Qwen 3 has full Hermes tool format
|
||||
|
||||
## Completed Features
|
||||
|
||||
### ✅ Phase 2 ManagedServer + SGLang (Feb 10, 2026)
|
||||
- SGLang patch in `environments/patches.py` monkey-patches VLLMServer
|
||||
- Handles SGLang's different request/response format vs VLLM
|
||||
- Handles RunPod's double-JSON wrapping
|
||||
- Full chain verified: ManagedServer → VLLMServer → _tokens_and_logprobs_comp (retry) → patched wrapper → /generate endpoint
|
||||
- SequenceNode tracking: tokens, logprobs, masked_tokens all populated
|
||||
- **Key discovery**: The AttributeError from earlier was NOT in our current code — likely from a prior code state
|
||||
### ✅ Modal Backend Integration (Feb 8, 2026 - MERGED & TESTED)
|
||||
Merged the `modal-integration` branch and fixed integration issues.
|
||||
|
||||
### ✅ Phase 1 OpenAI Server Mode (Feb 9-10, 2026)
|
||||
- GSM8k env works with Nous API (OpenRouter-style endpoint)
|
||||
- Terminal tool calls properly dispatched
|
||||
- Tool call parsing handled natively by server (VLLM/SGLang /v1/chat/completions)
|
||||
- Reward computation verified (math_verify for robust LaTeX comparison)
|
||||
**What Works:**
|
||||
- `ModalToolBackend` implements full `ToolBackend` interface (start, stop, acquire, release, execute_batch)
|
||||
- Modal Sandboxes used for long-lived containers (not Functions)
|
||||
- `sandbox.exec()` for direct command execution (no HTTP server needed)
|
||||
- Slot-based multiplexing matching Nomad pattern
|
||||
- Multi-profile support (`ModalSandboxConfig`, `_ModalMultiProfileManager`)
|
||||
- YAML profile loading (`modal_profiles.yaml`)
|
||||
- `AgentEnvConfig` fields for all Modal settings (`--env.modal_*`)
|
||||
- `create_tool_backend()` supports `tool_pool_mode="modal"`
|
||||
- Terminal tool (`tools/terminal_tool.py`) native Modal integration with pool management
|
||||
- Named sandbox recovery via `Sandbox.from_name()`
|
||||
- Auto-scaling sandbox pool per profile
|
||||
- Artifact helpers (read, list, archive)
|
||||
|
||||
### ✅ Sandbox Pool Integration (Feb 10, 2026)
|
||||
- Config fields added to `HermesAgentEnvConfig` for Nomad and Modal
|
||||
- `_start_sandbox_backend()` / `_stop_sandbox_backend()` lifecycle methods
|
||||
- Optional hooks: `setup_trajectory_workspace()`, `verify_and_score_trajectory()`
|
||||
- Integrated into `env_manager()` and `process_manager()` cleanup
|
||||
**CLI Usage:**
|
||||
```bash
|
||||
# Atropos backend
|
||||
python -m atropos.envs.swe_smith_oracle_env process \
|
||||
--env.tool_pool_mode modal \
|
||||
--env.modal_image python:3.11
|
||||
|
||||
### ✅ Tool Call Parsers (Feb 9-10, 2026)
|
||||
- 13 parsers: hermes, llama3_json, llama4_json, qwen, qwen3_coder, deepseek_v3, deepseek_v31, glm45, glm47, mistral, kimi_k2, longcat
|
||||
- Registry pattern: `get_parser("hermes")` returns parser instance
|
||||
- Each parser: `.parse(text) → (content, tool_calls)`
|
||||
- Used by ManagedServer in Phase 2 to extract structured tool_calls from raw completion
|
||||
# Terminal tool
|
||||
TERMINAL_ENV=modal ./hermes
|
||||
```
|
||||
|
||||
### ✅ Modal Backend Integration (Feb 8, 2026)
|
||||
- `ModalToolBackend` with slot-based multiplexing
|
||||
- Multi-profile support (CPU, GPU, high-memory)
|
||||
- Auto-scaling sandbox pool via Modal Sandboxes
|
||||
**Files Modified/Created:**
|
||||
- `atropos/backends/modal_backend.py` - Full implementation (~1200 lines)
|
||||
- `atropos/backends/__init__.py` - `create_tool_backend()` updated
|
||||
- `atropos/envs/agent_env.py` - 15 Modal config fields added
|
||||
- `tools/terminal_tool.py` - Native Modal sandbox pool
|
||||
- `docs/MODAL_BACKEND.md` - Documentation
|
||||
- `modal_profiles.yaml.example` - Example profiles
|
||||
- `tests/test_modal_integration.py` - Integration tests
|
||||
- `tests/test_modal_stress.py` - Stress tests
|
||||
- `tests/test_modal_terminal.py` - Terminal tool tests
|
||||
|
||||
### ✅ Main Branch Merge (Feb 9, 2026)
|
||||
- Merged 22,560 lines, 79 files, 5 conflicts resolved
|
||||
- New: hermes_cli/, file_operations, RL training tools, gateway, cron
|
||||
### ✅ Singularity/Apptainer Sandbox Integration (Feb 6, 2026 - FULLY TESTED)
|
||||
Adapted the Atropos sandbox environment from Docker to Singularity/Apptainer for HPC clusters.
|
||||
|
||||
### ✅ Tinker RL Training Setup (Feb 9, 2026)
|
||||
- tinker 0.12.0 + tinker-atropos installed
|
||||
- GSM8k agent config created
|
||||
- Pipeline verified: Tinker API connection works, all imports pass
|
||||
- **Blocked on billing** (Tinker 402 error)
|
||||
**What Works:**
|
||||
- `create_sandbox_job()` supports both `driver="docker"` and `driver="singularity"`
|
||||
- SlotPoolConfig and NomadBackendConfig propagate driver settings
|
||||
- Singularity container runs sandbox_server.py via Nomad's raw_exec driver
|
||||
- All sandbox operations work: bash execution, file read/write
|
||||
- **CLI arguments** `--env.driver` and `--env.singularity_image` for AgentEnvConfig
|
||||
- **Static port binding** for Singularity (ReservedPorts vs DynamicPorts)
|
||||
|
||||
### ✅ Singularity/Apptainer Sandbox (Feb 6, 2026)
|
||||
- Nomad raw_exec driver for HPC clusters
|
||||
- All sandbox operations tested and working
|
||||
### ✅ Memory Bank Initialized (Feb 5, 2026)
|
||||
Set up project documentation structure for context persistence.
|
||||
|
||||
### ✅ Memory Bank (Feb 5, 2026)
|
||||
- Project documentation structure initialized
|
||||
|
||||
## What to KEEP vs REMOVE
|
||||
|
||||
### KEEP (valuable infrastructure):
|
||||
| Component | Location | Purpose |
|
||||
|-----------|----------|---------|
|
||||
| Modal backend | `atropos/backends/modal_backend.py` | Cloud sandbox pool |
|
||||
| Nomad backend | `atropos/backends/nomad_backend.py` | Docker/Singularity sandboxes |
|
||||
| Slot pool | `atropos/slots/` | Container multiplexing |
|
||||
| Nomad client | `atropos/nomad/` | Nomad API |
|
||||
| Sandbox server | `atropos/sandbox_server.py` | HTTP server in containers |
|
||||
| Dockerfile | `atropos/Dockerfile` | Container image |
|
||||
| Agent loop | `environments/agent_loop.py` | Proper OpenAI-spec tool calling |
|
||||
| Base env | `environments/hermes_base_env.py` | Phase 1/2 with parsers |
|
||||
| Tool parsers | `environments/tool_call_parsers/` | 13 model parsers |
|
||||
| SGLang patch | `environments/patches.py` | SGLang compatibility |
|
||||
|
||||
### REMOVE (redundant with environments/):
|
||||
| Component | Location | Replaced By |
|
||||
|-----------|----------|-------------|
|
||||
| ICL agent | `atropos/agent/atropos_agent.py` | `environments/agent_loop.py` |
|
||||
| AgentEnv | `atropos/envs/agent_env.py` | `environments/hermes_base_env.py` |
|
||||
| Tool registry | `atropos/tools/` | `model_tools.py` + `tools/` |
|
||||
| GSM8k ICL env | `tinker-atropos/.../gsm8k_agent.py` | `environments/gsm8k_agent_env.py` |
|
||||
## In Progress
|
||||
None currently.
|
||||
|
||||
## Known Issues
|
||||
- Tinker billing (402 error) - user's payment didn't process
|
||||
- Modal backend not yet live-tested with actual Modal cloud credentials
|
||||
- `bwrap_available: false` in Singularity containers
|
||||
- Llama-3-8B-Instruct doesn't reliably produce tool calls via Phase 2 (needs Hermes-format model)
|
||||
- Model answered GSM8k correctly but didn't actually USE the terminal tool (computed mentally)
|
||||
- Health check timing - may need longer wait for container startup on slower systems
|
||||
|
||||
## What's Left to Build
|
||||
|
||||
### Modal Backend
|
||||
- [ ] Live test with Modal credentials on actual cloud
|
||||
- [ ] Test multi-profile GPU workflows
|
||||
- [ ] Test sandbox recovery after restart
|
||||
- [ ] Integrate with SWE-smith-oracle env for GRPO training loop
|
||||
- [ ] Performance benchmarking vs Nomad backend
|
||||
|
||||
### HPC Deployment
|
||||
- [ ] Test on actual HPC cluster with Slurm/PBS integration
|
||||
- [ ] Document cluster-specific deployment procedures
|
||||
|
||||
### Documentation
|
||||
- [ ] Add Singularity deployment to README
|
||||
- [ ] Create HPC deployment skill in skills/mlops/
|
||||
|
||||
## Evolution of Decisions
|
||||
|
||||
### Agent Architecture
|
||||
- **v1 (our branch)**: ICL-based agent with `<tool_call>` XML tags in system prompt
|
||||
- **v2 (Teknium's)**: Proper OpenAI-spec tool calling with `tools=` parameter
|
||||
- **Decision**: Adopt v2, consolidate into `environments/`, keep sandbox backends from v1
|
||||
### Container Runtime Selection
|
||||
- **Initial**: Docker-only via Nomad docker driver
|
||||
- **Problem**: HPC clusters don't allow Docker without sudo
|
||||
- **Solution**: Added Singularity/Apptainer support via raw_exec driver
|
||||
- **Result**: Both runtimes now supported with same API
|
||||
|
||||
### Environment Organization
|
||||
- **Before**: Two parallel systems (`atropos/envs/` and `environments/`)
|
||||
- **After**: Single system in `environments/`, using `HermesAgentBaseEnv` as base class
|
||||
- Sandbox backends remain in `atropos/backends/` but integrate via terminal backend config
|
||||
|
||||
### Phase 2 SGLang Support
|
||||
- **Problem**: VLLMServer hardcoded for VLLM's /generate format, SGLang is different
|
||||
- **Solution**: Monkey-patch `_tokens_and_logprobs_completion_wrapper` in `environments/patches.py`
|
||||
- **Applied**: Automatically at import time via `apply_patches()` in `hermes_base_env.py`
|
||||
- **Handles**: SGLang format differences AND RunPod's double-JSON wrapping
|
||||
### Modal Backend Architecture
|
||||
- **Initial**: Stub placeholder raising RuntimeError
|
||||
- **Investigation**: Modal Sandboxes vs Functions - chose Sandboxes for long-lived containers
|
||||
- **Design**: Direct `sandbox.exec()` instead of HTTP/sandbox_server.py (simpler, no networking needed)
|
||||
- **Implementation**: Merged from `modal-integration` branch, fixed agent_env.py config fields
|
||||
- **Result**: Three backends now supported: Nomad/Docker, Nomad/Singularity, Modal
|
||||
|
||||
@@ -148,85 +148,11 @@ The agent validates responses before accepting:
|
||||
4. `AIAgent` reads env vars when initializing terminal tool
|
||||
5. Terminal tool creates appropriate backend based on `TERMINAL_ENV`
|
||||
|
||||
## RL Training Architecture (Consolidated)
|
||||
|
||||
### Environment System (`environments/`)
|
||||
|
||||
The canonical way to build agentic RL environments in Hermes-Agent:
|
||||
## Atropos Backend Architecture
|
||||
|
||||
### Backend Hierarchy
|
||||
```
|
||||
environments/
|
||||
├── agent_loop.py ← HermesAgentLoop: OpenAI-spec tool calling
|
||||
├── hermes_base_env.py ← HermesAgentBaseEnv: base class for all envs
|
||||
├── tool_context.py ← ToolContext: reward function tool access
|
||||
├── tool_call_parsers/ ← 11+ model parsers (hermes, qwen, deepseek, etc.)
|
||||
├── terminal_test_env.py ← Example: file creation tasks
|
||||
├── hermes_swe_env.py ← SWE environment
|
||||
└── gsm8k_agent_env.py ← GSM8k with Python REPL (TODO)
|
||||
```
|
||||
|
||||
### Two-Phase Operation
|
||||
- **Phase 1 (OpenAI server)**: Native tool_calls from VLLM/SGLang/OpenRouter
|
||||
- Good for: SFT data gen, testing, evaluation
|
||||
- Server handles tool call parsing via `/v1/chat/completions`
|
||||
- **Phase 2 (ManagedServer)**: Client-side tool call parser + logprob tracking
|
||||
- Required for: RL training (exact token IDs + logprobs for GRPO/PPO)
|
||||
- Uses `/generate` endpoint for raw token output
|
||||
- Parser registry selects per-model parser (hermes, qwen, llama, etc.)
|
||||
- **Verified working** with RunPod SGLang endpoint (Feb 10, 2026)
|
||||
|
||||
### Phase 2 Call Chain (Verified)
|
||||
```
|
||||
collect_trajectory()
|
||||
→ ServerManager.managed_server(tokenizer, tool_call_parser)
|
||||
→ ManagedServer(server=VLLMServer)
|
||||
→ ManagedServer.chat_completion(messages, tools, n, max_tokens, temp)
|
||||
→ _convert_messages_to_prompt(messages, tools=tools) [apply_chat_template]
|
||||
→ _compute_input_ids(prompt, extending_node)
|
||||
→ VLLMServer.tokens_and_logprobs_completion(**kwargs) [public method]
|
||||
→ _tokens_and_logprobs_comp(stat_dict, **kwargs) [retry decorator, semaphore]
|
||||
→ _tokens_and_logprobs_completion_wrapper(**kwargs) [patched for SGLang]
|
||||
→ aiohttp POST to /generate
|
||||
→ Returns (prompt_tokens, [output_tokens], [output_logprobs], [finish_reasons])
|
||||
→ _create_sequence_node(...) [stores in current_nodes]
|
||||
→ tool_call_parser.parse(completion_text) [if parser configured]
|
||||
→ Returns ChatCompletion with tool_calls
|
||||
```
|
||||
|
||||
### SGLang Compatibility Patch (`environments/patches.py`)
|
||||
VLLMServer's `_tokens_and_logprobs_completion_wrapper` is monkey-patched to handle SGLang's
|
||||
different request/response format. Applied automatically at import time via `apply_patches()`.
|
||||
|
||||
```
|
||||
SGLang request: {"input_ids": [...], "sampling_params": {...}, "return_logprob": true}
|
||||
SGLang response: {"meta_info": {"output_token_logprobs": [[logprob, token_id, text], ...]}}
|
||||
|
||||
VLLM request: {"prompt": {"prompt_token_ids": [...]}, "logprobs": 0}
|
||||
VLLM response: {"logprobs": [[{token_id: logprob}]], "finish_reasons": [...]}
|
||||
```
|
||||
|
||||
Also handles RunPod serverless double-JSON wrapping (response body wrapped in quotes).
|
||||
|
||||
### Key Design: Proper Tool Calling (NOT ICL)
|
||||
```python
|
||||
# CORRECT: pass tools= to chat_completion()
|
||||
response = await server.chat_completion(
|
||||
messages=messages,
|
||||
tools=tool_schemas, # ← tokenizer.apply_chat_template(tools=...) formats these
|
||||
temperature=1.0,
|
||||
)
|
||||
# Response has response.choices[0].message.tool_calls (structured objects)
|
||||
|
||||
# WRONG (old approach): embed tools in system prompt as XML
|
||||
system_prompt = f"<tools>{json.dumps(tools)}</tools>" # ← ICL, not proper training format
|
||||
```
|
||||
|
||||
### Sandbox Backends (`atropos/backends/`)
|
||||
|
||||
Infrastructure for scaled sandbox execution, integrated into HermesAgentBaseEnv:
|
||||
|
||||
```
|
||||
ToolBackend (Protocol)
|
||||
ToolBackend (Protocol - base.py)
|
||||
├── NomadToolBackend → SlotPool → NomadClient + SandboxExecutor (HTTP)
|
||||
│ ├── Docker driver (default)
|
||||
│ └── Singularity driver (HPC)
|
||||
@@ -234,34 +160,32 @@ ToolBackend (Protocol)
|
||||
└── _ModalMultiProfileManager (multi-profile support)
|
||||
```
|
||||
|
||||
Two execution modes in HermesAgentBaseEnv (controlled by `tool_pool_mode` config):
|
||||
- `default` - Local tool execution via handle_function_call() + ToolContext
|
||||
- `modal` / `nomad` - Sandbox routing: slot acquire → setup workspace → agent loop → verify → release
|
||||
### Slot-Based Multiplexing Pattern
|
||||
All backends share the same slot multiplexing concept:
|
||||
- **Sandbox/Container**: Long-lived compute unit
|
||||
- **Slot**: Isolated workspace directory within a sandbox (e.g., `/data/slot_0`)
|
||||
- **Trajectory**: One agent task using one slot
|
||||
- Multiple trajectories share a sandbox via different slots
|
||||
|
||||
Sandbox routing architecture:
|
||||
```
|
||||
collect_trajectory()
|
||||
├── tool_pool_mode="default" → _collect_trajectory_local()
|
||||
│ └── _run_agent_loop(tool_handler=None) → compute_reward(ctx)
|
||||
│
|
||||
└── tool_pool_mode="modal"/"nomad" → _collect_trajectory_sandbox()
|
||||
├── backend.acquire(task_id) → Slot
|
||||
├── exec_tool = backend.execute_batch wrapper → ExecutionResult
|
||||
├── setup_trajectory_workspace(item, exec_tool) [subclass hook]
|
||||
├── _run_agent_loop(tool_handler=sandbox_tool_handler)
|
||||
│ └── terminal → backend.execute_batch → JSON string
|
||||
│ └── other tools → handle_function_call (local)
|
||||
├── verify_and_score_trajectory(item, result, exec_tool) [subclass hook]
|
||||
└── backend.release(slot, reset_workspace=True) [finally]
|
||||
```
|
||||
### Nomad Backend (HTTP-based)
|
||||
- Deploys `sandbox_server.py` inside containers (Docker or Singularity)
|
||||
- Uses `SandboxExecutor` for HTTP communication (POST /execute, POST /batch)
|
||||
- Nomad manages container lifecycle (scaling, health checks)
|
||||
- Tools: bash, bash_stateful, read_file, write_file, tmux
|
||||
|
||||
Key interfaces:
|
||||
- `exec_tool(tool_name, args, timeout)` → `ExecutionResult` (for env hooks)
|
||||
- `tool_handler(tool_name, args, task_id)` → JSON string (for agent loop)
|
||||
### Modal Backend (exec-based)
|
||||
- Creates `modal.Sandbox` instances (long-lived containers)
|
||||
- Uses `sandbox.exec("bash", "-c", command)` directly (no HTTP server)
|
||||
- Modal manages container lifecycle (idle_timeout, max_lifetime)
|
||||
- Multi-profile support: different resource configs (CPU, GPU, memory)
|
||||
- Named sandboxes for recovery: `Sandbox.from_name(app_name, sandbox_name)`
|
||||
- YAML config via `modal_profiles.yaml`
|
||||
|
||||
### Training Pipeline (Tinker + Atropos)
|
||||
```
|
||||
Terminal 1: run-api (port 8000) ← Atropos Rollout API
|
||||
Terminal 2: launch_training.py (port 8001) ← Tinker Trainer + inference
|
||||
Terminal 3: environment.py serve ← Environment (rollouts)
|
||||
### Backend Selection
|
||||
```python
|
||||
# In agent_env.py / create_tool_backend()
|
||||
if mode == "nomad":
|
||||
return NomadToolBackend(NomadBackendConfig.from_agent_env_config(cfg))
|
||||
if mode == "modal":
|
||||
return ModalToolBackend(ModalSandboxConfig.from_agent_env_config(cfg))
|
||||
```
|
||||
|
||||
Submodule mini-swe-agent updated: ee36b3d4e5...9ddd61b62d
@@ -41,9 +41,8 @@ messaging = ["python-telegram-bot>=20.0", "discord.py>=2.0", "aiohttp>=3.9.0"]
|
||||
cron = ["croniter"]
|
||||
cli = ["simple-term-menu"]
|
||||
# Install Atropos + Tinker training integration from source.
|
||||
# Uses tool_call_support branch for ManagedServer tool calling (PR #366).
|
||||
atropos = [
|
||||
"atroposlib @ git+https://github.com/NousResearch/atropos.git@tool_call_support",
|
||||
"atroposlib @ git+https://github.com/NousResearch/atropos.git",
|
||||
"tinker @ git+https://github.com/thinking-machines-lab/tinker.git",
|
||||
# Atropos integration runtime deps (kept optional for Hermes-only users)
|
||||
"aiohttp",
|
||||
|
||||
@@ -1545,311 +1545,6 @@ class _ModalSandboxEnvironment:
|
||||
pass
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Slot Pool Environment — routes through atropos/backends/ for multiplexed
|
||||
# sandbox execution. Supports Modal, Nomad (Docker + Singularity/Apptainer).
|
||||
#
|
||||
# Usage: TERMINAL_ENV=slot_pool TERMINAL_SLOT_BACKEND=modal
|
||||
# =============================================================================
|
||||
|
||||
class _SlotPoolAsyncWorker:
|
||||
"""Background thread with its own event loop for running async backend ops."""
|
||||
|
||||
def __init__(self):
|
||||
self._loop = None
|
||||
self._thread = None
|
||||
|
||||
def start(self):
|
||||
import asyncio as _aio
|
||||
self._loop = _aio.new_event_loop()
|
||||
self._thread = threading.Thread(target=self._run, daemon=True)
|
||||
self._thread.start()
|
||||
|
||||
def _run(self):
|
||||
import asyncio as _aio
|
||||
_aio.set_event_loop(self._loop)
|
||||
self._loop.run_forever()
|
||||
|
||||
def run(self, coro, timeout=300):
|
||||
"""Run an async coroutine synchronously on the worker thread."""
|
||||
import asyncio as _aio
|
||||
if self._loop is None or self._thread is None:
|
||||
raise RuntimeError("SlotPoolAsyncWorker not started")
|
||||
future = _aio.run_coroutine_threadsafe(coro, self._loop)
|
||||
return future.result(timeout=timeout)
|
||||
|
||||
def stop(self):
|
||||
if self._loop:
|
||||
self._loop.call_soon_threadsafe(self._loop.stop)
|
||||
if self._thread:
|
||||
self._thread.join(timeout=5)
|
||||
|
||||
|
||||
class _SlotPoolManager:
|
||||
"""
|
||||
Singleton manager for the slot-pool sandbox backend.
|
||||
|
||||
Wraps atropos/backends/ (ModalToolBackend or NomadToolBackend) and provides
|
||||
synchronous acquire/execute/release operations via a background async worker.
|
||||
|
||||
Config via environment variables:
|
||||
TERMINAL_SLOT_BACKEND = modal | nomad (default: modal)
|
||||
# Modal settings (reuses TERMINAL_MODAL_* vars):
|
||||
TERMINAL_MODAL_IMAGE = python:3.11
|
||||
TERMINAL_MODAL_SLOTS = 10
|
||||
TERMINAL_MODAL_MIN = 1
|
||||
TERMINAL_MODAL_MAX = 5
|
||||
# Nomad settings:
|
||||
TERMINAL_NOMAD_ADDRESS = http://localhost:4646
|
||||
TERMINAL_NOMAD_DRIVER = docker | singularity
|
||||
TERMINAL_NOMAD_IMAGE = atropos-sandbox:local
|
||||
"""
|
||||
|
||||
_instance: Optional["_SlotPoolManager"] = None
|
||||
_lock = threading.Lock()
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls) -> "_SlotPoolManager":
|
||||
if cls._instance is None:
|
||||
with cls._lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = cls()
|
||||
cls._instance._start()
|
||||
return cls._instance
|
||||
|
||||
@classmethod
|
||||
def reset_instance(cls):
|
||||
with cls._lock:
|
||||
if cls._instance is not None:
|
||||
cls._instance._stop()
|
||||
cls._instance = None
|
||||
|
||||
def __init__(self):
|
||||
self._backend = None
|
||||
self._worker = _SlotPoolAsyncWorker()
|
||||
self._slots: Dict[str, Any] = {} # task_id → Slot
|
||||
self._slot_lock = threading.Lock()
|
||||
self._started = False
|
||||
|
||||
def _start(self):
|
||||
"""Initialize the backend and async worker."""
|
||||
self._worker.start()
|
||||
|
||||
backend_type = os.getenv("TERMINAL_SLOT_BACKEND", "modal").strip().lower()
|
||||
print(f"[SlotPool] Starting {backend_type} backend...")
|
||||
|
||||
if backend_type == "modal":
|
||||
self._backend = self._create_modal_backend()
|
||||
elif backend_type == "nomad":
|
||||
self._backend = self._create_nomad_backend()
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown TERMINAL_SLOT_BACKEND: {backend_type}. Use 'modal' or 'nomad'."
|
||||
)
|
||||
|
||||
self._worker.run(self._backend.start(), timeout=120)
|
||||
self._started = True
|
||||
print(f"[SlotPool] {backend_type} backend started")
|
||||
|
||||
def _create_modal_backend(self):
|
||||
from atropos.backends.modal_backend import ModalSandboxConfig, ModalToolBackend
|
||||
|
||||
config = ModalSandboxConfig(
|
||||
name="default",
|
||||
app_name=os.getenv("TERMINAL_SLOT_APP_NAME", "hermes-slot-pool"),
|
||||
image=os.getenv("TERMINAL_MODAL_IMAGE") or os.getenv("TERMINAL_DOCKER_IMAGE", "python:3.11"),
|
||||
gpu=os.getenv("TERMINAL_MODAL_GPU") or None,
|
||||
cpu=float(os.getenv("TERMINAL_MODAL_CPU", "1.0")),
|
||||
memory=int(os.getenv("TERMINAL_MODAL_MEMORY", "2048")),
|
||||
slots_per_sandbox=int(os.getenv("TERMINAL_MODAL_SLOTS", "10")),
|
||||
min_sandboxes=int(os.getenv("TERMINAL_MODAL_MIN", "1")),
|
||||
max_sandboxes=int(os.getenv("TERMINAL_MODAL_MAX", "5")),
|
||||
idle_timeout=int(os.getenv("TERMINAL_MODAL_IDLE_TIMEOUT", "120")),
|
||||
max_lifetime=int(os.getenv("TERMINAL_MODAL_MAX_LIFETIME", "3600")),
|
||||
acquire_timeout_s=float(os.getenv("TERMINAL_MODAL_ACQUIRE_TIMEOUT", "60.0")),
|
||||
execution_timeout_s=float(os.getenv("TERMINAL_MODAL_EXEC_TIMEOUT", "300.0")),
|
||||
workspace_base=os.getenv("TERMINAL_MODAL_WORKSPACE", "/data"),
|
||||
)
|
||||
return ModalToolBackend(config)
|
||||
|
||||
def _create_nomad_backend(self):
|
||||
from atropos.backends.nomad_backend import NomadBackendConfig, NomadToolBackend
|
||||
|
||||
config = NomadBackendConfig(
|
||||
nomad_address=os.getenv("TERMINAL_NOMAD_ADDRESS", "http://localhost:4646"),
|
||||
job_id=os.getenv("TERMINAL_NOMAD_JOB_ID", "hermes-slot-pool"),
|
||||
image=os.getenv("TERMINAL_NOMAD_IMAGE") or os.getenv("TERMINAL_DOCKER_IMAGE", "atropos-sandbox:local"),
|
||||
driver=os.getenv("TERMINAL_NOMAD_DRIVER", "docker"),
|
||||
slots_per_container=int(os.getenv("TERMINAL_NOMAD_SLOTS", "10")),
|
||||
min_containers=int(os.getenv("TERMINAL_NOMAD_MIN", "1")),
|
||||
max_containers=int(os.getenv("TERMINAL_NOMAD_MAX", "10")),
|
||||
)
|
||||
return NomadToolBackend(config)
|
||||
|
||||
def _stop(self):
|
||||
"""Shut down the backend and worker."""
|
||||
if self._started and self._backend:
|
||||
try:
|
||||
# Release all held slots
|
||||
with self._slot_lock:
|
||||
for task_id, slot in list(self._slots.items()):
|
||||
try:
|
||||
self._worker.run(
|
||||
self._backend.release(slot, reset_workspace=True),
|
||||
timeout=10,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
self._slots.clear()
|
||||
|
||||
self._worker.run(self._backend.stop(purge=False), timeout=30)
|
||||
except Exception as e:
|
||||
print(f"[SlotPool] Warning: shutdown error: {e}")
|
||||
finally:
|
||||
self._started = False
|
||||
|
||||
self._worker.stop()
|
||||
print("[SlotPool] Backend stopped")
|
||||
|
||||
def acquire(self, task_id: str, timeout: float = 60.0):
|
||||
"""Acquire a slot for a task_id. Returns the Slot object."""
|
||||
with self._slot_lock:
|
||||
if task_id in self._slots:
|
||||
return self._slots[task_id]
|
||||
|
||||
slot = self._worker.run(
|
||||
self._backend.acquire(task_id), timeout=timeout
|
||||
)
|
||||
|
||||
with self._slot_lock:
|
||||
self._slots[task_id] = slot
|
||||
|
||||
return slot
|
||||
|
||||
def execute(self, task_id: str, command: str, cwd: str = "", timeout: float = 300.0) -> dict:
|
||||
"""Execute a command in the task's slot. Returns {"output": ..., "returncode": ...}."""
|
||||
with self._slot_lock:
|
||||
slot = self._slots.get(task_id)
|
||||
if slot is None:
|
||||
return {"output": "Error: no slot acquired for this task", "returncode": 1}
|
||||
|
||||
# Build command with cwd prefix if needed
|
||||
full_command = f"cd {cwd} && {command}" if cwd else command
|
||||
|
||||
results = self._worker.run(
|
||||
self._backend.execute_batch(
|
||||
[(slot, "bash", {"command": full_command})],
|
||||
timeout_s=timeout,
|
||||
),
|
||||
timeout=timeout + 30, # Extra margin for network
|
||||
)
|
||||
|
||||
r = results[0]
|
||||
output = r.output if r.success else (
|
||||
f"{r.output}\n{r.error}" if r.output else r.error
|
||||
)
|
||||
returncode = r.metadata.get("returncode", 0 if r.success else 1)
|
||||
return {"output": output, "returncode": returncode}
|
||||
|
||||
def release(self, task_id: str, reset_workspace: bool = True):
|
||||
"""Release a task's slot back to the pool."""
|
||||
with self._slot_lock:
|
||||
slot = self._slots.pop(task_id, None)
|
||||
if slot is None:
|
||||
return
|
||||
|
||||
try:
|
||||
self._worker.run(
|
||||
self._backend.release(slot, reset_workspace=reset_workspace),
|
||||
timeout=30,
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"[SlotPool] Warning: release failed for {task_id}: {e}")
|
||||
|
||||
def get_status(self) -> Dict[str, Any]:
|
||||
"""Get pool status."""
|
||||
if not self._started or not self._backend:
|
||||
return {"status": "not started"}
|
||||
return self._backend.get_status()
|
||||
|
||||
|
||||
class _SlotPoolEnvironment:
|
||||
"""
|
||||
Slot-pool based execution environment.
|
||||
|
||||
Routes terminal commands through atropos/backends/ (Modal, Nomad/Docker,
|
||||
Nomad/Singularity) with N:M slot multiplexing. Multiple tasks share a
|
||||
smaller number of sandboxes via slot assignment.
|
||||
|
||||
Usage:
|
||||
TERMINAL_ENV=slot_pool
|
||||
TERMINAL_SLOT_BACKEND=modal # or nomad
|
||||
TERMINAL_MODAL_IMAGE=python:3.11
|
||||
TERMINAL_MODAL_SLOTS=10
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cwd: str = "/data",
|
||||
timeout: int = 300,
|
||||
task_id: str = "",
|
||||
):
|
||||
self.cwd = cwd
|
||||
self.timeout = timeout
|
||||
self.task_id = task_id or str(uuid.uuid4())
|
||||
self._released = False
|
||||
|
||||
# Acquire a slot from the pool
|
||||
manager = _SlotPoolManager.get_instance()
|
||||
manager.acquire(self.task_id, timeout=60.0)
|
||||
|
||||
def execute(self, command: str, cwd: str = "", *, timeout: int | None = None) -> dict:
|
||||
"""Execute a command in the slot's workspace."""
|
||||
exec_command = _transform_sudo_command(command)
|
||||
work_dir = cwd or self.cwd
|
||||
|
||||
try:
|
||||
return _SlotPoolManager.get_instance().execute(
|
||||
self.task_id,
|
||||
exec_command,
|
||||
cwd=work_dir,
|
||||
timeout=float(timeout or self.timeout),
|
||||
)
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
if "timeout" in error_msg.lower():
|
||||
return {"output": f"Command timed out after {timeout or self.timeout}s", "returncode": 124}
|
||||
return {"output": f"SlotPool execution error: {error_msg}", "returncode": 1}
|
||||
|
||||
def cleanup(self):
|
||||
"""Release slot back to the pool (workspace reset for reuse)."""
|
||||
if not self._released:
|
||||
self._released = True
|
||||
_SlotPoolManager.get_instance().release(self.task_id, reset_workspace=True)
|
||||
|
||||
def stop(self):
|
||||
"""Same as cleanup for slot pool."""
|
||||
self.cleanup()
|
||||
|
||||
def __del__(self):
|
||||
try:
|
||||
self.cleanup()
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
def _shutdown_slot_pool():
|
||||
"""Shutdown the slot pool manager (called at process exit)."""
|
||||
try:
|
||||
_SlotPoolManager.reset_instance()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Register slot pool shutdown alongside modal pool shutdown
|
||||
import atexit as _atexit_slot
|
||||
_atexit_slot.register(_shutdown_slot_pool)
|
||||
|
||||
|
||||
# Tool description for LLM
|
||||
TERMINAL_TOOL_DESCRIPTION = """Execute commands on a secure Linux environment.
|
||||
|
||||
@@ -1969,21 +1664,8 @@ def _create_environment(env_type: str, image: str, cwd: str, timeout: int, ssh_c
|
||||
timeout=timeout
|
||||
)
|
||||
|
||||
elif env_type == "slot_pool":
|
||||
# Multiplexed sandbox pool via atropos/backends/ (Modal, Nomad/Docker, Nomad/Singularity)
|
||||
# N:M slot multiplexing for high-throughput parallel execution
|
||||
workspace = os.getenv("TERMINAL_MODAL_WORKSPACE", "/data")
|
||||
return _SlotPoolEnvironment(
|
||||
cwd=cwd or workspace,
|
||||
timeout=timeout,
|
||||
task_id=task_id if 'task_id' in dir() else "",
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown environment type: {env_type}. "
|
||||
"Use 'local', 'docker', 'singularity', 'modal', 'ssh', or 'slot_pool'"
|
||||
)
|
||||
raise ValueError(f"Unknown environment type: {env_type}. Use 'local', 'docker', 'singularity', 'modal', or 'ssh'")
|
||||
|
||||
|
||||
def _cleanup_inactive_envs(lifetime_seconds: int = 300):
|
||||
@@ -2411,14 +2093,17 @@ def check_terminal_requirements() -> bool:
|
||||
|
||||
try:
|
||||
if env_type == "local":
|
||||
from minisweagent.environments.local import LocalEnvironment
|
||||
return True
|
||||
elif env_type == "docker":
|
||||
# check actually available..
|
||||
from minisweagent.environments.docker import DockerEnvironment
|
||||
# Check if docker is available
|
||||
import subprocess
|
||||
result = subprocess.run(["docker", "version"], capture_output=True, timeout=5)
|
||||
return result.returncode == 0
|
||||
elif env_type == "singularity":
|
||||
# Check if singularity/apptainer is available (doesn't work on mac)
|
||||
from minisweagent.environments.singularity import SingularityEnvironment
|
||||
# Check if singularity/apptainer is available
|
||||
import subprocess
|
||||
import shutil
|
||||
executable = shutil.which("apptainer") or shutil.which("singularity")
|
||||
@@ -2427,11 +2112,9 @@ def check_terminal_requirements() -> bool:
|
||||
return result.returncode == 0
|
||||
return False
|
||||
elif env_type == "modal":
|
||||
# check modal is actually configured
|
||||
from minisweagent.environments.extra.swerex_modal import SwerexModalEnvironment
|
||||
# Check for modal token
|
||||
return os.getenv("MODAL_TOKEN_ID") is not None or Path.home().joinpath(".modal.toml").exists()
|
||||
elif env_type == "slot_pool":
|
||||
# Slot pool uses atropos/backends/ & always available if modal/nomad is configured
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
except Exception as e:
|
||||
|
||||
Reference in New Issue
Block a user