mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-28 06:51:16 +08:00
1030 lines
34 KiB
Python
1030 lines
34 KiB
Python
"""
|
|
test_yuanbao_pipeline.py - Unit tests for the inbound middleware pipeline.
|
|
|
|
Tests cover:
|
|
1. InboundPipeline engine (use, use_before, use_after, remove, execute)
|
|
2. InboundContext dataclass
|
|
3. Individual middlewares (DecodeMiddleware, DedupMiddleware, SkipSelfMiddleware, etc.)
|
|
4. InboundPipelineBuilder
|
|
5. End-to-end pipeline integration
|
|
6. OOP middleware ABC and class tests
|
|
"""
|
|
|
|
import sys
|
|
import os
|
|
import json
|
|
import asyncio
|
|
|
|
# Ensure project root is on the path
|
|
_REPO_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
|
if _REPO_ROOT not in sys.path:
|
|
sys.path.insert(0, _REPO_ROOT)
|
|
|
|
import pytest
|
|
from unittest.mock import AsyncMock, MagicMock, patch, PropertyMock
|
|
|
|
from gateway.platforms.yuanbao import (
|
|
InboundContext,
|
|
InboundMiddleware,
|
|
InboundPipeline,
|
|
DecodeMiddleware,
|
|
ExtractFieldsMiddleware,
|
|
DedupMiddleware,
|
|
SkipSelfMiddleware,
|
|
ChatRoutingMiddleware,
|
|
AccessPolicy,
|
|
AccessGuardMiddleware,
|
|
ExtractContentMiddleware,
|
|
PlaceholderFilterMiddleware,
|
|
OwnerCommandMiddleware,
|
|
BuildSourceMiddleware,
|
|
GroupAtGuardMiddleware,
|
|
DispatchMiddleware,
|
|
InboundPipelineBuilder,
|
|
YuanbaoAdapter,
|
|
)
|
|
from gateway.config import Platform, PlatformConfig
|
|
|
|
|
|
# ============================================================
|
|
# Helpers
|
|
# ============================================================
|
|
|
|
def make_config(**kwargs):
|
|
extra = kwargs.pop("extra", {})
|
|
extra.setdefault("app_id", "test_key")
|
|
extra.setdefault("app_secret", "test_secret")
|
|
extra.setdefault("ws_url", "wss://test.example.com/ws")
|
|
extra.setdefault("api_domain", "https://test.example.com")
|
|
return PlatformConfig(
|
|
extra=extra,
|
|
**kwargs,
|
|
)
|
|
|
|
|
|
def make_adapter(**kwargs) -> YuanbaoAdapter:
|
|
"""Create a YuanbaoAdapter with test config."""
|
|
config = make_config(**kwargs)
|
|
adapter = YuanbaoAdapter(config)
|
|
adapter._bot_id = "bot_123"
|
|
return adapter
|
|
|
|
|
|
def make_ctx(adapter=None, conn_data=b"", **overrides) -> InboundContext:
|
|
"""Create an InboundContext with sensible defaults for testing."""
|
|
if adapter is None:
|
|
adapter = make_adapter()
|
|
raw_frames = [conn_data] if conn_data else []
|
|
ctx = InboundContext(adapter=adapter, raw_frames=raw_frames)
|
|
for k, v in overrides.items():
|
|
setattr(ctx, k, v)
|
|
return ctx
|
|
|
|
|
|
def make_json_push(
|
|
from_account="alice",
|
|
to_account="bot_123",
|
|
group_code="",
|
|
text="Hello!",
|
|
msg_id="msg-001",
|
|
) -> bytes:
|
|
"""Build a JSON callback_command push payload.
|
|
|
|
Note: MsgContent inner fields use lowercase ("text" not "Text")
|
|
because _extract_text() looks for lowercase keys.
|
|
"""
|
|
msg_body = [{"MsgType": "TIMTextElem", "MsgContent": {"text": text}}]
|
|
push = {
|
|
"CallbackCommand": "C2C.CallbackAfterSendMsg",
|
|
"From_Account": from_account,
|
|
"To_Account": to_account,
|
|
"MsgBody": msg_body,
|
|
"MsgKey": msg_id,
|
|
}
|
|
if group_code:
|
|
push["CallbackCommand"] = "Group.CallbackAfterSendMsg"
|
|
push["GroupId"] = group_code
|
|
return json.dumps(push).encode("utf-8")
|
|
|
|
|
|
# ============================================================
|
|
# 1. InboundPipeline Engine Tests
|
|
# ============================================================
|
|
|
|
class TestInboundPipeline:
|
|
"""Test the pipeline engine itself."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_empty_pipeline(self):
|
|
"""Empty pipeline executes without error."""
|
|
pipeline = InboundPipeline()
|
|
ctx = make_ctx()
|
|
await pipeline.execute(ctx) # Should not raise
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_single_middleware(self):
|
|
"""Single middleware is called with ctx and next_fn."""
|
|
called = []
|
|
|
|
async def mw(ctx, next_fn):
|
|
called.append("mw")
|
|
await next_fn()
|
|
|
|
pipeline = InboundPipeline().use("test", mw)
|
|
ctx = make_ctx()
|
|
await pipeline.execute(ctx)
|
|
assert called == ["mw"]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_middleware_order(self):
|
|
"""Middlewares execute in registration order."""
|
|
order = []
|
|
|
|
async def mw_a(ctx, next_fn):
|
|
order.append("a")
|
|
await next_fn()
|
|
|
|
async def mw_b(ctx, next_fn):
|
|
order.append("b")
|
|
await next_fn()
|
|
|
|
async def mw_c(ctx, next_fn):
|
|
order.append("c")
|
|
await next_fn()
|
|
|
|
pipeline = InboundPipeline().use("a", mw_a).use("b", mw_b).use("c", mw_c)
|
|
await pipeline.execute(make_ctx())
|
|
assert order == ["a", "b", "c"]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_middleware_can_stop_pipeline(self):
|
|
"""A middleware that doesn't call next_fn stops the pipeline."""
|
|
order = []
|
|
|
|
async def mw_stop(ctx, next_fn):
|
|
order.append("stop")
|
|
# Don't call next_fn — pipeline stops here
|
|
|
|
async def mw_after(ctx, next_fn):
|
|
order.append("after")
|
|
await next_fn()
|
|
|
|
pipeline = InboundPipeline().use("stop", mw_stop).use("after", mw_after)
|
|
await pipeline.execute(make_ctx())
|
|
assert order == ["stop"] # "after" should NOT be called
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_conditional_guard_skip(self):
|
|
"""Middleware with when=False is skipped."""
|
|
order = []
|
|
|
|
async def mw_a(ctx, next_fn):
|
|
order.append("a")
|
|
await next_fn()
|
|
|
|
async def mw_skipped(ctx, next_fn):
|
|
order.append("skipped")
|
|
await next_fn()
|
|
|
|
async def mw_c(ctx, next_fn):
|
|
order.append("c")
|
|
await next_fn()
|
|
|
|
pipeline = (
|
|
InboundPipeline()
|
|
.use("a", mw_a)
|
|
.use("skipped", mw_skipped, when=lambda ctx: False)
|
|
.use("c", mw_c)
|
|
)
|
|
await pipeline.execute(make_ctx())
|
|
assert order == ["a", "c"]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_conditional_guard_pass(self):
|
|
"""Middleware with when=True is executed."""
|
|
order = []
|
|
|
|
async def mw(ctx, next_fn):
|
|
order.append("mw")
|
|
await next_fn()
|
|
|
|
pipeline = InboundPipeline().use("mw", mw, when=lambda ctx: True)
|
|
await pipeline.execute(make_ctx())
|
|
assert order == ["mw"]
|
|
|
|
def test_use_before(self):
|
|
"""use_before inserts middleware before the target."""
|
|
async def noop(ctx, next_fn):
|
|
await next_fn()
|
|
|
|
pipeline = InboundPipeline().use("a", noop).use("c", noop)
|
|
pipeline.use_before("c", "b", noop)
|
|
assert pipeline.middleware_names == ["a", "b", "c"]
|
|
|
|
def test_use_before_nonexistent_appends(self):
|
|
"""use_before with nonexistent target appends to end."""
|
|
async def noop(ctx, next_fn):
|
|
await next_fn()
|
|
|
|
pipeline = InboundPipeline().use("a", noop)
|
|
pipeline.use_before("nonexistent", "b", noop)
|
|
assert pipeline.middleware_names == ["a", "b"]
|
|
|
|
def test_use_after(self):
|
|
"""use_after inserts middleware after the target."""
|
|
async def noop(ctx, next_fn):
|
|
await next_fn()
|
|
|
|
pipeline = InboundPipeline().use("a", noop).use("c", noop)
|
|
pipeline.use_after("a", "b", noop)
|
|
assert pipeline.middleware_names == ["a", "b", "c"]
|
|
|
|
def test_use_after_nonexistent_appends(self):
|
|
"""use_after with nonexistent target appends to end."""
|
|
async def noop(ctx, next_fn):
|
|
await next_fn()
|
|
|
|
pipeline = InboundPipeline().use("a", noop)
|
|
pipeline.use_after("nonexistent", "b", noop)
|
|
assert pipeline.middleware_names == ["a", "b"]
|
|
|
|
def test_remove(self):
|
|
"""remove deletes middleware by name."""
|
|
async def noop(ctx, next_fn):
|
|
await next_fn()
|
|
|
|
pipeline = InboundPipeline().use("a", noop).use("b", noop).use("c", noop)
|
|
pipeline.remove("b")
|
|
assert pipeline.middleware_names == ["a", "c"]
|
|
|
|
def test_remove_nonexistent_is_noop(self):
|
|
"""remove with nonexistent name is a no-op."""
|
|
async def noop(ctx, next_fn):
|
|
await next_fn()
|
|
|
|
pipeline = InboundPipeline().use("a", noop)
|
|
pipeline.remove("nonexistent")
|
|
assert pipeline.middleware_names == ["a"]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_error_propagation(self):
|
|
"""Errors in middlewares propagate to the caller."""
|
|
async def mw_error(ctx, next_fn):
|
|
raise ValueError("test error")
|
|
|
|
pipeline = InboundPipeline().use("error", mw_error)
|
|
with pytest.raises(ValueError, match="test error"):
|
|
await pipeline.execute(make_ctx())
|
|
|
|
def test_middleware_names_property(self):
|
|
"""middleware_names returns ordered list of names."""
|
|
async def noop(ctx, next_fn):
|
|
await next_fn()
|
|
|
|
pipeline = (
|
|
InboundPipeline()
|
|
.use("decode", noop)
|
|
.use("dedup", noop)
|
|
.use("dispatch", noop)
|
|
)
|
|
assert pipeline.middleware_names == ["decode", "dedup", "dispatch"]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_onion_model(self):
|
|
"""Middlewares support before/after processing (onion model)."""
|
|
order = []
|
|
|
|
async def mw_outer(ctx, next_fn):
|
|
order.append("outer-before")
|
|
await next_fn()
|
|
order.append("outer-after")
|
|
|
|
async def mw_inner(ctx, next_fn):
|
|
order.append("inner")
|
|
await next_fn()
|
|
|
|
pipeline = InboundPipeline().use("outer", mw_outer).use("inner", mw_inner)
|
|
await pipeline.execute(make_ctx())
|
|
assert order == ["outer-before", "inner", "outer-after"]
|
|
|
|
|
|
# ============================================================
|
|
# 2. InboundContext Tests
|
|
# ============================================================
|
|
|
|
class TestInboundContext:
|
|
def test_default_values(self):
|
|
"""InboundContext has sensible defaults."""
|
|
adapter = make_adapter()
|
|
ctx = InboundContext(adapter=adapter)
|
|
assert ctx.raw_frames == []
|
|
assert ctx.push is None
|
|
assert ctx.decoded_via == ""
|
|
assert ctx.from_account == ""
|
|
assert ctx.group_code == ""
|
|
assert ctx.msg_body == []
|
|
assert ctx.msg_id == ""
|
|
assert ctx.chat_id == ""
|
|
assert ctx.chat_type == ""
|
|
assert ctx.raw_text == ""
|
|
assert ctx.media_refs == []
|
|
assert ctx.owner_command is None
|
|
assert ctx.source is None
|
|
assert ctx.msg_type is None
|
|
|
|
def test_mutable_fields(self):
|
|
"""InboundContext fields are mutable."""
|
|
ctx = make_ctx()
|
|
ctx.from_account = "alice"
|
|
ctx.chat_type = "dm"
|
|
assert ctx.from_account == "alice"
|
|
assert ctx.chat_type == "dm"
|
|
|
|
|
|
# ============================================================
|
|
# 3. Individual Middleware Tests
|
|
# ============================================================
|
|
|
|
class TestDecodeMiddleware:
|
|
@pytest.mark.asyncio
|
|
async def test_json_decode(self):
|
|
"""DecodeMiddleware parses JSON push correctly."""
|
|
push_data = make_json_push(from_account="alice", text="hi")
|
|
ctx = make_ctx(conn_data=push_data)
|
|
next_fn = AsyncMock()
|
|
|
|
await DecodeMiddleware()(ctx, next_fn)
|
|
|
|
assert ctx.push is not None
|
|
assert ctx.decoded_via == "json"
|
|
assert ctx.push.get("from_account") == "alice"
|
|
next_fn.assert_awaited_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_empty_data_stops_pipeline(self):
|
|
"""DecodeMiddleware stops pipeline on empty conn_data."""
|
|
ctx = make_ctx(conn_data=b"")
|
|
next_fn = AsyncMock()
|
|
|
|
await DecodeMiddleware()(ctx, next_fn)
|
|
|
|
assert ctx.push is None
|
|
next_fn.assert_not_awaited()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_invalid_data_may_produce_garbage(self):
|
|
"""DecodeMiddleware: binary data may be parsed by protobuf as garbage fields.
|
|
|
|
This is expected behavior — the protobuf parser is lenient and may
|
|
produce "seemingly valid" fields from arbitrary bytes. The downstream
|
|
middlewares (dedup, skip-self, etc.) will filter out such garbage.
|
|
"""
|
|
ctx = make_ctx(conn_data=b"\x00\x01\x02\x03")
|
|
next_fn = AsyncMock()
|
|
|
|
await DecodeMiddleware()(ctx, next_fn)
|
|
|
|
# Protobuf parser may or may not produce a result — either is acceptable.
|
|
# The key invariant: no exception is raised.
|
|
assert True # Reached here without error
|
|
|
|
|
|
class TestExtractFieldsMiddleware:
|
|
@pytest.mark.asyncio
|
|
async def test_extracts_fields(self):
|
|
"""ExtractFieldsMiddleware populates ctx from push dict."""
|
|
ctx = make_ctx(push={
|
|
"from_account": "alice",
|
|
"group_code": "grp-1",
|
|
"group_name": "Test Group",
|
|
"sender_nickname": "Alice",
|
|
"msg_body": [{"msg_type": "TIMTextElem", "msg_content": {"text": "hi"}}],
|
|
"msg_id": "msg-001",
|
|
"cloud_custom_data": '{"key": "val"}',
|
|
})
|
|
next_fn = AsyncMock()
|
|
|
|
await ExtractFieldsMiddleware()(ctx, next_fn)
|
|
|
|
assert ctx.from_account == "alice"
|
|
assert ctx.group_code == "grp-1"
|
|
assert ctx.group_name == "Test Group"
|
|
assert ctx.sender_nickname == "Alice"
|
|
assert len(ctx.msg_body) == 1
|
|
assert ctx.msg_id == "msg-001"
|
|
assert ctx.cloud_custom_data == '{"key": "val"}'
|
|
next_fn.assert_awaited_once()
|
|
|
|
|
|
class TestDedupMiddleware:
|
|
@pytest.mark.asyncio
|
|
async def test_new_message_passes(self):
|
|
"""DedupMiddleware passes new messages through."""
|
|
adapter = make_adapter()
|
|
ctx = make_ctx(adapter=adapter, msg_id="unique-msg-001")
|
|
next_fn = AsyncMock()
|
|
|
|
await DedupMiddleware()(ctx, next_fn)
|
|
next_fn.assert_awaited_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_duplicate_stops_pipeline(self):
|
|
"""DedupMiddleware stops pipeline for duplicate messages."""
|
|
adapter = make_adapter()
|
|
# Mark message as seen
|
|
adapter._dedup.is_duplicate("dup-msg-001")
|
|
|
|
ctx = make_ctx(adapter=adapter, msg_id="dup-msg-001")
|
|
next_fn = AsyncMock()
|
|
|
|
await DedupMiddleware()(ctx, next_fn)
|
|
next_fn.assert_not_awaited()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_empty_msg_id_passes(self):
|
|
"""DedupMiddleware passes messages with empty msg_id."""
|
|
ctx = make_ctx(msg_id="")
|
|
next_fn = AsyncMock()
|
|
|
|
await DedupMiddleware()(ctx, next_fn)
|
|
next_fn.assert_awaited_once()
|
|
|
|
|
|
class TestSkipSelfMiddleware:
|
|
@pytest.mark.asyncio
|
|
async def test_self_message_stops(self):
|
|
"""SkipSelfMiddleware stops pipeline for bot's own messages."""
|
|
adapter = make_adapter()
|
|
adapter._bot_id = "bot_123"
|
|
ctx = make_ctx(adapter=adapter, from_account="bot_123")
|
|
next_fn = AsyncMock()
|
|
|
|
await SkipSelfMiddleware()(ctx, next_fn)
|
|
next_fn.assert_not_awaited()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_other_message_passes(self):
|
|
"""SkipSelfMiddleware passes messages from other users."""
|
|
adapter = make_adapter()
|
|
adapter._bot_id = "bot_123"
|
|
ctx = make_ctx(adapter=adapter, from_account="alice")
|
|
next_fn = AsyncMock()
|
|
|
|
await SkipSelfMiddleware()(ctx, next_fn)
|
|
next_fn.assert_awaited_once()
|
|
|
|
|
|
class TestChatRoutingMiddleware:
|
|
@pytest.mark.asyncio
|
|
async def test_group_routing(self):
|
|
"""ChatRoutingMiddleware sets group chat fields."""
|
|
ctx = make_ctx(group_code="grp-1", group_name="Test Group")
|
|
next_fn = AsyncMock()
|
|
|
|
await ChatRoutingMiddleware()(ctx, next_fn)
|
|
|
|
assert ctx.chat_id == "group:grp-1"
|
|
assert ctx.chat_type == "group"
|
|
assert ctx.chat_name == "Test Group"
|
|
next_fn.assert_awaited_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_dm_routing(self):
|
|
"""ChatRoutingMiddleware sets DM chat fields."""
|
|
ctx = make_ctx(from_account="alice", sender_nickname="Alice")
|
|
next_fn = AsyncMock()
|
|
|
|
await ChatRoutingMiddleware()(ctx, next_fn)
|
|
|
|
assert ctx.chat_id == "direct:alice"
|
|
assert ctx.chat_type == "dm"
|
|
assert ctx.chat_name == "Alice"
|
|
next_fn.assert_awaited_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_dm_routing_no_nickname(self):
|
|
"""ChatRoutingMiddleware falls back to from_account when no nickname."""
|
|
ctx = make_ctx(from_account="alice", sender_nickname="")
|
|
next_fn = AsyncMock()
|
|
|
|
await ChatRoutingMiddleware()(ctx, next_fn)
|
|
|
|
assert ctx.chat_name == "alice"
|
|
|
|
|
|
class TestAccessGuardMiddleware:
|
|
@pytest.mark.asyncio
|
|
async def test_open_policy_passes(self):
|
|
"""AccessGuardMiddleware passes with open policy."""
|
|
adapter = make_adapter()
|
|
adapter._access_policy = AccessPolicy(dm_policy="open", dm_allow_from=[], group_policy="open", group_allow_from=[])
|
|
ctx = make_ctx(adapter=adapter, chat_type="dm", from_account="alice")
|
|
next_fn = AsyncMock()
|
|
|
|
await AccessGuardMiddleware()(ctx, next_fn)
|
|
next_fn.assert_awaited_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_disabled_dm_stops(self):
|
|
"""AccessGuardMiddleware stops DM when dm_policy=disabled."""
|
|
adapter = make_adapter()
|
|
adapter._access_policy = AccessPolicy(dm_policy="disabled", dm_allow_from=[], group_policy="open", group_allow_from=[])
|
|
ctx = make_ctx(adapter=adapter, chat_type="dm", from_account="alice")
|
|
next_fn = AsyncMock()
|
|
|
|
await AccessGuardMiddleware()(ctx, next_fn)
|
|
next_fn.assert_not_awaited()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_allowlist_dm_allowed(self):
|
|
"""AccessGuardMiddleware passes DM when sender is in allowlist."""
|
|
adapter = make_adapter()
|
|
adapter._access_policy = AccessPolicy(dm_policy="allowlist", dm_allow_from=["alice"], group_policy="open", group_allow_from=[])
|
|
ctx = make_ctx(adapter=adapter, chat_type="dm", from_account="alice")
|
|
next_fn = AsyncMock()
|
|
|
|
await AccessGuardMiddleware()(ctx, next_fn)
|
|
next_fn.assert_awaited_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_allowlist_dm_blocked(self):
|
|
"""AccessGuardMiddleware blocks DM when sender is not in allowlist."""
|
|
adapter = make_adapter()
|
|
adapter._access_policy = AccessPolicy(dm_policy="allowlist", dm_allow_from=["bob"], group_policy="open", group_allow_from=[])
|
|
ctx = make_ctx(adapter=adapter, chat_type="dm", from_account="alice")
|
|
next_fn = AsyncMock()
|
|
|
|
await AccessGuardMiddleware()(ctx, next_fn)
|
|
next_fn.assert_not_awaited()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_disabled_group_stops(self):
|
|
"""AccessGuardMiddleware stops group when group_policy=disabled."""
|
|
adapter = make_adapter()
|
|
adapter._access_policy = AccessPolicy(dm_policy="open", dm_allow_from=[], group_policy="disabled", group_allow_from=[])
|
|
ctx = make_ctx(adapter=adapter, chat_type="group", group_code="grp-1")
|
|
next_fn = AsyncMock()
|
|
|
|
await AccessGuardMiddleware()(ctx, next_fn)
|
|
next_fn.assert_not_awaited()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_allowlist_group_allowed(self):
|
|
"""AccessGuardMiddleware passes group when group_code is in allowlist."""
|
|
adapter = make_adapter()
|
|
adapter._access_policy = AccessPolicy(dm_policy="open", dm_allow_from=[], group_policy="allowlist", group_allow_from=["grp-1"])
|
|
ctx = make_ctx(adapter=adapter, chat_type="group", group_code="grp-1")
|
|
next_fn = AsyncMock()
|
|
|
|
await AccessGuardMiddleware()(ctx, next_fn)
|
|
next_fn.assert_awaited_once()
|
|
|
|
|
|
class TestExtractContentMiddleware:
|
|
@pytest.mark.asyncio
|
|
async def test_extracts_text_and_media(self):
|
|
"""ExtractContentMiddleware extracts text and media refs."""
|
|
adapter = make_adapter()
|
|
msg_body = [
|
|
{"msg_type": "TIMTextElem", "msg_content": {"text": "Hello!"}},
|
|
{"msg_type": "TIMImageElem", "msg_content": {
|
|
"image_info_array": [{"url": "https://img.example.com/1.jpg"}]
|
|
}},
|
|
]
|
|
ctx = make_ctx(adapter=adapter, msg_body=msg_body)
|
|
next_fn = AsyncMock()
|
|
|
|
await ExtractContentMiddleware()(ctx, next_fn)
|
|
|
|
assert "Hello!" in ctx.raw_text
|
|
assert len(ctx.media_refs) == 1
|
|
assert ctx.media_refs[0]["kind"] == "image"
|
|
next_fn.assert_awaited_once()
|
|
|
|
|
|
class TestPlaceholderFilterMiddleware:
|
|
@pytest.mark.asyncio
|
|
async def test_placeholder_stops(self):
|
|
"""PlaceholderFilterMiddleware stops on pure placeholder."""
|
|
ctx = make_ctx(raw_text="[image]", media_refs=[])
|
|
next_fn = AsyncMock()
|
|
|
|
await PlaceholderFilterMiddleware()(ctx, next_fn)
|
|
next_fn.assert_not_awaited()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_placeholder_with_media_passes(self):
|
|
"""PlaceholderFilterMiddleware passes placeholder when media exists."""
|
|
ctx = make_ctx(
|
|
raw_text="[image]",
|
|
media_refs=[{"kind": "image", "url": "https://img.example.com/1.jpg"}],
|
|
)
|
|
next_fn = AsyncMock()
|
|
|
|
await PlaceholderFilterMiddleware()(ctx, next_fn)
|
|
next_fn.assert_awaited_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_normal_text_passes(self):
|
|
"""PlaceholderFilterMiddleware passes normal text."""
|
|
ctx = make_ctx(raw_text="Hello world!")
|
|
next_fn = AsyncMock()
|
|
|
|
await PlaceholderFilterMiddleware()(ctx, next_fn)
|
|
next_fn.assert_awaited_once()
|
|
|
|
|
|
class TestGroupAtGuardMiddleware:
|
|
@pytest.mark.asyncio
|
|
async def test_dm_passes(self):
|
|
"""GroupAtGuardMiddleware passes DM messages."""
|
|
adapter = make_adapter()
|
|
ctx = make_ctx(adapter=adapter, chat_type="dm")
|
|
next_fn = AsyncMock()
|
|
|
|
await GroupAtGuardMiddleware()(ctx, next_fn)
|
|
next_fn.assert_awaited_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_group_with_at_bot_passes(self):
|
|
"""GroupAtGuardMiddleware passes group messages that @bot."""
|
|
adapter = make_adapter()
|
|
adapter._bot_id = "bot_123"
|
|
msg_body = [
|
|
{"msg_type": "TIMCustomElem", "msg_content": {
|
|
"data": json.dumps({"elem_type": 1002, "text": "@Bot", "user_id": "bot_123"})
|
|
}},
|
|
]
|
|
ctx = make_ctx(
|
|
adapter=adapter,
|
|
chat_type="group",
|
|
chat_id="group:grp-1",
|
|
msg_body=msg_body,
|
|
from_account="alice",
|
|
sender_nickname="Alice",
|
|
raw_text="Hello",
|
|
source=MagicMock(),
|
|
)
|
|
next_fn = AsyncMock()
|
|
|
|
await GroupAtGuardMiddleware()(ctx, next_fn)
|
|
next_fn.assert_awaited_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_group_without_at_bot_observes(self):
|
|
"""GroupAtGuardMiddleware observes group messages without @bot."""
|
|
adapter = make_adapter()
|
|
adapter._bot_id = "bot_123"
|
|
adapter._session_store = None # No session store -> observe is a no-op
|
|
ctx = make_ctx(
|
|
adapter=adapter,
|
|
chat_type="group",
|
|
chat_id="group:grp-1",
|
|
msg_body=[{"msg_type": "TIMTextElem", "msg_content": {"text": "hi"}}],
|
|
from_account="alice",
|
|
sender_nickname="Alice",
|
|
raw_text="hi",
|
|
source=MagicMock(),
|
|
)
|
|
next_fn = AsyncMock()
|
|
|
|
await GroupAtGuardMiddleware()(ctx, next_fn)
|
|
|
|
next_fn.assert_not_awaited()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_owner_command_skips_at_check(self):
|
|
"""GroupAtGuardMiddleware passes when owner_command is set."""
|
|
adapter = make_adapter()
|
|
adapter._bot_id = "bot_123"
|
|
ctx = make_ctx(
|
|
adapter=adapter,
|
|
chat_type="group",
|
|
msg_body=[],
|
|
owner_command="/new",
|
|
source=MagicMock(),
|
|
)
|
|
next_fn = AsyncMock()
|
|
|
|
await GroupAtGuardMiddleware()(ctx, next_fn)
|
|
next_fn.assert_awaited_once()
|
|
|
|
|
|
# ============================================================
|
|
# 4. Factory Tests
|
|
# ============================================================
|
|
|
|
class TestCreateInboundPipeline:
|
|
def test_default_pipeline_has_all_middlewares(self):
|
|
"""InboundPipelineBuilder.build() creates pipeline with all expected middlewares."""
|
|
pipeline = InboundPipelineBuilder.build()
|
|
expected = [
|
|
"decode",
|
|
"extract-fields",
|
|
"dedup",
|
|
"skip-self",
|
|
"chat-routing",
|
|
"access-guard",
|
|
"extract-content",
|
|
"placeholder-filter",
|
|
"owner-command",
|
|
"build-source",
|
|
"group-at-guard",
|
|
"classify-msg-type",
|
|
"quote-context",
|
|
"media-resolve",
|
|
"dispatch",
|
|
]
|
|
"""Pipeline can be customized after creation."""
|
|
pipeline = InboundPipelineBuilder.build()
|
|
|
|
async def custom_mw(ctx, next_fn):
|
|
await next_fn()
|
|
|
|
pipeline.use_before("dispatch", "custom", custom_mw)
|
|
assert "custom" in pipeline.middleware_names
|
|
idx_custom = pipeline.middleware_names.index("custom")
|
|
idx_dispatch = pipeline.middleware_names.index("dispatch")
|
|
assert idx_custom < idx_dispatch
|
|
|
|
|
|
# ============================================================
|
|
# 5. End-to-End Pipeline Integration Tests
|
|
# ============================================================
|
|
|
|
class TestPipelineIntegration:
|
|
@pytest.mark.asyncio
|
|
async def test_full_dm_message_flow(self):
|
|
"""Full pipeline processes a DM message end-to-end."""
|
|
adapter = make_adapter()
|
|
adapter._bot_id = "bot_123"
|
|
adapter._access_policy = AccessPolicy(dm_policy="open", dm_allow_from=[], group_policy="open", group_allow_from=[])
|
|
adapter.handle_message = AsyncMock()
|
|
adapter._resolve_inbound_media_urls = AsyncMock(return_value=([], []))
|
|
|
|
push_data = make_json_push(
|
|
from_account="alice",
|
|
to_account="bot_123",
|
|
text="Hello bot!",
|
|
msg_id="msg-e2e-001",
|
|
)
|
|
|
|
ctx = InboundContext(adapter=adapter, raw_frames=[push_data])
|
|
pipeline = InboundPipelineBuilder.build()
|
|
await pipeline.execute(ctx)
|
|
|
|
# Verify context was populated correctly
|
|
assert ctx.decoded_via == "json"
|
|
assert ctx.from_account == "alice"
|
|
assert ctx.chat_type == "dm"
|
|
assert ctx.chat_id == "direct:alice"
|
|
assert "Hello bot!" in ctx.raw_text
|
|
assert ctx.source is not None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_self_message_filtered(self):
|
|
"""Pipeline stops when message is from bot itself."""
|
|
adapter = make_adapter()
|
|
adapter._bot_id = "bot_123"
|
|
|
|
push_data = make_json_push(
|
|
from_account="bot_123",
|
|
to_account="bot_123",
|
|
text="echo",
|
|
msg_id="msg-self-001",
|
|
)
|
|
|
|
ctx = InboundContext(adapter=adapter, raw_frames=[push_data])
|
|
pipeline = InboundPipelineBuilder.build()
|
|
await pipeline.execute(ctx)
|
|
|
|
# Pipeline should have stopped at skip-self — no source built
|
|
assert ctx.source is None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_duplicate_message_filtered(self):
|
|
"""Pipeline stops on duplicate message."""
|
|
adapter = make_adapter()
|
|
adapter._bot_id = "bot_123"
|
|
|
|
# First message goes through
|
|
push_data = make_json_push(
|
|
from_account="alice",
|
|
text="Hello!",
|
|
msg_id="msg-dup-001",
|
|
)
|
|
ctx1 = InboundContext(adapter=adapter, raw_frames=[push_data])
|
|
pipeline = InboundPipelineBuilder.build()
|
|
await pipeline.execute(ctx1)
|
|
assert ctx1.from_account == "alice"
|
|
|
|
# Second message with same msg_id is filtered
|
|
ctx2 = InboundContext(adapter=adapter, raw_frames=[push_data])
|
|
await pipeline.execute(ctx2)
|
|
# Dedup should stop pipeline before chat routing
|
|
assert ctx2.chat_type == ""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_blocked_dm_filtered(self):
|
|
"""Pipeline stops when DM is blocked by policy."""
|
|
adapter = make_adapter()
|
|
adapter._bot_id = "bot_123"
|
|
adapter._access_policy = AccessPolicy(dm_policy="disabled", dm_allow_from=[], group_policy="open", group_allow_from=[])
|
|
|
|
push_data = make_json_push(
|
|
from_account="alice",
|
|
text="Hello!",
|
|
msg_id="msg-blocked-001",
|
|
)
|
|
|
|
ctx = InboundContext(adapter=adapter, raw_frames=[push_data])
|
|
pipeline = InboundPipelineBuilder.build()
|
|
await pipeline.execute(ctx)
|
|
|
|
# Pipeline stopped at access-guard — no content extracted
|
|
assert ctx.raw_text == ""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_adapter_has_pipeline(self):
|
|
"""YuanbaoAdapter.__init__ creates an inbound pipeline."""
|
|
adapter = make_adapter()
|
|
assert hasattr(adapter, "_inbound_pipeline")
|
|
assert isinstance(adapter._inbound_pipeline, InboundPipeline)
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
pytest.main([__file__, "-v"])
|
|
|
|
|
|
# ============================================================
|
|
# 6. OOP Middleware Tests
|
|
# ============================================================
|
|
|
|
class TestInboundMiddlewareABC:
|
|
"""Test the InboundMiddleware abstract base class."""
|
|
|
|
def test_cannot_instantiate_abc(self):
|
|
"""InboundMiddleware cannot be instantiated directly."""
|
|
with pytest.raises(TypeError):
|
|
InboundMiddleware()
|
|
|
|
def test_subclass_must_implement_handle(self):
|
|
"""Subclass without handle() raises TypeError."""
|
|
with pytest.raises(TypeError):
|
|
class BadMiddleware(InboundMiddleware):
|
|
name = "bad"
|
|
BadMiddleware()
|
|
|
|
def test_subclass_with_handle_works(self):
|
|
"""Subclass with handle() can be instantiated."""
|
|
class GoodMiddleware(InboundMiddleware):
|
|
name = "good"
|
|
async def handle(self, ctx, next_fn):
|
|
await next_fn()
|
|
mw = GoodMiddleware()
|
|
assert mw.name == "good"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_callable_protocol(self):
|
|
"""Middleware instances are callable via __call__."""
|
|
class TestMW(InboundMiddleware):
|
|
name = "test"
|
|
async def handle(self, ctx, next_fn):
|
|
ctx.raw_text = "called"
|
|
await next_fn()
|
|
|
|
mw = TestMW()
|
|
ctx = make_ctx()
|
|
next_fn = AsyncMock()
|
|
await mw(ctx, next_fn) # Call via __call__
|
|
assert ctx.raw_text == "called"
|
|
next_fn.assert_awaited_once()
|
|
|
|
def test_repr(self):
|
|
"""Middleware has a useful repr."""
|
|
class MyMW(InboundMiddleware):
|
|
name = "my-mw"
|
|
async def handle(self, ctx, next_fn):
|
|
pass
|
|
mw = MyMW()
|
|
assert "MyMW" in repr(mw)
|
|
assert "my-mw" in repr(mw)
|
|
|
|
|
|
class TestMiddlewareClasses:
|
|
"""Test that all concrete middleware classes have correct names and are InboundMiddleware subclasses."""
|
|
|
|
MIDDLEWARE_CLASSES = [
|
|
(DecodeMiddleware, "decode"),
|
|
(ExtractFieldsMiddleware, "extract-fields"),
|
|
(DedupMiddleware, "dedup"),
|
|
(SkipSelfMiddleware, "skip-self"),
|
|
(ChatRoutingMiddleware, "chat-routing"),
|
|
(AccessGuardMiddleware, "access-guard"),
|
|
(ExtractContentMiddleware, "extract-content"),
|
|
(PlaceholderFilterMiddleware, "placeholder-filter"),
|
|
(OwnerCommandMiddleware, "owner-command"),
|
|
(BuildSourceMiddleware, "build-source"),
|
|
(GroupAtGuardMiddleware, "group-at-guard"),
|
|
(DispatchMiddleware, "dispatch"),
|
|
]
|
|
|
|
@pytest.mark.parametrize("cls,expected_name", MIDDLEWARE_CLASSES)
|
|
def test_is_inbound_middleware(self, cls, expected_name):
|
|
"""Each middleware class is a subclass of InboundMiddleware."""
|
|
assert issubclass(cls, InboundMiddleware)
|
|
|
|
@pytest.mark.parametrize("cls,expected_name", MIDDLEWARE_CLASSES)
|
|
def test_has_correct_name(self, cls, expected_name):
|
|
"""Each middleware class has the expected name."""
|
|
mw = cls()
|
|
assert mw.name == expected_name
|
|
|
|
@pytest.mark.parametrize("cls,expected_name", MIDDLEWARE_CLASSES)
|
|
def test_is_callable(self, cls, expected_name):
|
|
"""Each middleware instance is callable."""
|
|
mw = cls()
|
|
assert callable(mw)
|
|
|
|
|
|
class TestPipelineOOPRegistration:
|
|
"""Test that InboundPipeline works with OOP middleware instances."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_use_with_middleware_instance(self):
|
|
"""pipeline.use(SomeMiddleware()) auto-extracts name."""
|
|
class TestMW(InboundMiddleware):
|
|
name = "test-mw"
|
|
async def handle(self, ctx, next_fn):
|
|
ctx.raw_text = "oop-works"
|
|
await next_fn()
|
|
|
|
pipeline = InboundPipeline().use(TestMW())
|
|
assert pipeline.middleware_names == ["test-mw"]
|
|
|
|
ctx = make_ctx()
|
|
await pipeline.execute(ctx)
|
|
assert ctx.raw_text == "oop-works"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_mixed_oop_and_functional(self):
|
|
"""Pipeline supports mixing OOP and functional middlewares."""
|
|
order = []
|
|
|
|
class OopMW(InboundMiddleware):
|
|
name = "oop"
|
|
async def handle(self, ctx, next_fn):
|
|
order.append("oop")
|
|
await next_fn()
|
|
|
|
async def func_mw(ctx, next_fn):
|
|
order.append("func")
|
|
await next_fn()
|
|
|
|
pipeline = (
|
|
InboundPipeline()
|
|
.use(OopMW())
|
|
.use("func", func_mw)
|
|
)
|
|
assert pipeline.middleware_names == ["oop", "func"]
|
|
|
|
await pipeline.execute(make_ctx())
|
|
assert order == ["oop", "func"]
|
|
|
|
def test_use_before_with_middleware_instance(self):
|
|
"""use_before works with OOP middleware instances."""
|
|
class MwA(InboundMiddleware):
|
|
name = "a"
|
|
async def handle(self, ctx, next_fn): await next_fn()
|
|
|
|
class MwB(InboundMiddleware):
|
|
name = "b"
|
|
async def handle(self, ctx, next_fn): await next_fn()
|
|
|
|
class MwC(InboundMiddleware):
|
|
name = "c"
|
|
async def handle(self, ctx, next_fn): await next_fn()
|
|
|
|
pipeline = InboundPipeline().use(MwA()).use(MwC())
|
|
pipeline.use_before("c", MwB())
|
|
assert pipeline.middleware_names == ["a", "b", "c"]
|
|
|
|
def test_use_after_with_middleware_instance(self):
|
|
"""use_after works with OOP middleware instances."""
|
|
class MwA(InboundMiddleware):
|
|
name = "a"
|
|
async def handle(self, ctx, next_fn): await next_fn()
|
|
|
|
class MwB(InboundMiddleware):
|
|
name = "b"
|
|
async def handle(self, ctx, next_fn): await next_fn()
|
|
|
|
class MwC(InboundMiddleware):
|
|
name = "c"
|
|
async def handle(self, ctx, next_fn): await next_fn()
|
|
|
|
pipeline = InboundPipeline().use(MwA()).use(MwC())
|
|
pipeline.use_after("a", MwB())
|
|
assert pipeline.middleware_names == ["a", "b", "c"]
|