mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-28 23:11:37 +08:00
Compare commits
4 Commits
skill/gith
...
thought-si
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a219e178a1 | ||
|
|
e06a15b3ab | ||
|
|
349e37de0a | ||
|
|
31c733383b |
717
batch_runner.py
717
batch_runner.py
@@ -9,6 +9,8 @@ across multiple prompts from a dataset. It includes:
|
|||||||
- Checkpointing for fault tolerance and resumption
|
- Checkpointing for fault tolerance and resumption
|
||||||
- Trajectory saving in the proper format (from/value pairs)
|
- Trajectory saving in the proper format (from/value pairs)
|
||||||
- Tool usage statistics aggregation across all batches
|
- Tool usage statistics aggregation across all batches
|
||||||
|
- Cluster failure detection and graceful shutdown (morph, firecrawl, API errors)
|
||||||
|
- Configurable failure thresholds with automatic data consolidation
|
||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run
|
python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run
|
||||||
@@ -18,6 +20,10 @@ Usage:
|
|||||||
|
|
||||||
# Use a specific toolset distribution
|
# Use a specific toolset distribution
|
||||||
python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run --distribution=image_gen
|
python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run --distribution=image_gen
|
||||||
|
|
||||||
|
# Configure tool failure thresholds
|
||||||
|
python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run \\
|
||||||
|
--max_tool_failures=20 --max_tool_failure_rate=0.3 --min_tool_calls_for_rate=10
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
@@ -29,6 +35,7 @@ from typing import List, Dict, Any, Optional, Tuple
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from multiprocessing import Pool, Manager, Lock
|
from multiprocessing import Pool, Manager, Lock
|
||||||
import traceback
|
import traceback
|
||||||
|
import re
|
||||||
|
|
||||||
import fire
|
import fire
|
||||||
|
|
||||||
@@ -39,11 +46,166 @@ from toolset_distributions import (
|
|||||||
sample_toolsets_from_distribution,
|
sample_toolsets_from_distribution,
|
||||||
validate_distribution
|
validate_distribution
|
||||||
)
|
)
|
||||||
|
from safe_print import safe_print
|
||||||
|
|
||||||
|
|
||||||
# Global configuration for worker processes
|
# Global configuration for worker processes
|
||||||
_WORKER_CONFIG = {}
|
_WORKER_CONFIG = {}
|
||||||
|
|
||||||
|
# Canonical names for the terminal tool (old & new implementations)
|
||||||
|
_TERMINAL_TOOL_NAMES = {"terminal", "terminal_tool", "simple_terminal_tool"}
|
||||||
|
|
||||||
|
|
||||||
|
def _is_terminal_tool_name(tool_name: Optional[str]) -> bool:
|
||||||
|
"""Return True if the given tool name corresponds to a terminal tool."""
|
||||||
|
return bool(tool_name) and tool_name.lower() in _TERMINAL_TOOL_NAMES
|
||||||
|
|
||||||
|
|
||||||
|
def _terminal_tool_failed(content_json: Dict[str, Any]) -> bool:
|
||||||
|
"""
|
||||||
|
Determine whether the terminal tool itself failed (not the user command).
|
||||||
|
|
||||||
|
Terminal failures are indicated by explicit status flags or negative exit codes.
|
||||||
|
Regular command failures (non-zero positive exit codes, stderr, timeouts) are not counted.
|
||||||
|
"""
|
||||||
|
if not isinstance(content_json, dict):
|
||||||
|
return False
|
||||||
|
|
||||||
|
status = str(content_json.get("status", "")).lower()
|
||||||
|
if status in {"error", "disabled"}:
|
||||||
|
return True
|
||||||
|
|
||||||
|
exit_code = content_json.get("exit_code")
|
||||||
|
if isinstance(exit_code, int) and exit_code < 0:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _categorize_error_type(error_message: str) -> str:
|
||||||
|
"""
|
||||||
|
Categorize an error message into a failure type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
error_message (str): The error message to categorize
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Category of the error
|
||||||
|
"""
|
||||||
|
error_lower = error_message.lower()
|
||||||
|
|
||||||
|
# Common error patterns
|
||||||
|
if "timeout" in error_lower or "timed out" in error_lower:
|
||||||
|
return "Timeout"
|
||||||
|
elif "connection" in error_lower or "connect" in error_lower:
|
||||||
|
return "Connection Error"
|
||||||
|
elif "rate limit" in error_lower or "ratelimit" in error_lower or "429" in error_lower:
|
||||||
|
return "Rate Limit"
|
||||||
|
elif "authentication" in error_lower or "auth" in error_lower or "unauthorized" in error_lower or "401" in error_lower:
|
||||||
|
return "Authentication"
|
||||||
|
elif "not found" in error_lower or "404" in error_lower:
|
||||||
|
return "Not Found"
|
||||||
|
elif "permission" in error_lower or "forbidden" in error_lower or "403" in error_lower:
|
||||||
|
return "Permission Denied"
|
||||||
|
elif "invalid" in error_lower or "malformed" in error_lower or "bad request" in error_lower or "400" in error_lower:
|
||||||
|
return "Invalid Input"
|
||||||
|
elif "out of memory" in error_lower or "oom" in error_lower:
|
||||||
|
return "Out of Memory"
|
||||||
|
elif "network" in error_lower:
|
||||||
|
return "Network Error"
|
||||||
|
elif "server error" in error_lower or "500" in error_lower or "502" in error_lower or "503" in error_lower:
|
||||||
|
return "Server Error"
|
||||||
|
elif "vm" in error_lower and ("fail" in error_lower or "error" in error_lower):
|
||||||
|
return "VM Error"
|
||||||
|
else:
|
||||||
|
return "Other"
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_tool_errors_from_messages(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Extract tool errors from message history with tool names.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages (List[Dict]): Message history
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Dict]: List of tool errors with tool name, error message, error type, and context
|
||||||
|
"""
|
||||||
|
tool_errors = []
|
||||||
|
tool_calls_map = {} # Map tool_call_id to tool name
|
||||||
|
|
||||||
|
for msg in messages:
|
||||||
|
# Track tool calls from assistant messages
|
||||||
|
if msg["role"] == "assistant" and "tool_calls" in msg and msg["tool_calls"]:
|
||||||
|
for tool_call in msg["tool_calls"]:
|
||||||
|
tool_name = tool_call["function"]["name"]
|
||||||
|
tool_call_id = tool_call["id"]
|
||||||
|
tool_calls_map[tool_call_id] = tool_name
|
||||||
|
|
||||||
|
# Check tool responses for errors
|
||||||
|
elif msg["role"] == "tool":
|
||||||
|
tool_call_id = msg.get("tool_call_id", "")
|
||||||
|
content = msg.get("content", "")
|
||||||
|
|
||||||
|
# Determine if tool call had an error
|
||||||
|
has_error = False
|
||||||
|
error_msg = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
content_json = json.loads(content) if isinstance(content, str) else content
|
||||||
|
|
||||||
|
if isinstance(content_json, dict):
|
||||||
|
# Get tool name for special handling
|
||||||
|
tool_name = tool_calls_map.get(tool_call_id, "unknown")
|
||||||
|
|
||||||
|
# Special handling for terminal tool outputs
|
||||||
|
if _is_terminal_tool_name(tool_name):
|
||||||
|
if _terminal_tool_failed(content_json):
|
||||||
|
has_error = True
|
||||||
|
# Prefer explicit error text, fall back to status or generic message
|
||||||
|
error_msg = str(
|
||||||
|
content_json.get("error")
|
||||||
|
or content_json.get("status")
|
||||||
|
or "Terminal tool failure"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# For other tools, check if error field exists AND has a non-null value
|
||||||
|
if "error" in content_json and content_json["error"] is not None:
|
||||||
|
has_error = True
|
||||||
|
error_msg = str(content_json["error"])
|
||||||
|
|
||||||
|
# Check nested content structure (some tools wrap responses)
|
||||||
|
if "content" in content_json and isinstance(content_json["content"], dict):
|
||||||
|
inner_content = content_json["content"]
|
||||||
|
if inner_content.get("error") is not None:
|
||||||
|
has_error = True
|
||||||
|
error_msg = inner_content.get("error")
|
||||||
|
|
||||||
|
# Check for "success": false pattern
|
||||||
|
if content_json.get("success") is False:
|
||||||
|
has_error = True
|
||||||
|
if not error_msg:
|
||||||
|
error_msg = str(content_json.get("message", content_json.get("error", "Unknown error")))
|
||||||
|
|
||||||
|
except:
|
||||||
|
# If not JSON, check if content explicitly states an error
|
||||||
|
if content.strip().lower().startswith("error:"):
|
||||||
|
has_error = True
|
||||||
|
error_msg = content.strip()
|
||||||
|
|
||||||
|
# Record error if found
|
||||||
|
if has_error and tool_call_id in tool_calls_map:
|
||||||
|
tool_name = tool_calls_map[tool_call_id]
|
||||||
|
error_message = error_msg or "Unknown error"
|
||||||
|
tool_errors.append({
|
||||||
|
"tool_name": tool_name,
|
||||||
|
"error_message": error_message,
|
||||||
|
"error_type": _categorize_error_type(error_message),
|
||||||
|
"full_content": content[:500] # Keep first 500 chars of full response
|
||||||
|
})
|
||||||
|
|
||||||
|
return tool_errors
|
||||||
|
|
||||||
|
|
||||||
def _extract_tool_stats(messages: List[Dict[str, Any]]) -> Dict[str, Dict[str, int]]:
|
def _extract_tool_stats(messages: List[Dict[str, Any]]) -> Dict[str, Dict[str, int]]:
|
||||||
"""
|
"""
|
||||||
@@ -90,22 +252,28 @@ def _extract_tool_stats(messages: List[Dict[str, Any]]) -> Dict[str, Dict[str, i
|
|||||||
content_json = json.loads(content) if isinstance(content, str) else content
|
content_json = json.loads(content) if isinstance(content, str) else content
|
||||||
|
|
||||||
if isinstance(content_json, dict):
|
if isinstance(content_json, dict):
|
||||||
# Check if error field exists AND has a non-null value
|
# Get tool name for special handling
|
||||||
if "error" in content_json and content_json["error"] is not None:
|
tool_name = tool_calls_map.get(tool_call_id, "unknown")
|
||||||
is_success = False
|
|
||||||
|
|
||||||
# Special handling for terminal tool responses
|
# Special handling for terminal tool: only count as failure when the tool itself fails
|
||||||
# Terminal wraps its response in a "content" field
|
if _is_terminal_tool_name(tool_name):
|
||||||
if "content" in content_json and isinstance(content_json["content"], dict):
|
if _terminal_tool_failed(content_json):
|
||||||
inner_content = content_json["content"]
|
is_success = False
|
||||||
# Check for actual error (non-null error field)
|
else:
|
||||||
# Note: non-zero exit codes are not failures - the model can self-correct
|
# For other tools, check if error field exists AND has a non-null value
|
||||||
if inner_content.get("error") is not None:
|
if "error" in content_json and content_json["error"] is not None:
|
||||||
is_success = False
|
is_success = False
|
||||||
|
|
||||||
# Check for "success": false pattern used by some tools
|
# Check nested content structure (some tools wrap responses)
|
||||||
if content_json.get("success") is False:
|
if "content" in content_json and isinstance(content_json["content"], dict):
|
||||||
is_success = False
|
inner_content = content_json["content"]
|
||||||
|
# Check for actual error (non-null error field)
|
||||||
|
if inner_content.get("error") is not None:
|
||||||
|
is_success = False
|
||||||
|
|
||||||
|
# Check for "success": false pattern used by some tools
|
||||||
|
if content_json.get("success") is False:
|
||||||
|
is_success = False
|
||||||
|
|
||||||
except:
|
except:
|
||||||
# If not JSON, check if content is empty or explicitly states an error
|
# If not JSON, check if content is empty or explicitly states an error
|
||||||
@@ -173,6 +341,9 @@ def _process_single_prompt(
|
|||||||
# Extract tool usage statistics
|
# Extract tool usage statistics
|
||||||
tool_stats = _extract_tool_stats(result["messages"])
|
tool_stats = _extract_tool_stats(result["messages"])
|
||||||
|
|
||||||
|
# Extract tool errors from conversation
|
||||||
|
tool_errors = _extract_tool_errors_from_messages(result["messages"])
|
||||||
|
|
||||||
# Convert to trajectory format (using existing method)
|
# Convert to trajectory format (using existing method)
|
||||||
trajectory = agent._convert_to_trajectory_format(
|
trajectory = agent._convert_to_trajectory_format(
|
||||||
result["messages"],
|
result["messages"],
|
||||||
@@ -180,11 +351,16 @@ def _process_single_prompt(
|
|||||||
result["completed"]
|
result["completed"]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Get profiling stats from the result
|
||||||
|
profiling_stats = result.get("profiling_stats", {"tools": {}, "api_calls": {}})
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"success": True,
|
"success": True,
|
||||||
"prompt_index": prompt_index,
|
"prompt_index": prompt_index,
|
||||||
"trajectory": trajectory,
|
"trajectory": trajectory,
|
||||||
"tool_stats": tool_stats,
|
"tool_stats": tool_stats,
|
||||||
|
"tool_errors": tool_errors,
|
||||||
|
"profiling_stats": profiling_stats,
|
||||||
"completed": result["completed"],
|
"completed": result["completed"],
|
||||||
"api_calls": result["api_calls"],
|
"api_calls": result["api_calls"],
|
||||||
"toolsets_used": selected_toolsets,
|
"toolsets_used": selected_toolsets,
|
||||||
@@ -196,14 +372,19 @@ def _process_single_prompt(
|
|||||||
}
|
}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"❌ Error processing prompt {prompt_index}: {e}")
|
error_msg = str(e)
|
||||||
|
tb = traceback.format_exc()
|
||||||
|
safe_print(f"[bold red]❌ Error processing prompt {prompt_index}:[/bold red] {error_msg}")
|
||||||
if config.get("verbose"):
|
if config.get("verbose"):
|
||||||
traceback.print_exc()
|
safe_print(tb)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"success": False,
|
"success": False,
|
||||||
"prompt_index": prompt_index,
|
"prompt_index": prompt_index,
|
||||||
"error": str(e),
|
"error": error_msg,
|
||||||
|
"traceback": tb,
|
||||||
|
"tool_errors": [],
|
||||||
|
"profiling_stats": {"tools": {}, "api_calls": {}},
|
||||||
"trajectory": None,
|
"trajectory": None,
|
||||||
"tool_stats": {},
|
"tool_stats": {},
|
||||||
"toolsets_used": [],
|
"toolsets_used": [],
|
||||||
@@ -252,7 +433,10 @@ def _process_batch_worker(args: Tuple) -> Dict[str, Any]:
|
|||||||
|
|
||||||
# Initialize aggregated stats for this batch
|
# Initialize aggregated stats for this batch
|
||||||
batch_tool_stats = {}
|
batch_tool_stats = {}
|
||||||
|
batch_profiling_stats = [] # Collect profiling stats from each prompt
|
||||||
completed_in_batch = []
|
completed_in_batch = []
|
||||||
|
all_tool_errors = [] # Track all tool errors in this batch
|
||||||
|
exception_errors = [] # Track top-level exceptions
|
||||||
|
|
||||||
# Process each prompt sequentially in this batch
|
# Process each prompt sequentially in this batch
|
||||||
for prompt_index, prompt_data in prompts_to_process:
|
for prompt_index, prompt_data in prompts_to_process:
|
||||||
@@ -264,6 +448,26 @@ def _process_batch_worker(args: Tuple) -> Dict[str, Any]:
|
|||||||
config
|
config
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Track tool errors from the conversation
|
||||||
|
if result.get("tool_errors"):
|
||||||
|
for tool_error in result["tool_errors"]:
|
||||||
|
all_tool_errors.append({
|
||||||
|
"prompt_index": prompt_index,
|
||||||
|
"tool_name": tool_error["tool_name"],
|
||||||
|
"error_message": tool_error["error_message"],
|
||||||
|
"full_content": tool_error.get("full_content", ""),
|
||||||
|
"error_type": tool_error.get("error_type", "Other")
|
||||||
|
})
|
||||||
|
|
||||||
|
# Track top-level exceptions (not tool errors)
|
||||||
|
if not result["success"]:
|
||||||
|
exception_errors.append({
|
||||||
|
"prompt_index": prompt_index,
|
||||||
|
"error": result.get("error", "Unknown error"),
|
||||||
|
"traceback": result.get("traceback", "")
|
||||||
|
})
|
||||||
|
safe_print(f"[bold red]❌ Exception in prompt {prompt_index}:[/bold red] {result.get('error', '')[:100]}")
|
||||||
|
|
||||||
# Save trajectory if successful
|
# Save trajectory if successful
|
||||||
if result["success"] and result["trajectory"]:
|
if result["success"] and result["trajectory"]:
|
||||||
trajectory_entry = {
|
trajectory_entry = {
|
||||||
@@ -292,6 +496,10 @@ def _process_batch_worker(args: Tuple) -> Dict[str, Any]:
|
|||||||
batch_tool_stats[tool_name]["success"] += stats["success"]
|
batch_tool_stats[tool_name]["success"] += stats["success"]
|
||||||
batch_tool_stats[tool_name]["failure"] += stats["failure"]
|
batch_tool_stats[tool_name]["failure"] += stats["failure"]
|
||||||
|
|
||||||
|
# Collect profiling statistics
|
||||||
|
if result.get("profiling_stats"):
|
||||||
|
batch_profiling_stats.append(result["profiling_stats"])
|
||||||
|
|
||||||
completed_in_batch.append(prompt_index)
|
completed_in_batch.append(prompt_index)
|
||||||
print(f" ✅ Prompt {prompt_index} completed")
|
print(f" ✅ Prompt {prompt_index} completed")
|
||||||
|
|
||||||
@@ -302,7 +510,10 @@ def _process_batch_worker(args: Tuple) -> Dict[str, Any]:
|
|||||||
"processed": len(prompts_to_process),
|
"processed": len(prompts_to_process),
|
||||||
"skipped": len(batch_data) - len(prompts_to_process),
|
"skipped": len(batch_data) - len(prompts_to_process),
|
||||||
"tool_stats": batch_tool_stats,
|
"tool_stats": batch_tool_stats,
|
||||||
"completed_prompts": completed_in_batch
|
"profiling_stats": batch_profiling_stats,
|
||||||
|
"completed_prompts": completed_in_batch,
|
||||||
|
"tool_errors": all_tool_errors,
|
||||||
|
"exception_errors": exception_errors
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -325,6 +536,10 @@ class BatchRunner:
|
|||||||
verbose: bool = False,
|
verbose: bool = False,
|
||||||
ephemeral_system_prompt: str = None,
|
ephemeral_system_prompt: str = None,
|
||||||
log_prefix_chars: int = 100,
|
log_prefix_chars: int = 100,
|
||||||
|
max_tool_failures: int = 10,
|
||||||
|
max_tool_failure_rate: float = 0.5,
|
||||||
|
keep_recent_errors: int = 5,
|
||||||
|
min_tool_calls_for_rate: int = 10,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize the batch runner.
|
Initialize the batch runner.
|
||||||
@@ -342,6 +557,10 @@ class BatchRunner:
|
|||||||
verbose (bool): Enable verbose logging
|
verbose (bool): Enable verbose logging
|
||||||
ephemeral_system_prompt (str): System prompt used during agent execution but NOT saved to trajectories (optional)
|
ephemeral_system_prompt (str): System prompt used during agent execution but NOT saved to trajectories (optional)
|
||||||
log_prefix_chars (int): Number of characters to show in log previews for tool calls/responses (default: 20)
|
log_prefix_chars (int): Number of characters to show in log previews for tool calls/responses (default: 20)
|
||||||
|
max_tool_failures (int): Maximum number of tool failures before stopping (default: 10)
|
||||||
|
max_tool_failure_rate (float): Maximum tool failure rate (0.0-1.0) before stopping (default: 0.5)
|
||||||
|
keep_recent_errors (int): Number of recent errors to keep per tool (default: 5)
|
||||||
|
min_tool_calls_for_rate (int): Minimum number of tool calls before checking failure rate (default: 10)
|
||||||
"""
|
"""
|
||||||
self.dataset_file = Path(dataset_file)
|
self.dataset_file = Path(dataset_file)
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
@@ -355,6 +574,10 @@ class BatchRunner:
|
|||||||
self.verbose = verbose
|
self.verbose = verbose
|
||||||
self.ephemeral_system_prompt = ephemeral_system_prompt
|
self.ephemeral_system_prompt = ephemeral_system_prompt
|
||||||
self.log_prefix_chars = log_prefix_chars
|
self.log_prefix_chars = log_prefix_chars
|
||||||
|
self.max_tool_failures = max_tool_failures
|
||||||
|
self.max_tool_failure_rate = max_tool_failure_rate
|
||||||
|
self.keep_recent_errors = keep_recent_errors
|
||||||
|
self.min_tool_calls_for_rate = min_tool_calls_for_rate
|
||||||
|
|
||||||
# Validate distribution
|
# Validate distribution
|
||||||
if not validate_distribution(distribution):
|
if not validate_distribution(distribution):
|
||||||
@@ -370,23 +593,31 @@ class BatchRunner:
|
|||||||
# Statistics file
|
# Statistics file
|
||||||
self.stats_file = self.output_dir / "statistics.json"
|
self.stats_file = self.output_dir / "statistics.json"
|
||||||
|
|
||||||
|
# Errors file
|
||||||
|
self.errors_file = self.output_dir / "errors.json"
|
||||||
|
|
||||||
# Load dataset
|
# Load dataset
|
||||||
self.dataset = self._load_dataset()
|
self.dataset = self._load_dataset()
|
||||||
|
|
||||||
# Create batches
|
# Create batches
|
||||||
self.batches = self._create_batches()
|
self.batches = self._create_batches()
|
||||||
|
|
||||||
print(f"📊 Batch Runner Initialized")
|
safe_print("[bold cyan]📊 Batch Runner Initialized[/bold cyan]")
|
||||||
print(f" Dataset: {self.dataset_file} ({len(self.dataset)} prompts)")
|
safe_print(f" Dataset: {self.dataset_file} ({len(self.dataset)} prompts)")
|
||||||
print(f" Batch size: {self.batch_size}")
|
safe_print(f" Batch size: {self.batch_size}")
|
||||||
print(f" Total batches: {len(self.batches)}")
|
safe_print(f" Total batches: {len(self.batches)}")
|
||||||
print(f" Run name: {self.run_name}")
|
safe_print(f" Run name: {self.run_name}")
|
||||||
print(f" Distribution: {self.distribution}")
|
safe_print(f" Distribution: {self.distribution}")
|
||||||
print(f" Output directory: {self.output_dir}")
|
safe_print(f" Output directory: {self.output_dir}")
|
||||||
print(f" Workers: {self.num_workers}")
|
safe_print(f" Workers: {self.num_workers}")
|
||||||
|
safe_print(f" [yellow]Tool failure limits:[/yellow]")
|
||||||
|
safe_print(f" Max failures: {self.max_tool_failures}")
|
||||||
|
safe_print(f" Max failure rate: {self.max_tool_failure_rate:.1%}")
|
||||||
|
safe_print(f" Min tool calls for rate check: {self.min_tool_calls_for_rate}")
|
||||||
|
safe_print(f" Keep recent errors: {self.keep_recent_errors}")
|
||||||
if self.ephemeral_system_prompt:
|
if self.ephemeral_system_prompt:
|
||||||
prompt_preview = self.ephemeral_system_prompt[:60] + "..." if len(self.ephemeral_system_prompt) > 60 else self.ephemeral_system_prompt
|
prompt_preview = self.ephemeral_system_prompt[:60] + "..." if len(self.ephemeral_system_prompt) > 60 else self.ephemeral_system_prompt
|
||||||
print(f" 🔒 Ephemeral system prompt: '{prompt_preview}'")
|
safe_print(f" 🔒 Ephemeral system prompt: '{prompt_preview}'")
|
||||||
|
|
||||||
def _load_dataset(self) -> List[Dict[str, Any]]:
|
def _load_dataset(self) -> List[Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
@@ -479,6 +710,118 @@ class BatchRunner:
|
|||||||
with open(self.checkpoint_file, 'w', encoding='utf-8') as f:
|
with open(self.checkpoint_file, 'w', encoding='utf-8') as f:
|
||||||
json.dump(checkpoint_data, f, indent=2, ensure_ascii=False)
|
json.dump(checkpoint_data, f, indent=2, ensure_ascii=False)
|
||||||
|
|
||||||
|
def _consolidate_data(self, num_batches: int, tool_stats: Dict[str, Dict[str, int]],
|
||||||
|
start_time: float, tool_errors_by_tool: Dict[str, List[Dict]],
|
||||||
|
exception_errors: List[Dict], early_exit: bool = False, exit_reason: str = None,
|
||||||
|
profiling_stats_list: List[Dict] = None):
|
||||||
|
"""
|
||||||
|
Consolidate batch data into trajectories.jsonl and save statistics.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_batches (int): Number of batches processed
|
||||||
|
tool_stats (Dict): Aggregated tool statistics
|
||||||
|
start_time (float): Start time of the run
|
||||||
|
tool_errors_by_tool (Dict): Tool errors grouped by tool name with k most recent
|
||||||
|
exception_errors (List): Top-level exceptions
|
||||||
|
early_exit (bool): Whether this is an early exit
|
||||||
|
exit_reason (str): Reason for early exit
|
||||||
|
profiling_stats_list (List[Dict]): List of profiling statistics from each conversation
|
||||||
|
"""
|
||||||
|
# Combine all batch files into a single trajectories.jsonl file
|
||||||
|
combined_file = self.output_dir / "trajectories.jsonl"
|
||||||
|
safe_print(f"\n[cyan]📦 Combining batch files into {combined_file.name}...[/cyan]")
|
||||||
|
|
||||||
|
entries_written = 0
|
||||||
|
with open(combined_file, 'w', encoding='utf-8') as outfile:
|
||||||
|
for batch_num in range(num_batches):
|
||||||
|
batch_file = self.output_dir / f"batch_{batch_num}.jsonl"
|
||||||
|
if batch_file.exists():
|
||||||
|
with open(batch_file, 'r', encoding='utf-8') as infile:
|
||||||
|
for line in infile:
|
||||||
|
outfile.write(line)
|
||||||
|
entries_written += 1
|
||||||
|
|
||||||
|
safe_print(f"[green]✅ Combined {num_batches} batch files into trajectories.jsonl ({entries_written} entries)[/green]")
|
||||||
|
|
||||||
|
# Calculate success rates for tool stats
|
||||||
|
for tool_name in tool_stats:
|
||||||
|
stats = tool_stats[tool_name]
|
||||||
|
total_calls = stats["success"] + stats["failure"]
|
||||||
|
if total_calls > 0:
|
||||||
|
stats["success_rate"] = round(stats["success"] / total_calls * 100, 2)
|
||||||
|
stats["failure_rate"] = round(stats["failure"] / total_calls * 100, 2)
|
||||||
|
else:
|
||||||
|
stats["success_rate"] = 0.0
|
||||||
|
stats["failure_rate"] = 0.0
|
||||||
|
|
||||||
|
# Build failure type breakdown for each tool
|
||||||
|
failure_type_breakdown = {}
|
||||||
|
for tool_name, errors in tool_errors_by_tool.items():
|
||||||
|
failure_types = {}
|
||||||
|
for error in errors:
|
||||||
|
error_type = error.get("error_type", "Other")
|
||||||
|
if error_type not in failure_types:
|
||||||
|
failure_types[error_type] = 0
|
||||||
|
failure_types[error_type] += 1
|
||||||
|
|
||||||
|
# Calculate percentages
|
||||||
|
total_failures = len(errors)
|
||||||
|
failure_type_breakdown[tool_name] = {
|
||||||
|
"total_failures": total_failures,
|
||||||
|
"types": {
|
||||||
|
error_type: {
|
||||||
|
"count": count,
|
||||||
|
"percentage": round((count / total_failures) * 100, 2)
|
||||||
|
}
|
||||||
|
for error_type, count in failure_types.items()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# Save error information to separate file
|
||||||
|
error_data = {
|
||||||
|
"run_name": self.run_name,
|
||||||
|
"completed_at": datetime.now().isoformat(),
|
||||||
|
"total_tool_errors": sum(len(errors) for errors in tool_errors_by_tool.values()),
|
||||||
|
"total_exception_errors": len(exception_errors),
|
||||||
|
"tool_errors": tool_errors_by_tool,
|
||||||
|
"failure_type_breakdown": failure_type_breakdown,
|
||||||
|
"exception_errors": exception_errors[:self.keep_recent_errors] # Keep k most recent
|
||||||
|
}
|
||||||
|
|
||||||
|
with open(self.errors_file, 'w', encoding='utf-8') as f:
|
||||||
|
json.dump(error_data, f, indent=2, ensure_ascii=False)
|
||||||
|
|
||||||
|
# Aggregate profiling statistics if available
|
||||||
|
aggregated_profiling_stats = None
|
||||||
|
if profiling_stats_list:
|
||||||
|
from profiling import aggregate_profiling_stats
|
||||||
|
aggregated_profiling_stats = aggregate_profiling_stats(profiling_stats_list)
|
||||||
|
|
||||||
|
# Save final statistics (without detailed errors)
|
||||||
|
final_stats = {
|
||||||
|
"run_name": self.run_name,
|
||||||
|
"distribution": self.distribution,
|
||||||
|
"total_prompts": len(self.dataset),
|
||||||
|
"total_batches": len(self.batches),
|
||||||
|
"batches_processed": num_batches,
|
||||||
|
"batch_size": self.batch_size,
|
||||||
|
"model": self.model,
|
||||||
|
"completed_at": datetime.now().isoformat(),
|
||||||
|
"duration_seconds": round(time.time() - start_time, 2),
|
||||||
|
"early_exit": early_exit,
|
||||||
|
"exit_reason": exit_reason,
|
||||||
|
"tool_statistics": tool_stats,
|
||||||
|
"profiling_statistics": aggregated_profiling_stats
|
||||||
|
}
|
||||||
|
|
||||||
|
with open(self.stats_file, 'w', encoding='utf-8') as f:
|
||||||
|
json.dump(final_stats, f, indent=2, ensure_ascii=False)
|
||||||
|
|
||||||
|
# Display aggregated profiling statistics
|
||||||
|
if aggregated_profiling_stats:
|
||||||
|
from profiling import print_aggregated_statistics
|
||||||
|
print_aggregated_statistics(aggregated_profiling_stats, detailed=True)
|
||||||
|
|
||||||
|
|
||||||
def run(self, resume: bool = False):
|
def run(self, resume: bool = False):
|
||||||
"""
|
"""
|
||||||
@@ -519,6 +862,14 @@ class BatchRunner:
|
|||||||
|
|
||||||
# Aggregate statistics across all batches
|
# Aggregate statistics across all batches
|
||||||
total_tool_stats = {}
|
total_tool_stats = {}
|
||||||
|
all_profiling_stats = [] # Collect all profiling stats for aggregation
|
||||||
|
tool_errors_by_tool = {} # {tool_name: [list of k most recent errors]}
|
||||||
|
all_exception_errors = []
|
||||||
|
all_completed_prompts = list(completed_prompts_set)
|
||||||
|
total_processed = len(completed_prompts_set)
|
||||||
|
total_tool_errors = 0
|
||||||
|
early_exit = False
|
||||||
|
exit_reason = None
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
@@ -536,82 +887,182 @@ class BatchRunner:
|
|||||||
for batch_num, batch_data in enumerate(self.batches)
|
for batch_num, batch_data in enumerate(self.batches)
|
||||||
]
|
]
|
||||||
|
|
||||||
# Use map to process batches in parallel
|
# Process batches in parallel and check tool failure threshold as results come in
|
||||||
results = pool.map(_process_batch_worker, tasks)
|
# imap_unordered allows parallel processing while getting results as they complete
|
||||||
|
batch_num = 0
|
||||||
|
try:
|
||||||
|
for result in pool.imap_unordered(_process_batch_worker, tasks):
|
||||||
|
# Update statistics
|
||||||
|
all_completed_prompts.extend(result.get("completed_prompts", []))
|
||||||
|
total_processed += result.get("processed", 0)
|
||||||
|
|
||||||
# Aggregate all batch statistics and update checkpoint
|
# Aggregate tool stats
|
||||||
all_completed_prompts = list(completed_prompts_set)
|
for tool_name, stats in result.get("tool_stats", {}).items():
|
||||||
for batch_result in results:
|
if tool_name not in total_tool_stats:
|
||||||
# Add newly completed prompts
|
total_tool_stats[tool_name] = {
|
||||||
all_completed_prompts.extend(batch_result.get("completed_prompts", []))
|
"count": 0,
|
||||||
|
"success": 0,
|
||||||
|
"failure": 0
|
||||||
|
}
|
||||||
|
|
||||||
# Aggregate tool stats
|
total_tool_stats[tool_name]["count"] += stats["count"]
|
||||||
for tool_name, stats in batch_result.get("tool_stats", {}).items():
|
total_tool_stats[tool_name]["success"] += stats["success"]
|
||||||
if tool_name not in total_tool_stats:
|
total_tool_stats[tool_name]["failure"] += stats["failure"]
|
||||||
total_tool_stats[tool_name] = {
|
|
||||||
"count": 0,
|
|
||||||
"success": 0,
|
|
||||||
"failure": 0
|
|
||||||
}
|
|
||||||
|
|
||||||
total_tool_stats[tool_name]["count"] += stats["count"]
|
# Collect profiling stats from this batch
|
||||||
total_tool_stats[tool_name]["success"] += stats["success"]
|
if result.get("profiling_stats"):
|
||||||
total_tool_stats[tool_name]["failure"] += stats["failure"]
|
all_profiling_stats.extend(result["profiling_stats"])
|
||||||
|
|
||||||
|
# Aggregate tool errors (keep k most recent per tool)
|
||||||
|
for tool_error in result.get("tool_errors", []):
|
||||||
|
tool_name = tool_error["tool_name"]
|
||||||
|
if tool_name not in tool_errors_by_tool:
|
||||||
|
tool_errors_by_tool[tool_name] = []
|
||||||
|
|
||||||
|
# Add error and keep only k most recent
|
||||||
|
tool_errors_by_tool[tool_name].append(tool_error)
|
||||||
|
if len(tool_errors_by_tool[tool_name]) > self.keep_recent_errors:
|
||||||
|
tool_errors_by_tool[tool_name] = tool_errors_by_tool[tool_name][-self.keep_recent_errors:]
|
||||||
|
|
||||||
|
total_tool_errors += 1
|
||||||
|
|
||||||
|
# Track exception errors
|
||||||
|
all_exception_errors.extend(result.get("exception_errors", []))
|
||||||
|
|
||||||
|
# Check tool failure thresholds
|
||||||
|
# Calculate total tool calls (not prompts)
|
||||||
|
total_tool_calls = sum(stats["count"] for stats in total_tool_stats.values())
|
||||||
|
|
||||||
|
# Check absolute count threshold
|
||||||
|
if total_tool_errors >= self.max_tool_failures:
|
||||||
|
early_exit = True
|
||||||
|
exit_reason = f"Exceeded maximum tool failures ({total_tool_errors}/{self.max_tool_failures})"
|
||||||
|
safe_print(f"\n[bold red]🛑 STOPPING: {exit_reason}[/bold red]")
|
||||||
|
pool.terminate() # Stop all workers immediately
|
||||||
|
break
|
||||||
|
|
||||||
|
# Check rate threshold (only if we have enough tool calls to trust the rate)
|
||||||
|
if total_tool_calls >= self.min_tool_calls_for_rate:
|
||||||
|
tool_failure_rate = total_tool_errors / total_tool_calls
|
||||||
|
|
||||||
|
if tool_failure_rate >= self.max_tool_failure_rate:
|
||||||
|
early_exit = True
|
||||||
|
exit_reason = f"Exceeded tool failure rate ({tool_failure_rate:.2%} >= {self.max_tool_failure_rate:.2%}, {total_tool_errors}/{total_tool_calls} tool calls)"
|
||||||
|
safe_print(f"\n[bold red]🛑 STOPPING: {exit_reason}[/bold red]")
|
||||||
|
pool.terminate() # Stop all workers immediately
|
||||||
|
break
|
||||||
|
|
||||||
|
# Update checkpoint after each batch completes
|
||||||
|
checkpoint_data["completed_prompts"] = all_completed_prompts
|
||||||
|
self._save_checkpoint(checkpoint_data)
|
||||||
|
|
||||||
|
batch_num += 1
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
safe_print("\n[bold yellow]⚠️ Interrupted by user, stopping workers...[/bold yellow]")
|
||||||
|
pool.terminate()
|
||||||
|
early_exit = True
|
||||||
|
exit_reason = "Interrupted by user"
|
||||||
|
|
||||||
# Save final checkpoint
|
# Save final checkpoint
|
||||||
checkpoint_data["completed_prompts"] = all_completed_prompts
|
checkpoint_data["completed_prompts"] = all_completed_prompts
|
||||||
self._save_checkpoint(checkpoint_data)
|
self._save_checkpoint(checkpoint_data)
|
||||||
|
|
||||||
# Calculate success rates
|
# Consolidate data and save statistics
|
||||||
for tool_name in total_tool_stats:
|
num_batches_processed = batch_num + 1 if early_exit else len(self.batches)
|
||||||
stats = total_tool_stats[tool_name]
|
self._consolidate_data(
|
||||||
total_calls = stats["success"] + stats["failure"]
|
num_batches_processed,
|
||||||
if total_calls > 0:
|
total_tool_stats,
|
||||||
stats["success_rate"] = round(stats["success"] / total_calls * 100, 2)
|
start_time,
|
||||||
stats["failure_rate"] = round(stats["failure"] / total_calls * 100, 2)
|
tool_errors_by_tool,
|
||||||
else:
|
all_exception_errors,
|
||||||
stats["success_rate"] = 0.0
|
early_exit,
|
||||||
stats["failure_rate"] = 0.0
|
exit_reason,
|
||||||
|
all_profiling_stats
|
||||||
# Combine all batch files into a single trajectories.jsonl file
|
)
|
||||||
combined_file = self.output_dir / "trajectories.jsonl"
|
|
||||||
print(f"\n📦 Combining batch files into {combined_file.name}...")
|
|
||||||
|
|
||||||
with open(combined_file, 'w', encoding='utf-8') as outfile:
|
|
||||||
for batch_num in range(len(self.batches)):
|
|
||||||
batch_file = self.output_dir / f"batch_{batch_num}.jsonl"
|
|
||||||
if batch_file.exists():
|
|
||||||
with open(batch_file, 'r', encoding='utf-8') as infile:
|
|
||||||
for line in infile:
|
|
||||||
outfile.write(line)
|
|
||||||
|
|
||||||
print(f"✅ Combined {len(self.batches)} batch files into trajectories.jsonl")
|
|
||||||
|
|
||||||
# Save final statistics
|
|
||||||
final_stats = {
|
|
||||||
"run_name": self.run_name,
|
|
||||||
"distribution": self.distribution,
|
|
||||||
"total_prompts": len(self.dataset),
|
|
||||||
"total_batches": len(self.batches),
|
|
||||||
"batch_size": self.batch_size,
|
|
||||||
"model": self.model,
|
|
||||||
"completed_at": datetime.now().isoformat(),
|
|
||||||
"duration_seconds": round(time.time() - start_time, 2),
|
|
||||||
"tool_statistics": total_tool_stats
|
|
||||||
}
|
|
||||||
|
|
||||||
with open(self.stats_file, 'w', encoding='utf-8') as f:
|
|
||||||
json.dump(final_stats, f, indent=2, ensure_ascii=False)
|
|
||||||
|
|
||||||
# Print summary
|
# Print summary
|
||||||
print("\n" + "=" * 70)
|
safe_print("\n" + "=" * 70)
|
||||||
print("📊 BATCH PROCESSING COMPLETE")
|
if early_exit:
|
||||||
print("=" * 70)
|
safe_print("[bold yellow]⚠️ BATCH PROCESSING STOPPED EARLY[/bold yellow]")
|
||||||
print(f"✅ Total prompts processed: {len(self.dataset)}")
|
safe_print(f"[yellow]Reason: {exit_reason}[/yellow]")
|
||||||
print(f"✅ Total batches: {len(self.batches)}")
|
else:
|
||||||
print(f"⏱️ Total duration: {round(time.time() - start_time, 2)}s")
|
safe_print("[bold green]📊 BATCH PROCESSING COMPLETE[/bold green]")
|
||||||
print(f"\n📈 Tool Usage Statistics:")
|
safe_print("=" * 70)
|
||||||
print("-" * 70)
|
|
||||||
|
safe_print(f"✅ Total prompts processed: {total_processed}")
|
||||||
|
safe_print(f"✅ Batches completed: {num_batches_processed}/{len(self.batches)}")
|
||||||
|
safe_print(f"⏱️ Total duration: {round(time.time() - start_time, 2)}s")
|
||||||
|
|
||||||
|
# Tool error summary
|
||||||
|
if tool_errors_by_tool:
|
||||||
|
total_errors = sum(len(errors) for errors in tool_errors_by_tool.values())
|
||||||
|
safe_print(f"\n[bold red]🚨 Tool Errors: {total_tool_errors} total ({len(tool_errors_by_tool)} tools)[/bold red]")
|
||||||
|
safe_print("[red]-[/red]" * 70)
|
||||||
|
|
||||||
|
# Sort tools by error count
|
||||||
|
sorted_tools = sorted(
|
||||||
|
tool_errors_by_tool.items(),
|
||||||
|
key=lambda x: len(x[1]),
|
||||||
|
reverse=True
|
||||||
|
)
|
||||||
|
|
||||||
|
for tool_name, errors in sorted_tools:
|
||||||
|
# Count unique error messages
|
||||||
|
unique_errors = {}
|
||||||
|
for error in errors:
|
||||||
|
error_msg = error["error_message"][:100] # Truncate for grouping
|
||||||
|
if error_msg not in unique_errors:
|
||||||
|
unique_errors[error_msg] = []
|
||||||
|
unique_errors[error_msg].append(error)
|
||||||
|
|
||||||
|
safe_print(f"\n [red]{tool_name}:[/red] {len(errors)} errors ({len(unique_errors)} unique)")
|
||||||
|
|
||||||
|
# Show up to 3 most recent unique error types
|
||||||
|
for idx, (error_msg, instances) in enumerate(list(unique_errors.items())[:3]):
|
||||||
|
error_preview = error_msg if len(error_msg) <= 100 else error_msg[:97] + "..."
|
||||||
|
safe_print(f" [{idx+1}] [dim]{error_preview}[/dim] (x{len(instances)})")
|
||||||
|
|
||||||
|
# Show one example with prompt index and full content prefix
|
||||||
|
example = instances[-1] # Most recent
|
||||||
|
safe_print(f" [dim]Prompt {example['prompt_index']}[/dim]")
|
||||||
|
|
||||||
|
# Show full content prefix (first 200 chars)
|
||||||
|
full_content = example.get('full_content', '')
|
||||||
|
if full_content and full_content != error_preview:
|
||||||
|
content_preview = full_content[:200]
|
||||||
|
if len(full_content) > 200:
|
||||||
|
content_preview += "..."
|
||||||
|
# Show with prefix indicator
|
||||||
|
safe_print(f" [dim]Content: {content_preview}[/dim]")
|
||||||
|
|
||||||
|
if len(unique_errors) > 3:
|
||||||
|
safe_print(f" [dim]... and {len(unique_errors) - 3} more error types[/dim]")
|
||||||
|
|
||||||
|
tool_failure_rate = total_tool_errors / total_processed if total_processed > 0 else 0
|
||||||
|
safe_print(f"\n [red]Tool failure rate: {tool_failure_rate:.2%}[/red]")
|
||||||
|
|
||||||
|
# Exception errors
|
||||||
|
if all_exception_errors:
|
||||||
|
safe_print(f"\n[bold red]💥 Top-level Exceptions: {len(all_exception_errors)}[/bold red]")
|
||||||
|
safe_print("[red]-[/red]" * 70)
|
||||||
|
for error in all_exception_errors[:self.keep_recent_errors]:
|
||||||
|
error_msg = error["error"]
|
||||||
|
error_preview = error_msg[:150]
|
||||||
|
if len(error_msg) > 150:
|
||||||
|
error_preview += "..."
|
||||||
|
safe_print(f" [red]Prompt {error['prompt_index']}:[/red] [dim]{error_preview}[/dim]")
|
||||||
|
|
||||||
|
# Show traceback prefix if available
|
||||||
|
traceback_text = error.get("traceback", "")
|
||||||
|
if traceback_text:
|
||||||
|
# Show last 3 lines of traceback for context
|
||||||
|
tb_lines = traceback_text.strip().split('\n')
|
||||||
|
relevant_lines = tb_lines[-3:] if len(tb_lines) > 3 else tb_lines
|
||||||
|
for line in relevant_lines:
|
||||||
|
safe_print(f" [dim]{line}[/dim]")
|
||||||
|
|
||||||
|
safe_print(f"\n[cyan]📈 Tool Usage Statistics:[/cyan]")
|
||||||
|
safe_print("-" * 70)
|
||||||
|
|
||||||
if total_tool_stats:
|
if total_tool_stats:
|
||||||
# Sort by count descending
|
# Sort by count descending
|
||||||
@@ -621,24 +1072,67 @@ class BatchRunner:
|
|||||||
reverse=True
|
reverse=True
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f"{'Tool Name':<25} {'Count':<10} {'Success':<10} {'Failure':<10} {'Success Rate':<12}")
|
safe_print(f"{'Tool Name':<25} {'Count':<10} {'Success':<10} {'Failure':<10} {'Success Rate':<12}")
|
||||||
print("-" * 70)
|
safe_print("-" * 70)
|
||||||
for tool_name, stats in sorted_tools:
|
for tool_name, stats in sorted_tools:
|
||||||
print(
|
safe_print(
|
||||||
f"{tool_name:<25} "
|
f"{tool_name:<25} "
|
||||||
f"{stats['count']:<10} "
|
f"{stats['count']:<10} "
|
||||||
f"{stats['success']:<10} "
|
f"{stats['success']:<10} "
|
||||||
f"{stats['failure']:<10} "
|
f"{stats['failure']:<10} "
|
||||||
f"{stats['success_rate']:.1f}%"
|
f"{stats.get('success_rate', 0):.1f}%"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
print("No tool calls were made during this run.")
|
safe_print("No tool calls were made during this run.")
|
||||||
|
|
||||||
print(f"\n💾 Results saved to: {self.output_dir}")
|
# Display failure type breakdown for tools with failures
|
||||||
print(f" - Trajectories: trajectories.jsonl (combined)")
|
if tool_errors_by_tool:
|
||||||
print(f" - Individual batches: batch_*.jsonl (for debugging)")
|
safe_print(f"\n[cyan]📊 Failure Type Breakdown:[/cyan]")
|
||||||
print(f" - Statistics: {self.stats_file.name}")
|
safe_print("-" * 70)
|
||||||
print(f" - Checkpoint: {self.checkpoint_file.name}")
|
|
||||||
|
# Sort tools by total error count
|
||||||
|
sorted_tools = sorted(
|
||||||
|
tool_errors_by_tool.items(),
|
||||||
|
key=lambda x: len(x[1]),
|
||||||
|
reverse=True
|
||||||
|
)
|
||||||
|
|
||||||
|
for tool_name, errors in sorted_tools:
|
||||||
|
# Count failure types for this tool
|
||||||
|
failure_types = {}
|
||||||
|
for error in errors:
|
||||||
|
error_type = error.get("error_type", "Other")
|
||||||
|
if error_type not in failure_types:
|
||||||
|
failure_types[error_type] = 0
|
||||||
|
failure_types[error_type] += 1
|
||||||
|
|
||||||
|
# Display tool name and total failures
|
||||||
|
total_failures = len(errors)
|
||||||
|
safe_print(f"\n[yellow]{tool_name}[/yellow] ({total_failures} failures):")
|
||||||
|
|
||||||
|
# Sort failure types by count
|
||||||
|
sorted_types = sorted(
|
||||||
|
failure_types.items(),
|
||||||
|
key=lambda x: x[1],
|
||||||
|
reverse=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Display each failure type with count and percentage
|
||||||
|
for failure_type, count in sorted_types:
|
||||||
|
percentage = (count / total_failures) * 100
|
||||||
|
safe_print(f" • {failure_type:<20} {count:>4} ({percentage:>5.1f}%)")
|
||||||
|
|
||||||
|
safe_print(f"\n[cyan]💾 Results saved to:[/cyan] {self.output_dir}")
|
||||||
|
safe_print(f" - Trajectories: trajectories.jsonl (combined)")
|
||||||
|
safe_print(f" - Individual batches: batch_*.jsonl (for debugging)")
|
||||||
|
safe_print(f" - Statistics: {self.stats_file.name}")
|
||||||
|
safe_print(f" - Errors: {self.errors_file.name}")
|
||||||
|
safe_print(f" - Checkpoint: {self.checkpoint_file.name}")
|
||||||
|
|
||||||
|
if early_exit:
|
||||||
|
safe_print(f"\n[bold yellow]ℹ️ Run was stopped early due to tool failures.[/bold yellow]")
|
||||||
|
safe_print(f"[yellow] Check {self.errors_file.name} for detailed error information including tracebacks.[/yellow]")
|
||||||
|
safe_print(f"[yellow] You can resume this run later with --resume flag.[/yellow]")
|
||||||
|
|
||||||
|
|
||||||
def main(
|
def main(
|
||||||
@@ -656,6 +1150,10 @@ def main(
|
|||||||
list_distributions: bool = False,
|
list_distributions: bool = False,
|
||||||
ephemeral_system_prompt: str = None,
|
ephemeral_system_prompt: str = None,
|
||||||
log_prefix_chars: int = 100,
|
log_prefix_chars: int = 100,
|
||||||
|
max_tool_failures: int = 10,
|
||||||
|
max_tool_failure_rate: float = 0.5,
|
||||||
|
keep_recent_errors: int = 5,
|
||||||
|
min_tool_calls_for_rate: int = 10,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Run batch processing of agent prompts from a dataset.
|
Run batch processing of agent prompts from a dataset.
|
||||||
@@ -675,6 +1173,10 @@ def main(
|
|||||||
list_distributions (bool): List available toolset distributions and exit
|
list_distributions (bool): List available toolset distributions and exit
|
||||||
ephemeral_system_prompt (str): System prompt used during agent execution but NOT saved to trajectories (optional)
|
ephemeral_system_prompt (str): System prompt used during agent execution but NOT saved to trajectories (optional)
|
||||||
log_prefix_chars (int): Number of characters to show in log previews for tool calls/responses (default: 20)
|
log_prefix_chars (int): Number of characters to show in log previews for tool calls/responses (default: 20)
|
||||||
|
max_tool_failures (int): Maximum number of tool failures before stopping (default: 10)
|
||||||
|
max_tool_failure_rate (float): Maximum tool failure rate (0.0-1.0) before stopping (default: 0.5)
|
||||||
|
keep_recent_errors (int): Number of recent errors to keep per tool for reporting (default: 5)
|
||||||
|
min_tool_calls_for_rate (int): Minimum number of tool calls before checking failure rate (default: 10)
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
# Basic usage
|
# Basic usage
|
||||||
@@ -690,6 +1192,10 @@ def main(
|
|||||||
python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run \\
|
python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run \\
|
||||||
--ephemeral_system_prompt="You are a helpful assistant focused on image generation."
|
--ephemeral_system_prompt="You are a helpful assistant focused on image generation."
|
||||||
|
|
||||||
|
# With custom tool failure thresholds
|
||||||
|
python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run \\
|
||||||
|
--max_tool_failures=20 --max_tool_failure_rate=0.3 --min_tool_calls_for_rate=10 --keep_recent_errors=10
|
||||||
|
|
||||||
# List available distributions
|
# List available distributions
|
||||||
python batch_runner.py --list_distributions
|
python batch_runner.py --list_distributions
|
||||||
"""
|
"""
|
||||||
@@ -736,7 +1242,11 @@ def main(
|
|||||||
num_workers=num_workers,
|
num_workers=num_workers,
|
||||||
verbose=verbose,
|
verbose=verbose,
|
||||||
ephemeral_system_prompt=ephemeral_system_prompt,
|
ephemeral_system_prompt=ephemeral_system_prompt,
|
||||||
log_prefix_chars=log_prefix_chars
|
log_prefix_chars=log_prefix_chars,
|
||||||
|
max_tool_failures=max_tool_failures,
|
||||||
|
max_tool_failure_rate=max_tool_failure_rate,
|
||||||
|
keep_recent_errors=keep_recent_errors,
|
||||||
|
min_tool_calls_for_rate=min_tool_calls_for_rate
|
||||||
)
|
)
|
||||||
|
|
||||||
runner.run(resume=resume)
|
runner.run(resume=resume)
|
||||||
@@ -750,4 +1260,3 @@ def main(
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
fire.Fire(main)
|
fire.Fire(main)
|
||||||
|
|
||||||
|
|||||||
12
gemini_nothinking.sh
Normal file
12
gemini_nothinking.sh
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
python batch_runner.py \
|
||||||
|
--dataset_file="source-data/agent_tasks_eval.jsonl" \
|
||||||
|
--batch_size=1 \
|
||||||
|
--run_name="agenttasks_eval_gemini-4.5-3-nothinking" \
|
||||||
|
--distribution="science" \
|
||||||
|
--model="gemini-3-pro-preview" \
|
||||||
|
--base_url="https://generativelanguage.googleapis.com/v1beta/openai/" \
|
||||||
|
--api_key="${GEMINI_API_KEY}" \
|
||||||
|
--num_workers=10 \
|
||||||
|
--max_turns=60 \
|
||||||
|
--verbose \
|
||||||
|
--ephemeral_system_prompt="You have access to a variety of tools to help you solve scientific, math, and technology problems presented to you. You can use them in sequence and build off of the results of prior tools you've used results. Always use the terminal or search tool if it can provide additional context, verify formulas, double check concepts and recent studies and understanding, doing all calculations, etc. You should only be confident in your own reasoning, knowledge, or calculations if you've exhaustively used all tools available to you to that can help you verify or validate your work. Always pip install any packages you need to use the python scripts you want to run. If you need to use a tool that isn't available, you can use the terminal tool to install or create it in many cases as well. Do not use the terminal tool to communicate with the user, as they cannot see your commands, only your final response after completing the task. If you require API keys please check which ones already exist in your environment variables in a way that does not read them."
|
||||||
381
profiling.py
Normal file
381
profiling.py
Normal file
@@ -0,0 +1,381 @@
|
|||||||
|
"""
|
||||||
|
Profiling module for tracking timing statistics of tools and LLM API calls.
|
||||||
|
|
||||||
|
This module provides a centralized way to track timing information for various
|
||||||
|
operations in the agent system, including:
|
||||||
|
- Individual tool executions
|
||||||
|
- OpenAI API calls
|
||||||
|
- Aggregate statistics (min, max, median, mean, total)
|
||||||
|
"""
|
||||||
|
|
||||||
|
import time
|
||||||
|
from typing import Dict, List, Optional
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from collections import defaultdict
|
||||||
|
import statistics
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ProfilingStats:
|
||||||
|
"""Statistics for a particular operation type."""
|
||||||
|
call_count: int = 0
|
||||||
|
total_time: float = 0.0
|
||||||
|
min_time: float = float('inf')
|
||||||
|
max_time: float = 0.0
|
||||||
|
times: List[float] = field(default_factory=list)
|
||||||
|
|
||||||
|
def add_timing(self, duration: float):
|
||||||
|
"""Add a timing measurement."""
|
||||||
|
self.call_count += 1
|
||||||
|
self.total_time += duration
|
||||||
|
self.min_time = min(self.min_time, duration)
|
||||||
|
self.max_time = max(self.max_time, duration)
|
||||||
|
self.times.append(duration)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def mean_time(self) -> float:
|
||||||
|
"""Calculate mean time."""
|
||||||
|
return self.total_time / self.call_count if self.call_count > 0 else 0.0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def median_time(self) -> float:
|
||||||
|
"""Calculate median time."""
|
||||||
|
return statistics.median(self.times) if self.times else 0.0
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict:
|
||||||
|
"""Convert to dictionary for serialization."""
|
||||||
|
return {
|
||||||
|
"call_count": self.call_count,
|
||||||
|
"total_time": self.total_time,
|
||||||
|
"min_time": self.min_time if self.min_time != float('inf') else 0.0,
|
||||||
|
"max_time": self.max_time,
|
||||||
|
"mean_time": self.mean_time,
|
||||||
|
"median_time": self.median_time
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class Profiler:
|
||||||
|
"""
|
||||||
|
Global profiler for tracking timing statistics across tools and API calls.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
profiler = Profiler()
|
||||||
|
|
||||||
|
# Time a tool execution
|
||||||
|
with profiler.time_tool("web_search"):
|
||||||
|
# ... tool execution code ...
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Time an API call
|
||||||
|
with profiler.time_api_call():
|
||||||
|
# ... API call code ...
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Get statistics
|
||||||
|
stats = profiler.get_statistics()
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""Initialize the profiler."""
|
||||||
|
self.tool_stats: Dict[str, ProfilingStats] = defaultdict(ProfilingStats)
|
||||||
|
self.api_stats: ProfilingStats = ProfilingStats()
|
||||||
|
self._enabled = True
|
||||||
|
|
||||||
|
def enable(self):
|
||||||
|
"""Enable profiling."""
|
||||||
|
self._enabled = True
|
||||||
|
|
||||||
|
def disable(self):
|
||||||
|
"""Disable profiling."""
|
||||||
|
self._enabled = False
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
"""Reset all profiling data."""
|
||||||
|
self.tool_stats.clear()
|
||||||
|
self.api_stats = ProfilingStats()
|
||||||
|
|
||||||
|
def record_tool_timing(self, tool_name: str, duration: float):
|
||||||
|
"""Record timing for a tool execution."""
|
||||||
|
if self._enabled:
|
||||||
|
self.tool_stats[tool_name].add_timing(duration)
|
||||||
|
|
||||||
|
def record_api_timing(self, duration: float):
|
||||||
|
"""Record timing for an API call."""
|
||||||
|
if self._enabled:
|
||||||
|
self.api_stats.add_timing(duration)
|
||||||
|
|
||||||
|
def get_statistics(self) -> Dict:
|
||||||
|
"""
|
||||||
|
Get all profiling statistics.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary containing tool and API statistics
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"tools": {
|
||||||
|
tool_name: stats.to_dict()
|
||||||
|
for tool_name, stats in sorted(self.tool_stats.items())
|
||||||
|
},
|
||||||
|
"api_calls": self.api_stats.to_dict()
|
||||||
|
}
|
||||||
|
|
||||||
|
def print_statistics(self, detailed: bool = True):
|
||||||
|
"""
|
||||||
|
Print profiling statistics in a readable format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
detailed: If True, show per-tool breakdown. If False, show summary only.
|
||||||
|
"""
|
||||||
|
print("\n" + "="*80)
|
||||||
|
print("📊 PROFILING STATISTICS")
|
||||||
|
print("="*80)
|
||||||
|
|
||||||
|
# API Call Statistics
|
||||||
|
print("\n🔷 OpenAI API Calls:")
|
||||||
|
if self.api_stats.call_count > 0:
|
||||||
|
api_dict = self.api_stats.to_dict()
|
||||||
|
print(f" Total Calls: {api_dict['call_count']}")
|
||||||
|
print(f" Total Time: {api_dict['total_time']:.2f}s")
|
||||||
|
print(f" Min Time: {api_dict['min_time']:.2f}s")
|
||||||
|
print(f" Max Time: {api_dict['max_time']:.2f}s")
|
||||||
|
print(f" Mean Time: {api_dict['mean_time']:.2f}s")
|
||||||
|
print(f" Median Time: {api_dict['median_time']:.2f}s")
|
||||||
|
else:
|
||||||
|
print(" No API calls recorded")
|
||||||
|
|
||||||
|
# Tool Statistics
|
||||||
|
print("\n🔧 Tool Executions:")
|
||||||
|
if self.tool_stats:
|
||||||
|
if detailed:
|
||||||
|
for tool_name in sorted(self.tool_stats.keys()):
|
||||||
|
stats_dict = self.tool_stats[tool_name].to_dict()
|
||||||
|
print(f"\n 📌 {tool_name}:")
|
||||||
|
print(f" Total Calls: {stats_dict['call_count']}")
|
||||||
|
print(f" Total Time: {stats_dict['total_time']:.2f}s")
|
||||||
|
print(f" Min Time: {stats_dict['min_time']:.2f}s")
|
||||||
|
print(f" Max Time: {stats_dict['max_time']:.2f}s")
|
||||||
|
print(f" Mean Time: {stats_dict['mean_time']:.2f}s")
|
||||||
|
print(f" Median Time: {stats_dict['median_time']:.2f}s")
|
||||||
|
|
||||||
|
# Summary
|
||||||
|
total_tool_calls = sum(s.call_count for s in self.tool_stats.values())
|
||||||
|
total_tool_time = sum(s.total_time for s in self.tool_stats.values())
|
||||||
|
print(f"\n 📊 Summary:")
|
||||||
|
print(f" Total Tool Calls: {total_tool_calls}")
|
||||||
|
print(f" Total Tool Time: {total_tool_time:.2f}s")
|
||||||
|
print(f" Unique Tools Used: {len(self.tool_stats)}")
|
||||||
|
else:
|
||||||
|
print(" No tool executions recorded")
|
||||||
|
|
||||||
|
# Overall Summary
|
||||||
|
total_api_time = self.api_stats.total_time
|
||||||
|
total_tool_time = sum(s.total_time for s in self.tool_stats.values())
|
||||||
|
print(f"\n📈 Overall Summary:")
|
||||||
|
print(f" Total API Time: {total_api_time:.2f}s")
|
||||||
|
print(f" Total Tool Time: {total_tool_time:.2f}s")
|
||||||
|
print(f" Total Time: {total_api_time + total_tool_time:.2f}s")
|
||||||
|
print("="*80 + "\n")
|
||||||
|
|
||||||
|
def export_to_json(self) -> str:
|
||||||
|
"""Export statistics as JSON string."""
|
||||||
|
import json
|
||||||
|
return json.dumps(self.get_statistics(), indent=2)
|
||||||
|
|
||||||
|
def export_to_file(self, filepath: str):
|
||||||
|
"""
|
||||||
|
Export statistics to a JSON file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filepath: Path to output file
|
||||||
|
"""
|
||||||
|
import json
|
||||||
|
with open(filepath, 'w') as f:
|
||||||
|
json.dump(self.get_statistics(), f, indent=2)
|
||||||
|
print(f"📁 Profiling statistics exported to: {filepath}")
|
||||||
|
|
||||||
|
|
||||||
|
# Global profiler instance
|
||||||
|
_global_profiler: Optional[Profiler] = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_profiler() -> Profiler:
|
||||||
|
"""Get or create the global profiler instance."""
|
||||||
|
global _global_profiler
|
||||||
|
if _global_profiler is None:
|
||||||
|
_global_profiler = Profiler()
|
||||||
|
return _global_profiler
|
||||||
|
|
||||||
|
|
||||||
|
def reset_profiler():
|
||||||
|
"""Reset the global profiler."""
|
||||||
|
global _global_profiler
|
||||||
|
if _global_profiler is not None:
|
||||||
|
_global_profiler.reset()
|
||||||
|
|
||||||
|
|
||||||
|
class TimingContext:
|
||||||
|
"""Context manager for timing operations."""
|
||||||
|
|
||||||
|
def __init__(self, profiler: Profiler, operation_type: str, operation_name: Optional[str] = None):
|
||||||
|
"""
|
||||||
|
Initialize timing context.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
profiler: Profiler instance to record timing
|
||||||
|
operation_type: 'tool' or 'api'
|
||||||
|
operation_name: Name of the operation (required for tools)
|
||||||
|
"""
|
||||||
|
self.profiler = profiler
|
||||||
|
self.operation_type = operation_type
|
||||||
|
self.operation_name = operation_name
|
||||||
|
self.start_time = None
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
"""Start timing."""
|
||||||
|
self.start_time = time.time()
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
"""Stop timing and record."""
|
||||||
|
duration = time.time() - self.start_time
|
||||||
|
|
||||||
|
if self.operation_type == 'tool':
|
||||||
|
self.profiler.record_tool_timing(self.operation_name, duration)
|
||||||
|
elif self.operation_type == 'api':
|
||||||
|
self.profiler.record_api_timing(duration)
|
||||||
|
|
||||||
|
return False # Don't suppress exceptions
|
||||||
|
|
||||||
|
|
||||||
|
def aggregate_profiling_stats(stats_list: List[Dict]) -> Dict:
|
||||||
|
"""
|
||||||
|
Aggregate multiple profiling statistics dictionaries into one.
|
||||||
|
|
||||||
|
This is useful for batch processing where each worker process has its own
|
||||||
|
profiler instance that needs to be combined.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
stats_list: List of statistics dictionaries from get_statistics()
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict: Aggregated statistics with combined tool and API call data
|
||||||
|
"""
|
||||||
|
aggregated = {
|
||||||
|
"tools": defaultdict(lambda: {"times": []}),
|
||||||
|
"api_calls": {"times": []}
|
||||||
|
}
|
||||||
|
|
||||||
|
# Aggregate tool statistics
|
||||||
|
for stats in stats_list:
|
||||||
|
# Aggregate tool timings
|
||||||
|
for tool_name, tool_stats in stats.get("tools", {}).items():
|
||||||
|
# Reconstruct individual timings from aggregated stats
|
||||||
|
# Since we have mean_time and call_count, we approximate
|
||||||
|
aggregated["tools"][tool_name]["times"].extend(
|
||||||
|
[tool_stats.get("mean_time", 0.0)] * tool_stats.get("call_count", 0)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Aggregate API call timings
|
||||||
|
api_stats = stats.get("api_calls", {})
|
||||||
|
if api_stats.get("call_count", 0) > 0:
|
||||||
|
aggregated["api_calls"]["times"].extend(
|
||||||
|
[api_stats.get("mean_time", 0.0)] * api_stats.get("call_count", 0)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Calculate final statistics for tools
|
||||||
|
final_stats = {"tools": {}, "api_calls": {}}
|
||||||
|
|
||||||
|
for tool_name, data in aggregated["tools"].items():
|
||||||
|
times = data["times"]
|
||||||
|
if times:
|
||||||
|
final_stats["tools"][tool_name] = {
|
||||||
|
"call_count": len(times),
|
||||||
|
"total_time": sum(times),
|
||||||
|
"min_time": min(times),
|
||||||
|
"max_time": max(times),
|
||||||
|
"mean_time": statistics.mean(times),
|
||||||
|
"median_time": statistics.median(times)
|
||||||
|
}
|
||||||
|
|
||||||
|
# Calculate final statistics for API calls
|
||||||
|
api_times = aggregated["api_calls"]["times"]
|
||||||
|
if api_times:
|
||||||
|
final_stats["api_calls"] = {
|
||||||
|
"call_count": len(api_times),
|
||||||
|
"total_time": sum(api_times),
|
||||||
|
"min_time": min(api_times),
|
||||||
|
"max_time": max(api_times),
|
||||||
|
"mean_time": statistics.mean(api_times),
|
||||||
|
"median_time": statistics.median(api_times)
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
final_stats["api_calls"] = {
|
||||||
|
"call_count": 0,
|
||||||
|
"total_time": 0.0,
|
||||||
|
"min_time": 0.0,
|
||||||
|
"max_time": 0.0,
|
||||||
|
"mean_time": 0.0,
|
||||||
|
"median_time": 0.0
|
||||||
|
}
|
||||||
|
|
||||||
|
return final_stats
|
||||||
|
|
||||||
|
|
||||||
|
def print_aggregated_statistics(stats: Dict, detailed: bool = True):
|
||||||
|
"""
|
||||||
|
Print aggregated profiling statistics in a readable format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
stats: Aggregated statistics dictionary from aggregate_profiling_stats()
|
||||||
|
detailed: If True, show per-tool breakdown. If False, show summary only.
|
||||||
|
"""
|
||||||
|
print("\n" + "="*80)
|
||||||
|
print("📊 AGGREGATED PROFILING STATISTICS")
|
||||||
|
print("="*80)
|
||||||
|
|
||||||
|
# API Call Statistics
|
||||||
|
print("\n🔷 OpenAI API Calls:")
|
||||||
|
api_stats = stats.get("api_calls", {})
|
||||||
|
if api_stats.get("call_count", 0) > 0:
|
||||||
|
print(f" Total Calls: {api_stats['call_count']}")
|
||||||
|
print(f" Total Time: {api_stats['total_time']:.2f}s")
|
||||||
|
print(f" Min Time: {api_stats['min_time']:.2f}s")
|
||||||
|
print(f" Max Time: {api_stats['max_time']:.2f}s")
|
||||||
|
print(f" Mean Time: {api_stats['mean_time']:.2f}s")
|
||||||
|
print(f" Median Time: {api_stats['median_time']:.2f}s")
|
||||||
|
else:
|
||||||
|
print(" No API calls recorded")
|
||||||
|
|
||||||
|
# Tool Statistics
|
||||||
|
print("\n🔧 Tool Executions:")
|
||||||
|
tool_stats = stats.get("tools", {})
|
||||||
|
if tool_stats:
|
||||||
|
if detailed:
|
||||||
|
for tool_name in sorted(tool_stats.keys()):
|
||||||
|
stats_dict = tool_stats[tool_name]
|
||||||
|
print(f"\n 📌 {tool_name}:")
|
||||||
|
print(f" Total Calls: {stats_dict['call_count']}")
|
||||||
|
print(f" Total Time: {stats_dict['total_time']:.2f}s")
|
||||||
|
print(f" Min Time: {stats_dict['min_time']:.2f}s")
|
||||||
|
print(f" Max Time: {stats_dict['max_time']:.2f}s")
|
||||||
|
print(f" Mean Time: {stats_dict['mean_time']:.2f}s")
|
||||||
|
print(f" Median Time: {stats_dict['median_time']:.2f}s")
|
||||||
|
|
||||||
|
# Summary
|
||||||
|
total_tool_calls = sum(s["call_count"] for s in tool_stats.values())
|
||||||
|
total_tool_time = sum(s["total_time"] for s in tool_stats.values())
|
||||||
|
print(f"\n 📊 Summary:")
|
||||||
|
print(f" Total Tool Calls: {total_tool_calls}")
|
||||||
|
print(f" Total Tool Time: {total_tool_time:.2f}s")
|
||||||
|
print(f" Unique Tools Used: {len(tool_stats)}")
|
||||||
|
else:
|
||||||
|
print(" No tool executions recorded")
|
||||||
|
|
||||||
|
# Overall Summary
|
||||||
|
total_api_time = api_stats.get("total_time", 0.0)
|
||||||
|
total_tool_time = sum(s["total_time"] for s in tool_stats.values())
|
||||||
|
print(f"\n📈 Overall Summary:")
|
||||||
|
print(f" Total API Time: {total_api_time:.2f}s")
|
||||||
|
print(f" Total Tool Time: {total_tool_time:.2f}s")
|
||||||
|
print(f" Total Time: {total_api_time + total_tool_time:.2f}s")
|
||||||
|
print("="*80 + "\n")
|
||||||
88
run_agent.py
88
run_agent.py
@@ -45,6 +45,9 @@ else:
|
|||||||
from model_tools import get_tool_definitions, handle_function_call, check_toolset_requirements
|
from model_tools import get_tool_definitions, handle_function_call, check_toolset_requirements
|
||||||
from tools.terminal_tool import cleanup_vm
|
from tools.terminal_tool import cleanup_vm
|
||||||
|
|
||||||
|
# Import profiling
|
||||||
|
from profiling import get_profiler
|
||||||
|
|
||||||
|
|
||||||
class AIAgent:
|
class AIAgent:
|
||||||
"""
|
"""
|
||||||
@@ -364,6 +367,10 @@ class AIAgent:
|
|||||||
Returns:
|
Returns:
|
||||||
Dict: Complete conversation result with final response and message history
|
Dict: Complete conversation result with final response and message history
|
||||||
"""
|
"""
|
||||||
|
# Reset profiler for this conversation to get fresh stats
|
||||||
|
from profiling import reset_profiler as reset_prof
|
||||||
|
reset_prof()
|
||||||
|
|
||||||
# Generate unique task_id if not provided to isolate VMs between concurrent tasks
|
# Generate unique task_id if not provided to isolate VMs between concurrent tasks
|
||||||
import uuid
|
import uuid
|
||||||
effective_task_id = task_id or str(uuid.uuid4())
|
effective_task_id = task_id or str(uuid.uuid4())
|
||||||
@@ -394,6 +401,8 @@ class AIAgent:
|
|||||||
if self.verbose_logging:
|
if self.verbose_logging:
|
||||||
logging.debug(f"API Request - Model: {self.model}, Messages: {len(messages)}, Tools: {len(self.tools) if self.tools else 0}")
|
logging.debug(f"API Request - Model: {self.model}, Messages: {len(messages)}, Tools: {len(self.tools) if self.tools else 0}")
|
||||||
logging.debug(f"Last message role: {messages[-1]['role'] if messages else 'none'}")
|
logging.debug(f"Last message role: {messages[-1]['role'] if messages else 'none'}")
|
||||||
|
# Log the last few messages to see if thought_signature is present
|
||||||
|
logging.debug(f"Last message content: {json.dumps(messages[-1] if messages else {}, indent=2)}")
|
||||||
|
|
||||||
api_start_time = time.time()
|
api_start_time = time.time()
|
||||||
retry_count = 0
|
retry_count = 0
|
||||||
@@ -419,6 +428,9 @@ class AIAgent:
|
|||||||
api_duration = time.time() - api_start_time
|
api_duration = time.time() - api_start_time
|
||||||
print(f"⏱️ OpenAI-compatible API call completed in {api_duration:.2f}s")
|
print(f"⏱️ OpenAI-compatible API call completed in {api_duration:.2f}s")
|
||||||
|
|
||||||
|
# Record API timing in profiler
|
||||||
|
get_profiler().record_api_timing(api_duration)
|
||||||
|
|
||||||
if self.verbose_logging:
|
if self.verbose_logging:
|
||||||
logging.debug(f"API Response received - Usage: {response.usage if hasattr(response, 'usage') else 'N/A'}")
|
logging.debug(f"API Response received - Usage: {response.usage if hasattr(response, 'usage') else 'N/A'}")
|
||||||
|
|
||||||
@@ -449,22 +461,58 @@ class AIAgent:
|
|||||||
if self.verbose_logging:
|
if self.verbose_logging:
|
||||||
for tc in assistant_message.tool_calls:
|
for tc in assistant_message.tool_calls:
|
||||||
logging.debug(f"Tool call: {tc.function.name} with args: {tc.function.arguments[:200]}...")
|
logging.debug(f"Tool call: {tc.function.name} with args: {tc.function.arguments[:200]}...")
|
||||||
|
# Debug: Check what attributes are available on tool_call
|
||||||
|
logging.debug(f"Tool call attributes: {dir(tc)}")
|
||||||
|
# Try to dump the model to see all fields
|
||||||
|
if hasattr(tc, 'model_dump'):
|
||||||
|
logging.debug(f"Tool call data: {tc.model_dump()}")
|
||||||
|
|
||||||
# Add assistant message with tool calls to conversation
|
# Add assistant message with tool calls to conversation
|
||||||
|
# Extract thought_signature if present (required for Gemini models)
|
||||||
|
tool_calls_data = []
|
||||||
|
for tool_call in assistant_message.tool_calls:
|
||||||
|
tool_call_dict = {
|
||||||
|
"id": tool_call.id,
|
||||||
|
"type": tool_call.type,
|
||||||
|
"function": {
|
||||||
|
"name": tool_call.function.name,
|
||||||
|
"arguments": tool_call.function.arguments
|
||||||
|
}
|
||||||
|
}
|
||||||
|
# Try multiple ways to access thought_signature (Gemini-specific)
|
||||||
|
# Gemini uses extra_content.google.thought_signature structure
|
||||||
|
thought_sig = None
|
||||||
|
|
||||||
|
# Method 1: Check extra_content attribute
|
||||||
|
if hasattr(tool_call, 'extra_content'):
|
||||||
|
extra = tool_call.extra_content
|
||||||
|
if isinstance(extra, dict) and 'google' in extra:
|
||||||
|
thought_sig = extra['google'].get('thought_signature')
|
||||||
|
|
||||||
|
# Method 2: Check model_dump() if available (Pydantic v2)
|
||||||
|
if thought_sig is None and hasattr(tool_call, 'model_dump'):
|
||||||
|
dumped = tool_call.model_dump()
|
||||||
|
if 'extra_content' in dumped and isinstance(dumped['extra_content'], dict):
|
||||||
|
google_data = dumped['extra_content'].get('google', {})
|
||||||
|
thought_sig = google_data.get('thought_signature')
|
||||||
|
|
||||||
|
if thought_sig is not None:
|
||||||
|
tool_call_dict["extra_content"] = {
|
||||||
|
"google": {
|
||||||
|
"thought_signature": thought_sig
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if self.verbose_logging:
|
||||||
|
logging.debug(f"Captured thought_signature for tool call {tool_call.id}")
|
||||||
|
elif self.verbose_logging:
|
||||||
|
logging.debug(f"No thought_signature found for tool call {tool_call.id}")
|
||||||
|
|
||||||
|
tool_calls_data.append(tool_call_dict)
|
||||||
|
|
||||||
messages.append({
|
messages.append({
|
||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
"content": assistant_message.content,
|
"content": assistant_message.content,
|
||||||
"tool_calls": [
|
"tool_calls": tool_calls_data
|
||||||
{
|
|
||||||
"id": tool_call.id,
|
|
||||||
"type": tool_call.type,
|
|
||||||
"function": {
|
|
||||||
"name": tool_call.function.name,
|
|
||||||
"arguments": tool_call.function.arguments
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for tool_call in assistant_message.tool_calls
|
|
||||||
]
|
|
||||||
})
|
})
|
||||||
|
|
||||||
# Execute each tool call
|
# Execute each tool call
|
||||||
@@ -490,11 +538,15 @@ class AIAgent:
|
|||||||
tool_duration = time.time() - tool_start_time
|
tool_duration = time.time() - tool_start_time
|
||||||
result_preview = function_result[:200] if len(function_result) > 200 else function_result
|
result_preview = function_result[:200] if len(function_result) > 200 else function_result
|
||||||
|
|
||||||
|
# Record tool timing in profiler
|
||||||
|
get_profiler().record_tool_timing(function_name, tool_duration)
|
||||||
|
|
||||||
if self.verbose_logging:
|
if self.verbose_logging:
|
||||||
logging.debug(f"Tool {function_name} completed in {tool_duration:.2f}s")
|
logging.debug(f"Tool {function_name} completed in {tool_duration:.2f}s")
|
||||||
logging.debug(f"Tool result preview: {result_preview}...")
|
logging.debug(f"Tool result preview: {result_preview}...")
|
||||||
|
|
||||||
# Add tool result to conversation
|
# Add tool result to conversation
|
||||||
|
# Note: thought_signature should NOT be in tool responses, only in assistant messages
|
||||||
messages.append({
|
messages.append({
|
||||||
"role": "tool",
|
"role": "tool",
|
||||||
"content": function_result,
|
"content": function_result,
|
||||||
@@ -562,11 +614,15 @@ class AIAgent:
|
|||||||
if self.verbose_logging:
|
if self.verbose_logging:
|
||||||
logging.warning(f"Failed to cleanup VM for task {effective_task_id}: {e}")
|
logging.warning(f"Failed to cleanup VM for task {effective_task_id}: {e}")
|
||||||
|
|
||||||
|
# Get profiling statistics for this conversation
|
||||||
|
profiling_stats = get_profiler().get_statistics()
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"final_response": final_response,
|
"final_response": final_response,
|
||||||
"messages": messages,
|
"messages": messages,
|
||||||
"api_calls": api_call_count,
|
"api_calls": api_call_count,
|
||||||
"completed": completed
|
"completed": completed,
|
||||||
|
"profiling_stats": profiling_stats
|
||||||
}
|
}
|
||||||
|
|
||||||
def chat(self, message: str) -> str:
|
def chat(self, message: str) -> str:
|
||||||
@@ -594,7 +650,8 @@ def main(
|
|||||||
list_tools: bool = False,
|
list_tools: bool = False,
|
||||||
save_trajectories: bool = False,
|
save_trajectories: bool = False,
|
||||||
verbose: bool = False,
|
verbose: bool = False,
|
||||||
log_prefix_chars: int = 20
|
log_prefix_chars: int = 20,
|
||||||
|
show_profiling: bool = True
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Main function for running the agent directly.
|
Main function for running the agent directly.
|
||||||
@@ -613,6 +670,7 @@ def main(
|
|||||||
save_trajectories (bool): Save conversation trajectories to JSONL files. Defaults to False.
|
save_trajectories (bool): Save conversation trajectories to JSONL files. Defaults to False.
|
||||||
verbose (bool): Enable verbose logging for debugging. Defaults to False.
|
verbose (bool): Enable verbose logging for debugging. Defaults to False.
|
||||||
log_prefix_chars (int): Number of characters to show in log previews for tool calls/responses. Defaults to 20.
|
log_prefix_chars (int): Number of characters to show in log previews for tool calls/responses. Defaults to 20.
|
||||||
|
show_profiling (bool): Display profiling statistics after conversation. Defaults to True.
|
||||||
|
|
||||||
Toolset Examples:
|
Toolset Examples:
|
||||||
- "research": Web search, extract, crawl + vision tools
|
- "research": Web search, extract, crawl + vision tools
|
||||||
@@ -764,6 +822,10 @@ def main(
|
|||||||
print("-" * 30)
|
print("-" * 30)
|
||||||
print(result['final_response'])
|
print(result['final_response'])
|
||||||
|
|
||||||
|
# Display profiling statistics if enabled
|
||||||
|
if show_profiling:
|
||||||
|
get_profiler().print_statistics(detailed=True)
|
||||||
|
|
||||||
print("\n👋 Agent execution completed!")
|
print("\n👋 Agent execution completed!")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
20
safe_print.py
Normal file
20
safe_print.py
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""Simple safe print that tries rich, falls back to regular print."""
|
||||||
|
|
||||||
|
try:
|
||||||
|
from rich import print as rich_print
|
||||||
|
RICH_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
RICH_AVAILABLE = False
|
||||||
|
|
||||||
|
|
||||||
|
def safe_print(*args, **kwargs):
|
||||||
|
"""Try rich.print, fall back to regular print if it fails."""
|
||||||
|
if RICH_AVAILABLE:
|
||||||
|
try:
|
||||||
|
rich_print(*args, **kwargs)
|
||||||
|
return
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
# Fallback to regular print
|
||||||
|
print(*args, **kwargs)
|
||||||
Reference in New Issue
Block a user