Merge pull request #2465 from NousResearch/hermes/hermes-31d7db3b

feat(cli): MCP server management CLI + OAuth 2.1 PKCE auth
This commit is contained in:
Teknium
2026-03-22 04:56:48 -07:00
committed by GitHub
6 changed files with 1509 additions and 10 deletions

235
tools/mcp_oauth.py Normal file
View File

@@ -0,0 +1,235 @@
"""Thin OAuth adapter for MCP HTTP servers.
Wraps the MCP SDK's built-in ``OAuthClientProvider`` (which implements
``httpx.Auth``) with Hermes-specific token storage and browser-based
authorization. The SDK handles all of the heavy lifting: PKCE generation,
metadata discovery, dynamic client registration, token exchange, and refresh.
Usage in mcp_tool.py::
from tools.mcp_oauth import build_oauth_auth
auth = build_oauth_auth(server_name, server_url)
# pass ``auth`` as the httpx auth parameter
"""
from __future__ import annotations
import asyncio
import json
import logging
import os
import socket
import threading
import webbrowser
from http.server import BaseHTTPRequestHandler, HTTPServer
from pathlib import Path
from typing import Any
from urllib.parse import parse_qs, urlparse
logger = logging.getLogger(__name__)
_TOKEN_DIR_NAME = "mcp-tokens"
# ---------------------------------------------------------------------------
# Token storage — persists tokens + client info to ~/.hermes/mcp-tokens/
# ---------------------------------------------------------------------------
class HermesTokenStorage:
"""File-backed token storage implementing the MCP SDK's TokenStorage protocol."""
def __init__(self, server_name: str):
self._server_name = server_name
def _base_dir(self) -> Path:
home = Path(os.environ.get("HERMES_HOME", Path.home() / ".hermes"))
d = home / _TOKEN_DIR_NAME
d.mkdir(parents=True, exist_ok=True)
return d
def _tokens_path(self) -> Path:
return self._base_dir() / f"{self._server_name}.json"
def _client_path(self) -> Path:
return self._base_dir() / f"{self._server_name}.client.json"
# -- TokenStorage protocol (async) --
async def get_tokens(self):
data = self._read_json(self._tokens_path())
if not data:
return None
try:
from mcp.shared.auth import OAuthToken
return OAuthToken(**data)
except Exception:
return None
async def set_tokens(self, tokens) -> None:
self._write_json(self._tokens_path(), tokens.model_dump(exclude_none=True))
async def get_client_info(self):
data = self._read_json(self._client_path())
if not data:
return None
try:
from mcp.shared.auth import OAuthClientInformationFull
return OAuthClientInformationFull(**data)
except Exception:
return None
async def set_client_info(self, client_info) -> None:
self._write_json(self._client_path(), client_info.model_dump(exclude_none=True))
# -- helpers --
@staticmethod
def _read_json(path: Path) -> dict | None:
if not path.exists():
return None
try:
return json.loads(path.read_text(encoding="utf-8"))
except Exception:
return None
@staticmethod
def _write_json(path: Path, data: dict) -> None:
path.write_text(json.dumps(data, indent=2), encoding="utf-8")
try:
path.chmod(0o600)
except OSError:
pass
def remove(self) -> None:
"""Delete stored tokens and client info for this server."""
for p in (self._tokens_path(), self._client_path()):
try:
p.unlink(missing_ok=True)
except OSError:
pass
# ---------------------------------------------------------------------------
# Browser-based callback handler
# ---------------------------------------------------------------------------
def _find_free_port() -> int:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("127.0.0.1", 0))
return s.getsockname()[1]
class _CallbackHandler(BaseHTTPRequestHandler):
auth_code: str | None = None
state: str | None = None
def do_GET(self):
qs = parse_qs(urlparse(self.path).query)
_CallbackHandler.auth_code = (qs.get("code") or [None])[0]
_CallbackHandler.state = (qs.get("state") or [None])[0]
self.send_response(200)
self.send_header("Content-Type", "text/html")
self.end_headers()
self.wfile.write(b"<html><body><h3>Authorization complete. You can close this tab.</h3></body></html>")
def log_message(self, *_args: Any) -> None:
pass # suppress HTTP log noise
async def _redirect_to_browser(auth_url: str) -> None:
"""Open the authorization URL in the user's browser."""
try:
if _can_open_browser():
webbrowser.open(auth_url)
print(f" Opened browser for authorization...")
else:
print(f"\n Open this URL to authorize:\n {auth_url}\n")
except Exception:
print(f"\n Open this URL to authorize:\n {auth_url}\n")
async def _wait_for_callback() -> tuple[str, str | None]:
"""Start a local HTTP server and wait for the OAuth redirect callback."""
port = _find_free_port()
server = HTTPServer(("127.0.0.1", port), _CallbackHandler)
_CallbackHandler.auth_code = None
_CallbackHandler.state = None
def _serve():
server.timeout = 120
server.handle_request()
thread = threading.Thread(target=_serve, daemon=True)
thread.start()
# Wait for the callback
for _ in range(1200): # 120 seconds
await asyncio.sleep(0.1)
if _CallbackHandler.auth_code is not None:
break
server.server_close()
code = _CallbackHandler.auth_code or ""
state = _CallbackHandler.state
if not code:
# Fallback to manual entry
print(" Browser callback timed out. Paste the authorization code manually:")
code = input(" Code: ").strip()
return code, state
def _can_open_browser() -> bool:
if os.environ.get("SSH_CLIENT") or os.environ.get("SSH_TTY"):
return False
if not os.environ.get("DISPLAY") and os.name != "nt" and "darwin" not in os.uname().sysname.lower():
return False
return True
# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------
def build_oauth_auth(server_name: str, server_url: str):
"""Build an ``httpx.Auth`` handler for the given MCP server using OAuth 2.1 PKCE.
Uses the MCP SDK's ``OAuthClientProvider`` which handles discovery,
registration, PKCE, token exchange, and refresh automatically.
Returns an ``OAuthClientProvider`` instance (implements ``httpx.Auth``),
or ``None`` if the MCP SDK auth module is not available.
"""
try:
from mcp.client.auth import OAuthClientProvider
from mcp.shared.auth import OAuthClientMetadata
except ImportError:
logger.warning("MCP SDK auth module not available — OAuth disabled")
return None
port = _find_free_port()
redirect_uri = f"http://127.0.0.1:{port}/callback"
client_metadata = OAuthClientMetadata(
client_name="Hermes Agent",
redirect_uris=[redirect_uri],
grant_types=["authorization_code", "refresh_token"],
response_types=["code"],
scope="openid profile email offline_access",
token_endpoint_auth_method="none",
)
storage = HermesTokenStorage(server_name)
return OAuthClientProvider(
server_url=server_url,
client_metadata=client_metadata,
storage=storage,
redirect_handler=_redirect_to_browser,
callback_handler=_wait_for_callback,
timeout=120.0,
)
def remove_oauth_tokens(server_name: str) -> None:
"""Delete stored OAuth tokens and client info for a server."""
HermesTokenStorage(server_name).remove()

