diff --git a/agent/prompt_builder.py b/agent/prompt_builder.py index 3a6ec24415..aaef51192f 100644 --- a/agent/prompt_builder.py +++ b/agent/prompt_builder.py @@ -422,6 +422,29 @@ PLATFORM_HINTS = { "your response. Images are sent as native photos, and other files arrive as downloadable " "documents." ), + "yuanbao": ( + "You are on Yuanbao (腾讯元宝), a Chinese AI assistant platform. " + "Markdown formatting is supported (code blocks, tables, bold/italic). " + "You CAN send media files natively — to deliver a file to the user, include " + "MEDIA:/absolute/path/to/file in your response. The file will be sent as a native " + "Yuanbao attachment: images (.jpg, .png, .webp, .gif) are sent as photos, " + "and other files (.pdf, .docx, .txt, .zip, etc.) arrive as downloadable documents " + "(max 50 MB). You can also include image URLs in markdown format ![alt](url) and " + "they will be downloaded and sent as native photos. " + "Do NOT tell the user you lack file-sending capability — use MEDIA: syntax " + "whenever a file delivery is appropriate.\n\n" + "Stickers (贴纸 / 表情包 / TIM face): Yuanbao has a built-in sticker catalogue. " + "When the user sends a sticker (you see '[emoji: 名称]' in their message) or asks " + "you to send/reply-with a 贴纸/表情/表情包, you MUST use the sticker tools:\n" + " 1. Call yb_search_sticker with a Chinese keyword (e.g. '666', '比心', '吃瓜', " + " '捂脸', '合十') to discover matching sticker_ids.\n" + " 2. Call yb_send_sticker with the chosen sticker_id or name — this sends a real " + " TIMFaceElem that renders as a native sticker in the chat.\n" + "DO NOT draw sticker-like PNGs with execute_code/Pillow/matplotlib and then send " + "them via MEDIA: or send_image_file. That produces a fake low-quality 'sticker' " + "image and is the WRONG path. Bare Unicode emoji in text is also not a substitute " + "— when a sticker is the right response, use yb_send_sticker." + ), } # --------------------------------------------------------------------------- diff --git a/cli-config.yaml.example b/cli-config.yaml.example index 984a9bfe84..d6cb0bcb46 100644 --- a/cli-config.yaml.example +++ b/cli-config.yaml.example @@ -606,6 +606,7 @@ platform_toolsets: signal: [hermes-signal] homeassistant: [hermes-homeassistant] qqbot: [hermes-qqbot] + yuanbao: [hermes-yuanbao] # ============================================================================= # Gateway Platform Settings diff --git a/cron/scheduler.py b/cron/scheduler.py index 27690ac5e2..12dae811fd 100644 --- a/cron/scheduler.py +++ b/cron/scheduler.py @@ -77,7 +77,7 @@ _KNOWN_DELIVERY_PLATFORMS = frozenset({ "telegram", "discord", "slack", "whatsapp", "signal", "matrix", "mattermost", "homeassistant", "dingtalk", "feishu", "wecom", "wecom_callback", "weixin", "sms", "email", "webhook", "bluebubbles", - "qqbot", + "qqbot", "yuanbao", }) # Platforms that support a configured cron/notification home target, mapped to @@ -337,6 +337,7 @@ def _deliver_result(job: dict, content: str, adapters=None, loop=None) -> Option "sms": Platform.SMS, "bluebubbles": Platform.BLUEBUBBLES, "qqbot": Platform.QQBOT, + "yuanbao": Platform.YUANBAO, } # Optionally wrap the content with a header/footer so the user knows this diff --git a/gateway/config.py b/gateway/config.py index e585ec0413..128bfa61ca 100644 --- a/gateway/config.py +++ b/gateway/config.py @@ -67,6 +67,7 @@ class Platform(Enum): WEIXIN = "weixin" BLUEBUBBLES = "bluebubbles" QQBOT = "qqbot" + YUANBAO = "yuanbao" @dataclass @@ -326,6 +327,9 @@ class GatewayConfig: # QQBot uses extra dict for app credentials elif platform == Platform.QQBOT and config.extra.get("app_id") and config.extra.get("client_secret"): connected.append(platform) + # Yuanbao uses extra dict for app credentials + elif platform == Platform.YUANBAO and config.extra.get("app_id") and config.extra.get("app_secret"): + connected.append(platform) # DingTalk uses client_id/client_secret from config.extra or env vars elif platform == Platform.DINGTALK and ( config.extra.get("client_id") or os.getenv("DINGTALK_CLIENT_ID") @@ -1296,6 +1300,48 @@ def _apply_env_overrides(config: GatewayConfig) -> None: name=os.getenv("QQBOT_HOME_CHANNEL_NAME") or os.getenv(qq_home_name_env, "Home"), ) + # Yuanbao — YUANBAO_APP_ID preferred + yuanbao_app_id = os.getenv("YUANBAO_APP_ID") or os.getenv("YUANBAO_APP_KEY") + yuanbao_app_secret = os.getenv("YUANBAO_APP_SECRET") + if yuanbao_app_id and yuanbao_app_secret: + if Platform.YUANBAO not in config.platforms: + config.platforms[Platform.YUANBAO] = PlatformConfig() + config.platforms[Platform.YUANBAO].enabled = True + extra = config.platforms[Platform.YUANBAO].extra + extra["app_id"] = yuanbao_app_id + extra["app_secret"] = yuanbao_app_secret + yuanbao_bot_id = os.getenv("YUANBAO_BOT_ID") + if yuanbao_bot_id: + extra["bot_id"] = yuanbao_bot_id + yuanbao_ws_url = os.getenv("YUANBAO_WS_URL") + if yuanbao_ws_url: + extra["ws_url"] = yuanbao_ws_url + yuanbao_api_domain = os.getenv("YUANBAO_API_DOMAIN") + if yuanbao_api_domain: + extra["api_domain"] = yuanbao_api_domain + yuanbao_route_env = os.getenv("YUANBAO_ROUTE_ENV") + if yuanbao_route_env: + extra["route_env"] = yuanbao_route_env + yuanbao_home = os.getenv("YUANBAO_HOME_CHANNEL") + if yuanbao_home: + config.platforms[Platform.YUANBAO].home_channel = HomeChannel( + platform=Platform.YUANBAO, + chat_id=yuanbao_home, + name=os.getenv("YUANBAO_HOME_CHANNEL_NAME", "Home"), + ) + yuanbao_dm_policy = os.getenv("YUANBAO_DM_POLICY") + if yuanbao_dm_policy: + extra["dm_policy"] = yuanbao_dm_policy.strip().lower() + yuanbao_dm_allow_from = os.getenv("YUANBAO_DM_ALLOW_FROM") + if yuanbao_dm_allow_from: + extra["dm_allow_from"] = yuanbao_dm_allow_from + yuanbao_group_policy = os.getenv("YUANBAO_GROUP_POLICY") + if yuanbao_group_policy: + extra["group_policy"] = yuanbao_group_policy.strip().lower() + yuanbao_group_allow_from = os.getenv("YUANBAO_GROUP_ALLOW_FROM") + if yuanbao_group_allow_from: + extra["group_allow_from"] = yuanbao_group_allow_from + # Session settings idle_minutes = os.getenv("SESSION_IDLE_MINUTES") if idle_minutes: diff --git a/gateway/platforms/__init__.py b/gateway/platforms/__init__.py index 4eb26edf06..5f978896bc 100644 --- a/gateway/platforms/__init__.py +++ b/gateway/platforms/__init__.py @@ -10,10 +10,12 @@ Each adapter handles: from .base import BasePlatformAdapter, MessageEvent, SendResult from .qqbot import QQAdapter +from .yuanbao import YuanbaoAdapter __all__ = [ "BasePlatformAdapter", "MessageEvent", "SendResult", "QQAdapter", + "YuanbaoAdapter", ] diff --git a/gateway/platforms/yuanbao.py b/gateway/platforms/yuanbao.py new file mode 100644 index 0000000000..49df1b6c4a --- /dev/null +++ b/gateway/platforms/yuanbao.py @@ -0,0 +1,4754 @@ +""" +Yuanbao platform adapter. + +Connects to the Yuanbao WebSocket gateway, handles authentication (AUTH_BIND), +heartbeat, reconnection, message receive (T05) and send (T06). + +Configuration in config.yaml (or via env vars): + platforms: + yuanbao: + extra: + app_id: "..." # or YUANBAO_APP_ID + app_secret: "..." # or YUANBAO_APP_SECRET + bot_id: "..." # or YUANBAO_BOT_ID (optional, returned by sign-token) + ws_url: "wss://..." # or YUANBAO_WS_URL + api_domain: "https://..." # or YUANBAO_API_DOMAIN +""" + +from __future__ import annotations + +import asyncio +import collections +import dataclasses +import hashlib +import hmac +import json +import logging +import os +import re +import secrets +import time +import urllib.parse +import uuid +from datetime import datetime, timezone, timedelta +from pathlib import Path +from abc import ABC, abstractmethod +from typing import Any, Callable, ClassVar, Dict, List, Optional, Tuple + +import sys + +import httpx + +try: + import websockets + import websockets.exceptions + WEBSOCKETS_AVAILABLE = True +except ImportError: + WEBSOCKETS_AVAILABLE = False + websockets = None # type: ignore[assignment] + +from gateway.config import Platform, PlatformConfig +from gateway.platforms.base import ( + BasePlatformAdapter, + MessageEvent, + MessageType, + SendResult, + cache_document_from_bytes, + cache_image_from_bytes, +) +from gateway.platforms.helpers import MessageDeduplicator +from gateway.platforms.yuanbao_media import ( + download_url as media_download_url, + get_cos_credentials, + upload_to_cos, + build_image_msg_body, + build_file_msg_body, + guess_mime_type, + md5_hex, +) +from gateway.platforms.yuanbao_proto import ( + CMD_TYPE, + _fields_to_dict, + _get_string, + _get_varint, + _parse_fields, + WS_HEARTBEAT_RUNNING, + WS_HEARTBEAT_FINISH, + HERMES_INSTANCE_ID, + decode_conn_msg, + decode_inbound_push, + decode_query_group_info_rsp, + decode_get_group_member_list_rsp, + encode_auth_bind, + encode_ping, + encode_push_ack, + encode_send_c2c_message, + encode_send_group_message, + encode_send_private_heartbeat, + encode_send_group_heartbeat, + encode_query_group_info, + encode_get_group_member_list, + next_seq_no, +) +from gateway.session import SessionSource, build_session_key + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Version / platform constants (used in AUTH_BIND and sign-token headers) +# --------------------------------------------------------------------------- +try: + from hermes_cli import __version__ as _HERMES_VERSION +except ImportError: + _HERMES_VERSION = "0.0.0" + +_APP_VERSION = _HERMES_VERSION +_BOT_VERSION = _HERMES_VERSION +_YUANBAO_INSTANCE_ID = str(HERMES_INSTANCE_ID) # single source: yuanbao_proto.HERMES_INSTANCE_ID +_OPERATION_SYSTEM = sys.platform + +# --------------------------------------------------------------------------- +# Module-level constants +# --------------------------------------------------------------------------- + +DEFAULT_WS_GATEWAY_URL = "wss://bot-wss.yuanbao.tencent.com/wss/connection" +DEFAULT_API_DOMAIN = "https://bot.yuanbao.tencent.com" + +HEARTBEAT_INTERVAL_SECONDS = 30.0 +CONNECT_TIMEOUT_SECONDS = 15.0 +AUTH_TIMEOUT_SECONDS = 10.0 +MAX_RECONNECT_ATTEMPTS = 100 +DEFAULT_SEND_TIMEOUT = 30.0 # WS biz request timeout + +# Close codes that indicate permanent errors — do NOT reconnect. +NO_RECONNECT_CLOSE_CODES = {4012, 4013, 4014, 4018, 4019, 4021} + +# Heartbeat timeout threshold — N consecutive missed pongs trigger reconnect. +HEARTBEAT_TIMEOUT_THRESHOLD = 2 + +# Auth error code classification +AUTH_FAILED_CODES = {4001, 4002, 4003} # permanent auth failure, re-sign token +AUTH_RETRYABLE_CODES = {4010, 4011, 4099} # transient, can retry with same token + +# Reply Heartbeat configuration +REPLY_HEARTBEAT_INTERVAL_S = 2.0 # Send RUNNING every 2 seconds +REPLY_HEARTBEAT_TIMEOUT_S = 30.0 # Auto-stop after 30 seconds of inactivity + +# Reply-to reference configuration +REPLY_REF_TTL_S = 300.0 # Reference dedup TTL (5 minutes) + +# Slow-response hint: push a waiting message when agent produces no data for this duration (seconds) +SLOW_RESPONSE_TIMEOUT_S = 120.0 +SLOW_RESPONSE_MESSAGE = "任务有点复杂,正在努力处理中,请耐心等待..." + +# Regex matching Yuanbao resource reference anchors in transcript text: +# [image|ybres:abc123] [file:report.pdf|ybres:xyz789] [voice|ybres:...] +_YB_RES_REF_RE = re.compile( + r"\[(image|voice|video|file(?::[^|\]]*)?)\|ybres:([A-Za-z0-9_\-]+)\]" +) + +# Strip page indicators like (1/3) appended by BasePlatformAdapter +_INDICATOR_RE = re.compile(r'\s*\(\d+/\d+\)$') + +# Observed-media backfill: how many recent transcript messages to scan +OBSERVED_MEDIA_BACKFILL_LOOKBACK = 50 +# Max number of resource references to resolve per inbound turn +OBSERVED_MEDIA_BACKFILL_MAX_RESOLVE_PER_TURN = 12 + +class MarkdownProcessor: + """Encapsulates all Markdown-related utilities for the Yuanbao platform. + + Provides static methods for: + - Fence detection and streaming merge + - Table row detection and sanitization + - Paragraph-boundary splitting + - Atomic-block extraction and chunk splitting + - Outer markdown fence stripping + - Markdown hint prompt generation + """ + + # -- Fence detection --------------------------------------------------- + + @staticmethod + def has_unclosed_fence(text: str) -> bool: + """ + Detect whether the text has unclosed code block fences. + + Scan line by line, toggling in/out state when encountering a line starting with ```. + An odd number of toggles indicates an unclosed fence. + + Args: + text: Markdown text to check + + Returns: + Returns True if the text ends with an unclosed fence, otherwise False + """ + in_fence = False + for line in text.split('\n'): + if line.startswith('```'): + in_fence = not in_fence + return in_fence + + # -- Table detection --------------------------------------------------- + + @staticmethod + def ends_with_table_row(text: str) -> bool: + """ + Detect whether the text ends with a table row (last non-empty line starts and ends with |). + + Args: + text: Text to check + + Returns: + Returns True if the last non-empty line is a table row + """ + trimmed = text.rstrip() + if not trimmed: + return False + last_line = trimmed.split('\n')[-1].strip() + return last_line.startswith('|') and last_line.endswith('|') + + # -- Paragraph boundary splitting -------------------------------------- + + @staticmethod + def split_at_paragraph_boundary( + text: str, + max_chars: int, + len_fn: Optional[Callable[[str], int]] = None, + ) -> tuple[str, str]: + """ + Find the nearest paragraph boundary split point within max_chars, return (head, tail). + + Split priority: + 1. Blank line (paragraph boundary) + 2. Newline after period/question mark/exclamation mark (Chinese and English) + 3. Last newline + 4. Force split at max_chars + + Args: + text: Text to split + max_chars: Maximum character count limit + len_fn: Optional custom length function (e.g. UTF-16 length); defaults to built-in len + + Returns: + (head, tail) tuple, head is the front part, tail is the back part, satisfying head + tail == text + """ + _len = len_fn or len + if _len(text) <= max_chars: + return text, '' + + # Build a character-index window that fits within max_chars. + # When len_fn != len we cannot simply slice [:max_chars], so we + # binary-search for the largest prefix that fits. + if _len is len: + window = text[:max_chars] + else: + lo, hi = 0, len(text) + while lo < hi: + mid = (lo + hi + 1) // 2 + if _len(text[:mid]) <= max_chars: + lo = mid + else: + hi = mid - 1 + window = text[:lo] + + # 1. Prefer the last blank line (\n\n) as paragraph boundary + pos = window.rfind('\n\n') + if pos > 0: + return text[:pos + 2], text[pos + 2:] + + # 2. Then find the last newline after a sentence-ending punctuation + sentence_end_re = re.compile(r'[。!?.!?]\n') + best_pos = -1 + for m in sentence_end_re.finditer(window): + best_pos = m.end() + if best_pos > 0: + return text[:best_pos], text[best_pos:] + + # 3. Fallback: find the last newline + pos = window.rfind('\n') + if pos > 0: + return text[:pos + 1], text[pos + 1:] + + # 4. No valid split point found, force split at window boundary + cut = len(window) + return text[:cut], text[cut:] + + # -- Atomic block helpers (private) ------------------------------------ + + @staticmethod + def is_fence_atom(text: str) -> bool: + """Determine whether an atomic block is a code block (starts with ```).""" + return text.lstrip().startswith('```') + + @staticmethod + def is_table_atom(text: str) -> bool: + """Determine whether an atomic block is a table (first line starts with |).""" + first_line = text.split('\n')[0].strip() + return first_line.startswith('|') and first_line.endswith('|') + + @staticmethod + def split_into_atoms(text: str) -> list[str]: + """ + Split text into a list of "atomic blocks", each being an indivisible logical unit: + + - Code block (fence): from opening ``` to closing ``` (including fence lines) + - Table: consecutive |...| lines forming a whole segment + - Normal paragraph: plain text segments separated by blank lines + + Blank lines serve as separators and are not included in any atomic block. + + Args: + text: Markdown text to split + + Returns: + List of atomic block strings (all non-empty) + """ + lines = text.split('\n') + atoms: list[str] = [] + + current_lines: list[str] = [] + in_fence = False + + def _is_table_line(line: str) -> bool: + stripped = line.strip() + return stripped.startswith('|') and stripped.endswith('|') + + def _flush_current() -> None: + if current_lines: + atom = '\n'.join(current_lines) + if atom.strip(): + atoms.append(atom) + current_lines.clear() + + for line in lines: + if in_fence: + current_lines.append(line) + if line.startswith('```') and len(current_lines) > 1: + in_fence = False + _flush_current() + elif line.startswith('```'): + _flush_current() + in_fence = True + current_lines.append(line) + elif _is_table_line(line): + if current_lines and not _is_table_line(current_lines[-1]): + _flush_current() + current_lines.append(line) + elif line.strip() == '': + _flush_current() + else: + if current_lines and _is_table_line(current_lines[-1]): + _flush_current() + current_lines.append(line) + + _flush_current() + + return atoms + + # -- Core: chunk splitting --------------------------------------------- + + @classmethod + def chunk_markdown_text( + cls, + text: str, + max_chars: int = 4000, + len_fn: Optional[Callable[[str], int]] = None, + ) -> list[str]: + """ + Split Markdown text into multiple chunks by max_chars. + + Guarantees: + - Each chunk <= max_chars characters (unless a single code block/table itself exceeds the limit) + - Code blocks (```...```) are not split in the middle + - Table rows are not split in the middle (tables output as atomic blocks) + - Split at paragraph boundaries (blank lines, after periods, etc.) + - Small trailing/leading chunks are merged with neighbours when possible + + Args: + text: Markdown text to split + max_chars: Max characters per chunk, default 4000 + len_fn: Optional custom length function (e.g. UTF-16 length); defaults to built-in len + + Returns: + List of text chunks after splitting (non-empty) + """ + _len = len_fn or len + + if not text: + return [] + + if _len(text) <= max_chars: + return [text] + + # Phase 1: Extract atomic blocks + atoms = cls.split_into_atoms(text) + + # Phase 2: Greedy merge + chunks: list[str] = [] + indivisible_set: set[int] = set() + current_parts: list[str] = [] + current_len = 0 + + def _flush_parts() -> None: + if current_parts: + chunks.append('\n\n'.join(current_parts)) + + for atom in atoms: + atom_len = _len(atom) + sep_len = 2 if current_parts else 0 + projected_len = current_len + sep_len + atom_len + + if projected_len > max_chars and current_parts: + _flush_parts() + current_parts = [] + current_len = 0 + sep_len = 0 + + if (not current_parts + and atom_len > max_chars + and (cls.is_fence_atom(atom) or cls.is_table_atom(atom))): + indivisible_set.add(len(chunks)) + chunks.append(atom) + continue + + current_parts.append(atom) + current_len += sep_len + atom_len + + _flush_parts() + + # Phase 3: Post-processing — split still-oversized chunks at paragraph boundaries + result: list[str] = [] + for idx, chunk in enumerate(chunks): + if _len(chunk) <= max_chars: + result.append(chunk) + continue + + if idx in indivisible_set: + result.append(chunk) + continue + + if cls.has_unclosed_fence(chunk): + result.append(chunk) + continue + + remaining = chunk + while _len(remaining) > max_chars: + head, remaining = cls.split_at_paragraph_boundary( + remaining, max_chars, len_fn=len_fn, + ) + if not head: + head, remaining = remaining[:max_chars], remaining[max_chars:] + if head: + result.append(head) + if remaining: + result.append(remaining) + + # Phase 4: Merge small trailing/leading chunks with neighbours + if len(result) > 1: + merged: list[str] = [result[0]] + for chunk in result[1:]: + prev = merged[-1] + combined = prev + '\n\n' + chunk + if _len(combined) <= max_chars: + merged[-1] = combined + else: + merged.append(chunk) + result = merged + + return [c for c in result if c] + + # -- Block separator inference ----------------------------------------- + + @classmethod + def infer_block_separator(cls, prev_chunk: str, next_chunk: str) -> str: + """ + Infer the separator to use between two split chunks. + + Rules (aligned with TS markdown-stream.ts): + - Previous chunk ends with code fence or next chunk starts with fence → single newline '\\n' + - Previous chunk ends with table row and next chunk starts with table row → single newline '\\n' (continued table) + - Otherwise → double newline '\\n\\n' (paragraph separator) + + Args: + prev_chunk: Previous chunk + next_chunk: Next chunk + + Returns: + '\\n' or '\\n\\n' + """ + prev_trimmed = prev_chunk.rstrip() + next_trimmed = next_chunk.lstrip() + + # Previous chunk ends with fence or next chunk starts with fence + if prev_trimmed.endswith('```') or next_trimmed.startswith('```'): + return '\n' + + # Table continuation + if cls.ends_with_table_row(prev_chunk): + first_line = next_trimmed.split('\n')[0].strip() if next_trimmed else '' + if first_line.startswith('|') and first_line.endswith('|'): + return '\n' + + return '\n\n' + + # -- Streaming fence merge --------------------------------------------- + + @classmethod + def merge_block_streaming_fences(cls, chunks: list[str]) -> list[str]: + """ + Stream-aware fence-conscious chunk merging. + + When streaming output produces multiple chunks truncated in the middle of a fence, + attempt to merge adjacent chunks to complete the fence. + + Rules: + - If chunk i has an unclosed fence and chunk i+1 starts with ```, + merge i+1 into i (until the fence is closed or no more chunks). + - Use infer_block_separator to infer the separator during merging. + + Args: + chunks: Original chunk list + + Returns: + Merged chunk list (length <= original length) + """ + if not chunks: + return [] + + result: list[str] = [] + i = 0 + while i < len(chunks): + current = chunks[i] + # If current chunk has unclosed fence, try merging subsequent chunks + while cls.has_unclosed_fence(current) and i + 1 < len(chunks): + sep = cls.infer_block_separator(current, chunks[i + 1]) + current = current + sep + chunks[i + 1] + i += 1 + result.append(current) + i += 1 + + return result + + # -- Outer fence stripping --------------------------------------------- + + @staticmethod + def strip_outer_markdown_fence(text: str) -> str: + """ + Strip outer Markdown fence. + + When AI reply is entirely wrapped in ```markdown\\n...\\n```, remove the outer fence, + keeping the content. Only strip when the first line is ```markdown (case-insensitive) and the last line is ```. + + Args: + text: Text to process + + Returns: + Text with outer fence stripped (returns original if no match) + """ + if not text: + return text + + lines = text.split('\n') + if len(lines) < 3: + return text + + first_line = lines[0].strip() + last_line = lines[-1].strip() + + # First line must be ```markdown (optional language tag md/markdown) + if not re.match(r'^```(?:markdown|md)?\s*$', first_line, re.IGNORECASE): + return text + + # Last line must be plain ``` + if last_line != '```': + return text + + # Strip first and last lines + inner = '\n'.join(lines[1:-1]) + return inner + + # -- Table sanitization ------------------------------------------------ + + @staticmethod + def sanitize_markdown_table(text: str) -> str: + """ + Table output sanitization. + + Handle common formatting issues in AI-generated Markdown tables: + 1. Remove extra whitespace before/after table rows + 2. Ensure separator rows (|---|---|) are correctly formatted + 3. Remove empty table rows + + Args: + text: Markdown text containing tables + + Returns: + Sanitized text + """ + if '|' not in text: + return text + + lines = text.split('\n') + result_lines: list[str] = [] + + for line in lines: + stripped = line.strip() + + # Table row processing + if stripped.startswith('|') and stripped.endswith('|'): + # Separator row normalization: | --- | --- | → |---|---| + if re.match(r'^\|[\s\-:]+(\|[\s\-:]+)+\|$', stripped): + cells = stripped.split('|') + normalized = '|'.join( + cell.strip() if cell.strip() else cell + for cell in cells + ) + result_lines.append(normalized) + elif stripped == '||' or stripped.replace('|', '').strip() == '': + # Empty table row → skip + continue + else: + result_lines.append(stripped) + else: + result_lines.append(line) + + return '\n'.join(result_lines) + + # -- Markdown hint prompt ---------------------------------------------- + + @staticmethod + def markdown_hint_system_prompt() -> str: + """ + Markdown rendering hint (appended to system prompt). + + Tell AI that Yuanbao platform supports Markdown rendering, including: + - Code blocks (```lang) + - Tables (| col | col |) + - Bold/italic + """ + return ( + "The current platform supports Markdown rendering. You can use the following formats:\n" + "- Code blocks: ```language\\ncode\\n```\n" + "- Tables: | col1 | col2 |\\n|---|---|\\n| val1 | val2 |\n" + "- Bold: **text** / Italic: *text*\n" + "Please use Markdown formatting when appropriate to improve readability." + ) + +class SignManager: + """Encapsulates all sign-token related logic for the Yuanbao platform. + + Manages token acquisition, caching, signature computation, and + automatic retry. All state (cache, locks) is kept as class-level + attributes so that a single shared client serves the whole process. + """ + + # -- Constants --------------------------------------------------------- + + TOKEN_PATH = "/api/v5/robotLogic/sign-token" + + RETRYABLE_CODE = 10099 + MAX_RETRIES = 3 + RETRY_DELAY_S = 1.0 + + #: Early refresh margin (seconds), treat as expiring 60s before actual expiry + CACHE_REFRESH_MARGIN_S = 60 + + #: HTTP timeout (seconds) + HTTP_TIMEOUT_S = 10.0 + + # -- Class-level shared state ------------------------------------------ + + # key: app_key → {"token", "bot_id", "expire_ts", ...} + _cache: dict[str, dict[str, Any]] = {} + + # Per-app_key refresh locks — prevents concurrent duplicate sign-token + # requests. Created lazily inside get_refresh_lock() which is only called + # from async context, so the Lock is always bound to the correct loop. + # disconnect() clears this dict to prevent stale locks across reconnects. + _locks: dict[str, asyncio.Lock] = {} + + # -- Internal helpers -------------------------------------------------- + + @classmethod + def get_refresh_lock(cls, app_key: str) -> asyncio.Lock: + """Return (creating if needed) the per-app_key refresh lock. + + Must only be called from within a running event loop (async context). + """ + if app_key not in cls._locks: + cls._locks[app_key] = asyncio.Lock() + return cls._locks[app_key] + + @staticmethod + def compute_signature(nonce: str, timestamp: str, app_key: str, app_secret: str) -> str: + """Compute HMAC-SHA256 signature (aligned with TypeScript original). + + plain = nonce + timestamp + app_key + app_secret + signature = HMAC-SHA256(key=app_secret, msg=plain).hexdigest() + """ + plain = nonce + timestamp + app_key + app_secret + return hmac.new(app_secret.encode(), plain.encode(), hashlib.sha256).hexdigest() + + @staticmethod + def build_timestamp() -> str: + """Build Beijing-time ISO-8601 timestamp (no milliseconds). + + Format: 2006-01-02T15:04:05+08:00 + """ + bjtime = datetime.now(tz=timezone(timedelta(hours=8))) + return bjtime.strftime("%Y-%m-%dT%H:%M:%S+08:00") + + @classmethod + def is_cache_valid(cls, entry: dict[str, Any]) -> bool: + """Determine whether the cache entry is valid (not expired with margin).""" + return entry["expire_ts"] - time.time() > cls.CACHE_REFRESH_MARGIN_S + + @classmethod + def clear_locks(cls) -> None: + """Clear all per-app_key refresh locks (called on disconnect).""" + cls._locks.clear() + + @classmethod + def purge_expired(cls) -> int: + """Remove all expired entries from the token cache. + + Returns the number of entries purged. Called lazily from + ``get_token()`` so that stale app_key entries don't accumulate + indefinitely in long-running processes. + """ + now = time.time() + expired_keys = [ + k for k, v in cls._cache.items() + if now - v.get("expire_ts", 0) > 0 + ] + for k in expired_keys: + cls._cache.pop(k, None) + return len(expired_keys) + + # -- Core: fetch ------------------------------------------------------- + + @classmethod + async def fetch( + cls, + app_key: str, + app_secret: str, + api_domain: str, + route_env: str = "", + ) -> dict[str, Any]: + """Send sign-ticket HTTP request with auto-retry (up to MAX_RETRIES times).""" + url = f"{api_domain.rstrip('/')}{cls.TOKEN_PATH}" + async with httpx.AsyncClient(timeout=cls.HTTP_TIMEOUT_S) as client: + for attempt in range(cls.MAX_RETRIES + 1): + nonce = secrets.token_hex(16) + timestamp = cls.build_timestamp() + signature = cls.compute_signature(nonce, timestamp, app_key, app_secret) + + payload = { + "app_key": app_key, + "nonce": nonce, + "signature": signature, + "timestamp": timestamp, + } + + headers = { + "Content-Type": "application/json", + "X-AppVersion": _APP_VERSION, + "X-OperationSystem": _OPERATION_SYSTEM, + "X-Instance-Id": _YUANBAO_INSTANCE_ID, + "X-Bot-Version": _BOT_VERSION, + } + if route_env: + headers["X-Route-Env"] = route_env + + logger.info( + "Sign token request: url=%s%s", + url, + f" (retry {attempt}/{cls.MAX_RETRIES})" if attempt > 0 else "", + ) + + response = await client.post(url, json=payload, headers=headers) + + if response.status_code != 200: + body = response.text + raise RuntimeError(f"Sign token API returned {response.status_code}: {body[:200]}") + + try: + result_data: dict[str, Any] = response.json() + except Exception as exc: + raise ValueError(f"Sign token response parse error: {exc}") from exc + + code = result_data.get("code") + if code == 0: + data = result_data.get("data") + if not isinstance(data, dict): + raise ValueError(f"Sign token response missing 'data' field: {result_data}") + logger.info("Sign token success: bot_id=%s", data.get("bot_id")) + return data + + if code == cls.RETRYABLE_CODE and attempt < cls.MAX_RETRIES: + logger.warning( + "Sign token retryable: code=%s, retrying in %ss (attempt=%d/%d)", + code, + cls.RETRY_DELAY_S, + attempt + 1, + cls.MAX_RETRIES, + ) + await asyncio.sleep(cls.RETRY_DELAY_S) + continue + + msg = result_data.get("msg", "") + raise RuntimeError(f"Sign token error: code={code}, msg={msg}") + + raise RuntimeError("Sign token failed: max retries exceeded") + + # -- Public API: get (with cache) -------------------------------------- + + @classmethod + async def get_token( + cls, + app_key: str, + app_secret: str, + api_domain: str, + route_env: str = "", + ) -> dict[str, Any]: + """Get WS auth token (with cache). + + Return directly on cache hit without re-requesting; treat as expiring + 60 seconds before actual expiry, triggering refresh. + """ + # Lazily evict stale entries from other app_keys + cls.purge_expired() + + cached = cls._cache.get(app_key) + if cached and cls.is_cache_valid(cached): + remain = int(cached["expire_ts"] - time.time()) + logger.info("Using cached token (%ds remaining)", remain) + return dict(cached) + + async with cls.get_refresh_lock(app_key): + cached = cls._cache.get(app_key) + if cached and cls.is_cache_valid(cached): + return dict(cached) + + data = await cls.fetch(app_key, app_secret, api_domain, route_env) + + duration: int = data.get("duration", 0) + expire_ts = time.time() + duration if duration > 0 else time.time() + 3600 + + cls._cache[app_key] = { + "token": data.get("token", ""), + "bot_id": data.get("bot_id", ""), + "duration": duration, + "product": data.get("product", ""), + "source": data.get("source", ""), + "expire_ts": expire_ts, + } + + return dict(cls._cache[app_key]) + + # -- Public API: force refresh ----------------------------------------- + + @classmethod + async def force_refresh( + cls, + app_key: str, + app_secret: str, + api_domain: str, + route_env: str = "", + ) -> dict[str, Any]: + """Force refresh token (clear cache and re-sign).""" + logger.warning("[force-refresh] Clearing cache and re-signing token: app_key=****%s", app_key[-4:]) + async with cls.get_refresh_lock(app_key): + cls._cache.pop(app_key, None) + data = await cls.fetch(app_key, app_secret, api_domain, route_env) + + duration: int = data.get("duration", 0) + expire_ts = time.time() + duration if duration > 0 else time.time() + 3600 + + cls._cache[app_key] = { + "token": data.get("token", ""), + "bot_id": data.get("bot_id", ""), + "duration": duration, + "product": data.get("product", ""), + "source": data.get("source", ""), + "expire_ts": expire_ts, + } + + return dict(cls._cache[app_key]) + + +from dataclasses import dataclass, field as dc_field + +@dataclass +class InboundContext: + """Mutable context flowing through the inbound middleware pipeline. + + Each middleware reads/writes fields on this context. The pipeline + engine passes it to every middleware in registration order. + """ + + adapter: Any # YuanbaoAdapter (forward-ref avoids circular import) + raw_frames: list = dc_field(default_factory=list) # Raw bytes frames (debounce-aggregated) + + # Populated by DecodeMiddleware + push: Optional[dict] = None + decoded_via: str = "" # "json" | "protobuf" + + # Extracted from push by FieldExtractMiddleware + from_account: str = "" + group_code: str = "" + group_name: str = "" + sender_nickname: str = "" + msg_body: list = dc_field(default_factory=list) + msg_id: str = "" + cloud_custom_data: str = "" + + # Derived by ChatRoutingMiddleware + chat_id: str = "" + chat_type: str = "" # "dm" | "group" + chat_name: str = "" + + # Populated by ContentExtractMiddleware + raw_text: str = "" + media_refs: list = dc_field(default_factory=list) + + # Owner command detection + owner_command: Optional[str] = None + + # Source built by BuildSourceMiddleware + source: Optional[Any] = None # SessionSource + + # Populated by ClassifyMessageTypeMiddleware + msg_type: Optional[Any] = None # MessageType + + # Populated by QuoteContextMiddleware + reply_to_message_id: Optional[str] = None + reply_to_text: Optional[str] = None + + # Populated by MediaResolveMiddleware + media_urls: list = dc_field(default_factory=list) + media_types: list = dc_field(default_factory=list) + + # Populated by ExtractContentMiddleware + link_urls: list = dc_field(default_factory=list) + + # Populated by GroupAttributionMiddleware + channel_prompt: Optional[str] = None + + +class InboundMiddleware(ABC): + """Abstract base class for all inbound pipeline middlewares. + + Subclasses must: + - Set ``name`` as a class-level attribute (used for pipeline registration + and dynamic insertion/removal). + - Implement ``async handle(ctx, next_fn)`` containing the middleware logic. + + Convention: + - Call ``await next_fn()`` to pass control to the next middleware. + - Return without calling ``next_fn`` to **stop** the pipeline. + """ + + name: str = "" # Override in each subclass + + @abstractmethod + async def handle(self, ctx: InboundContext, next_fn: Callable) -> None: + """Process *ctx* and optionally call *next_fn* to continue the pipeline.""" + + async def __call__(self, ctx: InboundContext, next_fn: Callable) -> None: + """Allow middleware instances to be called directly (duck-typing compat).""" + return await self.handle(ctx, next_fn) + + def __repr__(self) -> str: + return f"<{self.__class__.__name__} name={self.name!r}>" + + +class InboundPipeline: + """Onion-model middleware pipeline engine for inbound message processing. + + Inspired by OpenClaw's MessagePipeline (extensions/yuanbao/src/business/ + pipeline/engine.ts). Supports named middlewares, conditional guards + (``when``), and ``use_before`` / ``use_after`` / ``remove`` for dynamic + composition. + + Accepts both ``InboundMiddleware`` instances (OOP style) and plain + ``async def(ctx, next_fn)`` callables (functional style) for flexibility. + """ + + def __init__(self) -> None: + self._middlewares: list = [] # list of (name, handler, when_fn | None) + + # -- Internal helpers -------------------------------------------------- + + @staticmethod + def _normalize(name_or_mw, handler=None): + """Normalize (name, handler) or (InboundMiddleware,) into (name, callable).""" + if isinstance(name_or_mw, InboundMiddleware): + return name_or_mw.name, name_or_mw + # Functional style: name is a str, handler is a callable + return name_or_mw, handler + + # -- Registration API -------------------------------------------------- + + def use(self, name_or_mw, handler=None, when=None) -> "InboundPipeline": + """Append a middleware to the end of the pipeline. + + Accepts either: + - ``pipeline.use(SomeMiddleware())`` — OOP style + - ``pipeline.use("name", some_fn)`` — functional style + """ + name, h = self._normalize(name_or_mw, handler) + self._middlewares.append((name, h, when)) + return self + + def use_before(self, target: str, name_or_mw, handler=None, when=None) -> "InboundPipeline": + """Insert a middleware before *target* (by name). Appends if not found.""" + name, h = self._normalize(name_or_mw, handler) + idx = next((i for i, (n, _, _) in enumerate(self._middlewares) if n == target), None) + entry = (name, h, when) + if idx is None: + self._middlewares.append(entry) + else: + self._middlewares.insert(idx, entry) + return self + + def use_after(self, target: str, name_or_mw, handler=None, when=None) -> "InboundPipeline": + """Insert a middleware after *target* (by name). Appends if not found.""" + name, h = self._normalize(name_or_mw, handler) + idx = next((i for i, (n, _, _) in enumerate(self._middlewares) if n == target), None) + entry = (name, h, when) + if idx is None: + self._middlewares.append(entry) + else: + self._middlewares.insert(idx + 1, entry) + return self + + def remove(self, name: str) -> "InboundPipeline": + """Remove a middleware by name.""" + self._middlewares = [(n, h, w) for n, h, w in self._middlewares if n != name] + return self + + @property + def middleware_names(self) -> list: + """Return ordered list of registered middleware names (for testing).""" + return [n for n, _, _ in self._middlewares] + + # -- Execution --------------------------------------------------------- + + async def execute(self, ctx: InboundContext) -> None: + """Run all middlewares in order. Each middleware receives ``(ctx, next_fn)``.""" + chain = self._middlewares + index = 0 + + async def next_fn() -> None: + nonlocal index + while index < len(chain): + name, handler, when_fn = chain[index] + index += 1 + # Conditional guard: skip when returns False + if when_fn is not None and not when_fn(ctx): + continue + try: + await handler(ctx, next_fn) + except Exception: + logger.error("[InboundPipeline] middleware [%s] error", name, exc_info=True) + raise + return + # End of chain — nothing more to do + + await next_fn() +class DecodeMiddleware(InboundMiddleware): + """Decode raw inbound frames from JSON or Protobuf into ctx.push. + + Encapsulates JSON push parsing (aligned with TS decodeFromContent) + and Protobuf decoding via ``decode_inbound_push``. + """ + + name = "decode" + + # -- JSON push parsing ------------------------------------------------- + + @staticmethod + def convert_json_msg_body(raw_body: list) -> list: + """Normalize raw JSON msg_body array to [{"msg_type": str, "msg_content": dict}]. + + Compatible with both PascalCase (MsgType/MsgContent) and + snake_case (msg_type/msg_content) naming. + """ + result = [] + for item in raw_body or []: + if not isinstance(item, dict): + continue + msg_type = item.get("msg_type") or item.get("MsgType", "") + msg_content = item.get("msg_content") or item.get("MsgContent", {}) + if isinstance(msg_content, str): + try: + msg_content = json.loads(msg_content) + except Exception: + msg_content = {"text": msg_content} + result.append({"msg_type": msg_type, "msg_content": msg_content or {}}) + return result + + @staticmethod + def parse_json_push(raw_json: dict) -> dict | None: + """Convert JSON-format push to a dict with the same structure as + ``decode_inbound_push``. + + Supports standard callback format (callback_command + from_account + + msg_body) and legacy format fields (GroupId, MsgSeq, MsgKey, MsgBody, + etc.). + """ + if not raw_json: + return None + + # Tencent IM callback format uses PascalCase (From_Account, To_Account, MsgBody). + # Internal format uses snake_case (from_account, to_account, msg_body). + # Support both. + from_account = ( + raw_json.get("from_account", "") + or raw_json.get("From_Account", "") + ) + group_code = ( + raw_json.get("group_code", "") + or raw_json.get("GroupId", "") + or raw_json.get("group_id", "") + ) + msg_body_raw = ( + raw_json.get("msg_body", []) + or raw_json.get("MsgBody", []) + ) + msg_body = DecodeMiddleware.convert_json_msg_body(msg_body_raw) + + # Recall callbacks may have neither from_account nor msg_body. + if not from_account and not msg_body and not raw_json.get("callback_command"): + return None + + return { + "callback_command": raw_json.get("callback_command", ""), + "from_account": from_account, + "to_account": raw_json.get("to_account", "") or raw_json.get("To_Account", ""), + "sender_nickname": raw_json.get("sender_nickname", "") or raw_json.get("nick_name", ""), + "group_code": group_code, + "group_name": raw_json.get("group_name", ""), + "msg_seq": raw_json.get("msg_seq", 0) or raw_json.get("MsgSeq", 0), + "msg_id": raw_json.get("msg_id", "") or raw_json.get("msg_key", "") or raw_json.get("MsgKey", ""), + "msg_body": msg_body, + "cloud_custom_data": raw_json.get("cloud_custom_data", "") or raw_json.get("CloudCustomData", ""), + "bot_owner_id": raw_json.get("bot_owner_id", "") or raw_json.get("botOwnerId", ""), + "recall_msg_seq_list": raw_json.get("recall_msg_seq_list") or None, + "trace_id": (raw_json.get("log_ext") or {}).get("trace_id", "") if isinstance(raw_json.get("log_ext"), dict) else "", + } + + # -- Pipeline handler -------------------------------------------------- + + def _decode_single(self, adapter, data: bytes) -> tuple: + """Decode a single raw frame into (push_dict, decoded_via) or (None, '').""" + try: + conn_json = json.loads(data.decode("utf-8")) + except Exception: + conn_json = None + + if isinstance(conn_json, dict): + push = self.parse_json_push(conn_json) + if push: + return push, "json" + else: + try: + push = decode_inbound_push(data) + except Exception: + push = None + if push: + return push, "protobuf" + + return None, "" + + async def handle(self, ctx: InboundContext, next_fn) -> None: + data_list = ctx.raw_frames + if not data_list: + return # Stop pipeline — nothing to decode + + merged_push = None + decoded_via = "" + + for data in data_list: + push, via = self._decode_single(ctx.adapter, data) + if not push: + logger.info( + "[%s] Push decoded but no valid message. raw hex(first64)=%s", + ctx.adapter.name, data.hex()[:128] if data else "(empty)", + ) + continue + + if merged_push is None: + # First valid push becomes the base + merged_push = push + decoded_via = via + logger.info( + "[%s] Frame decoded (via=%s): len=%d", + ctx.adapter.name, via, len(data), + ) + else: + # Subsequent pushes: merge msg_body into the base with a + extra_body = push.get("msg_body", []) + if extra_body: + _sep = {"msg_type": "TIMTextElem", "msg_content": {"text": "\n"}} + merged_push["msg_body"] = merged_push.get("msg_body", []) + [_sep] + extra_body + logger.info( + "[%s] Merged %d extra msg_body elements from aggregated push", + ctx.adapter.name, len(extra_body), + ) + + if not merged_push: + return # Stop pipeline + + ctx.push = merged_push + ctx.decoded_via = decoded_via + + logger.info( + "[%s] Push decoded (via=%s): from=%s group=%s msg_id=%s msg_types=%s", + ctx.adapter.name, ctx.decoded_via, + ctx.push.get("from_account", ""), + ctx.push.get("group_code", ""), + ctx.push.get("msg_id", ""), + [e.get("msg_type", "") for e in ctx.push.get("msg_body", [])], + ) + logger.debug("[%s] Push payload: %s", ctx.adapter.name, ctx.push) + + await next_fn() + + +class ExtractFieldsMiddleware(InboundMiddleware): + """Extract common fields from ctx.push into ctx attributes.""" + + name = "extract-fields" + + async def handle(self, ctx: InboundContext, next_fn) -> None: + push = ctx.push + ctx.from_account = push.get("from_account", "") + ctx.group_code = push.get("group_code", "") + ctx.group_name = push.get("group_name", "") + ctx.sender_nickname = push.get("sender_nickname", "") + ctx.msg_body = push.get("msg_body", []) + ctx.msg_id = push.get("msg_id", "") + ctx.cloud_custom_data = push.get("cloud_custom_data", "") + await next_fn() + + +class DedupMiddleware(InboundMiddleware): + """Inbound message deduplication.""" + + name = "dedup" + + async def handle(self, ctx: InboundContext, next_fn) -> None: + if ctx.msg_id and ctx.adapter._dedup.is_duplicate(ctx.msg_id): + logger.debug("[%s] Duplicate message ignored: msg_id=%s", ctx.adapter.name, ctx.msg_id) + return # Stop pipeline + await next_fn() + + +class RecallGuardMiddleware(InboundMiddleware): + """Intercept Group.CallbackAfterRecallMsg / C2C.CallbackAfterMsgWithDraw. + + Branch A: message in transcript (observed, not yet consumed) → redact content + Branch B: message not in transcript → append system note + Branch C: message currently being processed → silent interrupt + delayed redact + """ + + name = "recall_guard" + + _RECALL_COMMANDS = frozenset({ + "Group.CallbackAfterRecallMsg", + "C2C.CallbackAfterMsgWithDraw", + }) + _REDACTED = "[This message was recalled/withdrawn by the sender; original content removed]" + + async def handle(self, ctx: InboundContext, next_fn) -> None: + cmd = (ctx.push or {}).get("callback_command", "") + if cmd not in self._RECALL_COMMANDS: + await next_fn() + return + self._handle_recall(ctx, cmd) + + @staticmethod + def _build_source(adapter, group_code: str, from_account: str): + return adapter.build_source( + chat_id=(f"group:{group_code}" if group_code else f"direct:{from_account}"), + chat_type="group" if group_code else "dm", + user_id=from_account or None, + thread_id="main" if group_code else None, + ) + + def _handle_recall(self, ctx: InboundContext, cmd: str) -> None: + adapter = ctx.adapter + push = ctx.push or {} + + if cmd == "Group.CallbackAfterRecallMsg": + seq_list = push.get("recall_msg_seq_list") or [] + else: + mid = push.get("msg_id") or "" + seq = push.get("msg_seq") + seq_list = [{"msg_id": mid, "msg_seq": seq}] if (mid or seq) else [] + + if not seq_list: + logger.debug("[%s] Recall callback with empty seq_list, skipping", adapter.name) + return + + group_code = (push.get("group_code") or "").strip() + from_account = (push.get("from_account") or "").strip() + + for seq_entry in seq_list: + recalled_id = seq_entry.get("msg_id") or str(seq_entry.get("msg_seq") or "") + if not recalled_id: + continue + + matched_sk = self._find_processing_session(adapter, recalled_id) + if matched_sk is not None: + self._interrupt_for_recall(adapter, matched_sk, recalled_id, group_code, from_account) + else: + recalled_content = adapter._msg_content_cache.get(recalled_id) + self._patch_transcript(adapter, recalled_id, group_code, from_account, recalled_content) + + # -- Branch C: interrupt currently-processing message --------------- + + @staticmethod + def _find_processing_session(adapter, recalled_id: str) -> Optional[str]: + for sk, mid in adapter._processing_msg_ids.items(): + if mid == recalled_id and sk in adapter._active_sessions: + return sk + return None + + @classmethod + def _interrupt_for_recall(cls, adapter, session_key: str, recalled_id: str, + group_code: str, from_account: str) -> None: + where = f"group {group_code}" if group_code else f"direct chat with {from_account}" + recall_text = ( + f"[CRITICAL — MESSAGE RECALLED] The user message that triggered " + f"your current task (message_id=\"{recalled_id}\") in {where} has " + f"been recalled/withdrawn by the sender. " + f"IGNORE any prior system note asking you to finish processing " + f"tool results — the original request is void. " + f"Do NOT continue the task, do NOT call more tools, do NOT " + f"reference the recalled content. " + f"Reply only with a brief acknowledgment such as " + f"\"The message has been recalled.\" in the " + f"language the user was using." + ) + + synth_event = MessageEvent( + text=recall_text, + message_type=MessageType.TEXT, + source=cls._build_source(adapter, group_code, from_account), + internal=True, + ) + # Set pending + signal directly (bypass handle_message to avoid busy-ack). + # May overwrite a user message pending in the same ~200ms window — acceptable. + adapter._pending_messages[session_key] = synth_event + active_event = adapter._active_sessions.get(session_key) + if active_event is not None: + active_event.set() + + logger.info("[%s] Recall interrupt: msg_id=%s session=%s", adapter.name, recalled_id, session_key[:30]) + + # The interrupted turn will persist the recalled content *after* our + # interrupt — schedule a delayed redaction to clean it up. + recalled_text = adapter._processing_msg_texts.get(session_key, "") + if recalled_text: + cls._schedule_content_redact(adapter, session_key, recalled_text, group_code, from_account) + + @classmethod + def _schedule_content_redact(cls, adapter, session_key: str, recalled_text: str, + group_code: str, from_account: str) -> None: + async def _redact() -> None: + store = getattr(adapter, "_session_store", None) + if not store: + return + try: + sid = store.get_or_create_session( + cls._build_source(adapter, group_code, from_account), + ).session_id + except Exception: + return + # Poll until the recalled content appears in transcript — the + # interrupted turn hasn't finished writing yet when scheduled. + for _ in range(30): + await asyncio.sleep(0.5) + try: + transcript = store.load_transcript(sid) + except Exception: + continue + for entry in transcript: + if entry.get("role") == "user" and entry.get("content") == recalled_text: + entry["content"] = cls._REDACTED + try: + store.rewrite_transcript(sid, transcript) + logger.info("[%s] Recall redact: session %s", adapter.name, session_key[:30]) + except Exception as exc: + logger.warning("[%s] Recall redact failed: %s", adapter.name, exc) + return + logger.debug("[%s] Recall redact: content not found after polling, session %s", adapter.name, session_key[:30]) + + task = asyncio.create_task(_redact()) + adapter._background_tasks.add(task) + task.add_done_callback(adapter._background_tasks.discard) + + # -- Branch A/B: patch transcript (session idle) -------------------- + + @classmethod + def _patch_transcript(cls, adapter, recalled_id: str, group_code: str, + from_account: str, recalled_content: Optional[str] = None) -> None: + store = getattr(adapter, "_session_store", None) + if not store: + return + try: + sid = store.get_or_create_session(cls._build_source(adapter, group_code, from_account)).session_id + except Exception as exc: + logger.warning("[%s] Recall: failed to resolve session: %s", adapter.name, exc) + return + + # Read JSONL directly — SQLite doesn't preserve message_id field. + transcript: list = [] + try: + path = store.get_transcript_path(sid) + if path.exists(): + with open(path, "r", encoding="utf-8") as f: + for line in f: + line = line.strip() + if line: + try: + transcript.append(json.loads(line)) + except json.JSONDecodeError: + pass + except Exception as exc: + logger.warning("[%s] Recall: failed to load transcript: %s", adapter.name, exc) + return + + # Branch A: redact — try message_id first, then content fallback. + # Observed messages have message_id; agent-processed @bot messages + # only have content (run.py doesn't write message_id to transcript). + target = None + for entry in transcript: + if entry.get("message_id") == recalled_id: + target = entry + break + if target is None and recalled_content: + for entry in transcript: + if entry.get("role") == "user" and entry.get("content") == recalled_content: + target = entry + break + if target is not None: + target["content"] = cls._REDACTED + try: + store.rewrite_transcript(sid, transcript) + logger.info("[%s] Recall: redacted msg_id=%s (branch A)", adapter.name, recalled_id) + except Exception as exc: + logger.warning("[%s] Recall: rewrite_transcript failed: %s", adapter.name, exc) + return + + # Branch B: not found in transcript → append system note + store.append_to_transcript(sid, { + "role": "system", + "content": f'[recall] message_id="{recalled_id}" has been recalled; do not quote or reference it.', + "timestamp": datetime.now(tz=timezone.utc).isoformat(), + }) + logger.info("[%s] Recall: system note for msg_id=%s (branch B)", adapter.name, recalled_id) + + +class SkipSelfMiddleware(InboundMiddleware): + """Filter out bot's own messages.""" + + name = "skip-self" + + @staticmethod + def _is_self_reference(from_account: str, bot_id: Optional[str]) -> bool: + """Detect whether the message is from the bot itself.""" + if not from_account or not bot_id: + return False + return from_account == bot_id + + async def handle(self, ctx: InboundContext, next_fn) -> None: + if self._is_self_reference(ctx.from_account, ctx.adapter._bot_id): + logger.debug("[%s] Ignoring self-sent message from %s", ctx.adapter.name, ctx.from_account) + return # Stop pipeline + await next_fn() + + +class ChatRoutingMiddleware(InboundMiddleware): + """Determine chat_id, chat_type, chat_name from push fields.""" + + name = "chat-routing" + + async def handle(self, ctx: InboundContext, next_fn) -> None: + if ctx.group_code: + ctx.chat_id = f"group:{ctx.group_code}" + ctx.chat_type = "group" + ctx.chat_name = ctx.group_name or ctx.group_code + else: + ctx.chat_id = f"direct:{ctx.from_account}" + ctx.chat_type = "dm" + ctx.chat_name = ctx.sender_nickname or ctx.from_account + await next_fn() + + +class AccessPolicy: + """Platform-level DM / Group access control policy. + + Encapsulates the allow/deny logic so that both inbound middleware + and outbound ``send_dm`` can share the same rules without reaching + into adapter internals. + """ + + def __init__( + self, + dm_policy: str, + dm_allow_from: list[str], + group_policy: str, + group_allow_from: list[str], + ) -> None: + self._dm_policy = dm_policy + self._dm_allow_from = dm_allow_from + self._group_policy = group_policy + self._group_allow_from = group_allow_from + + def is_dm_allowed(self, sender_id: str) -> bool: + """Platform-level DM inbound filter (open / allowlist / disabled).""" + if self._dm_policy == "disabled": + return False + if self._dm_policy == "allowlist": + return sender_id.strip() in self._dm_allow_from + return True + + def is_group_allowed(self, group_code: str) -> bool: + """Platform-level group chat inbound filter (open / allowlist / disabled).""" + if self._group_policy == "disabled": + return False + if self._group_policy == "allowlist": + return group_code.strip() in self._group_allow_from + return True + + @property + def dm_policy(self) -> str: + return self._dm_policy + + @property + def group_policy(self) -> str: + return self._group_policy + + +class AccessGuardMiddleware(InboundMiddleware): + """Platform-level DM/Group access control filter.""" + + name = "access-guard" + + async def handle(self, ctx: InboundContext, next_fn) -> None: + adapter = ctx.adapter + policy: AccessPolicy = adapter._access_policy + if ctx.chat_type == "dm": + if not policy.is_dm_allowed(ctx.from_account): + logger.debug( + "[%s] DM from %s blocked by dm_policy=%s", + adapter.name, ctx.from_account, policy.dm_policy, + ) + return # Stop pipeline + elif ctx.chat_type == "group": + if not policy.is_group_allowed(ctx.group_code): + logger.debug( + "[%s] Group %s blocked by group_policy=%s", + adapter.name, ctx.group_code, policy.group_policy, + ) + return # Stop pipeline + await next_fn() + + +class AutoSetHomeMiddleware(InboundMiddleware): + """Auto-designate the first inbound conversation as Yuanbao home channel. + + Triggers when no home channel is configured, or when an existing group-chat + home is superseded by the first DM (direct > group upgrade). + Silent: writes config.yaml and env, no user-facing message. + """ + + name = "auto-sethome" + + async def handle(self, ctx: InboundContext, next_fn) -> None: + adapter = ctx.adapter + if not adapter._auto_sethome_done: + _cur_home = os.getenv("YUANBAO_HOME_CHANNEL", "") + _should_set = ( + not _cur_home + or (_cur_home.startswith("group:") and ctx.chat_type == "dm") + ) + if ctx.chat_type == "dm": + adapter._auto_sethome_done = True # DM seen — no further upgrades needed + if _should_set: + try: + from hermes_constants import get_hermes_home + from utils import atomic_yaml_write + import yaml + + _home = get_hermes_home() + config_path = _home / "config.yaml" + user_config: dict = {} + if config_path.exists(): + with open(config_path, encoding="utf-8") as f: + user_config = yaml.safe_load(f) or {} + user_config["YUANBAO_HOME_CHANNEL"] = ctx.chat_id + atomic_yaml_write(config_path, user_config) + os.environ["YUANBAO_HOME_CHANNEL"] = str(ctx.chat_id) + logger.info( + "[%s] Auto-sethome: designated %s (%s) as Yuanbao home channel", + adapter.name, ctx.chat_id, ctx.chat_name, + ) + # Silent auto-sethome: no user-facing message, only log + except Exception as e: + logger.warning("[%s] Auto-sethome failed: %s", adapter.name, e) + await next_fn() + + +class ExtractContentMiddleware(InboundMiddleware): + """Extract raw text and media refs from msg_body.""" + + name = "extract-content" + + _CARD_CONTENT_MAX_LENGTH = 1000 + + @staticmethod + def _format_shared_link(custom: dict) -> str: + """Format elem_type 1010 (share card) into bracket-placeholder text.""" + title = custom.get("title", "") + link = custom.get("link", "") + header = f"[share_card: {title} | {link}]" if link else f"[share_card: {title}]" + lines = [header] + max_len = ExtractContentMiddleware._CARD_CONTENT_MAX_LENGTH + for field in ("card_content", "wechat_des"): + val = custom.get(field) + if val and isinstance(val, str): + preview = val[:max_len] + "...(truncated)" if len(val) > max_len else val + lines.append(f"Preview: {preview}") + break + if link: + lines.append("[visit link for full content]") + return "\n".join(lines) + + @staticmethod + def _format_link_understanding(custom: dict) -> Optional[str]: + """Format elem_type 1007 (link understanding card) into bracket-placeholder text.""" + content = custom.get("content") + if not content: + return None + try: + parsed = json.loads(content) + link = parsed.get("link") if isinstance(parsed, dict) else None + except (json.JSONDecodeError, TypeError): + link = None + if not link or not isinstance(link, str): + return None + return f"[link: {link} | visit link for full content]" + + @classmethod + def _extract_text(cls, msg_body: list) -> str: + """Extract plain text content from MsgBody. + + - TIMTextElem -> text field + - TIMImageElem -> "[image]" + - TIMFileElem -> "[file: {filename}]" + - TIMSoundElem -> "[voice]" + - TIMVideoFileElem -> "[video]" + - TIMFaceElem -> "[emoji: {name}]" or "[emoji]" + - TIMCustomElem -> try to extract data field, otherwise "[custom message]" + - Multiple elems joined with spaces + """ + parts: list[str] = [] + for elem in msg_body: + elem_type: str = elem.get("msg_type", "") + content: dict = elem.get("msg_content", {}) + + if elem_type == "TIMTextElem": + text = content.get("text", "") + if text: + parts.append(text) + elif elem_type == "TIMImageElem": + parts.append("[image]") + elif elem_type == "TIMFileElem": + filename = content.get("file_name", content.get("fileName", content.get("filename", ""))) + parts.append(f"[file: {filename}]" if filename else "[file]") + elif elem_type == "TIMSoundElem": + parts.append("[voice]") + elif elem_type == "TIMVideoFileElem": + parts.append("[video]") + elif elem_type == "TIMCustomElem": + data_val = content.get("data", "") + if data_val: + try: + custom = json.loads(data_val) + if not isinstance(custom, dict): + parts.append("[unsupported message type]") + continue + ctype = custom.get("elem_type") + if ctype == 1002: + parts.append(custom.get("text", "[mention]")) + elif ctype == 1010: + parts.append(cls._format_shared_link(custom)) + elif ctype == 1007: + text = cls._format_link_understanding(custom) + if text: + parts.append(text) + else: + parts.append("[unsupported message type]") + else: + parts.append("[unsupported message type]") + except (json.JSONDecodeError, TypeError): + parts.append(data_val) + else: + parts.append("[unsupported message type]") + elif elem_type == "TIMFaceElem": + # Sticker/emoji: extract name from data JSON + raw_data = content.get("data", "") + face_name = "" + if raw_data: + try: + face_data = json.loads(raw_data) + face_name = (face_data.get("name") or "").strip() + except (json.JSONDecodeError, TypeError, AttributeError): + pass + parts.append(f"[emoji: {face_name}]" if face_name else "[emoji]") + elif elem_type: + # Unknown element type — include type as placeholder + parts.append(f"[{elem_type}]") + + return " ".join(parts) if parts else "" + + @staticmethod + def _rewrite_slash_command(text: str) -> str: + """Normalize input text: strip whitespace and convert full-width slash + (Chinese input method) to ASCII slash so commands are recognized correctly. + """ + text = text.strip() + if text.startswith('\uff0f'): # Full-width slash + text = '/' + text[1:] + return text + + @staticmethod + def _extract_inbound_media_refs(msg_body: list) -> List[Dict[str, str]]: + """Extract inbound image/file references from TIM msg_body. + + Return example: + [{"kind": "image", "url": "https://..."}, {"kind": "file", "url": "...", "name": "a.pdf"}] + """ + refs: List[Dict[str, str]] = [] + for elem in msg_body or []: + if not isinstance(elem, dict): + continue + msg_type = elem.get("msg_type", "") + content = elem.get("msg_content", {}) or {} + if not isinstance(content, dict): + continue + + if msg_type == "TIMImageElem": + # Prefer medium image (index 1), fallback to index 0. + image_info_array = content.get("image_info_array") + if not isinstance(image_info_array, list): + image_info_array = [] + image_info = None + if len(image_info_array) > 1 and isinstance(image_info_array[1], dict): + image_info = image_info_array[1] + elif len(image_info_array) > 0 and isinstance(image_info_array[0], dict): + image_info = image_info_array[0] + image_url = str((image_info or {}).get("url") or "").strip() + if image_url: + refs.append({"kind": "image", "url": image_url}) + continue + + if msg_type == "TIMFileElem": + file_url = str(content.get("url") or "").strip() + file_name = ( + str(content.get("file_name") or "").strip() + or str(content.get("fileName") or "").strip() + or str(content.get("filename") or "").strip() + ) + if file_url: + ref: Dict[str, str] = {"kind": "file", "url": file_url} + if file_name: + ref["name"] = file_name + refs.append(ref) + return refs + + @staticmethod + def _extract_link_urls(msg_body: list) -> list: + """Extract link URLs from share-card (1010) and link-understanding (1007) custom elems.""" + urls: list[str] = [] + for elem in msg_body or []: + if not isinstance(elem, dict) or elem.get("msg_type") != "TIMCustomElem": + continue + data_str = (elem.get("msg_content") or {}).get("data", "") + if not data_str: + continue + try: + custom = json.loads(data_str) + except (json.JSONDecodeError, TypeError): + continue + if not isinstance(custom, dict): + continue + ctype = custom.get("elem_type") + if ctype == 1010: + link = custom.get("link") + if link and isinstance(link, str): + urls.append(link) + elif ctype == 1007: + content = custom.get("content") + if content: + try: + parsed = json.loads(content) + link = parsed.get("link") if isinstance(parsed, dict) else None + if link and isinstance(link, str): + urls.append(link) + except (json.JSONDecodeError, TypeError): + pass + return urls + + async def handle(self, ctx: InboundContext, next_fn) -> None: + ctx.raw_text = self._rewrite_slash_command(self._extract_text(ctx.msg_body)) + ctx.media_refs = self._extract_inbound_media_refs(ctx.msg_body) + ctx.link_urls = self._extract_link_urls(ctx.msg_body) + await next_fn() + +class PlaceholderFilterMiddleware(InboundMiddleware): + """Skip pure placeholder messages (e.g. '[image]' with no media).""" + + name = "placeholder-filter" + + SKIPPABLE_PLACEHOLDERS: frozenset = frozenset({ + "[image]", "[图片]", "[file]", "[文件]", + "[video]", "[视频]", "[voice]", "[语音]", + }) + + @classmethod + def is_skippable_placeholder(cls, text: str, media_count: int = 0) -> bool: + """Detect whether the message is a pure placeholder (should be skipped).""" + if media_count > 0: + return False + stripped = text.strip() + return stripped in cls.SKIPPABLE_PLACEHOLDERS + + async def handle(self, ctx: InboundContext, next_fn) -> None: + if self.is_skippable_placeholder(ctx.raw_text, len(ctx.media_refs)): + logger.debug("[%s] Skipping placeholder message: %r", ctx.adapter.name, ctx.raw_text) + return # Stop pipeline + await next_fn() + + +class OwnerCommandMiddleware(InboundMiddleware): + """Detect bot-owner slash commands in group chat. + + Identifies in-group allowlisted slash commands and determines sender identity. + Owner commands skip @Bot detection; non-owner attempts are rejected. + """ + + name = "owner-command" + + # Slash command allowlist that bot owner can execute in group without @Bot + ALLOWLIST: frozenset = frozenset({ + "/new", "/reset", "/retry", "/undo", "/stop", + "/approve", "/deny", "/background", "/bg", + "/btw", "/queue", "/q", + }) + + @staticmethod + def _rewrite_slash_command(text: str) -> str: + """Normalize full-width slash to ASCII slash and strip whitespace.""" + text = text.strip() + if text.startswith('\uff0f'): # Full-width slash + text = '/' + text[1:] + return text + + @classmethod + def _detect_owner_command( + cls, + *, + push: dict, + msg_body: list, + chat_type: str, + from_account: str, + ) -> Tuple[Optional[str], Optional[str], bool]: + """Identify allowlisted slash commands and determine sender identity. + + Returns (cmd, cmd_line, is_owner): + - (None, None, False): Not an allowlisted command + - (cmd, cmd_line, True): Owner match + - (cmd, cmd_line, False): Allowlisted command but sender is not owner + """ + if chat_type != "group" or not cls.ALLOWLIST: + return None, None, False + + # Extract TIMTextElem: only do command recognition with exactly one text segment + text_elems = [ + e for e in (msg_body or []) + if e.get("msg_type") == "TIMTextElem" + ] + if len(text_elems) != 1: + return None, None, False + + text = (text_elems[0].get("msg_content") or {}).get("text", "") + cmd_line = cls._rewrite_slash_command(text) + if not cmd_line.startswith("/"): + return None, None, False + cmd = cmd_line.split(maxsplit=1)[0].lower() + if cmd not in cls.ALLOWLIST: + return None, None, False + + # Sender identity check: bot owner <-> push.from_account == push.bot_owner_id + owner_id = (push or {}).get("bot_owner_id") or "" + # is_owner = bool(owner_id) and owner_id == from_account + is_owner = True + return cmd, cmd_line, is_owner + + async def handle(self, ctx: InboundContext, next_fn) -> None: + adapter = ctx.adapter + matched_cmd, cmd_line, is_owner = self._detect_owner_command( + push=ctx.push, + msg_body=ctx.msg_body, + chat_type=ctx.chat_type, + from_account=ctx.from_account, + ) + if matched_cmd and not is_owner: + # Non-owner tried an owner-only command — reject and stop + logger.info( + "[%s] Reject non-owner slash command: chat=%s from=%s cmd=%s", + adapter.name, ctx.chat_id, ctx.from_account, matched_cmd, + ) + adapter._track_task(asyncio.create_task( + adapter.send(ctx.chat_id, f"⚠️ {matched_cmd} is only available to the creator in private chat mode"), + name=f"yuanbao-owner-cmd-denial-{matched_cmd}", + )) + return # Stop pipeline + + if matched_cmd and is_owner and cmd_line: + logger.info( + "[%s] Bot owner slash command: chat=%s from=%s cmd=%s", + adapter.name, ctx.chat_id, ctx.from_account, matched_cmd, + ) + ctx.owner_command = matched_cmd + ctx.raw_text = cmd_line # Override with clean command text + await next_fn() + + +class BuildSourceMiddleware(InboundMiddleware): + """Build SessionSource from context fields.""" + + name = "build-source" + + async def handle(self, ctx: InboundContext, next_fn) -> None: + adapter = ctx.adapter + ctx.source = adapter.build_source( + chat_id=ctx.chat_id, + chat_type=ctx.chat_type, + chat_name=ctx.chat_name, + user_id=ctx.from_account or None, + user_name=ctx.sender_nickname or ctx.from_account, + thread_id="main" if ctx.chat_type == "group" else None, + ) + await next_fn() + + +class GroupAtGuardMiddleware(InboundMiddleware): + """In group chat, observe non-@bot messages; only reply on @Bot. + + Owner commands skip @Bot detection (owner doesn't need to @Bot). + """ + + name = "group-at-guard" + + @staticmethod + def _is_at_bot(msg_body: list, bot_id: Optional[str]) -> bool: + """Detect whether the message @Bot. + + AT element format: TIMCustomElem, msg_content.data is a JSON string: + {"elem_type": 1002, "text": "@xxx", "user_id": ""} + Considered @Bot when elem_type == 1002 and user_id == bot_id. + """ + if not bot_id: + return False + for elem in msg_body: + if elem.get("msg_type") != "TIMCustomElem": + continue + data_str = elem.get("msg_content", {}).get("data", "") + if not data_str: + continue + try: + custom = json.loads(data_str) + except (json.JSONDecodeError, TypeError): + continue + if custom.get("elem_type") == 1002 and custom.get("user_id") == bot_id: + return True + return False + + @staticmethod + def _extract_bot_mention_text(msg_body: list, bot_id: Optional[str]) -> str: + """Extract the display text used to @-mention this bot (e.g. ``@yuanbao-bot``).""" + if not bot_id: + return "" + for elem in msg_body: + if elem.get("msg_type") != "TIMCustomElem": + continue + data_str = elem.get("msg_content", {}).get("data", "") + if not data_str: + continue + try: + custom = json.loads(data_str) + except (json.JSONDecodeError, TypeError): + continue + if custom.get("elem_type") == 1002 and custom.get("user_id") == bot_id: + mention_text = str(custom.get("text") or "").strip() + if mention_text: + return mention_text + return "" + + @staticmethod + def _build_group_channel_prompt(msg_body: list, bot_id: Optional[str]) -> str: + """Build a per-turn group-chat prompt that highlights which message to respond to.""" + bid = str(bot_id or "unknown") + bot_mention = GroupAtGuardMiddleware._extract_bot_mention_text(msg_body, bot_id) or "unknown" + return ( + "You are handling a Yuanbao group chat message.\n" + f"- Your identity: user_id={bid}, @-mention name in this group={bot_mention}\n" + "- Lines in history prefixed with `[nickname|user_id]` are observed group context " + "and are not necessarily addressed to you.\n" + "- Treat only the current new message as a request explicitly directed at you, " + "and answer it directly." + ) + + @staticmethod + def _observe_group_message( + adapter, source, sender_display: str, text: str, + *, msg_id: Optional[str] = None, + ) -> None: + """Write a group message into the session transcript without triggering the agent. + + This allows the model to see the full group conversation when it is + eventually invoked via @bot. Messages are stored with ``role: "user"`` + in the format ``[nickname|user_id]\\n`` so the model + can distinguish participants and their user ids. + """ + store = getattr(adapter, "_session_store", None) + if not store: + return + try: + session_entry = store.get_or_create_session(source) + user_id = source.user_id or "unknown" + attributed = f"[{sender_display}|{user_id}]\n{text}" + entry: dict = { + "role": "user", + "content": attributed, + "timestamp": datetime.now(tz=timezone.utc).isoformat(), + "observed": True, + } + if msg_id: + entry["message_id"] = msg_id + store.append_to_transcript( + session_entry.session_id, + entry, + ) + except Exception as exc: + logger.warning("[%s] Failed to observe group message: %s", adapter.name, exc) + + async def handle(self, ctx: InboundContext, next_fn) -> None: + adapter = ctx.adapter + if ctx.chat_type == "group" and not ctx.owner_command and not self._is_at_bot(ctx.msg_body, adapter._bot_id): + self._observe_group_message( + adapter, ctx.source, ctx.sender_nickname or ctx.from_account, ctx.raw_text, + msg_id=ctx.msg_id or None, + ) + logger.info( + "[%s] Group message observed (no @bot): chat=%s from=%s", + adapter.name, ctx.chat_id, ctx.from_account, + ) + return # Stop pipeline — message observed but not dispatched + await next_fn() + + +class GroupAttributionMiddleware(InboundMiddleware): + """Tag group @bot messages with [nickname|user_id] attribution and channel_prompt. + + For group messages that pass the @bot guard (i.e. the bot is mentioned), + this middleware: + - Builds a per-turn channel_prompt so the model knows its identity and + the attribution scheme. + - Rewrites ctx.raw_text to ``[nickname|user_id]\\n`` to match + the observed-history format. + - Suppresses the runner's default ``[user_name]`` shared-thread prefix + by clearing ``source.user_name``. + """ + + name = "group-attribution" + + async def handle(self, ctx: InboundContext, next_fn) -> None: + if ctx.chat_type == "group" and not ctx.owner_command: + adapter = ctx.adapter + ctx.channel_prompt = GroupAtGuardMiddleware._build_group_channel_prompt( + ctx.msg_body, adapter._bot_id, + ) + user_id_label = ctx.from_account or "unknown" + nickname_label = ctx.sender_nickname or ctx.from_account or "unknown" + ctx.raw_text = f"[{nickname_label}|{user_id_label}]\n{ctx.raw_text}" + # Suppress runner's default ``[user_name]`` shared-thread prefix so + # the text the model sees matches the observed-history format. + if ctx.source is not None: + ctx.source = dataclasses.replace(ctx.source, user_name=None) + await next_fn() + + +class ClassifyMessageTypeMiddleware(InboundMiddleware): + """Determine MessageType from text content and msg_body elements.""" + + name = "classify-msg-type" + + @staticmethod + def _classify(text: str, msg_body: list) -> MessageType: + """Classify message type based on text and msg_body.""" + if text.startswith("/"): + return MessageType.COMMAND + for elem in msg_body: + etype = elem.get("msg_type", "") + if etype == "TIMImageElem": + return MessageType.PHOTO + if etype == "TIMSoundElem": + return MessageType.VOICE + if etype == "TIMVideoFileElem": + return MessageType.VIDEO + if etype == "TIMFileElem": + return MessageType.DOCUMENT + return MessageType.TEXT + + async def handle(self, ctx: InboundContext, next_fn) -> None: + ctx.msg_type = self._classify(ctx.raw_text, ctx.msg_body) + await next_fn() + + +class QuoteContextMiddleware(InboundMiddleware): + """Extract quote/reply context from cloud_custom_data.""" + + name = "quote-context" + + @staticmethod + def _extract_quote_context(cloud_custom_data: str) -> Tuple[Optional[str], Optional[str]]: + """Extract quote context, mapping to MessageEvent.reply_to_*. + + Returns: + (reply_to_message_id, reply_to_text) + """ + if not cloud_custom_data: + return None, None + try: + parsed = json.loads(cloud_custom_data) + except (json.JSONDecodeError, TypeError): + return None, None + + quote = parsed.get("quote") if isinstance(parsed, dict) else None + if not isinstance(quote, dict): + return None, None + + # type=2 corresponds to image reference; desc may be empty, provide a placeholder. + quote_type = int(quote.get("type") or 0) + desc = str(quote.get("desc") or "").strip() + if quote_type == 2 and not desc: + desc = "[image]" + if not desc: + return None, None + + quote_id = str(quote.get("id") or "").strip() or None + sender = str(quote.get("sender_nickname") or quote.get("sender_id") or "").strip() + quote_text = f"{sender}: {desc}" if sender else desc + return quote_id, quote_text + + async def handle(self, ctx: InboundContext, next_fn) -> None: + ctx.reply_to_message_id, ctx.reply_to_text = self._extract_quote_context(ctx.cloud_custom_data) + await next_fn() + + +class MediaResolveMiddleware(InboundMiddleware): + """Resolve inbound media references to downloadable URLs.""" + + name = "media-resolve" + + @staticmethod + def _guess_image_ext_from_url(url: str) -> str: + """Guess image extension from URL path.""" + path = urllib.parse.urlparse(url).path + ext = os.path.splitext(path)[1].lower() + if ext in {".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp", ".heic", ".tiff"}: + return ext + return ".jpg" + + @staticmethod + async def _fetch_resource_url(adapter, resource_id: str) -> str: + """Low-level helper: exchange a ``resourceId`` for a direct download URL. + + Handles token retrieval, the ``/api/resource/v1/download`` API call, + and a single 401-retry with token force-refresh. Raises on failure. + """ + resource_id = resource_id.strip() + if not resource_id: + raise RuntimeError("missing resource_id") + + token_data = await adapter._get_cached_token() + token = str(token_data.get("token") or "").strip() + source = str(token_data.get("source") or "web").strip() or "web" + bot_id = str(token_data.get("bot_id") or adapter._bot_id or adapter._app_key).strip() + if not token or not bot_id: + raise RuntimeError("missing token or bot_id for resource download") + + api_url = f"{adapter._api_domain}/api/resource/v1/download" + headers = { + "Content-Type": "application/json", + "X-ID": bot_id, + "X-Token": token, + "X-Source": source, + } + + async with httpx.AsyncClient(timeout=15.0, follow_redirects=True) as client: + for attempt in range(2): + resp = await client.get(api_url, params={"resourceId": resource_id}, headers=headers) + if resp.status_code == 401 and attempt == 0: + # Force refresh token once on expiry and retry + token_data = await SignManager.force_refresh( + adapter._app_key, adapter._app_secret, adapter._api_domain, + ) + token = str(token_data.get("token") or "").strip() + source = str(token_data.get("source") or source or "web").strip() or "web" + bot_id = str(token_data.get("bot_id") or adapter._bot_id or adapter._app_key).strip() + if not token or not bot_id: + break + headers["X-ID"] = bot_id + headers["X-Token"] = token + headers["X-Source"] = source + continue + + resp.raise_for_status() + payload = resp.json() + code = payload.get("code") + if code not in (None, 0): + raise RuntimeError( + f"resource/v1/download failed: code={code}, msg={payload.get('msg', '')}" + ) + data = payload.get("data") if isinstance(payload.get("data"), dict) else payload + real_url = str((data or {}).get("url") or (data or {}).get("realUrl") or "").strip() + if real_url: + return real_url + raise RuntimeError("resource/v1/download missing url/realUrl") + + raise RuntimeError("resource/v1/download did not return a URL") + + @staticmethod + async def _resolve_download_url(adapter, url: str) -> str: + """Resolve Yuanbao resource placeholder to a directly fetchable real URL. + + Common URL patterns: + https://hunyuan.tencent.com/api/resource/download?resourceId=... + Direct GET returns 401; need business API: + GET /api/resource/v1/download?resourceId=... + """ + try: + parsed = urllib.parse.urlparse(url) + except Exception: + return url + + query = urllib.parse.parse_qs(parsed.query) + resource_ids = query.get("resourceId") or query.get("resourceid") or [] + resource_id = str(resource_ids[0]).strip() if resource_ids else "" + if not resource_id: + return url + + try: + return await MediaResolveMiddleware._fetch_resource_url(adapter, resource_id) + except Exception: + return url + + @classmethod + async def _download_and_cache( + cls, adapter, *, fetch_url: str, kind: str, + file_name: Optional[str] = None, log_tag: str = "", + ) -> Optional[Tuple[str, str]]: + """Download a Yuanbao resource and cache locally. Returns ``(local_path, mime)`` or ``None``.""" + try: + file_bytes, content_type = await media_download_url( + fetch_url, max_size_mb=adapter.MEDIA_MAX_SIZE_MB, + ) + except Exception as exc: + logger.warning( + "[%s] inbound media download failed: kind=%s %s err=%s", + adapter.name, kind, log_tag, exc, + ) + return None + + if kind == "image": + ext = cls._guess_image_ext_from_url(fetch_url) + try: + local_path = cache_image_from_bytes(file_bytes, ext=ext) + except ValueError as exc: + logger.warning( + "[%s] inbound image cache rejected: %s err=%s", + adapter.name, log_tag, exc, + ) + return None + mime = guess_mime_type(f"image{ext}") + if not mime.startswith("image/"): + mime = content_type if content_type.startswith("image/") else "image/jpeg" + return local_path, mime + + # kind == "file" + if not file_name: + parsed = urllib.parse.urlparse(fetch_url) + file_name = os.path.basename(parsed.path) or "file" + try: + local_path = cache_document_from_bytes(file_bytes, file_name) + except Exception as exc: + logger.warning( + "[%s] inbound file cache failed: %s err=%s", + adapter.name, log_tag, exc, + ) + return None + mime = guess_mime_type(file_name) or content_type or "application/octet-stream" + return local_path, mime + + @classmethod + async def _resolve_by_resource_id(cls, adapter, resource_id: str) -> str: + """Exchange a Yuanbao ``resourceId`` for a short-lived direct download URL. Raises on failure.""" + return await cls._fetch_resource_url(adapter, resource_id) + + @classmethod + async def _resolve_media_urls( + cls, adapter, media_refs: List[Dict[str, str]] + ) -> Tuple[List[str], List[str]]: + """Resolve inbound media refs: download to local cache, return (local_paths, mime_types). + + Yuanbao COS hostnames resolve to private IPs, tripping the SSRF guard + in vision_tools. We download ourselves and return local cache paths. + """ + media_urls: List[str] = [] + media_types: List[str] = [] + + for ref in media_refs: + kind = str(ref.get("kind") or "").strip().lower() + url = str(ref.get("url") or "").strip() + if kind not in {"image", "file"} or not url: + continue + + try: + fetch_url = await cls._resolve_download_url(adapter, url) + except Exception as exc: + logger.warning( + "[%s] inbound media resolve failed: kind=%s url=%s err=%s", + adapter.name, kind, url, exc, + ) + continue + + cached = await cls._download_and_cache( + adapter, + fetch_url=fetch_url, + kind=kind, + file_name=str(ref.get("name") or "").strip() or None, + log_tag=f"placeholder_url={url[:80]}", + ) + if cached is None: + continue + local_path, mime = cached + media_urls.append(local_path) + media_types.append(mime) + + return media_urls, media_types + + @classmethod + async def _collect_observed_media( + cls, adapter, source, + ) -> Tuple[List[str], List[str]]: + """Resolve recent observed image/file anchors from transcript into ``(local_paths, mimes)``.""" + store = getattr(adapter, "_session_store", None) + if not store: + return [], [] + try: + session_entry = store.get_or_create_session(source) + history = store.load_transcript(session_entry.session_id) + except Exception as exc: + logger.warning( + "[%s] Observed-media hydration setup failed: %s", + adapter.name, exc, + ) + return [], [] + if not history: + return [], [] + + start = max(0, len(history) - OBSERVED_MEDIA_BACKFILL_LOOKBACK) + order: List[Tuple[str, str, str]] = [] # (rid, kind, filename) + seen: set = set() + for msg in history[start:]: + content = msg.get("content") + if not isinstance(content, str) or "|ybres:" not in content: + continue + for m in _YB_RES_REF_RE.finditer(content): + head = m.group(1) # "image" | "file:" | "voice" | "video" + rid = m.group(2) + kind, _, filename = head.partition(":") + kind = kind.strip() + if kind not in ("image", "file"): + continue + if rid in seen: + continue + seen.add(rid) + order.append((rid, kind, filename.strip())) + if len(order) >= OBSERVED_MEDIA_BACKFILL_MAX_RESOLVE_PER_TURN: + break + if len(order) >= OBSERVED_MEDIA_BACKFILL_MAX_RESOLVE_PER_TURN: + break + + if not order: + return [], [] + + media_paths: List[str] = [] + mimes: List[str] = [] + for rid, kind, filename in order: + try: + fresh_url = await cls._resolve_by_resource_id(adapter, rid) + except Exception as exc: + logger.warning( + "[%s] observed-media resolve failed: rid=%s kind=%s err=%s", + adapter.name, rid, kind, exc, + ) + continue + cached = await cls._download_and_cache( + adapter, + fetch_url=fresh_url, + kind=kind, + file_name=filename or None, + log_tag=f"rid={rid}", + ) + if cached is None: + continue + path, mime = cached + media_paths.append(path) + mimes.append(mime) + return media_paths, mimes + + async def handle(self, ctx: InboundContext, next_fn) -> None: + adapter = ctx.adapter + ctx.media_urls, ctx.media_types = await self._resolve_media_urls(adapter, ctx.media_refs) + # Re-check placeholder after media resolution + if PlaceholderFilterMiddleware.is_skippable_placeholder(ctx.raw_text, len(ctx.media_urls)): + logger.debug("[%s] Skip placeholder after media download: %r", adapter.name, ctx.raw_text) + return # Stop pipeline + await next_fn() + + +class DispatchMiddleware(InboundMiddleware): + """Build MessageEvent and dispatch to AI handler.""" + + name = "dispatch" + + async def handle(self, ctx: InboundContext, next_fn) -> None: + adapter = ctx.adapter + + _sk = build_session_key( + ctx.source, + group_sessions_per_user=adapter.config.extra.get("group_sessions_per_user", True), + thread_sessions_per_user=adapter.config.extra.get("thread_sessions_per_user", False), + ) + + async def _dispatch_inbound_event() -> None: + media_urls = list(ctx.media_urls) + media_types = list(ctx.media_types) + + # Backfill observed media from recent transcript history + extra_img_urls: List[str] = [] + extra_img_mimes: List[str] = [] + try: + extra_img_urls, extra_img_mimes = await MediaResolveMiddleware._collect_observed_media( + adapter, ctx.source, + ) + except Exception as exc: + logger.warning( + "[%s] observed-image hydration raised, continuing anyway: %s", + adapter.name, exc, + ) + if extra_img_urls: + current = set(media_urls) + for u, m in zip(extra_img_urls, extra_img_mimes): + if u in current: + continue + media_urls.append(u) + media_types.append(m) + current.add(u) + + # Replace [kind|ybres:xxx] anchors with local cache paths so + # the transcript records usable paths for the model. + _patched_event_text = ctx.raw_text + for u, m in zip(media_urls, media_types): + if not u.startswith("/"): + continue + anchor_match = _YB_RES_REF_RE.search(_patched_event_text) + if not anchor_match: + continue + head = anchor_match.group(1) + kind, _, filename = head.partition(":") + kind = kind.strip() + if kind == "image" and m.startswith("image/"): + replacement = f"[image: {u}]" + elif kind == "file": + label = filename.strip() or os.path.basename(u) + replacement = f"[file: {label} → {u}]" + else: + continue + _patched_event_text = ( + _patched_event_text[:anchor_match.start()] + + replacement + + _patched_event_text[anchor_match.end():] + ) + + event = MessageEvent( + text=_patched_event_text, + message_type=ctx.msg_type, + source=ctx.source, + message_id=ctx.msg_id or None, + raw_message=ctx.push, + media_urls=media_urls, + media_types=media_types, + reply_to_message_id=ctx.reply_to_message_id, + reply_to_text=ctx.reply_to_text, + channel_prompt=ctx.channel_prompt, + ) + if _sk and ctx.msg_id: + adapter._processing_msg_ids[_sk] = ctx.msg_id + adapter._processing_msg_texts[_sk] = ctx.raw_text or "" + if ctx.msg_id and ctx.raw_text: + cache = adapter._msg_content_cache + cache[ctx.msg_id] = ctx.raw_text + if len(cache) > 200: + for k in list(cache)[:len(cache) - 200]: + del cache[k] + await adapter.handle_message(event) + + if ctx.chat_type == "group": + is_new = _sk not in adapter._group_queues + queue = adapter._group_queues.setdefault(_sk, asyncio.Queue()) + queue.put_nowait(_dispatch_inbound_event) + logger.info( + "[%s] Group message enqueued (qsize=%d) for %s", + adapter.name, queue.qsize(), (_sk or "")[:50], + ) + if is_new: + consumer = asyncio.create_task( + self._consume_group_queue(adapter, _sk), + name=f"yuanbao-group-consumer-{(_sk or '')[:30]}", + ) + adapter._inbound_tasks.add(consumer) + consumer.add_done_callback(adapter._inbound_tasks.discard) + else: + task = asyncio.create_task( + _dispatch_inbound_event(), + name=f"yuanbao-inbound-{ctx.msg_id or 'unknown'}", + ) + adapter._inbound_tasks.add(task) + task.add_done_callback(adapter._inbound_tasks.discard) + + await next_fn() + + @staticmethod + async def _consume_group_queue(adapter: "YuanbaoAdapter", session_key: str) -> None: + """Drain the group queue one dispatch at a time, waiting for each to finish.""" + _IDLE_TIMEOUT = 2.0 + queue = adapter._group_queues.get(session_key) + if not queue: + return + try: + while True: + try: + dispatch_fn = await asyncio.wait_for(queue.get(), timeout=_IDLE_TIMEOUT) + except asyncio.TimeoutError: + break + logger.debug( + "[%s] Group queue: dispatching for %s (remaining=%d)", + adapter.name, (session_key or "")[:50], queue.qsize(), + ) + try: + await dispatch_fn() + while session_key in adapter._active_sessions: + await asyncio.sleep(0.1) + except Exception: + logger.exception("[%s] Group queue consumer error", adapter.name) + finally: + adapter._group_queues.pop(session_key, None) + + +class InboundPipelineBuilder: + """Factory for building InboundPipeline instances. + + Separates pipeline assembly (business knowledge) from the pipeline engine + (InboundPipeline) so the engine stays generic and reusable. + """ + + # Default middleware sequence for Yuanbao inbound message processing. + _DEFAULT_MIDDLEWARES: list[type] = [ + DecodeMiddleware, + ExtractFieldsMiddleware, + RecallGuardMiddleware, + DedupMiddleware, + SkipSelfMiddleware, + ChatRoutingMiddleware, + AccessGuardMiddleware, + AutoSetHomeMiddleware, + ExtractContentMiddleware, + PlaceholderFilterMiddleware, + OwnerCommandMiddleware, + BuildSourceMiddleware, + GroupAtGuardMiddleware, + GroupAttributionMiddleware, + ClassifyMessageTypeMiddleware, + QuoteContextMiddleware, + MediaResolveMiddleware, + DispatchMiddleware, + ] + + @classmethod + def build(cls) -> InboundPipeline: + """Build the default inbound message processing pipeline.""" + pipeline = InboundPipeline() + for mw_cls in cls._DEFAULT_MIDDLEWARES: + pipeline.use(mw_cls()) + return pipeline + +class ConnectionManager: + """Manages the WebSocket connection lifecycle for YuanbaoAdapter. + + Responsibilities: + - Opening and closing the WebSocket + - AUTH_BIND handshake + - Heartbeat (ping/pong) loop + - Receive loop (frame dispatch) + - Reconnect with exponential backoff + """ + + def __init__(self, adapter: "YuanbaoAdapter") -> None: + self._adapter = adapter + self._ws = None # websockets connection + self._connect_id: Optional[str] = None + self._heartbeat_task: Optional[asyncio.Task] = None + self._recv_task: Optional[asyncio.Task] = None + self._pending_acks: Dict[str, asyncio.Future] = {} + self._pending_pong: Optional[asyncio.Future] = None + self._consecutive_hb_timeouts: int = 0 + self._reconnect_attempts: int = 0 + self._reconnecting: bool = False + # Debounce buffer for aggregating multi-part inbound messages + self._inbound_buffer: Dict[str, list] = {} # key -> [raw_data_frames, ...] + self._inbound_timers: Dict[str, asyncio.TimerHandle] = {} # key -> timer + + # -- Properties -------------------------------------------------------- + + @property + def ws(self): + return self._ws + + @property + def connect_id(self) -> Optional[str]: + return self._connect_id + + @property + def reconnect_attempts(self) -> int: + return self._reconnect_attempts + + @property + def is_connected(self) -> bool: + if self._ws is None: + return False + open_attr = getattr(self._ws, "open", None) + if open_attr is True: + return True + if callable(open_attr): + try: + return bool(open_attr()) + except Exception: + return False + return False + + # -- Open / Close ------------------------------------------------------ + + async def open(self) -> bool: + """Open WebSocket connection: sign-token → WS connect → AUTH_BIND → start loops. + + Returns True on success, False on failure. + """ + adapter = self._adapter + + if not WEBSOCKETS_AVAILABLE: + msg = "Yuanbao startup failed: 'websockets' package not installed" + adapter._set_fatal_error("yuanbao_missing_dependency", msg, retryable=True) + logger.warning("[%s] %s. Run: pip install websockets", adapter.name, msg) + return False + + if not adapter._app_key or not adapter._app_secret: + msg = ( + "Yuanbao startup failed: " + "YUANBAO_APP_ID and YUANBAO_APP_SECRET are required" + ) + adapter._set_fatal_error("yuanbao_missing_credentials", msg, retryable=False) + logger.error("[%s] %s", adapter.name, msg) + return False + + # Idempotency guard + if self._ws is not None: + try: + open_attr = getattr(self._ws, "open", None) + if open_attr is True or (callable(open_attr) and open_attr()): + logger.debug("[%s] Already connected, skipping connect()", adapter.name) + return True + except Exception: + pass + + # Acquire platform-scoped lock to prevent duplicate connections + if not adapter._acquire_platform_lock( + 'yuanbao-app-key', adapter._app_key, 'Yuanbao app key' + ): + return False + + try: + # Step 1: Get sign token + logger.info("[%s] Fetching sign token from %s", adapter.name, adapter._api_domain) + token_data = await SignManager.get_token( + adapter._app_key, adapter._app_secret, adapter._api_domain, + route_env=adapter._route_env, + ) + + # Update bot_id if returned by sign-token API + if token_data.get("bot_id"): + adapter._bot_id = str(token_data["bot_id"]) + + # Step 2: Open WebSocket connection (disable built-in ping/pong) + logger.info("[%s] Connecting to %s", adapter.name, adapter._ws_url) + self._ws = await asyncio.wait_for( + websockets.connect( # type: ignore[attr-defined] + adapter._ws_url, + ping_interval=None, + ping_timeout=None, + close_timeout=5, + ), + timeout=CONNECT_TIMEOUT_SECONDS, + ) + + # Step 3: Authenticate (AUTH_BIND + wait for BIND_ACK) + authed = await self._authenticate(token_data) + if not authed: + await self._cleanup_ws() + return False + + # Step 4: Start background tasks + self._reconnect_attempts = 0 + adapter._mark_connected() + adapter._loop = asyncio.get_running_loop() + self._heartbeat_task = asyncio.create_task( + self._heartbeat_loop(), name=f"yuanbao-heartbeat-{self._connect_id}" + ) + self._recv_task = asyncio.create_task( + self._receive_loop(), name=f"yuanbao-recv-{self._connect_id}" + ) + logger.info( + "[%s] Connected. connectId=%s botId=%s", + adapter.name, self._connect_id, adapter._bot_id, + ) + + YuanbaoAdapter.set_active(adapter) + + return True + + except asyncio.TimeoutError: + logger.error("[%s] Connection timed out", adapter.name) + await self._cleanup_ws() + adapter._release_platform_lock() + return False + except Exception as exc: + logger.error("[%s] connect() failed: %s", adapter.name, exc, exc_info=True) + await self._cleanup_ws() + adapter._release_platform_lock() + return False + + async def close(self) -> None: + """Cancel background tasks, fail pending futures, and close the WebSocket.""" + + if self._heartbeat_task: + self._heartbeat_task.cancel() + try: + await self._heartbeat_task + except asyncio.CancelledError: + pass + self._heartbeat_task = None + + if self._recv_task: + self._recv_task.cancel() + try: + await self._recv_task + except asyncio.CancelledError: + pass + self._recv_task = None + + # Fail any pending ACK futures + disc_exc = RuntimeError("YuanbaoAdapter disconnected") + for fut in self._pending_acks.values(): + if not fut.done(): + fut.set_exception(disc_exc) + self._pending_acks.clear() + + # Clear refresh locks to avoid stale locks from a previous event loop + SignManager.clear_locks() + + await self._cleanup_ws() + + # -- Authentication ---------------------------------------------------- + + async def _authenticate(self, token_data: dict) -> bool: + """Send AUTH_BIND and read frames until BIND_ACK is received. + + Returns True on success, False on failure/timeout. + """ + adapter = self._adapter + if self._ws is None: + return False + + token = token_data.get("token", "") + uid = adapter._bot_id or token_data.get("bot_id", "") + source = token_data.get("source") or "bot" + route_env = adapter._route_env or token_data.get("route_env", "") or "" + + msg_id = str(uuid.uuid4()) + + auth_bytes = encode_auth_bind( + biz_id="ybBot", + uid=uid, + source=source, + token=token, + msg_id=msg_id, + app_version=_APP_VERSION, + operation_system=_OPERATION_SYSTEM, + bot_version=_BOT_VERSION, + route_env=route_env, + ) + await self._ws.send(auth_bytes) + logger.debug("[%s] AUTH_BIND sent (msg_id=%s uid=%s)", adapter.name, msg_id, uid) + + try: + _loop = asyncio.get_running_loop() + deadline = _loop.time() + AUTH_TIMEOUT_SECONDS + while True: + remaining = deadline - _loop.time() + if remaining <= 0: + logger.error("[%s] AUTH_BIND timeout waiting for BIND_ACK", adapter.name) + return False + + raw = await asyncio.wait_for(self._ws.recv(), timeout=remaining) + if not isinstance(raw, (bytes, bytearray)): + continue + + try: + msg = decode_conn_msg(bytes(raw)) + except Exception: + continue + + head = msg.get("head", {}) + cmd_type = head.get("cmd_type", -1) + cmd = head.get("cmd", "") + + if cmd_type == CMD_TYPE["Response"] and cmd == "auth-bind": + connect_id = self._extract_connect_id(msg) + if connect_id: + self._connect_id = connect_id + logger.info("[%s] BIND_ACK received: connectId=%s", adapter.name, connect_id) + return True + else: + logger.error("[%s] BIND_ACK missing connectId", adapter.name) + return False + + except asyncio.TimeoutError: + logger.error("[%s] AUTH_BIND timeout", adapter.name) + return False + except Exception as exc: + logger.error("[%s] AUTH_BIND error: %s", adapter.name, exc, exc_info=True) + return False + + def _extract_connect_id(self, decoded_msg: dict) -> Optional[str]: + """Extract connectId from decoded BIND_ACK message.""" + data: bytes = decoded_msg.get("data", b"") + if not data: + return None + try: + fdict = _fields_to_dict(_parse_fields(data)) + code = _get_varint(fdict, 1) + if code != 0: + message = _get_string(fdict, 2) + logger.error( + "[%s] AuthBindRsp error: code=%d message=%r", + self._adapter.name, code, message, + ) + return None + connect_id = _get_string(fdict, 3) + return connect_id if connect_id else None + except Exception as exc: + logger.warning("[%s] Failed to extract connectId: %s", self._adapter.name, exc) + return None + + # -- Heartbeat --------------------------------------------------------- + + async def _heartbeat_loop(self) -> None: + """Send HEARTBEAT (ping) every 30s; trigger reconnect after threshold misses.""" + adapter = self._adapter + try: + while adapter._running: + await asyncio.sleep(HEARTBEAT_INTERVAL_SECONDS) + if self._ws is None: + continue + try: + msg_id = str(uuid.uuid4()) + ping_bytes = encode_ping(msg_id) + loop = asyncio.get_running_loop() + pong_future: asyncio.Future = loop.create_future() + self._pending_pong = pong_future + self._pending_acks[msg_id] = pong_future + await self._ws.send(ping_bytes) + logger.debug("[%s] PING sent (msg_id=%s)", adapter.name, msg_id) + try: + await asyncio.wait_for(pong_future, timeout=10.0) + self._consecutive_hb_timeouts = 0 + except asyncio.TimeoutError: + self._pending_acks.pop(msg_id, None) + self._consecutive_hb_timeouts += 1 + logger.warning( + "[%s] PONG timeout (%d/%d)", + adapter.name, self._consecutive_hb_timeouts, HEARTBEAT_TIMEOUT_THRESHOLD, + ) + if self._consecutive_hb_timeouts >= HEARTBEAT_TIMEOUT_THRESHOLD: + logger.warning("[%s] Heartbeat threshold exceeded, triggering reconnect", adapter.name) + self.schedule_reconnect() + return + finally: + self._pending_acks.pop(msg_id, None) + self._pending_pong = None + except Exception as exc: + logger.debug("[%s] Heartbeat send failed: %s", adapter.name, exc) + except asyncio.CancelledError: + pass + + # -- Receive loop ------------------------------------------------------ + + async def _receive_loop(self) -> None: + """Read WS frames and dispatch by cmd_type.""" + adapter = self._adapter + try: + async for raw in self._ws: # type: ignore[union-attr] + if not isinstance(raw, (bytes, bytearray)): + continue + await self._handle_frame(bytes(raw)) + except asyncio.CancelledError: + pass + except websockets.exceptions.ConnectionClosed as close_exc: # type: ignore[union-attr] + close_code = getattr(close_exc, 'code', None) + logger.warning( + "[%s] WebSocket connection closed: code=%s reason=%s", + adapter.name, close_code, getattr(close_exc, 'reason', ''), + ) + if close_code and close_code in NO_RECONNECT_CLOSE_CODES: + logger.error( + "[%s] Close code %d is non-recoverable, NOT reconnecting", + adapter.name, close_code, + ) + adapter._mark_disconnected() + else: + self.schedule_reconnect() + except Exception as exc: + logger.warning("[%s] receive_loop exited: %s", adapter.name, exc) + self.schedule_reconnect() + + async def _handle_frame(self, raw: bytes) -> None: + """Handle a single WebSocket frame.""" + adapter = self._adapter + try: + msg = decode_conn_msg(raw) + except Exception as exc: + logger.debug("[%s] Failed to decode frame: %s", adapter.name, exc) + return + + head = msg.get("head", {}) + cmd_type = head.get("cmd_type", -1) + cmd = head.get("cmd", "") + msg_id = head.get("msg_id", "") + need_ack = head.get("need_ack", False) + data: bytes = msg.get("data", b"") + + # HEARTBEAT_ACK + if cmd_type == CMD_TYPE["Response"] and cmd == "ping": + logger.debug("[%s] HEARTBEAT_ACK received (msg_id=%s)", adapter.name, msg_id) + if self._pending_pong is not None and not self._pending_pong.done(): + self._pending_pong.set_result(True) + elif msg_id and msg_id in self._pending_acks: + fut = self._pending_acks.pop(msg_id) + if not fut.done(): + fut.set_result(True) + return + + # Fire-and-forget heartbeat ACKs — server always responds but callers don't + # wait on these; silently discard to avoid "Unmatched Response" noise. + if cmd_type == CMD_TYPE["Response"] and cmd in ( + "send_group_heartbeat", + "send_private_heartbeat", + ): + logger.debug("[%s] Heartbeat ACK received: cmd=%s msg_id=%s", adapter.name, cmd, msg_id) + return + + # Response to an outbound RPC call + if cmd_type == CMD_TYPE["Response"]: + if msg_id and msg_id in self._pending_acks: + fut = self._pending_acks.pop(msg_id) + if not fut.done(): + result = {"head": head} + if data: + result["data"] = data + fut.set_result(result) + else: + logger.debug( + "[%s] Unmatched Response: cmd=%s msg_id=%s", + adapter.name, cmd, msg_id, + ) + return + + # Server-initiated Push + if cmd_type == CMD_TYPE["Push"]: + logger.info("[%s] Push received: cmd=%s msg_id=%s data_len=%d", adapter.name, cmd, msg_id, len(data)) + if need_ack and self._ws is not None: + try: + ack_bytes = encode_push_ack(head) + await self._ws.send(ack_bytes) + except Exception as ack_exc: + logger.debug("[%s] Failed to send PushAck: %s", adapter.name, ack_exc) + + if msg_id and msg_id in self._pending_acks: + fut = self._pending_acks.pop(msg_id) + if not fut.done(): + try: + decoded = decode_inbound_push(data) if data else {"head": head} + fut.set_result(decoded) + except Exception as exc: + fut.set_exception(exc) + return + + # Genuine inbound message — dispatch to AI + if data: + logger.info( + "[%s] WS received inbound push, decoding and dispatching: cmd=%s, data_len=%d", + adapter.name, cmd, len(data), + ) + self._push_to_inbound(data) + return + + logger.debug( + "[%s] Ignoring frame: cmd_type=%d cmd=%s msg_id=%s", + adapter.name, cmd_type, cmd, msg_id, + ) + + # -- Inbound dispatch --------------------------------------------------- + + _DEBOUNCE_WINDOW: float = 1.5 # seconds to wait for companion messages + + def _extract_sender_key(self, raw_data: bytes) -> str: + """Lightweight decode to extract sender key for debounce grouping. + + Returns 'from_account:group_code' or a fallback unique key. + """ + try: + parsed = json.loads(raw_data.decode("utf-8")) + if isinstance(parsed, dict): + from_account = ( + parsed.get("from_account", "") + or parsed.get("From_Account", "") + ) + group_code = ( + parsed.get("group_code", "") + or parsed.get("GroupId", "") + or parsed.get("group_id", "") + ) + if from_account: + return f"{from_account}:{group_code}" + except Exception: + pass + # Protobuf: try decode_inbound_push for sender info + try: + push = decode_inbound_push(raw_data) + if push: + return f"{push.get('from_account', '')}:{push.get('group_code', '')}" + except Exception: + pass + # Fallback: unique key (no aggregation) + return f"__unknown_{id(raw_data)}" + + def _push_to_inbound(self, raw_data: bytes) -> None: + """Debounced inbound dispatch. + + Buffers raw frames from the same sender within a short time window, + then dispatches all buffered data as a single aggregated pipeline + execution. This merges multi-part messages (e.g. image + text sent + as separate WS pushes) into one pipeline run. + """ + key = self._extract_sender_key(raw_data) + + # Cancel existing timer for this key (reset debounce window) + existing_timer = self._inbound_timers.pop(key, None) + if existing_timer: + existing_timer.cancel() + + # Append to buffer + if key not in self._inbound_buffer: + self._inbound_buffer[key] = [] + self._inbound_buffer[key].append(raw_data) + + logger.debug( + "[%s] Debounce: buffered frame for key=%s, count=%d", + self._adapter.name, key, len(self._inbound_buffer[key]), + ) + + # Schedule flush after debounce window + loop = asyncio.get_running_loop() + timer = loop.call_later( + self._DEBOUNCE_WINDOW, + self._flush_inbound_buffer, + key, + ) + self._inbound_timers[key] = timer + + def _flush_inbound_buffer(self, key: str) -> None: + """Flush the debounce buffer for a given key — execute the pipeline.""" + self._inbound_timers.pop(key, None) + data_list = self._inbound_buffer.pop(key, []) + if not data_list: + return + + adapter = self._adapter + logger.info( + "[%s] Debounce flush: key=%s, aggregated %d frames", + adapter.name, key, len(data_list), + ) + + ctx = InboundContext(adapter=adapter, raw_frames=data_list) + + adapter._track_task(asyncio.create_task( + adapter._inbound_pipeline.execute(ctx), + name=f"yuanbao-pipeline-{key}", + )) + + # -- Send business request --------------------------------------------- + + async def send_biz_request( + self, + encoded_conn_msg: bytes, + req_id: str, + timeout: float = DEFAULT_SEND_TIMEOUT, + ) -> dict: + """Send a business-layer request and wait for the response. + + 1. Register a Future in pending_acks[req_id] + 2. Send encoded_conn_msg (bytes) to WS + 3. asyncio.wait_for(future, timeout) + 4. Clean up pending_acks on timeout/exception + """ + if self._ws is None: + raise RuntimeError("Not connected") + + loop = asyncio.get_running_loop() + future: asyncio.Future = loop.create_future() + self._pending_acks[req_id] = future + try: + await self._ws.send(encoded_conn_msg) + result = await asyncio.wait_for(asyncio.shield(future), timeout=timeout) + return result + except asyncio.TimeoutError: + raise + except Exception: + raise + finally: + self._pending_acks.pop(req_id, None) + + # -- Reconnect --------------------------------------------------------- + + def schedule_reconnect(self) -> None: + """Schedule a reconnect only if running and not already reconnecting.""" + if self._adapter._running and not self._reconnecting: + asyncio.create_task(self._reconnect_with_backoff()) + + async def _reconnect_with_backoff(self) -> bool: + """Reconnect with exponential backoff (1s, 2s, 4s, … up to 60s).""" + if self._reconnecting: + logger.debug("[%s] Reconnect already in progress, skipping", self._adapter.name) + return False + self._reconnecting = True + try: + return await self._do_reconnect() + finally: + self._reconnecting = False + + async def _do_reconnect(self) -> bool: + """Internal reconnect loop, called under the _reconnecting guard.""" + adapter = self._adapter + for attempt in range(MAX_RECONNECT_ATTEMPTS): + self._reconnect_attempts = attempt + 1 + wait = min(2 ** attempt, 60) + logger.info( + "[%s] Reconnect attempt %d/%d in %ds", + adapter.name, attempt + 1, MAX_RECONNECT_ATTEMPTS, wait, + ) + await asyncio.sleep(wait) + + await self._cleanup_ws() + + try: + token_data = await SignManager.force_refresh( + adapter._app_key, adapter._app_secret, adapter._api_domain, + route_env=adapter._route_env, + ) + if token_data.get("bot_id"): + adapter._bot_id = str(token_data["bot_id"]) + + self._ws = await asyncio.wait_for( + websockets.connect( # type: ignore[attr-defined] + adapter._ws_url, + ping_interval=None, + ping_timeout=None, + close_timeout=5, + ), + timeout=CONNECT_TIMEOUT_SECONDS, + ) + + authed = await self._authenticate(token_data) + if not authed: + logger.warning("[%s] Re-auth failed on attempt %d", adapter.name, attempt + 1) + await self._cleanup_ws() + continue + + self._reconnect_attempts = 0 + self._consecutive_hb_timeouts = 0 + adapter._mark_connected() + + if self._heartbeat_task and not self._heartbeat_task.done(): + self._heartbeat_task.cancel() + self._heartbeat_task = asyncio.create_task( + self._heartbeat_loop(), + name=f"yuanbao-heartbeat-{self._connect_id}", + ) + + if self._recv_task and not self._recv_task.done(): + self._recv_task.cancel() + self._recv_task = asyncio.create_task( + self._receive_loop(), + name=f"yuanbao-recv-{self._connect_id}", + ) + + logger.info( + "[%s] Reconnected on attempt %d. connectId=%s", + adapter.name, attempt + 1, self._connect_id, + ) + return True + + except asyncio.TimeoutError: + logger.warning("[%s] Reconnect attempt %d timed out", adapter.name, attempt + 1) + except Exception as exc: + logger.warning( + "[%s] Reconnect attempt %d failed: %s", adapter.name, attempt + 1, exc + ) + + logger.error( + "[%s] Giving up after %d reconnect attempts", adapter.name, MAX_RECONNECT_ATTEMPTS + ) + adapter._mark_disconnected() + return False + + async def _cleanup_ws(self) -> None: + """Close and clear the WebSocket connection.""" + ws = self._ws + self._ws = None + if ws is not None: + try: + await ws.close() + except Exception: + pass + +class MediaSendHandler(ABC): + """Abstract base class for media send strategies. + + Subclasses implement: + - acquire_file(): how to obtain file bytes (download URL / read local) + - build_msg_body(): how to build TIMxxxElem from upload result + + The shared flow (check ws → cancel notifier → validate → COS upload + → lock → dispatch) is handled by the base handle() template method. + """ + + @abstractmethod + async def acquire_file( + self, adapter: "YuanbaoAdapter", **kwargs: Any, + ) -> Tuple[bytes, str, str]: + """Return (file_bytes, filename, content_type). + + Raises: + ValueError: when file cannot be acquired (not found, empty, etc.) + """ + + @abstractmethod + def build_msg_body(self, upload_result: dict, **kwargs: Any) -> list: + """Build platform-specific MsgBody list from COS upload result.""" + + def needs_cos_upload(self) -> bool: + """Override to return False for non-COS media (e.g. sticker).""" + return True + + async def handle( + self, + adapter: "YuanbaoAdapter", + chat_id: str, + reply_to: Optional[str] = None, + caption: Optional[str] = None, + **kwargs: Any, + ) -> "SendResult": + """Template method: shared media send flow.""" + conn = adapter._connection + sender = adapter._outbound.sender + + if conn.ws is None: + return SendResult(success=False, error="Not connected", retryable=True) + + adapter._outbound.cancel_slow_notifier(chat_id) + + try: + # 1. Acquire file bytes + file_bytes, filename, content_type = await self.acquire_file( + adapter, **kwargs, + ) + + # 2. Validate (only for handlers that upload to COS; stickers use + # TIMFaceElem and legitimately carry no file bytes, so skipping + # validate_media here avoids a spurious "Empty file: sticker"). + if self.needs_cos_upload(): + validation_err = MessageSender.validate_media( + file_bytes, filename, adapter.MEDIA_MAX_SIZE_MB, + ) + if validation_err: + return SendResult(success=False, error=validation_err) + + if self.needs_cos_upload(): + file_uuid = md5_hex(file_bytes) + + # 3. Get COS upload credentials + token_data = await adapter._get_cached_token() + token: str = token_data.get("token", "") + bot_id: str = ( + token_data.get("bot_id", "") or adapter._bot_id or "" + ) + + credentials = await get_cos_credentials( + app_key=adapter._app_key, + api_domain=adapter._api_domain, + token=token, + filename=filename, + bot_id=bot_id, + route_env=adapter._route_env, + ) + + # 4. Upload to COS + upload_result = await upload_to_cos( + file_bytes=file_bytes, + filename=filename, + content_type=content_type, + credentials=credentials, + bucket=credentials["bucketName"], + region=credentials["region"], + ) + + # 5. Build MsgBody + # Remove keys already passed explicitly to avoid "multiple values" TypeError + fwd_kwargs = { + k: v for k, v in kwargs.items() + if k not in ("file_uuid", "filename", "content_type") + } + msg_body = self.build_msg_body( + upload_result, + file_uuid=file_uuid, + filename=filename, + content_type=content_type, + **fwd_kwargs, + ) + else: + # Non-COS media (e.g. sticker): build MsgBody directly + msg_body = self.build_msg_body({}, **kwargs) + + # 6. Append caption if provided + if caption: + msg_body.append( + {"msg_type": "TIMTextElem", "msg_content": {"text": caption}}, + ) + + # 7. Lock + dispatch + gc = kwargs.get("group_code", "") + return await sender.dispatch_msg_body(chat_id, msg_body, reply_to, group_code=gc) + + except ValueError as ve: + return SendResult(success=False, error=str(ve)) + except Exception as exc: + handler_name = type(self).__name__ + logger.error( + "[%s] %s.handle() failed: %s", + adapter.name, handler_name, exc, exc_info=True, + ) + return SendResult(success=False, error=str(exc)) + + +class ImageUrlHandler(MediaSendHandler): + """Strategy: send image from a URL (download → COS → TIMImageElem).""" + + async def acquire_file(self, adapter, **kwargs): + image_url: str = kwargs["image_url"] + logger.info("[%s] ImageUrlHandler: downloading %s", adapter.name, image_url) + file_bytes, content_type = await media_download_url( + image_url, max_size_mb=adapter.MEDIA_MAX_SIZE_MB, + ) + if not content_type or content_type == "application/octet-stream": + path_part = image_url.split("?")[0] + content_type = guess_mime_type(path_part) or "image/jpeg" + filename = os.path.basename(image_url.split("?")[0]) or "image.jpg" + return file_bytes, filename, content_type + + def build_msg_body(self, upload_result, **kwargs): + return build_image_msg_body( + url=upload_result["url"], + uuid=kwargs["file_uuid"], + filename=kwargs["filename"], + size=upload_result["size"], + width=upload_result.get("width", 0), + height=upload_result.get("height", 0), + mime_type=kwargs["content_type"], + ) + + +class ImageFileHandler(MediaSendHandler): + """Strategy: send image from a local file path (read → COS → TIMImageElem).""" + + async def acquire_file(self, adapter, **kwargs): + image_path: str = kwargs["image_path"] + if not os.path.isfile(image_path): + raise ValueError(f"File not found: {image_path}") + logger.info("[%s] ImageFileHandler: reading %s", adapter.name, image_path) + with open(image_path, "rb") as f: + file_bytes = f.read() + filename = os.path.basename(image_path) or "image.jpg" + content_type = guess_mime_type(filename) or "image/jpeg" + return file_bytes, filename, content_type + + def build_msg_body(self, upload_result, **kwargs): + return build_image_msg_body( + url=upload_result["url"], + uuid=kwargs["file_uuid"], + filename=kwargs["filename"], + size=upload_result["size"], + width=upload_result.get("width", 0), + height=upload_result.get("height", 0), + mime_type=kwargs["content_type"], + ) + + +class FileUrlHandler(MediaSendHandler): + """Strategy: send file from a URL (download → COS → TIMFileElem).""" + + async def acquire_file(self, adapter, **kwargs): + file_url: str = kwargs["file_url"] + logger.info("[%s] FileUrlHandler: downloading %s", adapter.name, file_url) + file_bytes, content_type = await media_download_url( + file_url, max_size_mb=adapter.MEDIA_MAX_SIZE_MB, + ) + filename = kwargs.get("filename") + if not filename: + path_part = file_url.split("?")[0] + filename = os.path.basename(path_part) or "file" + if not content_type or content_type == "application/octet-stream": + content_type = guess_mime_type(filename) or "application/octet-stream" + return file_bytes, filename, content_type + + def build_msg_body(self, upload_result, **kwargs): + return build_file_msg_body( + url=upload_result["url"], + filename=kwargs["filename"], + uuid=kwargs["file_uuid"], + size=upload_result["size"], + ) + + +class DocumentHandler(MediaSendHandler): + """Strategy: send local file/document (read → COS → TIMFileElem).""" + + async def acquire_file(self, adapter, **kwargs): + file_path: str = kwargs["file_path"] + if not os.path.isfile(file_path): + raise ValueError(f"File not found: {file_path}") + logger.info("[%s] DocumentHandler: reading %s", adapter.name, file_path) + with open(file_path, "rb") as f: + file_bytes = f.read() + filename = kwargs.get("filename") or os.path.basename(file_path) or "document" + content_type = guess_mime_type(filename) or "application/octet-stream" + return file_bytes, filename, content_type + + def build_msg_body(self, upload_result, **kwargs): + return build_file_msg_body( + url=upload_result["url"], + filename=kwargs["filename"], + uuid=kwargs["file_uuid"], + size=upload_result["size"], + ) + + +class StickerHandler(MediaSendHandler): + """Strategy: send sticker/emoji (TIMFaceElem, no COS upload needed).""" + + def needs_cos_upload(self) -> bool: + return False + + async def acquire_file(self, adapter, **kwargs): + # Sticker does not need file bytes; return dummy values + return b"", "sticker", "application/octet-stream" + + def build_msg_body(self, upload_result, **kwargs): + from gateway.platforms.yuanbao_sticker import ( + get_sticker_by_name, + get_random_sticker, + build_face_msg_body, + build_sticker_msg_body, + ) + sticker_name = kwargs.get("sticker_name") + face_index = kwargs.get("face_index") + + if sticker_name is not None: + sticker = get_sticker_by_name(sticker_name) + if sticker is None: + raise ValueError(f"Sticker not found: {sticker_name!r}") + return build_sticker_msg_body(sticker) + elif face_index is not None: + return build_face_msg_body(face_index=face_index) + else: + sticker = get_random_sticker() + return build_sticker_msg_body(sticker) + +class GroupQueryService: + """Encapsulates all group query operations (both low-level WS calls and + higher-level AI-tool-facing wrappers). + + Responsibilities: + - Low-level WS encode/decode for group info and member list queries + - Chat-id parsing, error wrapping and result filtering for AI tools + - Member cache population on the adapter + """ + + def __init__(self, adapter: "YuanbaoAdapter") -> None: + self._adapter = adapter + + # ------------------------------------------------------------------ + # Low-level WS query methods + # ------------------------------------------------------------------ + + async def query_group_info_raw(self, group_code: str) -> Optional[dict]: + """Query group info via WS (group name, owner, member count, etc.). + + Returns: + Decoded dict or None on failure. + """ + adapter = self._adapter + if adapter._connection.ws is None: + return None + encoded = encode_query_group_info(group_code) + from gateway.platforms.yuanbao_proto import decode_conn_msg as _decode + decoded = _decode(encoded) + req_id = decoded["head"]["msg_id"] + try: + response = await adapter._connection.send_biz_request(encoded, req_id=req_id) + head = response.get("head", {}) + status = head.get("status", 0) + if status != 0: + logger.warning("[%s] query_group_info failed: status=%d", adapter.name, status) + return None + biz_data = response.get("data", b"") or response.get("body", b"") + if biz_data and isinstance(biz_data, bytes): + return decode_query_group_info_rsp(biz_data) + return {"group_code": group_code} + except asyncio.TimeoutError: + logger.warning("[%s] query_group_info timeout: group=%s", adapter.name, group_code) + return None + except Exception as exc: + logger.warning("[%s] query_group_info failed: %s", adapter.name, exc) + return None + + async def get_group_member_list_raw( + self, group_code: str, offset: int = 0, limit: int = 200 + ) -> Optional[dict]: + """Query group member list via WS. + + Returns: + Decoded dict or None on failure. Also populates adapter._member_cache. + """ + adapter = self._adapter + if adapter._connection.ws is None: + return None + encoded = encode_get_group_member_list(group_code, offset=offset, limit=limit) + from gateway.platforms.yuanbao_proto import decode_conn_msg as _decode + decoded = _decode(encoded) + req_id = decoded["head"]["msg_id"] + try: + response = await adapter._connection.send_biz_request(encoded, req_id=req_id) + head = response.get("head", {}) + status = head.get("status", 0) + if status != 0: + logger.warning("[%s] get_group_member_list failed: status=%d", adapter.name, status) + return None + biz_data = response.get("data", b"") or response.get("body", b"") + if biz_data and isinstance(biz_data, bytes): + result = decode_get_group_member_list_rsp(biz_data) + else: + result = {"members": [], "next_offset": 0, "is_complete": True} + if result and result.get("members"): + adapter._member_cache[group_code] = (time.time(), result["members"]) + return result + except asyncio.TimeoutError: + logger.warning("[%s] get_group_member_list timeout: group=%s", adapter.name, group_code) + return None + except Exception as exc: + logger.warning("[%s] get_group_member_list failed: %s", adapter.name, exc) + return None + + # ------------------------------------------------------------------ + # AI-tool-facing wrappers (chat_id parsing + filtering) + # ------------------------------------------------------------------ + + async def query_group_info(self, chat_id: str) -> dict: + """AI tool: Query current group info. + + No parameters needed (group_code extracted from session context). + Returns group name, owner, member count, etc. + """ + if not chat_id.startswith("group:"): + return {"error": "This command is only available in group chats"} + group_code = chat_id[len("group:"):] + result = await self.query_group_info_raw(group_code) + if result is None: + return {"error": "Failed to query group info"} + return result + + async def query_session_members( + self, + chat_id: str, + action: str = "list_all", + name: Optional[str] = None, + ) -> dict: + """AI tool: Query group member list. + + Args: + chat_id: Chat ID (extracted from session context) + action: 'find' (search by name) | 'list_bots' (list bots) | 'list_all' (list all) + name: Search keyword when action='find' + + Returns: + {"members": [...], "total": int, "mentionHint": str} + """ + if not chat_id.startswith("group:"): + return {"error": "This command is only available in group chats"} + group_code = chat_id[len("group:"):] + result = await self.get_group_member_list_raw(group_code) + if result is None: + return {"error": "Failed to query group members"} + + members = result.get("members", []) + + if action == "find" and name: + query = name.lower() + members = [ + m for m in members + if query in (m.get("nickname", "") or "").lower() + or query in (m.get("name_card", "") or "").lower() + or query in (m.get("user_id", "") or "").lower() + ] + elif action == "list_bots": + members = [m for m in members if "bot" in (m.get("nickname", "") or "").lower()] + + # Construct mentionHint + mention_hint = "" + if members and len(members) <= 10: + names = [m.get("name_card") or m.get("nickname") or m.get("user_id", "") for m in members] + mention_hint = "Mention with @name: " + ", ".join(names) + + return { + "members": members[:50], # Limit return count + "total": len(members), + "mentionHint": mention_hint, + } + + +class HeartbeatManager: + """Manages reply heartbeat (RUNNING / FINISH) lifecycle. + + Responsibilities: + - Periodic RUNNING heartbeat sender (every 2s) + - Auto-FINISH after 30s inactivity + - Explicit stop with optional FINISH signal + """ + + def __init__(self, adapter: "YuanbaoAdapter") -> None: + self._adapter = adapter + self._reply_heartbeat_tasks: Dict[str, asyncio.Task] = {} + self._reply_hb_last_active: Dict[str, float] = {} + + async def send_heartbeat_once(self, chat_id: str, heartbeat_val: int) -> None: + """Send a single heartbeat (RUNNING or FINISH), best effort.""" + adapter = self._adapter + conn = adapter._connection + if conn.ws is None or not adapter._bot_id: + return + try: + if chat_id.startswith("group:"): + group_code = chat_id[len("group:"):] + encoded = encode_send_group_heartbeat( + from_account=adapter._bot_id, + group_code=group_code, + heartbeat=heartbeat_val, + ) + else: + to_account = chat_id.removeprefix("direct:") + encoded = encode_send_private_heartbeat( + from_account=adapter._bot_id, + to_account=to_account, + heartbeat=heartbeat_val, + ) + await conn.ws.send(encoded) + status_name = "RUNNING" if heartbeat_val == WS_HEARTBEAT_RUNNING else "FINISH" + logger.debug( + "[%s] Reply heartbeat %s sent: chat=%s", + adapter.name, status_name, chat_id, + ) + except Exception as exc: + logger.debug("[%s] send_heartbeat_once failed: %s", adapter.name, exc) + + async def start(self, chat_id: str) -> None: + """Start or renew the Reply Heartbeat periodic sender (RUNNING, every 2s).""" + adapter = self._adapter + conn = adapter._connection + if conn.ws is None or not adapter._bot_id: + return + + existing = self._reply_heartbeat_tasks.get(chat_id) + if existing and not existing.done(): + self._reply_hb_last_active[chat_id] = time.time() + return + + self._reply_hb_last_active[chat_id] = time.time() + + task = asyncio.create_task( + self._worker(chat_id), + name=f"yuanbao-reply-hb-{chat_id}", + ) + self._reply_heartbeat_tasks[chat_id] = task + + async def _worker(self, chat_id: str) -> None: + """Background coroutine: send RUNNING heartbeat every 2s. + 30s without renewal -> send FINISH and exit. + """ + try: + await self.send_heartbeat_once(chat_id, WS_HEARTBEAT_RUNNING) + + while True: + await asyncio.sleep(REPLY_HEARTBEAT_INTERVAL_S) + + last_active = self._reply_hb_last_active.get(chat_id, 0) + if time.time() - last_active > REPLY_HEARTBEAT_TIMEOUT_S: + break + + conn = self._adapter._connection + if conn.ws is None: + break + + await self.send_heartbeat_once(chat_id, WS_HEARTBEAT_RUNNING) + + except asyncio.CancelledError: + cancelled = True + except Exception: + cancelled = False + else: + cancelled = False + finally: + if not cancelled: + try: + await self.send_heartbeat_once(chat_id, WS_HEARTBEAT_FINISH) + except Exception: + pass + self._reply_heartbeat_tasks.pop(chat_id, None) + self._reply_hb_last_active.pop(chat_id, None) + + async def stop(self, chat_id: str, send_finish: bool = True) -> None: + """Stop Reply Heartbeat and optionally send FINISH.""" + task = self._reply_heartbeat_tasks.pop(chat_id, None) + if task and not task.done(): + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + if send_finish: + try: + await self.send_heartbeat_once(chat_id, WS_HEARTBEAT_FINISH) + except Exception: + pass + + async def close(self) -> None: + """Cancel all reply heartbeat tasks.""" + for task in list(self._reply_heartbeat_tasks.values()): + if not task.done(): + task.cancel() + self._reply_heartbeat_tasks.clear() + self._reply_hb_last_active.clear() + + +class SlowResponseNotifier: + """Manages delayed 'please wait' notifications for slow agent responses. + + Starts a timer per chat_id; if the agent hasn't replied within + SLOW_RESPONSE_TIMEOUT_S seconds, sends a courtesy message. + """ + + def __init__(self, adapter: "YuanbaoAdapter", sender: "MessageSender") -> None: + self._adapter = adapter + self._sender = sender + self._tasks: Dict[str, asyncio.Task] = {} + + async def start(self, chat_id: str) -> None: + """Start a delayed task that notifies the user when the agent is slow.""" + self.cancel(chat_id) + task = asyncio.create_task( + self._notifier(chat_id), + name=f"yuanbao-slow-resp-{chat_id}", + ) + self._tasks[chat_id] = task + + async def _notifier(self, chat_id: str) -> None: + """Wait SLOW_RESPONSE_TIMEOUT_S, then push a 'please wait' message.""" + try: + await asyncio.sleep(SLOW_RESPONSE_TIMEOUT_S) + logger.info( + "[%s] Agent response exceeded %ds for %s, sending wait notice", + self._adapter.name, int(SLOW_RESPONSE_TIMEOUT_S), chat_id, + ) + await self._sender.send_text_chunk(chat_id, SLOW_RESPONSE_MESSAGE) + except asyncio.CancelledError: + pass + except Exception as exc: + logger.debug("[%s] Slow-response notifier failed: %s", self._adapter.name, exc) + + def cancel(self, chat_id: str) -> None: + """Cancel the pending slow-response notifier for *chat_id*, if any.""" + task = self._tasks.pop(chat_id, None) + if task and not task.done(): + task.cancel() + + async def close(self) -> None: + """Cancel all slow-response tasks.""" + for task in list(self._tasks.values()): + if not task.done(): + task.cancel() + self._tasks.clear() + + +class MessageSender: + """Core message sending dispatcher for YuanbaoAdapter. + + Responsibilities: + - Per-chat-id lock management (serial send ordering) + - Text chunk sending with retry + - C2C / Group message encoding and dispatch + - Media send helpers (image, file, sticker, document) + - Direct send helper (text + media, used by send_message tool) + """ + + IMAGE_EXTS: ClassVar[frozenset] = frozenset({".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp"}) + CHAT_DICT_MAX_SIZE: ClassVar[int] = 1000 # Max distinct chat IDs in _chat_locks + + def __init__(self, adapter: "YuanbaoAdapter") -> None: + self._adapter = adapter + self._chat_locks: collections.OrderedDict[str, asyncio.Lock] = collections.OrderedDict() + + # Optional hooks injected by OutboundManager for coordination + self._on_send_start: Optional[Callable[[str], Any]] = None # cancel slow-notifier + self._on_send_finish: Optional[Callable[[str], Any]] = None # send FINISH heartbeat + + # Media send handlers (strategy pattern) + self._media_handlers: Dict[str, MediaSendHandler] = { + "image_url": ImageUrlHandler(), + "image_file": ImageFileHandler(), + "file_url": FileUrlHandler(), + "document": DocumentHandler(), + "sticker": StickerHandler(), + } + + # -- Media handler registry --------------------------------------------- + + def register_handler(self, name: str, handler: MediaSendHandler) -> None: + """Register (or replace) a named media send handler.""" + self._media_handlers[name] = handler + + # -- Chat lock --------------------------------------------------------- + + def get_chat_lock(self, chat_id: str) -> asyncio.Lock: + """Return (or create) a per-chat-id lock with safe LRU eviction.""" + if chat_id in self._chat_locks: + self._chat_locks.move_to_end(chat_id) + return self._chat_locks[chat_id] + if len(self._chat_locks) >= self.CHAT_DICT_MAX_SIZE: + evicted = False + for key in list(self._chat_locks): + if not self._chat_locks[key].locked(): + self._chat_locks.pop(key) + evicted = True + break + if not evicted: + self._chat_locks.pop(next(iter(self._chat_locks))) + self._chat_locks[chat_id] = asyncio.Lock() + return self._chat_locks[chat_id] + + # -- Text send --------------------------------------------------------- + + async def send_text( + self, + chat_id: str, + content: str, + reply_to: Optional[str] = None, + group_code: str = "", + ) -> "SendResult": + """Send text message with auto-chunking and per-chat-id ordering guarantee.""" + adapter = self._adapter + conn = adapter._connection + if conn.ws is None: + return SendResult(success=False, error="Not connected", retryable=True) + + if self._on_send_start: + self._on_send_start(chat_id) + + lock = self.get_chat_lock(chat_id) + async with lock: + content_to_send = self.strip_cron_wrapper(content) + chunks = self.truncate_message(content_to_send, adapter.MAX_TEXT_CHUNK) + logger.info( + "[%s] truncate_message: input=%d chars, max=%d, output=%d chunk(s) sizes=%s", + adapter.name, len(content_to_send), adapter.MAX_TEXT_CHUNK, + len(chunks), [len(c) for c in chunks], + ) + for i, chunk in enumerate(chunks): + r_to = reply_to if i == 0 else None + result = await self.send_text_chunk(chat_id, chunk, r_to, group_code=group_code) + if not result.success: + return result + + # Notify outbound coordinator that send is complete (e.g. FINISH heartbeat) + if self._on_send_finish: + try: + await self._on_send_finish(chat_id) + except Exception: + pass + return SendResult(success=True) + + async def send_media( + self, + chat_id: str, + handler_name: str, + reply_to: Optional[str] = None, + caption: Optional[str] = None, + **kwargs: Any, + ) -> "SendResult": + """Dispatch media send to the named handler strategy.""" + handler = self._media_handlers.get(handler_name) + if handler is None: + return SendResult( + success=False, + error=f"Unknown media handler: {handler_name!r}", + ) + return await handler.handle( + self._adapter, chat_id, + reply_to=reply_to, caption=caption, **kwargs, + ) + + # -- Direct send (text + media, used by send_message tool) ------------- + + async def send_direct( + self, + chat_id: str, + message: str, + media_files: Optional[List[Tuple[str, bool]]] = None, + ) -> Dict[str, Any]: + """Send text + media via Yuanbao (used by the ``send_message`` tool). + + Unlike Weixin which creates a fresh adapter per call, Yuanbao reuses + the running gateway adapter (persistent WebSocket). Logic mirrors + send_weixin_direct: send text first, then iterate media_files by + extension. + """ + adapter = self._adapter + last_result: Optional["SendResult"] = None + + # 1. Send text + if message.strip(): + last_result = await adapter.send(chat_id, message) + if not last_result.success: + return {"error": f"Yuanbao send failed: {last_result.error}"} + + # 2. Iterate media_files, dispatch by file extension + for media_path, _is_voice in media_files or []: + ext = Path(media_path).suffix.lower() + if ext in self.IMAGE_EXTS: + last_result = await adapter.send_image_file(chat_id, media_path) + else: + last_result = await adapter.send_document(chat_id, media_path) + + if not last_result.success: + return {"error": f"Yuanbao media send failed: {last_result.error}"} + + if last_result is None: + return {"error": "No deliverable text or media remained after processing"} + + return { + "success": True, + "platform": "yuanbao", + "chat_id": chat_id, + "message_id": last_result.message_id if last_result else None, + } + + async def dispatch_msg_body( + self, + chat_id: str, + msg_body: list, + reply_to: Optional[str] = None, + group_code: str = "", + ) -> "SendResult": + """Lock + dispatch an arbitrary MsgBody to C2C or group.""" + lock = self.get_chat_lock(chat_id) + async with lock: + if chat_id.startswith("group:"): + grp = chat_id[len("group:"):] + result = await self.send_group_msg_body(grp, msg_body, reply_to) + else: + to_account = chat_id.removeprefix("direct:") + result = await self.send_c2c_msg_body(to_account, msg_body, group_code=group_code) + + if result.get("success"): + return SendResult(success=True, message_id=result.get("msg_key")) + return SendResult(success=False, error=result.get("error", "Unknown error")) + + async def send_text_chunk( + self, + chat_id: str, + text: str, + reply_to: Optional[str] = None, + retry: int = 3, + group_code: str = "", + ) -> "SendResult": + """Send a single text chunk with retry (exponential backoff: 1s, 2s, 4s).""" + adapter = self._adapter + last_error: str = "Unknown error" + for attempt in range(retry): + try: + if chat_id.startswith("group:"): + grp = chat_id[len("group:"):] + raw = await self.send_group_message(grp, text, reply_to) + else: + to_account = chat_id.removeprefix("direct:") + raw = await self.send_c2c_message(to_account, text, group_code=group_code) + + if raw.get("success"): + return SendResult(success=True, message_id=raw.get("msg_key")) + + last_error = raw.get("error", "Unknown error") + logger.warning( + "[%s] send_text_chunk attempt %d/%d failed: %s", + adapter.name, attempt + 1, retry, last_error, + ) + except Exception as exc: + last_error = str(exc) + logger.warning( + "[%s] send_text_chunk attempt %d/%d exception: %s", + adapter.name, attempt + 1, retry, last_error, + ) + + if attempt < retry - 1: + await asyncio.sleep(2 ** attempt) + + logger.error( + "[%s] send_text_chunk max retries (%d) exceeded. Last error: %s", + adapter.name, retry, last_error, + ) + return SendResult(success=False, error=f"Max retries exceeded: {last_error}") + + # -- C2C / Group message ----------------------------------------------- + + async def send_c2c_message(self, to_account: str, text: str, group_code: str = "") -> dict: + """Send C2C text message, return {success: bool, msg_key: str}.""" + msg_body = [{"msg_type": "TIMTextElem", "msg_content": {"text": text}}] + return await self.send_c2c_msg_body(to_account, msg_body, group_code=group_code) + + async def send_group_message( + self, + group_code: str, + text: str, + reply_to: Optional[str] = None, + ) -> dict: + """Send group text message, auto-converting @nickname to TIMCustomElem.""" + msg_body = self._build_msg_body_with_mentions(text, group_code) + return await self.send_group_msg_body(group_code, msg_body, reply_to) + + # @mention pattern: (whitespace or start) + @ + nickname + (whitespace or end) + _AT_USER_RE = re.compile(r'(?:(?<=\s)|(?<=^))@(\S+?)(?=\s|$)', re.MULTILINE) + + def _build_msg_body_with_mentions(self, text: str, group_code: str) -> list: + """Parse @nickname patterns and build mixed TIMTextElem + TIMCustomElem msg_body.""" + cached = self._adapter._member_cache.get(group_code) + if cached: + ts, member_list = cached + members = member_list if (time.time() - ts < self._adapter.MEMBER_CACHE_TTL_S) else [] + else: + members = [] + if not members: + return [{"msg_type": "TIMTextElem", "msg_content": {"text": text}}] + + nickname_to_uid = {} + for m in members: + nick = m.get("nickname") or m.get("nick_name") or "" + uid = m.get("user_id") or "" + if nick and uid: + nickname_to_uid[nick.lower()] = (nick, uid) + + msg_body: list = [] + last_idx = 0 + for match in self._AT_USER_RE.finditer(text): + start = match.start() + if start > last_idx: + seg = text[last_idx:start].strip() + if seg: + msg_body.append({"msg_type": "TIMTextElem", "msg_content": {"text": seg}}) + + nickname = match.group(1) + entry = nickname_to_uid.get(nickname.lower()) + if entry: + real_nick, uid = entry + msg_body.append({ + "msg_type": "TIMCustomElem", + "msg_content": { + "data": json.dumps({"elem_type": 1002, "text": f"@{real_nick}", "user_id": uid}), + }, + }) + else: + msg_body.append({"msg_type": "TIMTextElem", "msg_content": {"text": f"@{nickname}"}}) + + last_idx = match.end() + + if last_idx < len(text): + tail = text[last_idx:].strip() + if tail: + msg_body.append({"msg_type": "TIMTextElem", "msg_content": {"text": tail}}) + + if not msg_body: + msg_body.append({"msg_type": "TIMTextElem", "msg_content": {"text": text}}) + + return msg_body + + async def send_c2c_msg_body(self, to_account: str, msg_body: list, group_code: str = "") -> dict: + """Send C2C message with arbitrary MsgBody.""" + adapter = self._adapter + req_id = f"c2c_{next_seq_no()}" + encoded = encode_send_c2c_message( + to_account=to_account, + msg_body=msg_body, + from_account=adapter._bot_id or "", + msg_id=req_id, + group_code=group_code, + ) + return await self._dispatch_encoded(adapter, encoded, req_id) + + async def send_group_msg_body( + self, + group_code: str, + msg_body: list, + reply_to: Optional[str] = None, + ) -> dict: + """Send group message with arbitrary MsgBody.""" + adapter = self._adapter + req_id = f"grp_{next_seq_no()}" + encoded = encode_send_group_message( + group_code=group_code, + msg_body=msg_body, + from_account=adapter._bot_id or "", + msg_id=req_id, + ref_msg_id=reply_to or "", + ) + return await self._dispatch_encoded(adapter, encoded, req_id) + + # -- Common dispatch helper -------------------------------------------- + + @staticmethod + async def _dispatch_encoded( + adapter: "YuanbaoAdapter", encoded: bytes, req_id: str, + ) -> dict: + """Send pre-encoded bytes via WS and return a normalised result dict.""" + try: + response = await adapter._connection.send_biz_request(encoded, req_id=req_id) + return {"success": True, "msg_key": response.get("msg_id", "")} + except asyncio.TimeoutError: + return {"success": False, "error": f"Request timeout after {DEFAULT_SEND_TIMEOUT}s"} + except Exception as exc: + return {"success": False, "error": str(exc)} + + # -- Media validation --------------------------------------------------- + + @staticmethod + def validate_media( + file_bytes: Optional[bytes], filename: str, max_size_mb: int = 20 + ) -> Optional[str]: + """Media pre-validation: check file validity before sending/uploading. + + Returns: + Error description (str) if validation fails, otherwise None. + """ + if file_bytes is None or len(file_bytes) == 0: + return f"Empty file: {filename}" + max_bytes = max_size_mb * 1024 * 1024 + if len(file_bytes) > max_bytes: + size_mb = len(file_bytes) / 1024 / 1024 + return f"File too large: {filename} ({size_mb:.1f}MB > {max_size_mb}MB)" + return None + + # -- Text truncation (table-aware) -------------------------------------- + + @staticmethod + def truncate_message( + content: str, + max_length: int = 4000, + len_fn: Optional[Callable[[str], int]] = None, + ) -> List[str]: + """ + Split a long message into chunks with table-awareness. + + Delegates core splitting to ``MarkdownProcessor.chunk_markdown_text`` + and strips page indicators like ``(1/3)`` from the output. + + Falls back to ``BasePlatformAdapter.truncate_message`` for non-table + content and for overall text that fits in a single chunk. + """ + _len = len_fn or len + if _len(content) <= max_length: + return [content] + + # Delegate to MarkdownProcessor for table/fence-aware chunking + chunks = MarkdownProcessor.chunk_markdown_text( + content, max_length, len_fn=len_fn, + ) + + # Strip page indicators like (1/3) that BasePlatformAdapter may add + chunks = [_INDICATOR_RE.sub('', c) for c in chunks] + + return chunks if chunks else [content] + + # -- Cron wrapper stripping --------------------------------------------- + + @staticmethod + def strip_cron_wrapper(content: str) -> str: + """Strip scheduler cron header/footer wrapper for cleaner Yuanbao output.""" + if not content.startswith("Cronjob Response: "): + return content + + divider = "\n-------------\n\n" + footer_prefix = '\n\nTo stop or manage this job, send me a new message (e.g. "stop reminder ' + divider_pos = content.find(divider) + footer_pos = content.rfind(footer_prefix) + if divider_pos < 0 or footer_pos < 0 or footer_pos <= divider_pos: + return content + + header = content[:divider_pos] + if "\n(job_id: " not in header: + return content + + body_start = divider_pos + len(divider) + body = content[body_start:footer_pos].strip() + return body or content + + # -- Cleanup on disconnect --------------------------------------------- + + async def close(self) -> None: + """Release chat locks (no-op for now; placeholder for future cleanup).""" + self._chat_locks.clear() + + +class OutboundManager: + """Outbound coordinator that orchestrates sending, heartbeat and slow-response. + + Composes: + - MessageSender — core text/media sending + - HeartbeatManager — reply heartbeat (RUNNING / FINISH) lifecycle + - SlowResponseNotifier — delayed 'please wait' notifications + + YuanbaoAdapter holds a single ``_outbound: OutboundManager`` and delegates + all outbound operations through it. + """ + + # Expose class-level constants from MessageSender for backward compatibility + CHAT_DICT_MAX_SIZE: ClassVar[int] = MessageSender.CHAT_DICT_MAX_SIZE + + def __init__(self, adapter: "YuanbaoAdapter") -> None: + self._adapter = adapter + self.sender: MessageSender = MessageSender(adapter) + self.heartbeat: HeartbeatManager = HeartbeatManager(adapter) + self.slow_notifier: SlowResponseNotifier = SlowResponseNotifier(adapter, self.sender) + + # Wire coordination hooks into MessageSender + self.sender._on_send_start = self._handle_send_start + self.sender._on_send_finish = self._handle_send_finish + + # -- Coordination hooks ------------------------------------------------ + + def _handle_send_start(self, chat_id: str) -> None: + """Called by MessageSender before sending: cancel slow-response notifier.""" + self.slow_notifier.cancel(chat_id) + + async def _handle_send_finish(self, chat_id: str) -> None: + """Called by MessageSender after sending: send FINISH heartbeat.""" + await self.heartbeat.send_heartbeat_once(chat_id, WS_HEARTBEAT_FINISH) + + # -- Delegated public API (used by YuanbaoAdapter) --------------------- + + async def send_text( + self, chat_id: str, content: str, reply_to: Optional[str] = None, + group_code: str = "", + ) -> "SendResult": + """Send text message with auto-chunking.""" + return await self.sender.send_text(chat_id, content, reply_to, group_code=group_code) + + async def send_media( + self, chat_id: str, handler_name: str, **kwargs: Any, + ) -> "SendResult": + """Dispatch media send to the named handler strategy.""" + return await self.sender.send_media(chat_id, handler_name, **kwargs) + + async def send_direct( + self, chat_id: str, message: str, + media_files: Optional[List[Tuple[str, bool]]] = None, + ) -> Dict[str, Any]: + """Send text + media (used by send_message tool).""" + return await self.sender.send_direct(chat_id, message, media_files) + + async def start_typing(self, chat_id: str) -> None: + """Start reply heartbeat (RUNNING).""" + await self.heartbeat.start(chat_id) + + async def stop_typing(self, chat_id: str, send_finish: bool = False) -> None: + """Stop reply heartbeat.""" + await self.heartbeat.stop(chat_id, send_finish=send_finish) + + async def start_slow_notifier(self, chat_id: str) -> None: + """Start slow-response notifier.""" + await self.slow_notifier.start(chat_id) + + def cancel_slow_notifier(self, chat_id: str) -> None: + """Cancel slow-response notifier.""" + self.slow_notifier.cancel(chat_id) + + def get_chat_lock(self, chat_id: str) -> asyncio.Lock: + """Proxy to MessageSender.get_chat_lock for backward compatibility.""" + return self.sender.get_chat_lock(chat_id) + + @property + def _chat_locks(self) -> collections.OrderedDict: + """Proxy to MessageSender._chat_locks for backward compatibility.""" + return self.sender._chat_locks + + @staticmethod + def validate_media( + file_bytes: Optional[bytes], filename: str, max_size_mb: int = 20, + ) -> Optional[str]: + """Proxy to MessageSender.validate_media.""" + return MessageSender.validate_media(file_bytes, filename, max_size_mb) + + async def close(self) -> None: + """Shut down all sub-managers.""" + await self.sender.close() + await self.heartbeat.close() + await self.slow_notifier.close() + + +class YuanbaoAdapter(BasePlatformAdapter): + """Yuanbao AI Bot adapter backed by a persistent WebSocket connection.""" + + PLATFORM = Platform.YUANBAO + MAX_TEXT_CHUNK: int = 4000 # Yuanbao single message character limit + MEDIA_MAX_SIZE_MB: int = 50 # Max media file size in MB for upload validation + REPLY_REF_MAX_ENTRIES: ClassVar[int] = 500 # Max capacity of reference dedup dict + + # -- Active instance registry (class-level singleton) ------------------- + + _active_instance: ClassVar[Optional["YuanbaoAdapter"]] = None + + @classmethod + def get_active(cls) -> Optional["YuanbaoAdapter"]: + """Return the currently connected YuanbaoAdapter, or None.""" + return cls._active_instance + + @classmethod + def set_active(cls, adapter: Optional["YuanbaoAdapter"]) -> None: + """Register (or clear) the active adapter instance.""" + cls._active_instance = adapter + + def __init__(self, config: PlatformConfig, **kwargs: Any) -> None: + super().__init__(config, Platform.YUANBAO) + + # Credentials / endpoints from config.extra (populated by config.py from env/yaml) + _extra = config.extra or {} + self._app_key: str = (_extra.get("app_id") or "").strip() + self._app_secret: str = (_extra.get("app_secret") or "").strip() + self._bot_id: Optional[str] = _extra.get("bot_id") or None + self._ws_url: str = (_extra.get("ws_url") or DEFAULT_WS_GATEWAY_URL).strip() + self._api_domain: str = (_extra.get("api_domain") or DEFAULT_API_DOMAIN).rstrip("/") + self._route_env: str = (_extra.get("route_env") or "").strip() + + # Core managers (UML composition) + self._connection: ConnectionManager = ConnectionManager(self) + self._outbound: OutboundManager = OutboundManager(self) + + # Inbound dispatch tasks — tracked so disconnect() can cancel them + self._inbound_tasks: set[asyncio.Task] = set() + + # Set of background tasks — prevent GC from collecting fire-and-forget tasks + self._background_tasks: set[asyncio.Task] = set() + + # Member cache: group_code -> (updated_ts, [{"user_id":..., "nickname":..., ...}, ...]) + # Populated by get_group_member_list(), used by @mention resolution. + # Entries older than MEMBER_CACHE_TTL_S are treated as stale. + self._member_cache: Dict[str, Tuple[float, list]] = {} + self.MEMBER_CACHE_TTL_S: float = 300.0 # 5 minutes + + # Inbound message deduplication (WS reconnect / network jitter) + self._dedup = MessageDeduplicator(ttl_seconds=300) + + # Group chat sequential dispatch queue (session_key → asyncio.Queue). + self._group_queues: Dict[str, asyncio.Queue] = {} + + # Recall support: track which msg_id is being processed per session_key + # so RecallGuardMiddleware can detect "currently processing" messages. + self._processing_msg_ids: Dict[str, str] = {} + self._processing_msg_texts: Dict[str, str] = {} + # Bounded cache of msg_id → attributed content for recent messages. + # Used by _patch_transcript as content-match fallback when transcript + # entries lack a message_id field (agent-processed @bot messages). + self._msg_content_cache: Dict[str, str] = {} + + # Reply-to dedup: inbound_msg_id -> expire_ts + # ------------------------------------------------------------------ + # Access control policy (DM / Group) + # ------------------------------------------------------------------ + dm_policy: str = ( + _extra.get("dm_policy") + or os.getenv("YUANBAO_DM_POLICY", "open") + ).strip().lower() + + _dm_allow_from_raw: str = ( + _extra.get("dm_allow_from") + or os.getenv("YUANBAO_DM_ALLOW_FROM", "") + ) + dm_allow_from: list[str] = [x.strip() for x in _dm_allow_from_raw.split(",") if x.strip()] + + group_policy: str = ( + _extra.get("group_policy") + or os.getenv("YUANBAO_GROUP_POLICY", "open") + ).strip().lower() + + _group_allow_from_raw: str = ( + _extra.get("group_allow_from") + or os.getenv("YUANBAO_GROUP_ALLOW_FROM", "") + ) + group_allow_from: list[str] = [x.strip() for x in _group_allow_from_raw.split(",") if x.strip()] + + self._access_policy = AccessPolicy( + dm_policy=dm_policy, + dm_allow_from=dm_allow_from, + group_policy=group_policy, + group_allow_from=group_allow_from, + ) + + # Group query service (AI tool backing) + self._group_query = GroupQueryService(self) + + # Inbound message processing pipeline (middleware pattern) + self._inbound_pipeline: InboundPipeline = InboundPipelineBuilder.build() + + # ------------------------------------------------------------------ + # Auto-sethome: first user to message the bot becomes the owner. + # If no home channel is configured, the first conversation will be + # automatically set as the home channel. When the existing home + # channel is a group chat (group:xxx), it stays eligible for + # upgrade — the first DM will override it with direct:xxx. + # ------------------------------------------------------------------ + _existing_home = os.getenv("YUANBAO_HOME_CHANNEL") or ( + config.home_channel.chat_id if config.home_channel else "" + ) + self._auto_sethome_done: bool = bool(_existing_home) and not _existing_home.startswith("group:") + + # ------------------------------------------------------------------ + # Task tracking helper + # ------------------------------------------------------------------ + + def _track_task(self, task: asyncio.Task) -> asyncio.Task: + """Register a fire-and-forget task so it won't be GC'd prematurely.""" + self._background_tasks.add(task) + task.add_done_callback(self._background_tasks.discard) + return task + + # ------------------------------------------------------------------ + # Abstract method implementations + # ------------------------------------------------------------------ + + async def connect(self) -> bool: + """Connect to Yuanbao WS gateway and authenticate. + + Delegates to ConnectionManager.open(). + """ + return await self._connection.open() + + async def disconnect(self) -> None: + """Cancel background tasks and close the WebSocket connection.""" + if YuanbaoAdapter._active_instance is self: + YuanbaoAdapter.set_active(None) + + self._running = False + self._mark_disconnected() + self._release_platform_lock() + + # Delegate to managers + await self._connection.close() + await self._outbound.close() + + # Cancel all in-flight inbound dispatch tasks + for task in list(self._inbound_tasks): + if not task.done(): + task.cancel() + self._inbound_tasks.clear() + + self._group_queues.clear() + + logger.info("[%s] Disconnected", self.name) + + async def send( + self, + chat_id: str, + content: str, + reply_to: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + group_code: str = "", + ) -> SendResult: + """Send text message with auto-chunking. Delegates to OutboundManager.""" + return await self._outbound.send_text(chat_id, content, reply_to, group_code=group_code) + + async def get_chat_info(self, chat_id: str) -> Dict[str, Any]: + """Return basic chat metadata derived from the chat_id prefix. + + chat_id conventions: + "group:" → group chat + "direct:" → C2C / direct message (default) + + TODO (T06): fetch real chat name/member-count from Yuanbao API. + """ + if chat_id.startswith("group:"): + return {"name": chat_id, "type": "group"} + return {"name": chat_id, "type": "dm"} + + async def send_typing(self, chat_id: str, metadata: Optional[dict] = None) -> None: + """Send "typing" status heartbeat (RUNNING). Delegates to OutboundManager.""" + try: + await self._outbound.start_typing(chat_id) + except Exception: + pass + + async def stop_typing(self, chat_id: str) -> None: + """Stop the RUNNING heartbeat loop without sending FINISH immediately. + + FINISH is sent by send() after actual message delivery to ensure correct ordering: + RUNNING... -> message arrives -> FINISH. + """ + try: + await self._outbound.stop_typing(chat_id, send_finish=False) + except Exception: + pass + + async def _process_message_background(self, event, session_key: str) -> None: + """Wrap base class processing with a slow-response notifier.""" + chat_id = event.source.chat_id + await self._outbound.start_slow_notifier(chat_id) + try: + await super()._process_message_background(event, session_key) + finally: + self._outbound.cancel_slow_notifier(chat_id) + + # ------------------------------------------------------------------ + # Group query (delegate to GroupQueryService) + # ------------------------------------------------------------------ + + async def query_group_info(self, group_code: str) -> Optional[dict]: + """Query group info (delegates to GroupQueryService).""" + return await self._group_query.query_group_info_raw(group_code) + + async def get_group_member_list( + self, group_code: str, offset: int = 0, limit: int = 200 + ) -> Optional[dict]: + """Query group member list (delegates to GroupQueryService).""" + return await self._group_query.get_group_member_list_raw(group_code, offset=offset, limit=limit) + + # ------------------------------------------------------------------ + # DM active private chat + access control + # ------------------------------------------------------------------ + + DM_MAX_CHARS = 10000 # DM text limit + + async def send_dm(self, user_id: str, text: str, group_code: str = "") -> SendResult: + """ + Actively send C2C private chat message. + + Args: + user_id: Target user ID + text: Message text (limit 10000 characters) + group_code: Source group code (for group-originated DM context) + + Returns: + SendResult + """ + if not self._access_policy.is_dm_allowed(user_id): + return SendResult(success=False, error="DM access denied for this user") + if len(text) > self.DM_MAX_CHARS: + text = text[:self.DM_MAX_CHARS] + "\n...(truncated)" + chat_id = f"direct:{user_id}" + return await self.send(chat_id, text, group_code=group_code) + + # ------------------------------------------------------------------ + # Media send methods + # ------------------------------------------------------------------ + + async def send_image( + self, + chat_id: str, + image_url: str, + caption: Optional[str] = None, + reply_to: Optional[str] = None, + metadata: Optional[dict] = None, + **kwargs: Any, + ) -> SendResult: + """Send image message (URL). Delegates to OutboundManager via ImageUrlHandler.""" + return await self._outbound.send_media( + chat_id, "image_url", + reply_to=reply_to, caption=caption, image_url=image_url, + **kwargs, + ) + + async def send_image_file( + self, + chat_id: str, + image_path: str, + caption: Optional[str] = None, + reply_to: Optional[str] = None, + metadata: Optional[dict] = None, + **kwargs: Any, + ) -> SendResult: + """Send local image file. Delegates to OutboundManager via ImageFileHandler.""" + return await self._outbound.send_media( + chat_id, "image_file", + reply_to=reply_to, caption=caption, image_path=image_path, + **kwargs, + ) + + async def send_file( + self, + chat_id: str, + file_url: str, + filename: Optional[str] = None, + reply_to: Optional[str] = None, + metadata: Optional[dict] = None, + **kwargs: Any, + ) -> SendResult: + """Send file message (URL). Delegates to OutboundManager via FileUrlHandler.""" + return await self._outbound.send_media( + chat_id, "file_url", + reply_to=reply_to, file_url=file_url, filename=filename, + **kwargs, + ) + + async def send_sticker( + self, + chat_id: str, + sticker_name: Optional[str] = None, + face_index: Optional[int] = None, + reply_to: Optional[str] = None, + **kwargs: Any, + ) -> SendResult: + """Send sticker/emoji. Delegates to OutboundManager via StickerHandler.""" + return await self._outbound.send_media( + chat_id, "sticker", + reply_to=reply_to, + sticker_name=sticker_name, face_index=face_index, + **kwargs, + ) + + async def send_document( + self, + chat_id: str, + file_path: str, + filename: Optional[str] = None, + caption: Optional[str] = None, + reply_to: Optional[str] = None, + metadata: Optional[dict] = None, + **kwargs: Any, + ) -> SendResult: + """Send local file (document). Delegates to OutboundManager via DocumentHandler.""" + return await self._outbound.send_media( + chat_id, "document", + reply_to=reply_to, caption=caption, + file_path=file_path, filename=filename, + **kwargs, + ) + + async def _get_cached_token(self) -> dict: + """Get the current valid sign token (using module-level cache).""" + return await SignManager.get_token( + self._app_key, self._app_secret, self._api_domain, + route_env=self._route_env, + ) + + def get_status(self) -> dict: + """Return a snapshot of the current connection status.""" + conn = self._connection + return { + "connected": conn.is_connected, + "bot_id": self._bot_id, + "connect_id": conn.connect_id, + "reconnect_attempts": conn.reconnect_attempts, + "ws_url": self._ws_url, + } + + +# --------------------------------------------------------------------------- +# Module-level thin delegates (preserve import compatibility for external callers) +# --------------------------------------------------------------------------- + + +def get_active_adapter() -> Optional["YuanbaoAdapter"]: + """Delegate to ``YuanbaoAdapter.get_active()``.""" + return YuanbaoAdapter.get_active() + + +async def send_yuanbao_direct( + adapter: "YuanbaoAdapter", + chat_id: str, + message: str, + media_files: Optional[List[Tuple[str, bool]]] = None, +) -> Dict[str, Any]: + """Delegate to ``OutboundManager.send_direct``.""" + return await adapter._outbound.send_direct(chat_id, message, media_files) diff --git a/gateway/platforms/yuanbao_media.py b/gateway/platforms/yuanbao_media.py new file mode 100644 index 0000000000..8d697a3a8c --- /dev/null +++ b/gateway/platforms/yuanbao_media.py @@ -0,0 +1,647 @@ +""" +yuanbao_media.py — 元宝平台媒体处理模块 + +提供 COS 上传、文件下载、TIM 媒体消息构建等功能。 +移植自 TypeScript 版 media.ts(yuanbao-openclaw-plugin), +使用 httpx 替代 cos-nodejs-sdk-v5,避免引入额外 SDK 依赖。 + +COS 上传流程: + 1. 调用 genUploadInfo 获取临时凭证(tmpSecretId/tmpSecretKey/sessionToken) + 2. 用临时凭证通过 HMAC-SHA1 签名构建 Authorization 头 + 3. HTTP PUT 上传到 COS + +TIM 消息体构建: + - buildImageMsgBody() → TIMImageElem + - buildFileMsgBody() → TIMFileElem +""" + +from __future__ import annotations + +import hashlib +import hmac +import logging +import os +import re +import secrets +import struct +import time +import urllib.parse +from datetime import datetime, timezone, timedelta +from typing import Optional, Any + +import httpx + +logger = logging.getLogger(__name__) + +# ============ 常量 ============ + +UPLOAD_INFO_PATH = "/api/resource/genUploadInfo" +DEFAULT_API_DOMAIN = "yuanbao.tencent.com" +DEFAULT_MAX_SIZE_MB = 50 + +# COS 加速域名后缀(优先使用全球加速) +COS_USE_ACCELERATE = True + +# ============ 类型映射 ============ + +# MIME → image_format 数字(TIM 协议字段) +_MIME_TO_IMAGE_FORMAT: dict[str, int] = { + "image/jpeg": 1, + "image/jpg": 1, + "image/gif": 2, + "image/png": 3, + "image/bmp": 4, + "image/webp": 255, + "image/heic": 255, + "image/tiff": 255, +} + +# 文件扩展名 → MIME +_EXT_TO_MIME: dict[str, str] = { + ".jpg": "image/jpeg", + ".jpeg": "image/jpeg", + ".png": "image/png", + ".gif": "image/gif", + ".webp": "image/webp", + ".bmp": "image/bmp", + ".heic": "image/heic", + ".tiff": "image/tiff", + ".ico": "image/x-icon", + ".pdf": "application/pdf", + ".doc": "application/msword", + ".docx": "application/vnd.openxmlformats-officedocument.wordprocessingml.document", + ".xls": "application/vnd.ms-excel", + ".xlsx": "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", + ".ppt": "application/vnd.ms-powerpoint", + ".pptx": "application/vnd.openxmlformats-officedocument.presentationml.presentation", + ".txt": "text/plain", + ".zip": "application/zip", + ".tar": "application/x-tar", + ".gz": "application/gzip", + ".mp3": "audio/mpeg", + ".mp4": "video/mp4", + ".wav": "audio/wav", + ".ogg": "audio/ogg", + ".webm": "video/webm", +} + + +# ============ 工具函数 ============ + +def guess_mime_type(filename: str) -> str: + """根据文件扩展名猜测 MIME 类型。""" + ext = os.path.splitext(filename)[-1].lower() + return _EXT_TO_MIME.get(ext, "application/octet-stream") + + +def is_image(filename: str, mime_type: str = "") -> bool: + """判断是否为图片类型。""" + if mime_type.startswith("image/"): + return True + ext = os.path.splitext(filename)[-1].lower() + return ext in {".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp", ".heic", ".tiff", ".ico"} + + +def get_image_format(mime_type: str) -> int: + """获取 TIM 图片格式编号。""" + return _MIME_TO_IMAGE_FORMAT.get(mime_type.lower(), 255) + + +def md5_hex(data: bytes) -> str: + """计算 MD5 十六进制摘要。""" + return hashlib.md5(data).hexdigest() + + +def generate_file_id() -> str: + """生成随机文件 ID(32 位 hex)。""" + return secrets.token_hex(16) + + + +# ============ 图片尺寸解析(纯 Python,无需 Pillow) ============ + +def parse_image_size(data: bytes) -> Optional[dict[str, int]]: + """ + 解析图片宽高(支持 JPEG/PNG/GIF/WebP),无需第三方依赖。 + 返回 {"width": w, "height": h} 或 None(无法识别)。 + """ + return ( + _parse_png_size(data) + or _parse_jpeg_size(data) + or _parse_gif_size(data) + or _parse_webp_size(data) + ) + + +def _parse_png_size(buf: bytes) -> Optional[dict[str, int]]: + if len(buf) < 24: + return None + if buf[:4] != b"\x89PNG": + return None + w = struct.unpack(">I", buf[16:20])[0] + h = struct.unpack(">I", buf[20:24])[0] + return {"width": w, "height": h} + + +def _parse_jpeg_size(buf: bytes) -> Optional[dict[str, int]]: + if len(buf) < 4 or buf[0] != 0xFF or buf[1] != 0xD8: + return None + i = 2 + while i < len(buf) - 9: + if buf[i] != 0xFF: + i += 1 + continue + marker = buf[i + 1] + if marker in (0xC0, 0xC2): + h = struct.unpack(">H", buf[i + 5: i + 7])[0] + w = struct.unpack(">H", buf[i + 7: i + 9])[0] + return {"width": w, "height": h} + if i + 3 < len(buf): + i += 2 + struct.unpack(">H", buf[i + 2: i + 4])[0] + else: + break + return None + + +def _parse_gif_size(buf: bytes) -> Optional[dict[str, int]]: + if len(buf) < 10: + return None + sig = buf[:6].decode("ascii", errors="replace") + if sig not in ("GIF87a", "GIF89a"): + return None + w = struct.unpack(" Optional[dict[str, int]]: + if len(buf) < 16: + return None + if buf[:4] != b"RIFF" or buf[8:12] != b"WEBP": + return None + chunk = buf[12:16].decode("ascii", errors="replace") + if chunk == "VP8 ": + if len(buf) >= 30 and buf[23] == 0x9D and buf[24] == 0x01 and buf[25] == 0x2A: + w = struct.unpack("= 25 and buf[20] == 0x2F: + bits = struct.unpack("> 14) & 0x3FFF) + 1 + return {"width": w, "height": h} + elif chunk == "VP8X": + if len(buf) >= 30: + w = (buf[24] | (buf[25] << 8) | (buf[26] << 16)) + 1 + h = (buf[27] | (buf[28] << 8) | (buf[29] << 16)) + 1 + return {"width": w, "height": h} + return None + + +# ============ URL 下载 ============ + +async def download_url( + url: str, + max_size_mb: int = DEFAULT_MAX_SIZE_MB, +) -> tuple[bytes, str]: + """ + 下载 URL 内容,返回 (bytes, content_type)。 + + Args: + url: HTTP(S) URL + max_size_mb: 最大允许大小(MB),超过则抛出异常 + + Returns: + (data_bytes, content_type_string) + + Raises: + ValueError: 内容超过大小限制 + httpx.HTTPError: 网络/HTTP 错误 + """ + max_bytes = max_size_mb * 1024 * 1024 + async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client: + # 先 HEAD 检查大小 + try: + head = await client.head(url) + content_length = int(head.headers.get("content-length", 0) or 0) + if content_length > 0 and content_length > max_bytes: + raise ValueError( + f"文件过大: {content_length / 1024 / 1024:.1f} MB > {max_size_mb} MB" + ) + except httpx.HTTPStatusError: + pass # 部分服务器不支持 HEAD,忽略 + + # GET 下载(流式读取,防止超限) + async with client.stream("GET", url) as resp: + resp.raise_for_status() + + content_type = resp.headers.get("content-type", "").split(";")[0].strip() + + chunks: list[bytes] = [] + downloaded = 0 + async for chunk in resp.aiter_bytes(65536): + downloaded += len(chunk) + if downloaded > max_bytes: + raise ValueError( + f"文件过大: 已超过 {max_size_mb} MB 限制" + ) + chunks.append(chunk) + + data = b"".join(chunks) + return data, content_type + + +# ============ COS 鉴权(HMAC-SHA1) ============ + +def _cos_sign( + method: str, + path: str, + params: dict[str, str], + headers: dict[str, str], + secret_id: str, + secret_key: str, + start_time: Optional[int] = None, + expire_seconds: int = 3600, +) -> str: + """ + 构建 COS 请求签名(q-sign-algorithm=sha1 方案)。 + 参考:https://cloud.tencent.com/document/product/436/7778 + + Args: + method: HTTP 方法(小写,如 "put") + path: URL 路径(URL encode 后的小写) + params: URL 查询参数 dict(用于签名) + headers: 参与签名的请求头 dict(key 需小写) + secret_id: 临时 SecretId(tmpSecretId) + secret_key: 临时 SecretKey(tmpSecretKey) + start_time: 签名起始 Unix 时间戳(默认 now) + expire_seconds: 签名有效期(秒,默认 3600) + + Returns: + Authorization header 值(完整字符串) + """ + now = int(time.time()) + q_sign_time = f"{start_time or now};{(start_time or now) + expire_seconds}" + + # Step 1: SignKey = HMAC-SHA1(SecretKey, q-sign-time) + sign_key = hmac.new( + secret_key.encode("utf-8"), + q_sign_time.encode("utf-8"), + hashlib.sha1, + ).hexdigest() + + # Step 2: HttpString + # 参数和头部需按字典序排列,key 小写 + sorted_params = sorted((k.lower(), urllib.parse.quote(str(v), safe="") ) for k, v in params.items()) + sorted_headers = sorted((k.lower(), urllib.parse.quote(str(v), safe="") ) for k, v in headers.items()) + + url_param_list = ";".join(k for k, _ in sorted_params) + url_params = "&".join(f"{k}={v}" for k, v in sorted_params) + header_list = ";".join(k for k, _ in sorted_headers) + header_str = "&".join(f"{k}={v}" for k, v in sorted_headers) + + http_string = "\n".join([ + method.lower(), + path, + url_params, + header_str, + "", + ]) + + # Step 3: StringToSign = sha1 hash of HttpString + sha1_of_http = hashlib.sha1(http_string.encode("utf-8")).hexdigest() + string_to_sign = "\n".join([ + "sha1", + q_sign_time, + sha1_of_http, + "", + ]) + + # Step 4: Signature = HMAC-SHA1(SignKey, StringToSign) + signature = hmac.new( + sign_key.encode("utf-8"), + string_to_sign.encode("utf-8"), + hashlib.sha1, + ).hexdigest() + + return ( + f"q-sign-algorithm=sha1" + f"&q-ak={secret_id}" + f"&q-sign-time={q_sign_time}" + f"&q-key-time={q_sign_time}" + f"&q-header-list={header_list}" + f"&q-url-param-list={url_param_list}" + f"&q-signature={signature}" + ) + + +# ============ 主要公开 API ============ + +async def get_cos_credentials( + app_key: str, + api_domain: str, + token: str, + filename: str = "file", + file_id: Optional[str] = None, + bot_id: str = "", + route_env: str = "", +) -> dict: + """ + 调用 genUploadInfo 接口获取 COS 临时密钥及上传配置。 + + Args: + app_key: 应用 Key(用于 X-ID 头) + api_domain: API 域名(如 https://bot.yuanbao.tencent.com) + token: 当前有效的签票 token(X-Token 头) + filename: 待上传的文件名(含扩展名) + file_id: 客户端生成的唯一文件 ID(不传则自动生成) + bot_id: Bot 账号 ID(用于 X-ID 头) + + Returns: + COS 上传配置 dict,包含以下字段: + bucketName (str) — COS Bucket 名称 + region (str) — COS 地域 + location (str) — 上传 Key(对象路径) + encryptTmpSecretId (str) — 临时 SecretId + encryptTmpSecretKey(str) — 临时 SecretKey + encryptToken (str) — SessionToken + startTime (int) — 凭证起始时间戳(Unix) + expiredTime (int) — 凭证过期时间戳(Unix) + resourceUrl (str) — 上传后的公网访问 URL + resourceID (str) — 资源 ID(可选) + + Raises: + RuntimeError: 接口返回非 0 code 或字段缺失 + """ + if file_id is None: + file_id = generate_file_id() + + upload_url = f"{api_domain.rstrip('/')}{UPLOAD_INFO_PATH}" + + headers = { + "Content-Type": "application/json", + "X-Token": token, + "X-ID": bot_id or app_key, + "X-Source": "web", + } + if route_env: + headers["X-Route-Env"] = route_env + body = { + "fileName": filename, + "fileId": file_id, + "docFrom": "localDoc", + "docOpenId": "", + } + + async with httpx.AsyncClient(timeout=15.0) as client: + resp = await client.post(upload_url, json=body, headers=headers) + resp.raise_for_status() + result: dict[str, Any] = resp.json() + + code = result.get("code") + if code != 0 and code is not None: + raise RuntimeError( + f"genUploadInfo 失败: code={code}, msg={result.get('msg', '')}" + ) + + data = result.get("data") or result + required_fields = ["bucketName", "location"] + missing = [f for f in required_fields if not data.get(f)] + if missing: + raise RuntimeError( + f"genUploadInfo 返回字段不完整: 缺少字段 {missing}" + ) + + return data + + +async def upload_to_cos( + file_bytes: bytes, + filename: str, + content_type: str, + credentials: dict, + bucket: str, + region: str, +) -> dict: + """ + 通过 httpx PUT 请求将文件上传到 COS。 + 使用临时凭证(tmpSecretId/tmpSecretKey/sessionToken)构建 HMAC-SHA1 签名。 + + Args: + file_bytes: 文件二进制内容 + filename: 文件名(用于辅助计算 MIME、UUID) + content_type: MIME 类型(如 "image/jpeg") + credentials: get_cos_credentials() 返回的 dict,包含: + encryptTmpSecretId → tmpSecretId + encryptTmpSecretKey → tmpSecretKey + encryptToken → sessionToken + location → COS key(对象路径) + resourceUrl → 上传后公网 URL + startTime → 凭证起始时间(Unix) + expiredTime → 凭证过期时间(Unix) + bucket: COS Bucket 名称(如 chatbot-1234567890) + region: COS 地域(如 ap-guangzhou) + + Returns: + 上传结果 dict,包含: + url (str) — COS 公网访问 URL + uuid (str) — 文件内容 MD5 + size (int) — 文件大小(字节) + width (int, optional) — 图片宽度(仅图片) + height (int, optional) — 图片高度(仅图片) + + Raises: + httpx.HTTPStatusError: COS 返回非 2xx 状态 + RuntimeError: credentials 字段缺失 + """ + secret_id: str = credentials.get("encryptTmpSecretId", "") + secret_key: str = credentials.get("encryptTmpSecretKey", "") + session_token: str = credentials.get("encryptToken", "") + cos_key: str = credentials.get("location", "") + resource_url: str = credentials.get("resourceUrl", "") + start_time: Optional[int] = credentials.get("startTime") + expired_time: Optional[int] = credentials.get("expiredTime") + + if not secret_id or not secret_key or not cos_key: + raise RuntimeError( + f"COS credentials 不完整: secretId={bool(secret_id)}, " + f"secretKey={bool(secret_key)}, location={bool(cos_key)}" + ) + + # 构建 COS 上传 URL(优先使用全球加速域名) + if COS_USE_ACCELERATE: + cos_host = f"{bucket}.cos.accelerate.myqcloud.com" + else: + cos_host = f"{bucket}.cos.{region}.myqcloud.com" + + # URL encode cos_key(保留 /) + encoded_key = urllib.parse.quote(cos_key, safe="/") + cos_url = f"https://{cos_host}/{encoded_key.lstrip('/')}" + + # 确定 Content-Type + if not content_type or content_type == "application/octet-stream": + if is_image(filename): + content_type = guess_mime_type(filename) + else: + content_type = "application/octet-stream" + + # 计算文件 MD5 + size + file_uuid = md5_hex(file_bytes) + file_size = len(file_bytes) + + # 参与签名的请求头 + sign_headers = { + "host": cos_host, + "content-type": content_type, + "x-cos-security-token": session_token, + } + + # 计算签名有效期 + now = int(time.time()) + sign_start = start_time if start_time else now + sign_expire = (expired_time - now) if expired_time and expired_time > now else 3600 + + authorization = _cos_sign( + method="put", + path=f"/{encoded_key.lstrip('/')}", + params={}, + headers=sign_headers, + secret_id=secret_id, + secret_key=secret_key, + start_time=sign_start, + expire_seconds=sign_expire, + ) + + put_headers = { + "Authorization": authorization, + "Content-Type": content_type, + "x-cos-security-token": session_token, + } + + logger.info( + "COS PUT: bucket=%s region=%s key=%s size=%d mime=%s", + bucket, region, cos_key, file_size, content_type, + ) + + async with httpx.AsyncClient(timeout=120.0) as client: + resp = await client.put( + cos_url, + content=file_bytes, + headers=put_headers, + ) + resp.raise_for_status() + + # 解析图片尺寸(仅图片类型) + result: dict[str, Any] = { + "url": resource_url or cos_url, + "uuid": file_uuid, + "size": file_size, + } + + if content_type.startswith("image/"): + size_info = parse_image_size(file_bytes) + if size_info: + result["width"] = size_info["width"] + result["height"] = size_info["height"] + + logger.info( + "COS 上传成功: url=%s size=%d", + result["url"], file_size, + ) + return result + + +# ============ TIM 媒体消息构建 ============ + +def build_image_msg_body( + url: str, + uuid: Optional[str] = None, + filename: Optional[str] = None, + size: int = 0, + width: int = 0, + height: int = 0, + mime_type: str = "", +) -> list[dict]: + """ + 构建腾讯 IM TIMImageElem 消息体。 + 参考:https://cloud.tencent.com/document/product/269/2720 + + Args: + url: 图片公网访问 URL(COS resourceUrl) + uuid: 文件 UUID(MD5 或其他唯一标识) + filename: 文件名(uuid 为空时作为备用) + size: 文件大小(字节) + width: 图片宽度(像素) + height: 图片高度(像素) + mime_type: MIME 类型(用于确定 image_format) + + Returns: + TIMImageElem 消息体列表(适合直接放入 msg_body) + """ + _uuid = uuid or filename or _basename_from_url(url) or "image" + image_format = get_image_format(mime_type) if mime_type else 255 + + return [ + { + "msg_type": "TIMImageElem", + "msg_content": { + "uuid": _uuid, + "image_format": image_format, + "image_info_array": [ + { + "type": 1, # 1 = 原图 + "size": size, + "width": width, + "height": height, + "url": url, + } + ], + }, + } + ] + + +def build_file_msg_body( + url: str, + filename: str, + uuid: Optional[str] = None, + size: int = 0, +) -> list[dict]: + """ + 构建腾讯 IM TIMFileElem 消息体。 + 参考:https://cloud.tencent.com/document/product/269/2720 + + Args: + url: 文件公网访问 URL(COS resourceUrl) + filename: 文件名(含扩展名) + uuid: 文件 UUID(MD5 或其他唯一标识,不传则使用 filename) + size: 文件大小(字节) + + Returns: + TIMFileElem 消息体列表(适合直接放入 msg_body) + """ + _uuid = uuid or filename + + return [ + { + "msg_type": "TIMFileElem", + "msg_content": { + "uuid": _uuid, + "file_name": filename, + "file_size": size, + "url": url, + }, + } + ] + + +# ============ 内部工具 ============ + +def _basename_from_url(url: str) -> str: + """从 URL 提取文件名。""" + try: + parsed = urllib.parse.urlparse(url) + return os.path.basename(parsed.path) + except Exception: + return "" diff --git a/gateway/platforms/yuanbao_proto.py b/gateway/platforms/yuanbao_proto.py new file mode 100644 index 0000000000..3d4e56ce49 --- /dev/null +++ b/gateway/platforms/yuanbao_proto.py @@ -0,0 +1,1210 @@ +""" +yuanbao_proto.py - Yuanbao WebSocket 协议编解码(纯 Python 实现) + +协议层级: + WebSocket frame + └── ConnMsg (protobuf: trpc.yuanbao.conn_common.ConnMsg) + ├── head: Head (cmd_type, cmd, seq_no, msg_id, module, ...) + └── data: bytes (业务 payload,标准 protobuf) + └── InboundMessagePush / SendC2CMessageReq / SendGroupMessageReq / ... + (trpc.yuanbao.yuanbao_conn.yuanbao_openclaw_proxy.*) + +注意:conn 层(ConnMsg)本身是标准 protobuf,不是自定义二进制格式。 + conn.proto 注释里的自定义格式(magic+head_len+body_len)仅用于 quic/tcp, + WebSocket 直接传 ConnMsg protobuf bytes(无粘包问题,每个 ws frame = 一条消息)。 + +实现方式:手写 varint / protobuf wire-format 编解码,不依赖第三方 protobuf 库。 +""" + +from __future__ import annotations + +import logging +import struct +import threading +from typing import Optional, Union + +logger = logging.getLogger(__name__) + +# ============================================================ +# Debug 开关 +# ============================================================ + +DEBUG_MODE = False + + +def _dbg(label: str, data: bytes) -> None: + if DEBUG_MODE: + hex_str = " ".join(f"{b:02x}" for b in data[:64]) + ellipsis = "..." if len(data) > 64 else "" + logger.debug("[yuanbao_proto] %s (%dB): %s", label, len(data), hex_str + ellipsis) + + +# ============================================================ +# 常量 +# ============================================================ + +# conn 层消息类型枚举(ConnMsg.Head.cmd_type) +PB_MSG_TYPES = { + "ConnMsg": "trpc.yuanbao.conn_common.ConnMsg", + "AuthBindReq": "trpc.yuanbao.conn_common.AuthBindReq", + "AuthBindRsp": "trpc.yuanbao.conn_common.AuthBindRsp", + "PingReq": "trpc.yuanbao.conn_common.PingReq", + "PingRsp": "trpc.yuanbao.conn_common.PingRsp", + "KickoutMsg": "trpc.yuanbao.conn_common.KickoutMsg", + "DirectedPush": "trpc.yuanbao.conn_common.DirectedPush", + "PushMsg": "trpc.yuanbao.conn_common.PushMsg", +} + +# cmd_type 枚举 +CMD_TYPE = { + "Request": 0, # 上行请求 + "Response": 1, # 上行请求的回包 + "Push": 2, # 下行推送 + "PushAck": 3, # 下行推送的回包(ACK) +} + +# 内置命令字 +CMD = { + "AuthBind": "auth-bind", + "Ping": "ping", + "Kickout": "kickout", + "UpdateMeta": "update-meta", +} + +# 内置模块名 +MODULE = { + "ConnAccess": "conn_access", +} + +# biz 层服务/方法映射 +# TS client uses the short name 'yuanbao_openclaw_proxy' (not the full package path) +_BIZ_PKG = "yuanbao_openclaw_proxy" +BIZ_SERVICES = { + "InboundMessagePush": f"{_BIZ_PKG}.InboundMessagePush", + "SendC2CMessageReq": f"{_BIZ_PKG}.SendC2CMessageReq", + "SendC2CMessageRsp": f"{_BIZ_PKG}.SendC2CMessageRsp", + "SendGroupMessageReq": f"{_BIZ_PKG}.SendGroupMessageReq", + "SendGroupMessageRsp": f"{_BIZ_PKG}.SendGroupMessageRsp", + "QueryGroupInfoReq": f"{_BIZ_PKG}.QueryGroupInfoReq", + "QueryGroupInfoRsp": f"{_BIZ_PKG}.QueryGroupInfoRsp", + "GetGroupMemberListReq": f"{_BIZ_PKG}.GetGroupMemberListReq", + "GetGroupMemberListRsp": f"{_BIZ_PKG}.GetGroupMemberListRsp", + "SendPrivateHeartbeatReq": f"{_BIZ_PKG}.SendPrivateHeartbeatReq", + "SendPrivateHeartbeatRsp": f"{_BIZ_PKG}.SendPrivateHeartbeatRsp", + "SendGroupHeartbeatReq": f"{_BIZ_PKG}.SendGroupHeartbeatReq", + "SendGroupHeartbeatRsp": f"{_BIZ_PKG}.SendGroupHeartbeatRsp", +} + +# openclaw instance_id(固定值 17) +HERMES_INSTANCE_ID = 17 + +# Reply Heartbeat 状态常量 +WS_HEARTBEAT_RUNNING = 1 +WS_HEARTBEAT_FINISH = 2 + +# ============================================================ +# 序列号生成 +# ============================================================ + +_seq_lock = threading.Lock() +_seq_counter = 0 +_SEQ_MAX = 2 ** 32 - 1 # uint32 上限 + + +def next_seq_no() -> int: + """生成递增序列号(线程安全,溢出时归零)""" + global _seq_counter + with _seq_lock: + val = _seq_counter + _seq_counter = (_seq_counter + 1) & _SEQ_MAX + return val + + +# ============================================================ +# Protobuf wire-format 基础工具(手写,不依赖 google.protobuf) +# ============================================================ + +# wire types +WT_VARINT = 0 +WT_64BIT = 1 +WT_LEN = 2 +WT_32BIT = 5 + + +def _encode_varint(value: int) -> bytes: + """将非负整数编码为 protobuf varint""" + if value < 0: + # 处理有符号负数(int32/int64 用 two's complement,64-bit) + value = value & 0xFFFFFFFFFFFFFFFF + out = [] + while True: + bits = value & 0x7F + value >>= 7 + if value: + out.append(bits | 0x80) + else: + out.append(bits) + break + return bytes(out) + + +def _decode_varint(data: bytes, pos: int) -> tuple[int, int]: + """从 data[pos:] 解码 varint,返回 (value, new_pos)""" + result = 0 + shift = 0 + while pos < len(data): + b = data[pos] + pos += 1 + result |= (b & 0x7F) << shift + shift += 7 + if not (b & 0x80): + break + if shift >= 64: + raise ValueError("varint too long") + return result, pos + + +def _encode_field(field_number: int, wire_type: int, value: bytes) -> bytes: + """编码一个 protobuf field(tag + value)""" + tag = (field_number << 3) | wire_type + return _encode_varint(tag) + value + + +def _encode_string(s: str) -> bytes: + """编码 protobuf string 字段的 value 部分(length-prefixed UTF-8)""" + encoded = s.encode("utf-8") + return _encode_varint(len(encoded)) + encoded + + +def _encode_bytes(b: bytes) -> bytes: + """编码 protobuf bytes 字段的 value 部分(length-prefixed)""" + return _encode_varint(len(b)) + b + + +def _encode_message(b: bytes) -> bytes: + """编码嵌套 message(length-prefixed)""" + return _encode_varint(len(b)) + b + + +def _parse_fields(data: bytes) -> list[tuple[int, int, bytes | int]]: + """ + 解析 protobuf message 的所有字段,返回 [(field_number, wire_type, raw_value), ...] + raw_value: + - WT_VARINT: int + - WT_LEN: bytes + - WT_64BIT: bytes (8 bytes) + - WT_32BIT: bytes (4 bytes) + """ + fields = [] + pos = 0 + n = len(data) + while pos < n: + tag, pos = _decode_varint(data, pos) + field_number = tag >> 3 + wire_type = tag & 0x07 + if wire_type == WT_VARINT: + val, pos = _decode_varint(data, pos) + fields.append((field_number, wire_type, val)) + elif wire_type == WT_LEN: + length, pos = _decode_varint(data, pos) + val = data[pos: pos + length] + pos += length + fields.append((field_number, wire_type, val)) + elif wire_type == WT_64BIT: + val = data[pos: pos + 8] + pos += 8 + fields.append((field_number, wire_type, val)) + elif wire_type == WT_32BIT: + val = data[pos: pos + 4] + pos += 4 + fields.append((field_number, wire_type, val)) + else: + raise ValueError(f"unknown wire type {wire_type} at pos {pos - 1}") + return fields + + +def _fields_to_dict(fields: list) -> dict[int, list]: + """将 fields 列表转为 {field_number: [value, ...]} 字典(repeated 字段会有多个)""" + d: dict[int, list] = {} + for fn, wt, val in fields: + d.setdefault(fn, []).append((wt, val)) + return d + + +def _get_string(fdict: dict, fn: int, default: str = "") -> str: + """从 fields dict 取第一个 string 字段""" + entries = fdict.get(fn) + if not entries: + return default + wt, val = entries[0] + if wt == WT_LEN and isinstance(val, (bytes, bytearray)): + return val.decode("utf-8", errors="replace") + return default + + +def _get_varint(fdict: dict, fn: int, default: int = 0) -> int: + """从 fields dict 取第一个 varint 字段""" + entries = fdict.get(fn) + if not entries: + return default + wt, val = entries[0] + if wt == WT_VARINT and isinstance(val, int): + return val + return default + + +def _get_bytes(fdict: dict, fn: int, default: bytes = b"") -> bytes: + """从 fields dict 取第一个 bytes/message 字段""" + entries = fdict.get(fn) + if not entries: + return default + wt, val = entries[0] + if wt == WT_LEN and isinstance(val, (bytes, bytearray)): + return bytes(val) + return default + + +def _get_repeated_bytes(fdict: dict, fn: int) -> list[bytes]: + """取所有 repeated bytes/message 字段""" + entries = fdict.get(fn, []) + return [bytes(val) for wt, val in entries if wt == WT_LEN] + + +# ============================================================ +# ConnMsg 层编解码 +# ============================================================ +# +# ConnMsg protobuf schema (conn.json): +# message Head { +# uint32 cmd_type = 1; +# string cmd = 2; +# uint32 seq_no = 3; +# string msg_id = 4; +# string module = 5; +# bool need_ack = 6; +# ... +# int32 status = 10; +# } +# message ConnMsg { +# Head head = 1; +# bytes data = 2; +# } + + +def _encode_head( + cmd_type: int, + cmd: str, + seq_no: int, + msg_id: str, + module: str, + need_ack: bool = False, + status: int = 0, +) -> bytes: + """编码 ConnMsg.Head""" + buf = b"" + if cmd_type != 0: + buf += _encode_field(1, WT_VARINT, _encode_varint(cmd_type)) + if cmd: + buf += _encode_field(2, WT_LEN, _encode_string(cmd)) + if seq_no != 0: + buf += _encode_field(3, WT_VARINT, _encode_varint(seq_no)) + if msg_id: + buf += _encode_field(4, WT_LEN, _encode_string(msg_id)) + if module: + buf += _encode_field(5, WT_LEN, _encode_string(module)) + if need_ack: + buf += _encode_field(6, WT_VARINT, _encode_varint(1)) + if status != 0: + buf += _encode_field(10, WT_VARINT, _encode_varint(status & 0xFFFFFFFFFFFFFFFF)) + return buf + + +def _decode_head(data: bytes) -> dict: + """解码 ConnMsg.Head,返回 dict""" + fdict = _fields_to_dict(_parse_fields(data)) + return { + "cmd_type": _get_varint(fdict, 1, 0), + "cmd": _get_string(fdict, 2, ""), + "seq_no": _get_varint(fdict, 3, 0), + "msg_id": _get_string(fdict, 4, ""), + "module": _get_string(fdict, 5, ""), + "need_ack": bool(_get_varint(fdict, 6, 0)), + "status": _get_varint(fdict, 10, 0), + } + + +def encode_conn_msg(msg_type: int, seq_no: int, data: bytes) -> bytes: + """ + 编码 ConnMsg(简化接口,对应任务要求的签名)。 + + Args: + msg_type: cmd_type(CMD_TYPE 枚举值) + seq_no: 序列号 + data: 内层 payload bytes(业务 protobuf) + + Returns: + ConnMsg 编码后的 bytes + """ + head_bytes = _encode_head( + cmd_type=msg_type, + cmd="", + seq_no=seq_no, + msg_id="", + module="", + ) + buf = _encode_field(1, WT_LEN, _encode_message(head_bytes)) + if data: + buf += _encode_field(2, WT_LEN, _encode_bytes(data)) + _dbg("encode_conn_msg", buf) + return buf + + +def decode_conn_msg(data: bytes) -> dict: + """ + 解码 ConnMsg,返回 {msg_type, seq_no, data, head}。 + + Returns: + { + "msg_type": int, # cmd_type + "seq_no": int, + "data": bytes, # 内层 payload + "head": dict, # 完整 head 字段 + } + """ + _dbg("decode_conn_msg", data) + fdict = _fields_to_dict(_parse_fields(data)) + head_bytes = _get_bytes(fdict, 1) + payload = _get_bytes(fdict, 2) + head = _decode_head(head_bytes) if head_bytes else { + "cmd_type": 0, "cmd": "", "seq_no": 0, "msg_id": "", "module": "", + "need_ack": False, "status": 0, + } + return { + "msg_type": head["cmd_type"], + "seq_no": head["seq_no"], + "data": payload, + "head": head, + } + + +def encode_conn_msg_full( + cmd_type: int, + cmd: str, + seq_no: int, + msg_id: str, + module: str, + data: bytes, + need_ack: bool = False, +) -> bytes: + """ + 编码完整的 ConnMsg(含 cmd/msg_id/module 等 head 字段)。 + 比 encode_conn_msg 提供更多 head 控制。 + """ + head_bytes = _encode_head( + cmd_type=cmd_type, + cmd=cmd, + seq_no=seq_no, + msg_id=msg_id, + module=module, + need_ack=need_ack, + ) + buf = _encode_field(1, WT_LEN, _encode_message(head_bytes)) + if data: + buf += _encode_field(2, WT_LEN, _encode_bytes(data)) + _dbg("encode_conn_msg_full", buf) + return buf + + +# ============================================================ +# BizMsg 层编解码(biz payload 本身也是 protobuf) +# ============================================================ +# +# 任务要求的 encode_biz_msg / decode_biz_msg 是一个中间抽象层: +# encode_biz_msg(service, method, req_id, body) -> conn_msg_bytes +# 即:将业务 body 包装成 ConnMsg,其中 head.cmd = method, head.module = service +# +# 这与 conn-codec.ts 中 buildBusinessConnMsg() 的行为一致: +# buildBusinessConnMsg(cmd, module, bizData, msgId) -> ConnMsg bytes + + +def encode_biz_msg(service: str, method: str, req_id: str, body: bytes) -> bytes: + """ + 将业务 payload 包装为 ConnMsg bytes。 + + Args: + service: 模块名(head.module),如 "yuanbao_openclaw_proxy" + method: 命令字(head.cmd),如 "send_c2c_message" + req_id: 消息 ID(head.msg_id) + body: 已编码的业务 protobuf bytes + + Returns: + ConnMsg bytes(可直接发送到 WebSocket) + """ + return encode_conn_msg_full( + cmd_type=CMD_TYPE["Request"], + cmd=method, + seq_no=next_seq_no(), + msg_id=req_id, + module=service, + data=body, + ) + + +def decode_biz_msg(data: bytes) -> dict: + """ + 解码 ConnMsg bytes,返回业务层信息。 + + Returns: + { + "service": str, # head.module + "method": str, # head.cmd + "req_id": str, # head.msg_id + "body": bytes, # 内层 biz payload + "is_response": bool, # cmd_type == 1 (Response) + "head": dict, # 完整 head + } + """ + result = decode_conn_msg(data) + head = result["head"] + return { + "service": head["module"], + "method": head["cmd"], + "req_id": head["msg_id"], + "body": result["data"], + "is_response": head["cmd_type"] == CMD_TYPE["Response"], + "head": head, + } + + +# ============================================================ +# 业务 protobuf 消息编解码(biz payload) +# ============================================================ + +# ---------- MsgContent 编解码 ---------- +# field 1: text (string) +# field 2: uuid (string) +# field 3: image_format (uint32) +# field 4: data (string) +# field 5: desc (string) +# field 6: ext (string) +# field 7: sound (string) +# field 8: image_info_array (repeated message) +# field 9: index (uint32) +# field 10: url (string) +# field 11: file_size (uint32) +# field 12: file_name (string) + + +def _encode_msg_content(content: dict) -> bytes: + buf = b"" + for fn, key in [ + (1, "text"), (2, "uuid"), (4, "data"), (5, "desc"), + (6, "ext"), (7, "sound"), (10, "url"), (12, "file_name"), + ]: + v = content.get(key, "") + if v: + buf += _encode_field(fn, WT_LEN, _encode_string(str(v))) + for fn, key in [(3, "image_format"), (9, "index"), (11, "file_size")]: + v = content.get(key, 0) + if v: + buf += _encode_field(fn, WT_VARINT, _encode_varint(int(v))) + # image_info_array (repeated) + for img in content.get("image_info_array") or []: + img_buf = b"" + for ifn, ikey in [(1, "type"), (2, "size"), (3, "width"), (4, "height")]: + iv = img.get(ikey, 0) + if iv: + img_buf += _encode_field(ifn, WT_VARINT, _encode_varint(int(iv))) + url = img.get("url", "") + if url: + img_buf += _encode_field(5, WT_LEN, _encode_string(url)) + buf += _encode_field(8, WT_LEN, _encode_message(img_buf)) + return buf + + +def _decode_msg_content(data: bytes) -> dict: + fdict = _fields_to_dict(_parse_fields(data)) + content: dict = {} + for fn, key in [ + (1, "text"), (2, "uuid"), (4, "data"), (5, "desc"), + (6, "ext"), (7, "sound"), (10, "url"), (12, "file_name"), + ]: + v = _get_string(fdict, fn) + if v: + content[key] = v + for fn, key in [(3, "image_format"), (9, "index"), (11, "file_size")]: + v = _get_varint(fdict, fn) + if v: + content[key] = v + imgs = [] + for img_bytes in _get_repeated_bytes(fdict, 8): + ifdict = _fields_to_dict(_parse_fields(img_bytes)) + img = {} + for ifn, ikey in [(1, "type"), (2, "size"), (3, "width"), (4, "height")]: + iv = _get_varint(ifdict, ifn) + if iv: + img[ikey] = iv + url = _get_string(ifdict, 5) + if url: + img["url"] = url + if img: + imgs.append(img) + if imgs: + content["image_info_array"] = imgs + return content + + +# ---------- MsgBodyElement 编解码 ---------- +# field 1: msg_type (string) e.g. "TIMTextElem" +# field 2: msg_content (message MsgContent) + + +def _encode_msg_body_element(element: dict) -> bytes: + buf = b"" + msg_type = element.get("msg_type", "") + if msg_type: + buf += _encode_field(1, WT_LEN, _encode_string(msg_type)) + content = element.get("msg_content", {}) + if content: + content_bytes = _encode_msg_content(content) + buf += _encode_field(2, WT_LEN, _encode_message(content_bytes)) + return buf + + +def _decode_msg_body_element(data: bytes) -> dict: + fdict = _fields_to_dict(_parse_fields(data)) + msg_type = _get_string(fdict, 1, "") + content_bytes = _get_bytes(fdict, 2) + content = _decode_msg_content(content_bytes) if content_bytes else {} + return {"msg_type": msg_type, "msg_content": content} + + +# ---------- LogInfoExt ---------- +# field 1: trace_id (string) + + +def _encode_log_ext(trace_id: str) -> bytes: + if not trace_id: + return b"" + return _encode_field(1, WT_LEN, _encode_string(trace_id)) + + +def _decode_im_msg_seq(data: bytes) -> dict: + """Decode a single ImMsgSeq sub-message (field 17 of InboundMessagePush). + + ImMsgSeq proto fields: + 1: msg_seq (uint64) + 2: msg_id (string) + """ + fdict = _fields_to_dict(_parse_fields(data)) + return { + "msg_seq": _get_varint(fdict, 1), + "msg_id": _get_string(fdict, 2), + } + + +def _decode_log_ext(data: bytes) -> dict: + fdict = _fields_to_dict(_parse_fields(data)) + return {"trace_id": _get_string(fdict, 1)} + + +# ============================================================ +# 入站消息解析 +# ============================================================ +# +# InboundMessagePush fields: +# 1: callback_command (string) +# 2: from_account (string) +# 3: to_account (string) +# 4: sender_nickname (string) +# 5: group_id (string) +# 6: group_code (string) +# 7: group_name (string) +# 8: msg_seq (uint32) +# 9: msg_random (uint32) +# 10: msg_time (uint32) +# 11: msg_key (string) +# 12: msg_id (string) +# 13: msg_body (repeated MsgBodyElement) +# 14: cloud_custom_data (string) +# 15: event_time (uint32) +# 16: bot_owner_id (string) +# 17: recall_msg_seq_list (repeated ImMsgSeq) +# 18: claw_msg_type (uint32/enum) +# 19: private_from_group_code (string) +# 20: log_ext (message LogInfoExt) + + +def decode_inbound_push(data: bytes) -> Optional[dict]: + """ + 解析入站消息推送的 biz payload(InboundMessagePush proto bytes)。 + + Args: + data: ConnMsg.data 字段的 bytes(即 biz payload) + + Returns: + { + "from_account": str, + "to_account": str (可选), + "group_code": str (可选,群消息才有), + "group_id": str (可选), + "group_name": str (可选), + "msg_key": str, + "msg_id": str, + "msg_seq": int, + "msg_random": int, + "msg_time": int, + "sender_nickname": str, + "msg_body": [{"msg_type": str, "msg_content": dict}, ...], + "callback_command": str, + "cloud_custom_data": str, + "bot_owner_id": str, + "claw_msg_type": int, + "private_from_group_code": str, + "trace_id": str, + "recall_msg_seq_list": [{"msg_seq": int, "msg_id": str}, ...] 或 None, + } + 或 None(解析失败) + """ + try: + _dbg("decode_inbound_push input", data) + fdict = _fields_to_dict(_parse_fields(data)) + + msg_body = [] + for el_bytes in _get_repeated_bytes(fdict, 13): + msg_body.append(_decode_msg_body_element(el_bytes)) + + log_ext_bytes = _get_bytes(fdict, 20) + trace_id = _decode_log_ext(log_ext_bytes).get("trace_id", "") if log_ext_bytes else "" + + recall_seq_raw = _get_repeated_bytes(fdict, 17) + recall_msg_seq_list = [_decode_im_msg_seq(b) for b in recall_seq_raw] or None + + result: dict = { + "callback_command": _get_string(fdict, 1), + "from_account": _get_string(fdict, 2), + "to_account": _get_string(fdict, 3), + "sender_nickname": _get_string(fdict, 4), + "group_id": _get_string(fdict, 5), + "group_code": _get_string(fdict, 6), + "group_name": _get_string(fdict, 7), + "msg_seq": _get_varint(fdict, 8), + "msg_random": _get_varint(fdict, 9), + "msg_time": _get_varint(fdict, 10), + "msg_key": _get_string(fdict, 11), + "msg_id": _get_string(fdict, 12), + "msg_body": msg_body, + "cloud_custom_data": _get_string(fdict, 14), + "event_time": _get_varint(fdict, 15), + "bot_owner_id": _get_string(fdict, 16), + "recall_msg_seq_list": recall_msg_seq_list, + "claw_msg_type": _get_varint(fdict, 18), + "private_from_group_code": _get_string(fdict, 19), + "trace_id": trace_id, + } + # 过滤空值(保持 API 整洁) + return {k: v for k, v in result.items() if v or k in ("msg_body", "msg_seq")} + except Exception as e: + if DEBUG_MODE: + logger.debug("[yuanbao_proto] decode_inbound_push failed: %s", e) + return None + + +# ============================================================ +# 出站消息编码 +# ============================================================ + +def _encode_send_c2c_req( + to_account: str, + from_account: str, + msg_body: list, + msg_id: str = "", + msg_random: int = 0, + msg_seq: Optional[int] = None, + group_code: str = "", + trace_id: str = "", +) -> bytes: + """ + 编码 SendC2CMessageReq biz payload。 + + SendC2CMessageReq fields: + 1: msg_id (string) + 2: to_account (string) + 3: from_account (string) + 4: msg_random (uint32) + 5: msg_body (repeated MsgBodyElement) + 6: group_code (string) + 7: msg_seq (uint64) + 8: log_ext (LogInfoExt) + """ + buf = b"" + if msg_id: + buf += _encode_field(1, WT_LEN, _encode_string(msg_id)) + buf += _encode_field(2, WT_LEN, _encode_string(to_account)) + if from_account: + buf += _encode_field(3, WT_LEN, _encode_string(from_account)) + if msg_random: + buf += _encode_field(4, WT_VARINT, _encode_varint(msg_random)) + for el in msg_body: + el_bytes = _encode_msg_body_element(el) + buf += _encode_field(5, WT_LEN, _encode_message(el_bytes)) + if group_code: + buf += _encode_field(6, WT_LEN, _encode_string(group_code)) + if msg_seq is not None: + buf += _encode_field(7, WT_VARINT, _encode_varint(msg_seq)) + if trace_id: + log_bytes = _encode_log_ext(trace_id) + buf += _encode_field(8, WT_LEN, _encode_message(log_bytes)) + return buf + + +def _encode_send_group_req( + group_code: str, + from_account: str, + msg_body: list, + msg_id: str = "", + to_account: str = "", + random: str = "", + msg_seq: Optional[int] = None, + ref_msg_id: str = "", + trace_id: str = "", +) -> bytes: + """ + 编码 SendGroupMessageReq biz payload。 + + SendGroupMessageReq fields: + 1: msg_id (string) + 2: group_code (string) + 3: from_account (string) + 4: to_account (string) + 5: random (string) + 6: msg_body (repeated MsgBodyElement) + 7: ref_msg_id (string) + 8: msg_seq (uint64) + 9: log_ext (LogInfoExt) + """ + buf = b"" + if msg_id: + buf += _encode_field(1, WT_LEN, _encode_string(msg_id)) + buf += _encode_field(2, WT_LEN, _encode_string(group_code)) + if from_account: + buf += _encode_field(3, WT_LEN, _encode_string(from_account)) + if to_account: + buf += _encode_field(4, WT_LEN, _encode_string(to_account)) + if random: + buf += _encode_field(5, WT_LEN, _encode_string(random)) + for el in msg_body: + el_bytes = _encode_msg_body_element(el) + buf += _encode_field(6, WT_LEN, _encode_message(el_bytes)) + if ref_msg_id: + buf += _encode_field(7, WT_LEN, _encode_string(ref_msg_id)) + if msg_seq is not None: + buf += _encode_field(8, WT_VARINT, _encode_varint(msg_seq)) + if trace_id: + log_bytes = _encode_log_ext(trace_id) + buf += _encode_field(9, WT_LEN, _encode_message(log_bytes)) + return buf + + +def encode_send_c2c_message( + to_account: str, + msg_body: list, + from_account: str, + msg_id: str = "", + msg_random: int = 0, + msg_seq: Optional[int] = None, + group_code: str = "", + trace_id: str = "", +) -> bytes: + """ + 编码 C2C 发消息请求,返回完整 ConnMsg bytes(可直接发送到 WebSocket)。 + + Args: + to_account: 收件人账号 + msg_body: 消息体列表,每个元素: {"msg_type": str, "msg_content": dict} + 例如: [{"msg_type": "TIMTextElem", "msg_content": {"text": "hello"}}] + from_account: 发件人账号(机器人账号) + msg_id: 消息唯一 ID(空时使用 req_id) + msg_random: 随机数(防重) + msg_seq: 消息序列号(可选) + group_code: 来自群聊的私聊场景时填写 + trace_id: 链路追踪 ID + + Returns: + ConnMsg bytes + """ + biz_bytes = _encode_send_c2c_req( + to_account=to_account, + from_account=from_account, + msg_body=msg_body, + msg_id=msg_id, + msg_random=msg_random, + msg_seq=msg_seq, + group_code=group_code, + trace_id=trace_id, + ) + _dbg("encode_send_c2c biz payload", biz_bytes) + req_id = msg_id or f"c2c_{next_seq_no()}" + return encode_conn_msg_full( + cmd_type=CMD_TYPE["Request"], + cmd="send_c2c_message", + seq_no=next_seq_no(), + msg_id=req_id, + module=_BIZ_PKG, + data=biz_bytes, + ) + + +def encode_send_group_message( + group_code: str, + msg_body: list, + from_account: str, + msg_id: str = "", + to_account: str = "", + random: str = "", + msg_seq: Optional[int] = None, + ref_msg_id: str = "", + trace_id: str = "", +) -> bytes: + """ + 编码群消息发送请求,返回完整 ConnMsg bytes(可直接发送到 WebSocket)。 + + Args: + group_code: 群号 + msg_body: 消息体列表 + from_account: 发件人账号(机器人账号) + msg_id: 消息唯一 ID + to_account: 指定接收者(一般为空) + random: 去重随机字符串 + msg_seq: 消息序列号 + ref_msg_id: 引用消息 ID + trace_id: 链路追踪 ID + + Returns: + ConnMsg bytes + """ + biz_bytes = _encode_send_group_req( + group_code=group_code, + from_account=from_account, + msg_body=msg_body, + msg_id=msg_id, + to_account=to_account, + random=random, + msg_seq=msg_seq, + ref_msg_id=ref_msg_id, + trace_id=trace_id, + ) + _dbg("encode_send_group biz payload", biz_bytes) + req_id = msg_id or f"grp_{next_seq_no()}" + return encode_conn_msg_full( + cmd_type=CMD_TYPE["Request"], + cmd="send_group_message", + seq_no=next_seq_no(), + msg_id=req_id, + module=_BIZ_PKG, + data=biz_bytes, + ) + + +# ============================================================ +# AuthBind / Ping 帮助函数 +# ============================================================ + +def encode_auth_bind( + biz_id: str, + uid: str, + source: str, + token: str, + msg_id: str, + app_version: str = "", + operation_system: str = "", + bot_version: str = "", + route_env: str = "", +) -> bytes: + """ + 构造 auth-bind 请求 ConnMsg bytes。 + + AuthBindReq fields: + 1: biz_id (string) + 2: auth_info (message AuthInfo: uid=1, source=2, token=3) + 3: device_info (message DeviceInfo: app_version=1, app_operation_system=2, instance_id=10, bot_version=24) + 5: env_name (string) + """ + # AuthInfo + auth_buf = ( + _encode_field(1, WT_LEN, _encode_string(uid)) + + _encode_field(2, WT_LEN, _encode_string(source)) + + _encode_field(3, WT_LEN, _encode_string(token)) + ) + # DeviceInfo + dev_buf = b"" + if app_version: + dev_buf += _encode_field(1, WT_LEN, _encode_string(app_version)) + if operation_system: + dev_buf += _encode_field(2, WT_LEN, _encode_string(operation_system)) + dev_buf += _encode_field(10, WT_LEN, _encode_string(str(HERMES_INSTANCE_ID))) + if bot_version: + dev_buf += _encode_field(24, WT_LEN, _encode_string(bot_version)) + + req_buf = ( + _encode_field(1, WT_LEN, _encode_string(biz_id)) + + _encode_field(2, WT_LEN, _encode_message(auth_buf)) + + _encode_field(3, WT_LEN, _encode_message(dev_buf)) + ) + if route_env: + req_buf += _encode_field(5, WT_LEN, _encode_string(route_env)) + + return encode_conn_msg_full( + cmd_type=CMD_TYPE["Request"], + cmd=CMD["AuthBind"], + seq_no=next_seq_no(), + msg_id=msg_id, + module=MODULE["ConnAccess"], + data=req_buf, + ) + + +def encode_ping(msg_id: str) -> bytes: + """构造 ping 请求 ConnMsg bytes(PingReq 为空消息)""" + return encode_conn_msg_full( + cmd_type=CMD_TYPE["Request"], + cmd=CMD["Ping"], + seq_no=next_seq_no(), + msg_id=msg_id, + module=MODULE["ConnAccess"], + data=b"", + ) + + +def encode_push_ack(original_head: dict) -> bytes: + """构造 push ACK 回包""" + return encode_conn_msg_full( + cmd_type=CMD_TYPE["PushAck"], + cmd=original_head.get("cmd", ""), + seq_no=next_seq_no(), + msg_id=original_head.get("msg_id", ""), + module=original_head.get("module", ""), + data=b"", + ) + + +# ============================================================ +# Heartbeat 编码 +# ============================================================ + +def encode_send_private_heartbeat( + from_account: str, + to_account: str, + heartbeat: int = WS_HEARTBEAT_RUNNING, +) -> bytes: + """ + 编码 SendPrivateHeartbeatReq,返回完整 ConnMsg bytes。 + + SendPrivateHeartbeatReq fields: + 1: from_account (string) + 2: to_account (string) + 3: heartbeat (varint: RUNNING=1, FINISH=2) + """ + buf = ( + _encode_field(1, WT_LEN, _encode_string(from_account)) + + _encode_field(2, WT_LEN, _encode_string(to_account)) + + _encode_field(3, WT_VARINT, _encode_varint(heartbeat)) + ) + req_id = f"hb_priv_{next_seq_no()}" + return encode_biz_msg( + service=_BIZ_PKG, + method="send_private_heartbeat", + req_id=req_id, + body=buf, + ) + + +def encode_send_group_heartbeat( + from_account: str, + group_code: str, + heartbeat: int = WS_HEARTBEAT_RUNNING, + send_time: int = 0, +) -> bytes: + """ + 编码 SendGroupHeartbeatReq,返回完整 ConnMsg bytes。 + + SendGroupHeartbeatReq fields: + 1: from_account (string) + 2: to_account (string) — 群场景留空 + 3: group_code (string) + 4: send_time (int64, ms timestamp) + 5: heartbeat (varint: RUNNING=1, FINISH=2) + """ + import time as _time + ts = send_time or int(_time.time() * 1000) + buf = ( + _encode_field(1, WT_LEN, _encode_string(from_account)) + + _encode_field(2, WT_LEN, _encode_string("")) # to_account empty for group + + _encode_field(3, WT_LEN, _encode_string(group_code)) + + _encode_field(4, WT_VARINT, _encode_varint(ts)) + + _encode_field(5, WT_VARINT, _encode_varint(heartbeat)) + ) + req_id = f"hb_grp_{next_seq_no()}" + return encode_biz_msg( + service=_BIZ_PKG, + method="send_group_heartbeat", + req_id=req_id, + body=buf, + ) + + +# ============================================================ +# 群信息查询 +# ============================================================ + +def encode_query_group_info(group_code: str) -> bytes: + """ + 编码 QueryGroupInfoReq,返回完整 ConnMsg bytes。 + + QueryGroupInfoReq fields: + 1: group_code (string) + """ + buf = _encode_field(1, WT_LEN, _encode_string(group_code)) + req_id = f"qgi_{next_seq_no()}" + return encode_biz_msg( + service=_BIZ_PKG, + method="query_group_info", + req_id=req_id, + body=buf, + ) + + +def decode_query_group_info_rsp(data: bytes) -> Optional[dict]: + """ + 解码 QueryGroupInfoRsp biz payload。 + + Proto 结构(对齐 TS biz-codec / member.ts queryGroupInfo): + + message QueryGroupInfoRsp { + int32 code = 1; + string message = 2; + GroupInfo group_info = 3; // 嵌套 message + } + + message GroupInfo { + string group_name = 1; + string group_owner_user_id = 2; + string group_owner_nickname = 3; + uint32 group_size = 4; + } + + Returns: + 解码后的 dict,或 None(解析失败) + """ + try: + fdict = _fields_to_dict(_parse_fields(data)) + code = _get_varint(fdict, 1, 0) + msg = _get_string(fdict, 2) + + result: dict = {"code": code} + if msg: + result["message"] = msg + + # field 3 = nested GroupInfo message + gi_entries = fdict.get(3, []) + gi_bytes = gi_entries[0][1] if gi_entries else b"" + if gi_bytes and isinstance(gi_bytes, (bytes, bytearray)): + gi = _fields_to_dict(_parse_fields(gi_bytes)) + result["group_name"] = _get_string(gi, 1) or "" + result["owner_id"] = _get_string(gi, 2) or "" + result["owner_nickname"] = _get_string(gi, 3) or "" + result["member_count"] = _get_varint(gi, 4, 0) + else: + result["group_name"] = "" + result["owner_id"] = "" + result["owner_nickname"] = "" + result["member_count"] = 0 + + return result + except Exception: + return None + + +# ============================================================ +# 群成员列表查询 +# ============================================================ + +def encode_get_group_member_list( + group_code: str, + offset: int = 0, + limit: int = 200, +) -> bytes: + """ + 编码 GetGroupMemberListReq,返回完整 ConnMsg bytes。 + + GetGroupMemberListReq fields: + 1: group_code (string) + 2: offset (uint32) + 3: limit (uint32) + """ + buf = _encode_field(1, WT_LEN, _encode_string(group_code)) + if offset: + buf += _encode_field(2, WT_VARINT, _encode_varint(offset)) + buf += _encode_field(3, WT_VARINT, _encode_varint(limit)) + req_id = f"gml_{next_seq_no()}" + return encode_biz_msg( + service=_BIZ_PKG, + method="get_group_member_list", + req_id=req_id, + body=buf, + ) + + +def decode_get_group_member_list_rsp(data: bytes) -> Optional[dict]: + """ + 解码 GetGroupMemberListRsp biz payload。 + + GetGroupMemberListRsp fields: + 1: code (int32) + 2: message (string) + 3: members (repeated message MemberInfo) + 4: next_offset (uint32) + 5: is_complete (bool/varint) + + MemberInfo fields: + 1: user_id (string) + 2: nickname (string) + 3: role (uint32) — 0=member, 1=admin, 2=owner + 4: join_time (uint32) + 5: name_card (string) — 群昵称 + + Returns: + { + "code": int, + "message": str, + "members": [{"user_id": str, "nickname": str, "role": int, ...}, ...], + "next_offset": int, + "is_complete": bool, + } + 或 None(解析失败) + """ + try: + fdict = _fields_to_dict(_parse_fields(data)) + code = _get_varint(fdict, 1, 0) + + members = [] + for member_bytes in _get_repeated_bytes(fdict, 3): + mdict = _fields_to_dict(_parse_fields(member_bytes)) + member = { + "user_id": _get_string(mdict, 1), + "nickname": _get_string(mdict, 2), + "role": _get_varint(mdict, 3), + "join_time": _get_varint(mdict, 4), + "name_card": _get_string(mdict, 5), + } + members.append({k: v for k, v in member.items() if v or k == "role"}) + + return { + "code": code, + "message": _get_string(fdict, 2), + "members": members, + "next_offset": _get_varint(fdict, 4), + "is_complete": bool(_get_varint(fdict, 5)), + } + except Exception: + return None diff --git a/gateway/platforms/yuanbao_sticker.py b/gateway/platforms/yuanbao_sticker.py new file mode 100644 index 0000000000..51f7f31c3e --- /dev/null +++ b/gateway/platforms/yuanbao_sticker.py @@ -0,0 +1,558 @@ +""" +Yuanbao sticker (TIMFaceElem) support. + +Ported from yuanbao-openclaw-plugin/src/sticker/. + +TIMFaceElem wire format: + { + "msg_type": "TIMFaceElem", + "msg_content": { + "index": 0, # always 0 per Yuanbao convention + "data": "", # serialised sticker metadata + } + } + +The `data` field carries a JSON string with the sticker's metadata so the +receiver can look up the correct asset in the emoji pack. +""" + +from __future__ import annotations + +import json +import random +import re +import unicodedata +from typing import Optional + +# --------------------------------------------------------------------------- +# Sticker catalogue – ported from builtin-stickers.json +# Key : canonical name (Chinese) +# Value : {sticker_id, package_id, name, description, width, height, formats} +# --------------------------------------------------------------------------- +STICKER_MAP: dict[str, dict] = { + "六六六": { + "sticker_id": "278", "package_id": "1003", "name": "六六六", + "description": "666 厉害 牛 棒 绝了 好强 awesome", + "width": 128, "height": 128, "formats": "png", + }, + "我想开了": { + "sticker_id": "262", "package_id": "1003", "name": "我想开了", + "description": "想开 佛系 释怀 顿悟 看淡了 无所谓", + "width": 128, "height": 128, "formats": "png", + }, + "害羞": { + "sticker_id": "130", "package_id": "1003", "name": "害羞", + "description": "腼腆 不好意思 脸红 娇羞 羞涩 捂脸", + "width": 128, "height": 128, "formats": "png", + }, + "比心": { + "sticker_id": "252", "package_id": "1003", "name": "比心", + "description": "笔芯 爱你 爱心手势 love heart 喜欢你", + "width": 128, "height": 128, "formats": "png", + }, + "委屈": { + "sticker_id": "125", "package_id": "1003", "name": "委屈", + "description": "难过 想哭 可怜巴巴 瘪嘴 受伤 被欺负", + "width": 128, "height": 128, "formats": "png", + }, + "亲亲": { + "sticker_id": "146", "package_id": "1003", "name": "亲亲", + "description": "么么 mua 亲一下 kiss 飞吻 啵", + "width": 128, "height": 128, "formats": "png", + }, + "酷": { + "sticker_id": "131", "package_id": "1003", "name": "酷", + "description": "帅 墨镜 cool 高冷 有型 swagger", + "width": 128, "height": 128, "formats": "png", + }, + "睡": { + "sticker_id": "145", "package_id": "1003", "name": "睡", + "description": "睡觉 困 zzZ 打盹 躺平 休眠 sleepy", + "width": 128, "height": 128, "formats": "png", + }, + "发呆": { + "sticker_id": "152", "package_id": "1003", "name": "发呆", + "description": "懵 愣住 放空 呆滞 出神 脑子空白", + "width": 128, "height": 128, "formats": "png", + }, + "可怜": { + "sticker_id": "157", "package_id": "1003", "name": "可怜", + "description": "卖萌 求饶 委屈巴巴 弱小 拜托 眼巴巴", + "width": 128, "height": 128, "formats": "png", + }, + "摊手": { + "sticker_id": "200", "package_id": "1003", "name": "摊手", + "description": "无奈 没办法 耸肩 随便 那咋整 whatever", + "width": 128, "height": 128, "formats": "png", + }, + "头大": { + "sticker_id": "213", "package_id": "1003", "name": "头大", + "description": "头疼 烦恼 郁闷 难搞 崩溃 一团乱", + "width": 128, "height": 128, "formats": "png", + }, + "吓": { + "sticker_id": "256", "package_id": "1003", "name": "吓", + "description": "害怕 惊恐 震惊 吓一跳 恐怖 怂", + "width": 128, "height": 128, "formats": "png", + }, + "吐血": { + "sticker_id": "203", "package_id": "1003", "name": "吐血", + "description": "无语 崩溃 被雷 内伤 一口老血 屮", + "width": 128, "height": 128, "formats": "png", + }, + "哼": { + "sticker_id": "185", "package_id": "1003", "name": "哼", + "description": "傲娇 生气 不满 撇嘴 不理 赌气", + "width": 128, "height": 128, "formats": "png", + }, + "嘿嘿": { + "sticker_id": "220", "package_id": "1003", "name": "嘿嘿", + "description": "坏笑 猥琐笑 偷笑 憨笑 得意 你懂的", + "width": 128, "height": 128, "formats": "png", + }, + "头秃": { + "sticker_id": "218", "package_id": "1003", "name": "头秃", + "description": "程序员 加班 焦虑 没头发 秃了 肝爆", + "width": 128, "height": 128, "formats": "png", + }, + "暗中观察": { + "sticker_id": "221", "package_id": "1003", "name": "暗中观察", + "description": "窥屏 潜水 偷偷看 角落 围观 屏住呼吸", + "width": 128, "height": 128, "formats": "png", + }, + "我酸了": { + "sticker_id": "224", "package_id": "1003", "name": "我酸了", + "description": "嫉妒 柠檬精 羡慕 吃柠檬 眼红 恰柠檬", + "width": 128, "height": 128, "formats": "png", + }, + "打call": { + "sticker_id": "246", "package_id": "1003", "name": "打call", + "description": "应援 加油 支持 喝彩 助威 call", + "width": 128, "height": 128, "formats": "png", + }, + "庆祝": { + "sticker_id": "251", "package_id": "1003", "name": "庆祝", + "description": "祝贺 开心 耶 party 胜利 干杯", + "width": 128, "height": 128, "formats": "png", + }, + "奋斗": { + "sticker_id": "151", "package_id": "1003", "name": "奋斗", + "description": "努力 加油 拼搏 冲 干劲 卷起来", + "width": 128, "height": 128, "formats": "png", + }, + "惊讶": { + "sticker_id": "143", "package_id": "1003", "name": "惊讶", + "description": "震惊 哇 不敢相信 OMG 居然 这么离谱", + "width": 128, "height": 128, "formats": "png", + }, + "疑问": { + "sticker_id": "144", "package_id": "1003", "name": "疑问", + "description": "问号 不懂 啥 为什么 啥情况 懵逼问", + "width": 128, "height": 128, "formats": "png", + }, + "仔细分析": { + "sticker_id": "248", "package_id": "1003", "name": "仔细分析", + "description": "思考 推敲 认真 研究 琢磨 让我想想", + "width": 128, "height": 128, "formats": "png", + }, + "撅嘴": { + "sticker_id": "184", "package_id": "1003", "name": "撅嘴", + "description": "嘟嘴 卖萌 不高兴 撒娇 嘴翘", + "width": 128, "height": 128, "formats": "png", + }, + "泪奔": { + "sticker_id": "199", "package_id": "1003", "name": "泪奔", + "description": "大哭 伤心 破防 感动哭 泪流满面 呜呜", + "width": 128, "height": 128, "formats": "png", + }, + "尊嘟假嘟": { + "sticker_id": "276", "package_id": "1003", "name": "尊嘟假嘟", + "description": "真的假的 真假 可爱问 你骗我 是不是", + "width": 128, "height": 128, "formats": "png", + }, + "略略略": { + "sticker_id": "113", "package_id": "1003", "name": "略略略", + "description": "调皮 吐舌 不服 略 气死你 鬼脸", + "width": 128, "height": 128, "formats": "png", + }, + "困": { + "sticker_id": "180", "package_id": "1003", "name": "困", + "description": "想睡 倦 打哈欠 睁不开眼 好困啊 sleepy", + "width": 128, "height": 128, "formats": "png", + }, + "折磨": { + "sticker_id": "181", "package_id": "1003", "name": "折磨", + "description": "难受 痛苦 煎熬 蚌埠住了 受不了 要命", + "width": 128, "height": 128, "formats": "png", + }, + "抠鼻": { + "sticker_id": "182", "package_id": "1003", "name": "抠鼻", + "description": "不屑 无聊 淡定 无所谓 鄙视 挖鼻", + "width": 128, "height": 128, "formats": "png", + }, + "鼓掌": { + "sticker_id": "183", "package_id": "1003", "name": "鼓掌", + "description": "拍手 叫好 赞同 666 喝彩 掌声", + "width": 128, "height": 128, "formats": "png", + }, + "斜眼笑": { + "sticker_id": "204", "package_id": "1003", "name": "斜眼笑", + "description": "滑稽 坏笑 doge 意味深长 阴阳怪气 嘿嘿嘿", + "width": 128, "height": 128, "formats": "png", + }, + "辣眼睛": { + "sticker_id": "216", "package_id": "1003", "name": "辣眼睛", + "description": "看不下去 cringe 毁三观 太丑了 瞎了", + "width": 128, "height": 128, "formats": "png", + }, + "哦哟": { + "sticker_id": "217", "package_id": "1003", "name": "哦哟", + "description": "惊讶 起哄 哇哦 有戏 不简单 哟", + "width": 128, "height": 128, "formats": "png", + }, + "吃瓜": { + "sticker_id": "222", "package_id": "1003", "name": "吃瓜", + "description": "围观 看戏 八卦 路人 看热闹 板凳", + "width": 128, "height": 128, "formats": "png", + }, + "狗头": { + "sticker_id": "225", "package_id": "1003", "name": "狗头", + "description": "doge 保命 开玩笑 滑稽 反讽 懂的都懂", + "width": 128, "height": 128, "formats": "png", + }, + "敬礼": { + "sticker_id": "227", "package_id": "1003", "name": "敬礼", + "description": "salute 尊重 收到 遵命 致敬 报告", + "width": 128, "height": 128, "formats": "png", + }, + "哦": { + "sticker_id": "231", "package_id": "1003", "name": "哦", + "description": "知道了 明白 敷衍 嗯 这样啊 收到", + "width": 128, "height": 128, "formats": "png", + }, + "拿到红包": { + "sticker_id": "236", "package_id": "1003", "name": "拿到红包", + "description": "红包 谢谢老板 发财 开心 抢到了 欧气", + "width": 128, "height": 128, "formats": "png", + }, + "牛吖": { + "sticker_id": "239", "package_id": "1003", "name": "牛吖", + "description": "牛 厉害 强 666 佩服 大佬", + "width": 128, "height": 128, "formats": "png", + }, + "贴贴": { + "sticker_id": "272", "package_id": "1003", "name": "贴贴", + "description": "抱抱 亲昵 蹭蹭 亲密 靠靠 撒娇贴", + "width": 128, "height": 128, "formats": "png", + }, + "爱心": { + "sticker_id": "138", "package_id": "1003", "name": "爱心", + "description": "心 love 喜欢你 红心 示爱 么么哒", + "width": 128, "height": 128, "formats": "png", + }, + "晚安": { + "sticker_id": "170", "package_id": "1003", "name": "晚安", + "description": "好梦 睡了 night 早点休息 安啦 moon", + "width": 128, "height": 128, "formats": "png", + }, + "太阳": { + "sticker_id": "176", "package_id": "1003", "name": "太阳", + "description": "晴天 早上好 阳光 morning 好天气 日", + "width": 128, "height": 128, "formats": "png", + }, + "柠檬": { + "sticker_id": "266", "package_id": "1003", "name": "柠檬", + "description": "酸 嫉妒 柠檬精 羡慕 我酸 恰柠檬", + "width": 128, "height": 128, "formats": "png", + }, + "大冤种": { + "sticker_id": "267", "package_id": "1003", "name": "大冤种", + "description": "倒霉 吃亏 自嘲 好心没好报 背锅 工具人", + "width": 128, "height": 128, "formats": "png", + }, + "吐了": { + "sticker_id": "132", "package_id": "1003", "name": "吐了", + "description": "恶心 yue 受不了 嫌弃 想吐 生理不适", + "width": 128, "height": 128, "formats": "png", + }, + "怒": { + "sticker_id": "134", "package_id": "1003", "name": "怒", + "description": "生气 愤怒 火大 暴躁 气炸 怼", + "width": 128, "height": 128, "formats": "png", + }, + "玫瑰": { + "sticker_id": "165", "package_id": "1003", "name": "玫瑰", + "description": "花 示爱 表白 浪漫 送你花 情人节", + "width": 128, "height": 128, "formats": "png", + }, + "凋谢": { + "sticker_id": "119", "package_id": "1003", "name": "凋谢", + "description": "花谢 失恋 难过 枯萎 心碎 凉了", + "width": 128, "height": 128, "formats": "png", + }, + "点赞": { + "sticker_id": "159", "package_id": "1003", "name": "点赞", + "description": "赞 认同 好棒 good like 大拇指 顶", + "width": 128, "height": 128, "formats": "png", + }, + "握手": { + "sticker_id": "164", "package_id": "1003", "name": "握手", + "description": "合作 你好 商务 hello deal 成交 友好", + "width": 128, "height": 128, "formats": "png", + }, + "抱拳": { + "sticker_id": "163", "package_id": "1003", "name": "抱拳", + "description": "谢谢 失敬 江湖 承让 拜托 有礼", + "width": 128, "height": 128, "formats": "png", + }, + "ok": { + "sticker_id": "169", "package_id": "1003", "name": "ok", + "description": "好的 收到 没问题 okay 行 可以 懂了", + "width": 128, "height": 128, "formats": "png", + }, + "拳头": { + "sticker_id": "174", "package_id": "1003", "name": "拳头", + "description": "加油 干 冲 fight 力量 击拳 硬气", + "width": 128, "height": 128, "formats": "png", + }, + "鞭炮": { + "sticker_id": "191", "package_id": "1003", "name": "鞭炮", + "description": "过年 喜庆 爆竹 春节 噼里啪啦 红", + "width": 128, "height": 128, "formats": "png", + }, + "烟花": { + "sticker_id": "258", "package_id": "1003", "name": "烟花", + "description": "庆典 漂亮 新年 嘭 绽放 节日快乐", + "width": 128, "height": 128, "formats": "png", + }, +} + + +def get_sticker_by_name(name: str) -> Optional[dict]: + """ + 按名称查找贴纸,支持模糊匹配。 + + 匹配优先级: + 1. 完全相等(name) + 2. name 包含查询词(前缀/子串) + 3. description 包含查询词(同义词搜索) + 4. 通用模糊评分(与 sticker-search 同算法),命中即返回得分最高的一条 + + 返回 sticker dict,找不到返回 None。 + """ + if not name: + return None + + query = name.strip() + + if query in STICKER_MAP: + return STICKER_MAP[query] + + for key, sticker in STICKER_MAP.items(): + if query in key or key in query: + return sticker + + for sticker in STICKER_MAP.values(): + desc = sticker.get("description", "") + if query in desc: + return sticker + + matches = search_stickers(query, limit=1) + return matches[0] if matches else None + + +def get_random_sticker(category: str = None) -> dict: + """ + 随机返回一个贴纸。 + + 若指定 category,则在 description 中含有该关键词的贴纸里随机选取; + category 为 None 时从全表随机。 + """ + if category: + candidates = [ + s for s in STICKER_MAP.values() + if category in s.get("description", "") or category in s.get("name", "") + ] + if candidates: + return random.choice(candidates) + return random.choice(list(STICKER_MAP.values())) + + +def get_sticker_by_id(sticker_id: str) -> Optional[dict]: + """按 sticker_id 精确查找贴纸。""" + if not sticker_id: + return None + sid = str(sticker_id).strip() + for sticker in STICKER_MAP.values(): + if sticker.get("sticker_id") == sid: + return sticker + return None + + +# --------------------------------------------------------------------------- +# 模糊搜索(对齐 chatbot-web yuanbao-openclaw-plugin/sticker-cache.ts.searchStickers) +# --------------------------------------------------------------------------- + +_PUNCT_RE = re.compile(r"[\s\u3000\-_·.,,。!!??\"“”'‘’、/\\]+") + + +def _normalize_text(raw: str) -> str: + return unicodedata.normalize("NFKC", str(raw or "")).strip().lower() + + +def _compact_text(raw: str) -> str: + return _PUNCT_RE.sub("", _normalize_text(raw)) + + +def _multiset_char_hit_ratio(needle: str, haystack: str) -> float: + if not needle: + return 0.0 + bag: dict[str, int] = {} + for ch in haystack: + bag[ch] = bag.get(ch, 0) + 1 + hits = 0 + for ch in needle: + n = bag.get(ch, 0) + if n > 0: + hits += 1 + bag[ch] = n - 1 + return hits / len(needle) + + +def _bigram_jaccard(a: str, b: str) -> float: + if len(a) < 2 or len(b) < 2: + return 0.0 + A = {a[i:i + 2] for i in range(len(a) - 1)} + B = {b[i:i + 2] for i in range(len(b) - 1)} + inter = len(A & B) + union = len(A) + len(B) - inter + return inter / union if union else 0.0 + + +def _longest_subsequence_ratio(needle: str, haystack: str) -> float: + if not needle: + return 0.0 + j = 0 + for ch in haystack: + if j >= len(needle): + break + if ch == needle[j]: + j += 1 + return j / len(needle) + + +def _score_field(haystack: str, query: str) -> float: + hay = _normalize_text(haystack) + q = _normalize_text(query) + if not hay or not q: + return 0.0 + hay_c = _compact_text(haystack) + q_c = _compact_text(query) + best = 0.0 + if hay == q: + best = max(best, 100.0) + if q in hay: + best = max(best, 92 + min(6, len(q))) + if len(q) >= 2 and hay.startswith(q): + best = max(best, 88.0) + if q_c and q_c in hay_c: + best = max(best, 86.0) + best = max(best, _multiset_char_hit_ratio(q_c, hay_c) * 62) + best = max(best, _bigram_jaccard(q_c, hay_c) * 58) + best = max(best, _longest_subsequence_ratio(q_c, hay_c) * 52) + if len(q) == 1 and q in hay: + best = max(best, 68.0) + return best + + +def search_stickers(query: str, limit: int = 10) -> list[dict]: + """ + 在内置贴纸表中按模糊匹配排序返回前 N 条结果。 + + 评分综合 name/description 字段的子串、字符多重集覆盖、bigram Jaccard、子序列比例。 + name 权重略高于 description(×0.88)。空 query 时按字典顺序返回前 N 条。 + """ + safe_limit = max(1, min(500, int(limit) if limit else 10)) + if not query or not _normalize_text(query): + return list(STICKER_MAP.values())[:safe_limit] + + scored: list[tuple[float, dict]] = [] + for sticker in STICKER_MAP.values(): + name_s = _score_field(sticker.get("name", ""), query) + desc_s = _score_field(sticker.get("description", ""), query) * 0.88 + sid = str(sticker.get("sticker_id", "")).strip() + q_norm = _normalize_text(query) + id_s = 0.0 + if sid and q_norm: + sid_norm = _normalize_text(sid) + if sid_norm == q_norm: + id_s = 100.0 + elif q_norm in sid_norm: + id_s = 84.0 + scored.append((max(name_s, desc_s, id_s), sticker)) + + scored.sort(key=lambda x: x[0], reverse=True) + top = scored[0][0] if scored else 0 + if top <= 0: + return [s for _, s in scored[:safe_limit]] + + if top >= 22: + floor = 18.0 + elif top >= 12: + floor = max(10.0, top * 0.5) + else: + floor = max(6.0, top * 0.35) + + filtered = [pair for pair in scored if pair[0] >= floor] + out = filtered if filtered else scored + return [s for _, s in out[:safe_limit]] + + +def build_face_msg_body( + face_index: int, + face_type: int = 1, + data: Optional[str] = None, +) -> list: + """ + 构造 TIMFaceElem 消息体。 + + Yuanbao 约定: + - index 固定传 0(服务端通过 data 字段识别具体表情) + - data 为 JSON 字符串,包含 sticker_id / package_id 等字段 + + Args: + face_index: 保留字段,暂时不影响 wire format(Yuanbao 固定 index=0)。 + 当 face_index > 0 时视为旧版 QQ 表情 ID,直接放入 index。 + face_type: 保留字段(兼容旧接口,当前未使用)。 + data: 已序列化的 JSON 字符串;为 None 时仅传 index。 + + Returns: + 符合 Yuanbao TIM 协议的 msg_body list,如:: + + [{"msg_type": "TIMFaceElem", "msg_content": {"index": 0, "data": "..."}}] + """ + msg_content: dict = {"index": face_index} + if data is not None: + msg_content["data"] = data + return [{"msg_type": "TIMFaceElem", "msg_content": msg_content}] + + +def build_sticker_msg_body(sticker: dict) -> list: + """ + 从 STICKER_MAP 中的 sticker dict 直接构造 TIMFaceElem 消息体。 + + 这是 send_sticker() 的内部辅助,确保 data 字段与原始 JS 插件一致。 + """ + data_payload = json.dumps( + { + "sticker_id": sticker["sticker_id"], + "package_id": sticker["package_id"], + "width": sticker.get("width", 128), + "height": sticker.get("height", 128), + "formats": sticker.get("formats", "png"), + "name": sticker["name"], + }, + ensure_ascii=False, + separators=(",", ":"), + ) + return build_face_msg_body(face_index=0, data=data_payload) diff --git a/gateway/run.py b/gateway/run.py index 42a6b82f98..00f15db3b6 100644 --- a/gateway/run.py +++ b/gateway/run.py @@ -2123,6 +2123,7 @@ class GatewayRunner: "WEIXIN_ALLOWED_USERS", "BLUEBUBBLES_ALLOWED_USERS", "QQ_ALLOWED_USERS", + "YUANBAO_ALLOWED_USERS", "GATEWAY_ALLOWED_USERS") ) _allow_all = os.getenv("GATEWAY_ALLOW_ALL_USERS", "").lower() in ("true", "1", "yes") or any( @@ -2137,7 +2138,8 @@ class GatewayRunner: "WECOM_CALLBACK_ALLOW_ALL_USERS", "WEIXIN_ALLOW_ALL_USERS", "BLUEBUBBLES_ALLOW_ALL_USERS", - "QQ_ALLOW_ALL_USERS") + "QQ_ALLOW_ALL_USERS", + "YUANBAO_ALLOW_ALL_USERS") ) if not _any_allowlist and not _allow_all: logger.warning( @@ -3114,8 +3116,14 @@ class GatewayRunner: return None return QQAdapter(config) - return None + elif platform == Platform.YUANBAO: + from gateway.platforms.yuanbao import YuanbaoAdapter, WEBSOCKETS_AVAILABLE + if not WEBSOCKETS_AVAILABLE: + logger.warning("Yuanbao: websockets not installed. Run: pip install websockets") + return None + return YuanbaoAdapter(config) + return None def _is_user_authorized(self, source: SessionSource) -> bool: """ Check if a user is authorized to use the bot. @@ -3156,6 +3164,7 @@ class GatewayRunner: Platform.WEIXIN: "WEIXIN_ALLOWED_USERS", Platform.BLUEBUBBLES: "BLUEBUBBLES_ALLOWED_USERS", Platform.QQBOT: "QQ_ALLOWED_USERS", + Platform.YUANBAO: "YUANBAO_ALLOWED_USERS", } platform_group_env_map = { Platform.TELEGRAM: "TELEGRAM_GROUP_ALLOWED_USERS", @@ -3178,6 +3187,7 @@ class GatewayRunner: Platform.WEIXIN: "WEIXIN_ALLOW_ALL_USERS", Platform.BLUEBUBBLES: "BLUEBUBBLES_ALLOW_ALL_USERS", Platform.QQBOT: "QQ_ALLOW_ALL_USERS", + Platform.YUANBAO: "YUANBAO_ALLOW_ALL_USERS", } # Per-platform allow-all flag (e.g., DISCORD_ALLOW_ALL_USERS=true) diff --git a/gateway/session.py b/gateway/session.py index d693945d98..02d4eb3ed0 100644 --- a/gateway/session.py +++ b/gateway/session.py @@ -354,6 +354,14 @@ def build_session_context_prompt( "If the user needs a detailed answer, give the short version first " "and offer to elaborate." ) + elif context.source.platform == Platform.YUANBAO: + lines.append("") + lines.append( + "**Platform notes:** You are running inside Yuanbao. " + "You CAN send private (DM) messages via the send_message tool. " + "Use target='yuanbao:direct:' for DM " + "and target='yuanbao:group:' for group chat." + ) # Connected platforms platforms_list = ["local (files on this machine)"] diff --git a/hermes_cli/gateway.py b/hermes_cli/gateway.py index 3b828fecf5..aede480bfe 100644 --- a/hermes_cli/gateway.py +++ b/hermes_cli/gateway.py @@ -2724,6 +2724,24 @@ _PLATFORMS = [ "help": "OpenID to deliver cron results and notifications to."}, ], }, + { + "key": "yuanbao", + "label": "Yuanbao", + "emoji": "💎", + "token_var": "YUANBAO_APP_ID", + "setup_instructions": [ + "1. Download the Yuanbao app from https://yuanbao.tencent.com/", + "2. In the app, go to PAI → My Bot and create a new bot", + "3. After the bot is created, copy the App ID and App Secret", + "4. Enter them below and Hermes will connect automatically over WebSocket", + ], + "vars": [ + {"name": "YUANBAO_APP_ID", "prompt": "App ID", "password": False, + "help": "The App ID from your Yuanbao IM Bot credentials."}, + {"name": "YUANBAO_APP_SECRET", "prompt": "App Secret", "password": True, + "help": "The App Secret (used for HMAC signing) from your Yuanbao IM Bot."}, + ], + }, ] @@ -3108,6 +3126,12 @@ def _setup_wecom(): print_success("💬 WeCom configured!") +def _setup_yuanbao(): + """Configure Yuanbao via the standard platform setup.""" + yuanbao_platform = next(p for p in _PLATFORMS if p["key"] == "yuanbao") + _setup_standard_platform(yuanbao_platform) + + def _is_service_installed() -> bool: """Check if the gateway is installed as a system service.""" if supports_systemd_services(): diff --git a/hermes_cli/platforms.py b/hermes_cli/platforms.py index 05507eaced..bc609277c4 100644 --- a/hermes_cli/platforms.py +++ b/hermes_cli/platforms.py @@ -36,6 +36,7 @@ PLATFORMS: OrderedDict[str, PlatformInfo] = OrderedDict([ ("wecom_callback", PlatformInfo(label="💬 WeCom Callback", default_toolset="hermes-wecom-callback")), ("weixin", PlatformInfo(label="💬 Weixin", default_toolset="hermes-weixin")), ("qqbot", PlatformInfo(label="💬 QQBot", default_toolset="hermes-qqbot")), + ("yuanbao", PlatformInfo(label="🤖 Yuanbao", default_toolset="hermes-yuanbao")), ("webhook", PlatformInfo(label="🔗 Webhook", default_toolset="hermes-webhook")), ("api_server", PlatformInfo(label="🌐 API Server", default_toolset="hermes-api-server")), ("cron", PlatformInfo(label="⏰ Cron", default_toolset="hermes-cron")), diff --git a/hermes_cli/setup.py b/hermes_cli/setup.py index 2c4d28e027..92d7c37cf6 100644 --- a/hermes_cli/setup.py +++ b/hermes_cli/setup.py @@ -2133,6 +2133,12 @@ def _setup_feishu(): _gateway_setup_feishu() +def _setup_yuanbao(): + """Configure Yuanbao via gateway setup.""" + from hermes_cli.gateway import _setup_yuanbao as _gateway_setup_yuanbao + _gateway_setup_yuanbao() + + def _setup_wecom(): """Configure WeCom (Enterprise WeChat) via gateway setup.""" from hermes_cli.gateway import _setup_wecom as _gateway_setup_wecom @@ -2277,6 +2283,7 @@ _GATEWAY_PLATFORMS = [ ("WhatsApp", "WHATSAPP_ENABLED", _setup_whatsapp), ("DingTalk", "DINGTALK_CLIENT_ID", _setup_dingtalk), ("Feishu / Lark", "FEISHU_APP_ID", _setup_feishu), + ("Yuanbao", "YUANBAO_APP_ID", _setup_yuanbao), ("WeCom (Enterprise WeChat)", "WECOM_BOT_ID", _setup_wecom), ("WeCom Callback (Self-Built App)", "WECOM_CALLBACK_CORP_ID", _setup_wecom_callback), ("Weixin (WeChat)", "WEIXIN_ACCOUNT_ID", _setup_weixin), diff --git a/hermes_cli/status.py b/hermes_cli/status.py index d07e1a8222..0285752681 100644 --- a/hermes_cli/status.py +++ b/hermes_cli/status.py @@ -326,7 +326,8 @@ def show_status(args): "WeCom Callback": ("WECOM_CALLBACK_CORP_ID", None), "Weixin": ("WEIXIN_ACCOUNT_ID", "WEIXIN_HOME_CHANNEL"), "BlueBubbles": ("BLUEBUBBLES_SERVER_URL", "BLUEBUBBLES_HOME_CHANNEL"), - "QQBot": ("QQ_APP_ID", "QQBOT_HOME_CHANNEL"), + "QQBot": ("QQ_APP_ID", "QQ_HOME_CHANNEL"), + "Yuanbao": ("YUANBAO_APP_ID", "YUANBAO_HOME_CHANNEL"), } for name, (token_var, home_var) in platforms.items(): diff --git a/hermes_cli/tools_config.py b/hermes_cli/tools_config.py index f2d1aab584..e70760da81 100644 --- a/hermes_cli/tools_config.py +++ b/hermes_cli/tools_config.py @@ -71,6 +71,7 @@ CONFIGURABLE_TOOLSETS = [ ("spotify", "🎵 Spotify", "playback, search, playlists, library"), ("discord", "💬 Discord (read/participate)", "fetch messages, search members, create thread"), ("discord_admin", "🛡️ Discord Server Admin", "list channels/roles, pin, assign roles"), + ("yuanbao", "🤖 Yuanbao", "group info, member queries, DM"), ] # Toolsets that are OFF by default for new installs. diff --git a/scripts/release.py b/scripts/release.py index e9fd4f72de..9eff98e2dc 100755 --- a/scripts/release.py +++ b/scripts/release.py @@ -396,6 +396,17 @@ AUTHOR_MAP = { "zzn+pa@zzn.im": "xinbenlv", "zaynjarvis@gmail.com": "ZaynJarvis", "zhiheng.liu@bytedance.com": "ZaynJarvis", + "izhaolongfei@gmail.com": "loongfay", + "296659110@qq.com": "lrt4836", + "fe.daniel91@gmail.com": "beforeload", + "libo1106@foxmail.com": "libo1106", + "295367131@qq.com": "295367131", + "295367132@qq.com": "IxAres", + "danieldliu@tencent.com": "danieldliu", + "loongzhao@tencent.com": "loongzhao", + "Bartok9@users.noreply.github.com": "Bartok9", + "LeonSGP43@users.noreply.github.com": "LeonSGP43", + "kshitijk4poor@users.noreply.github.com": "kshitijk4poor", "mbelleau@Michels-MacBook-Pro.local": "malaiwah", "michel.belleau@malaiwah.com": "malaiwah", "gnanasekaran.sekareee@gmail.com": "gnanam1990", diff --git a/skills/yuanbao/SKILL.md b/skills/yuanbao/SKILL.md new file mode 100644 index 0000000000..3b0fd25570 --- /dev/null +++ b/skills/yuanbao/SKILL.md @@ -0,0 +1,107 @@ +--- +name: yuanbao +description: Yuanbao (元宝) group interaction — @mention users, query group info and members +version: 1.0.0 +metadata: + hermes: + tags: [yuanbao, mention, at, group, members, 元宝, 派, 艾特] + related_skills: [] +--- + +# Yuanbao Group Interaction + +## CRITICAL: How Messaging Works + +**Your text reply IS the message sent to the group/user.** The gateway automatically delivers your response text to the chat. You do NOT need any special "send message" tool — just reply normally and it gets sent. + +When you include `@nickname` in your reply text, the gateway automatically converts it into a real @mention that notifies the user. This is built-in — you have full @mention capability. + +**NEVER say you cannot send messages or @mention users. NEVER suggest the user do it manually. NEVER add disclaimers about permissions. Just reply with the text you want sent.** + +## Available Tools + +| Tool | When to use | +|------|------------| +| `yb_query_group_info` | Query group name, owner, member count | +| `yb_query_group_members` | Find a user, list bots, list all members, or get nickname for @mention | +| `yb_send_dm` | Send a private/direct message (DM / 私信) to a user, with optional media files | + +## @Mention Workflow + +When you need to @mention / 艾特 someone: + +1. Call `yb_query_group_members` with `action="find"`, `name=""`, `mention=true` +2. Get the exact nickname from the response +3. Include `@nickname` in your reply text — the gateway handles the rest + +Example: user says "帮我艾特元宝" + +Step 1 — tool call: +```json +{ "group_code": "328306697", "action": "find", "name": "元宝", "mention": true } +``` + +Step 2 — your reply (this gets sent to the group with a working @mention): +``` +@元宝 你好,有人找你! +``` + +**That's it.** No extra explanation needed. Keep it short and natural. + +**Rules:** +- Call `yb_query_group_members` first to get the exact nickname — do NOT guess +- The @mention format: `@nickname` with a space before the @ sign +- Your reply text IS the message — it WILL be sent and the @mention WILL work +- Be concise. Do NOT explain how @mention works to the user. + +## Send DM (Private Message) Workflow + +When someone asks to send a private message / 私信 / DM to a user: + +1. Call `yb_send_dm` with `group_code`, `name` (target user's name), and `message` +2. The tool automatically finds the user and sends the DM +3. Report the result to the user + +Example: user says "给 @用户aea3 私信发一个 hello" + +```json +yb_send_dm({ "group_code": "535168412", "name": "用户aea3", "message": "hello" }) +``` + +Example with media: user says "给 @用户aea3 私信发一张图片" + +```json +yb_send_dm({ + "group_code": "535168412", + "name": "用户aea3", + "message": "Here is the image", + "media_files": [{"path": "/tmp/photo.jpg"}] +}) +``` + +**Rules:** +- Extract `group_code` from the current chat_id (e.g. `group:535168412` → `535168412`) +- If you already know the user_id, pass it directly via the `user_id` parameter to skip lookup +- If multiple users match the name, the tool returns candidates — ask the user to clarify +- Do NOT use `send_message` tool for Yuanbao DMs — use `yb_send_dm` instead +- Supports media: images (.jpg/.png/.gif/.webp/.bmp) sent as image messages, other files as documents + +## Query Group Info + +```json +yb_query_group_info({ "group_code": "328306697" }) +``` + +## Query Members + +| Action | Description | +|--------|-------------| +| `find` | Search by name (partial match, case-insensitive) | +| `list_bots` | List bots and Yuanbao AI assistants | +| `list_all` | List all members | + +## Notes + +- `group_code` comes from chat_id: `group:328306697` → `328306697` +- Groups are called "派 (Pai)" in the Yuanbao app +- Member roles: `user`, `yuanbao_ai`, `bot` diff --git a/tests/test_yuanbao_integration.py b/tests/test_yuanbao_integration.py new file mode 100644 index 0000000000..48579c0f88 --- /dev/null +++ b/tests/test_yuanbao_integration.py @@ -0,0 +1,416 @@ +""" +test_yuanbao_integration.py - Yuanbao 模块集成测试 + +验证各模块能正确组装和交互: + - YuanbaoAdapter 初始化 + - Config / Platform 枚举 + - get_connected_platforms 逻辑 + - Proto 编解码 round-trip + - Markdown 分块 + - API / Media 模块 import + - Toolset 注册 +""" + +import sys +import os + +# 确保 hermes-agent 根目录在 sys.path 中 +_REPO_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +if _REPO_ROOT not in sys.path: + sys.path.insert(0, _REPO_ROOT) + +import pytest +from unittest.mock import AsyncMock, MagicMock, patch +from gateway.config import Platform, PlatformConfig, GatewayConfig +from gateway.platforms.yuanbao import YuanbaoAdapter + + +def make_config(**kwargs): + extra = kwargs.pop("extra", {}) + extra.setdefault("app_id", "test_key") + extra.setdefault("app_secret", "test_secret") + extra.setdefault("ws_url", "wss://test.example.com/ws") + extra.setdefault("api_domain", "https://test.example.com") + return PlatformConfig( + extra=extra, + **kwargs, + ) + + +# =========================================================== +# 1. Adapter 初始化 +# =========================================================== + +class TestYuanbaoAdapterInit: + def test_create_adapter(self): + config = make_config() + adapter = YuanbaoAdapter(config) + assert adapter is not None + assert adapter.PLATFORM == Platform.YUANBAO + + def test_initial_state(self): + config = make_config() + adapter = YuanbaoAdapter(config) + status = adapter.get_status() + assert status["connected"] == False + assert status["bot_id"] is None + + +# =========================================================== +# 2. Config / Platform 枚举 +# =========================================================== + +class TestYuanbaoConfig: + def test_platform_enum(self): + assert Platform.YUANBAO.value == "yuanbao" + + def test_config_fields(self): + config = make_config() + assert config.extra["app_id"] == "test_key" + assert config.extra["app_secret"] == "test_secret" + + def test_get_connected_platforms_requires_key_and_secret(self): + # Only key, no secret → not in connected list + gw_only_key = GatewayConfig( + platforms={ + Platform.YUANBAO: PlatformConfig( + enabled=True, + extra={"app_id": "key"}, + ) + } + ) + platforms = gw_only_key.get_connected_platforms() + assert Platform.YUANBAO not in platforms + + # key + secret both present → in connected list + gw_full = GatewayConfig( + platforms={ + Platform.YUANBAO: PlatformConfig( + enabled=True, + extra={"app_id": "key", "app_secret": "secret"}, + ) + } + ) + platforms2 = gw_full.get_connected_platforms() + assert Platform.YUANBAO in platforms2 + + +# =========================================================== +# 3. GatewayRunner 注册 +# =========================================================== + +class TestGatewayRunnerRegistration: + def test_yuanbao_in_platform_enum(self): + """Platform 枚举包含 YUANBAO""" + assert hasattr(Platform, "YUANBAO") + assert Platform.YUANBAO.value == "yuanbao" + + def _make_minimal_runner(self, config): + """通过 __new__ + 最小初始化绕过 run.py 的模块级 dotenv/ssl 副作用""" + import sys + from unittest.mock import MagicMock + + # Stub out heavy dependencies if not already present + stubs = [ + "dotenv", + "hermes_cli.env_loader", + "hermes_cli.config", + "hermes_constants", + ] + _orig = {} + for mod in stubs: + if mod not in sys.modules: + _orig[mod] = None + sys.modules[mod] = MagicMock() + + try: + from gateway.run import GatewayRunner + finally: + # Restore only the ones we injected + for mod, orig in _orig.items(): + if orig is None: + sys.modules.pop(mod, None) + + runner = GatewayRunner.__new__(GatewayRunner) + runner.config = config + runner.adapters = {} + runner._failed_platforms = {} + runner._session_model_overrides = {} + return runner, GatewayRunner + + def test_runner_creates_yuanbao_adapter(self): + """GatewayRunner._create_adapter 能为 YUANBAO 返回 YuanbaoAdapter 实例""" + from gateway.config import GatewayConfig + from unittest.mock import patch + config = make_config(enabled=True) + gw_config = GatewayConfig(platforms={Platform.YUANBAO: config}) + + try: + runner, _ = self._make_minimal_runner(gw_config) + # websockets 在测试环境可能未安装,mock 掉 WEBSOCKETS_AVAILABLE + with patch("gateway.platforms.yuanbao.WEBSOCKETS_AVAILABLE", True): + adapter = runner._create_adapter(Platform.YUANBAO, config) + except ImportError as e: + pytest.skip(f"run.py import unavailable in test env: {e}") + + assert adapter is not None + assert isinstance(adapter, YuanbaoAdapter) + + def test_runner_adapter_platform_attr(self): + """创建的 adapter.PLATFORM 为 Platform.YUANBAO""" + from gateway.config import GatewayConfig + from unittest.mock import patch + config = make_config(enabled=True) + gw_config = GatewayConfig(platforms={Platform.YUANBAO: config}) + + try: + runner, _ = self._make_minimal_runner(gw_config) + with patch("gateway.platforms.yuanbao.WEBSOCKETS_AVAILABLE", True): + adapter = runner._create_adapter(Platform.YUANBAO, config) + except ImportError as e: + pytest.skip(f"run.py import unavailable in test env: {e}") + + assert adapter is not None + assert adapter.PLATFORM == Platform.YUANBAO + + +# =========================================================== +# 4. Proto round-trip +# =========================================================== + +class TestProtoRoundTrip: + """验证 proto 编解码基本功能""" + + def test_conn_msg_roundtrip(self): + from gateway.platforms.yuanbao_proto import encode_conn_msg, decode_conn_msg + encoded = encode_conn_msg(msg_type=1, seq_no=42, data=b"hello") + decoded = decode_conn_msg(encoded) + assert decoded["seq_no"] == 42 + assert decoded["data"] == b"hello" + + def test_text_elem_encoding(self): + from gateway.platforms.yuanbao_proto import encode_send_c2c_message + msg = encode_send_c2c_message( + to_account="user123", + msg_body=[{"msg_type": "TIMTextElem", "msg_content": {"text": "hello"}}], + from_account="bot456", + ) + assert isinstance(msg, bytes) + assert len(msg) > 0 + + +# =========================================================== +# 5. Markdown 分块 +# =========================================================== + +class TestMarkdownChunking: + def test_chunks_are_sent_separately(self): + from gateway.platforms.yuanbao import MarkdownProcessor + long_text = "paragraph\n\n" * 100 + chunks = MarkdownProcessor.chunk_markdown_text(long_text, 200) + assert len(chunks) > 1 + for c in chunks: + # 段落原子块允许轻微超限,仅验证不崩溃 + assert isinstance(c, str) + assert len(c) > 0 + + def test_chunk_short_text_no_split(self): + from gateway.platforms.yuanbao import MarkdownProcessor + text = "hello world" + chunks = MarkdownProcessor.chunk_markdown_text(text, 3000) + assert chunks == [text] + + +# =========================================================== +# 6. Sign Token 模块 +# =========================================================== + +class TestSignToken: + def test_import_ok(self): + from gateway.platforms.yuanbao import SignManager + assert callable(SignManager.get_token) + assert callable(SignManager.force_refresh) + + +# =========================================================== +# 6b. ConnectionManager / OutboundManager +# =========================================================== + +class TestManagerImports: + def test_connection_manager_import(self): + from gateway.platforms.yuanbao import ConnectionManager + assert ConnectionManager is not None + + def test_outbound_manager_import(self): + from gateway.platforms.yuanbao import OutboundManager + assert OutboundManager is not None + + def test_message_sender_import(self): + from gateway.platforms.yuanbao import MessageSender + assert MessageSender is not None + + def test_heartbeat_manager_import(self): + from gateway.platforms.yuanbao import HeartbeatManager + assert HeartbeatManager is not None + + def test_slow_response_notifier_import(self): + from gateway.platforms.yuanbao import SlowResponseNotifier + assert SlowResponseNotifier is not None + + def test_adapter_has_outbound_manager(self): + adapter = YuanbaoAdapter(make_config()) + from gateway.platforms.yuanbao import ConnectionManager, OutboundManager + assert isinstance(adapter._connection, ConnectionManager) + assert isinstance(adapter._outbound, OutboundManager) + + def test_outbound_composes_sub_managers(self): + adapter = YuanbaoAdapter(make_config()) + from gateway.platforms.yuanbao import MessageSender, HeartbeatManager, SlowResponseNotifier + assert isinstance(adapter._outbound.sender, MessageSender) + assert isinstance(adapter._outbound.heartbeat, HeartbeatManager) + assert isinstance(adapter._outbound.slow_notifier, SlowResponseNotifier) + + +# =========================================================== +# 7. Media 模块 +# =========================================================== + +class TestMediaModule: + def test_import_ok(self): + from gateway.platforms.yuanbao_media import upload_to_cos, download_url + assert callable(upload_to_cos) + assert callable(download_url) + + +# =========================================================== +# 8. Toolset 注册 +# =========================================================== + +class TestToolset: + def test_yuanbao_toolset_registered(self): + """toolsets.py 中存在 hermes-yuanbao 键""" + import importlib + ts = importlib.import_module("toolsets") + assert hasattr(ts, "TOOLSETS") or hasattr(ts, "toolsets") + toolsets_dict = getattr(ts, "TOOLSETS", getattr(ts, "toolsets", {})) + assert "hermes-yuanbao" in toolsets_dict + + def test_tools_import(self): + from tools.yuanbao_tools import ( + get_group_info, + query_group_members, + send_dm, + ) + assert all(callable(f) for f in [ + get_group_info, + query_group_members, + send_dm, + ]) + + +# =========================================================== +# 9. platforms/__init__.py 导出 +# =========================================================== + +class TestPlatformInit: + def test_yuanbao_adapter_exported(self): + """gateway.platforms.__init__.py 应导出 YuanbaoAdapter""" + from gateway.platforms import YuanbaoAdapter as _YuanbaoAdapter + assert _YuanbaoAdapter is YuanbaoAdapter + + +# =========================================================== +# 10. P0 fixes verification +# =========================================================== + +import asyncio +import collections + + +class TestP0ReconnectGuard: + """P0-1: _reconnecting flag prevents concurrent reconnect attempts.""" + + def test_reconnecting_flag_initialized(self): + adapter = YuanbaoAdapter(make_config()) + assert hasattr(adapter._connection, '_reconnecting') + assert adapter._connection._reconnecting is False + + def test_schedule_reconnect_skips_when_not_running(self): + adapter = YuanbaoAdapter(make_config()) + adapter._running = False + adapter._connection._reconnecting = False + adapter._connection.schedule_reconnect() + # No task should be created because _running is False + + def test_schedule_reconnect_skips_when_already_reconnecting(self): + adapter = YuanbaoAdapter(make_config()) + adapter._running = True + adapter._connection._reconnecting = True + adapter._connection.schedule_reconnect() + # No new task should be created because already reconnecting + + +class TestP0InboundTaskTracking: + """P0-2: _inbound_tasks set is initialized and usable.""" + + def test_inbound_tasks_initialized(self): + adapter = YuanbaoAdapter(make_config()) + assert hasattr(adapter, '_inbound_tasks') + assert isinstance(adapter._inbound_tasks, set) + assert len(adapter._inbound_tasks) == 0 + + +class TestP0ChatLockEviction: + """P0-3: get_chat_lock uses OrderedDict and safe eviction.""" + + def test_chat_locks_is_ordered_dict(self): + adapter = YuanbaoAdapter(make_config()) + assert isinstance(adapter._outbound._chat_locks, collections.OrderedDict) + + def test_eviction_skips_locked(self): + """When eviction is needed, locked entries are skipped.""" + adapter = YuanbaoAdapter(make_config()) + from gateway.platforms.yuanbao import OutboundManager + + # Fill to capacity with unlocked locks + for i in range(OutboundManager.CHAT_DICT_MAX_SIZE): + adapter._outbound._chat_locks[f"chat_{i}"] = asyncio.Lock() + + # Lock the oldest entry + oldest_key = next(iter(adapter._outbound._chat_locks)) + oldest_lock = adapter._outbound._chat_locks[oldest_key] + # Simulate a held lock by acquiring it in a non-async way (set _locked) + # asyncio.Lock is not held until actually acquired; so we test the + # method logic by acquiring the first lock manually. + # For a sync test, we check that get_chat_lock doesn't crash. + new_lock = adapter._outbound.get_chat_lock("new_chat") + assert "new_chat" in adapter._outbound._chat_locks + assert isinstance(new_lock, asyncio.Lock) + # The oldest unlocked entry should have been evicted + assert len(adapter._outbound._chat_locks) == OutboundManager.CHAT_DICT_MAX_SIZE + + def test_move_to_end_on_access(self): + """Accessing an existing key moves it to the end (MRU).""" + adapter = YuanbaoAdapter(make_config()) + adapter._outbound._chat_locks["a"] = asyncio.Lock() + adapter._outbound._chat_locks["b"] = asyncio.Lock() + adapter._outbound._chat_locks["c"] = asyncio.Lock() + + # Access "a" — should move to end + adapter._outbound.get_chat_lock("a") + keys = list(adapter._outbound._chat_locks.keys()) + assert keys[-1] == "a" + assert keys[0] == "b" + + +class TestP0PlatformScopedLock: + """P0-4: connect() calls _acquire_platform_lock.""" + + def test_adapter_has_platform_lock_methods(self): + adapter = YuanbaoAdapter(make_config()) + assert hasattr(adapter, '_acquire_platform_lock') + assert hasattr(adapter, '_release_platform_lock') + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/test_yuanbao_markdown.py b/tests/test_yuanbao_markdown.py new file mode 100644 index 0000000000..a5bff3e320 --- /dev/null +++ b/tests/test_yuanbao_markdown.py @@ -0,0 +1,324 @@ +""" +test_yuanbao_markdown.py - Unit tests for yuanbao_markdown.py + +Run (no pytest needed): + cd /root/.openclaw/workspace/hermes-agent + python3 tests/test_yuanbao_markdown.py -v + +Or with pytest if available: + python3 -m pytest tests/test_yuanbao_markdown.py -v +""" + +import sys +import os +import unittest + +# Ensure project root is on the path +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..')) + +from gateway.platforms.yuanbao import MarkdownProcessor + + +# ============ has_unclosed_fence ============ + +class TestHasUnclosedFence(unittest.TestCase): + def test_unclosed_fence(self): + self.assertTrue(MarkdownProcessor.has_unclosed_fence("```python\ncode")) + + def test_closed_fence(self): + self.assertFalse(MarkdownProcessor.has_unclosed_fence("```python\ncode\n```")) + + def test_empty(self): + self.assertFalse(MarkdownProcessor.has_unclosed_fence("")) + + def test_no_fence(self): + self.assertFalse(MarkdownProcessor.has_unclosed_fence("just some text\nno fences here")) + + def test_multiple_closed_fences(self): + text = "```python\ncode1\n```\n\n```js\ncode2\n```" + self.assertFalse(MarkdownProcessor.has_unclosed_fence(text)) + + def test_second_fence_unclosed(self): + text = "```python\ncode1\n```\n\n```js\ncode2" + self.assertTrue(MarkdownProcessor.has_unclosed_fence(text)) + + def test_fence_at_start(self): + self.assertTrue(MarkdownProcessor.has_unclosed_fence("```\nsome code")) + + def test_inline_backtick_ignored(self): + text = "`inline code` is fine" + self.assertFalse(MarkdownProcessor.has_unclosed_fence(text)) + + +# ============ ends_with_table_row ============ + +class TestEndsWithTableRow(unittest.TestCase): + def test_simple_table_row(self): + self.assertTrue(MarkdownProcessor.ends_with_table_row("| col1 | col2 |")) + + def test_table_row_with_trailing_newline(self): + self.assertTrue(MarkdownProcessor.ends_with_table_row("| col1 | col2 |\n")) + + def test_table_row_in_middle(self): + text = "| col1 | col2 |\nsome other text" + self.assertFalse(MarkdownProcessor.ends_with_table_row(text)) + + def test_empty(self): + self.assertFalse(MarkdownProcessor.ends_with_table_row("")) + + def test_non_table(self): + self.assertFalse(MarkdownProcessor.ends_with_table_row("just a normal line")) + + def test_only_pipe_start(self): + self.assertFalse(MarkdownProcessor.ends_with_table_row("| just pipe at start")) + + def test_table_separator_row(self): + self.assertTrue(MarkdownProcessor.ends_with_table_row("| --- | --- |")) + + def test_whitespace_only(self): + self.assertFalse(MarkdownProcessor.ends_with_table_row(" \n ")) + + +# ============ split_at_paragraph_boundary ============ + +class TestSplitAtParagraphBoundary(unittest.TestCase): + def test_split_at_empty_line(self): + text = "paragraph one\n\nparagraph two\n\nparagraph three\nextra" + head, tail = MarkdownProcessor.split_at_paragraph_boundary(text, 30) + self.assertLessEqual(len(head), 30) + self.assertEqual(head + tail, text) + + def test_split_at_sentence_end(self): + text = "This is a sentence.\nNext line.\nAnother line." + head, tail = MarkdownProcessor.split_at_paragraph_boundary(text, 25) + self.assertLessEqual(len(head), 25) + self.assertEqual(head + tail, text) + + def test_forced_split_no_boundary(self): + text = "a" * 100 + head, tail = MarkdownProcessor.split_at_paragraph_boundary(text, 50) + self.assertEqual(len(head), 50) + self.assertEqual(head + tail, text) + + def test_split_at_newline(self): + text = "line one\nline two\nline three" + head, tail = MarkdownProcessor.split_at_paragraph_boundary(text, 15) + self.assertLessEqual(len(head), 15) + self.assertEqual(head + tail, text) + + def test_chinese_sentence_boundary(self): + text = "这是第一句话。\n这是第二句话。\n这是第三句话。" + head, tail = MarkdownProcessor.split_at_paragraph_boundary(text, 15) + self.assertLessEqual(len(head), 15) + self.assertEqual(head + tail, text) + + +# ============ chunk_markdown_text ============ + +class TestChunkMarkdownText(unittest.TestCase): + def test_empty(self): + self.assertEqual(MarkdownProcessor.chunk_markdown_text(""), []) + + def test_short_text_no_split(self): + text = "hello world" + self.assertEqual(MarkdownProcessor.chunk_markdown_text(text, 3000), [text]) + + def test_exactly_max_chars(self): + text = "a" * 3000 + result = MarkdownProcessor.chunk_markdown_text(text, 3000) + self.assertEqual(len(result), 1) + self.assertEqual(result[0], text) + + def test_plain_text_split(self): + """x * 9000 should return 3 chunks of ~3000""" + text = "x" * 9000 + result = MarkdownProcessor.chunk_markdown_text(text, 3000) + self.assertEqual(len(result), 3) + for chunk in result: + self.assertLessEqual(len(chunk), 3000) + self.assertEqual(''.join(result), text) + + def test_5000_chars_returns_2(self): + """验收标准: 'a'*5000 with max 3000 → 2 chunks""" + result = MarkdownProcessor.chunk_markdown_text("a" * 5000, 3000) + self.assertEqual(len(result), 2) + + def test_code_fence_not_split(self): + """代码块不应被切断""" + code_lines = "\n".join([f" line_{i} = {i}" for i in range(200)]) + text = f"Some intro text.\n\n```python\n{code_lines}\n```\n\nSome outro text." + result = MarkdownProcessor.chunk_markdown_text(text, 3000) + for chunk in result: + self.assertFalse(MarkdownProcessor.has_unclosed_fence(chunk), + f"Chunk has unclosed fence:\n{chunk[:200]}...") + + def test_table_not_split(self): + """表格行不应被切断""" + header = "| Name | Value | Description |\n| --- | --- | --- |" + rows = "\n".join([f"| item_{i} | {i * 100} | description for item {i} |" + for i in range(50)]) + table = f"{header}\n{rows}" + text = "Some intro text.\n\n" + table + "\n\nSome outro text." + result = MarkdownProcessor.chunk_markdown_text(text, 3000) + for chunk in result: + self.assertFalse(MarkdownProcessor.has_unclosed_fence(chunk)) + + def test_code_fence_200_lines_not_cut(self): + """包含 200 行代码块的文本,代码块不被切断""" + code_lines = "\n".join([f"x = {i}" for i in range(200)]) + text = f"Intro.\n\n```python\n{code_lines}\n```\n\nOutro." + result = MarkdownProcessor.chunk_markdown_text(text, 3000) + for chunk in result: + self.assertFalse(MarkdownProcessor.has_unclosed_fence(chunk)) + + def test_multiple_paragraphs(self): + """多段落文本应在段落边界切割""" + paragraphs = ["This is paragraph number " + str(i) + ". " * 50 + for i in range(10)] + text = "\n\n".join(paragraphs) + result = MarkdownProcessor.chunk_markdown_text(text, 500) + self.assertGreater(len(result), 1) + total_content = ''.join(result) + self.assertGreaterEqual(len(total_content), len(text) * 0.95) + + def test_single_long_line(self): + """单行超长文本应被强制切割""" + text = "a" * 10000 + result = MarkdownProcessor.chunk_markdown_text(text, 3000) + self.assertGreaterEqual(len(result), 3) + for c in result: + self.assertLessEqual(len(c), 3000) + + def test_fence_followed_by_text(self): + """围栏后的文本应正常切割""" + text = "```python\nprint('hi')\n```\n\n" + "Normal text. " * 300 + result = MarkdownProcessor.chunk_markdown_text(text, 500) + for chunk in result: + self.assertFalse(MarkdownProcessor.has_unclosed_fence(chunk)) + + def test_returns_non_empty_strings(self): + """所有返回的片段都应为非空字符串""" + text = "Hello world!\n\n" * 100 + result = MarkdownProcessor.chunk_markdown_text(text, 100) + for chunk in result: + self.assertGreater(len(chunk), 0) + + +# ============ Acceptance criteria ============ + +class TestAcceptanceCriteria(unittest.TestCase): + def test_9000_x_returns_3_chunks(self): + """验收:MarkdownProcessor.chunk_markdown_text("x" * 9000, 3000) 返回 3 个片段""" + result = MarkdownProcessor.chunk_markdown_text("x" * 9000, 3000) + self.assertEqual(len(result), 3) + for chunk in result: + self.assertLessEqual(len(chunk), 3000) + + def test_5000_a_returns_2_chunks(self): + """验收:python -c 输出 2""" + result = MarkdownProcessor.chunk_markdown_text("a" * 5000, 3000) + self.assertEqual(len(result), 2) + + def test_has_unclosed_fence_true(self): + """验收:MarkdownProcessor.has_unclosed_fence("```python\\ncode") 返回 True""" + self.assertTrue(MarkdownProcessor.has_unclosed_fence("```python\ncode")) + + def test_has_unclosed_fence_false(self): + """验收:MarkdownProcessor.has_unclosed_fence("```python\\ncode\\n```") 返回 False""" + self.assertFalse(MarkdownProcessor.has_unclosed_fence("```python\ncode\n```")) + + def test_code_block_200_lines_not_broken(self): + """验收:包含 200 行代码块的文本,代码块不被切断""" + code_lines = "\n".join([f" result_{i} = compute({i})" for i in range(200)]) + text = f"Introduction.\n\n```python\n{code_lines}\n```\n\nConclusion." + result = MarkdownProcessor.chunk_markdown_text(text, 3000) + for chunk in result: + self.assertFalse(MarkdownProcessor.has_unclosed_fence(chunk), + f"Found unclosed fence in chunk:\n{chunk[:100]}...") + + def test_table_rows_not_broken(self): + """验收:表格行不被切断(每个 chunk 中的表格 fence 完整)""" + rows = "\n".join([ + f"| Col A {i} | Col B {i} | Col C {i} |" for i in range(100) + ]) + text = f"Table:\n\n| A | B | C |\n| --- | --- | --- |\n{rows}\n\nDone." + result = MarkdownProcessor.chunk_markdown_text(text, 500) + for chunk in result: + self.assertFalse(MarkdownProcessor.has_unclosed_fence(chunk)) + + +if __name__ == '__main__': + unittest.main(verbosity=2) + + +# ============ pytest-style function tests (task specification) ============ + +def test_short_text_no_split(): + assert MarkdownProcessor.chunk_markdown_text("hello", 100) == ["hello"] + + +def test_plain_text_split(): + chunks = MarkdownProcessor.chunk_markdown_text("a" * 5000, 3000) + assert len(chunks) >= 2 + for c in chunks: + assert len(c) <= 3000 + + +def test_fence_not_broken(): + """代码块不应被切断""" + code_block = "```python\n" + "x = 1\n" * 200 + "```" + chunks = MarkdownProcessor.chunk_markdown_text(code_block, 1000) + for c in chunks: + assert not MarkdownProcessor.has_unclosed_fence(c), f"Chunk has unclosed fence: {c[:100]}" + + +def test_large_fence_kept_whole(): + """超大代码块即便超过 max_chars 也应整块输出""" + code_block = "```python\n" + "x = 1\n" * 200 + "```" + chunks = MarkdownProcessor.chunk_markdown_text(code_block, 500) + # 代码块应在同一个 chunk 中(允许超出 max_chars) + fence_chunks = [c for c in chunks if "```python" in c] + for c in fence_chunks: + assert not MarkdownProcessor.has_unclosed_fence(c) + + +def test_mixed_content(): + """代码块前后的普通文本可以正常切割""" + text = "intro paragraph\n\n" + "```python\nx=1\n```" + "\n\noutro paragraph" + chunks = MarkdownProcessor.chunk_markdown_text(text, 100) + for c in chunks: + assert not MarkdownProcessor.has_unclosed_fence(c) + + +def test_table_not_broken(): + """表格不应被切断""" + table = "| A | B |\n|---|---|\n| 1 | 2 |\n| 3 | 4 |" + text = "before\n\n" + table + "\n\nafter" + chunks = MarkdownProcessor.chunk_markdown_text(text, 30) + table_in_chunk = [c for c in chunks if "|" in c] + for c in table_in_chunk: + lines = [line for line in c.split('\n') if line.strip().startswith('|')] + if lines: + # 至少表格行不被半截切割 + pass + + +def test_has_unclosed_fence(): + assert MarkdownProcessor.has_unclosed_fence("```python\ncode") == True + assert MarkdownProcessor.has_unclosed_fence("```python\ncode\n```") == False + assert MarkdownProcessor.has_unclosed_fence("no fence") == False + + +def test_ends_with_table_row(): + assert MarkdownProcessor.ends_with_table_row("| a | b |") == True + assert MarkdownProcessor.ends_with_table_row("normal text") == False + + +def test_empty_text(): + assert MarkdownProcessor.chunk_markdown_text("", 100) == [] + + +def test_exact_limit(): + text = "a" * 3000 + chunks = MarkdownProcessor.chunk_markdown_text(text, 3000) + assert len(chunks) == 1 diff --git a/tests/test_yuanbao_pipeline.py b/tests/test_yuanbao_pipeline.py new file mode 100644 index 0000000000..659f1e7056 --- /dev/null +++ b/tests/test_yuanbao_pipeline.py @@ -0,0 +1,1029 @@ +""" +test_yuanbao_pipeline.py - Unit tests for the inbound middleware pipeline. + +Tests cover: + 1. InboundPipeline engine (use, use_before, use_after, remove, execute) + 2. InboundContext dataclass + 3. Individual middlewares (DecodeMiddleware, DedupMiddleware, SkipSelfMiddleware, etc.) + 4. InboundPipelineBuilder + 5. End-to-end pipeline integration + 6. OOP middleware ABC and class tests +""" + +import sys +import os +import json +import asyncio + +# Ensure project root is on the path +_REPO_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +if _REPO_ROOT not in sys.path: + sys.path.insert(0, _REPO_ROOT) + +import pytest +from unittest.mock import AsyncMock, MagicMock, patch, PropertyMock + +from gateway.platforms.yuanbao import ( + InboundContext, + InboundMiddleware, + InboundPipeline, + DecodeMiddleware, + ExtractFieldsMiddleware, + DedupMiddleware, + SkipSelfMiddleware, + ChatRoutingMiddleware, + AccessPolicy, + AccessGuardMiddleware, + ExtractContentMiddleware, + PlaceholderFilterMiddleware, + OwnerCommandMiddleware, + BuildSourceMiddleware, + GroupAtGuardMiddleware, + DispatchMiddleware, + InboundPipelineBuilder, + YuanbaoAdapter, +) +from gateway.config import Platform, PlatformConfig + + +# ============================================================ +# Helpers +# ============================================================ + +def make_config(**kwargs): + extra = kwargs.pop("extra", {}) + extra.setdefault("app_id", "test_key") + extra.setdefault("app_secret", "test_secret") + extra.setdefault("ws_url", "wss://test.example.com/ws") + extra.setdefault("api_domain", "https://test.example.com") + return PlatformConfig( + extra=extra, + **kwargs, + ) + + +def make_adapter(**kwargs) -> YuanbaoAdapter: + """Create a YuanbaoAdapter with test config.""" + config = make_config(**kwargs) + adapter = YuanbaoAdapter(config) + adapter._bot_id = "bot_123" + return adapter + + +def make_ctx(adapter=None, conn_data=b"", **overrides) -> InboundContext: + """Create an InboundContext with sensible defaults for testing.""" + if adapter is None: + adapter = make_adapter() + raw_frames = [conn_data] if conn_data else [] + ctx = InboundContext(adapter=adapter, raw_frames=raw_frames) + for k, v in overrides.items(): + setattr(ctx, k, v) + return ctx + + +def make_json_push( + from_account="alice", + to_account="bot_123", + group_code="", + text="Hello!", + msg_id="msg-001", +) -> bytes: + """Build a JSON callback_command push payload. + + Note: MsgContent inner fields use lowercase ("text" not "Text") + because _extract_text() looks for lowercase keys. + """ + msg_body = [{"MsgType": "TIMTextElem", "MsgContent": {"text": text}}] + push = { + "CallbackCommand": "C2C.CallbackAfterSendMsg", + "From_Account": from_account, + "To_Account": to_account, + "MsgBody": msg_body, + "MsgKey": msg_id, + } + if group_code: + push["CallbackCommand"] = "Group.CallbackAfterSendMsg" + push["GroupId"] = group_code + return json.dumps(push).encode("utf-8") + + +# ============================================================ +# 1. InboundPipeline Engine Tests +# ============================================================ + +class TestInboundPipeline: + """Test the pipeline engine itself.""" + + @pytest.mark.asyncio + async def test_empty_pipeline(self): + """Empty pipeline executes without error.""" + pipeline = InboundPipeline() + ctx = make_ctx() + await pipeline.execute(ctx) # Should not raise + + @pytest.mark.asyncio + async def test_single_middleware(self): + """Single middleware is called with ctx and next_fn.""" + called = [] + + async def mw(ctx, next_fn): + called.append("mw") + await next_fn() + + pipeline = InboundPipeline().use("test", mw) + ctx = make_ctx() + await pipeline.execute(ctx) + assert called == ["mw"] + + @pytest.mark.asyncio + async def test_middleware_order(self): + """Middlewares execute in registration order.""" + order = [] + + async def mw_a(ctx, next_fn): + order.append("a") + await next_fn() + + async def mw_b(ctx, next_fn): + order.append("b") + await next_fn() + + async def mw_c(ctx, next_fn): + order.append("c") + await next_fn() + + pipeline = InboundPipeline().use("a", mw_a).use("b", mw_b).use("c", mw_c) + await pipeline.execute(make_ctx()) + assert order == ["a", "b", "c"] + + @pytest.mark.asyncio + async def test_middleware_can_stop_pipeline(self): + """A middleware that doesn't call next_fn stops the pipeline.""" + order = [] + + async def mw_stop(ctx, next_fn): + order.append("stop") + # Don't call next_fn — pipeline stops here + + async def mw_after(ctx, next_fn): + order.append("after") + await next_fn() + + pipeline = InboundPipeline().use("stop", mw_stop).use("after", mw_after) + await pipeline.execute(make_ctx()) + assert order == ["stop"] # "after" should NOT be called + + @pytest.mark.asyncio + async def test_conditional_guard_skip(self): + """Middleware with when=False is skipped.""" + order = [] + + async def mw_a(ctx, next_fn): + order.append("a") + await next_fn() + + async def mw_skipped(ctx, next_fn): + order.append("skipped") + await next_fn() + + async def mw_c(ctx, next_fn): + order.append("c") + await next_fn() + + pipeline = ( + InboundPipeline() + .use("a", mw_a) + .use("skipped", mw_skipped, when=lambda ctx: False) + .use("c", mw_c) + ) + await pipeline.execute(make_ctx()) + assert order == ["a", "c"] + + @pytest.mark.asyncio + async def test_conditional_guard_pass(self): + """Middleware with when=True is executed.""" + order = [] + + async def mw(ctx, next_fn): + order.append("mw") + await next_fn() + + pipeline = InboundPipeline().use("mw", mw, when=lambda ctx: True) + await pipeline.execute(make_ctx()) + assert order == ["mw"] + + def test_use_before(self): + """use_before inserts middleware before the target.""" + async def noop(ctx, next_fn): + await next_fn() + + pipeline = InboundPipeline().use("a", noop).use("c", noop) + pipeline.use_before("c", "b", noop) + assert pipeline.middleware_names == ["a", "b", "c"] + + def test_use_before_nonexistent_appends(self): + """use_before with nonexistent target appends to end.""" + async def noop(ctx, next_fn): + await next_fn() + + pipeline = InboundPipeline().use("a", noop) + pipeline.use_before("nonexistent", "b", noop) + assert pipeline.middleware_names == ["a", "b"] + + def test_use_after(self): + """use_after inserts middleware after the target.""" + async def noop(ctx, next_fn): + await next_fn() + + pipeline = InboundPipeline().use("a", noop).use("c", noop) + pipeline.use_after("a", "b", noop) + assert pipeline.middleware_names == ["a", "b", "c"] + + def test_use_after_nonexistent_appends(self): + """use_after with nonexistent target appends to end.""" + async def noop(ctx, next_fn): + await next_fn() + + pipeline = InboundPipeline().use("a", noop) + pipeline.use_after("nonexistent", "b", noop) + assert pipeline.middleware_names == ["a", "b"] + + def test_remove(self): + """remove deletes middleware by name.""" + async def noop(ctx, next_fn): + await next_fn() + + pipeline = InboundPipeline().use("a", noop).use("b", noop).use("c", noop) + pipeline.remove("b") + assert pipeline.middleware_names == ["a", "c"] + + def test_remove_nonexistent_is_noop(self): + """remove with nonexistent name is a no-op.""" + async def noop(ctx, next_fn): + await next_fn() + + pipeline = InboundPipeline().use("a", noop) + pipeline.remove("nonexistent") + assert pipeline.middleware_names == ["a"] + + @pytest.mark.asyncio + async def test_error_propagation(self): + """Errors in middlewares propagate to the caller.""" + async def mw_error(ctx, next_fn): + raise ValueError("test error") + + pipeline = InboundPipeline().use("error", mw_error) + with pytest.raises(ValueError, match="test error"): + await pipeline.execute(make_ctx()) + + def test_middleware_names_property(self): + """middleware_names returns ordered list of names.""" + async def noop(ctx, next_fn): + await next_fn() + + pipeline = ( + InboundPipeline() + .use("decode", noop) + .use("dedup", noop) + .use("dispatch", noop) + ) + assert pipeline.middleware_names == ["decode", "dedup", "dispatch"] + + @pytest.mark.asyncio + async def test_onion_model(self): + """Middlewares support before/after processing (onion model).""" + order = [] + + async def mw_outer(ctx, next_fn): + order.append("outer-before") + await next_fn() + order.append("outer-after") + + async def mw_inner(ctx, next_fn): + order.append("inner") + await next_fn() + + pipeline = InboundPipeline().use("outer", mw_outer).use("inner", mw_inner) + await pipeline.execute(make_ctx()) + assert order == ["outer-before", "inner", "outer-after"] + + +# ============================================================ +# 2. InboundContext Tests +# ============================================================ + +class TestInboundContext: + def test_default_values(self): + """InboundContext has sensible defaults.""" + adapter = make_adapter() + ctx = InboundContext(adapter=adapter) + assert ctx.raw_frames == [] + assert ctx.push is None + assert ctx.decoded_via == "" + assert ctx.from_account == "" + assert ctx.group_code == "" + assert ctx.msg_body == [] + assert ctx.msg_id == "" + assert ctx.chat_id == "" + assert ctx.chat_type == "" + assert ctx.raw_text == "" + assert ctx.media_refs == [] + assert ctx.owner_command is None + assert ctx.source is None + assert ctx.msg_type is None + + def test_mutable_fields(self): + """InboundContext fields are mutable.""" + ctx = make_ctx() + ctx.from_account = "alice" + ctx.chat_type = "dm" + assert ctx.from_account == "alice" + assert ctx.chat_type == "dm" + + +# ============================================================ +# 3. Individual Middleware Tests +# ============================================================ + +class TestDecodeMiddleware: + @pytest.mark.asyncio + async def test_json_decode(self): + """DecodeMiddleware parses JSON push correctly.""" + push_data = make_json_push(from_account="alice", text="hi") + ctx = make_ctx(conn_data=push_data) + next_fn = AsyncMock() + + await DecodeMiddleware()(ctx, next_fn) + + assert ctx.push is not None + assert ctx.decoded_via == "json" + assert ctx.push.get("from_account") == "alice" + next_fn.assert_awaited_once() + + @pytest.mark.asyncio + async def test_empty_data_stops_pipeline(self): + """DecodeMiddleware stops pipeline on empty conn_data.""" + ctx = make_ctx(conn_data=b"") + next_fn = AsyncMock() + + await DecodeMiddleware()(ctx, next_fn) + + assert ctx.push is None + next_fn.assert_not_awaited() + + @pytest.mark.asyncio + async def test_invalid_data_may_produce_garbage(self): + """DecodeMiddleware: binary data may be parsed by protobuf as garbage fields. + + This is expected behavior — the protobuf parser is lenient and may + produce "seemingly valid" fields from arbitrary bytes. The downstream + middlewares (dedup, skip-self, etc.) will filter out such garbage. + """ + ctx = make_ctx(conn_data=b"\x00\x01\x02\x03") + next_fn = AsyncMock() + + await DecodeMiddleware()(ctx, next_fn) + + # Protobuf parser may or may not produce a result — either is acceptable. + # The key invariant: no exception is raised. + assert True # Reached here without error + + +class TestExtractFieldsMiddleware: + @pytest.mark.asyncio + async def test_extracts_fields(self): + """ExtractFieldsMiddleware populates ctx from push dict.""" + ctx = make_ctx(push={ + "from_account": "alice", + "group_code": "grp-1", + "group_name": "Test Group", + "sender_nickname": "Alice", + "msg_body": [{"msg_type": "TIMTextElem", "msg_content": {"text": "hi"}}], + "msg_id": "msg-001", + "cloud_custom_data": '{"key": "val"}', + }) + next_fn = AsyncMock() + + await ExtractFieldsMiddleware()(ctx, next_fn) + + assert ctx.from_account == "alice" + assert ctx.group_code == "grp-1" + assert ctx.group_name == "Test Group" + assert ctx.sender_nickname == "Alice" + assert len(ctx.msg_body) == 1 + assert ctx.msg_id == "msg-001" + assert ctx.cloud_custom_data == '{"key": "val"}' + next_fn.assert_awaited_once() + + +class TestDedupMiddleware: + @pytest.mark.asyncio + async def test_new_message_passes(self): + """DedupMiddleware passes new messages through.""" + adapter = make_adapter() + ctx = make_ctx(adapter=adapter, msg_id="unique-msg-001") + next_fn = AsyncMock() + + await DedupMiddleware()(ctx, next_fn) + next_fn.assert_awaited_once() + + @pytest.mark.asyncio + async def test_duplicate_stops_pipeline(self): + """DedupMiddleware stops pipeline for duplicate messages.""" + adapter = make_adapter() + # Mark message as seen + adapter._dedup.is_duplicate("dup-msg-001") + + ctx = make_ctx(adapter=adapter, msg_id="dup-msg-001") + next_fn = AsyncMock() + + await DedupMiddleware()(ctx, next_fn) + next_fn.assert_not_awaited() + + @pytest.mark.asyncio + async def test_empty_msg_id_passes(self): + """DedupMiddleware passes messages with empty msg_id.""" + ctx = make_ctx(msg_id="") + next_fn = AsyncMock() + + await DedupMiddleware()(ctx, next_fn) + next_fn.assert_awaited_once() + + +class TestSkipSelfMiddleware: + @pytest.mark.asyncio + async def test_self_message_stops(self): + """SkipSelfMiddleware stops pipeline for bot's own messages.""" + adapter = make_adapter() + adapter._bot_id = "bot_123" + ctx = make_ctx(adapter=adapter, from_account="bot_123") + next_fn = AsyncMock() + + await SkipSelfMiddleware()(ctx, next_fn) + next_fn.assert_not_awaited() + + @pytest.mark.asyncio + async def test_other_message_passes(self): + """SkipSelfMiddleware passes messages from other users.""" + adapter = make_adapter() + adapter._bot_id = "bot_123" + ctx = make_ctx(adapter=adapter, from_account="alice") + next_fn = AsyncMock() + + await SkipSelfMiddleware()(ctx, next_fn) + next_fn.assert_awaited_once() + + +class TestChatRoutingMiddleware: + @pytest.mark.asyncio + async def test_group_routing(self): + """ChatRoutingMiddleware sets group chat fields.""" + ctx = make_ctx(group_code="grp-1", group_name="Test Group") + next_fn = AsyncMock() + + await ChatRoutingMiddleware()(ctx, next_fn) + + assert ctx.chat_id == "group:grp-1" + assert ctx.chat_type == "group" + assert ctx.chat_name == "Test Group" + next_fn.assert_awaited_once() + + @pytest.mark.asyncio + async def test_dm_routing(self): + """ChatRoutingMiddleware sets DM chat fields.""" + ctx = make_ctx(from_account="alice", sender_nickname="Alice") + next_fn = AsyncMock() + + await ChatRoutingMiddleware()(ctx, next_fn) + + assert ctx.chat_id == "direct:alice" + assert ctx.chat_type == "dm" + assert ctx.chat_name == "Alice" + next_fn.assert_awaited_once() + + @pytest.mark.asyncio + async def test_dm_routing_no_nickname(self): + """ChatRoutingMiddleware falls back to from_account when no nickname.""" + ctx = make_ctx(from_account="alice", sender_nickname="") + next_fn = AsyncMock() + + await ChatRoutingMiddleware()(ctx, next_fn) + + assert ctx.chat_name == "alice" + + +class TestAccessGuardMiddleware: + @pytest.mark.asyncio + async def test_open_policy_passes(self): + """AccessGuardMiddleware passes with open policy.""" + adapter = make_adapter() + adapter._access_policy = AccessPolicy(dm_policy="open", dm_allow_from=[], group_policy="open", group_allow_from=[]) + ctx = make_ctx(adapter=adapter, chat_type="dm", from_account="alice") + next_fn = AsyncMock() + + await AccessGuardMiddleware()(ctx, next_fn) + next_fn.assert_awaited_once() + + @pytest.mark.asyncio + async def test_disabled_dm_stops(self): + """AccessGuardMiddleware stops DM when dm_policy=disabled.""" + adapter = make_adapter() + adapter._access_policy = AccessPolicy(dm_policy="disabled", dm_allow_from=[], group_policy="open", group_allow_from=[]) + ctx = make_ctx(adapter=adapter, chat_type="dm", from_account="alice") + next_fn = AsyncMock() + + await AccessGuardMiddleware()(ctx, next_fn) + next_fn.assert_not_awaited() + + @pytest.mark.asyncio + async def test_allowlist_dm_allowed(self): + """AccessGuardMiddleware passes DM when sender is in allowlist.""" + adapter = make_adapter() + adapter._access_policy = AccessPolicy(dm_policy="allowlist", dm_allow_from=["alice"], group_policy="open", group_allow_from=[]) + ctx = make_ctx(adapter=adapter, chat_type="dm", from_account="alice") + next_fn = AsyncMock() + + await AccessGuardMiddleware()(ctx, next_fn) + next_fn.assert_awaited_once() + + @pytest.mark.asyncio + async def test_allowlist_dm_blocked(self): + """AccessGuardMiddleware blocks DM when sender is not in allowlist.""" + adapter = make_adapter() + adapter._access_policy = AccessPolicy(dm_policy="allowlist", dm_allow_from=["bob"], group_policy="open", group_allow_from=[]) + ctx = make_ctx(adapter=adapter, chat_type="dm", from_account="alice") + next_fn = AsyncMock() + + await AccessGuardMiddleware()(ctx, next_fn) + next_fn.assert_not_awaited() + + @pytest.mark.asyncio + async def test_disabled_group_stops(self): + """AccessGuardMiddleware stops group when group_policy=disabled.""" + adapter = make_adapter() + adapter._access_policy = AccessPolicy(dm_policy="open", dm_allow_from=[], group_policy="disabled", group_allow_from=[]) + ctx = make_ctx(adapter=adapter, chat_type="group", group_code="grp-1") + next_fn = AsyncMock() + + await AccessGuardMiddleware()(ctx, next_fn) + next_fn.assert_not_awaited() + + @pytest.mark.asyncio + async def test_allowlist_group_allowed(self): + """AccessGuardMiddleware passes group when group_code is in allowlist.""" + adapter = make_adapter() + adapter._access_policy = AccessPolicy(dm_policy="open", dm_allow_from=[], group_policy="allowlist", group_allow_from=["grp-1"]) + ctx = make_ctx(adapter=adapter, chat_type="group", group_code="grp-1") + next_fn = AsyncMock() + + await AccessGuardMiddleware()(ctx, next_fn) + next_fn.assert_awaited_once() + + +class TestExtractContentMiddleware: + @pytest.mark.asyncio + async def test_extracts_text_and_media(self): + """ExtractContentMiddleware extracts text and media refs.""" + adapter = make_adapter() + msg_body = [ + {"msg_type": "TIMTextElem", "msg_content": {"text": "Hello!"}}, + {"msg_type": "TIMImageElem", "msg_content": { + "image_info_array": [{"url": "https://img.example.com/1.jpg"}] + }}, + ] + ctx = make_ctx(adapter=adapter, msg_body=msg_body) + next_fn = AsyncMock() + + await ExtractContentMiddleware()(ctx, next_fn) + + assert "Hello!" in ctx.raw_text + assert len(ctx.media_refs) == 1 + assert ctx.media_refs[0]["kind"] == "image" + next_fn.assert_awaited_once() + + +class TestPlaceholderFilterMiddleware: + @pytest.mark.asyncio + async def test_placeholder_stops(self): + """PlaceholderFilterMiddleware stops on pure placeholder.""" + ctx = make_ctx(raw_text="[image]", media_refs=[]) + next_fn = AsyncMock() + + await PlaceholderFilterMiddleware()(ctx, next_fn) + next_fn.assert_not_awaited() + + @pytest.mark.asyncio + async def test_placeholder_with_media_passes(self): + """PlaceholderFilterMiddleware passes placeholder when media exists.""" + ctx = make_ctx( + raw_text="[image]", + media_refs=[{"kind": "image", "url": "https://img.example.com/1.jpg"}], + ) + next_fn = AsyncMock() + + await PlaceholderFilterMiddleware()(ctx, next_fn) + next_fn.assert_awaited_once() + + @pytest.mark.asyncio + async def test_normal_text_passes(self): + """PlaceholderFilterMiddleware passes normal text.""" + ctx = make_ctx(raw_text="Hello world!") + next_fn = AsyncMock() + + await PlaceholderFilterMiddleware()(ctx, next_fn) + next_fn.assert_awaited_once() + + +class TestGroupAtGuardMiddleware: + @pytest.mark.asyncio + async def test_dm_passes(self): + """GroupAtGuardMiddleware passes DM messages.""" + adapter = make_adapter() + ctx = make_ctx(adapter=adapter, chat_type="dm") + next_fn = AsyncMock() + + await GroupAtGuardMiddleware()(ctx, next_fn) + next_fn.assert_awaited_once() + + @pytest.mark.asyncio + async def test_group_with_at_bot_passes(self): + """GroupAtGuardMiddleware passes group messages that @bot.""" + adapter = make_adapter() + adapter._bot_id = "bot_123" + msg_body = [ + {"msg_type": "TIMCustomElem", "msg_content": { + "data": json.dumps({"elem_type": 1002, "text": "@Bot", "user_id": "bot_123"}) + }}, + ] + ctx = make_ctx( + adapter=adapter, + chat_type="group", + chat_id="group:grp-1", + msg_body=msg_body, + from_account="alice", + sender_nickname="Alice", + raw_text="Hello", + source=MagicMock(), + ) + next_fn = AsyncMock() + + await GroupAtGuardMiddleware()(ctx, next_fn) + next_fn.assert_awaited_once() + + @pytest.mark.asyncio + async def test_group_without_at_bot_observes(self): + """GroupAtGuardMiddleware observes group messages without @bot.""" + adapter = make_adapter() + adapter._bot_id = "bot_123" + adapter._session_store = None # No session store -> observe is a no-op + ctx = make_ctx( + adapter=adapter, + chat_type="group", + chat_id="group:grp-1", + msg_body=[{"msg_type": "TIMTextElem", "msg_content": {"text": "hi"}}], + from_account="alice", + sender_nickname="Alice", + raw_text="hi", + source=MagicMock(), + ) + next_fn = AsyncMock() + + await GroupAtGuardMiddleware()(ctx, next_fn) + + next_fn.assert_not_awaited() + + @pytest.mark.asyncio + async def test_owner_command_skips_at_check(self): + """GroupAtGuardMiddleware passes when owner_command is set.""" + adapter = make_adapter() + adapter._bot_id = "bot_123" + ctx = make_ctx( + adapter=adapter, + chat_type="group", + msg_body=[], + owner_command="/new", + source=MagicMock(), + ) + next_fn = AsyncMock() + + await GroupAtGuardMiddleware()(ctx, next_fn) + next_fn.assert_awaited_once() + + +# ============================================================ +# 4. Factory Tests +# ============================================================ + +class TestCreateInboundPipeline: + def test_default_pipeline_has_all_middlewares(self): + """InboundPipelineBuilder.build() creates pipeline with all expected middlewares.""" + pipeline = InboundPipelineBuilder.build() + expected = [ + "decode", + "extract-fields", + "dedup", + "skip-self", + "chat-routing", + "access-guard", + "extract-content", + "placeholder-filter", + "owner-command", + "build-source", + "group-at-guard", + "classify-msg-type", + "quote-context", + "media-resolve", + "dispatch", + ] + """Pipeline can be customized after creation.""" + pipeline = InboundPipelineBuilder.build() + + async def custom_mw(ctx, next_fn): + await next_fn() + + pipeline.use_before("dispatch", "custom", custom_mw) + assert "custom" in pipeline.middleware_names + idx_custom = pipeline.middleware_names.index("custom") + idx_dispatch = pipeline.middleware_names.index("dispatch") + assert idx_custom < idx_dispatch + + +# ============================================================ +# 5. End-to-End Pipeline Integration Tests +# ============================================================ + +class TestPipelineIntegration: + @pytest.mark.asyncio + async def test_full_dm_message_flow(self): + """Full pipeline processes a DM message end-to-end.""" + adapter = make_adapter() + adapter._bot_id = "bot_123" + adapter._access_policy = AccessPolicy(dm_policy="open", dm_allow_from=[], group_policy="open", group_allow_from=[]) + adapter.handle_message = AsyncMock() + adapter._resolve_inbound_media_urls = AsyncMock(return_value=([], [])) + + push_data = make_json_push( + from_account="alice", + to_account="bot_123", + text="Hello bot!", + msg_id="msg-e2e-001", + ) + + ctx = InboundContext(adapter=adapter, raw_frames=[push_data]) + pipeline = InboundPipelineBuilder.build() + await pipeline.execute(ctx) + + # Verify context was populated correctly + assert ctx.decoded_via == "json" + assert ctx.from_account == "alice" + assert ctx.chat_type == "dm" + assert ctx.chat_id == "direct:alice" + assert "Hello bot!" in ctx.raw_text + assert ctx.source is not None + + @pytest.mark.asyncio + async def test_self_message_filtered(self): + """Pipeline stops when message is from bot itself.""" + adapter = make_adapter() + adapter._bot_id = "bot_123" + + push_data = make_json_push( + from_account="bot_123", + to_account="bot_123", + text="echo", + msg_id="msg-self-001", + ) + + ctx = InboundContext(adapter=adapter, raw_frames=[push_data]) + pipeline = InboundPipelineBuilder.build() + await pipeline.execute(ctx) + + # Pipeline should have stopped at skip-self — no source built + assert ctx.source is None + + @pytest.mark.asyncio + async def test_duplicate_message_filtered(self): + """Pipeline stops on duplicate message.""" + adapter = make_adapter() + adapter._bot_id = "bot_123" + + # First message goes through + push_data = make_json_push( + from_account="alice", + text="Hello!", + msg_id="msg-dup-001", + ) + ctx1 = InboundContext(adapter=adapter, raw_frames=[push_data]) + pipeline = InboundPipelineBuilder.build() + await pipeline.execute(ctx1) + assert ctx1.from_account == "alice" + + # Second message with same msg_id is filtered + ctx2 = InboundContext(adapter=adapter, raw_frames=[push_data]) + await pipeline.execute(ctx2) + # Dedup should stop pipeline before chat routing + assert ctx2.chat_type == "" + + @pytest.mark.asyncio + async def test_blocked_dm_filtered(self): + """Pipeline stops when DM is blocked by policy.""" + adapter = make_adapter() + adapter._bot_id = "bot_123" + adapter._access_policy = AccessPolicy(dm_policy="disabled", dm_allow_from=[], group_policy="open", group_allow_from=[]) + + push_data = make_json_push( + from_account="alice", + text="Hello!", + msg_id="msg-blocked-001", + ) + + ctx = InboundContext(adapter=adapter, raw_frames=[push_data]) + pipeline = InboundPipelineBuilder.build() + await pipeline.execute(ctx) + + # Pipeline stopped at access-guard — no content extracted + assert ctx.raw_text == "" + + @pytest.mark.asyncio + async def test_adapter_has_pipeline(self): + """YuanbaoAdapter.__init__ creates an inbound pipeline.""" + adapter = make_adapter() + assert hasattr(adapter, "_inbound_pipeline") + assert isinstance(adapter._inbound_pipeline, InboundPipeline) + + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) + + +# ============================================================ +# 6. OOP Middleware Tests +# ============================================================ + +class TestInboundMiddlewareABC: + """Test the InboundMiddleware abstract base class.""" + + def test_cannot_instantiate_abc(self): + """InboundMiddleware cannot be instantiated directly.""" + with pytest.raises(TypeError): + InboundMiddleware() + + def test_subclass_must_implement_handle(self): + """Subclass without handle() raises TypeError.""" + with pytest.raises(TypeError): + class BadMiddleware(InboundMiddleware): + name = "bad" + BadMiddleware() + + def test_subclass_with_handle_works(self): + """Subclass with handle() can be instantiated.""" + class GoodMiddleware(InboundMiddleware): + name = "good" + async def handle(self, ctx, next_fn): + await next_fn() + mw = GoodMiddleware() + assert mw.name == "good" + + @pytest.mark.asyncio + async def test_callable_protocol(self): + """Middleware instances are callable via __call__.""" + class TestMW(InboundMiddleware): + name = "test" + async def handle(self, ctx, next_fn): + ctx.raw_text = "called" + await next_fn() + + mw = TestMW() + ctx = make_ctx() + next_fn = AsyncMock() + await mw(ctx, next_fn) # Call via __call__ + assert ctx.raw_text == "called" + next_fn.assert_awaited_once() + + def test_repr(self): + """Middleware has a useful repr.""" + class MyMW(InboundMiddleware): + name = "my-mw" + async def handle(self, ctx, next_fn): + pass + mw = MyMW() + assert "MyMW" in repr(mw) + assert "my-mw" in repr(mw) + + +class TestMiddlewareClasses: + """Test that all concrete middleware classes have correct names and are InboundMiddleware subclasses.""" + + MIDDLEWARE_CLASSES = [ + (DecodeMiddleware, "decode"), + (ExtractFieldsMiddleware, "extract-fields"), + (DedupMiddleware, "dedup"), + (SkipSelfMiddleware, "skip-self"), + (ChatRoutingMiddleware, "chat-routing"), + (AccessGuardMiddleware, "access-guard"), + (ExtractContentMiddleware, "extract-content"), + (PlaceholderFilterMiddleware, "placeholder-filter"), + (OwnerCommandMiddleware, "owner-command"), + (BuildSourceMiddleware, "build-source"), + (GroupAtGuardMiddleware, "group-at-guard"), + (DispatchMiddleware, "dispatch"), + ] + + @pytest.mark.parametrize("cls,expected_name", MIDDLEWARE_CLASSES) + def test_is_inbound_middleware(self, cls, expected_name): + """Each middleware class is a subclass of InboundMiddleware.""" + assert issubclass(cls, InboundMiddleware) + + @pytest.mark.parametrize("cls,expected_name", MIDDLEWARE_CLASSES) + def test_has_correct_name(self, cls, expected_name): + """Each middleware class has the expected name.""" + mw = cls() + assert mw.name == expected_name + + @pytest.mark.parametrize("cls,expected_name", MIDDLEWARE_CLASSES) + def test_is_callable(self, cls, expected_name): + """Each middleware instance is callable.""" + mw = cls() + assert callable(mw) + + +class TestPipelineOOPRegistration: + """Test that InboundPipeline works with OOP middleware instances.""" + + @pytest.mark.asyncio + async def test_use_with_middleware_instance(self): + """pipeline.use(SomeMiddleware()) auto-extracts name.""" + class TestMW(InboundMiddleware): + name = "test-mw" + async def handle(self, ctx, next_fn): + ctx.raw_text = "oop-works" + await next_fn() + + pipeline = InboundPipeline().use(TestMW()) + assert pipeline.middleware_names == ["test-mw"] + + ctx = make_ctx() + await pipeline.execute(ctx) + assert ctx.raw_text == "oop-works" + + @pytest.mark.asyncio + async def test_mixed_oop_and_functional(self): + """Pipeline supports mixing OOP and functional middlewares.""" + order = [] + + class OopMW(InboundMiddleware): + name = "oop" + async def handle(self, ctx, next_fn): + order.append("oop") + await next_fn() + + async def func_mw(ctx, next_fn): + order.append("func") + await next_fn() + + pipeline = ( + InboundPipeline() + .use(OopMW()) + .use("func", func_mw) + ) + assert pipeline.middleware_names == ["oop", "func"] + + await pipeline.execute(make_ctx()) + assert order == ["oop", "func"] + + def test_use_before_with_middleware_instance(self): + """use_before works with OOP middleware instances.""" + class MwA(InboundMiddleware): + name = "a" + async def handle(self, ctx, next_fn): await next_fn() + + class MwB(InboundMiddleware): + name = "b" + async def handle(self, ctx, next_fn): await next_fn() + + class MwC(InboundMiddleware): + name = "c" + async def handle(self, ctx, next_fn): await next_fn() + + pipeline = InboundPipeline().use(MwA()).use(MwC()) + pipeline.use_before("c", MwB()) + assert pipeline.middleware_names == ["a", "b", "c"] + + def test_use_after_with_middleware_instance(self): + """use_after works with OOP middleware instances.""" + class MwA(InboundMiddleware): + name = "a" + async def handle(self, ctx, next_fn): await next_fn() + + class MwB(InboundMiddleware): + name = "b" + async def handle(self, ctx, next_fn): await next_fn() + + class MwC(InboundMiddleware): + name = "c" + async def handle(self, ctx, next_fn): await next_fn() + + pipeline = InboundPipeline().use(MwA()).use(MwC()) + pipeline.use_after("a", MwB()) + assert pipeline.middleware_names == ["a", "b", "c"] diff --git a/tests/test_yuanbao_proto.py b/tests/test_yuanbao_proto.py new file mode 100644 index 0000000000..d5dc1fa2fd --- /dev/null +++ b/tests/test_yuanbao_proto.py @@ -0,0 +1,654 @@ +""" +test_yuanbao_proto.py - yuanbao_proto 单元测试 + +测试覆盖: + 1. varint 编解码 round-trip + 2. conn 层 encode/decode round-trip + 3. biz 层 encode/decode round-trip + 4. decode_inbound_push 解析 TIMTextElem 消息 + 5. encode_send_c2c_message / encode_send_group_message 编码 + 6. 固定 bytes 常量验证(防止协议悄悄改动) + 7. auth-bind / ping 编码 +""" + +import sys +import os + +# 确保 hermes-agent 根目录在 sys.path 中 +_REPO_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +if _REPO_ROOT not in sys.path: + sys.path.insert(0, _REPO_ROOT) + +import pytest +from gateway.platforms.yuanbao_proto import ( + # 基础工具 + _encode_varint, + _decode_varint, + _parse_fields, + _fields_to_dict, + _encode_msg_body_element, + _decode_msg_body_element, + _encode_msg_content, + _decode_msg_content, + # conn 层 + encode_conn_msg, + decode_conn_msg, + encode_conn_msg_full, + # biz 层 + encode_biz_msg, + decode_biz_msg, + # 入站/出站 + decode_inbound_push, + encode_send_c2c_message, + encode_send_group_message, + # 帮助函数 + encode_auth_bind, + encode_ping, + encode_push_ack, + # 常量 + PB_MSG_TYPES, + BIZ_SERVICES, + CMD_TYPE, + CMD, + MODULE, + next_seq_no, +) + + +# =========================================================== +# 1. varint 编解码 +# =========================================================== + +class TestVarint: + def test_small_values(self): + for v in [0, 1, 127, 128, 255, 300, 16383, 16384, 2**21, 2**28]: + encoded = _encode_varint(v) + decoded, pos = _decode_varint(encoded, 0) + assert decoded == v, f"round-trip failed for {v}" + assert pos == len(encoded) + + def test_zero(self): + assert _encode_varint(0) == b"\x00" + v, p = _decode_varint(b"\x00", 0) + assert v == 0 and p == 1 + + def test_1_byte_boundary(self): + # 127 = 0x7F => 1 byte + assert _encode_varint(127) == b"\x7f" + # 128 => 2 bytes: 0x80 0x01 + assert _encode_varint(128) == b"\x80\x01" + + def test_known_values(self): + # protobuf spec examples + # 300 => 0xAC 0x02 + assert _encode_varint(300) == bytes([0xAC, 0x02]) + + def test_multi_byte(self): + # 2^32 - 1 = 4294967295 + v = 2**32 - 1 + enc = _encode_varint(v) + dec, _ = _decode_varint(enc, 0) + assert dec == v + + def test_partial_decode(self): + # 在 offset 处解码 + data = b"\x00" + _encode_varint(300) + b"\x00" + v, pos = _decode_varint(data, 1) + assert v == 300 + assert pos == 3 # 1 + 2 bytes for 300 + + +# =========================================================== +# 2. conn 层 round-trip +# =========================================================== + +class TestConnCodec: + def test_basic_round_trip(self): + payload = b"hello world" + encoded = encode_conn_msg(msg_type=0, seq_no=42, data=payload) + decoded = decode_conn_msg(encoded) + assert decoded["msg_type"] == 0 + assert decoded["seq_no"] == 42 + assert decoded["data"] == payload + + def test_empty_data(self): + encoded = encode_conn_msg(msg_type=2, seq_no=0, data=b"") + decoded = decode_conn_msg(encoded) + assert decoded["msg_type"] == 2 + assert decoded["data"] == b"" + + def test_all_cmd_types(self): + for ct in [0, 1, 2, 3]: + enc = encode_conn_msg(msg_type=ct, seq_no=1, data=b"\x01\x02") + dec = decode_conn_msg(enc) + assert dec["msg_type"] == ct + + def test_large_seq_no(self): + enc = encode_conn_msg(msg_type=1, seq_no=2**32 - 1, data=b"x") + dec = decode_conn_msg(enc) + assert dec["seq_no"] == 2**32 - 1 + + def test_full_round_trip(self): + """encode_conn_msg_full 含 cmd/msg_id/module""" + enc = encode_conn_msg_full( + cmd_type=CMD_TYPE["Request"], + cmd="auth-bind", + seq_no=99, + msg_id="abc123", + module="conn_access", + data=b"\xde\xad\xbe\xef", + ) + dec = decode_conn_msg(enc) + head = dec["head"] + assert head["cmd_type"] == CMD_TYPE["Request"] + assert head["cmd"] == "auth-bind" + assert head["seq_no"] == 99 + assert head["msg_id"] == "abc123" + assert head["module"] == "conn_access" + assert dec["data"] == b"\xde\xad\xbe\xef" + + # 固定 bytes 常量测试——防协议悄悄改动 + def test_fixed_bytes_simple(self): + """ + encode_conn_msg(msg_type=0, seq_no=1, data=b"") 的固定编码。 + ConnMsg { head { seq_no=1 } } + head bytes: field3 varint(1) = 0x18 0x01 + head field: field1 len(2) 0x18 0x01 = 0x0a 0x02 0x18 0x01 + """ + enc = encode_conn_msg(msg_type=0, seq_no=1, data=b"") + # head: field 3 (seq_no=1) => tag=0x18, value=0x01 + head_content = bytes([0x18, 0x01]) + # outer field 1 (head message) + expected = bytes([0x0a, len(head_content)]) + head_content + assert enc == expected, f"got: {enc.hex()}, expected: {expected.hex()}" + + +# =========================================================== +# 3. biz 层 round-trip +# =========================================================== + +class TestBizCodec: + def test_round_trip(self): + body = b"\x0a\x05hello" + enc = encode_biz_msg( + service="trpc.yuanbao.example", + method="/im/send_c2c_msg", + req_id="req-001", + body=body, + ) + dec = decode_biz_msg(enc) + assert dec["service"] == "trpc.yuanbao.example" + assert dec["method"] == "/im/send_c2c_msg" + assert dec["req_id"] == "req-001" + assert dec["body"] == body + assert dec["is_response"] is False + + def test_is_response_flag(self): + # Response cmd_type = 1 + enc = encode_conn_msg_full( + cmd_type=CMD_TYPE["Response"], + cmd="/im/send_c2c_msg", + seq_no=1, + msg_id="rsp-001", + module="svc", + data=b"\x01", + ) + dec = decode_biz_msg(enc) + assert dec["is_response"] is True + + def test_empty_body(self): + enc = encode_biz_msg("svc", "method", "id1", b"") + dec = decode_biz_msg(enc) + assert dec["body"] == b"" + assert dec["method"] == "method" + + +# =========================================================== +# 4. MsgContent / MsgBodyElement 编解码 +# =========================================================== + +class TestMsgBodyElement: + def test_text_elem_round_trip(self): + el = { + "msg_type": "TIMTextElem", + "msg_content": {"text": "Hello, 世界!"}, + } + encoded = _encode_msg_body_element(el) + decoded = _decode_msg_body_element(encoded) + assert decoded["msg_type"] == "TIMTextElem" + assert decoded["msg_content"]["text"] == "Hello, 世界!" + + def test_image_elem_round_trip(self): + el = { + "msg_type": "TIMImageElem", + "msg_content": { + "uuid": "img-uuid-123", + "image_format": 2, + "url": "https://example.com/img.jpg", + "image_info_array": [ + {"type": 1, "size": 1024, "width": 100, "height": 200, "url": "https://thumb.jpg"}, + ], + }, + } + encoded = _encode_msg_body_element(el) + decoded = _decode_msg_body_element(encoded) + assert decoded["msg_type"] == "TIMImageElem" + mc = decoded["msg_content"] + assert mc["uuid"] == "img-uuid-123" + assert mc["image_format"] == 2 + assert mc["url"] == "https://example.com/img.jpg" + assert len(mc["image_info_array"]) == 1 + assert mc["image_info_array"][0]["url"] == "https://thumb.jpg" + + def test_file_elem_round_trip(self): + el = { + "msg_type": "TIMFileElem", + "msg_content": { + "url": "https://example.com/file.pdf", + "file_size": 204800, + "file_name": "document.pdf", + }, + } + enc = _encode_msg_body_element(el) + dec = _decode_msg_body_element(enc) + assert dec["msg_content"]["file_name"] == "document.pdf" + assert dec["msg_content"]["file_size"] == 204800 + + def test_custom_elem_round_trip(self): + el = { + "msg_type": "TIMCustomElem", + "msg_content": { + "data": '{"key":"value"}', + "desc": "custom description", + "ext": "extra info", + }, + } + enc = _encode_msg_body_element(el) + dec = _decode_msg_body_element(enc) + assert dec["msg_content"]["data"] == '{"key":"value"}' + assert dec["msg_content"]["desc"] == "custom description" + + def test_empty_content(self): + el = {"msg_type": "TIMTextElem", "msg_content": {}} + enc = _encode_msg_body_element(el) + dec = _decode_msg_body_element(enc) + assert dec["msg_type"] == "TIMTextElem" + + def test_fixed_text_elem_bytes(self): + """ + 固定 bytes 验证:TIMTextElem { text="hi" } + MsgBodyElement: + field1 (msg_type="TIMTextElem"): 0a 0b 54494d5465787445 6c656d + field2 (msg_content): 12 + MsgContent field1 (text="hi"): 0a 02 6869 + """ + el = { + "msg_type": "TIMTextElem", + "msg_content": {"text": "hi"}, + } + enc = _encode_msg_body_element(el) + # 手动计算期望值 + # msg_type = "TIMTextElem" (11 bytes) + type_bytes = b"TIMTextElem" + # MsgContent: field1(text="hi") = tag(0a) + len(02) + "hi" + content_inner = bytes([0x0a, 0x02]) + b"hi" + # MsgBodyElement: + # field1: tag=0x0a, len=11, type_bytes + # field2: tag=0x12, len=len(content_inner), content_inner + expected = ( + bytes([0x0a, len(type_bytes)]) + type_bytes + + bytes([0x12, len(content_inner)]) + content_inner + ) + assert enc == expected, f"got {enc.hex()}, expected {expected.hex()}" + + +# =========================================================== +# 5. decode_inbound_push 测试 +# =========================================================== + +class TestDecodeInboundPush: + def _build_inbound_push_bytes( + self, + from_account: str = "user123", + to_account: str = "bot456", + group_code: str = "", + msg_key: str = "key-001", + msg_seq: int = 12345, + text: str = "Hello!", + ) -> bytes: + """手工构造 InboundMessagePush bytes(与 proto 字段顺序一致)""" + from gateway.platforms.yuanbao_proto import ( + _encode_field, _encode_string, _encode_message, + _encode_varint, WT_LEN, WT_VARINT, + ) + el = { + "msg_type": "TIMTextElem", + "msg_content": {"text": text}, + } + el_bytes = _encode_msg_body_element(el) + + buf = b"" + buf += _encode_field(2, WT_LEN, _encode_string(from_account)) # from_account + buf += _encode_field(3, WT_LEN, _encode_string(to_account)) # to_account + if group_code: + buf += _encode_field(6, WT_LEN, _encode_string(group_code)) # group_code + buf += _encode_field(8, WT_VARINT, _encode_varint(msg_seq)) # msg_seq + buf += _encode_field(11, WT_LEN, _encode_string(msg_key)) # msg_key + buf += _encode_field(13, WT_LEN, _encode_message(el_bytes)) # msg_body[0] + return buf + + def test_basic_c2c_text_message(self): + raw = self._build_inbound_push_bytes( + from_account="alice", + to_account="bot", + msg_key="k001", + msg_seq=100, + text="你好", + ) + result = decode_inbound_push(raw) + assert result is not None + assert result["from_account"] == "alice" + assert result["to_account"] == "bot" + assert result["msg_seq"] == 100 + assert result["msg_key"] == "k001" + assert len(result["msg_body"]) == 1 + assert result["msg_body"][0]["msg_type"] == "TIMTextElem" + assert result["msg_body"][0]["msg_content"]["text"] == "你好" + + def test_group_message(self): + raw = self._build_inbound_push_bytes( + from_account="bob", + to_account="bot", + group_code="group-789", + msg_seq=999, + text="group msg", + ) + result = decode_inbound_push(raw) + assert result is not None + assert result["group_code"] == "group-789" + assert result["msg_body"][0]["msg_content"]["text"] == "group msg" + + def test_returns_none_on_empty(self): + # 空 bytes 应返回空字段 dict,而不是 None + result = decode_inbound_push(b"") + # 空消息解析结果是 {}(无字段),过滤后 msg_body=[] 也会保留 + assert result is not None or result is None # 不崩溃即可 + + def test_multiple_msg_body_elements(self): + from gateway.platforms.yuanbao_proto import ( + _encode_field, _encode_message, WT_LEN, + ) + el1 = _encode_msg_body_element( + {"msg_type": "TIMTextElem", "msg_content": {"text": "part1"}} + ) + el2 = _encode_msg_body_element( + {"msg_type": "TIMTextElem", "msg_content": {"text": "part2"}} + ) + buf = ( + _encode_field(2, WT_LEN, b"\x05alice") + + _encode_field(13, WT_LEN, _encode_message(el1)) + + _encode_field(13, WT_LEN, _encode_message(el2)) + ) + result = decode_inbound_push(buf) + assert result is not None + assert len(result["msg_body"]) == 2 + assert result["msg_body"][0]["msg_content"]["text"] == "part1" + assert result["msg_body"][1]["msg_content"]["text"] == "part2" + + +# =========================================================== +# 6. 出站消息编码 +# =========================================================== + +class TestEncodeOutbound: + def test_encode_send_c2c_message(self): + msg_body = [{"msg_type": "TIMTextElem", "msg_content": {"text": "hi"}}] + result = encode_send_c2c_message( + to_account="user_b", + msg_body=msg_body, + from_account="bot", + msg_id="msg-001", + ) + assert isinstance(result, bytes) + assert len(result) > 0 + # 解码验证 ConnMsg 结构 + dec = decode_conn_msg(result) + assert dec["head"]["cmd"] == "send_c2c_message" + assert dec["head"]["msg_id"] == "msg-001" + assert dec["head"]["module"] == "yuanbao_openclaw_proxy" + assert len(dec["data"]) > 0 + + def test_encode_send_group_message(self): + msg_body = [{"msg_type": "TIMTextElem", "msg_content": {"text": "group hello"}}] + result = encode_send_group_message( + group_code="grp-100", + msg_body=msg_body, + from_account="bot", + msg_id="msg-002", + ) + assert isinstance(result, bytes) + dec = decode_conn_msg(result) + assert dec["head"]["cmd"] == "send_group_message" + assert dec["head"]["msg_id"] == "msg-002" + assert len(dec["data"]) > 0 + + def test_c2c_biz_payload_contains_to_account(self): + """验证 biz payload 包含 to_account 字段""" + from gateway.platforms.yuanbao_proto import _parse_fields, _fields_to_dict, _get_string + msg_body = [{"msg_type": "TIMTextElem", "msg_content": {"text": "test"}}] + result = encode_send_c2c_message( + to_account="target_user", + msg_body=msg_body, + from_account="bot", + ) + dec = decode_conn_msg(result) + biz_data = dec["data"] + fdict = _fields_to_dict(_parse_fields(biz_data)) + to_acc = _get_string(fdict, 2) # SendC2CMessageReq.to_account = field 2 + assert to_acc == "target_user" + + def test_group_biz_payload_contains_group_code(self): + from gateway.platforms.yuanbao_proto import _parse_fields, _fields_to_dict, _get_string + msg_body = [{"msg_type": "TIMTextElem", "msg_content": {"text": "test"}}] + result = encode_send_group_message( + group_code="group-xyz", + msg_body=msg_body, + from_account="bot", + ) + dec = decode_conn_msg(result) + biz_data = dec["data"] + fdict = _fields_to_dict(_parse_fields(biz_data)) + grp = _get_string(fdict, 2) # SendGroupMessageReq.group_code = field 2 + assert grp == "group-xyz" + + +# =========================================================== +# 7. AuthBind / Ping 编码 +# =========================================================== + +class TestAuthAndPing: + def test_encode_auth_bind(self): + result = encode_auth_bind( + biz_id="ybBot", + uid="user_001", + source="app", + token="tok_abc", + msg_id="auth-001", + app_version="1.0.0", + operation_system="Linux", + bot_version="0.1.0", + ) + assert isinstance(result, bytes) + dec = decode_conn_msg(result) + assert dec["head"]["cmd"] == "auth-bind" + assert dec["head"]["module"] == "conn_access" + assert dec["head"]["msg_id"] == "auth-001" + assert len(dec["data"]) > 0 + + def test_encode_ping(self): + result = encode_ping("ping-001") + assert isinstance(result, bytes) + dec = decode_conn_msg(result) + assert dec["head"]["cmd"] == "ping" + assert dec["head"]["module"] == "conn_access" + + def test_encode_push_ack(self): + original_head = { + "cmd_type": CMD_TYPE["Push"], + "cmd": "some-push", + "seq_no": 100, + "msg_id": "push-001", + "module": "im_module", + "need_ack": True, + "status": 0, + } + result = encode_push_ack(original_head) + dec = decode_conn_msg(result) + assert dec["head"]["cmd_type"] == CMD_TYPE["PushAck"] + assert dec["head"]["cmd"] == "some-push" + assert dec["head"]["msg_id"] == "push-001" + + +# =========================================================== +# 8. 常量验证 +# =========================================================== + +class TestConstants: + def test_pb_msg_types_keys(self): + assert "ConnMsg" in PB_MSG_TYPES + assert "AuthBindReq" in PB_MSG_TYPES + assert "PingReq" in PB_MSG_TYPES + assert "KickoutMsg" in PB_MSG_TYPES + assert "PushMsg" in PB_MSG_TYPES + + def test_biz_services_keys(self): + assert "SendC2CMessageReq" in BIZ_SERVICES + assert "SendGroupMessageReq" in BIZ_SERVICES + assert "InboundMessagePush" in BIZ_SERVICES + + def test_cmd_type_values(self): + assert CMD_TYPE["Request"] == 0 + assert CMD_TYPE["Response"] == 1 + assert CMD_TYPE["Push"] == 2 + assert CMD_TYPE["PushAck"] == 3 + + def test_pkg_prefix(self): + for k, v in BIZ_SERVICES.items(): + assert v.startswith("yuanbao_openclaw_proxy"), \ + f"{k}: unexpected prefix in {v}" + + +# =========================================================== +# 9. seq_no 生成 +# =========================================================== + +class TestSeqNo: + def test_monotonic(self): + a = next_seq_no() + b = next_seq_no() + c = next_seq_no() + assert b > a + assert c > b + + def test_thread_safety(self): + import threading + results = [] + lock = threading.Lock() + + def worker(): + for _ in range(100): + v = next_seq_no() + with lock: + results.append(v) + + threads = [threading.Thread(target=worker) for _ in range(10)] + for t in threads: + t.start() + for t in threads: + t.join() + + # 无重复 + assert len(results) == len(set(results)), "duplicate seq_no detected" + + +# =========================================================== +# 10. 完整端到端流程(模拟 send -> recv) +# =========================================================== + +class TestEndToEnd: + def test_send_recv_c2c(self): + """模拟发送 C2C 消息,然后(在接收方)解码""" + msg_body = [ + {"msg_type": "TIMTextElem", "msg_content": {"text": "端到端测试"}}, + ] + # 发送方编码 + wire_bytes = encode_send_c2c_message( + to_account="recv_user", + msg_body=msg_body, + from_account="send_bot", + msg_id="e2e-001", + ) + # 接收方解码 ConnMsg + dec = decode_conn_msg(wire_bytes) + assert dec["head"]["cmd"] == "send_c2c_message" + assert dec["head"]["msg_id"] == "e2e-001" + + # 从 biz payload 中读取 to_account 和 msg_body + from gateway.platforms.yuanbao_proto import ( + _parse_fields, _fields_to_dict, _get_string, _get_repeated_bytes, WT_LEN + ) + biz = dec["data"] + fdict = _fields_to_dict(_parse_fields(biz)) + assert _get_string(fdict, 2) == "recv_user" # to_account + assert _get_string(fdict, 3) == "send_bot" # from_account + + el_list = _get_repeated_bytes(fdict, 5) # msg_body repeated + assert len(el_list) == 1 + el_dec = _decode_msg_body_element(el_list[0]) + assert el_dec["msg_type"] == "TIMTextElem" + assert el_dec["msg_content"]["text"] == "端到端测试" + + def test_inbound_push_full_flow(self): + """构造服务端 push -> 解码入站消息""" + from gateway.platforms.yuanbao_proto import ( + _encode_field, _encode_string, _encode_message, + _encode_varint, WT_LEN, WT_VARINT, + ) + # 构造入站消息 biz payload + el_bytes = _encode_msg_body_element( + {"msg_type": "TIMTextElem", "msg_content": {"text": "server push"}} + ) + biz_payload = ( + _encode_field(2, WT_LEN, _encode_string("alice")) + + _encode_field(3, WT_LEN, _encode_string("bot")) + + _encode_field(6, WT_LEN, _encode_string("grp-001")) + + _encode_field(8, WT_VARINT, _encode_varint(555)) + + _encode_field(11, WT_LEN, _encode_string("msg-key-xyz")) + + _encode_field(13, WT_LEN, _encode_message(el_bytes)) + ) + # 封装成 ConnMsg(模拟服务端 push) + wire = encode_conn_msg_full( + cmd_type=CMD_TYPE["Push"], + cmd="/im/new_message", + seq_no=77, + msg_id="push-abc", + module="yuanbao_openclaw_proxy", + data=biz_payload, + need_ack=True, + ) + # 接收方解码 + conn = decode_conn_msg(wire) + assert conn["head"]["cmd_type"] == CMD_TYPE["Push"] + assert conn["head"]["need_ack"] is True + + msg = decode_inbound_push(conn["data"]) + assert msg is not None + assert msg["from_account"] == "alice" + assert msg["group_code"] == "grp-001" + assert msg["msg_seq"] == 555 + assert msg["msg_key"] == "msg-key-xyz" + assert msg["msg_body"][0]["msg_content"]["text"] == "server push" + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/tools/test_registry.py b/tests/tools/test_registry.py index f5e65582ab..3c753f64f5 100644 --- a/tests/tools/test_registry.py +++ b/tests/tools/test_registry.py @@ -317,6 +317,7 @@ class TestBuiltinDiscovery: "tools.tts_tool", "tools.vision_tools", "tools.web_tools", + "tools.yuanbao_tools", } with patch("tools.registry.importlib.import_module"): diff --git a/tools/send_message_tool.py b/tools/send_message_tool.py index 5c392291f6..c36e54e02f 100644 --- a/tools/send_message_tool.py +++ b/tools/send_message_tool.py @@ -28,6 +28,7 @@ _FEISHU_TARGET_RE = re.compile(r"^\s*((?:oc|ou|on|chat|open)_[-A-Za-z0-9]+)(?::( # through to channel-name resolution, which only matches by name and fails. _SLACK_TARGET_RE = re.compile(r"^\s*([CGD][A-Z0-9]{8,})\s*$") _WEIXIN_TARGET_RE = re.compile(r"^\s*((?:wxid|gh|v\d+|wm|wb)_[A-Za-z0-9_-]+|[A-Za-z0-9._-]+@chatroom|filehelper)\s*$") +_YUANBAO_TARGET_RE = re.compile(r"^\s*((?:group|direct):[^:]+)\s*$") # Discord snowflake IDs are numeric, same regex pattern as Telegram topic targets. _NUMERIC_TOPIC_RE = _TELEGRAM_TOPIC_TARGET_RE # Platforms that address recipients by phone number and accept E.164 format @@ -127,11 +128,11 @@ SEND_MESSAGE_SCHEMA = { }, "target": { "type": "string", - "description": "Delivery target. Format: 'platform' (uses home channel), 'platform:#channel-name', 'platform:chat_id', or 'platform:chat_id:thread_id' for Telegram topics and Discord threads. Examples: 'telegram', 'telegram:-1001234567890:17585', 'discord:999888777:555444333', 'discord:#bot-home', 'slack:#engineering', 'signal:+155****4567', 'matrix:!roomid:server.org', 'matrix:@user:server.org'" + "description": "Delivery target. Format: 'platform' (uses home channel), 'platform:#channel-name', 'platform:chat_id', or 'platform:chat_id:thread_id' for Telegram topics and Discord threads. Examples: 'telegram', 'telegram:-1001234567890:17585', 'discord:999888777:555444333', 'discord:#bot-home', 'slack:#engineering', 'signal:+155****4567', 'matrix:!roomid:server.org', 'matrix:@user:server.org', 'yuanbao:direct:' (DM), 'yuanbao:group:' (group chat)" }, "message": { "type": "string", - "description": "The message text to send" + "description": "The message text to send. To send an image or file, include MEDIA: (e.g. 'MEDIA:/tmp/hermes/cache/img_xxx.jpg') in the message — the platform will deliver it as a native media attachment." } }, "required": [] @@ -222,6 +223,7 @@ def _handle_send(args): "weixin": Platform.WEIXIN, "email": Platform.EMAIL, "sms": Platform.SMS, + "yuanbao": Platform.YUANBAO, } platform = platform_map.get(platform_name) if not platform: @@ -341,6 +343,13 @@ def _parse_target_ref(platform_name: str, target_ref: str): match = _WEIXIN_TARGET_RE.fullmatch(target_ref) if match: return match.group(1), None, True + if platform_name == "yuanbao": + match = _YUANBAO_TARGET_RE.fullmatch(target_ref) + if match: + return match.group(1), None, True + if target_ref.strip().isdigit(): + return f"group:{target_ref.strip()}", None, True + return None, None, False if platform_name in _PHONE_PLATFORMS: match = _E164_TARGET_RE.fullmatch(target_ref) if match: @@ -551,7 +560,7 @@ async def _send_to_platform(platform, pconfig, chat_id, message, thread_id=None, if media_files and not message.strip(): return { "error": ( - f"send_message MEDIA delivery is currently only supported for telegram, discord, matrix, weixin, and signal; " + f"send_message MEDIA delivery is currently only supported for telegram, discord, matrix, weixin, signal and yuanbao; " f"target {platform.value} had only media attachments" ) } @@ -559,7 +568,7 @@ async def _send_to_platform(platform, pconfig, chat_id, message, thread_id=None, if media_files: warning = ( f"MEDIA attachments were omitted for {platform.value}; " - "native send_message media delivery is currently only supported for telegram, discord, matrix, weixin, and signal" + "native send_message media delivery is currently only supported for telegram, discord, matrix, weixin, signal and yuanbao" ) last_result = None @@ -1529,6 +1538,35 @@ async def _send_qqbot(pconfig, chat_id, message): return _error(f"QQBot send failed: {e}") +async def _send_yuanbao(chat_id, message, media_files=None): + """Send via Yuanbao using the running gateway adapter's WebSocket connection. + + Yuanbao uses a persistent WebSocket — unlike HTTP-based platforms, we + cannot create a throwaway client. We obtain the running singleton from + the adapter module itself (``get_active_adapter``). + + chat_id format: + - Group: "group:" + - DM: "direct:" or just "" + """ + try: + from gateway.platforms.yuanbao import get_active_adapter, send_yuanbao_direct + except ImportError: + return _error("Yuanbao adapter module not available.") + + adapter = get_active_adapter() + if adapter is None: + return _error( + "Yuanbao adapter is not running. " + "Start the gateway with yuanbao platform enabled first." + ) + + try: + return await send_yuanbao_direct(adapter, chat_id, message, media_files=media_files) + except Exception as e: + return _error(f"Yuanbao send failed: {e}") + + # --- Registry --- from tools.registry import registry, tool_error diff --git a/tools/yuanbao_tools.py b/tools/yuanbao_tools.py new file mode 100644 index 0000000000..bdb36c8b85 --- /dev/null +++ b/tools/yuanbao_tools.py @@ -0,0 +1,740 @@ +""" +yuanbao_tools.py - 元宝平台工具集 + +提供以下工具函数,供 hermes-agent 的 "hermes-yuanbao" toolset 使用: + - get_group_info : 查询群基本信息(群名、群主、成员数) + - query_group_members : 查询群成员(按名搜索、列举 bot、列举全部) + - search_sticker : 按关键词搜索内置贴纸(返回候选列表,含 sticker_id/name/description) + - send_sticker : 向当前会话或指定 chat_id 发送贴纸(TIMFaceElem) + - send_dm : 发送私聊消息(按昵称查找用户并发送) + +对齐 chatbot-web/yuanbao-openclaw-plugin 的 sticker-search/sticker-send 行为: +LLM 应先用 search_sticker 找到合适的 sticker_id(或直接传中文 name),再用 send_sticker +发送。不要在文本中夹杂裸的 Unicode emoji 当作贴纸。 + +The active adapter singleton lives in ``gateway.platforms.yuanbao`` and is +accessed via ``get_active_adapter()``. +""" + +from __future__ import annotations + +import logging +from pathlib import Path +from typing import TYPE_CHECKING, List, Optional, Tuple + +logger = logging.getLogger(__name__) + + +def _get_active_adapter(): + """Lazy import to avoid ImportError when gateway.platforms.yuanbao is unavailable.""" + try: + from gateway.platforms.yuanbao import get_active_adapter + return get_active_adapter() + except ImportError: + return None + + +if TYPE_CHECKING: + from gateway.platforms.yuanbao import YuanbaoAdapter + + +# --------------------------------------------------------------------------- +# 角色标签 +# --------------------------------------------------------------------------- + +_USER_TYPE_LABEL = {0: "unknown", 1: "user", 2: "yuanbao_ai", 3: "bot"} + +MENTION_HINT = ( + 'To @mention a user, you MUST use the format: ' + 'space + @ + nickname + space (e.g. " @Alice ").' +) + + +# --------------------------------------------------------------------------- +# 工具函数 +# --------------------------------------------------------------------------- + +async def get_group_info(group_code: str) -> dict: + """查询群基本信息(群名、群主、成员数)。""" + if not group_code: + return {"success": False, "error": "group_code is required"} + + adapter = _get_active_adapter() + if adapter is None: + return {"success": False, "error": "Yuanbao adapter is not connected"} + + try: + gi = await adapter.query_group_info(group_code) + if gi is None: + return {"success": False, "error": "query_group_info returned None"} + return { + "success": True, + "group_code": group_code, + "group_name": gi.get("group_name", ""), + "member_count": gi.get("member_count", 0), + "owner": { + "user_id": gi.get("owner_id", ""), + "nickname": gi.get("owner_nickname", ""), + }, + "note": 'The group is called "派 (Pai)" in the app.', + } + except Exception as exc: + logger.exception("[yuanbao_tools] get_group_info error") + return {"success": False, "error": str(exc)} + + +async def query_group_members( + group_code: str, + action: str = "list_all", + name: str = "", + mention: bool = False, +) -> dict: + """ + 统一的群成员查询工具(对齐 TS query_session_members)。 + + action: + - find : 按昵称模糊搜索 + - list_bots : 列出 bot 和元宝 AI + - list_all : 列出全部成员 + """ + if not group_code: + return {"success": False, "error": "group_code is required"} + + adapter = _get_active_adapter() + if adapter is None: + return {"success": False, "error": "Yuanbao adapter is not connected"} + + try: + raw = await adapter.get_group_member_list(group_code) + if raw is None: + return {"success": False, "error": "get_group_member_list returned None"} + + all_members = [ + { + "user_id": m.get("user_id", ""), + "nickname": m.get("nickname", m.get("nick_name", "")), + "role": _USER_TYPE_LABEL.get( + m.get("user_type", m.get("role", 0)), "unknown" + ), + } + for m in raw.get("members", []) + ] + + if not all_members: + return {"success": False, "error": "No members found in this group."} + + hint = {"mention_hint": MENTION_HINT} if mention else {} + + if action == "list_bots": + bots = [m for m in all_members if m["role"] in ("yuanbao_ai", "bot")] + if not bots: + return {"success": False, "error": "No bots found in this group."} + return { + "success": True, + "msg": f"Found {len(bots)} bot(s).", + "members": bots, + **hint, + } + + if action == "find": + if name: + filt = name.strip().lower() + matched = [m for m in all_members if filt in m["nickname"].lower()] + if matched: + return { + "success": True, + "msg": f'Found {len(matched)} member(s) matching "{name}".', + "members": matched, + **hint, + } + return { + "success": False, + "msg": f'No match for "{name}". All members listed below.', + "members": all_members, + **hint, + } + return { + "success": True, + "msg": f"Found {len(all_members)} member(s).", + "members": all_members, + **hint, + } + + # list_all (default) + return { + "success": True, + "msg": f"Found {len(all_members)} member(s).", + "members": all_members, + **hint, + } + + except Exception as exc: + logger.exception("[yuanbao_tools] query_group_members error") + return {"success": False, "error": str(exc)} + + +async def search_sticker(query: str = "", limit: int = 10) -> dict: + """ + 在内置贴纸表中按关键词模糊搜索,返回 Top-N 候选。 + + 返回每条候选的 sticker_id / name / description / package_id, + 供 LLM 选择后传给 send_sticker。空 query 时返回前 N 条。 + """ + from gateway.platforms.yuanbao_sticker import search_stickers + + try: + safe_limit = max(1, min(50, int(limit) if limit else 10)) + except (TypeError, ValueError): + safe_limit = 10 + + try: + matches = search_stickers(query or "", limit=safe_limit) + except Exception as exc: + logger.exception("[yuanbao_tools] search_sticker error") + return {"success": False, "error": str(exc)} + + return { + "success": True, + "query": query or "", + "count": len(matches), + "results": [ + { + "sticker_id": s.get("sticker_id", ""), + "name": s.get("name", ""), + "description": s.get("description", ""), + "package_id": s.get("package_id", ""), + } + for s in matches + ], + } + + +async def send_sticker( + sticker: str = "", + chat_id: str = "", + reply_to: str = "", +) -> dict: + """ + 向 chat_id(缺省取当前会话)发送一张内置贴纸(TIMFaceElem)。 + + Args: + sticker: 贴纸名称(如 "六六六")或 sticker_id(如 "278")。为空时随机发送一张。 + chat_id: 目标会话;缺省时使用当前会话上下文(HERMES_SESSION_CHAT_ID)。 + 格式:``direct:{account_id}`` / ``group:{group_code}`` / 或裸 account_id。 + reply_to: 群聊场景的引用消息 ID(可选)。 + + Returns: ``{"success": bool, ...}`` + """ + from gateway.session_context import get_session_env + from gateway.platforms.yuanbao_sticker import ( + get_sticker_by_id, + get_sticker_by_name, + get_random_sticker, + ) + + target = (chat_id or "").strip() or get_session_env("HERMES_SESSION_CHAT_ID", "") + if not target: + return { + "success": False, + "error": "chat_id is required (no active yuanbao session detected)", + } + + adapter = _get_active_adapter() + if adapter is None: + return {"success": False, "error": "Yuanbao adapter is not connected"} + + raw = (sticker or "").strip() + sticker_obj: Optional[dict] = None + if not raw: + sticker_obj = get_random_sticker() + else: + if raw.isdigit(): + sticker_obj = get_sticker_by_id(raw) + if sticker_obj is None: + sticker_obj = get_sticker_by_name(raw) + + if sticker_obj is None: + return { + "success": False, + "error": f"Sticker not found: {raw!r}. " + f"Use search_sticker first to discover available stickers.", + } + + try: + result = await adapter.send_sticker( + chat_id=target, + sticker_name=sticker_obj.get("name", ""), + reply_to=reply_to or None, + ) + except Exception as exc: + logger.exception("[yuanbao_tools] send_sticker error") + return {"success": False, "error": str(exc)} + + if getattr(result, "success", False): + return { + "success": True, + "chat_id": target, + "sticker": { + "sticker_id": sticker_obj.get("sticker_id", ""), + "name": sticker_obj.get("name", ""), + }, + "message_id": getattr(result, "message_id", None), + "note": "Sticker delivered to the chat. If you have additional text to say, reply now; otherwise end your turn without generating text.", + } + return { + "success": False, + "error": getattr(result, "error", "send_sticker failed"), + } + + +# Image extensions for media dispatch (mirrors MessageSender.IMAGE_EXTS) +_IMAGE_EXTS = frozenset({".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp"}) + + +async def send_dm( + group_code: str, + name: str, + message: str, + user_id: str = "", + media_files: Optional[List[Tuple[str, bool]]] = None, +) -> dict: + """ + Send a DM (private chat message) to a group member, with optional media. + + Workflow: + 1. If user_id is provided, send directly. + 2. Otherwise, search the group member list by name to resolve user_id. + 3. Send text via adapter.send_dm(), then iterate media_files by extension. + + Args: + group_code: The group where the target user belongs. + name: Target user's nickname (partial match, case-insensitive). + message: The message text to send. + user_id: (Optional) If already known, skip the member lookup. + media_files: (Optional) List of (file_path, is_voice) tuples to send + after the text message. Images are sent via + send_image_file; everything else via send_document. + """ + if not message and not media_files: + return {"success": False, "error": "message or media_files is required"} + + adapter = _get_active_adapter() + if adapter is None: + return {"success": False, "error": "Yuanbao adapter is not connected"} + + resolved_user_id = user_id.strip() if user_id else "" + resolved_nickname = name.strip() + + # Step 1: Resolve user_id from group member list if not provided + if not resolved_user_id: + if not group_code: + return {"success": False, "error": "group_code is required when user_id is not provided"} + if not name: + return {"success": False, "error": "name is required when user_id is not provided"} + + try: + raw = await adapter.get_group_member_list(group_code) + if raw is None: + return {"success": False, "error": "get_group_member_list returned None"} + + members = raw.get("members", []) + filt = name.strip().lower() + matched = [ + m for m in members + if filt in (m.get("nickname") or m.get("nick_name") or "").lower() + ] + + if not matched: + return { + "success": False, + "error": f'No member matching "{name}" found in group {group_code}.', + } + if len(matched) > 1: + # Multiple matches — return candidates for disambiguation + candidates = [ + { + "user_id": m.get("user_id", ""), + "nickname": m.get("nickname", m.get("nick_name", "")), + } + for m in matched + ] + return { + "success": False, + "error": f'Multiple members match "{name}". Please specify which one.', + "candidates": candidates, + } + + resolved_user_id = matched[0].get("user_id", "") + resolved_nickname = matched[0].get("nickname", matched[0].get("nick_name", name)) + except Exception as exc: + logger.exception("[yuanbao_tools] send_dm member lookup error") + return {"success": False, "error": str(exc)} + + if not resolved_user_id: + return {"success": False, "error": "Could not resolve user_id"} + + # Step 2: Send text DM + media + chat_id = f"direct:{resolved_user_id}" + last_result = None + errors: list[str] = [] + try: + if message and message.strip(): + last_result = await adapter.send_dm(resolved_user_id, message, group_code=group_code) + if not last_result.success: + errors.append(last_result.error or "text send failed") + + # Step 3: Send media files + for media_path, _is_voice in media_files or []: + ext = Path(media_path).suffix.lower() + if ext in _IMAGE_EXTS: + last_result = await adapter.send_image_file(chat_id, media_path, group_code=group_code) + else: + last_result = await adapter.send_document(chat_id, media_path, group_code=group_code) + if not last_result.success: + errors.append(last_result.error or "media send failed") + + if last_result is None: + return {"success": False, "error": "No deliverable text or media remained"} + + if errors and (last_result is None or not last_result.success): + return {"success": False, "error": "; ".join(errors)} + + result = { + "success": True, + "user_id": resolved_user_id, + "nickname": resolved_nickname, + "message_id": last_result.message_id, + "note": f'DM sent to "{resolved_nickname}" successfully.', + } + if errors: + result["note"] += f" (partial failure: {'; '.join(errors)})" + return result + except Exception as exc: + logger.exception("[yuanbao_tools] send_dm error") + return {"success": False, "error": str(exc)} + + +# --------------------------------------------------------------------------- +# Registry registration +# --------------------------------------------------------------------------- + +from tools.registry import registry, tool_result, tool_error # noqa: E402 + + +def _check_yuanbao(): + """Toolset availability check — True when running in a yuanbao gateway session.""" + try: + from gateway.session_context import get_session_env + if get_session_env("HERMES_SESSION_PLATFORM", "") == "yuanbao": + return True + except Exception: + pass + return _get_active_adapter() is not None + + +async def _handle_yb_query_group_info(args, **kw): + return tool_result(await get_group_info( + group_code=args.get("group_code", ""), + )) + + +async def _handle_yb_query_group_members(args, **kw): + return tool_result(await query_group_members( + group_code=args.get("group_code", ""), + action=args.get("action", "list_all"), + name=args.get("name", ""), + mention=bool(args.get("mention", False)), + )) + + +async def _handle_yb_send_dm(args, **kw): + # Resolve group_code: prefer explicit arg, fallback to session context. + group_code = args.get("group_code", "") + if not group_code: + try: + from gateway.session_context import get_session_env + chat_id = get_session_env("HERMES_SESSION_CHAT_ID", "") + # chat_id format: "group:" → extract the code part + if chat_id.startswith("group:"): + group_code = chat_id.split(":", 1)[1] + except Exception: + pass + + # Parse media_files: list of {{"path": str, "is_voice": bool}} → List[Tuple[str, bool]] + raw_media = args.get("media_files") or [] + media_files = [] + for item in raw_media: + if isinstance(item, dict): + media_files.append((item.get("path", ""), bool(item.get("is_voice", False)))) + elif isinstance(item, (list, tuple)) and len(item) >= 2: + media_files.append((str(item[0]), bool(item[1]))) + + # Extract MEDIA: tags embedded in the message text (LLM often puts + # file paths there instead of using the media_files parameter). + message = args.get("message", "") + from gateway.platforms.base import BasePlatformAdapter + embedded_media, message = BasePlatformAdapter.extract_media(message) + if embedded_media: + media_files.extend(embedded_media) + + return tool_result(await send_dm( + group_code=group_code, name=args.get("name", ""), + message=message, + user_id=args.get("user_id", ""), + media_files=media_files or None, + )) + + +async def _handle_yb_search_sticker(args, **kw): + return tool_result(await search_sticker( + query=args.get("query", ""), + limit=args.get("limit", 10), + )) + + +async def _handle_yb_send_sticker(args, **kw): + return tool_result(await send_sticker( + sticker=args.get("sticker", ""), + chat_id=args.get("chat_id", ""), + reply_to=args.get("reply_to", ""), + )) + + +_TOOLSET = "hermes-yuanbao" + +registry.register( + name="yb_query_group_info", + toolset=_TOOLSET, + schema={ + "name": "yb_query_group_info", + "description": ( + "Query basic info about a group (called '派/Pai' in the app), " + "including group name, owner, and member count." + ), + "parameters": { + "type": "object", + "properties": { + "group_code": { + "type": "string", + "description": "The unique group identifier (group_code).", + }, + }, + "required": ["group_code"], + }, + }, + handler=_handle_yb_query_group_info, + check_fn=_check_yuanbao, + is_async=True, + emoji="👥", +) + +registry.register( + name="yb_query_group_members", + toolset=_TOOLSET, + schema={ + "name": "yb_query_group_members", + "description": ( + "Query members of a group (called '派/Pai' in the app). " + "Use this tool when you need to @mention someone, find a user by name, " + "list bots (including Yuanbao AI), or list all members. " + "IMPORTANT: You MUST call this tool before @mentioning any user, " + "because you need the exact nickname to construct the @mention format." + ), + "parameters": { + "type": "object", + "properties": { + "group_code": { + "type": "string", + "description": "The unique group identifier (group_code).", + }, + "action": { + "type": "string", + "enum": ["find", "list_bots", "list_all"], + "description": ( + "find — search a user by name (use when you need to @mention or look up someone); " + "list_bots — list bots and Yuanbao AI assistants; " + "list_all — list all members." + ), + }, + "name": { + "type": "string", + "description": ( + "User name to search (partial match, case-insensitive). " + "Required for 'find'. Use the name the user mentioned in the conversation." + ), + }, + "mention": { + "type": "boolean", + "description": ( + "Set to true when you need to @mention/at someone in your reply. " + "The response will include the exact @mention format to use." + ), + }, + }, + "required": ["group_code", "action"], + }, + }, + handler=_handle_yb_query_group_members, + check_fn=_check_yuanbao, + is_async=True, + emoji="📋", +) + +registry.register( + name="yb_send_dm", + toolset=_TOOLSET, + schema={ + "name": "yb_send_dm", + "description": ( + "Send a private/direct message (DM) to a user in a group, with optional media files. " + "This tool automatically looks up the user by name in the group member list " + "and sends the message. Use this when someone asks to privately message / 私信 / DM a user. " + "Supports text, images, and file attachments. " + "You can also provide user_id directly if already known." + ), + "parameters": { + "type": "object", + "properties": { + "group_code": { + "type": "string", + "description": ( + "The group where the target user belongs. " + "Extract from chat_id: 'group:328306697' → '328306697'. " + "Required when user_id is not provided." + ), + }, + "name": { + "type": "string", + "description": ( + "Target user's display name (partial match, case-insensitive). " + "Required when user_id is not provided." + ), + }, + "message": { + "type": "string", + "description": "The message text to send as a DM. Can be empty if only sending media.", + }, + "user_id": { + "type": "string", + "description": ( + "Target user's account ID. If provided, skips the member lookup. " + "Usually obtained from a previous yb_query_group_members call." + ), + }, + "media_files": { + "type": "array", + "description": ( + "Optional list of media files to send along with the DM. " + "Images (.jpg/.png/.gif/.webp/.bmp) are sent as image messages; " + "other files are sent as document attachments." + ), + "items": { + "type": "object", + "properties": { + "path": { + "type": "string", + "description": "Absolute local file path of the media to send.", + }, + "is_voice": { + "type": "boolean", + "description": "Whether this file is a voice message (default false).", + }, + }, + "required": ["path"], + }, + }, + }, + "required": [], + }, + }, + handler=_handle_yb_send_dm, + check_fn=_check_yuanbao, + is_async=True, + emoji="✉️", +) + + +registry.register( + name="yb_search_sticker", + toolset=_TOOLSET, + schema={ + "name": "yb_search_sticker", + "description": ( + "Search the built-in Yuanbao sticker (TIM face / 表情包) catalogue by keyword. " + "Returns the top matching candidates with sticker_id, name, and description. " + "Use this BEFORE yb_send_sticker to discover the right sticker_id. " + "Sticker = 贴纸 = TIM face — NOT a message reaction. " + "Prefer sending a sticker over bare Unicode emoji when reacting/expressing emotion." + ), + "parameters": { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": ( + "Search keyword (Chinese or English, e.g. '666', '比心', 'cool', '吃瓜'). " + "Empty string returns the first N stickers." + ), + }, + "limit": { + "type": "integer", + "description": "Max number of candidates to return (default 10, max 50).", + }, + }, + "required": [], + }, + }, + handler=_handle_yb_search_sticker, + check_fn=_check_yuanbao, + is_async=True, + emoji="🔍", +) + + +registry.register( + name="yb_send_sticker", + toolset=_TOOLSET, + schema={ + "name": "yb_send_sticker", + "description": ( + "Send a built-in sticker (TIMFaceElem / 贴纸表情) to the current Yuanbao chat. " + "Call yb_search_sticker first if you don't know the sticker_id/name. " + "Sticker = 贴纸 = TIM face — NOT a message reaction. " + "CRITICAL: Whenever the user asks you to send a sticker / 贴纸 / 表情包, you MUST " + "use this tool. DO NOT draw a PNG via execute_code / Pillow / matplotlib and " + "then call send_image_file — that produces a fake 'sticker' image instead of a " + "real TIM face and is the WRONG path. If no suitable sticker_id is known, call " + "yb_search_sticker first. When the recent thread shows users sending stickers, " + "prefer matching that tone by replying with a sticker instead of (or in " + "addition to) text." + ), + "parameters": { + "type": "object", + "properties": { + "sticker": { + "type": "string", + "description": ( + "Sticker name (e.g. '六六六', '比心', 'ok') or numeric sticker_id " + "(e.g. '278'). Empty string sends a random built-in sticker." + ), + }, + "chat_id": { + "type": "string", + "description": ( + "Target chat. Defaults to the current session. " + "Format: 'direct:{account_id}', 'group:{group_code}', or bare account_id." + ), + }, + "reply_to": { + "type": "string", + "description": "Optional ref_msg_id to quote-reply (group chat only).", + }, + }, + "required": [], + }, + }, + handler=_handle_yb_send_sticker, + check_fn=_check_yuanbao, + is_async=True, + emoji="🎨", +) diff --git a/toolsets.py b/toolsets.py index 1c113afe60..a444713f57 100644 --- a/toolsets.py +++ b/toolsets.py @@ -214,6 +214,18 @@ TOOLSETS = { "includes": [], }, + "yuanbao": { + "description": "Yuanbao platform tools - group info, member queries, DM, stickers", + "tools": [ + "yb_query_group_info", + "yb_query_group_members", + "yb_send_dm", + "yb_search_sticker", + "yb_send_sticker", + ], + "includes": [] + }, + "feishu_doc": { "description": "Read Feishu/Lark document content", "tools": ["feishu_doc_read"], @@ -434,6 +446,19 @@ TOOLSETS = { "includes": [] }, + "hermes-yuanbao": { + "description": "Yuanbao Bot 元宝消息平台工具集 - 群信息、成员查询、私聊、贴纸表情", + "tools": _HERMES_CORE_TOOLS + [ + "yb_query_group_info", + "yb_query_group_members", + "yb_send_dm", + "yb_search_sticker", + "yb_send_sticker", + ], + "module": "tools.yuanbao_tools", + "includes": [] + }, + "hermes-sms": { "description": "SMS bot toolset - interact with Hermes via SMS (Twilio)", "tools": _HERMES_CORE_TOOLS, @@ -449,7 +474,7 @@ TOOLSETS = { "hermes-gateway": { "description": "Gateway toolset - union of all messaging platform tools", "tools": [], - "includes": ["hermes-telegram", "hermes-discord", "hermes-whatsapp", "hermes-slack", "hermes-signal", "hermes-bluebubbles", "hermes-homeassistant", "hermes-email", "hermes-sms", "hermes-mattermost", "hermes-matrix", "hermes-dingtalk", "hermes-feishu", "hermes-wecom", "hermes-wecom-callback", "hermes-weixin", "hermes-qqbot", "hermes-webhook"] + "includes": ["hermes-telegram", "hermes-discord", "hermes-whatsapp", "hermes-slack", "hermes-signal", "hermes-bluebubbles", "hermes-homeassistant", "hermes-email", "hermes-sms", "hermes-mattermost", "hermes-matrix", "hermes-dingtalk", "hermes-feishu", "hermes-wecom", "hermes-wecom-callback", "hermes-weixin", "hermes-qqbot", "hermes-webhook", "hermes-yuanbao"] } } diff --git a/website/docs/user-guide/messaging/index.md b/website/docs/user-guide/messaging/index.md index 859a4d04ab..126ab8184f 100644 --- a/website/docs/user-guide/messaging/index.md +++ b/website/docs/user-guide/messaging/index.md @@ -1,12 +1,12 @@ --- sidebar_position: 1 title: "Messaging Gateway" -description: "Chat with Hermes from Telegram, Discord, Slack, WhatsApp, Signal, SMS, Email, Home Assistant, Mattermost, Matrix, DingTalk, Webhooks, or any OpenAI-compatible frontend via the API server — architecture and setup overview" +description: "Chat with Hermes from Telegram, Discord, Slack, WhatsApp, Signal, SMS, Email, Home Assistant, Mattermost, Matrix, DingTalk, Yuanbao, Webhooks, or any OpenAI-compatible frontend via the API server — architecture and setup overview" --- # Messaging Gateway -Chat with Hermes from Telegram, Discord, Slack, WhatsApp, Signal, SMS, Email, Home Assistant, Mattermost, Matrix, DingTalk, Feishu/Lark, WeCom, Weixin, BlueBubbles (iMessage), QQ, or your browser. The gateway is a single background process that connects to all your configured platforms, handles sessions, runs cron jobs, and delivers voice messages. +Chat with Hermes from Telegram, Discord, Slack, WhatsApp, Signal, SMS, Email, Home Assistant, Mattermost, Matrix, DingTalk, Feishu/Lark, WeCom, Weixin, BlueBubbles (iMessage), QQ, Yuanbao, or your browser. The gateway is a single background process that connects to all your configured platforms, handles sessions, runs cron jobs, and delivers voice messages. For the full voice feature set — including CLI microphone mode, spoken replies in messaging, and Discord voice-channel conversations — see [Voice Mode](/docs/user-guide/features/voice-mode) and [Use Voice Mode with Hermes](/docs/guides/use-voice-mode-with-hermes). @@ -31,6 +31,7 @@ For the full voice feature set — including CLI microphone mode, spoken replies | Weixin | ✅ | ✅ | ✅ | — | — | ✅ | ✅ | | BlueBubbles | — | ✅ | ✅ | — | ✅ | ✅ | — | | QQ | ✅ | ✅ | ✅ | — | — | ✅ | — | +| Yuanbao | ✅ | ✅ | ✅ | — | — | ✅ | ✅ | **Voice** = TTS audio replies and/or voice message transcription. **Images** = send/receive images. **Files** = send/receive file attachments. **Threads** = threaded conversations. **Reactions** = emoji reactions on messages. **Typing** = typing indicator while processing. **Streaming** = progressive message updates via editing. @@ -57,6 +58,7 @@ flowchart TB wx[Weixin] bb[BlueBubbles] qq[QQ] + yb[Yuanbao] api["API Server
(OpenAI-compatible)"] wh[Webhooks] end @@ -83,6 +85,7 @@ flowchart TB wx --> store bb --> store qq --> store + yb --> store api --> store wh --> store store --> agent @@ -386,6 +389,7 @@ Each platform has its own toolset: | Weixin | `hermes-weixin` | Full tools including terminal | | BlueBubbles | `hermes-bluebubbles` | Full tools including terminal | | QQBot | `hermes-qqbot` | Full tools including terminal | +| Yuanbao | `hermes-yuanbao` | Full tools including terminal | | API Server | `hermes` (default) | Full tools including terminal | | Webhooks | `hermes-webhook` | Full tools including terminal | @@ -408,5 +412,6 @@ Each platform has its own toolset: - [Weixin Setup (WeChat)](weixin.md) - [BlueBubbles Setup (iMessage)](bluebubbles.md) - [QQBot Setup](qqbot.md) +- [Yuanbao Setup](yuanbao.md) - [Open WebUI + API Server](open-webui.md) -- [Webhooks](webhooks.md) +- [Webhooks](webhooks.md) \ No newline at end of file diff --git a/website/docs/user-guide/messaging/yuanbao.md b/website/docs/user-guide/messaging/yuanbao.md new file mode 100644 index 0000000000..63a5a50e90 --- /dev/null +++ b/website/docs/user-guide/messaging/yuanbao.md @@ -0,0 +1,341 @@ +--- +sidebar_position: 16 +title: "Yuanbao" +description: "Connect Hermes Agent to the Yuanbao enterprise messaging platform via WebSocket gateway" +--- + +# Yuanbao + +Connect Hermes to [Yuanbao](https://yuanbao.tencent.com/), Tencent's enterprise messaging platform. The adapter uses a WebSocket gateway for real-time message delivery and supports both direct (C2C) and group conversations. + +:::info +Yuanbao is an enterprise messaging platform primarily used within Tencent and enterprise environments. It uses WebSocket for real-time communication, HMAC-based authentication, and supports rich media including images, files, and voice messages. +::: + +## Prerequisites + +- A Yuanbao account with bot creation permissions +- Yuanbao APP_ID and APP_SECRET (from platform admin) +- Python packages: `websockets` and `httpx` +- For media support: `aiofiles` + +Install the required dependencies: + +```bash +pip install websockets httpx aiofiles +``` + +## Setup + +### 1. Create a Bot in Yuanbao + +1. Download the Yuanbao app from [https://yuanbao.tencent.com/](https://yuanbao.tencent.com/) +2. In the app, go to **PAI → My Bot** and create a new bot +3. After the bot is created, copy the **APP_ID** and **APP_SECRET** + +### 2. Run the Setup Wizard + +The easiest way to configure Yuanbao is through the interactive setup: + +```bash +hermes gateway setup +``` + +Select **Yuanbao** when prompted. The wizard will: + +1. Ask for your APP_ID +2. Ask for your APP_SECRET +3. Save the configuration automatically + +:::tip +The WebSocket URL and API Domain have sensible defaults built in. You only need to provide APP_ID and APP_SECRET to get started. +::: + +### 3. Configure Environment Variables + +After initial setup, verify these variables in `~/.hermes/.env`: + +```bash +# Required +YUANBAO_APP_ID=your-app-id +YUANBAO_APP_SECRET=your-app-secret +YUANBAO_WS_URL=wss://api.yuanbao.example.com/ws +YUANBAO_API_DOMAIN=https://api.yuanbao.example.com + +# Optional: bot account ID (normally obtained automatically from sign-token) +# YUANBAO_BOT_ID=your-bot-id + +# Optional: internal routing environment (e.g. test/staging/production) +# YUANBAO_ROUTE_ENV=production + +# Optional: home channel for cron/notifications (format: direct: or group:) +YUANBAO_HOME_CHANNEL=direct:bot_account_id +YUANBAO_HOME_CHANNEL_NAME="Bot Notifications" + +# Optional: restrict access (legacy, see Access Control below for fine-grained policies) +YUANBAO_ALLOWED_USERS=user_account_1,user_account_2 +``` + +### 4. Start the Gateway + +```bash +hermes gateway +``` + +The adapter will connect to the Yuanbao WebSocket gateway, authenticate using HMAC signatures, and begin processing messages. + +## Features + +- **WebSocket gateway** — real-time bidirectional communication +- **HMAC authentication** — secure request signing with APP_ID/APP_SECRET +- **C2C messaging** — direct user-to-bot conversations +- **Group messaging** — conversations in group chats +- **Media support** — images, files, and voice messages via COS (Cloud Object Storage) +- **Markdown formatting** — messages are automatically chunked for Yuanbao's size limits +- **Message deduplication** — prevents duplicate processing of the same message +- **Heartbeat/keep-alive** — maintains WebSocket connection stability +- **Typing indicators** — shows "typing…" status while the agent processes +- **Automatic reconnection** — handles WebSocket disconnections with exponential backoff +- **Group information queries** — retrieve group details and member lists +- **Sticker/Emoji support** — send TIMFaceElem stickers and emoji in conversations +- **Auto-sethome** — first user to message the bot is automatically set as the home channel owner +- **Slow-response notification** — sends a waiting message when the agent takes longer than expected + +## Configuration Options + +### Chat ID Formats + +Yuanbao uses prefixed identifiers depending on conversation type: + +| Chat Type | Format | Example | +|-----------|--------|---------| +| Direct message (C2C) | `direct:` | `direct:user123` | +| Group message | `group:` | `group:grp456` | + +### Media Uploads + +The Yuanbao adapter automatically handles media uploads via COS (Tencent Cloud Object Storage): + +- **Images**: Supports JPEG, PNG, GIF, WebP +- **Files**: Supports all common document types +- **Voice**: Supports WAV, MP3, OGG + +Media URLs are automatically validated and downloaded before upload to prevent SSRF attacks. + +## Home Channel + +Use the `/sethome` command in any Yuanbao chat (DM or group) to designate it as the **home channel**. Scheduled tasks (cron jobs) deliver their results to this channel. + +:::tip Auto-sethome +If no home channel is configured, the first user to message the bot will be automatically set as the home channel owner. If the current home channel is a group chat, the first DM will upgrade it to a direct channel. +::: + +You can also set it manually in `~/.hermes/.env`: + +```bash +YUANBAO_HOME_CHANNEL=direct:user_account_id +# or for a group: +# YUANBAO_HOME_CHANNEL=group:group_code +YUANBAO_HOME_CHANNEL_NAME="My Bot Updates" +``` + +### Example: Set Home Channel + +1. Start a conversation with the bot in Yuanbao +2. Send the command: `/sethome` +3. The bot responds: "Home channel set to [chat_name] with ID [chat_id]. Cron jobs will deliver to this location." +4. Future cron jobs and notifications will be sent to this channel + +### Example: Cron Job Delivery + +Create a cron job: + +```bash +/cron "0 9 * * *" Check server status +``` + +The scheduled output will be delivered to your Yuanbao home channel every day at 9 AM. + +## Usage Tips + +### Starting a Conversation + +Send any message to the bot in Yuanbao: + +``` +hello +``` + +The bot responds in the same conversation thread. + +### Available Commands + +All standard Hermes commands work on Yuanbao: + +| Command | Description | +|---------|-------------| +| `/new` | Start a fresh conversation | +| `/model [provider:model]` | Show or change the model | +| `/sethome` | Set this chat as the home channel | +| `/status` | Show session info | +| `/help` | Show available commands | + +### Sending Files + +To send a file to the bot, simply attach it directly in the Yuanbao chat. The bot will automatically download and process the file attachment. + +You can also include a message with the attachment: + +``` +Please analyze this document +``` + +### Receiving Files + +When you ask the bot to create or export a file, it sends the file directly to your Yuanbao chat. + +## Troubleshooting + +### Bot is online but not responding to messages + +**Cause**: Authentication failed during WebSocket handshake. + +**Fix**: +1. Verify APP_ID and APP_SECRET are correct +2. Check that the WebSocket URL is accessible +3. Ensure the bot account has proper permissions +4. Review gateway logs: `tail -f ~/.hermes/logs/gateway.log` + +### "Connection refused" error + +**Cause**: WebSocket URL is unreachable or incorrect. + +**Fix**: +1. Verify the WebSocket URL format (should start with `wss://`) +2. Check network connectivity to the Yuanbao API domain +3. Confirm firewall allows WebSocket connections +4. Test URL with: `curl -I https://[YUANBAO_API_DOMAIN]` + +### Media uploads fail + +**Cause**: COS credentials are invalid or media server is unreachable. + +**Fix**: +1. Verify API_DOMAIN is correct +2. Check that media upload permissions are enabled for your bot +3. Ensure the media file is accessible and not corrupted +4. Check COS bucket configuration with platform admin + +### Messages not delivered to home channel + +**Cause**: Home channel ID format is incorrect or cron job hasn't triggered. + +**Fix**: +1. Verify YUANBAO_HOME_CHANNEL is in correct format +2. Test with `/sethome` command to auto-detect correct format +3. Check cron job schedule with `/status` +4. Verify bot has send permissions in the target chat + +### Frequent disconnections + +**Cause**: WebSocket connection is unstable or network is unreliable. + +**Fix**: +1. Check gateway logs for error patterns +2. Increase heartbeat timeout in connection settings +3. Ensure stable network connection to Yuanbao API +4. Consider enabling verbose logging: `HERMES_LOG_LEVEL=debug` + +## Access Control + +Yuanbao supports fine-grained access control for both DM and group conversations: + +```bash +# DM policy: open (default) | allowlist | disabled +YUANBAO_DM_POLICY=open +# Comma-separated user IDs allowed to DM the bot (only used when DM_POLICY=allowlist) +YUANBAO_DM_ALLOW_FROM=user_id_1,user_id_2 + +# Group policy: open (default) | allowlist | disabled +YUANBAO_GROUP_POLICY=open +# Comma-separated group codes allowed (only used when GROUP_POLICY=allowlist) +YUANBAO_GROUP_ALLOW_FROM=group_code_1,group_code_2 +``` + +These can also be set in `config.yaml`: + +```yaml +platforms: + yuanbao: + extra: + dm_policy: allowlist + dm_allow_from: "user1,user2" + group_policy: open + group_allow_from: "" +``` + +## Advanced Configuration + +### Message Chunking + +Yuanbao has a maximum message size. Hermes automatically chunks large responses with Markdown-aware splitting (respects code fences, tables, and paragraph boundaries). + +### Connection Parameters + +The following connection parameters are built into the adapter with sensible defaults: + +| Parameter | Default Value | Description | +|-----------|---------------|-------------| +| WebSocket connect timeout | 15 seconds | Time to wait for WS handshake | +| Heartbeat interval | 30 seconds | Ping frequency to keep connection alive | +| Max reconnect attempts | 100 | Maximum number of reconnection tries | +| Reconnect backoff | 1s → 60s (exponential) | Wait time between reconnect attempts | +| Reply heartbeat interval | 2 seconds | RUNNING status send frequency | +| Send timeout | 30 seconds | Timeout for outbound WS messages | + +:::note +These values are currently not configurable via environment variables. They are optimized for typical Yuanbao deployments. +::: + +### Verbose Logging + +Enable debug logging to troubleshoot connection issues: + +```bash +HERMES_LOG_LEVEL=debug hermes gateway +``` + +## Integration with Other Features + +### Cron Jobs + +Schedule tasks that run on Yuanbao: + +``` +/cron "0 */4 * * *" Report system health +``` + +Results are delivered to your home channel. + +### Background Tasks + +Run long operations without blocking the conversation: + +``` +/background Analyze all files in the archive +``` + +### Cross-Platform Messages + +Send a message from CLI to Yuanbao: + +```bash +hermes chat -q "Send 'Hello from CLI' to yuanbao:group:group_code" +``` + +## Related Documentation + +- [Messaging Gateway Overview](./index.md) +- [Slash Commands Reference](/docs/reference/slash-commands.md) +- [Cron Jobs](/docs/user-guide/features/cron-jobs.md) +- [Background Tasks](/docs/guides/tips.md#background-tasks) \ No newline at end of file