diff --git a/atropos/api/tool_executor_server.py b/atropos/api/tool_executor_server.py index 21a98c1b16..1c1ba345c1 100644 --- a/atropos/api/tool_executor_server.py +++ b/atropos/api/tool_executor_server.py @@ -1,7 +1,7 @@ """ Tool Executor API (Phase 4) -This service provides a queued, batched execution layer on top of SlotPool. +This service provides a queued, batched execution layer on top of a ToolBackend. It mirrors the stateful FastAPI + app.state pattern used in: atropos/atroposlib/api/server.py @@ -18,7 +18,7 @@ from pathlib import Path from fastapi import FastAPI, Header, HTTPException, status from pydantic import BaseModel, Field -from ..slots import SlotPool, SlotPoolConfig +from ..backends.nomad_backend import NomadBackendConfig, NomadToolBackend from ..tools import ToolRegistry, build_tool_registry from ..tools.base import ( ArtifactArchiveRequestPayload, @@ -123,22 +123,23 @@ async def _startup() -> None: tool_server_url=cfg.tool_server_url, ) - pool = SlotPool( - SlotPoolConfig( + backend = NomadToolBackend( + NomadBackendConfig( nomad_address=cfg.nomad_address, - job_id=cfg.job_id, - image=cfg.image, + sandbox_job_id=cfg.job_id, + sandbox_image=cfg.image, slots_per_container=cfg.slots_per_container, min_containers=cfg.min_containers, max_containers=cfg.max_containers, privileged=cfg.privileged, - acquire_timeout=cfg.acquire_timeout_s, + acquire_timeout_s=cfg.acquire_timeout_s, + purge_job_on_start=False, ) ) - await pool.start() + await backend.start() executor = ToolExecutor( - pool=pool, + backend=backend, tools=tools, config=ToolExecutorConfig( batch_window_ms=cfg.batch_window_ms, @@ -151,21 +152,21 @@ async def _startup() -> None: await executor.start() app.state.cfg = cfg - app.state.pool = pool + app.state.backend = backend app.state.executor = executor @app.on_event("shutdown") async def _shutdown() -> None: executor: Optional[ToolExecutor] = getattr(app.state, "executor", None) - pool: Optional[SlotPool] = getattr(app.state, "pool", None) + backend: Optional[NomadToolBackend] = getattr(app.state, "backend", None) cfg: Optional[ToolExecutorServerConfig] = getattr(app.state, "cfg", None) if executor is not None: await executor.close() - if pool is not None: - await pool.stop(purge_job=bool(cfg.purge_job_on_shutdown) if cfg else False) + if backend is not None: + await backend.stop(purge=bool(cfg.purge_job_on_shutdown) if cfg else False) @app.get("/health") @@ -176,13 +177,13 @@ async def health() -> Dict[str, Any]: @app.get("/status") async def status_endpoint() -> Dict[str, Any]: executor: ToolExecutor = app.state.executor - pool: SlotPool = app.state.pool + backend: NomadToolBackend = app.state.backend return { "queue_size": executor.queue_size(), "total_requests": executor.total_requests, "total_errors": executor.total_errors, - "pool": pool.get_stats(), + "pool": backend.get_stats(), } diff --git a/atropos/backends/__init__.py b/atropos/backends/__init__.py new file mode 100644 index 0000000000..f3b911959b --- /dev/null +++ b/atropos/backends/__init__.py @@ -0,0 +1,27 @@ +from __future__ import annotations + +from typing import Any + +from .base import ToolBackend +from .modal_backend import ModalBackendConfig, ModalToolBackend +from .nomad_backend import NomadBackendConfig, NomadToolBackend + + +def create_tool_backend(cfg: Any) -> ToolBackend: + mode = str(getattr(cfg, "tool_pool_mode", "nomad")).strip().lower() + if mode == "nomad": + return NomadToolBackend(NomadBackendConfig.from_agent_env_config(cfg)) + if mode == "modal": + return ModalToolBackend(ModalBackendConfig.from_agent_env_config(cfg)) + raise ValueError(f"Unknown tool_pool_mode: {mode}") + + +__all__ = [ + "ToolBackend", + "create_tool_backend", + "NomadBackendConfig", + "NomadToolBackend", + "ModalBackendConfig", + "ModalToolBackend", +] + diff --git a/atropos/backends/base.py b/atropos/backends/base.py new file mode 100644 index 0000000000..4540b78e39 --- /dev/null +++ b/atropos/backends/base.py @@ -0,0 +1,89 @@ +""" +Backend interfaces for AgentEnv tool execution. + +The goal of this module is to decouple ToolExecutor / AgentEnv from any single +execution backend (Nomad/Docker today; Modal later). +""" + +from __future__ import annotations + +from typing import Any, Dict, List, Optional, Protocol, Tuple + +from ..slots.executor import ExecutionResult +from ..slots.slot import Slot + + +class ToolBackend(Protocol): + """ + Minimal interface required by ToolExecutor. + + Backends provide: + - lifecycle (start/stop) + - slot acquisition/release (workspace affinity) + - batched tool execution across slots + - optional artifact helpers (for env verification / demos) + """ + + @property + def default_timeout_s(self) -> Optional[float]: + """Default sandbox execution timeout in seconds (if any).""" + + async def start(self) -> None: + """Start the backend (provision workers/containers, health checks, etc).""" + + async def stop(self, *, purge: bool = False) -> None: + """Stop the backend and optionally purge remote resources.""" + + async def acquire(self, trajectory_id: Optional[str] = None) -> Slot: + """Acquire a slot for a trajectory (workspace affinity).""" + + async def release(self, slot: Slot, *, reset_workspace: bool = False) -> None: + """Release a slot back to the pool.""" + + async def execute_batch( + self, + requests: List[Tuple[Slot, str, Dict[str, Any]]], + *, + timeout_s: Optional[float] = None, + ) -> List[ExecutionResult]: + """Execute a batch of sandbox tool calls and return results in order.""" + + # --------------------------------------------------------------------- + # Optional artifact helpers (supported by the Nomad sandbox-server today) + # --------------------------------------------------------------------- + + async def read_artifact( + self, + slot: Slot, + path: str, + *, + encoding: str = "text", + max_bytes: Optional[int] = None, + include_sha256: bool = False, + timeout_s: Optional[float] = None, + ) -> Dict[str, Any]: + raise NotImplementedError + + async def list_artifacts( + self, + slot: Slot, + path: str = ".", + *, + recursive: bool = False, + max_entries: Optional[int] = None, + timeout_s: Optional[float] = None, + ) -> Dict[str, Any]: + raise NotImplementedError + + async def archive_artifacts( + self, + slot: Slot, + path: str = ".", + *, + archive_format: str = "tar.gz", + max_bytes: Optional[int] = None, + max_entries: Optional[int] = None, + timeout_s: Optional[float] = None, + ) -> Dict[str, Any]: + raise NotImplementedError + diff --git a/atropos/backends/modal_backend.py b/atropos/backends/modal_backend.py new file mode 100644 index 0000000000..3affe08e35 --- /dev/null +++ b/atropos/backends/modal_backend.py @@ -0,0 +1,73 @@ +""" +Modal tool backend (stub). + +We intentionally ship a placeholder implementation so AgentEnv can expose a +backend switch without forcing Modal as a hard dependency for Hermes-Agent. + +When org access is available, this backend will be implemented by running a +long-lived Modal worker (or pool) that owns N slots and exposes `execute_batch`. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple + +from ..slots.executor import ExecutionResult +from ..slots.slot import Slot +from .base import ToolBackend + + +@dataclass(frozen=True) +class ModalBackendConfig: + # Placeholders for future implementation. + app_name: str = "atropos-sandbox" + function_name: str = "sandbox_server" + volume_name: Optional[str] = None + volume_mount_path: str = "/data" + + @classmethod + def from_agent_env_config(cls, cfg: Any) -> "ModalBackendConfig": + return cls( + app_name=str(getattr(cfg, "modal_app_name", cls.app_name)), + function_name=str(getattr(cfg, "modal_function_name", cls.function_name)), + volume_name=(getattr(cfg, "modal_volume_name", None) or None), + volume_mount_path=str(getattr(cfg, "modal_volume_mount_path", cls.volume_mount_path)), + ) + + +class ModalToolBackend(ToolBackend): + def __init__(self, config: ModalBackendConfig): + self.config = config + + @property + def default_timeout_s(self) -> Optional[float]: + return None + + def _unavailable(self) -> RuntimeError: + return RuntimeError( + "Modal tool backend is not implemented yet. " + "Keep `--env.tool_pool_mode nomad` for now." + ) + + async def start(self) -> None: + raise self._unavailable() + + async def stop(self, *, purge: bool = False) -> None: # noqa: ARG002 + # If start() isn't implemented, stop() is also unavailable. + raise self._unavailable() + + async def acquire(self, trajectory_id: Optional[str] = None) -> Slot: # noqa: ARG002 + raise self._unavailable() + + async def release(self, slot: Slot, *, reset_workspace: bool = False) -> None: # noqa: ARG002 + raise self._unavailable() + + async def execute_batch( + self, + requests: List[Tuple[Slot, str, Dict[str, Any]]], + *, + timeout_s: Optional[float] = None, # noqa: ARG002 + ) -> List[ExecutionResult]: + raise self._unavailable() + diff --git a/atropos/backends/nomad_backend.py b/atropos/backends/nomad_backend.py new file mode 100644 index 0000000000..8bfc0df8ea --- /dev/null +++ b/atropos/backends/nomad_backend.py @@ -0,0 +1,148 @@ +""" +Nomad/Docker tool backend. + +This backend is the current default for AgentEnv: it provisions a Nomad job +running `sandbox_server.py` and multiplexes stateless slots inside each container. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple + +from ..slots import Slot, SlotPool, SlotPoolConfig +from ..slots.executor import ExecutionResult +from .base import ToolBackend + + +@dataclass(frozen=True) +class NomadBackendConfig: + nomad_address: str + sandbox_job_id: str + sandbox_image: str + slots_per_container: int + min_containers: int + max_containers: int + privileged: bool + acquire_timeout_s: float + purge_job_on_start: bool + + @classmethod + def from_agent_env_config(cls, cfg: Any) -> "NomadBackendConfig": + return cls( + nomad_address=str(getattr(cfg, "nomad_address")), + sandbox_job_id=str(getattr(cfg, "sandbox_job_id")), + sandbox_image=str(getattr(cfg, "sandbox_image")), + slots_per_container=int(getattr(cfg, "slots_per_container")), + min_containers=int(getattr(cfg, "min_containers")), + max_containers=int(getattr(cfg, "max_containers")), + privileged=bool(getattr(cfg, "privileged")), + acquire_timeout_s=float(getattr(cfg, "acquire_timeout_s")), + purge_job_on_start=bool(getattr(cfg, "purge_job_on_start", False)), + ) + + +class NomadToolBackend(ToolBackend): + def __init__(self, config: NomadBackendConfig): + self.config = config + self.pool = SlotPool( + SlotPoolConfig( + nomad_address=config.nomad_address, + job_id=config.sandbox_job_id, + image=config.sandbox_image, + slots_per_container=config.slots_per_container, + min_containers=config.min_containers, + max_containers=config.max_containers, + privileged=config.privileged, + acquire_timeout=config.acquire_timeout_s, + purge_job_on_start=bool(config.purge_job_on_start), + ) + ) + + @property + def default_timeout_s(self) -> Optional[float]: + t = getattr(self.pool.executor, "timeout", None) + total = getattr(t, "total", None) + try: + return float(total) if total is not None else None + except Exception: + return None + + async def start(self) -> None: + await self.pool.start() + + async def stop(self, *, purge: bool = False) -> None: + await self.pool.stop(purge_job=purge) + + async def acquire(self, trajectory_id: Optional[str] = None) -> Slot: + return await self.pool.acquire(trajectory_id) + + async def release(self, slot: Slot, *, reset_workspace: bool = False) -> None: + await self.pool.release(slot, reset_workspace=reset_workspace) + + async def execute_batch( + self, + requests: List[Tuple[Slot, str, Dict[str, Any]]], + *, + timeout_s: Optional[float] = None, + ) -> List[ExecutionResult]: + return await self.pool.execute_batch(requests, timeout=timeout_s) + + async def read_artifact( + self, + slot: Slot, + path: str, + *, + encoding: str = "text", + max_bytes: Optional[int] = None, + include_sha256: bool = False, + timeout_s: Optional[float] = None, + ) -> Dict[str, Any]: + return await self.pool.executor.read_artifact( + slot, + path, + encoding=encoding, + max_bytes=max_bytes, + include_sha256=include_sha256, + timeout=timeout_s, + ) + + async def list_artifacts( + self, + slot: Slot, + path: str = ".", + *, + recursive: bool = False, + max_entries: Optional[int] = None, + timeout_s: Optional[float] = None, + ) -> Dict[str, Any]: + return await self.pool.executor.list_artifacts( + slot, + path, + recursive=recursive, + max_entries=max_entries, + timeout=timeout_s, + ) + + async def archive_artifacts( + self, + slot: Slot, + path: str = ".", + *, + archive_format: str = "tar.gz", + max_bytes: Optional[int] = None, + max_entries: Optional[int] = None, + timeout_s: Optional[float] = None, + ) -> Dict[str, Any]: + return await self.pool.executor.archive_artifacts( + slot, + path, + archive_format=archive_format, + max_bytes=max_bytes, + max_entries=max_entries, + timeout=timeout_s, + ) + + def get_stats(self) -> Dict[str, Any]: + return self.pool.get_stats() + diff --git a/atropos/envs/agent_env.py b/atropos/envs/agent_env.py index c1bd9ab72e..10a863dd44 100644 --- a/atropos/envs/agent_env.py +++ b/atropos/envs/agent_env.py @@ -19,14 +19,14 @@ from atroposlib.envs.base import APIServerConfig, BaseEnv, BaseEnvConfig, Item, from atroposlib.envs.server_handling.server_baseline import AsyncSemWithAdaptiveWeight from ..agent import AgentConfig, AgentResult, AtroposAgent -from ..slots import SlotPool, SlotPoolConfig +from ..backends import ToolBackend, create_tool_backend from ..tools import ToolRegistry, build_tool_registry from ..tools.tool_executor import ToolExecutor, ToolExecutorConfig # Main BaseEnv child classes. Child class THESE to get agent+tooling functionality easily. class AgentEnvConfig(BaseEnvConfig): - tool_pool_mode: str = Field(default="nomad", description="Tool execution backend (only 'nomad' is supported)") + tool_pool_mode: str = Field(default="nomad", description="Tool execution backend ('nomad' or 'modal')") allow_network: bool = Field( default=True, @@ -61,6 +61,12 @@ class AgentEnvConfig(BaseEnvConfig): ) purge_job_on_shutdown: bool = Field(default=True, description="Nomad mode: stop/purge job on shutdown") + # modal mode settings (stub; implementation pending) + modal_app_name: str = Field(default="atropos-sandbox", description="Modal app name (stub)") + modal_function_name: str = Field(default="sandbox_server", description="Modal function/actor name (stub)") + modal_volume_name: Optional[str] = Field(default=None, description="Modal Volume name for persistent storage (stub)") + modal_volume_mount_path: str = Field(default="/data", description="Modal Volume mount path (stub)") + # basic agent defaults agent_max_steps: int = Field(default=50, description="Max ReACT steps per trajectory") agent_temperature: float = Field(default=0.7, description="Sampling temperature") @@ -108,7 +114,7 @@ class AgentEnv(BaseEnv, ABC, Generic[AgentEnvConfigT]): self.tools: ToolRegistry = self.build_tools() - self._pool: Optional[Any] = None + self._backend: Optional[ToolBackend] = None self._tool_executor: Optional[ToolExecutor] = None self._tool_server_inprocess: bool = False self._trajectory_workspace_meta: Dict[str, Dict[str, Any]] = {} @@ -263,27 +269,11 @@ class AgentEnv(BaseEnv, ABC, Generic[AgentEnvConfigT]): tool_server_url = "http://toolserver" self._tool_server_inprocess = True - if self.config.tool_pool_mode != "nomad": - # TODO Add Modal here, maybe in-process, but not safe to have that tbh - raise RuntimeError("tool_pool_mode must be 'nomad' (local/in-process pools are not supported)") - - pool = SlotPool( - SlotPoolConfig( - nomad_address=self.config.nomad_address, - job_id=self.config.sandbox_job_id, - image=self.config.sandbox_image, - slots_per_container=self.config.slots_per_container, - min_containers=self.config.min_containers, - max_containers=self.config.max_containers, - privileged=self.config.privileged, - acquire_timeout=self.config.acquire_timeout_s, - purge_job_on_start=bool(self.config.purge_job_on_start), - ) - ) - await pool.start() + backend = create_tool_backend(self.config) + await backend.start() executor = ToolExecutor( - pool=pool, + backend=backend, tools=self.tools, config=ToolExecutorConfig( batch_window_ms=self.config.tool_batch_window_ms, @@ -299,21 +289,21 @@ class AgentEnv(BaseEnv, ABC, Generic[AgentEnvConfigT]): if tool_server_client is not None: executor._tool_server_client = tool_server_client # type: ignore[attr-defined] - self._pool = pool + self._backend = backend self._tool_executor = executor async def shutdown_tool_backend(self) -> None: executor = self._tool_executor - pool = self._pool + backend = self._backend inprocess_tool_server = self._tool_server_inprocess self._tool_executor = None - self._pool = None + self._backend = None self._tool_server_inprocess = False if executor is not None: await executor.close() - if pool is not None: - await pool.stop(purge_job=bool(self.config.purge_job_on_shutdown)) + if backend is not None: + await backend.stop(purge=bool(self.config.purge_job_on_shutdown)) if inprocess_tool_server: from ..api.tool_server import app as tool_server_app diff --git a/atropos/tools/tool_executor.py b/atropos/tools/tool_executor.py index 148354db3b..0f8d9c1aa0 100644 --- a/atropos/tools/tool_executor.py +++ b/atropos/tools/tool_executor.py @@ -35,7 +35,8 @@ from .base import ( ToolResultPayload, ToolServerExecuteRequest, ) -from ..slots import Slot, SlotPool +from ..backends.base import ToolBackend +from ..slots import Slot @dataclass @@ -60,11 +61,11 @@ class _QueuedToolRequest: class ToolExecutor: def __init__( self, - pool: SlotPool, + backend: ToolBackend, tools: ToolRegistry, config: Optional[ToolExecutorConfig] = None, ) -> None: - self.pool = pool + self.backend = backend self.tools = tools self.config = config or ToolExecutorConfig() @@ -109,7 +110,7 @@ class ToolExecutor: for _, slot in slots: try: - await self.pool.release(slot, reset_workspace=False) + await self.backend.release(slot, reset_workspace=False) except Exception: pass @@ -146,7 +147,7 @@ class ToolExecutor: slot = self._slot_by_trajectory.pop(trajectory_id, None) if slot is not None: - await self.pool.release(slot, reset_workspace=reset_workspace) + await self.backend.release(slot, reset_workspace=reset_workspace) async def _get_slot_if_present(self, trajectory_id: str) -> Optional[Slot]: async with self._slots_lock: @@ -160,7 +161,7 @@ class ToolExecutor: slot = await self._get_slot_if_present(req.trajectory_id) if slot is None: return ArtifactReadResponsePayload(success=False, error="No active slot for trajectory (run a sandbox tool first)") - data = await self.pool.executor.read_artifact( + data = await self.backend.read_artifact( slot, req.path, encoding=req.encoding, @@ -179,7 +180,7 @@ class ToolExecutor: slot = await self._get_slot_if_present(req.trajectory_id) if slot is None: return ArtifactListResponsePayload(success=False, error="No active slot for trajectory (run a sandbox tool first)") - data = await self.pool.executor.list_artifacts( + data = await self.backend.list_artifacts( slot, req.path, recursive=req.recursive, @@ -197,7 +198,7 @@ class ToolExecutor: slot = await self._get_slot_if_present(req.trajectory_id) if slot is None: return ArtifactArchiveResponsePayload(success=False, error="No active slot for trajectory (run a sandbox tool first)") - data = await self.pool.executor.archive_artifacts( + data = await self.backend.archive_artifacts( slot, req.path, archive_format=req.format, @@ -218,13 +219,13 @@ class ToolExecutor: if existing is not None: return existing - slot = await self.pool.acquire(trajectory_id) + slot = await self.backend.acquire(trajectory_id) async with self._slots_lock: existing = self._slot_by_trajectory.get(trajectory_id) if existing is not None: # Another coroutine won the race; return its slot. - await self.pool.release(slot, reset_workspace=False) + await self.backend.release(slot, reset_workspace=False) return existing self._slot_by_trajectory[trajectory_id] = slot return slot @@ -400,9 +401,7 @@ class ToolExecutor: # Group by timeout so we don't accidentally make short timeouts wait on long ones. by_timeout: Dict[float, List[_QueuedToolRequest]] = {} - default_timeout = None - if self.pool.executor.timeout.total is not None: - default_timeout = float(self.pool.executor.timeout.total) + default_timeout = self.backend.default_timeout_s for it in sandbox_items: t = it.timeout_s @@ -476,7 +475,7 @@ class ToolExecutor: try: if not dispatched: continue - results = await self.pool.execute_batch(requests, timeout=timeout_s) + results = await self.backend.execute_batch(requests, timeout_s=timeout_s) except Exception as e: for it in items: self.total_requests += 1