Compare commits

...

2 Commits

Author SHA1 Message Date
Sam Herring
545809d09b Tau2 bench changes 2026-04-06 15:30:54 -07:00
Sam Herring
c32efc2885 Initial taubench implementation 2026-04-02 09:23:42 -07:00
4 changed files with 614 additions and 0 deletions

View File

@@ -0,0 +1,324 @@
"""
HermesAgent for tau2-bench evaluation.
Implements the tau2 HalfDuplexAgent interface using litellm with OpenRouter,
matching the inference path used across the rest of the Hermes Agent codebase.
Usage:
python environments/benchmarks/taubench/run_eval.py \\
--model anthropic/claude-sonnet-4-5 \\
--base-url openrouter \\
--env retail
"""
import json
import os
import sys
from pathlib import Path
from typing import Optional
import litellm
from pydantic import BaseModel
_repo_root = Path(__file__).resolve().parent.parent.parent.parent
if str(_repo_root) not in sys.path:
sys.path.insert(0, str(_repo_root))
from environments.tool_call_parsers import get_parser
from tau2.agent.base_agent import HalfDuplexAgent, ValidAgentInputMessage
from tau2.data_model.message import (
AssistantMessage,
Message,
MultiToolMessage,
SystemMessage,
ToolCall,
ToolMessage,
UserMessage,
)
from tau2.environment.tool import Tool
class HermesAgentState(BaseModel):
system_messages: list[SystemMessage]
messages: list
class HermesAgent(HalfDuplexAgent[HermesAgentState]):
"""
tau2 HalfDuplexAgent backed by litellm, using OpenRouter (or any
OpenAI-compatible endpoint).
Registered as "hermes_agent" in the tau2 registry by run_eval.py.
"""
SYSTEM_PROMPT = (
"You are a customer service agent that helps the user according to the "
"<policy> provided below.\n"
"In each turn you can either:\n"
"- Send a message to the user.\n"
"- Make a tool call.\n"
"You cannot do both at the same time.\n\n"
"Try to be helpful and always follow the policy. "
"Always make sure you generate valid JSON only.\n\n"
"<policy>\n{domain_policy}\n</policy>"
)
# System prompt variant for qwen3_coder tool format — tools are embedded
# directly in the system prompt as <tools> XML instead of passed via the
# OpenAI tools= parameter.
SYSTEM_PROMPT_QWEN3_CODER = (
"You are a customer service agent that helps the user according to the "
"<policy> provided below.\n"
"In each turn you can either:\n"
"- Send a message to the user.\n"
"- Make a tool call.\n"
"You cannot do both at the same time.\n\n"
"Try to be helpful and always follow the policy. "
"Always make sure you generate valid JSON only.\n\n"
"You may call one or more functions to assist with the user query.\n\n"
"You are provided with function signatures within <tools></tools> XML tags:\n"
"<tools>\n{tools_json}\n</tools>\n\n"
"<policy>\n{domain_policy}\n</policy>"
)
def __init__(
self,
tools: list[Tool],
domain_policy: str,
model: str,
base_url: Optional[str] = None,
api_key: Optional[str] = None,
temperature: float = 0.0,
max_tokens: Optional[int] = None,
top_p: Optional[float] = None,
thinking: bool = False,
tool_parser: Optional[str] = None,
):
super().__init__(tools=tools, domain_policy=domain_policy)
self.model = model
self.base_url = base_url
self.api_key = api_key
self.temperature = temperature
self.max_tokens = max_tokens
self.top_p = top_p
self.thinking = thinking
self.tool_parser = tool_parser
self._parser = get_parser(tool_parser) if tool_parser else None
# OpenRouter requires specific headers; pass them via litellm extra_headers
self._extra_headers: dict = {}
if base_url and "openrouter" in base_url.lower():
self._extra_headers = {
"HTTP-Referer": "https://hermes-agent.nousresearch.com",
"X-Title": "Hermes Agent",
}
@property
def system_prompt(self) -> str:
if self.tool_parser == "qwen3_coder" and self.tools:
tools_json = json.dumps(
[t.openai_schema for t in self.tools], indent=2, ensure_ascii=False
)
return self.SYSTEM_PROMPT_QWEN3_CODER.format(
tools_json=tools_json,
domain_policy=self.domain_policy,
)
return self.SYSTEM_PROMPT.format(domain_policy=self.domain_policy)
def get_init_state(
self, message_history: Optional[list[Message]] = None
) -> HermesAgentState:
return HermesAgentState(
system_messages=[SystemMessage(role="system", content=self.system_prompt)],
messages=list(message_history or []),
)
def generate_next_message(
self, message: ValidAgentInputMessage, state: HermesAgentState
) -> tuple[AssistantMessage, HermesAgentState]:
# Append incoming message(s) to history
if isinstance(message, MultiToolMessage):
state.messages.extend(message.tool_messages)
else:
state.messages.append(message)
# Build litellm-compatible message list
all_messages = state.system_messages + state.messages
lm_messages = [_to_litellm_message(m) for m in all_messages]
kwargs = dict(
model=self.model,
messages=lm_messages,
temperature=self.temperature,
)
if self.tools:
kwargs["tools"] = [t.openai_schema for t in self.tools]
if self.max_tokens is not None:
kwargs["max_tokens"] = self.max_tokens
if self.top_p is not None:
kwargs["top_p"] = self.top_p
# Enable thinking/reasoning mode. OpenRouter exposes this as
# `include_reasoning` for nemotron (per supported_parameters in the
# model metadata). Pass via extra_body to bypass litellm filtering.
if self.thinking:
kwargs["extra_body"] = {"include_reasoning": True}
# Only pass base_url when model doesn't already have a provider prefix
# (litellm uses either the prefix OR base_url, not both)
if self.base_url and not self.model.startswith("openrouter/"):
kwargs["base_url"] = self.base_url
if self.api_key:
kwargs["api_key"] = self.api_key
if self._extra_headers:
kwargs["extra_headers"] = self._extra_headers
response = litellm.completion(**kwargs)
assistant_msg = _litellm_response_to_assistant_message(response, parser=self._parser)
state.messages.append(assistant_msg)
return assistant_msg, state
# ---------------------------------------------------------------------------
# Conversion helpers
# ---------------------------------------------------------------------------
def _to_litellm_message(msg) -> dict:
"""Convert a tau2 message object to a litellm-compatible dict."""
if isinstance(msg, SystemMessage):
return {"role": "system", "content": msg.content or ""}
if isinstance(msg, UserMessage):
if msg.tool_calls:
# User tool calls (tau2 v2 feature — user has tools too)
return {
"role": "user",
"content": msg.content or "",
"tool_calls": [_tool_call_to_dict(tc) for tc in msg.tool_calls],
}
return {"role": "user", "content": msg.content or ""}
if isinstance(msg, AssistantMessage):
d: dict = {"role": "assistant", "content": msg.content or ""}
if msg.tool_calls:
d["tool_calls"] = [_tool_call_to_dict(tc) for tc in msg.tool_calls]
return d
if isinstance(msg, ToolMessage):
return {
"role": "tool",
"tool_call_id": msg.id,
"content": msg.content or "",
}
# Fallback
return {"role": getattr(msg, "role", "user"), "content": str(getattr(msg, "content", ""))}
def _tool_call_to_dict(tc: ToolCall) -> dict:
import json
return {
"id": tc.id or "call_0",
"type": "function",
"function": {
"name": tc.name,
"arguments": json.dumps(tc.arguments),
},
}
def _litellm_response_to_assistant_message(response, parser=None) -> AssistantMessage:
"""Convert a litellm ModelResponse to a tau2 AssistantMessage."""
import json
choice = response.choices[0]
msg = choice.message
content = msg.content or ""
tool_calls_raw = getattr(msg, "tool_calls", None)
tau2_tool_calls: Optional[list[ToolCall]] = None
if parser and content:
# Use the custom tool parser (e.g. qwen3_coder) to extract tool calls
# from the raw text response.
parsed_content, parsed_tool_calls = parser.parse(content)
if parsed_tool_calls:
content = parsed_content or ""
tau2_tool_calls = []
for tc in parsed_tool_calls:
try:
arguments = json.loads(tc.function.arguments or "{}")
except json.JSONDecodeError:
arguments = {}
tau2_tool_calls.append(
ToolCall(
id=tc.id or "call_0",
name=tc.function.name,
arguments=arguments,
requestor="assistant",
)
)
elif tool_calls_raw:
tau2_tool_calls = []
for tc in tool_calls_raw:
if hasattr(tc, "function"):
name = tc.function.name
try:
arguments = json.loads(tc.function.arguments or "{}")
except json.JSONDecodeError:
arguments = {}
tau2_tool_calls.append(
ToolCall(
id=tc.id or "call_0",
name=name,
arguments=arguments,
requestor="assistant",
)
)
cost = None
try:
cost = litellm.completion_cost(response)
except Exception:
pass
usage = None
if hasattr(response, "usage") and response.usage:
usage = dict(response.usage)
return AssistantMessage(
role="assistant",
content=content if not tau2_tool_calls else None,
tool_calls=tau2_tool_calls,
cost=cost,
usage=usage,
)
def create_hermes_agent(tools: list[Tool], domain_policy: str, **kwargs) -> HermesAgent:
"""
Factory function registered with the tau2 registry.
Expected kwargs:
model (str): litellm model string
base_url (str): API base URL (optional)
api_key (str): API key (optional)
temperature (float): sampling temperature (default 0.0)
top_p (float): nucleus sampling (optional)
max_tokens (int): max tokens (optional)
thinking (bool): enable reasoning/thinking mode (default False)
"""
return HermesAgent(
tools=tools,
domain_policy=domain_policy,
model=kwargs["model"],
base_url=kwargs.get("base_url"),
api_key=kwargs.get("api_key"),
temperature=kwargs.get("temperature", 0.0),
top_p=kwargs.get("top_p"),
max_tokens=kwargs.get("max_tokens"),
thinking=kwargs.get("thinking", False),
tool_parser=kwargs.get("tool_parser"),
)

