mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-28 06:51:16 +08:00
Merge pull request #2465 from NousResearch/hermes/hermes-31d7db3b
feat(cli): MCP server management CLI + OAuth 2.1 PKCE auth
This commit is contained in:
235
tools/mcp_oauth.py
Normal file
235
tools/mcp_oauth.py
Normal file
@@ -0,0 +1,235 @@
|
||||
"""Thin OAuth adapter for MCP HTTP servers.
|
||||
|
||||
Wraps the MCP SDK's built-in ``OAuthClientProvider`` (which implements
|
||||
``httpx.Auth``) with Hermes-specific token storage and browser-based
|
||||
authorization. The SDK handles all of the heavy lifting: PKCE generation,
|
||||
metadata discovery, dynamic client registration, token exchange, and refresh.
|
||||
|
||||
Usage in mcp_tool.py::
|
||||
|
||||
from tools.mcp_oauth import build_oauth_auth
|
||||
auth = build_oauth_auth(server_name, server_url)
|
||||
# pass ``auth`` as the httpx auth parameter
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import socket
|
||||
import threading
|
||||
import webbrowser
|
||||
from http.server import BaseHTTPRequestHandler, HTTPServer
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_TOKEN_DIR_NAME = "mcp-tokens"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Token storage — persists tokens + client info to ~/.hermes/mcp-tokens/
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class HermesTokenStorage:
|
||||
"""File-backed token storage implementing the MCP SDK's TokenStorage protocol."""
|
||||
|
||||
def __init__(self, server_name: str):
|
||||
self._server_name = server_name
|
||||
|
||||
def _base_dir(self) -> Path:
|
||||
home = Path(os.environ.get("HERMES_HOME", Path.home() / ".hermes"))
|
||||
d = home / _TOKEN_DIR_NAME
|
||||
d.mkdir(parents=True, exist_ok=True)
|
||||
return d
|
||||
|
||||
def _tokens_path(self) -> Path:
|
||||
return self._base_dir() / f"{self._server_name}.json"
|
||||
|
||||
def _client_path(self) -> Path:
|
||||
return self._base_dir() / f"{self._server_name}.client.json"
|
||||
|
||||
# -- TokenStorage protocol (async) --
|
||||
|
||||
async def get_tokens(self):
|
||||
data = self._read_json(self._tokens_path())
|
||||
if not data:
|
||||
return None
|
||||
try:
|
||||
from mcp.shared.auth import OAuthToken
|
||||
return OAuthToken(**data)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
async def set_tokens(self, tokens) -> None:
|
||||
self._write_json(self._tokens_path(), tokens.model_dump(exclude_none=True))
|
||||
|
||||
async def get_client_info(self):
|
||||
data = self._read_json(self._client_path())
|
||||
if not data:
|
||||
return None
|
||||
try:
|
||||
from mcp.shared.auth import OAuthClientInformationFull
|
||||
return OAuthClientInformationFull(**data)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
async def set_client_info(self, client_info) -> None:
|
||||
self._write_json(self._client_path(), client_info.model_dump(exclude_none=True))
|
||||
|
||||
# -- helpers --
|
||||
|
||||
@staticmethod
|
||||
def _read_json(path: Path) -> dict | None:
|
||||
if not path.exists():
|
||||
return None
|
||||
try:
|
||||
return json.loads(path.read_text(encoding="utf-8"))
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _write_json(path: Path, data: dict) -> None:
|
||||
path.write_text(json.dumps(data, indent=2), encoding="utf-8")
|
||||
try:
|
||||
path.chmod(0o600)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
def remove(self) -> None:
|
||||
"""Delete stored tokens and client info for this server."""
|
||||
for p in (self._tokens_path(), self._client_path()):
|
||||
try:
|
||||
p.unlink(missing_ok=True)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Browser-based callback handler
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _find_free_port() -> int:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.bind(("127.0.0.1", 0))
|
||||
return s.getsockname()[1]
|
||||
|
||||
|
||||
class _CallbackHandler(BaseHTTPRequestHandler):
|
||||
auth_code: str | None = None
|
||||
state: str | None = None
|
||||
|
||||
def do_GET(self):
|
||||
qs = parse_qs(urlparse(self.path).query)
|
||||
_CallbackHandler.auth_code = (qs.get("code") or [None])[0]
|
||||
_CallbackHandler.state = (qs.get("state") or [None])[0]
|
||||
self.send_response(200)
|
||||
self.send_header("Content-Type", "text/html")
|
||||
self.end_headers()
|
||||
self.wfile.write(b"<html><body><h3>Authorization complete. You can close this tab.</h3></body></html>")
|
||||
|
||||
def log_message(self, *_args: Any) -> None:
|
||||
pass # suppress HTTP log noise
|
||||
|
||||
|
||||
async def _redirect_to_browser(auth_url: str) -> None:
|
||||
"""Open the authorization URL in the user's browser."""
|
||||
try:
|
||||
if _can_open_browser():
|
||||
webbrowser.open(auth_url)
|
||||
print(f" Opened browser for authorization...")
|
||||
else:
|
||||
print(f"\n Open this URL to authorize:\n {auth_url}\n")
|
||||
except Exception:
|
||||
print(f"\n Open this URL to authorize:\n {auth_url}\n")
|
||||
|
||||
|
||||
async def _wait_for_callback() -> tuple[str, str | None]:
|
||||
"""Start a local HTTP server and wait for the OAuth redirect callback."""
|
||||
port = _find_free_port()
|
||||
server = HTTPServer(("127.0.0.1", port), _CallbackHandler)
|
||||
_CallbackHandler.auth_code = None
|
||||
_CallbackHandler.state = None
|
||||
|
||||
def _serve():
|
||||
server.timeout = 120
|
||||
server.handle_request()
|
||||
|
||||
thread = threading.Thread(target=_serve, daemon=True)
|
||||
thread.start()
|
||||
|
||||
# Wait for the callback
|
||||
for _ in range(1200): # 120 seconds
|
||||
await asyncio.sleep(0.1)
|
||||
if _CallbackHandler.auth_code is not None:
|
||||
break
|
||||
|
||||
server.server_close()
|
||||
code = _CallbackHandler.auth_code or ""
|
||||
state = _CallbackHandler.state
|
||||
if not code:
|
||||
# Fallback to manual entry
|
||||
print(" Browser callback timed out. Paste the authorization code manually:")
|
||||
code = input(" Code: ").strip()
|
||||
return code, state
|
||||
|
||||
|
||||
def _can_open_browser() -> bool:
|
||||
if os.environ.get("SSH_CLIENT") or os.environ.get("SSH_TTY"):
|
||||
return False
|
||||
if not os.environ.get("DISPLAY") and os.name != "nt" and "darwin" not in os.uname().sysname.lower():
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public API
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def build_oauth_auth(server_name: str, server_url: str):
|
||||
"""Build an ``httpx.Auth`` handler for the given MCP server using OAuth 2.1 PKCE.
|
||||
|
||||
Uses the MCP SDK's ``OAuthClientProvider`` which handles discovery,
|
||||
registration, PKCE, token exchange, and refresh automatically.
|
||||
|
||||
Returns an ``OAuthClientProvider`` instance (implements ``httpx.Auth``),
|
||||
or ``None`` if the MCP SDK auth module is not available.
|
||||
"""
|
||||
try:
|
||||
from mcp.client.auth import OAuthClientProvider
|
||||
from mcp.shared.auth import OAuthClientMetadata
|
||||
except ImportError:
|
||||
logger.warning("MCP SDK auth module not available — OAuth disabled")
|
||||
return None
|
||||
|
||||
port = _find_free_port()
|
||||
redirect_uri = f"http://127.0.0.1:{port}/callback"
|
||||
|
||||
client_metadata = OAuthClientMetadata(
|
||||
client_name="Hermes Agent",
|
||||
redirect_uris=[redirect_uri],
|
||||
grant_types=["authorization_code", "refresh_token"],
|
||||
response_types=["code"],
|
||||
scope="openid profile email offline_access",
|
||||
token_endpoint_auth_method="none",
|
||||
)
|
||||
|
||||
storage = HermesTokenStorage(server_name)
|
||||
|
||||
return OAuthClientProvider(
|
||||
server_url=server_url,
|
||||
client_metadata=client_metadata,
|
||||
storage=storage,
|
||||
redirect_handler=_redirect_to_browser,
|
||||
callback_handler=_wait_for_callback,
|
||||
timeout=120.0,
|
||||
)
|
||||
|
||||
|
||||
def remove_oauth_tokens(server_name: str) -> None:
|
||||
"""Delete stored OAuth tokens and client info for a server."""
|
||||
HermesTokenStorage(server_name).remove()
|
||||
@@ -690,7 +690,7 @@ class MCPServerTask:
|
||||
__slots__ = (
|
||||
"name", "session", "tool_timeout",
|
||||
"_task", "_ready", "_shutdown_event", "_tools", "_error", "_config",
|
||||
"_sampling", "_registered_tool_names",
|
||||
"_sampling", "_registered_tool_names", "_auth_type",
|
||||
)
|
||||
|
||||
def __init__(self, name: str):
|
||||
@@ -705,6 +705,7 @@ class MCPServerTask:
|
||||
self._config: dict = {}
|
||||
self._sampling: Optional[SamplingHandler] = None
|
||||
self._registered_tool_names: list[str] = []
|
||||
self._auth_type: str = ""
|
||||
|
||||
def _is_http(self) -> bool:
|
||||
"""Check if this server uses HTTP transport."""
|
||||
@@ -748,15 +749,28 @@ class MCPServerTask:
|
||||
)
|
||||
|
||||
url = config["url"]
|
||||
headers = config.get("headers")
|
||||
headers = dict(config.get("headers") or {})
|
||||
connect_timeout = config.get("connect_timeout", _DEFAULT_CONNECT_TIMEOUT)
|
||||
|
||||
# OAuth 2.1 PKCE: build httpx.Auth handler using the MCP SDK
|
||||
_oauth_auth = None
|
||||
if self._auth_type == "oauth":
|
||||
try:
|
||||
from tools.mcp_oauth import build_oauth_auth
|
||||
_oauth_auth = build_oauth_auth(self.name, url)
|
||||
except Exception as exc:
|
||||
logger.warning("MCP OAuth setup failed for '%s': %s", self.name, exc)
|
||||
|
||||
sampling_kwargs = self._sampling.session_kwargs() if self._sampling else {}
|
||||
async with streamablehttp_client(
|
||||
url,
|
||||
headers=headers,
|
||||
timeout=float(connect_timeout),
|
||||
) as (read_stream, write_stream, _get_session_id):
|
||||
_http_kwargs: dict = {
|
||||
"headers": headers,
|
||||
"timeout": float(connect_timeout),
|
||||
}
|
||||
if _oauth_auth is not None:
|
||||
_http_kwargs["auth"] = _oauth_auth
|
||||
async with streamablehttp_client(url, **_http_kwargs) as (
|
||||
read_stream, write_stream, _get_session_id,
|
||||
):
|
||||
async with ClientSession(read_stream, write_stream, **sampling_kwargs) as session:
|
||||
await session.initialize()
|
||||
self.session = session
|
||||
@@ -783,6 +797,7 @@ class MCPServerTask:
|
||||
"""
|
||||
self._config = config
|
||||
self.tool_timeout = config.get("timeout", _DEFAULT_TOOL_TIMEOUT)
|
||||
self._auth_type = config.get("auth", "").lower().strip()
|
||||
|
||||
# Set up sampling handler if enabled and SDK types are available
|
||||
sampling_config = config.get("sampling", {})
|
||||
@@ -920,13 +935,30 @@ def _run_on_mcp_loop(coro, timeout: float = 30):
|
||||
# Config loading
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _interpolate_env_vars(value):
|
||||
"""Recursively resolve ``${VAR}`` placeholders from ``os.environ``."""
|
||||
if isinstance(value, str):
|
||||
import re
|
||||
def _replace(m):
|
||||
return os.environ.get(m.group(1), m.group(0))
|
||||
return re.sub(r"\$\{([^}]+)\}", _replace, value)
|
||||
if isinstance(value, dict):
|
||||
return {k: _interpolate_env_vars(v) for k, v in value.items()}
|
||||
if isinstance(value, list):
|
||||
return [_interpolate_env_vars(v) for v in value]
|
||||
return value
|
||||
|
||||
|
||||
def _load_mcp_config() -> Dict[str, dict]:
|
||||
"""Read ``mcp_servers`` from the Hermes config file.
|
||||
|
||||
Returns a dict of ``{server_name: server_config}`` or empty dict.
|
||||
Server config can contain either ``command``/``args``/``env`` for stdio
|
||||
transport or ``url``/``headers`` for HTTP transport, plus optional
|
||||
``timeout`` and ``connect_timeout`` overrides.
|
||||
``timeout``, ``connect_timeout``, and ``auth`` overrides.
|
||||
|
||||
``${ENV_VAR}`` placeholders in string values are resolved from
|
||||
``os.environ`` (which includes ``~/.hermes/.env`` loaded at startup).
|
||||
"""
|
||||
try:
|
||||
from hermes_cli.config import load_config
|
||||
@@ -934,7 +966,13 @@ def _load_mcp_config() -> Dict[str, dict]:
|
||||
servers = config.get("mcp_servers")
|
||||
if not servers or not isinstance(servers, dict):
|
||||
return {}
|
||||
return servers
|
||||
# Ensure .env vars are available for interpolation
|
||||
try:
|
||||
from hermes_cli.env_loader import load_hermes_dotenv
|
||||
load_hermes_dotenv()
|
||||
except Exception:
|
||||
pass
|
||||
return {name: _interpolate_env_vars(cfg) for name, cfg in servers.items()}
|
||||
except Exception as exc:
|
||||
logger.debug("Failed to load MCP config: %s", exc)
|
||||
return {}
|
||||
|
||||
Reference in New Issue
Block a user