mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-05-05 10:17:17 +08:00
Compare commits
4 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7d9a1e119d | ||
|
|
e91d9e839a | ||
|
|
98321be8b0 | ||
|
|
a219e178a1 |
695
batch_runner.py
695
batch_runner.py
@@ -4,25 +4,25 @@ Batch Agent Runner
|
|||||||
|
|
||||||
This module provides parallel batch processing capabilities for running the agent
|
This module provides parallel batch processing capabilities for running the agent
|
||||||
across multiple prompts from a dataset. It includes:
|
across multiple prompts from a dataset. It includes:
|
||||||
- Dataset loading and batching
|
- Dataset loading
|
||||||
- Parallel batch processing with multiprocessing
|
- Concurrent processing with asyncio (Producer-Consumer pattern)
|
||||||
- Checkpointing for fault tolerance and resumption
|
- Checkpointing for fault tolerance and resumption
|
||||||
- Trajectory saving in the proper format (from/value pairs)
|
- Trajectory saving in the proper format (from/value pairs)
|
||||||
- Tool usage statistics aggregation across all batches
|
- Tool usage statistics aggregation across all prompts
|
||||||
- Cluster failure detection and graceful shutdown (morph, firecrawl, API errors)
|
- Cluster failure detection and graceful shutdown (morph, firecrawl, API errors)
|
||||||
- Configurable failure thresholds with automatic data consolidation
|
- Configurable failure thresholds with automatic data consolidation
|
||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run
|
python batch_runner.py --dataset_file=data.jsonl --run_name=my_run
|
||||||
|
|
||||||
# Resume an interrupted run
|
# Resume an interrupted run
|
||||||
python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run --resume
|
python batch_runner.py --dataset_file=data.jsonl --run_name=my_run --resume
|
||||||
|
|
||||||
# Use a specific toolset distribution
|
# Use a specific toolset distribution
|
||||||
python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run --distribution=image_gen
|
python batch_runner.py --dataset_file=data.jsonl --run_name=my_run --distribution=image_gen
|
||||||
|
|
||||||
# Configure tool failure thresholds
|
# Configure tool failure thresholds
|
||||||
python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run \\
|
python batch_runner.py --dataset_file=data.jsonl --run_name=my_run \\
|
||||||
--max_tool_failures=20 --max_tool_failure_rate=0.3 --min_tool_calls_for_rate=10
|
--max_tool_failures=20 --max_tool_failure_rate=0.3 --min_tool_calls_for_rate=10
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -30,10 +30,10 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
|
import asyncio
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Dict, Any, Optional, Tuple
|
from typing import List, Dict, Any, Optional, Tuple, Set
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from multiprocessing import Pool, Manager, Lock
|
|
||||||
import traceback
|
import traceback
|
||||||
import re
|
import re
|
||||||
|
|
||||||
@@ -49,9 +49,6 @@ from toolset_distributions import (
|
|||||||
from safe_print import safe_print
|
from safe_print import safe_print
|
||||||
|
|
||||||
|
|
||||||
# Global configuration for worker processes
|
|
||||||
_WORKER_CONFIG = {}
|
|
||||||
|
|
||||||
# Canonical names for the terminal tool (old & new implementations)
|
# Canonical names for the terminal tool (old & new implementations)
|
||||||
_TERMINAL_TOOL_NAMES = {"terminal", "terminal_tool", "simple_terminal_tool"}
|
_TERMINAL_TOOL_NAMES = {"terminal", "terminal_tool", "simple_terminal_tool"}
|
||||||
|
|
||||||
@@ -295,10 +292,9 @@ def _extract_tool_stats(messages: List[Dict[str, Any]]) -> Dict[str, Dict[str, i
|
|||||||
return tool_stats
|
return tool_stats
|
||||||
|
|
||||||
|
|
||||||
def _process_single_prompt(
|
async def _process_single_prompt(
|
||||||
prompt_index: int,
|
prompt_index: int,
|
||||||
prompt_data: Dict[str, Any],
|
prompt_data: Dict[str, Any],
|
||||||
batch_num: int,
|
|
||||||
config: Dict[str, Any]
|
config: Dict[str, Any]
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
@@ -307,7 +303,6 @@ def _process_single_prompt(
|
|||||||
Args:
|
Args:
|
||||||
prompt_index (int): Index of prompt in dataset
|
prompt_index (int): Index of prompt in dataset
|
||||||
prompt_data (Dict): Prompt data containing 'prompt' field
|
prompt_data (Dict): Prompt data containing 'prompt' field
|
||||||
batch_num (int): Batch number
|
|
||||||
config (Dict): Configuration dict with agent parameters
|
config (Dict): Configuration dict with agent parameters
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -332,11 +327,13 @@ def _process_single_prompt(
|
|||||||
save_trajectories=False, # We handle saving ourselves
|
save_trajectories=False, # We handle saving ourselves
|
||||||
verbose_logging=config.get("verbose", False),
|
verbose_logging=config.get("verbose", False),
|
||||||
ephemeral_system_prompt=config.get("ephemeral_system_prompt"),
|
ephemeral_system_prompt=config.get("ephemeral_system_prompt"),
|
||||||
log_prefix_chars=config.get("log_prefix_chars", 100)
|
log_prefix_chars=config.get("log_prefix_chars", 100),
|
||||||
|
prokletor_client=config.get("prokletor_client"),
|
||||||
|
prokletor_formatter=config.get("prokletor_formatter")
|
||||||
)
|
)
|
||||||
|
|
||||||
# Run the agent with task_id to ensure each task gets its own isolated VM
|
# Run the agent with task_id to ensure each task gets its own isolated VM
|
||||||
result = agent.run_conversation(prompt, task_id=f"task_{prompt_index}")
|
result = await agent.run_conversation(prompt, task_id=f"task_{prompt_index}")
|
||||||
|
|
||||||
# Extract tool usage statistics
|
# Extract tool usage statistics
|
||||||
tool_stats = _extract_tool_stats(result["messages"])
|
tool_stats = _extract_tool_stats(result["messages"])
|
||||||
@@ -365,7 +362,6 @@ def _process_single_prompt(
|
|||||||
"api_calls": result["api_calls"],
|
"api_calls": result["api_calls"],
|
||||||
"toolsets_used": selected_toolsets,
|
"toolsets_used": selected_toolsets,
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"batch_num": batch_num,
|
|
||||||
"timestamp": datetime.now().isoformat(),
|
"timestamp": datetime.now().isoformat(),
|
||||||
"model": config["model"]
|
"model": config["model"]
|
||||||
}
|
}
|
||||||
@@ -389,132 +385,38 @@ def _process_single_prompt(
|
|||||||
"tool_stats": {},
|
"tool_stats": {},
|
||||||
"toolsets_used": [],
|
"toolsets_used": [],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"batch_num": batch_num,
|
|
||||||
"timestamp": datetime.now().isoformat()
|
"timestamp": datetime.now().isoformat()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def _process_batch_worker(args: Tuple) -> Dict[str, Any]:
|
async def worker(
|
||||||
|
work_queue: asyncio.Queue,
|
||||||
|
result_queue: asyncio.Queue,
|
||||||
|
config: Dict[str, Any]
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Worker function to process a single batch of prompts.
|
Consumer worker that processes prompts from the work queue.
|
||||||
|
|
||||||
Args:
|
|
||||||
args (Tuple): (batch_num, batch_data, output_dir, completed_prompts, config)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dict: Batch results with statistics
|
|
||||||
"""
|
"""
|
||||||
batch_num, batch_data, output_dir, completed_prompts_set, config = args
|
while True:
|
||||||
|
try:
|
||||||
|
task = await work_queue.get()
|
||||||
|
if task is None:
|
||||||
|
# Sentinel to stop worker
|
||||||
|
work_queue.task_done()
|
||||||
|
break
|
||||||
|
|
||||||
output_dir = Path(output_dir)
|
prompt_index, prompt_data = task
|
||||||
print(f"\n🔄 Batch {batch_num}: Starting ({len(batch_data)} prompts)")
|
|
||||||
|
|
||||||
# Output file for this batch
|
result = await _process_single_prompt(prompt_index, prompt_data, config)
|
||||||
batch_output_file = output_dir / f"batch_{batch_num}.jsonl"
|
|
||||||
|
|
||||||
# Filter out already completed prompts
|
await result_queue.put(result)
|
||||||
prompts_to_process = [
|
work_queue.task_done()
|
||||||
(idx, data) for idx, data in batch_data
|
|
||||||
if idx not in completed_prompts_set
|
|
||||||
]
|
|
||||||
|
|
||||||
if not prompts_to_process:
|
except Exception as e:
|
||||||
print(f"✅ Batch {batch_num}: Already completed (skipping)")
|
print(f"Error in worker: {e}")
|
||||||
return {
|
if 'task' in locals() and task is not None:
|
||||||
"batch_num": batch_num,
|
work_queue.task_done()
|
||||||
"processed": 0,
|
|
||||||
"skipped": len(batch_data),
|
|
||||||
"tool_stats": {},
|
|
||||||
"completed_prompts": []
|
|
||||||
}
|
|
||||||
|
|
||||||
print(f" Processing {len(prompts_to_process)} prompts (skipping {len(batch_data) - len(prompts_to_process)} already completed)")
|
|
||||||
|
|
||||||
# Initialize aggregated stats for this batch
|
|
||||||
batch_tool_stats = {}
|
|
||||||
batch_profiling_stats = [] # Collect profiling stats from each prompt
|
|
||||||
completed_in_batch = []
|
|
||||||
all_tool_errors = [] # Track all tool errors in this batch
|
|
||||||
exception_errors = [] # Track top-level exceptions
|
|
||||||
|
|
||||||
# Process each prompt sequentially in this batch
|
|
||||||
for prompt_index, prompt_data in prompts_to_process:
|
|
||||||
# Process the prompt
|
|
||||||
result = _process_single_prompt(
|
|
||||||
prompt_index,
|
|
||||||
prompt_data,
|
|
||||||
batch_num,
|
|
||||||
config
|
|
||||||
)
|
|
||||||
|
|
||||||
# Track tool errors from the conversation
|
|
||||||
if result.get("tool_errors"):
|
|
||||||
for tool_error in result["tool_errors"]:
|
|
||||||
all_tool_errors.append({
|
|
||||||
"prompt_index": prompt_index,
|
|
||||||
"tool_name": tool_error["tool_name"],
|
|
||||||
"error_message": tool_error["error_message"],
|
|
||||||
"full_content": tool_error.get("full_content", ""),
|
|
||||||
"error_type": tool_error.get("error_type", "Other")
|
|
||||||
})
|
|
||||||
|
|
||||||
# Track top-level exceptions (not tool errors)
|
|
||||||
if not result["success"]:
|
|
||||||
exception_errors.append({
|
|
||||||
"prompt_index": prompt_index,
|
|
||||||
"error": result.get("error", "Unknown error"),
|
|
||||||
"traceback": result.get("traceback", "")
|
|
||||||
})
|
|
||||||
safe_print(f"[bold red]❌ Exception in prompt {prompt_index}:[/bold red] {result.get('error', '')[:100]}")
|
|
||||||
|
|
||||||
# Save trajectory if successful
|
|
||||||
if result["success"] and result["trajectory"]:
|
|
||||||
trajectory_entry = {
|
|
||||||
"prompt_index": prompt_index,
|
|
||||||
"conversations": result["trajectory"],
|
|
||||||
"metadata": result["metadata"],
|
|
||||||
"completed": result["completed"],
|
|
||||||
"api_calls": result["api_calls"],
|
|
||||||
"toolsets_used": result["toolsets_used"]
|
|
||||||
}
|
|
||||||
|
|
||||||
# Append to batch output file
|
|
||||||
with open(batch_output_file, 'a', encoding='utf-8') as f:
|
|
||||||
f.write(json.dumps(trajectory_entry, ensure_ascii=False) + "\n")
|
|
||||||
|
|
||||||
# Aggregate tool statistics
|
|
||||||
for tool_name, stats in result.get("tool_stats", {}).items():
|
|
||||||
if tool_name not in batch_tool_stats:
|
|
||||||
batch_tool_stats[tool_name] = {
|
|
||||||
"count": 0,
|
|
||||||
"success": 0,
|
|
||||||
"failure": 0
|
|
||||||
}
|
|
||||||
|
|
||||||
batch_tool_stats[tool_name]["count"] += stats["count"]
|
|
||||||
batch_tool_stats[tool_name]["success"] += stats["success"]
|
|
||||||
batch_tool_stats[tool_name]["failure"] += stats["failure"]
|
|
||||||
|
|
||||||
# Collect profiling statistics
|
|
||||||
if result.get("profiling_stats"):
|
|
||||||
batch_profiling_stats.append(result["profiling_stats"])
|
|
||||||
|
|
||||||
completed_in_batch.append(prompt_index)
|
|
||||||
print(f" ✅ Prompt {prompt_index} completed")
|
|
||||||
|
|
||||||
print(f"✅ Batch {batch_num}: Completed ({len(prompts_to_process)} prompts processed)")
|
|
||||||
|
|
||||||
return {
|
|
||||||
"batch_num": batch_num,
|
|
||||||
"processed": len(prompts_to_process),
|
|
||||||
"skipped": len(batch_data) - len(prompts_to_process),
|
|
||||||
"tool_stats": batch_tool_stats,
|
|
||||||
"profiling_stats": batch_profiling_stats,
|
|
||||||
"completed_prompts": completed_in_batch,
|
|
||||||
"tool_errors": all_tool_errors,
|
|
||||||
"exception_errors": exception_errors
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class BatchRunner:
|
class BatchRunner:
|
||||||
@@ -525,7 +427,6 @@ class BatchRunner:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
dataset_file: str,
|
dataset_file: str,
|
||||||
batch_size: int,
|
|
||||||
run_name: str,
|
run_name: str,
|
||||||
distribution: str = "default",
|
distribution: str = "default",
|
||||||
max_iterations: int = 10,
|
max_iterations: int = 10,
|
||||||
@@ -536,34 +437,36 @@ class BatchRunner:
|
|||||||
verbose: bool = False,
|
verbose: bool = False,
|
||||||
ephemeral_system_prompt: str = None,
|
ephemeral_system_prompt: str = None,
|
||||||
log_prefix_chars: int = 100,
|
log_prefix_chars: int = 100,
|
||||||
max_tool_failures: int = 10,
|
max_tool_failures: float = float("inf"),
|
||||||
max_tool_failure_rate: float = 0.5,
|
max_tool_failure_rate: float = 0.5,
|
||||||
keep_recent_errors: int = 5,
|
keep_recent_errors: int = 5,
|
||||||
min_tool_calls_for_rate: int = 10,
|
min_tool_calls_for_rate: int = 10,
|
||||||
|
prokletor_client: str = None,
|
||||||
|
prokletor_formatter: str = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize the batch runner.
|
Initialize the batch runner.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dataset_file (str): Path to the dataset JSONL file with 'prompt' field
|
dataset_file (str): Path to the dataset JSONL file with 'prompt' field
|
||||||
batch_size (int): Number of prompts per batch
|
|
||||||
run_name (str): Name for this run (used for checkpointing and output)
|
run_name (str): Name for this run (used for checkpointing and output)
|
||||||
distribution (str): Toolset distribution to use (default: "default")
|
distribution (str): Toolset distribution to use (default: "default")
|
||||||
max_iterations (int): Max iterations per agent run
|
max_iterations (int): Max iterations per agent run
|
||||||
base_url (str): Base URL for model API
|
base_url (str): Base URL for model API
|
||||||
api_key (str): API key for model
|
api_key (str): API key for model
|
||||||
model (str): Model name to use
|
model (str): Model name to use
|
||||||
num_workers (int): Number of parallel workers
|
num_workers (int): Number of parallel workers (default: 4)
|
||||||
verbose (bool): Enable verbose logging
|
verbose (bool): Enable verbose logging
|
||||||
ephemeral_system_prompt (str): System prompt used during agent execution but NOT saved to trajectories (optional)
|
ephemeral_system_prompt (str): System prompt used during agent execution but NOT saved to trajectories (optional)
|
||||||
log_prefix_chars (int): Number of characters to show in log previews for tool calls/responses (default: 20)
|
log_prefix_chars (int): Number of characters to show in log previews for tool calls/responses (default: 20)
|
||||||
max_tool_failures (int): Maximum number of tool failures before stopping (default: 10)
|
max_tool_failures (float): Maximum number of tool failures before stopping (default: inf for unlimited)
|
||||||
max_tool_failure_rate (float): Maximum tool failure rate (0.0-1.0) before stopping (default: 0.5)
|
max_tool_failure_rate (float): Maximum tool failure rate (0.0-1.0) before stopping (default: 0.5)
|
||||||
keep_recent_errors (int): Number of recent errors to keep per tool (default: 5)
|
keep_recent_errors (int): Number of recent errors to keep per tool (default: 5)
|
||||||
min_tool_calls_for_rate (int): Minimum number of tool calls before checking failure rate (default: 10)
|
min_tool_calls_for_rate (int): Minimum number of tool calls before checking failure rate (default: 10)
|
||||||
|
prokletor_client (str): Name of the prokletor client to use
|
||||||
|
prokletor_formatter (str): Name of the prokletor formatter to use
|
||||||
"""
|
"""
|
||||||
self.dataset_file = Path(dataset_file)
|
self.dataset_file = Path(dataset_file)
|
||||||
self.batch_size = batch_size
|
|
||||||
self.run_name = run_name
|
self.run_name = run_name
|
||||||
self.distribution = distribution
|
self.distribution = distribution
|
||||||
self.max_iterations = max_iterations
|
self.max_iterations = max_iterations
|
||||||
@@ -578,6 +481,8 @@ class BatchRunner:
|
|||||||
self.max_tool_failure_rate = max_tool_failure_rate
|
self.max_tool_failure_rate = max_tool_failure_rate
|
||||||
self.keep_recent_errors = keep_recent_errors
|
self.keep_recent_errors = keep_recent_errors
|
||||||
self.min_tool_calls_for_rate = min_tool_calls_for_rate
|
self.min_tool_calls_for_rate = min_tool_calls_for_rate
|
||||||
|
self.prokletor_client = prokletor_client
|
||||||
|
self.prokletor_formatter = prokletor_formatter
|
||||||
|
|
||||||
# Validate distribution
|
# Validate distribution
|
||||||
if not validate_distribution(distribution):
|
if not validate_distribution(distribution):
|
||||||
@@ -596,16 +501,14 @@ class BatchRunner:
|
|||||||
# Errors file
|
# Errors file
|
||||||
self.errors_file = self.output_dir / "errors.json"
|
self.errors_file = self.output_dir / "errors.json"
|
||||||
|
|
||||||
|
# Trajectories file
|
||||||
|
self.trajectories_file = self.output_dir / "trajectories.jsonl"
|
||||||
|
|
||||||
# Load dataset
|
# Load dataset
|
||||||
self.dataset = self._load_dataset()
|
self.dataset = self._load_dataset()
|
||||||
|
|
||||||
# Create batches
|
|
||||||
self.batches = self._create_batches()
|
|
||||||
|
|
||||||
safe_print("[bold cyan]📊 Batch Runner Initialized[/bold cyan]")
|
safe_print("[bold cyan]📊 Batch Runner Initialized[/bold cyan]")
|
||||||
safe_print(f" Dataset: {self.dataset_file} ({len(self.dataset)} prompts)")
|
safe_print(f" Dataset: {self.dataset_file} ({len(self.dataset)} prompts)")
|
||||||
safe_print(f" Batch size: {self.batch_size}")
|
|
||||||
safe_print(f" Total batches: {len(self.batches)}")
|
|
||||||
safe_print(f" Run name: {self.run_name}")
|
safe_print(f" Run name: {self.run_name}")
|
||||||
safe_print(f" Distribution: {self.distribution}")
|
safe_print(f" Distribution: {self.distribution}")
|
||||||
safe_print(f" Output directory: {self.output_dir}")
|
safe_print(f" Output directory: {self.output_dir}")
|
||||||
@@ -651,20 +554,6 @@ class BatchRunner:
|
|||||||
|
|
||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
def _create_batches(self) -> List[List[Tuple[int, Dict[str, Any]]]]:
|
|
||||||
"""
|
|
||||||
Split dataset into batches with indices.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of batches, where each batch is a list of (index, entry) tuples
|
|
||||||
"""
|
|
||||||
batches = []
|
|
||||||
for i in range(0, len(self.dataset), self.batch_size):
|
|
||||||
batch = [(idx, entry) for idx, entry in enumerate(self.dataset[i:i + self.batch_size], start=i)]
|
|
||||||
batches.append(batch)
|
|
||||||
|
|
||||||
return batches
|
|
||||||
|
|
||||||
def _load_checkpoint(self) -> Dict[str, Any]:
|
def _load_checkpoint(self) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Load checkpoint data if it exists.
|
Load checkpoint data if it exists.
|
||||||
@@ -676,7 +565,6 @@ class BatchRunner:
|
|||||||
return {
|
return {
|
||||||
"run_name": self.run_name,
|
"run_name": self.run_name,
|
||||||
"completed_prompts": [],
|
"completed_prompts": [],
|
||||||
"batch_stats": {},
|
|
||||||
"last_updated": None
|
"last_updated": None
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -688,61 +576,34 @@ class BatchRunner:
|
|||||||
return {
|
return {
|
||||||
"run_name": self.run_name,
|
"run_name": self.run_name,
|
||||||
"completed_prompts": [],
|
"completed_prompts": [],
|
||||||
"batch_stats": {},
|
|
||||||
"last_updated": None
|
"last_updated": None
|
||||||
}
|
}
|
||||||
|
|
||||||
def _save_checkpoint(self, checkpoint_data: Dict[str, Any], lock: Optional[Lock] = None):
|
def _save_checkpoint(self, checkpoint_data: Dict[str, Any]):
|
||||||
"""
|
"""
|
||||||
Save checkpoint data.
|
Save checkpoint data.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
checkpoint_data (Dict): Checkpoint data to save
|
checkpoint_data (Dict): Checkpoint data to save
|
||||||
lock (Lock): Optional lock for thread-safe access
|
|
||||||
"""
|
"""
|
||||||
checkpoint_data["last_updated"] = datetime.now().isoformat()
|
checkpoint_data["last_updated"] = datetime.now().isoformat()
|
||||||
|
with open(self.checkpoint_file, 'w', encoding='utf-8') as f:
|
||||||
|
json.dump(checkpoint_data, f, indent=2, ensure_ascii=False)
|
||||||
|
|
||||||
if lock:
|
def _save_final_stats(
|
||||||
with lock:
|
self,
|
||||||
with open(self.checkpoint_file, 'w', encoding='utf-8') as f:
|
num_processed: int,
|
||||||
json.dump(checkpoint_data, f, indent=2, ensure_ascii=False)
|
tool_stats: Dict[str, Dict[str, int]],
|
||||||
else:
|
start_time: float,
|
||||||
with open(self.checkpoint_file, 'w', encoding='utf-8') as f:
|
tool_errors_by_tool: Dict[str, List[Dict]],
|
||||||
json.dump(checkpoint_data, f, indent=2, ensure_ascii=False)
|
exception_errors: List[Dict],
|
||||||
|
early_exit: bool = False,
|
||||||
def _consolidate_data(self, num_batches: int, tool_stats: Dict[str, Dict[str, int]],
|
exit_reason: str = None,
|
||||||
start_time: float, tool_errors_by_tool: Dict[str, List[Dict]],
|
profiling_stats_list: List[Dict] = None
|
||||||
exception_errors: List[Dict], early_exit: bool = False, exit_reason: str = None,
|
):
|
||||||
profiling_stats_list: List[Dict] = None):
|
|
||||||
"""
|
"""
|
||||||
Consolidate batch data into trajectories.jsonl and save statistics.
|
Save final statistics and errors.
|
||||||
|
|
||||||
Args:
|
|
||||||
num_batches (int): Number of batches processed
|
|
||||||
tool_stats (Dict): Aggregated tool statistics
|
|
||||||
start_time (float): Start time of the run
|
|
||||||
tool_errors_by_tool (Dict): Tool errors grouped by tool name with k most recent
|
|
||||||
exception_errors (List): Top-level exceptions
|
|
||||||
early_exit (bool): Whether this is an early exit
|
|
||||||
exit_reason (str): Reason for early exit
|
|
||||||
profiling_stats_list (List[Dict]): List of profiling statistics from each conversation
|
|
||||||
"""
|
"""
|
||||||
# Combine all batch files into a single trajectories.jsonl file
|
|
||||||
combined_file = self.output_dir / "trajectories.jsonl"
|
|
||||||
safe_print(f"\n[cyan]📦 Combining batch files into {combined_file.name}...[/cyan]")
|
|
||||||
|
|
||||||
entries_written = 0
|
|
||||||
with open(combined_file, 'w', encoding='utf-8') as outfile:
|
|
||||||
for batch_num in range(num_batches):
|
|
||||||
batch_file = self.output_dir / f"batch_{batch_num}.jsonl"
|
|
||||||
if batch_file.exists():
|
|
||||||
with open(batch_file, 'r', encoding='utf-8') as infile:
|
|
||||||
for line in infile:
|
|
||||||
outfile.write(line)
|
|
||||||
entries_written += 1
|
|
||||||
|
|
||||||
safe_print(f"[green]✅ Combined {num_batches} batch files into trajectories.jsonl ({entries_written} entries)[/green]")
|
|
||||||
|
|
||||||
# Calculate success rates for tool stats
|
# Calculate success rates for tool stats
|
||||||
for tool_name in tool_stats:
|
for tool_name in tool_stats:
|
||||||
stats = tool_stats[tool_name]
|
stats = tool_stats[tool_name]
|
||||||
@@ -794,17 +655,21 @@ class BatchRunner:
|
|||||||
# Aggregate profiling statistics if available
|
# Aggregate profiling statistics if available
|
||||||
aggregated_profiling_stats = None
|
aggregated_profiling_stats = None
|
||||||
if profiling_stats_list:
|
if profiling_stats_list:
|
||||||
from profiling import aggregate_profiling_stats
|
try:
|
||||||
aggregated_profiling_stats = aggregate_profiling_stats(profiling_stats_list)
|
from profiling import aggregate_profiling_stats, print_aggregated_statistics
|
||||||
|
aggregated_profiling_stats = aggregate_profiling_stats(profiling_stats_list)
|
||||||
|
|
||||||
# Save final statistics (without detailed errors)
|
# Display aggregated profiling statistics
|
||||||
|
print_aggregated_statistics(aggregated_profiling_stats, detailed=True)
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Save final statistics
|
||||||
final_stats = {
|
final_stats = {
|
||||||
"run_name": self.run_name,
|
"run_name": self.run_name,
|
||||||
"distribution": self.distribution,
|
"distribution": self.distribution,
|
||||||
"total_prompts": len(self.dataset),
|
"total_prompts": len(self.dataset),
|
||||||
"total_batches": len(self.batches),
|
"processed": num_processed,
|
||||||
"batches_processed": num_batches,
|
|
||||||
"batch_size": self.batch_size,
|
|
||||||
"model": self.model,
|
"model": self.model,
|
||||||
"completed_at": datetime.now().isoformat(),
|
"completed_at": datetime.now().isoformat(),
|
||||||
"duration_seconds": round(time.time() - start_time, 2),
|
"duration_seconds": round(time.time() - start_time, 2),
|
||||||
@@ -817,18 +682,9 @@ class BatchRunner:
|
|||||||
with open(self.stats_file, 'w', encoding='utf-8') as f:
|
with open(self.stats_file, 'w', encoding='utf-8') as f:
|
||||||
json.dump(final_stats, f, indent=2, ensure_ascii=False)
|
json.dump(final_stats, f, indent=2, ensure_ascii=False)
|
||||||
|
|
||||||
# Display aggregated profiling statistics
|
async def _run_async(self, resume: bool = False):
|
||||||
if aggregated_profiling_stats:
|
|
||||||
from profiling import print_aggregated_statistics
|
|
||||||
print_aggregated_statistics(aggregated_profiling_stats, detailed=True)
|
|
||||||
|
|
||||||
|
|
||||||
def run(self, resume: bool = False):
|
|
||||||
"""
|
"""
|
||||||
Run the batch processing pipeline.
|
Async implementation of the batch runner pipeline.
|
||||||
|
|
||||||
Args:
|
|
||||||
resume (bool): Whether to resume from checkpoint
|
|
||||||
"""
|
"""
|
||||||
print("\n" + "=" * 70)
|
print("\n" + "=" * 70)
|
||||||
print("🚀 Starting Batch Processing")
|
print("🚀 Starting Batch Processing")
|
||||||
@@ -838,15 +694,32 @@ class BatchRunner:
|
|||||||
checkpoint_data = self._load_checkpoint() if resume else {
|
checkpoint_data = self._load_checkpoint() if resume else {
|
||||||
"run_name": self.run_name,
|
"run_name": self.run_name,
|
||||||
"completed_prompts": [],
|
"completed_prompts": [],
|
||||||
"batch_stats": {},
|
|
||||||
"last_updated": None
|
"last_updated": None
|
||||||
}
|
}
|
||||||
|
|
||||||
if resume and checkpoint_data.get("completed_prompts"):
|
if resume and checkpoint_data.get("completed_prompts"):
|
||||||
print(f"📂 Resuming from checkpoint ({len(checkpoint_data['completed_prompts'])} prompts already completed)")
|
print(f"📂 Resuming from checkpoint ({len(checkpoint_data['completed_prompts'])} prompts already completed)")
|
||||||
|
|
||||||
# Prepare configuration for workers
|
completed_prompts_set = set(checkpoint_data.get("completed_prompts", []))
|
||||||
config = {
|
|
||||||
|
# Prepare queues
|
||||||
|
work_queue = asyncio.Queue()
|
||||||
|
result_queue = asyncio.Queue()
|
||||||
|
|
||||||
|
# Enqueue prompts to process
|
||||||
|
prompts_to_process = []
|
||||||
|
for idx, entry in enumerate(self.dataset):
|
||||||
|
if idx not in completed_prompts_set:
|
||||||
|
prompts_to_process.append((idx, entry))
|
||||||
|
work_queue.put_nowait((idx, entry))
|
||||||
|
|
||||||
|
total_to_process = len(prompts_to_process)
|
||||||
|
if total_to_process == 0:
|
||||||
|
print("✅ All prompts already completed.")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Worker configuration
|
||||||
|
worker_config = {
|
||||||
"distribution": self.distribution,
|
"distribution": self.distribution,
|
||||||
"model": self.model,
|
"model": self.model,
|
||||||
"max_iterations": self.max_iterations,
|
"max_iterations": self.max_iterations,
|
||||||
@@ -854,123 +727,124 @@ class BatchRunner:
|
|||||||
"api_key": self.api_key,
|
"api_key": self.api_key,
|
||||||
"verbose": self.verbose,
|
"verbose": self.verbose,
|
||||||
"ephemeral_system_prompt": self.ephemeral_system_prompt,
|
"ephemeral_system_prompt": self.ephemeral_system_prompt,
|
||||||
"log_prefix_chars": self.log_prefix_chars
|
"log_prefix_chars": self.log_prefix_chars,
|
||||||
|
"prokletor_client": self.prokletor_client,
|
||||||
|
"prokletor_formatter": self.prokletor_formatter
|
||||||
}
|
}
|
||||||
|
|
||||||
# Get completed prompts set
|
# Start workers
|
||||||
completed_prompts_set = set(checkpoint_data.get("completed_prompts", []))
|
workers = []
|
||||||
|
for _ in range(min(self.num_workers, total_to_process)):
|
||||||
|
w = asyncio.create_task(worker(work_queue, result_queue, worker_config))
|
||||||
|
workers.append(w)
|
||||||
|
|
||||||
# Aggregate statistics across all batches
|
print(f" Processing {total_to_process} prompts with {len(workers)} workers...")
|
||||||
|
|
||||||
|
# Aggregate statistics
|
||||||
total_tool_stats = {}
|
total_tool_stats = {}
|
||||||
all_profiling_stats = [] # Collect all profiling stats for aggregation
|
all_profiling_stats = []
|
||||||
tool_errors_by_tool = {} # {tool_name: [list of k most recent errors]}
|
tool_errors_by_tool = {}
|
||||||
all_exception_errors = []
|
all_exception_errors = []
|
||||||
all_completed_prompts = list(completed_prompts_set)
|
|
||||||
total_processed = len(completed_prompts_set)
|
|
||||||
total_tool_errors = 0
|
total_tool_errors = 0
|
||||||
early_exit = False
|
early_exit = False
|
||||||
exit_reason = None
|
exit_reason = None
|
||||||
|
processed_count = 0
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
# Process batches in parallel
|
# Process results as they arrive
|
||||||
with Pool(processes=self.num_workers) as pool:
|
try:
|
||||||
# Create tasks for each batch
|
while processed_count < total_to_process:
|
||||||
tasks = [
|
result = await result_queue.get()
|
||||||
(
|
processed_count += 1
|
||||||
batch_num,
|
|
||||||
batch_data,
|
|
||||||
str(self.output_dir), # Convert Path to string for pickling
|
|
||||||
completed_prompts_set,
|
|
||||||
config
|
|
||||||
)
|
|
||||||
for batch_num, batch_data in enumerate(self.batches)
|
|
||||||
]
|
|
||||||
|
|
||||||
# Process batches in parallel and check tool failure threshold as results come in
|
prompt_index = result["prompt_index"]
|
||||||
# imap_unordered allows parallel processing while getting results as they complete
|
|
||||||
batch_num = 0
|
|
||||||
try:
|
|
||||||
for result in pool.imap_unordered(_process_batch_worker, tasks):
|
|
||||||
# Update statistics
|
|
||||||
all_completed_prompts.extend(result.get("completed_prompts", []))
|
|
||||||
total_processed += result.get("processed", 0)
|
|
||||||
|
|
||||||
# Aggregate tool stats
|
# Track exceptions
|
||||||
for tool_name, stats in result.get("tool_stats", {}).items():
|
if not result["success"]:
|
||||||
if tool_name not in total_tool_stats:
|
safe_print(f"[bold red]❌ Exception in prompt {prompt_index}:[/bold red] {result.get('error', '')[:100]}")
|
||||||
total_tool_stats[tool_name] = {
|
all_exception_errors.append({
|
||||||
"count": 0,
|
"prompt_index": prompt_index,
|
||||||
"success": 0,
|
"error": result.get("error", "Unknown error"),
|
||||||
"failure": 0
|
"traceback": result.get("traceback", "")
|
||||||
}
|
})
|
||||||
|
else:
|
||||||
|
print(f" ✅ Prompt {prompt_index} completed")
|
||||||
|
|
||||||
total_tool_stats[tool_name]["count"] += stats["count"]
|
# Save trajectory immediately
|
||||||
total_tool_stats[tool_name]["success"] += stats["success"]
|
if result.get("trajectory"):
|
||||||
total_tool_stats[tool_name]["failure"] += stats["failure"]
|
trajectory_entry = {
|
||||||
|
"prompt_index": prompt_index,
|
||||||
|
"conversations": result["trajectory"],
|
||||||
|
"metadata": result["metadata"],
|
||||||
|
"completed": result["completed"],
|
||||||
|
"api_calls": result["api_calls"],
|
||||||
|
"toolsets_used": result["toolsets_used"]
|
||||||
|
}
|
||||||
|
with open(self.trajectories_file, 'a', encoding='utf-8') as f:
|
||||||
|
f.write(json.dumps(trajectory_entry, ensure_ascii=False) + "\n")
|
||||||
|
|
||||||
# Collect profiling stats from this batch
|
# Aggregate tool stats
|
||||||
if result.get("profiling_stats"):
|
for tool_name, stats in result.get("tool_stats", {}).items():
|
||||||
all_profiling_stats.extend(result["profiling_stats"])
|
if tool_name not in total_tool_stats:
|
||||||
|
total_tool_stats[tool_name] = {"count": 0, "success": 0, "failure": 0}
|
||||||
|
|
||||||
# Aggregate tool errors (keep k most recent per tool)
|
total_tool_stats[tool_name]["count"] += stats["count"]
|
||||||
for tool_error in result.get("tool_errors", []):
|
total_tool_stats[tool_name]["success"] += stats["success"]
|
||||||
tool_name = tool_error["tool_name"]
|
total_tool_stats[tool_name]["failure"] += stats["failure"]
|
||||||
if tool_name not in tool_errors_by_tool:
|
|
||||||
tool_errors_by_tool[tool_name] = []
|
|
||||||
|
|
||||||
# Add error and keep only k most recent
|
# Collect profiling stats
|
||||||
tool_errors_by_tool[tool_name].append(tool_error)
|
if result.get("profiling_stats"):
|
||||||
if len(tool_errors_by_tool[tool_name]) > self.keep_recent_errors:
|
all_profiling_stats.append(result["profiling_stats"])
|
||||||
tool_errors_by_tool[tool_name] = tool_errors_by_tool[tool_name][-self.keep_recent_errors:]
|
|
||||||
|
|
||||||
total_tool_errors += 1
|
# Aggregate tool errors
|
||||||
|
for tool_error in result.get("tool_errors", []):
|
||||||
|
tool_name = tool_error["tool_name"]
|
||||||
|
if tool_name not in tool_errors_by_tool:
|
||||||
|
tool_errors_by_tool[tool_name] = []
|
||||||
|
|
||||||
# Track exception errors
|
tool_errors_by_tool[tool_name].append(tool_error)
|
||||||
all_exception_errors.extend(result.get("exception_errors", []))
|
# Keep only k most recent
|
||||||
|
if len(tool_errors_by_tool[tool_name]) > self.keep_recent_errors:
|
||||||
|
tool_errors_by_tool[tool_name] = tool_errors_by_tool[tool_name][-self.keep_recent_errors:]
|
||||||
|
|
||||||
# Check tool failure thresholds
|
total_tool_errors += 1
|
||||||
# Calculate total tool calls (not prompts)
|
|
||||||
total_tool_calls = sum(stats["count"] for stats in total_tool_stats.values())
|
|
||||||
|
|
||||||
# Check absolute count threshold
|
# Update checkpoint
|
||||||
if total_tool_errors >= self.max_tool_failures:
|
completed_prompts_set.add(prompt_index)
|
||||||
|
checkpoint_data["completed_prompts"] = list(completed_prompts_set)
|
||||||
|
self._save_checkpoint(checkpoint_data)
|
||||||
|
|
||||||
|
# Check failure thresholds
|
||||||
|
total_tool_calls = sum(stats["count"] for stats in total_tool_stats.values())
|
||||||
|
|
||||||
|
if total_tool_errors >= self.max_tool_failures:
|
||||||
|
early_exit = True
|
||||||
|
exit_reason = f"Exceeded maximum tool failures ({total_tool_errors}/{self.max_tool_failures})"
|
||||||
|
break
|
||||||
|
|
||||||
|
if total_tool_calls >= self.min_tool_calls_for_rate:
|
||||||
|
tool_failure_rate = total_tool_errors / total_tool_calls
|
||||||
|
if tool_failure_rate >= self.max_tool_failure_rate:
|
||||||
early_exit = True
|
early_exit = True
|
||||||
exit_reason = f"Exceeded maximum tool failures ({total_tool_errors}/{self.max_tool_failures})"
|
exit_reason = f"Exceeded tool failure rate ({tool_failure_rate:.2%})"
|
||||||
safe_print(f"\n[bold red]🛑 STOPPING: {exit_reason}[/bold red]")
|
|
||||||
pool.terminate() # Stop all workers immediately
|
|
||||||
break
|
break
|
||||||
|
|
||||||
# Check rate threshold (only if we have enough tool calls to trust the rate)
|
except asyncio.CancelledError:
|
||||||
if total_tool_calls >= self.min_tool_calls_for_rate:
|
early_exit = True
|
||||||
tool_failure_rate = total_tool_errors / total_tool_calls
|
exit_reason = "Run cancelled"
|
||||||
|
finally:
|
||||||
|
# Stop all workers
|
||||||
|
for _ in range(len(workers)):
|
||||||
|
work_queue.put_nowait(None)
|
||||||
|
await asyncio.gather(*workers, return_exceptions=True)
|
||||||
|
|
||||||
if tool_failure_rate >= self.max_tool_failure_rate:
|
if early_exit:
|
||||||
early_exit = True
|
safe_print(f"\n[bold red]🛑 STOPPING: {exit_reason}[/bold red]")
|
||||||
exit_reason = f"Exceeded tool failure rate ({tool_failure_rate:.2%} >= {self.max_tool_failure_rate:.2%}, {total_tool_errors}/{total_tool_calls} tool calls)"
|
|
||||||
safe_print(f"\n[bold red]🛑 STOPPING: {exit_reason}[/bold red]")
|
|
||||||
pool.terminate() # Stop all workers immediately
|
|
||||||
break
|
|
||||||
|
|
||||||
# Update checkpoint after each batch completes
|
# Save final statistics
|
||||||
checkpoint_data["completed_prompts"] = all_completed_prompts
|
self._save_final_stats(
|
||||||
self._save_checkpoint(checkpoint_data)
|
processed_count,
|
||||||
|
|
||||||
batch_num += 1
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
safe_print("\n[bold yellow]⚠️ Interrupted by user, stopping workers...[/bold yellow]")
|
|
||||||
pool.terminate()
|
|
||||||
early_exit = True
|
|
||||||
exit_reason = "Interrupted by user"
|
|
||||||
|
|
||||||
# Save final checkpoint
|
|
||||||
checkpoint_data["completed_prompts"] = all_completed_prompts
|
|
||||||
self._save_checkpoint(checkpoint_data)
|
|
||||||
|
|
||||||
# Consolidate data and save statistics
|
|
||||||
num_batches_processed = batch_num + 1 if early_exit else len(self.batches)
|
|
||||||
self._consolidate_data(
|
|
||||||
num_batches_processed,
|
|
||||||
total_tool_stats,
|
total_tool_stats,
|
||||||
start_time,
|
start_time,
|
||||||
tool_errors_by_tool,
|
tool_errors_by_tool,
|
||||||
@@ -980,164 +854,28 @@ class BatchRunner:
|
|||||||
all_profiling_stats
|
all_profiling_stats
|
||||||
)
|
)
|
||||||
|
|
||||||
# Print summary
|
# Summary output
|
||||||
safe_print("\n" + "=" * 70)
|
safe_print("\n" + "=" * 70)
|
||||||
if early_exit:
|
safe_print(f"✅ Total prompts processed: {processed_count}/{total_to_process}")
|
||||||
safe_print("[bold yellow]⚠️ BATCH PROCESSING STOPPED EARLY[/bold yellow]")
|
|
||||||
safe_print(f"[yellow]Reason: {exit_reason}[/yellow]")
|
|
||||||
else:
|
|
||||||
safe_print("[bold green]📊 BATCH PROCESSING COMPLETE[/bold green]")
|
|
||||||
safe_print("=" * 70)
|
|
||||||
|
|
||||||
safe_print(f"✅ Total prompts processed: {total_processed}")
|
|
||||||
safe_print(f"✅ Batches completed: {num_batches_processed}/{len(self.batches)}")
|
|
||||||
safe_print(f"⏱️ Total duration: {round(time.time() - start_time, 2)}s")
|
safe_print(f"⏱️ Total duration: {round(time.time() - start_time, 2)}s")
|
||||||
|
|
||||||
# Tool error summary
|
|
||||||
if tool_errors_by_tool:
|
if tool_errors_by_tool:
|
||||||
total_errors = sum(len(errors) for errors in tool_errors_by_tool.values())
|
safe_print(f"\n[bold red]🚨 Tool Errors: {total_tool_errors} total[/bold red]")
|
||||||
safe_print(f"\n[bold red]🚨 Tool Errors: {total_tool_errors} total ({len(tool_errors_by_tool)} tools)[/bold red]")
|
# Simplified error printing here, full detail is in json
|
||||||
safe_print("[red]-[/red]" * 70)
|
for tool_name, errors in tool_errors_by_tool.items():
|
||||||
|
safe_print(f" {tool_name}: {len(errors)} errors")
|
||||||
# Sort tools by error count
|
|
||||||
sorted_tools = sorted(
|
|
||||||
tool_errors_by_tool.items(),
|
|
||||||
key=lambda x: len(x[1]),
|
|
||||||
reverse=True
|
|
||||||
)
|
|
||||||
|
|
||||||
for tool_name, errors in sorted_tools:
|
|
||||||
# Count unique error messages
|
|
||||||
unique_errors = {}
|
|
||||||
for error in errors:
|
|
||||||
error_msg = error["error_message"][:100] # Truncate for grouping
|
|
||||||
if error_msg not in unique_errors:
|
|
||||||
unique_errors[error_msg] = []
|
|
||||||
unique_errors[error_msg].append(error)
|
|
||||||
|
|
||||||
safe_print(f"\n [red]{tool_name}:[/red] {len(errors)} errors ({len(unique_errors)} unique)")
|
|
||||||
|
|
||||||
# Show up to 3 most recent unique error types
|
|
||||||
for idx, (error_msg, instances) in enumerate(list(unique_errors.items())[:3]):
|
|
||||||
error_preview = error_msg if len(error_msg) <= 100 else error_msg[:97] + "..."
|
|
||||||
safe_print(f" [{idx+1}] [dim]{error_preview}[/dim] (x{len(instances)})")
|
|
||||||
|
|
||||||
# Show one example with prompt index and full content prefix
|
|
||||||
example = instances[-1] # Most recent
|
|
||||||
safe_print(f" [dim]Prompt {example['prompt_index']}[/dim]")
|
|
||||||
|
|
||||||
# Show full content prefix (first 200 chars)
|
|
||||||
full_content = example.get('full_content', '')
|
|
||||||
if full_content and full_content != error_preview:
|
|
||||||
content_preview = full_content[:200]
|
|
||||||
if len(full_content) > 200:
|
|
||||||
content_preview += "..."
|
|
||||||
# Show with prefix indicator
|
|
||||||
safe_print(f" [dim]Content: {content_preview}[/dim]")
|
|
||||||
|
|
||||||
if len(unique_errors) > 3:
|
|
||||||
safe_print(f" [dim]... and {len(unique_errors) - 3} more error types[/dim]")
|
|
||||||
|
|
||||||
tool_failure_rate = total_tool_errors / total_processed if total_processed > 0 else 0
|
|
||||||
safe_print(f"\n [red]Tool failure rate: {tool_failure_rate:.2%}[/red]")
|
|
||||||
|
|
||||||
# Exception errors
|
|
||||||
if all_exception_errors:
|
|
||||||
safe_print(f"\n[bold red]💥 Top-level Exceptions: {len(all_exception_errors)}[/bold red]")
|
|
||||||
safe_print("[red]-[/red]" * 70)
|
|
||||||
for error in all_exception_errors[:self.keep_recent_errors]:
|
|
||||||
error_msg = error["error"]
|
|
||||||
error_preview = error_msg[:150]
|
|
||||||
if len(error_msg) > 150:
|
|
||||||
error_preview += "..."
|
|
||||||
safe_print(f" [red]Prompt {error['prompt_index']}:[/red] [dim]{error_preview}[/dim]")
|
|
||||||
|
|
||||||
# Show traceback prefix if available
|
|
||||||
traceback_text = error.get("traceback", "")
|
|
||||||
if traceback_text:
|
|
||||||
# Show last 3 lines of traceback for context
|
|
||||||
tb_lines = traceback_text.strip().split('\n')
|
|
||||||
relevant_lines = tb_lines[-3:] if len(tb_lines) > 3 else tb_lines
|
|
||||||
for line in relevant_lines:
|
|
||||||
safe_print(f" [dim]{line}[/dim]")
|
|
||||||
|
|
||||||
safe_print(f"\n[cyan]📈 Tool Usage Statistics:[/cyan]")
|
|
||||||
safe_print("-" * 70)
|
|
||||||
|
|
||||||
if total_tool_stats:
|
|
||||||
# Sort by count descending
|
|
||||||
sorted_tools = sorted(
|
|
||||||
total_tool_stats.items(),
|
|
||||||
key=lambda x: x[1]["count"],
|
|
||||||
reverse=True
|
|
||||||
)
|
|
||||||
|
|
||||||
safe_print(f"{'Tool Name':<25} {'Count':<10} {'Success':<10} {'Failure':<10} {'Success Rate':<12}")
|
|
||||||
safe_print("-" * 70)
|
|
||||||
for tool_name, stats in sorted_tools:
|
|
||||||
safe_print(
|
|
||||||
f"{tool_name:<25} "
|
|
||||||
f"{stats['count']:<10} "
|
|
||||||
f"{stats['success']:<10} "
|
|
||||||
f"{stats['failure']:<10} "
|
|
||||||
f"{stats.get('success_rate', 0):.1f}%"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
safe_print("No tool calls were made during this run.")
|
|
||||||
|
|
||||||
# Display failure type breakdown for tools with failures
|
|
||||||
if tool_errors_by_tool:
|
|
||||||
safe_print(f"\n[cyan]📊 Failure Type Breakdown:[/cyan]")
|
|
||||||
safe_print("-" * 70)
|
|
||||||
|
|
||||||
# Sort tools by total error count
|
|
||||||
sorted_tools = sorted(
|
|
||||||
tool_errors_by_tool.items(),
|
|
||||||
key=lambda x: len(x[1]),
|
|
||||||
reverse=True
|
|
||||||
)
|
|
||||||
|
|
||||||
for tool_name, errors in sorted_tools:
|
|
||||||
# Count failure types for this tool
|
|
||||||
failure_types = {}
|
|
||||||
for error in errors:
|
|
||||||
error_type = error.get("error_type", "Other")
|
|
||||||
if error_type not in failure_types:
|
|
||||||
failure_types[error_type] = 0
|
|
||||||
failure_types[error_type] += 1
|
|
||||||
|
|
||||||
# Display tool name and total failures
|
|
||||||
total_failures = len(errors)
|
|
||||||
safe_print(f"\n[yellow]{tool_name}[/yellow] ({total_failures} failures):")
|
|
||||||
|
|
||||||
# Sort failure types by count
|
|
||||||
sorted_types = sorted(
|
|
||||||
failure_types.items(),
|
|
||||||
key=lambda x: x[1],
|
|
||||||
reverse=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# Display each failure type with count and percentage
|
|
||||||
for failure_type, count in sorted_types:
|
|
||||||
percentage = (count / total_failures) * 100
|
|
||||||
safe_print(f" • {failure_type:<20} {count:>4} ({percentage:>5.1f}%)")
|
|
||||||
|
|
||||||
safe_print(f"\n[cyan]💾 Results saved to:[/cyan] {self.output_dir}")
|
safe_print(f"\n[cyan]💾 Results saved to:[/cyan] {self.output_dir}")
|
||||||
safe_print(f" - Trajectories: trajectories.jsonl (combined)")
|
|
||||||
safe_print(f" - Individual batches: batch_*.jsonl (for debugging)")
|
|
||||||
safe_print(f" - Statistics: {self.stats_file.name}")
|
|
||||||
safe_print(f" - Errors: {self.errors_file.name}")
|
|
||||||
safe_print(f" - Checkpoint: {self.checkpoint_file.name}")
|
|
||||||
|
|
||||||
if early_exit:
|
def run(self, resume: bool = False):
|
||||||
safe_print(f"\n[bold yellow]ℹ️ Run was stopped early due to tool failures.[/bold yellow]")
|
"""
|
||||||
safe_print(f"[yellow] Check {self.errors_file.name} for detailed error information including tracebacks.[/yellow]")
|
Run the batch processing pipeline (sync wrapper).
|
||||||
safe_print(f"[yellow] You can resume this run later with --resume flag.[/yellow]")
|
"""
|
||||||
|
asyncio.run(self._run_async(resume))
|
||||||
|
|
||||||
|
|
||||||
def main(
|
def main(
|
||||||
dataset_file: str = None,
|
dataset_file: str = None,
|
||||||
batch_size: int = None,
|
|
||||||
run_name: str = None,
|
run_name: str = None,
|
||||||
distribution: str = "default",
|
distribution: str = "default",
|
||||||
model: str = "claude-opus-4-20250514",
|
model: str = "claude-opus-4-20250514",
|
||||||
@@ -1150,17 +888,18 @@ def main(
|
|||||||
list_distributions: bool = False,
|
list_distributions: bool = False,
|
||||||
ephemeral_system_prompt: str = None,
|
ephemeral_system_prompt: str = None,
|
||||||
log_prefix_chars: int = 100,
|
log_prefix_chars: int = 100,
|
||||||
max_tool_failures: int = 10,
|
max_tool_failures: float = float("inf"),
|
||||||
max_tool_failure_rate: float = 0.5,
|
max_tool_failure_rate: float = 0.5,
|
||||||
keep_recent_errors: int = 5,
|
keep_recent_errors: int = 5,
|
||||||
min_tool_calls_for_rate: int = 10,
|
min_tool_calls_for_rate: int = 10,
|
||||||
|
prokletor_client: str = None,
|
||||||
|
prokletor_formatter: str = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Run batch processing of agent prompts from a dataset.
|
Run batch processing of agent prompts from a dataset.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dataset_file (str): Path to JSONL file with 'prompt' field in each entry
|
dataset_file (str): Path to JSONL file with 'prompt' field in each entry
|
||||||
batch_size (int): Number of prompts per batch
|
|
||||||
run_name (str): Name for this run (used for output and checkpointing)
|
run_name (str): Name for this run (used for output and checkpointing)
|
||||||
distribution (str): Toolset distribution to use (default: "default")
|
distribution (str): Toolset distribution to use (default: "default")
|
||||||
model (str): Model name to use (default: "claude-opus-4-20250514")
|
model (str): Model name to use (default: "claude-opus-4-20250514")
|
||||||
@@ -1173,31 +912,22 @@ def main(
|
|||||||
list_distributions (bool): List available toolset distributions and exit
|
list_distributions (bool): List available toolset distributions and exit
|
||||||
ephemeral_system_prompt (str): System prompt used during agent execution but NOT saved to trajectories (optional)
|
ephemeral_system_prompt (str): System prompt used during agent execution but NOT saved to trajectories (optional)
|
||||||
log_prefix_chars (int): Number of characters to show in log previews for tool calls/responses (default: 20)
|
log_prefix_chars (int): Number of characters to show in log previews for tool calls/responses (default: 20)
|
||||||
max_tool_failures (int): Maximum number of tool failures before stopping (default: 10)
|
max_tool_failures (float): Maximum number of tool failures before stopping (default: inf for unlimited)
|
||||||
max_tool_failure_rate (float): Maximum tool failure rate (0.0-1.0) before stopping (default: 0.5)
|
max_tool_failure_rate (float): Maximum tool failure rate (0.0-1.0) before stopping (default: 0.5)
|
||||||
keep_recent_errors (int): Number of recent errors to keep per tool for reporting (default: 5)
|
keep_recent_errors (int): Number of recent errors to keep per tool for reporting (default: 5)
|
||||||
min_tool_calls_for_rate (int): Minimum number of tool calls before checking failure rate (default: 10)
|
min_tool_calls_for_rate (int): Minimum number of tool calls before checking failure rate (default: 10)
|
||||||
|
prokletor_client (str): Name of the prokletor client to use
|
||||||
|
prokletor_formatter (str): Name of the prokletor formatter to use
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
# Basic usage
|
# Basic usage
|
||||||
python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run
|
python batch_runner.py --dataset_file=data.jsonl --run_name=my_run
|
||||||
|
|
||||||
# Resume interrupted run
|
# Resume interrupted run
|
||||||
python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run --resume
|
python batch_runner.py --dataset_file=data.jsonl --run_name=my_run --resume
|
||||||
|
|
||||||
# Use specific distribution
|
# Use specific distribution
|
||||||
python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=image_test --distribution=image_gen
|
python batch_runner.py --dataset_file=data.jsonl --run_name=image_test --distribution=image_gen
|
||||||
|
|
||||||
# With ephemeral system prompt (not saved to dataset)
|
|
||||||
python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run \\
|
|
||||||
--ephemeral_system_prompt="You are a helpful assistant focused on image generation."
|
|
||||||
|
|
||||||
# With custom tool failure thresholds
|
|
||||||
python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run \\
|
|
||||||
--max_tool_failures=20 --max_tool_failure_rate=0.3 --min_tool_calls_for_rate=10 --keep_recent_errors=10
|
|
||||||
|
|
||||||
# List available distributions
|
|
||||||
python batch_runner.py --list_distributions
|
|
||||||
"""
|
"""
|
||||||
# Handle list distributions
|
# Handle list distributions
|
||||||
if list_distributions:
|
if list_distributions:
|
||||||
@@ -1209,10 +939,6 @@ def main(
|
|||||||
all_dists = get_all_dists()
|
all_dists = get_all_dists()
|
||||||
for dist_name in sorted(all_dists.keys()):
|
for dist_name in sorted(all_dists.keys()):
|
||||||
print_distribution_info(dist_name)
|
print_distribution_info(dist_name)
|
||||||
|
|
||||||
print("\n💡 Usage:")
|
|
||||||
print(" python batch_runner.py --dataset_file=data.jsonl --batch_size=10 \\")
|
|
||||||
print(" --run_name=my_run --distribution=<name>")
|
|
||||||
return
|
return
|
||||||
|
|
||||||
# Validate required arguments
|
# Validate required arguments
|
||||||
@@ -1220,10 +946,6 @@ def main(
|
|||||||
print("❌ Error: --dataset_file is required")
|
print("❌ Error: --dataset_file is required")
|
||||||
return
|
return
|
||||||
|
|
||||||
if not batch_size or batch_size < 1:
|
|
||||||
print("❌ Error: --batch_size must be a positive integer")
|
|
||||||
return
|
|
||||||
|
|
||||||
if not run_name:
|
if not run_name:
|
||||||
print("❌ Error: --run_name is required")
|
print("❌ Error: --run_name is required")
|
||||||
return
|
return
|
||||||
@@ -1232,7 +954,6 @@ def main(
|
|||||||
try:
|
try:
|
||||||
runner = BatchRunner(
|
runner = BatchRunner(
|
||||||
dataset_file=dataset_file,
|
dataset_file=dataset_file,
|
||||||
batch_size=batch_size,
|
|
||||||
run_name=run_name,
|
run_name=run_name,
|
||||||
distribution=distribution,
|
distribution=distribution,
|
||||||
max_iterations=max_turns,
|
max_iterations=max_turns,
|
||||||
@@ -1246,7 +967,9 @@ def main(
|
|||||||
max_tool_failures=max_tool_failures,
|
max_tool_failures=max_tool_failures,
|
||||||
max_tool_failure_rate=max_tool_failure_rate,
|
max_tool_failure_rate=max_tool_failure_rate,
|
||||||
keep_recent_errors=keep_recent_errors,
|
keep_recent_errors=keep_recent_errors,
|
||||||
min_tool_calls_for_rate=min_tool_calls_for_rate
|
min_tool_calls_for_rate=min_tool_calls_for_rate,
|
||||||
|
prokletor_client=prokletor_client,
|
||||||
|
prokletor_formatter=prokletor_formatter
|
||||||
)
|
)
|
||||||
|
|
||||||
runner.run(resume=resume)
|
runner.run(resume=resume)
|
||||||
|
|||||||
12
gemini_nothinking.sh
Normal file
12
gemini_nothinking.sh
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
python batch_runner.py \
|
||||||
|
--dataset_file="source-data/agent_tasks_eval.jsonl" \
|
||||||
|
--batch_size=1 \
|
||||||
|
--run_name="agenttasks_eval_gemini-4.5-3-nothinking" \
|
||||||
|
--distribution="science" \
|
||||||
|
--model="gemini-3-pro-preview" \
|
||||||
|
--base_url="https://generativelanguage.googleapis.com/v1beta/openai/" \
|
||||||
|
--api_key="${GEMINI_API_KEY}" \
|
||||||
|
--num_workers=10 \
|
||||||
|
--max_turns=60 \
|
||||||
|
--verbose \
|
||||||
|
--ephemeral_system_prompt="You have access to a variety of tools to help you solve scientific, math, and technology problems presented to you. You can use them in sequence and build off of the results of prior tools you've used results. Always use the terminal or search tool if it can provide additional context, verify formulas, double check concepts and recent studies and understanding, doing all calculations, etc. You should only be confident in your own reasoning, knowledge, or calculations if you've exhaustively used all tools available to you to that can help you verify or validate your work. Always pip install any packages you need to use the python scripts you want to run. If you need to use a tool that isn't available, you can use the terminal tool to install or create it in many cases as well. Do not use the terminal tool to communicate with the user, as they cannot see your commands, only your final response after completing the task. If you require API keys please check which ones already exist in your environment variables in a way that does not read them."
|
||||||
@@ -23,7 +23,7 @@ Usage:
|
|||||||
web_tools = get_tool_definitions(enabled_toolsets=['web_tools'])
|
web_tools = get_tool_definitions(enabled_toolsets=['web_tools'])
|
||||||
|
|
||||||
# Handle function calls from model
|
# Handle function calls from model
|
||||||
result = handle_function_call("web_search", {"query": "Python"})
|
result = await handle_function_call("web_search", {"query": "Python"})
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
@@ -439,7 +439,7 @@ def get_tool_definitions(
|
|||||||
|
|
||||||
return filtered_tools
|
return filtered_tools
|
||||||
|
|
||||||
def handle_web_function_call(function_name: str, function_args: Dict[str, Any]) -> str:
|
async def handle_web_function_call(function_name: str, function_args: Dict[str, Any]) -> str:
|
||||||
"""
|
"""
|
||||||
Handle function calls for web tools.
|
Handle function calls for web tools.
|
||||||
|
|
||||||
@@ -454,25 +454,25 @@ def handle_web_function_call(function_name: str, function_args: Dict[str, Any])
|
|||||||
query = function_args.get("query", "")
|
query = function_args.get("query", "")
|
||||||
# Always use fixed limit of 5
|
# Always use fixed limit of 5
|
||||||
limit = 5
|
limit = 5
|
||||||
return web_search_tool(query, limit)
|
return await web_search_tool(query, limit)
|
||||||
|
|
||||||
elif function_name == "web_extract":
|
elif function_name == "web_extract":
|
||||||
urls = function_args.get("urls", [])
|
urls = function_args.get("urls", [])
|
||||||
# Limit URLs to prevent abuse
|
# Limit URLs to prevent abuse
|
||||||
urls = urls[:5] if isinstance(urls, list) else []
|
urls = urls[:5] if isinstance(urls, list) else []
|
||||||
# Run async function in event loop
|
# Run async function
|
||||||
return asyncio.run(web_extract_tool(urls, "markdown"))
|
return await web_extract_tool(urls, "markdown")
|
||||||
|
|
||||||
elif function_name == "web_crawl":
|
elif function_name == "web_crawl":
|
||||||
url = function_args.get("url", "")
|
url = function_args.get("url", "")
|
||||||
instructions = function_args.get("instructions")
|
instructions = function_args.get("instructions")
|
||||||
# Run async function in event loop
|
# Run async function
|
||||||
return asyncio.run(web_crawl_tool(url, instructions, "basic"))
|
return await web_crawl_tool(url, instructions, "basic")
|
||||||
|
|
||||||
else:
|
else:
|
||||||
return json.dumps({"error": f"Unknown web function: {function_name}"}, ensure_ascii=False)
|
return json.dumps({"error": f"Unknown web function: {function_name}"}, ensure_ascii=False)
|
||||||
|
|
||||||
def handle_terminal_function_call(function_name: str, function_args: Dict[str, Any], task_id: Optional[str] = None) -> str:
|
async def handle_terminal_function_call(function_name: str, function_args: Dict[str, Any], task_id: Optional[str] = None) -> str:
|
||||||
"""
|
"""
|
||||||
Handle function calls for terminal tools.
|
Handle function calls for terminal tools.
|
||||||
|
|
||||||
@@ -489,13 +489,20 @@ def handle_terminal_function_call(function_name: str, function_args: Dict[str, A
|
|||||||
background = function_args.get("background", False)
|
background = function_args.get("background", False)
|
||||||
timeout = function_args.get("timeout")
|
timeout = function_args.get("timeout")
|
||||||
|
|
||||||
return simple_terminal_tool(command=command, background=background, timeout=timeout, task_id=task_id)
|
# Run sync terminal tool in a thread to avoid blocking
|
||||||
|
return await asyncio.to_thread(
|
||||||
|
simple_terminal_tool,
|
||||||
|
command=command,
|
||||||
|
background=background,
|
||||||
|
timeout=timeout,
|
||||||
|
task_id=task_id
|
||||||
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
return json.dumps({"error": f"Unknown terminal function: {function_name}"}, ensure_ascii=False)
|
return json.dumps({"error": f"Unknown terminal function: {function_name}"}, ensure_ascii=False)
|
||||||
|
|
||||||
|
|
||||||
def handle_vision_function_call(function_name: str, function_args: Dict[str, Any]) -> str:
|
async def handle_vision_function_call(function_name: str, function_args: Dict[str, Any]) -> str:
|
||||||
"""
|
"""
|
||||||
Handle function calls for vision tools.
|
Handle function calls for vision tools.
|
||||||
|
|
||||||
@@ -512,14 +519,14 @@ def handle_vision_function_call(function_name: str, function_args: Dict[str, Any
|
|||||||
|
|
||||||
full_prompt = f"Fully describe and explain everything about this image, then answer the following question:\n\n{question}"
|
full_prompt = f"Fully describe and explain everything about this image, then answer the following question:\n\n{question}"
|
||||||
|
|
||||||
# Run async function in event loop
|
# Run async function
|
||||||
return asyncio.run(vision_analyze_tool(image_url, full_prompt, "gemini-2.5-flash"))
|
return await vision_analyze_tool(image_url, full_prompt, "gemini-2.5-flash")
|
||||||
|
|
||||||
else:
|
else:
|
||||||
return json.dumps({"error": f"Unknown vision function: {function_name}"}, ensure_ascii=False)
|
return json.dumps({"error": f"Unknown vision function: {function_name}"}, ensure_ascii=False)
|
||||||
|
|
||||||
|
|
||||||
def handle_moa_function_call(function_name: str, function_args: Dict[str, Any]) -> str:
|
async def handle_moa_function_call(function_name: str, function_args: Dict[str, Any]) -> str:
|
||||||
"""
|
"""
|
||||||
Handle function calls for Mixture-of-Agents tools.
|
Handle function calls for Mixture-of-Agents tools.
|
||||||
|
|
||||||
@@ -536,14 +543,14 @@ def handle_moa_function_call(function_name: str, function_args: Dict[str, Any])
|
|||||||
if not user_prompt:
|
if not user_prompt:
|
||||||
return json.dumps({"error": "user_prompt is required for MoA processing"}, ensure_ascii=False)
|
return json.dumps({"error": "user_prompt is required for MoA processing"}, ensure_ascii=False)
|
||||||
|
|
||||||
# Run async function in event loop
|
# Run async function
|
||||||
return asyncio.run(mixture_of_agents_tool(user_prompt=user_prompt))
|
return await mixture_of_agents_tool(user_prompt=user_prompt)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
return json.dumps({"error": f"Unknown MoA function: {function_name}"}, ensure_ascii=False)
|
return json.dumps({"error": f"Unknown MoA function: {function_name}"}, ensure_ascii=False)
|
||||||
|
|
||||||
|
|
||||||
def handle_image_function_call(function_name: str, function_args: Dict[str, Any]) -> str:
|
async def handle_image_function_call(function_name: str, function_args: Dict[str, Any]) -> str:
|
||||||
"""
|
"""
|
||||||
Handle function calls for image generation tools.
|
Handle function calls for image generation tools.
|
||||||
|
|
||||||
@@ -572,21 +579,8 @@ def handle_image_function_call(function_name: str, function_args: Dict[str, Any]
|
|||||||
allow_nsfw_images = True
|
allow_nsfw_images = True
|
||||||
seed = None
|
seed = None
|
||||||
|
|
||||||
# Run async function in event loop with proper handling for multiprocessing
|
# Run async function
|
||||||
try:
|
return await image_generate_tool(
|
||||||
# Try to get existing event loop
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
if loop.is_closed():
|
|
||||||
# If closed, create a new one
|
|
||||||
loop = asyncio.new_event_loop()
|
|
||||||
asyncio.set_event_loop(loop)
|
|
||||||
except RuntimeError:
|
|
||||||
# No event loop in current thread, create one
|
|
||||||
loop = asyncio.new_event_loop()
|
|
||||||
asyncio.set_event_loop(loop)
|
|
||||||
|
|
||||||
# Run the coroutine in the event loop
|
|
||||||
result = loop.run_until_complete(image_generate_tool(
|
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
image_size=image_size,
|
image_size=image_size,
|
||||||
num_inference_steps=num_inference_steps,
|
num_inference_steps=num_inference_steps,
|
||||||
@@ -597,15 +591,13 @@ def handle_image_function_call(function_name: str, function_args: Dict[str, Any]
|
|||||||
acceleration=acceleration,
|
acceleration=acceleration,
|
||||||
allow_nsfw_images=allow_nsfw_images,
|
allow_nsfw_images=allow_nsfw_images,
|
||||||
seed=seed
|
seed=seed
|
||||||
))
|
)
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
return json.dumps({"error": f"Unknown image generation function: {function_name}"}, ensure_ascii=False)
|
return json.dumps({"error": f"Unknown image generation function: {function_name}"}, ensure_ascii=False)
|
||||||
|
|
||||||
|
|
||||||
def handle_function_call(function_name: str, function_args: Dict[str, Any], task_id: Optional[str] = None) -> str:
|
async def handle_function_call(function_name: str, function_args: Dict[str, Any], task_id: Optional[str] = None) -> str:
|
||||||
"""
|
"""
|
||||||
Main function call dispatcher that routes calls to appropriate toolsets.
|
Main function call dispatcher that routes calls to appropriate toolsets.
|
||||||
|
|
||||||
@@ -627,23 +619,23 @@ def handle_function_call(function_name: str, function_args: Dict[str, Any], task
|
|||||||
try:
|
try:
|
||||||
# Route web tools
|
# Route web tools
|
||||||
if function_name in ["web_search", "web_extract", "web_crawl"]:
|
if function_name in ["web_search", "web_extract", "web_crawl"]:
|
||||||
return handle_web_function_call(function_name, function_args)
|
return await handle_web_function_call(function_name, function_args)
|
||||||
|
|
||||||
# Route terminal tools
|
# Route terminal tools
|
||||||
elif function_name in ["terminal"]:
|
elif function_name in ["terminal"]:
|
||||||
return handle_terminal_function_call(function_name, function_args, task_id)
|
return await handle_terminal_function_call(function_name, function_args, task_id)
|
||||||
|
|
||||||
# Route vision tools
|
# Route vision tools
|
||||||
elif function_name in ["vision_analyze"]:
|
elif function_name in ["vision_analyze"]:
|
||||||
return handle_vision_function_call(function_name, function_args)
|
return await handle_vision_function_call(function_name, function_args)
|
||||||
|
|
||||||
# Route MoA tools
|
# Route MoA tools
|
||||||
elif function_name in ["mixture_of_agents"]:
|
elif function_name in ["mixture_of_agents"]:
|
||||||
return handle_moa_function_call(function_name, function_args)
|
return await handle_moa_function_call(function_name, function_args)
|
||||||
|
|
||||||
# Route image generation tools
|
# Route image generation tools
|
||||||
elif function_name in ["image_generate"]:
|
elif function_name in ["image_generate"]:
|
||||||
return handle_image_function_call(function_name, function_args)
|
return await handle_image_function_call(function_name, function_args)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
error_msg = f"Unknown function: {function_name}"
|
error_msg = f"Unknown function: {function_name}"
|
||||||
|
|||||||
369
run_agent.py
369
run_agent.py
@@ -24,11 +24,23 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
|
import asyncio
|
||||||
|
import sys
|
||||||
from typing import List, Dict, Any, Optional
|
from typing import List, Dict, Any, Optional
|
||||||
from openai import OpenAI
|
from openai import AsyncOpenAI
|
||||||
import fire
|
import fire
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from rich import print
|
||||||
|
|
||||||
|
from prokletor.formatters.hermes import HermesToolFormatterWithReasoning
|
||||||
|
from prokletor.formatters.hermes import HermesToolFormatterWithReasoning
|
||||||
|
from prokletor.clients.hermes import HermesToolClientWithReasoning, HermesToolClient
|
||||||
|
from prokletor.clients.claude import AsyncClaudeClient
|
||||||
|
try:
|
||||||
|
from anthropic import AsyncAnthropic
|
||||||
|
except ImportError:
|
||||||
|
AsyncAnthropic = None
|
||||||
|
|
||||||
# Load environment variables from .env file
|
# Load environment variables from .env file
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
@@ -70,6 +82,8 @@ class AIAgent:
|
|||||||
verbose_logging: bool = False,
|
verbose_logging: bool = False,
|
||||||
ephemeral_system_prompt: str = None,
|
ephemeral_system_prompt: str = None,
|
||||||
log_prefix_chars: int = 100,
|
log_prefix_chars: int = 100,
|
||||||
|
prokletor_client: str = None,
|
||||||
|
prokletor_formatter: str = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize the AI Agent.
|
Initialize the AI Agent.
|
||||||
@@ -86,6 +100,8 @@ class AIAgent:
|
|||||||
verbose_logging (bool): Enable verbose logging for debugging (default: False)
|
verbose_logging (bool): Enable verbose logging for debugging (default: False)
|
||||||
ephemeral_system_prompt (str): System prompt used during agent execution but NOT saved to trajectories (optional)
|
ephemeral_system_prompt (str): System prompt used during agent execution but NOT saved to trajectories (optional)
|
||||||
log_prefix_chars (int): Number of characters to show in log previews for tool calls/responses (default: 20)
|
log_prefix_chars (int): Number of characters to show in log previews for tool calls/responses (default: 20)
|
||||||
|
prokletor_client (str): Name of the prokletor client to use (e.g., "AsyncClaudeClient", "HermesToolClient")
|
||||||
|
prokletor_formatter (str): Name of the prokletor formatter to use (optional)
|
||||||
"""
|
"""
|
||||||
self.model = model
|
self.model = model
|
||||||
self.max_iterations = max_iterations
|
self.max_iterations = max_iterations
|
||||||
@@ -94,6 +110,8 @@ class AIAgent:
|
|||||||
self.verbose_logging = verbose_logging
|
self.verbose_logging = verbose_logging
|
||||||
self.ephemeral_system_prompt = ephemeral_system_prompt
|
self.ephemeral_system_prompt = ephemeral_system_prompt
|
||||||
self.log_prefix_chars = log_prefix_chars
|
self.log_prefix_chars = log_prefix_chars
|
||||||
|
self.prokletor_client_name = prokletor_client
|
||||||
|
self.prokletor_formatter_name = prokletor_formatter
|
||||||
|
|
||||||
# Store toolset filtering options
|
# Store toolset filtering options
|
||||||
self.enabled_toolsets = enabled_toolsets
|
self.enabled_toolsets = enabled_toolsets
|
||||||
@@ -122,7 +140,7 @@ class AIAgent:
|
|||||||
logging.getLogger('openai').setLevel(logging.WARNING)
|
logging.getLogger('openai').setLevel(logging.WARNING)
|
||||||
logging.getLogger('httpx').setLevel(logging.WARNING)
|
logging.getLogger('httpx').setLevel(logging.WARNING)
|
||||||
|
|
||||||
# Initialize OpenAI client
|
# Initialize Client
|
||||||
client_kwargs = {}
|
client_kwargs = {}
|
||||||
if base_url:
|
if base_url:
|
||||||
client_kwargs["base_url"] = base_url
|
client_kwargs["base_url"] = base_url
|
||||||
@@ -132,12 +150,45 @@ class AIAgent:
|
|||||||
client_kwargs["api_key"] = os.getenv("ANTHROPIC_API_KEY", "dummy-key")
|
client_kwargs["api_key"] = os.getenv("ANTHROPIC_API_KEY", "dummy-key")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.client = OpenAI(**client_kwargs)
|
if prokletor_client == "AsyncClaudeClient":
|
||||||
|
if AsyncAnthropic is None:
|
||||||
|
raise ImportError("anthropic package is required for AsyncClaudeClient")
|
||||||
|
|
||||||
|
# AsyncAnthropic kwargs
|
||||||
|
anthropic_kwargs = {k: v for k, v in client_kwargs.items() if k in ["api_key", "base_url", "timeout", "max_retries", "default_headers"]}
|
||||||
|
|
||||||
|
anthropic_client = AsyncAnthropic(**anthropic_kwargs)
|
||||||
|
self.client = AsyncClaudeClient(anthropic_client)
|
||||||
|
print(f"🧠 Wrapped Anthropic client with AsyncClaudeClient")
|
||||||
|
|
||||||
|
elif prokletor_client == "HermesToolClient":
|
||||||
|
oai_client = AsyncOpenAI(**client_kwargs)
|
||||||
|
self.client = HermesToolClient(oai_client)
|
||||||
|
print(f"🧠 Wrapped OpenAI client with HermesToolClient")
|
||||||
|
|
||||||
|
elif prokletor_client == "HermesToolClientWithReasoning":
|
||||||
|
oai_client = AsyncOpenAI(**client_kwargs)
|
||||||
|
self.client = HermesToolClientWithReasoning(oai_client)
|
||||||
|
print(f"🧠 Wrapped OpenAI client with HermesToolClientWithReasoning")
|
||||||
|
|
||||||
|
elif prokletor_client:
|
||||||
|
# Fallback for unknown client names or if user provides a custom one (future proofing?)
|
||||||
|
# For now, raise error or default to OpenAI
|
||||||
|
print(f"⚠️ Unknown prokletor_client '{prokletor_client}'. Defaulting to HermesToolClientWithReasoning.")
|
||||||
|
oai_client = AsyncOpenAI(**client_kwargs)
|
||||||
|
self.client = HermesToolClientWithReasoning(oai_client)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Default behavior
|
||||||
|
oai_client = AsyncOpenAI(**client_kwargs)
|
||||||
|
self.client = oai_client
|
||||||
|
print(f"🧠 Using raw OpenAI client (no prokletor wrapper)")
|
||||||
|
|
||||||
print(f"🤖 AI Agent initialized with model: {self.model}")
|
print(f"🤖 AI Agent initialized with model: {self.model}")
|
||||||
if base_url:
|
if base_url:
|
||||||
print(f"🔗 Using custom base URL: {base_url}")
|
print(f"🔗 Using custom base URL: {base_url}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise RuntimeError(f"Failed to initialize OpenAI client: {e}")
|
raise RuntimeError(f"Failed to initialize client: {e}")
|
||||||
|
|
||||||
# Get available tools with filtering
|
# Get available tools with filtering
|
||||||
self.tools = get_tool_definitions(
|
self.tools = get_tool_definitions(
|
||||||
@@ -210,22 +261,54 @@ class AIAgent:
|
|||||||
Returns:
|
Returns:
|
||||||
List[Dict]: Messages in trajectory format
|
List[Dict]: Messages in trajectory format
|
||||||
"""
|
"""
|
||||||
|
# Use the client wrapper's format method if available to get the exact Hermes format
|
||||||
|
# This ensures batch runner also gets the correct formatting
|
||||||
|
if hasattr(self, 'client') and hasattr(self.client, 'format'):
|
||||||
|
formatted_messages = self.client.format(messages, self.tools, render_final=True)
|
||||||
|
|
||||||
|
trajectory = []
|
||||||
|
for msg in formatted_messages:
|
||||||
|
role = msg["role"]
|
||||||
|
content = msg["content"]
|
||||||
|
|
||||||
|
# Map roles to trajectory format (human, gpt, system, tool)
|
||||||
|
if role == "user":
|
||||||
|
trajectory_role = "human"
|
||||||
|
elif role == "assistant":
|
||||||
|
trajectory_role = "gpt"
|
||||||
|
elif role == "system":
|
||||||
|
trajectory_role = "system"
|
||||||
|
elif role == "tool":
|
||||||
|
trajectory_role = "tool"
|
||||||
|
else:
|
||||||
|
trajectory_role = role
|
||||||
|
|
||||||
|
trajectory.append({
|
||||||
|
"from": trajectory_role,
|
||||||
|
"value": content
|
||||||
|
})
|
||||||
|
return trajectory
|
||||||
|
|
||||||
trajectory = []
|
trajectory = []
|
||||||
|
|
||||||
# Add system message with tool definitions
|
# Add system message with tool definitions
|
||||||
system_msg = (
|
# Use the client's formatter if available to ensure consistency (e.g. reasoning prompt)
|
||||||
"You are a function calling AI model. You are provided with function signatures within <tools> </tools> XML tags. "
|
if hasattr(self, 'client') and hasattr(self.client, 'formatter'):
|
||||||
"You may call one or more functions to assist with the user query. If available tools are not relevant in assisting "
|
system_msg = self.client.formatter.format_system_message(self.tools if self.tools else [])
|
||||||
"with user query, just respond in natural conversational language. Don't make assumptions about what values to plug "
|
else:
|
||||||
"into functions. After calling & executing the functions, you will be provided with function results within "
|
system_msg = (
|
||||||
"<tool_response> </tool_response> XML tags. Here are the available tools:\n"
|
"You are a function calling AI model. You are provided with function signatures within <tools> </tools> XML tags. "
|
||||||
f"<tools>\n{self._format_tools_for_system_message()}\n</tools>\n"
|
"You may call one or more functions to assist with the user query. If available tools are not relevant in assisting "
|
||||||
"For each function call return a JSON object, with the following pydantic model json schema for each:\n"
|
"with user query, just respond in natural conversational language. Don't make assumptions about what values to plug "
|
||||||
"{'title': 'FunctionCall', 'type': 'object', 'properties': {'name': {'title': 'Name', 'type': 'string'}, "
|
"into functions. After calling & executing the functions, you will be provided with function results within "
|
||||||
"'arguments': {'title': 'Arguments', 'type': 'object'}}, 'required': ['name', 'arguments']}\n"
|
"<tool_response> </tool_response> XML tags. Here are the available tools:\n"
|
||||||
"Each function call should be enclosed within <tool_call> </tool_call> XML tags.\n"
|
f"<tools>\n{self._format_tools_for_system_message()}\n</tools>\n"
|
||||||
"Example:\n<tool_call>\n{'name': <function-name>,'arguments': <args-dict>}\n</tool_call>"
|
"For each function call return a JSON object, with the following pydantic model json schema for each:\n"
|
||||||
)
|
"{'title': 'FunctionCall', 'type': 'object', 'properties': {'name': {'title': 'Name', 'type': 'string'}, "
|
||||||
|
"'arguments': {'title': 'Arguments', 'type': 'object'}}, 'required': ['name', 'arguments']}\n"
|
||||||
|
"Each function call should be enclosed within <tool_call> </tool_call> XML tags.\n"
|
||||||
|
"Example:\n<tool_call>\n{'name': <function-name>,'arguments': <args-dict>}\n</tool_call>"
|
||||||
|
)
|
||||||
|
|
||||||
trajectory.append({
|
trajectory.append({
|
||||||
"from": "system",
|
"from": "system",
|
||||||
@@ -348,7 +431,7 @@ class AIAgent:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"⚠️ Failed to save trajectory: {e}")
|
print(f"⚠️ Failed to save trajectory: {e}")
|
||||||
|
|
||||||
def run_conversation(
|
async def run_conversation(
|
||||||
self,
|
self,
|
||||||
user_message: str,
|
user_message: str,
|
||||||
system_message: str = None,
|
system_message: str = None,
|
||||||
@@ -401,10 +484,14 @@ class AIAgent:
|
|||||||
if self.verbose_logging:
|
if self.verbose_logging:
|
||||||
logging.debug(f"API Request - Model: {self.model}, Messages: {len(messages)}, Tools: {len(self.tools) if self.tools else 0}")
|
logging.debug(f"API Request - Model: {self.model}, Messages: {len(messages)}, Tools: {len(self.tools) if self.tools else 0}")
|
||||||
logging.debug(f"Last message role: {messages[-1]['role'] if messages else 'none'}")
|
logging.debug(f"Last message role: {messages[-1]['role'] if messages else 'none'}")
|
||||||
|
# Log the last few messages to see if thought_signature is present
|
||||||
|
logging.debug(f"Last message content: {json.dumps(messages[-1] if messages else {}, indent=2)}")
|
||||||
|
|
||||||
api_start_time = time.time()
|
api_start_time = time.time()
|
||||||
retry_count = 0
|
retry_count = 0
|
||||||
max_retries = 6 # Increased to allow longer backoff periods
|
max_retries = 6 # Increased to allow longer backoff periods
|
||||||
|
response = None
|
||||||
|
last_api_error = None
|
||||||
|
|
||||||
while retry_count <= max_retries:
|
while retry_count <= max_retries:
|
||||||
try:
|
try:
|
||||||
@@ -416,12 +503,23 @@ class AIAgent:
|
|||||||
api_messages = [{"role": "system", "content": active_system_prompt}] + api_messages
|
api_messages = [{"role": "system", "content": active_system_prompt}] + api_messages
|
||||||
|
|
||||||
# Make API call with tools
|
# Make API call with tools
|
||||||
response = self.client.chat.completions.create(
|
api_kwargs = {
|
||||||
model=self.model,
|
"model": self.model,
|
||||||
messages=api_messages,
|
"messages": api_messages,
|
||||||
tools=self.tools if self.tools else None,
|
"tools": self.tools if self.tools else None,
|
||||||
timeout=300.0 # 5 minute timeout for long-running agent tasks
|
"timeout": 300.0, # 5 minute timeout for long-running agent tasks
|
||||||
)
|
}
|
||||||
|
|
||||||
|
# Enable thinking by default for AsyncClaudeClient if using a supported model
|
||||||
|
if self.prokletor_client_name == "AsyncClaudeClient" and self.model.startswith("claude"):
|
||||||
|
api_kwargs["thinking"] = {
|
||||||
|
"type": "enabled",
|
||||||
|
"budget_tokens": 8000
|
||||||
|
}
|
||||||
|
# Ensure max_tokens is set higher than budget_tokens
|
||||||
|
api_kwargs["max_tokens"] = 16000
|
||||||
|
|
||||||
|
response = await self.client.chat.completions.create(**api_kwargs)
|
||||||
|
|
||||||
api_duration = time.time() - api_start_time
|
api_duration = time.time() - api_start_time
|
||||||
print(f"⏱️ OpenAI-compatible API call completed in {api_duration:.2f}s")
|
print(f"⏱️ OpenAI-compatible API call completed in {api_duration:.2f}s")
|
||||||
@@ -435,6 +533,15 @@ class AIAgent:
|
|||||||
break # Success, exit retry loop
|
break # Success, exit retry loop
|
||||||
|
|
||||||
except Exception as api_error:
|
except Exception as api_error:
|
||||||
|
last_api_error = api_error
|
||||||
|
error_message = str(api_error)
|
||||||
|
token_limit_error = "input token count exceeds the maximum number of tokens" in error_message.lower()
|
||||||
|
|
||||||
|
if token_limit_error:
|
||||||
|
print("❌ OpenAI-compatible API call failed: input token limit exceeded. Not retrying this request.")
|
||||||
|
logging.error("Non-retryable token limit error from API: %s", api_error)
|
||||||
|
break
|
||||||
|
|
||||||
retry_count += 1
|
retry_count += 1
|
||||||
if retry_count > max_retries:
|
if retry_count > max_retries:
|
||||||
raise api_error
|
raise api_error
|
||||||
@@ -443,7 +550,10 @@ class AIAgent:
|
|||||||
print(f"⚠️ OpenAI-compatible API call failed (attempt {retry_count}/{max_retries}): {str(api_error)[:100]}")
|
print(f"⚠️ OpenAI-compatible API call failed (attempt {retry_count}/{max_retries}): {str(api_error)[:100]}")
|
||||||
print(f"⏳ Retrying in {wait_time}s...")
|
print(f"⏳ Retrying in {wait_time}s...")
|
||||||
logging.warning(f"API retry {retry_count}/{max_retries} after error: {api_error}")
|
logging.warning(f"API retry {retry_count}/{max_retries} after error: {api_error}")
|
||||||
time.sleep(wait_time)
|
await asyncio.sleep(wait_time)
|
||||||
|
|
||||||
|
if response is None:
|
||||||
|
raise last_api_error if last_api_error else RuntimeError("OpenAI-compatible API call failed without a response")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
assistant_message = response.choices[0].message
|
assistant_message = response.choices[0].message
|
||||||
@@ -459,25 +569,62 @@ class AIAgent:
|
|||||||
if self.verbose_logging:
|
if self.verbose_logging:
|
||||||
for tc in assistant_message.tool_calls:
|
for tc in assistant_message.tool_calls:
|
||||||
logging.debug(f"Tool call: {tc.function.name} with args: {tc.function.arguments[:200]}...")
|
logging.debug(f"Tool call: {tc.function.name} with args: {tc.function.arguments[:200]}...")
|
||||||
|
# Debug: Check what attributes are available on tool_call
|
||||||
|
logging.debug(f"Tool call attributes: {dir(tc)}")
|
||||||
|
# Try to dump the model to see all fields
|
||||||
|
if hasattr(tc, 'model_dump'):
|
||||||
|
logging.debug(f"Tool call data: {tc.model_dump()}")
|
||||||
|
|
||||||
# Add assistant message with tool calls to conversation
|
# Add assistant message with tool calls to conversation
|
||||||
|
# Extract thought_signature if present (required for Gemini models)
|
||||||
|
tool_calls_data = []
|
||||||
|
for tool_call in assistant_message.tool_calls:
|
||||||
|
tool_call_dict = {
|
||||||
|
"id": tool_call.id,
|
||||||
|
"type": tool_call.type,
|
||||||
|
"function": {
|
||||||
|
"name": tool_call.function.name,
|
||||||
|
"arguments": tool_call.function.arguments
|
||||||
|
}
|
||||||
|
}
|
||||||
|
# Try multiple ways to access thought_signature (Gemini-specific)
|
||||||
|
# Gemini uses extra_content.google.thought_signature structure
|
||||||
|
thought_sig = None
|
||||||
|
|
||||||
|
# Method 1: Check extra_content attribute
|
||||||
|
if hasattr(tool_call, 'extra_content'):
|
||||||
|
extra = tool_call.extra_content
|
||||||
|
if isinstance(extra, dict) and 'google' in extra:
|
||||||
|
thought_sig = extra['google'].get('thought_signature')
|
||||||
|
|
||||||
|
# Method 2: Check model_dump() if available (Pydantic v2)
|
||||||
|
if thought_sig is None and hasattr(tool_call, 'model_dump'):
|
||||||
|
dumped = tool_call.model_dump()
|
||||||
|
if 'extra_content' in dumped and isinstance(dumped['extra_content'], dict):
|
||||||
|
google_data = dumped['extra_content'].get('google', {})
|
||||||
|
thought_sig = google_data.get('thought_signature')
|
||||||
|
|
||||||
|
if thought_sig is not None:
|
||||||
|
tool_call_dict["extra_content"] = {
|
||||||
|
"google": {
|
||||||
|
"thought_signature": thought_sig
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if self.verbose_logging:
|
||||||
|
logging.debug(f"Captured thought_signature for tool call {tool_call.id}")
|
||||||
|
elif self.verbose_logging:
|
||||||
|
logging.debug(f"No thought_signature found for tool call {tool_call.id}")
|
||||||
|
|
||||||
|
tool_calls_data.append(tool_call_dict)
|
||||||
|
|
||||||
messages.append({
|
messages.append({
|
||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
"content": assistant_message.content,
|
"content": assistant_message.content,
|
||||||
"tool_calls": [
|
"tool_calls": tool_calls_data
|
||||||
{
|
|
||||||
"id": tool_call.id,
|
|
||||||
"type": tool_call.type,
|
|
||||||
"function": {
|
|
||||||
"name": tool_call.function.name,
|
|
||||||
"arguments": tool_call.function.arguments
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for tool_call in assistant_message.tool_calls
|
|
||||||
]
|
|
||||||
})
|
})
|
||||||
|
|
||||||
# Execute each tool call
|
# Execute tool calls concurrently
|
||||||
|
tool_tasks = []
|
||||||
for i, tool_call in enumerate(assistant_message.tool_calls, 1):
|
for i, tool_call in enumerate(assistant_message.tool_calls, 1):
|
||||||
function_name = tool_call.function.name
|
function_name = tool_call.function.name
|
||||||
|
|
||||||
@@ -492,35 +639,55 @@ class AIAgent:
|
|||||||
args_preview = args_str[:self.log_prefix_chars] + "..." if len(args_str) > self.log_prefix_chars else args_str
|
args_preview = args_str[:self.log_prefix_chars] + "..." if len(args_str) > self.log_prefix_chars else args_str
|
||||||
print(f" 📞 Tool {i}: {function_name}({list(function_args.keys())}) - {args_preview}")
|
print(f" 📞 Tool {i}: {function_name}({list(function_args.keys())}) - {args_preview}")
|
||||||
|
|
||||||
|
# Create coroutine for tool execution
|
||||||
|
task = handle_function_call(function_name, function_args, effective_task_id)
|
||||||
|
tool_tasks.append(task)
|
||||||
|
|
||||||
|
if tool_tasks:
|
||||||
tool_start_time = time.time()
|
tool_start_time = time.time()
|
||||||
|
|
||||||
# Execute the tool with task_id to isolate VMs between concurrent tasks
|
# Execute all tools concurrently
|
||||||
function_result = handle_function_call(function_name, function_args, effective_task_id)
|
# We use return_exceptions=True to ensure one failure doesn't stop others
|
||||||
|
# Order of results corresponds to order of tasks
|
||||||
|
results = await asyncio.gather(*tool_tasks, return_exceptions=True)
|
||||||
|
|
||||||
tool_duration = time.time() - tool_start_time
|
tool_duration = time.time() - tool_start_time
|
||||||
result_preview = function_result[:200] if len(function_result) > 200 else function_result
|
|
||||||
|
|
||||||
# Record tool timing in profiler
|
# Process results
|
||||||
get_profiler().record_tool_timing(function_name, tool_duration)
|
for i, (result, tool_call) in enumerate(zip(results, assistant_message.tool_calls), 1):
|
||||||
|
function_name = tool_call.function.name
|
||||||
|
|
||||||
if self.verbose_logging:
|
# Handle exceptions from asyncio.gather
|
||||||
logging.debug(f"Tool {function_name} completed in {tool_duration:.2f}s")
|
if isinstance(result, Exception):
|
||||||
logging.debug(f"Tool result preview: {result_preview}...")
|
function_result = json.dumps({"error": str(result)}, ensure_ascii=False)
|
||||||
|
print(f"❌ Tool {i} ({function_name}) failed: {result}")
|
||||||
|
else:
|
||||||
|
function_result = result
|
||||||
|
|
||||||
# Add tool result to conversation
|
result_preview = function_result[:200] if len(function_result) > 200 else function_result
|
||||||
messages.append({
|
|
||||||
"role": "tool",
|
|
||||||
"content": function_result,
|
|
||||||
"tool_call_id": tool_call.id
|
|
||||||
})
|
|
||||||
|
|
||||||
# Preview tool response
|
# Record tool timing in profiler (approximate since they ran in parallel)
|
||||||
response_preview = function_result[:self.log_prefix_chars] + "..." if len(function_result) > self.log_prefix_chars else function_result
|
get_profiler().record_tool_timing(function_name, tool_duration)
|
||||||
print(f" ✅ Tool {i} completed in {tool_duration:.2f}s - {response_preview}")
|
|
||||||
|
|
||||||
# Delay between tool calls
|
if self.verbose_logging:
|
||||||
if self.tool_delay > 0 and i < len(assistant_message.tool_calls):
|
logging.debug(f"Tool {function_name} completed in parallel batch")
|
||||||
time.sleep(self.tool_delay)
|
logging.debug(f"Tool result preview: {result_preview}...")
|
||||||
|
|
||||||
|
# Add tool result to conversation
|
||||||
|
# Note: thought_signature should NOT be in tool responses, only in assistant messages
|
||||||
|
messages.append({
|
||||||
|
"role": "tool",
|
||||||
|
"content": function_result,
|
||||||
|
"tool_call_id": tool_call.id
|
||||||
|
})
|
||||||
|
|
||||||
|
# Preview tool response
|
||||||
|
response_preview = function_result[:self.log_prefix_chars] + "..." if len(function_result) > self.log_prefix_chars else function_result
|
||||||
|
print(f" ✅ Tool {i} completed - {response_preview}")
|
||||||
|
|
||||||
|
# Optional delay after batch execution
|
||||||
|
if self.tool_delay > 0:
|
||||||
|
await asyncio.sleep(self.tool_delay)
|
||||||
|
|
||||||
# Continue loop for next response
|
# Continue loop for next response
|
||||||
continue
|
continue
|
||||||
@@ -566,11 +733,79 @@ class AIAgent:
|
|||||||
completed = final_response is not None and api_call_count < self.max_iterations
|
completed = final_response is not None and api_call_count < self.max_iterations
|
||||||
|
|
||||||
# Save trajectory if enabled
|
# Save trajectory if enabled
|
||||||
self._save_trajectory(messages, user_message, completed)
|
# When saving trajectory, we want to show what the prompt would look like with proper tool roles
|
||||||
|
# This is helpful for training data or debugging
|
||||||
|
if self.save_trajectories:
|
||||||
|
# Use the client wrapper's format method if available to get the exact Hermes format
|
||||||
|
if hasattr(self, 'client') and hasattr(self.client, 'format'):
|
||||||
|
raise ValueError("reached this point")
|
||||||
|
formatted_messages = self.client.format(messages, self.tools, render_final=True)
|
||||||
|
|
||||||
|
# We need to adapt this formatted list to the trajectory format expected by _save_trajectory
|
||||||
|
# Since _convert_to_trajectory_format expects raw OAI messages, we might need a different approach
|
||||||
|
# OR just pass the formatted messages directly if _save_trajectory supports it.
|
||||||
|
|
||||||
|
# Let's look at _convert_to_trajectory_format. It iterates through messages and converts them.
|
||||||
|
# If we pass messages that are already formatted (e.g. system prompt with tools, tool calls in XML),
|
||||||
|
# we need to be careful not to double-format.
|
||||||
|
|
||||||
|
# Actually, the goal is to save the trajectory in a specific JSONL format for training/eval.
|
||||||
|
# If we use the Hermes formatter, it produces a list of messages where content is XML strings.
|
||||||
|
# The existing _convert_to_trajectory_format does manual XML wrapping.
|
||||||
|
|
||||||
|
# Ideally, we should use the messages as they are (OAI format) and let the training pipeline handle formatting,
|
||||||
|
# OR save them in the exact format the model sees.
|
||||||
|
|
||||||
|
# The user request is: "accumulating history in oai format and then calling that final thing with use_tool_call True"
|
||||||
|
# referring to client.format(messages, tools, use_tool_role=True)
|
||||||
|
|
||||||
|
# So let's save the RESULT of client.format() to the trajectory file.
|
||||||
|
|
||||||
|
# Create a custom trajectory entry directly from the formatted messages
|
||||||
|
trajectory_content = []
|
||||||
|
for msg in formatted_messages:
|
||||||
|
role = msg["role"]
|
||||||
|
content = msg["content"]
|
||||||
|
|
||||||
|
# Map roles to trajectory format (human, gpt, system, tool)
|
||||||
|
if role == "user":
|
||||||
|
trajectory_role = "human"
|
||||||
|
elif role == "assistant":
|
||||||
|
trajectory_role = "gpt"
|
||||||
|
elif role == "system":
|
||||||
|
trajectory_role = "system"
|
||||||
|
elif role == "tool":
|
||||||
|
trajectory_role = "tool"
|
||||||
|
else:
|
||||||
|
trajectory_role = role
|
||||||
|
|
||||||
|
trajectory_content.append({
|
||||||
|
"from": trajectory_role,
|
||||||
|
"value": content
|
||||||
|
})
|
||||||
|
|
||||||
|
# Save this specific formatted trajectory
|
||||||
|
filename = "trajectory_samples.jsonl" if completed else "failed_trajectories.jsonl"
|
||||||
|
entry = {
|
||||||
|
"conversations": trajectory_content,
|
||||||
|
"timestamp": datetime.now().isoformat(),
|
||||||
|
"model": self.model,
|
||||||
|
"completed": completed
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(filename, "a", encoding="utf-8") as f:
|
||||||
|
f.write(json.dumps(entry, ensure_ascii=False) + "\n")
|
||||||
|
print(f"💾 Trajectory saved to {filename} (using Hermes format)")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"⚠️ Failed to save trajectory: {e}")
|
||||||
|
else:
|
||||||
|
# Fallback to original saving method
|
||||||
|
self._save_trajectory(messages, user_message, completed)
|
||||||
|
|
||||||
# Clean up VM for this task after conversation completes
|
# Clean up VM for this task after conversation completes
|
||||||
try:
|
try:
|
||||||
cleanup_vm(effective_task_id)
|
await asyncio.to_thread(cleanup_vm, effective_task_id)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if self.verbose_logging:
|
if self.verbose_logging:
|
||||||
logging.warning(f"Failed to cleanup VM for task {effective_task_id}: {e}")
|
logging.warning(f"Failed to cleanup VM for task {effective_task_id}: {e}")
|
||||||
@@ -586,7 +821,7 @@ class AIAgent:
|
|||||||
"profiling_stats": profiling_stats
|
"profiling_stats": profiling_stats
|
||||||
}
|
}
|
||||||
|
|
||||||
def chat(self, message: str) -> str:
|
async def chat(self, message: str) -> str:
|
||||||
"""
|
"""
|
||||||
Simple chat interface that returns just the final response.
|
Simple chat interface that returns just the final response.
|
||||||
|
|
||||||
@@ -596,7 +831,7 @@ class AIAgent:
|
|||||||
Returns:
|
Returns:
|
||||||
str: Final assistant response
|
str: Final assistant response
|
||||||
"""
|
"""
|
||||||
result = self.run_conversation(message)
|
result = await self.run_conversation(message)
|
||||||
return result["final_response"]
|
return result["final_response"]
|
||||||
|
|
||||||
|
|
||||||
@@ -612,7 +847,9 @@ def main(
|
|||||||
save_trajectories: bool = False,
|
save_trajectories: bool = False,
|
||||||
verbose: bool = False,
|
verbose: bool = False,
|
||||||
log_prefix_chars: int = 20,
|
log_prefix_chars: int = 20,
|
||||||
show_profiling: bool = True
|
show_profiling: bool = True,
|
||||||
|
prokletor_client: str = None,
|
||||||
|
prokletor_formatter: str = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Main function for running the agent directly.
|
Main function for running the agent directly.
|
||||||
@@ -632,6 +869,8 @@ def main(
|
|||||||
verbose (bool): Enable verbose logging for debugging. Defaults to False.
|
verbose (bool): Enable verbose logging for debugging. Defaults to False.
|
||||||
log_prefix_chars (int): Number of characters to show in log previews for tool calls/responses. Defaults to 20.
|
log_prefix_chars (int): Number of characters to show in log previews for tool calls/responses. Defaults to 20.
|
||||||
show_profiling (bool): Display profiling statistics after conversation. Defaults to True.
|
show_profiling (bool): Display profiling statistics after conversation. Defaults to True.
|
||||||
|
prokletor_client (str): Name of the prokletor client to use (e.g., "AsyncClaudeClient")
|
||||||
|
prokletor_formatter (str): Name of the prokletor formatter to use
|
||||||
|
|
||||||
Toolset Examples:
|
Toolset Examples:
|
||||||
- "research": Web search, extract, crawl + vision tools
|
- "research": Web search, extract, crawl + vision tools
|
||||||
@@ -750,7 +989,9 @@ def main(
|
|||||||
disabled_toolsets=disabled_toolsets_list,
|
disabled_toolsets=disabled_toolsets_list,
|
||||||
save_trajectories=save_trajectories,
|
save_trajectories=save_trajectories,
|
||||||
verbose_logging=verbose,
|
verbose_logging=verbose,
|
||||||
log_prefix_chars=log_prefix_chars
|
log_prefix_chars=log_prefix_chars,
|
||||||
|
prokletor_client=prokletor_client,
|
||||||
|
prokletor_formatter=prokletor_formatter
|
||||||
)
|
)
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
print(f"❌ Failed to initialize agent: {e}")
|
print(f"❌ Failed to initialize agent: {e}")
|
||||||
@@ -769,7 +1010,7 @@ def main(
|
|||||||
print("\n" + "=" * 50)
|
print("\n" + "=" * 50)
|
||||||
|
|
||||||
# Run conversation
|
# Run conversation
|
||||||
result = agent.run_conversation(user_query)
|
result = asyncio.run(agent.run_conversation(user_query))
|
||||||
|
|
||||||
print("\n" + "=" * 50)
|
print("\n" + "=" * 50)
|
||||||
print("📋 CONVERSATION SUMMARY")
|
print("📋 CONVERSATION SUMMARY")
|
||||||
|
|||||||
@@ -78,6 +78,7 @@ AGGREGATOR_TEMPERATURE = 0.4 # Focused synthesis for consistency
|
|||||||
|
|
||||||
# Failure handling configuration
|
# Failure handling configuration
|
||||||
MIN_SUCCESSFUL_REFERENCES = 1 # Minimum successful reference models needed to proceed
|
MIN_SUCCESSFUL_REFERENCES = 1 # Minimum successful reference models needed to proceed
|
||||||
|
UNAVAILABLE_TOOL_RESPONSE = "This tools is not available"
|
||||||
|
|
||||||
# System prompt for the aggregator model (from the research paper)
|
# System prompt for the aggregator model (from the research paper)
|
||||||
AGGREGATOR_SYSTEM_PROMPT = """You have been provided with a set of responses from various open-source models to the latest user query. Your task is to synthesize these responses into a single, high-quality response. It is crucial to critically evaluate the information provided in these responses, recognizing that some of it may be biased or incorrect. Your response should not simply replicate the given answers but should offer a refined, accurate, and comprehensive reply to the instruction. Ensure your response is well-structured, coherent, and adheres to the highest standards of accuracy and reliability.
|
AGGREGATOR_SYSTEM_PROMPT = """You have been provided with a set of responses from various open-source models to the latest user query. Your task is to synthesize these responses into a single, high-quality response. It is crucial to critically evaluate the information provided in these responses, recognizing that some of it may be biased or incorrect. Your response should not simply replicate the given answers but should offer a refined, accurate, and comprehensive reply to the instruction. Ensure your response is well-structured, coherent, and adheres to the highest standards of accuracy and reliability.
|
||||||
@@ -364,14 +365,29 @@ async def mixture_of_agents_tool(
|
|||||||
if failed_models:
|
if failed_models:
|
||||||
print(f"⚠️ Failed models: {', '.join(failed_models)}")
|
print(f"⚠️ Failed models: {', '.join(failed_models)}")
|
||||||
|
|
||||||
# Check if we have enough successful responses to proceed
|
|
||||||
if successful_count < MIN_SUCCESSFUL_REFERENCES:
|
|
||||||
raise ValueError(f"Insufficient successful reference models ({successful_count}/{len(ref_models)}). Need at least {MIN_SUCCESSFUL_REFERENCES} successful responses.")
|
|
||||||
|
|
||||||
debug_call_data["reference_responses_count"] = successful_count
|
debug_call_data["reference_responses_count"] = successful_count
|
||||||
debug_call_data["failed_models_count"] = failed_count
|
debug_call_data["failed_models_count"] = failed_count
|
||||||
debug_call_data["failed_models"] = failed_models
|
debug_call_data["failed_models"] = failed_models
|
||||||
|
|
||||||
|
# Check if we have enough successful responses to proceed
|
||||||
|
if successful_count < MIN_SUCCESSFUL_REFERENCES:
|
||||||
|
print("🚫 MoA tool unavailable: insufficient successful reference models after retries")
|
||||||
|
result = {
|
||||||
|
"success": False,
|
||||||
|
"response": UNAVAILABLE_TOOL_RESPONSE,
|
||||||
|
"models_used": {
|
||||||
|
"reference_models": ref_models,
|
||||||
|
"aggregator_model": agg_model
|
||||||
|
}
|
||||||
|
}
|
||||||
|
debug_call_data["error"] = UNAVAILABLE_TOOL_RESPONSE
|
||||||
|
debug_call_data["models_used"] = result["models_used"]
|
||||||
|
processing_time = (datetime.datetime.now() - start_time).total_seconds()
|
||||||
|
debug_call_data["processing_time_seconds"] = processing_time
|
||||||
|
_log_debug_call("mixture_of_agents_tool", debug_call_data)
|
||||||
|
_save_debug_log()
|
||||||
|
return json.dumps(result, indent=2, ensure_ascii=False)
|
||||||
|
|
||||||
# Layer 2: Aggregate responses using the aggregator model
|
# Layer 2: Aggregate responses using the aggregator model
|
||||||
print("🧠 Layer 2: Synthesizing final response...")
|
print("🧠 Layer 2: Synthesizing final response...")
|
||||||
aggregator_system_prompt = _construct_aggregator_prompt(
|
aggregator_system_prompt = _construct_aggregator_prompt(
|
||||||
|
|||||||
@@ -189,8 +189,13 @@ def _execute_ssh_command(instance, command: str, timeout: Optional[int] = None)
|
|||||||
ssh_context_manager = instance.ssh()
|
ssh_context_manager = instance.ssh()
|
||||||
ssh_context = ssh_context_manager.__enter__()
|
ssh_context = ssh_context_manager.__enter__()
|
||||||
|
|
||||||
# Execute the command
|
# Execute the command. Using a PTY ensures stdout/stderr ordering matches
|
||||||
result = ssh_context.run(command, get_pty=False, timeout=timeout or 120)
|
# what a human would see in a terminal session.
|
||||||
|
result = ssh_context.run(
|
||||||
|
command,
|
||||||
|
get_pty=True,
|
||||||
|
timeout=timeout or 120,
|
||||||
|
)
|
||||||
|
|
||||||
# Close the SSH connection
|
# Close the SSH connection
|
||||||
if ssh_context_manager:
|
if ssh_context_manager:
|
||||||
@@ -213,22 +218,12 @@ def _execute_ssh_command(instance, command: str, timeout: Optional[int] = None)
|
|||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# Check if it's a timeout
|
|
||||||
error_str = str(e).lower()
|
|
||||||
if "timeout" in error_str:
|
|
||||||
return {
|
|
||||||
"stdout": "",
|
|
||||||
"stderr": f"Command timed out after {timeout or 120} seconds",
|
|
||||||
"returncode": 124
|
|
||||||
}
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"stdout": "",
|
"stdout": "",
|
||||||
"stderr": f"SSH execution failed: {str(e)}",
|
"stderr": f"SSH execution failed: {str(e)}",
|
||||||
"returncode": -1
|
"returncode": -1
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def simple_terminal_tool(
|
def simple_terminal_tool(
|
||||||
command: str,
|
command: str,
|
||||||
background: bool = False,
|
background: bool = False,
|
||||||
@@ -315,15 +310,21 @@ def simple_terminal_tool(
|
|||||||
result = _execute_ssh_command(instance, exec_command, timeout=10)
|
result = _execute_ssh_command(instance, exec_command, timeout=10)
|
||||||
|
|
||||||
# For background tasks, return immediately with info
|
# For background tasks, return immediately with info
|
||||||
|
stderr_text = (result["stderr"] or "").strip()
|
||||||
if result["returncode"] == 0:
|
if result["returncode"] == 0:
|
||||||
return json.dumps({
|
return json.dumps({
|
||||||
"output": "Background task started successfully",
|
"output": "Background task started successfully",
|
||||||
|
"stderr": stderr_text,
|
||||||
"exit_code": 0,
|
"exit_code": 0,
|
||||||
"error": None
|
"error": None
|
||||||
}, ensure_ascii=False)
|
}, ensure_ascii=False)
|
||||||
else:
|
else:
|
||||||
|
output_text = result["stdout"] or ""
|
||||||
|
if result["stderr"] and not output_text:
|
||||||
|
output_text = result["stderr"]
|
||||||
return json.dumps({
|
return json.dumps({
|
||||||
"output": result["stdout"],
|
"output": output_text,
|
||||||
|
"stderr": stderr_text,
|
||||||
"exit_code": result["returncode"],
|
"exit_code": result["returncode"],
|
||||||
"error": result["stderr"]
|
"error": result["stderr"]
|
||||||
}, ensure_ascii=False)
|
}, ensure_ascii=False)
|
||||||
@@ -331,13 +332,13 @@ def simple_terminal_tool(
|
|||||||
# Run foreground command
|
# Run foreground command
|
||||||
result = _execute_ssh_command(instance, command, timeout=timeout)
|
result = _execute_ssh_command(instance, command, timeout=timeout)
|
||||||
|
|
||||||
# Combine stdout and stderr for output
|
output = result["stdout"] or ""
|
||||||
output = result["stdout"]
|
|
||||||
if result["stderr"] and result["returncode"] != 0:
|
if result["stderr"] and result["returncode"] != 0:
|
||||||
output = f"{output}\n{result['stderr']}" if output else result["stderr"]
|
output = f"{output}\n{result['stderr']}" if output else result["stderr"]
|
||||||
|
stderr_text = (result["stderr"] or "").strip()
|
||||||
return json.dumps({
|
return json.dumps({
|
||||||
"output": output.strip(),
|
"output": output.strip(),
|
||||||
|
"stderr": stderr_text,
|
||||||
"exit_code": result["returncode"],
|
"exit_code": result["returncode"],
|
||||||
"error": result["stderr"] if result["returncode"] != 0 else None
|
"error": result["stderr"] if result["returncode"] != 0 else None
|
||||||
}, ensure_ascii=False)
|
}, ensure_ascii=False)
|
||||||
|
|||||||
@@ -48,11 +48,11 @@ import uuid
|
|||||||
import datetime
|
import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Dict, Any, Optional
|
from typing import List, Dict, Any, Optional
|
||||||
from firecrawl import Firecrawl
|
from firecrawl import AsyncFirecrawl
|
||||||
from openai import AsyncOpenAI
|
from openai import AsyncOpenAI
|
||||||
|
|
||||||
# Initialize Firecrawl client once at module level
|
# Initialize Firecrawl client once at module level
|
||||||
firecrawl_client = Firecrawl(api_key=os.getenv("FIRECRAWL_API_KEY"))
|
firecrawl_client = AsyncFirecrawl(api_key=os.getenv("FIRECRAWL_API_KEY"))
|
||||||
|
|
||||||
# Initialize Nous Research API client for LLM processing (async)
|
# Initialize Nous Research API client for LLM processing (async)
|
||||||
nous_client = AsyncOpenAI(
|
nous_client = AsyncOpenAI(
|
||||||
@@ -261,7 +261,7 @@ def clean_base64_images(text: str) -> str:
|
|||||||
return cleaned_text
|
return cleaned_text
|
||||||
|
|
||||||
|
|
||||||
def web_search_tool(query: str, limit: int = 5) -> str:
|
async def web_search_tool(query: str, limit: int = 5) -> str:
|
||||||
"""
|
"""
|
||||||
Search the web for information using available search API backend.
|
Search the web for information using available search API backend.
|
||||||
|
|
||||||
@@ -312,7 +312,7 @@ def web_search_tool(query: str, limit: int = 5) -> str:
|
|||||||
# Use Firecrawl's v2 search functionality WITHOUT scraping
|
# Use Firecrawl's v2 search functionality WITHOUT scraping
|
||||||
# We only want search result metadata, not scraped content
|
# We only want search result metadata, not scraped content
|
||||||
# Docs: https://docs.firecrawl.dev/features/search
|
# Docs: https://docs.firecrawl.dev/features/search
|
||||||
response = firecrawl_client.search(
|
response = await firecrawl_client.search(
|
||||||
query=query,
|
query=query,
|
||||||
limit=limit
|
limit=limit
|
||||||
)
|
)
|
||||||
@@ -446,7 +446,7 @@ async def web_extract_tool(
|
|||||||
for url in urls:
|
for url in urls:
|
||||||
try:
|
try:
|
||||||
print(f" 📄 Scraping: {url}")
|
print(f" 📄 Scraping: {url}")
|
||||||
scrape_result = firecrawl_client.scrape(
|
scrape_result = await firecrawl_client.scrape(
|
||||||
url=url,
|
url=url,
|
||||||
formats=formats
|
formats=formats
|
||||||
)
|
)
|
||||||
@@ -703,7 +703,7 @@ async def web_crawl_tool(
|
|||||||
|
|
||||||
# Use the crawl method which waits for completion automatically
|
# Use the crawl method which waits for completion automatically
|
||||||
try:
|
try:
|
||||||
crawl_result = firecrawl_client.crawl(
|
crawl_result = await firecrawl_client.crawl(
|
||||||
url=url,
|
url=url,
|
||||||
**crawl_params
|
**crawl_params
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user