Compare commits

...

1 Commits

Author SHA1 Message Date
heathley
109a49bb03 fix(yuanbao): skip resource resolve on cache hits 2026-05-29 03:50:29 -07:00
2 changed files with 150 additions and 0 deletions

View File

@@ -2269,6 +2269,27 @@ class MediaResolveMiddleware(InboundMiddleware):
cls._resource_cache.pop(k, None)
cls._resource_cache[resource_id] = (local_path, mime, time.time())
@classmethod
def _append_cached_resource(
cls,
adapter,
resource_id: str,
media_paths: List[str],
mimes: List[str],
) -> bool:
"""Append a cached resource to output lists when available."""
hit = cls._get_cached_resource(resource_id)
if hit is None:
return False
local_path, mime = hit
logger.debug(
"[%s] resource cache hit: rid=%s path=%s",
adapter.name, resource_id, local_path,
)
media_paths.append(local_path)
mimes.append(mime)
return True
@staticmethod
def _guess_image_ext_from_url(url: str) -> str:
"""Guess image extension from URL path."""
@@ -2451,6 +2472,8 @@ class MediaResolveMiddleware(InboundMiddleware):
# Extract resourceId from the placeholder URL for cache dedup.
rid = ExtractContentMiddleware._parse_resource_id(url)
if rid and cls._append_cached_resource(adapter, rid, media_urls, media_types):
continue
try:
fetch_url = await cls._resolve_download_url(adapter, url)
@@ -2526,6 +2549,8 @@ class MediaResolveMiddleware(InboundMiddleware):
media_paths: List[str] = []
mimes: List[str] = []
for rid, kind, filename in order:
if cls._append_cached_resource(adapter, rid, media_paths, mimes):
continue
try:
fresh_url = await cls._resolve_by_resource_id(adapter, rid)
except Exception as exc:
@@ -2610,6 +2635,10 @@ class DispatchMiddleware(InboundMiddleware):
for rid, kind, filename in ctx.quote_media_refs:
if kind not in _RESOLVABLE_MEDIA_KINDS:
continue
if MediaResolveMiddleware._append_cached_resource(
adapter, rid, media_urls, media_types
):
continue
try:
fresh_url = await MediaResolveMiddleware._resolve_by_resource_id(adapter, rid)
except Exception as exc:

View File

@@ -10,6 +10,7 @@ Tests cover:
6. OOP middleware ABC and class tests
"""
import asyncio
import sys
import os
import json
@@ -34,6 +35,7 @@ from gateway.platforms.yuanbao import (
AccessPolicy,
AccessGuardMiddleware,
ExtractContentMiddleware,
MediaResolveMiddleware,
PlaceholderFilterMiddleware,
OwnerCommandMiddleware,
BuildSourceMiddleware,
@@ -43,6 +45,7 @@ from gateway.platforms.yuanbao import (
YuanbaoAdapter,
)
from gateway.config import PlatformConfig
from gateway.session import Platform, SessionSource
# ============================================================
@@ -340,6 +343,124 @@ class TestInboundContext:
assert ctx.chat_type == "dm"
class _SessionEntry:
def __init__(self, session_id: str) -> None:
self.session_id = session_id
class _TranscriptStore:
def __init__(self, history):
self._history = history
def get_or_create_session(self, _source):
return _SessionEntry("test-session")
def load_transcript(self, _session_id):
return self._history
@pytest.fixture(autouse=True)
def _clear_media_resource_cache():
MediaResolveMiddleware._resource_cache.clear()
yield
MediaResolveMiddleware._resource_cache.clear()
class TestMediaResolveResourceCache:
"""ResourceId cache hits must skip Yuanbao resource URL resolution."""
def _put_existing_cached_resource(self, tmp_path, rid="rid-cached"):
local = tmp_path / f"{rid}.jpg"
local.write_bytes(b"cached-image")
MediaResolveMiddleware._put_cached_resource(rid, str(local), "image/jpeg")
return str(local)
@pytest.mark.asyncio
async def test_current_message_cache_hit_skips_resource_url_resolve(
self, tmp_path, monkeypatch
):
cached_path = self._put_existing_cached_resource(tmp_path)
fetch_resource_url = AsyncMock(return_value="https://fresh.example/image.jpg")
monkeypatch.setattr(
MediaResolveMiddleware, "_fetch_resource_url", fetch_resource_url
)
paths, mimes = await MediaResolveMiddleware._resolve_media_urls(
make_adapter(),
[
{
"kind": "image",
"url": "https://hunyuan.tencent.com/api/resource/download?resourceId=rid-cached",
}
],
)
assert paths == [cached_path]
assert mimes == ["image/jpeg"]
fetch_resource_url.assert_not_awaited()
@pytest.mark.asyncio
async def test_observed_media_cache_hit_skips_resource_url_resolve(
self, tmp_path, monkeypatch
):
cached_path = self._put_existing_cached_resource(tmp_path)
adapter = make_adapter()
adapter._session_store = _TranscriptStore(
[{"content": "earlier [image|ybres:rid-cached]"}]
)
fetch_resource_url = AsyncMock(return_value="https://fresh.example/image.jpg")
monkeypatch.setattr(
MediaResolveMiddleware, "_fetch_resource_url", fetch_resource_url
)
paths, mimes = await MediaResolveMiddleware._collect_observed_media(
adapter, "direct:alice"
)
assert paths == [cached_path]
assert mimes == ["image/jpeg"]
fetch_resource_url.assert_not_awaited()
@pytest.mark.asyncio
async def test_quote_media_cache_hit_skips_resource_url_resolve(
self, tmp_path, monkeypatch
):
cached_path = self._put_existing_cached_resource(tmp_path)
adapter = make_adapter()
adapter.handle_message = AsyncMock()
fetch_resource_url = AsyncMock(return_value="https://fresh.example/image.jpg")
monkeypatch.setattr(
MediaResolveMiddleware, "_fetch_resource_url", fetch_resource_url
)
ctx = make_ctx(
adapter=adapter,
chat_type="dm",
source=SessionSource(
platform=Platform.YUANBAO,
chat_id="alice",
chat_type="dm",
user_id="alice",
),
msg_id="msg-quote",
raw_text="replying",
reply_to_message_id="quoted-msg",
quote_media_refs=[("rid-cached", "image", "")],
media_urls=[],
media_types=[],
)
next_fn = AsyncMock()
await DispatchMiddleware()(ctx, next_fn)
await asyncio.gather(*list(adapter._inbound_tasks))
next_fn.assert_awaited_once()
fetch_resource_url.assert_not_awaited()
adapter.handle_message.assert_awaited_once()
event = adapter.handle_message.await_args.args[0]
assert event.media_urls == [cached_path]
assert event.media_types == ["image/jpeg"]
# ============================================================
# 3. Individual Middleware Tests
# ============================================================