mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-05-06 02:37:05 +08:00
Two mitigations for the CLOSE_WAIT accumulation reported against QQ Bot + Feishu on macOS behind Cloudflare Warp. 1. Shared httpx.Limits helper (gateway/platforms/_http_client_limits.py). Every long-lived platform adapter now constructs httpx.AsyncClient with max_keepalive_connections=10 and keepalive_expiry=2.0, vs httpx's default of unbounded keepalive pool and 5.0s expiry. On macOS/Warp the default 5s window let idle keepalive sockets sit in CLOSE_WAIT long enough for seven persistent adapters (QQ Bot, WeCom, DingTalk, Signal, BlueBubbles, WeCom-callback, plus the transient Feishu helper) to compound to the 256-fd ulimit. Tunable via HERMES_GATEWAY_HTTPX_KEEPALIVE_EXPIRY and HERMES_GATEWAY_HTTPX_MAX_KEEPALIVE env vars. 2. whatsapp.send_typing aiohttp leak. The call was 'await self._http_session.post(...)' with no 'async with' and no variable capture — the ClientResponse went out of scope unclosed, holding its TCP socket in CLOSE_WAIT until GC. Fixed by wrapping in 'async with'. This was the only bare-await aiohttp leak in the gateway/tools/plugins tree per audit; all other aiohttp sites use the context-manager pattern correctly. The underlying reporter also saw Feishu SDK (lark-oapi) connections in CLOSE_WAIT — those are inside the SDK and out of our direct control, but tightening httpx keepalive across adapters reduces the aggregate pool pressure regardless of which individual adapter leaks.
404 lines
16 KiB
Python
404 lines
16 KiB
Python
"""WeCom callback-mode adapter for self-built enterprise applications.
|
|
|
|
Unlike the bot/websocket adapter in ``wecom.py``, this handles the standard
|
|
WeCom callback flow: WeCom POSTs encrypted XML to an HTTP endpoint, the
|
|
adapter decrypts it, queues the message for the agent, and immediately
|
|
acknowledges. The agent's reply is delivered later via the proactive
|
|
``message/send`` API using an access-token.
|
|
|
|
Supports multiple self-built apps under one gateway instance, scoped by
|
|
``corp_id:user_id`` to avoid cross-corp collisions.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import logging
|
|
import socket as _socket
|
|
import time
|
|
from typing import Any, Dict, List, Optional
|
|
from xml.etree import ElementTree as ET
|
|
|
|
try:
|
|
from aiohttp import web
|
|
|
|
AIOHTTP_AVAILABLE = True
|
|
except ImportError:
|
|
web = None # type: ignore[assignment]
|
|
AIOHTTP_AVAILABLE = False
|
|
|
|
try:
|
|
import httpx
|
|
|
|
HTTPX_AVAILABLE = True
|
|
except ImportError:
|
|
httpx = None # type: ignore[assignment]
|
|
HTTPX_AVAILABLE = False
|
|
|
|
from gateway.config import Platform, PlatformConfig
|
|
from gateway.platforms.base import BasePlatformAdapter, MessageEvent, MessageType, SendResult
|
|
from gateway.platforms.wecom_crypto import WXBizMsgCrypt, WeComCryptoError
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
DEFAULT_HOST = "0.0.0.0"
|
|
DEFAULT_PORT = 8645
|
|
DEFAULT_PATH = "/wecom/callback"
|
|
ACCESS_TOKEN_TTL_SECONDS = 7200
|
|
MESSAGE_DEDUP_TTL_SECONDS = 300
|
|
|
|
|
|
def check_wecom_callback_requirements() -> bool:
|
|
return AIOHTTP_AVAILABLE and HTTPX_AVAILABLE
|
|
|
|
|
|
class WecomCallbackAdapter(BasePlatformAdapter):
|
|
def __init__(self, config: PlatformConfig):
|
|
super().__init__(config, Platform.WECOM_CALLBACK)
|
|
extra = config.extra or {}
|
|
self._host = str(extra.get("host") or DEFAULT_HOST)
|
|
self._port = int(extra.get("port") or DEFAULT_PORT)
|
|
self._path = str(extra.get("path") or DEFAULT_PATH)
|
|
self._apps: List[Dict[str, Any]] = self._normalize_apps(extra)
|
|
self._runner: Optional[web.AppRunner] = None
|
|
self._site: Optional[web.TCPSite] = None
|
|
self._app: Optional[web.Application] = None
|
|
self._http_client: Optional[httpx.AsyncClient] = None
|
|
self._message_queue: asyncio.Queue[MessageEvent] = asyncio.Queue()
|
|
self._poll_task: Optional[asyncio.Task] = None
|
|
self._seen_messages: Dict[str, float] = {}
|
|
self._user_app_map: Dict[str, str] = {}
|
|
self._access_tokens: Dict[str, Dict[str, Any]] = {}
|
|
|
|
# ------------------------------------------------------------------
|
|
# App normalisation
|
|
# ------------------------------------------------------------------
|
|
|
|
@staticmethod
|
|
def _user_app_key(corp_id: str, user_id: str) -> str:
|
|
return f"{corp_id}:{user_id}" if corp_id else user_id
|
|
|
|
@staticmethod
|
|
def _normalize_apps(extra: Dict[str, Any]) -> List[Dict[str, Any]]:
|
|
apps = extra.get("apps")
|
|
if isinstance(apps, list) and apps:
|
|
return [dict(app) for app in apps if isinstance(app, dict)]
|
|
if extra.get("corp_id"):
|
|
return [
|
|
{
|
|
"name": extra.get("name") or "default",
|
|
"corp_id": extra.get("corp_id", ""),
|
|
"corp_secret": extra.get("corp_secret", ""),
|
|
"agent_id": str(extra.get("agent_id", "")),
|
|
"token": extra.get("token", ""),
|
|
"encoding_aes_key": extra.get("encoding_aes_key", ""),
|
|
}
|
|
]
|
|
return []
|
|
|
|
# ------------------------------------------------------------------
|
|
# Lifecycle
|
|
# ------------------------------------------------------------------
|
|
|
|
async def connect(self) -> bool:
|
|
if not self._apps:
|
|
logger.warning("[WecomCallback] No callback apps configured")
|
|
return False
|
|
if not check_wecom_callback_requirements():
|
|
logger.warning("[WecomCallback] aiohttp/httpx not installed")
|
|
return False
|
|
|
|
# Quick port-in-use check.
|
|
try:
|
|
with _socket.socket(_socket.AF_INET, _socket.SOCK_STREAM) as sock:
|
|
sock.settimeout(1)
|
|
sock.connect(("127.0.0.1", self._port))
|
|
logger.error("[WecomCallback] Port %d already in use", self._port)
|
|
return False
|
|
except (ConnectionRefusedError, OSError):
|
|
pass
|
|
|
|
try:
|
|
# Tighter keepalive so idle CLOSE_WAIT drains promptly (#18451).
|
|
from gateway.platforms._http_client_limits import platform_httpx_limits
|
|
self._http_client = httpx.AsyncClient(timeout=20.0, limits=platform_httpx_limits())
|
|
self._app = web.Application()
|
|
self._app.router.add_get("/health", self._handle_health)
|
|
self._app.router.add_get(self._path, self._handle_verify)
|
|
self._app.router.add_post(self._path, self._handle_callback)
|
|
self._runner = web.AppRunner(self._app)
|
|
await self._runner.setup()
|
|
self._site = web.TCPSite(self._runner, self._host, self._port)
|
|
await self._site.start()
|
|
self._poll_task = asyncio.create_task(self._poll_loop())
|
|
self._mark_connected()
|
|
logger.info(
|
|
"[WecomCallback] HTTP server listening on %s:%s%s",
|
|
self._host, self._port, self._path,
|
|
)
|
|
for app in self._apps:
|
|
try:
|
|
await self._refresh_access_token(app)
|
|
except Exception as exc:
|
|
logger.warning(
|
|
"[WecomCallback] Initial token refresh failed for app '%s': %s",
|
|
app.get("name", "default"), exc,
|
|
)
|
|
return True
|
|
except Exception:
|
|
await self._cleanup()
|
|
logger.exception("[WecomCallback] Failed to start")
|
|
return False
|
|
|
|
async def disconnect(self) -> None:
|
|
self._running = False
|
|
if self._poll_task:
|
|
self._poll_task.cancel()
|
|
try:
|
|
await self._poll_task
|
|
except asyncio.CancelledError:
|
|
pass
|
|
self._poll_task = None
|
|
await self._cleanup()
|
|
self._mark_disconnected()
|
|
logger.info("[WecomCallback] Disconnected")
|
|
|
|
async def _cleanup(self) -> None:
|
|
self._site = None
|
|
if self._runner:
|
|
await self._runner.cleanup()
|
|
self._runner = None
|
|
self._app = None
|
|
if self._http_client:
|
|
await self._http_client.aclose()
|
|
self._http_client = None
|
|
|
|
# ------------------------------------------------------------------
|
|
# Outbound: proactive send via access-token API
|
|
# ------------------------------------------------------------------
|
|
|
|
async def send(
|
|
self,
|
|
chat_id: str,
|
|
content: str,
|
|
reply_to: Optional[str] = None,
|
|
metadata: Optional[Dict[str, Any]] = None,
|
|
) -> SendResult:
|
|
app = self._resolve_app_for_chat(chat_id)
|
|
touser = chat_id.split(":", 1)[1] if ":" in chat_id else chat_id
|
|
try:
|
|
token = await self._get_access_token(app)
|
|
payload = {
|
|
"touser": touser,
|
|
"msgtype": "text",
|
|
"agentid": int(str(app.get("agent_id") or 0)),
|
|
"text": {"content": content[:2048]},
|
|
"safe": 0,
|
|
}
|
|
resp = await self._http_client.post(
|
|
f"https://qyapi.weixin.qq.com/cgi-bin/message/send?access_token={token}",
|
|
json=payload,
|
|
)
|
|
data = resp.json()
|
|
if data.get("errcode") != 0:
|
|
return SendResult(success=False, error=str(data))
|
|
return SendResult(
|
|
success=True,
|
|
message_id=str(data.get("msgid", "")),
|
|
raw_response=data,
|
|
)
|
|
except Exception as exc:
|
|
return SendResult(success=False, error=str(exc))
|
|
|
|
def _resolve_app_for_chat(self, chat_id: str) -> Dict[str, Any]:
|
|
"""Pick the app associated with *chat_id*, falling back sensibly."""
|
|
app_name = self._user_app_map.get(chat_id)
|
|
if not app_name and ":" not in chat_id:
|
|
# Legacy bare user_id — try to find a unique match.
|
|
matching = [k for k in self._user_app_map if k.endswith(f":{chat_id}")]
|
|
if len(matching) == 1:
|
|
app_name = self._user_app_map.get(matching[0])
|
|
app = self._get_app_by_name(app_name) if app_name else None
|
|
return app or self._apps[0]
|
|
|
|
async def get_chat_info(self, chat_id: str) -> Dict[str, Any]:
|
|
return {"name": chat_id, "type": "dm"}
|
|
|
|
# ------------------------------------------------------------------
|
|
# Inbound: HTTP callback handlers
|
|
# ------------------------------------------------------------------
|
|
|
|
async def _handle_health(self, request: web.Request) -> web.Response:
|
|
return web.json_response({"status": "ok", "platform": "wecom_callback"})
|
|
|
|
async def _handle_verify(self, request: web.Request) -> web.Response:
|
|
"""GET endpoint — WeCom URL verification handshake."""
|
|
msg_signature = request.query.get("msg_signature", "")
|
|
timestamp = request.query.get("timestamp", "")
|
|
nonce = request.query.get("nonce", "")
|
|
echostr = request.query.get("echostr", "")
|
|
for app in self._apps:
|
|
try:
|
|
crypt = self._crypt_for_app(app)
|
|
plain = crypt.verify_url(msg_signature, timestamp, nonce, echostr)
|
|
return web.Response(text=plain, content_type="text/plain")
|
|
except Exception:
|
|
continue
|
|
return web.Response(status=403, text="signature verification failed")
|
|
|
|
async def _handle_callback(self, request: web.Request) -> web.Response:
|
|
"""POST endpoint — receive an encrypted message callback."""
|
|
msg_signature = request.query.get("msg_signature", "")
|
|
timestamp = request.query.get("timestamp", "")
|
|
nonce = request.query.get("nonce", "")
|
|
body = await request.text()
|
|
|
|
for app in self._apps:
|
|
try:
|
|
decrypted = self._decrypt_request(
|
|
app, body, msg_signature, timestamp, nonce,
|
|
)
|
|
event = self._build_event(app, decrypted)
|
|
if event is not None:
|
|
# Deduplicate: WeCom retries callbacks on timeout,
|
|
# producing duplicate inbound messages (#10305).
|
|
if event.message_id:
|
|
now = time.time()
|
|
if event.message_id in self._seen_messages:
|
|
if now - self._seen_messages[event.message_id] < MESSAGE_DEDUP_TTL_SECONDS:
|
|
logger.debug("[WecomCallback] Duplicate MsgId %s, skipping", event.message_id)
|
|
return web.Response(text="success", content_type="text/plain")
|
|
del self._seen_messages[event.message_id]
|
|
self._seen_messages[event.message_id] = now
|
|
# Prune expired entries when cache grows large
|
|
if len(self._seen_messages) > 2000:
|
|
cutoff = now - MESSAGE_DEDUP_TTL_SECONDS
|
|
self._seen_messages = {k: v for k, v in self._seen_messages.items() if v > cutoff}
|
|
# Record which app this user belongs to.
|
|
if event.source and event.source.user_id:
|
|
map_key = self._user_app_key(
|
|
str(app.get("corp_id") or ""), event.source.user_id,
|
|
)
|
|
self._user_app_map[map_key] = app["name"]
|
|
await self._message_queue.put(event)
|
|
# Immediately acknowledge — the agent's reply will arrive
|
|
# later via the proactive message/send API.
|
|
return web.Response(text="success", content_type="text/plain")
|
|
except WeComCryptoError:
|
|
continue
|
|
except Exception:
|
|
logger.exception("[WecomCallback] Error handling message")
|
|
break
|
|
return web.Response(status=400, text="invalid callback payload")
|
|
|
|
async def _poll_loop(self) -> None:
|
|
"""Drain the message queue and dispatch to the gateway runner."""
|
|
while True:
|
|
event = await self._message_queue.get()
|
|
try:
|
|
task = asyncio.create_task(self.handle_message(event))
|
|
self._background_tasks.add(task)
|
|
task.add_done_callback(self._background_tasks.discard)
|
|
except Exception:
|
|
logger.exception("[WecomCallback] Failed to enqueue event")
|
|
|
|
# ------------------------------------------------------------------
|
|
# XML / crypto helpers
|
|
# ------------------------------------------------------------------
|
|
|
|
def _decrypt_request(
|
|
self, app: Dict[str, Any], body: str,
|
|
msg_signature: str, timestamp: str, nonce: str,
|
|
) -> str:
|
|
root = ET.fromstring(body)
|
|
encrypt = root.findtext("Encrypt", default="")
|
|
crypt = self._crypt_for_app(app)
|
|
return crypt.decrypt(msg_signature, timestamp, nonce, encrypt).decode("utf-8")
|
|
|
|
def _build_event(self, app: Dict[str, Any], xml_text: str) -> Optional[MessageEvent]:
|
|
root = ET.fromstring(xml_text)
|
|
msg_type = (root.findtext("MsgType") or "").lower()
|
|
# Silently acknowledge lifecycle events.
|
|
if msg_type == "event":
|
|
event_name = (root.findtext("Event") or "").lower()
|
|
if event_name in {"enter_agent", "subscribe"}:
|
|
return None
|
|
if msg_type not in {"text", "event"}:
|
|
return None
|
|
|
|
user_id = root.findtext("FromUserName", default="")
|
|
corp_id = root.findtext("ToUserName", default=app.get("corp_id", ""))
|
|
scoped_chat_id = self._user_app_key(corp_id, user_id)
|
|
content = root.findtext("Content", default="").strip()
|
|
if not content and msg_type == "event":
|
|
content = "/start"
|
|
msg_id = (
|
|
root.findtext("MsgId")
|
|
or f"{user_id}:{root.findtext('CreateTime', default='0')}"
|
|
)
|
|
source = self.build_source(
|
|
chat_id=scoped_chat_id,
|
|
chat_name=user_id,
|
|
chat_type="dm",
|
|
user_id=user_id,
|
|
user_name=user_id,
|
|
)
|
|
return MessageEvent(
|
|
text=content,
|
|
message_type=MessageType.TEXT,
|
|
source=source,
|
|
raw_message=xml_text,
|
|
message_id=msg_id,
|
|
)
|
|
|
|
def _crypt_for_app(self, app: Dict[str, Any]) -> WXBizMsgCrypt:
|
|
return WXBizMsgCrypt(
|
|
token=str(app.get("token") or ""),
|
|
encoding_aes_key=str(app.get("encoding_aes_key") or ""),
|
|
receive_id=str(app.get("corp_id") or ""),
|
|
)
|
|
|
|
def _get_app_by_name(self, name: Optional[str]) -> Optional[Dict[str, Any]]:
|
|
if not name:
|
|
return None
|
|
for app in self._apps:
|
|
if app.get("name") == name:
|
|
return app
|
|
return None
|
|
|
|
# ------------------------------------------------------------------
|
|
# Access-token management
|
|
# ------------------------------------------------------------------
|
|
|
|
async def _get_access_token(self, app: Dict[str, Any]) -> str:
|
|
cached = self._access_tokens.get(app["name"])
|
|
now = time.time()
|
|
if cached and cached.get("expires_at", 0) > now + 60:
|
|
return cached["token"]
|
|
return await self._refresh_access_token(app)
|
|
|
|
async def _refresh_access_token(self, app: Dict[str, Any]) -> str:
|
|
resp = await self._http_client.get(
|
|
"https://qyapi.weixin.qq.com/cgi-bin/gettoken",
|
|
params={
|
|
"corpid": app.get("corp_id"),
|
|
"corpsecret": app.get("corp_secret"),
|
|
},
|
|
)
|
|
data = resp.json()
|
|
if data.get("errcode") != 0:
|
|
raise RuntimeError(f"WeCom token refresh failed: {data}")
|
|
token = data["access_token"]
|
|
expires_in = int(data.get("expires_in", ACCESS_TOKEN_TTL_SECONDS))
|
|
self._access_tokens[app["name"]] = {
|
|
"token": token,
|
|
"expires_at": time.time() + expires_in,
|
|
}
|
|
logger.info(
|
|
"[WecomCallback] Token refreshed for app '%s' (corp=%s), expires in %ss",
|
|
app.get("name", "default"),
|
|
app.get("corp_id", ""),
|
|
expires_in,
|
|
)
|
|
return token
|