Compare commits

...

1 Commits

Author SHA1 Message Date
hjc-puro
31c733383b add tracking for cluster failurse 2025-11-15 00:01:19 -05:00
2 changed files with 409 additions and 116 deletions

View File

@@ -9,6 +9,8 @@ across multiple prompts from a dataset. It includes:
- Checkpointing for fault tolerance and resumption - Checkpointing for fault tolerance and resumption
- Trajectory saving in the proper format (from/value pairs) - Trajectory saving in the proper format (from/value pairs)
- Tool usage statistics aggregation across all batches - Tool usage statistics aggregation across all batches
- Cluster failure detection and graceful shutdown (morph, firecrawl, API errors)
- Configurable failure thresholds with automatic data consolidation
Usage: Usage:
python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run
@@ -18,6 +20,10 @@ Usage:
# Use a specific toolset distribution # Use a specific toolset distribution
python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run --distribution=image_gen python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run --distribution=image_gen
# Configure tool failure thresholds
python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run \\
--max_tool_failures=20 --max_tool_failure_rate=0.3
""" """
import json import json
@@ -29,6 +35,7 @@ from typing import List, Dict, Any, Optional, Tuple
from datetime import datetime from datetime import datetime
from multiprocessing import Pool, Manager, Lock from multiprocessing import Pool, Manager, Lock
import traceback import traceback
import re
import fire import fire
@@ -39,12 +46,83 @@ from toolset_distributions import (
sample_toolsets_from_distribution, sample_toolsets_from_distribution,
validate_distribution validate_distribution
) )
from safe_print import safe_print
# Global configuration for worker processes # Global configuration for worker processes
_WORKER_CONFIG = {} _WORKER_CONFIG = {}
def _extract_tool_errors_from_messages(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
Extract tool errors from message history with tool names.
Args:
messages (List[Dict]): Message history
Returns:
List[Dict]: List of tool errors with tool name, error message, and context
"""
tool_errors = []
tool_calls_map = {} # Map tool_call_id to tool name
for msg in messages:
# Track tool calls from assistant messages
if msg["role"] == "assistant" and "tool_calls" in msg and msg["tool_calls"]:
for tool_call in msg["tool_calls"]:
tool_name = tool_call["function"]["name"]
tool_call_id = tool_call["id"]
tool_calls_map[tool_call_id] = tool_name
# Check tool responses for errors
elif msg["role"] == "tool":
tool_call_id = msg.get("tool_call_id", "")
content = msg.get("content", "")
# Determine if tool call had an error
has_error = False
error_msg = None
try:
content_json = json.loads(content) if isinstance(content, str) else content
if isinstance(content_json, dict):
# Check if error field exists AND has a non-null value
if "error" in content_json and content_json["error"] is not None:
has_error = True
error_msg = str(content_json["error"])
# Special handling for terminal tool responses
if "content" in content_json and isinstance(content_json["content"], dict):
inner_content = content_json["content"]
if inner_content.get("error") is not None or inner_content.get("exit_code", 0) != 0:
has_error = True
error_msg = inner_content.get("error") or f"Exit code: {inner_content.get('exit_code')}"
# Check for "success": false pattern
if content_json.get("success") is False:
has_error = True
if not error_msg:
error_msg = str(content_json.get("message", content_json.get("error", "Unknown error")))
except:
# If not JSON, check if content explicitly states an error
if content.strip().lower().startswith("error:"):
has_error = True
error_msg = content.strip()
# Record error if found
if has_error and tool_call_id in tool_calls_map:
tool_name = tool_calls_map[tool_call_id]
tool_errors.append({
"tool_name": tool_name,
"error_message": error_msg or "Unknown error",
"full_content": content[:500] # Keep first 500 chars of full response
})
return tool_errors
def _extract_tool_stats(messages: List[Dict[str, Any]]) -> Dict[str, Dict[str, int]]: def _extract_tool_stats(messages: List[Dict[str, Any]]) -> Dict[str, Dict[str, int]]:
""" """
Extract tool usage statistics from message history. Extract tool usage statistics from message history.
@@ -174,6 +252,9 @@ def _process_single_prompt(
# Extract tool usage statistics # Extract tool usage statistics
tool_stats = _extract_tool_stats(result["messages"]) tool_stats = _extract_tool_stats(result["messages"])
# Extract tool errors from conversation
tool_errors = _extract_tool_errors_from_messages(result["messages"])
# Convert to trajectory format (using existing method) # Convert to trajectory format (using existing method)
trajectory = agent._convert_to_trajectory_format( trajectory = agent._convert_to_trajectory_format(
result["messages"], result["messages"],
@@ -186,6 +267,7 @@ def _process_single_prompt(
"prompt_index": prompt_index, "prompt_index": prompt_index,
"trajectory": trajectory, "trajectory": trajectory,
"tool_stats": tool_stats, "tool_stats": tool_stats,
"tool_errors": tool_errors,
"completed": result["completed"], "completed": result["completed"],
"api_calls": result["api_calls"], "api_calls": result["api_calls"],
"toolsets_used": selected_toolsets, "toolsets_used": selected_toolsets,
@@ -197,14 +279,18 @@ def _process_single_prompt(
} }
except Exception as e: except Exception as e:
print(f"❌ Error processing prompt {prompt_index}: {e}") error_msg = str(e)
tb = traceback.format_exc()
safe_print(f"[bold red]❌ Error processing prompt {prompt_index}:[/bold red] {error_msg}")
if config.get("verbose"): if config.get("verbose"):
traceback.print_exc() safe_print(tb)
return { return {
"success": False, "success": False,
"prompt_index": prompt_index, "prompt_index": prompt_index,
"error": str(e), "error": error_msg,
"traceback": tb,
"tool_errors": [],
"trajectory": None, "trajectory": None,
"tool_stats": {}, "tool_stats": {},
"toolsets_used": [], "toolsets_used": [],
@@ -254,6 +340,8 @@ def _process_batch_worker(args: Tuple) -> Dict[str, Any]:
# Initialize aggregated stats for this batch # Initialize aggregated stats for this batch
batch_tool_stats = {} batch_tool_stats = {}
completed_in_batch = [] completed_in_batch = []
all_tool_errors = [] # Track all tool errors in this batch
exception_errors = [] # Track top-level exceptions
# Process each prompt sequentially in this batch # Process each prompt sequentially in this batch
for prompt_index, prompt_data in prompts_to_process: for prompt_index, prompt_data in prompts_to_process:
@@ -265,6 +353,25 @@ def _process_batch_worker(args: Tuple) -> Dict[str, Any]:
config config
) )
# Track tool errors from the conversation
if result.get("tool_errors"):
for tool_error in result["tool_errors"]:
all_tool_errors.append({
"prompt_index": prompt_index,
"tool_name": tool_error["tool_name"],
"error_message": tool_error["error_message"],
"full_content": tool_error.get("full_content", "")
})
# Track top-level exceptions (not tool errors)
if not result["success"]:
exception_errors.append({
"prompt_index": prompt_index,
"error": result.get("error", "Unknown error"),
"traceback": result.get("traceback", "")
})
safe_print(f"[bold red]❌ Exception in prompt {prompt_index}:[/bold red] {result.get('error', '')[:100]}")
# Save trajectory if successful # Save trajectory if successful
if result["success"] and result["trajectory"]: if result["success"] and result["trajectory"]:
trajectory_entry = { trajectory_entry = {
@@ -303,7 +410,9 @@ def _process_batch_worker(args: Tuple) -> Dict[str, Any]:
"processed": len(prompts_to_process), "processed": len(prompts_to_process),
"skipped": len(batch_data) - len(prompts_to_process), "skipped": len(batch_data) - len(prompts_to_process),
"tool_stats": batch_tool_stats, "tool_stats": batch_tool_stats,
"completed_prompts": completed_in_batch "completed_prompts": completed_in_batch,
"tool_errors": all_tool_errors,
"exception_errors": exception_errors
} }
@@ -326,6 +435,9 @@ class BatchRunner:
verbose: bool = False, verbose: bool = False,
ephemeral_system_prompt: str = None, ephemeral_system_prompt: str = None,
log_prefix_chars: int = 100, log_prefix_chars: int = 100,
max_tool_failures: int = 10,
max_tool_failure_rate: float = 0.5,
keep_recent_errors: int = 5,
): ):
""" """
Initialize the batch runner. Initialize the batch runner.
@@ -343,6 +455,9 @@ class BatchRunner:
verbose (bool): Enable verbose logging verbose (bool): Enable verbose logging
ephemeral_system_prompt (str): System prompt used during agent execution but NOT saved to trajectories (optional) ephemeral_system_prompt (str): System prompt used during agent execution but NOT saved to trajectories (optional)
log_prefix_chars (int): Number of characters to show in log previews for tool calls/responses (default: 20) log_prefix_chars (int): Number of characters to show in log previews for tool calls/responses (default: 20)
max_tool_failures (int): Maximum number of tool failures before stopping (default: 10)
max_tool_failure_rate (float): Maximum tool failure rate (0.0-1.0) before stopping (default: 0.5)
keep_recent_errors (int): Number of recent errors to keep per tool (default: 5)
""" """
self.dataset_file = Path(dataset_file) self.dataset_file = Path(dataset_file)
self.batch_size = batch_size self.batch_size = batch_size
@@ -356,6 +471,9 @@ class BatchRunner:
self.verbose = verbose self.verbose = verbose
self.ephemeral_system_prompt = ephemeral_system_prompt self.ephemeral_system_prompt = ephemeral_system_prompt
self.log_prefix_chars = log_prefix_chars self.log_prefix_chars = log_prefix_chars
self.max_tool_failures = max_tool_failures
self.max_tool_failure_rate = max_tool_failure_rate
self.keep_recent_errors = keep_recent_errors
# Validate distribution # Validate distribution
if not validate_distribution(distribution): if not validate_distribution(distribution):
@@ -377,17 +495,21 @@ class BatchRunner:
# Create batches # Create batches
self.batches = self._create_batches() self.batches = self._create_batches()
print(f"📊 Batch Runner Initialized") safe_print("[bold cyan]📊 Batch Runner Initialized[/bold cyan]")
print(f" Dataset: {self.dataset_file} ({len(self.dataset)} prompts)") safe_print(f" Dataset: {self.dataset_file} ({len(self.dataset)} prompts)")
print(f" Batch size: {self.batch_size}") safe_print(f" Batch size: {self.batch_size}")
print(f" Total batches: {len(self.batches)}") safe_print(f" Total batches: {len(self.batches)}")
print(f" Run name: {self.run_name}") safe_print(f" Run name: {self.run_name}")
print(f" Distribution: {self.distribution}") safe_print(f" Distribution: {self.distribution}")
print(f" Output directory: {self.output_dir}") safe_print(f" Output directory: {self.output_dir}")
print(f" Workers: {self.num_workers}") safe_print(f" Workers: {self.num_workers}")
safe_print(f" [yellow]Tool failure limits:[/yellow]")
safe_print(f" Max failures: {self.max_tool_failures}")
safe_print(f" Max failure rate: {self.max_tool_failure_rate:.1%}")
safe_print(f" Keep recent errors: {self.keep_recent_errors}")
if self.ephemeral_system_prompt: if self.ephemeral_system_prompt:
prompt_preview = self.ephemeral_system_prompt[:60] + "..." if len(self.ephemeral_system_prompt) > 60 else self.ephemeral_system_prompt prompt_preview = self.ephemeral_system_prompt[:60] + "..." if len(self.ephemeral_system_prompt) > 60 else self.ephemeral_system_prompt
print(f" 🔒 Ephemeral system prompt: '{prompt_preview}'") safe_print(f" 🔒 Ephemeral system prompt: '{prompt_preview}'")
def _load_dataset(self) -> List[Dict[str, Any]]: def _load_dataset(self) -> List[Dict[str, Any]]:
""" """
@@ -480,6 +602,69 @@ class BatchRunner:
with open(self.checkpoint_file, 'w', encoding='utf-8') as f: with open(self.checkpoint_file, 'w', encoding='utf-8') as f:
json.dump(checkpoint_data, f, indent=2, ensure_ascii=False) json.dump(checkpoint_data, f, indent=2, ensure_ascii=False)
def _consolidate_data(self, num_batches: int, tool_stats: Dict[str, Dict[str, int]],
start_time: float, tool_errors_by_tool: Dict[str, List[Dict]],
exception_errors: List[Dict], early_exit: bool = False, exit_reason: str = None):
"""
Consolidate batch data into trajectories.jsonl and save statistics.
Args:
num_batches (int): Number of batches processed
tool_stats (Dict): Aggregated tool statistics
start_time (float): Start time of the run
tool_errors_by_tool (Dict): Tool errors grouped by tool name with k most recent
exception_errors (List): Top-level exceptions
early_exit (bool): Whether this is an early exit
exit_reason (str): Reason for early exit
"""
# Combine all batch files into a single trajectories.jsonl file
combined_file = self.output_dir / "trajectories.jsonl"
safe_print(f"\n[cyan]📦 Combining batch files into {combined_file.name}...[/cyan]")
entries_written = 0
with open(combined_file, 'w', encoding='utf-8') as outfile:
for batch_num in range(num_batches):
batch_file = self.output_dir / f"batch_{batch_num}.jsonl"
if batch_file.exists():
with open(batch_file, 'r', encoding='utf-8') as infile:
for line in infile:
outfile.write(line)
entries_written += 1
safe_print(f"[green]✅ Combined {num_batches} batch files into trajectories.jsonl ({entries_written} entries)[/green]")
# Calculate success rates for tool stats
for tool_name in tool_stats:
stats = tool_stats[tool_name]
total_calls = stats["success"] + stats["failure"]
if total_calls > 0:
stats["success_rate"] = round(stats["success"] / total_calls * 100, 2)
stats["failure_rate"] = round(stats["failure"] / total_calls * 100, 2)
else:
stats["success_rate"] = 0.0
stats["failure_rate"] = 0.0
# Save final statistics
final_stats = {
"run_name": self.run_name,
"distribution": self.distribution,
"total_prompts": len(self.dataset),
"total_batches": len(self.batches),
"batches_processed": num_batches,
"batch_size": self.batch_size,
"model": self.model,
"completed_at": datetime.now().isoformat(),
"duration_seconds": round(time.time() - start_time, 2),
"early_exit": early_exit,
"exit_reason": exit_reason,
"tool_errors": tool_errors_by_tool,
"exception_errors": exception_errors[:self.keep_recent_errors], # Keep k most recent
"tool_statistics": tool_stats
}
with open(self.stats_file, 'w', encoding='utf-8') as f:
json.dump(final_stats, f, indent=2, ensure_ascii=False)
def run(self, resume: bool = False): def run(self, resume: bool = False):
""" """
@@ -520,6 +705,13 @@ class BatchRunner:
# Aggregate statistics across all batches # Aggregate statistics across all batches
total_tool_stats = {} total_tool_stats = {}
tool_errors_by_tool = {} # {tool_name: [list of k most recent errors]}
all_exception_errors = []
all_completed_prompts = list(completed_prompts_set)
total_processed = len(completed_prompts_set)
total_tool_errors = 0
early_exit = False
exit_reason = None
start_time = time.time() start_time = time.time()
@@ -537,82 +729,145 @@ class BatchRunner:
for batch_num, batch_data in enumerate(self.batches) for batch_num, batch_data in enumerate(self.batches)
] ]
# Use map to process batches in parallel # Process batches and check tool failure threshold after each batch
results = pool.map(_process_batch_worker, tasks) for batch_num, task in enumerate(tasks):
# Process single batch
result = pool.apply(_process_batch_worker, (task,))
# Aggregate all batch statistics and update checkpoint # Update statistics
all_completed_prompts = list(completed_prompts_set) all_completed_prompts.extend(result.get("completed_prompts", []))
for batch_result in results: total_processed += result.get("processed", 0)
# Add newly completed prompts
all_completed_prompts.extend(batch_result.get("completed_prompts", []))
# Aggregate tool stats # Aggregate tool stats
for tool_name, stats in batch_result.get("tool_stats", {}).items(): for tool_name, stats in result.get("tool_stats", {}).items():
if tool_name not in total_tool_stats: if tool_name not in total_tool_stats:
total_tool_stats[tool_name] = { total_tool_stats[tool_name] = {
"count": 0, "count": 0,
"success": 0, "success": 0,
"failure": 0 "failure": 0
} }
total_tool_stats[tool_name]["count"] += stats["count"] total_tool_stats[tool_name]["count"] += stats["count"]
total_tool_stats[tool_name]["success"] += stats["success"] total_tool_stats[tool_name]["success"] += stats["success"]
total_tool_stats[tool_name]["failure"] += stats["failure"] total_tool_stats[tool_name]["failure"] += stats["failure"]
# Aggregate tool errors (keep k most recent per tool)
for tool_error in result.get("tool_errors", []):
tool_name = tool_error["tool_name"]
if tool_name not in tool_errors_by_tool:
tool_errors_by_tool[tool_name] = []
# Add error and keep only k most recent
tool_errors_by_tool[tool_name].append(tool_error)
if len(tool_errors_by_tool[tool_name]) > self.keep_recent_errors:
tool_errors_by_tool[tool_name] = tool_errors_by_tool[tool_name][-self.keep_recent_errors:]
total_tool_errors += 1
# Track exception errors
all_exception_errors.extend(result.get("exception_errors", []))
# Check tool failure thresholds
if total_processed > 0:
tool_failure_rate = total_tool_errors / total_processed
# Check absolute count threshold
if total_tool_errors >= self.max_tool_failures:
early_exit = True
exit_reason = f"Exceeded maximum tool failures ({total_tool_errors}/{self.max_tool_failures})"
safe_print(f"\n[bold red]🛑 STOPPING: {exit_reason}[/bold red]")
break
# Check rate threshold
if tool_failure_rate >= self.max_tool_failure_rate:
early_exit = True
exit_reason = f"Exceeded tool failure rate ({tool_failure_rate:.2%} >= {self.max_tool_failure_rate:.2%})"
safe_print(f"\n[bold red]🛑 STOPPING: {exit_reason}[/bold red]")
break
# Update checkpoint after each batch
checkpoint_data["completed_prompts"] = all_completed_prompts
self._save_checkpoint(checkpoint_data)
# Save final checkpoint # Save final checkpoint
checkpoint_data["completed_prompts"] = all_completed_prompts checkpoint_data["completed_prompts"] = all_completed_prompts
self._save_checkpoint(checkpoint_data) self._save_checkpoint(checkpoint_data)
# Calculate success rates # Consolidate data and save statistics
for tool_name in total_tool_stats: num_batches_processed = batch_num + 1 if early_exit else len(self.batches)
stats = total_tool_stats[tool_name] self._consolidate_data(
total_calls = stats["success"] + stats["failure"] num_batches_processed,
if total_calls > 0: total_tool_stats,
stats["success_rate"] = round(stats["success"] / total_calls * 100, 2) start_time,
stats["failure_rate"] = round(stats["failure"] / total_calls * 100, 2) tool_errors_by_tool,
else: all_exception_errors,
stats["success_rate"] = 0.0 early_exit,
stats["failure_rate"] = 0.0 exit_reason
)
# Combine all batch files into a single trajectories.jsonl file
combined_file = self.output_dir / "trajectories.jsonl"
print(f"\n📦 Combining batch files into {combined_file.name}...")
with open(combined_file, 'w', encoding='utf-8') as outfile:
for batch_num in range(len(self.batches)):
batch_file = self.output_dir / f"batch_{batch_num}.jsonl"
if batch_file.exists():
with open(batch_file, 'r', encoding='utf-8') as infile:
for line in infile:
outfile.write(line)
print(f"✅ Combined {len(self.batches)} batch files into trajectories.jsonl")
# Save final statistics
final_stats = {
"run_name": self.run_name,
"distribution": self.distribution,
"total_prompts": len(self.dataset),
"total_batches": len(self.batches),
"batch_size": self.batch_size,
"model": self.model,
"completed_at": datetime.now().isoformat(),
"duration_seconds": round(time.time() - start_time, 2),
"tool_statistics": total_tool_stats
}
with open(self.stats_file, 'w', encoding='utf-8') as f:
json.dump(final_stats, f, indent=2, ensure_ascii=False)
# Print summary # Print summary
print("\n" + "=" * 70) safe_print("\n" + "=" * 70)
print("📊 BATCH PROCESSING COMPLETE") if early_exit:
print("=" * 70) safe_print("[bold yellow]⚠️ BATCH PROCESSING STOPPED EARLY[/bold yellow]")
print(f"✅ Total prompts processed: {len(self.dataset)}") safe_print(f"[yellow]Reason: {exit_reason}[/yellow]")
print(f"✅ Total batches: {len(self.batches)}") else:
print(f"⏱️ Total duration: {round(time.time() - start_time, 2)}s") safe_print("[bold green]📊 BATCH PROCESSING COMPLETE[/bold green]")
print(f"\n📈 Tool Usage Statistics:") safe_print("=" * 70)
print("-" * 70)
safe_print(f"✅ Total prompts processed: {total_processed}")
safe_print(f"✅ Batches completed: {num_batches_processed}/{len(self.batches)}")
safe_print(f"⏱️ Total duration: {round(time.time() - start_time, 2)}s")
# Tool error summary
if tool_errors_by_tool:
total_errors = sum(len(errors) for errors in tool_errors_by_tool.values())
safe_print(f"\n[bold red]🚨 Tool Errors: {total_tool_errors} total ({len(tool_errors_by_tool)} tools)[/bold red]")
safe_print("[red]-[/red]" * 70)
# Sort tools by error count
sorted_tools = sorted(
tool_errors_by_tool.items(),
key=lambda x: len(x[1]),
reverse=True
)
for tool_name, errors in sorted_tools:
# Count unique error messages
unique_errors = {}
for error in errors:
error_msg = error["error_message"][:100] # Truncate for grouping
if error_msg not in unique_errors:
unique_errors[error_msg] = []
unique_errors[error_msg].append(error)
safe_print(f"\n [red]{tool_name}:[/red] {len(errors)} errors ({len(unique_errors)} unique)")
# Show up to 3 most recent unique error types
for idx, (error_msg, instances) in enumerate(list(unique_errors.items())[:3]):
error_preview = error_msg if len(error_msg) <= 100 else error_msg[:97] + "..."
safe_print(f" [{idx+1}] [dim]{error_preview}[/dim] (x{len(instances)})")
# Show one example with prompt index
example = instances[-1] # Most recent
safe_print(f" [dim]Prompt {example['prompt_index']}[/dim]")
if len(unique_errors) > 3:
safe_print(f" [dim]... and {len(unique_errors) - 3} more error types[/dim]")
tool_failure_rate = total_tool_errors / total_processed if total_processed > 0 else 0
safe_print(f"\n [red]Tool failure rate: {tool_failure_rate:.2%}[/red]")
# Exception errors
if all_exception_errors:
safe_print(f"\n[bold red]💥 Top-level Exceptions: {len(all_exception_errors)}[/bold red]")
safe_print("[red]-[/red]" * 70)
for error in all_exception_errors[:self.keep_recent_errors]:
error_preview = error["error"][:100]
if len(error["error"]) > 100:
error_preview += "..."
safe_print(f" Prompt {error['prompt_index']}: [dim]{error_preview}[/dim]")
safe_print(f"\n[cyan]📈 Tool Usage Statistics:[/cyan]")
safe_print("-" * 70)
if total_tool_stats: if total_tool_stats:
# Sort by count descending # Sort by count descending
@@ -622,24 +877,29 @@ class BatchRunner:
reverse=True reverse=True
) )
print(f"{'Tool Name':<25} {'Count':<10} {'Success':<10} {'Failure':<10} {'Success Rate':<12}") safe_print(f"{'Tool Name':<25} {'Count':<10} {'Success':<10} {'Failure':<10} {'Success Rate':<12}")
print("-" * 70) safe_print("-" * 70)
for tool_name, stats in sorted_tools: for tool_name, stats in sorted_tools:
print( safe_print(
f"{tool_name:<25} " f"{tool_name:<25} "
f"{stats['count']:<10} " f"{stats['count']:<10} "
f"{stats['success']:<10} " f"{stats['success']:<10} "
f"{stats['failure']:<10} " f"{stats['failure']:<10} "
f"{stats['success_rate']:.1f}%" f"{stats.get('success_rate', 0):.1f}%"
) )
else: else:
print("No tool calls were made during this run.") safe_print("No tool calls were made during this run.")
print(f"\n💾 Results saved to: {self.output_dir}") safe_print(f"\n[cyan]💾 Results saved to:[/cyan] {self.output_dir}")
print(f" - Trajectories: trajectories.jsonl (combined)") safe_print(f" - Trajectories: trajectories.jsonl (combined)")
print(f" - Individual batches: batch_*.jsonl (for debugging)") safe_print(f" - Individual batches: batch_*.jsonl (for debugging)")
print(f" - Statistics: {self.stats_file.name}") safe_print(f" - Statistics: {self.stats_file.name}")
print(f" - Checkpoint: {self.checkpoint_file.name}") safe_print(f" - Checkpoint: {self.checkpoint_file.name}")
if early_exit:
safe_print(f"\n[bold yellow] Run was stopped early due to tool failures.[/bold yellow]")
safe_print(f"[yellow] Check {self.stats_file.name} for detailed error information including tracebacks.[/yellow]")
safe_print(f"[yellow] You can resume this run later with --resume flag.[/yellow]")
def main( def main(
@@ -657,6 +917,9 @@ def main(
list_distributions: bool = False, list_distributions: bool = False,
ephemeral_system_prompt: str = None, ephemeral_system_prompt: str = None,
log_prefix_chars: int = 100, log_prefix_chars: int = 100,
max_tool_failures: int = 10,
max_tool_failure_rate: float = 0.5,
keep_recent_errors: int = 5,
): ):
""" """
Run batch processing of agent prompts from a dataset. Run batch processing of agent prompts from a dataset.
@@ -676,6 +939,9 @@ def main(
list_distributions (bool): List available toolset distributions and exit list_distributions (bool): List available toolset distributions and exit
ephemeral_system_prompt (str): System prompt used during agent execution but NOT saved to trajectories (optional) ephemeral_system_prompt (str): System prompt used during agent execution but NOT saved to trajectories (optional)
log_prefix_chars (int): Number of characters to show in log previews for tool calls/responses (default: 20) log_prefix_chars (int): Number of characters to show in log previews for tool calls/responses (default: 20)
max_tool_failures (int): Maximum number of tool failures before stopping (default: 10)
max_tool_failure_rate (float): Maximum tool failure rate (0.0-1.0) before stopping (default: 0.5)
keep_recent_errors (int): Number of recent errors to keep per tool for reporting (default: 5)
Examples: Examples:
# Basic usage # Basic usage
@@ -691,6 +957,10 @@ def main(
python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run \\ python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run \\
--ephemeral_system_prompt="You are a helpful assistant focused on image generation." --ephemeral_system_prompt="You are a helpful assistant focused on image generation."
# With custom tool failure thresholds
python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run \\
--max_tool_failures=20 --max_tool_failure_rate=0.3 --keep_recent_errors=10
# List available distributions # List available distributions
python batch_runner.py --list_distributions python batch_runner.py --list_distributions
""" """
@@ -737,7 +1007,10 @@ def main(
num_workers=num_workers, num_workers=num_workers,
verbose=verbose, verbose=verbose,
ephemeral_system_prompt=ephemeral_system_prompt, ephemeral_system_prompt=ephemeral_system_prompt,
log_prefix_chars=log_prefix_chars log_prefix_chars=log_prefix_chars,
max_tool_failures=max_tool_failures,
max_tool_failure_rate=max_tool_failure_rate,
keep_recent_errors=keep_recent_errors
) )
runner.run(resume=resume) runner.run(resume=resume)

20
safe_print.py Normal file
View File

@@ -0,0 +1,20 @@
#!/usr/bin/env python3
"""Simple safe print that tries rich, falls back to regular print."""
try:
from rich import print as rich_print
RICH_AVAILABLE = True
except ImportError:
RICH_AVAILABLE = False
def safe_print(*args, **kwargs):
"""Try rich.print, fall back to regular print if it fails."""
if RICH_AVAILABLE:
try:
rich_print(*args, **kwargs)
return
except Exception:
pass
# Fallback to regular print
print(*args, **kwargs)