Files
hermes-agent/tests/gateway/test_discord_channel_prompts.py
Teknium 620c296b1d fix: discord mock setup and AUTHOR_MAP for channel_prompts tests
Move _ensure_discord_mock() from module level to _make_adapter() so it
doesn't poison sys.modules for other discord test files. Use
types.ModuleType instead of MagicMock for the mock module to avoid
auto-generated __file__ attribute confusing hasattr checks.

Add BrennerSpear to AUTHOR_MAP.
2026-04-15 16:31:28 -07:00

260 lines
9.2 KiB
Python

"""Tests for Discord channel_prompts resolution and injection."""
import sys
import threading
import types
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock
import pytest
def _ensure_discord_mock():
if "discord" in sys.modules and hasattr(sys.modules["discord"], "__file__"):
return
discord_mod = types.ModuleType("discord")
discord_mod.Intents = MagicMock()
discord_mod.Intents.default.return_value = MagicMock()
discord_mod.DMChannel = type("DMChannel", (), {})
discord_mod.Thread = type("Thread", (), {})
discord_mod.ForumChannel = type("ForumChannel", (), {})
discord_mod.Interaction = object
ext_mod = MagicMock()
commands_mod = MagicMock()
commands_mod.Bot = MagicMock
ext_mod.commands = commands_mod
sys.modules.setdefault("discord", discord_mod)
sys.modules.setdefault("discord.ext", ext_mod)
sys.modules.setdefault("discord.ext.commands", commands_mod)
import gateway.run as gateway_run
from gateway.config import Platform
from gateway.platforms.base import MessageEvent
from gateway.session import SessionSource
class _CapturingAgent:
last_init = None
def __init__(self, *args, **kwargs):
type(self).last_init = dict(kwargs)
self.tools = []
def run_conversation(self, user_message, conversation_history=None, task_id=None, persist_user_message=None):
return {
"final_response": "ok",
"messages": [],
"api_calls": 1,
"completed": True,
}
def _install_fake_agent(monkeypatch):
fake_run_agent = types.ModuleType("run_agent")
fake_run_agent.AIAgent = _CapturingAgent
monkeypatch.setitem(sys.modules, "run_agent", fake_run_agent)
def _make_adapter():
_ensure_discord_mock()
from gateway.platforms.discord import DiscordAdapter
adapter = object.__new__(DiscordAdapter)
adapter.config = MagicMock()
adapter.config.extra = {}
return adapter
def _make_runner():
runner = object.__new__(gateway_run.GatewayRunner)
runner.adapters = {}
runner._ephemeral_system_prompt = "Global prompt"
runner._prefill_messages = []
runner._reasoning_config = None
runner._service_tier = None
runner._provider_routing = {}
runner._fallback_model = None
runner._smart_model_routing = {}
runner._running_agents = {}
runner._pending_model_notes = {}
runner._session_db = None
runner._agent_cache = {}
runner._agent_cache_lock = threading.Lock()
runner._session_model_overrides = {}
runner.hooks = SimpleNamespace(loaded_hooks=False)
runner.config = SimpleNamespace(streaming=None)
runner.session_store = SimpleNamespace(
get_or_create_session=lambda source: SimpleNamespace(session_id="session-1"),
load_transcript=lambda session_id: [],
)
runner._get_or_create_gateway_honcho = lambda session_key: (None, None)
runner._enrich_message_with_vision = AsyncMock(return_value="ENRICHED")
return runner
def _make_source() -> SessionSource:
return SessionSource(
platform=Platform.DISCORD,
chat_id="12345",
chat_type="thread",
user_id="user-1",
)
class TestResolveChannelPrompts:
def test_no_prompt_returns_none(self):
adapter = _make_adapter()
assert adapter._resolve_channel_prompt("123") is None
def test_match_by_channel_id(self):
adapter = _make_adapter()
adapter.config.extra = {"channel_prompts": {"100": "Research mode"}}
assert adapter._resolve_channel_prompt("100") == "Research mode"
def test_numeric_yaml_keys_normalized_at_config_load(self):
"""Numeric YAML keys are normalized to strings by config bridging.
The resolver itself expects string keys (config.py handles normalization),
so raw numeric keys will not match — this is intentional.
"""
adapter = _make_adapter()
# Simulates post-bridging state: keys are already strings
adapter.config.extra = {"channel_prompts": {"100": "Research mode"}}
assert adapter._resolve_channel_prompt("100") == "Research mode"
# Pre-bridging numeric key would not match (bridging is responsible)
adapter.config.extra = {"channel_prompts": {100: "Research mode"}}
assert adapter._resolve_channel_prompt("100") is None
def test_match_by_parent_id(self):
adapter = _make_adapter()
adapter.config.extra = {"channel_prompts": {"200": "Forum prompt"}}
assert adapter._resolve_channel_prompt("999", parent_id="200") == "Forum prompt"
def test_exact_channel_overrides_parent(self):
adapter = _make_adapter()
adapter.config.extra = {
"channel_prompts": {
"999": "Thread override",
"200": "Forum prompt",
}
}
assert adapter._resolve_channel_prompt("999", parent_id="200") == "Thread override"
def test_build_message_event_sets_channel_prompt(self):
adapter = _make_adapter()
adapter.config.extra = {"channel_prompts": {"321": "Command prompt"}}
adapter.build_source = MagicMock(return_value=SimpleNamespace())
interaction = SimpleNamespace(
channel_id=321,
channel=SimpleNamespace(name="general", guild=None, parent_id=None),
user=SimpleNamespace(id=1, display_name="Brenner"),
)
adapter._get_effective_topic = MagicMock(return_value=None)
event = adapter._build_slash_event(interaction, "/retry")
assert event.channel_prompt == "Command prompt"
@pytest.mark.asyncio
async def test_dispatch_thread_session_inherits_parent_channel_prompt(self):
adapter = _make_adapter()
adapter.config.extra = {"channel_prompts": {"200": "Parent prompt"}}
adapter.build_source = MagicMock(return_value=SimpleNamespace())
adapter._get_effective_topic = MagicMock(return_value=None)
adapter.handle_message = AsyncMock()
interaction = SimpleNamespace(
guild=SimpleNamespace(name="Wetlands"),
channel=SimpleNamespace(id=200, parent=None),
user=SimpleNamespace(id=1, display_name="Brenner"),
)
await adapter._dispatch_thread_session(interaction, "999", "new-thread", "hello")
dispatched_event = adapter.handle_message.await_args.args[0]
assert dispatched_event.channel_prompt == "Parent prompt"
def test_blank_prompts_are_ignored(self):
adapter = _make_adapter()
adapter.config.extra = {"channel_prompts": {"100": " "}}
assert adapter._resolve_channel_prompt("100") is None
@pytest.mark.asyncio
async def test_retry_preserves_channel_prompt(monkeypatch):
runner = _make_runner()
runner.session_store = SimpleNamespace(
get_or_create_session=lambda source: SimpleNamespace(session_id="session-1", last_prompt_tokens=10),
load_transcript=lambda session_id: [
{"role": "user", "content": "original message"},
{"role": "assistant", "content": "old reply"},
],
rewrite_transcript=MagicMock(),
)
runner._handle_message = AsyncMock(return_value="ok")
event = MessageEvent(
text="/retry",
message_type=gateway_run.MessageType.COMMAND,
source=_make_source(),
raw_message=SimpleNamespace(),
channel_prompt="Channel prompt",
)
result = await runner._handle_retry_command(event)
assert result == "ok"
retried_event = runner._handle_message.await_args.args[0]
assert retried_event.channel_prompt == "Channel prompt"
@pytest.mark.asyncio
async def test_run_agent_appends_channel_prompt_to_ephemeral_system_prompt(monkeypatch, tmp_path):
_install_fake_agent(monkeypatch)
runner = _make_runner()
(tmp_path / "config.yaml").write_text("agent:\n system_prompt: Global prompt\n", encoding="utf-8")
monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path)
monkeypatch.setattr(gateway_run, "_env_path", tmp_path / ".env")
monkeypatch.setattr(gateway_run, "load_dotenv", lambda *args, **kwargs: None)
monkeypatch.setattr(gateway_run, "_load_gateway_config", lambda: {})
monkeypatch.setattr(gateway_run, "_resolve_gateway_model", lambda config=None: "gpt-5.4")
monkeypatch.setattr(
gateway_run,
"_resolve_runtime_agent_kwargs",
lambda: {
"provider": "openrouter",
"api_mode": "chat_completions",
"base_url": "https://openrouter.ai/api/v1",
"api_key": "***",
},
)
import hermes_cli.tools_config as tools_config
monkeypatch.setattr(tools_config, "_get_platform_tools", lambda user_config, platform_key: {"core"})
_CapturingAgent.last_init = None
event = MessageEvent(
text="hi",
source=_make_source(),
message_id="m1",
channel_prompt="Channel prompt",
)
result = await runner._run_agent(
message="hi",
context_prompt="Context prompt",
history=[],
source=_make_source(),
session_id="session-1",
session_key="agent:main:discord:thread:12345",
channel_prompt=event.channel_prompt,
)
assert result["final_response"] == "ok"
assert _CapturingAgent.last_init["ephemeral_system_prompt"] == (
"Context prompt\n\nChannel prompt\n\nGlobal prompt"
)