mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-30 07:51:45 +08:00
Compare commits
2 Commits
fix/plugin
...
hermes/her
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
804b961c80 | ||
|
|
ef6455238a |
@@ -134,6 +134,92 @@ class TestUtilities:
|
|||||||
# remove_oauth_tokens
|
# remove_oauth_tokens
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestPathTraversal:
|
||||||
|
"""Verify server_name is sanitized to prevent path traversal."""
|
||||||
|
|
||||||
|
def test_path_traversal_blocked(self, tmp_path, monkeypatch):
|
||||||
|
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||||
|
storage = HermesTokenStorage("../../.ssh/config")
|
||||||
|
path = storage._tokens_path()
|
||||||
|
# Should stay within mcp-tokens directory
|
||||||
|
assert "mcp-tokens" in str(path)
|
||||||
|
assert ".ssh" not in str(path.resolve())
|
||||||
|
|
||||||
|
def test_dots_and_slashes_sanitized(self, tmp_path, monkeypatch):
|
||||||
|
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||||
|
storage = HermesTokenStorage("../../../etc/passwd")
|
||||||
|
path = storage._tokens_path()
|
||||||
|
resolved = path.resolve()
|
||||||
|
assert resolved.is_relative_to((tmp_path / "mcp-tokens").resolve())
|
||||||
|
|
||||||
|
def test_normal_name_unchanged(self, tmp_path, monkeypatch):
|
||||||
|
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||||
|
storage = HermesTokenStorage("my-mcp-server")
|
||||||
|
assert "my-mcp-server.json" in str(storage._tokens_path())
|
||||||
|
|
||||||
|
def test_special_chars_sanitized(self, tmp_path, monkeypatch):
|
||||||
|
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||||
|
storage = HermesTokenStorage("server@host:8080/path")
|
||||||
|
path = storage._tokens_path()
|
||||||
|
assert "@" not in path.name
|
||||||
|
assert ":" not in path.name
|
||||||
|
assert "/" not in path.stem
|
||||||
|
|
||||||
|
|
||||||
|
class TestCallbackHandlerIsolation:
|
||||||
|
"""Verify concurrent OAuth flows don't share state."""
|
||||||
|
|
||||||
|
def test_independent_result_dicts(self):
|
||||||
|
from tools.mcp_oauth import _make_callback_handler
|
||||||
|
_, result_a = _make_callback_handler()
|
||||||
|
_, result_b = _make_callback_handler()
|
||||||
|
|
||||||
|
result_a["auth_code"] = "code_A"
|
||||||
|
result_b["auth_code"] = "code_B"
|
||||||
|
|
||||||
|
assert result_a["auth_code"] == "code_A"
|
||||||
|
assert result_b["auth_code"] == "code_B"
|
||||||
|
|
||||||
|
def test_handler_writes_to_own_result(self):
|
||||||
|
from tools.mcp_oauth import _make_callback_handler
|
||||||
|
from io import BytesIO
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
HandlerClass, result = _make_callback_handler()
|
||||||
|
assert result["auth_code"] is None
|
||||||
|
|
||||||
|
# Simulate a GET request
|
||||||
|
handler = HandlerClass.__new__(HandlerClass)
|
||||||
|
handler.path = "/callback?code=test123&state=mystate"
|
||||||
|
handler.wfile = BytesIO()
|
||||||
|
handler.send_response = MagicMock()
|
||||||
|
handler.send_header = MagicMock()
|
||||||
|
handler.end_headers = MagicMock()
|
||||||
|
handler.do_GET()
|
||||||
|
|
||||||
|
assert result["auth_code"] == "test123"
|
||||||
|
assert result["state"] == "mystate"
|
||||||
|
|
||||||
|
|
||||||
|
class TestOAuthPortSharing:
|
||||||
|
"""Verify build_oauth_auth and _wait_for_callback use the same port."""
|
||||||
|
|
||||||
|
def test_port_stored_globally(self):
|
||||||
|
import tools.mcp_oauth as mod
|
||||||
|
# Reset
|
||||||
|
mod._oauth_port = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
from mcp.client.auth import OAuthClientProvider
|
||||||
|
except ImportError:
|
||||||
|
pytest.skip("MCP SDK auth not available")
|
||||||
|
|
||||||
|
build_oauth_auth("test-port", "https://example.com/mcp")
|
||||||
|
assert mod._oauth_port is not None
|
||||||
|
assert isinstance(mod._oauth_port, int)
|
||||||
|
assert 1024 <= mod._oauth_port <= 65535
|
||||||
|
|
||||||
|
|
||||||
class TestRemoveOAuthTokens:
|
class TestRemoveOAuthTokens:
|
||||||
def test_removes_files(self, tmp_path, monkeypatch):
|
def test_removes_files(self, tmp_path, monkeypatch):
|
||||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||||
|
|||||||
@@ -35,11 +35,19 @@ _TOKEN_DIR_NAME = "mcp-tokens"
|
|||||||
# Token storage — persists tokens + client info to ~/.hermes/mcp-tokens/
|
# Token storage — persists tokens + client info to ~/.hermes/mcp-tokens/
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _sanitize_server_name(name: str) -> str:
|
||||||
|
"""Sanitize server name for safe use as a filename."""
|
||||||
|
import re
|
||||||
|
clean = re.sub(r"[^\w\-]", "-", name.strip().lower())
|
||||||
|
clean = re.sub(r"-+", "-", clean).strip("-")
|
||||||
|
return clean[:60] or "unnamed"
|
||||||
|
|
||||||
|
|
||||||
class HermesTokenStorage:
|
class HermesTokenStorage:
|
||||||
"""File-backed token storage implementing the MCP SDK's TokenStorage protocol."""
|
"""File-backed token storage implementing the MCP SDK's TokenStorage protocol."""
|
||||||
|
|
||||||
def __init__(self, server_name: str):
|
def __init__(self, server_name: str):
|
||||||
self._server_name = server_name
|
self._server_name = _sanitize_server_name(server_name)
|
||||||
|
|
||||||
def _base_dir(self) -> Path:
|
def _base_dir(self) -> Path:
|
||||||
home = Path(os.environ.get("HERMES_HOME", Path.home() / ".hermes"))
|
home = Path(os.environ.get("HERMES_HOME", Path.home() / ".hermes"))
|
||||||
@@ -119,21 +127,28 @@ def _find_free_port() -> int:
|
|||||||
return s.getsockname()[1]
|
return s.getsockname()[1]
|
||||||
|
|
||||||
|
|
||||||
class _CallbackHandler(BaseHTTPRequestHandler):
|
def _make_callback_handler():
|
||||||
auth_code: str | None = None
|
"""Create a callback handler class with instance-scoped result storage."""
|
||||||
state: str | None = None
|
result = {"auth_code": None, "state": None}
|
||||||
|
|
||||||
|
class Handler(BaseHTTPRequestHandler):
|
||||||
def do_GET(self):
|
def do_GET(self):
|
||||||
qs = parse_qs(urlparse(self.path).query)
|
qs = parse_qs(urlparse(self.path).query)
|
||||||
_CallbackHandler.auth_code = (qs.get("code") or [None])[0]
|
result["auth_code"] = (qs.get("code") or [None])[0]
|
||||||
_CallbackHandler.state = (qs.get("state") or [None])[0]
|
result["state"] = (qs.get("state") or [None])[0]
|
||||||
self.send_response(200)
|
self.send_response(200)
|
||||||
self.send_header("Content-Type", "text/html")
|
self.send_header("Content-Type", "text/html")
|
||||||
self.end_headers()
|
self.end_headers()
|
||||||
self.wfile.write(b"<html><body><h3>Authorization complete. You can close this tab.</h3></body></html>")
|
self.wfile.write(b"<html><body><h3>Authorization complete. You can close this tab.</h3></body></html>")
|
||||||
|
|
||||||
def log_message(self, *_args: Any) -> None:
|
def log_message(self, *_args: Any) -> None:
|
||||||
pass # suppress HTTP log noise
|
pass
|
||||||
|
|
||||||
|
return Handler, result
|
||||||
|
|
||||||
|
|
||||||
|
# Port chosen at build time and shared with the callback handler via closure.
|
||||||
|
_oauth_port: int | None = None
|
||||||
|
|
||||||
|
|
||||||
async def _redirect_to_browser(auth_url: str) -> None:
|
async def _redirect_to_browser(auth_url: str) -> None:
|
||||||
@@ -149,11 +164,11 @@ async def _redirect_to_browser(auth_url: str) -> None:
|
|||||||
|
|
||||||
|
|
||||||
async def _wait_for_callback() -> tuple[str, str | None]:
|
async def _wait_for_callback() -> tuple[str, str | None]:
|
||||||
"""Start a local HTTP server and wait for the OAuth redirect callback."""
|
"""Start a local HTTP server on the pre-registered port and wait for the OAuth redirect."""
|
||||||
port = _find_free_port()
|
global _oauth_port
|
||||||
server = HTTPServer(("127.0.0.1", port), _CallbackHandler)
|
port = _oauth_port or _find_free_port()
|
||||||
_CallbackHandler.auth_code = None
|
HandlerClass, result = _make_callback_handler()
|
||||||
_CallbackHandler.state = None
|
server = HTTPServer(("127.0.0.1", port), HandlerClass)
|
||||||
|
|
||||||
def _serve():
|
def _serve():
|
||||||
server.timeout = 120
|
server.timeout = 120
|
||||||
@@ -162,17 +177,15 @@ async def _wait_for_callback() -> tuple[str, str | None]:
|
|||||||
thread = threading.Thread(target=_serve, daemon=True)
|
thread = threading.Thread(target=_serve, daemon=True)
|
||||||
thread.start()
|
thread.start()
|
||||||
|
|
||||||
# Wait for the callback
|
|
||||||
for _ in range(1200): # 120 seconds
|
for _ in range(1200): # 120 seconds
|
||||||
await asyncio.sleep(0.1)
|
await asyncio.sleep(0.1)
|
||||||
if _CallbackHandler.auth_code is not None:
|
if result["auth_code"] is not None:
|
||||||
break
|
break
|
||||||
|
|
||||||
server.server_close()
|
server.server_close()
|
||||||
code = _CallbackHandler.auth_code or ""
|
code = result["auth_code"] or ""
|
||||||
state = _CallbackHandler.state
|
state = result["state"]
|
||||||
if not code:
|
if not code:
|
||||||
# Fallback to manual entry
|
|
||||||
print(" Browser callback timed out. Paste the authorization code manually:")
|
print(" Browser callback timed out. Paste the authorization code manually:")
|
||||||
code = input(" Code: ").strip()
|
code = input(" Code: ").strip()
|
||||||
return code, state
|
return code, state
|
||||||
@@ -206,8 +219,9 @@ def build_oauth_auth(server_name: str, server_url: str):
|
|||||||
logger.warning("MCP SDK auth module not available — OAuth disabled")
|
logger.warning("MCP SDK auth module not available — OAuth disabled")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
port = _find_free_port()
|
global _oauth_port
|
||||||
redirect_uri = f"http://127.0.0.1:{port}/callback"
|
_oauth_port = _find_free_port()
|
||||||
|
redirect_uri = f"http://127.0.0.1:{_oauth_port}/callback"
|
||||||
|
|
||||||
client_metadata = OAuthClientMetadata(
|
client_metadata = OAuthClientMetadata(
|
||||||
client_name="Hermes Agent",
|
client_name="Hermes Agent",
|
||||||
|
|||||||
Reference in New Issue
Block a user