mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-06-29 05:06:48 +08:00
Compare commits
2 Commits
feat/head-
...
hermes/her
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fbed199672 | ||
|
|
2d13eb9795 |
@@ -8,7 +8,6 @@ the first 6 and last 4 characters for debuggability.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from typing import Optional
|
||||
|
||||
@@ -16,7 +15,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
# Known API key prefixes -- match the prefix + contiguous token chars
|
||||
_PREFIX_PATTERNS = [
|
||||
r"sk-[A-Za-z0-9_-]{10,}", # OpenAI / OpenRouter / Anthropic (sk-ant-*)
|
||||
r"sk-[A-Za-z0-9_-]{10,}", # OpenAI / OpenRouter
|
||||
r"ghp_[A-Za-z0-9]{10,}", # GitHub PAT (classic)
|
||||
r"github_pat_[A-Za-z0-9_]{10,}", # GitHub PAT (fine-grained)
|
||||
r"xox[baprs]-[A-Za-z0-9-]{10,}", # Slack tokens
|
||||
@@ -26,18 +25,6 @@ _PREFIX_PATTERNS = [
|
||||
r"fc-[A-Za-z0-9]{10,}", # Firecrawl
|
||||
r"bb_live_[A-Za-z0-9_-]{10,}", # BrowserBase
|
||||
r"gAAAA[A-Za-z0-9_=-]{20,}", # Codex encrypted tokens
|
||||
r"AKIA[A-Z0-9]{16}", # AWS Access Key ID
|
||||
r"sk_live_[A-Za-z0-9]{10,}", # Stripe secret key (live)
|
||||
r"sk_test_[A-Za-z0-9]{10,}", # Stripe secret key (test)
|
||||
r"rk_live_[A-Za-z0-9]{10,}", # Stripe restricted key
|
||||
r"SG\.[A-Za-z0-9_-]{10,}", # SendGrid API key
|
||||
r"hf_[A-Za-z0-9]{10,}", # HuggingFace token
|
||||
r"r8_[A-Za-z0-9]{10,}", # Replicate API token
|
||||
r"npm_[A-Za-z0-9]{10,}", # npm access token
|
||||
r"pypi-[A-Za-z0-9_-]{10,}", # PyPI API token
|
||||
r"dop_v1_[A-Za-z0-9]{10,}", # DigitalOcean PAT
|
||||
r"doo_v1_[A-Za-z0-9]{10,}", # DigitalOcean OAuth
|
||||
r"am_[A-Za-z0-9_-]{10,}", # AgentMail API key
|
||||
]
|
||||
|
||||
# ENV assignment patterns: KEY=value where KEY contains a secret-like name
|
||||
@@ -65,18 +52,6 @@ _TELEGRAM_RE = re.compile(
|
||||
r"(bot)?(\d{8,}):([-A-Za-z0-9_]{30,})",
|
||||
)
|
||||
|
||||
# Private key blocks: -----BEGIN RSA PRIVATE KEY----- ... -----END RSA PRIVATE KEY-----
|
||||
_PRIVATE_KEY_RE = re.compile(
|
||||
r"-----BEGIN[A-Z ]*PRIVATE KEY-----[\s\S]*?-----END[A-Z ]*PRIVATE KEY-----"
|
||||
)
|
||||
|
||||
# Database connection strings: protocol://user:PASSWORD@host
|
||||
# Catches postgres, mysql, mongodb, redis, amqp URLs and redacts the password
|
||||
_DB_CONNSTR_RE = re.compile(
|
||||
r"((?:postgres(?:ql)?|mysql|mongodb(?:\+srv)?|redis|amqp)://[^:]+:)([^@]+)(@)",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
# E.164 phone numbers: +<country><number>, 7-15 digits
|
||||
# Negative lookahead prevents matching hex strings or identifiers
|
||||
_SIGNAL_PHONE_RE = re.compile(r"(\+[1-9]\d{6,14})(?![A-Za-z0-9])")
|
||||
@@ -98,12 +73,9 @@ def redact_sensitive_text(text: str) -> str:
|
||||
"""Apply all redaction patterns to a block of text.
|
||||
|
||||
Safe to call on any string -- non-matching text passes through unchanged.
|
||||
Disabled when security.redact_secrets is false in config.yaml.
|
||||
"""
|
||||
if not text:
|
||||
return text
|
||||
if os.getenv("HERMES_REDACT_SECRETS", "").lower() in ("0", "false", "no", "off"):
|
||||
return text
|
||||
|
||||
# Known prefixes (sk-, ghp_, etc.)
|
||||
text = _PREFIX_RE.sub(lambda m: _mask_token(m.group(1)), text)
|
||||
@@ -133,12 +105,6 @@ def redact_sensitive_text(text: str) -> str:
|
||||
return f"{prefix}{digits}:***"
|
||||
text = _TELEGRAM_RE.sub(_redact_telegram, text)
|
||||
|
||||
# Private key blocks
|
||||
text = _PRIVATE_KEY_RE.sub("[REDACTED PRIVATE KEY]", text)
|
||||
|
||||
# Database connection string passwords
|
||||
text = _DB_CONNSTR_RE.sub(lambda m: f"{m.group(1)}***{m.group(3)}", text)
|
||||
|
||||
# E.164 phone numbers (Signal, WhatsApp)
|
||||
def _redact_phone(m):
|
||||
phone = m.group(1)
|
||||
|
||||
@@ -555,6 +555,21 @@ toolsets:
|
||||
# args: ["-y", "@modelcontextprotocol/server-github"]
|
||||
# env:
|
||||
# GITHUB_PERSONAL_ACCESS_TOKEN: "ghp_..."
|
||||
#
|
||||
# Sampling (server-initiated LLM requests) — enabled by default.
|
||||
# Per-server config under the 'sampling' key:
|
||||
# analysis:
|
||||
# command: npx
|
||||
# args: ["-y", "analysis-server"]
|
||||
# sampling:
|
||||
# enabled: true # default: true
|
||||
# model: "gemini-3-flash" # override model (optional)
|
||||
# max_tokens_cap: 4096 # max tokens per request
|
||||
# timeout: 30 # LLM call timeout (seconds)
|
||||
# max_rpm: 10 # max requests per minute
|
||||
# allowed_models: [] # model whitelist (empty = all)
|
||||
# max_tool_rounds: 5 # tool loop limit (0 = disable)
|
||||
# log_level: "info" # audit verbosity
|
||||
|
||||
# =============================================================================
|
||||
# Voice Transcription (Speech-to-Text)
|
||||
|
||||
7
cli.py
7
cli.py
@@ -364,13 +364,6 @@ def load_cli_config() -> Dict[str, Any]:
|
||||
if model:
|
||||
os.environ[model_env] = model
|
||||
|
||||
# Security settings
|
||||
security_config = defaults.get("security", {})
|
||||
if isinstance(security_config, dict):
|
||||
redact = security_config.get("redact_secrets")
|
||||
if redact is not None:
|
||||
os.environ["HERMES_REDACT_SECRETS"] = str(redact).lower()
|
||||
|
||||
return defaults
|
||||
|
||||
# Load configuration at module startup
|
||||
|
||||
@@ -52,7 +52,7 @@ def build_channel_directory(adapters: Dict[Any, Any]) -> Dict[str, Any]:
|
||||
|
||||
try:
|
||||
DIRECTORY_PATH.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(DIRECTORY_PATH, "w", encoding="utf-8") as f:
|
||||
with open(DIRECTORY_PATH, "w") as f:
|
||||
json.dump(directory, f, indent=2, ensure_ascii=False)
|
||||
except Exception as e:
|
||||
logger.warning("Channel directory: failed to write: %s", e)
|
||||
@@ -115,7 +115,7 @@ def _build_from_sessions(platform_name: str) -> List[Dict[str, str]]:
|
||||
|
||||
entries = []
|
||||
try:
|
||||
with open(sessions_path, encoding="utf-8") as f:
|
||||
with open(sessions_path) as f:
|
||||
data = json.load(f)
|
||||
|
||||
seen_ids = set()
|
||||
@@ -147,7 +147,7 @@ def load_directory() -> Dict[str, Any]:
|
||||
if not DIRECTORY_PATH.exists():
|
||||
return {"updated_at": None, "platforms": {}}
|
||||
try:
|
||||
with open(DIRECTORY_PATH, encoding="utf-8") as f:
|
||||
with open(DIRECTORY_PATH) as f:
|
||||
return json.load(f)
|
||||
except Exception:
|
||||
return {"updated_at": None, "platforms": {}}
|
||||
|
||||
@@ -73,7 +73,7 @@ def _find_session_id(platform: str, chat_id: str) -> Optional[str]:
|
||||
return None
|
||||
|
||||
try:
|
||||
with open(_SESSIONS_INDEX, encoding="utf-8") as f:
|
||||
with open(_SESSIONS_INDEX) as f:
|
||||
data = json.load(f)
|
||||
except Exception:
|
||||
return None
|
||||
@@ -103,7 +103,7 @@ def _append_to_jsonl(session_id: str, message: dict) -> None:
|
||||
"""Append a message to the JSONL transcript file."""
|
||||
transcript_path = _SESSIONS_DIR / f"{session_id}.jsonl"
|
||||
try:
|
||||
with open(transcript_path, "a", encoding="utf-8") as f:
|
||||
with open(transcript_path, "a") as f:
|
||||
f.write(json.dumps(message, ensure_ascii=False) + "\n")
|
||||
except Exception as e:
|
||||
logger.debug("Mirror JSONL write failed: %s", e)
|
||||
|
||||
@@ -118,12 +118,6 @@ if _config_path.exists():
|
||||
_tz_cfg = _cfg.get("timezone", "")
|
||||
if _tz_cfg and isinstance(_tz_cfg, str) and "HERMES_TIMEZONE" not in os.environ:
|
||||
os.environ["HERMES_TIMEZONE"] = _tz_cfg.strip()
|
||||
# Security settings
|
||||
_security_cfg = _cfg.get("security", {})
|
||||
if isinstance(_security_cfg, dict):
|
||||
_redact = _security_cfg.get("redact_secrets")
|
||||
if _redact is not None:
|
||||
os.environ["HERMES_REDACT_SECRETS"] = str(_redact).lower()
|
||||
except Exception:
|
||||
pass # Non-fatal; gateway can still run with .env values
|
||||
|
||||
|
||||
@@ -342,7 +342,7 @@ class SessionStore:
|
||||
|
||||
if sessions_file.exists():
|
||||
try:
|
||||
with open(sessions_file, "r", encoding="utf-8") as f:
|
||||
with open(sessions_file, "r") as f:
|
||||
data = json.load(f)
|
||||
for key, entry_data in data.items():
|
||||
self._entries[key] = SessionEntry.from_dict(entry_data)
|
||||
@@ -357,7 +357,7 @@ class SessionStore:
|
||||
sessions_file = self.sessions_dir / "sessions.json"
|
||||
|
||||
data = {key: entry.to_dict() for key, entry in self._entries.items()}
|
||||
with open(sessions_file, "w", encoding="utf-8") as f:
|
||||
with open(sessions_file, "w") as f:
|
||||
json.dump(data, f, indent=2)
|
||||
|
||||
def _generate_session_key(self, source: SessionSource) -> str:
|
||||
@@ -681,7 +681,7 @@ class SessionStore:
|
||||
|
||||
# Also write legacy JSONL (keeps existing tooling working during transition)
|
||||
transcript_path = self.get_transcript_path(session_id)
|
||||
with open(transcript_path, "a", encoding="utf-8") as f:
|
||||
with open(transcript_path, "a") as f:
|
||||
f.write(json.dumps(message, ensure_ascii=False) + "\n")
|
||||
|
||||
def rewrite_transcript(self, session_id: str, messages: List[Dict[str, Any]]) -> None:
|
||||
@@ -708,7 +708,7 @@ class SessionStore:
|
||||
|
||||
# JSONL: overwrite the file
|
||||
transcript_path = self.get_transcript_path(session_id)
|
||||
with open(transcript_path, "w", encoding="utf-8") as f:
|
||||
with open(transcript_path, "w") as f:
|
||||
for msg in messages:
|
||||
f.write(json.dumps(msg, ensure_ascii=False) + "\n")
|
||||
|
||||
@@ -730,7 +730,7 @@ class SessionStore:
|
||||
return []
|
||||
|
||||
messages = []
|
||||
with open(transcript_path, "r", encoding="utf-8") as f:
|
||||
with open(transcript_path, "r") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line:
|
||||
|
||||
@@ -759,16 +759,8 @@ def load_config() -> Dict[str, Any]:
|
||||
return config
|
||||
|
||||
|
||||
_COMMENTED_SECTIONS = """
|
||||
# ── Security ──────────────────────────────────────────────────────────
|
||||
# API keys, tokens, and passwords are redacted from tool output by default.
|
||||
# Set to false to see full values (useful for debugging auth issues).
|
||||
#
|
||||
# security:
|
||||
# redact_secrets: false
|
||||
|
||||
# ── Fallback Model ────────────────────────────────────────────────────
|
||||
# Automatic provider failover when primary is unavailable.
|
||||
_FALLBACK_MODEL_COMMENT = """
|
||||
# Fallback model — automatic provider failover when primary is unavailable.
|
||||
# Uncomment and configure to enable. Triggers on rate limits (429),
|
||||
# overload (529), service errors (503), or connection failures.
|
||||
#
|
||||
@@ -796,18 +788,10 @@ def save_config(config: Dict[str, Any]):
|
||||
|
||||
with open(config_path, 'w') as f:
|
||||
yaml.dump(config, f, default_flow_style=False, sort_keys=False)
|
||||
# Append commented-out sections for features that are off by default
|
||||
# or only relevant when explicitly configured. Skip sections the
|
||||
# user has already uncommented and configured.
|
||||
sections = []
|
||||
sec = config.get("security", {})
|
||||
if not sec or sec.get("redact_secrets") is None:
|
||||
sections.append("security")
|
||||
fb = config.get("fallback_model", {})
|
||||
# Append commented-out fallback_model docs if user hasn't configured it
|
||||
fb = config.get("fallback_model")
|
||||
if not fb or not (fb.get("provider") and fb.get("model")):
|
||||
sections.append("fallback")
|
||||
if sections:
|
||||
f.write(_COMMENTED_SECTIONS)
|
||||
f.write(_FALLBACK_MODEL_COMMENT)
|
||||
|
||||
|
||||
def load_env() -> Dict[str, str]:
|
||||
|
||||
55
run_agent.py
55
run_agent.py
@@ -3092,14 +3092,9 @@ class AIAgent:
|
||||
)
|
||||
self._iters_since_skill = 0
|
||||
|
||||
# Honcho prefetch: retrieve user context for system prompt injection.
|
||||
# Only on the FIRST turn of a session (empty history). On subsequent
|
||||
# turns the model already has all prior context in its conversation
|
||||
# history, and the Honcho context is baked into the stored system
|
||||
# prompt — re-fetching it would change the system message and break
|
||||
# Anthropic prompt caching.
|
||||
# Honcho prefetch: retrieve user context for system prompt injection
|
||||
self._honcho_context = ""
|
||||
if self._honcho and self._honcho_session_key and not conversation_history:
|
||||
if self._honcho and self._honcho_session_key:
|
||||
try:
|
||||
self._honcho_context = self._honcho_prefetch(user_message)
|
||||
except Exception as e:
|
||||
@@ -3117,42 +3112,14 @@ class AIAgent:
|
||||
# Built once on first call, reused for all subsequent calls.
|
||||
# Only rebuilt after context compression events (which invalidate
|
||||
# the cache and reload memory from disk).
|
||||
#
|
||||
# For continuing sessions (gateway creates a fresh AIAgent per
|
||||
# message), we load the stored system prompt from the session DB
|
||||
# instead of rebuilding. Rebuilding would pick up memory changes
|
||||
# from disk that the model already knows about (it wrote them!),
|
||||
# producing a different system prompt and breaking the Anthropic
|
||||
# prefix cache.
|
||||
if self._cached_system_prompt is None:
|
||||
stored_prompt = None
|
||||
if conversation_history and self._session_db:
|
||||
self._cached_system_prompt = self._build_system_prompt(system_message)
|
||||
# Store the system prompt snapshot in SQLite
|
||||
if self._session_db:
|
||||
try:
|
||||
session_row = self._session_db.get_session(self.session_id)
|
||||
if session_row:
|
||||
stored_prompt = session_row.get("system_prompt") or None
|
||||
except Exception:
|
||||
pass # Fall through to build fresh
|
||||
|
||||
if stored_prompt:
|
||||
# Continuing session — reuse the exact system prompt from
|
||||
# the previous turn so the Anthropic cache prefix matches.
|
||||
self._cached_system_prompt = stored_prompt
|
||||
else:
|
||||
# First turn of a new session — build from scratch.
|
||||
self._cached_system_prompt = self._build_system_prompt(system_message)
|
||||
# Bake Honcho context into the prompt so it's stable for
|
||||
# the entire session (not re-fetched per turn).
|
||||
if self._honcho_context:
|
||||
self._cached_system_prompt = (
|
||||
self._cached_system_prompt + "\n\n" + self._honcho_context
|
||||
).strip()
|
||||
# Store the system prompt snapshot in SQLite
|
||||
if self._session_db:
|
||||
try:
|
||||
self._session_db.update_system_prompt(self.session_id, self._cached_system_prompt)
|
||||
except Exception as e:
|
||||
logger.debug("Session DB update_system_prompt failed: %s", e)
|
||||
self._session_db.update_system_prompt(self.session_id, self._cached_system_prompt)
|
||||
except Exception as e:
|
||||
logger.debug("Session DB update_system_prompt failed: %s", e)
|
||||
|
||||
active_system_prompt = self._cached_system_prompt
|
||||
|
||||
@@ -3277,13 +3244,11 @@ class AIAgent:
|
||||
# Build the final system message: cached prompt + ephemeral system prompt.
|
||||
# The ephemeral part is appended here (not baked into the cached prompt)
|
||||
# so it stays out of the session DB and logs.
|
||||
# Note: Honcho context is baked into _cached_system_prompt on the first
|
||||
# turn and stored in the session DB, so it does NOT need to be injected
|
||||
# here. This keeps the system message identical across all turns in a
|
||||
# session, maximizing Anthropic prompt cache hits.
|
||||
effective_system = active_system_prompt or ""
|
||||
if self.ephemeral_system_prompt:
|
||||
effective_system = (effective_system + "\n\n" + self.ephemeral_system_prompt).strip()
|
||||
if self._honcho_context:
|
||||
effective_system = (effective_system + "\n\n" + self._honcho_context).strip()
|
||||
if effective_system:
|
||||
api_messages = [{"role": "system", "content": effective_system}] + api_messages
|
||||
|
||||
|
||||
@@ -321,6 +321,32 @@ mcp_servers:
|
||||
|
||||
All tools from all servers are registered and available simultaneously. Each server's tools are prefixed with its name to avoid collisions.
|
||||
|
||||
## Sampling (Server-Initiated LLM Requests)
|
||||
|
||||
Hermes supports MCP's `sampling/createMessage` capability — MCP servers can request LLM completions through the agent during tool execution. This enables agent-in-the-loop workflows (data analysis, content generation, decision-making).
|
||||
|
||||
Sampling is **enabled by default**. Configure per server:
|
||||
|
||||
```yaml
|
||||
mcp_servers:
|
||||
my_server:
|
||||
command: "npx"
|
||||
args: ["-y", "my-mcp-server"]
|
||||
sampling:
|
||||
enabled: true # default: true
|
||||
model: "gemini-3-flash" # model override (optional)
|
||||
max_tokens_cap: 4096 # max tokens per request
|
||||
timeout: 30 # LLM call timeout (seconds)
|
||||
max_rpm: 10 # max requests per minute
|
||||
allowed_models: [] # model whitelist (empty = all)
|
||||
max_tool_rounds: 5 # tool loop limit (0 = disable)
|
||||
log_level: "info" # audit verbosity
|
||||
```
|
||||
|
||||
Servers can also include `tools` in sampling requests for multi-turn tool-augmented workflows. The `max_tool_rounds` config prevents infinite tool loops. Per-server audit metrics (requests, errors, tokens, tool use count) are tracked via `get_mcp_status()`.
|
||||
|
||||
Disable sampling for untrusted servers with `sampling: { enabled: false }`.
|
||||
|
||||
## Notes
|
||||
|
||||
- MCP tools are called synchronously from the agent's perspective but run asynchronously on a dedicated background event loop
|
||||
|
||||
@@ -1040,136 +1040,3 @@ class TestMaxTokensParam:
|
||||
agent.base_url = "https://openrouter.ai/api/v1/api.openai.com"
|
||||
result = agent._max_tokens_param(4096)
|
||||
assert result == {"max_tokens": 4096}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# System prompt stability for prompt caching
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSystemPromptStability:
|
||||
"""Verify that the system prompt stays stable across turns for cache hits."""
|
||||
|
||||
def test_stored_prompt_reused_for_continuing_session(self, agent):
|
||||
"""When conversation_history is non-empty and session DB has a stored
|
||||
prompt, it should be reused instead of rebuilding from disk."""
|
||||
stored = "You are helpful. [stored from turn 1]"
|
||||
mock_db = MagicMock()
|
||||
mock_db.get_session.return_value = {"system_prompt": stored}
|
||||
agent._session_db = mock_db
|
||||
|
||||
# Simulate a continuing session with history
|
||||
history = [
|
||||
{"role": "user", "content": "hello"},
|
||||
{"role": "assistant", "content": "hi"},
|
||||
]
|
||||
|
||||
# First call — _cached_system_prompt is None, history is non-empty
|
||||
agent._cached_system_prompt = None
|
||||
|
||||
# Patch run_conversation internals to just test the system prompt logic.
|
||||
# We'll call the prompt caching block directly by simulating what
|
||||
# run_conversation does.
|
||||
conversation_history = history
|
||||
|
||||
# The block under test (from run_conversation):
|
||||
if agent._cached_system_prompt is None:
|
||||
stored_prompt = None
|
||||
if conversation_history and agent._session_db:
|
||||
try:
|
||||
session_row = agent._session_db.get_session(agent.session_id)
|
||||
if session_row:
|
||||
stored_prompt = session_row.get("system_prompt") or None
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if stored_prompt:
|
||||
agent._cached_system_prompt = stored_prompt
|
||||
|
||||
assert agent._cached_system_prompt == stored
|
||||
mock_db.get_session.assert_called_once_with(agent.session_id)
|
||||
|
||||
def test_fresh_build_when_no_history(self, agent):
|
||||
"""On the first turn (no history), system prompt should be built fresh."""
|
||||
mock_db = MagicMock()
|
||||
agent._session_db = mock_db
|
||||
|
||||
agent._cached_system_prompt = None
|
||||
conversation_history = []
|
||||
|
||||
# The block under test:
|
||||
if agent._cached_system_prompt is None:
|
||||
stored_prompt = None
|
||||
if conversation_history and agent._session_db:
|
||||
session_row = agent._session_db.get_session(agent.session_id)
|
||||
if session_row:
|
||||
stored_prompt = session_row.get("system_prompt") or None
|
||||
|
||||
if stored_prompt:
|
||||
agent._cached_system_prompt = stored_prompt
|
||||
else:
|
||||
agent._cached_system_prompt = agent._build_system_prompt()
|
||||
|
||||
# Should have built fresh, not queried the DB
|
||||
mock_db.get_session.assert_not_called()
|
||||
assert agent._cached_system_prompt is not None
|
||||
assert "Hermes Agent" in agent._cached_system_prompt
|
||||
|
||||
def test_fresh_build_when_db_has_no_prompt(self, agent):
|
||||
"""If the session DB has no stored prompt, build fresh even with history."""
|
||||
mock_db = MagicMock()
|
||||
mock_db.get_session.return_value = {"system_prompt": ""}
|
||||
agent._session_db = mock_db
|
||||
|
||||
agent._cached_system_prompt = None
|
||||
conversation_history = [{"role": "user", "content": "hi"}]
|
||||
|
||||
if agent._cached_system_prompt is None:
|
||||
stored_prompt = None
|
||||
if conversation_history and agent._session_db:
|
||||
try:
|
||||
session_row = agent._session_db.get_session(agent.session_id)
|
||||
if session_row:
|
||||
stored_prompt = session_row.get("system_prompt") or None
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if stored_prompt:
|
||||
agent._cached_system_prompt = stored_prompt
|
||||
else:
|
||||
agent._cached_system_prompt = agent._build_system_prompt()
|
||||
|
||||
# Empty string is falsy, so should fall through to fresh build
|
||||
assert "Hermes Agent" in agent._cached_system_prompt
|
||||
|
||||
def test_honcho_context_baked_into_prompt_on_first_turn(self, agent):
|
||||
"""Honcho context should be baked into _cached_system_prompt on
|
||||
the first turn, not injected separately per API call."""
|
||||
agent._honcho_context = "User prefers Python over JavaScript."
|
||||
agent._cached_system_prompt = None
|
||||
|
||||
# Simulate first turn: build fresh and bake in Honcho
|
||||
agent._cached_system_prompt = agent._build_system_prompt()
|
||||
if agent._honcho_context:
|
||||
agent._cached_system_prompt = (
|
||||
agent._cached_system_prompt + "\n\n" + agent._honcho_context
|
||||
).strip()
|
||||
|
||||
assert "User prefers Python over JavaScript" in agent._cached_system_prompt
|
||||
|
||||
def test_honcho_prefetch_skipped_on_continuing_session(self):
|
||||
"""Honcho prefetch should not be called when conversation_history
|
||||
is non-empty (continuing session)."""
|
||||
conversation_history = [
|
||||
{"role": "user", "content": "hello"},
|
||||
{"role": "assistant", "content": "hi there"},
|
||||
]
|
||||
|
||||
# The guard: `not conversation_history` is False when history exists
|
||||
should_prefetch = not conversation_history
|
||||
assert should_prefetch is False
|
||||
|
||||
def test_honcho_prefetch_runs_on_first_turn(self):
|
||||
"""Honcho prefetch should run when conversation_history is empty."""
|
||||
conversation_history = []
|
||||
should_prefetch = not conversation_history
|
||||
assert should_prefetch is True
|
||||
|
||||
@@ -393,56 +393,5 @@ class TestStubSchemaDrift(unittest.TestCase):
|
||||
self.assertIn("mode", src)
|
||||
|
||||
|
||||
class TestHeadTailTruncation(unittest.TestCase):
|
||||
"""Tests for head+tail truncation of large stdout in execute_code."""
|
||||
|
||||
def _run(self, code):
|
||||
with patch("model_tools.handle_function_call", side_effect=_mock_handle_function_call):
|
||||
result = execute_code(
|
||||
code=code,
|
||||
task_id="test-task",
|
||||
enabled_tools=list(SANDBOX_ALLOWED_TOOLS),
|
||||
)
|
||||
return json.loads(result)
|
||||
|
||||
def test_short_output_not_truncated(self):
|
||||
"""Output under MAX_STDOUT_BYTES should not be truncated."""
|
||||
result = self._run('print("small output")')
|
||||
self.assertEqual(result["status"], "success")
|
||||
self.assertIn("small output", result["output"])
|
||||
self.assertNotIn("TRUNCATED", result["output"])
|
||||
|
||||
def test_large_output_preserves_head_and_tail(self):
|
||||
"""Output exceeding MAX_STDOUT_BYTES keeps both head and tail."""
|
||||
code = '''
|
||||
# Print HEAD marker, then filler, then TAIL marker
|
||||
print("HEAD_MARKER_START")
|
||||
for i in range(15000):
|
||||
print(f"filler_line_{i:06d}_padding_to_fill_buffer")
|
||||
print("TAIL_MARKER_END")
|
||||
'''
|
||||
result = self._run(code)
|
||||
self.assertEqual(result["status"], "success")
|
||||
output = result["output"]
|
||||
# Head should be preserved
|
||||
self.assertIn("HEAD_MARKER_START", output)
|
||||
# Tail should be preserved (this is the key improvement)
|
||||
self.assertIn("TAIL_MARKER_END", output)
|
||||
# Truncation notice should be present
|
||||
self.assertIn("TRUNCATED", output)
|
||||
|
||||
def test_truncation_notice_format(self):
|
||||
"""Truncation notice includes character counts."""
|
||||
code = '''
|
||||
for i in range(15000):
|
||||
print(f"padding_line_{i:06d}_xxxxxxxxxxxxxxxxxxxxxxxxxx")
|
||||
'''
|
||||
result = self._run(code)
|
||||
output = result["output"]
|
||||
if "TRUNCATED" in output:
|
||||
self.assertIn("chars omitted", output)
|
||||
self.assertIn("total", output)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -38,7 +38,6 @@ class TestReadFileHandler:
|
||||
def test_returns_file_content(self, mock_get):
|
||||
mock_ops = MagicMock()
|
||||
result_obj = MagicMock()
|
||||
result_obj.content = "line1\nline2"
|
||||
result_obj.to_dict.return_value = {"content": "line1\nline2", "total_lines": 2}
|
||||
mock_ops.read_file.return_value = result_obj
|
||||
mock_get.return_value = mock_ops
|
||||
@@ -53,7 +52,6 @@ class TestReadFileHandler:
|
||||
def test_custom_offset_and_limit(self, mock_get):
|
||||
mock_ops = MagicMock()
|
||||
result_obj = MagicMock()
|
||||
result_obj.content = "line10"
|
||||
result_obj.to_dict.return_value = {"content": "line10", "total_lines": 50}
|
||||
mock_ops.read_file.return_value = result_obj
|
||||
mock_get.return_value = mock_ops
|
||||
|
||||
@@ -1489,3 +1489,781 @@ class TestUtilityToolRegistration:
|
||||
assert entry.check_fn() is False
|
||||
|
||||
_servers.pop("chk", None)
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# SamplingHandler tests
|
||||
# ===========================================================================
|
||||
|
||||
import math
|
||||
import time
|
||||
|
||||
from mcp.types import (
|
||||
CreateMessageResult,
|
||||
CreateMessageResultWithTools,
|
||||
ErrorData,
|
||||
SamplingCapability,
|
||||
SamplingToolsCapability,
|
||||
TextContent,
|
||||
ToolUseContent,
|
||||
)
|
||||
|
||||
from tools.mcp_tool import SamplingHandler, _safe_numeric
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers for sampling tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_sampling_params(
|
||||
messages=None,
|
||||
max_tokens=100,
|
||||
system_prompt=None,
|
||||
model_preferences=None,
|
||||
temperature=None,
|
||||
stop_sequences=None,
|
||||
tools=None,
|
||||
tool_choice=None,
|
||||
):
|
||||
"""Create a fake CreateMessageRequestParams using SimpleNamespace.
|
||||
|
||||
Each message must have a ``content_as_list`` attribute that mirrors
|
||||
the SDK helper so that ``_convert_messages`` works correctly.
|
||||
"""
|
||||
if messages is None:
|
||||
content = SimpleNamespace(text="Hello")
|
||||
msg = SimpleNamespace(role="user", content=content, content_as_list=[content])
|
||||
messages = [msg]
|
||||
|
||||
params = SimpleNamespace(
|
||||
messages=messages,
|
||||
maxTokens=max_tokens,
|
||||
modelPreferences=model_preferences,
|
||||
temperature=temperature,
|
||||
stopSequences=stop_sequences,
|
||||
tools=tools,
|
||||
toolChoice=tool_choice,
|
||||
)
|
||||
if system_prompt is not None:
|
||||
params.systemPrompt = system_prompt
|
||||
return params
|
||||
|
||||
|
||||
def _make_llm_response(
|
||||
content="LLM response",
|
||||
model="test-model",
|
||||
finish_reason="stop",
|
||||
tool_calls=None,
|
||||
):
|
||||
"""Create a fake OpenAI chat completion response (text)."""
|
||||
message = SimpleNamespace(content=content, tool_calls=tool_calls)
|
||||
choice = SimpleNamespace(
|
||||
finish_reason=finish_reason,
|
||||
message=message,
|
||||
)
|
||||
usage = SimpleNamespace(total_tokens=42)
|
||||
return SimpleNamespace(choices=[choice], model=model, usage=usage)
|
||||
|
||||
|
||||
def _make_llm_tool_response(tool_calls_data=None, model="test-model"):
|
||||
"""Create a fake response with tool_calls.
|
||||
|
||||
``tool_calls_data``: list of (id, name, arguments_json) tuples.
|
||||
"""
|
||||
if tool_calls_data is None:
|
||||
tool_calls_data = [("call_1", "get_weather", '{"city": "London"}')]
|
||||
|
||||
tc_list = [
|
||||
SimpleNamespace(
|
||||
id=tc_id,
|
||||
function=SimpleNamespace(name=name, arguments=args),
|
||||
)
|
||||
for tc_id, name, args in tool_calls_data
|
||||
]
|
||||
return _make_llm_response(
|
||||
content=None,
|
||||
model=model,
|
||||
finish_reason="tool_calls",
|
||||
tool_calls=tc_list,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 1. _safe_numeric helper
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSafeNumeric:
|
||||
def test_int_passthrough(self):
|
||||
assert _safe_numeric(10, 5, int) == 10
|
||||
|
||||
def test_string_coercion(self):
|
||||
assert _safe_numeric("20", 5, int) == 20
|
||||
|
||||
def test_none_returns_default(self):
|
||||
assert _safe_numeric(None, 7, int) == 7
|
||||
|
||||
def test_inf_returns_default(self):
|
||||
assert _safe_numeric(float("inf"), 3.0, float) == 3.0
|
||||
|
||||
def test_nan_returns_default(self):
|
||||
assert _safe_numeric(float("nan"), 4.0, float) == 4.0
|
||||
|
||||
def test_below_minimum_clamps(self):
|
||||
assert _safe_numeric(-5, 10, int, minimum=1) == 1
|
||||
|
||||
def test_minimum_zero_allowed(self):
|
||||
assert _safe_numeric(0, 10, int, minimum=0) == 0
|
||||
|
||||
def test_non_numeric_string_returns_default(self):
|
||||
assert _safe_numeric("abc", 42, int) == 42
|
||||
|
||||
def test_float_coercion(self):
|
||||
assert _safe_numeric("3.5", 1.0, float) == 3.5
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 2. SamplingHandler initialization and config parsing
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSamplingHandlerInit:
|
||||
def test_defaults(self):
|
||||
h = SamplingHandler("srv", {})
|
||||
assert h.server_name == "srv"
|
||||
assert h.max_rpm == 10
|
||||
assert h.timeout == 30
|
||||
assert h.max_tokens_cap == 4096
|
||||
assert h.max_tool_rounds == 5
|
||||
assert h.model_override is None
|
||||
assert h.allowed_models == []
|
||||
assert h.metrics == {"requests": 0, "errors": 0, "tokens_used": 0, "tool_use_count": 0}
|
||||
|
||||
def test_custom_config(self):
|
||||
cfg = {
|
||||
"max_rpm": 20,
|
||||
"timeout": 60,
|
||||
"max_tokens_cap": 2048,
|
||||
"max_tool_rounds": 3,
|
||||
"model": "gpt-4o",
|
||||
"allowed_models": ["gpt-4o", "gpt-3.5-turbo"],
|
||||
"log_level": "debug",
|
||||
}
|
||||
h = SamplingHandler("custom", cfg)
|
||||
assert h.max_rpm == 20
|
||||
assert h.timeout == 60.0
|
||||
assert h.max_tokens_cap == 2048
|
||||
assert h.max_tool_rounds == 3
|
||||
assert h.model_override == "gpt-4o"
|
||||
assert h.allowed_models == ["gpt-4o", "gpt-3.5-turbo"]
|
||||
|
||||
def test_string_numeric_config_values(self):
|
||||
"""YAML sometimes delivers numeric values as strings."""
|
||||
cfg = {"max_rpm": "15", "timeout": "45.5", "max_tokens_cap": "1024"}
|
||||
h = SamplingHandler("s", cfg)
|
||||
assert h.max_rpm == 15
|
||||
assert h.timeout == 45.5
|
||||
assert h.max_tokens_cap == 1024
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 3. Rate limiting
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestRateLimit:
|
||||
def setup_method(self):
|
||||
self.handler = SamplingHandler("rl", {"max_rpm": 3})
|
||||
|
||||
def test_allows_under_limit(self):
|
||||
assert self.handler._check_rate_limit() is True
|
||||
assert self.handler._check_rate_limit() is True
|
||||
assert self.handler._check_rate_limit() is True
|
||||
|
||||
def test_rejects_over_limit(self):
|
||||
for _ in range(3):
|
||||
self.handler._check_rate_limit()
|
||||
assert self.handler._check_rate_limit() is False
|
||||
|
||||
def test_window_expiry(self):
|
||||
"""Old timestamps should be purged from the sliding window."""
|
||||
for _ in range(3):
|
||||
self.handler._check_rate_limit()
|
||||
# Simulate timestamps from 61 seconds ago
|
||||
self.handler._rate_timestamps[:] = [time.time() - 61] * 3
|
||||
assert self.handler._check_rate_limit() is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 4. Model resolution
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestResolveModel:
|
||||
def setup_method(self):
|
||||
self.handler = SamplingHandler("mr", {})
|
||||
|
||||
def test_no_preference_no_override(self):
|
||||
assert self.handler._resolve_model(None) is None
|
||||
|
||||
def test_config_override_wins(self):
|
||||
self.handler.model_override = "override-model"
|
||||
prefs = SimpleNamespace(hints=[SimpleNamespace(name="hint-model")])
|
||||
assert self.handler._resolve_model(prefs) == "override-model"
|
||||
|
||||
def test_hint_used_when_no_override(self):
|
||||
prefs = SimpleNamespace(hints=[SimpleNamespace(name="hint-model")])
|
||||
assert self.handler._resolve_model(prefs) == "hint-model"
|
||||
|
||||
def test_empty_hints(self):
|
||||
prefs = SimpleNamespace(hints=[])
|
||||
assert self.handler._resolve_model(prefs) is None
|
||||
|
||||
def test_hint_without_name(self):
|
||||
prefs = SimpleNamespace(hints=[SimpleNamespace(name=None)])
|
||||
assert self.handler._resolve_model(prefs) is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 5. Message conversion
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestConvertMessages:
|
||||
def setup_method(self):
|
||||
self.handler = SamplingHandler("mc", {})
|
||||
|
||||
def test_single_text_message(self):
|
||||
content = SimpleNamespace(text="Hello world")
|
||||
msg = SimpleNamespace(role="user", content=content, content_as_list=[content])
|
||||
params = _make_sampling_params(messages=[msg])
|
||||
result = self.handler._convert_messages(params)
|
||||
assert len(result) == 1
|
||||
assert result[0] == {"role": "user", "content": "Hello world"}
|
||||
|
||||
def test_image_message(self):
|
||||
text_block = SimpleNamespace(text="Look at this")
|
||||
img_block = SimpleNamespace(data="abc123", mimeType="image/png")
|
||||
msg = SimpleNamespace(
|
||||
role="user",
|
||||
content=[text_block, img_block],
|
||||
content_as_list=[text_block, img_block],
|
||||
)
|
||||
params = _make_sampling_params(messages=[msg])
|
||||
result = self.handler._convert_messages(params)
|
||||
assert len(result) == 1
|
||||
parts = result[0]["content"]
|
||||
assert len(parts) == 2
|
||||
assert parts[0] == {"type": "text", "text": "Look at this"}
|
||||
assert parts[1]["type"] == "image_url"
|
||||
assert "data:image/png;base64,abc123" in parts[1]["image_url"]["url"]
|
||||
|
||||
def test_tool_result_message(self):
|
||||
inner = SimpleNamespace(text="42 degrees")
|
||||
tr_block = SimpleNamespace(toolUseId="call_1", content=[inner])
|
||||
msg = SimpleNamespace(
|
||||
role="user",
|
||||
content=[tr_block],
|
||||
content_as_list=[tr_block],
|
||||
)
|
||||
params = _make_sampling_params(messages=[msg])
|
||||
result = self.handler._convert_messages(params)
|
||||
assert len(result) == 1
|
||||
assert result[0]["role"] == "tool"
|
||||
assert result[0]["tool_call_id"] == "call_1"
|
||||
assert result[0]["content"] == "42 degrees"
|
||||
|
||||
def test_tool_use_message(self):
|
||||
tu_block = SimpleNamespace(
|
||||
id="call_2", name="get_weather", input={"city": "London"}
|
||||
)
|
||||
msg = SimpleNamespace(
|
||||
role="assistant",
|
||||
content=[tu_block],
|
||||
content_as_list=[tu_block],
|
||||
)
|
||||
params = _make_sampling_params(messages=[msg])
|
||||
result = self.handler._convert_messages(params)
|
||||
assert len(result) == 1
|
||||
assert result[0]["role"] == "assistant"
|
||||
assert len(result[0]["tool_calls"]) == 1
|
||||
assert result[0]["tool_calls"][0]["function"]["name"] == "get_weather"
|
||||
assert json.loads(result[0]["tool_calls"][0]["function"]["arguments"]) == {"city": "London"}
|
||||
|
||||
def test_mixed_text_and_tool_use(self):
|
||||
"""Assistant message with both text and tool_calls."""
|
||||
text_block = SimpleNamespace(text="Let me check the weather")
|
||||
tu_block = SimpleNamespace(
|
||||
id="call_3", name="get_weather", input={"city": "Paris"}
|
||||
)
|
||||
msg = SimpleNamespace(
|
||||
role="assistant",
|
||||
content=[text_block, tu_block],
|
||||
content_as_list=[text_block, tu_block],
|
||||
)
|
||||
params = _make_sampling_params(messages=[msg])
|
||||
result = self.handler._convert_messages(params)
|
||||
assert len(result) == 1
|
||||
assert result[0]["content"] == "Let me check the weather"
|
||||
assert len(result[0]["tool_calls"]) == 1
|
||||
|
||||
def test_fallback_without_content_as_list(self):
|
||||
"""When content_as_list is absent, falls back to content."""
|
||||
content = SimpleNamespace(text="Fallback text")
|
||||
msg = SimpleNamespace(role="user", content=content)
|
||||
params = _make_sampling_params(messages=[msg])
|
||||
result = self.handler._convert_messages(params)
|
||||
assert len(result) == 1
|
||||
assert result[0]["content"] == "Fallback text"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 6. Text-only sampling callback (full flow)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSamplingCallbackText:
|
||||
def setup_method(self):
|
||||
self.handler = SamplingHandler("txt", {})
|
||||
|
||||
def test_text_response(self):
|
||||
"""Full flow: text response returns CreateMessageResult."""
|
||||
fake_client = MagicMock()
|
||||
fake_client.chat.completions.create.return_value = _make_llm_response(
|
||||
content="Hello from LLM"
|
||||
)
|
||||
|
||||
with patch(
|
||||
"agent.auxiliary_client.get_text_auxiliary_client",
|
||||
return_value=(fake_client, "default-model"),
|
||||
):
|
||||
params = _make_sampling_params()
|
||||
result = asyncio.run(self.handler(None, params))
|
||||
|
||||
assert isinstance(result, CreateMessageResult)
|
||||
assert isinstance(result.content, TextContent)
|
||||
assert result.content.text == "Hello from LLM"
|
||||
assert result.model == "test-model"
|
||||
assert result.role == "assistant"
|
||||
assert result.stopReason == "endTurn"
|
||||
|
||||
def test_system_prompt_prepended(self):
|
||||
"""System prompt is inserted as the first message."""
|
||||
fake_client = MagicMock()
|
||||
fake_client.chat.completions.create.return_value = _make_llm_response()
|
||||
|
||||
with patch(
|
||||
"agent.auxiliary_client.get_text_auxiliary_client",
|
||||
return_value=(fake_client, "default-model"),
|
||||
):
|
||||
params = _make_sampling_params(system_prompt="Be helpful")
|
||||
asyncio.run(self.handler(None, params))
|
||||
|
||||
call_args = fake_client.chat.completions.create.call_args
|
||||
messages = call_args.kwargs["messages"]
|
||||
assert messages[0] == {"role": "system", "content": "Be helpful"}
|
||||
|
||||
def test_length_stop_reason(self):
|
||||
"""finish_reason='length' maps to stopReason='maxTokens'."""
|
||||
fake_client = MagicMock()
|
||||
fake_client.chat.completions.create.return_value = _make_llm_response(
|
||||
finish_reason="length"
|
||||
)
|
||||
|
||||
with patch(
|
||||
"agent.auxiliary_client.get_text_auxiliary_client",
|
||||
return_value=(fake_client, "default-model"),
|
||||
):
|
||||
params = _make_sampling_params()
|
||||
result = asyncio.run(self.handler(None, params))
|
||||
|
||||
assert isinstance(result, CreateMessageResult)
|
||||
assert result.stopReason == "maxTokens"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 7. Tool use sampling callback
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSamplingCallbackToolUse:
|
||||
def setup_method(self):
|
||||
self.handler = SamplingHandler("tu", {})
|
||||
|
||||
def test_tool_use_response(self):
|
||||
"""LLM tool_calls response returns CreateMessageResultWithTools."""
|
||||
fake_client = MagicMock()
|
||||
fake_client.chat.completions.create.return_value = _make_llm_tool_response()
|
||||
|
||||
with patch(
|
||||
"agent.auxiliary_client.get_text_auxiliary_client",
|
||||
return_value=(fake_client, "default-model"),
|
||||
):
|
||||
params = _make_sampling_params()
|
||||
result = asyncio.run(self.handler(None, params))
|
||||
|
||||
assert isinstance(result, CreateMessageResultWithTools)
|
||||
assert result.stopReason == "toolUse"
|
||||
assert result.model == "test-model"
|
||||
assert len(result.content) == 1
|
||||
tc = result.content[0]
|
||||
assert isinstance(tc, ToolUseContent)
|
||||
assert tc.name == "get_weather"
|
||||
assert tc.id == "call_1"
|
||||
assert tc.input == {"city": "London"}
|
||||
|
||||
def test_multiple_tool_calls(self):
|
||||
"""Multiple tool_calls in a single response."""
|
||||
fake_client = MagicMock()
|
||||
fake_client.chat.completions.create.return_value = _make_llm_tool_response(
|
||||
tool_calls_data=[
|
||||
("call_a", "func_a", '{"x": 1}'),
|
||||
("call_b", "func_b", '{"y": 2}'),
|
||||
]
|
||||
)
|
||||
|
||||
with patch(
|
||||
"agent.auxiliary_client.get_text_auxiliary_client",
|
||||
return_value=(fake_client, "default-model"),
|
||||
):
|
||||
result = asyncio.run(self.handler(None, _make_sampling_params()))
|
||||
|
||||
assert isinstance(result, CreateMessageResultWithTools)
|
||||
assert len(result.content) == 2
|
||||
assert result.content[0].name == "func_a"
|
||||
assert result.content[1].name == "func_b"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 8. Tool loop governance
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestToolLoopGovernance:
|
||||
def test_max_tool_rounds_enforcement(self):
|
||||
"""After max_tool_rounds consecutive tool responses, an error is returned."""
|
||||
handler = SamplingHandler("tl", {"max_tool_rounds": 2})
|
||||
fake_client = MagicMock()
|
||||
fake_client.chat.completions.create.return_value = _make_llm_tool_response()
|
||||
|
||||
with patch(
|
||||
"agent.auxiliary_client.get_text_auxiliary_client",
|
||||
return_value=(fake_client, "default-model"),
|
||||
):
|
||||
params = _make_sampling_params()
|
||||
# Round 1, 2: allowed
|
||||
r1 = asyncio.run(handler(None, params))
|
||||
assert isinstance(r1, CreateMessageResultWithTools)
|
||||
r2 = asyncio.run(handler(None, params))
|
||||
assert isinstance(r2, CreateMessageResultWithTools)
|
||||
# Round 3: exceeds limit
|
||||
r3 = asyncio.run(handler(None, params))
|
||||
assert isinstance(r3, ErrorData)
|
||||
assert "Tool loop limit exceeded" in r3.message
|
||||
|
||||
def test_text_response_resets_counter(self):
|
||||
"""A text response resets the tool loop counter."""
|
||||
handler = SamplingHandler("tl2", {"max_tool_rounds": 1})
|
||||
fake_client = MagicMock()
|
||||
|
||||
with patch(
|
||||
"agent.auxiliary_client.get_text_auxiliary_client",
|
||||
return_value=(fake_client, "default-model"),
|
||||
):
|
||||
# Tool response (round 1 of 1 allowed)
|
||||
fake_client.chat.completions.create.return_value = _make_llm_tool_response()
|
||||
r1 = asyncio.run(handler(None, _make_sampling_params()))
|
||||
assert isinstance(r1, CreateMessageResultWithTools)
|
||||
|
||||
# Text response resets counter
|
||||
fake_client.chat.completions.create.return_value = _make_llm_response()
|
||||
r2 = asyncio.run(handler(None, _make_sampling_params()))
|
||||
assert isinstance(r2, CreateMessageResult)
|
||||
|
||||
# Tool response again (should succeed since counter was reset)
|
||||
fake_client.chat.completions.create.return_value = _make_llm_tool_response()
|
||||
r3 = asyncio.run(handler(None, _make_sampling_params()))
|
||||
assert isinstance(r3, CreateMessageResultWithTools)
|
||||
|
||||
def test_max_tool_rounds_zero_disables(self):
|
||||
"""max_tool_rounds=0 means tool loops are disabled entirely."""
|
||||
handler = SamplingHandler("tl3", {"max_tool_rounds": 0})
|
||||
fake_client = MagicMock()
|
||||
fake_client.chat.completions.create.return_value = _make_llm_tool_response()
|
||||
|
||||
with patch(
|
||||
"agent.auxiliary_client.get_text_auxiliary_client",
|
||||
return_value=(fake_client, "default-model"),
|
||||
):
|
||||
result = asyncio.run(handler(None, _make_sampling_params()))
|
||||
assert isinstance(result, ErrorData)
|
||||
assert "Tool loops disabled" in result.message
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 9. Error paths: rate limit, timeout, no provider
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSamplingErrors:
|
||||
def test_rate_limit_error(self):
|
||||
handler = SamplingHandler("rle", {"max_rpm": 1})
|
||||
fake_client = MagicMock()
|
||||
fake_client.chat.completions.create.return_value = _make_llm_response()
|
||||
|
||||
with patch(
|
||||
"agent.auxiliary_client.get_text_auxiliary_client",
|
||||
return_value=(fake_client, "default-model"),
|
||||
):
|
||||
# First call succeeds
|
||||
r1 = asyncio.run(handler(None, _make_sampling_params()))
|
||||
assert isinstance(r1, CreateMessageResult)
|
||||
# Second call is rate limited
|
||||
r2 = asyncio.run(handler(None, _make_sampling_params()))
|
||||
assert isinstance(r2, ErrorData)
|
||||
assert "rate limit" in r2.message.lower()
|
||||
assert handler.metrics["errors"] == 1
|
||||
|
||||
def test_timeout_error(self):
|
||||
handler = SamplingHandler("to", {"timeout": 0.05})
|
||||
fake_client = MagicMock()
|
||||
|
||||
def slow_call(**kwargs):
|
||||
import threading
|
||||
# Use an event to ensure the thread truly blocks long enough
|
||||
evt = threading.Event()
|
||||
evt.wait(5) # blocks for up to 5 seconds (cancelled by timeout)
|
||||
return _make_llm_response()
|
||||
|
||||
fake_client.chat.completions.create.side_effect = slow_call
|
||||
|
||||
with patch(
|
||||
"agent.auxiliary_client.get_text_auxiliary_client",
|
||||
return_value=(fake_client, "default-model"),
|
||||
):
|
||||
result = asyncio.run(handler(None, _make_sampling_params()))
|
||||
assert isinstance(result, ErrorData)
|
||||
assert "timed out" in result.message.lower()
|
||||
assert handler.metrics["errors"] == 1
|
||||
|
||||
def test_no_provider_error(self):
|
||||
handler = SamplingHandler("np", {})
|
||||
|
||||
with patch(
|
||||
"agent.auxiliary_client.get_text_auxiliary_client",
|
||||
return_value=(None, None),
|
||||
):
|
||||
result = asyncio.run(handler(None, _make_sampling_params()))
|
||||
assert isinstance(result, ErrorData)
|
||||
assert "No LLM provider" in result.message
|
||||
assert handler.metrics["errors"] == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 10. Model whitelist
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestModelWhitelist:
|
||||
def test_allowed_model_passes(self):
|
||||
handler = SamplingHandler("wl", {"allowed_models": ["gpt-4o", "test-model"]})
|
||||
fake_client = MagicMock()
|
||||
fake_client.chat.completions.create.return_value = _make_llm_response()
|
||||
|
||||
with patch(
|
||||
"agent.auxiliary_client.get_text_auxiliary_client",
|
||||
return_value=(fake_client, "test-model"),
|
||||
):
|
||||
result = asyncio.run(handler(None, _make_sampling_params()))
|
||||
assert isinstance(result, CreateMessageResult)
|
||||
|
||||
def test_disallowed_model_rejected(self):
|
||||
handler = SamplingHandler("wl2", {"allowed_models": ["gpt-4o"]})
|
||||
fake_client = MagicMock()
|
||||
|
||||
with patch(
|
||||
"agent.auxiliary_client.get_text_auxiliary_client",
|
||||
return_value=(fake_client, "gpt-3.5-turbo"),
|
||||
):
|
||||
result = asyncio.run(handler(None, _make_sampling_params()))
|
||||
assert isinstance(result, ErrorData)
|
||||
assert "not allowed" in result.message
|
||||
assert handler.metrics["errors"] == 1
|
||||
|
||||
def test_empty_whitelist_allows_all(self):
|
||||
handler = SamplingHandler("wl3", {"allowed_models": []})
|
||||
fake_client = MagicMock()
|
||||
fake_client.chat.completions.create.return_value = _make_llm_response()
|
||||
|
||||
with patch(
|
||||
"agent.auxiliary_client.get_text_auxiliary_client",
|
||||
return_value=(fake_client, "any-model"),
|
||||
):
|
||||
result = asyncio.run(handler(None, _make_sampling_params()))
|
||||
assert isinstance(result, CreateMessageResult)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 11. Malformed tool_call arguments
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestMalformedToolCallArgs:
|
||||
def test_invalid_json_wrapped_as_raw(self):
|
||||
"""Malformed JSON arguments get wrapped in {"_raw": ...}."""
|
||||
handler = SamplingHandler("mf", {})
|
||||
fake_client = MagicMock()
|
||||
fake_client.chat.completions.create.return_value = _make_llm_tool_response(
|
||||
tool_calls_data=[("call_x", "some_tool", "not valid json {{{")]
|
||||
)
|
||||
|
||||
with patch(
|
||||
"agent.auxiliary_client.get_text_auxiliary_client",
|
||||
return_value=(fake_client, "default-model"),
|
||||
):
|
||||
result = asyncio.run(handler(None, _make_sampling_params()))
|
||||
|
||||
assert isinstance(result, CreateMessageResultWithTools)
|
||||
tc = result.content[0]
|
||||
assert isinstance(tc, ToolUseContent)
|
||||
assert tc.input == {"_raw": "not valid json {{{"}
|
||||
|
||||
def test_dict_args_pass_through(self):
|
||||
"""When arguments are already a dict, they pass through directly."""
|
||||
handler = SamplingHandler("mf2", {})
|
||||
|
||||
# Build a tool call where arguments is already a dict
|
||||
tc_obj = SimpleNamespace(
|
||||
id="call_d",
|
||||
function=SimpleNamespace(name="do_stuff", arguments={"key": "val"}),
|
||||
)
|
||||
message = SimpleNamespace(content=None, tool_calls=[tc_obj])
|
||||
choice = SimpleNamespace(finish_reason="tool_calls", message=message)
|
||||
usage = SimpleNamespace(total_tokens=10)
|
||||
response = SimpleNamespace(choices=[choice], model="m", usage=usage)
|
||||
|
||||
fake_client = MagicMock()
|
||||
fake_client.chat.completions.create.return_value = response
|
||||
|
||||
with patch(
|
||||
"agent.auxiliary_client.get_text_auxiliary_client",
|
||||
return_value=(fake_client, "default-model"),
|
||||
):
|
||||
result = asyncio.run(handler(None, _make_sampling_params()))
|
||||
|
||||
assert isinstance(result, CreateMessageResultWithTools)
|
||||
assert result.content[0].input == {"key": "val"}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 12. Metrics tracking
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestMetricsTracking:
|
||||
def test_request_and_token_metrics(self):
|
||||
handler = SamplingHandler("met", {})
|
||||
fake_client = MagicMock()
|
||||
fake_client.chat.completions.create.return_value = _make_llm_response()
|
||||
|
||||
with patch(
|
||||
"agent.auxiliary_client.get_text_auxiliary_client",
|
||||
return_value=(fake_client, "default-model"),
|
||||
):
|
||||
asyncio.run(handler(None, _make_sampling_params()))
|
||||
|
||||
assert handler.metrics["requests"] == 1
|
||||
assert handler.metrics["tokens_used"] == 42
|
||||
assert handler.metrics["errors"] == 0
|
||||
|
||||
def test_tool_use_count_metric(self):
|
||||
handler = SamplingHandler("met2", {})
|
||||
fake_client = MagicMock()
|
||||
fake_client.chat.completions.create.return_value = _make_llm_tool_response()
|
||||
|
||||
with patch(
|
||||
"agent.auxiliary_client.get_text_auxiliary_client",
|
||||
return_value=(fake_client, "default-model"),
|
||||
):
|
||||
asyncio.run(handler(None, _make_sampling_params()))
|
||||
|
||||
assert handler.metrics["tool_use_count"] == 1
|
||||
assert handler.metrics["requests"] == 1
|
||||
|
||||
def test_error_metric_incremented(self):
|
||||
handler = SamplingHandler("met3", {})
|
||||
|
||||
with patch(
|
||||
"agent.auxiliary_client.get_text_auxiliary_client",
|
||||
return_value=(None, None),
|
||||
):
|
||||
asyncio.run(handler(None, _make_sampling_params()))
|
||||
|
||||
assert handler.metrics["errors"] == 1
|
||||
assert handler.metrics["requests"] == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 13. session_kwargs()
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSessionKwargs:
|
||||
def test_returns_correct_keys(self):
|
||||
handler = SamplingHandler("sk", {})
|
||||
kwargs = handler.session_kwargs()
|
||||
assert "sampling_callback" in kwargs
|
||||
assert "sampling_capabilities" in kwargs
|
||||
assert kwargs["sampling_callback"] is handler
|
||||
|
||||
def test_sampling_capabilities_type(self):
|
||||
handler = SamplingHandler("sk2", {})
|
||||
kwargs = handler.session_kwargs()
|
||||
cap = kwargs["sampling_capabilities"]
|
||||
assert isinstance(cap, SamplingCapability)
|
||||
assert isinstance(cap.tools, SamplingToolsCapability)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 14. MCPServerTask integration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestMCPServerTaskSamplingIntegration:
|
||||
def test_sampling_handler_created_when_enabled(self):
|
||||
"""MCPServerTask.run() creates a SamplingHandler when sampling is enabled."""
|
||||
from tools.mcp_tool import MCPServerTask, _MCP_SAMPLING_TYPES
|
||||
|
||||
server = MCPServerTask("int_test")
|
||||
config = {
|
||||
"command": "fake",
|
||||
"sampling": {"enabled": True, "max_rpm": 5},
|
||||
}
|
||||
# We only need to test the setup logic, not the actual connection.
|
||||
# Calling run() would attempt a real connection, so we test the
|
||||
# sampling setup portion directly.
|
||||
server._config = config
|
||||
sampling_config = config.get("sampling", {})
|
||||
if sampling_config.get("enabled", True) and _MCP_SAMPLING_TYPES:
|
||||
server._sampling = SamplingHandler(server.name, sampling_config)
|
||||
else:
|
||||
server._sampling = None
|
||||
|
||||
assert server._sampling is not None
|
||||
assert isinstance(server._sampling, SamplingHandler)
|
||||
assert server._sampling.server_name == "int_test"
|
||||
assert server._sampling.max_rpm == 5
|
||||
|
||||
def test_sampling_handler_none_when_disabled(self):
|
||||
"""MCPServerTask._sampling is None when sampling is disabled."""
|
||||
from tools.mcp_tool import MCPServerTask, _MCP_SAMPLING_TYPES
|
||||
|
||||
server = MCPServerTask("int_test2")
|
||||
config = {
|
||||
"command": "fake",
|
||||
"sampling": {"enabled": False},
|
||||
}
|
||||
server._config = config
|
||||
sampling_config = config.get("sampling", {})
|
||||
if sampling_config.get("enabled", True) and _MCP_SAMPLING_TYPES:
|
||||
server._sampling = SamplingHandler(server.name, sampling_config)
|
||||
else:
|
||||
server._sampling = None
|
||||
|
||||
assert server._sampling is None
|
||||
|
||||
def test_session_kwargs_used_in_stdio(self):
|
||||
"""When sampling is set, session_kwargs() are passed to ClientSession."""
|
||||
from tools.mcp_tool import MCPServerTask
|
||||
|
||||
server = MCPServerTask("sk_test")
|
||||
server._sampling = SamplingHandler("sk_test", {"max_rpm": 7})
|
||||
kwargs = server._sampling.session_kwargs()
|
||||
assert "sampling_callback" in kwargs
|
||||
assert "sampling_capabilities" in kwargs
|
||||
|
||||
@@ -457,17 +457,11 @@ def execute_code(
|
||||
|
||||
# --- Poll loop: watch for exit, timeout, and interrupt ---
|
||||
deadline = time.monotonic() + timeout
|
||||
stdout_chunks: list = []
|
||||
stderr_chunks: list = []
|
||||
|
||||
# Background readers to avoid pipe buffer deadlocks.
|
||||
# For stdout we use a head+tail strategy: keep the first HEAD_BYTES
|
||||
# and a rolling window of the last TAIL_BYTES so the final print()
|
||||
# output is never lost. Stderr keeps head-only (errors appear early).
|
||||
_STDOUT_HEAD_BYTES = int(MAX_STDOUT_BYTES * 0.4) # 40% head
|
||||
_STDOUT_TAIL_BYTES = MAX_STDOUT_BYTES - _STDOUT_HEAD_BYTES # 60% tail
|
||||
|
||||
# Background readers to avoid pipe buffer deadlocks
|
||||
def _drain(pipe, chunks, max_bytes):
|
||||
"""Simple head-only drain (used for stderr)."""
|
||||
total = 0
|
||||
try:
|
||||
while True:
|
||||
@@ -481,48 +475,8 @@ def execute_code(
|
||||
except (ValueError, OSError):
|
||||
pass
|
||||
|
||||
stdout_total_bytes = [0] # mutable ref for total bytes seen
|
||||
|
||||
def _drain_head_tail(pipe, head_chunks, tail_chunks, head_bytes, tail_bytes, total_ref):
|
||||
"""Drain stdout keeping both head and tail data."""
|
||||
head_collected = 0
|
||||
from collections import deque
|
||||
tail_buf = deque()
|
||||
tail_collected = 0
|
||||
try:
|
||||
while True:
|
||||
data = pipe.read(4096)
|
||||
if not data:
|
||||
break
|
||||
total_ref[0] += len(data)
|
||||
# Fill head buffer first
|
||||
if head_collected < head_bytes:
|
||||
keep = min(len(data), head_bytes - head_collected)
|
||||
head_chunks.append(data[:keep])
|
||||
head_collected += keep
|
||||
data = data[keep:] # remaining goes to tail
|
||||
if not data:
|
||||
continue
|
||||
# Everything past head goes into rolling tail buffer
|
||||
tail_buf.append(data)
|
||||
tail_collected += len(data)
|
||||
# Evict old tail data to stay within tail_bytes budget
|
||||
while tail_collected > tail_bytes and tail_buf:
|
||||
oldest = tail_buf.popleft()
|
||||
tail_collected -= len(oldest)
|
||||
except (ValueError, OSError):
|
||||
pass
|
||||
# Transfer final tail to output list
|
||||
tail_chunks.extend(tail_buf)
|
||||
|
||||
stdout_head_chunks: list = []
|
||||
stdout_tail_chunks: list = []
|
||||
|
||||
stdout_reader = threading.Thread(
|
||||
target=_drain_head_tail,
|
||||
args=(proc.stdout, stdout_head_chunks, stdout_tail_chunks,
|
||||
_STDOUT_HEAD_BYTES, _STDOUT_TAIL_BYTES, stdout_total_bytes),
|
||||
daemon=True
|
||||
target=_drain, args=(proc.stdout, stdout_chunks, MAX_STDOUT_BYTES), daemon=True
|
||||
)
|
||||
stderr_reader = threading.Thread(
|
||||
target=_drain, args=(proc.stderr, stderr_chunks, MAX_STDERR_BYTES), daemon=True
|
||||
@@ -546,21 +500,12 @@ def execute_code(
|
||||
stdout_reader.join(timeout=3)
|
||||
stderr_reader.join(timeout=3)
|
||||
|
||||
stdout_head = b"".join(stdout_head_chunks).decode("utf-8", errors="replace")
|
||||
stdout_tail = b"".join(stdout_tail_chunks).decode("utf-8", errors="replace")
|
||||
stdout_text = b"".join(stdout_chunks).decode("utf-8", errors="replace")
|
||||
stderr_text = b"".join(stderr_chunks).decode("utf-8", errors="replace")
|
||||
|
||||
# Assemble stdout with head+tail truncation
|
||||
total_stdout = stdout_total_bytes[0]
|
||||
if total_stdout > MAX_STDOUT_BYTES and stdout_tail:
|
||||
omitted = total_stdout - len(stdout_head) - len(stdout_tail)
|
||||
truncated_notice = (
|
||||
f"\n\n... [OUTPUT TRUNCATED - {omitted:,} chars omitted "
|
||||
f"out of {total_stdout:,} total] ...\n\n"
|
||||
)
|
||||
stdout_text = stdout_head + truncated_notice + stdout_tail
|
||||
else:
|
||||
stdout_text = stdout_head + stdout_tail
|
||||
# Truncation notice
|
||||
if len(stdout_text) >= MAX_STDOUT_BYTES:
|
||||
stdout_text = stdout_text[:MAX_STDOUT_BYTES] + "\n[output truncated at 50KB]"
|
||||
|
||||
exit_code = proc.returncode if proc.returncode is not None else -1
|
||||
duration = round(time.monotonic() - exec_start, 2)
|
||||
|
||||
@@ -7,7 +7,6 @@ import os
|
||||
import threading
|
||||
from typing import Optional
|
||||
from tools.file_operations import ShellFileOperations
|
||||
from agent.redact import redact_sensitive_text
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -129,8 +128,6 @@ def read_file_tool(path: str, offset: int = 1, limit: int = 500, task_id: str =
|
||||
try:
|
||||
file_ops = _get_file_ops(task_id)
|
||||
result = file_ops.read_file(path, offset, limit)
|
||||
if result.content:
|
||||
result.content = redact_sensitive_text(result.content)
|
||||
return json.dumps(result.to_dict(), ensure_ascii=False)
|
||||
except Exception as e:
|
||||
return json.dumps({"error": str(e)}, ensure_ascii=False)
|
||||
@@ -189,10 +186,6 @@ def search_tool(pattern: str, target: str = "content", path: str = ".",
|
||||
pattern=pattern, path=path, target=target, file_glob=file_glob,
|
||||
limit=limit, offset=offset, output_mode=output_mode, context=context
|
||||
)
|
||||
if hasattr(result, 'matches'):
|
||||
for m in result.matches:
|
||||
if hasattr(m, 'content') and m.content:
|
||||
m.content = redact_sensitive_text(m.content)
|
||||
result_dict = result.to_dict()
|
||||
result_json = json.dumps(result_dict, ensure_ascii=False)
|
||||
# Hint when results were truncated — explicit next offset is clearer
|
||||
|
||||
@@ -29,6 +29,18 @@ Example config::
|
||||
headers:
|
||||
Authorization: "Bearer sk-..."
|
||||
timeout: 180
|
||||
analysis:
|
||||
command: "npx"
|
||||
args: ["-y", "analysis-server"]
|
||||
sampling: # server-initiated LLM requests
|
||||
enabled: true # default: true
|
||||
model: "gemini-3-flash" # override model (optional)
|
||||
max_tokens_cap: 4096 # max tokens per request
|
||||
timeout: 30 # LLM call timeout (seconds)
|
||||
max_rpm: 10 # max requests per minute
|
||||
allowed_models: [] # model whitelist (empty = all)
|
||||
max_tool_rounds: 5 # tool loop limit (0 = disable)
|
||||
log_level: "info" # audit verbosity
|
||||
|
||||
Features:
|
||||
- Stdio transport (command + args) and HTTP/StreamableHTTP transport (url)
|
||||
@@ -37,6 +49,8 @@ Features:
|
||||
- Credential stripping in error messages returned to the LLM
|
||||
- Configurable per-server timeouts for tool calls and connections
|
||||
- Thread-safe architecture with dedicated background event loop
|
||||
- Sampling support: MCP servers can request LLM completions via
|
||||
sampling/createMessage (text and tool-use responses)
|
||||
|
||||
Architecture:
|
||||
A dedicated background event loop (_mcp_loop) runs in a daemon thread.
|
||||
@@ -58,9 +72,11 @@ Thread safety:
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
import threading
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -71,6 +87,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
_MCP_AVAILABLE = False
|
||||
_MCP_HTTP_AVAILABLE = False
|
||||
_MCP_SAMPLING_TYPES = False
|
||||
try:
|
||||
from mcp import ClientSession, StdioServerParameters
|
||||
from mcp.client.stdio import stdio_client
|
||||
@@ -80,6 +97,20 @@ try:
|
||||
_MCP_HTTP_AVAILABLE = True
|
||||
except ImportError:
|
||||
_MCP_HTTP_AVAILABLE = False
|
||||
# Sampling types -- separated so older SDK versions don't break MCP support
|
||||
try:
|
||||
from mcp.types import (
|
||||
CreateMessageResult,
|
||||
CreateMessageResultWithTools,
|
||||
ErrorData,
|
||||
SamplingCapability,
|
||||
SamplingToolsCapability,
|
||||
TextContent,
|
||||
ToolUseContent,
|
||||
)
|
||||
_MCP_SAMPLING_TYPES = True
|
||||
except ImportError:
|
||||
logger.debug("MCP sampling types not available -- sampling disabled")
|
||||
except ImportError:
|
||||
logger.debug("mcp package not installed -- MCP tool support disabled")
|
||||
|
||||
@@ -145,6 +176,386 @@ def _sanitize_error(text: str) -> str:
|
||||
return _CREDENTIAL_PATTERN.sub("[REDACTED]", text)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Sampling -- server-initiated LLM requests (MCP sampling/createMessage)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _safe_numeric(value, default, coerce=int, minimum=1):
|
||||
"""Coerce a config value to a numeric type, returning *default* on failure.
|
||||
|
||||
Handles string values from YAML (e.g. ``"10"`` instead of ``10``),
|
||||
non-finite floats, and values below *minimum*.
|
||||
"""
|
||||
try:
|
||||
result = coerce(value)
|
||||
if isinstance(result, float) and not math.isfinite(result):
|
||||
return default
|
||||
return max(result, minimum)
|
||||
except (TypeError, ValueError, OverflowError):
|
||||
return default
|
||||
|
||||
|
||||
class SamplingHandler:
|
||||
"""Handles sampling/createMessage requests for a single MCP server.
|
||||
|
||||
Each MCPServerTask that has sampling enabled creates one SamplingHandler.
|
||||
The handler is callable and passed directly to ``ClientSession`` as
|
||||
the ``sampling_callback``. All state (rate-limit timestamps, metrics,
|
||||
tool-loop counters) lives on the instance -- no module-level globals.
|
||||
|
||||
The callback is async and runs on the MCP background event loop. The
|
||||
sync LLM call is offloaded to a thread via ``asyncio.to_thread()`` so
|
||||
it doesn't block the event loop.
|
||||
"""
|
||||
|
||||
_STOP_REASON_MAP = {"stop": "endTurn", "length": "maxTokens", "tool_calls": "toolUse"}
|
||||
|
||||
def __init__(self, server_name: str, config: dict):
|
||||
self.server_name = server_name
|
||||
self.max_rpm = _safe_numeric(config.get("max_rpm", 10), 10, int)
|
||||
self.timeout = _safe_numeric(config.get("timeout", 30), 30, float)
|
||||
self.max_tokens_cap = _safe_numeric(config.get("max_tokens_cap", 4096), 4096, int)
|
||||
self.max_tool_rounds = _safe_numeric(
|
||||
config.get("max_tool_rounds", 5), 5, int, minimum=0,
|
||||
)
|
||||
self.model_override = config.get("model")
|
||||
self.allowed_models = config.get("allowed_models", [])
|
||||
|
||||
_log_levels = {"debug": logging.DEBUG, "info": logging.INFO, "warning": logging.WARNING}
|
||||
self.audit_level = _log_levels.get(
|
||||
str(config.get("log_level", "info")).lower(), logging.INFO,
|
||||
)
|
||||
|
||||
# Per-instance state
|
||||
self._rate_timestamps: List[float] = []
|
||||
self._tool_loop_count = 0
|
||||
self.metrics = {"requests": 0, "errors": 0, "tokens_used": 0, "tool_use_count": 0}
|
||||
|
||||
# -- Rate limiting -------------------------------------------------------
|
||||
|
||||
def _check_rate_limit(self) -> bool:
|
||||
"""Sliding-window rate limiter. Returns True if request is allowed."""
|
||||
now = time.time()
|
||||
window = now - 60
|
||||
self._rate_timestamps[:] = [t for t in self._rate_timestamps if t > window]
|
||||
if len(self._rate_timestamps) >= self.max_rpm:
|
||||
return False
|
||||
self._rate_timestamps.append(now)
|
||||
return True
|
||||
|
||||
# -- Model resolution ----------------------------------------------------
|
||||
|
||||
def _resolve_model(self, preferences) -> Optional[str]:
|
||||
"""Config override > server hint > None (use default)."""
|
||||
if self.model_override:
|
||||
return self.model_override
|
||||
if preferences and hasattr(preferences, "hints") and preferences.hints:
|
||||
for hint in preferences.hints:
|
||||
if hasattr(hint, "name") and hint.name:
|
||||
return hint.name
|
||||
return None
|
||||
|
||||
# -- Message conversion --------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _extract_tool_result_text(block) -> str:
|
||||
"""Extract text from a ToolResultContent block."""
|
||||
if not hasattr(block, "content") or block.content is None:
|
||||
return ""
|
||||
items = block.content if isinstance(block.content, list) else [block.content]
|
||||
return "\n".join(item.text for item in items if hasattr(item, "text"))
|
||||
|
||||
def _convert_messages(self, params) -> List[dict]:
|
||||
"""Convert MCP SamplingMessages to OpenAI format.
|
||||
|
||||
Uses ``msg.content_as_list`` (SDK helper) so single-block and
|
||||
list-of-blocks are handled uniformly. Dispatches per block type
|
||||
with ``isinstance`` on real SDK types when available, falling back
|
||||
to duck-typing via ``hasattr`` for compatibility.
|
||||
"""
|
||||
messages: List[dict] = []
|
||||
for msg in params.messages:
|
||||
blocks = msg.content_as_list if hasattr(msg, "content_as_list") else (
|
||||
msg.content if isinstance(msg.content, list) else [msg.content]
|
||||
)
|
||||
|
||||
# Separate blocks by kind
|
||||
tool_results = [b for b in blocks if hasattr(b, "toolUseId")]
|
||||
tool_uses = [b for b in blocks if hasattr(b, "name") and hasattr(b, "input") and not hasattr(b, "toolUseId")]
|
||||
content_blocks = [b for b in blocks if not hasattr(b, "toolUseId") and not (hasattr(b, "name") and hasattr(b, "input"))]
|
||||
|
||||
# Emit tool result messages (role: tool)
|
||||
for tr in tool_results:
|
||||
messages.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": tr.toolUseId,
|
||||
"content": self._extract_tool_result_text(tr),
|
||||
})
|
||||
|
||||
# Emit assistant tool_calls message
|
||||
if tool_uses:
|
||||
tc_list = []
|
||||
for tu in tool_uses:
|
||||
tc_list.append({
|
||||
"id": getattr(tu, "id", f"call_{len(tc_list)}"),
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tu.name,
|
||||
"arguments": json.dumps(tu.input) if isinstance(tu.input, dict) else str(tu.input),
|
||||
},
|
||||
})
|
||||
msg_dict: dict = {"role": msg.role, "tool_calls": tc_list}
|
||||
# Include any accompanying text
|
||||
text_parts = [b.text for b in content_blocks if hasattr(b, "text")]
|
||||
if text_parts:
|
||||
msg_dict["content"] = "\n".join(text_parts)
|
||||
messages.append(msg_dict)
|
||||
elif content_blocks:
|
||||
# Pure text/image content
|
||||
if len(content_blocks) == 1 and hasattr(content_blocks[0], "text"):
|
||||
messages.append({"role": msg.role, "content": content_blocks[0].text})
|
||||
else:
|
||||
parts = []
|
||||
for block in content_blocks:
|
||||
if hasattr(block, "text"):
|
||||
parts.append({"type": "text", "text": block.text})
|
||||
elif hasattr(block, "data") and hasattr(block, "mimeType"):
|
||||
parts.append({
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:{block.mimeType};base64,{block.data}"},
|
||||
})
|
||||
else:
|
||||
logger.warning(
|
||||
"Unsupported sampling content block type: %s (skipped)",
|
||||
type(block).__name__,
|
||||
)
|
||||
if parts:
|
||||
messages.append({"role": msg.role, "content": parts})
|
||||
|
||||
return messages
|
||||
|
||||
# -- Error helper --------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _error(message: str, code: int = -1):
|
||||
"""Return ErrorData (MCP spec) or raise as fallback."""
|
||||
if _MCP_SAMPLING_TYPES:
|
||||
return ErrorData(code=code, message=message)
|
||||
raise Exception(message)
|
||||
|
||||
# -- Response building ---------------------------------------------------
|
||||
|
||||
def _build_tool_use_result(self, choice, response):
|
||||
"""Build a CreateMessageResultWithTools from an LLM tool_calls response."""
|
||||
self.metrics["tool_use_count"] += 1
|
||||
|
||||
# Tool loop governance
|
||||
if self.max_tool_rounds == 0:
|
||||
self._tool_loop_count = 0
|
||||
return self._error(
|
||||
f"Tool loops disabled for server '{self.server_name}' (max_tool_rounds=0)"
|
||||
)
|
||||
|
||||
self._tool_loop_count += 1
|
||||
if self._tool_loop_count > self.max_tool_rounds:
|
||||
self._tool_loop_count = 0
|
||||
return self._error(
|
||||
f"Tool loop limit exceeded for server '{self.server_name}' "
|
||||
f"(max {self.max_tool_rounds} rounds)"
|
||||
)
|
||||
|
||||
content_blocks = []
|
||||
for tc in choice.message.tool_calls:
|
||||
args = tc.function.arguments
|
||||
if isinstance(args, str):
|
||||
try:
|
||||
parsed = json.loads(args)
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
logger.warning(
|
||||
"MCP server '%s': malformed tool_calls arguments "
|
||||
"from LLM (wrapping as raw): %.100s",
|
||||
self.server_name, args,
|
||||
)
|
||||
parsed = {"_raw": args}
|
||||
else:
|
||||
parsed = args if isinstance(args, dict) else {"_raw": str(args)}
|
||||
|
||||
content_blocks.append(ToolUseContent(
|
||||
type="tool_use",
|
||||
id=tc.id,
|
||||
name=tc.function.name,
|
||||
input=parsed,
|
||||
))
|
||||
|
||||
logger.log(
|
||||
self.audit_level,
|
||||
"MCP server '%s' sampling response: model=%s, tokens=%s, tool_calls=%d",
|
||||
self.server_name, response.model,
|
||||
getattr(getattr(response, "usage", None), "total_tokens", "?"),
|
||||
len(content_blocks),
|
||||
)
|
||||
|
||||
return CreateMessageResultWithTools(
|
||||
role="assistant",
|
||||
content=content_blocks,
|
||||
model=response.model,
|
||||
stopReason="toolUse",
|
||||
)
|
||||
|
||||
def _build_text_result(self, choice, response):
|
||||
"""Build a CreateMessageResult from a normal text response."""
|
||||
self._tool_loop_count = 0 # reset on text response
|
||||
response_text = choice.message.content or ""
|
||||
|
||||
logger.log(
|
||||
self.audit_level,
|
||||
"MCP server '%s' sampling response: model=%s, tokens=%s",
|
||||
self.server_name, response.model,
|
||||
getattr(getattr(response, "usage", None), "total_tokens", "?"),
|
||||
)
|
||||
|
||||
return CreateMessageResult(
|
||||
role="assistant",
|
||||
content=TextContent(type="text", text=_sanitize_error(response_text)),
|
||||
model=response.model,
|
||||
stopReason=self._STOP_REASON_MAP.get(choice.finish_reason, "endTurn"),
|
||||
)
|
||||
|
||||
# -- Session kwargs helper -----------------------------------------------
|
||||
|
||||
def session_kwargs(self) -> dict:
|
||||
"""Return kwargs to pass to ClientSession for sampling support."""
|
||||
return {
|
||||
"sampling_callback": self,
|
||||
"sampling_capabilities": SamplingCapability(
|
||||
tools=SamplingToolsCapability(),
|
||||
),
|
||||
}
|
||||
|
||||
# -- Main callback -------------------------------------------------------
|
||||
|
||||
async def __call__(self, context, params):
|
||||
"""Sampling callback invoked by the MCP SDK.
|
||||
|
||||
Conforms to ``SamplingFnT`` protocol. Returns
|
||||
``CreateMessageResult``, ``CreateMessageResultWithTools``, or
|
||||
``ErrorData``.
|
||||
"""
|
||||
# Rate limit
|
||||
if not self._check_rate_limit():
|
||||
logger.warning(
|
||||
"MCP server '%s' sampling rate limit exceeded (%d/min)",
|
||||
self.server_name, self.max_rpm,
|
||||
)
|
||||
self.metrics["errors"] += 1
|
||||
return self._error(
|
||||
f"Sampling rate limit exceeded for server '{self.server_name}' "
|
||||
f"({self.max_rpm} requests/minute)"
|
||||
)
|
||||
|
||||
# Resolve model
|
||||
model = self._resolve_model(getattr(params, "modelPreferences", None))
|
||||
|
||||
# Get auxiliary LLM client
|
||||
from agent.auxiliary_client import get_text_auxiliary_client
|
||||
client, default_model = get_text_auxiliary_client()
|
||||
if client is None:
|
||||
self.metrics["errors"] += 1
|
||||
return self._error("No LLM provider available for sampling")
|
||||
|
||||
resolved_model = model or default_model
|
||||
|
||||
# Model whitelist check
|
||||
if self.allowed_models and resolved_model not in self.allowed_models:
|
||||
logger.warning(
|
||||
"MCP server '%s' requested model '%s' not in allowed_models",
|
||||
self.server_name, resolved_model,
|
||||
)
|
||||
self.metrics["errors"] += 1
|
||||
return self._error(
|
||||
f"Model '{resolved_model}' not allowed for server "
|
||||
f"'{self.server_name}'. Allowed: {', '.join(self.allowed_models)}"
|
||||
)
|
||||
|
||||
# Convert messages
|
||||
messages = self._convert_messages(params)
|
||||
if hasattr(params, "systemPrompt") and params.systemPrompt:
|
||||
messages.insert(0, {"role": "system", "content": params.systemPrompt})
|
||||
|
||||
# Build LLM call kwargs
|
||||
max_tokens = min(params.maxTokens, self.max_tokens_cap)
|
||||
call_kwargs: dict = {
|
||||
"model": resolved_model,
|
||||
"messages": messages,
|
||||
"max_tokens": max_tokens,
|
||||
}
|
||||
if hasattr(params, "temperature") and params.temperature is not None:
|
||||
call_kwargs["temperature"] = params.temperature
|
||||
if stop := getattr(params, "stopSequences", None):
|
||||
call_kwargs["stop"] = stop
|
||||
|
||||
# Forward server-provided tools
|
||||
server_tools = getattr(params, "tools", None)
|
||||
if server_tools:
|
||||
call_kwargs["tools"] = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": getattr(t, "name", ""),
|
||||
"description": getattr(t, "description", "") or "",
|
||||
"parameters": getattr(t, "inputSchema", {}) or {},
|
||||
},
|
||||
}
|
||||
for t in server_tools
|
||||
]
|
||||
if tool_choice := getattr(params, "toolChoice", None):
|
||||
mode = getattr(tool_choice, "mode", "auto")
|
||||
call_kwargs["tool_choice"] = {"auto": "auto", "required": "required", "none": "none"}.get(mode, "auto")
|
||||
|
||||
logger.log(
|
||||
self.audit_level,
|
||||
"MCP server '%s' sampling request: model=%s, max_tokens=%d, messages=%d",
|
||||
self.server_name, resolved_model, max_tokens, len(messages),
|
||||
)
|
||||
|
||||
# Offload sync LLM call to thread (non-blocking)
|
||||
def _sync_call():
|
||||
return client.chat.completions.create(**call_kwargs)
|
||||
|
||||
try:
|
||||
response = await asyncio.wait_for(
|
||||
asyncio.to_thread(_sync_call), timeout=self.timeout,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
self.metrics["errors"] += 1
|
||||
return self._error(
|
||||
f"Sampling LLM call timed out after {self.timeout}s "
|
||||
f"for server '{self.server_name}'"
|
||||
)
|
||||
except Exception as exc:
|
||||
self.metrics["errors"] += 1
|
||||
return self._error(
|
||||
f"Sampling LLM call failed: {_sanitize_error(str(exc))}"
|
||||
)
|
||||
|
||||
# Track metrics
|
||||
choice = response.choices[0]
|
||||
self.metrics["requests"] += 1
|
||||
total_tokens = getattr(getattr(response, "usage", None), "total_tokens", 0)
|
||||
if isinstance(total_tokens, int):
|
||||
self.metrics["tokens_used"] += total_tokens
|
||||
|
||||
# Dispatch based on response type
|
||||
if (
|
||||
choice.finish_reason == "tool_calls"
|
||||
and hasattr(choice.message, "tool_calls")
|
||||
and choice.message.tool_calls
|
||||
):
|
||||
return self._build_tool_use_result(choice, response)
|
||||
|
||||
return self._build_text_result(choice, response)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Server task -- each MCP server lives in one long-lived asyncio Task
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -162,6 +573,7 @@ class MCPServerTask:
|
||||
__slots__ = (
|
||||
"name", "session", "tool_timeout",
|
||||
"_task", "_ready", "_shutdown_event", "_tools", "_error", "_config",
|
||||
"_sampling",
|
||||
)
|
||||
|
||||
def __init__(self, name: str):
|
||||
@@ -174,6 +586,7 @@ class MCPServerTask:
|
||||
self._tools: list = []
|
||||
self._error: Optional[Exception] = None
|
||||
self._config: dict = {}
|
||||
self._sampling: Optional[SamplingHandler] = None
|
||||
|
||||
def _is_http(self) -> bool:
|
||||
"""Check if this server uses HTTP transport."""
|
||||
@@ -197,8 +610,9 @@ class MCPServerTask:
|
||||
env=safe_env if safe_env else None,
|
||||
)
|
||||
|
||||
sampling_kwargs = self._sampling.session_kwargs() if self._sampling else {}
|
||||
async with stdio_client(server_params) as (read_stream, write_stream):
|
||||
async with ClientSession(read_stream, write_stream) as session:
|
||||
async with ClientSession(read_stream, write_stream, **sampling_kwargs) as session:
|
||||
await session.initialize()
|
||||
self.session = session
|
||||
await self._discover_tools()
|
||||
@@ -218,12 +632,13 @@ class MCPServerTask:
|
||||
headers = config.get("headers")
|
||||
connect_timeout = config.get("connect_timeout", _DEFAULT_CONNECT_TIMEOUT)
|
||||
|
||||
sampling_kwargs = self._sampling.session_kwargs() if self._sampling else {}
|
||||
async with streamablehttp_client(
|
||||
url,
|
||||
headers=headers,
|
||||
timeout=float(connect_timeout),
|
||||
) as (read_stream, write_stream, _get_session_id):
|
||||
async with ClientSession(read_stream, write_stream) as session:
|
||||
async with ClientSession(read_stream, write_stream, **sampling_kwargs) as session:
|
||||
await session.initialize()
|
||||
self.session = session
|
||||
await self._discover_tools()
|
||||
@@ -250,6 +665,13 @@ class MCPServerTask:
|
||||
self._config = config
|
||||
self.tool_timeout = config.get("timeout", _DEFAULT_TOOL_TIMEOUT)
|
||||
|
||||
# Set up sampling handler if enabled and SDK types are available
|
||||
sampling_config = config.get("sampling", {})
|
||||
if sampling_config.get("enabled", True) and _MCP_SAMPLING_TYPES:
|
||||
self._sampling = SamplingHandler(self.name, sampling_config)
|
||||
else:
|
||||
self._sampling = None
|
||||
|
||||
# Validate: warn if both url and command are present
|
||||
if "url" in config and "command" in config:
|
||||
logger.warning(
|
||||
@@ -975,12 +1397,15 @@ def get_mcp_status() -> List[dict]:
|
||||
transport = "http" if "url" in cfg else "stdio"
|
||||
server = active_servers.get(name)
|
||||
if server and server.session is not None:
|
||||
result.append({
|
||||
entry = {
|
||||
"name": name,
|
||||
"transport": transport,
|
||||
"tools": len(server._tools),
|
||||
"connected": True,
|
||||
})
|
||||
}
|
||||
if server._sampling:
|
||||
entry["sampling"] = dict(server._sampling.metrics)
|
||||
result.append(entry)
|
||||
else:
|
||||
result.append({
|
||||
"name": name,
|
||||
|
||||
@@ -69,36 +69,10 @@ def _read_manifest() -> Dict[str, str]:
|
||||
|
||||
|
||||
def _write_manifest(entries: Dict[str, str]):
|
||||
"""Write the manifest file atomically in v2 format (name:hash).
|
||||
|
||||
Uses a temp file + os.replace() to avoid corruption if the process
|
||||
crashes or is interrupted mid-write.
|
||||
"""
|
||||
import tempfile
|
||||
|
||||
"""Write the manifest file in v2 format (name:hash)."""
|
||||
MANIFEST_FILE.parent.mkdir(parents=True, exist_ok=True)
|
||||
data = "\n".join(f"{name}:{hash_val}" for name, hash_val in sorted(entries.items())) + "\n"
|
||||
|
||||
try:
|
||||
fd, tmp_path = tempfile.mkstemp(
|
||||
dir=str(MANIFEST_FILE.parent),
|
||||
prefix=".bundled_manifest_",
|
||||
suffix=".tmp",
|
||||
)
|
||||
try:
|
||||
with os.fdopen(fd, "w", encoding="utf-8") as f:
|
||||
f.write(data)
|
||||
f.flush()
|
||||
os.fsync(f.fileno())
|
||||
os.replace(tmp_path, MANIFEST_FILE)
|
||||
except BaseException:
|
||||
try:
|
||||
os.unlink(tmp_path)
|
||||
except OSError:
|
||||
pass
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.debug("Failed to write skills manifest %s: %s", MANIFEST_FILE, e, exc_info=True)
|
||||
lines = [f"{name}:{hash_val}" for name, hash_val in sorted(entries.items())]
|
||||
MANIFEST_FILE.write_text("\n".join(lines) + "\n", encoding="utf-8")
|
||||
|
||||
|
||||
def _discover_bundled_skills(bundled_dir: Path) -> List[Tuple[str, Path]]:
|
||||
|
||||
@@ -271,3 +271,62 @@ You can reload MCP servers without restarting Hermes:
|
||||
|
||||
- In the CLI: the agent reconnects automatically
|
||||
- In messaging: send `/reload-mcp`
|
||||
|
||||
## Sampling (Server-Initiated LLM Requests)
|
||||
|
||||
MCP's `sampling/createMessage` capability allows MCP servers to request LLM completions through the Hermes agent. This enables agent-in-the-loop workflows where servers can leverage the LLM during tool execution — for example, a database server asking the LLM to interpret query results, or a code analysis server requesting the LLM to review findings.
|
||||
|
||||
### How It Works
|
||||
|
||||
When an MCP server sends a `sampling/createMessage` request:
|
||||
|
||||
1. The sampling callback validates against rate limits and model whitelist
|
||||
2. Resolves which model to use (config override > server hint > default)
|
||||
3. Converts MCP messages to OpenAI-compatible format
|
||||
4. Offloads the LLM call to a thread via `asyncio.to_thread()` (non-blocking)
|
||||
5. Returns the response (text or tool use) back to the server
|
||||
|
||||
### Configuration
|
||||
|
||||
Sampling is **enabled by default** for all MCP servers. No extra setup needed — if you have an auxiliary LLM client configured, sampling works automatically.
|
||||
|
||||
```yaml
|
||||
mcp_servers:
|
||||
analysis_server:
|
||||
command: "npx"
|
||||
args: ["-y", "my-analysis-server"]
|
||||
sampling:
|
||||
enabled: true # default: true
|
||||
model: "gemini-3-flash" # override model (optional)
|
||||
max_tokens_cap: 4096 # max tokens per request (default: 4096)
|
||||
timeout: 30 # LLM call timeout in seconds (default: 30)
|
||||
max_rpm: 10 # max requests per minute (default: 10)
|
||||
allowed_models: [] # model whitelist (empty = allow all)
|
||||
max_tool_rounds: 5 # max consecutive tool use rounds (0 = disable)
|
||||
log_level: "info" # audit verbosity: debug, info, warning
|
||||
```
|
||||
|
||||
### Tool Use in Sampling
|
||||
|
||||
Servers can include `tools` and `toolChoice` in sampling requests, enabling multi-turn tool-augmented workflows within a single sampling session. The callback forwards tool definitions to the LLM, handles tool use responses with proper `ToolUseContent` types, and enforces `max_tool_rounds` to prevent infinite loops.
|
||||
|
||||
### Security
|
||||
|
||||
- **Rate limiting**: Per-server sliding window (default: 10 req/min)
|
||||
- **Token cap**: Servers can't request more than `max_tokens_cap` (default: 4096)
|
||||
- **Model whitelist**: `allowed_models` restricts which models a server can use
|
||||
- **Tool loop limit**: `max_tool_rounds` caps consecutive tool use rounds
|
||||
- **Credential stripping**: LLM responses are sanitized before returning to the server
|
||||
- **Non-blocking**: LLM calls run in a separate thread via `asyncio.to_thread()`
|
||||
- **Typed errors**: All failures return structured `ErrorData` per MCP spec
|
||||
|
||||
To disable sampling for untrusted servers:
|
||||
|
||||
```yaml
|
||||
mcp_servers:
|
||||
untrusted:
|
||||
command: "npx"
|
||||
args: ["-y", "untrusted-server"]
|
||||
sampling:
|
||||
enabled: false
|
||||
```
|
||||
|
||||
Reference in New Issue
Block a user