Compare commits

...

2 Commits

Author SHA1 Message Date
teknium1
fbed199672 docs: add sampling config examples to docstring and cli-config.yaml.example 2026-03-09 03:36:47 -07:00
teknium1
2d13eb9795 feat(mcp): add sampling support — server-initiated LLM requests
Add MCP sampling/createMessage capability allowing MCP servers to request
LLM completions through the Hermes agent during tool execution. Enables
agent-in-the-loop workflows (data analysis, content generation, decision
making) where servers can leverage the LLM as needed.

Implementation as SamplingHandler class (per-server instance, no globals):
- Text-only sampling: server asks LLM a question, gets text back
- Tool use in sampling: server provides tools, LLM can use them in a
  multi-turn loop with configurable max_tool_rounds governance
- Rate limiting (sliding window, configurable max_rpm per server)
- Model resolution (config override > server hint > default)
- Model whitelist (allowed_models per server)
- Token cap (max_tokens_cap per server)
- LLM timeout with asyncio.wait_for
- Credential stripping on responses
- Per-server audit metrics (requests, errors, tokens_used, tool_use_count)
- Configurable log_level for audit verbosity
- Non-blocking: LLM calls offloaded via asyncio.to_thread()
- Proper MCP SDK types: CreateMessageResult for text responses,
  CreateMessageResultWithTools + ToolUseContent for tool use responses
- SamplingCapability with SamplingToolsCapability advertised to servers
- Backward compatible: silently disabled if MCP SDK lacks sampling types

Config (all optional, zero breaking changes):
  mcp_servers:
    my_server:
      sampling:
        enabled: true        # default
        model: 'gemini-3-flash'
        max_tokens_cap: 4096
        timeout: 30
        max_rpm: 10
        allowed_models: []
        max_tool_rounds: 5
        log_level: 'info'

Based on the sampling concept from PR #366 by eren-karakus0. Restructured
as a class-based design, fixed critical bugs (wrong return types for tool
use, missing capability advertisement, broken Pydantic validation), and
added tests using real MCP SDK types.

50 new tests, full suite passes (2600 tests).
2026-03-09 01:21:34 -07:00
5 changed files with 1307 additions and 4 deletions

View File

@@ -555,6 +555,21 @@ toolsets:
# args: ["-y", "@modelcontextprotocol/server-github"]
# env:
# GITHUB_PERSONAL_ACCESS_TOKEN: "ghp_..."
#
# Sampling (server-initiated LLM requests) — enabled by default.
# Per-server config under the 'sampling' key:
# analysis:
# command: npx
# args: ["-y", "analysis-server"]
# sampling:
# enabled: true # default: true
# model: "gemini-3-flash" # override model (optional)
# max_tokens_cap: 4096 # max tokens per request
# timeout: 30 # LLM call timeout (seconds)
# max_rpm: 10 # max requests per minute
# allowed_models: [] # model whitelist (empty = all)
# max_tool_rounds: 5 # tool loop limit (0 = disable)
# log_level: "info" # audit verbosity
# =============================================================================
# Voice Transcription (Speech-to-Text)

View File

@@ -321,6 +321,32 @@ mcp_servers:
All tools from all servers are registered and available simultaneously. Each server's tools are prefixed with its name to avoid collisions.
## Sampling (Server-Initiated LLM Requests)
Hermes supports MCP's `sampling/createMessage` capability — MCP servers can request LLM completions through the agent during tool execution. This enables agent-in-the-loop workflows (data analysis, content generation, decision-making).
Sampling is **enabled by default**. Configure per server:
```yaml
mcp_servers:
my_server:
command: "npx"
args: ["-y", "my-mcp-server"]
sampling:
enabled: true # default: true
model: "gemini-3-flash" # model override (optional)
max_tokens_cap: 4096 # max tokens per request
timeout: 30 # LLM call timeout (seconds)
max_rpm: 10 # max requests per minute
allowed_models: [] # model whitelist (empty = all)
max_tool_rounds: 5 # tool loop limit (0 = disable)
log_level: "info" # audit verbosity
```
Servers can also include `tools` in sampling requests for multi-turn tool-augmented workflows. The `max_tool_rounds` config prevents infinite tool loops. Per-server audit metrics (requests, errors, tokens, tool use count) are tracked via `get_mcp_status()`.
Disable sampling for untrusted servers with `sampling: { enabled: false }`.
## Notes
- MCP tools are called synchronously from the agent's perspective but run asynchronously on a dedicated background event loop

View File

