mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-07-03 08:47:04 +08:00
fix(mcp-oauth): anchor 401 handler task to prevent GC mid-flight
`handle_401` spawned a dedup'd recovery coroutine via `asyncio.create_task(_do_handle())` and discarded the returned task reference. Python's event loop only keeps weak references to tasks, so the coroutine could be garbage-collected before it called `pending.set_result(...)`. Every concurrent caller awaiting that future then hangs forever, and the `finally: entry.pending_401.pop(...)` cleanup never runs — so subsequent 401s for the same key latch onto the dead future too. Same pattern the adapter-side fixes address (#11997, #11998, #12000, #12001, #12006). Hold the task in a process-wide set on the manager and discard it via `add_done_callback` once it completes. Regression test covers both the structural invariant (task tracked, then removed on completion) and a concurrent dedup path with a forced `gc.collect()` between the handler's await points.
This commit is contained in:
@@ -134,6 +134,101 @@ async def test_disk_watch_invalidates_on_mtime_change(tmp_path, monkeypatch):
|
||||
assert provider._initialized is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_401_tracks_inflight_task_to_prevent_gc(tmp_path, monkeypatch):
|
||||
"""The 401 handler task must be strongly referenced by the manager.
|
||||
|
||||
``asyncio.create_task`` returns a task the event loop only weakly
|
||||
references. If the manager discards its handle, the background coroutine
|
||||
can be garbage-collected mid-run and every concurrent waiter stuck on
|
||||
``await pending`` hangs forever. See the design note on
|
||||
``MCPOAuthManager._inflight_tasks``.
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
from tools.mcp_oauth_manager import MCPOAuthManager, _ProviderEntry
|
||||
|
||||
class _TrackedSet(set):
|
||||
"""set subclass that records every element ever inserted."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.ever_added: list = []
|
||||
|
||||
def add(self, item): # noqa: A003
|
||||
self.ever_added.append(item)
|
||||
super().add(item)
|
||||
|
||||
mgr = MCPOAuthManager()
|
||||
mgr._inflight_tasks = _TrackedSet()
|
||||
|
||||
class _DummyProvider:
|
||||
context = None # forces the can_refresh=False branch
|
||||
|
||||
mgr._entries["srv"] = _ProviderEntry(
|
||||
server_url="https://example.com/mcp",
|
||||
oauth_config=None,
|
||||
provider=_DummyProvider(),
|
||||
)
|
||||
|
||||
result = await mgr.handle_401("srv", failed_access_token="TOK")
|
||||
|
||||
# Exactly one handler task was created and tracked.
|
||||
assert len(mgr._inflight_tasks.ever_added) == 1
|
||||
tracked_task = mgr._inflight_tasks.ever_added[0]
|
||||
assert isinstance(tracked_task, asyncio.Task)
|
||||
# done_callback must have removed the finished task from the live set,
|
||||
# otherwise the set would grow unbounded across repeated 401s.
|
||||
assert tracked_task not in mgr._inflight_tasks
|
||||
assert len(mgr._inflight_tasks) == 0
|
||||
assert tracked_task.done()
|
||||
# With provider.context=None, there's nothing to refresh — result False.
|
||||
assert result is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_401_dedup_survives_even_if_task_reference_dropped(tmp_path, monkeypatch):
|
||||
"""Concurrent 401s share one handler task and all callers resolve.
|
||||
|
||||
Regression guard: if the manager ever stops holding a strong reference
|
||||
to the `_do_handle` task, this test can intermittently hang when the
|
||||
task is GC'd between the ``await`` checkpoints inside ``_do_handle``.
|
||||
Running it in CI with ``gc.collect()`` mid-flight (below) exercises
|
||||
that window.
|
||||
"""
|
||||
import asyncio
|
||||
import gc
|
||||
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
from tools.mcp_oauth_manager import MCPOAuthManager, _ProviderEntry
|
||||
|
||||
mgr = MCPOAuthManager()
|
||||
|
||||
class _DummyProvider:
|
||||
context = None
|
||||
|
||||
mgr._entries["srv"] = _ProviderEntry(
|
||||
server_url="https://example.com/mcp",
|
||||
oauth_config=None,
|
||||
provider=_DummyProvider(),
|
||||
)
|
||||
|
||||
# Fan out N concurrent callers sharing the same failed token so all
|
||||
# collapse onto a single deduped handler future.
|
||||
async def _caller():
|
||||
return await mgr.handle_401("srv", failed_access_token="TOK")
|
||||
|
||||
tasks = [asyncio.create_task(_caller()) for _ in range(8)]
|
||||
# Give the event loop one tick to schedule _do_handle, then force GC.
|
||||
await asyncio.sleep(0)
|
||||
gc.collect()
|
||||
|
||||
results = await asyncio.wait_for(asyncio.gather(*tasks), timeout=5.0)
|
||||
assert results == [False] * 8
|
||||
assert len(mgr._inflight_tasks) == 0
|
||||
|
||||
|
||||
def test_manager_builds_hermes_provider_subclass(tmp_path, monkeypatch):
|
||||
"""get_or_build_provider returns HermesMCPOAuthProvider, not plain OAuthClientProvider."""
|
||||
from tools.mcp_oauth_manager import (
|
||||
|
||||
@@ -451,6 +451,10 @@ class MCPOAuthManager:
|
||||
def __init__(self) -> None:
|
||||
self._entries: dict[str, _ProviderEntry] = {}
|
||||
self._entries_lock = threading.Lock()
|
||||
# Holds strong references to in-flight 401 handler tasks so the
|
||||
# event loop's weak-reference bookkeeping cannot GC them mid-run
|
||||
# and leave `await pending` waiters hanging forever.
|
||||
self._inflight_tasks: set[asyncio.Task] = set()
|
||||
|
||||
# -- Provider construction / caching -------------------------------------
|
||||
|
||||
@@ -677,7 +681,9 @@ class MCPOAuthManager:
|
||||
finally:
|
||||
entry.pending_401.pop(key, None)
|
||||
|
||||
asyncio.create_task(_do_handle())
|
||||
task = asyncio.create_task(_do_handle())
|
||||
self._inflight_tasks.add(task)
|
||||
task.add_done_callback(self._inflight_tasks.discard)
|
||||
|
||||
try:
|
||||
return await pending
|
||||
|
||||
Reference in New Issue
Block a user