diff --git a/gateway/config.py b/gateway/config.py index ba0840bfc0c..5d3dfa9f59f 100644 --- a/gateway/config.py +++ b/gateway/config.py @@ -292,6 +292,18 @@ def load_gateway_config() -> GatewayConfig: sr = yaml_cfg.get("session_reset") if sr and isinstance(sr, dict): config.default_reset_policy = SessionResetPolicy.from_dict(sr) + + # Bridge discord settings from config.yaml to env vars + # (env vars take precedence — only set if not already defined) + discord_cfg = yaml_cfg.get("discord", {}) + if isinstance(discord_cfg, dict): + if "require_mention" in discord_cfg and not os.getenv("DISCORD_REQUIRE_MENTION"): + os.environ["DISCORD_REQUIRE_MENTION"] = str(discord_cfg["require_mention"]).lower() + frc = discord_cfg.get("free_response_channels") + if frc is not None and not os.getenv("DISCORD_FREE_RESPONSE_CHANNELS"): + if isinstance(frc, list): + frc = ",".join(str(v) for v in frc) + os.environ["DISCORD_FREE_RESPONSE_CHANNELS"] = str(frc) except Exception: pass diff --git a/gateway/platforms/discord.py b/gateway/platforms/discord.py index 04607ab0779..c7ae2ada5db 100644 --- a/gateway/platforms/discord.py +++ b/gateway/platforms/discord.py @@ -775,6 +775,46 @@ class DiscordAdapter(BasePlatformAdapter): except Exception as e: return SendResult(success=False, error=str(e)) + def _get_parent_channel_id(self, channel: Any) -> Optional[str]: + """Return the parent channel ID for a Discord thread-like channel, if present.""" + parent = getattr(channel, "parent", None) + if parent is not None and getattr(parent, "id", None) is not None: + return str(parent.id) + parent_id = getattr(channel, "parent_id", None) + if parent_id is not None: + return str(parent_id) + return None + + def _is_forum_parent(self, channel: Any) -> bool: + """Best-effort check for whether a Discord channel is a forum channel.""" + if channel is None: + return False + forum_cls = getattr(discord, "ForumChannel", None) + if forum_cls and isinstance(channel, forum_cls): + return True + channel_type = getattr(channel, "type", None) + if channel_type is not None: + type_value = getattr(channel_type, "value", channel_type) + if type_value == 15: + return True + return False + + def _format_thread_chat_name(self, thread: Any) -> str: + """Build a readable chat name for thread-like Discord channels, including forum context when available.""" + thread_name = getattr(thread, "name", None) or str(getattr(thread, "id", "thread")) + parent = getattr(thread, "parent", None) + guild = getattr(thread, "guild", None) or getattr(parent, "guild", None) + guild_name = getattr(guild, "name", None) + parent_name = getattr(parent, "name", None) + + if self._is_forum_parent(parent) and guild_name and parent_name: + return f"{guild_name} / {parent_name} / {thread_name}" + if parent_name and guild_name: + return f"{guild_name} / #{parent_name} / {thread_name}" + if parent_name: + return f"{parent_name} / {thread_name}" + return thread_name + async def _handle_message(self, message: DiscordMessage) -> None: """Handle incoming Discord messages.""" # In server channels (not DMs), require the bot to be @mentioned @@ -785,28 +825,33 @@ class DiscordAdapter(BasePlatformAdapter): # bot responds to every message without needing a mention. # DISCORD_REQUIRE_MENTION: Set to "false" to disable mention requirement # globally (all channels become free-response). Default: "true". - + # Can also be set via discord.require_mention in config.yaml. + + thread_id = None + parent_channel_id = None + is_thread = isinstance(message.channel, discord.Thread) + if is_thread: + thread_id = str(message.channel.id) + parent_channel_id = self._get_parent_channel_id(message.channel) + if not isinstance(message.channel, discord.DMChannel): - # Check if this channel is in the free-response list free_channels_raw = os.getenv("DISCORD_FREE_RESPONSE_CHANNELS", "") free_channels = {ch.strip() for ch in free_channels_raw.split(",") if ch.strip()} - channel_id = str(message.channel.id) - - # Global override: if DISCORD_REQUIRE_MENTION=false, all channels are free + channel_ids = {str(message.channel.id)} + if parent_channel_id: + channel_ids.add(parent_channel_id) + require_mention = os.getenv("DISCORD_REQUIRE_MENTION", "true").lower() not in ("false", "0", "no") - - is_free_channel = channel_id in free_channels - + is_free_channel = bool(channel_ids & free_channels) + if require_mention and not is_free_channel: - # Must be @mentioned to respond if self._client.user not in message.mentions: - return # Silently ignore messages that don't mention the bot - - # Strip the bot mention from the message text so the agent sees clean input + return + if self._client.user and self._client.user in message.mentions: message.content = message.content.replace(f"<@{self._client.user.id}>", "").strip() message.content = message.content.replace(f"<@!{self._client.user.id}>", "").strip() - + # Determine message type msg_type = MessageType.TEXT if message.content.startswith("/"): @@ -829,20 +874,15 @@ class DiscordAdapter(BasePlatformAdapter): if isinstance(message.channel, discord.DMChannel): chat_type = "dm" chat_name = message.author.name - elif isinstance(message.channel, discord.Thread): + elif is_thread: chat_type = "thread" - chat_name = message.channel.name + chat_name = self._format_thread_chat_name(message.channel) else: - chat_type = "group" # Treat server channels as groups + chat_type = "group" chat_name = getattr(message.channel, "name", str(message.channel.id)) if hasattr(message.channel, "guild") and message.channel.guild: chat_name = f"{message.channel.guild.name} / #{chat_name}" - - # Get thread ID if in a thread - thread_id = None - if isinstance(message.channel, discord.Thread): - thread_id = str(message.channel.id) - + # Get channel topic (if available - TextChannels have topics, DMs/threads don't) chat_topic = getattr(message.channel, "topic", None) diff --git a/hermes_cli/config.py b/hermes_cli/config.py index f2b5d42c188..0094b94b5e4 100644 --- a/hermes_cli/config.py +++ b/hermes_cli/config.py @@ -208,6 +208,12 @@ DEFAULT_CONFIG = { # Empty string means use server-local time. "timezone": "", + # Discord platform settings (gateway mode) + "discord": { + "require_mention": True, # Require @mention to respond in server channels + "free_response_channels": "", # Comma-separated channel IDs where bot responds without mention + }, + # Permanently allowed dangerous command patterns (added via "always" approval) "command_allowlist": [], # User-defined quick commands that bypass the agent loop (type: exec only) diff --git a/tests/gateway/test_discord_free_response.py b/tests/gateway/test_discord_free_response.py new file mode 100644 index 00000000000..fd9eacab253 --- /dev/null +++ b/tests/gateway/test_discord_free_response.py @@ -0,0 +1,249 @@ +"""Tests for Discord free-response defaults and mention gating.""" + +from datetime import datetime, timezone +from types import SimpleNamespace +from unittest.mock import AsyncMock, MagicMock +import sys + +import pytest + +from gateway.config import PlatformConfig + + +def _ensure_discord_mock(): + """Install a mock discord module when discord.py isn't available.""" + if "discord" in sys.modules and hasattr(sys.modules["discord"], "__file__"): + return + + discord_mod = MagicMock() + discord_mod.Intents.default.return_value = MagicMock() + discord_mod.Client = MagicMock + discord_mod.File = MagicMock + discord_mod.DMChannel = type("DMChannel", (), {}) + discord_mod.Thread = type("Thread", (), {}) + discord_mod.ForumChannel = type("ForumChannel", (), {}) + discord_mod.ui = SimpleNamespace(View=object, button=lambda *a, **k: (lambda fn: fn), Button=object) + discord_mod.ButtonStyle = SimpleNamespace(success=1, primary=2, danger=3, green=1, blurple=2, red=3) + discord_mod.Color = SimpleNamespace(orange=lambda: 1, green=lambda: 2, blue=lambda: 3, red=lambda: 4) + discord_mod.Interaction = object + discord_mod.Embed = MagicMock + + ext_mod = MagicMock() + commands_mod = MagicMock() + commands_mod.Bot = MagicMock + ext_mod.commands = commands_mod + + sys.modules.setdefault("discord", discord_mod) + sys.modules.setdefault("discord.ext", ext_mod) + sys.modules.setdefault("discord.ext.commands", commands_mod) + + +_ensure_discord_mock() + +import gateway.platforms.discord as discord_platform # noqa: E402 +from gateway.platforms.discord import DiscordAdapter # noqa: E402 + + +class FakeDMChannel: + def __init__(self, channel_id: int = 1, name: str = "dm"): + self.id = channel_id + self.name = name + + +class FakeTextChannel: + def __init__(self, channel_id: int = 1, name: str = "general", guild_name: str = "Hermes Server"): + self.id = channel_id + self.name = name + self.guild = SimpleNamespace(name=guild_name) + self.topic = None + + +class FakeForumChannel: + def __init__(self, channel_id: int = 1, name: str = "support-forum", guild_name: str = "Hermes Server"): + self.id = channel_id + self.name = name + self.guild = SimpleNamespace(name=guild_name) + self.type = 15 + self.topic = None + + +class FakeThread: + def __init__(self, channel_id: int = 1, name: str = "thread", parent=None, guild_name: str = "Hermes Server"): + self.id = channel_id + self.name = name + self.parent = parent + self.parent_id = getattr(parent, "id", None) + self.guild = getattr(parent, "guild", None) or SimpleNamespace(name=guild_name) + self.topic = None + + +@pytest.fixture +def adapter(monkeypatch): + monkeypatch.setattr(discord_platform.discord, "DMChannel", FakeDMChannel, raising=False) + monkeypatch.setattr(discord_platform.discord, "Thread", FakeThread, raising=False) + monkeypatch.setattr(discord_platform.discord, "ForumChannel", FakeForumChannel, raising=False) + + config = PlatformConfig(enabled=True, token="fake-token") + adapter = DiscordAdapter(config) + adapter._client = SimpleNamespace(user=SimpleNamespace(id=999)) + adapter.handle_message = AsyncMock() + return adapter + + +def make_message(*, channel, content: str, mentions=None): + author = SimpleNamespace(id=42, display_name="Jezza", name="Jezza") + return SimpleNamespace( + id=123, + content=content, + mentions=list(mentions or []), + attachments=[], + reference=None, + created_at=datetime.now(timezone.utc), + channel=channel, + author=author, + ) + + +@pytest.mark.asyncio +async def test_discord_defaults_to_require_mention(adapter, monkeypatch): + """Default behavior: require @mention in server channels.""" + monkeypatch.delenv("DISCORD_REQUIRE_MENTION", raising=False) + monkeypatch.delenv("DISCORD_FREE_RESPONSE_CHANNELS", raising=False) + + message = make_message(channel=FakeTextChannel(channel_id=123), content="hello from channel") + + await adapter._handle_message(message) + + # Should be ignored — no mention, require_mention defaults to true + adapter.handle_message.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_discord_free_response_in_server_channels(adapter, monkeypatch): + monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "false") + monkeypatch.delenv("DISCORD_FREE_RESPONSE_CHANNELS", raising=False) + + message = make_message(channel=FakeTextChannel(channel_id=123), content="hello from channel") + + await adapter._handle_message(message) + + adapter.handle_message.assert_awaited_once() + event = adapter.handle_message.await_args.args[0] + assert event.text == "hello from channel" + assert event.source.chat_id == "123" + assert event.source.chat_type == "group" + + +@pytest.mark.asyncio +async def test_discord_free_response_in_threads(adapter, monkeypatch): + monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "false") + monkeypatch.delenv("DISCORD_FREE_RESPONSE_CHANNELS", raising=False) + + thread = FakeThread(channel_id=456, name="Ghost reader skill") + message = make_message(channel=thread, content="hello from thread") + + await adapter._handle_message(message) + + adapter.handle_message.assert_awaited_once() + event = adapter.handle_message.await_args.args[0] + assert event.text == "hello from thread" + assert event.source.chat_id == "456" + assert event.source.thread_id == "456" + assert event.source.chat_type == "thread" + + +@pytest.mark.asyncio +async def test_discord_forum_threads_are_handled_as_threads(adapter, monkeypatch): + monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "false") + monkeypatch.delenv("DISCORD_FREE_RESPONSE_CHANNELS", raising=False) + + forum = FakeForumChannel(channel_id=222, name="support-forum") + thread = FakeThread(channel_id=456, name="Can Hermes reply here?", parent=forum) + message = make_message(channel=thread, content="hello from forum post") + + await adapter._handle_message(message) + + adapter.handle_message.assert_awaited_once() + event = adapter.handle_message.await_args.args[0] + assert event.text == "hello from forum post" + assert event.source.chat_id == "456" + assert event.source.thread_id == "456" + assert event.source.chat_type == "thread" + assert event.source.chat_name == "Hermes Server / support-forum / Can Hermes reply here?" + + +@pytest.mark.asyncio +async def test_discord_can_still_require_mentions_when_enabled(adapter, monkeypatch): + monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "true") + monkeypatch.delenv("DISCORD_FREE_RESPONSE_CHANNELS", raising=False) + + message = make_message(channel=FakeTextChannel(channel_id=789), content="ignored without mention") + + await adapter._handle_message(message) + + adapter.handle_message.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_discord_free_response_channel_overrides_mention_requirement(adapter, monkeypatch): + monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "true") + monkeypatch.setenv("DISCORD_FREE_RESPONSE_CHANNELS", "789,999") + + message = make_message(channel=FakeTextChannel(channel_id=789), content="allowed without mention") + + await adapter._handle_message(message) + + adapter.handle_message.assert_awaited_once() + event = adapter.handle_message.await_args.args[0] + assert event.text == "allowed without mention" + + +@pytest.mark.asyncio +async def test_discord_forum_parent_in_free_response_list_allows_forum_thread(adapter, monkeypatch): + monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "true") + monkeypatch.setenv("DISCORD_FREE_RESPONSE_CHANNELS", "222") + + forum = FakeForumChannel(channel_id=222, name="support-forum") + thread = FakeThread(channel_id=333, name="Forum topic", parent=forum) + message = make_message(channel=thread, content="allowed from forum thread") + + await adapter._handle_message(message) + + adapter.handle_message.assert_awaited_once() + event = adapter.handle_message.await_args.args[0] + assert event.text == "allowed from forum thread" + assert event.source.chat_id == "333" + + +@pytest.mark.asyncio +async def test_discord_accepts_and_strips_bot_mentions_when_required(adapter, monkeypatch): + monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "true") + monkeypatch.delenv("DISCORD_FREE_RESPONSE_CHANNELS", raising=False) + + bot_user = adapter._client.user + message = make_message( + channel=FakeTextChannel(channel_id=321), + content=f"<@{bot_user.id}> hello with mention", + mentions=[bot_user], + ) + + await adapter._handle_message(message) + + adapter.handle_message.assert_awaited_once() + event = adapter.handle_message.await_args.args[0] + assert event.text == "hello with mention" + + +@pytest.mark.asyncio +async def test_discord_dms_ignore_mention_requirement(adapter, monkeypatch): + monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "true") + monkeypatch.delenv("DISCORD_FREE_RESPONSE_CHANNELS", raising=False) + + message = make_message(channel=FakeDMChannel(channel_id=654), content="dm without mention") + + await adapter._handle_message(message) + + adapter.handle_message.assert_awaited_once() + event = adapter.handle_message.await_args.args[0] + assert event.text == "dm without mention" + assert event.source.chat_type == "dm"