"""Tests for Discord channel_prompts resolution and injection.""" import sys import threading import types from types import SimpleNamespace from unittest.mock import AsyncMock, MagicMock import pytest def _ensure_discord_mock(): if "discord" in sys.modules and hasattr(sys.modules["discord"], "__file__"): return discord_mod = types.ModuleType("discord") discord_mod.Intents = MagicMock() discord_mod.Intents.default.return_value = MagicMock() discord_mod.DMChannel = type("DMChannel", (), {}) discord_mod.Thread = type("Thread", (), {}) discord_mod.ForumChannel = type("ForumChannel", (), {}) discord_mod.Interaction = object ext_mod = MagicMock() commands_mod = MagicMock() commands_mod.Bot = MagicMock ext_mod.commands = commands_mod sys.modules.setdefault("discord", discord_mod) sys.modules.setdefault("discord.ext", ext_mod) sys.modules.setdefault("discord.ext.commands", commands_mod) import gateway.run as gateway_run from gateway.config import Platform from gateway.platforms.base import MessageEvent from gateway.session import SessionSource class _CapturingAgent: last_init = None def __init__(self, *args, **kwargs): type(self).last_init = dict(kwargs) self.tools = [] def run_conversation(self, user_message, conversation_history=None, task_id=None, persist_user_message=None): return { "final_response": "ok", "messages": [], "api_calls": 1, "completed": True, } def _install_fake_agent(monkeypatch): fake_run_agent = types.ModuleType("run_agent") fake_run_agent.AIAgent = _CapturingAgent monkeypatch.setitem(sys.modules, "run_agent", fake_run_agent) def _make_adapter(): _ensure_discord_mock() from gateway.platforms.discord import DiscordAdapter adapter = object.__new__(DiscordAdapter) adapter.config = MagicMock() adapter.config.extra = {} return adapter def _make_runner(): runner = object.__new__(gateway_run.GatewayRunner) runner.adapters = {} runner._ephemeral_system_prompt = "Global prompt" runner._prefill_messages = [] runner._reasoning_config = None runner._service_tier = None runner._provider_routing = {} runner._fallback_model = None runner._running_agents = {} runner._pending_model_notes = {} runner._session_db = None runner._agent_cache = {} runner._agent_cache_lock = threading.Lock() runner._session_model_overrides = {} runner.hooks = SimpleNamespace(loaded_hooks=False) runner.config = SimpleNamespace(streaming=None) runner.session_store = SimpleNamespace( get_or_create_session=lambda source: SimpleNamespace(session_id="session-1"), load_transcript=lambda session_id: [], ) runner._get_or_create_gateway_honcho = lambda session_key: (None, None) runner._enrich_message_with_vision = AsyncMock(return_value="ENRICHED") return runner def _make_source() -> SessionSource: return SessionSource( platform=Platform.DISCORD, chat_id="12345", chat_type="thread", user_id="user-1", ) class TestResolveChannelPrompts: def test_no_prompt_returns_none(self): adapter = _make_adapter() assert adapter._resolve_channel_prompt("123") is None def test_match_by_channel_id(self): adapter = _make_adapter() adapter.config.extra = {"channel_prompts": {"100": "Research mode"}} assert adapter._resolve_channel_prompt("100") == "Research mode" def test_numeric_yaml_keys_normalized_at_config_load(self): """Numeric YAML keys are normalized to strings by config bridging. The resolver itself expects string keys (config.py handles normalization), so raw numeric keys will not match — this is intentional. """ adapter = _make_adapter() # Simulates post-bridging state: keys are already strings adapter.config.extra = {"channel_prompts": {"100": "Research mode"}} assert adapter._resolve_channel_prompt("100") == "Research mode" # Pre-bridging numeric key would not match (bridging is responsible) adapter.config.extra = {"channel_prompts": {100: "Research mode"}} assert adapter._resolve_channel_prompt("100") is None def test_match_by_parent_id(self): adapter = _make_adapter() adapter.config.extra = {"channel_prompts": {"200": "Forum prompt"}} assert adapter._resolve_channel_prompt("999", parent_id="200") == "Forum prompt" def test_exact_channel_overrides_parent(self): adapter = _make_adapter() adapter.config.extra = { "channel_prompts": { "999": "Thread override", "200": "Forum prompt", } } assert adapter._resolve_channel_prompt("999", parent_id="200") == "Thread override" def test_build_message_event_sets_channel_prompt(self): adapter = _make_adapter() adapter.config.extra = {"channel_prompts": {"321": "Command prompt"}} adapter.build_source = MagicMock(return_value=SimpleNamespace()) interaction = SimpleNamespace( channel_id=321, channel=SimpleNamespace(name="general", guild=None, parent_id=None), user=SimpleNamespace(id=1, display_name="Brenner"), ) adapter._get_effective_topic = MagicMock(return_value=None) event = adapter._build_slash_event(interaction, "/retry") assert event.channel_prompt == "Command prompt" @pytest.mark.asyncio async def test_dispatch_thread_session_inherits_parent_channel_prompt(self): adapter = _make_adapter() adapter.config.extra = {"channel_prompts": {"200": "Parent prompt"}} adapter.build_source = MagicMock(return_value=SimpleNamespace()) adapter._get_effective_topic = MagicMock(return_value=None) adapter.handle_message = AsyncMock() interaction = SimpleNamespace( guild=SimpleNamespace(name="Wetlands"), channel=SimpleNamespace(id=200, parent=None), user=SimpleNamespace(id=1, display_name="Brenner"), ) await adapter._dispatch_thread_session(interaction, "999", "new-thread", "hello") dispatched_event = adapter.handle_message.await_args.args[0] assert dispatched_event.channel_prompt == "Parent prompt" def test_blank_prompts_are_ignored(self): adapter = _make_adapter() adapter.config.extra = {"channel_prompts": {"100": " "}} assert adapter._resolve_channel_prompt("100") is None @pytest.mark.asyncio async def test_retry_preserves_channel_prompt(monkeypatch): runner = _make_runner() runner.session_store = SimpleNamespace( get_or_create_session=lambda source: SimpleNamespace(session_id="session-1", last_prompt_tokens=10), load_transcript=lambda session_id: [ {"role": "user", "content": "original message"}, {"role": "assistant", "content": "old reply"}, ], rewrite_transcript=MagicMock(), ) runner._handle_message = AsyncMock(return_value="ok") event = MessageEvent( text="/retry", message_type=gateway_run.MessageType.COMMAND, source=_make_source(), raw_message=SimpleNamespace(), channel_prompt="Channel prompt", ) result = await runner._handle_retry_command(event) assert result == "ok" retried_event = runner._handle_message.await_args.args[0] assert retried_event.channel_prompt == "Channel prompt" @pytest.mark.asyncio async def test_run_agent_appends_channel_prompt_to_ephemeral_system_prompt(monkeypatch, tmp_path): _install_fake_agent(monkeypatch) runner = _make_runner() (tmp_path / "config.yaml").write_text("agent:\n system_prompt: Global prompt\n", encoding="utf-8") monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path) monkeypatch.setattr(gateway_run, "_env_path", tmp_path / ".env") monkeypatch.setattr(gateway_run, "load_dotenv", lambda *args, **kwargs: None) monkeypatch.setattr(gateway_run, "_load_gateway_config", lambda: {}) monkeypatch.setattr(gateway_run, "_resolve_gateway_model", lambda config=None: "gpt-5.4") monkeypatch.setattr( gateway_run, "_resolve_runtime_agent_kwargs", lambda: { "provider": "openrouter", "api_mode": "chat_completions", "base_url": "https://openrouter.ai/api/v1", "api_key": "***", }, ) import hermes_cli.tools_config as tools_config monkeypatch.setattr(tools_config, "_get_platform_tools", lambda user_config, platform_key: {"core"}) _CapturingAgent.last_init = None event = MessageEvent( text="hi", source=_make_source(), message_id="m1", channel_prompt="Channel prompt", ) result = await runner._run_agent( message="hi", context_prompt="Context prompt", history=[], source=_make_source(), session_id="session-1", session_key="agent:main:discord:thread:12345", channel_prompt=event.channel_prompt, ) assert result["final_response"] == "ok" assert _CapturingAgent.last_init["ephemeral_system_prompt"] == ( "Context prompt\n\nChannel prompt\n\nGlobal prompt" )