mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-28 23:11:37 +08:00
Compare commits
2 Commits
skill/gith
...
taubench_e
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
545809d09b | ||
|
|
c32efc2885 |
0
environments/benchmarks/taubench/__init__.py
Normal file
0
environments/benchmarks/taubench/__init__.py
Normal file
324
environments/benchmarks/taubench/hermes_agent.py
Normal file
324
environments/benchmarks/taubench/hermes_agent.py
Normal file
@@ -0,0 +1,324 @@
|
|||||||
|
"""
|
||||||
|
HermesAgent for tau2-bench evaluation.
|
||||||
|
|
||||||
|
Implements the tau2 HalfDuplexAgent interface using litellm with OpenRouter,
|
||||||
|
matching the inference path used across the rest of the Hermes Agent codebase.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python environments/benchmarks/taubench/run_eval.py \\
|
||||||
|
--model anthropic/claude-sonnet-4-5 \\
|
||||||
|
--base-url openrouter \\
|
||||||
|
--env retail
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
_repo_root = Path(__file__).resolve().parent.parent.parent.parent
|
||||||
|
if str(_repo_root) not in sys.path:
|
||||||
|
sys.path.insert(0, str(_repo_root))
|
||||||
|
|
||||||
|
from environments.tool_call_parsers import get_parser
|
||||||
|
|
||||||
|
from tau2.agent.base_agent import HalfDuplexAgent, ValidAgentInputMessage
|
||||||
|
from tau2.data_model.message import (
|
||||||
|
AssistantMessage,
|
||||||
|
Message,
|
||||||
|
MultiToolMessage,
|
||||||
|
SystemMessage,
|
||||||
|
ToolCall,
|
||||||
|
ToolMessage,
|
||||||
|
UserMessage,
|
||||||
|
)
|
||||||
|
from tau2.environment.tool import Tool
|
||||||
|
|
||||||
|
|
||||||
|
class HermesAgentState(BaseModel):
|
||||||
|
system_messages: list[SystemMessage]
|
||||||
|
messages: list
|
||||||
|
|
||||||
|
|
||||||
|
class HermesAgent(HalfDuplexAgent[HermesAgentState]):
|
||||||
|
"""
|
||||||
|
tau2 HalfDuplexAgent backed by litellm, using OpenRouter (or any
|
||||||
|
OpenAI-compatible endpoint).
|
||||||
|
|
||||||
|
Registered as "hermes_agent" in the tau2 registry by run_eval.py.
|
||||||
|
"""
|
||||||
|
|
||||||
|
SYSTEM_PROMPT = (
|
||||||
|
"You are a customer service agent that helps the user according to the "
|
||||||
|
"<policy> provided below.\n"
|
||||||
|
"In each turn you can either:\n"
|
||||||
|
"- Send a message to the user.\n"
|
||||||
|
"- Make a tool call.\n"
|
||||||
|
"You cannot do both at the same time.\n\n"
|
||||||
|
"Try to be helpful and always follow the policy. "
|
||||||
|
"Always make sure you generate valid JSON only.\n\n"
|
||||||
|
"<policy>\n{domain_policy}\n</policy>"
|
||||||
|
)
|
||||||
|
|
||||||
|
# System prompt variant for qwen3_coder tool format — tools are embedded
|
||||||
|
# directly in the system prompt as <tools> XML instead of passed via the
|
||||||
|
# OpenAI tools= parameter.
|
||||||
|
SYSTEM_PROMPT_QWEN3_CODER = (
|
||||||
|
"You are a customer service agent that helps the user according to the "
|
||||||
|
"<policy> provided below.\n"
|
||||||
|
"In each turn you can either:\n"
|
||||||
|
"- Send a message to the user.\n"
|
||||||
|
"- Make a tool call.\n"
|
||||||
|
"You cannot do both at the same time.\n\n"
|
||||||
|
"Try to be helpful and always follow the policy. "
|
||||||
|
"Always make sure you generate valid JSON only.\n\n"
|
||||||
|
"You may call one or more functions to assist with the user query.\n\n"
|
||||||
|
"You are provided with function signatures within <tools></tools> XML tags:\n"
|
||||||
|
"<tools>\n{tools_json}\n</tools>\n\n"
|
||||||
|
"<policy>\n{domain_policy}\n</policy>"
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
tools: list[Tool],
|
||||||
|
domain_policy: str,
|
||||||
|
model: str,
|
||||||
|
base_url: Optional[str] = None,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
temperature: float = 0.0,
|
||||||
|
max_tokens: Optional[int] = None,
|
||||||
|
top_p: Optional[float] = None,
|
||||||
|
thinking: bool = False,
|
||||||
|
tool_parser: Optional[str] = None,
|
||||||
|
):
|
||||||
|
super().__init__(tools=tools, domain_policy=domain_policy)
|
||||||
|
self.model = model
|
||||||
|
self.base_url = base_url
|
||||||
|
self.api_key = api_key
|
||||||
|
self.temperature = temperature
|
||||||
|
self.max_tokens = max_tokens
|
||||||
|
self.top_p = top_p
|
||||||
|
self.thinking = thinking
|
||||||
|
self.tool_parser = tool_parser
|
||||||
|
self._parser = get_parser(tool_parser) if tool_parser else None
|
||||||
|
|
||||||
|
# OpenRouter requires specific headers; pass them via litellm extra_headers
|
||||||
|
self._extra_headers: dict = {}
|
||||||
|
if base_url and "openrouter" in base_url.lower():
|
||||||
|
self._extra_headers = {
|
||||||
|
"HTTP-Referer": "https://hermes-agent.nousresearch.com",
|
||||||
|
"X-Title": "Hermes Agent",
|
||||||
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def system_prompt(self) -> str:
|
||||||
|
if self.tool_parser == "qwen3_coder" and self.tools:
|
||||||
|
tools_json = json.dumps(
|
||||||
|
[t.openai_schema for t in self.tools], indent=2, ensure_ascii=False
|
||||||
|
)
|
||||||
|
return self.SYSTEM_PROMPT_QWEN3_CODER.format(
|
||||||
|
tools_json=tools_json,
|
||||||
|
domain_policy=self.domain_policy,
|
||||||
|
)
|
||||||
|
return self.SYSTEM_PROMPT.format(domain_policy=self.domain_policy)
|
||||||
|
|
||||||
|
def get_init_state(
|
||||||
|
self, message_history: Optional[list[Message]] = None
|
||||||
|
) -> HermesAgentState:
|
||||||
|
return HermesAgentState(
|
||||||
|
system_messages=[SystemMessage(role="system", content=self.system_prompt)],
|
||||||
|
messages=list(message_history or []),
|
||||||
|
)
|
||||||
|
|
||||||
|
def generate_next_message(
|
||||||
|
self, message: ValidAgentInputMessage, state: HermesAgentState
|
||||||
|
) -> tuple[AssistantMessage, HermesAgentState]:
|
||||||
|
# Append incoming message(s) to history
|
||||||
|
if isinstance(message, MultiToolMessage):
|
||||||
|
state.messages.extend(message.tool_messages)
|
||||||
|
else:
|
||||||
|
state.messages.append(message)
|
||||||
|
|
||||||
|
# Build litellm-compatible message list
|
||||||
|
all_messages = state.system_messages + state.messages
|
||||||
|
lm_messages = [_to_litellm_message(m) for m in all_messages]
|
||||||
|
|
||||||
|
kwargs = dict(
|
||||||
|
model=self.model,
|
||||||
|
messages=lm_messages,
|
||||||
|
temperature=self.temperature,
|
||||||
|
)
|
||||||
|
if self.tools:
|
||||||
|
kwargs["tools"] = [t.openai_schema for t in self.tools]
|
||||||
|
if self.max_tokens is not None:
|
||||||
|
kwargs["max_tokens"] = self.max_tokens
|
||||||
|
if self.top_p is not None:
|
||||||
|
kwargs["top_p"] = self.top_p
|
||||||
|
# Enable thinking/reasoning mode. OpenRouter exposes this as
|
||||||
|
# `include_reasoning` for nemotron (per supported_parameters in the
|
||||||
|
# model metadata). Pass via extra_body to bypass litellm filtering.
|
||||||
|
if self.thinking:
|
||||||
|
kwargs["extra_body"] = {"include_reasoning": True}
|
||||||
|
# Only pass base_url when model doesn't already have a provider prefix
|
||||||
|
# (litellm uses either the prefix OR base_url, not both)
|
||||||
|
if self.base_url and not self.model.startswith("openrouter/"):
|
||||||
|
kwargs["base_url"] = self.base_url
|
||||||
|
if self.api_key:
|
||||||
|
kwargs["api_key"] = self.api_key
|
||||||
|
if self._extra_headers:
|
||||||
|
kwargs["extra_headers"] = self._extra_headers
|
||||||
|
|
||||||
|
response = litellm.completion(**kwargs)
|
||||||
|
assistant_msg = _litellm_response_to_assistant_message(response, parser=self._parser)
|
||||||
|
|
||||||
|
state.messages.append(assistant_msg)
|
||||||
|
return assistant_msg, state
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Conversion helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _to_litellm_message(msg) -> dict:
|
||||||
|
"""Convert a tau2 message object to a litellm-compatible dict."""
|
||||||
|
if isinstance(msg, SystemMessage):
|
||||||
|
return {"role": "system", "content": msg.content or ""}
|
||||||
|
|
||||||
|
if isinstance(msg, UserMessage):
|
||||||
|
if msg.tool_calls:
|
||||||
|
# User tool calls (tau2 v2 feature — user has tools too)
|
||||||
|
return {
|
||||||
|
"role": "user",
|
||||||
|
"content": msg.content or "",
|
||||||
|
"tool_calls": [_tool_call_to_dict(tc) for tc in msg.tool_calls],
|
||||||
|
}
|
||||||
|
return {"role": "user", "content": msg.content or ""}
|
||||||
|
|
||||||
|
if isinstance(msg, AssistantMessage):
|
||||||
|
d: dict = {"role": "assistant", "content": msg.content or ""}
|
||||||
|
if msg.tool_calls:
|
||||||
|
d["tool_calls"] = [_tool_call_to_dict(tc) for tc in msg.tool_calls]
|
||||||
|
return d
|
||||||
|
|
||||||
|
if isinstance(msg, ToolMessage):
|
||||||
|
return {
|
||||||
|
"role": "tool",
|
||||||
|
"tool_call_id": msg.id,
|
||||||
|
"content": msg.content or "",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Fallback
|
||||||
|
return {"role": getattr(msg, "role", "user"), "content": str(getattr(msg, "content", ""))}
|
||||||
|
|
||||||
|
|
||||||
|
def _tool_call_to_dict(tc: ToolCall) -> dict:
|
||||||
|
import json
|
||||||
|
return {
|
||||||
|
"id": tc.id or "call_0",
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": tc.name,
|
||||||
|
"arguments": json.dumps(tc.arguments),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _litellm_response_to_assistant_message(response, parser=None) -> AssistantMessage:
|
||||||
|
"""Convert a litellm ModelResponse to a tau2 AssistantMessage."""
|
||||||
|
import json
|
||||||
|
|
||||||
|
choice = response.choices[0]
|
||||||
|
msg = choice.message
|
||||||
|
|
||||||
|
content = msg.content or ""
|
||||||
|
tool_calls_raw = getattr(msg, "tool_calls", None)
|
||||||
|
|
||||||
|
tau2_tool_calls: Optional[list[ToolCall]] = None
|
||||||
|
|
||||||
|
if parser and content:
|
||||||
|
# Use the custom tool parser (e.g. qwen3_coder) to extract tool calls
|
||||||
|
# from the raw text response.
|
||||||
|
parsed_content, parsed_tool_calls = parser.parse(content)
|
||||||
|
if parsed_tool_calls:
|
||||||
|
content = parsed_content or ""
|
||||||
|
tau2_tool_calls = []
|
||||||
|
for tc in parsed_tool_calls:
|
||||||
|
try:
|
||||||
|
arguments = json.loads(tc.function.arguments or "{}")
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
arguments = {}
|
||||||
|
tau2_tool_calls.append(
|
||||||
|
ToolCall(
|
||||||
|
id=tc.id or "call_0",
|
||||||
|
name=tc.function.name,
|
||||||
|
arguments=arguments,
|
||||||
|
requestor="assistant",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif tool_calls_raw:
|
||||||
|
tau2_tool_calls = []
|
||||||
|
for tc in tool_calls_raw:
|
||||||
|
if hasattr(tc, "function"):
|
||||||
|
name = tc.function.name
|
||||||
|
try:
|
||||||
|
arguments = json.loads(tc.function.arguments or "{}")
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
arguments = {}
|
||||||
|
tau2_tool_calls.append(
|
||||||
|
ToolCall(
|
||||||
|
id=tc.id or "call_0",
|
||||||
|
name=name,
|
||||||
|
arguments=arguments,
|
||||||
|
requestor="assistant",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
cost = None
|
||||||
|
try:
|
||||||
|
cost = litellm.completion_cost(response)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
usage = None
|
||||||
|
if hasattr(response, "usage") and response.usage:
|
||||||
|
usage = dict(response.usage)
|
||||||
|
|
||||||
|
return AssistantMessage(
|
||||||
|
role="assistant",
|
||||||
|
content=content if not tau2_tool_calls else None,
|
||||||
|
tool_calls=tau2_tool_calls,
|
||||||
|
cost=cost,
|
||||||
|
usage=usage,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def create_hermes_agent(tools: list[Tool], domain_policy: str, **kwargs) -> HermesAgent:
|
||||||
|
"""
|
||||||
|
Factory function registered with the tau2 registry.
|
||||||
|
|
||||||
|
Expected kwargs:
|
||||||
|
model (str): litellm model string
|
||||||
|
base_url (str): API base URL (optional)
|
||||||
|
api_key (str): API key (optional)
|
||||||
|
temperature (float): sampling temperature (default 0.0)
|
||||||
|
top_p (float): nucleus sampling (optional)
|
||||||
|
max_tokens (int): max tokens (optional)
|
||||||
|
thinking (bool): enable reasoning/thinking mode (default False)
|
||||||
|
"""
|
||||||
|
return HermesAgent(
|
||||||
|
tools=tools,
|
||||||
|
domain_policy=domain_policy,
|
||||||
|
model=kwargs["model"],
|
||||||
|
base_url=kwargs.get("base_url"),
|
||||||
|
api_key=kwargs.get("api_key"),
|
||||||
|
temperature=kwargs.get("temperature", 0.0),
|
||||||
|
top_p=kwargs.get("top_p"),
|
||||||
|
max_tokens=kwargs.get("max_tokens"),
|
||||||
|
thinking=kwargs.get("thinking", False),
|
||||||
|
tool_parser=kwargs.get("tool_parser"),
|
||||||
|
)
|
||||||
288
environments/benchmarks/taubench/run_eval.py
Normal file
288
environments/benchmarks/taubench/run_eval.py
Normal file
@@ -0,0 +1,288 @@
|
|||||||
|
"""
|
||||||
|
tau2-bench evaluation runner for Hermes Agent.
|
||||||
|
|
||||||
|
Runs the tau2-bench retail, airline, telecom, or banking_knowledge evaluation
|
||||||
|
using HermesAgent backed by litellm — the same inference path used across the
|
||||||
|
rest of the Hermes Agent codebase.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
# Against OpenRouter (auto-detects OPENROUTER_API_KEY)
|
||||||
|
python environments/benchmarks/taubench/run_eval.py \\
|
||||||
|
--model openrouter/anthropic/claude-sonnet-4-5 \\
|
||||||
|
--base-url openrouter \\
|
||||||
|
--env retail
|
||||||
|
|
||||||
|
# Against OpenAI directly
|
||||||
|
python environments/benchmarks/taubench/run_eval.py \\
|
||||||
|
--model gpt-4o \\
|
||||||
|
--env retail
|
||||||
|
|
||||||
|
# Local vLLM
|
||||||
|
python environments/benchmarks/taubench/run_eval.py \\
|
||||||
|
--model openai/NousResearch/Hermes-3-Llama-3.1-70B \\
|
||||||
|
--base-url http://localhost:8000/v1 \\
|
||||||
|
--env retail \\
|
||||||
|
--num-trials 3
|
||||||
|
|
||||||
|
# Specific tasks only
|
||||||
|
python environments/benchmarks/taubench/run_eval.py \\
|
||||||
|
--model openrouter/anthropic/claude-sonnet-4-5 \\
|
||||||
|
--base-url openrouter \\
|
||||||
|
--env retail \\
|
||||||
|
--task-ids task_1 task_2 task_5
|
||||||
|
|
||||||
|
Results are saved to results/tau2bench/ as JSON.
|
||||||
|
|
||||||
|
Dependencies (requires Python 3.12+):
|
||||||
|
pip install "tau2 @ git+https://github.com/sierra-research/tau2-bench.git"
|
||||||
|
# or: pip install -e ".[tau2bench]"
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
_repo_root = Path(__file__).resolve().parent.parent.parent.parent
|
||||||
|
if str(_repo_root) not in sys.path:
|
||||||
|
sys.path.insert(0, str(_repo_root))
|
||||||
|
|
||||||
|
from tau2.data_model.simulation import Results, TextRunConfig
|
||||||
|
from tau2.evaluator.evaluator import EvaluationType
|
||||||
|
from tau2.registry import registry
|
||||||
|
from tau2.runner.batch import run_tasks
|
||||||
|
from tau2.runner.helpers import get_tasks
|
||||||
|
|
||||||
|
from environments.benchmarks.taubench.hermes_agent import create_hermes_agent
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s"
|
||||||
|
)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
OPENROUTER_BASE_URL = "https://openrouter.ai/api/v1"
|
||||||
|
AGENT_NAME = "hermes_agent"
|
||||||
|
|
||||||
|
|
||||||
|
def _register_agent(
|
||||||
|
model: str,
|
||||||
|
base_url: Optional[str],
|
||||||
|
api_key: Optional[str],
|
||||||
|
temperature: float,
|
||||||
|
top_p: Optional[float],
|
||||||
|
max_tokens: Optional[int],
|
||||||
|
thinking: bool,
|
||||||
|
tool_parser: Optional[str],
|
||||||
|
) -> None:
|
||||||
|
"""Register the HermesAgent factory with the tau2 registry (idempotent)."""
|
||||||
|
if registry.get_agent_factory(AGENT_NAME) is not None:
|
||||||
|
return
|
||||||
|
|
||||||
|
def factory(tools, domain_policy, **kwargs):
|
||||||
|
return create_hermes_agent(
|
||||||
|
tools=tools,
|
||||||
|
domain_policy=domain_policy,
|
||||||
|
model=model,
|
||||||
|
base_url=base_url,
|
||||||
|
api_key=api_key,
|
||||||
|
temperature=temperature,
|
||||||
|
top_p=top_p,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
thinking=thinking,
|
||||||
|
tool_parser=tool_parser,
|
||||||
|
)
|
||||||
|
|
||||||
|
registry.register_agent_factory(factory=factory, name=AGENT_NAME)
|
||||||
|
logger.info("Registered agent factory: %s (model=%s, thinking=%s, tool_parser=%s)", AGENT_NAME, model, thinking, tool_parser)
|
||||||
|
|
||||||
|
|
||||||
|
def run_eval(
|
||||||
|
model: str,
|
||||||
|
base_url: Optional[str],
|
||||||
|
api_key: Optional[str],
|
||||||
|
user_model: str,
|
||||||
|
env_name: str,
|
||||||
|
task_split: Optional[str],
|
||||||
|
num_trials: int,
|
||||||
|
max_concurrency: int,
|
||||||
|
max_steps: int,
|
||||||
|
temperature: float,
|
||||||
|
top_p: Optional[float],
|
||||||
|
max_tokens: Optional[int],
|
||||||
|
thinking: bool,
|
||||||
|
tool_parser: Optional[str],
|
||||||
|
task_ids: Optional[list],
|
||||||
|
start_index: int,
|
||||||
|
end_index: int,
|
||||||
|
log_dir: str,
|
||||||
|
seed: int,
|
||||||
|
) -> Results:
|
||||||
|
# Resolve OpenRouter shorthand
|
||||||
|
if base_url and base_url.strip().lower() == "openrouter":
|
||||||
|
base_url = OPENROUTER_BASE_URL
|
||||||
|
|
||||||
|
is_openrouter = base_url and "openrouter" in base_url.lower()
|
||||||
|
|
||||||
|
# litellm requires the "openrouter/" prefix to route correctly
|
||||||
|
if is_openrouter and not model.startswith("openrouter/"):
|
||||||
|
model = f"openrouter/{model}"
|
||||||
|
if is_openrouter and not user_model.startswith("openrouter/"):
|
||||||
|
user_model = f"openrouter/{user_model}"
|
||||||
|
|
||||||
|
# Resolve API key
|
||||||
|
if is_openrouter:
|
||||||
|
api_key = api_key or os.environ.get("OPENROUTER_API_KEY") or os.environ.get("OPENAI_API_KEY")
|
||||||
|
# litellm reads OPENAI_API_KEY for base_url overrides; set it so the
|
||||||
|
# user simulator's generate() call also authenticates correctly.
|
||||||
|
if api_key and not os.environ.get("OPENAI_API_KEY"):
|
||||||
|
os.environ["OPENAI_API_KEY"] = api_key
|
||||||
|
else:
|
||||||
|
api_key = api_key or os.environ.get("OPENAI_API_KEY")
|
||||||
|
|
||||||
|
_register_agent(
|
||||||
|
model=model,
|
||||||
|
base_url=base_url,
|
||||||
|
api_key=api_key,
|
||||||
|
temperature=temperature,
|
||||||
|
top_p=top_p,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
thinking=thinking,
|
||||||
|
tool_parser=tool_parser,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load tasks — task_ids in tau2 are strings like "task_1"
|
||||||
|
tasks = get_tasks(
|
||||||
|
task_set_name=env_name,
|
||||||
|
task_split_name=task_split,
|
||||||
|
task_ids=[str(i) for i in task_ids] if task_ids else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not task_ids and (end_index != -1 or start_index != 0):
|
||||||
|
end = end_index if end_index != -1 else len(tasks)
|
||||||
|
tasks = tasks[start_index:end]
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Running tau2-%s eval: %d tasks, %d trial(s), concurrency=%d",
|
||||||
|
env_name, len(tasks), num_trials, max_concurrency,
|
||||||
|
)
|
||||||
|
|
||||||
|
save_path = Path(log_dir) / f"tau2-{env_name}-{model.split('/')[-1]}.json"
|
||||||
|
save_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Pass api_key/base_url to user sim via llm_args so tau2's generate() authenticates.
|
||||||
|
# When using OpenRouter for the user sim, mirror the agent's key + endpoint.
|
||||||
|
user_llm_args: dict = {}
|
||||||
|
if is_openrouter and api_key:
|
||||||
|
user_llm_args["api_key"] = api_key
|
||||||
|
user_llm_args["base_url"] = base_url
|
||||||
|
|
||||||
|
config = TextRunConfig(
|
||||||
|
domain=env_name,
|
||||||
|
agent=AGENT_NAME,
|
||||||
|
user="user_simulator",
|
||||||
|
llm_agent=model,
|
||||||
|
llm_args_agent={},
|
||||||
|
llm_user=user_model,
|
||||||
|
llm_args_user=user_llm_args,
|
||||||
|
num_trials=num_trials,
|
||||||
|
max_steps=max_steps,
|
||||||
|
max_concurrency=max_concurrency,
|
||||||
|
seed=seed,
|
||||||
|
)
|
||||||
|
|
||||||
|
results = run_tasks(
|
||||||
|
config,
|
||||||
|
tasks,
|
||||||
|
save_path=save_path,
|
||||||
|
console_display=True,
|
||||||
|
# ALL: respects each task's reward_basis. NL assertions are skipped
|
||||||
|
# gracefully (scored as pass) rather than raising an error, so tasks
|
||||||
|
# are evaluated only on their actual basis components (DB, ACTION, etc.)
|
||||||
|
evaluation_type=EvaluationType.ALL,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("Results saved to %s", save_path)
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Run tau2-bench evaluation with Hermes Agent (requires Python 3.12+)",
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--model", required=True,
|
||||||
|
help="litellm model string, e.g. 'openrouter/anthropic/claude-sonnet-4-5' or 'gpt-4o'",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--base-url", default=None,
|
||||||
|
help="API base URL. Use 'openrouter' as shorthand for https://openrouter.ai/api/v1.",
|
||||||
|
)
|
||||||
|
parser.add_argument("--api-key", default=None, help="API key (falls back to OPENROUTER_API_KEY / OPENAI_API_KEY)")
|
||||||
|
parser.add_argument("--temperature", type=float, default=1.0,
|
||||||
|
help="Sampling temperature. NVIDIA used 1.0 for nemotron-super.")
|
||||||
|
parser.add_argument("--top-p", type=float, default=0.95,
|
||||||
|
help="Nucleus sampling. NVIDIA used 0.95 for nemotron-super.")
|
||||||
|
parser.add_argument("--max-tokens", type=int, default=None)
|
||||||
|
parser.add_argument("--thinking", action="store_true", default=False,
|
||||||
|
help="Enable reasoning/thinking mode (use_reasoning=true). "
|
||||||
|
"Required to match NVIDIA's reported nemotron-super scores.")
|
||||||
|
parser.add_argument("--tool-parser", default=None,
|
||||||
|
help="Tool call parser to use (e.g. 'qwen3_coder'). When set, tools are "
|
||||||
|
"embedded in the system prompt as <tools> XML and responses are parsed "
|
||||||
|
"from raw text instead of using OpenAI function calling format.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--user-model", default="qwen/qwen3-235b-a22b-2507:nitro",
|
||||||
|
help="litellm model string for the tau2 user simulator. "
|
||||||
|
"Defaults to qwen/qwen3-235b-a22b-2507:nitro (instruct, non-thinking) to match NVIDIA's eval setup. "
|
||||||
|
"When using --base-url openrouter the openrouter/ prefix is added automatically.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--env", default="retail",
|
||||||
|
choices=["retail", "airline", "telecom", "banking_knowledge", "mock"],
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--task-split", default=None,
|
||||||
|
help="Task split name (e.g. 'base'). Defaults to the domain default.",
|
||||||
|
)
|
||||||
|
parser.add_argument("--num-trials", type=int, default=1)
|
||||||
|
parser.add_argument("--max-concurrency", type=int, default=8)
|
||||||
|
parser.add_argument("--max-steps", type=int, default=50)
|
||||||
|
parser.add_argument(
|
||||||
|
"--task-ids", nargs="*", default=None,
|
||||||
|
help="Specific task IDs to run (tau2 task IDs are strings like 'task_1')",
|
||||||
|
)
|
||||||
|
parser.add_argument("--start-index", type=int, default=0)
|
||||||
|
parser.add_argument("--end-index", type=int, default=-1)
|
||||||
|
parser.add_argument("--seed", type=int, default=10)
|
||||||
|
parser.add_argument("--log-dir", default="results/tau2bench")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
run_eval(
|
||||||
|
model=args.model,
|
||||||
|
base_url=args.base_url,
|
||||||
|
api_key=args.api_key,
|
||||||
|
user_model=args.user_model,
|
||||||
|
env_name=args.env,
|
||||||
|
task_split=args.task_split,
|
||||||
|
num_trials=args.num_trials,
|
||||||
|
max_concurrency=args.max_concurrency,
|
||||||
|
max_steps=args.max_steps,
|
||||||
|
temperature=args.temperature,
|
||||||
|
top_p=args.top_p,
|
||||||
|
max_tokens=args.max_tokens,
|
||||||
|
thinking=args.thinking,
|
||||||
|
tool_parser=args.tool_parser,
|
||||||
|
task_ids=args.task_ids,
|
||||||
|
start_index=args.start_index,
|
||||||
|
end_index=args.end_index,
|
||||||
|
log_dir=args.log_dir,
|
||||||
|
seed=args.seed,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -72,6 +72,8 @@ rl = [
|
|||||||
"wandb>=0.15.0,<1",
|
"wandb>=0.15.0,<1",
|
||||||
]
|
]
|
||||||
yc-bench = ["yc-bench @ git+https://github.com/collinear-ai/yc-bench.git ; python_version >= '3.12'"]
|
yc-bench = ["yc-bench @ git+https://github.com/collinear-ai/yc-bench.git ; python_version >= '3.12'"]
|
||||||
|
taubench = ["tau-bench @ git+https://github.com/sierra-research/tau-bench.git"]
|
||||||
|
tau2bench = ["tau2 @ git+https://github.com/sierra-research/tau2-bench.git"]
|
||||||
all = [
|
all = [
|
||||||
"hermes-agent[modal]",
|
"hermes-agent[modal]",
|
||||||
"hermes-agent[daytona]",
|
"hermes-agent[daytona]",
|
||||||
|
|||||||
Reference in New Issue
Block a user