mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-28 06:51:16 +08:00
Enhance async tool execution and error handling in Hermes agent for Atropos integration
- Updated `.gitignore` to exclude `testlogs` directory. - Refactored `handle_web_function_call` in `model_tools.py` to support running async functions in existing event loops, improving compatibility with Atropos. - Introduced a thread pool executor in `agent_loop.py` for running synchronous tool calls that internally use `asyncio.run()`, preventing deadlocks. - Added `ToolError` class to track tool execution errors, enhancing error reporting during agent loops. - Updated `wandb_log` method in `hermes_base_env.py` to log tool error statistics for better monitoring. - Implemented patches in `patches.py` to ensure async-safe operation of tools within Atropos's event loop. - Enhanced `ToolContext` and `terminal_tool.py` to utilize the new async handling, improving overall tool execution reliability.
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -42,6 +42,7 @@ images/
|
|||||||
__pycache__/
|
__pycache__/
|
||||||
hermes_agent.egg-info/
|
hermes_agent.egg-info/
|
||||||
wandb/
|
wandb/
|
||||||
|
testlogs
|
||||||
|
|
||||||
# CLI config (may contain sensitive SSH paths)
|
# CLI config (may contain sensitive SSH paths)
|
||||||
cli-config.yaml
|
cli-config.yaml
|
||||||
|
|||||||
@@ -11,6 +11,8 @@ identical to hermes-agent's run_agent.py. Tool execution is dispatched via
|
|||||||
handle_function_call() from model_tools.py.
|
handle_function_call() from model_tools.py.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import concurrent.futures
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
@@ -19,9 +21,25 @@ from typing import Any, Dict, List, Optional, Set
|
|||||||
|
|
||||||
from model_tools import handle_function_call
|
from model_tools import handle_function_call
|
||||||
|
|
||||||
|
# Thread pool for running sync tool calls that internally use asyncio.run()
|
||||||
|
# (e.g., mini-swe-agent's modal/docker backends). Running them in a separate
|
||||||
|
# thread gives them a clean event loop so they don't deadlock inside Atropos's loop.
|
||||||
|
_tool_executor = concurrent.futures.ThreadPoolExecutor(max_workers=8)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ToolError:
|
||||||
|
"""Record of a tool execution error during the agent loop."""
|
||||||
|
|
||||||
|
turn: int # Which turn the error occurred on
|
||||||
|
tool_name: str # Which tool was called
|
||||||
|
arguments: str # The arguments passed (truncated)
|
||||||
|
error: str # The error message
|
||||||
|
tool_result: str # The raw result returned to the model
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class AgentResult:
|
class AgentResult:
|
||||||
"""Result of running the agent loop."""
|
"""Result of running the agent loop."""
|
||||||
@@ -36,6 +54,8 @@ class AgentResult:
|
|||||||
finished_naturally: bool = False
|
finished_naturally: bool = False
|
||||||
# Extracted reasoning content per turn (from PR #297 helpers)
|
# Extracted reasoning content per turn (from PR #297 helpers)
|
||||||
reasoning_per_turn: List[Optional[str]] = field(default_factory=list)
|
reasoning_per_turn: List[Optional[str]] = field(default_factory=list)
|
||||||
|
# Tool errors encountered during the loop
|
||||||
|
tool_errors: List[ToolError] = field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
def _extract_reasoning_from_message(message) -> Optional[str]:
|
def _extract_reasoning_from_message(message) -> Optional[str]:
|
||||||
@@ -133,6 +153,7 @@ class HermesAgentLoop:
|
|||||||
AgentResult with full conversation history, managed state, and metadata
|
AgentResult with full conversation history, managed state, and metadata
|
||||||
"""
|
"""
|
||||||
reasoning_per_turn = []
|
reasoning_per_turn = []
|
||||||
|
tool_errors: List[ToolError] = []
|
||||||
|
|
||||||
for turn in range(self.max_turns):
|
for turn in range(self.max_turns):
|
||||||
# Build the chat_completion kwargs
|
# Build the chat_completion kwargs
|
||||||
@@ -161,6 +182,7 @@ class HermesAgentLoop:
|
|||||||
turns_used=turn + 1,
|
turns_used=turn + 1,
|
||||||
finished_naturally=False,
|
finished_naturally=False,
|
||||||
reasoning_per_turn=reasoning_per_turn,
|
reasoning_per_turn=reasoning_per_turn,
|
||||||
|
tool_errors=tool_errors,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not response or not response.choices:
|
if not response or not response.choices:
|
||||||
@@ -171,6 +193,7 @@ class HermesAgentLoop:
|
|||||||
turns_used=turn + 1,
|
turns_used=turn + 1,
|
||||||
finished_naturally=False,
|
finished_naturally=False,
|
||||||
reasoning_per_turn=reasoning_per_turn,
|
reasoning_per_turn=reasoning_per_turn,
|
||||||
|
tool_errors=tool_errors,
|
||||||
)
|
)
|
||||||
|
|
||||||
assistant_msg = response.choices[0].message
|
assistant_msg = response.choices[0].message
|
||||||
@@ -209,6 +232,7 @@ class HermesAgentLoop:
|
|||||||
# Execute each tool call via hermes-agent's dispatch
|
# Execute each tool call via hermes-agent's dispatch
|
||||||
for tc in assistant_msg.tool_calls:
|
for tc in assistant_msg.tool_calls:
|
||||||
tool_name = tc.function.name
|
tool_name = tc.function.name
|
||||||
|
tool_args_raw = tc.function.arguments
|
||||||
|
|
||||||
# Validate tool name
|
# Validate tool name
|
||||||
if tool_name not in self.valid_tool_names:
|
if tool_name not in self.valid_tool_names:
|
||||||
@@ -218,35 +242,75 @@ class HermesAgentLoop:
|
|||||||
f"Available tools: {sorted(self.valid_tool_names)}"
|
f"Available tools: {sorted(self.valid_tool_names)}"
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
tool_errors.append(ToolError(
|
||||||
|
turn=turn + 1, tool_name=tool_name,
|
||||||
|
arguments=tool_args_raw[:200],
|
||||||
|
error=f"Unknown tool '{tool_name}'",
|
||||||
|
tool_result=tool_result,
|
||||||
|
))
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Model called unknown tool '%s' on turn %d",
|
"Model called unknown tool '%s' on turn %d",
|
||||||
tool_name,
|
tool_name, turn + 1,
|
||||||
turn + 1,
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Parse arguments and dispatch
|
# Parse arguments and dispatch
|
||||||
try:
|
try:
|
||||||
args = json.loads(tc.function.arguments)
|
args = json.loads(tool_args_raw)
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
args = {}
|
args = {}
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Invalid JSON in tool call arguments for '%s': %s",
|
"Invalid JSON in tool call arguments for '%s': %s",
|
||||||
tool_name,
|
tool_name, tool_args_raw[:200],
|
||||||
tc.function.arguments[:200],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
tool_result = handle_function_call(
|
if tool_name == "terminal":
|
||||||
tool_name, args, task_id=self.task_id
|
import os
|
||||||
|
backend = os.getenv("TERMINAL_ENV", "local")
|
||||||
|
cmd_preview = args.get("command", "")[:80]
|
||||||
|
print(f" 🖥️ [{backend}] $ {cmd_preview}")
|
||||||
|
|
||||||
|
# Run tool calls in a thread pool so backends that use
|
||||||
|
# asyncio.run() internally (modal, docker) get a clean
|
||||||
|
# event loop instead of deadlocking inside Atropos's loop.
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
tool_result = await loop.run_in_executor(
|
||||||
|
_tool_executor,
|
||||||
|
lambda: handle_function_call(
|
||||||
|
tool_name, args, task_id=self.task_id
|
||||||
|
),
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
tool_result = json.dumps(
|
tool_result = json.dumps(
|
||||||
{"error": f"Tool execution failed: {str(e)}"}
|
{"error": f"Tool execution failed: {type(e).__name__}: {str(e)}"}
|
||||||
)
|
)
|
||||||
|
tool_errors.append(ToolError(
|
||||||
|
turn=turn + 1, tool_name=tool_name,
|
||||||
|
arguments=tool_args_raw[:200],
|
||||||
|
error=f"{type(e).__name__}: {str(e)}",
|
||||||
|
tool_result=tool_result,
|
||||||
|
))
|
||||||
logger.error(
|
logger.error(
|
||||||
"Tool '%s' execution failed: %s", tool_name, e
|
"Tool '%s' execution failed on turn %d: %s",
|
||||||
|
tool_name, turn + 1, e,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Also check if the tool returned an error in its JSON result
|
||||||
|
try:
|
||||||
|
result_data = json.loads(tool_result)
|
||||||
|
if isinstance(result_data, dict):
|
||||||
|
err = result_data.get("error")
|
||||||
|
exit_code = result_data.get("exit_code")
|
||||||
|
if err and exit_code and exit_code < 0:
|
||||||
|
tool_errors.append(ToolError(
|
||||||
|
turn=turn + 1, tool_name=tool_name,
|
||||||
|
arguments=tool_args_raw[:200],
|
||||||
|
error=str(err),
|
||||||
|
tool_result=tool_result[:500],
|
||||||
|
))
|
||||||
|
except (json.JSONDecodeError, TypeError):
|
||||||
|
pass
|
||||||
|
|
||||||
# Add tool response to conversation
|
# Add tool response to conversation
|
||||||
messages.append(
|
messages.append(
|
||||||
{
|
{
|
||||||
@@ -282,6 +346,7 @@ class HermesAgentLoop:
|
|||||||
turns_used=turn + 1,
|
turns_used=turn + 1,
|
||||||
finished_naturally=True,
|
finished_naturally=True,
|
||||||
reasoning_per_turn=reasoning_per_turn,
|
reasoning_per_turn=reasoning_per_turn,
|
||||||
|
tool_errors=tool_errors,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Hit max turns without the model stopping
|
# Hit max turns without the model stopping
|
||||||
@@ -292,6 +357,7 @@ class HermesAgentLoop:
|
|||||||
turns_used=self.max_turns,
|
turns_used=self.max_turns,
|
||||||
finished_naturally=False,
|
finished_naturally=False,
|
||||||
reasoning_per_turn=reasoning_per_turn,
|
reasoning_per_turn=reasoning_per_turn,
|
||||||
|
tool_errors=tool_errors,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_managed_state(self) -> Optional[Dict[str, Any]]:
|
def _get_managed_state(self) -> Optional[Dict[str, Any]]:
|
||||||
|
|||||||
@@ -41,6 +41,12 @@ _env_path = _repo_root / ".env"
|
|||||||
if _env_path.exists():
|
if _env_path.exists():
|
||||||
load_dotenv(dotenv_path=_env_path)
|
load_dotenv(dotenv_path=_env_path)
|
||||||
|
|
||||||
|
# Apply monkey patches for async-safe tool operation inside Atropos's event loop.
|
||||||
|
# This patches SwerexModalEnvironment to use a background thread instead of
|
||||||
|
# asyncio.run(), which would deadlock inside Atropos. Safe for normal CLI too.
|
||||||
|
from environments.patches import apply_patches
|
||||||
|
apply_patches()
|
||||||
|
|
||||||
from atroposlib.envs.base import (
|
from atroposlib.envs.base import (
|
||||||
BaseEnv,
|
BaseEnv,
|
||||||
BaseEnvConfig,
|
BaseEnvConfig,
|
||||||
@@ -172,10 +178,14 @@ class HermesAgentBaseEnv(BaseEnv):
|
|||||||
# Set terminal backend environment variable so hermes tools pick it up
|
# Set terminal backend environment variable so hermes tools pick it up
|
||||||
if config.terminal_backend:
|
if config.terminal_backend:
|
||||||
os.environ["TERMINAL_ENV"] = config.terminal_backend
|
os.environ["TERMINAL_ENV"] = config.terminal_backend
|
||||||
|
print(f"🖥️ Terminal backend: {config.terminal_backend}")
|
||||||
|
|
||||||
# Current group's resolved tools (set in collect_trajectories)
|
# Current group's resolved tools (set in collect_trajectories)
|
||||||
self._current_group_tools: Optional[Tuple[List[Dict], Set[str]]] = None
|
self._current_group_tools: Optional[Tuple[List[Dict], Set[str]]] = None
|
||||||
|
|
||||||
|
# Tool error tracking for wandb logging
|
||||||
|
self._tool_error_buffer: List[Dict[str, Any]] = []
|
||||||
|
|
||||||
# =========================================================================
|
# =========================================================================
|
||||||
# Toolset resolution (per-group)
|
# Toolset resolution (per-group)
|
||||||
# =========================================================================
|
# =========================================================================
|
||||||
@@ -348,6 +358,33 @@ class HermesAgentBaseEnv(BaseEnv):
|
|||||||
if len(self.rollouts_for_wandb) > self.config.num_rollouts_to_keep:
|
if len(self.rollouts_for_wandb) > self.config.num_rollouts_to_keep:
|
||||||
self.rollouts_for_wandb.pop(0)
|
self.rollouts_for_wandb.pop(0)
|
||||||
|
|
||||||
|
async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
|
||||||
|
"""Log base metrics including tool errors to wandb."""
|
||||||
|
if wandb_metrics is None:
|
||||||
|
wandb_metrics = {}
|
||||||
|
|
||||||
|
# Log tool error stats
|
||||||
|
if self._tool_error_buffer:
|
||||||
|
wandb_metrics["train/tool_errors_count"] = len(self._tool_error_buffer)
|
||||||
|
|
||||||
|
# Log error details as a summary string (tables can crash wandb on tmp cleanup)
|
||||||
|
error_summaries = []
|
||||||
|
for err in self._tool_error_buffer:
|
||||||
|
error_summaries.append(
|
||||||
|
f"[turn {err['turn']}] {err['tool']}({err['args'][:80]}) -> {err['error'][:150]}"
|
||||||
|
)
|
||||||
|
wandb_metrics["train/tool_error_details"] = "\n".join(error_summaries)
|
||||||
|
|
||||||
|
# Also print to stdout for immediate visibility
|
||||||
|
for summary in error_summaries:
|
||||||
|
print(f" Tool Error: {summary}")
|
||||||
|
|
||||||
|
self._tool_error_buffer = []
|
||||||
|
else:
|
||||||
|
wandb_metrics["train/tool_errors_count"] = 0
|
||||||
|
|
||||||
|
await super().wandb_log(wandb_metrics)
|
||||||
|
|
||||||
async def collect_trajectory(
|
async def collect_trajectory(
|
||||||
self, item: Item
|
self, item: Item
|
||||||
) -> Tuple[Optional[Union[ScoredDataItem, Any]], List[Item]]:
|
) -> Tuple[Optional[Union[ScoredDataItem, Any]], List[Item]]:
|
||||||
@@ -376,8 +413,22 @@ class HermesAgentBaseEnv(BaseEnv):
|
|||||||
result: AgentResult
|
result: AgentResult
|
||||||
if self._use_managed_server():
|
if self._use_managed_server():
|
||||||
# Phase 2: ManagedServer with parser -- exact tokens + logprobs
|
# Phase 2: ManagedServer with parser -- exact tokens + logprobs
|
||||||
|
# Load the tool call parser from registry based on config
|
||||||
|
from environments.tool_call_parsers import get_parser
|
||||||
try:
|
try:
|
||||||
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
|
tc_parser = get_parser(self.config.tool_call_parser)
|
||||||
|
except KeyError:
|
||||||
|
logger.warning(
|
||||||
|
"Tool call parser '%s' not found, falling back to 'hermes'",
|
||||||
|
self.config.tool_call_parser,
|
||||||
|
)
|
||||||
|
tc_parser = get_parser("hermes")
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with self.server.managed_server(
|
||||||
|
tokenizer=self.tokenizer,
|
||||||
|
tool_call_parser=tc_parser,
|
||||||
|
) as managed:
|
||||||
agent = HermesAgentLoop(
|
agent = HermesAgentLoop(
|
||||||
server=managed,
|
server=managed,
|
||||||
tool_schemas=tools,
|
tool_schemas=tools,
|
||||||
@@ -417,15 +468,39 @@ class HermesAgentBaseEnv(BaseEnv):
|
|||||||
)
|
)
|
||||||
result = await agent.run(messages)
|
result = await agent.run(messages)
|
||||||
|
|
||||||
# Compute reward using ToolContext (gives verifier full tool access)
|
# Skip reward computation if the agent loop produced no meaningful work
|
||||||
ctx = ToolContext(task_id)
|
# (e.g., API call failed on turn 1). No point spinning up a Modal sandbox
|
||||||
try:
|
# just to verify files that were never created.
|
||||||
reward = await self.compute_reward(item, result, ctx)
|
only_system_and_user = all(
|
||||||
except Exception as e:
|
msg.get("role") in ("system", "user") for msg in result.messages
|
||||||
logger.error("compute_reward failed: %s", e)
|
)
|
||||||
|
if result.turns_used == 0 or only_system_and_user:
|
||||||
|
logger.warning(
|
||||||
|
"Agent loop produced no output (turns=%d, msgs=%d). Skipping reward.",
|
||||||
|
result.turns_used, len(result.messages),
|
||||||
|
)
|
||||||
reward = 0.0
|
reward = 0.0
|
||||||
finally:
|
else:
|
||||||
ctx.cleanup()
|
# Compute reward using ToolContext (gives verifier full tool access)
|
||||||
|
ctx = ToolContext(task_id)
|
||||||
|
try:
|
||||||
|
reward = await self.compute_reward(item, result, ctx)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("compute_reward failed: %s", e)
|
||||||
|
reward = 0.0
|
||||||
|
finally:
|
||||||
|
ctx.cleanup()
|
||||||
|
|
||||||
|
# Track tool errors for wandb logging
|
||||||
|
if result.tool_errors:
|
||||||
|
for err in result.tool_errors:
|
||||||
|
self._tool_error_buffer.append({
|
||||||
|
"turn": err.turn,
|
||||||
|
"tool": err.tool_name,
|
||||||
|
"args": err.arguments[:150],
|
||||||
|
"error": err.error[:300],
|
||||||
|
"result": err.tool_result[:300],
|
||||||
|
})
|
||||||
|
|
||||||
# Build ScoredDataItem from ManagedServer state
|
# Build ScoredDataItem from ManagedServer state
|
||||||
# Phase 2: real tokens/masks/logprobs from SequenceNodes
|
# Phase 2: real tokens/masks/logprobs from SequenceNodes
|
||||||
|
|||||||
188
environments/patches.py
Normal file
188
environments/patches.py
Normal file
@@ -0,0 +1,188 @@
|
|||||||
|
"""
|
||||||
|
Monkey patches for making hermes-agent tools work inside async frameworks (Atropos).
|
||||||
|
|
||||||
|
Problem:
|
||||||
|
Some tools use asyncio.run() internally (e.g., mini-swe-agent's Modal backend,
|
||||||
|
web_extract). This crashes when called from inside Atropos's event loop because
|
||||||
|
asyncio.run() can't be nested.
|
||||||
|
|
||||||
|
Solution:
|
||||||
|
Replace the problematic methods with versions that use a dedicated background
|
||||||
|
thread with its own event loop. The calling code sees the same sync interface --
|
||||||
|
call a function, get a result -- but internally the async work happens on a
|
||||||
|
separate thread that doesn't conflict with Atropos's loop.
|
||||||
|
|
||||||
|
These patches are safe for normal CLI use too: when there's no running event
|
||||||
|
loop, the behavior is identical (the background thread approach works regardless).
|
||||||
|
|
||||||
|
What gets patched:
|
||||||
|
- SwerexModalEnvironment.__init__ -- creates Modal deployment on a background thread
|
||||||
|
- SwerexModalEnvironment.execute -- runs commands on the same background thread
|
||||||
|
- SwerexModalEnvironment.stop -- stops deployment on the background thread
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
Call apply_patches() once at import time (done automatically by hermes_base_env.py).
|
||||||
|
This is idempotent -- calling it multiple times is safe.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import threading
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_patches_applied = False
|
||||||
|
|
||||||
|
|
||||||
|
class _AsyncWorker:
|
||||||
|
"""
|
||||||
|
A dedicated background thread with its own event loop.
|
||||||
|
|
||||||
|
Allows sync code to submit async coroutines and block for results,
|
||||||
|
even when called from inside another running event loop. Used to
|
||||||
|
bridge sync tool interfaces with async backends (Modal, SWE-ReX).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._loop: asyncio.AbstractEventLoop = None
|
||||||
|
self._thread: threading.Thread = None
|
||||||
|
self._started = threading.Event()
|
||||||
|
|
||||||
|
def start(self):
|
||||||
|
"""Start the background event loop thread."""
|
||||||
|
self._thread = threading.Thread(target=self._run_loop, daemon=True)
|
||||||
|
self._thread.start()
|
||||||
|
self._started.wait(timeout=30)
|
||||||
|
|
||||||
|
def _run_loop(self):
|
||||||
|
"""Background thread entry point -- runs the event loop forever."""
|
||||||
|
self._loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(self._loop)
|
||||||
|
self._started.set()
|
||||||
|
self._loop.run_forever()
|
||||||
|
|
||||||
|
def run_coroutine(self, coro, timeout=600):
|
||||||
|
"""
|
||||||
|
Submit a coroutine to the background loop and block until it completes.
|
||||||
|
|
||||||
|
Safe to call from any thread, including threads that already have
|
||||||
|
a running event loop.
|
||||||
|
"""
|
||||||
|
if self._loop is None or self._loop.is_closed():
|
||||||
|
raise RuntimeError("AsyncWorker loop is not running")
|
||||||
|
future = asyncio.run_coroutine_threadsafe(coro, self._loop)
|
||||||
|
return future.result(timeout=timeout)
|
||||||
|
|
||||||
|
def stop(self):
|
||||||
|
"""Stop the background event loop and join the thread."""
|
||||||
|
if self._loop and self._loop.is_running():
|
||||||
|
self._loop.call_soon_threadsafe(self._loop.stop)
|
||||||
|
if self._thread:
|
||||||
|
self._thread.join(timeout=10)
|
||||||
|
|
||||||
|
|
||||||
|
def _patch_swerex_modal():
|
||||||
|
"""
|
||||||
|
Monkey patch SwerexModalEnvironment to use a background thread event loop
|
||||||
|
instead of asyncio.run(). This makes it safe to call from inside Atropos's
|
||||||
|
async event loop.
|
||||||
|
|
||||||
|
The patched methods have the exact same interface and behavior -- the only
|
||||||
|
difference is HOW the async work is executed internally.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from minisweagent.environments.extra.swerex_modal import (
|
||||||
|
SwerexModalEnvironment,
|
||||||
|
SwerexModalEnvironmentConfig,
|
||||||
|
)
|
||||||
|
from swerex.deployment.modal import ModalDeployment
|
||||||
|
from swerex.runtime.abstract import Command as RexCommand
|
||||||
|
except ImportError:
|
||||||
|
# mini-swe-agent or swe-rex not installed -- nothing to patch
|
||||||
|
logger.debug("mini-swe-agent Modal backend not available, skipping patch")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Save original methods so we can refer to config handling
|
||||||
|
_original_init = SwerexModalEnvironment.__init__
|
||||||
|
|
||||||
|
def _patched_init(self, **kwargs):
|
||||||
|
"""Patched __init__: creates Modal deployment on a background thread."""
|
||||||
|
self.config = SwerexModalEnvironmentConfig(**kwargs)
|
||||||
|
|
||||||
|
# Start a dedicated event loop thread for all Modal async operations
|
||||||
|
self._worker = _AsyncWorker()
|
||||||
|
self._worker.start()
|
||||||
|
|
||||||
|
# Create AND start the deployment entirely on the worker's loop/thread
|
||||||
|
# so all gRPC channels and async state are bound to that loop
|
||||||
|
async def _create_and_start():
|
||||||
|
deployment = ModalDeployment(
|
||||||
|
image=self.config.image,
|
||||||
|
startup_timeout=self.config.startup_timeout,
|
||||||
|
runtime_timeout=self.config.runtime_timeout,
|
||||||
|
deployment_timeout=self.config.deployment_timeout,
|
||||||
|
install_pipx=self.config.install_pipx,
|
||||||
|
modal_sandbox_kwargs=self.config.modal_sandbox_kwargs,
|
||||||
|
)
|
||||||
|
await deployment.start()
|
||||||
|
return deployment
|
||||||
|
|
||||||
|
self.deployment = self._worker.run_coroutine(_create_and_start())
|
||||||
|
|
||||||
|
def _patched_execute(self, command: str, cwd: str = "", *, timeout: int | None = None) -> dict[str, Any]:
|
||||||
|
"""Patched execute: runs commands on the background thread's loop."""
|
||||||
|
async def _do_execute():
|
||||||
|
return await self.deployment.runtime.execute(
|
||||||
|
RexCommand(
|
||||||
|
command=command,
|
||||||
|
shell=True,
|
||||||
|
check=False,
|
||||||
|
cwd=cwd or self.config.cwd,
|
||||||
|
timeout=timeout or self.config.timeout,
|
||||||
|
merge_output_streams=True,
|
||||||
|
env=self.config.env if self.config.env else None,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
output = self._worker.run_coroutine(_do_execute())
|
||||||
|
return {
|
||||||
|
"output": output.stdout,
|
||||||
|
"returncode": output.exit_code,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _patched_stop(self):
|
||||||
|
"""Patched stop: stops deployment on the background thread, then stops the thread."""
|
||||||
|
try:
|
||||||
|
self._worker.run_coroutine(
|
||||||
|
asyncio.wait_for(self.deployment.stop(), timeout=10),
|
||||||
|
timeout=15,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
finally:
|
||||||
|
self._worker.stop()
|
||||||
|
|
||||||
|
# Apply the patches
|
||||||
|
SwerexModalEnvironment.__init__ = _patched_init
|
||||||
|
SwerexModalEnvironment.execute = _patched_execute
|
||||||
|
SwerexModalEnvironment.stop = _patched_stop
|
||||||
|
|
||||||
|
logger.debug("Patched SwerexModalEnvironment for async-safe operation")
|
||||||
|
|
||||||
|
|
||||||
|
def apply_patches():
|
||||||
|
"""
|
||||||
|
Apply all monkey patches needed for Atropos compatibility.
|
||||||
|
|
||||||
|
Safe to call multiple times -- patches are only applied once.
|
||||||
|
Safe for normal CLI use -- patched code works identically when
|
||||||
|
there is no running event loop.
|
||||||
|
"""
|
||||||
|
global _patches_applied
|
||||||
|
if _patches_applied:
|
||||||
|
return
|
||||||
|
|
||||||
|
_patch_swerex_modal()
|
||||||
|
|
||||||
|
_patches_applied = True
|
||||||
@@ -132,7 +132,7 @@ class TerminalTestEnv(HermesAgentBaseEnv):
|
|||||||
terminal_backend="modal",
|
terminal_backend="modal",
|
||||||
# Atropos settings
|
# Atropos settings
|
||||||
group_size=3, # 3 rollouts per group
|
group_size=3, # 3 rollouts per group
|
||||||
tokenizer_name="NousResearch/DeepHermes-3-Llama-3-3B-Preview",
|
tokenizer_name="NousResearch/q-30b-t-h45-e1",
|
||||||
tool_call_parser="hermes",
|
tool_call_parser="hermes",
|
||||||
steps_per_eval=3, # Eval after all 3 steps
|
steps_per_eval=3, # Eval after all 3 steps
|
||||||
total_steps=3, # 3 groups total (1 group per step)
|
total_steps=3, # 3 groups total (1 group per step)
|
||||||
|
|||||||
@@ -25,14 +25,43 @@ Example usage in a compute_reward():
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import concurrent.futures
|
||||||
|
|
||||||
from model_tools import handle_function_call
|
from model_tools import handle_function_call
|
||||||
from tools.terminal_tool import cleanup_vm
|
from tools.terminal_tool import cleanup_vm
|
||||||
from tools.browser_tool import cleanup_browser
|
from tools.browser_tool import cleanup_browser
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Thread pool for running sync tool calls that internally use asyncio.run()
|
||||||
|
_tool_executor = concurrent.futures.ThreadPoolExecutor(max_workers=4)
|
||||||
|
|
||||||
|
|
||||||
|
def _run_tool_in_thread(tool_name: str, arguments: Dict[str, Any], task_id: str) -> str:
|
||||||
|
"""
|
||||||
|
Run a tool call in a thread pool executor so backends that use asyncio.run()
|
||||||
|
internally (modal, docker) get a clean event loop.
|
||||||
|
|
||||||
|
If we're already in an async context, uses run_in_executor.
|
||||||
|
If not (e.g., called from sync code), runs directly.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
# We're in an async context -- need to run in thread
|
||||||
|
import concurrent.futures
|
||||||
|
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
|
||||||
|
future = pool.submit(
|
||||||
|
handle_function_call, tool_name, arguments, task_id
|
||||||
|
)
|
||||||
|
return future.result(timeout=300)
|
||||||
|
except RuntimeError:
|
||||||
|
# No running event loop -- safe to call directly
|
||||||
|
return handle_function_call(tool_name, arguments, task_id)
|
||||||
|
|
||||||
|
|
||||||
class ToolContext:
|
class ToolContext:
|
||||||
"""
|
"""
|
||||||
@@ -61,10 +90,15 @@ class ToolContext:
|
|||||||
Returns:
|
Returns:
|
||||||
Dict with 'exit_code' (int) and 'output' (str)
|
Dict with 'exit_code' (int) and 'output' (str)
|
||||||
"""
|
"""
|
||||||
result = handle_function_call(
|
import os
|
||||||
|
backend = os.getenv("TERMINAL_ENV", "local")
|
||||||
|
logger.debug("ToolContext.terminal [%s backend] task=%s: %s", backend, self.task_id[:8], command[:100])
|
||||||
|
|
||||||
|
# Run in thread pool so modal/docker backends' asyncio.run() doesn't deadlock
|
||||||
|
result = _run_tool_in_thread(
|
||||||
"terminal",
|
"terminal",
|
||||||
{"command": command, "timeout": timeout},
|
{"command": command, "timeout": timeout},
|
||||||
task_id=self.task_id,
|
self.task_id,
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
return json.loads(result)
|
return json.loads(result)
|
||||||
@@ -222,7 +256,7 @@ class ToolContext:
|
|||||||
Returns:
|
Returns:
|
||||||
Raw JSON string result from the tool
|
Raw JSON string result from the tool
|
||||||
"""
|
"""
|
||||||
return handle_function_call(tool_name, arguments, task_id=self.task_id)
|
return _run_tool_in_thread(tool_name, arguments, self.task_id)
|
||||||
|
|
||||||
# -------------------------------------------------------------------------
|
# -------------------------------------------------------------------------
|
||||||
# Cleanup
|
# Cleanup
|
||||||
@@ -240,7 +274,16 @@ class ToolContext:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug("VM cleanup for task %s: %s", self.task_id, e)
|
logger.debug("VM cleanup for task %s: %s", self.task_id, e)
|
||||||
|
|
||||||
|
# Suppress browser_tool's noisy debug prints during cleanup.
|
||||||
|
# The cleanup still runs (safe), it just doesn't spam the console.
|
||||||
|
_prev_quiet = os.environ.get("HERMES_QUIET")
|
||||||
|
os.environ["HERMES_QUIET"] = "1"
|
||||||
try:
|
try:
|
||||||
cleanup_browser(self.task_id)
|
cleanup_browser(self.task_id)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug("Browser cleanup for task %s: %s", self.task_id, e)
|
logger.debug("Browser cleanup for task %s: %s", self.task_id, e)
|
||||||
|
finally:
|
||||||
|
if _prev_quiet is None:
|
||||||
|
os.environ.pop("HERMES_QUIET", None)
|
||||||
|
else:
|
||||||
|
os.environ["HERMES_QUIET"] = _prev_quiet
|
||||||
|
|||||||
@@ -1191,8 +1191,19 @@ def handle_web_function_call(function_name: str, function_args: Dict[str, Any])
|
|||||||
urls = function_args.get("urls", [])
|
urls = function_args.get("urls", [])
|
||||||
# Limit URLs to prevent abuse
|
# Limit URLs to prevent abuse
|
||||||
urls = urls[:5] if isinstance(urls, list) else []
|
urls = urls[:5] if isinstance(urls, list) else []
|
||||||
# Run async function in event loop
|
# Run async function -- use existing loop if available (Atropos),
|
||||||
return asyncio.run(web_extract_tool(urls, "markdown"))
|
# otherwise create one (normal CLI)
|
||||||
|
try:
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
# Already in an async context (Atropos) -- run in a thread
|
||||||
|
import concurrent.futures
|
||||||
|
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
|
||||||
|
return pool.submit(
|
||||||
|
lambda: asyncio.run(web_extract_tool(urls, "markdown"))
|
||||||
|
).result(timeout=120)
|
||||||
|
except RuntimeError:
|
||||||
|
# No running loop (normal CLI) -- use asyncio.run directly
|
||||||
|
return asyncio.run(web_extract_tool(urls, "markdown"))
|
||||||
|
|
||||||
else:
|
else:
|
||||||
return json.dumps({"error": f"Unknown web function: {function_name}"}, ensure_ascii=False)
|
return json.dumps({"error": f"Unknown web function: {function_name}"}, ensure_ascii=False)
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
"""File Tools Module - LLM agent file manipulation tools."""
|
"""File Tools Module - LLM agent file manipulation tools."""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
import os
|
||||||
import threading
|
import threading
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from tools.file_operations import ShellFileOperations
|
from tools.file_operations import ShellFileOperations
|
||||||
@@ -11,23 +12,85 @@ _file_ops_cache: dict = {}
|
|||||||
|
|
||||||
|
|
||||||
def _get_file_ops(task_id: str = "default") -> ShellFileOperations:
|
def _get_file_ops(task_id: str = "default") -> ShellFileOperations:
|
||||||
"""Get or create ShellFileOperations for a terminal environment."""
|
"""Get or create ShellFileOperations for a terminal environment.
|
||||||
from tools.terminal_tool import _active_environments, _env_lock, _LocalEnvironment
|
|
||||||
|
|
||||||
|
Respects the TERMINAL_ENV setting -- if the task_id doesn't have an
|
||||||
|
environment yet, creates one using the configured backend (local, docker,
|
||||||
|
modal, etc.) rather than always defaulting to local.
|
||||||
|
"""
|
||||||
|
from tools.terminal_tool import (
|
||||||
|
_active_environments, _env_lock, _create_environment,
|
||||||
|
_get_env_config, _last_activity, _start_cleanup_thread,
|
||||||
|
_check_disk_usage_warning,
|
||||||
|
)
|
||||||
|
import time
|
||||||
|
|
||||||
|
# Fast path: check cache without heavy locks
|
||||||
with _file_ops_lock:
|
with _file_ops_lock:
|
||||||
if task_id in _file_ops_cache:
|
if task_id in _file_ops_cache:
|
||||||
return _file_ops_cache[task_id]
|
return _file_ops_cache[task_id]
|
||||||
|
|
||||||
|
# Check if we need to create a new environment
|
||||||
|
needs_creation = False
|
||||||
|
with _env_lock:
|
||||||
|
if task_id not in _active_environments:
|
||||||
|
needs_creation = True
|
||||||
|
|
||||||
|
# Create environment OUTSIDE locks so we don't block other rollouts
|
||||||
|
# during slow Modal/Docker startup (~10s)
|
||||||
|
if needs_creation:
|
||||||
|
config = _get_env_config()
|
||||||
|
env_type = config["env_type"]
|
||||||
|
|
||||||
|
if env_type == "docker":
|
||||||
|
image = config["docker_image"]
|
||||||
|
elif env_type == "singularity":
|
||||||
|
image = config["singularity_image"]
|
||||||
|
elif env_type == "modal":
|
||||||
|
image = config["modal_image"]
|
||||||
|
else:
|
||||||
|
image = ""
|
||||||
|
|
||||||
|
cwd = config["cwd"]
|
||||||
|
_check_disk_usage_warning()
|
||||||
|
if not os.getenv("HERMES_QUIET"):
|
||||||
|
print(f"[FileTools] Creating new {env_type} environment for task {task_id[:8]}...", flush=True)
|
||||||
|
|
||||||
|
new_env = _create_environment(
|
||||||
|
env_type=env_type,
|
||||||
|
image=image,
|
||||||
|
cwd=cwd,
|
||||||
|
timeout=config["timeout"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Store under lock (brief) -- do NOT call _start_cleanup_thread inside
|
||||||
|
# the lock because it also acquires _env_lock (non-reentrant = deadlock)
|
||||||
|
created = False
|
||||||
with _env_lock:
|
with _env_lock:
|
||||||
if task_id not in _active_environments:
|
if task_id not in _active_environments:
|
||||||
import os
|
_active_environments[task_id] = new_env
|
||||||
env = _LocalEnvironment(cwd=os.getcwd(), timeout=60)
|
created = True
|
||||||
_active_environments[task_id] = env
|
else:
|
||||||
terminal_env = _active_environments[task_id]
|
try:
|
||||||
|
if hasattr(new_env, 'stop'):
|
||||||
|
new_env.stop()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
file_ops = ShellFileOperations(terminal_env)
|
if created:
|
||||||
|
_start_cleanup_thread()
|
||||||
|
if not os.getenv("HERMES_QUIET"):
|
||||||
|
print(f"[FileTools] {env_type} environment ready for task {task_id[:8]}", flush=True)
|
||||||
|
|
||||||
|
# Now get the environment and build file_ops
|
||||||
|
with _env_lock:
|
||||||
|
_last_activity[task_id] = time.time()
|
||||||
|
terminal_env = _active_environments[task_id]
|
||||||
|
|
||||||
|
file_ops = ShellFileOperations(terminal_env)
|
||||||
|
with _file_ops_lock:
|
||||||
_file_ops_cache[task_id] = file_ops
|
_file_ops_cache[task_id] = file_ops
|
||||||
return file_ops
|
return file_ops
|
||||||
|
|
||||||
|
|
||||||
def clear_file_ops_cache(task_id: str = None):
|
def clear_file_ops_cache(task_id: str = None):
|
||||||
@@ -56,6 +119,7 @@ def write_file_tool(path: str, content: str, task_id: str = "default") -> str:
|
|||||||
result = file_ops.write_file(path, content)
|
result = file_ops.write_file(path, content)
|
||||||
return json.dumps(result.to_dict(), ensure_ascii=False)
|
return json.dumps(result.to_dict(), ensure_ascii=False)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
print(f"[FileTools] write_file error: {type(e).__name__}: {e}", flush=True)
|
||||||
return json.dumps({"error": str(e)}, ensure_ascii=False)
|
return json.dumps({"error": str(e)}, ensure_ascii=False)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1347,40 +1347,61 @@ def terminal_tool(
|
|||||||
_start_cleanup_thread()
|
_start_cleanup_thread()
|
||||||
|
|
||||||
# Get or create environment
|
# Get or create environment
|
||||||
|
# Check under lock, but create OUTSIDE lock so we don't block
|
||||||
|
# other concurrent rollouts during slow Modal/Docker startup
|
||||||
|
needs_creation = False
|
||||||
with _env_lock:
|
with _env_lock:
|
||||||
if effective_task_id not in _active_environments:
|
if effective_task_id not in _active_environments:
|
||||||
# Check disk usage before creating new environment
|
needs_creation = True
|
||||||
_check_disk_usage_warning()
|
else:
|
||||||
|
_last_activity[effective_task_id] = time.time()
|
||||||
try:
|
env = _active_environments[effective_task_id]
|
||||||
# Build SSH config if using SSH environment
|
|
||||||
ssh_config = None
|
|
||||||
if env_type == "ssh":
|
|
||||||
ssh_config = {
|
|
||||||
"host": config.get("ssh_host", ""),
|
|
||||||
"user": config.get("ssh_user", ""),
|
|
||||||
"port": config.get("ssh_port", 22),
|
|
||||||
"key": config.get("ssh_key", ""),
|
|
||||||
}
|
|
||||||
|
|
||||||
_active_environments[effective_task_id] = _create_environment(
|
|
||||||
env_type=env_type,
|
|
||||||
image=image,
|
|
||||||
cwd=cwd,
|
|
||||||
timeout=effective_timeout,
|
|
||||||
ssh_config=ssh_config
|
|
||||||
)
|
|
||||||
except ImportError as e:
|
|
||||||
return json.dumps({
|
|
||||||
"output": "",
|
|
||||||
"exit_code": -1,
|
|
||||||
"error": f"Terminal tool disabled: mini-swe-agent not available ({e})",
|
|
||||||
"status": "disabled"
|
|
||||||
}, ensure_ascii=False)
|
|
||||||
|
|
||||||
# Update last activity time
|
if needs_creation:
|
||||||
_last_activity[effective_task_id] = time.time()
|
_check_disk_usage_warning()
|
||||||
env = _active_environments[effective_task_id]
|
if not os.getenv("HERMES_QUIET"):
|
||||||
|
print(f"[Terminal] Creating new {env_type} environment for task {effective_task_id[:8]}...", flush=True)
|
||||||
|
try:
|
||||||
|
ssh_config = None
|
||||||
|
if env_type == "ssh":
|
||||||
|
ssh_config = {
|
||||||
|
"host": config.get("ssh_host", ""),
|
||||||
|
"user": config.get("ssh_user", ""),
|
||||||
|
"port": config.get("ssh_port", 22),
|
||||||
|
"key": config.get("ssh_key", ""),
|
||||||
|
}
|
||||||
|
|
||||||
|
new_env = _create_environment(
|
||||||
|
env_type=env_type,
|
||||||
|
image=image,
|
||||||
|
cwd=cwd,
|
||||||
|
timeout=effective_timeout,
|
||||||
|
ssh_config=ssh_config
|
||||||
|
)
|
||||||
|
except ImportError as e:
|
||||||
|
return json.dumps({
|
||||||
|
"output": "",
|
||||||
|
"exit_code": -1,
|
||||||
|
"error": f"Terminal tool disabled: mini-swe-agent not available ({e})",
|
||||||
|
"status": "disabled"
|
||||||
|
}, ensure_ascii=False)
|
||||||
|
|
||||||
|
# Store under lock (brief)
|
||||||
|
with _env_lock:
|
||||||
|
if effective_task_id not in _active_environments:
|
||||||
|
_active_environments[effective_task_id] = new_env
|
||||||
|
else:
|
||||||
|
# Another thread created it while we were building -- clean up ours
|
||||||
|
try:
|
||||||
|
if hasattr(new_env, 'stop'):
|
||||||
|
new_env.stop()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
_last_activity[effective_task_id] = time.time()
|
||||||
|
env = _active_environments[effective_task_id]
|
||||||
|
if not os.getenv("HERMES_QUIET"):
|
||||||
|
print(f"[Terminal] {env_type} environment ready for task {effective_task_id[:8]}", flush=True)
|
||||||
|
|
||||||
# Check for dangerous commands (only for local/ssh in interactive modes)
|
# Check for dangerous commands (only for local/ssh in interactive modes)
|
||||||
# Skip check if force=True (user has confirmed they want to run it)
|
# Skip check if force=True (user has confirmed they want to run it)
|
||||||
@@ -1435,13 +1456,20 @@ def terminal_tool(
|
|||||||
retry_count += 1
|
retry_count += 1
|
||||||
wait_time = 2 ** retry_count
|
wait_time = 2 ** retry_count
|
||||||
print(f"⚠️ Terminal: execution error, retrying in {wait_time}s (attempt {retry_count}/{max_retries})")
|
print(f"⚠️ Terminal: execution error, retrying in {wait_time}s (attempt {retry_count}/{max_retries})")
|
||||||
|
print(f" Command: {command[:200]}")
|
||||||
|
print(f" Error: {type(e).__name__}: {e}")
|
||||||
|
print(f" Task ID: {effective_task_id}, Backend: {env_type}")
|
||||||
time.sleep(wait_time)
|
time.sleep(wait_time)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
print(f"❌ Terminal: execution failed after {max_retries} retries")
|
||||||
|
print(f" Command: {command[:200]}")
|
||||||
|
print(f" Error: {type(e).__name__}: {e}")
|
||||||
|
print(f" Task ID: {effective_task_id}, Backend: {env_type}")
|
||||||
return json.dumps({
|
return json.dumps({
|
||||||
"output": "",
|
"output": "",
|
||||||
"exit_code": -1,
|
"exit_code": -1,
|
||||||
"error": f"Command execution failed: {str(e)}"
|
"error": f"Command execution failed: {type(e).__name__}: {str(e)}"
|
||||||
}, ensure_ascii=False)
|
}, ensure_ascii=False)
|
||||||
|
|
||||||
# Got a result
|
# Got a result
|
||||||
|
|||||||
Reference in New Issue
Block a user