mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-28 15:01:34 +08:00
Compare commits
1 Commits
bb/base-gu
...
feat/gatew
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8cf7e80fc5 |
114
gateway/run.py
114
gateway/run.py
@@ -631,7 +631,18 @@ class GatewayRunner:
|
||||
# Track background tasks to prevent garbage collection mid-execution
|
||||
self._background_tasks: set = set()
|
||||
|
||||
|
||||
# MCP config watcher state — detect header changes (e.g. OAuth token refresh)
|
||||
self._mcp_config_mtime: float = 0.0
|
||||
self._mcp_config_servers: dict = {}
|
||||
try:
|
||||
from hermes_cli.config import get_config_path, load_config
|
||||
cfg_path = get_config_path()
|
||||
if cfg_path.exists():
|
||||
self._mcp_config_mtime = cfg_path.stat().st_mtime
|
||||
cfg = load_config()
|
||||
self._mcp_config_servers = cfg.get("mcp_servers") or {}
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
# -- Setup skill availability ----------------------------------------
|
||||
@@ -1690,10 +1701,111 @@ class GatewayRunner:
|
||||
)
|
||||
asyncio.create_task(self._platform_reconnect_watcher())
|
||||
|
||||
# Start background MCP config watcher for auto-reloading on token refresh
|
||||
asyncio.create_task(self._mcp_config_watcher())
|
||||
|
||||
logger.info("Press Ctrl+C to stop")
|
||||
|
||||
return True
|
||||
|
||||
async def _mcp_config_watcher(self, interval: int = 30, _initial_delay: int = 30) -> None:
|
||||
"""Background task that detects MCP config changes and auto-reloads connections.
|
||||
|
||||
Polls config.yaml every ``interval`` seconds. When the ``mcp_servers``
|
||||
section changes (e.g. OAuth token refresh updates the Authorization
|
||||
header), triggers a full MCP shutdown + reconnect so the running
|
||||
gateway picks up new credentials without a restart.
|
||||
|
||||
Mirrors the CLI's ``_check_config_mcp_changes`` but adapted for the
|
||||
async gateway event loop.
|
||||
"""
|
||||
# Initial delay — let startup finish. Sleep in 1s increments for quick shutdown.
|
||||
for _ in range(_initial_delay):
|
||||
if not self._running:
|
||||
return
|
||||
await asyncio.sleep(1)
|
||||
logger.info("MCP config watcher started (checking every %ds)", interval)
|
||||
|
||||
while self._running:
|
||||
try:
|
||||
from hermes_cli.config import get_config_path
|
||||
import yaml as _yaml
|
||||
|
||||
cfg_path = get_config_path()
|
||||
if not cfg_path.exists():
|
||||
await asyncio.sleep(interval)
|
||||
continue
|
||||
|
||||
try:
|
||||
mtime = cfg_path.stat().st_mtime
|
||||
except OSError:
|
||||
await asyncio.sleep(interval)
|
||||
continue
|
||||
|
||||
if mtime == self._mcp_config_mtime:
|
||||
await asyncio.sleep(interval)
|
||||
continue
|
||||
|
||||
# File changed — read and compare mcp_servers section
|
||||
self._mcp_config_mtime = mtime
|
||||
try:
|
||||
with open(cfg_path, encoding="utf-8") as f:
|
||||
new_cfg = _yaml.safe_load(f) or {}
|
||||
except Exception:
|
||||
await asyncio.sleep(interval)
|
||||
continue
|
||||
|
||||
new_mcp = new_cfg.get("mcp_servers") or {}
|
||||
if new_mcp == self._mcp_config_servers:
|
||||
# Some other config section changed, not MCP
|
||||
await asyncio.sleep(interval)
|
||||
continue
|
||||
|
||||
self._mcp_config_servers = new_mcp
|
||||
logger.info("MCP config change detected — reloading connections...")
|
||||
|
||||
# Perform the reload in a thread to avoid blocking the event loop
|
||||
loop = asyncio.get_event_loop()
|
||||
try:
|
||||
from tools.mcp_tool import shutdown_mcp_servers, discover_mcp_tools, _servers, _lock
|
||||
|
||||
with _lock:
|
||||
old_servers = set(_servers.keys())
|
||||
|
||||
await loop.run_in_executor(None, shutdown_mcp_servers)
|
||||
new_tools = await loop.run_in_executor(None, discover_mcp_tools)
|
||||
|
||||
with _lock:
|
||||
connected_servers = set(_servers.keys())
|
||||
|
||||
added = connected_servers - old_servers
|
||||
removed = old_servers - connected_servers
|
||||
reconnected = connected_servers & old_servers
|
||||
|
||||
parts = []
|
||||
if reconnected:
|
||||
parts.append(f"♻️ Reconnected: {', '.join(sorted(reconnected))}")
|
||||
if added:
|
||||
parts.append(f"➕ Added: {', '.join(sorted(added))}")
|
||||
if removed:
|
||||
parts.append(f"➖ Removed: {', '.join(sorted(removed))}")
|
||||
parts.append(
|
||||
f"🔧 {len(new_tools)} tool(s) from {len(connected_servers)} server(s)"
|
||||
)
|
||||
logger.info("MCP auto-reload complete: %s", "; ".join(parts))
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("MCP auto-reload failed: %s", e)
|
||||
|
||||
except Exception as e:
|
||||
logger.debug("MCP config watcher error: %s", e)
|
||||
|
||||
# Sleep in 1-second increments so we respond quickly to shutdown
|
||||
for _ in range(interval):
|
||||
if not self._running:
|
||||
return
|
||||
await asyncio.sleep(1)
|
||||
|
||||
async def _session_expiry_watcher(self, interval: int = 300):
|
||||
"""Background task that proactively flushes memories for expired sessions.
|
||||
|
||||
|
||||
217
tests/gateway/test_gateway_mcp_config_watcher.py
Normal file
217
tests/gateway/test_gateway_mcp_config_watcher.py
Normal file
@@ -0,0 +1,217 @@
|
||||
"""Tests for gateway MCP config watcher — auto-reload on mcp_servers changes."""
|
||||
import asyncio
|
||||
import time
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
from gateway.run import GatewayRunner
|
||||
|
||||
|
||||
def _make_runner(tmp_path, mcp_servers=None):
|
||||
"""Create a minimal GatewayRunner with mocked MCP config watcher state."""
|
||||
runner = object.__new__(GatewayRunner)
|
||||
runner._running = True
|
||||
runner._mcp_config_servers = mcp_servers or {}
|
||||
|
||||
cfg_file = tmp_path / "config.yaml"
|
||||
cfg_file.write_text(yaml.dump({"mcp_servers": mcp_servers or {}}))
|
||||
runner._mcp_config_mtime = cfg_file.stat().st_mtime
|
||||
|
||||
return runner, cfg_file
|
||||
|
||||
|
||||
class TestMCPConfigWatcher:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_change_does_not_reload(self, tmp_path):
|
||||
"""If config file hasn't changed, no MCP reload should happen."""
|
||||
runner, cfg_file = _make_runner(tmp_path, mcp_servers={
|
||||
"betterstack": {"url": "https://mcp.betterstack.com", "headers": {"Authorization": "Bearer old"}}
|
||||
})
|
||||
|
||||
reload_called = False
|
||||
|
||||
async def fake_watcher_iteration():
|
||||
nonlocal reload_called
|
||||
from hermes_cli.config import get_config_path
|
||||
import yaml as _yaml
|
||||
|
||||
cfg_path = cfg_file
|
||||
mtime = cfg_path.stat().st_mtime
|
||||
|
||||
if mtime == runner._mcp_config_mtime:
|
||||
return # No change — fast path
|
||||
|
||||
runner._mcp_config_mtime = mtime
|
||||
with open(cfg_path, encoding="utf-8") as f:
|
||||
new_cfg = _yaml.safe_load(f) or {}
|
||||
|
||||
new_mcp = new_cfg.get("mcp_servers") or {}
|
||||
if new_mcp == runner._mcp_config_servers:
|
||||
return
|
||||
|
||||
reload_called = True
|
||||
|
||||
await fake_watcher_iteration()
|
||||
assert not reload_called
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_header_change_triggers_reload(self, tmp_path):
|
||||
"""When Authorization header changes, reload should be triggered."""
|
||||
old_servers = {
|
||||
"betterstack": {"url": "https://mcp.betterstack.com", "headers": {"Authorization": "Bearer old_token"}}
|
||||
}
|
||||
runner, cfg_file = _make_runner(tmp_path, mcp_servers=old_servers)
|
||||
|
||||
# Simulate token refresh updating the config
|
||||
new_servers = {
|
||||
"betterstack": {"url": "https://mcp.betterstack.com", "headers": {"Authorization": "Bearer new_token"}}
|
||||
}
|
||||
cfg_file.write_text(yaml.dump({"mcp_servers": new_servers}))
|
||||
|
||||
# Force mtime to look different
|
||||
runner._mcp_config_mtime = 0.0
|
||||
|
||||
reload_triggered = False
|
||||
|
||||
# Simulate one iteration of the watcher's core logic
|
||||
mtime = cfg_file.stat().st_mtime
|
||||
assert mtime != runner._mcp_config_mtime
|
||||
|
||||
runner._mcp_config_mtime = mtime
|
||||
with open(cfg_file, encoding="utf-8") as f:
|
||||
new_cfg = yaml.safe_load(f) or {}
|
||||
|
||||
new_mcp = new_cfg.get("mcp_servers") or {}
|
||||
if new_mcp != runner._mcp_config_servers:
|
||||
reload_triggered = True
|
||||
runner._mcp_config_servers = new_mcp
|
||||
|
||||
assert reload_triggered
|
||||
assert runner._mcp_config_servers == new_servers
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_mcp_change_does_not_reload(self, tmp_path):
|
||||
"""If a non-MCP section changes but mcp_servers stays the same, no reload."""
|
||||
servers = {
|
||||
"betterstack": {"url": "https://mcp.betterstack.com", "headers": {"Authorization": "Bearer tok"}}
|
||||
}
|
||||
runner, cfg_file = _make_runner(tmp_path, mcp_servers=servers)
|
||||
|
||||
# Write same mcp_servers but change something else
|
||||
cfg_file.write_text(yaml.dump({
|
||||
"mcp_servers": servers,
|
||||
"some_other_setting": "changed"
|
||||
}))
|
||||
runner._mcp_config_mtime = 0.0 # force stale mtime
|
||||
|
||||
mtime = cfg_file.stat().st_mtime
|
||||
runner._mcp_config_mtime = mtime
|
||||
with open(cfg_file, encoding="utf-8") as f:
|
||||
new_cfg = yaml.safe_load(f) or {}
|
||||
|
||||
new_mcp = new_cfg.get("mcp_servers") or {}
|
||||
assert new_mcp == runner._mcp_config_servers # Should be unchanged
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_server_added_triggers_reload(self, tmp_path):
|
||||
"""Adding a new MCP server to config triggers reload."""
|
||||
runner, cfg_file = _make_runner(tmp_path, mcp_servers={})
|
||||
|
||||
new_servers = {"github": {"url": "https://api.github.com/mcp"}}
|
||||
cfg_file.write_text(yaml.dump({"mcp_servers": new_servers}))
|
||||
runner._mcp_config_mtime = 0.0
|
||||
|
||||
mtime = cfg_file.stat().st_mtime
|
||||
runner._mcp_config_mtime = mtime
|
||||
with open(cfg_file, encoding="utf-8") as f:
|
||||
new_cfg = yaml.safe_load(f) or {}
|
||||
|
||||
new_mcp = new_cfg.get("mcp_servers") or {}
|
||||
assert new_mcp != runner._mcp_config_servers
|
||||
runner._mcp_config_servers = new_mcp
|
||||
assert runner._mcp_config_servers == new_servers
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_server_removed_triggers_reload(self, tmp_path):
|
||||
"""Removing an MCP server from config triggers reload."""
|
||||
runner, cfg_file = _make_runner(tmp_path, mcp_servers={
|
||||
"github": {"url": "https://api.github.com/mcp"}
|
||||
})
|
||||
|
||||
cfg_file.write_text(yaml.dump({"mcp_servers": {}}))
|
||||
runner._mcp_config_mtime = 0.0
|
||||
|
||||
mtime = cfg_file.stat().st_mtime
|
||||
runner._mcp_config_mtime = mtime
|
||||
with open(cfg_file, encoding="utf-8") as f:
|
||||
new_cfg = yaml.safe_load(f) or {}
|
||||
|
||||
new_mcp = new_cfg.get("mcp_servers") or {}
|
||||
assert new_mcp != runner._mcp_config_servers
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_watcher_stops_on_shutdown(self, tmp_path):
|
||||
"""Watcher loop exits when _running is set to False."""
|
||||
runner, cfg_file = _make_runner(tmp_path)
|
||||
runner._running = False
|
||||
|
||||
# The watcher should return almost immediately
|
||||
# We test it doesn't hang by using a timeout
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
runner._mcp_config_watcher(interval=1, _initial_delay=0),
|
||||
timeout=5.0,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
pytest.fail("_mcp_config_watcher did not exit after _running=False")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_full_watcher_detects_change_and_reloads(self, tmp_path):
|
||||
"""Integration test: watcher detects a header change and calls MCP reload."""
|
||||
old_servers = {
|
||||
"betterstack": {"url": "https://mcp.betterstack.com", "headers": {"Authorization": "Bearer old"}}
|
||||
}
|
||||
runner, cfg_file = _make_runner(tmp_path, mcp_servers=old_servers)
|
||||
|
||||
# Prepare the config change that will happen during the watcher run
|
||||
new_servers = {
|
||||
"betterstack": {"url": "https://mcp.betterstack.com", "headers": {"Authorization": "Bearer new"}}
|
||||
}
|
||||
|
||||
shutdown_mock = MagicMock()
|
||||
discover_mock = MagicMock(return_value=[{"function": {"name": "test_tool"}}])
|
||||
servers_dict = {"betterstack": MagicMock()}
|
||||
lock_mock = MagicMock()
|
||||
|
||||
async def stop_after_reload():
|
||||
"""Write the config change, wait for the watcher to pick it up, then stop."""
|
||||
await asyncio.sleep(0.5)
|
||||
cfg_file.write_text(yaml.dump({"mcp_servers": new_servers}))
|
||||
# Wait enough time for the watcher to detect + reload
|
||||
await asyncio.sleep(4)
|
||||
runner._running = False
|
||||
|
||||
with patch("hermes_cli.config.get_config_path", return_value=cfg_file), \
|
||||
patch("tools.mcp_tool.shutdown_mcp_servers", shutdown_mock), \
|
||||
patch("tools.mcp_tool.discover_mcp_tools", discover_mock), \
|
||||
patch("tools.mcp_tool._servers", servers_dict), \
|
||||
patch("tools.mcp_tool._lock", lock_mock):
|
||||
|
||||
stop_task = asyncio.create_task(stop_after_reload())
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
runner._mcp_config_watcher(interval=1, _initial_delay=0),
|
||||
timeout=10.0,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
runner._running = False
|
||||
|
||||
await stop_task
|
||||
|
||||
shutdown_mock.assert_called_once()
|
||||
discover_mock.assert_called_once()
|
||||
assert runner._mcp_config_servers == new_servers
|
||||
Reference in New Issue
Block a user