Files
hermes-agent/tests/test_yuanbao_pipeline.py
Teknium ab6879634e yuanbao platform (#16298)
Co-authored-by: loongzhao <loongzhao@tencent.com>
2026-04-26 18:50:49 -07:00

1030 lines
34 KiB
Python

"""
test_yuanbao_pipeline.py - Unit tests for the inbound middleware pipeline.
Tests cover:
1. InboundPipeline engine (use, use_before, use_after, remove, execute)
2. InboundContext dataclass
3. Individual middlewares (DecodeMiddleware, DedupMiddleware, SkipSelfMiddleware, etc.)
4. InboundPipelineBuilder
5. End-to-end pipeline integration
6. OOP middleware ABC and class tests
"""
import sys
import os
import json
import asyncio
# Ensure project root is on the path
_REPO_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
if _REPO_ROOT not in sys.path:
sys.path.insert(0, _REPO_ROOT)
import pytest
from unittest.mock import AsyncMock, MagicMock, patch, PropertyMock
from gateway.platforms.yuanbao import (
InboundContext,
InboundMiddleware,
InboundPipeline,
DecodeMiddleware,
ExtractFieldsMiddleware,
DedupMiddleware,
SkipSelfMiddleware,
ChatRoutingMiddleware,
AccessPolicy,
AccessGuardMiddleware,
ExtractContentMiddleware,
PlaceholderFilterMiddleware,
OwnerCommandMiddleware,
BuildSourceMiddleware,
GroupAtGuardMiddleware,
DispatchMiddleware,
InboundPipelineBuilder,
YuanbaoAdapter,
)
from gateway.config import Platform, PlatformConfig
# ============================================================
# Helpers
# ============================================================
def make_config(**kwargs):
extra = kwargs.pop("extra", {})
extra.setdefault("app_id", "test_key")
extra.setdefault("app_secret", "test_secret")
extra.setdefault("ws_url", "wss://test.example.com/ws")
extra.setdefault("api_domain", "https://test.example.com")
return PlatformConfig(
extra=extra,
**kwargs,
)
def make_adapter(**kwargs) -> YuanbaoAdapter:
"""Create a YuanbaoAdapter with test config."""
config = make_config(**kwargs)
adapter = YuanbaoAdapter(config)
adapter._bot_id = "bot_123"
return adapter
def make_ctx(adapter=None, conn_data=b"", **overrides) -> InboundContext:
"""Create an InboundContext with sensible defaults for testing."""
if adapter is None:
adapter = make_adapter()
raw_frames = [conn_data] if conn_data else []
ctx = InboundContext(adapter=adapter, raw_frames=raw_frames)
for k, v in overrides.items():
setattr(ctx, k, v)
return ctx
def make_json_push(
from_account="alice",
to_account="bot_123",
group_code="",
text="Hello!",
msg_id="msg-001",
) -> bytes:
"""Build a JSON callback_command push payload.
Note: MsgContent inner fields use lowercase ("text" not "Text")
because _extract_text() looks for lowercase keys.
"""
msg_body = [{"MsgType": "TIMTextElem", "MsgContent": {"text": text}}]
push = {
"CallbackCommand": "C2C.CallbackAfterSendMsg",
"From_Account": from_account,
"To_Account": to_account,
"MsgBody": msg_body,
"MsgKey": msg_id,
}
if group_code:
push["CallbackCommand"] = "Group.CallbackAfterSendMsg"
push["GroupId"] = group_code
return json.dumps(push).encode("utf-8")
# ============================================================
# 1. InboundPipeline Engine Tests
# ============================================================
class TestInboundPipeline:
"""Test the pipeline engine itself."""
@pytest.mark.asyncio
async def test_empty_pipeline(self):
"""Empty pipeline executes without error."""
pipeline = InboundPipeline()
ctx = make_ctx()
await pipeline.execute(ctx) # Should not raise
@pytest.mark.asyncio
async def test_single_middleware(self):
"""Single middleware is called with ctx and next_fn."""
called = []
async def mw(ctx, next_fn):
called.append("mw")
await next_fn()
pipeline = InboundPipeline().use("test", mw)
ctx = make_ctx()
await pipeline.execute(ctx)
assert called == ["mw"]
@pytest.mark.asyncio
async def test_middleware_order(self):
"""Middlewares execute in registration order."""
order = []
async def mw_a(ctx, next_fn):
order.append("a")
await next_fn()
async def mw_b(ctx, next_fn):
order.append("b")
await next_fn()
async def mw_c(ctx, next_fn):
order.append("c")
await next_fn()
pipeline = InboundPipeline().use("a", mw_a).use("b", mw_b).use("c", mw_c)
await pipeline.execute(make_ctx())
assert order == ["a", "b", "c"]
@pytest.mark.asyncio
async def test_middleware_can_stop_pipeline(self):
"""A middleware that doesn't call next_fn stops the pipeline."""
order = []
async def mw_stop(ctx, next_fn):
order.append("stop")
# Don't call next_fn — pipeline stops here
async def mw_after(ctx, next_fn):
order.append("after")
await next_fn()
pipeline = InboundPipeline().use("stop", mw_stop).use("after", mw_after)
await pipeline.execute(make_ctx())
assert order == ["stop"] # "after" should NOT be called
@pytest.mark.asyncio
async def test_conditional_guard_skip(self):
"""Middleware with when=False is skipped."""
order = []
async def mw_a(ctx, next_fn):
order.append("a")
await next_fn()
async def mw_skipped(ctx, next_fn):
order.append("skipped")
await next_fn()
async def mw_c(ctx, next_fn):
order.append("c")
await next_fn()
pipeline = (
InboundPipeline()
.use("a", mw_a)
.use("skipped", mw_skipped, when=lambda ctx: False)
.use("c", mw_c)
)
await pipeline.execute(make_ctx())
assert order == ["a", "c"]
@pytest.mark.asyncio
async def test_conditional_guard_pass(self):
"""Middleware with when=True is executed."""
order = []
async def mw(ctx, next_fn):
order.append("mw")
await next_fn()
pipeline = InboundPipeline().use("mw", mw, when=lambda ctx: True)
await pipeline.execute(make_ctx())
assert order == ["mw"]
def test_use_before(self):
"""use_before inserts middleware before the target."""
async def noop(ctx, next_fn):
await next_fn()
pipeline = InboundPipeline().use("a", noop).use("c", noop)
pipeline.use_before("c", "b", noop)
assert pipeline.middleware_names == ["a", "b", "c"]
def test_use_before_nonexistent_appends(self):
"""use_before with nonexistent target appends to end."""
async def noop(ctx, next_fn):
await next_fn()
pipeline = InboundPipeline().use("a", noop)
pipeline.use_before("nonexistent", "b", noop)
assert pipeline.middleware_names == ["a", "b"]
def test_use_after(self):
"""use_after inserts middleware after the target."""
async def noop(ctx, next_fn):
await next_fn()
pipeline = InboundPipeline().use("a", noop).use("c", noop)
pipeline.use_after("a", "b", noop)
assert pipeline.middleware_names == ["a", "b", "c"]
def test_use_after_nonexistent_appends(self):
"""use_after with nonexistent target appends to end."""
async def noop(ctx, next_fn):
await next_fn()
pipeline = InboundPipeline().use("a", noop)
pipeline.use_after("nonexistent", "b", noop)
assert pipeline.middleware_names == ["a", "b"]
def test_remove(self):
"""remove deletes middleware by name."""
async def noop(ctx, next_fn):
await next_fn()
pipeline = InboundPipeline().use("a", noop).use("b", noop).use("c", noop)
pipeline.remove("b")
assert pipeline.middleware_names == ["a", "c"]
def test_remove_nonexistent_is_noop(self):
"""remove with nonexistent name is a no-op."""
async def noop(ctx, next_fn):
await next_fn()
pipeline = InboundPipeline().use("a", noop)
pipeline.remove("nonexistent")
assert pipeline.middleware_names == ["a"]
@pytest.mark.asyncio
async def test_error_propagation(self):
"""Errors in middlewares propagate to the caller."""
async def mw_error(ctx, next_fn):
raise ValueError("test error")
pipeline = InboundPipeline().use("error", mw_error)
with pytest.raises(ValueError, match="test error"):
await pipeline.execute(make_ctx())
def test_middleware_names_property(self):
"""middleware_names returns ordered list of names."""
async def noop(ctx, next_fn):
await next_fn()
pipeline = (
InboundPipeline()
.use("decode", noop)
.use("dedup", noop)
.use("dispatch", noop)
)
assert pipeline.middleware_names == ["decode", "dedup", "dispatch"]
@pytest.mark.asyncio
async def test_onion_model(self):
"""Middlewares support before/after processing (onion model)."""
order = []
async def mw_outer(ctx, next_fn):
order.append("outer-before")
await next_fn()
order.append("outer-after")
async def mw_inner(ctx, next_fn):
order.append("inner")
await next_fn()
pipeline = InboundPipeline().use("outer", mw_outer).use("inner", mw_inner)
await pipeline.execute(make_ctx())
assert order == ["outer-before", "inner", "outer-after"]
# ============================================================
# 2. InboundContext Tests
# ============================================================
class TestInboundContext:
def test_default_values(self):
"""InboundContext has sensible defaults."""
adapter = make_adapter()
ctx = InboundContext(adapter=adapter)
assert ctx.raw_frames == []
assert ctx.push is None
assert ctx.decoded_via == ""
assert ctx.from_account == ""
assert ctx.group_code == ""
assert ctx.msg_body == []
assert ctx.msg_id == ""
assert ctx.chat_id == ""
assert ctx.chat_type == ""
assert ctx.raw_text == ""
assert ctx.media_refs == []
assert ctx.owner_command is None
assert ctx.source is None
assert ctx.msg_type is None
def test_mutable_fields(self):
"""InboundContext fields are mutable."""
ctx = make_ctx()
ctx.from_account = "alice"
ctx.chat_type = "dm"
assert ctx.from_account == "alice"
assert ctx.chat_type == "dm"
# ============================================================
# 3. Individual Middleware Tests
# ============================================================
class TestDecodeMiddleware:
@pytest.mark.asyncio
async def test_json_decode(self):
"""DecodeMiddleware parses JSON push correctly."""
push_data = make_json_push(from_account="alice", text="hi")
ctx = make_ctx(conn_data=push_data)
next_fn = AsyncMock()
await DecodeMiddleware()(ctx, next_fn)
assert ctx.push is not None
assert ctx.decoded_via == "json"
assert ctx.push.get("from_account") == "alice"
next_fn.assert_awaited_once()
@pytest.mark.asyncio
async def test_empty_data_stops_pipeline(self):
"""DecodeMiddleware stops pipeline on empty conn_data."""
ctx = make_ctx(conn_data=b"")
next_fn = AsyncMock()
await DecodeMiddleware()(ctx, next_fn)
assert ctx.push is None
next_fn.assert_not_awaited()
@pytest.mark.asyncio
async def test_invalid_data_may_produce_garbage(self):
"""DecodeMiddleware: binary data may be parsed by protobuf as garbage fields.
This is expected behavior — the protobuf parser is lenient and may
produce "seemingly valid" fields from arbitrary bytes. The downstream
middlewares (dedup, skip-self, etc.) will filter out such garbage.
"""
ctx = make_ctx(conn_data=b"\x00\x01\x02\x03")
next_fn = AsyncMock()
await DecodeMiddleware()(ctx, next_fn)
# Protobuf parser may or may not produce a result — either is acceptable.
# The key invariant: no exception is raised.
assert True # Reached here without error
class TestExtractFieldsMiddleware:
@pytest.mark.asyncio
async def test_extracts_fields(self):
"""ExtractFieldsMiddleware populates ctx from push dict."""
ctx = make_ctx(push={
"from_account": "alice",
"group_code": "grp-1",
"group_name": "Test Group",
"sender_nickname": "Alice",
"msg_body": [{"msg_type": "TIMTextElem", "msg_content": {"text": "hi"}}],
"msg_id": "msg-001",
"cloud_custom_data": '{"key": "val"}',
})
next_fn = AsyncMock()
await ExtractFieldsMiddleware()(ctx, next_fn)
assert ctx.from_account == "alice"
assert ctx.group_code == "grp-1"
assert ctx.group_name == "Test Group"
assert ctx.sender_nickname == "Alice"
assert len(ctx.msg_body) == 1
assert ctx.msg_id == "msg-001"
assert ctx.cloud_custom_data == '{"key": "val"}'
next_fn.assert_awaited_once()
class TestDedupMiddleware:
@pytest.mark.asyncio
async def test_new_message_passes(self):
"""DedupMiddleware passes new messages through."""
adapter = make_adapter()
ctx = make_ctx(adapter=adapter, msg_id="unique-msg-001")
next_fn = AsyncMock()
await DedupMiddleware()(ctx, next_fn)
next_fn.assert_awaited_once()
@pytest.mark.asyncio
async def test_duplicate_stops_pipeline(self):
"""DedupMiddleware stops pipeline for duplicate messages."""
adapter = make_adapter()
# Mark message as seen
adapter._dedup.is_duplicate("dup-msg-001")
ctx = make_ctx(adapter=adapter, msg_id="dup-msg-001")
next_fn = AsyncMock()
await DedupMiddleware()(ctx, next_fn)
next_fn.assert_not_awaited()
@pytest.mark.asyncio
async def test_empty_msg_id_passes(self):
"""DedupMiddleware passes messages with empty msg_id."""
ctx = make_ctx(msg_id="")
next_fn = AsyncMock()
await DedupMiddleware()(ctx, next_fn)
next_fn.assert_awaited_once()
class TestSkipSelfMiddleware:
@pytest.mark.asyncio
async def test_self_message_stops(self):
"""SkipSelfMiddleware stops pipeline for bot's own messages."""
adapter = make_adapter()
adapter._bot_id = "bot_123"
ctx = make_ctx(adapter=adapter, from_account="bot_123")
next_fn = AsyncMock()
await SkipSelfMiddleware()(ctx, next_fn)
next_fn.assert_not_awaited()
@pytest.mark.asyncio
async def test_other_message_passes(self):
"""SkipSelfMiddleware passes messages from other users."""
adapter = make_adapter()
adapter._bot_id = "bot_123"
ctx = make_ctx(adapter=adapter, from_account="alice")
next_fn = AsyncMock()
await SkipSelfMiddleware()(ctx, next_fn)
next_fn.assert_awaited_once()
class TestChatRoutingMiddleware:
@pytest.mark.asyncio
async def test_group_routing(self):
"""ChatRoutingMiddleware sets group chat fields."""
ctx = make_ctx(group_code="grp-1", group_name="Test Group")
next_fn = AsyncMock()
await ChatRoutingMiddleware()(ctx, next_fn)
assert ctx.chat_id == "group:grp-1"
assert ctx.chat_type == "group"
assert ctx.chat_name == "Test Group"
next_fn.assert_awaited_once()
@pytest.mark.asyncio
async def test_dm_routing(self):
"""ChatRoutingMiddleware sets DM chat fields."""
ctx = make_ctx(from_account="alice", sender_nickname="Alice")
next_fn = AsyncMock()
await ChatRoutingMiddleware()(ctx, next_fn)
assert ctx.chat_id == "direct:alice"
assert ctx.chat_type == "dm"
assert ctx.chat_name == "Alice"
next_fn.assert_awaited_once()
@pytest.mark.asyncio
async def test_dm_routing_no_nickname(self):
"""ChatRoutingMiddleware falls back to from_account when no nickname."""
ctx = make_ctx(from_account="alice", sender_nickname="")
next_fn = AsyncMock()
await ChatRoutingMiddleware()(ctx, next_fn)
assert ctx.chat_name == "alice"
class TestAccessGuardMiddleware:
@pytest.mark.asyncio
async def test_open_policy_passes(self):
"""AccessGuardMiddleware passes with open policy."""
adapter = make_adapter()
adapter._access_policy = AccessPolicy(dm_policy="open", dm_allow_from=[], group_policy="open", group_allow_from=[])
ctx = make_ctx(adapter=adapter, chat_type="dm", from_account="alice")
next_fn = AsyncMock()
await AccessGuardMiddleware()(ctx, next_fn)
next_fn.assert_awaited_once()
@pytest.mark.asyncio
async def test_disabled_dm_stops(self):
"""AccessGuardMiddleware stops DM when dm_policy=disabled."""
adapter = make_adapter()
adapter._access_policy = AccessPolicy(dm_policy="disabled", dm_allow_from=[], group_policy="open", group_allow_from=[])
ctx = make_ctx(adapter=adapter, chat_type="dm", from_account="alice")
next_fn = AsyncMock()
await AccessGuardMiddleware()(ctx, next_fn)
next_fn.assert_not_awaited()
@pytest.mark.asyncio
async def test_allowlist_dm_allowed(self):
"""AccessGuardMiddleware passes DM when sender is in allowlist."""
adapter = make_adapter()
adapter._access_policy = AccessPolicy(dm_policy="allowlist", dm_allow_from=["alice"], group_policy="open", group_allow_from=[])
ctx = make_ctx(adapter=adapter, chat_type="dm", from_account="alice")
next_fn = AsyncMock()
await AccessGuardMiddleware()(ctx, next_fn)
next_fn.assert_awaited_once()
@pytest.mark.asyncio
async def test_allowlist_dm_blocked(self):
"""AccessGuardMiddleware blocks DM when sender is not in allowlist."""
adapter = make_adapter()
adapter._access_policy = AccessPolicy(dm_policy="allowlist", dm_allow_from=["bob"], group_policy="open", group_allow_from=[])
ctx = make_ctx(adapter=adapter, chat_type="dm", from_account="alice")
next_fn = AsyncMock()
await AccessGuardMiddleware()(ctx, next_fn)
next_fn.assert_not_awaited()
@pytest.mark.asyncio
async def test_disabled_group_stops(self):
"""AccessGuardMiddleware stops group when group_policy=disabled."""
adapter = make_adapter()
adapter._access_policy = AccessPolicy(dm_policy="open", dm_allow_from=[], group_policy="disabled", group_allow_from=[])
ctx = make_ctx(adapter=adapter, chat_type="group", group_code="grp-1")
next_fn = AsyncMock()
await AccessGuardMiddleware()(ctx, next_fn)
next_fn.assert_not_awaited()
@pytest.mark.asyncio
async def test_allowlist_group_allowed(self):
"""AccessGuardMiddleware passes group when group_code is in allowlist."""
adapter = make_adapter()
adapter._access_policy = AccessPolicy(dm_policy="open", dm_allow_from=[], group_policy="allowlist", group_allow_from=["grp-1"])
ctx = make_ctx(adapter=adapter, chat_type="group", group_code="grp-1")
next_fn = AsyncMock()
await AccessGuardMiddleware()(ctx, next_fn)
next_fn.assert_awaited_once()
class TestExtractContentMiddleware:
@pytest.mark.asyncio
async def test_extracts_text_and_media(self):
"""ExtractContentMiddleware extracts text and media refs."""
adapter = make_adapter()
msg_body = [
{"msg_type": "TIMTextElem", "msg_content": {"text": "Hello!"}},
{"msg_type": "TIMImageElem", "msg_content": {
"image_info_array": [{"url": "https://img.example.com/1.jpg"}]
}},
]
ctx = make_ctx(adapter=adapter, msg_body=msg_body)
next_fn = AsyncMock()
await ExtractContentMiddleware()(ctx, next_fn)
assert "Hello!" in ctx.raw_text
assert len(ctx.media_refs) == 1
assert ctx.media_refs[0]["kind"] == "image"
next_fn.assert_awaited_once()
class TestPlaceholderFilterMiddleware:
@pytest.mark.asyncio
async def test_placeholder_stops(self):
"""PlaceholderFilterMiddleware stops on pure placeholder."""
ctx = make_ctx(raw_text="[image]", media_refs=[])
next_fn = AsyncMock()
await PlaceholderFilterMiddleware()(ctx, next_fn)
next_fn.assert_not_awaited()
@pytest.mark.asyncio
async def test_placeholder_with_media_passes(self):
"""PlaceholderFilterMiddleware passes placeholder when media exists."""
ctx = make_ctx(
raw_text="[image]",
media_refs=[{"kind": "image", "url": "https://img.example.com/1.jpg"}],
)
next_fn = AsyncMock()
await PlaceholderFilterMiddleware()(ctx, next_fn)
next_fn.assert_awaited_once()
@pytest.mark.asyncio
async def test_normal_text_passes(self):
"""PlaceholderFilterMiddleware passes normal text."""
ctx = make_ctx(raw_text="Hello world!")
next_fn = AsyncMock()
await PlaceholderFilterMiddleware()(ctx, next_fn)
next_fn.assert_awaited_once()
class TestGroupAtGuardMiddleware:
@pytest.mark.asyncio
async def test_dm_passes(self):
"""GroupAtGuardMiddleware passes DM messages."""
adapter = make_adapter()
ctx = make_ctx(adapter=adapter, chat_type="dm")
next_fn = AsyncMock()
await GroupAtGuardMiddleware()(ctx, next_fn)
next_fn.assert_awaited_once()
@pytest.mark.asyncio
async def test_group_with_at_bot_passes(self):
"""GroupAtGuardMiddleware passes group messages that @bot."""
adapter = make_adapter()
adapter._bot_id = "bot_123"
msg_body = [
{"msg_type": "TIMCustomElem", "msg_content": {
"data": json.dumps({"elem_type": 1002, "text": "@Bot", "user_id": "bot_123"})
}},
]
ctx = make_ctx(
adapter=adapter,
chat_type="group",
chat_id="group:grp-1",
msg_body=msg_body,
from_account="alice",
sender_nickname="Alice",
raw_text="Hello",
source=MagicMock(),
)
next_fn = AsyncMock()
await GroupAtGuardMiddleware()(ctx, next_fn)
next_fn.assert_awaited_once()
@pytest.mark.asyncio
async def test_group_without_at_bot_observes(self):
"""GroupAtGuardMiddleware observes group messages without @bot."""
adapter = make_adapter()
adapter._bot_id = "bot_123"
adapter._session_store = None # No session store -> observe is a no-op
ctx = make_ctx(
adapter=adapter,
chat_type="group",
chat_id="group:grp-1",
msg_body=[{"msg_type": "TIMTextElem", "msg_content": {"text": "hi"}}],
from_account="alice",
sender_nickname="Alice",
raw_text="hi",
source=MagicMock(),
)
next_fn = AsyncMock()
await GroupAtGuardMiddleware()(ctx, next_fn)
next_fn.assert_not_awaited()
@pytest.mark.asyncio
async def test_owner_command_skips_at_check(self):
"""GroupAtGuardMiddleware passes when owner_command is set."""
adapter = make_adapter()
adapter._bot_id = "bot_123"
ctx = make_ctx(
adapter=adapter,
chat_type="group",
msg_body=[],
owner_command="/new",
source=MagicMock(),
)
next_fn = AsyncMock()
await GroupAtGuardMiddleware()(ctx, next_fn)
next_fn.assert_awaited_once()
# ============================================================
# 4. Factory Tests
# ============================================================
class TestCreateInboundPipeline:
def test_default_pipeline_has_all_middlewares(self):
"""InboundPipelineBuilder.build() creates pipeline with all expected middlewares."""
pipeline = InboundPipelineBuilder.build()
expected = [
"decode",
"extract-fields",
"dedup",
"skip-self",
"chat-routing",
"access-guard",
"extract-content",
"placeholder-filter",
"owner-command",
"build-source",
"group-at-guard",
"classify-msg-type",
"quote-context",
"media-resolve",
"dispatch",
]
"""Pipeline can be customized after creation."""
pipeline = InboundPipelineBuilder.build()
async def custom_mw(ctx, next_fn):
await next_fn()
pipeline.use_before("dispatch", "custom", custom_mw)
assert "custom" in pipeline.middleware_names
idx_custom = pipeline.middleware_names.index("custom")
idx_dispatch = pipeline.middleware_names.index("dispatch")
assert idx_custom < idx_dispatch
# ============================================================
# 5. End-to-End Pipeline Integration Tests
# ============================================================
class TestPipelineIntegration:
@pytest.mark.asyncio
async def test_full_dm_message_flow(self):
"""Full pipeline processes a DM message end-to-end."""
adapter = make_adapter()
adapter._bot_id = "bot_123"
adapter._access_policy = AccessPolicy(dm_policy="open", dm_allow_from=[], group_policy="open", group_allow_from=[])
adapter.handle_message = AsyncMock()
adapter._resolve_inbound_media_urls = AsyncMock(return_value=([], []))
push_data = make_json_push(
from_account="alice",
to_account="bot_123",
text="Hello bot!",
msg_id="msg-e2e-001",
)
ctx = InboundContext(adapter=adapter, raw_frames=[push_data])
pipeline = InboundPipelineBuilder.build()
await pipeline.execute(ctx)
# Verify context was populated correctly
assert ctx.decoded_via == "json"
assert ctx.from_account == "alice"
assert ctx.chat_type == "dm"
assert ctx.chat_id == "direct:alice"
assert "Hello bot!" in ctx.raw_text
assert ctx.source is not None
@pytest.mark.asyncio
async def test_self_message_filtered(self):
"""Pipeline stops when message is from bot itself."""
adapter = make_adapter()
adapter._bot_id = "bot_123"
push_data = make_json_push(
from_account="bot_123",
to_account="bot_123",
text="echo",
msg_id="msg-self-001",
)
ctx = InboundContext(adapter=adapter, raw_frames=[push_data])
pipeline = InboundPipelineBuilder.build()
await pipeline.execute(ctx)
# Pipeline should have stopped at skip-self — no source built
assert ctx.source is None
@pytest.mark.asyncio
async def test_duplicate_message_filtered(self):
"""Pipeline stops on duplicate message."""
adapter = make_adapter()
adapter._bot_id = "bot_123"
# First message goes through
push_data = make_json_push(
from_account="alice",
text="Hello!",
msg_id="msg-dup-001",
)
ctx1 = InboundContext(adapter=adapter, raw_frames=[push_data])
pipeline = InboundPipelineBuilder.build()
await pipeline.execute(ctx1)
assert ctx1.from_account == "alice"
# Second message with same msg_id is filtered
ctx2 = InboundContext(adapter=adapter, raw_frames=[push_data])
await pipeline.execute(ctx2)
# Dedup should stop pipeline before chat routing
assert ctx2.chat_type == ""
@pytest.mark.asyncio
async def test_blocked_dm_filtered(self):
"""Pipeline stops when DM is blocked by policy."""
adapter = make_adapter()
adapter._bot_id = "bot_123"
adapter._access_policy = AccessPolicy(dm_policy="disabled", dm_allow_from=[], group_policy="open", group_allow_from=[])
push_data = make_json_push(
from_account="alice",
text="Hello!",
msg_id="msg-blocked-001",
)
ctx = InboundContext(adapter=adapter, raw_frames=[push_data])
pipeline = InboundPipelineBuilder.build()
await pipeline.execute(ctx)
# Pipeline stopped at access-guard — no content extracted
assert ctx.raw_text == ""
@pytest.mark.asyncio
async def test_adapter_has_pipeline(self):
"""YuanbaoAdapter.__init__ creates an inbound pipeline."""
adapter = make_adapter()
assert hasattr(adapter, "_inbound_pipeline")
assert isinstance(adapter._inbound_pipeline, InboundPipeline)
if __name__ == "__main__":
pytest.main([__file__, "-v"])
# ============================================================
# 6. OOP Middleware Tests
# ============================================================
class TestInboundMiddlewareABC:
"""Test the InboundMiddleware abstract base class."""
def test_cannot_instantiate_abc(self):
"""InboundMiddleware cannot be instantiated directly."""
with pytest.raises(TypeError):
InboundMiddleware()
def test_subclass_must_implement_handle(self):
"""Subclass without handle() raises TypeError."""
with pytest.raises(TypeError):
class BadMiddleware(InboundMiddleware):
name = "bad"
BadMiddleware()
def test_subclass_with_handle_works(self):
"""Subclass with handle() can be instantiated."""
class GoodMiddleware(InboundMiddleware):
name = "good"
async def handle(self, ctx, next_fn):
await next_fn()
mw = GoodMiddleware()
assert mw.name == "good"
@pytest.mark.asyncio
async def test_callable_protocol(self):
"""Middleware instances are callable via __call__."""
class TestMW(InboundMiddleware):
name = "test"
async def handle(self, ctx, next_fn):
ctx.raw_text = "called"
await next_fn()
mw = TestMW()
ctx = make_ctx()
next_fn = AsyncMock()
await mw(ctx, next_fn) # Call via __call__
assert ctx.raw_text == "called"
next_fn.assert_awaited_once()
def test_repr(self):
"""Middleware has a useful repr."""
class MyMW(InboundMiddleware):
name = "my-mw"
async def handle(self, ctx, next_fn):
pass
mw = MyMW()
assert "MyMW" in repr(mw)
assert "my-mw" in repr(mw)
class TestMiddlewareClasses:
"""Test that all concrete middleware classes have correct names and are InboundMiddleware subclasses."""
MIDDLEWARE_CLASSES = [
(DecodeMiddleware, "decode"),
(ExtractFieldsMiddleware, "extract-fields"),
(DedupMiddleware, "dedup"),
(SkipSelfMiddleware, "skip-self"),
(ChatRoutingMiddleware, "chat-routing"),
(AccessGuardMiddleware, "access-guard"),
(ExtractContentMiddleware, "extract-content"),
(PlaceholderFilterMiddleware, "placeholder-filter"),
(OwnerCommandMiddleware, "owner-command"),
(BuildSourceMiddleware, "build-source"),
(GroupAtGuardMiddleware, "group-at-guard"),
(DispatchMiddleware, "dispatch"),
]
@pytest.mark.parametrize("cls,expected_name", MIDDLEWARE_CLASSES)
def test_is_inbound_middleware(self, cls, expected_name):
"""Each middleware class is a subclass of InboundMiddleware."""
assert issubclass(cls, InboundMiddleware)
@pytest.mark.parametrize("cls,expected_name", MIDDLEWARE_CLASSES)
def test_has_correct_name(self, cls, expected_name):
"""Each middleware class has the expected name."""
mw = cls()
assert mw.name == expected_name
@pytest.mark.parametrize("cls,expected_name", MIDDLEWARE_CLASSES)
def test_is_callable(self, cls, expected_name):
"""Each middleware instance is callable."""
mw = cls()
assert callable(mw)
class TestPipelineOOPRegistration:
"""Test that InboundPipeline works with OOP middleware instances."""
@pytest.mark.asyncio
async def test_use_with_middleware_instance(self):
"""pipeline.use(SomeMiddleware()) auto-extracts name."""
class TestMW(InboundMiddleware):
name = "test-mw"
async def handle(self, ctx, next_fn):
ctx.raw_text = "oop-works"
await next_fn()
pipeline = InboundPipeline().use(TestMW())
assert pipeline.middleware_names == ["test-mw"]
ctx = make_ctx()
await pipeline.execute(ctx)
assert ctx.raw_text == "oop-works"
@pytest.mark.asyncio
async def test_mixed_oop_and_functional(self):
"""Pipeline supports mixing OOP and functional middlewares."""
order = []
class OopMW(InboundMiddleware):
name = "oop"
async def handle(self, ctx, next_fn):
order.append("oop")
await next_fn()
async def func_mw(ctx, next_fn):
order.append("func")
await next_fn()
pipeline = (
InboundPipeline()
.use(OopMW())
.use("func", func_mw)
)
assert pipeline.middleware_names == ["oop", "func"]
await pipeline.execute(make_ctx())
assert order == ["oop", "func"]
def test_use_before_with_middleware_instance(self):
"""use_before works with OOP middleware instances."""
class MwA(InboundMiddleware):
name = "a"
async def handle(self, ctx, next_fn): await next_fn()
class MwB(InboundMiddleware):
name = "b"
async def handle(self, ctx, next_fn): await next_fn()
class MwC(InboundMiddleware):
name = "c"
async def handle(self, ctx, next_fn): await next_fn()
pipeline = InboundPipeline().use(MwA()).use(MwC())
pipeline.use_before("c", MwB())
assert pipeline.middleware_names == ["a", "b", "c"]
def test_use_after_with_middleware_instance(self):
"""use_after works with OOP middleware instances."""
class MwA(InboundMiddleware):
name = "a"
async def handle(self, ctx, next_fn): await next_fn()
class MwB(InboundMiddleware):
name = "b"
async def handle(self, ctx, next_fn): await next_fn()
class MwC(InboundMiddleware):
name = "c"
async def handle(self, ctx, next_fn): await next_fn()
pipeline = InboundPipeline().use(MwA()).use(MwC())
pipeline.use_after("a", MwB())
assert pipeline.middleware_names == ["a", "b", "c"]