fix(mcp-oauth): anchor 401 handler task to prevent GC mid-flight

`handle_401` spawned a dedup'd recovery coroutine via
`asyncio.create_task(_do_handle())` and discarded the returned task
reference. Python's event loop only keeps weak references to tasks, so
the coroutine could be garbage-collected before it called
`pending.set_result(...)`. Every concurrent caller awaiting that future
then hangs forever, and the `finally: entry.pending_401.pop(...)`
cleanup never runs — so subsequent 401s for the same key latch onto the
dead future too. Same pattern the adapter-side fixes address (#11997,
#11998, #12000, #12001, #12006).

Hold the task in a process-wide set on the manager and discard it via
`add_done_callback` once it completes. Regression test covers both the
structural invariant (task tracked, then removed on completion) and a
concurrent dedup path with a forced `gc.collect()` between the handler's
await points.
This commit is contained in:
haileymarshall
2026-04-18 18:13:04 +01:00
committed by Teknium
parent d431dfc448
commit 9f22f36625
2 changed files with 102 additions and 1 deletions

View File

@@ -134,6 +134,101 @@ async def test_disk_watch_invalidates_on_mtime_change(tmp_path, monkeypatch):
assert provider._initialized is False
@pytest.mark.asyncio
async def test_handle_401_tracks_inflight_task_to_prevent_gc(tmp_path, monkeypatch):
"""The 401 handler task must be strongly referenced by the manager.
``asyncio.create_task`` returns a task the event loop only weakly
references. If the manager discards its handle, the background coroutine
can be garbage-collected mid-run and every concurrent waiter stuck on
``await pending`` hangs forever. See the design note on
``MCPOAuthManager._inflight_tasks``.
"""
import asyncio
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
from tools.mcp_oauth_manager import MCPOAuthManager, _ProviderEntry
class _TrackedSet(set):
"""set subclass that records every element ever inserted."""
def __init__(self):
super().__init__()
self.ever_added: list = []
def add(self, item): # noqa: A003
self.ever_added.append(item)
super().add(item)
mgr = MCPOAuthManager()
mgr._inflight_tasks = _TrackedSet()
class _DummyProvider:
context = None # forces the can_refresh=False branch
mgr._entries["srv"] = _ProviderEntry(
server_url="https://example.com/mcp",
oauth_config=None,
provider=_DummyProvider(),
)
result = await mgr.handle_401("srv", failed_access_token="TOK")
# Exactly one handler task was created and tracked.
assert len(mgr._inflight_tasks.ever_added) == 1
tracked_task = mgr._inflight_tasks.ever_added[0]
assert isinstance(tracked_task, asyncio.Task)
# done_callback must have removed the finished task from the live set,
# otherwise the set would grow unbounded across repeated 401s.
assert tracked_task not in mgr._inflight_tasks
assert len(mgr._inflight_tasks) == 0
assert tracked_task.done()
# With provider.context=None, there's nothing to refresh — result False.
assert result is False
@pytest.mark.asyncio
async def test_handle_401_dedup_survives_even_if_task_reference_dropped(tmp_path, monkeypatch):
"""Concurrent 401s share one handler task and all callers resolve.
Regression guard: if the manager ever stops holding a strong reference
to the `_do_handle` task, this test can intermittently hang when the
task is GC'd between the ``await`` checkpoints inside ``_do_handle``.
Running it in CI with ``gc.collect()`` mid-flight (below) exercises
that window.
"""
import asyncio
import gc
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
from tools.mcp_oauth_manager import MCPOAuthManager, _ProviderEntry
mgr = MCPOAuthManager()
class _DummyProvider:
context = None
mgr._entries["srv"] = _ProviderEntry(
server_url="https://example.com/mcp",
oauth_config=None,
provider=_DummyProvider(),
)
# Fan out N concurrent callers sharing the same failed token so all
# collapse onto a single deduped handler future.
async def _caller():
return await mgr.handle_401("srv", failed_access_token="TOK")
tasks = [asyncio.create_task(_caller()) for _ in range(8)]
# Give the event loop one tick to schedule _do_handle, then force GC.
await asyncio.sleep(0)
gc.collect()
results = await asyncio.wait_for(asyncio.gather(*tasks), timeout=5.0)
assert results == [False] * 8
assert len(mgr._inflight_tasks) == 0
def test_manager_builds_hermes_provider_subclass(tmp_path, monkeypatch):
"""get_or_build_provider returns HermesMCPOAuthProvider, not plain OAuthClientProvider."""
from tools.mcp_oauth_manager import (

View File

@@ -451,6 +451,10 @@ class MCPOAuthManager:
def __init__(self) -> None:
self._entries: dict[str, _ProviderEntry] = {}
self._entries_lock = threading.Lock()
# Holds strong references to in-flight 401 handler tasks so the
# event loop's weak-reference bookkeeping cannot GC them mid-run
# and leave `await pending` waiters hanging forever.
self._inflight_tasks: set[asyncio.Task] = set()
# -- Provider construction / caching -------------------------------------
@@ -677,7 +681,9 @@ class MCPOAuthManager:
finally:
entry.pending_401.pop(key, None)
asyncio.create_task(_do_handle())
task = asyncio.create_task(_do_handle())
self._inflight_tasks.add(task)
task.add_done_callback(self._inflight_tasks.discard)
try:
return await pending