@@ -1489,3 +1489,781 @@ class TestUtilityToolRegistration:
assert entry.check_fn() is False
_servers.pop("chk", None)
# ===========================================================================
# SamplingHandler tests
# ===========================================================================
import math
import time
from mcp.types import (
CreateMessageResult,
CreateMessageResultWithTools,
ErrorData,
SamplingCapability,
SamplingToolsCapability,
TextContent,
ToolUseContent,
)
from tools.mcp_tool import SamplingHandler, _safe_numeric
# ---------------------------------------------------------------------------
# Helpers for sampling tests
# ---------------------------------------------------------------------------
def _make_sampling_params(
messages=None,
max_tokens=100,
system_prompt=None,
model_preferences=None,
temperature=None,
stop_sequences=None,
tools=None,
tool_choice=None,
):
"""Create a fake CreateMessageRequestParams using SimpleNamespace.
Each message must have a ``content_as_list`` attribute that mirrors
the SDK helper so that ``_convert_messages`` works correctly.
"""
if messages is None:
content = SimpleNamespace(text="Hello")
msg = SimpleNamespace(role="user", content=content, content_as_list=[content])
messages = [msg]
params = SimpleNamespace(
messages=messages,
maxTokens=max_tokens,
modelPreferences=model_preferences,
temperature=temperature,
stopSequences=stop_sequences,
tools=tools,
toolChoice=tool_choice,
)
if system_prompt is not None:
params.systemPrompt = system_prompt
return params
def _make_llm_response(
content="LLM response",
model="test-model",
finish_reason="stop",
tool_calls=None,
):
"""Create a fake OpenAI chat completion response (text)."""
message = SimpleNamespace(content=content, tool_calls=tool_calls)
choice = SimpleNamespace(
finish_reason=finish_reason,
message=message,
)
usage = SimpleNamespace(total_tokens=42)
return SimpleNamespace(choices=[choice], model=model, usage=usage)
def _make_llm_tool_response(tool_calls_data=None, model="test-model"):
"""Create a fake response with tool_calls.
``tool_calls_data``: list of (id, name, arguments_json) tuples.
"""
if tool_calls_data is None:
tool_calls_data = [("call_1", "get_weather", '{"city": "London"}')]
tc_list = [
SimpleNamespace(
id=tc_id,
function=SimpleNamespace(name=name, arguments=args),
)
for tc_id, name, args in tool_calls_data
]
return _make_llm_response(
content=None,
model=model,
finish_reason="tool_calls",
tool_calls=tc_list,
)
# ---------------------------------------------------------------------------
# 1. _safe_numeric helper
# ---------------------------------------------------------------------------
class TestSafeNumeric:
def test_int_passthrough(self):
assert _safe_numeric(10, 5, int) == 10
def test_string_coercion(self):
assert _safe_numeric("20", 5, int) == 20
def test_none_returns_default(self):
assert _safe_numeric(None, 7, int) == 7
def test_inf_returns_default(self):
assert _safe_numeric(float("inf"), 3.0, float) == 3.0
def test_nan_returns_default(self):
assert _safe_numeric(float("nan"), 4.0, float) == 4.0
def test_below_minimum_clamps(self):
assert _safe_numeric(-5, 10, int, minimum=1) == 1
def test_minimum_zero_allowed(self):
assert _safe_numeric(0, 10, int, minimum=0) == 0
def test_non_numeric_string_returns_default(self):
assert _safe_numeric("abc", 42, int) == 42
def test_float_coercion(self):
assert _safe_numeric("3.5", 1.0, float) == 3.5
# ---------------------------------------------------------------------------
# 2. SamplingHandler initialization and config parsing
# ---------------------------------------------------------------------------
class TestSamplingHandlerInit:
def test_defaults(self):
h = SamplingHandler("srv", {})
assert h.server_name == "srv"
assert h.max_rpm == 10
assert h.timeout == 30
assert h.max_tokens_cap == 4096
assert h.max_tool_rounds == 5
assert h.model_override is None
assert h.allowed_models == []
assert h.metrics == {"requests": 0, "errors": 0, "tokens_used": 0, "tool_use_count": 0}
def test_custom_config(self):
cfg = {
"max_rpm": 20,
"timeout": 60,
"max_tokens_cap": 2048,
"max_tool_rounds": 3,
"model": "gpt-4o",
"allowed_models": ["gpt-4o", "gpt-3.5-turbo"],
"log_level": "debug",
}
h = SamplingHandler("custom", cfg)
assert h.max_rpm == 20
assert h.timeout == 60.0
assert h.max_tokens_cap == 2048
assert h.max_tool_rounds == 3
assert h.model_override == "gpt-4o"
assert h.allowed_models == ["gpt-4o", "gpt-3.5-turbo"]
def test_string_numeric_config_values(self):
"""YAML sometimes delivers numeric values as strings."""
cfg = {"max_rpm": "15", "timeout": "45.5", "max_tokens_cap": "1024"}
h = SamplingHandler("s", cfg)
assert h.max_rpm == 15
assert h.timeout == 45.5
assert h.max_tokens_cap == 1024
# ---------------------------------------------------------------------------
# 3. Rate limiting
# ---------------------------------------------------------------------------
class TestRateLimit:
def setup_method(self):
self.handler = SamplingHandler("rl", {"max_rpm": 3})
def test_allows_under_limit(self):
assert self.handler._check_rate_limit() is True
assert self.handler._check_rate_limit() is True
assert self.handler._check_rate_limit() is True
def test_rejects_over_limit(self):
for _ in range(3):
self.handler._check_rate_limit()
assert self.handler._check_rate_limit() is False
def test_window_expiry(self):
"""Old timestamps should be purged from the sliding window."""
for _ in range(3):
self.handler._check_rate_limit()
# Simulate timestamps from 61 seconds ago
self.handler._rate_timestamps[:] = [time.time() - 61] * 3
assert self.handler._check_rate_limit() is True
# ---------------------------------------------------------------------------
# 4. Model resolution
# ---------------------------------------------------------------------------
class TestResolveModel:
def setup_method(self):
self.handler = SamplingHandler("mr", {})
def test_no_preference_no_override(self):
assert self.handler._resolve_model(None) is None
def test_config_override_wins(self):
self.handler.model_override = "override-model"
prefs = SimpleNamespace(hints=[SimpleNamespace(name="hint-model")])
assert self.handler._resolve_model(prefs) == "override-model"
def test_hint_used_when_no_override(self):
prefs = SimpleNamespace(hints=[SimpleNamespace(name="hint-model")])
assert self.handler._resolve_model(prefs) == "hint-model"
def test_empty_hints(self):
prefs = SimpleNamespace(hints=[])
assert self.handler._resolve_model(prefs) is None
def test_hint_without_name(self):
prefs = SimpleNamespace(hints=[SimpleNamespace(name=None)])
assert self.handler._resolve_model(prefs) is None
# ---------------------------------------------------------------------------
# 5. Message conversion
# ---------------------------------------------------------------------------
class TestConvertMessages:
def setup_method(self):
self.handler = SamplingHandler("mc", {})
def test_single_text_message(self):
content = SimpleNamespace(text="Hello world")
msg = SimpleNamespace(role="user", content=content, content_as_list=[content])
params = _make_sampling_params(messages=[msg])
result = self.handler._convert_messages(params)
assert len(result) == 1
assert result[0] == {"role": "user", "content": "Hello world"}
def test_image_message(self):
text_block = SimpleNamespace(text="Look at this")
img_block = SimpleNamespace(data="abc123", mimeType="image/png")
msg = SimpleNamespace(
role="user",
content=[text_block, img_block],
content_as_list=[text_block, img_block],
)
params = _make_sampling_params(messages=[msg])
result = self.handler._convert_messages(params)
assert len(result) == 1
parts = result[0]["content"]
assert len(parts) == 2
assert parts[0] == {"type": "text", "text": "Look at this"}
assert parts[1]["type"] == "image_url"
assert "data:image/png;base64,abc123" in parts[1]["image_url"]["url"]
def test_tool_result_message(self):
inner = SimpleNamespace(text="42 degrees")
tr_block = SimpleNamespace(toolUseId="call_1", content=[inner])
msg = SimpleNamespace(
role="user",
content=[tr_block],
content_as_list=[tr_block],
)
params = _make_sampling_params(messages=[msg])
result = self.handler._convert_messages(params)
assert len(result) == 1
assert result[0]["role"] == "tool"
assert result[0]["tool_call_id"] == "call_1"
assert result[0]["content"] == "42 degrees"
def test_tool_use_message(self):
tu_block = SimpleNamespace(
id="call_2", name="get_weather", input={"city": "London"}
)
msg = SimpleNamespace(
role="assistant",
content=[tu_block],
content_as_list=[tu_block],
)
params = _make_sampling_params(messages=[msg])
result = self.handler._convert_messages(params)
assert len(result) == 1
assert result[0]["role"] == "assistant"
assert len(result[0]["tool_calls"]) == 1
assert result[0]["tool_calls"][0]["function"]["name"] == "get_weather"
assert json.loads(result[0]["tool_calls"][0]["function"]["arguments"]) == {"city": "London"}
def test_mixed_text_and_tool_use(self):
"""Assistant message with both text and tool_calls."""
text_block = SimpleNamespace(text="Let me check the weather")
tu_block = SimpleNamespace(
id="call_3", name="get_weather", input={"city": "Paris"}
)
msg = SimpleNamespace(
role="assistant",
content=[text_block, tu_block],
content_as_list=[text_block, tu_block],
)
params = _make_sampling_params(messages=[msg])
result = self.handler._convert_messages(params)
assert len(result) == 1
assert result[0]["content"] == "Let me check the weather"
assert len(result[0]["tool_calls"]) == 1
def test_fallback_without_content_as_list(self):
"""When content_as_list is absent, falls back to content."""
content = SimpleNamespace(text="Fallback text")
msg = SimpleNamespace(role="user", content=content)
params = _make_sampling_params(messages=[msg])
result = self.handler._convert_messages(params)
assert len(result) == 1
assert result[0]["content"] == "Fallback text"
# ---------------------------------------------------------------------------
# 6. Text-only sampling callback (full flow)
# ---------------------------------------------------------------------------
class TestSamplingCallbackText:
def setup_method(self):
self.handler = SamplingHandler("txt", {})
def test_text_response(self):
"""Full flow: text response returns CreateMessageResult."""
fake_client = MagicMock()
fake_client.chat.completions.create.return_value = _make_llm_response(
content="Hello from LLM"
)
with patch(
"agent.auxiliary_client.get_text_auxiliary_client",
return_value=(fake_client, "default-model"),
):
params = _make_sampling_params()
result = asyncio.run(self.handler(None, params))
assert isinstance(result, CreateMessageResult)
assert isinstance(result.content, TextContent)
assert result.content.text == "Hello from LLM"
assert result.model == "test-model"
assert result.role == "assistant"
assert result.stopReason == "endTurn"
def test_system_prompt_prepended(self):
"""System prompt is inserted as the first message."""
fake_client = MagicMock()
fake_client.chat.completions.create.return_value = _make_llm_response()
with patch(
"agent.auxiliary_client.get_text_auxiliary_client",
return_value=(fake_client, "default-model"),
):
params = _make_sampling_params(system_prompt="Be helpful")
asyncio.run(self.handler(None, params))
call_args = fake_client.chat.completions.create.call_args
messages = call_args.kwargs["messages"]
assert messages[0] == {"role": "system", "content": "Be helpful"}
def test_length_stop_reason(self):
"""finish_reason='length' maps to stopReason='maxTokens'."""
fake_client = MagicMock()
fake_client.chat.completions.create.return_value = _make_llm_response(
finish_reason="length"
)
with patch(
"agent.auxiliary_client.get_text_auxiliary_client",
return_value=(fake_client, "default-model"),
):
params = _make_sampling_params()
result = asyncio.run(self.handler(None, params))
assert isinstance(result, CreateMessageResult)
assert result.stopReason == "maxTokens"
# ---------------------------------------------------------------------------
# 7. Tool use sampling callback
# ---------------------------------------------------------------------------
class TestSamplingCallbackToolUse:
def setup_method(self):
self.handler = SamplingHandler("tu", {})
def test_tool_use_response(self):
"""LLM tool_calls response returns CreateMessageResultWithTools."""
fake_client = MagicMock()
fake_client.chat.completions.create.return_value = _make_llm_tool_response()
with patch(
"agent.auxiliary_client.get_text_auxiliary_client",
return_value=(fake_client, "default-model"),
):
params = _make_sampling_params()
result = asyncio.run(self.handler(None, params))
assert isinstance(result, CreateMessageResultWithTools)
assert result.stopReason == "toolUse"
assert result.model == "test-model"
assert len(result.content) == 1
tc = result.content[0]
assert isinstance(tc, ToolUseContent)
assert tc.name == "get_weather"
assert tc.id == "call_1"
assert tc.input == {"city": "London"}
def test_multiple_tool_calls(self):
"""Multiple tool_calls in a single response."""
fake_client = MagicMock()
fake_client.chat.completions.create.return_value = _make_llm_tool_response(
tool_calls_data=[
("call_a", "func_a", '{"x": 1}'),
("call_b", "func_b", '{"y": 2}'),
]
)
with patch(
"agent.auxiliary_client.get_text_auxiliary_client",
return_value=(fake_client, "default-model"),
):
result = asyncio.run(self.handler(None, _make_sampling_params()))
assert isinstance(result, CreateMessageResultWithTools)
assert len(result.content) == 2
assert result.content[0].name == "func_a"
assert result.content[1].name == "func_b"
# ---------------------------------------------------------------------------
# 8. Tool loop governance
# ---------------------------------------------------------------------------
class TestToolLoopGovernance:
def test_max_tool_rounds_enforcement(self):
"""After max_tool_rounds consecutive tool responses, an error is returned."""
handler = SamplingHandler("tl", {"max_tool_rounds": 2})
fake_client = MagicMock()
fake_client.chat.completions.create.return_value = _make_llm_tool_response()
with patch(
"agent.auxiliary_client.get_text_auxiliary_client",
return_value=(fake_client, "default-model"),
):
params = _make_sampling_params()
# Round 1, 2: allowed
r1 = asyncio.run(handler(None, params))
assert isinstance(r1, CreateMessageResultWithTools)
r2 = asyncio.run(handler(None, params))
assert isinstance(r2, CreateMessageResultWithTools)
# Round 3: exceeds limit
r3 = asyncio.run(handler(None, params))
assert isinstance(r3, ErrorData)
assert "Tool loop limit exceeded" in r3.message
def test_text_response_resets_counter(self):
"""A text response resets the tool loop counter."""
handler = SamplingHandler("tl2", {"max_tool_rounds": 1})
fake_client = MagicMock()
with patch(
"agent.auxiliary_client.get_text_auxiliary_client",
return_value=(fake_client, "default-model"),
):
# Tool response (round 1 of 1 allowed)
fake_client.chat.completions.create.return_value = _make_llm_tool_response()
r1 = asyncio.run(handler(None, _make_sampling_params()))
assert isinstance(r1, CreateMessageResultWithTools)
# Text response resets counter
fake_client.chat.completions.create.return_value = _make_llm_response()
r2 = asyncio.run(handler(None, _make_sampling_params()))
assert isinstance(r2, CreateMessageResult)
# Tool response again (should succeed since counter was reset)
fake_client.chat.completions.create.return_value = _make_llm_tool_response()
r3 = asyncio.run(handler(None, _make_sampling_params()))
assert isinstance(r3, CreateMessageResultWithTools)
def test_max_tool_rounds_zero_disables(self):
"""max_tool_rounds=0 means tool loops are disabled entirely."""
handler = SamplingHandler("tl3", {"max_tool_rounds": 0})
fake_client = MagicMock()
fake_client.chat.completions.create.return_value = _make_llm_tool_response()
with patch(
"agent.auxiliary_client.get_text_auxiliary_client",
return_value=(fake_client, "default-model"),
):
result = asyncio.run(handler(None, _make_sampling_params()))
assert isinstance(result, ErrorData)
assert "Tool loops disabled" in result.message
# ---------------------------------------------------------------------------
# 9. Error paths: rate limit, timeout, no provider
# ---------------------------------------------------------------------------
class TestSamplingErrors:
def test_rate_limit_error(self):
handler = SamplingHandler("rle", {"max_rpm": 1})
fake_client = MagicMock()
fake_client.chat.completions.create.return_value = _make_llm_response()
with patch(
"agent.auxiliary_client.get_text_auxiliary_client",
return_value=(fake_client, "default-model"),
):
# First call succeeds
r1 = asyncio.run(handler(None, _make_sampling_params()))
assert isinstance(r1, CreateMessageResult)
# Second call is rate limited
r2 = asyncio.run(handler(None, _make_sampling_params()))
assert isinstance(r2, ErrorData)
assert "rate limit" in r2.message.lower()
assert handler.metrics["errors"] == 1
def test_timeout_error(self):
handler = SamplingHandler("to", {"timeout": 0.05})
fake_client = MagicMock()
def slow_call(**kwargs):
import threading
# Use an event to ensure the thread truly blocks long enough
evt = threading.Event()
evt.wait(5) # blocks for up to 5 seconds (cancelled by timeout)
return _make_llm_response()
fake_client.chat.completions.create.side_effect = slow_call
with patch(
"agent.auxiliary_client.get_text_auxiliary_client",
return_value=(fake_client, "default-model"),
):
result = asyncio.run(handler(None, _make_sampling_params()))
assert isinstance(result, ErrorData)
assert "timed out" in result.message.lower()
assert handler.metrics["errors"] == 1
def test_no_provider_error(self):
handler = SamplingHandler("np", {})
with patch(
"agent.auxiliary_client.get_text_auxiliary_client",
return_value=(None, None),
):
result = asyncio.run(handler(None, _make_sampling_params()))
assert isinstance(result, ErrorData)
assert "No LLM provider" in result.message
assert handler.metrics["errors"] == 1
# ---------------------------------------------------------------------------
# 10. Model whitelist
# ---------------------------------------------------------------------------
class TestModelWhitelist:
def test_allowed_model_passes(self):
handler = SamplingHandler("wl", {"allowed_models": ["gpt-4o", "test-model"]})
fake_client = MagicMock()
fake_client.chat.completions.create.return_value = _make_llm_response()
with patch(
"agent.auxiliary_client.get_text_auxiliary_client",
return_value=(fake_client, "test-model"),
):
result = asyncio.run(handler(None, _make_sampling_params()))
assert isinstance(result, CreateMessageResult)
def test_disallowed_model_rejected(self):
handler = SamplingHandler("wl2", {"allowed_models": ["gpt-4o"]})
fake_client = MagicMock()
with patch(
"agent.auxiliary_client.get_text_auxiliary_client",
return_value=(fake_client, "gpt-3.5-turbo"),
):
result = asyncio.run(handler(None, _make_sampling_params()))
assert isinstance(result, ErrorData)
assert "not allowed" in result.message
assert handler.metrics["errors"] == 1
def test_empty_whitelist_allows_all(self):
handler = SamplingHandler("wl3", {"allowed_models": []})
fake_client = MagicMock()
fake_client.chat.completions.create.return_value = _make_llm_response()
with patch(
"agent.auxiliary_client.get_text_auxiliary_client",
return_value=(fake_client, "any-model"),
):
result = asyncio.run(handler(None, _make_sampling_params()))
assert isinstance(result, CreateMessageResult)
# ---------------------------------------------------------------------------
# 11. Malformed tool_call arguments
# ---------------------------------------------------------------------------
class TestMalformedToolCallArgs:
def test_invalid_json_wrapped_as_raw(self):
"""Malformed JSON arguments get wrapped in {"_raw": ...}."""
handler = SamplingHandler("mf", {})
fake_client = MagicMock()
fake_client.chat.completions.create.return_value = _make_llm_tool_response(
tool_calls_data=[("call_x", "some_tool", "not valid json {{{")]
)
with patch(
"agent.auxiliary_client.get_text_auxiliary_client",
return_value=(fake_client, "default-model"),
):
result = asyncio.run(handler(None, _make_sampling_params()))
assert isinstance(result, CreateMessageResultWithTools)
tc = result.content[0]
assert isinstance(tc, ToolUseContent)
assert tc.input == {"_raw": "not valid json {{{"}
def test_dict_args_pass_through(self):
"""When arguments are already a dict, they pass through directly."""
handler = SamplingHandler("mf2", {})
# Build a tool call where arguments is already a dict
tc_obj = SimpleNamespace(
id="call_d",
function=SimpleNamespace(name="do_stuff", arguments={"key": "val"}),
)
message = SimpleNamespace(content=None, tool_calls=[tc_obj])
choice = SimpleNamespace(finish_reason="tool_calls", message=message)
usage = SimpleNamespace(total_tokens=10)
response = SimpleNamespace(choices=[choice], model="m", usage=usage)
fake_client = MagicMock()
fake_client.chat.completions.create.return_value = response
with patch(
"agent.auxiliary_client.get_text_auxiliary_client",
return_value=(fake_client, "default-model"),
):
result = asyncio.run(handler(None, _make_sampling_params()))
assert isinstance(result, CreateMessageResultWithTools)
assert result.content[0].input == {"key": "val"}
# ---------------------------------------------------------------------------
# 12. Metrics tracking
# ---------------------------------------------------------------------------
class TestMetricsTracking:
def test_request_and_token_metrics(self):
handler = SamplingHandler("met", {})
fake_client = MagicMock()
fake_client.chat.completions.create.return_value = _make_llm_response()
with patch(
"agent.auxiliary_client.get_text_auxiliary_client",
return_value=(fake_client, "default-model"),
):
asyncio.run(handler(None, _make_sampling_params()))
assert handler.metrics["requests"] == 1
assert handler.metrics["tokens_used"] == 42
assert handler.metrics["errors"] == 0
def test_tool_use_count_metric(self):
handler = SamplingHandler("met2", {})
fake_client = MagicMock()
fake_client.chat.completions.create.return_value = _make_llm_tool_response()
with patch(
"agent.auxiliary_client.get_text_auxiliary_client",
return_value=(fake_client, "default-model"),
):
asyncio.run(handler(None, _make_sampling_params()))
assert handler.metrics["tool_use_count"] == 1
assert handler.metrics["requests"] == 1
def test_error_metric_incremented(self):
handler = SamplingHandler("met3", {})
with patch(
"agent.auxiliary_client.get_text_auxiliary_client",
return_value=(None, None),
):
asyncio.run(handler(None, _make_sampling_params()))
assert handler.metrics["errors"] == 1
assert handler.metrics["requests"] == 0
# ---------------------------------------------------------------------------
# 13. session_kwargs()
# ---------------------------------------------------------------------------
class TestSessionKwargs:
def test_returns_correct_keys(self):
handler = SamplingHandler("sk", {})
kwargs = handler.session_kwargs()
assert "sampling_callback" in kwargs
assert "sampling_capabilities" in kwargs
assert kwargs["sampling_callback"] is handler
def test_sampling_capabilities_type(self):
handler = SamplingHandler("sk2", {})
kwargs = handler.session_kwargs()
cap = kwargs["sampling_capabilities"]
assert isinstance(cap, SamplingCapability)
assert isinstance(cap.tools, SamplingToolsCapability)
# ---------------------------------------------------------------------------
# 14. MCPServerTask integration
# ---------------------------------------------------------------------------
class TestMCPServerTaskSamplingIntegration:
def test_sampling_handler_created_when_enabled(self):
"""MCPServerTask.run() creates a SamplingHandler when sampling is enabled."""
from tools.mcp_tool import MCPServerTask, _MCP_SAMPLING_TYPES
server = MCPServerTask("int_test")
config = {
"command": "fake",
"sampling": {"enabled": True, "max_rpm": 5},
}
# We only need to test the setup logic, not the actual connection.
# Calling run() would attempt a real connection, so we test the
# sampling setup portion directly.
server._config = config
sampling_config = config.get("sampling", {})
if sampling_config.get("enabled", True) and _MCP_SAMPLING_TYPES:
server._sampling = SamplingHandler(server.name, sampling_config)
else:
server._sampling = None
assert server._sampling is not None
assert isinstance(server._sampling, SamplingHandler)
assert server._sampling.server_name == "int_test"
assert server._sampling.max_rpm == 5
def test_sampling_handler_none_when_disabled(self):
"""MCPServerTask._sampling is None when sampling is disabled."""
from tools.mcp_tool import MCPServerTask, _MCP_SAMPLING_TYPES
server = MCPServerTask("int_test2")
config = {
"command": "fake",
"sampling": {"enabled": False},
}
server._config = config
sampling_config = config.get("sampling", {})
if sampling_config.get("enabled", True) and _MCP_SAMPLING_TYPES:
server._sampling = SamplingHandler(server.name, sampling_config)
else:
server._sampling = None
assert server._sampling is None
def test_session_kwargs_used_in_stdio(self):
"""When sampling is set, session_kwargs() are passed to ClientSession."""
from tools.mcp_tool import MCPServerTask
server = MCPServerTask("sk_test")
server._sampling = SamplingHandler("sk_test", {"max_rpm": 7})
kwargs = server._sampling.session_kwargs()
assert "sampling_callback" in kwargs
assert "sampling_capabilities" in kwargs

