diff --git a/gateway/platforms/discord.py b/gateway/platforms/discord.py index aebae49b4a..d932d39a16 100644 --- a/gateway/platforms/discord.py +++ b/gateway/platforms/discord.py @@ -87,8 +87,9 @@ class VoiceReceiver: SAMPLE_RATE = 48000 # Discord native rate CHANNELS = 2 # Discord sends stereo - def __init__(self, voice_client): + def __init__(self, voice_client, allowed_user_ids: set = None): self._vc = voice_client + self._allowed_user_ids = allowed_user_ids or set() self._running = False # Decryption @@ -274,19 +275,21 @@ class VoiceReceiver: if self._dave_session: with self._lock: user_id = self._ssrc_to_user.get(ssrc, 0) - if user_id == 0: - if self._packet_debug_count <= 10: - logger.warning("DAVE skip: unknown user for ssrc=%d", ssrc) - return # unknown user, can't DAVE-decrypt - try: - import davey - decrypted = self._dave_session.decrypt( - user_id, davey.MediaType.audio, decrypted - ) - except Exception as e: - if self._packet_debug_count <= 10: - logger.warning("DAVE decrypt failed for ssrc=%d: %s", ssrc, e) - return + if user_id: + try: + import davey + decrypted = self._dave_session.decrypt( + user_id, davey.MediaType.audio, decrypted + ) + except Exception as e: + # Unencrypted passthrough — use NaCl-decrypted data as-is + if "Unencrypted" not in str(e): + if self._packet_debug_count <= 10: + logger.warning("DAVE decrypt failed for ssrc=%d: %s", ssrc, e) + return + # If SSRC unknown (no SPEAKING event yet), skip DAVE and try + # Opus decode directly — audio may be in passthrough mode. + # Buffer will get a user_id when SPEAKING event arrives later. # --- Opus decode -> PCM --- try: @@ -304,6 +307,32 @@ class VoiceReceiver: # Silence detection # ------------------------------------------------------------------ + def _infer_user_for_ssrc(self, ssrc: int) -> int: + """Try to infer user_id for an unmapped SSRC. + + When the bot rejoins a voice channel, Discord may not resend + SPEAKING events for users already speaking. If exactly one + allowed user is in the channel, map the SSRC to them. + """ + try: + channel = self._vc.channel + if not channel: + return 0 + bot_id = self._vc.user.id if self._vc.user else 0 + allowed = self._allowed_user_ids + candidates = [ + m.id for m in channel.members + if m.id != bot_id and (not allowed or str(m.id) in allowed) + ] + if len(candidates) == 1: + uid = candidates[0] + self._ssrc_to_user[ssrc] = uid + logger.info("Auto-mapped ssrc=%d -> user=%d (sole allowed member)", ssrc, uid) + return uid + except Exception: + pass + return 0 + def check_silence(self) -> list: """Return list of (user_id, pcm_bytes) for completed utterances.""" now = time.monotonic() @@ -322,6 +351,10 @@ class VoiceReceiver: if silence_duration >= self.SILENCE_THRESHOLD and buf_duration >= self.MIN_SPEECH_DURATION: user_id = ssrc_user_map.get(ssrc, 0) + if not user_id: + # SSRC not mapped (SPEAKING event missing after bot rejoin). + # Infer from allowed users in the voice channel. + user_id = self._infer_user_for_ssrc(ssrc) if user_id: completed.append((user_id, bytes(buf))) self._buffers[ssrc] = bytearray() @@ -695,13 +728,14 @@ class DiscordAdapter(BasePlatformAdapter): ) -> SendResult: """Play auto-TTS audio. - When the bot is in a voice channel for this chat's guild, skip the - file attachment — the gateway runner plays audio in the VC instead. + When the bot is in a voice channel for this chat's guild, play + directly in the VC instead of sending as a file attachment. """ for gid, text_ch_id in self._voice_text_channels.items(): if str(text_ch_id) == str(chat_id) and self.is_in_voice_channel(gid): - logger.debug("[%s] Skipping play_tts for %s — VC playback handled by runner", self.name, chat_id) - return SendResult(success=True) + logger.info("[%s] Playing TTS in voice channel (guild=%d)", self.name, gid) + success = await self.play_in_voice_channel(gid, audio_path) + return SendResult(success=success) return await self.send_voice(chat_id=chat_id, audio_path=audio_path, **kwargs) async def send_voice( @@ -805,7 +839,7 @@ class DiscordAdapter(BasePlatformAdapter): # Start voice receiver (Phase 2: listen to users) try: - receiver = VoiceReceiver(vc) + receiver = VoiceReceiver(vc, allowed_user_ids=self._allowed_user_ids) receiver.start() self._voice_receivers[guild_id] = receiver self._voice_listen_tasks[guild_id] = asyncio.ensure_future( @@ -1001,14 +1035,32 @@ class DiscordAdapter(BasePlatformAdapter): # Voice listening (Phase 2) # ------------------------------------------------------------------ + # UDP keepalive interval in seconds — prevents Discord from dropping + # the UDP route after ~60s of silence. + _KEEPALIVE_INTERVAL = 15 + async def _voice_listen_loop(self, guild_id: int): """Periodically check for completed utterances and process them.""" receiver = self._voice_receivers.get(guild_id) if not receiver: return + last_keepalive = time.monotonic() try: while receiver._running: await asyncio.sleep(0.2) + + # Send periodic UDP keepalive to prevent Discord from + # dropping the UDP session after ~60s of silence. + now = time.monotonic() + if now - last_keepalive >= self._KEEPALIVE_INTERVAL: + last_keepalive = now + try: + vc = self._voice_clients.get(guild_id) + if vc and vc.is_connected(): + vc._connection.send_packet(b'\xf8\xff\xfe') + except Exception: + pass + completed = receiver.check_silence() for user_id, pcm_data in completed: if not self._is_allowed_user(str(user_id)): diff --git a/gateway/run.py b/gateway/run.py index 716e981f22..43ec892696 100644 --- a/gateway/run.py +++ b/gateway/run.py @@ -2435,6 +2435,13 @@ class GatewayRunner: except Exception as e: logger.warning("Failed to join voice channel: %s", e) adapter._voice_input_callback = None + err_lower = str(e).lower() + if "pynacl" in err_lower or "nacl" in err_lower or "davey" in err_lower: + return ( + "Voice dependencies are missing (PyNaCl / davey). " + "Install or reinstall Hermes with the messaging extra, e.g. " + "`pip install hermes-agent[messaging]`." + ) return f"Failed to join voice channel: {e}" if success: @@ -2575,18 +2582,9 @@ class GatewayRunner: if has_agent_tts: return False - # Dedup: base adapter auto-TTS already handles voice input. - # Exception: Discord voice channel — play_tts override is a no-op, - # so the runner must handle VC playback. - skip_double = is_voice_input - if skip_double: - adapter = self.adapters.get(event.source.platform) - guild_id = self._get_guild_id(event) - if (guild_id and adapter - and hasattr(adapter, "is_in_voice_channel") - and adapter.is_in_voice_channel(guild_id)): - skip_double = False - if skip_double: + # Dedup: base adapter auto-TTS already handles voice input + # (play_tts plays in VC when connected, so runner can skip). + if is_voice_input: return False return True diff --git a/scripts/discord-voice-doctor.py b/scripts/discord-voice-doctor.py new file mode 100755 index 0000000000..4fd55f9e8e --- /dev/null +++ b/scripts/discord-voice-doctor.py @@ -0,0 +1,389 @@ +#!/usr/bin/env python3 +"""Discord Voice Doctor — diagnostic tool for voice channel support. + +Checks all dependencies, configuration, and bot permissions needed +for Discord voice mode to work correctly. + +Usage: + python scripts/discord-voice-doctor.py + .venv/bin/python scripts/discord-voice-doctor.py +""" + +import os +import sys +import shutil +from pathlib import Path + +# Resolve project root +SCRIPT_DIR = Path(__file__).resolve().parent +PROJECT_ROOT = SCRIPT_DIR.parent +sys.path.insert(0, str(PROJECT_ROOT)) + +HERMES_HOME = Path(os.getenv("HERMES_HOME", Path.home() / ".hermes")) +ENV_FILE = HERMES_HOME / ".env" + +OK = "\033[92m\u2713\033[0m" +FAIL = "\033[91m\u2717\033[0m" +WARN = "\033[93m!\033[0m" + +# Track whether discord.py is available for later sections +_discord_available = False + + +def mask(value): + """Mask sensitive value: show only first 4 chars.""" + if not value or len(value) < 8: + return "****" + return f"{value[:4]}{'*' * (len(value) - 4)}" + + +def check(label, ok, detail=""): + symbol = OK if ok else FAIL + msg = f" {symbol} {label}" + if detail: + msg += f" ({detail})" + print(msg) + return ok + + +def warn(label, detail=""): + msg = f" {WARN} {label}" + if detail: + msg += f" ({detail})" + print(msg) + + +def section(title): + print(f"\n\033[1m{title}\033[0m") + + +def check_packages(): + """Check Python package dependencies. Returns True if all critical deps OK.""" + global _discord_available + section("Python Packages") + ok = True + + # discord.py + try: + import discord + _discord_available = True + check("discord.py", True, f"v{discord.__version__}") + except ImportError: + check("discord.py", False, "pip install discord.py[voice]") + ok = False + + # PyNaCl + try: + import nacl + ver = getattr(nacl, "__version__", "unknown") + try: + import nacl.secret + nacl.secret.Aead(bytes(32)) + check("PyNaCl", True, f"v{ver}") + except (AttributeError, Exception): + check("PyNaCl (Aead)", False, f"v{ver} — need >=1.5.0") + ok = False + except ImportError: + check("PyNaCl", False, "pip install PyNaCl>=1.5.0") + ok = False + + # davey (DAVE E2EE) + try: + import davey + check("davey (DAVE E2EE)", True, f"v{getattr(davey, '__version__', '?')}") + except ImportError: + check("davey (DAVE E2EE)", False, "pip install davey") + ok = False + + # Optional: local STT + try: + import faster_whisper + check("faster-whisper (local STT)", True) + except ImportError: + warn("faster-whisper (local STT)", "not installed — local STT unavailable") + + # Optional: TTS providers + try: + import edge_tts + check("edge-tts", True) + except ImportError: + warn("edge-tts", "not installed — edge TTS unavailable") + + try: + import elevenlabs + check("elevenlabs SDK", True) + except ImportError: + warn("elevenlabs SDK", "not installed — premium TTS unavailable") + + return ok + + +def check_system_tools(): + """Check system-level tools (opus, ffmpeg). Returns True if all OK.""" + section("System Tools") + ok = True + + # Opus codec + if _discord_available: + try: + import discord + opus_loaded = discord.opus.is_loaded() + if not opus_loaded: + import ctypes.util + opus_path = ctypes.util.find_library("opus") + if not opus_path: + # Platform-specific fallback paths + candidates = [ + "/opt/homebrew/lib/libopus.dylib", # macOS Apple Silicon + "/usr/local/lib/libopus.dylib", # macOS Intel + "/usr/lib/x86_64-linux-gnu/libopus.so.0", # Debian/Ubuntu x86 + "/usr/lib/aarch64-linux-gnu/libopus.so.0", # Debian/Ubuntu ARM + "/usr/lib/libopus.so", # Arch Linux + "/usr/lib64/libopus.so", # RHEL/Fedora + ] + for p in candidates: + if os.path.isfile(p): + opus_path = p + break + if opus_path: + discord.opus.load_opus(opus_path) + opus_loaded = discord.opus.is_loaded() + if opus_loaded: + check("Opus codec", True) + else: + check("Opus codec", False, "brew install opus / apt install libopus0") + ok = False + except Exception as e: + check("Opus codec", False, str(e)) + ok = False + else: + warn("Opus codec", "skipped — discord.py not installed") + + # ffmpeg + ffmpeg_path = shutil.which("ffmpeg") + if ffmpeg_path: + check("ffmpeg", True, ffmpeg_path) + else: + check("ffmpeg", False, "brew install ffmpeg / apt install ffmpeg") + ok = False + + return ok + + +def check_env_vars(): + """Check environment variables. Returns (ok, token, groq_key, eleven_key).""" + section("Environment Variables") + + # Load .env + try: + from dotenv import load_dotenv + if ENV_FILE.exists(): + load_dotenv(ENV_FILE) + except ImportError: + pass + + ok = True + + token = os.getenv("DISCORD_BOT_TOKEN", "") + if token: + check("DISCORD_BOT_TOKEN", True, mask(token)) + else: + check("DISCORD_BOT_TOKEN", False, "not set") + ok = False + + # Allowed users — resolve usernames if possible + allowed = os.getenv("DISCORD_ALLOWED_USERS", "") + if allowed: + users = [u.strip() for u in allowed.split(",") if u.strip()] + user_labels = [] + for uid in users: + label = mask(uid) + if token and uid.isdigit(): + try: + import requests + r = requests.get( + f"https://discord.com/api/v10/users/{uid}", + headers={"Authorization": f"Bot {token}"}, + timeout=3, + ) + if r.status_code == 200: + label = f"{r.json().get('username', '?')} ({mask(uid)})" + except Exception: + pass + user_labels.append(label) + check("DISCORD_ALLOWED_USERS", True, f"{len(users)} user(s): {', '.join(user_labels)}") + else: + warn("DISCORD_ALLOWED_USERS", "not set — all users can use voice") + + groq_key = os.getenv("GROQ_API_KEY", "") + eleven_key = os.getenv("ELEVENLABS_API_KEY", "") + + if groq_key: + check("GROQ_API_KEY (STT)", True, mask(groq_key)) + else: + warn("GROQ_API_KEY", "not set — Groq STT unavailable") + + if eleven_key: + check("ELEVENLABS_API_KEY (TTS)", True, mask(eleven_key)) + else: + warn("ELEVENLABS_API_KEY", "not set — ElevenLabs TTS unavailable") + + return ok, token, groq_key, eleven_key + + +def check_config(groq_key, eleven_key): + """Check hermes config.yaml.""" + section("Configuration") + + config_path = HERMES_HOME / "config.yaml" + if config_path.exists(): + try: + import yaml + with open(config_path) as f: + cfg = yaml.safe_load(f) or {} + + stt_provider = cfg.get("stt", {}).get("provider", "local") + tts_provider = cfg.get("tts", {}).get("provider", "edge") + check("STT provider", True, stt_provider) + check("TTS provider", True, tts_provider) + + if stt_provider == "groq" and not groq_key: + warn("STT config says groq but GROQ_API_KEY is missing") + if tts_provider == "elevenlabs" and not eleven_key: + warn("TTS config says elevenlabs but ELEVENLABS_API_KEY is missing") + except Exception as e: + warn("config.yaml", f"parse error: {e}") + else: + warn("config.yaml", "not found — using defaults") + + # Voice mode state + voice_mode_path = HERMES_HOME / "gateway_voice_mode.json" + if voice_mode_path.exists(): + try: + import json + modes = json.loads(voice_mode_path.read_text()) + off_count = sum(1 for v in modes.values() if v == "off") + all_count = sum(1 for v in modes.values() if v == "all") + check("Voice mode state", True, f"{all_count} on, {off_count} off, {len(modes)} total") + except Exception: + warn("Voice mode state", "parse error") + else: + check("Voice mode state", True, "no saved state (fresh)") + + +def check_bot_permissions(token): + """Check bot permissions via Discord API. Returns True if all OK.""" + section("Bot Permissions") + + if not token: + warn("Bot permissions", "no token — skipping") + return True + + try: + import requests + except ImportError: + warn("Bot permissions", "requests not installed — skipping") + return True + + VOICE_PERMS = { + "Priority Speaker": 8, + "Stream": 9, + "View Channel": 10, + "Send Messages": 11, + "Embed Links": 14, + "Attach Files": 15, + "Read Message History": 16, + "Connect": 20, + "Speak": 21, + "Mute Members": 22, + "Deafen Members": 23, + "Move Members": 24, + "Use VAD": 25, + "Send Voice Messages": 46, + } + REQUIRED_PERMS = {"Connect", "Speak", "View Channel", "Send Messages"} + ok = True + + try: + headers = {"Authorization": f"Bot {token}"} + r = requests.get("https://discord.com/api/v10/users/@me", headers=headers, timeout=5) + + if r.status_code == 401: + check("Bot login", False, "invalid token (401)") + return False + if r.status_code != 200: + check("Bot login", False, f"HTTP {r.status_code}") + return False + + bot = r.json() + bot_name = bot.get("username", "?") + check("Bot login", True, f"{bot_name[:3]}{'*' * (len(bot_name) - 3)}") + + # Check guilds + r2 = requests.get("https://discord.com/api/v10/users/@me/guilds", headers=headers, timeout=5) + if r2.status_code != 200: + warn("Guilds", f"HTTP {r2.status_code}") + return ok + + guilds = r2.json() + check("Guilds", True, f"{len(guilds)} guild(s)") + + for g in guilds[:5]: + perms = int(g.get("permissions", 0)) + is_admin = bool(perms & (1 << 3)) + + if is_admin: + print(f" {OK} {g['name']}: Administrator (all permissions)") + continue + + has = [] + missing = [] + for name, bit in sorted(VOICE_PERMS.items(), key=lambda x: x[1]): + if perms & (1 << bit): + has.append(name) + elif name in REQUIRED_PERMS: + missing.append(name) + + if missing: + print(f" {FAIL} {g['name']}: missing {', '.join(missing)}") + ok = False + else: + print(f" {OK} {g['name']}: {', '.join(has)}") + + except requests.exceptions.Timeout: + warn("Bot permissions", "Discord API timeout") + except requests.exceptions.ConnectionError: + warn("Bot permissions", "cannot reach Discord API") + except Exception as e: + warn("Bot permissions", f"check failed: {e}") + + return ok + + +def main(): + print() + print("\033[1m" + "=" * 50 + "\033[0m") + print("\033[1m Discord Voice Doctor\033[0m") + print("\033[1m" + "=" * 50 + "\033[0m") + + all_ok = True + + all_ok &= check_packages() + all_ok &= check_system_tools() + env_ok, token, groq_key, eleven_key = check_env_vars() + all_ok &= env_ok + check_config(groq_key, eleven_key) + all_ok &= check_bot_permissions(token) + + # Summary + print() + print("\033[1m" + "-" * 50 + "\033[0m") + if all_ok: + print(f" {OK} \033[92mAll checks passed — voice mode ready!\033[0m") + else: + print(f" {FAIL} \033[91mSome checks failed — fix issues above.\033[0m") + print() + + +if __name__ == "__main__": + main() diff --git a/tests/gateway/test_voice_command.py b/tests/gateway/test_voice_command.py index 545f2b28fb..9c9d5753a4 100644 --- a/tests/gateway/test_voice_command.py +++ b/tests/gateway/test_voice_command.py @@ -1,5 +1,6 @@ """Tests for the /voice command and auto voice reply in the gateway.""" +import importlib.util import json import os import queue @@ -206,9 +207,11 @@ class TestAutoVoiceReply: 2. gateway _send_voice_reply: fires based on voice_mode setting To prevent double audio, _send_voice_reply is skipped when voice input - already triggered base adapter auto-TTS (skip_double = is_voice_input). - Exception: Discord voice channel — both auto-TTS and Discord play_tts - override skip, so the runner must handle it via play_in_voice_channel. + already triggered base adapter auto-TTS. + + For Discord voice channels, the base adapter now routes play_tts directly + into VC playback, so the runner should still skip voice-input follow-ups to + avoid double playback. """ @pytest.fixture @@ -292,14 +295,14 @@ class TestAutoVoiceReply: # -- Discord VC exception: runner must handle -------------------------- - def test_discord_vc_voice_input_runner_fires(self, runner): - """Discord VC + voice input: base play_tts skips (VC override), - so runner must handle via play_in_voice_channel.""" - assert self._call(runner, "all", MessageType.VOICE, in_voice_channel=True) is True + def test_discord_vc_voice_input_base_handles(self, runner): + """Discord VC + voice input: base adapter play_tts plays in VC, + so runner skips to avoid double playback.""" + assert self._call(runner, "all", MessageType.VOICE, in_voice_channel=True) is False - def test_discord_vc_voice_only_runner_fires(self, runner): - """Discord VC + voice_only + voice: runner must handle.""" - assert self._call(runner, "voice_only", MessageType.VOICE, in_voice_channel=True) is True + def test_discord_vc_voice_only_base_handles(self, runner): + """Discord VC + voice_only + voice: base adapter handles.""" + assert self._call(runner, "voice_only", MessageType.VOICE, in_voice_channel=True) is False # -- Edge cases -------------------------------------------------------- @@ -422,17 +425,23 @@ class TestDiscordPlayTtsSkip: return adapter @pytest.mark.asyncio - async def test_play_tts_skipped_when_in_vc(self): + async def test_play_tts_plays_in_vc_when_connected(self): adapter = self._make_discord_adapter() # Simulate bot in voice channel for guild 111, text channel 123 mock_vc = MagicMock() mock_vc.is_connected.return_value = True + mock_vc.is_playing.return_value = False adapter._voice_clients[111] = mock_vc adapter._voice_text_channels[111] = 123 + # Mock play_in_voice_channel to avoid actual ffmpeg call + async def fake_play(gid, path): + return True + adapter.play_in_voice_channel = fake_play + result = await adapter.play_tts(chat_id="123", audio_path="/tmp/test.ogg") + # play_tts now plays in VC instead of being a no-op assert result.success is True - # send_voice should NOT have been called (no client, would fail) @pytest.mark.asyncio async def test_play_tts_not_skipped_when_not_in_vc(self): @@ -728,6 +737,24 @@ class TestVoiceChannelCommands: result = await runner._handle_voice_channel_join(event) assert "failed" in result.lower() + @pytest.mark.asyncio + async def test_join_missing_voice_dependencies(self, runner): + """Missing PyNaCl/davey should return a user-actionable install hint.""" + mock_channel = MagicMock() + mock_channel.name = "General" + mock_adapter = AsyncMock() + mock_adapter.join_voice_channel = AsyncMock( + side_effect=RuntimeError("PyNaCl library needed in order to use voice") + ) + mock_adapter.get_user_voice_channel = AsyncMock(return_value=mock_channel) + event = self._make_discord_event() + runner.adapters[event.source.platform] = mock_adapter + + result = await runner._handle_voice_channel_join(event) + + assert "voice dependencies are missing" in result.lower() + assert "hermes-agent[messaging]" in result + # -- _handle_voice_channel_leave -- @pytest.mark.asyncio @@ -2031,3 +2058,534 @@ class TestDisconnectVoiceCleanup: assert len(adapter._voice_receivers) == 0 assert len(adapter._voice_listen_tasks) == 0 assert len(adapter._voice_timeout_tasks) == 0 + + +# ===================================================================== +# Discord Voice Channel Flow Tests +# ===================================================================== + + +@pytest.mark.skipif( + importlib.util.find_spec("nacl") is None, + reason="PyNaCl not installed", +) +class TestVoiceReception: + """Audio reception: SSRC mapping, DAVE passthrough, buffer lifecycle.""" + + @staticmethod + def _make_receiver(allowed_ids=None, members=None, dave=False, bot_id=9999): + from gateway.platforms.discord import VoiceReceiver + vc = MagicMock() + vc._connection.secret_key = [0] * 32 + vc._connection.dave_session = MagicMock() if dave else None + vc._connection.ssrc = bot_id + vc._connection.add_socket_listener = MagicMock() + vc._connection.remove_socket_listener = MagicMock() + vc._connection.hook = None + vc.user = SimpleNamespace(id=bot_id) + vc.channel = MagicMock() + vc.channel.members = members or [] + receiver = VoiceReceiver(vc, allowed_user_ids=allowed_ids) + return receiver + + @staticmethod + def _fill_buffer(receiver, ssrc, duration_s=1.0, age_s=3.0): + """Add PCM data to buffer. 48kHz stereo 16-bit = 192000 bytes/sec.""" + size = int(192000 * duration_s) + receiver._buffers[ssrc] = bytearray(b"\x00" * size) + receiver._last_packet_time[ssrc] = time.monotonic() - age_s + + # -- Known SSRC (normal flow) -- + + def test_known_ssrc_returns_completed(self): + receiver = self._make_receiver() + receiver.start() + receiver.map_ssrc(100, 42) + self._fill_buffer(receiver, 100) + completed = receiver.check_silence() + assert len(completed) == 1 + assert completed[0][0] == 42 + assert len(receiver._buffers[100]) == 0 # cleared + + def test_known_ssrc_short_buffer_ignored(self): + receiver = self._make_receiver() + receiver.start() + receiver.map_ssrc(100, 42) + self._fill_buffer(receiver, 100, duration_s=0.1) # too short + completed = receiver.check_silence() + assert len(completed) == 0 + + def test_known_ssrc_recent_audio_waits(self): + receiver = self._make_receiver() + receiver.start() + receiver.map_ssrc(100, 42) + self._fill_buffer(receiver, 100, age_s=0.0) # just arrived + completed = receiver.check_silence() + assert len(completed) == 0 + + # -- Unknown SSRC + DAVE passthrough -- + + def test_unknown_ssrc_no_automap_no_completed(self): + """Unknown SSRC, no members to infer — buffer cleared, not returned.""" + receiver = self._make_receiver(dave=True, members=[]) + receiver.start() + self._fill_buffer(receiver, 100) + completed = receiver.check_silence() + assert len(completed) == 0 + assert len(receiver._buffers[100]) == 0 + + def test_unknown_ssrc_late_speaking_event(self): + """Audio buffered before SPEAKING → SPEAKING maps → next check returns it.""" + receiver = self._make_receiver(dave=True) + receiver.start() + self._fill_buffer(receiver, 100, age_s=0.0) # still receiving + # No user yet + assert receiver.check_silence() == [] + # SPEAKING event arrives + receiver.map_ssrc(100, 42) + # Silence kicks in + receiver._last_packet_time[100] = time.monotonic() - 3.0 + completed = receiver.check_silence() + assert len(completed) == 1 + assert completed[0][0] == 42 + + # -- SSRC auto-mapping -- + + def test_automap_single_allowed_user(self): + members = [ + SimpleNamespace(id=9999, name="Bot"), + SimpleNamespace(id=42, name="Alice"), + ] + receiver = self._make_receiver(allowed_ids={"42"}, members=members) + receiver.start() + self._fill_buffer(receiver, 100) + completed = receiver.check_silence() + assert len(completed) == 1 + assert completed[0][0] == 42 + assert receiver._ssrc_to_user[100] == 42 + + def test_automap_multiple_allowed_users_no_map(self): + members = [ + SimpleNamespace(id=9999, name="Bot"), + SimpleNamespace(id=42, name="Alice"), + SimpleNamespace(id=43, name="Bob"), + ] + receiver = self._make_receiver(allowed_ids={"42", "43"}, members=members) + receiver.start() + self._fill_buffer(receiver, 100) + completed = receiver.check_silence() + assert len(completed) == 0 + + def test_automap_no_allowlist_single_member(self): + """No allowed_user_ids → sole non-bot member inferred.""" + members = [ + SimpleNamespace(id=9999, name="Bot"), + SimpleNamespace(id=42, name="Alice"), + ] + receiver = self._make_receiver(allowed_ids=None, members=members) + receiver.start() + self._fill_buffer(receiver, 100) + completed = receiver.check_silence() + assert len(completed) == 1 + assert completed[0][0] == 42 + + def test_automap_unallowed_user_rejected(self): + """User in channel but not in allowed list — not mapped.""" + members = [ + SimpleNamespace(id=9999, name="Bot"), + SimpleNamespace(id=42, name="Alice"), + ] + receiver = self._make_receiver(allowed_ids={"99"}, members=members) + receiver.start() + self._fill_buffer(receiver, 100) + completed = receiver.check_silence() + assert len(completed) == 0 + + def test_automap_only_bot_in_channel(self): + """Only bot in channel — no one to map to.""" + members = [SimpleNamespace(id=9999, name="Bot")] + receiver = self._make_receiver(allowed_ids=None, members=members) + receiver.start() + self._fill_buffer(receiver, 100) + completed = receiver.check_silence() + assert len(completed) == 0 + + def test_automap_persists_across_calls(self): + """Auto-mapped SSRC stays mapped for subsequent checks.""" + members = [ + SimpleNamespace(id=9999, name="Bot"), + SimpleNamespace(id=42, name="Alice"), + ] + receiver = self._make_receiver(allowed_ids={"42"}, members=members) + receiver.start() + self._fill_buffer(receiver, 100) + receiver.check_silence() + assert receiver._ssrc_to_user[100] == 42 + # Second utterance — should use cached mapping + self._fill_buffer(receiver, 100) + completed = receiver.check_silence() + assert len(completed) == 1 + assert completed[0][0] == 42 + + # -- Stale buffer cleanup -- + + def test_stale_unknown_buffer_discarded(self): + """Buffer with no user and very old timestamp is discarded.""" + receiver = self._make_receiver() + receiver.start() + receiver._buffers[200] = bytearray(b"\x00" * 100) + receiver._last_packet_time[200] = time.monotonic() - 10.0 + receiver.check_silence() + assert 200 not in receiver._buffers + + # -- Pause / resume (echo prevention) -- + + def test_paused_receiver_ignores_packets(self): + receiver = self._make_receiver() + receiver.start() + receiver.pause() + receiver._on_packet(b"\x00" * 100) + assert len(receiver._buffers) == 0 + + def test_resumed_receiver_accepts_packets(self): + receiver = self._make_receiver() + receiver.start() + receiver.pause() + receiver.resume() + assert receiver._paused is False + + # -- _on_packet DAVE passthrough behavior -- + + def _make_receiver_with_nacl(self, dave_session=None, mapped_ssrcs=None): + """Create a receiver that can process _on_packet with mocked NaCl + Opus.""" + from gateway.platforms.discord import VoiceReceiver + vc = MagicMock() + vc._connection.secret_key = [0] * 32 + vc._connection.dave_session = dave_session + vc._connection.ssrc = 9999 + vc._connection.add_socket_listener = MagicMock() + vc._connection.remove_socket_listener = MagicMock() + vc._connection.hook = None + vc.user = SimpleNamespace(id=9999) + vc.channel = MagicMock() + vc.channel.members = [] + receiver = VoiceReceiver(vc) + receiver.start() + # Pre-map SSRCs if provided + if mapped_ssrcs: + for ssrc, uid in mapped_ssrcs.items(): + receiver.map_ssrc(ssrc, uid) + return receiver + + @staticmethod + def _build_rtp_packet(ssrc=100, seq=1, timestamp=960): + """Build a minimal valid RTP packet for _on_packet. + + We need: RTP header (12 bytes) + encrypted payload + 4-byte nonce. + NaCl decrypt is mocked so payload content doesn't matter. + """ + import struct + # RTP header: version=2, payload_type=0x78, no extension, no CSRC + header = struct.pack(">BBHII", 0x80, 0x78, seq, timestamp, ssrc) + # Fake encrypted payload (NaCl will be mocked) + 4 byte nonce + payload = b"\x00" * 20 + b"\x00\x00\x00\x01" + return header + payload + + def _inject_mock_decoder(self, receiver, ssrc): + """Pre-inject a mock Opus decoder for the given SSRC.""" + mock_decoder = MagicMock() + mock_decoder.decode.return_value = b"\x00" * 3840 + receiver._decoders[ssrc] = mock_decoder + return mock_decoder + + def test_on_packet_dave_known_user_decrypt_ok(self): + """Known SSRC + DAVE decrypt success → audio buffered.""" + dave = MagicMock() + dave.decrypt.return_value = b"\xf8\xff\xfe" + receiver = self._make_receiver_with_nacl( + dave_session=dave, mapped_ssrcs={100: 42} + ) + self._inject_mock_decoder(receiver, 100) + + with patch("nacl.secret.Aead") as mock_aead: + mock_aead.return_value.decrypt.return_value = b"\xf8\xff\xfe" + receiver._on_packet(self._build_rtp_packet(ssrc=100)) + + assert 100 in receiver._buffers + assert len(receiver._buffers[100]) > 0 + dave.decrypt.assert_called_once() + + def test_on_packet_dave_unknown_ssrc_passthrough(self): + """Unknown SSRC + DAVE → skip DAVE, attempt Opus decode (passthrough).""" + dave = MagicMock() + receiver = self._make_receiver_with_nacl(dave_session=dave) + self._inject_mock_decoder(receiver, 100) + + with patch("nacl.secret.Aead") as mock_aead: + mock_aead.return_value.decrypt.return_value = b"\xf8\xff\xfe" + receiver._on_packet(self._build_rtp_packet(ssrc=100)) + + dave.decrypt.assert_not_called() + assert 100 in receiver._buffers + assert len(receiver._buffers[100]) > 0 + + def test_on_packet_dave_unencrypted_error_passthrough(self): + """DAVE decrypt 'Unencrypted' error → use data as-is, don't drop.""" + dave = MagicMock() + dave.decrypt.side_effect = Exception( + "Failed to decrypt: DecryptionFailed(UnencryptedWhenPassthroughDisabled)" + ) + receiver = self._make_receiver_with_nacl( + dave_session=dave, mapped_ssrcs={100: 42} + ) + self._inject_mock_decoder(receiver, 100) + + with patch("nacl.secret.Aead") as mock_aead: + mock_aead.return_value.decrypt.return_value = b"\xf8\xff\xfe" + receiver._on_packet(self._build_rtp_packet(ssrc=100)) + + assert 100 in receiver._buffers + assert len(receiver._buffers[100]) > 0 + + def test_on_packet_dave_other_error_drops(self): + """DAVE decrypt non-Unencrypted error → packet dropped.""" + dave = MagicMock() + dave.decrypt.side_effect = Exception("KeyRotationFailed") + receiver = self._make_receiver_with_nacl( + dave_session=dave, mapped_ssrcs={100: 42} + ) + + with patch("nacl.secret.Aead") as mock_aead: + mock_aead.return_value.decrypt.return_value = b"\xf8\xff\xfe" + receiver._on_packet(self._build_rtp_packet(ssrc=100)) + + assert len(receiver._buffers.get(100, b"")) == 0 + + def test_on_packet_no_dave_direct_decode(self): + """No DAVE session → decode directly.""" + receiver = self._make_receiver_with_nacl(dave_session=None) + self._inject_mock_decoder(receiver, 100) + + with patch("nacl.secret.Aead") as mock_aead: + mock_aead.return_value.decrypt.return_value = b"\xf8\xff\xfe" + receiver._on_packet(self._build_rtp_packet(ssrc=100)) + + assert 100 in receiver._buffers + assert len(receiver._buffers[100]) > 0 + + def test_on_packet_bot_own_ssrc_ignored(self): + """Bot's own SSRC → dropped (echo prevention).""" + receiver = self._make_receiver_with_nacl() + with patch("nacl.secret.Aead"): + receiver._on_packet(self._build_rtp_packet(ssrc=9999)) + assert len(receiver._buffers) == 0 + + def test_on_packet_multiple_ssrcs_separate_buffers(self): + """Different SSRCs → separate buffers.""" + receiver = self._make_receiver_with_nacl(dave_session=None) + self._inject_mock_decoder(receiver, 100) + self._inject_mock_decoder(receiver, 200) + + with patch("nacl.secret.Aead") as mock_aead: + mock_aead.return_value.decrypt.return_value = b"\xf8\xff\xfe" + receiver._on_packet(self._build_rtp_packet(ssrc=100)) + receiver._on_packet(self._build_rtp_packet(ssrc=200)) + + assert 100 in receiver._buffers + assert 200 in receiver._buffers + + +class TestVoiceTTSPlayback: + """TTS playback: play_tts in VC, dedup, fallback.""" + + @staticmethod + def _make_discord_adapter(): + from gateway.platforms.discord import DiscordAdapter + from gateway.config import PlatformConfig, Platform + config = PlatformConfig(enabled=True, extra={}) + config.token = "fake-token" + adapter = object.__new__(DiscordAdapter) + adapter.platform = Platform.DISCORD + adapter.config = config + adapter._voice_clients = {} + adapter._voice_text_channels = {} + adapter._voice_receivers = {} + return adapter + + # -- play_tts behavior -- + + @pytest.mark.asyncio + async def test_play_tts_plays_in_vc(self): + """play_tts calls play_in_voice_channel when bot is in VC.""" + adapter = self._make_discord_adapter() + mock_vc = MagicMock() + mock_vc.is_connected.return_value = True + adapter._voice_clients[111] = mock_vc + adapter._voice_text_channels[111] = 123 + + played = [] + async def fake_play(gid, path): + played.append((gid, path)) + return True + adapter.play_in_voice_channel = fake_play + + result = await adapter.play_tts(chat_id="123", audio_path="/tmp/tts.ogg") + assert result.success is True + assert played == [(111, "/tmp/tts.ogg")] + + @pytest.mark.asyncio + async def test_play_tts_fallback_when_not_in_vc(self): + """play_tts sends as file attachment when bot is not in VC.""" + adapter = self._make_discord_adapter() + from gateway.platforms.base import SendResult + adapter.send_voice = AsyncMock(return_value=SendResult(success=False, error="no client")) + result = await adapter.play_tts(chat_id="123", audio_path="/tmp/tts.ogg") + assert result.success is False + adapter.send_voice.assert_called_once() + + @pytest.mark.asyncio + async def test_play_tts_wrong_channel_no_match(self): + """play_tts doesn't match if chat_id is for a different channel.""" + adapter = self._make_discord_adapter() + mock_vc = MagicMock() + mock_vc.is_connected.return_value = True + adapter._voice_clients[111] = mock_vc + adapter._voice_text_channels[111] = 123 + + from gateway.platforms.base import SendResult + adapter.send_voice = AsyncMock(return_value=SendResult(success=True)) + # Different chat_id — shouldn't match VC + result = await adapter.play_tts(chat_id="999", audio_path="/tmp/tts.ogg") + adapter.send_voice.assert_called_once() + + # -- Runner dedup -- + + @staticmethod + def _make_runner(): + from gateway.run import GatewayRunner + runner = object.__new__(GatewayRunner) + runner._voice_mode = {} + runner.adapters = {} + return runner + + def _call_should_reply(self, runner, voice_mode, msg_type, response="Hello", agent_msgs=None): + from gateway.platforms.base import MessageType, MessageEvent, SessionSource + from gateway.config import Platform + runner._voice_mode["ch1"] = voice_mode + source = SessionSource( + platform=Platform.DISCORD, chat_id="ch1", + user_id="1", user_name="test", chat_type="channel", + ) + event = MessageEvent(source=source, text="test", message_type=msg_type) + return runner._should_send_voice_reply(event, response, agent_msgs or []) + + def test_voice_input_runner_skips(self): + """Voice input: runner skips — base adapter handles via play_tts.""" + from gateway.platforms.base import MessageType + runner = self._make_runner() + assert self._call_should_reply(runner, "all", MessageType.VOICE) is False + + def test_text_input_voice_all_runner_fires(self): + """Text input + voice_mode=all: runner generates TTS.""" + from gateway.platforms.base import MessageType + runner = self._make_runner() + assert self._call_should_reply(runner, "all", MessageType.TEXT) is True + + def test_text_input_voice_off_no_tts(self): + """Text input + voice_mode=off: no TTS.""" + from gateway.platforms.base import MessageType + runner = self._make_runner() + assert self._call_should_reply(runner, "off", MessageType.TEXT) is False + + def test_text_input_voice_only_no_tts(self): + """Text input + voice_mode=voice_only: no TTS for text.""" + from gateway.platforms.base import MessageType + runner = self._make_runner() + assert self._call_should_reply(runner, "voice_only", MessageType.TEXT) is False + + def test_error_response_no_tts(self): + """Error response: no TTS regardless of voice_mode.""" + from gateway.platforms.base import MessageType + runner = self._make_runner() + assert self._call_should_reply(runner, "all", MessageType.TEXT, response="Error: boom") is False + + def test_empty_response_no_tts(self): + """Empty response: no TTS.""" + from gateway.platforms.base import MessageType + runner = self._make_runner() + assert self._call_should_reply(runner, "all", MessageType.TEXT, response="") is False + + def test_agent_tts_tool_dedup(self): + """Agent already called text_to_speech tool: runner skips.""" + from gateway.platforms.base import MessageType + runner = self._make_runner() + agent_msgs = [{"role": "assistant", "tool_calls": [ + {"id": "1", "type": "function", "function": {"name": "text_to_speech", "arguments": "{}"}} + ]}] + assert self._call_should_reply(runner, "all", MessageType.TEXT, agent_msgs=agent_msgs) is False + + +class TestUDPKeepalive: + """UDP keepalive prevents Discord from dropping the voice session.""" + + def test_keepalive_interval_is_reasonable(self): + from gateway.platforms.discord import DiscordAdapter + interval = DiscordAdapter._KEEPALIVE_INTERVAL + assert 5 <= interval <= 30, f"Keepalive interval {interval}s should be between 5-30s" + + @pytest.mark.asyncio + async def test_keepalive_sends_silence_frame(self): + """Listen loop sends silence frame via send_packet after interval.""" + from gateway.platforms.discord import DiscordAdapter + from gateway.config import PlatformConfig, Platform + + config = PlatformConfig(enabled=True, extra={}) + config.token = "fake" + adapter = object.__new__(DiscordAdapter) + adapter.platform = Platform.DISCORD + adapter.config = config + adapter._voice_clients = {} + adapter._voice_text_channels = {} + adapter._voice_receivers = {} + adapter._voice_listen_tasks = {} + + # Mock VC and receiver + mock_vc = MagicMock() + mock_vc.is_connected.return_value = True + mock_conn = MagicMock() + adapter._voice_clients[111] = mock_vc + mock_vc._connection = mock_conn + + from gateway.platforms.discord import VoiceReceiver + mock_receiver_vc = MagicMock() + mock_receiver_vc._connection.secret_key = [0] * 32 + mock_receiver_vc._connection.dave_session = None + mock_receiver_vc._connection.ssrc = 9999 + mock_receiver_vc._connection.add_socket_listener = MagicMock() + mock_receiver_vc._connection.remove_socket_listener = MagicMock() + mock_receiver_vc._connection.hook = None + receiver = VoiceReceiver(mock_receiver_vc) + receiver.start() + adapter._voice_receivers[111] = receiver + + # Set keepalive interval very short for test + original_interval = DiscordAdapter._KEEPALIVE_INTERVAL + DiscordAdapter._KEEPALIVE_INTERVAL = 0.1 + + try: + # Run listen loop briefly + import asyncio + loop_task = asyncio.create_task(adapter._voice_listen_loop(111)) + await asyncio.sleep(0.3) + receiver._running = False # stop loop + await asyncio.sleep(0.1) + loop_task.cancel() + try: + await loop_task + except asyncio.CancelledError: + pass + + # send_packet should have been called with silence frame + mock_conn.send_packet.assert_called_with(b'\xf8\xff\xfe') + finally: + DiscordAdapter._KEEPALIVE_INTERVAL = original_interval diff --git a/tests/integration/test_voice_channel_flow.py b/tests/integration/test_voice_channel_flow.py new file mode 100644 index 0000000000..096ef9d3f3 --- /dev/null +++ b/tests/integration/test_voice_channel_flow.py @@ -0,0 +1,611 @@ +"""Integration tests for Discord voice channel audio flow. + +Uses real NaCl encryption and Opus codec (no mocks for crypto/codec). +Does NOT require a Discord connection — tests the VoiceReceiver +packet processing pipeline end-to-end. + +Requires: PyNaCl>=1.5.0, discord.py[voice] (opus codec) +""" + +import struct +import time +import pytest + +pytestmark = pytest.mark.integration + +# Skip entire module if voice deps are missing +pytest.importorskip("nacl.secret", reason="PyNaCl required for voice integration tests") +discord = pytest.importorskip("discord", reason="discord.py required for voice integration tests") + +import nacl.secret + +try: + if not discord.opus.is_loaded(): + import ctypes.util + opus_path = ctypes.util.find_library("opus") + if not opus_path: + import sys + for p in ("/opt/homebrew/lib/libopus.dylib", "/usr/local/lib/libopus.dylib"): + import os + if os.path.isfile(p): + opus_path = p + break + if opus_path: + discord.opus.load_opus(opus_path) + OPUS_AVAILABLE = discord.opus.is_loaded() +except Exception: + OPUS_AVAILABLE = False + +from types import SimpleNamespace +from unittest.mock import MagicMock +from gateway.platforms.discord import VoiceReceiver + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_secret_key(): + """Generate a random 32-byte key.""" + import os + return os.urandom(32) + + +def _build_encrypted_rtp_packet(secret_key, opus_payload, ssrc=100, seq=1, timestamp=960): + """Build a real NaCl-encrypted RTP packet matching Discord's format. + + Format: RTP header (12 bytes) + encrypted(opus) + 4-byte nonce + Encryption: aead_xchacha20_poly1305 with RTP header as AAD. + """ + # RTP header: version=2, payload_type=0x78, no extension, no CSRC + header = struct.pack(">BBHII", 0x80, 0x78, seq, timestamp, ssrc) + + # Encrypt with NaCl AEAD + box = nacl.secret.Aead(secret_key) + nonce_counter = struct.pack(">I", seq) # 4-byte counter as nonce seed + # Full 24-byte nonce: counter in first 4 bytes, rest zeros + full_nonce = nonce_counter + b'\x00' * 20 + + enc_msg = box.encrypt(opus_payload, header, full_nonce) + ciphertext = enc_msg.ciphertext # without nonce prefix + + # Discord format: header + ciphertext + 4-byte nonce + return header + ciphertext + nonce_counter + + +def _make_voice_receiver(secret_key, dave_session=None, bot_ssrc=9999, + allowed_user_ids=None, members=None): + """Create a VoiceReceiver with real secret key.""" + vc = MagicMock() + vc._connection.secret_key = list(secret_key) + vc._connection.dave_session = dave_session + vc._connection.ssrc = bot_ssrc + vc._connection.add_socket_listener = MagicMock() + vc._connection.remove_socket_listener = MagicMock() + vc._connection.hook = None + vc.user = SimpleNamespace(id=bot_ssrc) + vc.channel = MagicMock() + vc.channel.members = members or [] + receiver = VoiceReceiver(vc, allowed_user_ids=allowed_user_ids) + receiver.start() + return receiver + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +class TestRealNaClDecrypt: + """End-to-end: real NaCl encrypt → _on_packet decrypt → buffer.""" + + def test_valid_encrypted_packet_buffered(self): + """Real NaCl encrypted packet → decrypted → buffered.""" + key = _make_secret_key() + opus_silence = b'\xf8\xff\xfe' + receiver = _make_voice_receiver(key) + + packet = _build_encrypted_rtp_packet(key, opus_silence, ssrc=100) + receiver._on_packet(packet) + + assert 100 in receiver._buffers + assert len(receiver._buffers[100]) > 0 + + def test_wrong_key_packet_dropped(self): + """Packet encrypted with wrong key → NaCl fails → not buffered.""" + real_key = _make_secret_key() + wrong_key = _make_secret_key() + opus_silence = b'\xf8\xff\xfe' + receiver = _make_voice_receiver(real_key) + + packet = _build_encrypted_rtp_packet(wrong_key, opus_silence, ssrc=100) + receiver._on_packet(packet) + + assert len(receiver._buffers.get(100, b"")) == 0 + + def test_bot_ssrc_ignored(self): + """Packet from bot's own SSRC → ignored.""" + key = _make_secret_key() + receiver = _make_voice_receiver(key, bot_ssrc=9999) + + packet = _build_encrypted_rtp_packet(key, b'\xf8\xff\xfe', ssrc=9999) + receiver._on_packet(packet) + + assert len(receiver._buffers) == 0 + + def test_multiple_packets_accumulate(self): + """Multiple valid packets → buffer grows.""" + key = _make_secret_key() + receiver = _make_voice_receiver(key) + + for seq in range(1, 6): + packet = _build_encrypted_rtp_packet( + key, b'\xf8\xff\xfe', ssrc=100, seq=seq, timestamp=960 * seq + ) + receiver._on_packet(packet) + + assert 100 in receiver._buffers + buf_size = len(receiver._buffers[100]) + assert buf_size > 0, "Multiple packets should accumulate in buffer" + + def test_different_ssrcs_separate_buffers(self): + """Packets from different SSRCs → separate buffers.""" + key = _make_secret_key() + receiver = _make_voice_receiver(key) + + for ssrc in [100, 200, 300]: + packet = _build_encrypted_rtp_packet(key, b'\xf8\xff\xfe', ssrc=ssrc) + receiver._on_packet(packet) + + assert len(receiver._buffers) == 3 + for ssrc in [100, 200, 300]: + assert ssrc in receiver._buffers + + +class TestRealNaClWithDAVE: + """NaCl decrypt + DAVE passthrough scenarios with real crypto.""" + + def test_dave_unknown_ssrc_passthrough(self): + """DAVE enabled but SSRC unknown → skip DAVE, buffer audio.""" + key = _make_secret_key() + dave = MagicMock() # DAVE session present but SSRC not mapped + receiver = _make_voice_receiver(key, dave_session=dave) + + packet = _build_encrypted_rtp_packet(key, b'\xf8\xff\xfe', ssrc=100) + receiver._on_packet(packet) + + # DAVE decrypt not called (SSRC unknown) + dave.decrypt.assert_not_called() + # Audio still buffered via passthrough + assert 100 in receiver._buffers + assert len(receiver._buffers[100]) > 0 + + def test_dave_unencrypted_error_passthrough(self): + """DAVE raises 'Unencrypted' → use NaCl-decrypted data as-is.""" + key = _make_secret_key() + dave = MagicMock() + dave.decrypt.side_effect = Exception( + "DecryptionFailed(UnencryptedWhenPassthroughDisabled)" + ) + receiver = _make_voice_receiver(key, dave_session=dave) + receiver.map_ssrc(100, 42) + + packet = _build_encrypted_rtp_packet(key, b'\xf8\xff\xfe', ssrc=100) + receiver._on_packet(packet) + + # DAVE was called but failed → passthrough + dave.decrypt.assert_called_once() + assert 100 in receiver._buffers + assert len(receiver._buffers[100]) > 0 + + def test_dave_real_error_drops(self): + """DAVE raises non-Unencrypted error → packet dropped.""" + key = _make_secret_key() + dave = MagicMock() + dave.decrypt.side_effect = Exception("KeyRotationFailed") + receiver = _make_voice_receiver(key, dave_session=dave) + receiver.map_ssrc(100, 42) + + packet = _build_encrypted_rtp_packet(key, b'\xf8\xff\xfe', ssrc=100) + receiver._on_packet(packet) + + assert len(receiver._buffers.get(100, b"")) == 0 + + +class TestFullVoiceFlow: + """End-to-end: encrypt → receive → buffer → silence detect → complete.""" + + def test_single_utterance_flow(self): + """Encrypt packets → buffer → silence → check_silence returns utterance.""" + key = _make_secret_key() + receiver = _make_voice_receiver(key) + receiver.map_ssrc(100, 42) + + # Send enough packets to exceed MIN_SPEECH_DURATION (0.5s) + # At 48kHz stereo 16-bit, each Opus silence frame decodes to ~3840 bytes + # Need 96000 bytes = ~25 frames + for seq in range(1, 30): + packet = _build_encrypted_rtp_packet( + key, b'\xf8\xff\xfe', ssrc=100, seq=seq, timestamp=960 * seq + ) + receiver._on_packet(packet) + + # Simulate silence by setting last_packet_time in the past + receiver._last_packet_time[100] = time.monotonic() - 3.0 + + completed = receiver.check_silence() + assert len(completed) == 1 + user_id, pcm_data = completed[0] + assert user_id == 42 + assert len(pcm_data) > 0 + + def test_utterance_with_ssrc_automap(self): + """No SPEAKING event → auto-map sole allowed user → utterance processed.""" + key = _make_secret_key() + members = [ + SimpleNamespace(id=9999, name="Bot"), + SimpleNamespace(id=42, name="Alice"), + ] + receiver = _make_voice_receiver( + key, allowed_user_ids={"42"}, members=members + ) + # No map_ssrc call — simulating missing SPEAKING event + + for seq in range(1, 30): + packet = _build_encrypted_rtp_packet( + key, b'\xf8\xff\xfe', ssrc=100, seq=seq, timestamp=960 * seq + ) + receiver._on_packet(packet) + + receiver._last_packet_time[100] = time.monotonic() - 3.0 + + completed = receiver.check_silence() + assert len(completed) == 1 + assert completed[0][0] == 42 # auto-mapped to sole allowed user + + def test_pause_blocks_during_playback(self): + """Pause receiver → packets ignored → resume → packets accepted.""" + key = _make_secret_key() + receiver = _make_voice_receiver(key) + + # Pause (echo prevention during TTS playback) + receiver.pause() + packet = _build_encrypted_rtp_packet(key, b'\xf8\xff\xfe', ssrc=100) + receiver._on_packet(packet) + assert len(receiver._buffers.get(100, b"")) == 0 + + # Resume + receiver.resume() + receiver._on_packet(packet) + assert 100 in receiver._buffers + assert len(receiver._buffers[100]) > 0 + + def test_corrupted_packet_ignored(self): + """Corrupted/truncated packet → silently ignored.""" + key = _make_secret_key() + receiver = _make_voice_receiver(key) + + # Too short + receiver._on_packet(b"\x00" * 5) + assert len(receiver._buffers) == 0 + + # Wrong RTP version + bad_header = struct.pack(">BBHII", 0x00, 0x78, 1, 960, 100) + receiver._on_packet(bad_header + b"\x00" * 20) + assert len(receiver._buffers) == 0 + + # Wrong payload type + bad_pt = struct.pack(">BBHII", 0x80, 0x00, 1, 960, 100) + receiver._on_packet(bad_pt + b"\x00" * 20) + assert len(receiver._buffers) == 0 + + def test_stop_cleans_everything(self): + """stop() clears all state cleanly.""" + key = _make_secret_key() + receiver = _make_voice_receiver(key) + receiver.map_ssrc(100, 42) + + for seq in range(1, 10): + packet = _build_encrypted_rtp_packet( + key, b'\xf8\xff\xfe', ssrc=100, seq=seq, timestamp=960 * seq + ) + receiver._on_packet(packet) + + assert len(receiver._buffers[100]) > 0 + + receiver.stop() + assert receiver._running is False + assert len(receiver._buffers) == 0 + assert len(receiver._ssrc_to_user) == 0 + assert len(receiver._decoders) == 0 + + +class TestSPEAKINGHook: + """SPEAKING event hook correctly maps SSRC to user_id.""" + + def test_speaking_hook_installed(self): + """start() installs speaking hook on connection.""" + key = _make_secret_key() + receiver = _make_voice_receiver(key) + conn = receiver._vc._connection + # hook should be set (wrapped) + assert conn.hook is not None + + def test_map_ssrc_via_speaking(self): + """SPEAKING op 5 event maps SSRC to user_id.""" + key = _make_secret_key() + receiver = _make_voice_receiver(key) + receiver.map_ssrc(500, 12345) + assert receiver._ssrc_to_user[500] == 12345 + + def test_map_ssrc_overwrites(self): + """New SPEAKING event for same SSRC overwrites old mapping.""" + key = _make_secret_key() + receiver = _make_voice_receiver(key) + receiver.map_ssrc(500, 111) + receiver.map_ssrc(500, 222) + assert receiver._ssrc_to_user[500] == 222 + + def test_speaking_mapped_audio_processed(self): + """After SSRC is mapped, audio from that SSRC gets correct user_id.""" + key = _make_secret_key() + receiver = _make_voice_receiver(key) + receiver.map_ssrc(100, 42) + + for seq in range(1, 30): + packet = _build_encrypted_rtp_packet( + key, b'\xf8\xff\xfe', ssrc=100, seq=seq, timestamp=960 * seq + ) + receiver._on_packet(packet) + + receiver._last_packet_time[100] = time.monotonic() - 3.0 + completed = receiver.check_silence() + assert len(completed) == 1 + assert completed[0][0] == 42 + + +class TestAuthFiltering: + """Only allowed users' audio should be processed.""" + + def test_allowed_user_audio_processed(self): + """Allowed user's utterance is returned by check_silence.""" + key = _make_secret_key() + members = [ + SimpleNamespace(id=9999, name="Bot"), + SimpleNamespace(id=42, name="Alice"), + ] + receiver = _make_voice_receiver( + key, allowed_user_ids={"42"}, members=members, + ) + receiver.map_ssrc(100, 42) + + for seq in range(1, 30): + packet = _build_encrypted_rtp_packet( + key, b'\xf8\xff\xfe', ssrc=100, seq=seq, timestamp=960 * seq + ) + receiver._on_packet(packet) + + receiver._last_packet_time[100] = time.monotonic() - 3.0 + completed = receiver.check_silence() + assert len(completed) == 1 + assert completed[0][0] == 42 + + def test_automap_rejects_unallowed_user(self): + """Auto-map refuses to map SSRC to user not in allowed list.""" + key = _make_secret_key() + members = [ + SimpleNamespace(id=9999, name="Bot"), + SimpleNamespace(id=42, name="Alice"), + ] + receiver = _make_voice_receiver( + key, allowed_user_ids={"99"}, # Alice not allowed + members=members, + ) + # No map_ssrc — SSRC unknown, auto-map should reject + + for seq in range(1, 30): + packet = _build_encrypted_rtp_packet( + key, b'\xf8\xff\xfe', ssrc=100, seq=seq, timestamp=960 * seq + ) + receiver._on_packet(packet) + + receiver._last_packet_time[100] = time.monotonic() - 3.0 + completed = receiver.check_silence() + assert len(completed) == 0 + + def test_empty_allowlist_allows_all(self): + """Empty allowed_user_ids means no restriction.""" + key = _make_secret_key() + members = [ + SimpleNamespace(id=9999, name="Bot"), + SimpleNamespace(id=42, name="Alice"), + ] + receiver = _make_voice_receiver( + key, allowed_user_ids=None, members=members, + ) + + for seq in range(1, 30): + packet = _build_encrypted_rtp_packet( + key, b'\xf8\xff\xfe', ssrc=100, seq=seq, timestamp=960 * seq + ) + receiver._on_packet(packet) + + receiver._last_packet_time[100] = time.monotonic() - 3.0 + completed = receiver.check_silence() + # Auto-mapped to sole non-bot member + assert len(completed) == 1 + assert completed[0][0] == 42 + + +class TestRejoinFlow: + """Leave and rejoin: state cleanup and fresh receiver.""" + + def test_stop_then_new_receiver_clean_state(self): + """After stop(), a new receiver starts with empty state.""" + key = _make_secret_key() + receiver1 = _make_voice_receiver(key) + receiver1.map_ssrc(100, 42) + + for seq in range(1, 10): + packet = _build_encrypted_rtp_packet( + key, b'\xf8\xff\xfe', ssrc=100, seq=seq, timestamp=960 * seq + ) + receiver1._on_packet(packet) + + assert len(receiver1._buffers[100]) > 0 + receiver1.stop() + + # New receiver (simulates rejoin) + receiver2 = _make_voice_receiver(key) + assert len(receiver2._buffers) == 0 + assert len(receiver2._ssrc_to_user) == 0 + assert len(receiver2._decoders) == 0 + + def test_rejoin_new_ssrc_works(self): + """After rejoin, user may get new SSRC — still works.""" + key = _make_secret_key() + receiver1 = _make_voice_receiver(key) + receiver1.map_ssrc(100, 42) # old SSRC + receiver1.stop() + + receiver2 = _make_voice_receiver(key) + receiver2.map_ssrc(200, 42) # new SSRC after rejoin + + for seq in range(1, 30): + packet = _build_encrypted_rtp_packet( + key, b'\xf8\xff\xfe', ssrc=200, seq=seq, timestamp=960 * seq + ) + receiver2._on_packet(packet) + + receiver2._last_packet_time[200] = time.monotonic() - 3.0 + completed = receiver2.check_silence() + assert len(completed) == 1 + assert completed[0][0] == 42 + + def test_rejoin_without_speaking_event_automap(self): + """Rejoin without SPEAKING event — auto-map sole allowed user.""" + key = _make_secret_key() + members = [ + SimpleNamespace(id=9999, name="Bot"), + SimpleNamespace(id=42, name="Alice"), + ] + + # First session + receiver1 = _make_voice_receiver( + key, allowed_user_ids={"42"}, members=members, + ) + receiver1.stop() + + # Rejoin — new key (Discord may assign new secret_key) + new_key = _make_secret_key() + receiver2 = _make_voice_receiver( + new_key, allowed_user_ids={"42"}, members=members, + ) + # No map_ssrc — simulating missing SPEAKING event + + for seq in range(1, 30): + packet = _build_encrypted_rtp_packet( + new_key, b'\xf8\xff\xfe', ssrc=300, seq=seq, timestamp=960 * seq + ) + receiver2._on_packet(packet) + + receiver2._last_packet_time[300] = time.monotonic() - 3.0 + completed = receiver2.check_silence() + assert len(completed) == 1 + assert completed[0][0] == 42 + + +class TestMultiGuildIsolation: + """Each guild has independent voice state.""" + + def test_separate_receivers_independent(self): + """Two receivers (different guilds) don't interfere.""" + key1 = _make_secret_key() + key2 = _make_secret_key() + + receiver1 = _make_voice_receiver(key1, bot_ssrc=1111) + receiver2 = _make_voice_receiver(key2, bot_ssrc=2222) + + receiver1.map_ssrc(100, 42) + receiver2.map_ssrc(200, 99) + + # Send to receiver1 + for seq in range(1, 10): + packet = _build_encrypted_rtp_packet( + key1, b'\xf8\xff\xfe', ssrc=100, seq=seq, timestamp=960 * seq + ) + receiver1._on_packet(packet) + + # receiver2 should be empty + assert len(receiver2._buffers) == 0 + assert 100 in receiver1._buffers + + def test_stop_one_doesnt_affect_other(self): + """Stopping one receiver doesn't affect another.""" + key1 = _make_secret_key() + key2 = _make_secret_key() + + receiver1 = _make_voice_receiver(key1) + receiver2 = _make_voice_receiver(key2) + + receiver1.map_ssrc(100, 42) + receiver2.map_ssrc(200, 99) + + for seq in range(1, 10): + packet = _build_encrypted_rtp_packet( + key2, b'\xf8\xff\xfe', ssrc=200, seq=seq, timestamp=960 * seq + ) + receiver2._on_packet(packet) + + receiver1.stop() + + # receiver2 still has data + assert receiver2._running is True + assert len(receiver2._buffers[200]) > 0 + + +class TestEchoPreventionFlow: + """Receiver pause/resume during TTS playback prevents echo.""" + + def test_audio_during_pause_ignored(self): + """Audio arriving while paused is completely ignored.""" + key = _make_secret_key() + receiver = _make_voice_receiver(key) + receiver.map_ssrc(100, 42) + receiver.pause() + + for seq in range(1, 30): + packet = _build_encrypted_rtp_packet( + key, b'\xf8\xff\xfe', ssrc=100, seq=seq, timestamp=960 * seq + ) + receiver._on_packet(packet) + + assert len(receiver._buffers.get(100, b"")) == 0 + + def test_audio_after_resume_processed(self): + """Audio arriving after resume is processed normally.""" + key = _make_secret_key() + receiver = _make_voice_receiver(key) + receiver.map_ssrc(100, 42) + + # Pause → send packets → resume → send more packets + receiver.pause() + for seq in range(1, 5): + packet = _build_encrypted_rtp_packet( + key, b'\xf8\xff\xfe', ssrc=100, seq=seq, timestamp=960 * seq + ) + receiver._on_packet(packet) + assert len(receiver._buffers.get(100, b"")) == 0 + + receiver.resume() + for seq in range(5, 35): + packet = _build_encrypted_rtp_packet( + key, b'\xf8\xff\xfe', ssrc=100, seq=seq, timestamp=960 * seq + ) + receiver._on_packet(packet) + + assert len(receiver._buffers[100]) > 0 + receiver._last_packet_time[100] = time.monotonic() - 3.0 + completed = receiver.check_silence() + assert len(completed) == 1 + assert completed[0][0] == 42