Compare commits

..

2 Commits

Author SHA1 Message Date
teknium1
fbed199672 docs: add sampling config examples to docstring and cli-config.yaml.example 2026-03-09 03:36:47 -07:00
teknium1
2d13eb9795 feat(mcp): add sampling support — server-initiated LLM requests
Add MCP sampling/createMessage capability allowing MCP servers to request
LLM completions through the Hermes agent during tool execution. Enables
agent-in-the-loop workflows (data analysis, content generation, decision
making) where servers can leverage the LLM as needed.

Implementation as SamplingHandler class (per-server instance, no globals):
- Text-only sampling: server asks LLM a question, gets text back
- Tool use in sampling: server provides tools, LLM can use them in a
  multi-turn loop with configurable max_tool_rounds governance
- Rate limiting (sliding window, configurable max_rpm per server)
- Model resolution (config override > server hint > default)
- Model whitelist (allowed_models per server)
- Token cap (max_tokens_cap per server)
- LLM timeout with asyncio.wait_for
- Credential stripping on responses
- Per-server audit metrics (requests, errors, tokens_used, tool_use_count)
- Configurable log_level for audit verbosity
- Non-blocking: LLM calls offloaded via asyncio.to_thread()
- Proper MCP SDK types: CreateMessageResult for text responses,
  CreateMessageResultWithTools + ToolUseContent for tool use responses
- SamplingCapability with SamplingToolsCapability advertised to servers
- Backward compatible: silently disabled if MCP SDK lacks sampling types

Config (all optional, zero breaking changes):
  mcp_servers:
    my_server:
      sampling:
        enabled: true        # default
        model: 'gemini-3-flash'
        max_tokens_cap: 4096
        timeout: 30
        max_rpm: 10
        allowed_models: []
        max_tool_rounds: 5
        log_level: 'info'

Based on the sampling concept from PR #366 by eren-karakus0. Restructured
as a class-based design, fixed critical bugs (wrong return types for tool
use, missing capability advertisement, broken Pydantic validation), and
added tests using real MCP SDK types.

50 new tests, full suite passes (2600 tests).
2026-03-09 01:21:34 -07:00
19 changed files with 1343 additions and 412 deletions

View File

