mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-29 15:31:38 +08:00
Compare commits
2 Commits
fix/plugin
...
fix/mcp-oa
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
cc06beaf13 | ||
|
|
3eeab4bc06 |
210
tests/tools/test_mcp_oauth_bidirectional.py
Normal file
210
tests/tools/test_mcp_oauth_bidirectional.py
Normal file
@@ -0,0 +1,210 @@
|
|||||||
|
"""Regression test for the ``HermesMCPOAuthProvider.async_auth_flow`` bidirectional
|
||||||
|
generator bridge.
|
||||||
|
|
||||||
|
PR #11383 introduced a subclass method that wrapped the SDK's ``auth_flow`` with::
|
||||||
|
|
||||||
|
async for item in super().async_auth_flow(request):
|
||||||
|
yield item
|
||||||
|
|
||||||
|
``httpx``'s auth_flow contract is a **bidirectional** async generator — the
|
||||||
|
driving code (``httpx._client._send_handling_auth``) does::
|
||||||
|
|
||||||
|
next_request = await auth_flow.asend(response)
|
||||||
|
|
||||||
|
to feed HTTP responses back into the generator. The naive ``async for ...``
|
||||||
|
wrapper discards those ``.asend(response)`` values and resumes the inner
|
||||||
|
generator with ``None``, so the SDK's ``response = yield request`` branch in
|
||||||
|
``mcp/client/auth/oauth2.py`` sees ``response = None`` and crashes at
|
||||||
|
``if response.status_code == 401`` with
|
||||||
|
``AttributeError: 'NoneType' object has no attribute 'status_code'``.
|
||||||
|
|
||||||
|
This broke every OAuth MCP server on the first HTTP response regardless of
|
||||||
|
status code. The reason nothing caught it in CI: zero existing tests drive
|
||||||
|
the full ``.asend()`` round-trip — the integration tests in
|
||||||
|
``test_mcp_oauth_integration.py`` stop at ``_initialize()`` and disk-watching.
|
||||||
|
|
||||||
|
These tests drive the wrapper through a manual ``.asend()`` sequence to prove
|
||||||
|
the bridge forwards responses correctly into the inner SDK generator.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
pytest.importorskip("mcp.client.auth.oauth2", reason="MCP SDK 1.26.0+ required")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_hermes_provider_forwards_asend_values(tmp_path, monkeypatch):
|
||||||
|
"""The wrapper MUST forward ``.asend(response)`` into the inner generator.
|
||||||
|
|
||||||
|
This is the primary regression test. With the broken wrapper, the inner
|
||||||
|
SDK generator sees ``response = None`` and raises ``AttributeError`` at
|
||||||
|
``oauth2.py:505``. With the correct bridge, a 200 response finishes the
|
||||||
|
flow cleanly (``StopAsyncIteration``).
|
||||||
|
"""
|
||||||
|
import httpx
|
||||||
|
from mcp.shared.auth import OAuthClientMetadata, OAuthToken
|
||||||
|
from pydantic import AnyUrl
|
||||||
|
|
||||||
|
from tools.mcp_oauth import HermesTokenStorage
|
||||||
|
from tools.mcp_oauth_manager import _HERMES_PROVIDER_CLS, reset_manager_for_tests
|
||||||
|
|
||||||
|
assert _HERMES_PROVIDER_CLS is not None, "SDK OAuth types must be available"
|
||||||
|
|
||||||
|
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||||
|
reset_manager_for_tests()
|
||||||
|
|
||||||
|
# Seed a valid-looking token so the SDK's _initialize loads something and
|
||||||
|
# can_refresh_token() is True (though we don't exercise refresh here — we
|
||||||
|
# go straight through the 200 path).
|
||||||
|
storage = HermesTokenStorage("srv")
|
||||||
|
await storage.set_tokens(
|
||||||
|
OAuthToken(
|
||||||
|
access_token="old_access",
|
||||||
|
token_type="Bearer",
|
||||||
|
expires_in=3600,
|
||||||
|
refresh_token="old_refresh",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# Also seed client_info so the SDK doesn't attempt registration.
|
||||||
|
from mcp.shared.auth import OAuthClientInformationFull
|
||||||
|
|
||||||
|
await storage.set_client_info(
|
||||||
|
OAuthClientInformationFull(
|
||||||
|
client_id="test-client",
|
||||||
|
redirect_uris=[AnyUrl("http://127.0.0.1:12345/callback")],
|
||||||
|
grant_types=["authorization_code", "refresh_token"],
|
||||||
|
response_types=["code"],
|
||||||
|
token_endpoint_auth_method="none",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
metadata = OAuthClientMetadata(
|
||||||
|
redirect_uris=[AnyUrl("http://127.0.0.1:12345/callback")],
|
||||||
|
client_name="Hermes Agent",
|
||||||
|
)
|
||||||
|
provider = _HERMES_PROVIDER_CLS(
|
||||||
|
server_name="srv",
|
||||||
|
server_url="https://example.com/mcp",
|
||||||
|
client_metadata=metadata,
|
||||||
|
storage=storage,
|
||||||
|
redirect_handler=_noop_redirect,
|
||||||
|
callback_handler=_noop_callback,
|
||||||
|
)
|
||||||
|
|
||||||
|
req = httpx.Request("POST", "https://example.com/mcp")
|
||||||
|
flow = provider.async_auth_flow(req)
|
||||||
|
|
||||||
|
# First anext() drives the wrapper + inner generator until the inner
|
||||||
|
# yields the outbound request (at oauth2.py:503 ``response = yield request``).
|
||||||
|
outbound = await flow.__anext__()
|
||||||
|
assert outbound is not None, "wrapper must yield the outbound request"
|
||||||
|
assert outbound.url.host == "example.com"
|
||||||
|
|
||||||
|
# Simulate httpx returning a 200 response.
|
||||||
|
fake_response = httpx.Response(200, request=outbound)
|
||||||
|
|
||||||
|
# The broken wrapper would crash here with AttributeError: 'NoneType'
|
||||||
|
# object has no attribute 'status_code', because the SDK's inner generator
|
||||||
|
# resumes with response=None and dereferences .status_code at line 505.
|
||||||
|
#
|
||||||
|
# The correct wrapper forwards the response, the SDK takes the non-401
|
||||||
|
# non-403 exit, and the generator ends cleanly (StopAsyncIteration).
|
||||||
|
with pytest.raises(StopAsyncIteration):
|
||||||
|
await flow.asend(fake_response)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_hermes_provider_forwards_401_triggers_refresh(tmp_path, monkeypatch):
|
||||||
|
"""A 401 response MUST flow into the inner generator and trigger the
|
||||||
|
SDK's 401 recovery branch.
|
||||||
|
|
||||||
|
With the broken wrapper, the inner generator sees ``response = None``
|
||||||
|
and the 401 check short-circuits into AttributeError. With the correct
|
||||||
|
bridge, the 401 is routed into the SDK's ``response.status_code == 401``
|
||||||
|
branch which begins discovery (yielding a metadata-discovery request).
|
||||||
|
"""
|
||||||
|
import httpx
|
||||||
|
from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken
|
||||||
|
from pydantic import AnyUrl
|
||||||
|
|
||||||
|
from tools.mcp_oauth import HermesTokenStorage
|
||||||
|
from tools.mcp_oauth_manager import _HERMES_PROVIDER_CLS, reset_manager_for_tests
|
||||||
|
|
||||||
|
assert _HERMES_PROVIDER_CLS is not None
|
||||||
|
|
||||||
|
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||||
|
reset_manager_for_tests()
|
||||||
|
|
||||||
|
storage = HermesTokenStorage("srv")
|
||||||
|
await storage.set_tokens(
|
||||||
|
OAuthToken(
|
||||||
|
access_token="old_access",
|
||||||
|
token_type="Bearer",
|
||||||
|
expires_in=3600,
|
||||||
|
refresh_token="old_refresh",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
await storage.set_client_info(
|
||||||
|
OAuthClientInformationFull(
|
||||||
|
client_id="test-client",
|
||||||
|
redirect_uris=[AnyUrl("http://127.0.0.1:12345/callback")],
|
||||||
|
grant_types=["authorization_code", "refresh_token"],
|
||||||
|
response_types=["code"],
|
||||||
|
token_endpoint_auth_method="none",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
metadata = OAuthClientMetadata(
|
||||||
|
redirect_uris=[AnyUrl("http://127.0.0.1:12345/callback")],
|
||||||
|
client_name="Hermes Agent",
|
||||||
|
)
|
||||||
|
provider = _HERMES_PROVIDER_CLS(
|
||||||
|
server_name="srv",
|
||||||
|
server_url="https://example.com/mcp",
|
||||||
|
client_metadata=metadata,
|
||||||
|
storage=storage,
|
||||||
|
redirect_handler=_noop_redirect,
|
||||||
|
callback_handler=_noop_callback,
|
||||||
|
)
|
||||||
|
|
||||||
|
req = httpx.Request("POST", "https://example.com/mcp")
|
||||||
|
flow = provider.async_auth_flow(req)
|
||||||
|
|
||||||
|
# Drive to the first yield (outbound MCP request).
|
||||||
|
outbound = await flow.__anext__()
|
||||||
|
|
||||||
|
# Reply with a 401 including a minimal WWW-Authenticate so the SDK's
|
||||||
|
# 401 branch can parse resource metadata from it. We just need something
|
||||||
|
# the SDK accepts before it tries to yield the metadata-discovery request.
|
||||||
|
fake_401 = httpx.Response(
|
||||||
|
401,
|
||||||
|
request=outbound,
|
||||||
|
headers={"www-authenticate": 'Bearer resource_metadata="https://example.com/.well-known/oauth-protected-resource"'},
|
||||||
|
)
|
||||||
|
|
||||||
|
# The correct bridge forwards the 401 into the SDK; the SDK then yields
|
||||||
|
# its NEXT request (a metadata-discovery GET). We assert we get a request
|
||||||
|
# back — any request. The broken bridge would have crashed with
|
||||||
|
# AttributeError before we ever reach this point.
|
||||||
|
next_request = await flow.asend(fake_401)
|
||||||
|
assert isinstance(next_request, httpx.Request), (
|
||||||
|
"wrapper must forward .asend() so the SDK's 401 branch can yield the "
|
||||||
|
"next request in the discovery flow"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Clean up the generator — we don't need to complete the full dance.
|
||||||
|
await flow.aclose()
|
||||||
|
|
||||||
|
|
||||||
|
async def _noop_redirect(_url: str) -> None:
|
||||||
|
"""Redirect handler that does nothing (won't be invoked in these tests)."""
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def _noop_callback() -> tuple[str, str | None]:
|
||||||
|
"""Callback handler that won't be invoked in these tests."""
|
||||||
|
raise AssertionError(
|
||||||
|
"callback handler should not be invoked in bidirectional-generator tests"
|
||||||
|
)
|
||||||
546
tests/tools/test_mcp_oauth_cold_load_expiry.py
Normal file
546
tests/tools/test_mcp_oauth_cold_load_expiry.py
Normal file
@@ -0,0 +1,546 @@
|
|||||||
|
"""Tests for cold-load token expiry tracking in MCP OAuth.
|
||||||
|
|
||||||
|
PR #11383's consolidation fixed external-refresh reloading (mtime disk-watch)
|
||||||
|
and 401 dedup, but left two underlying latent bugs in place:
|
||||||
|
|
||||||
|
1. ``HermesTokenStorage.set_tokens`` persisted only relative ``expires_in``,
|
||||||
|
which is meaningless after a process restart.
|
||||||
|
2. The MCP SDK's ``OAuthContext._initialize`` loads ``current_tokens`` from
|
||||||
|
storage but does NOT call ``update_token_expiry``, so
|
||||||
|
``token_expiry_time`` stays None. ``is_token_valid()`` then returns True
|
||||||
|
for any loaded token regardless of actual age, and the SDK's preemptive
|
||||||
|
refresh branch at ``oauth2.py:491`` is never taken.
|
||||||
|
|
||||||
|
Consequence: a token that expired while the process was down ships to the
|
||||||
|
server with a stale Bearer header. The server's response is provider-specific
|
||||||
|
— some return HTTP 401 (caught by the consolidation's 401 handler, which
|
||||||
|
surfaces a ``needs_reauth`` error), others return HTTP 200 with an
|
||||||
|
application-level auth failure in the body (e.g. BetterStack's "No teams
|
||||||
|
found. Please check your authentication."), which the consolidation cannot
|
||||||
|
detect.
|
||||||
|
|
||||||
|
These tests pin the contract for Fix A:
|
||||||
|
- ``set_tokens`` persists an absolute ``expires_at`` wall-clock timestamp.
|
||||||
|
- ``get_tokens`` reconstructs ``expires_in`` from ``expires_at - now`` so
|
||||||
|
the SDK's ``update_token_expiry`` computes the correct absolute expiry.
|
||||||
|
- ``HermesMCPOAuthProvider._initialize`` seeds ``context.token_expiry_time``
|
||||||
|
after loading, so ``is_token_valid()`` reports True only for tokens that
|
||||||
|
are actually still valid, and the SDK's preemptive refresh fires for
|
||||||
|
expired tokens with a live refresh_token.
|
||||||
|
|
||||||
|
Reference: Claude Code solves this via an ``OAuthTokens.expiresAt`` absolute
|
||||||
|
timestamp persisted alongside the access_token (``auth.ts:~180``).
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
pytest.importorskip("mcp.client.auth.oauth2", reason="MCP SDK 1.26.0+ required")
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# HermesTokenStorage — absolute expiry persistence
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestSetTokensAbsoluteExpiry:
|
||||||
|
def test_set_tokens_persists_absolute_expires_at(self, tmp_path, monkeypatch):
|
||||||
|
"""Tokens round-tripped through disk must encode absolute expiry."""
|
||||||
|
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||||
|
from mcp.shared.auth import OAuthToken
|
||||||
|
|
||||||
|
from tools.mcp_oauth import HermesTokenStorage
|
||||||
|
|
||||||
|
storage = HermesTokenStorage("srv")
|
||||||
|
before = time.time()
|
||||||
|
asyncio.run(
|
||||||
|
storage.set_tokens(
|
||||||
|
OAuthToken(
|
||||||
|
access_token="a",
|
||||||
|
token_type="Bearer",
|
||||||
|
expires_in=3600,
|
||||||
|
refresh_token="r",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
after = time.time()
|
||||||
|
|
||||||
|
on_disk = json.loads(
|
||||||
|
(tmp_path / "mcp-tokens" / "srv.json").read_text()
|
||||||
|
)
|
||||||
|
assert "expires_at" in on_disk, (
|
||||||
|
"Fix A: set_tokens must record an absolute expires_at wall-clock "
|
||||||
|
"timestamp alongside the SDK's serialized token so cold-loads "
|
||||||
|
"can compute correct remaining TTL."
|
||||||
|
)
|
||||||
|
assert before + 3600 <= on_disk["expires_at"] <= after + 3600
|
||||||
|
|
||||||
|
def test_set_tokens_without_expires_in_omits_expires_at(
|
||||||
|
self, tmp_path, monkeypatch
|
||||||
|
):
|
||||||
|
"""Tokens without a TTL must not gain a fabricated expires_at."""
|
||||||
|
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||||
|
from mcp.shared.auth import OAuthToken
|
||||||
|
|
||||||
|
from tools.mcp_oauth import HermesTokenStorage
|
||||||
|
|
||||||
|
storage = HermesTokenStorage("srv")
|
||||||
|
asyncio.run(
|
||||||
|
storage.set_tokens(
|
||||||
|
OAuthToken(
|
||||||
|
access_token="a",
|
||||||
|
token_type="Bearer",
|
||||||
|
refresh_token="r",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
on_disk = json.loads(
|
||||||
|
(tmp_path / "mcp-tokens" / "srv.json").read_text()
|
||||||
|
)
|
||||||
|
assert "expires_at" not in on_disk
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetTokensReconstructsExpiresIn:
|
||||||
|
def test_get_tokens_uses_expires_at_for_remaining_ttl(
|
||||||
|
self, tmp_path, monkeypatch
|
||||||
|
):
|
||||||
|
"""Round-trip: expires_in on read must reflect time remaining."""
|
||||||
|
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||||
|
from mcp.shared.auth import OAuthToken
|
||||||
|
|
||||||
|
from tools.mcp_oauth import HermesTokenStorage
|
||||||
|
|
||||||
|
storage = HermesTokenStorage("srv")
|
||||||
|
asyncio.run(
|
||||||
|
storage.set_tokens(
|
||||||
|
OAuthToken(
|
||||||
|
access_token="a",
|
||||||
|
token_type="Bearer",
|
||||||
|
expires_in=3600,
|
||||||
|
refresh_token="r",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Wait briefly so the remaining TTL is measurably less than 3600.
|
||||||
|
time.sleep(0.05)
|
||||||
|
|
||||||
|
reloaded = asyncio.run(storage.get_tokens())
|
||||||
|
assert reloaded is not None
|
||||||
|
assert reloaded.expires_in is not None
|
||||||
|
# Should be slightly less than 3600 after the 50ms sleep.
|
||||||
|
assert 3500 < reloaded.expires_in <= 3600
|
||||||
|
|
||||||
|
def test_get_tokens_returns_zero_ttl_for_expired_token(
|
||||||
|
self, tmp_path, monkeypatch
|
||||||
|
):
|
||||||
|
"""An already-expired token reloaded from disk must report expires_in=0."""
|
||||||
|
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||||
|
from tools.mcp_oauth import HermesTokenStorage, _get_token_dir
|
||||||
|
|
||||||
|
token_dir = _get_token_dir()
|
||||||
|
token_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
# Write an already-expired token file directly.
|
||||||
|
(token_dir / "srv.json").write_text(
|
||||||
|
json.dumps(
|
||||||
|
{
|
||||||
|
"access_token": "a",
|
||||||
|
"token_type": "Bearer",
|
||||||
|
"expires_in": 3600,
|
||||||
|
"expires_at": time.time() - 60, # expired 1 min ago
|
||||||
|
"refresh_token": "r",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
storage = HermesTokenStorage("srv")
|
||||||
|
reloaded = asyncio.run(storage.get_tokens())
|
||||||
|
assert reloaded is not None
|
||||||
|
assert reloaded.expires_in == 0, (
|
||||||
|
"Expired token must reload with expires_in=0 so the SDK's "
|
||||||
|
"is_token_valid() returns False and preemptive refresh fires."
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_get_tokens_legacy_file_without_expires_at_is_loadable(
|
||||||
|
self, tmp_path, monkeypatch
|
||||||
|
):
|
||||||
|
"""Existing on-disk files (pre-Fix-A) must still load without crashing.
|
||||||
|
|
||||||
|
Pre-existing token files have ``expires_in`` but no ``expires_at``.
|
||||||
|
Fix A falls back to the file's mtime as a best-effort wall-clock
|
||||||
|
proxy: a file whose (mtime + expires_in) is in the past clamps
|
||||||
|
expires_in to zero so the SDK refreshes on next request. A fresh
|
||||||
|
legacy-format file (mtime = now) keeps most of its TTL.
|
||||||
|
"""
|
||||||
|
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||||
|
from tools.mcp_oauth import HermesTokenStorage, _get_token_dir
|
||||||
|
|
||||||
|
token_dir = _get_token_dir()
|
||||||
|
token_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
# Legacy-shape file (no expires_at). Make it stale by backdating mtime
|
||||||
|
# well past its nominal expires_in.
|
||||||
|
legacy_path = token_dir / "srv.json"
|
||||||
|
legacy_path.write_text(
|
||||||
|
json.dumps(
|
||||||
|
{
|
||||||
|
"access_token": "a",
|
||||||
|
"token_type": "Bearer",
|
||||||
|
"expires_in": 3600,
|
||||||
|
"refresh_token": "r",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
stale_time = time.time() - 7200 # 2hr ago, exceeds 3600s TTL
|
||||||
|
import os
|
||||||
|
|
||||||
|
os.utime(legacy_path, (stale_time, stale_time))
|
||||||
|
|
||||||
|
storage = HermesTokenStorage("srv")
|
||||||
|
reloaded = asyncio.run(storage.get_tokens())
|
||||||
|
assert reloaded is not None
|
||||||
|
assert reloaded.expires_in == 0, (
|
||||||
|
"Legacy file whose mtime + expires_in is in the past must report "
|
||||||
|
"expires_in=0 so the SDK refreshes on next request."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# HermesMCPOAuthProvider._initialize — seed token_expiry_time
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_initialize_seeds_token_expiry_time_from_stored_tokens(
|
||||||
|
tmp_path, monkeypatch
|
||||||
|
):
|
||||||
|
"""Cold-load must populate context.token_expiry_time.
|
||||||
|
|
||||||
|
The SDK's base ``_initialize`` loads current_tokens but doesn't seed
|
||||||
|
token_expiry_time. Our subclass must do it so ``is_token_valid()``
|
||||||
|
reports correctly and the preemptive-refresh path fires when needed.
|
||||||
|
"""
|
||||||
|
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||||
|
from mcp.shared.auth import OAuthClientInformationFull, OAuthToken
|
||||||
|
from pydantic import AnyUrl
|
||||||
|
|
||||||
|
from tools.mcp_oauth import HermesTokenStorage
|
||||||
|
from tools.mcp_oauth_manager import _HERMES_PROVIDER_CLS, reset_manager_for_tests
|
||||||
|
|
||||||
|
assert _HERMES_PROVIDER_CLS is not None
|
||||||
|
reset_manager_for_tests()
|
||||||
|
|
||||||
|
storage = HermesTokenStorage("srv")
|
||||||
|
await storage.set_tokens(
|
||||||
|
OAuthToken(
|
||||||
|
access_token="a",
|
||||||
|
token_type="Bearer",
|
||||||
|
expires_in=7200,
|
||||||
|
refresh_token="r",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
await storage.set_client_info(
|
||||||
|
OAuthClientInformationFull(
|
||||||
|
client_id="test-client",
|
||||||
|
redirect_uris=[AnyUrl("http://127.0.0.1:12345/callback")],
|
||||||
|
grant_types=["authorization_code", "refresh_token"],
|
||||||
|
response_types=["code"],
|
||||||
|
token_endpoint_auth_method="none",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
from mcp.shared.auth import OAuthClientMetadata
|
||||||
|
|
||||||
|
metadata = OAuthClientMetadata(
|
||||||
|
redirect_uris=[AnyUrl("http://127.0.0.1:12345/callback")],
|
||||||
|
client_name="Hermes Agent",
|
||||||
|
)
|
||||||
|
provider = _HERMES_PROVIDER_CLS(
|
||||||
|
server_name="srv",
|
||||||
|
server_url="https://example.com/mcp",
|
||||||
|
client_metadata=metadata,
|
||||||
|
storage=storage,
|
||||||
|
redirect_handler=_noop_redirect,
|
||||||
|
callback_handler=_noop_callback,
|
||||||
|
)
|
||||||
|
|
||||||
|
await provider._initialize()
|
||||||
|
|
||||||
|
assert provider.context.token_expiry_time is not None, (
|
||||||
|
"Fix A: _initialize must seed context.token_expiry_time so "
|
||||||
|
"is_token_valid() correctly reports expiry on cold-load."
|
||||||
|
)
|
||||||
|
# Should be ~7200s in the future (fresh write).
|
||||||
|
assert provider.context.token_expiry_time > time.time() + 7000
|
||||||
|
assert provider.context.token_expiry_time <= time.time() + 7200 + 5
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_initialize_flags_expired_token_as_invalid(tmp_path, monkeypatch):
|
||||||
|
"""After _initialize, an expired-on-disk token must report is_token_valid=False.
|
||||||
|
|
||||||
|
This is the end-to-end assertion: cold-load an expired token, verify the
|
||||||
|
SDK's own ``is_token_valid()`` now returns False (the consequence of
|
||||||
|
seeding token_expiry_time correctly), so the SDK's ``async_auth_flow``
|
||||||
|
will take the ``can_refresh_token()`` branch on the next request and
|
||||||
|
silently refresh instead of sending the stale Bearer.
|
||||||
|
"""
|
||||||
|
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||||
|
from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata
|
||||||
|
from pydantic import AnyUrl
|
||||||
|
|
||||||
|
from tools.mcp_oauth import HermesTokenStorage, _get_token_dir
|
||||||
|
from tools.mcp_oauth_manager import _HERMES_PROVIDER_CLS, reset_manager_for_tests
|
||||||
|
|
||||||
|
assert _HERMES_PROVIDER_CLS is not None
|
||||||
|
reset_manager_for_tests()
|
||||||
|
|
||||||
|
# Write an already-expired token directly so we control the wall-clock.
|
||||||
|
token_dir = _get_token_dir()
|
||||||
|
token_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
(token_dir / "srv.json").write_text(
|
||||||
|
json.dumps(
|
||||||
|
{
|
||||||
|
"access_token": "stale",
|
||||||
|
"token_type": "Bearer",
|
||||||
|
"expires_in": 3600,
|
||||||
|
"expires_at": time.time() - 60,
|
||||||
|
"refresh_token": "fresh",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
storage = HermesTokenStorage("srv")
|
||||||
|
await storage.set_client_info(
|
||||||
|
OAuthClientInformationFull(
|
||||||
|
client_id="test-client",
|
||||||
|
redirect_uris=[AnyUrl("http://127.0.0.1:12345/callback")],
|
||||||
|
grant_types=["authorization_code", "refresh_token"],
|
||||||
|
response_types=["code"],
|
||||||
|
token_endpoint_auth_method="none",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
metadata = OAuthClientMetadata(
|
||||||
|
redirect_uris=[AnyUrl("http://127.0.0.1:12345/callback")],
|
||||||
|
client_name="Hermes Agent",
|
||||||
|
)
|
||||||
|
provider = _HERMES_PROVIDER_CLS(
|
||||||
|
server_name="srv",
|
||||||
|
server_url="https://example.com/mcp",
|
||||||
|
client_metadata=metadata,
|
||||||
|
storage=storage,
|
||||||
|
redirect_handler=_noop_redirect,
|
||||||
|
callback_handler=_noop_callback,
|
||||||
|
)
|
||||||
|
|
||||||
|
await provider._initialize()
|
||||||
|
|
||||||
|
assert provider.context.is_token_valid() is False, (
|
||||||
|
"After _initialize with an expired-on-disk token, is_token_valid() "
|
||||||
|
"must return False so the SDK's async_auth_flow takes the "
|
||||||
|
"preemptive refresh path."
|
||||||
|
)
|
||||||
|
assert provider.context.can_refresh_token() is True, (
|
||||||
|
"Refresh should remain possible because refresh_token + client_info "
|
||||||
|
"are both present."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _noop_redirect(_url: str) -> None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def _noop_callback() -> tuple[str, str | None]:
|
||||||
|
raise AssertionError("callback handler should not be invoked in these tests")
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Pre-flight OAuth metadata discovery
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_initialize_prefetches_oauth_metadata_when_missing(
|
||||||
|
tmp_path, monkeypatch
|
||||||
|
):
|
||||||
|
"""Cold-load must pre-flight PRM + ASM discovery so ``_refresh_token``
|
||||||
|
has the correct ``token_endpoint`` before the first refresh attempt.
|
||||||
|
|
||||||
|
Without this, the SDK's ``_refresh_token`` falls back to
|
||||||
|
``{server_url}/token`` which is wrong for providers whose AS is at
|
||||||
|
a different origin. BetterStack specifically: MCP at
|
||||||
|
``mcp.betterstack.com`` but token_endpoint at
|
||||||
|
``betterstack.com/oauth/token``. Without pre-flight the refresh 404s
|
||||||
|
and we drop into full browser re-auth — visible to the user as an
|
||||||
|
unwanted OAuth browser prompt every time the process restarts.
|
||||||
|
"""
|
||||||
|
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
from mcp.shared.auth import (
|
||||||
|
OAuthClientInformationFull,
|
||||||
|
OAuthClientMetadata,
|
||||||
|
OAuthToken,
|
||||||
|
)
|
||||||
|
from pydantic import AnyUrl
|
||||||
|
|
||||||
|
from tools.mcp_oauth import HermesTokenStorage
|
||||||
|
from tools.mcp_oauth_manager import _HERMES_PROVIDER_CLS, reset_manager_for_tests
|
||||||
|
|
||||||
|
assert _HERMES_PROVIDER_CLS is not None
|
||||||
|
reset_manager_for_tests()
|
||||||
|
|
||||||
|
storage = HermesTokenStorage("srv")
|
||||||
|
await storage.set_tokens(
|
||||||
|
OAuthToken(
|
||||||
|
access_token="a",
|
||||||
|
token_type="Bearer",
|
||||||
|
expires_in=3600,
|
||||||
|
refresh_token="r",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
await storage.set_client_info(
|
||||||
|
OAuthClientInformationFull(
|
||||||
|
client_id="test-client",
|
||||||
|
redirect_uris=[AnyUrl("http://127.0.0.1:12345/callback")],
|
||||||
|
grant_types=["authorization_code", "refresh_token"],
|
||||||
|
response_types=["code"],
|
||||||
|
token_endpoint_auth_method="none",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Route the AsyncClient used inside _prefetch_oauth_metadata through a
|
||||||
|
# MockTransport that mimics BetterStack's split-origin discovery:
|
||||||
|
# PRM at mcp.example.com/.well-known/oauth-protected-resource -> points to auth.example.com
|
||||||
|
# ASM at auth.example.com/.well-known/oauth-authorization-server -> token_endpoint at auth.example.com/oauth/token
|
||||||
|
def mock_handler(request: httpx.Request) -> httpx.Response:
|
||||||
|
url = str(request.url)
|
||||||
|
if url.endswith("/.well-known/oauth-protected-resource"):
|
||||||
|
return httpx.Response(
|
||||||
|
200,
|
||||||
|
json={
|
||||||
|
"resource": "https://mcp.example.com",
|
||||||
|
"authorization_servers": ["https://auth.example.com"],
|
||||||
|
"scopes_supported": ["read", "write"],
|
||||||
|
"bearer_methods_supported": ["header"],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
if url.endswith("/.well-known/oauth-authorization-server"):
|
||||||
|
return httpx.Response(
|
||||||
|
200,
|
||||||
|
json={
|
||||||
|
"issuer": "https://auth.example.com",
|
||||||
|
"authorization_endpoint": "https://auth.example.com/oauth/authorize",
|
||||||
|
"token_endpoint": "https://auth.example.com/oauth/token",
|
||||||
|
"registration_endpoint": "https://auth.example.com/oauth/register",
|
||||||
|
"response_types_supported": ["code"],
|
||||||
|
"grant_types_supported": ["authorization_code", "refresh_token"],
|
||||||
|
"code_challenge_methods_supported": ["S256"],
|
||||||
|
"token_endpoint_auth_methods_supported": ["none"],
|
||||||
|
"scopes_supported": ["read", "write"],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return httpx.Response(404)
|
||||||
|
|
||||||
|
transport = httpx.MockTransport(mock_handler)
|
||||||
|
|
||||||
|
# Patch the AsyncClient constructor used by _prefetch_oauth_metadata so
|
||||||
|
# it uses our mock transport instead of the real network.
|
||||||
|
import httpx as real_httpx
|
||||||
|
|
||||||
|
original_async_client = real_httpx.AsyncClient
|
||||||
|
|
||||||
|
def patched_async_client(*args, **kwargs):
|
||||||
|
kwargs["transport"] = transport
|
||||||
|
return original_async_client(*args, **kwargs)
|
||||||
|
|
||||||
|
monkeypatch.setattr(real_httpx, "AsyncClient", patched_async_client)
|
||||||
|
|
||||||
|
metadata = OAuthClientMetadata(
|
||||||
|
redirect_uris=[AnyUrl("http://127.0.0.1:12345/callback")],
|
||||||
|
client_name="Hermes Agent",
|
||||||
|
)
|
||||||
|
provider = _HERMES_PROVIDER_CLS(
|
||||||
|
server_name="srv",
|
||||||
|
server_url="https://mcp.example.com",
|
||||||
|
client_metadata=metadata,
|
||||||
|
storage=storage,
|
||||||
|
redirect_handler=_noop_redirect,
|
||||||
|
callback_handler=_noop_callback,
|
||||||
|
)
|
||||||
|
|
||||||
|
await provider._initialize()
|
||||||
|
|
||||||
|
assert provider.context.protected_resource_metadata is not None, (
|
||||||
|
"Pre-flight must cache PRM for the SDK to reference later."
|
||||||
|
)
|
||||||
|
assert provider.context.oauth_metadata is not None, (
|
||||||
|
"Pre-flight must cache ASM so _refresh_token builds the correct "
|
||||||
|
"token_endpoint URL."
|
||||||
|
)
|
||||||
|
assert str(provider.context.oauth_metadata.token_endpoint) == (
|
||||||
|
"https://auth.example.com/oauth/token"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_initialize_skips_prefetch_when_no_tokens(tmp_path, monkeypatch):
|
||||||
|
"""Pre-flight must not run when there are no stored tokens yet.
|
||||||
|
|
||||||
|
Without this guard, every fresh-install ``_initialize`` would do two
|
||||||
|
extra network roundtrips that gain nothing (the SDK's 401-branch
|
||||||
|
discovery will run on the first real request anyway).
|
||||||
|
"""
|
||||||
|
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||||
|
import httpx
|
||||||
|
from mcp.shared.auth import OAuthClientMetadata
|
||||||
|
from pydantic import AnyUrl
|
||||||
|
|
||||||
|
from tools.mcp_oauth_manager import _HERMES_PROVIDER_CLS, reset_manager_for_tests
|
||||||
|
from tools.mcp_oauth import HermesTokenStorage
|
||||||
|
|
||||||
|
assert _HERMES_PROVIDER_CLS is not None
|
||||||
|
reset_manager_for_tests()
|
||||||
|
|
||||||
|
calls: list[str] = []
|
||||||
|
|
||||||
|
def mock_handler(request: httpx.Request) -> httpx.Response:
|
||||||
|
calls.append(str(request.url))
|
||||||
|
return httpx.Response(404)
|
||||||
|
|
||||||
|
transport = httpx.MockTransport(mock_handler)
|
||||||
|
import httpx as real_httpx
|
||||||
|
|
||||||
|
original = real_httpx.AsyncClient
|
||||||
|
|
||||||
|
def patched(*args, **kwargs):
|
||||||
|
kwargs["transport"] = transport
|
||||||
|
return original(*args, **kwargs)
|
||||||
|
|
||||||
|
monkeypatch.setattr(real_httpx, "AsyncClient", patched)
|
||||||
|
|
||||||
|
storage = HermesTokenStorage("srv") # empty — no tokens on disk
|
||||||
|
metadata = OAuthClientMetadata(
|
||||||
|
redirect_uris=[AnyUrl("http://127.0.0.1:12345/callback")],
|
||||||
|
client_name="Hermes Agent",
|
||||||
|
)
|
||||||
|
provider = _HERMES_PROVIDER_CLS(
|
||||||
|
server_name="srv",
|
||||||
|
server_url="https://mcp.example.com",
|
||||||
|
client_metadata=metadata,
|
||||||
|
storage=storage,
|
||||||
|
redirect_handler=_noop_redirect,
|
||||||
|
callback_handler=_noop_callback,
|
||||||
|
)
|
||||||
|
|
||||||
|
await provider._initialize()
|
||||||
|
|
||||||
|
assert calls == [], (
|
||||||
|
f"Pre-flight must not fire when no tokens are stored, but got {calls}"
|
||||||
|
)
|
||||||
@@ -40,6 +40,7 @@ import re
|
|||||||
import socket
|
import socket
|
||||||
import sys
|
import sys
|
||||||
import threading
|
import threading
|
||||||
|
import time
|
||||||
import webbrowser
|
import webbrowser
|
||||||
from http.server import BaseHTTPRequestHandler, HTTPServer
|
from http.server import BaseHTTPRequestHandler, HTTPServer
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -196,6 +197,35 @@ class HermesTokenStorage:
|
|||||||
data = _read_json(self._tokens_path())
|
data = _read_json(self._tokens_path())
|
||||||
if data is None:
|
if data is None:
|
||||||
return None
|
return None
|
||||||
|
# Hermes records an absolute wall-clock ``expires_at`` alongside the
|
||||||
|
# SDK's serialized token (see ``set_tokens``). On read we rewrite
|
||||||
|
# ``expires_in`` to the remaining seconds so the SDK's downstream
|
||||||
|
# ``update_token_expiry`` computes the correct absolute time and
|
||||||
|
# ``is_token_valid()`` correctly reports False for tokens that
|
||||||
|
# expired while the process was down.
|
||||||
|
#
|
||||||
|
# Legacy token files (pre-Fix-A) have ``expires_in`` but no
|
||||||
|
# ``expires_at``. We fall back to the file's mtime as a best-effort
|
||||||
|
# wall-clock proxy for when the token was written: if (mtime +
|
||||||
|
# expires_in) is in the past, clamp ``expires_in`` to zero so the
|
||||||
|
# SDK refreshes before the first request. This self-heals one-time
|
||||||
|
# on the next successful ``set_tokens``, which writes the new
|
||||||
|
# ``expires_at`` field. The stored ``expires_at`` is stripped before
|
||||||
|
# model_validate because it's not part of the SDK's OAuthToken schema.
|
||||||
|
absolute_expiry = data.pop("expires_at", None)
|
||||||
|
if absolute_expiry is not None:
|
||||||
|
data["expires_in"] = int(max(absolute_expiry - time.time(), 0))
|
||||||
|
elif data.get("expires_in") is not None:
|
||||||
|
try:
|
||||||
|
file_mtime = self._tokens_path().stat().st_mtime
|
||||||
|
except OSError:
|
||||||
|
file_mtime = None
|
||||||
|
if file_mtime is not None:
|
||||||
|
try:
|
||||||
|
implied_expiry = file_mtime + int(data["expires_in"])
|
||||||
|
data["expires_in"] = int(max(implied_expiry - time.time(), 0))
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
pass
|
||||||
try:
|
try:
|
||||||
return OAuthToken.model_validate(data)
|
return OAuthToken.model_validate(data)
|
||||||
except (ValueError, TypeError, KeyError) as exc:
|
except (ValueError, TypeError, KeyError) as exc:
|
||||||
@@ -203,7 +233,23 @@ class HermesTokenStorage:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
async def set_tokens(self, tokens: "OAuthToken") -> None:
|
async def set_tokens(self, tokens: "OAuthToken") -> None:
|
||||||
_write_json(self._tokens_path(), tokens.model_dump(exclude_none=True))
|
payload = tokens.model_dump(exclude_none=True)
|
||||||
|
# Persist an absolute ``expires_at`` so a process restart can
|
||||||
|
# reconstruct the correct remaining TTL. Without this the MCP SDK's
|
||||||
|
# ``_initialize`` reloads a relative ``expires_in`` which has no
|
||||||
|
# wall-clock reference, leaving ``context.token_expiry_time=None``
|
||||||
|
# and ``is_token_valid()`` falsely reporting True. See Fix A in
|
||||||
|
# ``mcp-oauth-token-diagnosis`` skill + Claude Code's
|
||||||
|
# ``OAuthTokens.expiresAt`` persistence (auth.ts ~180).
|
||||||
|
expires_in = payload.get("expires_in")
|
||||||
|
if expires_in is not None:
|
||||||
|
try:
|
||||||
|
payload["expires_at"] = time.time() + int(expires_in)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
# Mock tokens or unusual shapes: skip the expires_at write
|
||||||
|
# rather than fail persistence.
|
||||||
|
pass
|
||||||
|
_write_json(self._tokens_path(), payload)
|
||||||
logger.debug("OAuth tokens saved for %s", self._server_name)
|
logger.debug("OAuth tokens saved for %s", self._server_name)
|
||||||
|
|
||||||
# -- client info -------------------------------------------------------
|
# -- client info -------------------------------------------------------
|
||||||
|
|||||||
@@ -111,6 +111,131 @@ def _make_hermes_provider_class() -> Optional[type]:
|
|||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self._hermes_server_name = server_name
|
self._hermes_server_name = server_name
|
||||||
|
|
||||||
|
async def _initialize(self) -> None:
|
||||||
|
"""Load stored tokens + client info AND seed token_expiry_time.
|
||||||
|
|
||||||
|
Also eagerly fetches OAuth authorization-server metadata (PRM +
|
||||||
|
ASM) when we have stored tokens but no cached metadata, so the
|
||||||
|
SDK's ``_refresh_token`` can build the correct token_endpoint
|
||||||
|
URL on the preemptive-refresh path. Without this, the SDK
|
||||||
|
falls back to ``{mcp_server_url}/token`` (wrong for providers
|
||||||
|
whose AS is a different origin — BetterStack's MCP lives at
|
||||||
|
``https://mcp.betterstack.com`` but its token endpoint is at
|
||||||
|
``https://betterstack.com/oauth/token``), the refresh 404s, and
|
||||||
|
we drop through to full browser reauth.
|
||||||
|
|
||||||
|
The SDK's base ``_initialize`` populates ``current_tokens`` but
|
||||||
|
does NOT call ``update_token_expiry``, so ``token_expiry_time``
|
||||||
|
stays ``None`` and ``is_token_valid()`` returns True for any
|
||||||
|
loaded token regardless of actual age. After a process restart
|
||||||
|
this ships stale Bearer tokens to the server; some providers
|
||||||
|
return HTTP 401 (caught by the 401 handler), others return 200
|
||||||
|
with an app-level auth error (invisible to the transport layer,
|
||||||
|
e.g. BetterStack returning "No teams found. Please check your
|
||||||
|
authentication.").
|
||||||
|
|
||||||
|
Seeding ``token_expiry_time`` from the reloaded token fixes that:
|
||||||
|
``is_token_valid()`` correctly reports False for expired tokens,
|
||||||
|
``async_auth_flow`` takes the ``can_refresh_token()`` branch,
|
||||||
|
and the SDK quietly refreshes before the first real request.
|
||||||
|
|
||||||
|
Paired with :class:`HermesTokenStorage` persisting an absolute
|
||||||
|
``expires_at`` timestamp (``mcp_oauth.py:set_tokens``) so the
|
||||||
|
remaining TTL we compute here reflects real wall-clock age.
|
||||||
|
"""
|
||||||
|
await super()._initialize()
|
||||||
|
tokens = self.context.current_tokens
|
||||||
|
if tokens is not None and tokens.expires_in is not None:
|
||||||
|
self.context.update_token_expiry(tokens)
|
||||||
|
|
||||||
|
# Pre-flight OAuth AS discovery so ``_refresh_token`` has a
|
||||||
|
# correct ``token_endpoint`` before the first refresh attempt.
|
||||||
|
# Only runs when we have tokens on cold-load but no cached
|
||||||
|
# metadata — i.e. the exact scenario where the SDK's built-in
|
||||||
|
# 401-branch discovery hasn't had a chance to run yet.
|
||||||
|
if (
|
||||||
|
tokens is not None
|
||||||
|
and self.context.oauth_metadata is None
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
await self._prefetch_oauth_metadata()
|
||||||
|
except Exception as exc: # pragma: no cover — defensive
|
||||||
|
# Non-fatal: if discovery fails, the SDK's normal 401-
|
||||||
|
# branch discovery will run on the next request.
|
||||||
|
logger.debug(
|
||||||
|
"MCP OAuth '%s': pre-flight metadata discovery "
|
||||||
|
"failed (non-fatal): %s",
|
||||||
|
self._hermes_server_name, exc,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _prefetch_oauth_metadata(self) -> None:
|
||||||
|
"""Fetch PRM + ASM from the well-known endpoints, cache on context.
|
||||||
|
|
||||||
|
Mirrors the SDK's 401-branch discovery (oauth2.py ~line 511-551)
|
||||||
|
but runs synchronously before the first request instead of
|
||||||
|
inside the httpx auth_flow generator. Uses the SDK's own URL
|
||||||
|
builders and response handlers so we track whatever the SDK
|
||||||
|
version we're pinned to expects.
|
||||||
|
"""
|
||||||
|
import httpx # local import: httpx is an MCP SDK dependency
|
||||||
|
from mcp.client.auth.utils import (
|
||||||
|
build_oauth_authorization_server_metadata_discovery_urls,
|
||||||
|
build_protected_resource_metadata_discovery_urls,
|
||||||
|
create_oauth_metadata_request,
|
||||||
|
handle_auth_metadata_response,
|
||||||
|
handle_protected_resource_response,
|
||||||
|
)
|
||||||
|
|
||||||
|
server_url = self.context.server_url
|
||||||
|
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||||
|
# Step 1: PRM discovery to learn the authorization_server URL.
|
||||||
|
for url in build_protected_resource_metadata_discovery_urls(
|
||||||
|
None, server_url
|
||||||
|
):
|
||||||
|
req = create_oauth_metadata_request(url)
|
||||||
|
try:
|
||||||
|
resp = await client.send(req)
|
||||||
|
except httpx.HTTPError as exc:
|
||||||
|
logger.debug(
|
||||||
|
"MCP OAuth '%s': PRM discovery to %s failed: %s",
|
||||||
|
self._hermes_server_name, url, exc,
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
prm = await handle_protected_resource_response(resp)
|
||||||
|
if prm:
|
||||||
|
self.context.protected_resource_metadata = prm
|
||||||
|
if prm.authorization_servers:
|
||||||
|
self.context.auth_server_url = str(
|
||||||
|
prm.authorization_servers[0]
|
||||||
|
)
|
||||||
|
break
|
||||||
|
|
||||||
|
# Step 2: ASM discovery against the auth_server_url (or
|
||||||
|
# server_url fallback for legacy providers).
|
||||||
|
for url in build_oauth_authorization_server_metadata_discovery_urls(
|
||||||
|
self.context.auth_server_url, server_url
|
||||||
|
):
|
||||||
|
req = create_oauth_metadata_request(url)
|
||||||
|
try:
|
||||||
|
resp = await client.send(req)
|
||||||
|
except httpx.HTTPError as exc:
|
||||||
|
logger.debug(
|
||||||
|
"MCP OAuth '%s': ASM discovery to %s failed: %s",
|
||||||
|
self._hermes_server_name, url, exc,
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
ok, asm = await handle_auth_metadata_response(resp)
|
||||||
|
if not ok:
|
||||||
|
break
|
||||||
|
if asm:
|
||||||
|
self.context.oauth_metadata = asm
|
||||||
|
logger.debug(
|
||||||
|
"MCP OAuth '%s': pre-flight ASM discovered "
|
||||||
|
"token_endpoint=%s",
|
||||||
|
self._hermes_server_name, asm.token_endpoint,
|
||||||
|
)
|
||||||
|
break
|
||||||
|
|
||||||
async def async_auth_flow(self, request): # type: ignore[override]
|
async def async_auth_flow(self, request): # type: ignore[override]
|
||||||
# Pre-flow hook: ask the manager to refresh from disk if needed.
|
# Pre-flow hook: ask the manager to refresh from disk if needed.
|
||||||
# Any failure here is non-fatal — we just log and proceed with
|
# Any failure here is non-fatal — we just log and proceed with
|
||||||
@@ -125,9 +250,28 @@ def _make_hermes_provider_class() -> Optional[type]:
|
|||||||
self._hermes_server_name, exc,
|
self._hermes_server_name, exc,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Delegate to the SDK's auth flow
|
# Manually bridge the bidirectional generator protocol. httpx's
|
||||||
async for item in super().async_auth_flow(request):
|
# auth_flow driver (httpx._client._send_handling_auth) calls
|
||||||
yield item
|
# ``auth_flow.asend(response)`` to feed HTTP responses back into
|
||||||
|
# the generator. A naive wrapper using ``async for item in inner:
|
||||||
|
# yield item`` DISCARDS those .asend(response) values and resumes
|
||||||
|
# the inner generator with None, so the SDK's
|
||||||
|
# ``response = yield request`` branch in
|
||||||
|
# mcp/client/auth/oauth2.py sees response=None and crashes at
|
||||||
|
# ``if response.status_code == 401`` with AttributeError.
|
||||||
|
#
|
||||||
|
# The bridge below forwards each .asend() value into the inner
|
||||||
|
# generator via inner.asend(incoming), preserving the bidirectional
|
||||||
|
# contract. Regression from PR #11383 caught by
|
||||||
|
# tests/tools/test_mcp_oauth_bidirectional.py.
|
||||||
|
inner = super().async_auth_flow(request)
|
||||||
|
try:
|
||||||
|
outgoing = await inner.__anext__()
|
||||||
|
while True:
|
||||||
|
incoming = yield outgoing
|
||||||
|
outgoing = await inner.asend(incoming)
|
||||||
|
except StopAsyncIteration:
|
||||||
|
return
|
||||||
|
|
||||||
return HermesMCPOAuthProvider
|
return HermesMCPOAuthProvider
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user