mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-05-07 19:26:56 +08:00
Compare commits
13 Commits
UI
...
fix-termin
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a6ec79730c | ||
|
|
faecbddd9b | ||
|
|
de9c0edc51 | ||
|
|
8d256779d8 | ||
|
|
d36790de91 | ||
|
|
a398d320b7 | ||
|
|
22b6d5866c | ||
|
|
0e2e69a71d | ||
|
|
bc5f0e62d9 | ||
|
|
6fac6fecde | ||
|
|
c42d9055ed | ||
|
|
a7ff4d49e9 | ||
|
|
0411ca1880 |
23
.cursorrules
Normal file
23
.cursorrules
Normal file
@@ -0,0 +1,23 @@
|
||||
Hermes-Agent is an agent harness for LLMs.
|
||||
|
||||
When building, the tool functionality is in the tools/ directory, where each specific tool (or in some cases, tools that are built for the same execution category or api) are placed in a script each their own.
|
||||
|
||||
Each tool is then consolidated in the model_tools.py file in the repo root.
|
||||
|
||||
There is also a way to consolidate sets of tools in toolsets.py for the agent to use.
|
||||
|
||||
The primary agent runner code is in run_agent, but other runners could be developed using the tools and framework.
|
||||
|
||||
Always ensure consistency between tools, the model_tools.py and toolsets.py when changing any of them, otherwise they could become desynced in a way that is detrimental to functionality.
|
||||
|
||||
The expected pathway for using API keys is to setup and place them in a .env file in the repo root.
|
||||
|
||||
Test scripts will be placed in tests/
|
||||
|
||||
The run_agent loop is setup to:
|
||||
- Process the enabled toolsets to provide to the model,
|
||||
- Pipe in a prompt or problem from the input to the agent,
|
||||
- Loop the LLM each time it calls a tool, until the model decides no more tools are needed and provides a natural language response,
|
||||
- Return that response.
|
||||
|
||||
There are additional caveats for logging, where we restructure the "tools" as a system prompt for storage later into a format that can be used and handled properly later.
|
||||
49
.env.example
Normal file
49
.env.example
Normal file
@@ -0,0 +1,49 @@
|
||||
# Hermes Agent Environment Configuration
|
||||
# Copy this file to .env and fill in your API keys
|
||||
# Get API keys from the URLs listed below
|
||||
|
||||
# =============================================================================
|
||||
# REQUIRED API KEYS
|
||||
# =============================================================================
|
||||
|
||||
# Anthropic API Key - Main agent model
|
||||
# Get at: https://console.anthropic.com/
|
||||
ANTHROPIC_API_KEY=
|
||||
|
||||
# Firecrawl API Key - Web search, extract, and crawl
|
||||
# Get at: https://firecrawl.dev/
|
||||
FIRECRAWL_API_KEY=
|
||||
|
||||
# Nous Research API Key - Vision analysis and multi-model reasoning
|
||||
# Get at: https://inference-api.nousresearch.com/
|
||||
NOUS_API_KEY=
|
||||
|
||||
# Morph API Key - Terminal/command execution tools
|
||||
# Get at: https://morph.so/
|
||||
MORPH_API_KEY=
|
||||
|
||||
# FAL.ai API Key - Image generation
|
||||
# Get at: https://fal.ai/
|
||||
FAL_KEY=
|
||||
|
||||
# =============================================================================
|
||||
# OPTIONAL API KEYS
|
||||
# =============================================================================
|
||||
|
||||
# OpenAI API Key - Optional, for enhanced Hecate features
|
||||
# Get at: https://platform.openai.com/
|
||||
OPENAI_API_KEY=
|
||||
|
||||
# =============================================================================
|
||||
# OPTIONAL CONFIGURATION
|
||||
# =============================================================================
|
||||
|
||||
# Terminal Tool Settings
|
||||
HECATE_VM_LIFETIME_SECONDS=300
|
||||
HECATE_DEFAULT_SNAPSHOT_ID=snapshot_p5294qxt
|
||||
|
||||
# Debug Logging (set to "true" to enable, logs saved to ./logs/)
|
||||
WEB_TOOLS_DEBUG=false
|
||||
VISION_TOOLS_DEBUG=false
|
||||
MOA_TOOLS_DEBUG=false
|
||||
IMAGE_TOOLS_DEBUG=false
|
||||
6
.gitignore
vendored
6
.gitignore
vendored
@@ -16,4 +16,8 @@ __pycache__/
|
||||
export*
|
||||
__pycache__/model_tools.cpython-310.pyc
|
||||
__pycache__/web_tools.cpython-310.pyc
|
||||
logs/
|
||||
logs/
|
||||
data/
|
||||
.pytest_cache/
|
||||
tmp/
|
||||
temp_vision_images/
|
||||
472
README.md
472
README.md
@@ -1,295 +1,243 @@
|
||||
# Hermes Agent
|
||||
|
||||
AI Agent with advanced tool calling capabilities, real-time logging, and extensible toolsets.
|
||||
An AI agent with advanced tool-calling capabilities, featuring a flexible toolsets system for organizing and managing tools.
|
||||
|
||||
## Features
|
||||
|
||||
- 🤖 **Multi-model Support**: Works with Claude, GPT-4, and other OpenAI-compatible models
|
||||
- 🔧 **Rich Tool Library**: Web search, content extraction, vision analysis, terminal execution, and more
|
||||
- 📊 **Real-time Logging**: WebSocket-based logging system for monitoring agent execution
|
||||
- 🖥️ **Desktop UI**: Modern PySide6 frontend with real-time event streaming
|
||||
- 🎯 **Flexible Toolsets**: Predefined toolset combinations for different use cases
|
||||
- 💾 **Trajectory Saving**: Save conversation flows for training and analysis
|
||||
- 🔄 **Auto-retry**: Built-in error handling and retry logic
|
||||
- **Web Tools**: Search, extract content, and crawl websites
|
||||
- **Terminal Tools**: Execute commands with interactive session support
|
||||
- **Vision Tools**: Analyze images from URLs
|
||||
- **Reasoning Tools**: Advanced multi-model reasoning (Mixture of Agents)
|
||||
- **Creative Tools**: Generate images from text prompts
|
||||
- **Toolsets System**: Organize tools into logical groups for different scenarios
|
||||
- **Batch Processing**: Process datasets in parallel with checkpointing and statistics tracking
|
||||
- **Ephemeral System Prompts**: Guide model behavior without polluting training datasets
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Installation
|
||||
## Setup
|
||||
|
||||
### 1. Install Dependencies
|
||||
```bash
|
||||
# Create and activate virtual environment (recommended)
|
||||
python3 -m venv venv
|
||||
source venv/bin/activate # On Windows: venv\Scripts\activate
|
||||
|
||||
# Install required packages
|
||||
pip install -r requirements.txt
|
||||
|
||||
# Install Hecate for terminal tools
|
||||
git clone git@github.com:NousResearch/hecate.git
|
||||
cd hecate
|
||||
pip install -e .
|
||||
cd ..
|
||||
```
|
||||
|
||||
### Basic Usage
|
||||
### 2. Configure Environment Variables
|
||||
```bash
|
||||
# Copy the example environment file
|
||||
cp .env.example .env
|
||||
|
||||
# Edit .env and add your API keys
|
||||
nano .env # or use your preferred editor
|
||||
```
|
||||
|
||||
**Required API Keys:**
|
||||
- `ANTHROPIC_API_KEY` - Main agent model (get at: https://console.anthropic.com/)
|
||||
- `FIRECRAWL_API_KEY` - Web tools (get at: https://firecrawl.dev/)
|
||||
- `NOUS_API_KEY` - Vision & reasoning tools (get at: https://inference-api.nousresearch.com/)
|
||||
- `MORPH_API_KEY` - Terminal tools (get at: https://morph.so/)
|
||||
- `FAL_KEY` - Image generation (get at: https://fal.ai/)
|
||||
- `OPENAI_API_KEY` - Optional, for some Hecate features
|
||||
|
||||
See `.env.example` for all available configuration options including debug settings and terminal tool configuration.
|
||||
|
||||
## Toolsets System
|
||||
|
||||
The agent uses a toolsets system for organizing and managing tools. All tools must be part of a toolset to be accessible - individual tool selection is not supported. This ensures consistent and logical grouping of capabilities.
|
||||
|
||||
### Key Concepts
|
||||
|
||||
- **Toolsets**: Logical groups of tools for specific use cases (e.g., "research", "development", "debugging")
|
||||
- **Composition**: Toolsets can include other toolsets for powerful combinations
|
||||
- **Custom Toolsets**: Create your own toolsets at runtime or by editing `toolsets.py`
|
||||
- **Toolset-Only Access**: Tools are only accessible through toolsets, not individually
|
||||
|
||||
### Available Toolsets
|
||||
|
||||
See `toolsets.py` for the complete list of predefined toolsets including:
|
||||
- Basic toolsets (web, terminal, vision, creative, reasoning)
|
||||
- Composite toolsets (research, development, analysis, etc.)
|
||||
- Scenario-specific toolsets (debugging, documentation, API testing, etc.)
|
||||
- Special toolsets (safe mode without terminal, minimal, offline)
|
||||
|
||||
### Using Toolsets
|
||||
|
||||
```bash
|
||||
python run_agent.py \
|
||||
--enabled_toolsets web \
|
||||
--query "Search for the latest AI news"
|
||||
```
|
||||
|
||||
### With Real-time Logging
|
||||
|
||||
```bash
|
||||
# Terminal 1: Start API endpoint server
|
||||
python api_endpoint/logging_server.py
|
||||
|
||||
# Terminal 2: Run agent
|
||||
python run_agent.py \
|
||||
--enabled_toolsets web \
|
||||
--enable_websocket_logging \
|
||||
--query "Your question here"
|
||||
```
|
||||
|
||||
### With Desktop UI (Recommended)
|
||||
|
||||
The easiest way to use Hermes Agent is through the desktop UI:
|
||||
|
||||
```bash
|
||||
# One-command launch (starts server + UI)
|
||||
cd ui && ./start_hermes_ui.sh
|
||||
|
||||
# Or manually:
|
||||
# Terminal 1: Start server
|
||||
python api_endpoint/logging_server.py
|
||||
|
||||
# Terminal 2: Start UI
|
||||
python ui/hermes_ui.py
|
||||
```
|
||||
|
||||
The UI provides:
|
||||
- 🖱️ Point-and-click query submission
|
||||
- 🎛️ Easy model and tool selection
|
||||
- 📊 Real-time event visualization
|
||||
- 🔄 Automatic WebSocket connection
|
||||
- 📝 Session history
|
||||
|
||||
## Project Structure
|
||||
|
||||
```
|
||||
Hermes-Agent/
|
||||
├── run_agent.py # Main agent runner
|
||||
├── model_tools.py # Tool definitions and handling
|
||||
├── toolsets.py # Predefined toolset combinations
|
||||
├── requirements.txt # Python dependencies
|
||||
│
|
||||
├── ui/ # Desktop UI ⭐ NEW
|
||||
│ ├── hermes_ui.py # PySide6 desktop application
|
||||
│ ├── start_hermes_ui.sh # UI launcher script
|
||||
│ └── test_ui_flow.py # UI integration tests
|
||||
│
|
||||
├── tools/ # Tool implementations
|
||||
│ ├── web_tools.py # Web search, extract, crawl
|
||||
│ ├── vision_tools.py # Image analysis
|
||||
│ ├── terminal_tool.py # Command execution
|
||||
│ ├── image_generation_tool.py
|
||||
│ └── ...
|
||||
│
|
||||
├── api_endpoint/ # FastAPI + WebSocket logging endpoint
|
||||
│ ├── logging_server.py # WebSocket server + Agent API ⭐ ENHANCED
|
||||
│ ├── websocket_logger.py # Client library
|
||||
│ ├── README.md # API endpoint docs
|
||||
│ └── ...
|
||||
│
|
||||
├── logs/ # Log files
|
||||
│ └── realtime/ # WebSocket session logs
|
||||
│
|
||||
└── tests/ # Test files
|
||||
```
|
||||
|
||||
## Available Toolsets
|
||||
|
||||
### Basic Toolsets
|
||||
- **web**: Web search, extract, and crawl
|
||||
- **terminal**: Command execution
|
||||
- **vision**: Image analysis
|
||||
- **creative**: Image generation
|
||||
- **reasoning**: Mixture of agents
|
||||
|
||||
### Composite Toolsets
|
||||
- **research**: Web + vision tools
|
||||
- **development**: Web + terminal + vision
|
||||
- **analysis**: Web + vision + reasoning
|
||||
- **full_stack**: All tools enabled
|
||||
|
||||
### Usage Examples
|
||||
|
||||
```bash
|
||||
# Research with web and vision
|
||||
python run_agent.py --enabled_toolsets research --query "..."
|
||||
|
||||
# Development with terminal access
|
||||
python run_agent.py --enabled_toolsets development --query "..."
|
||||
# Use a predefined toolset
|
||||
python run_agent.py --enabled_toolsets=research --query "Find latest AI papers"
|
||||
|
||||
# Combine multiple toolsets
|
||||
python run_agent.py --enabled_toolsets web,vision --query "..."
|
||||
python run_agent.py --enabled_toolsets=web,vision --query "Analyze this website"
|
||||
|
||||
# Enable all toolsets explicitly (same as omitting the flag)
|
||||
python run_agent.py --enabled_toolsets=all --query "Do web research and run commands if helpful"
|
||||
|
||||
# Safe mode (no terminal access)
|
||||
python run_agent.py --enabled_toolsets=safe --query "Help without running commands"
|
||||
|
||||
# List all available toolsets and tools
|
||||
python run_agent.py --list_tools
|
||||
```
|
||||
|
||||
## Real-time Logging System
|
||||
For detailed documentation on toolsets, see `TOOLSETS_README.md`.
|
||||
|
||||
Monitor your agent's execution in real-time with the FastAPI WebSocket endpoint using a **persistent connection pool** architecture.
|
||||
## Basic Usage
|
||||
|
||||
### Architecture
|
||||
|
||||
The logging system uses a **singleton WebSocket connection** that persists across multiple agent runs:
|
||||
- ✅ **No timeouts** - connection stays alive indefinitely
|
||||
- ✅ **No reconnection overhead** - connect once, reuse forever
|
||||
- ✅ **Parallel execution** - multiple agents share one connection
|
||||
- ✅ **Production-ready** - graceful shutdown with signal handlers
|
||||
|
||||
See [`api_endpoint/PERSISTENT_CONNECTION_GUIDE.md`](api_endpoint/PERSISTENT_CONNECTION_GUIDE.md) for technical details.
|
||||
|
||||
### Features
|
||||
- Track all API calls and responses
|
||||
- **Persistent connection** - one WebSocket for all sessions
|
||||
- Monitor tool executions with parameters and timing
|
||||
- Capture errors and completion status
|
||||
- REST API for querying sessions
|
||||
- Real-time WebSocket broadcasting
|
||||
|
||||
### Documentation
|
||||
See [`api_endpoint/README.md`](api_endpoint/README.md) for complete documentation.
|
||||
|
||||
### Quick Start
|
||||
### Default (all tools enabled)
|
||||
```bash
|
||||
# Start API endpoint server
|
||||
python api_endpoint/logging_server.py
|
||||
|
||||
# Run agent with logging
|
||||
python run_agent.py --enable_websocket_logging --query "..."
|
||||
|
||||
# View logs
|
||||
curl http://localhost:8000/sessions
|
||||
python run_agent.py \
|
||||
--query "search up the latest docs on jit in python 3.13 and write me basic example that's not in their docs. profile its perf" \
|
||||
--max_turns 20 \
|
||||
--model claude-sonnet-4-20250514 \
|
||||
--base_url https://api.anthropic.com/v1/ \
|
||||
--api_key $ANTHROPIC_API_KEY
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
||||
### Environment Variables
|
||||
|
||||
Create a `.env` file in the project root:
|
||||
|
||||
### With specific toolset
|
||||
```bash
|
||||
# API Keys
|
||||
ANTHROPIC_API_KEY=your_key_here
|
||||
FIRECRAWL_API_KEY=your_key_here
|
||||
NOUS_API_KEY=your_key_here
|
||||
FAL_KEY=your_key_here
|
||||
|
||||
# Optional
|
||||
WEB_TOOLS_DEBUG=true # Enable web tools debug logging
|
||||
python run_agent.py \
|
||||
--query "Debug this Python error" \
|
||||
--enabled_toolsets=debugging \
|
||||
--model claude-sonnet-4-20250514 \
|
||||
--api_key $ANTHROPIC_API_KEY
|
||||
```
|
||||
|
||||
### Command-Line Options
|
||||
|
||||
```bash
|
||||
python run_agent.py --help
|
||||
```
|
||||
|
||||
Key options:
|
||||
- `--query`: Your question/task
|
||||
- `--model`: Model to use (default: claude-sonnet-4-5-20250929)
|
||||
- `--enabled_toolsets`: Toolsets to enable
|
||||
- `--max_turns`: Maximum conversation turns
|
||||
- `--enable_websocket_logging`: Enable real-time logging
|
||||
- `--verbose`: Verbose debug output
|
||||
- `--save_trajectories`: Save conversation trajectories
|
||||
|
||||
## Parallel Execution
|
||||
|
||||
The persistent connection pool enables true parallel agent execution. Multiple agents can run simultaneously, all sharing the same WebSocket connection for logging.
|
||||
|
||||
### Test Parallel Execution
|
||||
|
||||
```bash
|
||||
python test_parallel_execution.py
|
||||
```
|
||||
|
||||
This script runs three tests:
|
||||
1. **Sequential** - baseline (3 queries one after another)
|
||||
2. **Parallel** - 3 queries simultaneously
|
||||
3. **High Concurrency** - 10 queries simultaneously
|
||||
|
||||
**Expected Results:**
|
||||
- ⚡ ~3x speedup with parallel execution
|
||||
- ✅ All queries logged to same connection
|
||||
- ✅ No connection timeouts or errors
|
||||
|
||||
### Custom Parallel Code
|
||||
|
||||
### Python API
|
||||
```python
|
||||
import asyncio
|
||||
from run_agent import AIAgent
|
||||
|
||||
async def main():
|
||||
agent1 = AIAgent(enable_websocket_logging=True)
|
||||
agent2 = AIAgent(enable_websocket_logging=True)
|
||||
|
||||
# Run in parallel - both use shared connection!
|
||||
results = await asyncio.gather(
|
||||
agent1.run_conversation("Query 1"),
|
||||
agent2.run_conversation("Query 2")
|
||||
)
|
||||
# Use a specific toolset
|
||||
agent = AIAgent(
|
||||
model="claude-opus-4-20250514",
|
||||
enabled_toolsets=["research"]
|
||||
)
|
||||
response = agent.chat("Find information about quantum computing")
|
||||
|
||||
asyncio.run(main())
|
||||
# Create custom toolset at runtime
|
||||
from toolsets import create_custom_toolset
|
||||
|
||||
create_custom_toolset(
|
||||
name="my_tools",
|
||||
description="My custom toolkit",
|
||||
tools=["web_search"],
|
||||
includes=["terminal", "vision"]
|
||||
)
|
||||
|
||||
agent = AIAgent(enabled_toolsets=["my_tools"])
|
||||
```
|
||||
|
||||
## Batch Processing
|
||||
|
||||
Process multiple prompts from a dataset in parallel with automatic checkpointing and statistics tracking:
|
||||
|
||||
```bash
|
||||
# Basic batch processing
|
||||
python batch_runner.py \
|
||||
--dataset_file=prompts.jsonl \
|
||||
--batch_size=20 \
|
||||
--run_name=my_run
|
||||
|
||||
# With specific distribution
|
||||
python batch_runner.py \
|
||||
--dataset_file=prompts.jsonl \
|
||||
--batch_size=20 \
|
||||
--run_name=image_run \
|
||||
--distribution=image_gen \
|
||||
--num_workers=4
|
||||
```
|
||||
|
||||
**Key Features:**
|
||||
- Parallel processing with configurable workers
|
||||
- Toolset distributions for varied data generation
|
||||
- Automatic checkpointing and resume capability
|
||||
- Combined output in `data/<run_name>/trajectories.jsonl`
|
||||
- Tool usage statistics and success rates
|
||||
|
||||
**Quick Start:** See [QUICKSTART_BATCH.md](QUICKSTART_BATCH.md) for a 5-minute getting started guide.
|
||||
**Full Documentation:** See [BATCH_PROCESSING.md](BATCH_PROCESSING.md) for comprehensive documentation.
|
||||
|
||||
### Ephemeral System Prompts
|
||||
|
||||
The ephemeral system prompt feature allows you to guide the model's behavior during batch processing **without** saving that prompt to the training dataset trajectories. This is useful for:
|
||||
|
||||
- Guiding model behavior during data collection
|
||||
- Adding task-specific instructions
|
||||
- Keeping saved trajectories clean and focused on tool-calling format
|
||||
|
||||
**Example:**
|
||||
```bash
|
||||
python batch_runner.py \
|
||||
--dataset_file=prompts.jsonl \
|
||||
--batch_size=10 \
|
||||
--run_name=my_run \
|
||||
--ephemeral_system_prompt="You are a helpful assistant focused on image generation."
|
||||
```
|
||||
|
||||
The ephemeral prompt will influence the model's behavior during execution, but **only the standard tool-calling system prompt** will be saved in the trajectory files.
|
||||
|
||||
**Documentation:** See [docs/ephemeral_system_prompt.md](docs/ephemeral_system_prompt.md) for complete details.
|
||||
|
||||
## Command Line Arguments
|
||||
|
||||
**Single Agent (`run_agent.py`):**
|
||||
- `--query`: The question or task for the agent
|
||||
- `--model`: Model to use (default: claude-opus-4-20250514)
|
||||
- `--api_key`: API key for authentication
|
||||
- `--base_url`: API endpoint URL
|
||||
- `--max_turns`: Maximum number of tool-calling iterations
|
||||
- `--enabled_toolsets`: Comma-separated list of toolsets to enable. Use `all` (or `*`) to enable everything. If omitted, all toolsets are enabled by default.
|
||||
- `--disabled_toolsets`: Comma-separated list of toolsets to disable
|
||||
- `--list_tools`: List all available toolsets and tools
|
||||
- `--save_trajectories`: Save conversation trajectories to JSONL files
|
||||
|
||||
**Batch Processing (`batch_runner.py`):**
|
||||
- `--dataset_file`: Path to JSONL file with prompts
|
||||
- `--batch_size`: Number of prompts per batch
|
||||
- `--run_name`: Name for this run (for output/checkpointing)
|
||||
- `--distribution`: Toolset distribution to use (default: "default")
|
||||
- `--num_workers`: Number of parallel workers (default: 4)
|
||||
- `--resume`: Resume from checkpoint if interrupted
|
||||
- `--ephemeral_system_prompt`: System prompt used during execution but NOT saved to trajectories
|
||||
- `--list_distributions`: List available toolset distributions
|
||||
|
||||
## Environment Variables
|
||||
|
||||
All environment variables can be configured in the `.env` file (copy from `.env.example`).
|
||||
|
||||
**Core API Keys:**
|
||||
- `ANTHROPIC_API_KEY`: Main agent model
|
||||
- `FIRECRAWL_API_KEY`: Web tools (search, extract, crawl)
|
||||
- `NOUS_API_KEY`: Vision and reasoning tools
|
||||
- `MORPH_API_KEY`: Terminal tools
|
||||
- `FAL_KEY`: Image generation tools
|
||||
- `OPENAI_API_KEY`: Optional, for some Hecate features
|
||||
|
||||
**Configuration Options:**
|
||||
- `HECATE_VM_LIFETIME_SECONDS`: VM lifetime (default: 300)
|
||||
- `HECATE_DEFAULT_SNAPSHOT_ID`: Default snapshot (default: snapshot_p5294qxt)
|
||||
- `WEB_TOOLS_DEBUG`, `VISION_TOOLS_DEBUG`, `MOA_TOOLS_DEBUG`, `IMAGE_TOOLS_DEBUG`: Enable debug logging
|
||||
|
||||
## Documentation
|
||||
|
||||
**Single Agent Usage:**
|
||||
- `TOOLSETS_README.md`: Comprehensive guide to the toolsets system
|
||||
- `toolsets.py`: View and modify available toolsets
|
||||
- `model_tools.py`: Core tool definitions and handlers
|
||||
|
||||
**Batch Processing:**
|
||||
- `QUICKSTART_BATCH.md`: 5-minute quick start guide
|
||||
- `BATCH_PROCESSING.md`: Complete batch processing documentation
|
||||
- `toolset_distributions.py`: Toolset distributions for data generation
|
||||
|
||||
## Examples
|
||||
|
||||
### Investment Research
|
||||
```bash
|
||||
python run_agent.py \
|
||||
--enabled_toolsets web \
|
||||
--query "Find publicly traded companies in renewable energy"
|
||||
```
|
||||
|
||||
### Code Analysis
|
||||
```bash
|
||||
python run_agent.py \
|
||||
--enabled_toolsets development \
|
||||
--query "Analyze the codebase and suggest improvements"
|
||||
```
|
||||
|
||||
### Image Analysis
|
||||
```bash
|
||||
python run_agent.py \
|
||||
--enabled_toolsets vision \
|
||||
--query "Analyze this chart and explain the trends"
|
||||
```
|
||||
|
||||
## Development
|
||||
|
||||
### Adding New Tools
|
||||
|
||||
1. Create tool in `tools/` directory
|
||||
2. Register in `model_tools.py`
|
||||
3. Add to appropriate toolset in `toolsets.py`
|
||||
|
||||
### Running Tests
|
||||
|
||||
```bash
|
||||
# Test web tools
|
||||
python tests/test_web_tools.py
|
||||
|
||||
# Test API endpoint / logging
|
||||
cd api_endpoint
|
||||
./test_websocket_logging.sh
|
||||
```
|
||||
|
||||
## License
|
||||
|
||||
MIT License - see LICENSE file for details
|
||||
|
||||
## Contributing
|
||||
|
||||
Contributions welcome! Please open an issue or PR.
|
||||
|
||||
## Support
|
||||
|
||||
For questions or issues:
|
||||
1. Check documentation in `api_endpoint/`
|
||||
2. Review example usage in this README
|
||||
3. Open a GitHub issue
|
||||
|
||||
---
|
||||
|
||||
Built with ❤️ for advanced AI agent workflows
|
||||
See `TOOLSETS_README.md` for extensive examples of using different toolsets for various scenarios.
|
||||
|
||||
BIN
__pycache__/model_tools.cpython-310.pyc
Normal file
BIN
__pycache__/model_tools.cpython-310.pyc
Normal file
Binary file not shown.
BIN
__pycache__/web_tools.cpython-310.pyc
Normal file
BIN
__pycache__/web_tools.cpython-310.pyc
Normal file
Binary file not shown.
@@ -1,26 +0,0 @@
|
||||
"""
|
||||
Hermes Agent - API Endpoint & Real-time Logging
|
||||
|
||||
This package provides a FastAPI WebSocket endpoint for real-time logging of the Hermes Agent.
|
||||
|
||||
Components:
|
||||
- logging_server: FastAPI server that receives and stores events
|
||||
- websocket_logger: Client library for sending events from the agent
|
||||
|
||||
Usage:
|
||||
# Start the API endpoint server
|
||||
python api_endpoint/logging_server.py
|
||||
|
||||
# Use in agent code
|
||||
from api_endpoint.websocket_logger import WebSocketLogger
|
||||
|
||||
For more information, see:
|
||||
- WEBSOCKET_LOGGING_GUIDE.md - User guide
|
||||
- IMPLEMENTATION_SUMMARY.md - Technical details
|
||||
"""
|
||||
|
||||
from .websocket_logger import WebSocketLogger, SyncWebSocketLogger
|
||||
|
||||
__all__ = ['WebSocketLogger', 'SyncWebSocketLogger']
|
||||
__version__ = '1.0.0'
|
||||
|
||||
@@ -1,603 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Hermes Agent - Real-time Logging Server
|
||||
|
||||
A FastAPI server with WebSocket support that listens for agent execution events
|
||||
and logs them to JSON files in real-time.
|
||||
|
||||
Events tracked:
|
||||
- User queries
|
||||
- API calls (requests to the model)
|
||||
- Assistant responses
|
||||
- Tool calls (name, parameters, timing)
|
||||
- Tool results (outputs, errors, duration)
|
||||
- Final responses
|
||||
- Session metadata
|
||||
|
||||
Usage:
|
||||
python logging_server.py
|
||||
|
||||
Or with uvicorn directly:
|
||||
uvicorn logging_server:app --host 0.0.0.0 --port 8000 --reload
|
||||
|
||||
The server will listen for WebSocket connections at ws://localhost:8000/ws
|
||||
"""
|
||||
|
||||
import json
|
||||
import asyncio
|
||||
import threading
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, List, Optional
|
||||
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, BackgroundTasks
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from pydantic import BaseModel
|
||||
import uvicorn
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
|
||||
|
||||
# Configuration
|
||||
LOGS_DIR = Path(__file__).parent / "logs" / "realtime"
|
||||
LOGS_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Initialize FastAPI app
|
||||
app = FastAPI(
|
||||
title="Hermes Agent API Endpoint",
|
||||
description="Manage interface between agent and user",
|
||||
version="1.0.0"
|
||||
)
|
||||
|
||||
# Add CORS middleware
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
class SessionLogger:
|
||||
"""
|
||||
Manages logging for a single agent session.
|
||||
|
||||
Each agent execution gets its own SessionLogger instance.
|
||||
Responsible for:
|
||||
- Collecting all events for the session
|
||||
- Saving events to JSON file in real-time
|
||||
- Managing session lifecycle (start -> events -> finalize)
|
||||
"""
|
||||
|
||||
def __init__(self, session_id: str):
|
||||
self.session_id = session_id
|
||||
self.start_time = datetime.now()
|
||||
self.events: List[Dict[str, Any]] = [] # In-memory list of all events
|
||||
self.log_file = LOGS_DIR / f"session_{session_id}.json" # Where to save on disk
|
||||
|
||||
# Initialize session data structure
|
||||
# This is what gets saved to the JSON file
|
||||
self.session_data = {
|
||||
"session_id": session_id,
|
||||
"start_time": self.start_time.isoformat(),
|
||||
"end_time": None, # Set when session completes
|
||||
"events": [], # Will be populated as events come in
|
||||
"metadata": {} # Model, toolsets, etc. (set via session_start event)
|
||||
}
|
||||
|
||||
def add_event(self, event: Dict[str, Any]):
|
||||
"""
|
||||
Add an event to the session log.
|
||||
|
||||
Called every time a new event arrives (query, api_call, tool_call, etc).
|
||||
IMMEDIATELY saves to file for real-time persistence.
|
||||
"""
|
||||
# Add timestamp if not present (should always be added, but safety check)
|
||||
if "timestamp" not in event:
|
||||
event["timestamp"] = datetime.now().isoformat()
|
||||
|
||||
# Add to in-memory event list
|
||||
self.events.append(event)
|
||||
self.session_data["events"] = self.events
|
||||
|
||||
# CRITICAL: Save to file immediately (real-time logging)
|
||||
# This ensures events are persisted even if agent crashes
|
||||
self._save()
|
||||
|
||||
def set_metadata(self, metadata: Dict[str, Any]):
|
||||
"""Set session metadata (model, toolsets, etc.)."""
|
||||
self.session_data["metadata"].update(metadata)
|
||||
self._save()
|
||||
|
||||
def finalize(self):
|
||||
"""Finalize the session and save."""
|
||||
self.session_data["end_time"] = datetime.now().isoformat()
|
||||
self._save()
|
||||
|
||||
def _save(self):
|
||||
"""
|
||||
Save current session data to JSON file.
|
||||
|
||||
Called after EVERY event is added - provides real-time persistence.
|
||||
If file writing fails, logs error but continues (doesn't crash server).
|
||||
"""
|
||||
try:
|
||||
# Write complete session data to JSON file
|
||||
# indent=2 makes it human-readable
|
||||
# ensure_ascii=False preserves Unicode characters
|
||||
with open(self.log_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(self.session_data, f, indent=2, ensure_ascii=False)
|
||||
except Exception as e:
|
||||
print(f"❌ Error saving session log: {e}")
|
||||
|
||||
|
||||
class ConnectionManager:
|
||||
"""
|
||||
Manages WebSocket connections and active sessions.
|
||||
|
||||
Global singleton that:
|
||||
- Tracks all active WebSocket connections (for broadcasting)
|
||||
- Manages all SessionLogger instances (one per agent session)
|
||||
- Coordinates between WebSocket events and file logging
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.active_connections: List[WebSocket] = [] # All connected WebSocket clients
|
||||
self.sessions: Dict[str, SessionLogger] = {} # session_id -> SessionLogger mapping
|
||||
|
||||
async def connect(self, websocket: WebSocket):
|
||||
"""Accept a new WebSocket connection."""
|
||||
await websocket.accept()
|
||||
self.active_connections.append(websocket)
|
||||
print(f"✅ WebSocket connected. Active connections: {len(self.active_connections)}")
|
||||
|
||||
def disconnect(self, websocket: WebSocket):
|
||||
"""Remove a WebSocket connection."""
|
||||
if websocket in self.active_connections:
|
||||
self.active_connections.remove(websocket)
|
||||
print(f"❌ WebSocket disconnected. Active connections: {len(self.active_connections)}")
|
||||
|
||||
def get_or_create_session(self, session_id: str) -> SessionLogger:
|
||||
"""
|
||||
Get existing session logger or create a new one.
|
||||
|
||||
Called when an event arrives for a session. Creates SessionLogger
|
||||
on first event, reuses it for subsequent events from same session.
|
||||
"""
|
||||
if session_id not in self.sessions:
|
||||
# First time seeing this session - create new logger
|
||||
self.sessions[session_id] = SessionLogger(session_id)
|
||||
print(f"📝 Created new session: {session_id}")
|
||||
return self.sessions[session_id]
|
||||
|
||||
def finalize_session(self, session_id: str):
|
||||
"""Finalize and clean up a session."""
|
||||
if session_id in self.sessions:
|
||||
self.sessions[session_id].finalize()
|
||||
print(f"✅ Session finalized: {session_id}")
|
||||
|
||||
async def broadcast(self, message: Dict[str, Any]):
|
||||
"""
|
||||
Broadcast a message to all connected WebSocket clients.
|
||||
|
||||
Allows multiple clients (e.g., multiple browser tabs) to watch
|
||||
the same agent session in real-time. Future UI feature.
|
||||
"""
|
||||
disconnected = []
|
||||
for connection in self.active_connections:
|
||||
try:
|
||||
await connection.send_json(message)
|
||||
except Exception:
|
||||
# Connection closed - mark for removal
|
||||
disconnected.append(connection)
|
||||
|
||||
# Clean up disconnected clients silently
|
||||
for conn in disconnected:
|
||||
if conn in self.active_connections:
|
||||
self.active_connections.remove(conn)
|
||||
|
||||
|
||||
# Global connection manager
|
||||
manager = ConnectionManager()
|
||||
|
||||
|
||||
# Request/Response models for API endpoints
|
||||
class AgentRequest(BaseModel):
|
||||
"""Request model for starting an agent run."""
|
||||
query: str
|
||||
model: str = "claude-sonnet-4-5-20250929"
|
||||
base_url: str = "https://api.anthropic.com/v1/"
|
||||
enabled_toolsets: Optional[List[str]] = None
|
||||
disabled_toolsets: Optional[List[str]] = None
|
||||
max_turns: int = 10
|
||||
mock_web_tools: bool = False
|
||||
mock_delay: int = 60
|
||||
verbose: bool = False
|
||||
|
||||
|
||||
class AgentResponse(BaseModel):
|
||||
"""Response model for agent run request."""
|
||||
status: str
|
||||
session_id: str
|
||||
message: str
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
"""Root endpoint - server status."""
|
||||
return {
|
||||
"status": "running",
|
||||
"service": "Hermes Agent Logging Server",
|
||||
"websocket_url": "ws://localhost:8000/ws",
|
||||
"active_connections": len(manager.active_connections),
|
||||
"active_sessions": len(manager.sessions),
|
||||
"logs_directory": str(LOGS_DIR)
|
||||
}
|
||||
|
||||
|
||||
@app.get("/sessions")
|
||||
async def list_sessions():
|
||||
"""List all active and recent sessions."""
|
||||
# Get all session log files
|
||||
session_files = list(LOGS_DIR.glob("session_*.json"))
|
||||
|
||||
sessions = []
|
||||
for session_file in sorted(session_files, key=lambda x: x.stat().st_mtime, reverse=True):
|
||||
try:
|
||||
with open(session_file, 'r', encoding='utf-8') as f:
|
||||
session_data = json.load(f)
|
||||
sessions.append({
|
||||
"session_id": session_data.get("session_id"),
|
||||
"start_time": session_data.get("start_time"),
|
||||
"end_time": session_data.get("end_time"),
|
||||
"event_count": len(session_data.get("events", [])),
|
||||
"file": str(session_file)
|
||||
})
|
||||
except Exception as e:
|
||||
print(f"⚠️ Error reading session file {session_file}: {e}")
|
||||
|
||||
return {
|
||||
"total_sessions": len(sessions),
|
||||
"sessions": sessions
|
||||
}
|
||||
|
||||
|
||||
@app.get("/sessions/{session_id}")
|
||||
async def get_session(session_id: str):
|
||||
"""Get detailed data for a specific session."""
|
||||
session_file = LOGS_DIR / f"session_{session_id}.json"
|
||||
|
||||
if not session_file.exists():
|
||||
return {"error": "Session not found"}, 404
|
||||
|
||||
try:
|
||||
with open(session_file, 'r', encoding='utf-8') as f:
|
||||
return json.load(f)
|
||||
except Exception as e:
|
||||
return {"error": f"Failed to load session: {str(e)}"}, 500
|
||||
|
||||
|
||||
@app.post("/agent/run", response_model=AgentResponse)
|
||||
async def run_agent(request: AgentRequest, background_tasks: BackgroundTasks):
|
||||
"""
|
||||
Start an agent run with specified parameters.
|
||||
|
||||
This endpoint triggers an agent execution in the background and returns immediately.
|
||||
The agent will connect to the WebSocket endpoint to send real-time events.
|
||||
|
||||
Args:
|
||||
request: AgentRequest with query and configuration
|
||||
background_tasks: FastAPI background tasks for async execution
|
||||
|
||||
Returns:
|
||||
AgentResponse with session_id for tracking
|
||||
"""
|
||||
import uuid
|
||||
import sys
|
||||
import os
|
||||
|
||||
# Generate session ID for this run - we'll pass it to the agent
|
||||
session_id = str(uuid.uuid4())
|
||||
|
||||
# Add parent directory to path to import run_agent
|
||||
parent_dir = str(Path(__file__).parent.parent)
|
||||
if parent_dir not in sys.path:
|
||||
sys.path.insert(0, parent_dir)
|
||||
|
||||
from run_agent import AIAgent
|
||||
|
||||
# Run agent in background thread (not blocking the API)
|
||||
def run_agent_background():
|
||||
"""Run agent in a separate thread."""
|
||||
try:
|
||||
# Initialize agent with WebSocket logging enabled
|
||||
agent = AIAgent(
|
||||
base_url=request.base_url,
|
||||
model=request.model,
|
||||
api_key=os.getenv("ANTHROPIC_API_KEY"),
|
||||
max_iterations=request.max_turns,
|
||||
enabled_toolsets=request.enabled_toolsets,
|
||||
disabled_toolsets=request.disabled_toolsets,
|
||||
save_trajectories=False,
|
||||
verbose_logging=request.verbose,
|
||||
enable_websocket_logging=True, # Always enable for UI
|
||||
websocket_server="ws://localhost:8000/ws",
|
||||
mock_web_tools=request.mock_web_tools,
|
||||
mock_delay=request.mock_delay
|
||||
)
|
||||
|
||||
# Run conversation with our session_id
|
||||
result = agent.run_conversation(
|
||||
request.query,
|
||||
session_id=session_id # Pass session_id so it matches
|
||||
)
|
||||
|
||||
print(f"✅ Agent run completed: {session_id[:8]}...")
|
||||
print(f" Final response: {result['final_response'][:100] if result.get('final_response') else 'No response'}...")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error running agent {session_id[:8]}...: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
# Start agent in background thread
|
||||
thread = threading.Thread(target=run_agent_background, daemon=True)
|
||||
thread.start()
|
||||
|
||||
return AgentResponse(
|
||||
status="started",
|
||||
session_id=session_id,
|
||||
message=f"Agent started with session ID: {session_id}"
|
||||
)
|
||||
|
||||
|
||||
@app.get("/tools")
|
||||
async def get_available_tools():
|
||||
"""Get list of available toolsets and tools."""
|
||||
try:
|
||||
import sys
|
||||
parent_dir = str(Path(__file__).parent.parent)
|
||||
if parent_dir not in sys.path:
|
||||
sys.path.insert(0, parent_dir)
|
||||
|
||||
from toolsets import get_all_toolsets, get_toolset_info
|
||||
|
||||
all_toolsets = get_all_toolsets()
|
||||
toolsets_info = []
|
||||
|
||||
for name in all_toolsets.keys():
|
||||
info = get_toolset_info(name)
|
||||
if info:
|
||||
toolsets_info.append({
|
||||
"name": name,
|
||||
"description": info['description'],
|
||||
"tool_count": info['tool_count'],
|
||||
"resolved_tools": info['resolved_tools']
|
||||
})
|
||||
|
||||
return {
|
||||
"toolsets": toolsets_info
|
||||
}
|
||||
except Exception as e:
|
||||
return {"error": f"Failed to load tools: {str(e)}"}
|
||||
|
||||
|
||||
@app.websocket("/ws")
|
||||
async def websocket_endpoint(websocket: WebSocket):
|
||||
"""
|
||||
WebSocket endpoint for receiving real-time agent events.
|
||||
|
||||
This is the main entry point for all logging. Agents connect here and send events.
|
||||
|
||||
Message Flow:
|
||||
1. Agent connects to ws://localhost:8000/ws
|
||||
2. Agent sends events as JSON messages
|
||||
3. Server parses event_type and routes to appropriate handler
|
||||
4. Event is added to SessionLogger (saved to file)
|
||||
5. Event is broadcast to all connected clients
|
||||
6. Acknowledgment sent back to agent
|
||||
|
||||
Expected message format:
|
||||
{
|
||||
"session_id": "unique-session-id", // Links event to specific session
|
||||
"event_type": "query" | "api_call" | ..., // What kind of event
|
||||
"data": { ... event-specific data ... } // Event payload
|
||||
}
|
||||
"""
|
||||
# Accept the WebSocket connection
|
||||
await manager.connect(websocket)
|
||||
|
||||
try:
|
||||
# Main event loop - runs until client disconnects
|
||||
while True:
|
||||
# Receive message from client (agent)
|
||||
# This is a blocking call - waits for next message
|
||||
message = await websocket.receive_json()
|
||||
|
||||
# Parse the standard message structure
|
||||
session_id = message.get("session_id") # Which agent session
|
||||
event_type = message.get("event_type") # What kind of event
|
||||
data = message.get("data", {}) # Event payload
|
||||
|
||||
# Validate: session_id is required
|
||||
if not session_id:
|
||||
await websocket.send_json({
|
||||
"error": "session_id is required"
|
||||
})
|
||||
continue
|
||||
|
||||
# Get or create SessionLogger for this session
|
||||
# First event creates it, subsequent events reuse it
|
||||
session = manager.get_or_create_session(session_id)
|
||||
|
||||
# Route event to appropriate handler based on event_type
|
||||
# Each handler extracts relevant data and adds to session log
|
||||
|
||||
if event_type == "session_start":
|
||||
# Initial event - sent when agent first connects
|
||||
# Contains metadata about the session (model, toolsets, etc.)
|
||||
session.set_metadata(data)
|
||||
print(f"🚀 Session started: {session_id}")
|
||||
|
||||
elif event_type == "query":
|
||||
# User query
|
||||
session.add_event({
|
||||
"type": "query",
|
||||
"query": data.get("query"),
|
||||
"toolsets": data.get("toolsets"),
|
||||
"model": data.get("model")
|
||||
})
|
||||
print(f"📝 Query logged: {data.get('query', '')[:60]}...")
|
||||
|
||||
elif event_type == "api_call":
|
||||
# API call to model
|
||||
session.add_event({
|
||||
"type": "api_call",
|
||||
"call_number": data.get("call_number"),
|
||||
"message_count": data.get("message_count"),
|
||||
"has_tools": data.get("has_tools")
|
||||
})
|
||||
print(f"🔄 API call #{data.get('call_number')} logged")
|
||||
|
||||
elif event_type == "response":
|
||||
# Assistant response
|
||||
session.add_event({
|
||||
"type": "response",
|
||||
"call_number": data.get("call_number"),
|
||||
"content": data.get("content"),
|
||||
"has_tool_calls": data.get("has_tool_calls"),
|
||||
"tool_call_count": data.get("tool_call_count"),
|
||||
"duration": data.get("duration")
|
||||
})
|
||||
print(f"🤖 Response logged: {data.get('content', '')[:60]}...")
|
||||
|
||||
elif event_type == "tool_call":
|
||||
# Tool execution
|
||||
session.add_event({
|
||||
"type": "tool_call",
|
||||
"call_number": data.get("call_number"),
|
||||
"tool_index": data.get("tool_index"),
|
||||
"tool_name": data.get("tool_name"),
|
||||
"parameters": data.get("parameters"),
|
||||
"tool_call_id": data.get("tool_call_id")
|
||||
})
|
||||
print(f"🔧 Tool call logged: {data.get('tool_name')}")
|
||||
|
||||
elif event_type == "tool_result":
|
||||
# Tool result - captures output from tool execution
|
||||
# Now includes BOTH truncated preview AND full raw result
|
||||
session.add_event({
|
||||
"type": "tool_result",
|
||||
"call_number": data.get("call_number"),
|
||||
"tool_index": data.get("tool_index"),
|
||||
"tool_name": data.get("tool_name"),
|
||||
"result": data.get("result"), # Truncated preview (1000 chars)
|
||||
"raw_result": data.get("raw_result"), # NEW: Full untruncated result
|
||||
"error": data.get("error"),
|
||||
"duration": data.get("duration"),
|
||||
"tool_call_id": data.get("tool_call_id")
|
||||
})
|
||||
|
||||
# Enhanced logging with size information
|
||||
if data.get("error"):
|
||||
print(f"❌ Tool error logged: {data.get('tool_name')}")
|
||||
else:
|
||||
# Show size of raw result to indicate data volume
|
||||
raw_size = len(data.get("raw_result", "")) if data.get("raw_result") else len(data.get("result", ""))
|
||||
size_kb = raw_size / 1024
|
||||
print(f"✅ Tool result logged: {data.get('tool_name')} ({size_kb:.1f} KB)")
|
||||
|
||||
elif event_type == "error":
|
||||
# Error event
|
||||
session.add_event({
|
||||
"type": "error",
|
||||
"error_message": data.get("error_message"),
|
||||
"call_number": data.get("call_number")
|
||||
})
|
||||
print(f"❌ Error logged: {data.get('error_message', '')[:60]}...")
|
||||
|
||||
elif event_type == "complete":
|
||||
# Session complete
|
||||
session.add_event({
|
||||
"type": "complete",
|
||||
"final_response": data.get("final_response"),
|
||||
"total_calls": data.get("total_calls"),
|
||||
"completed": data.get("completed")
|
||||
})
|
||||
manager.finalize_session(session_id)
|
||||
print(f"🎉 Session complete: {session_id}")
|
||||
|
||||
else:
|
||||
# Unknown event type - log it anyway
|
||||
session.add_event({
|
||||
"type": event_type or "unknown",
|
||||
**data
|
||||
})
|
||||
print(f"⚠️ Unknown event type: {event_type}")
|
||||
|
||||
# Broadcast event to all connected clients (for future real-time UI)
|
||||
# Allows multiple browsers/dashboards to watch same session live
|
||||
await manager.broadcast({
|
||||
"session_id": session_id,
|
||||
"event_type": event_type,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"data": data
|
||||
})
|
||||
|
||||
# Send acknowledgment back to sender
|
||||
# Confirms event was received and logged
|
||||
# Handle case where client disconnects before we can ack
|
||||
try:
|
||||
await websocket.send_json({
|
||||
"status": "logged",
|
||||
"session_id": session_id,
|
||||
"event_type": event_type
|
||||
})
|
||||
except Exception:
|
||||
# Connection closed before ack - this is normal for "complete" event
|
||||
# Client disconnects after sending, so we can't ack
|
||||
pass
|
||||
|
||||
except WebSocketDisconnect:
|
||||
manager.disconnect(websocket)
|
||||
except Exception as e:
|
||||
print(f"❌ WebSocket error: {e}")
|
||||
manager.disconnect(websocket)
|
||||
|
||||
|
||||
def main(host: str = "0.0.0.0", port: int = 8000, reload: bool = False):
|
||||
"""
|
||||
Start the logging server.
|
||||
|
||||
Args:
|
||||
host: Host to bind to (default: 0.0.0.0)
|
||||
port: Port to run on (default: 8000)
|
||||
reload: Enable auto-reload on file changes (default: False)
|
||||
"""
|
||||
print("🚀 Hermes Agent Logging Server")
|
||||
print("=" * 50)
|
||||
print(f"📂 Logs directory: {LOGS_DIR}")
|
||||
print(f"🌐 Server starting at http://{host}:{port}")
|
||||
print(f"🔌 WebSocket endpoint: ws://{host}:{port}/ws")
|
||||
print(f"🔄 Auto-reload: {'enabled' if reload else 'disabled'}")
|
||||
print("\n📡 Ready to receive agent events...")
|
||||
print("=" * 50)
|
||||
|
||||
uvicorn.run(
|
||||
"logging_server:app",
|
||||
host=host,
|
||||
port=port,
|
||||
reload=reload,
|
||||
log_level="info",
|
||||
timeout_keep_alive=600 # Keep HTTP/WS connections alive for 10 minutes of inactivity
|
||||
# Note: WebSocket ping/pong disabled in client to avoid timeout during blocked event loop
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import fire
|
||||
fire.Fire(main)
|
||||
|
||||
@@ -1,91 +0,0 @@
|
||||
#!/bin/bash
|
||||
# Test script for WebSocket logging system
|
||||
#
|
||||
# This script demonstrates the complete WebSocket logging workflow:
|
||||
# 1. Starts the logging server
|
||||
# 2. Runs the agent with WebSocket logging enabled
|
||||
# 3. Shows the logged data
|
||||
#
|
||||
# Usage: ./test_websocket_logging.sh
|
||||
|
||||
set -e # Exit on error
|
||||
|
||||
echo "🧪 Testing WebSocket Logging System"
|
||||
echo "===================================="
|
||||
echo ""
|
||||
|
||||
# Check if required packages are installed
|
||||
echo "📦 Checking dependencies..."
|
||||
python -c "import fastapi; import uvicorn; import websockets" 2>/dev/null || {
|
||||
echo "❌ Missing dependencies. Installing..."
|
||||
pip install fastapi uvicorn websockets
|
||||
}
|
||||
echo "✅ Dependencies OK"
|
||||
echo ""
|
||||
|
||||
# Start the logging server in the background
|
||||
echo "🚀 Starting logging server..."
|
||||
python api_endpoint/logging_server.py --port 8000 &
|
||||
SERVER_PID=$!
|
||||
|
||||
# Give server time to start
|
||||
sleep 2
|
||||
|
||||
# Check if server is running
|
||||
if ps -p $SERVER_PID > /dev/null; then
|
||||
echo "✅ Logging server started (PID: $SERVER_PID)"
|
||||
else
|
||||
echo "❌ Failed to start logging server"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo ""
|
||||
echo "🤖 Running agent with WebSocket logging..."
|
||||
echo ""
|
||||
|
||||
# Run the agent with WebSocket logging
|
||||
python run_agent.py \
|
||||
--enabled_toolsets web \
|
||||
--enable_websocket_logging \
|
||||
--query "What are the top 3 programming languages in 2025?" \
|
||||
--max_turns 5
|
||||
|
||||
echo ""
|
||||
echo "✅ Agent execution complete!"
|
||||
echo ""
|
||||
|
||||
# Show the most recent log file
|
||||
echo "📊 Viewing logged session data..."
|
||||
echo ""
|
||||
|
||||
LATEST_LOG=$(ls -t logs/realtime/session_*.json 2>/dev/null | head -1)
|
||||
|
||||
if [ -f "$LATEST_LOG" ]; then
|
||||
echo "📄 Log file: $LATEST_LOG"
|
||||
echo ""
|
||||
|
||||
# Pretty print the JSON if jq is available
|
||||
if command -v jq &> /dev/null; then
|
||||
echo "Event summary:"
|
||||
jq '.events[] | {type: .type, timestamp: .timestamp}' "$LATEST_LOG"
|
||||
echo ""
|
||||
echo "Total events: $(jq '.events | length' "$LATEST_LOG")"
|
||||
else
|
||||
echo "Content (install 'jq' for pretty printing):"
|
||||
cat "$LATEST_LOG"
|
||||
fi
|
||||
else
|
||||
echo "⚠️ No log files found in logs/realtime/"
|
||||
fi
|
||||
|
||||
echo ""
|
||||
echo "🛑 Stopping logging server..."
|
||||
kill $SERVER_PID 2>/dev/null || true
|
||||
|
||||
echo "✅ Test complete!"
|
||||
echo ""
|
||||
echo "Next steps:"
|
||||
echo " 1. Start server: python api_endpoint/logging_server.py"
|
||||
echo " 2. Run agent: python run_agent.py --enable_websocket_logging --query \"...\""
|
||||
echo " 3. View logs: http://localhost:8000/sessions"
|
||||
|
||||
@@ -1,457 +0,0 @@
|
||||
"""
|
||||
WebSocket Connection Pool - Persistent Connection Manager
|
||||
|
||||
This module provides a singleton WebSocket connection that persists across
|
||||
multiple agent runs. This is a more robust architecture than creating a new
|
||||
connection for each run.
|
||||
|
||||
Benefits:
|
||||
- No timeout issues (connection stays alive indefinitely)
|
||||
- No reconnection overhead (connect once)
|
||||
- Supports parallel agent runs (multiple sessions share one socket)
|
||||
- Proper shutdown handling (SIGTERM/SIGINT)
|
||||
- Thread-safe concurrent sends
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import signal
|
||||
import websockets
|
||||
from typing import Optional, Dict, Any
|
||||
import json
|
||||
import atexit
|
||||
import sys
|
||||
import threading
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class WebSocketConnectionPool:
|
||||
"""
|
||||
Singleton WebSocket connection manager.
|
||||
|
||||
Maintains a single persistent connection to the logging server
|
||||
that all agent sessions can use. Handles graceful shutdown.
|
||||
|
||||
Usage:
|
||||
# Get singleton instance
|
||||
pool = WebSocketConnectionPool()
|
||||
|
||||
# Connect (idempotent - safe to call multiple times)
|
||||
await pool.connect()
|
||||
|
||||
# Send events (thread-safe, multiple sessions can call concurrently)
|
||||
await pool.send_event("query", session_id, {...})
|
||||
|
||||
# Shutdown handled automatically on SIGTERM/SIGINT
|
||||
"""
|
||||
|
||||
_instance: Optional['WebSocketConnectionPool'] = None
|
||||
|
||||
def __new__(cls):
|
||||
"""Ensure only one instance exists (singleton pattern)."""
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
cls._instance._initialized = False
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the connection pool (only once)."""
|
||||
if getattr(self, '_initialized', False):
|
||||
return
|
||||
|
||||
self.websocket: Optional[websockets.WebSocketClientProtocol] = None
|
||||
self.server_url: str = "ws://localhost:8000/ws"
|
||||
self.connected: bool = False
|
||||
# Store reference to loop for signal handlers
|
||||
# Agent code should never close event loops when using persistent connections
|
||||
self.loop: Optional[asyncio.AbstractEventLoop] = None
|
||||
# Locks are created lazily when event loop exists
|
||||
self._send_lock: Optional[asyncio.Lock] = None
|
||||
self._connect_lock: Optional[asyncio.Lock] = None
|
||||
self._locks_loop: Optional[asyncio.AbstractEventLoop] = None # Track which loop created locks
|
||||
self._init_lock = threading.Lock() # Thread-safe lock initialization
|
||||
self._shutdown_in_progress = False
|
||||
self._initialized = True
|
||||
|
||||
# Register shutdown handlers for graceful cleanup
|
||||
# These ensure WebSocket is closed properly on exit
|
||||
signal.signal(signal.SIGTERM, self._signal_handler)
|
||||
signal.signal(signal.SIGINT, self._signal_handler)
|
||||
atexit.register(self._cleanup_sync)
|
||||
|
||||
print("🔌 WebSocket connection pool initialized")
|
||||
|
||||
def _ensure_locks(self):
|
||||
"""
|
||||
Lazy initialization of asyncio locks with thread safety and loop tracking.
|
||||
|
||||
Locks must be created when an event loop exists, not at import time.
|
||||
If the event loop changes between runs, locks must be recreated because
|
||||
asyncio.Lock objects are bound to the loop that created them.
|
||||
|
||||
This is called before any async operation that needs locks.
|
||||
Uses a threading.Lock to prevent race conditions during initialization.
|
||||
"""
|
||||
with self._init_lock: # Thread-safe initialization
|
||||
try:
|
||||
current_loop = asyncio.get_event_loop()
|
||||
except RuntimeError:
|
||||
# No event loop in current thread
|
||||
return
|
||||
|
||||
# Recreate locks if:
|
||||
# 1. Locks don't exist yet, OR
|
||||
# 2. Event loop has changed (locks are bound to the loop that created them)
|
||||
if self._locks_loop is not current_loop or self._send_lock is None:
|
||||
self._send_lock = asyncio.Lock()
|
||||
self._connect_lock = asyncio.Lock()
|
||||
self._locks_loop = current_loop
|
||||
|
||||
async def connect(self, server_url: str = "ws://localhost:8000/ws") -> bool:
|
||||
"""
|
||||
Connect to WebSocket server.
|
||||
|
||||
This is idempotent - safe to call multiple times. If already connected,
|
||||
does nothing. If connection failed previously, will retry.
|
||||
|
||||
Args:
|
||||
server_url: WebSocket server URL (default: ws://localhost:8000/ws)
|
||||
|
||||
Returns:
|
||||
bool: True if connected successfully, False otherwise
|
||||
"""
|
||||
# Ensure locks exist (lazy initialization)
|
||||
self._ensure_locks()
|
||||
|
||||
async with self._connect_lock:
|
||||
# Always update loop reference to current loop (even if already connected)
|
||||
# This ensures signal handlers and cleanup use the correct loop
|
||||
self.loop = asyncio.get_event_loop()
|
||||
|
||||
# Already connected - nothing to do
|
||||
if self.connected and self.websocket:
|
||||
return True
|
||||
|
||||
try:
|
||||
self.server_url = server_url
|
||||
|
||||
# Establish persistent WebSocket connection
|
||||
# No ping/pong needed since connection stays open indefinitely
|
||||
self.websocket = await websockets.connect(
|
||||
server_url,
|
||||
ping_interval=None, # Disable ping/pong (not needed for persistent connection)
|
||||
max_size=10 * 1024 * 1024, # 10MB max message size for large tool results
|
||||
open_timeout=10, # 10s timeout for initial connection
|
||||
close_timeout=5 # 5s timeout for close handshake
|
||||
)
|
||||
|
||||
self.connected = True
|
||||
|
||||
print(f"✅ Connected to logging server (persistent): {server_url}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"⚠️ Failed to connect to logging server: {e}")
|
||||
self.connected = False
|
||||
self.websocket = None
|
||||
return False
|
||||
|
||||
async def send_event(
|
||||
self,
|
||||
event_type: str,
|
||||
session_id: str,
|
||||
data: Dict[str, Any],
|
||||
retry: bool = True
|
||||
) -> bool:
|
||||
"""
|
||||
Send event to logging server (thread-safe).
|
||||
|
||||
Multiple agent runs can call this concurrently. The send lock ensures
|
||||
only one message is sent at a time (WebSocket protocol requirement).
|
||||
|
||||
Args:
|
||||
event_type: Type of event (query, api_call, response, tool_call, tool_result, error, complete)
|
||||
session_id: Unique session identifier
|
||||
data: Event-specific data dictionary
|
||||
retry: Whether to retry connection if disconnected (default: True)
|
||||
|
||||
Returns:
|
||||
bool: True if sent successfully, False otherwise
|
||||
"""
|
||||
# Try to connect if not connected (or reconnect if disconnected)
|
||||
if not self.connected or not self.websocket:
|
||||
if retry:
|
||||
await self.connect()
|
||||
if not self.connected:
|
||||
return False # Give up if connection fails
|
||||
|
||||
# Ensure locks exist (lazy initialization)
|
||||
self._ensure_locks()
|
||||
|
||||
# Lock to prevent concurrent sends (WebSocket requires sequential sends)
|
||||
async with self._send_lock:
|
||||
try:
|
||||
# Create standardized message format
|
||||
message = {
|
||||
"session_id": session_id,
|
||||
"event_type": event_type,
|
||||
"data": data,
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
# Send message as JSON
|
||||
await self.websocket.send(json.dumps(message))
|
||||
|
||||
# Wait for server acknowledgment (with timeout)
|
||||
# This confirms the server received and processed the event
|
||||
try:
|
||||
response = await asyncio.wait_for(
|
||||
self.websocket.recv(),
|
||||
timeout=2.0 # Increased to 2s for busy servers
|
||||
)
|
||||
# Successfully received acknowledgment
|
||||
return True
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
# No response within timeout - that's OK, message likely sent
|
||||
# Server might be busy processing
|
||||
return True
|
||||
|
||||
except websockets.exceptions.ConnectionClosed:
|
||||
print(f"⚠️ WebSocket connection closed unexpectedly")
|
||||
self.connected = False
|
||||
|
||||
# Try to reconnect and resend (one retry)
|
||||
if retry:
|
||||
print("🔄 Attempting to reconnect...")
|
||||
if await self.connect():
|
||||
# Recursively call with retry=False to avoid infinite loop
|
||||
return await self.send_event(event_type, session_id, data, retry=False)
|
||||
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
print(f"⚠️ Error sending event: {e}")
|
||||
self.connected = False
|
||||
return False
|
||||
|
||||
async def disconnect(self):
|
||||
"""
|
||||
Gracefully close the WebSocket connection.
|
||||
|
||||
Called on shutdown (SIGTERM/SIGINT/exit). Ensures proper cleanup.
|
||||
"""
|
||||
if self._shutdown_in_progress:
|
||||
return # Already shutting down
|
||||
|
||||
self._shutdown_in_progress = True
|
||||
|
||||
if self.websocket and self.connected:
|
||||
try:
|
||||
await self.websocket.close()
|
||||
self.connected = False
|
||||
print("✅ WebSocket connection pool closed gracefully")
|
||||
except Exception as e:
|
||||
print(f"⚠️ Error closing WebSocket: {e}")
|
||||
|
||||
self._shutdown_in_progress = False
|
||||
|
||||
def _signal_handler(self, signum, frame):
|
||||
"""
|
||||
Handle SIGTERM/SIGINT signals for graceful shutdown.
|
||||
|
||||
When user presses Ctrl+C or system sends SIGTERM, this ensures
|
||||
the WebSocket is closed properly before exit.
|
||||
"""
|
||||
print(f"\n🛑 Received signal {signum}, closing WebSocket connection pool...")
|
||||
|
||||
# Check if we have a valid loop and are connected
|
||||
if self.loop and not self.loop.is_closed() and self.connected and not self._shutdown_in_progress:
|
||||
try:
|
||||
# If loop is not running, we can wait for disconnect
|
||||
if not self.loop.is_running():
|
||||
self.loop.run_until_complete(self.disconnect())
|
||||
else:
|
||||
# Loop is running, can't wait for task - just mark disconnected
|
||||
# The disconnect task would be cancelled when we exit anyway
|
||||
self.connected = False
|
||||
print("⚠️ Loop is running, marking disconnected without waiting")
|
||||
except Exception as e:
|
||||
print(f"⚠️ Error during signal handler cleanup: {e}")
|
||||
|
||||
# Exit gracefully
|
||||
sys.exit(0)
|
||||
|
||||
def _cleanup_sync(self):
|
||||
"""
|
||||
Cleanup at exit (atexit handler).
|
||||
|
||||
This is a fallback in case signal handlers don't fire.
|
||||
Called when Python interpreter shuts down normally.
|
||||
"""
|
||||
if self.loop and not self.loop.is_closed() and self.connected and not self._shutdown_in_progress:
|
||||
try:
|
||||
# Try to run disconnect synchronously
|
||||
self.loop.run_until_complete(self.disconnect())
|
||||
except Exception:
|
||||
# Ignore errors during exit cleanup
|
||||
pass
|
||||
|
||||
def is_connected(self) -> bool:
|
||||
"""Check if currently connected to server."""
|
||||
return self.connected and self.websocket is not None
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""Get connection statistics for debugging."""
|
||||
return {
|
||||
"connected": self.connected,
|
||||
"server_url": self.server_url,
|
||||
"shutdown_in_progress": self._shutdown_in_progress,
|
||||
"has_websocket": self.websocket is not None,
|
||||
"has_loop": self.loop is not None
|
||||
}
|
||||
|
||||
|
||||
# Global singleton instance
|
||||
# Import this in other modules: from websocket_connection_pool import ws_pool
|
||||
ws_pool = WebSocketConnectionPool()
|
||||
|
||||
|
||||
# Convenience functions for direct usage
|
||||
async def connect(server_url: str = "ws://localhost:8000/ws") -> bool:
|
||||
"""Connect to logging server (convenience function)."""
|
||||
return await ws_pool.connect(server_url)
|
||||
|
||||
|
||||
async def send_event(event_type: str, session_id: str, data: Dict[str, Any]) -> bool:
|
||||
"""Send event to logging server (convenience function)."""
|
||||
return await ws_pool.send_event(event_type, session_id, data)
|
||||
|
||||
|
||||
async def disconnect():
|
||||
"""Disconnect from logging server (convenience function)."""
|
||||
await ws_pool.disconnect()
|
||||
|
||||
|
||||
def is_connected() -> bool:
|
||||
"""Check if connected to logging server (convenience function)."""
|
||||
return ws_pool.is_connected()
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# SYNCHRONOUS API FOR AGENT LAYER
|
||||
# ============================================================================
|
||||
# These functions provide a clean abstraction that hides event loop management
|
||||
# from the agent layer. Agent code should ONLY use these functions.
|
||||
|
||||
def connect_sync(server_url: str = "ws://localhost:8000/ws") -> bool:
|
||||
"""
|
||||
Synchronous connect - handles event loop internally.
|
||||
|
||||
Creates a persistent event loop in a background thread if needed.
|
||||
This is thread-safe and can be called from any thread (including agent background threads).
|
||||
"""
|
||||
import threading
|
||||
|
||||
# If pool doesn't have a loop yet or it's closed, we need to start one
|
||||
if not ws_pool.loop or ws_pool.loop.is_closed():
|
||||
# Start connection in a background thread with its own loop
|
||||
result_container = {"success": False, "error": None, "connected": False}
|
||||
|
||||
def run_in_thread():
|
||||
try:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
ws_pool.loop = loop # Store the loop in the pool
|
||||
|
||||
# Connect to WebSocket
|
||||
result_container["success"] = loop.run_until_complete(ws_pool.connect(server_url))
|
||||
result_container["connected"] = True
|
||||
|
||||
# Keep loop running forever for future send_event calls
|
||||
# This is critical - the loop must stay alive for run_coroutine_threadsafe to work
|
||||
loop.run_forever()
|
||||
|
||||
except Exception as e:
|
||||
result_container["error"] = str(e)
|
||||
print(f"❌ Error in WebSocket connection thread: {e}")
|
||||
finally:
|
||||
# Clean up if loop stops
|
||||
if loop.is_running():
|
||||
loop.close()
|
||||
|
||||
thread = threading.Thread(target=run_in_thread, daemon=True, name="WebSocket-EventLoop")
|
||||
thread.start()
|
||||
|
||||
# Wait for connection to complete (but not for loop to exit - it runs forever)
|
||||
import time
|
||||
timeout = 10.0
|
||||
start = time.time()
|
||||
while not result_container["connected"] and (time.time() - start) < timeout:
|
||||
time.sleep(0.1)
|
||||
|
||||
if result_container["error"]:
|
||||
print(f"⚠️ Connection failed: {result_container['error']}")
|
||||
|
||||
return result_container["success"]
|
||||
else:
|
||||
# Pool already has a loop, use run_coroutine_threadsafe
|
||||
try:
|
||||
future = asyncio.run_coroutine_threadsafe(
|
||||
ws_pool.connect(server_url),
|
||||
ws_pool.loop
|
||||
)
|
||||
return future.result(timeout=10.0)
|
||||
except Exception as e:
|
||||
print(f"⚠️ Connection failed: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def send_event_sync(event_type: str, session_id: str, data: Dict[str, Any]) -> bool:
|
||||
"""
|
||||
Synchronous send event - handles event loop internally.
|
||||
|
||||
Uses the WebSocket pool's own event loop to avoid loop conflicts.
|
||||
This is critical when called from background threads (like agent execution).
|
||||
This is thread-safe and works correctly even when agent runs in a different thread.
|
||||
"""
|
||||
if not ws_pool.loop or not ws_pool.loop.is_running():
|
||||
# No event loop running - can't send
|
||||
print("⚠️ WebSocket pool has no running event loop")
|
||||
return False
|
||||
|
||||
try:
|
||||
# Use run_coroutine_threadsafe to submit to the WebSocket pool's loop
|
||||
# This works across threads - submits the coroutine to the correct loop
|
||||
future = asyncio.run_coroutine_threadsafe(
|
||||
ws_pool.send_event(event_type, session_id, data),
|
||||
ws_pool.loop # ← Use the pool's loop, not current thread's loop
|
||||
)
|
||||
|
||||
# Wait for completion (with timeout to avoid hanging)
|
||||
return future.result(timeout=5.0)
|
||||
|
||||
except TimeoutError:
|
||||
print(f"⚠️ Timeout sending event {event_type}")
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"⚠️ Error sending event: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def disconnect_sync():
|
||||
"""
|
||||
Synchronous disconnect - handles event loop internally.
|
||||
|
||||
Thread-safe disconnect that works from any thread.
|
||||
"""
|
||||
if ws_pool.loop and ws_pool.loop.is_running():
|
||||
try:
|
||||
future = asyncio.run_coroutine_threadsafe(
|
||||
ws_pool.disconnect(),
|
||||
ws_pool.loop
|
||||
)
|
||||
return future.result(timeout=5.0)
|
||||
except Exception as e:
|
||||
print(f"⚠️ Error disconnecting: {e}")
|
||||
return False
|
||||
return True
|
||||
@@ -1,387 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
WebSocket Logger Client
|
||||
|
||||
Simple client for sending agent events to the logging server via WebSocket.
|
||||
Used by the agent to log events in real-time during execution.
|
||||
"""
|
||||
|
||||
import json
|
||||
import asyncio
|
||||
from typing import Dict, Any, Optional
|
||||
from datetime import datetime
|
||||
import websockets
|
||||
|
||||
|
||||
class WebSocketLogger:
|
||||
"""
|
||||
Client for logging agent events via WebSocket.
|
||||
|
||||
Usage:
|
||||
logger = WebSocketLogger("unique-session-id")
|
||||
await logger.connect()
|
||||
await logger.log_query("What is Python?", model="gpt-4")
|
||||
await logger.log_api_call(call_number=1)
|
||||
await logger.log_response(call_number=1, content="Python is...")
|
||||
await logger.disconnect()
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session_id: str,
|
||||
server_url: str = "ws://localhost:8000/ws",
|
||||
enabled: bool = True
|
||||
):
|
||||
"""
|
||||
Initialize WebSocket logger.
|
||||
|
||||
Args:
|
||||
session_id: Unique identifier for this agent session
|
||||
server_url: WebSocket server URL (default: ws://localhost:8000/ws)
|
||||
enabled: Whether logging is enabled (default: True)
|
||||
"""
|
||||
self.session_id = session_id
|
||||
self.server_url = server_url
|
||||
self.enabled = enabled
|
||||
self.websocket: Optional[websockets.WebSocketClientProtocol] = None
|
||||
self.connected = False
|
||||
self.reconnect_count = 0 # Track reconnections for debugging
|
||||
|
||||
async def connect(self):
|
||||
"""
|
||||
Connect to the WebSocket logging server.
|
||||
|
||||
Establishes WebSocket connection and sends initial session_start event.
|
||||
If connection fails, gracefully disables logging (agent continues normally).
|
||||
"""
|
||||
if not self.enabled:
|
||||
return
|
||||
|
||||
try:
|
||||
# Establish WebSocket connection to the server
|
||||
# Use VERY LONG ping intervals to avoid timeout during long tool execution
|
||||
# The event loop is blocked during tool execution, so we can't process pings
|
||||
# Setting to very large values (1 hour) effectively disables it
|
||||
self.websocket = await websockets.connect(
|
||||
self.server_url,
|
||||
ping_interval=3600, # 1 hour - effectively disabled (event loop blocked anyway)
|
||||
ping_timeout=3600, # 1 hour timeout for pong response
|
||||
close_timeout=10, # Timeout for close handshake
|
||||
max_size=10 * 1024 * 1024, # 10MB max message size (for large raw_results)
|
||||
open_timeout=10 # Timeout for initial connection
|
||||
)
|
||||
self.connected = True
|
||||
print(f"✅ Connected to logging server (ping/pong: 3600s intervals): {self.server_url}")
|
||||
|
||||
# Send initial session_start event
|
||||
# This tells the server to create a new SessionLogger for this session
|
||||
await self._send_event("session_start", {
|
||||
"session_id": self.session_id,
|
||||
"start_time": datetime.now().isoformat()
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
# Connection failed - disable logging but don't crash the agent
|
||||
print(f"⚠️ Failed to connect to logging server: {e}")
|
||||
print(f" Logging will be disabled for this session.")
|
||||
self.enabled = False
|
||||
self.connected = False
|
||||
|
||||
async def disconnect(self):
|
||||
"""Disconnect from the WebSocket server."""
|
||||
if self.websocket and self.connected:
|
||||
try:
|
||||
await self.websocket.close()
|
||||
self.connected = False
|
||||
print(f"✅ Disconnected from logging server")
|
||||
except Exception as e:
|
||||
print(f"⚠️ Error disconnecting: {e}")
|
||||
|
||||
async def _send_event(self, event_type: str, data: Dict[str, Any]):
|
||||
"""
|
||||
Send an event to the logging server.
|
||||
|
||||
This is the core method that sends all events via WebSocket.
|
||||
Creates a standardized message format and handles acknowledgments.
|
||||
|
||||
Args:
|
||||
event_type: Type of event (query, api_call, response, tool_call, tool_result, error, complete)
|
||||
data: Event data dictionary containing event-specific information
|
||||
"""
|
||||
# Safety check: Don't send if logging is disabled
|
||||
if not self.enabled:
|
||||
return
|
||||
|
||||
# Auto-reconnect if connection was lost
|
||||
if not self.connected or not self.websocket:
|
||||
try:
|
||||
self.reconnect_count += 1
|
||||
print(f"🔄 Reconnecting to logging server (attempt #{self.reconnect_count})...")
|
||||
await self.connect()
|
||||
print(f"✅ Reconnected successfully!")
|
||||
except Exception as e:
|
||||
print(f"⚠️ Failed to reconnect: {e}")
|
||||
self.enabled = False # Disable logging after failed reconnect
|
||||
return
|
||||
|
||||
try:
|
||||
# Create standardized message structure
|
||||
# All events follow this format for consistent server-side handling
|
||||
message = {
|
||||
"session_id": self.session_id, # Links event to specific agent session
|
||||
"event_type": event_type, # Identifies what kind of event this is
|
||||
"data": data # Event-specific payload
|
||||
}
|
||||
|
||||
# Send message as JSON string over WebSocket
|
||||
await self.websocket.send(json.dumps(message))
|
||||
|
||||
# Wait for server acknowledgment (with 1 second timeout)
|
||||
# This ensures the server received and processed the event
|
||||
try:
|
||||
response = await asyncio.wait_for(
|
||||
self.websocket.recv(),
|
||||
timeout=1.0
|
||||
)
|
||||
# Server sends back: {"status": "logged", "session_id": "...", "event_type": "..."}
|
||||
# We don't need to process it, just confirms receipt
|
||||
except asyncio.TimeoutError:
|
||||
# No response within 1 second - that's okay, continue anyway
|
||||
# Server might be busy or network slow, but event was likely sent
|
||||
pass
|
||||
|
||||
except Exception as e:
|
||||
# Log error but don't crash - graceful degradation
|
||||
# Agent should continue working even if logging fails
|
||||
error_str = str(e)
|
||||
|
||||
# Check if connection was closed (error 1011 = keepalive ping timeout)
|
||||
if "1011" in error_str or "closed" in error_str.lower():
|
||||
print(f"⚠️ WebSocket connection closed: {error_str}")
|
||||
self.connected = False # Mark as disconnected
|
||||
# Don't try to send more events - connection is dead
|
||||
else:
|
||||
print(f"⚠️ Error sending event to logging server: {e}")
|
||||
# Don't disable entirely or try to reconnect - just continue with logging disabled
|
||||
|
||||
# Convenience methods for specific event types
|
||||
|
||||
async def log_query(
|
||||
self,
|
||||
query: str,
|
||||
model: str = None,
|
||||
toolsets: list = None
|
||||
):
|
||||
"""
|
||||
Log a user query (the question/task given to the agent).
|
||||
|
||||
This is typically the first event in a session after connection.
|
||||
Captures what the user asked and which model/tools will be used.
|
||||
"""
|
||||
await self._send_event("query", {
|
||||
"query": query, # The user's question/instruction
|
||||
"model": model, # Which AI model is being used
|
||||
"toolsets": toolsets # Which tool categories are enabled
|
||||
})
|
||||
|
||||
async def log_api_call(
|
||||
self,
|
||||
call_number: int,
|
||||
message_count: int = None,
|
||||
has_tools: bool = None
|
||||
):
|
||||
"""
|
||||
Log an API call to the AI model.
|
||||
|
||||
Called right before sending a request to the model (OpenAI/Anthropic/etc).
|
||||
Helps track how many API calls are being made and conversation length.
|
||||
"""
|
||||
await self._send_event("api_call", {
|
||||
"call_number": call_number, # Sequential number (1, 2, 3...)
|
||||
"message_count": message_count, # How many messages in conversation so far
|
||||
"has_tools": has_tools # Whether tools are available to the model
|
||||
})
|
||||
|
||||
async def log_response(
|
||||
self,
|
||||
call_number: int,
|
||||
content: str = None,
|
||||
has_tool_calls: bool = False,
|
||||
tool_call_count: int = 0,
|
||||
duration: float = None
|
||||
):
|
||||
"""
|
||||
Log an assistant response from the AI model.
|
||||
|
||||
Called after receiving a response from the API.
|
||||
Captures what the model said and whether it wants to use tools.
|
||||
"""
|
||||
await self._send_event("response", {
|
||||
"call_number": call_number, # Which API call this response is from
|
||||
"content": content, # What the model said (text response)
|
||||
"has_tool_calls": has_tool_calls, # Did model request tool execution?
|
||||
"tool_call_count": tool_call_count, # How many tools does it want to call?
|
||||
"duration": duration # How long the API call took (seconds)
|
||||
})
|
||||
|
||||
async def log_tool_call(
|
||||
self,
|
||||
call_number: int,
|
||||
tool_index: int,
|
||||
tool_name: str,
|
||||
parameters: Dict[str, Any],
|
||||
tool_call_id: str = None
|
||||
):
|
||||
"""
|
||||
Log a tool call (before executing the tool).
|
||||
|
||||
Captures which tool is being called and with what parameters.
|
||||
This happens BEFORE the tool runs, so no results yet.
|
||||
"""
|
||||
await self._send_event("tool_call", {
|
||||
"call_number": call_number, # Which API call requested this tool
|
||||
"tool_index": tool_index, # Which tool in the sequence (if multiple)
|
||||
"tool_name": tool_name, # Name of tool (e.g., "web_search", "web_extract")
|
||||
"parameters": parameters, # Arguments passed to the tool (e.g., {"query": "Python", "limit": 5})
|
||||
"tool_call_id": tool_call_id # Unique ID to link call with result
|
||||
})
|
||||
|
||||
async def log_tool_result(
|
||||
self,
|
||||
call_number: int,
|
||||
tool_index: int,
|
||||
tool_name: str,
|
||||
result: str = None,
|
||||
error: str = None,
|
||||
duration: float = None,
|
||||
tool_call_id: str = None,
|
||||
raw_result: str = None # NEW: Full untruncated result for verification
|
||||
):
|
||||
"""
|
||||
Log a tool result (output from tool execution).
|
||||
|
||||
Captures both a truncated preview (for UI display) and the full raw result
|
||||
(for verification and debugging). This is especially important for web tools
|
||||
where you want to see what was scraped vs what the LLM processed.
|
||||
|
||||
Args:
|
||||
call_number: Which API call this tool was part of
|
||||
tool_index: Which tool in the sequence (1st, 2nd, etc.)
|
||||
tool_name: Name of the tool that was executed
|
||||
result: Tool output (will be truncated to 1000 chars for preview)
|
||||
error: Error message if tool failed
|
||||
duration: How long the tool took to execute (seconds)
|
||||
tool_call_id: Unique ID linking this result to the tool call
|
||||
raw_result: NEW - Full untruncated result for verification/debugging
|
||||
"""
|
||||
await self._send_event("tool_result", {
|
||||
"call_number": call_number,
|
||||
"tool_index": tool_index,
|
||||
"tool_name": tool_name,
|
||||
"result": result[:1000] if result else None, # Truncated preview (1000 chars max)
|
||||
"raw_result": raw_result, # NEW: Full result - can be 100KB+ for web scraping
|
||||
"error": error,
|
||||
"duration": duration,
|
||||
"tool_call_id": tool_call_id
|
||||
})
|
||||
|
||||
async def log_error(
|
||||
self,
|
||||
error_message: str,
|
||||
call_number: int = None
|
||||
):
|
||||
"""
|
||||
Log an error that occurred during agent execution.
|
||||
|
||||
Captures exceptions, API failures, or other issues.
|
||||
"""
|
||||
await self._send_event("error", {
|
||||
"error_message": error_message, # Description of what went wrong
|
||||
"call_number": call_number # Which API call caused the error (if applicable)
|
||||
})
|
||||
|
||||
async def log_complete(
|
||||
self,
|
||||
final_response: str = None,
|
||||
total_calls: int = None,
|
||||
completed: bool = True
|
||||
):
|
||||
"""
|
||||
Log session completion (final event before disconnecting).
|
||||
|
||||
Marks the end of the agent's execution and provides summary info.
|
||||
"""
|
||||
await self._send_event("complete", {
|
||||
"final_response": final_response[:500] if final_response else None, # Truncated summary of final answer
|
||||
"total_calls": total_calls, # How many API calls were made total
|
||||
"completed": completed # Did it complete successfully? (true/false)
|
||||
})
|
||||
|
||||
|
||||
# Synchronous wrapper for convenience
|
||||
class SyncWebSocketLogger:
|
||||
"""
|
||||
Synchronous wrapper around WebSocketLogger.
|
||||
|
||||
For use in synchronous code - creates an event loop internally.
|
||||
"""
|
||||
|
||||
def __init__(self, session_id: str, server_url: str = "ws://localhost:8000/ws", enabled: bool = True):
|
||||
self.logger = WebSocketLogger(session_id, server_url, enabled)
|
||||
self.loop = None
|
||||
|
||||
def connect(self):
|
||||
"""Connect to server (synchronous)."""
|
||||
self.loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(self.loop)
|
||||
self.loop.run_until_complete(self.logger.connect())
|
||||
|
||||
def disconnect(self):
|
||||
"""Disconnect from server (synchronous)."""
|
||||
if self.loop:
|
||||
self.loop.run_until_complete(self.logger.disconnect())
|
||||
self.loop.close()
|
||||
|
||||
def _run_async(self, coro):
|
||||
"""
|
||||
Run an async coroutine synchronously.
|
||||
|
||||
Bridge between sync code (agent) and async code (WebSocket).
|
||||
Uses event loop to execute async operations in sync context.
|
||||
"""
|
||||
if self.loop and self.loop.is_running():
|
||||
# Already in event loop, just await
|
||||
asyncio.create_task(coro)
|
||||
else:
|
||||
# Run in current loop
|
||||
if self.loop:
|
||||
self.loop.run_until_complete(coro)
|
||||
|
||||
def log_query(self, query: str, model: str = None, toolsets: list = None):
|
||||
self._run_async(self.logger.log_query(query, model, toolsets))
|
||||
|
||||
def log_api_call(self, call_number: int, message_count: int = None, has_tools: bool = None):
|
||||
self._run_async(self.logger.log_api_call(call_number, message_count, has_tools))
|
||||
|
||||
def log_response(self, call_number: int, content: str = None, has_tool_calls: bool = False,
|
||||
tool_call_count: int = 0, duration: float = None):
|
||||
self._run_async(self.logger.log_response(call_number, content, has_tool_calls,
|
||||
tool_call_count, duration))
|
||||
|
||||
def log_tool_call(self, call_number: int, tool_index: int, tool_name: str,
|
||||
parameters: Dict[str, Any], tool_call_id: str = None):
|
||||
self._run_async(self.logger.log_tool_call(call_number, tool_index, tool_name,
|
||||
parameters, tool_call_id))
|
||||
|
||||
def log_tool_result(self, call_number: int, tool_index: int, tool_name: str,
|
||||
result: str = None, error: str = None, duration: float = None,
|
||||
tool_call_id: str = None, raw_result: str = None):
|
||||
self._run_async(self.logger.log_tool_result(call_number, tool_index, tool_name,
|
||||
result, error, duration, tool_call_id, raw_result))
|
||||
|
||||
def log_error(self, error_message: str, call_number: int = None):
|
||||
self._run_async(self.logger.log_error(error_message, call_number))
|
||||
|
||||
def log_complete(self, final_response: str = None, total_calls: int = None, completed: bool = True):
|
||||
self._run_async(self.logger.log_complete(final_response, total_calls, completed))
|
||||
|
||||
746
batch_runner.py
Normal file
746
batch_runner.py
Normal file
@@ -0,0 +1,746 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Batch Agent Runner
|
||||
|
||||
This module provides parallel batch processing capabilities for running the agent
|
||||
across multiple prompts from a dataset. It includes:
|
||||
- Dataset loading and batching
|
||||
- Parallel batch processing with multiprocessing
|
||||
- Checkpointing for fault tolerance and resumption
|
||||
- Trajectory saving in the proper format (from/value pairs)
|
||||
- Tool usage statistics aggregation across all batches
|
||||
|
||||
Usage:
|
||||
python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run
|
||||
|
||||
# Resume an interrupted run
|
||||
python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run --resume
|
||||
|
||||
# Use a specific toolset distribution
|
||||
python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run --distribution=image_gen
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Any, Optional, Tuple
|
||||
from datetime import datetime
|
||||
from multiprocessing import Pool, Manager, Lock
|
||||
import traceback
|
||||
|
||||
import fire
|
||||
|
||||
from run_agent import AIAgent
|
||||
from toolset_distributions import (
|
||||
get_distribution,
|
||||
list_distributions,
|
||||
sample_toolsets_from_distribution,
|
||||
validate_distribution
|
||||
)
|
||||
|
||||
|
||||
# Global configuration for worker processes
|
||||
_WORKER_CONFIG = {}
|
||||
|
||||
|
||||
def _extract_tool_stats(messages: List[Dict[str, Any]]) -> Dict[str, Dict[str, int]]:
|
||||
"""
|
||||
Extract tool usage statistics from message history.
|
||||
|
||||
Args:
|
||||
messages (List[Dict]): Message history
|
||||
|
||||
Returns:
|
||||
Dict: Tool statistics with counts and success/failure rates
|
||||
"""
|
||||
tool_stats = {}
|
||||
|
||||
# Track tool calls and their results
|
||||
tool_calls_map = {} # Map tool_call_id to tool name
|
||||
|
||||
for msg in messages:
|
||||
# Track tool calls from assistant messages
|
||||
if msg["role"] == "assistant" and "tool_calls" in msg and msg["tool_calls"]:
|
||||
for tool_call in msg["tool_calls"]:
|
||||
tool_name = tool_call["function"]["name"]
|
||||
tool_call_id = tool_call["id"]
|
||||
|
||||
# Initialize stats for this tool if not exists
|
||||
if tool_name not in tool_stats:
|
||||
tool_stats[tool_name] = {
|
||||
"count": 0,
|
||||
"success": 0,
|
||||
"failure": 0
|
||||
}
|
||||
|
||||
tool_stats[tool_name]["count"] += 1
|
||||
tool_calls_map[tool_call_id] = tool_name
|
||||
|
||||
# Track tool responses
|
||||
elif msg["role"] == "tool":
|
||||
tool_call_id = msg.get("tool_call_id", "")
|
||||
content = msg.get("content", "")
|
||||
|
||||
# Determine if tool call was successful
|
||||
is_success = True
|
||||
try:
|
||||
# Try to parse as JSON and check for actual error values
|
||||
content_json = json.loads(content) if isinstance(content, str) else content
|
||||
|
||||
if isinstance(content_json, dict):
|
||||
# Check if error field exists AND has a non-null value
|
||||
if "error" in content_json and content_json["error"] is not None:
|
||||
is_success = False
|
||||
|
||||
# Special handling for terminal tool responses
|
||||
# Terminal wraps its response in a "content" field
|
||||
if "content" in content_json and isinstance(content_json["content"], dict):
|
||||
inner_content = content_json["content"]
|
||||
# Check for actual error (non-null error field or non-zero exit code)
|
||||
has_error = (inner_content.get("error") is not None or
|
||||
inner_content.get("exit_code", 0) != 0)
|
||||
if has_error:
|
||||
is_success = False
|
||||
|
||||
# Check for "success": false pattern used by some tools
|
||||
if content_json.get("success") is False:
|
||||
is_success = False
|
||||
|
||||
except:
|
||||
# If not JSON, check if content is empty or explicitly states an error
|
||||
# Note: We avoid simple substring matching to prevent false positives
|
||||
if not content:
|
||||
is_success = False
|
||||
# Only mark as failure if it explicitly starts with "Error:" or "ERROR:"
|
||||
elif content.strip().lower().startswith("error:"):
|
||||
is_success = False
|
||||
|
||||
# Update success/failure count
|
||||
if tool_call_id in tool_calls_map:
|
||||
tool_name = tool_calls_map[tool_call_id]
|
||||
if is_success:
|
||||
tool_stats[tool_name]["success"] += 1
|
||||
else:
|
||||
tool_stats[tool_name]["failure"] += 1
|
||||
|
||||
return tool_stats
|
||||
|
||||
|
||||
def _process_single_prompt(
|
||||
prompt_index: int,
|
||||
prompt_data: Dict[str, Any],
|
||||
batch_num: int,
|
||||
config: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Process a single prompt with the agent.
|
||||
|
||||
Args:
|
||||
prompt_index (int): Index of prompt in dataset
|
||||
prompt_data (Dict): Prompt data containing 'prompt' field
|
||||
batch_num (int): Batch number
|
||||
config (Dict): Configuration dict with agent parameters
|
||||
|
||||
Returns:
|
||||
Dict: Result containing trajectory, stats, and metadata
|
||||
"""
|
||||
prompt = prompt_data["prompt"]
|
||||
|
||||
try:
|
||||
# Sample toolsets from distribution for this prompt
|
||||
selected_toolsets = sample_toolsets_from_distribution(config["distribution"])
|
||||
|
||||
if config.get("verbose"):
|
||||
print(f" Prompt {prompt_index}: Using toolsets {selected_toolsets}")
|
||||
|
||||
# Initialize agent with sampled toolsets
|
||||
agent = AIAgent(
|
||||
base_url=config.get("base_url"),
|
||||
api_key=config.get("api_key"),
|
||||
model=config["model"],
|
||||
max_iterations=config["max_iterations"],
|
||||
enabled_toolsets=selected_toolsets,
|
||||
save_trajectories=False, # We handle saving ourselves
|
||||
verbose_logging=config.get("verbose", False),
|
||||
ephemeral_system_prompt=config.get("ephemeral_system_prompt")
|
||||
)
|
||||
|
||||
# Run the agent
|
||||
result = agent.run_conversation(prompt)
|
||||
|
||||
# Extract tool usage statistics
|
||||
tool_stats = _extract_tool_stats(result["messages"])
|
||||
|
||||
# Convert to trajectory format (using existing method)
|
||||
trajectory = agent._convert_to_trajectory_format(
|
||||
result["messages"],
|
||||
prompt,
|
||||
result["completed"]
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"prompt_index": prompt_index,
|
||||
"trajectory": trajectory,
|
||||
"tool_stats": tool_stats,
|
||||
"completed": result["completed"],
|
||||
"api_calls": result["api_calls"],
|
||||
"toolsets_used": selected_toolsets,
|
||||
"metadata": {
|
||||
"batch_num": batch_num,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"model": config["model"]
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error processing prompt {prompt_index}: {e}")
|
||||
if config.get("verbose"):
|
||||
traceback.print_exc()
|
||||
|
||||
return {
|
||||
"success": False,
|
||||
"prompt_index": prompt_index,
|
||||
"error": str(e),
|
||||
"trajectory": None,
|
||||
"tool_stats": {},
|
||||
"toolsets_used": [],
|
||||
"metadata": {
|
||||
"batch_num": batch_num,
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def _process_batch_worker(args: Tuple) -> Dict[str, Any]:
|
||||
"""
|
||||
Worker function to process a single batch of prompts.
|
||||
|
||||
Args:
|
||||
args (Tuple): (batch_num, batch_data, output_dir, completed_prompts, config)
|
||||
|
||||
Returns:
|
||||
Dict: Batch results with statistics
|
||||
"""
|
||||
batch_num, batch_data, output_dir, completed_prompts_set, config = args
|
||||
|
||||
output_dir = Path(output_dir)
|
||||
print(f"\n🔄 Batch {batch_num}: Starting ({len(batch_data)} prompts)")
|
||||
|
||||
# Output file for this batch
|
||||
batch_output_file = output_dir / f"batch_{batch_num}.jsonl"
|
||||
|
||||
# Filter out already completed prompts
|
||||
prompts_to_process = [
|
||||
(idx, data) for idx, data in batch_data
|
||||
if idx not in completed_prompts_set
|
||||
]
|
||||
|
||||
if not prompts_to_process:
|
||||
print(f"✅ Batch {batch_num}: Already completed (skipping)")
|
||||
return {
|
||||
"batch_num": batch_num,
|
||||
"processed": 0,
|
||||
"skipped": len(batch_data),
|
||||
"tool_stats": {},
|
||||
"completed_prompts": []
|
||||
}
|
||||
|
||||
print(f" Processing {len(prompts_to_process)} prompts (skipping {len(batch_data) - len(prompts_to_process)} already completed)")
|
||||
|
||||
# Initialize aggregated stats for this batch
|
||||
batch_tool_stats = {}
|
||||
completed_in_batch = []
|
||||
|
||||
# Process each prompt sequentially in this batch
|
||||
for prompt_index, prompt_data in prompts_to_process:
|
||||
# Process the prompt
|
||||
result = _process_single_prompt(
|
||||
prompt_index,
|
||||
prompt_data,
|
||||
batch_num,
|
||||
config
|
||||
)
|
||||
|
||||
# Save trajectory if successful
|
||||
if result["success"] and result["trajectory"]:
|
||||
trajectory_entry = {
|
||||
"prompt_index": prompt_index,
|
||||
"conversations": result["trajectory"],
|
||||
"metadata": result["metadata"],
|
||||
"completed": result["completed"],
|
||||
"api_calls": result["api_calls"],
|
||||
"toolsets_used": result["toolsets_used"]
|
||||
}
|
||||
|
||||
# Append to batch output file
|
||||
with open(batch_output_file, 'a', encoding='utf-8') as f:
|
||||
f.write(json.dumps(trajectory_entry, ensure_ascii=False) + "\n")
|
||||
|
||||
# Aggregate tool statistics
|
||||
for tool_name, stats in result.get("tool_stats", {}).items():
|
||||
if tool_name not in batch_tool_stats:
|
||||
batch_tool_stats[tool_name] = {
|
||||
"count": 0,
|
||||
"success": 0,
|
||||
"failure": 0
|
||||
}
|
||||
|
||||
batch_tool_stats[tool_name]["count"] += stats["count"]
|
||||
batch_tool_stats[tool_name]["success"] += stats["success"]
|
||||
batch_tool_stats[tool_name]["failure"] += stats["failure"]
|
||||
|
||||
completed_in_batch.append(prompt_index)
|
||||
print(f" ✅ Prompt {prompt_index} completed")
|
||||
|
||||
print(f"✅ Batch {batch_num}: Completed ({len(prompts_to_process)} prompts processed)")
|
||||
|
||||
return {
|
||||
"batch_num": batch_num,
|
||||
"processed": len(prompts_to_process),
|
||||
"skipped": len(batch_data) - len(prompts_to_process),
|
||||
"tool_stats": batch_tool_stats,
|
||||
"completed_prompts": completed_in_batch
|
||||
}
|
||||
|
||||
|
||||
class BatchRunner:
|
||||
"""
|
||||
Manages batch processing of agent prompts with checkpointing and statistics.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset_file: str,
|
||||
batch_size: int,
|
||||
run_name: str,
|
||||
distribution: str = "default",
|
||||
max_iterations: int = 10,
|
||||
base_url: str = None,
|
||||
api_key: str = None,
|
||||
model: str = "claude-opus-4-20250514",
|
||||
num_workers: int = 4,
|
||||
verbose: bool = False,
|
||||
ephemeral_system_prompt: str = None
|
||||
):
|
||||
"""
|
||||
Initialize the batch runner.
|
||||
|
||||
Args:
|
||||
dataset_file (str): Path to the dataset JSONL file with 'prompt' field
|
||||
batch_size (int): Number of prompts per batch
|
||||
run_name (str): Name for this run (used for checkpointing and output)
|
||||
distribution (str): Toolset distribution to use (default: "default")
|
||||
max_iterations (int): Max iterations per agent run
|
||||
base_url (str): Base URL for model API
|
||||
api_key (str): API key for model
|
||||
model (str): Model name to use
|
||||
num_workers (int): Number of parallel workers
|
||||
verbose (bool): Enable verbose logging
|
||||
ephemeral_system_prompt (str): System prompt used during agent execution but NOT saved to trajectories (optional)
|
||||
"""
|
||||
self.dataset_file = Path(dataset_file)
|
||||
self.batch_size = batch_size
|
||||
self.run_name = run_name
|
||||
self.distribution = distribution
|
||||
self.max_iterations = max_iterations
|
||||
self.base_url = base_url
|
||||
self.api_key = api_key
|
||||
self.model = model
|
||||
self.num_workers = num_workers
|
||||
self.verbose = verbose
|
||||
self.ephemeral_system_prompt = ephemeral_system_prompt
|
||||
|
||||
# Validate distribution
|
||||
if not validate_distribution(distribution):
|
||||
raise ValueError(f"Unknown distribution: {distribution}. Available: {list(list_distributions().keys())}")
|
||||
|
||||
# Setup output directory
|
||||
self.output_dir = Path("data") / run_name
|
||||
self.output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Checkpoint file
|
||||
self.checkpoint_file = self.output_dir / "checkpoint.json"
|
||||
|
||||
# Statistics file
|
||||
self.stats_file = self.output_dir / "statistics.json"
|
||||
|
||||
# Load dataset
|
||||
self.dataset = self._load_dataset()
|
||||
|
||||
# Create batches
|
||||
self.batches = self._create_batches()
|
||||
|
||||
print(f"📊 Batch Runner Initialized")
|
||||
print(f" Dataset: {self.dataset_file} ({len(self.dataset)} prompts)")
|
||||
print(f" Batch size: {self.batch_size}")
|
||||
print(f" Total batches: {len(self.batches)}")
|
||||
print(f" Run name: {self.run_name}")
|
||||
print(f" Distribution: {self.distribution}")
|
||||
print(f" Output directory: {self.output_dir}")
|
||||
print(f" Workers: {self.num_workers}")
|
||||
if self.ephemeral_system_prompt:
|
||||
prompt_preview = self.ephemeral_system_prompt[:60] + "..." if len(self.ephemeral_system_prompt) > 60 else self.ephemeral_system_prompt
|
||||
print(f" 🔒 Ephemeral system prompt: '{prompt_preview}'")
|
||||
|
||||
def _load_dataset(self) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Load dataset from JSONL file.
|
||||
|
||||
Returns:
|
||||
List[Dict]: List of dataset entries
|
||||
"""
|
||||
if not self.dataset_file.exists():
|
||||
raise FileNotFoundError(f"Dataset file not found: {self.dataset_file}")
|
||||
|
||||
dataset = []
|
||||
with open(self.dataset_file, 'r', encoding='utf-8') as f:
|
||||
for line_num, line in enumerate(f, 1):
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
|
||||
try:
|
||||
entry = json.loads(line)
|
||||
if 'prompt' not in entry:
|
||||
print(f"⚠️ Warning: Line {line_num} missing 'prompt' field, skipping")
|
||||
continue
|
||||
dataset.append(entry)
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"⚠️ Warning: Invalid JSON on line {line_num}: {e}")
|
||||
continue
|
||||
|
||||
if not dataset:
|
||||
raise ValueError(f"No valid entries found in dataset file: {self.dataset_file}")
|
||||
|
||||
return dataset
|
||||
|
||||
def _create_batches(self) -> List[List[Tuple[int, Dict[str, Any]]]]:
|
||||
"""
|
||||
Split dataset into batches with indices.
|
||||
|
||||
Returns:
|
||||
List of batches, where each batch is a list of (index, entry) tuples
|
||||
"""
|
||||
batches = []
|
||||
for i in range(0, len(self.dataset), self.batch_size):
|
||||
batch = [(idx, entry) for idx, entry in enumerate(self.dataset[i:i + self.batch_size], start=i)]
|
||||
batches.append(batch)
|
||||
|
||||
return batches
|
||||
|
||||
def _load_checkpoint(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Load checkpoint data if it exists.
|
||||
|
||||
Returns:
|
||||
Dict: Checkpoint data with completed prompt indices
|
||||
"""
|
||||
if not self.checkpoint_file.exists():
|
||||
return {
|
||||
"run_name": self.run_name,
|
||||
"completed_prompts": [],
|
||||
"batch_stats": {},
|
||||
"last_updated": None
|
||||
}
|
||||
|
||||
try:
|
||||
with open(self.checkpoint_file, 'r', encoding='utf-8') as f:
|
||||
return json.load(f)
|
||||
except Exception as e:
|
||||
print(f"⚠️ Warning: Failed to load checkpoint: {e}")
|
||||
return {
|
||||
"run_name": self.run_name,
|
||||
"completed_prompts": [],
|
||||
"batch_stats": {},
|
||||
"last_updated": None
|
||||
}
|
||||
|
||||
def _save_checkpoint(self, checkpoint_data: Dict[str, Any], lock: Optional[Lock] = None):
|
||||
"""
|
||||
Save checkpoint data.
|
||||
|
||||
Args:
|
||||
checkpoint_data (Dict): Checkpoint data to save
|
||||
lock (Lock): Optional lock for thread-safe access
|
||||
"""
|
||||
checkpoint_data["last_updated"] = datetime.now().isoformat()
|
||||
|
||||
if lock:
|
||||
with lock:
|
||||
with open(self.checkpoint_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(checkpoint_data, f, indent=2)
|
||||
else:
|
||||
with open(self.checkpoint_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(checkpoint_data, f, indent=2)
|
||||
|
||||
|
||||
def run(self, resume: bool = False):
|
||||
"""
|
||||
Run the batch processing pipeline.
|
||||
|
||||
Args:
|
||||
resume (bool): Whether to resume from checkpoint
|
||||
"""
|
||||
print("\n" + "=" * 70)
|
||||
print("🚀 Starting Batch Processing")
|
||||
print("=" * 70)
|
||||
|
||||
# Load checkpoint
|
||||
checkpoint_data = self._load_checkpoint() if resume else {
|
||||
"run_name": self.run_name,
|
||||
"completed_prompts": [],
|
||||
"batch_stats": {},
|
||||
"last_updated": None
|
||||
}
|
||||
|
||||
if resume and checkpoint_data.get("completed_prompts"):
|
||||
print(f"📂 Resuming from checkpoint ({len(checkpoint_data['completed_prompts'])} prompts already completed)")
|
||||
|
||||
# Prepare configuration for workers
|
||||
config = {
|
||||
"distribution": self.distribution,
|
||||
"model": self.model,
|
||||
"max_iterations": self.max_iterations,
|
||||
"base_url": self.base_url,
|
||||
"api_key": self.api_key,
|
||||
"verbose": self.verbose,
|
||||
"ephemeral_system_prompt": self.ephemeral_system_prompt
|
||||
}
|
||||
|
||||
# Get completed prompts set
|
||||
completed_prompts_set = set(checkpoint_data.get("completed_prompts", []))
|
||||
|
||||
# Aggregate statistics across all batches
|
||||
total_tool_stats = {}
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# Process batches in parallel
|
||||
with Pool(processes=self.num_workers) as pool:
|
||||
# Create tasks for each batch
|
||||
tasks = [
|
||||
(
|
||||
batch_num,
|
||||
batch_data,
|
||||
str(self.output_dir), # Convert Path to string for pickling
|
||||
completed_prompts_set,
|
||||
config
|
||||
)
|
||||
for batch_num, batch_data in enumerate(self.batches)
|
||||
]
|
||||
|
||||
# Use map to process batches in parallel
|
||||
results = pool.map(_process_batch_worker, tasks)
|
||||
|
||||
# Aggregate all batch statistics and update checkpoint
|
||||
all_completed_prompts = list(completed_prompts_set)
|
||||
for batch_result in results:
|
||||
# Add newly completed prompts
|
||||
all_completed_prompts.extend(batch_result.get("completed_prompts", []))
|
||||
|
||||
# Aggregate tool stats
|
||||
for tool_name, stats in batch_result.get("tool_stats", {}).items():
|
||||
if tool_name not in total_tool_stats:
|
||||
total_tool_stats[tool_name] = {
|
||||
"count": 0,
|
||||
"success": 0,
|
||||
"failure": 0
|
||||
}
|
||||
|
||||
total_tool_stats[tool_name]["count"] += stats["count"]
|
||||
total_tool_stats[tool_name]["success"] += stats["success"]
|
||||
total_tool_stats[tool_name]["failure"] += stats["failure"]
|
||||
|
||||
# Save final checkpoint
|
||||
checkpoint_data["completed_prompts"] = all_completed_prompts
|
||||
self._save_checkpoint(checkpoint_data)
|
||||
|
||||
# Calculate success rates
|
||||
for tool_name in total_tool_stats:
|
||||
stats = total_tool_stats[tool_name]
|
||||
total_calls = stats["success"] + stats["failure"]
|
||||
if total_calls > 0:
|
||||
stats["success_rate"] = round(stats["success"] / total_calls * 100, 2)
|
||||
stats["failure_rate"] = round(stats["failure"] / total_calls * 100, 2)
|
||||
else:
|
||||
stats["success_rate"] = 0.0
|
||||
stats["failure_rate"] = 0.0
|
||||
|
||||
# Combine all batch files into a single trajectories.jsonl file
|
||||
combined_file = self.output_dir / "trajectories.jsonl"
|
||||
print(f"\n📦 Combining batch files into {combined_file.name}...")
|
||||
|
||||
with open(combined_file, 'w', encoding='utf-8') as outfile:
|
||||
for batch_num in range(len(self.batches)):
|
||||
batch_file = self.output_dir / f"batch_{batch_num}.jsonl"
|
||||
if batch_file.exists():
|
||||
with open(batch_file, 'r', encoding='utf-8') as infile:
|
||||
for line in infile:
|
||||
outfile.write(line)
|
||||
|
||||
print(f"✅ Combined {len(self.batches)} batch files into trajectories.jsonl")
|
||||
|
||||
# Save final statistics
|
||||
final_stats = {
|
||||
"run_name": self.run_name,
|
||||
"distribution": self.distribution,
|
||||
"total_prompts": len(self.dataset),
|
||||
"total_batches": len(self.batches),
|
||||
"batch_size": self.batch_size,
|
||||
"model": self.model,
|
||||
"completed_at": datetime.now().isoformat(),
|
||||
"duration_seconds": round(time.time() - start_time, 2),
|
||||
"tool_statistics": total_tool_stats
|
||||
}
|
||||
|
||||
with open(self.stats_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(final_stats, f, indent=2)
|
||||
|
||||
# Print summary
|
||||
print("\n" + "=" * 70)
|
||||
print("📊 BATCH PROCESSING COMPLETE")
|
||||
print("=" * 70)
|
||||
print(f"✅ Total prompts processed: {len(self.dataset)}")
|
||||
print(f"✅ Total batches: {len(self.batches)}")
|
||||
print(f"⏱️ Total duration: {round(time.time() - start_time, 2)}s")
|
||||
print(f"\n📈 Tool Usage Statistics:")
|
||||
print("-" * 70)
|
||||
|
||||
if total_tool_stats:
|
||||
# Sort by count descending
|
||||
sorted_tools = sorted(
|
||||
total_tool_stats.items(),
|
||||
key=lambda x: x[1]["count"],
|
||||
reverse=True
|
||||
)
|
||||
|
||||
print(f"{'Tool Name':<25} {'Count':<10} {'Success':<10} {'Failure':<10} {'Success Rate':<12}")
|
||||
print("-" * 70)
|
||||
for tool_name, stats in sorted_tools:
|
||||
print(
|
||||
f"{tool_name:<25} "
|
||||
f"{stats['count']:<10} "
|
||||
f"{stats['success']:<10} "
|
||||
f"{stats['failure']:<10} "
|
||||
f"{stats['success_rate']:.1f}%"
|
||||
)
|
||||
else:
|
||||
print("No tool calls were made during this run.")
|
||||
|
||||
print(f"\n💾 Results saved to: {self.output_dir}")
|
||||
print(f" - Trajectories: trajectories.jsonl (combined)")
|
||||
print(f" - Individual batches: batch_*.jsonl (for debugging)")
|
||||
print(f" - Statistics: {self.stats_file.name}")
|
||||
print(f" - Checkpoint: {self.checkpoint_file.name}")
|
||||
|
||||
|
||||
def main(
|
||||
dataset_file: str = None,
|
||||
batch_size: int = None,
|
||||
run_name: str = None,
|
||||
distribution: str = "default",
|
||||
model: str = "claude-opus-4-20250514",
|
||||
api_key: str = None,
|
||||
base_url: str = "https://api.anthropic.com/v1/",
|
||||
max_turns: int = 10,
|
||||
num_workers: int = 4,
|
||||
resume: bool = False,
|
||||
verbose: bool = False,
|
||||
list_distributions: bool = False,
|
||||
ephemeral_system_prompt: str = None
|
||||
):
|
||||
"""
|
||||
Run batch processing of agent prompts from a dataset.
|
||||
|
||||
Args:
|
||||
dataset_file (str): Path to JSONL file with 'prompt' field in each entry
|
||||
batch_size (int): Number of prompts per batch
|
||||
run_name (str): Name for this run (used for output and checkpointing)
|
||||
distribution (str): Toolset distribution to use (default: "default")
|
||||
model (str): Model name to use (default: "claude-opus-4-20250514")
|
||||
api_key (str): API key for model authentication
|
||||
base_url (str): Base URL for model API
|
||||
max_turns (int): Maximum number of tool calling iterations per prompt (default: 10)
|
||||
num_workers (int): Number of parallel worker processes (default: 4)
|
||||
resume (bool): Resume from checkpoint if run was interrupted (default: False)
|
||||
verbose (bool): Enable verbose logging (default: False)
|
||||
list_distributions (bool): List available toolset distributions and exit
|
||||
ephemeral_system_prompt (str): System prompt used during agent execution but NOT saved to trajectories (optional)
|
||||
|
||||
Examples:
|
||||
# Basic usage
|
||||
python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run
|
||||
|
||||
# Resume interrupted run
|
||||
python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run --resume
|
||||
|
||||
# Use specific distribution
|
||||
python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=image_test --distribution=image_gen
|
||||
|
||||
# With ephemeral system prompt (not saved to dataset)
|
||||
python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run \\
|
||||
--ephemeral_system_prompt="You are a helpful assistant focused on image generation."
|
||||
|
||||
# List available distributions
|
||||
python batch_runner.py --list_distributions
|
||||
"""
|
||||
# Handle list distributions
|
||||
if list_distributions:
|
||||
from toolset_distributions import list_distributions as get_all_dists, print_distribution_info
|
||||
|
||||
print("📊 Available Toolset Distributions")
|
||||
print("=" * 70)
|
||||
|
||||
all_dists = get_all_dists()
|
||||
for dist_name in sorted(all_dists.keys()):
|
||||
print_distribution_info(dist_name)
|
||||
|
||||
print("\n💡 Usage:")
|
||||
print(" python batch_runner.py --dataset_file=data.jsonl --batch_size=10 \\")
|
||||
print(" --run_name=my_run --distribution=<name>")
|
||||
return
|
||||
|
||||
# Validate required arguments
|
||||
if not dataset_file:
|
||||
print("❌ Error: --dataset_file is required")
|
||||
return
|
||||
|
||||
if not batch_size or batch_size < 1:
|
||||
print("❌ Error: --batch_size must be a positive integer")
|
||||
return
|
||||
|
||||
if not run_name:
|
||||
print("❌ Error: --run_name is required")
|
||||
return
|
||||
|
||||
# Initialize and run batch runner
|
||||
try:
|
||||
runner = BatchRunner(
|
||||
dataset_file=dataset_file,
|
||||
batch_size=batch_size,
|
||||
run_name=run_name,
|
||||
distribution=distribution,
|
||||
max_iterations=max_turns,
|
||||
base_url=base_url,
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
num_workers=num_workers,
|
||||
verbose=verbose,
|
||||
ephemeral_system_prompt=ephemeral_system_prompt
|
||||
)
|
||||
|
||||
runner.run(resume=resume)
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n❌ Fatal error: {e}")
|
||||
if verbose:
|
||||
traceback.print_exc()
|
||||
return 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(main)
|
||||
|
||||
@@ -1,243 +0,0 @@
|
||||
"""
|
||||
Mock Web Tools for Testing WebSocket Reconnection
|
||||
|
||||
This module provides mock implementations of web_search and web_extract
|
||||
that simulate long-running operations without making real API calls.
|
||||
|
||||
Perfect for testing WebSocket timeout/reconnection behavior without:
|
||||
- Wasting API credits
|
||||
- Waiting for real web crawling
|
||||
- Network dependencies
|
||||
"""
|
||||
|
||||
import time
|
||||
import json
|
||||
from typing import List
|
||||
|
||||
|
||||
def mock_web_search(query: str, delay: int = 2) -> str:
|
||||
"""
|
||||
Mock web search that returns fake results after a delay.
|
||||
|
||||
Args:
|
||||
query: Search query (ignored, just for API compatibility)
|
||||
delay: Seconds to sleep (default: 2s)
|
||||
|
||||
Returns:
|
||||
JSON string with fake search results
|
||||
"""
|
||||
print(f"🔍 [MOCK] Searching for: '{query}' (will take {delay}s)...")
|
||||
time.sleep(delay)
|
||||
|
||||
result = {
|
||||
"success": True,
|
||||
"data": {
|
||||
"web": [
|
||||
{
|
||||
"url": "https://example.com/article1",
|
||||
"title": "Mock Article 1 - Water Utilities",
|
||||
"description": "This is a mock search result for testing purposes. Real data would appear here.",
|
||||
"category": None
|
||||
},
|
||||
{
|
||||
"url": "https://example.com/article2",
|
||||
"title": "Mock Article 2 - AI Data Centers",
|
||||
"description": "Another mock result. This simulates web_search without making real API calls.",
|
||||
"category": None
|
||||
},
|
||||
{
|
||||
"url": "https://example.com/article3",
|
||||
"title": "Mock Article 3 - Investment Opportunities",
|
||||
"description": "Third mock result for testing. Query was: " + query,
|
||||
"category": None
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
return json.dumps(result, indent=2)
|
||||
|
||||
|
||||
def mock_web_extract(urls: List[str], delay: int = 60) -> str:
|
||||
"""
|
||||
Mock web extraction that simulates long-running crawl.
|
||||
|
||||
This is perfect for testing WebSocket timeout/reconnection because:
|
||||
- Default 60s delay triggers the ~30s WebSocket timeout
|
||||
- No actual web requests made
|
||||
- No API credits consumed
|
||||
- Predictable, reproducible behavior
|
||||
|
||||
Args:
|
||||
urls: List of URLs to "extract" (ignored)
|
||||
delay: Seconds to sleep (default: 60s to trigger timeout)
|
||||
|
||||
Returns:
|
||||
JSON string with fake extraction results
|
||||
"""
|
||||
print(f"🌐 [MOCK] Extracting {len(urls)} URLs (will take {delay}s)...")
|
||||
print(f"📊 [MOCK] This will test WebSocket reconnection (timeout at ~30s)")
|
||||
|
||||
# Simulate long-running operation
|
||||
# Show progress so user knows it's working
|
||||
for i in range(delay):
|
||||
if i % 10 == 0 and i > 0:
|
||||
print(f" ⏱️ [MOCK] {i}/{delay}s elapsed...")
|
||||
time.sleep(1)
|
||||
|
||||
# Generate fake but realistic-looking content
|
||||
result = {
|
||||
"success": True,
|
||||
"data": []
|
||||
}
|
||||
|
||||
for idx, url in enumerate(urls, 1):
|
||||
result["data"].append({
|
||||
"url": url,
|
||||
"title": f"Mock Extracted Content {idx}",
|
||||
"content": f"# Mock Content from {url}\n\n"
|
||||
f"This is simulated extracted content for testing purposes. "
|
||||
f"In a real scenario, this would contain the full text from the webpage. "
|
||||
f"\n\n## Key Points\n"
|
||||
f"- Mock point 1 about water utilities\n"
|
||||
f"- Mock point 2 about AI data centers\n"
|
||||
f"- Mock point 3 about investment opportunities\n"
|
||||
f"\n\nThis content took {delay} seconds to 'extract', which is long enough "
|
||||
f"to trigger WebSocket timeout and test reconnection logic."
|
||||
* 10, # Make it longer to simulate real extraction
|
||||
"extracted_at": "2025-10-10T14:00:00Z"
|
||||
})
|
||||
|
||||
json_result = json.dumps(result, indent=2)
|
||||
size_kb = len(json_result) / 1024
|
||||
|
||||
print(f"✅ [MOCK] Extraction completed: {len(urls)} URLs, {size_kb:.1f} KB")
|
||||
return json_result
|
||||
|
||||
|
||||
def mock_web_crawl(start_url: str, max_pages: int = 10, delay: int = 30) -> str:
|
||||
"""
|
||||
Mock web crawling that simulates multi-page crawl.
|
||||
|
||||
Args:
|
||||
start_url: Starting URL (ignored)
|
||||
max_pages: Max pages to crawl (just affects result count)
|
||||
delay: Seconds to sleep (default: 30s)
|
||||
|
||||
Returns:
|
||||
JSON string with fake crawl results
|
||||
"""
|
||||
print(f"🕷️ [MOCK] Crawling from: {start_url} (max {max_pages} pages, {delay}s)...")
|
||||
time.sleep(delay)
|
||||
|
||||
result = {
|
||||
"success": True,
|
||||
"data": {
|
||||
"start_url": start_url,
|
||||
"pages_crawled": min(max_pages, 5),
|
||||
"pages": []
|
||||
}
|
||||
}
|
||||
|
||||
for i in range(min(max_pages, 5)):
|
||||
result["data"]["pages"].append({
|
||||
"url": f"{start_url}/page{i+1}",
|
||||
"title": f"Mock Page {i+1}",
|
||||
"content": f"Mock content from page {i+1}. " * 50
|
||||
})
|
||||
|
||||
print(f"✅ [MOCK] Crawl completed: {len(result['data']['pages'])} pages")
|
||||
return json.dumps(result, indent=2)
|
||||
|
||||
|
||||
# Tool definitions for the agent (same format as real tools)
|
||||
MOCK_WEB_TOOLS = [
|
||||
{
|
||||
"name": "web_search",
|
||||
"description": "[MOCK] Search the web for information. Returns fake results after 2s delay. Perfect for quick tests.",
|
||||
"input_schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "The search query"
|
||||
},
|
||||
"delay": {
|
||||
"type": "integer",
|
||||
"description": "Seconds to delay (default: 2)",
|
||||
"default": 2
|
||||
}
|
||||
},
|
||||
"required": ["query"]
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "web_extract",
|
||||
"description": "[MOCK] Extract content from URLs. Simulates 60s delay to test WebSocket timeout/reconnection. Returns fake content without making real requests. PERFECT FOR TESTING!",
|
||||
"input_schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"urls": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "List of URLs to extract"
|
||||
},
|
||||
"delay": {
|
||||
"type": "integer",
|
||||
"description": "Seconds to delay (default: 60 to trigger timeout)",
|
||||
"default": 60
|
||||
}
|
||||
},
|
||||
"required": ["urls"]
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "web_crawl",
|
||||
"description": "[MOCK] Crawl website starting from URL. Returns fake results after 30s delay.",
|
||||
"input_schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"start_url": {
|
||||
"type": "string",
|
||||
"description": "Starting URL for crawl"
|
||||
},
|
||||
"max_pages": {
|
||||
"type": "integer",
|
||||
"description": "Max pages to crawl (default: 10)",
|
||||
"default": 10
|
||||
},
|
||||
"delay": {
|
||||
"type": "integer",
|
||||
"description": "Seconds to delay (default: 30)",
|
||||
"default": 30
|
||||
}
|
||||
},
|
||||
"required": ["start_url"]
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
# Map function names to implementations
|
||||
MOCK_TOOL_FUNCTIONS = {
|
||||
"web_search": mock_web_search,
|
||||
"web_extract": mock_web_extract,
|
||||
"web_crawl": mock_web_crawl
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Demo/test the mock tools
|
||||
print("Testing Mock Web Tools")
|
||||
print("=" * 60)
|
||||
|
||||
print("\n1. Mock web_search (2s delay):")
|
||||
result = mock_web_search("test query", delay=2)
|
||||
print(f"Result length: {len(result)} chars\n")
|
||||
|
||||
print("\n2. Mock web_extract (5s delay for demo - normally 60s):")
|
||||
result = mock_web_extract(["https://example.com"], delay=5)
|
||||
print(f"Result length: {len(result)} chars\n")
|
||||
|
||||
print("\n✅ All mock tools working!")
|
||||
|
||||
@@ -30,11 +30,11 @@ import json
|
||||
import asyncio
|
||||
from typing import Dict, Any, List
|
||||
|
||||
from web_tools import web_search_tool, web_extract_tool, web_crawl_tool, check_firecrawl_api_key
|
||||
from terminal_tool import terminal_tool, check_hecate_requirements, TERMINAL_TOOL_DESCRIPTION
|
||||
from vision_tools import vision_analyze_tool, check_vision_requirements
|
||||
from mixture_of_agents_tool import mixture_of_agents_tool, check_moa_requirements
|
||||
from image_generation_tool import image_generate_tool, check_image_generation_requirements
|
||||
from tools.web_tools import web_search_tool, web_extract_tool, web_crawl_tool, check_firecrawl_api_key
|
||||
from tools.terminal_tool import terminal_tool, check_hecate_requirements, TERMINAL_TOOL_DESCRIPTION
|
||||
from tools.vision_tools import vision_analyze_tool, check_vision_requirements
|
||||
from tools.mixture_of_agents_tool import mixture_of_agents_tool, check_moa_requirements
|
||||
from tools.image_generation_tool import image_generate_tool, check_image_generation_requirements
|
||||
from toolsets import (
|
||||
get_toolset, resolve_toolset, resolve_multiple_toolsets,
|
||||
get_all_toolsets, get_toolset_names, validate_toolset,
|
||||
@@ -581,8 +581,21 @@ def handle_image_function_call(function_name: str, function_args: Dict[str, Any]
|
||||
allow_nsfw_images = True
|
||||
seed = None
|
||||
|
||||
# Run async function in event loop
|
||||
return asyncio.run(image_generate_tool(
|
||||
# Run async function in event loop with proper handling for multiprocessing
|
||||
try:
|
||||
# Try to get existing event loop
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_closed():
|
||||
# If closed, create a new one
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
except RuntimeError:
|
||||
# No event loop in current thread, create one
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
# Run the coroutine in the event loop
|
||||
result = loop.run_until_complete(image_generate_tool(
|
||||
prompt=prompt,
|
||||
image_size=image_size,
|
||||
num_inference_steps=num_inference_steps,
|
||||
@@ -594,6 +607,8 @@ def handle_image_function_call(function_name: str, function_args: Dict[str, Any]
|
||||
allow_nsfw_images=allow_nsfw_images,
|
||||
seed=seed
|
||||
))
|
||||
|
||||
return result
|
||||
|
||||
else:
|
||||
return json.dumps({"error": f"Unknown image generation function: {function_name}"})
|
||||
|
||||
527
output.txt
527
output.txt
File diff suppressed because one or more lines are too long
28
pyproject.toml
Normal file
28
pyproject.toml
Normal file
@@ -0,0 +1,28 @@
|
||||
[build-system]
|
||||
requires = ["setuptools>=61.0"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "hermes-agent"
|
||||
version = "0.1.0"
|
||||
description = "AI agent with advanced tool-calling and toolsets"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
authors = [{ name = "Hermes Agent" }]
|
||||
license = { text = "MIT" }
|
||||
dependencies = [
|
||||
"firecrawl-py",
|
||||
"openai",
|
||||
"fal-client",
|
||||
"python-dotenv",
|
||||
"fire"
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
hermes-agent = "run_agent:main"
|
||||
|
||||
[tool.setuptools]
|
||||
py-modules = ["run_agent", "model_tools", "toolsets"]
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
include = ["tools"]
|
||||
@@ -3,12 +3,4 @@ openai
|
||||
fal-client
|
||||
python-dotenv
|
||||
fire
|
||||
httpx
|
||||
yt-dlp
|
||||
streamlit
|
||||
fastapi
|
||||
uvicorn
|
||||
websockets
|
||||
PySide6>=6.6.0
|
||||
websocket-client>=1.7.0
|
||||
requests>=2.31.0
|
||||
requests
|
||||
315
run_agent.py
315
run_agent.py
@@ -24,28 +24,25 @@ import json
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
import uuid
|
||||
import asyncio
|
||||
from typing import List, Dict, Any, Optional
|
||||
from openai import OpenAI
|
||||
import fire
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
# Load environment variables from .env file
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# Load .env file if it exists
|
||||
env_path = Path(__file__).parent / '.env'
|
||||
if env_path.exists():
|
||||
load_dotenv(dotenv_path=env_path)
|
||||
print(f"✅ Loaded environment variables from {env_path}")
|
||||
else:
|
||||
print(f"ℹ️ No .env file found at {env_path}. Using system environment variables.")
|
||||
|
||||
# Import our tool system
|
||||
from model_tools import get_tool_definitions, handle_function_call, check_toolset_requirements
|
||||
from mock_web_tools import MOCK_TOOL_FUNCTIONS, MOCK_WEB_TOOLS
|
||||
|
||||
# Import WebSocket connection pool (optional dependency)
|
||||
# Use synchronous API to avoid event loop management in agent layer
|
||||
try:
|
||||
from api_endpoint.websocket_connection_pool import connect_sync, send_event_sync, is_connected
|
||||
WEBSOCKET_LOGGER_AVAILABLE = True
|
||||
except ImportError:
|
||||
WEBSOCKET_LOGGER_AVAILABLE = False
|
||||
connect_sync = None
|
||||
send_event_sync = None
|
||||
is_connected = None
|
||||
print("⚠️ WebSocket logger not available (missing websockets package)")
|
||||
|
||||
|
||||
class AIAgent:
|
||||
@@ -67,10 +64,7 @@ class AIAgent:
|
||||
disabled_toolsets: List[str] = None,
|
||||
save_trajectories: bool = False,
|
||||
verbose_logging: bool = False,
|
||||
enable_websocket_logging: bool = False,
|
||||
websocket_server: str = "ws://localhost:8000/ws",
|
||||
mock_web_tools: bool = False,
|
||||
mock_delay: int = 60
|
||||
ephemeral_system_prompt: str = None
|
||||
):
|
||||
"""
|
||||
Initialize the AI Agent.
|
||||
@@ -85,21 +79,14 @@ class AIAgent:
|
||||
disabled_toolsets (List[str]): Disable tools from these toolsets (optional)
|
||||
save_trajectories (bool): Whether to save conversation trajectories to JSONL files (default: False)
|
||||
verbose_logging (bool): Enable verbose logging for debugging (default: False)
|
||||
enable_websocket_logging (bool): Enable real-time WebSocket logging (default: False)
|
||||
websocket_server (str): WebSocket server URL (default: ws://localhost:8000/ws)
|
||||
mock_web_tools (bool): Use mock web tools for testing (no API calls, configurable delays) (default: False)
|
||||
mock_delay (int): Delay in seconds for mock web_extract to test timeout (default: 60)
|
||||
ephemeral_system_prompt (str): System prompt used during agent execution but NOT saved to trajectories (optional)
|
||||
"""
|
||||
self.model = model
|
||||
self.max_iterations = max_iterations
|
||||
self.tool_delay = tool_delay
|
||||
self.save_trajectories = save_trajectories
|
||||
self.verbose_logging = verbose_logging
|
||||
self.enable_websocket_logging = enable_websocket_logging and WEBSOCKET_LOGGER_AVAILABLE
|
||||
self.websocket_server = websocket_server
|
||||
self.mock_web_tools = mock_web_tools
|
||||
self.mock_delay = mock_delay
|
||||
# Note: We use global ws_pool instead of per-instance connection
|
||||
self.ephemeral_system_prompt = ephemeral_system_prompt
|
||||
|
||||
# Store toolset filtering options
|
||||
self.enabled_toolsets = enabled_toolsets
|
||||
@@ -112,10 +99,11 @@ class AIAgent:
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||
datefmt='%H:%M:%S'
|
||||
)
|
||||
# Also set OpenAI client logging to debug
|
||||
logging.getLogger('openai').setLevel(logging.DEBUG)
|
||||
logging.getLogger('httpx').setLevel(logging.DEBUG)
|
||||
print("🔍 Verbose logging enabled")
|
||||
# Keep OpenAI and httpx at INFO level to avoid massive base64 logs
|
||||
# Even in verbose mode, we don't want to see full request/response bodies
|
||||
logging.getLogger('openai').setLevel(logging.INFO)
|
||||
logging.getLogger('httpx').setLevel(logging.WARNING)
|
||||
print("🔍 Verbose logging enabled (OpenAI/httpx request bodies suppressed)")
|
||||
else:
|
||||
# Set logging to INFO level for important messages only
|
||||
logging.basicConfig(
|
||||
@@ -174,10 +162,10 @@ class AIAgent:
|
||||
if self.save_trajectories:
|
||||
print("📝 Trajectory saving enabled")
|
||||
|
||||
# Show mock tools status
|
||||
if self.mock_web_tools:
|
||||
print(f"🧪 MOCK MODE ENABLED - Web tools will use fake data (delay: {self.mock_delay}s)")
|
||||
print(f" Perfect for testing WebSocket reconnection without API costs!")
|
||||
# Show ephemeral system prompt status
|
||||
if self.ephemeral_system_prompt:
|
||||
prompt_preview = self.ephemeral_system_prompt[:60] + "..." if len(self.ephemeral_system_prompt) > 60 else self.ephemeral_system_prompt
|
||||
print(f"🔒 Ephemeral system prompt: '{prompt_preview}' (not saved to trajectories)")
|
||||
|
||||
def _format_tools_for_system_message(self) -> str:
|
||||
"""
|
||||
@@ -353,71 +341,23 @@ class AIAgent:
|
||||
except Exception as e:
|
||||
print(f"⚠️ Failed to save trajectory: {e}")
|
||||
|
||||
def _init_websocket_connection(self, session_id: str):
|
||||
"""
|
||||
Initialize WebSocket connection pool if enabled.
|
||||
|
||||
Connects to logging server using persistent connection pool.
|
||||
Connection is shared across all agent runs - no per-run overhead!
|
||||
|
||||
Uses synchronous API - no event loop management needed in agent layer.
|
||||
"""
|
||||
if self.enable_websocket_logging and WEBSOCKET_LOGGER_AVAILABLE and connect_sync:
|
||||
try:
|
||||
# Connect to server (idempotent - safe if already connected)
|
||||
# API layer handles all event loop management internally
|
||||
connect_sync(self.websocket_server)
|
||||
|
||||
# Send session_start event for this specific session
|
||||
send_event_sync("session_start", session_id, {
|
||||
"session_id": session_id,
|
||||
"start_time": datetime.now().isoformat()
|
||||
})
|
||||
|
||||
print(f"📡 WebSocket logging enabled (session: {session_id[:8]}...)")
|
||||
except Exception as e:
|
||||
print(f"⚠️ Failed to initialize WebSocket connection: {e}")
|
||||
self.enable_websocket_logging = False
|
||||
|
||||
def run_conversation(
|
||||
self,
|
||||
user_message: str,
|
||||
system_message: str = None,
|
||||
conversation_history: List[Dict[str, Any]] = None,
|
||||
session_id: str = None
|
||||
conversation_history: List[Dict[str, Any]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Run a complete conversation with tool calling until completion.
|
||||
|
||||
Args:
|
||||
user_message (str): The user's message/question
|
||||
system_message (str): Custom system message (optional)
|
||||
system_message (str): Custom system message (optional, overrides ephemeral_system_prompt if provided)
|
||||
conversation_history (List[Dict]): Previous conversation messages (optional)
|
||||
session_id (str): Optional session ID (generated if not provided)
|
||||
|
||||
Returns:
|
||||
Dict: Complete conversation result with final response and message history
|
||||
"""
|
||||
if session_id is None:
|
||||
session_id = str(uuid.uuid4())
|
||||
|
||||
# Uses synchronous API - no event loop management in agent layer
|
||||
if self.enable_websocket_logging:
|
||||
try:
|
||||
# Connect to logging server and log initial query
|
||||
# All event loop management is handled inside the API layer
|
||||
self._init_websocket_connection(session_id)
|
||||
send_event_sync("query", session_id, {
|
||||
"query": user_message,
|
||||
"model": self.model,
|
||||
"toolsets": self.enabled_toolsets
|
||||
})
|
||||
except Exception as e:
|
||||
print(f"⚠️ WebSocket logging initialization failed: {e}")
|
||||
import traceback
|
||||
if self.verbose_logging:
|
||||
traceback.print_exc()
|
||||
|
||||
# Initialize conversation
|
||||
messages = conversation_history or []
|
||||
|
||||
@@ -429,6 +369,10 @@ class AIAgent:
|
||||
|
||||
print(f"💬 Starting conversation: '{user_message[:60]}{'...' if len(user_message) > 60 else ''}'")
|
||||
|
||||
# Determine which system prompt to use for API calls (ephemeral)
|
||||
# Priority: explicit system_message > ephemeral_system_prompt > None
|
||||
active_system_prompt = system_message if system_message is not None else self.ephemeral_system_prompt
|
||||
|
||||
# Main conversation loop
|
||||
api_call_count = 0
|
||||
final_response = None
|
||||
@@ -437,22 +381,6 @@ class AIAgent:
|
||||
api_call_count += 1
|
||||
print(f"\n🔄 Making API call #{api_call_count}...")
|
||||
|
||||
# ============================================================
|
||||
# WEBSOCKET LOGGING: API Call Start
|
||||
# ============================================================
|
||||
# Log that we're about to make an API call to the model
|
||||
# Captures: which call number, how many messages, whether tools available
|
||||
if self.enable_websocket_logging and WEBSOCKET_LOGGER_AVAILABLE and send_event_sync:
|
||||
try:
|
||||
send_event_sync("api_call", session_id, {
|
||||
"call_number": api_call_count,
|
||||
"message_count": len(messages),
|
||||
"has_tools": bool(self.tools)
|
||||
})
|
||||
except Exception as e:
|
||||
if self.verbose_logging:
|
||||
print(f"⚠️ WebSocket logging error: {e}")
|
||||
|
||||
# Log request details if verbose
|
||||
if self.verbose_logging:
|
||||
logging.debug(f"API Request - Model: {self.model}, Messages: {len(messages)}, Tools: {len(self.tools) if self.tools else 0}")
|
||||
@@ -464,38 +392,24 @@ class AIAgent:
|
||||
|
||||
while retry_count <= max_retries:
|
||||
try:
|
||||
# Prepare messages for API call
|
||||
# If we have an ephemeral system prompt, prepend it to the messages
|
||||
api_messages = messages.copy()
|
||||
if active_system_prompt:
|
||||
# Insert system message at the beginning
|
||||
api_messages = [{"role": "system", "content": active_system_prompt}] + api_messages
|
||||
|
||||
# Make API call with tools
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
messages=api_messages,
|
||||
tools=self.tools if self.tools else None,
|
||||
timeout=60.0 # Add explicit timeout
|
||||
)
|
||||
|
||||
print(f"🔧 Response: {response}")
|
||||
|
||||
api_duration = time.time() - api_start_time
|
||||
print(f"⏱️ API call completed in {api_duration:.2f}s")
|
||||
|
||||
# ============================================================
|
||||
# WEBSOCKET LOGGING: API Response
|
||||
# ============================================================
|
||||
# Log the response we got back from the AI model
|
||||
# Captures: what the model said, whether it wants tools, how long it took
|
||||
if self.enable_websocket_logging and WEBSOCKET_LOGGER_AVAILABLE and send_event_sync:
|
||||
try:
|
||||
assistant_msg = response.choices[0].message
|
||||
send_event_sync("response", session_id, {
|
||||
"call_number": api_call_count,
|
||||
"content": assistant_msg.content if hasattr(assistant_msg, 'content') else None,
|
||||
"has_tool_calls": hasattr(assistant_msg, 'tool_calls') and bool(assistant_msg.tool_calls),
|
||||
"tool_call_count": len(assistant_msg.tool_calls) if hasattr(assistant_msg, 'tool_calls') and assistant_msg.tool_calls else 0,
|
||||
"duration": api_duration
|
||||
})
|
||||
except Exception as e:
|
||||
if self.verbose_logging:
|
||||
print(f"⚠️ WebSocket logging error: {e}")
|
||||
|
||||
if self.verbose_logging:
|
||||
logging.debug(f"API Response received - Usage: {response.usage if hasattr(response, 'usage') else 'N/A'}")
|
||||
|
||||
@@ -517,12 +431,10 @@ class AIAgent:
|
||||
|
||||
# Handle assistant response
|
||||
if assistant_message.content:
|
||||
print(f"🤖 Assistant: {assistant_message.content}")
|
||||
print(f"🤖 Assistant: {assistant_message.content[:100]}{'...' if len(assistant_message.content) > 100 else ''}")
|
||||
|
||||
# Check for tool calls
|
||||
if assistant_message.tool_calls:
|
||||
|
||||
print(f"🔧 Tool calls: {assistant_message.tool_calls}")
|
||||
print(f"🔧 Processing {len(assistant_message.tool_calls)} tool call(s)...")
|
||||
|
||||
if self.verbose_logging:
|
||||
@@ -558,37 +470,10 @@ class AIAgent:
|
||||
|
||||
print(f" 📞 Tool {i}: {function_name}({list(function_args.keys())})")
|
||||
|
||||
# ============================================================
|
||||
# WEBSOCKET LOGGING: Tool Call (Before Execution)
|
||||
# ============================================================
|
||||
# Log which tool we're about to execute and with what parameters
|
||||
# This happens BEFORE tool runs, so we know what was requested
|
||||
if self.enable_websocket_logging and WEBSOCKET_LOGGER_AVAILABLE and send_event_sync:
|
||||
try:
|
||||
send_event_sync("tool_call", session_id, {
|
||||
"call_number": api_call_count,
|
||||
"tool_index": i,
|
||||
"tool_name": function_name,
|
||||
"parameters": function_args, # E.g., {"query": "Python", "limit": 5}
|
||||
"tool_call_id": tool_call.id
|
||||
})
|
||||
except Exception as e:
|
||||
if self.verbose_logging:
|
||||
print(f"⚠️ WebSocket logging error: {e}")
|
||||
|
||||
tool_start_time = time.time()
|
||||
|
||||
# Execute the tool (mock or real based on configuration)
|
||||
if self.mock_web_tools and function_name in MOCK_TOOL_FUNCTIONS:
|
||||
# Use mock implementation (no API calls, configurable delay)
|
||||
mock_function = MOCK_TOOL_FUNCTIONS[function_name]
|
||||
# Inject mock_delay for web_extract if not provided
|
||||
if function_name == "web_extract" and "delay" not in function_args:
|
||||
function_args["delay"] = self.mock_delay
|
||||
function_result = mock_function(**function_args)
|
||||
else:
|
||||
# Use real tool implementation
|
||||
function_result = handle_function_call(function_name, function_args)
|
||||
# Execute the tool
|
||||
function_result = handle_function_call(function_name, function_args)
|
||||
|
||||
tool_duration = time.time() - tool_start_time
|
||||
result_preview = function_result[:200] if len(function_result) > 200 else function_result
|
||||
@@ -606,36 +491,6 @@ class AIAgent:
|
||||
|
||||
print(f" ✅ Tool {i} completed in {tool_duration:.2f}s")
|
||||
|
||||
# ============================================================
|
||||
# WEBSOCKET LOGGING: Tool Result (After Execution)
|
||||
# ============================================================
|
||||
# Log the result we got back from the tool
|
||||
# IMPORTANT: Logs BOTH truncated preview AND full raw result
|
||||
#
|
||||
# Why both?
|
||||
# - result: Truncated to 1000 chars for quick preview in UI
|
||||
# - raw_result: FULL untruncated output for verification
|
||||
#
|
||||
# This is crucial for web tools where you want to see:
|
||||
# - What the scraper actually returned (raw_result)
|
||||
# - What the LLM processed it into (compare against raw)
|
||||
# - Verify the LLM isn't losing important information
|
||||
if self.enable_websocket_logging and WEBSOCKET_LOGGER_AVAILABLE and send_event_sync:
|
||||
try:
|
||||
send_event_sync("tool_result", session_id, {
|
||||
"call_number": api_call_count,
|
||||
"tool_index": i,
|
||||
"tool_name": function_name,
|
||||
"result": function_result[:1000] if function_result else None, # Truncated preview
|
||||
"raw_result": function_result, # Full untruncated result (can be 100KB+)
|
||||
"error": None,
|
||||
"duration": tool_duration,
|
||||
"tool_call_id": tool_call.id
|
||||
})
|
||||
except Exception as e:
|
||||
if self.verbose_logging:
|
||||
print(f"⚠️ WebSocket logging error: {e}")
|
||||
|
||||
# Delay between tool calls
|
||||
if self.tool_delay > 0 and i < len(assistant_message.tool_calls):
|
||||
time.sleep(self.tool_delay)
|
||||
@@ -660,21 +515,6 @@ class AIAgent:
|
||||
error_msg = f"Error during API call #{api_call_count}: {str(e)}"
|
||||
print(f"❌ {error_msg}")
|
||||
|
||||
# ============================================================
|
||||
# WEBSOCKET LOGGING: Error Event
|
||||
# ============================================================
|
||||
# Log any errors that occur during API calls or tool execution
|
||||
# Helps track failures and debug issues
|
||||
if self.enable_websocket_logging and WEBSOCKET_LOGGER_AVAILABLE and send_event_sync:
|
||||
try:
|
||||
send_event_sync("error", session_id, {
|
||||
"error_message": error_msg,
|
||||
"call_number": api_call_count
|
||||
})
|
||||
except Exception as ws_error:
|
||||
if self.verbose_logging:
|
||||
print(f"⚠️ WebSocket logging error: {ws_error}")
|
||||
|
||||
if self.verbose_logging:
|
||||
logging.exception("Detailed error information:")
|
||||
|
||||
@@ -701,37 +541,14 @@ class AIAgent:
|
||||
# Save trajectory if enabled
|
||||
self._save_trajectory(messages, user_message, completed)
|
||||
|
||||
# ============================================================
|
||||
# WEBSOCKET LOGGING: Session Complete
|
||||
# ============================================================
|
||||
# Log final completion event for this session
|
||||
# Note: WebSocket connection stays open for future runs (persistent pool)
|
||||
# Uses synchronous API - no event loop management in agent layer
|
||||
if self.enable_websocket_logging and WEBSOCKET_LOGGER_AVAILABLE and send_event_sync:
|
||||
try:
|
||||
# Log completion with summary information
|
||||
# API layer handles event loop management internally
|
||||
send_event_sync("complete", session_id, {
|
||||
"final_response": final_response, # What the agent finally answered
|
||||
"total_calls": api_call_count, # How many API calls were made
|
||||
"completed": completed # Did it finish successfully?
|
||||
})
|
||||
# Connection persists automatically - agent has no control over lifecycle
|
||||
except Exception as e:
|
||||
if self.verbose_logging:
|
||||
print(f"⚠️ WebSocket logging error: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
return {
|
||||
"final_response": final_response,
|
||||
"messages": messages,
|
||||
"api_calls": api_call_count,
|
||||
"completed": completed,
|
||||
"session_id": session_id if self.enable_websocket_logging else None
|
||||
"completed": completed
|
||||
}
|
||||
|
||||
def chat(self, message: str) -> str: # After we connect the UI we can put whatever we want here
|
||||
def chat(self, message: str) -> str:
|
||||
"""
|
||||
Simple chat interface that returns just the final response.
|
||||
|
||||
@@ -747,7 +564,7 @@ class AIAgent:
|
||||
|
||||
def main(
|
||||
query: str = None,
|
||||
model: str = "claude-sonnet-4-5-20250929",
|
||||
model: str = "claude-opus-4-20250514",
|
||||
api_key: str = None,
|
||||
base_url: str = "https://api.anthropic.com/v1/",
|
||||
max_turns: int = 10,
|
||||
@@ -755,11 +572,7 @@ def main(
|
||||
disabled_toolsets: str = None,
|
||||
list_tools: bool = False,
|
||||
save_trajectories: bool = False,
|
||||
verbose: bool = False,
|
||||
enable_websocket_logging: bool = False,
|
||||
websocket_server: str = "ws://localhost:8000/ws",
|
||||
mock_web_tools: bool = False,
|
||||
mock_delay: int = 60
|
||||
verbose: bool = False
|
||||
):
|
||||
"""
|
||||
Main function for running the agent directly.
|
||||
@@ -777,24 +590,9 @@ def main(
|
||||
list_tools (bool): Just list available tools and exit
|
||||
save_trajectories (bool): Save conversation trajectories to JSONL files. Defaults to False.
|
||||
verbose (bool): Enable verbose logging for debugging. Defaults to False.
|
||||
enable_websocket_logging (bool): Enable real-time WebSocket logging. Defaults to False.
|
||||
websocket_server (str): WebSocket server URL. Defaults to ws://localhost:8000/ws.
|
||||
mock_web_tools (bool): Use mock web tools for testing (no API calls, configurable delays). Defaults to False.
|
||||
mock_delay (int): Delay in seconds for mock web_extract (default: 60s to test timeout). Defaults to 60.
|
||||
|
||||
Toolset Examples:
|
||||
- "research": Web search, extract, crawl + vision tools
|
||||
|
||||
Mock Tools (Testing):
|
||||
Use --mock_web_tools to test WebSocket reconnection without API calls:
|
||||
- web_search: Returns fake results after 2s
|
||||
- web_extract: Returns fake content after 60s (tests timeout)
|
||||
- web_crawl: Returns fake pages after 30s
|
||||
|
||||
WebSocket Logging:
|
||||
1. Start logging server: python logging_server.py
|
||||
2. Run agent with --enable_websocket_logging flag
|
||||
3. View logs in realtime at http://localhost:8000
|
||||
"""
|
||||
print("🤖 AI Agent with Tool Calling")
|
||||
print("=" * 50)
|
||||
@@ -899,11 +697,6 @@ def main(
|
||||
print(f" - Successful conversations → trajectory_samples.jsonl")
|
||||
print(f" - Failed conversations → failed_trajectories.jsonl")
|
||||
|
||||
if enable_websocket_logging:
|
||||
print(f"📡 WebSocket logging: ENABLED")
|
||||
print(f" - Server: {websocket_server}")
|
||||
print(f" - Make sure logging server is running: python logging_server.py")
|
||||
|
||||
# Initialize agent with provided parameters
|
||||
try:
|
||||
agent = AIAgent(
|
||||
@@ -914,11 +707,7 @@ def main(
|
||||
enabled_toolsets=enabled_toolsets_list,
|
||||
disabled_toolsets=disabled_toolsets_list,
|
||||
save_trajectories=save_trajectories,
|
||||
verbose_logging=verbose,
|
||||
enable_websocket_logging=enable_websocket_logging,
|
||||
websocket_server=websocket_server,
|
||||
mock_web_tools=mock_web_tools,
|
||||
mock_delay=mock_delay
|
||||
verbose_logging=verbose
|
||||
)
|
||||
except RuntimeError as e:
|
||||
print(f"❌ Failed to initialize agent: {e}")
|
||||
@@ -932,9 +721,6 @@ def main(
|
||||
)
|
||||
else:
|
||||
user_query = query
|
||||
|
||||
# There needs to be a multi-turn conversation here
|
||||
# Hermes Agent needs to be multi-turn to be useful
|
||||
|
||||
print(f"\n📝 User Query: {user_query}")
|
||||
print("\n" + "=" * 50)
|
||||
@@ -959,12 +745,3 @@ def main(
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(main)
|
||||
|
||||
|
||||
# Order of operations:
|
||||
# First track the ways in which information flows through the agent in realtime
|
||||
# Create a FastAPI endpoint that is first able to listen for the logging through sockets
|
||||
# Create the UI through there and now you have you have a pretty UI. CHECKPOINT 1
|
||||
# Now that you have better visualization write out the chat interface and allow it to be controlled through the UI as well as the main program
|
||||
# Now decide how the information flows through the agent you may need to do some trial and error to get this part right
|
||||
# Implement multiturn conversation now and then CHECKPOINT 2 is now done with multiturn conversations
|
||||
12
run_datagen_images.sh
Normal file
12
run_datagen_images.sh
Normal file
@@ -0,0 +1,12 @@
|
||||
python batch_runner.py \
|
||||
--dataset_file="hermes-agent-imagen-data/hermes_agent_imagen_eval.jsonl" \
|
||||
--batch_size=10 \
|
||||
--run_name="imagen_eval_gpt5" \
|
||||
--distribution="image_gen" \
|
||||
--model="gpt-5" \
|
||||
--base_url="https://api.openai.com/v1" \
|
||||
--api_key="${OPENAI_API_KEY}" \
|
||||
--num_workers=4 \
|
||||
--max_turns=5 \
|
||||
--verbose \
|
||||
--ephemeral_system_prompt="When generating an image for the user view the image by using the vision_analyze tool to ensure it is what the user wanted. If it isn't feel free to retry a few times. If none are perfect, choose the best option that is the closest match, and explain its imperfections. If the image generation tool fails, try again a few times. If the vision analyze tool fails, provide the image to the user and explain it is your best effort attempt."
|
||||
@@ -1,122 +0,0 @@
|
||||
#!/bin/bash
|
||||
#
|
||||
# Test Script for Mock Web Tools & WebSocket Reconnection
|
||||
#
|
||||
# This script tests:
|
||||
# 1. Mock web tools (no API calls, fake data)
|
||||
# 2. WebSocket timeout/reconnection during long operations
|
||||
# 3. Complete logging capture
|
||||
#
|
||||
# Perfect for development/testing without wasting API credits!
|
||||
|
||||
set -e
|
||||
|
||||
cd "$(dirname "$0")"
|
||||
|
||||
echo "=========================================="
|
||||
echo "🧪 Mock Mode Test Script"
|
||||
echo "=========================================="
|
||||
echo ""
|
||||
|
||||
# Check if logging server is running
|
||||
if ! curl -s http://localhost:8000/health > /dev/null 2>&1; then
|
||||
echo "⚠️ Logging server not detected!"
|
||||
echo " Starting logging server in background..."
|
||||
python api_endpoint/logging_server.py &
|
||||
SERVER_PID=$!
|
||||
echo " Server PID: $SERVER_PID"
|
||||
sleep 3
|
||||
else
|
||||
echo "✅ Logging server already running"
|
||||
SERVER_PID=""
|
||||
fi
|
||||
|
||||
echo ""
|
||||
echo "📋 Test Configuration:"
|
||||
echo " - Mock web tools: ENABLED"
|
||||
echo " - Mock delay: 60 seconds (triggers WebSocket timeout)"
|
||||
echo " - WebSocket logging: ENABLED"
|
||||
echo " - Expected behavior: Connection timeout + auto-reconnect"
|
||||
echo ""
|
||||
echo "🔄 Running agent with mock mode..."
|
||||
echo " (This will take ~60 seconds to test reconnection)"
|
||||
echo ""
|
||||
|
||||
# Run agent with mock mode
|
||||
python run_agent.py \
|
||||
--enabled_toolsets web \
|
||||
--enable_websocket_logging \
|
||||
--mock_web_tools \
|
||||
--mock_delay 60 \
|
||||
--query "Find publicly traded water companies benefiting from AI data centers"
|
||||
|
||||
echo ""
|
||||
echo "=========================================="
|
||||
echo "✅ Test Complete!"
|
||||
echo "=========================================="
|
||||
echo ""
|
||||
|
||||
# Find most recent log file
|
||||
LATEST_LOG=$(ls -t api_endpoint/logs/realtime/session_*.json 2>/dev/null | head -1)
|
||||
|
||||
if [ -n "$LATEST_LOG" ]; then
|
||||
echo "📊 Log Analysis:"
|
||||
echo " File: $LATEST_LOG"
|
||||
echo ""
|
||||
|
||||
# Count events
|
||||
echo " Event Counts:"
|
||||
python3 -c "
|
||||
import json
|
||||
import sys
|
||||
|
||||
with open('$LATEST_LOG') as f:
|
||||
data = json.load(f)
|
||||
events = data.get('events', [])
|
||||
|
||||
# Count by type
|
||||
counts = {}
|
||||
for e in events:
|
||||
etype = e.get('type', 'unknown')
|
||||
counts[etype] = counts.get(etype, 0) + 1
|
||||
|
||||
for etype, count in sorted(counts.items()):
|
||||
print(f' - {etype}: {count}')
|
||||
|
||||
# Check completeness
|
||||
has_complete = any(e.get('type') == 'complete' for e in events)
|
||||
print()
|
||||
if has_complete:
|
||||
print(' ✅ Session completed successfully!')
|
||||
else:
|
||||
print(' ⚠️ Session incomplete (may have been interrupted)')
|
||||
|
||||
# Check for reconnections
|
||||
tool_results = [e for e in events if e.get('type') == 'tool_result']
|
||||
tool_calls = [e for e in events if e.get('type') == 'tool_call']
|
||||
|
||||
if len(tool_results) == len(tool_calls):
|
||||
print(' ✅ All tool calls have results (no missing events)')
|
||||
else:
|
||||
print(f' ⚠️ Tool calls: {len(tool_calls)}, Results: {len(tool_results)}')
|
||||
"
|
||||
else
|
||||
echo "⚠️ No log files found"
|
||||
fi
|
||||
|
||||
# Cleanup
|
||||
if [ -n "$SERVER_PID" ]; then
|
||||
echo ""
|
||||
echo "🛑 Stopping logging server (PID: $SERVER_PID)..."
|
||||
kill $SERVER_PID 2>/dev/null || true
|
||||
fi
|
||||
|
||||
echo ""
|
||||
echo "💡 Key Observations to Look For:"
|
||||
echo " 1. '[MOCK]' prefix on tool execution messages"
|
||||
echo " 2. '🔄 Reconnecting to logging server' after long tool"
|
||||
echo " 3. '✅ Reconnected successfully!' confirmation"
|
||||
echo " 4. Complete log file with all events captured"
|
||||
echo ""
|
||||
echo "🎉 Mock mode test completed!"
|
||||
|
||||
@@ -1,242 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test Parallel Execution with Persistent WebSocket Connection Pool
|
||||
|
||||
This script demonstrates that multiple agent runs can execute in parallel,
|
||||
all sharing a single WebSocket connection for logging.
|
||||
|
||||
Benefits:
|
||||
- No connection overhead (single persistent connection)
|
||||
- No timeout issues (connection stays alive)
|
||||
- True parallel execution (multiple sessions simultaneously)
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from run_agent import AIAgent
|
||||
import time
|
||||
|
||||
|
||||
async def run_agent_query(query: str, agent_name: str, mock_delay: int = 10):
|
||||
"""
|
||||
Run a single agent query with logging.
|
||||
|
||||
Args:
|
||||
query: Query to send to agent
|
||||
agent_name: Name for logging purposes
|
||||
mock_delay: Delay for mock tools (seconds)
|
||||
"""
|
||||
print(f"🚀 [{agent_name}] Starting query: '{query[:40]}...'")
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
agent = AIAgent(
|
||||
model="claude-sonnet-4-5-20250929",
|
||||
max_iterations=5,
|
||||
enabled_toolsets=["web"],
|
||||
enable_websocket_logging=True,
|
||||
websocket_server="ws://localhost:8000/ws",
|
||||
mock_web_tools=True, # Use mock tools for fast testing
|
||||
mock_delay=mock_delay
|
||||
)
|
||||
|
||||
result = await agent.run_conversation(query)
|
||||
|
||||
duration = time.time() - start_time
|
||||
print(f"✅ [{agent_name}] Completed in {duration:.1f}s - {result['api_calls']} API calls")
|
||||
|
||||
return {
|
||||
"agent": agent_name,
|
||||
"query": query,
|
||||
"success": True,
|
||||
"duration": duration,
|
||||
"api_calls": result['api_calls'],
|
||||
"session_id": result.get('session_id')
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
duration = time.time() - start_time
|
||||
print(f"❌ [{agent_name}] Failed in {duration:.1f}s: {e}")
|
||||
return {
|
||||
"agent": agent_name,
|
||||
"query": query,
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"duration": duration
|
||||
}
|
||||
|
||||
|
||||
async def test_sequential():
|
||||
"""
|
||||
Test 1: Sequential execution (baseline).
|
||||
|
||||
Runs 3 queries one after another. This shows how long it takes
|
||||
without parallelization.
|
||||
"""
|
||||
print("\n" + "="*60)
|
||||
print("TEST 1: Sequential Execution (Baseline)")
|
||||
print("="*60)
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
results = []
|
||||
for i in range(3):
|
||||
result = await run_agent_query(
|
||||
query=f"Find information about water companies #{i+1}",
|
||||
agent_name=f"Agent{i+1}",
|
||||
mock_delay=5 # Short delay for quick test
|
||||
)
|
||||
results.append(result)
|
||||
|
||||
total_time = time.time() - start_time
|
||||
|
||||
print(f"\n📊 Sequential Results:")
|
||||
print(f" Total time: {total_time:.1f}s")
|
||||
print(f" Successful: {sum(1 for r in results if r['success'])}/3")
|
||||
print(f" Average per query: {total_time/3:.1f}s")
|
||||
|
||||
return results
|
||||
|
||||
|
||||
async def test_parallel():
|
||||
"""
|
||||
Test 2: Parallel execution.
|
||||
|
||||
Runs 3 queries simultaneously using asyncio.gather().
|
||||
All queries share the same WebSocket connection for logging.
|
||||
"""
|
||||
print("\n" + "="*60)
|
||||
print("TEST 2: Parallel Execution (Shared Connection)")
|
||||
print("="*60)
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# Run all queries in parallel!
|
||||
results = await asyncio.gather(
|
||||
run_agent_query(
|
||||
query="Find publicly traded water utility companies",
|
||||
agent_name="Agent1",
|
||||
mock_delay=5
|
||||
),
|
||||
run_agent_query(
|
||||
query="Find energy infrastructure companies",
|
||||
agent_name="Agent2",
|
||||
mock_delay=5
|
||||
),
|
||||
run_agent_query(
|
||||
query="Find AI data center operators",
|
||||
agent_name="Agent3",
|
||||
mock_delay=5
|
||||
)
|
||||
)
|
||||
|
||||
total_time = time.time() - start_time
|
||||
|
||||
print(f"\n📊 Parallel Results:")
|
||||
print(f" Total time: {total_time:.1f}s")
|
||||
print(f" Successful: {sum(1 for r in results if r['success'])}/3")
|
||||
print(f" Speedup: ~{(sum(r['duration'] for r in results) / total_time):.1f}x")
|
||||
print(f" Sessions logged: {[r.get('session_id', 'N/A')[:8] for r in results]}")
|
||||
|
||||
return results
|
||||
|
||||
|
||||
async def test_high_concurrency():
|
||||
"""
|
||||
Test 3: High concurrency (stress test).
|
||||
|
||||
Runs 10 queries simultaneously to test connection pool under load.
|
||||
"""
|
||||
print("\n" + "="*60)
|
||||
print("TEST 3: High Concurrency (10 Parallel Agents)")
|
||||
print("="*60)
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
tasks = [
|
||||
run_agent_query(
|
||||
query=f"Test query #{i+1}",
|
||||
agent_name=f"Agent{i+1}",
|
||||
mock_delay=3 # Very short for stress test
|
||||
)
|
||||
for i in range(10)
|
||||
]
|
||||
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
total_time = time.time() - start_time
|
||||
successful = sum(1 for r in results if r['success'])
|
||||
|
||||
print(f"\n📊 High Concurrency Results:")
|
||||
print(f" Total time: {total_time:.1f}s")
|
||||
print(f" Successful: {successful}/10")
|
||||
print(f" Failed: {10 - successful}/10")
|
||||
print(f" Queries per second: {10 / total_time:.2f}")
|
||||
|
||||
return results
|
||||
|
||||
|
||||
async def main():
|
||||
"""Run all tests."""
|
||||
print("\n🧪 WebSocket Connection Pool - Parallel Execution Tests")
|
||||
print("="*60)
|
||||
print("\nPREREQUISITE: Make sure logging server is running:")
|
||||
print(" python api_endpoint/logging_server.py")
|
||||
print("\nPress Ctrl+C to stop at any time\n")
|
||||
|
||||
await asyncio.sleep(2) # Give user time to read
|
||||
|
||||
try:
|
||||
# Test 1: Sequential (baseline)
|
||||
seq_results = await test_sequential()
|
||||
|
||||
# Test 2: Parallel (main test)
|
||||
par_results = await test_parallel()
|
||||
|
||||
# Test 3: High concurrency
|
||||
stress_results = await test_high_concurrency()
|
||||
|
||||
# Summary
|
||||
print("\n" + "="*60)
|
||||
print("SUMMARY")
|
||||
print("="*60)
|
||||
print(f"\n✅ All tests completed!")
|
||||
print(f"\nKey Findings:")
|
||||
print(f" • Sequential (3 queries): {sum(r['duration'] for r in seq_results):.1f}s total")
|
||||
print(f" • Parallel (3 queries): {max(r['duration'] for r in par_results):.1f}s total")
|
||||
print(f" • Speedup: ~{sum(r['duration'] for r in seq_results) / max(r['duration'] for r in par_results):.1f}x")
|
||||
print(f" • High concurrency (10 queries): ✅ Handled successfully")
|
||||
print(f"\n💡 All queries used the same persistent WebSocket connection!")
|
||||
print(f" No connection overhead, no timeouts, true parallelization.")
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n\n⚠️ Tests interrupted by user")
|
||||
except Exception as e:
|
||||
print(f"\n\n❌ Tests failed: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("\n" + "="*60)
|
||||
print("SETUP CHECK")
|
||||
print("="*60)
|
||||
|
||||
# Check if logging server is running
|
||||
import socket
|
||||
try:
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
result = sock.connect_ex(('localhost', 8000))
|
||||
sock.close()
|
||||
|
||||
if result == 0:
|
||||
print("✅ Logging server is running on port 8000")
|
||||
else:
|
||||
print("⚠️ Logging server not detected on port 8000")
|
||||
print(" Start it with: python api_endpoint/logging_server.py")
|
||||
print("\nContinuing anyway (tests will fail gracefully)...")
|
||||
except Exception as e:
|
||||
print(f"⚠️ Could not check server status: {e}")
|
||||
|
||||
# Run tests
|
||||
asyncio.run(main())
|
||||
|
||||
12
test_run.sh
Normal file → Executable file
12
test_run.sh
Normal file → Executable file
@@ -17,15 +17,7 @@ export WEB_TOOLS_DEBUG=true
|
||||
python run_agent.py \
|
||||
--query "$PROMPT" \
|
||||
--max_turns 30 \
|
||||
--model claude-sonnet-4-20250514 \
|
||||
--model claude-sonnet-4-5-20250929 \
|
||||
--base_url https://api.anthropic.com/v1/ \
|
||||
--api_key $ANTHROPIC_API_KEY \
|
||||
--save_trajectories \
|
||||
--enabled_toolsets=web
|
||||
|
||||
# --model claude-sonnet-4-20250514 \
|
||||
#
|
||||
#Possible Toolsets:
|
||||
#web_tools
|
||||
#vision_tools
|
||||
#terminal_tools
|
||||
--save_trajectories
|
||||
264
test_ui_flow.py
264
test_ui_flow.py
@@ -1,264 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test script to verify UI flow works correctly.
|
||||
|
||||
This tests:
|
||||
1. API server is running
|
||||
2. WebSocket connection works
|
||||
3. Agent can be started via API
|
||||
4. Events are broadcast properly
|
||||
"""
|
||||
|
||||
import requests
|
||||
import json
|
||||
import time
|
||||
import websocket
|
||||
import threading
|
||||
|
||||
API_URL = "http://localhost:8000"
|
||||
WS_URL = "ws://localhost:8000/ws"
|
||||
|
||||
def test_api_server():
|
||||
"""Test if API server is running."""
|
||||
print("🔍 Testing API server...")
|
||||
try:
|
||||
response = requests.get(f"{API_URL}/", timeout=5)
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
print(f"✅ API server is running: {data.get('service')}")
|
||||
print(f" Active connections: {data.get('active_connections')}")
|
||||
return True
|
||||
else:
|
||||
print(f"❌ API server returned: {response.status_code}")
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"❌ API server not accessible: {e}")
|
||||
return False
|
||||
|
||||
def test_tools_endpoint():
|
||||
"""Test if tools endpoint works."""
|
||||
print("\n🔍 Testing tools endpoint...")
|
||||
try:
|
||||
response = requests.get(f"{API_URL}/tools", timeout=5)
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
toolsets = data.get("toolsets", [])
|
||||
print(f"✅ Tools endpoint works - {len(toolsets)} toolsets available")
|
||||
for ts in toolsets[:3]:
|
||||
print(f" • {ts.get('name')} ({ts.get('tool_count')} tools)")
|
||||
return True
|
||||
else:
|
||||
print(f"❌ Tools endpoint failed: {response.status_code}")
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"❌ Tools endpoint error: {e}")
|
||||
return False
|
||||
|
||||
def test_websocket():
|
||||
"""Test WebSocket connection."""
|
||||
print("\n🔍 Testing WebSocket connection...")
|
||||
|
||||
connected = threading.Event()
|
||||
message_received = threading.Event()
|
||||
messages = []
|
||||
|
||||
def on_open(ws):
|
||||
print("✅ WebSocket connected")
|
||||
connected.set()
|
||||
|
||||
def on_message(ws, message):
|
||||
data = json.loads(message)
|
||||
messages.append(data)
|
||||
message_received.set()
|
||||
print(f"📨 Received: {data.get('event_type', 'unknown')}")
|
||||
|
||||
def on_error(ws, error):
|
||||
print(f"❌ WebSocket error: {error}")
|
||||
|
||||
def on_close(ws, close_status_code, close_msg):
|
||||
print(f"🔌 WebSocket closed: {close_status_code}")
|
||||
|
||||
ws = websocket.WebSocketApp(
|
||||
WS_URL,
|
||||
on_open=on_open,
|
||||
on_message=on_message,
|
||||
on_error=on_error,
|
||||
on_close=on_close
|
||||
)
|
||||
|
||||
# Run WebSocket in background
|
||||
ws_thread = threading.Thread(target=lambda: ws.run_forever(), daemon=True)
|
||||
ws_thread.start()
|
||||
|
||||
# Wait for connection
|
||||
if connected.wait(timeout=5):
|
||||
print("✅ WebSocket connection established")
|
||||
ws.close()
|
||||
return True
|
||||
else:
|
||||
print("❌ WebSocket connection timeout")
|
||||
ws.close()
|
||||
return False
|
||||
|
||||
def test_agent_run():
|
||||
"""Test running agent via API."""
|
||||
print("\n🔍 Testing agent run via API (mock mode)...")
|
||||
|
||||
# Start listening for events first
|
||||
events = []
|
||||
ws_connected = threading.Event()
|
||||
session_complete = threading.Event()
|
||||
|
||||
def on_message(ws, message):
|
||||
data = json.loads(message)
|
||||
events.append(data)
|
||||
event_type = data.get("event_type")
|
||||
print(f" 📨 Event: {event_type}")
|
||||
|
||||
if event_type == "complete":
|
||||
session_complete.set()
|
||||
|
||||
def on_open(ws):
|
||||
ws_connected.set()
|
||||
|
||||
# Connect WebSocket
|
||||
ws = websocket.WebSocketApp(
|
||||
WS_URL,
|
||||
on_open=on_open,
|
||||
on_message=on_message
|
||||
)
|
||||
|
||||
ws_thread = threading.Thread(target=lambda: ws.run_forever(), daemon=True)
|
||||
ws_thread.start()
|
||||
|
||||
# Wait for WebSocket connection
|
||||
if not ws_connected.wait(timeout=5):
|
||||
print("❌ WebSocket didn't connect")
|
||||
ws.close()
|
||||
return False
|
||||
|
||||
print("✅ WebSocket connected, starting agent...")
|
||||
|
||||
# Submit agent run
|
||||
payload = {
|
||||
"query": "Test query for UI flow verification",
|
||||
"model": "claude-sonnet-4-5-20250929",
|
||||
"base_url": "https://api.anthropic.com/v1/",
|
||||
"enabled_toolsets": ["web"],
|
||||
"max_turns": 5,
|
||||
"mock_web_tools": True, # Use mock mode to avoid API costs
|
||||
"mock_delay": 2, # Fast for testing
|
||||
"verbose": False
|
||||
}
|
||||
|
||||
try:
|
||||
response = requests.post(f"{API_URL}/agent/run", json=payload, timeout=10)
|
||||
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
session_id = result.get("session_id")
|
||||
print(f"✅ Agent started: {session_id[:8]}...")
|
||||
|
||||
# Wait for completion (or timeout)
|
||||
print("⏳ Waiting for agent to complete (up to 30s)...")
|
||||
if session_complete.wait(timeout=30):
|
||||
print(f"✅ Agent completed! Received {len(events)} events:")
|
||||
|
||||
# Count event types
|
||||
event_counts = {}
|
||||
for evt in events:
|
||||
evt_type = evt.get("event_type", "unknown")
|
||||
event_counts[evt_type] = event_counts.get(evt_type, 0) + 1
|
||||
|
||||
for evt_type, count in event_counts.items():
|
||||
print(f" • {evt_type}: {count}")
|
||||
|
||||
# Check we got expected events
|
||||
expected_events = ["query", "api_call", "response", "complete"]
|
||||
missing = [e for e in expected_events if e not in event_counts]
|
||||
|
||||
if missing:
|
||||
print(f"⚠️ Missing expected events: {missing}")
|
||||
else:
|
||||
print("✅ All expected event types received!")
|
||||
|
||||
ws.close()
|
||||
return True
|
||||
else:
|
||||
print(f"⚠️ Timeout waiting for completion. Got {len(events)} events so far.")
|
||||
ws.close()
|
||||
return False
|
||||
|
||||
else:
|
||||
print(f"❌ Agent start failed: {response.status_code}")
|
||||
print(f" Response: {response.text}")
|
||||
ws.close()
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Agent run error: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
ws.close()
|
||||
return False
|
||||
|
||||
def main():
|
||||
"""Run all tests."""
|
||||
print("=" * 60)
|
||||
print("🧪 Hermes Agent UI Flow Test")
|
||||
print("=" * 60)
|
||||
print("\nThis will test the complete flow:")
|
||||
print(" 1. API server connectivity")
|
||||
print(" 2. Tools endpoint")
|
||||
print(" 3. WebSocket connection")
|
||||
print(" 4. Agent execution via API (mock mode)")
|
||||
print(" 5. Event streaming to UI")
|
||||
print("\n" + "=" * 60)
|
||||
|
||||
results = []
|
||||
|
||||
# Test 1: API server
|
||||
results.append(("API Server", test_api_server()))
|
||||
|
||||
# Test 2: Tools endpoint
|
||||
results.append(("Tools Endpoint", test_tools_endpoint()))
|
||||
|
||||
# Test 3: WebSocket
|
||||
results.append(("WebSocket Connection", test_websocket()))
|
||||
|
||||
# Test 4: Agent run
|
||||
results.append(("Agent Execution + Events", test_agent_run()))
|
||||
|
||||
# Summary
|
||||
print("\n" + "=" * 60)
|
||||
print("📊 TEST SUMMARY")
|
||||
print("=" * 60)
|
||||
|
||||
for test_name, passed in results:
|
||||
status = "✅ PASS" if passed else "❌ FAIL"
|
||||
print(f"{status} - {test_name}")
|
||||
|
||||
all_passed = all(r[1] for r in results)
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
if all_passed:
|
||||
print("🎉 ALL TESTS PASSED!")
|
||||
print("\n✅ The UI flow is working correctly!")
|
||||
print(" You can now use the UI to:")
|
||||
print(" • Submit queries")
|
||||
print(" • View real-time events")
|
||||
print(" • See tool executions")
|
||||
print(" • Get final responses")
|
||||
else:
|
||||
print("❌ SOME TESTS FAILED")
|
||||
print("\nMake sure:")
|
||||
print(" 1. API server is running: python api_endpoint/logging_server.py")
|
||||
print(" 2. ANTHROPIC_API_KEY is set in environment")
|
||||
print(" 3. All dependencies are installed: pip install -r requirements.txt")
|
||||
print("=" * 60)
|
||||
|
||||
return 0 if all_passed else 1
|
||||
|
||||
if __name__ == "__main__":
|
||||
exit(main())
|
||||
|
||||
0
tests/__init__.py
Normal file
0
tests/__init__.py
Normal file
129
tests/test_batch_runner.py
Normal file
129
tests/test_batch_runner.py
Normal file
@@ -0,0 +1,129 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test script for batch runner
|
||||
|
||||
This script tests the batch runner with a small sample dataset
|
||||
to verify functionality before running large batches.
|
||||
"""
|
||||
|
||||
import json
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def create_test_dataset():
|
||||
"""Create a small test dataset."""
|
||||
test_file = Path("tests/test_dataset.jsonl")
|
||||
test_file.parent.mkdir(exist_ok=True)
|
||||
|
||||
prompts = [
|
||||
{"prompt": "What is 2 + 2?"},
|
||||
{"prompt": "What is the capital of France?"},
|
||||
{"prompt": "Explain what Python is in one sentence."},
|
||||
]
|
||||
|
||||
with open(test_file, 'w') as f:
|
||||
for prompt in prompts:
|
||||
f.write(json.dumps(prompt) + "\n")
|
||||
|
||||
print(f"✅ Created test dataset: {test_file}")
|
||||
return test_file
|
||||
|
||||
|
||||
def cleanup_test_run(run_name):
|
||||
"""Clean up test run output."""
|
||||
output_dir = Path("data") / run_name
|
||||
if output_dir.exists():
|
||||
shutil.rmtree(output_dir)
|
||||
print(f"🗑️ Cleaned up test output: {output_dir}")
|
||||
|
||||
|
||||
def verify_output(run_name):
|
||||
"""Verify that output files were created correctly."""
|
||||
output_dir = Path("data") / run_name
|
||||
|
||||
# Check directory exists
|
||||
if not output_dir.exists():
|
||||
print(f"❌ Output directory not found: {output_dir}")
|
||||
return False
|
||||
|
||||
# Check for checkpoint
|
||||
checkpoint_file = output_dir / "checkpoint.json"
|
||||
if not checkpoint_file.exists():
|
||||
print(f"❌ Checkpoint file not found: {checkpoint_file}")
|
||||
return False
|
||||
|
||||
# Check for statistics
|
||||
stats_file = output_dir / "statistics.json"
|
||||
if not stats_file.exists():
|
||||
print(f"❌ Statistics file not found: {stats_file}")
|
||||
return False
|
||||
|
||||
# Check for batch files
|
||||
batch_files = list(output_dir.glob("batch_*.jsonl"))
|
||||
if not batch_files:
|
||||
print(f"❌ No batch files found in: {output_dir}")
|
||||
return False
|
||||
|
||||
print(f"✅ Output verification passed:")
|
||||
print(f" - Checkpoint: {checkpoint_file}")
|
||||
print(f" - Statistics: {stats_file}")
|
||||
print(f" - Batch files: {len(batch_files)}")
|
||||
|
||||
# Load and display statistics
|
||||
with open(stats_file) as f:
|
||||
stats = json.load(f)
|
||||
|
||||
print(f"\n📊 Statistics Summary:")
|
||||
print(f" - Total prompts: {stats['total_prompts']}")
|
||||
print(f" - Total batches: {stats['total_batches']}")
|
||||
print(f" - Duration: {stats['duration_seconds']}s")
|
||||
|
||||
if stats.get('tool_statistics'):
|
||||
print(f" - Tool calls:")
|
||||
for tool, tool_stats in stats['tool_statistics'].items():
|
||||
print(f" • {tool}: {tool_stats['count']} calls, {tool_stats['success_rate']:.1f}% success")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def main():
|
||||
"""Run the test."""
|
||||
print("🧪 Batch Runner Test")
|
||||
print("=" * 60)
|
||||
|
||||
run_name = "test_run"
|
||||
|
||||
# Clean up any previous test run
|
||||
cleanup_test_run(run_name)
|
||||
|
||||
# Create test dataset
|
||||
test_file = create_test_dataset()
|
||||
|
||||
print(f"\n📝 To run the test manually:")
|
||||
print(f" python batch_runner.py \\")
|
||||
print(f" --dataset_file={test_file} \\")
|
||||
print(f" --batch_size=2 \\")
|
||||
print(f" --run_name={run_name} \\")
|
||||
print(f" --distribution=minimal \\")
|
||||
print(f" --num_workers=2")
|
||||
|
||||
print(f"\n💡 Or test with different distributions:")
|
||||
print(f" python batch_runner.py --list_distributions")
|
||||
|
||||
print(f"\n🔍 After running, you can verify output with:")
|
||||
print(f" python tests/test_batch_runner.py --verify")
|
||||
|
||||
# Note: We don't actually run the batch runner here to avoid API calls during testing
|
||||
# Users should run it manually with their API keys configured
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
if "--verify" in sys.argv:
|
||||
run_name = "test_run"
|
||||
verify_output(run_name)
|
||||
else:
|
||||
main()
|
||||
|
||||
@@ -23,8 +23,8 @@ import argparse
|
||||
from datetime import datetime
|
||||
from typing import List, Dict, Any
|
||||
|
||||
# Import the web tools to test
|
||||
from web_tools import (
|
||||
# Import the web tools to test (updated path after moving tools/)
|
||||
from tools.web_tools import (
|
||||
web_search_tool,
|
||||
web_extract_tool,
|
||||
web_crawl_tool,
|
||||
67
tools/__init__.py
Normal file
67
tools/__init__.py
Normal file
@@ -0,0 +1,67 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Tools Package
|
||||
|
||||
This package contains all the specific tool implementations for the Hermes Agent.
|
||||
Each module provides specialized functionality for different capabilities:
|
||||
|
||||
- web_tools: Web search, content extraction, and crawling
|
||||
- terminal_tool: Command execution on virtual machines
|
||||
- vision_tools: Image analysis and understanding
|
||||
- mixture_of_agents_tool: Multi-model collaborative reasoning
|
||||
- image_generation_tool: Text-to-image generation with upscaling
|
||||
|
||||
The tools are imported into model_tools.py which provides a unified interface
|
||||
for the AI agent to access all capabilities.
|
||||
"""
|
||||
|
||||
# Export all tools for easy importing
|
||||
from .web_tools import (
|
||||
web_search_tool,
|
||||
web_extract_tool,
|
||||
web_crawl_tool,
|
||||
check_firecrawl_api_key
|
||||
)
|
||||
|
||||
from .terminal_tool import (
|
||||
terminal_tool,
|
||||
check_hecate_requirements,
|
||||
TERMINAL_TOOL_DESCRIPTION
|
||||
)
|
||||
|
||||
from .vision_tools import (
|
||||
vision_analyze_tool,
|
||||
check_vision_requirements
|
||||
)
|
||||
|
||||
from .mixture_of_agents_tool import (
|
||||
mixture_of_agents_tool,
|
||||
check_moa_requirements
|
||||
)
|
||||
|
||||
from .image_generation_tool import (
|
||||
image_generate_tool,
|
||||
check_image_generation_requirements
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Web tools
|
||||
'web_search_tool',
|
||||
'web_extract_tool',
|
||||
'web_crawl_tool',
|
||||
'check_firecrawl_api_key',
|
||||
# Terminal tools
|
||||
'terminal_tool',
|
||||
'check_hecate_requirements',
|
||||
'TERMINAL_TOOL_DESCRIPTION',
|
||||
# Vision tools
|
||||
'vision_analyze_tool',
|
||||
'check_vision_requirements',
|
||||
# MoA tools
|
||||
'mixture_of_agents_tool',
|
||||
'check_moa_requirements',
|
||||
# Image generation tools
|
||||
'image_generate_tool',
|
||||
'check_image_generation_requirements',
|
||||
]
|
||||
|
||||
@@ -319,9 +319,6 @@ async def image_generate_tool(
|
||||
if not prompt or not isinstance(prompt, str) or len(prompt.strip()) == 0:
|
||||
raise ValueError("Prompt is required and must be a non-empty string")
|
||||
|
||||
if len(prompt) > 1000:
|
||||
raise ValueError("Prompt must be 1000 characters or less")
|
||||
|
||||
# Check API key availability
|
||||
if not os.getenv("FAL_KEY"):
|
||||
raise ValueError("FAL_KEY environment variable not set")
|
||||
File diff suppressed because it is too large
Load Diff
@@ -4,26 +4,27 @@ Terminal Tool Module
|
||||
|
||||
This module provides a single terminal tool using Hecate's VM infrastructure.
|
||||
It wraps Hecate's functionality to provide a simple interface for executing commands
|
||||
on Morph VMs with automatic lifecycle management.
|
||||
on Morph VMs with automatic lifecycle management. VMs live for 5 minutes after last use.
|
||||
Timer resets with each use.
|
||||
|
||||
Available tool:
|
||||
- terminal_tool: Execute commands with optional interactive session support
|
||||
|
||||
Usage:
|
||||
from terminal_tool import terminal_tool
|
||||
|
||||
|
||||
# Execute a single command
|
||||
result = terminal_tool("ls -la")
|
||||
|
||||
|
||||
# Execute in an interactive session
|
||||
result = terminal_tool("python", input_keys="print('hello')\\nexit()\\n")
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import uuid
|
||||
import threading
|
||||
from typing import Optional, Dict, Any
|
||||
# from hecate import run_tool_with_lifecycle_management
|
||||
# from morphcloud._llm import ToolCall
|
||||
|
||||
# Detailed description for the terminal tool based on Hermes Terminal system prompt
|
||||
TERMINAL_TOOL_DESCRIPTION = """Execute commands on a secure, persistent Linux VM environment with full interactive application support.
|
||||
@@ -72,6 +73,12 @@ When commands enter interactive mode (vim, nano, less, git prompts, package mana
|
||||
- Test components incrementally with mock inputs
|
||||
- Install whatever tools needed - full system access provided"""
|
||||
|
||||
# Global state for VM lifecycle management
|
||||
# These persist across tool calls to enable session continuity
|
||||
_active_instance = None
|
||||
_active_context = None
|
||||
_instance_lock = threading.Lock()
|
||||
|
||||
def terminal_tool(
|
||||
command: Optional[str] = None,
|
||||
input_keys: Optional[str] = None,
|
||||
@@ -113,10 +120,60 @@ def terminal_tool(
|
||||
# Run a background task
|
||||
>>> result = terminal_tool(command="sleep 60", background=True)
|
||||
"""
|
||||
global _active_instance, _active_context
|
||||
|
||||
try:
|
||||
# Import required modules lazily so this module can be imported
|
||||
# even when hecate is not installed
|
||||
try:
|
||||
from morphcloud._llm import ToolCall
|
||||
from morphcloud.api import MorphCloudClient
|
||||
from hecate.cli import run_tool, ExecutionContext
|
||||
from rich.console import Console
|
||||
import io
|
||||
except ImportError as import_error:
|
||||
return json.dumps({
|
||||
"output": "",
|
||||
"screen": "",
|
||||
"session_id": None,
|
||||
"exit_code": -1,
|
||||
"error": f"Terminal tool is disabled due to import error: {import_error}",
|
||||
"status": "disabled"
|
||||
})
|
||||
|
||||
# Get configuration from environment
|
||||
vm_lifetime_seconds = int(os.getenv("HECATE_VM_LIFETIME_SECONDS", "300"))
|
||||
snapshot_id = os.getenv("HECATE_DEFAULT_SNAPSHOT_ID", "python-2025-10-31")
|
||||
|
||||
# Check API key
|
||||
morph_api_key = os.getenv("MORPH_API_KEY")
|
||||
if not morph_api_key:
|
||||
return json.dumps({
|
||||
"output": "",
|
||||
"screen": "",
|
||||
"session_id": None,
|
||||
"exit_code": -1,
|
||||
"error": "MORPH_API_KEY environment variable not set",
|
||||
"status": "disabled"
|
||||
})
|
||||
|
||||
# Get or create VM instance and execution context
|
||||
# This is critical for interactive session support - the context must persist!
|
||||
with _instance_lock:
|
||||
if _active_instance is None:
|
||||
morph_client = MorphCloudClient(api_key=morph_api_key)
|
||||
_active_instance = morph_client.instances.start(snapshot_id=snapshot_id)
|
||||
|
||||
# Get or create persistent execution context
|
||||
if _active_context is None:
|
||||
_active_context = ExecutionContext()
|
||||
|
||||
instance = _active_instance
|
||||
ctx = _active_context
|
||||
|
||||
# Build tool input based on provided parameters
|
||||
tool_input = {}
|
||||
|
||||
|
||||
if command:
|
||||
tool_input["command"] = command
|
||||
if input_keys:
|
||||
@@ -130,30 +187,40 @@ def terminal_tool(
|
||||
if timeout is not None:
|
||||
tool_input["timeout"] = timeout
|
||||
|
||||
# THIS IS BROKEN FOR NOW ~!!!!!!!
|
||||
|
||||
# tool_call = ToolCall(
|
||||
# name="run_command",
|
||||
# input=tool_input
|
||||
# )
|
||||
|
||||
# # Execute with lifecycle management
|
||||
# result = run_tool_with_lifecycle_management(tool_call)
|
||||
tool_call = ToolCall(
|
||||
name="run_command",
|
||||
input=tool_input
|
||||
)
|
||||
|
||||
# Create a console for output (redirect to string buffer to avoid printing)
|
||||
console_output = io.StringIO()
|
||||
console = Console(file=console_output, force_terminal=False, legacy_windows=False)
|
||||
|
||||
# Generate unique tool block ID
|
||||
tool_block_id = f"tool_{uuid.uuid4().hex[:8]}"
|
||||
|
||||
# Execute the tool with hecate
|
||||
result = run_tool(
|
||||
tool_call=tool_call,
|
||||
instance=instance,
|
||||
console=console,
|
||||
tool_block_id=tool_block_id,
|
||||
ctx=ctx
|
||||
)
|
||||
|
||||
# Format the result with all possible fields
|
||||
# Map hecate's "stdout" to "output" for compatibility
|
||||
formatted_result = {
|
||||
"output": result.get("stdout", result.get("output", "")),
|
||||
"screen": result.get("screen", ""),
|
||||
"session_id": result.get("session_id"),
|
||||
"exit_code": result.get("returncode", result.get("exit_code", -1)),
|
||||
"error": result.get("error"),
|
||||
"status": "active" if result.get("session_id") else "ended"
|
||||
}
|
||||
|
||||
return json.dumps(formatted_result)
|
||||
|
||||
|
||||
# # Format the result with all possible fields
|
||||
# # Map hecate's "stdout" to "output" for compatibility
|
||||
# formatted_result = {
|
||||
# "output": result.get("stdout", result.get("output", "")),
|
||||
# "screen": result.get("screen", ""),
|
||||
# "session_id": result.get("session_id"),
|
||||
# "exit_code": result.get("returncode", result.get("exit_code", -1)),
|
||||
# "error": result.get("error"),
|
||||
# "status": "active" if result.get("session_id") else "ended"
|
||||
# }
|
||||
|
||||
return json.dumps({})
|
||||
|
||||
except Exception as e:
|
||||
return json.dumps({
|
||||
"output": "",
|
||||
@@ -186,12 +253,16 @@ def check_hecate_requirements() -> bool:
|
||||
print(f"Warning: Missing optional environment variables: {', '.join(missing_optional)}")
|
||||
print(" (Some Hecate features may be limited)")
|
||||
|
||||
# Check if Hecate is importable
|
||||
# Check if Hecate and required modules are importable
|
||||
try:
|
||||
import hecate
|
||||
from morphcloud._llm import ToolCall
|
||||
from morphcloud.api import MorphCloudClient
|
||||
from hecate.cli import run_tool, ExecutionContext
|
||||
from rich.console import Console
|
||||
return True
|
||||
except ImportError:
|
||||
print("Hecate is not installed. Please install it with: pip install hecate")
|
||||
except Exception as e:
|
||||
print(f"Hecate not available: {e}")
|
||||
print(f"Make sure hecate is installed and MORPH_API_KEY is set.")
|
||||
return False
|
||||
|
||||
# Module-level initialization check
|
||||
@@ -234,4 +305,4 @@ if __name__ == "__main__":
|
||||
print(f" MORPH_API_KEY: {'Set' if os.getenv('MORPH_API_KEY') else 'Not set'}")
|
||||
print(f" OPENAI_API_KEY: {'Set' if os.getenv('OPENAI_API_KEY') else 'Not set (optional)'}")
|
||||
print(f" HECATE_VM_LIFETIME_SECONDS: {os.getenv('HECATE_VM_LIFETIME_SECONDS', '300')} (default: 300)")
|
||||
print(f" HECATE_DEFAULT_SNAPSHOT_ID: {os.getenv('HECATE_DEFAULT_SNAPSHOT_ID', 'snapshot_p5294qxt')} (default: snapshot_p5294qxt)")
|
||||
print(f" HECATE_DEFAULT_SNAPSHOT_ID: {os.getenv('HECATE_DEFAULT_SNAPSHOT_ID', 'snapshot_p5294qxt')} (default: snapshot_p5294qxt)")
|
||||
@@ -1,349 +1,471 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Vision Tools Module
|
||||
|
||||
This module provides vision analysis tools that work with image URLs.
|
||||
Uses Gemini Flash via Nous Research API for intelligent image understanding.
|
||||
|
||||
Available tools:
|
||||
- vision_analyze_tool: Analyze images from URLs with custom prompts
|
||||
|
||||
Features:
|
||||
- Comprehensive image description
|
||||
- Context-aware analysis based on user queries
|
||||
- Proper error handling and validation
|
||||
- Debug logging support
|
||||
|
||||
Usage:
|
||||
from vision_tools import vision_analyze_tool
|
||||
import asyncio
|
||||
|
||||
# Analyze an image
|
||||
result = await vision_analyze_tool(
|
||||
image_url="https://example.com/image.jpg",
|
||||
user_prompt="What architectural style is this building?"
|
||||
)
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import asyncio
|
||||
import uuid
|
||||
from dotenv import load_dotenv
|
||||
import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, Optional
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
load_dotenv()
|
||||
|
||||
# Initialize Nous Research API client for vision processing
|
||||
nous_client = AsyncOpenAI(
|
||||
api_key=os.getenv("NOUS_API_KEY"),
|
||||
base_url="https://inference-api.nousresearch.com/v1"
|
||||
)
|
||||
|
||||
# Configuration for vision processing
|
||||
DEFAULT_VISION_MODEL = "gemini-2.5-flash"
|
||||
|
||||
# Debug mode configuration
|
||||
DEBUG_MODE = os.getenv("VISION_TOOLS_DEBUG", "false").lower() == "true"
|
||||
DEBUG_SESSION_ID = str(uuid.uuid4())
|
||||
DEBUG_LOG_PATH = Path("./logs")
|
||||
DEBUG_DATA = {
|
||||
"session_id": DEBUG_SESSION_ID,
|
||||
"start_time": datetime.datetime.now().isoformat(),
|
||||
"debug_enabled": DEBUG_MODE,
|
||||
"tool_calls": []
|
||||
} if DEBUG_MODE else None
|
||||
|
||||
# Create logs directory if debug mode is enabled
|
||||
if DEBUG_MODE:
|
||||
DEBUG_LOG_PATH.mkdir(exist_ok=True)
|
||||
print(f"🐛 Vision debug mode enabled - Session ID: {DEBUG_SESSION_ID}")
|
||||
|
||||
|
||||
def _log_debug_call(tool_name: str, call_data: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Log a debug call entry to the global debug data structure.
|
||||
|
||||
Args:
|
||||
tool_name (str): Name of the tool being called
|
||||
call_data (Dict[str, Any]): Data about the call including parameters and results
|
||||
"""
|
||||
if not DEBUG_MODE or not DEBUG_DATA:
|
||||
return
|
||||
|
||||
call_entry = {
|
||||
"timestamp": datetime.datetime.now().isoformat(),
|
||||
"tool_name": tool_name,
|
||||
**call_data
|
||||
}
|
||||
|
||||
DEBUG_DATA["tool_calls"].append(call_entry)
|
||||
|
||||
|
||||
def _save_debug_log() -> None:
|
||||
"""
|
||||
Save the current debug data to a JSON file in the logs directory.
|
||||
"""
|
||||
if not DEBUG_MODE or not DEBUG_DATA:
|
||||
return
|
||||
|
||||
try:
|
||||
debug_filename = f"vision_tools_debug_{DEBUG_SESSION_ID}.json"
|
||||
debug_filepath = DEBUG_LOG_PATH / debug_filename
|
||||
|
||||
# Update end time
|
||||
DEBUG_DATA["end_time"] = datetime.datetime.now().isoformat()
|
||||
DEBUG_DATA["total_calls"] = len(DEBUG_DATA["tool_calls"])
|
||||
|
||||
with open(debug_filepath, 'w', encoding='utf-8') as f:
|
||||
json.dump(DEBUG_DATA, f, indent=2, ensure_ascii=False)
|
||||
|
||||
print(f"🐛 Vision debug log saved: {debug_filepath}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error saving vision debug log: {str(e)}")
|
||||
|
||||
|
||||
def _validate_image_url(url: str) -> bool:
|
||||
"""
|
||||
Basic validation of image URL format.
|
||||
|
||||
Args:
|
||||
url (str): The URL to validate
|
||||
|
||||
Returns:
|
||||
bool: True if URL appears to be valid, False otherwise
|
||||
"""
|
||||
if not url or not isinstance(url, str):
|
||||
return False
|
||||
|
||||
# Check if it's a valid URL format
|
||||
if not (url.startswith('http://') or url.startswith('https://')):
|
||||
return False
|
||||
|
||||
# Check for common image extensions (optional, as URLs may not have extensions)
|
||||
image_extensions = ['.jpg', '.jpeg', '.png', '.gif', '.bmp', '.webp', '.svg']
|
||||
|
||||
return True # Allow all HTTP/HTTPS URLs for flexibility
|
||||
|
||||
|
||||
async def vision_analyze_tool(
|
||||
image_url: str,
|
||||
user_prompt: str,
|
||||
model: str = DEFAULT_VISION_MODEL
|
||||
) -> str:
|
||||
"""
|
||||
Analyze an image from a URL using vision AI.
|
||||
|
||||
This tool processes images using Gemini Flash via Nous Research API.
|
||||
The user_prompt parameter is expected to be pre-formatted by the calling
|
||||
function (typically model_tools.py) to include both full description
|
||||
requests and specific questions.
|
||||
|
||||
Args:
|
||||
image_url (str): The URL of the image to analyze
|
||||
user_prompt (str): The pre-formatted prompt for the vision model
|
||||
model (str): The vision model to use (default: gemini-2.5-flash)
|
||||
|
||||
Returns:
|
||||
str: JSON string containing the analysis results with the following structure:
|
||||
{
|
||||
"success": bool,
|
||||
"analysis": str (defaults to error message if None)
|
||||
}
|
||||
|
||||
Raises:
|
||||
Exception: If analysis fails or API key is not set
|
||||
"""
|
||||
debug_call_data = {
|
||||
"parameters": {
|
||||
"image_url": image_url,
|
||||
"user_prompt": user_prompt,
|
||||
"model": model
|
||||
},
|
||||
"error": None,
|
||||
"success": False,
|
||||
"analysis_length": 0,
|
||||
"model_used": model
|
||||
}
|
||||
|
||||
try:
|
||||
print(f"🔍 Analyzing image from URL: {image_url[:60]}{'...' if len(image_url) > 60 else ''}")
|
||||
print(f"📝 User prompt: {user_prompt[:100]}{'...' if len(user_prompt) > 100 else ''}")
|
||||
|
||||
# Validate image URL
|
||||
if not _validate_image_url(image_url):
|
||||
raise ValueError("Invalid image URL format. Must start with http:// or https://")
|
||||
|
||||
# Check API key availability
|
||||
if not os.getenv("NOUS_API_KEY"):
|
||||
raise ValueError("NOUS_API_KEY environment variable not set")
|
||||
|
||||
# Use the prompt as provided (model_tools.py now handles full description formatting)
|
||||
comprehensive_prompt = user_prompt
|
||||
|
||||
# Prepare the message with image URL format
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": comprehensive_prompt
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": image_url
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
print(f"🧠 Processing image with {model}...")
|
||||
|
||||
# Call the vision API
|
||||
response = await nous_client.chat.completions.create(
|
||||
model=model,
|
||||
messages=messages,
|
||||
temperature=0.1, # Low temperature for consistent analysis
|
||||
max_tokens=2000 # Generous limit for detailed analysis
|
||||
)
|
||||
|
||||
# Extract the analysis
|
||||
analysis = response.choices[0].message.content.strip()
|
||||
analysis_length = len(analysis)
|
||||
|
||||
print(f"✅ Image analysis completed ({analysis_length} characters)")
|
||||
|
||||
# Prepare successful response
|
||||
result = {
|
||||
"success": True,
|
||||
"analysis": analysis or "There was a problem with the request and the image could not be analyzed."
|
||||
}
|
||||
|
||||
debug_call_data["success"] = True
|
||||
debug_call_data["analysis_length"] = analysis_length
|
||||
|
||||
# Log debug information
|
||||
_log_debug_call("vision_analyze_tool", debug_call_data)
|
||||
_save_debug_log()
|
||||
|
||||
return json.dumps(result, indent=2)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Error analyzing image: {str(e)}"
|
||||
print(f"❌ {error_msg}")
|
||||
|
||||
# Prepare error response
|
||||
result = {
|
||||
"success": False,
|
||||
"analysis": "There was a problem with the request and the image could not be analyzed."
|
||||
}
|
||||
|
||||
debug_call_data["error"] = error_msg
|
||||
_log_debug_call("vision_analyze_tool", debug_call_data)
|
||||
_save_debug_log()
|
||||
|
||||
return json.dumps(result, indent=2)
|
||||
|
||||
|
||||
def check_nous_api_key() -> bool:
|
||||
"""
|
||||
Check if the Nous Research API key is available in environment variables.
|
||||
|
||||
Returns:
|
||||
bool: True if API key is set, False otherwise
|
||||
"""
|
||||
return bool(os.getenv("NOUS_API_KEY"))
|
||||
|
||||
|
||||
def check_vision_requirements() -> bool:
|
||||
"""
|
||||
Check if all requirements for vision tools are met.
|
||||
|
||||
Returns:
|
||||
bool: True if requirements are met, False otherwise
|
||||
"""
|
||||
return check_nous_api_key()
|
||||
|
||||
|
||||
def get_debug_session_info() -> Dict[str, Any]:
|
||||
"""
|
||||
Get information about the current debug session.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Dictionary containing debug session information
|
||||
"""
|
||||
if not DEBUG_MODE or not DEBUG_DATA:
|
||||
return {
|
||||
"enabled": False,
|
||||
"session_id": None,
|
||||
"log_path": None,
|
||||
"total_calls": 0
|
||||
}
|
||||
|
||||
return {
|
||||
"enabled": True,
|
||||
"session_id": DEBUG_SESSION_ID,
|
||||
"log_path": str(DEBUG_LOG_PATH / f"vision_tools_debug_{DEBUG_SESSION_ID}.json"),
|
||||
"total_calls": len(DEBUG_DATA["tool_calls"])
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""
|
||||
Simple test/demo when run directly
|
||||
"""
|
||||
print("👁️ Vision Tools Module")
|
||||
print("=" * 40)
|
||||
|
||||
# Check if API key is available
|
||||
api_available = check_nous_api_key()
|
||||
|
||||
if not api_available:
|
||||
print("❌ NOUS_API_KEY environment variable not set")
|
||||
print("Please set your API key: export NOUS_API_KEY='your-key-here'")
|
||||
print("Get API key at: https://inference-api.nousresearch.com/")
|
||||
exit(1)
|
||||
else:
|
||||
print("✅ Nous Research API key found")
|
||||
|
||||
print("🛠️ Vision tools ready for use!")
|
||||
print(f"🧠 Using model: {DEFAULT_VISION_MODEL}")
|
||||
|
||||
# Show debug mode status
|
||||
if DEBUG_MODE:
|
||||
print(f"🐛 Debug mode ENABLED - Session ID: {DEBUG_SESSION_ID}")
|
||||
print(f" Debug logs will be saved to: ./logs/vision_tools_debug_{DEBUG_SESSION_ID}.json")
|
||||
else:
|
||||
print("🐛 Debug mode disabled (set VISION_TOOLS_DEBUG=true to enable)")
|
||||
|
||||
print("\nBasic usage:")
|
||||
print(" from vision_tools import vision_analyze_tool")
|
||||
print(" import asyncio")
|
||||
print("")
|
||||
print(" async def main():")
|
||||
print(" result = await vision_analyze_tool(")
|
||||
print(" image_url='https://example.com/image.jpg',")
|
||||
print(" user_prompt='What do you see in this image?'")
|
||||
print(" )")
|
||||
print(" print(result)")
|
||||
print(" asyncio.run(main())")
|
||||
|
||||
print("\nExample prompts:")
|
||||
print(" - 'What architectural style is this building?'")
|
||||
print(" - 'Describe the emotions and mood in this image'")
|
||||
print(" - 'What text can you read in this image?'")
|
||||
print(" - 'Identify any safety hazards visible'")
|
||||
print(" - 'What products or brands are shown?'")
|
||||
|
||||
print("\nDebug mode:")
|
||||
print(" # Enable debug logging")
|
||||
print(" export VISION_TOOLS_DEBUG=true")
|
||||
print(" # Debug logs capture all vision analysis calls and results")
|
||||
print(" # Logs saved to: ./logs/vision_tools_debug_UUID.json")
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Vision Tools Module
|
||||
|
||||
This module provides vision analysis tools that work with image URLs.
|
||||
Uses Gemini Flash via Nous Research API for intelligent image understanding.
|
||||
|
||||
Available tools:
|
||||
- vision_analyze_tool: Analyze images from URLs with custom prompts
|
||||
|
||||
Features:
|
||||
- Downloads images from URLs and converts to base64 for API compatibility
|
||||
- Comprehensive image description
|
||||
- Context-aware analysis based on user queries
|
||||
- Automatic temporary file cleanup
|
||||
- Proper error handling and validation
|
||||
- Debug logging support
|
||||
|
||||
Usage:
|
||||
from vision_tools import vision_analyze_tool
|
||||
import asyncio
|
||||
|
||||
# Analyze an image
|
||||
result = await vision_analyze_tool(
|
||||
image_url="https://example.com/image.jpg",
|
||||
user_prompt="What architectural style is this building?"
|
||||
)
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import asyncio
|
||||
import uuid
|
||||
import datetime
|
||||
import base64
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, Optional
|
||||
from openai import AsyncOpenAI
|
||||
import httpx # Use httpx for async HTTP requests
|
||||
|
||||
# Initialize Nous Research API client for vision processing
|
||||
nous_client = AsyncOpenAI(
|
||||
api_key=os.getenv("NOUS_API_KEY"),
|
||||
base_url="https://inference-api.nousresearch.com/v1"
|
||||
)
|
||||
|
||||
# Configuration for vision processing
|
||||
DEFAULT_VISION_MODEL = "gemini-2.5-flash"
|
||||
|
||||
# Debug mode configuration
|
||||
DEBUG_MODE = os.getenv("VISION_TOOLS_DEBUG", "false").lower() == "true"
|
||||
DEBUG_SESSION_ID = str(uuid.uuid4())
|
||||
DEBUG_LOG_PATH = Path("./logs")
|
||||
DEBUG_DATA = {
|
||||
"session_id": DEBUG_SESSION_ID,
|
||||
"start_time": datetime.datetime.now().isoformat(),
|
||||
"debug_enabled": DEBUG_MODE,
|
||||
"tool_calls": []
|
||||
} if DEBUG_MODE else None
|
||||
|
||||
# Create logs directory if debug mode is enabled
|
||||
if DEBUG_MODE:
|
||||
DEBUG_LOG_PATH.mkdir(exist_ok=True)
|
||||
print(f"🐛 Vision debug mode enabled - Session ID: {DEBUG_SESSION_ID}")
|
||||
|
||||
|
||||
def _log_debug_call(tool_name: str, call_data: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Log a debug call entry to the global debug data structure.
|
||||
|
||||
Args:
|
||||
tool_name (str): Name of the tool being called
|
||||
call_data (Dict[str, Any]): Data about the call including parameters and results
|
||||
"""
|
||||
if not DEBUG_MODE or not DEBUG_DATA:
|
||||
return
|
||||
|
||||
call_entry = {
|
||||
"timestamp": datetime.datetime.now().isoformat(),
|
||||
"tool_name": tool_name,
|
||||
**call_data
|
||||
}
|
||||
|
||||
DEBUG_DATA["tool_calls"].append(call_entry)
|
||||
|
||||
|
||||
def _save_debug_log() -> None:
|
||||
"""
|
||||
Save the current debug data to a JSON file in the logs directory.
|
||||
"""
|
||||
if not DEBUG_MODE or not DEBUG_DATA:
|
||||
return
|
||||
|
||||
try:
|
||||
debug_filename = f"vision_tools_debug_{DEBUG_SESSION_ID}.json"
|
||||
debug_filepath = DEBUG_LOG_PATH / debug_filename
|
||||
|
||||
# Update end time
|
||||
DEBUG_DATA["end_time"] = datetime.datetime.now().isoformat()
|
||||
DEBUG_DATA["total_calls"] = len(DEBUG_DATA["tool_calls"])
|
||||
|
||||
with open(debug_filepath, 'w', encoding='utf-8') as f:
|
||||
json.dump(DEBUG_DATA, f, indent=2, ensure_ascii=False)
|
||||
|
||||
print(f"🐛 Vision debug log saved: {debug_filepath}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error saving vision debug log: {str(e)}")
|
||||
|
||||
|
||||
def _validate_image_url(url: str) -> bool:
|
||||
"""
|
||||
Basic validation of image URL format.
|
||||
|
||||
Args:
|
||||
url (str): The URL to validate
|
||||
|
||||
Returns:
|
||||
bool: True if URL appears to be valid, False otherwise
|
||||
"""
|
||||
if not url or not isinstance(url, str):
|
||||
return False
|
||||
|
||||
# Check if it's a valid URL format
|
||||
if not (url.startswith('http://') or url.startswith('https://')):
|
||||
return False
|
||||
|
||||
# Check for common image extensions (optional, as URLs may not have extensions)
|
||||
image_extensions = ['.jpg', '.jpeg', '.png', '.gif', '.bmp', '.webp', '.svg']
|
||||
|
||||
return True # Allow all HTTP/HTTPS URLs for flexibility
|
||||
|
||||
|
||||
async def _download_image(image_url: str, destination: Path) -> Path:
|
||||
"""
|
||||
Download an image from a URL to a local destination (async).
|
||||
|
||||
Args:
|
||||
image_url (str): The URL of the image to download
|
||||
destination (Path): The path where the image should be saved
|
||||
|
||||
Returns:
|
||||
Path: The path to the downloaded image
|
||||
|
||||
Raises:
|
||||
Exception: If download fails or response is invalid
|
||||
"""
|
||||
# Create parent directories if they don't exist
|
||||
destination.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Download the image with appropriate headers using async httpx
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.get(
|
||||
image_url,
|
||||
headers={"User-Agent": "hermes-agent-vision/1.0"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
# Save the image content
|
||||
destination.write_bytes(response.content)
|
||||
|
||||
return destination
|
||||
|
||||
|
||||
def _determine_mime_type(image_path: Path) -> str:
|
||||
"""
|
||||
Determine the MIME type of an image based on its file extension.
|
||||
|
||||
Args:
|
||||
image_path (Path): Path to the image file
|
||||
|
||||
Returns:
|
||||
str: The MIME type (defaults to image/jpeg if unknown)
|
||||
"""
|
||||
extension = image_path.suffix.lower()
|
||||
mime_types = {
|
||||
'.jpg': 'image/jpeg',
|
||||
'.jpeg': 'image/jpeg',
|
||||
'.png': 'image/png',
|
||||
'.gif': 'image/gif',
|
||||
'.bmp': 'image/bmp',
|
||||
'.webp': 'image/webp',
|
||||
'.svg': 'image/svg+xml'
|
||||
}
|
||||
return mime_types.get(extension, 'image/jpeg')
|
||||
|
||||
|
||||
def _image_to_base64_data_url(image_path: Path, mime_type: Optional[str] = None) -> str:
|
||||
"""
|
||||
Convert an image file to a base64-encoded data URL.
|
||||
|
||||
Args:
|
||||
image_path (Path): Path to the image file
|
||||
mime_type (Optional[str]): MIME type of the image (auto-detected if None)
|
||||
|
||||
Returns:
|
||||
str: Base64-encoded data URL (e.g., "data:image/jpeg;base64,...")
|
||||
"""
|
||||
# Read the image as bytes
|
||||
data = image_path.read_bytes()
|
||||
|
||||
# Encode to base64
|
||||
encoded = base64.b64encode(data).decode("ascii")
|
||||
|
||||
# Determine MIME type
|
||||
mime = mime_type or _determine_mime_type(image_path)
|
||||
|
||||
# Create data URL
|
||||
data_url = f"data:{mime};base64,{encoded}"
|
||||
|
||||
return data_url
|
||||
|
||||
|
||||
async def vision_analyze_tool(
|
||||
image_url: str,
|
||||
user_prompt: str,
|
||||
model: str = DEFAULT_VISION_MODEL
|
||||
) -> str:
|
||||
"""
|
||||
Analyze an image from a URL using vision AI.
|
||||
|
||||
This tool downloads images from URLs, converts them to base64, and processes
|
||||
them using Gemini Flash via Nous Research API. The image is downloaded to a
|
||||
temporary location and automatically cleaned up after processing.
|
||||
|
||||
The user_prompt parameter is expected to be pre-formatted by the calling
|
||||
function (typically model_tools.py) to include both full description
|
||||
requests and specific questions.
|
||||
|
||||
Args:
|
||||
image_url (str): The URL of the image to analyze (must be http:// or https://)
|
||||
user_prompt (str): The pre-formatted prompt for the vision model
|
||||
model (str): The vision model to use (default: gemini-2.5-flash)
|
||||
|
||||
Returns:
|
||||
str: JSON string containing the analysis results with the following structure:
|
||||
{
|
||||
"success": bool,
|
||||
"analysis": str (defaults to error message if None)
|
||||
}
|
||||
|
||||
Raises:
|
||||
Exception: If download fails, analysis fails, or API key is not set
|
||||
|
||||
Note:
|
||||
- Temporary images are stored in ./temp_vision_images/
|
||||
- Images are automatically deleted after processing
|
||||
- Supports common image formats (JPEG, PNG, GIF, WebP, etc.)
|
||||
"""
|
||||
debug_call_data = {
|
||||
"parameters": {
|
||||
"image_url": image_url,
|
||||
"user_prompt": user_prompt[:200] + "..." if len(user_prompt) > 200 else user_prompt,
|
||||
"model": model
|
||||
},
|
||||
"error": None,
|
||||
"success": False,
|
||||
"analysis_length": 0,
|
||||
"model_used": model,
|
||||
"image_size_bytes": 0
|
||||
}
|
||||
|
||||
temp_image_path = None
|
||||
|
||||
try:
|
||||
print(f"🔍 Analyzing image from URL: {image_url[:60]}{'...' if len(image_url) > 60 else ''}", flush=True)
|
||||
print(f"📝 User prompt: {user_prompt[:100]}{'...' if len(user_prompt) > 100 else ''}", flush=True)
|
||||
|
||||
# Validate image URL
|
||||
if not _validate_image_url(image_url):
|
||||
raise ValueError("Invalid image URL format. Must start with http:// or https://")
|
||||
|
||||
# Check API key availability
|
||||
if not os.getenv("NOUS_API_KEY"):
|
||||
raise ValueError("NOUS_API_KEY environment variable not set")
|
||||
|
||||
# Download the image to a temporary location
|
||||
print(f"⬇️ Downloading image from URL...", flush=True)
|
||||
temp_dir = Path("./temp_vision_images")
|
||||
temp_image_path = temp_dir / f"temp_image_{uuid.uuid4()}.jpg"
|
||||
|
||||
await _download_image(image_url, temp_image_path)
|
||||
|
||||
# Get image file size for logging
|
||||
image_size_bytes = temp_image_path.stat().st_size
|
||||
image_size_kb = image_size_bytes / 1024
|
||||
print(f"✅ Image downloaded successfully ({image_size_kb:.1f} KB)", flush=True)
|
||||
|
||||
# Convert image to base64 data URL
|
||||
print(f"🔄 Converting image to base64...", flush=True)
|
||||
image_data_url = _image_to_base64_data_url(temp_image_path)
|
||||
# Calculate size in KB for better readability
|
||||
data_size_kb = len(image_data_url) / 1024
|
||||
print(f"✅ Image converted to base64 ({data_size_kb:.1f} KB)", flush=True)
|
||||
|
||||
debug_call_data["image_size_bytes"] = image_size_bytes
|
||||
|
||||
# Use the prompt as provided (model_tools.py now handles full description formatting)
|
||||
comprehensive_prompt = user_prompt
|
||||
|
||||
# Prepare the message with base64-encoded image
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": comprehensive_prompt
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": image_data_url
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
print(f"🧠 Processing image with {model}...", flush=True)
|
||||
|
||||
# Call the vision API
|
||||
response = await nous_client.chat.completions.create(
|
||||
model=model,
|
||||
messages=messages,
|
||||
temperature=0.1, # Low temperature for consistent analysis
|
||||
max_tokens=2000 # Generous limit for detailed analysis
|
||||
)
|
||||
|
||||
# Extract the analysis
|
||||
analysis = response.choices[0].message.content.strip()
|
||||
analysis_length = len(analysis)
|
||||
|
||||
print(f"✅ Image analysis completed ({analysis_length} characters)", flush=True)
|
||||
|
||||
# Prepare successful response
|
||||
result = {
|
||||
"success": True,
|
||||
"analysis": analysis or "There was a problem with the request and the image could not be analyzed."
|
||||
}
|
||||
|
||||
debug_call_data["success"] = True
|
||||
debug_call_data["analysis_length"] = analysis_length
|
||||
|
||||
# Log debug information
|
||||
_log_debug_call("vision_analyze_tool", debug_call_data)
|
||||
_save_debug_log()
|
||||
|
||||
return json.dumps(result, indent=2)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Error analyzing image: {str(e)}"
|
||||
print(f"❌ {error_msg}", flush=True)
|
||||
|
||||
# Prepare error response
|
||||
result = {
|
||||
"success": False,
|
||||
"analysis": "There was a problem with the request and the image could not be analyzed."
|
||||
}
|
||||
|
||||
debug_call_data["error"] = error_msg
|
||||
_log_debug_call("vision_analyze_tool", debug_call_data)
|
||||
_save_debug_log()
|
||||
|
||||
return json.dumps(result, indent=2)
|
||||
|
||||
finally:
|
||||
# Clean up temporary image file
|
||||
if temp_image_path and temp_image_path.exists():
|
||||
try:
|
||||
temp_image_path.unlink()
|
||||
print(f"🧹 Cleaned up temporary image file", flush=True)
|
||||
except Exception as cleanup_error:
|
||||
print(f"⚠️ Warning: Could not delete temporary file: {cleanup_error}", flush=True)
|
||||
|
||||
|
||||
def check_nous_api_key() -> bool:
|
||||
"""
|
||||
Check if the Nous Research API key is available in environment variables.
|
||||
|
||||
Returns:
|
||||
bool: True if API key is set, False otherwise
|
||||
"""
|
||||
return bool(os.getenv("NOUS_API_KEY"))
|
||||
|
||||
|
||||
def check_vision_requirements() -> bool:
|
||||
"""
|
||||
Check if all requirements for vision tools are met.
|
||||
|
||||
Returns:
|
||||
bool: True if requirements are met, False otherwise
|
||||
"""
|
||||
return check_nous_api_key()
|
||||
|
||||
|
||||
def get_debug_session_info() -> Dict[str, Any]:
|
||||
"""
|
||||
Get information about the current debug session.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Dictionary containing debug session information
|
||||
"""
|
||||
if not DEBUG_MODE or not DEBUG_DATA:
|
||||
return {
|
||||
"enabled": False,
|
||||
"session_id": None,
|
||||
"log_path": None,
|
||||
"total_calls": 0
|
||||
}
|
||||
|
||||
return {
|
||||
"enabled": True,
|
||||
"session_id": DEBUG_SESSION_ID,
|
||||
"log_path": str(DEBUG_LOG_PATH / f"vision_tools_debug_{DEBUG_SESSION_ID}.json"),
|
||||
"total_calls": len(DEBUG_DATA["tool_calls"])
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""
|
||||
Simple test/demo when run directly
|
||||
"""
|
||||
print("👁️ Vision Tools Module")
|
||||
print("=" * 40)
|
||||
|
||||
# Check if API key is available
|
||||
api_available = check_nous_api_key()
|
||||
|
||||
if not api_available:
|
||||
print("❌ NOUS_API_KEY environment variable not set")
|
||||
print("Please set your API key: export NOUS_API_KEY='your-key-here'")
|
||||
print("Get API key at: https://inference-api.nousresearch.com/")
|
||||
exit(1)
|
||||
else:
|
||||
print("✅ Nous Research API key found")
|
||||
|
||||
print("🛠️ Vision tools ready for use!")
|
||||
print(f"🧠 Using model: {DEFAULT_VISION_MODEL}")
|
||||
|
||||
# Show debug mode status
|
||||
if DEBUG_MODE:
|
||||
print(f"🐛 Debug mode ENABLED - Session ID: {DEBUG_SESSION_ID}")
|
||||
print(f" Debug logs will be saved to: ./logs/vision_tools_debug_{DEBUG_SESSION_ID}.json")
|
||||
else:
|
||||
print("🐛 Debug mode disabled (set VISION_TOOLS_DEBUG=true to enable)")
|
||||
|
||||
print("\nBasic usage:")
|
||||
print(" from vision_tools import vision_analyze_tool")
|
||||
print(" import asyncio")
|
||||
print("")
|
||||
print(" async def main():")
|
||||
print(" result = await vision_analyze_tool(")
|
||||
print(" image_url='https://example.com/image.jpg',")
|
||||
print(" user_prompt='What do you see in this image?'")
|
||||
print(" )")
|
||||
print(" print(result)")
|
||||
print(" asyncio.run(main())")
|
||||
|
||||
print("\nExample prompts:")
|
||||
print(" - 'What architectural style is this building?'")
|
||||
print(" - 'Describe the emotions and mood in this image'")
|
||||
print(" - 'What text can you read in this image?'")
|
||||
print(" - 'Identify any safety hazards visible'")
|
||||
print(" - 'What products or brands are shown?'")
|
||||
|
||||
print("\nDebug mode:")
|
||||
print(" # Enable debug logging")
|
||||
print(" export VISION_TOOLS_DEBUG=true")
|
||||
print(" # Debug logs capture all vision analysis calls and results")
|
||||
print(" # Logs saved to: ./logs/vision_tools_debug_UUID.json")
|
||||
File diff suppressed because it is too large
Load Diff
270
toolset_distributions.py
Normal file
270
toolset_distributions.py
Normal file
@@ -0,0 +1,270 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Toolset Distributions Module
|
||||
|
||||
This module defines distributions of toolsets for data generation runs.
|
||||
Each distribution specifies which toolsets should be used and their probability
|
||||
of being selected for any given prompt during the batch processing.
|
||||
|
||||
A distribution is a dictionary mapping toolset names to their selection probability (%).
|
||||
Probabilities should sum to 100, but the system will normalize if they don't.
|
||||
|
||||
Usage:
|
||||
from toolset_distributions import get_distribution, list_distributions
|
||||
|
||||
# Get a specific distribution
|
||||
dist = get_distribution("image_gen")
|
||||
|
||||
# List all available distributions
|
||||
all_dists = list_distributions()
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Optional
|
||||
import random
|
||||
from toolsets import validate_toolset
|
||||
|
||||
|
||||
# Distribution definitions
|
||||
# Each key is a distribution name, and the value is a dict of toolset_name: probability_percentage
|
||||
DISTRIBUTIONS = {
|
||||
# Default: All tools available 100% of the time
|
||||
"default": {
|
||||
"description": "All available tools, all the time",
|
||||
"toolsets": {
|
||||
"web": 100,
|
||||
"vision": 100,
|
||||
"image_gen": 100,
|
||||
"terminal": 100,
|
||||
"moa": 100
|
||||
}
|
||||
},
|
||||
|
||||
# Image generation focused distribution
|
||||
"image_gen": {
|
||||
"description": "Heavy focus on image generation with vision and web support",
|
||||
"toolsets": {
|
||||
"image_gen": 90, # 80% chance of image generation tools
|
||||
"vision": 90, # 60% chance of vision tools
|
||||
"web": 55, # 40% chance of web tools
|
||||
"terminal": 45,
|
||||
"moa": 10 # 20% chance of reasoning tools
|
||||
}
|
||||
},
|
||||
|
||||
# Research-focused distribution
|
||||
"research": {
|
||||
"description": "Web research with vision analysis and reasoning",
|
||||
"toolsets": {
|
||||
"web": 90, # 90% chance of web tools
|
||||
"vision": 50, # 50% chance of vision tools
|
||||
"moa": 40, # 40% chance of reasoning tools
|
||||
"terminal": 10 # 10% chance of terminal tools
|
||||
}
|
||||
},
|
||||
|
||||
# Development-focused distribution
|
||||
"development": {
|
||||
"description": "Terminal and reasoning with occasional web lookup",
|
||||
"toolsets": {
|
||||
"terminal": 80, # 80% chance of terminal tools
|
||||
"moa": 60, # 60% chance of reasoning tools
|
||||
"web": 30, # 30% chance of web tools
|
||||
"vision": 10 # 10% chance of vision tools
|
||||
}
|
||||
},
|
||||
|
||||
# Safe mode (no terminal)
|
||||
"safe": {
|
||||
"description": "All tools except terminal for safety",
|
||||
"toolsets": {
|
||||
"web": 80,
|
||||
"vision": 60,
|
||||
"image_gen": 60,
|
||||
"moa": 50
|
||||
}
|
||||
},
|
||||
|
||||
# Balanced distribution
|
||||
"balanced": {
|
||||
"description": "Equal probability of all toolsets",
|
||||
"toolsets": {
|
||||
"web": 50,
|
||||
"vision": 50,
|
||||
"image_gen": 50,
|
||||
"terminal": 50,
|
||||
"moa": 50
|
||||
}
|
||||
},
|
||||
|
||||
# Minimal (web only)
|
||||
"minimal": {
|
||||
"description": "Only web tools for basic research",
|
||||
"toolsets": {
|
||||
"web": 100
|
||||
}
|
||||
},
|
||||
|
||||
# Creative (vision + image generation)
|
||||
"creative": {
|
||||
"description": "Image generation and vision analysis focus",
|
||||
"toolsets": {
|
||||
"image_gen": 90,
|
||||
"vision": 90,
|
||||
"web": 30
|
||||
}
|
||||
},
|
||||
|
||||
# Reasoning heavy
|
||||
"reasoning": {
|
||||
"description": "Heavy mixture of agents usage with minimal other tools",
|
||||
"toolsets": {
|
||||
"moa": 90,
|
||||
"web": 30,
|
||||
"terminal": 20
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def get_distribution(name: str) -> Optional[Dict[str, any]]:
|
||||
"""
|
||||
Get a toolset distribution by name.
|
||||
|
||||
Args:
|
||||
name (str): Name of the distribution
|
||||
|
||||
Returns:
|
||||
Dict: Distribution definition with description and toolsets
|
||||
None: If distribution not found
|
||||
"""
|
||||
return DISTRIBUTIONS.get(name)
|
||||
|
||||
|
||||
def list_distributions() -> Dict[str, Dict]:
|
||||
"""
|
||||
List all available distributions.
|
||||
|
||||
Returns:
|
||||
Dict: All distribution definitions
|
||||
"""
|
||||
return DISTRIBUTIONS.copy()
|
||||
|
||||
|
||||
def sample_toolsets_from_distribution(distribution_name: str) -> List[str]:
|
||||
"""
|
||||
Sample toolsets based on a distribution's probabilities.
|
||||
|
||||
Each toolset in the distribution has a % chance of being included.
|
||||
This allows multiple toolsets to be active simultaneously.
|
||||
|
||||
Args:
|
||||
distribution_name (str): Name of the distribution to sample from
|
||||
|
||||
Returns:
|
||||
List[str]: List of sampled toolset names
|
||||
|
||||
Raises:
|
||||
ValueError: If distribution name is not found
|
||||
"""
|
||||
dist = get_distribution(distribution_name)
|
||||
if not dist:
|
||||
raise ValueError(f"Unknown distribution: {distribution_name}")
|
||||
|
||||
# Sample each toolset independently based on its probability
|
||||
selected_toolsets = []
|
||||
|
||||
for toolset_name, probability in dist["toolsets"].items():
|
||||
# Validate toolset exists
|
||||
if not validate_toolset(toolset_name):
|
||||
print(f"⚠️ Warning: Toolset '{toolset_name}' in distribution '{distribution_name}' is not valid")
|
||||
continue
|
||||
|
||||
# Roll the dice - if random value is less than probability, include this toolset
|
||||
if random.random() * 100 < probability:
|
||||
selected_toolsets.append(toolset_name)
|
||||
|
||||
# If no toolsets were selected (can happen with low probabilities),
|
||||
# ensure at least one toolset is selected by picking the highest probability one
|
||||
if not selected_toolsets and dist["toolsets"]:
|
||||
# Find toolset with highest probability
|
||||
highest_prob_toolset = max(dist["toolsets"].items(), key=lambda x: x[1])[0]
|
||||
if validate_toolset(highest_prob_toolset):
|
||||
selected_toolsets.append(highest_prob_toolset)
|
||||
|
||||
return selected_toolsets
|
||||
|
||||
|
||||
def validate_distribution(distribution_name: str) -> bool:
|
||||
"""
|
||||
Check if a distribution name is valid.
|
||||
|
||||
Args:
|
||||
distribution_name (str): Distribution name to validate
|
||||
|
||||
Returns:
|
||||
bool: True if valid, False otherwise
|
||||
"""
|
||||
return distribution_name in DISTRIBUTIONS
|
||||
|
||||
|
||||
def print_distribution_info(distribution_name: str) -> None:
|
||||
"""
|
||||
Print detailed information about a distribution.
|
||||
|
||||
Args:
|
||||
distribution_name (str): Distribution name
|
||||
"""
|
||||
dist = get_distribution(distribution_name)
|
||||
if not dist:
|
||||
print(f"❌ Unknown distribution: {distribution_name}")
|
||||
return
|
||||
|
||||
print(f"\n📊 Distribution: {distribution_name}")
|
||||
print(f" Description: {dist['description']}")
|
||||
print(f" Toolsets:")
|
||||
for toolset, prob in sorted(dist["toolsets"].items(), key=lambda x: x[1], reverse=True):
|
||||
print(f" • {toolset:15} : {prob:3}% chance")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""
|
||||
Demo and testing of the distributions system
|
||||
"""
|
||||
print("📊 Toolset Distributions Demo")
|
||||
print("=" * 60)
|
||||
|
||||
# List all distributions
|
||||
print("\n📋 Available Distributions:")
|
||||
print("-" * 40)
|
||||
for name, dist in list_distributions().items():
|
||||
print(f"\n {name}:")
|
||||
print(f" {dist['description']}")
|
||||
toolset_list = ", ".join([f"{ts}({p}%)" for ts, p in dist["toolsets"].items()])
|
||||
print(f" Toolsets: {toolset_list}")
|
||||
|
||||
# Demo sampling
|
||||
print("\n\n🎲 Sampling Examples:")
|
||||
print("-" * 40)
|
||||
|
||||
test_distributions = ["image_gen", "research", "balanced", "default"]
|
||||
|
||||
for dist_name in test_distributions:
|
||||
print(f"\n{dist_name}:")
|
||||
# Sample 5 times to show variability
|
||||
samples = []
|
||||
for _ in range(5):
|
||||
sampled = sample_toolsets_from_distribution(dist_name)
|
||||
samples.append(sorted(sampled))
|
||||
|
||||
print(f" Sample 1: {samples[0]}")
|
||||
print(f" Sample 2: {samples[1]}")
|
||||
print(f" Sample 3: {samples[2]}")
|
||||
print(f" Sample 4: {samples[3]}")
|
||||
print(f" Sample 5: {samples[4]}")
|
||||
|
||||
# Show detailed info
|
||||
print("\n\n📊 Detailed Distribution Info:")
|
||||
print("-" * 40)
|
||||
print_distribution_info("image_gen")
|
||||
print_distribution_info("research")
|
||||
|
||||
13
toolsets.py
13
toolsets.py
@@ -110,6 +110,16 @@ def resolve_toolset(name: str, visited: Set[str] = None) -> List[str]:
|
||||
if visited is None:
|
||||
visited = set()
|
||||
|
||||
# Special aliases that represent all tools across every toolset
|
||||
# This ensures future toolsets are automatically included without changes.
|
||||
if name in {"all", "*"}:
|
||||
all_tools: Set[str] = set()
|
||||
for toolset_name in get_toolset_names():
|
||||
# Use a fresh visited set per branch to avoid cross-branch contamination
|
||||
resolved = resolve_toolset(toolset_name, visited.copy())
|
||||
all_tools.update(resolved)
|
||||
return list(all_tools)
|
||||
|
||||
# Check for cycles
|
||||
if name in visited:
|
||||
print(f"⚠️ Circular dependency detected in toolset '{name}'")
|
||||
@@ -184,6 +194,9 @@ def validate_toolset(name: str) -> bool:
|
||||
Returns:
|
||||
bool: True if valid, False otherwise
|
||||
"""
|
||||
# Accept special alias names for convenience
|
||||
if name in {"all", "*"}:
|
||||
return True
|
||||
return name in TOOLSETS
|
||||
|
||||
|
||||
|
||||
@@ -1,23 +0,0 @@
|
||||
"""
|
||||
Hermes Agent UI Package
|
||||
|
||||
A modular PySide6 UI for the Hermes AI Agent with real-time event streaming.
|
||||
|
||||
Modules:
|
||||
- websocket_client: WebSocket communication
|
||||
- event_widgets: Event display components
|
||||
- main_window: Main application window
|
||||
- hermes_ui: Application entry point
|
||||
"""
|
||||
|
||||
from .websocket_client import WebSocketClient
|
||||
from .event_widgets import CollapsibleEventWidget, InteractiveEventDisplayWidget
|
||||
from .main_window import HermesMainWindow
|
||||
|
||||
__all__ = [
|
||||
'WebSocketClient',
|
||||
'CollapsibleEventWidget',
|
||||
'InteractiveEventDisplayWidget',
|
||||
'HermesMainWindow',
|
||||
]
|
||||
|
||||
@@ -1,334 +0,0 @@
|
||||
"""
|
||||
Event display widgets for Hermes Agent UI.
|
||||
|
||||
This module provides widgets for displaying and managing real-time agent events
|
||||
in a collapsible, filterable interface.
|
||||
"""
|
||||
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Dict, Any
|
||||
|
||||
from PySide6.QtWidgets import (
|
||||
QWidget, QVBoxLayout, QHBoxLayout, QLabel, QPushButton,
|
||||
QCheckBox, QGroupBox, QFrame, QScrollArea
|
||||
)
|
||||
from PySide6.QtCore import Qt, QTimer
|
||||
from PySide6.QtGui import QFont
|
||||
|
||||
|
||||
class CollapsibleEventWidget(QFrame):
|
||||
"""
|
||||
A single collapsible event with expand/collapse functionality.
|
||||
"""
|
||||
|
||||
def __init__(self, event: Dict[str, Any], parent=None):
|
||||
super().__init__(parent)
|
||||
self.event = event
|
||||
self.is_expanded = False
|
||||
self.event_type = event.get("event_type", "unknown")
|
||||
|
||||
self.setFrameStyle(QFrame.Box | QFrame.Raised)
|
||||
self.setLineWidth(1)
|
||||
self.setup_ui()
|
||||
|
||||
def setup_ui(self):
|
||||
"""Initialize UI components."""
|
||||
layout = QVBoxLayout()
|
||||
layout.setContentsMargins(8, 8, 8, 8)
|
||||
layout.setSpacing(4)
|
||||
|
||||
# Header (clickable)
|
||||
self.header_widget = QWidget()
|
||||
header_layout = QHBoxLayout()
|
||||
header_layout.setContentsMargins(0, 0, 0, 0)
|
||||
|
||||
self.expand_indicator = QLabel("▶")
|
||||
self.expand_indicator.setFixedWidth(20)
|
||||
header_layout.addWidget(self.expand_indicator)
|
||||
|
||||
self.summary_label = QLabel()
|
||||
self.summary_label.setFont(QFont("Arial", 10, QFont.Bold))
|
||||
self.update_summary()
|
||||
header_layout.addWidget(self.summary_label, 1)
|
||||
|
||||
# Timestamp
|
||||
timestamp = self.event.get("timestamp", datetime.now().isoformat())
|
||||
time_str = datetime.fromisoformat(timestamp.replace('Z', '+00:00')).strftime("%H:%M:%S")
|
||||
time_label = QLabel(time_str)
|
||||
time_label.setStyleSheet("color: #888;")
|
||||
header_layout.addWidget(time_label)
|
||||
|
||||
self.header_widget.setLayout(header_layout)
|
||||
self.header_widget.mousePressEvent = lambda e: self.toggle_expand()
|
||||
self.header_widget.setCursor(Qt.PointingHandCursor)
|
||||
|
||||
layout.addWidget(self.header_widget)
|
||||
|
||||
# Details (collapsible)
|
||||
self.details_widget = QWidget()
|
||||
self.details_layout = QVBoxLayout()
|
||||
self.details_layout.setContentsMargins(25, 5, 5, 5)
|
||||
self.populate_details()
|
||||
self.details_widget.setLayout(self.details_layout)
|
||||
self.details_widget.setVisible(False)
|
||||
|
||||
layout.addWidget(self.details_widget)
|
||||
|
||||
self.setLayout(layout)
|
||||
self.apply_colors()
|
||||
|
||||
def apply_colors(self):
|
||||
"""Apply color scheme based on event type."""
|
||||
colors = {
|
||||
"query": "#E8F5E9", # Light green
|
||||
"api_call": "#E3F2FD", # Light blue
|
||||
"response": "#F3E5F5", # Light purple
|
||||
"tool_call": "#FFF3E0", # Light orange
|
||||
"tool_result": "#E8F5E9", # Light green
|
||||
"complete": "#E8F5E9", # Light green
|
||||
"error": "#FFEBEE", # Light red
|
||||
"session_start": "#F5F5F5" # Light gray
|
||||
}
|
||||
|
||||
bg_color = colors.get(self.event_type, "#FAFAFA")
|
||||
self.setStyleSheet(f"""
|
||||
CollapsibleEventWidget {{
|
||||
background-color: {bg_color};
|
||||
border: 1px solid #ddd;
|
||||
border-radius: 4px;
|
||||
}}
|
||||
""")
|
||||
|
||||
def update_summary(self):
|
||||
"""Update the summary label with event type."""
|
||||
self.summary_label.setText(f"- {self.event_type.upper()}")
|
||||
|
||||
def populate_details(self):
|
||||
"""Populate the details section with event data."""
|
||||
data = self.event.get("data", {})
|
||||
|
||||
# Clear existing details
|
||||
while self.details_layout.count():
|
||||
item = self.details_layout.takeAt(0)
|
||||
if item.widget():
|
||||
item.widget().deleteLater()
|
||||
|
||||
self.add_detail("Raw Data", json.dumps(data, indent=2), multiline=True)
|
||||
|
||||
def add_detail(self, label: str, value: str, multiline: bool = True):
|
||||
"""Add a detail row to the details section."""
|
||||
detail_widget = QWidget()
|
||||
detail_layout = QVBoxLayout() if multiline else QHBoxLayout()
|
||||
detail_layout.setContentsMargins(0, 2, 0, 2)
|
||||
|
||||
label_widget = QLabel(f"<b>{label}:</b>")
|
||||
label_widget.setTextFormat(Qt.RichText)
|
||||
|
||||
value_widget = QLabel(value)
|
||||
value_widget.setWordWrap(True)
|
||||
value_widget.setTextInteractionFlags(Qt.TextSelectableByMouse)
|
||||
|
||||
if multiline:
|
||||
font = QFont()
|
||||
font.setStyleHint(QFont.Monospace)
|
||||
font.setPointSize(9)
|
||||
value_widget.setFont(font)
|
||||
value_widget.setStyleSheet("background-color: #f5f5f5; padding: 5px; border-radius: 3px;")
|
||||
detail_layout.addWidget(label_widget)
|
||||
detail_layout.addWidget(value_widget)
|
||||
else:
|
||||
detail_layout.addWidget(label_widget)
|
||||
detail_layout.addWidget(value_widget, 1)
|
||||
|
||||
detail_widget.setLayout(detail_layout)
|
||||
self.details_layout.addWidget(detail_widget)
|
||||
|
||||
def toggle_expand(self):
|
||||
"""Toggle expanded/collapsed state."""
|
||||
self.is_expanded = not self.is_expanded
|
||||
self.details_widget.setVisible(self.is_expanded)
|
||||
self.expand_indicator.setText("▼" if self.is_expanded else "▶")
|
||||
|
||||
|
||||
class InteractiveEventDisplayWidget(QWidget):
|
||||
"""
|
||||
Interactive widget for displaying real-time agent events.
|
||||
|
||||
Features:
|
||||
- Collapsible event items
|
||||
- Event type filtering
|
||||
- Expand/collapse all
|
||||
- Auto-scroll to latest events
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.events = []
|
||||
self.event_widgets = []
|
||||
self.current_session = None
|
||||
self.filters = {
|
||||
"query": True,
|
||||
"api_call": True,
|
||||
"response": True,
|
||||
"tool_call": True,
|
||||
"tool_result": True,
|
||||
"complete": True,
|
||||
"error": True,
|
||||
"session_start": True
|
||||
}
|
||||
self.init_ui()
|
||||
|
||||
def init_ui(self):
|
||||
"""Initialize the UI components."""
|
||||
layout = QVBoxLayout()
|
||||
layout.setContentsMargins(5, 5, 5, 5)
|
||||
|
||||
# Header with controls
|
||||
header_layout = QHBoxLayout()
|
||||
|
||||
title = QLabel("📡 Real-time Event Stream")
|
||||
title.setFont(QFont("Arial", 12, QFont.Bold))
|
||||
header_layout.addWidget(title)
|
||||
|
||||
header_layout.addStretch()
|
||||
|
||||
# Expand/Collapse All buttons
|
||||
expand_all_btn = QPushButton("Expand All")
|
||||
expand_all_btn.clicked.connect(self.expand_all)
|
||||
header_layout.addWidget(expand_all_btn)
|
||||
|
||||
collapse_all_btn = QPushButton("Collapse All")
|
||||
collapse_all_btn.clicked.connect(self.collapse_all)
|
||||
header_layout.addWidget(collapse_all_btn)
|
||||
|
||||
# Clear button
|
||||
clear_btn = QPushButton("🗑️ Clear")
|
||||
clear_btn.clicked.connect(self.clear_events)
|
||||
header_layout.addWidget(clear_btn)
|
||||
|
||||
layout.addLayout(header_layout)
|
||||
|
||||
# Filter controls
|
||||
filter_group = QGroupBox("Event Filters (Show/Hide)")
|
||||
filter_layout = QHBoxLayout()
|
||||
filter_layout.setSpacing(10)
|
||||
|
||||
self.filter_checkboxes = {}
|
||||
filter_configs = [
|
||||
("query", "📝 Queries"),
|
||||
("api_call", "🔄 API Calls"),
|
||||
("response", "🤖 Responses"),
|
||||
("tool_call", "🔧 Tool Calls"),
|
||||
("tool_result", "✅ Results"),
|
||||
("complete", "🎉 Complete"),
|
||||
("error", "❌ Errors"),
|
||||
]
|
||||
|
||||
for event_type, label in filter_configs:
|
||||
checkbox = QCheckBox(label)
|
||||
checkbox.setChecked(True)
|
||||
checkbox.stateChanged.connect(lambda state, et=event_type: self.toggle_filter(et, state))
|
||||
self.filter_checkboxes[event_type] = checkbox
|
||||
filter_layout.addWidget(checkbox)
|
||||
|
||||
filter_group.setLayout(filter_layout)
|
||||
layout.addWidget(filter_group)
|
||||
|
||||
# Scroll area for events
|
||||
scroll_area = QScrollArea()
|
||||
scroll_area.setWidgetResizable(True)
|
||||
scroll_area.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff)
|
||||
|
||||
# Container for event widgets
|
||||
self.events_container = QWidget()
|
||||
self.events_layout = QVBoxLayout()
|
||||
self.events_layout.setSpacing(5)
|
||||
self.events_layout.addStretch() # Push events to top
|
||||
self.events_container.setLayout(self.events_layout)
|
||||
|
||||
scroll_area.setWidget(self.events_container)
|
||||
layout.addWidget(scroll_area)
|
||||
|
||||
self.setLayout(layout)
|
||||
|
||||
def clear_events(self):
|
||||
"""Clear all displayed events."""
|
||||
self.events.clear()
|
||||
self.event_widgets.clear()
|
||||
|
||||
# Remove all widgets
|
||||
while self.events_layout.count() > 1: # Keep the stretch
|
||||
item = self.events_layout.takeAt(0)
|
||||
if item.widget():
|
||||
item.widget().deleteLater()
|
||||
|
||||
self.current_session = None
|
||||
|
||||
def add_event(self, event: Dict[str, Any]):
|
||||
"""Add an event to the display."""
|
||||
event_type = event.get("event_type", "unknown")
|
||||
session_id = event.get("session_id", "")
|
||||
|
||||
# Track session changes - add session start event
|
||||
if self.current_session != session_id:
|
||||
self.current_session = session_id
|
||||
session_event = {
|
||||
"event_type": "session_start",
|
||||
"session_id": session_id,
|
||||
"timestamp": event.get("timestamp", datetime.now().isoformat()),
|
||||
"data": {
|
||||
"session_id": session_id,
|
||||
"start_time": event.get("timestamp", datetime.now().isoformat())
|
||||
}
|
||||
}
|
||||
self._add_event_widget(session_event)
|
||||
|
||||
# Add the actual event
|
||||
self._add_event_widget(event)
|
||||
|
||||
def _add_event_widget(self, event: Dict[str, Any]):
|
||||
"""Internal method to add event widget."""
|
||||
event_widget = CollapsibleEventWidget(event)
|
||||
|
||||
# Apply filter visibility
|
||||
event_type = event.get("event_type", "unknown")
|
||||
event_widget.setVisible(self.filters.get(event_type, True))
|
||||
|
||||
# Insert before the stretch
|
||||
self.events_layout.insertWidget(self.events_layout.count() - 1, event_widget)
|
||||
|
||||
self.events.append(event)
|
||||
self.event_widgets.append(event_widget)
|
||||
|
||||
# Auto-scroll to bottom after widget is rendered
|
||||
QTimer.singleShot(50, self._scroll_to_bottom)
|
||||
|
||||
def _scroll_to_bottom(self):
|
||||
"""Scroll to the bottom of the events list."""
|
||||
scroll_area = self.events_container.parent()
|
||||
if isinstance(scroll_area, QScrollArea):
|
||||
scroll_bar = scroll_area.verticalScrollBar()
|
||||
scroll_bar.setValue(scroll_bar.maximum())
|
||||
|
||||
def expand_all(self):
|
||||
"""Expand all event widgets."""
|
||||
for widget in self.event_widgets:
|
||||
if not widget.is_expanded:
|
||||
widget.toggle_expand()
|
||||
|
||||
def collapse_all(self):
|
||||
"""Collapse all event widgets."""
|
||||
for widget in self.event_widgets:
|
||||
if widget.is_expanded:
|
||||
widget.toggle_expand()
|
||||
|
||||
def toggle_filter(self, event_type: str, state: int):
|
||||
"""Toggle visibility of events by type."""
|
||||
self.filters[event_type] = bool(state)
|
||||
|
||||
# Update visibility of existing widgets
|
||||
for event, widget in zip(self.events, self.event_widgets):
|
||||
if event.get("event_type") == event_type:
|
||||
widget.setVisible(self.filters[event_type])
|
||||
|
||||
102
ui/hermes_ui.py
102
ui/hermes_ui.py
@@ -1,102 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Hermes Agent - PySide6 Frontend
|
||||
|
||||
A modern desktop UI for the Hermes AI Agent with real-time event streaming.
|
||||
|
||||
Features:
|
||||
- Query input with multi-line support
|
||||
- Tool/toolset selection
|
||||
- Model and API configuration
|
||||
- Real-time event display via WebSocket
|
||||
- Beautiful, responsive UI with dark theme
|
||||
- Session history
|
||||
- Safe exit handling (no segfaults)
|
||||
|
||||
Usage:
|
||||
python hermes_ui.py
|
||||
"""
|
||||
|
||||
import sys
|
||||
import signal
|
||||
import os
|
||||
|
||||
# Suppress Qt logging warnings BEFORE importing Qt
|
||||
os.environ['QT_LOGGING_RULES'] = 'qt.qpa.*=false'
|
||||
|
||||
from PySide6.QtWidgets import QApplication
|
||||
from PySide6.QtCore import QTimer
|
||||
|
||||
from main_window import HermesMainWindow
|
||||
|
||||
|
||||
def setup_signal_handlers(app: QApplication) -> QTimer:
|
||||
"""
|
||||
Setup signal handlers for graceful shutdown on Ctrl+C.
|
||||
|
||||
This prevents segmentation faults by:
|
||||
1. Catching SIGINT/SIGTERM signals
|
||||
2. Creating a timer that keeps Python responsive to signals
|
||||
3. Calling app.quit() for proper Qt cleanup
|
||||
|
||||
Args:
|
||||
app: The QApplication instance
|
||||
|
||||
Returns:
|
||||
Timer that keeps Python interpreter responsive to signals
|
||||
"""
|
||||
def signal_handler(signum, frame):
|
||||
"""Handle interrupt signals gracefully."""
|
||||
print("\n🛑 Interrupt received, shutting down gracefully...")
|
||||
app.quit()
|
||||
|
||||
signal.signal(signal.SIGINT, signal_handler) # Ctrl+C
|
||||
signal.signal(signal.SIGTERM, signal_handler) # Termination signal
|
||||
|
||||
# CRITICAL: Create a timer to wake up Python interpreter periodically
|
||||
# This allows Python to process signals while Qt's event loop is running
|
||||
# Without this, Ctrl+C will not work and may cause segfaults
|
||||
timer = QTimer()
|
||||
timer.timeout.connect(lambda: None) # Empty callback just to wake up Python
|
||||
timer.start(100) # Check every 100ms
|
||||
|
||||
return timer
|
||||
|
||||
|
||||
def main():
|
||||
"""Main entry point for the application."""
|
||||
# Create application
|
||||
app = QApplication(sys.argv)
|
||||
|
||||
# Set application metadata
|
||||
app.setApplicationName("Hermes Agent")
|
||||
app.setOrganizationName("Hermes")
|
||||
app.setApplicationVersion("1.0.0")
|
||||
|
||||
# Setup signal handlers for safe Ctrl+C handling (prevents segfaults!)
|
||||
timer = setup_signal_handlers(app)
|
||||
|
||||
# Apply dark theme (optional)
|
||||
# Uncomment to enable dark mode
|
||||
# app.setStyle("Fusion")
|
||||
# palette = QPalette()
|
||||
# palette.setColor(QPalette.Window, QColor(53, 53, 53))
|
||||
# palette.setColor(QPalette.WindowText, Qt.white)
|
||||
# app.setPalette(palette)
|
||||
|
||||
# Create and show main window
|
||||
window = HermesMainWindow()
|
||||
window.show()
|
||||
|
||||
print("✨ Hermes Agent UI started")
|
||||
print(" Press Ctrl+C to exit gracefully")
|
||||
|
||||
# Start event loop
|
||||
exit_code = app.exec()
|
||||
|
||||
print("👋 Hermes Agent UI closed")
|
||||
sys.exit(exit_code)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,375 +0,0 @@
|
||||
"""
|
||||
Main window for Hermes Agent UI.
|
||||
|
||||
This module provides the main application window with controls for
|
||||
submitting queries, configuring settings, and viewing real-time events.
|
||||
"""
|
||||
|
||||
import requests
|
||||
from typing import Dict, Any
|
||||
|
||||
from PySide6.QtWidgets import (
|
||||
QMainWindow, QWidget, QVBoxLayout, QHBoxLayout, QTextEdit,
|
||||
QPushButton, QLabel, QLineEdit, QComboBox, QCheckBox,
|
||||
QGroupBox, QSplitter, QListWidget, QListWidgetItem,
|
||||
QSpinBox, QMessageBox
|
||||
)
|
||||
from PySide6.QtCore import Qt, Slot, QTimer
|
||||
from PySide6.QtGui import QFont
|
||||
|
||||
from .websocket_client import WebSocketClient
|
||||
from .event_widgets import InteractiveEventDisplayWidget
|
||||
|
||||
|
||||
class HermesMainWindow(QMainWindow):
|
||||
"""
|
||||
Main window for Hermes Agent UI.
|
||||
|
||||
Provides interface for:
|
||||
- Submitting queries
|
||||
- Configuring agent settings
|
||||
- Viewing real-time events
|
||||
- Managing sessions
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.api_base_url = "http://localhost:8000"
|
||||
self.ws_client = None
|
||||
self.current_session_id = None
|
||||
self.available_toolsets = []
|
||||
self.is_closing = False # Flag to prevent reconnection during shutdown
|
||||
|
||||
self.init_ui()
|
||||
self.setup_websocket()
|
||||
self.load_available_tools()
|
||||
|
||||
def init_ui(self):
|
||||
"""Initialize the user interface."""
|
||||
self.setWindowTitle("Hermes Agent - AI Assistant UI")
|
||||
self.setGeometry(100, 100, 1400, 900)
|
||||
|
||||
# Central widget
|
||||
central_widget = QWidget()
|
||||
self.setCentralWidget(central_widget)
|
||||
|
||||
# Main layout (horizontal split)
|
||||
main_layout = QHBoxLayout()
|
||||
|
||||
# Left panel: Controls
|
||||
left_panel = self.create_control_panel()
|
||||
|
||||
# Right panel: Event display
|
||||
right_panel = self.create_event_panel()
|
||||
|
||||
# Splitter for resizable panels
|
||||
splitter = QSplitter(Qt.Horizontal)
|
||||
splitter.addWidget(left_panel)
|
||||
splitter.addWidget(right_panel)
|
||||
splitter.setStretchFactor(0, 1) # Control panel
|
||||
splitter.setStretchFactor(1, 2) # Event panel (larger)
|
||||
|
||||
main_layout.addWidget(splitter)
|
||||
central_widget.setLayout(main_layout)
|
||||
|
||||
# Status bar
|
||||
self.statusBar().showMessage("Ready")
|
||||
|
||||
def create_control_panel(self) -> QWidget:
|
||||
"""Create the left control panel."""
|
||||
panel = QWidget()
|
||||
layout = QVBoxLayout()
|
||||
|
||||
# Title
|
||||
title = QLabel("🤖 Hermes Agent Control")
|
||||
title.setFont(QFont("Arial", 14, QFont.Bold))
|
||||
title.setAlignment(Qt.AlignCenter)
|
||||
layout.addWidget(title)
|
||||
|
||||
# Query input group
|
||||
query_group = QGroupBox("Query Input")
|
||||
query_layout = QVBoxLayout()
|
||||
|
||||
self.query_input = QTextEdit()
|
||||
self.query_input.setPlaceholderText("Enter your query here...")
|
||||
self.query_input.setMaximumHeight(150)
|
||||
query_layout.addWidget(self.query_input)
|
||||
|
||||
self.submit_btn = QPushButton("🚀 Submit Query")
|
||||
self.submit_btn.setFont(QFont("Arial", 11, QFont.Bold))
|
||||
self.submit_btn.setStyleSheet("QPushButton { background-color: #4CAF50; color: white; padding: 10px; }")
|
||||
self.submit_btn.clicked.connect(self.submit_query)
|
||||
query_layout.addWidget(self.submit_btn)
|
||||
|
||||
query_group.setLayout(query_layout)
|
||||
layout.addWidget(query_group)
|
||||
|
||||
# Model configuration group
|
||||
model_group = QGroupBox("Model Configuration")
|
||||
model_layout = QVBoxLayout()
|
||||
|
||||
# Model selection
|
||||
model_layout.addWidget(QLabel("Model:"))
|
||||
self.model_combo = QComboBox()
|
||||
self.model_combo.addItems([
|
||||
"claude-sonnet-4-5-20250929",
|
||||
"claude-opus-4-20250514",
|
||||
"gpt-4",
|
||||
"gpt-4-turbo"
|
||||
])
|
||||
model_layout.addWidget(self.model_combo)
|
||||
|
||||
# API Base URL
|
||||
model_layout.addWidget(QLabel("API Base URL:"))
|
||||
self.base_url_input = QLineEdit("https://api.anthropic.com/v1/")
|
||||
model_layout.addWidget(self.base_url_input)
|
||||
|
||||
# Max turns
|
||||
model_layout.addWidget(QLabel("Max Turns:"))
|
||||
self.max_turns_spin = QSpinBox()
|
||||
self.max_turns_spin.setMinimum(1)
|
||||
self.max_turns_spin.setMaximum(50)
|
||||
self.max_turns_spin.setValue(10)
|
||||
model_layout.addWidget(self.max_turns_spin)
|
||||
|
||||
model_group.setLayout(model_layout)
|
||||
layout.addWidget(model_group)
|
||||
|
||||
# Tools configuration group
|
||||
tools_group = QGroupBox("Tools & Toolsets")
|
||||
tools_layout = QVBoxLayout()
|
||||
|
||||
tools_layout.addWidget(QLabel("Select Toolsets:"))
|
||||
self.toolsets_list = QListWidget()
|
||||
self.toolsets_list.setSelectionMode(QListWidget.MultiSelection)
|
||||
self.toolsets_list.setMaximumHeight(150)
|
||||
tools_layout.addWidget(self.toolsets_list)
|
||||
|
||||
tools_group.setLayout(tools_layout)
|
||||
layout.addWidget(tools_group)
|
||||
|
||||
# Options group
|
||||
options_group = QGroupBox("Options")
|
||||
options_layout = QVBoxLayout()
|
||||
|
||||
self.mock_mode_checkbox = QCheckBox("Mock Web Tools (Testing)")
|
||||
options_layout.addWidget(self.mock_mode_checkbox)
|
||||
|
||||
self.verbose_checkbox = QCheckBox("Verbose Logging")
|
||||
options_layout.addWidget(self.verbose_checkbox)
|
||||
|
||||
options_layout.addWidget(QLabel("Mock Delay (seconds):"))
|
||||
self.mock_delay_spin = QSpinBox()
|
||||
self.mock_delay_spin.setMinimum(1)
|
||||
self.mock_delay_spin.setMaximum(300)
|
||||
self.mock_delay_spin.setValue(60)
|
||||
options_layout.addWidget(self.mock_delay_spin)
|
||||
|
||||
options_group.setLayout(options_layout)
|
||||
layout.addWidget(options_group)
|
||||
|
||||
# Connection status
|
||||
self.connection_status = QLabel("🔴 Disconnected")
|
||||
self.connection_status.setAlignment(Qt.AlignCenter)
|
||||
self.connection_status.setStyleSheet("QLabel { padding: 5px; background-color: #F44336; color: white; border-radius: 3px; }")
|
||||
layout.addWidget(self.connection_status)
|
||||
|
||||
# Add stretch to push everything to top
|
||||
layout.addStretch()
|
||||
|
||||
panel.setLayout(layout)
|
||||
return panel
|
||||
|
||||
def create_event_panel(self) -> QWidget:
|
||||
"""Create the right event display panel."""
|
||||
panel = QWidget()
|
||||
layout = QVBoxLayout()
|
||||
|
||||
# Event display widget
|
||||
self.event_widget = InteractiveEventDisplayWidget()
|
||||
layout.addWidget(self.event_widget)
|
||||
|
||||
panel.setLayout(layout)
|
||||
return panel
|
||||
|
||||
def setup_websocket(self):
|
||||
"""Setup WebSocket connection for real-time events."""
|
||||
self.ws_client = WebSocketClient("ws://localhost:8000/ws")
|
||||
|
||||
# Connect signals
|
||||
self.ws_client.connected.connect(self.on_ws_connected)
|
||||
self.ws_client.disconnected.connect(self.on_ws_disconnected)
|
||||
self.ws_client.error.connect(self.on_ws_error)
|
||||
self.ws_client.event_received.connect(self.on_event_received)
|
||||
|
||||
# Start connection
|
||||
self.ws_client.connect()
|
||||
|
||||
@Slot()
|
||||
def on_ws_connected(self):
|
||||
"""Called when WebSocket connection is established."""
|
||||
self.connection_status.setText("🟢 Connected")
|
||||
self.connection_status.setStyleSheet("QLabel { padding: 5px; background-color: #4CAF50; color: white; border-radius: 3px; }")
|
||||
self.statusBar().showMessage("WebSocket connected")
|
||||
|
||||
@Slot()
|
||||
def on_ws_disconnected(self):
|
||||
"""Called when WebSocket connection is lost."""
|
||||
# Don't attempt reconnection if we're closing the application
|
||||
if self.is_closing:
|
||||
return
|
||||
|
||||
self.connection_status.setText("🔴 Disconnected")
|
||||
self.connection_status.setStyleSheet("QLabel { padding: 5px; background-color: #F44336; color: white; border-radius: 3px; }")
|
||||
self.statusBar().showMessage("WebSocket disconnected - attempting reconnect...")
|
||||
|
||||
# Attempt reconnect after 5 seconds
|
||||
QTimer.singleShot(5000, self.ws_client.connect)
|
||||
|
||||
@Slot(str)
|
||||
def on_ws_error(self, error: str):
|
||||
"""Called when WebSocket error occurs."""
|
||||
self.statusBar().showMessage(f"WebSocket error: {error}")
|
||||
|
||||
@Slot(dict)
|
||||
def on_event_received(self, event: Dict[str, Any]):
|
||||
"""
|
||||
Called when an event is received from WebSocket.
|
||||
|
||||
Args:
|
||||
event: Event data from server
|
||||
"""
|
||||
self.event_widget.add_event(event)
|
||||
|
||||
# Update status for specific events
|
||||
event_type = event.get("event_type")
|
||||
if event_type == "query":
|
||||
self.statusBar().showMessage("Query received - agent processing...")
|
||||
elif event_type == "complete":
|
||||
self.statusBar().showMessage("Agent completed!")
|
||||
self.submit_btn.setEnabled(True)
|
||||
|
||||
def load_available_tools(self):
|
||||
"""Load available toolsets from the API."""
|
||||
try:
|
||||
response = requests.get(f"{self.api_base_url}/tools", timeout=5)
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
toolsets = data.get("toolsets", [])
|
||||
|
||||
self.available_toolsets = toolsets
|
||||
self.toolsets_list.clear()
|
||||
|
||||
for toolset in toolsets:
|
||||
name = toolset.get("name", "")
|
||||
description = toolset.get("description", "")
|
||||
tool_count = toolset.get("tool_count", 0)
|
||||
|
||||
item_text = f"{name} ({tool_count} tools) - {description}"
|
||||
item = QListWidgetItem(item_text)
|
||||
item.setData(Qt.UserRole, name) # Store toolset name
|
||||
self.toolsets_list.addItem(item)
|
||||
|
||||
self.statusBar().showMessage(f"Loaded {len(toolsets)} toolsets")
|
||||
else:
|
||||
self.statusBar().showMessage("Failed to load toolsets from API")
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
self.statusBar().showMessage(f"Error loading toolsets: {str(e)}")
|
||||
# Add some default toolsets
|
||||
default_toolsets = ["web", "vision", "terminal", "research"]
|
||||
for ts in default_toolsets:
|
||||
item = QListWidgetItem(f"{ts} (default)")
|
||||
item.setData(Qt.UserRole, ts)
|
||||
self.toolsets_list.addItem(item)
|
||||
|
||||
@Slot()
|
||||
def submit_query(self):
|
||||
"""Submit query to the agent API."""
|
||||
query = self.query_input.toPlainText().strip()
|
||||
|
||||
if not query:
|
||||
QMessageBox.warning(self, "No Query", "Please enter a query first!")
|
||||
return
|
||||
|
||||
# Get selected toolsets
|
||||
selected_toolsets = []
|
||||
for i in range(self.toolsets_list.count()):
|
||||
item = self.toolsets_list.item(i)
|
||||
if item.isSelected():
|
||||
toolset_name = item.data(Qt.UserRole)
|
||||
selected_toolsets.append(toolset_name)
|
||||
|
||||
# Build request payload
|
||||
payload = {
|
||||
"query": query,
|
||||
"model": self.model_combo.currentText(),
|
||||
"base_url": self.base_url_input.text(),
|
||||
"max_turns": self.max_turns_spin.value(),
|
||||
"enabled_toolsets": selected_toolsets if selected_toolsets else None,
|
||||
"mock_web_tools": self.mock_mode_checkbox.isChecked(),
|
||||
"mock_delay": self.mock_delay_spin.value(),
|
||||
"verbose": self.verbose_checkbox.isChecked()
|
||||
}
|
||||
|
||||
# Disable submit button during execution
|
||||
self.submit_btn.setEnabled(False)
|
||||
self.submit_btn.setText("⏳ Running...")
|
||||
self.statusBar().showMessage("Submitting query to agent...")
|
||||
|
||||
# Submit to API
|
||||
try:
|
||||
response = requests.post(
|
||||
f"{self.api_base_url}/agent/run",
|
||||
json=payload,
|
||||
timeout=10
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
session_id = result.get("session_id", "")
|
||||
self.current_session_id = session_id
|
||||
|
||||
self.statusBar().showMessage(f"Agent started! Session: {session_id[:8]}...")
|
||||
|
||||
# Clear event display for new session (optional)
|
||||
# self.event_widget.clear_events()
|
||||
|
||||
else:
|
||||
QMessageBox.warning(
|
||||
self,
|
||||
"API Error",
|
||||
f"Failed to start agent: {response.status_code}\n{response.text}"
|
||||
)
|
||||
self.submit_btn.setEnabled(True)
|
||||
self.submit_btn.setText("🚀 Submit Query")
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
QMessageBox.critical(
|
||||
self,
|
||||
"Connection Error",
|
||||
f"Failed to connect to API server:\n{str(e)}\n\nMake sure the server is running:\npython logging_server.py"
|
||||
)
|
||||
self.submit_btn.setEnabled(True)
|
||||
self.submit_btn.setText("🚀 Submit Query")
|
||||
|
||||
# Re-enable button after short delay (UI feedback)
|
||||
QTimer.singleShot(2000, lambda: self.submit_btn.setText("🚀 Submit Query"))
|
||||
|
||||
def cleanup(self):
|
||||
"""Clean up resources before exit."""
|
||||
print("Cleaning up resources...")
|
||||
self.is_closing = True
|
||||
|
||||
if self.ws_client:
|
||||
try:
|
||||
self.ws_client.disconnect()
|
||||
except Exception as e:
|
||||
print(f"Error disconnecting WebSocket: {e}")
|
||||
|
||||
def closeEvent(self, event):
|
||||
"""Handle window close event - ensures clean shutdown."""
|
||||
print("Closing application...")
|
||||
self.cleanup()
|
||||
event.accept()
|
||||
|
||||
@@ -1,115 +0,0 @@
|
||||
#!/bin/bash
|
||||
# Hermes Agent UI Launcher
|
||||
#
|
||||
# This script starts both the API server and UI application.
|
||||
# It will run them in the background and provide a clean shutdown.
|
||||
|
||||
set -e
|
||||
|
||||
# Colors for output
|
||||
GREEN='\033[0;32m'
|
||||
BLUE='\033[0;34m'
|
||||
RED='\033[0;31m'
|
||||
NC='\033[0m' # No Color
|
||||
|
||||
echo -e "${BLUE}🚀 Hermes Agent UI Launcher${NC}"
|
||||
echo "================================"
|
||||
echo ""
|
||||
|
||||
# Check if Python is available
|
||||
if ! command -v python3 &> /dev/null; then
|
||||
echo -e "${RED}❌ Python 3 not found. Please install Python 3.${NC}"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Check if virtual environment exists
|
||||
if [ -d "../../env" ]; then
|
||||
echo -e "${GREEN}✓ Activating virtual environment${NC}"
|
||||
source ../../env/bin/activate
|
||||
else
|
||||
echo -e "${BLUE}ℹ No virtual environment found, using system Python${NC}"
|
||||
fi
|
||||
|
||||
# Check dependencies
|
||||
echo -e "${BLUE}Checking dependencies...${NC}"
|
||||
python3 -c "import PySide6" 2>/dev/null || {
|
||||
echo -e "${RED}❌ PySide6 not installed${NC}"
|
||||
echo -e "${BLUE}Installing dependencies...${NC}"
|
||||
pip install -r ../requirements.txt
|
||||
}
|
||||
|
||||
# Check for API keys
|
||||
if [ -z "$ANTHROPIC_API_KEY" ]; then
|
||||
echo -e "${RED}⚠️ Warning: ANTHROPIC_API_KEY not set${NC}"
|
||||
echo " Set it with: export ANTHROPIC_API_KEY='your-key'"
|
||||
echo ""
|
||||
fi
|
||||
|
||||
# Function to cleanup on exit
|
||||
cleanup() {
|
||||
echo ""
|
||||
echo -e "${BLUE}🛑 Shutting down Hermes Agent...${NC}"
|
||||
if [ ! -z "$SERVER_PID" ]; then
|
||||
kill $SERVER_PID 2>/dev/null || true
|
||||
echo -e "${GREEN}✓ API Server stopped${NC}"
|
||||
fi
|
||||
if [ ! -z "$UI_PID" ]; then
|
||||
kill $UI_PID 2>/dev/null || true
|
||||
echo -e "${GREEN}✓ UI Application stopped${NC}"
|
||||
fi
|
||||
echo -e "${GREEN}✓ Cleanup complete${NC}"
|
||||
exit 0
|
||||
}
|
||||
|
||||
# Set up trap for cleanup
|
||||
trap cleanup SIGINT SIGTERM EXIT
|
||||
|
||||
# Start API server in background
|
||||
echo -e "${BLUE}Starting API Server...${NC}"
|
||||
cd ../api_endpoint
|
||||
python3 logging_server.py > /tmp/hermes_server.log 2>&1 &
|
||||
SERVER_PID=$!
|
||||
cd ../ui
|
||||
|
||||
# Wait for server to start
|
||||
echo -e "${BLUE}Waiting for server to start...${NC}"
|
||||
sleep 3
|
||||
|
||||
# Check if server is running
|
||||
if ! kill -0 $SERVER_PID 2>/dev/null; then
|
||||
echo -e "${RED}❌ Server failed to start. Check /tmp/hermes_server.log${NC}"
|
||||
tail -20 /tmp/hermes_server.log
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Check if server is responding
|
||||
if curl -s http://localhost:8000/ > /dev/null; then
|
||||
echo -e "${GREEN}✓ API Server running on http://localhost:8000${NC}"
|
||||
else
|
||||
echo -e "${RED}❌ Server not responding. Check /tmp/hermes_server.log${NC}"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Start UI application
|
||||
echo -e "${BLUE}Starting UI Application...${NC}"
|
||||
python3 hermes_ui.py &
|
||||
UI_PID=$!
|
||||
|
||||
echo ""
|
||||
echo -e "${GREEN}================================${NC}"
|
||||
echo -e "${GREEN}✓ Hermes Agent UI is running!${NC}"
|
||||
echo -e "${GREEN}================================${NC}"
|
||||
echo ""
|
||||
echo -e "${BLUE}📊 Component Status:${NC}"
|
||||
echo -e " API Server: http://localhost:8000 (PID: $SERVER_PID)"
|
||||
echo -e " UI App: Running (PID: $UI_PID)"
|
||||
echo -e " Server Log: /tmp/hermes_server.log"
|
||||
echo ""
|
||||
echo -e "${BLUE}Press Ctrl+C to stop all services${NC}"
|
||||
echo ""
|
||||
|
||||
# Wait for UI to exit
|
||||
wait $UI_PID
|
||||
|
||||
# Cleanup will be triggered by trap
|
||||
|
||||
@@ -1,264 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test script to verify UI flow works correctly.
|
||||
|
||||
This tests:
|
||||
1. API server is running
|
||||
2. WebSocket connection works
|
||||
3. Agent can be started via API
|
||||
4. Events are broadcast properly
|
||||
"""
|
||||
|
||||
import requests
|
||||
import json
|
||||
import time
|
||||
import websocket
|
||||
import threading
|
||||
|
||||
API_URL = "http://localhost:8000"
|
||||
WS_URL = "ws://localhost:8000/ws"
|
||||
|
||||
def test_api_server():
|
||||
"""Test if API server is running."""
|
||||
print("🔍 Testing API server...")
|
||||
try:
|
||||
response = requests.get(f"{API_URL}/", timeout=5)
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
print(f"✅ API server is running: {data.get('service')}")
|
||||
print(f" Active connections: {data.get('active_connections')}")
|
||||
return True
|
||||
else:
|
||||
print(f"❌ API server returned: {response.status_code}")
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"❌ API server not accessible: {e}")
|
||||
return False
|
||||
|
||||
def test_tools_endpoint():
|
||||
"""Test if tools endpoint works."""
|
||||
print("\n🔍 Testing tools endpoint...")
|
||||
try:
|
||||
response = requests.get(f"{API_URL}/tools", timeout=5)
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
toolsets = data.get("toolsets", [])
|
||||
print(f"✅ Tools endpoint works - {len(toolsets)} toolsets available")
|
||||
for ts in toolsets[:3]:
|
||||
print(f" • {ts.get('name')} ({ts.get('tool_count')} tools)")
|
||||
return True
|
||||
else:
|
||||
print(f"❌ Tools endpoint failed: {response.status_code}")
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"❌ Tools endpoint error: {e}")
|
||||
return False
|
||||
|
||||
def test_websocket():
|
||||
"""Test WebSocket connection."""
|
||||
print("\n🔍 Testing WebSocket connection...")
|
||||
|
||||
connected = threading.Event()
|
||||
message_received = threading.Event()
|
||||
messages = []
|
||||
|
||||
def on_open(ws):
|
||||
print("✅ WebSocket connected")
|
||||
connected.set()
|
||||
|
||||
def on_message(ws, message):
|
||||
data = json.loads(message)
|
||||
messages.append(data)
|
||||
message_received.set()
|
||||
print(f"📨 Received: {data.get('event_type', 'unknown')}")
|
||||
|
||||
def on_error(ws, error):
|
||||
print(f"❌ WebSocket error: {error}")
|
||||
|
||||
def on_close(ws, close_status_code, close_msg):
|
||||
print(f"🔌 WebSocket closed: {close_status_code}")
|
||||
|
||||
ws = websocket.WebSocketApp(
|
||||
WS_URL,
|
||||
on_open=on_open,
|
||||
on_message=on_message,
|
||||
on_error=on_error,
|
||||
on_close=on_close
|
||||
)
|
||||
|
||||
# Run WebSocket in background
|
||||
ws_thread = threading.Thread(target=lambda: ws.run_forever(), daemon=True)
|
||||
ws_thread.start()
|
||||
|
||||
# Wait for connection
|
||||
if connected.wait(timeout=5):
|
||||
print("✅ WebSocket connection established")
|
||||
ws.close()
|
||||
return True
|
||||
else:
|
||||
print("❌ WebSocket connection timeout")
|
||||
ws.close()
|
||||
return False
|
||||
|
||||
def test_agent_run():
|
||||
"""Test running agent via API."""
|
||||
print("\n🔍 Testing agent run via API (mock mode)...")
|
||||
|
||||
# Start listening for events first
|
||||
events = []
|
||||
ws_connected = threading.Event()
|
||||
session_complete = threading.Event()
|
||||
|
||||
def on_message(ws, message):
|
||||
data = json.loads(message)
|
||||
events.append(data)
|
||||
event_type = data.get("event_type")
|
||||
print(f" 📨 Event: {event_type}")
|
||||
|
||||
if event_type == "complete":
|
||||
session_complete.set()
|
||||
|
||||
def on_open(ws):
|
||||
ws_connected.set()
|
||||
|
||||
# Connect WebSocket
|
||||
ws = websocket.WebSocketApp(
|
||||
WS_URL,
|
||||
on_open=on_open,
|
||||
on_message=on_message
|
||||
)
|
||||
|
||||
ws_thread = threading.Thread(target=lambda: ws.run_forever(), daemon=True)
|
||||
ws_thread.start()
|
||||
|
||||
# Wait for WebSocket connection
|
||||
if not ws_connected.wait(timeout=5):
|
||||
print("❌ WebSocket didn't connect")
|
||||
ws.close()
|
||||
return False
|
||||
|
||||
print("✅ WebSocket connected, starting agent...")
|
||||
|
||||
# Submit agent run
|
||||
payload = {
|
||||
"query": "Test query for UI flow verification",
|
||||
"model": "claude-sonnet-4-5-20250929",
|
||||
"base_url": "https://api.anthropic.com/v1/",
|
||||
"enabled_toolsets": ["web"],
|
||||
"max_turns": 5,
|
||||
"mock_web_tools": True, # Use mock mode to avoid API costs
|
||||
"mock_delay": 2, # Fast for testing
|
||||
"verbose": False
|
||||
}
|
||||
|
||||
try:
|
||||
response = requests.post(f"{API_URL}/agent/run", json=payload, timeout=10)
|
||||
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
session_id = result.get("session_id")
|
||||
print(f"✅ Agent started: {session_id[:8]}...")
|
||||
|
||||
# Wait for completion (or timeout)
|
||||
print("⏳ Waiting for agent to complete (up to 30s)...")
|
||||
if session_complete.wait(timeout=30):
|
||||
print(f"✅ Agent completed! Received {len(events)} events:")
|
||||
|
||||
# Count event types
|
||||
event_counts = {}
|
||||
for evt in events:
|
||||
evt_type = evt.get("event_type", "unknown")
|
||||
event_counts[evt_type] = event_counts.get(evt_type, 0) + 1
|
||||
|
||||
for evt_type, count in event_counts.items():
|
||||
print(f" • {evt_type}: {count}")
|
||||
|
||||
# Check we got expected events
|
||||
expected_events = ["query", "api_call", "response", "complete"]
|
||||
missing = [e for e in expected_events if e not in event_counts]
|
||||
|
||||
if missing:
|
||||
print(f"⚠️ Missing expected events: {missing}")
|
||||
else:
|
||||
print("✅ All expected event types received!")
|
||||
|
||||
ws.close()
|
||||
return True
|
||||
else:
|
||||
print(f"⚠️ Timeout waiting for completion. Got {len(events)} events so far.")
|
||||
ws.close()
|
||||
return False
|
||||
|
||||
else:
|
||||
print(f"❌ Agent start failed: {response.status_code}")
|
||||
print(f" Response: {response.text}")
|
||||
ws.close()
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Agent run error: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
ws.close()
|
||||
return False
|
||||
|
||||
def main():
|
||||
"""Run all tests."""
|
||||
print("=" * 60)
|
||||
print("🧪 Hermes Agent UI Flow Test")
|
||||
print("=" * 60)
|
||||
print("\nThis will test the complete flow:")
|
||||
print(" 1. API server connectivity")
|
||||
print(" 2. Tools endpoint")
|
||||
print(" 3. WebSocket connection")
|
||||
print(" 4. Agent execution via API (mock mode)")
|
||||
print(" 5. Event streaming to UI")
|
||||
print("\n" + "=" * 60)
|
||||
|
||||
results = []
|
||||
|
||||
# Test 1: API server
|
||||
results.append(("API Server", test_api_server()))
|
||||
|
||||
# Test 2: Tools endpoint
|
||||
results.append(("Tools Endpoint", test_tools_endpoint()))
|
||||
|
||||
# Test 3: WebSocket
|
||||
results.append(("WebSocket Connection", test_websocket()))
|
||||
|
||||
# Test 4: Agent run
|
||||
results.append(("Agent Execution + Events", test_agent_run()))
|
||||
|
||||
# Summary
|
||||
print("\n" + "=" * 60)
|
||||
print("📊 TEST SUMMARY")
|
||||
print("=" * 60)
|
||||
|
||||
for test_name, passed in results:
|
||||
status = "✅ PASS" if passed else "❌ FAIL"
|
||||
print(f"{status} - {test_name}")
|
||||
|
||||
all_passed = all(r[1] for r in results)
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
if all_passed:
|
||||
print("🎉 ALL TESTS PASSED!")
|
||||
print("\n✅ The UI flow is working correctly!")
|
||||
print(" You can now use the UI to:")
|
||||
print(" • Submit queries")
|
||||
print(" • View real-time events")
|
||||
print(" • See tool executions")
|
||||
print(" • Get final responses")
|
||||
else:
|
||||
print("❌ SOME TESTS FAILED")
|
||||
print("\nMake sure:")
|
||||
print(" 1. API server is running: python api_endpoint/logging_server.py")
|
||||
print(" 2. ANTHROPIC_API_KEY is set in environment")
|
||||
print(" 3. All dependencies are installed: pip install -r requirements.txt")
|
||||
print("=" * 60)
|
||||
|
||||
return 0 if all_passed else 1
|
||||
|
||||
if __name__ == "__main__":
|
||||
exit(main())
|
||||
|
||||
@@ -1,91 +0,0 @@
|
||||
"""
|
||||
WebSocket client for real-time event streaming from Hermes Agent.
|
||||
|
||||
This module provides a WebSocket client that runs in a separate thread
|
||||
and emits Qt signals when events are received from the server.
|
||||
"""
|
||||
|
||||
import json
|
||||
import threading
|
||||
import websocket
|
||||
from PySide6.QtCore import QObject, Signal
|
||||
|
||||
|
||||
class WebSocketClient(QObject):
|
||||
"""
|
||||
WebSocket client for receiving real-time agent events.
|
||||
|
||||
Runs in a separate thread and emits Qt signals when events arrive.
|
||||
"""
|
||||
|
||||
# Signals for event communication
|
||||
event_received = Signal(dict) # Emits parsed event data
|
||||
connected = Signal()
|
||||
disconnected = Signal()
|
||||
error = Signal(str)
|
||||
|
||||
def __init__(self, url: str = "ws://localhost:8000/ws"):
|
||||
super().__init__()
|
||||
self.url = url
|
||||
self.ws = None
|
||||
self.running = False
|
||||
self.thread = None
|
||||
|
||||
def connect(self):
|
||||
"""Start WebSocket connection in background thread."""
|
||||
if self.running:
|
||||
return
|
||||
|
||||
self.running = True
|
||||
self.thread = threading.Thread(target=self._run, daemon=True)
|
||||
self.thread.start()
|
||||
|
||||
def disconnect(self):
|
||||
"""Stop WebSocket connection."""
|
||||
self.running = False
|
||||
if self.ws:
|
||||
try:
|
||||
self.ws.close()
|
||||
except Exception as e:
|
||||
print(f"Error closing WebSocket: {e}")
|
||||
|
||||
def _run(self):
|
||||
"""WebSocket event loop (runs in background thread)."""
|
||||
try:
|
||||
self.ws = websocket.WebSocketApp(
|
||||
self.url,
|
||||
on_open=self._on_open,
|
||||
on_message=self._on_message,
|
||||
on_error=self._on_error,
|
||||
on_close=self._on_close
|
||||
)
|
||||
|
||||
# Run forever with reconnection
|
||||
self.ws.run_forever(ping_interval=300, ping_timeout=60)
|
||||
|
||||
except Exception as e:
|
||||
self.error.emit(f"WebSocket error: {str(e)}")
|
||||
|
||||
def _on_open(self, ws):
|
||||
"""Called when WebSocket connection is established."""
|
||||
print("WebSocket connected")
|
||||
self.connected.emit()
|
||||
|
||||
def _on_message(self, ws, message):
|
||||
"""Called when a message is received from the server."""
|
||||
try:
|
||||
data = json.loads(message)
|
||||
self.event_received.emit(data)
|
||||
except json.JSONDecodeError as e:
|
||||
print(f" Failed to parse WebSocket message: {e}")
|
||||
|
||||
def _on_error(self, ws, error):
|
||||
"""Called when an error occurs."""
|
||||
print(f"WebSocket error: {error}")
|
||||
self.error.emit(str(error))
|
||||
|
||||
def _on_close(self, ws, close_status_code, close_msg):
|
||||
"""Called when WebSocket connection is closed."""
|
||||
print(f"🔌 WebSocket disconnected: {close_status_code} - {close_msg}")
|
||||
self.disconnected.emit()
|
||||
|
||||
Reference in New Issue
Block a user