@@ -8,7 +8,6 @@ the first 6 and last 4 characters for debuggability.
"""
import logging
import os
import re
from typing import Optional
@@ -16,7 +15,7 @@ logger = logging.getLogger(__name__)
# Known API key prefixes -- match the prefix + contiguous token chars
_PREFIX_PATTERNS = [
r"sk-[A-Za-z0-9_-]{10,}", # OpenAI / OpenRouter / Anthropic (sk-ant-*)
r"sk-[A-Za-z0-9_-]{10,}", # OpenAI / OpenRouter
r"ghp_[A-Za-z0-9]{10,}", # GitHub PAT (classic)
r"github_pat_[A-Za-z0-9_]{10,}", # GitHub PAT (fine-grained)
r"xox[baprs]-[A-Za-z0-9-]{10,}", # Slack tokens
@@ -26,18 +25,6 @@ _PREFIX_PATTERNS = [
r"fc-[A-Za-z0-9]{10,}", # Firecrawl
r"bb_live_[A-Za-z0-9_-]{10,}", # BrowserBase
r"gAAAA[A-Za-z0-9_=-]{20,}", # Codex encrypted tokens
r"AKIA[A-Z0-9]{16}", # AWS Access Key ID
r"sk_live_[A-Za-z0-9]{10,}", # Stripe secret key (live)
r"sk_test_[A-Za-z0-9]{10,}", # Stripe secret key (test)
r"rk_live_[A-Za-z0-9]{10,}", # Stripe restricted key
r"SG\.[A-Za-z0-9_-]{10,}", # SendGrid API key
r"hf_[A-Za-z0-9]{10,}", # HuggingFace token
r"r8_[A-Za-z0-9]{10,}", # Replicate API token
r"npm_[A-Za-z0-9]{10,}", # npm access token
r"pypi-[A-Za-z0-9_-]{10,}", # PyPI API token
r"dop_v1_[A-Za-z0-9]{10,}", # DigitalOcean PAT
r"doo_v1_[A-Za-z0-9]{10,}", # DigitalOcean OAuth
r"am_[A-Za-z0-9_-]{10,}", # AgentMail API key
]
# ENV assignment patterns: KEY=value where KEY contains a secret-like name
@@ -65,18 +52,6 @@ _TELEGRAM_RE = re.compile(
r"(bot)?(\d{8,}):([-A-Za-z0-9_]{30,})",
)
# Private key blocks: -----BEGIN RSA PRIVATE KEY----- ... -----END RSA PRIVATE KEY-----
_PRIVATE_KEY_RE = re.compile(
r"-----BEGIN[A-Z ]*PRIVATE KEY-----[\s\S]*?-----END[A-Z ]*PRIVATE KEY-----"
)
# Database connection strings: protocol://user:PASSWORD@host
# Catches postgres, mysql, mongodb, redis, amqp URLs and redacts the password
_DB_CONNSTR_RE = re.compile(
r"((?:postgres(?:ql)?|mysql|mongodb(?:\+srv)?|redis|amqp)://[^:]+:)([^@]+)(@)",
re.IGNORECASE,
)
# E.164 phone numbers: +<country><number>, 7-15 digits
# Negative lookahead prevents matching hex strings or identifiers
_SIGNAL_PHONE_RE = re.compile(r"(\+[1-9]\d{6,14})(?![A-Za-z0-9])")
@@ -98,12 +73,9 @@ def redact_sensitive_text(text: str) -> str:
"""Apply all redaction patterns to a block of text.
Safe to call on any string -- non-matching text passes through unchanged.
Disabled when security.redact_secrets is false in config.yaml.
"""
if not text:
return text
if os.getenv("HERMES_REDACT_SECRETS", "").lower() in ("0", "false", "no", "off"):
return text
# Known prefixes (sk-, ghp_, etc.)
text = _PREFIX_RE.sub(lambda m: _mask_token(m.group(1)), text)
@@ -133,12 +105,6 @@ def redact_sensitive_text(text: str) -> str:
return f"{prefix}{digits}:***"
text = _TELEGRAM_RE.sub(_redact_telegram, text)
# Private key blocks
text = _PRIVATE_KEY_RE.sub("[REDACTED PRIVATE KEY]", text)
# Database connection string passwords
text = _DB_CONNSTR_RE.sub(lambda m: f"{m.group(1)}***{m.group(3)}", text)
# E.164 phone numbers (Signal, WhatsApp)
def _redact_phone(m):
phone = m.group(1)

View File

@@ -555,6 +555,21 @@ toolsets:
# args: ["-y", "@modelcontextprotocol/server-github"]
# env:
# GITHUB_PERSONAL_ACCESS_TOKEN: "ghp_..."
#
# Sampling (server-initiated LLM requests) — enabled by default.
# Per-server config under the 'sampling' key:
# analysis:
# command: npx
# args: ["-y", "analysis-server"]
# sampling:
# enabled: true # default: true
# model: "gemini-3-flash" # override model (optional)
# max_tokens_cap: 4096 # max tokens per request
# timeout: 30 # LLM call timeout (seconds)
# max_rpm: 10 # max requests per minute
# allowed_models: [] # model whitelist (empty = all)
# max_tool_rounds: 5 # tool loop limit (0 = disable)
# log_level: "info" # audit verbosity
# =============================================================================
# Voice Transcription (Speech-to-Text)

7
cli.py
View File

@@ -364,13 +364,6 @@ def load_cli_config() -> Dict[str, Any]:
if model:
os.environ[model_env] = model
# Security settings
security_config = defaults.get("security", {})
if isinstance(security_config, dict):
redact = security_config.get("redact_secrets")
if redact is not None:
os.environ["HERMES_REDACT_SECRETS"] = str(redact).lower()
return defaults
# Load configuration at module startup

View File

@@ -52,7 +52,7 @@ def build_channel_directory(adapters: Dict[Any, Any]) -> Dict[str, Any]:
try:
DIRECTORY_PATH.parent.mkdir(parents=True, exist_ok=True)
with open(DIRECTORY_PATH, "w", encoding="utf-8") as f:
with open(DIRECTORY_PATH, "w") as f:
json.dump(directory, f, indent=2, ensure_ascii=False)
except Exception as e:
logger.warning("Channel directory: failed to write: %s", e)
@@ -115,7 +115,7 @@ def _build_from_sessions(platform_name: str) -> List[Dict[str, str]]:
entries = []
try:
with open(sessions_path, encoding="utf-8") as f:
with open(sessions_path) as f:
data = json.load(f)
seen_ids = set()
@@ -147,7 +147,7 @@ def load_directory() -> Dict[str, Any]:
if not DIRECTORY_PATH.exists():
return {"updated_at": None, "platforms": {}}
try:
with open(DIRECTORY_PATH, encoding="utf-8") as f:
with open(DIRECTORY_PATH) as f:
return json.load(f)
except Exception:
return {"updated_at": None, "platforms": {}}

View File

@@ -73,7 +73,7 @@ def _find_session_id(platform: str, chat_id: str) -> Optional[str]:
return None
try:
with open(_SESSIONS_INDEX, encoding="utf-8") as f:
with open(_SESSIONS_INDEX) as f:
data = json.load(f)
except Exception:
return None
@@ -103,7 +103,7 @@ def _append_to_jsonl(session_id: str, message: dict) -> None:
"""Append a message to the JSONL transcript file."""
transcript_path = _SESSIONS_DIR / f"{session_id}.jsonl"
try:
with open(transcript_path, "a", encoding="utf-8") as f:
with open(transcript_path, "a") as f:
f.write(json.dumps(message, ensure_ascii=False) + "\n")
except Exception as e:
logger.debug("Mirror JSONL write failed: %s", e)

View File

@@ -118,12 +118,6 @@ if _config_path.exists():
_tz_cfg = _cfg.get("timezone", "")
if _tz_cfg and isinstance(_tz_cfg, str) and "HERMES_TIMEZONE" not in os.environ:
os.environ["HERMES_TIMEZONE"] = _tz_cfg.strip()
# Security settings
_security_cfg = _cfg.get("security", {})
if isinstance(_security_cfg, dict):
_redact = _security_cfg.get("redact_secrets")
if _redact is not None:
os.environ["HERMES_REDACT_SECRETS"] = str(_redact).lower()
except Exception:
pass # Non-fatal; gateway can still run with .env values

View File

@@ -342,7 +342,7 @@ class SessionStore:
if sessions_file.exists():
try:
with open(sessions_file, "r", encoding="utf-8") as f:
with open(sessions_file, "r") as f:
data = json.load(f)
for key, entry_data in data.items():
self._entries[key] = SessionEntry.from_dict(entry_data)
@@ -357,7 +357,7 @@ class SessionStore:
sessions_file = self.sessions_dir / "sessions.json"
data = {key: entry.to_dict() for key, entry in self._entries.items()}
with open(sessions_file, "w", encoding="utf-8") as f:
with open(sessions_file, "w") as f:
json.dump(data, f, indent=2)
def _generate_session_key(self, source: SessionSource) -> str:
@@ -681,7 +681,7 @@ class SessionStore:
# Also write legacy JSONL (keeps existing tooling working during transition)
transcript_path = self.get_transcript_path(session_id)
with open(transcript_path, "a", encoding="utf-8") as f:
with open(transcript_path, "a") as f:
f.write(json.dumps(message, ensure_ascii=False) + "\n")
def rewrite_transcript(self, session_id: str, messages: List[Dict[str, Any]]) -> None:
@@ -708,7 +708,7 @@ class SessionStore:
# JSONL: overwrite the file
transcript_path = self.get_transcript_path(session_id)
with open(transcript_path, "w", encoding="utf-8") as f:
with open(transcript_path, "w") as f:
for msg in messages:
f.write(json.dumps(msg, ensure_ascii=False) + "\n")
@@ -730,7 +730,7 @@ class SessionStore:
return []
messages = []
with open(transcript_path, "r", encoding="utf-8") as f:
with open(transcript_path, "r") as f:
for line in f:
line = line.strip()
if line:

View File

@@ -759,16 +759,8 @@ def load_config() -> Dict[str, Any]:
return config
_COMMENTED_SECTIONS = """
# ── Security ──────────────────────────────────────────────────────────
# API keys, tokens, and passwords are redacted from tool output by default.
# Set to false to see full values (useful for debugging auth issues).
#
# security:
# redact_secrets: false
# ── Fallback Model ────────────────────────────────────────────────────
# Automatic provider failover when primary is unavailable.
_FALLBACK_MODEL_COMMENT = """
# Fallback model — automatic provider failover when primary is unavailable.
# Uncomment and configure to enable. Triggers on rate limits (429),
# overload (529), service errors (503), or connection failures.
#
@@ -796,18 +788,10 @@ def save_config(config: Dict[str, Any]):
with open(config_path, 'w') as f:
yaml.dump(config, f, default_flow_style=False, sort_keys=False)
# Append commented-out sections for features that are off by default
# or only relevant when explicitly configured. Skip sections the
# user has already uncommented and configured.
sections = []
sec = config.get("security", {})
if not sec or sec.get("redact_secrets") is None:
sections.append("security")
fb = config.get("fallback_model", {})
# Append commented-out fallback_model docs if user hasn't configured it
fb = config.get("fallback_model")
if not fb or not (fb.get("provider") and fb.get("model")):
sections.append("fallback")
if sections:
f.write(_COMMENTED_SECTIONS)
f.write(_FALLBACK_MODEL_COMMENT)
def load_env() -> Dict[str, str]:

View File

@@ -3092,14 +3092,9 @@ class AIAgent:
)
self._iters_since_skill = 0
# Honcho prefetch: retrieve user context for system prompt injection.
# Only on the FIRST turn of a session (empty history). On subsequent
# turns the model already has all prior context in its conversation
# history, and the Honcho context is baked into the stored system
# prompt — re-fetching it would change the system message and break
# Anthropic prompt caching.
# Honcho prefetch: retrieve user context for system prompt injection
self._honcho_context = ""
if self._honcho and self._honcho_session_key and not conversation_history:
if self._honcho and self._honcho_session_key:
try:
self._honcho_context = self._honcho_prefetch(user_message)
except Exception as e:
@@ -3117,42 +3112,14 @@ class AIAgent:
# Built once on first call, reused for all subsequent calls.
# Only rebuilt after context compression events (which invalidate
# the cache and reload memory from disk).
#
# For continuing sessions (gateway creates a fresh AIAgent per
# message), we load the stored system prompt from the session DB
# instead of rebuilding. Rebuilding would pick up memory changes
# from disk that the model already knows about (it wrote them!),
# producing a different system prompt and breaking the Anthropic
# prefix cache.
if self._cached_system_prompt is None:
stored_prompt = None
if conversation_history and self._session_db:
self._cached_system_prompt = self._build_system_prompt(system_message)
# Store the system prompt snapshot in SQLite
if self._session_db:
try:
session_row = self._session_db.get_session(self.session_id)
if session_row:
stored_prompt = session_row.get("system_prompt") or None
except Exception:
pass # Fall through to build fresh
if stored_prompt:
# Continuing session — reuse the exact system prompt from
# the previous turn so the Anthropic cache prefix matches.
self._cached_system_prompt = stored_prompt
else:
# First turn of a new session — build from scratch.
self._cached_system_prompt = self._build_system_prompt(system_message)
# Bake Honcho context into the prompt so it's stable for
# the entire session (not re-fetched per turn).
if self._honcho_context:
self._cached_system_prompt = (
self._cached_system_prompt + "\n\n" + self._honcho_context
).strip()
# Store the system prompt snapshot in SQLite
if self._session_db:
try:
self._session_db.update_system_prompt(self.session_id, self._cached_system_prompt)
except Exception as e:
logger.debug("Session DB update_system_prompt failed: %s", e)
self._session_db.update_system_prompt(self.session_id, self._cached_system_prompt)
except Exception as e:
logger.debug("Session DB update_system_prompt failed: %s", e)
active_system_prompt = self._cached_system_prompt
@@ -3277,13 +3244,11 @@ class AIAgent:
# Build the final system message: cached prompt + ephemeral system prompt.
# The ephemeral part is appended here (not baked into the cached prompt)
# so it stays out of the session DB and logs.
# Note: Honcho context is baked into _cached_system_prompt on the first
# turn and stored in the session DB, so it does NOT need to be injected
# here. This keeps the system message identical across all turns in a
# session, maximizing Anthropic prompt cache hits.
effective_system = active_system_prompt or ""
if self.ephemeral_system_prompt:
effective_system = (effective_system + "\n\n" + self.ephemeral_system_prompt).strip()
if self._honcho_context:
effective_system = (effective_system + "\n\n" + self._honcho_context).strip()
if effective_system:
api_messages = [{"role": "system", "content": effective_system}] + api_messages

View File

@@ -321,6 +321,32 @@ mcp_servers:
All tools from all servers are registered and available simultaneously. Each server's tools are prefixed with its name to avoid collisions.
## Sampling (Server-Initiated LLM Requests)
Hermes supports MCP's `sampling/createMessage` capability — MCP servers can request LLM completions through the agent during tool execution. This enables agent-in-the-loop workflows (data analysis, content generation, decision-making).
Sampling is **enabled by default**. Configure per server:
```yaml
mcp_servers:
my_server:
command: "npx"
args: ["-y", "my-mcp-server"]
sampling:
enabled: true # default: true
model: "gemini-3-flash" # model override (optional)
max_tokens_cap: 4096 # max tokens per request
timeout: 30 # LLM call timeout (seconds)
max_rpm: 10 # max requests per minute
allowed_models: [] # model whitelist (empty = all)
max_tool_rounds: 5 # tool loop limit (0 = disable)
log_level: "info" # audit verbosity
```
Servers can also include `tools` in sampling requests for multi-turn tool-augmented workflows. The `max_tool_rounds` config prevents infinite tool loops. Per-server audit metrics (requests, errors, tokens, tool use count) are tracked via `get_mcp_status()`.
Disable sampling for untrusted servers with `sampling: { enabled: false }`.
## Notes
- MCP tools are called synchronously from the agent's perspective but run asynchronously on a dedicated background event loop

View File

@@ -1040,136 +1040,3 @@ class TestMaxTokensParam:
agent.base_url = "https://openrouter.ai/api/v1/api.openai.com"
result = agent._max_tokens_param(4096)
assert result == {"max_tokens": 4096}
# ---------------------------------------------------------------------------
# System prompt stability for prompt caching
# ---------------------------------------------------------------------------
class TestSystemPromptStability:
"""Verify that the system prompt stays stable across turns for cache hits."""
def test_stored_prompt_reused_for_continuing_session(self, agent):
"""When conversation_history is non-empty and session DB has a stored
prompt, it should be reused instead of rebuilding from disk."""
stored = "You are helpful. [stored from turn 1]"
mock_db = MagicMock()
mock_db.get_session.return_value = {"system_prompt": stored}
agent._session_db = mock_db
# Simulate a continuing session with history
history = [
{"role": "user", "content": "hello"},
{"role": "assistant", "content": "hi"},
]
# First call — _cached_system_prompt is None, history is non-empty
agent._cached_system_prompt = None
# Patch run_conversation internals to just test the system prompt logic.
# We'll call the prompt caching block directly by simulating what
# run_conversation does.
conversation_history = history
# The block under test (from run_conversation):
if agent._cached_system_prompt is None:
stored_prompt = None
if conversation_history and agent._session_db:
try:
session_row = agent._session_db.get_session(agent.session_id)
if session_row:
stored_prompt = session_row.get("system_prompt") or None
except Exception:
pass
if stored_prompt:
agent._cached_system_prompt = stored_prompt
assert agent._cached_system_prompt == stored
mock_db.get_session.assert_called_once_with(agent.session_id)
def test_fresh_build_when_no_history(self, agent):
"""On the first turn (no history), system prompt should be built fresh."""
mock_db = MagicMock()
agent._session_db = mock_db
agent._cached_system_prompt = None
conversation_history = []
# The block under test:
if agent._cached_system_prompt is None:
stored_prompt = None
if conversation_history and agent._session_db:
session_row = agent._session_db.get_session(agent.session_id)
if session_row:
stored_prompt = session_row.get("system_prompt") or None
if stored_prompt:
agent._cached_system_prompt = stored_prompt
else:
agent._cached_system_prompt = agent._build_system_prompt()
# Should have built fresh, not queried the DB
mock_db.get_session.assert_not_called()
assert agent._cached_system_prompt is not None
assert "Hermes Agent" in agent._cached_system_prompt
def test_fresh_build_when_db_has_no_prompt(self, agent):
"""If the session DB has no stored prompt, build fresh even with history."""
mock_db = MagicMock()
mock_db.get_session.return_value = {"system_prompt": ""}
agent._session_db = mock_db
agent._cached_system_prompt = None
conversation_history = [{"role": "user", "content": "hi"}]
if agent._cached_system_prompt is None:
stored_prompt = None
if conversation_history and agent._session_db:
try:
session_row = agent._session_db.get_session(agent.session_id)
if session_row:
stored_prompt = session_row.get("system_prompt") or None
except Exception:
pass
if stored_prompt:
agent._cached_system_prompt = stored_prompt
else:
agent._cached_system_prompt = agent._build_system_prompt()
# Empty string is falsy, so should fall through to fresh build
assert "Hermes Agent" in agent._cached_system_prompt
def test_honcho_context_baked_into_prompt_on_first_turn(self, agent):
"""Honcho context should be baked into _cached_system_prompt on
the first turn, not injected separately per API call."""
agent._honcho_context = "User prefers Python over JavaScript."
agent._cached_system_prompt = None
# Simulate first turn: build fresh and bake in Honcho
agent._cached_system_prompt = agent._build_system_prompt()
if agent._honcho_context:
agent._cached_system_prompt = (
agent._cached_system_prompt + "\n\n" + agent._honcho_context
).strip()
assert "User prefers Python over JavaScript" in agent._cached_system_prompt
def test_honcho_prefetch_skipped_on_continuing_session(self):
"""Honcho prefetch should not be called when conversation_history
is non-empty (continuing session)."""
conversation_history = [
{"role": "user", "content": "hello"},
{"role": "assistant", "content": "hi there"},
]
# The guard: `not conversation_history` is False when history exists
should_prefetch = not conversation_history
assert should_prefetch is False
def test_honcho_prefetch_runs_on_first_turn(self):
"""Honcho prefetch should run when conversation_history is empty."""
conversation_history = []
should_prefetch = not conversation_history
assert should_prefetch is True

View File

@@ -393,56 +393,5 @@ class TestStubSchemaDrift(unittest.TestCase):
self.assertIn("mode", src)
class TestHeadTailTruncation(unittest.TestCase):
"""Tests for head+tail truncation of large stdout in execute_code."""
def _run(self, code):
with patch("model_tools.handle_function_call", side_effect=_mock_handle_function_call):
result = execute_code(
code=code,
task_id="test-task",
enabled_tools=list(SANDBOX_ALLOWED_TOOLS),
)
return json.loads(result)
def test_short_output_not_truncated(self):
"""Output under MAX_STDOUT_BYTES should not be truncated."""
result = self._run('print("small output")')
self.assertEqual(result["status"], "success")
self.assertIn("small output", result["output"])
self.assertNotIn("TRUNCATED", result["output"])
def test_large_output_preserves_head_and_tail(self):
"""Output exceeding MAX_STDOUT_BYTES keeps both head and tail."""
code = '''
# Print HEAD marker, then filler, then TAIL marker
print("HEAD_MARKER_START")
for i in range(15000):
print(f"filler_line_{i:06d}_padding_to_fill_buffer")
print("TAIL_MARKER_END")
'''
result = self._run(code)
self.assertEqual(result["status"], "success")
output = result["output"]
# Head should be preserved
self.assertIn("HEAD_MARKER_START", output)
# Tail should be preserved (this is the key improvement)
self.assertIn("TAIL_MARKER_END", output)
# Truncation notice should be present
self.assertIn("TRUNCATED", output)
def test_truncation_notice_format(self):
"""Truncation notice includes character counts."""
code = '''
for i in range(15000):
print(f"padding_line_{i:06d}_xxxxxxxxxxxxxxxxxxxxxxxxxx")
'''
result = self._run(code)
output = result["output"]
if "TRUNCATED" in output:
self.assertIn("chars omitted", output)
self.assertIn("total", output)
if __name__ == "__main__":
unittest.main()

View File

@@ -38,7 +38,6 @@ class TestReadFileHandler:
def test_returns_file_content(self, mock_get):
mock_ops = MagicMock()
result_obj = MagicMock()
result_obj.content = "line1\nline2"
result_obj.to_dict.return_value = {"content": "line1\nline2", "total_lines": 2}
mock_ops.read_file.return_value = result_obj
mock_get.return_value = mock_ops
@@ -53,7 +52,6 @@ class TestReadFileHandler:
def test_custom_offset_and_limit(self, mock_get):
mock_ops = MagicMock()
result_obj = MagicMock()
result_obj.content = "line10"
result_obj.to_dict.return_value = {"content": "line10", "total_lines": 50}
mock_ops.read_file.return_value = result_obj
mock_get.return_value = mock_ops

View File

@@ -1489,3 +1489,781 @@ class TestUtilityToolRegistration:
assert entry.check_fn() is False
_servers.pop("chk", None)
# ===========================================================================
# SamplingHandler tests
# ===========================================================================
import math
import time
from mcp.types import (
CreateMessageResult,
CreateMessageResultWithTools,
ErrorData,
SamplingCapability,
SamplingToolsCapability,
TextContent,
ToolUseContent,
)
from tools.mcp_tool import SamplingHandler, _safe_numeric
# ---------------------------------------------------------------------------
# Helpers for sampling tests
# ---------------------------------------------------------------------------
def _make_sampling_params(
messages=None,
max_tokens=100,
system_prompt=None,
model_preferences=None,
temperature=None,
stop_sequences=None,
tools=None,
tool_choice=None,
):
"""Create a fake CreateMessageRequestParams using SimpleNamespace.
Each message must have a ``content_as_list`` attribute that mirrors
the SDK helper so that ``_convert_messages`` works correctly.
"""
if messages is None:
content = SimpleNamespace(text="Hello")
msg = SimpleNamespace(role="user", content=content, content_as_list=[content])
messages = [msg]
params = SimpleNamespace(
messages=messages,
maxTokens=max_tokens,
modelPreferences=model_preferences,
temperature=temperature,
stopSequences=stop_sequences,
tools=tools,
toolChoice=tool_choice,
)
if system_prompt is not None:
params.systemPrompt = system_prompt
return params
def _make_llm_response(
content="LLM response",
model="test-model",
finish_reason="stop",
tool_calls=None,
):
"""Create a fake OpenAI chat completion response (text)."""
message = SimpleNamespace(content=content, tool_calls=tool_calls)
choice = SimpleNamespace(
finish_reason=finish_reason,
message=message,
)
usage = SimpleNamespace(total_tokens=42)
return SimpleNamespace(choices=[choice], model=model, usage=usage)
def _make_llm_tool_response(tool_calls_data=None, model="test-model"):
"""Create a fake response with tool_calls.
``tool_calls_data``: list of (id, name, arguments_json) tuples.
"""
if tool_calls_data is None:
tool_calls_data = [("call_1", "get_weather", '{"city": "London"}')]
tc_list = [
SimpleNamespace(
id=tc_id,
function=SimpleNamespace(name=name, arguments=args),
)
for tc_id, name, args in tool_calls_data
]
return _make_llm_response(
content=None,
model=model,
finish_reason="tool_calls",
tool_calls=tc_list,
)
# ---------------------------------------------------------------------------
# 1. _safe_numeric helper
# ---------------------------------------------------------------------------
class TestSafeNumeric:
def test_int_passthrough(self):
assert _safe_numeric(10, 5, int) == 10
def test_string_coercion(self):
assert _safe_numeric("20", 5, int) == 20
def test_none_returns_default(self):
assert _safe_numeric(None, 7, int) == 7
def test_inf_returns_default(self):
assert _safe_numeric(float("inf"), 3.0, float) == 3.0
def test_nan_returns_default(self):
assert _safe_numeric(float("nan"), 4.0, float) == 4.0
def test_below_minimum_clamps(self):
assert _safe_numeric(-5, 10, int, minimum=1) == 1
def test_minimum_zero_allowed(self):
assert _safe_numeric(0, 10, int, minimum=0) == 0
def test_non_numeric_string_returns_default(self):
assert _safe_numeric("abc", 42, int) == 42
def test_float_coercion(self):
assert _safe_numeric("3.5", 1.0, float) == 3.5
# ---------------------------------------------------------------------------
# 2. SamplingHandler initialization and config parsing
# ---------------------------------------------------------------------------
class TestSamplingHandlerInit:
def test_defaults(self):
h = SamplingHandler("srv", {})
assert h.server_name == "srv"
assert h.max_rpm == 10
assert h.timeout == 30
assert h.max_tokens_cap == 4096
assert h.max_tool_rounds == 5
assert h.model_override is None
assert h.allowed_models == []
assert h.metrics == {"requests": 0, "errors": 0, "tokens_used": 0, "tool_use_count": 0}
def test_custom_config(self):
cfg = {
"max_rpm": 20,
"timeout": 60,
"max_tokens_cap": 2048,
"max_tool_rounds": 3,
"model": "gpt-4o",
"allowed_models": ["gpt-4o", "gpt-3.5-turbo"],
"log_level": "debug",
}
h = SamplingHandler("custom", cfg)
assert h.max_rpm == 20
assert h.timeout == 60.0
assert h.max_tokens_cap == 2048
assert h.max_tool_rounds == 3
assert h.model_override == "gpt-4o"
assert h.allowed_models == ["gpt-4o", "gpt-3.5-turbo"]
def test_string_numeric_config_values(self):
"""YAML sometimes delivers numeric values as strings."""
cfg = {"max_rpm": "15", "timeout": "45.5", "max_tokens_cap": "1024"}
h = SamplingHandler("s", cfg)
assert h.max_rpm == 15
assert h.timeout == 45.5
assert h.max_tokens_cap == 1024
# ---------------------------------------------------------------------------
# 3. Rate limiting
# ---------------------------------------------------------------------------
class TestRateLimit:
def setup_method(self):
self.handler = SamplingHandler("rl", {"max_rpm": 3})
def test_allows_under_limit(self):
assert self.handler._check_rate_limit() is True
assert self.handler._check_rate_limit() is True
assert self.handler._check_rate_limit() is True
def test_rejects_over_limit(self):
for _ in range(3):
self.handler._check_rate_limit()
assert self.handler._check_rate_limit() is False
def test_window_expiry(self):
"""Old timestamps should be purged from the sliding window."""
for _ in range(3):
self.handler._check_rate_limit()
# Simulate timestamps from 61 seconds ago
self.handler._rate_timestamps[:] = [time.time() - 61] * 3
assert self.handler._check_rate_limit() is True
# ---------------------------------------------------------------------------
# 4. Model resolution
# ---------------------------------------------------------------------------
class TestResolveModel:
def setup_method(self):
self.handler = SamplingHandler("mr", {})
def test_no_preference_no_override(self):
assert self.handler._resolve_model(None) is None
def test_config_override_wins(self):
self.handler.model_override = "override-model"
prefs = SimpleNamespace(hints=[SimpleNamespace(name="hint-model")])
assert self.handler._resolve_model(prefs) == "override-model"
def test_hint_used_when_no_override(self):
prefs = SimpleNamespace(hints=[SimpleNamespace(name="hint-model")])
assert self.handler._resolve_model(prefs) == "hint-model"
def test_empty_hints(self):
prefs = SimpleNamespace(hints=[])
assert self.handler._resolve_model(prefs) is None
def test_hint_without_name(self):
prefs = SimpleNamespace(hints=[SimpleNamespace(name=None)])
assert self.handler._resolve_model(prefs) is None
# ---------------------------------------------------------------------------
# 5. Message conversion
# ---------------------------------------------------------------------------
class TestConvertMessages:
def setup_method(self):
self.handler = SamplingHandler("mc", {})
def test_single_text_message(self):
content = SimpleNamespace(text="Hello world")
msg = SimpleNamespace(role="user", content=content, content_as_list=[content])
params = _make_sampling_params(messages=[msg])
result = self.handler._convert_messages(params)
assert len(result) == 1
assert result[0] == {"role": "user", "content": "Hello world"}
def test_image_message(self):
text_block = SimpleNamespace(text="Look at this")
img_block = SimpleNamespace(data="abc123", mimeType="image/png")
msg = SimpleNamespace(
role="user",
content=[text_block, img_block],
content_as_list=[text_block, img_block],
)
params = _make_sampling_params(messages=[msg])
result = self.handler._convert_messages(params)
assert len(result) == 1
parts = result[0]["content"]
assert len(parts) == 2
assert parts[0] == {"type": "text", "text": "Look at this"}
assert parts[1]["type"] == "image_url"
assert "data:image/png;base64,abc123" in parts[1]["image_url"]["url"]
def test_tool_result_message(self):
inner = SimpleNamespace(text="42 degrees")
tr_block = SimpleNamespace(toolUseId="call_1", content=[inner])
msg = SimpleNamespace(
role="user",
content=[tr_block],
content_as_list=[tr_block],
)
params = _make_sampling_params(messages=[msg])
result = self.handler._convert_messages(params)
assert len(result) == 1
assert result[0]["role"] == "tool"
assert result[0]["tool_call_id"] == "call_1"
assert result[0]["content"] == "42 degrees"
def test_tool_use_message(self):
tu_block = SimpleNamespace(
id="call_2", name="get_weather", input={"city": "London"}
)
msg = SimpleNamespace(
role="assistant",
content=[tu_block],
content_as_list=[tu_block],
)
params = _make_sampling_params(messages=[msg])
result = self.handler._convert_messages(params)
assert len(result) == 1
assert result[0]["role"] == "assistant"
assert len(result[0]["tool_calls"]) == 1
assert result[0]["tool_calls"][0]["function"]["name"] == "get_weather"
assert json.loads(result[0]["tool_calls"][0]["function"]["arguments"]) == {"city": "London"}
def test_mixed_text_and_tool_use(self):
"""Assistant message with both text and tool_calls."""
text_block = SimpleNamespace(text="Let me check the weather")
tu_block = SimpleNamespace(
id="call_3", name="get_weather", input={"city": "Paris"}
)
msg = SimpleNamespace(
role="assistant",
content=[text_block, tu_block],
content_as_list=[text_block, tu_block],
)
params = _make_sampling_params(messages=[msg])
result = self.handler._convert_messages(params)
assert len(result) == 1
assert result[0]["content"] == "Let me check the weather"
assert len(result[0]["tool_calls"]) == 1
def test_fallback_without_content_as_list(self):
"""When content_as_list is absent, falls back to content."""
content = SimpleNamespace(text="Fallback text")
msg = SimpleNamespace(role="user", content=content)
params = _make_sampling_params(messages=[msg])
result = self.handler._convert_messages(params)
assert len(result) == 1
assert result[0]["content"] == "Fallback text"
# ---------------------------------------------------------------------------
# 6. Text-only sampling callback (full flow)
# ---------------------------------------------------------------------------
class TestSamplingCallbackText:
def setup_method(self):
self.handler = SamplingHandler("txt", {})
def test_text_response(self):
"""Full flow: text response returns CreateMessageResult."""
fake_client = MagicMock()
fake_client.chat.completions.create.return_value = _make_llm_response(
content="Hello from LLM"
)
with patch(
"agent.auxiliary_client.get_text_auxiliary_client",
return_value=(fake_client, "default-model"),
):
params = _make_sampling_params()
result = asyncio.run(self.handler(None, params))
assert isinstance(result, CreateMessageResult)
assert isinstance(result.content, TextContent)
assert result.content.text == "Hello from LLM"
assert result.model == "test-model"
assert result.role == "assistant"
assert result.stopReason == "endTurn"
def test_system_prompt_prepended(self):
"""System prompt is inserted as the first message."""
fake_client = MagicMock()
fake_client.chat.completions.create.return_value = _make_llm_response()
with patch(
"agent.auxiliary_client.get_text_auxiliary_client",
return_value=(fake_client, "default-model"),
):
params = _make_sampling_params(system_prompt="Be helpful")
asyncio.run(self.handler(None, params))
call_args = fake_client.chat.completions.create.call_args
messages = call_args.kwargs["messages"]
assert messages[0] == {"role": "system", "content": "Be helpful"}
def test_length_stop_reason(self):
"""finish_reason='length' maps to stopReason='maxTokens'."""
fake_client = MagicMock()
fake_client.chat.completions.create.return_value = _make_llm_response(
finish_reason="length"
)
with patch(
"agent.auxiliary_client.get_text_auxiliary_client",
return_value=(fake_client, "default-model"),
):
params = _make_sampling_params()
result = asyncio.run(self.handler(None, params))
assert isinstance(result, CreateMessageResult)
assert result.stopReason == "maxTokens"
# ---------------------------------------------------------------------------
# 7. Tool use sampling callback
# ---------------------------------------------------------------------------
class TestSamplingCallbackToolUse:
def setup_method(self):
self.handler = SamplingHandler("tu", {})
def test_tool_use_response(self):
"""LLM tool_calls response returns CreateMessageResultWithTools."""
fake_client = MagicMock()
fake_client.chat.completions.create.return_value = _make_llm_tool_response()
with patch(
"agent.auxiliary_client.get_text_auxiliary_client",
return_value=(fake_client, "default-model"),
):
params = _make_sampling_params()
result = asyncio.run(self.handler(None, params))
assert isinstance(result, CreateMessageResultWithTools)
assert result.stopReason == "toolUse"
assert result.model == "test-model"
assert len(result.content) == 1
tc = result.content[0]
assert isinstance(tc, ToolUseContent)
assert tc.name == "get_weather"
assert tc.id == "call_1"
assert tc.input == {"city": "London"}
def test_multiple_tool_calls(self):
"""Multiple tool_calls in a single response."""
fake_client = MagicMock()
fake_client.chat.completions.create.return_value = _make_llm_tool_response(
tool_calls_data=[
("call_a", "func_a", '{"x": 1}'),
("call_b", "func_b", '{"y": 2}'),
]
)
with patch(
"agent.auxiliary_client.get_text_auxiliary_client",
return_value=(fake_client, "default-model"),
):
result = asyncio.run(self.handler(None, _make_sampling_params()))
assert isinstance(result, CreateMessageResultWithTools)
assert len(result.content) == 2
assert result.content[0].name == "func_a"
assert result.content[1].name == "func_b"
# ---------------------------------------------------------------------------
# 8. Tool loop governance
# ---------------------------------------------------------------------------
class TestToolLoopGovernance:
def test_max_tool_rounds_enforcement(self):
"""After max_tool_rounds consecutive tool responses, an error is returned."""
handler = SamplingHandler("tl", {"max_tool_rounds": 2})
fake_client = MagicMock()
fake_client.chat.completions.create.return_value = _make_llm_tool_response()
with patch(
"agent.auxiliary_client.get_text_auxiliary_client",
return_value=(fake_client, "default-model"),
):
params = _make_sampling_params()
# Round 1, 2: allowed
r1 = asyncio.run(handler(None, params))
assert isinstance(r1, CreateMessageResultWithTools)
r2 = asyncio.run(handler(None, params))
assert isinstance(r2, CreateMessageResultWithTools)
# Round 3: exceeds limit
r3 = asyncio.run(handler(None, params))
assert isinstance(r3, ErrorData)
assert "Tool loop limit exceeded" in r3.message
def test_text_response_resets_counter(self):
"""A text response resets the tool loop counter."""
handler = SamplingHandler("tl2", {"max_tool_rounds": 1})
fake_client = MagicMock()
with patch(
"agent.auxiliary_client.get_text_auxiliary_client",
return_value=(fake_client, "default-model"),
):
# Tool response (round 1 of 1 allowed)
fake_client.chat.completions.create.return_value = _make_llm_tool_response()
r1 = asyncio.run(handler(None, _make_sampling_params()))
assert isinstance(r1, CreateMessageResultWithTools)
# Text response resets counter
fake_client.chat.completions.create.return_value = _make_llm_response()
r2 = asyncio.run(handler(None, _make_sampling_params()))
assert isinstance(r2, CreateMessageResult)
# Tool response again (should succeed since counter was reset)
fake_client.chat.completions.create.return_value = _make_llm_tool_response()
r3 = asyncio.run(handler(None, _make_sampling_params()))
assert isinstance(r3, CreateMessageResultWithTools)
def test_max_tool_rounds_zero_disables(self):
"""max_tool_rounds=0 means tool loops are disabled entirely."""
handler = SamplingHandler("tl3", {"max_tool_rounds": 0})
fake_client = MagicMock()
fake_client.chat.completions.create.return_value = _make_llm_tool_response()
with patch(
"agent.auxiliary_client.get_text_auxiliary_client",
return_value=(fake_client, "default-model"),
):
result = asyncio.run(handler(None, _make_sampling_params()))
assert isinstance(result, ErrorData)
assert "Tool loops disabled" in result.message
# ---------------------------------------------------------------------------
# 9. Error paths: rate limit, timeout, no provider
# ---------------------------------------------------------------------------
class TestSamplingErrors:
def test_rate_limit_error(self):
handler = SamplingHandler("rle", {"max_rpm": 1})
fake_client = MagicMock()
fake_client.chat.completions.create.return_value = _make_llm_response()
with patch(
"agent.auxiliary_client.get_text_auxiliary_client",
return_value=(fake_client, "default-model"),
):
# First call succeeds
r1 = asyncio.run(handler(None, _make_sampling_params()))
assert isinstance(r1, CreateMessageResult)
# Second call is rate limited
r2 = asyncio.run(handler(None, _make_sampling_params()))
assert isinstance(r2, ErrorData)
assert "rate limit" in r2.message.lower()
assert handler.metrics["errors"] == 1
def test_timeout_error(self):
handler = SamplingHandler("to", {"timeout": 0.05})
fake_client = MagicMock()
def slow_call(**kwargs):
import threading
# Use an event to ensure the thread truly blocks long enough
evt = threading.Event()
evt.wait(5) # blocks for up to 5 seconds (cancelled by timeout)
return _make_llm_response()
fake_client.chat.completions.create.side_effect = slow_call
with patch(
"agent.auxiliary_client.get_text_auxiliary_client",
return_value=(fake_client, "default-model"),
):
result = asyncio.run(handler(None, _make_sampling_params()))
assert isinstance(result, ErrorData)
assert "timed out" in result.message.lower()
assert handler.metrics["errors"] == 1
def test_no_provider_error(self):
handler = SamplingHandler("np", {})
with patch(
"agent.auxiliary_client.get_text_auxiliary_client",
return_value=(None, None),
):
result = asyncio.run(handler(None, _make_sampling_params()))
assert isinstance(result, ErrorData)
assert "No LLM provider" in result.message
assert handler.metrics["errors"] == 1
# ---------------------------------------------------------------------------
# 10. Model whitelist
# ---------------------------------------------------------------------------
class TestModelWhitelist:
def test_allowed_model_passes(self):
handler = SamplingHandler("wl", {"allowed_models": ["gpt-4o", "test-model"]})
fake_client = MagicMock()
fake_client.chat.completions.create.return_value = _make_llm_response()
with patch(
"agent.auxiliary_client.get_text_auxiliary_client",
return_value=(fake_client, "test-model"),
):
result = asyncio.run(handler(None, _make_sampling_params()))
assert isinstance(result, CreateMessageResult)
def test_disallowed_model_rejected(self):
handler = SamplingHandler("wl2", {"allowed_models": ["gpt-4o"]})
fake_client = MagicMock()
with patch(
"agent.auxiliary_client.get_text_auxiliary_client",
return_value=(fake_client, "gpt-3.5-turbo"),
):
result = asyncio.run(handler(None, _make_sampling_params()))
assert isinstance(result, ErrorData)
assert "not allowed" in result.message
assert handler.metrics["errors"] == 1
def test_empty_whitelist_allows_all(self):
handler = SamplingHandler("wl3", {"allowed_models": []})
fake_client = MagicMock()
fake_client.chat.completions.create.return_value = _make_llm_response()
with patch(
"agent.auxiliary_client.get_text_auxiliary_client",
return_value=(fake_client, "any-model"),
):
result = asyncio.run(handler(None, _make_sampling_params()))
assert isinstance(result, CreateMessageResult)
# ---------------------------------------------------------------------------
# 11. Malformed tool_call arguments
# ---------------------------------------------------------------------------
class TestMalformedToolCallArgs:
def test_invalid_json_wrapped_as_raw(self):
"""Malformed JSON arguments get wrapped in {"_raw": ...}."""
handler = SamplingHandler("mf", {})
fake_client = MagicMock()
fake_client.chat.completions.create.return_value = _make_llm_tool_response(
tool_calls_data=[("call_x", "some_tool", "not valid json {{{")]
)
with patch(
"agent.auxiliary_client.get_text_auxiliary_client",
return_value=(fake_client, "default-model"),
):
result = asyncio.run(handler(None, _make_sampling_params()))
assert isinstance(result, CreateMessageResultWithTools)
tc = result.content[0]
assert isinstance(tc, ToolUseContent)
assert tc.input == {"_raw": "not valid json {{{"}
def test_dict_args_pass_through(self):
"""When arguments are already a dict, they pass through directly."""
handler = SamplingHandler("mf2", {})
# Build a tool call where arguments is already a dict
tc_obj = SimpleNamespace(
id="call_d",
function=SimpleNamespace(name="do_stuff", arguments={"key": "val"}),
)
message = SimpleNamespace(content=None, tool_calls=[tc_obj])
choice = SimpleNamespace(finish_reason="tool_calls", message=message)
usage = SimpleNamespace(total_tokens=10)
response = SimpleNamespace(choices=[choice], model="m", usage=usage)
fake_client = MagicMock()
fake_client.chat.completions.create.return_value = response
with patch(
"agent.auxiliary_client.get_text_auxiliary_client",
return_value=(fake_client, "default-model"),
):
result = asyncio.run(handler(None, _make_sampling_params()))
assert isinstance(result, CreateMessageResultWithTools)
assert result.content[0].input == {"key": "val"}
# ---------------------------------------------------------------------------
# 12. Metrics tracking
# ---------------------------------------------------------------------------
class TestMetricsTracking:
def test_request_and_token_metrics(self):
handler = SamplingHandler("met", {})
fake_client = MagicMock()
fake_client.chat.completions.create.return_value = _make_llm_response()
with patch(
"agent.auxiliary_client.get_text_auxiliary_client",
return_value=(fake_client, "default-model"),
):
asyncio.run(handler(None, _make_sampling_params()))
assert handler.metrics["requests"] == 1
assert handler.metrics["tokens_used"] == 42
assert handler.metrics["errors"] == 0
def test_tool_use_count_metric(self):
handler = SamplingHandler("met2", {})
fake_client = MagicMock()
fake_client.chat.completions.create.return_value = _make_llm_tool_response()
with patch(
"agent.auxiliary_client.get_text_auxiliary_client",
return_value=(fake_client, "default-model"),
):
asyncio.run(handler(None, _make_sampling_params()))
assert handler.metrics["tool_use_count"] == 1
assert handler.metrics["requests"] == 1
def test_error_metric_incremented(self):
handler = SamplingHandler("met3", {})
with patch(
"agent.auxiliary_client.get_text_auxiliary_client",
return_value=(None, None),
):
asyncio.run(handler(None, _make_sampling_params()))
assert handler.metrics["errors"] == 1
assert handler.metrics["requests"] == 0
# ---------------------------------------------------------------------------
# 13. session_kwargs()
# ---------------------------------------------------------------------------
class TestSessionKwargs:
def test_returns_correct_keys(self):
handler = SamplingHandler("sk", {})
kwargs = handler.session_kwargs()
assert "sampling_callback" in kwargs
assert "sampling_capabilities" in kwargs
assert kwargs["sampling_callback"] is handler
def test_sampling_capabilities_type(self):
handler = SamplingHandler("sk2", {})
kwargs = handler.session_kwargs()
cap = kwargs["sampling_capabilities"]
assert isinstance(cap, SamplingCapability)
assert isinstance(cap.tools, SamplingToolsCapability)
# ---------------------------------------------------------------------------
# 14. MCPServerTask integration
# ---------------------------------------------------------------------------
class TestMCPServerTaskSamplingIntegration:
def test_sampling_handler_created_when_enabled(self):
"""MCPServerTask.run() creates a SamplingHandler when sampling is enabled."""
from tools.mcp_tool import MCPServerTask, _MCP_SAMPLING_TYPES
server = MCPServerTask("int_test")
config = {
"command": "fake",
"sampling": {"enabled": True, "max_rpm": 5},
}
# We only need to test the setup logic, not the actual connection.
# Calling run() would attempt a real connection, so we test the
# sampling setup portion directly.
server._config = config
sampling_config = config.get("sampling", {})
if sampling_config.get("enabled", True) and _MCP_SAMPLING_TYPES:
server._sampling = SamplingHandler(server.name, sampling_config)
else:
server._sampling = None
assert server._sampling is not None
assert isinstance(server._sampling, SamplingHandler)
assert server._sampling.server_name == "int_test"
assert server._sampling.max_rpm == 5
def test_sampling_handler_none_when_disabled(self):
"""MCPServerTask._sampling is None when sampling is disabled."""
from tools.mcp_tool import MCPServerTask, _MCP_SAMPLING_TYPES
server = MCPServerTask("int_test2")
config = {
"command": "fake",
"sampling": {"enabled": False},
}
server._config = config
sampling_config = config.get("sampling", {})
if sampling_config.get("enabled", True) and _MCP_SAMPLING_TYPES:
server._sampling = SamplingHandler(server.name, sampling_config)
else:
server._sampling = None
assert server._sampling is None
def test_session_kwargs_used_in_stdio(self):
"""When sampling is set, session_kwargs() are passed to ClientSession."""
from tools.mcp_tool import MCPServerTask
server = MCPServerTask("sk_test")
server._sampling = SamplingHandler("sk_test", {"max_rpm": 7})
kwargs = server._sampling.session_kwargs()
assert "sampling_callback" in kwargs
assert "sampling_capabilities" in kwargs

View File

@@ -457,17 +457,11 @@ def execute_code(
# --- Poll loop: watch for exit, timeout, and interrupt ---
deadline = time.monotonic() + timeout
stdout_chunks: list = []
stderr_chunks: list = []
# Background readers to avoid pipe buffer deadlocks.
# For stdout we use a head+tail strategy: keep the first HEAD_BYTES
# and a rolling window of the last TAIL_BYTES so the final print()
# output is never lost. Stderr keeps head-only (errors appear early).
_STDOUT_HEAD_BYTES = int(MAX_STDOUT_BYTES * 0.4) # 40% head
_STDOUT_TAIL_BYTES = MAX_STDOUT_BYTES - _STDOUT_HEAD_BYTES # 60% tail
# Background readers to avoid pipe buffer deadlocks
def _drain(pipe, chunks, max_bytes):
"""Simple head-only drain (used for stderr)."""
total = 0
try:
while True:
@@ -481,48 +475,8 @@ def execute_code(
except (ValueError, OSError):
pass
stdout_total_bytes = [0] # mutable ref for total bytes seen
def _drain_head_tail(pipe, head_chunks, tail_chunks, head_bytes, tail_bytes, total_ref):
"""Drain stdout keeping both head and tail data."""
head_collected = 0
from collections import deque
tail_buf = deque()
tail_collected = 0
try:
while True:
data = pipe.read(4096)
if not data:
break
total_ref[0] += len(data)
# Fill head buffer first
if head_collected < head_bytes:
keep = min(len(data), head_bytes - head_collected)
head_chunks.append(data[:keep])
head_collected += keep
data = data[keep:] # remaining goes to tail
if not data:
continue
# Everything past head goes into rolling tail buffer
tail_buf.append(data)
tail_collected += len(data)
# Evict old tail data to stay within tail_bytes budget
while tail_collected > tail_bytes and tail_buf:
oldest = tail_buf.popleft()
tail_collected -= len(oldest)
except (ValueError, OSError):
pass
# Transfer final tail to output list
tail_chunks.extend(tail_buf)
stdout_head_chunks: list = []
stdout_tail_chunks: list = []
stdout_reader = threading.Thread(
target=_drain_head_tail,
args=(proc.stdout, stdout_head_chunks, stdout_tail_chunks,
_STDOUT_HEAD_BYTES, _STDOUT_TAIL_BYTES, stdout_total_bytes),
daemon=True
target=_drain, args=(proc.stdout, stdout_chunks, MAX_STDOUT_BYTES), daemon=True
)
stderr_reader = threading.Thread(
target=_drain, args=(proc.stderr, stderr_chunks, MAX_STDERR_BYTES), daemon=True
@@ -546,21 +500,12 @@ def execute_code(
stdout_reader.join(timeout=3)
stderr_reader.join(timeout=3)
stdout_head = b"".join(stdout_head_chunks).decode("utf-8", errors="replace")
stdout_tail = b"".join(stdout_tail_chunks).decode("utf-8", errors="replace")
stdout_text = b"".join(stdout_chunks).decode("utf-8", errors="replace")
stderr_text = b"".join(stderr_chunks).decode("utf-8", errors="replace")
# Assemble stdout with head+tail truncation
total_stdout = stdout_total_bytes[0]
if total_stdout > MAX_STDOUT_BYTES and stdout_tail:
omitted = total_stdout - len(stdout_head) - len(stdout_tail)
truncated_notice = (
f"\n\n... [OUTPUT TRUNCATED - {omitted:,} chars omitted "
f"out of {total_stdout:,} total] ...\n\n"
)
stdout_text = stdout_head + truncated_notice + stdout_tail
else:
stdout_text = stdout_head + stdout_tail
# Truncation notice
if len(stdout_text) >= MAX_STDOUT_BYTES:
stdout_text = stdout_text[:MAX_STDOUT_BYTES] + "\n[output truncated at 50KB]"
exit_code = proc.returncode if proc.returncode is not None else -1
duration = round(time.monotonic() - exec_start, 2)

View File

@@ -7,7 +7,6 @@ import os
import threading
from typing import Optional
from tools.file_operations import ShellFileOperations
from agent.redact import redact_sensitive_text
logger = logging.getLogger(__name__)
@@ -129,8 +128,6 @@ def read_file_tool(path: str, offset: int = 1, limit: int = 500, task_id: str =
try:
file_ops = _get_file_ops(task_id)
result = file_ops.read_file(path, offset, limit)
if result.content:
result.content = redact_sensitive_text(result.content)
return json.dumps(result.to_dict(), ensure_ascii=False)
except Exception as e:
return json.dumps({"error": str(e)}, ensure_ascii=False)
@@ -189,10 +186,6 @@ def search_tool(pattern: str, target: str = "content", path: str = ".",
pattern=pattern, path=path, target=target, file_glob=file_glob,
limit=limit, offset=offset, output_mode=output_mode, context=context
)
if hasattr(result, 'matches'):
for m in result.matches:
if hasattr(m, 'content') and m.content:
m.content = redact_sensitive_text(m.content)
result_dict = result.to_dict()
result_json = json.dumps(result_dict, ensure_ascii=False)
# Hint when results were truncated — explicit next offset is clearer

View File

@@ -29,6 +29,18 @@ Example config::
headers:
Authorization: "Bearer sk-..."
timeout: 180
analysis:
command: "npx"
args: ["-y", "analysis-server"]
sampling: # server-initiated LLM requests
enabled: true # default: true
model: "gemini-3-flash" # override model (optional)
max_tokens_cap: 4096 # max tokens per request
timeout: 30 # LLM call timeout (seconds)
max_rpm: 10 # max requests per minute
allowed_models: [] # model whitelist (empty = all)
max_tool_rounds: 5 # tool loop limit (0 = disable)
log_level: "info" # audit verbosity
Features:
- Stdio transport (command + args) and HTTP/StreamableHTTP transport (url)
@@ -37,6 +49,8 @@ Features:
- Credential stripping in error messages returned to the LLM
- Configurable per-server timeouts for tool calls and connections
- Thread-safe architecture with dedicated background event loop
- Sampling support: MCP servers can request LLM completions via
sampling/createMessage (text and tool-use responses)
Architecture:
A dedicated background event loop (_mcp_loop) runs in a daemon thread.
@@ -58,9 +72,11 @@ Thread safety:
import asyncio
import json
import logging
import math
import os
import re
import threading
import time
from typing import Any, Dict, List, Optional
logger = logging.getLogger(__name__)
@@ -71,6 +87,7 @@ logger = logging.getLogger(__name__)
_MCP_AVAILABLE = False
_MCP_HTTP_AVAILABLE = False
_MCP_SAMPLING_TYPES = False
try:
from mcp import ClientSession, StdioServerParameters
from mcp.client.stdio import stdio_client
@@ -80,6 +97,20 @@ try:
_MCP_HTTP_AVAILABLE = True
except ImportError:
_MCP_HTTP_AVAILABLE = False
# Sampling types -- separated so older SDK versions don't break MCP support
try:
from mcp.types import (
CreateMessageResult,
CreateMessageResultWithTools,
ErrorData,
SamplingCapability,
SamplingToolsCapability,
TextContent,
ToolUseContent,
)
_MCP_SAMPLING_TYPES = True
except ImportError:
logger.debug("MCP sampling types not available -- sampling disabled")
except ImportError:
logger.debug("mcp package not installed -- MCP tool support disabled")
@@ -145,6 +176,386 @@ def _sanitize_error(text: str) -> str:
return _CREDENTIAL_PATTERN.sub("[REDACTED]", text)
# ---------------------------------------------------------------------------
# Sampling -- server-initiated LLM requests (MCP sampling/createMessage)
# ---------------------------------------------------------------------------
def _safe_numeric(value, default, coerce=int, minimum=1):
"""Coerce a config value to a numeric type, returning *default* on failure.
Handles string values from YAML (e.g. ``"10"`` instead of ``10``),
non-finite floats, and values below *minimum*.
"""
try:
result = coerce(value)
if isinstance(result, float) and not math.isfinite(result):
return default
return max(result, minimum)
except (TypeError, ValueError, OverflowError):
return default
class SamplingHandler:
"""Handles sampling/createMessage requests for a single MCP server.
Each MCPServerTask that has sampling enabled creates one SamplingHandler.
The handler is callable and passed directly to ``ClientSession`` as
the ``sampling_callback``. All state (rate-limit timestamps, metrics,
tool-loop counters) lives on the instance -- no module-level globals.
The callback is async and runs on the MCP background event loop. The
sync LLM call is offloaded to a thread via ``asyncio.to_thread()`` so
it doesn't block the event loop.
"""
_STOP_REASON_MAP = {"stop": "endTurn", "length": "maxTokens", "tool_calls": "toolUse"}
def __init__(self, server_name: str, config: dict):
self.server_name = server_name
self.max_rpm = _safe_numeric(config.get("max_rpm", 10), 10, int)
self.timeout = _safe_numeric(config.get("timeout", 30), 30, float)
self.max_tokens_cap = _safe_numeric(config.get("max_tokens_cap", 4096), 4096, int)
self.max_tool_rounds = _safe_numeric(
config.get("max_tool_rounds", 5), 5, int, minimum=0,
)
self.model_override = config.get("model")
self.allowed_models = config.get("allowed_models", [])
_log_levels = {"debug": logging.DEBUG, "info": logging.INFO, "warning": logging.WARNING}
self.audit_level = _log_levels.get(
str(config.get("log_level", "info")).lower(), logging.INFO,
)
# Per-instance state
self._rate_timestamps: List[float] = []
self._tool_loop_count = 0
self.metrics = {"requests": 0, "errors": 0, "tokens_used": 0, "tool_use_count": 0}
# -- Rate limiting -------------------------------------------------------
def _check_rate_limit(self) -> bool:
"""Sliding-window rate limiter. Returns True if request is allowed."""
now = time.time()
window = now - 60
self._rate_timestamps[:] = [t for t in self._rate_timestamps if t > window]
if len(self._rate_timestamps) >= self.max_rpm:
return False
self._rate_timestamps.append(now)
return True
# -- Model resolution ----------------------------------------------------
def _resolve_model(self, preferences) -> Optional[str]:
"""Config override > server hint > None (use default)."""
if self.model_override:
return self.model_override
if preferences and hasattr(preferences, "hints") and preferences.hints:
for hint in preferences.hints:
if hasattr(hint, "name") and hint.name:
return hint.name
return None
# -- Message conversion --------------------------------------------------
@staticmethod
def _extract_tool_result_text(block) -> str:
"""Extract text from a ToolResultContent block."""
if not hasattr(block, "content") or block.content is None:
return ""
items = block.content if isinstance(block.content, list) else [block.content]
return "\n".join(item.text for item in items if hasattr(item, "text"))
def _convert_messages(self, params) -> List[dict]:
"""Convert MCP SamplingMessages to OpenAI format.
Uses ``msg.content_as_list`` (SDK helper) so single-block and
list-of-blocks are handled uniformly. Dispatches per block type
with ``isinstance`` on real SDK types when available, falling back
to duck-typing via ``hasattr`` for compatibility.
"""
messages: List[dict] = []
for msg in params.messages:
blocks = msg.content_as_list if hasattr(msg, "content_as_list") else (
msg.content if isinstance(msg.content, list) else [msg.content]
)
# Separate blocks by kind
tool_results = [b for b in blocks if hasattr(b, "toolUseId")]
tool_uses = [b for b in blocks if hasattr(b, "name") and hasattr(b, "input") and not hasattr(b, "toolUseId")]
content_blocks = [b for b in blocks if not hasattr(b, "toolUseId") and not (hasattr(b, "name") and hasattr(b, "input"))]
# Emit tool result messages (role: tool)
for tr in tool_results:
messages.append({
"role": "tool",
"tool_call_id": tr.toolUseId,
"content": self._extract_tool_result_text(tr),
})
# Emit assistant tool_calls message
if tool_uses:
tc_list = []
for tu in tool_uses:
tc_list.append({
"id": getattr(tu, "id", f"call_{len(tc_list)}"),
"type": "function",
"function": {
"name": tu.name,
"arguments": json.dumps(tu.input) if isinstance(tu.input, dict) else str(tu.input),
},
})
msg_dict: dict = {"role": msg.role, "tool_calls": tc_list}
# Include any accompanying text
text_parts = [b.text for b in content_blocks if hasattr(b, "text")]
if text_parts:
msg_dict["content"] = "\n".join(text_parts)
messages.append(msg_dict)
elif content_blocks:
# Pure text/image content
if len(content_blocks) == 1 and hasattr(content_blocks[0], "text"):
messages.append({"role": msg.role, "content": content_blocks[0].text})
else:
parts = []
for block in content_blocks:
if hasattr(block, "text"):
parts.append({"type": "text", "text": block.text})
elif hasattr(block, "data") and hasattr(block, "mimeType"):
parts.append({
"type": "image_url",
"image_url": {"url": f"data:{block.mimeType};base64,{block.data}"},
})
else:
logger.warning(
"Unsupported sampling content block type: %s (skipped)",
type(block).__name__,
)
if parts:
messages.append({"role": msg.role, "content": parts})
return messages
# -- Error helper --------------------------------------------------------
@staticmethod
def _error(message: str, code: int = -1):
"""Return ErrorData (MCP spec) or raise as fallback."""
if _MCP_SAMPLING_TYPES:
return ErrorData(code=code, message=message)
raise Exception(message)
# -- Response building ---------------------------------------------------
def _build_tool_use_result(self, choice, response):
"""Build a CreateMessageResultWithTools from an LLM tool_calls response."""
self.metrics["tool_use_count"] += 1
# Tool loop governance
if self.max_tool_rounds == 0:
self._tool_loop_count = 0
return self._error(
f"Tool loops disabled for server '{self.server_name}' (max_tool_rounds=0)"
)
self._tool_loop_count += 1
if self._tool_loop_count > self.max_tool_rounds:
self._tool_loop_count = 0
return self._error(
f"Tool loop limit exceeded for server '{self.server_name}' "
f"(max {self.max_tool_rounds} rounds)"
)
content_blocks = []
for tc in choice.message.tool_calls:
args = tc.function.arguments
if isinstance(args, str):
try:
parsed = json.loads(args)
except (json.JSONDecodeError, ValueError):
logger.warning(
"MCP server '%s': malformed tool_calls arguments "
"from LLM (wrapping as raw): %.100s",
self.server_name, args,
)
parsed = {"_raw": args}
else:
parsed = args if isinstance(args, dict) else {"_raw": str(args)}
content_blocks.append(ToolUseContent(
type="tool_use",
id=tc.id,
name=tc.function.name,
input=parsed,
))
logger.log(
self.audit_level,
"MCP server '%s' sampling response: model=%s, tokens=%s, tool_calls=%d",
self.server_name, response.model,
getattr(getattr(response, "usage", None), "total_tokens", "?"),
len(content_blocks),
)
return CreateMessageResultWithTools(
role="assistant",
content=content_blocks,
model=response.model,
stopReason="toolUse",
)
def _build_text_result(self, choice, response):
"""Build a CreateMessageResult from a normal text response."""
self._tool_loop_count = 0 # reset on text response
response_text = choice.message.content or ""
logger.log(
self.audit_level,
"MCP server '%s' sampling response: model=%s, tokens=%s",
self.server_name, response.model,
getattr(getattr(response, "usage", None), "total_tokens", "?"),
)
return CreateMessageResult(
role="assistant",
content=TextContent(type="text", text=_sanitize_error(response_text)),
model=response.model,
stopReason=self._STOP_REASON_MAP.get(choice.finish_reason, "endTurn"),
)
# -- Session kwargs helper -----------------------------------------------
def session_kwargs(self) -> dict:
"""Return kwargs to pass to ClientSession for sampling support."""
return {
"sampling_callback": self,
"sampling_capabilities": SamplingCapability(
tools=SamplingToolsCapability(),
),
}
# -- Main callback -------------------------------------------------------
async def __call__(self, context, params):
"""Sampling callback invoked by the MCP SDK.
Conforms to ``SamplingFnT`` protocol. Returns
``CreateMessageResult``, ``CreateMessageResultWithTools``, or
``ErrorData``.
"""
# Rate limit
if not self._check_rate_limit():
logger.warning(
"MCP server '%s' sampling rate limit exceeded (%d/min)",
self.server_name, self.max_rpm,
)
self.metrics["errors"] += 1
return self._error(
f"Sampling rate limit exceeded for server '{self.server_name}' "
f"({self.max_rpm} requests/minute)"
)
# Resolve model
model = self._resolve_model(getattr(params, "modelPreferences", None))
# Get auxiliary LLM client
from agent.auxiliary_client import get_text_auxiliary_client
client, default_model = get_text_auxiliary_client()
if client is None:
self.metrics["errors"] += 1
return self._error("No LLM provider available for sampling")
resolved_model = model or default_model
# Model whitelist check
if self.allowed_models and resolved_model not in self.allowed_models:
logger.warning(
"MCP server '%s' requested model '%s' not in allowed_models",
self.server_name, resolved_model,
)
self.metrics["errors"] += 1
return self._error(
f"Model '{resolved_model}' not allowed for server "
f"'{self.server_name}'. Allowed: {', '.join(self.allowed_models)}"
)
# Convert messages
messages = self._convert_messages(params)
if hasattr(params, "systemPrompt") and params.systemPrompt:
messages.insert(0, {"role": "system", "content": params.systemPrompt})
# Build LLM call kwargs
max_tokens = min(params.maxTokens, self.max_tokens_cap)
call_kwargs: dict = {
"model": resolved_model,
"messages": messages,
"max_tokens": max_tokens,
}
if hasattr(params, "temperature") and params.temperature is not None:
call_kwargs["temperature"] = params.temperature
if stop := getattr(params, "stopSequences", None):
call_kwargs["stop"] = stop
# Forward server-provided tools
server_tools = getattr(params, "tools", None)
if server_tools:
call_kwargs["tools"] = [
{
"type": "function",
"function": {
"name": getattr(t, "name", ""),
"description": getattr(t, "description", "") or "",
"parameters": getattr(t, "inputSchema", {}) or {},
},
}
for t in server_tools
]
if tool_choice := getattr(params, "toolChoice", None):
mode = getattr(tool_choice, "mode", "auto")
call_kwargs["tool_choice"] = {"auto": "auto", "required": "required", "none": "none"}.get(mode, "auto")
logger.log(
self.audit_level,
"MCP server '%s' sampling request: model=%s, max_tokens=%d, messages=%d",
self.server_name, resolved_model, max_tokens, len(messages),
)
# Offload sync LLM call to thread (non-blocking)
def _sync_call():
return client.chat.completions.create(**call_kwargs)
try:
response = await asyncio.wait_for(
asyncio.to_thread(_sync_call), timeout=self.timeout,
)
except asyncio.TimeoutError:
self.metrics["errors"] += 1
return self._error(
f"Sampling LLM call timed out after {self.timeout}s "
f"for server '{self.server_name}'"
)
except Exception as exc:
self.metrics["errors"] += 1
return self._error(
f"Sampling LLM call failed: {_sanitize_error(str(exc))}"
)
# Track metrics
choice = response.choices[0]
self.metrics["requests"] += 1
total_tokens = getattr(getattr(response, "usage", None), "total_tokens", 0)
if isinstance(total_tokens, int):
self.metrics["tokens_used"] += total_tokens
# Dispatch based on response type
if (
choice.finish_reason == "tool_calls"
and hasattr(choice.message, "tool_calls")
and choice.message.tool_calls
):
return self._build_tool_use_result(choice, response)
return self._build_text_result(choice, response)
# ---------------------------------------------------------------------------
# Server task -- each MCP server lives in one long-lived asyncio Task
# ---------------------------------------------------------------------------
@@ -162,6 +573,7 @@ class MCPServerTask:
__slots__ = (
"name", "session", "tool_timeout",
"_task", "_ready", "_shutdown_event", "_tools", "_error", "_config",
"_sampling",
)
def __init__(self, name: str):
@@ -174,6 +586,7 @@ class MCPServerTask:
self._tools: list = []
self._error: Optional[Exception] = None
self._config: dict = {}
self._sampling: Optional[SamplingHandler] = None
def _is_http(self) -> bool:
"""Check if this server uses HTTP transport."""
@@ -197,8 +610,9 @@ class MCPServerTask:
env=safe_env if safe_env else None,
)
sampling_kwargs = self._sampling.session_kwargs() if self._sampling else {}
async with stdio_client(server_params) as (read_stream, write_stream):
async with ClientSession(read_stream, write_stream) as session:
async with ClientSession(read_stream, write_stream, **sampling_kwargs) as session:
await session.initialize()
self.session = session
await self._discover_tools()
@@ -218,12 +632,13 @@ class MCPServerTask:
headers = config.get("headers")
connect_timeout = config.get("connect_timeout", _DEFAULT_CONNECT_TIMEOUT)
sampling_kwargs = self._sampling.session_kwargs() if self._sampling else {}
async with streamablehttp_client(
url,
headers=headers,
timeout=float(connect_timeout),
) as (read_stream, write_stream, _get_session_id):
async with ClientSession(read_stream, write_stream) as session:
async with ClientSession(read_stream, write_stream, **sampling_kwargs) as session:
await session.initialize()
self.session = session
await self._discover_tools()
@@ -250,6 +665,13 @@ class MCPServerTask:
self._config = config
self.tool_timeout = config.get("timeout", _DEFAULT_TOOL_TIMEOUT)
# Set up sampling handler if enabled and SDK types are available
sampling_config = config.get("sampling", {})
if sampling_config.get("enabled", True) and _MCP_SAMPLING_TYPES:
self._sampling = SamplingHandler(self.name, sampling_config)
else:
self._sampling = None
# Validate: warn if both url and command are present
if "url" in config and "command" in config:
logger.warning(
@@ -975,12 +1397,15 @@ def get_mcp_status() -> List[dict]:
transport = "http" if "url" in cfg else "stdio"
server = active_servers.get(name)
if server and server.session is not None:
result.append({
entry = {
"name": name,
"transport": transport,
"tools": len(server._tools),
"connected": True,
})
}
if server._sampling:
entry["sampling"] = dict(server._sampling.metrics)
result.append(entry)
else:
result.append({
"name": name,

View File

@@ -69,36 +69,10 @@ def _read_manifest() -> Dict[str, str]:
def _write_manifest(entries: Dict[str, str]):
"""Write the manifest file atomically in v2 format (name:hash).
Uses a temp file + os.replace() to avoid corruption if the process
crashes or is interrupted mid-write.
"""
import tempfile
"""Write the manifest file in v2 format (name:hash)."""
MANIFEST_FILE.parent.mkdir(parents=True, exist_ok=True)
data = "\n".join(f"{name}:{hash_val}" for name, hash_val in sorted(entries.items())) + "\n"
try:
fd, tmp_path = tempfile.mkstemp(
dir=str(MANIFEST_FILE.parent),
prefix=".bundled_manifest_",
suffix=".tmp",
)
try:
with os.fdopen(fd, "w", encoding="utf-8") as f:
f.write(data)
f.flush()
os.fsync(f.fileno())
os.replace(tmp_path, MANIFEST_FILE)
except BaseException:
try:
os.unlink(tmp_path)
except OSError:
pass
raise
except Exception as e:
logger.debug("Failed to write skills manifest %s: %s", MANIFEST_FILE, e, exc_info=True)
lines = [f"{name}:{hash_val}" for name, hash_val in sorted(entries.items())]
MANIFEST_FILE.write_text("\n".join(lines) + "\n", encoding="utf-8")
def _discover_bundled_skills(bundled_dir: Path) -> List[Tuple[str, Path]]:

View File

@@ -271,3 +271,62 @@ You can reload MCP servers without restarting Hermes:
- In the CLI: the agent reconnects automatically
- In messaging: send `/reload-mcp`
## Sampling (Server-Initiated LLM Requests)
MCP's `sampling/createMessage` capability allows MCP servers to request LLM completions through the Hermes agent. This enables agent-in-the-loop workflows where servers can leverage the LLM during tool execution — for example, a database server asking the LLM to interpret query results, or a code analysis server requesting the LLM to review findings.
### How It Works
When an MCP server sends a `sampling/createMessage` request:
1. The sampling callback validates against rate limits and model whitelist
2. Resolves which model to use (config override > server hint > default)
3. Converts MCP messages to OpenAI-compatible format
4. Offloads the LLM call to a thread via `asyncio.to_thread()` (non-blocking)
5. Returns the response (text or tool use) back to the server
### Configuration
Sampling is **enabled by default** for all MCP servers. No extra setup needed — if you have an auxiliary LLM client configured, sampling works automatically.
```yaml
mcp_servers:
analysis_server:
command: "npx"
args: ["-y", "my-analysis-server"]
sampling:
enabled: true # default: true
model: "gemini-3-flash" # override model (optional)
max_tokens_cap: 4096 # max tokens per request (default: 4096)
timeout: 30 # LLM call timeout in seconds (default: 30)
max_rpm: 10 # max requests per minute (default: 10)
allowed_models: [] # model whitelist (empty = allow all)
max_tool_rounds: 5 # max consecutive tool use rounds (0 = disable)
log_level: "info" # audit verbosity: debug, info, warning
```
### Tool Use in Sampling
Servers can include `tools` and `toolChoice` in sampling requests, enabling multi-turn tool-augmented workflows within a single sampling session. The callback forwards tool definitions to the LLM, handles tool use responses with proper `ToolUseContent` types, and enforces `max_tool_rounds` to prevent infinite loops.
### Security
- **Rate limiting**: Per-server sliding window (default: 10 req/min)
- **Token cap**: Servers can't request more than `max_tokens_cap` (default: 4096)
- **Model whitelist**: `allowed_models` restricts which models a server can use
- **Tool loop limit**: `max_tool_rounds` caps consecutive tool use rounds
- **Credential stripping**: LLM responses are sanitized before returning to the server
- **Non-blocking**: LLM calls run in a separate thread via `asyncio.to_thread()`
- **Typed errors**: All failures return structured `ErrorData` per MCP spec
To disable sampling for untrusted servers:
```yaml
mcp_servers:
untrusted:
command: "npx"
args: ["-y", "untrusted-server"]
sampling:
enabled: false
```