mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-05-05 18:27:04 +08:00
Compare commits
22 Commits
fix-termin
...
profiling
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e06a15b3ab | ||
|
|
349e37de0a | ||
|
|
ab7293bed6 | ||
|
|
1614c15bb1 | ||
|
|
f813959750 | ||
|
|
f957ec2267 | ||
|
|
92e3074c10 | ||
|
|
31c733383b | ||
|
|
0c618482c4 | ||
|
|
2d8f6c46f1 | ||
|
|
c27787f09f | ||
|
|
d90fcd4e2b | ||
|
|
69fd0ca9aa | ||
|
|
4135cf4682 | ||
|
|
c82741c3d8 | ||
|
|
9573b2ac2d | ||
|
|
fbd3a2fdb8 | ||
|
|
a4db3fdee5 | ||
|
|
ab5c9fc37b | ||
|
|
0ca3e0aaa9 | ||
|
|
f6f75cbe2b | ||
|
|
d4544f08c5 |
11
.gitignore
vendored
11
.gitignore
vendored
@@ -20,4 +20,13 @@ logs/
|
||||
data/
|
||||
.pytest_cache/
|
||||
tmp/
|
||||
temp_vision_images/
|
||||
temp_vision_images/
|
||||
hermes-*/*
|
||||
examples/
|
||||
tests/quick_test_dataset.jsonl
|
||||
tests/sample_dataset.jsonl
|
||||
run_datagen_kimik2-thinking.sh
|
||||
run_datagen_megascience_glm4-6.sh
|
||||
run_datagen_sonnet.sh
|
||||
source-data/*
|
||||
run_datagen_megascience_glm4-6.sh
|
||||
|
||||
822
batch_runner.py
822
batch_runner.py
File diff suppressed because it is too large
Load Diff
103
model_tools.py
103
model_tools.py
@@ -28,10 +28,12 @@ Usage:
|
||||
|
||||
import json
|
||||
import asyncio
|
||||
from typing import Dict, Any, List
|
||||
from typing import Dict, Any, List, Optional
|
||||
|
||||
from tools.web_tools import web_search_tool, web_extract_tool, web_crawl_tool, check_firecrawl_api_key
|
||||
from tools.terminal_tool import terminal_tool, check_hecate_requirements, TERMINAL_TOOL_DESCRIPTION
|
||||
from tools.simple_terminal_tool import simple_terminal_tool, check_requirements as check_simple_terminal_requirements, SIMPLE_TERMINAL_TOOL_DESCRIPTION
|
||||
# Keep old terminal tool for backwards compatibility if needed
|
||||
# from tools.terminal_tool import terminal_tool, check_hecate_requirements, TERMINAL_TOOL_DESCRIPTION
|
||||
from tools.vision_tools import vision_analyze_tool, check_vision_requirements
|
||||
from tools.mixture_of_agents_tool import mixture_of_agents_tool, check_moa_requirements
|
||||
from tools.image_generation_tool import image_generate_tool, check_image_generation_requirements
|
||||
@@ -111,7 +113,7 @@ def get_web_tool_definitions() -> List[Dict[str, Any]]:
|
||||
def get_terminal_tool_definitions() -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get tool definitions for terminal tools in OpenAI's expected format.
|
||||
|
||||
|
||||
Returns:
|
||||
List[Dict]: List of terminal tool definitions compatible with OpenAI API
|
||||
"""
|
||||
@@ -120,7 +122,7 @@ def get_terminal_tool_definitions() -> List[Dict[str, Any]]:
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "terminal",
|
||||
"description": TERMINAL_TOOL_DESCRIPTION,
|
||||
"description": SIMPLE_TERMINAL_TOOL_DESCRIPTION,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@@ -128,28 +130,18 @@ def get_terminal_tool_definitions() -> List[Dict[str, Any]]:
|
||||
"type": "string",
|
||||
"description": "The command to execute on the VM"
|
||||
},
|
||||
"input_keys": {
|
||||
"type": "string",
|
||||
"description": "Keystrokes to send to the most recent interactive session (e.g., 'hello\\n' for typing hello + Enter). If no active session exists, this will be ignored."
|
||||
},
|
||||
"background": {
|
||||
"type": "boolean",
|
||||
"description": "Whether to run the command in the background (default: false)",
|
||||
"default": False
|
||||
},
|
||||
"idle_threshold": {
|
||||
"type": "number",
|
||||
"description": "Seconds to wait for output before considering session idle (default: 5.0)",
|
||||
"default": 5.0,
|
||||
"minimum": 0.1
|
||||
},
|
||||
"timeout": {
|
||||
"type": "integer",
|
||||
"description": "Command timeout in seconds (optional)",
|
||||
"minimum": 1
|
||||
}
|
||||
},
|
||||
"required": []
|
||||
"required": ["command"]
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -262,11 +254,11 @@ def get_all_tool_names() -> List[str]:
|
||||
# Web tools
|
||||
if check_firecrawl_api_key():
|
||||
tool_names.extend(["web_search", "web_extract", "web_crawl"])
|
||||
|
||||
# Terminal tools
|
||||
if check_hecate_requirements():
|
||||
|
||||
# Terminal tools
|
||||
if check_simple_terminal_requirements():
|
||||
tool_names.extend(["terminal"])
|
||||
|
||||
|
||||
# Vision tools
|
||||
if check_vision_requirements():
|
||||
tool_names.extend(["vision_analyze"])
|
||||
@@ -346,11 +338,11 @@ def get_tool_definitions(
|
||||
if check_firecrawl_api_key():
|
||||
for tool in get_web_tool_definitions():
|
||||
all_available_tools_map[tool["function"]["name"]] = tool
|
||||
|
||||
if check_hecate_requirements():
|
||||
|
||||
if check_simple_terminal_requirements():
|
||||
for tool in get_terminal_tool_definitions():
|
||||
all_available_tools_map[tool["function"]["name"]] = tool
|
||||
|
||||
|
||||
if check_vision_requirements():
|
||||
for tool in get_vision_tool_definitions():
|
||||
all_available_tools_map[tool["function"]["name"]] = tool
|
||||
@@ -478,30 +470,29 @@ def handle_web_function_call(function_name: str, function_args: Dict[str, Any])
|
||||
return asyncio.run(web_crawl_tool(url, instructions, "basic"))
|
||||
|
||||
else:
|
||||
return json.dumps({"error": f"Unknown web function: {function_name}"})
|
||||
return json.dumps({"error": f"Unknown web function: {function_name}"}, ensure_ascii=False)
|
||||
|
||||
def handle_terminal_function_call(function_name: str, function_args: Dict[str, Any]) -> str:
|
||||
def handle_terminal_function_call(function_name: str, function_args: Dict[str, Any], task_id: Optional[str] = None) -> str:
|
||||
"""
|
||||
Handle function calls for terminal tools.
|
||||
|
||||
|
||||
Args:
|
||||
function_name (str): Name of the terminal function to call
|
||||
function_args (Dict): Arguments for the function
|
||||
|
||||
task_id (str): Unique identifier for this task to isolate VMs between concurrent tasks (optional)
|
||||
|
||||
Returns:
|
||||
str: Function result as JSON string
|
||||
"""
|
||||
if function_name == "terminal":
|
||||
command = function_args.get("command")
|
||||
input_keys = function_args.get("input_keys")
|
||||
background = function_args.get("background", False)
|
||||
idle_threshold = function_args.get("idle_threshold", 5.0)
|
||||
timeout = function_args.get("timeout")
|
||||
|
||||
return terminal_tool(command, input_keys, None, background, idle_threshold, timeout)
|
||||
|
||||
return simple_terminal_tool(command=command, background=background, timeout=timeout, task_id=task_id)
|
||||
|
||||
else:
|
||||
return json.dumps({"error": f"Unknown terminal function: {function_name}"})
|
||||
return json.dumps({"error": f"Unknown terminal function: {function_name}"}, ensure_ascii=False)
|
||||
|
||||
|
||||
def handle_vision_function_call(function_name: str, function_args: Dict[str, Any]) -> str:
|
||||
@@ -525,7 +516,7 @@ def handle_vision_function_call(function_name: str, function_args: Dict[str, Any
|
||||
return asyncio.run(vision_analyze_tool(image_url, full_prompt, "gemini-2.5-flash"))
|
||||
|
||||
else:
|
||||
return json.dumps({"error": f"Unknown vision function: {function_name}"})
|
||||
return json.dumps({"error": f"Unknown vision function: {function_name}"}, ensure_ascii=False)
|
||||
|
||||
|
||||
def handle_moa_function_call(function_name: str, function_args: Dict[str, Any]) -> str:
|
||||
@@ -543,13 +534,13 @@ def handle_moa_function_call(function_name: str, function_args: Dict[str, Any])
|
||||
user_prompt = function_args.get("user_prompt", "")
|
||||
|
||||
if not user_prompt:
|
||||
return json.dumps({"error": "user_prompt is required for MoA processing"})
|
||||
return json.dumps({"error": "user_prompt is required for MoA processing"}, ensure_ascii=False)
|
||||
|
||||
# Run async function in event loop
|
||||
return asyncio.run(mixture_of_agents_tool(user_prompt=user_prompt))
|
||||
|
||||
else:
|
||||
return json.dumps({"error": f"Unknown MoA function: {function_name}"})
|
||||
return json.dumps({"error": f"Unknown MoA function: {function_name}"}, ensure_ascii=False)
|
||||
|
||||
|
||||
def handle_image_function_call(function_name: str, function_args: Dict[str, Any]) -> str:
|
||||
@@ -567,7 +558,7 @@ def handle_image_function_call(function_name: str, function_args: Dict[str, Any]
|
||||
prompt = function_args.get("prompt", "")
|
||||
|
||||
if not prompt:
|
||||
return json.dumps({"success": False, "image": None})
|
||||
return json.dumps({"success": False, "image": None}, ensure_ascii=False)
|
||||
|
||||
image_size = function_args.get("image_size", "landscape_16_9")
|
||||
|
||||
@@ -611,24 +602,25 @@ def handle_image_function_call(function_name: str, function_args: Dict[str, Any]
|
||||
return result
|
||||
|
||||
else:
|
||||
return json.dumps({"error": f"Unknown image generation function: {function_name}"})
|
||||
return json.dumps({"error": f"Unknown image generation function: {function_name}"}, ensure_ascii=False)
|
||||
|
||||
|
||||
def handle_function_call(function_name: str, function_args: Dict[str, Any]) -> str:
|
||||
def handle_function_call(function_name: str, function_args: Dict[str, Any], task_id: Optional[str] = None) -> str:
|
||||
"""
|
||||
Main function call dispatcher that routes calls to appropriate toolsets.
|
||||
|
||||
|
||||
This function determines which toolset a function belongs to and dispatches
|
||||
the call to the appropriate handler. This makes it easy to add new toolsets
|
||||
without changing the main calling interface.
|
||||
|
||||
|
||||
Args:
|
||||
function_name (str): Name of the function to call
|
||||
function_args (Dict): Arguments for the function
|
||||
|
||||
task_id (str): Unique identifier for this task to isolate VMs between concurrent tasks (optional)
|
||||
|
||||
Returns:
|
||||
str: Function result as JSON string
|
||||
|
||||
|
||||
Raises:
|
||||
None: Returns error as JSON string instead of raising exceptions
|
||||
"""
|
||||
@@ -636,32 +628,33 @@ def handle_function_call(function_name: str, function_args: Dict[str, Any]) -> s
|
||||
# Route web tools
|
||||
if function_name in ["web_search", "web_extract", "web_crawl"]:
|
||||
return handle_web_function_call(function_name, function_args)
|
||||
|
||||
|
||||
# Route terminal tools
|
||||
elif function_name in ["terminal"]:
|
||||
return handle_terminal_function_call(function_name, function_args)
|
||||
|
||||
return handle_terminal_function_call(function_name, function_args, task_id)
|
||||
|
||||
# Route vision tools
|
||||
elif function_name in ["vision_analyze"]:
|
||||
return handle_vision_function_call(function_name, function_args)
|
||||
|
||||
|
||||
# Route MoA tools
|
||||
elif function_name in ["mixture_of_agents"]:
|
||||
return handle_moa_function_call(function_name, function_args)
|
||||
|
||||
|
||||
# Route image generation tools
|
||||
elif function_name in ["image_generate"]:
|
||||
return handle_image_function_call(function_name, function_args)
|
||||
|
||||
|
||||
else:
|
||||
error_msg = f"Unknown function: {function_name}"
|
||||
print(f"❌ {error_msg}")
|
||||
return json.dumps({"error": error_msg})
|
||||
|
||||
return json.dumps({"error": error_msg}, ensure_ascii=False)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Error executing {function_name}: {str(e)}"
|
||||
print(f"❌ {error_msg}")
|
||||
return json.dumps({"error": error_msg})
|
||||
return json.dumps({"error": error_msg}, ensure_ascii=False)
|
||||
|
||||
def get_available_toolsets() -> Dict[str, Dict[str, Any]]:
|
||||
"""
|
||||
@@ -678,10 +671,10 @@ def get_available_toolsets() -> Dict[str, Dict[str, Any]]:
|
||||
"requirements": ["FIRECRAWL_API_KEY environment variable"]
|
||||
},
|
||||
"terminal_tools": {
|
||||
"available": check_hecate_requirements(),
|
||||
"tools": ["terminal_tool"],
|
||||
"description": "Execute commands with optional interactive session support on Linux VMs",
|
||||
"requirements": ["MORPH_API_KEY environment variable", "hecate package"]
|
||||
"available": check_simple_terminal_requirements(),
|
||||
"tools": ["simple_terminal_tool"],
|
||||
"description": "Execute commands on secure Linux VMs without session persistence",
|
||||
"requirements": ["MORPH_API_KEY environment variable"]
|
||||
},
|
||||
"vision_tools": {
|
||||
"available": check_vision_requirements(),
|
||||
@@ -708,13 +701,13 @@ def get_available_toolsets() -> Dict[str, Dict[str, Any]]:
|
||||
def check_toolset_requirements() -> Dict[str, bool]:
|
||||
"""
|
||||
Check if all requirements for available toolsets are met.
|
||||
|
||||
|
||||
Returns:
|
||||
Dict: Status of each toolset's requirements
|
||||
"""
|
||||
return {
|
||||
"web_tools": check_firecrawl_api_key(),
|
||||
"terminal_tools": check_hecate_requirements(),
|
||||
"terminal_tools": check_simple_terminal_requirements(),
|
||||
"vision_tools": check_vision_requirements(),
|
||||
"moa_tools": check_moa_requirements(),
|
||||
"image_tools": check_image_generation_requirements()
|
||||
|
||||
381
profiling.py
Normal file
381
profiling.py
Normal file
@@ -0,0 +1,381 @@
|
||||
"""
|
||||
Profiling module for tracking timing statistics of tools and LLM API calls.
|
||||
|
||||
This module provides a centralized way to track timing information for various
|
||||
operations in the agent system, including:
|
||||
- Individual tool executions
|
||||
- OpenAI API calls
|
||||
- Aggregate statistics (min, max, median, mean, total)
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import Dict, List, Optional
|
||||
from dataclasses import dataclass, field
|
||||
from collections import defaultdict
|
||||
import statistics
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProfilingStats:
|
||||
"""Statistics for a particular operation type."""
|
||||
call_count: int = 0
|
||||
total_time: float = 0.0
|
||||
min_time: float = float('inf')
|
||||
max_time: float = 0.0
|
||||
times: List[float] = field(default_factory=list)
|
||||
|
||||
def add_timing(self, duration: float):
|
||||
"""Add a timing measurement."""
|
||||
self.call_count += 1
|
||||
self.total_time += duration
|
||||
self.min_time = min(self.min_time, duration)
|
||||
self.max_time = max(self.max_time, duration)
|
||||
self.times.append(duration)
|
||||
|
||||
@property
|
||||
def mean_time(self) -> float:
|
||||
"""Calculate mean time."""
|
||||
return self.total_time / self.call_count if self.call_count > 0 else 0.0
|
||||
|
||||
@property
|
||||
def median_time(self) -> float:
|
||||
"""Calculate median time."""
|
||||
return statistics.median(self.times) if self.times else 0.0
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
"""Convert to dictionary for serialization."""
|
||||
return {
|
||||
"call_count": self.call_count,
|
||||
"total_time": self.total_time,
|
||||
"min_time": self.min_time if self.min_time != float('inf') else 0.0,
|
||||
"max_time": self.max_time,
|
||||
"mean_time": self.mean_time,
|
||||
"median_time": self.median_time
|
||||
}
|
||||
|
||||
|
||||
class Profiler:
|
||||
"""
|
||||
Global profiler for tracking timing statistics across tools and API calls.
|
||||
|
||||
Usage:
|
||||
profiler = Profiler()
|
||||
|
||||
# Time a tool execution
|
||||
with profiler.time_tool("web_search"):
|
||||
# ... tool execution code ...
|
||||
pass
|
||||
|
||||
# Time an API call
|
||||
with profiler.time_api_call():
|
||||
# ... API call code ...
|
||||
pass
|
||||
|
||||
# Get statistics
|
||||
stats = profiler.get_statistics()
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the profiler."""
|
||||
self.tool_stats: Dict[str, ProfilingStats] = defaultdict(ProfilingStats)
|
||||
self.api_stats: ProfilingStats = ProfilingStats()
|
||||
self._enabled = True
|
||||
|
||||
def enable(self):
|
||||
"""Enable profiling."""
|
||||
self._enabled = True
|
||||
|
||||
def disable(self):
|
||||
"""Disable profiling."""
|
||||
self._enabled = False
|
||||
|
||||
def reset(self):
|
||||
"""Reset all profiling data."""
|
||||
self.tool_stats.clear()
|
||||
self.api_stats = ProfilingStats()
|
||||
|
||||
def record_tool_timing(self, tool_name: str, duration: float):
|
||||
"""Record timing for a tool execution."""
|
||||
if self._enabled:
|
||||
self.tool_stats[tool_name].add_timing(duration)
|
||||
|
||||
def record_api_timing(self, duration: float):
|
||||
"""Record timing for an API call."""
|
||||
if self._enabled:
|
||||
self.api_stats.add_timing(duration)
|
||||
|
||||
def get_statistics(self) -> Dict:
|
||||
"""
|
||||
Get all profiling statistics.
|
||||
|
||||
Returns:
|
||||
Dictionary containing tool and API statistics
|
||||
"""
|
||||
return {
|
||||
"tools": {
|
||||
tool_name: stats.to_dict()
|
||||
for tool_name, stats in sorted(self.tool_stats.items())
|
||||
},
|
||||
"api_calls": self.api_stats.to_dict()
|
||||
}
|
||||
|
||||
def print_statistics(self, detailed: bool = True):
|
||||
"""
|
||||
Print profiling statistics in a readable format.
|
||||
|
||||
Args:
|
||||
detailed: If True, show per-tool breakdown. If False, show summary only.
|
||||
"""
|
||||
print("\n" + "="*80)
|
||||
print("📊 PROFILING STATISTICS")
|
||||
print("="*80)
|
||||
|
||||
# API Call Statistics
|
||||
print("\n🔷 OpenAI API Calls:")
|
||||
if self.api_stats.call_count > 0:
|
||||
api_dict = self.api_stats.to_dict()
|
||||
print(f" Total Calls: {api_dict['call_count']}")
|
||||
print(f" Total Time: {api_dict['total_time']:.2f}s")
|
||||
print(f" Min Time: {api_dict['min_time']:.2f}s")
|
||||
print(f" Max Time: {api_dict['max_time']:.2f}s")
|
||||
print(f" Mean Time: {api_dict['mean_time']:.2f}s")
|
||||
print(f" Median Time: {api_dict['median_time']:.2f}s")
|
||||
else:
|
||||
print(" No API calls recorded")
|
||||
|
||||
# Tool Statistics
|
||||
print("\n🔧 Tool Executions:")
|
||||
if self.tool_stats:
|
||||
if detailed:
|
||||
for tool_name in sorted(self.tool_stats.keys()):
|
||||
stats_dict = self.tool_stats[tool_name].to_dict()
|
||||
print(f"\n 📌 {tool_name}:")
|
||||
print(f" Total Calls: {stats_dict['call_count']}")
|
||||
print(f" Total Time: {stats_dict['total_time']:.2f}s")
|
||||
print(f" Min Time: {stats_dict['min_time']:.2f}s")
|
||||
print(f" Max Time: {stats_dict['max_time']:.2f}s")
|
||||
print(f" Mean Time: {stats_dict['mean_time']:.2f}s")
|
||||
print(f" Median Time: {stats_dict['median_time']:.2f}s")
|
||||
|
||||
# Summary
|
||||
total_tool_calls = sum(s.call_count for s in self.tool_stats.values())
|
||||
total_tool_time = sum(s.total_time for s in self.tool_stats.values())
|
||||
print(f"\n 📊 Summary:")
|
||||
print(f" Total Tool Calls: {total_tool_calls}")
|
||||
print(f" Total Tool Time: {total_tool_time:.2f}s")
|
||||
print(f" Unique Tools Used: {len(self.tool_stats)}")
|
||||
else:
|
||||
print(" No tool executions recorded")
|
||||
|
||||
# Overall Summary
|
||||
total_api_time = self.api_stats.total_time
|
||||
total_tool_time = sum(s.total_time for s in self.tool_stats.values())
|
||||
print(f"\n📈 Overall Summary:")
|
||||
print(f" Total API Time: {total_api_time:.2f}s")
|
||||
print(f" Total Tool Time: {total_tool_time:.2f}s")
|
||||
print(f" Total Time: {total_api_time + total_tool_time:.2f}s")
|
||||
print("="*80 + "\n")
|
||||
|
||||
def export_to_json(self) -> str:
|
||||
"""Export statistics as JSON string."""
|
||||
import json
|
||||
return json.dumps(self.get_statistics(), indent=2)
|
||||
|
||||
def export_to_file(self, filepath: str):
|
||||
"""
|
||||
Export statistics to a JSON file.
|
||||
|
||||
Args:
|
||||
filepath: Path to output file
|
||||
"""
|
||||
import json
|
||||
with open(filepath, 'w') as f:
|
||||
json.dump(self.get_statistics(), f, indent=2)
|
||||
print(f"📁 Profiling statistics exported to: {filepath}")
|
||||
|
||||
|
||||
# Global profiler instance
|
||||
_global_profiler: Optional[Profiler] = None
|
||||
|
||||
|
||||
def get_profiler() -> Profiler:
|
||||
"""Get or create the global profiler instance."""
|
||||
global _global_profiler
|
||||
if _global_profiler is None:
|
||||
_global_profiler = Profiler()
|
||||
return _global_profiler
|
||||
|
||||
|
||||
def reset_profiler():
|
||||
"""Reset the global profiler."""
|
||||
global _global_profiler
|
||||
if _global_profiler is not None:
|
||||
_global_profiler.reset()
|
||||
|
||||
|
||||
class TimingContext:
|
||||
"""Context manager for timing operations."""
|
||||
|
||||
def __init__(self, profiler: Profiler, operation_type: str, operation_name: Optional[str] = None):
|
||||
"""
|
||||
Initialize timing context.
|
||||
|
||||
Args:
|
||||
profiler: Profiler instance to record timing
|
||||
operation_type: 'tool' or 'api'
|
||||
operation_name: Name of the operation (required for tools)
|
||||
"""
|
||||
self.profiler = profiler
|
||||
self.operation_type = operation_type
|
||||
self.operation_name = operation_name
|
||||
self.start_time = None
|
||||
|
||||
def __enter__(self):
|
||||
"""Start timing."""
|
||||
self.start_time = time.time()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Stop timing and record."""
|
||||
duration = time.time() - self.start_time
|
||||
|
||||
if self.operation_type == 'tool':
|
||||
self.profiler.record_tool_timing(self.operation_name, duration)
|
||||
elif self.operation_type == 'api':
|
||||
self.profiler.record_api_timing(duration)
|
||||
|
||||
return False # Don't suppress exceptions
|
||||
|
||||
|
||||
def aggregate_profiling_stats(stats_list: List[Dict]) -> Dict:
|
||||
"""
|
||||
Aggregate multiple profiling statistics dictionaries into one.
|
||||
|
||||
This is useful for batch processing where each worker process has its own
|
||||
profiler instance that needs to be combined.
|
||||
|
||||
Args:
|
||||
stats_list: List of statistics dictionaries from get_statistics()
|
||||
|
||||
Returns:
|
||||
Dict: Aggregated statistics with combined tool and API call data
|
||||
"""
|
||||
aggregated = {
|
||||
"tools": defaultdict(lambda: {"times": []}),
|
||||
"api_calls": {"times": []}
|
||||
}
|
||||
|
||||
# Aggregate tool statistics
|
||||
for stats in stats_list:
|
||||
# Aggregate tool timings
|
||||
for tool_name, tool_stats in stats.get("tools", {}).items():
|
||||
# Reconstruct individual timings from aggregated stats
|
||||
# Since we have mean_time and call_count, we approximate
|
||||
aggregated["tools"][tool_name]["times"].extend(
|
||||
[tool_stats.get("mean_time", 0.0)] * tool_stats.get("call_count", 0)
|
||||
)
|
||||
|
||||
# Aggregate API call timings
|
||||
api_stats = stats.get("api_calls", {})
|
||||
if api_stats.get("call_count", 0) > 0:
|
||||
aggregated["api_calls"]["times"].extend(
|
||||
[api_stats.get("mean_time", 0.0)] * api_stats.get("call_count", 0)
|
||||
)
|
||||
|
||||
# Calculate final statistics for tools
|
||||
final_stats = {"tools": {}, "api_calls": {}}
|
||||
|
||||
for tool_name, data in aggregated["tools"].items():
|
||||
times = data["times"]
|
||||
if times:
|
||||
final_stats["tools"][tool_name] = {
|
||||
"call_count": len(times),
|
||||
"total_time": sum(times),
|
||||
"min_time": min(times),
|
||||
"max_time": max(times),
|
||||
"mean_time": statistics.mean(times),
|
||||
"median_time": statistics.median(times)
|
||||
}
|
||||
|
||||
# Calculate final statistics for API calls
|
||||
api_times = aggregated["api_calls"]["times"]
|
||||
if api_times:
|
||||
final_stats["api_calls"] = {
|
||||
"call_count": len(api_times),
|
||||
"total_time": sum(api_times),
|
||||
"min_time": min(api_times),
|
||||
"max_time": max(api_times),
|
||||
"mean_time": statistics.mean(api_times),
|
||||
"median_time": statistics.median(api_times)
|
||||
}
|
||||
else:
|
||||
final_stats["api_calls"] = {
|
||||
"call_count": 0,
|
||||
"total_time": 0.0,
|
||||
"min_time": 0.0,
|
||||
"max_time": 0.0,
|
||||
"mean_time": 0.0,
|
||||
"median_time": 0.0
|
||||
}
|
||||
|
||||
return final_stats
|
||||
|
||||
|
||||
def print_aggregated_statistics(stats: Dict, detailed: bool = True):
|
||||
"""
|
||||
Print aggregated profiling statistics in a readable format.
|
||||
|
||||
Args:
|
||||
stats: Aggregated statistics dictionary from aggregate_profiling_stats()
|
||||
detailed: If True, show per-tool breakdown. If False, show summary only.
|
||||
"""
|
||||
print("\n" + "="*80)
|
||||
print("📊 AGGREGATED PROFILING STATISTICS")
|
||||
print("="*80)
|
||||
|
||||
# API Call Statistics
|
||||
print("\n🔷 OpenAI API Calls:")
|
||||
api_stats = stats.get("api_calls", {})
|
||||
if api_stats.get("call_count", 0) > 0:
|
||||
print(f" Total Calls: {api_stats['call_count']}")
|
||||
print(f" Total Time: {api_stats['total_time']:.2f}s")
|
||||
print(f" Min Time: {api_stats['min_time']:.2f}s")
|
||||
print(f" Max Time: {api_stats['max_time']:.2f}s")
|
||||
print(f" Mean Time: {api_stats['mean_time']:.2f}s")
|
||||
print(f" Median Time: {api_stats['median_time']:.2f}s")
|
||||
else:
|
||||
print(" No API calls recorded")
|
||||
|
||||
# Tool Statistics
|
||||
print("\n🔧 Tool Executions:")
|
||||
tool_stats = stats.get("tools", {})
|
||||
if tool_stats:
|
||||
if detailed:
|
||||
for tool_name in sorted(tool_stats.keys()):
|
||||
stats_dict = tool_stats[tool_name]
|
||||
print(f"\n 📌 {tool_name}:")
|
||||
print(f" Total Calls: {stats_dict['call_count']}")
|
||||
print(f" Total Time: {stats_dict['total_time']:.2f}s")
|
||||
print(f" Min Time: {stats_dict['min_time']:.2f}s")
|
||||
print(f" Max Time: {stats_dict['max_time']:.2f}s")
|
||||
print(f" Mean Time: {stats_dict['mean_time']:.2f}s")
|
||||
print(f" Median Time: {stats_dict['median_time']:.2f}s")
|
||||
|
||||
# Summary
|
||||
total_tool_calls = sum(s["call_count"] for s in tool_stats.values())
|
||||
total_tool_time = sum(s["total_time"] for s in tool_stats.values())
|
||||
print(f"\n 📊 Summary:")
|
||||
print(f" Total Tool Calls: {total_tool_calls}")
|
||||
print(f" Total Tool Time: {total_tool_time:.2f}s")
|
||||
print(f" Unique Tools Used: {len(tool_stats)}")
|
||||
else:
|
||||
print(" No tool executions recorded")
|
||||
|
||||
# Overall Summary
|
||||
total_api_time = api_stats.get("total_time", 0.0)
|
||||
total_tool_time = sum(s["total_time"] for s in tool_stats.values())
|
||||
print(f"\n📈 Overall Summary:")
|
||||
print(f" Total API Time: {total_api_time:.2f}s")
|
||||
print(f" Total Tool Time: {total_tool_time:.2f}s")
|
||||
print(f" Total Time: {total_api_time + total_tool_time:.2f}s")
|
||||
print("="*80 + "\n")
|
||||
@@ -3,4 +3,4 @@ openai
|
||||
fal-client
|
||||
python-dotenv
|
||||
fire
|
||||
requests
|
||||
httpx
|
||||
149
run_agent.py
149
run_agent.py
@@ -43,6 +43,10 @@ else:
|
||||
|
||||
# Import our tool system
|
||||
from model_tools import get_tool_definitions, handle_function_call, check_toolset_requirements
|
||||
from tools.terminal_tool import cleanup_vm
|
||||
|
||||
# Import profiling
|
||||
from profiling import get_profiler
|
||||
|
||||
|
||||
class AIAgent:
|
||||
@@ -54,9 +58,9 @@ class AIAgent:
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_url: str = None,
|
||||
api_key: str = None,
|
||||
self,
|
||||
base_url: str = None,
|
||||
api_key: str = None,
|
||||
model: str = "gpt-4",
|
||||
max_iterations: int = 10,
|
||||
tool_delay: float = 1.0,
|
||||
@@ -64,11 +68,12 @@ class AIAgent:
|
||||
disabled_toolsets: List[str] = None,
|
||||
save_trajectories: bool = False,
|
||||
verbose_logging: bool = False,
|
||||
ephemeral_system_prompt: str = None
|
||||
ephemeral_system_prompt: str = None,
|
||||
log_prefix_chars: int = 100,
|
||||
):
|
||||
"""
|
||||
Initialize the AI Agent.
|
||||
|
||||
|
||||
Args:
|
||||
base_url (str): Base URL for the model API (optional)
|
||||
api_key (str): API key for authentication (optional, uses env var if not provided)
|
||||
@@ -80,6 +85,7 @@ class AIAgent:
|
||||
save_trajectories (bool): Whether to save conversation trajectories to JSONL files (default: False)
|
||||
verbose_logging (bool): Enable verbose logging for debugging (default: False)
|
||||
ephemeral_system_prompt (str): System prompt used during agent execution but NOT saved to trajectories (optional)
|
||||
log_prefix_chars (int): Number of characters to show in log previews for tool calls/responses (default: 20)
|
||||
"""
|
||||
self.model = model
|
||||
self.max_iterations = max_iterations
|
||||
@@ -87,7 +93,8 @@ class AIAgent:
|
||||
self.save_trajectories = save_trajectories
|
||||
self.verbose_logging = verbose_logging
|
||||
self.ephemeral_system_prompt = ephemeral_system_prompt
|
||||
|
||||
self.log_prefix_chars = log_prefix_chars
|
||||
|
||||
# Store toolset filtering options
|
||||
self.enabled_toolsets = enabled_toolsets
|
||||
self.disabled_toolsets = disabled_toolsets
|
||||
@@ -189,7 +196,7 @@ class AIAgent:
|
||||
}
|
||||
formatted_tools.append(formatted_tool)
|
||||
|
||||
return json.dumps(formatted_tools)
|
||||
return json.dumps(formatted_tools, ensure_ascii=False)
|
||||
|
||||
def _convert_to_trajectory_format(self, messages: List[Dict[str, Any]], user_query: str, completed: bool) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
@@ -250,7 +257,7 @@ class AIAgent:
|
||||
"name": tool_call["function"]["name"],
|
||||
"arguments": json.loads(tool_call["function"]["arguments"]) if isinstance(tool_call["function"]["arguments"], str) else tool_call["function"]["arguments"]
|
||||
}
|
||||
content += f"<tool_call>\n{json.dumps(tool_call_json)}\n</tool_call>\n"
|
||||
content += f"<tool_call>\n{json.dumps(tool_call_json, ensure_ascii=False)}\n</tool_call>\n"
|
||||
|
||||
trajectory.append({
|
||||
"from": "gpt",
|
||||
@@ -277,7 +284,7 @@ class AIAgent:
|
||||
"tool_call_id": tool_msg.get("tool_call_id", ""),
|
||||
"name": msg["tool_calls"][len(tool_responses)]["function"]["name"] if len(tool_responses) < len(msg["tool_calls"]) else "unknown",
|
||||
"content": tool_content
|
||||
})
|
||||
}, ensure_ascii=False)
|
||||
tool_response += "\n</tool_response>"
|
||||
tool_responses.append(tool_response)
|
||||
j += 1
|
||||
@@ -342,22 +349,31 @@ class AIAgent:
|
||||
print(f"⚠️ Failed to save trajectory: {e}")
|
||||
|
||||
def run_conversation(
|
||||
self,
|
||||
user_message: str,
|
||||
system_message: str = None,
|
||||
conversation_history: List[Dict[str, Any]] = None
|
||||
self,
|
||||
user_message: str,
|
||||
system_message: str = None,
|
||||
conversation_history: List[Dict[str, Any]] = None,
|
||||
task_id: str = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Run a complete conversation with tool calling until completion.
|
||||
|
||||
|
||||
Args:
|
||||
user_message (str): The user's message/question
|
||||
system_message (str): Custom system message (optional, overrides ephemeral_system_prompt if provided)
|
||||
conversation_history (List[Dict]): Previous conversation messages (optional)
|
||||
|
||||
task_id (str): Unique identifier for this task to isolate VMs between concurrent tasks (optional, auto-generated if not provided)
|
||||
|
||||
Returns:
|
||||
Dict: Complete conversation result with final response and message history
|
||||
"""
|
||||
# Reset profiler for this conversation to get fresh stats
|
||||
from profiling import reset_profiler as reset_prof
|
||||
reset_prof()
|
||||
|
||||
# Generate unique task_id if not provided to isolate VMs between concurrent tasks
|
||||
import uuid
|
||||
effective_task_id = task_id or str(uuid.uuid4())
|
||||
# Initialize conversation
|
||||
messages = conversation_history or []
|
||||
|
||||
@@ -379,7 +395,7 @@ class AIAgent:
|
||||
|
||||
while api_call_count < self.max_iterations:
|
||||
api_call_count += 1
|
||||
print(f"\n🔄 Making API call #{api_call_count}...")
|
||||
print(f"\n🔄 Making OpenAI-compatible API call #{api_call_count}...")
|
||||
|
||||
# Log request details if verbose
|
||||
if self.verbose_logging:
|
||||
@@ -388,8 +404,8 @@ class AIAgent:
|
||||
|
||||
api_start_time = time.time()
|
||||
retry_count = 0
|
||||
max_retries = 3
|
||||
|
||||
max_retries = 6 # Increased to allow longer backoff periods
|
||||
|
||||
while retry_count <= max_retries:
|
||||
try:
|
||||
# Prepare messages for API call
|
||||
@@ -398,30 +414,33 @@ class AIAgent:
|
||||
if active_system_prompt:
|
||||
# Insert system message at the beginning
|
||||
api_messages = [{"role": "system", "content": active_system_prompt}] + api_messages
|
||||
|
||||
|
||||
# Make API call with tools
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=api_messages,
|
||||
tools=self.tools if self.tools else None,
|
||||
timeout=60.0 # Add explicit timeout
|
||||
timeout=300.0 # 5 minute timeout for long-running agent tasks
|
||||
)
|
||||
|
||||
|
||||
api_duration = time.time() - api_start_time
|
||||
print(f"⏱️ API call completed in {api_duration:.2f}s")
|
||||
|
||||
print(f"⏱️ OpenAI-compatible API call completed in {api_duration:.2f}s")
|
||||
|
||||
# Record API timing in profiler
|
||||
get_profiler().record_api_timing(api_duration)
|
||||
|
||||
if self.verbose_logging:
|
||||
logging.debug(f"API Response received - Usage: {response.usage if hasattr(response, 'usage') else 'N/A'}")
|
||||
|
||||
|
||||
break # Success, exit retry loop
|
||||
|
||||
|
||||
except Exception as api_error:
|
||||
retry_count += 1
|
||||
if retry_count > max_retries:
|
||||
raise api_error
|
||||
|
||||
wait_time = min(2 ** retry_count, 10) # Exponential backoff, max 10s
|
||||
print(f"⚠️ API call failed (attempt {retry_count}/{max_retries}): {str(api_error)[:100]}")
|
||||
|
||||
wait_time = min(2 ** retry_count, 60) # Exponential backoff: 2s, 4s, 8s, 16s, 32s, 60s, 60s
|
||||
print(f"⚠️ OpenAI-compatible API call failed (attempt {retry_count}/{max_retries}): {str(api_error)[:100]}")
|
||||
print(f"⏳ Retrying in {wait_time}s...")
|
||||
logging.warning(f"API retry {retry_count}/{max_retries} after error: {api_error}")
|
||||
time.sleep(wait_time)
|
||||
@@ -468,28 +487,36 @@ class AIAgent:
|
||||
print(f"❌ Invalid JSON in tool call arguments: {e}")
|
||||
function_args = {}
|
||||
|
||||
print(f" 📞 Tool {i}: {function_name}({list(function_args.keys())})")
|
||||
|
||||
# Preview tool call arguments
|
||||
args_str = json.dumps(function_args, ensure_ascii=False)
|
||||
args_preview = args_str[:self.log_prefix_chars] + "..." if len(args_str) > self.log_prefix_chars else args_str
|
||||
print(f" 📞 Tool {i}: {function_name}({list(function_args.keys())}) - {args_preview}")
|
||||
|
||||
tool_start_time = time.time()
|
||||
|
||||
# Execute the tool
|
||||
function_result = handle_function_call(function_name, function_args)
|
||||
|
||||
|
||||
# Execute the tool with task_id to isolate VMs between concurrent tasks
|
||||
function_result = handle_function_call(function_name, function_args, effective_task_id)
|
||||
|
||||
tool_duration = time.time() - tool_start_time
|
||||
result_preview = function_result[:200] if len(function_result) > 200 else function_result
|
||||
|
||||
|
||||
# Record tool timing in profiler
|
||||
get_profiler().record_tool_timing(function_name, tool_duration)
|
||||
|
||||
if self.verbose_logging:
|
||||
logging.debug(f"Tool {function_name} completed in {tool_duration:.2f}s")
|
||||
logging.debug(f"Tool result preview: {result_preview}...")
|
||||
|
||||
|
||||
# Add tool result to conversation
|
||||
messages.append({
|
||||
"role": "tool",
|
||||
"content": function_result,
|
||||
"tool_call_id": tool_call.id
|
||||
})
|
||||
|
||||
print(f" ✅ Tool {i} completed in {tool_duration:.2f}s")
|
||||
|
||||
# Preview tool response
|
||||
response_preview = function_result[:self.log_prefix_chars] + "..." if len(function_result) > self.log_prefix_chars else function_result
|
||||
print(f" ✅ Tool {i} completed in {tool_duration:.2f}s - {response_preview}")
|
||||
|
||||
# Delay between tool calls
|
||||
if self.tool_delay > 0 and i < len(assistant_message.tool_calls):
|
||||
@@ -508,11 +535,11 @@ class AIAgent:
|
||||
"content": final_response
|
||||
})
|
||||
|
||||
print(f"🎉 Conversation completed after {api_call_count} API call(s)")
|
||||
print(f"🎉 Conversation completed after {api_call_count} OpenAI-compatible API call(s)")
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Error during API call #{api_call_count}: {str(e)}"
|
||||
error_msg = f"Error during OpenAI-compatible API call #{api_call_count}: {str(e)}"
|
||||
print(f"❌ {error_msg}")
|
||||
|
||||
if self.verbose_logging:
|
||||
@@ -537,15 +564,26 @@ class AIAgent:
|
||||
|
||||
# Determine if conversation completed successfully
|
||||
completed = final_response is not None and api_call_count < self.max_iterations
|
||||
|
||||
|
||||
# Save trajectory if enabled
|
||||
self._save_trajectory(messages, user_message, completed)
|
||||
|
||||
|
||||
# Clean up VM for this task after conversation completes
|
||||
try:
|
||||
cleanup_vm(effective_task_id)
|
||||
except Exception as e:
|
||||
if self.verbose_logging:
|
||||
logging.warning(f"Failed to cleanup VM for task {effective_task_id}: {e}")
|
||||
|
||||
# Get profiling statistics for this conversation
|
||||
profiling_stats = get_profiler().get_statistics()
|
||||
|
||||
return {
|
||||
"final_response": final_response,
|
||||
"messages": messages,
|
||||
"api_calls": api_call_count,
|
||||
"completed": completed
|
||||
"completed": completed,
|
||||
"profiling_stats": profiling_stats
|
||||
}
|
||||
|
||||
def chat(self, message: str) -> str:
|
||||
@@ -564,7 +602,7 @@ class AIAgent:
|
||||
|
||||
def main(
|
||||
query: str = None,
|
||||
model: str = "claude-opus-4-20250514",
|
||||
model: str = "claude-opus-4-20250514",
|
||||
api_key: str = None,
|
||||
base_url: str = "https://api.anthropic.com/v1/",
|
||||
max_turns: int = 10,
|
||||
@@ -572,25 +610,29 @@ def main(
|
||||
disabled_toolsets: str = None,
|
||||
list_tools: bool = False,
|
||||
save_trajectories: bool = False,
|
||||
verbose: bool = False
|
||||
verbose: bool = False,
|
||||
log_prefix_chars: int = 20,
|
||||
show_profiling: bool = True
|
||||
):
|
||||
"""
|
||||
Main function for running the agent directly.
|
||||
|
||||
|
||||
Args:
|
||||
query (str): Natural language query for the agent. Defaults to Python 3.13 example.
|
||||
model (str): Model name to use. Defaults to claude-opus-4-20250514.
|
||||
api_key (str): API key for authentication. Uses ANTHROPIC_API_KEY env var if not provided.
|
||||
base_url (str): Base URL for the model API. Defaults to https://api.anthropic.com/v1/
|
||||
max_turns (int): Maximum number of API call iterations. Defaults to 10.
|
||||
enabled_toolsets (str): Comma-separated list of toolsets to enable. Supports predefined
|
||||
toolsets (e.g., "research", "development", "safe").
|
||||
enabled_toolsets (str): Comma-separated list of toolsets to enable. Supports predefined
|
||||
toolsets (e.g., "research", "development", "safe").
|
||||
Multiple toolsets can be combined: "web,vision"
|
||||
disabled_toolsets (str): Comma-separated list of toolsets to disable (e.g., "terminal")
|
||||
list_tools (bool): Just list available tools and exit
|
||||
save_trajectories (bool): Save conversation trajectories to JSONL files. Defaults to False.
|
||||
verbose (bool): Enable verbose logging for debugging. Defaults to False.
|
||||
|
||||
log_prefix_chars (int): Number of characters to show in log previews for tool calls/responses. Defaults to 20.
|
||||
show_profiling (bool): Display profiling statistics after conversation. Defaults to True.
|
||||
|
||||
Toolset Examples:
|
||||
- "research": Web search, extract, crawl + vision tools
|
||||
"""
|
||||
@@ -707,7 +749,8 @@ def main(
|
||||
enabled_toolsets=enabled_toolsets_list,
|
||||
disabled_toolsets=disabled_toolsets_list,
|
||||
save_trajectories=save_trajectories,
|
||||
verbose_logging=verbose
|
||||
verbose_logging=verbose,
|
||||
log_prefix_chars=log_prefix_chars
|
||||
)
|
||||
except RuntimeError as e:
|
||||
print(f"❌ Failed to initialize agent: {e}")
|
||||
@@ -739,7 +782,11 @@ def main(
|
||||
print(f"\n🎯 FINAL RESPONSE:")
|
||||
print("-" * 30)
|
||||
print(result['final_response'])
|
||||
|
||||
|
||||
# Display profiling statistics if enabled
|
||||
if show_profiling:
|
||||
get_profiler().print_statistics(detailed=True)
|
||||
|
||||
print("\n👋 Agent execution completed!")
|
||||
|
||||
|
||||
|
||||
12
run_datagen_megascience.sh
Executable file
12
run_datagen_megascience.sh
Executable file
@@ -0,0 +1,12 @@
|
||||
python batch_runner.py \
|
||||
--dataset_file="hermes-agent-megascience-data/hermes_agent_megascience_eval.jsonl" \
|
||||
--batch_size=10 \
|
||||
--run_name="megascience_eval_gpt5_2" \
|
||||
--distribution="science" \
|
||||
--model="gpt-5" \
|
||||
--base_url="https://api.openai.com/v1" \
|
||||
--api_key="${OPENAI_API_KEY}" \
|
||||
--num_workers=5 \
|
||||
--max_turns=30 \
|
||||
--verbose \
|
||||
--ephemeral_system_prompt="You have access to a variety of tools to help you solve scientific, math, and technology problems presented to you. You can use them in sequence and build off of the results of prior tools you've used results. Always use a tool if it can provide additional context, verify formulas, double check concepts and recent studies and understanding, doing all calculations, etc. You should not be confident in your own reasoning, knowledge, or calculations without using a tool to verify or validate your work."
|
||||
12
run_datagen_megascience_glm4-6.sh
Executable file
12
run_datagen_megascience_glm4-6.sh
Executable file
@@ -0,0 +1,12 @@
|
||||
python batch_runner.py \
|
||||
--dataset_file="hermes-agent-megascience-data/hermes_agent_megascience_eval.jsonl" \
|
||||
--batch_size=10 \
|
||||
--run_name="megascience_eval_glm4-6-fixedterminal-2" \
|
||||
--distribution="science" \
|
||||
--model="z-ai/glm-4.6" \
|
||||
--base_url="https://openrouter.ai/api/v1" \
|
||||
--api_key="${OPENROUTER_API_KEY}" \
|
||||
--num_workers=5 \
|
||||
--max_turns=30 \
|
||||
--verbose \
|
||||
--ephemeral_system_prompt="You have access to a variety of tools to help you solve scientific, math, and technology problems presented to you. You can use them in sequence and build off of the results of prior tools you've used results. Always use a tool if it can provide additional context, verify formulas, double check concepts and recent studies and understanding, doing all calculations, etc. You should only be confident in your own reasoning, knowledge, or calculations if you've exhaustively used all tools available to you to that can help you verify or validate your work. Always pip install any packages you need to use the python scripts you want to run."
|
||||
20
safe_print.py
Normal file
20
safe_print.py
Normal file
@@ -0,0 +1,20 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Simple safe print that tries rich, falls back to regular print."""
|
||||
|
||||
try:
|
||||
from rich import print as rich_print
|
||||
RICH_AVAILABLE = True
|
||||
except ImportError:
|
||||
RICH_AVAILABLE = False
|
||||
|
||||
|
||||
def safe_print(*args, **kwargs):
|
||||
"""Try rich.print, fall back to regular print if it fails."""
|
||||
if RICH_AVAILABLE:
|
||||
try:
|
||||
rich_print(*args, **kwargs)
|
||||
return
|
||||
except Exception:
|
||||
pass
|
||||
# Fallback to regular print
|
||||
print(*args, **kwargs)
|
||||
@@ -24,7 +24,7 @@ def create_test_dataset():
|
||||
|
||||
with open(test_file, 'w') as f:
|
||||
for prompt in prompts:
|
||||
f.write(json.dumps(prompt) + "\n")
|
||||
f.write(json.dumps(prompt, ensure_ascii=False) + "\n")
|
||||
|
||||
print(f"✅ Created test dataset: {test_file}")
|
||||
return test_file
|
||||
|
||||
424
tests/test_checkpoint_resumption.py
Normal file
424
tests/test_checkpoint_resumption.py
Normal file
@@ -0,0 +1,424 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test script to verify checkpoint behavior in batch_runner.py
|
||||
|
||||
This script simulates batch processing with intentional failures to test:
|
||||
1. Whether checkpoints are saved incrementally during processing
|
||||
2. Whether resume functionality works correctly after interruption
|
||||
3. Whether data integrity is maintained across checkpoint cycles
|
||||
|
||||
Usage:
|
||||
# Test current implementation
|
||||
python tests/test_checkpoint_resumption.py --test_current
|
||||
|
||||
# Test after fix is applied
|
||||
python tests/test_checkpoint_resumption.py --test_fixed
|
||||
|
||||
# Run full comparison
|
||||
python tests/test_checkpoint_resumption.py --compare
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
import time
|
||||
import signal
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Any
|
||||
import traceback
|
||||
|
||||
# Add parent directory to path to import batch_runner
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
|
||||
def create_test_dataset(num_prompts: int = 20) -> Path:
|
||||
"""Create a small test dataset for checkpoint testing."""
|
||||
test_data_dir = Path("tests/test_data")
|
||||
test_data_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
dataset_file = test_data_dir / "checkpoint_test_dataset.jsonl"
|
||||
|
||||
with open(dataset_file, 'w', encoding='utf-8') as f:
|
||||
for i in range(num_prompts):
|
||||
entry = {
|
||||
"prompt": f"Test prompt {i}: What is 2+2? Just answer briefly.",
|
||||
"test_id": i
|
||||
}
|
||||
f.write(json.dumps(entry, ensure_ascii=False) + "\n")
|
||||
|
||||
print(f"✅ Created test dataset: {dataset_file} ({num_prompts} prompts)")
|
||||
return dataset_file
|
||||
|
||||
|
||||
def monitor_checkpoint_during_run(checkpoint_file: Path, duration: int = 30) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Monitor checkpoint file during a batch run to see when it gets updated.
|
||||
|
||||
Args:
|
||||
checkpoint_file: Path to checkpoint file to monitor
|
||||
duration: How long to monitor (seconds)
|
||||
|
||||
Returns:
|
||||
List of checkpoint snapshots with timestamps
|
||||
"""
|
||||
snapshots = []
|
||||
start_time = time.time()
|
||||
last_mtime = None
|
||||
|
||||
print(f"\n🔍 Monitoring checkpoint file: {checkpoint_file}")
|
||||
print(f" Duration: {duration}s")
|
||||
print("-" * 70)
|
||||
|
||||
while time.time() - start_time < duration:
|
||||
if checkpoint_file.exists():
|
||||
current_mtime = checkpoint_file.stat().st_mtime
|
||||
|
||||
# Check if file was modified
|
||||
if last_mtime is None or current_mtime != last_mtime:
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
try:
|
||||
with open(checkpoint_file, 'r') as f:
|
||||
checkpoint_data = json.load(f)
|
||||
|
||||
snapshot = {
|
||||
"elapsed_seconds": round(elapsed, 2),
|
||||
"completed_count": len(checkpoint_data.get("completed_prompts", [])),
|
||||
"completed_prompts": checkpoint_data.get("completed_prompts", [])[:5], # First 5 for display
|
||||
"timestamp": checkpoint_data.get("last_updated")
|
||||
}
|
||||
|
||||
snapshots.append(snapshot)
|
||||
|
||||
print(f"[{elapsed:6.2f}s] Checkpoint updated: {snapshot['completed_count']} prompts completed")
|
||||
|
||||
except Exception as e:
|
||||
print(f"[{elapsed:6.2f}s] Error reading checkpoint: {e}")
|
||||
|
||||
last_mtime = current_mtime
|
||||
else:
|
||||
if len(snapshots) == 0:
|
||||
print(f"[{time.time() - start_time:6.2f}s] Checkpoint file not yet created...")
|
||||
|
||||
time.sleep(0.5) # Check every 0.5 seconds
|
||||
|
||||
return snapshots
|
||||
|
||||
|
||||
def test_current_implementation():
|
||||
"""Test the current checkpoint implementation."""
|
||||
print("\n" + "=" * 70)
|
||||
print("TEST 1: Current Implementation - Checkpoint Timing")
|
||||
print("=" * 70)
|
||||
print("\n📝 Testing whether checkpoints are saved incrementally during run...")
|
||||
|
||||
# Setup
|
||||
dataset_file = create_test_dataset(num_prompts=12)
|
||||
run_name = "checkpoint_test_current"
|
||||
output_dir = Path("data") / run_name
|
||||
|
||||
# Clean up any existing test data
|
||||
if output_dir.exists():
|
||||
shutil.rmtree(output_dir)
|
||||
|
||||
# Import here to avoid issues if module changes
|
||||
from batch_runner import BatchRunner
|
||||
|
||||
checkpoint_file = output_dir / "checkpoint.json"
|
||||
|
||||
# Start monitoring in a separate process would be ideal, but for simplicity
|
||||
# we'll just check before and after
|
||||
print(f"\n▶️ Starting batch run...")
|
||||
print(f" Dataset: {dataset_file}")
|
||||
print(f" Batch size: 3 (4 batches total)")
|
||||
print(f" Workers: 2")
|
||||
print(f" Expected behavior: If incremental, checkpoint should update during run")
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
runner = BatchRunner(
|
||||
dataset_file=str(dataset_file),
|
||||
batch_size=3,
|
||||
run_name=run_name,
|
||||
distribution="default",
|
||||
max_iterations=3, # Keep it short
|
||||
model="claude-opus-4-20250514",
|
||||
num_workers=2,
|
||||
verbose=False
|
||||
)
|
||||
|
||||
# Run with monitoring
|
||||
import threading
|
||||
snapshots = []
|
||||
|
||||
def monitor():
|
||||
nonlocal snapshots
|
||||
snapshots = monitor_checkpoint_during_run(checkpoint_file, duration=60)
|
||||
|
||||
monitor_thread = threading.Thread(target=monitor, daemon=True)
|
||||
monitor_thread.start()
|
||||
|
||||
runner.run(resume=False)
|
||||
|
||||
monitor_thread.join(timeout=2)
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error during run: {e}")
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
# Analyze results
|
||||
print("\n" + "=" * 70)
|
||||
print("📊 TEST RESULTS")
|
||||
print("=" * 70)
|
||||
print(f"Total run time: {elapsed:.2f}s")
|
||||
print(f"Checkpoint updates observed: {len(snapshots)}")
|
||||
|
||||
if len(snapshots) == 0:
|
||||
print("\n❌ ISSUE: No checkpoint updates observed during run")
|
||||
print(" This suggests checkpoints are only saved at the end")
|
||||
return False
|
||||
elif len(snapshots) == 1:
|
||||
print("\n⚠️ WARNING: Only 1 checkpoint update (likely at the end)")
|
||||
print(" This confirms the bug - no incremental checkpointing")
|
||||
return False
|
||||
else:
|
||||
print(f"\n✅ GOOD: Multiple checkpoint updates ({len(snapshots)}) observed")
|
||||
print(" Checkpointing appears to be incremental")
|
||||
|
||||
# Show timeline
|
||||
print("\n📈 Checkpoint Timeline:")
|
||||
for i, snapshot in enumerate(snapshots, 1):
|
||||
print(f" {i}. [{snapshot['elapsed_seconds']:6.2f}s] "
|
||||
f"{snapshot['completed_count']} prompts completed")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def test_interruption_and_resume():
|
||||
"""Test that resume actually works after interruption."""
|
||||
print("\n" + "=" * 70)
|
||||
print("TEST 2: Interruption and Resume")
|
||||
print("=" * 70)
|
||||
print("\n📝 Testing whether resume works after manual interruption...")
|
||||
|
||||
# Setup
|
||||
dataset_file = create_test_dataset(num_prompts=15)
|
||||
run_name = "checkpoint_test_resume"
|
||||
output_dir = Path("data") / run_name
|
||||
|
||||
# Clean up any existing test data
|
||||
if output_dir.exists():
|
||||
shutil.rmtree(output_dir)
|
||||
|
||||
from batch_runner import BatchRunner
|
||||
|
||||
checkpoint_file = output_dir / "checkpoint.json"
|
||||
|
||||
print(f"\n▶️ Starting first run (will process 5 prompts, then simulate interruption)...")
|
||||
|
||||
try:
|
||||
# Create a modified dataset with only first 5 prompts for initial run
|
||||
temp_dataset = Path("tests/test_data/checkpoint_test_resume_partial.jsonl")
|
||||
with open(dataset_file, 'r') as f:
|
||||
lines = f.readlines()[:5]
|
||||
with open(temp_dataset, 'w') as f:
|
||||
f.writelines(lines)
|
||||
|
||||
runner = BatchRunner(
|
||||
dataset_file=str(temp_dataset),
|
||||
batch_size=2,
|
||||
run_name=run_name,
|
||||
distribution="default",
|
||||
max_iterations=3,
|
||||
model="claude-opus-4-20250514",
|
||||
num_workers=1,
|
||||
verbose=False
|
||||
)
|
||||
|
||||
runner.run(resume=False)
|
||||
|
||||
# Check checkpoint after first run
|
||||
if not checkpoint_file.exists():
|
||||
print("❌ ERROR: Checkpoint file not created after first run")
|
||||
return False
|
||||
|
||||
with open(checkpoint_file, 'r') as f:
|
||||
checkpoint_data = json.load(f)
|
||||
|
||||
initial_completed = len(checkpoint_data.get("completed_prompts", []))
|
||||
print(f"✅ First run completed: {initial_completed} prompts saved to checkpoint")
|
||||
|
||||
# Now try to resume with full dataset
|
||||
print(f"\n▶️ Starting resume run with full dataset (15 prompts)...")
|
||||
|
||||
runner2 = BatchRunner(
|
||||
dataset_file=str(dataset_file),
|
||||
batch_size=2,
|
||||
run_name=run_name,
|
||||
distribution="default",
|
||||
max_iterations=3,
|
||||
model="claude-opus-4-20250514",
|
||||
num_workers=1,
|
||||
verbose=False
|
||||
)
|
||||
|
||||
runner2.run(resume=True)
|
||||
|
||||
# Check final checkpoint
|
||||
with open(checkpoint_file, 'r') as f:
|
||||
final_checkpoint = json.load(f)
|
||||
|
||||
final_completed = len(final_checkpoint.get("completed_prompts", []))
|
||||
|
||||
print("\n" + "=" * 70)
|
||||
print("📊 TEST RESULTS")
|
||||
print("=" * 70)
|
||||
print(f"Initial completed: {initial_completed}")
|
||||
print(f"Final completed: {final_completed}")
|
||||
print(f"Expected: 15")
|
||||
|
||||
if final_completed == 15:
|
||||
print("\n✅ PASS: Resume successfully completed all prompts")
|
||||
return True
|
||||
else:
|
||||
print(f"\n❌ FAIL: Expected 15 completed, got {final_completed}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error during test: {e}")
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
|
||||
def test_simulated_crash():
|
||||
"""Test behavior when process crashes mid-execution."""
|
||||
print("\n" + "=" * 70)
|
||||
print("TEST 3: Simulated Crash During Execution")
|
||||
print("=" * 70)
|
||||
print("\n📝 This test would require running in a subprocess and killing it...")
|
||||
print(" Skipping for safety - manual testing recommended")
|
||||
return None
|
||||
|
||||
|
||||
def print_test_plan():
|
||||
"""Print the detailed test and fix plan."""
|
||||
print("\n" + "=" * 70)
|
||||
print("CHECKPOINT FIX - DETAILED PLAN")
|
||||
print("=" * 70)
|
||||
|
||||
print("""
|
||||
📋 PROBLEM SUMMARY
|
||||
------------------
|
||||
Current implementation uses pool.map() which blocks until ALL batches complete.
|
||||
Checkpoint is only saved after all batches finish (line 558-559).
|
||||
|
||||
If process crashes during batch processing:
|
||||
- All progress is lost
|
||||
- Resume does nothing (no incremental checkpoint was saved)
|
||||
|
||||
📋 PROPOSED SOLUTION
|
||||
--------------------
|
||||
Replace pool.map() with pool.imap_unordered() to get results as they complete.
|
||||
Save checkpoint after EACH batch completes using a multiprocessing Lock.
|
||||
|
||||
Key changes:
|
||||
1. Use Manager().Lock() for thread-safe checkpoint writes
|
||||
2. Replace pool.map() with pool.imap_unordered()
|
||||
3. Update checkpoint after each batch result
|
||||
4. Maintain backward compatibility with existing checkpoints
|
||||
|
||||
📋 IMPLEMENTATION STEPS
|
||||
-----------------------
|
||||
1. Add Manager and Lock initialization before Pool creation
|
||||
2. Pass shared checkpoint data and lock to workers (via Manager)
|
||||
3. Replace pool.map() with pool.imap_unordered()
|
||||
4. In result loop: save checkpoint after each batch
|
||||
5. Add error handling for checkpoint write failures
|
||||
|
||||
📋 RISKS & MITIGATIONS
|
||||
----------------------
|
||||
Risk: Checkpoint file corruption if two processes write simultaneously
|
||||
→ Mitigation: Use multiprocessing.Lock() for exclusive access
|
||||
|
||||
Risk: Performance impact from frequent checkpoint writes
|
||||
→ Mitigation: Checkpoint writes are fast (small JSON), negligible impact
|
||||
|
||||
Risk: Breaking existing runs that are already checkpointed
|
||||
→ Mitigation: Maintain checkpoint format, only change timing
|
||||
|
||||
Risk: Bugs in multiprocessing lock/manager code
|
||||
→ Mitigation: Thorough testing with this test script
|
||||
|
||||
📋 TESTING STRATEGY
|
||||
-------------------
|
||||
1. Run test_current_implementation() - Confirm bug exists
|
||||
2. Apply fix to batch_runner.py
|
||||
3. Run test_current_implementation() again - Should see incremental updates
|
||||
4. Run test_interruption_and_resume() - Verify resume works
|
||||
5. Manual test: Start run, kill process mid-batch, resume
|
||||
|
||||
📋 ROLLBACK PLAN
|
||||
----------------
|
||||
If issues arise:
|
||||
1. Git revert the changes
|
||||
2. Original code is working (just missing incremental checkpoint)
|
||||
3. No data corruption risk - checkpoints are write-only
|
||||
""")
|
||||
|
||||
|
||||
def main(
|
||||
test_current: bool = False,
|
||||
test_resume: bool = False,
|
||||
test_crash: bool = False,
|
||||
compare: bool = False,
|
||||
show_plan: bool = False
|
||||
):
|
||||
"""
|
||||
Run checkpoint behavior tests.
|
||||
|
||||
Args:
|
||||
test_current: Test current implementation checkpoint timing
|
||||
test_resume: Test interruption and resume functionality
|
||||
test_crash: Test simulated crash scenario (manual)
|
||||
compare: Run all tests and compare
|
||||
show_plan: Show detailed fix plan
|
||||
"""
|
||||
if show_plan or (not any([test_current, test_resume, test_crash, compare])):
|
||||
print_test_plan()
|
||||
return
|
||||
|
||||
results = {}
|
||||
|
||||
if test_current or compare:
|
||||
results['current'] = test_current_implementation()
|
||||
|
||||
if test_resume or compare:
|
||||
results['resume'] = test_interruption_and_resume()
|
||||
|
||||
if test_crash or compare:
|
||||
results['crash'] = test_simulated_crash()
|
||||
|
||||
# Summary
|
||||
if results:
|
||||
print("\n" + "=" * 70)
|
||||
print("OVERALL TEST SUMMARY")
|
||||
print("=" * 70)
|
||||
for test_name, result in results.items():
|
||||
if result is None:
|
||||
status = "⏭️ SKIPPED"
|
||||
elif result:
|
||||
status = "✅ PASS"
|
||||
else:
|
||||
status = "❌ FAIL"
|
||||
print(f"{status} - {test_name}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import fire
|
||||
fire.Fire(main)
|
||||
|
||||
176
tests/test_nous_api_limits.py
Executable file
176
tests/test_nous_api_limits.py
Executable file
@@ -0,0 +1,176 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test script to diagnose Nous API 400 errors with gemini-2.5-flash model.
|
||||
This tests various content lengths and parameters to identify what causes failures.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from openai import AsyncOpenAI
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv()
|
||||
|
||||
# Initialize the Nous API client
|
||||
nous_client = AsyncOpenAI(
|
||||
api_key=os.getenv("NOUS_API_KEY"),
|
||||
base_url="https://inference-api.nousresearch.com/v1"
|
||||
)
|
||||
|
||||
MODEL = "gemini-2.5-flash"
|
||||
|
||||
async def test_api_call(test_name: str, content_length: int, **kwargs):
|
||||
"""Test an API call with specific parameters."""
|
||||
print(f"\n{'='*60}")
|
||||
print(f"Test: {test_name}")
|
||||
print(f"Content length: {content_length:,} characters")
|
||||
print(f"Additional params: {kwargs}")
|
||||
print(f"{'='*60}")
|
||||
|
||||
# Generate test content
|
||||
content = "A" * content_length
|
||||
|
||||
system_prompt = """You are an expert content analyst. Your job is to process web content and create a comprehensive yet concise summary that preserves all important information while dramatically reducing bulk.
|
||||
|
||||
Create a well-structured markdown summary that includes:
|
||||
1. Key excerpts (quotes, code snippets, important facts) in their original format
|
||||
2. Comprehensive summary of all other important information
|
||||
3. Proper markdown formatting with headers, bullets, and emphasis
|
||||
|
||||
Your goal is to preserve ALL important information while reducing length. Never lose key facts, figures, insights, or actionable information. Make it scannable and well-organized."""
|
||||
|
||||
user_prompt = f"""Please process this web content and create a comprehensive markdown summary:
|
||||
|
||||
CONTENT TO PROCESS:
|
||||
{content}
|
||||
|
||||
Create a markdown summary that captures all key information in a well-organized, scannable format. Include important quotes and code snippets in their original formatting. Focus on actionable information, specific details, and unique insights."""
|
||||
|
||||
try:
|
||||
response = await nous_client.chat.completions.create(
|
||||
model=MODEL,
|
||||
messages=[
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt}
|
||||
],
|
||||
**kwargs
|
||||
)
|
||||
|
||||
result = response.choices[0].message.content
|
||||
print(f"✅ SUCCESS")
|
||||
print(f" Response length: {len(result)} characters")
|
||||
print(f" Model used: {response.model}")
|
||||
print(f" Usage: {response.usage}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ FAILED: {str(e)}")
|
||||
return False
|
||||
|
||||
async def main():
|
||||
"""Run all tests."""
|
||||
print("Testing Nous API with gemini-2.5-flash model")
|
||||
print(f"API Key present: {'Yes' if os.getenv('NOUS_API_KEY') else 'No'}")
|
||||
|
||||
results = {}
|
||||
|
||||
# Test 1: Small content (should always work)
|
||||
results['small'] = await test_api_call(
|
||||
"Small content (5,000 chars)",
|
||||
5000,
|
||||
temperature=0.1,
|
||||
max_tokens=4000
|
||||
)
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# Test 2: Medium content (around what was failing)
|
||||
results['medium'] = await test_api_call(
|
||||
"Medium content (20,000 chars)",
|
||||
20000,
|
||||
temperature=0.1,
|
||||
max_tokens=4000
|
||||
)
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# Test 3: Large content (79,625 chars like the error)
|
||||
results['large'] = await test_api_call(
|
||||
"Large content (79,625 chars)",
|
||||
79625,
|
||||
temperature=0.1,
|
||||
max_tokens=4000
|
||||
)
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# Test 4: Very large content (100k chars)
|
||||
results['very_large'] = await test_api_call(
|
||||
"Very large content (100,000 chars)",
|
||||
100000,
|
||||
temperature=0.1,
|
||||
max_tokens=4000
|
||||
)
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# Test 5: Same as working case but different max_tokens
|
||||
results['diff_max_tokens'] = await test_api_call(
|
||||
"Medium content with higher max_tokens",
|
||||
20000,
|
||||
temperature=0.1,
|
||||
max_tokens=8000
|
||||
)
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# Test 6: No max_tokens specified
|
||||
results['no_max_tokens'] = await test_api_call(
|
||||
"Medium content without max_tokens",
|
||||
20000,
|
||||
temperature=0.1
|
||||
)
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# Test 7: With actual web content (mixed characters)
|
||||
mixed_content = """
|
||||
This is a test of web content with various characters:
|
||||
- Unicode: 你好世界 🌍
|
||||
- Special chars: <>&"'
|
||||
- Numbers: 123456789
|
||||
- Markdown: **bold** _italic_ `code`
|
||||
- URLs: https://example.com
|
||||
""" * 1000 # Repeat to make it ~79k chars
|
||||
|
||||
print(f"\n{'='*60}")
|
||||
print(f"Test: Mixed content (real-world scenario)")
|
||||
print(f"Content length: {len(mixed_content):,} characters")
|
||||
print(f"{'='*60}")
|
||||
|
||||
try:
|
||||
response = await nous_client.chat.completions.create(
|
||||
model=MODEL,
|
||||
messages=[
|
||||
{"role": "system", "content": "Summarize this content."},
|
||||
{"role": "user", "content": mixed_content}
|
||||
],
|
||||
temperature=0.1,
|
||||
max_tokens=4000
|
||||
)
|
||||
print(f"✅ SUCCESS")
|
||||
results['mixed_content'] = True
|
||||
except Exception as e:
|
||||
print(f"❌ FAILED: {str(e)}")
|
||||
results['mixed_content'] = False
|
||||
|
||||
# Summary
|
||||
print(f"\n{'='*60}")
|
||||
print("SUMMARY OF RESULTS:")
|
||||
print(f"{'='*60}")
|
||||
for test, passed in results.items():
|
||||
status = "✅ PASS" if passed else "❌ FAIL"
|
||||
print(f"{test:20s}: {status}")
|
||||
|
||||
passed = sum(results.values())
|
||||
total = len(results)
|
||||
print(f"\nTotal: {passed}/{total} tests passed")
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
||||
131
tests/test_nous_api_pattern.py
Normal file
131
tests/test_nous_api_pattern.py
Normal file
@@ -0,0 +1,131 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test to understand the pattern of failures - it's not about content length!
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from openai import AsyncOpenAI
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
nous_client = AsyncOpenAI(
|
||||
api_key=os.getenv("NOUS_API_KEY"),
|
||||
base_url="https://inference-api.nousresearch.com/v1"
|
||||
)
|
||||
|
||||
MODEL = "gemini-2.5-flash"
|
||||
|
||||
async def quick_test(description: str, content: str, **kwargs):
|
||||
"""Quick API test."""
|
||||
print(f"\n{description} ({len(content):,} chars)...", end=" ")
|
||||
|
||||
try:
|
||||
response = await nous_client.chat.completions.create(
|
||||
model=MODEL,
|
||||
messages=[
|
||||
{"role": "system", "content": "Summarize this."},
|
||||
{"role": "user", "content": content}
|
||||
],
|
||||
**kwargs
|
||||
)
|
||||
print(f"✅ SUCCESS")
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"❌ FAILED: {str(e)[:80]}")
|
||||
return False
|
||||
|
||||
async def main():
|
||||
print("Testing different content types and parameters...")
|
||||
|
||||
# Theory 1: Repeated characters trigger validation
|
||||
print("\n" + "="*60)
|
||||
print("THEORY 1: Repeated characters")
|
||||
print("="*60)
|
||||
await quick_test("Repeated 'A's (5k)", "A" * 5000, temperature=0.1, max_tokens=4000)
|
||||
await asyncio.sleep(0.5)
|
||||
await quick_test("Repeated 'A's (79k)", "A" * 79625, temperature=0.1, max_tokens=4000)
|
||||
await asyncio.sleep(0.5)
|
||||
await quick_test("Varied text (5k)", "Test content. " * 400, temperature=0.1, max_tokens=4000)
|
||||
await asyncio.sleep(0.5)
|
||||
await quick_test("Varied text (79k)", "Test content with variety. " * 3000, temperature=0.1, max_tokens=4000)
|
||||
|
||||
# Theory 2: max_tokens parameter
|
||||
print("\n" + "="*60)
|
||||
print("THEORY 2: max_tokens parameter")
|
||||
print("="*60)
|
||||
content = "Test " * 4000 # 20k chars
|
||||
await quick_test("max_tokens=4000", content, temperature=0.1, max_tokens=4000)
|
||||
await asyncio.sleep(0.5)
|
||||
await quick_test("max_tokens=8000", content, temperature=0.1, max_tokens=8000)
|
||||
await asyncio.sleep(0.5)
|
||||
await quick_test("max_tokens=2000", content, temperature=0.1, max_tokens=2000)
|
||||
await asyncio.sleep(0.5)
|
||||
await quick_test("No max_tokens", content, temperature=0.1)
|
||||
|
||||
# Theory 3: Temperature parameter
|
||||
print("\n" + "="*60)
|
||||
print("THEORY 3: Temperature parameter")
|
||||
print("="*60)
|
||||
content = "Test " * 4000
|
||||
await quick_test("temperature=0.1", content, temperature=0.1, max_tokens=4000)
|
||||
await asyncio.sleep(0.5)
|
||||
await quick_test("temperature=0.0", content, temperature=0.0, max_tokens=4000)
|
||||
await asyncio.sleep(0.5)
|
||||
await quick_test("temperature=0.5", content, temperature=0.5, max_tokens=4000)
|
||||
await asyncio.sleep(0.5)
|
||||
await quick_test("No temperature", content, max_tokens=4000)
|
||||
|
||||
# Theory 4: System prompt impact
|
||||
print("\n" + "="*60)
|
||||
print("THEORY 4: System prompt length")
|
||||
print("="*60)
|
||||
|
||||
short_system = "Summarize this."
|
||||
long_system = """You are an expert content analyst. Your job is to process web content and create a comprehensive yet concise summary that preserves all important information while dramatically reducing bulk.
|
||||
|
||||
Create a well-structured markdown summary that includes:
|
||||
1. Key excerpts (quotes, code snippets, important facts) in their original format
|
||||
2. Comprehensive summary of all other important information
|
||||
3. Proper markdown formatting with headers, bullets, and emphasis
|
||||
|
||||
Your goal is to preserve ALL important information while reducing length."""
|
||||
|
||||
content = "A" * 5000
|
||||
|
||||
print(f"\nShort system prompt...", end=" ")
|
||||
try:
|
||||
response = await nous_client.chat.completions.create(
|
||||
model=MODEL,
|
||||
messages=[
|
||||
{"role": "system", "content": short_system},
|
||||
{"role": "user", "content": content}
|
||||
],
|
||||
temperature=0.1,
|
||||
max_tokens=4000
|
||||
)
|
||||
print(f"✅ SUCCESS")
|
||||
except Exception as e:
|
||||
print(f"❌ FAILED")
|
||||
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
print(f"Long system prompt...", end=" ")
|
||||
try:
|
||||
response = await nous_client.chat.completions.create(
|
||||
model=MODEL,
|
||||
messages=[
|
||||
{"role": "system", "content": long_system},
|
||||
{"role": "user", "content": content}
|
||||
],
|
||||
temperature=0.1,
|
||||
max_tokens=4000
|
||||
)
|
||||
print(f"✅ SUCCESS")
|
||||
except Exception as e:
|
||||
print(f"❌ FAILED")
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
||||
109
tests/test_temperature_fix.py
Normal file
109
tests/test_temperature_fix.py
Normal file
@@ -0,0 +1,109 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test to confirm: temperature < 0.3 causes failures on Nous API
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from openai import AsyncOpenAI
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
nous_client = AsyncOpenAI(
|
||||
api_key=os.getenv("NOUS_API_KEY"),
|
||||
base_url="https://inference-api.nousresearch.com/v1"
|
||||
)
|
||||
|
||||
MODEL = "gemini-2.5-flash"
|
||||
|
||||
async def test_temp(temp_value):
|
||||
"""Test a specific temperature value."""
|
||||
content = "Test content. " * 1000 # 14k chars
|
||||
|
||||
print(f"Testing temperature={temp_value}...", end=" ")
|
||||
|
||||
try:
|
||||
response = await nous_client.chat.completions.create(
|
||||
model=MODEL,
|
||||
messages=[
|
||||
{"role": "system", "content": "Summarize this content."},
|
||||
{"role": "user", "content": content}
|
||||
],
|
||||
temperature=temp_value,
|
||||
max_tokens=4000
|
||||
)
|
||||
print(f"✅ SUCCESS")
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"❌ FAILED")
|
||||
return False
|
||||
|
||||
async def main():
|
||||
print("Testing temperature threshold for Nous API...")
|
||||
print("="*60)
|
||||
|
||||
temps = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 1.0]
|
||||
|
||||
for temp in temps:
|
||||
await test_temp(temp)
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
print("="*60)
|
||||
print("\nNow testing with ACTUAL web_tools.py content and parameters:")
|
||||
print("="*60)
|
||||
|
||||
# Simulate the actual web_tools.py call
|
||||
system_prompt = """You are an expert content analyst. Your job is to process web content and create a comprehensive yet concise summary that preserves all important information while dramatically reducing bulk.
|
||||
|
||||
Create a well-structured markdown summary that includes:
|
||||
1. Key excerpts (quotes, code snippets, important facts) in their original format
|
||||
2. Comprehensive summary of all other important information
|
||||
3. Proper markdown formatting with headers, bullets, and emphasis
|
||||
|
||||
Your goal is to preserve ALL important information while reducing length. Never lose key facts, figures, insights, or actionable information. Make it scannable and well-organized."""
|
||||
|
||||
content = "Sample web page content. " * 3000 # ~75k chars like the real failures
|
||||
|
||||
user_prompt = f"""Please process this web content and create a comprehensive markdown summary:
|
||||
|
||||
CONTENT TO PROCESS:
|
||||
{content}
|
||||
|
||||
Create a markdown summary that captures all key information in a well-organized, scannable format. Include important quotes and code snippets in their original formatting. Focus on actionable information, specific details, and unique insights."""
|
||||
|
||||
print(f"\nActual web_tools call (temp=0.1, {len(content):,} chars)...", end=" ")
|
||||
try:
|
||||
response = await nous_client.chat.completions.create(
|
||||
model=MODEL,
|
||||
messages=[
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt}
|
||||
],
|
||||
temperature=0.1,
|
||||
max_tokens=4000
|
||||
)
|
||||
print(f"✅ SUCCESS")
|
||||
except:
|
||||
print(f"❌ FAILED")
|
||||
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
print(f"Same call but with temp=0.3...", end=" ")
|
||||
try:
|
||||
response = await nous_client.chat.completions.create(
|
||||
model=MODEL,
|
||||
messages=[
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt}
|
||||
],
|
||||
temperature=0.3,
|
||||
max_tokens=4000
|
||||
)
|
||||
print(f"✅ SUCCESS")
|
||||
except:
|
||||
print(f"❌ FAILED")
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -414,7 +414,7 @@ async def image_generate_tool(
|
||||
_log_debug_call("image_generate_tool", debug_call_data)
|
||||
_save_debug_log()
|
||||
|
||||
return json.dumps(response_data, indent=2)
|
||||
return json.dumps(response_data, indent=2, ensure_ascii=False)
|
||||
|
||||
except Exception as e:
|
||||
generation_time = (datetime.datetime.now() - start_time).total_seconds()
|
||||
@@ -432,7 +432,7 @@ async def image_generate_tool(
|
||||
_log_debug_call("image_generate_tool", debug_call_data)
|
||||
_save_debug_log()
|
||||
|
||||
return json.dumps(response_data, indent=2)
|
||||
return json.dumps(response_data, indent=2, ensure_ascii=False)
|
||||
|
||||
|
||||
def check_fal_api_key() -> bool:
|
||||
|
||||
@@ -161,11 +161,11 @@ def _construct_aggregator_prompt(system_prompt: str, responses: List[str]) -> st
|
||||
|
||||
|
||||
async def _run_reference_model_safe(
|
||||
model: str,
|
||||
user_prompt: str,
|
||||
model: str,
|
||||
user_prompt: str,
|
||||
temperature: float = REFERENCE_TEMPERATURE,
|
||||
max_tokens: int = 32000,
|
||||
max_retries: int = 3
|
||||
max_retries: int = 6
|
||||
) -> tuple[str, str, bool]:
|
||||
"""
|
||||
Run a single reference model with retry logic and graceful failure handling.
|
||||
@@ -212,8 +212,8 @@ async def _run_reference_model_safe(
|
||||
print(f"⚠️ {model} unknown error (attempt {attempt + 1}): {error_str}")
|
||||
|
||||
if attempt < max_retries - 1:
|
||||
# Exponential backoff for rate limiting
|
||||
sleep_time = 2 ** attempt
|
||||
# Exponential backoff for rate limiting: 2s, 4s, 8s, 16s, 32s, 60s
|
||||
sleep_time = min(2 ** (attempt + 1), 60)
|
||||
print(f" Retrying in {sleep_time}s...")
|
||||
await asyncio.sleep(sleep_time)
|
||||
else:
|
||||
@@ -410,7 +410,7 @@ async def mixture_of_agents_tool(
|
||||
_log_debug_call("mixture_of_agents_tool", debug_call_data)
|
||||
_save_debug_log()
|
||||
|
||||
return json.dumps(result, indent=2)
|
||||
return json.dumps(result, indent=2, ensure_ascii=False)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Error in MoA processing: {str(e)}"
|
||||
@@ -436,7 +436,7 @@ async def mixture_of_agents_tool(
|
||||
_log_debug_call("mixture_of_agents_tool", debug_call_data)
|
||||
_save_debug_log()
|
||||
|
||||
return json.dumps(result, indent=2)
|
||||
return json.dumps(result, indent=2, ensure_ascii=False)
|
||||
|
||||
|
||||
def check_nous_api_key() -> bool:
|
||||
|
||||
395
tools/simple_terminal_tool.py
Normal file
395
tools/simple_terminal_tool.py
Normal file
@@ -0,0 +1,395 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Simple Terminal Tool Module
|
||||
|
||||
A simplified terminal tool that executes commands on MorphCloud VMs without tmux.
|
||||
No session persistence, no interactive app support - just simple command execution.
|
||||
|
||||
Features:
|
||||
- Direct SSH command execution
|
||||
- Background task support
|
||||
- VM lifecycle management with TTL
|
||||
- Automatic cleanup after inactivity
|
||||
|
||||
Usage:
|
||||
from simple_terminal_tool import simple_terminal_tool
|
||||
|
||||
# Execute a simple command
|
||||
result = simple_terminal_tool("ls -la")
|
||||
|
||||
# Execute in background
|
||||
result = simple_terminal_tool("python server.py", background=True)
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import threading
|
||||
import atexit
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
# Tool description for LLM
|
||||
SIMPLE_TERMINAL_TOOL_DESCRIPTION = """Execute commands on a secure Linux VM environment.
|
||||
|
||||
**Environment:**
|
||||
- Minimal Debian-based OS with internet access
|
||||
- Automatic VM lifecycle management (creates on-demand, reuses, cleans up)
|
||||
- Filesystem is persisted between tool calls but environment variables, venvs, etc are reset.
|
||||
|
||||
**Command Execution:**
|
||||
- Simple commands: Just provide the 'command' parameter
|
||||
- Background processes: Set 'background': True for servers/long-running tasks
|
||||
- Command timeout: Optional 'timeout' parameter in seconds
|
||||
|
||||
**Examples:**
|
||||
- Run command: `{"command": "ls -la"}`
|
||||
- Background task: `{"command": "source path/to/my/venv/bin/activate && python server.py", "background": True}`
|
||||
- With timeout: `{"command": "long_task.sh", "timeout": 300}`
|
||||
|
||||
**Best Practices:**
|
||||
- Run servers/long processes in background
|
||||
- Monitor disk usage for large tasks
|
||||
- Install whatever tools you need with sudo apt-get
|
||||
- Do not be afraid to run pip with --break-system-packages
|
||||
|
||||
**Things to avoid**
|
||||
- Do NOT use interactive tools such as tmux, vim, nano, python repl - you will get stuck. Even git sometimes becomes interactive if the output is large. If you're not sure pipe to cat.
|
||||
"""
|
||||
|
||||
# Global state for VM lifecycle management
|
||||
_active_instances: Dict[str, Any] = {}
|
||||
_last_activity: Dict[str, float] = {}
|
||||
_instance_lock = threading.Lock()
|
||||
_cleanup_thread = None
|
||||
_cleanup_running = False
|
||||
|
||||
|
||||
def _cleanup_inactive_vms(vm_lifetime_seconds: int = 300):
|
||||
"""Clean up VMs that have been inactive for longer than vm_lifetime_seconds."""
|
||||
global _active_instances, _last_activity
|
||||
|
||||
current_time = time.time()
|
||||
tasks_to_cleanup = []
|
||||
|
||||
with _instance_lock:
|
||||
for task_id, last_time in list(_last_activity.items()):
|
||||
if current_time - last_time > vm_lifetime_seconds:
|
||||
tasks_to_cleanup.append(task_id)
|
||||
|
||||
for task_id in tasks_to_cleanup:
|
||||
try:
|
||||
if task_id in _active_instances:
|
||||
instance = _active_instances[task_id]
|
||||
if hasattr(instance, 'terminate'):
|
||||
instance.terminate()
|
||||
elif hasattr(instance, 'stop'):
|
||||
instance.stop()
|
||||
elif hasattr(instance, 'delete'):
|
||||
instance.delete()
|
||||
|
||||
del _active_instances[task_id]
|
||||
print(f"[VM Cleanup] Terminated inactive VM for task: {task_id}")
|
||||
|
||||
if task_id in _last_activity:
|
||||
del _last_activity[task_id]
|
||||
|
||||
except Exception as e:
|
||||
# 404 errors are benign - VM already cleaned up by TTL
|
||||
error_str = str(e)
|
||||
if "404" in error_str or "InstanceNotFoundError" in error_str or "not found" in error_str.lower():
|
||||
print(f"[VM Cleanup] VM for task {task_id} already cleaned up (likely TTL expiration)")
|
||||
else:
|
||||
print(f"[VM Cleanup] Error cleaning up VM for task {task_id}: {e}")
|
||||
|
||||
|
||||
def _cleanup_thread_worker():
|
||||
"""Background thread worker that periodically cleans up inactive VMs."""
|
||||
global _cleanup_running
|
||||
|
||||
while _cleanup_running:
|
||||
try:
|
||||
vm_lifetime = int(os.getenv("HECATE_VM_LIFETIME_SECONDS", "300"))
|
||||
_cleanup_inactive_vms(vm_lifetime)
|
||||
except Exception as e:
|
||||
print(f"[VM Cleanup] Error in cleanup thread: {e}")
|
||||
|
||||
for _ in range(60):
|
||||
if not _cleanup_running:
|
||||
break
|
||||
time.sleep(1)
|
||||
|
||||
|
||||
def _start_cleanup_thread():
|
||||
"""Start the background cleanup thread if not already running."""
|
||||
global _cleanup_thread, _cleanup_running
|
||||
|
||||
with _instance_lock:
|
||||
if _cleanup_thread is None or not _cleanup_thread.is_alive():
|
||||
_cleanup_running = True
|
||||
_cleanup_thread = threading.Thread(target=_cleanup_thread_worker, daemon=True)
|
||||
_cleanup_thread.start()
|
||||
|
||||
|
||||
def _stop_cleanup_thread():
|
||||
"""Stop the background cleanup thread."""
|
||||
global _cleanup_running
|
||||
_cleanup_running = False
|
||||
if _cleanup_thread is not None:
|
||||
_cleanup_thread.join(timeout=5)
|
||||
|
||||
|
||||
def cleanup_vm(task_id: str):
|
||||
"""Manually clean up a specific VM by task_id."""
|
||||
global _active_instances, _last_activity
|
||||
|
||||
with _instance_lock:
|
||||
try:
|
||||
if task_id in _active_instances:
|
||||
instance = _active_instances[task_id]
|
||||
if hasattr(instance, 'terminate'):
|
||||
instance.terminate()
|
||||
elif hasattr(instance, 'stop'):
|
||||
instance.stop()
|
||||
elif hasattr(instance, 'delete'):
|
||||
instance.delete()
|
||||
|
||||
del _active_instances[task_id]
|
||||
print(f"[VM Cleanup] Manually terminated VM for task: {task_id}")
|
||||
|
||||
if task_id in _last_activity:
|
||||
del _last_activity[task_id]
|
||||
|
||||
except Exception as e:
|
||||
# 404 errors are benign - VM already cleaned up by TTL
|
||||
error_str = str(e)
|
||||
if "404" in error_str or "InstanceNotFoundError" in error_str or "not found" in error_str.lower():
|
||||
print(f"[VM Cleanup] VM for task {task_id} already cleaned up (likely TTL expiration)")
|
||||
else:
|
||||
print(f"[VM Cleanup] Error manually cleaning up VM for task {task_id}: {e}")
|
||||
|
||||
|
||||
atexit.register(_stop_cleanup_thread)
|
||||
|
||||
|
||||
def _execute_ssh_command(instance, command: str, timeout: Optional[int] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Execute a command via SSH on the VM instance.
|
||||
|
||||
Args:
|
||||
instance: MorphVM instance
|
||||
command: Command to execute
|
||||
timeout: Optional timeout in seconds
|
||||
|
||||
Returns:
|
||||
dict with stdout, stderr, returncode
|
||||
"""
|
||||
ssh_context_manager = None
|
||||
try:
|
||||
# Use the instance's SSH context manager
|
||||
ssh_context_manager = instance.ssh()
|
||||
ssh_context = ssh_context_manager.__enter__()
|
||||
|
||||
# Execute the command
|
||||
result = ssh_context.run(command, get_pty=False, timeout=timeout or 120)
|
||||
|
||||
# Close the SSH connection
|
||||
if ssh_context_manager:
|
||||
try:
|
||||
ssh_context_manager.__exit__(None, None, None)
|
||||
except:
|
||||
pass
|
||||
|
||||
return {
|
||||
"stdout": result.stdout or "",
|
||||
"stderr": result.stderr or "",
|
||||
"returncode": result.returncode
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
# Close connection on error
|
||||
if ssh_context_manager:
|
||||
try:
|
||||
ssh_context_manager.__exit__(None, None, None)
|
||||
except:
|
||||
pass
|
||||
|
||||
# Check if it's a timeout
|
||||
error_str = str(e).lower()
|
||||
if "timeout" in error_str:
|
||||
return {
|
||||
"stdout": "",
|
||||
"stderr": f"Command timed out after {timeout or 120} seconds",
|
||||
"returncode": 124
|
||||
}
|
||||
|
||||
return {
|
||||
"stdout": "",
|
||||
"stderr": f"SSH execution failed: {str(e)}",
|
||||
"returncode": -1
|
||||
}
|
||||
|
||||
|
||||
def simple_terminal_tool(
|
||||
command: str,
|
||||
background: bool = False,
|
||||
timeout: Optional[int] = None,
|
||||
task_id: Optional[str] = None
|
||||
) -> str:
|
||||
"""
|
||||
Execute a command on a MorphCloud VM without session persistence.
|
||||
|
||||
Args:
|
||||
command: The command to execute
|
||||
background: Whether to run in background (default: False)
|
||||
timeout: Command timeout in seconds (default: 120)
|
||||
task_id: Unique identifier for VM isolation (optional)
|
||||
|
||||
Returns:
|
||||
str: JSON string with output, exit_code, and error fields
|
||||
|
||||
Examples:
|
||||
# Execute a simple command
|
||||
>>> result = simple_terminal_tool(command="ls -la /tmp")
|
||||
|
||||
# Run a background task
|
||||
>>> result = simple_terminal_tool(command="python server.py", background=True)
|
||||
|
||||
# With custom timeout
|
||||
>>> result = simple_terminal_tool(command="long_task.sh", timeout=300)
|
||||
"""
|
||||
global _active_instances, _last_activity
|
||||
|
||||
try:
|
||||
# Import required modules
|
||||
try:
|
||||
from morphcloud.api import MorphCloudClient
|
||||
except ImportError as import_error:
|
||||
return json.dumps({
|
||||
"output": "",
|
||||
"exit_code": -1,
|
||||
"error": f"Terminal tool disabled: {import_error}",
|
||||
"status": "disabled"
|
||||
}, ensure_ascii=False)
|
||||
|
||||
# Get configuration
|
||||
vm_ttl_seconds = int(os.getenv("HECATE_VM_TTL_SECONDS", "1200"))
|
||||
snapshot_id = os.getenv("HECATE_DEFAULT_SNAPSHOT_ID", "snapshot_defv9tjg")
|
||||
|
||||
# Check API key
|
||||
morph_api_key = os.getenv("MORPH_API_KEY")
|
||||
if not morph_api_key:
|
||||
return json.dumps({
|
||||
"output": "",
|
||||
"exit_code": -1,
|
||||
"error": "MORPH_API_KEY environment variable not set",
|
||||
"status": "disabled"
|
||||
}, ensure_ascii=False)
|
||||
|
||||
# Use task_id for VM isolation
|
||||
effective_task_id = task_id or "default"
|
||||
|
||||
# Start cleanup thread
|
||||
_start_cleanup_thread()
|
||||
|
||||
# Get or create VM instance
|
||||
with _instance_lock:
|
||||
if effective_task_id not in _active_instances:
|
||||
morph_client = MorphCloudClient(api_key=morph_api_key)
|
||||
_active_instances[effective_task_id] = morph_client.instances.start(
|
||||
snapshot_id=snapshot_id,
|
||||
ttl_seconds=vm_ttl_seconds,
|
||||
ttl_action="stop"
|
||||
)
|
||||
|
||||
# Update last activity time
|
||||
_last_activity[effective_task_id] = time.time()
|
||||
instance = _active_instances[effective_task_id]
|
||||
|
||||
# Wait for instance to be ready
|
||||
instance.wait_until_ready()
|
||||
|
||||
# Prepare command for execution
|
||||
if background:
|
||||
# Run in background with nohup and redirect output
|
||||
exec_command = f"nohup {command} > /tmp/bg_output.log 2>&1 &"
|
||||
result = _execute_ssh_command(instance, exec_command, timeout=10)
|
||||
|
||||
# For background tasks, return immediately with info
|
||||
if result["returncode"] == 0:
|
||||
return json.dumps({
|
||||
"output": "Background task started successfully",
|
||||
"exit_code": 0,
|
||||
"error": None
|
||||
}, ensure_ascii=False)
|
||||
else:
|
||||
return json.dumps({
|
||||
"output": result["stdout"],
|
||||
"exit_code": result["returncode"],
|
||||
"error": result["stderr"]
|
||||
}, ensure_ascii=False)
|
||||
else:
|
||||
# Run foreground command
|
||||
result = _execute_ssh_command(instance, command, timeout=timeout)
|
||||
|
||||
# Combine stdout and stderr for output
|
||||
output = result["stdout"]
|
||||
if result["stderr"] and result["returncode"] != 0:
|
||||
output = f"{output}\n{result['stderr']}" if output else result["stderr"]
|
||||
|
||||
return json.dumps({
|
||||
"output": output.strip(),
|
||||
"exit_code": result["returncode"],
|
||||
"error": result["stderr"] if result["returncode"] != 0 else None
|
||||
}, ensure_ascii=False)
|
||||
|
||||
except Exception as e:
|
||||
return json.dumps({
|
||||
"output": "",
|
||||
"exit_code": -1,
|
||||
"error": f"Failed to execute command: {str(e)}",
|
||||
"status": "error"
|
||||
}, ensure_ascii=False)
|
||||
|
||||
|
||||
def check_requirements() -> bool:
|
||||
"""Check if all requirements for the simple terminal tool are met."""
|
||||
required_vars = ["MORPH_API_KEY"]
|
||||
missing_required = [var for var in required_vars if not os.getenv(var)]
|
||||
|
||||
if missing_required:
|
||||
print(f"Missing required environment variables: {', '.join(missing_required)}")
|
||||
return False
|
||||
|
||||
try:
|
||||
from morphcloud.api import MorphCloudClient
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"MorphCloud not available: {e}")
|
||||
return False
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""Simple test when run directly."""
|
||||
print("Simple Terminal Tool Module")
|
||||
print("=" * 40)
|
||||
|
||||
if not check_requirements():
|
||||
print("Requirements not met. Please check the messages above.")
|
||||
exit(1)
|
||||
|
||||
print("All requirements met!")
|
||||
print("\nAvailable Tool:")
|
||||
print(" - simple_terminal_tool: Execute commands without session persistence")
|
||||
|
||||
print("\nUsage Examples:")
|
||||
print(" # Execute a command")
|
||||
print(" result = simple_terminal_tool(command='ls -la')")
|
||||
print(" ")
|
||||
print(" # Run a background task")
|
||||
print(" result = simple_terminal_tool(command='python server.py', background=True)")
|
||||
|
||||
print("\nEnvironment Variables:")
|
||||
print(f" MORPH_API_KEY: {'Set' if os.getenv('MORPH_API_KEY') else 'Not set'}")
|
||||
print(f" HECATE_VM_TTL_SECONDS: {os.getenv('HECATE_VM_TTL_SECONDS', '1200')} (default: 1200 / 20 minutes)")
|
||||
print(f" HECATE_VM_LIFETIME_SECONDS: {os.getenv('HECATE_VM_LIFETIME_SECONDS', '300')} (default: 300 / 5 minutes)")
|
||||
print(f" HECATE_DEFAULT_SNAPSHOT_ID: {os.getenv('HECATE_DEFAULT_SNAPSHOT_ID', 'snapshot_defv9tjg')}")
|
||||
@@ -4,8 +4,12 @@ Terminal Tool Module
|
||||
|
||||
This module provides a single terminal tool using Hecate's VM infrastructure.
|
||||
It wraps Hecate's functionality to provide a simple interface for executing commands
|
||||
on Morph VMs with automatic lifecycle management. VMs live for 5 minutes after last use.
|
||||
Timer resets with each use.
|
||||
on Morph VMs with automatic lifecycle management.
|
||||
|
||||
VM Lifecycle:
|
||||
- VMs have a TTL (time to live) set at creation (default: 20 minutes)
|
||||
- VMs are also cleaned up locally after 5 minutes of inactivity
|
||||
- Timer resets with each use
|
||||
|
||||
Available tool:
|
||||
- terminal_tool: Execute commands with optional interactive session support
|
||||
@@ -24,6 +28,8 @@ import json
|
||||
import os
|
||||
import uuid
|
||||
import threading
|
||||
import time
|
||||
import atexit
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
# Detailed description for the terminal tool based on Hermes Terminal system prompt
|
||||
@@ -75,9 +81,137 @@ When commands enter interactive mode (vim, nano, less, git prompts, package mana
|
||||
|
||||
# Global state for VM lifecycle management
|
||||
# These persist across tool calls to enable session continuity
|
||||
_active_instance = None
|
||||
_active_context = None
|
||||
# Changed to dictionaries keyed by task_id to prevent leakage between concurrent tasks
|
||||
_active_instances: Dict[str, Any] = {}
|
||||
_active_contexts: Dict[str, Any] = {}
|
||||
_last_activity: Dict[str, float] = {} # Track last activity time for each VM
|
||||
_instance_lock = threading.Lock()
|
||||
_cleanup_thread = None
|
||||
_cleanup_running = False
|
||||
|
||||
def _cleanup_inactive_vms(vm_lifetime_seconds: int = 300):
|
||||
"""
|
||||
Clean up VMs that have been inactive for longer than vm_lifetime_seconds.
|
||||
This function should be called periodically by a background thread.
|
||||
|
||||
Args:
|
||||
vm_lifetime_seconds: Maximum lifetime in seconds for inactive VMs (default: 300)
|
||||
"""
|
||||
global _active_instances, _active_contexts, _last_activity
|
||||
|
||||
current_time = time.time()
|
||||
tasks_to_cleanup = []
|
||||
|
||||
with _instance_lock:
|
||||
# Find all VMs that have been inactive for too long
|
||||
for task_id, last_time in list(_last_activity.items()):
|
||||
if current_time - last_time > vm_lifetime_seconds:
|
||||
tasks_to_cleanup.append(task_id)
|
||||
|
||||
# Clean up the inactive VMs
|
||||
for task_id in tasks_to_cleanup:
|
||||
try:
|
||||
if task_id in _active_instances:
|
||||
instance = _active_instances[task_id]
|
||||
# Terminate the VM instance
|
||||
if hasattr(instance, 'terminate'):
|
||||
instance.terminate()
|
||||
elif hasattr(instance, 'stop'):
|
||||
instance.stop()
|
||||
elif hasattr(instance, 'delete'):
|
||||
instance.delete()
|
||||
|
||||
# Remove from tracking dictionaries
|
||||
del _active_instances[task_id]
|
||||
print(f"[VM Cleanup] Terminated inactive VM for task: {task_id}")
|
||||
|
||||
if task_id in _active_contexts:
|
||||
del _active_contexts[task_id]
|
||||
|
||||
if task_id in _last_activity:
|
||||
del _last_activity[task_id]
|
||||
|
||||
except Exception as e:
|
||||
print(f"[VM Cleanup] Error cleaning up VM for task {task_id}: {e}")
|
||||
|
||||
def _cleanup_thread_worker():
|
||||
"""
|
||||
Background thread worker that periodically cleans up inactive VMs.
|
||||
Runs every 60 seconds.
|
||||
"""
|
||||
global _cleanup_running
|
||||
|
||||
while _cleanup_running:
|
||||
try:
|
||||
vm_lifetime = int(os.getenv("HECATE_VM_LIFETIME_SECONDS", "300"))
|
||||
_cleanup_inactive_vms(vm_lifetime)
|
||||
except Exception as e:
|
||||
print(f"[VM Cleanup] Error in cleanup thread: {e}")
|
||||
|
||||
# Sleep for 60 seconds, but check every second if we should stop
|
||||
for _ in range(60):
|
||||
if not _cleanup_running:
|
||||
break
|
||||
time.sleep(1)
|
||||
|
||||
def _start_cleanup_thread():
|
||||
"""
|
||||
Start the background cleanup thread if it's not already running.
|
||||
"""
|
||||
global _cleanup_thread, _cleanup_running
|
||||
|
||||
with _instance_lock:
|
||||
if _cleanup_thread is None or not _cleanup_thread.is_alive():
|
||||
_cleanup_running = True
|
||||
_cleanup_thread = threading.Thread(target=_cleanup_thread_worker, daemon=True)
|
||||
_cleanup_thread.start()
|
||||
|
||||
def _stop_cleanup_thread():
|
||||
"""
|
||||
Stop the background cleanup thread.
|
||||
"""
|
||||
global _cleanup_running
|
||||
_cleanup_running = False
|
||||
if _cleanup_thread is not None:
|
||||
_cleanup_thread.join(timeout=5)
|
||||
|
||||
def cleanup_vm(task_id: str):
|
||||
"""
|
||||
Manually clean up a specific VM by task_id.
|
||||
This should be called when a task is completed.
|
||||
|
||||
Args:
|
||||
task_id: The task ID of the VM to clean up
|
||||
"""
|
||||
global _active_instances, _active_contexts, _last_activity
|
||||
|
||||
with _instance_lock:
|
||||
try:
|
||||
if task_id in _active_instances:
|
||||
instance = _active_instances[task_id]
|
||||
# Terminate the VM instance
|
||||
if hasattr(instance, 'terminate'):
|
||||
instance.terminate()
|
||||
elif hasattr(instance, 'stop'):
|
||||
instance.stop()
|
||||
elif hasattr(instance, 'delete'):
|
||||
instance.delete()
|
||||
|
||||
# Remove from tracking dictionaries
|
||||
del _active_instances[task_id]
|
||||
print(f"[VM Cleanup] Manually terminated VM for task: {task_id}")
|
||||
|
||||
if task_id in _active_contexts:
|
||||
del _active_contexts[task_id]
|
||||
|
||||
if task_id in _last_activity:
|
||||
del _last_activity[task_id]
|
||||
|
||||
except Exception as e:
|
||||
print(f"[VM Cleanup] Error manually cleaning up VM for task {task_id}: {e}")
|
||||
|
||||
# Register cleanup on program exit
|
||||
atexit.register(_stop_cleanup_thread)
|
||||
|
||||
def terminal_tool(
|
||||
command: Optional[str] = None,
|
||||
@@ -85,23 +219,25 @@ def terminal_tool(
|
||||
session_id: Optional[str] = None,
|
||||
background: bool = False,
|
||||
idle_threshold: float = 5.0,
|
||||
timeout: Optional[int] = None
|
||||
timeout: Optional[int] = None,
|
||||
task_id: Optional[str] = None
|
||||
) -> str:
|
||||
"""
|
||||
Execute a command on a Morph VM with optional interactive session support.
|
||||
|
||||
|
||||
This tool uses Hecate's VM lifecycle management to automatically create
|
||||
and manage VMs. VMs are reused within the configured lifetime window
|
||||
and automatically cleaned up after inactivity.
|
||||
|
||||
|
||||
Args:
|
||||
command: The command to execute (optional if continuing existing session)
|
||||
input_keys: Keystrokes to send to interactive session (e.g., "hello\\n")
|
||||
session_id: ID of existing session to continue (optional)
|
||||
background: Whether to run the command in the background (default: False)
|
||||
background: Whether to run the command in the background (default: False)
|
||||
idle_threshold: Seconds to wait for output before considering session idle (default: 5.0)
|
||||
timeout: Command timeout in seconds (optional)
|
||||
|
||||
task_id: Unique identifier for this task to isolate VMs between concurrent tasks (optional)
|
||||
|
||||
Returns:
|
||||
str: JSON string containing command output, session info, exit code, and any errors
|
||||
|
||||
@@ -120,7 +256,7 @@ def terminal_tool(
|
||||
# Run a background task
|
||||
>>> result = terminal_tool(command="sleep 60", background=True)
|
||||
"""
|
||||
global _active_instance, _active_context
|
||||
global _active_instances, _active_contexts
|
||||
|
||||
try:
|
||||
# Import required modules lazily so this module can be imported
|
||||
@@ -135,15 +271,16 @@ def terminal_tool(
|
||||
return json.dumps({
|
||||
"output": "",
|
||||
"screen": "",
|
||||
"session_id": None,
|
||||
"exit_code": -1,
|
||||
"error": f"Terminal tool is disabled due to import error: {import_error}",
|
||||
"status": "disabled"
|
||||
})
|
||||
}, ensure_ascii=False)
|
||||
|
||||
|
||||
# Get configuration from environment
|
||||
vm_lifetime_seconds = int(os.getenv("HECATE_VM_LIFETIME_SECONDS", "300"))
|
||||
snapshot_id = os.getenv("HECATE_DEFAULT_SNAPSHOT_ID", "python-2025-10-31")
|
||||
vm_ttl_seconds = int(os.getenv("HECATE_VM_TTL_SECONDS", "1200")) # 20 minutes default
|
||||
snapshot_id = os.getenv("HECATE_DEFAULT_SNAPSHOT_ID", "snapshot_defv9tjg")
|
||||
|
||||
# Check API key
|
||||
morph_api_key = os.getenv("MORPH_API_KEY")
|
||||
@@ -151,25 +288,38 @@ def terminal_tool(
|
||||
return json.dumps({
|
||||
"output": "",
|
||||
"screen": "",
|
||||
"session_id": None,
|
||||
"exit_code": -1,
|
||||
"error": "MORPH_API_KEY environment variable not set",
|
||||
"status": "disabled"
|
||||
})
|
||||
}, ensure_ascii=False)
|
||||
|
||||
# Get or create VM instance and execution context
|
||||
# Use task_id to isolate VMs between concurrent tasks
|
||||
# If no task_id provided, use "default" for backward compatibility
|
||||
effective_task_id = task_id or "default"
|
||||
|
||||
# Start the cleanup thread if not already running
|
||||
_start_cleanup_thread()
|
||||
|
||||
# Get or create VM instance and execution context per task
|
||||
# This is critical for interactive session support - the context must persist!
|
||||
with _instance_lock:
|
||||
if _active_instance is None:
|
||||
if effective_task_id not in _active_instances:
|
||||
morph_client = MorphCloudClient(api_key=morph_api_key)
|
||||
_active_instance = morph_client.instances.start(snapshot_id=snapshot_id)
|
||||
_active_instances[effective_task_id] = morph_client.instances.start(
|
||||
snapshot_id=snapshot_id,
|
||||
ttl_seconds=vm_ttl_seconds,
|
||||
ttl_action="stop"
|
||||
)
|
||||
|
||||
# Get or create persistent execution context
|
||||
if _active_context is None:
|
||||
_active_context = ExecutionContext()
|
||||
# Get or create persistent execution context per task
|
||||
if effective_task_id not in _active_contexts:
|
||||
_active_contexts[effective_task_id] = ExecutionContext()
|
||||
|
||||
instance = _active_instance
|
||||
ctx = _active_context
|
||||
# Update last activity time for this VM (resets the inactivity timer)
|
||||
_last_activity[effective_task_id] = time.time()
|
||||
|
||||
instance = _active_instances[effective_task_id]
|
||||
ctx = _active_contexts[effective_task_id]
|
||||
|
||||
# Build tool input based on provided parameters
|
||||
tool_input = {}
|
||||
@@ -208,28 +358,25 @@ def terminal_tool(
|
||||
ctx=ctx
|
||||
)
|
||||
|
||||
# Format the result with all possible fields
|
||||
# Format the result with only essential fields for the LLM
|
||||
# Map hecate's "stdout" to "output" for compatibility
|
||||
formatted_result = {
|
||||
"output": result.get("stdout", result.get("output", "")),
|
||||
"screen": result.get("screen", ""),
|
||||
"session_id": result.get("session_id"),
|
||||
"exit_code": result.get("returncode", result.get("exit_code", -1)),
|
||||
"error": result.get("error"),
|
||||
"status": "active" if result.get("session_id") else "ended"
|
||||
"error": result.get("error")
|
||||
}
|
||||
|
||||
return json.dumps(formatted_result)
|
||||
return json.dumps(formatted_result, ensure_ascii=False)
|
||||
|
||||
except Exception as e:
|
||||
return json.dumps({
|
||||
"output": "",
|
||||
"screen": "",
|
||||
"session_id": None,
|
||||
"exit_code": -1,
|
||||
"error": f"Failed to execute terminal command: {str(e)}",
|
||||
"status": "error"
|
||||
})
|
||||
}, ensure_ascii=False)
|
||||
|
||||
def check_hecate_requirements() -> bool:
|
||||
"""
|
||||
@@ -304,5 +451,6 @@ if __name__ == "__main__":
|
||||
print("\nEnvironment Variables:")
|
||||
print(f" MORPH_API_KEY: {'Set' if os.getenv('MORPH_API_KEY') else 'Not set'}")
|
||||
print(f" OPENAI_API_KEY: {'Set' if os.getenv('OPENAI_API_KEY') else 'Not set (optional)'}")
|
||||
print(f" HECATE_VM_LIFETIME_SECONDS: {os.getenv('HECATE_VM_LIFETIME_SECONDS', '300')} (default: 300)")
|
||||
print(f" HECATE_DEFAULT_SNAPSHOT_ID: {os.getenv('HECATE_DEFAULT_SNAPSHOT_ID', 'snapshot_p5294qxt')} (default: snapshot_p5294qxt)")
|
||||
print(f" HECATE_VM_TTL_SECONDS: {os.getenv('HECATE_VM_TTL_SECONDS', '1200')} (default: 1200 / 20 minutes)")
|
||||
print(f" HECATE_VM_LIFETIME_SECONDS: {os.getenv('HECATE_VM_LIFETIME_SECONDS', '300')} (default: 300 / 5 minutes)")
|
||||
print(f" HECATE_DEFAULT_SNAPSHOT_ID: {os.getenv('HECATE_DEFAULT_SNAPSHOT_ID', 'snapshot_defv9tjg')} (default: snapshot_defv9tjg)")
|
||||
|
||||
@@ -346,7 +346,7 @@ async def vision_analyze_tool(
|
||||
_log_debug_call("vision_analyze_tool", debug_call_data)
|
||||
_save_debug_log()
|
||||
|
||||
return json.dumps(result, indent=2)
|
||||
return json.dumps(result, indent=2, ensure_ascii=False)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Error analyzing image: {str(e)}"
|
||||
@@ -362,7 +362,7 @@ async def vision_analyze_tool(
|
||||
_log_debug_call("vision_analyze_tool", debug_call_data)
|
||||
_save_debug_log()
|
||||
|
||||
return json.dumps(result, indent=2)
|
||||
return json.dumps(result, indent=2, ensure_ascii=False)
|
||||
|
||||
finally:
|
||||
# Clean up temporary image file
|
||||
|
||||
@@ -183,16 +183,33 @@ Your goal is to preserve ALL important information while reducing length. Never
|
||||
|
||||
Create a markdown summary that captures all key information in a well-organized, scannable format. Include important quotes and code snippets in their original formatting. Focus on actionable information, specific details, and unique insights."""
|
||||
|
||||
# Call the LLM asynchronously
|
||||
response = await nous_client.chat.completions.create(
|
||||
model=model,
|
||||
messages=[
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt}
|
||||
],
|
||||
temperature=0.1, # Low temperature for consistent extraction
|
||||
max_tokens=4000 # Generous limit for comprehensive processing
|
||||
)
|
||||
# Call the LLM asynchronously with retry logic for flaky API
|
||||
max_retries = 6
|
||||
retry_delay = 2 # Start with 2 seconds
|
||||
last_error = None
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
response = await nous_client.chat.completions.create(
|
||||
model=model,
|
||||
messages=[
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt}
|
||||
],
|
||||
temperature=0.1, # Low temperature for consistent extraction
|
||||
max_tokens=4000 # Generous limit for comprehensive processing
|
||||
)
|
||||
break # Success, exit retry loop
|
||||
except Exception as api_error:
|
||||
last_error = api_error
|
||||
if attempt < max_retries - 1:
|
||||
print(f"⚠️ LLM API call failed (attempt {attempt + 1}/{max_retries}): {str(api_error)[:100]}")
|
||||
print(f" Retrying in {retry_delay}s...")
|
||||
await asyncio.sleep(retry_delay)
|
||||
retry_delay = min(retry_delay * 2, 60) # Exponential backoff: 2s, 4s, 8s, 16s, 32s, 60s
|
||||
else:
|
||||
# All retries exhausted
|
||||
raise last_error
|
||||
|
||||
# Get the markdown response directly
|
||||
processed_content = response.choices[0].message.content.strip()
|
||||
@@ -344,7 +361,7 @@ def web_search_tool(query: str, limit: int = 5) -> str:
|
||||
debug_call_data["results_count"] = results_count
|
||||
|
||||
# Convert to JSON
|
||||
result_json = json.dumps(response_data, indent=2)
|
||||
result_json = json.dumps(response_data, indent=2, ensure_ascii=False)
|
||||
|
||||
debug_call_data["final_response_size"] = len(result_json)
|
||||
|
||||
@@ -362,7 +379,7 @@ def web_search_tool(query: str, limit: int = 5) -> str:
|
||||
_log_debug_call("web_search_tool", debug_call_data)
|
||||
_save_debug_log()
|
||||
|
||||
return json.dumps({"error": error_msg})
|
||||
return json.dumps({"error": error_msg}, ensure_ascii=False)
|
||||
|
||||
|
||||
async def web_extract_tool(
|
||||
@@ -575,18 +592,20 @@ async def web_extract_tool(
|
||||
"title": r.get("title", ""),
|
||||
"content": r.get("content", ""),
|
||||
"error": r.get("error"),
|
||||
**({"llm_model": model} if use_llm_processing else {})
|
||||
}
|
||||
for r in response.get("results", [])
|
||||
]
|
||||
trimmed_response = {"results": trimmed_results}
|
||||
# Include model name used for summarization when LLM processing was requested
|
||||
if use_llm_processing:
|
||||
trimmed_response["llm_model"] = model
|
||||
|
||||
if trimmed_response.get("results") == []:
|
||||
result_json = json.dumps({"error": "Content was inaccessible or not found"}, ensure_ascii=False)
|
||||
|
||||
cleaned_result = clean_base64_images(result_json)
|
||||
|
||||
result_json = json.dumps(trimmed_response, indent=2)
|
||||
# Clean base64 images from extracted content
|
||||
cleaned_result = clean_base64_images(result_json)
|
||||
else:
|
||||
result_json = json.dumps(trimmed_response, indent=2, ensure_ascii=False)
|
||||
|
||||
cleaned_result = clean_base64_images(result_json)
|
||||
|
||||
debug_call_data["final_response_size"] = len(cleaned_result)
|
||||
debug_call_data["processing_applied"].append("base64_image_removal")
|
||||
@@ -605,7 +624,7 @@ async def web_extract_tool(
|
||||
_log_debug_call("web_extract_tool", debug_call_data)
|
||||
_save_debug_log()
|
||||
|
||||
return json.dumps({"error": error_msg})
|
||||
return json.dumps({"error": error_msg}, ensure_ascii=False)
|
||||
|
||||
|
||||
async def web_crawl_tool(
|
||||
@@ -851,17 +870,13 @@ async def web_crawl_tool(
|
||||
{
|
||||
"title": r.get("title", ""),
|
||||
"content": r.get("content", ""),
|
||||
"error": r.get("error"),
|
||||
**({"llm_model": model} if use_llm_processing else {})
|
||||
"error": r.get("error")
|
||||
}
|
||||
for r in response.get("results", [])
|
||||
]
|
||||
trimmed_response = {"results": trimmed_results}
|
||||
# Include model name used for summarization when LLM processing was requested
|
||||
if use_llm_processing:
|
||||
trimmed_response["llm_model"] = model
|
||||
|
||||
result_json = json.dumps(trimmed_response, indent=2)
|
||||
result_json = json.dumps(trimmed_response, indent=2, ensure_ascii=False)
|
||||
# Clean base64 images from crawled content
|
||||
cleaned_result = clean_base64_images(result_json)
|
||||
|
||||
@@ -882,7 +897,7 @@ async def web_crawl_tool(
|
||||
_log_debug_call("web_crawl_tool", debug_call_data)
|
||||
_save_debug_log()
|
||||
|
||||
return json.dumps({"error": error_msg})
|
||||
return json.dumps({"error": error_msg}, ensure_ascii=False)
|
||||
|
||||
|
||||
# Convenience function to check if API key is available
|
||||
|
||||
@@ -61,7 +61,19 @@ DISTRIBUTIONS = {
|
||||
"terminal": 10 # 10% chance of terminal tools
|
||||
}
|
||||
},
|
||||
|
||||
|
||||
# Scientific problem solving focused distribution
|
||||
"science": {
|
||||
"description": "Web research with vision analysis and reasoning",
|
||||
"toolsets": {
|
||||
"web": 94, # 90% chance of web tools
|
||||
"vision": 65, # 50% chance of vision tools
|
||||
"moa": 10, # 40% chance of reasoning tools
|
||||
"terminal": 94, # 10% chance of terminal tools
|
||||
"image_gen": 15 # 80% chance of image generation tools
|
||||
}
|
||||
},
|
||||
|
||||
# Development-focused distribution
|
||||
"development": {
|
||||
"description": "Terminal and reasoning with occasional web lookup",
|
||||
|
||||
Reference in New Issue
Block a user