mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-29 15:31:38 +08:00
Compare commits
1 Commits
skill/gith
...
feat/gatew
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8cf7e80fc5 |
114
gateway/run.py
114
gateway/run.py
@@ -631,7 +631,18 @@ class GatewayRunner:
|
|||||||
# Track background tasks to prevent garbage collection mid-execution
|
# Track background tasks to prevent garbage collection mid-execution
|
||||||
self._background_tasks: set = set()
|
self._background_tasks: set = set()
|
||||||
|
|
||||||
|
# MCP config watcher state — detect header changes (e.g. OAuth token refresh)
|
||||||
|
self._mcp_config_mtime: float = 0.0
|
||||||
|
self._mcp_config_servers: dict = {}
|
||||||
|
try:
|
||||||
|
from hermes_cli.config import get_config_path, load_config
|
||||||
|
cfg_path = get_config_path()
|
||||||
|
if cfg_path.exists():
|
||||||
|
self._mcp_config_mtime = cfg_path.stat().st_mtime
|
||||||
|
cfg = load_config()
|
||||||
|
self._mcp_config_servers = cfg.get("mcp_servers") or {}
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
# -- Setup skill availability ----------------------------------------
|
# -- Setup skill availability ----------------------------------------
|
||||||
@@ -1690,10 +1701,111 @@ class GatewayRunner:
|
|||||||
)
|
)
|
||||||
asyncio.create_task(self._platform_reconnect_watcher())
|
asyncio.create_task(self._platform_reconnect_watcher())
|
||||||
|
|
||||||
|
# Start background MCP config watcher for auto-reloading on token refresh
|
||||||
|
asyncio.create_task(self._mcp_config_watcher())
|
||||||
|
|
||||||
logger.info("Press Ctrl+C to stop")
|
logger.info("Press Ctrl+C to stop")
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
async def _mcp_config_watcher(self, interval: int = 30, _initial_delay: int = 30) -> None:
|
||||||
|
"""Background task that detects MCP config changes and auto-reloads connections.
|
||||||
|
|
||||||
|
Polls config.yaml every ``interval`` seconds. When the ``mcp_servers``
|
||||||
|
section changes (e.g. OAuth token refresh updates the Authorization
|
||||||
|
header), triggers a full MCP shutdown + reconnect so the running
|
||||||
|
gateway picks up new credentials without a restart.
|
||||||
|
|
||||||
|
Mirrors the CLI's ``_check_config_mcp_changes`` but adapted for the
|
||||||
|
async gateway event loop.
|
||||||
|
"""
|
||||||
|
# Initial delay — let startup finish. Sleep in 1s increments for quick shutdown.
|
||||||
|
for _ in range(_initial_delay):
|
||||||
|
if not self._running:
|
||||||
|
return
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
logger.info("MCP config watcher started (checking every %ds)", interval)
|
||||||
|
|
||||||
|
while self._running:
|
||||||
|
try:
|
||||||
|
from hermes_cli.config import get_config_path
|
||||||
|
import yaml as _yaml
|
||||||
|
|
||||||
|
cfg_path = get_config_path()
|
||||||
|
if not cfg_path.exists():
|
||||||
|
await asyncio.sleep(interval)
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
mtime = cfg_path.stat().st_mtime
|
||||||
|
except OSError:
|
||||||
|
await asyncio.sleep(interval)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if mtime == self._mcp_config_mtime:
|
||||||
|
await asyncio.sleep(interval)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# File changed — read and compare mcp_servers section
|
||||||
|
self._mcp_config_mtime = mtime
|
||||||
|
try:
|
||||||
|
with open(cfg_path, encoding="utf-8") as f:
|
||||||
|
new_cfg = _yaml.safe_load(f) or {}
|
||||||
|
except Exception:
|
||||||
|
await asyncio.sleep(interval)
|
||||||
|
continue
|
||||||
|
|
||||||
|
new_mcp = new_cfg.get("mcp_servers") or {}
|
||||||
|
if new_mcp == self._mcp_config_servers:
|
||||||
|
# Some other config section changed, not MCP
|
||||||
|
await asyncio.sleep(interval)
|
||||||
|
continue
|
||||||
|
|
||||||
|
self._mcp_config_servers = new_mcp
|
||||||
|
logger.info("MCP config change detected — reloading connections...")
|
||||||
|
|
||||||
|
# Perform the reload in a thread to avoid blocking the event loop
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
try:
|
||||||
|
from tools.mcp_tool import shutdown_mcp_servers, discover_mcp_tools, _servers, _lock
|
||||||
|
|
||||||
|
with _lock:
|
||||||
|
old_servers = set(_servers.keys())
|
||||||
|
|
||||||
|
await loop.run_in_executor(None, shutdown_mcp_servers)
|
||||||
|
new_tools = await loop.run_in_executor(None, discover_mcp_tools)
|
||||||
|
|
||||||
|
with _lock:
|
||||||
|
connected_servers = set(_servers.keys())
|
||||||
|
|
||||||
|
added = connected_servers - old_servers
|
||||||
|
removed = old_servers - connected_servers
|
||||||
|
reconnected = connected_servers & old_servers
|
||||||
|
|
||||||
|
parts = []
|
||||||
|
if reconnected:
|
||||||
|
parts.append(f"♻️ Reconnected: {', '.join(sorted(reconnected))}")
|
||||||
|
if added:
|
||||||
|
parts.append(f"➕ Added: {', '.join(sorted(added))}")
|
||||||
|
if removed:
|
||||||
|
parts.append(f"➖ Removed: {', '.join(sorted(removed))}")
|
||||||
|
parts.append(
|
||||||
|
f"🔧 {len(new_tools)} tool(s) from {len(connected_servers)} server(s)"
|
||||||
|
)
|
||||||
|
logger.info("MCP auto-reload complete: %s", "; ".join(parts))
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("MCP auto-reload failed: %s", e)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug("MCP config watcher error: %s", e)
|
||||||
|
|
||||||
|
# Sleep in 1-second increments so we respond quickly to shutdown
|
||||||
|
for _ in range(interval):
|
||||||
|
if not self._running:
|
||||||
|
return
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
|
||||||
async def _session_expiry_watcher(self, interval: int = 300):
|
async def _session_expiry_watcher(self, interval: int = 300):
|
||||||
"""Background task that proactively flushes memories for expired sessions.
|
"""Background task that proactively flushes memories for expired sessions.
|
||||||
|
|
||||||
|
|||||||
217
tests/gateway/test_gateway_mcp_config_watcher.py
Normal file
217
tests/gateway/test_gateway_mcp_config_watcher.py
Normal file
@@ -0,0 +1,217 @@
|
|||||||
|
"""Tests for gateway MCP config watcher — auto-reload on mcp_servers changes."""
|
||||||
|
import asyncio
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
from gateway.run import GatewayRunner
|
||||||
|
|
||||||
|
|
||||||
|
def _make_runner(tmp_path, mcp_servers=None):
|
||||||
|
"""Create a minimal GatewayRunner with mocked MCP config watcher state."""
|
||||||
|
runner = object.__new__(GatewayRunner)
|
||||||
|
runner._running = True
|
||||||
|
runner._mcp_config_servers = mcp_servers or {}
|
||||||
|
|
||||||
|
cfg_file = tmp_path / "config.yaml"
|
||||||
|
cfg_file.write_text(yaml.dump({"mcp_servers": mcp_servers or {}}))
|
||||||
|
runner._mcp_config_mtime = cfg_file.stat().st_mtime
|
||||||
|
|
||||||
|
return runner, cfg_file
|
||||||
|
|
||||||
|
|
||||||
|
class TestMCPConfigWatcher:
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_no_change_does_not_reload(self, tmp_path):
|
||||||
|
"""If config file hasn't changed, no MCP reload should happen."""
|
||||||
|
runner, cfg_file = _make_runner(tmp_path, mcp_servers={
|
||||||
|
"betterstack": {"url": "https://mcp.betterstack.com", "headers": {"Authorization": "Bearer old"}}
|
||||||
|
})
|
||||||
|
|
||||||
|
reload_called = False
|
||||||
|
|
||||||
|
async def fake_watcher_iteration():
|
||||||
|
nonlocal reload_called
|
||||||
|
from hermes_cli.config import get_config_path
|
||||||
|
import yaml as _yaml
|
||||||
|
|
||||||
|
cfg_path = cfg_file
|
||||||
|
mtime = cfg_path.stat().st_mtime
|
||||||
|
|
||||||
|
if mtime == runner._mcp_config_mtime:
|
||||||
|
return # No change — fast path
|
||||||
|
|
||||||
|
runner._mcp_config_mtime = mtime
|
||||||
|
with open(cfg_path, encoding="utf-8") as f:
|
||||||
|
new_cfg = _yaml.safe_load(f) or {}
|
||||||
|
|
||||||
|
new_mcp = new_cfg.get("mcp_servers") or {}
|
||||||
|
if new_mcp == runner._mcp_config_servers:
|
||||||
|
return
|
||||||
|
|
||||||
|
reload_called = True
|
||||||
|
|
||||||
|
await fake_watcher_iteration()
|
||||||
|
assert not reload_called
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_header_change_triggers_reload(self, tmp_path):
|
||||||
|
"""When Authorization header changes, reload should be triggered."""
|
||||||
|
old_servers = {
|
||||||
|
"betterstack": {"url": "https://mcp.betterstack.com", "headers": {"Authorization": "Bearer old_token"}}
|
||||||
|
}
|
||||||
|
runner, cfg_file = _make_runner(tmp_path, mcp_servers=old_servers)
|
||||||
|
|
||||||
|
# Simulate token refresh updating the config
|
||||||
|
new_servers = {
|
||||||
|
"betterstack": {"url": "https://mcp.betterstack.com", "headers": {"Authorization": "Bearer new_token"}}
|
||||||
|
}
|
||||||
|
cfg_file.write_text(yaml.dump({"mcp_servers": new_servers}))
|
||||||
|
|
||||||
|
# Force mtime to look different
|
||||||
|
runner._mcp_config_mtime = 0.0
|
||||||
|
|
||||||
|
reload_triggered = False
|
||||||
|
|
||||||
|
# Simulate one iteration of the watcher's core logic
|
||||||
|
mtime = cfg_file.stat().st_mtime
|
||||||
|
assert mtime != runner._mcp_config_mtime
|
||||||
|
|
||||||
|
runner._mcp_config_mtime = mtime
|
||||||
|
with open(cfg_file, encoding="utf-8") as f:
|
||||||
|
new_cfg = yaml.safe_load(f) or {}
|
||||||
|
|
||||||
|
new_mcp = new_cfg.get("mcp_servers") or {}
|
||||||
|
if new_mcp != runner._mcp_config_servers:
|
||||||
|
reload_triggered = True
|
||||||
|
runner._mcp_config_servers = new_mcp
|
||||||
|
|
||||||
|
assert reload_triggered
|
||||||
|
assert runner._mcp_config_servers == new_servers
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_non_mcp_change_does_not_reload(self, tmp_path):
|
||||||
|
"""If a non-MCP section changes but mcp_servers stays the same, no reload."""
|
||||||
|
servers = {
|
||||||
|
"betterstack": {"url": "https://mcp.betterstack.com", "headers": {"Authorization": "Bearer tok"}}
|
||||||
|
}
|
||||||
|
runner, cfg_file = _make_runner(tmp_path, mcp_servers=servers)
|
||||||
|
|
||||||
|
# Write same mcp_servers but change something else
|
||||||
|
cfg_file.write_text(yaml.dump({
|
||||||
|
"mcp_servers": servers,
|
||||||
|
"some_other_setting": "changed"
|
||||||
|
}))
|
||||||
|
runner._mcp_config_mtime = 0.0 # force stale mtime
|
||||||
|
|
||||||
|
mtime = cfg_file.stat().st_mtime
|
||||||
|
runner._mcp_config_mtime = mtime
|
||||||
|
with open(cfg_file, encoding="utf-8") as f:
|
||||||
|
new_cfg = yaml.safe_load(f) or {}
|
||||||
|
|
||||||
|
new_mcp = new_cfg.get("mcp_servers") or {}
|
||||||
|
assert new_mcp == runner._mcp_config_servers # Should be unchanged
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_server_added_triggers_reload(self, tmp_path):
|
||||||
|
"""Adding a new MCP server to config triggers reload."""
|
||||||
|
runner, cfg_file = _make_runner(tmp_path, mcp_servers={})
|
||||||
|
|
||||||
|
new_servers = {"github": {"url": "https://api.github.com/mcp"}}
|
||||||
|
cfg_file.write_text(yaml.dump({"mcp_servers": new_servers}))
|
||||||
|
runner._mcp_config_mtime = 0.0
|
||||||
|
|
||||||
|
mtime = cfg_file.stat().st_mtime
|
||||||
|
runner._mcp_config_mtime = mtime
|
||||||
|
with open(cfg_file, encoding="utf-8") as f:
|
||||||
|
new_cfg = yaml.safe_load(f) or {}
|
||||||
|
|
||||||
|
new_mcp = new_cfg.get("mcp_servers") or {}
|
||||||
|
assert new_mcp != runner._mcp_config_servers
|
||||||
|
runner._mcp_config_servers = new_mcp
|
||||||
|
assert runner._mcp_config_servers == new_servers
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_server_removed_triggers_reload(self, tmp_path):
|
||||||
|
"""Removing an MCP server from config triggers reload."""
|
||||||
|
runner, cfg_file = _make_runner(tmp_path, mcp_servers={
|
||||||
|
"github": {"url": "https://api.github.com/mcp"}
|
||||||
|
})
|
||||||
|
|
||||||
|
cfg_file.write_text(yaml.dump({"mcp_servers": {}}))
|
||||||
|
runner._mcp_config_mtime = 0.0
|
||||||
|
|
||||||
|
mtime = cfg_file.stat().st_mtime
|
||||||
|
runner._mcp_config_mtime = mtime
|
||||||
|
with open(cfg_file, encoding="utf-8") as f:
|
||||||
|
new_cfg = yaml.safe_load(f) or {}
|
||||||
|
|
||||||
|
new_mcp = new_cfg.get("mcp_servers") or {}
|
||||||
|
assert new_mcp != runner._mcp_config_servers
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_watcher_stops_on_shutdown(self, tmp_path):
|
||||||
|
"""Watcher loop exits when _running is set to False."""
|
||||||
|
runner, cfg_file = _make_runner(tmp_path)
|
||||||
|
runner._running = False
|
||||||
|
|
||||||
|
# The watcher should return almost immediately
|
||||||
|
# We test it doesn't hang by using a timeout
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(
|
||||||
|
runner._mcp_config_watcher(interval=1, _initial_delay=0),
|
||||||
|
timeout=5.0,
|
||||||
|
)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
pytest.fail("_mcp_config_watcher did not exit after _running=False")
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_full_watcher_detects_change_and_reloads(self, tmp_path):
|
||||||
|
"""Integration test: watcher detects a header change and calls MCP reload."""
|
||||||
|
old_servers = {
|
||||||
|
"betterstack": {"url": "https://mcp.betterstack.com", "headers": {"Authorization": "Bearer old"}}
|
||||||
|
}
|
||||||
|
runner, cfg_file = _make_runner(tmp_path, mcp_servers=old_servers)
|
||||||
|
|
||||||
|
# Prepare the config change that will happen during the watcher run
|
||||||
|
new_servers = {
|
||||||
|
"betterstack": {"url": "https://mcp.betterstack.com", "headers": {"Authorization": "Bearer new"}}
|
||||||
|
}
|
||||||
|
|
||||||
|
shutdown_mock = MagicMock()
|
||||||
|
discover_mock = MagicMock(return_value=[{"function": {"name": "test_tool"}}])
|
||||||
|
servers_dict = {"betterstack": MagicMock()}
|
||||||
|
lock_mock = MagicMock()
|
||||||
|
|
||||||
|
async def stop_after_reload():
|
||||||
|
"""Write the config change, wait for the watcher to pick it up, then stop."""
|
||||||
|
await asyncio.sleep(0.5)
|
||||||
|
cfg_file.write_text(yaml.dump({"mcp_servers": new_servers}))
|
||||||
|
# Wait enough time for the watcher to detect + reload
|
||||||
|
await asyncio.sleep(4)
|
||||||
|
runner._running = False
|
||||||
|
|
||||||
|
with patch("hermes_cli.config.get_config_path", return_value=cfg_file), \
|
||||||
|
patch("tools.mcp_tool.shutdown_mcp_servers", shutdown_mock), \
|
||||||
|
patch("tools.mcp_tool.discover_mcp_tools", discover_mock), \
|
||||||
|
patch("tools.mcp_tool._servers", servers_dict), \
|
||||||
|
patch("tools.mcp_tool._lock", lock_mock):
|
||||||
|
|
||||||
|
stop_task = asyncio.create_task(stop_after_reload())
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(
|
||||||
|
runner._mcp_config_watcher(interval=1, _initial_delay=0),
|
||||||
|
timeout=10.0,
|
||||||
|
)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
runner._running = False
|
||||||
|
|
||||||
|
await stop_task
|
||||||
|
|
||||||
|
shutdown_mock.assert_called_once()
|
||||||
|
discover_mock.assert_called_once()
|
||||||
|
assert runner._mcp_config_servers == new_servers
|
||||||
Reference in New Issue
Block a user