mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-28 15:01:34 +08:00
Compare commits
7 Commits
skill/gith
...
asyncio
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7d9a1e119d | ||
|
|
e91d9e839a | ||
|
|
98321be8b0 | ||
|
|
a219e178a1 | ||
|
|
e06a15b3ab | ||
|
|
349e37de0a | ||
|
|
31c733383b |
882
batch_runner.py
882
batch_runner.py
File diff suppressed because it is too large
Load Diff
12
gemini_nothinking.sh
Normal file
12
gemini_nothinking.sh
Normal file
@@ -0,0 +1,12 @@
|
||||
python batch_runner.py \
|
||||
--dataset_file="source-data/agent_tasks_eval.jsonl" \
|
||||
--batch_size=1 \
|
||||
--run_name="agenttasks_eval_gemini-4.5-3-nothinking" \
|
||||
--distribution="science" \
|
||||
--model="gemini-3-pro-preview" \
|
||||
--base_url="https://generativelanguage.googleapis.com/v1beta/openai/" \
|
||||
--api_key="${GEMINI_API_KEY}" \
|
||||
--num_workers=10 \
|
||||
--max_turns=60 \
|
||||
--verbose \
|
||||
--ephemeral_system_prompt="You have access to a variety of tools to help you solve scientific, math, and technology problems presented to you. You can use them in sequence and build off of the results of prior tools you've used results. Always use the terminal or search tool if it can provide additional context, verify formulas, double check concepts and recent studies and understanding, doing all calculations, etc. You should only be confident in your own reasoning, knowledge, or calculations if you've exhaustively used all tools available to you to that can help you verify or validate your work. Always pip install any packages you need to use the python scripts you want to run. If you need to use a tool that isn't available, you can use the terminal tool to install or create it in many cases as well. Do not use the terminal tool to communicate with the user, as they cannot see your commands, only your final response after completing the task. If you require API keys please check which ones already exist in your environment variables in a way that does not read them."
|
||||
@@ -23,7 +23,7 @@ Usage:
|
||||
web_tools = get_tool_definitions(enabled_toolsets=['web_tools'])
|
||||
|
||||
# Handle function calls from model
|
||||
result = handle_function_call("web_search", {"query": "Python"})
|
||||
result = await handle_function_call("web_search", {"query": "Python"})
|
||||
"""
|
||||
|
||||
import json
|
||||
@@ -439,7 +439,7 @@ def get_tool_definitions(
|
||||
|
||||
return filtered_tools
|
||||
|
||||
def handle_web_function_call(function_name: str, function_args: Dict[str, Any]) -> str:
|
||||
async def handle_web_function_call(function_name: str, function_args: Dict[str, Any]) -> str:
|
||||
"""
|
||||
Handle function calls for web tools.
|
||||
|
||||
@@ -454,25 +454,25 @@ def handle_web_function_call(function_name: str, function_args: Dict[str, Any])
|
||||
query = function_args.get("query", "")
|
||||
# Always use fixed limit of 5
|
||||
limit = 5
|
||||
return web_search_tool(query, limit)
|
||||
return await web_search_tool(query, limit)
|
||||
|
||||
elif function_name == "web_extract":
|
||||
urls = function_args.get("urls", [])
|
||||
# Limit URLs to prevent abuse
|
||||
urls = urls[:5] if isinstance(urls, list) else []
|
||||
# Run async function in event loop
|
||||
return asyncio.run(web_extract_tool(urls, "markdown"))
|
||||
# Run async function
|
||||
return await web_extract_tool(urls, "markdown")
|
||||
|
||||
elif function_name == "web_crawl":
|
||||
url = function_args.get("url", "")
|
||||
instructions = function_args.get("instructions")
|
||||
# Run async function in event loop
|
||||
return asyncio.run(web_crawl_tool(url, instructions, "basic"))
|
||||
# Run async function
|
||||
return await web_crawl_tool(url, instructions, "basic")
|
||||
|
||||
else:
|
||||
return json.dumps({"error": f"Unknown web function: {function_name}"}, ensure_ascii=False)
|
||||
|
||||
def handle_terminal_function_call(function_name: str, function_args: Dict[str, Any], task_id: Optional[str] = None) -> str:
|
||||
async def handle_terminal_function_call(function_name: str, function_args: Dict[str, Any], task_id: Optional[str] = None) -> str:
|
||||
"""
|
||||
Handle function calls for terminal tools.
|
||||
|
||||
@@ -489,13 +489,20 @@ def handle_terminal_function_call(function_name: str, function_args: Dict[str, A
|
||||
background = function_args.get("background", False)
|
||||
timeout = function_args.get("timeout")
|
||||
|
||||
return simple_terminal_tool(command=command, background=background, timeout=timeout, task_id=task_id)
|
||||
# Run sync terminal tool in a thread to avoid blocking
|
||||
return await asyncio.to_thread(
|
||||
simple_terminal_tool,
|
||||
command=command,
|
||||
background=background,
|
||||
timeout=timeout,
|
||||
task_id=task_id
|
||||
)
|
||||
|
||||
else:
|
||||
return json.dumps({"error": f"Unknown terminal function: {function_name}"}, ensure_ascii=False)
|
||||
|
||||
|
||||
def handle_vision_function_call(function_name: str, function_args: Dict[str, Any]) -> str:
|
||||
async def handle_vision_function_call(function_name: str, function_args: Dict[str, Any]) -> str:
|
||||
"""
|
||||
Handle function calls for vision tools.
|
||||
|
||||
@@ -512,14 +519,14 @@ def handle_vision_function_call(function_name: str, function_args: Dict[str, Any
|
||||
|
||||
full_prompt = f"Fully describe and explain everything about this image, then answer the following question:\n\n{question}"
|
||||
|
||||
# Run async function in event loop
|
||||
return asyncio.run(vision_analyze_tool(image_url, full_prompt, "gemini-2.5-flash"))
|
||||
# Run async function
|
||||
return await vision_analyze_tool(image_url, full_prompt, "gemini-2.5-flash")
|
||||
|
||||
else:
|
||||
return json.dumps({"error": f"Unknown vision function: {function_name}"}, ensure_ascii=False)
|
||||
|
||||
|
||||
def handle_moa_function_call(function_name: str, function_args: Dict[str, Any]) -> str:
|
||||
async def handle_moa_function_call(function_name: str, function_args: Dict[str, Any]) -> str:
|
||||
"""
|
||||
Handle function calls for Mixture-of-Agents tools.
|
||||
|
||||
@@ -536,14 +543,14 @@ def handle_moa_function_call(function_name: str, function_args: Dict[str, Any])
|
||||
if not user_prompt:
|
||||
return json.dumps({"error": "user_prompt is required for MoA processing"}, ensure_ascii=False)
|
||||
|
||||
# Run async function in event loop
|
||||
return asyncio.run(mixture_of_agents_tool(user_prompt=user_prompt))
|
||||
# Run async function
|
||||
return await mixture_of_agents_tool(user_prompt=user_prompt)
|
||||
|
||||
else:
|
||||
return json.dumps({"error": f"Unknown MoA function: {function_name}"}, ensure_ascii=False)
|
||||
|
||||
|
||||
def handle_image_function_call(function_name: str, function_args: Dict[str, Any]) -> str:
|
||||
async def handle_image_function_call(function_name: str, function_args: Dict[str, Any]) -> str:
|
||||
"""
|
||||
Handle function calls for image generation tools.
|
||||
|
||||
@@ -572,21 +579,8 @@ def handle_image_function_call(function_name: str, function_args: Dict[str, Any]
|
||||
allow_nsfw_images = True
|
||||
seed = None
|
||||
|
||||
# Run async function in event loop with proper handling for multiprocessing
|
||||
try:
|
||||
# Try to get existing event loop
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_closed():
|
||||
# If closed, create a new one
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
except RuntimeError:
|
||||
# No event loop in current thread, create one
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
# Run the coroutine in the event loop
|
||||
result = loop.run_until_complete(image_generate_tool(
|
||||
# Run async function
|
||||
return await image_generate_tool(
|
||||
prompt=prompt,
|
||||
image_size=image_size,
|
||||
num_inference_steps=num_inference_steps,
|
||||
@@ -597,15 +591,13 @@ def handle_image_function_call(function_name: str, function_args: Dict[str, Any]
|
||||
acceleration=acceleration,
|
||||
allow_nsfw_images=allow_nsfw_images,
|
||||
seed=seed
|
||||
))
|
||||
|
||||
return result
|
||||
)
|
||||
|
||||
else:
|
||||
return json.dumps({"error": f"Unknown image generation function: {function_name}"}, ensure_ascii=False)
|
||||
|
||||
|
||||
def handle_function_call(function_name: str, function_args: Dict[str, Any], task_id: Optional[str] = None) -> str:
|
||||
async def handle_function_call(function_name: str, function_args: Dict[str, Any], task_id: Optional[str] = None) -> str:
|
||||
"""
|
||||
Main function call dispatcher that routes calls to appropriate toolsets.
|
||||
|
||||
@@ -627,23 +619,23 @@ def handle_function_call(function_name: str, function_args: Dict[str, Any], task
|
||||
try:
|
||||
# Route web tools
|
||||
if function_name in ["web_search", "web_extract", "web_crawl"]:
|
||||
return handle_web_function_call(function_name, function_args)
|
||||
return await handle_web_function_call(function_name, function_args)
|
||||
|
||||
# Route terminal tools
|
||||
elif function_name in ["terminal"]:
|
||||
return handle_terminal_function_call(function_name, function_args, task_id)
|
||||
return await handle_terminal_function_call(function_name, function_args, task_id)
|
||||
|
||||
# Route vision tools
|
||||
elif function_name in ["vision_analyze"]:
|
||||
return handle_vision_function_call(function_name, function_args)
|
||||
return await handle_vision_function_call(function_name, function_args)
|
||||
|
||||
# Route MoA tools
|
||||
elif function_name in ["mixture_of_agents"]:
|
||||
return handle_moa_function_call(function_name, function_args)
|
||||
return await handle_moa_function_call(function_name, function_args)
|
||||
|
||||
# Route image generation tools
|
||||
elif function_name in ["image_generate"]:
|
||||
return handle_image_function_call(function_name, function_args)
|
||||
return await handle_image_function_call(function_name, function_args)
|
||||
|
||||
else:
|
||||
error_msg = f"Unknown function: {function_name}"
|
||||
@@ -773,4 +765,4 @@ if __name__ == "__main__":
|
||||
|
||||
if "terminal" in all_tool_names:
|
||||
no_terminal = get_tool_definitions(disabled_tools=["terminal"])
|
||||
print(f" All except terminal: {len(no_terminal)} tools")
|
||||
print(f" All except terminal: {len(no_terminal)} tools")
|
||||
381
profiling.py
Normal file
381
profiling.py
Normal file
@@ -0,0 +1,381 @@
|
||||
"""
|
||||
Profiling module for tracking timing statistics of tools and LLM API calls.
|
||||
|
||||
This module provides a centralized way to track timing information for various
|
||||
operations in the agent system, including:
|
||||
- Individual tool executions
|
||||
- OpenAI API calls
|
||||
- Aggregate statistics (min, max, median, mean, total)
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import Dict, List, Optional
|
||||
from dataclasses import dataclass, field
|
||||
from collections import defaultdict
|
||||
import statistics
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProfilingStats:
|
||||
"""Statistics for a particular operation type."""
|
||||
call_count: int = 0
|
||||
total_time: float = 0.0
|
||||
min_time: float = float('inf')
|
||||
max_time: float = 0.0
|
||||
times: List[float] = field(default_factory=list)
|
||||
|
||||
def add_timing(self, duration: float):
|
||||
"""Add a timing measurement."""
|
||||
self.call_count += 1
|
||||
self.total_time += duration
|
||||
self.min_time = min(self.min_time, duration)
|
||||
self.max_time = max(self.max_time, duration)
|
||||
self.times.append(duration)
|
||||
|
||||
@property
|
||||
def mean_time(self) -> float:
|
||||
"""Calculate mean time."""
|
||||
return self.total_time / self.call_count if self.call_count > 0 else 0.0
|
||||
|
||||
@property
|
||||
def median_time(self) -> float:
|
||||
"""Calculate median time."""
|
||||
return statistics.median(self.times) if self.times else 0.0
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
"""Convert to dictionary for serialization."""
|
||||
return {
|
||||
"call_count": self.call_count,
|
||||
"total_time": self.total_time,
|
||||
"min_time": self.min_time if self.min_time != float('inf') else 0.0,
|
||||
"max_time": self.max_time,
|
||||
"mean_time": self.mean_time,
|
||||
"median_time": self.median_time
|
||||
}
|
||||
|
||||
|
||||
class Profiler:
|
||||
"""
|
||||
Global profiler for tracking timing statistics across tools and API calls.
|
||||
|
||||
Usage:
|
||||
profiler = Profiler()
|
||||
|
||||
# Time a tool execution
|
||||
with profiler.time_tool("web_search"):
|
||||
# ... tool execution code ...
|
||||
pass
|
||||
|
||||
# Time an API call
|
||||
with profiler.time_api_call():
|
||||
# ... API call code ...
|
||||
pass
|
||||
|
||||
# Get statistics
|
||||
stats = profiler.get_statistics()
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the profiler."""
|
||||
self.tool_stats: Dict[str, ProfilingStats] = defaultdict(ProfilingStats)
|
||||
self.api_stats: ProfilingStats = ProfilingStats()
|
||||
self._enabled = True
|
||||
|
||||
def enable(self):
|
||||
"""Enable profiling."""
|
||||
self._enabled = True
|
||||
|
||||
def disable(self):
|
||||
"""Disable profiling."""
|
||||
self._enabled = False
|
||||
|
||||
def reset(self):
|
||||
"""Reset all profiling data."""
|
||||
self.tool_stats.clear()
|
||||
self.api_stats = ProfilingStats()
|
||||
|
||||
def record_tool_timing(self, tool_name: str, duration: float):
|
||||
"""Record timing for a tool execution."""
|
||||
if self._enabled:
|
||||
self.tool_stats[tool_name].add_timing(duration)
|
||||
|
||||
def record_api_timing(self, duration: float):
|
||||
"""Record timing for an API call."""
|
||||
if self._enabled:
|
||||
self.api_stats.add_timing(duration)
|
||||
|
||||
def get_statistics(self) -> Dict:
|
||||
"""
|
||||
Get all profiling statistics.
|
||||
|
||||
Returns:
|
||||
Dictionary containing tool and API statistics
|
||||
"""
|
||||
return {
|
||||
"tools": {
|
||||
tool_name: stats.to_dict()
|
||||
for tool_name, stats in sorted(self.tool_stats.items())
|
||||
},
|
||||
"api_calls": self.api_stats.to_dict()
|
||||
}
|
||||
|
||||
def print_statistics(self, detailed: bool = True):
|
||||
"""
|
||||
Print profiling statistics in a readable format.
|
||||
|
||||
Args:
|
||||
detailed: If True, show per-tool breakdown. If False, show summary only.
|
||||
"""
|
||||
print("\n" + "="*80)
|
||||
print("📊 PROFILING STATISTICS")
|
||||
print("="*80)
|
||||
|
||||
# API Call Statistics
|
||||
print("\n🔷 OpenAI API Calls:")
|
||||
if self.api_stats.call_count > 0:
|
||||
api_dict = self.api_stats.to_dict()
|
||||
print(f" Total Calls: {api_dict['call_count']}")
|
||||
print(f" Total Time: {api_dict['total_time']:.2f}s")
|
||||
print(f" Min Time: {api_dict['min_time']:.2f}s")
|
||||
print(f" Max Time: {api_dict['max_time']:.2f}s")
|
||||
print(f" Mean Time: {api_dict['mean_time']:.2f}s")
|
||||
print(f" Median Time: {api_dict['median_time']:.2f}s")
|
||||
else:
|
||||
print(" No API calls recorded")
|
||||
|
||||
# Tool Statistics
|
||||
print("\n🔧 Tool Executions:")
|
||||
if self.tool_stats:
|
||||
if detailed:
|
||||
for tool_name in sorted(self.tool_stats.keys()):
|
||||
stats_dict = self.tool_stats[tool_name].to_dict()
|
||||
print(f"\n 📌 {tool_name}:")
|
||||
print(f" Total Calls: {stats_dict['call_count']}")
|
||||
print(f" Total Time: {stats_dict['total_time']:.2f}s")
|
||||
print(f" Min Time: {stats_dict['min_time']:.2f}s")
|
||||
print(f" Max Time: {stats_dict['max_time']:.2f}s")
|
||||
print(f" Mean Time: {stats_dict['mean_time']:.2f}s")
|
||||
print(f" Median Time: {stats_dict['median_time']:.2f}s")
|
||||
|
||||
# Summary
|
||||
total_tool_calls = sum(s.call_count for s in self.tool_stats.values())
|
||||
total_tool_time = sum(s.total_time for s in self.tool_stats.values())
|
||||
print(f"\n 📊 Summary:")
|
||||
print(f" Total Tool Calls: {total_tool_calls}")
|
||||
print(f" Total Tool Time: {total_tool_time:.2f}s")
|
||||
print(f" Unique Tools Used: {len(self.tool_stats)}")
|
||||
else:
|
||||
print(" No tool executions recorded")
|
||||
|
||||
# Overall Summary
|
||||
total_api_time = self.api_stats.total_time
|
||||
total_tool_time = sum(s.total_time for s in self.tool_stats.values())
|
||||
print(f"\n📈 Overall Summary:")
|
||||
print(f" Total API Time: {total_api_time:.2f}s")
|
||||
print(f" Total Tool Time: {total_tool_time:.2f}s")
|
||||
print(f" Total Time: {total_api_time + total_tool_time:.2f}s")
|
||||
print("="*80 + "\n")
|
||||
|
||||
def export_to_json(self) -> str:
|
||||
"""Export statistics as JSON string."""
|
||||
import json
|
||||
return json.dumps(self.get_statistics(), indent=2)
|
||||
|
||||
def export_to_file(self, filepath: str):
|
||||
"""
|
||||
Export statistics to a JSON file.
|
||||
|
||||
Args:
|
||||
filepath: Path to output file
|
||||
"""
|
||||
import json
|
||||
with open(filepath, 'w') as f:
|
||||
json.dump(self.get_statistics(), f, indent=2)
|
||||
print(f"📁 Profiling statistics exported to: {filepath}")
|
||||
|
||||
|
||||
# Global profiler instance
|
||||
_global_profiler: Optional[Profiler] = None
|
||||
|
||||
|
||||
def get_profiler() -> Profiler:
|
||||
"""Get or create the global profiler instance."""
|
||||
global _global_profiler
|
||||
if _global_profiler is None:
|
||||
_global_profiler = Profiler()
|
||||
return _global_profiler
|
||||
|
||||
|
||||
def reset_profiler():
|
||||
"""Reset the global profiler."""
|
||||
global _global_profiler
|
||||
if _global_profiler is not None:
|
||||
_global_profiler.reset()
|
||||
|
||||
|
||||
class TimingContext:
|
||||
"""Context manager for timing operations."""
|
||||
|
||||
def __init__(self, profiler: Profiler, operation_type: str, operation_name: Optional[str] = None):
|
||||
"""
|
||||
Initialize timing context.
|
||||
|
||||
Args:
|
||||
profiler: Profiler instance to record timing
|
||||
operation_type: 'tool' or 'api'
|
||||
operation_name: Name of the operation (required for tools)
|
||||
"""
|
||||
self.profiler = profiler
|
||||
self.operation_type = operation_type
|
||||
self.operation_name = operation_name
|
||||
self.start_time = None
|
||||
|
||||
def __enter__(self):
|
||||
"""Start timing."""
|
||||
self.start_time = time.time()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Stop timing and record."""
|
||||
duration = time.time() - self.start_time
|
||||
|
||||
if self.operation_type == 'tool':
|
||||
self.profiler.record_tool_timing(self.operation_name, duration)
|
||||
elif self.operation_type == 'api':
|
||||
self.profiler.record_api_timing(duration)
|
||||
|
||||
return False # Don't suppress exceptions
|
||||
|
||||
|
||||
def aggregate_profiling_stats(stats_list: List[Dict]) -> Dict:
|
||||
"""
|
||||
Aggregate multiple profiling statistics dictionaries into one.
|
||||
|
||||
This is useful for batch processing where each worker process has its own
|
||||
profiler instance that needs to be combined.
|
||||
|
||||
Args:
|
||||
stats_list: List of statistics dictionaries from get_statistics()
|
||||
|
||||
Returns:
|
||||
Dict: Aggregated statistics with combined tool and API call data
|
||||
"""
|
||||
aggregated = {
|
||||
"tools": defaultdict(lambda: {"times": []}),
|
||||
"api_calls": {"times": []}
|
||||
}
|
||||
|
||||
# Aggregate tool statistics
|
||||
for stats in stats_list:
|
||||
# Aggregate tool timings
|
||||
for tool_name, tool_stats in stats.get("tools", {}).items():
|
||||
# Reconstruct individual timings from aggregated stats
|
||||
# Since we have mean_time and call_count, we approximate
|
||||
aggregated["tools"][tool_name]["times"].extend(
|
||||
[tool_stats.get("mean_time", 0.0)] * tool_stats.get("call_count", 0)
|
||||
)
|
||||
|
||||
# Aggregate API call timings
|
||||
api_stats = stats.get("api_calls", {})
|
||||
if api_stats.get("call_count", 0) > 0:
|
||||
aggregated["api_calls"]["times"].extend(
|
||||
[api_stats.get("mean_time", 0.0)] * api_stats.get("call_count", 0)
|
||||
)
|
||||
|
||||
# Calculate final statistics for tools
|
||||
final_stats = {"tools": {}, "api_calls": {}}
|
||||
|
||||
for tool_name, data in aggregated["tools"].items():
|
||||
times = data["times"]
|
||||
if times:
|
||||
final_stats["tools"][tool_name] = {
|
||||
"call_count": len(times),
|
||||
"total_time": sum(times),
|
||||
"min_time": min(times),
|
||||
"max_time": max(times),
|
||||
"mean_time": statistics.mean(times),
|
||||
"median_time": statistics.median(times)
|
||||
}
|
||||
|
||||
# Calculate final statistics for API calls
|
||||
api_times = aggregated["api_calls"]["times"]
|
||||
if api_times:
|
||||
final_stats["api_calls"] = {
|
||||
"call_count": len(api_times),
|
||||
"total_time": sum(api_times),
|
||||
"min_time": min(api_times),
|
||||
"max_time": max(api_times),
|
||||
"mean_time": statistics.mean(api_times),
|
||||
"median_time": statistics.median(api_times)
|
||||
}
|
||||
else:
|
||||
final_stats["api_calls"] = {
|
||||
"call_count": 0,
|
||||
"total_time": 0.0,
|
||||
"min_time": 0.0,
|
||||
"max_time": 0.0,
|
||||
"mean_time": 0.0,
|
||||
"median_time": 0.0
|
||||
}
|
||||
|
||||
return final_stats
|
||||
|
||||
|
||||
def print_aggregated_statistics(stats: Dict, detailed: bool = True):
|
||||
"""
|
||||
Print aggregated profiling statistics in a readable format.
|
||||
|
||||
Args:
|
||||
stats: Aggregated statistics dictionary from aggregate_profiling_stats()
|
||||
detailed: If True, show per-tool breakdown. If False, show summary only.
|
||||
"""
|
||||
print("\n" + "="*80)
|
||||
print("📊 AGGREGATED PROFILING STATISTICS")
|
||||
print("="*80)
|
||||
|
||||
# API Call Statistics
|
||||
print("\n🔷 OpenAI API Calls:")
|
||||
api_stats = stats.get("api_calls", {})
|
||||
if api_stats.get("call_count", 0) > 0:
|
||||
print(f" Total Calls: {api_stats['call_count']}")
|
||||
print(f" Total Time: {api_stats['total_time']:.2f}s")
|
||||
print(f" Min Time: {api_stats['min_time']:.2f}s")
|
||||
print(f" Max Time: {api_stats['max_time']:.2f}s")
|
||||
print(f" Mean Time: {api_stats['mean_time']:.2f}s")
|
||||
print(f" Median Time: {api_stats['median_time']:.2f}s")
|
||||
else:
|
||||
print(" No API calls recorded")
|
||||
|
||||
# Tool Statistics
|
||||
print("\n🔧 Tool Executions:")
|
||||
tool_stats = stats.get("tools", {})
|
||||
if tool_stats:
|
||||
if detailed:
|
||||
for tool_name in sorted(tool_stats.keys()):
|
||||
stats_dict = tool_stats[tool_name]
|
||||
print(f"\n 📌 {tool_name}:")
|
||||
print(f" Total Calls: {stats_dict['call_count']}")
|
||||
print(f" Total Time: {stats_dict['total_time']:.2f}s")
|
||||
print(f" Min Time: {stats_dict['min_time']:.2f}s")
|
||||
print(f" Max Time: {stats_dict['max_time']:.2f}s")
|
||||
print(f" Mean Time: {stats_dict['mean_time']:.2f}s")
|
||||
print(f" Median Time: {stats_dict['median_time']:.2f}s")
|
||||
|
||||
# Summary
|
||||
total_tool_calls = sum(s["call_count"] for s in tool_stats.values())
|
||||
total_tool_time = sum(s["total_time"] for s in tool_stats.values())
|
||||
print(f"\n 📊 Summary:")
|
||||
print(f" Total Tool Calls: {total_tool_calls}")
|
||||
print(f" Total Tool Time: {total_tool_time:.2f}s")
|
||||
print(f" Unique Tools Used: {len(tool_stats)}")
|
||||
else:
|
||||
print(" No tool executions recorded")
|
||||
|
||||
# Overall Summary
|
||||
total_api_time = api_stats.get("total_time", 0.0)
|
||||
total_tool_time = sum(s["total_time"] for s in tool_stats.values())
|
||||
print(f"\n📈 Overall Summary:")
|
||||
print(f" Total API Time: {total_api_time:.2f}s")
|
||||
print(f" Total Tool Time: {total_tool_time:.2f}s")
|
||||
print(f" Total Time: {total_api_time + total_tool_time:.2f}s")
|
||||
print("="*80 + "\n")
|
||||
408
run_agent.py
408
run_agent.py
@@ -24,11 +24,23 @@ import json
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
import asyncio
|
||||
import sys
|
||||
from typing import List, Dict, Any, Optional
|
||||
from openai import OpenAI
|
||||
from openai import AsyncOpenAI
|
||||
import fire
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from rich import print
|
||||
|
||||
from prokletor.formatters.hermes import HermesToolFormatterWithReasoning
|
||||
from prokletor.formatters.hermes import HermesToolFormatterWithReasoning
|
||||
from prokletor.clients.hermes import HermesToolClientWithReasoning, HermesToolClient
|
||||
from prokletor.clients.claude import AsyncClaudeClient
|
||||
try:
|
||||
from anthropic import AsyncAnthropic
|
||||
except ImportError:
|
||||
AsyncAnthropic = None
|
||||
|
||||
# Load environment variables from .env file
|
||||
from dotenv import load_dotenv
|
||||
@@ -45,6 +57,9 @@ else:
|
||||
from model_tools import get_tool_definitions, handle_function_call, check_toolset_requirements
|
||||
from tools.terminal_tool import cleanup_vm
|
||||
|
||||
# Import profiling
|
||||
from profiling import get_profiler
|
||||
|
||||
|
||||
class AIAgent:
|
||||
"""
|
||||
@@ -67,6 +82,8 @@ class AIAgent:
|
||||
verbose_logging: bool = False,
|
||||
ephemeral_system_prompt: str = None,
|
||||
log_prefix_chars: int = 100,
|
||||
prokletor_client: str = None,
|
||||
prokletor_formatter: str = None,
|
||||
):
|
||||
"""
|
||||
Initialize the AI Agent.
|
||||
@@ -83,6 +100,8 @@ class AIAgent:
|
||||
verbose_logging (bool): Enable verbose logging for debugging (default: False)
|
||||
ephemeral_system_prompt (str): System prompt used during agent execution but NOT saved to trajectories (optional)
|
||||
log_prefix_chars (int): Number of characters to show in log previews for tool calls/responses (default: 20)
|
||||
prokletor_client (str): Name of the prokletor client to use (e.g., "AsyncClaudeClient", "HermesToolClient")
|
||||
prokletor_formatter (str): Name of the prokletor formatter to use (optional)
|
||||
"""
|
||||
self.model = model
|
||||
self.max_iterations = max_iterations
|
||||
@@ -91,6 +110,8 @@ class AIAgent:
|
||||
self.verbose_logging = verbose_logging
|
||||
self.ephemeral_system_prompt = ephemeral_system_prompt
|
||||
self.log_prefix_chars = log_prefix_chars
|
||||
self.prokletor_client_name = prokletor_client
|
||||
self.prokletor_formatter_name = prokletor_formatter
|
||||
|
||||
# Store toolset filtering options
|
||||
self.enabled_toolsets = enabled_toolsets
|
||||
@@ -119,7 +140,7 @@ class AIAgent:
|
||||
logging.getLogger('openai').setLevel(logging.WARNING)
|
||||
logging.getLogger('httpx').setLevel(logging.WARNING)
|
||||
|
||||
# Initialize OpenAI client
|
||||
# Initialize Client
|
||||
client_kwargs = {}
|
||||
if base_url:
|
||||
client_kwargs["base_url"] = base_url
|
||||
@@ -129,12 +150,45 @@ class AIAgent:
|
||||
client_kwargs["api_key"] = os.getenv("ANTHROPIC_API_KEY", "dummy-key")
|
||||
|
||||
try:
|
||||
self.client = OpenAI(**client_kwargs)
|
||||
if prokletor_client == "AsyncClaudeClient":
|
||||
if AsyncAnthropic is None:
|
||||
raise ImportError("anthropic package is required for AsyncClaudeClient")
|
||||
|
||||
# AsyncAnthropic kwargs
|
||||
anthropic_kwargs = {k: v for k, v in client_kwargs.items() if k in ["api_key", "base_url", "timeout", "max_retries", "default_headers"]}
|
||||
|
||||
anthropic_client = AsyncAnthropic(**anthropic_kwargs)
|
||||
self.client = AsyncClaudeClient(anthropic_client)
|
||||
print(f"🧠 Wrapped Anthropic client with AsyncClaudeClient")
|
||||
|
||||
elif prokletor_client == "HermesToolClient":
|
||||
oai_client = AsyncOpenAI(**client_kwargs)
|
||||
self.client = HermesToolClient(oai_client)
|
||||
print(f"🧠 Wrapped OpenAI client with HermesToolClient")
|
||||
|
||||
elif prokletor_client == "HermesToolClientWithReasoning":
|
||||
oai_client = AsyncOpenAI(**client_kwargs)
|
||||
self.client = HermesToolClientWithReasoning(oai_client)
|
||||
print(f"🧠 Wrapped OpenAI client with HermesToolClientWithReasoning")
|
||||
|
||||
elif prokletor_client:
|
||||
# Fallback for unknown client names or if user provides a custom one (future proofing?)
|
||||
# For now, raise error or default to OpenAI
|
||||
print(f"⚠️ Unknown prokletor_client '{prokletor_client}'. Defaulting to HermesToolClientWithReasoning.")
|
||||
oai_client = AsyncOpenAI(**client_kwargs)
|
||||
self.client = HermesToolClientWithReasoning(oai_client)
|
||||
|
||||
else:
|
||||
# Default behavior
|
||||
oai_client = AsyncOpenAI(**client_kwargs)
|
||||
self.client = oai_client
|
||||
print(f"🧠 Using raw OpenAI client (no prokletor wrapper)")
|
||||
|
||||
print(f"🤖 AI Agent initialized with model: {self.model}")
|
||||
if base_url:
|
||||
print(f"🔗 Using custom base URL: {base_url}")
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to initialize OpenAI client: {e}")
|
||||
raise RuntimeError(f"Failed to initialize client: {e}")
|
||||
|
||||
# Get available tools with filtering
|
||||
self.tools = get_tool_definitions(
|
||||
@@ -207,22 +261,54 @@ class AIAgent:
|
||||
Returns:
|
||||
List[Dict]: Messages in trajectory format
|
||||
"""
|
||||
# Use the client wrapper's format method if available to get the exact Hermes format
|
||||
# This ensures batch runner also gets the correct formatting
|
||||
if hasattr(self, 'client') and hasattr(self.client, 'format'):
|
||||
formatted_messages = self.client.format(messages, self.tools, render_final=True)
|
||||
|
||||
trajectory = []
|
||||
for msg in formatted_messages:
|
||||
role = msg["role"]
|
||||
content = msg["content"]
|
||||
|
||||
# Map roles to trajectory format (human, gpt, system, tool)
|
||||
if role == "user":
|
||||
trajectory_role = "human"
|
||||
elif role == "assistant":
|
||||
trajectory_role = "gpt"
|
||||
elif role == "system":
|
||||
trajectory_role = "system"
|
||||
elif role == "tool":
|
||||
trajectory_role = "tool"
|
||||
else:
|
||||
trajectory_role = role
|
||||
|
||||
trajectory.append({
|
||||
"from": trajectory_role,
|
||||
"value": content
|
||||
})
|
||||
return trajectory
|
||||
|
||||
trajectory = []
|
||||
|
||||
# Add system message with tool definitions
|
||||
system_msg = (
|
||||
"You are a function calling AI model. You are provided with function signatures within <tools> </tools> XML tags. "
|
||||
"You may call one or more functions to assist with the user query. If available tools are not relevant in assisting "
|
||||
"with user query, just respond in natural conversational language. Don't make assumptions about what values to plug "
|
||||
"into functions. After calling & executing the functions, you will be provided with function results within "
|
||||
"<tool_response> </tool_response> XML tags. Here are the available tools:\n"
|
||||
f"<tools>\n{self._format_tools_for_system_message()}\n</tools>\n"
|
||||
"For each function call return a JSON object, with the following pydantic model json schema for each:\n"
|
||||
"{'title': 'FunctionCall', 'type': 'object', 'properties': {'name': {'title': 'Name', 'type': 'string'}, "
|
||||
"'arguments': {'title': 'Arguments', 'type': 'object'}}, 'required': ['name', 'arguments']}\n"
|
||||
"Each function call should be enclosed within <tool_call> </tool_call> XML tags.\n"
|
||||
"Example:\n<tool_call>\n{'name': <function-name>,'arguments': <args-dict>}\n</tool_call>"
|
||||
)
|
||||
# Use the client's formatter if available to ensure consistency (e.g. reasoning prompt)
|
||||
if hasattr(self, 'client') and hasattr(self.client, 'formatter'):
|
||||
system_msg = self.client.formatter.format_system_message(self.tools if self.tools else [])
|
||||
else:
|
||||
system_msg = (
|
||||
"You are a function calling AI model. You are provided with function signatures within <tools> </tools> XML tags. "
|
||||
"You may call one or more functions to assist with the user query. If available tools are not relevant in assisting "
|
||||
"with user query, just respond in natural conversational language. Don't make assumptions about what values to plug "
|
||||
"into functions. After calling & executing the functions, you will be provided with function results within "
|
||||
"<tool_response> </tool_response> XML tags. Here are the available tools:\n"
|
||||
f"<tools>\n{self._format_tools_for_system_message()}\n</tools>\n"
|
||||
"For each function call return a JSON object, with the following pydantic model json schema for each:\n"
|
||||
"{'title': 'FunctionCall', 'type': 'object', 'properties': {'name': {'title': 'Name', 'type': 'string'}, "
|
||||
"'arguments': {'title': 'Arguments', 'type': 'object'}}, 'required': ['name', 'arguments']}\n"
|
||||
"Each function call should be enclosed within <tool_call> </tool_call> XML tags.\n"
|
||||
"Example:\n<tool_call>\n{'name': <function-name>,'arguments': <args-dict>}\n</tool_call>"
|
||||
)
|
||||
|
||||
trajectory.append({
|
||||
"from": "system",
|
||||
@@ -345,7 +431,7 @@ class AIAgent:
|
||||
except Exception as e:
|
||||
print(f"⚠️ Failed to save trajectory: {e}")
|
||||
|
||||
def run_conversation(
|
||||
async def run_conversation(
|
||||
self,
|
||||
user_message: str,
|
||||
system_message: str = None,
|
||||
@@ -364,6 +450,10 @@ class AIAgent:
|
||||
Returns:
|
||||
Dict: Complete conversation result with final response and message history
|
||||
"""
|
||||
# Reset profiler for this conversation to get fresh stats
|
||||
from profiling import reset_profiler as reset_prof
|
||||
reset_prof()
|
||||
|
||||
# Generate unique task_id if not provided to isolate VMs between concurrent tasks
|
||||
import uuid
|
||||
effective_task_id = task_id or str(uuid.uuid4())
|
||||
@@ -394,10 +484,14 @@ class AIAgent:
|
||||
if self.verbose_logging:
|
||||
logging.debug(f"API Request - Model: {self.model}, Messages: {len(messages)}, Tools: {len(self.tools) if self.tools else 0}")
|
||||
logging.debug(f"Last message role: {messages[-1]['role'] if messages else 'none'}")
|
||||
# Log the last few messages to see if thought_signature is present
|
||||
logging.debug(f"Last message content: {json.dumps(messages[-1] if messages else {}, indent=2)}")
|
||||
|
||||
api_start_time = time.time()
|
||||
retry_count = 0
|
||||
max_retries = 6 # Increased to allow longer backoff periods
|
||||
response = None
|
||||
last_api_error = None
|
||||
|
||||
while retry_count <= max_retries:
|
||||
try:
|
||||
@@ -407,24 +501,47 @@ class AIAgent:
|
||||
if active_system_prompt:
|
||||
# Insert system message at the beginning
|
||||
api_messages = [{"role": "system", "content": active_system_prompt}] + api_messages
|
||||
|
||||
|
||||
# Make API call with tools
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=api_messages,
|
||||
tools=self.tools if self.tools else None,
|
||||
timeout=300.0 # 5 minute timeout for long-running agent tasks
|
||||
)
|
||||
api_kwargs = {
|
||||
"model": self.model,
|
||||
"messages": api_messages,
|
||||
"tools": self.tools if self.tools else None,
|
||||
"timeout": 300.0, # 5 minute timeout for long-running agent tasks
|
||||
}
|
||||
|
||||
# Enable thinking by default for AsyncClaudeClient if using a supported model
|
||||
if self.prokletor_client_name == "AsyncClaudeClient" and self.model.startswith("claude"):
|
||||
api_kwargs["thinking"] = {
|
||||
"type": "enabled",
|
||||
"budget_tokens": 8000
|
||||
}
|
||||
# Ensure max_tokens is set higher than budget_tokens
|
||||
api_kwargs["max_tokens"] = 16000
|
||||
|
||||
response = await self.client.chat.completions.create(**api_kwargs)
|
||||
|
||||
api_duration = time.time() - api_start_time
|
||||
print(f"⏱️ OpenAI-compatible API call completed in {api_duration:.2f}s")
|
||||
|
||||
# Record API timing in profiler
|
||||
get_profiler().record_api_timing(api_duration)
|
||||
|
||||
if self.verbose_logging:
|
||||
logging.debug(f"API Response received - Usage: {response.usage if hasattr(response, 'usage') else 'N/A'}")
|
||||
|
||||
break # Success, exit retry loop
|
||||
|
||||
except Exception as api_error:
|
||||
last_api_error = api_error
|
||||
error_message = str(api_error)
|
||||
token_limit_error = "input token count exceeds the maximum number of tokens" in error_message.lower()
|
||||
|
||||
if token_limit_error:
|
||||
print("❌ OpenAI-compatible API call failed: input token limit exceeded. Not retrying this request.")
|
||||
logging.error("Non-retryable token limit error from API: %s", api_error)
|
||||
break
|
||||
|
||||
retry_count += 1
|
||||
if retry_count > max_retries:
|
||||
raise api_error
|
||||
@@ -433,8 +550,11 @@ class AIAgent:
|
||||
print(f"⚠️ OpenAI-compatible API call failed (attempt {retry_count}/{max_retries}): {str(api_error)[:100]}")
|
||||
print(f"⏳ Retrying in {wait_time}s...")
|
||||
logging.warning(f"API retry {retry_count}/{max_retries} after error: {api_error}")
|
||||
time.sleep(wait_time)
|
||||
|
||||
await asyncio.sleep(wait_time)
|
||||
|
||||
if response is None:
|
||||
raise last_api_error if last_api_error else RuntimeError("OpenAI-compatible API call failed without a response")
|
||||
|
||||
try:
|
||||
assistant_message = response.choices[0].message
|
||||
|
||||
@@ -449,25 +569,62 @@ class AIAgent:
|
||||
if self.verbose_logging:
|
||||
for tc in assistant_message.tool_calls:
|
||||
logging.debug(f"Tool call: {tc.function.name} with args: {tc.function.arguments[:200]}...")
|
||||
# Debug: Check what attributes are available on tool_call
|
||||
logging.debug(f"Tool call attributes: {dir(tc)}")
|
||||
# Try to dump the model to see all fields
|
||||
if hasattr(tc, 'model_dump'):
|
||||
logging.debug(f"Tool call data: {tc.model_dump()}")
|
||||
|
||||
# Add assistant message with tool calls to conversation
|
||||
# Extract thought_signature if present (required for Gemini models)
|
||||
tool_calls_data = []
|
||||
for tool_call in assistant_message.tool_calls:
|
||||
tool_call_dict = {
|
||||
"id": tool_call.id,
|
||||
"type": tool_call.type,
|
||||
"function": {
|
||||
"name": tool_call.function.name,
|
||||
"arguments": tool_call.function.arguments
|
||||
}
|
||||
}
|
||||
# Try multiple ways to access thought_signature (Gemini-specific)
|
||||
# Gemini uses extra_content.google.thought_signature structure
|
||||
thought_sig = None
|
||||
|
||||
# Method 1: Check extra_content attribute
|
||||
if hasattr(tool_call, 'extra_content'):
|
||||
extra = tool_call.extra_content
|
||||
if isinstance(extra, dict) and 'google' in extra:
|
||||
thought_sig = extra['google'].get('thought_signature')
|
||||
|
||||
# Method 2: Check model_dump() if available (Pydantic v2)
|
||||
if thought_sig is None and hasattr(tool_call, 'model_dump'):
|
||||
dumped = tool_call.model_dump()
|
||||
if 'extra_content' in dumped and isinstance(dumped['extra_content'], dict):
|
||||
google_data = dumped['extra_content'].get('google', {})
|
||||
thought_sig = google_data.get('thought_signature')
|
||||
|
||||
if thought_sig is not None:
|
||||
tool_call_dict["extra_content"] = {
|
||||
"google": {
|
||||
"thought_signature": thought_sig
|
||||
}
|
||||
}
|
||||
if self.verbose_logging:
|
||||
logging.debug(f"Captured thought_signature for tool call {tool_call.id}")
|
||||
elif self.verbose_logging:
|
||||
logging.debug(f"No thought_signature found for tool call {tool_call.id}")
|
||||
|
||||
tool_calls_data.append(tool_call_dict)
|
||||
|
||||
messages.append({
|
||||
"role": "assistant",
|
||||
"content": assistant_message.content,
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": tool_call.id,
|
||||
"type": tool_call.type,
|
||||
"function": {
|
||||
"name": tool_call.function.name,
|
||||
"arguments": tool_call.function.arguments
|
||||
}
|
||||
}
|
||||
for tool_call in assistant_message.tool_calls
|
||||
]
|
||||
"tool_calls": tool_calls_data
|
||||
})
|
||||
|
||||
# Execute each tool call
|
||||
# Execute tool calls concurrently
|
||||
tool_tasks = []
|
||||
for i, tool_call in enumerate(assistant_message.tool_calls, 1):
|
||||
function_name = tool_call.function.name
|
||||
|
||||
@@ -482,32 +639,55 @@ class AIAgent:
|
||||
args_preview = args_str[:self.log_prefix_chars] + "..." if len(args_str) > self.log_prefix_chars else args_str
|
||||
print(f" 📞 Tool {i}: {function_name}({list(function_args.keys())}) - {args_preview}")
|
||||
|
||||
# Create coroutine for tool execution
|
||||
task = handle_function_call(function_name, function_args, effective_task_id)
|
||||
tool_tasks.append(task)
|
||||
|
||||
if tool_tasks:
|
||||
tool_start_time = time.time()
|
||||
|
||||
# Execute the tool with task_id to isolate VMs between concurrent tasks
|
||||
function_result = handle_function_call(function_name, function_args, effective_task_id)
|
||||
|
||||
tool_duration = time.time() - tool_start_time
|
||||
result_preview = function_result[:200] if len(function_result) > 200 else function_result
|
||||
|
||||
if self.verbose_logging:
|
||||
logging.debug(f"Tool {function_name} completed in {tool_duration:.2f}s")
|
||||
logging.debug(f"Tool result preview: {result_preview}...")
|
||||
|
||||
# Add tool result to conversation
|
||||
messages.append({
|
||||
"role": "tool",
|
||||
"content": function_result,
|
||||
"tool_call_id": tool_call.id
|
||||
})
|
||||
|
||||
# Preview tool response
|
||||
response_preview = function_result[:self.log_prefix_chars] + "..." if len(function_result) > self.log_prefix_chars else function_result
|
||||
print(f" ✅ Tool {i} completed in {tool_duration:.2f}s - {response_preview}")
|
||||
|
||||
# Delay between tool calls
|
||||
if self.tool_delay > 0 and i < len(assistant_message.tool_calls):
|
||||
time.sleep(self.tool_delay)
|
||||
# Execute all tools concurrently
|
||||
# We use return_exceptions=True to ensure one failure doesn't stop others
|
||||
# Order of results corresponds to order of tasks
|
||||
results = await asyncio.gather(*tool_tasks, return_exceptions=True)
|
||||
|
||||
tool_duration = time.time() - tool_start_time
|
||||
|
||||
# Process results
|
||||
for i, (result, tool_call) in enumerate(zip(results, assistant_message.tool_calls), 1):
|
||||
function_name = tool_call.function.name
|
||||
|
||||
# Handle exceptions from asyncio.gather
|
||||
if isinstance(result, Exception):
|
||||
function_result = json.dumps({"error": str(result)}, ensure_ascii=False)
|
||||
print(f"❌ Tool {i} ({function_name}) failed: {result}")
|
||||
else:
|
||||
function_result = result
|
||||
|
||||
result_preview = function_result[:200] if len(function_result) > 200 else function_result
|
||||
|
||||
# Record tool timing in profiler (approximate since they ran in parallel)
|
||||
get_profiler().record_tool_timing(function_name, tool_duration)
|
||||
|
||||
if self.verbose_logging:
|
||||
logging.debug(f"Tool {function_name} completed in parallel batch")
|
||||
logging.debug(f"Tool result preview: {result_preview}...")
|
||||
|
||||
# Add tool result to conversation
|
||||
# Note: thought_signature should NOT be in tool responses, only in assistant messages
|
||||
messages.append({
|
||||
"role": "tool",
|
||||
"content": function_result,
|
||||
"tool_call_id": tool_call.id
|
||||
})
|
||||
|
||||
# Preview tool response
|
||||
response_preview = function_result[:self.log_prefix_chars] + "..." if len(function_result) > self.log_prefix_chars else function_result
|
||||
print(f" ✅ Tool {i} completed - {response_preview}")
|
||||
|
||||
# Optional delay after batch execution
|
||||
if self.tool_delay > 0:
|
||||
await asyncio.sleep(self.tool_delay)
|
||||
|
||||
# Continue loop for next response
|
||||
continue
|
||||
@@ -553,23 +733,95 @@ class AIAgent:
|
||||
completed = final_response is not None and api_call_count < self.max_iterations
|
||||
|
||||
# Save trajectory if enabled
|
||||
self._save_trajectory(messages, user_message, completed)
|
||||
# When saving trajectory, we want to show what the prompt would look like with proper tool roles
|
||||
# This is helpful for training data or debugging
|
||||
if self.save_trajectories:
|
||||
# Use the client wrapper's format method if available to get the exact Hermes format
|
||||
if hasattr(self, 'client') and hasattr(self.client, 'format'):
|
||||
raise ValueError("reached this point")
|
||||
formatted_messages = self.client.format(messages, self.tools, render_final=True)
|
||||
|
||||
# We need to adapt this formatted list to the trajectory format expected by _save_trajectory
|
||||
# Since _convert_to_trajectory_format expects raw OAI messages, we might need a different approach
|
||||
# OR just pass the formatted messages directly if _save_trajectory supports it.
|
||||
|
||||
# Let's look at _convert_to_trajectory_format. It iterates through messages and converts them.
|
||||
# If we pass messages that are already formatted (e.g. system prompt with tools, tool calls in XML),
|
||||
# we need to be careful not to double-format.
|
||||
|
||||
# Actually, the goal is to save the trajectory in a specific JSONL format for training/eval.
|
||||
# If we use the Hermes formatter, it produces a list of messages where content is XML strings.
|
||||
# The existing _convert_to_trajectory_format does manual XML wrapping.
|
||||
|
||||
# Ideally, we should use the messages as they are (OAI format) and let the training pipeline handle formatting,
|
||||
# OR save them in the exact format the model sees.
|
||||
|
||||
# The user request is: "accumulating history in oai format and then calling that final thing with use_tool_call True"
|
||||
# referring to client.format(messages, tools, use_tool_role=True)
|
||||
|
||||
# So let's save the RESULT of client.format() to the trajectory file.
|
||||
|
||||
# Create a custom trajectory entry directly from the formatted messages
|
||||
trajectory_content = []
|
||||
for msg in formatted_messages:
|
||||
role = msg["role"]
|
||||
content = msg["content"]
|
||||
|
||||
# Map roles to trajectory format (human, gpt, system, tool)
|
||||
if role == "user":
|
||||
trajectory_role = "human"
|
||||
elif role == "assistant":
|
||||
trajectory_role = "gpt"
|
||||
elif role == "system":
|
||||
trajectory_role = "system"
|
||||
elif role == "tool":
|
||||
trajectory_role = "tool"
|
||||
else:
|
||||
trajectory_role = role
|
||||
|
||||
trajectory_content.append({
|
||||
"from": trajectory_role,
|
||||
"value": content
|
||||
})
|
||||
|
||||
# Save this specific formatted trajectory
|
||||
filename = "trajectory_samples.jsonl" if completed else "failed_trajectories.jsonl"
|
||||
entry = {
|
||||
"conversations": trajectory_content,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"model": self.model,
|
||||
"completed": completed
|
||||
}
|
||||
|
||||
try:
|
||||
with open(filename, "a", encoding="utf-8") as f:
|
||||
f.write(json.dumps(entry, ensure_ascii=False) + "\n")
|
||||
print(f"💾 Trajectory saved to {filename} (using Hermes format)")
|
||||
except Exception as e:
|
||||
print(f"⚠️ Failed to save trajectory: {e}")
|
||||
else:
|
||||
# Fallback to original saving method
|
||||
self._save_trajectory(messages, user_message, completed)
|
||||
|
||||
# Clean up VM for this task after conversation completes
|
||||
try:
|
||||
cleanup_vm(effective_task_id)
|
||||
await asyncio.to_thread(cleanup_vm, effective_task_id)
|
||||
except Exception as e:
|
||||
if self.verbose_logging:
|
||||
logging.warning(f"Failed to cleanup VM for task {effective_task_id}: {e}")
|
||||
|
||||
# Get profiling statistics for this conversation
|
||||
profiling_stats = get_profiler().get_statistics()
|
||||
|
||||
return {
|
||||
"final_response": final_response,
|
||||
"messages": messages,
|
||||
"api_calls": api_call_count,
|
||||
"completed": completed
|
||||
"completed": completed,
|
||||
"profiling_stats": profiling_stats
|
||||
}
|
||||
|
||||
def chat(self, message: str) -> str:
|
||||
async def chat(self, message: str) -> str:
|
||||
"""
|
||||
Simple chat interface that returns just the final response.
|
||||
|
||||
@@ -579,7 +831,7 @@ class AIAgent:
|
||||
Returns:
|
||||
str: Final assistant response
|
||||
"""
|
||||
result = self.run_conversation(message)
|
||||
result = await self.run_conversation(message)
|
||||
return result["final_response"]
|
||||
|
||||
|
||||
@@ -594,7 +846,10 @@ def main(
|
||||
list_tools: bool = False,
|
||||
save_trajectories: bool = False,
|
||||
verbose: bool = False,
|
||||
log_prefix_chars: int = 20
|
||||
log_prefix_chars: int = 20,
|
||||
show_profiling: bool = True,
|
||||
prokletor_client: str = None,
|
||||
prokletor_formatter: str = None,
|
||||
):
|
||||
"""
|
||||
Main function for running the agent directly.
|
||||
@@ -613,6 +868,9 @@ def main(
|
||||
save_trajectories (bool): Save conversation trajectories to JSONL files. Defaults to False.
|
||||
verbose (bool): Enable verbose logging for debugging. Defaults to False.
|
||||
log_prefix_chars (int): Number of characters to show in log previews for tool calls/responses. Defaults to 20.
|
||||
show_profiling (bool): Display profiling statistics after conversation. Defaults to True.
|
||||
prokletor_client (str): Name of the prokletor client to use (e.g., "AsyncClaudeClient")
|
||||
prokletor_formatter (str): Name of the prokletor formatter to use
|
||||
|
||||
Toolset Examples:
|
||||
- "research": Web search, extract, crawl + vision tools
|
||||
@@ -731,7 +989,9 @@ def main(
|
||||
disabled_toolsets=disabled_toolsets_list,
|
||||
save_trajectories=save_trajectories,
|
||||
verbose_logging=verbose,
|
||||
log_prefix_chars=log_prefix_chars
|
||||
log_prefix_chars=log_prefix_chars,
|
||||
prokletor_client=prokletor_client,
|
||||
prokletor_formatter=prokletor_formatter
|
||||
)
|
||||
except RuntimeError as e:
|
||||
print(f"❌ Failed to initialize agent: {e}")
|
||||
@@ -750,7 +1010,7 @@ def main(
|
||||
print("\n" + "=" * 50)
|
||||
|
||||
# Run conversation
|
||||
result = agent.run_conversation(user_query)
|
||||
result = asyncio.run(agent.run_conversation(user_query))
|
||||
|
||||
print("\n" + "=" * 50)
|
||||
print("📋 CONVERSATION SUMMARY")
|
||||
@@ -763,7 +1023,11 @@ def main(
|
||||
print(f"\n🎯 FINAL RESPONSE:")
|
||||
print("-" * 30)
|
||||
print(result['final_response'])
|
||||
|
||||
|
||||
# Display profiling statistics if enabled
|
||||
if show_profiling:
|
||||
get_profiler().print_statistics(detailed=True)
|
||||
|
||||
print("\n👋 Agent execution completed!")
|
||||
|
||||
|
||||
|
||||
20
safe_print.py
Normal file
20
safe_print.py
Normal file
@@ -0,0 +1,20 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Simple safe print that tries rich, falls back to regular print."""
|
||||
|
||||
try:
|
||||
from rich import print as rich_print
|
||||
RICH_AVAILABLE = True
|
||||
except ImportError:
|
||||
RICH_AVAILABLE = False
|
||||
|
||||
|
||||
def safe_print(*args, **kwargs):
|
||||
"""Try rich.print, fall back to regular print if it fails."""
|
||||
if RICH_AVAILABLE:
|
||||
try:
|
||||
rich_print(*args, **kwargs)
|
||||
return
|
||||
except Exception:
|
||||
pass
|
||||
# Fallback to regular print
|
||||
print(*args, **kwargs)
|
||||
@@ -78,6 +78,7 @@ AGGREGATOR_TEMPERATURE = 0.4 # Focused synthesis for consistency
|
||||
|
||||
# Failure handling configuration
|
||||
MIN_SUCCESSFUL_REFERENCES = 1 # Minimum successful reference models needed to proceed
|
||||
UNAVAILABLE_TOOL_RESPONSE = "This tools is not available"
|
||||
|
||||
# System prompt for the aggregator model (from the research paper)
|
||||
AGGREGATOR_SYSTEM_PROMPT = """You have been provided with a set of responses from various open-source models to the latest user query. Your task is to synthesize these responses into a single, high-quality response. It is crucial to critically evaluate the information provided in these responses, recognizing that some of it may be biased or incorrect. Your response should not simply replicate the given answers but should offer a refined, accurate, and comprehensive reply to the instruction. Ensure your response is well-structured, coherent, and adheres to the highest standards of accuracy and reliability.
|
||||
@@ -364,13 +365,28 @@ async def mixture_of_agents_tool(
|
||||
if failed_models:
|
||||
print(f"⚠️ Failed models: {', '.join(failed_models)}")
|
||||
|
||||
# Check if we have enough successful responses to proceed
|
||||
if successful_count < MIN_SUCCESSFUL_REFERENCES:
|
||||
raise ValueError(f"Insufficient successful reference models ({successful_count}/{len(ref_models)}). Need at least {MIN_SUCCESSFUL_REFERENCES} successful responses.")
|
||||
|
||||
debug_call_data["reference_responses_count"] = successful_count
|
||||
debug_call_data["failed_models_count"] = failed_count
|
||||
debug_call_data["failed_models"] = failed_models
|
||||
|
||||
# Check if we have enough successful responses to proceed
|
||||
if successful_count < MIN_SUCCESSFUL_REFERENCES:
|
||||
print("🚫 MoA tool unavailable: insufficient successful reference models after retries")
|
||||
result = {
|
||||
"success": False,
|
||||
"response": UNAVAILABLE_TOOL_RESPONSE,
|
||||
"models_used": {
|
||||
"reference_models": ref_models,
|
||||
"aggregator_model": agg_model
|
||||
}
|
||||
}
|
||||
debug_call_data["error"] = UNAVAILABLE_TOOL_RESPONSE
|
||||
debug_call_data["models_used"] = result["models_used"]
|
||||
processing_time = (datetime.datetime.now() - start_time).total_seconds()
|
||||
debug_call_data["processing_time_seconds"] = processing_time
|
||||
_log_debug_call("mixture_of_agents_tool", debug_call_data)
|
||||
_save_debug_log()
|
||||
return json.dumps(result, indent=2, ensure_ascii=False)
|
||||
|
||||
# Layer 2: Aggregate responses using the aggregator model
|
||||
print("🧠 Layer 2: Synthesizing final response...")
|
||||
|
||||
@@ -189,8 +189,13 @@ def _execute_ssh_command(instance, command: str, timeout: Optional[int] = None)
|
||||
ssh_context_manager = instance.ssh()
|
||||
ssh_context = ssh_context_manager.__enter__()
|
||||
|
||||
# Execute the command
|
||||
result = ssh_context.run(command, get_pty=False, timeout=timeout or 120)
|
||||
# Execute the command. Using a PTY ensures stdout/stderr ordering matches
|
||||
# what a human would see in a terminal session.
|
||||
result = ssh_context.run(
|
||||
command,
|
||||
get_pty=True,
|
||||
timeout=timeout or 120,
|
||||
)
|
||||
|
||||
# Close the SSH connection
|
||||
if ssh_context_manager:
|
||||
@@ -213,22 +218,12 @@ def _execute_ssh_command(instance, command: str, timeout: Optional[int] = None)
|
||||
except:
|
||||
pass
|
||||
|
||||
# Check if it's a timeout
|
||||
error_str = str(e).lower()
|
||||
if "timeout" in error_str:
|
||||
return {
|
||||
"stdout": "",
|
||||
"stderr": f"Command timed out after {timeout or 120} seconds",
|
||||
"returncode": 124
|
||||
}
|
||||
|
||||
return {
|
||||
"stdout": "",
|
||||
"stderr": f"SSH execution failed: {str(e)}",
|
||||
"returncode": -1
|
||||
}
|
||||
|
||||
|
||||
def simple_terminal_tool(
|
||||
command: str,
|
||||
background: bool = False,
|
||||
@@ -315,15 +310,21 @@ def simple_terminal_tool(
|
||||
result = _execute_ssh_command(instance, exec_command, timeout=10)
|
||||
|
||||
# For background tasks, return immediately with info
|
||||
stderr_text = (result["stderr"] or "").strip()
|
||||
if result["returncode"] == 0:
|
||||
return json.dumps({
|
||||
"output": "Background task started successfully",
|
||||
"stderr": stderr_text,
|
||||
"exit_code": 0,
|
||||
"error": None
|
||||
}, ensure_ascii=False)
|
||||
else:
|
||||
output_text = result["stdout"] or ""
|
||||
if result["stderr"] and not output_text:
|
||||
output_text = result["stderr"]
|
||||
return json.dumps({
|
||||
"output": result["stdout"],
|
||||
"output": output_text,
|
||||
"stderr": stderr_text,
|
||||
"exit_code": result["returncode"],
|
||||
"error": result["stderr"]
|
||||
}, ensure_ascii=False)
|
||||
@@ -331,13 +332,13 @@ def simple_terminal_tool(
|
||||
# Run foreground command
|
||||
result = _execute_ssh_command(instance, command, timeout=timeout)
|
||||
|
||||
# Combine stdout and stderr for output
|
||||
output = result["stdout"]
|
||||
output = result["stdout"] or ""
|
||||
if result["stderr"] and result["returncode"] != 0:
|
||||
output = f"{output}\n{result['stderr']}" if output else result["stderr"]
|
||||
|
||||
stderr_text = (result["stderr"] or "").strip()
|
||||
return json.dumps({
|
||||
"output": output.strip(),
|
||||
"stderr": stderr_text,
|
||||
"exit_code": result["returncode"],
|
||||
"error": result["stderr"] if result["returncode"] != 0 else None
|
||||
}, ensure_ascii=False)
|
||||
|
||||
@@ -48,11 +48,11 @@ import uuid
|
||||
import datetime
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Any, Optional
|
||||
from firecrawl import Firecrawl
|
||||
from firecrawl import AsyncFirecrawl
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
# Initialize Firecrawl client once at module level
|
||||
firecrawl_client = Firecrawl(api_key=os.getenv("FIRECRAWL_API_KEY"))
|
||||
firecrawl_client = AsyncFirecrawl(api_key=os.getenv("FIRECRAWL_API_KEY"))
|
||||
|
||||
# Initialize Nous Research API client for LLM processing (async)
|
||||
nous_client = AsyncOpenAI(
|
||||
@@ -261,7 +261,7 @@ def clean_base64_images(text: str) -> str:
|
||||
return cleaned_text
|
||||
|
||||
|
||||
def web_search_tool(query: str, limit: int = 5) -> str:
|
||||
async def web_search_tool(query: str, limit: int = 5) -> str:
|
||||
"""
|
||||
Search the web for information using available search API backend.
|
||||
|
||||
@@ -312,7 +312,7 @@ def web_search_tool(query: str, limit: int = 5) -> str:
|
||||
# Use Firecrawl's v2 search functionality WITHOUT scraping
|
||||
# We only want search result metadata, not scraped content
|
||||
# Docs: https://docs.firecrawl.dev/features/search
|
||||
response = firecrawl_client.search(
|
||||
response = await firecrawl_client.search(
|
||||
query=query,
|
||||
limit=limit
|
||||
)
|
||||
@@ -446,7 +446,7 @@ async def web_extract_tool(
|
||||
for url in urls:
|
||||
try:
|
||||
print(f" 📄 Scraping: {url}")
|
||||
scrape_result = firecrawl_client.scrape(
|
||||
scrape_result = await firecrawl_client.scrape(
|
||||
url=url,
|
||||
formats=formats
|
||||
)
|
||||
@@ -703,7 +703,7 @@ async def web_crawl_tool(
|
||||
|
||||
# Use the crawl method which waits for completion automatically
|
||||
try:
|
||||
crawl_result = firecrawl_client.crawl(
|
||||
crawl_result = await firecrawl_client.crawl(
|
||||
url=url,
|
||||
**crawl_params
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user