View File

@@ -0,0 +1,288 @@
"""
tau2-bench evaluation runner for Hermes Agent.
Runs the tau2-bench retail, airline, telecom, or banking_knowledge evaluation
using HermesAgent backed by litellm — the same inference path used across the
rest of the Hermes Agent codebase.
Usage:
# Against OpenRouter (auto-detects OPENROUTER_API_KEY)
python environments/benchmarks/taubench/run_eval.py \\
--model openrouter/anthropic/claude-sonnet-4-5 \\
--base-url openrouter \\
--env retail
# Against OpenAI directly
python environments/benchmarks/taubench/run_eval.py \\
--model gpt-4o \\
--env retail
# Local vLLM
python environments/benchmarks/taubench/run_eval.py \\
--model openai/NousResearch/Hermes-3-Llama-3.1-70B \\
--base-url http://localhost:8000/v1 \\
--env retail \\
--num-trials 3
# Specific tasks only
python environments/benchmarks/taubench/run_eval.py \\
--model openrouter/anthropic/claude-sonnet-4-5 \\
--base-url openrouter \\
--env retail \\
--task-ids task_1 task_2 task_5
Results are saved to results/tau2bench/ as JSON.
Dependencies (requires Python 3.12+):
pip install "tau2 @ git+https://github.com/sierra-research/tau2-bench.git"
# or: pip install -e ".[tau2bench]"
"""
import argparse
import logging
import os
import sys
from pathlib import Path
from typing import Optional
_repo_root = Path(__file__).resolve().parent.parent.parent.parent
if str(_repo_root) not in sys.path:
sys.path.insert(0, str(_repo_root))
from tau2.data_model.simulation import Results, TextRunConfig
from tau2.evaluator.evaluator import EvaluationType
from tau2.registry import registry
from tau2.runner.batch import run_tasks
from tau2.runner.helpers import get_tasks
from environments.benchmarks.taubench.hermes_agent import create_hermes_agent
logging.basicConfig(
level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s"
)
logger = logging.getLogger(__name__)
OPENROUTER_BASE_URL = "https://openrouter.ai/api/v1"
AGENT_NAME = "hermes_agent"
def _register_agent(
model: str,
base_url: Optional[str],
api_key: Optional[str],
temperature: float,
top_p: Optional[float],
max_tokens: Optional[int],
thinking: bool,
tool_parser: Optional[str],
) -> None:
"""Register the HermesAgent factory with the tau2 registry (idempotent)."""
if registry.get_agent_factory(AGENT_NAME) is not None:
return
def factory(tools, domain_policy, **kwargs):
return create_hermes_agent(
tools=tools,
domain_policy=domain_policy,
model=model,
base_url=base_url,
api_key=api_key,
temperature=temperature,
top_p=top_p,
max_tokens=max_tokens,
thinking=thinking,
tool_parser=tool_parser,
)
registry.register_agent_factory(factory=factory, name=AGENT_NAME)
logger.info("Registered agent factory: %s (model=%s, thinking=%s, tool_parser=%s)", AGENT_NAME, model, thinking, tool_parser)
def run_eval(
model: str,
base_url: Optional[str],
api_key: Optional[str],
user_model: str,
env_name: str,
task_split: Optional[str],
num_trials: int,
max_concurrency: int,
max_steps: int,
temperature: float,
top_p: Optional[float],
max_tokens: Optional[int],
thinking: bool,
tool_parser: Optional[str],
task_ids: Optional[list],
start_index: int,
end_index: int,
log_dir: str,
seed: int,
) -> Results:
# Resolve OpenRouter shorthand
if base_url and base_url.strip().lower() == "openrouter":
base_url = OPENROUTER_BASE_URL
is_openrouter = base_url and "openrouter" in base_url.lower()
# litellm requires the "openrouter/" prefix to route correctly
if is_openrouter and not model.startswith("openrouter/"):
model = f"openrouter/{model}"
if is_openrouter and not user_model.startswith("openrouter/"):
user_model = f"openrouter/{user_model}"
# Resolve API key
if is_openrouter:
api_key = api_key or os.environ.get("OPENROUTER_API_KEY") or os.environ.get("OPENAI_API_KEY")
# litellm reads OPENAI_API_KEY for base_url overrides; set it so the
# user simulator's generate() call also authenticates correctly.
if api_key and not os.environ.get("OPENAI_API_KEY"):
os.environ["OPENAI_API_KEY"] = api_key
else:
api_key = api_key or os.environ.get("OPENAI_API_KEY")
_register_agent(
model=model,
base_url=base_url,
api_key=api_key,
temperature=temperature,
top_p=top_p,
max_tokens=max_tokens,
thinking=thinking,
tool_parser=tool_parser,
)
# Load tasks — task_ids in tau2 are strings like "task_1"
tasks = get_tasks(
task_set_name=env_name,
task_split_name=task_split,
task_ids=[str(i) for i in task_ids] if task_ids else None,
)
if not task_ids and (end_index != -1 or start_index != 0):
end = end_index if end_index != -1 else len(tasks)
tasks = tasks[start_index:end]
logger.info(
"Running tau2-%s eval: %d tasks, %d trial(s), concurrency=%d",
env_name, len(tasks), num_trials, max_concurrency,
)
save_path = Path(log_dir) / f"tau2-{env_name}-{model.split('/')[-1]}.json"
save_path.parent.mkdir(parents=True, exist_ok=True)
# Pass api_key/base_url to user sim via llm_args so tau2's generate() authenticates.
# When using OpenRouter for the user sim, mirror the agent's key + endpoint.
user_llm_args: dict = {}
if is_openrouter and api_key:
user_llm_args["api_key"] = api_key
user_llm_args["base_url"] = base_url
config = TextRunConfig(
domain=env_name,
agent=AGENT_NAME,
user="user_simulator",
llm_agent=model,
llm_args_agent={},
llm_user=user_model,
llm_args_user=user_llm_args,
num_trials=num_trials,
max_steps=max_steps,
max_concurrency=max_concurrency,
seed=seed,
)
results = run_tasks(
config,
tasks,
save_path=save_path,
console_display=True,
# ALL: respects each task's reward_basis. NL assertions are skipped
# gracefully (scored as pass) rather than raising an error, so tasks
# are evaluated only on their actual basis components (DB, ACTION, etc.)
evaluation_type=EvaluationType.ALL,
)
logger.info("Results saved to %s", save_path)
return results
def main():
parser = argparse.ArgumentParser(
description="Run tau2-bench evaluation with Hermes Agent (requires Python 3.12+)",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--model", required=True,
help="litellm model string, e.g. 'openrouter/anthropic/claude-sonnet-4-5' or 'gpt-4o'",
)
parser.add_argument(
"--base-url", default=None,
help="API base URL. Use 'openrouter' as shorthand for https://openrouter.ai/api/v1.",
)
parser.add_argument("--api-key", default=None, help="API key (falls back to OPENROUTER_API_KEY / OPENAI_API_KEY)")
parser.add_argument("--temperature", type=float, default=1.0,
help="Sampling temperature. NVIDIA used 1.0 for nemotron-super.")
parser.add_argument("--top-p", type=float, default=0.95,
help="Nucleus sampling. NVIDIA used 0.95 for nemotron-super.")
parser.add_argument("--max-tokens", type=int, default=None)
parser.add_argument("--thinking", action="store_true", default=False,
help="Enable reasoning/thinking mode (use_reasoning=true). "
"Required to match NVIDIA's reported nemotron-super scores.")
parser.add_argument("--tool-parser", default=None,
help="Tool call parser to use (e.g. 'qwen3_coder'). When set, tools are "
"embedded in the system prompt as <tools> XML and responses are parsed "
"from raw text instead of using OpenAI function calling format.")
parser.add_argument(
"--user-model", default="qwen/qwen3-235b-a22b-2507:nitro",
help="litellm model string for the tau2 user simulator. "
"Defaults to qwen/qwen3-235b-a22b-2507:nitro (instruct, non-thinking) to match NVIDIA's eval setup. "
"When using --base-url openrouter the openrouter/ prefix is added automatically.",
)
parser.add_argument(
"--env", default="retail",
choices=["retail", "airline", "telecom", "banking_knowledge", "mock"],
)
parser.add_argument(
"--task-split", default=None,
help="Task split name (e.g. 'base'). Defaults to the domain default.",
)
parser.add_argument("--num-trials", type=int, default=1)
parser.add_argument("--max-concurrency", type=int, default=8)
parser.add_argument("--max-steps", type=int, default=50)
parser.add_argument(
"--task-ids", nargs="*", default=None,
help="Specific task IDs to run (tau2 task IDs are strings like 'task_1')",
)
parser.add_argument("--start-index", type=int, default=0)
parser.add_argument("--end-index", type=int, default=-1)
parser.add_argument("--seed", type=int, default=10)
parser.add_argument("--log-dir", default="results/tau2bench")
args = parser.parse_args()
run_eval(
model=args.model,
base_url=args.base_url,
api_key=args.api_key,
user_model=args.user_model,
env_name=args.env,
task_split=args.task_split,
num_trials=args.num_trials,
max_concurrency=args.max_concurrency,
max_steps=args.max_steps,
temperature=args.temperature,
top_p=args.top_p,
max_tokens=args.max_tokens,
thinking=args.thinking,
tool_parser=args.tool_parser,
task_ids=args.task_ids,
start_index=args.start_index,
end_index=args.end_index,
log_dir=args.log_dir,
seed=args.seed,
)
if __name__ == "__main__":
main()

View File

@@ -72,6 +72,8 @@ rl = [
"wandb>=0.15.0,<1",
]
yc-bench = ["yc-bench @ git+https://github.com/collinear-ai/yc-bench.git ; python_version >= '3.12'"]
taubench = ["tau-bench @ git+https://github.com/sierra-research/tau-bench.git"]
tau2bench = ["tau2 @ git+https://github.com/sierra-research/tau2-bench.git"]
all = [
"hermes-agent[modal]",
"hermes-agent[daytona]",