mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-28 15:01:34 +08:00
Compare commits
50 Commits
feat/volce
...
atropos-in
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
24c13bc412 | ||
|
|
06e9422324 | ||
|
|
907616a692 | ||
|
|
33a00d9b8e | ||
|
|
a2312076da | ||
|
|
499490d06a | ||
|
|
35b2250b36 | ||
|
|
395392e5de | ||
|
|
2041b354a9 | ||
|
|
3951eab399 | ||
|
|
62001e3bf5 | ||
|
|
c8b30e9efa | ||
|
|
f82c3081f2 | ||
|
|
a69924631c | ||
|
|
4619d1c8ef | ||
|
|
98d945f6de | ||
|
|
507b77c4ac | ||
|
|
b99c2a2644 | ||
|
|
975c849308 | ||
|
|
9dc27880cd | ||
|
|
3b9c53e6db | ||
|
|
05dd31131f | ||
|
|
36ea883d45 | ||
|
|
6be8cdeeca | ||
|
|
0bc914b00c | ||
|
|
411e7f8ff4 | ||
|
|
eb2e6b73fe | ||
|
|
664acf7426 | ||
|
|
fd1c3da305 | ||
|
|
4d619bcd21 | ||
|
|
beac2ee06a | ||
|
|
487487406d | ||
|
|
87464821d8 | ||
|
|
661d8f4d6c | ||
|
|
bf13a848ef | ||
|
|
88286f6da3 | ||
|
|
5b82190460 | ||
|
|
ea7aa0b0d4 | ||
|
|
7130fa50cb | ||
|
|
5a9c98a771 | ||
|
|
6cb4fe948a | ||
|
|
30221d8c20 | ||
|
|
b5b1fef20a | ||
|
|
16fb41f9cc | ||
|
|
4939130485 | ||
|
|
8dccd6569e | ||
|
|
db348dc467 | ||
|
|
88722e230d | ||
|
|
68fb0efe0e | ||
|
|
e38c274f8d |
115
.clinerules
Normal file
115
.clinerules
Normal file
@@ -0,0 +1,115 @@
|
||||
# Cline's Memory Bank
|
||||
|
||||
I am Cline, an expert software engineer with a unique characteristic: my memory resets completely between sessions. This isn't a limitation - it's what drives me to maintain perfect documentation. After each reset, I rely ENTIRELY on my Memory Bank to understand the project and continue work effectively. I MUST read ALL memory bank files at the start of EVERY task - this is not optional.
|
||||
|
||||
## Memory Bank Structure
|
||||
|
||||
The Memory Bank consists of core files and optional context files, all in Markdown format. Files build upon each other in a clear hierarchy:
|
||||
|
||||
flowchart TD
|
||||
PB[projectbrief.md] --> PC[productContext.md]
|
||||
PB --> SP[systemPatterns.md]
|
||||
PB --> TC[techContext.md]
|
||||
|
||||
PC --> AC[activeContext.md]
|
||||
SP --> AC
|
||||
TC --> AC
|
||||
|
||||
AC --> P[progress.md]
|
||||
|
||||
### Core Files (Required)
|
||||
1. `projectbrief.md`
|
||||
- Foundation document that shapes all other files
|
||||
- Created at project start if it doesn't exist
|
||||
- Defines core requirements and goals
|
||||
- Source of truth for project scope
|
||||
|
||||
2. `productContext.md`
|
||||
- Why this project exists
|
||||
- Problems it solves
|
||||
- How it should work
|
||||
- User experience goals
|
||||
|
||||
3. `activeContext.md`
|
||||
- Current work focus
|
||||
- Recent changes
|
||||
- Next steps
|
||||
- Active decisions and considerations
|
||||
- Important patterns and preferences
|
||||
- Learnings and project insights
|
||||
|
||||
4. `systemPatterns.md`
|
||||
- System architecture
|
||||
- Key technical decisions
|
||||
- Design patterns in use
|
||||
- Component relationships
|
||||
- Critical implementation paths
|
||||
|
||||
5. `techContext.md`
|
||||
- Technologies used
|
||||
- Development setup
|
||||
- Technical constraints
|
||||
- Dependencies
|
||||
- Tool usage patterns
|
||||
|
||||
6. `progress.md`
|
||||
- What works
|
||||
- What's left to build
|
||||
- Current status
|
||||
- Known issues
|
||||
- Evolution of project decisions
|
||||
|
||||
### Additional Context
|
||||
Create additional files/folders within memory-bank/ when they help organize:
|
||||
- Complex feature documentation
|
||||
- Integration specifications
|
||||
- API documentation
|
||||
- Testing strategies
|
||||
- Deployment procedures
|
||||
|
||||
## Core Workflows
|
||||
|
||||
### Plan Mode
|
||||
flowchart TD
|
||||
Start[Start] --> ReadFiles[Read Memory Bank]
|
||||
ReadFiles --> CheckFiles{Files Complete?}
|
||||
|
||||
CheckFiles -->|No| Plan[Create Plan]
|
||||
Plan --> Document[Document in Chat]
|
||||
|
||||
CheckFiles -->|Yes| Verify[Verify Context]
|
||||
Verify --> Strategy[Develop Strategy]
|
||||
Strategy --> Present[Present Approach]
|
||||
|
||||
### Act Mode
|
||||
flowchart TD
|
||||
Start[Start] --> Context[Check Memory Bank]
|
||||
Context --> Update[Update Documentation]
|
||||
Update --> Execute[Execute Task]
|
||||
Execute --> Document[Document Changes]
|
||||
|
||||
## Documentation Updates
|
||||
|
||||
Memory Bank updates occur when:
|
||||
1. Discovering new project patterns
|
||||
2. After implementing significant changes
|
||||
3. When user requests with **update memory bank** (MUST review ALL files)
|
||||
4. When context needs clarification
|
||||
|
||||
flowchart TD
|
||||
Start[Update Process]
|
||||
|
||||
subgraph Process
|
||||
P1[Review ALL Files]
|
||||
P2[Document Current State]
|
||||
P3[Clarify Next Steps]
|
||||
P4[Document Insights & Patterns]
|
||||
|
||||
P1 --> P2 --> P3 --> P4
|
||||
end
|
||||
|
||||
Start --> Process
|
||||
|
||||
Note: When triggered by **update memory bank**, I MUST review every memory bank file, even if some don't require updates. Focus particularly on activeContext.md and progress.md as they track current state.
|
||||
|
||||
REMEMBER: After every memory reset, I begin completely fresh. The Memory Bank is my only link to previous work. It must be maintained with precision and clarity, as my effectiveness depends entirely on its accuracy.
|
||||
141
.env.example
141
.env.example
@@ -1,12 +1,68 @@
|
||||
# Hermes Agent Environment Configuration
|
||||
# Copy this file to .env and fill in your API keys
|
||||
|
||||
# =============================================================================
|
||||
# CORE SETTINGS
|
||||
# =============================================================================
|
||||
# Agent backend:
|
||||
# - openai : default Hermes-Agent loop (OpenAI function-calling via OpenAI SDK)
|
||||
# - atropos : Atroposlib ServerManager/ManagedServer-backed loop (training/env integration)
|
||||
HERMES_BACKEND=openai
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# LOCAL / SELF-HOSTED OPENAI-COMPATIBLE ENDPOINTS (vLLM, SGLang, llama.cpp, etc.)
|
||||
# =============================================================================
|
||||
# For local development (matches the Atropos test env defaults):
|
||||
# ATROPOS_SERVER_BASE_URL=http://127.0.0.1:8080
|
||||
# ATROPOS_SERVER_MODEL=hermes-4-36b
|
||||
# For hosted inference (Nous Research inference API):
|
||||
ATROPOS_SERVER_BASE_URL=
|
||||
ATROPOS_SERVER_MODEL=
|
||||
ATROPOS_TOKENIZER_NAME=
|
||||
# Set this to your Nous API key (Bearer token).
|
||||
ATROPOS_SERVER_API_KEY=
|
||||
|
||||
# Debugging (prints to stdout; use with care)
|
||||
# HERMES_DEBUG_ATROPOS_REQUEST=1
|
||||
# HERMES_DEBUG_ATROPOS_RESPONSE=1
|
||||
# HERMES_DEBUG_OPENAI_REQUEST=1
|
||||
# HERMES_DEBUG_OPENAI_RESPONSE=1
|
||||
|
||||
# =============================================================================
|
||||
# LOCAL / SELF-HOSTED OPENAI-COMPATIBLE ENDPOINTS (vLLM, SGLang, llama.cpp, etc.)
|
||||
# =============================================================================
|
||||
# If you set ATROPOS_SERVER_BASE_URL or OPENAI_BASE_URL, Hermes will use it instead
|
||||
# of OpenRouter.
|
||||
#
|
||||
# Local server convenience (base URL without /v1):
|
||||
# llama.cpp example (see `Hermes-Agent/scripts/launch_llama_cpp_hermes_4_36b.sh`):
|
||||
# ATROPOS_SERVER_BASE_URL=http://127.0.0.1:8080
|
||||
# ATROPOS_SERVER_MODEL=hermes-4-36b
|
||||
# ATROPOS_TOKENIZER_NAME=NousResearch/Hermes-4.3-36B
|
||||
# ATROPOS_SERVER_API_KEY=local
|
||||
#
|
||||
# Hosted Nous inference API:
|
||||
# ATROPOS_SERVER_BASE_URL=https://inference-api.nousresearch.com
|
||||
# ATROPOS_SERVER_MODEL=Hermes-4.3-36B
|
||||
# ATROPOS_TOKENIZER_NAME=NousResearch/Hermes-4.3-36B
|
||||
# ATROPOS_SERVER_API_KEY=sk-... (Bearer token)
|
||||
#
|
||||
# If you plan to run GRPO-style group sampling (e.g. `--env.group_size 4`) against
|
||||
# llama.cpp, start the server with at least that many slots, e.g.:
|
||||
# LLAMA_CPP_PARALLEL=4 Hermes-Agent/scripts/launch_llama_cpp_hermes_4_36b.sh
|
||||
#
|
||||
# Generic OpenAI-compatible (base URL should include /v1):
|
||||
# OPENAI_BASE_URL=http://127.0.0.1:8080/v1
|
||||
# OPENAI_API_KEY=local
|
||||
|
||||
# =============================================================================
|
||||
# LLM PROVIDER (OpenRouter)
|
||||
# =============================================================================
|
||||
# OpenRouter provides access to many models through one API
|
||||
# All LLM calls go through OpenRouter - no direct provider keys needed
|
||||
# Get your key at: https://openrouter.ai/keys
|
||||
OPENROUTER_BASE_URL=https://openrouter.ai/api/v1
|
||||
OPENROUTER_API_KEY=
|
||||
|
||||
# Default model to use (OpenRouter format: provider/model)
|
||||
@@ -92,12 +148,87 @@ TERMINAL_LIFETIME_SECONDS=300
|
||||
# SUDO_PASSWORD=your_password_here
|
||||
|
||||
# =============================================================================
|
||||
# MODAL CLOUD BACKEND (Optional - for TERMINAL_ENV=modal)
|
||||
# MODAL CLOUD BACKEND (for TERMINAL_ENV=modal)
|
||||
# =============================================================================
|
||||
# Modal uses CLI authentication, not environment variables.
|
||||
# Run: pip install modal && modal setup
|
||||
# This will authenticate via browser and store credentials locally.
|
||||
# No API key needed in .env - Modal handles auth automatically.
|
||||
# Modal provides cloud sandboxes with per-second billing and auto-scaling.
|
||||
# This implementation uses a warm pool of sandboxes for cost efficiency.
|
||||
#
|
||||
# SETUP:
|
||||
# pip install modal && modal setup
|
||||
# (Authenticates via browser, stores credentials locally)
|
||||
#
|
||||
# FEATURES:
|
||||
# - Auto-scaling warm sandbox pool (no cold start after first use)
|
||||
# - Named sandbox recovery (reconnects after restart)
|
||||
# - Profile-based heterogeneous environments (CPU, GPU, different images)
|
||||
# - Server-side idle_timeout protection against orphaned sandboxes
|
||||
|
||||
# Modal app name (groups all sandboxes, used for recovery)
|
||||
TERMINAL_MODAL_APP_NAME=hermes-sandbox
|
||||
|
||||
# Default profile when none specified
|
||||
TERMINAL_MODAL_DEFAULT_PROFILE=default
|
||||
|
||||
# Profile config file (optional - YAML format, see modal_profiles.yaml)
|
||||
# TERMINAL_MODAL_PROFILES_FILE=modal_profiles.yaml
|
||||
|
||||
# --- Default Profile Settings (used if no YAML file) ---
|
||||
# These apply when no profile is specified or for the "default" profile
|
||||
TERMINAL_MODAL_IMAGE=python:3.11
|
||||
TERMINAL_MODAL_MIN_POOL=1
|
||||
TERMINAL_MODAL_MAX_POOL=5
|
||||
TERMINAL_MODAL_IDLE_TIMEOUT=120
|
||||
TERMINAL_MODAL_MAX_LIFETIME=3600
|
||||
TERMINAL_MODAL_SCALE_DOWN_IDLE=180
|
||||
|
||||
# --- Custom Profile Example: pytorch-gpu ---
|
||||
# Uncomment to enable a GPU profile for ML tasks
|
||||
# Usage: terminal_tool("python train.py", profile="pytorch-gpu")
|
||||
#
|
||||
# TERMINAL_MODAL_PROFILE_pytorch_gpu_IMAGE=pytorch/pytorch:2.1.0-cuda12.1-cudnn8-runtime
|
||||
# TERMINAL_MODAL_PROFILE_pytorch_gpu_GPU=T4
|
||||
# TERMINAL_MODAL_PROFILE_pytorch_gpu_MEMORY=16384
|
||||
# TERMINAL_MODAL_PROFILE_pytorch_gpu_MIN_POOL=0
|
||||
# TERMINAL_MODAL_PROFILE_pytorch_gpu_MAX_POOL=2
|
||||
# TERMINAL_MODAL_PROFILE_pytorch_gpu_IDLE_TIMEOUT=60
|
||||
|
||||
# --- Custom Profile Example: node ---
|
||||
# Uncomment to enable a Node.js profile
|
||||
# Usage: terminal_tool("npm test", profile="node")
|
||||
#
|
||||
# TERMINAL_MODAL_PROFILE_node_IMAGE=node:18
|
||||
# TERMINAL_MODAL_PROFILE_node_MIN_POOL=0
|
||||
# TERMINAL_MODAL_PROFILE_node_MAX_POOL=3
|
||||
|
||||
# =============================================================================
|
||||
# MODAL SECRETS (Secure credential injection)
|
||||
# =============================================================================
|
||||
# Modal Secrets allow you to securely pass API keys, passwords, and other
|
||||
# sensitive data to your sandboxes without exposing them in code or logs.
|
||||
#
|
||||
# SETUP SECRETS:
|
||||
# 1. Via Dashboard: https://modal.com/secrets
|
||||
# 2. Via CLI: modal secret create my-secret KEY1=value1 KEY2=value2
|
||||
# 3. Via CLI with env: modal secret create my-secret API_KEY="$API_KEY"
|
||||
#
|
||||
# LIST SECRETS:
|
||||
# modal secret list
|
||||
#
|
||||
# DELETE SECRETS:
|
||||
# modal secret delete my-secret
|
||||
|
||||
# Global secrets applied to ALL profiles (comma-separated secret names)
|
||||
# These secrets must be created on Modal dashboard or via CLI first
|
||||
# TERMINAL_MODAL_SECRETS=my-api-keys,database-creds
|
||||
|
||||
# Per-profile secrets (comma-separated secret names)
|
||||
# TERMINAL_MODAL_PROFILE_pytorch_gpu_SECRETS=huggingface-token,wandb-key
|
||||
|
||||
# Per-profile environment variables (semicolon-separated KEY=VALUE pairs)
|
||||
# TERMINAL_MODAL_PROFILE_default_ENV_VARS=DEBUG=1;LOG_LEVEL=info
|
||||
|
||||
# Load local .env file into sandbox (useful for development)
|
||||
# TERMINAL_MODAL_PROFILE_default_USE_DOTENV=true
|
||||
|
||||
# =============================================================================
|
||||
# BROWSER TOOL CONFIGURATION (agent-browser + Browserbase)
|
||||
|
||||
24
.gitignore
vendored
24
.gitignore
vendored
@@ -46,3 +46,27 @@ testlogs
|
||||
|
||||
# CLI config (may contain sensitive SSH paths)
|
||||
cli-config.yaml
|
||||
|
||||
.DS_Store
|
||||
|
||||
# artifacts
|
||||
*.jsonl
|
||||
*.html
|
||||
*.json
|
||||
*.log
|
||||
*.csv
|
||||
|
||||
# Singularity/Apptainer images (large binary files)
|
||||
*.sif
|
||||
|
||||
# Test files
|
||||
test_singularity_*.py
|
||||
test_*.py
|
||||
!tests/test_*.py
|
||||
|
||||
# Nomad data
|
||||
/tmp/NomadClient*/
|
||||
|
||||
*.egg-info*
|
||||
wandb
|
||||
logs
|
||||
131
README.md
131
README.md
@@ -995,6 +995,137 @@ All variables go in `~/.hermes/.env`. Run `hermes config set VAR value` to set t
|
||||
|
||||
---
|
||||
|
||||
## RL Training with Tinker
|
||||
|
||||
Hermes-Agent includes an RL training integration with [Tinker](https://thinkingmachines.ai/tinker/) (Thinking Machines) and [Atropos](https://github.com/NousResearch/atropos) for training language models with reinforcement learning from agent trajectories.
|
||||
|
||||
### Prerequisites
|
||||
|
||||
1. **Install with Atropos extras** (includes Tinker SDK, atroposlib, torch, wandb):
|
||||
```bash
|
||||
pip install -e ".[atropos]"
|
||||
```
|
||||
|
||||
2. **Initialize the tinker-atropos submodule**:
|
||||
```bash
|
||||
git submodule update --init
|
||||
pip install -e ./tinker-atropos
|
||||
```
|
||||
|
||||
3. **Get API keys**:
|
||||
- `TINKER_API_KEY` from [Tinker Console](https://tinker-console.thinkingmachines.ai/keys) (requires billing setup)
|
||||
- `WANDB_API_KEY` from [Weights & Biases](https://wandb.ai/settings) (for metrics tracking)
|
||||
|
||||
4. **Add keys to your `.env` file**:
|
||||
```bash
|
||||
# Add to .env or ~/.hermes/.env
|
||||
TINKER_API_KEY=your_tinker_key
|
||||
WANDB_API_KEY=your_wandb_key
|
||||
```
|
||||
|
||||
### Architecture
|
||||
|
||||
The RL training pipeline uses three processes that communicate over HTTP:
|
||||
|
||||
```
|
||||
┌──────────────────────┐ ┌─────────────────────┐ ┌────────────────────────┐
|
||||
│ Atropos Rollout API │ │ Tinker Trainer │ │ Environment │
|
||||
│ (port 8000) │◄──│ (port 8001) │◄──│ (worker) │
|
||||
│ │ │ │ │ │
|
||||
│ • Collects batches │ │ • LoRA training │ │ • Generates prompts │
|
||||
│ • Coordinates env │ │ • Inference server │ │ • Calls inference API │
|
||||
│ and trainer │ │ • Weight updates │ │ • Scores responses │
|
||||
│ │ │ • WandB logging │ │ • Sends scored batches │
|
||||
└──────────────────────┘ └─────────────────────┘ └────────────────────────┘
|
||||
```
|
||||
|
||||
### Quick Start: GSM8k Agent Training
|
||||
|
||||
This example trains a model on math problems using a Python REPL tool — the model learns to write and execute Python code to solve math:
|
||||
|
||||
```bash
|
||||
# Terminal 1: Start Atropos Rollout API
|
||||
cd tinker-atropos
|
||||
source ../.venv/bin/activate
|
||||
set -a && source ../.env && set +a
|
||||
run-api
|
||||
|
||||
# Terminal 2: Start Tinker Trainer + Inference Server
|
||||
cd tinker-atropos
|
||||
source ../.venv/bin/activate
|
||||
set -a && source ../.env && set +a
|
||||
python launch_training.py --config configs/gsm8k_agent.yaml
|
||||
|
||||
# Terminal 3: Start GSM8k Agent Environment
|
||||
cd tinker-atropos
|
||||
source ../.venv/bin/activate
|
||||
set -a && source ../.env && set +a
|
||||
python tinker_atropos/environments/gsm8k_agent.py serve --config configs/gsm8k_agent.yaml
|
||||
```
|
||||
|
||||
### Available Environments
|
||||
|
||||
| Environment | File | Description |
|
||||
|------------|------|-------------|
|
||||
| `gsm8k` | `gsm8k_tinker.py` | Standard GSM8k math (no tools) |
|
||||
| `gsm8k_agent` | `gsm8k_agent.py` | GSM8k with Python REPL tool calling |
|
||||
|
||||
### Configuration
|
||||
|
||||
Configs are YAML files in `tinker-atropos/configs/` with three sections:
|
||||
|
||||
```yaml
|
||||
env: # Atropos environment settings
|
||||
group_size: 4 # Parallel rollouts per problem
|
||||
batch_size: 16 # Training batch size
|
||||
tokenizer_name: "Qwen/Qwen3-4B-Instruct-2507"
|
||||
max_token_length: 2048 # Max generation length
|
||||
total_steps: 20 # Training steps
|
||||
|
||||
openai: # Inference server (served by Tinker trainer)
|
||||
- model_name: "Qwen/Qwen3-4B-Instruct-2507"
|
||||
base_url: "http://localhost:8001/v1"
|
||||
|
||||
tinker: # Tinker training parameters
|
||||
lora_rank: 16 # LoRA rank (lower = faster, less capacity)
|
||||
learning_rate: 0.00005 # Learning rate
|
||||
max_token_trainer_length: 4096 # Max tokens for training
|
||||
wandb_project: "hermes-agent-rl"
|
||||
```
|
||||
|
||||
### RL CLI (Agent-Driven Training)
|
||||
|
||||
For interactive training management via the Hermes agent:
|
||||
|
||||
```bash
|
||||
# Interactive mode - let the agent manage training
|
||||
python rl_cli.py --interactive
|
||||
|
||||
# List available environments
|
||||
python rl_cli.py --list-environments
|
||||
|
||||
# Direct task
|
||||
python rl_cli.py "Train a model on GSM8k with tool use"
|
||||
```
|
||||
|
||||
### Sandbox Backends for Agent Environments
|
||||
|
||||
For agent environments that need isolated tool execution (e.g., SWE tasks), Hermes-Agent supports multiple sandbox backends:
|
||||
|
||||
| Backend | Use Case | Command |
|
||||
|---------|----------|---------|
|
||||
| **Nomad + Docker** | Default, local development | `--env.tool_pool_mode nomad` |
|
||||
| **Nomad + Singularity** | HPC clusters without Docker | `--env.tool_pool_mode nomad --env.driver singularity` |
|
||||
| **Modal** | Cloud-based, auto-scaling | `--env.tool_pool_mode modal` |
|
||||
|
||||
See [docs/MODAL_BACKEND.md](docs/MODAL_BACKEND.md) for Modal backend details.
|
||||
|
||||
### Cost
|
||||
|
||||
Check the [Tinker Rate Card](https://tinker-console.thinkingmachines.ai/rate-card) for available models and pricing.
|
||||
|
||||
---
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
```bash
|
||||
|
||||
41
atropos/Dockerfile
Normal file
41
atropos/Dockerfile
Normal file
@@ -0,0 +1,41 @@
|
||||
# Dockerfile for atropos-agent sandbox server
|
||||
# Runs inside Nomad containers to handle tool execution
|
||||
# Includes bubblewrap for namespace-based slot isolation
|
||||
|
||||
FROM python:3.11-slim
|
||||
|
||||
# Install system dependencies
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
# Bubblewrap for namespace isolation
|
||||
bubblewrap \
|
||||
# `script` for PTY allocation (used for stable tmux+asciinema startup)
|
||||
util-linux \
|
||||
# Git for SWE-style tasks (cloning repos)
|
||||
git \
|
||||
# tmux for stateful terminal sessions (Phase 4.7+)
|
||||
tmux \
|
||||
# Common tools agents might need
|
||||
curl \
|
||||
wget \
|
||||
jq \
|
||||
# Cleanup
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Install Python dependencies (sandbox server + optional terminal recording)
|
||||
RUN pip install --no-cache-dir aiohttp asciinema
|
||||
|
||||
# Copy the sandbox server
|
||||
COPY sandbox_server.py /app/sandbox_server.py
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Create data directory for slot workspaces
|
||||
RUN mkdir -p /data
|
||||
|
||||
# Verify bubblewrap is installed and working
|
||||
RUN bwrap --version
|
||||
|
||||
EXPOSE 8080
|
||||
|
||||
# Default command - can be overridden by Nomad job spec
|
||||
CMD ["python", "sandbox_server.py", "--port", "8080", "--slots", "10", "--data-dir", "/data"]
|
||||
47
atropos/__init__.py
Normal file
47
atropos/__init__.py
Normal file
@@ -0,0 +1,47 @@
|
||||
"""
|
||||
Atropos integration for Hermes-Agent.
|
||||
|
||||
This package is intentionally optional: Hermes-Agent should work without Atropos.
|
||||
If you import anything from `atropos.*` without having `atroposlib` installed,
|
||||
we raise a clear error with install instructions.
|
||||
|
||||
Install (recommended, from repo checkout):
|
||||
uv sync --extra atropos
|
||||
|
||||
Or (pip / editable):
|
||||
pip install -e '.[atropos]'
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
def _require_atroposlib() -> None:
|
||||
try:
|
||||
import atroposlib # noqa: F401
|
||||
except ModuleNotFoundError as exc: # pragma: no cover
|
||||
raise ModuleNotFoundError(
|
||||
"Hermes-Agent Atropos integration requires `atroposlib`, but it is not installed.\n"
|
||||
"Install it with:\n"
|
||||
" uv sync --extra atropos\n"
|
||||
"or:\n"
|
||||
" pip install -e '.[atropos]'\n"
|
||||
) from exc
|
||||
|
||||
|
||||
_require_atroposlib()
|
||||
|
||||
# Re-export the most commonly used pieces for convenience.
|
||||
# Agent imports are eager (always available).
|
||||
from .agent import AgentConfig, AgentResult, AgentStep, AtroposAgent, SequenceData # noqa: E402
|
||||
|
||||
# Env imports are lazy to avoid pulling in deleted atropos.tools dependencies.
|
||||
# Use: from atropos.envs import AgentEnv, AgentEnvConfig (if needed)
|
||||
|
||||
__all__ = [
|
||||
"AtroposAgent",
|
||||
"AgentConfig",
|
||||
"AgentResult",
|
||||
"AgentStep",
|
||||
"SequenceData",
|
||||
]
|
||||
|
||||
15
atropos/agent/__init__.py
Normal file
15
atropos/agent/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
"""
|
||||
Agent abstractions for atropos-agent.
|
||||
|
||||
Provides the core AtroposAgent class for running ReACT-style agent loops.
|
||||
"""
|
||||
|
||||
from .atropos_agent import AgentConfig, AgentResult, AgentStep, AtroposAgent, SequenceData
|
||||
|
||||
__all__ = [
|
||||
"AtroposAgent",
|
||||
"AgentConfig",
|
||||
"AgentResult",
|
||||
"AgentStep",
|
||||
"SequenceData",
|
||||
]
|
||||
850
atropos/agent/atropos_agent.py
Normal file
850
atropos/agent/atropos_agent.py
Normal file
@@ -0,0 +1,850 @@
|
||||
"""
|
||||
ReACT-style agent implementation for atropos-agent.
|
||||
|
||||
This module provides the core AtroposAgent class that implements a basic
|
||||
Reason-Act-Observe loop with tool calling capabilities.
|
||||
|
||||
Uses ManagedServer from atroposlib for automatic token/logprob tracking,
|
||||
making trajectories ready for RL training.
|
||||
|
||||
The agent uses Hermes-style XML tags for tool calls:
|
||||
- <think>...</think> for reasoning
|
||||
- <tool_call>{"name": "...", "arguments": {...}}</tool_call> for actions
|
||||
- <tool_response>...</tool_response> for observations
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import json
|
||||
import time
|
||||
from contextlib import asynccontextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from uuid import uuid4
|
||||
from typing import Any, AsyncGenerator, Awaitable, Callable, Dict, List, Optional, Union
|
||||
|
||||
from dotenv import load_dotenv
|
||||
import httpx
|
||||
|
||||
from ..tools import ToolCall, ToolRegistry, ToolResult
|
||||
from atroposlib.envs.server_handling.managed_server import ManagedServer
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
# Default system prompt with tool calling instructions.
|
||||
AGENT_SYSTEM_PROMPT = """You are a deep thinking AI. You MUST enclose your internal reasoning inside <think>...</think> tags.
|
||||
|
||||
You are a function calling AI model.
|
||||
|
||||
You are provided with function signatures within <tools></tools> XML tags.
|
||||
You must call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions.
|
||||
You can ONLY respond without a tool call if you are totally certain you have the final answer to the user's question or task
|
||||
After calling & executing a function, you will be provided with function results within <tool_response></tool_response> XML tags.
|
||||
|
||||
Here are the available tools:
|
||||
<tools>
|
||||
{tools_json}
|
||||
</tools>
|
||||
|
||||
Use the following JSON schema for each tool call you will make:
|
||||
{"title": "FunctionCall", "type": "object", "properties": {"name": {"title": "Name", "type": "string"}, "arguments": {"title": "Arguments", "type": "object"}}, "required": ["name", "arguments"]}
|
||||
|
||||
## REQUIRED TOOL FORMAT
|
||||
|
||||
When you decide to call a tool, your assistant message MUST be:
|
||||
1) exactly one <think>...</think> block, followed by
|
||||
2) one or more <tool_call>...</tool_call> blocks,
|
||||
and NOTHING else in that message.
|
||||
|
||||
If you need to explain anything, put it inside <think>. Do NOT write natural language outside <think> or <tool_call>.
|
||||
|
||||
For each function call return a JSON object with function name and arguments within <tool_call></tool_call> XML tags as follows:
|
||||
<tool_call>
|
||||
{"name": "<function-name>", "arguments": {"arg1": "value1"}}
|
||||
</tool_call>
|
||||
|
||||
Each <tool_call> must be on its own and contain ONLY the JSON object (no extra text).
|
||||
The JSON inside <tool_call> MUST be valid JSON with double quotes.
|
||||
|
||||
Do NOT output <tool_response> in an assistant message.
|
||||
|
||||
After you receive tool results, you may either call more tools (same required format) or provide the final answer.
|
||||
When providing the final answer, do NOT include any <tool_call> blocks.
|
||||
|
||||
## TERMINAL TOOL NOTES
|
||||
|
||||
- Commands execute under POSIX `/bin/sh` (not bash).
|
||||
- Each tool call runs in a fresh shell: environment changes (like `cd` or venv activation) do not persist across tool calls.
|
||||
- Avoid bash-only features like `source`, `[[ ... ]]`, or process substitution.
|
||||
- Prefer explicit venv usage:
|
||||
- `python -m venv .venv && . .venv/bin/activate && python -m pip install -e .` (POSIX `.` activation), or
|
||||
- `.venv/bin/python -m pip install -e .` (no activation required).
|
||||
|
||||
## ICL (examples)
|
||||
|
||||
User: Show the current directory.
|
||||
Assistant:
|
||||
<think>I should run pwd.</think>
|
||||
<tool_call>
|
||||
{"name": "terminal", "arguments": {"command": "pwd"}}
|
||||
</tool_call>
|
||||
User: <tool_response>{"success": true, "output": "/tmp\\n"}</tool_response>
|
||||
Assistant: /tmp
|
||||
|
||||
User: List files, then count them.
|
||||
Assistant:
|
||||
<think>I should count files.</think>
|
||||
<tool_call>
|
||||
{"name": "terminal", "arguments": {"command": "ls -1 | wc -l"}}
|
||||
</tool_call>
|
||||
User: <tool_response>{"success": true, "output": "3\\n"}</tool_response>
|
||||
Assistant: 3
|
||||
|
||||
User: Run pwd, then print ok (two tool calls).
|
||||
Assistant:
|
||||
<think>I should run two commands.</think>
|
||||
<tool_call>
|
||||
{"name": "terminal", "arguments": {"command": "pwd"}}
|
||||
</tool_call>
|
||||
<tool_call>
|
||||
{"name": "terminal", "arguments": {"command": "echo ok"}}
|
||||
</tool_call>
|
||||
User: <tool_response>{"success": true, "output": "/tmp\\n"}</tool_response>
|
||||
User: <tool_response>{"success": true, "output": "ok\\n"}</tool_response>
|
||||
Assistant: ok
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentConfig:
|
||||
"""Configuration for the AtroposAgent."""
|
||||
|
||||
# Generation parameters
|
||||
temperature: Optional[float] = 0.7
|
||||
# Default to "let the backend decide" (important for tool-tag completions that may be longer).
|
||||
max_tokens: Optional[int] = None
|
||||
|
||||
# Agent behavior
|
||||
max_steps: int = 50
|
||||
system_prompt: Optional[str] = None
|
||||
tool_delay_s: float = 0.0
|
||||
|
||||
# Working directory for tools
|
||||
working_dir: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class SequenceData:
|
||||
"""Token/logprob data from a single completion."""
|
||||
|
||||
full_text: str
|
||||
tokens: List[int]
|
||||
masked_tokens: List[int] # -100 for prompt, actual IDs for completion
|
||||
logprobs: List[float] # 1.0 for prompt, actual values for completion
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
|
||||
@classmethod
|
||||
def from_sequence_node(cls, node) -> "SequenceData":
|
||||
"""Create from a ManagedServer SequenceNode."""
|
||||
return cls(
|
||||
full_text=node.full_text,
|
||||
tokens=node.tokens,
|
||||
masked_tokens=node.masked_tokens,
|
||||
logprobs=node.logprobs,
|
||||
metadata=getattr(node, "metadata", None),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentStep:
|
||||
"""A single step in the agent's trajectory."""
|
||||
|
||||
step_number: int
|
||||
assistant_message: str
|
||||
tool_calls: List[ToolCall] = field(default_factory=list)
|
||||
tool_results: List[ToolResult] = field(default_factory=list)
|
||||
sequence_data: Optional[SequenceData] = None # Token data from this step
|
||||
|
||||
@property
|
||||
def has_tool_calls(self) -> bool:
|
||||
return len(self.tool_calls) > 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentResult:
|
||||
"""Result of running an agent trajectory."""
|
||||
|
||||
success: bool
|
||||
final_response: str
|
||||
steps: List[AgentStep] = field(default_factory=list)
|
||||
total_tokens: int = 0
|
||||
error: Optional[str] = None
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
# Full trajectory token data for RL training
|
||||
trajectory_data: Optional[SequenceData] = None
|
||||
|
||||
@property
|
||||
def num_steps(self) -> int:
|
||||
return len(self.steps)
|
||||
|
||||
@property
|
||||
def total_tool_calls(self) -> int:
|
||||
return sum(len(step.tool_calls) for step in self.steps)
|
||||
|
||||
def to_messages(self) -> List[Dict[str, str]]:
|
||||
"""Convert trajectory to messages format for logging."""
|
||||
messages = []
|
||||
for step in self.steps:
|
||||
messages.append({"role": "assistant", "content": step.assistant_message})
|
||||
if step.tool_results:
|
||||
# Combine all tool responses
|
||||
responses = "\n".join(r.to_xml() for r in step.tool_results)
|
||||
messages.append({"role": "user", "content": responses})
|
||||
return messages
|
||||
|
||||
def to_scored_data(self, score: float) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Convert to format suitable for ScoredDataGroup.
|
||||
|
||||
Args:
|
||||
score: The score for this trajectory
|
||||
|
||||
Returns:
|
||||
Dict with tokens, masks, scores suitable for training, or None if no data
|
||||
"""
|
||||
if self.trajectory_data is None:
|
||||
return None
|
||||
|
||||
return {
|
||||
"tokens": self.trajectory_data.tokens,
|
||||
"masks": self.trajectory_data.masked_tokens,
|
||||
"scores": score,
|
||||
"logprobs": self.trajectory_data.logprobs,
|
||||
}
|
||||
|
||||
|
||||
class AtroposAgent:
|
||||
"""
|
||||
A ReACT-style agent that uses LLMs with tool calling.
|
||||
|
||||
This implementation wraps ManagedServer for automatic token/logprob tracking,
|
||||
making trajectories ready for RL training.
|
||||
|
||||
Example:
|
||||
# `server` may be an Atropos `ServerManager` (recommended) or a single `APIServer`.
|
||||
# In practice, environments usually construct this via `BaseEnv`.
|
||||
server = ...
|
||||
tools = ToolRegistry()
|
||||
tools.register(BashTool())
|
||||
|
||||
agent = AtroposAgent(server=server, tools=tools)
|
||||
result = await agent.run("List the files in the current directory")
|
||||
|
||||
# Access token data for training
|
||||
if result.trajectory_data:
|
||||
print(f"Tokens: {result.trajectory_data.tokens}")
|
||||
print(f"Masked: {result.trajectory_data.masked_tokens}")
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server, # ServerManager or APIServer
|
||||
tools: Optional[ToolRegistry] = None,
|
||||
config: Optional[AgentConfig] = None,
|
||||
tokenizer: Optional[Any] = None,
|
||||
execute_tool: Optional[Callable[[ToolCall], Awaitable[ToolResult]]] = None,
|
||||
):
|
||||
self.server = server
|
||||
self.tools = tools or ToolRegistry()
|
||||
self.config = config or AgentConfig()
|
||||
self.tokenizer = tokenizer or getattr(server, "tokenizer", None)
|
||||
self.execute_tool = execute_tool or self.tools.execute
|
||||
|
||||
@asynccontextmanager
|
||||
async def _managed(self) -> AsyncGenerator[Any, None]:
|
||||
"""
|
||||
Yield a ManagedServer-like object.
|
||||
|
||||
- If `self.server` is a ServerManager, use its `managed_server()` context manager.
|
||||
- If `self.server` is a single APIServer, wrap it in `ManagedServer` directly.
|
||||
"""
|
||||
if os.getenv("ATROPOS_BYPASS_MANAGED_SERVER") == "1":
|
||||
yield _DirectChatCompletionClient(server=self.server)
|
||||
return
|
||||
if hasattr(self.server, "managed_server"):
|
||||
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
|
||||
yield managed
|
||||
else:
|
||||
managed = ManagedServer(server=self.server, tokenizer=self.tokenizer)
|
||||
try:
|
||||
yield managed
|
||||
finally:
|
||||
managed.reset()
|
||||
|
||||
def _build_system_prompt(self) -> str:
|
||||
"""Build the system prompt with tool descriptions."""
|
||||
if self.config.system_prompt:
|
||||
return self.config.system_prompt
|
||||
|
||||
tools_json = self.tools.get_prompt_tool_definitions_json()
|
||||
# Avoid `str.format()` here because the prompt contains many literal `{}` braces
|
||||
# in JSON examples; we only want to substitute the single `{tools_json}` token.
|
||||
return AGENT_SYSTEM_PROMPT.replace("{tools_json}", tools_json)
|
||||
|
||||
def _infer_server_model_for_debug(self) -> Optional[str]:
|
||||
"""
|
||||
Best-effort inference of the configured model name for debug payload saving.
|
||||
|
||||
ManagedServer/server_manager typically injects `model` internally, so `chat_kwargs`
|
||||
may not contain it. For replaying saved payloads via curl, it's useful to persist it.
|
||||
"""
|
||||
servers = getattr(self.server, "servers", None)
|
||||
if isinstance(servers, list) and servers:
|
||||
s0 = servers[0]
|
||||
cfg = getattr(s0, "config", None)
|
||||
model = getattr(cfg, "model_name", None) or getattr(s0, "model_name", None)
|
||||
if isinstance(model, str) and model:
|
||||
return model
|
||||
model = getattr(self.server, "model_name", None) or getattr(self.server, "model", None)
|
||||
if isinstance(model, str) and model:
|
||||
return model
|
||||
return None
|
||||
|
||||
def _infer_server_base_url_for_debug(self) -> Optional[str]:
|
||||
"""
|
||||
Best-effort inference of the configured base_url for debug logging.
|
||||
|
||||
This is helpful when diagnosing hangs / retries at the transport layer.
|
||||
"""
|
||||
servers = getattr(self.server, "servers", None)
|
||||
if isinstance(servers, list) and servers:
|
||||
s0 = servers[0]
|
||||
cfg = getattr(s0, "config", None)
|
||||
base_url = getattr(cfg, "base_url", None) or getattr(s0, "base_url", None)
|
||||
if isinstance(base_url, str) and base_url:
|
||||
return base_url
|
||||
base_url = getattr(self.server, "base_url", None)
|
||||
if isinstance(base_url, str) and base_url:
|
||||
return base_url
|
||||
return None
|
||||
|
||||
def _extract_response_metadata(self, response: Any) -> Dict[str, Any]:
|
||||
"""
|
||||
Extract lightweight, JSON-serializable metadata from an OpenAI-style response.
|
||||
|
||||
This is useful for debugging training runs, especially when ManagedServer state
|
||||
tracking is unavailable (e.g. OpenAI-compatible chat endpoints).
|
||||
"""
|
||||
meta: Dict[str, Any] = {}
|
||||
try:
|
||||
rid = getattr(response, "id", None)
|
||||
if isinstance(rid, str) and rid:
|
||||
meta["id"] = rid
|
||||
model = getattr(response, "model", None)
|
||||
if isinstance(model, str) and model:
|
||||
meta["model"] = model
|
||||
created = getattr(response, "created", None)
|
||||
if isinstance(created, int):
|
||||
meta["created"] = created
|
||||
system_fingerprint = getattr(response, "system_fingerprint", None)
|
||||
if isinstance(system_fingerprint, str) and system_fingerprint:
|
||||
meta["system_fingerprint"] = system_fingerprint
|
||||
|
||||
choices = getattr(response, "choices", None)
|
||||
if isinstance(choices, list) and choices:
|
||||
fr = getattr(choices[0], "finish_reason", None)
|
||||
if isinstance(fr, str) and fr:
|
||||
meta["finish_reason"] = fr
|
||||
|
||||
usage = getattr(response, "usage", None)
|
||||
if usage is not None:
|
||||
if hasattr(usage, "model_dump"):
|
||||
meta["usage"] = usage.model_dump()
|
||||
elif isinstance(usage, dict):
|
||||
meta["usage"] = usage
|
||||
except Exception:
|
||||
pass
|
||||
return meta
|
||||
|
||||
def _debug_dump_request(self, *, step_num: int, chat_kwargs: Dict[str, Any]) -> None:
|
||||
if os.getenv("ATROPOS_DEBUG_AGENT_REQUEST") != "1":
|
||||
return
|
||||
try:
|
||||
# Avoid dumping megabytes by default; messages can be huge.
|
||||
meta = {
|
||||
"step": step_num,
|
||||
"base_url": self._infer_server_base_url_for_debug(),
|
||||
"model": chat_kwargs.get("model") or self._infer_server_model_for_debug(),
|
||||
"chat_kwargs_keys": sorted(list(chat_kwargs.keys())),
|
||||
"n": chat_kwargs.get("n"),
|
||||
"max_tokens": chat_kwargs.get("max_tokens"),
|
||||
"temperature": chat_kwargs.get("temperature"),
|
||||
"num_messages": len(chat_kwargs.get("messages") or []),
|
||||
}
|
||||
print("\n=== ATROPOS_DEBUG_AGENT_REQUEST ===", flush=True)
|
||||
print(meta, flush=True)
|
||||
|
||||
if os.getenv("ATROPOS_DEBUG_AGENT_REQUEST_FULL") == "1":
|
||||
payload = dict(chat_kwargs)
|
||||
# Make the payload more legible and less huge.
|
||||
try:
|
||||
dumped = json.dumps(payload, ensure_ascii=False, indent=2)
|
||||
except Exception:
|
||||
dumped = repr(payload)
|
||||
print("\n=== ATROPOS_DEBUG_AGENT_REQUEST_FULL ===", flush=True)
|
||||
print(dumped[:200_000], flush=True)
|
||||
|
||||
# Optional: save the FULL request payload to disk (no truncation).
|
||||
save_dir = os.getenv("ATROPOS_DEBUG_AGENT_REQUEST_SAVE_DIR")
|
||||
if save_dir:
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
payload: Dict[str, Any] = dict(chat_kwargs)
|
||||
if "model" not in payload:
|
||||
model = self._infer_server_model_for_debug()
|
||||
if model:
|
||||
payload["model"] = model
|
||||
# Use a unique filename so parallel trajectories don't clobber each other.
|
||||
fname = os.path.join(
|
||||
save_dir,
|
||||
f"atropos_agent_request_step{step_num}_{int(time.time()*1000)}_{os.getpid()}_{uuid4().hex}.json",
|
||||
)
|
||||
with open(fname, "w", encoding="utf-8") as f:
|
||||
json.dump(payload, f, ensure_ascii=False, indent=2)
|
||||
print(f"[AtroposAgent] saved request payload: {fname}", flush=True)
|
||||
except Exception:
|
||||
return
|
||||
|
||||
def _debug_dump_response(self, *, step_num: int, response: Any) -> None:
|
||||
if os.getenv("ATROPOS_DEBUG_AGENT_RESPONSE") != "1":
|
||||
return
|
||||
print("\n=== ATROPOS_DEBUG_AGENT_RESPONSE ===", flush=True)
|
||||
print({"step": step_num, "type": type(response).__name__}, flush=True)
|
||||
try:
|
||||
dumped = response.model_dump() # openai pydantic model
|
||||
except Exception:
|
||||
dumped = getattr(response, "__dict__", {"repr": repr(response)})
|
||||
# Keep the dump bounded; we only need enough to see the assistant message content.
|
||||
text = str(dumped)
|
||||
print(text[:200_000], flush=True)
|
||||
|
||||
async def _chat_completion_with_debug(
|
||||
self, *, managed: Any, step_num: int, chat_kwargs: Dict[str, Any]
|
||||
) -> Any:
|
||||
"""
|
||||
Call `managed.chat_completion()` with optional timeout + richer failure logging.
|
||||
|
||||
Debug env vars:
|
||||
- `ATROPOS_AGENT_CHAT_TIMEOUT_S`: if set, wraps the await in `asyncio.wait_for`.
|
||||
- `ATROPOS_DEBUG_AGENT_WAIT_EVERY_S`: if set, prints a heartbeat while waiting.
|
||||
"""
|
||||
# Hard guardrail: never allow a single chat completion to block for too long.
|
||||
# This is essential for RL data-gen stability; long hangs should be treated as failures (score=0).
|
||||
timeout_s_raw = os.getenv("ATROPOS_AGENT_CHAT_TIMEOUT_S")
|
||||
timeout_s_default = 240.0
|
||||
timeout_s = float(timeout_s_raw) if timeout_s_raw else timeout_s_default
|
||||
timeout_s = min(timeout_s, 240.0)
|
||||
|
||||
wait_every_raw = os.getenv("ATROPOS_DEBUG_AGENT_WAIT_EVERY_S")
|
||||
wait_every_s = float(wait_every_raw) if wait_every_raw else None
|
||||
|
||||
async def _await_call() -> Any:
|
||||
if not wait_every_s or wait_every_s <= 0:
|
||||
return await managed.chat_completion(**chat_kwargs)
|
||||
|
||||
# Heartbeat mode: wait in chunks without cancelling the underlying request.
|
||||
# NOTE: do NOT use `asyncio.wait_for(task, timeout=...)` here, because a timeout
|
||||
# will cancel the task and surface as `CancelledError` on the next loop.
|
||||
task = asyncio.create_task(managed.chat_completion(**chat_kwargs))
|
||||
t0 = time.perf_counter()
|
||||
try:
|
||||
while True:
|
||||
done, _pending = await asyncio.wait({task}, timeout=wait_every_s)
|
||||
if task in done:
|
||||
return task.result()
|
||||
|
||||
waited = time.perf_counter() - t0
|
||||
print(
|
||||
f"[AtroposAgent] step={step_num} still waiting for chat_completion... ({waited:.1f}s)",
|
||||
flush=True,
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
task.cancel()
|
||||
raise
|
||||
|
||||
try:
|
||||
return await asyncio.wait_for(_await_call(), timeout=timeout_s)
|
||||
except asyncio.TimeoutError as e:
|
||||
print("\n=== ATROPOS_DEBUG_AGENT_CHAT_TIMEOUT ===", flush=True)
|
||||
print({"step": step_num, "timeout_s": timeout_s}, flush=True)
|
||||
raise RuntimeError(f"chat_completion timed out after {timeout_s:.1f}s") from e
|
||||
except asyncio.CancelledError:
|
||||
# Treat cancellation as a hard failure rather than crashing the whole env run.
|
||||
# (Atropos/BaseEnv may cancel tasks during shutdown or retries.)
|
||||
raise RuntimeError("chat_completion cancelled") from None
|
||||
except Exception as e:
|
||||
detail: Dict[str, Any] = {
|
||||
"step": step_num,
|
||||
"exc_type": type(e).__name__,
|
||||
"exc_str": str(e),
|
||||
}
|
||||
if isinstance(e, httpx.HTTPStatusError):
|
||||
try:
|
||||
detail["status_code"] = e.response.status_code
|
||||
detail["response_text"] = e.response.text[:20_000]
|
||||
except Exception:
|
||||
pass
|
||||
elif isinstance(e, httpx.RequestError):
|
||||
detail["request"] = repr(getattr(e, "request", None))
|
||||
|
||||
print("\n=== ATROPOS_DEBUG_AGENT_CHAT_FAILURE ===", flush=True)
|
||||
print(detail, flush=True)
|
||||
raise
|
||||
|
||||
async def run(
|
||||
self,
|
||||
task: str,
|
||||
initial_messages: Optional[List[Dict[str, str]]] = None,
|
||||
) -> AgentResult:
|
||||
"""
|
||||
Run the agent on a task using ManagedServer for token tracking.
|
||||
|
||||
Args:
|
||||
task: The task/prompt for the agent
|
||||
initial_messages: Optional additional context messages
|
||||
|
||||
Returns:
|
||||
AgentResult with the trajectory, final response, and token data
|
||||
"""
|
||||
messages = [
|
||||
{"role": "system", "content": self._build_system_prompt()},
|
||||
]
|
||||
|
||||
if initial_messages:
|
||||
messages.extend(initial_messages)
|
||||
|
||||
messages.append({"role": "user", "content": task})
|
||||
|
||||
steps = []
|
||||
final_response = ""
|
||||
final_node = None
|
||||
final_prompt_messages: Optional[List[Dict[str, str]]] = None
|
||||
last_node = None
|
||||
last_prompt_messages: Optional[List[Dict[str, str]]] = None
|
||||
last_response_text: str = ""
|
||||
|
||||
# Use ManagedServer for automatic token tracking
|
||||
async with self._managed() as managed:
|
||||
for step_num in range(self.config.max_steps):
|
||||
# ReACT loop iteration here, just call -> tools -> observe until done (no tools called)
|
||||
try:
|
||||
# Keep a copy of the prompt messages used for this completion.
|
||||
# Useful for reconstructing tokens/masks when state tracking is unavailable.
|
||||
prompt_messages = list(messages)
|
||||
chat_kwargs: Dict[str, Any] = {"messages": messages, "n": 1}
|
||||
if self.config.max_tokens is not None:
|
||||
chat_kwargs["max_tokens"] = self.config.max_tokens
|
||||
if self.config.temperature is not None:
|
||||
chat_kwargs["temperature"] = self.config.temperature
|
||||
|
||||
t_req = time.perf_counter()
|
||||
print(
|
||||
f"[AtroposAgent] step={step_num+1} chat_completion start "
|
||||
f"(messages={len(messages)}, max_tokens={self.config.max_tokens}, temp={self.config.temperature})",
|
||||
flush=True,
|
||||
)
|
||||
self._debug_dump_request(step_num=step_num + 1, chat_kwargs=chat_kwargs)
|
||||
response = await self._chat_completion_with_debug(
|
||||
managed=managed, step_num=step_num + 1, chat_kwargs=chat_kwargs
|
||||
)
|
||||
self._debug_dump_response(step_num=step_num + 1, response=response)
|
||||
response_meta = self._extract_response_metadata(response)
|
||||
print(
|
||||
f"[AtroposAgent] step={step_num+1} chat_completion done in {time.perf_counter() - t_req:.2f}s",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
current_node = None
|
||||
if hasattr(managed, "get_state"):
|
||||
state = managed.get_state()
|
||||
nodes = state.get("nodes", [])
|
||||
current_node = nodes[-1] if nodes else None
|
||||
|
||||
except Exception as e:
|
||||
return AgentResult(
|
||||
success=False,
|
||||
final_response="",
|
||||
steps=steps,
|
||||
error=f"Generation error: {str(e)}",
|
||||
)
|
||||
|
||||
msg = response.choices[0].message
|
||||
# Some OpenAI-compatible servers populate `message.reasoning` and leave `content=""`.
|
||||
response_text = (msg.content or "") or (getattr(msg, "reasoning", None) or "")
|
||||
tool_calls = ToolCall.parse_from_text(response_text)
|
||||
last_node = current_node
|
||||
last_prompt_messages = prompt_messages
|
||||
last_response_text = response_text
|
||||
|
||||
step_sequence_data = SequenceData.from_sequence_node(current_node) if current_node else None
|
||||
if step_sequence_data is None:
|
||||
if response_meta:
|
||||
# We still want metadata for debugging even if token/logprob state tracking is unavailable.
|
||||
step_sequence_data = SequenceData(
|
||||
full_text=response_text,
|
||||
tokens=[],
|
||||
masked_tokens=[],
|
||||
logprobs=[],
|
||||
metadata=response_meta,
|
||||
)
|
||||
else:
|
||||
merged = dict(response_meta)
|
||||
node_meta = step_sequence_data.metadata
|
||||
if isinstance(node_meta, dict):
|
||||
merged.update(node_meta)
|
||||
step_sequence_data.metadata = merged or step_sequence_data.metadata
|
||||
|
||||
step = AgentStep(
|
||||
step_number=step_num + 1,
|
||||
assistant_message=response_text,
|
||||
tool_calls=tool_calls,
|
||||
sequence_data=step_sequence_data,
|
||||
)
|
||||
|
||||
if not tool_calls:
|
||||
steps.append(step)
|
||||
final_response = response_text
|
||||
final_node = current_node
|
||||
final_prompt_messages = prompt_messages
|
||||
break
|
||||
|
||||
messages.append({"role": "assistant", "content": response_text})
|
||||
|
||||
tool_responses = []
|
||||
for call in tool_calls:
|
||||
result = await self.execute_tool(call)
|
||||
step.tool_results.append(result)
|
||||
tool_responses.append(result.to_xml())
|
||||
if self.config.tool_delay_s > 0:
|
||||
await asyncio.sleep(self.config.tool_delay_s)
|
||||
|
||||
steps.append(step)
|
||||
|
||||
responses_text = "\n".join(tool_responses)
|
||||
# Tool observations are represented as user content with Hermes-style tags.
|
||||
# This is compatible with most OpenAI-compatible chat APIs and ensures
|
||||
# tokenizers/chat templates include tool outputs during training.
|
||||
messages.append({"role": "user", "content": responses_text})
|
||||
|
||||
else:
|
||||
# Reached max steps without completing
|
||||
# Return a failure result but include the last observed completion so callers can
|
||||
# record the trajectory (score=0) without triggering retries.
|
||||
final_response = last_response_text or final_response
|
||||
final_node = last_node
|
||||
final_prompt_messages = last_prompt_messages
|
||||
trajectory_data = None
|
||||
if final_node:
|
||||
trajectory_data = SequenceData.from_sequence_node(final_node)
|
||||
elif final_prompt_messages is not None and self.tokenizer is not None:
|
||||
if hasattr(self.tokenizer, "apply_chat_template"):
|
||||
prompt_text = self.tokenizer.apply_chat_template(
|
||||
final_prompt_messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
prompt_tokens = self.tokenizer.encode(prompt_text, add_special_tokens=False)
|
||||
else:
|
||||
prompt_text = "\n".join([f"{m['role']}: {m['content']}" for m in final_prompt_messages])
|
||||
prompt_tokens = self.tokenizer.encode(prompt_text, add_special_tokens=True)
|
||||
output_tokens = self.tokenizer.encode(final_response, add_special_tokens=False)
|
||||
tokens = prompt_tokens + output_tokens
|
||||
masked_tokens = ([-100] * len(prompt_tokens)) + output_tokens
|
||||
logprobs = ([1.0] * len(prompt_tokens)) + ([0.0] * len(output_tokens))
|
||||
trajectory_data = SequenceData(
|
||||
full_text=f"{prompt_text}{final_response}",
|
||||
tokens=tokens,
|
||||
masked_tokens=masked_tokens,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
# Preserve response metadata (if any) even on failure trajectories.
|
||||
try:
|
||||
if trajectory_data is not None and steps:
|
||||
last_step = steps[-1]
|
||||
if last_step.sequence_data and isinstance(last_step.sequence_data.metadata, dict):
|
||||
trajectory_data.metadata = dict(last_step.sequence_data.metadata)
|
||||
except Exception:
|
||||
pass
|
||||
return AgentResult(
|
||||
success=False,
|
||||
final_response=final_response,
|
||||
steps=steps,
|
||||
error=f"Reached maximum steps ({self.config.max_steps})",
|
||||
trajectory_data=trajectory_data,
|
||||
)
|
||||
|
||||
# Build result with trajectory data
|
||||
trajectory_data = None
|
||||
if final_node:
|
||||
trajectory_data = SequenceData.from_sequence_node(final_node)
|
||||
elif final_prompt_messages is not None and self.tokenizer is not None:
|
||||
if hasattr(self.tokenizer, "apply_chat_template"):
|
||||
prompt_text = self.tokenizer.apply_chat_template(
|
||||
final_prompt_messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
prompt_tokens = self.tokenizer.encode(prompt_text, add_special_tokens=False)
|
||||
else:
|
||||
prompt_text = "\n".join([f"{m['role']}: {m['content']}" for m in final_prompt_messages])
|
||||
prompt_tokens = self.tokenizer.encode(prompt_text, add_special_tokens=True)
|
||||
output_tokens = self.tokenizer.encode(final_response, add_special_tokens=False)
|
||||
tokens = prompt_tokens + output_tokens
|
||||
masked_tokens = ([-100] * len(prompt_tokens)) + output_tokens
|
||||
logprobs = ([1.0] * len(prompt_tokens)) + ([0.0] * len(output_tokens))
|
||||
trajectory_data = SequenceData(
|
||||
full_text=f"{prompt_text}{final_response}",
|
||||
tokens=tokens,
|
||||
masked_tokens=masked_tokens,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
|
||||
# Ensure trajectory_data carries the most recent metadata we observed (if any).
|
||||
try:
|
||||
if trajectory_data is not None and steps:
|
||||
last_step = steps[-1]
|
||||
if last_step.sequence_data and isinstance(last_step.sequence_data.metadata, dict):
|
||||
trajectory_data.metadata = dict(last_step.sequence_data.metadata)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return AgentResult(
|
||||
success=True,
|
||||
final_response=final_response,
|
||||
steps=steps,
|
||||
trajectory_data=trajectory_data,
|
||||
)
|
||||
|
||||
async def run_single_turn(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
execute_tools: bool = True,
|
||||
) -> tuple[str, List[ToolResult], Optional[SequenceData]]:
|
||||
"""
|
||||
Run a single turn of the agent (one LLM call + tool execution).
|
||||
|
||||
This is useful for integration with BaseEnv where you want more
|
||||
control over the loop.
|
||||
|
||||
Args:
|
||||
messages: The conversation history
|
||||
execute_tools: Whether to execute parsed tool calls
|
||||
|
||||
Returns:
|
||||
Tuple of (response_text, tool_results, sequence_data)
|
||||
"""
|
||||
async with self._managed() as managed:
|
||||
chat_kwargs: Dict[str, Any] = {"messages": messages, "n": 1}
|
||||
if self.config.max_tokens is not None:
|
||||
chat_kwargs["max_tokens"] = self.config.max_tokens
|
||||
if self.config.temperature is not None:
|
||||
chat_kwargs["temperature"] = self.config.temperature
|
||||
|
||||
self._debug_dump_request(step_num=1, chat_kwargs=chat_kwargs)
|
||||
response = await self._chat_completion_with_debug(managed=managed, step_num=1, chat_kwargs=chat_kwargs)
|
||||
self._debug_dump_response(step_num=1, response=response)
|
||||
|
||||
current_node = None
|
||||
if hasattr(managed, "get_state"):
|
||||
state = managed.get_state()
|
||||
nodes = state.get("nodes", [])
|
||||
current_node = nodes[-1] if nodes else None
|
||||
|
||||
msg = response.choices[0].message
|
||||
response_text = (msg.content or "") or (getattr(msg, "reasoning", None) or "")
|
||||
tool_results = []
|
||||
|
||||
if execute_tools:
|
||||
tool_calls = ToolCall.parse_from_text(response_text)
|
||||
for call in tool_calls:
|
||||
result = await self.execute_tool(call)
|
||||
tool_results.append(result)
|
||||
|
||||
sequence_data = SequenceData.from_sequence_node(current_node) if current_node else None
|
||||
|
||||
return response_text, tool_results, sequence_data
|
||||
|
||||
|
||||
class _DirectChatCompletionClient:
|
||||
"""
|
||||
Minimal stand-in for ManagedServer that calls the OpenAI-compatible endpoint directly.
|
||||
|
||||
This is for isolating issues where `ManagedServer.chat_completion()` hangs or misbehaves.
|
||||
It intentionally does NOT do token/logprob tracking.
|
||||
"""
|
||||
|
||||
def __init__(self, server: Any):
|
||||
self._server = server
|
||||
|
||||
def _server_config(self) -> tuple[str, str, str]:
|
||||
# ServerManager case: first configured server.
|
||||
servers = getattr(self._server, "servers", None)
|
||||
if isinstance(servers, list) and servers:
|
||||
s0 = servers[0]
|
||||
cfg = getattr(s0, "config", None)
|
||||
base_url = getattr(cfg, "base_url", None) or getattr(s0, "base_url", None)
|
||||
api_key = getattr(cfg, "api_key", None) or getattr(s0, "api_key", None)
|
||||
model = getattr(cfg, "model_name", None) or getattr(s0, "model_name", None)
|
||||
if isinstance(base_url, str) and isinstance(api_key, str) and isinstance(model, str):
|
||||
return base_url.rstrip("/"), api_key, model
|
||||
|
||||
# APIServer-like fallback.
|
||||
base_url = getattr(self._server, "base_url", None)
|
||||
api_key = getattr(self._server, "api_key", None)
|
||||
model = getattr(self._server, "model_name", None) or getattr(self._server, "model", None)
|
||||
if isinstance(base_url, str) and isinstance(api_key, str) and isinstance(model, str):
|
||||
return base_url.rstrip("/"), api_key, model
|
||||
|
||||
raise RuntimeError("Unable to resolve server base_url/api_key/model for direct chat completion")
|
||||
|
||||
async def chat_completion(self, *, messages: List[Dict[str, str]], n: int = 1, **kwargs: Any) -> Any:
|
||||
base_url, api_key, model = self._server_config()
|
||||
url = f"{base_url}/chat/completions"
|
||||
|
||||
payload: Dict[str, Any] = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"n": n,
|
||||
}
|
||||
# Pass through common generation kwargs.
|
||||
for k in ("max_tokens", "temperature", "top_p", "presence_penalty", "frequency_penalty", "stop"):
|
||||
if k in kwargs and kwargs[k] is not None:
|
||||
payload[k] = kwargs[k]
|
||||
|
||||
timeout_s = float(os.getenv("ATROPOS_DIRECT_REQUEST_TIMEOUT_S") or "120")
|
||||
print(f"[AtroposAgent] DIRECT chat_completion POST {url} (timeout={timeout_s}s)", flush=True)
|
||||
async with httpx.AsyncClient(timeout=timeout_s) as client:
|
||||
resp = await client.post(
|
||||
url,
|
||||
headers={"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"},
|
||||
json=payload,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
|
||||
# Return a very small object compatible with the code paths that read
|
||||
# `response.choices[0].message.content`.
|
||||
class _Msg:
|
||||
def __init__(self, d: Dict[str, Any]):
|
||||
self.content = d.get("content")
|
||||
self.reasoning = d.get("reasoning")
|
||||
|
||||
class _Choice:
|
||||
def __init__(self, d: Dict[str, Any]):
|
||||
self.message = _Msg(d.get("message") or {})
|
||||
|
||||
class _Resp:
|
||||
def __init__(self, d: Dict[str, Any]):
|
||||
self._d = d
|
||||
self.choices = [_Choice(c) for c in (d.get("choices") or [])]
|
||||
|
||||
def model_dump(self) -> Dict[str, Any]:
|
||||
return self._d
|
||||
|
||||
return _Resp(data)
|
||||
6
atropos/api/__init__.py
Normal file
6
atropos/api/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
"""
|
||||
FastAPI services for atropos-agent.
|
||||
|
||||
- tool_executor_server: queued/batched sandbox tool execution (Phase 4)
|
||||
"""
|
||||
|
||||
254
atropos/api/tool_executor_server.py
Normal file
254
atropos/api/tool_executor_server.py
Normal file
@@ -0,0 +1,254 @@
|
||||
"""
|
||||
Tool Executor API (Phase 4)
|
||||
|
||||
This service provides a queued, batched execution layer on top of a ToolBackend.
|
||||
It mirrors the stateful FastAPI + app.state pattern used in:
|
||||
atropos/atroposlib/api/server.py
|
||||
|
||||
Run (dev):
|
||||
uv run uvicorn atropos_agent.api.tool_executor_server:app --host 0.0.0.0 --port 9001
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import Any, Dict, Optional
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import FastAPI, Header, HTTPException, status
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ..backends.nomad_backend import NomadBackendConfig, NomadToolBackend
|
||||
from ..tools import ToolRegistry, build_tool_registry
|
||||
from ..tools.base import (
|
||||
ArtifactArchiveRequestPayload,
|
||||
ArtifactArchiveResponsePayload,
|
||||
ArtifactListRequestPayload,
|
||||
ArtifactListResponsePayload,
|
||||
ArtifactReadRequestPayload,
|
||||
ArtifactReadResponsePayload,
|
||||
ToolExecutorExecuteRequest,
|
||||
ToolExecutorReleaseRequest,
|
||||
ToolResultPayload,
|
||||
)
|
||||
from ..tools.tool_executor import ToolExecutor, ToolExecutorConfig
|
||||
|
||||
|
||||
class ToolExecutorServerConfig(BaseModel):
|
||||
nomad_address: str = Field(default="http://localhost:4646")
|
||||
job_id: str = Field(default="atropos-sandbox-tool-executor")
|
||||
image: str = Field(default="atropos-sandbox:local")
|
||||
slots_per_container: int = Field(default=10)
|
||||
min_containers: int = Field(default=1)
|
||||
max_containers: int = Field(default=10)
|
||||
privileged: bool = Field(default=False)
|
||||
acquire_timeout_s: float = Field(default=30.0)
|
||||
|
||||
batch_window_ms: int = Field(default=20)
|
||||
max_batch_size: int = Field(default=200)
|
||||
allow_network: bool = Field(default=True)
|
||||
|
||||
tool_server_url: Optional[str] = Field(default=None)
|
||||
tool_server_token: Optional[str] = Field(default=None)
|
||||
|
||||
token: Optional[str] = Field(default=None, description="Bearer token required for requests (optional in dev).")
|
||||
|
||||
purge_job_on_shutdown: bool = Field(default=True)
|
||||
|
||||
@classmethod
|
||||
def from_env(cls) -> "ToolExecutorServerConfig":
|
||||
# In dev, prefer loading secrets/config from the repo-local `.env` (not committed).
|
||||
try:
|
||||
from dotenv import load_dotenv # type: ignore
|
||||
except Exception: # pragma: no cover
|
||||
load_dotenv = None # type: ignore[assignment]
|
||||
if load_dotenv is not None:
|
||||
env_path = Path(__file__).resolve().parents[2] / ".env"
|
||||
if env_path.exists():
|
||||
load_dotenv(dotenv_path=env_path)
|
||||
|
||||
def _get_bool(name: str, default: bool) -> bool:
|
||||
raw = os.getenv(name)
|
||||
if raw is None:
|
||||
return default
|
||||
return raw.strip().lower() in {"1", "true", "yes", "y", "on"}
|
||||
|
||||
return cls(
|
||||
nomad_address=os.getenv("TOOL_EXECUTOR_NOMAD_ADDRESS", "http://localhost:4646"),
|
||||
job_id=os.getenv("TOOL_EXECUTOR_JOB_ID", "atropos-sandbox-tool-executor"),
|
||||
image=os.getenv("TOOL_EXECUTOR_IMAGE", "atropos-sandbox:local"),
|
||||
slots_per_container=int(os.getenv("TOOL_EXECUTOR_SLOTS", "10")),
|
||||
min_containers=int(os.getenv("TOOL_EXECUTOR_MIN_CONTAINERS", "1")),
|
||||
max_containers=int(os.getenv("TOOL_EXECUTOR_MAX_CONTAINERS", "10")),
|
||||
privileged=_get_bool("TOOL_EXECUTOR_PRIVILEGED", False),
|
||||
acquire_timeout_s=float(os.getenv("TOOL_EXECUTOR_ACQUIRE_TIMEOUT_S", "30.0")),
|
||||
batch_window_ms=int(os.getenv("TOOL_EXECUTOR_BATCH_WINDOW_MS", "20")),
|
||||
max_batch_size=int(os.getenv("TOOL_EXECUTOR_MAX_BATCH_SIZE", "200")),
|
||||
allow_network=_get_bool("TOOL_EXECUTOR_ALLOW_NETWORK", True),
|
||||
tool_server_url=os.getenv("TOOL_EXECUTOR_TOOL_SERVER_URL") or None,
|
||||
tool_server_token=os.getenv("TOOL_EXECUTOR_TOOL_SERVER_TOKEN") or None,
|
||||
token=os.getenv("TOOL_EXECUTOR_TOKEN") or None,
|
||||
purge_job_on_shutdown=_get_bool("TOOL_EXECUTOR_PURGE_JOB_ON_SHUTDOWN", True),
|
||||
)
|
||||
|
||||
|
||||
app = FastAPI(title="Atropos-Agent Tool Executor")
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def root() -> Dict[str, str]:
|
||||
return {"message": "Atropos-Agent Tool Executor"}
|
||||
|
||||
|
||||
def _check_auth(cfg: ToolExecutorServerConfig, authorization: Optional[str]) -> None:
|
||||
if not cfg.token:
|
||||
return
|
||||
if not authorization:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Missing Authorization header")
|
||||
if not authorization.lower().startswith("bearer "):
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid Authorization header")
|
||||
token = authorization.split(" ", 1)[1].strip()
|
||||
if token != cfg.token:
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Invalid token")
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
async def _startup() -> None:
|
||||
cfg = ToolExecutorServerConfig.from_env()
|
||||
|
||||
# Default to Atropos "full" tool surface: sandbox + external (if tool_server_url provided).
|
||||
tools: ToolRegistry = build_tool_registry(
|
||||
enabled_toolsets=["full"],
|
||||
disabled_toolsets=None,
|
||||
tool_server_url=cfg.tool_server_url,
|
||||
)
|
||||
|
||||
backend = NomadToolBackend(
|
||||
NomadBackendConfig(
|
||||
nomad_address=cfg.nomad_address,
|
||||
sandbox_job_id=cfg.job_id,
|
||||
sandbox_image=cfg.image,
|
||||
slots_per_container=cfg.slots_per_container,
|
||||
min_containers=cfg.min_containers,
|
||||
max_containers=cfg.max_containers,
|
||||
privileged=cfg.privileged,
|
||||
acquire_timeout_s=cfg.acquire_timeout_s,
|
||||
purge_job_on_start=False,
|
||||
)
|
||||
)
|
||||
await backend.start()
|
||||
|
||||
executor = ToolExecutor(
|
||||
backend=backend,
|
||||
tools=tools,
|
||||
config=ToolExecutorConfig(
|
||||
batch_window_ms=cfg.batch_window_ms,
|
||||
max_batch_size=cfg.max_batch_size,
|
||||
allow_network=cfg.allow_network,
|
||||
tool_server_url=cfg.tool_server_url,
|
||||
tool_server_token=cfg.tool_server_token,
|
||||
),
|
||||
)
|
||||
await executor.start()
|
||||
|
||||
app.state.cfg = cfg
|
||||
app.state.backend = backend
|
||||
app.state.executor = executor
|
||||
|
||||
|
||||
@app.on_event("shutdown")
|
||||
async def _shutdown() -> None:
|
||||
executor: Optional[ToolExecutor] = getattr(app.state, "executor", None)
|
||||
backend: Optional[NomadToolBackend] = getattr(app.state, "backend", None)
|
||||
cfg: Optional[ToolExecutorServerConfig] = getattr(app.state, "cfg", None)
|
||||
|
||||
if executor is not None:
|
||||
await executor.close()
|
||||
|
||||
if backend is not None:
|
||||
await backend.stop(purge=bool(cfg.purge_job_on_shutdown) if cfg else False)
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health() -> Dict[str, Any]:
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
@app.get("/status")
|
||||
async def status_endpoint() -> Dict[str, Any]:
|
||||
executor: ToolExecutor = app.state.executor
|
||||
backend: NomadToolBackend = app.state.backend
|
||||
|
||||
return {
|
||||
"queue_size": executor.queue_size(),
|
||||
"total_requests": executor.total_requests,
|
||||
"total_errors": executor.total_errors,
|
||||
"pool": backend.get_stats(),
|
||||
}
|
||||
|
||||
|
||||
@app.post("/execute", response_model=ToolResultPayload)
|
||||
async def execute_tool(
|
||||
req: ToolExecutorExecuteRequest,
|
||||
authorization: Optional[str] = Header(default=None),
|
||||
status_code: int = status.HTTP_200_OK, # noqa: B008
|
||||
) -> ToolResultPayload:
|
||||
cfg: ToolExecutorServerConfig = app.state.cfg
|
||||
_check_auth(cfg, authorization)
|
||||
|
||||
executor: ToolExecutor = app.state.executor
|
||||
result = await executor.execute(
|
||||
trajectory_id=req.trajectory_id,
|
||||
call=req.tool.to_tool_call(),
|
||||
timeout_s=req.timeout_s,
|
||||
)
|
||||
return ToolResultPayload.from_tool_result(result)
|
||||
|
||||
|
||||
@app.post("/release")
|
||||
async def release_trajectory(
|
||||
req: ToolExecutorReleaseRequest,
|
||||
authorization: Optional[str] = Header(default=None),
|
||||
) -> Dict[str, Any]:
|
||||
cfg: ToolExecutorServerConfig = app.state.cfg
|
||||
_check_auth(cfg, authorization)
|
||||
|
||||
executor: ToolExecutor = app.state.executor
|
||||
await executor.release_trajectory(req.trajectory_id, reset_workspace=req.reset_workspace)
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
@app.post("/artifacts/read", response_model=ArtifactReadResponsePayload)
|
||||
async def artifacts_read(
|
||||
req: ArtifactReadRequestPayload,
|
||||
authorization: Optional[str] = Header(default=None),
|
||||
) -> ArtifactReadResponsePayload:
|
||||
cfg: ToolExecutorServerConfig = app.state.cfg
|
||||
_check_auth(cfg, authorization)
|
||||
|
||||
executor: ToolExecutor = app.state.executor
|
||||
return await executor.read_artifact(req)
|
||||
|
||||
|
||||
@app.post("/artifacts/list", response_model=ArtifactListResponsePayload)
|
||||
async def artifacts_list(
|
||||
req: ArtifactListRequestPayload,
|
||||
authorization: Optional[str] = Header(default=None),
|
||||
) -> ArtifactListResponsePayload:
|
||||
cfg: ToolExecutorServerConfig = app.state.cfg
|
||||
_check_auth(cfg, authorization)
|
||||
|
||||
executor: ToolExecutor = app.state.executor
|
||||
return await executor.list_artifacts(req)
|
||||
|
||||
|
||||
@app.post("/artifacts/archive", response_model=ArtifactArchiveResponsePayload)
|
||||
async def artifacts_archive(
|
||||
req: ArtifactArchiveRequestPayload,
|
||||
authorization: Optional[str] = Header(default=None),
|
||||
) -> ArtifactArchiveResponsePayload:
|
||||
cfg: ToolExecutorServerConfig = app.state.cfg
|
||||
_check_auth(cfg, authorization)
|
||||
|
||||
executor: ToolExecutor = app.state.executor
|
||||
return await executor.archive_artifacts(req)
|
||||
140
atropos/api/tool_server.py
Normal file
140
atropos/api/tool_server.py
Normal file
@@ -0,0 +1,140 @@
|
||||
"""
|
||||
External ToolServer (Phase 4.5+).
|
||||
|
||||
This server executes tools that must NOT run inside the sandbox, typically
|
||||
because they require credentials or access to external services.
|
||||
|
||||
Run (dev):
|
||||
uv run uvicorn atropos_agent.api.tool_server:app --host 0.0.0.0 --port 9002
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import inspect
|
||||
from typing import Any, Dict, List, Optional
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import FastAPI, Header, HTTPException, status
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ..tools import ToolRegistry, build_tool_registry
|
||||
from ..tools.base import ToolResultPayload, ToolServerExecuteRequest
|
||||
|
||||
|
||||
class ToolServerConfig(BaseModel):
|
||||
token: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Bearer token required for requests (optional in dev).",
|
||||
)
|
||||
max_concurrency: int = Field(default=16, ge=1, description="Max concurrent tool executions.")
|
||||
|
||||
@classmethod
|
||||
def from_env(cls) -> "ToolServerConfig":
|
||||
# In dev, prefer loading secrets from the repo-local `.env` (not committed).
|
||||
try:
|
||||
from dotenv import load_dotenv # type: ignore
|
||||
except Exception: # pragma: no cover
|
||||
load_dotenv = None # type: ignore[assignment]
|
||||
if load_dotenv is not None:
|
||||
env_path = Path(__file__).resolve().parents[2] / ".env"
|
||||
if env_path.exists():
|
||||
load_dotenv(dotenv_path=env_path)
|
||||
|
||||
token = os.getenv("TOOL_SERVER_TOKEN") or None
|
||||
max_concurrency = int(os.getenv("TOOL_SERVER_MAX_CONCURRENCY", "16"))
|
||||
return cls(token=token, max_concurrency=max_concurrency)
|
||||
|
||||
|
||||
app = FastAPI(title="Atropos-Agent Tool Server")
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def root() -> Dict[str, str]:
|
||||
return {"message": "Atropos-Agent Tool Server"}
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
async def _startup() -> None:
|
||||
cfg = ToolServerConfig.from_env()
|
||||
|
||||
# External-only registry. It will only include tools that are enabled by toolsets and
|
||||
# whose Hermes requirements/keys are satisfied in this process.
|
||||
tools: ToolRegistry = build_tool_registry(
|
||||
enabled_toolsets=["all"],
|
||||
disabled_toolsets=["terminal", "sandbox", "filesystem", "terminal_stateful", "default"],
|
||||
tool_server_url="enabled",
|
||||
)
|
||||
|
||||
app.state.cfg = cfg
|
||||
app.state.tools = tools
|
||||
app.state.semaphore = asyncio.Semaphore(cfg.max_concurrency)
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health() -> Dict[str, Any]:
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
@app.get("/tools")
|
||||
async def list_tools() -> Dict[str, Any]:
|
||||
tools: ToolRegistry = app.state.tools
|
||||
return {"tools": [s.to_dict() for s in tools.get_schemas()]}
|
||||
|
||||
|
||||
def _check_auth(cfg: ToolServerConfig, authorization: Optional[str]) -> None:
|
||||
if not cfg.token:
|
||||
return
|
||||
if not authorization:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Missing Authorization header")
|
||||
if not authorization.lower().startswith("bearer "):
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid Authorization header")
|
||||
token = authorization.split(" ", 1)[1].strip()
|
||||
if token != cfg.token:
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Invalid token")
|
||||
|
||||
|
||||
@app.post("/execute", response_model=ToolResultPayload)
|
||||
async def execute_tool(
|
||||
req: ToolServerExecuteRequest,
|
||||
authorization: Optional[str] = Header(default=None),
|
||||
) -> ToolResultPayload:
|
||||
cfg: ToolServerConfig = app.state.cfg
|
||||
_check_auth(cfg, authorization)
|
||||
|
||||
tools: ToolRegistry = app.state.tools
|
||||
sem: asyncio.Semaphore = app.state.semaphore
|
||||
|
||||
tool = tools.get(req.tool.name)
|
||||
if tool is None:
|
||||
return ToolResultPayload(
|
||||
success=False,
|
||||
error=f"Unknown tool: {req.tool.name}",
|
||||
uniq_id=req.tool.uniq_id,
|
||||
)
|
||||
|
||||
async with sem:
|
||||
try:
|
||||
kwargs = dict(req.tool.arguments)
|
||||
sig = inspect.signature(tool.execute).parameters
|
||||
# Some tools can benefit from extra context.
|
||||
if req.trajectory_id and "trajectory_id" in sig:
|
||||
kwargs["trajectory_id"] = req.trajectory_id
|
||||
if req.slot_id and "slot_id" in sig:
|
||||
kwargs["slot_id"] = req.slot_id
|
||||
if req.container_addr and "container_addr" in sig:
|
||||
kwargs["container_addr"] = req.container_addr
|
||||
if "task_id" in sig:
|
||||
kwargs["task_id"] = req.trajectory_id
|
||||
result = await tool.execute(**kwargs)
|
||||
except Exception as e:
|
||||
return ToolResultPayload(
|
||||
success=False,
|
||||
error=f"Tool execution error: {e}",
|
||||
uniq_id=req.tool.uniq_id,
|
||||
)
|
||||
|
||||
if result.uniq_id is None:
|
||||
result.uniq_id = req.tool.uniq_id
|
||||
return ToolResultPayload.from_tool_result(result)
|
||||
27
atropos/backends/__init__.py
Normal file
27
atropos/backends/__init__.py
Normal file
@@ -0,0 +1,27 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from .base import ToolBackend
|
||||
from .modal_backend import ModalSandboxConfig, ModalToolBackend
|
||||
from .nomad_backend import NomadBackendConfig, NomadToolBackend
|
||||
|
||||
|
||||
def create_tool_backend(cfg: Any) -> ToolBackend:
|
||||
mode = str(getattr(cfg, "tool_pool_mode", "nomad")).strip().lower()
|
||||
if mode == "nomad":
|
||||
return NomadToolBackend(NomadBackendConfig.from_agent_env_config(cfg))
|
||||
if mode == "modal":
|
||||
return ModalToolBackend(ModalSandboxConfig.from_agent_env_config(cfg))
|
||||
raise ValueError(f"Unknown tool_pool_mode: {mode}")
|
||||
|
||||
|
||||
__all__ = [
|
||||
"ToolBackend",
|
||||
"create_tool_backend",
|
||||
"NomadBackendConfig",
|
||||
"NomadToolBackend",
|
||||
"ModalSandboxConfig",
|
||||
"ModalToolBackend",
|
||||
]
|
||||
|
||||
89
atropos/backends/base.py
Normal file
89
atropos/backends/base.py
Normal file
@@ -0,0 +1,89 @@
|
||||
"""
|
||||
Backend interfaces for AgentEnv tool execution.
|
||||
|
||||
The goal of this module is to decouple ToolExecutor / AgentEnv from any single
|
||||
execution backend (Nomad/Docker today; Modal later).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Optional, Protocol, Tuple
|
||||
|
||||
from ..slots.executor import ExecutionResult
|
||||
from ..slots.slot import Slot
|
||||
|
||||
|
||||
class ToolBackend(Protocol):
|
||||
"""
|
||||
Minimal interface required by ToolExecutor.
|
||||
|
||||
Backends provide:
|
||||
- lifecycle (start/stop)
|
||||
- slot acquisition/release (workspace affinity)
|
||||
- batched tool execution across slots
|
||||
- optional artifact helpers (for env verification / demos)
|
||||
"""
|
||||
|
||||
@property
|
||||
def default_timeout_s(self) -> Optional[float]:
|
||||
"""Default sandbox execution timeout in seconds (if any)."""
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the backend (provision workers/containers, health checks, etc)."""
|
||||
|
||||
async def stop(self, *, purge: bool = False) -> None:
|
||||
"""Stop the backend and optionally purge remote resources."""
|
||||
|
||||
async def acquire(self, trajectory_id: Optional[str] = None) -> Slot:
|
||||
"""Acquire a slot for a trajectory (workspace affinity)."""
|
||||
|
||||
async def release(self, slot: Slot, *, reset_workspace: bool = False) -> None:
|
||||
"""Release a slot back to the pool."""
|
||||
|
||||
async def execute_batch(
|
||||
self,
|
||||
requests: List[Tuple[Slot, str, Dict[str, Any]]],
|
||||
*,
|
||||
timeout_s: Optional[float] = None,
|
||||
) -> List[ExecutionResult]:
|
||||
"""Execute a batch of sandbox tool calls and return results in order."""
|
||||
|
||||
# ---------------------------------------------------------------------
|
||||
# Optional artifact helpers (supported by the Nomad sandbox-server today)
|
||||
# ---------------------------------------------------------------------
|
||||
|
||||
async def read_artifact(
|
||||
self,
|
||||
slot: Slot,
|
||||
path: str,
|
||||
*,
|
||||
encoding: str = "text",
|
||||
max_bytes: Optional[int] = None,
|
||||
include_sha256: bool = False,
|
||||
timeout_s: Optional[float] = None,
|
||||
) -> Dict[str, Any]:
|
||||
raise NotImplementedError
|
||||
|
||||
async def list_artifacts(
|
||||
self,
|
||||
slot: Slot,
|
||||
path: str = ".",
|
||||
*,
|
||||
recursive: bool = False,
|
||||
max_entries: Optional[int] = None,
|
||||
timeout_s: Optional[float] = None,
|
||||
) -> Dict[str, Any]:
|
||||
raise NotImplementedError
|
||||
|
||||
async def archive_artifacts(
|
||||
self,
|
||||
slot: Slot,
|
||||
path: str = ".",
|
||||
*,
|
||||
archive_format: str = "tar.gz",
|
||||
max_bytes: Optional[int] = None,
|
||||
max_entries: Optional[int] = None,
|
||||
timeout_s: Optional[float] = None,
|
||||
) -> Dict[str, Any]:
|
||||
raise NotImplementedError
|
||||
|
||||
1179
atropos/backends/modal_backend.py
Normal file
1179
atropos/backends/modal_backend.py
Normal file
File diff suppressed because it is too large
Load Diff
156
atropos/backends/nomad_backend.py
Normal file
156
atropos/backends/nomad_backend.py
Normal file
@@ -0,0 +1,156 @@
|
||||
"""
|
||||
Nomad/Docker tool backend.
|
||||
|
||||
This backend is the current default for AgentEnv: it provisions a Nomad job
|
||||
running `sandbox_server.py` and multiplexes stateless slots inside each container.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from ..slots import Slot, SlotPool, SlotPoolConfig
|
||||
from ..slots.executor import ExecutionResult
|
||||
from .base import ToolBackend
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class NomadBackendConfig:
|
||||
nomad_address: str
|
||||
sandbox_job_id: str
|
||||
sandbox_image: str
|
||||
slots_per_container: int
|
||||
min_containers: int
|
||||
max_containers: int
|
||||
privileged: bool
|
||||
acquire_timeout_s: float
|
||||
purge_job_on_start: bool
|
||||
# Driver selection: "docker" or "singularity"
|
||||
driver: str = "docker"
|
||||
# Path to .sif file for singularity driver (required if driver="singularity")
|
||||
singularity_image: Optional[str] = None
|
||||
|
||||
@classmethod
|
||||
def from_agent_env_config(cls, cfg: Any) -> "NomadBackendConfig":
|
||||
return cls(
|
||||
nomad_address=str(getattr(cfg, "nomad_address")),
|
||||
sandbox_job_id=str(getattr(cfg, "sandbox_job_id")),
|
||||
sandbox_image=str(getattr(cfg, "sandbox_image")),
|
||||
slots_per_container=int(getattr(cfg, "slots_per_container")),
|
||||
min_containers=int(getattr(cfg, "min_containers")),
|
||||
max_containers=int(getattr(cfg, "max_containers")),
|
||||
privileged=bool(getattr(cfg, "privileged")),
|
||||
acquire_timeout_s=float(getattr(cfg, "acquire_timeout_s")),
|
||||
purge_job_on_start=bool(getattr(cfg, "purge_job_on_start", False)),
|
||||
driver=str(getattr(cfg, "driver", "docker")),
|
||||
singularity_image=getattr(cfg, "singularity_image", None),
|
||||
)
|
||||
|
||||
|
||||
class NomadToolBackend(ToolBackend):
|
||||
def __init__(self, config: NomadBackendConfig):
|
||||
self.config = config
|
||||
self.pool = SlotPool(
|
||||
SlotPoolConfig(
|
||||
nomad_address=config.nomad_address,
|
||||
job_id=config.sandbox_job_id,
|
||||
image=config.sandbox_image,
|
||||
slots_per_container=config.slots_per_container,
|
||||
min_containers=config.min_containers,
|
||||
max_containers=config.max_containers,
|
||||
privileged=config.privileged,
|
||||
acquire_timeout=config.acquire_timeout_s,
|
||||
purge_job_on_start=bool(config.purge_job_on_start),
|
||||
driver=config.driver,
|
||||
singularity_image=config.singularity_image,
|
||||
)
|
||||
)
|
||||
|
||||
@property
|
||||
def default_timeout_s(self) -> Optional[float]:
|
||||
t = getattr(self.pool.executor, "timeout", None)
|
||||
total = getattr(t, "total", None)
|
||||
try:
|
||||
return float(total) if total is not None else None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
async def start(self) -> None:
|
||||
await self.pool.start()
|
||||
|
||||
async def stop(self, *, purge: bool = False) -> None:
|
||||
await self.pool.stop(purge_job=purge)
|
||||
|
||||
async def acquire(self, trajectory_id: Optional[str] = None) -> Slot:
|
||||
return await self.pool.acquire(trajectory_id)
|
||||
|
||||
async def release(self, slot: Slot, *, reset_workspace: bool = False) -> None:
|
||||
await self.pool.release(slot, reset_workspace=reset_workspace)
|
||||
|
||||
async def execute_batch(
|
||||
self,
|
||||
requests: List[Tuple[Slot, str, Dict[str, Any]]],
|
||||
*,
|
||||
timeout_s: Optional[float] = None,
|
||||
) -> List[ExecutionResult]:
|
||||
return await self.pool.execute_batch(requests, timeout=timeout_s)
|
||||
|
||||
async def read_artifact(
|
||||
self,
|
||||
slot: Slot,
|
||||
path: str,
|
||||
*,
|
||||
encoding: str = "text",
|
||||
max_bytes: Optional[int] = None,
|
||||
include_sha256: bool = False,
|
||||
timeout_s: Optional[float] = None,
|
||||
) -> Dict[str, Any]:
|
||||
return await self.pool.executor.read_artifact(
|
||||
slot,
|
||||
path,
|
||||
encoding=encoding,
|
||||
max_bytes=max_bytes,
|
||||
include_sha256=include_sha256,
|
||||
timeout=timeout_s,
|
||||
)
|
||||
|
||||
async def list_artifacts(
|
||||
self,
|
||||
slot: Slot,
|
||||
path: str = ".",
|
||||
*,
|
||||
recursive: bool = False,
|
||||
max_entries: Optional[int] = None,
|
||||
timeout_s: Optional[float] = None,
|
||||
) -> Dict[str, Any]:
|
||||
return await self.pool.executor.list_artifacts(
|
||||
slot,
|
||||
path,
|
||||
recursive=recursive,
|
||||
max_entries=max_entries,
|
||||
timeout=timeout_s,
|
||||
)
|
||||
|
||||
async def archive_artifacts(
|
||||
self,
|
||||
slot: Slot,
|
||||
path: str = ".",
|
||||
*,
|
||||
archive_format: str = "tar.gz",
|
||||
max_bytes: Optional[int] = None,
|
||||
max_entries: Optional[int] = None,
|
||||
timeout_s: Optional[float] = None,
|
||||
) -> Dict[str, Any]:
|
||||
return await self.pool.executor.archive_artifacts(
|
||||
slot,
|
||||
path,
|
||||
archive_format=archive_format,
|
||||
max_bytes=max_bytes,
|
||||
max_entries=max_entries,
|
||||
timeout=timeout_s,
|
||||
)
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
return self.pool.get_stats()
|
||||
|
||||
18
atropos/envs/__init__.py
Normal file
18
atropos/envs/__init__.py
Normal file
@@ -0,0 +1,18 @@
|
||||
"""
|
||||
Environment implementations for atropos-agent.
|
||||
|
||||
NOTE: AgentEnv is the OLD environment system, replaced by
|
||||
environments/hermes_base_env.py (HermesAgentBaseEnv).
|
||||
Import is lazy to avoid pulling in deleted dependencies.
|
||||
"""
|
||||
|
||||
|
||||
def __getattr__(name):
|
||||
"""Lazy import to avoid breaking when old dependencies are removed."""
|
||||
if name in ("AgentEnv", "AgentEnvConfig"):
|
||||
from .agent_env import AgentEnv, AgentEnvConfig
|
||||
return {"AgentEnv": AgentEnv, "AgentEnvConfig": AgentEnvConfig}[name]
|
||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||
|
||||
|
||||
__all__ = ["AgentEnv", "AgentEnvConfig"]
|
||||
537
atropos/envs/agent_env.py
Normal file
537
atropos/envs/agent_env.py
Normal file
@@ -0,0 +1,537 @@
|
||||
"""
|
||||
AgentEnv - Atropos BaseEnv extension for agent/tool-call workloads.
|
||||
|
||||
AgentEnv is responsible for starting the sandbox tool execution backend and
|
||||
providing helpers for running agent trajectories with queued/batched tool calls.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import os
|
||||
import asyncio
|
||||
import time
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Awaitable, Callable, Dict, Generic, List, Optional, Tuple, TypeVar
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from atroposlib.envs.base import APIServerConfig, BaseEnv, BaseEnvConfig, Item, ScoredDataGroup, ScoredDataItem
|
||||
from atroposlib.envs.server_handling.server_baseline import AsyncSemWithAdaptiveWeight
|
||||
|
||||
from ..agent import AgentConfig, AgentResult, AtroposAgent
|
||||
from ..backends import ToolBackend, create_tool_backend
|
||||
from ..tools import ToolRegistry, build_tool_registry
|
||||
from ..tools.tool_executor import ToolExecutor, ToolExecutorConfig
|
||||
|
||||
# Main BaseEnv child classes. Child class THESE to get agent+tooling functionality easily.
|
||||
|
||||
class AgentEnvConfig(BaseEnvConfig):
|
||||
tool_pool_mode: str = Field(default="nomad", description="Tool execution backend ('nomad' or 'modal')")
|
||||
|
||||
allow_network: bool = Field(
|
||||
default=True,
|
||||
description="Whether sandbox bash commands may access the network (env policy).",
|
||||
)
|
||||
require_sandbox: bool = Field(
|
||||
default=False,
|
||||
description="Fail closed if bubblewrap sandboxing is unavailable/unusable for stateless sandbox tools.",
|
||||
)
|
||||
require_stateful_sandbox: bool = Field(
|
||||
default=False,
|
||||
description="Fail closed if bubblewrap/PID isolation is unavailable for stateful terminal tools (tmux).",
|
||||
)
|
||||
tool_batch_window_ms: int = Field(default=20, description="ToolExecutor batching window (ms)")
|
||||
tool_max_batch_size: int = Field(default=200, description="ToolExecutor maximum batch size")
|
||||
|
||||
# nomad mode settings. TODO: Add Modal support, split this into own config
|
||||
nomad_address: str = Field(default="http://localhost:4646", description="Nomad API address")
|
||||
sandbox_job_id: str = Field(default="atropos-sandbox-agent-env", description="Nomad job id for sandbox containers")
|
||||
sandbox_image: str = Field(default="atropos-sandbox:local", description="Docker image for sandbox containers")
|
||||
slots_per_container: int = Field(default=10, description="Nomad mode: slots per container")
|
||||
min_containers: int = Field(default=1, description="Nomad mode: minimum containers")
|
||||
max_containers: int = Field(default=10, description="Nomad mode: maximum containers")
|
||||
privileged: bool = Field(default=False, description="Nomad mode: run container privileged")
|
||||
acquire_timeout_s: float = Field(default=30.0, description="Slot acquisition timeout (seconds)")
|
||||
purge_job_on_start: bool = Field(
|
||||
default=False,
|
||||
description=(
|
||||
"Nomad mode: stop/purge the sandbox job on startup. This is helpful in local dev and training runs "
|
||||
"to recover from previous crashes that leave the job in a restart backoff state."
|
||||
),
|
||||
)
|
||||
purge_job_on_shutdown: bool = Field(default=True, description="Nomad mode: stop/purge job on shutdown")
|
||||
|
||||
# Nomad driver selection (docker or singularity)
|
||||
driver: str = Field(
|
||||
default="docker",
|
||||
description="Nomad task driver: 'docker' (default) or 'singularity' (for HPC without sudo Docker)",
|
||||
)
|
||||
singularity_image: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Path to .sif file for Singularity driver (required if driver='singularity')",
|
||||
)
|
||||
|
||||
# Modal mode settings
|
||||
modal_app_name: str = Field(default="atropos-sandbox", description="Modal app name prefix")
|
||||
modal_image: str = Field(default="python:3.11", description="Modal: container image")
|
||||
modal_gpu: Optional[str] = Field(default=None, description="Modal: GPU type (None, 'T4', 'A10G', 'A100', 'H100')")
|
||||
modal_cpu: float = Field(default=1.0, description="Modal: CPU cores")
|
||||
modal_memory: int = Field(default=2048, description="Modal: memory in MB")
|
||||
modal_slots_per_sandbox: int = Field(default=10, description="Modal: slots per sandbox")
|
||||
modal_min_sandboxes: int = Field(default=1, description="Modal: minimum sandboxes")
|
||||
modal_max_sandboxes: int = Field(default=5, description="Modal: maximum sandboxes")
|
||||
modal_idle_timeout: int = Field(default=120, description="Modal: server-side idle timeout (seconds)")
|
||||
modal_max_lifetime: int = Field(default=3600, description="Modal: max sandbox lifetime (seconds)")
|
||||
modal_acquire_timeout: float = Field(default=60.0, description="Modal: slot acquisition timeout (seconds)")
|
||||
modal_execution_timeout: float = Field(default=30.0, description="Modal: default command execution timeout (seconds)")
|
||||
modal_secrets: str = Field(default="", description="Modal: comma-separated list of Modal Secret names")
|
||||
modal_env_vars: str = Field(default="", description="Modal: semicolon-separated KEY=VALUE pairs for env vars")
|
||||
modal_workspace_base: str = Field(default="/data", description="Modal: workspace base directory in sandbox")
|
||||
|
||||
# basic agent defaults
|
||||
agent_max_steps: int = Field(default=50, description="Max ReACT steps per trajectory")
|
||||
agent_temperature: float = Field(default=0.7, description="Sampling temperature")
|
||||
agent_max_tokens: Optional[int] = Field(
|
||||
default=None,
|
||||
description="Max tokens per model response (default: let backend decide)",
|
||||
)
|
||||
agent_tool_delay_s: float = Field(default=0.0, description="Delay between tool calls (seconds)")
|
||||
|
||||
# tool selection
|
||||
enabled_toolsets: List[str] = Field(
|
||||
default_factory=lambda: ["default"],
|
||||
description="Toolsets to enable (Hermes-style grouping).",
|
||||
)
|
||||
disabled_toolsets: List[str] = Field(
|
||||
default_factory=list,
|
||||
description="Toolsets to disable (applied after enabled_toolsets).",
|
||||
)
|
||||
|
||||
# external ToolServer routing (Phase 4.5+)
|
||||
tool_server_url: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Base URL for external ToolServer (enables external tools).",
|
||||
)
|
||||
tool_server_token: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Bearer token for ToolServer auth (optional in dev).",
|
||||
)
|
||||
|
||||
AgentEnvConfigT = TypeVar("AgentEnvConfigT", bound="AgentEnvConfig")
|
||||
|
||||
|
||||
class AgentEnv(BaseEnv, ABC, Generic[AgentEnvConfigT]):
|
||||
env_config_cls = AgentEnvConfig
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: AgentEnvConfigT,
|
||||
server_configs: List[APIServerConfig],
|
||||
slurm: bool = False,
|
||||
testing: bool = False,
|
||||
):
|
||||
super().__init__(config, server_configs, slurm, testing)
|
||||
self.config: AgentEnvConfigT = config
|
||||
|
||||
self.tools: ToolRegistry = self.build_tools()
|
||||
|
||||
self._backend: Optional[ToolBackend] = None
|
||||
self._tool_executor: Optional[ToolExecutor] = None
|
||||
self._tool_server_inprocess: bool = False
|
||||
self._trajectory_workspace_meta: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
def build_tools(self) -> ToolRegistry:
|
||||
"""Wraps original Hermes-Agent ToolRegistry for atropos AgentEnv use.
|
||||
See Hermes-Agent docs for toolsets and available tools etc.
|
||||
"""
|
||||
return build_tool_registry(
|
||||
enabled_toolsets=self.config.enabled_toolsets or ["default"],
|
||||
disabled_toolsets=self.config.disabled_toolsets or None,
|
||||
tool_server_url=self.config.tool_server_url,
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def build_task(self, item: Item) -> str:
|
||||
"""Return the user-facing task string for the agent."""
|
||||
|
||||
@abstractmethod
|
||||
async def score_trajectory(self, item: Item, final_response: str) -> float:
|
||||
"""Return a scalar score for this trajectory."""
|
||||
|
||||
async def setup_trajectory_workspace(
|
||||
self,
|
||||
item: Item,
|
||||
*,
|
||||
trajectory_id: str,
|
||||
exec_tool: Callable[["ToolCall"], Awaitable["ToolResult"]],
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Optional hook: prepare the sandbox workspace before the agent starts.
|
||||
|
||||
Examples:
|
||||
- clone a repo and checkout a commit
|
||||
- write fixture files (e.g. images) for external-tool demos
|
||||
- pre-install dependencies
|
||||
|
||||
Default: no-op.
|
||||
"""
|
||||
_ = (item, trajectory_id, exec_tool)
|
||||
return {}
|
||||
|
||||
async def verify_and_score_trajectory(
|
||||
self,
|
||||
item: Item,
|
||||
final_response: str,
|
||||
*,
|
||||
trajectory_id: str,
|
||||
exec_tool: Callable[["ToolCall"], Awaitable["ToolResult"]],
|
||||
agent_result: Optional[AgentResult] = None,
|
||||
workspace_meta: Optional[Dict[str, Any]] = None,
|
||||
) -> tuple[float, Dict[str, Any]]:
|
||||
"""
|
||||
Optional hook: run in-sandbox verification before scoring.
|
||||
|
||||
Many agent envs need to execute verification inside the same trajectory
|
||||
workspace (e.g. pytest) before releasing/resetting the slot.
|
||||
|
||||
Default: calls `score_trajectory()` and returns empty metadata.
|
||||
"""
|
||||
_ = (trajectory_id, exec_tool, agent_result, workspace_meta) # default ignores in-workspace verification
|
||||
score = await self.score_trajectory(item, final_response)
|
||||
return score, {}
|
||||
|
||||
def build_agent_config(self, item: Item) -> AgentConfig: # noqa: ARG002
|
||||
return AgentConfig(
|
||||
max_steps=self.config.agent_max_steps,
|
||||
temperature=self.config.agent_temperature,
|
||||
max_tokens=self.config.agent_max_tokens,
|
||||
tool_delay_s=self.config.agent_tool_delay_s,
|
||||
)
|
||||
|
||||
async def setup(self) -> None:
|
||||
print(f"[AgentEnv] setup(): starting tool backend ({self.config.tool_pool_mode})", flush=True)
|
||||
await self._start_tool_backend()
|
||||
print("[AgentEnv] setup(): configuring server concurrency", flush=True)
|
||||
self._configure_server_concurrency()
|
||||
print("[AgentEnv] setup(): running env-specific setup_agent_env()", flush=True)
|
||||
await self.setup_agent_env()
|
||||
print("[AgentEnv] setup(): done", flush=True)
|
||||
|
||||
def _configure_server_concurrency(self) -> None:
|
||||
"""
|
||||
Ensure the LLM server concurrency isn't accidentally capped below `group_size`.
|
||||
|
||||
In `BaseEnv process` mode, groups are collected concurrently and if the underlying
|
||||
ServerManager/OpenAIServer semaphore is left at 1, we serialize inference even
|
||||
when `--env.group_size` is > 1.
|
||||
"""
|
||||
desired = int(getattr(self.config, "group_size", 1) or 1)
|
||||
if desired <= 1:
|
||||
return
|
||||
|
||||
servers = getattr(self.server, "servers", None)
|
||||
if not isinstance(servers, list) or not servers:
|
||||
return
|
||||
|
||||
for s in servers:
|
||||
sem = getattr(s, "sem", None)
|
||||
eval_sem = getattr(s, "eval_sem", None)
|
||||
# Only increase; never shrink.
|
||||
if sem is not None and getattr(sem, "max_val", 0) < desired:
|
||||
s.sem = AsyncSemWithAdaptiveWeight(desired)
|
||||
if hasattr(s, "config") and hasattr(s.config, "num_max_requests_at_once"):
|
||||
s.config.num_max_requests_at_once = desired
|
||||
if eval_sem is not None and getattr(eval_sem, "max_val", 0) < desired:
|
||||
s.eval_sem = AsyncSemWithAdaptiveWeight(desired)
|
||||
if hasattr(s, "config") and hasattr(s.config, "num_requests_for_eval"):
|
||||
s.config.num_requests_for_eval = desired
|
||||
|
||||
@abstractmethod
|
||||
async def setup_agent_env(self) -> None:
|
||||
"""Subclass hook for env-specific setup."""
|
||||
|
||||
async def evaluate(self, *args, **kwargs): # noqa: ARG002
|
||||
"""
|
||||
Default eval hook (no-op).
|
||||
|
||||
Atropos BaseEnv requires an `evaluate()` implementation. Many agent envs
|
||||
won't have a meaningful evaluation path during early PoC work; they can
|
||||
override this when needed.
|
||||
"""
|
||||
return {}
|
||||
|
||||
async def env_manager(self):
|
||||
try:
|
||||
return await super().env_manager()
|
||||
finally:
|
||||
await self.shutdown_tool_backend()
|
||||
|
||||
async def process_manager(self):
|
||||
try:
|
||||
return await super().process_manager()
|
||||
finally:
|
||||
await self.shutdown_tool_backend()
|
||||
|
||||
async def _start_tool_backend(self) -> None:
|
||||
if self._tool_executor is not None:
|
||||
return
|
||||
|
||||
tool_server_url = self.config.tool_server_url
|
||||
tool_server_client = None
|
||||
if tool_server_url == "inprocess":
|
||||
import httpx
|
||||
from ..api.tool_server import app as tool_server_app
|
||||
|
||||
await tool_server_app.router.startup()
|
||||
tool_server_client = httpx.AsyncClient(
|
||||
transport=httpx.ASGITransport(app=tool_server_app),
|
||||
base_url="http://toolserver",
|
||||
)
|
||||
tool_server_url = "http://toolserver"
|
||||
self._tool_server_inprocess = True
|
||||
|
||||
backend = create_tool_backend(self.config)
|
||||
await backend.start()
|
||||
|
||||
executor = ToolExecutor(
|
||||
backend=backend,
|
||||
tools=self.tools,
|
||||
config=ToolExecutorConfig(
|
||||
batch_window_ms=self.config.tool_batch_window_ms,
|
||||
max_batch_size=self.config.tool_max_batch_size,
|
||||
allow_network=self.config.allow_network,
|
||||
require_sandbox=self.config.require_sandbox,
|
||||
require_stateful_sandbox=self.config.require_stateful_sandbox,
|
||||
tool_server_url=tool_server_url,
|
||||
tool_server_token=self.config.tool_server_token,
|
||||
),
|
||||
)
|
||||
await executor.start()
|
||||
if tool_server_client is not None:
|
||||
executor._tool_server_client = tool_server_client # type: ignore[attr-defined]
|
||||
|
||||
self._backend = backend
|
||||
self._tool_executor = executor
|
||||
|
||||
async def shutdown_tool_backend(self) -> None:
|
||||
executor = self._tool_executor
|
||||
backend = self._backend
|
||||
inprocess_tool_server = self._tool_server_inprocess
|
||||
self._tool_executor = None
|
||||
self._backend = None
|
||||
self._tool_server_inprocess = False
|
||||
|
||||
if executor is not None:
|
||||
await executor.close()
|
||||
if backend is not None:
|
||||
await backend.stop(purge=bool(self.config.purge_job_on_shutdown))
|
||||
if inprocess_tool_server:
|
||||
from ..api.tool_server import app as tool_server_app
|
||||
|
||||
await tool_server_app.router.shutdown()
|
||||
|
||||
async def collect_trajectory(
|
||||
self, item: Item
|
||||
) -> Tuple[Optional[ScoredDataItem], List[Item]]:
|
||||
if self._tool_executor is None:
|
||||
raise RuntimeError("Tool backend not started")
|
||||
|
||||
trajectory_id = str(uuid.uuid4())
|
||||
t0 = time.perf_counter()
|
||||
print(f"[AgentEnv] collect_trajectory(): tid={trajectory_id} start", flush=True)
|
||||
task = self.build_task(item)
|
||||
agent_config = self.build_agent_config(item)
|
||||
if os.getenv("ATROPOS_DEBUG_PRINT_TASK") == "1":
|
||||
print(f"Starting trajectory {trajectory_id} with task: {task}", flush=True)
|
||||
else:
|
||||
# Avoid printing the full task prompt by default (can be huge/noisy).
|
||||
one_line = " ".join(str(task).splitlines()).strip()
|
||||
preview = one_line[:240] + ("…" if len(one_line) > 240 else "")
|
||||
print(f"Starting trajectory {trajectory_id} (task preview): {preview}", flush=True)
|
||||
|
||||
async def _exec(call):
|
||||
return await self._tool_executor.execute(trajectory_id, call)
|
||||
|
||||
agent = AtroposAgent(
|
||||
server=self.server,
|
||||
tokenizer=self.tokenizer,
|
||||
tools=self.tools,
|
||||
config=agent_config,
|
||||
execute_tool=_exec,
|
||||
)
|
||||
|
||||
try:
|
||||
print(f"[AgentEnv] tid={trajectory_id} setup_trajectory_workspace() start", flush=True)
|
||||
workspace_meta = await self.setup_trajectory_workspace(item, trajectory_id=trajectory_id, exec_tool=_exec)
|
||||
if not isinstance(workspace_meta, dict):
|
||||
workspace_meta = {}
|
||||
self._trajectory_workspace_meta[trajectory_id] = workspace_meta
|
||||
print(
|
||||
f"[AgentEnv] tid={trajectory_id} setup_trajectory_workspace() done in {time.perf_counter() - t0:.2f}s",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
print(f"[AgentEnv] tid={trajectory_id} agent.run() start", flush=True)
|
||||
result = await agent.run(task)
|
||||
print(
|
||||
f"[AgentEnv] tid={trajectory_id} agent.run() done in {time.perf_counter() - t0:.2f}s "
|
||||
f"success={result.success} tool_calls={result.total_tool_calls}",
|
||||
flush=True,
|
||||
)
|
||||
if not result.success or result.trajectory_data is None:
|
||||
# Do not trigger BaseEnv retries for agent failures.
|
||||
# Record the trajectory with score 0.0 so training/eval can see the failure mode.
|
||||
messages = [{"role": "system", "content": agent._build_system_prompt()}] # noqa: SLF001
|
||||
messages.append({"role": "user", "content": task})
|
||||
for step in result.steps:
|
||||
messages.append({"role": "assistant", "content": step.assistant_message})
|
||||
if step.tool_results:
|
||||
tool_text = "\n".join(r.to_xml() for r in step.tool_results)
|
||||
messages.append({"role": "user", "content": tool_text})
|
||||
|
||||
scored: ScoredDataItem = {
|
||||
"tokens": (result.trajectory_data.tokens if result.trajectory_data else []),
|
||||
"masks": (result.trajectory_data.masked_tokens if result.trajectory_data else []),
|
||||
"scores": 0.0,
|
||||
}
|
||||
if result.trajectory_data is not None:
|
||||
scored["inference_logprobs"] = result.trajectory_data.logprobs # type: ignore[typeddict-unknown-key]
|
||||
if getattr(result.trajectory_data, "metadata", None):
|
||||
scored["overrides"] = {"managed_metadata": result.trajectory_data.metadata}
|
||||
if self.config.include_messages:
|
||||
# Record a final failure marker as a user-side tool_response-like block so it survives templates.
|
||||
import json
|
||||
|
||||
err = result.error or "agent_failed"
|
||||
messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"<tool_response>{json.dumps({'success': False, 'error': err})}</tool_response>",
|
||||
}
|
||||
)
|
||||
scored["messages"] = messages
|
||||
return scored, []
|
||||
|
||||
print(f"[AgentEnv] tid={trajectory_id} verify_and_score_trajectory() start", flush=True)
|
||||
score, score_metadata = await self.verify_and_score_trajectory(
|
||||
item,
|
||||
result.final_response,
|
||||
trajectory_id=trajectory_id,
|
||||
exec_tool=_exec,
|
||||
agent_result=result,
|
||||
workspace_meta=workspace_meta,
|
||||
)
|
||||
print(
|
||||
f"[AgentEnv] tid={trajectory_id} verify_and_score_trajectory() done in {time.perf_counter() - t0:.2f}s "
|
||||
f"score={score}",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
messages = [{"role": "system", "content": agent._build_system_prompt()}] # noqa: SLF001
|
||||
messages.append({"role": "user", "content": task})
|
||||
for step in result.steps:
|
||||
messages.append({"role": "assistant", "content": step.assistant_message})
|
||||
if step.tool_results:
|
||||
tool_text = "\n".join(r.to_xml() for r in step.tool_results)
|
||||
messages.append({"role": "user", "content": tool_text})
|
||||
|
||||
# Optional: allow env verification to attach additional messages (e.g. install logs).
|
||||
if self.config.include_messages and isinstance(score_metadata, dict):
|
||||
extra = score_metadata.get("verification_messages")
|
||||
if isinstance(extra, list):
|
||||
for m in extra:
|
||||
if isinstance(m, dict) and isinstance(m.get("role"), str) and isinstance(m.get("content"), str):
|
||||
messages.append({"role": m["role"], "content": m["content"]})
|
||||
|
||||
scored: ScoredDataItem = {
|
||||
"tokens": result.trajectory_data.tokens,
|
||||
"masks": result.trajectory_data.masked_tokens,
|
||||
"scores": score,
|
||||
}
|
||||
# Atroposlib expects policy logprobs at the *group* level under `inference_logprobs`.
|
||||
# We stash per-item values here and lift them into the group in `collect_trajectories()`.
|
||||
scored["inference_logprobs"] = result.trajectory_data.logprobs # type: ignore[typeddict-unknown-key]
|
||||
if getattr(result.trajectory_data, "metadata", None):
|
||||
scored["overrides"] = {"managed_metadata": result.trajectory_data.metadata}
|
||||
if self.config.include_messages:
|
||||
scored["messages"] = messages
|
||||
|
||||
return scored, []
|
||||
finally:
|
||||
self._trajectory_workspace_meta.pop(trajectory_id, None)
|
||||
print(f"[AgentEnv] tid={trajectory_id} release_trajectory(reset_workspace=True)", flush=True)
|
||||
await self._tool_executor.release_trajectory(trajectory_id, reset_workspace=True)
|
||||
print(f"[AgentEnv] collect_trajectory(): tid={trajectory_id} done in {time.perf_counter() - t0:.2f}s", flush=True)
|
||||
|
||||
async def collect_trajectories(
|
||||
self, item: Item
|
||||
) -> Tuple[Optional[ScoredDataGroup], List[Item]]:
|
||||
tasks = [self.collect_trajectory(item) for _ in range(self.config.group_size)]
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
backlog: List[Item] = []
|
||||
items: List[ScoredDataItem] = []
|
||||
for scored, b in results:
|
||||
backlog.extend(b)
|
||||
if scored is not None:
|
||||
items.append(scored)
|
||||
|
||||
if len(items) != self.config.group_size:
|
||||
return None, backlog
|
||||
|
||||
group: ScoredDataGroup = ScoredDataGroup(
|
||||
tokens=[],
|
||||
masks=[],
|
||||
scores=[],
|
||||
advantages=[],
|
||||
ref_logprobs=[],
|
||||
messages=[] if self.config.include_messages else None,
|
||||
inference_logprobs=[],
|
||||
group_overrides={},
|
||||
overrides=[],
|
||||
images=[],
|
||||
generation_params=None,
|
||||
)
|
||||
|
||||
for it in items:
|
||||
group["tokens"].append(it["tokens"])
|
||||
group["masks"].append(it["masks"])
|
||||
group["scores"].append(it["scores"])
|
||||
# policy logprobs (for PPO/GRPO training) if present
|
||||
lp = it.get("inference_logprobs") # type: ignore[typeddict-item]
|
||||
if lp is not None:
|
||||
group["inference_logprobs"].append(lp)
|
||||
group["overrides"].append(it.get("overrides") or {}) # type: ignore[typeddict-item]
|
||||
if group.get("messages") is not None and it.get("messages") is not None:
|
||||
group["messages"].append(it["messages"])
|
||||
|
||||
return group, backlog
|
||||
|
||||
async def run_agent(self, task: str, *, trajectory_id: Optional[str] = None) -> Tuple[str, Dict[str, Any]]:
|
||||
"""
|
||||
Run the AtroposAgent on a single task and return (final_response, debug).
|
||||
|
||||
This is a helper intended for simple environments and tests.
|
||||
"""
|
||||
if self._tool_executor is None:
|
||||
raise RuntimeError("Tool backend not started")
|
||||
|
||||
tid = trajectory_id or str(uuid.uuid4())
|
||||
|
||||
async def _exec(call):
|
||||
return await self._tool_executor.execute(tid, call)
|
||||
|
||||
agent = AtroposAgent(
|
||||
server=self.server,
|
||||
tokenizer=self.tokenizer,
|
||||
tools=self.tools,
|
||||
config=AgentConfig(
|
||||
max_steps=self.config.agent_max_steps,
|
||||
temperature=self.config.agent_temperature,
|
||||
max_tokens=self.config.agent_max_tokens,
|
||||
),
|
||||
execute_tool=_exec,
|
||||
)
|
||||
result = await agent.run(task)
|
||||
await self._tool_executor.release_trajectory(tid, reset_workspace=True)
|
||||
return result.final_response, {"success": result.success, "error": result.error, "tool_calls": result.total_tool_calls}
|
||||
171
atropos/envs/hermes_compat_test_env.py
Normal file
171
atropos/envs/hermes_compat_test_env.py
Normal file
@@ -0,0 +1,171 @@
|
||||
"""
|
||||
Hermes-Agent + Atropos (Nomad sandbox) compatibility smoke environment.
|
||||
|
||||
This environment is intended to validate, end-to-end:
|
||||
BaseEnv.process -> AgentEnv -> ToolExecutor (batched) -> Nomad SlotPool -> sandbox_server
|
||||
|
||||
It forces the model to use a sandbox tool by asking it to run a command that
|
||||
generates a high-entropy token inside the sandbox, then repeat it exactly.
|
||||
|
||||
Run (process mode):
|
||||
uv run python -m atropos.envs.hermes_compat_test_env process --env.use_wandb false --env.total_steps 2 --env.group_size 1
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from pydantic import Field
|
||||
|
||||
from atroposlib.envs.base import APIServerConfig, Item
|
||||
|
||||
from ..agent import AgentConfig, AgentResult
|
||||
from ..tools import ToolCall
|
||||
from .agent_env import AgentEnv, AgentEnvConfig
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
def _forced_tool_item() -> Item:
|
||||
# Use double quotes in the shell command and show JSON escaping explicitly.
|
||||
# This avoids invalid JSON escapes like `\\'` (not valid JSON) that some models produce.
|
||||
cmd = 'python -c "import secrets; print(secrets.token_hex(16))"'
|
||||
return {
|
||||
"command": cmd,
|
||||
"prompt": (
|
||||
"You are acting as an agent inside a sandboxed environment.\n"
|
||||
"You MUST use the terminal tool to execute commands.\n"
|
||||
"Run this exact command:\n"
|
||||
f"{cmd}\n"
|
||||
"When you call the tool, use valid JSON inside <tool_call>. Example:\n"
|
||||
'<tool_call>{"name": "terminal", "arguments": {"command": '
|
||||
'"python -c \\\\"import secrets; print(secrets.token_hex(16))\\\\""}}'
|
||||
"</tool_call>\n"
|
||||
"Then respond with EXACTLY what it printed (the hex token) and nothing else.\n"
|
||||
"Do not guess. Do not explain."
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
class HermesCompatTestEnvConfig(AgentEnvConfig):
|
||||
server_base_url: str = Field(
|
||||
default="http://127.0.0.1:8080",
|
||||
description="Base URL for an OpenAI-compatible chat server (without /v1).",
|
||||
)
|
||||
server_model: str = Field(default="hermes-4-36b", description="Model name")
|
||||
tokenizer_name: str = Field(default="NousResearch/Hermes-4.3-36B", description="Tokenizer name for RL tokenization")
|
||||
|
||||
|
||||
class HermesCompatTestEnv(AgentEnv[HermesCompatTestEnvConfig]):
|
||||
name = "hermes_compat_test_env"
|
||||
env_config_cls = HermesCompatTestEnvConfig
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: HermesCompatTestEnvConfig,
|
||||
server_configs: List[APIServerConfig],
|
||||
slurm: bool = False,
|
||||
testing: bool = False,
|
||||
):
|
||||
super().__init__(config, server_configs, slurm, testing)
|
||||
self._iter = 0
|
||||
|
||||
@classmethod
|
||||
def config_init(cls) -> Tuple[HermesCompatTestEnvConfig, List[APIServerConfig]]:
|
||||
base_url = (
|
||||
os.getenv("ATROPOS_SERVER_BASE_URL")
|
||||
or os.getenv("OPENAI_BASE_URL")
|
||||
or os.getenv("LLM_BASE_URL")
|
||||
or "http://127.0.0.1:8080"
|
||||
)
|
||||
model = os.getenv("ATROPOS_SERVER_MODEL") or os.getenv("LLM_MODEL") or "hermes-4-36b"
|
||||
api_key = os.getenv("ATROPOS_SERVER_API_KEY") or os.getenv("NOUS_API_KEY") or os.getenv("OPENAI_API_KEY") or "local"
|
||||
|
||||
env_config = HermesCompatTestEnvConfig(
|
||||
tokenizer_name=os.getenv("ATROPOS_TOKENIZER_NAME") or "NousResearch/Hermes-4.3-36B",
|
||||
group_size=1,
|
||||
use_wandb=False,
|
||||
include_messages=True,
|
||||
ensure_scores_are_not_same=False,
|
||||
total_steps=2,
|
||||
batch_size=1,
|
||||
server_base_url=base_url,
|
||||
server_model=model,
|
||||
# Tooling: sandbox-only terminal.
|
||||
enabled_toolsets=["terminal"],
|
||||
disabled_toolsets=[],
|
||||
# Default to Nomad sandboxing; users can override via --env.* args.
|
||||
sandbox_image=os.getenv("ATROPOS_SANDBOX_IMAGE") or "atropos-sandbox:local",
|
||||
# In local dev it's common for a previous crash to leave the job in backoff.
|
||||
purge_job_on_start=True,
|
||||
purge_job_on_shutdown=True,
|
||||
)
|
||||
|
||||
server_configs = [
|
||||
APIServerConfig(
|
||||
model_name=model,
|
||||
base_url=f"{base_url.rstrip('/')}/v1",
|
||||
api_key=api_key,
|
||||
num_max_requests_at_once=1,
|
||||
num_requests_for_eval=1,
|
||||
timeout=120,
|
||||
)
|
||||
]
|
||||
return env_config, server_configs
|
||||
|
||||
async def setup_agent_env(self) -> None:
|
||||
return None
|
||||
|
||||
async def get_next_item(self) -> Item:
|
||||
self._iter += 1
|
||||
return _forced_tool_item()
|
||||
|
||||
def build_task(self, item: Item) -> str:
|
||||
return str(item.get("prompt") or "")
|
||||
|
||||
def build_agent_config(self, item: Item) -> AgentConfig: # noqa: ARG002
|
||||
# Avoid imposing max_tokens by default; tool-tag responses can be long for some models.
|
||||
return AgentConfig(
|
||||
max_steps=min(8, int(self.config.agent_max_steps)),
|
||||
temperature=0.2,
|
||||
max_tokens=None,
|
||||
)
|
||||
|
||||
async def score_trajectory(self, item: Item, final_response: str) -> float:
|
||||
# Scoring happens in verify_and_score_trajectory so we can inspect tool results.
|
||||
_ = (item, final_response)
|
||||
return 0.0
|
||||
|
||||
async def verify_and_score_trajectory(
|
||||
self,
|
||||
item: Item,
|
||||
final_response: str,
|
||||
*,
|
||||
trajectory_id: str, # noqa: ARG002
|
||||
exec_tool, # noqa: ARG002
|
||||
agent_result: AgentResult | None = None,
|
||||
workspace_meta: Dict[str, Any] | None = None, # noqa: ARG002
|
||||
) -> tuple[float, Dict[str, Any]]:
|
||||
if agent_result is None:
|
||||
return 0.0, {"error": "Missing agent_result"}
|
||||
|
||||
observed: str = ""
|
||||
tool_ok = False
|
||||
for step in agent_result.steps:
|
||||
for res in step.tool_results:
|
||||
if not res.success:
|
||||
return 0.0, {"error": res.error, "output": res.output}
|
||||
out = (res.output or "").strip()
|
||||
if out:
|
||||
observed = out.splitlines()[-1].strip()
|
||||
tool_ok = True
|
||||
|
||||
final = (final_response or "").strip()
|
||||
score = 1.0 if tool_ok and agent_result.total_tool_calls > 0 and observed and final == observed else 0.0
|
||||
return score, {"observed": observed, "tool_calls": agent_result.total_tool_calls, "command": item.get("command")}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
HermesCompatTestEnv.cli()
|
||||
172
atropos/envs/sandbox_terminal_smoke_env.py
Normal file
172
atropos/envs/sandbox_terminal_smoke_env.py
Normal file
@@ -0,0 +1,172 @@
|
||||
"""
|
||||
Nomad sandbox terminal smoke environment (training-oriented).
|
||||
|
||||
Validates, end-to-end:
|
||||
BaseEnv.process -> AgentEnv -> ToolExecutor (batched) -> Nomad SlotPool -> sandbox_server
|
||||
|
||||
It forces the model to use a sandbox tool by asking it to run a command that
|
||||
generates a high-entropy token inside the sandbox, then repeat it exactly.
|
||||
|
||||
Run (process mode):
|
||||
uv run python -m atropos.envs.sandbox_terminal_smoke_env process --env.use_wandb false --env.total_steps 2 --env.group_size 1
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from pydantic import Field
|
||||
|
||||
from atroposlib.envs.base import APIServerConfig, Item
|
||||
|
||||
from ..agent import AgentConfig, AgentResult
|
||||
from ..tools import ToolCall
|
||||
from .agent_env import AgentEnv, AgentEnvConfig
|
||||
|
||||
load_dotenv()
|
||||
|
||||
STRICT_TOOLCALL_SYSTEM_PROMPT = None
|
||||
|
||||
|
||||
def _forced_tool_item() -> Item:
|
||||
# Use double quotes in the shell command and show JSON escaping explicitly.
|
||||
# This avoids invalid JSON escapes like `\\'` (not valid JSON) that some models produce.
|
||||
cmd = 'python -c "import secrets; print(secrets.token_hex(16))"'
|
||||
return {
|
||||
"command": cmd,
|
||||
"prompt": (
|
||||
"You MUST use the terminal tool.\n"
|
||||
"Run this exact command:\n"
|
||||
f"{cmd}\n"
|
||||
"When you call the tool, use valid JSON inside <tool_call>. Example:\n"
|
||||
'<tool_call>{"name": "terminal", "arguments": {"command": '
|
||||
'"python -c \\\\"import secrets; print(secrets.token_hex(16))\\\\""}}'
|
||||
"</tool_call>\n"
|
||||
"Then respond with EXACTLY what it printed (the hex token) and nothing else.\n"
|
||||
"Do not guess. Do not explain."
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
class SandboxTerminalSmokeEnvConfig(AgentEnvConfig):
|
||||
server_base_url: str = Field(
|
||||
default="http://127.0.0.1:8080",
|
||||
description="Base URL for an OpenAI-compatible chat server (without /v1).",
|
||||
)
|
||||
server_model: str = Field(default="hermes-4-36b", description="Model name")
|
||||
tokenizer_name: str = Field(default="NousResearch/Hermes-4.3-36B", description="Tokenizer name for RL tokenization")
|
||||
|
||||
|
||||
class SandboxTerminalSmokeEnv(AgentEnv[SandboxTerminalSmokeEnvConfig]):
|
||||
name = "sandbox_terminal_smoke_env"
|
||||
env_config_cls = SandboxTerminalSmokeEnvConfig
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: SandboxTerminalSmokeEnvConfig,
|
||||
server_configs: List[APIServerConfig],
|
||||
slurm: bool = False,
|
||||
testing: bool = False,
|
||||
):
|
||||
super().__init__(config, server_configs, slurm, testing)
|
||||
self._iter = 0
|
||||
|
||||
@classmethod
|
||||
def config_init(cls) -> Tuple[SandboxTerminalSmokeEnvConfig, List[APIServerConfig]]:
|
||||
base_url = (
|
||||
os.getenv("ATROPOS_SERVER_BASE_URL")
|
||||
or os.getenv("OPENAI_BASE_URL")
|
||||
or os.getenv("LLM_BASE_URL")
|
||||
or "http://127.0.0.1:8080"
|
||||
)
|
||||
model = os.getenv("ATROPOS_SERVER_MODEL") or os.getenv("LLM_MODEL") or "hermes-4-36b"
|
||||
api_key = os.getenv("ATROPOS_SERVER_API_KEY") or os.getenv("NOUS_API_KEY") or os.getenv("OPENAI_API_KEY") or "local"
|
||||
|
||||
env_config = SandboxTerminalSmokeEnvConfig(
|
||||
tokenizer_name=os.getenv("ATROPOS_TOKENIZER_NAME") or "NousResearch/Hermes-4.3-36B",
|
||||
group_size=1,
|
||||
use_wandb=False,
|
||||
include_messages=True,
|
||||
ensure_scores_are_not_same=False,
|
||||
total_steps=2,
|
||||
batch_size=1,
|
||||
server_base_url=base_url,
|
||||
server_model=model,
|
||||
# Tooling: sandbox-only terminal.
|
||||
enabled_toolsets=["terminal"],
|
||||
disabled_toolsets=[],
|
||||
# Default to Nomad sandboxing; users can override via --env.* args.
|
||||
sandbox_image=os.getenv("ATROPOS_SANDBOX_IMAGE") or "atropos-sandbox:local",
|
||||
purge_job_on_start=True,
|
||||
purge_job_on_shutdown=True,
|
||||
)
|
||||
|
||||
server_configs = [
|
||||
APIServerConfig(
|
||||
model_name=model,
|
||||
base_url=f"{base_url.rstrip('/')}/v1",
|
||||
api_key=api_key,
|
||||
num_max_requests_at_once=1,
|
||||
num_requests_for_eval=1,
|
||||
timeout=120,
|
||||
)
|
||||
]
|
||||
return env_config, server_configs
|
||||
|
||||
async def setup_agent_env(self) -> None:
|
||||
return None
|
||||
|
||||
async def get_next_item(self) -> Item:
|
||||
self._iter += 1
|
||||
return _forced_tool_item()
|
||||
|
||||
def build_task(self, item: Item) -> str:
|
||||
return str(item.get("prompt") or "")
|
||||
|
||||
def build_agent_config(self, item: Item) -> AgentConfig: # noqa: ARG002
|
||||
# Avoid imposing max_tokens by default; tool-tag responses can be long for some models.
|
||||
return AgentConfig(
|
||||
max_steps=min(8, int(self.config.agent_max_steps)),
|
||||
temperature=0.2,
|
||||
max_tokens=None,
|
||||
system_prompt=STRICT_TOOLCALL_SYSTEM_PROMPT,
|
||||
)
|
||||
|
||||
async def score_trajectory(self, item: Item, final_response: str) -> float:
|
||||
# Scoring happens in verify_and_score_trajectory so we can inspect tool results.
|
||||
_ = (item, final_response)
|
||||
return 0.0
|
||||
|
||||
async def verify_and_score_trajectory(
|
||||
self,
|
||||
item: Item,
|
||||
final_response: str,
|
||||
*,
|
||||
trajectory_id: str, # noqa: ARG002
|
||||
exec_tool, # noqa: ARG002
|
||||
agent_result: AgentResult | None = None,
|
||||
workspace_meta: Dict[str, Any] | None = None, # noqa: ARG002
|
||||
) -> tuple[float, Dict[str, Any]]:
|
||||
if agent_result is None:
|
||||
return 0.0, {"error": "Missing agent_result"}
|
||||
|
||||
observed: str = ""
|
||||
tool_ok = False
|
||||
for step in agent_result.steps:
|
||||
for res in step.tool_results:
|
||||
if not res.success:
|
||||
return 0.0, {"error": res.error, "output": res.output}
|
||||
out = (res.output or "").strip()
|
||||
if out:
|
||||
observed = out.splitlines()[-1].strip()
|
||||
tool_ok = True
|
||||
|
||||
final = (final_response or "").strip()
|
||||
score = 1.0 if tool_ok and agent_result.total_tool_calls > 0 and observed and final == observed else 0.0
|
||||
return score, {"observed": observed, "tool_calls": agent_result.total_tool_calls, "command": item.get("command")}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
SandboxTerminalSmokeEnv.cli()
|
||||
418
atropos/envs/swe_smith_oracle_env.py
Normal file
418
atropos/envs/swe_smith_oracle_env.py
Normal file
@@ -0,0 +1,418 @@
|
||||
"""
|
||||
SWE-smith-oracle environment.
|
||||
|
||||
This environment is intentionally minimal:
|
||||
- prepares a sandbox workspace by cloning a public GitHub repo at `base_commit`
|
||||
- runs an AtroposAgent tool loop to apply a fix
|
||||
- verifies by running pytest nodeids from the dataset (reward = pass/fail)
|
||||
- Python only (no multi-language support currently, need to properly bauild & add to dropbox)
|
||||
- TODO: Get the other nonpython sandboxes up and running, then add a config knob to switch between them per row
|
||||
- oh and add to dockerhub
|
||||
|
||||
Dataset: NousResearch/SWE-smith-oracle (train; does NOT use SWE-bench eval set).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from atroposlib.envs.base import APIServerConfig, Item
|
||||
|
||||
from ..agent import AgentConfig
|
||||
from ..tools import ToolCall
|
||||
from .agent_env import AgentEnv, AgentEnvConfig
|
||||
|
||||
|
||||
class SweSmithOracleEnvConfig(AgentEnvConfig):
|
||||
dataset_name: str = Field(default="NousResearch/SWE-smith-oracle")
|
||||
dataset_split: str = Field(default="train")
|
||||
max_items: int = Field(default=0, description="0 = no limit")
|
||||
shuffle: bool = Field(default=True)
|
||||
seed: int = Field(default=0)
|
||||
|
||||
python_only: bool = Field(default=True, description="Filter to Python-evaluable rows")
|
||||
score_include_fail_to_pass: bool = Field(
|
||||
default=True,
|
||||
description=(
|
||||
"If true (default), score tests on PASS_TO_PASS ∪ FAIL_TO_PASS. "
|
||||
"Disable to only run PASS_TO_PASS (faster but weaker signal)."
|
||||
),
|
||||
)
|
||||
|
||||
prompt_mode: str = Field(
|
||||
default="problem_statement",
|
||||
description="Task prompt content: 'problem_statement' (fast) or 'problem_statement+text' (slower, includes dataset 'text').",
|
||||
)
|
||||
|
||||
repo_base_url: str = Field(default="https://github.com", description="Base URL for repo cloning")
|
||||
install_timeout_s: float = Field(default=600.0)
|
||||
test_timeout_s: float = Field(default=600.0)
|
||||
|
||||
tokenizer_name: str = Field(default="NousResearch/Hermes-4.3-36B", description="Tokenizer name for RL tokenization")
|
||||
|
||||
|
||||
class SweSmithOracleEnv(AgentEnv[SweSmithOracleEnvConfig]):
|
||||
"""
|
||||
SWE-smith-oracle AgentEnv.
|
||||
|
||||
This is designed for benchmarking multiplexed slot execution vs naive container-per-trajectory.
|
||||
"""
|
||||
|
||||
name = "swe_smith_oracle_env"
|
||||
env_config_cls = SweSmithOracleEnvConfig
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: SweSmithOracleEnvConfig,
|
||||
server_configs: List[APIServerConfig],
|
||||
slurm: bool = False,
|
||||
testing: bool = False,
|
||||
):
|
||||
super().__init__(config, server_configs, slurm, testing)
|
||||
self._dataset = None
|
||||
self._indices: List[int] = []
|
||||
self._cursor = 0
|
||||
|
||||
@classmethod
|
||||
def config_init(cls) -> Tuple[SweSmithOracleEnvConfig, List[APIServerConfig]]:
|
||||
# Defaults for running the env via CLI in offline `process` mode.
|
||||
# Override via env vars or `--env.*` flags as needed.
|
||||
base_url_raw = (
|
||||
os.getenv("ATROPOS_SERVER_BASE_URL")
|
||||
or os.getenv("OPENAI_BASE_URL")
|
||||
or os.getenv("LLM_BASE_URL")
|
||||
or "http://127.0.0.1:8080"
|
||||
)
|
||||
base_url = base_url_raw.rstrip("/")
|
||||
if not base_url.endswith("/v1"):
|
||||
base_url = f"{base_url}/v1"
|
||||
model = os.getenv("ATROPOS_SERVER_MODEL") or os.getenv("LLM_MODEL") or "hermes-4-36b"
|
||||
api_key = os.getenv("ATROPOS_SERVER_API_KEY") or os.getenv("NOUS_API_KEY") or os.getenv("OPENAI_API_KEY") or "local"
|
||||
|
||||
env_config = SweSmithOracleEnvConfig(
|
||||
tokenizer_name=os.getenv("ATROPOS_TOKENIZER_NAME") or "NousResearch/Hermes-4.3-36B",
|
||||
group_size=1,
|
||||
use_wandb=False,
|
||||
rollout_server_url="http://localhost:8000",
|
||||
total_steps=1,
|
||||
batch_size=1,
|
||||
steps_per_eval=1,
|
||||
max_token_length=8192,
|
||||
inference_weight=1.0,
|
||||
wandb_name="swe_smith_oracle",
|
||||
enabled_toolsets=["terminal"],
|
||||
disabled_toolsets=[],
|
||||
sandbox_image=os.getenv("ATROPOS_SANDBOX_IMAGE") or "atropos-sandbox:local",
|
||||
purge_job_on_start=True,
|
||||
purge_job_on_shutdown=True,
|
||||
)
|
||||
|
||||
server_configs = [
|
||||
APIServerConfig(
|
||||
model_name=model,
|
||||
base_url=base_url,
|
||||
api_key=api_key,
|
||||
num_max_requests_at_once=1,
|
||||
num_requests_for_eval=1,
|
||||
timeout=int(os.getenv("ATROPOS_SERVER_TIMEOUT_S") or "300"),
|
||||
),
|
||||
]
|
||||
|
||||
return env_config, server_configs
|
||||
|
||||
async def setup_agent_env(self) -> None:
|
||||
from datasets import load_dataset
|
||||
|
||||
t0 = time.perf_counter()
|
||||
print(
|
||||
f"[SweSmithOracleEnv] loading dataset {self.config.dataset_name}:{self.config.dataset_split} "
|
||||
f"(python_only={self.config.python_only}, max_items={self.config.max_items or 'all'})",
|
||||
flush=True,
|
||||
)
|
||||
ds = load_dataset(self.config.dataset_name, split=self.config.dataset_split)
|
||||
self._dataset = ds
|
||||
|
||||
indices: List[int] = []
|
||||
for idx in range(len(ds)):
|
||||
row = ds[idx]
|
||||
if self.config.python_only and not self._is_python_row(row):
|
||||
continue
|
||||
indices.append(idx)
|
||||
|
||||
if self.config.shuffle:
|
||||
rnd = random.Random(self.config.seed)
|
||||
rnd.shuffle(indices)
|
||||
|
||||
if self.config.max_items and self.config.max_items > 0:
|
||||
indices = indices[: self.config.max_items]
|
||||
|
||||
self._indices = indices
|
||||
self._cursor = 0
|
||||
|
||||
print(
|
||||
f"[SweSmithOracleEnv] loaded {len(self._indices)} items from {self.config.dataset_name}:{self.config.dataset_split} "
|
||||
f"in {time.perf_counter() - t0:.2f}s",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
def _is_python_row(self, row: Dict[str, Any]) -> bool:
|
||||
nodeids = row.get("PASS_TO_PASS")
|
||||
if not isinstance(nodeids, list) or not nodeids:
|
||||
return False
|
||||
for nid in nodeids:
|
||||
if not isinstance(nid, str) or ".py::" not in nid:
|
||||
return False
|
||||
return True
|
||||
|
||||
async def get_next_item(self) -> Item:
|
||||
print(f"[SweSmithOracleEnv] get_next_item() cursor={self._cursor}/{len(self._indices)}", flush=True)
|
||||
if not self._dataset or not self._indices:
|
||||
raise RuntimeError("Dataset not initialized (did setup() run?)")
|
||||
if self._cursor >= len(self._indices):
|
||||
self._cursor = 0
|
||||
idx = self._indices[self._cursor]
|
||||
self._cursor += 1
|
||||
return dict(self._dataset[idx])
|
||||
|
||||
def _repo_name(self, item: Item) -> str:
|
||||
repo = item.get("repo") or ""
|
||||
if isinstance(repo, str) and "/" in repo:
|
||||
return repo.split("/")[-1]
|
||||
return "repo"
|
||||
|
||||
def build_task(self, item: Item) -> str:
|
||||
repo = item.get("repo") or ""
|
||||
base_commit = item.get("base_commit") or ""
|
||||
problem = str(item.get("problem_statement") or "")
|
||||
context = str(item.get("text") or "")
|
||||
|
||||
nodeids = self._tests_for_item(item)
|
||||
tests_list = "\n".join(f"- {t}" for t in nodeids)
|
||||
|
||||
repo_dir = self._repo_name(item)
|
||||
|
||||
tests_block = (
|
||||
"Run these tests to verify:\n"
|
||||
f"{tests_list}\n\n"
|
||||
"When done, briefly describe what you changed and confirm tests pass."
|
||||
)
|
||||
|
||||
prompt_mode = (self.config.prompt_mode or "problem_statement").strip().lower()
|
||||
if prompt_mode not in {"problem_statement", "problem_statement+text"}:
|
||||
raise ValueError(
|
||||
f"Invalid prompt_mode={self.config.prompt_mode!r}. "
|
||||
"Expected 'problem_statement' or 'problem_statement+text'."
|
||||
)
|
||||
|
||||
context_block = ""
|
||||
if prompt_mode == "problem_statement+text" and context:
|
||||
# Note: We intentionally do NOT truncate/cap here. This mode is for debugging / richer prompts and can be slow.
|
||||
context_block = f"\nAdditional context:\n{context}\n"
|
||||
|
||||
return (
|
||||
"You are a senior software engineer. Fix the repository so the specified tests pass.\n\n"
|
||||
f"Repository: {repo} (checked out at base_commit={base_commit})\n"
|
||||
f"Workspace path: ./{repo_dir}\n\n"
|
||||
"Constraints:\n"
|
||||
"- You MUST use the terminal tool to inspect, edit, and verify the repository. Do not respond with a patch file.\n"
|
||||
f"- Start by inspecting the repo (e.g. `ls`, `cd ./{repo_dir}`, `git status`).\n"
|
||||
"- Use a workspace-local virtualenv (e.g. inside the repo at ./.venv) to avoid cross-run contamination.\n"
|
||||
"- Use non-interactive commands only.\n\n"
|
||||
"- Terminal commands run under POSIX /bin/sh and each tool call runs in a fresh shell (no persisted env vars).\n"
|
||||
" Avoid bash-only `source`; prefer `. .venv/bin/activate` or `.venv/bin/python ...`.\n\n"
|
||||
"Problem statement:\n"
|
||||
f"{problem}\n\n"
|
||||
f"{context_block}\n"
|
||||
f"{tests_block}"
|
||||
)
|
||||
|
||||
def build_agent_config(self, item: Item) -> AgentConfig: # noqa: ARG002
|
||||
# SWE tasks are longer than the simple test env.
|
||||
return AgentConfig(
|
||||
max_steps=self.config.agent_max_steps,
|
||||
temperature=self.config.agent_temperature,
|
||||
max_tokens=self.config.agent_max_tokens,
|
||||
tool_delay_s=self.config.agent_tool_delay_s,
|
||||
)
|
||||
|
||||
async def setup_trajectory_workspace(self, item: Item, *, trajectory_id: str, exec_tool) -> Dict[str, Any]:
|
||||
t0 = time.perf_counter()
|
||||
repo = item.get("repo")
|
||||
base_commit = item.get("base_commit")
|
||||
instance_id = item.get("instance_id") or item.get("id") or item.get("problem_id")
|
||||
if not isinstance(repo, str) or not isinstance(base_commit, str):
|
||||
raise RuntimeError("Invalid dataset row: missing repo/base_commit")
|
||||
|
||||
repo_dir = self._repo_name(item)
|
||||
clone_url = f"{self.config.repo_base_url.rstrip('/')}/{repo}.git"
|
||||
print(
|
||||
f"[SweSmithOracleEnv] tid={trajectory_id} setup_trajectory_workspace(): "
|
||||
f"repo={repo} base_commit={base_commit} instance_id={instance_id} dir=./{repo_dir}",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
# Repo setup strategy:
|
||||
# - Maintain a shared, per-container bare repo cache under /data/repo_cache
|
||||
# - For each trajectory, create an isolated git worktree under the slot workspace
|
||||
# This avoids cloning/fetching full repos per trajectory and is crucial for multiplexing.
|
||||
|
||||
def _repo_cache_slug(repo_name: str) -> str:
|
||||
return repo_name.replace("/", "__")
|
||||
|
||||
repo_slug = _repo_cache_slug(repo)
|
||||
cache_root = "/data/repo_cache"
|
||||
bare_repo = f"{cache_root}/{repo_slug}.git"
|
||||
lock_file = f"{cache_root}/.locks/{repo_slug}.lock"
|
||||
|
||||
# Use flock to serialize operations that mutate the shared bare repo (fetch/worktree).
|
||||
# util-linux (flock) is included in the sandbox image.
|
||||
worktree_cmd = (
|
||||
"set -e; "
|
||||
f"rm -rf {repo_dir}; "
|
||||
f"mkdir -p {cache_root}/.locks; "
|
||||
f": > {lock_file}; "
|
||||
f"flock -x {lock_file} sh -lc '"
|
||||
f"set -e; "
|
||||
"export GIT_TERMINAL_PROMPT=0; "
|
||||
"export GIT_LFS_SKIP_SMUDGE=1; "
|
||||
f"if [ ! -d \"{bare_repo}\" ]; then "
|
||||
f" git init --bare \"{bare_repo}\"; "
|
||||
f" git -C \"{bare_repo}\" remote add origin \"{clone_url}\"; "
|
||||
"fi; "
|
||||
f"git -C \"{bare_repo}\" remote set-url origin \"{clone_url}\"; "
|
||||
f"git -C \"{bare_repo}\" worktree prune || true; "
|
||||
f"if ! git -C \"{bare_repo}\" cat-file -e \"{base_commit}^{{commit}}\" 2>/dev/null; then "
|
||||
f" git -C \"{bare_repo}\" fetch --depth 1 origin \"{base_commit}\" || true; "
|
||||
"fi; "
|
||||
f"if ! git -C \"{bare_repo}\" cat-file -e \"{base_commit}^{{commit}}\" 2>/dev/null; then "
|
||||
f" git -C \"{bare_repo}\" fetch --prune origin; "
|
||||
"fi; "
|
||||
f"git --git-dir=\"{bare_repo}\" worktree add --detach \"{repo_dir}\" \"{base_commit}\"; "
|
||||
"'"
|
||||
)
|
||||
|
||||
print(f"[SweSmithOracleEnv] tid={trajectory_id} preparing worktree from repo cache", flush=True)
|
||||
res = await exec_tool(
|
||||
ToolCall(
|
||||
name="terminal",
|
||||
arguments={"command": worktree_cmd, "timeout": self.config.install_timeout_s},
|
||||
)
|
||||
)
|
||||
if not res.success:
|
||||
raise RuntimeError(
|
||||
"git worktree setup failed "
|
||||
f"(repo={repo}, base_commit={base_commit}, instance_id={instance_id}): {res.error}\n{res.output}"
|
||||
)
|
||||
|
||||
print(
|
||||
f"[SweSmithOracleEnv] tid={trajectory_id} setup_trajectory_workspace(): worktree ready in {time.perf_counter() - t0:.2f}s",
|
||||
flush=True,
|
||||
)
|
||||
return {"repo_dir": repo_dir, "base_commit": base_commit}
|
||||
|
||||
def _tests_for_item(self, item: Item) -> List[str]:
|
||||
tests: List[str] = []
|
||||
if self.config.score_include_fail_to_pass:
|
||||
for key in ("PASS_TO_PASS", "FAIL_TO_PASS"):
|
||||
nodeids = item.get(key)
|
||||
if isinstance(nodeids, list):
|
||||
tests.extend([n for n in nodeids if isinstance(n, str)])
|
||||
else:
|
||||
nodeids = item.get("PASS_TO_PASS")
|
||||
if isinstance(nodeids, list):
|
||||
tests.extend([n for n in nodeids if isinstance(n, str)])
|
||||
# Stable order for reproducibility.
|
||||
return sorted(dict.fromkeys(tests))
|
||||
|
||||
def _chunk_nodeids(self, nodeids: List[str], max_per_chunk: int = 50) -> List[List[str]]:
|
||||
chunks: List[List[str]] = []
|
||||
for i in range(0, len(nodeids), max_per_chunk):
|
||||
chunks.append(nodeids[i : i + max_per_chunk])
|
||||
return chunks
|
||||
|
||||
async def verify_and_score_trajectory(
|
||||
self,
|
||||
item: Item,
|
||||
final_response: str, # noqa: ARG002
|
||||
*,
|
||||
trajectory_id: str,
|
||||
exec_tool,
|
||||
agent_result=None,
|
||||
workspace_meta: Optional[Dict[str, Any]] = None,
|
||||
) -> tuple[float, Dict[str, Any]]:
|
||||
_ = trajectory_id
|
||||
repo_dir = self._repo_name(item)
|
||||
|
||||
# Training correctness: do not reward trajectories that never actually used tools.
|
||||
if agent_result is not None and getattr(agent_result, "total_tool_calls", 0) <= 0:
|
||||
print(
|
||||
f"[SweSmithOracleEnv] tid={trajectory_id} verify (dataset_tests): no tool calls; score=0.0",
|
||||
flush=True,
|
||||
)
|
||||
return 0.0, {
|
||||
"verification_mode": "dataset_tests",
|
||||
"error": "No tool calls were made by the agent",
|
||||
}
|
||||
|
||||
nodeids = self._tests_for_item(item)
|
||||
if not nodeids:
|
||||
return 0.0, {"error": "No tests provided"}
|
||||
|
||||
print(f"[SweSmithOracleEnv] tid={trajectory_id} verify (dataset_tests): ensuring venv + deps", flush=True)
|
||||
setup_cmd = (
|
||||
f"cd {repo_dir} && "
|
||||
"python -m venv .venv && "
|
||||
". .venv/bin/activate && "
|
||||
"python -m pip install -U pip setuptools wheel && "
|
||||
"python -m pip install -e . && "
|
||||
"python -m pip install pytest"
|
||||
)
|
||||
setup_res = await exec_tool(
|
||||
ToolCall(name="terminal", arguments={"command": setup_cmd, "timeout": self.config.install_timeout_s})
|
||||
)
|
||||
verification_messages = [{"role": "user", "content": setup_res.to_xml()}]
|
||||
if not setup_res.success:
|
||||
return 0.0, {
|
||||
"verification_mode": "dataset_tests",
|
||||
"phase": "install",
|
||||
"error": setup_res.error,
|
||||
"output": setup_res.output,
|
||||
"verification_messages": verification_messages,
|
||||
}
|
||||
|
||||
chunks = self._chunk_nodeids(nodeids, max_per_chunk=50)
|
||||
for chunk_idx, chunk in enumerate(chunks):
|
||||
joined = " ".join(chunk)
|
||||
cmd = f"cd {repo_dir} && . .venv/bin/activate && python -m pytest -q {joined}"
|
||||
res = await exec_tool(
|
||||
ToolCall(
|
||||
name="terminal",
|
||||
arguments={"command": cmd, "timeout": self.config.test_timeout_s},
|
||||
)
|
||||
)
|
||||
verification_messages.append({"role": "user", "content": res.to_xml()})
|
||||
if not res.success:
|
||||
return 0.0, {
|
||||
"verification_mode": "dataset_tests",
|
||||
"phase": "pytest",
|
||||
"failed_chunk": chunk_idx,
|
||||
"error": res.error,
|
||||
"output": res.output,
|
||||
"verification_messages": verification_messages,
|
||||
}
|
||||
|
||||
return 1.0, {"verification_mode": "dataset_tests", "passed": True, "verification_messages": verification_messages}
|
||||
|
||||
async def score_trajectory(self, item: Item, final_response: str) -> float:
|
||||
# Not used; scoring happens in verify_and_score_trajectory.
|
||||
_ = (item, final_response)
|
||||
return 0.0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
SweSmithOracleEnv.cli()
|
||||
217
atropos/envs/test_env.py
Normal file
217
atropos/envs/test_env.py
Normal file
@@ -0,0 +1,217 @@
|
||||
"""
|
||||
Simple test environment for validating the atropos-agent setup.
|
||||
|
||||
This environment uses a local OpenAI-compatible server for LLM testing to verify:
|
||||
- BaseEnv extension works correctly
|
||||
- API communication via OpenAI-compatible endpoint
|
||||
- Basic trajectory collection
|
||||
|
||||
This is a minimal environment for testing, not production use.
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from pydantic import Field
|
||||
|
||||
from atroposlib.envs.base import (
|
||||
APIServerConfig,
|
||||
Item,
|
||||
)
|
||||
|
||||
from ..agent import AgentConfig
|
||||
from .agent_env import AgentEnv, AgentEnvConfig
|
||||
|
||||
# Load environment variables from .env file
|
||||
load_dotenv()
|
||||
|
||||
|
||||
# Simple test prompts for validation
|
||||
TEST_PROMPTS = [
|
||||
{
|
||||
"prompt": "What is 2 + 2? Answer with just the number.",
|
||||
"expected": "4",
|
||||
},
|
||||
{
|
||||
"prompt": "What is the capital of France? Answer with just the city name.",
|
||||
"expected": "Paris",
|
||||
},
|
||||
{
|
||||
"prompt": "What color is the sky on a clear day? Answer with just the color.",
|
||||
"expected": "Blue",
|
||||
},
|
||||
{
|
||||
"prompt": "How many days are in a week? Answer with just the number.",
|
||||
"expected": "7",
|
||||
},
|
||||
{
|
||||
"prompt": "What is 10 * 5? Answer with just the number.",
|
||||
"expected": "50",
|
||||
},
|
||||
]
|
||||
|
||||
SYSTEM_PROMPT = (
|
||||
"You are a helpful assistant. Answer questions concisely and directly. "
|
||||
"When asked for a simple answer, provide just that answer without explanation."
|
||||
)
|
||||
|
||||
|
||||
class SimpleTestEnvConfig(AgentEnvConfig):
|
||||
"""Configuration for the simple test environment."""
|
||||
|
||||
server_base_url: str = Field(
|
||||
default="http://127.0.0.1:8080",
|
||||
description="Base URL for an OpenAI-compatible server (without /v1)",
|
||||
)
|
||||
server_model: str = Field(
|
||||
default="hermes-4-36b",
|
||||
description="Model name",
|
||||
)
|
||||
tokenizer_name: str = Field(default="NousResearch/Hermes-4.3-36B", description="Tokenizer name for RL tokenization")
|
||||
|
||||
|
||||
class SimpleTestEnv(AgentEnv[SimpleTestEnvConfig]):
|
||||
"""
|
||||
A simple test environment to validate the atropos-agent setup.
|
||||
|
||||
Uses a local OpenAI-compatible LLM endpoint with basic question-answering tasks.
|
||||
Scoring is based on whether the response contains the expected answer.
|
||||
"""
|
||||
|
||||
name = "simple_test_env"
|
||||
env_config_cls = SimpleTestEnvConfig
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: SimpleTestEnvConfig,
|
||||
server_configs: List[APIServerConfig],
|
||||
slurm: bool = False,
|
||||
testing: bool = False,
|
||||
):
|
||||
super().__init__(config, server_configs, slurm, testing)
|
||||
self.iter = 0
|
||||
self.test_prompts = TEST_PROMPTS
|
||||
self.percent_correct_buffer: List[float] = []
|
||||
|
||||
@classmethod
|
||||
def config_init(cls) -> Tuple[SimpleTestEnvConfig, List[APIServerConfig]]:
|
||||
"""
|
||||
Initialize configuration with local server settings from environment variables.
|
||||
"""
|
||||
base_url = (
|
||||
os.getenv("ATROPOS_SERVER_BASE_URL")
|
||||
or os.getenv("OPENAI_BASE_URL")
|
||||
or os.getenv("LLM_BASE_URL")
|
||||
or "http://127.0.0.1:8080"
|
||||
)
|
||||
model = os.getenv("ATROPOS_SERVER_MODEL") or os.getenv("LLM_MODEL") or "hermes-4-36b"
|
||||
api_key = os.getenv("ATROPOS_SERVER_API_KEY") or os.getenv("NOUS_API_KEY") or os.getenv("OPENAI_API_KEY") or "local"
|
||||
|
||||
env_config = SimpleTestEnvConfig(
|
||||
tokenizer_name=os.getenv("ATROPOS_TOKENIZER_NAME") or "NousResearch/Hermes-4.3-36B",
|
||||
group_size=4,
|
||||
use_wandb=False, # Disable wandb for simple testing
|
||||
rollout_server_url="http://localhost:8000",
|
||||
total_steps=10,
|
||||
batch_size=16,
|
||||
steps_per_eval=5,
|
||||
max_token_length=2048,
|
||||
inference_weight=1.0,
|
||||
wandb_name="simple_test",
|
||||
server_base_url=base_url,
|
||||
server_model=model,
|
||||
)
|
||||
|
||||
# OpenAI-compatible servers typically expose chat completions at /v1.
|
||||
server_configs = [
|
||||
APIServerConfig(
|
||||
model_name=model,
|
||||
base_url=f"{base_url}/v1",
|
||||
api_key=api_key,
|
||||
num_max_requests_at_once=4,
|
||||
num_requests_for_eval=8,
|
||||
timeout=120, # Local models may be slower
|
||||
),
|
||||
]
|
||||
|
||||
return env_config, server_configs
|
||||
|
||||
async def setup_agent_env(self):
|
||||
"""Setup the environment - load test data."""
|
||||
print(f"SimpleTestEnv setup complete. {len(self.test_prompts)} test prompts loaded.")
|
||||
print(f"Using server at: {self.config.server_base_url}")
|
||||
print(f"Model: {self.config.server_model}")
|
||||
|
||||
async def get_next_item(self) -> Item:
|
||||
"""Get the next test prompt."""
|
||||
item = self.test_prompts[self.iter % len(self.test_prompts)]
|
||||
self.iter += 1
|
||||
return item
|
||||
|
||||
def build_task(self, item: Item) -> str:
|
||||
return item["prompt"]
|
||||
|
||||
def build_agent_config(self, item: Item) -> AgentConfig: # noqa: ARG002
|
||||
return AgentConfig(
|
||||
max_steps=5,
|
||||
temperature=0.7,
|
||||
max_tokens=256,
|
||||
system_prompt=SYSTEM_PROMPT,
|
||||
)
|
||||
|
||||
async def score_trajectory(self, item: Item, final_response: str) -> float:
|
||||
expected = item["expected"].lower()
|
||||
response_lower = (final_response or "").lower()
|
||||
score = 1.0 if expected in response_lower else 0.0
|
||||
self.percent_correct_buffer.append(score)
|
||||
return score
|
||||
|
||||
async def evaluate(self, *args, **kwargs):
|
||||
"""
|
||||
Simple evaluation - run through all test prompts once.
|
||||
"""
|
||||
correct = 0
|
||||
total = len(self.test_prompts)
|
||||
|
||||
for item in self.test_prompts:
|
||||
messages = [
|
||||
{"role": "system", "content": SYSTEM_PROMPT},
|
||||
{"role": "user", "content": item["prompt"]},
|
||||
]
|
||||
|
||||
response = await self.server.chat_completion(
|
||||
messages=messages,
|
||||
n=1,
|
||||
max_tokens=256,
|
||||
temperature=0.0, # Greedy for eval
|
||||
split="eval",
|
||||
)
|
||||
|
||||
response_text = response.choices[0].message.content or ""
|
||||
expected = item["expected"].lower()
|
||||
|
||||
if expected in response_text.lower():
|
||||
correct += 1
|
||||
|
||||
accuracy = correct / total
|
||||
print(f"Evaluation: {correct}/{total} = {accuracy:.2%} accuracy")
|
||||
return {"eval_accuracy": accuracy}
|
||||
|
||||
async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
|
||||
"""Log metrics (simplified for testing)."""
|
||||
if wandb_metrics is None:
|
||||
wandb_metrics = {}
|
||||
|
||||
if self.percent_correct_buffer:
|
||||
avg_correct = sum(self.percent_correct_buffer) / len(self.percent_correct_buffer)
|
||||
wandb_metrics["train/percent_correct"] = avg_correct
|
||||
print(f"Train accuracy: {avg_correct:.2%}")
|
||||
self.percent_correct_buffer = []
|
||||
|
||||
await super().wandb_log(wandb_metrics)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Allow running as CLI
|
||||
SimpleTestEnv.cli()
|
||||
165
atropos/envs/toolserver_smoke_env.py
Normal file
165
atropos/envs/toolserver_smoke_env.py
Normal file
@@ -0,0 +1,165 @@
|
||||
"""
|
||||
ToolServer routing smoke environment.
|
||||
|
||||
Validates that:
|
||||
- sandbox tools run through Nomad SlotPool (terminal -> bash in sandbox)
|
||||
- external tools run through ToolServer (skills_list)
|
||||
|
||||
This env uses ToolServer in-process by default (`tool_server_url="inprocess"`),
|
||||
so it is self-contained for local testing.
|
||||
|
||||
Run:
|
||||
uv run python -m atropos.envs.toolserver_smoke_env process --env.use_wandb false --env.total_steps 1 --env.group_size 1
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from pydantic import Field
|
||||
|
||||
from atroposlib.envs.base import APIServerConfig, Item
|
||||
|
||||
from ..agent import AgentConfig, AgentResult
|
||||
from .agent_env import AgentEnv, AgentEnvConfig
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
class ToolServerSmokeEnvConfig(AgentEnvConfig):
|
||||
server_base_url: str = Field(
|
||||
default="http://127.0.0.1:8080",
|
||||
description="Base URL for an OpenAI-compatible chat server (without /v1).",
|
||||
)
|
||||
server_model: str = Field(default="hermes-4-36b", description="Model name")
|
||||
tokenizer_name: str = Field(default="NousResearch/Hermes-4.3-36B", description="Tokenizer name for RL tokenization")
|
||||
|
||||
|
||||
class ToolServerSmokeEnv(AgentEnv[ToolServerSmokeEnvConfig]):
|
||||
name = "toolserver_smoke_env"
|
||||
env_config_cls = ToolServerSmokeEnvConfig
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: ToolServerSmokeEnvConfig,
|
||||
server_configs: List[APIServerConfig],
|
||||
slurm: bool = False,
|
||||
testing: bool = False,
|
||||
):
|
||||
super().__init__(config, server_configs, slurm, testing)
|
||||
self._iter = 0
|
||||
|
||||
@classmethod
|
||||
def config_init(cls) -> Tuple[ToolServerSmokeEnvConfig, List[APIServerConfig]]:
|
||||
base_url = (
|
||||
os.getenv("ATROPOS_SERVER_BASE_URL")
|
||||
or os.getenv("OPENAI_BASE_URL")
|
||||
or os.getenv("LLM_BASE_URL")
|
||||
or "http://127.0.0.1:8080"
|
||||
)
|
||||
model = os.getenv("ATROPOS_SERVER_MODEL") or os.getenv("LLM_MODEL") or "hermes-4-36b"
|
||||
api_key = os.getenv("ATROPOS_SERVER_API_KEY") or os.getenv("NOUS_API_KEY") or os.getenv("OPENAI_API_KEY") or "local"
|
||||
|
||||
env_config = ToolServerSmokeEnvConfig(
|
||||
tokenizer_name=os.getenv("ATROPOS_TOKENIZER_NAME") or "NousResearch/Hermes-4.3-36B",
|
||||
group_size=1,
|
||||
use_wandb=False,
|
||||
include_messages=True,
|
||||
ensure_scores_are_not_same=False,
|
||||
total_steps=1,
|
||||
batch_size=1,
|
||||
server_base_url=base_url,
|
||||
server_model=model,
|
||||
enabled_toolsets=["terminal", "skills"],
|
||||
disabled_toolsets=[],
|
||||
# Self-contained ToolServer for local smoke.
|
||||
tool_server_url="inprocess",
|
||||
sandbox_image=os.getenv("ATROPOS_SANDBOX_IMAGE") or "atropos-sandbox:local",
|
||||
purge_job_on_start=True,
|
||||
purge_job_on_shutdown=True,
|
||||
)
|
||||
|
||||
server_configs = [
|
||||
APIServerConfig(
|
||||
model_name=model,
|
||||
base_url=f"{base_url.rstrip('/')}/v1",
|
||||
api_key=api_key,
|
||||
num_max_requests_at_once=1,
|
||||
num_requests_for_eval=1,
|
||||
timeout=120,
|
||||
)
|
||||
]
|
||||
return env_config, server_configs
|
||||
|
||||
async def setup_agent_env(self) -> None:
|
||||
return None
|
||||
|
||||
async def get_next_item(self) -> Item:
|
||||
self._iter += 1
|
||||
return {
|
||||
"prompt": (
|
||||
"You MUST call exactly one tool per assistant message.\n"
|
||||
"\n"
|
||||
"Step 1) Call the skills_list tool (no arguments), then stop.\n"
|
||||
"Step 2) After you receive the tool response, call the terminal tool to run:\n"
|
||||
"python -c \"print('ok')\"\n"
|
||||
"Step 3) After you receive the terminal tool response, answer with just: ok\n"
|
||||
"\n"
|
||||
"Tool call format requirements:\n"
|
||||
"- Every tool call MUST be a complete XML block with a closing tag.\n"
|
||||
"- Do NOT emit a second <tool_call> in the same assistant message.\n"
|
||||
"\n"
|
||||
"Example:\n"
|
||||
"<tool_call>{\"name\": \"skills_list\", \"arguments\": {}}</tool_call>\n"
|
||||
"Do not include anything else in your final answer."
|
||||
)
|
||||
}
|
||||
|
||||
def build_task(self, item: Item) -> str:
|
||||
return str(item.get("prompt") or "")
|
||||
|
||||
def build_agent_config(self, item: Item) -> AgentConfig: # noqa: ARG002
|
||||
return AgentConfig(
|
||||
max_steps=min(10, int(self.config.agent_max_steps)),
|
||||
temperature=0.2,
|
||||
max_tokens=None,
|
||||
)
|
||||
|
||||
async def score_trajectory(self, item: Item, final_response: str) -> float:
|
||||
_ = (item, final_response)
|
||||
return 0.0
|
||||
|
||||
async def verify_and_score_trajectory(
|
||||
self,
|
||||
item: Item,
|
||||
final_response: str,
|
||||
*,
|
||||
trajectory_id: str, # noqa: ARG002
|
||||
exec_tool, # noqa: ARG002
|
||||
agent_result: AgentResult | None = None,
|
||||
workspace_meta: Dict[str, Any] | None = None, # noqa: ARG002
|
||||
) -> tuple[float, Dict[str, Any]]:
|
||||
if agent_result is None:
|
||||
return 0.0, {"error": "Missing agent_result"}
|
||||
|
||||
called = {c.name for s in agent_result.steps for c in s.tool_calls}
|
||||
need = {"skills_list", "terminal"}
|
||||
if not need.issubset(called):
|
||||
return 0.0, {"error": f"Missing tool calls: {sorted(need - called)}", "called": sorted(called)}
|
||||
|
||||
terminal_ok = False
|
||||
for step in agent_result.steps:
|
||||
for call, res in zip(step.tool_calls, step.tool_results):
|
||||
if call.name != "terminal":
|
||||
continue
|
||||
if res.success and (res.output or "").strip().splitlines()[-1].strip() == "ok":
|
||||
terminal_ok = True
|
||||
|
||||
score = 1.0 if terminal_ok and (final_response or "").strip() == "ok" else 0.0
|
||||
return score, {"called": sorted(called), "final": (final_response or "").strip()}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
ToolServerSmokeEnv.cli()
|
||||
11
atropos/nomad/__init__.py
Normal file
11
atropos/nomad/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
"""
|
||||
Nomad integration for atropos-agent.
|
||||
|
||||
Provides:
|
||||
- NomadClient: Client for Nomad HTTP API
|
||||
- Job templates for sandbox containers
|
||||
"""
|
||||
|
||||
from .client import NomadClient
|
||||
|
||||
__all__ = ["NomadClient"]
|
||||
500
atropos/nomad/client.py
Normal file
500
atropos/nomad/client.py
Normal file
@@ -0,0 +1,500 @@
|
||||
"""
|
||||
Nomad API Client for atropos-agent.
|
||||
|
||||
Provides a simple async client for interacting with the Nomad HTTP API:
|
||||
- Submit/stop jobs
|
||||
- Query allocations
|
||||
- Get allocation addresses
|
||||
- Scale jobs up/down
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import aiohttp
|
||||
|
||||
|
||||
class AllocationStatus(Enum):
|
||||
"""Nomad allocation status."""
|
||||
PENDING = "pending"
|
||||
RUNNING = "running"
|
||||
COMPLETE = "complete"
|
||||
FAILED = "failed"
|
||||
LOST = "lost"
|
||||
|
||||
|
||||
@dataclass
|
||||
class Allocation:
|
||||
"""Information about a Nomad allocation."""
|
||||
id: str
|
||||
job_id: str
|
||||
task_group: str
|
||||
node_id: str
|
||||
status: AllocationStatus
|
||||
# Network info for reaching the allocation
|
||||
address: Optional[str] = None
|
||||
port: Optional[int] = None
|
||||
|
||||
@property
|
||||
def http_address(self) -> Optional[str]:
|
||||
"""Get full HTTP address for the allocation."""
|
||||
if self.address and self.port:
|
||||
return f"http://{self.address}:{self.port}"
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class JobStatus:
|
||||
"""Status of a Nomad job."""
|
||||
id: str
|
||||
name: str
|
||||
status: str
|
||||
allocations: List[Allocation] = field(default_factory=list)
|
||||
count: int = 0 # Number of task groups
|
||||
|
||||
|
||||
class NomadClient:
|
||||
"""
|
||||
Async client for Nomad HTTP API.
|
||||
|
||||
Usage:
|
||||
client = NomadClient(address="http://localhost:4646")
|
||||
|
||||
# Submit a job
|
||||
await client.submit_job(job_spec)
|
||||
|
||||
# Get allocations
|
||||
allocs = await client.get_job_allocations("sandbox-python")
|
||||
|
||||
# Scale job
|
||||
await client.scale_job("sandbox-python", count=5)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
address: str = "http://localhost:4646",
|
||||
token: Optional[str] = None,
|
||||
timeout: float = 30.0,
|
||||
):
|
||||
self.address = address.rstrip("/")
|
||||
self.token = token or os.environ.get("NOMAD_TOKEN")
|
||||
self.timeout = aiohttp.ClientTimeout(total=timeout)
|
||||
self._session: Optional[aiohttp.ClientSession] = None
|
||||
|
||||
async def _get_session(self) -> aiohttp.ClientSession:
|
||||
"""Get or create HTTP session."""
|
||||
if self._session is None or self._session.closed:
|
||||
headers = {}
|
||||
if self.token:
|
||||
headers["X-Nomad-Token"] = self.token
|
||||
self._session = aiohttp.ClientSession(
|
||||
timeout=self.timeout,
|
||||
headers=headers,
|
||||
)
|
||||
return self._session
|
||||
|
||||
async def close(self):
|
||||
"""Close the HTTP session."""
|
||||
if self._session and not self._session.closed:
|
||||
await self._session.close()
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
await self.close()
|
||||
|
||||
async def _request(
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
data: Optional[Dict[str, Any]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Make an HTTP request to Nomad API."""
|
||||
session = await self._get_session()
|
||||
url = f"{self.address}{path}"
|
||||
|
||||
try:
|
||||
async with session.request(method, url, json=data) as response:
|
||||
if response.status == 404:
|
||||
return {"error": "not_found", "status": 404}
|
||||
|
||||
text = await response.text()
|
||||
if not text:
|
||||
return {"status": response.status}
|
||||
|
||||
try:
|
||||
result = json.loads(text)
|
||||
except json.JSONDecodeError:
|
||||
return {"text": text, "status": response.status}
|
||||
|
||||
if response.status >= 400:
|
||||
return {"error": result, "status": response.status}
|
||||
|
||||
return result if isinstance(result, dict) else {"data": result, "status": response.status}
|
||||
|
||||
except aiohttp.ClientError as e:
|
||||
return {"error": str(e), "status": 0}
|
||||
|
||||
# Job Operations
|
||||
|
||||
async def submit_job(self, job_spec: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Submit a job to Nomad.
|
||||
|
||||
Args:
|
||||
job_spec: Job specification dict (HCL converted to JSON)
|
||||
|
||||
Returns:
|
||||
Response with EvalID if successful
|
||||
"""
|
||||
return await self._request("POST", "/v1/jobs", {"Job": job_spec})
|
||||
|
||||
async def stop_job(self, job_id: str, purge: bool = False) -> Dict[str, Any]:
|
||||
"""
|
||||
Stop (and optionally purge) a job.
|
||||
|
||||
Args:
|
||||
job_id: Job identifier
|
||||
purge: If True, completely remove the job
|
||||
"""
|
||||
path = f"/v1/job/{job_id}"
|
||||
if purge:
|
||||
path += "?purge=true"
|
||||
return await self._request("DELETE", path)
|
||||
|
||||
async def get_job(self, job_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get job details."""
|
||||
result = await self._request("GET", f"/v1/job/{job_id}")
|
||||
if "error" in result and result.get("status") == 404:
|
||||
return None
|
||||
return result
|
||||
|
||||
async def get_job_status(self, job_id: str) -> Optional[JobStatus]:
|
||||
"""Get job status with allocations."""
|
||||
job = await self.get_job(job_id)
|
||||
if not job:
|
||||
return None
|
||||
|
||||
allocs = await self.get_job_allocations(job_id)
|
||||
|
||||
# Get count from task groups
|
||||
count = 0
|
||||
task_groups = job.get("TaskGroups", [])
|
||||
for tg in task_groups:
|
||||
count += tg.get("Count", 1)
|
||||
|
||||
return JobStatus(
|
||||
id=job_id,
|
||||
name=job.get("Name", job_id),
|
||||
status=job.get("Status", "unknown"),
|
||||
allocations=allocs,
|
||||
count=count,
|
||||
)
|
||||
|
||||
# Allocation Operations
|
||||
|
||||
async def get_job_allocations(self, job_id: str) -> List[Allocation]:
|
||||
"""Get all allocations for a job."""
|
||||
result = await self._request("GET", f"/v1/job/{job_id}/allocations")
|
||||
|
||||
if "error" in result:
|
||||
return []
|
||||
|
||||
allocs_data = result.get("data", result) if isinstance(result, dict) else result
|
||||
if not isinstance(allocs_data, list):
|
||||
return []
|
||||
|
||||
allocations = []
|
||||
for alloc_data in allocs_data:
|
||||
# Parse allocation info
|
||||
alloc_id = alloc_data.get("ID", "")
|
||||
status_str = alloc_data.get("ClientStatus", "unknown")
|
||||
|
||||
try:
|
||||
status = AllocationStatus(status_str)
|
||||
except ValueError:
|
||||
status = AllocationStatus.PENDING
|
||||
|
||||
# Get network info - need to fetch detailed allocation for this
|
||||
address = None
|
||||
port = None
|
||||
|
||||
# First try the summary data
|
||||
resources = alloc_data.get("AllocatedResources") or {}
|
||||
shared = resources.get("Shared") or {}
|
||||
networks = shared.get("Networks") or []
|
||||
|
||||
# If no networks in summary, fetch detailed allocation
|
||||
if not networks and alloc_id:
|
||||
detailed = await self.get_allocation(alloc_id)
|
||||
if detailed:
|
||||
resources = detailed.get("AllocatedResources") or {}
|
||||
shared = resources.get("Shared") or {}
|
||||
networks = shared.get("Networks") or []
|
||||
|
||||
if networks:
|
||||
network = networks[0]
|
||||
address = network.get("IP")
|
||||
# Look for dynamic ports OR reserved ports (Singularity/raw_exec uses reserved)
|
||||
dyn_ports = network.get("DynamicPorts") or []
|
||||
reserved_ports = network.get("ReservedPorts") or []
|
||||
for dp in dyn_ports + reserved_ports:
|
||||
if dp.get("Label") == "http":
|
||||
port = dp.get("Value")
|
||||
break
|
||||
|
||||
allocations.append(Allocation(
|
||||
id=alloc_id,
|
||||
job_id=job_id,
|
||||
task_group=alloc_data.get("TaskGroup", ""),
|
||||
node_id=alloc_data.get("NodeID", ""),
|
||||
status=status,
|
||||
address=address,
|
||||
port=port,
|
||||
))
|
||||
|
||||
return allocations
|
||||
|
||||
async def get_allocation(self, alloc_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get detailed allocation info."""
|
||||
result = await self._request("GET", f"/v1/allocation/{alloc_id}")
|
||||
if "error" in result and result.get("status") == 404:
|
||||
return None
|
||||
return result
|
||||
|
||||
# Scaling Operations
|
||||
|
||||
async def scale_job(self, job_id: str, count: int, task_group: str = "sandbox") -> Dict[str, Any]:
|
||||
"""
|
||||
Scale a job's task group to specified count.
|
||||
|
||||
Args:
|
||||
job_id: Job identifier
|
||||
count: Desired number of allocations
|
||||
task_group: Name of task group to scale
|
||||
"""
|
||||
payload = {
|
||||
"Count": count,
|
||||
"Target": {
|
||||
"Group": task_group,
|
||||
},
|
||||
}
|
||||
return await self._request("POST", f"/v1/job/{job_id}/scale", payload)
|
||||
|
||||
async def get_job_scale_status(self, job_id: str) -> Dict[str, int]:
|
||||
"""
|
||||
Get current scale status for a job.
|
||||
|
||||
Returns:
|
||||
Dict mapping task group name to count
|
||||
"""
|
||||
result = await self._request("GET", f"/v1/job/{job_id}/scale")
|
||||
|
||||
if "error" in result:
|
||||
return {}
|
||||
|
||||
task_groups = result.get("TaskGroups", {})
|
||||
return {
|
||||
name: info.get("Running", 0)
|
||||
for name, info in task_groups.items()
|
||||
}
|
||||
|
||||
# Health Check
|
||||
|
||||
async def is_healthy(self) -> bool:
|
||||
"""Check if Nomad is reachable and healthy."""
|
||||
try:
|
||||
result = await self._request("GET", "/v1/status/leader")
|
||||
return "error" not in result
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
async def get_leader(self) -> Optional[str]:
|
||||
"""Get current Nomad leader address."""
|
||||
result = await self._request("GET", "/v1/status/leader")
|
||||
if isinstance(result, dict) and "data" in result:
|
||||
return result["data"]
|
||||
return None
|
||||
|
||||
|
||||
def load_job_template(
|
||||
template_name: str = "sandbox",
|
||||
**kwargs,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Load and configure a job template.
|
||||
|
||||
Args:
|
||||
template_name: Name of template (e.g., "sandbox")
|
||||
**kwargs: Template variables to substitute
|
||||
|
||||
Returns:
|
||||
Job specification dict ready for Nomad API
|
||||
"""
|
||||
# Default job template for sandbox container
|
||||
if template_name == "sandbox":
|
||||
return create_sandbox_job(**kwargs)
|
||||
else:
|
||||
raise ValueError(f"Unknown template: {template_name}")
|
||||
|
||||
|
||||
def create_sandbox_job(
|
||||
job_id: str = "atropos-sandbox",
|
||||
image: str = "atropos-sandbox:local", # Use :local tag to avoid registry pull
|
||||
count: int = 1,
|
||||
slots_per_container: int = 10,
|
||||
privileged: bool = False,
|
||||
cpu: int = 500,
|
||||
memory: int = 512,
|
||||
port: int = 8080,
|
||||
datacenter: str = "dc1",
|
||||
driver: str = "docker", # "docker" or "singularity"
|
||||
singularity_image: str = None, # Path to .sif file for singularity driver
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Create a sandbox job specification.
|
||||
|
||||
This job runs the sandbox_server.py inside a container,
|
||||
with the specified number of slots for agent workspaces.
|
||||
|
||||
Args:
|
||||
job_id: Unique job identifier
|
||||
image: Docker image to use (for docker driver)
|
||||
count: Number of container instances
|
||||
slots_per_container: Number of slots per container
|
||||
privileged: Run container in privileged mode (recommended for bubblewrap)
|
||||
cpu: CPU allocation in MHz
|
||||
memory: Memory allocation in MB
|
||||
port: HTTP port for sandbox server
|
||||
datacenter: Nomad datacenter
|
||||
driver: Container driver - "docker" or "singularity"
|
||||
singularity_image: Path to .sif file (required if driver="singularity")
|
||||
|
||||
Returns:
|
||||
Job specification dict
|
||||
"""
|
||||
# Build task config based on driver
|
||||
if driver == "singularity":
|
||||
if not singularity_image:
|
||||
raise ValueError("singularity_image path required when driver='singularity'")
|
||||
|
||||
# Use raw_exec driver to run apptainer via shell for variable expansion
|
||||
# The container binds the allocation directory for workspace persistence
|
||||
# For raw_exec, we use static port since Nomad's dynamic port mapping doesn't
|
||||
# work the same as Docker - the process runs directly on the host.
|
||||
shell_cmd = (
|
||||
f'apptainer run '
|
||||
f'--bind "$NOMAD_ALLOC_DIR/data:/data" '
|
||||
f'--pwd /app '
|
||||
f'--env PYTHONUNBUFFERED=1 '
|
||||
f'{singularity_image} '
|
||||
f'python sandbox_server.py '
|
||||
f'--port {port} '
|
||||
f'--slots {slots_per_container} '
|
||||
f'--data-dir /data'
|
||||
)
|
||||
task_config = {
|
||||
"command": "/bin/sh",
|
||||
"args": ["-c", shell_cmd],
|
||||
}
|
||||
task_driver = "raw_exec"
|
||||
else:
|
||||
# Docker driver (default)
|
||||
task_config = {
|
||||
"image": image,
|
||||
"force_pull": False, # Use local image, don't try to pull
|
||||
"ports": ["http"],
|
||||
"privileged": privileged,
|
||||
"command": "python",
|
||||
"args": [
|
||||
"sandbox_server.py",
|
||||
"--port", str(port),
|
||||
"--slots", str(slots_per_container),
|
||||
"--data-dir", "/data",
|
||||
],
|
||||
# Note: On Linux, you can mount persistent storage:
|
||||
# "volumes": ["${NOMAD_ALLOC_DIR}/data:/data"],
|
||||
# On macOS/Docker Desktop, skip volumes for PoC
|
||||
# (container /data is ephemeral but works for testing)
|
||||
}
|
||||
task_driver = "docker"
|
||||
|
||||
# For Singularity/raw_exec, use static ports since the process runs directly on host.
|
||||
# For Docker, use dynamic ports with port mapping.
|
||||
if driver == "singularity":
|
||||
network_config = {
|
||||
"Mode": "host",
|
||||
"ReservedPorts": [
|
||||
{
|
||||
"Label": "http",
|
||||
"Value": port,
|
||||
}
|
||||
],
|
||||
}
|
||||
else:
|
||||
network_config = {
|
||||
"Mode": "host",
|
||||
"DynamicPorts": [
|
||||
{
|
||||
"Label": "http",
|
||||
"To": port,
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
return {
|
||||
"ID": job_id,
|
||||
"Name": job_id,
|
||||
"Type": "service",
|
||||
"Datacenters": [datacenter],
|
||||
"TaskGroups": [
|
||||
{
|
||||
"Name": "sandbox",
|
||||
"Count": count,
|
||||
# Speed up deployments and avoid Consul checks. Without this, Nomad may
|
||||
# keep an "active deployment" around for the default MinHealthyTime,
|
||||
# which blocks immediate scaling under load.
|
||||
"Update": {
|
||||
"HealthCheck": "task_states",
|
||||
"MinHealthyTime": 0,
|
||||
},
|
||||
"Networks": [network_config],
|
||||
"Tasks": [
|
||||
{
|
||||
"Name": "sandbox-server",
|
||||
"Driver": task_driver,
|
||||
"Config": task_config,
|
||||
"Env": {
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"NOMAD_ALLOC_DIR": "${NOMAD_ALLOC_DIR}",
|
||||
},
|
||||
"Resources": {
|
||||
"CPU": cpu,
|
||||
"MemoryMB": memory,
|
||||
},
|
||||
# Note: Services with Checks require Consul, which we skip for the PoC
|
||||
}
|
||||
],
|
||||
"RestartPolicy": {
|
||||
"Attempts": 3,
|
||||
"Interval": 300_000_000_000, # 5 minutes
|
||||
"Delay": 10_000_000_000, # 10 seconds
|
||||
"Mode": "delay",
|
||||
},
|
||||
"ReschedulePolicy": {
|
||||
"Attempts": 5,
|
||||
"Interval": 3600_000_000_000, # 1 hour
|
||||
"Delay": 30_000_000_000, # 30 seconds
|
||||
"DelayFunction": "exponential",
|
||||
"MaxDelay": 300_000_000_000, # 5 minutes
|
||||
"Unlimited": False,
|
||||
},
|
||||
}
|
||||
],
|
||||
}
|
||||
1912
atropos/sandbox_server.py
Normal file
1912
atropos/sandbox_server.py
Normal file
File diff suppressed because it is too large
Load Diff
20
atropos/slots/__init__.py
Normal file
20
atropos/slots/__init__.py
Normal file
@@ -0,0 +1,20 @@
|
||||
"""
|
||||
Slot-based multiplexing for atropos-agent.
|
||||
|
||||
Provides:
|
||||
- Slot: Isolated workspace for a single trajectory
|
||||
- SlotPool: Manages slots across Nomad allocations
|
||||
- SandboxExecutor: Executes tools in sandbox containers
|
||||
"""
|
||||
|
||||
from .executor import SandboxExecutor
|
||||
from .pool import SlotPool, SlotPoolConfig
|
||||
from .slot import Slot, SlotState
|
||||
|
||||
__all__ = [
|
||||
"Slot",
|
||||
"SlotState",
|
||||
"SlotPool",
|
||||
"SlotPoolConfig",
|
||||
"SandboxExecutor",
|
||||
]
|
||||
457
atropos/slots/executor.py
Normal file
457
atropos/slots/executor.py
Normal file
@@ -0,0 +1,457 @@
|
||||
"""
|
||||
SandboxExecutor - HTTP client for sandbox container communication.
|
||||
|
||||
Sends tool execution requests to sandbox_server.py running inside Nomad containers.
|
||||
Supports single and batch execution for efficiency.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import aiohttp
|
||||
|
||||
from .slot import Slot, SlotState
|
||||
from ..tools.base import ToolCall, ToolResult
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExecutionRequest:
|
||||
"""Request to execute a tool in a slot."""
|
||||
slot: Slot
|
||||
tool_name: str
|
||||
args: Dict[str, Any]
|
||||
execution_id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
||||
timeout: float = 30.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExecutionResult:
|
||||
"""Result from sandbox execution."""
|
||||
success: bool
|
||||
output: str = ""
|
||||
error: str = ""
|
||||
execution_id: str = ""
|
||||
slot_id: str = ""
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def to_tool_result(self) -> ToolResult:
|
||||
"""Convert to ToolResult for agent consumption."""
|
||||
return ToolResult(
|
||||
success=self.success,
|
||||
output=self.output,
|
||||
error=self.error,
|
||||
metadata=self.metadata,
|
||||
uniq_id=self.execution_id,
|
||||
)
|
||||
|
||||
|
||||
class SandboxExecutor:
|
||||
"""
|
||||
HTTP client for executing tools in sandbox containers.
|
||||
|
||||
Communicates with sandbox_server.py running inside Nomad allocations.
|
||||
Supports both single execution and batched parallel execution.
|
||||
|
||||
Usage:
|
||||
executor = SandboxExecutor()
|
||||
|
||||
# Single execution
|
||||
result = await executor.execute(slot, "bash", {"command": "ls"})
|
||||
|
||||
# Batch execution
|
||||
results = await executor.execute_batch([
|
||||
(slot1, "bash", {"command": "ls"}),
|
||||
(slot2, "write_file", {"path": "test.txt", "content": "hello"}),
|
||||
])
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
timeout: float = 30.0,
|
||||
max_retries: int = 3,
|
||||
retry_delay: float = 1.0,
|
||||
):
|
||||
self.timeout = aiohttp.ClientTimeout(total=timeout)
|
||||
self.max_retries = max_retries
|
||||
self.retry_delay = retry_delay
|
||||
self._session: Optional[aiohttp.ClientSession] = None
|
||||
|
||||
async def _get_session(self) -> aiohttp.ClientSession:
|
||||
"""Get or create HTTP session."""
|
||||
if self._session is None or self._session.closed:
|
||||
self._session = aiohttp.ClientSession(timeout=self.timeout)
|
||||
return self._session
|
||||
|
||||
async def close(self):
|
||||
"""Close HTTP session."""
|
||||
if self._session and not self._session.closed:
|
||||
await self._session.close()
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
await self.close()
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
slot: Slot,
|
||||
tool_name: str,
|
||||
args: Dict[str, Any],
|
||||
timeout: Optional[float] = None,
|
||||
) -> ExecutionResult:
|
||||
"""
|
||||
Execute a tool in a slot's workspace.
|
||||
|
||||
Args:
|
||||
slot: Slot to execute in
|
||||
tool_name: Name of tool (bash, read_file, write_file)
|
||||
args: Tool arguments
|
||||
timeout: Optional timeout override
|
||||
|
||||
Returns:
|
||||
ExecutionResult with output or error
|
||||
"""
|
||||
execution_id = str(uuid.uuid4())
|
||||
exec_timeout = timeout or self.timeout.total or 30.0
|
||||
|
||||
# Mark slot as executing
|
||||
original_state = slot.state
|
||||
try:
|
||||
if slot.state == SlotState.ACQUIRED:
|
||||
slot.start_execution(execution_id)
|
||||
|
||||
result = await self._send_execute_request(
|
||||
container_addr=slot.container_addr,
|
||||
slot_id=slot.slot_id,
|
||||
tool_name=tool_name,
|
||||
args=args,
|
||||
execution_id=execution_id,
|
||||
timeout=exec_timeout,
|
||||
)
|
||||
result.slot_id = slot.slot_id
|
||||
return result
|
||||
|
||||
finally:
|
||||
# Restore slot state
|
||||
if slot.state == SlotState.EXECUTING:
|
||||
slot.end_execution()
|
||||
|
||||
async def _send_execute_request(
|
||||
self,
|
||||
container_addr: str,
|
||||
slot_id: str,
|
||||
tool_name: str,
|
||||
args: Dict[str, Any],
|
||||
execution_id: str,
|
||||
timeout: float,
|
||||
) -> ExecutionResult:
|
||||
"""Send execution request to sandbox server with retry logic."""
|
||||
session = await self._get_session()
|
||||
url = f"{container_addr}/execute"
|
||||
|
||||
payload = {
|
||||
"slot_id": slot_id,
|
||||
"tool": tool_name,
|
||||
"args": args,
|
||||
"execution_id": execution_id,
|
||||
"timeout": timeout,
|
||||
}
|
||||
|
||||
last_error = None
|
||||
for attempt in range(self.max_retries):
|
||||
try:
|
||||
async with session.post(url, json=payload) as response:
|
||||
data = await response.json()
|
||||
|
||||
return ExecutionResult(
|
||||
success=data.get("success", False),
|
||||
output=data.get("output", ""),
|
||||
error=data.get("error", ""),
|
||||
execution_id=data.get("execution_id", execution_id),
|
||||
metadata=data.get("metadata", {}),
|
||||
)
|
||||
|
||||
except aiohttp.ClientError as e:
|
||||
last_error = str(e)
|
||||
if attempt < self.max_retries - 1:
|
||||
await asyncio.sleep(self.retry_delay * (attempt + 1))
|
||||
continue
|
||||
except asyncio.TimeoutError:
|
||||
last_error = f"Request timed out after {timeout}s"
|
||||
break
|
||||
except Exception as e:
|
||||
last_error = str(e)
|
||||
break
|
||||
|
||||
return ExecutionResult(
|
||||
success=False,
|
||||
error=f"Failed after {self.max_retries} attempts: {last_error}",
|
||||
execution_id=execution_id,
|
||||
)
|
||||
|
||||
async def execute_batch(
|
||||
self,
|
||||
requests: List[Tuple[Slot, str, Dict[str, Any]]],
|
||||
timeout: Optional[float] = None,
|
||||
) -> List[ExecutionResult]:
|
||||
"""
|
||||
Execute multiple tools in parallel across slots.
|
||||
|
||||
This is the key optimization - we batch tool calls to maximize
|
||||
container utilization while agents are waiting for LLM responses.
|
||||
|
||||
Args:
|
||||
requests: List of (slot, tool_name, args) tuples
|
||||
timeout: Optional timeout override
|
||||
|
||||
Returns:
|
||||
List of ExecutionResults in same order as requests
|
||||
"""
|
||||
if not requests:
|
||||
return []
|
||||
|
||||
# Group requests by container address for batch API
|
||||
by_container: Dict[str, List[Tuple[int, Slot, str, Dict[str, Any], str]]] = {}
|
||||
|
||||
for idx, (slot, tool_name, args) in enumerate(requests):
|
||||
execution_id = str(uuid.uuid4())
|
||||
container = slot.container_addr
|
||||
|
||||
if container not in by_container:
|
||||
by_container[container] = []
|
||||
by_container[container].append((idx, slot, tool_name, args, execution_id))
|
||||
|
||||
# Mark slots as executing
|
||||
if slot.state == SlotState.ACQUIRED:
|
||||
slot.start_execution(execution_id)
|
||||
|
||||
# Execute batches in parallel
|
||||
exec_timeout = timeout or self.timeout.total or 30.0
|
||||
batch_tasks = []
|
||||
|
||||
for container_addr, batch_requests in by_container.items():
|
||||
task = self._send_batch_request(
|
||||
container_addr=container_addr,
|
||||
batch_requests=batch_requests,
|
||||
timeout=exec_timeout,
|
||||
)
|
||||
batch_tasks.append(task)
|
||||
|
||||
# Gather all batch results
|
||||
batch_results = await asyncio.gather(*batch_tasks, return_exceptions=True)
|
||||
|
||||
# Collect results in original order
|
||||
results: List[Optional[ExecutionResult]] = [None] * len(requests)
|
||||
|
||||
for batch_result in batch_results:
|
||||
if isinstance(batch_result, Exception):
|
||||
# Mark all in this batch as failed
|
||||
continue
|
||||
|
||||
for idx, result in batch_result:
|
||||
results[idx] = result
|
||||
|
||||
# Fill in any missing results
|
||||
for idx, result in enumerate(results):
|
||||
if result is None:
|
||||
slot, tool_name, args = requests[idx]
|
||||
results[idx] = ExecutionResult(
|
||||
success=False,
|
||||
error="Batch execution failed",
|
||||
slot_id=slot.slot_id,
|
||||
)
|
||||
|
||||
# End execution on all slots
|
||||
for slot, _, _ in requests:
|
||||
if slot.state == SlotState.EXECUTING:
|
||||
slot.end_execution()
|
||||
|
||||
return results # type: ignore
|
||||
|
||||
async def _send_batch_request(
|
||||
self,
|
||||
container_addr: str,
|
||||
batch_requests: List[Tuple[int, Slot, str, Dict[str, Any], str]],
|
||||
timeout: float,
|
||||
) -> List[Tuple[int, ExecutionResult]]:
|
||||
"""Send batch execution request to a single container."""
|
||||
session = await self._get_session()
|
||||
url = f"{container_addr}/batch"
|
||||
|
||||
# Build batch payload
|
||||
payload = [
|
||||
{
|
||||
"slot_id": slot.slot_id,
|
||||
"tool": tool_name,
|
||||
"args": args,
|
||||
"execution_id": execution_id,
|
||||
"timeout": timeout,
|
||||
}
|
||||
for _, slot, tool_name, args, execution_id in batch_requests
|
||||
]
|
||||
|
||||
try:
|
||||
async with session.post(url, json=payload) as response:
|
||||
data = await response.json()
|
||||
|
||||
if not isinstance(data, list):
|
||||
raise ValueError(f"Expected list response, got {type(data)}")
|
||||
|
||||
results = []
|
||||
for i, (idx, slot, _, _, execution_id) in enumerate(batch_requests):
|
||||
if i < len(data):
|
||||
item = data[i]
|
||||
result = ExecutionResult(
|
||||
success=item.get("success", False),
|
||||
output=item.get("output", ""),
|
||||
error=item.get("error", ""),
|
||||
execution_id=item.get("execution_id", execution_id),
|
||||
slot_id=slot.slot_id,
|
||||
metadata=item.get("metadata", {}),
|
||||
)
|
||||
else:
|
||||
result = ExecutionResult(
|
||||
success=False,
|
||||
error="Missing result in batch response",
|
||||
execution_id=execution_id,
|
||||
slot_id=slot.slot_id,
|
||||
)
|
||||
results.append((idx, result))
|
||||
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
# Return error for all requests in batch
|
||||
return [
|
||||
(idx, ExecutionResult(
|
||||
success=False,
|
||||
error=str(e),
|
||||
execution_id=execution_id,
|
||||
slot_id=slot.slot_id,
|
||||
))
|
||||
for idx, slot, _, _, execution_id in batch_requests
|
||||
]
|
||||
|
||||
async def reset_slot(self, slot: Slot) -> ExecutionResult:
|
||||
"""
|
||||
Reset a slot's workspace (delete all files).
|
||||
|
||||
Useful when reusing a slot for a new trajectory.
|
||||
"""
|
||||
session = await self._get_session()
|
||||
url = f"{slot.container_addr}/reset"
|
||||
|
||||
try:
|
||||
async with session.post(url, json={"slot_id": slot.slot_id}) as response:
|
||||
data = await response.json()
|
||||
return ExecutionResult(
|
||||
success=data.get("success", False),
|
||||
output=data.get("output", ""),
|
||||
error=data.get("error", ""),
|
||||
slot_id=slot.slot_id,
|
||||
)
|
||||
except Exception as e:
|
||||
return ExecutionResult(
|
||||
success=False,
|
||||
error=str(e),
|
||||
slot_id=slot.slot_id,
|
||||
)
|
||||
|
||||
async def health_check(self, container_addr: str) -> bool:
|
||||
"""Check if a sandbox container is healthy."""
|
||||
session = await self._get_session()
|
||||
url = f"{container_addr}/health"
|
||||
|
||||
try:
|
||||
async with session.get(url) as response:
|
||||
data = await response.json()
|
||||
return data.get("status") == "ok"
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
async def get_container_status(
|
||||
self,
|
||||
container_addr: str
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Get status info from a sandbox container."""
|
||||
session = await self._get_session()
|
||||
url = f"{container_addr}/health"
|
||||
|
||||
try:
|
||||
async with session.get(url) as response:
|
||||
return await response.json()
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Artifact helpers (optional)
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
async def _post_json(
|
||||
self,
|
||||
url: str,
|
||||
payload: Dict[str, Any],
|
||||
timeout: Optional[float] = None,
|
||||
) -> Dict[str, Any]:
|
||||
session = await self._get_session()
|
||||
try:
|
||||
async with session.post(url, json=payload, timeout=timeout) as response:
|
||||
data = await response.json()
|
||||
if isinstance(data, dict):
|
||||
data.setdefault("http_status", response.status)
|
||||
return data
|
||||
return {"success": False, "error": f"Unexpected response type: {type(data)}", "http_status": response.status}
|
||||
except Exception as e:
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
async def read_artifact(
|
||||
self,
|
||||
slot: Slot,
|
||||
path: str,
|
||||
*,
|
||||
encoding: str = "text",
|
||||
max_bytes: Optional[int] = None,
|
||||
include_sha256: bool = False,
|
||||
timeout: Optional[float] = None,
|
||||
) -> Dict[str, Any]:
|
||||
url = f"{slot.container_addr}/artifacts/read"
|
||||
payload: Dict[str, Any] = {"slot_id": slot.slot_id, "path": path, "encoding": encoding, "include_sha256": include_sha256}
|
||||
if max_bytes is not None:
|
||||
payload["max_bytes"] = max_bytes
|
||||
return await self._post_json(url, payload, timeout=timeout)
|
||||
|
||||
async def list_artifacts(
|
||||
self,
|
||||
slot: Slot,
|
||||
path: str = ".",
|
||||
*,
|
||||
recursive: bool = False,
|
||||
max_entries: Optional[int] = None,
|
||||
timeout: Optional[float] = None,
|
||||
) -> Dict[str, Any]:
|
||||
url = f"{slot.container_addr}/artifacts/list"
|
||||
payload: Dict[str, Any] = {"slot_id": slot.slot_id, "path": path, "recursive": recursive}
|
||||
if max_entries is not None:
|
||||
payload["max_entries"] = max_entries
|
||||
return await self._post_json(url, payload, timeout=timeout)
|
||||
|
||||
async def archive_artifacts(
|
||||
self,
|
||||
slot: Slot,
|
||||
path: str = ".",
|
||||
*,
|
||||
archive_format: str = "tar.gz",
|
||||
max_bytes: Optional[int] = None,
|
||||
max_entries: Optional[int] = None,
|
||||
timeout: Optional[float] = None,
|
||||
) -> Dict[str, Any]:
|
||||
url = f"{slot.container_addr}/artifacts/archive"
|
||||
payload: Dict[str, Any] = {"slot_id": slot.slot_id, "path": path, "format": archive_format}
|
||||
if max_bytes is not None:
|
||||
payload["max_bytes"] = max_bytes
|
||||
if max_entries is not None:
|
||||
payload["max_entries"] = max_entries
|
||||
return await self._post_json(url, payload, timeout=timeout)
|
||||
659
atropos/slots/pool.py
Normal file
659
atropos/slots/pool.py
Normal file
@@ -0,0 +1,659 @@
|
||||
"""
|
||||
SlotPool - Manages slots across Nomad allocations.
|
||||
|
||||
The SlotPool is the core abstraction for slot-based multiplexing:
|
||||
- Tracks available/acquired slots across containers
|
||||
- Handles slot acquisition and release
|
||||
- Auto-scales Nomad job count based on demand
|
||||
- Provides batched tool execution
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from ..nomad.client import (
|
||||
Allocation,
|
||||
AllocationStatus,
|
||||
NomadClient,
|
||||
create_sandbox_job,
|
||||
)
|
||||
from .executor import ExecutionResult, SandboxExecutor
|
||||
from .slot import Slot, SlotState, create_slots_for_allocation
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SlotPoolConfig:
|
||||
"""Configuration for SlotPool."""
|
||||
|
||||
# Nomad settings
|
||||
nomad_address: str = "http://localhost:4646"
|
||||
job_id: str = "atropos-sandbox"
|
||||
datacenter: str = "dc1"
|
||||
|
||||
# Container settings
|
||||
image: str = "atropos-sandbox:local" # Use :local tag to avoid registry pull
|
||||
slots_per_container: int = 10
|
||||
privileged: bool = False
|
||||
cpu: int = 500 # MHz
|
||||
memory: int = 512 # MB
|
||||
|
||||
# Driver selection: "docker" or "singularity"
|
||||
driver: str = "docker"
|
||||
# Path to .sif file for singularity driver (required if driver="singularity")
|
||||
singularity_image: Optional[str] = None
|
||||
|
||||
# Scaling settings
|
||||
min_containers: int = 1
|
||||
max_containers: int = 10
|
||||
|
||||
# Timeouts
|
||||
acquire_timeout: float = 30.0 # Seconds between acquire polls (also triggers scale-up attempts)
|
||||
health_check_interval: float = 30.0 # Seconds between health checks
|
||||
scale_cooldown: float = 60.0 # Seconds between scale operations
|
||||
|
||||
# Job lifecycle
|
||||
purge_job_on_start: bool = False # Purge any pre-existing job before starting (local dev/training friendly)
|
||||
|
||||
# Local Docker image convenience (macOS/Nomad dev mode)
|
||||
auto_build_local_image: bool = True # If image endswith :local and is missing, build it from the bundled Dockerfile.
|
||||
dockerfile_path: Optional[str] = None # Override Dockerfile path (default: Hermes-Agent/atropos/Dockerfile).
|
||||
docker_build_context: Optional[str] = None # Override build context (default: Hermes-Agent/atropos).
|
||||
|
||||
|
||||
class SlotPool:
|
||||
"""
|
||||
Manages a pool of slots across Nomad allocations.
|
||||
|
||||
The SlotPool:
|
||||
- Deploys sandbox containers to Nomad
|
||||
- Tracks slots across all running containers
|
||||
- Handles slot acquisition/release
|
||||
- Auto-scales based on demand
|
||||
- Provides batched execution via SandboxExecutor
|
||||
|
||||
Usage:
|
||||
config = SlotPoolConfig(
|
||||
nomad_address="http://localhost:4646",
|
||||
job_id="my-sandbox",
|
||||
slots_per_container=10,
|
||||
)
|
||||
|
||||
pool = SlotPool(config)
|
||||
await pool.start()
|
||||
|
||||
# Acquire a slot
|
||||
slot = await pool.acquire()
|
||||
|
||||
# Execute tool
|
||||
result = await pool.execute(slot, "bash", {"command": "ls"})
|
||||
|
||||
# Release slot
|
||||
await pool.release(slot)
|
||||
|
||||
# Shutdown
|
||||
await pool.stop()
|
||||
"""
|
||||
|
||||
def __init__(self, config: Optional[SlotPoolConfig] = None):
|
||||
self.config = config or SlotPoolConfig()
|
||||
|
||||
# Nomad client
|
||||
self.nomad = NomadClient(address=self.config.nomad_address)
|
||||
|
||||
# Sandbox executor for tool execution
|
||||
self.executor = SandboxExecutor()
|
||||
|
||||
# Slot tracking
|
||||
self._slots: Dict[str, Slot] = {} # slot_key -> Slot
|
||||
self._available_queue: asyncio.Queue[str] = asyncio.Queue()
|
||||
self._lock = asyncio.Lock()
|
||||
self._scale_lock = asyncio.Lock()
|
||||
|
||||
# State
|
||||
self._started = False
|
||||
self._health_task: Optional[asyncio.Task] = None
|
||||
self._scale_task: Optional[asyncio.Task] = None
|
||||
self._last_scale_time = 0.0
|
||||
|
||||
def _default_dockerfile_path(self) -> Path:
|
||||
# Hermes-Agent/atropos/Dockerfile lives next to this module in source checkouts.
|
||||
return Path(__file__).resolve().parents[1] / "Dockerfile"
|
||||
|
||||
def _default_build_context(self) -> Path:
|
||||
return Path(__file__).resolve().parents[1]
|
||||
|
||||
def _docker_image_exists(self, image: str) -> bool:
|
||||
try:
|
||||
proc = subprocess.run(
|
||||
["docker", "image", "inspect", image],
|
||||
stdout=subprocess.DEVNULL,
|
||||
stderr=subprocess.DEVNULL,
|
||||
check=False,
|
||||
env={**os.environ, "DOCKER_CLI_HINTS": "false"},
|
||||
)
|
||||
return proc.returncode == 0
|
||||
except FileNotFoundError:
|
||||
return False
|
||||
|
||||
def _try_build_local_image(self, image: str) -> None:
|
||||
dockerfile = Path(self.config.dockerfile_path) if self.config.dockerfile_path else self._default_dockerfile_path()
|
||||
context = Path(self.config.docker_build_context) if self.config.docker_build_context else self._default_build_context()
|
||||
|
||||
if not dockerfile.exists():
|
||||
raise RuntimeError(
|
||||
f"Sandbox Dockerfile not found at {dockerfile}. "
|
||||
"Build the sandbox image manually or set --env.purge_job_on_start false and provide a non-local image."
|
||||
)
|
||||
if not context.exists():
|
||||
raise RuntimeError(f"Docker build context not found at {context}")
|
||||
|
||||
# Prefer buildx+--load to ensure the image ends up in the local daemon (required by Nomad's docker driver).
|
||||
buildx_cmd = [
|
||||
"docker",
|
||||
"buildx",
|
||||
"build",
|
||||
"--load",
|
||||
"-t",
|
||||
image,
|
||||
"-f",
|
||||
str(dockerfile),
|
||||
str(context),
|
||||
]
|
||||
proc = subprocess.run(buildx_cmd, check=False, env={**os.environ, "DOCKER_CLI_HINTS": "false"})
|
||||
if proc.returncode == 0:
|
||||
return
|
||||
|
||||
# Fallback to classic docker build if buildx isn't available.
|
||||
build_cmd = ["docker", "build", "-t", image, "-f", str(dockerfile), str(context)]
|
||||
proc2 = subprocess.run(build_cmd, check=False, env={**os.environ, "DOCKER_CLI_HINTS": "false"})
|
||||
if proc2.returncode != 0:
|
||||
raise RuntimeError(
|
||||
f"Failed to build local sandbox image {image}. "
|
||||
f"Tried: {' '.join(buildx_cmd)} and {' '.join(build_cmd)}"
|
||||
)
|
||||
|
||||
def _ensure_local_image(self) -> None:
|
||||
image = (self.config.image or "").strip()
|
||||
if not image.endswith(":local"):
|
||||
return
|
||||
if not self.config.auto_build_local_image:
|
||||
return
|
||||
|
||||
if self._docker_image_exists(image):
|
||||
return
|
||||
|
||||
logger.info(f"Local sandbox image {image} not found; building it now...")
|
||||
self._try_build_local_image(image)
|
||||
|
||||
def _slot_key(self, alloc_id: str, slot_id: str) -> str:
|
||||
"""Generate unique key for a slot."""
|
||||
return f"{alloc_id}:{slot_id}"
|
||||
|
||||
@property
|
||||
def total_slots(self) -> int:
|
||||
"""Total number of slots in pool."""
|
||||
return len(self._slots)
|
||||
|
||||
@property
|
||||
def available_slots(self) -> int:
|
||||
"""Number of available slots."""
|
||||
return sum(1 for s in self._slots.values() if s.is_available)
|
||||
|
||||
@property
|
||||
def acquired_slots(self) -> int:
|
||||
"""Number of acquired slots."""
|
||||
return sum(1 for s in self._slots.values() if s.is_acquired)
|
||||
|
||||
async def start(self) -> None:
|
||||
"""
|
||||
Start the slot pool.
|
||||
|
||||
- Checks if Nomad is healthy
|
||||
- Deploys sandbox job if not running
|
||||
- Discovers existing allocations
|
||||
- Starts health check background task
|
||||
"""
|
||||
if self._started:
|
||||
return
|
||||
|
||||
logger.info(f"Starting SlotPool (job_id={self.config.job_id})")
|
||||
|
||||
try:
|
||||
# Make sure local sandbox images exist before Nomad tries to pull them.
|
||||
# This is a common footgun in macOS dev mode with :local tags.
|
||||
self._ensure_local_image()
|
||||
|
||||
# Check Nomad health
|
||||
if not await self.nomad.is_healthy():
|
||||
raise RuntimeError(f"Nomad is not reachable at {self.config.nomad_address}")
|
||||
|
||||
if self.config.purge_job_on_start:
|
||||
logger.info(f"Purging any existing Nomad job: {self.config.job_id}")
|
||||
await self.nomad.stop_job(self.config.job_id, purge=True)
|
||||
|
||||
# Check if job exists (after optional purge)
|
||||
job = await self.nomad.get_job(self.config.job_id)
|
||||
|
||||
if job is None:
|
||||
# Deploy new job
|
||||
logger.info(f"Deploying sandbox job: {self.config.job_id} (driver={self.config.driver})")
|
||||
job_spec = create_sandbox_job(
|
||||
job_id=self.config.job_id,
|
||||
image=self.config.image,
|
||||
count=self.config.min_containers,
|
||||
slots_per_container=self.config.slots_per_container,
|
||||
privileged=self.config.privileged,
|
||||
cpu=self.config.cpu,
|
||||
memory=self.config.memory,
|
||||
datacenter=self.config.datacenter,
|
||||
driver=self.config.driver,
|
||||
singularity_image=self.config.singularity_image,
|
||||
)
|
||||
result = await self.nomad.submit_job(job_spec)
|
||||
if "error" in result:
|
||||
raise RuntimeError(f"Failed to submit job: {result}")
|
||||
|
||||
# Wait for allocations to be running (even if the job already existed).
|
||||
await self._wait_for_healthy_allocations(self.config.min_containers)
|
||||
|
||||
# Discover existing allocations and slots
|
||||
await self._refresh_slots()
|
||||
|
||||
# Start health check task
|
||||
self._health_task = asyncio.create_task(self._health_check_loop())
|
||||
|
||||
self._started = True
|
||||
logger.info(f"SlotPool started: {self.total_slots} slots available")
|
||||
except Exception:
|
||||
# Ensure aiohttp sessions are not leaked if we fail to start.
|
||||
await self.stop(purge_job=False)
|
||||
raise
|
||||
|
||||
async def stop(self, purge_job: bool = False) -> None:
|
||||
"""
|
||||
Stop the slot pool.
|
||||
|
||||
Args:
|
||||
purge_job: If True, also stop the Nomad job
|
||||
"""
|
||||
logger.info("Stopping SlotPool")
|
||||
|
||||
# Cancel health check task
|
||||
if self._health_task:
|
||||
self._health_task.cancel()
|
||||
try:
|
||||
await self._health_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
finally:
|
||||
self._health_task = None
|
||||
|
||||
if self._scale_task:
|
||||
self._scale_task.cancel()
|
||||
try:
|
||||
await self._scale_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
finally:
|
||||
self._scale_task = None
|
||||
|
||||
# Optionally stop the job (do this even if start() never completed).
|
||||
if purge_job:
|
||||
logger.info(f"Stopping Nomad job: {self.config.job_id}")
|
||||
await self.nomad.stop_job(self.config.job_id, purge=True)
|
||||
|
||||
# Close connections
|
||||
await self.executor.close()
|
||||
await self.nomad.close()
|
||||
|
||||
self._started = False
|
||||
self._slots.clear()
|
||||
|
||||
# Clear the queue
|
||||
while not self._available_queue.empty():
|
||||
try:
|
||||
self._available_queue.get_nowait()
|
||||
except asyncio.QueueEmpty:
|
||||
break
|
||||
|
||||
async def acquire(self, trajectory_id: Optional[str] = None) -> Slot:
|
||||
"""
|
||||
Acquire an available slot.
|
||||
|
||||
If no slots are available, waits up to acquire_timeout seconds.
|
||||
If still no slots, attempts to scale up.
|
||||
|
||||
Args:
|
||||
trajectory_id: Optional ID of trajectory acquiring the slot
|
||||
|
||||
Returns:
|
||||
Acquired Slot
|
||||
|
||||
Raises:
|
||||
asyncio.TimeoutError: If no slot becomes available
|
||||
"""
|
||||
if not self._started:
|
||||
raise RuntimeError("SlotPool not started")
|
||||
|
||||
while True:
|
||||
try:
|
||||
# Try to get an available slot
|
||||
slot_key = await asyncio.wait_for(
|
||||
self._available_queue.get(),
|
||||
timeout=self.config.acquire_timeout,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
# Try to scale up, but keep waiting even if scaling isn't possible.
|
||||
# In practice, slots may become available shortly (e.g. contention),
|
||||
# and scaling may be temporarily blocked by Nomad deployments.
|
||||
await self._try_scale_up()
|
||||
continue
|
||||
|
||||
slot = self._slots.get(slot_key)
|
||||
if slot is None:
|
||||
# Slot was removed; discard stale queue entry and retry.
|
||||
continue
|
||||
|
||||
try:
|
||||
slot.acquire(trajectory_id)
|
||||
except RuntimeError:
|
||||
# Slot isn't actually available (e.g. duplicate queue entry); retry.
|
||||
continue
|
||||
|
||||
logger.debug(f"Acquired slot {slot.slot_id} (alloc={slot.alloc_id[:8]})")
|
||||
return slot
|
||||
|
||||
async def release(self, slot: Slot, reset_workspace: bool = False) -> None:
|
||||
"""
|
||||
Release a slot back to the pool.
|
||||
|
||||
Args:
|
||||
slot: Slot to release
|
||||
reset_workspace: If True, clear the workspace files
|
||||
"""
|
||||
slot_key = self._slot_key(slot.alloc_id, slot.slot_id)
|
||||
|
||||
if slot_key not in self._slots:
|
||||
logger.warning(f"Releasing unknown slot: {slot_key}")
|
||||
return
|
||||
|
||||
# Optionally reset workspace
|
||||
if reset_workspace:
|
||||
await self.executor.reset_slot(slot)
|
||||
|
||||
slot.release()
|
||||
await self._available_queue.put(slot_key)
|
||||
|
||||
logger.debug(f"Released slot {slot.slot_id}")
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
slot: Slot,
|
||||
tool_name: str,
|
||||
args: Dict[str, Any],
|
||||
timeout: Optional[float] = None,
|
||||
) -> ExecutionResult:
|
||||
"""
|
||||
Execute a tool in a slot's workspace.
|
||||
|
||||
Args:
|
||||
slot: Slot to execute in
|
||||
tool_name: Name of tool (bash, read_file, write_file)
|
||||
args: Tool arguments
|
||||
timeout: Optional timeout override
|
||||
|
||||
Returns:
|
||||
ExecutionResult
|
||||
"""
|
||||
return await self.executor.execute(slot, tool_name, args, timeout)
|
||||
|
||||
async def execute_batch(
|
||||
self,
|
||||
requests: List[Tuple[Slot, str, Dict[str, Any]]],
|
||||
timeout: Optional[float] = None,
|
||||
) -> List[ExecutionResult]:
|
||||
"""
|
||||
Execute multiple tools in parallel.
|
||||
|
||||
This is the key optimization - batch execution across multiple slots
|
||||
maximizes container utilization.
|
||||
|
||||
Args:
|
||||
requests: List of (slot, tool_name, args) tuples
|
||||
timeout: Optional timeout override
|
||||
|
||||
Returns:
|
||||
List of ExecutionResults in same order
|
||||
"""
|
||||
return await self.executor.execute_batch(requests, timeout)
|
||||
|
||||
async def _refresh_slots(self) -> None:
|
||||
"""Refresh slot inventory from Nomad allocations."""
|
||||
async with self._lock:
|
||||
allocs = await self.nomad.get_job_allocations(self.config.job_id)
|
||||
|
||||
# Track which slots we've seen
|
||||
seen_keys = set()
|
||||
|
||||
for alloc in allocs:
|
||||
if alloc.status != AllocationStatus.RUNNING:
|
||||
continue
|
||||
|
||||
if not alloc.http_address:
|
||||
continue
|
||||
|
||||
# Check container health
|
||||
healthy = await self.executor.health_check(alloc.http_address)
|
||||
if not healthy:
|
||||
continue
|
||||
|
||||
# Create slots for this allocation
|
||||
for i in range(self.config.slots_per_container):
|
||||
slot_id = f"slot_{i}"
|
||||
slot_key = self._slot_key(alloc.id, slot_id)
|
||||
seen_keys.add(slot_key)
|
||||
|
||||
if slot_key not in self._slots:
|
||||
# New slot
|
||||
slot = Slot(
|
||||
slot_id=slot_id,
|
||||
alloc_id=alloc.id,
|
||||
container_addr=alloc.http_address,
|
||||
)
|
||||
self._slots[slot_key] = slot
|
||||
await self._available_queue.put(slot_key)
|
||||
logger.debug(f"Added slot: {slot_key}")
|
||||
|
||||
# Remove slots from dead allocations
|
||||
for slot_key in list(self._slots.keys()):
|
||||
if slot_key not in seen_keys:
|
||||
slot = self._slots.pop(slot_key)
|
||||
logger.debug(f"Removed slot: {slot_key}")
|
||||
|
||||
async def _wait_for_healthy_allocations(
|
||||
self,
|
||||
min_count: int,
|
||||
timeout: float = 120.0
|
||||
) -> None:
|
||||
"""Wait for allocations to become healthy."""
|
||||
import time
|
||||
start = time.time()
|
||||
|
||||
def _summarize_alloc_detail(detail: Dict[str, Any]) -> str:
|
||||
task_states = detail.get("TaskStates") or {}
|
||||
parts: List[str] = []
|
||||
if isinstance(task_states, dict):
|
||||
for task_name, st in task_states.items():
|
||||
events = (st or {}).get("Events") or []
|
||||
if isinstance(events, list) and events:
|
||||
# Include a few recent events; the latest can be a generic restart message
|
||||
# while the true root cause is slightly earlier (e.g. image pull failure).
|
||||
recent = events[-3:]
|
||||
msgs: List[str] = []
|
||||
for ev in recent:
|
||||
desc = ev.get("DisplayMessage") or ev.get("Message") or ev.get("Type") or ""
|
||||
if desc:
|
||||
msgs.append(desc)
|
||||
if msgs:
|
||||
parts.append(f"{task_name}: " + " | ".join(msgs))
|
||||
return "; ".join(parts)
|
||||
|
||||
def _alloc_events_lower(detail: Dict[str, Any]) -> str:
|
||||
task_states = detail.get("TaskStates") or {}
|
||||
texts: List[str] = []
|
||||
if isinstance(task_states, dict):
|
||||
for _task_name, st in task_states.items():
|
||||
events = (st or {}).get("Events") or []
|
||||
if isinstance(events, list):
|
||||
for ev in events[-10:]:
|
||||
desc = ev.get("DisplayMessage") or ev.get("Message") or ev.get("Type") or ""
|
||||
if desc:
|
||||
texts.append(desc)
|
||||
return " ".join(texts).lower()
|
||||
|
||||
while time.time() - start < timeout:
|
||||
allocs = await self.nomad.get_job_allocations(self.config.job_id)
|
||||
|
||||
healthy_count = 0
|
||||
for alloc in allocs:
|
||||
if alloc.status == AllocationStatus.RUNNING and alloc.http_address:
|
||||
if await self.executor.health_check(alloc.http_address):
|
||||
healthy_count += 1
|
||||
|
||||
# Fast-fail on obvious driver/image errors to avoid waiting out the full timeout.
|
||||
if alloc.id:
|
||||
detail = await self.nomad.get_allocation(alloc.id)
|
||||
if isinstance(detail, dict):
|
||||
summary = _summarize_alloc_detail(detail)
|
||||
lowered = _alloc_events_lower(detail) or summary.lower()
|
||||
if "failed to pull" in lowered or "pull access denied" in lowered:
|
||||
raise RuntimeError(
|
||||
"Nomad allocation failed to start due to a Docker image pull error. "
|
||||
f"Allocation {alloc.id[:8]}: {summary}\n"
|
||||
"If you're using a local image tag (e.g. `atropos-sandbox:local`) on macOS, "
|
||||
"make sure the image is loaded into Docker, e.g.:\n"
|
||||
" docker buildx build --load -t atropos-sandbox:local -f Hermes-Agent/atropos/Dockerfile Hermes-Agent/atropos"
|
||||
)
|
||||
if "exceeded allowed attempts" in lowered:
|
||||
raise RuntimeError(
|
||||
"Nomad allocation is crash-looping and has entered restart backoff. "
|
||||
f"Allocation {alloc.id[:8]}: {summary}\n"
|
||||
"Inspect logs with:\n"
|
||||
f" nomad alloc logs -stderr -task sandbox-server {alloc.id}\n"
|
||||
"Common causes include: missing local Docker image tag, container entrypoint error, "
|
||||
"or sandbox-server startup failure."
|
||||
)
|
||||
|
||||
if healthy_count >= min_count:
|
||||
return
|
||||
|
||||
await asyncio.sleep(2.0)
|
||||
|
||||
# Timed out: include allocation status detail to help debugging.
|
||||
allocs = await self.nomad.get_job_allocations(self.config.job_id)
|
||||
alloc_lines: List[str] = []
|
||||
for alloc in allocs[:10]:
|
||||
addr = alloc.http_address or "-"
|
||||
line = f"{alloc.id[:8]} status={alloc.status.value} http={addr}"
|
||||
detail = await self.nomad.get_allocation(alloc.id)
|
||||
if isinstance(detail, dict):
|
||||
summary = _summarize_alloc_detail(detail)
|
||||
if summary:
|
||||
line += f" detail={summary}"
|
||||
alloc_lines.append(line)
|
||||
|
||||
hint = (
|
||||
"Timed out waiting for healthy sandbox allocations.\n"
|
||||
f"Job: {self.config.job_id}, desired_healthy: {min_count}\n"
|
||||
"Allocations:\n - " + "\n - ".join(alloc_lines)
|
||||
)
|
||||
raise RuntimeError(hint)
|
||||
|
||||
async def _try_scale_up(self) -> bool:
|
||||
"""Attempt to scale up the job."""
|
||||
import time
|
||||
|
||||
async with self._scale_lock:
|
||||
# Check cooldown
|
||||
if time.time() - self._last_scale_time < self.config.scale_cooldown:
|
||||
return False
|
||||
|
||||
# Check max containers
|
||||
status = await self.nomad.get_job_status(self.config.job_id)
|
||||
if status is None:
|
||||
return False
|
||||
|
||||
current_count = status.count
|
||||
if current_count >= self.config.max_containers:
|
||||
logger.warning(f"Cannot scale up: already at max ({self.config.max_containers})")
|
||||
return False
|
||||
|
||||
# Scale up
|
||||
new_count = min(current_count + 1, self.config.max_containers)
|
||||
logger.info(f"Scaling up from {current_count} to {new_count} containers")
|
||||
|
||||
scale_resp = await self.nomad.scale_job(
|
||||
self.config.job_id,
|
||||
count=new_count,
|
||||
task_group="sandbox",
|
||||
)
|
||||
|
||||
# Nomad may return non-JSON errors (e.g. plain text) with a status field.
|
||||
if isinstance(scale_resp, dict) and scale_resp.get("status", 200) >= 400:
|
||||
logger.warning(f"Scale request rejected: {scale_resp}")
|
||||
self._last_scale_time = time.time()
|
||||
return False
|
||||
|
||||
self._last_scale_time = time.time()
|
||||
|
||||
# Wait for new allocation in the background so contended acquires can still
|
||||
# make progress (e.g. by grabbing slots released by other trajectories).
|
||||
if self._scale_task is None or self._scale_task.done():
|
||||
self._scale_task = asyncio.create_task(self._wait_for_scale(new_count))
|
||||
|
||||
return True
|
||||
|
||||
async def _wait_for_scale(self, desired_count: int) -> None:
|
||||
try:
|
||||
await self._wait_for_healthy_allocations(desired_count, timeout=60.0)
|
||||
await self._refresh_slots()
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to scale up: {e}")
|
||||
|
||||
async def _health_check_loop(self) -> None:
|
||||
"""Background task to monitor container health."""
|
||||
while True:
|
||||
try:
|
||||
await asyncio.sleep(self.config.health_check_interval)
|
||||
await self._refresh_slots()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Health check error: {e}")
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""Get pool statistics."""
|
||||
slots_by_state = {}
|
||||
for slot in self._slots.values():
|
||||
state = slot.state.value
|
||||
slots_by_state[state] = slots_by_state.get(state, 0) + 1
|
||||
|
||||
container_count = len({s.alloc_id for s in self._slots.values()}) if self._slots else 0
|
||||
|
||||
return {
|
||||
"total_slots": self.total_slots,
|
||||
"available_slots": self.available_slots,
|
||||
"acquired_slots": self.acquired_slots,
|
||||
"containers": container_count,
|
||||
"slots_by_state": slots_by_state,
|
||||
"started": self._started,
|
||||
}
|
||||
159
atropos/slots/slot.py
Normal file
159
atropos/slots/slot.py
Normal file
@@ -0,0 +1,159 @@
|
||||
"""
|
||||
Slot abstraction for atropos-agent.
|
||||
|
||||
A Slot represents an isolated workspace for a single agent trajectory.
|
||||
Slots are hosted on Nomad allocations and provide workspace isolation
|
||||
via filesystem directories.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Optional
|
||||
import uuid
|
||||
|
||||
|
||||
class SlotState(Enum):
|
||||
"""State of a slot in the pool."""
|
||||
AVAILABLE = "available" # Ready to be acquired
|
||||
ACQUIRED = "acquired" # Assigned to a trajectory
|
||||
EXECUTING = "executing" # Currently executing a tool
|
||||
RELEASING = "releasing" # Being released back to pool
|
||||
ERROR = "error" # In error state
|
||||
|
||||
|
||||
@dataclass
|
||||
class Slot:
|
||||
"""
|
||||
An isolated workspace for a single agent trajectory.
|
||||
|
||||
Slots are the unit of scheduling - each trajectory runs in its own slot,
|
||||
with an isolated workspace directory. Multiple slots share a container.
|
||||
|
||||
Attributes:
|
||||
slot_id: Unique identifier for this slot (e.g., "slot_0")
|
||||
alloc_id: Nomad allocation ID hosting this slot
|
||||
container_addr: HTTP address of the sandbox server (e.g., "http://10.0.0.1:8080")
|
||||
workspace_dir: Path to workspace in container (e.g., "/data/slot_0")
|
||||
state: Current state of the slot
|
||||
trajectory_id: ID of trajectory currently using this slot (if acquired)
|
||||
metadata: Additional metadata
|
||||
"""
|
||||
slot_id: str
|
||||
alloc_id: str
|
||||
container_addr: str
|
||||
workspace_dir: str = ""
|
||||
state: SlotState = SlotState.AVAILABLE
|
||||
trajectory_id: Optional[str] = None
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self):
|
||||
"""Set default workspace_dir if not provided."""
|
||||
if not self.workspace_dir:
|
||||
self.workspace_dir = f"/data/{self.slot_id}"
|
||||
|
||||
@property
|
||||
def is_available(self) -> bool:
|
||||
"""Check if slot is available for acquisition."""
|
||||
return self.state == SlotState.AVAILABLE
|
||||
|
||||
@property
|
||||
def is_acquired(self) -> bool:
|
||||
"""Check if slot is currently acquired."""
|
||||
return self.state in (SlotState.ACQUIRED, SlotState.EXECUTING)
|
||||
|
||||
def acquire(self, trajectory_id: Optional[str] = None) -> None:
|
||||
"""
|
||||
Mark slot as acquired by a trajectory.
|
||||
|
||||
Args:
|
||||
trajectory_id: Optional ID of acquiring trajectory
|
||||
"""
|
||||
if not self.is_available:
|
||||
raise RuntimeError(f"Cannot acquire slot {self.slot_id}: state is {self.state}")
|
||||
|
||||
self.state = SlotState.ACQUIRED
|
||||
self.trajectory_id = trajectory_id or str(uuid.uuid4())
|
||||
|
||||
def start_execution(self, execution_id: Optional[str] = None) -> None:
|
||||
"""Mark slot as executing."""
|
||||
if self.state != SlotState.ACQUIRED:
|
||||
raise RuntimeError(f"Cannot start execution on slot {self.slot_id}: state is {self.state}")
|
||||
|
||||
self.state = SlotState.EXECUTING
|
||||
if execution_id:
|
||||
self.metadata["current_execution_id"] = execution_id
|
||||
|
||||
def end_execution(self) -> None:
|
||||
"""Mark execution as complete, return to acquired state."""
|
||||
if self.state != SlotState.EXECUTING:
|
||||
raise RuntimeError(f"Cannot end execution on slot {self.slot_id}: state is {self.state}")
|
||||
|
||||
self.state = SlotState.ACQUIRED
|
||||
self.metadata.pop("current_execution_id", None)
|
||||
|
||||
def release(self) -> None:
|
||||
"""Release slot back to available state."""
|
||||
self.state = SlotState.AVAILABLE
|
||||
self.trajectory_id = None
|
||||
self.metadata.pop("current_execution_id", None)
|
||||
|
||||
def mark_error(self, error: str) -> None:
|
||||
"""Mark slot as in error state."""
|
||||
self.state = SlotState.ERROR
|
||||
self.metadata["error"] = error
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary for serialization."""
|
||||
return {
|
||||
"slot_id": self.slot_id,
|
||||
"alloc_id": self.alloc_id,
|
||||
"container_addr": self.container_addr,
|
||||
"workspace_dir": self.workspace_dir,
|
||||
"state": self.state.value,
|
||||
"trajectory_id": self.trajectory_id,
|
||||
"metadata": self.metadata,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "Slot":
|
||||
"""Create from dictionary."""
|
||||
return cls(
|
||||
slot_id=data["slot_id"],
|
||||
alloc_id=data["alloc_id"],
|
||||
container_addr=data["container_addr"],
|
||||
workspace_dir=data.get("workspace_dir", ""),
|
||||
state=SlotState(data.get("state", "available")),
|
||||
trajectory_id=data.get("trajectory_id"),
|
||||
metadata=data.get("metadata", {}),
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"Slot({self.slot_id}, state={self.state.value}, alloc={self.alloc_id[:8]}...)"
|
||||
|
||||
|
||||
def create_slots_for_allocation(
|
||||
alloc_id: str,
|
||||
container_addr: str,
|
||||
num_slots: int = 10,
|
||||
) -> list["Slot"]:
|
||||
"""
|
||||
Create slots for a Nomad allocation.
|
||||
|
||||
Args:
|
||||
alloc_id: Nomad allocation ID
|
||||
container_addr: HTTP address of sandbox server
|
||||
num_slots: Number of slots to create
|
||||
|
||||
Returns:
|
||||
List of Slot objects
|
||||
"""
|
||||
slots = []
|
||||
for i in range(num_slots):
|
||||
slot_id = f"slot_{i}"
|
||||
slots.append(Slot(
|
||||
slot_id=slot_id,
|
||||
alloc_id=alloc_id,
|
||||
container_addr=container_addr,
|
||||
workspace_dir=f"/data/{slot_id}",
|
||||
))
|
||||
return slots
|
||||
2
atropos/terminal/__init__.py
Normal file
2
atropos/terminal/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
"""Terminal helpers for stateful sandbox interactions."""
|
||||
|
||||
115
atropos/terminal/asciinema_stream.py
Normal file
115
atropos/terminal/asciinema_stream.py
Normal file
@@ -0,0 +1,115 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
import pyte
|
||||
|
||||
|
||||
class AsciinemaStreamDecoder:
|
||||
def __init__(self, *, default_width: int = 80, default_height: int = 24) -> None:
|
||||
self._default_width = max(1, int(default_width))
|
||||
self._default_height = max(1, int(default_height))
|
||||
self._buffer = ""
|
||||
self._has_header = False
|
||||
self.width = self._default_width
|
||||
self.height = self._default_height
|
||||
self._screen = pyte.Screen(self.width, self.height)
|
||||
self._stream = pyte.Stream(self._screen)
|
||||
|
||||
def reset(self) -> None:
|
||||
self._buffer = ""
|
||||
self._has_header = False
|
||||
self.width = self._default_width
|
||||
self.height = self._default_height
|
||||
self._screen = pyte.Screen(self.width, self.height)
|
||||
self._stream = pyte.Stream(self._screen)
|
||||
|
||||
def feed(self, chunk: str | bytes) -> None:
|
||||
if not chunk:
|
||||
return
|
||||
if isinstance(chunk, bytes):
|
||||
chunk = chunk.decode("utf-8", errors="replace")
|
||||
self._buffer += chunk
|
||||
while True:
|
||||
line, sep, rest = self._buffer.partition("\n")
|
||||
if not sep:
|
||||
break
|
||||
self._buffer = rest
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
parsed = self._parse_json_line(line)
|
||||
if parsed is None:
|
||||
continue
|
||||
if not self._has_header:
|
||||
if isinstance(parsed, dict):
|
||||
self._init_from_header(parsed)
|
||||
continue
|
||||
if isinstance(parsed, list):
|
||||
self._has_header = True
|
||||
self._apply_event(parsed)
|
||||
continue
|
||||
continue
|
||||
if isinstance(parsed, list):
|
||||
self._apply_event(parsed)
|
||||
|
||||
def render(self) -> str:
|
||||
return "\n".join(self._screen.display)
|
||||
|
||||
def _parse_json_line(self, line: str) -> Any | None:
|
||||
try:
|
||||
return json.loads(line)
|
||||
except json.JSONDecodeError:
|
||||
return None
|
||||
|
||||
def _init_from_header(self, header: dict[str, Any]) -> None:
|
||||
width = _coerce_int(
|
||||
header.get("width") or header.get("columns") or header.get("cols"),
|
||||
self._default_width,
|
||||
)
|
||||
height = _coerce_int(
|
||||
header.get("height") or header.get("rows") or header.get("lines"),
|
||||
self._default_height,
|
||||
)
|
||||
self.width = max(1, width)
|
||||
self.height = max(1, height)
|
||||
self._screen = pyte.Screen(self.width, self.height)
|
||||
self._stream = pyte.Stream(self._screen)
|
||||
self._has_header = True
|
||||
|
||||
def _apply_event(self, event: list[Any]) -> None:
|
||||
if len(event) < 2:
|
||||
return
|
||||
event_type = event[1]
|
||||
payload = event[2] if len(event) > 2 else ""
|
||||
if event_type == "o":
|
||||
if isinstance(payload, str):
|
||||
self._stream.feed(payload)
|
||||
elif event_type == "r":
|
||||
width, height = _parse_resize(payload)
|
||||
if width and height:
|
||||
self.width = width
|
||||
self.height = height
|
||||
self._screen.resize(width, height)
|
||||
|
||||
|
||||
def _coerce_int(value: Any, default: int) -> int:
|
||||
try:
|
||||
return int(value)
|
||||
except (TypeError, ValueError):
|
||||
return int(default)
|
||||
|
||||
|
||||
def _parse_resize(payload: Any) -> tuple[int, int]:
|
||||
if isinstance(payload, str) and "x" in payload:
|
||||
left, right = payload.lower().split("x", 1)
|
||||
return _coerce_int(left, 0), _coerce_int(right, 0)
|
||||
if isinstance(payload, dict):
|
||||
width = _coerce_int(payload.get("width") or payload.get("columns") or payload.get("cols"), 0)
|
||||
height = _coerce_int(payload.get("height") or payload.get("rows") or payload.get("lines"), 0)
|
||||
return width, height
|
||||
if isinstance(payload, list) and len(payload) >= 2:
|
||||
return _coerce_int(payload[0], 0), _coerce_int(payload[1], 0)
|
||||
return 0, 0
|
||||
|
||||
31
atropos/tools/__init__.py
Normal file
31
atropos/tools/__init__.py
Normal file
@@ -0,0 +1,31 @@
|
||||
"""
|
||||
Tool abstractions for atropos-agent.
|
||||
|
||||
Provides base Tool class, ToolCall/ToolResult types, and specialized tools.
|
||||
|
||||
Kept modules:
|
||||
- base.py: ToolSchema, ToolCall, ToolResult, Tool ABC, ToolRegistry
|
||||
- tool_executor.py: Batched execution queue with slot routing
|
||||
- terminal_stateful_tool.py: Persistent terminal sessions
|
||||
- tmux_tool.py: Tmux-based streaming terminal
|
||||
|
||||
Removed (replaced by hermes-agent equivalents):
|
||||
- build_registry.py → model_tools.py + toolsets.py
|
||||
- sandbox_stubs.py → atropos/backends/ execute() methods
|
||||
- hermes_external_tools.py → environments/agent_loop.py handle_function_call()
|
||||
- toolset_resolver.py → toolsets.py
|
||||
"""
|
||||
|
||||
from .base import Tool, ToolCall, ToolRegistry, ToolResult, ToolSchema
|
||||
from .terminal_stateful_tool import TerminalStatefulTool
|
||||
from .tmux_tool import TmuxTool
|
||||
|
||||
__all__ = [
|
||||
"Tool",
|
||||
"ToolCall",
|
||||
"ToolRegistry",
|
||||
"ToolResult",
|
||||
"ToolSchema",
|
||||
"TerminalStatefulTool",
|
||||
"TmuxTool",
|
||||
]
|
||||
423
atropos/tools/base.py
Normal file
423
atropos/tools/base.py
Normal file
@@ -0,0 +1,423 @@
|
||||
"""
|
||||
Base Tool abstraction for atropos-agent.
|
||||
|
||||
Tools follow a simple pattern:
|
||||
1. Define schema (name, description, parameters)
|
||||
2. Implement execute() method
|
||||
3. Return ToolResult with output/error
|
||||
|
||||
Tool calls use Hermes-style XML tags:
|
||||
<tool_call>{"name": "bash", "arguments": {"command": "ls"}}</tool_call>
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolSchema:
|
||||
"""JSON Schema for a tool's parameters."""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
parameters: Dict[str, Any] = field(default_factory=dict)
|
||||
required: List[str] = field(default_factory=list)
|
||||
external: bool = False # Whether the tool must be executed via an external ToolServer (secret proxy) and not inside the sandbox.
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to OpenAI-compatible function schema."""
|
||||
return {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": self.parameters,
|
||||
"required": self.required,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
def to_prompt_description(self) -> str:
|
||||
"""Convert to human-readable description for system prompt."""
|
||||
params_desc = []
|
||||
for name, spec in self.parameters.items():
|
||||
req = "(required)" if name in self.required else "(optional)"
|
||||
desc = spec.get("description", "")
|
||||
param_type = spec.get("type", "string")
|
||||
params_desc.append(f" - {name} ({param_type}) {req}: {desc}")
|
||||
|
||||
params_str = "\n".join(params_desc) if params_desc else " (no parameters)"
|
||||
return f"**{self.name}**: {self.description}\nParameters:\n{params_str}"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolCall:
|
||||
"""A parsed tool call from model output."""
|
||||
|
||||
name: str
|
||||
arguments: Dict[str, Any]
|
||||
raw_text: str = "" # Original XML/JSON text
|
||||
uniq_id: str = field(default_factory=lambda: str(uuid.uuid4())) # Unique tool-call id for traceability/reconstruction.
|
||||
|
||||
@classmethod
|
||||
def parse_from_text(cls, text: str) -> List["ToolCall"]:
|
||||
"""
|
||||
Extract tool calls from text using Hermes-style XML tags.
|
||||
|
||||
Supported formats (STRICT: requires well-formed closing tags):
|
||||
- Hermes JSON wrapper:
|
||||
<tool_call>{"name": "...", "arguments": {...}}</tool_call>
|
||||
- GLM/llama.cpp style:
|
||||
<tool_call>terminal{"command":"ls -la"}</tool_call>
|
||||
"""
|
||||
calls: List["ToolCall"] = []
|
||||
|
||||
if not text:
|
||||
return calls
|
||||
|
||||
def _append_from_payload(*, name: str, arguments: Dict[str, Any], raw: str, uniq_id: Optional[str] = None) -> None:
|
||||
if not isinstance(name, str) or not name:
|
||||
return
|
||||
if not isinstance(arguments, dict):
|
||||
return
|
||||
calls.append(
|
||||
cls(
|
||||
name=name,
|
||||
arguments=arguments,
|
||||
raw_text=raw,
|
||||
uniq_id=uniq_id or str(uuid.uuid4()),
|
||||
)
|
||||
)
|
||||
|
||||
# STRICT parsing: only accept well-formed <tool_call>...</tool_call> blocks.
|
||||
pattern = r"<tool_call>\s*(.*?)\s*</tool_call>"
|
||||
for inner in re.findall(pattern, text, re.DOTALL):
|
||||
cleaned = (inner or "").strip()
|
||||
if not cleaned:
|
||||
continue
|
||||
|
||||
# Hermes JSON wrapper.
|
||||
if cleaned.startswith("{"):
|
||||
try:
|
||||
data = json.loads(cleaned)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
uniq_id = data.get("uniq_id") or data.get("id") or None
|
||||
_append_from_payload(
|
||||
name=data.get("name", ""),
|
||||
arguments=data.get("arguments", {}),
|
||||
raw=inner,
|
||||
uniq_id=uniq_id,
|
||||
)
|
||||
continue
|
||||
|
||||
# GLM/llama.cpp style: terminal{...}
|
||||
m = re.match(r"^\s*([A-Za-z0-9_.:\\-]+)\s*(\{.*\})\s*$", cleaned, re.DOTALL)
|
||||
if not m:
|
||||
continue
|
||||
name = m.group(1)
|
||||
args_text = m.group(2)
|
||||
try:
|
||||
args = json.loads(args_text)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
_append_from_payload(name=name, arguments=args, raw=inner)
|
||||
|
||||
return calls
|
||||
|
||||
@classmethod
|
||||
def has_tool_call(cls, text: str) -> bool:
|
||||
"""Check if text contains any tool calls."""
|
||||
return bool(re.search(r"<tool_call>", text))
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolResult:
|
||||
"""Result from executing a tool."""
|
||||
|
||||
success: bool
|
||||
output: str = ""
|
||||
error: str = ""
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
uniq_id: Optional[str] = None # Should match ToolCall.uniq_id for async execution tracking.
|
||||
|
||||
def to_xml(self) -> str:
|
||||
"""Format as XML for including in conversation."""
|
||||
data = {
|
||||
"success": self.success,
|
||||
"output": self.output,
|
||||
}
|
||||
if self.uniq_id:
|
||||
data["uniq_id"] = self.uniq_id
|
||||
if self.error:
|
||||
data["error"] = self.error
|
||||
if self.metadata:
|
||||
data["metadata"] = self.metadata
|
||||
return f"<tool_response>{json.dumps(data)}</tool_response>"
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
"success": self.success,
|
||||
"output": self.output,
|
||||
"error": self.error,
|
||||
"metadata": self.metadata,
|
||||
"uniq_id": self.uniq_id,
|
||||
}
|
||||
|
||||
|
||||
class Tool(ABC):
|
||||
"""
|
||||
Abstract base class for tools.
|
||||
|
||||
Subclasses must implement:
|
||||
- schema: ToolSchema describing the tool
|
||||
- execute(): async method that performs the tool action
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def schema(self) -> ToolSchema:
|
||||
"""Return the tool's schema."""
|
||||
pass
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""Tool name (from schema)."""
|
||||
return self.schema.name
|
||||
|
||||
@abstractmethod
|
||||
async def execute(self, **kwargs) -> ToolResult:
|
||||
"""
|
||||
Execute the tool with given arguments.
|
||||
|
||||
Args:
|
||||
**kwargs: Tool-specific arguments
|
||||
|
||||
Returns:
|
||||
ToolResult with success/failure and output
|
||||
"""
|
||||
pass
|
||||
|
||||
def is_available(self) -> tuple[bool, str | None]:
|
||||
"""
|
||||
Return whether this tool should be exposed/executable in the current process.
|
||||
|
||||
Tools that depend on optional binaries/services/env vars can override this
|
||||
to avoid advertising a tool that will fail at runtime.
|
||||
"""
|
||||
return True, None
|
||||
|
||||
async def __call__(self, **kwargs) -> ToolResult:
|
||||
"""Allow calling tool instance directly."""
|
||||
return await self.execute(**kwargs)
|
||||
|
||||
# Note: This is only wrapping declarations for the external ToolServer (for execution on external process tools), and tools preinstalled in envs
|
||||
class ToolRegistry:
|
||||
"""Registry of available tools."""
|
||||
|
||||
def __init__(self):
|
||||
self._tools: Dict[str, Tool] = {}
|
||||
|
||||
def register(self, tool: Tool) -> None:
|
||||
"""Register a tool."""
|
||||
self._tools[tool.name] = tool
|
||||
|
||||
def get(self, name: str) -> Optional[Tool]:
|
||||
"""Get a tool by name."""
|
||||
return self._tools.get(name)
|
||||
|
||||
def list_tools(self) -> List[Tool]:
|
||||
"""List all registered tools."""
|
||||
return list(self._tools.values())
|
||||
|
||||
def get_schemas(self) -> List[ToolSchema]:
|
||||
"""Get schemas for all registered tools."""
|
||||
return [tool.schema for tool in self._tools.values()]
|
||||
|
||||
def get_prompt_description(self) -> str:
|
||||
"""Generate tool descriptions for system prompt."""
|
||||
descriptions = [tool.schema.to_prompt_description() for tool in self._tools.values()]
|
||||
return "\n\n".join(descriptions)
|
||||
|
||||
def get_prompt_tool_definitions_json(self) -> str:
|
||||
"""
|
||||
Return a Hermes-style JSON list of tool definitions for use inside a `<tools>...</tools>` block.
|
||||
|
||||
Hermes trajectories historically use a simplified schema list:
|
||||
[{"name": ..., "description": ..., "parameters": {...}, "required": null}, ...]
|
||||
"""
|
||||
formatted: List[Dict[str, Any]] = []
|
||||
for tool in self._tools.values():
|
||||
fn = tool.schema.to_dict().get("function", {})
|
||||
formatted.append(
|
||||
{
|
||||
"name": fn.get("name", tool.name),
|
||||
"description": fn.get("description", ""),
|
||||
"parameters": fn.get("parameters", {}),
|
||||
# Keep parity with Hermes saved trajectories (required is typically null there).
|
||||
"required": None,
|
||||
}
|
||||
)
|
||||
return json.dumps(formatted, ensure_ascii=False)
|
||||
|
||||
async def execute(self, call: ToolCall) -> ToolResult:
|
||||
"""Execute a tool call."""
|
||||
tool = self.get(call.name)
|
||||
if tool is None:
|
||||
return ToolResult(
|
||||
success=False,
|
||||
error=f"Unknown tool: {call.name}",
|
||||
uniq_id=call.uniq_id,
|
||||
)
|
||||
|
||||
try:
|
||||
result = await tool.execute(**call.arguments)
|
||||
if result.uniq_id is None:
|
||||
result.uniq_id = call.uniq_id
|
||||
return result
|
||||
except Exception as e:
|
||||
return ToolResult(
|
||||
success=False,
|
||||
error=f"Tool execution error: {str(e)}",
|
||||
uniq_id=call.uniq_id,
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# FastAPI / transport models
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class ToolCallPayload(BaseModel):
|
||||
name: str
|
||||
arguments: Dict[str, Any] = Field(default_factory=dict)
|
||||
uniq_id: str
|
||||
|
||||
@classmethod
|
||||
def from_tool_call(cls, call: ToolCall) -> "ToolCallPayload":
|
||||
return cls(name=call.name, arguments=call.arguments, uniq_id=call.uniq_id)
|
||||
|
||||
def to_tool_call(self) -> ToolCall:
|
||||
return ToolCall(name=self.name, arguments=self.arguments, uniq_id=self.uniq_id)
|
||||
|
||||
|
||||
class ToolResultPayload(BaseModel):
|
||||
success: bool
|
||||
output: str = ""
|
||||
error: str = ""
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
uniq_id: Optional[str] = None
|
||||
|
||||
@classmethod
|
||||
def from_tool_result(cls, result: ToolResult) -> "ToolResultPayload":
|
||||
return cls(
|
||||
success=result.success,
|
||||
output=result.output,
|
||||
error=result.error,
|
||||
metadata=result.metadata,
|
||||
uniq_id=result.uniq_id,
|
||||
)
|
||||
|
||||
def to_tool_result(self) -> ToolResult:
|
||||
return ToolResult(
|
||||
success=self.success,
|
||||
output=self.output,
|
||||
error=self.error,
|
||||
metadata=self.metadata,
|
||||
uniq_id=self.uniq_id,
|
||||
)
|
||||
|
||||
|
||||
class ToolExecutorExecuteRequest(BaseModel):
|
||||
trajectory_id: str
|
||||
tool: ToolCallPayload
|
||||
timeout_s: Optional[float] = None
|
||||
|
||||
|
||||
class ToolExecutorReleaseRequest(BaseModel):
|
||||
trajectory_id: str
|
||||
reset_workspace: bool = False
|
||||
|
||||
|
||||
class ToolServerExecuteRequest(BaseModel):
|
||||
trajectory_id: Optional[str] = None
|
||||
tool: ToolCallPayload
|
||||
timeout_s: Optional[float] = None
|
||||
# Optional sandbox context for tools that need workspace artifacts.
|
||||
# This is set by ToolExecutor and is NOT model-controlled.
|
||||
slot_id: Optional[str] = None
|
||||
container_addr: Optional[str] = None
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Artifact transport models
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class ArtifactReadRequestPayload(BaseModel):
|
||||
trajectory_id: str
|
||||
path: str
|
||||
encoding: Literal["text", "base64"] = "text"
|
||||
max_bytes: Optional[int] = None
|
||||
include_sha256: bool = False
|
||||
|
||||
|
||||
class ArtifactReadResponsePayload(BaseModel):
|
||||
success: bool
|
||||
content: str = ""
|
||||
error: str = ""
|
||||
encoding: str = "text"
|
||||
truncated: bool = False
|
||||
bytes: int = 0
|
||||
file_size: Optional[int] = None
|
||||
path: str = ""
|
||||
mime: Optional[str] = None
|
||||
sha256: Optional[str] = None
|
||||
|
||||
|
||||
class ArtifactListRequestPayload(BaseModel):
|
||||
trajectory_id: str
|
||||
path: str = "."
|
||||
recursive: bool = False
|
||||
max_entries: Optional[int] = None
|
||||
|
||||
|
||||
class ArtifactListEntryPayload(BaseModel):
|
||||
path: str
|
||||
is_dir: bool
|
||||
size: int
|
||||
mtime: float
|
||||
|
||||
|
||||
class ArtifactListResponsePayload(BaseModel):
|
||||
success: bool
|
||||
entries: List[ArtifactListEntryPayload] = Field(default_factory=list)
|
||||
truncated: bool = False
|
||||
error: str = ""
|
||||
|
||||
|
||||
class ArtifactArchiveRequestPayload(BaseModel):
|
||||
trajectory_id: str
|
||||
path: str = "."
|
||||
format: Literal["tar.gz", "tgz"] = "tar.gz"
|
||||
max_bytes: Optional[int] = None
|
||||
max_entries: Optional[int] = None
|
||||
|
||||
|
||||
class ArtifactArchiveResponsePayload(BaseModel):
|
||||
success: bool
|
||||
content: str = ""
|
||||
error: str = ""
|
||||
encoding: str = "base64"
|
||||
format: str = "tar.gz"
|
||||
bytes: int = 0
|
||||
entry_count: int = 0
|
||||
45
atropos/tools/terminal_stateful_tool.py
Normal file
45
atropos/tools/terminal_stateful_tool.py
Normal file
@@ -0,0 +1,45 @@
|
||||
"""
|
||||
Stateful terminal tool schema.
|
||||
|
||||
This is a sandbox tool that routes to the sandbox server as `bash_stateful`
|
||||
via ToolExecutor mapping. It exists to expose an explicit, opt-in terminal
|
||||
primitive suitable for stateful workflows (e.g. tmux sessions / TUIs).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from .base import Tool, ToolResult, ToolSchema
|
||||
|
||||
|
||||
class TerminalStatefulTool(Tool):
|
||||
@property
|
||||
def schema(self) -> ToolSchema:
|
||||
return ToolSchema(
|
||||
name="terminal_stateful",
|
||||
description=(
|
||||
"Execute a command in the sandbox, allowing stateful/background processes to persist "
|
||||
"across tool calls within the same trajectory slot (e.g. tmux sessions). "
|
||||
"Use sparingly; output is still non-interactive."
|
||||
),
|
||||
parameters={
|
||||
"command": {"type": "string", "description": "The command to execute"},
|
||||
"timeout": {
|
||||
"type": "integer",
|
||||
"description": "Command timeout in seconds (optional).",
|
||||
"minimum": 1,
|
||||
},
|
||||
},
|
||||
required=["command"],
|
||||
)
|
||||
|
||||
def is_available(self) -> tuple[bool, str | None]:
|
||||
return True, None
|
||||
|
||||
async def execute(self, command: str, timeout: Optional[int] = None) -> ToolResult:
|
||||
_ = (command, timeout)
|
||||
return ToolResult(
|
||||
success=False,
|
||||
error="terminal_stateful must be executed via ToolExecutor inside the sandbox",
|
||||
)
|
||||
89
atropos/tools/tmux_tool.py
Normal file
89
atropos/tools/tmux_tool.py
Normal file
@@ -0,0 +1,89 @@
|
||||
"""
|
||||
tmux tool schema (sandbox).
|
||||
|
||||
This is a sandbox tool that provides basic tmux session control suitable for
|
||||
TUI-style terminal interactions:
|
||||
- send keys (arrow keys, enter, etc.)
|
||||
- capture the current screen buffer
|
||||
|
||||
Execution is routed by ToolExecutor to the sandbox server's `tmux` backend.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from .base import Tool, ToolResult, ToolSchema
|
||||
|
||||
|
||||
class TmuxTool(Tool):
|
||||
@property
|
||||
def schema(self) -> ToolSchema:
|
||||
return ToolSchema(
|
||||
name="tmux",
|
||||
description=(
|
||||
"Control a per-trajectory tmux session inside the sandbox (stateful terminal). "
|
||||
"Use this for TUI-style interactions: send keys and capture the current screen."
|
||||
),
|
||||
parameters={
|
||||
"action": {
|
||||
"type": "string",
|
||||
"description": "Action to perform: start | send_keys | stream | stop.",
|
||||
"enum": ["start", "send_keys", "stream", "stop", "capture"],
|
||||
},
|
||||
"keys": {
|
||||
"description": "Keys to send (string or list of strings) when action=send_keys.",
|
||||
},
|
||||
"block": {
|
||||
"type": "boolean",
|
||||
"description": "If true, wait for shell command completion (only valid at a shell prompt).",
|
||||
"default": False,
|
||||
},
|
||||
"min_wait_s": {
|
||||
"type": "number",
|
||||
"description": "For non-blocking send_keys, sleep this long after sending keys (seconds).",
|
||||
"default": 0.0,
|
||||
},
|
||||
"max_wait_s": {
|
||||
"type": "number",
|
||||
"description": "For blocking send_keys, max time to wait for completion (seconds).",
|
||||
},
|
||||
"capture_entire": {
|
||||
"type": "boolean",
|
||||
"description": "Deprecated. Streaming is preferred.",
|
||||
"default": False,
|
||||
},
|
||||
"max_bytes": {
|
||||
"type": "integer",
|
||||
"description": "Max bytes to return per stream call.",
|
||||
},
|
||||
"reset": {
|
||||
"type": "boolean",
|
||||
"description": "If true, reset stream offset to the beginning of the asciinema recording.",
|
||||
"default": False,
|
||||
},
|
||||
"pane_width": {
|
||||
"type": "integer",
|
||||
"description": "Pane width for action=start (columns).",
|
||||
"minimum": 20,
|
||||
},
|
||||
"pane_height": {
|
||||
"type": "integer",
|
||||
"description": "Pane height for action=start (rows).",
|
||||
"minimum": 10,
|
||||
},
|
||||
},
|
||||
required=["action"],
|
||||
)
|
||||
|
||||
def is_available(self) -> tuple[bool, str | None]:
|
||||
return True, None
|
||||
|
||||
async def execute(self, **kwargs: Dict[str, Any]) -> ToolResult:
|
||||
# This tool is intended to be executed via ToolExecutor -> sandbox server.
|
||||
# We keep a safe fallback for non-sandbox contexts.
|
||||
action = str(kwargs.get("action") or "").strip()
|
||||
return ToolResult(
|
||||
success=False,
|
||||
error=f"tmux tool must be executed in the sandbox (got action={action!r})",
|
||||
)
|
||||
500
atropos/tools/tool_executor.py
Normal file
500
atropos/tools/tool_executor.py
Normal file
@@ -0,0 +1,500 @@
|
||||
"""
|
||||
ToolExecutor - queued, batched tool dispatch for multiplexed agent trajectories.
|
||||
|
||||
This component is responsible for:
|
||||
- Maintaining trajectory -> Slot affinity (workspace continuity)
|
||||
- Batching sandbox tool calls across trajectories to maximize container utilization
|
||||
- Routing external tools (ToolSchema.external=True) to a ToolServer (Phase 4.5)
|
||||
|
||||
For now, only sandbox tools are executed:
|
||||
- bash
|
||||
- read_file
|
||||
- write_file
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from .base import (
|
||||
ArtifactArchiveRequestPayload,
|
||||
ArtifactArchiveResponsePayload,
|
||||
ArtifactListRequestPayload,
|
||||
ArtifactListResponsePayload,
|
||||
ArtifactReadRequestPayload,
|
||||
ArtifactReadResponsePayload,
|
||||
ToolCall,
|
||||
ToolCallPayload,
|
||||
ToolRegistry,
|
||||
ToolResult,
|
||||
ToolResultPayload,
|
||||
ToolServerExecuteRequest,
|
||||
)
|
||||
from ..backends.base import ToolBackend
|
||||
from ..slots import Slot
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolExecutorConfig:
|
||||
batch_window_ms: int = 20
|
||||
max_batch_size: int = 200
|
||||
allow_network: bool = True
|
||||
require_sandbox: bool = False
|
||||
require_stateful_sandbox: bool = False
|
||||
tool_server_url: Optional[str] = None
|
||||
tool_server_token: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class _QueuedToolRequest:
|
||||
trajectory_id: str
|
||||
call: ToolCall
|
||||
timeout_s: Optional[float]
|
||||
future: asyncio.Future
|
||||
|
||||
|
||||
class ToolExecutor:
|
||||
def __init__(
|
||||
self,
|
||||
backend: ToolBackend,
|
||||
tools: ToolRegistry,
|
||||
config: Optional[ToolExecutorConfig] = None,
|
||||
) -> None:
|
||||
self.backend = backend
|
||||
self.tools = tools
|
||||
self.config = config or ToolExecutorConfig()
|
||||
|
||||
self._queue: asyncio.Queue[Optional[_QueuedToolRequest]] = asyncio.Queue()
|
||||
self._task: Optional[asyncio.Task] = None
|
||||
self._stopping = asyncio.Event()
|
||||
|
||||
self._slots_lock = asyncio.Lock()
|
||||
self._slot_by_trajectory: Dict[str, Slot] = {}
|
||||
|
||||
self._tool_server_client: Optional[httpx.AsyncClient] = None
|
||||
self._tool_server_lock = asyncio.Lock()
|
||||
|
||||
# lightweight stats for status endpoints
|
||||
self.total_requests: int = 0
|
||||
self.total_errors: int = 0
|
||||
self.latencies_s: List[float] = []
|
||||
|
||||
async def start(self) -> None:
|
||||
if self._task is None:
|
||||
self._task = asyncio.create_task(self._run_loop())
|
||||
|
||||
def queue_size(self) -> int:
|
||||
return self._queue.qsize()
|
||||
|
||||
async def close(self) -> None:
|
||||
self._stopping.set()
|
||||
await self._queue.put(None)
|
||||
if self._task:
|
||||
await self._task
|
||||
self._task = None
|
||||
|
||||
client = self._tool_server_client
|
||||
self._tool_server_client = None
|
||||
if client is not None:
|
||||
await client.aclose()
|
||||
|
||||
# Best-effort release any remaining slots.
|
||||
async with self._slots_lock:
|
||||
slots = list(self._slot_by_trajectory.items())
|
||||
self._slot_by_trajectory.clear()
|
||||
|
||||
for _, slot in slots:
|
||||
try:
|
||||
await self.backend.release(slot, reset_workspace=False)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
trajectory_id: str,
|
||||
call: ToolCall,
|
||||
timeout_s: Optional[float] = None,
|
||||
) -> ToolResult:
|
||||
if self._task is None:
|
||||
raise RuntimeError("ToolExecutor not started (call start() first)")
|
||||
|
||||
# Allow tool args to suggest a timeout (Hermes-compatible terminal tool),
|
||||
# but never let the model choose "infinite" timeouts.
|
||||
if timeout_s is None:
|
||||
raw_timeout = call.arguments.get("timeout")
|
||||
if isinstance(raw_timeout, (int, float)):
|
||||
timeout_s = float(raw_timeout)
|
||||
if timeout_s is not None:
|
||||
timeout_s = max(1.0, min(float(timeout_s), 600.0))
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
fut: asyncio.Future = loop.create_future()
|
||||
started = time.perf_counter()
|
||||
await self._queue.put(_QueuedToolRequest(trajectory_id=trajectory_id, call=call, timeout_s=timeout_s, future=fut))
|
||||
try:
|
||||
result: ToolResult = await fut
|
||||
return result
|
||||
finally:
|
||||
self.latencies_s.append(time.perf_counter() - started)
|
||||
|
||||
async def release_trajectory(self, trajectory_id: str, reset_workspace: bool = False) -> None:
|
||||
async with self._slots_lock:
|
||||
slot = self._slot_by_trajectory.pop(trajectory_id, None)
|
||||
|
||||
if slot is not None:
|
||||
await self.backend.release(slot, reset_workspace=reset_workspace)
|
||||
|
||||
async def _get_slot_if_present(self, trajectory_id: str) -> Optional[Slot]:
|
||||
async with self._slots_lock:
|
||||
return self._slot_by_trajectory.get(trajectory_id)
|
||||
|
||||
# ---------------------------------------------------------------------
|
||||
# Artifact helpers (optional)
|
||||
# ---------------------------------------------------------------------
|
||||
|
||||
async def read_artifact(self, req: ArtifactReadRequestPayload) -> ArtifactReadResponsePayload:
|
||||
slot = await self._get_slot_if_present(req.trajectory_id)
|
||||
if slot is None:
|
||||
return ArtifactReadResponsePayload(success=False, error="No active slot for trajectory (run a sandbox tool first)")
|
||||
data = await self.backend.read_artifact(
|
||||
slot,
|
||||
req.path,
|
||||
encoding=req.encoding,
|
||||
max_bytes=req.max_bytes,
|
||||
include_sha256=req.include_sha256,
|
||||
)
|
||||
if isinstance(data, dict):
|
||||
data = dict(data)
|
||||
data.pop("http_status", None)
|
||||
try:
|
||||
return ArtifactReadResponsePayload(**(data or {}))
|
||||
except Exception as e:
|
||||
return ArtifactReadResponsePayload(success=False, error=f"Invalid artifact read response: {e}")
|
||||
|
||||
async def list_artifacts(self, req: ArtifactListRequestPayload) -> ArtifactListResponsePayload:
|
||||
slot = await self._get_slot_if_present(req.trajectory_id)
|
||||
if slot is None:
|
||||
return ArtifactListResponsePayload(success=False, error="No active slot for trajectory (run a sandbox tool first)")
|
||||
data = await self.backend.list_artifacts(
|
||||
slot,
|
||||
req.path,
|
||||
recursive=req.recursive,
|
||||
max_entries=req.max_entries,
|
||||
)
|
||||
if isinstance(data, dict):
|
||||
data = dict(data)
|
||||
data.pop("http_status", None)
|
||||
try:
|
||||
return ArtifactListResponsePayload(**(data or {}))
|
||||
except Exception as e:
|
||||
return ArtifactListResponsePayload(success=False, error=f"Invalid artifact list response: {e}")
|
||||
|
||||
async def archive_artifacts(self, req: ArtifactArchiveRequestPayload) -> ArtifactArchiveResponsePayload:
|
||||
slot = await self._get_slot_if_present(req.trajectory_id)
|
||||
if slot is None:
|
||||
return ArtifactArchiveResponsePayload(success=False, error="No active slot for trajectory (run a sandbox tool first)")
|
||||
data = await self.backend.archive_artifacts(
|
||||
slot,
|
||||
req.path,
|
||||
archive_format=req.format,
|
||||
max_bytes=req.max_bytes,
|
||||
max_entries=req.max_entries,
|
||||
)
|
||||
if isinstance(data, dict):
|
||||
data = dict(data)
|
||||
data.pop("http_status", None)
|
||||
try:
|
||||
return ArtifactArchiveResponsePayload(**(data or {}))
|
||||
except Exception as e:
|
||||
return ArtifactArchiveResponsePayload(success=False, error=f"Invalid artifact archive response: {e}")
|
||||
|
||||
async def _get_or_acquire_slot(self, trajectory_id: str) -> Slot:
|
||||
async with self._slots_lock:
|
||||
existing = self._slot_by_trajectory.get(trajectory_id)
|
||||
if existing is not None:
|
||||
return existing
|
||||
|
||||
slot = await self.backend.acquire(trajectory_id)
|
||||
|
||||
async with self._slots_lock:
|
||||
existing = self._slot_by_trajectory.get(trajectory_id)
|
||||
if existing is not None:
|
||||
# Another coroutine won the race; return its slot.
|
||||
await self.backend.release(slot, reset_workspace=False)
|
||||
return existing
|
||||
self._slot_by_trajectory[trajectory_id] = slot
|
||||
return slot
|
||||
|
||||
async def _run_loop(self) -> None:
|
||||
pending: List[_QueuedToolRequest] = []
|
||||
deadline: Optional[float] = None
|
||||
|
||||
batch_window_s = max(0.0, self.config.batch_window_ms / 1000.0)
|
||||
max_batch = max(1, self.config.max_batch_size)
|
||||
|
||||
while True:
|
||||
if self._stopping.is_set() and self._queue.empty() and not pending:
|
||||
break
|
||||
|
||||
timeout = None
|
||||
if pending and deadline is not None:
|
||||
timeout = max(0.0, deadline - time.perf_counter())
|
||||
|
||||
try:
|
||||
item = await asyncio.wait_for(self._queue.get(), timeout=timeout)
|
||||
if item is None:
|
||||
continue
|
||||
pending.append(item)
|
||||
if len(pending) == 1:
|
||||
deadline = time.perf_counter() + batch_window_s
|
||||
if len(pending) < max_batch:
|
||||
continue
|
||||
except asyncio.TimeoutError:
|
||||
# batch window elapsed
|
||||
pass
|
||||
|
||||
if not pending:
|
||||
deadline = None
|
||||
continue
|
||||
|
||||
batch = pending
|
||||
pending = []
|
||||
deadline = None
|
||||
|
||||
await self._execute_batch(batch)
|
||||
|
||||
async def _get_tool_server_client(self) -> httpx.AsyncClient:
|
||||
url = self.config.tool_server_url
|
||||
if not url:
|
||||
raise RuntimeError("ToolServer not configured")
|
||||
|
||||
if self._tool_server_client is not None:
|
||||
return self._tool_server_client
|
||||
|
||||
async with self._tool_server_lock:
|
||||
if self._tool_server_client is None:
|
||||
self._tool_server_client = httpx.AsyncClient(base_url=url.rstrip("/"))
|
||||
return self._tool_server_client
|
||||
|
||||
def _tool_server_headers(self) -> Dict[str, str]:
|
||||
token = self.config.tool_server_token
|
||||
if not token:
|
||||
return {}
|
||||
return {"Authorization": f"Bearer {token}"}
|
||||
|
||||
async def _execute_external(self, req: _QueuedToolRequest) -> ToolResult:
|
||||
client = await self._get_tool_server_client()
|
||||
slot_id: Optional[str] = None
|
||||
container_addr: Optional[str] = None
|
||||
slot = await self._get_slot_if_present(req.trajectory_id)
|
||||
if slot is not None:
|
||||
slot_id = slot.slot_id
|
||||
container_addr = slot.container_addr
|
||||
|
||||
payload = ToolServerExecuteRequest(
|
||||
trajectory_id=req.trajectory_id,
|
||||
tool=ToolCallPayload.from_tool_call(req.call),
|
||||
timeout_s=req.timeout_s,
|
||||
slot_id=slot_id,
|
||||
container_addr=container_addr,
|
||||
)
|
||||
|
||||
try:
|
||||
resp = await client.post(
|
||||
"/execute",
|
||||
json=payload.model_dump(),
|
||||
headers=self._tool_server_headers(),
|
||||
timeout=req.timeout_s,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
parsed = ToolResultPayload(**data)
|
||||
result = parsed.to_tool_result()
|
||||
if result.uniq_id is None:
|
||||
result.uniq_id = req.call.uniq_id
|
||||
return result
|
||||
except Exception as e:
|
||||
return ToolResult(
|
||||
success=False,
|
||||
error=f"External tool failed: {e}",
|
||||
uniq_id=req.call.uniq_id,
|
||||
)
|
||||
|
||||
async def _execute_batch(self, batch: List[_QueuedToolRequest]) -> None:
|
||||
# Resolve tool schemas once per request and separate sandbox/external/unknown.
|
||||
sandbox_items: List[_QueuedToolRequest] = []
|
||||
external_items: List[_QueuedToolRequest] = []
|
||||
unknown_items: List[_QueuedToolRequest] = []
|
||||
|
||||
for it in batch:
|
||||
tool = self.tools.get(it.call.name)
|
||||
if tool is None:
|
||||
unknown_items.append(it)
|
||||
continue
|
||||
|
||||
schema = tool.schema
|
||||
if not schema.external:
|
||||
sandbox_items.append(it)
|
||||
else:
|
||||
external_items.append(it)
|
||||
|
||||
for it in unknown_items:
|
||||
self.total_requests += 1
|
||||
self.total_errors += 1
|
||||
if not it.future.done():
|
||||
it.future.set_result(
|
||||
ToolResult(
|
||||
success=False,
|
||||
error=f"Unknown tool: {it.call.name}",
|
||||
uniq_id=it.call.uniq_id,
|
||||
)
|
||||
)
|
||||
|
||||
if external_items:
|
||||
if not self.config.tool_server_url:
|
||||
for it in external_items:
|
||||
self.total_requests += 1
|
||||
self.total_errors += 1
|
||||
if not it.future.done():
|
||||
it.future.set_result(
|
||||
ToolResult(
|
||||
success=False,
|
||||
error=f"External tool not available (ToolServer not configured): {it.call.name}",
|
||||
uniq_id=it.call.uniq_id,
|
||||
)
|
||||
)
|
||||
else:
|
||||
results = await asyncio.gather(*[self._execute_external(it) for it in external_items])
|
||||
for it, res in zip(external_items, results):
|
||||
self.total_requests += 1
|
||||
if not getattr(res, "success", False):
|
||||
self.total_errors += 1
|
||||
if not it.future.done():
|
||||
it.future.set_result(res)
|
||||
|
||||
if not sandbox_items:
|
||||
return
|
||||
|
||||
# Acquire slots for the distinct trajectories in this batch.
|
||||
try:
|
||||
traj_ids = list({it.trajectory_id for it in sandbox_items})
|
||||
slots = await asyncio.gather(*[self._get_or_acquire_slot(tid) for tid in traj_ids])
|
||||
slot_by_traj = dict(zip(traj_ids, slots))
|
||||
except Exception as e:
|
||||
for it in sandbox_items:
|
||||
self.total_requests += 1
|
||||
self.total_errors += 1
|
||||
if not it.future.done():
|
||||
it.future.set_result(
|
||||
ToolResult(
|
||||
success=False,
|
||||
error=f"Failed to acquire slot: {e}",
|
||||
uniq_id=it.call.uniq_id,
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
# Group by timeout so we don't accidentally make short timeouts wait on long ones.
|
||||
by_timeout: Dict[float, List[_QueuedToolRequest]] = {}
|
||||
default_timeout = self.backend.default_timeout_s
|
||||
|
||||
for it in sandbox_items:
|
||||
t = it.timeout_s
|
||||
if t is None:
|
||||
t = default_timeout
|
||||
if t is None:
|
||||
t = 30.0
|
||||
by_timeout.setdefault(float(t), []).append(it)
|
||||
|
||||
for timeout_s, items in by_timeout.items():
|
||||
requests = []
|
||||
dispatched: List[_QueuedToolRequest] = []
|
||||
for it in items:
|
||||
slot = slot_by_traj[it.trajectory_id]
|
||||
tool_name = it.call.name
|
||||
args = dict(it.call.arguments)
|
||||
|
||||
# Hermes compatibility: treat `terminal` as an alias of sandbox `bash`.
|
||||
if tool_name == "terminal":
|
||||
if args.get("background"):
|
||||
self.total_requests += 1
|
||||
self.total_errors += 1
|
||||
if not it.future.done():
|
||||
it.future.set_result(
|
||||
ToolResult(
|
||||
success=False,
|
||||
error="terminal background execution is not supported in sandbox",
|
||||
uniq_id=it.call.uniq_id,
|
||||
)
|
||||
)
|
||||
continue
|
||||
tool_name = "bash"
|
||||
# `timeout` is handled at the ToolExecutor level, not passed to the sandbox tool args.
|
||||
args.pop("timeout", None)
|
||||
elif tool_name == "terminal_stateful":
|
||||
tool_name = "bash_stateful"
|
||||
args.pop("timeout", None)
|
||||
elif tool_name == "tmux":
|
||||
# `tmux` is a sandbox tool backed by the stateful session manager.
|
||||
# Network policy is env-controlled.
|
||||
args.pop("allow_network", None)
|
||||
|
||||
if tool_name == "bash":
|
||||
# Network policy is set by the environment/executor, not by the model.
|
||||
args.pop("allow_network", None)
|
||||
args.pop("require_sandbox", None)
|
||||
args["allow_network"] = bool(self.config.allow_network)
|
||||
args["require_sandbox"] = bool(self.config.require_sandbox)
|
||||
# `timeout` is handled at the ToolExecutor level, not passed to the sandbox tool args.
|
||||
args.pop("timeout", None)
|
||||
elif tool_name == "bash_stateful":
|
||||
# Network policy is set by the environment/executor, not by the model.
|
||||
args.pop("allow_network", None)
|
||||
args.pop("require_sandbox", None)
|
||||
args.pop("require_stateful_sandbox", None)
|
||||
args["allow_network"] = bool(self.config.allow_network)
|
||||
args["require_stateful_sandbox"] = bool(self.config.require_stateful_sandbox)
|
||||
args.pop("timeout", None)
|
||||
elif tool_name == "tmux":
|
||||
# Network policy applies to the underlying stateful session.
|
||||
args.pop("allow_network", None)
|
||||
args.pop("require_sandbox", None)
|
||||
args.pop("require_stateful_sandbox", None)
|
||||
args["allow_network"] = bool(self.config.allow_network)
|
||||
args["require_stateful_sandbox"] = bool(self.config.require_stateful_sandbox)
|
||||
|
||||
requests.append((slot, tool_name, args))
|
||||
dispatched.append(it)
|
||||
|
||||
results = None
|
||||
try:
|
||||
if not dispatched:
|
||||
continue
|
||||
results = await self.backend.execute_batch(requests, timeout_s=timeout_s)
|
||||
except Exception as e:
|
||||
for it in items:
|
||||
self.total_requests += 1
|
||||
self.total_errors += 1
|
||||
if not it.future.done():
|
||||
it.future.set_result(
|
||||
ToolResult(
|
||||
success=False,
|
||||
error=f"Batch execution failed: {e}",
|
||||
uniq_id=it.call.uniq_id,
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
for it, res in zip(dispatched, results):
|
||||
self.total_requests += 1
|
||||
if not getattr(res, "success", False):
|
||||
self.total_errors += 1
|
||||
tool_result = res.to_tool_result()
|
||||
tool_result.uniq_id = it.call.uniq_id
|
||||
if not it.future.done():
|
||||
it.future.set_result(tool_result)
|
||||
415
atropos_compatible_agent.py
Normal file
415
atropos_compatible_agent.py
Normal file
@@ -0,0 +1,415 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Atropos-compatible Hermes agent runner.
|
||||
|
||||
This is a minimal subclass of Hermes-Agent's `AIAgent` that swaps the OpenAI
|
||||
function-calling backend for Atroposlib's `ManagedServer`/`ServerManager` backend
|
||||
and uses Hermes-style XML tool tags:
|
||||
|
||||
- <tool_call>{"name": "...", "arguments": {...}}</tool_call>
|
||||
- <tool_response>{...}</tool_response>
|
||||
|
||||
Tool observations are appended as `role="user"` messages containing one or more
|
||||
`<tool_response>` blocks so they survive common chat templates during tokenization.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
import warnings
|
||||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple
|
||||
|
||||
from model_tools import cleanup_vm, handle_function_call
|
||||
from run_agent import AIAgent
|
||||
|
||||
_TOOL_CALL_RE = re.compile(r"<tool_call>\\s*(.*?)\\s*</tool_call>", re.DOTALL)
|
||||
|
||||
|
||||
ATROPOS_TOOL_SYSTEM_PROMPT = """You are a helpful AI assistant with access to tools.
|
||||
|
||||
## Available Tools
|
||||
<tools>
|
||||
{tool_descriptions}
|
||||
</tools>
|
||||
|
||||
## How to Use Tools
|
||||
To call a tool, output:
|
||||
<tool_call>{{"name": "tool_name", "arguments": {{"arg1": "value1"}}}}</tool_call>
|
||||
|
||||
You may include optional reasoning in <think>...</think> before tool calls.
|
||||
|
||||
After each tool call, you will receive tool results as:
|
||||
<tool_response>{{...}}</tool_response>
|
||||
|
||||
Continue until finished, then provide a final response with no <tool_call> blocks.
|
||||
"""
|
||||
|
||||
|
||||
class AtroposAIAgent(AIAgent):
|
||||
"""
|
||||
Hermes `AIAgent` variant that uses Atroposlib ServerManager/ManagedServer.
|
||||
|
||||
Notes:
|
||||
- The default Hermes `AIAgent` remains unchanged; this class is opt-in.
|
||||
- The underlying server must expose `managed_server(tokenizer=...)` OR be a single
|
||||
APIServer-compatible object usable by Atroposlib's `ManagedServer`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
server: Any,
|
||||
tokenizer: Any = None,
|
||||
model: str = "local",
|
||||
max_iterations: int = 10,
|
||||
tool_delay: float = 0.0,
|
||||
enabled_toolsets: Optional[List[str]] = None,
|
||||
disabled_toolsets: Optional[List[str]] = None,
|
||||
save_trajectories: bool = False,
|
||||
verbose_logging: bool = False,
|
||||
quiet_mode: bool = False,
|
||||
ephemeral_system_prompt: Optional[str] = None,
|
||||
log_prefix_chars: int = 100,
|
||||
log_prefix: str = "",
|
||||
session_id: Optional[str] = None,
|
||||
temperature: Optional[float] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
):
|
||||
# Call parent init mainly to reuse tool selection + trajectory saving utilities.
|
||||
super().__init__(
|
||||
base_url="http://unused",
|
||||
api_key="dummy-key",
|
||||
model=model,
|
||||
max_iterations=max_iterations,
|
||||
tool_delay=tool_delay,
|
||||
enabled_toolsets=enabled_toolsets,
|
||||
disabled_toolsets=disabled_toolsets,
|
||||
save_trajectories=save_trajectories,
|
||||
verbose_logging=verbose_logging,
|
||||
quiet_mode=quiet_mode,
|
||||
ephemeral_system_prompt=ephemeral_system_prompt,
|
||||
log_prefix_chars=log_prefix_chars,
|
||||
log_prefix=log_prefix,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
self.server = server
|
||||
self.tokenizer = tokenizer
|
||||
self.temperature = temperature
|
||||
self.max_tokens = max_tokens
|
||||
|
||||
@asynccontextmanager
|
||||
async def _managed(self) -> AsyncGenerator[Any, None]:
|
||||
if hasattr(self.server, "managed_server"):
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
message=r"Using OpenAIServer with managed_server does not allow for state tracking",
|
||||
category=UserWarning,
|
||||
)
|
||||
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
|
||||
yield managed
|
||||
return
|
||||
|
||||
# Fall back to directly wrapping a single server object.
|
||||
from atroposlib.envs.server_handling.managed_server import ManagedServer
|
||||
|
||||
managed = ManagedServer(server=self.server, tokenizer=self.tokenizer)
|
||||
try:
|
||||
yield managed
|
||||
finally:
|
||||
managed.reset()
|
||||
|
||||
def _tool_descriptions_text(self) -> str:
|
||||
if not self.tools:
|
||||
return "(no tools available)"
|
||||
|
||||
parts: List[str] = []
|
||||
for tool in self.tools:
|
||||
fn = (tool or {}).get("function", {})
|
||||
name = fn.get("name", "")
|
||||
desc = (fn.get("description") or "").strip()
|
||||
if not name:
|
||||
continue
|
||||
if desc:
|
||||
parts.append(f"- {name}: {desc}")
|
||||
else:
|
||||
parts.append(f"- {name}")
|
||||
return "\n".join(parts) if parts else "(no tools available)"
|
||||
|
||||
def _build_system_prompt(self, system_message: Optional[str]) -> Optional[str]:
|
||||
tool_prompt = ATROPOS_TOOL_SYSTEM_PROMPT.format(
|
||||
tool_descriptions=self._tool_descriptions_text()
|
||||
)
|
||||
|
||||
parts: List[str] = []
|
||||
if system_message:
|
||||
parts.append(system_message)
|
||||
if self.ephemeral_system_prompt:
|
||||
parts.append(self.ephemeral_system_prompt)
|
||||
parts.append(tool_prompt)
|
||||
|
||||
return "\n\n".join(parts)
|
||||
|
||||
def _parse_tool_calls(self, content: str) -> Tuple[List[Tuple[str, Dict[str, Any]]], List[str]]:
|
||||
"""
|
||||
Returns:
|
||||
(calls, errors)
|
||||
"""
|
||||
calls: List[Tuple[str, Dict[str, Any]]] = []
|
||||
errors: List[str] = []
|
||||
|
||||
for raw in _TOOL_CALL_RE.findall(content or ""):
|
||||
try:
|
||||
payload = json.loads(raw)
|
||||
except json.JSONDecodeError as exc:
|
||||
errors.append(f"Invalid JSON inside <tool_call>: {exc}")
|
||||
continue
|
||||
|
||||
name = payload.get("name")
|
||||
args = payload.get("arguments", {})
|
||||
if not isinstance(name, str) or not name:
|
||||
errors.append("Tool call missing 'name' string")
|
||||
continue
|
||||
if not isinstance(args, dict):
|
||||
errors.append("Tool call 'arguments' must be an object")
|
||||
continue
|
||||
|
||||
calls.append((name, args))
|
||||
|
||||
return calls, errors
|
||||
|
||||
async def run_conversation_async(
|
||||
self,
|
||||
user_message: str,
|
||||
system_message: Optional[str] = None,
|
||||
conversation_history: Optional[List[Dict[str, Any]]] = None,
|
||||
task_id: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
import uuid
|
||||
|
||||
effective_task_id = task_id or str(uuid.uuid4())
|
||||
|
||||
messages: List[Dict[str, Any]] = conversation_history.copy() if conversation_history else []
|
||||
messages.append({"role": "user", "content": user_message})
|
||||
|
||||
active_system_prompt = self._build_system_prompt(system_message)
|
||||
|
||||
api_call_count = 0
|
||||
final_response: Optional[str] = None
|
||||
managed_state: Optional[Dict[str, Any]] = None
|
||||
completed = False
|
||||
|
||||
try:
|
||||
async with self._managed() as managed:
|
||||
while api_call_count < self.max_iterations:
|
||||
api_call_count += 1
|
||||
|
||||
api_messages = messages.copy()
|
||||
if active_system_prompt:
|
||||
api_messages = [{"role": "system", "content": active_system_prompt}] + api_messages
|
||||
|
||||
chat_kwargs: Dict[str, Any] = {"messages": api_messages, "n": 1}
|
||||
if self.max_tokens is not None:
|
||||
chat_kwargs["max_tokens"] = self.max_tokens
|
||||
if self.temperature is not None:
|
||||
chat_kwargs["temperature"] = self.temperature
|
||||
|
||||
# Prefer OpenAI tool calling when supported by the backend:
|
||||
# - Many providers normalize Hermes-style <tool_call> tags into tool_calls when `tools` is provided.
|
||||
# - ManagedServer (atroposlib) does prompt->completion conversion and does not support `tools`.
|
||||
# Only pass `tools` when we're calling an OpenAI-compatible chat endpoint directly.
|
||||
tool_schemas = self.tools if self.tools else None
|
||||
managed_cls = type(managed).__name__
|
||||
if tool_schemas and managed_cls != "ManagedServer":
|
||||
chat_kwargs["tools"] = tool_schemas
|
||||
|
||||
if os.getenv("HERMES_DEBUG_ATROPOS_REQUEST") == "1":
|
||||
meta = {
|
||||
"managed_type": managed_cls,
|
||||
"model": getattr(getattr(managed, "config", None), "model_name", self.model),
|
||||
"base_url": getattr(getattr(managed, "config", None), "base_url", None),
|
||||
"kwargs": chat_kwargs,
|
||||
}
|
||||
# Avoid dumping megabytes of data accidentally.
|
||||
# (Messages can be large; this is still "full" but bounded.)
|
||||
print("\n=== HERMES_DEBUG_ATROPOS_REQUEST ===", flush=True)
|
||||
print(json.dumps(meta, ensure_ascii=False, indent=2)[:200_000], flush=True)
|
||||
|
||||
response = await managed.chat_completion(**chat_kwargs)
|
||||
|
||||
if os.getenv("HERMES_DEBUG_ATROPOS_RESPONSE") == "1":
|
||||
try:
|
||||
dumped = response.model_dump() # openai pydantic model
|
||||
except Exception:
|
||||
dumped = getattr(response, "__dict__", {"repr": repr(response)})
|
||||
print("\n=== HERMES_DEBUG_ATROPOS_RESPONSE: ChatCompletion (raw) ===", flush=True)
|
||||
print(json.dumps(dumped, ensure_ascii=False, indent=2), flush=True)
|
||||
|
||||
if hasattr(managed, "get_state"):
|
||||
managed_state = managed.get_state()
|
||||
|
||||
msg = response.choices[0].message
|
||||
assistant_content = (msg.content or "")
|
||||
msg_reasoning = getattr(msg, "reasoning", None)
|
||||
|
||||
# Use tool_calls if the backend provides them (preferred).
|
||||
structured_tool_calls = getattr(msg, "tool_calls", None)
|
||||
|
||||
# If the backend emits content="" but includes useful text in reasoning,
|
||||
# use it for parsing *only if needed* (e.g. tool tags).
|
||||
if assistant_content == "" and isinstance(msg_reasoning, str) and msg_reasoning:
|
||||
if os.getenv("HERMES_DEBUG_ATROPOS_RESPONSE") == "1":
|
||||
print("\n=== HERMES_DEBUG_ATROPOS_RESPONSE: message.reasoning present (content empty) ===", flush=True)
|
||||
print(msg_reasoning, flush=True)
|
||||
|
||||
assistant_msg: Dict[str, Any] = {"role": "assistant", "content": assistant_content}
|
||||
if structured_tool_calls:
|
||||
# Preserve tool_calls so the next request is consistent with OpenAI protocol.
|
||||
try:
|
||||
assistant_msg["tool_calls"] = [
|
||||
{
|
||||
"id": tc.id,
|
||||
"type": tc.type,
|
||||
"function": {"name": tc.function.name, "arguments": tc.function.arguments},
|
||||
}
|
||||
for tc in structured_tool_calls
|
||||
]
|
||||
except Exception:
|
||||
# Best-effort; keep conversation moving.
|
||||
pass
|
||||
messages.append(assistant_msg)
|
||||
|
||||
# Mode A: OpenAI tool calling (preferred when supported)
|
||||
if structured_tool_calls:
|
||||
for tc in structured_tool_calls:
|
||||
tool_start = time.time()
|
||||
try:
|
||||
tool_args = json.loads(tc.function.arguments or "{}")
|
||||
except Exception:
|
||||
tool_args = {}
|
||||
tool_result = handle_function_call(tc.function.name, tool_args, effective_task_id)
|
||||
tool_duration = time.time() - tool_start
|
||||
|
||||
# Keep the raw tool result as tool content (OpenAI protocol expects role=tool).
|
||||
messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": tc.id,
|
||||
"content": tool_result,
|
||||
}
|
||||
)
|
||||
|
||||
if self.tool_delay and self.tool_delay > 0:
|
||||
await asyncio.sleep(self.tool_delay)
|
||||
|
||||
# Continue loop after tool execution.
|
||||
continue
|
||||
|
||||
# Mode B: Hermes XML tool tags in assistant text (fallback).
|
||||
parse_source = assistant_content or (msg_reasoning or "")
|
||||
tool_calls, parse_errors = self._parse_tool_calls(parse_source)
|
||||
|
||||
if parse_errors and not tool_calls:
|
||||
# Ask the model to retry with valid tool JSON.
|
||||
err_text = "; ".join(parse_errors[:3])
|
||||
messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": (
|
||||
f"<tool_response>{json.dumps({'error': err_text}, ensure_ascii=False)}</tool_response>\n"
|
||||
"The previous <tool_call> blocks were invalid. Please output valid JSON inside <tool_call>."
|
||||
),
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
if not tool_calls:
|
||||
# No tool calls: treat as final answer.
|
||||
final_response = (assistant_content or "").strip()
|
||||
completed = True
|
||||
break
|
||||
|
||||
tool_responses: List[str] = []
|
||||
for tool_name, tool_args in tool_calls:
|
||||
tool_start = time.time()
|
||||
tool_result = handle_function_call(tool_name, tool_args, effective_task_id)
|
||||
tool_duration = time.time() - tool_start
|
||||
|
||||
try:
|
||||
parsed = json.loads(tool_result)
|
||||
payload: Any = parsed
|
||||
except Exception:
|
||||
payload = tool_result
|
||||
|
||||
tool_payload = {
|
||||
"name": tool_name,
|
||||
"duration_s": round(tool_duration, 3),
|
||||
"result": payload,
|
||||
}
|
||||
tool_responses.append(
|
||||
f"<tool_response>{json.dumps(tool_payload, ensure_ascii=False)}</tool_response>"
|
||||
)
|
||||
|
||||
if self.tool_delay and self.tool_delay > 0:
|
||||
await asyncio.sleep(self.tool_delay)
|
||||
|
||||
messages.append({"role": "user", "content": "\n".join(tool_responses)})
|
||||
|
||||
if final_response is None:
|
||||
final_response = "I've reached the maximum number of iterations."
|
||||
|
||||
finally:
|
||||
try:
|
||||
cleanup_vm(effective_task_id)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Save trajectory using Hermes formatting (optional).
|
||||
self._save_trajectory(messages, user_message, completed=completed)
|
||||
|
||||
return {
|
||||
"final_response": final_response,
|
||||
"messages": messages,
|
||||
"api_calls": api_call_count,
|
||||
"completed": completed,
|
||||
"managed_state": managed_state,
|
||||
"system_prompt": active_system_prompt,
|
||||
"task_id": effective_task_id,
|
||||
}
|
||||
|
||||
def run_conversation(self, *args: Any, **kwargs: Any) -> Dict[str, Any]:
|
||||
"""
|
||||
Sync wrapper for convenience.
|
||||
|
||||
If called from within a running event loop (e.g. prompt_toolkit), this
|
||||
runs the async conversation in a dedicated thread to avoid nested loops.
|
||||
"""
|
||||
try:
|
||||
asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
return asyncio.run(self.run_conversation_async(*args, **kwargs))
|
||||
|
||||
import queue
|
||||
import threading
|
||||
|
||||
out: "queue.Queue[object]" = queue.Queue(maxsize=1)
|
||||
|
||||
def runner() -> None:
|
||||
try:
|
||||
out.put(asyncio.run(self.run_conversation_async(*args, **kwargs)))
|
||||
except BaseException as exc: # noqa: BLE001
|
||||
out.put(exc)
|
||||
|
||||
thread = threading.Thread(target=runner, daemon=True)
|
||||
thread.start()
|
||||
|
||||
result = out.get()
|
||||
if isinstance(result, BaseException):
|
||||
raise result
|
||||
return result # type: ignore[return-value]
|
||||
1228
batch_runner_threaded.py
Normal file
1228
batch_runner_threaded.py
Normal file
File diff suppressed because it is too large
Load Diff
224
docs/MODAL_BACKEND.md
Normal file
224
docs/MODAL_BACKEND.md
Normal file
@@ -0,0 +1,224 @@
|
||||
# Modal Backend
|
||||
|
||||
Hermes Agent uses [Modal](https://modal.com) for scalable, isolated cloud execution environments. There are two Modal integrations:
|
||||
|
||||
1. **Terminal Tool** (`tools/terminal_tool.py`) - For CLI/agent command execution
|
||||
2. **Atropos Backend** (`atropos/backends/modal_backend.py`) - For batch RL training workloads
|
||||
|
||||
|
||||
|
||||
---
|
||||
|
||||
## Terminal Tool (CLI/Agent)
|
||||
|
||||
The terminal tool provides a simple interface for executing commands in Modal sandboxes.
|
||||
|
||||
### Configuration
|
||||
|
||||
Set environment variables:
|
||||
|
||||
```bash
|
||||
export TERMINAL_ENV=modal
|
||||
export TERMINAL_MODAL_IMAGE=python:3.11
|
||||
export TERMINAL_MODAL_APP_NAME=hermes-sandbox
|
||||
```
|
||||
|
||||
Or use a YAML config file (`modal_profiles.yaml`):
|
||||
|
||||
```yaml
|
||||
profiles:
|
||||
default:
|
||||
image: python:3.11
|
||||
cpu: 1.0
|
||||
memory: 2048
|
||||
min_pool: 1
|
||||
max_pool: 5
|
||||
idle_timeout: 120
|
||||
|
||||
gpu:
|
||||
image: pytorch/pytorch:2.1.0-cuda12.1-cudnn8-runtime
|
||||
gpu: T4
|
||||
memory: 16384
|
||||
min_pool: 0
|
||||
max_pool: 2
|
||||
```
|
||||
|
||||
### Features
|
||||
|
||||
| Feature | Description |
|
||||
|---------|-------------|
|
||||
| **Sandbox Pool** | Pre-warmed sandboxes for low latency |
|
||||
| **Auto-scaling** | Grows/shrinks pool based on demand |
|
||||
| **Idle Timeout** | Sandboxes auto-terminate when unused |
|
||||
| **Profile Selection** | Different configs for different workloads |
|
||||
| **Credential Injection** | `modal.Secret` integration |
|
||||
|
||||
### Usage
|
||||
|
||||
```python
|
||||
from tools.terminal_tool import terminal_tool
|
||||
|
||||
# Simple command
|
||||
output = terminal_tool("echo hello", task_id="my-task")
|
||||
|
||||
# With profile selection
|
||||
output = terminal_tool("python train.py", task_id="training", profile="gpu")
|
||||
|
||||
# Cleanup when done
|
||||
from tools.terminal_tool import cleanup_vm
|
||||
cleanup_vm("my-task")
|
||||
```
|
||||
|
||||
### Architecture
|
||||
|
||||
```
|
||||
_ModalPoolManager (singleton)
|
||||
├── "default" pool → [sandbox-0, sandbox-1, ...]
|
||||
└── "gpu" pool → [sandbox-0, ...]
|
||||
|
||||
Each pool:
|
||||
- Maintains min_pool warm sandboxes
|
||||
- Scales up to max_pool on demand
|
||||
- Background thread scales down idle sandboxes
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Atropos Backend (RL Training)
|
||||
|
||||
The Atropos backend is designed for high-throughput batch execution during reinforcement learning training.
|
||||
|
||||
### Key Concept: Slot-based Multiplexing
|
||||
|
||||
Instead of one sandbox per trajectory, multiple trajectories share sandboxes via **slots**:
|
||||
|
||||
```
|
||||
Sandbox (1 container)
|
||||
├── Slot 0 → Trajectory A (workspace: /data/slot_0)
|
||||
├── Slot 1 → Trajectory B (workspace: /data/slot_1)
|
||||
└── Slot 2 → Trajectory C (workspace: /data/slot_2)
|
||||
```
|
||||
|
||||
**Benefits**:
|
||||
- Fewer containers = lower cost
|
||||
- Shared warm-up time
|
||||
- Better GPU utilization
|
||||
|
||||
### Configuration
|
||||
|
||||
```python
|
||||
from atropos.backends.modal_backend import ModalSandboxConfig, ModalToolBackend
|
||||
|
||||
config = ModalSandboxConfig(
|
||||
name="default",
|
||||
image="python:3.11",
|
||||
cpu=1.0,
|
||||
memory=2048,
|
||||
slots_per_sandbox=10, # 10 trajectories per container
|
||||
min_sandboxes=1,
|
||||
max_sandboxes=5,
|
||||
)
|
||||
|
||||
backend = ModalToolBackend(config.with_app_name("my-training"))
|
||||
```
|
||||
|
||||
### Multi-Profile Support
|
||||
|
||||
Different trajectory types can request different resources:
|
||||
|
||||
```python
|
||||
backend = ModalToolBackend.with_profiles(
|
||||
app_name="rl-training",
|
||||
profiles={
|
||||
"default": ModalSandboxConfig(
|
||||
name="default",
|
||||
cpu=1.0,
|
||||
memory=2048,
|
||||
),
|
||||
"pytorch-gpu": ModalSandboxConfig(
|
||||
name="pytorch-gpu",
|
||||
image="pytorch/pytorch:2.1.0-cuda12.1-cudnn8-runtime",
|
||||
gpu="T4",
|
||||
memory=16384,
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
# CPU task
|
||||
slot1 = await backend.acquire("traj-1", profile="default")
|
||||
|
||||
# GPU task
|
||||
slot2 = await backend.acquire("traj-2", profile="pytorch-gpu")
|
||||
```
|
||||
|
||||
### Batched Execution
|
||||
|
||||
The key optimization - execute many commands in parallel:
|
||||
|
||||
```python
|
||||
# Acquire slots for multiple trajectories
|
||||
slots = [await backend.acquire(f"traj-{i}") for i in range(50)]
|
||||
|
||||
# Execute batch across all slots in parallel
|
||||
results = await backend.execute_batch([
|
||||
(slot, "bash", {"command": "python step.py"})
|
||||
for slot in slots
|
||||
])
|
||||
|
||||
# Release slots
|
||||
for slot in slots:
|
||||
await backend.release(slot)
|
||||
```
|
||||
|
||||
### Architecture
|
||||
|
||||
```
|
||||
ModalToolBackend
|
||||
└── _ModalMultiProfileManager
|
||||
├── "default" → _ModalSandboxPool
|
||||
│ ├── Sandbox 0 (slots 0-9)
|
||||
│ └── Sandbox 1 (slots 0-9)
|
||||
│
|
||||
└── "pytorch-gpu" → _ModalSandboxPool
|
||||
└── Sandbox 0 (slots 0-9)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Credentials
|
||||
|
||||
Inject secrets securely using Modal's secret management:
|
||||
|
||||
```bash
|
||||
# Create secret in Modal dashboard or CLI
|
||||
modal secret create my-api-key API_KEY=sk-xxx
|
||||
```
|
||||
|
||||
```python
|
||||
# Reference in config
|
||||
config = ModalSandboxConfig(
|
||||
secrets=["my-api-key"], # Modal secret names
|
||||
env_vars={"DEBUG": "1"}, # Additional env vars
|
||||
)
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### "Modal package not installed"
|
||||
```bash
|
||||
pip install modal
|
||||
modal token new # Authenticate
|
||||
```
|
||||
|
||||
### "Sandbox creation failed"
|
||||
- Check Modal dashboard for quota limits
|
||||
- Verify image exists and is accessible
|
||||
- Check secret names are correct
|
||||
|
||||
### Shutdown errors
|
||||
These are harmless warnings during Python interpreter shutdown:
|
||||
```
|
||||
[Modal] Error terminating ...: cannot schedule new futures after interpreter shutdown
|
||||
```
|
||||
|
||||
The sandboxes will auto-terminate via Modal's idle_timeout anyway.
|
||||
@@ -57,6 +57,12 @@ class AgentResult:
|
||||
# Tool errors encountered during the loop
|
||||
tool_errors: List[ToolError] = field(default_factory=list)
|
||||
|
||||
# Tool-call metrics (for reward shaping + debugging)
|
||||
tool_calls_attempted: int = 0 # Valid tool name + attempted dispatch
|
||||
tool_calls_schema_valid: int = 0 # Arguments matched schema (no coercion)
|
||||
tool_calls_executed_ok: int = 0 # Tool ran and returned no error
|
||||
tool_calls_exec_error: int = 0 # Unknown tool / exception / tool returned error
|
||||
|
||||
|
||||
def _extract_reasoning_from_message(message) -> Optional[str]:
|
||||
"""
|
||||
@@ -119,6 +125,8 @@ class HermesAgentLoop:
|
||||
task_id: Optional[str] = None,
|
||||
temperature: float = 1.0,
|
||||
max_tokens: Optional[int] = None,
|
||||
tool_handler=None,
|
||||
max_context_tokens: Optional[int] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the agent loop.
|
||||
@@ -132,6 +140,13 @@ class HermesAgentLoop:
|
||||
task_id: Unique ID for terminal/browser session isolation
|
||||
temperature: Sampling temperature for generation
|
||||
max_tokens: Max tokens per generation (None for server default)
|
||||
tool_handler: Optional async callable(tool_name, args, task_id) -> str.
|
||||
When provided, used INSTEAD of handle_function_call() for
|
||||
tool dispatch. This allows sandbox backends (Modal, Nomad)
|
||||
to route tool calls through their slot-based execution.
|
||||
max_context_tokens: Maximum prompt tokens before truncation.
|
||||
If None, no truncation is applied.
|
||||
Recommended: set to max_model_len - max_tokens - 512 (safety margin).
|
||||
"""
|
||||
self.server = server
|
||||
self.tool_schemas = tool_schemas
|
||||
@@ -140,6 +155,139 @@ class HermesAgentLoop:
|
||||
self.task_id = task_id or str(uuid.uuid4())
|
||||
self.temperature = temperature
|
||||
self.max_tokens = max_tokens
|
||||
self.tool_handler = tool_handler
|
||||
self.max_context_tokens = max_context_tokens
|
||||
|
||||
|
||||
def _truncate_context(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Truncate conversation history to fit within max_context_tokens.
|
||||
|
||||
Strategy:
|
||||
- Keep system message (index 0) and initial user message (index 1) always
|
||||
- Keep last 6 messages (recent context) always
|
||||
- For everything in between, progressively truncate tool result content
|
||||
- If still too long, drop oldest middle messages entirely
|
||||
|
||||
Uses rough char/4 token estimate (fast, no tokenizer needed).
|
||||
"""
|
||||
if self.max_context_tokens is None:
|
||||
return messages
|
||||
|
||||
def estimate_tokens(msgs):
|
||||
total = 0
|
||||
for m in msgs:
|
||||
content = m.get("content", "") or ""
|
||||
total += len(content) // 4 + 10 # ~4 chars per token + overhead
|
||||
if "tool_calls" in m:
|
||||
total += 50 * len(m["tool_calls"]) # tool call overhead
|
||||
return total
|
||||
|
||||
est = estimate_tokens(messages)
|
||||
if est <= self.max_context_tokens:
|
||||
return messages
|
||||
|
||||
# Phase 1: Truncate tool result content in middle messages
|
||||
# Keep first 2 and last 6 messages untouched
|
||||
protect_head = 2
|
||||
protect_tail = max(0, min(6, len(messages) - protect_head))
|
||||
middle_start = protect_head
|
||||
middle_end = len(messages) - protect_tail
|
||||
|
||||
if middle_start < middle_end:
|
||||
# Truncate tool results from oldest first
|
||||
for i in range(middle_start, middle_end):
|
||||
if messages[i].get("role") == "tool":
|
||||
content = messages[i].get("content", "") or ""
|
||||
if len(content) > 200:
|
||||
messages[i] = dict(messages[i]) # copy
|
||||
messages[i]["content"] = content[:100] + "\n...[truncated]...\n" + content[-50:]
|
||||
|
||||
est = estimate_tokens(messages)
|
||||
if est <= self.max_context_tokens:
|
||||
logger.debug("Context truncated (phase 1: tool results): %d tokens", est)
|
||||
return messages
|
||||
|
||||
# Phase 2: Drop oldest middle messages entirely
|
||||
while middle_start < middle_end and estimate_tokens(messages) > self.max_context_tokens:
|
||||
# Remove the oldest middle message
|
||||
# But keep assistant+tool pairs together
|
||||
msg = messages[middle_start]
|
||||
messages.pop(middle_start)
|
||||
middle_end -= 1
|
||||
# If we removed an assistant with tool_calls, also remove matching tool responses
|
||||
if msg.get("role") == "assistant" and msg.get("tool_calls"):
|
||||
tool_ids = {tc.get("id") or tc.get("tool_call_id", "") for tc in msg.get("tool_calls", []) if isinstance(tc, dict)}
|
||||
# Remove tool responses for those IDs
|
||||
i = middle_start
|
||||
while i < middle_end:
|
||||
if messages[i].get("role") == "tool" and messages[i].get("tool_call_id", "") in tool_ids:
|
||||
messages.pop(i)
|
||||
middle_end -= 1
|
||||
else:
|
||||
i += 1
|
||||
|
||||
est = estimate_tokens(messages)
|
||||
logger.info("Context truncated (phase 2: dropped messages): %d estimated tokens, %d messages remaining", est, len(messages))
|
||||
return messages
|
||||
|
||||
def _normalize_tool_args(self, tool_name: str, tool_args_raw: str) -> (Dict[str, Any], bool):
|
||||
"""Normalize tool arguments into a dict.
|
||||
|
||||
Returns:
|
||||
(args_dict, schema_valid)
|
||||
|
||||
schema_valid is True only when the arguments decode directly into a dict
|
||||
(i.e. no double-decoding and no coercion/wrapping was needed).
|
||||
|
||||
This lets us keep the environment robust (never crash due to args format)
|
||||
while still scoring down malformed tool-call argument formats.
|
||||
"""
|
||||
try:
|
||||
decoded = json.loads(tool_args_raw)
|
||||
except json.JSONDecodeError:
|
||||
# Not valid JSON at all. Be robust: treat it as a plain string.
|
||||
# (Some parsers/providers may pass through non-JSON strings.)
|
||||
if tool_name == "terminal":
|
||||
return {"command": tool_args_raw}, False
|
||||
return {"input": tool_args_raw}, False
|
||||
|
||||
# Canonical case: decoded is already a dict
|
||||
if isinstance(decoded, dict):
|
||||
# For terminal tool, require a command key
|
||||
if tool_name == "terminal":
|
||||
cmd = decoded.get("command")
|
||||
if isinstance(cmd, str) and cmd.strip():
|
||||
return decoded, True
|
||||
# Common alternate key
|
||||
if isinstance(decoded.get("input"), str):
|
||||
return {"command": decoded.get("input")}, False
|
||||
return decoded, False
|
||||
return decoded, True
|
||||
|
||||
# Common drift case: decoded is a JSON string of an object
|
||||
if isinstance(decoded, str):
|
||||
s = decoded.strip()
|
||||
if (s.startswith("{") and s.endswith("}")) or (s.startswith("[") and s.endswith("]")):
|
||||
try:
|
||||
decoded2 = json.loads(s)
|
||||
except json.JSONDecodeError:
|
||||
decoded2 = None
|
||||
if isinstance(decoded2, dict):
|
||||
# Terminal tool: ensure command
|
||||
if tool_name == "terminal" and isinstance(decoded2.get("command"), str):
|
||||
return decoded2, False
|
||||
return decoded2, False
|
||||
|
||||
# Plain string (not JSON) — coerce to expected shape
|
||||
if tool_name == "terminal":
|
||||
return {"command": decoded}, False
|
||||
return {"input": decoded}, False
|
||||
|
||||
# Other JSON types (list/number/etc.) — wrap
|
||||
if tool_name == "terminal":
|
||||
return {"command": str(decoded)}, False
|
||||
return {"input": decoded}, False
|
||||
|
||||
async def run(self, messages: List[Dict[str, Any]]) -> AgentResult:
|
||||
"""
|
||||
@@ -147,7 +295,12 @@ class HermesAgentLoop:
|
||||
|
||||
Args:
|
||||
messages: Initial conversation messages (system + user).
|
||||
Modified in-place as the conversation progresses.
|
||||
This list is treated as the FULL trajectory and is
|
||||
appended to as the conversation progresses.
|
||||
|
||||
Prompt truncation (to avoid context overflow) is applied
|
||||
on a copy of this list per turn, so we do not lose
|
||||
earlier messages for reward computation/debugging.
|
||||
|
||||
Returns:
|
||||
AgentResult with full conversation history, managed state, and metadata
|
||||
@@ -155,10 +308,21 @@ class HermesAgentLoop:
|
||||
reasoning_per_turn = []
|
||||
tool_errors: List[ToolError] = []
|
||||
|
||||
# Metrics to separate "attempted tool use" from "schema-valid tool use"
|
||||
tool_calls_attempted = 0
|
||||
tool_calls_schema_valid = 0
|
||||
tool_calls_executed_ok = 0
|
||||
tool_calls_exec_error = 0
|
||||
|
||||
for turn in range(self.max_turns):
|
||||
# Truncate context if approaching limit.
|
||||
# IMPORTANT: do this on a copy so we keep the full trajectory in `messages`
|
||||
# for reward computation + debugging, while only trimming the prompt view.
|
||||
prompt_messages = self._truncate_context(list(messages))
|
||||
|
||||
# Build the chat_completion kwargs
|
||||
chat_kwargs = {
|
||||
"messages": messages,
|
||||
"messages": prompt_messages,
|
||||
"n": 1,
|
||||
"temperature": self.temperature,
|
||||
}
|
||||
@@ -183,6 +347,10 @@ class HermesAgentLoop:
|
||||
finished_naturally=False,
|
||||
reasoning_per_turn=reasoning_per_turn,
|
||||
tool_errors=tool_errors,
|
||||
tool_calls_attempted=tool_calls_attempted,
|
||||
tool_calls_schema_valid=tool_calls_schema_valid,
|
||||
tool_calls_executed_ok=tool_calls_executed_ok,
|
||||
tool_calls_exec_error=tool_calls_exec_error,
|
||||
)
|
||||
|
||||
if not response or not response.choices:
|
||||
@@ -194,6 +362,10 @@ class HermesAgentLoop:
|
||||
finished_naturally=False,
|
||||
reasoning_per_turn=reasoning_per_turn,
|
||||
tool_errors=tool_errors,
|
||||
tool_calls_attempted=tool_calls_attempted,
|
||||
tool_calls_schema_valid=tool_calls_schema_valid,
|
||||
tool_calls_executed_ok=tool_calls_executed_ok,
|
||||
tool_calls_exec_error=tool_calls_exec_error,
|
||||
)
|
||||
|
||||
assistant_msg = response.choices[0].message
|
||||
@@ -252,35 +424,45 @@ class HermesAgentLoop:
|
||||
"Model called unknown tool '%s' on turn %d",
|
||||
tool_name, turn + 1,
|
||||
)
|
||||
tool_calls_exec_error += 1
|
||||
else:
|
||||
# Parse arguments and dispatch
|
||||
try:
|
||||
args = json.loads(tool_args_raw)
|
||||
except json.JSONDecodeError:
|
||||
args = {}
|
||||
logger.warning(
|
||||
"Invalid JSON in tool call arguments for '%s': %s",
|
||||
tool_name, tool_args_raw[:200],
|
||||
)
|
||||
tool_calls_attempted += 1
|
||||
|
||||
# Normalize args into a dict so we never crash due to formatting.
|
||||
# Track schema_valid separately so reward shaping can penalize
|
||||
# non-canonical formats (e.g. stringified JSON).
|
||||
args, schema_valid = self._normalize_tool_args(tool_name, tool_args_raw)
|
||||
if schema_valid:
|
||||
tool_calls_schema_valid += 1
|
||||
|
||||
try:
|
||||
if tool_name == "terminal":
|
||||
import os
|
||||
backend = os.getenv("TERMINAL_ENV", "local")
|
||||
cmd_preview = args.get("command", "")[:80]
|
||||
if self.tool_handler:
|
||||
backend = "sandbox"
|
||||
cmd_preview = str(args.get("command", ""))[:80]
|
||||
print(f" 🖥️ [{backend}] $ {cmd_preview}")
|
||||
|
||||
# Run tool calls in a thread pool so backends that use
|
||||
# asyncio.run() internally (modal, docker) get a clean
|
||||
# event loop instead of deadlocking inside Atropos's loop.
|
||||
loop = asyncio.get_event_loop()
|
||||
tool_result = await loop.run_in_executor(
|
||||
_tool_executor,
|
||||
lambda: handle_function_call(
|
||||
tool_name, args, task_id=self.task_id
|
||||
),
|
||||
)
|
||||
if self.tool_handler:
|
||||
# Use custom tool handler (sandbox backend routing)
|
||||
tool_result = await self.tool_handler(
|
||||
tool_name, args, self.task_id
|
||||
)
|
||||
else:
|
||||
# Default: run via hermes-agent's handle_function_call
|
||||
# in a thread pool so backends that use asyncio.run()
|
||||
# internally (modal, docker) get a clean event loop
|
||||
# instead of deadlocking inside Atropos's loop.
|
||||
loop = asyncio.get_event_loop()
|
||||
tool_result = await loop.run_in_executor(
|
||||
_tool_executor,
|
||||
lambda: handle_function_call(
|
||||
tool_name, args, task_id=self.task_id
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
tool_calls_exec_error += 1
|
||||
tool_result = json.dumps(
|
||||
{"error": f"Tool execution failed: {type(e).__name__}: {str(e)}"}
|
||||
)
|
||||
@@ -294,22 +476,34 @@ class HermesAgentLoop:
|
||||
"Tool '%s' execution failed on turn %d: %s",
|
||||
tool_name, turn + 1, e,
|
||||
)
|
||||
else:
|
||||
# Count tool result errors (if tool returns structured JSON error)
|
||||
tool_err = False
|
||||
try:
|
||||
result_data = json.loads(tool_result)
|
||||
if isinstance(result_data, dict):
|
||||
err = result_data.get("error")
|
||||
if err:
|
||||
tool_err = True
|
||||
|
||||
# Also check if the tool returned an error in its JSON result
|
||||
try:
|
||||
result_data = json.loads(tool_result)
|
||||
if isinstance(result_data, dict):
|
||||
err = result_data.get("error")
|
||||
exit_code = result_data.get("exit_code")
|
||||
if err and exit_code and exit_code < 0:
|
||||
tool_errors.append(ToolError(
|
||||
turn=turn + 1, tool_name=tool_name,
|
||||
arguments=tool_args_raw[:200],
|
||||
error=str(err),
|
||||
tool_result=tool_result[:500],
|
||||
))
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
# Keep existing behavior: treat negative exit_code as tool error
|
||||
exit_code = result_data.get("exit_code")
|
||||
if exit_code is not None and isinstance(exit_code, int) and exit_code < 0:
|
||||
tool_err = True
|
||||
tool_errors.append(ToolError(
|
||||
turn=turn + 1, tool_name=tool_name,
|
||||
arguments=tool_args_raw[:200],
|
||||
error=str(err) if err else "nonzero exit_code",
|
||||
tool_result=tool_result[:500],
|
||||
))
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
# Non-JSON tool output — assume ok
|
||||
pass
|
||||
|
||||
if tool_err:
|
||||
tool_calls_exec_error += 1
|
||||
else:
|
||||
tool_calls_executed_ok += 1
|
||||
|
||||
# Add tool response to conversation
|
||||
messages.append(
|
||||
@@ -347,6 +541,10 @@ class HermesAgentLoop:
|
||||
finished_naturally=True,
|
||||
reasoning_per_turn=reasoning_per_turn,
|
||||
tool_errors=tool_errors,
|
||||
tool_calls_attempted=tool_calls_attempted,
|
||||
tool_calls_schema_valid=tool_calls_schema_valid,
|
||||
tool_calls_executed_ok=tool_calls_executed_ok,
|
||||
tool_calls_exec_error=tool_calls_exec_error,
|
||||
)
|
||||
|
||||
# Hit max turns without the model stopping
|
||||
@@ -358,6 +556,10 @@ class HermesAgentLoop:
|
||||
finished_naturally=False,
|
||||
reasoning_per_turn=reasoning_per_turn,
|
||||
tool_errors=tool_errors,
|
||||
tool_calls_attempted=tool_calls_attempted,
|
||||
tool_calls_schema_valid=tool_calls_schema_valid,
|
||||
tool_calls_executed_ok=tool_calls_executed_ok,
|
||||
tool_calls_exec_error=tool_calls_exec_error,
|
||||
)
|
||||
|
||||
def _get_managed_state(self) -> Optional[Dict[str, Any]]:
|
||||
|
||||
350
environments/gsm8k_agent_env.py
Normal file
350
environments/gsm8k_agent_env.py
Normal file
@@ -0,0 +1,350 @@
|
||||
"""
|
||||
GSM8kAgentEnv -- Math Reasoning with Tool Use (Python REPL)
|
||||
|
||||
An agentic RL environment where models solve GSM8k math problems using
|
||||
a Python interpreter tool. Uses proper OpenAI-spec tool calling via
|
||||
HermesAgentBaseEnv (not ICL).
|
||||
|
||||
The model:
|
||||
1. Receives a math problem
|
||||
2. Can call the `terminal` tool to run Python code (`python3 -c "..."`)
|
||||
3. Provides a final answer in \\boxed{} format
|
||||
4. Gets reward: 1.0 if correct, 0.0 if wrong
|
||||
|
||||
Usage:
|
||||
# Phase 1 (OpenRouter, no training):
|
||||
python environments/gsm8k_agent_env.py process \\
|
||||
--env.data_path_to_save_groups gsm8k_agent_output.jsonl
|
||||
|
||||
# Phase 2 (VLLM + Tinker training):
|
||||
run-api
|
||||
python launch_training.py --config configs/gsm8k_agent.yaml
|
||||
python environments/gsm8k_agent_env.py serve
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
# Ensure repo root is on sys.path
|
||||
_repo_root = Path(__file__).resolve().parent.parent
|
||||
if str(_repo_root) not in sys.path:
|
||||
sys.path.insert(0, str(_repo_root))
|
||||
|
||||
from atroposlib.envs.base import ScoredDataGroup
|
||||
from atroposlib.envs.server_handling.server_manager import APIServerConfig
|
||||
from atroposlib.type_definitions import Item
|
||||
|
||||
from environments.agent_loop import AgentResult
|
||||
from environments.hermes_base_env import HermesAgentBaseEnv, HermesAgentEnvConfig
|
||||
from environments.tool_context import ToolContext
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Math verification helpers
|
||||
# =============================================================================
|
||||
|
||||
def _verify_math_answer(model_response: str, gold_answer: str) -> bool:
|
||||
"""
|
||||
Verify if the model's response contains the correct answer.
|
||||
Uses math_verify for robust LaTeX comparison, falls back to string matching.
|
||||
"""
|
||||
try:
|
||||
from latex2sympy2_extended import NormalizationConfig
|
||||
from math_verify import LatexExtractionConfig, parse, verify
|
||||
|
||||
gold_parsed = parse(
|
||||
f"\\boxed{{{gold_answer}}}",
|
||||
extraction_mode="first_match",
|
||||
extraction_config=[LatexExtractionConfig()],
|
||||
)
|
||||
|
||||
# Strip <think> blocks if present
|
||||
answer_text = model_response
|
||||
if "</think>" in answer_text:
|
||||
answer_text = answer_text.split("</think>")[-1]
|
||||
|
||||
answer_parsed = parse(
|
||||
answer_text,
|
||||
extraction_config=[
|
||||
LatexExtractionConfig(
|
||||
normalization_config=NormalizationConfig(
|
||||
nits=False,
|
||||
malformed_operators=False,
|
||||
basic_latex=True,
|
||||
boxed="all",
|
||||
units=True,
|
||||
),
|
||||
boxed_match_priority=0,
|
||||
try_extract_without_anchor=False,
|
||||
)
|
||||
],
|
||||
extraction_mode="first_match",
|
||||
)
|
||||
|
||||
return bool(verify(answer_parsed, gold_parsed))
|
||||
|
||||
except ImportError:
|
||||
# Fallback: simple string matching for \\boxed{answer}
|
||||
import re
|
||||
pattern = r'\\boxed\{([^}]+)\}'
|
||||
matches = re.findall(pattern, model_response)
|
||||
if matches:
|
||||
model_answer = matches[-1].strip().replace(",", "")
|
||||
gold_clean = gold_answer.strip().replace(",", "")
|
||||
return model_answer == gold_clean
|
||||
return False
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Environment Config
|
||||
# =============================================================================
|
||||
|
||||
class GSM8kAgentEnvConfig(HermesAgentEnvConfig):
|
||||
"""Config with defaults for GSM8k agent environment."""
|
||||
pass
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Environment
|
||||
# =============================================================================
|
||||
|
||||
class GSM8kAgentEnv(HermesAgentBaseEnv):
|
||||
"""
|
||||
GSM8k math environment with Python REPL tool calling.
|
||||
|
||||
Models solve grade-school math problems by reasoning step by step
|
||||
and using Python (via the terminal tool) for calculations.
|
||||
|
||||
Exercises the full agentic RL training loop:
|
||||
- Model receives math problem
|
||||
- Makes tool calls to compute (python3 -c "...")
|
||||
- Provides final answer in \\boxed{}
|
||||
- Reward: binary (1.0 correct, 0.0 wrong)
|
||||
"""
|
||||
|
||||
name = "gsm8k-agent"
|
||||
env_config_cls = GSM8kAgentEnvConfig
|
||||
|
||||
@classmethod
|
||||
def config_init(cls) -> Tuple[GSM8kAgentEnvConfig, List[APIServerConfig]]:
|
||||
"""
|
||||
Default config using terminal tool.
|
||||
|
||||
Reads from environment variables (set in .env):
|
||||
ATROPOS_SERVER_BASE_URL - Inference server URL
|
||||
ATROPOS_SERVER_MODEL - Model name on the server
|
||||
ATROPOS_TOKENIZER_NAME - HuggingFace tokenizer name
|
||||
ATROPOS_SERVER_API_KEY - API key for the server
|
||||
"""
|
||||
# Resolve inference server settings from env
|
||||
base_url = (
|
||||
os.getenv("ATROPOS_SERVER_BASE_URL")
|
||||
or os.getenv("OPENAI_BASE_URL")
|
||||
or os.getenv("LLM_BASE_URL")
|
||||
or "https://openrouter.ai/api/v1"
|
||||
)
|
||||
if not base_url.rstrip("/").endswith("/v1"):
|
||||
base_url = base_url.rstrip("/") + "/v1"
|
||||
|
||||
model = (
|
||||
os.getenv("ATROPOS_SERVER_MODEL")
|
||||
or os.getenv("LLM_MODEL")
|
||||
or "Hermes-4.3-36B"
|
||||
)
|
||||
|
||||
api_key = (
|
||||
os.getenv("ATROPOS_SERVER_API_KEY")
|
||||
or os.getenv("NOUS_API_KEY")
|
||||
or os.getenv("OPENROUTER_API_KEY")
|
||||
or os.getenv("OPENAI_API_KEY")
|
||||
or ""
|
||||
)
|
||||
|
||||
tokenizer = (
|
||||
os.getenv("ATROPOS_TOKENIZER_NAME")
|
||||
or os.getenv("ATROPOS_TOKENIZER")
|
||||
or "NousResearch/Hermes-4.3-36B"
|
||||
)
|
||||
|
||||
env_config = GSM8kAgentEnvConfig(
|
||||
# Terminal + file toolsets (same as terminal_test_env.py)
|
||||
enabled_toolsets=["terminal", "file"],
|
||||
disabled_toolsets=None,
|
||||
distribution=None,
|
||||
# Agent settings
|
||||
max_agent_turns=5, # Math problems don't need many turns
|
||||
max_token_length=2048, # Room for reasoning + code
|
||||
agent_temperature=1.0,
|
||||
system_prompt=(
|
||||
"You are a helpful math assistant. You have access to a terminal "
|
||||
"where you can run Python code to help solve problems.\n\n"
|
||||
"When you need to calculate something, use the terminal tool with "
|
||||
"a command like: python3 -c \"print(2 + 2)\"\n\n"
|
||||
"When you have the final answer, write it inside \\boxed{} like: \\boxed{42}\n\n"
|
||||
"Work step by step. Use Python to verify your reasoning."
|
||||
),
|
||||
# Terminal backend (local for testing, modal for production)
|
||||
terminal_backend=os.getenv("TERMINAL_ENV", "local"),
|
||||
# Parser -- hermes format for Hermes models
|
||||
tool_call_parser="hermes",
|
||||
# Atropos settings
|
||||
group_size=4,
|
||||
tokenizer_name=tokenizer,
|
||||
steps_per_eval=5,
|
||||
total_steps=10,
|
||||
use_wandb=bool(os.getenv("WANDB_API_KEY")),
|
||||
wandb_name="gsm8k-agent",
|
||||
ensure_scores_are_not_same=False,
|
||||
# No external dataset (we load GSM8k ourselves)
|
||||
dataset_name=None,
|
||||
)
|
||||
|
||||
server_configs = [
|
||||
APIServerConfig(
|
||||
base_url=base_url,
|
||||
model_name=model,
|
||||
server_type="openai",
|
||||
api_key=api_key,
|
||||
health_check=False,
|
||||
)
|
||||
]
|
||||
|
||||
return env_config, server_configs
|
||||
|
||||
async def setup(self):
|
||||
"""Load GSM8k dataset."""
|
||||
from datasets import load_dataset
|
||||
|
||||
self.train = load_dataset("gsm8k", "main", split="train").shuffle(seed=42)
|
||||
test_data = load_dataset("gsm8k", "main", split="test").shuffle(seed=42)
|
||||
self.test = [
|
||||
{
|
||||
"question": item["question"],
|
||||
"gold_answer": item["answer"].split("#")[-1].strip().replace(",", ""),
|
||||
}
|
||||
for item in test_data
|
||||
]
|
||||
self.iter = 0
|
||||
self.reward_buffer: List[float] = []
|
||||
self.tool_use_buffer: List[int] = []
|
||||
print(f"[GSM8kAgentEnv] Loaded {len(self.train)} train, {len(self.test)} test examples")
|
||||
|
||||
async def get_next_item(self) -> Dict[str, str]:
|
||||
"""Cycle through training problems."""
|
||||
item = self.train[self.iter % len(self.train)]
|
||||
self.iter += 1
|
||||
return {
|
||||
"question": item["question"],
|
||||
"gold_answer": item["answer"].split("#")[-1].strip().replace(",", ""),
|
||||
}
|
||||
|
||||
def format_prompt(self, item: Dict[str, str]) -> str:
|
||||
"""Format the math problem as a user message."""
|
||||
return item["question"]
|
||||
|
||||
async def compute_reward(
|
||||
self, item: Dict[str, str], result: AgentResult, ctx: ToolContext
|
||||
) -> float:
|
||||
"""
|
||||
Score: verify the model's \\boxed{} answer against the gold answer.
|
||||
|
||||
The agent has full access to terminal via ctx, but for GSM8k we just
|
||||
check the final answer from the conversation.
|
||||
"""
|
||||
# Get the last assistant message content
|
||||
final_text = ""
|
||||
for msg in reversed(result.messages):
|
||||
if msg.get("role") == "assistant" and msg.get("content"):
|
||||
final_text = msg["content"]
|
||||
break
|
||||
|
||||
correct = _verify_math_answer(final_text, item["gold_answer"])
|
||||
reward = 1.0 if correct else 0.0
|
||||
|
||||
self.reward_buffer.append(reward)
|
||||
# Count tool calls in this trajectory
|
||||
tool_call_count = sum(
|
||||
len(msg.get("tool_calls", []))
|
||||
for msg in result.messages
|
||||
if msg.get("role") == "assistant"
|
||||
)
|
||||
self.tool_use_buffer.append(tool_call_count)
|
||||
|
||||
return reward
|
||||
|
||||
async def evaluate(self, *args, **kwargs):
|
||||
"""Evaluate on a subset of the test set (greedy, no tools for speed)."""
|
||||
start_time = time.time()
|
||||
correct = 0
|
||||
total = 0
|
||||
samples = []
|
||||
|
||||
eval_subset = self.test[:30] # Small subset for quick eval
|
||||
|
||||
for item in eval_subset:
|
||||
try:
|
||||
completion = await self.server.chat_completion(
|
||||
messages=[
|
||||
{"role": "system", "content": self.config.system_prompt or ""},
|
||||
{"role": "user", "content": item["question"]},
|
||||
],
|
||||
n=1,
|
||||
max_tokens=self.config.max_token_length,
|
||||
temperature=0.0,
|
||||
split="eval",
|
||||
)
|
||||
|
||||
response = completion.choices[0].message.content or ""
|
||||
is_correct = _verify_math_answer(response, item["gold_answer"])
|
||||
|
||||
if is_correct:
|
||||
correct += 1
|
||||
total += 1
|
||||
|
||||
samples.append({
|
||||
"question": item["question"],
|
||||
"gold_answer": item["gold_answer"],
|
||||
"response": response[:500],
|
||||
"correct": is_correct,
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Eval failed: %s", e)
|
||||
total += 1
|
||||
|
||||
percent_correct = correct / total if total > 0 else 0
|
||||
end_time = time.time()
|
||||
|
||||
await self.evaluate_log(
|
||||
metrics={"eval/percent_correct": percent_correct, "eval/total": total},
|
||||
samples=samples,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
)
|
||||
|
||||
async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
|
||||
"""Log training metrics."""
|
||||
if wandb_metrics is None:
|
||||
wandb_metrics = {}
|
||||
|
||||
if self.reward_buffer:
|
||||
wandb_metrics["train/percent_correct"] = sum(self.reward_buffer) / len(self.reward_buffer)
|
||||
wandb_metrics["train/total_rollouts"] = len(self.reward_buffer)
|
||||
self.reward_buffer = []
|
||||
|
||||
if self.tool_use_buffer:
|
||||
wandb_metrics["train/avg_tool_calls"] = sum(self.tool_use_buffer) / len(self.tool_use_buffer)
|
||||
wandb_metrics["train/tool_use_rate"] = sum(1 for t in self.tool_use_buffer if t > 0) / len(self.tool_use_buffer)
|
||||
self.tool_use_buffer = []
|
||||
|
||||
await super().wandb_log(wandb_metrics)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
GSM8kAgentEnv.cli()
|
||||
@@ -45,7 +45,7 @@ if _env_path.exists():
|
||||
# This patches SwerexModalEnvironment to use a background thread instead of
|
||||
# asyncio.run(), which would deadlock inside Atropos. Safe for normal CLI too.
|
||||
from environments.patches import apply_patches
|
||||
apply_patches()
|
||||
# apply_patches() # DISABLED: sglang patch breaks native vLLM /generate
|
||||
|
||||
from atroposlib.envs.base import (
|
||||
BaseEnv,
|
||||
@@ -64,7 +64,7 @@ from environments.agent_loop import AgentResult, HermesAgentLoop
|
||||
from environments.tool_context import ToolContext
|
||||
|
||||
# Import hermes-agent toolset infrastructure
|
||||
from model_tools import get_tool_definitions
|
||||
from model_tools import get_tool_definitions, handle_function_call
|
||||
from toolset_distributions import sample_toolsets_from_distribution
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -140,6 +140,48 @@ class HermesAgentEnvConfig(BaseEnvConfig):
|
||||
"Options: hermes, mistral, llama3_json, qwen, deepseek_v3, etc.",
|
||||
)
|
||||
|
||||
# --- Sandbox pool mode (optional, for scaled environments) ---
|
||||
tool_pool_mode: str = Field(
|
||||
default="default",
|
||||
description="Tool execution mode: 'default' (terminal tool per task_id), "
|
||||
"'nomad' (slot pool via Nomad/Docker/Singularity), or 'modal' (Modal sandbox pool).",
|
||||
)
|
||||
|
||||
# Sandbox pool: shared settings
|
||||
allow_network: bool = Field(default=True, description="Whether sandbox bash commands may access the network.")
|
||||
require_sandbox: bool = Field(default=False, description="Fail closed if bubblewrap is unavailable.")
|
||||
purge_job_on_start: bool = Field(default=False, description="Purge existing sandbox job on startup.")
|
||||
purge_job_on_shutdown: bool = Field(default=True, description="Purge sandbox job on shutdown.")
|
||||
acquire_timeout_s: float = Field(default=30.0, description="Slot acquisition timeout (seconds).")
|
||||
|
||||
# Sandbox pool: Nomad settings
|
||||
nomad_address: str = Field(default="http://localhost:4646", description="Nomad API address.")
|
||||
sandbox_job_id: str = Field(default="atropos-sandbox", description="Nomad job id for sandbox containers.")
|
||||
sandbox_image: str = Field(default="atropos-sandbox:local", description="Docker image for sandbox containers.")
|
||||
slots_per_container: int = Field(default=10, description="Nomad: slots per container.")
|
||||
min_containers: int = Field(default=1, description="Nomad: minimum containers.")
|
||||
max_containers: int = Field(default=10, description="Nomad: maximum containers.")
|
||||
privileged: bool = Field(default=False, description="Nomad: run container privileged.")
|
||||
driver: str = Field(default="docker", description="Nomad task driver: 'docker' or 'singularity'.")
|
||||
singularity_image: Optional[str] = Field(default=None, description="Path to .sif file for Singularity driver.")
|
||||
|
||||
# Sandbox pool: Modal settings
|
||||
modal_app_name: str = Field(default="atropos-sandbox", description="Modal app name prefix.")
|
||||
modal_image: str = Field(default="python:3.11", description="Modal: container image.")
|
||||
modal_gpu: Optional[str] = Field(default=None, description="Modal: GPU type (None, 'T4', 'A10G', 'A100', 'H100').")
|
||||
modal_cpu: float = Field(default=1.0, description="Modal: CPU cores.")
|
||||
modal_memory: int = Field(default=2048, description="Modal: memory in MB.")
|
||||
modal_slots_per_sandbox: int = Field(default=10, description="Modal: slots per sandbox.")
|
||||
modal_min_sandboxes: int = Field(default=1, description="Modal: minimum sandboxes.")
|
||||
modal_max_sandboxes: int = Field(default=5, description="Modal: maximum sandboxes.")
|
||||
modal_idle_timeout: int = Field(default=120, description="Modal: idle timeout (seconds).")
|
||||
modal_max_lifetime: int = Field(default=3600, description="Modal: max sandbox lifetime (seconds).")
|
||||
modal_acquire_timeout: float = Field(default=60.0, description="Modal: slot acquisition timeout (seconds).")
|
||||
modal_execution_timeout: float = Field(default=30.0, description="Modal: command execution timeout (seconds).")
|
||||
modal_secrets: str = Field(default="", description="Modal: comma-separated Modal Secret names.")
|
||||
modal_env_vars: str = Field(default="", description="Modal: semicolon-separated KEY=VALUE pairs.")
|
||||
modal_workspace_base: str = Field(default="/data", description="Modal: workspace base directory.")
|
||||
|
||||
|
||||
class HermesAgentBaseEnv(BaseEnv):
|
||||
"""
|
||||
@@ -186,6 +228,9 @@ class HermesAgentBaseEnv(BaseEnv):
|
||||
# Tool error tracking for wandb logging
|
||||
self._tool_error_buffer: List[Dict[str, Any]] = []
|
||||
|
||||
# Sandbox pool backend (only used when tool_pool_mode != "default")
|
||||
self._sandbox_backend = None
|
||||
|
||||
# =========================================================================
|
||||
# Toolset resolution (per-group)
|
||||
# =========================================================================
|
||||
@@ -225,6 +270,12 @@ class HermesAgentBaseEnv(BaseEnv):
|
||||
# =========================================================================
|
||||
|
||||
def _use_managed_server(self) -> bool:
|
||||
import sys
|
||||
result = self._use_managed_server_inner()
|
||||
print(f"HERMES_DEBUG _use_managed_server={result}, servers={len(self.server.servers) if hasattr(self.server, 'servers') else 'N/A'}, type={type(self.server.servers[0]).__name__ if hasattr(self.server, 'servers') and self.server.servers else 'N/A'}", file=sys.stderr, flush=True)
|
||||
return result
|
||||
|
||||
def _use_managed_server_inner(self) -> bool:
|
||||
"""
|
||||
Determine if we should use ManagedServer (Phase 2) or direct server (Phase 1).
|
||||
|
||||
@@ -242,6 +293,154 @@ class HermesAgentBaseEnv(BaseEnv):
|
||||
from atroposlib.envs.server_handling.openai_server import OpenAIServer
|
||||
return not isinstance(server, OpenAIServer)
|
||||
|
||||
# =========================================================================
|
||||
# Sandbox pool backend (tool_pool_mode != "default")
|
||||
# =========================================================================
|
||||
|
||||
async def _start_sandbox_backend(self) -> None:
|
||||
"""
|
||||
Configure the slot pool backend if tool_pool_mode is not 'default'.
|
||||
|
||||
Sets TERMINAL_ENV=slot_pool and configures env vars so that ALL hermes
|
||||
tools (terminal, file, etc.) automatically route through the sandbox
|
||||
pool via _SlotPoolEnvironment in terminal_tool.py.
|
||||
"""
|
||||
if self.config.tool_pool_mode == "default":
|
||||
return
|
||||
|
||||
mode = self.config.tool_pool_mode
|
||||
logger.info("Configuring slot pool backend (mode=%s)", mode)
|
||||
|
||||
# Set TERMINAL_ENV=slot_pool so terminal_tool.py uses _SlotPoolEnvironment
|
||||
os.environ["TERMINAL_ENV"] = "slot_pool"
|
||||
|
||||
# Set the backend type (modal or nomad)
|
||||
if mode == "modal":
|
||||
os.environ["TERMINAL_SLOT_BACKEND"] = "modal"
|
||||
# Forward modal config from env config to slot pool env vars
|
||||
os.environ.setdefault("TERMINAL_MODAL_IMAGE", self.config.modal_image)
|
||||
os.environ.setdefault("TERMINAL_MODAL_SLOTS", str(self.config.modal_slots_per_sandbox))
|
||||
os.environ.setdefault("TERMINAL_MODAL_MIN", str(self.config.modal_min_sandboxes))
|
||||
os.environ.setdefault("TERMINAL_MODAL_MAX", str(self.config.modal_max_sandboxes))
|
||||
os.environ.setdefault("TERMINAL_MODAL_IDLE_TIMEOUT", str(self.config.modal_idle_timeout))
|
||||
os.environ.setdefault("TERMINAL_MODAL_MAX_LIFETIME", str(self.config.modal_max_lifetime))
|
||||
os.environ.setdefault("TERMINAL_MODAL_ACQUIRE_TIMEOUT", str(self.config.modal_acquire_timeout))
|
||||
os.environ.setdefault("TERMINAL_MODAL_EXEC_TIMEOUT", str(self.config.modal_execution_timeout))
|
||||
os.environ.setdefault("TERMINAL_MODAL_WORKSPACE", self.config.modal_workspace_base)
|
||||
if self.config.modal_gpu:
|
||||
os.environ.setdefault("TERMINAL_MODAL_GPU", self.config.modal_gpu)
|
||||
elif mode == "nomad":
|
||||
os.environ["TERMINAL_SLOT_BACKEND"] = "nomad"
|
||||
os.environ.setdefault("TERMINAL_NOMAD_ADDRESS", self.config.nomad_address)
|
||||
os.environ.setdefault("TERMINAL_NOMAD_IMAGE", self.config.sandbox_image)
|
||||
os.environ.setdefault("TERMINAL_NOMAD_DRIVER", self.config.driver)
|
||||
os.environ.setdefault("TERMINAL_NOMAD_SLOTS", str(self.config.slots_per_container))
|
||||
os.environ.setdefault("TERMINAL_NOMAD_MIN", str(self.config.min_containers))
|
||||
os.environ.setdefault("TERMINAL_NOMAD_MAX", str(self.config.max_containers))
|
||||
|
||||
# Eagerly start the _SlotPoolManager so the backend is ready
|
||||
# before any trajectories try to use it
|
||||
from tools.terminal_tool import _SlotPoolManager
|
||||
_SlotPoolManager.get_instance() # Triggers _start() which creates sandboxes
|
||||
|
||||
self._sandbox_backend = True # Flag that sandbox mode is active
|
||||
print(f"🔧 Slot pool started: TERMINAL_ENV=slot_pool, backend={mode}")
|
||||
|
||||
async def _stop_sandbox_backend(self) -> None:
|
||||
"""Stop the slot pool backend."""
|
||||
if self._sandbox_backend:
|
||||
logger.info("Stopping slot pool backend")
|
||||
try:
|
||||
from tools.terminal_tool import _SlotPoolManager
|
||||
_SlotPoolManager.reset_instance()
|
||||
except Exception as e:
|
||||
logger.warning("Slot pool shutdown: %s", e)
|
||||
self._sandbox_backend = None
|
||||
|
||||
# =========================================================================
|
||||
# Optional hooks for sandbox environments
|
||||
# =========================================================================
|
||||
|
||||
async def setup_trajectory_workspace(
|
||||
self,
|
||||
item: Item,
|
||||
*,
|
||||
trajectory_id: str,
|
||||
exec_tool,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Optional hook: prepare the sandbox workspace before the agent starts.
|
||||
|
||||
Override in subclasses for environments that need workspace setup
|
||||
(e.g., git clone, worktree creation, dependency installation).
|
||||
|
||||
Args:
|
||||
item: The dataset item being rolled out
|
||||
trajectory_id: Unique ID for this trajectory
|
||||
exec_tool: Callable to execute tool calls in the sandbox
|
||||
|
||||
Returns:
|
||||
Dict of workspace metadata (passed to verify_and_score_trajectory)
|
||||
"""
|
||||
return {}
|
||||
|
||||
async def verify_and_score_trajectory(
|
||||
self,
|
||||
item: Item,
|
||||
result: AgentResult,
|
||||
*,
|
||||
trajectory_id: str,
|
||||
exec_tool,
|
||||
workspace_meta: Optional[Dict[str, Any]] = None,
|
||||
) -> Tuple[float, Dict[str, Any]]:
|
||||
"""
|
||||
Optional hook: run in-sandbox verification before scoring.
|
||||
|
||||
Override in subclasses for environments that need to verify results
|
||||
inside the sandbox (e.g., run pytest, check file contents).
|
||||
|
||||
Default: calls compute_reward() with ToolContext.
|
||||
|
||||
Args:
|
||||
item: The dataset item
|
||||
result: The agent's rollout result
|
||||
trajectory_id: Unique ID for this trajectory
|
||||
exec_tool: Callable to execute tool calls in the sandbox
|
||||
workspace_meta: Metadata from setup_trajectory_workspace
|
||||
|
||||
Returns:
|
||||
Tuple of (reward, metadata_dict)
|
||||
"""
|
||||
ctx = ToolContext(trajectory_id)
|
||||
try:
|
||||
reward = await self.compute_reward(item, result, ctx)
|
||||
except Exception as e:
|
||||
logger.error("compute_reward failed: %s", e)
|
||||
reward = 0.0
|
||||
finally:
|
||||
ctx.cleanup()
|
||||
return reward, {}
|
||||
|
||||
# =========================================================================
|
||||
# Lifecycle hooks for env_manager/process_manager cleanup
|
||||
# =========================================================================
|
||||
|
||||
async def env_manager(self):
|
||||
"""Start sandbox backend, run env, then clean up."""
|
||||
await self._start_sandbox_backend()
|
||||
try:
|
||||
return await super().env_manager()
|
||||
finally:
|
||||
await self._stop_sandbox_backend()
|
||||
|
||||
async def process_manager(self):
|
||||
"""Start sandbox backend, run process, then clean up."""
|
||||
await self._start_sandbox_backend()
|
||||
try:
|
||||
return await super().process_manager()
|
||||
finally:
|
||||
await self._stop_sandbox_backend()
|
||||
|
||||
# =========================================================================
|
||||
# Core Atropos integration
|
||||
# =========================================================================
|
||||
@@ -385,6 +584,13 @@ class HermesAgentBaseEnv(BaseEnv):
|
||||
|
||||
await super().wandb_log(wandb_metrics)
|
||||
|
||||
def _use_sandbox_backend(self) -> bool:
|
||||
"""Check if we should route tool execution through a sandbox backend."""
|
||||
return (
|
||||
self.config.tool_pool_mode != "default"
|
||||
and self._sandbox_backend is not None
|
||||
)
|
||||
|
||||
async def collect_trajectory(
|
||||
self, item: Item
|
||||
) -> Tuple[Optional[Union[ScoredDataItem, Any]], List[Item]]:
|
||||
@@ -393,12 +599,19 @@ class HermesAgentBaseEnv(BaseEnv):
|
||||
|
||||
This is called group_size times in parallel by collect_trajectories().
|
||||
Each call gets its own task_id for terminal/browser session isolation.
|
||||
|
||||
When tool_pool_mode != "default", routes tool execution through the
|
||||
sandbox backend (Modal, Nomad) with slot-based multiplexing:
|
||||
1. Acquire a slot from the sandbox pool
|
||||
2. Setup workspace via subclass hook (e.g., git clone + worktree)
|
||||
3. Run agent loop with terminal calls routed through sandbox
|
||||
4. Verify and score in-sandbox via subclass hook (e.g., pytest)
|
||||
5. Release the slot
|
||||
"""
|
||||
task_id = str(uuid.uuid4())
|
||||
|
||||
# Get group-level tools (resolved once in collect_trajectories)
|
||||
if self._current_group_tools is None:
|
||||
# Fallback: resolve per-trajectory if called outside collect_trajectories
|
||||
tools, valid_names = self._resolve_tools_for_group()
|
||||
else:
|
||||
tools, valid_names = self._current_group_tools
|
||||
@@ -409,11 +622,194 @@ class HermesAgentBaseEnv(BaseEnv):
|
||||
messages.append({"role": "system", "content": self.config.system_prompt})
|
||||
messages.append({"role": "user", "content": self.format_prompt(item)})
|
||||
|
||||
# Run the agent loop
|
||||
result: AgentResult
|
||||
# Dispatch to the appropriate path
|
||||
if self._use_sandbox_backend():
|
||||
return await self._collect_trajectory_sandbox(
|
||||
item, task_id, tools, valid_names, messages
|
||||
)
|
||||
else:
|
||||
return await self._collect_trajectory_local(
|
||||
item, task_id, tools, valid_names, messages
|
||||
)
|
||||
|
||||
async def _collect_trajectory_local(
|
||||
self,
|
||||
item: Item,
|
||||
task_id: str,
|
||||
tools: List[Dict[str, Any]],
|
||||
valid_names: Set[str],
|
||||
messages: List[Dict[str, Any]],
|
||||
) -> Tuple[Optional[Union[ScoredDataItem, Any]], List[Item]]:
|
||||
"""
|
||||
Default (local) trajectory collection path.
|
||||
|
||||
Uses hermes-agent's handle_function_call() for tool execution.
|
||||
Reward computed via compute_reward() with ToolContext.
|
||||
"""
|
||||
result = await self._run_agent_loop(
|
||||
task_id, tools, valid_names, messages, tool_handler=None
|
||||
)
|
||||
|
||||
# Skip reward if the agent loop produced no meaningful work
|
||||
only_system_and_user = all(
|
||||
msg.get("role") in ("system", "user") for msg in result.messages
|
||||
)
|
||||
if result.turns_used == 0 or only_system_and_user:
|
||||
logger.warning(
|
||||
"Agent loop produced no output (turns=%d, msgs=%d). Skipping reward.",
|
||||
result.turns_used, len(result.messages),
|
||||
)
|
||||
reward = 0.0
|
||||
else:
|
||||
ctx = ToolContext(task_id)
|
||||
try:
|
||||
reward = await self.compute_reward(item, result, ctx)
|
||||
except Exception as e:
|
||||
logger.error("compute_reward failed: %s", e)
|
||||
reward = 0.0
|
||||
finally:
|
||||
ctx.cleanup()
|
||||
|
||||
return self._build_scored_item(item, result, reward)
|
||||
|
||||
async def _collect_trajectory_sandbox(
|
||||
self,
|
||||
item: Item,
|
||||
task_id: str,
|
||||
tools: List[Dict[str, Any]],
|
||||
valid_names: Set[str],
|
||||
messages: List[Dict[str, Any]],
|
||||
) -> Tuple[Optional[Union[ScoredDataItem, Any]], List[Item]]:
|
||||
"""
|
||||
Sandbox trajectory collection path (Modal, Nomad).
|
||||
|
||||
Uses TERMINAL_ENV=slot_pool so ALL hermes tools (terminal, file, web)
|
||||
automatically route through the sandbox pool via _SlotPoolEnvironment.
|
||||
No per-tool routing needed — the slot pool is the terminal backend.
|
||||
|
||||
Flow:
|
||||
1. Pre-warm terminal env (acquires a slot in the pool)
|
||||
2. Setup workspace via subclass hook (e.g., git clone + worktree)
|
||||
3. Run agent loop with tool_handler=None (all tools use handle_function_call)
|
||||
4. Verify and score in-sandbox via subclass hook (e.g., pytest)
|
||||
5. Release the slot via cleanup_vm()
|
||||
"""
|
||||
from tools.terminal_tool import _SlotPoolManager, cleanup_vm
|
||||
from dataclasses import dataclass
|
||||
|
||||
@dataclass
|
||||
class _ExecResult:
|
||||
"""Lightweight result for exec_tool compatibility with env hooks."""
|
||||
success: bool
|
||||
output: str = ""
|
||||
error: str = ""
|
||||
metadata: Dict[str, Any] = None
|
||||
def __post_init__(self):
|
||||
if self.metadata is None:
|
||||
self.metadata = {}
|
||||
|
||||
try:
|
||||
# 1. Pre-warm: trigger terminal env creation → acquires slot
|
||||
logger.info("Pre-warming sandbox slot for task %s", task_id)
|
||||
loop = asyncio.get_event_loop()
|
||||
warmup = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: handle_function_call(
|
||||
"terminal", {"command": "echo slot_ready"}, task_id=task_id
|
||||
),
|
||||
)
|
||||
logger.info("Sandbox slot acquired for task %s", task_id)
|
||||
|
||||
# 2. Create exec_tool for setup/verify hooks
|
||||
# Routes through handle_function_call → terminal_tool → same _SlotPoolEnvironment
|
||||
async def exec_tool(tool_name: str, args: Dict[str, Any], timeout: float = 300) -> _ExecResult:
|
||||
command = args.get("command", "")
|
||||
result_json = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: handle_function_call(
|
||||
"terminal",
|
||||
{"command": command, "timeout": int(timeout)},
|
||||
task_id=task_id,
|
||||
),
|
||||
)
|
||||
try:
|
||||
result_dict = json.loads(result_json)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
result_dict = {"output": str(result_json), "exit_code": 1}
|
||||
returncode = result_dict.get("exit_code", result_dict.get("returncode", 1))
|
||||
output = result_dict.get("output", "")
|
||||
return _ExecResult(
|
||||
success=(returncode == 0),
|
||||
output=output,
|
||||
error=result_dict.get("error", "") if returncode != 0 else "",
|
||||
metadata={"returncode": returncode},
|
||||
)
|
||||
|
||||
# 3. Setup workspace (subclass hook: git clone, worktree, etc.)
|
||||
workspace_meta = await self.setup_trajectory_workspace(
|
||||
item, trajectory_id=task_id, exec_tool=exec_tool
|
||||
)
|
||||
|
||||
# 4. Run agent loop — tool_handler=None means ALL tools go through
|
||||
# handle_function_call() → terminal_tool() → _SlotPoolEnvironment
|
||||
# → same sandbox slot. File tools also route through same env.
|
||||
result = await self._run_agent_loop(
|
||||
task_id, tools, valid_names, messages,
|
||||
tool_handler=None,
|
||||
)
|
||||
|
||||
# 5. Skip verification if no meaningful work
|
||||
only_system_and_user = all(
|
||||
msg.get("role") in ("system", "user") for msg in result.messages
|
||||
)
|
||||
if result.turns_used == 0 or only_system_and_user:
|
||||
logger.warning(
|
||||
"Agent loop produced no output (turns=%d, msgs=%d). Skipping reward.",
|
||||
result.turns_used, len(result.messages),
|
||||
)
|
||||
reward = 0.0
|
||||
else:
|
||||
# 6. Verify and score in-sandbox (subclass hook: pytest, etc.)
|
||||
reward, score_meta = await self.verify_and_score_trajectory(
|
||||
item, result,
|
||||
trajectory_id=task_id,
|
||||
exec_tool=exec_tool,
|
||||
workspace_meta=workspace_meta,
|
||||
)
|
||||
logger.info("Sandbox reward for task %s: %.2f", task_id, reward)
|
||||
|
||||
return self._build_scored_item(item, result, reward)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Sandbox trajectory failed for task %s: %s", task_id, e, exc_info=True)
|
||||
dummy_result = AgentResult(
|
||||
messages=messages, turns_used=0, finished_naturally=False
|
||||
)
|
||||
return self._build_scored_item(item, dummy_result, 0.0)
|
||||
|
||||
finally:
|
||||
# Release the slot back to the pool
|
||||
try:
|
||||
cleanup_vm(task_id)
|
||||
logger.info("Released sandbox slot for task %s", task_id)
|
||||
except Exception as e:
|
||||
logger.error("Failed to release slot for task %s: %s", task_id, e)
|
||||
|
||||
async def _run_agent_loop(
|
||||
self,
|
||||
task_id: str,
|
||||
tools: List[Dict[str, Any]],
|
||||
valid_names: Set[str],
|
||||
messages: List[Dict[str, Any]],
|
||||
tool_handler=None,
|
||||
) -> AgentResult:
|
||||
"""
|
||||
Run the agent loop in either Phase 1 or Phase 2 mode.
|
||||
|
||||
Shared between local and sandbox paths -- the only difference is
|
||||
the tool_handler parameter (None for local, sandbox callable for sandbox).
|
||||
"""
|
||||
if self._use_managed_server():
|
||||
# Phase 2: ManagedServer with parser -- exact tokens + logprobs
|
||||
# Load the tool call parser from registry based on config
|
||||
from environments.tool_call_parsers import get_parser
|
||||
try:
|
||||
tc_parser = get_parser(self.config.tool_call_parser)
|
||||
@@ -429,6 +825,13 @@ class HermesAgentBaseEnv(BaseEnv):
|
||||
tokenizer=self.tokenizer,
|
||||
tool_call_parser=tc_parser,
|
||||
) as managed:
|
||||
# Calculate max prompt tokens
|
||||
# Context budget = max_token_length (prompt can be as long as generation budget)
|
||||
# This ensures prompt + generation stays under typical model context limits
|
||||
# E.g., max_token_length=16384 → 16384 prompt + 16384 gen = 32K < 40960 model limit
|
||||
_max_ctx = None
|
||||
if self.config.max_token_length and self.config.max_token_length > 0:
|
||||
_max_ctx = self.config.max_token_length
|
||||
agent = HermesAgentLoop(
|
||||
server=managed,
|
||||
tool_schemas=tools,
|
||||
@@ -437,14 +840,18 @@ class HermesAgentBaseEnv(BaseEnv):
|
||||
task_id=task_id,
|
||||
temperature=self.config.agent_temperature,
|
||||
max_tokens=self.config.max_token_length,
|
||||
tool_handler=tool_handler,
|
||||
max_context_tokens=_max_ctx,
|
||||
)
|
||||
result = await agent.run(messages)
|
||||
return await agent.run(messages)
|
||||
except NotImplementedError:
|
||||
# DummyManagedServer not allowed -- fall back to Phase 1
|
||||
logger.warning(
|
||||
"ManagedServer not available (OpenAI server?). "
|
||||
"Falling back to direct server mode."
|
||||
)
|
||||
_max_ctx = None
|
||||
if self.config.max_token_length and self.config.max_token_length > 0:
|
||||
_max_ctx = self.config.max_token_length
|
||||
agent = HermesAgentLoop(
|
||||
server=self.server,
|
||||
tool_schemas=tools,
|
||||
@@ -453,10 +860,14 @@ class HermesAgentBaseEnv(BaseEnv):
|
||||
task_id=task_id,
|
||||
temperature=self.config.agent_temperature,
|
||||
max_tokens=self.config.max_token_length,
|
||||
tool_handler=tool_handler,
|
||||
max_context_tokens=_max_ctx,
|
||||
)
|
||||
result = await agent.run(messages)
|
||||
return await agent.run(messages)
|
||||
else:
|
||||
# Phase 1: OpenAI server -- native tool_calls, placeholder tokens
|
||||
_max_ctx = None
|
||||
if self.config.max_token_length and self.config.max_token_length > 0:
|
||||
_max_ctx = self.config.max_token_length
|
||||
agent = HermesAgentLoop(
|
||||
server=self.server,
|
||||
tool_schemas=tools,
|
||||
@@ -465,32 +876,22 @@ class HermesAgentBaseEnv(BaseEnv):
|
||||
task_id=task_id,
|
||||
temperature=self.config.agent_temperature,
|
||||
max_tokens=self.config.max_token_length,
|
||||
tool_handler=tool_handler,
|
||||
max_context_tokens=_max_ctx,
|
||||
)
|
||||
result = await agent.run(messages)
|
||||
return await agent.run(messages)
|
||||
|
||||
# Skip reward computation if the agent loop produced no meaningful work
|
||||
# (e.g., API call failed on turn 1). No point spinning up a Modal sandbox
|
||||
# just to verify files that were never created.
|
||||
only_system_and_user = all(
|
||||
msg.get("role") in ("system", "user") for msg in result.messages
|
||||
)
|
||||
if result.turns_used == 0 or only_system_and_user:
|
||||
logger.warning(
|
||||
"Agent loop produced no output (turns=%d, msgs=%d). Skipping reward.",
|
||||
result.turns_used, len(result.messages),
|
||||
)
|
||||
reward = 0.0
|
||||
else:
|
||||
# Compute reward using ToolContext (gives verifier full tool access)
|
||||
ctx = ToolContext(task_id)
|
||||
try:
|
||||
reward = await self.compute_reward(item, result, ctx)
|
||||
except Exception as e:
|
||||
logger.error("compute_reward failed: %s", e)
|
||||
reward = 0.0
|
||||
finally:
|
||||
ctx.cleanup()
|
||||
def _build_scored_item(
|
||||
self,
|
||||
item: Item,
|
||||
result: AgentResult,
|
||||
reward: float,
|
||||
) -> Tuple[Optional[Union[ScoredDataItem, Any]], List[Item]]:
|
||||
"""
|
||||
Build a ScoredDataItem from an AgentResult and reward.
|
||||
|
||||
Shared between local and sandbox paths.
|
||||
"""
|
||||
# Track tool errors for wandb logging
|
||||
if result.tool_errors:
|
||||
for err in result.tool_errors:
|
||||
@@ -503,28 +904,19 @@ class HermesAgentBaseEnv(BaseEnv):
|
||||
})
|
||||
|
||||
# Build ScoredDataItem from ManagedServer state
|
||||
# Phase 2: real tokens/masks/logprobs from SequenceNodes
|
||||
# Phase 1: placeholder tokens (still need a valid ScoredDataItem for the pipeline)
|
||||
nodes = (result.managed_state or {}).get("nodes", [])
|
||||
|
||||
if nodes:
|
||||
# Phase 2 (or DummyManagedServer): use actual node data
|
||||
node = nodes[-1] # Final sequence node = full trajectory
|
||||
node = nodes[-1]
|
||||
scored_item: Dict[str, Any] = {
|
||||
"tokens": node.tokens,
|
||||
"masks": node.masked_tokens,
|
||||
"scores": reward,
|
||||
}
|
||||
|
||||
# Include logprobs if available (Phase 2)
|
||||
if hasattr(node, "logprobs") and node.logprobs:
|
||||
scored_item["advantages"] = None # Computed by trainer
|
||||
scored_item["advantages"] = None
|
||||
scored_item["ref_logprobs"] = None
|
||||
else:
|
||||
# Phase 1 with no managed state: create placeholder tokens
|
||||
# so the data pipeline doesn't break. These are NOT suitable
|
||||
# for training but allow process mode (SFT data gen) to work.
|
||||
# Tokenize the full conversation to get approximate tokens.
|
||||
full_text = "\n".join(
|
||||
msg.get("content", "") for msg in result.messages if msg.get("content")
|
||||
)
|
||||
@@ -535,13 +927,11 @@ class HermesAgentBaseEnv(BaseEnv):
|
||||
|
||||
scored_item = {
|
||||
"tokens": tokens,
|
||||
"masks": [-100] + tokens[1:], # Mask first token as prompt
|
||||
"masks": [-100] + tokens[1:],
|
||||
"scores": reward,
|
||||
}
|
||||
|
||||
# Always include messages for wandb rollout display and data logging
|
||||
scored_item["messages"] = result.messages
|
||||
|
||||
return scored_item, []
|
||||
|
||||
# =========================================================================
|
||||
|
||||
@@ -171,6 +171,126 @@ def _patch_swerex_modal():
|
||||
logger.debug("Patched SwerexModalEnvironment for async-safe operation")
|
||||
|
||||
|
||||
def _patch_vllm_server_for_sglang():
|
||||
"""
|
||||
(Mainly for Runpod serverless compat)
|
||||
|
||||
Monkey patch VLLMServer._tokens_and_logprobs_completion_wrapper to handle
|
||||
SGLang's /generate response format.
|
||||
|
||||
VLLMServer expects:
|
||||
Request: {"prompt": {"prompt_token_ids": [...]}, "logprobs": 0}
|
||||
Response: {"logprobs": [[{token_id: logprob}]], "finish_reasons": [...]}
|
||||
|
||||
SGLang returns:
|
||||
Request: {"input_ids": [...], "sampling_params": {...}, "return_logprob": true}
|
||||
Response: {"text": "...", "meta_info": {"output_token_logprobs": [[logprob, token_id, text], ...]}}
|
||||
|
||||
This patch makes VLLMServer work with SGLang endpoints (e.g., RunPod SGLang workers).
|
||||
"""
|
||||
try:
|
||||
import aiohttp
|
||||
from atroposlib.envs.server_handling.vllm_server import VLLMServer
|
||||
except ImportError:
|
||||
logger.debug("atroposlib VLLMServer not available, skipping SGLang patch")
|
||||
return
|
||||
|
||||
# Save the original method
|
||||
_original_wrapper = VLLMServer._tokens_and_logprobs_completion_wrapper
|
||||
|
||||
async def _sglang_compatible_wrapper(self, **kwargs):
|
||||
"""
|
||||
Patched wrapper that tries the original VLLMServer format first,
|
||||
then falls back to SGLang format if that fails.
|
||||
"""
|
||||
assert kwargs.get("model") is not None, "Model is required!"
|
||||
assert kwargs.get("prompt") is not None or kwargs.get("input_ids") is not None, "Prompt or input_ids required!"
|
||||
|
||||
# Get prompt tokens
|
||||
if "input_ids" in kwargs:
|
||||
prompt_tokens = kwargs.pop("input_ids")
|
||||
kwargs.pop("prompt", None)
|
||||
else:
|
||||
prompt_tokens = self.tokenizer.encode(kwargs.pop("prompt"))
|
||||
|
||||
# Check for double BOS
|
||||
if (len(prompt_tokens) >= 2
|
||||
and prompt_tokens[0] == self.tokenizer.bos_token_id == prompt_tokens[1]):
|
||||
prompt_tokens = prompt_tokens[1:]
|
||||
|
||||
# Normalize kwargs
|
||||
max_tokens = kwargs.pop("max_new_tokens", kwargs.pop("max_completion_tokens", kwargs.pop("max_tokens", 2048)))
|
||||
n = kwargs.pop("n", 1)
|
||||
temperature = kwargs.pop("temperature", 1.0)
|
||||
kwargs.pop("model", None)
|
||||
|
||||
# Build SGLang-compatible request
|
||||
request_data = {
|
||||
"input_ids": prompt_tokens,
|
||||
"sampling_params": {
|
||||
"max_new_tokens": max_tokens,
|
||||
"temperature": temperature,
|
||||
"n": n,
|
||||
},
|
||||
"return_logprob": True,
|
||||
"top_logprobs_num": 0,
|
||||
}
|
||||
|
||||
generate_url = f"{self.config.base_url.replace('/v1', '')}/generate"
|
||||
|
||||
headers = {}
|
||||
if self.config.api_key:
|
||||
headers["Authorization"] = f"Bearer {self.config.api_key}"
|
||||
headers["Content-Type"] = "application/json"
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
generate_url,
|
||||
json=request_data,
|
||||
headers=headers,
|
||||
timeout=aiohttp.ClientTimeout(total=self.config.timeout),
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
raw_text = await response.text()
|
||||
|
||||
# RunPod wraps JSON responses in quotes — may need double-parse
|
||||
import json
|
||||
results = json.loads(raw_text)
|
||||
if isinstance(results, str):
|
||||
results = json.loads(results)
|
||||
|
||||
# Parse SGLang response format
|
||||
meta = results.get("meta_info", {})
|
||||
output_token_logprobs_raw = meta.get("output_token_logprobs", [])
|
||||
|
||||
# SGLang format: [[logprob, token_id, token_text], ...]
|
||||
output_tokens = []
|
||||
output_logprobs = []
|
||||
for entry in output_token_logprobs_raw:
|
||||
if isinstance(entry, (list, tuple)) and len(entry) >= 2:
|
||||
logprob, token_id = entry[0], entry[1]
|
||||
output_tokens.append(int(token_id))
|
||||
output_logprobs.append(float(logprob))
|
||||
|
||||
# Get finish reason
|
||||
finish_reason_raw = meta.get("finish_reason", "stop")
|
||||
if isinstance(finish_reason_raw, dict):
|
||||
finish_reason = finish_reason_raw.get("type", "stop")
|
||||
else:
|
||||
finish_reason = str(finish_reason_raw)
|
||||
|
||||
return (
|
||||
prompt_tokens,
|
||||
[output_tokens],
|
||||
[output_logprobs],
|
||||
[finish_reason],
|
||||
)
|
||||
|
||||
# Apply the patch
|
||||
VLLMServer._tokens_and_logprobs_completion_wrapper = _sglang_compatible_wrapper
|
||||
logger.info("Patched VLLMServer for SGLang /generate compatibility")
|
||||
|
||||
|
||||
def apply_patches():
|
||||
"""
|
||||
Apply all monkey patches needed for Atropos compatibility.
|
||||
@@ -184,5 +304,6 @@ def apply_patches():
|
||||
return
|
||||
|
||||
_patch_swerex_modal()
|
||||
# _patch_vllm_server_for_sglang()
|
||||
|
||||
_patches_applied = True
|
||||
|
||||
620
environments/swe_smith_oracle_env.py
Normal file
620
environments/swe_smith_oracle_env.py
Normal file
@@ -0,0 +1,620 @@
|
||||
"""
|
||||
SWE-smith-oracle environment (ported to HermesAgentBaseEnv).
|
||||
|
||||
Trains models to fix real GitHub repositories:
|
||||
- Clones a public GitHub repo at a specific commit
|
||||
- Runs an agent loop with terminal tool to apply a fix
|
||||
- Verifies by running pytest with nodeids from the dataset
|
||||
- Reward: 1.0 if all tests pass, 0.0 otherwise
|
||||
|
||||
Dataset: NousResearch/SWE-smith-oracle (train split; does NOT use SWE-bench eval set).
|
||||
|
||||
Usage:
|
||||
# Process mode (OpenAI server, no training):
|
||||
python environments/swe_smith_oracle_env.py process \\
|
||||
--env.data_path_to_save_groups data/swe_oracle_output.jsonl
|
||||
|
||||
# With Modal sandbox backend:
|
||||
python environments/swe_smith_oracle_env.py process \\
|
||||
--env.tool_pool_mode modal \\
|
||||
--env.modal_image python:3.11
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
_repo_root = Path(__file__).resolve().parent.parent
|
||||
if str(_repo_root) not in sys.path:
|
||||
sys.path.insert(0, str(_repo_root))
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from atroposlib.envs.base import ScoredDataGroup
|
||||
from atroposlib.envs.server_handling.server_manager import APIServerConfig
|
||||
from atroposlib.type_definitions import Item
|
||||
|
||||
from environments.agent_loop import AgentResult
|
||||
from environments.hermes_base_env import HermesAgentBaseEnv, HermesAgentEnvConfig
|
||||
from environments.tool_context import ToolContext
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Config
|
||||
# =============================================================================
|
||||
|
||||
class SweSmithOracleEnvConfig(HermesAgentEnvConfig):
|
||||
"""Config for SWE-smith-oracle environment."""
|
||||
|
||||
dataset_name: str = Field(default="NousResearch/SWE-smith-oracle")
|
||||
dataset_split: str = Field(default="train")
|
||||
max_items: int = Field(default=0, description="0 = no limit")
|
||||
shuffle: bool = Field(default=True)
|
||||
seed: int = Field(default=0)
|
||||
|
||||
python_only: bool = Field(default=True, description="Filter to Python-evaluable rows")
|
||||
score_include_fail_to_pass: bool = Field(
|
||||
default=True,
|
||||
description="Score tests on PASS_TO_PASS ∪ FAIL_TO_PASS. "
|
||||
"Disable to only run PASS_TO_PASS (faster but weaker signal).",
|
||||
)
|
||||
|
||||
prompt_mode: str = Field(
|
||||
default="problem_statement",
|
||||
description="'problem_statement' (fast) or 'problem_statement+text' (includes dataset 'text').",
|
||||
)
|
||||
|
||||
repo_base_url: str = Field(default="https://github.com", description="Base URL for repo cloning")
|
||||
install_timeout_s: float = Field(default=600.0)
|
||||
test_timeout_s: float = Field(default=600.0)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Environment
|
||||
# =============================================================================
|
||||
|
||||
class SweSmithOracleEnv(HermesAgentBaseEnv):
|
||||
"""
|
||||
SWE-smith-oracle environment for training models to fix real GitHub repos.
|
||||
|
||||
Uses proper OpenAI-spec tool calling via HermesAgentBaseEnv.
|
||||
The model gets terminal access to inspect, edit, and test the repository.
|
||||
"""
|
||||
|
||||
name = "swe-smith-oracle"
|
||||
env_config_cls = SweSmithOracleEnvConfig
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: SweSmithOracleEnvConfig,
|
||||
server_configs,
|
||||
slurm=False,
|
||||
testing=False,
|
||||
):
|
||||
super().__init__(config, server_configs, slurm, testing)
|
||||
self._dataset = None
|
||||
self._indices: List[int] = []
|
||||
self._cursor = 0
|
||||
|
||||
@classmethod
|
||||
def config_init(cls) -> Tuple[SweSmithOracleEnvConfig, List[APIServerConfig]]:
|
||||
"""Default config — reads from ATROPOS_SERVER_* env vars."""
|
||||
base_url = (
|
||||
os.getenv("ATROPOS_SERVER_BASE_URL")
|
||||
or os.getenv("OPENAI_BASE_URL")
|
||||
or os.getenv("LLM_BASE_URL")
|
||||
or "http://127.0.0.1:8080"
|
||||
)
|
||||
if not base_url.rstrip("/").endswith("/v1"):
|
||||
base_url = base_url.rstrip("/") + "/v1"
|
||||
|
||||
model = os.getenv("ATROPOS_SERVER_MODEL") or os.getenv("LLM_MODEL") or "Hermes-4.3-36B"
|
||||
api_key = (
|
||||
os.getenv("ATROPOS_SERVER_API_KEY")
|
||||
or os.getenv("NOUS_API_KEY")
|
||||
or os.getenv("OPENAI_API_KEY")
|
||||
or "local"
|
||||
)
|
||||
|
||||
env_config = SweSmithOracleEnvConfig(
|
||||
tokenizer_name=os.getenv("ATROPOS_TOKENIZER_NAME") or "NousResearch/Hermes-4.3-36B",
|
||||
group_size=1,
|
||||
use_wandb=False,
|
||||
rollout_server_url="http://localhost:8000",
|
||||
total_steps=1,
|
||||
batch_size=1,
|
||||
steps_per_eval=1,
|
||||
max_token_length=8192,
|
||||
wandb_name="swe_smith_oracle",
|
||||
enabled_toolsets=["terminal", "file"],
|
||||
terminal_backend=os.getenv("TERMINAL_ENV", "local"),
|
||||
# Longer agent turns for SWE tasks
|
||||
max_agent_turns=50,
|
||||
agent_temperature=0.7,
|
||||
system_prompt=(
|
||||
"You are a senior software engineer. You have access to a terminal "
|
||||
"to inspect and fix repositories. Use non-interactive commands only. "
|
||||
"Each terminal command runs in a fresh shell."
|
||||
),
|
||||
tool_call_parser="hermes",
|
||||
# Sandbox settings (used when tool_pool_mode != "default")
|
||||
sandbox_image=os.getenv("ATROPOS_SANDBOX_IMAGE") or "atropos-sandbox:local",
|
||||
purge_job_on_start=True,
|
||||
purge_job_on_shutdown=True,
|
||||
)
|
||||
|
||||
server_configs = [
|
||||
APIServerConfig(
|
||||
model_name=model,
|
||||
base_url=base_url,
|
||||
api_key=api_key,
|
||||
server_type="vllm",
|
||||
health_check=False,
|
||||
timeout=int(os.getenv("ATROPOS_SERVER_TIMEOUT_S") or "300"),
|
||||
),
|
||||
]
|
||||
|
||||
return env_config, server_configs
|
||||
|
||||
# =========================================================================
|
||||
# Dataset loading
|
||||
# =========================================================================
|
||||
|
||||
async def setup(self):
|
||||
"""Load SWE-smith-oracle dataset."""
|
||||
from datasets import load_dataset
|
||||
|
||||
t0 = time.perf_counter()
|
||||
print(
|
||||
f"[SweSmithOracleEnv] loading dataset {self.config.dataset_name}:{self.config.dataset_split} "
|
||||
f"(python_only={self.config.python_only}, max_items={self.config.max_items or 'all'})",
|
||||
flush=True,
|
||||
)
|
||||
ds = load_dataset(self.config.dataset_name, split=self.config.dataset_split)
|
||||
self._dataset = ds
|
||||
|
||||
indices: List[int] = []
|
||||
for idx in range(len(ds)):
|
||||
row = ds[idx]
|
||||
if self.config.python_only and not self._is_python_row(row):
|
||||
continue
|
||||
indices.append(idx)
|
||||
|
||||
if self.config.shuffle:
|
||||
rnd = random.Random(self.config.seed)
|
||||
rnd.shuffle(indices)
|
||||
|
||||
if self.config.max_items and self.config.max_items > 0:
|
||||
indices = indices[: self.config.max_items]
|
||||
|
||||
self._indices = indices
|
||||
self._cursor = 0
|
||||
|
||||
print(
|
||||
f"[SweSmithOracleEnv] loaded {len(self._indices)} items "
|
||||
f"in {time.perf_counter() - t0:.2f}s",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
def _is_python_row(self, row: Dict[str, Any]) -> bool:
|
||||
nodeids = row.get("PASS_TO_PASS")
|
||||
if not isinstance(nodeids, list) or not nodeids:
|
||||
return False
|
||||
return all(isinstance(nid, str) and ".py::" in nid for nid in nodeids)
|
||||
|
||||
async def get_next_item(self) -> Item:
|
||||
if not self._dataset or not self._indices:
|
||||
raise RuntimeError("Dataset not initialized")
|
||||
if self._cursor >= len(self._indices):
|
||||
self._cursor = 0
|
||||
idx = self._indices[self._cursor]
|
||||
self._cursor += 1
|
||||
return dict(self._dataset[idx])
|
||||
|
||||
# =========================================================================
|
||||
# Prompt formatting
|
||||
# =========================================================================
|
||||
|
||||
def _repo_name(self, item: Item) -> str:
|
||||
repo = item.get("repo") or ""
|
||||
if isinstance(repo, str) and "/" in repo:
|
||||
return repo.split("/")[-1]
|
||||
return "repo"
|
||||
|
||||
def format_prompt(self, item: Item) -> str:
|
||||
"""Build the SWE task prompt."""
|
||||
repo = item.get("repo") or ""
|
||||
base_commit = item.get("base_commit") or ""
|
||||
problem = str(item.get("problem_statement") or "")
|
||||
context = str(item.get("text") or "")
|
||||
repo_dir = self._repo_name(item)
|
||||
|
||||
nodeids = self._tests_for_item(item)
|
||||
tests_list = "\n".join(f"- {t}" for t in nodeids)
|
||||
|
||||
context_block = ""
|
||||
prompt_mode = (self.config.prompt_mode or "problem_statement").strip().lower()
|
||||
if prompt_mode == "problem_statement+text" and context:
|
||||
context_block = f"\nAdditional context:\n{context}\n"
|
||||
|
||||
return (
|
||||
f"Fix the repository so the specified tests pass.\n\n"
|
||||
f"Repository: {repo} (checked out at base_commit={base_commit})\n"
|
||||
f"Workspace path: ./{repo_dir}\n\n"
|
||||
"Constraints:\n"
|
||||
"- Use the terminal tool to inspect, edit, and verify the repository.\n"
|
||||
f"- Start by inspecting the repo (e.g. `ls`, `cd ./{repo_dir}`, `git status`).\n"
|
||||
"- Use a workspace-local virtualenv (.venv) to avoid cross-run contamination.\n"
|
||||
"- Use non-interactive commands only.\n"
|
||||
"- Prefer `. .venv/bin/activate` or `.venv/bin/python ...` (POSIX compatible).\n\n"
|
||||
f"Problem statement:\n{problem}\n\n"
|
||||
f"{context_block}"
|
||||
f"Run these tests to verify:\n{tests_list}\n\n"
|
||||
"When done, briefly describe what you changed and confirm tests pass."
|
||||
)
|
||||
|
||||
# =========================================================================
|
||||
# Test helpers
|
||||
# =========================================================================
|
||||
|
||||
def _tests_for_item(self, item: Item) -> List[str]:
|
||||
tests: List[str] = []
|
||||
if self.config.score_include_fail_to_pass:
|
||||
for key in ("PASS_TO_PASS", "FAIL_TO_PASS"):
|
||||
nodeids = item.get(key)
|
||||
if isinstance(nodeids, list):
|
||||
tests.extend([n for n in nodeids if isinstance(n, str)])
|
||||
else:
|
||||
nodeids = item.get("PASS_TO_PASS")
|
||||
if isinstance(nodeids, list):
|
||||
tests.extend([n for n in nodeids if isinstance(n, str)])
|
||||
return sorted(dict.fromkeys(tests))
|
||||
|
||||
def _chunk_nodeids(self, nodeids: List[str], max_per_chunk: int = 50) -> List[List[str]]:
|
||||
return [nodeids[i : i + max_per_chunk] for i in range(0, len(nodeids), max_per_chunk)]
|
||||
|
||||
# =========================================================================
|
||||
# Sandbox hooks: setup_trajectory_workspace + verify_and_score_trajectory
|
||||
# =========================================================================
|
||||
|
||||
async def setup_trajectory_workspace(
|
||||
self, item: Item, *, trajectory_id: str, exec_tool
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Prepare a sandbox workspace: bare repo cache + git worktree.
|
||||
|
||||
Uses flock-serialized bare repo cache under /data/repo_cache so
|
||||
multiple trajectories sharing a sandbox don't clone the same repo
|
||||
in parallel. Each trajectory gets an isolated worktree at the
|
||||
specified base_commit.
|
||||
|
||||
Args:
|
||||
item: Dataset row with repo, base_commit, etc.
|
||||
trajectory_id: Unique trajectory ID
|
||||
exec_tool: async callable(tool_name, args, timeout) -> ExecutionResult
|
||||
|
||||
Returns:
|
||||
Dict with repo_dir, base_commit metadata
|
||||
"""
|
||||
import time as _time
|
||||
|
||||
t0 = _time.perf_counter()
|
||||
repo = item.get("repo")
|
||||
base_commit = item.get("base_commit")
|
||||
instance_id = item.get("instance_id") or item.get("id") or item.get("problem_id")
|
||||
if not isinstance(repo, str) or not isinstance(base_commit, str):
|
||||
raise RuntimeError("Invalid dataset row: missing repo/base_commit")
|
||||
|
||||
repo_dir = self._repo_name(item)
|
||||
clone_url = f"{self.config.repo_base_url.rstrip('/')}/{repo}.git"
|
||||
print(
|
||||
f"[SweSmithOracleEnv] tid={trajectory_id} setup_trajectory_workspace(): "
|
||||
f"repo={repo} base_commit={base_commit} instance_id={instance_id} dir=./{repo_dir}",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
# Bare repo cache + worktree strategy (same as atropos/envs/swe_smith_oracle_env.py)
|
||||
repo_slug = repo.replace("/", "__")
|
||||
cache_root = "/data/repo_cache"
|
||||
bare_repo = f"{cache_root}/{repo_slug}.git"
|
||||
lock_file = f"{cache_root}/.locks/{repo_slug}.lock"
|
||||
|
||||
worktree_cmd = (
|
||||
"set -e; "
|
||||
f"rm -rf {repo_dir}; "
|
||||
f"mkdir -p {cache_root}/.locks; "
|
||||
f": > {lock_file}; "
|
||||
f"flock -x {lock_file} sh -lc '"
|
||||
f"set -e; "
|
||||
"export GIT_TERMINAL_PROMPT=0; "
|
||||
"export GIT_LFS_SKIP_SMUDGE=1; "
|
||||
f"if [ ! -d \"{bare_repo}\" ]; then "
|
||||
f" git init --bare \"{bare_repo}\"; "
|
||||
f" git -C \"{bare_repo}\" remote add origin \"{clone_url}\"; "
|
||||
"fi; "
|
||||
f"git -C \"{bare_repo}\" remote set-url origin \"{clone_url}\"; "
|
||||
f"git -C \"{bare_repo}\" worktree prune || true; "
|
||||
f"if ! git -C \"{bare_repo}\" cat-file -e \"{base_commit}^{{commit}}\" 2>/dev/null; then "
|
||||
f" git -C \"{bare_repo}\" fetch --depth 1 origin \"{base_commit}\" || true; "
|
||||
"fi; "
|
||||
f"if ! git -C \"{bare_repo}\" cat-file -e \"{base_commit}^{{commit}}\" 2>/dev/null; then "
|
||||
f" git -C \"{bare_repo}\" fetch --prune origin; "
|
||||
"fi; "
|
||||
f"git --git-dir=\"{bare_repo}\" worktree add --detach \"{repo_dir}\" \"{base_commit}\"; "
|
||||
"'"
|
||||
)
|
||||
|
||||
print(f"[SweSmithOracleEnv] tid={trajectory_id} preparing worktree from repo cache", flush=True)
|
||||
res = await exec_tool(
|
||||
"bash",
|
||||
{"command": worktree_cmd},
|
||||
timeout=self.config.install_timeout_s,
|
||||
)
|
||||
if not res.success:
|
||||
raise RuntimeError(
|
||||
f"git worktree setup failed "
|
||||
f"(repo={repo}, base_commit={base_commit}, instance_id={instance_id}): "
|
||||
f"{res.error}\n{res.output}"
|
||||
)
|
||||
|
||||
print(
|
||||
f"[SweSmithOracleEnv] tid={trajectory_id} setup_trajectory_workspace(): "
|
||||
f"worktree ready in {_time.perf_counter() - t0:.2f}s",
|
||||
flush=True,
|
||||
)
|
||||
return {"repo_dir": repo_dir, "base_commit": base_commit}
|
||||
|
||||
async def verify_and_score_trajectory(
|
||||
self,
|
||||
item: Item,
|
||||
result: AgentResult,
|
||||
*,
|
||||
trajectory_id: str,
|
||||
exec_tool,
|
||||
workspace_meta: Optional[Dict[str, Any]] = None,
|
||||
) -> Tuple[float, Dict[str, Any]]:
|
||||
"""
|
||||
In-sandbox verification: install deps + run pytest with dataset nodeids.
|
||||
|
||||
Args:
|
||||
item: Dataset row
|
||||
result: Agent's rollout result
|
||||
trajectory_id: Unique trajectory ID
|
||||
exec_tool: async callable(tool_name, args, timeout) -> ExecutionResult
|
||||
workspace_meta: From setup_trajectory_workspace (has repo_dir)
|
||||
|
||||
Returns:
|
||||
(reward, metadata) tuple
|
||||
"""
|
||||
repo_dir = (workspace_meta or {}).get("repo_dir") or self._repo_name(item)
|
||||
|
||||
# Don't reward trajectories that never used tools
|
||||
tool_call_count = sum(
|
||||
len(msg.get("tool_calls", []))
|
||||
for msg in result.messages
|
||||
if msg.get("role") == "assistant"
|
||||
)
|
||||
if tool_call_count == 0:
|
||||
print(
|
||||
f"[SweSmithOracleEnv] tid={trajectory_id} verify: no tool calls; score=0.0",
|
||||
flush=True,
|
||||
)
|
||||
return 0.0, {"error": "No tool calls were made by the agent"}
|
||||
|
||||
nodeids = self._tests_for_item(item)
|
||||
if not nodeids:
|
||||
return 0.0, {"error": "No tests provided"}
|
||||
|
||||
# Install dependencies
|
||||
print(
|
||||
f"[SweSmithOracleEnv] tid={trajectory_id} verify: installing deps + running tests",
|
||||
flush=True,
|
||||
)
|
||||
setup_cmd = (
|
||||
f"cd {repo_dir} && "
|
||||
"python -m venv .venv && "
|
||||
". .venv/bin/activate && "
|
||||
"python -m pip install -U pip setuptools wheel && "
|
||||
"python -m pip install -e . && "
|
||||
"python -m pip install pytest"
|
||||
)
|
||||
setup_res = await exec_tool(
|
||||
"bash", {"command": setup_cmd}, timeout=self.config.install_timeout_s
|
||||
)
|
||||
if not setup_res.success:
|
||||
print(
|
||||
f"[SweSmithOracleEnv] tid={trajectory_id} install failed; score=0.0",
|
||||
flush=True,
|
||||
)
|
||||
return 0.0, {
|
||||
"phase": "install",
|
||||
"error": setup_res.error,
|
||||
"output": setup_res.output,
|
||||
}
|
||||
|
||||
# Run test chunks
|
||||
chunks = self._chunk_nodeids(nodeids, max_per_chunk=50)
|
||||
for chunk_idx, chunk in enumerate(chunks):
|
||||
joined = " ".join(chunk)
|
||||
cmd = f"cd {repo_dir} && . .venv/bin/activate && python -m pytest -q {joined}"
|
||||
res = await exec_tool(
|
||||
"bash", {"command": cmd}, timeout=self.config.test_timeout_s
|
||||
)
|
||||
if not res.success:
|
||||
print(
|
||||
f"[SweSmithOracleEnv] tid={trajectory_id} tests failed (chunk {chunk_idx}); score=0.0",
|
||||
flush=True,
|
||||
)
|
||||
return 0.0, {
|
||||
"phase": "pytest",
|
||||
"failed_chunk": chunk_idx,
|
||||
"error": res.error,
|
||||
"output": res.output,
|
||||
}
|
||||
|
||||
print(
|
||||
f"[SweSmithOracleEnv] tid={trajectory_id} all tests passed; score=1.0",
|
||||
flush=True,
|
||||
)
|
||||
return 1.0, {"passed": True}
|
||||
|
||||
# =========================================================================
|
||||
# Reward: run pytest in the terminal (local / non-sandbox path)
|
||||
# =========================================================================
|
||||
|
||||
async def compute_reward(
|
||||
self, item: Item, result: AgentResult, ctx: ToolContext
|
||||
) -> float:
|
||||
"""
|
||||
Verify by running pytest with the dataset's nodeids.
|
||||
|
||||
Reward structure (shaped to give training signal even when model can't solve tasks):
|
||||
- 0.0: No tool calls at all
|
||||
- 0.05: Per valid tool call (up to 0.3 max for tool-call shaping)
|
||||
- 0.4: Successfully installed deps
|
||||
- 1.0: All tests pass
|
||||
|
||||
The partial rewards for tool calls help the model learn to USE tools
|
||||
before it can learn to use them CORRECTLY. This is critical for cold-start
|
||||
training where the base model barely makes any tool calls.
|
||||
"""
|
||||
repo_dir = self._repo_name(item)
|
||||
|
||||
# Count tool calls (assistant messages that have tool_calls).
|
||||
# NOTE: we keep scoring policy here intentionally simple and env-specific.
|
||||
# The agent loop exposes additional tool-call metrics (attempted/schema_valid/
|
||||
# executed_ok/exec_error) that other environments may choose to use for
|
||||
# reward shaping, but we don't hard-require any particular calling format here.
|
||||
tool_call_count = sum(
|
||||
len(msg.get("tool_calls", []))
|
||||
for msg in result.messages
|
||||
if msg.get("role") == "assistant"
|
||||
)
|
||||
|
||||
if tool_call_count == 0:
|
||||
print(f"[SweSmithOracleEnv] No tool calls made; score=0.0", flush=True)
|
||||
return 0.0
|
||||
|
||||
# Partial reward: 0.05 per tool call, capped at 0.3
|
||||
tool_call_reward = min(tool_call_count * 0.05, 0.3)
|
||||
|
||||
# Debug: log tool-call quality metrics if present
|
||||
attempted = getattr(result, "tool_calls_attempted", None)
|
||||
schema_valid = getattr(result, "tool_calls_schema_valid", None)
|
||||
executed_ok = getattr(result, "tool_calls_executed_ok", None)
|
||||
exec_error = getattr(result, "tool_calls_exec_error", None)
|
||||
if attempted is not None:
|
||||
print(
|
||||
f"[SweSmithOracleEnv] Tool calls: total={tool_call_count}, attempted={attempted}, schema_valid={schema_valid}, ok={executed_ok}, err={exec_error}",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
nodeids = self._tests_for_item(item)
|
||||
if not nodeids:
|
||||
# No tests defined — just reward tool usage
|
||||
print(f"[SweSmithOracleEnv] No tests defined; score={tool_call_reward:.2f} (tool calls)", flush=True)
|
||||
return tool_call_reward
|
||||
|
||||
# Install deps + run tests
|
||||
print(f"[SweSmithOracleEnv] Verifying: installing deps + running tests", flush=True)
|
||||
setup_result = ctx.terminal(
|
||||
f"cd {repo_dir} && "
|
||||
"python -m venv .venv && "
|
||||
". .venv/bin/activate && "
|
||||
"python -m pip install -U pip setuptools wheel && "
|
||||
"python -m pip install -e . && "
|
||||
"python -m pip install pytest",
|
||||
timeout=int(self.config.install_timeout_s),
|
||||
)
|
||||
if setup_result.get("exit_code", 1) != 0:
|
||||
print(f"[SweSmithOracleEnv] Install failed; score={tool_call_reward:.2f} (tool calls only)", flush=True)
|
||||
return tool_call_reward
|
||||
|
||||
# Partial reward for successful install
|
||||
install_reward = 0.4
|
||||
|
||||
# Run test chunks
|
||||
chunks = self._chunk_nodeids(nodeids, max_per_chunk=50)
|
||||
for chunk_idx, chunk in enumerate(chunks):
|
||||
joined = " ".join(chunk)
|
||||
test_result = ctx.terminal(
|
||||
f"cd {repo_dir} && . .venv/bin/activate && python -m pytest -q {joined}",
|
||||
timeout=int(self.config.test_timeout_s),
|
||||
)
|
||||
if test_result.get("exit_code", 1) != 0:
|
||||
print(f"[SweSmithOracleEnv] Tests failed (chunk {chunk_idx}); score={install_reward:.2f} (install ok)", flush=True)
|
||||
return install_reward
|
||||
|
||||
print(f"[SweSmithOracleEnv] All tests passed; score=1.0", flush=True)
|
||||
return 1.0
|
||||
|
||||
# =========================================================================
|
||||
# Token truncation — keep start of trajectory, truncate from end
|
||||
# =========================================================================
|
||||
|
||||
def _build_scored_item(self, item, result, reward):
|
||||
"""
|
||||
Override to truncate tokens/masks from the END to fit within max_token_len.
|
||||
|
||||
Intuition (from NeurIPS finding): the start of the trajectory is most important
|
||||
for shifting the model distribution. Truncating from the end only costs ~2-3%
|
||||
vs handling the full sequence, but avoids the "Token length is too long" discard
|
||||
that throws away entire groups including valid training signal.
|
||||
"""
|
||||
scored_item, remaining = super()._build_scored_item(item, result, reward)
|
||||
if scored_item is None:
|
||||
return scored_item, remaining
|
||||
|
||||
# Use config.max_token_length as the truncation limit.
|
||||
# self.max_token_len comes from the trainer via /info, but may be -1
|
||||
# if the trainer hasn't registered yet (race condition).
|
||||
max_len = self.max_token_len
|
||||
if max_len <= 0:
|
||||
# Fallback to config value
|
||||
max_len = getattr(self.config, 'max_token_length', 0)
|
||||
if max_len <= 0:
|
||||
return scored_item, remaining
|
||||
|
||||
# Leave some margin (64 tokens) to avoid edge cases with padding alignment
|
||||
truncate_to = max_len - 64
|
||||
|
||||
tokens = scored_item.get("tokens")
|
||||
masks = scored_item.get("masks")
|
||||
|
||||
if tokens is not None and len(tokens) >= max_len:
|
||||
orig_len = len(tokens)
|
||||
scored_item["tokens"] = tokens[:truncate_to]
|
||||
if masks is not None and len(masks) >= max_len:
|
||||
scored_item["masks"] = masks[:truncate_to]
|
||||
logger.info(
|
||||
"Truncated trajectory from %d to %d tokens (max_token_len=%d)",
|
||||
orig_len, truncate_to, max_len,
|
||||
)
|
||||
|
||||
return scored_item, remaining
|
||||
|
||||
# =========================================================================
|
||||
# Evaluation (minimal for now)
|
||||
# =========================================================================
|
||||
|
||||
async def evaluate(self, *args, **kwargs):
|
||||
"""Placeholder evaluation — SWE tasks are too expensive for frequent eval."""
|
||||
start_time = time.time()
|
||||
await self.evaluate_log(
|
||||
metrics={"eval/placeholder": 0.0},
|
||||
samples=[],
|
||||
start_time=start_time,
|
||||
end_time=time.time(),
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
SweSmithOracleEnv.cli()
|
||||
@@ -49,15 +49,22 @@ class HermesToolCallParser(ToolCallParser):
|
||||
continue
|
||||
|
||||
tc_data = json.loads(raw_json)
|
||||
# Handle arguments: could be dict or already a JSON string
|
||||
raw_args = tc_data.get("arguments", {})
|
||||
if isinstance(raw_args, str):
|
||||
# Already a string — pass through as-is.
|
||||
# It may be a JSON string ("{...}") or a plain string ("ls").
|
||||
args_str = raw_args
|
||||
else:
|
||||
# Dict — serialize to JSON
|
||||
args_str = json.dumps(raw_args, ensure_ascii=False)
|
||||
tool_calls.append(
|
||||
ChatCompletionMessageToolCall(
|
||||
id=f"call_{uuid.uuid4().hex[:8]}",
|
||||
type="function",
|
||||
function=Function(
|
||||
name=tc_data["name"],
|
||||
arguments=json.dumps(
|
||||
tc_data.get("arguments", {}), ensure_ascii=False
|
||||
),
|
||||
arguments=args_str,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
34
hermes
34
hermes
@@ -7,6 +7,40 @@ Usage: ./hermes [options]
|
||||
"""
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""
|
||||
Fire (google/python-fire) does not support POSIX-style short flags like `-p`.
|
||||
We translate the most common shorthands to their long equivalents so wrapper
|
||||
scripts can reliably use:
|
||||
- `-p "..."` -> `--prompt "..."` (no TUI/banner; print result and exit)
|
||||
- `-q "..."` -> `--query "..."` (single-shot with banner UX)
|
||||
"""
|
||||
|
||||
import sys
|
||||
|
||||
def _rewrite_short_flags(argv: list[str]) -> list[str]:
|
||||
rewritten: list[str] = []
|
||||
i = 0
|
||||
while i < len(argv):
|
||||
arg = argv[i]
|
||||
if arg == "-p":
|
||||
rewritten.append("--prompt")
|
||||
if i + 1 < len(argv):
|
||||
rewritten.append(argv[i + 1])
|
||||
i += 2
|
||||
continue
|
||||
if arg == "-q":
|
||||
rewritten.append("--query")
|
||||
if i + 1 < len(argv):
|
||||
rewritten.append(argv[i + 1])
|
||||
i += 2
|
||||
continue
|
||||
rewritten.append(arg)
|
||||
i += 1
|
||||
return rewritten
|
||||
|
||||
sys.argv = [sys.argv[0]] + _rewrite_short_flags(sys.argv[1:])
|
||||
|
||||
from cli import main
|
||||
import fire
|
||||
|
||||
fire.Fire(main)
|
||||
|
||||
659
hermes_agent.egg-info/PKG-INFO
Normal file
659
hermes_agent.egg-info/PKG-INFO
Normal file
@@ -0,0 +1,659 @@
|
||||
Metadata-Version: 2.4
|
||||
Name: hermes-agent
|
||||
Version: 0.1.0
|
||||
Summary: AI agent with advanced tool-calling and toolsets
|
||||
Author: Nous Research
|
||||
License: MIT
|
||||
Requires-Python: >=3.10
|
||||
Description-Content-Type: text/markdown
|
||||
Requires-Dist: openai
|
||||
Requires-Dist: python-dotenv
|
||||
Requires-Dist: fire
|
||||
Requires-Dist: httpx
|
||||
Requires-Dist: rich
|
||||
Requires-Dist: tenacity
|
||||
Requires-Dist: pyyaml
|
||||
Requires-Dist: prompt_toolkit
|
||||
Requires-Dist: requests
|
||||
Requires-Dist: jinja2
|
||||
Requires-Dist: pydantic>=2.0
|
||||
Requires-Dist: firecrawl-py
|
||||
Requires-Dist: fal-client
|
||||
Requires-Dist: litellm>=1.75.5
|
||||
Requires-Dist: typer
|
||||
Requires-Dist: platformdirs
|
||||
Provides-Extra: modal
|
||||
Requires-Dist: modal; extra == "modal"
|
||||
Requires-Dist: boto3; extra == "modal"
|
||||
Provides-Extra: dev
|
||||
Requires-Dist: pytest; extra == "dev"
|
||||
Requires-Dist: pytest-asyncio; extra == "dev"
|
||||
Provides-Extra: atropos
|
||||
Requires-Dist: atroposlib @ git+https://github.com/NousResearch/atropos.git ; extra == "atropos"
|
||||
Requires-Dist: aiohttp; extra == "atropos"
|
||||
Requires-Dist: fastapi; extra == "atropos"
|
||||
Requires-Dist: uvicorn; extra == "atropos"
|
||||
Requires-Dist: pyte; extra == "atropos"
|
||||
|
||||
# Hermes Agent
|
||||
|
||||
An AI agent with advanced tool-calling capabilities, featuring a flexible toolsets system for organizing and managing tools.
|
||||
|
||||
## Features
|
||||
|
||||
- **Interactive CLI**: Beautiful terminal interface with animated feedback, personalities, and session management
|
||||
- **Web Tools**: Search, extract content, and crawl websites
|
||||
- **Terminal Tools**: Execute commands via local, Docker, Singularity, Modal, or SSH backends
|
||||
- **Browser Tools**: Automate web browsers to navigate, click, type, and extract content
|
||||
- **Vision Tools**: Analyze images from URLs
|
||||
- **Reasoning Tools**: Advanced multi-model reasoning (Mixture of Agents)
|
||||
- **Creative Tools**: Generate images from text prompts
|
||||
- **Skills Tools**: On-demand knowledge documents with progressive disclosure
|
||||
- **Toolsets System**: Organize tools into logical groups for different scenarios
|
||||
- **Batch Processing**: Process datasets in parallel with checkpointing and statistics tracking
|
||||
- **Ephemeral System Prompts**: Guide model behavior without polluting training datasets
|
||||
|
||||
## Quick Start (CLI)
|
||||
|
||||
```bash
|
||||
# After setup (see below), just run:
|
||||
./hermes
|
||||
|
||||
# Or with options:
|
||||
./hermes --model "anthropic/claude-sonnet-4" --toolsets "web,terminal"
|
||||
```
|
||||
|
||||
The CLI provides:
|
||||
- Animated spinners during thinking and tool execution
|
||||
- Kawaii-style feedback messages
|
||||
- `/commands` for configuration, history, and session management
|
||||
- Customizable personalities (`/personality kawaii`, `/personality pirate`, etc.)
|
||||
- Persistent configuration via `cli-config.yaml`
|
||||
|
||||
## Setup
|
||||
|
||||
### 1. Clone the Repository
|
||||
```bash
|
||||
# Clone with submodules (recommended)
|
||||
git clone --recurse-submodules https://github.com/NousResearch/Hermes-Agent.git
|
||||
cd Hermes-Agent
|
||||
|
||||
# Or if already cloned without submodules:
|
||||
git submodule update --init --recursive
|
||||
```
|
||||
|
||||
### 2. Install Dependencies
|
||||
```bash
|
||||
# Create and activate virtual environment (recommended)
|
||||
python3 -m venv venv
|
||||
source venv/bin/activate # On Windows: venv\Scripts\activate
|
||||
|
||||
# Install Python packages
|
||||
pip install -r requirements.txt
|
||||
|
||||
# Install mini-swe-agent for terminal tools
|
||||
pip install -e ./mini-swe-agent
|
||||
|
||||
# Install Node.js dependencies for browser tools (requires Node.js)
|
||||
npm install
|
||||
```
|
||||
|
||||
### 3. Configure Environment Variables
|
||||
```bash
|
||||
# Copy the example environment file
|
||||
cp .env.example .env
|
||||
|
||||
# Edit .env and add your API keys
|
||||
nano .env # or use your preferred editor
|
||||
```
|
||||
|
||||
**Required API Keys:**
|
||||
- `OPENROUTER_API_KEY` - LLM access via OpenRouter (get at: https://openrouter.ai/keys)
|
||||
- `FIRECRAWL_API_KEY` - Web tools (get at: https://firecrawl.dev/)
|
||||
- `NOUS_API_KEY` - Vision & reasoning tools (get at: https://inference-api.nousresearch.com/)
|
||||
- `FAL_KEY` - Image generation (get at: https://fal.ai/)
|
||||
|
||||
**Optional API Keys (for specific features):**
|
||||
- `BROWSERBASE_API_KEY` - Browser automation (get at: https://browserbase.com/)
|
||||
- `BROWSERBASE_PROJECT_ID` - From Browserbase dashboard
|
||||
- `MORPH_API_KEY` - For legacy Hecate terminal backend (get at: https://morph.so/)
|
||||
|
||||
### 4. Configure Terminal Backend
|
||||
|
||||
The terminal tool uses **mini-swe-agent** environments. Configure in `.env` or `cli-config.yaml`:
|
||||
|
||||
```bash
|
||||
# Backend: "local", "docker", "singularity", "modal", or "ssh"
|
||||
TERMINAL_ENV=local # Default: runs on host machine (no isolation)
|
||||
TERMINAL_ENV=ssh # Remote execution via SSH (agent code stays local)
|
||||
TERMINAL_ENV=singularity # Recommended for HPC: Apptainer/Singularity containers
|
||||
TERMINAL_ENV=docker # Isolated Docker containers
|
||||
TERMINAL_ENV=modal # Cloud execution via Modal
|
||||
|
||||
# Container image (for docker/singularity/modal backends)
|
||||
TERMINAL_DOCKER_IMAGE=python:3.11-slim
|
||||
TERMINAL_SINGULARITY_IMAGE=docker://python:3.11-slim
|
||||
TERMINAL_TIMEOUT=60
|
||||
|
||||
# SSH backend (for ssh)
|
||||
TERMINAL_SSH_HOST=my-server.example.com
|
||||
TERMINAL_SSH_USER=myuser
|
||||
TERMINAL_SSH_KEY=~/.ssh/id_rsa # Optional, uses ssh-agent if not set
|
||||
```
|
||||
|
||||
**Backend Requirements:**
|
||||
- **local**: No extra setup (runs directly on your machine, no isolation)
|
||||
- **ssh**: SSH access to remote machine (great for sandboxing - agent can't touch its own code)
|
||||
- **singularity**: Requires Apptainer or Singularity installed (common on HPC clusters, no root needed)
|
||||
- **docker**: Requires Docker installed and user in `docker` group
|
||||
- **modal**: Requires Modal account (see setup below)
|
||||
|
||||
### Singularity/Apptainer Setup (Recommended for HPC)
|
||||
|
||||
Singularity/Apptainer provides rootless container execution, ideal for HPC clusters:
|
||||
|
||||
```bash
|
||||
# 1. Verify Apptainer is installed
|
||||
apptainer --version # or: singularity --version
|
||||
|
||||
# 2. Set up cache directories (important for parallel workers)
|
||||
# Use /scratch if available (HPC), otherwise /tmp
|
||||
export APPTAINER_CACHEDIR=/scratch/$USER/.apptainer
|
||||
export APPTAINER_TMPDIR=/scratch/$USER/.apptainer/tmp
|
||||
mkdir -p "$APPTAINER_CACHEDIR" "$APPTAINER_TMPDIR"
|
||||
|
||||
# 3. Pre-build SIF image (recommended for parallel batch processing)
|
||||
# This avoids race conditions when multiple workers start simultaneously
|
||||
apptainer build $APPTAINER_CACHEDIR/python-nodejs.sif docker://nikolaik/python-nodejs:python3.11-nodejs20
|
||||
|
||||
# 4. Configure .env to use the local SIF
|
||||
TERMINAL_ENV=singularity
|
||||
TERMINAL_SINGULARITY_IMAGE=/scratch/$USER/.apptainer/python-nodejs.sif
|
||||
```
|
||||
|
||||
**Tip:** The batch scripts in `configs/` automatically handle SIF pre-building if `/scratch` is available.
|
||||
|
||||
### Modal Cloud Backend Setup
|
||||
|
||||
[Modal](https://modal.com) provides serverless cloud compute for running sandboxed environments at scale.
|
||||
|
||||
```bash
|
||||
# 1. Install Modal and dependencies
|
||||
pip install modal boto3
|
||||
|
||||
# 2. Authenticate with Modal (opens browser)
|
||||
modal setup
|
||||
|
||||
# 3. Set terminal backend to modal in .env
|
||||
TERMINAL_ENV=modal
|
||||
```
|
||||
|
||||
Modal uses CLI-based authentication (stored in `~/.modal/`), so no API key is needed in `.env`. After running `modal setup`, commands will automatically execute in Modal's cloud sandboxes.
|
||||
|
||||
### Browser Tools Setup
|
||||
|
||||
Browser tools enable the agent to navigate websites, fill forms, click buttons, and extract content. They use [agent-browser](https://github.com/vercel-labs/agent-browser) CLI with [Browserbase](https://browserbase.com) cloud execution.
|
||||
|
||||
```bash
|
||||
# 1. Install Node.js (if not already installed)
|
||||
# Use nvm (recommended) or your package manager
|
||||
|
||||
# 2. Install agent-browser CLI (choose one option):
|
||||
npm install -g agent-browser # Option A: Global install (recommended)
|
||||
npm install # Option B: Local install (uses npx fallback)
|
||||
|
||||
# 3. Get Browserbase credentials
|
||||
# Sign up at https://browserbase.com/ and get your:
|
||||
# - API Key (from Settings → API Keys)
|
||||
# - Project ID (from your project dashboard)
|
||||
|
||||
# 4. Add to your .env file:
|
||||
BROWSERBASE_API_KEY=your_api_key_here
|
||||
BROWSERBASE_PROJECT_ID=your_project_id_here
|
||||
```
|
||||
|
||||
**Available Browser Tools:**
|
||||
|
||||
| Tool | Description |
|
||||
|------|-------------|
|
||||
| `browser_navigate` | Navigate to a URL |
|
||||
| `browser_snapshot` | Get text-based page snapshot with element refs |
|
||||
| `browser_click` | Click an element by ref (e.g., `@e5`) |
|
||||
| `browser_type` | Type text into an input field |
|
||||
| `browser_scroll` | Scroll up or down |
|
||||
| `browser_back` | Go back in browser history |
|
||||
| `browser_press` | Press a keyboard key (Enter, Tab, etc.) |
|
||||
| `browser_close` | Close the browser session |
|
||||
| `browser_get_images` | Get list of images on the page |
|
||||
|
||||
**Example Usage:**
|
||||
```bash
|
||||
# Use browser tools with web search and vision
|
||||
python run_agent.py \
|
||||
--query "Go to amazon.com and find the price of the latest Kindle" \
|
||||
--enabled_toolsets=browser,web,vision
|
||||
|
||||
# Use browser-focused distribution
|
||||
python batch_runner.py \
|
||||
--dataset_file=browser_tasks.jsonl \
|
||||
--distribution=browser_use \
|
||||
--run_name=browser_run
|
||||
```
|
||||
|
||||
See `.env.example` for all available configuration options including debug settings.
|
||||
|
||||
### Skills Tools
|
||||
|
||||
Skills are on-demand knowledge documents the agent can load when needed. They follow a **progressive disclosure** pattern to minimize token usage:
|
||||
|
||||
```
|
||||
skills/
|
||||
├── mlops/ # Category folder
|
||||
│ ├── axolotl/ # Skill folder
|
||||
│ │ ├── SKILL.md # Main instructions (required)
|
||||
│ │ ├── references/ # Additional docs, API specs
|
||||
│ │ └── templates/ # Output formats, configs
|
||||
│ └── vllm/
|
||||
│ └── SKILL.md
|
||||
```
|
||||
|
||||
**Available Skills Tools:**
|
||||
|
||||
| Tool | Description |
|
||||
|------|-------------|
|
||||
| `skills_categories` | List available skill categories (~50 tokens) |
|
||||
| `skills_list` | List skills with name + description (~3k tokens for 40 skills) |
|
||||
| `skill_view` | Load full skill content, tags, and linked files |
|
||||
|
||||
**Example Usage:**
|
||||
```bash
|
||||
# Use skills tools
|
||||
python run_agent.py \
|
||||
--query "What skills do you have for fine-tuning? Show me the axolotl skill." \
|
||||
--enabled_toolsets=skills
|
||||
```
|
||||
|
||||
**Creating Skills:**
|
||||
|
||||
Skills use YAML frontmatter for metadata:
|
||||
```yaml
|
||||
---
|
||||
name: my-skill
|
||||
description: Brief description shown in skills_list
|
||||
tags: [tag1, tag2]
|
||||
related_skills: [other-skill]
|
||||
version: 1.0.0
|
||||
---
|
||||
# Skill Content
|
||||
|
||||
Instructions, examples, and guidelines here...
|
||||
```
|
||||
|
||||
Skills can include:
|
||||
- `references/` - Additional documentation, API specs, examples
|
||||
- `templates/` - Output formats, config files, boilerplate code
|
||||
- `scripts/` - Executable helpers (Python, shell scripts)
|
||||
|
||||
## Session Logging
|
||||
|
||||
Every conversation is automatically logged to `logs/` for debugging and inspection:
|
||||
|
||||
```
|
||||
logs/
|
||||
├── session_20260201_143052_a1b2c3.json
|
||||
├── session_20260201_150217_d4e5f6.json
|
||||
└── ...
|
||||
```
|
||||
|
||||
**Log Format:**
|
||||
```json
|
||||
{
|
||||
"session_id": "20260201_143052_a1b2c3",
|
||||
"model": "anthropic/claude-sonnet-4",
|
||||
"session_start": "2026-02-01T14:30:52.123456",
|
||||
"last_updated": "2026-02-01T14:35:12.789012",
|
||||
"message_count": 8,
|
||||
"conversations": [
|
||||
{"from": "system", "value": "..."},
|
||||
{"from": "human", "value": "..."},
|
||||
{"from": "gpt", "value": "..."},
|
||||
{"from": "tool", "value": "..."}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
- **Automatic**: Logs are created and updated automatically after each conversation turn
|
||||
- **Session ID in Banner**: The CLI displays the session ID in the welcome banner
|
||||
- **Trajectory Format**: Uses the same format as batch processing for consistency
|
||||
- **Git Ignored**: `logs/` is in `.gitignore` so logs aren't committed
|
||||
|
||||
## Interactive CLI
|
||||
|
||||
The CLI provides a rich interactive experience for working with the agent.
|
||||
|
||||
### Running the CLI
|
||||
|
||||
```bash
|
||||
# Basic usage
|
||||
./hermes
|
||||
|
||||
# With specific model
|
||||
./hermes --model "anthropic/claude-sonnet-4"
|
||||
|
||||
# With specific toolsets
|
||||
./hermes --toolsets "web,terminal,skills"
|
||||
```
|
||||
|
||||
### CLI Commands
|
||||
|
||||
| Command | Description |
|
||||
|---------|-------------|
|
||||
| `/help` | Show available commands |
|
||||
| `/tools` | List available tools by toolset |
|
||||
| `/toolsets` | List available toolsets |
|
||||
| `/model [name]` | Show or change the current model |
|
||||
| `/prompt [text]` | View/set custom system prompt |
|
||||
| `/personality [name]` | Set a predefined personality |
|
||||
| `/clear` | Clear screen and reset conversation |
|
||||
| `/reset` | Reset conversation only |
|
||||
| `/history` | Show conversation history |
|
||||
| `/save` | Save current conversation to file |
|
||||
| `/config` | Show current configuration |
|
||||
| `/quit` | Exit the CLI |
|
||||
|
||||
### Configuration
|
||||
|
||||
Copy `cli-config.yaml.example` to `cli-config.yaml` and customize:
|
||||
|
||||
```yaml
|
||||
# Model settings
|
||||
model:
|
||||
default: "anthropic/claude-sonnet-4"
|
||||
|
||||
# Terminal backend (local, docker, singularity, modal, or ssh)
|
||||
terminal:
|
||||
env_type: "local"
|
||||
cwd: "." # Use current directory
|
||||
|
||||
# Or use SSH for remote execution (keeps agent code isolated)
|
||||
# terminal:
|
||||
# env_type: "ssh"
|
||||
# ssh_host: "my-server.example.com"
|
||||
# ssh_user: "myuser"
|
||||
# ssh_key: "~/.ssh/id_rsa"
|
||||
# cwd: "/home/myuser/project"
|
||||
|
||||
# Enable specific toolsets
|
||||
toolsets:
|
||||
- all # or: web, terminal, browser, vision, etc.
|
||||
|
||||
# Custom personalities (use with /personality command)
|
||||
agent:
|
||||
personalities:
|
||||
helpful: "You are a helpful assistant."
|
||||
kawaii: "You are a kawaii assistant! Use cute expressions..."
|
||||
```
|
||||
|
||||
### Personalities
|
||||
|
||||
Built-in personalities available via `/personality`:
|
||||
- `helpful`, `concise`, `technical`, `creative`, `teacher`
|
||||
- `kawaii`, `catgirl`, `pirate`, `shakespeare`, `surfer`
|
||||
- `noir`, `uwu`, `philosopher`, `hype`
|
||||
|
||||
## Toolsets System
|
||||
|
||||
The agent uses a toolsets system for organizing and managing tools. All tools must be part of a toolset to be accessible - individual tool selection is not supported. This ensures consistent and logical grouping of capabilities.
|
||||
|
||||
### Key Concepts
|
||||
|
||||
- **Toolsets**: Logical groups of tools for specific use cases (e.g., "research", "development", "debugging")
|
||||
- **Composition**: Toolsets can include other toolsets for powerful combinations
|
||||
- **Custom Toolsets**: Create your own toolsets at runtime or by editing `toolsets.py`
|
||||
- **Toolset-Only Access**: Tools are only accessible through toolsets, not individually
|
||||
|
||||
### Available Toolsets
|
||||
|
||||
See `toolsets.py` for the complete list of predefined toolsets including:
|
||||
- Basic toolsets (web, terminal, vision, creative, reasoning)
|
||||
- Composite toolsets (research, development, analysis, etc.)
|
||||
- Scenario-specific toolsets (debugging, documentation, API testing, etc.)
|
||||
- Special toolsets (safe mode without terminal, minimal, offline)
|
||||
|
||||
### Using Toolsets
|
||||
|
||||
```bash
|
||||
# Use a predefined toolset
|
||||
python run_agent.py --enabled_toolsets=research --query "Find latest AI papers"
|
||||
|
||||
# Combine multiple toolsets
|
||||
python run_agent.py --enabled_toolsets=web,vision --query "Analyze this website"
|
||||
|
||||
# Enable all toolsets explicitly (same as omitting the flag)
|
||||
python run_agent.py --enabled_toolsets=all --query "Do web research and run commands if helpful"
|
||||
|
||||
# Safe mode (no terminal access)
|
||||
python run_agent.py --enabled_toolsets=safe --query "Help without running commands"
|
||||
|
||||
# List all available toolsets and tools
|
||||
python run_agent.py --list_tools
|
||||
```
|
||||
|
||||
See `toolsets.py` for the complete list of available toolsets and how to create custom ones.
|
||||
|
||||
## Basic Usage
|
||||
|
||||
### Default (all tools enabled)
|
||||
```bash
|
||||
# Uses OpenRouter by default - just set OPENROUTER_API_KEY in .env
|
||||
python run_agent.py \
|
||||
--query "search up the latest docs on jit in python 3.13 and write me basic example that's not in their docs. profile its perf" \
|
||||
--max_turns 20 \
|
||||
--model anthropic/claude-sonnet-4-20250514
|
||||
```
|
||||
|
||||
### With specific toolset
|
||||
```bash
|
||||
python run_agent.py \
|
||||
--query "Debug this Python error" \
|
||||
--enabled_toolsets=debugging \
|
||||
--model anthropic/claude-sonnet-4-20250514
|
||||
```
|
||||
|
||||
### Python API
|
||||
```python
|
||||
from run_agent import AIAgent
|
||||
|
||||
# Uses OpenRouter by default (reads OPENROUTER_API_KEY from .env)
|
||||
agent = AIAgent(
|
||||
model="anthropic/claude-sonnet-4-20250514",
|
||||
enabled_toolsets=["research"]
|
||||
)
|
||||
response = agent.chat("Find information about quantum computing")
|
||||
|
||||
# Create custom toolset at runtime
|
||||
from toolsets import create_custom_toolset
|
||||
|
||||
create_custom_toolset(
|
||||
name="my_tools",
|
||||
description="My custom toolkit",
|
||||
tools=["web_search"],
|
||||
includes=["terminal", "vision"]
|
||||
)
|
||||
|
||||
agent = AIAgent(enabled_toolsets=["my_tools"])
|
||||
```
|
||||
|
||||
## Batch Processing
|
||||
|
||||
Process multiple prompts from a dataset in parallel with automatic checkpointing and statistics tracking:
|
||||
|
||||
```bash
|
||||
# Basic batch processing
|
||||
python batch_runner.py \
|
||||
--dataset_file=prompts.jsonl \
|
||||
--batch_size=20 \
|
||||
--run_name=my_run
|
||||
|
||||
# With specific distribution
|
||||
python batch_runner.py \
|
||||
--dataset_file=prompts.jsonl \
|
||||
--batch_size=20 \
|
||||
--run_name=image_run \
|
||||
--distribution=image_gen \
|
||||
--num_workers=4
|
||||
```
|
||||
|
||||
**Key Features:**
|
||||
- Parallel processing with configurable workers
|
||||
- Toolset distributions for varied data generation
|
||||
- Automatic checkpointing and resume capability
|
||||
- Combined output in `data/<run_name>/trajectories.jsonl`
|
||||
- Tool usage statistics and success rates
|
||||
|
||||
Use `--list_distributions` to see available toolset distributions for varied data generation.
|
||||
|
||||
### Trajectory Compression
|
||||
|
||||
Post-process trajectories to fit within token budgets for training:
|
||||
|
||||
```bash
|
||||
# Compress a directory of JSONL files
|
||||
python trajectory_compressor.py --input=data/my_run
|
||||
|
||||
# Compress a single JSONL file
|
||||
python trajectory_compressor.py --input=data/trajectories.jsonl
|
||||
|
||||
# Compress a 15% sample (useful for creating smaller training sets)
|
||||
python trajectory_compressor.py --input=data/trajectories.jsonl --sample_percent=15
|
||||
|
||||
# Custom output and token target
|
||||
python trajectory_compressor.py \
|
||||
--input=data/trajectories.jsonl \
|
||||
--output=data/compressed.jsonl \
|
||||
--target_max_tokens=16000
|
||||
```
|
||||
|
||||
**Features:**
|
||||
- Protects first turns (system, human, first GPT response, first tool call)
|
||||
- Protects last N turns (configurable)
|
||||
- Summarizes middle turns using LLM to fit target token budget
|
||||
- Supports both directory and single file input
|
||||
- Optional random sampling with `--sample_percent`
|
||||
- Configurable via `configs/trajectory_compression.yaml`
|
||||
|
||||
### Ephemeral System Prompts
|
||||
|
||||
The ephemeral system prompt feature allows you to guide the model's behavior during batch processing **without** saving that prompt to the training dataset trajectories. This is useful for:
|
||||
|
||||
- Guiding model behavior during data collection
|
||||
- Adding task-specific instructions
|
||||
- Keeping saved trajectories clean and focused on tool-calling format
|
||||
|
||||
**Example:**
|
||||
```bash
|
||||
python batch_runner.py \
|
||||
--dataset_file=prompts.jsonl \
|
||||
--batch_size=10 \
|
||||
--run_name=my_run \
|
||||
--ephemeral_system_prompt="You are a helpful assistant focused on image generation."
|
||||
```
|
||||
|
||||
The ephemeral prompt will influence the model's behavior during execution, but **only the standard tool-calling system prompt** will be saved in the trajectory files.
|
||||
|
||||
The ephemeral prompt influences model behavior during execution, but **only the standard tool-calling system prompt** is saved in trajectory files.
|
||||
|
||||
## Command Line Arguments
|
||||
|
||||
**Single Agent (`run_agent.py`):**
|
||||
- `--query`: The question or task for the agent
|
||||
- `--model`: Model to use (default: claude-opus-4-20250514)
|
||||
- `--api_key`: API key for authentication
|
||||
- `--base_url`: API endpoint URL
|
||||
- `--max_turns`: Maximum number of tool-calling iterations
|
||||
- `--enabled_toolsets`: Comma-separated list of toolsets to enable. Use `all` (or `*`) to enable everything. If omitted, all toolsets are enabled by default.
|
||||
- `--disabled_toolsets`: Comma-separated list of toolsets to disable
|
||||
- `--list_tools`: List all available toolsets and tools
|
||||
- `--save_trajectories`: Save conversation trajectories to JSONL files
|
||||
|
||||
**Batch Processing (`batch_runner.py`):**
|
||||
- `--dataset_file`: Path to JSONL file with prompts
|
||||
- `--batch_size`: Number of prompts per batch
|
||||
- `--run_name`: Name for this run (for output/checkpointing)
|
||||
- `--distribution`: Toolset distribution to use (default: "default")
|
||||
- `--num_workers`: Number of parallel workers (default: 4)
|
||||
- `--resume`: Resume from checkpoint if interrupted
|
||||
- `--ephemeral_system_prompt`: System prompt used during execution but NOT saved to trajectories
|
||||
- `--list_distributions`: List available toolset distributions
|
||||
|
||||
## Environment Variables
|
||||
|
||||
All environment variables can be configured in the `.env` file (copy from `.env.example`).
|
||||
|
||||
**LLM Provider (OpenRouter):**
|
||||
- `OPENROUTER_API_KEY`: Primary LLM access via OpenRouter (supports Claude, GPT-4, Gemini, etc.)
|
||||
- `LLM_MODEL`: Default model (e.g., `anthropic/claude-sonnet-4`, `openai/gpt-4o`)
|
||||
|
||||
**Tool API Keys:**
|
||||
- `FIRECRAWL_API_KEY`: Web tools (search, extract, crawl)
|
||||
- `NOUS_API_KEY`: Vision and reasoning tools
|
||||
- `FAL_KEY`: Image generation tools
|
||||
|
||||
**Terminal Tool Configuration (mini-swe-agent backend):**
|
||||
- `TERMINAL_ENV`: Backend type - `local`, `docker`, `singularity`, `modal`, or `ssh` (default: `local`)
|
||||
- `TERMINAL_DOCKER_IMAGE`: Docker image for docker backend (default: `python:3.11-slim`)
|
||||
- `TERMINAL_SINGULARITY_IMAGE`: Singularity/Apptainer image (can be `docker://...` URL or local `.sif` path)
|
||||
- `TERMINAL_TIMEOUT`: Command timeout in seconds (default: `60`)
|
||||
- `TERMINAL_LIFETIME_SECONDS`: Cleanup inactive environments after this time (default: `300`)
|
||||
- `TERMINAL_CWD`: Working directory inside containers (default: `/tmp`)
|
||||
- `TERMINAL_SCRATCH_DIR`: Custom scratch directory for sandbox storage (optional, auto-detects `/scratch`)
|
||||
- `SUDO_PASSWORD`: Enable sudo commands by piping password via `sudo -S` (works with all backends)
|
||||
- If unset in CLI mode, you'll be prompted interactively when sudo is needed (45s timeout)
|
||||
|
||||
**SSH Backend Configuration (for remote execution):**
|
||||
- `TERMINAL_SSH_HOST`: Remote server hostname or IP
|
||||
- `TERMINAL_SSH_USER`: SSH username
|
||||
- `TERMINAL_SSH_PORT`: SSH port (default: `22`)
|
||||
- `TERMINAL_SSH_KEY`: Path to SSH private key (optional, uses ssh-agent if not set)
|
||||
|
||||
**Browser Tool Configuration (agent-browser + Browserbase):**
|
||||
- `BROWSERBASE_API_KEY`: Browserbase API key for cloud browser execution
|
||||
- `BROWSERBASE_PROJECT_ID`: Browserbase project ID
|
||||
- `BROWSER_SESSION_TIMEOUT`: Session timeout in seconds (default: `300`)
|
||||
|
||||
**Legacy Hecate Terminal Backend (optional):**
|
||||
- `MORPH_API_KEY`: For Hecate/MorphCloud terminal backend
|
||||
- `HECATE_VM_LIFETIME_SECONDS`: VM lifetime (default: 300)
|
||||
- `HECATE_DEFAULT_SNAPSHOT_ID`: Default snapshot (default: snapshot_p5294qxt)
|
||||
|
||||
**Debug Options:**
|
||||
- `WEB_TOOLS_DEBUG`, `VISION_TOOLS_DEBUG`, `MOA_TOOLS_DEBUG`, `IMAGE_TOOLS_DEBUG`: Enable debug logging
|
||||
|
||||
## Key Files
|
||||
|
||||
| File | Purpose |
|
||||
|------|---------|
|
||||
| `hermes` | CLI launcher script (run with `./hermes`) |
|
||||
| `cli.py` | Interactive CLI implementation |
|
||||
| `cli-config.yaml` | CLI configuration (copy from `.example`) |
|
||||
| `run_agent.py` | Main agent runner - single query execution |
|
||||
| `batch_runner.py` | Parallel batch processing with checkpointing |
|
||||
| `model_tools.py` | Core tool definitions and handlers |
|
||||
| `toolsets.py` | Toolset definitions and composition |
|
||||
| `toolset_distributions.py` | Probability distributions for data generation |
|
||||
| `trajectory_compressor.py` | Post-process trajectories for training |
|
||||
| `tools/` | Individual tool implementations |
|
||||
| `tools/skills_tool.py` | Skills system with progressive disclosure |
|
||||
| `skills/` | On-demand knowledge documents |
|
||||
| `docs/` | Documentation |
|
||||
| `configs/` | Example batch run scripts |
|
||||
|
||||
# Atropos Integrations & RL Training
|
||||
|
||||
## Nomad Setup
|
||||
Follow this: https://developer.hashicorp.com/nomad/docs/deploy
|
||||
|
||||
## Atropos dependencies
|
||||
python3 -m venv .venv
|
||||
source .venv/bin/activate
|
||||
pip install -e '.[atropos]'
|
||||
70
hermes_agent.egg-info/SOURCES.txt
Normal file
70
hermes_agent.egg-info/SOURCES.txt
Normal file
@@ -0,0 +1,70 @@
|
||||
README.md
|
||||
atropos_compatible_agent.py
|
||||
batch_runner.py
|
||||
local_server.py
|
||||
model_tools.py
|
||||
pyproject.toml
|
||||
run_agent.py
|
||||
toolset_distributions.py
|
||||
toolsets.py
|
||||
trajectory_compressor.py
|
||||
atropos/__init__.py
|
||||
atropos/sandbox_server.py
|
||||
atropos/agent/__init__.py
|
||||
atropos/agent/atropos_agent.py
|
||||
atropos/api/__init__.py
|
||||
atropos/api/tool_executor_server.py
|
||||
atropos/api/tool_server.py
|
||||
atropos/backends/__init__.py
|
||||
atropos/backends/base.py
|
||||
atropos/backends/modal_backend.py
|
||||
atropos/backends/nomad_backend.py
|
||||
atropos/envs/__init__.py
|
||||
atropos/envs/agent_env.py
|
||||
atropos/envs/hermes_compat_test_env.py
|
||||
atropos/envs/sandbox_terminal_smoke_env.py
|
||||
atropos/envs/swe_smith_oracle_env.py
|
||||
atropos/envs/test_env.py
|
||||
atropos/envs/toolserver_smoke_env.py
|
||||
atropos/nomad/__init__.py
|
||||
atropos/nomad/client.py
|
||||
atropos/slots/__init__.py
|
||||
atropos/slots/executor.py
|
||||
atropos/slots/pool.py
|
||||
atropos/slots/slot.py
|
||||
atropos/terminal/__init__.py
|
||||
atropos/terminal/asciinema_stream.py
|
||||
atropos/tools/__init__.py
|
||||
atropos/tools/base.py
|
||||
atropos/tools/build_registry.py
|
||||
atropos/tools/hermes_external_tools.py
|
||||
atropos/tools/sandbox_stubs.py
|
||||
atropos/tools/terminal_stateful_tool.py
|
||||
atropos/tools/tmux_tool.py
|
||||
atropos/tools/tool_executor.py
|
||||
atropos/tools/toolset_resolver.py
|
||||
hermes_agent.egg-info/PKG-INFO
|
||||
hermes_agent.egg-info/SOURCES.txt
|
||||
hermes_agent.egg-info/dependency_links.txt
|
||||
hermes_agent.egg-info/entry_points.txt
|
||||
hermes_agent.egg-info/requires.txt
|
||||
hermes_agent.egg-info/top_level.txt
|
||||
tests/test_batch_runner.py
|
||||
tests/test_checkpoint_resumption.py
|
||||
tests/test_modal_integration.py
|
||||
tests/test_modal_stress.py
|
||||
tests/test_modal_terminal.py
|
||||
tests/test_nous_api_limits.py
|
||||
tests/test_nous_api_pattern.py
|
||||
tests/test_temperature_fix.py
|
||||
tests/test_tool_call_parsing.py
|
||||
tests/test_web_tools.py
|
||||
tools/__init__.py
|
||||
tools/browser_tool.py
|
||||
tools/image_generation_tool.py
|
||||
tools/mixture_of_agents_tool.py
|
||||
tools/skills_tool.py
|
||||
tools/terminal_hecate.py
|
||||
tools/terminal_tool.py
|
||||
tools/vision_tools.py
|
||||
tools/web_tools.py
|
||||
1
hermes_agent.egg-info/dependency_links.txt
Normal file
1
hermes_agent.egg-info/dependency_links.txt
Normal file
@@ -0,0 +1 @@
|
||||
|
||||
4
hermes_agent.egg-info/entry_points.txt
Normal file
4
hermes_agent.egg-info/entry_points.txt
Normal file
@@ -0,0 +1,4 @@
|
||||
[console_scripts]
|
||||
hermes-agent = run_agent:main
|
||||
hermes-atropos-sandbox-smoke = atropos.envs.sandbox_terminal_smoke_env:SandboxTerminalSmokeEnv.cli
|
||||
hermes-atropos-toolserver-smoke = atropos.envs.toolserver_smoke_env:ToolServerSmokeEnv.cli
|
||||
31
hermes_agent.egg-info/requires.txt
Normal file
31
hermes_agent.egg-info/requires.txt
Normal file
@@ -0,0 +1,31 @@
|
||||
openai
|
||||
python-dotenv
|
||||
fire
|
||||
httpx
|
||||
rich
|
||||
tenacity
|
||||
pyyaml
|
||||
prompt_toolkit
|
||||
requests
|
||||
jinja2
|
||||
pydantic>=2.0
|
||||
firecrawl-py
|
||||
fal-client
|
||||
litellm>=1.75.5
|
||||
typer
|
||||
platformdirs
|
||||
|
||||
[atropos]
|
||||
atroposlib @ git+https://github.com/NousResearch/atropos.git
|
||||
aiohttp
|
||||
fastapi
|
||||
uvicorn
|
||||
pyte
|
||||
|
||||
[dev]
|
||||
pytest
|
||||
pytest-asyncio
|
||||
|
||||
[modal]
|
||||
modal
|
||||
boto3
|
||||
10
hermes_agent.egg-info/top_level.txt
Normal file
10
hermes_agent.egg-info/top_level.txt
Normal file
@@ -0,0 +1,10 @@
|
||||
atropos
|
||||
atropos_compatible_agent
|
||||
batch_runner
|
||||
local_server
|
||||
model_tools
|
||||
run_agent
|
||||
tools
|
||||
toolset_distributions
|
||||
toolsets
|
||||
trajectory_compressor
|
||||
353
local_server.py
Normal file
353
local_server.py
Normal file
@@ -0,0 +1,353 @@
|
||||
"""
|
||||
Local OpenAI-compatible server implementation for Hermes-Agent (Atropos integration).
|
||||
|
||||
Extends the Atropos APIServer to work with local OpenAI-compatible APIs (e.g. vLLM, SGLang),
|
||||
providing tokens_and_logprobs_completion support via client-side tokenization.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import warnings
|
||||
from typing import Any, List, Optional
|
||||
|
||||
import openai
|
||||
from openai.types.chat.chat_completion import ChatCompletion
|
||||
from openai.types.completion import Completion
|
||||
|
||||
from atroposlib.envs.server_handling.server_baseline import (
|
||||
APIServer,
|
||||
APIServerConfig,
|
||||
ReasoningConfig,
|
||||
)
|
||||
|
||||
|
||||
class LocalServer(APIServer):
|
||||
"""
|
||||
OpenAI-compatible local server with tokens_and_logprobs support.
|
||||
|
||||
Uses an OpenAI-compatible API (typically at a /v1 endpoint) and handles
|
||||
token extraction via client-side tokenization.
|
||||
|
||||
Note: Many local servers don't return per-token logprobs in the standard API,
|
||||
so this implementation uses placeholder logprobs (0.0) for PoC purposes.
|
||||
For production training, use vLLM/SGLang servers that return real logprobs.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: APIServerConfig,
|
||||
tokenizer: Optional[Any] = None,
|
||||
tokenizer_name: str = "gpt2",
|
||||
reasoning_config: Optional[ReasoningConfig] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the local server.
|
||||
|
||||
Args:
|
||||
config: Server configuration
|
||||
tokenizer: Pre-initialized tokenizer (optional)
|
||||
tokenizer_name: Name of tokenizer to load if tokenizer not provided
|
||||
reasoning_config: Optional reasoning configuration
|
||||
"""
|
||||
# Build the OpenAI client pointing to the server's /v1 endpoint
|
||||
base_url = config.base_url
|
||||
if base_url and not base_url.endswith("/v1"):
|
||||
base_url = f"{base_url.rstrip('/')}/v1"
|
||||
|
||||
self.openai = openai.AsyncClient(
|
||||
api_key=config.api_key or "local", # Local servers often ignore auth
|
||||
base_url=base_url,
|
||||
timeout=config.timeout,
|
||||
)
|
||||
|
||||
# Initialize tokenizer
|
||||
if tokenizer is not None:
|
||||
self.tokenizer = tokenizer
|
||||
else:
|
||||
try:
|
||||
from transformers import AutoTokenizer # type: ignore
|
||||
except ModuleNotFoundError as exc:
|
||||
raise ModuleNotFoundError(
|
||||
"Missing optional dependency 'transformers'. Pass a tokenizer instance to LocalServer, "
|
||||
"or install transformers to enable `tokenizer_name` auto-loading."
|
||||
) from exc
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
||||
|
||||
# Add a simple chat template if the tokenizer doesn't have one
|
||||
# This is needed for ManagedServer's chat_completion to work
|
||||
if not hasattr(self.tokenizer, 'chat_template') or self.tokenizer.chat_template is None:
|
||||
# Simple ChatML-style template
|
||||
self.tokenizer.chat_template = (
|
||||
"{% for message in messages %}"
|
||||
"{% if message['role'] == 'system' %}<|im_start|>system\n{{ message['content'] }}<|im_end|>\n"
|
||||
"{% elif message['role'] == 'user' %}<|im_start|>user\n{{ message['content'] }}<|im_end|>\n"
|
||||
"{% elif message['role'] == 'assistant' %}<|im_start|>assistant\n{{ message['content'] }}<|im_end|>\n"
|
||||
"{% endif %}"
|
||||
"{% endfor %}"
|
||||
"{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}"
|
||||
)
|
||||
|
||||
super().__init__(config, reasoning_config=reasoning_config)
|
||||
# Local servers are treated as always-healthy unless a status task is enabled.
|
||||
self.server_healthy = True
|
||||
|
||||
@classmethod
|
||||
def from_env(
|
||||
cls,
|
||||
base_url: Optional[str] = None,
|
||||
model: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
tokenizer_name: str = "gpt2",
|
||||
**kwargs,
|
||||
) -> "LocalServer":
|
||||
"""
|
||||
Create a LocalServer from environment variables (or explicit overrides).
|
||||
|
||||
Env vars (checked in order):
|
||||
- base URL: ATROPOS_SERVER_BASE_URL, OPENAI_BASE_URL, LOCAL_LLM_BASE_URL, LLM_BASE_URL
|
||||
- model: ATROPOS_SERVER_MODEL, LLM_MODEL, LOCAL_LLM_MODEL
|
||||
- api key: ATROPOS_SERVER_API_KEY, OPENAI_API_KEY, LOCAL_LLM_API_KEY, LLM_API_KEY
|
||||
"""
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
|
||||
base_url = (
|
||||
base_url
|
||||
or os.getenv("ATROPOS_SERVER_BASE_URL")
|
||||
or os.getenv("OPENAI_BASE_URL")
|
||||
or os.getenv("LOCAL_LLM_BASE_URL")
|
||||
or os.getenv("LLM_BASE_URL")
|
||||
or "http://localhost:11434"
|
||||
)
|
||||
model = (
|
||||
model
|
||||
or os.getenv("ATROPOS_SERVER_MODEL")
|
||||
or os.getenv("LLM_MODEL")
|
||||
or os.getenv("LOCAL_LLM_MODEL")
|
||||
or "hermes3:8b"
|
||||
)
|
||||
api_key = (
|
||||
api_key
|
||||
or os.getenv("ATROPOS_SERVER_API_KEY")
|
||||
or os.getenv("OPENAI_API_KEY")
|
||||
or os.getenv("LOCAL_LLM_API_KEY")
|
||||
or os.getenv("LLM_API_KEY")
|
||||
)
|
||||
|
||||
config = APIServerConfig(
|
||||
model_name=model,
|
||||
base_url=base_url,
|
||||
api_key=api_key or "local",
|
||||
timeout=kwargs.get("timeout", 120),
|
||||
num_max_requests_at_once=kwargs.get("num_max_requests_at_once", 4),
|
||||
num_requests_for_eval=kwargs.get("num_requests_for_eval", 4),
|
||||
health_check=False, # Local dev servers often lack /health
|
||||
)
|
||||
|
||||
return cls(config, tokenizer_name=tokenizer_name)
|
||||
|
||||
async def check_server_status_task(self, chat_completion: bool = True):
|
||||
"""
|
||||
Check if the server is healthy.
|
||||
|
||||
For local development, we generally assume the server is healthy.
|
||||
"""
|
||||
while True:
|
||||
try:
|
||||
# Simple health check via a minimal completion
|
||||
if chat_completion:
|
||||
await self.openai.chat.completions.create(
|
||||
model=self.config.model_name,
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
max_tokens=1,
|
||||
)
|
||||
else:
|
||||
await self.openai.completions.create(
|
||||
model=self.config.model_name,
|
||||
prompt="hi",
|
||||
max_tokens=1,
|
||||
)
|
||||
self.server_healthy = True
|
||||
except Exception:
|
||||
self.server_healthy = False
|
||||
await asyncio.sleep(5)
|
||||
|
||||
async def _chat_completion_wrapper(self, **kwargs) -> ChatCompletion:
|
||||
"""
|
||||
Wrapper for chat completion using an OpenAI-compatible API.
|
||||
"""
|
||||
assert kwargs.get("model") is not None, "Model is required!"
|
||||
assert kwargs.get("messages") is not None, "Messages are required!"
|
||||
|
||||
n = kwargs.get("n", 1)
|
||||
|
||||
# Some OpenAI-compatible servers don't support n > 1, so we make multiple requests.
|
||||
if n > 1:
|
||||
completion_list = await asyncio.gather(
|
||||
*[self.openai.chat.completions.create(**{**kwargs, "n": 1}) for _ in range(n)]
|
||||
)
|
||||
# Merge completions
|
||||
completions = completion_list[0]
|
||||
for c in completion_list[1:]:
|
||||
for choice in c.choices:
|
||||
choice.index = len(completions.choices)
|
||||
completions.choices.append(choice)
|
||||
return completions
|
||||
else:
|
||||
return await self.openai.chat.completions.create(**kwargs)
|
||||
|
||||
async def _completion_wrapper(self, **kwargs) -> Completion:
|
||||
"""
|
||||
Wrapper for completion using an OpenAI-compatible API.
|
||||
"""
|
||||
assert kwargs.get("model") is not None, "Model is required!"
|
||||
assert kwargs.get("prompt") is not None, "Prompt is required!"
|
||||
|
||||
n = kwargs.get("n", 1)
|
||||
|
||||
# Some OpenAI-compatible servers don't support n > 1.
|
||||
if n > 1:
|
||||
completion_list = await asyncio.gather(
|
||||
*[self.openai.completions.create(**{**kwargs, "n": 1}) for _ in range(n)]
|
||||
)
|
||||
completions = completion_list[0]
|
||||
for c in completion_list[1:]:
|
||||
for choice in c.choices:
|
||||
choice.index = len(completions.choices)
|
||||
completions.choices.append(choice)
|
||||
return completions
|
||||
else:
|
||||
return await self.openai.completions.create(**kwargs)
|
||||
|
||||
async def _tokens_and_logprobs_completion_wrapper(
|
||||
self, **kwargs
|
||||
) -> tuple[List[int], List[List[int]], List[List[float]], List[str]]:
|
||||
"""
|
||||
Wrapper for tokens and logprobs completion.
|
||||
|
||||
Returns:
|
||||
Tuple of (prompt_tokens, output_tokens_list, output_logprobs_list, finish_reasons)
|
||||
|
||||
Note: Many OpenAI-compatible local servers don't return per-token logprobs,
|
||||
so we use placeholder logprobs (0.0). For real training, use vLLM/SGLang.
|
||||
"""
|
||||
model = kwargs.get("model")
|
||||
assert model is not None, "Model is required!"
|
||||
|
||||
# Handle input_ids (from ManagedServer) or prompt
|
||||
if "input_ids" in kwargs:
|
||||
prompt_tokens = kwargs.pop("input_ids")
|
||||
prompt = self.tokenizer.decode(prompt_tokens)
|
||||
kwargs.pop("prompt", None)
|
||||
else:
|
||||
prompt = kwargs.pop("prompt", "")
|
||||
prompt_tokens = self.tokenizer.encode(prompt, add_special_tokens=True)
|
||||
|
||||
n = kwargs.pop("n", 1)
|
||||
max_tokens = kwargs.pop("max_tokens", 256)
|
||||
temperature = kwargs.pop("temperature", 0.7)
|
||||
stop = kwargs.pop("stop", None)
|
||||
|
||||
# Make completion requests
|
||||
completions = []
|
||||
for _ in range(n):
|
||||
try:
|
||||
response = await self.openai.completions.create(
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
stop=stop,
|
||||
)
|
||||
completions.append(response)
|
||||
except Exception as e:
|
||||
# Fallback to chat completion if completion endpoint not supported
|
||||
warnings.warn(f"Completion API failed, trying chat: {e}")
|
||||
response = await self.openai.chat.completions.create(
|
||||
model=model,
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
stop=stop,
|
||||
)
|
||||
# Convert to completion-like response
|
||||
completions.append(response)
|
||||
|
||||
output_tokens_list = []
|
||||
output_logprobs_list = []
|
||||
finish_reasons = []
|
||||
|
||||
for completion in completions:
|
||||
# Extract text from response
|
||||
if hasattr(completion.choices[0], "text"):
|
||||
# Completion API response
|
||||
text = completion.choices[0].text
|
||||
finish_reason = completion.choices[0].finish_reason or "stop"
|
||||
else:
|
||||
# Chat completion API response
|
||||
text = completion.choices[0].message.content or ""
|
||||
finish_reason = completion.choices[0].finish_reason or "stop"
|
||||
|
||||
# Tokenize output
|
||||
output_tokens = self.tokenizer.encode(text, add_special_tokens=False)
|
||||
|
||||
# Placeholder logprobs (many local servers don't provide per-token logprobs).
|
||||
# In production, use vLLM/SGLang which return real logprobs
|
||||
output_logprobs = [0.0] * len(output_tokens)
|
||||
|
||||
output_tokens_list.append(output_tokens)
|
||||
output_logprobs_list.append(output_logprobs)
|
||||
finish_reasons.append(finish_reason)
|
||||
|
||||
return prompt_tokens, output_tokens_list, output_logprobs_list, finish_reasons
|
||||
|
||||
def managed_server(self, tokenizer=None, track_tree: bool = False):
|
||||
"""
|
||||
Create a ManagedServer context manager for this server.
|
||||
|
||||
Args:
|
||||
tokenizer: Optional tokenizer override
|
||||
track_tree: Whether to maintain tree structure for multi-turn
|
||||
|
||||
Returns:
|
||||
ManagedServer context manager
|
||||
"""
|
||||
from atroposlib.envs.server_handling.managed_server import ManagedServer
|
||||
|
||||
return ManagedServerContext(
|
||||
self,
|
||||
tokenizer=tokenizer or self.tokenizer,
|
||||
track_tree=track_tree,
|
||||
)
|
||||
|
||||
|
||||
class ManagedServerContext:
|
||||
"""
|
||||
Context manager wrapper for ManagedServer.
|
||||
|
||||
Usage:
|
||||
async with server.managed_server(tokenizer=tokenizer) as managed:
|
||||
response = await managed.chat_completion(...)
|
||||
state = managed.get_state()
|
||||
"""
|
||||
|
||||
def __init__(self, server: LocalServer, tokenizer, track_tree: bool = False):
|
||||
self.server = server
|
||||
self.tokenizer = tokenizer
|
||||
self.track_tree = track_tree
|
||||
self.managed = None
|
||||
|
||||
async def __aenter__(self):
|
||||
from atroposlib.envs.server_handling.managed_server import ManagedServer
|
||||
|
||||
self.managed = ManagedServer(
|
||||
self.server,
|
||||
tokenizer=self.tokenizer,
|
||||
track_tree=self.track_tree,
|
||||
)
|
||||
return self.managed
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
if self.managed:
|
||||
self.managed.reset()
|
||||
return False
|
||||
136
memory-bank/activeContext.md
Normal file
136
memory-bank/activeContext.md
Normal file
@@ -0,0 +1,136 @@
|
||||
# Active Context
|
||||
|
||||
## Current Task: SWE Smith Oracle Env with Modal Backend
|
||||
|
||||
### Goal
|
||||
Run this command:
|
||||
```bash
|
||||
python environments/swe_smith_oracle_env.py process \
|
||||
--env.use_wandb false \
|
||||
--env.total_steps 2 \
|
||||
--env.group_size 1 \
|
||||
--env.max_items 2 \
|
||||
--env.tool_pool_mode modal \
|
||||
--env.modal_image python:3.11 \
|
||||
--env.modal_slots_per_sandbox 10 \
|
||||
--env.modal_min_sandboxes 1
|
||||
```
|
||||
|
||||
### What's Done
|
||||
1. ✅ **agent_loop.py** - Added `tool_handler` parameter
|
||||
- New param: `tool_handler=None` in `__init__`
|
||||
- When `self.tool_handler` is set, it's called INSTEAD of `handle_function_call()`
|
||||
- Signature: `async tool_handler(tool_name, args, task_id) -> str`
|
||||
- Shows `[sandbox]` instead of backend name in terminal preview
|
||||
|
||||
2. ✅ **Phase 2 ManagedServer + SGLang** - Fully working (previous session)
|
||||
|
||||
3. ✅ **hermes_base_env.py** - Sandbox routing in collect_trajectory() (THIS SESSION)
|
||||
- Refactored `collect_trajectory()` into:
|
||||
- `_use_sandbox_backend()` - checks if sandbox should be used
|
||||
- `_collect_trajectory_local()` - existing path (ToolContext + handle_function_call)
|
||||
- `_collect_trajectory_sandbox()` - NEW sandbox path with slot lifecycle
|
||||
- `_run_agent_loop()` - shared agent loop for Phase 1/2, accepts tool_handler
|
||||
- `_build_scored_item()` - shared scored item construction
|
||||
- Sandbox path:
|
||||
1. `backend.acquire(task_id)` → Slot
|
||||
2. `exec_tool` callable wrapping `backend.execute_batch([(slot, tool_name, args)])`
|
||||
3. `setup_trajectory_workspace(item, exec_tool=exec_tool)` → workspace_meta
|
||||
4. `sandbox_tool_handler` routes terminal→sandbox, other→local
|
||||
5. `_run_agent_loop(tool_handler=sandbox_tool_handler)`
|
||||
6. `verify_and_score_trajectory(item, result, exec_tool=exec_tool)`
|
||||
7. `backend.release(slot, reset_workspace=True)` in finally
|
||||
- Added `handle_function_call` import for non-terminal tool fallback
|
||||
|
||||
4. ✅ **swe_smith_oracle_env.py** - Sandbox hooks (THIS SESSION)
|
||||
- `setup_trajectory_workspace()` - bare repo cache + git worktree (ported from atropos/envs/swe_smith_oracle_env.py)
|
||||
- `verify_and_score_trajectory()` - install deps + run pytest in sandbox
|
||||
- `compute_reward()` retained for local (non-sandbox) path
|
||||
- Uses `exec_tool("bash", {"command": cmd}, timeout=600)` → `ExecutionResult`
|
||||
|
||||
5. ✅ **All tests pass**:
|
||||
- Syntax checks (ast.parse) on both files
|
||||
- Import checks (both modules import cleanly)
|
||||
- Method existence checks (all new methods present)
|
||||
- Signature checks (exec_tool, trajectory_id, workspace_meta params)
|
||||
- Backend integration (ModalSandboxConfig.from_agent_env_config, create_tool_backend)
|
||||
- `_use_sandbox_backend()` logic (True when modal+backend set, False otherwise)
|
||||
|
||||
6. ✅ **End-to-end test with Qwen 3 8B + Modal sandbox** (THIS SESSION)
|
||||
- RunPod endpoint: `0tx0ruuuo4f10c` (Qwen/Qwen3-8B via SGLang)
|
||||
- 5 terminal tool calls executed IN sandbox: `ls`, `git status`, `git log`, `cat parse.py`, `cat tests/`
|
||||
- In-sandbox verification: install deps + pytest → score=0.0 (model inspected but didn't fix)
|
||||
- Full token tracking with logprobs via Phase 2 ManagedServer
|
||||
- Key finding: Llama-3-8B template silently drops `tools=` param, Qwen 3 has full Hermes format support
|
||||
|
||||
### Current Task: Integrate Slot Pool Backend into tools/terminal_tool.py
|
||||
|
||||
#### Step 1: Add `_SlotPoolEnvironment` to `tools/terminal_tool.py`
|
||||
- New class alongside existing `_LocalEnvironment`, `_DockerEnvironment`, etc.
|
||||
- Routes through `atropos/backends/` (ModalToolBackend or NomadToolBackend)
|
||||
- N:M slot multiplexing: 5-10 sandboxes × 10 slots each = 50-100 concurrent
|
||||
- Singleton `_SlotPoolManager` (like `_ModalPoolManager`) manages backend lifecycle
|
||||
- `execute()` acquires slot → `backend.execute_batch([(slot, "bash", ...)])` → returns `{"output": ..., "returncode": ...}`
|
||||
- `cleanup()` releases slot back to pool
|
||||
|
||||
#### Step 2: Wire into `_create_environment()`
|
||||
- `TERMINAL_ENV=slot_pool` → `_SlotPoolEnvironment(...)`
|
||||
- Sub-config: `TERMINAL_SLOT_BACKEND=modal` or `TERMINAL_SLOT_BACKEND=nomad`
|
||||
- Reuse existing `TERMINAL_MODAL_*` and Nomad env vars for configuration
|
||||
|
||||
#### Step 3: Remove redundant `atropos/tools/` files
|
||||
- DELETE: `hermes_external_tools.py`, `build_registry.py`, `sandbox_stubs.py`, `toolset_resolver.py`
|
||||
- KEEP: `base.py` (ToolCall/ToolResult types), `tool_executor.py` (batched queue), `terminal_stateful_tool.py`, `tmux_tool.py`
|
||||
|
||||
#### Step 4: Clean up `atropos/envs/` and `atropos/agent/` (defer)
|
||||
- Remove `atropos/envs/agent_env.py` → replaced by `environments/hermes_base_env.py`
|
||||
- Remove `atropos/agent/atropos_agent.py` → replaced by `environments/agent_loop.py`
|
||||
|
||||
#### Later
|
||||
- Test with Tinker trainer (blocked on billing)
|
||||
- Add more environments (endless-terminals, terminalbench 2)
|
||||
|
||||
### Key Architecture Insight
|
||||
Two separate sandbox integration points:
|
||||
1. **`tools/terminal_tool.py` with `TERMINAL_ENV=slot_pool`** — for hermes CLI, batch_runner, any code using `handle_function_call("terminal", ...)`. Uses `_SlotPoolEnvironment` which wraps `atropos/backends/`.
|
||||
2. **`environments/hermes_base_env.py` with `tool_pool_mode=modal/nomad`** — for RL environments. Uses `_collect_trajectory_sandbox()` which directly acquires slots and creates `sandbox_tool_handler`.
|
||||
|
||||
Both use the same underlying `atropos/backends/` (ModalToolBackend, NomadToolBackend) with the same slot pool.
|
||||
|
||||
### Architecture Summary
|
||||
|
||||
```
|
||||
environments/hermes_base_env.py (HermesAgentBaseEnv)
|
||||
│
|
||||
├── tool_pool_mode="default" (existing path)
|
||||
│ └── collect_trajectory() → HermesAgentLoop(tool_handler=None)
|
||||
│ → handle_function_call() → hermes terminal tool (local)
|
||||
│
|
||||
└── tool_pool_mode="modal" or "nomad" (new path)
|
||||
└── collect_trajectory():
|
||||
1. slot = backend.acquire(task_id)
|
||||
2. exec_tool = lambda routing through backend.execute_batch
|
||||
3. setup_trajectory_workspace(item, exec_tool=exec_tool) [subclass hook]
|
||||
4. HermesAgentLoop(tool_handler=sandbox_tool_handler)
|
||||
→ terminal calls → backend.execute_batch(slot, "bash", ...)
|
||||
5. verify_and_score_trajectory(item, result, exec_tool=exec_tool) [subclass hook]
|
||||
6. backend.release(slot, reset_workspace=True)
|
||||
|
||||
atropos/backends/modal_backend.py (ModalToolBackend)
|
||||
└── acquire(trajectory_id) → Slot
|
||||
└── execute_batch([(slot, "bash", {"command": "..."})]) → [ExecutionResult]
|
||||
└── release(slot, reset_workspace=True)
|
||||
```
|
||||
|
||||
### Key Files to Modify
|
||||
1. `environments/hermes_base_env.py` - Add sandbox path in `collect_trajectory()`
|
||||
2. `environments/swe_smith_oracle_env.py` - Override `setup_trajectory_workspace()` and `verify_and_score_trajectory()` to use exec_tool
|
||||
|
||||
### Important Notes
|
||||
- `exec_tool` returns `ExecutionResult` (from `atropos/slots/executor.py`) with `.success`, `.output`, `.error`, `.metadata`
|
||||
- `tool_handler` returns JSON string (for agent loop message format)
|
||||
- These are DIFFERENT interfaces for different purposes:
|
||||
- `exec_tool`: used by env hooks (setup/verify) - returns structured result
|
||||
- `tool_handler`: used by agent loop - returns JSON string like hermes tools do
|
||||
- The ModalToolBackend.execute_batch calls _ModalSandboxWithSlots.execute which runs `sandbox.exec("bash", "-c", command)` on Modal
|
||||
- For the SWE env, the worktree setup pattern from `atropos/envs/swe_smith_oracle_env.py` should be reused (bare repo cache + worktree add)
|
||||
55
memory-bank/productContext.md
Normal file
55
memory-bank/productContext.md
Normal file
@@ -0,0 +1,55 @@
|
||||
# Product Context: Hermes-Agent
|
||||
|
||||
## Why This Project Exists
|
||||
|
||||
Hermes-Agent addresses several key challenges in the AI agent space:
|
||||
|
||||
1. **Unified Tool Interface** - Provides a clean, consistent interface for LLMs to use various tools (web, terminal, browser, vision, etc.) without requiring custom integration for each model provider.
|
||||
|
||||
2. **Training Data Generation** - Enables efficient generation of high-quality tool-calling trajectories for fine-tuning LLMs, with features like batch processing, checkpointing, and trajectory compression.
|
||||
|
||||
3. **Flexible Deployment** - Supports multiple execution environments (local, Docker, Singularity, Modal, SSH) to accommodate different security and isolation requirements.
|
||||
|
||||
4. **Developer Experience** - Offers a beautiful, interactive CLI with kawaii-style feedback that makes working with AI agents enjoyable.
|
||||
|
||||
## Problems It Solves
|
||||
|
||||
### For AI Researchers
|
||||
- **Data Generation at Scale**: Parallel batch processing with content-based checkpointing for fault tolerance
|
||||
- **Clean Trajectories**: Trajectory compression to fit token budgets while preserving important information
|
||||
- **Toolset Distributions**: Probability-based tool selection for varied training data
|
||||
|
||||
### For Developers
|
||||
- **Tool Orchestration**: Logical grouping of tools into toolsets (research, development, debugging, etc.)
|
||||
- **Session Persistence**: Conversation history and session logging for debugging
|
||||
- **Multi-Model Support**: Works with any OpenAI-compatible API (OpenRouter, local models, etc.)
|
||||
|
||||
### For MLOps
|
||||
- **Skills System**: On-demand knowledge documents for specific tools/frameworks (Axolotl, vLLM, TRL, etc.)
|
||||
- **Sandboxed Execution**: Terminal commands can run in isolated environments (Docker, Singularity, Modal)
|
||||
- **Configurable Backends**: Easy switching between local and cloud execution
|
||||
|
||||
## How It Should Work
|
||||
|
||||
### User Flow (CLI)
|
||||
1. User launches `./hermes`
|
||||
2. Beautiful welcome banner displays with caduceus logo, model info, and available tools
|
||||
3. User types a natural language request
|
||||
4. Agent processes request, potentially calling tools with animated feedback
|
||||
5. Agent responds with results, conversation continues
|
||||
6. Session is automatically logged for debugging
|
||||
|
||||
### User Flow (Batch Processing)
|
||||
1. User prepares JSONL file with prompts
|
||||
2. Runs `batch_runner.py` with distribution and worker count
|
||||
3. System processes prompts in parallel, saves checkpoints
|
||||
4. Completed trajectories saved to `data/<run_name>/trajectories.jsonl`
|
||||
5. Optional: compress trajectories with `trajectory_compressor.py`
|
||||
|
||||
## User Experience Goals
|
||||
|
||||
- **Delightful Interaction**: Kawaii ASCII faces, animated spinners, cute messages
|
||||
- **Informative Feedback**: Clear progress indication during tool execution
|
||||
- **Configurable Personalities**: From "helpful" to "pirate" to "Shakespeare"
|
||||
- **Easy Configuration**: YAML config file + environment variables + CLI flags
|
||||
- **Graceful Degradation**: Missing tools/APIs don't break the system, just disable features
|
||||
134
memory-bank/progress.md
Normal file
134
memory-bank/progress.md
Normal file
@@ -0,0 +1,134 @@
|
||||
# Progress
|
||||
|
||||
## Current Sprint: Phase 2 ManagedServer + SGLang Working (Feb 10, 2026)
|
||||
|
||||
### ✅ Phase 2 End-to-End Pipeline VERIFIED
|
||||
Full pipeline working: GSM8k env → collect_trajectory → ManagedServer → VLLMServer (SGLang patched) → tokens + logprobs + masks.
|
||||
|
||||
Test results:
|
||||
- 212 tokens with logprobs and masks from single trajectory
|
||||
- Reward: 1.0 (correct answer)
|
||||
- ScoredDataItem has all required fields: tokens, masks, scores, advantages, ref_logprobs, messages
|
||||
- RunPod SGLang endpoint (b9zmuyn1carwya) with Llama-3-8B-Instruct
|
||||
|
||||
### Consolidation Checklist
|
||||
- [x] Install atropos `tool_call_support` branch (PR #366)
|
||||
- [x] Create `environments/gsm8k_agent_env.py` using `HermesAgentBaseEnv`
|
||||
- [x] Create `environments/agent_loop.py` with proper OpenAI-spec tool calling
|
||||
- [x] Create `environments/tool_call_parsers/` with 13 parsers
|
||||
- [x] Create `environments/patches.py` for SGLang compatibility
|
||||
- [x] Add sandbox pool support to `HermesAgentBaseEnv`
|
||||
- [x] Test Phase 1 (OpenAI server type) with Nous API — WORKS
|
||||
- [x] Test Phase 2 (ManagedServer) with RunPod SGLang — WORKS
|
||||
- [x] Port SWE env to `HermesAgentBaseEnv` with multiplexed sandboxing
|
||||
- [x] End-to-end test: Qwen 3 8B + Modal sandbox + tool calls in sandbox + pytest verification
|
||||
- [x] Add `_SlotPoolEnvironment` to `tools/terminal_tool.py` (TERMINAL_ENV=slot_pool)
|
||||
- [x] Remove redundant `atropos/tools/` files (4 of 8)
|
||||
- [ ] Remove redundant `atropos/agent/` and `atropos/envs/agent_env.py` (deferred)
|
||||
- [ ] Test end-to-end with Tinker trainer (blocked on billing)
|
||||
|
||||
### ✅ End-to-End SWE + Modal Sandbox Verified (Feb 10, 2026)
|
||||
- Qwen 3 8B on RunPod SGLang (endpoint `0tx0ruuuo4f10c`)
|
||||
- Phase 2 ManagedServer with hermes tool call parser
|
||||
- 5 terminal commands executed in Modal sandbox: ls, git status, git log, cat parse.py, cat tests/
|
||||
- In-sandbox verification: install deps + pytest → score 0.0 (model inspected but didn't fix)
|
||||
- Full token tracking with logprobs via /generate endpoint
|
||||
- Key finding: Llama-3-8B template drops tools= silently; Qwen 3 has full Hermes tool format
|
||||
|
||||
## Completed Features
|
||||
|
||||
### ✅ Phase 2 ManagedServer + SGLang (Feb 10, 2026)
|
||||
- SGLang patch in `environments/patches.py` monkey-patches VLLMServer
|
||||
- Handles SGLang's different request/response format vs VLLM
|
||||
- Handles RunPod's double-JSON wrapping
|
||||
- Full chain verified: ManagedServer → VLLMServer → _tokens_and_logprobs_comp (retry) → patched wrapper → /generate endpoint
|
||||
- SequenceNode tracking: tokens, logprobs, masked_tokens all populated
|
||||
- **Key discovery**: The AttributeError from earlier was NOT in our current code — likely from a prior code state
|
||||
|
||||
### ✅ Phase 1 OpenAI Server Mode (Feb 9-10, 2026)
|
||||
- GSM8k env works with Nous API (OpenRouter-style endpoint)
|
||||
- Terminal tool calls properly dispatched
|
||||
- Tool call parsing handled natively by server (VLLM/SGLang /v1/chat/completions)
|
||||
- Reward computation verified (math_verify for robust LaTeX comparison)
|
||||
|
||||
### ✅ Sandbox Pool Integration (Feb 10, 2026)
|
||||
- Config fields added to `HermesAgentEnvConfig` for Nomad and Modal
|
||||
- `_start_sandbox_backend()` / `_stop_sandbox_backend()` lifecycle methods
|
||||
- Optional hooks: `setup_trajectory_workspace()`, `verify_and_score_trajectory()`
|
||||
- Integrated into `env_manager()` and `process_manager()` cleanup
|
||||
|
||||
### ✅ Tool Call Parsers (Feb 9-10, 2026)
|
||||
- 13 parsers: hermes, llama3_json, llama4_json, qwen, qwen3_coder, deepseek_v3, deepseek_v31, glm45, glm47, mistral, kimi_k2, longcat
|
||||
- Registry pattern: `get_parser("hermes")` returns parser instance
|
||||
- Each parser: `.parse(text) → (content, tool_calls)`
|
||||
- Used by ManagedServer in Phase 2 to extract structured tool_calls from raw completion
|
||||
|
||||
### ✅ Modal Backend Integration (Feb 8, 2026)
|
||||
- `ModalToolBackend` with slot-based multiplexing
|
||||
- Multi-profile support (CPU, GPU, high-memory)
|
||||
- Auto-scaling sandbox pool via Modal Sandboxes
|
||||
|
||||
### ✅ Main Branch Merge (Feb 9, 2026)
|
||||
- Merged 22,560 lines, 79 files, 5 conflicts resolved
|
||||
- New: hermes_cli/, file_operations, RL training tools, gateway, cron
|
||||
|
||||
### ✅ Tinker RL Training Setup (Feb 9, 2026)
|
||||
- tinker 0.12.0 + tinker-atropos installed
|
||||
- GSM8k agent config created
|
||||
- Pipeline verified: Tinker API connection works, all imports pass
|
||||
- **Blocked on billing** (Tinker 402 error)
|
||||
|
||||
### ✅ Singularity/Apptainer Sandbox (Feb 6, 2026)
|
||||
- Nomad raw_exec driver for HPC clusters
|
||||
- All sandbox operations tested and working
|
||||
|
||||
### ✅ Memory Bank (Feb 5, 2026)
|
||||
- Project documentation structure initialized
|
||||
|
||||
## What to KEEP vs REMOVE
|
||||
|
||||
### KEEP (valuable infrastructure):
|
||||
| Component | Location | Purpose |
|
||||
|-----------|----------|---------|
|
||||
| Modal backend | `atropos/backends/modal_backend.py` | Cloud sandbox pool |
|
||||
| Nomad backend | `atropos/backends/nomad_backend.py` | Docker/Singularity sandboxes |
|
||||
| Slot pool | `atropos/slots/` | Container multiplexing |
|
||||
| Nomad client | `atropos/nomad/` | Nomad API |
|
||||
| Sandbox server | `atropos/sandbox_server.py` | HTTP server in containers |
|
||||
| Dockerfile | `atropos/Dockerfile` | Container image |
|
||||
| Agent loop | `environments/agent_loop.py` | Proper OpenAI-spec tool calling |
|
||||
| Base env | `environments/hermes_base_env.py` | Phase 1/2 with parsers |
|
||||
| Tool parsers | `environments/tool_call_parsers/` | 13 model parsers |
|
||||
| SGLang patch | `environments/patches.py` | SGLang compatibility |
|
||||
|
||||
### REMOVE (redundant with environments/):
|
||||
| Component | Location | Replaced By |
|
||||
|-----------|----------|-------------|
|
||||
| ICL agent | `atropos/agent/atropos_agent.py` | `environments/agent_loop.py` |
|
||||
| AgentEnv | `atropos/envs/agent_env.py` | `environments/hermes_base_env.py` |
|
||||
| Tool registry | `atropos/tools/` | `model_tools.py` + `tools/` |
|
||||
| GSM8k ICL env | `tinker-atropos/.../gsm8k_agent.py` | `environments/gsm8k_agent_env.py` |
|
||||
|
||||
## Known Issues
|
||||
- Tinker billing (402 error) - user's payment didn't process
|
||||
- `bwrap_available: false` in Singularity containers
|
||||
- Llama-3-8B-Instruct doesn't reliably produce tool calls via Phase 2 (needs Hermes-format model)
|
||||
- Model answered GSM8k correctly but didn't actually USE the terminal tool (computed mentally)
|
||||
|
||||
## Evolution of Decisions
|
||||
|
||||
### Agent Architecture
|
||||
- **v1 (our branch)**: ICL-based agent with `<tool_call>` XML tags in system prompt
|
||||
- **v2 (Teknium's)**: Proper OpenAI-spec tool calling with `tools=` parameter
|
||||
- **Decision**: Adopt v2, consolidate into `environments/`, keep sandbox backends from v1
|
||||
|
||||
### Environment Organization
|
||||
- **Before**: Two parallel systems (`atropos/envs/` and `environments/`)
|
||||
- **After**: Single system in `environments/`, using `HermesAgentBaseEnv` as base class
|
||||
- Sandbox backends remain in `atropos/backends/` but integrate via terminal backend config
|
||||
|
||||
### Phase 2 SGLang Support
|
||||
- **Problem**: VLLMServer hardcoded for VLLM's /generate format, SGLang is different
|
||||
- **Solution**: Monkey-patch `_tokens_and_logprobs_completion_wrapper` in `environments/patches.py`
|
||||
- **Applied**: Automatically at import time via `apply_patches()` in `hermes_base_env.py`
|
||||
- **Handles**: SGLang format differences AND RunPod's double-JSON wrapping
|
||||
44
memory-bank/projectbrief.md
Normal file
44
memory-bank/projectbrief.md
Normal file
@@ -0,0 +1,44 @@
|
||||
# Project Brief: Hermes-Agent
|
||||
|
||||
## Overview
|
||||
Hermes-Agent is an AI agent harness for LLMs with advanced tool-calling capabilities, featuring a flexible toolsets system for organizing and managing tools. Named after Hermes, the Greek messenger god, it serves as a bridge between human intent and AI-powered task execution.
|
||||
|
||||
## Core Requirements
|
||||
|
||||
### Primary Goals
|
||||
1. **Interactive CLI Experience** - Beautiful terminal interface with animated feedback, personalities, and session management
|
||||
2. **Flexible Tool System** - Modular tools organized into logical toolsets for different use cases
|
||||
3. **Batch Processing** - Process multiple prompts in parallel with checkpointing and statistics
|
||||
4. **Multi-Backend Support** - Support for local, Docker, Singularity, Modal, and SSH terminal backends
|
||||
5. **Training Data Generation** - Save conversation trajectories in formats suitable for LLM fine-tuning
|
||||
|
||||
### Target Users
|
||||
- AI researchers generating training data
|
||||
- Developers needing an AI assistant with tool access
|
||||
- MLOps practitioners automating workflows
|
||||
- Anyone needing a powerful CLI-based AI agent
|
||||
|
||||
## Scope
|
||||
|
||||
### In Scope
|
||||
- Interactive CLI with rich formatting and kawaii-style feedback
|
||||
- Web tools (search, extract, crawl via Firecrawl)
|
||||
- Terminal tools (command execution across multiple backends)
|
||||
- Browser automation (via agent-browser + Browserbase)
|
||||
- Vision tools (image analysis)
|
||||
- Image generation (FLUX via FAL.ai)
|
||||
- Mixture-of-Agents reasoning
|
||||
- Skills system for on-demand knowledge
|
||||
- Batch processing with parallel workers
|
||||
- Trajectory compression for training
|
||||
|
||||
### Out of Scope (Current)
|
||||
- Proactive suggestions (agent only runs on request)
|
||||
- Clipboard integration (no local system access)
|
||||
- Real-time streaming of thinking/reasoning (deferred)
|
||||
|
||||
## Success Metrics
|
||||
- Clean, maintainable tool architecture
|
||||
- Reliable tool execution with proper error handling
|
||||
- Efficient context management for long conversations
|
||||
- High-quality trajectory data for training
|
||||
267
memory-bank/systemPatterns.md
Normal file
267
memory-bank/systemPatterns.md
Normal file
@@ -0,0 +1,267 @@
|
||||
# System Patterns: Hermes-Agent
|
||||
|
||||
## Architecture Overview
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────────┐
|
||||
│ CLI (cli.py) │
|
||||
│ - Rich welcome banner with caduceus │
|
||||
│ - prompt_toolkit for input with history │
|
||||
│ - Kawaii-style feedback and personalities │
|
||||
└────────────────────────────┬────────────────────────────────────┘
|
||||
│
|
||||
▼
|
||||
┌─────────────────────────────────────────────────────────────────┐
|
||||
│ AIAgent (run_agent.py) │
|
||||
│ - Conversation loop with tool calling │
|
||||
│ - KawaiiSpinner for animated feedback │
|
||||
│ - Retry logic with exponential backoff │
|
||||
│ - Session logging to logs/ directory │
|
||||
└────────────────────────────┬────────────────────────────────────┘
|
||||
│
|
||||
▼
|
||||
┌─────────────────────────────────────────────────────────────────┐
|
||||
│ Tool Routing (model_tools.py) │
|
||||
│ - get_tool_definitions() - returns tools for API calls │
|
||||
│ - handle_function_call() - dispatches to tool handlers │
|
||||
│ - Toolset filtering (enabled/disabled) │
|
||||
└────────────────────────────┬────────────────────────────────────┘
|
||||
│
|
||||
┌─────────────────┼─────────────────┐
|
||||
▼ ▼ ▼
|
||||
┌───────────┐ ┌───────────┐ ┌───────────┐
|
||||
│ Web Tools │ │ Terminal │ │ Browser │
|
||||
│ (Firecrawl)│ │ (mini-swe)│ │(agent-brw)│
|
||||
└───────────┘ └───────────┘ └───────────┘
|
||||
│ │ │
|
||||
└─────────────────┼─────────────────┘
|
||||
▼
|
||||
┌───────────────┐
|
||||
│ Toolsets │
|
||||
│ (toolsets.py)│
|
||||
│ Composition │
|
||||
└───────────────┘
|
||||
```
|
||||
|
||||
## Key Design Patterns
|
||||
|
||||
### 1. Toolset Composition Pattern
|
||||
Toolsets can include other toolsets, allowing flexible composition:
|
||||
|
||||
```python
|
||||
TOOLSETS = {
|
||||
"web": {"tools": ["web_search", "web_extract"], "includes": []},
|
||||
"debugging": {"tools": ["terminal"], "includes": ["web"]},
|
||||
"full_stack": {"tools": [], "includes": ["web", "terminal", "vision", "browser"]}
|
||||
}
|
||||
```
|
||||
|
||||
Resolution is recursive with cycle detection.
|
||||
|
||||
### 2. Graceful Degradation Pattern
|
||||
Each tool module has a `check_*_requirements()` function:
|
||||
- Tools are only loaded if requirements are met
|
||||
- Missing API keys disable tools, not crash the system
|
||||
- Import errors are caught and tools marked unavailable
|
||||
|
||||
```python
|
||||
try:
|
||||
from tools.web_tools import web_search_tool, check_firecrawl_api_key
|
||||
except ModuleNotFoundError:
|
||||
web_search_tool = None
|
||||
def check_firecrawl_api_key(): return False
|
||||
```
|
||||
|
||||
### 3. Session Isolation Pattern (task_id)
|
||||
Stateful tools (terminal, browser) use `task_id` to isolate concurrent sessions:
|
||||
- Each batch worker gets unique task_id
|
||||
- VMs and browser sessions are tracked per task_id
|
||||
- Cleanup functions release resources: `cleanup_vm(task_id)`, `cleanup_browser(task_id)`
|
||||
|
||||
### 4. Trajectory Format Pattern
|
||||
Conversations are saved in ShareGPT format for training:
|
||||
|
||||
```json
|
||||
{"from": "system", "value": "System prompt with <tools>...</tools>"}
|
||||
{"from": "human", "value": "User message"}
|
||||
{"from": "gpt", "value": "<think>reasoning</think>\n<tool_call>{...}</tool_call>"}
|
||||
{"from": "tool", "value": "<tool_response>{...}</tool_response>"}
|
||||
{"from": "gpt", "value": "Final response"}
|
||||
```
|
||||
|
||||
### 5. Ephemeral System Prompt Pattern
|
||||
Guide model behavior during data collection without saving to trajectories:
|
||||
- `ephemeral_system_prompt` influences execution
|
||||
- Only standard tool-calling system prompt saved to trajectories
|
||||
- Keeps training data clean
|
||||
|
||||
### 6. Retry with Validation Pattern
|
||||
The agent validates responses before accepting:
|
||||
- Check tool names against `valid_tool_names` set
|
||||
- Validate JSON arguments can be parsed
|
||||
- Check for content after `<think>` blocks
|
||||
- Roll back to last valid state on persistent failures
|
||||
|
||||
## Component Relationships
|
||||
|
||||
### AIAgent Class
|
||||
- Central orchestrator for conversations
|
||||
- Manages conversation history
|
||||
- Calls OpenAI-compatible API
|
||||
- Routes tool calls to handlers
|
||||
- Provides animated feedback (KawaiiSpinner)
|
||||
|
||||
### Tool Modules (tools/*.py)
|
||||
- Self-contained tool implementations
|
||||
- Export: handler function + check function + schema
|
||||
- Return JSON strings (never raw dicts)
|
||||
- Accept optional `task_id` for stateful tools
|
||||
|
||||
### Toolsets System (toolsets.py)
|
||||
- Defines logical groupings of tools
|
||||
- Supports composition via `includes`
|
||||
- `resolve_toolset()` recursively resolves all tools
|
||||
- `validate_toolset()` checks if name is valid
|
||||
|
||||
### Model Tools (model_tools.py)
|
||||
- Aggregates all tool definitions
|
||||
- Routes function calls to correct handlers
|
||||
- Filters tools based on enabled/disabled toolsets
|
||||
- Bridge between agent and tool implementations
|
||||
|
||||
## Critical Implementation Paths
|
||||
|
||||
### Tool Execution Flow
|
||||
1. AIAgent receives tool_calls from API response
|
||||
2. Validates tool names against `valid_tool_names`
|
||||
3. Validates JSON arguments can be parsed
|
||||
4. Calls `handle_function_call()` with tool name, args, task_id
|
||||
5. `handle_function_call()` routes to appropriate handler
|
||||
6. Tool executes, returns JSON string
|
||||
7. Result added to conversation as tool message
|
||||
8. Loop continues until natural language response
|
||||
|
||||
### Configuration Loading Flow
|
||||
1. `cli.py` calls `load_cli_config()`
|
||||
2. Loads `cli-config.yaml`, merges with defaults
|
||||
3. Sets environment variables for terminal config
|
||||
4. `AIAgent` reads env vars when initializing terminal tool
|
||||
5. Terminal tool creates appropriate backend based on `TERMINAL_ENV`
|
||||
|
||||
## RL Training Architecture (Consolidated)
|
||||
|
||||
### Environment System (`environments/`)
|
||||
|
||||
The canonical way to build agentic RL environments in Hermes-Agent:
|
||||
|
||||
```
|
||||
environments/
|
||||
├── agent_loop.py ← HermesAgentLoop: OpenAI-spec tool calling
|
||||
├── hermes_base_env.py ← HermesAgentBaseEnv: base class for all envs
|
||||
├── tool_context.py ← ToolContext: reward function tool access
|
||||
├── tool_call_parsers/ ← 11+ model parsers (hermes, qwen, deepseek, etc.)
|
||||
├── terminal_test_env.py ← Example: file creation tasks
|
||||
├── hermes_swe_env.py ← SWE environment
|
||||
└── gsm8k_agent_env.py ← GSM8k with Python REPL (TODO)
|
||||
```
|
||||
|
||||
### Two-Phase Operation
|
||||
- **Phase 1 (OpenAI server)**: Native tool_calls from VLLM/SGLang/OpenRouter
|
||||
- Good for: SFT data gen, testing, evaluation
|
||||
- Server handles tool call parsing via `/v1/chat/completions`
|
||||
- **Phase 2 (ManagedServer)**: Client-side tool call parser + logprob tracking
|
||||
- Required for: RL training (exact token IDs + logprobs for GRPO/PPO)
|
||||
- Uses `/generate` endpoint for raw token output
|
||||
- Parser registry selects per-model parser (hermes, qwen, llama, etc.)
|
||||
- **Verified working** with RunPod SGLang endpoint (Feb 10, 2026)
|
||||
|
||||
### Phase 2 Call Chain (Verified)
|
||||
```
|
||||
collect_trajectory()
|
||||
→ ServerManager.managed_server(tokenizer, tool_call_parser)
|
||||
→ ManagedServer(server=VLLMServer)
|
||||
→ ManagedServer.chat_completion(messages, tools, n, max_tokens, temp)
|
||||
→ _convert_messages_to_prompt(messages, tools=tools) [apply_chat_template]
|
||||
→ _compute_input_ids(prompt, extending_node)
|
||||
→ VLLMServer.tokens_and_logprobs_completion(**kwargs) [public method]
|
||||
→ _tokens_and_logprobs_comp(stat_dict, **kwargs) [retry decorator, semaphore]
|
||||
→ _tokens_and_logprobs_completion_wrapper(**kwargs) [patched for SGLang]
|
||||
→ aiohttp POST to /generate
|
||||
→ Returns (prompt_tokens, [output_tokens], [output_logprobs], [finish_reasons])
|
||||
→ _create_sequence_node(...) [stores in current_nodes]
|
||||
→ tool_call_parser.parse(completion_text) [if parser configured]
|
||||
→ Returns ChatCompletion with tool_calls
|
||||
```
|
||||
|
||||
### SGLang Compatibility Patch (`environments/patches.py`)
|
||||
VLLMServer's `_tokens_and_logprobs_completion_wrapper` is monkey-patched to handle SGLang's
|
||||
different request/response format. Applied automatically at import time via `apply_patches()`.
|
||||
|
||||
```
|
||||
SGLang request: {"input_ids": [...], "sampling_params": {...}, "return_logprob": true}
|
||||
SGLang response: {"meta_info": {"output_token_logprobs": [[logprob, token_id, text], ...]}}
|
||||
|
||||
VLLM request: {"prompt": {"prompt_token_ids": [...]}, "logprobs": 0}
|
||||
VLLM response: {"logprobs": [[{token_id: logprob}]], "finish_reasons": [...]}
|
||||
```
|
||||
|
||||
Also handles RunPod serverless double-JSON wrapping (response body wrapped in quotes).
|
||||
|
||||
### Key Design: Proper Tool Calling (NOT ICL)
|
||||
```python
|
||||
# CORRECT: pass tools= to chat_completion()
|
||||
response = await server.chat_completion(
|
||||
messages=messages,
|
||||
tools=tool_schemas, # ← tokenizer.apply_chat_template(tools=...) formats these
|
||||
temperature=1.0,
|
||||
)
|
||||
# Response has response.choices[0].message.tool_calls (structured objects)
|
||||
|
||||
# WRONG (old approach): embed tools in system prompt as XML
|
||||
system_prompt = f"<tools>{json.dumps(tools)}</tools>" # ← ICL, not proper training format
|
||||
```
|
||||
|
||||
### Sandbox Backends (`atropos/backends/`)
|
||||
|
||||
Infrastructure for scaled sandbox execution, integrated into HermesAgentBaseEnv:
|
||||
|
||||
```
|
||||
ToolBackend (Protocol)
|
||||
├── NomadToolBackend → SlotPool → NomadClient + SandboxExecutor (HTTP)
|
||||
│ ├── Docker driver (default)
|
||||
│ └── Singularity driver (HPC)
|
||||
└── ModalToolBackend → _ModalSandboxPool → modal.Sandbox.exec() (direct)
|
||||
└── _ModalMultiProfileManager (multi-profile support)
|
||||
```
|
||||
|
||||
Two execution modes in HermesAgentBaseEnv (controlled by `tool_pool_mode` config):
|
||||
- `default` - Local tool execution via handle_function_call() + ToolContext
|
||||
- `modal` / `nomad` - Sandbox routing: slot acquire → setup workspace → agent loop → verify → release
|
||||
|
||||
Sandbox routing architecture:
|
||||
```
|
||||
collect_trajectory()
|
||||
├── tool_pool_mode="default" → _collect_trajectory_local()
|
||||
│ └── _run_agent_loop(tool_handler=None) → compute_reward(ctx)
|
||||
│
|
||||
└── tool_pool_mode="modal"/"nomad" → _collect_trajectory_sandbox()
|
||||
├── backend.acquire(task_id) → Slot
|
||||
├── exec_tool = backend.execute_batch wrapper → ExecutionResult
|
||||
├── setup_trajectory_workspace(item, exec_tool) [subclass hook]
|
||||
├── _run_agent_loop(tool_handler=sandbox_tool_handler)
|
||||
│ └── terminal → backend.execute_batch → JSON string
|
||||
│ └── other tools → handle_function_call (local)
|
||||
├── verify_and_score_trajectory(item, result, exec_tool) [subclass hook]
|
||||
└── backend.release(slot, reset_workspace=True) [finally]
|
||||
```
|
||||
|
||||
Key interfaces:
|
||||
- `exec_tool(tool_name, args, timeout)` → `ExecutionResult` (for env hooks)
|
||||
- `tool_handler(tool_name, args, task_id)` → JSON string (for agent loop)
|
||||
|
||||
### Training Pipeline (Tinker + Atropos)
|
||||
```
|
||||
Terminal 1: run-api (port 8000) ← Atropos Rollout API
|
||||
Terminal 2: launch_training.py (port 8001) ← Tinker Trainer + inference
|
||||
Terminal 3: environment.py serve ← Environment (rollouts)
|
||||
```
|
||||
113
memory-bank/techContext.md
Normal file
113
memory-bank/techContext.md
Normal file
@@ -0,0 +1,113 @@
|
||||
# Technical Context: Hermes-Agent
|
||||
|
||||
## Technologies Used
|
||||
|
||||
### Core Stack
|
||||
- **Python 3.11+** - Primary language
|
||||
- **OpenAI SDK** - For LLM API interactions (OpenAI-compatible)
|
||||
- **OpenRouter** - Default LLM provider (supports multiple models)
|
||||
- **Rich** - Terminal formatting and panels
|
||||
- **prompt_toolkit** - Interactive input with history
|
||||
- **Fire** - CLI argument parsing
|
||||
- **PyYAML** - Configuration files
|
||||
- **python-dotenv** - Environment variable management
|
||||
|
||||
### Tool Dependencies
|
||||
- **Firecrawl** - Web search and extraction (`FIRECRAWL_API_KEY`)
|
||||
- **mini-swe-agent** - Terminal tool backend (local/docker/singularity/modal/ssh)
|
||||
- **agent-browser** - Browser automation (npm package)
|
||||
- **Browserbase** - Cloud browser execution (`BROWSERBASE_API_KEY`)
|
||||
- **FAL.ai** - Image generation with FLUX (`FAL_KEY`)
|
||||
- **Nous API** - Vision and MoA tools (`NOUS_API_KEY`)
|
||||
|
||||
### Optional Dependencies
|
||||
- **Modal** - Cloud compute for sandboxed environments
|
||||
- **Singularity/Apptainer** - Rootless containers (HPC environments)
|
||||
- **Docker** - Container isolation
|
||||
|
||||
## Development Setup
|
||||
|
||||
### Quick Start
|
||||
```bash
|
||||
# Clone with submodules
|
||||
git clone --recurse-submodules https://github.com/NousResearch/Hermes-Agent.git
|
||||
cd Hermes-Agent
|
||||
|
||||
# Create virtual environment
|
||||
python3 -m venv venv
|
||||
source venv/bin/activate
|
||||
|
||||
# Install dependencies
|
||||
pip install -r requirements.txt
|
||||
pip install -e ./mini-swe-agent
|
||||
|
||||
# Install browser tools (optional)
|
||||
npm install
|
||||
|
||||
# Configure environment
|
||||
cp .env.example .env
|
||||
# Edit .env with your API keys
|
||||
```
|
||||
|
||||
### Key Configuration Files
|
||||
- `.env` - API keys and secrets
|
||||
- `cli-config.yaml` - CLI configuration (model, terminal, toolsets, personalities)
|
||||
- `configs/` - Batch run scripts and configuration
|
||||
|
||||
### Environment Variables
|
||||
|
||||
**Required for Full Functionality:**
|
||||
- `OPENROUTER_API_KEY` - Primary LLM access
|
||||
- `FIRECRAWL_API_KEY` - Web tools
|
||||
- `NOUS_API_KEY` - Vision and reasoning tools
|
||||
- `FAL_KEY` - Image generation
|
||||
|
||||
**Terminal Backend:**
|
||||
- `TERMINAL_ENV` - Backend type: `local`, `docker`, `singularity`, `modal`, `ssh`
|
||||
- `TERMINAL_CWD` - Working directory
|
||||
- `TERMINAL_DOCKER_IMAGE` / `TERMINAL_SINGULARITY_IMAGE` - Container images
|
||||
- `TERMINAL_SSH_HOST/USER/KEY` - SSH backend config
|
||||
- `SUDO_PASSWORD` - Optional sudo support
|
||||
|
||||
**Browser:**
|
||||
- `BROWSERBASE_API_KEY` - Browser automation
|
||||
- `BROWSERBASE_PROJECT_ID` - Browserbase project
|
||||
|
||||
## Technical Constraints
|
||||
|
||||
1. **Context Window Limits** - Long tool outputs can exhaust context; trajectory compression helps
|
||||
2. **API Rate Limits** - OpenRouter and tool APIs have rate limits; exponential backoff implemented
|
||||
3. **Tool Availability** - Tools gracefully degrade if dependencies/keys missing
|
||||
4. **Async Compatibility** - Some tools are async, handled via `asyncio.run()` in sync context
|
||||
|
||||
## Dependency Graph
|
||||
|
||||
```
|
||||
tools/*.py → tools/__init__.py → model_tools.py → toolsets.py → toolset_distributions.py
|
||||
↑
|
||||
run_agent.py ──────────────────────────┘
|
||||
cli.py → run_agent.py (uses AIAgent with quiet_mode=True)
|
||||
batch_runner.py → run_agent.py + toolset_distributions.py
|
||||
```
|
||||
|
||||
## Tool Usage Patterns
|
||||
|
||||
### Adding a New Tool
|
||||
1. Create `tools/your_tool.py` with handler + requirements check
|
||||
2. Export in `tools/__init__.py`
|
||||
3. Register in `model_tools.py` (definitions + handler routing)
|
||||
4. Add to toolset in `toolsets.py`
|
||||
5. Optionally add to `toolset_distributions.py` for batch processing
|
||||
|
||||
### Tool Handler Pattern
|
||||
```python
|
||||
def your_tool(param: str, task_id: str = None) -> str:
|
||||
"""Execute tool and return JSON string result."""
|
||||
try:
|
||||
result = {"success": True, "data": "..."}
|
||||
return json.dumps(result, ensure_ascii=False)
|
||||
except Exception as e:
|
||||
return json.dumps({"error": str(e)}, ensure_ascii=False)
|
||||
```
|
||||
|
||||
All tool handlers MUST return a JSON string, never raw dicts.
|
||||
Submodule mini-swe-agent updated: 07aa6a7385...ee36b3d4e5
134
modal_profiles.yaml.example
Normal file
134
modal_profiles.yaml.example
Normal file
@@ -0,0 +1,134 @@
|
||||
# Modal Sandbox Profiles Configuration
|
||||
# =====================================
|
||||
# This file defines different sandbox profiles for heterogeneous workloads.
|
||||
# Copy to modal_profiles.yaml and customize as needed.
|
||||
#
|
||||
# Usage:
|
||||
# terminal_tool("python train.py", profile="pytorch-gpu")
|
||||
# terminal_tool("npm test", profile="node")
|
||||
#
|
||||
# Each profile can specify:
|
||||
# - image: Docker image to use
|
||||
# - gpu: GPU type (null, "T4", "A10G", "A100", "H100")
|
||||
# - cpu: CPU cores (float)
|
||||
# - memory: Memory in MB
|
||||
# - min_pool: Minimum warm sandboxes (cost vs latency tradeoff)
|
||||
# - max_pool: Maximum sandboxes (hard cost cap)
|
||||
# - idle_timeout: Server-side auto-cleanup in seconds
|
||||
# - max_lifetime: Maximum sandbox lifetime in seconds
|
||||
# - scale_down_idle: Client-side scale-down threshold in seconds
|
||||
# - workdir: Working directory inside container
|
||||
# - secrets: List of Modal Secret names to inject (created via dashboard/CLI)
|
||||
# - env_vars: Dict of environment variables to pass directly
|
||||
# - use_dotenv: If true, loads local .env file into sandbox
|
||||
#
|
||||
# SECRETS SETUP:
|
||||
# Create secrets via Modal dashboard or CLI:
|
||||
# modal secret create huggingface-token HF_TOKEN=hf_xxx
|
||||
# modal secret create openai-key OPENAI_API_KEY=sk-xxx
|
||||
# Then reference by name in profile's secrets list.
|
||||
|
||||
# Default profile used when no profile specified
|
||||
default_profile: default
|
||||
|
||||
profiles:
|
||||
# Default Python environment - good for most tasks
|
||||
default:
|
||||
image: python:3.11
|
||||
gpu: null
|
||||
cpu: 1.0
|
||||
memory: 2048
|
||||
min_pool: 1 # Keep 1 warm for fast response
|
||||
max_pool: 5
|
||||
idle_timeout: 120 # Modal terminates if idle 2 min
|
||||
max_lifetime: 3600 # Max 1 hour
|
||||
scale_down_idle: 180
|
||||
workdir: /workspace
|
||||
secrets: [] # Add secret names here: ["my-api-keys"]
|
||||
env_vars: {} # Add env vars here: {DEBUG: "1"}
|
||||
use_dotenv: false # Set to true to load local .env
|
||||
|
||||
# PyTorch with GPU for ML training/inference
|
||||
pytorch-gpu:
|
||||
image: pytorch/pytorch:2.1.0-cuda12.1-cudnn8-runtime
|
||||
gpu: T4 # Options: T4, A10G, A100, H100
|
||||
cpu: 4.0
|
||||
memory: 16384 # 16GB
|
||||
min_pool: 0 # Don't keep GPU sandboxes warm (expensive!)
|
||||
max_pool: 2
|
||||
idle_timeout: 60 # Shorter idle timeout for GPU (cost)
|
||||
max_lifetime: 1800 # 30 min max for GPU tasks
|
||||
scale_down_idle: 60
|
||||
workdir: /workspace
|
||||
# ML-specific secrets
|
||||
secrets:
|
||||
- huggingface-token # HF_TOKEN env var
|
||||
- wandb-key # WANDB_API_KEY env var
|
||||
env_vars:
|
||||
CUDA_VISIBLE_DEVICES: "0"
|
||||
PYTORCH_CUDA_ALLOC_CONF: "expandable_segments:True"
|
||||
|
||||
# High-end GPU for large models
|
||||
pytorch-a100:
|
||||
image: pytorch/pytorch:2.1.0-cuda12.1-cudnn8-runtime
|
||||
gpu: A100
|
||||
cpu: 8.0
|
||||
memory: 65536 # 64GB
|
||||
min_pool: 0
|
||||
max_pool: 1 # Only 1 at a time (very expensive)
|
||||
idle_timeout: 30
|
||||
max_lifetime: 3600
|
||||
scale_down_idle: 30
|
||||
workdir: /workspace
|
||||
|
||||
# Node.js for JavaScript/TypeScript tasks
|
||||
node:
|
||||
image: node:18
|
||||
gpu: null
|
||||
cpu: 1.0
|
||||
memory: 2048
|
||||
min_pool: 0 # Create on-demand
|
||||
max_pool: 3
|
||||
idle_timeout: 120
|
||||
max_lifetime: 3600
|
||||
scale_down_idle: 180
|
||||
workdir: /workspace
|
||||
|
||||
# High memory for data processing
|
||||
high-memory:
|
||||
image: python:3.11
|
||||
gpu: null
|
||||
cpu: 4.0
|
||||
memory: 32768 # 32GB
|
||||
min_pool: 0
|
||||
max_pool: 2
|
||||
idle_timeout: 120
|
||||
max_lifetime: 3600
|
||||
scale_down_idle: 180
|
||||
workdir: /workspace
|
||||
|
||||
# Rust development environment
|
||||
rust:
|
||||
image: rust:1.75
|
||||
gpu: null
|
||||
cpu: 2.0
|
||||
memory: 4096
|
||||
min_pool: 0
|
||||
max_pool: 2
|
||||
idle_timeout: 120
|
||||
max_lifetime: 3600
|
||||
scale_down_idle: 180
|
||||
workdir: /workspace
|
||||
|
||||
# Go development environment
|
||||
golang:
|
||||
image: golang:1.21
|
||||
gpu: null
|
||||
cpu: 2.0
|
||||
memory: 4096
|
||||
min_pool: 0
|
||||
max_pool: 2
|
||||
idle_timeout: 120
|
||||
max_lifetime: 3600
|
||||
scale_down_idle: 180
|
||||
workdir: /workspace
|
||||
37
nomad-dev.hcl
Normal file
37
nomad-dev.hcl
Normal file
@@ -0,0 +1,37 @@
|
||||
# Nomad Development Configuration (Hermes-Agent)
|
||||
# Run with: nomad agent -dev -config=nomad-dev.hcl
|
||||
#
|
||||
# This is intended for local development only.
|
||||
|
||||
client {
|
||||
enabled = true
|
||||
|
||||
options {
|
||||
# Enable Docker volume mounts for persistent slot workspaces
|
||||
"docker.volumes.enabled" = "true"
|
||||
}
|
||||
}
|
||||
|
||||
# Docker driver plugin configuration
|
||||
plugin "docker" {
|
||||
config {
|
||||
# CRITICAL: Enable volume mounts
|
||||
volumes {
|
||||
enabled = true
|
||||
}
|
||||
|
||||
# Allow privileged containers if needed
|
||||
allow_privileged = false
|
||||
|
||||
# Garbage collection settings
|
||||
gc {
|
||||
image = true
|
||||
# NOTE: For local dev we often rely on locally built images like `atropos-sandbox:local`.
|
||||
# A short image GC delay can delete these between runs, causing confusing "Failed to pull"
|
||||
# crash loops. Keep this comfortably long; tighten it for CI/production if needed.
|
||||
image_delay = "24h"
|
||||
container = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
31
nomad-singularity.hcl
Normal file
31
nomad-singularity.hcl
Normal file
@@ -0,0 +1,31 @@
|
||||
# Nomad Configuration for Singularity/Apptainer Sandbox
|
||||
# Run with: nomad agent -dev -config=nomad-singularity.hcl
|
||||
#
|
||||
# This uses the raw_exec driver to run Apptainer containers.
|
||||
# Suitable for HPC environments where Docker cannot run without sudo.
|
||||
|
||||
client {
|
||||
enabled = true
|
||||
|
||||
options {
|
||||
# Enable raw_exec driver for Singularity/Apptainer
|
||||
"driver.raw_exec.enable" = "1"
|
||||
}
|
||||
}
|
||||
|
||||
# raw_exec driver plugin configuration
|
||||
plugin "raw_exec" {
|
||||
config {
|
||||
enabled = true
|
||||
}
|
||||
}
|
||||
|
||||
# Optional: If you have the nomad-driver-singularity plugin installed,
|
||||
# uncomment the following instead of using raw_exec:
|
||||
# plugin "singularity" {
|
||||
# config {
|
||||
# enabled = true
|
||||
# # Allow bind mounts
|
||||
# bind_paths = ["/tmp", "/var/tmp"]
|
||||
# }
|
||||
# }
|
||||
@@ -19,6 +19,7 @@ dependencies = [
|
||||
"rich",
|
||||
"tenacity",
|
||||
"pyyaml",
|
||||
"prompt_toolkit",
|
||||
"requests",
|
||||
"jinja2",
|
||||
"pydantic>=2.0",
|
||||
@@ -39,6 +40,20 @@ dev = ["pytest", "pytest-asyncio"]
|
||||
messaging = ["python-telegram-bot>=20.0", "discord.py>=2.0", "aiohttp>=3.9.0"]
|
||||
cron = ["croniter"]
|
||||
cli = ["simple-term-menu"]
|
||||
# Install Atropos + Tinker training integration from source.
|
||||
# Uses tool_call_support branch for ManagedServer tool calling (PR #366).
|
||||
atropos = [
|
||||
"atroposlib @ git+https://github.com/NousResearch/atropos.git@tool_call_support",
|
||||
"tinker @ git+https://github.com/thinking-machines-lab/tinker.git",
|
||||
# Atropos integration runtime deps (kept optional for Hermes-only users)
|
||||
"aiohttp",
|
||||
"fastapi",
|
||||
"uvicorn",
|
||||
"pyte",
|
||||
"torch",
|
||||
"wandb",
|
||||
"math-verify",
|
||||
]
|
||||
all = [
|
||||
"hermes-agent[modal]",
|
||||
"hermes-agent[messaging]",
|
||||
@@ -50,9 +65,21 @@ all = [
|
||||
[project.scripts]
|
||||
hermes = "hermes_cli.main:main"
|
||||
hermes-agent = "run_agent:main"
|
||||
hermes-atropos-sandbox-smoke = "atropos.envs.sandbox_terminal_smoke_env:SandboxTerminalSmokeEnv.cli"
|
||||
hermes-atropos-toolserver-smoke = "atropos.envs.toolserver_smoke_env:ToolServerSmokeEnv.cli"
|
||||
|
||||
[tool.setuptools]
|
||||
py-modules = ["run_agent", "model_tools", "toolsets", "batch_runner", "trajectory_compressor", "toolset_distributions", "cli"]
|
||||
py-modules = [
|
||||
"run_agent",
|
||||
"model_tools",
|
||||
"toolsets",
|
||||
"batch_runner",
|
||||
"trajectory_compressor",
|
||||
"toolset_distributions",
|
||||
"atropos_compatible_agent",
|
||||
"local_server",
|
||||
"cli",
|
||||
]
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
include = ["tools", "hermes_cli", "gateway", "cron"]
|
||||
include = ["tools", "hermes_cli", "gateway", "cron", "atropos", "atropos.*"]
|
||||
|
||||
44
run_agent.py
44
run_agent.py
@@ -30,7 +30,6 @@ import threading
|
||||
import uuid
|
||||
from typing import List, Dict, Any, Optional
|
||||
from openai import OpenAI
|
||||
import fire
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
@@ -1581,6 +1580,16 @@ class AIAgent:
|
||||
if active_system_prompt:
|
||||
# Insert system message at the beginning
|
||||
api_messages = [{"role": "system", "content": active_system_prompt}] + api_messages
|
||||
|
||||
if os.getenv("HERMES_DEBUG_OPENAI_REQUEST") == "1":
|
||||
meta = {
|
||||
"model": self.model,
|
||||
"base_url": self.base_url,
|
||||
"messages": api_messages,
|
||||
"tools": self.tools if self.tools else None,
|
||||
}
|
||||
print("\n=== HERMES_DEBUG_OPENAI_REQUEST ===", flush=True)
|
||||
print(json.dumps(meta, ensure_ascii=False, indent=2)[:200_000], flush=True)
|
||||
|
||||
# Calculate approximate request size for logging
|
||||
total_chars = sum(len(str(msg)) for msg in api_messages)
|
||||
@@ -1594,12 +1603,13 @@ class AIAgent:
|
||||
print(f"{self.log_prefix} 📊 Request size: {len(api_messages)} messages, ~{approx_tokens:,} tokens (~{total_chars:,} chars)")
|
||||
print(f"{self.log_prefix} 🔧 Available tools: {len(self.tools) if self.tools else 0}")
|
||||
else:
|
||||
# Animated thinking spinner in quiet mode
|
||||
face = random.choice(KawaiiSpinner.KAWAII_THINKING)
|
||||
verb = random.choice(KawaiiSpinner.THINKING_VERBS)
|
||||
spinner_type = random.choice(['brain', 'sparkle', 'pulse', 'moon', 'star'])
|
||||
thinking_spinner = KawaiiSpinner(f"{face} {verb}...", spinner_type=spinner_type)
|
||||
thinking_spinner.start()
|
||||
# Animated thinking spinner in quiet mode (disable for wrappers/non-TTY usage)
|
||||
if os.getenv("HERMES_DISABLE_SPINNER") != "1":
|
||||
face = random.choice(KawaiiSpinner.KAWAII_THINKING)
|
||||
verb = random.choice(KawaiiSpinner.THINKING_VERBS)
|
||||
spinner_type = random.choice(['brain', 'sparkle', 'pulse', 'moon', 'star'])
|
||||
thinking_spinner = KawaiiSpinner(f"{face} {verb}...", spinner_type=spinner_type)
|
||||
thinking_spinner.start()
|
||||
|
||||
# Log request details if verbose
|
||||
if self.verbose_logging:
|
||||
@@ -1659,6 +1669,14 @@ class AIAgent:
|
||||
api_kwargs["extra_body"] = extra_body
|
||||
|
||||
response = self.client.chat.completions.create(**api_kwargs)
|
||||
|
||||
if os.getenv("HERMES_DEBUG_OPENAI_RESPONSE") == "1":
|
||||
try:
|
||||
dumped = response.model_dump()
|
||||
except Exception:
|
||||
dumped = getattr(response, "__dict__", {"repr": repr(response)})
|
||||
print("\n=== HERMES_DEBUG_OPENAI_RESPONSE: ChatCompletion (raw) ===", flush=True)
|
||||
print(json.dumps(dumped, ensure_ascii=False, indent=2), flush=True)
|
||||
|
||||
api_duration = time.time() - api_start_time
|
||||
|
||||
@@ -2137,7 +2155,7 @@ class AIAgent:
|
||||
tool_start_time = time.time()
|
||||
|
||||
# Execute the tool - with animated spinner in quiet mode
|
||||
if self.quiet_mode:
|
||||
if self.quiet_mode and os.getenv("HERMES_DISABLE_SPINNER") != "1":
|
||||
# Tool-specific spinner animations
|
||||
tool_spinners = {
|
||||
'web_search': ('arrows', ['🔍', '🌐', '📡', '🔎']),
|
||||
@@ -2167,6 +2185,9 @@ class AIAgent:
|
||||
tool_duration = time.time() - tool_start_time
|
||||
cute_msg = self._get_cute_tool_message(function_name, function_args, tool_duration)
|
||||
spinner.stop(cute_msg)
|
||||
elif self.quiet_mode:
|
||||
function_result = handle_function_call(function_name, function_args, effective_task_id)
|
||||
tool_duration = time.time() - tool_start_time
|
||||
else:
|
||||
function_result = handle_function_call(function_name, function_args, effective_task_id)
|
||||
tool_duration = time.time() - tool_start_time
|
||||
@@ -2635,4 +2656,11 @@ def main(
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
import fire # type: ignore
|
||||
except ModuleNotFoundError as exc:
|
||||
raise SystemExit(
|
||||
"Missing optional dependency 'fire'. Install hermes-agent with its CLI extras or add `fire` "
|
||||
f"to your environment. Original error: {exc}"
|
||||
) from exc
|
||||
fire.Fire(main)
|
||||
|
||||
62
scripts/launch_llama_cpp_glm47_flash.sh
Executable file
62
scripts/launch_llama_cpp_glm47_flash.sh
Executable file
@@ -0,0 +1,62 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
# Launch a local llama.cpp OpenAI-compatible server running GLM-4.7-Flash (GGUF).
|
||||
#
|
||||
# Requires:
|
||||
# - `llama-server` installed (e.g. `brew install llama.cpp`)
|
||||
#
|
||||
# Default settings are chosen to avoid clashing with Atropos sandbox_server
|
||||
# (which commonly uses port 8080 in local dev).
|
||||
#
|
||||
# Usage:
|
||||
# Hermes-Agent/scripts/launch_llama_cpp_glm47_flash.sh
|
||||
#
|
||||
# Override defaults:
|
||||
# LLAMA_CPP_HOST=127.0.0.1 LLAMA_CPP_PORT=8082 \
|
||||
# LLAMA_CPP_HF_REPO=ggml-org/GLM-4.7-Flash-GGUF \
|
||||
# LLAMA_CPP_HF_FILE=GLM-4.7-Flash-Q4_K.gguf \
|
||||
# Hermes-Agent/scripts/launch_llama_cpp_glm47_flash.sh
|
||||
|
||||
HOST="${LLAMA_CPP_HOST:-127.0.0.1}"
|
||||
PORT="${LLAMA_CPP_PORT:-8080}"
|
||||
HF_REPO="${LLAMA_CPP_HF_REPO:-ggml-org/GLM-4.7-Flash-GGUF}"
|
||||
HF_FILE="${LLAMA_CPP_HF_FILE:-GLM-4.7-Flash-Q4_K.gguf}"
|
||||
ALIAS="${LLAMA_CPP_ALIAS:-glm-4.7-flash}"
|
||||
|
||||
if ! command -v llama-server >/dev/null 2>&1; then
|
||||
echo "Error: llama-server not found in PATH."
|
||||
echo "Install via Homebrew: brew install llama.cpp"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "Launching llama.cpp server..."
|
||||
echo " host: $HOST"
|
||||
echo " port: $PORT"
|
||||
echo " repo: $HF_REPO"
|
||||
echo " file: $HF_FILE"
|
||||
echo " alias: $ALIAS"
|
||||
echo
|
||||
echo "Suggested env vars for Hermes/Atropos integration:"
|
||||
echo " export ATROPOS_SERVER_BASE_URL=http://${HOST}:${PORT}"
|
||||
echo " export ATROPOS_SERVER_MODEL=${ALIAS}"
|
||||
echo " export ATROPOS_SERVER_API_KEY=local"
|
||||
echo
|
||||
|
||||
if command -v lsof >/dev/null 2>&1; then
|
||||
if lsof -nP -iTCP:"$PORT" -sTCP:LISTEN >/dev/null 2>&1; then
|
||||
echo "Error: port $PORT is already in use."
|
||||
echo "Pick a different port, e.g.:"
|
||||
echo " LLAMA_CPP_PORT=8082 Hermes-Agent/scripts/launch_llama_cpp_glm47_flash.sh"
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
|
||||
exec llama-server \
|
||||
--host "$HOST" \
|
||||
--port "$PORT" \
|
||||
--hf-repo "$HF_REPO" \
|
||||
--hf-file "$HF_FILE" \
|
||||
--alias "$ALIAS" \
|
||||
-c 32768 \
|
||||
-n -1
|
||||
70
scripts/launch_llama_cpp_hermes_4_36b.sh
Executable file
70
scripts/launch_llama_cpp_hermes_4_36b.sh
Executable file
@@ -0,0 +1,70 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
# Launch a local llama.cpp OpenAI-compatible server running Hermes 4.3 36B (GGUF).
|
||||
#
|
||||
# Requires:
|
||||
# - `llama-server` installed (e.g. `brew install llama.cpp`)
|
||||
#
|
||||
# Note: Port choice can conflict with other local dev servers. If 8080 is already
|
||||
# in use, override via `LLAMA_CPP_PORT=...`.
|
||||
#
|
||||
# Usage:
|
||||
# Hermes-Agent/scripts/launch_llama_cpp_hermes_4_36b.sh
|
||||
#
|
||||
# Override defaults:
|
||||
# LLAMA_CPP_HOST=127.0.0.1 LLAMA_CPP_PORT=8082 \
|
||||
# LLAMA_CPP_HF_REPO=NousResearch/Hermes-4.3-36B-GGUF \
|
||||
# LLAMA_CPP_HF_FILE=hermes-4_3_36b-Q4_K_M.gguf \
|
||||
# LLAMA_CPP_ALIAS=hermes-4-36b \
|
||||
# LLAMA_CPP_PARALLEL=4 LLAMA_CPP_THREADS_HTTP=4 \
|
||||
# Hermes-Agent/scripts/launch_llama_cpp_hermes_4_36b.sh
|
||||
|
||||
HOST="${LLAMA_CPP_HOST:-127.0.0.1}"
|
||||
PORT="${LLAMA_CPP_PORT:-8080}"
|
||||
HF_REPO="${LLAMA_CPP_HF_REPO:-NousResearch/Hermes-4.3-36B-GGUF}"
|
||||
HF_FILE="${LLAMA_CPP_HF_FILE:-hermes-4_3_36b-Q4_K_M.gguf}"
|
||||
ALIAS="${LLAMA_CPP_ALIAS:-hermes-4-36b}"
|
||||
PARALLEL="${LLAMA_CPP_PARALLEL:-4}"
|
||||
THREADS_HTTP="${LLAMA_CPP_THREADS_HTTP:-4}"
|
||||
|
||||
if ! command -v llama-server >/dev/null 2>&1; then
|
||||
echo "Error: llama-server not found in PATH."
|
||||
echo "Install via Homebrew: brew install llama.cpp"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "Launching llama.cpp server..."
|
||||
echo " host: $HOST"
|
||||
echo " port: $PORT"
|
||||
echo " repo: $HF_REPO"
|
||||
echo " file: $HF_FILE"
|
||||
echo " alias: $ALIAS"
|
||||
echo " slots: $PARALLEL"
|
||||
echo
|
||||
echo "Suggested env vars for Hermes/Atropos integration:"
|
||||
echo " export ATROPOS_SERVER_BASE_URL=http://${HOST}:${PORT}"
|
||||
echo " export ATROPOS_SERVER_MODEL=${ALIAS}"
|
||||
echo " export ATROPOS_TOKENIZER_NAME=NousResearch/Hermes-4.3-36B"
|
||||
echo " export ATROPOS_SERVER_API_KEY=local"
|
||||
echo
|
||||
|
||||
if command -v lsof >/dev/null 2>&1; then
|
||||
if lsof -nP -iTCP:"$PORT" -sTCP:LISTEN >/dev/null 2>&1; then
|
||||
echo "Error: port $PORT is already in use."
|
||||
echo "Pick a different port, e.g.:"
|
||||
echo " LLAMA_CPP_PORT=8082 Hermes-Agent/scripts/launch_llama_cpp_hermes_4_36b.sh"
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
|
||||
exec llama-server \
|
||||
--host "$HOST" \
|
||||
--port "$PORT" \
|
||||
--hf-repo "$HF_REPO" \
|
||||
--hf-file "$HF_FILE" \
|
||||
--alias "$ALIAS" \
|
||||
--parallel "$PARALLEL" \
|
||||
--threads-http "$THREADS_HTTP" \
|
||||
-c 32768 \
|
||||
-n -1
|
||||
15
tests/test_data/checkpoint_test_dataset.jsonl
Normal file
15
tests/test_data/checkpoint_test_dataset.jsonl
Normal file
@@ -0,0 +1,15 @@
|
||||
{"prompt": "Test prompt 0: What is 2+2? Just answer briefly.", "test_id": 0}
|
||||
{"prompt": "Test prompt 1: What is 2+2? Just answer briefly.", "test_id": 1}
|
||||
{"prompt": "Test prompt 2: What is 2+2? Just answer briefly.", "test_id": 2}
|
||||
{"prompt": "Test prompt 3: What is 2+2? Just answer briefly.", "test_id": 3}
|
||||
{"prompt": "Test prompt 4: What is 2+2? Just answer briefly.", "test_id": 4}
|
||||
{"prompt": "Test prompt 5: What is 2+2? Just answer briefly.", "test_id": 5}
|
||||
{"prompt": "Test prompt 6: What is 2+2? Just answer briefly.", "test_id": 6}
|
||||
{"prompt": "Test prompt 7: What is 2+2? Just answer briefly.", "test_id": 7}
|
||||
{"prompt": "Test prompt 8: What is 2+2? Just answer briefly.", "test_id": 8}
|
||||
{"prompt": "Test prompt 9: What is 2+2? Just answer briefly.", "test_id": 9}
|
||||
{"prompt": "Test prompt 10: What is 2+2? Just answer briefly.", "test_id": 10}
|
||||
{"prompt": "Test prompt 11: What is 2+2? Just answer briefly.", "test_id": 11}
|
||||
{"prompt": "Test prompt 12: What is 2+2? Just answer briefly.", "test_id": 12}
|
||||
{"prompt": "Test prompt 13: What is 2+2? Just answer briefly.", "test_id": 13}
|
||||
{"prompt": "Test prompt 14: What is 2+2? Just answer briefly.", "test_id": 14}
|
||||
5
tests/test_data/checkpoint_test_resume_partial.jsonl
Normal file
5
tests/test_data/checkpoint_test_resume_partial.jsonl
Normal file
@@ -0,0 +1,5 @@
|
||||
{"prompt": "Test prompt 0: What is 2+2? Just answer briefly.", "test_id": 0}
|
||||
{"prompt": "Test prompt 1: What is 2+2? Just answer briefly.", "test_id": 1}
|
||||
{"prompt": "Test prompt 2: What is 2+2? Just answer briefly.", "test_id": 2}
|
||||
{"prompt": "Test prompt 3: What is 2+2? Just answer briefly.", "test_id": 3}
|
||||
{"prompt": "Test prompt 4: What is 2+2? Just answer briefly.", "test_id": 4}
|
||||
1082
tests/test_modal_integration.py
Normal file
1082
tests/test_modal_integration.py
Normal file
File diff suppressed because it is too large
Load Diff
923
tests/test_modal_stress.py
Normal file
923
tests/test_modal_stress.py
Normal file
@@ -0,0 +1,923 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Modal Integration Stress Tests & Full Integration Tests
|
||||
|
||||
This test suite includes:
|
||||
1. Stress tests for Modal sandbox pools (concurrent load, scaling)
|
||||
2. Atropos backend tests (requires atroposlib)
|
||||
3. mini-swe-agent integration tests
|
||||
|
||||
Prerequisites:
|
||||
# Install dev dependencies
|
||||
pip install -e '.[dev,modal]'
|
||||
|
||||
# Install atroposlib for Atropos tests
|
||||
pip install -e '.[atropos]'
|
||||
|
||||
# Clone mini-swe-agent (if not present)
|
||||
git clone https://github.com/anthropics/mini-swe-agent.git mini-swe-agent
|
||||
# Or as submodule:
|
||||
git submodule add https://github.com/anthropics/mini-swe-agent.git mini-swe-agent
|
||||
|
||||
Run with:
|
||||
# All tests
|
||||
python tests/test_modal_stress.py
|
||||
|
||||
# Stress tests only
|
||||
python tests/test_modal_stress.py --category stress
|
||||
|
||||
# Atropos tests only
|
||||
python tests/test_modal_stress.py --category atropos
|
||||
|
||||
# Mini-swe-agent tests only
|
||||
python tests/test_modal_stress.py --category miniswe
|
||||
|
||||
# Dry run (no Modal calls)
|
||||
python tests/test_modal_stress.py --dry-run
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import random
|
||||
import traceback
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, List, Optional, Tuple
|
||||
from dataclasses import dataclass
|
||||
|
||||
# Add parent to path for imports
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test Configuration
|
||||
# =============================================================================
|
||||
|
||||
@dataclass
|
||||
class StressTestConfig:
|
||||
dry_run: bool = False
|
||||
verbose: bool = True
|
||||
category: Optional[str] = None
|
||||
# Stress test parameters (reduced defaults for faster first-run)
|
||||
concurrent_tasks: int = 3 # Start small - Modal cold starts are slow
|
||||
total_operations: int = 10
|
||||
max_sandboxes: int = 3
|
||||
slots_per_sandbox: int = 3
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test Results Tracking
|
||||
# =============================================================================
|
||||
|
||||
class TestResults:
|
||||
def __init__(self):
|
||||
self.passed: List[str] = []
|
||||
self.failed: List[Tuple[str, str]] = []
|
||||
self.skipped: List[Tuple[str, str]] = []
|
||||
self.metrics: Dict[str, Any] = {}
|
||||
|
||||
def record_pass(self, name: str, metrics: Optional[Dict] = None):
|
||||
self.passed.append(name)
|
||||
if metrics:
|
||||
self.metrics[name] = metrics
|
||||
print(f" ✅ {name}")
|
||||
if metrics:
|
||||
for k, v in metrics.items():
|
||||
print(f" 📊 {k}: {v}")
|
||||
|
||||
def record_fail(self, name: str, error: str):
|
||||
self.failed.append((name, error))
|
||||
print(f" ❌ {name}: {error}")
|
||||
|
||||
def record_skip(self, name: str, reason: str):
|
||||
self.skipped.append((name, reason))
|
||||
print(f" ⏭️ {name}: {reason}")
|
||||
|
||||
def summary(self):
|
||||
total = len(self.passed) + len(self.failed) + len(self.skipped)
|
||||
print(f"\n{'='*70}")
|
||||
print(f"STRESS TEST RESULTS: {len(self.passed)}/{total} passed")
|
||||
print(f" Passed: {len(self.passed)}")
|
||||
print(f" Failed: {len(self.failed)}")
|
||||
print(f" Skipped: {len(self.skipped)}")
|
||||
|
||||
if self.failed:
|
||||
print(f"\nFailed tests:")
|
||||
for name, error in self.failed:
|
||||
print(f" - {name}: {error}")
|
||||
|
||||
if self.metrics:
|
||||
print(f"\nPerformance Metrics:")
|
||||
for test, metrics in self.metrics.items():
|
||||
print(f" {test}:")
|
||||
for k, v in metrics.items():
|
||||
print(f" - {k}: {v}")
|
||||
|
||||
return len(self.failed) == 0
|
||||
|
||||
|
||||
results = TestResults()
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Helper: Atropos Import
|
||||
# =============================================================================
|
||||
|
||||
def try_import_atropos():
|
||||
"""Try importing Atropos backend components."""
|
||||
try:
|
||||
from atropos.backends.modal_backend import (
|
||||
ModalToolBackend, ModalSandboxConfig,
|
||||
_ModalMultiProfileManager
|
||||
)
|
||||
from atropos.slots.slot import Slot, SlotState
|
||||
return ModalToolBackend, ModalSandboxConfig, Slot, SlotState
|
||||
except (ImportError, ModuleNotFoundError) as e:
|
||||
return None
|
||||
|
||||
|
||||
def try_import_miniswe():
|
||||
"""Try importing mini-swe-agent components."""
|
||||
try:
|
||||
# Check if mini-swe-agent path exists and has content
|
||||
mini_swe_path = Path(__file__).parent.parent / "mini-swe-agent" / "src"
|
||||
if mini_swe_path.exists() and list(mini_swe_path.iterdir()):
|
||||
sys.path.insert(0, str(mini_swe_path))
|
||||
import minisweagent
|
||||
return minisweagent
|
||||
return None
|
||||
except (ImportError, ModuleNotFoundError) as e:
|
||||
return None
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# CATEGORY 1: Stress Tests (Terminal Tool)
|
||||
# =============================================================================
|
||||
|
||||
def test_stress_concurrent_tasks(config: StressTestConfig):
|
||||
"""Stress test: Multiple concurrent task_ids hitting the pool."""
|
||||
if config.dry_run:
|
||||
results.record_skip("test_stress_concurrent_tasks", "Dry run mode")
|
||||
return
|
||||
|
||||
from tools.terminal_tool import terminal_tool, cleanup_vm
|
||||
|
||||
original_env = os.environ.get("TERMINAL_ENV")
|
||||
os.environ["TERMINAL_ENV"] = "modal"
|
||||
|
||||
try:
|
||||
num_tasks = config.concurrent_tasks
|
||||
task_ids = [f"stress-concurrent-{i}-{int(time.time())}" for i in range(num_tasks)]
|
||||
|
||||
start_time = time.time()
|
||||
errors = []
|
||||
successes = 0
|
||||
|
||||
def run_task(task_id: str) -> Tuple[bool, str]:
|
||||
try:
|
||||
result = json.loads(terminal_tool(
|
||||
f"echo 'Hello from {task_id}' && sleep 0.5",
|
||||
task_id=task_id,
|
||||
))
|
||||
success = result["exit_code"] == 0
|
||||
|
||||
# IMPORTANT: Clean up immediately after task completes
|
||||
# This releases the sandbox back to the pool for other tasks
|
||||
try:
|
||||
cleanup_vm(task_id)
|
||||
except:
|
||||
pass
|
||||
|
||||
if success:
|
||||
return True, ""
|
||||
# Include more details for debugging
|
||||
error_detail = result.get("error", "no error message")
|
||||
output = result.get("output", "")[:100] # First 100 chars
|
||||
return False, f"Exit code: {result['exit_code']}, error: {error_detail}, output: {output}"
|
||||
except Exception as e:
|
||||
# Clean up even on failure
|
||||
try:
|
||||
cleanup_vm(task_id)
|
||||
except:
|
||||
pass
|
||||
import traceback
|
||||
return False, f"Exception: {str(e)}\n{traceback.format_exc()}"
|
||||
|
||||
# Run all tasks concurrently using threads
|
||||
with ThreadPoolExecutor(max_workers=num_tasks) as executor:
|
||||
futures = {executor.submit(run_task, tid): tid for tid in task_ids}
|
||||
|
||||
for future in as_completed(futures):
|
||||
task_id = futures[future]
|
||||
try:
|
||||
success, error = future.result(timeout=60)
|
||||
if success:
|
||||
successes += 1
|
||||
else:
|
||||
errors.append(f"{task_id}: {error}")
|
||||
except Exception as e:
|
||||
errors.append(f"{task_id}: {str(e)}")
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
# No need for cleanup here - each task cleans up immediately
|
||||
|
||||
# Report
|
||||
success_rate = successes / num_tasks * 100
|
||||
|
||||
if success_rate >= 90: # Allow 10% failure rate for stress test
|
||||
results.record_pass("test_stress_concurrent_tasks", {
|
||||
"concurrent_tasks": num_tasks,
|
||||
"successes": successes,
|
||||
"failures": len(errors),
|
||||
"success_rate": f"{success_rate:.1f}%",
|
||||
"total_time": f"{elapsed:.2f}s",
|
||||
"avg_time_per_task": f"{elapsed/num_tasks:.2f}s",
|
||||
})
|
||||
else:
|
||||
results.record_fail(
|
||||
"test_stress_concurrent_tasks",
|
||||
f"Success rate {success_rate:.1f}% < 90%. Errors: {errors[:3]}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
results.record_fail("test_stress_concurrent_tasks", str(e))
|
||||
finally:
|
||||
if original_env:
|
||||
os.environ["TERMINAL_ENV"] = original_env
|
||||
elif "TERMINAL_ENV" in os.environ:
|
||||
del os.environ["TERMINAL_ENV"]
|
||||
|
||||
|
||||
def test_stress_rapid_fire(config: StressTestConfig):
|
||||
"""Stress test: Rapid sequential commands to same task_id."""
|
||||
if config.dry_run:
|
||||
results.record_skip("test_stress_rapid_fire", "Dry run mode")
|
||||
return
|
||||
|
||||
from tools.terminal_tool import terminal_tool, cleanup_vm
|
||||
|
||||
original_env = os.environ.get("TERMINAL_ENV")
|
||||
os.environ["TERMINAL_ENV"] = "modal"
|
||||
|
||||
try:
|
||||
task_id = f"stress-rapid-{int(time.time())}"
|
||||
num_commands = config.total_operations
|
||||
|
||||
start_time = time.time()
|
||||
successes = 0
|
||||
errors = []
|
||||
|
||||
for i in range(num_commands):
|
||||
try:
|
||||
result = json.loads(terminal_tool(f"echo {i}", task_id=task_id))
|
||||
if result["exit_code"] == 0 and str(i) in result["output"]:
|
||||
successes += 1
|
||||
else:
|
||||
errors.append(f"Command {i}: unexpected result")
|
||||
except Exception as e:
|
||||
errors.append(f"Command {i}: {str(e)}")
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
cleanup_vm(task_id)
|
||||
|
||||
success_rate = successes / num_commands * 100
|
||||
commands_per_second = num_commands / elapsed
|
||||
|
||||
if success_rate >= 95:
|
||||
results.record_pass("test_stress_rapid_fire", {
|
||||
"total_commands": num_commands,
|
||||
"successes": successes,
|
||||
"success_rate": f"{success_rate:.1f}%",
|
||||
"total_time": f"{elapsed:.2f}s",
|
||||
"commands_per_second": f"{commands_per_second:.1f}",
|
||||
})
|
||||
else:
|
||||
results.record_fail(
|
||||
"test_stress_rapid_fire",
|
||||
f"Success rate {success_rate:.1f}% < 95%"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
results.record_fail("test_stress_rapid_fire", str(e))
|
||||
finally:
|
||||
if original_env:
|
||||
os.environ["TERMINAL_ENV"] = original_env
|
||||
elif "TERMINAL_ENV" in os.environ:
|
||||
del os.environ["TERMINAL_ENV"]
|
||||
|
||||
|
||||
def test_stress_pool_scaling(config: StressTestConfig):
|
||||
"""Stress test: Force pool to scale up and down by running tasks in batches."""
|
||||
if config.dry_run:
|
||||
results.record_skip("test_stress_pool_scaling", "Dry run mode")
|
||||
return
|
||||
|
||||
from tools.terminal_tool import terminal_tool, cleanup_vm, _ModalPoolManager
|
||||
|
||||
original_env = os.environ.get("TERMINAL_ENV")
|
||||
os.environ["TERMINAL_ENV"] = "modal"
|
||||
|
||||
try:
|
||||
# Run tasks in batches matching max_sandboxes to test pool reuse
|
||||
# This verifies sandboxes can be acquired, used, released, and reused
|
||||
batch_size = config.max_sandboxes
|
||||
num_batches = 3
|
||||
total_tasks = batch_size * num_batches
|
||||
|
||||
start_time = time.time()
|
||||
successes = 0
|
||||
|
||||
for batch in range(num_batches):
|
||||
task_ids = [f"stress-scale-{batch}-{i}-{int(time.time())}" for i in range(batch_size)]
|
||||
|
||||
def run_task(task_id: str):
|
||||
try:
|
||||
result = json.loads(terminal_tool(
|
||||
"echo done", # Fast command to test scaling
|
||||
task_id=task_id,
|
||||
))
|
||||
success = result["exit_code"] == 0
|
||||
try:
|
||||
cleanup_vm(task_id)
|
||||
except:
|
||||
pass
|
||||
return success
|
||||
except:
|
||||
try:
|
||||
cleanup_vm(task_id)
|
||||
except:
|
||||
pass
|
||||
return False
|
||||
|
||||
# Run batch concurrently
|
||||
with ThreadPoolExecutor(max_workers=batch_size) as executor:
|
||||
batch_results = list(executor.map(run_task, task_ids))
|
||||
successes += sum(batch_results)
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
# Check pool status
|
||||
try:
|
||||
manager = _ModalPoolManager.get_instance()
|
||||
pool_status = manager.get_status() if hasattr(manager, 'get_status') else {}
|
||||
except:
|
||||
pool_status = {}
|
||||
|
||||
success_rate = successes / total_tasks * 100
|
||||
|
||||
if success_rate >= 80: # Allow some tolerance
|
||||
results.record_pass("test_stress_pool_scaling", {
|
||||
"total_tasks": total_tasks,
|
||||
"num_batches": num_batches,
|
||||
"batch_size": batch_size,
|
||||
"successes": successes,
|
||||
"success_rate": f"{success_rate:.1f}%",
|
||||
"total_time": f"{elapsed:.2f}s",
|
||||
"pool_status": pool_status,
|
||||
})
|
||||
else:
|
||||
results.record_fail(
|
||||
"test_stress_pool_scaling",
|
||||
f"Success rate {success_rate:.1f}% < 80%"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
results.record_fail("test_stress_pool_scaling", str(e))
|
||||
finally:
|
||||
if original_env:
|
||||
os.environ["TERMINAL_ENV"] = original_env
|
||||
elif "TERMINAL_ENV" in os.environ:
|
||||
del os.environ["TERMINAL_ENV"]
|
||||
|
||||
|
||||
def test_stress_large_output(config: StressTestConfig):
|
||||
"""Stress test: Commands producing large output."""
|
||||
if config.dry_run:
|
||||
results.record_skip("test_stress_large_output", "Dry run mode")
|
||||
return
|
||||
|
||||
from tools.terminal_tool import terminal_tool, cleanup_vm
|
||||
|
||||
original_env = os.environ.get("TERMINAL_ENV")
|
||||
os.environ["TERMINAL_ENV"] = "modal"
|
||||
|
||||
try:
|
||||
task_id = f"stress-large-{int(time.time())}"
|
||||
|
||||
# First verify basic connectivity with simple command
|
||||
warmup = json.loads(terminal_tool("echo warmup", task_id=task_id))
|
||||
if warmup["exit_code"] != 0:
|
||||
results.record_fail(
|
||||
"test_stress_large_output",
|
||||
f"Warmup failed: {warmup.get('error', 'unknown')}"
|
||||
)
|
||||
return
|
||||
|
||||
# Generate output - use seq which is more portable
|
||||
start_time = time.time()
|
||||
result = json.loads(terminal_tool(
|
||||
'seq 1 500 | while read i; do echo "Line $i: This is test content for large output"; done',
|
||||
task_id=task_id,
|
||||
timeout=60,
|
||||
))
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
cleanup_vm(task_id)
|
||||
|
||||
output_size = len(result.get("output", ""))
|
||||
error_msg = result.get("error", "")
|
||||
|
||||
if result["exit_code"] == 0 and output_size > 5000:
|
||||
results.record_pass("test_stress_large_output", {
|
||||
"output_size": f"{output_size:,} bytes",
|
||||
"time": f"{elapsed:.2f}s",
|
||||
"throughput": f"{output_size/elapsed/1024:.1f} KB/s" if elapsed > 0 else "N/A",
|
||||
})
|
||||
else:
|
||||
results.record_fail(
|
||||
"test_stress_large_output",
|
||||
f"Exit code: {result['exit_code']}, output size: {output_size}, error: {error_msg}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
results.record_fail("test_stress_large_output", f"{str(e)}\n{traceback.format_exc()}")
|
||||
finally:
|
||||
try:
|
||||
cleanup_vm(task_id)
|
||||
except:
|
||||
pass
|
||||
if original_env:
|
||||
os.environ["TERMINAL_ENV"] = original_env
|
||||
elif "TERMINAL_ENV" in os.environ:
|
||||
del os.environ["TERMINAL_ENV"]
|
||||
|
||||
|
||||
def test_stress_error_recovery(config: StressTestConfig):
|
||||
"""Stress test: Commands that fail and verify sandbox continues working."""
|
||||
if config.dry_run:
|
||||
results.record_skip("test_stress_error_recovery", "Dry run mode")
|
||||
return
|
||||
|
||||
from tools.terminal_tool import terminal_tool, cleanup_vm
|
||||
|
||||
original_env = os.environ.get("TERMINAL_ENV")
|
||||
os.environ["TERMINAL_ENV"] = "modal"
|
||||
|
||||
try:
|
||||
task_id = f"stress-error-{int(time.time())}"
|
||||
|
||||
# Run some failing commands
|
||||
failing_commands = [
|
||||
"exit 1",
|
||||
"false",
|
||||
"cat /nonexistent/file",
|
||||
"command_that_does_not_exist",
|
||||
]
|
||||
|
||||
for cmd in failing_commands:
|
||||
result = json.loads(terminal_tool(cmd, task_id=task_id))
|
||||
# These should fail but not crash
|
||||
assert result["exit_code"] != 0 or result.get("error"), f"Expected failure for: {cmd}"
|
||||
|
||||
# Now run a command that should succeed
|
||||
result = json.loads(terminal_tool("echo 'recovery success'", task_id=task_id))
|
||||
|
||||
cleanup_vm(task_id)
|
||||
|
||||
if result["exit_code"] == 0 and "recovery success" in result["output"]:
|
||||
results.record_pass("test_stress_error_recovery", {
|
||||
"failed_commands": len(failing_commands),
|
||||
"recovery": "success",
|
||||
})
|
||||
else:
|
||||
results.record_fail(
|
||||
"test_stress_error_recovery",
|
||||
f"Recovery failed: {result}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
results.record_fail("test_stress_error_recovery", str(e))
|
||||
finally:
|
||||
if original_env:
|
||||
os.environ["TERMINAL_ENV"] = original_env
|
||||
elif "TERMINAL_ENV" in os.environ:
|
||||
del os.environ["TERMINAL_ENV"]
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# CATEGORY 2: Atropos Backend Stress Tests
|
||||
# =============================================================================
|
||||
|
||||
async def test_atropos_stress_slot_churn(config: StressTestConfig):
|
||||
"""Atropos stress test: Rapid slot acquire/release cycles."""
|
||||
if config.dry_run:
|
||||
results.record_skip("test_atropos_stress_slot_churn", "Dry run mode")
|
||||
return
|
||||
|
||||
imports = try_import_atropos()
|
||||
if imports is None:
|
||||
results.record_skip("test_atropos_stress_slot_churn", "Requires atroposlib")
|
||||
return
|
||||
|
||||
ModalToolBackend, ModalSandboxConfig, _, _ = imports
|
||||
|
||||
try:
|
||||
backend_config = ModalSandboxConfig(
|
||||
app_name=f"stress-churn-{int(time.time())}",
|
||||
min_sandboxes=1,
|
||||
max_sandboxes=3,
|
||||
slots_per_sandbox=5,
|
||||
)
|
||||
|
||||
backend = ModalToolBackend(backend_config)
|
||||
await backend.start()
|
||||
|
||||
try:
|
||||
num_cycles = config.total_operations
|
||||
start_time = time.time()
|
||||
successes = 0
|
||||
|
||||
for i in range(num_cycles):
|
||||
try:
|
||||
slot = await backend.acquire(f"churn-{i}")
|
||||
|
||||
# Quick command
|
||||
results_list = await backend.execute_batch([
|
||||
(slot, "bash", {"command": f"echo {i}"})
|
||||
])
|
||||
|
||||
if results_list[0].success:
|
||||
successes += 1
|
||||
|
||||
await backend.release(slot, reset_workspace=(i % 5 == 0))
|
||||
except Exception as e:
|
||||
pass # Count as failure
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
success_rate = successes / num_cycles * 100
|
||||
|
||||
if success_rate >= 90:
|
||||
results.record_pass("test_atropos_stress_slot_churn", {
|
||||
"cycles": num_cycles,
|
||||
"successes": successes,
|
||||
"success_rate": f"{success_rate:.1f}%",
|
||||
"total_time": f"{elapsed:.2f}s",
|
||||
"cycles_per_second": f"{num_cycles/elapsed:.1f}",
|
||||
})
|
||||
else:
|
||||
results.record_fail(
|
||||
"test_atropos_stress_slot_churn",
|
||||
f"Success rate {success_rate:.1f}% < 90%"
|
||||
)
|
||||
|
||||
finally:
|
||||
await backend.stop(purge=True)
|
||||
|
||||
except Exception as e:
|
||||
results.record_fail("test_atropos_stress_slot_churn", str(e))
|
||||
|
||||
|
||||
async def test_atropos_stress_parallel_batches(config: StressTestConfig):
|
||||
"""Atropos stress test: Multiple parallel batch executions."""
|
||||
if config.dry_run:
|
||||
results.record_skip("test_atropos_stress_parallel_batches", "Dry run mode")
|
||||
return
|
||||
|
||||
imports = try_import_atropos()
|
||||
if imports is None:
|
||||
results.record_skip("test_atropos_stress_parallel_batches", "Requires atroposlib")
|
||||
return
|
||||
|
||||
ModalToolBackend, ModalSandboxConfig, _, _ = imports
|
||||
|
||||
try:
|
||||
backend_config = ModalSandboxConfig(
|
||||
app_name=f"stress-batch-{int(time.time())}",
|
||||
min_sandboxes=2,
|
||||
max_sandboxes=4,
|
||||
slots_per_sandbox=5,
|
||||
)
|
||||
|
||||
backend = ModalToolBackend(backend_config)
|
||||
await backend.start()
|
||||
|
||||
try:
|
||||
num_slots = 10
|
||||
slots = []
|
||||
|
||||
# Acquire multiple slots
|
||||
for i in range(num_slots):
|
||||
slot = await backend.acquire(f"batch-{i}")
|
||||
slots.append(slot)
|
||||
|
||||
# Run multiple batches in parallel
|
||||
start_time = time.time()
|
||||
num_batches = 5
|
||||
|
||||
async def run_batch(batch_id: int):
|
||||
requests = [
|
||||
(slot, "bash", {"command": f"echo 'batch{batch_id}-slot{i}'"})
|
||||
for i, slot in enumerate(slots)
|
||||
]
|
||||
return await backend.execute_batch(requests)
|
||||
|
||||
batch_tasks = [run_batch(i) for i in range(num_batches)]
|
||||
all_results = await asyncio.gather(*batch_tasks)
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
# Count successes
|
||||
total_commands = num_batches * num_slots
|
||||
successes = sum(
|
||||
1 for batch_result in all_results
|
||||
for r in batch_result
|
||||
if r.success
|
||||
)
|
||||
|
||||
# Release slots
|
||||
for slot in slots:
|
||||
await backend.release(slot)
|
||||
|
||||
success_rate = successes / total_commands * 100
|
||||
|
||||
if success_rate >= 90:
|
||||
results.record_pass("test_atropos_stress_parallel_batches", {
|
||||
"batches": num_batches,
|
||||
"slots": num_slots,
|
||||
"total_commands": total_commands,
|
||||
"successes": successes,
|
||||
"success_rate": f"{success_rate:.1f}%",
|
||||
"total_time": f"{elapsed:.2f}s",
|
||||
"commands_per_second": f"{total_commands/elapsed:.1f}",
|
||||
})
|
||||
else:
|
||||
results.record_fail(
|
||||
"test_atropos_stress_parallel_batches",
|
||||
f"Success rate {success_rate:.1f}% < 90%"
|
||||
)
|
||||
|
||||
finally:
|
||||
await backend.stop(purge=True)
|
||||
|
||||
except Exception as e:
|
||||
results.record_fail("test_atropos_stress_parallel_batches", str(e))
|
||||
|
||||
|
||||
async def test_atropos_stress_multi_profile_load(config: StressTestConfig):
|
||||
"""Atropos stress test: Load across multiple profiles."""
|
||||
if config.dry_run:
|
||||
results.record_skip("test_atropos_stress_multi_profile_load", "Dry run mode")
|
||||
return
|
||||
|
||||
imports = try_import_atropos()
|
||||
if imports is None:
|
||||
results.record_skip("test_atropos_stress_multi_profile_load", "Requires atroposlib")
|
||||
return
|
||||
|
||||
ModalToolBackend, ModalSandboxConfig, _, _ = imports
|
||||
|
||||
try:
|
||||
backend = ModalToolBackend.with_profiles(
|
||||
app_name=f"stress-multiprofile-{int(time.time())}",
|
||||
profiles={
|
||||
"cpu-light": ModalSandboxConfig(
|
||||
name="cpu-light",
|
||||
cpu=0.5,
|
||||
memory=1024,
|
||||
min_sandboxes=1,
|
||||
max_sandboxes=2,
|
||||
slots_per_sandbox=5,
|
||||
),
|
||||
"cpu-heavy": ModalSandboxConfig(
|
||||
name="cpu-heavy",
|
||||
cpu=2.0,
|
||||
memory=4096,
|
||||
min_sandboxes=0,
|
||||
max_sandboxes=2,
|
||||
slots_per_sandbox=3,
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
await backend.start(profiles_to_start=["cpu-light", "cpu-heavy"])
|
||||
|
||||
try:
|
||||
num_tasks_per_profile = 5
|
||||
slots = []
|
||||
|
||||
# Acquire from both profiles
|
||||
for i in range(num_tasks_per_profile):
|
||||
light_slot = await backend.acquire(f"light-{i}", profile="cpu-light")
|
||||
heavy_slot = await backend.acquire(f"heavy-{i}", profile="cpu-heavy")
|
||||
slots.append((light_slot, "cpu-light"))
|
||||
slots.append((heavy_slot, "cpu-heavy"))
|
||||
|
||||
# Execute batch across all profiles
|
||||
start_time = time.time()
|
||||
|
||||
requests = [
|
||||
(slot, "bash", {"command": f"echo 'profile={profile}'"})
|
||||
for slot, profile in slots
|
||||
]
|
||||
|
||||
batch_results = await backend.execute_batch(requests)
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
successes = sum(1 for r in batch_results if r.success)
|
||||
|
||||
# Release all
|
||||
for slot, _ in slots:
|
||||
await backend.release(slot)
|
||||
|
||||
status = backend.get_status()
|
||||
|
||||
success_rate = successes / len(slots) * 100
|
||||
|
||||
if success_rate >= 90:
|
||||
results.record_pass("test_atropos_stress_multi_profile_load", {
|
||||
"profiles": 2,
|
||||
"tasks_per_profile": num_tasks_per_profile,
|
||||
"total_tasks": len(slots),
|
||||
"successes": successes,
|
||||
"success_rate": f"{success_rate:.1f}%",
|
||||
"time": f"{elapsed:.2f}s",
|
||||
"status": status,
|
||||
})
|
||||
else:
|
||||
results.record_fail(
|
||||
"test_atropos_stress_multi_profile_load",
|
||||
f"Success rate {success_rate:.1f}% < 90%"
|
||||
)
|
||||
|
||||
finally:
|
||||
await backend.stop(purge=True)
|
||||
|
||||
except Exception as e:
|
||||
results.record_fail("test_atropos_stress_multi_profile_load", str(e))
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# CATEGORY 3: Mini-SWE-Agent Integration Tests
|
||||
# =============================================================================
|
||||
|
||||
def test_miniswe_environment_available():
|
||||
"""Check if mini-swe-agent is properly set up."""
|
||||
mini_swe_path = Path(__file__).parent.parent / "mini-swe-agent" / "src"
|
||||
|
||||
if not mini_swe_path.exists():
|
||||
results.record_skip(
|
||||
"test_miniswe_environment_available",
|
||||
"mini-swe-agent not found. Run: git clone https://github.com/anthropics/mini-swe-agent.git mini-swe-agent"
|
||||
)
|
||||
return
|
||||
|
||||
if not list(mini_swe_path.iterdir()):
|
||||
results.record_skip(
|
||||
"test_miniswe_environment_available",
|
||||
"mini-swe-agent directory is empty. Run: git submodule update --init"
|
||||
)
|
||||
return
|
||||
|
||||
miniswe = try_import_miniswe()
|
||||
if miniswe is None:
|
||||
results.record_fail(
|
||||
"test_miniswe_environment_available",
|
||||
"Failed to import minisweagent module"
|
||||
)
|
||||
return
|
||||
|
||||
results.record_pass("test_miniswe_environment_available", {
|
||||
"path": str(mini_swe_path),
|
||||
"module": miniswe.__name__,
|
||||
})
|
||||
|
||||
|
||||
def test_miniswe_modal_backend(config: StressTestConfig):
|
||||
"""Test mini-swe-agent with Modal backend."""
|
||||
if config.dry_run:
|
||||
results.record_skip("test_miniswe_modal_backend", "Dry run mode")
|
||||
return
|
||||
|
||||
miniswe = try_import_miniswe()
|
||||
if miniswe is None:
|
||||
results.record_skip(
|
||||
"test_miniswe_modal_backend",
|
||||
"mini-swe-agent not available"
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
# Check if ModalEnvironment exists in minisweagent
|
||||
if not hasattr(miniswe, 'ModalEnvironment'):
|
||||
results.record_skip(
|
||||
"test_miniswe_modal_backend",
|
||||
"minisweagent.ModalEnvironment not found"
|
||||
)
|
||||
return
|
||||
|
||||
# Create Modal environment
|
||||
env = miniswe.ModalEnvironment(
|
||||
image="python:3.11",
|
||||
timeout=60,
|
||||
)
|
||||
|
||||
# Execute a command
|
||||
result = env.execute("echo 'Hello from mini-swe-agent Modal'")
|
||||
|
||||
env.cleanup()
|
||||
|
||||
if "Hello from mini-swe-agent Modal" in str(result):
|
||||
results.record_pass("test_miniswe_modal_backend")
|
||||
else:
|
||||
results.record_fail(
|
||||
"test_miniswe_modal_backend",
|
||||
f"Unexpected result: {result}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
results.record_fail("test_miniswe_modal_backend", str(e))
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test Runner
|
||||
# =============================================================================
|
||||
|
||||
def run_sync_tests(config: StressTestConfig):
|
||||
"""Run synchronous tests."""
|
||||
if config.category in (None, "stress"):
|
||||
print("\n" + "="*70)
|
||||
print("STRESS TESTS (Terminal Tool)")
|
||||
print("="*70)
|
||||
|
||||
test_stress_concurrent_tasks(config)
|
||||
test_stress_rapid_fire(config)
|
||||
test_stress_pool_scaling(config)
|
||||
test_stress_large_output(config)
|
||||
test_stress_error_recovery(config)
|
||||
|
||||
if config.category in (None, "miniswe"):
|
||||
print("\n" + "="*70)
|
||||
print("MINI-SWE-AGENT INTEGRATION TESTS")
|
||||
print("="*70)
|
||||
|
||||
test_miniswe_environment_available()
|
||||
test_miniswe_modal_backend(config)
|
||||
|
||||
|
||||
async def run_async_tests(config: StressTestConfig):
|
||||
"""Run asynchronous tests."""
|
||||
if config.category in (None, "atropos"):
|
||||
print("\n" + "="*70)
|
||||
print("ATROPOS BACKEND STRESS TESTS")
|
||||
print("="*70)
|
||||
|
||||
await test_atropos_stress_slot_churn(config)
|
||||
await test_atropos_stress_parallel_batches(config)
|
||||
await test_atropos_stress_multi_profile_load(config)
|
||||
|
||||
|
||||
def main():
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="Modal Stress Test Suite")
|
||||
parser.add_argument("--dry-run", action="store_true", help="Skip tests requiring Modal")
|
||||
parser.add_argument("--category", choices=["stress", "atropos", "miniswe"], help="Run specific category")
|
||||
parser.add_argument("--concurrent", type=int, default=10, help="Number of concurrent tasks")
|
||||
parser.add_argument("--operations", type=int, default=50, help="Total operations for stress tests")
|
||||
parser.add_argument("--verbose", action="store_true", default=True)
|
||||
args = parser.parse_args()
|
||||
|
||||
config = StressTestConfig(
|
||||
dry_run=args.dry_run,
|
||||
verbose=args.verbose,
|
||||
category=args.category,
|
||||
concurrent_tasks=args.concurrent,
|
||||
total_operations=args.operations,
|
||||
)
|
||||
|
||||
print("="*70)
|
||||
print("MODAL STRESS & INTEGRATION TEST SUITE")
|
||||
print("="*70)
|
||||
print(f"Mode: {'DRY RUN' if config.dry_run else 'LIVE'}")
|
||||
print(f"Category: {config.category or 'ALL'}")
|
||||
print(f"Concurrent tasks: {config.concurrent_tasks}")
|
||||
print(f"Total operations: {config.total_operations}")
|
||||
|
||||
# Run sync tests
|
||||
run_sync_tests(config)
|
||||
|
||||
# Run async tests
|
||||
asyncio.run(run_async_tests(config))
|
||||
|
||||
# Summary
|
||||
success = results.summary()
|
||||
sys.exit(0 if success else 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -236,6 +236,63 @@ def test_environment_isolation():
|
||||
return isolated
|
||||
|
||||
|
||||
def test_pool_status():
|
||||
"""Test that the Modal pool manager reports status correctly."""
|
||||
print("\n" + "=" * 60)
|
||||
print("TEST 7: Pool Status")
|
||||
print("=" * 60)
|
||||
|
||||
try:
|
||||
# Import pool manager
|
||||
_ModalPoolManager = terminal_module._ModalPoolManager
|
||||
|
||||
# Get pool manager instance
|
||||
manager = _ModalPoolManager.get_instance()
|
||||
status = manager.get_status()
|
||||
|
||||
print(f"\nPool Manager Status:")
|
||||
print(f" App name: {manager.app_name}")
|
||||
print(f" Default profile: {manager.default_profile}")
|
||||
print(f" Available profiles: {list(manager.profiles.keys())}")
|
||||
print(f" Active pools: {list(status.keys())}")
|
||||
|
||||
for pool_name, pool_status in status.items():
|
||||
print(f"\n Pool '{pool_name}':")
|
||||
print(f" Size: {pool_status['pool_size']}/{pool_status['max_pool']}")
|
||||
print(f" In use: {pool_status['in_use']}")
|
||||
print(f" Min pool: {pool_status['min_pool']}")
|
||||
|
||||
print(f"\nTest: ✅ Passed")
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"\nError: {e}")
|
||||
print(f"\nTest: ❌ Failed")
|
||||
return False
|
||||
|
||||
|
||||
def test_profile_selection():
|
||||
"""Test that profile parameter is accepted (even if profile doesn't exist)."""
|
||||
print("\n" + "=" * 60)
|
||||
print("TEST 8: Profile Selection")
|
||||
print("=" * 60)
|
||||
|
||||
test_task_id = "modal_test_profile"
|
||||
|
||||
# Test with default profile (no profile specified)
|
||||
print("Testing with default profile...")
|
||||
result = terminal_tool("echo 'default profile'", task_id=test_task_id)
|
||||
result_json = json.loads(result)
|
||||
|
||||
success = result_json.get('exit_code') == 0
|
||||
print(f" Default profile: {'✅' if success else '❌'} (exit code: {result_json.get('exit_code')})")
|
||||
|
||||
# Cleanup
|
||||
cleanup_vm(test_task_id)
|
||||
|
||||
print(f"\nTest: {'✅ Passed' if success else '❌ Failed'}")
|
||||
return success
|
||||
|
||||
|
||||
def main():
|
||||
"""Run all Modal terminal tests."""
|
||||
print("🧪 Modal Terminal Tool Test Suite")
|
||||
@@ -247,6 +304,8 @@ def main():
|
||||
print(f" TERMINAL_ENV: {config['env_type']}")
|
||||
print(f" TERMINAL_MODAL_IMAGE: {config['modal_image']}")
|
||||
print(f" TERMINAL_TIMEOUT: {config['timeout']}s")
|
||||
print(f" TERMINAL_MODAL_APP_NAME: {os.getenv('TERMINAL_MODAL_APP_NAME', 'hermes-sandbox')}")
|
||||
print(f" TERMINAL_MODAL_DEFAULT_PROFILE: {os.getenv('TERMINAL_MODAL_DEFAULT_PROFILE', 'default')}")
|
||||
|
||||
if config['env_type'] != 'modal':
|
||||
print(f"\n⚠️ WARNING: TERMINAL_ENV is set to '{config['env_type']}', not 'modal'")
|
||||
@@ -270,6 +329,8 @@ def main():
|
||||
results['pip_install'] = test_pip_install()
|
||||
results['filesystem_persistence'] = test_filesystem_persistence()
|
||||
results['environment_isolation'] = test_environment_isolation()
|
||||
results['pool_status'] = test_pool_status()
|
||||
results['profile_selection'] = test_profile_selection()
|
||||
|
||||
# Summary
|
||||
print("\n" + "=" * 60)
|
||||
|
||||
31
tests/test_tool_call_parsing.py
Normal file
31
tests/test_tool_call_parsing.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from atropos.tools.base import ToolCall
|
||||
|
||||
|
||||
def test_parse_tool_call_json_wrapper() -> None:
|
||||
text = '<tool_call>{"name":"terminal","arguments":{"command":"pwd"}}</tool_call>'
|
||||
calls = ToolCall.parse_from_text(text)
|
||||
assert len(calls) == 1
|
||||
assert calls[0].name == "terminal"
|
||||
assert calls[0].arguments == {"command": "pwd"}
|
||||
|
||||
|
||||
def test_parse_tool_call_glm_style() -> None:
|
||||
text = '<tool_call>terminal{"command":"ls -la"}</tool_call>'
|
||||
calls = ToolCall.parse_from_text(text)
|
||||
assert len(calls) == 1
|
||||
assert calls[0].name == "terminal"
|
||||
assert calls[0].arguments == {"command": "ls -la"}
|
||||
|
||||
|
||||
def test_parse_tool_call_missing_close_tag() -> None:
|
||||
text = '<tool_call>terminal{"command":"echo hi"}'
|
||||
calls = ToolCall.parse_from_text(text)
|
||||
assert calls == []
|
||||
|
||||
|
||||
def test_parse_tool_call_strips_accidental_xml() -> None:
|
||||
text = '<tool_call>terminal{"command":"ls -la"}</arg_value></tool_call>'
|
||||
calls = ToolCall.parse_from_text(text)
|
||||
assert calls == []
|
||||
@@ -16,14 +16,6 @@ The tools are imported into model_tools.py which provides a unified interface
|
||||
for the AI agent to access all capabilities.
|
||||
"""
|
||||
|
||||
# Export all tools for easy importing
|
||||
from .web_tools import (
|
||||
web_search_tool,
|
||||
web_extract_tool,
|
||||
web_crawl_tool,
|
||||
check_firecrawl_api_key
|
||||
)
|
||||
|
||||
# Primary terminal tool (mini-swe-agent backend: local/docker/singularity/modal)
|
||||
from .terminal_tool import (
|
||||
terminal_tool,
|
||||
@@ -34,54 +26,106 @@ from .terminal_tool import (
|
||||
TERMINAL_TOOL_DESCRIPTION
|
||||
)
|
||||
|
||||
# Alternative terminal tool (Hecate/MorphCloud cloud VMs)
|
||||
from .terminal_hecate import (
|
||||
terminal_hecate_tool,
|
||||
check_hecate_requirements,
|
||||
TERMINAL_HECATE_DESCRIPTION
|
||||
)
|
||||
# Optional toolsets: keep imports soft so users can run subsets of tools without
|
||||
# installing every dependency (requirements gating lives in model_tools.py).
|
||||
try:
|
||||
from .web_tools import check_firecrawl_api_key, web_crawl_tool, web_extract_tool, web_search_tool
|
||||
except ModuleNotFoundError: # pragma: no cover
|
||||
web_search_tool = None # type: ignore[assignment]
|
||||
web_extract_tool = None # type: ignore[assignment]
|
||||
web_crawl_tool = None # type: ignore[assignment]
|
||||
|
||||
from .vision_tools import (
|
||||
vision_analyze_tool,
|
||||
check_vision_requirements
|
||||
)
|
||||
def check_firecrawl_api_key() -> bool: # type: ignore[no-redef]
|
||||
return False
|
||||
|
||||
from .mixture_of_agents_tool import (
|
||||
mixture_of_agents_tool,
|
||||
check_moa_requirements
|
||||
)
|
||||
try:
|
||||
# Alternative terminal tool (Hecate/MorphCloud cloud VMs)
|
||||
from .terminal_hecate import TERMINAL_HECATE_DESCRIPTION, check_hecate_requirements, terminal_hecate_tool
|
||||
except ModuleNotFoundError: # pragma: no cover
|
||||
terminal_hecate_tool = None # type: ignore[assignment]
|
||||
TERMINAL_HECATE_DESCRIPTION = ""
|
||||
|
||||
from .image_generation_tool import (
|
||||
image_generate_tool,
|
||||
check_image_generation_requirements
|
||||
)
|
||||
def check_hecate_requirements() -> bool: # type: ignore[no-redef]
|
||||
return False
|
||||
|
||||
from .skills_tool import (
|
||||
skills_categories,
|
||||
skills_list,
|
||||
skill_view,
|
||||
check_skills_requirements,
|
||||
SKILLS_TOOL_DESCRIPTION
|
||||
)
|
||||
try:
|
||||
from .vision_tools import check_vision_requirements, vision_analyze_tool
|
||||
except ModuleNotFoundError: # pragma: no cover
|
||||
vision_analyze_tool = None # type: ignore[assignment]
|
||||
|
||||
# Browser automation tools (agent-browser + Browserbase)
|
||||
from .browser_tool import (
|
||||
browser_navigate,
|
||||
browser_snapshot,
|
||||
browser_click,
|
||||
browser_type,
|
||||
browser_scroll,
|
||||
browser_back,
|
||||
browser_press,
|
||||
browser_close,
|
||||
browser_get_images,
|
||||
browser_vision,
|
||||
cleanup_browser,
|
||||
cleanup_all_browsers,
|
||||
get_active_browser_sessions,
|
||||
check_browser_requirements,
|
||||
BROWSER_TOOL_SCHEMAS
|
||||
)
|
||||
def check_vision_requirements() -> bool: # type: ignore[no-redef]
|
||||
return False
|
||||
|
||||
try:
|
||||
from .mixture_of_agents_tool import check_moa_requirements, mixture_of_agents_tool
|
||||
except ModuleNotFoundError: # pragma: no cover
|
||||
mixture_of_agents_tool = None # type: ignore[assignment]
|
||||
|
||||
def check_moa_requirements() -> bool: # type: ignore[no-redef]
|
||||
return False
|
||||
|
||||
try:
|
||||
from .image_generation_tool import check_image_generation_requirements, image_generate_tool
|
||||
except ModuleNotFoundError: # pragma: no cover
|
||||
image_generate_tool = None # type: ignore[assignment]
|
||||
|
||||
def check_image_generation_requirements() -> bool: # type: ignore[no-redef]
|
||||
return False
|
||||
|
||||
try:
|
||||
from .skills_tool import (
|
||||
SKILLS_TOOL_DESCRIPTION,
|
||||
check_skills_requirements,
|
||||
skill_view,
|
||||
skills_categories,
|
||||
skills_list,
|
||||
)
|
||||
except ModuleNotFoundError: # pragma: no cover
|
||||
skills_categories = None # type: ignore[assignment]
|
||||
skills_list = None # type: ignore[assignment]
|
||||
skill_view = None # type: ignore[assignment]
|
||||
SKILLS_TOOL_DESCRIPTION = ""
|
||||
|
||||
def check_skills_requirements() -> bool: # type: ignore[no-redef]
|
||||
return False
|
||||
|
||||
try:
|
||||
# Browser automation tools (agent-browser + Browserbase)
|
||||
from .browser_tool import (
|
||||
BROWSER_TOOL_SCHEMAS,
|
||||
browser_back,
|
||||
browser_click,
|
||||
browser_close,
|
||||
browser_get_images,
|
||||
browser_navigate,
|
||||
browser_press,
|
||||
browser_scroll,
|
||||
browser_snapshot,
|
||||
browser_type,
|
||||
browser_vision,
|
||||
check_browser_requirements,
|
||||
cleanup_all_browsers,
|
||||
cleanup_browser,
|
||||
get_active_browser_sessions,
|
||||
)
|
||||
except ModuleNotFoundError: # pragma: no cover
|
||||
browser_navigate = None # type: ignore[assignment]
|
||||
browser_snapshot = None # type: ignore[assignment]
|
||||
browser_click = None # type: ignore[assignment]
|
||||
browser_type = None # type: ignore[assignment]
|
||||
browser_scroll = None # type: ignore[assignment]
|
||||
browser_back = None # type: ignore[assignment]
|
||||
browser_press = None # type: ignore[assignment]
|
||||
browser_close = None # type: ignore[assignment]
|
||||
browser_get_images = None # type: ignore[assignment]
|
||||
browser_vision = None # type: ignore[assignment]
|
||||
cleanup_browser = None # type: ignore[assignment]
|
||||
cleanup_all_browsers = None # type: ignore[assignment]
|
||||
get_active_browser_sessions = None # type: ignore[assignment]
|
||||
BROWSER_TOOL_SCHEMAS = []
|
||||
|
||||
def check_browser_requirements() -> bool: # type: ignore[no-redef]
|
||||
return False
|
||||
|
||||
# Cronjob management tools (CLI-only, hermes-cli toolset)
|
||||
from .cronjob_tools import (
|
||||
@@ -206,4 +250,3 @@ __all__ = [
|
||||
'clear_file_ops_cache',
|
||||
'check_file_requirements',
|
||||
]
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
1
wandb/latest-run
Symbolic link
1
wandb/latest-run
Symbolic link
@@ -0,0 +1 @@
|
||||
run-20260206_003827-82b0oahi
|
||||
180
wandb/run-20260206_003827-82b0oahi/files/config.yaml
Normal file
180
wandb/run-20260206_003827-82b0oahi/files/config.yaml
Normal file
@@ -0,0 +1,180 @@
|
||||
_wandb:
|
||||
value:
|
||||
cli_version: 0.24.2
|
||||
e:
|
||||
2gw7xuffca69jbm2b60l3w5ymo5pb5lf:
|
||||
args:
|
||||
- process
|
||||
- --env.driver
|
||||
- singularity
|
||||
- --env.singularity_image
|
||||
- /root/Hermes-Agent/atropos/atropos-sandbox.sif
|
||||
email: shannon@nousresearch.com
|
||||
executable: /root/Hermes-Agent/.venv/bin/python
|
||||
git:
|
||||
commit: 4d619bcd21feedc9eed36c53c038585d97e7295e
|
||||
remote: https://github.com/NousResearch/Hermes-Agent.git
|
||||
host: vultr
|
||||
os: Linux-6.8.0-90-generic-x86_64-with-glibc2.39
|
||||
program: -m atropos.envs.swe_smith_oracle_env
|
||||
python: CPython 3.12.3
|
||||
root: /root/Hermes-Agent
|
||||
startedAt: "2026-02-06T00:38:27.351013Z"
|
||||
writerId: 2gw7xuffca69jbm2b60l3w5ymo5pb5lf
|
||||
m: []
|
||||
python_version: 3.12.3
|
||||
t:
|
||||
"1":
|
||||
- 11
|
||||
- 49
|
||||
- 51
|
||||
- 95
|
||||
"3":
|
||||
- 13
|
||||
- 16
|
||||
"4": 3.12.3
|
||||
"5": 0.24.2
|
||||
"6": 5.0.0
|
||||
"12": 0.24.2
|
||||
"13": linux-x86_64
|
||||
acquire_timeout_s:
|
||||
value: 30
|
||||
agent_max_steps:
|
||||
value: 50
|
||||
agent_max_tokens:
|
||||
value: null
|
||||
agent_temperature:
|
||||
value: 0.7
|
||||
agent_tool_delay_s:
|
||||
value: 0
|
||||
allow_network:
|
||||
value: true
|
||||
batch_size:
|
||||
value: 1
|
||||
custom_thinking_prompt:
|
||||
value: null
|
||||
data_dir_to_save_evals:
|
||||
value: null
|
||||
data_path_to_save_groups:
|
||||
value: data/swe_smith_oracle_env_2.jsonl
|
||||
dataset_name:
|
||||
value: NousResearch/SWE-smith-oracle
|
||||
dataset_split:
|
||||
value: train
|
||||
disabled_toolsets:
|
||||
value: []
|
||||
driver:
|
||||
value: singularity
|
||||
enabled_toolsets:
|
||||
value:
|
||||
- terminal
|
||||
ensure_scores_are_not_same:
|
||||
value: false
|
||||
eval_handling:
|
||||
value: STOP_TRAIN
|
||||
eval_limit_ratio:
|
||||
value: 0.5
|
||||
group_size:
|
||||
value: 1
|
||||
include_messages:
|
||||
value: true
|
||||
inference_weight:
|
||||
value: 1
|
||||
install_timeout_s:
|
||||
value: 600
|
||||
max_batches_offpolicy:
|
||||
value: 3
|
||||
max_containers:
|
||||
value: 10
|
||||
max_eval_workers:
|
||||
value: 16
|
||||
max_items:
|
||||
value: 0
|
||||
max_num_workers:
|
||||
value: -1
|
||||
max_num_workers_per_node:
|
||||
value: 8
|
||||
max_reasoning_tokens:
|
||||
value: null
|
||||
max_token_length:
|
||||
value: 8192
|
||||
min_batch_allocation:
|
||||
value: null
|
||||
min_containers:
|
||||
value: 1
|
||||
min_items_sent_before_logging:
|
||||
value: 2
|
||||
modal_app_name:
|
||||
value: atropos-sandbox
|
||||
modal_function_name:
|
||||
value: sandbox_server
|
||||
modal_volume_mount_path:
|
||||
value: /data
|
||||
modal_volume_name:
|
||||
value: null
|
||||
nomad_address:
|
||||
value: http://localhost:4646
|
||||
num_rollouts_per_group_for_logging:
|
||||
value: 1
|
||||
num_rollouts_to_keep:
|
||||
value: 32
|
||||
privileged:
|
||||
value: false
|
||||
prompt_mode:
|
||||
value: problem_statement
|
||||
purge_job_on_shutdown:
|
||||
value: true
|
||||
purge_job_on_start:
|
||||
value: true
|
||||
python_only:
|
||||
value: true
|
||||
reasoning_effort:
|
||||
value: null
|
||||
repo_base_url:
|
||||
value: https://github.com
|
||||
require_sandbox:
|
||||
value: false
|
||||
require_stateful_sandbox:
|
||||
value: false
|
||||
rollout_server_url:
|
||||
value: http://localhost:8000
|
||||
sandbox_image:
|
||||
value: atropos-sandbox:local
|
||||
sandbox_job_id:
|
||||
value: atropos-sandbox-agent-env
|
||||
score_include_fail_to_pass:
|
||||
value: true
|
||||
seed:
|
||||
value: 0
|
||||
shuffle:
|
||||
value: true
|
||||
singularity_image:
|
||||
value: /root/Hermes-Agent/atropos/atropos-sandbox.sif
|
||||
slots_per_container:
|
||||
value: 10
|
||||
steps_per_eval:
|
||||
value: 1
|
||||
test_timeout_s:
|
||||
value: 600
|
||||
thinking_mode:
|
||||
value: false
|
||||
tokenizer_name:
|
||||
value: NousResearch/Hermes-4.3-36B
|
||||
tool_batch_window_ms:
|
||||
value: 20
|
||||
tool_max_batch_size:
|
||||
value: 200
|
||||
tool_pool_mode:
|
||||
value: nomad
|
||||
tool_server_token:
|
||||
value: null
|
||||
tool_server_url:
|
||||
value: null
|
||||
total_steps:
|
||||
value: 1
|
||||
use_wandb:
|
||||
value: true
|
||||
wandb_name:
|
||||
value: swe_smith_oracle
|
||||
worker_timeout:
|
||||
value: 600
|
||||
BIN
wandb/run-20260206_003827-82b0oahi/run-82b0oahi.wandb
Normal file
BIN
wandb/run-20260206_003827-82b0oahi/run-82b0oahi.wandb
Normal file
Binary file not shown.
Reference in New Issue
Block a user