Files
hermes-agent/tests/test_yuanbao_proto.py
Teknium ab6879634e yuanbao platform (#16298)
Co-authored-by: loongzhao <loongzhao@tencent.com>
2026-04-26 18:50:49 -07:00

655 lines
23 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
test_yuanbao_proto.py - yuanbao_proto 单元测试
测试覆盖:
1. varint 编解码 round-trip
2. conn 层 encode/decode round-trip
3. biz 层 encode/decode round-trip
4. decode_inbound_push 解析 TIMTextElem 消息
5. encode_send_c2c_message / encode_send_group_message 编码
6. 固定 bytes 常量验证(防止协议悄悄改动)
7. auth-bind / ping 编码
"""
import sys
import os
# 确保 hermes-agent 根目录在 sys.path 中
_REPO_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
if _REPO_ROOT not in sys.path:
sys.path.insert(0, _REPO_ROOT)
import pytest
from gateway.platforms.yuanbao_proto import (
# 基础工具
_encode_varint,
_decode_varint,
_parse_fields,
_fields_to_dict,
_encode_msg_body_element,
_decode_msg_body_element,
_encode_msg_content,
_decode_msg_content,
# conn 层
encode_conn_msg,
decode_conn_msg,
encode_conn_msg_full,
# biz 层
encode_biz_msg,
decode_biz_msg,
# 入站/出站
decode_inbound_push,
encode_send_c2c_message,
encode_send_group_message,
# 帮助函数
encode_auth_bind,
encode_ping,
encode_push_ack,
# 常量
PB_MSG_TYPES,
BIZ_SERVICES,
CMD_TYPE,
CMD,
MODULE,
next_seq_no,
)
# ===========================================================
# 1. varint 编解码
# ===========================================================
class TestVarint:
def test_small_values(self):
for v in [0, 1, 127, 128, 255, 300, 16383, 16384, 2**21, 2**28]:
encoded = _encode_varint(v)
decoded, pos = _decode_varint(encoded, 0)
assert decoded == v, f"round-trip failed for {v}"
assert pos == len(encoded)
def test_zero(self):
assert _encode_varint(0) == b"\x00"
v, p = _decode_varint(b"\x00", 0)
assert v == 0 and p == 1
def test_1_byte_boundary(self):
# 127 = 0x7F => 1 byte
assert _encode_varint(127) == b"\x7f"
# 128 => 2 bytes: 0x80 0x01
assert _encode_varint(128) == b"\x80\x01"
def test_known_values(self):
# protobuf spec examples
# 300 => 0xAC 0x02
assert _encode_varint(300) == bytes([0xAC, 0x02])
def test_multi_byte(self):
# 2^32 - 1 = 4294967295
v = 2**32 - 1
enc = _encode_varint(v)
dec, _ = _decode_varint(enc, 0)
assert dec == v
def test_partial_decode(self):
# 在 offset 处解码
data = b"\x00" + _encode_varint(300) + b"\x00"
v, pos = _decode_varint(data, 1)
assert v == 300
assert pos == 3 # 1 + 2 bytes for 300
# ===========================================================
# 2. conn 层 round-trip
# ===========================================================
class TestConnCodec:
def test_basic_round_trip(self):
payload = b"hello world"
encoded = encode_conn_msg(msg_type=0, seq_no=42, data=payload)
decoded = decode_conn_msg(encoded)
assert decoded["msg_type"] == 0
assert decoded["seq_no"] == 42
assert decoded["data"] == payload
def test_empty_data(self):
encoded = encode_conn_msg(msg_type=2, seq_no=0, data=b"")
decoded = decode_conn_msg(encoded)
assert decoded["msg_type"] == 2
assert decoded["data"] == b""
def test_all_cmd_types(self):
for ct in [0, 1, 2, 3]:
enc = encode_conn_msg(msg_type=ct, seq_no=1, data=b"\x01\x02")
dec = decode_conn_msg(enc)
assert dec["msg_type"] == ct
def test_large_seq_no(self):
enc = encode_conn_msg(msg_type=1, seq_no=2**32 - 1, data=b"x")
dec = decode_conn_msg(enc)
assert dec["seq_no"] == 2**32 - 1
def test_full_round_trip(self):
"""encode_conn_msg_full 含 cmd/msg_id/module"""
enc = encode_conn_msg_full(
cmd_type=CMD_TYPE["Request"],
cmd="auth-bind",
seq_no=99,
msg_id="abc123",
module="conn_access",
data=b"\xde\xad\xbe\xef",
)
dec = decode_conn_msg(enc)
head = dec["head"]
assert head["cmd_type"] == CMD_TYPE["Request"]
assert head["cmd"] == "auth-bind"
assert head["seq_no"] == 99
assert head["msg_id"] == "abc123"
assert head["module"] == "conn_access"
assert dec["data"] == b"\xde\xad\xbe\xef"
# 固定 bytes 常量测试——防协议悄悄改动
def test_fixed_bytes_simple(self):
"""
encode_conn_msg(msg_type=0, seq_no=1, data=b"") 的固定编码。
ConnMsg { head { seq_no=1 } }
head bytes: field3 varint(1) = 0x18 0x01
head field: field1 len(2) 0x18 0x01 = 0x0a 0x02 0x18 0x01
"""
enc = encode_conn_msg(msg_type=0, seq_no=1, data=b"")
# head: field 3 (seq_no=1) => tag=0x18, value=0x01
head_content = bytes([0x18, 0x01])
# outer field 1 (head message)
expected = bytes([0x0a, len(head_content)]) + head_content
assert enc == expected, f"got: {enc.hex()}, expected: {expected.hex()}"
# ===========================================================
# 3. biz 层 round-trip
# ===========================================================
class TestBizCodec:
def test_round_trip(self):
body = b"\x0a\x05hello"
enc = encode_biz_msg(
service="trpc.yuanbao.example",
method="/im/send_c2c_msg",
req_id="req-001",
body=body,
)
dec = decode_biz_msg(enc)
assert dec["service"] == "trpc.yuanbao.example"
assert dec["method"] == "/im/send_c2c_msg"
assert dec["req_id"] == "req-001"
assert dec["body"] == body
assert dec["is_response"] is False
def test_is_response_flag(self):
# Response cmd_type = 1
enc = encode_conn_msg_full(
cmd_type=CMD_TYPE["Response"],
cmd="/im/send_c2c_msg",
seq_no=1,
msg_id="rsp-001",
module="svc",
data=b"\x01",
)
dec = decode_biz_msg(enc)
assert dec["is_response"] is True
def test_empty_body(self):
enc = encode_biz_msg("svc", "method", "id1", b"")
dec = decode_biz_msg(enc)
assert dec["body"] == b""
assert dec["method"] == "method"
# ===========================================================
# 4. MsgContent / MsgBodyElement 编解码
# ===========================================================
class TestMsgBodyElement:
def test_text_elem_round_trip(self):
el = {
"msg_type": "TIMTextElem",
"msg_content": {"text": "Hello, 世界!"},
}
encoded = _encode_msg_body_element(el)
decoded = _decode_msg_body_element(encoded)
assert decoded["msg_type"] == "TIMTextElem"
assert decoded["msg_content"]["text"] == "Hello, 世界!"
def test_image_elem_round_trip(self):
el = {
"msg_type": "TIMImageElem",
"msg_content": {
"uuid": "img-uuid-123",
"image_format": 2,
"url": "https://example.com/img.jpg",
"image_info_array": [
{"type": 1, "size": 1024, "width": 100, "height": 200, "url": "https://thumb.jpg"},
],
},
}
encoded = _encode_msg_body_element(el)
decoded = _decode_msg_body_element(encoded)
assert decoded["msg_type"] == "TIMImageElem"
mc = decoded["msg_content"]
assert mc["uuid"] == "img-uuid-123"
assert mc["image_format"] == 2
assert mc["url"] == "https://example.com/img.jpg"
assert len(mc["image_info_array"]) == 1
assert mc["image_info_array"][0]["url"] == "https://thumb.jpg"
def test_file_elem_round_trip(self):
el = {
"msg_type": "TIMFileElem",
"msg_content": {
"url": "https://example.com/file.pdf",
"file_size": 204800,
"file_name": "document.pdf",
},
}
enc = _encode_msg_body_element(el)
dec = _decode_msg_body_element(enc)
assert dec["msg_content"]["file_name"] == "document.pdf"
assert dec["msg_content"]["file_size"] == 204800
def test_custom_elem_round_trip(self):
el = {
"msg_type": "TIMCustomElem",
"msg_content": {
"data": '{"key":"value"}',
"desc": "custom description",
"ext": "extra info",
},
}
enc = _encode_msg_body_element(el)
dec = _decode_msg_body_element(enc)
assert dec["msg_content"]["data"] == '{"key":"value"}'
assert dec["msg_content"]["desc"] == "custom description"
def test_empty_content(self):
el = {"msg_type": "TIMTextElem", "msg_content": {}}
enc = _encode_msg_body_element(el)
dec = _decode_msg_body_element(enc)
assert dec["msg_type"] == "TIMTextElem"
def test_fixed_text_elem_bytes(self):
"""
固定 bytes 验证TIMTextElem { text="hi" }
MsgBodyElement:
field1 (msg_type="TIMTextElem"): 0a 0b 54494d5465787445 6c656d
field2 (msg_content): 12 <len> <content>
MsgContent field1 (text="hi"): 0a 02 6869
"""
el = {
"msg_type": "TIMTextElem",
"msg_content": {"text": "hi"},
}
enc = _encode_msg_body_element(el)
# 手动计算期望值
# msg_type = "TIMTextElem" (11 bytes)
type_bytes = b"TIMTextElem"
# MsgContent: field1(text="hi") = tag(0a) + len(02) + "hi"
content_inner = bytes([0x0a, 0x02]) + b"hi"
# MsgBodyElement:
# field1: tag=0x0a, len=11, type_bytes
# field2: tag=0x12, len=len(content_inner), content_inner
expected = (
bytes([0x0a, len(type_bytes)]) + type_bytes
+ bytes([0x12, len(content_inner)]) + content_inner
)
assert enc == expected, f"got {enc.hex()}, expected {expected.hex()}"
# ===========================================================
# 5. decode_inbound_push 测试
# ===========================================================
class TestDecodeInboundPush:
def _build_inbound_push_bytes(
self,
from_account: str = "user123",
to_account: str = "bot456",
group_code: str = "",
msg_key: str = "key-001",
msg_seq: int = 12345,
text: str = "Hello!",
) -> bytes:
"""手工构造 InboundMessagePush bytes与 proto 字段顺序一致)"""
from gateway.platforms.yuanbao_proto import (
_encode_field, _encode_string, _encode_message,
_encode_varint, WT_LEN, WT_VARINT,
)
el = {
"msg_type": "TIMTextElem",
"msg_content": {"text": text},
}
el_bytes = _encode_msg_body_element(el)
buf = b""
buf += _encode_field(2, WT_LEN, _encode_string(from_account)) # from_account
buf += _encode_field(3, WT_LEN, _encode_string(to_account)) # to_account
if group_code:
buf += _encode_field(6, WT_LEN, _encode_string(group_code)) # group_code
buf += _encode_field(8, WT_VARINT, _encode_varint(msg_seq)) # msg_seq
buf += _encode_field(11, WT_LEN, _encode_string(msg_key)) # msg_key
buf += _encode_field(13, WT_LEN, _encode_message(el_bytes)) # msg_body[0]
return buf
def test_basic_c2c_text_message(self):
raw = self._build_inbound_push_bytes(
from_account="alice",
to_account="bot",
msg_key="k001",
msg_seq=100,
text="你好",
)
result = decode_inbound_push(raw)
assert result is not None
assert result["from_account"] == "alice"
assert result["to_account"] == "bot"
assert result["msg_seq"] == 100
assert result["msg_key"] == "k001"
assert len(result["msg_body"]) == 1
assert result["msg_body"][0]["msg_type"] == "TIMTextElem"
assert result["msg_body"][0]["msg_content"]["text"] == "你好"
def test_group_message(self):
raw = self._build_inbound_push_bytes(
from_account="bob",
to_account="bot",
group_code="group-789",
msg_seq=999,
text="group msg",
)
result = decode_inbound_push(raw)
assert result is not None
assert result["group_code"] == "group-789"
assert result["msg_body"][0]["msg_content"]["text"] == "group msg"
def test_returns_none_on_empty(self):
# 空 bytes 应返回空字段 dict而不是 None
result = decode_inbound_push(b"")
# 空消息解析结果是 {}(无字段),过滤后 msg_body=[] 也会保留
assert result is not None or result is None # 不崩溃即可
def test_multiple_msg_body_elements(self):
from gateway.platforms.yuanbao_proto import (
_encode_field, _encode_message, WT_LEN,
)
el1 = _encode_msg_body_element(
{"msg_type": "TIMTextElem", "msg_content": {"text": "part1"}}
)
el2 = _encode_msg_body_element(
{"msg_type": "TIMTextElem", "msg_content": {"text": "part2"}}
)
buf = (
_encode_field(2, WT_LEN, b"\x05alice")
+ _encode_field(13, WT_LEN, _encode_message(el1))
+ _encode_field(13, WT_LEN, _encode_message(el2))
)
result = decode_inbound_push(buf)
assert result is not None
assert len(result["msg_body"]) == 2
assert result["msg_body"][0]["msg_content"]["text"] == "part1"
assert result["msg_body"][1]["msg_content"]["text"] == "part2"
# ===========================================================
# 6. 出站消息编码
# ===========================================================
class TestEncodeOutbound:
def test_encode_send_c2c_message(self):
msg_body = [{"msg_type": "TIMTextElem", "msg_content": {"text": "hi"}}]
result = encode_send_c2c_message(
to_account="user_b",
msg_body=msg_body,
from_account="bot",
msg_id="msg-001",
)
assert isinstance(result, bytes)
assert len(result) > 0
# 解码验证 ConnMsg 结构
dec = decode_conn_msg(result)
assert dec["head"]["cmd"] == "send_c2c_message"
assert dec["head"]["msg_id"] == "msg-001"
assert dec["head"]["module"] == "yuanbao_openclaw_proxy"
assert len(dec["data"]) > 0
def test_encode_send_group_message(self):
msg_body = [{"msg_type": "TIMTextElem", "msg_content": {"text": "group hello"}}]
result = encode_send_group_message(
group_code="grp-100",
msg_body=msg_body,
from_account="bot",
msg_id="msg-002",
)
assert isinstance(result, bytes)
dec = decode_conn_msg(result)
assert dec["head"]["cmd"] == "send_group_message"
assert dec["head"]["msg_id"] == "msg-002"
assert len(dec["data"]) > 0
def test_c2c_biz_payload_contains_to_account(self):
"""验证 biz payload 包含 to_account 字段"""
from gateway.platforms.yuanbao_proto import _parse_fields, _fields_to_dict, _get_string
msg_body = [{"msg_type": "TIMTextElem", "msg_content": {"text": "test"}}]
result = encode_send_c2c_message(
to_account="target_user",
msg_body=msg_body,
from_account="bot",
)
dec = decode_conn_msg(result)
biz_data = dec["data"]
fdict = _fields_to_dict(_parse_fields(biz_data))
to_acc = _get_string(fdict, 2) # SendC2CMessageReq.to_account = field 2
assert to_acc == "target_user"
def test_group_biz_payload_contains_group_code(self):
from gateway.platforms.yuanbao_proto import _parse_fields, _fields_to_dict, _get_string
msg_body = [{"msg_type": "TIMTextElem", "msg_content": {"text": "test"}}]
result = encode_send_group_message(
group_code="group-xyz",
msg_body=msg_body,
from_account="bot",
)
dec = decode_conn_msg(result)
biz_data = dec["data"]
fdict = _fields_to_dict(_parse_fields(biz_data))
grp = _get_string(fdict, 2) # SendGroupMessageReq.group_code = field 2
assert grp == "group-xyz"
# ===========================================================
# 7. AuthBind / Ping 编码
# ===========================================================
class TestAuthAndPing:
def test_encode_auth_bind(self):
result = encode_auth_bind(
biz_id="ybBot",
uid="user_001",
source="app",
token="tok_abc",
msg_id="auth-001",
app_version="1.0.0",
operation_system="Linux",
bot_version="0.1.0",
)
assert isinstance(result, bytes)
dec = decode_conn_msg(result)
assert dec["head"]["cmd"] == "auth-bind"
assert dec["head"]["module"] == "conn_access"
assert dec["head"]["msg_id"] == "auth-001"
assert len(dec["data"]) > 0
def test_encode_ping(self):
result = encode_ping("ping-001")
assert isinstance(result, bytes)
dec = decode_conn_msg(result)
assert dec["head"]["cmd"] == "ping"
assert dec["head"]["module"] == "conn_access"
def test_encode_push_ack(self):
original_head = {
"cmd_type": CMD_TYPE["Push"],
"cmd": "some-push",
"seq_no": 100,
"msg_id": "push-001",
"module": "im_module",
"need_ack": True,
"status": 0,
}
result = encode_push_ack(original_head)
dec = decode_conn_msg(result)
assert dec["head"]["cmd_type"] == CMD_TYPE["PushAck"]
assert dec["head"]["cmd"] == "some-push"
assert dec["head"]["msg_id"] == "push-001"
# ===========================================================
# 8. 常量验证
# ===========================================================
class TestConstants:
def test_pb_msg_types_keys(self):
assert "ConnMsg" in PB_MSG_TYPES
assert "AuthBindReq" in PB_MSG_TYPES
assert "PingReq" in PB_MSG_TYPES
assert "KickoutMsg" in PB_MSG_TYPES
assert "PushMsg" in PB_MSG_TYPES
def test_biz_services_keys(self):
assert "SendC2CMessageReq" in BIZ_SERVICES
assert "SendGroupMessageReq" in BIZ_SERVICES
assert "InboundMessagePush" in BIZ_SERVICES
def test_cmd_type_values(self):
assert CMD_TYPE["Request"] == 0
assert CMD_TYPE["Response"] == 1
assert CMD_TYPE["Push"] == 2
assert CMD_TYPE["PushAck"] == 3
def test_pkg_prefix(self):
for k, v in BIZ_SERVICES.items():
assert v.startswith("yuanbao_openclaw_proxy"), \
f"{k}: unexpected prefix in {v}"
# ===========================================================
# 9. seq_no 生成
# ===========================================================
class TestSeqNo:
def test_monotonic(self):
a = next_seq_no()
b = next_seq_no()
c = next_seq_no()
assert b > a
assert c > b
def test_thread_safety(self):
import threading
results = []
lock = threading.Lock()
def worker():
for _ in range(100):
v = next_seq_no()
with lock:
results.append(v)
threads = [threading.Thread(target=worker) for _ in range(10)]
for t in threads:
t.start()
for t in threads:
t.join()
# 无重复
assert len(results) == len(set(results)), "duplicate seq_no detected"
# ===========================================================
# 10. 完整端到端流程(模拟 send -> recv
# ===========================================================
class TestEndToEnd:
def test_send_recv_c2c(self):
"""模拟发送 C2C 消息,然后(在接收方)解码"""
msg_body = [
{"msg_type": "TIMTextElem", "msg_content": {"text": "端到端测试"}},
]
# 发送方编码
wire_bytes = encode_send_c2c_message(
to_account="recv_user",
msg_body=msg_body,
from_account="send_bot",
msg_id="e2e-001",
)
# 接收方解码 ConnMsg
dec = decode_conn_msg(wire_bytes)
assert dec["head"]["cmd"] == "send_c2c_message"
assert dec["head"]["msg_id"] == "e2e-001"
# 从 biz payload 中读取 to_account 和 msg_body
from gateway.platforms.yuanbao_proto import (
_parse_fields, _fields_to_dict, _get_string, _get_repeated_bytes, WT_LEN
)
biz = dec["data"]
fdict = _fields_to_dict(_parse_fields(biz))
assert _get_string(fdict, 2) == "recv_user" # to_account
assert _get_string(fdict, 3) == "send_bot" # from_account
el_list = _get_repeated_bytes(fdict, 5) # msg_body repeated
assert len(el_list) == 1
el_dec = _decode_msg_body_element(el_list[0])
assert el_dec["msg_type"] == "TIMTextElem"
assert el_dec["msg_content"]["text"] == "端到端测试"
def test_inbound_push_full_flow(self):
"""构造服务端 push -> 解码入站消息"""
from gateway.platforms.yuanbao_proto import (
_encode_field, _encode_string, _encode_message,
_encode_varint, WT_LEN, WT_VARINT,
)
# 构造入站消息 biz payload
el_bytes = _encode_msg_body_element(
{"msg_type": "TIMTextElem", "msg_content": {"text": "server push"}}
)
biz_payload = (
_encode_field(2, WT_LEN, _encode_string("alice"))
+ _encode_field(3, WT_LEN, _encode_string("bot"))
+ _encode_field(6, WT_LEN, _encode_string("grp-001"))
+ _encode_field(8, WT_VARINT, _encode_varint(555))
+ _encode_field(11, WT_LEN, _encode_string("msg-key-xyz"))
+ _encode_field(13, WT_LEN, _encode_message(el_bytes))
)
# 封装成 ConnMsg模拟服务端 push
wire = encode_conn_msg_full(
cmd_type=CMD_TYPE["Push"],
cmd="/im/new_message",
seq_no=77,
msg_id="push-abc",
module="yuanbao_openclaw_proxy",
data=biz_payload,
need_ack=True,
)
# 接收方解码
conn = decode_conn_msg(wire)
assert conn["head"]["cmd_type"] == CMD_TYPE["Push"]
assert conn["head"]["need_ack"] is True
msg = decode_inbound_push(conn["data"])
assert msg is not None
assert msg["from_account"] == "alice"
assert msg["group_code"] == "grp-001"
assert msg["msg_seq"] == 555
assert msg["msg_key"] == "msg-key-xyz"
assert msg["msg_body"][0]["msg_content"]["text"] == "server push"
if __name__ == "__main__":
pytest.main([__file__, "-v"])