View File

@@ -29,6 +29,18 @@ Example config::
headers:
Authorization: "Bearer sk-..."
timeout: 180
analysis:
command: "npx"
args: ["-y", "analysis-server"]
sampling: # server-initiated LLM requests
enabled: true # default: true
model: "gemini-3-flash" # override model (optional)
max_tokens_cap: 4096 # max tokens per request
timeout: 30 # LLM call timeout (seconds)
max_rpm: 10 # max requests per minute
allowed_models: [] # model whitelist (empty = all)
max_tool_rounds: 5 # tool loop limit (0 = disable)
log_level: "info" # audit verbosity
Features:
- Stdio transport (command + args) and HTTP/StreamableHTTP transport (url)
@@ -37,6 +49,8 @@ Features:
- Credential stripping in error messages returned to the LLM
- Configurable per-server timeouts for tool calls and connections
- Thread-safe architecture with dedicated background event loop
- Sampling support: MCP servers can request LLM completions via
sampling/createMessage (text and tool-use responses)
Architecture:
A dedicated background event loop (_mcp_loop) runs in a daemon thread.
@@ -58,9 +72,11 @@ Thread safety:
import asyncio
import json
import logging
import math
import os
import re
import threading
import time
from typing import Any, Dict, List, Optional
logger = logging.getLogger(__name__)
@@ -71,6 +87,7 @@ logger = logging.getLogger(__name__)
_MCP_AVAILABLE = False
_MCP_HTTP_AVAILABLE = False
_MCP_SAMPLING_TYPES = False
try:
from mcp import ClientSession, StdioServerParameters
from mcp.client.stdio import stdio_client
@@ -80,6 +97,20 @@ try:
_MCP_HTTP_AVAILABLE = True
except ImportError:
_MCP_HTTP_AVAILABLE = False
# Sampling types -- separated so older SDK versions don't break MCP support
try:
from mcp.types import (
CreateMessageResult,
CreateMessageResultWithTools,
ErrorData,
SamplingCapability,
SamplingToolsCapability,
TextContent,
ToolUseContent,
)
_MCP_SAMPLING_TYPES = True
except ImportError:
logger.debug("MCP sampling types not available -- sampling disabled")
except ImportError:
logger.debug("mcp package not installed -- MCP tool support disabled")
@@ -145,6 +176,386 @@ def _sanitize_error(text: str) -> str:
return _CREDENTIAL_PATTERN.sub("[REDACTED]", text)
# ---------------------------------------------------------------------------
# Sampling -- server-initiated LLM requests (MCP sampling/createMessage)
# ---------------------------------------------------------------------------
def _safe_numeric(value, default, coerce=int, minimum=1):
"""Coerce a config value to a numeric type, returning *default* on failure.
Handles string values from YAML (e.g. ``"10"`` instead of ``10``),
non-finite floats, and values below *minimum*.
"""
try:
result = coerce(value)
if isinstance(result, float) and not math.isfinite(result):
return default
return max(result, minimum)
except (TypeError, ValueError, OverflowError):
return default
class SamplingHandler:
"""Handles sampling/createMessage requests for a single MCP server.
Each MCPServerTask that has sampling enabled creates one SamplingHandler.
The handler is callable and passed directly to ``ClientSession`` as
the ``sampling_callback``. All state (rate-limit timestamps, metrics,
tool-loop counters) lives on the instance -- no module-level globals.
The callback is async and runs on the MCP background event loop. The
sync LLM call is offloaded to a thread via ``asyncio.to_thread()`` so
it doesn't block the event loop.
"""
_STOP_REASON_MAP = {"stop": "endTurn", "length": "maxTokens", "tool_calls": "toolUse"}
def __init__(self, server_name: str, config: dict):
self.server_name = server_name
self.max_rpm = _safe_numeric(config.get("max_rpm", 10), 10, int)
self.timeout = _safe_numeric(config.get("timeout", 30), 30, float)
self.max_tokens_cap = _safe_numeric(config.get("max_tokens_cap", 4096), 4096, int)
self.max_tool_rounds = _safe_numeric(
config.get("max_tool_rounds", 5), 5, int, minimum=0,
)
self.model_override = config.get("model")
self.allowed_models = config.get("allowed_models", [])
_log_levels = {"debug": logging.DEBUG, "info": logging.INFO, "warning": logging.WARNING}
self.audit_level = _log_levels.get(
str(config.get("log_level", "info")).lower(), logging.INFO,
)
# Per-instance state
self._rate_timestamps: List[float] = []
self._tool_loop_count = 0
self.metrics = {"requests": 0, "errors": 0, "tokens_used": 0, "tool_use_count": 0}
# -- Rate limiting -------------------------------------------------------
def _check_rate_limit(self) -> bool:
"""Sliding-window rate limiter. Returns True if request is allowed."""
now = time.time()
window = now - 60
self._rate_timestamps[:] = [t for t in self._rate_timestamps if t > window]
if len(self._rate_timestamps) >= self.max_rpm:
return False
self._rate_timestamps.append(now)
return True
# -- Model resolution ----------------------------------------------------
def _resolve_model(self, preferences) -> Optional[str]:
"""Config override > server hint > None (use default)."""
if self.model_override:
return self.model_override
if preferences and hasattr(preferences, "hints") and preferences.hints:
for hint in preferences.hints:
if hasattr(hint, "name") and hint.name:
return hint.name
return None
# -- Message conversion --------------------------------------------------
@staticmethod
def _extract_tool_result_text(block) -> str:
"""Extract text from a ToolResultContent block."""
if not hasattr(block, "content") or block.content is None:
return ""
items = block.content if isinstance(block.content, list) else [block.content]
return "\n".join(item.text for item in items if hasattr(item, "text"))
def _convert_messages(self, params) -> List[dict]:
"""Convert MCP SamplingMessages to OpenAI format.
Uses ``msg.content_as_list`` (SDK helper) so single-block and
list-of-blocks are handled uniformly. Dispatches per block type
with ``isinstance`` on real SDK types when available, falling back
to duck-typing via ``hasattr`` for compatibility.
"""
messages: List[dict] = []
for msg in params.messages:
blocks = msg.content_as_list if hasattr(msg, "content_as_list") else (
msg.content if isinstance(msg.content, list) else [msg.content]
)
# Separate blocks by kind
tool_results = [b for b in blocks if hasattr(b, "toolUseId")]
tool_uses = [b for b in blocks if hasattr(b, "name") and hasattr(b, "input") and not hasattr(b, "toolUseId")]
content_blocks = [b for b in blocks if not hasattr(b, "toolUseId") and not (hasattr(b, "name") and hasattr(b, "input"))]
# Emit tool result messages (role: tool)
for tr in tool_results:
messages.append({
"role": "tool",
"tool_call_id": tr.toolUseId,
"content": self._extract_tool_result_text(tr),
})
# Emit assistant tool_calls message
if tool_uses:
tc_list = []
for tu in tool_uses:
tc_list.append({
"id": getattr(tu, "id", f"call_{len(tc_list)}"),
"type": "function",
"function": {
"name": tu.name,
"arguments": json.dumps(tu.input) if isinstance(tu.input, dict) else str(tu.input),
},
})
msg_dict: dict = {"role": msg.role, "tool_calls": tc_list}
# Include any accompanying text
text_parts = [b.text for b in content_blocks if hasattr(b, "text")]
if text_parts:
msg_dict["content"] = "\n".join(text_parts)
messages.append(msg_dict)
elif content_blocks:
# Pure text/image content
if len(content_blocks) == 1 and hasattr(content_blocks[0], "text"):
messages.append({"role": msg.role, "content": content_blocks[0].text})
else:
parts = []
for block in content_blocks:
if hasattr(block, "text"):
parts.append({"type": "text", "text": block.text})
elif hasattr(block, "data") and hasattr(block, "mimeType"):
parts.append({
"type": "image_url",
"image_url": {"url": f"data:{block.mimeType};base64,{block.data}"},
})
else:
logger.warning(
"Unsupported sampling content block type: %s (skipped)",
type(block).__name__,
)
if parts:
messages.append({"role": msg.role, "content": parts})
return messages
# -- Error helper --------------------------------------------------------
@staticmethod
def _error(message: str, code: int = -1):
"""Return ErrorData (MCP spec) or raise as fallback."""
if _MCP_SAMPLING_TYPES:
return ErrorData(code=code, message=message)
raise Exception(message)
# -- Response building ---------------------------------------------------
def _build_tool_use_result(self, choice, response):
"""Build a CreateMessageResultWithTools from an LLM tool_calls response."""
self.metrics["tool_use_count"] += 1
# Tool loop governance
if self.max_tool_rounds == 0:
self._tool_loop_count = 0
return self._error(
f"Tool loops disabled for server '{self.server_name}' (max_tool_rounds=0)"
)
self._tool_loop_count += 1
if self._tool_loop_count > self.max_tool_rounds:
self._tool_loop_count = 0
return self._error(
f"Tool loop limit exceeded for server '{self.server_name}' "
f"(max {self.max_tool_rounds} rounds)"
)
content_blocks = []
for tc in choice.message.tool_calls:
args = tc.function.arguments
if isinstance(args, str):
try:
parsed = json.loads(args)
except (json.JSONDecodeError, ValueError):
logger.warning(
"MCP server '%s': malformed tool_calls arguments "
"from LLM (wrapping as raw): %.100s",
self.server_name, args,
)
parsed = {"_raw": args}
else:
parsed = args if isinstance(args, dict) else {"_raw": str(args)}
content_blocks.append(ToolUseContent(
type="tool_use",
id=tc.id,
name=tc.function.name,
input=parsed,
))
logger.log(
self.audit_level,
"MCP server '%s' sampling response: model=%s, tokens=%s, tool_calls=%d",
self.server_name, response.model,
getattr(getattr(response, "usage", None), "total_tokens", "?"),
len(content_blocks),
)
return CreateMessageResultWithTools(
role="assistant",
content=content_blocks,
model=response.model,
stopReason="toolUse",
)
def _build_text_result(self, choice, response):
"""Build a CreateMessageResult from a normal text response."""
self._tool_loop_count = 0 # reset on text response
response_text = choice.message.content or ""
logger.log(
self.audit_level,
"MCP server '%s' sampling response: model=%s, tokens=%s",
self.server_name, response.model,
getattr(getattr(response, "usage", None), "total_tokens", "?"),
)
return CreateMessageResult(
role="assistant",
content=TextContent(type="text", text=_sanitize_error(response_text)),
model=response.model,
stopReason=self._STOP_REASON_MAP.get(choice.finish_reason, "endTurn"),
)
# -- Session kwargs helper -----------------------------------------------
def session_kwargs(self) -> dict:
"""Return kwargs to pass to ClientSession for sampling support."""
return {
"sampling_callback": self,
"sampling_capabilities": SamplingCapability(
tools=SamplingToolsCapability(),
),
}
# -- Main callback -------------------------------------------------------
async def __call__(self, context, params):
"""Sampling callback invoked by the MCP SDK.
Conforms to ``SamplingFnT`` protocol. Returns
``CreateMessageResult``, ``CreateMessageResultWithTools``, or
``ErrorData``.
"""
# Rate limit
if not self._check_rate_limit():
logger.warning(
"MCP server '%s' sampling rate limit exceeded (%d/min)",
self.server_name, self.max_rpm,
)
self.metrics["errors"] += 1
return self._error(
f"Sampling rate limit exceeded for server '{self.server_name}' "
f"({self.max_rpm} requests/minute)"
)
# Resolve model
model = self._resolve_model(getattr(params, "modelPreferences", None))
# Get auxiliary LLM client
from agent.auxiliary_client import get_text_auxiliary_client
client, default_model = get_text_auxiliary_client()
if client is None:
self.metrics["errors"] += 1
return self._error("No LLM provider available for sampling")
resolved_model = model or default_model
# Model whitelist check
if self.allowed_models and resolved_model not in self.allowed_models:
logger.warning(
"MCP server '%s' requested model '%s' not in allowed_models",
self.server_name, resolved_model,
)
self.metrics["errors"] += 1
return self._error(
f"Model '{resolved_model}' not allowed for server "
f"'{self.server_name}'. Allowed: {', '.join(self.allowed_models)}"
)
# Convert messages
messages = self._convert_messages(params)
if hasattr(params, "systemPrompt") and params.systemPrompt:
messages.insert(0, {"role": "system", "content": params.systemPrompt})
# Build LLM call kwargs
max_tokens = min(params.maxTokens, self.max_tokens_cap)
call_kwargs: dict = {
"model": resolved_model,
"messages": messages,
"max_tokens": max_tokens,
}
if hasattr(params, "temperature") and params.temperature is not None:
call_kwargs["temperature"] = params.temperature
if stop := getattr(params, "stopSequences", None):
call_kwargs["stop"] = stop
# Forward server-provided tools
server_tools = getattr(params, "tools", None)
if server_tools:
call_kwargs["tools"] = [
{
"type": "function",
"function": {
"name": getattr(t, "name", ""),
"description": getattr(t, "description", "") or "",
"parameters": getattr(t, "inputSchema", {}) or {},
},
}
for t in server_tools
]
if tool_choice := getattr(params, "toolChoice", None):
mode = getattr(tool_choice, "mode", "auto")
call_kwargs["tool_choice"] = {"auto": "auto", "required": "required", "none": "none"}.get(mode, "auto")
logger.log(
self.audit_level,
"MCP server '%s' sampling request: model=%s, max_tokens=%d, messages=%d",
self.server_name, resolved_model, max_tokens, len(messages),
)
# Offload sync LLM call to thread (non-blocking)
def _sync_call():
return client.chat.completions.create(**call_kwargs)
try:
response = await asyncio.wait_for(
asyncio.to_thread(_sync_call), timeout=self.timeout,
)
except asyncio.TimeoutError:
self.metrics["errors"] += 1
return self._error(
f"Sampling LLM call timed out after {self.timeout}s "
f"for server '{self.server_name}'"
)
except Exception as exc:
self.metrics["errors"] += 1
return self._error(
f"Sampling LLM call failed: {_sanitize_error(str(exc))}"
)
# Track metrics
choice = response.choices[0]
self.metrics["requests"] += 1
total_tokens = getattr(getattr(response, "usage", None), "total_tokens", 0)
if isinstance(total_tokens, int):
self.metrics["tokens_used"] += total_tokens
# Dispatch based on response type
if (
choice.finish_reason == "tool_calls"
and hasattr(choice.message, "tool_calls")
and choice.message.tool_calls
):
return self._build_tool_use_result(choice, response)
return self._build_text_result(choice, response)
# ---------------------------------------------------------------------------
# Server task -- each MCP server lives in one long-lived asyncio Task
# ---------------------------------------------------------------------------
@@ -162,6 +573,7 @@ class MCPServerTask:
__slots__ = (
"name", "session", "tool_timeout",
"_task", "_ready", "_shutdown_event", "_tools", "_error", "_config",
"_sampling",
)
def __init__(self, name: str):
@@ -174,6 +586,7 @@ class MCPServerTask:
self._tools: list = []
self._error: Optional[Exception] = None
self._config: dict = {}
self._sampling: Optional[SamplingHandler] = None
def _is_http(self) -> bool:
"""Check if this server uses HTTP transport."""
@@ -197,8 +610,9 @@ class MCPServerTask:
env=safe_env if safe_env else None,
)
sampling_kwargs = self._sampling.session_kwargs() if self._sampling else {}
async with stdio_client(server_params) as (read_stream, write_stream):
async with ClientSession(read_stream, write_stream) as session:
async with ClientSession(read_stream, write_stream, **sampling_kwargs) as session:
await session.initialize()
self.session = session
await self._discover_tools()
@@ -218,12 +632,13 @@ class MCPServerTask:
headers = config.get("headers")
connect_timeout = config.get("connect_timeout", _DEFAULT_CONNECT_TIMEOUT)
sampling_kwargs = self._sampling.session_kwargs() if self._sampling else {}
async with streamablehttp_client(
url,
headers=headers,
timeout=float(connect_timeout),
) as (read_stream, write_stream, _get_session_id):
async with ClientSession(read_stream, write_stream) as session:
async with ClientSession(read_stream, write_stream, **sampling_kwargs) as session:
await session.initialize()
self.session = session
await self._discover_tools()
@@ -250,6 +665,13 @@ class MCPServerTask:
self._config = config
self.tool_timeout = config.get("timeout", _DEFAULT_TOOL_TIMEOUT)
# Set up sampling handler if enabled and SDK types are available
sampling_config = config.get("sampling", {})
if sampling_config.get("enabled", True) and _MCP_SAMPLING_TYPES:
self._sampling = SamplingHandler(self.name, sampling_config)
else:
self._sampling = None
# Validate: warn if both url and command are present
if "url" in config and "command" in config:
logger.warning(
@@ -975,12 +1397,15 @@ def get_mcp_status() -> List[dict]:
transport = "http" if "url" in cfg else "stdio"
server = active_servers.get(name)
if server and server.session is not None:
result.append({
entry = {
"name": name,
"transport": transport,
"tools": len(server._tools),
"connected": True,
})
}
if server._sampling:
entry["sampling"] = dict(server._sampling.metrics)
result.append(entry)
else:
result.append({
"name": name,

View File

@@ -271,3 +271,62 @@ You can reload MCP servers without restarting Hermes:
- In the CLI: the agent reconnects automatically
- In messaging: send `/reload-mcp`
## Sampling (Server-Initiated LLM Requests)
MCP's `sampling/createMessage` capability allows MCP servers to request LLM completions through the Hermes agent. This enables agent-in-the-loop workflows where servers can leverage the LLM during tool execution — for example, a database server asking the LLM to interpret query results, or a code analysis server requesting the LLM to review findings.
### How It Works
When an MCP server sends a `sampling/createMessage` request:
1. The sampling callback validates against rate limits and model whitelist
2. Resolves which model to use (config override > server hint > default)
3. Converts MCP messages to OpenAI-compatible format
4. Offloads the LLM call to a thread via `asyncio.to_thread()` (non-blocking)
5. Returns the response (text or tool use) back to the server
### Configuration
Sampling is **enabled by default** for all MCP servers. No extra setup needed — if you have an auxiliary LLM client configured, sampling works automatically.
```yaml
mcp_servers:
analysis_server:
command: "npx"
args: ["-y", "my-analysis-server"]
sampling:
enabled: true # default: true
model: "gemini-3-flash" # override model (optional)
max_tokens_cap: 4096 # max tokens per request (default: 4096)
timeout: 30 # LLM call timeout in seconds (default: 30)
max_rpm: 10 # max requests per minute (default: 10)
allowed_models: [] # model whitelist (empty = allow all)
max_tool_rounds: 5 # max consecutive tool use rounds (0 = disable)
log_level: "info" # audit verbosity: debug, info, warning
```
### Tool Use in Sampling
Servers can include `tools` and `toolChoice` in sampling requests, enabling multi-turn tool-augmented workflows within a single sampling session. The callback forwards tool definitions to the LLM, handles tool use responses with proper `ToolUseContent` types, and enforces `max_tool_rounds` to prevent infinite loops.
### Security
- **Rate limiting**: Per-server sliding window (default: 10 req/min)
- **Token cap**: Servers can't request more than `max_tokens_cap` (default: 4096)
- **Model whitelist**: `allowed_models` restricts which models a server can use
- **Tool loop limit**: `max_tool_rounds` caps consecutive tool use rounds
- **Credential stripping**: LLM responses are sanitized before returning to the server
- **Non-blocking**: LLM calls run in a separate thread via `asyncio.to_thread()`
- **Typed errors**: All failures return structured `ErrorData` per MCP spec
To disable sampling for untrusted servers:
```yaml
mcp_servers:
untrusted:
command: "npx"
args: ["-y", "untrusted-server"]
sampling:
enabled: false
```