mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-05-06 10:47:12 +08:00
Compare commits
20 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c2d5a28d15 | ||
|
|
bb5eab2645 | ||
|
|
6313c9879f | ||
|
|
e698b7e0e5 | ||
|
|
c5386ed7e6 | ||
|
|
2082c7caa3 | ||
|
|
17608c1142 | ||
|
|
c7fa4447b8 | ||
|
|
587d1cf720 | ||
|
|
4ece87efb0 | ||
|
|
96cff78335 | ||
|
|
58d5fa1e4c | ||
|
|
f4ff1f496b | ||
|
|
e1710378b7 | ||
|
|
bc71dffd4c | ||
|
|
ebb46ba0e6 | ||
|
|
3078053795 | ||
|
|
cde7e64418 | ||
|
|
bf4223f381 | ||
|
|
1dacd941f6 |
19
.gitignore
vendored
19
.gitignore
vendored
@@ -1,2 +1,19 @@
|
|||||||
/venv/
|
/venv/
|
||||||
/_pycache/
|
/_pycache/
|
||||||
|
hecate/
|
||||||
|
hecate-lib/
|
||||||
|
*.pyc*
|
||||||
|
__pycache__/
|
||||||
|
.venv/
|
||||||
|
.vscode/
|
||||||
|
.env
|
||||||
|
.env.local
|
||||||
|
.env.development.local
|
||||||
|
.env.test.local
|
||||||
|
.env.production.local
|
||||||
|
.env.development
|
||||||
|
.env.test
|
||||||
|
export*
|
||||||
|
__pycache__/model_tools.cpython-310.pyc
|
||||||
|
__pycache__/web_tools.cpython-310.pyc
|
||||||
|
logs/
|
||||||
302
README.md
302
README.md
@@ -1,17 +1,295 @@
|
|||||||
## Setup
|
# Hermes Agent
|
||||||
```
|
|
||||||
|
AI Agent with advanced tool calling capabilities, real-time logging, and extensible toolsets.
|
||||||
|
|
||||||
|
## Features
|
||||||
|
|
||||||
|
- 🤖 **Multi-model Support**: Works with Claude, GPT-4, and other OpenAI-compatible models
|
||||||
|
- 🔧 **Rich Tool Library**: Web search, content extraction, vision analysis, terminal execution, and more
|
||||||
|
- 📊 **Real-time Logging**: WebSocket-based logging system for monitoring agent execution
|
||||||
|
- 🖥️ **Desktop UI**: Modern PySide6 frontend with real-time event streaming
|
||||||
|
- 🎯 **Flexible Toolsets**: Predefined toolset combinations for different use cases
|
||||||
|
- 💾 **Trajectory Saving**: Save conversation flows for training and analysis
|
||||||
|
- 🔄 **Auto-retry**: Built-in error handling and retry logic
|
||||||
|
|
||||||
|
## Quick Start
|
||||||
|
|
||||||
|
### Installation
|
||||||
|
|
||||||
|
```bash
|
||||||
pip install -r requirements.txt
|
pip install -r requirements.txt
|
||||||
git clone git@github.com:NousResearch/hecate.git
|
|
||||||
cd hecate
|
|
||||||
pip install -e .
|
|
||||||
```
|
```
|
||||||
|
|
||||||
## Run
|
### Basic Usage
|
||||||
```
|
|
||||||
|
```bash
|
||||||
python run_agent.py \
|
python run_agent.py \
|
||||||
--query "search up the latest docs on jit in python 3.13 and write me basic example that's not in their docs. profile its perf" \
|
--enabled_toolsets web \
|
||||||
--max_turns 20 \
|
--query "Search for the latest AI news"
|
||||||
--model claude-sonnet-4-20250514 \
|
|
||||||
--base_url https://api.anthropic.com/v1/ \
|
|
||||||
--api_key $ANTHROPIC_API_KEY
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### With Real-time Logging
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Terminal 1: Start API endpoint server
|
||||||
|
python api_endpoint/logging_server.py
|
||||||
|
|
||||||
|
# Terminal 2: Run agent
|
||||||
|
python run_agent.py \
|
||||||
|
--enabled_toolsets web \
|
||||||
|
--enable_websocket_logging \
|
||||||
|
--query "Your question here"
|
||||||
|
```
|
||||||
|
|
||||||
|
### With Desktop UI (Recommended)
|
||||||
|
|
||||||
|
The easiest way to use Hermes Agent is through the desktop UI:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# One-command launch (starts server + UI)
|
||||||
|
cd ui && ./start_hermes_ui.sh
|
||||||
|
|
||||||
|
# Or manually:
|
||||||
|
# Terminal 1: Start server
|
||||||
|
python api_endpoint/logging_server.py
|
||||||
|
|
||||||
|
# Terminal 2: Start UI
|
||||||
|
python ui/hermes_ui.py
|
||||||
|
```
|
||||||
|
|
||||||
|
The UI provides:
|
||||||
|
- 🖱️ Point-and-click query submission
|
||||||
|
- 🎛️ Easy model and tool selection
|
||||||
|
- 📊 Real-time event visualization
|
||||||
|
- 🔄 Automatic WebSocket connection
|
||||||
|
- 📝 Session history
|
||||||
|
|
||||||
|
## Project Structure
|
||||||
|
|
||||||
|
```
|
||||||
|
Hermes-Agent/
|
||||||
|
├── run_agent.py # Main agent runner
|
||||||
|
├── model_tools.py # Tool definitions and handling
|
||||||
|
├── toolsets.py # Predefined toolset combinations
|
||||||
|
├── requirements.txt # Python dependencies
|
||||||
|
│
|
||||||
|
├── ui/ # Desktop UI ⭐ NEW
|
||||||
|
│ ├── hermes_ui.py # PySide6 desktop application
|
||||||
|
│ ├── start_hermes_ui.sh # UI launcher script
|
||||||
|
│ └── test_ui_flow.py # UI integration tests
|
||||||
|
│
|
||||||
|
├── tools/ # Tool implementations
|
||||||
|
│ ├── web_tools.py # Web search, extract, crawl
|
||||||
|
│ ├── vision_tools.py # Image analysis
|
||||||
|
│ ├── terminal_tool.py # Command execution
|
||||||
|
│ ├── image_generation_tool.py
|
||||||
|
│ └── ...
|
||||||
|
│
|
||||||
|
├── api_endpoint/ # FastAPI + WebSocket logging endpoint
|
||||||
|
│ ├── logging_server.py # WebSocket server + Agent API ⭐ ENHANCED
|
||||||
|
│ ├── websocket_logger.py # Client library
|
||||||
|
│ ├── README.md # API endpoint docs
|
||||||
|
│ └── ...
|
||||||
|
│
|
||||||
|
├── logs/ # Log files
|
||||||
|
│ └── realtime/ # WebSocket session logs
|
||||||
|
│
|
||||||
|
└── tests/ # Test files
|
||||||
|
```
|
||||||
|
|
||||||
|
## Available Toolsets
|
||||||
|
|
||||||
|
### Basic Toolsets
|
||||||
|
- **web**: Web search, extract, and crawl
|
||||||
|
- **terminal**: Command execution
|
||||||
|
- **vision**: Image analysis
|
||||||
|
- **creative**: Image generation
|
||||||
|
- **reasoning**: Mixture of agents
|
||||||
|
|
||||||
|
### Composite Toolsets
|
||||||
|
- **research**: Web + vision tools
|
||||||
|
- **development**: Web + terminal + vision
|
||||||
|
- **analysis**: Web + vision + reasoning
|
||||||
|
- **full_stack**: All tools enabled
|
||||||
|
|
||||||
|
### Usage Examples
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Research with web and vision
|
||||||
|
python run_agent.py --enabled_toolsets research --query "..."
|
||||||
|
|
||||||
|
# Development with terminal access
|
||||||
|
python run_agent.py --enabled_toolsets development --query "..."
|
||||||
|
|
||||||
|
# Combine multiple toolsets
|
||||||
|
python run_agent.py --enabled_toolsets web,vision --query "..."
|
||||||
|
```
|
||||||
|
|
||||||
|
## Real-time Logging System
|
||||||
|
|
||||||
|
Monitor your agent's execution in real-time with the FastAPI WebSocket endpoint using a **persistent connection pool** architecture.
|
||||||
|
|
||||||
|
### Architecture
|
||||||
|
|
||||||
|
The logging system uses a **singleton WebSocket connection** that persists across multiple agent runs:
|
||||||
|
- ✅ **No timeouts** - connection stays alive indefinitely
|
||||||
|
- ✅ **No reconnection overhead** - connect once, reuse forever
|
||||||
|
- ✅ **Parallel execution** - multiple agents share one connection
|
||||||
|
- ✅ **Production-ready** - graceful shutdown with signal handlers
|
||||||
|
|
||||||
|
See [`api_endpoint/PERSISTENT_CONNECTION_GUIDE.md`](api_endpoint/PERSISTENT_CONNECTION_GUIDE.md) for technical details.
|
||||||
|
|
||||||
|
### Features
|
||||||
|
- Track all API calls and responses
|
||||||
|
- **Persistent connection** - one WebSocket for all sessions
|
||||||
|
- Monitor tool executions with parameters and timing
|
||||||
|
- Capture errors and completion status
|
||||||
|
- REST API for querying sessions
|
||||||
|
- Real-time WebSocket broadcasting
|
||||||
|
|
||||||
|
### Documentation
|
||||||
|
See [`api_endpoint/README.md`](api_endpoint/README.md) for complete documentation.
|
||||||
|
|
||||||
|
### Quick Start
|
||||||
|
```bash
|
||||||
|
# Start API endpoint server
|
||||||
|
python api_endpoint/logging_server.py
|
||||||
|
|
||||||
|
# Run agent with logging
|
||||||
|
python run_agent.py --enable_websocket_logging --query "..."
|
||||||
|
|
||||||
|
# View logs
|
||||||
|
curl http://localhost:8000/sessions
|
||||||
|
```
|
||||||
|
|
||||||
|
## Configuration
|
||||||
|
|
||||||
|
### Environment Variables
|
||||||
|
|
||||||
|
Create a `.env` file in the project root:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# API Keys
|
||||||
|
ANTHROPIC_API_KEY=your_key_here
|
||||||
|
FIRECRAWL_API_KEY=your_key_here
|
||||||
|
NOUS_API_KEY=your_key_here
|
||||||
|
FAL_KEY=your_key_here
|
||||||
|
|
||||||
|
# Optional
|
||||||
|
WEB_TOOLS_DEBUG=true # Enable web tools debug logging
|
||||||
|
```
|
||||||
|
|
||||||
|
### Command-Line Options
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python run_agent.py --help
|
||||||
|
```
|
||||||
|
|
||||||
|
Key options:
|
||||||
|
- `--query`: Your question/task
|
||||||
|
- `--model`: Model to use (default: claude-sonnet-4-5-20250929)
|
||||||
|
- `--enabled_toolsets`: Toolsets to enable
|
||||||
|
- `--max_turns`: Maximum conversation turns
|
||||||
|
- `--enable_websocket_logging`: Enable real-time logging
|
||||||
|
- `--verbose`: Verbose debug output
|
||||||
|
- `--save_trajectories`: Save conversation trajectories
|
||||||
|
|
||||||
|
## Parallel Execution
|
||||||
|
|
||||||
|
The persistent connection pool enables true parallel agent execution. Multiple agents can run simultaneously, all sharing the same WebSocket connection for logging.
|
||||||
|
|
||||||
|
### Test Parallel Execution
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python test_parallel_execution.py
|
||||||
|
```
|
||||||
|
|
||||||
|
This script runs three tests:
|
||||||
|
1. **Sequential** - baseline (3 queries one after another)
|
||||||
|
2. **Parallel** - 3 queries simultaneously
|
||||||
|
3. **High Concurrency** - 10 queries simultaneously
|
||||||
|
|
||||||
|
**Expected Results:**
|
||||||
|
- ⚡ ~3x speedup with parallel execution
|
||||||
|
- ✅ All queries logged to same connection
|
||||||
|
- ✅ No connection timeouts or errors
|
||||||
|
|
||||||
|
### Custom Parallel Code
|
||||||
|
|
||||||
|
```python
|
||||||
|
import asyncio
|
||||||
|
from run_agent import AIAgent
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
agent1 = AIAgent(enable_websocket_logging=True)
|
||||||
|
agent2 = AIAgent(enable_websocket_logging=True)
|
||||||
|
|
||||||
|
# Run in parallel - both use shared connection!
|
||||||
|
results = await asyncio.gather(
|
||||||
|
agent1.run_conversation("Query 1"),
|
||||||
|
agent2.run_conversation("Query 2")
|
||||||
|
)
|
||||||
|
|
||||||
|
asyncio.run(main())
|
||||||
|
```
|
||||||
|
|
||||||
|
## Examples
|
||||||
|
|
||||||
|
### Investment Research
|
||||||
|
```bash
|
||||||
|
python run_agent.py \
|
||||||
|
--enabled_toolsets web \
|
||||||
|
--query "Find publicly traded companies in renewable energy"
|
||||||
|
```
|
||||||
|
|
||||||
|
### Code Analysis
|
||||||
|
```bash
|
||||||
|
python run_agent.py \
|
||||||
|
--enabled_toolsets development \
|
||||||
|
--query "Analyze the codebase and suggest improvements"
|
||||||
|
```
|
||||||
|
|
||||||
|
### Image Analysis
|
||||||
|
```bash
|
||||||
|
python run_agent.py \
|
||||||
|
--enabled_toolsets vision \
|
||||||
|
--query "Analyze this chart and explain the trends"
|
||||||
|
```
|
||||||
|
|
||||||
|
## Development
|
||||||
|
|
||||||
|
### Adding New Tools
|
||||||
|
|
||||||
|
1. Create tool in `tools/` directory
|
||||||
|
2. Register in `model_tools.py`
|
||||||
|
3. Add to appropriate toolset in `toolsets.py`
|
||||||
|
|
||||||
|
### Running Tests
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Test web tools
|
||||||
|
python tests/test_web_tools.py
|
||||||
|
|
||||||
|
# Test API endpoint / logging
|
||||||
|
cd api_endpoint
|
||||||
|
./test_websocket_logging.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
## License
|
||||||
|
|
||||||
|
MIT License - see LICENSE file for details
|
||||||
|
|
||||||
|
## Contributing
|
||||||
|
|
||||||
|
Contributions welcome! Please open an issue or PR.
|
||||||
|
|
||||||
|
## Support
|
||||||
|
|
||||||
|
For questions or issues:
|
||||||
|
1. Check documentation in `api_endpoint/`
|
||||||
|
2. Review example usage in this README
|
||||||
|
3. Open a GitHub issue
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
Built with ❤️ for advanced AI agent workflows
|
||||||
|
|||||||
Binary file not shown.
Binary file not shown.
26
api_endpoint/__init__.py
Normal file
26
api_endpoint/__init__.py
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
"""
|
||||||
|
Hermes Agent - API Endpoint & Real-time Logging
|
||||||
|
|
||||||
|
This package provides a FastAPI WebSocket endpoint for real-time logging of the Hermes Agent.
|
||||||
|
|
||||||
|
Components:
|
||||||
|
- logging_server: FastAPI server that receives and stores events
|
||||||
|
- websocket_logger: Client library for sending events from the agent
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
# Start the API endpoint server
|
||||||
|
python api_endpoint/logging_server.py
|
||||||
|
|
||||||
|
# Use in agent code
|
||||||
|
from api_endpoint.websocket_logger import WebSocketLogger
|
||||||
|
|
||||||
|
For more information, see:
|
||||||
|
- WEBSOCKET_LOGGING_GUIDE.md - User guide
|
||||||
|
- IMPLEMENTATION_SUMMARY.md - Technical details
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .websocket_logger import WebSocketLogger, SyncWebSocketLogger
|
||||||
|
|
||||||
|
__all__ = ['WebSocketLogger', 'SyncWebSocketLogger']
|
||||||
|
__version__ = '1.0.0'
|
||||||
|
|
||||||
603
api_endpoint/logging_server.py
Normal file
603
api_endpoint/logging_server.py
Normal file
@@ -0,0 +1,603 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Hermes Agent - Real-time Logging Server
|
||||||
|
|
||||||
|
A FastAPI server with WebSocket support that listens for agent execution events
|
||||||
|
and logs them to JSON files in real-time.
|
||||||
|
|
||||||
|
Events tracked:
|
||||||
|
- User queries
|
||||||
|
- API calls (requests to the model)
|
||||||
|
- Assistant responses
|
||||||
|
- Tool calls (name, parameters, timing)
|
||||||
|
- Tool results (outputs, errors, duration)
|
||||||
|
- Final responses
|
||||||
|
- Session metadata
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python logging_server.py
|
||||||
|
|
||||||
|
Or with uvicorn directly:
|
||||||
|
uvicorn logging_server:app --host 0.0.0.0 --port 8000 --reload
|
||||||
|
|
||||||
|
The server will listen for WebSocket connections at ws://localhost:8000/ws
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import asyncio
|
||||||
|
import threading
|
||||||
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, Any, List, Optional
|
||||||
|
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, BackgroundTasks
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from pydantic import BaseModel
|
||||||
|
import uvicorn
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# Configuration
|
||||||
|
LOGS_DIR = Path(__file__).parent / "logs" / "realtime"
|
||||||
|
LOGS_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Initialize FastAPI app
|
||||||
|
app = FastAPI(
|
||||||
|
title="Hermes Agent API Endpoint",
|
||||||
|
description="Manage interface between agent and user",
|
||||||
|
version="1.0.0"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add CORS middleware
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=["*"],
|
||||||
|
allow_credentials=True,
|
||||||
|
allow_methods=["*"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SessionLogger:
|
||||||
|
"""
|
||||||
|
Manages logging for a single agent session.
|
||||||
|
|
||||||
|
Each agent execution gets its own SessionLogger instance.
|
||||||
|
Responsible for:
|
||||||
|
- Collecting all events for the session
|
||||||
|
- Saving events to JSON file in real-time
|
||||||
|
- Managing session lifecycle (start -> events -> finalize)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, session_id: str):
|
||||||
|
self.session_id = session_id
|
||||||
|
self.start_time = datetime.now()
|
||||||
|
self.events: List[Dict[str, Any]] = [] # In-memory list of all events
|
||||||
|
self.log_file = LOGS_DIR / f"session_{session_id}.json" # Where to save on disk
|
||||||
|
|
||||||
|
# Initialize session data structure
|
||||||
|
# This is what gets saved to the JSON file
|
||||||
|
self.session_data = {
|
||||||
|
"session_id": session_id,
|
||||||
|
"start_time": self.start_time.isoformat(),
|
||||||
|
"end_time": None, # Set when session completes
|
||||||
|
"events": [], # Will be populated as events come in
|
||||||
|
"metadata": {} # Model, toolsets, etc. (set via session_start event)
|
||||||
|
}
|
||||||
|
|
||||||
|
def add_event(self, event: Dict[str, Any]):
|
||||||
|
"""
|
||||||
|
Add an event to the session log.
|
||||||
|
|
||||||
|
Called every time a new event arrives (query, api_call, tool_call, etc).
|
||||||
|
IMMEDIATELY saves to file for real-time persistence.
|
||||||
|
"""
|
||||||
|
# Add timestamp if not present (should always be added, but safety check)
|
||||||
|
if "timestamp" not in event:
|
||||||
|
event["timestamp"] = datetime.now().isoformat()
|
||||||
|
|
||||||
|
# Add to in-memory event list
|
||||||
|
self.events.append(event)
|
||||||
|
self.session_data["events"] = self.events
|
||||||
|
|
||||||
|
# CRITICAL: Save to file immediately (real-time logging)
|
||||||
|
# This ensures events are persisted even if agent crashes
|
||||||
|
self._save()
|
||||||
|
|
||||||
|
def set_metadata(self, metadata: Dict[str, Any]):
|
||||||
|
"""Set session metadata (model, toolsets, etc.)."""
|
||||||
|
self.session_data["metadata"].update(metadata)
|
||||||
|
self._save()
|
||||||
|
|
||||||
|
def finalize(self):
|
||||||
|
"""Finalize the session and save."""
|
||||||
|
self.session_data["end_time"] = datetime.now().isoformat()
|
||||||
|
self._save()
|
||||||
|
|
||||||
|
def _save(self):
|
||||||
|
"""
|
||||||
|
Save current session data to JSON file.
|
||||||
|
|
||||||
|
Called after EVERY event is added - provides real-time persistence.
|
||||||
|
If file writing fails, logs error but continues (doesn't crash server).
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Write complete session data to JSON file
|
||||||
|
# indent=2 makes it human-readable
|
||||||
|
# ensure_ascii=False preserves Unicode characters
|
||||||
|
with open(self.log_file, 'w', encoding='utf-8') as f:
|
||||||
|
json.dump(self.session_data, f, indent=2, ensure_ascii=False)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Error saving session log: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
class ConnectionManager:
|
||||||
|
"""
|
||||||
|
Manages WebSocket connections and active sessions.
|
||||||
|
|
||||||
|
Global singleton that:
|
||||||
|
- Tracks all active WebSocket connections (for broadcasting)
|
||||||
|
- Manages all SessionLogger instances (one per agent session)
|
||||||
|
- Coordinates between WebSocket events and file logging
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.active_connections: List[WebSocket] = [] # All connected WebSocket clients
|
||||||
|
self.sessions: Dict[str, SessionLogger] = {} # session_id -> SessionLogger mapping
|
||||||
|
|
||||||
|
async def connect(self, websocket: WebSocket):
|
||||||
|
"""Accept a new WebSocket connection."""
|
||||||
|
await websocket.accept()
|
||||||
|
self.active_connections.append(websocket)
|
||||||
|
print(f"✅ WebSocket connected. Active connections: {len(self.active_connections)}")
|
||||||
|
|
||||||
|
def disconnect(self, websocket: WebSocket):
|
||||||
|
"""Remove a WebSocket connection."""
|
||||||
|
if websocket in self.active_connections:
|
||||||
|
self.active_connections.remove(websocket)
|
||||||
|
print(f"❌ WebSocket disconnected. Active connections: {len(self.active_connections)}")
|
||||||
|
|
||||||
|
def get_or_create_session(self, session_id: str) -> SessionLogger:
|
||||||
|
"""
|
||||||
|
Get existing session logger or create a new one.
|
||||||
|
|
||||||
|
Called when an event arrives for a session. Creates SessionLogger
|
||||||
|
on first event, reuses it for subsequent events from same session.
|
||||||
|
"""
|
||||||
|
if session_id not in self.sessions:
|
||||||
|
# First time seeing this session - create new logger
|
||||||
|
self.sessions[session_id] = SessionLogger(session_id)
|
||||||
|
print(f"📝 Created new session: {session_id}")
|
||||||
|
return self.sessions[session_id]
|
||||||
|
|
||||||
|
def finalize_session(self, session_id: str):
|
||||||
|
"""Finalize and clean up a session."""
|
||||||
|
if session_id in self.sessions:
|
||||||
|
self.sessions[session_id].finalize()
|
||||||
|
print(f"✅ Session finalized: {session_id}")
|
||||||
|
|
||||||
|
async def broadcast(self, message: Dict[str, Any]):
|
||||||
|
"""
|
||||||
|
Broadcast a message to all connected WebSocket clients.
|
||||||
|
|
||||||
|
Allows multiple clients (e.g., multiple browser tabs) to watch
|
||||||
|
the same agent session in real-time. Future UI feature.
|
||||||
|
"""
|
||||||
|
disconnected = []
|
||||||
|
for connection in self.active_connections:
|
||||||
|
try:
|
||||||
|
await connection.send_json(message)
|
||||||
|
except Exception:
|
||||||
|
# Connection closed - mark for removal
|
||||||
|
disconnected.append(connection)
|
||||||
|
|
||||||
|
# Clean up disconnected clients silently
|
||||||
|
for conn in disconnected:
|
||||||
|
if conn in self.active_connections:
|
||||||
|
self.active_connections.remove(conn)
|
||||||
|
|
||||||
|
|
||||||
|
# Global connection manager
|
||||||
|
manager = ConnectionManager()
|
||||||
|
|
||||||
|
|
||||||
|
# Request/Response models for API endpoints
|
||||||
|
class AgentRequest(BaseModel):
|
||||||
|
"""Request model for starting an agent run."""
|
||||||
|
query: str
|
||||||
|
model: str = "claude-sonnet-4-5-20250929"
|
||||||
|
base_url: str = "https://api.anthropic.com/v1/"
|
||||||
|
enabled_toolsets: Optional[List[str]] = None
|
||||||
|
disabled_toolsets: Optional[List[str]] = None
|
||||||
|
max_turns: int = 10
|
||||||
|
mock_web_tools: bool = False
|
||||||
|
mock_delay: int = 60
|
||||||
|
verbose: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
class AgentResponse(BaseModel):
|
||||||
|
"""Response model for agent run request."""
|
||||||
|
status: str
|
||||||
|
session_id: str
|
||||||
|
message: str
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/")
|
||||||
|
async def root():
|
||||||
|
"""Root endpoint - server status."""
|
||||||
|
return {
|
||||||
|
"status": "running",
|
||||||
|
"service": "Hermes Agent Logging Server",
|
||||||
|
"websocket_url": "ws://localhost:8000/ws",
|
||||||
|
"active_connections": len(manager.active_connections),
|
||||||
|
"active_sessions": len(manager.sessions),
|
||||||
|
"logs_directory": str(LOGS_DIR)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/sessions")
|
||||||
|
async def list_sessions():
|
||||||
|
"""List all active and recent sessions."""
|
||||||
|
# Get all session log files
|
||||||
|
session_files = list(LOGS_DIR.glob("session_*.json"))
|
||||||
|
|
||||||
|
sessions = []
|
||||||
|
for session_file in sorted(session_files, key=lambda x: x.stat().st_mtime, reverse=True):
|
||||||
|
try:
|
||||||
|
with open(session_file, 'r', encoding='utf-8') as f:
|
||||||
|
session_data = json.load(f)
|
||||||
|
sessions.append({
|
||||||
|
"session_id": session_data.get("session_id"),
|
||||||
|
"start_time": session_data.get("start_time"),
|
||||||
|
"end_time": session_data.get("end_time"),
|
||||||
|
"event_count": len(session_data.get("events", [])),
|
||||||
|
"file": str(session_file)
|
||||||
|
})
|
||||||
|
except Exception as e:
|
||||||
|
print(f"⚠️ Error reading session file {session_file}: {e}")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"total_sessions": len(sessions),
|
||||||
|
"sessions": sessions
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/sessions/{session_id}")
|
||||||
|
async def get_session(session_id: str):
|
||||||
|
"""Get detailed data for a specific session."""
|
||||||
|
session_file = LOGS_DIR / f"session_{session_id}.json"
|
||||||
|
|
||||||
|
if not session_file.exists():
|
||||||
|
return {"error": "Session not found"}, 404
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(session_file, 'r', encoding='utf-8') as f:
|
||||||
|
return json.load(f)
|
||||||
|
except Exception as e:
|
||||||
|
return {"error": f"Failed to load session: {str(e)}"}, 500
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/agent/run", response_model=AgentResponse)
|
||||||
|
async def run_agent(request: AgentRequest, background_tasks: BackgroundTasks):
|
||||||
|
"""
|
||||||
|
Start an agent run with specified parameters.
|
||||||
|
|
||||||
|
This endpoint triggers an agent execution in the background and returns immediately.
|
||||||
|
The agent will connect to the WebSocket endpoint to send real-time events.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: AgentRequest with query and configuration
|
||||||
|
background_tasks: FastAPI background tasks for async execution
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AgentResponse with session_id for tracking
|
||||||
|
"""
|
||||||
|
import uuid
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
|
||||||
|
# Generate session ID for this run - we'll pass it to the agent
|
||||||
|
session_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
# Add parent directory to path to import run_agent
|
||||||
|
parent_dir = str(Path(__file__).parent.parent)
|
||||||
|
if parent_dir not in sys.path:
|
||||||
|
sys.path.insert(0, parent_dir)
|
||||||
|
|
||||||
|
from run_agent import AIAgent
|
||||||
|
|
||||||
|
# Run agent in background thread (not blocking the API)
|
||||||
|
def run_agent_background():
|
||||||
|
"""Run agent in a separate thread."""
|
||||||
|
try:
|
||||||
|
# Initialize agent with WebSocket logging enabled
|
||||||
|
agent = AIAgent(
|
||||||
|
base_url=request.base_url,
|
||||||
|
model=request.model,
|
||||||
|
api_key=os.getenv("ANTHROPIC_API_KEY"),
|
||||||
|
max_iterations=request.max_turns,
|
||||||
|
enabled_toolsets=request.enabled_toolsets,
|
||||||
|
disabled_toolsets=request.disabled_toolsets,
|
||||||
|
save_trajectories=False,
|
||||||
|
verbose_logging=request.verbose,
|
||||||
|
enable_websocket_logging=True, # Always enable for UI
|
||||||
|
websocket_server="ws://localhost:8000/ws",
|
||||||
|
mock_web_tools=request.mock_web_tools,
|
||||||
|
mock_delay=request.mock_delay
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run conversation with our session_id
|
||||||
|
result = agent.run_conversation(
|
||||||
|
request.query,
|
||||||
|
session_id=session_id # Pass session_id so it matches
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"✅ Agent run completed: {session_id[:8]}...")
|
||||||
|
print(f" Final response: {result['final_response'][:100] if result.get('final_response') else 'No response'}...")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Error running agent {session_id[:8]}...: {e}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
# Start agent in background thread
|
||||||
|
thread = threading.Thread(target=run_agent_background, daemon=True)
|
||||||
|
thread.start()
|
||||||
|
|
||||||
|
return AgentResponse(
|
||||||
|
status="started",
|
||||||
|
session_id=session_id,
|
||||||
|
message=f"Agent started with session ID: {session_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/tools")
|
||||||
|
async def get_available_tools():
|
||||||
|
"""Get list of available toolsets and tools."""
|
||||||
|
try:
|
||||||
|
import sys
|
||||||
|
parent_dir = str(Path(__file__).parent.parent)
|
||||||
|
if parent_dir not in sys.path:
|
||||||
|
sys.path.insert(0, parent_dir)
|
||||||
|
|
||||||
|
from toolsets import get_all_toolsets, get_toolset_info
|
||||||
|
|
||||||
|
all_toolsets = get_all_toolsets()
|
||||||
|
toolsets_info = []
|
||||||
|
|
||||||
|
for name in all_toolsets.keys():
|
||||||
|
info = get_toolset_info(name)
|
||||||
|
if info:
|
||||||
|
toolsets_info.append({
|
||||||
|
"name": name,
|
||||||
|
"description": info['description'],
|
||||||
|
"tool_count": info['tool_count'],
|
||||||
|
"resolved_tools": info['resolved_tools']
|
||||||
|
})
|
||||||
|
|
||||||
|
return {
|
||||||
|
"toolsets": toolsets_info
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
return {"error": f"Failed to load tools: {str(e)}"}
|
||||||
|
|
||||||
|
|
||||||
|
@app.websocket("/ws")
|
||||||
|
async def websocket_endpoint(websocket: WebSocket):
|
||||||
|
"""
|
||||||
|
WebSocket endpoint for receiving real-time agent events.
|
||||||
|
|
||||||
|
This is the main entry point for all logging. Agents connect here and send events.
|
||||||
|
|
||||||
|
Message Flow:
|
||||||
|
1. Agent connects to ws://localhost:8000/ws
|
||||||
|
2. Agent sends events as JSON messages
|
||||||
|
3. Server parses event_type and routes to appropriate handler
|
||||||
|
4. Event is added to SessionLogger (saved to file)
|
||||||
|
5. Event is broadcast to all connected clients
|
||||||
|
6. Acknowledgment sent back to agent
|
||||||
|
|
||||||
|
Expected message format:
|
||||||
|
{
|
||||||
|
"session_id": "unique-session-id", // Links event to specific session
|
||||||
|
"event_type": "query" | "api_call" | ..., // What kind of event
|
||||||
|
"data": { ... event-specific data ... } // Event payload
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
# Accept the WebSocket connection
|
||||||
|
await manager.connect(websocket)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Main event loop - runs until client disconnects
|
||||||
|
while True:
|
||||||
|
# Receive message from client (agent)
|
||||||
|
# This is a blocking call - waits for next message
|
||||||
|
message = await websocket.receive_json()
|
||||||
|
|
||||||
|
# Parse the standard message structure
|
||||||
|
session_id = message.get("session_id") # Which agent session
|
||||||
|
event_type = message.get("event_type") # What kind of event
|
||||||
|
data = message.get("data", {}) # Event payload
|
||||||
|
|
||||||
|
# Validate: session_id is required
|
||||||
|
if not session_id:
|
||||||
|
await websocket.send_json({
|
||||||
|
"error": "session_id is required"
|
||||||
|
})
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Get or create SessionLogger for this session
|
||||||
|
# First event creates it, subsequent events reuse it
|
||||||
|
session = manager.get_or_create_session(session_id)
|
||||||
|
|
||||||
|
# Route event to appropriate handler based on event_type
|
||||||
|
# Each handler extracts relevant data and adds to session log
|
||||||
|
|
||||||
|
if event_type == "session_start":
|
||||||
|
# Initial event - sent when agent first connects
|
||||||
|
# Contains metadata about the session (model, toolsets, etc.)
|
||||||
|
session.set_metadata(data)
|
||||||
|
print(f"🚀 Session started: {session_id}")
|
||||||
|
|
||||||
|
elif event_type == "query":
|
||||||
|
# User query
|
||||||
|
session.add_event({
|
||||||
|
"type": "query",
|
||||||
|
"query": data.get("query"),
|
||||||
|
"toolsets": data.get("toolsets"),
|
||||||
|
"model": data.get("model")
|
||||||
|
})
|
||||||
|
print(f"📝 Query logged: {data.get('query', '')[:60]}...")
|
||||||
|
|
||||||
|
elif event_type == "api_call":
|
||||||
|
# API call to model
|
||||||
|
session.add_event({
|
||||||
|
"type": "api_call",
|
||||||
|
"call_number": data.get("call_number"),
|
||||||
|
"message_count": data.get("message_count"),
|
||||||
|
"has_tools": data.get("has_tools")
|
||||||
|
})
|
||||||
|
print(f"🔄 API call #{data.get('call_number')} logged")
|
||||||
|
|
||||||
|
elif event_type == "response":
|
||||||
|
# Assistant response
|
||||||
|
session.add_event({
|
||||||
|
"type": "response",
|
||||||
|
"call_number": data.get("call_number"),
|
||||||
|
"content": data.get("content"),
|
||||||
|
"has_tool_calls": data.get("has_tool_calls"),
|
||||||
|
"tool_call_count": data.get("tool_call_count"),
|
||||||
|
"duration": data.get("duration")
|
||||||
|
})
|
||||||
|
print(f"🤖 Response logged: {data.get('content', '')[:60]}...")
|
||||||
|
|
||||||
|
elif event_type == "tool_call":
|
||||||
|
# Tool execution
|
||||||
|
session.add_event({
|
||||||
|
"type": "tool_call",
|
||||||
|
"call_number": data.get("call_number"),
|
||||||
|
"tool_index": data.get("tool_index"),
|
||||||
|
"tool_name": data.get("tool_name"),
|
||||||
|
"parameters": data.get("parameters"),
|
||||||
|
"tool_call_id": data.get("tool_call_id")
|
||||||
|
})
|
||||||
|
print(f"🔧 Tool call logged: {data.get('tool_name')}")
|
||||||
|
|
||||||
|
elif event_type == "tool_result":
|
||||||
|
# Tool result - captures output from tool execution
|
||||||
|
# Now includes BOTH truncated preview AND full raw result
|
||||||
|
session.add_event({
|
||||||
|
"type": "tool_result",
|
||||||
|
"call_number": data.get("call_number"),
|
||||||
|
"tool_index": data.get("tool_index"),
|
||||||
|
"tool_name": data.get("tool_name"),
|
||||||
|
"result": data.get("result"), # Truncated preview (1000 chars)
|
||||||
|
"raw_result": data.get("raw_result"), # NEW: Full untruncated result
|
||||||
|
"error": data.get("error"),
|
||||||
|
"duration": data.get("duration"),
|
||||||
|
"tool_call_id": data.get("tool_call_id")
|
||||||
|
})
|
||||||
|
|
||||||
|
# Enhanced logging with size information
|
||||||
|
if data.get("error"):
|
||||||
|
print(f"❌ Tool error logged: {data.get('tool_name')}")
|
||||||
|
else:
|
||||||
|
# Show size of raw result to indicate data volume
|
||||||
|
raw_size = len(data.get("raw_result", "")) if data.get("raw_result") else len(data.get("result", ""))
|
||||||
|
size_kb = raw_size / 1024
|
||||||
|
print(f"✅ Tool result logged: {data.get('tool_name')} ({size_kb:.1f} KB)")
|
||||||
|
|
||||||
|
elif event_type == "error":
|
||||||
|
# Error event
|
||||||
|
session.add_event({
|
||||||
|
"type": "error",
|
||||||
|
"error_message": data.get("error_message"),
|
||||||
|
"call_number": data.get("call_number")
|
||||||
|
})
|
||||||
|
print(f"❌ Error logged: {data.get('error_message', '')[:60]}...")
|
||||||
|
|
||||||
|
elif event_type == "complete":
|
||||||
|
# Session complete
|
||||||
|
session.add_event({
|
||||||
|
"type": "complete",
|
||||||
|
"final_response": data.get("final_response"),
|
||||||
|
"total_calls": data.get("total_calls"),
|
||||||
|
"completed": data.get("completed")
|
||||||
|
})
|
||||||
|
manager.finalize_session(session_id)
|
||||||
|
print(f"🎉 Session complete: {session_id}")
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Unknown event type - log it anyway
|
||||||
|
session.add_event({
|
||||||
|
"type": event_type or "unknown",
|
||||||
|
**data
|
||||||
|
})
|
||||||
|
print(f"⚠️ Unknown event type: {event_type}")
|
||||||
|
|
||||||
|
# Broadcast event to all connected clients (for future real-time UI)
|
||||||
|
# Allows multiple browsers/dashboards to watch same session live
|
||||||
|
await manager.broadcast({
|
||||||
|
"session_id": session_id,
|
||||||
|
"event_type": event_type,
|
||||||
|
"timestamp": datetime.now().isoformat(),
|
||||||
|
"data": data
|
||||||
|
})
|
||||||
|
|
||||||
|
# Send acknowledgment back to sender
|
||||||
|
# Confirms event was received and logged
|
||||||
|
# Handle case where client disconnects before we can ack
|
||||||
|
try:
|
||||||
|
await websocket.send_json({
|
||||||
|
"status": "logged",
|
||||||
|
"session_id": session_id,
|
||||||
|
"event_type": event_type
|
||||||
|
})
|
||||||
|
except Exception:
|
||||||
|
# Connection closed before ack - this is normal for "complete" event
|
||||||
|
# Client disconnects after sending, so we can't ack
|
||||||
|
pass
|
||||||
|
|
||||||
|
except WebSocketDisconnect:
|
||||||
|
manager.disconnect(websocket)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ WebSocket error: {e}")
|
||||||
|
manager.disconnect(websocket)
|
||||||
|
|
||||||
|
|
||||||
|
def main(host: str = "0.0.0.0", port: int = 8000, reload: bool = False):
|
||||||
|
"""
|
||||||
|
Start the logging server.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
host: Host to bind to (default: 0.0.0.0)
|
||||||
|
port: Port to run on (default: 8000)
|
||||||
|
reload: Enable auto-reload on file changes (default: False)
|
||||||
|
"""
|
||||||
|
print("🚀 Hermes Agent Logging Server")
|
||||||
|
print("=" * 50)
|
||||||
|
print(f"📂 Logs directory: {LOGS_DIR}")
|
||||||
|
print(f"🌐 Server starting at http://{host}:{port}")
|
||||||
|
print(f"🔌 WebSocket endpoint: ws://{host}:{port}/ws")
|
||||||
|
print(f"🔄 Auto-reload: {'enabled' if reload else 'disabled'}")
|
||||||
|
print("\n📡 Ready to receive agent events...")
|
||||||
|
print("=" * 50)
|
||||||
|
|
||||||
|
uvicorn.run(
|
||||||
|
"logging_server:app",
|
||||||
|
host=host,
|
||||||
|
port=port,
|
||||||
|
reload=reload,
|
||||||
|
log_level="info",
|
||||||
|
timeout_keep_alive=600 # Keep HTTP/WS connections alive for 10 minutes of inactivity
|
||||||
|
# Note: WebSocket ping/pong disabled in client to avoid timeout during blocked event loop
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import fire
|
||||||
|
fire.Fire(main)
|
||||||
|
|
||||||
91
api_endpoint/test_websocket_logging.sh
Executable file
91
api_endpoint/test_websocket_logging.sh
Executable file
@@ -0,0 +1,91 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# Test script for WebSocket logging system
|
||||||
|
#
|
||||||
|
# This script demonstrates the complete WebSocket logging workflow:
|
||||||
|
# 1. Starts the logging server
|
||||||
|
# 2. Runs the agent with WebSocket logging enabled
|
||||||
|
# 3. Shows the logged data
|
||||||
|
#
|
||||||
|
# Usage: ./test_websocket_logging.sh
|
||||||
|
|
||||||
|
set -e # Exit on error
|
||||||
|
|
||||||
|
echo "🧪 Testing WebSocket Logging System"
|
||||||
|
echo "===================================="
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
# Check if required packages are installed
|
||||||
|
echo "📦 Checking dependencies..."
|
||||||
|
python -c "import fastapi; import uvicorn; import websockets" 2>/dev/null || {
|
||||||
|
echo "❌ Missing dependencies. Installing..."
|
||||||
|
pip install fastapi uvicorn websockets
|
||||||
|
}
|
||||||
|
echo "✅ Dependencies OK"
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
# Start the logging server in the background
|
||||||
|
echo "🚀 Starting logging server..."
|
||||||
|
python api_endpoint/logging_server.py --port 8000 &
|
||||||
|
SERVER_PID=$!
|
||||||
|
|
||||||
|
# Give server time to start
|
||||||
|
sleep 2
|
||||||
|
|
||||||
|
# Check if server is running
|
||||||
|
if ps -p $SERVER_PID > /dev/null; then
|
||||||
|
echo "✅ Logging server started (PID: $SERVER_PID)"
|
||||||
|
else
|
||||||
|
echo "❌ Failed to start logging server"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "🤖 Running agent with WebSocket logging..."
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
# Run the agent with WebSocket logging
|
||||||
|
python run_agent.py \
|
||||||
|
--enabled_toolsets web \
|
||||||
|
--enable_websocket_logging \
|
||||||
|
--query "What are the top 3 programming languages in 2025?" \
|
||||||
|
--max_turns 5
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "✅ Agent execution complete!"
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
# Show the most recent log file
|
||||||
|
echo "📊 Viewing logged session data..."
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
LATEST_LOG=$(ls -t logs/realtime/session_*.json 2>/dev/null | head -1)
|
||||||
|
|
||||||
|
if [ -f "$LATEST_LOG" ]; then
|
||||||
|
echo "📄 Log file: $LATEST_LOG"
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
# Pretty print the JSON if jq is available
|
||||||
|
if command -v jq &> /dev/null; then
|
||||||
|
echo "Event summary:"
|
||||||
|
jq '.events[] | {type: .type, timestamp: .timestamp}' "$LATEST_LOG"
|
||||||
|
echo ""
|
||||||
|
echo "Total events: $(jq '.events | length' "$LATEST_LOG")"
|
||||||
|
else
|
||||||
|
echo "Content (install 'jq' for pretty printing):"
|
||||||
|
cat "$LATEST_LOG"
|
||||||
|
fi
|
||||||
|
else
|
||||||
|
echo "⚠️ No log files found in logs/realtime/"
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "🛑 Stopping logging server..."
|
||||||
|
kill $SERVER_PID 2>/dev/null || true
|
||||||
|
|
||||||
|
echo "✅ Test complete!"
|
||||||
|
echo ""
|
||||||
|
echo "Next steps:"
|
||||||
|
echo " 1. Start server: python api_endpoint/logging_server.py"
|
||||||
|
echo " 2. Run agent: python run_agent.py --enable_websocket_logging --query \"...\""
|
||||||
|
echo " 3. View logs: http://localhost:8000/sessions"
|
||||||
|
|
||||||
457
api_endpoint/websocket_connection_pool.py
Normal file
457
api_endpoint/websocket_connection_pool.py
Normal file
@@ -0,0 +1,457 @@
|
|||||||
|
"""
|
||||||
|
WebSocket Connection Pool - Persistent Connection Manager
|
||||||
|
|
||||||
|
This module provides a singleton WebSocket connection that persists across
|
||||||
|
multiple agent runs. This is a more robust architecture than creating a new
|
||||||
|
connection for each run.
|
||||||
|
|
||||||
|
Benefits:
|
||||||
|
- No timeout issues (connection stays alive indefinitely)
|
||||||
|
- No reconnection overhead (connect once)
|
||||||
|
- Supports parallel agent runs (multiple sessions share one socket)
|
||||||
|
- Proper shutdown handling (SIGTERM/SIGINT)
|
||||||
|
- Thread-safe concurrent sends
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import signal
|
||||||
|
import websockets
|
||||||
|
from typing import Optional, Dict, Any
|
||||||
|
import json
|
||||||
|
import atexit
|
||||||
|
import sys
|
||||||
|
import threading
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
|
||||||
|
class WebSocketConnectionPool:
|
||||||
|
"""
|
||||||
|
Singleton WebSocket connection manager.
|
||||||
|
|
||||||
|
Maintains a single persistent connection to the logging server
|
||||||
|
that all agent sessions can use. Handles graceful shutdown.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
# Get singleton instance
|
||||||
|
pool = WebSocketConnectionPool()
|
||||||
|
|
||||||
|
# Connect (idempotent - safe to call multiple times)
|
||||||
|
await pool.connect()
|
||||||
|
|
||||||
|
# Send events (thread-safe, multiple sessions can call concurrently)
|
||||||
|
await pool.send_event("query", session_id, {...})
|
||||||
|
|
||||||
|
# Shutdown handled automatically on SIGTERM/SIGINT
|
||||||
|
"""
|
||||||
|
|
||||||
|
_instance: Optional['WebSocketConnectionPool'] = None
|
||||||
|
|
||||||
|
def __new__(cls):
|
||||||
|
"""Ensure only one instance exists (singleton pattern)."""
|
||||||
|
if cls._instance is None:
|
||||||
|
cls._instance = super().__new__(cls)
|
||||||
|
cls._instance._initialized = False
|
||||||
|
return cls._instance
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""Initialize the connection pool (only once)."""
|
||||||
|
if getattr(self, '_initialized', False):
|
||||||
|
return
|
||||||
|
|
||||||
|
self.websocket: Optional[websockets.WebSocketClientProtocol] = None
|
||||||
|
self.server_url: str = "ws://localhost:8000/ws"
|
||||||
|
self.connected: bool = False
|
||||||
|
# Store reference to loop for signal handlers
|
||||||
|
# Agent code should never close event loops when using persistent connections
|
||||||
|
self.loop: Optional[asyncio.AbstractEventLoop] = None
|
||||||
|
# Locks are created lazily when event loop exists
|
||||||
|
self._send_lock: Optional[asyncio.Lock] = None
|
||||||
|
self._connect_lock: Optional[asyncio.Lock] = None
|
||||||
|
self._locks_loop: Optional[asyncio.AbstractEventLoop] = None # Track which loop created locks
|
||||||
|
self._init_lock = threading.Lock() # Thread-safe lock initialization
|
||||||
|
self._shutdown_in_progress = False
|
||||||
|
self._initialized = True
|
||||||
|
|
||||||
|
# Register shutdown handlers for graceful cleanup
|
||||||
|
# These ensure WebSocket is closed properly on exit
|
||||||
|
signal.signal(signal.SIGTERM, self._signal_handler)
|
||||||
|
signal.signal(signal.SIGINT, self._signal_handler)
|
||||||
|
atexit.register(self._cleanup_sync)
|
||||||
|
|
||||||
|
print("🔌 WebSocket connection pool initialized")
|
||||||
|
|
||||||
|
def _ensure_locks(self):
|
||||||
|
"""
|
||||||
|
Lazy initialization of asyncio locks with thread safety and loop tracking.
|
||||||
|
|
||||||
|
Locks must be created when an event loop exists, not at import time.
|
||||||
|
If the event loop changes between runs, locks must be recreated because
|
||||||
|
asyncio.Lock objects are bound to the loop that created them.
|
||||||
|
|
||||||
|
This is called before any async operation that needs locks.
|
||||||
|
Uses a threading.Lock to prevent race conditions during initialization.
|
||||||
|
"""
|
||||||
|
with self._init_lock: # Thread-safe initialization
|
||||||
|
try:
|
||||||
|
current_loop = asyncio.get_event_loop()
|
||||||
|
except RuntimeError:
|
||||||
|
# No event loop in current thread
|
||||||
|
return
|
||||||
|
|
||||||
|
# Recreate locks if:
|
||||||
|
# 1. Locks don't exist yet, OR
|
||||||
|
# 2. Event loop has changed (locks are bound to the loop that created them)
|
||||||
|
if self._locks_loop is not current_loop or self._send_lock is None:
|
||||||
|
self._send_lock = asyncio.Lock()
|
||||||
|
self._connect_lock = asyncio.Lock()
|
||||||
|
self._locks_loop = current_loop
|
||||||
|
|
||||||
|
async def connect(self, server_url: str = "ws://localhost:8000/ws") -> bool:
|
||||||
|
"""
|
||||||
|
Connect to WebSocket server.
|
||||||
|
|
||||||
|
This is idempotent - safe to call multiple times. If already connected,
|
||||||
|
does nothing. If connection failed previously, will retry.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
server_url: WebSocket server URL (default: ws://localhost:8000/ws)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if connected successfully, False otherwise
|
||||||
|
"""
|
||||||
|
# Ensure locks exist (lazy initialization)
|
||||||
|
self._ensure_locks()
|
||||||
|
|
||||||
|
async with self._connect_lock:
|
||||||
|
# Always update loop reference to current loop (even if already connected)
|
||||||
|
# This ensures signal handlers and cleanup use the correct loop
|
||||||
|
self.loop = asyncio.get_event_loop()
|
||||||
|
|
||||||
|
# Already connected - nothing to do
|
||||||
|
if self.connected and self.websocket:
|
||||||
|
return True
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.server_url = server_url
|
||||||
|
|
||||||
|
# Establish persistent WebSocket connection
|
||||||
|
# No ping/pong needed since connection stays open indefinitely
|
||||||
|
self.websocket = await websockets.connect(
|
||||||
|
server_url,
|
||||||
|
ping_interval=None, # Disable ping/pong (not needed for persistent connection)
|
||||||
|
max_size=10 * 1024 * 1024, # 10MB max message size for large tool results
|
||||||
|
open_timeout=10, # 10s timeout for initial connection
|
||||||
|
close_timeout=5 # 5s timeout for close handshake
|
||||||
|
)
|
||||||
|
|
||||||
|
self.connected = True
|
||||||
|
|
||||||
|
print(f"✅ Connected to logging server (persistent): {server_url}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"⚠️ Failed to connect to logging server: {e}")
|
||||||
|
self.connected = False
|
||||||
|
self.websocket = None
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def send_event(
|
||||||
|
self,
|
||||||
|
event_type: str,
|
||||||
|
session_id: str,
|
||||||
|
data: Dict[str, Any],
|
||||||
|
retry: bool = True
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Send event to logging server (thread-safe).
|
||||||
|
|
||||||
|
Multiple agent runs can call this concurrently. The send lock ensures
|
||||||
|
only one message is sent at a time (WebSocket protocol requirement).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
event_type: Type of event (query, api_call, response, tool_call, tool_result, error, complete)
|
||||||
|
session_id: Unique session identifier
|
||||||
|
data: Event-specific data dictionary
|
||||||
|
retry: Whether to retry connection if disconnected (default: True)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if sent successfully, False otherwise
|
||||||
|
"""
|
||||||
|
# Try to connect if not connected (or reconnect if disconnected)
|
||||||
|
if not self.connected or not self.websocket:
|
||||||
|
if retry:
|
||||||
|
await self.connect()
|
||||||
|
if not self.connected:
|
||||||
|
return False # Give up if connection fails
|
||||||
|
|
||||||
|
# Ensure locks exist (lazy initialization)
|
||||||
|
self._ensure_locks()
|
||||||
|
|
||||||
|
# Lock to prevent concurrent sends (WebSocket requires sequential sends)
|
||||||
|
async with self._send_lock:
|
||||||
|
try:
|
||||||
|
# Create standardized message format
|
||||||
|
message = {
|
||||||
|
"session_id": session_id,
|
||||||
|
"event_type": event_type,
|
||||||
|
"data": data,
|
||||||
|
"timestamp": datetime.now().isoformat()
|
||||||
|
}
|
||||||
|
|
||||||
|
# Send message as JSON
|
||||||
|
await self.websocket.send(json.dumps(message))
|
||||||
|
|
||||||
|
# Wait for server acknowledgment (with timeout)
|
||||||
|
# This confirms the server received and processed the event
|
||||||
|
try:
|
||||||
|
response = await asyncio.wait_for(
|
||||||
|
self.websocket.recv(),
|
||||||
|
timeout=2.0 # Increased to 2s for busy servers
|
||||||
|
)
|
||||||
|
# Successfully received acknowledgment
|
||||||
|
return True
|
||||||
|
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
# No response within timeout - that's OK, message likely sent
|
||||||
|
# Server might be busy processing
|
||||||
|
return True
|
||||||
|
|
||||||
|
except websockets.exceptions.ConnectionClosed:
|
||||||
|
print(f"⚠️ WebSocket connection closed unexpectedly")
|
||||||
|
self.connected = False
|
||||||
|
|
||||||
|
# Try to reconnect and resend (one retry)
|
||||||
|
if retry:
|
||||||
|
print("🔄 Attempting to reconnect...")
|
||||||
|
if await self.connect():
|
||||||
|
# Recursively call with retry=False to avoid infinite loop
|
||||||
|
return await self.send_event(event_type, session_id, data, retry=False)
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"⚠️ Error sending event: {e}")
|
||||||
|
self.connected = False
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def disconnect(self):
|
||||||
|
"""
|
||||||
|
Gracefully close the WebSocket connection.
|
||||||
|
|
||||||
|
Called on shutdown (SIGTERM/SIGINT/exit). Ensures proper cleanup.
|
||||||
|
"""
|
||||||
|
if self._shutdown_in_progress:
|
||||||
|
return # Already shutting down
|
||||||
|
|
||||||
|
self._shutdown_in_progress = True
|
||||||
|
|
||||||
|
if self.websocket and self.connected:
|
||||||
|
try:
|
||||||
|
await self.websocket.close()
|
||||||
|
self.connected = False
|
||||||
|
print("✅ WebSocket connection pool closed gracefully")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"⚠️ Error closing WebSocket: {e}")
|
||||||
|
|
||||||
|
self._shutdown_in_progress = False
|
||||||
|
|
||||||
|
def _signal_handler(self, signum, frame):
|
||||||
|
"""
|
||||||
|
Handle SIGTERM/SIGINT signals for graceful shutdown.
|
||||||
|
|
||||||
|
When user presses Ctrl+C or system sends SIGTERM, this ensures
|
||||||
|
the WebSocket is closed properly before exit.
|
||||||
|
"""
|
||||||
|
print(f"\n🛑 Received signal {signum}, closing WebSocket connection pool...")
|
||||||
|
|
||||||
|
# Check if we have a valid loop and are connected
|
||||||
|
if self.loop and not self.loop.is_closed() and self.connected and not self._shutdown_in_progress:
|
||||||
|
try:
|
||||||
|
# If loop is not running, we can wait for disconnect
|
||||||
|
if not self.loop.is_running():
|
||||||
|
self.loop.run_until_complete(self.disconnect())
|
||||||
|
else:
|
||||||
|
# Loop is running, can't wait for task - just mark disconnected
|
||||||
|
# The disconnect task would be cancelled when we exit anyway
|
||||||
|
self.connected = False
|
||||||
|
print("⚠️ Loop is running, marking disconnected without waiting")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"⚠️ Error during signal handler cleanup: {e}")
|
||||||
|
|
||||||
|
# Exit gracefully
|
||||||
|
sys.exit(0)
|
||||||
|
|
||||||
|
def _cleanup_sync(self):
|
||||||
|
"""
|
||||||
|
Cleanup at exit (atexit handler).
|
||||||
|
|
||||||
|
This is a fallback in case signal handlers don't fire.
|
||||||
|
Called when Python interpreter shuts down normally.
|
||||||
|
"""
|
||||||
|
if self.loop and not self.loop.is_closed() and self.connected and not self._shutdown_in_progress:
|
||||||
|
try:
|
||||||
|
# Try to run disconnect synchronously
|
||||||
|
self.loop.run_until_complete(self.disconnect())
|
||||||
|
except Exception:
|
||||||
|
# Ignore errors during exit cleanup
|
||||||
|
pass
|
||||||
|
|
||||||
|
def is_connected(self) -> bool:
|
||||||
|
"""Check if currently connected to server."""
|
||||||
|
return self.connected and self.websocket is not None
|
||||||
|
|
||||||
|
def get_stats(self) -> Dict[str, Any]:
|
||||||
|
"""Get connection statistics for debugging."""
|
||||||
|
return {
|
||||||
|
"connected": self.connected,
|
||||||
|
"server_url": self.server_url,
|
||||||
|
"shutdown_in_progress": self._shutdown_in_progress,
|
||||||
|
"has_websocket": self.websocket is not None,
|
||||||
|
"has_loop": self.loop is not None
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# Global singleton instance
|
||||||
|
# Import this in other modules: from websocket_connection_pool import ws_pool
|
||||||
|
ws_pool = WebSocketConnectionPool()
|
||||||
|
|
||||||
|
|
||||||
|
# Convenience functions for direct usage
|
||||||
|
async def connect(server_url: str = "ws://localhost:8000/ws") -> bool:
|
||||||
|
"""Connect to logging server (convenience function)."""
|
||||||
|
return await ws_pool.connect(server_url)
|
||||||
|
|
||||||
|
|
||||||
|
async def send_event(event_type: str, session_id: str, data: Dict[str, Any]) -> bool:
|
||||||
|
"""Send event to logging server (convenience function)."""
|
||||||
|
return await ws_pool.send_event(event_type, session_id, data)
|
||||||
|
|
||||||
|
|
||||||
|
async def disconnect():
|
||||||
|
"""Disconnect from logging server (convenience function)."""
|
||||||
|
await ws_pool.disconnect()
|
||||||
|
|
||||||
|
|
||||||
|
def is_connected() -> bool:
|
||||||
|
"""Check if connected to logging server (convenience function)."""
|
||||||
|
return ws_pool.is_connected()
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# SYNCHRONOUS API FOR AGENT LAYER
|
||||||
|
# ============================================================================
|
||||||
|
# These functions provide a clean abstraction that hides event loop management
|
||||||
|
# from the agent layer. Agent code should ONLY use these functions.
|
||||||
|
|
||||||
|
def connect_sync(server_url: str = "ws://localhost:8000/ws") -> bool:
|
||||||
|
"""
|
||||||
|
Synchronous connect - handles event loop internally.
|
||||||
|
|
||||||
|
Creates a persistent event loop in a background thread if needed.
|
||||||
|
This is thread-safe and can be called from any thread (including agent background threads).
|
||||||
|
"""
|
||||||
|
import threading
|
||||||
|
|
||||||
|
# If pool doesn't have a loop yet or it's closed, we need to start one
|
||||||
|
if not ws_pool.loop or ws_pool.loop.is_closed():
|
||||||
|
# Start connection in a background thread with its own loop
|
||||||
|
result_container = {"success": False, "error": None, "connected": False}
|
||||||
|
|
||||||
|
def run_in_thread():
|
||||||
|
try:
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
ws_pool.loop = loop # Store the loop in the pool
|
||||||
|
|
||||||
|
# Connect to WebSocket
|
||||||
|
result_container["success"] = loop.run_until_complete(ws_pool.connect(server_url))
|
||||||
|
result_container["connected"] = True
|
||||||
|
|
||||||
|
# Keep loop running forever for future send_event calls
|
||||||
|
# This is critical - the loop must stay alive for run_coroutine_threadsafe to work
|
||||||
|
loop.run_forever()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
result_container["error"] = str(e)
|
||||||
|
print(f"❌ Error in WebSocket connection thread: {e}")
|
||||||
|
finally:
|
||||||
|
# Clean up if loop stops
|
||||||
|
if loop.is_running():
|
||||||
|
loop.close()
|
||||||
|
|
||||||
|
thread = threading.Thread(target=run_in_thread, daemon=True, name="WebSocket-EventLoop")
|
||||||
|
thread.start()
|
||||||
|
|
||||||
|
# Wait for connection to complete (but not for loop to exit - it runs forever)
|
||||||
|
import time
|
||||||
|
timeout = 10.0
|
||||||
|
start = time.time()
|
||||||
|
while not result_container["connected"] and (time.time() - start) < timeout:
|
||||||
|
time.sleep(0.1)
|
||||||
|
|
||||||
|
if result_container["error"]:
|
||||||
|
print(f"⚠️ Connection failed: {result_container['error']}")
|
||||||
|
|
||||||
|
return result_container["success"]
|
||||||
|
else:
|
||||||
|
# Pool already has a loop, use run_coroutine_threadsafe
|
||||||
|
try:
|
||||||
|
future = asyncio.run_coroutine_threadsafe(
|
||||||
|
ws_pool.connect(server_url),
|
||||||
|
ws_pool.loop
|
||||||
|
)
|
||||||
|
return future.result(timeout=10.0)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"⚠️ Connection failed: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def send_event_sync(event_type: str, session_id: str, data: Dict[str, Any]) -> bool:
|
||||||
|
"""
|
||||||
|
Synchronous send event - handles event loop internally.
|
||||||
|
|
||||||
|
Uses the WebSocket pool's own event loop to avoid loop conflicts.
|
||||||
|
This is critical when called from background threads (like agent execution).
|
||||||
|
This is thread-safe and works correctly even when agent runs in a different thread.
|
||||||
|
"""
|
||||||
|
if not ws_pool.loop or not ws_pool.loop.is_running():
|
||||||
|
# No event loop running - can't send
|
||||||
|
print("⚠️ WebSocket pool has no running event loop")
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Use run_coroutine_threadsafe to submit to the WebSocket pool's loop
|
||||||
|
# This works across threads - submits the coroutine to the correct loop
|
||||||
|
future = asyncio.run_coroutine_threadsafe(
|
||||||
|
ws_pool.send_event(event_type, session_id, data),
|
||||||
|
ws_pool.loop # ← Use the pool's loop, not current thread's loop
|
||||||
|
)
|
||||||
|
|
||||||
|
# Wait for completion (with timeout to avoid hanging)
|
||||||
|
return future.result(timeout=5.0)
|
||||||
|
|
||||||
|
except TimeoutError:
|
||||||
|
print(f"⚠️ Timeout sending event {event_type}")
|
||||||
|
return False
|
||||||
|
except Exception as e:
|
||||||
|
print(f"⚠️ Error sending event: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def disconnect_sync():
|
||||||
|
"""
|
||||||
|
Synchronous disconnect - handles event loop internally.
|
||||||
|
|
||||||
|
Thread-safe disconnect that works from any thread.
|
||||||
|
"""
|
||||||
|
if ws_pool.loop and ws_pool.loop.is_running():
|
||||||
|
try:
|
||||||
|
future = asyncio.run_coroutine_threadsafe(
|
||||||
|
ws_pool.disconnect(),
|
||||||
|
ws_pool.loop
|
||||||
|
)
|
||||||
|
return future.result(timeout=5.0)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"⚠️ Error disconnecting: {e}")
|
||||||
|
return False
|
||||||
|
return True
|
||||||
387
api_endpoint/websocket_logger.py
Normal file
387
api_endpoint/websocket_logger.py
Normal file
@@ -0,0 +1,387 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
WebSocket Logger Client
|
||||||
|
|
||||||
|
Simple client for sending agent events to the logging server via WebSocket.
|
||||||
|
Used by the agent to log events in real-time during execution.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import asyncio
|
||||||
|
from typing import Dict, Any, Optional
|
||||||
|
from datetime import datetime
|
||||||
|
import websockets
|
||||||
|
|
||||||
|
|
||||||
|
class WebSocketLogger:
|
||||||
|
"""
|
||||||
|
Client for logging agent events via WebSocket.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
logger = WebSocketLogger("unique-session-id")
|
||||||
|
await logger.connect()
|
||||||
|
await logger.log_query("What is Python?", model="gpt-4")
|
||||||
|
await logger.log_api_call(call_number=1)
|
||||||
|
await logger.log_response(call_number=1, content="Python is...")
|
||||||
|
await logger.disconnect()
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
session_id: str,
|
||||||
|
server_url: str = "ws://localhost:8000/ws",
|
||||||
|
enabled: bool = True
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize WebSocket logger.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_id: Unique identifier for this agent session
|
||||||
|
server_url: WebSocket server URL (default: ws://localhost:8000/ws)
|
||||||
|
enabled: Whether logging is enabled (default: True)
|
||||||
|
"""
|
||||||
|
self.session_id = session_id
|
||||||
|
self.server_url = server_url
|
||||||
|
self.enabled = enabled
|
||||||
|
self.websocket: Optional[websockets.WebSocketClientProtocol] = None
|
||||||
|
self.connected = False
|
||||||
|
self.reconnect_count = 0 # Track reconnections for debugging
|
||||||
|
|
||||||
|
async def connect(self):
|
||||||
|
"""
|
||||||
|
Connect to the WebSocket logging server.
|
||||||
|
|
||||||
|
Establishes WebSocket connection and sends initial session_start event.
|
||||||
|
If connection fails, gracefully disables logging (agent continues normally).
|
||||||
|
"""
|
||||||
|
if not self.enabled:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Establish WebSocket connection to the server
|
||||||
|
# Use VERY LONG ping intervals to avoid timeout during long tool execution
|
||||||
|
# The event loop is blocked during tool execution, so we can't process pings
|
||||||
|
# Setting to very large values (1 hour) effectively disables it
|
||||||
|
self.websocket = await websockets.connect(
|
||||||
|
self.server_url,
|
||||||
|
ping_interval=3600, # 1 hour - effectively disabled (event loop blocked anyway)
|
||||||
|
ping_timeout=3600, # 1 hour timeout for pong response
|
||||||
|
close_timeout=10, # Timeout for close handshake
|
||||||
|
max_size=10 * 1024 * 1024, # 10MB max message size (for large raw_results)
|
||||||
|
open_timeout=10 # Timeout for initial connection
|
||||||
|
)
|
||||||
|
self.connected = True
|
||||||
|
print(f"✅ Connected to logging server (ping/pong: 3600s intervals): {self.server_url}")
|
||||||
|
|
||||||
|
# Send initial session_start event
|
||||||
|
# This tells the server to create a new SessionLogger for this session
|
||||||
|
await self._send_event("session_start", {
|
||||||
|
"session_id": self.session_id,
|
||||||
|
"start_time": datetime.now().isoformat()
|
||||||
|
})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# Connection failed - disable logging but don't crash the agent
|
||||||
|
print(f"⚠️ Failed to connect to logging server: {e}")
|
||||||
|
print(f" Logging will be disabled for this session.")
|
||||||
|
self.enabled = False
|
||||||
|
self.connected = False
|
||||||
|
|
||||||
|
async def disconnect(self):
|
||||||
|
"""Disconnect from the WebSocket server."""
|
||||||
|
if self.websocket and self.connected:
|
||||||
|
try:
|
||||||
|
await self.websocket.close()
|
||||||
|
self.connected = False
|
||||||
|
print(f"✅ Disconnected from logging server")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"⚠️ Error disconnecting: {e}")
|
||||||
|
|
||||||
|
async def _send_event(self, event_type: str, data: Dict[str, Any]):
|
||||||
|
"""
|
||||||
|
Send an event to the logging server.
|
||||||
|
|
||||||
|
This is the core method that sends all events via WebSocket.
|
||||||
|
Creates a standardized message format and handles acknowledgments.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
event_type: Type of event (query, api_call, response, tool_call, tool_result, error, complete)
|
||||||
|
data: Event data dictionary containing event-specific information
|
||||||
|
"""
|
||||||
|
# Safety check: Don't send if logging is disabled
|
||||||
|
if not self.enabled:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Auto-reconnect if connection was lost
|
||||||
|
if not self.connected or not self.websocket:
|
||||||
|
try:
|
||||||
|
self.reconnect_count += 1
|
||||||
|
print(f"🔄 Reconnecting to logging server (attempt #{self.reconnect_count})...")
|
||||||
|
await self.connect()
|
||||||
|
print(f"✅ Reconnected successfully!")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"⚠️ Failed to reconnect: {e}")
|
||||||
|
self.enabled = False # Disable logging after failed reconnect
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Create standardized message structure
|
||||||
|
# All events follow this format for consistent server-side handling
|
||||||
|
message = {
|
||||||
|
"session_id": self.session_id, # Links event to specific agent session
|
||||||
|
"event_type": event_type, # Identifies what kind of event this is
|
||||||
|
"data": data # Event-specific payload
|
||||||
|
}
|
||||||
|
|
||||||
|
# Send message as JSON string over WebSocket
|
||||||
|
await self.websocket.send(json.dumps(message))
|
||||||
|
|
||||||
|
# Wait for server acknowledgment (with 1 second timeout)
|
||||||
|
# This ensures the server received and processed the event
|
||||||
|
try:
|
||||||
|
response = await asyncio.wait_for(
|
||||||
|
self.websocket.recv(),
|
||||||
|
timeout=1.0
|
||||||
|
)
|
||||||
|
# Server sends back: {"status": "logged", "session_id": "...", "event_type": "..."}
|
||||||
|
# We don't need to process it, just confirms receipt
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
# No response within 1 second - that's okay, continue anyway
|
||||||
|
# Server might be busy or network slow, but event was likely sent
|
||||||
|
pass
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# Log error but don't crash - graceful degradation
|
||||||
|
# Agent should continue working even if logging fails
|
||||||
|
error_str = str(e)
|
||||||
|
|
||||||
|
# Check if connection was closed (error 1011 = keepalive ping timeout)
|
||||||
|
if "1011" in error_str or "closed" in error_str.lower():
|
||||||
|
print(f"⚠️ WebSocket connection closed: {error_str}")
|
||||||
|
self.connected = False # Mark as disconnected
|
||||||
|
# Don't try to send more events - connection is dead
|
||||||
|
else:
|
||||||
|
print(f"⚠️ Error sending event to logging server: {e}")
|
||||||
|
# Don't disable entirely or try to reconnect - just continue with logging disabled
|
||||||
|
|
||||||
|
# Convenience methods for specific event types
|
||||||
|
|
||||||
|
async def log_query(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
model: str = None,
|
||||||
|
toolsets: list = None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Log a user query (the question/task given to the agent).
|
||||||
|
|
||||||
|
This is typically the first event in a session after connection.
|
||||||
|
Captures what the user asked and which model/tools will be used.
|
||||||
|
"""
|
||||||
|
await self._send_event("query", {
|
||||||
|
"query": query, # The user's question/instruction
|
||||||
|
"model": model, # Which AI model is being used
|
||||||
|
"toolsets": toolsets # Which tool categories are enabled
|
||||||
|
})
|
||||||
|
|
||||||
|
async def log_api_call(
|
||||||
|
self,
|
||||||
|
call_number: int,
|
||||||
|
message_count: int = None,
|
||||||
|
has_tools: bool = None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Log an API call to the AI model.
|
||||||
|
|
||||||
|
Called right before sending a request to the model (OpenAI/Anthropic/etc).
|
||||||
|
Helps track how many API calls are being made and conversation length.
|
||||||
|
"""
|
||||||
|
await self._send_event("api_call", {
|
||||||
|
"call_number": call_number, # Sequential number (1, 2, 3...)
|
||||||
|
"message_count": message_count, # How many messages in conversation so far
|
||||||
|
"has_tools": has_tools # Whether tools are available to the model
|
||||||
|
})
|
||||||
|
|
||||||
|
async def log_response(
|
||||||
|
self,
|
||||||
|
call_number: int,
|
||||||
|
content: str = None,
|
||||||
|
has_tool_calls: bool = False,
|
||||||
|
tool_call_count: int = 0,
|
||||||
|
duration: float = None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Log an assistant response from the AI model.
|
||||||
|
|
||||||
|
Called after receiving a response from the API.
|
||||||
|
Captures what the model said and whether it wants to use tools.
|
||||||
|
"""
|
||||||
|
await self._send_event("response", {
|
||||||
|
"call_number": call_number, # Which API call this response is from
|
||||||
|
"content": content, # What the model said (text response)
|
||||||
|
"has_tool_calls": has_tool_calls, # Did model request tool execution?
|
||||||
|
"tool_call_count": tool_call_count, # How many tools does it want to call?
|
||||||
|
"duration": duration # How long the API call took (seconds)
|
||||||
|
})
|
||||||
|
|
||||||
|
async def log_tool_call(
|
||||||
|
self,
|
||||||
|
call_number: int,
|
||||||
|
tool_index: int,
|
||||||
|
tool_name: str,
|
||||||
|
parameters: Dict[str, Any],
|
||||||
|
tool_call_id: str = None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Log a tool call (before executing the tool).
|
||||||
|
|
||||||
|
Captures which tool is being called and with what parameters.
|
||||||
|
This happens BEFORE the tool runs, so no results yet.
|
||||||
|
"""
|
||||||
|
await self._send_event("tool_call", {
|
||||||
|
"call_number": call_number, # Which API call requested this tool
|
||||||
|
"tool_index": tool_index, # Which tool in the sequence (if multiple)
|
||||||
|
"tool_name": tool_name, # Name of tool (e.g., "web_search", "web_extract")
|
||||||
|
"parameters": parameters, # Arguments passed to the tool (e.g., {"query": "Python", "limit": 5})
|
||||||
|
"tool_call_id": tool_call_id # Unique ID to link call with result
|
||||||
|
})
|
||||||
|
|
||||||
|
async def log_tool_result(
|
||||||
|
self,
|
||||||
|
call_number: int,
|
||||||
|
tool_index: int,
|
||||||
|
tool_name: str,
|
||||||
|
result: str = None,
|
||||||
|
error: str = None,
|
||||||
|
duration: float = None,
|
||||||
|
tool_call_id: str = None,
|
||||||
|
raw_result: str = None # NEW: Full untruncated result for verification
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Log a tool result (output from tool execution).
|
||||||
|
|
||||||
|
Captures both a truncated preview (for UI display) and the full raw result
|
||||||
|
(for verification and debugging). This is especially important for web tools
|
||||||
|
where you want to see what was scraped vs what the LLM processed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
call_number: Which API call this tool was part of
|
||||||
|
tool_index: Which tool in the sequence (1st, 2nd, etc.)
|
||||||
|
tool_name: Name of the tool that was executed
|
||||||
|
result: Tool output (will be truncated to 1000 chars for preview)
|
||||||
|
error: Error message if tool failed
|
||||||
|
duration: How long the tool took to execute (seconds)
|
||||||
|
tool_call_id: Unique ID linking this result to the tool call
|
||||||
|
raw_result: NEW - Full untruncated result for verification/debugging
|
||||||
|
"""
|
||||||
|
await self._send_event("tool_result", {
|
||||||
|
"call_number": call_number,
|
||||||
|
"tool_index": tool_index,
|
||||||
|
"tool_name": tool_name,
|
||||||
|
"result": result[:1000] if result else None, # Truncated preview (1000 chars max)
|
||||||
|
"raw_result": raw_result, # NEW: Full result - can be 100KB+ for web scraping
|
||||||
|
"error": error,
|
||||||
|
"duration": duration,
|
||||||
|
"tool_call_id": tool_call_id
|
||||||
|
})
|
||||||
|
|
||||||
|
async def log_error(
|
||||||
|
self,
|
||||||
|
error_message: str,
|
||||||
|
call_number: int = None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Log an error that occurred during agent execution.
|
||||||
|
|
||||||
|
Captures exceptions, API failures, or other issues.
|
||||||
|
"""
|
||||||
|
await self._send_event("error", {
|
||||||
|
"error_message": error_message, # Description of what went wrong
|
||||||
|
"call_number": call_number # Which API call caused the error (if applicable)
|
||||||
|
})
|
||||||
|
|
||||||
|
async def log_complete(
|
||||||
|
self,
|
||||||
|
final_response: str = None,
|
||||||
|
total_calls: int = None,
|
||||||
|
completed: bool = True
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Log session completion (final event before disconnecting).
|
||||||
|
|
||||||
|
Marks the end of the agent's execution and provides summary info.
|
||||||
|
"""
|
||||||
|
await self._send_event("complete", {
|
||||||
|
"final_response": final_response[:500] if final_response else None, # Truncated summary of final answer
|
||||||
|
"total_calls": total_calls, # How many API calls were made total
|
||||||
|
"completed": completed # Did it complete successfully? (true/false)
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
# Synchronous wrapper for convenience
|
||||||
|
class SyncWebSocketLogger:
|
||||||
|
"""
|
||||||
|
Synchronous wrapper around WebSocketLogger.
|
||||||
|
|
||||||
|
For use in synchronous code - creates an event loop internally.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, session_id: str, server_url: str = "ws://localhost:8000/ws", enabled: bool = True):
|
||||||
|
self.logger = WebSocketLogger(session_id, server_url, enabled)
|
||||||
|
self.loop = None
|
||||||
|
|
||||||
|
def connect(self):
|
||||||
|
"""Connect to server (synchronous)."""
|
||||||
|
self.loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(self.loop)
|
||||||
|
self.loop.run_until_complete(self.logger.connect())
|
||||||
|
|
||||||
|
def disconnect(self):
|
||||||
|
"""Disconnect from server (synchronous)."""
|
||||||
|
if self.loop:
|
||||||
|
self.loop.run_until_complete(self.logger.disconnect())
|
||||||
|
self.loop.close()
|
||||||
|
|
||||||
|
def _run_async(self, coro):
|
||||||
|
"""
|
||||||
|
Run an async coroutine synchronously.
|
||||||
|
|
||||||
|
Bridge between sync code (agent) and async code (WebSocket).
|
||||||
|
Uses event loop to execute async operations in sync context.
|
||||||
|
"""
|
||||||
|
if self.loop and self.loop.is_running():
|
||||||
|
# Already in event loop, just await
|
||||||
|
asyncio.create_task(coro)
|
||||||
|
else:
|
||||||
|
# Run in current loop
|
||||||
|
if self.loop:
|
||||||
|
self.loop.run_until_complete(coro)
|
||||||
|
|
||||||
|
def log_query(self, query: str, model: str = None, toolsets: list = None):
|
||||||
|
self._run_async(self.logger.log_query(query, model, toolsets))
|
||||||
|
|
||||||
|
def log_api_call(self, call_number: int, message_count: int = None, has_tools: bool = None):
|
||||||
|
self._run_async(self.logger.log_api_call(call_number, message_count, has_tools))
|
||||||
|
|
||||||
|
def log_response(self, call_number: int, content: str = None, has_tool_calls: bool = False,
|
||||||
|
tool_call_count: int = 0, duration: float = None):
|
||||||
|
self._run_async(self.logger.log_response(call_number, content, has_tool_calls,
|
||||||
|
tool_call_count, duration))
|
||||||
|
|
||||||
|
def log_tool_call(self, call_number: int, tool_index: int, tool_name: str,
|
||||||
|
parameters: Dict[str, Any], tool_call_id: str = None):
|
||||||
|
self._run_async(self.logger.log_tool_call(call_number, tool_index, tool_name,
|
||||||
|
parameters, tool_call_id))
|
||||||
|
|
||||||
|
def log_tool_result(self, call_number: int, tool_index: int, tool_name: str,
|
||||||
|
result: str = None, error: str = None, duration: float = None,
|
||||||
|
tool_call_id: str = None, raw_result: str = None):
|
||||||
|
self._run_async(self.logger.log_tool_result(call_number, tool_index, tool_name,
|
||||||
|
result, error, duration, tool_call_id, raw_result))
|
||||||
|
|
||||||
|
def log_error(self, error_message: str, call_number: int = None):
|
||||||
|
self._run_async(self.logger.log_error(error_message, call_number))
|
||||||
|
|
||||||
|
def log_complete(self, final_response: str = None, total_calls: int = None, completed: bool = True):
|
||||||
|
self._run_async(self.logger.log_complete(final_response, total_calls, completed))
|
||||||
|
|
||||||
566
image_generation_tool.py
Normal file
566
image_generation_tool.py
Normal file
@@ -0,0 +1,566 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Image Generation Tools Module
|
||||||
|
|
||||||
|
This module provides image generation tools using FAL.ai's FLUX.1 Krea model with
|
||||||
|
automatic upscaling via FAL.ai's Clarity Upscaler for enhanced image quality.
|
||||||
|
|
||||||
|
Available tools:
|
||||||
|
- image_generate_tool: Generate images from text prompts with automatic upscaling
|
||||||
|
|
||||||
|
Features:
|
||||||
|
- High-quality image generation using FLUX.1 Krea model
|
||||||
|
- Automatic 2x upscaling using Clarity Upscaler for enhanced quality
|
||||||
|
- Comprehensive parameter control (size, steps, guidance, etc.)
|
||||||
|
- Proper error handling and validation with fallback to original images
|
||||||
|
- Debug logging support
|
||||||
|
- Sync mode for immediate results
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
from image_generation_tool import image_generate_tool
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
# Generate and automatically upscale an image
|
||||||
|
result = await image_generate_tool(
|
||||||
|
prompt="A serene mountain landscape with cherry blossoms",
|
||||||
|
image_size="landscape_4_3",
|
||||||
|
num_images=1
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import asyncio
|
||||||
|
import uuid
|
||||||
|
import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, Any, Optional, Union
|
||||||
|
import fal_client
|
||||||
|
|
||||||
|
# Configuration for image generation
|
||||||
|
DEFAULT_MODEL = "fal-ai/flux/krea"
|
||||||
|
DEFAULT_IMAGE_SIZE = "landscape_4_3"
|
||||||
|
DEFAULT_NUM_INFERENCE_STEPS = 50
|
||||||
|
DEFAULT_GUIDANCE_SCALE = 4.5
|
||||||
|
DEFAULT_NUM_IMAGES = 1
|
||||||
|
DEFAULT_OUTPUT_FORMAT = "png"
|
||||||
|
|
||||||
|
# Configuration for automatic upscaling
|
||||||
|
UPSCALER_MODEL = "fal-ai/clarity-upscaler"
|
||||||
|
UPSCALER_FACTOR = 2
|
||||||
|
UPSCALER_SAFETY_CHECKER = False
|
||||||
|
UPSCALER_DEFAULT_PROMPT = "masterpiece, best quality, highres"
|
||||||
|
UPSCALER_NEGATIVE_PROMPT = "(worst quality, low quality, normal quality:2)"
|
||||||
|
UPSCALER_CREATIVITY = 0.35
|
||||||
|
UPSCALER_RESEMBLANCE = 0.6
|
||||||
|
UPSCALER_GUIDANCE_SCALE = 4
|
||||||
|
UPSCALER_NUM_INFERENCE_STEPS = 18
|
||||||
|
|
||||||
|
# Valid parameter values for validation based on FLUX Krea documentation
|
||||||
|
VALID_IMAGE_SIZES = [
|
||||||
|
"square_hd", "square", "portrait_4_3", "portrait_16_9", "landscape_4_3", "landscape_16_9"
|
||||||
|
]
|
||||||
|
VALID_OUTPUT_FORMATS = ["jpeg", "png"]
|
||||||
|
VALID_ACCELERATION_MODES = ["none", "regular", "high"]
|
||||||
|
|
||||||
|
# Debug mode configuration
|
||||||
|
DEBUG_MODE = os.getenv("IMAGE_TOOLS_DEBUG", "false").lower() == "true"
|
||||||
|
DEBUG_SESSION_ID = str(uuid.uuid4())
|
||||||
|
DEBUG_LOG_PATH = Path("./logs")
|
||||||
|
DEBUG_DATA = {
|
||||||
|
"session_id": DEBUG_SESSION_ID,
|
||||||
|
"start_time": datetime.datetime.now().isoformat(),
|
||||||
|
"debug_enabled": DEBUG_MODE,
|
||||||
|
"tool_calls": []
|
||||||
|
} if DEBUG_MODE else None
|
||||||
|
|
||||||
|
# Create logs directory if debug mode is enabled
|
||||||
|
if DEBUG_MODE:
|
||||||
|
DEBUG_LOG_PATH.mkdir(exist_ok=True)
|
||||||
|
print(f"🐛 Image generation debug mode enabled - Session ID: {DEBUG_SESSION_ID}")
|
||||||
|
|
||||||
|
|
||||||
|
def _log_debug_call(tool_name: str, call_data: Dict[str, Any]) -> None:
|
||||||
|
"""
|
||||||
|
Log a debug call entry to the global debug data structure.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_name (str): Name of the tool being called
|
||||||
|
call_data (Dict[str, Any]): Data about the call including parameters and results
|
||||||
|
"""
|
||||||
|
if not DEBUG_MODE or not DEBUG_DATA:
|
||||||
|
return
|
||||||
|
|
||||||
|
call_entry = {
|
||||||
|
"timestamp": datetime.datetime.now().isoformat(),
|
||||||
|
"tool_name": tool_name,
|
||||||
|
**call_data
|
||||||
|
}
|
||||||
|
|
||||||
|
DEBUG_DATA["tool_calls"].append(call_entry)
|
||||||
|
|
||||||
|
|
||||||
|
def _save_debug_log() -> None:
|
||||||
|
"""
|
||||||
|
Save the current debug data to a JSON file in the logs directory.
|
||||||
|
"""
|
||||||
|
if not DEBUG_MODE or not DEBUG_DATA:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
debug_filename = f"image_tools_debug_{DEBUG_SESSION_ID}.json"
|
||||||
|
debug_filepath = DEBUG_LOG_PATH / debug_filename
|
||||||
|
|
||||||
|
# Update end time
|
||||||
|
DEBUG_DATA["end_time"] = datetime.datetime.now().isoformat()
|
||||||
|
DEBUG_DATA["total_calls"] = len(DEBUG_DATA["tool_calls"])
|
||||||
|
|
||||||
|
with open(debug_filepath, 'w', encoding='utf-8') as f:
|
||||||
|
json.dump(DEBUG_DATA, f, indent=2, ensure_ascii=False)
|
||||||
|
|
||||||
|
print(f"🐛 Image generation debug log saved: {debug_filepath}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Error saving image generation debug log: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_parameters(
|
||||||
|
image_size: Union[str, Dict[str, int]],
|
||||||
|
num_inference_steps: int,
|
||||||
|
guidance_scale: float,
|
||||||
|
num_images: int,
|
||||||
|
output_format: str,
|
||||||
|
acceleration: str = "none"
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Validate and normalize image generation parameters for FLUX Krea model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_size: Either a preset string or custom size dict
|
||||||
|
num_inference_steps: Number of inference steps
|
||||||
|
guidance_scale: Guidance scale value
|
||||||
|
num_images: Number of images to generate
|
||||||
|
output_format: Output format for images
|
||||||
|
acceleration: Acceleration mode for generation speed
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[str, Any]: Validated and normalized parameters
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If any parameter is invalid
|
||||||
|
"""
|
||||||
|
validated = {}
|
||||||
|
|
||||||
|
# Validate image_size
|
||||||
|
if isinstance(image_size, str):
|
||||||
|
if image_size not in VALID_IMAGE_SIZES:
|
||||||
|
raise ValueError(f"Invalid image_size '{image_size}'. Must be one of: {VALID_IMAGE_SIZES}")
|
||||||
|
validated["image_size"] = image_size
|
||||||
|
elif isinstance(image_size, dict):
|
||||||
|
if "width" not in image_size or "height" not in image_size:
|
||||||
|
raise ValueError("Custom image_size must contain 'width' and 'height' keys")
|
||||||
|
if not isinstance(image_size["width"], int) or not isinstance(image_size["height"], int):
|
||||||
|
raise ValueError("Custom image_size width and height must be integers")
|
||||||
|
if image_size["width"] < 64 or image_size["height"] < 64:
|
||||||
|
raise ValueError("Custom image_size dimensions must be at least 64x64")
|
||||||
|
if image_size["width"] > 2048 or image_size["height"] > 2048:
|
||||||
|
raise ValueError("Custom image_size dimensions must not exceed 2048x2048")
|
||||||
|
validated["image_size"] = image_size
|
||||||
|
else:
|
||||||
|
raise ValueError("image_size must be either a preset string or a dict with width/height")
|
||||||
|
|
||||||
|
# Validate num_inference_steps
|
||||||
|
if not isinstance(num_inference_steps, int) or num_inference_steps < 1 or num_inference_steps > 100:
|
||||||
|
raise ValueError("num_inference_steps must be an integer between 1 and 100")
|
||||||
|
validated["num_inference_steps"] = num_inference_steps
|
||||||
|
|
||||||
|
# Validate guidance_scale (FLUX Krea default is 4.5)
|
||||||
|
if not isinstance(guidance_scale, (int, float)) or guidance_scale < 0.1 or guidance_scale > 20.0:
|
||||||
|
raise ValueError("guidance_scale must be a number between 0.1 and 20.0")
|
||||||
|
validated["guidance_scale"] = float(guidance_scale)
|
||||||
|
|
||||||
|
# Validate num_images
|
||||||
|
if not isinstance(num_images, int) or num_images < 1 or num_images > 4:
|
||||||
|
raise ValueError("num_images must be an integer between 1 and 4")
|
||||||
|
validated["num_images"] = num_images
|
||||||
|
|
||||||
|
# Validate output_format
|
||||||
|
if output_format not in VALID_OUTPUT_FORMATS:
|
||||||
|
raise ValueError(f"Invalid output_format '{output_format}'. Must be one of: {VALID_OUTPUT_FORMATS}")
|
||||||
|
validated["output_format"] = output_format
|
||||||
|
|
||||||
|
# Validate acceleration
|
||||||
|
if acceleration not in VALID_ACCELERATION_MODES:
|
||||||
|
raise ValueError(f"Invalid acceleration '{acceleration}'. Must be one of: {VALID_ACCELERATION_MODES}")
|
||||||
|
validated["acceleration"] = acceleration
|
||||||
|
|
||||||
|
return validated
|
||||||
|
|
||||||
|
|
||||||
|
async def _upscale_image(image_url: str, original_prompt: str) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Upscale an image using FAL.ai's Clarity Upscaler.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_url (str): URL of the image to upscale
|
||||||
|
original_prompt (str): Original prompt used to generate the image
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[str, Any]: Upscaled image data or None if upscaling fails
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
print(f"🔍 Upscaling image with Clarity Upscaler...")
|
||||||
|
|
||||||
|
# Prepare arguments for upscaler
|
||||||
|
upscaler_arguments = {
|
||||||
|
"image_url": image_url,
|
||||||
|
"prompt": f"{UPSCALER_DEFAULT_PROMPT}, {original_prompt}",
|
||||||
|
"upscale_factor": UPSCALER_FACTOR,
|
||||||
|
"negative_prompt": UPSCALER_NEGATIVE_PROMPT,
|
||||||
|
"creativity": UPSCALER_CREATIVITY,
|
||||||
|
"resemblance": UPSCALER_RESEMBLANCE,
|
||||||
|
"guidance_scale": UPSCALER_GUIDANCE_SCALE,
|
||||||
|
"num_inference_steps": UPSCALER_NUM_INFERENCE_STEPS,
|
||||||
|
"enable_safety_checker": UPSCALER_SAFETY_CHECKER
|
||||||
|
}
|
||||||
|
|
||||||
|
# Submit upscaler request
|
||||||
|
handler = await fal_client.submit_async(
|
||||||
|
UPSCALER_MODEL,
|
||||||
|
arguments=upscaler_arguments
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get the upscaled result
|
||||||
|
result = await handler.get()
|
||||||
|
|
||||||
|
if result and "image" in result:
|
||||||
|
upscaled_image = result["image"]
|
||||||
|
print(f"✅ Image upscaled successfully to {upscaled_image.get('width', 'unknown')}x{upscaled_image.get('height', 'unknown')}")
|
||||||
|
return {
|
||||||
|
"url": upscaled_image["url"],
|
||||||
|
"width": upscaled_image.get("width", 0),
|
||||||
|
"height": upscaled_image.get("height", 0),
|
||||||
|
"upscaled": True,
|
||||||
|
"upscale_factor": UPSCALER_FACTOR
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
print("❌ Upscaler returned invalid response")
|
||||||
|
return None
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Error upscaling image: {str(e)}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def image_generate_tool(
|
||||||
|
prompt: str,
|
||||||
|
image_size: Union[str, Dict[str, int]] = DEFAULT_IMAGE_SIZE,
|
||||||
|
num_inference_steps: int = DEFAULT_NUM_INFERENCE_STEPS,
|
||||||
|
guidance_scale: float = DEFAULT_GUIDANCE_SCALE,
|
||||||
|
num_images: int = DEFAULT_NUM_IMAGES,
|
||||||
|
enable_safety_checker: bool = True,
|
||||||
|
output_format: str = DEFAULT_OUTPUT_FORMAT,
|
||||||
|
acceleration: str = "none",
|
||||||
|
allow_nsfw_images: bool = True,
|
||||||
|
seed: Optional[int] = None
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Generate images from text prompts using FAL.ai's FLUX.1 Krea model with automatic upscaling.
|
||||||
|
|
||||||
|
This tool uses FAL.ai's FLUX.1 Krea model for high-quality text-to-image generation
|
||||||
|
with extensive customization options. Generated images are automatically upscaled 2x
|
||||||
|
using FAL.ai's Clarity Upscaler for enhanced quality. The final upscaled images are
|
||||||
|
returned as URLs that can be displayed using <img src="{URL}"></img> tags.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt (str): The text prompt describing the desired image
|
||||||
|
image_size (Union[str, Dict[str, int]]): Preset size or custom {"width": int, "height": int}
|
||||||
|
num_inference_steps (int): Number of denoising steps (1-50, default: 28)
|
||||||
|
guidance_scale (float): How closely to follow prompt (0.1-20.0, default: 4.5)
|
||||||
|
num_images (int): Number of images to generate (1-4, default: 1)
|
||||||
|
enable_safety_checker (bool): Enable content safety filtering (default: True)
|
||||||
|
output_format (str): Image format "jpeg" or "png" (default: "png")
|
||||||
|
acceleration (str): Generation speed "none", "regular", or "high" (default: "none")
|
||||||
|
allow_nsfw_images (bool): Allow generation of NSFW content (default: True)
|
||||||
|
seed (Optional[int]): Random seed for reproducible results (optional)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: JSON string containing minimal generation results:
|
||||||
|
{
|
||||||
|
"success": bool,
|
||||||
|
"image": str or None # URL of the upscaled image, or None if failed
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
debug_call_data = {
|
||||||
|
"parameters": {
|
||||||
|
"prompt": prompt,
|
||||||
|
"image_size": image_size,
|
||||||
|
"num_inference_steps": num_inference_steps,
|
||||||
|
"guidance_scale": guidance_scale,
|
||||||
|
"num_images": num_images,
|
||||||
|
"enable_safety_checker": enable_safety_checker,
|
||||||
|
"output_format": output_format,
|
||||||
|
"acceleration": acceleration,
|
||||||
|
"allow_nsfw_images": allow_nsfw_images,
|
||||||
|
"seed": seed
|
||||||
|
},
|
||||||
|
"error": None,
|
||||||
|
"success": False,
|
||||||
|
"images_generated": 0,
|
||||||
|
"generation_time": 0
|
||||||
|
}
|
||||||
|
|
||||||
|
start_time = datetime.datetime.now()
|
||||||
|
|
||||||
|
try:
|
||||||
|
print(f"🎨 Generating {num_images} image(s) with FLUX Krea: {prompt[:80]}{'...' if len(prompt) > 80 else ''}")
|
||||||
|
|
||||||
|
# Validate prompt
|
||||||
|
if not prompt or not isinstance(prompt, str) or len(prompt.strip()) == 0:
|
||||||
|
raise ValueError("Prompt is required and must be a non-empty string")
|
||||||
|
|
||||||
|
if len(prompt) > 1000:
|
||||||
|
raise ValueError("Prompt must be 1000 characters or less")
|
||||||
|
|
||||||
|
# Check API key availability
|
||||||
|
if not os.getenv("FAL_KEY"):
|
||||||
|
raise ValueError("FAL_KEY environment variable not set")
|
||||||
|
|
||||||
|
# Validate parameters
|
||||||
|
validated_params = _validate_parameters(
|
||||||
|
image_size, num_inference_steps, guidance_scale, num_images, output_format, acceleration
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prepare arguments for FAL.ai FLUX Krea API
|
||||||
|
arguments = {
|
||||||
|
"prompt": prompt.strip(),
|
||||||
|
"image_size": validated_params["image_size"],
|
||||||
|
"num_inference_steps": validated_params["num_inference_steps"],
|
||||||
|
"guidance_scale": validated_params["guidance_scale"],
|
||||||
|
"num_images": validated_params["num_images"],
|
||||||
|
"enable_safety_checker": enable_safety_checker,
|
||||||
|
"output_format": validated_params["output_format"],
|
||||||
|
"acceleration": validated_params["acceleration"],
|
||||||
|
"allow_nsfw_images": allow_nsfw_images,
|
||||||
|
"sync_mode": True # Use sync mode for immediate results
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add seed if provided
|
||||||
|
if seed is not None and isinstance(seed, int):
|
||||||
|
arguments["seed"] = seed
|
||||||
|
|
||||||
|
print(f"🚀 Submitting generation request to FAL.ai FLUX Krea...")
|
||||||
|
print(f" Model: {DEFAULT_MODEL}")
|
||||||
|
print(f" Size: {validated_params['image_size']}")
|
||||||
|
print(f" Steps: {validated_params['num_inference_steps']}")
|
||||||
|
print(f" Guidance: {validated_params['guidance_scale']}")
|
||||||
|
print(f" Acceleration: {validated_params['acceleration']}")
|
||||||
|
|
||||||
|
# Submit request to FAL.ai
|
||||||
|
handler = await fal_client.submit_async(
|
||||||
|
DEFAULT_MODEL,
|
||||||
|
arguments=arguments
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get the result
|
||||||
|
result = await handler.get()
|
||||||
|
|
||||||
|
generation_time = (datetime.datetime.now() - start_time).total_seconds()
|
||||||
|
|
||||||
|
# Process the response
|
||||||
|
if not result or "images" not in result:
|
||||||
|
raise ValueError("Invalid response from FAL.ai API - no images returned")
|
||||||
|
|
||||||
|
images = result.get("images", [])
|
||||||
|
if not images:
|
||||||
|
raise ValueError("No images were generated")
|
||||||
|
|
||||||
|
# Format image data and upscale images
|
||||||
|
formatted_images = []
|
||||||
|
for img in images:
|
||||||
|
if isinstance(img, dict) and "url" in img:
|
||||||
|
original_image = {
|
||||||
|
"url": img["url"],
|
||||||
|
"width": img.get("width", 0),
|
||||||
|
"height": img.get("height", 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
# Attempt to upscale the image
|
||||||
|
upscaled_image = await _upscale_image(img["url"], prompt.strip())
|
||||||
|
|
||||||
|
if upscaled_image:
|
||||||
|
# Use upscaled image if successful
|
||||||
|
formatted_images.append(upscaled_image)
|
||||||
|
else:
|
||||||
|
# Fall back to original image if upscaling fails
|
||||||
|
print(f"⚠️ Using original image as fallback")
|
||||||
|
original_image["upscaled"] = False
|
||||||
|
formatted_images.append(original_image)
|
||||||
|
|
||||||
|
if not formatted_images:
|
||||||
|
raise ValueError("No valid image URLs returned from API")
|
||||||
|
|
||||||
|
upscaled_count = sum(1 for img in formatted_images if img.get("upscaled", False))
|
||||||
|
print(f"✅ Generated {len(formatted_images)} image(s) in {generation_time:.1f}s ({upscaled_count} upscaled)")
|
||||||
|
|
||||||
|
# Prepare successful response - minimal format
|
||||||
|
response_data = {
|
||||||
|
"success": True,
|
||||||
|
"image": formatted_images[0]["url"] if formatted_images else None
|
||||||
|
}
|
||||||
|
|
||||||
|
debug_call_data["success"] = True
|
||||||
|
debug_call_data["images_generated"] = len(formatted_images)
|
||||||
|
debug_call_data["generation_time"] = generation_time
|
||||||
|
|
||||||
|
# Log debug information
|
||||||
|
_log_debug_call("image_generate_tool", debug_call_data)
|
||||||
|
_save_debug_log()
|
||||||
|
|
||||||
|
return json.dumps(response_data, indent=2)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
generation_time = (datetime.datetime.now() - start_time).total_seconds()
|
||||||
|
error_msg = f"Error generating image: {str(e)}"
|
||||||
|
print(f"❌ {error_msg}")
|
||||||
|
|
||||||
|
# Prepare error response - minimal format
|
||||||
|
response_data = {
|
||||||
|
"success": False,
|
||||||
|
"image": None
|
||||||
|
}
|
||||||
|
|
||||||
|
debug_call_data["error"] = error_msg
|
||||||
|
debug_call_data["generation_time"] = generation_time
|
||||||
|
_log_debug_call("image_generate_tool", debug_call_data)
|
||||||
|
_save_debug_log()
|
||||||
|
|
||||||
|
return json.dumps(response_data, indent=2)
|
||||||
|
|
||||||
|
|
||||||
|
def check_fal_api_key() -> bool:
|
||||||
|
"""
|
||||||
|
Check if the FAL.ai API key is available in environment variables.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if API key is set, False otherwise
|
||||||
|
"""
|
||||||
|
return bool(os.getenv("FAL_KEY"))
|
||||||
|
|
||||||
|
|
||||||
|
def check_image_generation_requirements() -> bool:
|
||||||
|
"""
|
||||||
|
Check if all requirements for image generation tools are met.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if requirements are met, False otherwise
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Check API key
|
||||||
|
if not check_fal_api_key():
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check if fal_client is available
|
||||||
|
import fal_client
|
||||||
|
return True
|
||||||
|
|
||||||
|
except ImportError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def get_debug_session_info() -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Get information about the current debug session.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[str, Any]: Dictionary containing debug session information
|
||||||
|
"""
|
||||||
|
if not DEBUG_MODE or not DEBUG_DATA:
|
||||||
|
return {
|
||||||
|
"enabled": False,
|
||||||
|
"session_id": None,
|
||||||
|
"log_path": None,
|
||||||
|
"total_calls": 0
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
"enabled": True,
|
||||||
|
"session_id": DEBUG_SESSION_ID,
|
||||||
|
"log_path": str(DEBUG_LOG_PATH / f"image_tools_debug_{DEBUG_SESSION_ID}.json"),
|
||||||
|
"total_calls": len(DEBUG_DATA["tool_calls"])
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
"""
|
||||||
|
Simple test/demo when run directly
|
||||||
|
"""
|
||||||
|
print("🎨 Image Generation Tools Module - FLUX.1 Krea + Auto Upscaling")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
# Check if API key is available
|
||||||
|
api_available = check_fal_api_key()
|
||||||
|
|
||||||
|
if not api_available:
|
||||||
|
print("❌ FAL_KEY environment variable not set")
|
||||||
|
print("Please set your API key: export FAL_KEY='your-key-here'")
|
||||||
|
print("Get API key at: https://fal.ai/")
|
||||||
|
exit(1)
|
||||||
|
else:
|
||||||
|
print("✅ FAL.ai API key found")
|
||||||
|
|
||||||
|
# Check if fal_client is available
|
||||||
|
try:
|
||||||
|
import fal_client
|
||||||
|
print("✅ fal_client library available")
|
||||||
|
except ImportError:
|
||||||
|
print("❌ fal_client library not found")
|
||||||
|
print("Please install: pip install fal-client")
|
||||||
|
exit(1)
|
||||||
|
|
||||||
|
print("🛠️ Image generation tools ready for use!")
|
||||||
|
print(f"🤖 Using model: {DEFAULT_MODEL}")
|
||||||
|
print(f"🔍 Auto-upscaling with: {UPSCALER_MODEL} ({UPSCALER_FACTOR}x)")
|
||||||
|
|
||||||
|
# Show debug mode status
|
||||||
|
if DEBUG_MODE:
|
||||||
|
print(f"🐛 Debug mode ENABLED - Session ID: {DEBUG_SESSION_ID}")
|
||||||
|
print(f" Debug logs will be saved to: ./logs/image_tools_debug_{DEBUG_SESSION_ID}.json")
|
||||||
|
else:
|
||||||
|
print("🐛 Debug mode disabled (set IMAGE_TOOLS_DEBUG=true to enable)")
|
||||||
|
|
||||||
|
print("\nBasic usage:")
|
||||||
|
print(" from image_generation_tool import image_generate_tool")
|
||||||
|
print(" import asyncio")
|
||||||
|
print("")
|
||||||
|
print(" async def main():")
|
||||||
|
print(" # Generate image with automatic 2x upscaling")
|
||||||
|
print(" result = await image_generate_tool(")
|
||||||
|
print(" prompt='A serene mountain landscape with cherry blossoms',")
|
||||||
|
print(" image_size='landscape_4_3',")
|
||||||
|
print(" num_images=1")
|
||||||
|
print(" )")
|
||||||
|
print(" print(result)")
|
||||||
|
print(" asyncio.run(main())")
|
||||||
|
|
||||||
|
print("\nSupported image sizes:")
|
||||||
|
for size in VALID_IMAGE_SIZES:
|
||||||
|
print(f" - {size}")
|
||||||
|
print(" - Custom: {'width': 512, 'height': 768} (if needed)")
|
||||||
|
|
||||||
|
print("\nAcceleration modes:")
|
||||||
|
for mode in VALID_ACCELERATION_MODES:
|
||||||
|
print(f" - {mode}")
|
||||||
|
|
||||||
|
print("\nExample prompts:")
|
||||||
|
print(" - 'A candid street photo of a woman with a pink bob and bold eyeliner'")
|
||||||
|
print(" - 'Modern architecture building with glass facade, sunset lighting'")
|
||||||
|
print(" - 'Abstract art with vibrant colors and geometric patterns'")
|
||||||
|
print(" - 'Portrait of a wise old owl perched on ancient tree branch'")
|
||||||
|
print(" - 'Futuristic cityscape with flying cars and neon lights'")
|
||||||
|
|
||||||
|
print("\nDebug mode:")
|
||||||
|
print(" # Enable debug logging")
|
||||||
|
print(" export IMAGE_TOOLS_DEBUG=true")
|
||||||
|
print(" # Debug logs capture all image generation calls and results")
|
||||||
|
print(" # Logs saved to: ./logs/image_tools_debug_UUID.json")
|
||||||
589
mixture_of_agents_tool.py
Normal file
589
mixture_of_agents_tool.py
Normal file
@@ -0,0 +1,589 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Mixture-of-Agents Tool Module
|
||||||
|
|
||||||
|
This module implements the Mixture-of-Agents (MoA) methodology that leverages
|
||||||
|
the collective strengths of multiple LLMs through a layered architecture to
|
||||||
|
achieve state-of-the-art performance on complex reasoning tasks.
|
||||||
|
|
||||||
|
Based on the research paper: "Mixture-of-Agents Enhances Large Language Model Capabilities"
|
||||||
|
by Junlin Wang et al. (arXiv:2406.04692v1)
|
||||||
|
|
||||||
|
Key Features:
|
||||||
|
- Multi-layer LLM collaboration for enhanced reasoning
|
||||||
|
- Parallel processing of reference models for efficiency
|
||||||
|
- Intelligent aggregation and synthesis of diverse responses
|
||||||
|
- Specialized for extremely difficult problems requiring intense reasoning
|
||||||
|
- Optimized for coding, mathematics, and complex analytical tasks
|
||||||
|
|
||||||
|
Available Tool:
|
||||||
|
- mixture_of_agents_tool: Process complex queries using multiple frontier models
|
||||||
|
|
||||||
|
Architecture:
|
||||||
|
1. Reference models generate diverse initial responses in parallel
|
||||||
|
2. Aggregator model synthesizes responses into a high-quality output
|
||||||
|
3. Multiple layers can be used for iterative refinement (future enhancement)
|
||||||
|
|
||||||
|
Models Used:
|
||||||
|
- Reference Models: claude-opus-4-20250514, gemini-2.5-pro, o4-mini, deepseek-r1
|
||||||
|
- Aggregator Model: claude-opus-4-20250514 (highest capability for synthesis)
|
||||||
|
|
||||||
|
Configuration:
|
||||||
|
To customize the MoA setup, modify the configuration constants at the top of this file:
|
||||||
|
- REFERENCE_MODELS: List of models for generating diverse initial responses
|
||||||
|
- AGGREGATOR_MODEL: Model used to synthesize the final response
|
||||||
|
- REFERENCE_TEMPERATURE/AGGREGATOR_TEMPERATURE: Sampling temperatures
|
||||||
|
- MIN_SUCCESSFUL_REFERENCES: Minimum successful models needed to proceed
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
from mixture_of_agents_tool import mixture_of_agents_tool
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
# Process a complex query
|
||||||
|
result = await mixture_of_agents_tool(
|
||||||
|
user_prompt="Solve this complex mathematical proof..."
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import asyncio
|
||||||
|
import uuid
|
||||||
|
import datetime
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, Any, List, Optional
|
||||||
|
from openai import AsyncOpenAI
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
# Initialize Nous Research API client for MoA processing
|
||||||
|
nous_client = AsyncOpenAI(
|
||||||
|
api_key="sk-_yoJ_CBLbSNN2R5rGZ_rpg",
|
||||||
|
base_url="https://inference-api.nousresearch.com/v1"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Configuration for MoA processing
|
||||||
|
# Reference models - these generate diverse initial responses in parallel
|
||||||
|
REFERENCE_MODELS = [
|
||||||
|
"claude-opus-4-20250514",
|
||||||
|
"gemini-2.5-pro",
|
||||||
|
"gpt-5",
|
||||||
|
"deepseek-r1"
|
||||||
|
]
|
||||||
|
|
||||||
|
# Aggregator model - synthesizes reference responses into final output
|
||||||
|
AGGREGATOR_MODEL = "claude-opus-4-20250514" # Use highest capability model for aggregation
|
||||||
|
|
||||||
|
# Temperature settings optimized for MoA performance
|
||||||
|
REFERENCE_TEMPERATURE = 0.6 # Balanced creativity for diverse perspectives
|
||||||
|
AGGREGATOR_TEMPERATURE = 0.4 # Focused synthesis for consistency
|
||||||
|
|
||||||
|
# Failure handling configuration
|
||||||
|
MIN_SUCCESSFUL_REFERENCES = 1 # Minimum successful reference models needed to proceed
|
||||||
|
|
||||||
|
# System prompt for the aggregator model (from the research paper)
|
||||||
|
AGGREGATOR_SYSTEM_PROMPT = """You have been provided with a set of responses from various open-source models to the latest user query. Your task is to synthesize these responses into a single, high-quality response. It is crucial to critically evaluate the information provided in these responses, recognizing that some of it may be biased or incorrect. Your response should not simply replicate the given answers but should offer a refined, accurate, and comprehensive reply to the instruction. Ensure your response is well-structured, coherent, and adheres to the highest standards of accuracy and reliability.
|
||||||
|
|
||||||
|
Responses from models:"""
|
||||||
|
|
||||||
|
# Debug mode configuration
|
||||||
|
DEBUG_MODE = os.getenv("MOA_TOOLS_DEBUG", "false").lower() == "true"
|
||||||
|
DEBUG_SESSION_ID = str(uuid.uuid4())
|
||||||
|
DEBUG_LOG_PATH = Path("./logs")
|
||||||
|
DEBUG_DATA = {
|
||||||
|
"session_id": DEBUG_SESSION_ID,
|
||||||
|
"start_time": datetime.datetime.now().isoformat(),
|
||||||
|
"debug_enabled": DEBUG_MODE,
|
||||||
|
"tool_calls": []
|
||||||
|
} if DEBUG_MODE else None
|
||||||
|
|
||||||
|
# Create logs directory if debug mode is enabled
|
||||||
|
if DEBUG_MODE:
|
||||||
|
DEBUG_LOG_PATH.mkdir(exist_ok=True)
|
||||||
|
print(f"🐛 MoA debug mode enabled - Session ID: {DEBUG_SESSION_ID}")
|
||||||
|
|
||||||
|
|
||||||
|
def _log_debug_call(tool_name: str, call_data: Dict[str, Any]) -> None:
|
||||||
|
"""
|
||||||
|
Log a debug call entry to the global debug data structure.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_name (str): Name of the tool being called
|
||||||
|
call_data (Dict[str, Any]): Data about the call including parameters and results
|
||||||
|
"""
|
||||||
|
if not DEBUG_MODE or not DEBUG_DATA:
|
||||||
|
return
|
||||||
|
|
||||||
|
call_entry = {
|
||||||
|
"timestamp": datetime.datetime.now().isoformat(),
|
||||||
|
"tool_name": tool_name,
|
||||||
|
**call_data
|
||||||
|
}
|
||||||
|
|
||||||
|
DEBUG_DATA["tool_calls"].append(call_entry)
|
||||||
|
|
||||||
|
|
||||||
|
def _save_debug_log() -> None:
|
||||||
|
"""
|
||||||
|
Save the current debug data to a JSON file in the logs directory.
|
||||||
|
"""
|
||||||
|
if not DEBUG_MODE or not DEBUG_DATA:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
debug_filename = f"moa_tools_debug_{DEBUG_SESSION_ID}.json"
|
||||||
|
debug_filepath = DEBUG_LOG_PATH / debug_filename
|
||||||
|
|
||||||
|
# Update end time
|
||||||
|
DEBUG_DATA["end_time"] = datetime.datetime.now().isoformat()
|
||||||
|
DEBUG_DATA["total_calls"] = len(DEBUG_DATA["tool_calls"])
|
||||||
|
|
||||||
|
with open(debug_filepath, 'w', encoding='utf-8') as f:
|
||||||
|
json.dump(DEBUG_DATA, f, indent=2, ensure_ascii=False)
|
||||||
|
|
||||||
|
print(f"🐛 MoA debug log saved: {debug_filepath}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Error saving MoA debug log: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
def _construct_aggregator_prompt(system_prompt: str, responses: List[str]) -> str:
|
||||||
|
"""
|
||||||
|
Construct the final system prompt for the aggregator including all model responses.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
system_prompt (str): Base system prompt for aggregation
|
||||||
|
responses (List[str]): List of responses from reference models
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Complete system prompt with enumerated responses
|
||||||
|
"""
|
||||||
|
response_text = "\n".join([f"{i+1}. {response}" for i, response in enumerate(responses)])
|
||||||
|
return f"{system_prompt}\n\n{response_text}"
|
||||||
|
|
||||||
|
|
||||||
|
async def _run_reference_model_safe(
|
||||||
|
model: str,
|
||||||
|
user_prompt: str,
|
||||||
|
temperature: float = REFERENCE_TEMPERATURE,
|
||||||
|
max_tokens: int = 32000,
|
||||||
|
max_retries: int = 3
|
||||||
|
) -> tuple[str, str, bool]:
|
||||||
|
"""
|
||||||
|
Run a single reference model with retry logic and graceful failure handling.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (str): Model identifier to use
|
||||||
|
user_prompt (str): The user's query
|
||||||
|
temperature (float): Sampling temperature for response generation
|
||||||
|
max_tokens (int): Maximum tokens in response
|
||||||
|
max_retries (int): Maximum number of retry attempts
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple[str, str, bool]: (model_name, response_content_or_error, success_flag)
|
||||||
|
"""
|
||||||
|
for attempt in range(max_retries):
|
||||||
|
try:
|
||||||
|
print(f"🤖 Querying {model} (attempt {attempt + 1}/{max_retries})")
|
||||||
|
|
||||||
|
# Build parameters for the API call
|
||||||
|
api_params = {
|
||||||
|
"model": model,
|
||||||
|
"messages": [{"role": "user", "content": user_prompt}]
|
||||||
|
}
|
||||||
|
|
||||||
|
# GPT models (especially gpt-4o-mini) don't support custom temperature values
|
||||||
|
# Only include temperature for non-GPT models
|
||||||
|
if not model.lower().startswith('gpt-'):
|
||||||
|
api_params["temperature"] = temperature
|
||||||
|
|
||||||
|
response = await nous_client.chat.completions.create(**api_params)
|
||||||
|
|
||||||
|
content = response.choices[0].message.content.strip()
|
||||||
|
print(f"✅ {model} responded ({len(content)} characters)")
|
||||||
|
return model, content, True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_str = str(e)
|
||||||
|
# Log more detailed error information for debugging
|
||||||
|
if "invalid" in error_str.lower():
|
||||||
|
print(f"⚠️ {model} invalid request error (attempt {attempt + 1}): {error_str}")
|
||||||
|
elif "rate" in error_str.lower() or "limit" in error_str.lower():
|
||||||
|
print(f"⚠️ {model} rate limit error (attempt {attempt + 1}): {error_str}")
|
||||||
|
else:
|
||||||
|
print(f"⚠️ {model} unknown error (attempt {attempt + 1}): {error_str}")
|
||||||
|
|
||||||
|
if attempt < max_retries - 1:
|
||||||
|
# Exponential backoff for rate limiting
|
||||||
|
sleep_time = 2 ** attempt
|
||||||
|
print(f" Retrying in {sleep_time}s...")
|
||||||
|
await asyncio.sleep(sleep_time)
|
||||||
|
else:
|
||||||
|
error_msg = f"{model} failed after {max_retries} attempts: {error_str}"
|
||||||
|
print(f"❌ {error_msg}")
|
||||||
|
return model, error_msg, False
|
||||||
|
|
||||||
|
|
||||||
|
async def _run_aggregator_model(
|
||||||
|
system_prompt: str,
|
||||||
|
user_prompt: str,
|
||||||
|
temperature: float = AGGREGATOR_TEMPERATURE,
|
||||||
|
max_tokens: int = None
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Run the aggregator model to synthesize the final response.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
system_prompt (str): System prompt with all reference responses
|
||||||
|
user_prompt (str): Original user query
|
||||||
|
temperature (float): Focused temperature for consistent aggregation
|
||||||
|
max_tokens (int): Maximum tokens in final response
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Synthesized final response
|
||||||
|
"""
|
||||||
|
print(f"🧠 Running aggregator model: {AGGREGATOR_MODEL}")
|
||||||
|
|
||||||
|
# Build parameters for the API call
|
||||||
|
api_params = {
|
||||||
|
"model": AGGREGATOR_MODEL,
|
||||||
|
"messages": [
|
||||||
|
{"role": "system", "content": system_prompt},
|
||||||
|
{"role": "user", "content": user_prompt}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
# GPT models (especially gpt-4o-mini) don't support custom temperature values
|
||||||
|
# Only include temperature for non-GPT models
|
||||||
|
if not AGGREGATOR_MODEL.lower().startswith('gpt-'):
|
||||||
|
api_params["temperature"] = temperature
|
||||||
|
|
||||||
|
response = await nous_client.chat.completions.create(**api_params)
|
||||||
|
|
||||||
|
content = response.choices[0].message.content.strip()
|
||||||
|
print(f"✅ Aggregation complete ({len(content)} characters)")
|
||||||
|
return content
|
||||||
|
|
||||||
|
|
||||||
|
async def mixture_of_agents_tool(
|
||||||
|
user_prompt: str,
|
||||||
|
reference_models: Optional[List[str]] = None,
|
||||||
|
aggregator_model: Optional[str] = None
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Process a complex query using the Mixture-of-Agents methodology.
|
||||||
|
|
||||||
|
This tool leverages multiple frontier language models to collaboratively solve
|
||||||
|
extremely difficult problems requiring intense reasoning. It's particularly
|
||||||
|
effective for:
|
||||||
|
- Complex mathematical proofs and calculations
|
||||||
|
- Advanced coding problems and algorithm design
|
||||||
|
- Multi-step analytical reasoning tasks
|
||||||
|
- Problems requiring diverse domain expertise
|
||||||
|
- Tasks where single models show limitations
|
||||||
|
|
||||||
|
The MoA approach uses a fixed 2-layer architecture:
|
||||||
|
1. Layer 1: Multiple reference models generate diverse responses in parallel (temp=0.6)
|
||||||
|
2. Layer 2: Aggregator model synthesizes the best elements into final response (temp=0.4)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_prompt (str): The complex query or problem to solve
|
||||||
|
reference_models (Optional[List[str]]): Custom reference models to use
|
||||||
|
aggregator_model (Optional[str]): Custom aggregator model to use
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: JSON string containing the MoA results with the following structure:
|
||||||
|
{
|
||||||
|
"success": bool,
|
||||||
|
"response": str,
|
||||||
|
"models_used": {
|
||||||
|
"reference_models": List[str],
|
||||||
|
"aggregator_model": str
|
||||||
|
},
|
||||||
|
"processing_time": float
|
||||||
|
}
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Exception: If MoA processing fails or API key is not set
|
||||||
|
"""
|
||||||
|
start_time = datetime.datetime.now()
|
||||||
|
|
||||||
|
debug_call_data = {
|
||||||
|
"parameters": {
|
||||||
|
"user_prompt": user_prompt[:200] + "..." if len(user_prompt) > 200 else user_prompt,
|
||||||
|
"reference_models": reference_models or REFERENCE_MODELS,
|
||||||
|
"aggregator_model": aggregator_model or AGGREGATOR_MODEL,
|
||||||
|
"reference_temperature": REFERENCE_TEMPERATURE,
|
||||||
|
"aggregator_temperature": AGGREGATOR_TEMPERATURE,
|
||||||
|
"min_successful_references": MIN_SUCCESSFUL_REFERENCES
|
||||||
|
},
|
||||||
|
"error": None,
|
||||||
|
"success": False,
|
||||||
|
"reference_responses_count": 0,
|
||||||
|
"failed_models_count": 0,
|
||||||
|
"failed_models": [],
|
||||||
|
"final_response_length": 0,
|
||||||
|
"processing_time_seconds": 0,
|
||||||
|
"models_used": {}
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
print(f"🚀 Starting Mixture-of-Agents processing...")
|
||||||
|
print(f"📝 Query: {user_prompt[:100]}{'...' if len(user_prompt) > 100 else ''}")
|
||||||
|
|
||||||
|
# Validate API key availability
|
||||||
|
if not os.getenv("NOUS_API_KEY"):
|
||||||
|
raise ValueError("NOUS_API_KEY environment variable not set")
|
||||||
|
|
||||||
|
# Use provided models or defaults
|
||||||
|
ref_models = reference_models or REFERENCE_MODELS
|
||||||
|
agg_model = aggregator_model or AGGREGATOR_MODEL
|
||||||
|
|
||||||
|
print(f"🔄 Using {len(ref_models)} reference models in 2-layer MoA architecture")
|
||||||
|
|
||||||
|
# Layer 1: Generate diverse responses from reference models (with failure handling)
|
||||||
|
print("📡 Layer 1: Generating reference responses...")
|
||||||
|
model_results = await asyncio.gather(*[
|
||||||
|
_run_reference_model_safe(model, user_prompt, REFERENCE_TEMPERATURE)
|
||||||
|
for model in ref_models
|
||||||
|
])
|
||||||
|
|
||||||
|
# Separate successful and failed responses
|
||||||
|
successful_responses = []
|
||||||
|
failed_models = []
|
||||||
|
|
||||||
|
for model_name, content, success in model_results:
|
||||||
|
if success:
|
||||||
|
successful_responses.append(content)
|
||||||
|
else:
|
||||||
|
failed_models.append(model_name)
|
||||||
|
|
||||||
|
successful_count = len(successful_responses)
|
||||||
|
failed_count = len(failed_models)
|
||||||
|
|
||||||
|
print(f"📊 Reference model results: {successful_count} successful, {failed_count} failed")
|
||||||
|
|
||||||
|
if failed_models:
|
||||||
|
print(f"⚠️ Failed models: {', '.join(failed_models)}")
|
||||||
|
|
||||||
|
# Check if we have enough successful responses to proceed
|
||||||
|
if successful_count < MIN_SUCCESSFUL_REFERENCES:
|
||||||
|
raise ValueError(f"Insufficient successful reference models ({successful_count}/{len(ref_models)}). Need at least {MIN_SUCCESSFUL_REFERENCES} successful responses.")
|
||||||
|
|
||||||
|
debug_call_data["reference_responses_count"] = successful_count
|
||||||
|
debug_call_data["failed_models_count"] = failed_count
|
||||||
|
debug_call_data["failed_models"] = failed_models
|
||||||
|
|
||||||
|
# Layer 2: Aggregate responses using the aggregator model
|
||||||
|
print("🧠 Layer 2: Synthesizing final response...")
|
||||||
|
aggregator_system_prompt = _construct_aggregator_prompt(
|
||||||
|
AGGREGATOR_SYSTEM_PROMPT,
|
||||||
|
successful_responses
|
||||||
|
)
|
||||||
|
|
||||||
|
final_response = await _run_aggregator_model(
|
||||||
|
aggregator_system_prompt,
|
||||||
|
user_prompt,
|
||||||
|
AGGREGATOR_TEMPERATURE
|
||||||
|
)
|
||||||
|
|
||||||
|
# Calculate processing time
|
||||||
|
end_time = datetime.datetime.now()
|
||||||
|
processing_time = (end_time - start_time).total_seconds()
|
||||||
|
|
||||||
|
print(f"✅ MoA processing completed in {processing_time:.2f} seconds")
|
||||||
|
|
||||||
|
# Prepare successful response (only final aggregated result, minimal fields)
|
||||||
|
result = {
|
||||||
|
"success": True,
|
||||||
|
"response": final_response,
|
||||||
|
"models_used": {
|
||||||
|
"reference_models": ref_models,
|
||||||
|
"aggregator_model": agg_model
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
debug_call_data["success"] = True
|
||||||
|
debug_call_data["final_response_length"] = len(final_response)
|
||||||
|
debug_call_data["processing_time_seconds"] = processing_time
|
||||||
|
debug_call_data["models_used"] = result["models_used"]
|
||||||
|
|
||||||
|
# Log debug information
|
||||||
|
_log_debug_call("mixture_of_agents_tool", debug_call_data)
|
||||||
|
_save_debug_log()
|
||||||
|
|
||||||
|
return json.dumps(result, indent=2)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"Error in MoA processing: {str(e)}"
|
||||||
|
print(f"❌ {error_msg}")
|
||||||
|
|
||||||
|
# Calculate processing time even for errors
|
||||||
|
end_time = datetime.datetime.now()
|
||||||
|
processing_time = (end_time - start_time).total_seconds()
|
||||||
|
|
||||||
|
# Prepare error response (minimal fields)
|
||||||
|
result = {
|
||||||
|
"success": False,
|
||||||
|
"response": "MoA processing failed. Please try again or use a single model for this query.",
|
||||||
|
"models_used": {
|
||||||
|
"reference_models": reference_models or REFERENCE_MODELS,
|
||||||
|
"aggregator_model": aggregator_model or AGGREGATOR_MODEL
|
||||||
|
},
|
||||||
|
"error": error_msg
|
||||||
|
}
|
||||||
|
|
||||||
|
debug_call_data["error"] = error_msg
|
||||||
|
debug_call_data["processing_time_seconds"] = processing_time
|
||||||
|
_log_debug_call("mixture_of_agents_tool", debug_call_data)
|
||||||
|
_save_debug_log()
|
||||||
|
|
||||||
|
return json.dumps(result, indent=2)
|
||||||
|
|
||||||
|
|
||||||
|
def check_nous_api_key() -> bool:
|
||||||
|
"""
|
||||||
|
Check if the Nous Research API key is available in environment variables.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if API key is set, False otherwise
|
||||||
|
"""
|
||||||
|
return bool(os.getenv("NOUS_API_KEY"))
|
||||||
|
|
||||||
|
|
||||||
|
def check_moa_requirements() -> bool:
|
||||||
|
"""
|
||||||
|
Check if all requirements for MoA tools are met.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if requirements are met, False otherwise
|
||||||
|
"""
|
||||||
|
return check_nous_api_key()
|
||||||
|
|
||||||
|
|
||||||
|
def get_debug_session_info() -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Get information about the current debug session.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[str, Any]: Dictionary containing debug session information
|
||||||
|
"""
|
||||||
|
if not DEBUG_MODE or not DEBUG_DATA:
|
||||||
|
return {
|
||||||
|
"enabled": False,
|
||||||
|
"session_id": None,
|
||||||
|
"log_path": None,
|
||||||
|
"total_calls": 0
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
"enabled": True,
|
||||||
|
"session_id": DEBUG_SESSION_ID,
|
||||||
|
"log_path": str(DEBUG_LOG_PATH / f"moa_tools_debug_{DEBUG_SESSION_ID}.json"),
|
||||||
|
"total_calls": len(DEBUG_DATA["tool_calls"])
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_available_models() -> Dict[str, List[str]]:
|
||||||
|
"""
|
||||||
|
Get information about available models for MoA processing.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[str, List[str]]: Dictionary with reference and aggregator models
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"reference_models": REFERENCE_MODELS,
|
||||||
|
"aggregator_models": [AGGREGATOR_MODEL],
|
||||||
|
"supported_models": REFERENCE_MODELS + [AGGREGATOR_MODEL]
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_moa_configuration() -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Get the current MoA configuration settings.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[str, Any]: Dictionary containing all configuration parameters
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"reference_models": REFERENCE_MODELS,
|
||||||
|
"aggregator_model": AGGREGATOR_MODEL,
|
||||||
|
"reference_temperature": REFERENCE_TEMPERATURE,
|
||||||
|
"aggregator_temperature": AGGREGATOR_TEMPERATURE,
|
||||||
|
"min_successful_references": MIN_SUCCESSFUL_REFERENCES,
|
||||||
|
"total_reference_models": len(REFERENCE_MODELS),
|
||||||
|
"failure_tolerance": f"{len(REFERENCE_MODELS) - MIN_SUCCESSFUL_REFERENCES}/{len(REFERENCE_MODELS)} models can fail"
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
"""
|
||||||
|
Simple test/demo when run directly
|
||||||
|
"""
|
||||||
|
print("🤖 Mixture-of-Agents Tool Module")
|
||||||
|
print("=" * 50)
|
||||||
|
|
||||||
|
# Check if API key is available
|
||||||
|
api_available = check_nous_api_key()
|
||||||
|
|
||||||
|
if not api_available:
|
||||||
|
print("❌ NOUS_API_KEY environment variable not set")
|
||||||
|
print("Please set your API key: export NOUS_API_KEY='your-key-here'")
|
||||||
|
print("Get API key at: https://inference-api.nousresearch.com/")
|
||||||
|
exit(1)
|
||||||
|
else:
|
||||||
|
print("✅ Nous Research API key found")
|
||||||
|
|
||||||
|
print("🛠️ MoA tools ready for use!")
|
||||||
|
|
||||||
|
# Show current configuration
|
||||||
|
config = get_moa_configuration()
|
||||||
|
print(f"\n⚙️ Current Configuration:")
|
||||||
|
print(f" 🤖 Reference models ({len(config['reference_models'])}): {', '.join(config['reference_models'])}")
|
||||||
|
print(f" 🧠 Aggregator model: {config['aggregator_model']}")
|
||||||
|
print(f" 🌡️ Reference temperature: {config['reference_temperature']}")
|
||||||
|
print(f" 🌡️ Aggregator temperature: {config['aggregator_temperature']}")
|
||||||
|
print(f" 🛡️ Failure tolerance: {config['failure_tolerance']}")
|
||||||
|
print(f" 📊 Minimum successful models: {config['min_successful_references']}")
|
||||||
|
|
||||||
|
# Show debug mode status
|
||||||
|
if DEBUG_MODE:
|
||||||
|
print(f"\n🐛 Debug mode ENABLED - Session ID: {DEBUG_SESSION_ID}")
|
||||||
|
print(f" Debug logs will be saved to: ./logs/moa_tools_debug_{DEBUG_SESSION_ID}.json")
|
||||||
|
else:
|
||||||
|
print("\n🐛 Debug mode disabled (set MOA_TOOLS_DEBUG=true to enable)")
|
||||||
|
|
||||||
|
print("\nBasic usage:")
|
||||||
|
print(" from mixture_of_agents_tool import mixture_of_agents_tool")
|
||||||
|
print(" import asyncio")
|
||||||
|
print("")
|
||||||
|
print(" async def main():")
|
||||||
|
print(" result = await mixture_of_agents_tool(")
|
||||||
|
print(" user_prompt='Solve this complex mathematical proof...'")
|
||||||
|
print(" )")
|
||||||
|
print(" print(result)")
|
||||||
|
print(" asyncio.run(main())")
|
||||||
|
|
||||||
|
print("\nBest use cases:")
|
||||||
|
print(" - Complex mathematical proofs and calculations")
|
||||||
|
print(" - Advanced coding problems and algorithm design")
|
||||||
|
print(" - Multi-step analytical reasoning tasks")
|
||||||
|
print(" - Problems requiring diverse domain expertise")
|
||||||
|
print(" - Tasks where single models show limitations")
|
||||||
|
|
||||||
|
print("\nPerformance characteristics:")
|
||||||
|
print(" - Higher latency due to multiple model calls")
|
||||||
|
print(" - Significantly improved quality for complex tasks")
|
||||||
|
print(" - Parallel processing for efficiency")
|
||||||
|
print(f" - Optimized temperatures: {REFERENCE_TEMPERATURE} for reference models, {AGGREGATOR_TEMPERATURE} for aggregation")
|
||||||
|
print(" - Token-efficient: only returns final aggregated response")
|
||||||
|
print(" - Resilient: continues with partial model failures")
|
||||||
|
print(f" - Configurable: easy to modify models and settings at top of file")
|
||||||
|
print(" - State-of-the-art results on challenging benchmarks")
|
||||||
|
|
||||||
|
print("\nDebug mode:")
|
||||||
|
print(" # Enable debug logging")
|
||||||
|
print(" export MOA_TOOLS_DEBUG=true")
|
||||||
|
print(" # Debug logs capture all MoA processing steps and metrics")
|
||||||
|
print(" # Logs saved to: ./logs/moa_tools_debug_UUID.json")
|
||||||
243
mock_web_tools.py
Normal file
243
mock_web_tools.py
Normal file
@@ -0,0 +1,243 @@
|
|||||||
|
"""
|
||||||
|
Mock Web Tools for Testing WebSocket Reconnection
|
||||||
|
|
||||||
|
This module provides mock implementations of web_search and web_extract
|
||||||
|
that simulate long-running operations without making real API calls.
|
||||||
|
|
||||||
|
Perfect for testing WebSocket timeout/reconnection behavior without:
|
||||||
|
- Wasting API credits
|
||||||
|
- Waiting for real web crawling
|
||||||
|
- Network dependencies
|
||||||
|
"""
|
||||||
|
|
||||||
|
import time
|
||||||
|
import json
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
|
||||||
|
def mock_web_search(query: str, delay: int = 2) -> str:
|
||||||
|
"""
|
||||||
|
Mock web search that returns fake results after a delay.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: Search query (ignored, just for API compatibility)
|
||||||
|
delay: Seconds to sleep (default: 2s)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
JSON string with fake search results
|
||||||
|
"""
|
||||||
|
print(f"🔍 [MOCK] Searching for: '{query}' (will take {delay}s)...")
|
||||||
|
time.sleep(delay)
|
||||||
|
|
||||||
|
result = {
|
||||||
|
"success": True,
|
||||||
|
"data": {
|
||||||
|
"web": [
|
||||||
|
{
|
||||||
|
"url": "https://example.com/article1",
|
||||||
|
"title": "Mock Article 1 - Water Utilities",
|
||||||
|
"description": "This is a mock search result for testing purposes. Real data would appear here.",
|
||||||
|
"category": None
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"url": "https://example.com/article2",
|
||||||
|
"title": "Mock Article 2 - AI Data Centers",
|
||||||
|
"description": "Another mock result. This simulates web_search without making real API calls.",
|
||||||
|
"category": None
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"url": "https://example.com/article3",
|
||||||
|
"title": "Mock Article 3 - Investment Opportunities",
|
||||||
|
"description": "Third mock result for testing. Query was: " + query,
|
||||||
|
"category": None
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return json.dumps(result, indent=2)
|
||||||
|
|
||||||
|
|
||||||
|
def mock_web_extract(urls: List[str], delay: int = 60) -> str:
|
||||||
|
"""
|
||||||
|
Mock web extraction that simulates long-running crawl.
|
||||||
|
|
||||||
|
This is perfect for testing WebSocket timeout/reconnection because:
|
||||||
|
- Default 60s delay triggers the ~30s WebSocket timeout
|
||||||
|
- No actual web requests made
|
||||||
|
- No API credits consumed
|
||||||
|
- Predictable, reproducible behavior
|
||||||
|
|
||||||
|
Args:
|
||||||
|
urls: List of URLs to "extract" (ignored)
|
||||||
|
delay: Seconds to sleep (default: 60s to trigger timeout)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
JSON string with fake extraction results
|
||||||
|
"""
|
||||||
|
print(f"🌐 [MOCK] Extracting {len(urls)} URLs (will take {delay}s)...")
|
||||||
|
print(f"📊 [MOCK] This will test WebSocket reconnection (timeout at ~30s)")
|
||||||
|
|
||||||
|
# Simulate long-running operation
|
||||||
|
# Show progress so user knows it's working
|
||||||
|
for i in range(delay):
|
||||||
|
if i % 10 == 0 and i > 0:
|
||||||
|
print(f" ⏱️ [MOCK] {i}/{delay}s elapsed...")
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
|
# Generate fake but realistic-looking content
|
||||||
|
result = {
|
||||||
|
"success": True,
|
||||||
|
"data": []
|
||||||
|
}
|
||||||
|
|
||||||
|
for idx, url in enumerate(urls, 1):
|
||||||
|
result["data"].append({
|
||||||
|
"url": url,
|
||||||
|
"title": f"Mock Extracted Content {idx}",
|
||||||
|
"content": f"# Mock Content from {url}\n\n"
|
||||||
|
f"This is simulated extracted content for testing purposes. "
|
||||||
|
f"In a real scenario, this would contain the full text from the webpage. "
|
||||||
|
f"\n\n## Key Points\n"
|
||||||
|
f"- Mock point 1 about water utilities\n"
|
||||||
|
f"- Mock point 2 about AI data centers\n"
|
||||||
|
f"- Mock point 3 about investment opportunities\n"
|
||||||
|
f"\n\nThis content took {delay} seconds to 'extract', which is long enough "
|
||||||
|
f"to trigger WebSocket timeout and test reconnection logic."
|
||||||
|
* 10, # Make it longer to simulate real extraction
|
||||||
|
"extracted_at": "2025-10-10T14:00:00Z"
|
||||||
|
})
|
||||||
|
|
||||||
|
json_result = json.dumps(result, indent=2)
|
||||||
|
size_kb = len(json_result) / 1024
|
||||||
|
|
||||||
|
print(f"✅ [MOCK] Extraction completed: {len(urls)} URLs, {size_kb:.1f} KB")
|
||||||
|
return json_result
|
||||||
|
|
||||||
|
|
||||||
|
def mock_web_crawl(start_url: str, max_pages: int = 10, delay: int = 30) -> str:
|
||||||
|
"""
|
||||||
|
Mock web crawling that simulates multi-page crawl.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
start_url: Starting URL (ignored)
|
||||||
|
max_pages: Max pages to crawl (just affects result count)
|
||||||
|
delay: Seconds to sleep (default: 30s)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
JSON string with fake crawl results
|
||||||
|
"""
|
||||||
|
print(f"🕷️ [MOCK] Crawling from: {start_url} (max {max_pages} pages, {delay}s)...")
|
||||||
|
time.sleep(delay)
|
||||||
|
|
||||||
|
result = {
|
||||||
|
"success": True,
|
||||||
|
"data": {
|
||||||
|
"start_url": start_url,
|
||||||
|
"pages_crawled": min(max_pages, 5),
|
||||||
|
"pages": []
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for i in range(min(max_pages, 5)):
|
||||||
|
result["data"]["pages"].append({
|
||||||
|
"url": f"{start_url}/page{i+1}",
|
||||||
|
"title": f"Mock Page {i+1}",
|
||||||
|
"content": f"Mock content from page {i+1}. " * 50
|
||||||
|
})
|
||||||
|
|
||||||
|
print(f"✅ [MOCK] Crawl completed: {len(result['data']['pages'])} pages")
|
||||||
|
return json.dumps(result, indent=2)
|
||||||
|
|
||||||
|
|
||||||
|
# Tool definitions for the agent (same format as real tools)
|
||||||
|
MOCK_WEB_TOOLS = [
|
||||||
|
{
|
||||||
|
"name": "web_search",
|
||||||
|
"description": "[MOCK] Search the web for information. Returns fake results after 2s delay. Perfect for quick tests.",
|
||||||
|
"input_schema": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"query": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The search query"
|
||||||
|
},
|
||||||
|
"delay": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "Seconds to delay (default: 2)",
|
||||||
|
"default": 2
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["query"]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "web_extract",
|
||||||
|
"description": "[MOCK] Extract content from URLs. Simulates 60s delay to test WebSocket timeout/reconnection. Returns fake content without making real requests. PERFECT FOR TESTING!",
|
||||||
|
"input_schema": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"urls": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {"type": "string"},
|
||||||
|
"description": "List of URLs to extract"
|
||||||
|
},
|
||||||
|
"delay": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "Seconds to delay (default: 60 to trigger timeout)",
|
||||||
|
"default": 60
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["urls"]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "web_crawl",
|
||||||
|
"description": "[MOCK] Crawl website starting from URL. Returns fake results after 30s delay.",
|
||||||
|
"input_schema": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"start_url": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Starting URL for crawl"
|
||||||
|
},
|
||||||
|
"max_pages": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "Max pages to crawl (default: 10)",
|
||||||
|
"default": 10
|
||||||
|
},
|
||||||
|
"delay": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "Seconds to delay (default: 30)",
|
||||||
|
"default": 30
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["start_url"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# Map function names to implementations
|
||||||
|
MOCK_TOOL_FUNCTIONS = {
|
||||||
|
"web_search": mock_web_search,
|
||||||
|
"web_extract": mock_web_extract,
|
||||||
|
"web_crawl": mock_web_crawl
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Demo/test the mock tools
|
||||||
|
print("Testing Mock Web Tools")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
print("\n1. Mock web_search (2s delay):")
|
||||||
|
result = mock_web_search("test query", delay=2)
|
||||||
|
print(f"Result length: {len(result)} chars\n")
|
||||||
|
|
||||||
|
print("\n2. Mock web_extract (5s delay for demo - normally 60s):")
|
||||||
|
result = mock_web_extract(["https://example.com"], delay=5)
|
||||||
|
print(f"Result length: {len(result)} chars\n")
|
||||||
|
|
||||||
|
print("\n✅ All mock tools working!")
|
||||||
|
|
||||||
1125
model_tools.py
1125
model_tools.py
File diff suppressed because it is too large
Load Diff
527
output.txt
Normal file
527
output.txt
Normal file
File diff suppressed because one or more lines are too long
@@ -1,2 +1,14 @@
|
|||||||
tavily-python
|
firecrawl-py
|
||||||
openai
|
openai
|
||||||
|
fal-client
|
||||||
|
python-dotenv
|
||||||
|
fire
|
||||||
|
httpx
|
||||||
|
yt-dlp
|
||||||
|
streamlit
|
||||||
|
fastapi
|
||||||
|
uvicorn
|
||||||
|
websockets
|
||||||
|
PySide6>=6.6.0
|
||||||
|
websocket-client>=1.7.0
|
||||||
|
requests>=2.31.0
|
||||||
|
|||||||
1313
run_agent.py
1313
run_agent.py
File diff suppressed because it is too large
Load Diff
@@ -22,8 +22,8 @@ Usage:
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from typing import Optional, Dict, Any
|
from typing import Optional, Dict, Any
|
||||||
from hecate import run_tool_with_lifecycle_management
|
# from hecate import run_tool_with_lifecycle_management
|
||||||
from morphcloud._llm import ToolCall
|
# from morphcloud._llm import ToolCall
|
||||||
|
|
||||||
# Detailed description for the terminal tool based on Hermes Terminal system prompt
|
# Detailed description for the terminal tool based on Hermes Terminal system prompt
|
||||||
TERMINAL_TOOL_DESCRIPTION = """Execute commands on a secure, persistent Linux VM environment with full interactive application support.
|
TERMINAL_TOOL_DESCRIPTION = """Execute commands on a secure, persistent Linux VM environment with full interactive application support.
|
||||||
@@ -129,27 +129,30 @@ def terminal_tool(
|
|||||||
tool_input["idle_threshold"] = idle_threshold
|
tool_input["idle_threshold"] = idle_threshold
|
||||||
if timeout is not None:
|
if timeout is not None:
|
||||||
tool_input["timeout"] = timeout
|
tool_input["timeout"] = timeout
|
||||||
|
|
||||||
|
# THIS IS BROKEN FOR NOW ~!!!!!!!
|
||||||
|
|
||||||
tool_call = ToolCall(
|
# tool_call = ToolCall(
|
||||||
name="run_command",
|
# name="run_command",
|
||||||
input=tool_input
|
# input=tool_input
|
||||||
)
|
# )
|
||||||
|
|
||||||
# Execute with lifecycle management
|
# # Execute with lifecycle management
|
||||||
result = run_tool_with_lifecycle_management(tool_call)
|
# result = run_tool_with_lifecycle_management(tool_call)
|
||||||
|
|
||||||
|
|
||||||
# Format the result with all possible fields
|
# # Format the result with all possible fields
|
||||||
# Map hecate's "stdout" to "output" for compatibility
|
# # Map hecate's "stdout" to "output" for compatibility
|
||||||
formatted_result = {
|
# formatted_result = {
|
||||||
"output": result.get("stdout", result.get("output", "")),
|
# "output": result.get("stdout", result.get("output", "")),
|
||||||
"screen": result.get("screen", ""),
|
# "screen": result.get("screen", ""),
|
||||||
"session_id": result.get("session_id"),
|
# "session_id": result.get("session_id"),
|
||||||
"exit_code": result.get("returncode", result.get("exit_code", -1)),
|
# "exit_code": result.get("returncode", result.get("exit_code", -1)),
|
||||||
"error": result.get("error"),
|
# "error": result.get("error"),
|
||||||
"status": "active" if result.get("session_id") else "ended"
|
# "status": "active" if result.get("session_id") else "ended"
|
||||||
}
|
# }
|
||||||
|
|
||||||
return json.dumps(formatted_result)
|
return json.dumps({})
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return json.dumps({
|
return json.dumps({
|
||||||
|
|||||||
122
test_mock_mode.sh
Executable file
122
test_mock_mode.sh
Executable file
@@ -0,0 +1,122 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
#
|
||||||
|
# Test Script for Mock Web Tools & WebSocket Reconnection
|
||||||
|
#
|
||||||
|
# This script tests:
|
||||||
|
# 1. Mock web tools (no API calls, fake data)
|
||||||
|
# 2. WebSocket timeout/reconnection during long operations
|
||||||
|
# 3. Complete logging capture
|
||||||
|
#
|
||||||
|
# Perfect for development/testing without wasting API credits!
|
||||||
|
|
||||||
|
set -e
|
||||||
|
|
||||||
|
cd "$(dirname "$0")"
|
||||||
|
|
||||||
|
echo "=========================================="
|
||||||
|
echo "🧪 Mock Mode Test Script"
|
||||||
|
echo "=========================================="
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
# Check if logging server is running
|
||||||
|
if ! curl -s http://localhost:8000/health > /dev/null 2>&1; then
|
||||||
|
echo "⚠️ Logging server not detected!"
|
||||||
|
echo " Starting logging server in background..."
|
||||||
|
python api_endpoint/logging_server.py &
|
||||||
|
SERVER_PID=$!
|
||||||
|
echo " Server PID: $SERVER_PID"
|
||||||
|
sleep 3
|
||||||
|
else
|
||||||
|
echo "✅ Logging server already running"
|
||||||
|
SERVER_PID=""
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "📋 Test Configuration:"
|
||||||
|
echo " - Mock web tools: ENABLED"
|
||||||
|
echo " - Mock delay: 60 seconds (triggers WebSocket timeout)"
|
||||||
|
echo " - WebSocket logging: ENABLED"
|
||||||
|
echo " - Expected behavior: Connection timeout + auto-reconnect"
|
||||||
|
echo ""
|
||||||
|
echo "🔄 Running agent with mock mode..."
|
||||||
|
echo " (This will take ~60 seconds to test reconnection)"
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
# Run agent with mock mode
|
||||||
|
python run_agent.py \
|
||||||
|
--enabled_toolsets web \
|
||||||
|
--enable_websocket_logging \
|
||||||
|
--mock_web_tools \
|
||||||
|
--mock_delay 60 \
|
||||||
|
--query "Find publicly traded water companies benefiting from AI data centers"
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "=========================================="
|
||||||
|
echo "✅ Test Complete!"
|
||||||
|
echo "=========================================="
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
# Find most recent log file
|
||||||
|
LATEST_LOG=$(ls -t api_endpoint/logs/realtime/session_*.json 2>/dev/null | head -1)
|
||||||
|
|
||||||
|
if [ -n "$LATEST_LOG" ]; then
|
||||||
|
echo "📊 Log Analysis:"
|
||||||
|
echo " File: $LATEST_LOG"
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
# Count events
|
||||||
|
echo " Event Counts:"
|
||||||
|
python3 -c "
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
|
||||||
|
with open('$LATEST_LOG') as f:
|
||||||
|
data = json.load(f)
|
||||||
|
events = data.get('events', [])
|
||||||
|
|
||||||
|
# Count by type
|
||||||
|
counts = {}
|
||||||
|
for e in events:
|
||||||
|
etype = e.get('type', 'unknown')
|
||||||
|
counts[etype] = counts.get(etype, 0) + 1
|
||||||
|
|
||||||
|
for etype, count in sorted(counts.items()):
|
||||||
|
print(f' - {etype}: {count}')
|
||||||
|
|
||||||
|
# Check completeness
|
||||||
|
has_complete = any(e.get('type') == 'complete' for e in events)
|
||||||
|
print()
|
||||||
|
if has_complete:
|
||||||
|
print(' ✅ Session completed successfully!')
|
||||||
|
else:
|
||||||
|
print(' ⚠️ Session incomplete (may have been interrupted)')
|
||||||
|
|
||||||
|
# Check for reconnections
|
||||||
|
tool_results = [e for e in events if e.get('type') == 'tool_result']
|
||||||
|
tool_calls = [e for e in events if e.get('type') == 'tool_call']
|
||||||
|
|
||||||
|
if len(tool_results) == len(tool_calls):
|
||||||
|
print(' ✅ All tool calls have results (no missing events)')
|
||||||
|
else:
|
||||||
|
print(f' ⚠️ Tool calls: {len(tool_calls)}, Results: {len(tool_results)}')
|
||||||
|
"
|
||||||
|
else
|
||||||
|
echo "⚠️ No log files found"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Cleanup
|
||||||
|
if [ -n "$SERVER_PID" ]; then
|
||||||
|
echo ""
|
||||||
|
echo "🛑 Stopping logging server (PID: $SERVER_PID)..."
|
||||||
|
kill $SERVER_PID 2>/dev/null || true
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "💡 Key Observations to Look For:"
|
||||||
|
echo " 1. '[MOCK]' prefix on tool execution messages"
|
||||||
|
echo " 2. '🔄 Reconnecting to logging server' after long tool"
|
||||||
|
echo " 3. '✅ Reconnected successfully!' confirmation"
|
||||||
|
echo " 4. Complete log file with all events captured"
|
||||||
|
echo ""
|
||||||
|
echo "🎉 Mock mode test completed!"
|
||||||
|
|
||||||
242
test_parallel_execution.py
Normal file
242
test_parallel_execution.py
Normal file
@@ -0,0 +1,242 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Test Parallel Execution with Persistent WebSocket Connection Pool
|
||||||
|
|
||||||
|
This script demonstrates that multiple agent runs can execute in parallel,
|
||||||
|
all sharing a single WebSocket connection for logging.
|
||||||
|
|
||||||
|
Benefits:
|
||||||
|
- No connection overhead (single persistent connection)
|
||||||
|
- No timeout issues (connection stays alive)
|
||||||
|
- True parallel execution (multiple sessions simultaneously)
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from run_agent import AIAgent
|
||||||
|
import time
|
||||||
|
|
||||||
|
|
||||||
|
async def run_agent_query(query: str, agent_name: str, mock_delay: int = 10):
|
||||||
|
"""
|
||||||
|
Run a single agent query with logging.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: Query to send to agent
|
||||||
|
agent_name: Name for logging purposes
|
||||||
|
mock_delay: Delay for mock tools (seconds)
|
||||||
|
"""
|
||||||
|
print(f"🚀 [{agent_name}] Starting query: '{query[:40]}...'")
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
try:
|
||||||
|
agent = AIAgent(
|
||||||
|
model="claude-sonnet-4-5-20250929",
|
||||||
|
max_iterations=5,
|
||||||
|
enabled_toolsets=["web"],
|
||||||
|
enable_websocket_logging=True,
|
||||||
|
websocket_server="ws://localhost:8000/ws",
|
||||||
|
mock_web_tools=True, # Use mock tools for fast testing
|
||||||
|
mock_delay=mock_delay
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await agent.run_conversation(query)
|
||||||
|
|
||||||
|
duration = time.time() - start_time
|
||||||
|
print(f"✅ [{agent_name}] Completed in {duration:.1f}s - {result['api_calls']} API calls")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"agent": agent_name,
|
||||||
|
"query": query,
|
||||||
|
"success": True,
|
||||||
|
"duration": duration,
|
||||||
|
"api_calls": result['api_calls'],
|
||||||
|
"session_id": result.get('session_id')
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
duration = time.time() - start_time
|
||||||
|
print(f"❌ [{agent_name}] Failed in {duration:.1f}s: {e}")
|
||||||
|
return {
|
||||||
|
"agent": agent_name,
|
||||||
|
"query": query,
|
||||||
|
"success": False,
|
||||||
|
"error": str(e),
|
||||||
|
"duration": duration
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def test_sequential():
|
||||||
|
"""
|
||||||
|
Test 1: Sequential execution (baseline).
|
||||||
|
|
||||||
|
Runs 3 queries one after another. This shows how long it takes
|
||||||
|
without parallelization.
|
||||||
|
"""
|
||||||
|
print("\n" + "="*60)
|
||||||
|
print("TEST 1: Sequential Execution (Baseline)")
|
||||||
|
print("="*60)
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for i in range(3):
|
||||||
|
result = await run_agent_query(
|
||||||
|
query=f"Find information about water companies #{i+1}",
|
||||||
|
agent_name=f"Agent{i+1}",
|
||||||
|
mock_delay=5 # Short delay for quick test
|
||||||
|
)
|
||||||
|
results.append(result)
|
||||||
|
|
||||||
|
total_time = time.time() - start_time
|
||||||
|
|
||||||
|
print(f"\n📊 Sequential Results:")
|
||||||
|
print(f" Total time: {total_time:.1f}s")
|
||||||
|
print(f" Successful: {sum(1 for r in results if r['success'])}/3")
|
||||||
|
print(f" Average per query: {total_time/3:.1f}s")
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
async def test_parallel():
|
||||||
|
"""
|
||||||
|
Test 2: Parallel execution.
|
||||||
|
|
||||||
|
Runs 3 queries simultaneously using asyncio.gather().
|
||||||
|
All queries share the same WebSocket connection for logging.
|
||||||
|
"""
|
||||||
|
print("\n" + "="*60)
|
||||||
|
print("TEST 2: Parallel Execution (Shared Connection)")
|
||||||
|
print("="*60)
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
# Run all queries in parallel!
|
||||||
|
results = await asyncio.gather(
|
||||||
|
run_agent_query(
|
||||||
|
query="Find publicly traded water utility companies",
|
||||||
|
agent_name="Agent1",
|
||||||
|
mock_delay=5
|
||||||
|
),
|
||||||
|
run_agent_query(
|
||||||
|
query="Find energy infrastructure companies",
|
||||||
|
agent_name="Agent2",
|
||||||
|
mock_delay=5
|
||||||
|
),
|
||||||
|
run_agent_query(
|
||||||
|
query="Find AI data center operators",
|
||||||
|
agent_name="Agent3",
|
||||||
|
mock_delay=5
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
total_time = time.time() - start_time
|
||||||
|
|
||||||
|
print(f"\n📊 Parallel Results:")
|
||||||
|
print(f" Total time: {total_time:.1f}s")
|
||||||
|
print(f" Successful: {sum(1 for r in results if r['success'])}/3")
|
||||||
|
print(f" Speedup: ~{(sum(r['duration'] for r in results) / total_time):.1f}x")
|
||||||
|
print(f" Sessions logged: {[r.get('session_id', 'N/A')[:8] for r in results]}")
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
async def test_high_concurrency():
|
||||||
|
"""
|
||||||
|
Test 3: High concurrency (stress test).
|
||||||
|
|
||||||
|
Runs 10 queries simultaneously to test connection pool under load.
|
||||||
|
"""
|
||||||
|
print("\n" + "="*60)
|
||||||
|
print("TEST 3: High Concurrency (10 Parallel Agents)")
|
||||||
|
print("="*60)
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
tasks = [
|
||||||
|
run_agent_query(
|
||||||
|
query=f"Test query #{i+1}",
|
||||||
|
agent_name=f"Agent{i+1}",
|
||||||
|
mock_delay=3 # Very short for stress test
|
||||||
|
)
|
||||||
|
for i in range(10)
|
||||||
|
]
|
||||||
|
|
||||||
|
results = await asyncio.gather(*tasks)
|
||||||
|
|
||||||
|
total_time = time.time() - start_time
|
||||||
|
successful = sum(1 for r in results if r['success'])
|
||||||
|
|
||||||
|
print(f"\n📊 High Concurrency Results:")
|
||||||
|
print(f" Total time: {total_time:.1f}s")
|
||||||
|
print(f" Successful: {successful}/10")
|
||||||
|
print(f" Failed: {10 - successful}/10")
|
||||||
|
print(f" Queries per second: {10 / total_time:.2f}")
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
"""Run all tests."""
|
||||||
|
print("\n🧪 WebSocket Connection Pool - Parallel Execution Tests")
|
||||||
|
print("="*60)
|
||||||
|
print("\nPREREQUISITE: Make sure logging server is running:")
|
||||||
|
print(" python api_endpoint/logging_server.py")
|
||||||
|
print("\nPress Ctrl+C to stop at any time\n")
|
||||||
|
|
||||||
|
await asyncio.sleep(2) # Give user time to read
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Test 1: Sequential (baseline)
|
||||||
|
seq_results = await test_sequential()
|
||||||
|
|
||||||
|
# Test 2: Parallel (main test)
|
||||||
|
par_results = await test_parallel()
|
||||||
|
|
||||||
|
# Test 3: High concurrency
|
||||||
|
stress_results = await test_high_concurrency()
|
||||||
|
|
||||||
|
# Summary
|
||||||
|
print("\n" + "="*60)
|
||||||
|
print("SUMMARY")
|
||||||
|
print("="*60)
|
||||||
|
print(f"\n✅ All tests completed!")
|
||||||
|
print(f"\nKey Findings:")
|
||||||
|
print(f" • Sequential (3 queries): {sum(r['duration'] for r in seq_results):.1f}s total")
|
||||||
|
print(f" • Parallel (3 queries): {max(r['duration'] for r in par_results):.1f}s total")
|
||||||
|
print(f" • Speedup: ~{sum(r['duration'] for r in seq_results) / max(r['duration'] for r in par_results):.1f}x")
|
||||||
|
print(f" • High concurrency (10 queries): ✅ Handled successfully")
|
||||||
|
print(f"\n💡 All queries used the same persistent WebSocket connection!")
|
||||||
|
print(f" No connection overhead, no timeouts, true parallelization.")
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\n\n⚠️ Tests interrupted by user")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\n\n❌ Tests failed: {e}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
print("\n" + "="*60)
|
||||||
|
print("SETUP CHECK")
|
||||||
|
print("="*60)
|
||||||
|
|
||||||
|
# Check if logging server is running
|
||||||
|
import socket
|
||||||
|
try:
|
||||||
|
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||||
|
result = sock.connect_ex(('localhost', 8000))
|
||||||
|
sock.close()
|
||||||
|
|
||||||
|
if result == 0:
|
||||||
|
print("✅ Logging server is running on port 8000")
|
||||||
|
else:
|
||||||
|
print("⚠️ Logging server not detected on port 8000")
|
||||||
|
print(" Start it with: python api_endpoint/logging_server.py")
|
||||||
|
print("\nContinuing anyway (tests will fail gracefully)...")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"⚠️ Could not check server status: {e}")
|
||||||
|
|
||||||
|
# Run tests
|
||||||
|
asyncio.run(main())
|
||||||
|
|
||||||
31
test_run.sh
Normal file
31
test_run.sh
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# Check if a prompt argument was provided
|
||||||
|
if [ $# -eq 0 ]; then
|
||||||
|
echo "Error: Please provide a prompt as an argument"
|
||||||
|
echo "Usage: $0 \"your prompt here\""
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Get the prompt from the first argument
|
||||||
|
PROMPT="$1"
|
||||||
|
|
||||||
|
# Set debug mode for web tools
|
||||||
|
export WEB_TOOLS_DEBUG=true
|
||||||
|
|
||||||
|
# Run the agent with the provided prompt
|
||||||
|
python run_agent.py \
|
||||||
|
--query "$PROMPT" \
|
||||||
|
--max_turns 30 \
|
||||||
|
--model claude-sonnet-4-20250514 \
|
||||||
|
--base_url https://api.anthropic.com/v1/ \
|
||||||
|
--api_key $ANTHROPIC_API_KEY \
|
||||||
|
--save_trajectories \
|
||||||
|
--enabled_toolsets=web
|
||||||
|
|
||||||
|
# --model claude-sonnet-4-20250514 \
|
||||||
|
#
|
||||||
|
#Possible Toolsets:
|
||||||
|
#web_tools
|
||||||
|
#vision_tools
|
||||||
|
#terminal_tools
|
||||||
264
test_ui_flow.py
Normal file
264
test_ui_flow.py
Normal file
@@ -0,0 +1,264 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Test script to verify UI flow works correctly.
|
||||||
|
|
||||||
|
This tests:
|
||||||
|
1. API server is running
|
||||||
|
2. WebSocket connection works
|
||||||
|
3. Agent can be started via API
|
||||||
|
4. Events are broadcast properly
|
||||||
|
"""
|
||||||
|
|
||||||
|
import requests
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
import websocket
|
||||||
|
import threading
|
||||||
|
|
||||||
|
API_URL = "http://localhost:8000"
|
||||||
|
WS_URL = "ws://localhost:8000/ws"
|
||||||
|
|
||||||
|
def test_api_server():
|
||||||
|
"""Test if API server is running."""
|
||||||
|
print("🔍 Testing API server...")
|
||||||
|
try:
|
||||||
|
response = requests.get(f"{API_URL}/", timeout=5)
|
||||||
|
if response.status_code == 200:
|
||||||
|
data = response.json()
|
||||||
|
print(f"✅ API server is running: {data.get('service')}")
|
||||||
|
print(f" Active connections: {data.get('active_connections')}")
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
print(f"❌ API server returned: {response.status_code}")
|
||||||
|
return False
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ API server not accessible: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def test_tools_endpoint():
|
||||||
|
"""Test if tools endpoint works."""
|
||||||
|
print("\n🔍 Testing tools endpoint...")
|
||||||
|
try:
|
||||||
|
response = requests.get(f"{API_URL}/tools", timeout=5)
|
||||||
|
if response.status_code == 200:
|
||||||
|
data = response.json()
|
||||||
|
toolsets = data.get("toolsets", [])
|
||||||
|
print(f"✅ Tools endpoint works - {len(toolsets)} toolsets available")
|
||||||
|
for ts in toolsets[:3]:
|
||||||
|
print(f" • {ts.get('name')} ({ts.get('tool_count')} tools)")
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
print(f"❌ Tools endpoint failed: {response.status_code}")
|
||||||
|
return False
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Tools endpoint error: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def test_websocket():
|
||||||
|
"""Test WebSocket connection."""
|
||||||
|
print("\n🔍 Testing WebSocket connection...")
|
||||||
|
|
||||||
|
connected = threading.Event()
|
||||||
|
message_received = threading.Event()
|
||||||
|
messages = []
|
||||||
|
|
||||||
|
def on_open(ws):
|
||||||
|
print("✅ WebSocket connected")
|
||||||
|
connected.set()
|
||||||
|
|
||||||
|
def on_message(ws, message):
|
||||||
|
data = json.loads(message)
|
||||||
|
messages.append(data)
|
||||||
|
message_received.set()
|
||||||
|
print(f"📨 Received: {data.get('event_type', 'unknown')}")
|
||||||
|
|
||||||
|
def on_error(ws, error):
|
||||||
|
print(f"❌ WebSocket error: {error}")
|
||||||
|
|
||||||
|
def on_close(ws, close_status_code, close_msg):
|
||||||
|
print(f"🔌 WebSocket closed: {close_status_code}")
|
||||||
|
|
||||||
|
ws = websocket.WebSocketApp(
|
||||||
|
WS_URL,
|
||||||
|
on_open=on_open,
|
||||||
|
on_message=on_message,
|
||||||
|
on_error=on_error,
|
||||||
|
on_close=on_close
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run WebSocket in background
|
||||||
|
ws_thread = threading.Thread(target=lambda: ws.run_forever(), daemon=True)
|
||||||
|
ws_thread.start()
|
||||||
|
|
||||||
|
# Wait for connection
|
||||||
|
if connected.wait(timeout=5):
|
||||||
|
print("✅ WebSocket connection established")
|
||||||
|
ws.close()
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
print("❌ WebSocket connection timeout")
|
||||||
|
ws.close()
|
||||||
|
return False
|
||||||
|
|
||||||
|
def test_agent_run():
|
||||||
|
"""Test running agent via API."""
|
||||||
|
print("\n🔍 Testing agent run via API (mock mode)...")
|
||||||
|
|
||||||
|
# Start listening for events first
|
||||||
|
events = []
|
||||||
|
ws_connected = threading.Event()
|
||||||
|
session_complete = threading.Event()
|
||||||
|
|
||||||
|
def on_message(ws, message):
|
||||||
|
data = json.loads(message)
|
||||||
|
events.append(data)
|
||||||
|
event_type = data.get("event_type")
|
||||||
|
print(f" 📨 Event: {event_type}")
|
||||||
|
|
||||||
|
if event_type == "complete":
|
||||||
|
session_complete.set()
|
||||||
|
|
||||||
|
def on_open(ws):
|
||||||
|
ws_connected.set()
|
||||||
|
|
||||||
|
# Connect WebSocket
|
||||||
|
ws = websocket.WebSocketApp(
|
||||||
|
WS_URL,
|
||||||
|
on_open=on_open,
|
||||||
|
on_message=on_message
|
||||||
|
)
|
||||||
|
|
||||||
|
ws_thread = threading.Thread(target=lambda: ws.run_forever(), daemon=True)
|
||||||
|
ws_thread.start()
|
||||||
|
|
||||||
|
# Wait for WebSocket connection
|
||||||
|
if not ws_connected.wait(timeout=5):
|
||||||
|
print("❌ WebSocket didn't connect")
|
||||||
|
ws.close()
|
||||||
|
return False
|
||||||
|
|
||||||
|
print("✅ WebSocket connected, starting agent...")
|
||||||
|
|
||||||
|
# Submit agent run
|
||||||
|
payload = {
|
||||||
|
"query": "Test query for UI flow verification",
|
||||||
|
"model": "claude-sonnet-4-5-20250929",
|
||||||
|
"base_url": "https://api.anthropic.com/v1/",
|
||||||
|
"enabled_toolsets": ["web"],
|
||||||
|
"max_turns": 5,
|
||||||
|
"mock_web_tools": True, # Use mock mode to avoid API costs
|
||||||
|
"mock_delay": 2, # Fast for testing
|
||||||
|
"verbose": False
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = requests.post(f"{API_URL}/agent/run", json=payload, timeout=10)
|
||||||
|
|
||||||
|
if response.status_code == 200:
|
||||||
|
result = response.json()
|
||||||
|
session_id = result.get("session_id")
|
||||||
|
print(f"✅ Agent started: {session_id[:8]}...")
|
||||||
|
|
||||||
|
# Wait for completion (or timeout)
|
||||||
|
print("⏳ Waiting for agent to complete (up to 30s)...")
|
||||||
|
if session_complete.wait(timeout=30):
|
||||||
|
print(f"✅ Agent completed! Received {len(events)} events:")
|
||||||
|
|
||||||
|
# Count event types
|
||||||
|
event_counts = {}
|
||||||
|
for evt in events:
|
||||||
|
evt_type = evt.get("event_type", "unknown")
|
||||||
|
event_counts[evt_type] = event_counts.get(evt_type, 0) + 1
|
||||||
|
|
||||||
|
for evt_type, count in event_counts.items():
|
||||||
|
print(f" • {evt_type}: {count}")
|
||||||
|
|
||||||
|
# Check we got expected events
|
||||||
|
expected_events = ["query", "api_call", "response", "complete"]
|
||||||
|
missing = [e for e in expected_events if e not in event_counts]
|
||||||
|
|
||||||
|
if missing:
|
||||||
|
print(f"⚠️ Missing expected events: {missing}")
|
||||||
|
else:
|
||||||
|
print("✅ All expected event types received!")
|
||||||
|
|
||||||
|
ws.close()
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
print(f"⚠️ Timeout waiting for completion. Got {len(events)} events so far.")
|
||||||
|
ws.close()
|
||||||
|
return False
|
||||||
|
|
||||||
|
else:
|
||||||
|
print(f"❌ Agent start failed: {response.status_code}")
|
||||||
|
print(f" Response: {response.text}")
|
||||||
|
ws.close()
|
||||||
|
return False
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Agent run error: {e}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
ws.close()
|
||||||
|
return False
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Run all tests."""
|
||||||
|
print("=" * 60)
|
||||||
|
print("🧪 Hermes Agent UI Flow Test")
|
||||||
|
print("=" * 60)
|
||||||
|
print("\nThis will test the complete flow:")
|
||||||
|
print(" 1. API server connectivity")
|
||||||
|
print(" 2. Tools endpoint")
|
||||||
|
print(" 3. WebSocket connection")
|
||||||
|
print(" 4. Agent execution via API (mock mode)")
|
||||||
|
print(" 5. Event streaming to UI")
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
|
||||||
|
results = []
|
||||||
|
|
||||||
|
# Test 1: API server
|
||||||
|
results.append(("API Server", test_api_server()))
|
||||||
|
|
||||||
|
# Test 2: Tools endpoint
|
||||||
|
results.append(("Tools Endpoint", test_tools_endpoint()))
|
||||||
|
|
||||||
|
# Test 3: WebSocket
|
||||||
|
results.append(("WebSocket Connection", test_websocket()))
|
||||||
|
|
||||||
|
# Test 4: Agent run
|
||||||
|
results.append(("Agent Execution + Events", test_agent_run()))
|
||||||
|
|
||||||
|
# Summary
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("📊 TEST SUMMARY")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
for test_name, passed in results:
|
||||||
|
status = "✅ PASS" if passed else "❌ FAIL"
|
||||||
|
print(f"{status} - {test_name}")
|
||||||
|
|
||||||
|
all_passed = all(r[1] for r in results)
|
||||||
|
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
if all_passed:
|
||||||
|
print("🎉 ALL TESTS PASSED!")
|
||||||
|
print("\n✅ The UI flow is working correctly!")
|
||||||
|
print(" You can now use the UI to:")
|
||||||
|
print(" • Submit queries")
|
||||||
|
print(" • View real-time events")
|
||||||
|
print(" • See tool executions")
|
||||||
|
print(" • Get final responses")
|
||||||
|
else:
|
||||||
|
print("❌ SOME TESTS FAILED")
|
||||||
|
print("\nMake sure:")
|
||||||
|
print(" 1. API server is running: python api_endpoint/logging_server.py")
|
||||||
|
print(" 2. ANTHROPIC_API_KEY is set in environment")
|
||||||
|
print(" 3. All dependencies are installed: pip install -r requirements.txt")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
return 0 if all_passed else 1
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
exit(main())
|
||||||
|
|
||||||
620
test_web_tools.py
Normal file
620
test_web_tools.py
Normal file
@@ -0,0 +1,620 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Comprehensive Test Suite for Web Tools Module
|
||||||
|
|
||||||
|
This script tests all web tools functionality to ensure they work correctly.
|
||||||
|
Run this after any updates to the web_tools.py module or Firecrawl library.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python test_web_tools.py # Run all tests
|
||||||
|
python test_web_tools.py --no-llm # Skip LLM processing tests
|
||||||
|
python test_web_tools.py --verbose # Show detailed output
|
||||||
|
|
||||||
|
Requirements:
|
||||||
|
- FIRECRAWL_API_KEY environment variable must be set
|
||||||
|
- NOUS_API_KEY environment vitinariable (optional, for LLM tests)
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import asyncio
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
import argparse
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import List, Dict, Any
|
||||||
|
|
||||||
|
# Import the web tools to test
|
||||||
|
from web_tools import (
|
||||||
|
web_search_tool,
|
||||||
|
web_extract_tool,
|
||||||
|
web_crawl_tool,
|
||||||
|
check_firecrawl_api_key,
|
||||||
|
check_nous_api_key,
|
||||||
|
get_debug_session_info
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Colors:
|
||||||
|
"""ANSI color codes for terminal output"""
|
||||||
|
HEADER = '\033[95m'
|
||||||
|
BLUE = '\033[94m'
|
||||||
|
CYAN = '\033[96m'
|
||||||
|
GREEN = '\033[92m'
|
||||||
|
WARNING = '\033[93m'
|
||||||
|
FAIL = '\033[91m'
|
||||||
|
ENDC = '\033[0m'
|
||||||
|
BOLD = '\033[1m'
|
||||||
|
UNDERLINE = '\033[4m'
|
||||||
|
|
||||||
|
|
||||||
|
def print_header(text: str):
|
||||||
|
"""Print a formatted header"""
|
||||||
|
print(f"\n{Colors.HEADER}{Colors.BOLD}{'='*60}{Colors.ENDC}")
|
||||||
|
print(f"{Colors.HEADER}{Colors.BOLD}{text}{Colors.ENDC}")
|
||||||
|
print(f"{Colors.HEADER}{Colors.BOLD}{'='*60}{Colors.ENDC}")
|
||||||
|
|
||||||
|
|
||||||
|
def print_section(text: str):
|
||||||
|
"""Print a formatted section header"""
|
||||||
|
print(f"\n{Colors.CYAN}{Colors.BOLD}📌 {text}{Colors.ENDC}")
|
||||||
|
print(f"{Colors.CYAN}{'-'*50}{Colors.ENDC}")
|
||||||
|
|
||||||
|
|
||||||
|
def print_success(text: str):
|
||||||
|
"""Print success message"""
|
||||||
|
print(f"{Colors.GREEN}✅ {text}{Colors.ENDC}")
|
||||||
|
|
||||||
|
|
||||||
|
def print_error(text: str):
|
||||||
|
"""Print error message"""
|
||||||
|
print(f"{Colors.FAIL}❌ {text}{Colors.ENDC}")
|
||||||
|
|
||||||
|
|
||||||
|
def print_warning(text: str):
|
||||||
|
"""Print warning message"""
|
||||||
|
print(f"{Colors.WARNING}⚠️ {text}{Colors.ENDC}")
|
||||||
|
|
||||||
|
|
||||||
|
def print_info(text: str, indent: int = 0):
|
||||||
|
"""Print info message"""
|
||||||
|
indent_str = " " * indent
|
||||||
|
print(f"{indent_str}{Colors.BLUE}ℹ️ {text}{Colors.ENDC}")
|
||||||
|
|
||||||
|
|
||||||
|
class WebToolsTester:
|
||||||
|
"""Test suite for web tools"""
|
||||||
|
|
||||||
|
def __init__(self, verbose: bool = False, test_llm: bool = True):
|
||||||
|
self.verbose = verbose
|
||||||
|
self.test_llm = test_llm
|
||||||
|
self.test_results = {
|
||||||
|
"passed": [],
|
||||||
|
"failed": [],
|
||||||
|
"skipped": []
|
||||||
|
}
|
||||||
|
self.start_time = None
|
||||||
|
self.end_time = None
|
||||||
|
|
||||||
|
def log_result(self, test_name: str, status: str, details: str = ""):
|
||||||
|
"""Log test result"""
|
||||||
|
result = {
|
||||||
|
"test": test_name,
|
||||||
|
"status": status,
|
||||||
|
"details": details,
|
||||||
|
"timestamp": datetime.now().isoformat()
|
||||||
|
}
|
||||||
|
|
||||||
|
if status == "passed":
|
||||||
|
self.test_results["passed"].append(result)
|
||||||
|
print_success(f"{test_name}: {details}" if details else test_name)
|
||||||
|
elif status == "failed":
|
||||||
|
self.test_results["failed"].append(result)
|
||||||
|
print_error(f"{test_name}: {details}" if details else test_name)
|
||||||
|
elif status == "skipped":
|
||||||
|
self.test_results["skipped"].append(result)
|
||||||
|
print_warning(f"{test_name} skipped: {details}" if details else f"{test_name} skipped")
|
||||||
|
|
||||||
|
def test_environment(self) -> bool:
|
||||||
|
"""Test environment setup and API keys"""
|
||||||
|
print_section("Environment Check")
|
||||||
|
|
||||||
|
# Check Firecrawl API key
|
||||||
|
if not check_firecrawl_api_key():
|
||||||
|
self.log_result("Firecrawl API Key", "failed", "FIRECRAWL_API_KEY not set")
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
self.log_result("Firecrawl API Key", "passed", "Found")
|
||||||
|
|
||||||
|
# Check Nous API key (optional)
|
||||||
|
if not check_nous_api_key():
|
||||||
|
self.log_result("Nous API Key", "skipped", "NOUS_API_KEY not set (LLM tests will be skipped)")
|
||||||
|
self.test_llm = False
|
||||||
|
else:
|
||||||
|
self.log_result("Nous API Key", "passed", "Found")
|
||||||
|
|
||||||
|
# Check debug mode
|
||||||
|
debug_info = get_debug_session_info()
|
||||||
|
if debug_info["enabled"]:
|
||||||
|
print_info(f"Debug mode enabled - Session: {debug_info['session_id']}")
|
||||||
|
print_info(f"Debug log: {debug_info['log_path']}")
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def test_web_search(self) -> List[str]:
|
||||||
|
"""Test web search functionality"""
|
||||||
|
print_section("Test 1: Web Search")
|
||||||
|
|
||||||
|
test_queries = [
|
||||||
|
("Python web scraping tutorial", 5),
|
||||||
|
("Firecrawl API documentation", 3),
|
||||||
|
("inflammatory arthritis symptoms treatment", 8) # Test medical query from your example
|
||||||
|
]
|
||||||
|
|
||||||
|
extracted_urls = []
|
||||||
|
|
||||||
|
for query, limit in test_queries:
|
||||||
|
try:
|
||||||
|
print(f"\n Testing search: '{query}' (limit={limit})")
|
||||||
|
|
||||||
|
if self.verbose:
|
||||||
|
print(f" Calling web_search_tool(query='{query}', limit={limit})")
|
||||||
|
|
||||||
|
# Perform search
|
||||||
|
result = web_search_tool(query, limit)
|
||||||
|
|
||||||
|
# Parse result
|
||||||
|
try:
|
||||||
|
data = json.loads(result)
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
self.log_result(f"Search: {query[:30]}...", "failed", f"Invalid JSON: {e}")
|
||||||
|
if self.verbose:
|
||||||
|
print(f" Raw response (first 500 chars): {result[:500]}...")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if "error" in data:
|
||||||
|
self.log_result(f"Search: {query[:30]}...", "failed", f"API error: {data['error']}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check structure
|
||||||
|
if "success" not in data or "data" not in data:
|
||||||
|
self.log_result(f"Search: {query[:30]}...", "failed", "Missing success or data fields")
|
||||||
|
if self.verbose:
|
||||||
|
print(f" Response keys: {list(data.keys())}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
web_results = data.get("data", {}).get("web", [])
|
||||||
|
|
||||||
|
if not web_results:
|
||||||
|
self.log_result(f"Search: {query[:30]}...", "failed", "Empty web results array")
|
||||||
|
if self.verbose:
|
||||||
|
print(f" data.web content: {data.get('data', {}).get('web')}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Validate each result
|
||||||
|
valid_results = 0
|
||||||
|
missing_fields = []
|
||||||
|
|
||||||
|
for i, result in enumerate(web_results):
|
||||||
|
required_fields = ["url", "title", "description"]
|
||||||
|
has_all_fields = all(key in result for key in required_fields)
|
||||||
|
|
||||||
|
if has_all_fields:
|
||||||
|
valid_results += 1
|
||||||
|
# Collect URLs for extraction test
|
||||||
|
if len(extracted_urls) < 3:
|
||||||
|
extracted_urls.append(result["url"])
|
||||||
|
|
||||||
|
if self.verbose:
|
||||||
|
print(f" Result {i+1}: ✓ {result['title'][:50]}...")
|
||||||
|
print(f" URL: {result['url'][:60]}...")
|
||||||
|
else:
|
||||||
|
missing = [f for f in required_fields if f not in result]
|
||||||
|
missing_fields.append(f"Result {i+1} missing: {missing}")
|
||||||
|
if self.verbose:
|
||||||
|
print(f" Result {i+1}: ✗ Missing fields: {missing}")
|
||||||
|
|
||||||
|
# Log results
|
||||||
|
if valid_results == len(web_results):
|
||||||
|
self.log_result(
|
||||||
|
f"Search: {query[:30]}...",
|
||||||
|
"passed",
|
||||||
|
f"All {valid_results} results valid"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.log_result(
|
||||||
|
f"Search: {query[:30]}...",
|
||||||
|
"failed",
|
||||||
|
f"Only {valid_results}/{len(web_results)} valid. Issues: {'; '.join(missing_fields[:3])}"
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.log_result(f"Search: {query[:30]}...", "failed", f"Exception: {type(e).__name__}: {str(e)}")
|
||||||
|
if self.verbose:
|
||||||
|
import traceback
|
||||||
|
print(f" Traceback: {traceback.format_exc()}")
|
||||||
|
|
||||||
|
if self.verbose and extracted_urls:
|
||||||
|
print(f"\n URLs collected for extraction test: {len(extracted_urls)}")
|
||||||
|
for url in extracted_urls:
|
||||||
|
print(f" - {url}")
|
||||||
|
|
||||||
|
return extracted_urls
|
||||||
|
|
||||||
|
async def test_web_extract(self, urls: List[str] = None):
|
||||||
|
"""Test web content extraction"""
|
||||||
|
print_section("Test 2: Web Extract (without LLM)")
|
||||||
|
|
||||||
|
# Use provided URLs or defaults
|
||||||
|
if not urls:
|
||||||
|
urls = [
|
||||||
|
"https://docs.firecrawl.dev/introduction",
|
||||||
|
"https://www.python.org/about/"
|
||||||
|
]
|
||||||
|
print(f" Using default URLs for testing")
|
||||||
|
else:
|
||||||
|
print(f" Using {len(urls)} URLs from search results")
|
||||||
|
|
||||||
|
# Test extraction
|
||||||
|
if urls:
|
||||||
|
try:
|
||||||
|
test_urls = urls[:2] # Test with max 2 URLs
|
||||||
|
print(f"\n Extracting content from {len(test_urls)} URL(s)...")
|
||||||
|
for url in test_urls:
|
||||||
|
print(f" - {url}")
|
||||||
|
|
||||||
|
if self.verbose:
|
||||||
|
print(f" Calling web_extract_tool(urls={test_urls}, format='markdown', use_llm_processing=False)")
|
||||||
|
|
||||||
|
result = await web_extract_tool(
|
||||||
|
test_urls,
|
||||||
|
format="markdown",
|
||||||
|
use_llm_processing=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# Parse result
|
||||||
|
try:
|
||||||
|
data = json.loads(result)
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
self.log_result("Extract (no LLM)", "failed", f"Invalid JSON: {e}")
|
||||||
|
if self.verbose:
|
||||||
|
print(f" Raw response (first 500 chars): {result[:500]}...")
|
||||||
|
return
|
||||||
|
|
||||||
|
if "error" in data:
|
||||||
|
self.log_result("Extract (no LLM)", "failed", f"API error: {data['error']}")
|
||||||
|
return
|
||||||
|
|
||||||
|
results = data.get("results", [])
|
||||||
|
|
||||||
|
if not results:
|
||||||
|
self.log_result("Extract (no LLM)", "failed", "No results in response")
|
||||||
|
if self.verbose:
|
||||||
|
print(f" Response keys: {list(data.keys())}")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Validate each result
|
||||||
|
valid_results = 0
|
||||||
|
failed_results = 0
|
||||||
|
total_content_length = 0
|
||||||
|
extraction_details = []
|
||||||
|
|
||||||
|
for i, result in enumerate(results):
|
||||||
|
title = result.get("title", "No title")
|
||||||
|
content = result.get("content", "")
|
||||||
|
error = result.get("error")
|
||||||
|
|
||||||
|
if error:
|
||||||
|
failed_results += 1
|
||||||
|
extraction_details.append(f"Page {i+1}: ERROR - {error}")
|
||||||
|
if self.verbose:
|
||||||
|
print(f" Page {i+1}: ✗ Error - {error}")
|
||||||
|
elif content:
|
||||||
|
content_len = len(content)
|
||||||
|
total_content_length += content_len
|
||||||
|
valid_results += 1
|
||||||
|
extraction_details.append(f"Page {i+1}: {title[:40]}... ({content_len} chars)")
|
||||||
|
if self.verbose:
|
||||||
|
print(f" Page {i+1}: ✓ {title[:50]}... - {content_len} characters")
|
||||||
|
print(f" First 100 chars: {content[:100]}...")
|
||||||
|
else:
|
||||||
|
extraction_details.append(f"Page {i+1}: {title[:40]}... (EMPTY)")
|
||||||
|
if self.verbose:
|
||||||
|
print(f" Page {i+1}: ⚠ {title[:50]}... - Empty content")
|
||||||
|
|
||||||
|
# Log results
|
||||||
|
if valid_results > 0:
|
||||||
|
self.log_result(
|
||||||
|
"Extract (no LLM)",
|
||||||
|
"passed",
|
||||||
|
f"{valid_results}/{len(results)} pages extracted, {total_content_length} total chars"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.log_result(
|
||||||
|
"Extract (no LLM)",
|
||||||
|
"failed",
|
||||||
|
f"No valid content. {failed_results} errors, {len(results) - failed_results} empty"
|
||||||
|
)
|
||||||
|
if self.verbose:
|
||||||
|
print(f"\n Extraction details:")
|
||||||
|
for detail in extraction_details:
|
||||||
|
print(f" {detail}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.log_result("Extract (no LLM)", "failed", f"Exception: {type(e).__name__}: {str(e)}")
|
||||||
|
if self.verbose:
|
||||||
|
import traceback
|
||||||
|
print(f" Traceback: {traceback.format_exc()}")
|
||||||
|
|
||||||
|
async def test_web_extract_with_llm(self, urls: List[str] = None):
|
||||||
|
"""Test web extraction with LLM processing"""
|
||||||
|
print_section("Test 3: Web Extract (with Gemini LLM)")
|
||||||
|
|
||||||
|
if not self.test_llm:
|
||||||
|
self.log_result("Extract (with LLM)", "skipped", "LLM testing disabled")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Use a URL likely to have substantial content
|
||||||
|
test_url = urls[0] if urls else "https://docs.firecrawl.dev/features/scrape"
|
||||||
|
|
||||||
|
try:
|
||||||
|
print(f"\n Extracting and processing: {test_url}")
|
||||||
|
|
||||||
|
result = await web_extract_tool(
|
||||||
|
[test_url],
|
||||||
|
format="markdown",
|
||||||
|
use_llm_processing=True,
|
||||||
|
min_length=1000 # Lower threshold for testing
|
||||||
|
)
|
||||||
|
|
||||||
|
data = json.loads(result)
|
||||||
|
|
||||||
|
if "error" in data:
|
||||||
|
self.log_result("Extract (with LLM)", "failed", data["error"])
|
||||||
|
return
|
||||||
|
|
||||||
|
results = data.get("results", [])
|
||||||
|
|
||||||
|
if not results:
|
||||||
|
self.log_result("Extract (with LLM)", "failed", "No results returned")
|
||||||
|
return
|
||||||
|
|
||||||
|
result = results[0]
|
||||||
|
content = result.get("content", "")
|
||||||
|
|
||||||
|
if content:
|
||||||
|
content_len = len(content)
|
||||||
|
|
||||||
|
# Check if content was actually processed (should be shorter than typical raw content)
|
||||||
|
if content_len > 0:
|
||||||
|
self.log_result(
|
||||||
|
"Extract (with LLM)",
|
||||||
|
"passed",
|
||||||
|
f"Content processed: {content_len} chars"
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.verbose:
|
||||||
|
print(f"\n First 300 chars of processed content:")
|
||||||
|
print(f" {content[:300]}...")
|
||||||
|
else:
|
||||||
|
self.log_result("Extract (with LLM)", "failed", "No content after processing")
|
||||||
|
else:
|
||||||
|
self.log_result("Extract (with LLM)", "failed", "No content field in result")
|
||||||
|
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
self.log_result("Extract (with LLM)", "failed", f"Invalid JSON: {e}")
|
||||||
|
except Exception as e:
|
||||||
|
self.log_result("Extract (with LLM)", "failed", str(e))
|
||||||
|
|
||||||
|
async def test_web_crawl(self):
|
||||||
|
"""Test web crawling functionality"""
|
||||||
|
print_section("Test 4: Web Crawl")
|
||||||
|
|
||||||
|
test_sites = [
|
||||||
|
("https://docs.firecrawl.dev", None, 2), # Test docs site
|
||||||
|
("https://firecrawl.dev", None, 3), # Test main site
|
||||||
|
]
|
||||||
|
|
||||||
|
for url, instructions, expected_min_pages in test_sites:
|
||||||
|
try:
|
||||||
|
print(f"\n Testing crawl of: {url}")
|
||||||
|
if instructions:
|
||||||
|
print(f" Instructions: {instructions}")
|
||||||
|
else:
|
||||||
|
print(f" No instructions (general crawl)")
|
||||||
|
print(f" Expected minimum pages: {expected_min_pages}")
|
||||||
|
|
||||||
|
# Show what's being called
|
||||||
|
if self.verbose:
|
||||||
|
print(f" Calling web_crawl_tool(url='{url}', instructions={instructions}, use_llm_processing=False)")
|
||||||
|
|
||||||
|
result = await web_crawl_tool(
|
||||||
|
url,
|
||||||
|
instructions=instructions,
|
||||||
|
use_llm_processing=False # Disable LLM for faster testing
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if result is valid JSON
|
||||||
|
try:
|
||||||
|
data = json.loads(result)
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
self.log_result(f"Crawl: {url}", "failed", f"Invalid JSON response: {e}")
|
||||||
|
if self.verbose:
|
||||||
|
print(f" Raw response (first 500 chars): {result[:500]}...")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check for errors
|
||||||
|
if "error" in data:
|
||||||
|
self.log_result(f"Crawl: {url}", "failed", f"API error: {data['error']}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Get results
|
||||||
|
results = data.get("results", [])
|
||||||
|
|
||||||
|
if not results:
|
||||||
|
self.log_result(f"Crawl: {url}", "failed", "No pages in results array")
|
||||||
|
if self.verbose:
|
||||||
|
print(f" Full response: {json.dumps(data, indent=2)[:1000]}...")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Analyze pages
|
||||||
|
valid_pages = 0
|
||||||
|
empty_pages = 0
|
||||||
|
total_content = 0
|
||||||
|
page_details = []
|
||||||
|
|
||||||
|
for i, page in enumerate(results):
|
||||||
|
content = page.get("content", "")
|
||||||
|
title = page.get("title", "Untitled")
|
||||||
|
error = page.get("error")
|
||||||
|
|
||||||
|
if error:
|
||||||
|
page_details.append(f"Page {i+1}: ERROR - {error}")
|
||||||
|
elif content:
|
||||||
|
valid_pages += 1
|
||||||
|
content_len = len(content)
|
||||||
|
total_content += content_len
|
||||||
|
page_details.append(f"Page {i+1}: {title[:40]}... ({content_len} chars)")
|
||||||
|
else:
|
||||||
|
empty_pages += 1
|
||||||
|
page_details.append(f"Page {i+1}: {title[:40]}... (EMPTY)")
|
||||||
|
|
||||||
|
# Show detailed results if verbose
|
||||||
|
if self.verbose:
|
||||||
|
print(f"\n Crawl Results:")
|
||||||
|
print(f" Total pages returned: {len(results)}")
|
||||||
|
print(f" Valid pages (with content): {valid_pages}")
|
||||||
|
print(f" Empty pages: {empty_pages}")
|
||||||
|
print(f" Total content size: {total_content} characters")
|
||||||
|
print(f"\n Page Details:")
|
||||||
|
for detail in page_details[:10]: # Show first 10 pages
|
||||||
|
print(f" - {detail}")
|
||||||
|
if len(page_details) > 10:
|
||||||
|
print(f" ... and {len(page_details) - 10} more pages")
|
||||||
|
|
||||||
|
# Determine pass/fail
|
||||||
|
if valid_pages >= expected_min_pages:
|
||||||
|
self.log_result(
|
||||||
|
f"Crawl: {url}",
|
||||||
|
"passed",
|
||||||
|
f"{valid_pages}/{len(results)} valid pages, {total_content} chars total"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.log_result(
|
||||||
|
f"Crawl: {url}",
|
||||||
|
"failed",
|
||||||
|
f"Only {valid_pages} valid pages (expected >= {expected_min_pages}), {empty_pages} empty, {len(results)} total"
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.log_result(f"Crawl: {url}", "failed", f"Exception: {type(e).__name__}: {str(e)}")
|
||||||
|
if self.verbose:
|
||||||
|
import traceback
|
||||||
|
print(f" Traceback:")
|
||||||
|
print(" " + "\n ".join(traceback.format_exc().split("\n")))
|
||||||
|
|
||||||
|
async def run_all_tests(self):
|
||||||
|
"""Run all tests"""
|
||||||
|
self.start_time = datetime.now()
|
||||||
|
|
||||||
|
print_header("WEB TOOLS TEST SUITE")
|
||||||
|
print(f"Started at: {self.start_time.strftime('%Y-%m-%d %H:%M:%S')}")
|
||||||
|
|
||||||
|
# Test environment
|
||||||
|
if not self.test_environment():
|
||||||
|
print_error("\nCannot proceed without required API keys!")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Test search and collect URLs
|
||||||
|
urls = self.test_web_search()
|
||||||
|
|
||||||
|
# Test extraction
|
||||||
|
await self.test_web_extract(urls if urls else None)
|
||||||
|
|
||||||
|
# Test extraction with LLM
|
||||||
|
if self.test_llm:
|
||||||
|
await self.test_web_extract_with_llm(urls if urls else None)
|
||||||
|
|
||||||
|
# Test crawling
|
||||||
|
await self.test_web_crawl()
|
||||||
|
|
||||||
|
# Print summary
|
||||||
|
self.end_time = datetime.now()
|
||||||
|
duration = (self.end_time - self.start_time).total_seconds()
|
||||||
|
|
||||||
|
print_header("TEST SUMMARY")
|
||||||
|
print(f"Duration: {duration:.2f} seconds")
|
||||||
|
print(f"\n{Colors.GREEN}Passed: {len(self.test_results['passed'])}{Colors.ENDC}")
|
||||||
|
print(f"{Colors.FAIL}Failed: {len(self.test_results['failed'])}{Colors.ENDC}")
|
||||||
|
print(f"{Colors.WARNING}Skipped: {len(self.test_results['skipped'])}{Colors.ENDC}")
|
||||||
|
|
||||||
|
# List failed tests
|
||||||
|
if self.test_results["failed"]:
|
||||||
|
print(f"\n{Colors.FAIL}{Colors.BOLD}Failed Tests:{Colors.ENDC}")
|
||||||
|
for test in self.test_results["failed"]:
|
||||||
|
print(f" - {test['test']}: {test['details']}")
|
||||||
|
|
||||||
|
# Save results to file
|
||||||
|
self.save_results()
|
||||||
|
|
||||||
|
return len(self.test_results["failed"]) == 0
|
||||||
|
|
||||||
|
def save_results(self):
|
||||||
|
"""Save test results to a JSON file"""
|
||||||
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
|
filename = f"test_results_web_tools_{timestamp}.json"
|
||||||
|
|
||||||
|
results = {
|
||||||
|
"test_suite": "Web Tools",
|
||||||
|
"start_time": self.start_time.isoformat() if self.start_time else None,
|
||||||
|
"end_time": self.end_time.isoformat() if self.end_time else None,
|
||||||
|
"duration_seconds": (self.end_time - self.start_time).total_seconds() if self.start_time and self.end_time else None,
|
||||||
|
"summary": {
|
||||||
|
"passed": len(self.test_results["passed"]),
|
||||||
|
"failed": len(self.test_results["failed"]),
|
||||||
|
"skipped": len(self.test_results["skipped"])
|
||||||
|
},
|
||||||
|
"results": self.test_results,
|
||||||
|
"environment": {
|
||||||
|
"firecrawl_api_key": check_firecrawl_api_key(),
|
||||||
|
"nous_api_key": check_nous_api_key(),
|
||||||
|
"debug_mode": get_debug_session_info()["enabled"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(filename, 'w') as f:
|
||||||
|
json.dump(results, f, indent=2)
|
||||||
|
print_info(f"Test results saved to: {filename}")
|
||||||
|
except Exception as e:
|
||||||
|
print_warning(f"Failed to save results: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
"""Main entry point"""
|
||||||
|
parser = argparse.ArgumentParser(description="Test Web Tools Module")
|
||||||
|
parser.add_argument("--no-llm", action="store_true", help="Skip LLM processing tests")
|
||||||
|
parser.add_argument("--verbose", "-v", action="store_true", help="Show detailed output")
|
||||||
|
parser.add_argument("--debug", action="store_true", help="Enable debug mode for web tools")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Set debug mode if requested
|
||||||
|
if args.debug:
|
||||||
|
os.environ["WEB_TOOLS_DEBUG"] = "true"
|
||||||
|
print_info("Debug mode enabled for web tools")
|
||||||
|
|
||||||
|
# Create tester
|
||||||
|
tester = WebToolsTester(
|
||||||
|
verbose=args.verbose,
|
||||||
|
test_llm=not args.no_llm
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run tests
|
||||||
|
success = await tester.run_all_tests()
|
||||||
|
|
||||||
|
# Exit with appropriate code
|
||||||
|
sys.exit(0 if success else 1)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
||||||
326
toolsets.py
Normal file
326
toolsets.py
Normal file
@@ -0,0 +1,326 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Toolsets Module
|
||||||
|
|
||||||
|
This module provides a flexible system for defining and managing tool aliases/toolsets.
|
||||||
|
Toolsets allow you to group tools together for specific scenarios and can be composed
|
||||||
|
from individual tools or other toolsets.
|
||||||
|
|
||||||
|
Features:
|
||||||
|
- Define custom toolsets with specific tools
|
||||||
|
- Compose toolsets from other toolsets
|
||||||
|
- Built-in common toolsets for typical use cases
|
||||||
|
- Easy extension for new toolsets
|
||||||
|
- Support for dynamic toolset resolution
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
from toolsets import get_toolset, resolve_toolset, get_all_toolsets
|
||||||
|
|
||||||
|
# Get tools for a specific toolset
|
||||||
|
tools = get_toolset("research")
|
||||||
|
|
||||||
|
# Resolve a toolset to get all tool names (including from composed toolsets)
|
||||||
|
all_tools = resolve_toolset("full_stack")
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import List, Dict, Any, Set, Optional
|
||||||
|
import json
|
||||||
|
|
||||||
|
|
||||||
|
# Core toolset definitions
|
||||||
|
# These can include individual tools or reference other toolsets
|
||||||
|
TOOLSETS = {
|
||||||
|
# Basic toolsets - individual tool categories
|
||||||
|
"web": {
|
||||||
|
"description": "Web research and content extraction tools",
|
||||||
|
"tools": ["web_search", "web_extract", "web_crawl"],
|
||||||
|
"includes": [] # No other toolsets included
|
||||||
|
},
|
||||||
|
|
||||||
|
"vision": {
|
||||||
|
"description": "Image analysis and vision tools",
|
||||||
|
"tools": ["vision_analyze"],
|
||||||
|
"includes": []
|
||||||
|
},
|
||||||
|
|
||||||
|
"image_gen": {
|
||||||
|
"description": "Creative generation tools (images)",
|
||||||
|
"tools": ["image_generate"],
|
||||||
|
"includes": []
|
||||||
|
},
|
||||||
|
|
||||||
|
"terminal": {
|
||||||
|
"description": "Terminal/command execution tools",
|
||||||
|
"tools": ["terminal"],
|
||||||
|
"includes": []
|
||||||
|
},
|
||||||
|
|
||||||
|
"moa": {
|
||||||
|
"description": "Advanced reasoning and problem-solving tools",
|
||||||
|
"tools": ["mixture_of_agents"],
|
||||||
|
"includes": []
|
||||||
|
},
|
||||||
|
|
||||||
|
# Scenario-specific toolsets
|
||||||
|
|
||||||
|
"debugging": {
|
||||||
|
"description": "Debugging and troubleshooting toolkit",
|
||||||
|
"tools": ["terminal"],
|
||||||
|
"includes": ["web"] # For searching error messages and solutions
|
||||||
|
},
|
||||||
|
|
||||||
|
"safe": {
|
||||||
|
"description": "Safe toolkit without terminal access",
|
||||||
|
"tools": ["mixture_of_agents"],
|
||||||
|
"includes": ["web", "vision", "creative"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def get_toolset(name: str) -> Optional[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Get a toolset definition by name.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name (str): Name of the toolset
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict: Toolset definition with description, tools, and includes
|
||||||
|
None: If toolset not found
|
||||||
|
"""
|
||||||
|
# Return toolset definition
|
||||||
|
return TOOLSETS.get(name)
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_toolset(name: str, visited: Set[str] = None) -> List[str]:
|
||||||
|
"""
|
||||||
|
Recursively resolve a toolset to get all tool names.
|
||||||
|
|
||||||
|
This function handles toolset composition by recursively resolving
|
||||||
|
included toolsets and combining all tools.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name (str): Name of the toolset to resolve
|
||||||
|
visited (Set[str]): Set of already visited toolsets (for cycle detection)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[str]: List of all tool names in the toolset
|
||||||
|
"""
|
||||||
|
if visited is None:
|
||||||
|
visited = set()
|
||||||
|
|
||||||
|
# Check for cycles
|
||||||
|
if name in visited:
|
||||||
|
print(f"⚠️ Circular dependency detected in toolset '{name}'")
|
||||||
|
return []
|
||||||
|
|
||||||
|
visited.add(name)
|
||||||
|
|
||||||
|
# Get toolset definition
|
||||||
|
toolset = TOOLSETS.get(name)
|
||||||
|
if not toolset:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Collect direct tools
|
||||||
|
tools = set(toolset.get("tools", []))
|
||||||
|
|
||||||
|
# Recursively resolve included toolsets
|
||||||
|
for included_name in toolset.get("includes", []):
|
||||||
|
included_tools = resolve_toolset(included_name, visited.copy())
|
||||||
|
tools.update(included_tools)
|
||||||
|
|
||||||
|
return list(tools)
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_multiple_toolsets(toolset_names: List[str]) -> List[str]:
|
||||||
|
"""
|
||||||
|
Resolve multiple toolsets and combine their tools.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
toolset_names (List[str]): List of toolset names to resolve
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[str]: Combined list of all tool names (deduplicated)
|
||||||
|
"""
|
||||||
|
all_tools = set()
|
||||||
|
|
||||||
|
for name in toolset_names:
|
||||||
|
tools = resolve_toolset(name)
|
||||||
|
all_tools.update(tools)
|
||||||
|
|
||||||
|
return list(all_tools)
|
||||||
|
|
||||||
|
|
||||||
|
def get_all_toolsets() -> Dict[str, Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Get all available toolsets with their definitions.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict: All toolset definitions
|
||||||
|
"""
|
||||||
|
return TOOLSETS.copy()
|
||||||
|
|
||||||
|
|
||||||
|
def get_toolset_names() -> List[str]:
|
||||||
|
"""
|
||||||
|
Get names of all available toolsets (excluding aliases).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[str]: List of toolset names
|
||||||
|
"""
|
||||||
|
return list(TOOLSETS.keys())
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def validate_toolset(name: str) -> bool:
|
||||||
|
"""
|
||||||
|
Check if a toolset name is valid.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name (str): Toolset name to validate
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if valid, False otherwise
|
||||||
|
"""
|
||||||
|
return name in TOOLSETS
|
||||||
|
|
||||||
|
|
||||||
|
def create_custom_toolset(
|
||||||
|
name: str,
|
||||||
|
description: str,
|
||||||
|
tools: List[str] = None,
|
||||||
|
includes: List[str] = None
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Create a custom toolset at runtime.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name (str): Name for the new toolset
|
||||||
|
description (str): Description of the toolset
|
||||||
|
tools (List[str]): Direct tools to include
|
||||||
|
includes (List[str]): Other toolsets to include
|
||||||
|
"""
|
||||||
|
TOOLSETS[name] = {
|
||||||
|
"description": description,
|
||||||
|
"tools": tools or [],
|
||||||
|
"includes": includes or []
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def get_toolset_info(name: str) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Get detailed information about a toolset including resolved tools.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name (str): Toolset name
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict: Detailed toolset information
|
||||||
|
"""
|
||||||
|
toolset = get_toolset(name)
|
||||||
|
if not toolset:
|
||||||
|
return None
|
||||||
|
|
||||||
|
resolved_tools = resolve_toolset(name)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"name": name,
|
||||||
|
"description": toolset["description"],
|
||||||
|
"direct_tools": toolset["tools"],
|
||||||
|
"includes": toolset["includes"],
|
||||||
|
"resolved_tools": resolved_tools,
|
||||||
|
"tool_count": len(resolved_tools),
|
||||||
|
"is_composite": len(toolset["includes"]) > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def print_toolset_tree(name: str, indent: int = 0) -> None:
|
||||||
|
"""
|
||||||
|
Print a tree view of a toolset and its composition.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name (str): Toolset name
|
||||||
|
indent (int): Current indentation level
|
||||||
|
"""
|
||||||
|
prefix = " " * indent
|
||||||
|
toolset = get_toolset(name)
|
||||||
|
|
||||||
|
if not toolset:
|
||||||
|
print(f"{prefix}❌ Unknown toolset: {name}")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Print toolset name and description
|
||||||
|
print(f"{prefix}📦 {name}: {toolset['description']}")
|
||||||
|
|
||||||
|
# Print direct tools
|
||||||
|
if toolset["tools"]:
|
||||||
|
print(f"{prefix} 🔧 Tools: {', '.join(toolset['tools'])}")
|
||||||
|
|
||||||
|
# Print included toolsets
|
||||||
|
if toolset["includes"]:
|
||||||
|
print(f"{prefix} 📂 Includes:")
|
||||||
|
for included in toolset["includes"]:
|
||||||
|
print_toolset_tree(included, indent + 2)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
"""
|
||||||
|
Demo and testing of the toolsets system
|
||||||
|
"""
|
||||||
|
print("🎯 Toolsets System Demo")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
# Show all available toolsets
|
||||||
|
print("\n📦 Available Toolsets:")
|
||||||
|
print("-" * 40)
|
||||||
|
for name, toolset in get_all_toolsets().items():
|
||||||
|
info = get_toolset_info(name)
|
||||||
|
composite = "📂" if info["is_composite"] else "🔧"
|
||||||
|
print(f"{composite} {name:20} - {toolset['description']}")
|
||||||
|
print(f" Tools: {len(info['resolved_tools'])} total")
|
||||||
|
|
||||||
|
|
||||||
|
# Demo toolset resolution
|
||||||
|
print("\n🔍 Toolset Resolution Examples:")
|
||||||
|
print("-" * 40)
|
||||||
|
|
||||||
|
examples = ["research", "development", "full_stack", "minimal", "safe"]
|
||||||
|
for name in examples:
|
||||||
|
tools = resolve_toolset(name)
|
||||||
|
print(f"\n{name}:")
|
||||||
|
print(f" Resolved to {len(tools)} tools: {', '.join(sorted(tools))}")
|
||||||
|
|
||||||
|
# Show toolset composition tree
|
||||||
|
print("\n🌳 Toolset Composition Tree:")
|
||||||
|
print("-" * 40)
|
||||||
|
print("\nExample: 'content_creation' toolset:")
|
||||||
|
print_toolset_tree("content_creation")
|
||||||
|
|
||||||
|
print("\nExample: 'full_stack' toolset:")
|
||||||
|
print_toolset_tree("full_stack")
|
||||||
|
|
||||||
|
# Demo multiple toolset resolution
|
||||||
|
print("\n🔗 Multiple Toolset Resolution:")
|
||||||
|
print("-" * 40)
|
||||||
|
combined = resolve_multiple_toolsets(["minimal", "vision", "reasoning"])
|
||||||
|
print(f"Combining ['minimal', 'vision', 'reasoning']:")
|
||||||
|
print(f" Result: {', '.join(sorted(combined))}")
|
||||||
|
|
||||||
|
# Demo custom toolset creation
|
||||||
|
print("\n➕ Custom Toolset Creation:")
|
||||||
|
print("-" * 40)
|
||||||
|
create_custom_toolset(
|
||||||
|
name="my_custom",
|
||||||
|
description="My custom toolset for specific tasks",
|
||||||
|
tools=["web_search"],
|
||||||
|
includes=["terminal", "vision"]
|
||||||
|
)
|
||||||
|
|
||||||
|
custom_info = get_toolset_info("my_custom")
|
||||||
|
print(f"Created 'my_custom' toolset:")
|
||||||
|
print(f" Description: {custom_info['description']}")
|
||||||
|
print(f" Resolved tools: {', '.join(custom_info['resolved_tools'])}")
|
||||||
23
ui/__init__.py
Normal file
23
ui/__init__.py
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
"""
|
||||||
|
Hermes Agent UI Package
|
||||||
|
|
||||||
|
A modular PySide6 UI for the Hermes AI Agent with real-time event streaming.
|
||||||
|
|
||||||
|
Modules:
|
||||||
|
- websocket_client: WebSocket communication
|
||||||
|
- event_widgets: Event display components
|
||||||
|
- main_window: Main application window
|
||||||
|
- hermes_ui: Application entry point
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .websocket_client import WebSocketClient
|
||||||
|
from .event_widgets import CollapsibleEventWidget, InteractiveEventDisplayWidget
|
||||||
|
from .main_window import HermesMainWindow
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'WebSocketClient',
|
||||||
|
'CollapsibleEventWidget',
|
||||||
|
'InteractiveEventDisplayWidget',
|
||||||
|
'HermesMainWindow',
|
||||||
|
]
|
||||||
|
|
||||||
334
ui/event_widgets.py
Normal file
334
ui/event_widgets.py
Normal file
@@ -0,0 +1,334 @@
|
|||||||
|
"""
|
||||||
|
Event display widgets for Hermes Agent UI.
|
||||||
|
|
||||||
|
This module provides widgets for displaying and managing real-time agent events
|
||||||
|
in a collapsible, filterable interface.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Dict, Any
|
||||||
|
|
||||||
|
from PySide6.QtWidgets import (
|
||||||
|
QWidget, QVBoxLayout, QHBoxLayout, QLabel, QPushButton,
|
||||||
|
QCheckBox, QGroupBox, QFrame, QScrollArea
|
||||||
|
)
|
||||||
|
from PySide6.QtCore import Qt, QTimer
|
||||||
|
from PySide6.QtGui import QFont
|
||||||
|
|
||||||
|
|
||||||
|
class CollapsibleEventWidget(QFrame):
|
||||||
|
"""
|
||||||
|
A single collapsible event with expand/collapse functionality.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, event: Dict[str, Any], parent=None):
|
||||||
|
super().__init__(parent)
|
||||||
|
self.event = event
|
||||||
|
self.is_expanded = False
|
||||||
|
self.event_type = event.get("event_type", "unknown")
|
||||||
|
|
||||||
|
self.setFrameStyle(QFrame.Box | QFrame.Raised)
|
||||||
|
self.setLineWidth(1)
|
||||||
|
self.setup_ui()
|
||||||
|
|
||||||
|
def setup_ui(self):
|
||||||
|
"""Initialize UI components."""
|
||||||
|
layout = QVBoxLayout()
|
||||||
|
layout.setContentsMargins(8, 8, 8, 8)
|
||||||
|
layout.setSpacing(4)
|
||||||
|
|
||||||
|
# Header (clickable)
|
||||||
|
self.header_widget = QWidget()
|
||||||
|
header_layout = QHBoxLayout()
|
||||||
|
header_layout.setContentsMargins(0, 0, 0, 0)
|
||||||
|
|
||||||
|
self.expand_indicator = QLabel("▶")
|
||||||
|
self.expand_indicator.setFixedWidth(20)
|
||||||
|
header_layout.addWidget(self.expand_indicator)
|
||||||
|
|
||||||
|
self.summary_label = QLabel()
|
||||||
|
self.summary_label.setFont(QFont("Arial", 10, QFont.Bold))
|
||||||
|
self.update_summary()
|
||||||
|
header_layout.addWidget(self.summary_label, 1)
|
||||||
|
|
||||||
|
# Timestamp
|
||||||
|
timestamp = self.event.get("timestamp", datetime.now().isoformat())
|
||||||
|
time_str = datetime.fromisoformat(timestamp.replace('Z', '+00:00')).strftime("%H:%M:%S")
|
||||||
|
time_label = QLabel(time_str)
|
||||||
|
time_label.setStyleSheet("color: #888;")
|
||||||
|
header_layout.addWidget(time_label)
|
||||||
|
|
||||||
|
self.header_widget.setLayout(header_layout)
|
||||||
|
self.header_widget.mousePressEvent = lambda e: self.toggle_expand()
|
||||||
|
self.header_widget.setCursor(Qt.PointingHandCursor)
|
||||||
|
|
||||||
|
layout.addWidget(self.header_widget)
|
||||||
|
|
||||||
|
# Details (collapsible)
|
||||||
|
self.details_widget = QWidget()
|
||||||
|
self.details_layout = QVBoxLayout()
|
||||||
|
self.details_layout.setContentsMargins(25, 5, 5, 5)
|
||||||
|
self.populate_details()
|
||||||
|
self.details_widget.setLayout(self.details_layout)
|
||||||
|
self.details_widget.setVisible(False)
|
||||||
|
|
||||||
|
layout.addWidget(self.details_widget)
|
||||||
|
|
||||||
|
self.setLayout(layout)
|
||||||
|
self.apply_colors()
|
||||||
|
|
||||||
|
def apply_colors(self):
|
||||||
|
"""Apply color scheme based on event type."""
|
||||||
|
colors = {
|
||||||
|
"query": "#E8F5E9", # Light green
|
||||||
|
"api_call": "#E3F2FD", # Light blue
|
||||||
|
"response": "#F3E5F5", # Light purple
|
||||||
|
"tool_call": "#FFF3E0", # Light orange
|
||||||
|
"tool_result": "#E8F5E9", # Light green
|
||||||
|
"complete": "#E8F5E9", # Light green
|
||||||
|
"error": "#FFEBEE", # Light red
|
||||||
|
"session_start": "#F5F5F5" # Light gray
|
||||||
|
}
|
||||||
|
|
||||||
|
bg_color = colors.get(self.event_type, "#FAFAFA")
|
||||||
|
self.setStyleSheet(f"""
|
||||||
|
CollapsibleEventWidget {{
|
||||||
|
background-color: {bg_color};
|
||||||
|
border: 1px solid #ddd;
|
||||||
|
border-radius: 4px;
|
||||||
|
}}
|
||||||
|
""")
|
||||||
|
|
||||||
|
def update_summary(self):
|
||||||
|
"""Update the summary label with event type."""
|
||||||
|
self.summary_label.setText(f"- {self.event_type.upper()}")
|
||||||
|
|
||||||
|
def populate_details(self):
|
||||||
|
"""Populate the details section with event data."""
|
||||||
|
data = self.event.get("data", {})
|
||||||
|
|
||||||
|
# Clear existing details
|
||||||
|
while self.details_layout.count():
|
||||||
|
item = self.details_layout.takeAt(0)
|
||||||
|
if item.widget():
|
||||||
|
item.widget().deleteLater()
|
||||||
|
|
||||||
|
self.add_detail("Raw Data", json.dumps(data, indent=2), multiline=True)
|
||||||
|
|
||||||
|
def add_detail(self, label: str, value: str, multiline: bool = True):
|
||||||
|
"""Add a detail row to the details section."""
|
||||||
|
detail_widget = QWidget()
|
||||||
|
detail_layout = QVBoxLayout() if multiline else QHBoxLayout()
|
||||||
|
detail_layout.setContentsMargins(0, 2, 0, 2)
|
||||||
|
|
||||||
|
label_widget = QLabel(f"<b>{label}:</b>")
|
||||||
|
label_widget.setTextFormat(Qt.RichText)
|
||||||
|
|
||||||
|
value_widget = QLabel(value)
|
||||||
|
value_widget.setWordWrap(True)
|
||||||
|
value_widget.setTextInteractionFlags(Qt.TextSelectableByMouse)
|
||||||
|
|
||||||
|
if multiline:
|
||||||
|
font = QFont()
|
||||||
|
font.setStyleHint(QFont.Monospace)
|
||||||
|
font.setPointSize(9)
|
||||||
|
value_widget.setFont(font)
|
||||||
|
value_widget.setStyleSheet("background-color: #f5f5f5; padding: 5px; border-radius: 3px;")
|
||||||
|
detail_layout.addWidget(label_widget)
|
||||||
|
detail_layout.addWidget(value_widget)
|
||||||
|
else:
|
||||||
|
detail_layout.addWidget(label_widget)
|
||||||
|
detail_layout.addWidget(value_widget, 1)
|
||||||
|
|
||||||
|
detail_widget.setLayout(detail_layout)
|
||||||
|
self.details_layout.addWidget(detail_widget)
|
||||||
|
|
||||||
|
def toggle_expand(self):
|
||||||
|
"""Toggle expanded/collapsed state."""
|
||||||
|
self.is_expanded = not self.is_expanded
|
||||||
|
self.details_widget.setVisible(self.is_expanded)
|
||||||
|
self.expand_indicator.setText("▼" if self.is_expanded else "▶")
|
||||||
|
|
||||||
|
|
||||||
|
class InteractiveEventDisplayWidget(QWidget):
|
||||||
|
"""
|
||||||
|
Interactive widget for displaying real-time agent events.
|
||||||
|
|
||||||
|
Features:
|
||||||
|
- Collapsible event items
|
||||||
|
- Event type filtering
|
||||||
|
- Expand/collapse all
|
||||||
|
- Auto-scroll to latest events
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.events = []
|
||||||
|
self.event_widgets = []
|
||||||
|
self.current_session = None
|
||||||
|
self.filters = {
|
||||||
|
"query": True,
|
||||||
|
"api_call": True,
|
||||||
|
"response": True,
|
||||||
|
"tool_call": True,
|
||||||
|
"tool_result": True,
|
||||||
|
"complete": True,
|
||||||
|
"error": True,
|
||||||
|
"session_start": True
|
||||||
|
}
|
||||||
|
self.init_ui()
|
||||||
|
|
||||||
|
def init_ui(self):
|
||||||
|
"""Initialize the UI components."""
|
||||||
|
layout = QVBoxLayout()
|
||||||
|
layout.setContentsMargins(5, 5, 5, 5)
|
||||||
|
|
||||||
|
# Header with controls
|
||||||
|
header_layout = QHBoxLayout()
|
||||||
|
|
||||||
|
title = QLabel("📡 Real-time Event Stream")
|
||||||
|
title.setFont(QFont("Arial", 12, QFont.Bold))
|
||||||
|
header_layout.addWidget(title)
|
||||||
|
|
||||||
|
header_layout.addStretch()
|
||||||
|
|
||||||
|
# Expand/Collapse All buttons
|
||||||
|
expand_all_btn = QPushButton("Expand All")
|
||||||
|
expand_all_btn.clicked.connect(self.expand_all)
|
||||||
|
header_layout.addWidget(expand_all_btn)
|
||||||
|
|
||||||
|
collapse_all_btn = QPushButton("Collapse All")
|
||||||
|
collapse_all_btn.clicked.connect(self.collapse_all)
|
||||||
|
header_layout.addWidget(collapse_all_btn)
|
||||||
|
|
||||||
|
# Clear button
|
||||||
|
clear_btn = QPushButton("🗑️ Clear")
|
||||||
|
clear_btn.clicked.connect(self.clear_events)
|
||||||
|
header_layout.addWidget(clear_btn)
|
||||||
|
|
||||||
|
layout.addLayout(header_layout)
|
||||||
|
|
||||||
|
# Filter controls
|
||||||
|
filter_group = QGroupBox("Event Filters (Show/Hide)")
|
||||||
|
filter_layout = QHBoxLayout()
|
||||||
|
filter_layout.setSpacing(10)
|
||||||
|
|
||||||
|
self.filter_checkboxes = {}
|
||||||
|
filter_configs = [
|
||||||
|
("query", "📝 Queries"),
|
||||||
|
("api_call", "🔄 API Calls"),
|
||||||
|
("response", "🤖 Responses"),
|
||||||
|
("tool_call", "🔧 Tool Calls"),
|
||||||
|
("tool_result", "✅ Results"),
|
||||||
|
("complete", "🎉 Complete"),
|
||||||
|
("error", "❌ Errors"),
|
||||||
|
]
|
||||||
|
|
||||||
|
for event_type, label in filter_configs:
|
||||||
|
checkbox = QCheckBox(label)
|
||||||
|
checkbox.setChecked(True)
|
||||||
|
checkbox.stateChanged.connect(lambda state, et=event_type: self.toggle_filter(et, state))
|
||||||
|
self.filter_checkboxes[event_type] = checkbox
|
||||||
|
filter_layout.addWidget(checkbox)
|
||||||
|
|
||||||
|
filter_group.setLayout(filter_layout)
|
||||||
|
layout.addWidget(filter_group)
|
||||||
|
|
||||||
|
# Scroll area for events
|
||||||
|
scroll_area = QScrollArea()
|
||||||
|
scroll_area.setWidgetResizable(True)
|
||||||
|
scroll_area.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff)
|
||||||
|
|
||||||
|
# Container for event widgets
|
||||||
|
self.events_container = QWidget()
|
||||||
|
self.events_layout = QVBoxLayout()
|
||||||
|
self.events_layout.setSpacing(5)
|
||||||
|
self.events_layout.addStretch() # Push events to top
|
||||||
|
self.events_container.setLayout(self.events_layout)
|
||||||
|
|
||||||
|
scroll_area.setWidget(self.events_container)
|
||||||
|
layout.addWidget(scroll_area)
|
||||||
|
|
||||||
|
self.setLayout(layout)
|
||||||
|
|
||||||
|
def clear_events(self):
|
||||||
|
"""Clear all displayed events."""
|
||||||
|
self.events.clear()
|
||||||
|
self.event_widgets.clear()
|
||||||
|
|
||||||
|
# Remove all widgets
|
||||||
|
while self.events_layout.count() > 1: # Keep the stretch
|
||||||
|
item = self.events_layout.takeAt(0)
|
||||||
|
if item.widget():
|
||||||
|
item.widget().deleteLater()
|
||||||
|
|
||||||
|
self.current_session = None
|
||||||
|
|
||||||
|
def add_event(self, event: Dict[str, Any]):
|
||||||
|
"""Add an event to the display."""
|
||||||
|
event_type = event.get("event_type", "unknown")
|
||||||
|
session_id = event.get("session_id", "")
|
||||||
|
|
||||||
|
# Track session changes - add session start event
|
||||||
|
if self.current_session != session_id:
|
||||||
|
self.current_session = session_id
|
||||||
|
session_event = {
|
||||||
|
"event_type": "session_start",
|
||||||
|
"session_id": session_id,
|
||||||
|
"timestamp": event.get("timestamp", datetime.now().isoformat()),
|
||||||
|
"data": {
|
||||||
|
"session_id": session_id,
|
||||||
|
"start_time": event.get("timestamp", datetime.now().isoformat())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
self._add_event_widget(session_event)
|
||||||
|
|
||||||
|
# Add the actual event
|
||||||
|
self._add_event_widget(event)
|
||||||
|
|
||||||
|
def _add_event_widget(self, event: Dict[str, Any]):
|
||||||
|
"""Internal method to add event widget."""
|
||||||
|
event_widget = CollapsibleEventWidget(event)
|
||||||
|
|
||||||
|
# Apply filter visibility
|
||||||
|
event_type = event.get("event_type", "unknown")
|
||||||
|
event_widget.setVisible(self.filters.get(event_type, True))
|
||||||
|
|
||||||
|
# Insert before the stretch
|
||||||
|
self.events_layout.insertWidget(self.events_layout.count() - 1, event_widget)
|
||||||
|
|
||||||
|
self.events.append(event)
|
||||||
|
self.event_widgets.append(event_widget)
|
||||||
|
|
||||||
|
# Auto-scroll to bottom after widget is rendered
|
||||||
|
QTimer.singleShot(50, self._scroll_to_bottom)
|
||||||
|
|
||||||
|
def _scroll_to_bottom(self):
|
||||||
|
"""Scroll to the bottom of the events list."""
|
||||||
|
scroll_area = self.events_container.parent()
|
||||||
|
if isinstance(scroll_area, QScrollArea):
|
||||||
|
scroll_bar = scroll_area.verticalScrollBar()
|
||||||
|
scroll_bar.setValue(scroll_bar.maximum())
|
||||||
|
|
||||||
|
def expand_all(self):
|
||||||
|
"""Expand all event widgets."""
|
||||||
|
for widget in self.event_widgets:
|
||||||
|
if not widget.is_expanded:
|
||||||
|
widget.toggle_expand()
|
||||||
|
|
||||||
|
def collapse_all(self):
|
||||||
|
"""Collapse all event widgets."""
|
||||||
|
for widget in self.event_widgets:
|
||||||
|
if widget.is_expanded:
|
||||||
|
widget.toggle_expand()
|
||||||
|
|
||||||
|
def toggle_filter(self, event_type: str, state: int):
|
||||||
|
"""Toggle visibility of events by type."""
|
||||||
|
self.filters[event_type] = bool(state)
|
||||||
|
|
||||||
|
# Update visibility of existing widgets
|
||||||
|
for event, widget in zip(self.events, self.event_widgets):
|
||||||
|
if event.get("event_type") == event_type:
|
||||||
|
widget.setVisible(self.filters[event_type])
|
||||||
|
|
||||||
102
ui/hermes_ui.py
Normal file
102
ui/hermes_ui.py
Normal file
@@ -0,0 +1,102 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Hermes Agent - PySide6 Frontend
|
||||||
|
|
||||||
|
A modern desktop UI for the Hermes AI Agent with real-time event streaming.
|
||||||
|
|
||||||
|
Features:
|
||||||
|
- Query input with multi-line support
|
||||||
|
- Tool/toolset selection
|
||||||
|
- Model and API configuration
|
||||||
|
- Real-time event display via WebSocket
|
||||||
|
- Beautiful, responsive UI with dark theme
|
||||||
|
- Session history
|
||||||
|
- Safe exit handling (no segfaults)
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python hermes_ui.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import signal
|
||||||
|
import os
|
||||||
|
|
||||||
|
# Suppress Qt logging warnings BEFORE importing Qt
|
||||||
|
os.environ['QT_LOGGING_RULES'] = 'qt.qpa.*=false'
|
||||||
|
|
||||||
|
from PySide6.QtWidgets import QApplication
|
||||||
|
from PySide6.QtCore import QTimer
|
||||||
|
|
||||||
|
from main_window import HermesMainWindow
|
||||||
|
|
||||||
|
|
||||||
|
def setup_signal_handlers(app: QApplication) -> QTimer:
|
||||||
|
"""
|
||||||
|
Setup signal handlers for graceful shutdown on Ctrl+C.
|
||||||
|
|
||||||
|
This prevents segmentation faults by:
|
||||||
|
1. Catching SIGINT/SIGTERM signals
|
||||||
|
2. Creating a timer that keeps Python responsive to signals
|
||||||
|
3. Calling app.quit() for proper Qt cleanup
|
||||||
|
|
||||||
|
Args:
|
||||||
|
app: The QApplication instance
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Timer that keeps Python interpreter responsive to signals
|
||||||
|
"""
|
||||||
|
def signal_handler(signum, frame):
|
||||||
|
"""Handle interrupt signals gracefully."""
|
||||||
|
print("\n🛑 Interrupt received, shutting down gracefully...")
|
||||||
|
app.quit()
|
||||||
|
|
||||||
|
signal.signal(signal.SIGINT, signal_handler) # Ctrl+C
|
||||||
|
signal.signal(signal.SIGTERM, signal_handler) # Termination signal
|
||||||
|
|
||||||
|
# CRITICAL: Create a timer to wake up Python interpreter periodically
|
||||||
|
# This allows Python to process signals while Qt's event loop is running
|
||||||
|
# Without this, Ctrl+C will not work and may cause segfaults
|
||||||
|
timer = QTimer()
|
||||||
|
timer.timeout.connect(lambda: None) # Empty callback just to wake up Python
|
||||||
|
timer.start(100) # Check every 100ms
|
||||||
|
|
||||||
|
return timer
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Main entry point for the application."""
|
||||||
|
# Create application
|
||||||
|
app = QApplication(sys.argv)
|
||||||
|
|
||||||
|
# Set application metadata
|
||||||
|
app.setApplicationName("Hermes Agent")
|
||||||
|
app.setOrganizationName("Hermes")
|
||||||
|
app.setApplicationVersion("1.0.0")
|
||||||
|
|
||||||
|
# Setup signal handlers for safe Ctrl+C handling (prevents segfaults!)
|
||||||
|
timer = setup_signal_handlers(app)
|
||||||
|
|
||||||
|
# Apply dark theme (optional)
|
||||||
|
# Uncomment to enable dark mode
|
||||||
|
# app.setStyle("Fusion")
|
||||||
|
# palette = QPalette()
|
||||||
|
# palette.setColor(QPalette.Window, QColor(53, 53, 53))
|
||||||
|
# palette.setColor(QPalette.WindowText, Qt.white)
|
||||||
|
# app.setPalette(palette)
|
||||||
|
|
||||||
|
# Create and show main window
|
||||||
|
window = HermesMainWindow()
|
||||||
|
window.show()
|
||||||
|
|
||||||
|
print("✨ Hermes Agent UI started")
|
||||||
|
print(" Press Ctrl+C to exit gracefully")
|
||||||
|
|
||||||
|
# Start event loop
|
||||||
|
exit_code = app.exec()
|
||||||
|
|
||||||
|
print("👋 Hermes Agent UI closed")
|
||||||
|
sys.exit(exit_code)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
375
ui/main_window.py
Normal file
375
ui/main_window.py
Normal file
@@ -0,0 +1,375 @@
|
|||||||
|
"""
|
||||||
|
Main window for Hermes Agent UI.
|
||||||
|
|
||||||
|
This module provides the main application window with controls for
|
||||||
|
submitting queries, configuring settings, and viewing real-time events.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import requests
|
||||||
|
from typing import Dict, Any
|
||||||
|
|
||||||
|
from PySide6.QtWidgets import (
|
||||||
|
QMainWindow, QWidget, QVBoxLayout, QHBoxLayout, QTextEdit,
|
||||||
|
QPushButton, QLabel, QLineEdit, QComboBox, QCheckBox,
|
||||||
|
QGroupBox, QSplitter, QListWidget, QListWidgetItem,
|
||||||
|
QSpinBox, QMessageBox
|
||||||
|
)
|
||||||
|
from PySide6.QtCore import Qt, Slot, QTimer
|
||||||
|
from PySide6.QtGui import QFont
|
||||||
|
|
||||||
|
from .websocket_client import WebSocketClient
|
||||||
|
from .event_widgets import InteractiveEventDisplayWidget
|
||||||
|
|
||||||
|
|
||||||
|
class HermesMainWindow(QMainWindow):
|
||||||
|
"""
|
||||||
|
Main window for Hermes Agent UI.
|
||||||
|
|
||||||
|
Provides interface for:
|
||||||
|
- Submitting queries
|
||||||
|
- Configuring agent settings
|
||||||
|
- Viewing real-time events
|
||||||
|
- Managing sessions
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.api_base_url = "http://localhost:8000"
|
||||||
|
self.ws_client = None
|
||||||
|
self.current_session_id = None
|
||||||
|
self.available_toolsets = []
|
||||||
|
self.is_closing = False # Flag to prevent reconnection during shutdown
|
||||||
|
|
||||||
|
self.init_ui()
|
||||||
|
self.setup_websocket()
|
||||||
|
self.load_available_tools()
|
||||||
|
|
||||||
|
def init_ui(self):
|
||||||
|
"""Initialize the user interface."""
|
||||||
|
self.setWindowTitle("Hermes Agent - AI Assistant UI")
|
||||||
|
self.setGeometry(100, 100, 1400, 900)
|
||||||
|
|
||||||
|
# Central widget
|
||||||
|
central_widget = QWidget()
|
||||||
|
self.setCentralWidget(central_widget)
|
||||||
|
|
||||||
|
# Main layout (horizontal split)
|
||||||
|
main_layout = QHBoxLayout()
|
||||||
|
|
||||||
|
# Left panel: Controls
|
||||||
|
left_panel = self.create_control_panel()
|
||||||
|
|
||||||
|
# Right panel: Event display
|
||||||
|
right_panel = self.create_event_panel()
|
||||||
|
|
||||||
|
# Splitter for resizable panels
|
||||||
|
splitter = QSplitter(Qt.Horizontal)
|
||||||
|
splitter.addWidget(left_panel)
|
||||||
|
splitter.addWidget(right_panel)
|
||||||
|
splitter.setStretchFactor(0, 1) # Control panel
|
||||||
|
splitter.setStretchFactor(1, 2) # Event panel (larger)
|
||||||
|
|
||||||
|
main_layout.addWidget(splitter)
|
||||||
|
central_widget.setLayout(main_layout)
|
||||||
|
|
||||||
|
# Status bar
|
||||||
|
self.statusBar().showMessage("Ready")
|
||||||
|
|
||||||
|
def create_control_panel(self) -> QWidget:
|
||||||
|
"""Create the left control panel."""
|
||||||
|
panel = QWidget()
|
||||||
|
layout = QVBoxLayout()
|
||||||
|
|
||||||
|
# Title
|
||||||
|
title = QLabel("🤖 Hermes Agent Control")
|
||||||
|
title.setFont(QFont("Arial", 14, QFont.Bold))
|
||||||
|
title.setAlignment(Qt.AlignCenter)
|
||||||
|
layout.addWidget(title)
|
||||||
|
|
||||||
|
# Query input group
|
||||||
|
query_group = QGroupBox("Query Input")
|
||||||
|
query_layout = QVBoxLayout()
|
||||||
|
|
||||||
|
self.query_input = QTextEdit()
|
||||||
|
self.query_input.setPlaceholderText("Enter your query here...")
|
||||||
|
self.query_input.setMaximumHeight(150)
|
||||||
|
query_layout.addWidget(self.query_input)
|
||||||
|
|
||||||
|
self.submit_btn = QPushButton("🚀 Submit Query")
|
||||||
|
self.submit_btn.setFont(QFont("Arial", 11, QFont.Bold))
|
||||||
|
self.submit_btn.setStyleSheet("QPushButton { background-color: #4CAF50; color: white; padding: 10px; }")
|
||||||
|
self.submit_btn.clicked.connect(self.submit_query)
|
||||||
|
query_layout.addWidget(self.submit_btn)
|
||||||
|
|
||||||
|
query_group.setLayout(query_layout)
|
||||||
|
layout.addWidget(query_group)
|
||||||
|
|
||||||
|
# Model configuration group
|
||||||
|
model_group = QGroupBox("Model Configuration")
|
||||||
|
model_layout = QVBoxLayout()
|
||||||
|
|
||||||
|
# Model selection
|
||||||
|
model_layout.addWidget(QLabel("Model:"))
|
||||||
|
self.model_combo = QComboBox()
|
||||||
|
self.model_combo.addItems([
|
||||||
|
"claude-sonnet-4-5-20250929",
|
||||||
|
"claude-opus-4-20250514",
|
||||||
|
"gpt-4",
|
||||||
|
"gpt-4-turbo"
|
||||||
|
])
|
||||||
|
model_layout.addWidget(self.model_combo)
|
||||||
|
|
||||||
|
# API Base URL
|
||||||
|
model_layout.addWidget(QLabel("API Base URL:"))
|
||||||
|
self.base_url_input = QLineEdit("https://api.anthropic.com/v1/")
|
||||||
|
model_layout.addWidget(self.base_url_input)
|
||||||
|
|
||||||
|
# Max turns
|
||||||
|
model_layout.addWidget(QLabel("Max Turns:"))
|
||||||
|
self.max_turns_spin = QSpinBox()
|
||||||
|
self.max_turns_spin.setMinimum(1)
|
||||||
|
self.max_turns_spin.setMaximum(50)
|
||||||
|
self.max_turns_spin.setValue(10)
|
||||||
|
model_layout.addWidget(self.max_turns_spin)
|
||||||
|
|
||||||
|
model_group.setLayout(model_layout)
|
||||||
|
layout.addWidget(model_group)
|
||||||
|
|
||||||
|
# Tools configuration group
|
||||||
|
tools_group = QGroupBox("Tools & Toolsets")
|
||||||
|
tools_layout = QVBoxLayout()
|
||||||
|
|
||||||
|
tools_layout.addWidget(QLabel("Select Toolsets:"))
|
||||||
|
self.toolsets_list = QListWidget()
|
||||||
|
self.toolsets_list.setSelectionMode(QListWidget.MultiSelection)
|
||||||
|
self.toolsets_list.setMaximumHeight(150)
|
||||||
|
tools_layout.addWidget(self.toolsets_list)
|
||||||
|
|
||||||
|
tools_group.setLayout(tools_layout)
|
||||||
|
layout.addWidget(tools_group)
|
||||||
|
|
||||||
|
# Options group
|
||||||
|
options_group = QGroupBox("Options")
|
||||||
|
options_layout = QVBoxLayout()
|
||||||
|
|
||||||
|
self.mock_mode_checkbox = QCheckBox("Mock Web Tools (Testing)")
|
||||||
|
options_layout.addWidget(self.mock_mode_checkbox)
|
||||||
|
|
||||||
|
self.verbose_checkbox = QCheckBox("Verbose Logging")
|
||||||
|
options_layout.addWidget(self.verbose_checkbox)
|
||||||
|
|
||||||
|
options_layout.addWidget(QLabel("Mock Delay (seconds):"))
|
||||||
|
self.mock_delay_spin = QSpinBox()
|
||||||
|
self.mock_delay_spin.setMinimum(1)
|
||||||
|
self.mock_delay_spin.setMaximum(300)
|
||||||
|
self.mock_delay_spin.setValue(60)
|
||||||
|
options_layout.addWidget(self.mock_delay_spin)
|
||||||
|
|
||||||
|
options_group.setLayout(options_layout)
|
||||||
|
layout.addWidget(options_group)
|
||||||
|
|
||||||
|
# Connection status
|
||||||
|
self.connection_status = QLabel("🔴 Disconnected")
|
||||||
|
self.connection_status.setAlignment(Qt.AlignCenter)
|
||||||
|
self.connection_status.setStyleSheet("QLabel { padding: 5px; background-color: #F44336; color: white; border-radius: 3px; }")
|
||||||
|
layout.addWidget(self.connection_status)
|
||||||
|
|
||||||
|
# Add stretch to push everything to top
|
||||||
|
layout.addStretch()
|
||||||
|
|
||||||
|
panel.setLayout(layout)
|
||||||
|
return panel
|
||||||
|
|
||||||
|
def create_event_panel(self) -> QWidget:
|
||||||
|
"""Create the right event display panel."""
|
||||||
|
panel = QWidget()
|
||||||
|
layout = QVBoxLayout()
|
||||||
|
|
||||||
|
# Event display widget
|
||||||
|
self.event_widget = InteractiveEventDisplayWidget()
|
||||||
|
layout.addWidget(self.event_widget)
|
||||||
|
|
||||||
|
panel.setLayout(layout)
|
||||||
|
return panel
|
||||||
|
|
||||||
|
def setup_websocket(self):
|
||||||
|
"""Setup WebSocket connection for real-time events."""
|
||||||
|
self.ws_client = WebSocketClient("ws://localhost:8000/ws")
|
||||||
|
|
||||||
|
# Connect signals
|
||||||
|
self.ws_client.connected.connect(self.on_ws_connected)
|
||||||
|
self.ws_client.disconnected.connect(self.on_ws_disconnected)
|
||||||
|
self.ws_client.error.connect(self.on_ws_error)
|
||||||
|
self.ws_client.event_received.connect(self.on_event_received)
|
||||||
|
|
||||||
|
# Start connection
|
||||||
|
self.ws_client.connect()
|
||||||
|
|
||||||
|
@Slot()
|
||||||
|
def on_ws_connected(self):
|
||||||
|
"""Called when WebSocket connection is established."""
|
||||||
|
self.connection_status.setText("🟢 Connected")
|
||||||
|
self.connection_status.setStyleSheet("QLabel { padding: 5px; background-color: #4CAF50; color: white; border-radius: 3px; }")
|
||||||
|
self.statusBar().showMessage("WebSocket connected")
|
||||||
|
|
||||||
|
@Slot()
|
||||||
|
def on_ws_disconnected(self):
|
||||||
|
"""Called when WebSocket connection is lost."""
|
||||||
|
# Don't attempt reconnection if we're closing the application
|
||||||
|
if self.is_closing:
|
||||||
|
return
|
||||||
|
|
||||||
|
self.connection_status.setText("🔴 Disconnected")
|
||||||
|
self.connection_status.setStyleSheet("QLabel { padding: 5px; background-color: #F44336; color: white; border-radius: 3px; }")
|
||||||
|
self.statusBar().showMessage("WebSocket disconnected - attempting reconnect...")
|
||||||
|
|
||||||
|
# Attempt reconnect after 5 seconds
|
||||||
|
QTimer.singleShot(5000, self.ws_client.connect)
|
||||||
|
|
||||||
|
@Slot(str)
|
||||||
|
def on_ws_error(self, error: str):
|
||||||
|
"""Called when WebSocket error occurs."""
|
||||||
|
self.statusBar().showMessage(f"WebSocket error: {error}")
|
||||||
|
|
||||||
|
@Slot(dict)
|
||||||
|
def on_event_received(self, event: Dict[str, Any]):
|
||||||
|
"""
|
||||||
|
Called when an event is received from WebSocket.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
event: Event data from server
|
||||||
|
"""
|
||||||
|
self.event_widget.add_event(event)
|
||||||
|
|
||||||
|
# Update status for specific events
|
||||||
|
event_type = event.get("event_type")
|
||||||
|
if event_type == "query":
|
||||||
|
self.statusBar().showMessage("Query received - agent processing...")
|
||||||
|
elif event_type == "complete":
|
||||||
|
self.statusBar().showMessage("Agent completed!")
|
||||||
|
self.submit_btn.setEnabled(True)
|
||||||
|
|
||||||
|
def load_available_tools(self):
|
||||||
|
"""Load available toolsets from the API."""
|
||||||
|
try:
|
||||||
|
response = requests.get(f"{self.api_base_url}/tools", timeout=5)
|
||||||
|
if response.status_code == 200:
|
||||||
|
data = response.json()
|
||||||
|
toolsets = data.get("toolsets", [])
|
||||||
|
|
||||||
|
self.available_toolsets = toolsets
|
||||||
|
self.toolsets_list.clear()
|
||||||
|
|
||||||
|
for toolset in toolsets:
|
||||||
|
name = toolset.get("name", "")
|
||||||
|
description = toolset.get("description", "")
|
||||||
|
tool_count = toolset.get("tool_count", 0)
|
||||||
|
|
||||||
|
item_text = f"{name} ({tool_count} tools) - {description}"
|
||||||
|
item = QListWidgetItem(item_text)
|
||||||
|
item.setData(Qt.UserRole, name) # Store toolset name
|
||||||
|
self.toolsets_list.addItem(item)
|
||||||
|
|
||||||
|
self.statusBar().showMessage(f"Loaded {len(toolsets)} toolsets")
|
||||||
|
else:
|
||||||
|
self.statusBar().showMessage("Failed to load toolsets from API")
|
||||||
|
|
||||||
|
except requests.exceptions.RequestException as e:
|
||||||
|
self.statusBar().showMessage(f"Error loading toolsets: {str(e)}")
|
||||||
|
# Add some default toolsets
|
||||||
|
default_toolsets = ["web", "vision", "terminal", "research"]
|
||||||
|
for ts in default_toolsets:
|
||||||
|
item = QListWidgetItem(f"{ts} (default)")
|
||||||
|
item.setData(Qt.UserRole, ts)
|
||||||
|
self.toolsets_list.addItem(item)
|
||||||
|
|
||||||
|
@Slot()
|
||||||
|
def submit_query(self):
|
||||||
|
"""Submit query to the agent API."""
|
||||||
|
query = self.query_input.toPlainText().strip()
|
||||||
|
|
||||||
|
if not query:
|
||||||
|
QMessageBox.warning(self, "No Query", "Please enter a query first!")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Get selected toolsets
|
||||||
|
selected_toolsets = []
|
||||||
|
for i in range(self.toolsets_list.count()):
|
||||||
|
item = self.toolsets_list.item(i)
|
||||||
|
if item.isSelected():
|
||||||
|
toolset_name = item.data(Qt.UserRole)
|
||||||
|
selected_toolsets.append(toolset_name)
|
||||||
|
|
||||||
|
# Build request payload
|
||||||
|
payload = {
|
||||||
|
"query": query,
|
||||||
|
"model": self.model_combo.currentText(),
|
||||||
|
"base_url": self.base_url_input.text(),
|
||||||
|
"max_turns": self.max_turns_spin.value(),
|
||||||
|
"enabled_toolsets": selected_toolsets if selected_toolsets else None,
|
||||||
|
"mock_web_tools": self.mock_mode_checkbox.isChecked(),
|
||||||
|
"mock_delay": self.mock_delay_spin.value(),
|
||||||
|
"verbose": self.verbose_checkbox.isChecked()
|
||||||
|
}
|
||||||
|
|
||||||
|
# Disable submit button during execution
|
||||||
|
self.submit_btn.setEnabled(False)
|
||||||
|
self.submit_btn.setText("⏳ Running...")
|
||||||
|
self.statusBar().showMessage("Submitting query to agent...")
|
||||||
|
|
||||||
|
# Submit to API
|
||||||
|
try:
|
||||||
|
response = requests.post(
|
||||||
|
f"{self.api_base_url}/agent/run",
|
||||||
|
json=payload,
|
||||||
|
timeout=10
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code == 200:
|
||||||
|
result = response.json()
|
||||||
|
session_id = result.get("session_id", "")
|
||||||
|
self.current_session_id = session_id
|
||||||
|
|
||||||
|
self.statusBar().showMessage(f"Agent started! Session: {session_id[:8]}...")
|
||||||
|
|
||||||
|
# Clear event display for new session (optional)
|
||||||
|
# self.event_widget.clear_events()
|
||||||
|
|
||||||
|
else:
|
||||||
|
QMessageBox.warning(
|
||||||
|
self,
|
||||||
|
"API Error",
|
||||||
|
f"Failed to start agent: {response.status_code}\n{response.text}"
|
||||||
|
)
|
||||||
|
self.submit_btn.setEnabled(True)
|
||||||
|
self.submit_btn.setText("🚀 Submit Query")
|
||||||
|
|
||||||
|
except requests.exceptions.RequestException as e:
|
||||||
|
QMessageBox.critical(
|
||||||
|
self,
|
||||||
|
"Connection Error",
|
||||||
|
f"Failed to connect to API server:\n{str(e)}\n\nMake sure the server is running:\npython logging_server.py"
|
||||||
|
)
|
||||||
|
self.submit_btn.setEnabled(True)
|
||||||
|
self.submit_btn.setText("🚀 Submit Query")
|
||||||
|
|
||||||
|
# Re-enable button after short delay (UI feedback)
|
||||||
|
QTimer.singleShot(2000, lambda: self.submit_btn.setText("🚀 Submit Query"))
|
||||||
|
|
||||||
|
def cleanup(self):
|
||||||
|
"""Clean up resources before exit."""
|
||||||
|
print("Cleaning up resources...")
|
||||||
|
self.is_closing = True
|
||||||
|
|
||||||
|
if self.ws_client:
|
||||||
|
try:
|
||||||
|
self.ws_client.disconnect()
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error disconnecting WebSocket: {e}")
|
||||||
|
|
||||||
|
def closeEvent(self, event):
|
||||||
|
"""Handle window close event - ensures clean shutdown."""
|
||||||
|
print("Closing application...")
|
||||||
|
self.cleanup()
|
||||||
|
event.accept()
|
||||||
|
|
||||||
115
ui/start_hermes_ui.sh
Executable file
115
ui/start_hermes_ui.sh
Executable file
@@ -0,0 +1,115 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# Hermes Agent UI Launcher
|
||||||
|
#
|
||||||
|
# This script starts both the API server and UI application.
|
||||||
|
# It will run them in the background and provide a clean shutdown.
|
||||||
|
|
||||||
|
set -e
|
||||||
|
|
||||||
|
# Colors for output
|
||||||
|
GREEN='\033[0;32m'
|
||||||
|
BLUE='\033[0;34m'
|
||||||
|
RED='\033[0;31m'
|
||||||
|
NC='\033[0m' # No Color
|
||||||
|
|
||||||
|
echo -e "${BLUE}🚀 Hermes Agent UI Launcher${NC}"
|
||||||
|
echo "================================"
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
# Check if Python is available
|
||||||
|
if ! command -v python3 &> /dev/null; then
|
||||||
|
echo -e "${RED}❌ Python 3 not found. Please install Python 3.${NC}"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Check if virtual environment exists
|
||||||
|
if [ -d "../../env" ]; then
|
||||||
|
echo -e "${GREEN}✓ Activating virtual environment${NC}"
|
||||||
|
source ../../env/bin/activate
|
||||||
|
else
|
||||||
|
echo -e "${BLUE}ℹ No virtual environment found, using system Python${NC}"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Check dependencies
|
||||||
|
echo -e "${BLUE}Checking dependencies...${NC}"
|
||||||
|
python3 -c "import PySide6" 2>/dev/null || {
|
||||||
|
echo -e "${RED}❌ PySide6 not installed${NC}"
|
||||||
|
echo -e "${BLUE}Installing dependencies...${NC}"
|
||||||
|
pip install -r ../requirements.txt
|
||||||
|
}
|
||||||
|
|
||||||
|
# Check for API keys
|
||||||
|
if [ -z "$ANTHROPIC_API_KEY" ]; then
|
||||||
|
echo -e "${RED}⚠️ Warning: ANTHROPIC_API_KEY not set${NC}"
|
||||||
|
echo " Set it with: export ANTHROPIC_API_KEY='your-key'"
|
||||||
|
echo ""
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Function to cleanup on exit
|
||||||
|
cleanup() {
|
||||||
|
echo ""
|
||||||
|
echo -e "${BLUE}🛑 Shutting down Hermes Agent...${NC}"
|
||||||
|
if [ ! -z "$SERVER_PID" ]; then
|
||||||
|
kill $SERVER_PID 2>/dev/null || true
|
||||||
|
echo -e "${GREEN}✓ API Server stopped${NC}"
|
||||||
|
fi
|
||||||
|
if [ ! -z "$UI_PID" ]; then
|
||||||
|
kill $UI_PID 2>/dev/null || true
|
||||||
|
echo -e "${GREEN}✓ UI Application stopped${NC}"
|
||||||
|
fi
|
||||||
|
echo -e "${GREEN}✓ Cleanup complete${NC}"
|
||||||
|
exit 0
|
||||||
|
}
|
||||||
|
|
||||||
|
# Set up trap for cleanup
|
||||||
|
trap cleanup SIGINT SIGTERM EXIT
|
||||||
|
|
||||||
|
# Start API server in background
|
||||||
|
echo -e "${BLUE}Starting API Server...${NC}"
|
||||||
|
cd ../api_endpoint
|
||||||
|
python3 logging_server.py > /tmp/hermes_server.log 2>&1 &
|
||||||
|
SERVER_PID=$!
|
||||||
|
cd ../ui
|
||||||
|
|
||||||
|
# Wait for server to start
|
||||||
|
echo -e "${BLUE}Waiting for server to start...${NC}"
|
||||||
|
sleep 3
|
||||||
|
|
||||||
|
# Check if server is running
|
||||||
|
if ! kill -0 $SERVER_PID 2>/dev/null; then
|
||||||
|
echo -e "${RED}❌ Server failed to start. Check /tmp/hermes_server.log${NC}"
|
||||||
|
tail -20 /tmp/hermes_server.log
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Check if server is responding
|
||||||
|
if curl -s http://localhost:8000/ > /dev/null; then
|
||||||
|
echo -e "${GREEN}✓ API Server running on http://localhost:8000${NC}"
|
||||||
|
else
|
||||||
|
echo -e "${RED}❌ Server not responding. Check /tmp/hermes_server.log${NC}"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Start UI application
|
||||||
|
echo -e "${BLUE}Starting UI Application...${NC}"
|
||||||
|
python3 hermes_ui.py &
|
||||||
|
UI_PID=$!
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo -e "${GREEN}================================${NC}"
|
||||||
|
echo -e "${GREEN}✓ Hermes Agent UI is running!${NC}"
|
||||||
|
echo -e "${GREEN}================================${NC}"
|
||||||
|
echo ""
|
||||||
|
echo -e "${BLUE}📊 Component Status:${NC}"
|
||||||
|
echo -e " API Server: http://localhost:8000 (PID: $SERVER_PID)"
|
||||||
|
echo -e " UI App: Running (PID: $UI_PID)"
|
||||||
|
echo -e " Server Log: /tmp/hermes_server.log"
|
||||||
|
echo ""
|
||||||
|
echo -e "${BLUE}Press Ctrl+C to stop all services${NC}"
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
# Wait for UI to exit
|
||||||
|
wait $UI_PID
|
||||||
|
|
||||||
|
# Cleanup will be triggered by trap
|
||||||
|
|
||||||
264
ui/test_ui_flow.py
Normal file
264
ui/test_ui_flow.py
Normal file
@@ -0,0 +1,264 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Test script to verify UI flow works correctly.
|
||||||
|
|
||||||
|
This tests:
|
||||||
|
1. API server is running
|
||||||
|
2. WebSocket connection works
|
||||||
|
3. Agent can be started via API
|
||||||
|
4. Events are broadcast properly
|
||||||
|
"""
|
||||||
|
|
||||||
|
import requests
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
import websocket
|
||||||
|
import threading
|
||||||
|
|
||||||
|
API_URL = "http://localhost:8000"
|
||||||
|
WS_URL = "ws://localhost:8000/ws"
|
||||||
|
|
||||||
|
def test_api_server():
|
||||||
|
"""Test if API server is running."""
|
||||||
|
print("🔍 Testing API server...")
|
||||||
|
try:
|
||||||
|
response = requests.get(f"{API_URL}/", timeout=5)
|
||||||
|
if response.status_code == 200:
|
||||||
|
data = response.json()
|
||||||
|
print(f"✅ API server is running: {data.get('service')}")
|
||||||
|
print(f" Active connections: {data.get('active_connections')}")
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
print(f"❌ API server returned: {response.status_code}")
|
||||||
|
return False
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ API server not accessible: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def test_tools_endpoint():
|
||||||
|
"""Test if tools endpoint works."""
|
||||||
|
print("\n🔍 Testing tools endpoint...")
|
||||||
|
try:
|
||||||
|
response = requests.get(f"{API_URL}/tools", timeout=5)
|
||||||
|
if response.status_code == 200:
|
||||||
|
data = response.json()
|
||||||
|
toolsets = data.get("toolsets", [])
|
||||||
|
print(f"✅ Tools endpoint works - {len(toolsets)} toolsets available")
|
||||||
|
for ts in toolsets[:3]:
|
||||||
|
print(f" • {ts.get('name')} ({ts.get('tool_count')} tools)")
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
print(f"❌ Tools endpoint failed: {response.status_code}")
|
||||||
|
return False
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Tools endpoint error: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def test_websocket():
|
||||||
|
"""Test WebSocket connection."""
|
||||||
|
print("\n🔍 Testing WebSocket connection...")
|
||||||
|
|
||||||
|
connected = threading.Event()
|
||||||
|
message_received = threading.Event()
|
||||||
|
messages = []
|
||||||
|
|
||||||
|
def on_open(ws):
|
||||||
|
print("✅ WebSocket connected")
|
||||||
|
connected.set()
|
||||||
|
|
||||||
|
def on_message(ws, message):
|
||||||
|
data = json.loads(message)
|
||||||
|
messages.append(data)
|
||||||
|
message_received.set()
|
||||||
|
print(f"📨 Received: {data.get('event_type', 'unknown')}")
|
||||||
|
|
||||||
|
def on_error(ws, error):
|
||||||
|
print(f"❌ WebSocket error: {error}")
|
||||||
|
|
||||||
|
def on_close(ws, close_status_code, close_msg):
|
||||||
|
print(f"🔌 WebSocket closed: {close_status_code}")
|
||||||
|
|
||||||
|
ws = websocket.WebSocketApp(
|
||||||
|
WS_URL,
|
||||||
|
on_open=on_open,
|
||||||
|
on_message=on_message,
|
||||||
|
on_error=on_error,
|
||||||
|
on_close=on_close
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run WebSocket in background
|
||||||
|
ws_thread = threading.Thread(target=lambda: ws.run_forever(), daemon=True)
|
||||||
|
ws_thread.start()
|
||||||
|
|
||||||
|
# Wait for connection
|
||||||
|
if connected.wait(timeout=5):
|
||||||
|
print("✅ WebSocket connection established")
|
||||||
|
ws.close()
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
print("❌ WebSocket connection timeout")
|
||||||
|
ws.close()
|
||||||
|
return False
|
||||||
|
|
||||||
|
def test_agent_run():
|
||||||
|
"""Test running agent via API."""
|
||||||
|
print("\n🔍 Testing agent run via API (mock mode)...")
|
||||||
|
|
||||||
|
# Start listening for events first
|
||||||
|
events = []
|
||||||
|
ws_connected = threading.Event()
|
||||||
|
session_complete = threading.Event()
|
||||||
|
|
||||||
|
def on_message(ws, message):
|
||||||
|
data = json.loads(message)
|
||||||
|
events.append(data)
|
||||||
|
event_type = data.get("event_type")
|
||||||
|
print(f" 📨 Event: {event_type}")
|
||||||
|
|
||||||
|
if event_type == "complete":
|
||||||
|
session_complete.set()
|
||||||
|
|
||||||
|
def on_open(ws):
|
||||||
|
ws_connected.set()
|
||||||
|
|
||||||
|
# Connect WebSocket
|
||||||
|
ws = websocket.WebSocketApp(
|
||||||
|
WS_URL,
|
||||||
|
on_open=on_open,
|
||||||
|
on_message=on_message
|
||||||
|
)
|
||||||
|
|
||||||
|
ws_thread = threading.Thread(target=lambda: ws.run_forever(), daemon=True)
|
||||||
|
ws_thread.start()
|
||||||
|
|
||||||
|
# Wait for WebSocket connection
|
||||||
|
if not ws_connected.wait(timeout=5):
|
||||||
|
print("❌ WebSocket didn't connect")
|
||||||
|
ws.close()
|
||||||
|
return False
|
||||||
|
|
||||||
|
print("✅ WebSocket connected, starting agent...")
|
||||||
|
|
||||||
|
# Submit agent run
|
||||||
|
payload = {
|
||||||
|
"query": "Test query for UI flow verification",
|
||||||
|
"model": "claude-sonnet-4-5-20250929",
|
||||||
|
"base_url": "https://api.anthropic.com/v1/",
|
||||||
|
"enabled_toolsets": ["web"],
|
||||||
|
"max_turns": 5,
|
||||||
|
"mock_web_tools": True, # Use mock mode to avoid API costs
|
||||||
|
"mock_delay": 2, # Fast for testing
|
||||||
|
"verbose": False
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = requests.post(f"{API_URL}/agent/run", json=payload, timeout=10)
|
||||||
|
|
||||||
|
if response.status_code == 200:
|
||||||
|
result = response.json()
|
||||||
|
session_id = result.get("session_id")
|
||||||
|
print(f"✅ Agent started: {session_id[:8]}...")
|
||||||
|
|
||||||
|
# Wait for completion (or timeout)
|
||||||
|
print("⏳ Waiting for agent to complete (up to 30s)...")
|
||||||
|
if session_complete.wait(timeout=30):
|
||||||
|
print(f"✅ Agent completed! Received {len(events)} events:")
|
||||||
|
|
||||||
|
# Count event types
|
||||||
|
event_counts = {}
|
||||||
|
for evt in events:
|
||||||
|
evt_type = evt.get("event_type", "unknown")
|
||||||
|
event_counts[evt_type] = event_counts.get(evt_type, 0) + 1
|
||||||
|
|
||||||
|
for evt_type, count in event_counts.items():
|
||||||
|
print(f" • {evt_type}: {count}")
|
||||||
|
|
||||||
|
# Check we got expected events
|
||||||
|
expected_events = ["query", "api_call", "response", "complete"]
|
||||||
|
missing = [e for e in expected_events if e not in event_counts]
|
||||||
|
|
||||||
|
if missing:
|
||||||
|
print(f"⚠️ Missing expected events: {missing}")
|
||||||
|
else:
|
||||||
|
print("✅ All expected event types received!")
|
||||||
|
|
||||||
|
ws.close()
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
print(f"⚠️ Timeout waiting for completion. Got {len(events)} events so far.")
|
||||||
|
ws.close()
|
||||||
|
return False
|
||||||
|
|
||||||
|
else:
|
||||||
|
print(f"❌ Agent start failed: {response.status_code}")
|
||||||
|
print(f" Response: {response.text}")
|
||||||
|
ws.close()
|
||||||
|
return False
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Agent run error: {e}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
ws.close()
|
||||||
|
return False
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Run all tests."""
|
||||||
|
print("=" * 60)
|
||||||
|
print("🧪 Hermes Agent UI Flow Test")
|
||||||
|
print("=" * 60)
|
||||||
|
print("\nThis will test the complete flow:")
|
||||||
|
print(" 1. API server connectivity")
|
||||||
|
print(" 2. Tools endpoint")
|
||||||
|
print(" 3. WebSocket connection")
|
||||||
|
print(" 4. Agent execution via API (mock mode)")
|
||||||
|
print(" 5. Event streaming to UI")
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
|
||||||
|
results = []
|
||||||
|
|
||||||
|
# Test 1: API server
|
||||||
|
results.append(("API Server", test_api_server()))
|
||||||
|
|
||||||
|
# Test 2: Tools endpoint
|
||||||
|
results.append(("Tools Endpoint", test_tools_endpoint()))
|
||||||
|
|
||||||
|
# Test 3: WebSocket
|
||||||
|
results.append(("WebSocket Connection", test_websocket()))
|
||||||
|
|
||||||
|
# Test 4: Agent run
|
||||||
|
results.append(("Agent Execution + Events", test_agent_run()))
|
||||||
|
|
||||||
|
# Summary
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("📊 TEST SUMMARY")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
for test_name, passed in results:
|
||||||
|
status = "✅ PASS" if passed else "❌ FAIL"
|
||||||
|
print(f"{status} - {test_name}")
|
||||||
|
|
||||||
|
all_passed = all(r[1] for r in results)
|
||||||
|
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
if all_passed:
|
||||||
|
print("🎉 ALL TESTS PASSED!")
|
||||||
|
print("\n✅ The UI flow is working correctly!")
|
||||||
|
print(" You can now use the UI to:")
|
||||||
|
print(" • Submit queries")
|
||||||
|
print(" • View real-time events")
|
||||||
|
print(" • See tool executions")
|
||||||
|
print(" • Get final responses")
|
||||||
|
else:
|
||||||
|
print("❌ SOME TESTS FAILED")
|
||||||
|
print("\nMake sure:")
|
||||||
|
print(" 1. API server is running: python api_endpoint/logging_server.py")
|
||||||
|
print(" 2. ANTHROPIC_API_KEY is set in environment")
|
||||||
|
print(" 3. All dependencies are installed: pip install -r requirements.txt")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
return 0 if all_passed else 1
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
exit(main())
|
||||||
|
|
||||||
91
ui/websocket_client.py
Normal file
91
ui/websocket_client.py
Normal file
@@ -0,0 +1,91 @@
|
|||||||
|
"""
|
||||||
|
WebSocket client for real-time event streaming from Hermes Agent.
|
||||||
|
|
||||||
|
This module provides a WebSocket client that runs in a separate thread
|
||||||
|
and emits Qt signals when events are received from the server.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import threading
|
||||||
|
import websocket
|
||||||
|
from PySide6.QtCore import QObject, Signal
|
||||||
|
|
||||||
|
|
||||||
|
class WebSocketClient(QObject):
|
||||||
|
"""
|
||||||
|
WebSocket client for receiving real-time agent events.
|
||||||
|
|
||||||
|
Runs in a separate thread and emits Qt signals when events arrive.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Signals for event communication
|
||||||
|
event_received = Signal(dict) # Emits parsed event data
|
||||||
|
connected = Signal()
|
||||||
|
disconnected = Signal()
|
||||||
|
error = Signal(str)
|
||||||
|
|
||||||
|
def __init__(self, url: str = "ws://localhost:8000/ws"):
|
||||||
|
super().__init__()
|
||||||
|
self.url = url
|
||||||
|
self.ws = None
|
||||||
|
self.running = False
|
||||||
|
self.thread = None
|
||||||
|
|
||||||
|
def connect(self):
|
||||||
|
"""Start WebSocket connection in background thread."""
|
||||||
|
if self.running:
|
||||||
|
return
|
||||||
|
|
||||||
|
self.running = True
|
||||||
|
self.thread = threading.Thread(target=self._run, daemon=True)
|
||||||
|
self.thread.start()
|
||||||
|
|
||||||
|
def disconnect(self):
|
||||||
|
"""Stop WebSocket connection."""
|
||||||
|
self.running = False
|
||||||
|
if self.ws:
|
||||||
|
try:
|
||||||
|
self.ws.close()
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error closing WebSocket: {e}")
|
||||||
|
|
||||||
|
def _run(self):
|
||||||
|
"""WebSocket event loop (runs in background thread)."""
|
||||||
|
try:
|
||||||
|
self.ws = websocket.WebSocketApp(
|
||||||
|
self.url,
|
||||||
|
on_open=self._on_open,
|
||||||
|
on_message=self._on_message,
|
||||||
|
on_error=self._on_error,
|
||||||
|
on_close=self._on_close
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run forever with reconnection
|
||||||
|
self.ws.run_forever(ping_interval=300, ping_timeout=60)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.error.emit(f"WebSocket error: {str(e)}")
|
||||||
|
|
||||||
|
def _on_open(self, ws):
|
||||||
|
"""Called when WebSocket connection is established."""
|
||||||
|
print("WebSocket connected")
|
||||||
|
self.connected.emit()
|
||||||
|
|
||||||
|
def _on_message(self, ws, message):
|
||||||
|
"""Called when a message is received from the server."""
|
||||||
|
try:
|
||||||
|
data = json.loads(message)
|
||||||
|
self.event_received.emit(data)
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
print(f" Failed to parse WebSocket message: {e}")
|
||||||
|
|
||||||
|
def _on_error(self, ws, error):
|
||||||
|
"""Called when an error occurs."""
|
||||||
|
print(f"WebSocket error: {error}")
|
||||||
|
self.error.emit(str(error))
|
||||||
|
|
||||||
|
def _on_close(self, ws, close_status_code, close_msg):
|
||||||
|
"""Called when WebSocket connection is closed."""
|
||||||
|
print(f"🔌 WebSocket disconnected: {close_status_code} - {close_msg}")
|
||||||
|
self.disconnected.emit()
|
||||||
|
|
||||||
349
vision_tools.py
Normal file
349
vision_tools.py
Normal file
@@ -0,0 +1,349 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Vision Tools Module
|
||||||
|
|
||||||
|
This module provides vision analysis tools that work with image URLs.
|
||||||
|
Uses Gemini Flash via Nous Research API for intelligent image understanding.
|
||||||
|
|
||||||
|
Available tools:
|
||||||
|
- vision_analyze_tool: Analyze images from URLs with custom prompts
|
||||||
|
|
||||||
|
Features:
|
||||||
|
- Comprehensive image description
|
||||||
|
- Context-aware analysis based on user queries
|
||||||
|
- Proper error handling and validation
|
||||||
|
- Debug logging support
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
from vision_tools import vision_analyze_tool
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
# Analyze an image
|
||||||
|
result = await vision_analyze_tool(
|
||||||
|
image_url="https://example.com/image.jpg",
|
||||||
|
user_prompt="What architectural style is this building?"
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import asyncio
|
||||||
|
import uuid
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, Any, Optional
|
||||||
|
from openai import AsyncOpenAI
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
# Initialize Nous Research API client for vision processing
|
||||||
|
nous_client = AsyncOpenAI(
|
||||||
|
api_key=os.getenv("NOUS_API_KEY"),
|
||||||
|
base_url="https://inference-api.nousresearch.com/v1"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Configuration for vision processing
|
||||||
|
DEFAULT_VISION_MODEL = "gemini-2.5-flash"
|
||||||
|
|
||||||
|
# Debug mode configuration
|
||||||
|
DEBUG_MODE = os.getenv("VISION_TOOLS_DEBUG", "false").lower() == "true"
|
||||||
|
DEBUG_SESSION_ID = str(uuid.uuid4())
|
||||||
|
DEBUG_LOG_PATH = Path("./logs")
|
||||||
|
DEBUG_DATA = {
|
||||||
|
"session_id": DEBUG_SESSION_ID,
|
||||||
|
"start_time": datetime.datetime.now().isoformat(),
|
||||||
|
"debug_enabled": DEBUG_MODE,
|
||||||
|
"tool_calls": []
|
||||||
|
} if DEBUG_MODE else None
|
||||||
|
|
||||||
|
# Create logs directory if debug mode is enabled
|
||||||
|
if DEBUG_MODE:
|
||||||
|
DEBUG_LOG_PATH.mkdir(exist_ok=True)
|
||||||
|
print(f"🐛 Vision debug mode enabled - Session ID: {DEBUG_SESSION_ID}")
|
||||||
|
|
||||||
|
|
||||||
|
def _log_debug_call(tool_name: str, call_data: Dict[str, Any]) -> None:
|
||||||
|
"""
|
||||||
|
Log a debug call entry to the global debug data structure.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_name (str): Name of the tool being called
|
||||||
|
call_data (Dict[str, Any]): Data about the call including parameters and results
|
||||||
|
"""
|
||||||
|
if not DEBUG_MODE or not DEBUG_DATA:
|
||||||
|
return
|
||||||
|
|
||||||
|
call_entry = {
|
||||||
|
"timestamp": datetime.datetime.now().isoformat(),
|
||||||
|
"tool_name": tool_name,
|
||||||
|
**call_data
|
||||||
|
}
|
||||||
|
|
||||||
|
DEBUG_DATA["tool_calls"].append(call_entry)
|
||||||
|
|
||||||
|
|
||||||
|
def _save_debug_log() -> None:
|
||||||
|
"""
|
||||||
|
Save the current debug data to a JSON file in the logs directory.
|
||||||
|
"""
|
||||||
|
if not DEBUG_MODE or not DEBUG_DATA:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
debug_filename = f"vision_tools_debug_{DEBUG_SESSION_ID}.json"
|
||||||
|
debug_filepath = DEBUG_LOG_PATH / debug_filename
|
||||||
|
|
||||||
|
# Update end time
|
||||||
|
DEBUG_DATA["end_time"] = datetime.datetime.now().isoformat()
|
||||||
|
DEBUG_DATA["total_calls"] = len(DEBUG_DATA["tool_calls"])
|
||||||
|
|
||||||
|
with open(debug_filepath, 'w', encoding='utf-8') as f:
|
||||||
|
json.dump(DEBUG_DATA, f, indent=2, ensure_ascii=False)
|
||||||
|
|
||||||
|
print(f"🐛 Vision debug log saved: {debug_filepath}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Error saving vision debug log: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_image_url(url: str) -> bool:
|
||||||
|
"""
|
||||||
|
Basic validation of image URL format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
url (str): The URL to validate
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if URL appears to be valid, False otherwise
|
||||||
|
"""
|
||||||
|
if not url or not isinstance(url, str):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check if it's a valid URL format
|
||||||
|
if not (url.startswith('http://') or url.startswith('https://')):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check for common image extensions (optional, as URLs may not have extensions)
|
||||||
|
image_extensions = ['.jpg', '.jpeg', '.png', '.gif', '.bmp', '.webp', '.svg']
|
||||||
|
|
||||||
|
return True # Allow all HTTP/HTTPS URLs for flexibility
|
||||||
|
|
||||||
|
|
||||||
|
async def vision_analyze_tool(
|
||||||
|
image_url: str,
|
||||||
|
user_prompt: str,
|
||||||
|
model: str = DEFAULT_VISION_MODEL
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Analyze an image from a URL using vision AI.
|
||||||
|
|
||||||
|
This tool processes images using Gemini Flash via Nous Research API.
|
||||||
|
The user_prompt parameter is expected to be pre-formatted by the calling
|
||||||
|
function (typically model_tools.py) to include both full description
|
||||||
|
requests and specific questions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_url (str): The URL of the image to analyze
|
||||||
|
user_prompt (str): The pre-formatted prompt for the vision model
|
||||||
|
model (str): The vision model to use (default: gemini-2.5-flash)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: JSON string containing the analysis results with the following structure:
|
||||||
|
{
|
||||||
|
"success": bool,
|
||||||
|
"analysis": str (defaults to error message if None)
|
||||||
|
}
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Exception: If analysis fails or API key is not set
|
||||||
|
"""
|
||||||
|
debug_call_data = {
|
||||||
|
"parameters": {
|
||||||
|
"image_url": image_url,
|
||||||
|
"user_prompt": user_prompt,
|
||||||
|
"model": model
|
||||||
|
},
|
||||||
|
"error": None,
|
||||||
|
"success": False,
|
||||||
|
"analysis_length": 0,
|
||||||
|
"model_used": model
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
print(f"🔍 Analyzing image from URL: {image_url[:60]}{'...' if len(image_url) > 60 else ''}")
|
||||||
|
print(f"📝 User prompt: {user_prompt[:100]}{'...' if len(user_prompt) > 100 else ''}")
|
||||||
|
|
||||||
|
# Validate image URL
|
||||||
|
if not _validate_image_url(image_url):
|
||||||
|
raise ValueError("Invalid image URL format. Must start with http:// or https://")
|
||||||
|
|
||||||
|
# Check API key availability
|
||||||
|
if not os.getenv("NOUS_API_KEY"):
|
||||||
|
raise ValueError("NOUS_API_KEY environment variable not set")
|
||||||
|
|
||||||
|
# Use the prompt as provided (model_tools.py now handles full description formatting)
|
||||||
|
comprehensive_prompt = user_prompt
|
||||||
|
|
||||||
|
# Prepare the message with image URL format
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": comprehensive_prompt
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": image_url
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
print(f"🧠 Processing image with {model}...")
|
||||||
|
|
||||||
|
# Call the vision API
|
||||||
|
response = await nous_client.chat.completions.create(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
temperature=0.1, # Low temperature for consistent analysis
|
||||||
|
max_tokens=2000 # Generous limit for detailed analysis
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract the analysis
|
||||||
|
analysis = response.choices[0].message.content.strip()
|
||||||
|
analysis_length = len(analysis)
|
||||||
|
|
||||||
|
print(f"✅ Image analysis completed ({analysis_length} characters)")
|
||||||
|
|
||||||
|
# Prepare successful response
|
||||||
|
result = {
|
||||||
|
"success": True,
|
||||||
|
"analysis": analysis or "There was a problem with the request and the image could not be analyzed."
|
||||||
|
}
|
||||||
|
|
||||||
|
debug_call_data["success"] = True
|
||||||
|
debug_call_data["analysis_length"] = analysis_length
|
||||||
|
|
||||||
|
# Log debug information
|
||||||
|
_log_debug_call("vision_analyze_tool", debug_call_data)
|
||||||
|
_save_debug_log()
|
||||||
|
|
||||||
|
return json.dumps(result, indent=2)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"Error analyzing image: {str(e)}"
|
||||||
|
print(f"❌ {error_msg}")
|
||||||
|
|
||||||
|
# Prepare error response
|
||||||
|
result = {
|
||||||
|
"success": False,
|
||||||
|
"analysis": "There was a problem with the request and the image could not be analyzed."
|
||||||
|
}
|
||||||
|
|
||||||
|
debug_call_data["error"] = error_msg
|
||||||
|
_log_debug_call("vision_analyze_tool", debug_call_data)
|
||||||
|
_save_debug_log()
|
||||||
|
|
||||||
|
return json.dumps(result, indent=2)
|
||||||
|
|
||||||
|
|
||||||
|
def check_nous_api_key() -> bool:
|
||||||
|
"""
|
||||||
|
Check if the Nous Research API key is available in environment variables.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if API key is set, False otherwise
|
||||||
|
"""
|
||||||
|
return bool(os.getenv("NOUS_API_KEY"))
|
||||||
|
|
||||||
|
|
||||||
|
def check_vision_requirements() -> bool:
|
||||||
|
"""
|
||||||
|
Check if all requirements for vision tools are met.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if requirements are met, False otherwise
|
||||||
|
"""
|
||||||
|
return check_nous_api_key()
|
||||||
|
|
||||||
|
|
||||||
|
def get_debug_session_info() -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Get information about the current debug session.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[str, Any]: Dictionary containing debug session information
|
||||||
|
"""
|
||||||
|
if not DEBUG_MODE or not DEBUG_DATA:
|
||||||
|
return {
|
||||||
|
"enabled": False,
|
||||||
|
"session_id": None,
|
||||||
|
"log_path": None,
|
||||||
|
"total_calls": 0
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
"enabled": True,
|
||||||
|
"session_id": DEBUG_SESSION_ID,
|
||||||
|
"log_path": str(DEBUG_LOG_PATH / f"vision_tools_debug_{DEBUG_SESSION_ID}.json"),
|
||||||
|
"total_calls": len(DEBUG_DATA["tool_calls"])
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
"""
|
||||||
|
Simple test/demo when run directly
|
||||||
|
"""
|
||||||
|
print("👁️ Vision Tools Module")
|
||||||
|
print("=" * 40)
|
||||||
|
|
||||||
|
# Check if API key is available
|
||||||
|
api_available = check_nous_api_key()
|
||||||
|
|
||||||
|
if not api_available:
|
||||||
|
print("❌ NOUS_API_KEY environment variable not set")
|
||||||
|
print("Please set your API key: export NOUS_API_KEY='your-key-here'")
|
||||||
|
print("Get API key at: https://inference-api.nousresearch.com/")
|
||||||
|
exit(1)
|
||||||
|
else:
|
||||||
|
print("✅ Nous Research API key found")
|
||||||
|
|
||||||
|
print("🛠️ Vision tools ready for use!")
|
||||||
|
print(f"🧠 Using model: {DEFAULT_VISION_MODEL}")
|
||||||
|
|
||||||
|
# Show debug mode status
|
||||||
|
if DEBUG_MODE:
|
||||||
|
print(f"🐛 Debug mode ENABLED - Session ID: {DEBUG_SESSION_ID}")
|
||||||
|
print(f" Debug logs will be saved to: ./logs/vision_tools_debug_{DEBUG_SESSION_ID}.json")
|
||||||
|
else:
|
||||||
|
print("🐛 Debug mode disabled (set VISION_TOOLS_DEBUG=true to enable)")
|
||||||
|
|
||||||
|
print("\nBasic usage:")
|
||||||
|
print(" from vision_tools import vision_analyze_tool")
|
||||||
|
print(" import asyncio")
|
||||||
|
print("")
|
||||||
|
print(" async def main():")
|
||||||
|
print(" result = await vision_analyze_tool(")
|
||||||
|
print(" image_url='https://example.com/image.jpg',")
|
||||||
|
print(" user_prompt='What do you see in this image?'")
|
||||||
|
print(" )")
|
||||||
|
print(" print(result)")
|
||||||
|
print(" asyncio.run(main())")
|
||||||
|
|
||||||
|
print("\nExample prompts:")
|
||||||
|
print(" - 'What architectural style is this building?'")
|
||||||
|
print(" - 'Describe the emotions and mood in this image'")
|
||||||
|
print(" - 'What text can you read in this image?'")
|
||||||
|
print(" - 'Identify any safety hazards visible'")
|
||||||
|
print(" - 'What products or brands are shown?'")
|
||||||
|
|
||||||
|
print("\nDebug mode:")
|
||||||
|
print(" # Enable debug logging")
|
||||||
|
print(" export VISION_TOOLS_DEBUG=true")
|
||||||
|
print(" # Debug logs capture all vision analysis calls and results")
|
||||||
|
print(" # Logs saved to: ./logs/vision_tools_debug_UUID.json")
|
||||||
928
web_tools.py
928
web_tools.py
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user