mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-28 06:51:16 +08:00
4755 lines
181 KiB
Python
4755 lines
181 KiB
Python
"""
|
|
Yuanbao platform adapter.
|
|
|
|
Connects to the Yuanbao WebSocket gateway, handles authentication (AUTH_BIND),
|
|
heartbeat, reconnection, message receive (T05) and send (T06).
|
|
|
|
Configuration in config.yaml (or via env vars):
|
|
platforms:
|
|
yuanbao:
|
|
extra:
|
|
app_id: "..." # or YUANBAO_APP_ID
|
|
app_secret: "..." # or YUANBAO_APP_SECRET
|
|
bot_id: "..." # or YUANBAO_BOT_ID (optional, returned by sign-token)
|
|
ws_url: "wss://..." # or YUANBAO_WS_URL
|
|
api_domain: "https://..." # or YUANBAO_API_DOMAIN
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import collections
|
|
import dataclasses
|
|
import hashlib
|
|
import hmac
|
|
import json
|
|
import logging
|
|
import os
|
|
import re
|
|
import secrets
|
|
import time
|
|
import urllib.parse
|
|
import uuid
|
|
from datetime import datetime, timezone, timedelta
|
|
from pathlib import Path
|
|
from abc import ABC, abstractmethod
|
|
from typing import Any, Callable, ClassVar, Dict, List, Optional, Tuple
|
|
|
|
import sys
|
|
|
|
import httpx
|
|
|
|
try:
|
|
import websockets
|
|
import websockets.exceptions
|
|
WEBSOCKETS_AVAILABLE = True
|
|
except ImportError:
|
|
WEBSOCKETS_AVAILABLE = False
|
|
websockets = None # type: ignore[assignment]
|
|
|
|
from gateway.config import Platform, PlatformConfig
|
|
from gateway.platforms.base import (
|
|
BasePlatformAdapter,
|
|
MessageEvent,
|
|
MessageType,
|
|
SendResult,
|
|
cache_document_from_bytes,
|
|
cache_image_from_bytes,
|
|
)
|
|
from gateway.platforms.helpers import MessageDeduplicator
|
|
from gateway.platforms.yuanbao_media import (
|
|
download_url as media_download_url,
|
|
get_cos_credentials,
|
|
upload_to_cos,
|
|
build_image_msg_body,
|
|
build_file_msg_body,
|
|
guess_mime_type,
|
|
md5_hex,
|
|
)
|
|
from gateway.platforms.yuanbao_proto import (
|
|
CMD_TYPE,
|
|
_fields_to_dict,
|
|
_get_string,
|
|
_get_varint,
|
|
_parse_fields,
|
|
WS_HEARTBEAT_RUNNING,
|
|
WS_HEARTBEAT_FINISH,
|
|
HERMES_INSTANCE_ID,
|
|
decode_conn_msg,
|
|
decode_inbound_push,
|
|
decode_query_group_info_rsp,
|
|
decode_get_group_member_list_rsp,
|
|
encode_auth_bind,
|
|
encode_ping,
|
|
encode_push_ack,
|
|
encode_send_c2c_message,
|
|
encode_send_group_message,
|
|
encode_send_private_heartbeat,
|
|
encode_send_group_heartbeat,
|
|
encode_query_group_info,
|
|
encode_get_group_member_list,
|
|
next_seq_no,
|
|
)
|
|
from gateway.session import SessionSource, build_session_key
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Version / platform constants (used in AUTH_BIND and sign-token headers)
|
|
# ---------------------------------------------------------------------------
|
|
try:
|
|
from hermes_cli import __version__ as _HERMES_VERSION
|
|
except ImportError:
|
|
_HERMES_VERSION = "0.0.0"
|
|
|
|
_APP_VERSION = _HERMES_VERSION
|
|
_BOT_VERSION = _HERMES_VERSION
|
|
_YUANBAO_INSTANCE_ID = str(HERMES_INSTANCE_ID) # single source: yuanbao_proto.HERMES_INSTANCE_ID
|
|
_OPERATION_SYSTEM = sys.platform
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Module-level constants
|
|
# ---------------------------------------------------------------------------
|
|
|
|
DEFAULT_WS_GATEWAY_URL = "wss://bot-wss.yuanbao.tencent.com/wss/connection"
|
|
DEFAULT_API_DOMAIN = "https://bot.yuanbao.tencent.com"
|
|
|
|
HEARTBEAT_INTERVAL_SECONDS = 30.0
|
|
CONNECT_TIMEOUT_SECONDS = 15.0
|
|
AUTH_TIMEOUT_SECONDS = 10.0
|
|
MAX_RECONNECT_ATTEMPTS = 100
|
|
DEFAULT_SEND_TIMEOUT = 30.0 # WS biz request timeout
|
|
|
|
# Close codes that indicate permanent errors — do NOT reconnect.
|
|
NO_RECONNECT_CLOSE_CODES = {4012, 4013, 4014, 4018, 4019, 4021}
|
|
|
|
# Heartbeat timeout threshold — N consecutive missed pongs trigger reconnect.
|
|
HEARTBEAT_TIMEOUT_THRESHOLD = 2
|
|
|
|
# Auth error code classification
|
|
AUTH_FAILED_CODES = {4001, 4002, 4003} # permanent auth failure, re-sign token
|
|
AUTH_RETRYABLE_CODES = {4010, 4011, 4099} # transient, can retry with same token
|
|
|
|
# Reply Heartbeat configuration
|
|
REPLY_HEARTBEAT_INTERVAL_S = 2.0 # Send RUNNING every 2 seconds
|
|
REPLY_HEARTBEAT_TIMEOUT_S = 30.0 # Auto-stop after 30 seconds of inactivity
|
|
|
|
# Reply-to reference configuration
|
|
REPLY_REF_TTL_S = 300.0 # Reference dedup TTL (5 minutes)
|
|
|
|
# Slow-response hint: push a waiting message when agent produces no data for this duration (seconds)
|
|
SLOW_RESPONSE_TIMEOUT_S = 120.0
|
|
SLOW_RESPONSE_MESSAGE = "任务有点复杂,正在努力处理中,请耐心等待..."
|
|
|
|
# Regex matching Yuanbao resource reference anchors in transcript text:
|
|
# [image|ybres:abc123] [file:report.pdf|ybres:xyz789] [voice|ybres:...]
|
|
_YB_RES_REF_RE = re.compile(
|
|
r"\[(image|voice|video|file(?::[^|\]]*)?)\|ybres:([A-Za-z0-9_\-]+)\]"
|
|
)
|
|
|
|
# Strip page indicators like (1/3) appended by BasePlatformAdapter
|
|
_INDICATOR_RE = re.compile(r'\s*\(\d+/\d+\)$')
|
|
|
|
# Observed-media backfill: how many recent transcript messages to scan
|
|
OBSERVED_MEDIA_BACKFILL_LOOKBACK = 50
|
|
# Max number of resource references to resolve per inbound turn
|
|
OBSERVED_MEDIA_BACKFILL_MAX_RESOLVE_PER_TURN = 12
|
|
|
|
class MarkdownProcessor:
|
|
"""Encapsulates all Markdown-related utilities for the Yuanbao platform.
|
|
|
|
Provides static methods for:
|
|
- Fence detection and streaming merge
|
|
- Table row detection and sanitization
|
|
- Paragraph-boundary splitting
|
|
- Atomic-block extraction and chunk splitting
|
|
- Outer markdown fence stripping
|
|
- Markdown hint prompt generation
|
|
"""
|
|
|
|
# -- Fence detection ---------------------------------------------------
|
|
|
|
@staticmethod
|
|
def has_unclosed_fence(text: str) -> bool:
|
|
"""
|
|
Detect whether the text has unclosed code block fences.
|
|
|
|
Scan line by line, toggling in/out state when encountering a line starting with ```.
|
|
An odd number of toggles indicates an unclosed fence.
|
|
|
|
Args:
|
|
text: Markdown text to check
|
|
|
|
Returns:
|
|
Returns True if the text ends with an unclosed fence, otherwise False
|
|
"""
|
|
in_fence = False
|
|
for line in text.split('\n'):
|
|
if line.startswith('```'):
|
|
in_fence = not in_fence
|
|
return in_fence
|
|
|
|
# -- Table detection ---------------------------------------------------
|
|
|
|
@staticmethod
|
|
def ends_with_table_row(text: str) -> bool:
|
|
"""
|
|
Detect whether the text ends with a table row (last non-empty line starts and ends with |).
|
|
|
|
Args:
|
|
text: Text to check
|
|
|
|
Returns:
|
|
Returns True if the last non-empty line is a table row
|
|
"""
|
|
trimmed = text.rstrip()
|
|
if not trimmed:
|
|
return False
|
|
last_line = trimmed.split('\n')[-1].strip()
|
|
return last_line.startswith('|') and last_line.endswith('|')
|
|
|
|
# -- Paragraph boundary splitting --------------------------------------
|
|
|
|
@staticmethod
|
|
def split_at_paragraph_boundary(
|
|
text: str,
|
|
max_chars: int,
|
|
len_fn: Optional[Callable[[str], int]] = None,
|
|
) -> tuple[str, str]:
|
|
"""
|
|
Find the nearest paragraph boundary split point within max_chars, return (head, tail).
|
|
|
|
Split priority:
|
|
1. Blank line (paragraph boundary)
|
|
2. Newline after period/question mark/exclamation mark (Chinese and English)
|
|
3. Last newline
|
|
4. Force split at max_chars
|
|
|
|
Args:
|
|
text: Text to split
|
|
max_chars: Maximum character count limit
|
|
len_fn: Optional custom length function (e.g. UTF-16 length); defaults to built-in len
|
|
|
|
Returns:
|
|
(head, tail) tuple, head is the front part, tail is the back part, satisfying head + tail == text
|
|
"""
|
|
_len = len_fn or len
|
|
if _len(text) <= max_chars:
|
|
return text, ''
|
|
|
|
# Build a character-index window that fits within max_chars.
|
|
# When len_fn != len we cannot simply slice [:max_chars], so we
|
|
# binary-search for the largest prefix that fits.
|
|
if _len is len:
|
|
window = text[:max_chars]
|
|
else:
|
|
lo, hi = 0, len(text)
|
|
while lo < hi:
|
|
mid = (lo + hi + 1) // 2
|
|
if _len(text[:mid]) <= max_chars:
|
|
lo = mid
|
|
else:
|
|
hi = mid - 1
|
|
window = text[:lo]
|
|
|
|
# 1. Prefer the last blank line (\n\n) as paragraph boundary
|
|
pos = window.rfind('\n\n')
|
|
if pos > 0:
|
|
return text[:pos + 2], text[pos + 2:]
|
|
|
|
# 2. Then find the last newline after a sentence-ending punctuation
|
|
sentence_end_re = re.compile(r'[。!?.!?]\n')
|
|
best_pos = -1
|
|
for m in sentence_end_re.finditer(window):
|
|
best_pos = m.end()
|
|
if best_pos > 0:
|
|
return text[:best_pos], text[best_pos:]
|
|
|
|
# 3. Fallback: find the last newline
|
|
pos = window.rfind('\n')
|
|
if pos > 0:
|
|
return text[:pos + 1], text[pos + 1:]
|
|
|
|
# 4. No valid split point found, force split at window boundary
|
|
cut = len(window)
|
|
return text[:cut], text[cut:]
|
|
|
|
# -- Atomic block helpers (private) ------------------------------------
|
|
|
|
@staticmethod
|
|
def is_fence_atom(text: str) -> bool:
|
|
"""Determine whether an atomic block is a code block (starts with ```)."""
|
|
return text.lstrip().startswith('```')
|
|
|
|
@staticmethod
|
|
def is_table_atom(text: str) -> bool:
|
|
"""Determine whether an atomic block is a table (first line starts with |)."""
|
|
first_line = text.split('\n')[0].strip()
|
|
return first_line.startswith('|') and first_line.endswith('|')
|
|
|
|
@staticmethod
|
|
def split_into_atoms(text: str) -> list[str]:
|
|
"""
|
|
Split text into a list of "atomic blocks", each being an indivisible logical unit:
|
|
|
|
- Code block (fence): from opening ``` to closing ``` (including fence lines)
|
|
- Table: consecutive |...| lines forming a whole segment
|
|
- Normal paragraph: plain text segments separated by blank lines
|
|
|
|
Blank lines serve as separators and are not included in any atomic block.
|
|
|
|
Args:
|
|
text: Markdown text to split
|
|
|
|
Returns:
|
|
List of atomic block strings (all non-empty)
|
|
"""
|
|
lines = text.split('\n')
|
|
atoms: list[str] = []
|
|
|
|
current_lines: list[str] = []
|
|
in_fence = False
|
|
|
|
def _is_table_line(line: str) -> bool:
|
|
stripped = line.strip()
|
|
return stripped.startswith('|') and stripped.endswith('|')
|
|
|
|
def _flush_current() -> None:
|
|
if current_lines:
|
|
atom = '\n'.join(current_lines)
|
|
if atom.strip():
|
|
atoms.append(atom)
|
|
current_lines.clear()
|
|
|
|
for line in lines:
|
|
if in_fence:
|
|
current_lines.append(line)
|
|
if line.startswith('```') and len(current_lines) > 1:
|
|
in_fence = False
|
|
_flush_current()
|
|
elif line.startswith('```'):
|
|
_flush_current()
|
|
in_fence = True
|
|
current_lines.append(line)
|
|
elif _is_table_line(line):
|
|
if current_lines and not _is_table_line(current_lines[-1]):
|
|
_flush_current()
|
|
current_lines.append(line)
|
|
elif line.strip() == '':
|
|
_flush_current()
|
|
else:
|
|
if current_lines and _is_table_line(current_lines[-1]):
|
|
_flush_current()
|
|
current_lines.append(line)
|
|
|
|
_flush_current()
|
|
|
|
return atoms
|
|
|
|
# -- Core: chunk splitting ---------------------------------------------
|
|
|
|
@classmethod
|
|
def chunk_markdown_text(
|
|
cls,
|
|
text: str,
|
|
max_chars: int = 4000,
|
|
len_fn: Optional[Callable[[str], int]] = None,
|
|
) -> list[str]:
|
|
"""
|
|
Split Markdown text into multiple chunks by max_chars.
|
|
|
|
Guarantees:
|
|
- Each chunk <= max_chars characters (unless a single code block/table itself exceeds the limit)
|
|
- Code blocks (```...```) are not split in the middle
|
|
- Table rows are not split in the middle (tables output as atomic blocks)
|
|
- Split at paragraph boundaries (blank lines, after periods, etc.)
|
|
- Small trailing/leading chunks are merged with neighbours when possible
|
|
|
|
Args:
|
|
text: Markdown text to split
|
|
max_chars: Max characters per chunk, default 4000
|
|
len_fn: Optional custom length function (e.g. UTF-16 length); defaults to built-in len
|
|
|
|
Returns:
|
|
List of text chunks after splitting (non-empty)
|
|
"""
|
|
_len = len_fn or len
|
|
|
|
if not text:
|
|
return []
|
|
|
|
if _len(text) <= max_chars:
|
|
return [text]
|
|
|
|
# Phase 1: Extract atomic blocks
|
|
atoms = cls.split_into_atoms(text)
|
|
|
|
# Phase 2: Greedy merge
|
|
chunks: list[str] = []
|
|
indivisible_set: set[int] = set()
|
|
current_parts: list[str] = []
|
|
current_len = 0
|
|
|
|
def _flush_parts() -> None:
|
|
if current_parts:
|
|
chunks.append('\n\n'.join(current_parts))
|
|
|
|
for atom in atoms:
|
|
atom_len = _len(atom)
|
|
sep_len = 2 if current_parts else 0
|
|
projected_len = current_len + sep_len + atom_len
|
|
|
|
if projected_len > max_chars and current_parts:
|
|
_flush_parts()
|
|
current_parts = []
|
|
current_len = 0
|
|
sep_len = 0
|
|
|
|
if (not current_parts
|
|
and atom_len > max_chars
|
|
and (cls.is_fence_atom(atom) or cls.is_table_atom(atom))):
|
|
indivisible_set.add(len(chunks))
|
|
chunks.append(atom)
|
|
continue
|
|
|
|
current_parts.append(atom)
|
|
current_len += sep_len + atom_len
|
|
|
|
_flush_parts()
|
|
|
|
# Phase 3: Post-processing — split still-oversized chunks at paragraph boundaries
|
|
result: list[str] = []
|
|
for idx, chunk in enumerate(chunks):
|
|
if _len(chunk) <= max_chars:
|
|
result.append(chunk)
|
|
continue
|
|
|
|
if idx in indivisible_set:
|
|
result.append(chunk)
|
|
continue
|
|
|
|
if cls.has_unclosed_fence(chunk):
|
|
result.append(chunk)
|
|
continue
|
|
|
|
remaining = chunk
|
|
while _len(remaining) > max_chars:
|
|
head, remaining = cls.split_at_paragraph_boundary(
|
|
remaining, max_chars, len_fn=len_fn,
|
|
)
|
|
if not head:
|
|
head, remaining = remaining[:max_chars], remaining[max_chars:]
|
|
if head:
|
|
result.append(head)
|
|
if remaining:
|
|
result.append(remaining)
|
|
|
|
# Phase 4: Merge small trailing/leading chunks with neighbours
|
|
if len(result) > 1:
|
|
merged: list[str] = [result[0]]
|
|
for chunk in result[1:]:
|
|
prev = merged[-1]
|
|
combined = prev + '\n\n' + chunk
|
|
if _len(combined) <= max_chars:
|
|
merged[-1] = combined
|
|
else:
|
|
merged.append(chunk)
|
|
result = merged
|
|
|
|
return [c for c in result if c]
|
|
|
|
# -- Block separator inference -----------------------------------------
|
|
|
|
@classmethod
|
|
def infer_block_separator(cls, prev_chunk: str, next_chunk: str) -> str:
|
|
"""
|
|
Infer the separator to use between two split chunks.
|
|
|
|
Rules (aligned with TS markdown-stream.ts):
|
|
- Previous chunk ends with code fence or next chunk starts with fence → single newline '\\n'
|
|
- Previous chunk ends with table row and next chunk starts with table row → single newline '\\n' (continued table)
|
|
- Otherwise → double newline '\\n\\n' (paragraph separator)
|
|
|
|
Args:
|
|
prev_chunk: Previous chunk
|
|
next_chunk: Next chunk
|
|
|
|
Returns:
|
|
'\\n' or '\\n\\n'
|
|
"""
|
|
prev_trimmed = prev_chunk.rstrip()
|
|
next_trimmed = next_chunk.lstrip()
|
|
|
|
# Previous chunk ends with fence or next chunk starts with fence
|
|
if prev_trimmed.endswith('```') or next_trimmed.startswith('```'):
|
|
return '\n'
|
|
|
|
# Table continuation
|
|
if cls.ends_with_table_row(prev_chunk):
|
|
first_line = next_trimmed.split('\n')[0].strip() if next_trimmed else ''
|
|
if first_line.startswith('|') and first_line.endswith('|'):
|
|
return '\n'
|
|
|
|
return '\n\n'
|
|
|
|
# -- Streaming fence merge ---------------------------------------------
|
|
|
|
@classmethod
|
|
def merge_block_streaming_fences(cls, chunks: list[str]) -> list[str]:
|
|
"""
|
|
Stream-aware fence-conscious chunk merging.
|
|
|
|
When streaming output produces multiple chunks truncated in the middle of a fence,
|
|
attempt to merge adjacent chunks to complete the fence.
|
|
|
|
Rules:
|
|
- If chunk i has an unclosed fence and chunk i+1 starts with ```,
|
|
merge i+1 into i (until the fence is closed or no more chunks).
|
|
- Use infer_block_separator to infer the separator during merging.
|
|
|
|
Args:
|
|
chunks: Original chunk list
|
|
|
|
Returns:
|
|
Merged chunk list (length <= original length)
|
|
"""
|
|
if not chunks:
|
|
return []
|
|
|
|
result: list[str] = []
|
|
i = 0
|
|
while i < len(chunks):
|
|
current = chunks[i]
|
|
# If current chunk has unclosed fence, try merging subsequent chunks
|
|
while cls.has_unclosed_fence(current) and i + 1 < len(chunks):
|
|
sep = cls.infer_block_separator(current, chunks[i + 1])
|
|
current = current + sep + chunks[i + 1]
|
|
i += 1
|
|
result.append(current)
|
|
i += 1
|
|
|
|
return result
|
|
|
|
# -- Outer fence stripping ---------------------------------------------
|
|
|
|
@staticmethod
|
|
def strip_outer_markdown_fence(text: str) -> str:
|
|
"""
|
|
Strip outer Markdown fence.
|
|
|
|
When AI reply is entirely wrapped in ```markdown\\n...\\n```, remove the outer fence,
|
|
keeping the content. Only strip when the first line is ```markdown (case-insensitive) and the last line is ```.
|
|
|
|
Args:
|
|
text: Text to process
|
|
|
|
Returns:
|
|
Text with outer fence stripped (returns original if no match)
|
|
"""
|
|
if not text:
|
|
return text
|
|
|
|
lines = text.split('\n')
|
|
if len(lines) < 3:
|
|
return text
|
|
|
|
first_line = lines[0].strip()
|
|
last_line = lines[-1].strip()
|
|
|
|
# First line must be ```markdown (optional language tag md/markdown)
|
|
if not re.match(r'^```(?:markdown|md)?\s*$', first_line, re.IGNORECASE):
|
|
return text
|
|
|
|
# Last line must be plain ```
|
|
if last_line != '```':
|
|
return text
|
|
|
|
# Strip first and last lines
|
|
inner = '\n'.join(lines[1:-1])
|
|
return inner
|
|
|
|
# -- Table sanitization ------------------------------------------------
|
|
|
|
@staticmethod
|
|
def sanitize_markdown_table(text: str) -> str:
|
|
"""
|
|
Table output sanitization.
|
|
|
|
Handle common formatting issues in AI-generated Markdown tables:
|
|
1. Remove extra whitespace before/after table rows
|
|
2. Ensure separator rows (|---|---|) are correctly formatted
|
|
3. Remove empty table rows
|
|
|
|
Args:
|
|
text: Markdown text containing tables
|
|
|
|
Returns:
|
|
Sanitized text
|
|
"""
|
|
if '|' not in text:
|
|
return text
|
|
|
|
lines = text.split('\n')
|
|
result_lines: list[str] = []
|
|
|
|
for line in lines:
|
|
stripped = line.strip()
|
|
|
|
# Table row processing
|
|
if stripped.startswith('|') and stripped.endswith('|'):
|
|
# Separator row normalization: | --- | --- | → |---|---|
|
|
if re.match(r'^\|[\s\-:]+(\|[\s\-:]+)+\|$', stripped):
|
|
cells = stripped.split('|')
|
|
normalized = '|'.join(
|
|
cell.strip() if cell.strip() else cell
|
|
for cell in cells
|
|
)
|
|
result_lines.append(normalized)
|
|
elif stripped == '||' or stripped.replace('|', '').strip() == '':
|
|
# Empty table row → skip
|
|
continue
|
|
else:
|
|
result_lines.append(stripped)
|
|
else:
|
|
result_lines.append(line)
|
|
|
|
return '\n'.join(result_lines)
|
|
|
|
# -- Markdown hint prompt ----------------------------------------------
|
|
|
|
@staticmethod
|
|
def markdown_hint_system_prompt() -> str:
|
|
"""
|
|
Markdown rendering hint (appended to system prompt).
|
|
|
|
Tell AI that Yuanbao platform supports Markdown rendering, including:
|
|
- Code blocks (```lang)
|
|
- Tables (| col | col |)
|
|
- Bold/italic
|
|
"""
|
|
return (
|
|
"The current platform supports Markdown rendering. You can use the following formats:\n"
|
|
"- Code blocks: ```language\\ncode\\n```\n"
|
|
"- Tables: | col1 | col2 |\\n|---|---|\\n| val1 | val2 |\n"
|
|
"- Bold: **text** / Italic: *text*\n"
|
|
"Please use Markdown formatting when appropriate to improve readability."
|
|
)
|
|
|
|
class SignManager:
|
|
"""Encapsulates all sign-token related logic for the Yuanbao platform.
|
|
|
|
Manages token acquisition, caching, signature computation, and
|
|
automatic retry. All state (cache, locks) is kept as class-level
|
|
attributes so that a single shared client serves the whole process.
|
|
"""
|
|
|
|
# -- Constants ---------------------------------------------------------
|
|
|
|
TOKEN_PATH = "/api/v5/robotLogic/sign-token"
|
|
|
|
RETRYABLE_CODE = 10099
|
|
MAX_RETRIES = 3
|
|
RETRY_DELAY_S = 1.0
|
|
|
|
#: Early refresh margin (seconds), treat as expiring 60s before actual expiry
|
|
CACHE_REFRESH_MARGIN_S = 60
|
|
|
|
#: HTTP timeout (seconds)
|
|
HTTP_TIMEOUT_S = 10.0
|
|
|
|
# -- Class-level shared state ------------------------------------------
|
|
|
|
# key: app_key → {"token", "bot_id", "expire_ts", ...}
|
|
_cache: dict[str, dict[str, Any]] = {}
|
|
|
|
# Per-app_key refresh locks — prevents concurrent duplicate sign-token
|
|
# requests. Created lazily inside get_refresh_lock() which is only called
|
|
# from async context, so the Lock is always bound to the correct loop.
|
|
# disconnect() clears this dict to prevent stale locks across reconnects.
|
|
_locks: dict[str, asyncio.Lock] = {}
|
|
|
|
# -- Internal helpers --------------------------------------------------
|
|
|
|
@classmethod
|
|
def get_refresh_lock(cls, app_key: str) -> asyncio.Lock:
|
|
"""Return (creating if needed) the per-app_key refresh lock.
|
|
|
|
Must only be called from within a running event loop (async context).
|
|
"""
|
|
if app_key not in cls._locks:
|
|
cls._locks[app_key] = asyncio.Lock()
|
|
return cls._locks[app_key]
|
|
|
|
@staticmethod
|
|
def compute_signature(nonce: str, timestamp: str, app_key: str, app_secret: str) -> str:
|
|
"""Compute HMAC-SHA256 signature (aligned with TypeScript original).
|
|
|
|
plain = nonce + timestamp + app_key + app_secret
|
|
signature = HMAC-SHA256(key=app_secret, msg=plain).hexdigest()
|
|
"""
|
|
plain = nonce + timestamp + app_key + app_secret
|
|
return hmac.new(app_secret.encode(), plain.encode(), hashlib.sha256).hexdigest()
|
|
|
|
@staticmethod
|
|
def build_timestamp() -> str:
|
|
"""Build Beijing-time ISO-8601 timestamp (no milliseconds).
|
|
|
|
Format: 2006-01-02T15:04:05+08:00
|
|
"""
|
|
bjtime = datetime.now(tz=timezone(timedelta(hours=8)))
|
|
return bjtime.strftime("%Y-%m-%dT%H:%M:%S+08:00")
|
|
|
|
@classmethod
|
|
def is_cache_valid(cls, entry: dict[str, Any]) -> bool:
|
|
"""Determine whether the cache entry is valid (not expired with margin)."""
|
|
return entry["expire_ts"] - time.time() > cls.CACHE_REFRESH_MARGIN_S
|
|
|
|
@classmethod
|
|
def clear_locks(cls) -> None:
|
|
"""Clear all per-app_key refresh locks (called on disconnect)."""
|
|
cls._locks.clear()
|
|
|
|
@classmethod
|
|
def purge_expired(cls) -> int:
|
|
"""Remove all expired entries from the token cache.
|
|
|
|
Returns the number of entries purged. Called lazily from
|
|
``get_token()`` so that stale app_key entries don't accumulate
|
|
indefinitely in long-running processes.
|
|
"""
|
|
now = time.time()
|
|
expired_keys = [
|
|
k for k, v in cls._cache.items()
|
|
if now - v.get("expire_ts", 0) > 0
|
|
]
|
|
for k in expired_keys:
|
|
cls._cache.pop(k, None)
|
|
return len(expired_keys)
|
|
|
|
# -- Core: fetch -------------------------------------------------------
|
|
|
|
@classmethod
|
|
async def fetch(
|
|
cls,
|
|
app_key: str,
|
|
app_secret: str,
|
|
api_domain: str,
|
|
route_env: str = "",
|
|
) -> dict[str, Any]:
|
|
"""Send sign-ticket HTTP request with auto-retry (up to MAX_RETRIES times)."""
|
|
url = f"{api_domain.rstrip('/')}{cls.TOKEN_PATH}"
|
|
async with httpx.AsyncClient(timeout=cls.HTTP_TIMEOUT_S) as client:
|
|
for attempt in range(cls.MAX_RETRIES + 1):
|
|
nonce = secrets.token_hex(16)
|
|
timestamp = cls.build_timestamp()
|
|
signature = cls.compute_signature(nonce, timestamp, app_key, app_secret)
|
|
|
|
payload = {
|
|
"app_key": app_key,
|
|
"nonce": nonce,
|
|
"signature": signature,
|
|
"timestamp": timestamp,
|
|
}
|
|
|
|
headers = {
|
|
"Content-Type": "application/json",
|
|
"X-AppVersion": _APP_VERSION,
|
|
"X-OperationSystem": _OPERATION_SYSTEM,
|
|
"X-Instance-Id": _YUANBAO_INSTANCE_ID,
|
|
"X-Bot-Version": _BOT_VERSION,
|
|
}
|
|
if route_env:
|
|
headers["X-Route-Env"] = route_env
|
|
|
|
logger.info(
|
|
"Sign token request: url=%s%s",
|
|
url,
|
|
f" (retry {attempt}/{cls.MAX_RETRIES})" if attempt > 0 else "",
|
|
)
|
|
|
|
response = await client.post(url, json=payload, headers=headers)
|
|
|
|
if response.status_code != 200:
|
|
body = response.text
|
|
raise RuntimeError(f"Sign token API returned {response.status_code}: {body[:200]}")
|
|
|
|
try:
|
|
result_data: dict[str, Any] = response.json()
|
|
except Exception as exc:
|
|
raise ValueError(f"Sign token response parse error: {exc}") from exc
|
|
|
|
code = result_data.get("code")
|
|
if code == 0:
|
|
data = result_data.get("data")
|
|
if not isinstance(data, dict):
|
|
raise ValueError(f"Sign token response missing 'data' field: {result_data}")
|
|
logger.info("Sign token success: bot_id=%s", data.get("bot_id"))
|
|
return data
|
|
|
|
if code == cls.RETRYABLE_CODE and attempt < cls.MAX_RETRIES:
|
|
logger.warning(
|
|
"Sign token retryable: code=%s, retrying in %ss (attempt=%d/%d)",
|
|
code,
|
|
cls.RETRY_DELAY_S,
|
|
attempt + 1,
|
|
cls.MAX_RETRIES,
|
|
)
|
|
await asyncio.sleep(cls.RETRY_DELAY_S)
|
|
continue
|
|
|
|
msg = result_data.get("msg", "")
|
|
raise RuntimeError(f"Sign token error: code={code}, msg={msg}")
|
|
|
|
raise RuntimeError("Sign token failed: max retries exceeded")
|
|
|
|
# -- Public API: get (with cache) --------------------------------------
|
|
|
|
@classmethod
|
|
async def get_token(
|
|
cls,
|
|
app_key: str,
|
|
app_secret: str,
|
|
api_domain: str,
|
|
route_env: str = "",
|
|
) -> dict[str, Any]:
|
|
"""Get WS auth token (with cache).
|
|
|
|
Return directly on cache hit without re-requesting; treat as expiring
|
|
60 seconds before actual expiry, triggering refresh.
|
|
"""
|
|
# Lazily evict stale entries from other app_keys
|
|
cls.purge_expired()
|
|
|
|
cached = cls._cache.get(app_key)
|
|
if cached and cls.is_cache_valid(cached):
|
|
remain = int(cached["expire_ts"] - time.time())
|
|
logger.info("Using cached token (%ds remaining)", remain)
|
|
return dict(cached)
|
|
|
|
async with cls.get_refresh_lock(app_key):
|
|
cached = cls._cache.get(app_key)
|
|
if cached and cls.is_cache_valid(cached):
|
|
return dict(cached)
|
|
|
|
data = await cls.fetch(app_key, app_secret, api_domain, route_env)
|
|
|
|
duration: int = data.get("duration", 0)
|
|
expire_ts = time.time() + duration if duration > 0 else time.time() + 3600
|
|
|
|
cls._cache[app_key] = {
|
|
"token": data.get("token", ""),
|
|
"bot_id": data.get("bot_id", ""),
|
|
"duration": duration,
|
|
"product": data.get("product", ""),
|
|
"source": data.get("source", ""),
|
|
"expire_ts": expire_ts,
|
|
}
|
|
|
|
return dict(cls._cache[app_key])
|
|
|
|
# -- Public API: force refresh -----------------------------------------
|
|
|
|
@classmethod
|
|
async def force_refresh(
|
|
cls,
|
|
app_key: str,
|
|
app_secret: str,
|
|
api_domain: str,
|
|
route_env: str = "",
|
|
) -> dict[str, Any]:
|
|
"""Force refresh token (clear cache and re-sign)."""
|
|
logger.warning("[force-refresh] Clearing cache and re-signing token: app_key=****%s", app_key[-4:])
|
|
async with cls.get_refresh_lock(app_key):
|
|
cls._cache.pop(app_key, None)
|
|
data = await cls.fetch(app_key, app_secret, api_domain, route_env)
|
|
|
|
duration: int = data.get("duration", 0)
|
|
expire_ts = time.time() + duration if duration > 0 else time.time() + 3600
|
|
|
|
cls._cache[app_key] = {
|
|
"token": data.get("token", ""),
|
|
"bot_id": data.get("bot_id", ""),
|
|
"duration": duration,
|
|
"product": data.get("product", ""),
|
|
"source": data.get("source", ""),
|
|
"expire_ts": expire_ts,
|
|
}
|
|
|
|
return dict(cls._cache[app_key])
|
|
|
|
|
|
from dataclasses import dataclass, field as dc_field
|
|
|
|
@dataclass
|
|
class InboundContext:
|
|
"""Mutable context flowing through the inbound middleware pipeline.
|
|
|
|
Each middleware reads/writes fields on this context. The pipeline
|
|
engine passes it to every middleware in registration order.
|
|
"""
|
|
|
|
adapter: Any # YuanbaoAdapter (forward-ref avoids circular import)
|
|
raw_frames: list = dc_field(default_factory=list) # Raw bytes frames (debounce-aggregated)
|
|
|
|
# Populated by DecodeMiddleware
|
|
push: Optional[dict] = None
|
|
decoded_via: str = "" # "json" | "protobuf"
|
|
|
|
# Extracted from push by FieldExtractMiddleware
|
|
from_account: str = ""
|
|
group_code: str = ""
|
|
group_name: str = ""
|
|
sender_nickname: str = ""
|
|
msg_body: list = dc_field(default_factory=list)
|
|
msg_id: str = ""
|
|
cloud_custom_data: str = ""
|
|
|
|
# Derived by ChatRoutingMiddleware
|
|
chat_id: str = ""
|
|
chat_type: str = "" # "dm" | "group"
|
|
chat_name: str = ""
|
|
|
|
# Populated by ContentExtractMiddleware
|
|
raw_text: str = ""
|
|
media_refs: list = dc_field(default_factory=list)
|
|
|
|
# Owner command detection
|
|
owner_command: Optional[str] = None
|
|
|
|
# Source built by BuildSourceMiddleware
|
|
source: Optional[Any] = None # SessionSource
|
|
|
|
# Populated by ClassifyMessageTypeMiddleware
|
|
msg_type: Optional[Any] = None # MessageType
|
|
|
|
# Populated by QuoteContextMiddleware
|
|
reply_to_message_id: Optional[str] = None
|
|
reply_to_text: Optional[str] = None
|
|
|
|
# Populated by MediaResolveMiddleware
|
|
media_urls: list = dc_field(default_factory=list)
|
|
media_types: list = dc_field(default_factory=list)
|
|
|
|
# Populated by ExtractContentMiddleware
|
|
link_urls: list = dc_field(default_factory=list)
|
|
|
|
# Populated by GroupAttributionMiddleware
|
|
channel_prompt: Optional[str] = None
|
|
|
|
|
|
class InboundMiddleware(ABC):
|
|
"""Abstract base class for all inbound pipeline middlewares.
|
|
|
|
Subclasses must:
|
|
- Set ``name`` as a class-level attribute (used for pipeline registration
|
|
and dynamic insertion/removal).
|
|
- Implement ``async handle(ctx, next_fn)`` containing the middleware logic.
|
|
|
|
Convention:
|
|
- Call ``await next_fn()`` to pass control to the next middleware.
|
|
- Return without calling ``next_fn`` to **stop** the pipeline.
|
|
"""
|
|
|
|
name: str = "" # Override in each subclass
|
|
|
|
@abstractmethod
|
|
async def handle(self, ctx: InboundContext, next_fn: Callable) -> None:
|
|
"""Process *ctx* and optionally call *next_fn* to continue the pipeline."""
|
|
|
|
async def __call__(self, ctx: InboundContext, next_fn: Callable) -> None:
|
|
"""Allow middleware instances to be called directly (duck-typing compat)."""
|
|
return await self.handle(ctx, next_fn)
|
|
|
|
def __repr__(self) -> str:
|
|
return f"<{self.__class__.__name__} name={self.name!r}>"
|
|
|
|
|
|
class InboundPipeline:
|
|
"""Onion-model middleware pipeline engine for inbound message processing.
|
|
|
|
Inspired by OpenClaw's MessagePipeline (extensions/yuanbao/src/business/
|
|
pipeline/engine.ts). Supports named middlewares, conditional guards
|
|
(``when``), and ``use_before`` / ``use_after`` / ``remove`` for dynamic
|
|
composition.
|
|
|
|
Accepts both ``InboundMiddleware`` instances (OOP style) and plain
|
|
``async def(ctx, next_fn)`` callables (functional style) for flexibility.
|
|
"""
|
|
|
|
def __init__(self) -> None:
|
|
self._middlewares: list = [] # list of (name, handler, when_fn | None)
|
|
|
|
# -- Internal helpers --------------------------------------------------
|
|
|
|
@staticmethod
|
|
def _normalize(name_or_mw, handler=None):
|
|
"""Normalize (name, handler) or (InboundMiddleware,) into (name, callable)."""
|
|
if isinstance(name_or_mw, InboundMiddleware):
|
|
return name_or_mw.name, name_or_mw
|
|
# Functional style: name is a str, handler is a callable
|
|
return name_or_mw, handler
|
|
|
|
# -- Registration API --------------------------------------------------
|
|
|
|
def use(self, name_or_mw, handler=None, when=None) -> "InboundPipeline":
|
|
"""Append a middleware to the end of the pipeline.
|
|
|
|
Accepts either:
|
|
- ``pipeline.use(SomeMiddleware())`` — OOP style
|
|
- ``pipeline.use("name", some_fn)`` — functional style
|
|
"""
|
|
name, h = self._normalize(name_or_mw, handler)
|
|
self._middlewares.append((name, h, when))
|
|
return self
|
|
|
|
def use_before(self, target: str, name_or_mw, handler=None, when=None) -> "InboundPipeline":
|
|
"""Insert a middleware before *target* (by name). Appends if not found."""
|
|
name, h = self._normalize(name_or_mw, handler)
|
|
idx = next((i for i, (n, _, _) in enumerate(self._middlewares) if n == target), None)
|
|
entry = (name, h, when)
|
|
if idx is None:
|
|
self._middlewares.append(entry)
|
|
else:
|
|
self._middlewares.insert(idx, entry)
|
|
return self
|
|
|
|
def use_after(self, target: str, name_or_mw, handler=None, when=None) -> "InboundPipeline":
|
|
"""Insert a middleware after *target* (by name). Appends if not found."""
|
|
name, h = self._normalize(name_or_mw, handler)
|
|
idx = next((i for i, (n, _, _) in enumerate(self._middlewares) if n == target), None)
|
|
entry = (name, h, when)
|
|
if idx is None:
|
|
self._middlewares.append(entry)
|
|
else:
|
|
self._middlewares.insert(idx + 1, entry)
|
|
return self
|
|
|
|
def remove(self, name: str) -> "InboundPipeline":
|
|
"""Remove a middleware by name."""
|
|
self._middlewares = [(n, h, w) for n, h, w in self._middlewares if n != name]
|
|
return self
|
|
|
|
@property
|
|
def middleware_names(self) -> list:
|
|
"""Return ordered list of registered middleware names (for testing)."""
|
|
return [n for n, _, _ in self._middlewares]
|
|
|
|
# -- Execution ---------------------------------------------------------
|
|
|
|
async def execute(self, ctx: InboundContext) -> None:
|
|
"""Run all middlewares in order. Each middleware receives ``(ctx, next_fn)``."""
|
|
chain = self._middlewares
|
|
index = 0
|
|
|
|
async def next_fn() -> None:
|
|
nonlocal index
|
|
while index < len(chain):
|
|
name, handler, when_fn = chain[index]
|
|
index += 1
|
|
# Conditional guard: skip when returns False
|
|
if when_fn is not None and not when_fn(ctx):
|
|
continue
|
|
try:
|
|
await handler(ctx, next_fn)
|
|
except Exception:
|
|
logger.error("[InboundPipeline] middleware [%s] error", name, exc_info=True)
|
|
raise
|
|
return
|
|
# End of chain — nothing more to do
|
|
|
|
await next_fn()
|
|
class DecodeMiddleware(InboundMiddleware):
|
|
"""Decode raw inbound frames from JSON or Protobuf into ctx.push.
|
|
|
|
Encapsulates JSON push parsing (aligned with TS decodeFromContent)
|
|
and Protobuf decoding via ``decode_inbound_push``.
|
|
"""
|
|
|
|
name = "decode"
|
|
|
|
# -- JSON push parsing -------------------------------------------------
|
|
|
|
@staticmethod
|
|
def convert_json_msg_body(raw_body: list) -> list:
|
|
"""Normalize raw JSON msg_body array to [{"msg_type": str, "msg_content": dict}].
|
|
|
|
Compatible with both PascalCase (MsgType/MsgContent) and
|
|
snake_case (msg_type/msg_content) naming.
|
|
"""
|
|
result = []
|
|
for item in raw_body or []:
|
|
if not isinstance(item, dict):
|
|
continue
|
|
msg_type = item.get("msg_type") or item.get("MsgType", "")
|
|
msg_content = item.get("msg_content") or item.get("MsgContent", {})
|
|
if isinstance(msg_content, str):
|
|
try:
|
|
msg_content = json.loads(msg_content)
|
|
except Exception:
|
|
msg_content = {"text": msg_content}
|
|
result.append({"msg_type": msg_type, "msg_content": msg_content or {}})
|
|
return result
|
|
|
|
@staticmethod
|
|
def parse_json_push(raw_json: dict) -> dict | None:
|
|
"""Convert JSON-format push to a dict with the same structure as
|
|
``decode_inbound_push``.
|
|
|
|
Supports standard callback format (callback_command + from_account +
|
|
msg_body) and legacy format fields (GroupId, MsgSeq, MsgKey, MsgBody,
|
|
etc.).
|
|
"""
|
|
if not raw_json:
|
|
return None
|
|
|
|
# Tencent IM callback format uses PascalCase (From_Account, To_Account, MsgBody).
|
|
# Internal format uses snake_case (from_account, to_account, msg_body).
|
|
# Support both.
|
|
from_account = (
|
|
raw_json.get("from_account", "")
|
|
or raw_json.get("From_Account", "")
|
|
)
|
|
group_code = (
|
|
raw_json.get("group_code", "")
|
|
or raw_json.get("GroupId", "")
|
|
or raw_json.get("group_id", "")
|
|
)
|
|
msg_body_raw = (
|
|
raw_json.get("msg_body", [])
|
|
or raw_json.get("MsgBody", [])
|
|
)
|
|
msg_body = DecodeMiddleware.convert_json_msg_body(msg_body_raw)
|
|
|
|
# Recall callbacks may have neither from_account nor msg_body.
|
|
if not from_account and not msg_body and not raw_json.get("callback_command"):
|
|
return None
|
|
|
|
return {
|
|
"callback_command": raw_json.get("callback_command", ""),
|
|
"from_account": from_account,
|
|
"to_account": raw_json.get("to_account", "") or raw_json.get("To_Account", ""),
|
|
"sender_nickname": raw_json.get("sender_nickname", "") or raw_json.get("nick_name", ""),
|
|
"group_code": group_code,
|
|
"group_name": raw_json.get("group_name", ""),
|
|
"msg_seq": raw_json.get("msg_seq", 0) or raw_json.get("MsgSeq", 0),
|
|
"msg_id": raw_json.get("msg_id", "") or raw_json.get("msg_key", "") or raw_json.get("MsgKey", ""),
|
|
"msg_body": msg_body,
|
|
"cloud_custom_data": raw_json.get("cloud_custom_data", "") or raw_json.get("CloudCustomData", ""),
|
|
"bot_owner_id": raw_json.get("bot_owner_id", "") or raw_json.get("botOwnerId", ""),
|
|
"recall_msg_seq_list": raw_json.get("recall_msg_seq_list") or None,
|
|
"trace_id": (raw_json.get("log_ext") or {}).get("trace_id", "") if isinstance(raw_json.get("log_ext"), dict) else "",
|
|
}
|
|
|
|
# -- Pipeline handler --------------------------------------------------
|
|
|
|
def _decode_single(self, adapter, data: bytes) -> tuple:
|
|
"""Decode a single raw frame into (push_dict, decoded_via) or (None, '')."""
|
|
try:
|
|
conn_json = json.loads(data.decode("utf-8"))
|
|
except Exception:
|
|
conn_json = None
|
|
|
|
if isinstance(conn_json, dict):
|
|
push = self.parse_json_push(conn_json)
|
|
if push:
|
|
return push, "json"
|
|
else:
|
|
try:
|
|
push = decode_inbound_push(data)
|
|
except Exception:
|
|
push = None
|
|
if push:
|
|
return push, "protobuf"
|
|
|
|
return None, ""
|
|
|
|
async def handle(self, ctx: InboundContext, next_fn) -> None:
|
|
data_list = ctx.raw_frames
|
|
if not data_list:
|
|
return # Stop pipeline — nothing to decode
|
|
|
|
merged_push = None
|
|
decoded_via = ""
|
|
|
|
for data in data_list:
|
|
push, via = self._decode_single(ctx.adapter, data)
|
|
if not push:
|
|
logger.info(
|
|
"[%s] Push decoded but no valid message. raw hex(first64)=%s",
|
|
ctx.adapter.name, data.hex()[:128] if data else "(empty)",
|
|
)
|
|
continue
|
|
|
|
if merged_push is None:
|
|
# First valid push becomes the base
|
|
merged_push = push
|
|
decoded_via = via
|
|
logger.info(
|
|
"[%s] Frame decoded (via=%s): len=%d",
|
|
ctx.adapter.name, via, len(data),
|
|
)
|
|
else:
|
|
# Subsequent pushes: merge msg_body into the base with a
|
|
extra_body = push.get("msg_body", [])
|
|
if extra_body:
|
|
_sep = {"msg_type": "TIMTextElem", "msg_content": {"text": "\n"}}
|
|
merged_push["msg_body"] = merged_push.get("msg_body", []) + [_sep] + extra_body
|
|
logger.info(
|
|
"[%s] Merged %d extra msg_body elements from aggregated push",
|
|
ctx.adapter.name, len(extra_body),
|
|
)
|
|
|
|
if not merged_push:
|
|
return # Stop pipeline
|
|
|
|
ctx.push = merged_push
|
|
ctx.decoded_via = decoded_via
|
|
|
|
logger.info(
|
|
"[%s] Push decoded (via=%s): from=%s group=%s msg_id=%s msg_types=%s",
|
|
ctx.adapter.name, ctx.decoded_via,
|
|
ctx.push.get("from_account", ""),
|
|
ctx.push.get("group_code", ""),
|
|
ctx.push.get("msg_id", ""),
|
|
[e.get("msg_type", "") for e in ctx.push.get("msg_body", [])],
|
|
)
|
|
logger.debug("[%s] Push payload: %s", ctx.adapter.name, ctx.push)
|
|
|
|
await next_fn()
|
|
|
|
|
|
class ExtractFieldsMiddleware(InboundMiddleware):
|
|
"""Extract common fields from ctx.push into ctx attributes."""
|
|
|
|
name = "extract-fields"
|
|
|
|
async def handle(self, ctx: InboundContext, next_fn) -> None:
|
|
push = ctx.push
|
|
ctx.from_account = push.get("from_account", "")
|
|
ctx.group_code = push.get("group_code", "")
|
|
ctx.group_name = push.get("group_name", "")
|
|
ctx.sender_nickname = push.get("sender_nickname", "")
|
|
ctx.msg_body = push.get("msg_body", [])
|
|
ctx.msg_id = push.get("msg_id", "")
|
|
ctx.cloud_custom_data = push.get("cloud_custom_data", "")
|
|
await next_fn()
|
|
|
|
|
|
class DedupMiddleware(InboundMiddleware):
|
|
"""Inbound message deduplication."""
|
|
|
|
name = "dedup"
|
|
|
|
async def handle(self, ctx: InboundContext, next_fn) -> None:
|
|
if ctx.msg_id and ctx.adapter._dedup.is_duplicate(ctx.msg_id):
|
|
logger.debug("[%s] Duplicate message ignored: msg_id=%s", ctx.adapter.name, ctx.msg_id)
|
|
return # Stop pipeline
|
|
await next_fn()
|
|
|
|
|
|
class RecallGuardMiddleware(InboundMiddleware):
|
|
"""Intercept Group.CallbackAfterRecallMsg / C2C.CallbackAfterMsgWithDraw.
|
|
|
|
Branch A: message in transcript (observed, not yet consumed) → redact content
|
|
Branch B: message not in transcript → append system note
|
|
Branch C: message currently being processed → silent interrupt + delayed redact
|
|
"""
|
|
|
|
name = "recall_guard"
|
|
|
|
_RECALL_COMMANDS = frozenset({
|
|
"Group.CallbackAfterRecallMsg",
|
|
"C2C.CallbackAfterMsgWithDraw",
|
|
})
|
|
_REDACTED = "[This message was recalled/withdrawn by the sender; original content removed]"
|
|
|
|
async def handle(self, ctx: InboundContext, next_fn) -> None:
|
|
cmd = (ctx.push or {}).get("callback_command", "")
|
|
if cmd not in self._RECALL_COMMANDS:
|
|
await next_fn()
|
|
return
|
|
self._handle_recall(ctx, cmd)
|
|
|
|
@staticmethod
|
|
def _build_source(adapter, group_code: str, from_account: str):
|
|
return adapter.build_source(
|
|
chat_id=(f"group:{group_code}" if group_code else f"direct:{from_account}"),
|
|
chat_type="group" if group_code else "dm",
|
|
user_id=from_account or None,
|
|
thread_id="main" if group_code else None,
|
|
)
|
|
|
|
def _handle_recall(self, ctx: InboundContext, cmd: str) -> None:
|
|
adapter = ctx.adapter
|
|
push = ctx.push or {}
|
|
|
|
if cmd == "Group.CallbackAfterRecallMsg":
|
|
seq_list = push.get("recall_msg_seq_list") or []
|
|
else:
|
|
mid = push.get("msg_id") or ""
|
|
seq = push.get("msg_seq")
|
|
seq_list = [{"msg_id": mid, "msg_seq": seq}] if (mid or seq) else []
|
|
|
|
if not seq_list:
|
|
logger.debug("[%s] Recall callback with empty seq_list, skipping", adapter.name)
|
|
return
|
|
|
|
group_code = (push.get("group_code") or "").strip()
|
|
from_account = (push.get("from_account") or "").strip()
|
|
|
|
for seq_entry in seq_list:
|
|
recalled_id = seq_entry.get("msg_id") or str(seq_entry.get("msg_seq") or "")
|
|
if not recalled_id:
|
|
continue
|
|
|
|
matched_sk = self._find_processing_session(adapter, recalled_id)
|
|
if matched_sk is not None:
|
|
self._interrupt_for_recall(adapter, matched_sk, recalled_id, group_code, from_account)
|
|
else:
|
|
recalled_content = adapter._msg_content_cache.get(recalled_id)
|
|
self._patch_transcript(adapter, recalled_id, group_code, from_account, recalled_content)
|
|
|
|
# -- Branch C: interrupt currently-processing message ---------------
|
|
|
|
@staticmethod
|
|
def _find_processing_session(adapter, recalled_id: str) -> Optional[str]:
|
|
for sk, mid in adapter._processing_msg_ids.items():
|
|
if mid == recalled_id and sk in adapter._active_sessions:
|
|
return sk
|
|
return None
|
|
|
|
@classmethod
|
|
def _interrupt_for_recall(cls, adapter, session_key: str, recalled_id: str,
|
|
group_code: str, from_account: str) -> None:
|
|
where = f"group {group_code}" if group_code else f"direct chat with {from_account}"
|
|
recall_text = (
|
|
f"[CRITICAL — MESSAGE RECALLED] The user message that triggered "
|
|
f"your current task (message_id=\"{recalled_id}\") in {where} has "
|
|
f"been recalled/withdrawn by the sender. "
|
|
f"IGNORE any prior system note asking you to finish processing "
|
|
f"tool results — the original request is void. "
|
|
f"Do NOT continue the task, do NOT call more tools, do NOT "
|
|
f"reference the recalled content. "
|
|
f"Reply only with a brief acknowledgment such as "
|
|
f"\"The message has been recalled.\" in the "
|
|
f"language the user was using."
|
|
)
|
|
|
|
synth_event = MessageEvent(
|
|
text=recall_text,
|
|
message_type=MessageType.TEXT,
|
|
source=cls._build_source(adapter, group_code, from_account),
|
|
internal=True,
|
|
)
|
|
# Set pending + signal directly (bypass handle_message to avoid busy-ack).
|
|
# May overwrite a user message pending in the same ~200ms window — acceptable.
|
|
adapter._pending_messages[session_key] = synth_event
|
|
active_event = adapter._active_sessions.get(session_key)
|
|
if active_event is not None:
|
|
active_event.set()
|
|
|
|
logger.info("[%s] Recall interrupt: msg_id=%s session=%s", adapter.name, recalled_id, session_key[:30])
|
|
|
|
# The interrupted turn will persist the recalled content *after* our
|
|
# interrupt — schedule a delayed redaction to clean it up.
|
|
recalled_text = adapter._processing_msg_texts.get(session_key, "")
|
|
if recalled_text:
|
|
cls._schedule_content_redact(adapter, session_key, recalled_text, group_code, from_account)
|
|
|
|
@classmethod
|
|
def _schedule_content_redact(cls, adapter, session_key: str, recalled_text: str,
|
|
group_code: str, from_account: str) -> None:
|
|
async def _redact() -> None:
|
|
store = getattr(adapter, "_session_store", None)
|
|
if not store:
|
|
return
|
|
try:
|
|
sid = store.get_or_create_session(
|
|
cls._build_source(adapter, group_code, from_account),
|
|
).session_id
|
|
except Exception:
|
|
return
|
|
# Poll until the recalled content appears in transcript — the
|
|
# interrupted turn hasn't finished writing yet when scheduled.
|
|
for _ in range(30):
|
|
await asyncio.sleep(0.5)
|
|
try:
|
|
transcript = store.load_transcript(sid)
|
|
except Exception:
|
|
continue
|
|
for entry in transcript:
|
|
if entry.get("role") == "user" and entry.get("content") == recalled_text:
|
|
entry["content"] = cls._REDACTED
|
|
try:
|
|
store.rewrite_transcript(sid, transcript)
|
|
logger.info("[%s] Recall redact: session %s", adapter.name, session_key[:30])
|
|
except Exception as exc:
|
|
logger.warning("[%s] Recall redact failed: %s", adapter.name, exc)
|
|
return
|
|
logger.debug("[%s] Recall redact: content not found after polling, session %s", adapter.name, session_key[:30])
|
|
|
|
task = asyncio.create_task(_redact())
|
|
adapter._background_tasks.add(task)
|
|
task.add_done_callback(adapter._background_tasks.discard)
|
|
|
|
# -- Branch A/B: patch transcript (session idle) --------------------
|
|
|
|
@classmethod
|
|
def _patch_transcript(cls, adapter, recalled_id: str, group_code: str,
|
|
from_account: str, recalled_content: Optional[str] = None) -> None:
|
|
store = getattr(adapter, "_session_store", None)
|
|
if not store:
|
|
return
|
|
try:
|
|
sid = store.get_or_create_session(cls._build_source(adapter, group_code, from_account)).session_id
|
|
except Exception as exc:
|
|
logger.warning("[%s] Recall: failed to resolve session: %s", adapter.name, exc)
|
|
return
|
|
|
|
# Read JSONL directly — SQLite doesn't preserve message_id field.
|
|
transcript: list = []
|
|
try:
|
|
path = store.get_transcript_path(sid)
|
|
if path.exists():
|
|
with open(path, "r", encoding="utf-8") as f:
|
|
for line in f:
|
|
line = line.strip()
|
|
if line:
|
|
try:
|
|
transcript.append(json.loads(line))
|
|
except json.JSONDecodeError:
|
|
pass
|
|
except Exception as exc:
|
|
logger.warning("[%s] Recall: failed to load transcript: %s", adapter.name, exc)
|
|
return
|
|
|
|
# Branch A: redact — try message_id first, then content fallback.
|
|
# Observed messages have message_id; agent-processed @bot messages
|
|
# only have content (run.py doesn't write message_id to transcript).
|
|
target = None
|
|
for entry in transcript:
|
|
if entry.get("message_id") == recalled_id:
|
|
target = entry
|
|
break
|
|
if target is None and recalled_content:
|
|
for entry in transcript:
|
|
if entry.get("role") == "user" and entry.get("content") == recalled_content:
|
|
target = entry
|
|
break
|
|
if target is not None:
|
|
target["content"] = cls._REDACTED
|
|
try:
|
|
store.rewrite_transcript(sid, transcript)
|
|
logger.info("[%s] Recall: redacted msg_id=%s (branch A)", adapter.name, recalled_id)
|
|
except Exception as exc:
|
|
logger.warning("[%s] Recall: rewrite_transcript failed: %s", adapter.name, exc)
|
|
return
|
|
|
|
# Branch B: not found in transcript → append system note
|
|
store.append_to_transcript(sid, {
|
|
"role": "system",
|
|
"content": f'[recall] message_id="{recalled_id}" has been recalled; do not quote or reference it.',
|
|
"timestamp": datetime.now(tz=timezone.utc).isoformat(),
|
|
})
|
|
logger.info("[%s] Recall: system note for msg_id=%s (branch B)", adapter.name, recalled_id)
|
|
|
|
|
|
class SkipSelfMiddleware(InboundMiddleware):
|
|
"""Filter out bot's own messages."""
|
|
|
|
name = "skip-self"
|
|
|
|
@staticmethod
|
|
def _is_self_reference(from_account: str, bot_id: Optional[str]) -> bool:
|
|
"""Detect whether the message is from the bot itself."""
|
|
if not from_account or not bot_id:
|
|
return False
|
|
return from_account == bot_id
|
|
|
|
async def handle(self, ctx: InboundContext, next_fn) -> None:
|
|
if self._is_self_reference(ctx.from_account, ctx.adapter._bot_id):
|
|
logger.debug("[%s] Ignoring self-sent message from %s", ctx.adapter.name, ctx.from_account)
|
|
return # Stop pipeline
|
|
await next_fn()
|
|
|
|
|
|
class ChatRoutingMiddleware(InboundMiddleware):
|
|
"""Determine chat_id, chat_type, chat_name from push fields."""
|
|
|
|
name = "chat-routing"
|
|
|
|
async def handle(self, ctx: InboundContext, next_fn) -> None:
|
|
if ctx.group_code:
|
|
ctx.chat_id = f"group:{ctx.group_code}"
|
|
ctx.chat_type = "group"
|
|
ctx.chat_name = ctx.group_name or ctx.group_code
|
|
else:
|
|
ctx.chat_id = f"direct:{ctx.from_account}"
|
|
ctx.chat_type = "dm"
|
|
ctx.chat_name = ctx.sender_nickname or ctx.from_account
|
|
await next_fn()
|
|
|
|
|
|
class AccessPolicy:
|
|
"""Platform-level DM / Group access control policy.
|
|
|
|
Encapsulates the allow/deny logic so that both inbound middleware
|
|
and outbound ``send_dm`` can share the same rules without reaching
|
|
into adapter internals.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
dm_policy: str,
|
|
dm_allow_from: list[str],
|
|
group_policy: str,
|
|
group_allow_from: list[str],
|
|
) -> None:
|
|
self._dm_policy = dm_policy
|
|
self._dm_allow_from = dm_allow_from
|
|
self._group_policy = group_policy
|
|
self._group_allow_from = group_allow_from
|
|
|
|
def is_dm_allowed(self, sender_id: str) -> bool:
|
|
"""Platform-level DM inbound filter (open / allowlist / disabled)."""
|
|
if self._dm_policy == "disabled":
|
|
return False
|
|
if self._dm_policy == "allowlist":
|
|
return sender_id.strip() in self._dm_allow_from
|
|
return True
|
|
|
|
def is_group_allowed(self, group_code: str) -> bool:
|
|
"""Platform-level group chat inbound filter (open / allowlist / disabled)."""
|
|
if self._group_policy == "disabled":
|
|
return False
|
|
if self._group_policy == "allowlist":
|
|
return group_code.strip() in self._group_allow_from
|
|
return True
|
|
|
|
@property
|
|
def dm_policy(self) -> str:
|
|
return self._dm_policy
|
|
|
|
@property
|
|
def group_policy(self) -> str:
|
|
return self._group_policy
|
|
|
|
|
|
class AccessGuardMiddleware(InboundMiddleware):
|
|
"""Platform-level DM/Group access control filter."""
|
|
|
|
name = "access-guard"
|
|
|
|
async def handle(self, ctx: InboundContext, next_fn) -> None:
|
|
adapter = ctx.adapter
|
|
policy: AccessPolicy = adapter._access_policy
|
|
if ctx.chat_type == "dm":
|
|
if not policy.is_dm_allowed(ctx.from_account):
|
|
logger.debug(
|
|
"[%s] DM from %s blocked by dm_policy=%s",
|
|
adapter.name, ctx.from_account, policy.dm_policy,
|
|
)
|
|
return # Stop pipeline
|
|
elif ctx.chat_type == "group":
|
|
if not policy.is_group_allowed(ctx.group_code):
|
|
logger.debug(
|
|
"[%s] Group %s blocked by group_policy=%s",
|
|
adapter.name, ctx.group_code, policy.group_policy,
|
|
)
|
|
return # Stop pipeline
|
|
await next_fn()
|
|
|
|
|
|
class AutoSetHomeMiddleware(InboundMiddleware):
|
|
"""Auto-designate the first inbound conversation as Yuanbao home channel.
|
|
|
|
Triggers when no home channel is configured, or when an existing group-chat
|
|
home is superseded by the first DM (direct > group upgrade).
|
|
Silent: writes config.yaml and env, no user-facing message.
|
|
"""
|
|
|
|
name = "auto-sethome"
|
|
|
|
async def handle(self, ctx: InboundContext, next_fn) -> None:
|
|
adapter = ctx.adapter
|
|
if not adapter._auto_sethome_done:
|
|
_cur_home = os.getenv("YUANBAO_HOME_CHANNEL", "")
|
|
_should_set = (
|
|
not _cur_home
|
|
or (_cur_home.startswith("group:") and ctx.chat_type == "dm")
|
|
)
|
|
if ctx.chat_type == "dm":
|
|
adapter._auto_sethome_done = True # DM seen — no further upgrades needed
|
|
if _should_set:
|
|
try:
|
|
from hermes_constants import get_hermes_home
|
|
from utils import atomic_yaml_write
|
|
import yaml
|
|
|
|
_home = get_hermes_home()
|
|
config_path = _home / "config.yaml"
|
|
user_config: dict = {}
|
|
if config_path.exists():
|
|
with open(config_path, encoding="utf-8") as f:
|
|
user_config = yaml.safe_load(f) or {}
|
|
user_config["YUANBAO_HOME_CHANNEL"] = ctx.chat_id
|
|
atomic_yaml_write(config_path, user_config)
|
|
os.environ["YUANBAO_HOME_CHANNEL"] = str(ctx.chat_id)
|
|
logger.info(
|
|
"[%s] Auto-sethome: designated %s (%s) as Yuanbao home channel",
|
|
adapter.name, ctx.chat_id, ctx.chat_name,
|
|
)
|
|
# Silent auto-sethome: no user-facing message, only log
|
|
except Exception as e:
|
|
logger.warning("[%s] Auto-sethome failed: %s", adapter.name, e)
|
|
await next_fn()
|
|
|
|
|
|
class ExtractContentMiddleware(InboundMiddleware):
|
|
"""Extract raw text and media refs from msg_body."""
|
|
|
|
name = "extract-content"
|
|
|
|
_CARD_CONTENT_MAX_LENGTH = 1000
|
|
|
|
@staticmethod
|
|
def _format_shared_link(custom: dict) -> str:
|
|
"""Format elem_type 1010 (share card) into bracket-placeholder text."""
|
|
title = custom.get("title", "")
|
|
link = custom.get("link", "")
|
|
header = f"[share_card: {title} | {link}]" if link else f"[share_card: {title}]"
|
|
lines = [header]
|
|
max_len = ExtractContentMiddleware._CARD_CONTENT_MAX_LENGTH
|
|
for field in ("card_content", "wechat_des"):
|
|
val = custom.get(field)
|
|
if val and isinstance(val, str):
|
|
preview = val[:max_len] + "...(truncated)" if len(val) > max_len else val
|
|
lines.append(f"Preview: {preview}")
|
|
break
|
|
if link:
|
|
lines.append("[visit link for full content]")
|
|
return "\n".join(lines)
|
|
|
|
@staticmethod
|
|
def _format_link_understanding(custom: dict) -> Optional[str]:
|
|
"""Format elem_type 1007 (link understanding card) into bracket-placeholder text."""
|
|
content = custom.get("content")
|
|
if not content:
|
|
return None
|
|
try:
|
|
parsed = json.loads(content)
|
|
link = parsed.get("link") if isinstance(parsed, dict) else None
|
|
except (json.JSONDecodeError, TypeError):
|
|
link = None
|
|
if not link or not isinstance(link, str):
|
|
return None
|
|
return f"[link: {link} | visit link for full content]"
|
|
|
|
@classmethod
|
|
def _extract_text(cls, msg_body: list) -> str:
|
|
"""Extract plain text content from MsgBody.
|
|
|
|
- TIMTextElem -> text field
|
|
- TIMImageElem -> "[image]"
|
|
- TIMFileElem -> "[file: {filename}]"
|
|
- TIMSoundElem -> "[voice]"
|
|
- TIMVideoFileElem -> "[video]"
|
|
- TIMFaceElem -> "[emoji: {name}]" or "[emoji]"
|
|
- TIMCustomElem -> try to extract data field, otherwise "[custom message]"
|
|
- Multiple elems joined with spaces
|
|
"""
|
|
parts: list[str] = []
|
|
for elem in msg_body:
|
|
elem_type: str = elem.get("msg_type", "")
|
|
content: dict = elem.get("msg_content", {})
|
|
|
|
if elem_type == "TIMTextElem":
|
|
text = content.get("text", "")
|
|
if text:
|
|
parts.append(text)
|
|
elif elem_type == "TIMImageElem":
|
|
parts.append("[image]")
|
|
elif elem_type == "TIMFileElem":
|
|
filename = content.get("file_name", content.get("fileName", content.get("filename", "")))
|
|
parts.append(f"[file: {filename}]" if filename else "[file]")
|
|
elif elem_type == "TIMSoundElem":
|
|
parts.append("[voice]")
|
|
elif elem_type == "TIMVideoFileElem":
|
|
parts.append("[video]")
|
|
elif elem_type == "TIMCustomElem":
|
|
data_val = content.get("data", "")
|
|
if data_val:
|
|
try:
|
|
custom = json.loads(data_val)
|
|
if not isinstance(custom, dict):
|
|
parts.append("[unsupported message type]")
|
|
continue
|
|
ctype = custom.get("elem_type")
|
|
if ctype == 1002:
|
|
parts.append(custom.get("text", "[mention]"))
|
|
elif ctype == 1010:
|
|
parts.append(cls._format_shared_link(custom))
|
|
elif ctype == 1007:
|
|
text = cls._format_link_understanding(custom)
|
|
if text:
|
|
parts.append(text)
|
|
else:
|
|
parts.append("[unsupported message type]")
|
|
else:
|
|
parts.append("[unsupported message type]")
|
|
except (json.JSONDecodeError, TypeError):
|
|
parts.append(data_val)
|
|
else:
|
|
parts.append("[unsupported message type]")
|
|
elif elem_type == "TIMFaceElem":
|
|
# Sticker/emoji: extract name from data JSON
|
|
raw_data = content.get("data", "")
|
|
face_name = ""
|
|
if raw_data:
|
|
try:
|
|
face_data = json.loads(raw_data)
|
|
face_name = (face_data.get("name") or "").strip()
|
|
except (json.JSONDecodeError, TypeError, AttributeError):
|
|
pass
|
|
parts.append(f"[emoji: {face_name}]" if face_name else "[emoji]")
|
|
elif elem_type:
|
|
# Unknown element type — include type as placeholder
|
|
parts.append(f"[{elem_type}]")
|
|
|
|
return " ".join(parts) if parts else ""
|
|
|
|
@staticmethod
|
|
def _rewrite_slash_command(text: str) -> str:
|
|
"""Normalize input text: strip whitespace and convert full-width slash
|
|
(Chinese input method) to ASCII slash so commands are recognized correctly.
|
|
"""
|
|
text = text.strip()
|
|
if text.startswith('\uff0f'): # Full-width slash
|
|
text = '/' + text[1:]
|
|
return text
|
|
|
|
@staticmethod
|
|
def _extract_inbound_media_refs(msg_body: list) -> List[Dict[str, str]]:
|
|
"""Extract inbound image/file references from TIM msg_body.
|
|
|
|
Return example:
|
|
[{"kind": "image", "url": "https://..."}, {"kind": "file", "url": "...", "name": "a.pdf"}]
|
|
"""
|
|
refs: List[Dict[str, str]] = []
|
|
for elem in msg_body or []:
|
|
if not isinstance(elem, dict):
|
|
continue
|
|
msg_type = elem.get("msg_type", "")
|
|
content = elem.get("msg_content", {}) or {}
|
|
if not isinstance(content, dict):
|
|
continue
|
|
|
|
if msg_type == "TIMImageElem":
|
|
# Prefer medium image (index 1), fallback to index 0.
|
|
image_info_array = content.get("image_info_array")
|
|
if not isinstance(image_info_array, list):
|
|
image_info_array = []
|
|
image_info = None
|
|
if len(image_info_array) > 1 and isinstance(image_info_array[1], dict):
|
|
image_info = image_info_array[1]
|
|
elif len(image_info_array) > 0 and isinstance(image_info_array[0], dict):
|
|
image_info = image_info_array[0]
|
|
image_url = str((image_info or {}).get("url") or "").strip()
|
|
if image_url:
|
|
refs.append({"kind": "image", "url": image_url})
|
|
continue
|
|
|
|
if msg_type == "TIMFileElem":
|
|
file_url = str(content.get("url") or "").strip()
|
|
file_name = (
|
|
str(content.get("file_name") or "").strip()
|
|
or str(content.get("fileName") or "").strip()
|
|
or str(content.get("filename") or "").strip()
|
|
)
|
|
if file_url:
|
|
ref: Dict[str, str] = {"kind": "file", "url": file_url}
|
|
if file_name:
|
|
ref["name"] = file_name
|
|
refs.append(ref)
|
|
return refs
|
|
|
|
@staticmethod
|
|
def _extract_link_urls(msg_body: list) -> list:
|
|
"""Extract link URLs from share-card (1010) and link-understanding (1007) custom elems."""
|
|
urls: list[str] = []
|
|
for elem in msg_body or []:
|
|
if not isinstance(elem, dict) or elem.get("msg_type") != "TIMCustomElem":
|
|
continue
|
|
data_str = (elem.get("msg_content") or {}).get("data", "")
|
|
if not data_str:
|
|
continue
|
|
try:
|
|
custom = json.loads(data_str)
|
|
except (json.JSONDecodeError, TypeError):
|
|
continue
|
|
if not isinstance(custom, dict):
|
|
continue
|
|
ctype = custom.get("elem_type")
|
|
if ctype == 1010:
|
|
link = custom.get("link")
|
|
if link and isinstance(link, str):
|
|
urls.append(link)
|
|
elif ctype == 1007:
|
|
content = custom.get("content")
|
|
if content:
|
|
try:
|
|
parsed = json.loads(content)
|
|
link = parsed.get("link") if isinstance(parsed, dict) else None
|
|
if link and isinstance(link, str):
|
|
urls.append(link)
|
|
except (json.JSONDecodeError, TypeError):
|
|
pass
|
|
return urls
|
|
|
|
async def handle(self, ctx: InboundContext, next_fn) -> None:
|
|
ctx.raw_text = self._rewrite_slash_command(self._extract_text(ctx.msg_body))
|
|
ctx.media_refs = self._extract_inbound_media_refs(ctx.msg_body)
|
|
ctx.link_urls = self._extract_link_urls(ctx.msg_body)
|
|
await next_fn()
|
|
|
|
class PlaceholderFilterMiddleware(InboundMiddleware):
|
|
"""Skip pure placeholder messages (e.g. '[image]' with no media)."""
|
|
|
|
name = "placeholder-filter"
|
|
|
|
SKIPPABLE_PLACEHOLDERS: frozenset = frozenset({
|
|
"[image]", "[图片]", "[file]", "[文件]",
|
|
"[video]", "[视频]", "[voice]", "[语音]",
|
|
})
|
|
|
|
@classmethod
|
|
def is_skippable_placeholder(cls, text: str, media_count: int = 0) -> bool:
|
|
"""Detect whether the message is a pure placeholder (should be skipped)."""
|
|
if media_count > 0:
|
|
return False
|
|
stripped = text.strip()
|
|
return stripped in cls.SKIPPABLE_PLACEHOLDERS
|
|
|
|
async def handle(self, ctx: InboundContext, next_fn) -> None:
|
|
if self.is_skippable_placeholder(ctx.raw_text, len(ctx.media_refs)):
|
|
logger.debug("[%s] Skipping placeholder message: %r", ctx.adapter.name, ctx.raw_text)
|
|
return # Stop pipeline
|
|
await next_fn()
|
|
|
|
|
|
class OwnerCommandMiddleware(InboundMiddleware):
|
|
"""Detect bot-owner slash commands in group chat.
|
|
|
|
Identifies in-group allowlisted slash commands and determines sender identity.
|
|
Owner commands skip @Bot detection; non-owner attempts are rejected.
|
|
"""
|
|
|
|
name = "owner-command"
|
|
|
|
# Slash command allowlist that bot owner can execute in group without @Bot
|
|
ALLOWLIST: frozenset = frozenset({
|
|
"/new", "/reset", "/retry", "/undo", "/stop",
|
|
"/approve", "/deny", "/background", "/bg",
|
|
"/btw", "/queue", "/q",
|
|
})
|
|
|
|
@staticmethod
|
|
def _rewrite_slash_command(text: str) -> str:
|
|
"""Normalize full-width slash to ASCII slash and strip whitespace."""
|
|
text = text.strip()
|
|
if text.startswith('\uff0f'): # Full-width slash
|
|
text = '/' + text[1:]
|
|
return text
|
|
|
|
@classmethod
|
|
def _detect_owner_command(
|
|
cls,
|
|
*,
|
|
push: dict,
|
|
msg_body: list,
|
|
chat_type: str,
|
|
from_account: str,
|
|
) -> Tuple[Optional[str], Optional[str], bool]:
|
|
"""Identify allowlisted slash commands and determine sender identity.
|
|
|
|
Returns (cmd, cmd_line, is_owner):
|
|
- (None, None, False): Not an allowlisted command
|
|
- (cmd, cmd_line, True): Owner match
|
|
- (cmd, cmd_line, False): Allowlisted command but sender is not owner
|
|
"""
|
|
if chat_type != "group" or not cls.ALLOWLIST:
|
|
return None, None, False
|
|
|
|
# Extract TIMTextElem: only do command recognition with exactly one text segment
|
|
text_elems = [
|
|
e for e in (msg_body or [])
|
|
if e.get("msg_type") == "TIMTextElem"
|
|
]
|
|
if len(text_elems) != 1:
|
|
return None, None, False
|
|
|
|
text = (text_elems[0].get("msg_content") or {}).get("text", "")
|
|
cmd_line = cls._rewrite_slash_command(text)
|
|
if not cmd_line.startswith("/"):
|
|
return None, None, False
|
|
cmd = cmd_line.split(maxsplit=1)[0].lower()
|
|
if cmd not in cls.ALLOWLIST:
|
|
return None, None, False
|
|
|
|
# Sender identity check: bot owner <-> push.from_account == push.bot_owner_id
|
|
owner_id = (push or {}).get("bot_owner_id") or ""
|
|
# is_owner = bool(owner_id) and owner_id == from_account
|
|
is_owner = True
|
|
return cmd, cmd_line, is_owner
|
|
|
|
async def handle(self, ctx: InboundContext, next_fn) -> None:
|
|
adapter = ctx.adapter
|
|
matched_cmd, cmd_line, is_owner = self._detect_owner_command(
|
|
push=ctx.push,
|
|
msg_body=ctx.msg_body,
|
|
chat_type=ctx.chat_type,
|
|
from_account=ctx.from_account,
|
|
)
|
|
if matched_cmd and not is_owner:
|
|
# Non-owner tried an owner-only command — reject and stop
|
|
logger.info(
|
|
"[%s] Reject non-owner slash command: chat=%s from=%s cmd=%s",
|
|
adapter.name, ctx.chat_id, ctx.from_account, matched_cmd,
|
|
)
|
|
adapter._track_task(asyncio.create_task(
|
|
adapter.send(ctx.chat_id, f"⚠️ {matched_cmd} is only available to the creator in private chat mode"),
|
|
name=f"yuanbao-owner-cmd-denial-{matched_cmd}",
|
|
))
|
|
return # Stop pipeline
|
|
|
|
if matched_cmd and is_owner and cmd_line:
|
|
logger.info(
|
|
"[%s] Bot owner slash command: chat=%s from=%s cmd=%s",
|
|
adapter.name, ctx.chat_id, ctx.from_account, matched_cmd,
|
|
)
|
|
ctx.owner_command = matched_cmd
|
|
ctx.raw_text = cmd_line # Override with clean command text
|
|
await next_fn()
|
|
|
|
|
|
class BuildSourceMiddleware(InboundMiddleware):
|
|
"""Build SessionSource from context fields."""
|
|
|
|
name = "build-source"
|
|
|
|
async def handle(self, ctx: InboundContext, next_fn) -> None:
|
|
adapter = ctx.adapter
|
|
ctx.source = adapter.build_source(
|
|
chat_id=ctx.chat_id,
|
|
chat_type=ctx.chat_type,
|
|
chat_name=ctx.chat_name,
|
|
user_id=ctx.from_account or None,
|
|
user_name=ctx.sender_nickname or ctx.from_account,
|
|
thread_id="main" if ctx.chat_type == "group" else None,
|
|
)
|
|
await next_fn()
|
|
|
|
|
|
class GroupAtGuardMiddleware(InboundMiddleware):
|
|
"""In group chat, observe non-@bot messages; only reply on @Bot.
|
|
|
|
Owner commands skip @Bot detection (owner doesn't need to @Bot).
|
|
"""
|
|
|
|
name = "group-at-guard"
|
|
|
|
@staticmethod
|
|
def _is_at_bot(msg_body: list, bot_id: Optional[str]) -> bool:
|
|
"""Detect whether the message @Bot.
|
|
|
|
AT element format: TIMCustomElem, msg_content.data is a JSON string:
|
|
{"elem_type": 1002, "text": "@xxx", "user_id": "<botId>"}
|
|
Considered @Bot when elem_type == 1002 and user_id == bot_id.
|
|
"""
|
|
if not bot_id:
|
|
return False
|
|
for elem in msg_body:
|
|
if elem.get("msg_type") != "TIMCustomElem":
|
|
continue
|
|
data_str = elem.get("msg_content", {}).get("data", "")
|
|
if not data_str:
|
|
continue
|
|
try:
|
|
custom = json.loads(data_str)
|
|
except (json.JSONDecodeError, TypeError):
|
|
continue
|
|
if custom.get("elem_type") == 1002 and custom.get("user_id") == bot_id:
|
|
return True
|
|
return False
|
|
|
|
@staticmethod
|
|
def _extract_bot_mention_text(msg_body: list, bot_id: Optional[str]) -> str:
|
|
"""Extract the display text used to @-mention this bot (e.g. ``@yuanbao-bot``)."""
|
|
if not bot_id:
|
|
return ""
|
|
for elem in msg_body:
|
|
if elem.get("msg_type") != "TIMCustomElem":
|
|
continue
|
|
data_str = elem.get("msg_content", {}).get("data", "")
|
|
if not data_str:
|
|
continue
|
|
try:
|
|
custom = json.loads(data_str)
|
|
except (json.JSONDecodeError, TypeError):
|
|
continue
|
|
if custom.get("elem_type") == 1002 and custom.get("user_id") == bot_id:
|
|
mention_text = str(custom.get("text") or "").strip()
|
|
if mention_text:
|
|
return mention_text
|
|
return ""
|
|
|
|
@staticmethod
|
|
def _build_group_channel_prompt(msg_body: list, bot_id: Optional[str]) -> str:
|
|
"""Build a per-turn group-chat prompt that highlights which message to respond to."""
|
|
bid = str(bot_id or "unknown")
|
|
bot_mention = GroupAtGuardMiddleware._extract_bot_mention_text(msg_body, bot_id) or "unknown"
|
|
return (
|
|
"You are handling a Yuanbao group chat message.\n"
|
|
f"- Your identity: user_id={bid}, @-mention name in this group={bot_mention}\n"
|
|
"- Lines in history prefixed with `[nickname|user_id]` are observed group context "
|
|
"and are not necessarily addressed to you.\n"
|
|
"- Treat only the current new message as a request explicitly directed at you, "
|
|
"and answer it directly."
|
|
)
|
|
|
|
@staticmethod
|
|
def _observe_group_message(
|
|
adapter, source, sender_display: str, text: str,
|
|
*, msg_id: Optional[str] = None,
|
|
) -> None:
|
|
"""Write a group message into the session transcript without triggering the agent.
|
|
|
|
This allows the model to see the full group conversation when it is
|
|
eventually invoked via @bot. Messages are stored with ``role: "user"``
|
|
in the format ``[nickname|user_id]\\n<content>`` so the model
|
|
can distinguish participants and their user ids.
|
|
"""
|
|
store = getattr(adapter, "_session_store", None)
|
|
if not store:
|
|
return
|
|
try:
|
|
session_entry = store.get_or_create_session(source)
|
|
user_id = source.user_id or "unknown"
|
|
attributed = f"[{sender_display}|{user_id}]\n{text}"
|
|
entry: dict = {
|
|
"role": "user",
|
|
"content": attributed,
|
|
"timestamp": datetime.now(tz=timezone.utc).isoformat(),
|
|
"observed": True,
|
|
}
|
|
if msg_id:
|
|
entry["message_id"] = msg_id
|
|
store.append_to_transcript(
|
|
session_entry.session_id,
|
|
entry,
|
|
)
|
|
except Exception as exc:
|
|
logger.warning("[%s] Failed to observe group message: %s", adapter.name, exc)
|
|
|
|
async def handle(self, ctx: InboundContext, next_fn) -> None:
|
|
adapter = ctx.adapter
|
|
if ctx.chat_type == "group" and not ctx.owner_command and not self._is_at_bot(ctx.msg_body, adapter._bot_id):
|
|
self._observe_group_message(
|
|
adapter, ctx.source, ctx.sender_nickname or ctx.from_account, ctx.raw_text,
|
|
msg_id=ctx.msg_id or None,
|
|
)
|
|
logger.info(
|
|
"[%s] Group message observed (no @bot): chat=%s from=%s",
|
|
adapter.name, ctx.chat_id, ctx.from_account,
|
|
)
|
|
return # Stop pipeline — message observed but not dispatched
|
|
await next_fn()
|
|
|
|
|
|
class GroupAttributionMiddleware(InboundMiddleware):
|
|
"""Tag group @bot messages with [nickname|user_id] attribution and channel_prompt.
|
|
|
|
For group messages that pass the @bot guard (i.e. the bot is mentioned),
|
|
this middleware:
|
|
- Builds a per-turn channel_prompt so the model knows its identity and
|
|
the attribution scheme.
|
|
- Rewrites ctx.raw_text to ``[nickname|user_id]\\n<content>`` to match
|
|
the observed-history format.
|
|
- Suppresses the runner's default ``[user_name]`` shared-thread prefix
|
|
by clearing ``source.user_name``.
|
|
"""
|
|
|
|
name = "group-attribution"
|
|
|
|
async def handle(self, ctx: InboundContext, next_fn) -> None:
|
|
if ctx.chat_type == "group" and not ctx.owner_command:
|
|
adapter = ctx.adapter
|
|
ctx.channel_prompt = GroupAtGuardMiddleware._build_group_channel_prompt(
|
|
ctx.msg_body, adapter._bot_id,
|
|
)
|
|
user_id_label = ctx.from_account or "unknown"
|
|
nickname_label = ctx.sender_nickname or ctx.from_account or "unknown"
|
|
ctx.raw_text = f"[{nickname_label}|{user_id_label}]\n{ctx.raw_text}"
|
|
# Suppress runner's default ``[user_name]`` shared-thread prefix so
|
|
# the text the model sees matches the observed-history format.
|
|
if ctx.source is not None:
|
|
ctx.source = dataclasses.replace(ctx.source, user_name=None)
|
|
await next_fn()
|
|
|
|
|
|
class ClassifyMessageTypeMiddleware(InboundMiddleware):
|
|
"""Determine MessageType from text content and msg_body elements."""
|
|
|
|
name = "classify-msg-type"
|
|
|
|
@staticmethod
|
|
def _classify(text: str, msg_body: list) -> MessageType:
|
|
"""Classify message type based on text and msg_body."""
|
|
if text.startswith("/"):
|
|
return MessageType.COMMAND
|
|
for elem in msg_body:
|
|
etype = elem.get("msg_type", "")
|
|
if etype == "TIMImageElem":
|
|
return MessageType.PHOTO
|
|
if etype == "TIMSoundElem":
|
|
return MessageType.VOICE
|
|
if etype == "TIMVideoFileElem":
|
|
return MessageType.VIDEO
|
|
if etype == "TIMFileElem":
|
|
return MessageType.DOCUMENT
|
|
return MessageType.TEXT
|
|
|
|
async def handle(self, ctx: InboundContext, next_fn) -> None:
|
|
ctx.msg_type = self._classify(ctx.raw_text, ctx.msg_body)
|
|
await next_fn()
|
|
|
|
|
|
class QuoteContextMiddleware(InboundMiddleware):
|
|
"""Extract quote/reply context from cloud_custom_data."""
|
|
|
|
name = "quote-context"
|
|
|
|
@staticmethod
|
|
def _extract_quote_context(cloud_custom_data: str) -> Tuple[Optional[str], Optional[str]]:
|
|
"""Extract quote context, mapping to MessageEvent.reply_to_*.
|
|
|
|
Returns:
|
|
(reply_to_message_id, reply_to_text)
|
|
"""
|
|
if not cloud_custom_data:
|
|
return None, None
|
|
try:
|
|
parsed = json.loads(cloud_custom_data)
|
|
except (json.JSONDecodeError, TypeError):
|
|
return None, None
|
|
|
|
quote = parsed.get("quote") if isinstance(parsed, dict) else None
|
|
if not isinstance(quote, dict):
|
|
return None, None
|
|
|
|
# type=2 corresponds to image reference; desc may be empty, provide a placeholder.
|
|
quote_type = int(quote.get("type") or 0)
|
|
desc = str(quote.get("desc") or "").strip()
|
|
if quote_type == 2 and not desc:
|
|
desc = "[image]"
|
|
if not desc:
|
|
return None, None
|
|
|
|
quote_id = str(quote.get("id") or "").strip() or None
|
|
sender = str(quote.get("sender_nickname") or quote.get("sender_id") or "").strip()
|
|
quote_text = f"{sender}: {desc}" if sender else desc
|
|
return quote_id, quote_text
|
|
|
|
async def handle(self, ctx: InboundContext, next_fn) -> None:
|
|
ctx.reply_to_message_id, ctx.reply_to_text = self._extract_quote_context(ctx.cloud_custom_data)
|
|
await next_fn()
|
|
|
|
|
|
class MediaResolveMiddleware(InboundMiddleware):
|
|
"""Resolve inbound media references to downloadable URLs."""
|
|
|
|
name = "media-resolve"
|
|
|
|
@staticmethod
|
|
def _guess_image_ext_from_url(url: str) -> str:
|
|
"""Guess image extension from URL path."""
|
|
path = urllib.parse.urlparse(url).path
|
|
ext = os.path.splitext(path)[1].lower()
|
|
if ext in {".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp", ".heic", ".tiff"}:
|
|
return ext
|
|
return ".jpg"
|
|
|
|
@staticmethod
|
|
async def _fetch_resource_url(adapter, resource_id: str) -> str:
|
|
"""Low-level helper: exchange a ``resourceId`` for a direct download URL.
|
|
|
|
Handles token retrieval, the ``/api/resource/v1/download`` API call,
|
|
and a single 401-retry with token force-refresh. Raises on failure.
|
|
"""
|
|
resource_id = resource_id.strip()
|
|
if not resource_id:
|
|
raise RuntimeError("missing resource_id")
|
|
|
|
token_data = await adapter._get_cached_token()
|
|
token = str(token_data.get("token") or "").strip()
|
|
source = str(token_data.get("source") or "web").strip() or "web"
|
|
bot_id = str(token_data.get("bot_id") or adapter._bot_id or adapter._app_key).strip()
|
|
if not token or not bot_id:
|
|
raise RuntimeError("missing token or bot_id for resource download")
|
|
|
|
api_url = f"{adapter._api_domain}/api/resource/v1/download"
|
|
headers = {
|
|
"Content-Type": "application/json",
|
|
"X-ID": bot_id,
|
|
"X-Token": token,
|
|
"X-Source": source,
|
|
}
|
|
|
|
async with httpx.AsyncClient(timeout=15.0, follow_redirects=True) as client:
|
|
for attempt in range(2):
|
|
resp = await client.get(api_url, params={"resourceId": resource_id}, headers=headers)
|
|
if resp.status_code == 401 and attempt == 0:
|
|
# Force refresh token once on expiry and retry
|
|
token_data = await SignManager.force_refresh(
|
|
adapter._app_key, adapter._app_secret, adapter._api_domain,
|
|
)
|
|
token = str(token_data.get("token") or "").strip()
|
|
source = str(token_data.get("source") or source or "web").strip() or "web"
|
|
bot_id = str(token_data.get("bot_id") or adapter._bot_id or adapter._app_key).strip()
|
|
if not token or not bot_id:
|
|
break
|
|
headers["X-ID"] = bot_id
|
|
headers["X-Token"] = token
|
|
headers["X-Source"] = source
|
|
continue
|
|
|
|
resp.raise_for_status()
|
|
payload = resp.json()
|
|
code = payload.get("code")
|
|
if code not in (None, 0):
|
|
raise RuntimeError(
|
|
f"resource/v1/download failed: code={code}, msg={payload.get('msg', '')}"
|
|
)
|
|
data = payload.get("data") if isinstance(payload.get("data"), dict) else payload
|
|
real_url = str((data or {}).get("url") or (data or {}).get("realUrl") or "").strip()
|
|
if real_url:
|
|
return real_url
|
|
raise RuntimeError("resource/v1/download missing url/realUrl")
|
|
|
|
raise RuntimeError("resource/v1/download did not return a URL")
|
|
|
|
@staticmethod
|
|
async def _resolve_download_url(adapter, url: str) -> str:
|
|
"""Resolve Yuanbao resource placeholder to a directly fetchable real URL.
|
|
|
|
Common URL patterns:
|
|
https://hunyuan.tencent.com/api/resource/download?resourceId=...
|
|
Direct GET returns 401; need business API:
|
|
GET /api/resource/v1/download?resourceId=...
|
|
"""
|
|
try:
|
|
parsed = urllib.parse.urlparse(url)
|
|
except Exception:
|
|
return url
|
|
|
|
query = urllib.parse.parse_qs(parsed.query)
|
|
resource_ids = query.get("resourceId") or query.get("resourceid") or []
|
|
resource_id = str(resource_ids[0]).strip() if resource_ids else ""
|
|
if not resource_id:
|
|
return url
|
|
|
|
try:
|
|
return await MediaResolveMiddleware._fetch_resource_url(adapter, resource_id)
|
|
except Exception:
|
|
return url
|
|
|
|
@classmethod
|
|
async def _download_and_cache(
|
|
cls, adapter, *, fetch_url: str, kind: str,
|
|
file_name: Optional[str] = None, log_tag: str = "",
|
|
) -> Optional[Tuple[str, str]]:
|
|
"""Download a Yuanbao resource and cache locally. Returns ``(local_path, mime)`` or ``None``."""
|
|
try:
|
|
file_bytes, content_type = await media_download_url(
|
|
fetch_url, max_size_mb=adapter.MEDIA_MAX_SIZE_MB,
|
|
)
|
|
except Exception as exc:
|
|
logger.warning(
|
|
"[%s] inbound media download failed: kind=%s %s err=%s",
|
|
adapter.name, kind, log_tag, exc,
|
|
)
|
|
return None
|
|
|
|
if kind == "image":
|
|
ext = cls._guess_image_ext_from_url(fetch_url)
|
|
try:
|
|
local_path = cache_image_from_bytes(file_bytes, ext=ext)
|
|
except ValueError as exc:
|
|
logger.warning(
|
|
"[%s] inbound image cache rejected: %s err=%s",
|
|
adapter.name, log_tag, exc,
|
|
)
|
|
return None
|
|
mime = guess_mime_type(f"image{ext}")
|
|
if not mime.startswith("image/"):
|
|
mime = content_type if content_type.startswith("image/") else "image/jpeg"
|
|
return local_path, mime
|
|
|
|
# kind == "file"
|
|
if not file_name:
|
|
parsed = urllib.parse.urlparse(fetch_url)
|
|
file_name = os.path.basename(parsed.path) or "file"
|
|
try:
|
|
local_path = cache_document_from_bytes(file_bytes, file_name)
|
|
except Exception as exc:
|
|
logger.warning(
|
|
"[%s] inbound file cache failed: %s err=%s",
|
|
adapter.name, log_tag, exc,
|
|
)
|
|
return None
|
|
mime = guess_mime_type(file_name) or content_type or "application/octet-stream"
|
|
return local_path, mime
|
|
|
|
@classmethod
|
|
async def _resolve_by_resource_id(cls, adapter, resource_id: str) -> str:
|
|
"""Exchange a Yuanbao ``resourceId`` for a short-lived direct download URL. Raises on failure."""
|
|
return await cls._fetch_resource_url(adapter, resource_id)
|
|
|
|
@classmethod
|
|
async def _resolve_media_urls(
|
|
cls, adapter, media_refs: List[Dict[str, str]]
|
|
) -> Tuple[List[str], List[str]]:
|
|
"""Resolve inbound media refs: download to local cache, return (local_paths, mime_types).
|
|
|
|
Yuanbao COS hostnames resolve to private IPs, tripping the SSRF guard
|
|
in vision_tools. We download ourselves and return local cache paths.
|
|
"""
|
|
media_urls: List[str] = []
|
|
media_types: List[str] = []
|
|
|
|
for ref in media_refs:
|
|
kind = str(ref.get("kind") or "").strip().lower()
|
|
url = str(ref.get("url") or "").strip()
|
|
if kind not in {"image", "file"} or not url:
|
|
continue
|
|
|
|
try:
|
|
fetch_url = await cls._resolve_download_url(adapter, url)
|
|
except Exception as exc:
|
|
logger.warning(
|
|
"[%s] inbound media resolve failed: kind=%s url=%s err=%s",
|
|
adapter.name, kind, url, exc,
|
|
)
|
|
continue
|
|
|
|
cached = await cls._download_and_cache(
|
|
adapter,
|
|
fetch_url=fetch_url,
|
|
kind=kind,
|
|
file_name=str(ref.get("name") or "").strip() or None,
|
|
log_tag=f"placeholder_url={url[:80]}",
|
|
)
|
|
if cached is None:
|
|
continue
|
|
local_path, mime = cached
|
|
media_urls.append(local_path)
|
|
media_types.append(mime)
|
|
|
|
return media_urls, media_types
|
|
|
|
@classmethod
|
|
async def _collect_observed_media(
|
|
cls, adapter, source,
|
|
) -> Tuple[List[str], List[str]]:
|
|
"""Resolve recent observed image/file anchors from transcript into ``(local_paths, mimes)``."""
|
|
store = getattr(adapter, "_session_store", None)
|
|
if not store:
|
|
return [], []
|
|
try:
|
|
session_entry = store.get_or_create_session(source)
|
|
history = store.load_transcript(session_entry.session_id)
|
|
except Exception as exc:
|
|
logger.warning(
|
|
"[%s] Observed-media hydration setup failed: %s",
|
|
adapter.name, exc,
|
|
)
|
|
return [], []
|
|
if not history:
|
|
return [], []
|
|
|
|
start = max(0, len(history) - OBSERVED_MEDIA_BACKFILL_LOOKBACK)
|
|
order: List[Tuple[str, str, str]] = [] # (rid, kind, filename)
|
|
seen: set = set()
|
|
for msg in history[start:]:
|
|
content = msg.get("content")
|
|
if not isinstance(content, str) or "|ybres:" not in content:
|
|
continue
|
|
for m in _YB_RES_REF_RE.finditer(content):
|
|
head = m.group(1) # "image" | "file:<name>" | "voice" | "video"
|
|
rid = m.group(2)
|
|
kind, _, filename = head.partition(":")
|
|
kind = kind.strip()
|
|
if kind not in ("image", "file"):
|
|
continue
|
|
if rid in seen:
|
|
continue
|
|
seen.add(rid)
|
|
order.append((rid, kind, filename.strip()))
|
|
if len(order) >= OBSERVED_MEDIA_BACKFILL_MAX_RESOLVE_PER_TURN:
|
|
break
|
|
if len(order) >= OBSERVED_MEDIA_BACKFILL_MAX_RESOLVE_PER_TURN:
|
|
break
|
|
|
|
if not order:
|
|
return [], []
|
|
|
|
media_paths: List[str] = []
|
|
mimes: List[str] = []
|
|
for rid, kind, filename in order:
|
|
try:
|
|
fresh_url = await cls._resolve_by_resource_id(adapter, rid)
|
|
except Exception as exc:
|
|
logger.warning(
|
|
"[%s] observed-media resolve failed: rid=%s kind=%s err=%s",
|
|
adapter.name, rid, kind, exc,
|
|
)
|
|
continue
|
|
cached = await cls._download_and_cache(
|
|
adapter,
|
|
fetch_url=fresh_url,
|
|
kind=kind,
|
|
file_name=filename or None,
|
|
log_tag=f"rid={rid}",
|
|
)
|
|
if cached is None:
|
|
continue
|
|
path, mime = cached
|
|
media_paths.append(path)
|
|
mimes.append(mime)
|
|
return media_paths, mimes
|
|
|
|
async def handle(self, ctx: InboundContext, next_fn) -> None:
|
|
adapter = ctx.adapter
|
|
ctx.media_urls, ctx.media_types = await self._resolve_media_urls(adapter, ctx.media_refs)
|
|
# Re-check placeholder after media resolution
|
|
if PlaceholderFilterMiddleware.is_skippable_placeholder(ctx.raw_text, len(ctx.media_urls)):
|
|
logger.debug("[%s] Skip placeholder after media download: %r", adapter.name, ctx.raw_text)
|
|
return # Stop pipeline
|
|
await next_fn()
|
|
|
|
|
|
class DispatchMiddleware(InboundMiddleware):
|
|
"""Build MessageEvent and dispatch to AI handler."""
|
|
|
|
name = "dispatch"
|
|
|
|
async def handle(self, ctx: InboundContext, next_fn) -> None:
|
|
adapter = ctx.adapter
|
|
|
|
_sk = build_session_key(
|
|
ctx.source,
|
|
group_sessions_per_user=adapter.config.extra.get("group_sessions_per_user", True),
|
|
thread_sessions_per_user=adapter.config.extra.get("thread_sessions_per_user", False),
|
|
)
|
|
|
|
async def _dispatch_inbound_event() -> None:
|
|
media_urls = list(ctx.media_urls)
|
|
media_types = list(ctx.media_types)
|
|
|
|
# Backfill observed media from recent transcript history
|
|
extra_img_urls: List[str] = []
|
|
extra_img_mimes: List[str] = []
|
|
try:
|
|
extra_img_urls, extra_img_mimes = await MediaResolveMiddleware._collect_observed_media(
|
|
adapter, ctx.source,
|
|
)
|
|
except Exception as exc:
|
|
logger.warning(
|
|
"[%s] observed-image hydration raised, continuing anyway: %s",
|
|
adapter.name, exc,
|
|
)
|
|
if extra_img_urls:
|
|
current = set(media_urls)
|
|
for u, m in zip(extra_img_urls, extra_img_mimes):
|
|
if u in current:
|
|
continue
|
|
media_urls.append(u)
|
|
media_types.append(m)
|
|
current.add(u)
|
|
|
|
# Replace [kind|ybres:xxx] anchors with local cache paths so
|
|
# the transcript records usable paths for the model.
|
|
_patched_event_text = ctx.raw_text
|
|
for u, m in zip(media_urls, media_types):
|
|
if not u.startswith("/"):
|
|
continue
|
|
anchor_match = _YB_RES_REF_RE.search(_patched_event_text)
|
|
if not anchor_match:
|
|
continue
|
|
head = anchor_match.group(1)
|
|
kind, _, filename = head.partition(":")
|
|
kind = kind.strip()
|
|
if kind == "image" and m.startswith("image/"):
|
|
replacement = f"[image: {u}]"
|
|
elif kind == "file":
|
|
label = filename.strip() or os.path.basename(u)
|
|
replacement = f"[file: {label} → {u}]"
|
|
else:
|
|
continue
|
|
_patched_event_text = (
|
|
_patched_event_text[:anchor_match.start()]
|
|
+ replacement
|
|
+ _patched_event_text[anchor_match.end():]
|
|
)
|
|
|
|
event = MessageEvent(
|
|
text=_patched_event_text,
|
|
message_type=ctx.msg_type,
|
|
source=ctx.source,
|
|
message_id=ctx.msg_id or None,
|
|
raw_message=ctx.push,
|
|
media_urls=media_urls,
|
|
media_types=media_types,
|
|
reply_to_message_id=ctx.reply_to_message_id,
|
|
reply_to_text=ctx.reply_to_text,
|
|
channel_prompt=ctx.channel_prompt,
|
|
)
|
|
if _sk and ctx.msg_id:
|
|
adapter._processing_msg_ids[_sk] = ctx.msg_id
|
|
adapter._processing_msg_texts[_sk] = ctx.raw_text or ""
|
|
if ctx.msg_id and ctx.raw_text:
|
|
cache = adapter._msg_content_cache
|
|
cache[ctx.msg_id] = ctx.raw_text
|
|
if len(cache) > 200:
|
|
for k in list(cache)[:len(cache) - 200]:
|
|
del cache[k]
|
|
await adapter.handle_message(event)
|
|
|
|
if ctx.chat_type == "group":
|
|
is_new = _sk not in adapter._group_queues
|
|
queue = adapter._group_queues.setdefault(_sk, asyncio.Queue())
|
|
queue.put_nowait(_dispatch_inbound_event)
|
|
logger.info(
|
|
"[%s] Group message enqueued (qsize=%d) for %s",
|
|
adapter.name, queue.qsize(), (_sk or "")[:50],
|
|
)
|
|
if is_new:
|
|
consumer = asyncio.create_task(
|
|
self._consume_group_queue(adapter, _sk),
|
|
name=f"yuanbao-group-consumer-{(_sk or '')[:30]}",
|
|
)
|
|
adapter._inbound_tasks.add(consumer)
|
|
consumer.add_done_callback(adapter._inbound_tasks.discard)
|
|
else:
|
|
task = asyncio.create_task(
|
|
_dispatch_inbound_event(),
|
|
name=f"yuanbao-inbound-{ctx.msg_id or 'unknown'}",
|
|
)
|
|
adapter._inbound_tasks.add(task)
|
|
task.add_done_callback(adapter._inbound_tasks.discard)
|
|
|
|
await next_fn()
|
|
|
|
@staticmethod
|
|
async def _consume_group_queue(adapter: "YuanbaoAdapter", session_key: str) -> None:
|
|
"""Drain the group queue one dispatch at a time, waiting for each to finish."""
|
|
_IDLE_TIMEOUT = 2.0
|
|
queue = adapter._group_queues.get(session_key)
|
|
if not queue:
|
|
return
|
|
try:
|
|
while True:
|
|
try:
|
|
dispatch_fn = await asyncio.wait_for(queue.get(), timeout=_IDLE_TIMEOUT)
|
|
except asyncio.TimeoutError:
|
|
break
|
|
logger.debug(
|
|
"[%s] Group queue: dispatching for %s (remaining=%d)",
|
|
adapter.name, (session_key or "")[:50], queue.qsize(),
|
|
)
|
|
try:
|
|
await dispatch_fn()
|
|
while session_key in adapter._active_sessions:
|
|
await asyncio.sleep(0.1)
|
|
except Exception:
|
|
logger.exception("[%s] Group queue consumer error", adapter.name)
|
|
finally:
|
|
adapter._group_queues.pop(session_key, None)
|
|
|
|
|
|
class InboundPipelineBuilder:
|
|
"""Factory for building InboundPipeline instances.
|
|
|
|
Separates pipeline assembly (business knowledge) from the pipeline engine
|
|
(InboundPipeline) so the engine stays generic and reusable.
|
|
"""
|
|
|
|
# Default middleware sequence for Yuanbao inbound message processing.
|
|
_DEFAULT_MIDDLEWARES: list[type] = [
|
|
DecodeMiddleware,
|
|
ExtractFieldsMiddleware,
|
|
RecallGuardMiddleware,
|
|
DedupMiddleware,
|
|
SkipSelfMiddleware,
|
|
ChatRoutingMiddleware,
|
|
AccessGuardMiddleware,
|
|
AutoSetHomeMiddleware,
|
|
ExtractContentMiddleware,
|
|
PlaceholderFilterMiddleware,
|
|
OwnerCommandMiddleware,
|
|
BuildSourceMiddleware,
|
|
GroupAtGuardMiddleware,
|
|
GroupAttributionMiddleware,
|
|
ClassifyMessageTypeMiddleware,
|
|
QuoteContextMiddleware,
|
|
MediaResolveMiddleware,
|
|
DispatchMiddleware,
|
|
]
|
|
|
|
@classmethod
|
|
def build(cls) -> InboundPipeline:
|
|
"""Build the default inbound message processing pipeline."""
|
|
pipeline = InboundPipeline()
|
|
for mw_cls in cls._DEFAULT_MIDDLEWARES:
|
|
pipeline.use(mw_cls())
|
|
return pipeline
|
|
|
|
class ConnectionManager:
|
|
"""Manages the WebSocket connection lifecycle for YuanbaoAdapter.
|
|
|
|
Responsibilities:
|
|
- Opening and closing the WebSocket
|
|
- AUTH_BIND handshake
|
|
- Heartbeat (ping/pong) loop
|
|
- Receive loop (frame dispatch)
|
|
- Reconnect with exponential backoff
|
|
"""
|
|
|
|
def __init__(self, adapter: "YuanbaoAdapter") -> None:
|
|
self._adapter = adapter
|
|
self._ws = None # websockets connection
|
|
self._connect_id: Optional[str] = None
|
|
self._heartbeat_task: Optional[asyncio.Task] = None
|
|
self._recv_task: Optional[asyncio.Task] = None
|
|
self._pending_acks: Dict[str, asyncio.Future] = {}
|
|
self._pending_pong: Optional[asyncio.Future] = None
|
|
self._consecutive_hb_timeouts: int = 0
|
|
self._reconnect_attempts: int = 0
|
|
self._reconnecting: bool = False
|
|
# Debounce buffer for aggregating multi-part inbound messages
|
|
self._inbound_buffer: Dict[str, list] = {} # key -> [raw_data_frames, ...]
|
|
self._inbound_timers: Dict[str, asyncio.TimerHandle] = {} # key -> timer
|
|
|
|
# -- Properties --------------------------------------------------------
|
|
|
|
@property
|
|
def ws(self):
|
|
return self._ws
|
|
|
|
@property
|
|
def connect_id(self) -> Optional[str]:
|
|
return self._connect_id
|
|
|
|
@property
|
|
def reconnect_attempts(self) -> int:
|
|
return self._reconnect_attempts
|
|
|
|
@property
|
|
def is_connected(self) -> bool:
|
|
if self._ws is None:
|
|
return False
|
|
open_attr = getattr(self._ws, "open", None)
|
|
if open_attr is True:
|
|
return True
|
|
if callable(open_attr):
|
|
try:
|
|
return bool(open_attr())
|
|
except Exception:
|
|
return False
|
|
return False
|
|
|
|
# -- Open / Close ------------------------------------------------------
|
|
|
|
async def open(self) -> bool:
|
|
"""Open WebSocket connection: sign-token → WS connect → AUTH_BIND → start loops.
|
|
|
|
Returns True on success, False on failure.
|
|
"""
|
|
adapter = self._adapter
|
|
|
|
if not WEBSOCKETS_AVAILABLE:
|
|
msg = "Yuanbao startup failed: 'websockets' package not installed"
|
|
adapter._set_fatal_error("yuanbao_missing_dependency", msg, retryable=True)
|
|
logger.warning("[%s] %s. Run: pip install websockets", adapter.name, msg)
|
|
return False
|
|
|
|
if not adapter._app_key or not adapter._app_secret:
|
|
msg = (
|
|
"Yuanbao startup failed: "
|
|
"YUANBAO_APP_ID and YUANBAO_APP_SECRET are required"
|
|
)
|
|
adapter._set_fatal_error("yuanbao_missing_credentials", msg, retryable=False)
|
|
logger.error("[%s] %s", adapter.name, msg)
|
|
return False
|
|
|
|
# Idempotency guard
|
|
if self._ws is not None:
|
|
try:
|
|
open_attr = getattr(self._ws, "open", None)
|
|
if open_attr is True or (callable(open_attr) and open_attr()):
|
|
logger.debug("[%s] Already connected, skipping connect()", adapter.name)
|
|
return True
|
|
except Exception:
|
|
pass
|
|
|
|
# Acquire platform-scoped lock to prevent duplicate connections
|
|
if not adapter._acquire_platform_lock(
|
|
'yuanbao-app-key', adapter._app_key, 'Yuanbao app key'
|
|
):
|
|
return False
|
|
|
|
try:
|
|
# Step 1: Get sign token
|
|
logger.info("[%s] Fetching sign token from %s", adapter.name, adapter._api_domain)
|
|
token_data = await SignManager.get_token(
|
|
adapter._app_key, adapter._app_secret, adapter._api_domain,
|
|
route_env=adapter._route_env,
|
|
)
|
|
|
|
# Update bot_id if returned by sign-token API
|
|
if token_data.get("bot_id"):
|
|
adapter._bot_id = str(token_data["bot_id"])
|
|
|
|
# Step 2: Open WebSocket connection (disable built-in ping/pong)
|
|
logger.info("[%s] Connecting to %s", adapter.name, adapter._ws_url)
|
|
self._ws = await asyncio.wait_for(
|
|
websockets.connect( # type: ignore[attr-defined]
|
|
adapter._ws_url,
|
|
ping_interval=None,
|
|
ping_timeout=None,
|
|
close_timeout=5,
|
|
),
|
|
timeout=CONNECT_TIMEOUT_SECONDS,
|
|
)
|
|
|
|
# Step 3: Authenticate (AUTH_BIND + wait for BIND_ACK)
|
|
authed = await self._authenticate(token_data)
|
|
if not authed:
|
|
await self._cleanup_ws()
|
|
return False
|
|
|
|
# Step 4: Start background tasks
|
|
self._reconnect_attempts = 0
|
|
adapter._mark_connected()
|
|
adapter._loop = asyncio.get_running_loop()
|
|
self._heartbeat_task = asyncio.create_task(
|
|
self._heartbeat_loop(), name=f"yuanbao-heartbeat-{self._connect_id}"
|
|
)
|
|
self._recv_task = asyncio.create_task(
|
|
self._receive_loop(), name=f"yuanbao-recv-{self._connect_id}"
|
|
)
|
|
logger.info(
|
|
"[%s] Connected. connectId=%s botId=%s",
|
|
adapter.name, self._connect_id, adapter._bot_id,
|
|
)
|
|
|
|
YuanbaoAdapter.set_active(adapter)
|
|
|
|
return True
|
|
|
|
except asyncio.TimeoutError:
|
|
logger.error("[%s] Connection timed out", adapter.name)
|
|
await self._cleanup_ws()
|
|
adapter._release_platform_lock()
|
|
return False
|
|
except Exception as exc:
|
|
logger.error("[%s] connect() failed: %s", adapter.name, exc, exc_info=True)
|
|
await self._cleanup_ws()
|
|
adapter._release_platform_lock()
|
|
return False
|
|
|
|
async def close(self) -> None:
|
|
"""Cancel background tasks, fail pending futures, and close the WebSocket."""
|
|
|
|
if self._heartbeat_task:
|
|
self._heartbeat_task.cancel()
|
|
try:
|
|
await self._heartbeat_task
|
|
except asyncio.CancelledError:
|
|
pass
|
|
self._heartbeat_task = None
|
|
|
|
if self._recv_task:
|
|
self._recv_task.cancel()
|
|
try:
|
|
await self._recv_task
|
|
except asyncio.CancelledError:
|
|
pass
|
|
self._recv_task = None
|
|
|
|
# Fail any pending ACK futures
|
|
disc_exc = RuntimeError("YuanbaoAdapter disconnected")
|
|
for fut in self._pending_acks.values():
|
|
if not fut.done():
|
|
fut.set_exception(disc_exc)
|
|
self._pending_acks.clear()
|
|
|
|
# Clear refresh locks to avoid stale locks from a previous event loop
|
|
SignManager.clear_locks()
|
|
|
|
await self._cleanup_ws()
|
|
|
|
# -- Authentication ----------------------------------------------------
|
|
|
|
async def _authenticate(self, token_data: dict) -> bool:
|
|
"""Send AUTH_BIND and read frames until BIND_ACK is received.
|
|
|
|
Returns True on success, False on failure/timeout.
|
|
"""
|
|
adapter = self._adapter
|
|
if self._ws is None:
|
|
return False
|
|
|
|
token = token_data.get("token", "")
|
|
uid = adapter._bot_id or token_data.get("bot_id", "")
|
|
source = token_data.get("source") or "bot"
|
|
route_env = adapter._route_env or token_data.get("route_env", "") or ""
|
|
|
|
msg_id = str(uuid.uuid4())
|
|
|
|
auth_bytes = encode_auth_bind(
|
|
biz_id="ybBot",
|
|
uid=uid,
|
|
source=source,
|
|
token=token,
|
|
msg_id=msg_id,
|
|
app_version=_APP_VERSION,
|
|
operation_system=_OPERATION_SYSTEM,
|
|
bot_version=_BOT_VERSION,
|
|
route_env=route_env,
|
|
)
|
|
await self._ws.send(auth_bytes)
|
|
logger.debug("[%s] AUTH_BIND sent (msg_id=%s uid=%s)", adapter.name, msg_id, uid)
|
|
|
|
try:
|
|
_loop = asyncio.get_running_loop()
|
|
deadline = _loop.time() + AUTH_TIMEOUT_SECONDS
|
|
while True:
|
|
remaining = deadline - _loop.time()
|
|
if remaining <= 0:
|
|
logger.error("[%s] AUTH_BIND timeout waiting for BIND_ACK", adapter.name)
|
|
return False
|
|
|
|
raw = await asyncio.wait_for(self._ws.recv(), timeout=remaining)
|
|
if not isinstance(raw, (bytes, bytearray)):
|
|
continue
|
|
|
|
try:
|
|
msg = decode_conn_msg(bytes(raw))
|
|
except Exception:
|
|
continue
|
|
|
|
head = msg.get("head", {})
|
|
cmd_type = head.get("cmd_type", -1)
|
|
cmd = head.get("cmd", "")
|
|
|
|
if cmd_type == CMD_TYPE["Response"] and cmd == "auth-bind":
|
|
connect_id = self._extract_connect_id(msg)
|
|
if connect_id:
|
|
self._connect_id = connect_id
|
|
logger.info("[%s] BIND_ACK received: connectId=%s", adapter.name, connect_id)
|
|
return True
|
|
else:
|
|
logger.error("[%s] BIND_ACK missing connectId", adapter.name)
|
|
return False
|
|
|
|
except asyncio.TimeoutError:
|
|
logger.error("[%s] AUTH_BIND timeout", adapter.name)
|
|
return False
|
|
except Exception as exc:
|
|
logger.error("[%s] AUTH_BIND error: %s", adapter.name, exc, exc_info=True)
|
|
return False
|
|
|
|
def _extract_connect_id(self, decoded_msg: dict) -> Optional[str]:
|
|
"""Extract connectId from decoded BIND_ACK message."""
|
|
data: bytes = decoded_msg.get("data", b"")
|
|
if not data:
|
|
return None
|
|
try:
|
|
fdict = _fields_to_dict(_parse_fields(data))
|
|
code = _get_varint(fdict, 1)
|
|
if code != 0:
|
|
message = _get_string(fdict, 2)
|
|
logger.error(
|
|
"[%s] AuthBindRsp error: code=%d message=%r",
|
|
self._adapter.name, code, message,
|
|
)
|
|
return None
|
|
connect_id = _get_string(fdict, 3)
|
|
return connect_id if connect_id else None
|
|
except Exception as exc:
|
|
logger.warning("[%s] Failed to extract connectId: %s", self._adapter.name, exc)
|
|
return None
|
|
|
|
# -- Heartbeat ---------------------------------------------------------
|
|
|
|
async def _heartbeat_loop(self) -> None:
|
|
"""Send HEARTBEAT (ping) every 30s; trigger reconnect after threshold misses."""
|
|
adapter = self._adapter
|
|
try:
|
|
while adapter._running:
|
|
await asyncio.sleep(HEARTBEAT_INTERVAL_SECONDS)
|
|
if self._ws is None:
|
|
continue
|
|
try:
|
|
msg_id = str(uuid.uuid4())
|
|
ping_bytes = encode_ping(msg_id)
|
|
loop = asyncio.get_running_loop()
|
|
pong_future: asyncio.Future = loop.create_future()
|
|
self._pending_pong = pong_future
|
|
self._pending_acks[msg_id] = pong_future
|
|
await self._ws.send(ping_bytes)
|
|
logger.debug("[%s] PING sent (msg_id=%s)", adapter.name, msg_id)
|
|
try:
|
|
await asyncio.wait_for(pong_future, timeout=10.0)
|
|
self._consecutive_hb_timeouts = 0
|
|
except asyncio.TimeoutError:
|
|
self._pending_acks.pop(msg_id, None)
|
|
self._consecutive_hb_timeouts += 1
|
|
logger.warning(
|
|
"[%s] PONG timeout (%d/%d)",
|
|
adapter.name, self._consecutive_hb_timeouts, HEARTBEAT_TIMEOUT_THRESHOLD,
|
|
)
|
|
if self._consecutive_hb_timeouts >= HEARTBEAT_TIMEOUT_THRESHOLD:
|
|
logger.warning("[%s] Heartbeat threshold exceeded, triggering reconnect", adapter.name)
|
|
self.schedule_reconnect()
|
|
return
|
|
finally:
|
|
self._pending_acks.pop(msg_id, None)
|
|
self._pending_pong = None
|
|
except Exception as exc:
|
|
logger.debug("[%s] Heartbeat send failed: %s", adapter.name, exc)
|
|
except asyncio.CancelledError:
|
|
pass
|
|
|
|
# -- Receive loop ------------------------------------------------------
|
|
|
|
async def _receive_loop(self) -> None:
|
|
"""Read WS frames and dispatch by cmd_type."""
|
|
adapter = self._adapter
|
|
try:
|
|
async for raw in self._ws: # type: ignore[union-attr]
|
|
if not isinstance(raw, (bytes, bytearray)):
|
|
continue
|
|
await self._handle_frame(bytes(raw))
|
|
except asyncio.CancelledError:
|
|
pass
|
|
except websockets.exceptions.ConnectionClosed as close_exc: # type: ignore[union-attr]
|
|
close_code = getattr(close_exc, 'code', None)
|
|
logger.warning(
|
|
"[%s] WebSocket connection closed: code=%s reason=%s",
|
|
adapter.name, close_code, getattr(close_exc, 'reason', ''),
|
|
)
|
|
if close_code and close_code in NO_RECONNECT_CLOSE_CODES:
|
|
logger.error(
|
|
"[%s] Close code %d is non-recoverable, NOT reconnecting",
|
|
adapter.name, close_code,
|
|
)
|
|
adapter._mark_disconnected()
|
|
else:
|
|
self.schedule_reconnect()
|
|
except Exception as exc:
|
|
logger.warning("[%s] receive_loop exited: %s", adapter.name, exc)
|
|
self.schedule_reconnect()
|
|
|
|
async def _handle_frame(self, raw: bytes) -> None:
|
|
"""Handle a single WebSocket frame."""
|
|
adapter = self._adapter
|
|
try:
|
|
msg = decode_conn_msg(raw)
|
|
except Exception as exc:
|
|
logger.debug("[%s] Failed to decode frame: %s", adapter.name, exc)
|
|
return
|
|
|
|
head = msg.get("head", {})
|
|
cmd_type = head.get("cmd_type", -1)
|
|
cmd = head.get("cmd", "")
|
|
msg_id = head.get("msg_id", "")
|
|
need_ack = head.get("need_ack", False)
|
|
data: bytes = msg.get("data", b"")
|
|
|
|
# HEARTBEAT_ACK
|
|
if cmd_type == CMD_TYPE["Response"] and cmd == "ping":
|
|
logger.debug("[%s] HEARTBEAT_ACK received (msg_id=%s)", adapter.name, msg_id)
|
|
if self._pending_pong is not None and not self._pending_pong.done():
|
|
self._pending_pong.set_result(True)
|
|
elif msg_id and msg_id in self._pending_acks:
|
|
fut = self._pending_acks.pop(msg_id)
|
|
if not fut.done():
|
|
fut.set_result(True)
|
|
return
|
|
|
|
# Fire-and-forget heartbeat ACKs — server always responds but callers don't
|
|
# wait on these; silently discard to avoid "Unmatched Response" noise.
|
|
if cmd_type == CMD_TYPE["Response"] and cmd in (
|
|
"send_group_heartbeat",
|
|
"send_private_heartbeat",
|
|
):
|
|
logger.debug("[%s] Heartbeat ACK received: cmd=%s msg_id=%s", adapter.name, cmd, msg_id)
|
|
return
|
|
|
|
# Response to an outbound RPC call
|
|
if cmd_type == CMD_TYPE["Response"]:
|
|
if msg_id and msg_id in self._pending_acks:
|
|
fut = self._pending_acks.pop(msg_id)
|
|
if not fut.done():
|
|
result = {"head": head}
|
|
if data:
|
|
result["data"] = data
|
|
fut.set_result(result)
|
|
else:
|
|
logger.debug(
|
|
"[%s] Unmatched Response: cmd=%s msg_id=%s",
|
|
adapter.name, cmd, msg_id,
|
|
)
|
|
return
|
|
|
|
# Server-initiated Push
|
|
if cmd_type == CMD_TYPE["Push"]:
|
|
logger.info("[%s] Push received: cmd=%s msg_id=%s data_len=%d", adapter.name, cmd, msg_id, len(data))
|
|
if need_ack and self._ws is not None:
|
|
try:
|
|
ack_bytes = encode_push_ack(head)
|
|
await self._ws.send(ack_bytes)
|
|
except Exception as ack_exc:
|
|
logger.debug("[%s] Failed to send PushAck: %s", adapter.name, ack_exc)
|
|
|
|
if msg_id and msg_id in self._pending_acks:
|
|
fut = self._pending_acks.pop(msg_id)
|
|
if not fut.done():
|
|
try:
|
|
decoded = decode_inbound_push(data) if data else {"head": head}
|
|
fut.set_result(decoded)
|
|
except Exception as exc:
|
|
fut.set_exception(exc)
|
|
return
|
|
|
|
# Genuine inbound message — dispatch to AI
|
|
if data:
|
|
logger.info(
|
|
"[%s] WS received inbound push, decoding and dispatching: cmd=%s, data_len=%d",
|
|
adapter.name, cmd, len(data),
|
|
)
|
|
self._push_to_inbound(data)
|
|
return
|
|
|
|
logger.debug(
|
|
"[%s] Ignoring frame: cmd_type=%d cmd=%s msg_id=%s",
|
|
adapter.name, cmd_type, cmd, msg_id,
|
|
)
|
|
|
|
# -- Inbound dispatch ---------------------------------------------------
|
|
|
|
_DEBOUNCE_WINDOW: float = 1.5 # seconds to wait for companion messages
|
|
|
|
def _extract_sender_key(self, raw_data: bytes) -> str:
|
|
"""Lightweight decode to extract sender key for debounce grouping.
|
|
|
|
Returns 'from_account:group_code' or a fallback unique key.
|
|
"""
|
|
try:
|
|
parsed = json.loads(raw_data.decode("utf-8"))
|
|
if isinstance(parsed, dict):
|
|
from_account = (
|
|
parsed.get("from_account", "")
|
|
or parsed.get("From_Account", "")
|
|
)
|
|
group_code = (
|
|
parsed.get("group_code", "")
|
|
or parsed.get("GroupId", "")
|
|
or parsed.get("group_id", "")
|
|
)
|
|
if from_account:
|
|
return f"{from_account}:{group_code}"
|
|
except Exception:
|
|
pass
|
|
# Protobuf: try decode_inbound_push for sender info
|
|
try:
|
|
push = decode_inbound_push(raw_data)
|
|
if push:
|
|
return f"{push.get('from_account', '')}:{push.get('group_code', '')}"
|
|
except Exception:
|
|
pass
|
|
# Fallback: unique key (no aggregation)
|
|
return f"__unknown_{id(raw_data)}"
|
|
|
|
def _push_to_inbound(self, raw_data: bytes) -> None:
|
|
"""Debounced inbound dispatch.
|
|
|
|
Buffers raw frames from the same sender within a short time window,
|
|
then dispatches all buffered data as a single aggregated pipeline
|
|
execution. This merges multi-part messages (e.g. image + text sent
|
|
as separate WS pushes) into one pipeline run.
|
|
"""
|
|
key = self._extract_sender_key(raw_data)
|
|
|
|
# Cancel existing timer for this key (reset debounce window)
|
|
existing_timer = self._inbound_timers.pop(key, None)
|
|
if existing_timer:
|
|
existing_timer.cancel()
|
|
|
|
# Append to buffer
|
|
if key not in self._inbound_buffer:
|
|
self._inbound_buffer[key] = []
|
|
self._inbound_buffer[key].append(raw_data)
|
|
|
|
logger.debug(
|
|
"[%s] Debounce: buffered frame for key=%s, count=%d",
|
|
self._adapter.name, key, len(self._inbound_buffer[key]),
|
|
)
|
|
|
|
# Schedule flush after debounce window
|
|
loop = asyncio.get_running_loop()
|
|
timer = loop.call_later(
|
|
self._DEBOUNCE_WINDOW,
|
|
self._flush_inbound_buffer,
|
|
key,
|
|
)
|
|
self._inbound_timers[key] = timer
|
|
|
|
def _flush_inbound_buffer(self, key: str) -> None:
|
|
"""Flush the debounce buffer for a given key — execute the pipeline."""
|
|
self._inbound_timers.pop(key, None)
|
|
data_list = self._inbound_buffer.pop(key, [])
|
|
if not data_list:
|
|
return
|
|
|
|
adapter = self._adapter
|
|
logger.info(
|
|
"[%s] Debounce flush: key=%s, aggregated %d frames",
|
|
adapter.name, key, len(data_list),
|
|
)
|
|
|
|
ctx = InboundContext(adapter=adapter, raw_frames=data_list)
|
|
|
|
adapter._track_task(asyncio.create_task(
|
|
adapter._inbound_pipeline.execute(ctx),
|
|
name=f"yuanbao-pipeline-{key}",
|
|
))
|
|
|
|
# -- Send business request ---------------------------------------------
|
|
|
|
async def send_biz_request(
|
|
self,
|
|
encoded_conn_msg: bytes,
|
|
req_id: str,
|
|
timeout: float = DEFAULT_SEND_TIMEOUT,
|
|
) -> dict:
|
|
"""Send a business-layer request and wait for the response.
|
|
|
|
1. Register a Future in pending_acks[req_id]
|
|
2. Send encoded_conn_msg (bytes) to WS
|
|
3. asyncio.wait_for(future, timeout)
|
|
4. Clean up pending_acks on timeout/exception
|
|
"""
|
|
if self._ws is None:
|
|
raise RuntimeError("Not connected")
|
|
|
|
loop = asyncio.get_running_loop()
|
|
future: asyncio.Future = loop.create_future()
|
|
self._pending_acks[req_id] = future
|
|
try:
|
|
await self._ws.send(encoded_conn_msg)
|
|
result = await asyncio.wait_for(asyncio.shield(future), timeout=timeout)
|
|
return result
|
|
except asyncio.TimeoutError:
|
|
raise
|
|
except Exception:
|
|
raise
|
|
finally:
|
|
self._pending_acks.pop(req_id, None)
|
|
|
|
# -- Reconnect ---------------------------------------------------------
|
|
|
|
def schedule_reconnect(self) -> None:
|
|
"""Schedule a reconnect only if running and not already reconnecting."""
|
|
if self._adapter._running and not self._reconnecting:
|
|
asyncio.create_task(self._reconnect_with_backoff())
|
|
|
|
async def _reconnect_with_backoff(self) -> bool:
|
|
"""Reconnect with exponential backoff (1s, 2s, 4s, … up to 60s)."""
|
|
if self._reconnecting:
|
|
logger.debug("[%s] Reconnect already in progress, skipping", self._adapter.name)
|
|
return False
|
|
self._reconnecting = True
|
|
try:
|
|
return await self._do_reconnect()
|
|
finally:
|
|
self._reconnecting = False
|
|
|
|
async def _do_reconnect(self) -> bool:
|
|
"""Internal reconnect loop, called under the _reconnecting guard."""
|
|
adapter = self._adapter
|
|
for attempt in range(MAX_RECONNECT_ATTEMPTS):
|
|
self._reconnect_attempts = attempt + 1
|
|
wait = min(2 ** attempt, 60)
|
|
logger.info(
|
|
"[%s] Reconnect attempt %d/%d in %ds",
|
|
adapter.name, attempt + 1, MAX_RECONNECT_ATTEMPTS, wait,
|
|
)
|
|
await asyncio.sleep(wait)
|
|
|
|
await self._cleanup_ws()
|
|
|
|
try:
|
|
token_data = await SignManager.force_refresh(
|
|
adapter._app_key, adapter._app_secret, adapter._api_domain,
|
|
route_env=adapter._route_env,
|
|
)
|
|
if token_data.get("bot_id"):
|
|
adapter._bot_id = str(token_data["bot_id"])
|
|
|
|
self._ws = await asyncio.wait_for(
|
|
websockets.connect( # type: ignore[attr-defined]
|
|
adapter._ws_url,
|
|
ping_interval=None,
|
|
ping_timeout=None,
|
|
close_timeout=5,
|
|
),
|
|
timeout=CONNECT_TIMEOUT_SECONDS,
|
|
)
|
|
|
|
authed = await self._authenticate(token_data)
|
|
if not authed:
|
|
logger.warning("[%s] Re-auth failed on attempt %d", adapter.name, attempt + 1)
|
|
await self._cleanup_ws()
|
|
continue
|
|
|
|
self._reconnect_attempts = 0
|
|
self._consecutive_hb_timeouts = 0
|
|
adapter._mark_connected()
|
|
|
|
if self._heartbeat_task and not self._heartbeat_task.done():
|
|
self._heartbeat_task.cancel()
|
|
self._heartbeat_task = asyncio.create_task(
|
|
self._heartbeat_loop(),
|
|
name=f"yuanbao-heartbeat-{self._connect_id}",
|
|
)
|
|
|
|
if self._recv_task and not self._recv_task.done():
|
|
self._recv_task.cancel()
|
|
self._recv_task = asyncio.create_task(
|
|
self._receive_loop(),
|
|
name=f"yuanbao-recv-{self._connect_id}",
|
|
)
|
|
|
|
logger.info(
|
|
"[%s] Reconnected on attempt %d. connectId=%s",
|
|
adapter.name, attempt + 1, self._connect_id,
|
|
)
|
|
return True
|
|
|
|
except asyncio.TimeoutError:
|
|
logger.warning("[%s] Reconnect attempt %d timed out", adapter.name, attempt + 1)
|
|
except Exception as exc:
|
|
logger.warning(
|
|
"[%s] Reconnect attempt %d failed: %s", adapter.name, attempt + 1, exc
|
|
)
|
|
|
|
logger.error(
|
|
"[%s] Giving up after %d reconnect attempts", adapter.name, MAX_RECONNECT_ATTEMPTS
|
|
)
|
|
adapter._mark_disconnected()
|
|
return False
|
|
|
|
async def _cleanup_ws(self) -> None:
|
|
"""Close and clear the WebSocket connection."""
|
|
ws = self._ws
|
|
self._ws = None
|
|
if ws is not None:
|
|
try:
|
|
await ws.close()
|
|
except Exception:
|
|
pass
|
|
|
|
class MediaSendHandler(ABC):
|
|
"""Abstract base class for media send strategies.
|
|
|
|
Subclasses implement:
|
|
- acquire_file(): how to obtain file bytes (download URL / read local)
|
|
- build_msg_body(): how to build TIMxxxElem from upload result
|
|
|
|
The shared flow (check ws → cancel notifier → validate → COS upload
|
|
→ lock → dispatch) is handled by the base handle() template method.
|
|
"""
|
|
|
|
@abstractmethod
|
|
async def acquire_file(
|
|
self, adapter: "YuanbaoAdapter", **kwargs: Any,
|
|
) -> Tuple[bytes, str, str]:
|
|
"""Return (file_bytes, filename, content_type).
|
|
|
|
Raises:
|
|
ValueError: when file cannot be acquired (not found, empty, etc.)
|
|
"""
|
|
|
|
@abstractmethod
|
|
def build_msg_body(self, upload_result: dict, **kwargs: Any) -> list:
|
|
"""Build platform-specific MsgBody list from COS upload result."""
|
|
|
|
def needs_cos_upload(self) -> bool:
|
|
"""Override to return False for non-COS media (e.g. sticker)."""
|
|
return True
|
|
|
|
async def handle(
|
|
self,
|
|
adapter: "YuanbaoAdapter",
|
|
chat_id: str,
|
|
reply_to: Optional[str] = None,
|
|
caption: Optional[str] = None,
|
|
**kwargs: Any,
|
|
) -> "SendResult":
|
|
"""Template method: shared media send flow."""
|
|
conn = adapter._connection
|
|
sender = adapter._outbound.sender
|
|
|
|
if conn.ws is None:
|
|
return SendResult(success=False, error="Not connected", retryable=True)
|
|
|
|
adapter._outbound.cancel_slow_notifier(chat_id)
|
|
|
|
try:
|
|
# 1. Acquire file bytes
|
|
file_bytes, filename, content_type = await self.acquire_file(
|
|
adapter, **kwargs,
|
|
)
|
|
|
|
# 2. Validate (only for handlers that upload to COS; stickers use
|
|
# TIMFaceElem and legitimately carry no file bytes, so skipping
|
|
# validate_media here avoids a spurious "Empty file: sticker").
|
|
if self.needs_cos_upload():
|
|
validation_err = MessageSender.validate_media(
|
|
file_bytes, filename, adapter.MEDIA_MAX_SIZE_MB,
|
|
)
|
|
if validation_err:
|
|
return SendResult(success=False, error=validation_err)
|
|
|
|
if self.needs_cos_upload():
|
|
file_uuid = md5_hex(file_bytes)
|
|
|
|
# 3. Get COS upload credentials
|
|
token_data = await adapter._get_cached_token()
|
|
token: str = token_data.get("token", "")
|
|
bot_id: str = (
|
|
token_data.get("bot_id", "") or adapter._bot_id or ""
|
|
)
|
|
|
|
credentials = await get_cos_credentials(
|
|
app_key=adapter._app_key,
|
|
api_domain=adapter._api_domain,
|
|
token=token,
|
|
filename=filename,
|
|
bot_id=bot_id,
|
|
route_env=adapter._route_env,
|
|
)
|
|
|
|
# 4. Upload to COS
|
|
upload_result = await upload_to_cos(
|
|
file_bytes=file_bytes,
|
|
filename=filename,
|
|
content_type=content_type,
|
|
credentials=credentials,
|
|
bucket=credentials["bucketName"],
|
|
region=credentials["region"],
|
|
)
|
|
|
|
# 5. Build MsgBody
|
|
# Remove keys already passed explicitly to avoid "multiple values" TypeError
|
|
fwd_kwargs = {
|
|
k: v for k, v in kwargs.items()
|
|
if k not in ("file_uuid", "filename", "content_type")
|
|
}
|
|
msg_body = self.build_msg_body(
|
|
upload_result,
|
|
file_uuid=file_uuid,
|
|
filename=filename,
|
|
content_type=content_type,
|
|
**fwd_kwargs,
|
|
)
|
|
else:
|
|
# Non-COS media (e.g. sticker): build MsgBody directly
|
|
msg_body = self.build_msg_body({}, **kwargs)
|
|
|
|
# 6. Append caption if provided
|
|
if caption:
|
|
msg_body.append(
|
|
{"msg_type": "TIMTextElem", "msg_content": {"text": caption}},
|
|
)
|
|
|
|
# 7. Lock + dispatch
|
|
gc = kwargs.get("group_code", "")
|
|
return await sender.dispatch_msg_body(chat_id, msg_body, reply_to, group_code=gc)
|
|
|
|
except ValueError as ve:
|
|
return SendResult(success=False, error=str(ve))
|
|
except Exception as exc:
|
|
handler_name = type(self).__name__
|
|
logger.error(
|
|
"[%s] %s.handle() failed: %s",
|
|
adapter.name, handler_name, exc, exc_info=True,
|
|
)
|
|
return SendResult(success=False, error=str(exc))
|
|
|
|
|
|
class ImageUrlHandler(MediaSendHandler):
|
|
"""Strategy: send image from a URL (download → COS → TIMImageElem)."""
|
|
|
|
async def acquire_file(self, adapter, **kwargs):
|
|
image_url: str = kwargs["image_url"]
|
|
logger.info("[%s] ImageUrlHandler: downloading %s", adapter.name, image_url)
|
|
file_bytes, content_type = await media_download_url(
|
|
image_url, max_size_mb=adapter.MEDIA_MAX_SIZE_MB,
|
|
)
|
|
if not content_type or content_type == "application/octet-stream":
|
|
path_part = image_url.split("?")[0]
|
|
content_type = guess_mime_type(path_part) or "image/jpeg"
|
|
filename = os.path.basename(image_url.split("?")[0]) or "image.jpg"
|
|
return file_bytes, filename, content_type
|
|
|
|
def build_msg_body(self, upload_result, **kwargs):
|
|
return build_image_msg_body(
|
|
url=upload_result["url"],
|
|
uuid=kwargs["file_uuid"],
|
|
filename=kwargs["filename"],
|
|
size=upload_result["size"],
|
|
width=upload_result.get("width", 0),
|
|
height=upload_result.get("height", 0),
|
|
mime_type=kwargs["content_type"],
|
|
)
|
|
|
|
|
|
class ImageFileHandler(MediaSendHandler):
|
|
"""Strategy: send image from a local file path (read → COS → TIMImageElem)."""
|
|
|
|
async def acquire_file(self, adapter, **kwargs):
|
|
image_path: str = kwargs["image_path"]
|
|
if not os.path.isfile(image_path):
|
|
raise ValueError(f"File not found: {image_path}")
|
|
logger.info("[%s] ImageFileHandler: reading %s", adapter.name, image_path)
|
|
with open(image_path, "rb") as f:
|
|
file_bytes = f.read()
|
|
filename = os.path.basename(image_path) or "image.jpg"
|
|
content_type = guess_mime_type(filename) or "image/jpeg"
|
|
return file_bytes, filename, content_type
|
|
|
|
def build_msg_body(self, upload_result, **kwargs):
|
|
return build_image_msg_body(
|
|
url=upload_result["url"],
|
|
uuid=kwargs["file_uuid"],
|
|
filename=kwargs["filename"],
|
|
size=upload_result["size"],
|
|
width=upload_result.get("width", 0),
|
|
height=upload_result.get("height", 0),
|
|
mime_type=kwargs["content_type"],
|
|
)
|
|
|
|
|
|
class FileUrlHandler(MediaSendHandler):
|
|
"""Strategy: send file from a URL (download → COS → TIMFileElem)."""
|
|
|
|
async def acquire_file(self, adapter, **kwargs):
|
|
file_url: str = kwargs["file_url"]
|
|
logger.info("[%s] FileUrlHandler: downloading %s", adapter.name, file_url)
|
|
file_bytes, content_type = await media_download_url(
|
|
file_url, max_size_mb=adapter.MEDIA_MAX_SIZE_MB,
|
|
)
|
|
filename = kwargs.get("filename")
|
|
if not filename:
|
|
path_part = file_url.split("?")[0]
|
|
filename = os.path.basename(path_part) or "file"
|
|
if not content_type or content_type == "application/octet-stream":
|
|
content_type = guess_mime_type(filename) or "application/octet-stream"
|
|
return file_bytes, filename, content_type
|
|
|
|
def build_msg_body(self, upload_result, **kwargs):
|
|
return build_file_msg_body(
|
|
url=upload_result["url"],
|
|
filename=kwargs["filename"],
|
|
uuid=kwargs["file_uuid"],
|
|
size=upload_result["size"],
|
|
)
|
|
|
|
|
|
class DocumentHandler(MediaSendHandler):
|
|
"""Strategy: send local file/document (read → COS → TIMFileElem)."""
|
|
|
|
async def acquire_file(self, adapter, **kwargs):
|
|
file_path: str = kwargs["file_path"]
|
|
if not os.path.isfile(file_path):
|
|
raise ValueError(f"File not found: {file_path}")
|
|
logger.info("[%s] DocumentHandler: reading %s", adapter.name, file_path)
|
|
with open(file_path, "rb") as f:
|
|
file_bytes = f.read()
|
|
filename = kwargs.get("filename") or os.path.basename(file_path) or "document"
|
|
content_type = guess_mime_type(filename) or "application/octet-stream"
|
|
return file_bytes, filename, content_type
|
|
|
|
def build_msg_body(self, upload_result, **kwargs):
|
|
return build_file_msg_body(
|
|
url=upload_result["url"],
|
|
filename=kwargs["filename"],
|
|
uuid=kwargs["file_uuid"],
|
|
size=upload_result["size"],
|
|
)
|
|
|
|
|
|
class StickerHandler(MediaSendHandler):
|
|
"""Strategy: send sticker/emoji (TIMFaceElem, no COS upload needed)."""
|
|
|
|
def needs_cos_upload(self) -> bool:
|
|
return False
|
|
|
|
async def acquire_file(self, adapter, **kwargs):
|
|
# Sticker does not need file bytes; return dummy values
|
|
return b"", "sticker", "application/octet-stream"
|
|
|
|
def build_msg_body(self, upload_result, **kwargs):
|
|
from gateway.platforms.yuanbao_sticker import (
|
|
get_sticker_by_name,
|
|
get_random_sticker,
|
|
build_face_msg_body,
|
|
build_sticker_msg_body,
|
|
)
|
|
sticker_name = kwargs.get("sticker_name")
|
|
face_index = kwargs.get("face_index")
|
|
|
|
if sticker_name is not None:
|
|
sticker = get_sticker_by_name(sticker_name)
|
|
if sticker is None:
|
|
raise ValueError(f"Sticker not found: {sticker_name!r}")
|
|
return build_sticker_msg_body(sticker)
|
|
elif face_index is not None:
|
|
return build_face_msg_body(face_index=face_index)
|
|
else:
|
|
sticker = get_random_sticker()
|
|
return build_sticker_msg_body(sticker)
|
|
|
|
class GroupQueryService:
|
|
"""Encapsulates all group query operations (both low-level WS calls and
|
|
higher-level AI-tool-facing wrappers).
|
|
|
|
Responsibilities:
|
|
- Low-level WS encode/decode for group info and member list queries
|
|
- Chat-id parsing, error wrapping and result filtering for AI tools
|
|
- Member cache population on the adapter
|
|
"""
|
|
|
|
def __init__(self, adapter: "YuanbaoAdapter") -> None:
|
|
self._adapter = adapter
|
|
|
|
# ------------------------------------------------------------------
|
|
# Low-level WS query methods
|
|
# ------------------------------------------------------------------
|
|
|
|
async def query_group_info_raw(self, group_code: str) -> Optional[dict]:
|
|
"""Query group info via WS (group name, owner, member count, etc.).
|
|
|
|
Returns:
|
|
Decoded dict or None on failure.
|
|
"""
|
|
adapter = self._adapter
|
|
if adapter._connection.ws is None:
|
|
return None
|
|
encoded = encode_query_group_info(group_code)
|
|
from gateway.platforms.yuanbao_proto import decode_conn_msg as _decode
|
|
decoded = _decode(encoded)
|
|
req_id = decoded["head"]["msg_id"]
|
|
try:
|
|
response = await adapter._connection.send_biz_request(encoded, req_id=req_id)
|
|
head = response.get("head", {})
|
|
status = head.get("status", 0)
|
|
if status != 0:
|
|
logger.warning("[%s] query_group_info failed: status=%d", adapter.name, status)
|
|
return None
|
|
biz_data = response.get("data", b"") or response.get("body", b"")
|
|
if biz_data and isinstance(biz_data, bytes):
|
|
return decode_query_group_info_rsp(biz_data)
|
|
return {"group_code": group_code}
|
|
except asyncio.TimeoutError:
|
|
logger.warning("[%s] query_group_info timeout: group=%s", adapter.name, group_code)
|
|
return None
|
|
except Exception as exc:
|
|
logger.warning("[%s] query_group_info failed: %s", adapter.name, exc)
|
|
return None
|
|
|
|
async def get_group_member_list_raw(
|
|
self, group_code: str, offset: int = 0, limit: int = 200
|
|
) -> Optional[dict]:
|
|
"""Query group member list via WS.
|
|
|
|
Returns:
|
|
Decoded dict or None on failure. Also populates adapter._member_cache.
|
|
"""
|
|
adapter = self._adapter
|
|
if adapter._connection.ws is None:
|
|
return None
|
|
encoded = encode_get_group_member_list(group_code, offset=offset, limit=limit)
|
|
from gateway.platforms.yuanbao_proto import decode_conn_msg as _decode
|
|
decoded = _decode(encoded)
|
|
req_id = decoded["head"]["msg_id"]
|
|
try:
|
|
response = await adapter._connection.send_biz_request(encoded, req_id=req_id)
|
|
head = response.get("head", {})
|
|
status = head.get("status", 0)
|
|
if status != 0:
|
|
logger.warning("[%s] get_group_member_list failed: status=%d", adapter.name, status)
|
|
return None
|
|
biz_data = response.get("data", b"") or response.get("body", b"")
|
|
if biz_data and isinstance(biz_data, bytes):
|
|
result = decode_get_group_member_list_rsp(biz_data)
|
|
else:
|
|
result = {"members": [], "next_offset": 0, "is_complete": True}
|
|
if result and result.get("members"):
|
|
adapter._member_cache[group_code] = (time.time(), result["members"])
|
|
return result
|
|
except asyncio.TimeoutError:
|
|
logger.warning("[%s] get_group_member_list timeout: group=%s", adapter.name, group_code)
|
|
return None
|
|
except Exception as exc:
|
|
logger.warning("[%s] get_group_member_list failed: %s", adapter.name, exc)
|
|
return None
|
|
|
|
# ------------------------------------------------------------------
|
|
# AI-tool-facing wrappers (chat_id parsing + filtering)
|
|
# ------------------------------------------------------------------
|
|
|
|
async def query_group_info(self, chat_id: str) -> dict:
|
|
"""AI tool: Query current group info.
|
|
|
|
No parameters needed (group_code extracted from session context).
|
|
Returns group name, owner, member count, etc.
|
|
"""
|
|
if not chat_id.startswith("group:"):
|
|
return {"error": "This command is only available in group chats"}
|
|
group_code = chat_id[len("group:"):]
|
|
result = await self.query_group_info_raw(group_code)
|
|
if result is None:
|
|
return {"error": "Failed to query group info"}
|
|
return result
|
|
|
|
async def query_session_members(
|
|
self,
|
|
chat_id: str,
|
|
action: str = "list_all",
|
|
name: Optional[str] = None,
|
|
) -> dict:
|
|
"""AI tool: Query group member list.
|
|
|
|
Args:
|
|
chat_id: Chat ID (extracted from session context)
|
|
action: 'find' (search by name) | 'list_bots' (list bots) | 'list_all' (list all)
|
|
name: Search keyword when action='find'
|
|
|
|
Returns:
|
|
{"members": [...], "total": int, "mentionHint": str}
|
|
"""
|
|
if not chat_id.startswith("group:"):
|
|
return {"error": "This command is only available in group chats"}
|
|
group_code = chat_id[len("group:"):]
|
|
result = await self.get_group_member_list_raw(group_code)
|
|
if result is None:
|
|
return {"error": "Failed to query group members"}
|
|
|
|
members = result.get("members", [])
|
|
|
|
if action == "find" and name:
|
|
query = name.lower()
|
|
members = [
|
|
m for m in members
|
|
if query in (m.get("nickname", "") or "").lower()
|
|
or query in (m.get("name_card", "") or "").lower()
|
|
or query in (m.get("user_id", "") or "").lower()
|
|
]
|
|
elif action == "list_bots":
|
|
members = [m for m in members if "bot" in (m.get("nickname", "") or "").lower()]
|
|
|
|
# Construct mentionHint
|
|
mention_hint = ""
|
|
if members and len(members) <= 10:
|
|
names = [m.get("name_card") or m.get("nickname") or m.get("user_id", "") for m in members]
|
|
mention_hint = "Mention with @name: " + ", ".join(names)
|
|
|
|
return {
|
|
"members": members[:50], # Limit return count
|
|
"total": len(members),
|
|
"mentionHint": mention_hint,
|
|
}
|
|
|
|
|
|
class HeartbeatManager:
|
|
"""Manages reply heartbeat (RUNNING / FINISH) lifecycle.
|
|
|
|
Responsibilities:
|
|
- Periodic RUNNING heartbeat sender (every 2s)
|
|
- Auto-FINISH after 30s inactivity
|
|
- Explicit stop with optional FINISH signal
|
|
"""
|
|
|
|
def __init__(self, adapter: "YuanbaoAdapter") -> None:
|
|
self._adapter = adapter
|
|
self._reply_heartbeat_tasks: Dict[str, asyncio.Task] = {}
|
|
self._reply_hb_last_active: Dict[str, float] = {}
|
|
|
|
async def send_heartbeat_once(self, chat_id: str, heartbeat_val: int) -> None:
|
|
"""Send a single heartbeat (RUNNING or FINISH), best effort."""
|
|
adapter = self._adapter
|
|
conn = adapter._connection
|
|
if conn.ws is None or not adapter._bot_id:
|
|
return
|
|
try:
|
|
if chat_id.startswith("group:"):
|
|
group_code = chat_id[len("group:"):]
|
|
encoded = encode_send_group_heartbeat(
|
|
from_account=adapter._bot_id,
|
|
group_code=group_code,
|
|
heartbeat=heartbeat_val,
|
|
)
|
|
else:
|
|
to_account = chat_id.removeprefix("direct:")
|
|
encoded = encode_send_private_heartbeat(
|
|
from_account=adapter._bot_id,
|
|
to_account=to_account,
|
|
heartbeat=heartbeat_val,
|
|
)
|
|
await conn.ws.send(encoded)
|
|
status_name = "RUNNING" if heartbeat_val == WS_HEARTBEAT_RUNNING else "FINISH"
|
|
logger.debug(
|
|
"[%s] Reply heartbeat %s sent: chat=%s",
|
|
adapter.name, status_name, chat_id,
|
|
)
|
|
except Exception as exc:
|
|
logger.debug("[%s] send_heartbeat_once failed: %s", adapter.name, exc)
|
|
|
|
async def start(self, chat_id: str) -> None:
|
|
"""Start or renew the Reply Heartbeat periodic sender (RUNNING, every 2s)."""
|
|
adapter = self._adapter
|
|
conn = adapter._connection
|
|
if conn.ws is None or not adapter._bot_id:
|
|
return
|
|
|
|
existing = self._reply_heartbeat_tasks.get(chat_id)
|
|
if existing and not existing.done():
|
|
self._reply_hb_last_active[chat_id] = time.time()
|
|
return
|
|
|
|
self._reply_hb_last_active[chat_id] = time.time()
|
|
|
|
task = asyncio.create_task(
|
|
self._worker(chat_id),
|
|
name=f"yuanbao-reply-hb-{chat_id}",
|
|
)
|
|
self._reply_heartbeat_tasks[chat_id] = task
|
|
|
|
async def _worker(self, chat_id: str) -> None:
|
|
"""Background coroutine: send RUNNING heartbeat every 2s.
|
|
30s without renewal -> send FINISH and exit.
|
|
"""
|
|
try:
|
|
await self.send_heartbeat_once(chat_id, WS_HEARTBEAT_RUNNING)
|
|
|
|
while True:
|
|
await asyncio.sleep(REPLY_HEARTBEAT_INTERVAL_S)
|
|
|
|
last_active = self._reply_hb_last_active.get(chat_id, 0)
|
|
if time.time() - last_active > REPLY_HEARTBEAT_TIMEOUT_S:
|
|
break
|
|
|
|
conn = self._adapter._connection
|
|
if conn.ws is None:
|
|
break
|
|
|
|
await self.send_heartbeat_once(chat_id, WS_HEARTBEAT_RUNNING)
|
|
|
|
except asyncio.CancelledError:
|
|
cancelled = True
|
|
except Exception:
|
|
cancelled = False
|
|
else:
|
|
cancelled = False
|
|
finally:
|
|
if not cancelled:
|
|
try:
|
|
await self.send_heartbeat_once(chat_id, WS_HEARTBEAT_FINISH)
|
|
except Exception:
|
|
pass
|
|
self._reply_heartbeat_tasks.pop(chat_id, None)
|
|
self._reply_hb_last_active.pop(chat_id, None)
|
|
|
|
async def stop(self, chat_id: str, send_finish: bool = True) -> None:
|
|
"""Stop Reply Heartbeat and optionally send FINISH."""
|
|
task = self._reply_heartbeat_tasks.pop(chat_id, None)
|
|
if task and not task.done():
|
|
task.cancel()
|
|
try:
|
|
await task
|
|
except asyncio.CancelledError:
|
|
pass
|
|
if send_finish:
|
|
try:
|
|
await self.send_heartbeat_once(chat_id, WS_HEARTBEAT_FINISH)
|
|
except Exception:
|
|
pass
|
|
|
|
async def close(self) -> None:
|
|
"""Cancel all reply heartbeat tasks."""
|
|
for task in list(self._reply_heartbeat_tasks.values()):
|
|
if not task.done():
|
|
task.cancel()
|
|
self._reply_heartbeat_tasks.clear()
|
|
self._reply_hb_last_active.clear()
|
|
|
|
|
|
class SlowResponseNotifier:
|
|
"""Manages delayed 'please wait' notifications for slow agent responses.
|
|
|
|
Starts a timer per chat_id; if the agent hasn't replied within
|
|
SLOW_RESPONSE_TIMEOUT_S seconds, sends a courtesy message.
|
|
"""
|
|
|
|
def __init__(self, adapter: "YuanbaoAdapter", sender: "MessageSender") -> None:
|
|
self._adapter = adapter
|
|
self._sender = sender
|
|
self._tasks: Dict[str, asyncio.Task] = {}
|
|
|
|
async def start(self, chat_id: str) -> None:
|
|
"""Start a delayed task that notifies the user when the agent is slow."""
|
|
self.cancel(chat_id)
|
|
task = asyncio.create_task(
|
|
self._notifier(chat_id),
|
|
name=f"yuanbao-slow-resp-{chat_id}",
|
|
)
|
|
self._tasks[chat_id] = task
|
|
|
|
async def _notifier(self, chat_id: str) -> None:
|
|
"""Wait SLOW_RESPONSE_TIMEOUT_S, then push a 'please wait' message."""
|
|
try:
|
|
await asyncio.sleep(SLOW_RESPONSE_TIMEOUT_S)
|
|
logger.info(
|
|
"[%s] Agent response exceeded %ds for %s, sending wait notice",
|
|
self._adapter.name, int(SLOW_RESPONSE_TIMEOUT_S), chat_id,
|
|
)
|
|
await self._sender.send_text_chunk(chat_id, SLOW_RESPONSE_MESSAGE)
|
|
except asyncio.CancelledError:
|
|
pass
|
|
except Exception as exc:
|
|
logger.debug("[%s] Slow-response notifier failed: %s", self._adapter.name, exc)
|
|
|
|
def cancel(self, chat_id: str) -> None:
|
|
"""Cancel the pending slow-response notifier for *chat_id*, if any."""
|
|
task = self._tasks.pop(chat_id, None)
|
|
if task and not task.done():
|
|
task.cancel()
|
|
|
|
async def close(self) -> None:
|
|
"""Cancel all slow-response tasks."""
|
|
for task in list(self._tasks.values()):
|
|
if not task.done():
|
|
task.cancel()
|
|
self._tasks.clear()
|
|
|
|
|
|
class MessageSender:
|
|
"""Core message sending dispatcher for YuanbaoAdapter.
|
|
|
|
Responsibilities:
|
|
- Per-chat-id lock management (serial send ordering)
|
|
- Text chunk sending with retry
|
|
- C2C / Group message encoding and dispatch
|
|
- Media send helpers (image, file, sticker, document)
|
|
- Direct send helper (text + media, used by send_message tool)
|
|
"""
|
|
|
|
IMAGE_EXTS: ClassVar[frozenset] = frozenset({".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp"})
|
|
CHAT_DICT_MAX_SIZE: ClassVar[int] = 1000 # Max distinct chat IDs in _chat_locks
|
|
|
|
def __init__(self, adapter: "YuanbaoAdapter") -> None:
|
|
self._adapter = adapter
|
|
self._chat_locks: collections.OrderedDict[str, asyncio.Lock] = collections.OrderedDict()
|
|
|
|
# Optional hooks injected by OutboundManager for coordination
|
|
self._on_send_start: Optional[Callable[[str], Any]] = None # cancel slow-notifier
|
|
self._on_send_finish: Optional[Callable[[str], Any]] = None # send FINISH heartbeat
|
|
|
|
# Media send handlers (strategy pattern)
|
|
self._media_handlers: Dict[str, MediaSendHandler] = {
|
|
"image_url": ImageUrlHandler(),
|
|
"image_file": ImageFileHandler(),
|
|
"file_url": FileUrlHandler(),
|
|
"document": DocumentHandler(),
|
|
"sticker": StickerHandler(),
|
|
}
|
|
|
|
# -- Media handler registry ---------------------------------------------
|
|
|
|
def register_handler(self, name: str, handler: MediaSendHandler) -> None:
|
|
"""Register (or replace) a named media send handler."""
|
|
self._media_handlers[name] = handler
|
|
|
|
# -- Chat lock ---------------------------------------------------------
|
|
|
|
def get_chat_lock(self, chat_id: str) -> asyncio.Lock:
|
|
"""Return (or create) a per-chat-id lock with safe LRU eviction."""
|
|
if chat_id in self._chat_locks:
|
|
self._chat_locks.move_to_end(chat_id)
|
|
return self._chat_locks[chat_id]
|
|
if len(self._chat_locks) >= self.CHAT_DICT_MAX_SIZE:
|
|
evicted = False
|
|
for key in list(self._chat_locks):
|
|
if not self._chat_locks[key].locked():
|
|
self._chat_locks.pop(key)
|
|
evicted = True
|
|
break
|
|
if not evicted:
|
|
self._chat_locks.pop(next(iter(self._chat_locks)))
|
|
self._chat_locks[chat_id] = asyncio.Lock()
|
|
return self._chat_locks[chat_id]
|
|
|
|
# -- Text send ---------------------------------------------------------
|
|
|
|
async def send_text(
|
|
self,
|
|
chat_id: str,
|
|
content: str,
|
|
reply_to: Optional[str] = None,
|
|
group_code: str = "",
|
|
) -> "SendResult":
|
|
"""Send text message with auto-chunking and per-chat-id ordering guarantee."""
|
|
adapter = self._adapter
|
|
conn = adapter._connection
|
|
if conn.ws is None:
|
|
return SendResult(success=False, error="Not connected", retryable=True)
|
|
|
|
if self._on_send_start:
|
|
self._on_send_start(chat_id)
|
|
|
|
lock = self.get_chat_lock(chat_id)
|
|
async with lock:
|
|
content_to_send = self.strip_cron_wrapper(content)
|
|
chunks = self.truncate_message(content_to_send, adapter.MAX_TEXT_CHUNK)
|
|
logger.info(
|
|
"[%s] truncate_message: input=%d chars, max=%d, output=%d chunk(s) sizes=%s",
|
|
adapter.name, len(content_to_send), adapter.MAX_TEXT_CHUNK,
|
|
len(chunks), [len(c) for c in chunks],
|
|
)
|
|
for i, chunk in enumerate(chunks):
|
|
r_to = reply_to if i == 0 else None
|
|
result = await self.send_text_chunk(chat_id, chunk, r_to, group_code=group_code)
|
|
if not result.success:
|
|
return result
|
|
|
|
# Notify outbound coordinator that send is complete (e.g. FINISH heartbeat)
|
|
if self._on_send_finish:
|
|
try:
|
|
await self._on_send_finish(chat_id)
|
|
except Exception:
|
|
pass
|
|
return SendResult(success=True)
|
|
|
|
async def send_media(
|
|
self,
|
|
chat_id: str,
|
|
handler_name: str,
|
|
reply_to: Optional[str] = None,
|
|
caption: Optional[str] = None,
|
|
**kwargs: Any,
|
|
) -> "SendResult":
|
|
"""Dispatch media send to the named handler strategy."""
|
|
handler = self._media_handlers.get(handler_name)
|
|
if handler is None:
|
|
return SendResult(
|
|
success=False,
|
|
error=f"Unknown media handler: {handler_name!r}",
|
|
)
|
|
return await handler.handle(
|
|
self._adapter, chat_id,
|
|
reply_to=reply_to, caption=caption, **kwargs,
|
|
)
|
|
|
|
# -- Direct send (text + media, used by send_message tool) -------------
|
|
|
|
async def send_direct(
|
|
self,
|
|
chat_id: str,
|
|
message: str,
|
|
media_files: Optional[List[Tuple[str, bool]]] = None,
|
|
) -> Dict[str, Any]:
|
|
"""Send text + media via Yuanbao (used by the ``send_message`` tool).
|
|
|
|
Unlike Weixin which creates a fresh adapter per call, Yuanbao reuses
|
|
the running gateway adapter (persistent WebSocket). Logic mirrors
|
|
send_weixin_direct: send text first, then iterate media_files by
|
|
extension.
|
|
"""
|
|
adapter = self._adapter
|
|
last_result: Optional["SendResult"] = None
|
|
|
|
# 1. Send text
|
|
if message.strip():
|
|
last_result = await adapter.send(chat_id, message)
|
|
if not last_result.success:
|
|
return {"error": f"Yuanbao send failed: {last_result.error}"}
|
|
|
|
# 2. Iterate media_files, dispatch by file extension
|
|
for media_path, _is_voice in media_files or []:
|
|
ext = Path(media_path).suffix.lower()
|
|
if ext in self.IMAGE_EXTS:
|
|
last_result = await adapter.send_image_file(chat_id, media_path)
|
|
else:
|
|
last_result = await adapter.send_document(chat_id, media_path)
|
|
|
|
if not last_result.success:
|
|
return {"error": f"Yuanbao media send failed: {last_result.error}"}
|
|
|
|
if last_result is None:
|
|
return {"error": "No deliverable text or media remained after processing"}
|
|
|
|
return {
|
|
"success": True,
|
|
"platform": "yuanbao",
|
|
"chat_id": chat_id,
|
|
"message_id": last_result.message_id if last_result else None,
|
|
}
|
|
|
|
async def dispatch_msg_body(
|
|
self,
|
|
chat_id: str,
|
|
msg_body: list,
|
|
reply_to: Optional[str] = None,
|
|
group_code: str = "",
|
|
) -> "SendResult":
|
|
"""Lock + dispatch an arbitrary MsgBody to C2C or group."""
|
|
lock = self.get_chat_lock(chat_id)
|
|
async with lock:
|
|
if chat_id.startswith("group:"):
|
|
grp = chat_id[len("group:"):]
|
|
result = await self.send_group_msg_body(grp, msg_body, reply_to)
|
|
else:
|
|
to_account = chat_id.removeprefix("direct:")
|
|
result = await self.send_c2c_msg_body(to_account, msg_body, group_code=group_code)
|
|
|
|
if result.get("success"):
|
|
return SendResult(success=True, message_id=result.get("msg_key"))
|
|
return SendResult(success=False, error=result.get("error", "Unknown error"))
|
|
|
|
async def send_text_chunk(
|
|
self,
|
|
chat_id: str,
|
|
text: str,
|
|
reply_to: Optional[str] = None,
|
|
retry: int = 3,
|
|
group_code: str = "",
|
|
) -> "SendResult":
|
|
"""Send a single text chunk with retry (exponential backoff: 1s, 2s, 4s)."""
|
|
adapter = self._adapter
|
|
last_error: str = "Unknown error"
|
|
for attempt in range(retry):
|
|
try:
|
|
if chat_id.startswith("group:"):
|
|
grp = chat_id[len("group:"):]
|
|
raw = await self.send_group_message(grp, text, reply_to)
|
|
else:
|
|
to_account = chat_id.removeprefix("direct:")
|
|
raw = await self.send_c2c_message(to_account, text, group_code=group_code)
|
|
|
|
if raw.get("success"):
|
|
return SendResult(success=True, message_id=raw.get("msg_key"))
|
|
|
|
last_error = raw.get("error", "Unknown error")
|
|
logger.warning(
|
|
"[%s] send_text_chunk attempt %d/%d failed: %s",
|
|
adapter.name, attempt + 1, retry, last_error,
|
|
)
|
|
except Exception as exc:
|
|
last_error = str(exc)
|
|
logger.warning(
|
|
"[%s] send_text_chunk attempt %d/%d exception: %s",
|
|
adapter.name, attempt + 1, retry, last_error,
|
|
)
|
|
|
|
if attempt < retry - 1:
|
|
await asyncio.sleep(2 ** attempt)
|
|
|
|
logger.error(
|
|
"[%s] send_text_chunk max retries (%d) exceeded. Last error: %s",
|
|
adapter.name, retry, last_error,
|
|
)
|
|
return SendResult(success=False, error=f"Max retries exceeded: {last_error}")
|
|
|
|
# -- C2C / Group message -----------------------------------------------
|
|
|
|
async def send_c2c_message(self, to_account: str, text: str, group_code: str = "") -> dict:
|
|
"""Send C2C text message, return {success: bool, msg_key: str}."""
|
|
msg_body = [{"msg_type": "TIMTextElem", "msg_content": {"text": text}}]
|
|
return await self.send_c2c_msg_body(to_account, msg_body, group_code=group_code)
|
|
|
|
async def send_group_message(
|
|
self,
|
|
group_code: str,
|
|
text: str,
|
|
reply_to: Optional[str] = None,
|
|
) -> dict:
|
|
"""Send group text message, auto-converting @nickname to TIMCustomElem."""
|
|
msg_body = self._build_msg_body_with_mentions(text, group_code)
|
|
return await self.send_group_msg_body(group_code, msg_body, reply_to)
|
|
|
|
# @mention pattern: (whitespace or start) + @ + nickname + (whitespace or end)
|
|
_AT_USER_RE = re.compile(r'(?:(?<=\s)|(?<=^))@(\S+?)(?=\s|$)', re.MULTILINE)
|
|
|
|
def _build_msg_body_with_mentions(self, text: str, group_code: str) -> list:
|
|
"""Parse @nickname patterns and build mixed TIMTextElem + TIMCustomElem msg_body."""
|
|
cached = self._adapter._member_cache.get(group_code)
|
|
if cached:
|
|
ts, member_list = cached
|
|
members = member_list if (time.time() - ts < self._adapter.MEMBER_CACHE_TTL_S) else []
|
|
else:
|
|
members = []
|
|
if not members:
|
|
return [{"msg_type": "TIMTextElem", "msg_content": {"text": text}}]
|
|
|
|
nickname_to_uid = {}
|
|
for m in members:
|
|
nick = m.get("nickname") or m.get("nick_name") or ""
|
|
uid = m.get("user_id") or ""
|
|
if nick and uid:
|
|
nickname_to_uid[nick.lower()] = (nick, uid)
|
|
|
|
msg_body: list = []
|
|
last_idx = 0
|
|
for match in self._AT_USER_RE.finditer(text):
|
|
start = match.start()
|
|
if start > last_idx:
|
|
seg = text[last_idx:start].strip()
|
|
if seg:
|
|
msg_body.append({"msg_type": "TIMTextElem", "msg_content": {"text": seg}})
|
|
|
|
nickname = match.group(1)
|
|
entry = nickname_to_uid.get(nickname.lower())
|
|
if entry:
|
|
real_nick, uid = entry
|
|
msg_body.append({
|
|
"msg_type": "TIMCustomElem",
|
|
"msg_content": {
|
|
"data": json.dumps({"elem_type": 1002, "text": f"@{real_nick}", "user_id": uid}),
|
|
},
|
|
})
|
|
else:
|
|
msg_body.append({"msg_type": "TIMTextElem", "msg_content": {"text": f"@{nickname}"}})
|
|
|
|
last_idx = match.end()
|
|
|
|
if last_idx < len(text):
|
|
tail = text[last_idx:].strip()
|
|
if tail:
|
|
msg_body.append({"msg_type": "TIMTextElem", "msg_content": {"text": tail}})
|
|
|
|
if not msg_body:
|
|
msg_body.append({"msg_type": "TIMTextElem", "msg_content": {"text": text}})
|
|
|
|
return msg_body
|
|
|
|
async def send_c2c_msg_body(self, to_account: str, msg_body: list, group_code: str = "") -> dict:
|
|
"""Send C2C message with arbitrary MsgBody."""
|
|
adapter = self._adapter
|
|
req_id = f"c2c_{next_seq_no()}"
|
|
encoded = encode_send_c2c_message(
|
|
to_account=to_account,
|
|
msg_body=msg_body,
|
|
from_account=adapter._bot_id or "",
|
|
msg_id=req_id,
|
|
group_code=group_code,
|
|
)
|
|
return await self._dispatch_encoded(adapter, encoded, req_id)
|
|
|
|
async def send_group_msg_body(
|
|
self,
|
|
group_code: str,
|
|
msg_body: list,
|
|
reply_to: Optional[str] = None,
|
|
) -> dict:
|
|
"""Send group message with arbitrary MsgBody."""
|
|
adapter = self._adapter
|
|
req_id = f"grp_{next_seq_no()}"
|
|
encoded = encode_send_group_message(
|
|
group_code=group_code,
|
|
msg_body=msg_body,
|
|
from_account=adapter._bot_id or "",
|
|
msg_id=req_id,
|
|
ref_msg_id=reply_to or "",
|
|
)
|
|
return await self._dispatch_encoded(adapter, encoded, req_id)
|
|
|
|
# -- Common dispatch helper --------------------------------------------
|
|
|
|
@staticmethod
|
|
async def _dispatch_encoded(
|
|
adapter: "YuanbaoAdapter", encoded: bytes, req_id: str,
|
|
) -> dict:
|
|
"""Send pre-encoded bytes via WS and return a normalised result dict."""
|
|
try:
|
|
response = await adapter._connection.send_biz_request(encoded, req_id=req_id)
|
|
return {"success": True, "msg_key": response.get("msg_id", "")}
|
|
except asyncio.TimeoutError:
|
|
return {"success": False, "error": f"Request timeout after {DEFAULT_SEND_TIMEOUT}s"}
|
|
except Exception as exc:
|
|
return {"success": False, "error": str(exc)}
|
|
|
|
# -- Media validation ---------------------------------------------------
|
|
|
|
@staticmethod
|
|
def validate_media(
|
|
file_bytes: Optional[bytes], filename: str, max_size_mb: int = 20
|
|
) -> Optional[str]:
|
|
"""Media pre-validation: check file validity before sending/uploading.
|
|
|
|
Returns:
|
|
Error description (str) if validation fails, otherwise None.
|
|
"""
|
|
if file_bytes is None or len(file_bytes) == 0:
|
|
return f"Empty file: {filename}"
|
|
max_bytes = max_size_mb * 1024 * 1024
|
|
if len(file_bytes) > max_bytes:
|
|
size_mb = len(file_bytes) / 1024 / 1024
|
|
return f"File too large: {filename} ({size_mb:.1f}MB > {max_size_mb}MB)"
|
|
return None
|
|
|
|
# -- Text truncation (table-aware) --------------------------------------
|
|
|
|
@staticmethod
|
|
def truncate_message(
|
|
content: str,
|
|
max_length: int = 4000,
|
|
len_fn: Optional[Callable[[str], int]] = None,
|
|
) -> List[str]:
|
|
"""
|
|
Split a long message into chunks with table-awareness.
|
|
|
|
Delegates core splitting to ``MarkdownProcessor.chunk_markdown_text``
|
|
and strips page indicators like ``(1/3)`` from the output.
|
|
|
|
Falls back to ``BasePlatformAdapter.truncate_message`` for non-table
|
|
content and for overall text that fits in a single chunk.
|
|
"""
|
|
_len = len_fn or len
|
|
if _len(content) <= max_length:
|
|
return [content]
|
|
|
|
# Delegate to MarkdownProcessor for table/fence-aware chunking
|
|
chunks = MarkdownProcessor.chunk_markdown_text(
|
|
content, max_length, len_fn=len_fn,
|
|
)
|
|
|
|
# Strip page indicators like (1/3) that BasePlatformAdapter may add
|
|
chunks = [_INDICATOR_RE.sub('', c) for c in chunks]
|
|
|
|
return chunks if chunks else [content]
|
|
|
|
# -- Cron wrapper stripping ---------------------------------------------
|
|
|
|
@staticmethod
|
|
def strip_cron_wrapper(content: str) -> str:
|
|
"""Strip scheduler cron header/footer wrapper for cleaner Yuanbao output."""
|
|
if not content.startswith("Cronjob Response: "):
|
|
return content
|
|
|
|
divider = "\n-------------\n\n"
|
|
footer_prefix = '\n\nTo stop or manage this job, send me a new message (e.g. "stop reminder '
|
|
divider_pos = content.find(divider)
|
|
footer_pos = content.rfind(footer_prefix)
|
|
if divider_pos < 0 or footer_pos < 0 or footer_pos <= divider_pos:
|
|
return content
|
|
|
|
header = content[:divider_pos]
|
|
if "\n(job_id: " not in header:
|
|
return content
|
|
|
|
body_start = divider_pos + len(divider)
|
|
body = content[body_start:footer_pos].strip()
|
|
return body or content
|
|
|
|
# -- Cleanup on disconnect ---------------------------------------------
|
|
|
|
async def close(self) -> None:
|
|
"""Release chat locks (no-op for now; placeholder for future cleanup)."""
|
|
self._chat_locks.clear()
|
|
|
|
|
|
class OutboundManager:
|
|
"""Outbound coordinator that orchestrates sending, heartbeat and slow-response.
|
|
|
|
Composes:
|
|
- MessageSender — core text/media sending
|
|
- HeartbeatManager — reply heartbeat (RUNNING / FINISH) lifecycle
|
|
- SlowResponseNotifier — delayed 'please wait' notifications
|
|
|
|
YuanbaoAdapter holds a single ``_outbound: OutboundManager`` and delegates
|
|
all outbound operations through it.
|
|
"""
|
|
|
|
# Expose class-level constants from MessageSender for backward compatibility
|
|
CHAT_DICT_MAX_SIZE: ClassVar[int] = MessageSender.CHAT_DICT_MAX_SIZE
|
|
|
|
def __init__(self, adapter: "YuanbaoAdapter") -> None:
|
|
self._adapter = adapter
|
|
self.sender: MessageSender = MessageSender(adapter)
|
|
self.heartbeat: HeartbeatManager = HeartbeatManager(adapter)
|
|
self.slow_notifier: SlowResponseNotifier = SlowResponseNotifier(adapter, self.sender)
|
|
|
|
# Wire coordination hooks into MessageSender
|
|
self.sender._on_send_start = self._handle_send_start
|
|
self.sender._on_send_finish = self._handle_send_finish
|
|
|
|
# -- Coordination hooks ------------------------------------------------
|
|
|
|
def _handle_send_start(self, chat_id: str) -> None:
|
|
"""Called by MessageSender before sending: cancel slow-response notifier."""
|
|
self.slow_notifier.cancel(chat_id)
|
|
|
|
async def _handle_send_finish(self, chat_id: str) -> None:
|
|
"""Called by MessageSender after sending: send FINISH heartbeat."""
|
|
await self.heartbeat.send_heartbeat_once(chat_id, WS_HEARTBEAT_FINISH)
|
|
|
|
# -- Delegated public API (used by YuanbaoAdapter) ---------------------
|
|
|
|
async def send_text(
|
|
self, chat_id: str, content: str, reply_to: Optional[str] = None,
|
|
group_code: str = "",
|
|
) -> "SendResult":
|
|
"""Send text message with auto-chunking."""
|
|
return await self.sender.send_text(chat_id, content, reply_to, group_code=group_code)
|
|
|
|
async def send_media(
|
|
self, chat_id: str, handler_name: str, **kwargs: Any,
|
|
) -> "SendResult":
|
|
"""Dispatch media send to the named handler strategy."""
|
|
return await self.sender.send_media(chat_id, handler_name, **kwargs)
|
|
|
|
async def send_direct(
|
|
self, chat_id: str, message: str,
|
|
media_files: Optional[List[Tuple[str, bool]]] = None,
|
|
) -> Dict[str, Any]:
|
|
"""Send text + media (used by send_message tool)."""
|
|
return await self.sender.send_direct(chat_id, message, media_files)
|
|
|
|
async def start_typing(self, chat_id: str) -> None:
|
|
"""Start reply heartbeat (RUNNING)."""
|
|
await self.heartbeat.start(chat_id)
|
|
|
|
async def stop_typing(self, chat_id: str, send_finish: bool = False) -> None:
|
|
"""Stop reply heartbeat."""
|
|
await self.heartbeat.stop(chat_id, send_finish=send_finish)
|
|
|
|
async def start_slow_notifier(self, chat_id: str) -> None:
|
|
"""Start slow-response notifier."""
|
|
await self.slow_notifier.start(chat_id)
|
|
|
|
def cancel_slow_notifier(self, chat_id: str) -> None:
|
|
"""Cancel slow-response notifier."""
|
|
self.slow_notifier.cancel(chat_id)
|
|
|
|
def get_chat_lock(self, chat_id: str) -> asyncio.Lock:
|
|
"""Proxy to MessageSender.get_chat_lock for backward compatibility."""
|
|
return self.sender.get_chat_lock(chat_id)
|
|
|
|
@property
|
|
def _chat_locks(self) -> collections.OrderedDict:
|
|
"""Proxy to MessageSender._chat_locks for backward compatibility."""
|
|
return self.sender._chat_locks
|
|
|
|
@staticmethod
|
|
def validate_media(
|
|
file_bytes: Optional[bytes], filename: str, max_size_mb: int = 20,
|
|
) -> Optional[str]:
|
|
"""Proxy to MessageSender.validate_media."""
|
|
return MessageSender.validate_media(file_bytes, filename, max_size_mb)
|
|
|
|
async def close(self) -> None:
|
|
"""Shut down all sub-managers."""
|
|
await self.sender.close()
|
|
await self.heartbeat.close()
|
|
await self.slow_notifier.close()
|
|
|
|
|
|
class YuanbaoAdapter(BasePlatformAdapter):
|
|
"""Yuanbao AI Bot adapter backed by a persistent WebSocket connection."""
|
|
|
|
PLATFORM = Platform.YUANBAO
|
|
MAX_TEXT_CHUNK: int = 4000 # Yuanbao single message character limit
|
|
MEDIA_MAX_SIZE_MB: int = 50 # Max media file size in MB for upload validation
|
|
REPLY_REF_MAX_ENTRIES: ClassVar[int] = 500 # Max capacity of reference dedup dict
|
|
|
|
# -- Active instance registry (class-level singleton) -------------------
|
|
|
|
_active_instance: ClassVar[Optional["YuanbaoAdapter"]] = None
|
|
|
|
@classmethod
|
|
def get_active(cls) -> Optional["YuanbaoAdapter"]:
|
|
"""Return the currently connected YuanbaoAdapter, or None."""
|
|
return cls._active_instance
|
|
|
|
@classmethod
|
|
def set_active(cls, adapter: Optional["YuanbaoAdapter"]) -> None:
|
|
"""Register (or clear) the active adapter instance."""
|
|
cls._active_instance = adapter
|
|
|
|
def __init__(self, config: PlatformConfig, **kwargs: Any) -> None:
|
|
super().__init__(config, Platform.YUANBAO)
|
|
|
|
# Credentials / endpoints from config.extra (populated by config.py from env/yaml)
|
|
_extra = config.extra or {}
|
|
self._app_key: str = (_extra.get("app_id") or "").strip()
|
|
self._app_secret: str = (_extra.get("app_secret") or "").strip()
|
|
self._bot_id: Optional[str] = _extra.get("bot_id") or None
|
|
self._ws_url: str = (_extra.get("ws_url") or DEFAULT_WS_GATEWAY_URL).strip()
|
|
self._api_domain: str = (_extra.get("api_domain") or DEFAULT_API_DOMAIN).rstrip("/")
|
|
self._route_env: str = (_extra.get("route_env") or "").strip()
|
|
|
|
# Core managers (UML composition)
|
|
self._connection: ConnectionManager = ConnectionManager(self)
|
|
self._outbound: OutboundManager = OutboundManager(self)
|
|
|
|
# Inbound dispatch tasks — tracked so disconnect() can cancel them
|
|
self._inbound_tasks: set[asyncio.Task] = set()
|
|
|
|
# Set of background tasks — prevent GC from collecting fire-and-forget tasks
|
|
self._background_tasks: set[asyncio.Task] = set()
|
|
|
|
# Member cache: group_code -> (updated_ts, [{"user_id":..., "nickname":..., ...}, ...])
|
|
# Populated by get_group_member_list(), used by @mention resolution.
|
|
# Entries older than MEMBER_CACHE_TTL_S are treated as stale.
|
|
self._member_cache: Dict[str, Tuple[float, list]] = {}
|
|
self.MEMBER_CACHE_TTL_S: float = 300.0 # 5 minutes
|
|
|
|
# Inbound message deduplication (WS reconnect / network jitter)
|
|
self._dedup = MessageDeduplicator(ttl_seconds=300)
|
|
|
|
# Group chat sequential dispatch queue (session_key → asyncio.Queue).
|
|
self._group_queues: Dict[str, asyncio.Queue] = {}
|
|
|
|
# Recall support: track which msg_id is being processed per session_key
|
|
# so RecallGuardMiddleware can detect "currently processing" messages.
|
|
self._processing_msg_ids: Dict[str, str] = {}
|
|
self._processing_msg_texts: Dict[str, str] = {}
|
|
# Bounded cache of msg_id → attributed content for recent messages.
|
|
# Used by _patch_transcript as content-match fallback when transcript
|
|
# entries lack a message_id field (agent-processed @bot messages).
|
|
self._msg_content_cache: Dict[str, str] = {}
|
|
|
|
# Reply-to dedup: inbound_msg_id -> expire_ts
|
|
# ------------------------------------------------------------------
|
|
# Access control policy (DM / Group)
|
|
# ------------------------------------------------------------------
|
|
dm_policy: str = (
|
|
_extra.get("dm_policy")
|
|
or os.getenv("YUANBAO_DM_POLICY", "open")
|
|
).strip().lower()
|
|
|
|
_dm_allow_from_raw: str = (
|
|
_extra.get("dm_allow_from")
|
|
or os.getenv("YUANBAO_DM_ALLOW_FROM", "")
|
|
)
|
|
dm_allow_from: list[str] = [x.strip() for x in _dm_allow_from_raw.split(",") if x.strip()]
|
|
|
|
group_policy: str = (
|
|
_extra.get("group_policy")
|
|
or os.getenv("YUANBAO_GROUP_POLICY", "open")
|
|
).strip().lower()
|
|
|
|
_group_allow_from_raw: str = (
|
|
_extra.get("group_allow_from")
|
|
or os.getenv("YUANBAO_GROUP_ALLOW_FROM", "")
|
|
)
|
|
group_allow_from: list[str] = [x.strip() for x in _group_allow_from_raw.split(",") if x.strip()]
|
|
|
|
self._access_policy = AccessPolicy(
|
|
dm_policy=dm_policy,
|
|
dm_allow_from=dm_allow_from,
|
|
group_policy=group_policy,
|
|
group_allow_from=group_allow_from,
|
|
)
|
|
|
|
# Group query service (AI tool backing)
|
|
self._group_query = GroupQueryService(self)
|
|
|
|
# Inbound message processing pipeline (middleware pattern)
|
|
self._inbound_pipeline: InboundPipeline = InboundPipelineBuilder.build()
|
|
|
|
# ------------------------------------------------------------------
|
|
# Auto-sethome: first user to message the bot becomes the owner.
|
|
# If no home channel is configured, the first conversation will be
|
|
# automatically set as the home channel. When the existing home
|
|
# channel is a group chat (group:xxx), it stays eligible for
|
|
# upgrade — the first DM will override it with direct:xxx.
|
|
# ------------------------------------------------------------------
|
|
_existing_home = os.getenv("YUANBAO_HOME_CHANNEL") or (
|
|
config.home_channel.chat_id if config.home_channel else ""
|
|
)
|
|
self._auto_sethome_done: bool = bool(_existing_home) and not _existing_home.startswith("group:")
|
|
|
|
# ------------------------------------------------------------------
|
|
# Task tracking helper
|
|
# ------------------------------------------------------------------
|
|
|
|
def _track_task(self, task: asyncio.Task) -> asyncio.Task:
|
|
"""Register a fire-and-forget task so it won't be GC'd prematurely."""
|
|
self._background_tasks.add(task)
|
|
task.add_done_callback(self._background_tasks.discard)
|
|
return task
|
|
|
|
# ------------------------------------------------------------------
|
|
# Abstract method implementations
|
|
# ------------------------------------------------------------------
|
|
|
|
async def connect(self) -> bool:
|
|
"""Connect to Yuanbao WS gateway and authenticate.
|
|
|
|
Delegates to ConnectionManager.open().
|
|
"""
|
|
return await self._connection.open()
|
|
|
|
async def disconnect(self) -> None:
|
|
"""Cancel background tasks and close the WebSocket connection."""
|
|
if YuanbaoAdapter._active_instance is self:
|
|
YuanbaoAdapter.set_active(None)
|
|
|
|
self._running = False
|
|
self._mark_disconnected()
|
|
self._release_platform_lock()
|
|
|
|
# Delegate to managers
|
|
await self._connection.close()
|
|
await self._outbound.close()
|
|
|
|
# Cancel all in-flight inbound dispatch tasks
|
|
for task in list(self._inbound_tasks):
|
|
if not task.done():
|
|
task.cancel()
|
|
self._inbound_tasks.clear()
|
|
|
|
self._group_queues.clear()
|
|
|
|
logger.info("[%s] Disconnected", self.name)
|
|
|
|
async def send(
|
|
self,
|
|
chat_id: str,
|
|
content: str,
|
|
reply_to: Optional[str] = None,
|
|
metadata: Optional[Dict[str, Any]] = None,
|
|
group_code: str = "",
|
|
) -> SendResult:
|
|
"""Send text message with auto-chunking. Delegates to OutboundManager."""
|
|
return await self._outbound.send_text(chat_id, content, reply_to, group_code=group_code)
|
|
|
|
async def get_chat_info(self, chat_id: str) -> Dict[str, Any]:
|
|
"""Return basic chat metadata derived from the chat_id prefix.
|
|
|
|
chat_id conventions:
|
|
"group:<group_code>" → group chat
|
|
"direct:<account>" → C2C / direct message (default)
|
|
|
|
TODO (T06): fetch real chat name/member-count from Yuanbao API.
|
|
"""
|
|
if chat_id.startswith("group:"):
|
|
return {"name": chat_id, "type": "group"}
|
|
return {"name": chat_id, "type": "dm"}
|
|
|
|
async def send_typing(self, chat_id: str, metadata: Optional[dict] = None) -> None:
|
|
"""Send "typing" status heartbeat (RUNNING). Delegates to OutboundManager."""
|
|
try:
|
|
await self._outbound.start_typing(chat_id)
|
|
except Exception:
|
|
pass
|
|
|
|
async def stop_typing(self, chat_id: str) -> None:
|
|
"""Stop the RUNNING heartbeat loop without sending FINISH immediately.
|
|
|
|
FINISH is sent by send() after actual message delivery to ensure correct ordering:
|
|
RUNNING... -> message arrives -> FINISH.
|
|
"""
|
|
try:
|
|
await self._outbound.stop_typing(chat_id, send_finish=False)
|
|
except Exception:
|
|
pass
|
|
|
|
async def _process_message_background(self, event, session_key: str) -> None:
|
|
"""Wrap base class processing with a slow-response notifier."""
|
|
chat_id = event.source.chat_id
|
|
await self._outbound.start_slow_notifier(chat_id)
|
|
try:
|
|
await super()._process_message_background(event, session_key)
|
|
finally:
|
|
self._outbound.cancel_slow_notifier(chat_id)
|
|
|
|
# ------------------------------------------------------------------
|
|
# Group query (delegate to GroupQueryService)
|
|
# ------------------------------------------------------------------
|
|
|
|
async def query_group_info(self, group_code: str) -> Optional[dict]:
|
|
"""Query group info (delegates to GroupQueryService)."""
|
|
return await self._group_query.query_group_info_raw(group_code)
|
|
|
|
async def get_group_member_list(
|
|
self, group_code: str, offset: int = 0, limit: int = 200
|
|
) -> Optional[dict]:
|
|
"""Query group member list (delegates to GroupQueryService)."""
|
|
return await self._group_query.get_group_member_list_raw(group_code, offset=offset, limit=limit)
|
|
|
|
# ------------------------------------------------------------------
|
|
# DM active private chat + access control
|
|
# ------------------------------------------------------------------
|
|
|
|
DM_MAX_CHARS = 10000 # DM text limit
|
|
|
|
async def send_dm(self, user_id: str, text: str, group_code: str = "") -> SendResult:
|
|
"""
|
|
Actively send C2C private chat message.
|
|
|
|
Args:
|
|
user_id: Target user ID
|
|
text: Message text (limit 10000 characters)
|
|
group_code: Source group code (for group-originated DM context)
|
|
|
|
Returns:
|
|
SendResult
|
|
"""
|
|
if not self._access_policy.is_dm_allowed(user_id):
|
|
return SendResult(success=False, error="DM access denied for this user")
|
|
if len(text) > self.DM_MAX_CHARS:
|
|
text = text[:self.DM_MAX_CHARS] + "\n...(truncated)"
|
|
chat_id = f"direct:{user_id}"
|
|
return await self.send(chat_id, text, group_code=group_code)
|
|
|
|
# ------------------------------------------------------------------
|
|
# Media send methods
|
|
# ------------------------------------------------------------------
|
|
|
|
async def send_image(
|
|
self,
|
|
chat_id: str,
|
|
image_url: str,
|
|
caption: Optional[str] = None,
|
|
reply_to: Optional[str] = None,
|
|
metadata: Optional[dict] = None,
|
|
**kwargs: Any,
|
|
) -> SendResult:
|
|
"""Send image message (URL). Delegates to OutboundManager via ImageUrlHandler."""
|
|
return await self._outbound.send_media(
|
|
chat_id, "image_url",
|
|
reply_to=reply_to, caption=caption, image_url=image_url,
|
|
**kwargs,
|
|
)
|
|
|
|
async def send_image_file(
|
|
self,
|
|
chat_id: str,
|
|
image_path: str,
|
|
caption: Optional[str] = None,
|
|
reply_to: Optional[str] = None,
|
|
metadata: Optional[dict] = None,
|
|
**kwargs: Any,
|
|
) -> SendResult:
|
|
"""Send local image file. Delegates to OutboundManager via ImageFileHandler."""
|
|
return await self._outbound.send_media(
|
|
chat_id, "image_file",
|
|
reply_to=reply_to, caption=caption, image_path=image_path,
|
|
**kwargs,
|
|
)
|
|
|
|
async def send_file(
|
|
self,
|
|
chat_id: str,
|
|
file_url: str,
|
|
filename: Optional[str] = None,
|
|
reply_to: Optional[str] = None,
|
|
metadata: Optional[dict] = None,
|
|
**kwargs: Any,
|
|
) -> SendResult:
|
|
"""Send file message (URL). Delegates to OutboundManager via FileUrlHandler."""
|
|
return await self._outbound.send_media(
|
|
chat_id, "file_url",
|
|
reply_to=reply_to, file_url=file_url, filename=filename,
|
|
**kwargs,
|
|
)
|
|
|
|
async def send_sticker(
|
|
self,
|
|
chat_id: str,
|
|
sticker_name: Optional[str] = None,
|
|
face_index: Optional[int] = None,
|
|
reply_to: Optional[str] = None,
|
|
**kwargs: Any,
|
|
) -> SendResult:
|
|
"""Send sticker/emoji. Delegates to OutboundManager via StickerHandler."""
|
|
return await self._outbound.send_media(
|
|
chat_id, "sticker",
|
|
reply_to=reply_to,
|
|
sticker_name=sticker_name, face_index=face_index,
|
|
**kwargs,
|
|
)
|
|
|
|
async def send_document(
|
|
self,
|
|
chat_id: str,
|
|
file_path: str,
|
|
filename: Optional[str] = None,
|
|
caption: Optional[str] = None,
|
|
reply_to: Optional[str] = None,
|
|
metadata: Optional[dict] = None,
|
|
**kwargs: Any,
|
|
) -> SendResult:
|
|
"""Send local file (document). Delegates to OutboundManager via DocumentHandler."""
|
|
return await self._outbound.send_media(
|
|
chat_id, "document",
|
|
reply_to=reply_to, caption=caption,
|
|
file_path=file_path, filename=filename,
|
|
**kwargs,
|
|
)
|
|
|
|
async def _get_cached_token(self) -> dict:
|
|
"""Get the current valid sign token (using module-level cache)."""
|
|
return await SignManager.get_token(
|
|
self._app_key, self._app_secret, self._api_domain,
|
|
route_env=self._route_env,
|
|
)
|
|
|
|
def get_status(self) -> dict:
|
|
"""Return a snapshot of the current connection status."""
|
|
conn = self._connection
|
|
return {
|
|
"connected": conn.is_connected,
|
|
"bot_id": self._bot_id,
|
|
"connect_id": conn.connect_id,
|
|
"reconnect_attempts": conn.reconnect_attempts,
|
|
"ws_url": self._ws_url,
|
|
}
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Module-level thin delegates (preserve import compatibility for external callers)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def get_active_adapter() -> Optional["YuanbaoAdapter"]:
|
|
"""Delegate to ``YuanbaoAdapter.get_active()``."""
|
|
return YuanbaoAdapter.get_active()
|
|
|
|
|
|
async def send_yuanbao_direct(
|
|
adapter: "YuanbaoAdapter",
|
|
chat_id: str,
|
|
message: str,
|
|
media_files: Optional[List[Tuple[str, bool]]] = None,
|
|
) -> Dict[str, Any]:
|
|
"""Delegate to ``OutboundManager.send_direct``."""
|
|
return await adapter._outbound.send_direct(chat_id, message, media_files)
|