mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-28 06:51:16 +08:00
fix(whatsapp): enforce require_mention in group chats
This commit is contained in:
@@ -563,6 +563,14 @@ def load_gateway_config() -> GatewayConfig:
|
||||
if isinstance(frc, list):
|
||||
frc = ",".join(str(v) for v in frc)
|
||||
os.environ["TELEGRAM_FREE_RESPONSE_CHATS"] = str(frc)
|
||||
|
||||
whatsapp_cfg = yaml_cfg.get("whatsapp", {})
|
||||
if isinstance(whatsapp_cfg, dict):
|
||||
if "require_mention" in whatsapp_cfg and not os.getenv("WHATSAPP_REQUIRE_MENTION"):
|
||||
os.environ["WHATSAPP_REQUIRE_MENTION"] = str(whatsapp_cfg["require_mention"]).lower()
|
||||
if "mention_patterns" in whatsapp_cfg and not os.getenv("WHATSAPP_MENTION_PATTERNS"):
|
||||
import json as _json
|
||||
os.environ["WHATSAPP_MENTION_PATTERNS"] = _json.dumps(whatsapp_cfg["mention_patterns"])
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Failed to process config.yaml — falling back to .env / gateway.json values. "
|
||||
|
||||
@@ -16,9 +16,11 @@ with different backends via a bridge pattern.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
import re
|
||||
import subprocess
|
||||
|
||||
_IS_WINDOWS = platform.system() == "Windows"
|
||||
@@ -138,12 +140,113 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
get_hermes_dir("platforms/whatsapp/session", "whatsapp/session")
|
||||
))
|
||||
self._reply_prefix: Optional[str] = config.extra.get("reply_prefix")
|
||||
self._mention_patterns = self._compile_mention_patterns()
|
||||
self._message_queue: asyncio.Queue = asyncio.Queue()
|
||||
self._bridge_log_fh = None
|
||||
self._bridge_log: Optional[Path] = None
|
||||
self._poll_task: Optional[asyncio.Task] = None
|
||||
self._http_session: Optional["aiohttp.ClientSession"] = None
|
||||
self._session_lock_identity: Optional[str] = None
|
||||
|
||||
def _whatsapp_require_mention(self) -> bool:
|
||||
configured = self.config.extra.get("require_mention")
|
||||
if configured is not None:
|
||||
if isinstance(configured, str):
|
||||
return configured.lower() in ("true", "1", "yes", "on")
|
||||
return bool(configured)
|
||||
return os.getenv("WHATSAPP_REQUIRE_MENTION", "false").lower() in ("true", "1", "yes", "on")
|
||||
|
||||
def _compile_mention_patterns(self):
|
||||
patterns = self.config.extra.get("mention_patterns")
|
||||
if patterns is None:
|
||||
raw = os.getenv("WHATSAPP_MENTION_PATTERNS", "").strip()
|
||||
if raw:
|
||||
try:
|
||||
patterns = json.loads(raw)
|
||||
except Exception:
|
||||
patterns = [part.strip() for part in raw.splitlines() if part.strip()]
|
||||
if not patterns:
|
||||
patterns = [part.strip() for part in raw.split(",") if part.strip()]
|
||||
if patterns is None:
|
||||
return []
|
||||
if isinstance(patterns, str):
|
||||
patterns = [patterns]
|
||||
if not isinstance(patterns, list):
|
||||
logger.warning("[%s] whatsapp mention_patterns must be a list or string; got %s", self.name, type(patterns).__name__)
|
||||
return []
|
||||
|
||||
compiled = []
|
||||
for pattern in patterns:
|
||||
if not isinstance(pattern, str) or not pattern.strip():
|
||||
continue
|
||||
try:
|
||||
compiled.append(re.compile(pattern, re.IGNORECASE))
|
||||
except re.error as exc:
|
||||
logger.warning("[%s] Invalid WhatsApp mention pattern %r: %s", self.name, pattern, exc)
|
||||
return compiled
|
||||
|
||||
@staticmethod
|
||||
def _normalize_whatsapp_id(value: Optional[str]) -> str:
|
||||
if not value:
|
||||
return ""
|
||||
normalized = str(value).strip()
|
||||
if ":" in normalized and "@" in normalized:
|
||||
normalized = normalized.replace(":", "@", 1)
|
||||
return normalized
|
||||
|
||||
def _bot_ids_from_message(self, data: Dict[str, Any]) -> set[str]:
|
||||
bot_ids = set()
|
||||
for candidate in data.get("botIds") or []:
|
||||
normalized = self._normalize_whatsapp_id(candidate)
|
||||
if normalized:
|
||||
bot_ids.add(normalized)
|
||||
return bot_ids
|
||||
|
||||
def _message_is_reply_to_bot(self, data: Dict[str, Any]) -> bool:
|
||||
quoted_participant = self._normalize_whatsapp_id(data.get("quotedParticipant"))
|
||||
if not quoted_participant:
|
||||
return False
|
||||
return quoted_participant in self._bot_ids_from_message(data)
|
||||
|
||||
def _message_mentions_bot(self, data: Dict[str, Any]) -> bool:
|
||||
bot_ids = self._bot_ids_from_message(data)
|
||||
if not bot_ids:
|
||||
return False
|
||||
mentioned_ids = {
|
||||
self._normalize_whatsapp_id(candidate)
|
||||
for candidate in (data.get("mentionedIds") or [])
|
||||
if self._normalize_whatsapp_id(candidate)
|
||||
}
|
||||
if mentioned_ids & bot_ids:
|
||||
return True
|
||||
|
||||
body = str(data.get("body") or "")
|
||||
lower_body = body.lower()
|
||||
for bot_id in bot_ids:
|
||||
bare_id = bot_id.split("@", 1)[0].lower()
|
||||
if bare_id and (f"@{bare_id}" in lower_body or bare_id in lower_body):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _message_matches_mention_patterns(self, data: Dict[str, Any]) -> bool:
|
||||
if not self._mention_patterns:
|
||||
return False
|
||||
body = str(data.get("body") or "")
|
||||
return any(pattern.search(body) for pattern in self._mention_patterns)
|
||||
|
||||
def _should_process_message(self, data: Dict[str, Any]) -> bool:
|
||||
if not data.get("isGroup"):
|
||||
return True
|
||||
if not self._whatsapp_require_mention():
|
||||
return True
|
||||
body = str(data.get("body") or "").strip()
|
||||
if body.startswith("/"):
|
||||
return True
|
||||
if self._message_is_reply_to_bot(data):
|
||||
return True
|
||||
if self._message_mentions_bot(data):
|
||||
return True
|
||||
return self._message_matches_mention_patterns(data)
|
||||
|
||||
async def connect(self) -> bool:
|
||||
"""
|
||||
@@ -687,6 +790,9 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
async def _build_message_event(self, data: Dict[str, Any]) -> Optional[MessageEvent]:
|
||||
"""Build a MessageEvent from bridge message data, downloading images to cache."""
|
||||
try:
|
||||
if not self._should_process_message(data):
|
||||
return None
|
||||
|
||||
# Determine message type
|
||||
msg_type = MessageType.TEXT
|
||||
if data.get("hasMedia"):
|
||||
|
||||
@@ -62,6 +62,30 @@ function formatOutgoingMessage(message) {
|
||||
return REPLY_PREFIX ? `${REPLY_PREFIX}${message}` : message;
|
||||
}
|
||||
|
||||
function normalizeWhatsAppId(value) {
|
||||
if (!value) return '';
|
||||
return String(value).replace(':', '@');
|
||||
}
|
||||
|
||||
function getMessageContent(msg) {
|
||||
const content = msg?.message || {};
|
||||
if (content.ephemeralMessage?.message) return content.ephemeralMessage.message;
|
||||
if (content.viewOnceMessage?.message) return content.viewOnceMessage.message;
|
||||
if (content.viewOnceMessageV2?.message) return content.viewOnceMessageV2.message;
|
||||
if (content.documentWithCaptionMessage?.message) return content.documentWithCaptionMessage.message;
|
||||
return content;
|
||||
}
|
||||
|
||||
function getContextInfo(messageContent) {
|
||||
if (!messageContent || typeof messageContent !== 'object') return {};
|
||||
for (const value of Object.values(messageContent)) {
|
||||
if (value && typeof value === 'object' && value.contextInfo) {
|
||||
return value.contextInfo;
|
||||
}
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
mkdirSync(SESSION_DIR, { recursive: true });
|
||||
|
||||
// Build LID → phone reverse map from session files (lid-mapping-{phone}.json)
|
||||
@@ -200,23 +224,32 @@ async function startSocket() {
|
||||
continue;
|
||||
}
|
||||
|
||||
const messageContent = getMessageContent(msg);
|
||||
const contextInfo = getContextInfo(messageContent);
|
||||
const mentionedIds = Array.from(new Set((contextInfo?.mentionedJid || []).map(normalizeWhatsAppId).filter(Boolean)));
|
||||
const quotedParticipant = normalizeWhatsAppId(contextInfo?.participant || contextInfo?.remoteJid || '');
|
||||
const botIds = Array.from(new Set([
|
||||
normalizeWhatsAppId(sock.user?.id),
|
||||
normalizeWhatsAppId(sock.user?.lid),
|
||||
].filter(Boolean)));
|
||||
|
||||
// Extract message body
|
||||
let body = '';
|
||||
let hasMedia = false;
|
||||
let mediaType = '';
|
||||
const mediaUrls = [];
|
||||
|
||||
if (msg.message.conversation) {
|
||||
body = msg.message.conversation;
|
||||
} else if (msg.message.extendedTextMessage?.text) {
|
||||
body = msg.message.extendedTextMessage.text;
|
||||
} else if (msg.message.imageMessage) {
|
||||
body = msg.message.imageMessage.caption || '';
|
||||
if (messageContent.conversation) {
|
||||
body = messageContent.conversation;
|
||||
} else if (messageContent.extendedTextMessage?.text) {
|
||||
body = messageContent.extendedTextMessage.text;
|
||||
} else if (messageContent.imageMessage) {
|
||||
body = messageContent.imageMessage.caption || '';
|
||||
hasMedia = true;
|
||||
mediaType = 'image';
|
||||
try {
|
||||
const buf = await downloadMediaMessage(msg, 'buffer', {}, { logger, reuploadRequest: sock.updateMediaMessage });
|
||||
const mime = msg.message.imageMessage.mimetype || 'image/jpeg';
|
||||
const mime = messageContent.imageMessage.mimetype || 'image/jpeg';
|
||||
const extMap = { 'image/jpeg': '.jpg', 'image/png': '.png', 'image/webp': '.webp', 'image/gif': '.gif' };
|
||||
const ext = extMap[mime] || '.jpg';
|
||||
mkdirSync(IMAGE_CACHE_DIR, { recursive: true });
|
||||
@@ -226,13 +259,13 @@ async function startSocket() {
|
||||
} catch (err) {
|
||||
console.error('[bridge] Failed to download image:', err.message);
|
||||
}
|
||||
} else if (msg.message.videoMessage) {
|
||||
body = msg.message.videoMessage.caption || '';
|
||||
} else if (messageContent.videoMessage) {
|
||||
body = messageContent.videoMessage.caption || '';
|
||||
hasMedia = true;
|
||||
mediaType = 'video';
|
||||
try {
|
||||
const buf = await downloadMediaMessage(msg, 'buffer', {}, { logger, reuploadRequest: sock.updateMediaMessage });
|
||||
const mime = msg.message.videoMessage.mimetype || 'video/mp4';
|
||||
const mime = messageContent.videoMessage.mimetype || 'video/mp4';
|
||||
const ext = mime.includes('mp4') ? '.mp4' : '.mkv';
|
||||
mkdirSync(DOCUMENT_CACHE_DIR, { recursive: true });
|
||||
const filePath = path.join(DOCUMENT_CACHE_DIR, `vid_${randomBytes(6).toString('hex')}${ext}`);
|
||||
@@ -241,11 +274,11 @@ async function startSocket() {
|
||||
} catch (err) {
|
||||
console.error('[bridge] Failed to download video:', err.message);
|
||||
}
|
||||
} else if (msg.message.audioMessage || msg.message.pttMessage) {
|
||||
} else if (messageContent.audioMessage || messageContent.pttMessage) {
|
||||
hasMedia = true;
|
||||
mediaType = msg.message.pttMessage ? 'ptt' : 'audio';
|
||||
mediaType = messageContent.pttMessage ? 'ptt' : 'audio';
|
||||
try {
|
||||
const audioMsg = msg.message.pttMessage || msg.message.audioMessage;
|
||||
const audioMsg = messageContent.pttMessage || messageContent.audioMessage;
|
||||
const buf = await downloadMediaMessage(msg, 'buffer', {}, { logger, reuploadRequest: sock.updateMediaMessage });
|
||||
const mime = audioMsg.mimetype || 'audio/ogg';
|
||||
const ext = mime.includes('ogg') ? '.ogg' : mime.includes('mp4') ? '.m4a' : '.ogg';
|
||||
@@ -256,11 +289,11 @@ async function startSocket() {
|
||||
} catch (err) {
|
||||
console.error('[bridge] Failed to download audio:', err.message);
|
||||
}
|
||||
} else if (msg.message.documentMessage) {
|
||||
body = msg.message.documentMessage.caption || '';
|
||||
} else if (messageContent.documentMessage) {
|
||||
body = messageContent.documentMessage.caption || '';
|
||||
hasMedia = true;
|
||||
mediaType = 'document';
|
||||
const fileName = msg.message.documentMessage.fileName || 'document';
|
||||
const fileName = messageContent.documentMessage.fileName || 'document';
|
||||
try {
|
||||
const buf = await downloadMediaMessage(msg, 'buffer', {}, { logger, reuploadRequest: sock.updateMediaMessage });
|
||||
mkdirSync(DOCUMENT_CACHE_DIR, { recursive: true });
|
||||
@@ -309,6 +342,9 @@ async function startSocket() {
|
||||
hasMedia,
|
||||
mediaType,
|
||||
mediaUrls,
|
||||
mentionedIds,
|
||||
quotedParticipant,
|
||||
botIds,
|
||||
timestamp: msg.messageTimestamp,
|
||||
};
|
||||
|
||||
|
||||
97
tests/gateway/test_whatsapp_group_gating.py
Normal file
97
tests/gateway/test_whatsapp_group_gating.py
Normal file
@@ -0,0 +1,97 @@
|
||||
import json
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
from gateway.config import Platform, PlatformConfig, load_gateway_config
|
||||
|
||||
|
||||
def _make_adapter(require_mention=None, mention_patterns=None):
|
||||
from gateway.platforms.whatsapp import WhatsAppAdapter
|
||||
|
||||
extra = {}
|
||||
if require_mention is not None:
|
||||
extra["require_mention"] = require_mention
|
||||
if mention_patterns is not None:
|
||||
extra["mention_patterns"] = mention_patterns
|
||||
|
||||
adapter = object.__new__(WhatsAppAdapter)
|
||||
adapter.platform = Platform.WHATSAPP
|
||||
adapter.config = PlatformConfig(enabled=True, extra=extra)
|
||||
adapter._message_handler = AsyncMock()
|
||||
adapter._mention_patterns = adapter._compile_mention_patterns()
|
||||
return adapter
|
||||
|
||||
|
||||
def _group_message(body="hello", **overrides):
|
||||
data = {
|
||||
"isGroup": True,
|
||||
"body": body,
|
||||
"mentionedIds": [],
|
||||
"botIds": ["15551230000@s.whatsapp.net", "15551230000@lid"],
|
||||
"quotedParticipant": "",
|
||||
}
|
||||
data.update(overrides)
|
||||
return data
|
||||
|
||||
|
||||
def test_group_messages_can_be_opened_via_config():
|
||||
adapter = _make_adapter(require_mention=False)
|
||||
|
||||
assert adapter._should_process_message(_group_message("hello everyone")) is True
|
||||
|
||||
|
||||
def test_group_messages_can_require_direct_trigger_via_config():
|
||||
adapter = _make_adapter(require_mention=True)
|
||||
|
||||
assert adapter._should_process_message(_group_message("hello everyone")) is False
|
||||
assert adapter._should_process_message(
|
||||
_group_message(
|
||||
"hi there",
|
||||
mentionedIds=["15551230000@s.whatsapp.net"],
|
||||
)
|
||||
) is True
|
||||
assert adapter._should_process_message(
|
||||
_group_message(
|
||||
"replying",
|
||||
quotedParticipant="15551230000@lid",
|
||||
)
|
||||
) is True
|
||||
assert adapter._should_process_message(_group_message("/status")) is True
|
||||
|
||||
|
||||
def test_regex_mention_patterns_allow_custom_wake_words():
|
||||
adapter = _make_adapter(require_mention=True, mention_patterns=[r"^\s*chompy\b"])
|
||||
|
||||
assert adapter._should_process_message(_group_message("chompy status")) is True
|
||||
assert adapter._should_process_message(_group_message(" chompy help")) is True
|
||||
assert adapter._should_process_message(_group_message("hey chompy")) is False
|
||||
|
||||
|
||||
def test_invalid_regex_patterns_are_ignored():
|
||||
adapter = _make_adapter(require_mention=True, mention_patterns=[r"(", r"^\s*chompy\b"])
|
||||
|
||||
assert adapter._should_process_message(_group_message("chompy status")) is True
|
||||
assert adapter._should_process_message(_group_message("hello everyone")) is False
|
||||
|
||||
|
||||
def test_config_bridges_whatsapp_group_settings(monkeypatch, tmp_path):
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
(hermes_home / "config.yaml").write_text(
|
||||
"whatsapp:\n"
|
||||
" require_mention: true\n"
|
||||
" mention_patterns:\n"
|
||||
" - \"^\\\\s*chompy\\\\b\"\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
monkeypatch.delenv("WHATSAPP_REQUIRE_MENTION", raising=False)
|
||||
monkeypatch.delenv("WHATSAPP_MENTION_PATTERNS", raising=False)
|
||||
|
||||
config = load_gateway_config()
|
||||
|
||||
assert config is not None
|
||||
assert config.platforms[Platform.WHATSAPP].extra["require_mention"] is True
|
||||
assert config.platforms[Platform.WHATSAPP].extra["mention_patterns"] == [r"^\s*chompy\b"]
|
||||
assert __import__("os").environ["WHATSAPP_REQUIRE_MENTION"] == "true"
|
||||
assert json.loads(__import__("os").environ["WHATSAPP_MENTION_PATTERNS"]) == [r"^\s*chompy\b"]
|
||||
Reference in New Issue
Block a user