""" ToolExecutor - queued, batched tool dispatch for multiplexed agent trajectories. This component is responsible for: - Maintaining trajectory -> Slot affinity (workspace continuity) - Batching sandbox tool calls across trajectories to maximize container utilization - Routing external tools (ToolSchema.external=True) to a ToolServer (Phase 4.5) For now, only sandbox tools are executed: - bash - read_file - write_file """ from __future__ import annotations import asyncio import time from dataclasses import dataclass from typing import Any, Dict, List, Optional import httpx from .base import ( ArtifactArchiveRequestPayload, ArtifactArchiveResponsePayload, ArtifactListRequestPayload, ArtifactListResponsePayload, ArtifactReadRequestPayload, ArtifactReadResponsePayload, ToolCall, ToolCallPayload, ToolRegistry, ToolResult, ToolResultPayload, ToolServerExecuteRequest, ) from ..slots import Slot, SlotPool @dataclass class ToolExecutorConfig: batch_window_ms: int = 20 max_batch_size: int = 200 allow_network: bool = True require_sandbox: bool = False require_stateful_sandbox: bool = False tool_server_url: Optional[str] = None tool_server_token: Optional[str] = None @dataclass class _QueuedToolRequest: trajectory_id: str call: ToolCall timeout_s: Optional[float] future: asyncio.Future class ToolExecutor: def __init__( self, pool: SlotPool, tools: ToolRegistry, config: Optional[ToolExecutorConfig] = None, ) -> None: self.pool = pool self.tools = tools self.config = config or ToolExecutorConfig() self._queue: asyncio.Queue[Optional[_QueuedToolRequest]] = asyncio.Queue() self._task: Optional[asyncio.Task] = None self._stopping = asyncio.Event() self._slots_lock = asyncio.Lock() self._slot_by_trajectory: Dict[str, Slot] = {} self._tool_server_client: Optional[httpx.AsyncClient] = None self._tool_server_lock = asyncio.Lock() # lightweight stats for status endpoints self.total_requests: int = 0 self.total_errors: int = 0 self.latencies_s: List[float] = [] async def start(self) -> None: if self._task is None: self._task = asyncio.create_task(self._run_loop()) def queue_size(self) -> int: return self._queue.qsize() async def close(self) -> None: self._stopping.set() await self._queue.put(None) if self._task: await self._task self._task = None client = self._tool_server_client self._tool_server_client = None if client is not None: await client.aclose() # Best-effort release any remaining slots. async with self._slots_lock: slots = list(self._slot_by_trajectory.items()) self._slot_by_trajectory.clear() for _, slot in slots: try: await self.pool.release(slot, reset_workspace=False) except Exception: pass async def execute( self, trajectory_id: str, call: ToolCall, timeout_s: Optional[float] = None, ) -> ToolResult: if self._task is None: raise RuntimeError("ToolExecutor not started (call start() first)") # Allow tool args to suggest a timeout (Hermes-compatible terminal tool), # but never let the model choose "infinite" timeouts. if timeout_s is None: raw_timeout = call.arguments.get("timeout") if isinstance(raw_timeout, (int, float)): timeout_s = float(raw_timeout) if timeout_s is not None: timeout_s = max(1.0, min(float(timeout_s), 600.0)) loop = asyncio.get_running_loop() fut: asyncio.Future = loop.create_future() started = time.perf_counter() await self._queue.put(_QueuedToolRequest(trajectory_id=trajectory_id, call=call, timeout_s=timeout_s, future=fut)) try: result: ToolResult = await fut return result finally: self.latencies_s.append(time.perf_counter() - started) async def release_trajectory(self, trajectory_id: str, reset_workspace: bool = False) -> None: async with self._slots_lock: slot = self._slot_by_trajectory.pop(trajectory_id, None) if slot is not None: await self.pool.release(slot, reset_workspace=reset_workspace) async def _get_slot_if_present(self, trajectory_id: str) -> Optional[Slot]: async with self._slots_lock: return self._slot_by_trajectory.get(trajectory_id) # --------------------------------------------------------------------- # Artifact helpers (optional) # --------------------------------------------------------------------- async def read_artifact(self, req: ArtifactReadRequestPayload) -> ArtifactReadResponsePayload: slot = await self._get_slot_if_present(req.trajectory_id) if slot is None: return ArtifactReadResponsePayload(success=False, error="No active slot for trajectory (run a sandbox tool first)") data = await self.pool.executor.read_artifact( slot, req.path, encoding=req.encoding, max_bytes=req.max_bytes, include_sha256=req.include_sha256, ) if isinstance(data, dict): data = dict(data) data.pop("http_status", None) try: return ArtifactReadResponsePayload(**(data or {})) except Exception as e: return ArtifactReadResponsePayload(success=False, error=f"Invalid artifact read response: {e}") async def list_artifacts(self, req: ArtifactListRequestPayload) -> ArtifactListResponsePayload: slot = await self._get_slot_if_present(req.trajectory_id) if slot is None: return ArtifactListResponsePayload(success=False, error="No active slot for trajectory (run a sandbox tool first)") data = await self.pool.executor.list_artifacts( slot, req.path, recursive=req.recursive, max_entries=req.max_entries, ) if isinstance(data, dict): data = dict(data) data.pop("http_status", None) try: return ArtifactListResponsePayload(**(data or {})) except Exception as e: return ArtifactListResponsePayload(success=False, error=f"Invalid artifact list response: {e}") async def archive_artifacts(self, req: ArtifactArchiveRequestPayload) -> ArtifactArchiveResponsePayload: slot = await self._get_slot_if_present(req.trajectory_id) if slot is None: return ArtifactArchiveResponsePayload(success=False, error="No active slot for trajectory (run a sandbox tool first)") data = await self.pool.executor.archive_artifacts( slot, req.path, archive_format=req.format, max_bytes=req.max_bytes, max_entries=req.max_entries, ) if isinstance(data, dict): data = dict(data) data.pop("http_status", None) try: return ArtifactArchiveResponsePayload(**(data or {})) except Exception as e: return ArtifactArchiveResponsePayload(success=False, error=f"Invalid artifact archive response: {e}") async def _get_or_acquire_slot(self, trajectory_id: str) -> Slot: async with self._slots_lock: existing = self._slot_by_trajectory.get(trajectory_id) if existing is not None: return existing slot = await self.pool.acquire(trajectory_id) async with self._slots_lock: existing = self._slot_by_trajectory.get(trajectory_id) if existing is not None: # Another coroutine won the race; return its slot. await self.pool.release(slot, reset_workspace=False) return existing self._slot_by_trajectory[trajectory_id] = slot return slot async def _run_loop(self) -> None: pending: List[_QueuedToolRequest] = [] deadline: Optional[float] = None batch_window_s = max(0.0, self.config.batch_window_ms / 1000.0) max_batch = max(1, self.config.max_batch_size) while True: if self._stopping.is_set() and self._queue.empty() and not pending: break timeout = None if pending and deadline is not None: timeout = max(0.0, deadline - time.perf_counter()) try: item = await asyncio.wait_for(self._queue.get(), timeout=timeout) if item is None: continue pending.append(item) if len(pending) == 1: deadline = time.perf_counter() + batch_window_s if len(pending) < max_batch: continue except asyncio.TimeoutError: # batch window elapsed pass if not pending: deadline = None continue batch = pending pending = [] deadline = None await self._execute_batch(batch) async def _get_tool_server_client(self) -> httpx.AsyncClient: url = self.config.tool_server_url if not url: raise RuntimeError("ToolServer not configured") if self._tool_server_client is not None: return self._tool_server_client async with self._tool_server_lock: if self._tool_server_client is None: self._tool_server_client = httpx.AsyncClient(base_url=url.rstrip("/")) return self._tool_server_client def _tool_server_headers(self) -> Dict[str, str]: token = self.config.tool_server_token if not token: return {} return {"Authorization": f"Bearer {token}"} async def _execute_external(self, req: _QueuedToolRequest) -> ToolResult: client = await self._get_tool_server_client() slot_id: Optional[str] = None container_addr: Optional[str] = None slot = await self._get_slot_if_present(req.trajectory_id) if slot is not None: slot_id = slot.slot_id container_addr = slot.container_addr payload = ToolServerExecuteRequest( trajectory_id=req.trajectory_id, tool=ToolCallPayload.from_tool_call(req.call), timeout_s=req.timeout_s, slot_id=slot_id, container_addr=container_addr, ) try: resp = await client.post( "/execute", json=payload.model_dump(), headers=self._tool_server_headers(), timeout=req.timeout_s, ) resp.raise_for_status() data = resp.json() parsed = ToolResultPayload(**data) result = parsed.to_tool_result() if result.uniq_id is None: result.uniq_id = req.call.uniq_id return result except Exception as e: return ToolResult( success=False, error=f"External tool failed: {e}", uniq_id=req.call.uniq_id, ) async def _execute_batch(self, batch: List[_QueuedToolRequest]) -> None: # Resolve tool schemas once per request and separate sandbox/external/unknown. sandbox_items: List[_QueuedToolRequest] = [] external_items: List[_QueuedToolRequest] = [] unknown_items: List[_QueuedToolRequest] = [] for it in batch: tool = self.tools.get(it.call.name) if tool is None: unknown_items.append(it) continue schema = tool.schema if not schema.external: sandbox_items.append(it) else: external_items.append(it) for it in unknown_items: self.total_requests += 1 self.total_errors += 1 if not it.future.done(): it.future.set_result( ToolResult( success=False, error=f"Unknown tool: {it.call.name}", uniq_id=it.call.uniq_id, ) ) if external_items: if not self.config.tool_server_url: for it in external_items: self.total_requests += 1 self.total_errors += 1 if not it.future.done(): it.future.set_result( ToolResult( success=False, error=f"External tool not available (ToolServer not configured): {it.call.name}", uniq_id=it.call.uniq_id, ) ) else: results = await asyncio.gather(*[self._execute_external(it) for it in external_items]) for it, res in zip(external_items, results): self.total_requests += 1 if not getattr(res, "success", False): self.total_errors += 1 if not it.future.done(): it.future.set_result(res) if not sandbox_items: return # Acquire slots for the distinct trajectories in this batch. try: traj_ids = list({it.trajectory_id for it in sandbox_items}) slots = await asyncio.gather(*[self._get_or_acquire_slot(tid) for tid in traj_ids]) slot_by_traj = dict(zip(traj_ids, slots)) except Exception as e: for it in sandbox_items: self.total_requests += 1 self.total_errors += 1 if not it.future.done(): it.future.set_result( ToolResult( success=False, error=f"Failed to acquire slot: {e}", uniq_id=it.call.uniq_id, ) ) return # Group by timeout so we don't accidentally make short timeouts wait on long ones. by_timeout: Dict[float, List[_QueuedToolRequest]] = {} default_timeout = None if self.pool.executor.timeout.total is not None: default_timeout = float(self.pool.executor.timeout.total) for it in sandbox_items: t = it.timeout_s if t is None: t = default_timeout if t is None: t = 30.0 by_timeout.setdefault(float(t), []).append(it) for timeout_s, items in by_timeout.items(): requests = [] dispatched: List[_QueuedToolRequest] = [] for it in items: slot = slot_by_traj[it.trajectory_id] tool_name = it.call.name args = dict(it.call.arguments) # Hermes compatibility: treat `terminal` as an alias of sandbox `bash`. if tool_name == "terminal": if args.get("background"): self.total_requests += 1 self.total_errors += 1 if not it.future.done(): it.future.set_result( ToolResult( success=False, error="terminal background execution is not supported in sandbox", uniq_id=it.call.uniq_id, ) ) continue tool_name = "bash" # `timeout` is handled at the ToolExecutor level, not passed to the sandbox tool args. args.pop("timeout", None) elif tool_name == "terminal_stateful": tool_name = "bash_stateful" args.pop("timeout", None) elif tool_name == "tmux": # `tmux` is a sandbox tool backed by the stateful session manager. # Network policy is env-controlled. args.pop("allow_network", None) if tool_name == "bash": # Network policy is set by the environment/executor, not by the model. args.pop("allow_network", None) args.pop("require_sandbox", None) args["allow_network"] = bool(self.config.allow_network) args["require_sandbox"] = bool(self.config.require_sandbox) # `timeout` is handled at the ToolExecutor level, not passed to the sandbox tool args. args.pop("timeout", None) elif tool_name == "bash_stateful": # Network policy is set by the environment/executor, not by the model. args.pop("allow_network", None) args.pop("require_sandbox", None) args.pop("require_stateful_sandbox", None) args["allow_network"] = bool(self.config.allow_network) args["require_stateful_sandbox"] = bool(self.config.require_stateful_sandbox) args.pop("timeout", None) elif tool_name == "tmux": # Network policy applies to the underlying stateful session. args.pop("allow_network", None) args.pop("require_sandbox", None) args.pop("require_stateful_sandbox", None) args["allow_network"] = bool(self.config.allow_network) args["require_stateful_sandbox"] = bool(self.config.require_stateful_sandbox) requests.append((slot, tool_name, args)) dispatched.append(it) results = None try: if not dispatched: continue results = await self.pool.execute_batch(requests, timeout=timeout_s) except Exception as e: for it in items: self.total_requests += 1 self.total_errors += 1 if not it.future.done(): it.future.set_result( ToolResult( success=False, error=f"Batch execution failed: {e}", uniq_id=it.call.uniq_id, ) ) continue for it, res in zip(dispatched, results): self.total_requests += 1 if not getattr(res, "success", False): self.total_errors += 1 tool_result = res.to_tool_result() tool_result.uniq_id = it.call.uniq_id if not it.future.done(): it.future.set_result(tool_result)