mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-28 15:01:34 +08:00
Compare commits
6 Commits
codex-port
...
sid/tb2-ev
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f7d57c0108 | ||
|
|
d77783d198 | ||
|
|
4af69097f2 | ||
|
|
59471b79e5 | ||
|
|
0e459f2b7b | ||
|
|
3befb9389f |
@@ -138,6 +138,7 @@ class HermesAgentLoop:
|
||||
max_turns: int = 30,
|
||||
task_id: Optional[str] = None,
|
||||
temperature: float = 1.0,
|
||||
top_p: Optional[float] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
extra_body: Optional[Dict[str, Any]] = None,
|
||||
budget_config: Optional["BudgetConfig"] = None,
|
||||
@@ -153,6 +154,7 @@ class HermesAgentLoop:
|
||||
max_turns: Maximum number of LLM calls before stopping
|
||||
task_id: Unique ID for terminal/browser session isolation
|
||||
temperature: Sampling temperature for generation
|
||||
top_p: Nucleus sampling top_p (None = omit, use provider default)
|
||||
max_tokens: Max tokens per generation (None for server default)
|
||||
extra_body: Extra parameters passed to the OpenAI client's create() call.
|
||||
Used for OpenRouter provider preferences, transforms, etc.
|
||||
@@ -168,6 +170,7 @@ class HermesAgentLoop:
|
||||
self.max_turns = max_turns
|
||||
self.task_id = task_id or str(uuid.uuid4())
|
||||
self.temperature = temperature
|
||||
self.top_p = top_p
|
||||
self.max_tokens = max_tokens
|
||||
self.extra_body = extra_body
|
||||
self.budget_config = budget_config or DEFAULT_BUDGET
|
||||
@@ -211,6 +214,9 @@ class HermesAgentLoop:
|
||||
"temperature": self.temperature,
|
||||
}
|
||||
|
||||
if self.top_p is not None:
|
||||
chat_kwargs["top_p"] = self.top_p
|
||||
|
||||
# Only pass tools if we have them
|
||||
if self.tool_schemas:
|
||||
chat_kwargs["tools"] = self.tool_schemas
|
||||
@@ -225,20 +231,35 @@ class HermesAgentLoop:
|
||||
chat_kwargs["extra_body"] = self.extra_body
|
||||
|
||||
# Make the API call -- standard OpenAI spec
|
||||
# Retry on timeout/connection errors (provider queuing, rate limits)
|
||||
api_start = _time.monotonic()
|
||||
try:
|
||||
response = await self.server.chat_completion(**chat_kwargs)
|
||||
except Exception as e:
|
||||
api_elapsed = _time.monotonic() - api_start
|
||||
logger.error("API call failed on turn %d (%.1fs): %s", turn + 1, api_elapsed, e)
|
||||
return AgentResult(
|
||||
messages=messages,
|
||||
managed_state=self._get_managed_state(),
|
||||
turns_used=turn + 1,
|
||||
finished_naturally=False,
|
||||
reasoning_per_turn=reasoning_per_turn,
|
||||
tool_errors=tool_errors,
|
||||
)
|
||||
response = None
|
||||
max_retries = 3
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
response = await self.server.chat_completion(**chat_kwargs)
|
||||
break
|
||||
except Exception as e:
|
||||
api_elapsed = _time.monotonic() - api_start
|
||||
is_retryable = "timeout" in type(e).__name__.lower() or "connection" in type(e).__name__.lower()
|
||||
if is_retryable and attempt < max_retries - 1:
|
||||
wait = 2 ** attempt
|
||||
logger.warning(
|
||||
"[%s] API call timed out on turn %d attempt %d (%.1fs), retrying in %ds: %s",
|
||||
self.task_id[:8], turn + 1, attempt + 1, api_elapsed, wait, e,
|
||||
)
|
||||
await asyncio.sleep(wait)
|
||||
api_start = _time.monotonic()
|
||||
continue
|
||||
logger.error("API call failed on turn %d (%.1fs): %s", turn + 1, api_elapsed, e)
|
||||
return AgentResult(
|
||||
messages=messages,
|
||||
managed_state=self._get_managed_state(),
|
||||
turns_used=turn + 1,
|
||||
finished_naturally=False,
|
||||
reasoning_per_turn=reasoning_per_turn,
|
||||
tool_errors=tool_errors,
|
||||
)
|
||||
|
||||
api_elapsed = _time.monotonic() - api_start
|
||||
|
||||
|
||||
@@ -15,15 +15,15 @@
|
||||
|
||||
env:
|
||||
enabled_toolsets: ["terminal", "file"]
|
||||
max_agent_turns: 60
|
||||
max_agent_turns: 100
|
||||
max_token_length: 32000
|
||||
agent_temperature: 0.8
|
||||
agent_temperature: 1.0
|
||||
terminal_backend: "modal"
|
||||
terminal_timeout: 300 # 5 min per command (builds, pip install)
|
||||
tool_pool_size: 128 # thread pool for 89 parallel tasks
|
||||
dataset_name: "NousResearch/terminal-bench-2"
|
||||
terminal_timeout: 300 # 5 min per command (builds, pip install)
|
||||
tool_pool_size: 128 # thread pool for 89 parallel tasks
|
||||
dataset_name: "NousResearch/terminal-bench-2-verified-flattened"
|
||||
test_timeout: 600
|
||||
task_timeout: 1800 # 30 min wall-clock per task, auto-FAIL if exceeded
|
||||
task_timeout: 900 # 15 min wall-clock per task, auto-FAIL if exceeded
|
||||
tokenizer_name: "NousResearch/Hermes-3-Llama-3.1-8B"
|
||||
use_wandb: true
|
||||
wandb_name: "terminal-bench-2"
|
||||
@@ -33,10 +33,15 @@ env:
|
||||
# Modal's blocking calls (App.lookup, etc.) deadlock when too many sandboxes
|
||||
# are created simultaneously inside thread pool workers via asyncio.run().
|
||||
max_concurrent_tasks: 8
|
||||
extra_body:
|
||||
provider:
|
||||
order: ["DeepInfra"]
|
||||
allow_fallbacks: false
|
||||
|
||||
openai:
|
||||
base_url: "https://openrouter.ai/api/v1"
|
||||
model_name: "anthropic/claude-opus-4.6"
|
||||
model_name: "nvidia/nemotron-3-super-120b-a12b"
|
||||
server_type: "openai"
|
||||
health_check: false
|
||||
timeout: 300 # 5 min per API call (default 1200s causes 20min stalls)
|
||||
# api_key loaded from OPENROUTER_API_KEY in .env
|
||||
|
||||
@@ -32,8 +32,8 @@ export PYTHONUNBUFFERED=1
|
||||
# These go to the log file; tqdm + [START]/[PASS]/[FAIL] go to terminal
|
||||
export LOGLEVEL=INFO
|
||||
|
||||
python terminalbench2_env.py evaluate \
|
||||
--config default.yaml \
|
||||
uv run python environments/benchmarks/terminalbench_2/terminalbench2_env.py evaluate \
|
||||
--config environments/benchmarks/terminalbench_2/default.yaml \
|
||||
"$@" \
|
||||
2>&1 | tee "$LOG_FILE"
|
||||
|
||||
|
||||
@@ -52,18 +52,18 @@ _repo_root = Path(__file__).resolve().parent.parent.parent.parent
|
||||
if str(_repo_root) not in sys.path:
|
||||
sys.path.insert(0, str(_repo_root))
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from atroposlib.envs.base import EvalHandlingEnum
|
||||
from atroposlib.envs.server_handling.server_manager import APIServerConfig
|
||||
from pydantic import Field
|
||||
|
||||
from agent.prompt_builder import DEFAULT_AGENT_IDENTITY
|
||||
from environments.agent_loop import AgentResult, HermesAgentLoop
|
||||
from environments.hermes_base_env import HermesAgentBaseEnv, HermesAgentEnvConfig
|
||||
from environments.tool_context import ToolContext
|
||||
from tools.terminal_tool import (
|
||||
register_task_env_overrides,
|
||||
clear_task_env_overrides,
|
||||
cleanup_vm,
|
||||
clear_task_env_overrides,
|
||||
register_task_env_overrides,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -73,6 +73,7 @@ logger = logging.getLogger(__name__)
|
||||
# Configuration
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TerminalBench2EvalConfig(HermesAgentEnvConfig):
|
||||
"""
|
||||
Configuration for the Terminal-Bench 2.0 evaluation environment.
|
||||
@@ -138,11 +139,27 @@ class TerminalBench2EvalConfig(HermesAgentEnvConfig):
|
||||
|
||||
# Tasks that cannot run properly on Modal and are excluded from scoring.
|
||||
MODAL_INCOMPATIBLE_TASKS = {
|
||||
"qemu-startup", # Needs KVM/hardware virtualization
|
||||
"qemu-alpine-ssh", # Needs KVM/hardware virtualization
|
||||
"crack-7z-hash", # Password brute-force -- too slow for cloud sandbox timeouts
|
||||
"qemu-startup", # Needs KVM/hardware virtualization
|
||||
"qemu-alpine-ssh", # Needs KVM/hardware virtualization
|
||||
"crack-7z-hash", # Password brute-force -- too slow for cloud sandbox timeouts
|
||||
}
|
||||
|
||||
# Injected as a user message when the model responds with plain text instead of
|
||||
# calling a tool or including a <task_status> tag.
|
||||
_FORMAT_NUDGE_MESSAGE = (
|
||||
"You wrote a plain text response instead of using your tools. "
|
||||
"Plain text responses do not affect the environment — nothing was executed or saved.\n\n"
|
||||
"You MUST use your tools (terminal, read_file, write_file) to actually complete the task. "
|
||||
"Do not describe what you would do — execute it now by making tool calls.\n\n"
|
||||
"If you have already completed all required work using tools in previous turns, "
|
||||
"respond with exactly: <task_status>DONE</task_status>\n"
|
||||
"If you have exhausted all approaches and cannot make further progress, "
|
||||
"respond with exactly: <task_status>UNFINISHED</task_status>"
|
||||
)
|
||||
|
||||
# Maximum number of format nudges before giving up and moving on to scoring.
|
||||
_MAX_FORMAT_NUDGES = 3
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Tar extraction helper
|
||||
@@ -203,7 +220,6 @@ def _safe_extract_tar(tar: tarfile.TarFile, target_dir: Path) -> None:
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
|
||||
def _extract_base64_tar(b64_data: str, target_dir: Path):
|
||||
"""Extract a base64-encoded tar.gz archive into target_dir."""
|
||||
if not b64_data:
|
||||
@@ -218,6 +234,7 @@ def _extract_base64_tar(b64_data: str, target_dir: Path):
|
||||
# Main Environment
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TerminalBench2EvalEnv(HermesAgentBaseEnv):
|
||||
"""
|
||||
Terminal-Bench 2.0 evaluation environment (eval-only, no training).
|
||||
@@ -262,23 +279,18 @@ class TerminalBench2EvalEnv(HermesAgentBaseEnv):
|
||||
enabled_toolsets=["terminal", "file"],
|
||||
disabled_toolsets=None,
|
||||
distribution=None,
|
||||
|
||||
# Agent settings -- TB2 tasks are complex, need many turns
|
||||
max_agent_turns=60,
|
||||
max_token_length=16000,
|
||||
agent_temperature=0.6,
|
||||
system_prompt=None,
|
||||
|
||||
system_prompt=DEFAULT_AGENT_IDENTITY,
|
||||
# Modal backend for per-task cloud-isolated sandboxes
|
||||
terminal_backend="modal",
|
||||
terminal_timeout=300, # 5 min per command (builds, pip install, etc.)
|
||||
|
||||
terminal_timeout=300, # 5 min per command (builds, pip install, etc.)
|
||||
# Test execution timeout (TB2 test scripts can install deps like pytest)
|
||||
test_timeout=180,
|
||||
|
||||
# 89 tasks run in parallel, each needs a thread for tool calls
|
||||
tool_pool_size=128,
|
||||
|
||||
# --- Eval-only Atropos settings ---
|
||||
# These settings make the env work as an eval-only environment:
|
||||
# - STOP_TRAIN: pauses training during eval (standard for eval envs)
|
||||
@@ -288,7 +300,6 @@ class TerminalBench2EvalEnv(HermesAgentBaseEnv):
|
||||
group_size=1,
|
||||
steps_per_eval=1,
|
||||
total_steps=1,
|
||||
|
||||
tokenizer_name="NousResearch/Hermes-3-Llama-3.1-8B",
|
||||
use_wandb=True,
|
||||
wandb_name="terminal-bench-2",
|
||||
@@ -336,7 +347,11 @@ class TerminalBench2EvalEnv(HermesAgentBaseEnv):
|
||||
|
||||
# Skip tasks incompatible with the current backend (e.g., QEMU on Modal)
|
||||
# plus any user-specified skip_tasks
|
||||
skip = set(MODAL_INCOMPATIBLE_TASKS) if self.config.terminal_backend == "modal" else set()
|
||||
skip = (
|
||||
set(MODAL_INCOMPATIBLE_TASKS)
|
||||
if self.config.terminal_backend == "modal"
|
||||
else set()
|
||||
)
|
||||
if self.config.skip_tasks:
|
||||
skip |= {name.strip() for name in self.config.skip_tasks.split(",")}
|
||||
if skip:
|
||||
@@ -344,7 +359,9 @@ class TerminalBench2EvalEnv(HermesAgentBaseEnv):
|
||||
tasks = [t for t in tasks if t["task_name"] not in skip]
|
||||
skipped = before - len(tasks)
|
||||
if skipped > 0:
|
||||
print(f" Skipped {skipped} incompatible tasks: {sorted(skip & {t['task_name'] for t in ds})}")
|
||||
print(
|
||||
f" Skipped {skipped} incompatible tasks: {sorted(skip & {t['task_name'] for t in ds})}"
|
||||
)
|
||||
|
||||
self.all_eval_items = tasks
|
||||
self.iter = 0
|
||||
@@ -354,6 +371,16 @@ class TerminalBench2EvalEnv(HermesAgentBaseEnv):
|
||||
for i, task in enumerate(self.all_eval_items):
|
||||
self.category_index[task.get("category", "unknown")].append(i)
|
||||
|
||||
# Pre-compute which tasks need Modal's add_python (avoids re-decoding
|
||||
# multi-MB environment_tar blobs during per-task rollouts).
|
||||
self._needs_add_python: Dict[str, bool] = {
|
||||
task["task_name"]: self._image_needs_add_python(task)
|
||||
for task in self.all_eval_items
|
||||
}
|
||||
add_py_count = sum(self._needs_add_python.values())
|
||||
if add_py_count:
|
||||
print(f" {add_py_count} tasks need add_python (non-python base image)")
|
||||
|
||||
# Reward tracking for wandb logging
|
||||
self.eval_metrics: List[Tuple[str, float]] = []
|
||||
|
||||
@@ -361,15 +388,30 @@ class TerminalBench2EvalEnv(HermesAgentBaseEnv):
|
||||
# immediately on completion so data is preserved even on Ctrl+C.
|
||||
# Timestamped filename so each run produces a unique file.
|
||||
import datetime
|
||||
|
||||
log_dir = os.path.join(os.path.dirname(__file__), "logs")
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
run_ts = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
self._streaming_path = os.path.join(log_dir, f"samples_{run_ts}.jsonl")
|
||||
model_name = self.server.servers[0].config.model_name
|
||||
model_slug = model_name.replace("/", "_").replace(":", "_")
|
||||
self._streaming_path = os.path.join(
|
||||
log_dir, f"samples_{run_ts}_{model_slug}.jsonl"
|
||||
)
|
||||
self._streaming_file = open(self._streaming_path, "w")
|
||||
self._streaming_lock = __import__("threading").Lock()
|
||||
self._run_meta = {
|
||||
"model_name": model_name,
|
||||
"temperature": self.config.agent_temperature,
|
||||
"top_p": self.config.agent_top_p,
|
||||
"max_agent_turns": self.config.max_agent_turns,
|
||||
"task_timeout": self.config.task_timeout,
|
||||
"terminal_backend": self.config.terminal_backend,
|
||||
}
|
||||
print(f" Streaming results to: {self._streaming_path}")
|
||||
|
||||
print(f"TB2 ready: {len(self.all_eval_items)} tasks across {len(self.category_index)} categories")
|
||||
print(
|
||||
f"TB2 ready: {len(self.all_eval_items)} tasks across {len(self.category_index)} categories"
|
||||
)
|
||||
for cat, indices in sorted(self.category_index.items()):
|
||||
print(f" {cat}: {len(indices)} tasks")
|
||||
|
||||
@@ -378,7 +420,9 @@ class TerminalBench2EvalEnv(HermesAgentBaseEnv):
|
||||
if not hasattr(self, "_streaming_file") or self._streaming_file.closed:
|
||||
return
|
||||
with self._streaming_lock:
|
||||
self._streaming_file.write(json.dumps(result, ensure_ascii=False, default=str) + "\n")
|
||||
self._streaming_file.write(
|
||||
json.dumps(result, ensure_ascii=False, default=str) + "\n"
|
||||
)
|
||||
self._streaming_file.flush()
|
||||
|
||||
# =========================================================================
|
||||
@@ -414,6 +458,36 @@ class TerminalBench2EvalEnv(HermesAgentBaseEnv):
|
||||
# Docker image resolution
|
||||
# =========================================================================
|
||||
|
||||
@staticmethod
|
||||
def _image_needs_add_python(item: Dict[str, Any]) -> bool:
|
||||
"""Check if the task's base image lacks `python` on PATH.
|
||||
|
||||
Parses the Dockerfile FROM line in environment_tar. Returns True
|
||||
for non-python base images (ubuntu, debian, etc.) that need
|
||||
Modal's add_python parameter.
|
||||
"""
|
||||
environment_tar = item.get("environment_tar", "")
|
||||
if not environment_tar:
|
||||
return False
|
||||
try:
|
||||
raw = base64.b64decode(environment_tar)
|
||||
buf = io.BytesIO(raw)
|
||||
with tarfile.open(fileobj=buf, mode="r:gz") as tar:
|
||||
for member in tar:
|
||||
if not member.isfile() or "Dockerfile" not in member.name:
|
||||
continue
|
||||
f = tar.extractfile(member)
|
||||
if not f:
|
||||
continue
|
||||
for line in f.read().decode("utf-8", errors="ignore").splitlines():
|
||||
stripped = line.strip()
|
||||
if stripped.upper().startswith("FROM "):
|
||||
base = stripped.split()[1].lower()
|
||||
return not base.startswith("python:")
|
||||
except Exception:
|
||||
pass
|
||||
return False
|
||||
|
||||
def _resolve_task_image(
|
||||
self, item: Dict[str, Any], task_name: str
|
||||
) -> Tuple[str, Optional[Path]]:
|
||||
@@ -446,7 +520,9 @@ class TerminalBench2EvalEnv(HermesAgentBaseEnv):
|
||||
if dockerfile_path.exists():
|
||||
logger.info(
|
||||
"Task %s: building from Dockerfile (force_build=%s, docker_image=%s)",
|
||||
task_name, self.config.force_build, bool(docker_image),
|
||||
task_name,
|
||||
self.config.force_build,
|
||||
bool(docker_image),
|
||||
)
|
||||
return str(dockerfile_path), task_dir
|
||||
|
||||
@@ -454,12 +530,80 @@ class TerminalBench2EvalEnv(HermesAgentBaseEnv):
|
||||
if docker_image:
|
||||
logger.warning(
|
||||
"Task %s: force_build=True but no environment_tar, "
|
||||
"falling back to docker_image %s", task_name, docker_image,
|
||||
"falling back to docker_image %s",
|
||||
task_name,
|
||||
docker_image,
|
||||
)
|
||||
return docker_image, None
|
||||
|
||||
return "", None
|
||||
|
||||
# =========================================================================
|
||||
# Agent loop with format nudging
|
||||
# =========================================================================
|
||||
|
||||
async def _run_with_nudges(
|
||||
self,
|
||||
server,
|
||||
tools: List[Dict[str, Any]],
|
||||
valid_names: set,
|
||||
messages: List[Dict[str, Any]],
|
||||
task_id: str,
|
||||
task_name: str,
|
||||
) -> Tuple["AgentResult", int]:
|
||||
"""Run the agent loop, nudging if the model returns plain text without task_status tag."""
|
||||
total_turns_used = 0
|
||||
nudge_count = 0
|
||||
result = None
|
||||
|
||||
while total_turns_used < self.config.max_agent_turns:
|
||||
remaining = self.config.max_agent_turns - total_turns_used
|
||||
agent = HermesAgentLoop(
|
||||
server=server,
|
||||
tool_schemas=tools,
|
||||
valid_tool_names=valid_names,
|
||||
max_turns=remaining,
|
||||
task_id=task_id,
|
||||
temperature=self.config.agent_temperature,
|
||||
top_p=self.config.agent_top_p,
|
||||
max_tokens=self.config.max_token_length,
|
||||
extra_body=self.config.extra_body,
|
||||
)
|
||||
result = await agent.run(messages)
|
||||
total_turns_used += result.turns_used
|
||||
|
||||
if not result.finished_naturally:
|
||||
break
|
||||
|
||||
last_content = next(
|
||||
(
|
||||
m.get("content", "") or ""
|
||||
for m in reversed(messages)
|
||||
if m.get("role") == "assistant"
|
||||
),
|
||||
"",
|
||||
)
|
||||
if "<task_status>" in last_content:
|
||||
break
|
||||
|
||||
if nudge_count >= _MAX_FORMAT_NUDGES:
|
||||
logger.warning(
|
||||
"Task %s: model ignored %d format nudges; stopping.",
|
||||
task_name,
|
||||
nudge_count,
|
||||
)
|
||||
break
|
||||
nudge_count += 1
|
||||
logger.info(
|
||||
"Task %s: nudging model (nudge %d/%d) — no tool calls and no task_status",
|
||||
task_name,
|
||||
nudge_count,
|
||||
_MAX_FORMAT_NUDGES,
|
||||
)
|
||||
messages.append({"role": "user", "content": _FORMAT_NUDGE_MESSAGE})
|
||||
|
||||
return result, total_turns_used
|
||||
|
||||
# =========================================================================
|
||||
# Per-task evaluation -- agent loop + test verification
|
||||
# =========================================================================
|
||||
@@ -488,6 +632,7 @@ class TerminalBench2EvalEnv(HermesAgentBaseEnv):
|
||||
task_dir = None # Set if we extract a Dockerfile (needs cleanup)
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
tqdm.write(f" [START] {task_name} (task_id={task_id[:8]})")
|
||||
task_start = time.time()
|
||||
|
||||
@@ -495,24 +640,32 @@ class TerminalBench2EvalEnv(HermesAgentBaseEnv):
|
||||
# --- 1. Resolve Docker image ---
|
||||
modal_image, task_dir = self._resolve_task_image(eval_item, task_name)
|
||||
if not modal_image:
|
||||
logger.error("Task %s: no docker_image or environment_tar, skipping", task_name)
|
||||
logger.error(
|
||||
"Task %s: no docker_image or environment_tar, skipping", task_name
|
||||
)
|
||||
return {
|
||||
"passed": False, "reward": 0.0,
|
||||
"task_name": task_name, "category": category,
|
||||
"passed": False,
|
||||
"reward": 0.0,
|
||||
"task_name": task_name,
|
||||
"category": category,
|
||||
"error": "no_image",
|
||||
}
|
||||
|
||||
# --- 2. Register per-task image override ---
|
||||
# Set both modal_image and docker_image so the task image is used
|
||||
# regardless of which backend is configured.
|
||||
register_task_env_overrides(task_id, {
|
||||
overrides = {
|
||||
"modal_image": modal_image,
|
||||
"docker_image": modal_image,
|
||||
"cwd": "/app",
|
||||
})
|
||||
}
|
||||
if self._needs_add_python.get(task_name, False):
|
||||
overrides["add_python"] = "3.12"
|
||||
register_task_env_overrides(task_id, overrides)
|
||||
logger.info(
|
||||
"Task %s: registered image override for task_id %s",
|
||||
task_name, task_id[:8],
|
||||
task_name,
|
||||
task_id[:8],
|
||||
)
|
||||
|
||||
# --- 3. Resolve tools and build messages ---
|
||||
@@ -520,53 +673,48 @@ class TerminalBench2EvalEnv(HermesAgentBaseEnv):
|
||||
|
||||
messages: List[Dict[str, Any]] = []
|
||||
if self.config.system_prompt:
|
||||
messages.append({"role": "system", "content": self.config.system_prompt})
|
||||
messages.append(
|
||||
{"role": "system", "content": self.config.system_prompt}
|
||||
)
|
||||
messages.append({"role": "user", "content": self.format_prompt(eval_item)})
|
||||
|
||||
# --- 4. Run agent loop ---
|
||||
# Use ManagedServer (Phase 2) for vLLM/SGLang backends to get
|
||||
# token-level tracking via /generate. Falls back to direct
|
||||
# ServerManager (Phase 1) for OpenAI endpoints.
|
||||
# --- 4. Run agent loop with format enforcement ---
|
||||
# The model must either call a tool or end with <task_status>DONE/UNFINISHED</task_status>.
|
||||
# If it returns plain text without the tag, inject a nudge user message and
|
||||
# continue with the remaining turn budget (up to _MAX_FORMAT_NUDGES times).
|
||||
if self._use_managed_server():
|
||||
async with self.server.managed_server(
|
||||
tokenizer=self.tokenizer,
|
||||
preserve_think_blocks=bool(self.config.thinking_mode),
|
||||
) as managed:
|
||||
agent = HermesAgentLoop(
|
||||
result, total_turns_used = await self._run_with_nudges(
|
||||
server=managed,
|
||||
tool_schemas=tools,
|
||||
valid_tool_names=valid_names,
|
||||
max_turns=self.config.max_agent_turns,
|
||||
tools=tools,
|
||||
valid_names=valid_names,
|
||||
messages=messages,
|
||||
task_id=task_id,
|
||||
temperature=self.config.agent_temperature,
|
||||
max_tokens=self.config.max_token_length,
|
||||
extra_body=self.config.extra_body,
|
||||
budget_config=self.config.build_budget_config(),
|
||||
task_name=task_name,
|
||||
)
|
||||
result = await agent.run(messages)
|
||||
else:
|
||||
agent = HermesAgentLoop(
|
||||
result, total_turns_used = await self._run_with_nudges(
|
||||
server=self.server,
|
||||
tool_schemas=tools,
|
||||
valid_tool_names=valid_names,
|
||||
max_turns=self.config.max_agent_turns,
|
||||
tools=tools,
|
||||
valid_names=valid_names,
|
||||
messages=messages,
|
||||
task_id=task_id,
|
||||
temperature=self.config.agent_temperature,
|
||||
max_tokens=self.config.max_token_length,
|
||||
extra_body=self.config.extra_body,
|
||||
budget_config=self.config.build_budget_config(),
|
||||
task_name=task_name,
|
||||
)
|
||||
result = await agent.run(messages)
|
||||
|
||||
# --- 5. Verify -- run test suite in the agent's sandbox ---
|
||||
# Skip verification if the agent produced no meaningful output
|
||||
only_system_and_user = all(
|
||||
msg.get("role") in ("system", "user") for msg in result.messages
|
||||
msg.get("role") in ("system", "user") for msg in messages
|
||||
)
|
||||
if result.turns_used == 0 or only_system_and_user:
|
||||
if total_turns_used == 0 or only_system_and_user:
|
||||
logger.warning(
|
||||
"Task %s: agent produced no output (turns=%d). Reward=0.",
|
||||
task_name, result.turns_used,
|
||||
task_name,
|
||||
total_turns_used,
|
||||
)
|
||||
reward = 0.0
|
||||
else:
|
||||
@@ -578,7 +726,10 @@ class TerminalBench2EvalEnv(HermesAgentBaseEnv):
|
||||
loop = asyncio.get_event_loop()
|
||||
reward = await loop.run_in_executor(
|
||||
None, # default thread pool
|
||||
self._run_tests, eval_item, ctx, task_name,
|
||||
self._run_tests,
|
||||
eval_item,
|
||||
ctx,
|
||||
task_name,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Task %s: test verification failed: %s", task_name, e)
|
||||
@@ -589,20 +740,26 @@ class TerminalBench2EvalEnv(HermesAgentBaseEnv):
|
||||
passed = reward == 1.0
|
||||
status = "PASS" if passed else "FAIL"
|
||||
elapsed = time.time() - task_start
|
||||
tqdm.write(f" [{status}] {task_name} (turns={result.turns_used}, {elapsed:.0f}s)")
|
||||
tqdm.write(
|
||||
f" [{status}] {task_name} (turns={total_turns_used}, {elapsed:.0f}s)"
|
||||
)
|
||||
logger.info(
|
||||
"Task %s: reward=%.1f, turns=%d, finished=%s",
|
||||
task_name, reward, result.turns_used, result.finished_naturally,
|
||||
task_name,
|
||||
reward,
|
||||
total_turns_used,
|
||||
result.finished_naturally,
|
||||
)
|
||||
|
||||
out = {
|
||||
**self._run_meta,
|
||||
"passed": passed,
|
||||
"reward": reward,
|
||||
"task_name": task_name,
|
||||
"category": category,
|
||||
"turns_used": result.turns_used,
|
||||
"turns_used": total_turns_used,
|
||||
"finished_naturally": result.finished_naturally,
|
||||
"messages": result.messages,
|
||||
"messages": messages,
|
||||
}
|
||||
self._save_result(out)
|
||||
return out
|
||||
@@ -612,8 +769,11 @@ class TerminalBench2EvalEnv(HermesAgentBaseEnv):
|
||||
logger.error("Task %s: rollout failed: %s", task_name, e, exc_info=True)
|
||||
tqdm.write(f" [ERROR] {task_name}: {e} ({elapsed:.0f}s)")
|
||||
out = {
|
||||
"passed": False, "reward": 0.0,
|
||||
"task_name": task_name, "category": category,
|
||||
**self._run_meta,
|
||||
"passed": False,
|
||||
"reward": 0.0,
|
||||
"task_name": task_name,
|
||||
"category": category,
|
||||
"error": str(e),
|
||||
}
|
||||
self._save_result(out)
|
||||
@@ -686,7 +846,8 @@ class TerminalBench2EvalEnv(HermesAgentBaseEnv):
|
||||
# Execute the test suite
|
||||
logger.info(
|
||||
"Task %s: running test suite (timeout=%ds)",
|
||||
task_name, self.config.test_timeout,
|
||||
task_name,
|
||||
self.config.test_timeout,
|
||||
)
|
||||
test_result = ctx.terminal(
|
||||
"bash /tests/test.sh",
|
||||
@@ -719,7 +880,9 @@ class TerminalBench2EvalEnv(HermesAgentBaseEnv):
|
||||
logger.warning(
|
||||
"Task %s: reward.txt content unexpected (%r), "
|
||||
"falling back to exit_code=%d",
|
||||
task_name, content, exit_code,
|
||||
task_name,
|
||||
content,
|
||||
exit_code,
|
||||
)
|
||||
reward = 1.0 if exit_code == 0 else 0.0
|
||||
else:
|
||||
@@ -727,14 +890,17 @@ class TerminalBench2EvalEnv(HermesAgentBaseEnv):
|
||||
logger.warning(
|
||||
"Task %s: reward.txt not found after download, "
|
||||
"falling back to exit_code=%d",
|
||||
task_name, exit_code,
|
||||
task_name,
|
||||
exit_code,
|
||||
)
|
||||
reward = 1.0 if exit_code == 0 else 0.0
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Task %s: failed to download verifier dir: %s, "
|
||||
"falling back to exit_code=%d",
|
||||
task_name, e, exit_code,
|
||||
task_name,
|
||||
e,
|
||||
exit_code,
|
||||
)
|
||||
reward = 1.0 if exit_code == 0 else 0.0
|
||||
finally:
|
||||
@@ -745,7 +911,9 @@ class TerminalBench2EvalEnv(HermesAgentBaseEnv):
|
||||
output_preview = output[-500:] if output else "(no output)"
|
||||
logger.info(
|
||||
"Task %s: FAIL (exit_code=%d)\n%s",
|
||||
task_name, exit_code, output_preview,
|
||||
task_name,
|
||||
exit_code,
|
||||
output_preview,
|
||||
)
|
||||
|
||||
return reward
|
||||
@@ -770,12 +938,18 @@ class TerminalBench2EvalEnv(HermesAgentBaseEnv):
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
from tqdm import tqdm
|
||||
|
||||
elapsed = self.config.task_timeout
|
||||
tqdm.write(f" [TIMEOUT] {task_name} (exceeded {elapsed}s wall-clock limit)")
|
||||
tqdm.write(
|
||||
f" [TIMEOUT] {task_name} (exceeded {elapsed}s wall-clock limit)"
|
||||
)
|
||||
logger.error("Task %s: wall-clock timeout after %ds", task_name, elapsed)
|
||||
out = {
|
||||
"passed": False, "reward": 0.0,
|
||||
"task_name": task_name, "category": category,
|
||||
**self._run_meta,
|
||||
"passed": False,
|
||||
"reward": 0.0,
|
||||
"task_name": task_name,
|
||||
"category": category,
|
||||
"error": f"timeout ({elapsed}s)",
|
||||
}
|
||||
self._save_result(out)
|
||||
@@ -809,23 +983,25 @@ class TerminalBench2EvalEnv(HermesAgentBaseEnv):
|
||||
self.handleError(record)
|
||||
|
||||
handler = _TqdmHandler()
|
||||
handler.setFormatter(logging.Formatter(
|
||||
"%(asctime)s [%(name)s] %(levelname)s: %(message)s",
|
||||
datefmt="%H:%M:%S",
|
||||
))
|
||||
handler.setFormatter(
|
||||
logging.Formatter(
|
||||
"%(asctime)s [%(name)s] %(levelname)s: %(message)s",
|
||||
datefmt="%H:%M:%S",
|
||||
)
|
||||
)
|
||||
root = logging.getLogger()
|
||||
root.handlers = [handler] # Replace any existing handlers
|
||||
root.setLevel(logging.INFO)
|
||||
|
||||
# Silence noisy third-party loggers that flood the output
|
||||
logging.getLogger("httpx").setLevel(logging.WARNING) # Every HTTP request
|
||||
logging.getLogger("openai").setLevel(logging.WARNING) # OpenAI client retries
|
||||
logging.getLogger("rex-deploy").setLevel(logging.WARNING) # Swerex deployment
|
||||
logging.getLogger("httpx").setLevel(logging.WARNING) # Every HTTP request
|
||||
logging.getLogger("openai").setLevel(logging.WARNING) # OpenAI client retries
|
||||
logging.getLogger("rex-deploy").setLevel(logging.WARNING) # Swerex deployment
|
||||
logging.getLogger("rex_image_builder").setLevel(logging.WARNING) # Image builds
|
||||
|
||||
print(f"\n{'='*60}")
|
||||
print(f"\n{'=' * 60}")
|
||||
print("Starting Terminal-Bench 2.0 Evaluation")
|
||||
print(f"{'='*60}")
|
||||
print(f"{'=' * 60}")
|
||||
print(f" Dataset: {self.config.dataset_name}")
|
||||
print(f" Total tasks: {len(self.all_eval_items)}")
|
||||
print(f" Max agent turns: {self.config.max_agent_turns}")
|
||||
@@ -833,9 +1009,11 @@ class TerminalBench2EvalEnv(HermesAgentBaseEnv):
|
||||
print(f" Terminal backend: {self.config.terminal_backend}")
|
||||
print(f" Tool thread pool: {self.config.tool_pool_size}")
|
||||
print(f" Terminal timeout: {self.config.terminal_timeout}s/cmd")
|
||||
print(f" Terminal lifetime: {self.config.terminal_lifetime}s (auto: task_timeout + 120)")
|
||||
print(
|
||||
f" Terminal lifetime: {self.config.terminal_lifetime}s (auto: task_timeout + 120)"
|
||||
)
|
||||
print(f" Max concurrent tasks: {self.config.max_concurrent_tasks}")
|
||||
print(f"{'='*60}\n")
|
||||
print(f"{'=' * 60}\n")
|
||||
|
||||
# Semaphore to limit concurrent Modal sandbox creations.
|
||||
# Without this, all 86 tasks fire simultaneously, each creating a Modal
|
||||
@@ -877,6 +1055,7 @@ class TerminalBench2EvalEnv(HermesAgentBaseEnv):
|
||||
await asyncio.gather(*eval_tasks, return_exceptions=True)
|
||||
# Belt-and-suspenders: clean up any remaining sandboxes
|
||||
from tools.terminal_tool import cleanup_all_environments
|
||||
|
||||
cleanup_all_environments()
|
||||
print("All sandboxes cleaned up.")
|
||||
return
|
||||
@@ -922,9 +1101,9 @@ class TerminalBench2EvalEnv(HermesAgentBaseEnv):
|
||||
self.eval_metrics = [(k, v) for k, v in eval_metrics.items()]
|
||||
|
||||
# ---- Print summary ----
|
||||
print(f"\n{'='*60}")
|
||||
print(f"\n{'=' * 60}")
|
||||
print("Terminal-Bench 2.0 Evaluation Results")
|
||||
print(f"{'='*60}")
|
||||
print(f"{'=' * 60}")
|
||||
print(f"Overall Pass Rate: {overall_pass_rate:.4f} ({passed}/{total})")
|
||||
print(f"Evaluation Time: {end_time - start_time:.1f} seconds")
|
||||
|
||||
@@ -944,7 +1123,7 @@ class TerminalBench2EvalEnv(HermesAgentBaseEnv):
|
||||
extra = f" (error: {error})" if error else ""
|
||||
print(f" [{status}] {r['task_name']} (turns={turns}){extra}")
|
||||
|
||||
print(f"{'='*60}\n")
|
||||
print(f"{'=' * 60}\n")
|
||||
|
||||
# Build sample records for evaluate_log (includes full conversations)
|
||||
samples = [
|
||||
@@ -969,6 +1148,7 @@ class TerminalBench2EvalEnv(HermesAgentBaseEnv):
|
||||
end_time=end_time,
|
||||
generation_parameters={
|
||||
"temperature": self.config.agent_temperature,
|
||||
"top_p": self.config.agent_top_p,
|
||||
"max_tokens": self.config.max_token_length,
|
||||
"max_agent_turns": self.config.max_agent_turns,
|
||||
"terminal_backend": self.config.terminal_backend,
|
||||
@@ -985,6 +1165,7 @@ class TerminalBench2EvalEnv(HermesAgentBaseEnv):
|
||||
# Kill all remaining sandboxes. Timed-out tasks leave orphaned thread
|
||||
# pool workers still executing commands -- cleanup_all stops them.
|
||||
from tools.terminal_tool import cleanup_all_environments
|
||||
|
||||
print("\nCleaning up all sandboxes...")
|
||||
cleanup_all_environments()
|
||||
|
||||
@@ -992,6 +1173,7 @@ class TerminalBench2EvalEnv(HermesAgentBaseEnv):
|
||||
# tasks are killed immediately instead of retrying against dead
|
||||
# sandboxes and spamming the console with TimeoutError warnings.
|
||||
from environments.agent_loop import _tool_executor
|
||||
|
||||
_tool_executor.shutdown(wait=False, cancel_futures=True)
|
||||
print("Done.")
|
||||
|
||||
|
||||
@@ -115,6 +115,10 @@ class HermesAgentEnvConfig(BaseEnvConfig):
|
||||
default=1.0,
|
||||
description="Sampling temperature for agent generation during rollouts.",
|
||||
)
|
||||
agent_top_p: Optional[float] = Field(
|
||||
default=None,
|
||||
description="Nucleus sampling top_p for agent generation. None = provider default.",
|
||||
)
|
||||
|
||||
# --- Terminal backend ---
|
||||
terminal_backend: str = Field(
|
||||
@@ -529,6 +533,7 @@ class HermesAgentBaseEnv(BaseEnv):
|
||||
max_turns=self.config.max_agent_turns,
|
||||
task_id=task_id,
|
||||
temperature=self.config.agent_temperature,
|
||||
top_p=self.config.agent_top_p,
|
||||
max_tokens=self.config.max_token_length,
|
||||
extra_body=self.config.extra_body,
|
||||
budget_config=self.config.build_budget_config(),
|
||||
@@ -547,6 +552,7 @@ class HermesAgentBaseEnv(BaseEnv):
|
||||
max_turns=self.config.max_agent_turns,
|
||||
task_id=task_id,
|
||||
temperature=self.config.agent_temperature,
|
||||
top_p=self.config.agent_top_p,
|
||||
max_tokens=self.config.max_token_length,
|
||||
extra_body=self.config.extra_body,
|
||||
budget_config=self.config.build_budget_config(),
|
||||
@@ -561,6 +567,7 @@ class HermesAgentBaseEnv(BaseEnv):
|
||||
max_turns=self.config.max_agent_turns,
|
||||
task_id=task_id,
|
||||
temperature=self.config.agent_temperature,
|
||||
top_p=self.config.agent_top_p,
|
||||
max_tokens=self.config.max_token_length,
|
||||
extra_body=self.config.extra_body,
|
||||
budget_config=self.config.build_budget_config(),
|
||||
|
||||
401
tests/tools/test_modal_sandbox_timeout.py
Normal file
401
tests/tools/test_modal_sandbox_timeout.py
Normal file
@@ -0,0 +1,401 @@
|
||||
"""Tests verifying the Modal sandbox timeout bug fix.
|
||||
|
||||
Bug: `lifetime_seconds` from container_config was never passed through to
|
||||
`sandbox_kwargs["timeout"]`, so Modal always used its default of 3600s.
|
||||
|
||||
Fix applied to:
|
||||
- tools/terminal_tool.py: `_create_environment()` now sets
|
||||
`sandbox_kwargs["timeout"]` from `cc.get("lifetime_seconds", 3600)`
|
||||
- tools/terminal_tool.py: `container_config` dict now includes
|
||||
`"lifetime_seconds"` from config
|
||||
- tools/environments/managed_modal.py: `_create_sandbox()` reads timeout
|
||||
from `self._sandbox_kwargs` instead of hardcoding 3_600_000
|
||||
"""
|
||||
|
||||
import sys
|
||||
import types
|
||||
import tempfile
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Repo root on sys.path
|
||||
# ---------------------------------------------------------------------------
|
||||
_repo_root = Path(__file__).resolve().parents[2]
|
||||
if str(_repo_root) not in sys.path:
|
||||
sys.path.insert(0, str(_repo_root))
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Load terminal_tool (may be skipped if deps are missing)
|
||||
# ---------------------------------------------------------------------------
|
||||
try:
|
||||
import tools.terminal_tool as _tt_mod
|
||||
except ImportError:
|
||||
pytest.skip("tools.terminal_tool not importable (missing deps)", allow_module_level=True)
|
||||
|
||||
TOOLS_DIR = _repo_root / "tools"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers shared with test_managed_modal_environment.py
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _reset_modules(prefixes: tuple):
|
||||
for name in list(sys.modules):
|
||||
if name.startswith(prefixes):
|
||||
sys.modules.pop(name, None)
|
||||
|
||||
|
||||
def _install_fake_tools_package(*, credential_mounts=None):
|
||||
"""Install a minimal fake tools package so managed_modal.py can be loaded
|
||||
without network access or real Modal credentials."""
|
||||
_reset_modules(("tools", "agent", "hermes_cli"))
|
||||
|
||||
hermes_cli = types.ModuleType("hermes_cli")
|
||||
hermes_cli.__path__ = [] # type: ignore[attr-defined]
|
||||
sys.modules["hermes_cli"] = hermes_cli
|
||||
sys.modules["hermes_cli.config"] = types.SimpleNamespace(
|
||||
get_hermes_home=lambda: Path(tempfile.gettempdir()) / "hermes-home",
|
||||
)
|
||||
|
||||
tools_package = types.ModuleType("tools")
|
||||
tools_package.__path__ = [str(TOOLS_DIR)] # type: ignore[attr-defined]
|
||||
sys.modules["tools"] = tools_package
|
||||
|
||||
env_package = types.ModuleType("tools.environments")
|
||||
env_package.__path__ = [str(TOOLS_DIR / "environments")] # type: ignore[attr-defined]
|
||||
sys.modules["tools.environments"] = env_package
|
||||
|
||||
interrupt_event = threading.Event()
|
||||
sys.modules["tools.interrupt"] = types.SimpleNamespace(
|
||||
set_interrupt=lambda value=True: interrupt_event.set() if value else interrupt_event.clear(),
|
||||
is_interrupted=lambda: interrupt_event.is_set(),
|
||||
_interrupt_event=interrupt_event,
|
||||
)
|
||||
|
||||
class _DummyBaseEnvironment:
|
||||
def __init__(self, cwd: str = "/root", timeout: int = 60, env=None):
|
||||
self.cwd = cwd
|
||||
self.timeout = timeout
|
||||
self.env = env or {}
|
||||
|
||||
sys.modules["tools.environments.base"] = types.SimpleNamespace(
|
||||
BaseEnvironment=_DummyBaseEnvironment,
|
||||
)
|
||||
sys.modules["tools.managed_tool_gateway"] = types.SimpleNamespace(
|
||||
resolve_managed_tool_gateway=lambda vendor: types.SimpleNamespace(
|
||||
vendor=vendor,
|
||||
gateway_origin="https://modal-gateway.example.com",
|
||||
nous_user_token="user-token",
|
||||
managed_mode=True,
|
||||
)
|
||||
)
|
||||
sys.modules["tools.credential_files"] = types.SimpleNamespace(
|
||||
get_credential_file_mounts=lambda: list(credential_mounts or []),
|
||||
)
|
||||
|
||||
return interrupt_event
|
||||
|
||||
|
||||
class _FakeResponse:
|
||||
"""Minimal requests.Response substitute."""
|
||||
|
||||
def __init__(self, status_code: int, payload=None):
|
||||
self.status_code = status_code
|
||||
self._payload = payload
|
||||
self.text = ""
|
||||
|
||||
def json(self):
|
||||
return self._payload
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Tests: _create_environment (direct modal path)
|
||||
# ===========================================================================
|
||||
|
||||
class TestCreateEnvironmentTimeoutPassthrough:
|
||||
"""_create_environment() must set sandbox_kwargs['timeout'] from lifetime_seconds."""
|
||||
|
||||
def test_lifetime_seconds_7200_reaches_modal_environment(self, monkeypatch):
|
||||
"""When container_config has lifetime_seconds=7200, ModalEnvironment gets timeout=7200."""
|
||||
captured_kwargs = {}
|
||||
sentinel = object()
|
||||
|
||||
def _fake_modal_env(**kwargs):
|
||||
captured_kwargs.update(kwargs)
|
||||
return sentinel
|
||||
|
||||
# Force the direct backend so we hit ModalEnvironment, not ManagedModalEnvironment
|
||||
monkeypatch.setattr(
|
||||
_tt_mod,
|
||||
"_get_modal_backend_state",
|
||||
lambda _: {"selected_backend": "direct"},
|
||||
)
|
||||
monkeypatch.setattr(_tt_mod, "_ModalEnvironment", _fake_modal_env)
|
||||
|
||||
result = _tt_mod._create_environment(
|
||||
env_type="modal",
|
||||
image="python:3.11",
|
||||
cwd="/root",
|
||||
timeout=60,
|
||||
container_config={"lifetime_seconds": 7200},
|
||||
)
|
||||
|
||||
assert result is sentinel, "Should have used our fake ModalEnvironment"
|
||||
modal_sandbox_kwargs = captured_kwargs.get("modal_sandbox_kwargs", {})
|
||||
assert modal_sandbox_kwargs.get("timeout") == 7200, (
|
||||
f"Expected timeout=7200 in modal_sandbox_kwargs, got: {modal_sandbox_kwargs}"
|
||||
)
|
||||
|
||||
def test_lifetime_seconds_defaults_to_3600_when_absent(self, monkeypatch):
|
||||
"""When lifetime_seconds is not in container_config, timeout defaults to 3600."""
|
||||
captured_kwargs = {}
|
||||
sentinel = object()
|
||||
|
||||
def _fake_modal_env(**kwargs):
|
||||
captured_kwargs.update(kwargs)
|
||||
return sentinel
|
||||
|
||||
monkeypatch.setattr(
|
||||
_tt_mod,
|
||||
"_get_modal_backend_state",
|
||||
lambda _: {"selected_backend": "direct"},
|
||||
)
|
||||
monkeypatch.setattr(_tt_mod, "_ModalEnvironment", _fake_modal_env)
|
||||
|
||||
result = _tt_mod._create_environment(
|
||||
env_type="modal",
|
||||
image="python:3.11",
|
||||
cwd="/root",
|
||||
timeout=60,
|
||||
container_config={}, # no lifetime_seconds
|
||||
)
|
||||
|
||||
assert result is sentinel
|
||||
modal_sandbox_kwargs = captured_kwargs.get("modal_sandbox_kwargs", {})
|
||||
assert modal_sandbox_kwargs.get("timeout") == 3600, (
|
||||
f"Expected default timeout=3600, got: {modal_sandbox_kwargs}"
|
||||
)
|
||||
|
||||
def test_lifetime_seconds_none_container_config_defaults_to_3600(self, monkeypatch):
|
||||
"""When container_config is None, timeout defaults to 3600."""
|
||||
captured_kwargs = {}
|
||||
sentinel = object()
|
||||
|
||||
def _fake_modal_env(**kwargs):
|
||||
captured_kwargs.update(kwargs)
|
||||
return sentinel
|
||||
|
||||
monkeypatch.setattr(
|
||||
_tt_mod,
|
||||
"_get_modal_backend_state",
|
||||
lambda _: {"selected_backend": "direct"},
|
||||
)
|
||||
monkeypatch.setattr(_tt_mod, "_ModalEnvironment", _fake_modal_env)
|
||||
|
||||
result = _tt_mod._create_environment(
|
||||
env_type="modal",
|
||||
image="python:3.11",
|
||||
cwd="/root",
|
||||
timeout=60,
|
||||
container_config=None, # None container_config
|
||||
)
|
||||
|
||||
assert result is sentinel
|
||||
modal_sandbox_kwargs = captured_kwargs.get("modal_sandbox_kwargs", {})
|
||||
assert modal_sandbox_kwargs.get("timeout") == 3600, (
|
||||
f"Expected default timeout=3600, got: {modal_sandbox_kwargs}"
|
||||
)
|
||||
|
||||
def test_lifetime_seconds_7200_reaches_managed_modal_environment(self, monkeypatch):
|
||||
"""When managed backend is selected, ManagedModalEnvironment also gets timeout=7200."""
|
||||
captured_kwargs = {}
|
||||
sentinel = object()
|
||||
|
||||
def _fake_managed_env(**kwargs):
|
||||
captured_kwargs.update(kwargs)
|
||||
return sentinel
|
||||
|
||||
monkeypatch.setattr(
|
||||
_tt_mod,
|
||||
"_get_modal_backend_state",
|
||||
lambda _: {"selected_backend": "managed"},
|
||||
)
|
||||
monkeypatch.setattr(_tt_mod, "_ManagedModalEnvironment", _fake_managed_env)
|
||||
|
||||
result = _tt_mod._create_environment(
|
||||
env_type="modal",
|
||||
image="python:3.11",
|
||||
cwd="/root",
|
||||
timeout=60,
|
||||
container_config={"lifetime_seconds": 7200},
|
||||
)
|
||||
|
||||
assert result is sentinel
|
||||
modal_sandbox_kwargs = captured_kwargs.get("modal_sandbox_kwargs", {})
|
||||
assert modal_sandbox_kwargs.get("timeout") == 7200, (
|
||||
f"Expected timeout=7200 in modal_sandbox_kwargs for managed env, got: {modal_sandbox_kwargs}"
|
||||
)
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Tests: container_config includes lifetime_seconds from _get_env_config
|
||||
# ===========================================================================
|
||||
|
||||
class TestContainerConfigLifetimeSeconds:
|
||||
"""container_config dict built in terminal_tool must include lifetime_seconds."""
|
||||
|
||||
def test_container_config_includes_lifetime_seconds_from_env(self, monkeypatch):
|
||||
"""TERMINAL_LIFETIME_SECONDS env var flows into container_config."""
|
||||
monkeypatch.setenv("TERMINAL_ENV", "modal")
|
||||
monkeypatch.setenv("TERMINAL_LIFETIME_SECONDS", "7200")
|
||||
config = _tt_mod._get_env_config()
|
||||
assert config.get("lifetime_seconds") == 7200, (
|
||||
f"Expected lifetime_seconds=7200 in config, got: {config.get('lifetime_seconds')}"
|
||||
)
|
||||
|
||||
def test_container_config_lifetime_seconds_default_is_300(self, monkeypatch):
|
||||
"""Without TERMINAL_LIFETIME_SECONDS, the default should be 300 (cleanup thread default)."""
|
||||
monkeypatch.setenv("TERMINAL_ENV", "modal")
|
||||
monkeypatch.delenv("TERMINAL_LIFETIME_SECONDS", raising=False)
|
||||
config = _tt_mod._get_env_config()
|
||||
assert "lifetime_seconds" in config, "lifetime_seconds must be present in config"
|
||||
# Default from code is 300
|
||||
assert config["lifetime_seconds"] == 300, (
|
||||
f"Expected default lifetime_seconds=300, got: {config['lifetime_seconds']}"
|
||||
)
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Tests: ManagedModalEnvironment._create_sandbox uses sandbox_kwargs timeout
|
||||
# ===========================================================================
|
||||
|
||||
class TestManagedModalTimeoutPassthrough:
|
||||
"""ManagedModalEnvironment must read timeout from sandbox_kwargs, not hardcode 3_600_000."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _restore_modules(self):
|
||||
"""Save and restore sys.modules so fake package doesn't leak."""
|
||||
saved = {
|
||||
name: mod for name, mod in sys.modules.items()
|
||||
if name.startswith(("tools", "hermes_cli"))
|
||||
}
|
||||
yield
|
||||
_reset_modules(("tools", "hermes_cli"))
|
||||
sys.modules.update(saved)
|
||||
|
||||
def test_sandbox_created_with_7200_timeout(self, monkeypatch):
|
||||
"""ManagedModalEnvironment with lifetime_seconds=7200 sends timeoutMs=7_200_000."""
|
||||
_install_fake_tools_package()
|
||||
|
||||
# Load managed_modal fresh after installing fake package
|
||||
from importlib.util import spec_from_file_location, module_from_spec
|
||||
spec = spec_from_file_location(
|
||||
"tools.environments.managed_modal",
|
||||
TOOLS_DIR / "environments" / "managed_modal.py",
|
||||
)
|
||||
managed_modal = module_from_spec(spec)
|
||||
sys.modules["tools.environments.managed_modal"] = managed_modal
|
||||
spec.loader.exec_module(managed_modal)
|
||||
|
||||
create_payloads = []
|
||||
|
||||
def fake_request(method, url, headers=None, json=None, timeout=None):
|
||||
if method == "POST" and url.endswith("/v1/sandboxes"):
|
||||
create_payloads.append(json)
|
||||
return _FakeResponse(200, {"id": "sandbox-1"})
|
||||
if method == "POST" and url.endswith("/terminate"):
|
||||
return _FakeResponse(200, {"status": "terminated"})
|
||||
raise AssertionError(f"Unexpected request: {method} {url}")
|
||||
|
||||
monkeypatch.setattr(managed_modal.requests, "request", fake_request)
|
||||
|
||||
env = managed_modal.ManagedModalEnvironment(
|
||||
image="python:3.11",
|
||||
modal_sandbox_kwargs={"timeout": 7200},
|
||||
)
|
||||
env.cleanup()
|
||||
|
||||
assert len(create_payloads) == 1
|
||||
payload = create_payloads[0]
|
||||
assert payload["timeoutMs"] == 7_200_000, (
|
||||
f"Expected timeoutMs=7_200_000 (7200s * 1000), got: {payload['timeoutMs']}. "
|
||||
"ManagedModalEnvironment must read timeout from sandbox_kwargs, not hardcode 3600."
|
||||
)
|
||||
|
||||
def test_sandbox_created_with_default_3600_timeout(self, monkeypatch):
|
||||
"""ManagedModalEnvironment with no explicit timeout sends timeoutMs=3_600_000."""
|
||||
_install_fake_tools_package()
|
||||
|
||||
from importlib.util import spec_from_file_location, module_from_spec
|
||||
spec = spec_from_file_location(
|
||||
"tools.environments.managed_modal",
|
||||
TOOLS_DIR / "environments" / "managed_modal.py",
|
||||
)
|
||||
managed_modal = module_from_spec(spec)
|
||||
sys.modules["tools.environments.managed_modal"] = managed_modal
|
||||
spec.loader.exec_module(managed_modal)
|
||||
|
||||
create_payloads = []
|
||||
|
||||
def fake_request(method, url, headers=None, json=None, timeout=None):
|
||||
if method == "POST" and url.endswith("/v1/sandboxes"):
|
||||
create_payloads.append(json)
|
||||
return _FakeResponse(200, {"id": "sandbox-1"})
|
||||
if method == "POST" and url.endswith("/terminate"):
|
||||
return _FakeResponse(200, {"status": "terminated"})
|
||||
raise AssertionError(f"Unexpected request: {method} {url}")
|
||||
|
||||
monkeypatch.setattr(managed_modal.requests, "request", fake_request)
|
||||
|
||||
env = managed_modal.ManagedModalEnvironment(
|
||||
image="python:3.11",
|
||||
modal_sandbox_kwargs={}, # no timeout key — should default to 3600
|
||||
)
|
||||
env.cleanup()
|
||||
|
||||
assert len(create_payloads) == 1
|
||||
payload = create_payloads[0]
|
||||
assert payload["timeoutMs"] == 3_600_000, (
|
||||
f"Expected default timeoutMs=3_600_000, got: {payload['timeoutMs']}"
|
||||
)
|
||||
|
||||
def test_sandbox_created_with_none_kwargs_defaults_to_3600(self, monkeypatch):
|
||||
"""ManagedModalEnvironment with modal_sandbox_kwargs=None defaults to 3600."""
|
||||
_install_fake_tools_package()
|
||||
|
||||
from importlib.util import spec_from_file_location, module_from_spec
|
||||
spec = spec_from_file_location(
|
||||
"tools.environments.managed_modal",
|
||||
TOOLS_DIR / "environments" / "managed_modal.py",
|
||||
)
|
||||
managed_modal = module_from_spec(spec)
|
||||
sys.modules["tools.environments.managed_modal"] = managed_modal
|
||||
spec.loader.exec_module(managed_modal)
|
||||
|
||||
create_payloads = []
|
||||
|
||||
def fake_request(method, url, headers=None, json=None, timeout=None):
|
||||
if method == "POST" and url.endswith("/v1/sandboxes"):
|
||||
create_payloads.append(json)
|
||||
return _FakeResponse(200, {"id": "sandbox-1"})
|
||||
if method == "POST" and url.endswith("/terminate"):
|
||||
return _FakeResponse(200, {"status": "terminated"})
|
||||
raise AssertionError(f"Unexpected request: {method} {url}")
|
||||
|
||||
monkeypatch.setattr(managed_modal.requests, "request", fake_request)
|
||||
|
||||
env = managed_modal.ManagedModalEnvironment(
|
||||
image="python:3.11",
|
||||
modal_sandbox_kwargs=None,
|
||||
)
|
||||
env.cleanup()
|
||||
|
||||
assert len(create_payloads) == 1
|
||||
payload = create_payloads[0]
|
||||
assert payload["timeoutMs"] == 3_600_000, (
|
||||
f"Expected default timeoutMs=3_600_000, got: {payload['timeoutMs']}"
|
||||
)
|
||||
@@ -185,7 +185,7 @@ class ManagedModalEnvironment(BaseModalExecutionEnvironment):
|
||||
"cwd": self.cwd,
|
||||
"cpu": cpu,
|
||||
"memoryMiB": memory,
|
||||
"timeoutMs": 3_600_000,
|
||||
"timeoutMs": int(self._sandbox_kwargs.get("timeout", 3600)) * 1000,
|
||||
"idleTimeoutMs": max(300_000, int(self.timeout * 1000)),
|
||||
"persistentFilesystem": self._persistent,
|
||||
"logicalKey": self._task_id,
|
||||
|
||||
@@ -153,6 +153,7 @@ class ModalEnvironment(BaseEnvironment):
|
||||
modal_sandbox_kwargs: Optional[Dict[str, Any]] = None,
|
||||
persistent_filesystem: bool = True,
|
||||
task_id: str = "default",
|
||||
add_python: Optional[str] = None,
|
||||
):
|
||||
super().__init__(cwd=cwd, timeout=timeout)
|
||||
|
||||
|
||||
@@ -214,6 +214,7 @@ def _get_file_ops(task_id: str = "default") -> ShellFileOperations:
|
||||
image = ""
|
||||
|
||||
cwd = overrides.get("cwd") or config["cwd"]
|
||||
add_python = overrides.get("add_python")
|
||||
logger.info("Creating new %s environment for task %s...", env_type, task_id[:8])
|
||||
|
||||
container_config = None
|
||||
@@ -252,6 +253,7 @@ def _get_file_ops(task_id: str = "default") -> ShellFileOperations:
|
||||
local_config=local_config,
|
||||
task_id=task_id,
|
||||
host_cwd=config.get("host_cwd"),
|
||||
add_python=add_python,
|
||||
)
|
||||
|
||||
with _env_lock:
|
||||
|
||||
@@ -458,6 +458,7 @@ def register_task_env_overrides(task_id: str, overrides: Dict[str, Any]):
|
||||
- modal_image: str -- Path to Dockerfile or Docker Hub image name
|
||||
- docker_image: str -- Docker image name
|
||||
- cwd: str -- Working directory inside the sandbox
|
||||
- add_python: str -- Python version for Modal's add_python (for images without python on PATH)
|
||||
|
||||
Args:
|
||||
task_id: The rollout's unique task identifier
|
||||
@@ -584,7 +585,8 @@ def _create_environment(env_type: str, image: str, cwd: str, timeout: int,
|
||||
ssh_config: dict = None, container_config: dict = None,
|
||||
local_config: dict = None,
|
||||
task_id: str = "default",
|
||||
host_cwd: str = None):
|
||||
host_cwd: str = None,
|
||||
add_python: str = None):
|
||||
"""
|
||||
Create an execution environment for sandboxed command execution.
|
||||
|
||||
@@ -634,6 +636,8 @@ def _create_environment(env_type: str, image: str, cwd: str, timeout: int,
|
||||
|
||||
elif env_type == "modal":
|
||||
sandbox_kwargs = {}
|
||||
lifetime = cc.get("lifetime_seconds", 3600)
|
||||
sandbox_kwargs["timeout"] = int(lifetime)
|
||||
if cpu > 0:
|
||||
sandbox_kwargs["cpu"] = cpu
|
||||
if memory > 0:
|
||||
@@ -682,6 +686,7 @@ def _create_environment(env_type: str, image: str, cwd: str, timeout: int,
|
||||
image=image, cwd=cwd, timeout=timeout,
|
||||
modal_sandbox_kwargs=sandbox_kwargs,
|
||||
persistent_filesystem=persistent, task_id=task_id,
|
||||
add_python=add_python,
|
||||
)
|
||||
|
||||
elif env_type == "daytona":
|
||||
@@ -1057,6 +1062,7 @@ def terminal_tool(
|
||||
image = ""
|
||||
|
||||
cwd = overrides.get("cwd") or config["cwd"]
|
||||
add_python = overrides.get("add_python")
|
||||
default_timeout = config["timeout"]
|
||||
effective_timeout = timeout or default_timeout
|
||||
|
||||
@@ -1115,6 +1121,7 @@ def terminal_tool(
|
||||
"modal_mode": config.get("modal_mode", "auto"),
|
||||
"docker_volumes": config.get("docker_volumes", []),
|
||||
"docker_mount_cwd_to_workspace": config.get("docker_mount_cwd_to_workspace", False),
|
||||
"lifetime_seconds": config.get("lifetime_seconds", 3600),
|
||||
}
|
||||
|
||||
local_config = None
|
||||
@@ -1133,6 +1140,7 @@ def terminal_tool(
|
||||
local_config=local_config,
|
||||
task_id=effective_task_id,
|
||||
host_cwd=config.get("host_cwd"),
|
||||
add_python=add_python,
|
||||
)
|
||||
except ImportError as e:
|
||||
return json.dumps({
|
||||
|
||||
Reference in New Issue
Block a user