View File

@@ -690,7 +690,7 @@ class MCPServerTask:
__slots__ = (
"name", "session", "tool_timeout",
"_task", "_ready", "_shutdown_event", "_tools", "_error", "_config",
"_sampling", "_registered_tool_names",
"_sampling", "_registered_tool_names", "_auth_type",
)
def __init__(self, name: str):
@@ -705,6 +705,7 @@ class MCPServerTask:
self._config: dict = {}
self._sampling: Optional[SamplingHandler] = None
self._registered_tool_names: list[str] = []
self._auth_type: str = ""
def _is_http(self) -> bool:
"""Check if this server uses HTTP transport."""
@@ -748,15 +749,28 @@ class MCPServerTask:
)
url = config["url"]
headers = config.get("headers")
headers = dict(config.get("headers") or {})
connect_timeout = config.get("connect_timeout", _DEFAULT_CONNECT_TIMEOUT)
# OAuth 2.1 PKCE: build httpx.Auth handler using the MCP SDK
_oauth_auth = None
if self._auth_type == "oauth":
try:
from tools.mcp_oauth import build_oauth_auth
_oauth_auth = build_oauth_auth(self.name, url)
except Exception as exc:
logger.warning("MCP OAuth setup failed for '%s': %s", self.name, exc)
sampling_kwargs = self._sampling.session_kwargs() if self._sampling else {}
async with streamablehttp_client(
url,
headers=headers,
timeout=float(connect_timeout),
) as (read_stream, write_stream, _get_session_id):
_http_kwargs: dict = {
"headers": headers,
"timeout": float(connect_timeout),
}
if _oauth_auth is not None:
_http_kwargs["auth"] = _oauth_auth
async with streamablehttp_client(url, **_http_kwargs) as (
read_stream, write_stream, _get_session_id,
):
async with ClientSession(read_stream, write_stream, **sampling_kwargs) as session:
await session.initialize()
self.session = session
@@ -783,6 +797,7 @@ class MCPServerTask:
"""
self._config = config
self.tool_timeout = config.get("timeout", _DEFAULT_TOOL_TIMEOUT)
self._auth_type = config.get("auth", "").lower().strip()
# Set up sampling handler if enabled and SDK types are available
sampling_config = config.get("sampling", {})
@@ -920,13 +935,30 @@ def _run_on_mcp_loop(coro, timeout: float = 30):
# Config loading
# ---------------------------------------------------------------------------
def _interpolate_env_vars(value):
"""Recursively resolve ``${VAR}`` placeholders from ``os.environ``."""
if isinstance(value, str):
import re
def _replace(m):
return os.environ.get(m.group(1), m.group(0))
return re.sub(r"\$\{([^}]+)\}", _replace, value)
if isinstance(value, dict):
return {k: _interpolate_env_vars(v) for k, v in value.items()}
if isinstance(value, list):
return [_interpolate_env_vars(v) for v in value]
return value
def _load_mcp_config() -> Dict[str, dict]:
"""Read ``mcp_servers`` from the Hermes config file.
Returns a dict of ``{server_name: server_config}`` or empty dict.
Server config can contain either ``command``/``args``/``env`` for stdio
transport or ``url``/``headers`` for HTTP transport, plus optional
``timeout`` and ``connect_timeout`` overrides.
``timeout``, ``connect_timeout``, and ``auth`` overrides.
``${ENV_VAR}`` placeholders in string values are resolved from
``os.environ`` (which includes ``~/.hermes/.env`` loaded at startup).
"""
try:
from hermes_cli.config import load_config
@@ -934,7 +966,13 @@ def _load_mcp_config() -> Dict[str, dict]:
servers = config.get("mcp_servers")
if not servers or not isinstance(servers, dict):
return {}
return servers
# Ensure .env vars are available for interpolation
try:
from hermes_cli.env_loader import load_hermes_dotenv
load_hermes_dotenv()
except Exception:
pass
return {name: _interpolate_env_vars(cfg) for name, cfg in servers.items()}
except Exception as exc:
logger.debug("Failed to load MCP config: %s", exc)
return {}