From 9ceacc4e652e7e88a00cd8a619ab18fa9be6229d Mon Sep 17 00:00:00 2001 From: helix4u <4317663+helix4u@users.noreply.github.com> Date: Wed, 8 Apr 2026 14:24:59 -0600 Subject: [PATCH] fix(slack): handle assistant thread lifecycle events --- gateway/platforms/slack.py | 157 +++++++++++++++++++++++++++++++++++- tests/gateway/test_slack.py | 90 ++++++++++++++++++++- 2 files changed, 242 insertions(+), 5 deletions(-) diff --git a/gateway/platforms/slack.py b/gateway/platforms/slack.py index 7af313d325e..8685b92ed00 100644 --- a/gateway/platforms/slack.py +++ b/gateway/platforms/slack.py @@ -14,7 +14,7 @@ import logging import os import re import time -from typing import Dict, Optional, Any +from typing import Dict, Optional, Any, Tuple try: from slack_bolt.async_app import AsyncApp @@ -95,6 +95,11 @@ class SlackAdapter(BasePlatformAdapter): # respond to ALL subsequent messages in that thread automatically. self._mentioned_threads: set = set() self._MENTIONED_THREADS_MAX = 5000 + # Assistant thread metadata keyed by (channel_id, thread_ts). Slack's + # AI Assistant lifecycle events can arrive before/alongside message + # events, and they carry the user/thread identity needed for stable + # session + memory scoping. + self._assistant_threads: Dict[Tuple[str, str], Dict[str, str]] = {} async def connect(self) -> bool: """Connect to Slack via Socket Mode.""" @@ -181,6 +186,14 @@ class SlackAdapter(BasePlatformAdapter): async def handle_app_mention(event, say): pass + @self._app.event("assistant_thread_started") + async def handle_assistant_thread_started(event, say): + await self._handle_assistant_thread_lifecycle_event(event) + + @self._app.event("assistant_thread_context_changed") + async def handle_assistant_thread_context_changed(event, say): + await self._handle_assistant_thread_lifecycle_event(event) + # Register slash command handler @self._app.command("/hermes") async def handle_hermes_command(ack, command): @@ -755,6 +768,129 @@ class SlackAdapter(BasePlatformAdapter): # ----- Internal handlers ----- + def _assistant_thread_key(self, channel_id: str, thread_ts: str) -> Optional[Tuple[str, str]]: + """Return a stable cache key for Slack assistant thread metadata.""" + if not channel_id or not thread_ts: + return None + return (str(channel_id), str(thread_ts)) + + def _extract_assistant_thread_metadata(self, event: dict) -> Dict[str, str]: + """Extract Slack Assistant thread identity data from an event payload.""" + assistant_thread = event.get("assistant_thread") or {} + context = assistant_thread.get("context") or event.get("context") or {} + + channel_id = ( + assistant_thread.get("channel_id") + or event.get("channel") + or context.get("channel_id") + or "" + ) + thread_ts = ( + assistant_thread.get("thread_ts") + or event.get("thread_ts") + or event.get("message_ts") + or "" + ) + user_id = ( + assistant_thread.get("user_id") + or event.get("user") + or context.get("user_id") + or "" + ) + team_id = ( + event.get("team") + or event.get("team_id") + or assistant_thread.get("team_id") + or "" + ) + context_channel_id = context.get("channel_id") or "" + + return { + "channel_id": str(channel_id) if channel_id else "", + "thread_ts": str(thread_ts) if thread_ts else "", + "user_id": str(user_id) if user_id else "", + "team_id": str(team_id) if team_id else "", + "context_channel_id": str(context_channel_id) if context_channel_id else "", + } + + def _cache_assistant_thread_metadata(self, metadata: Dict[str, str]) -> None: + """Remember assistant thread identity data for later message events.""" + channel_id = metadata.get("channel_id", "") + thread_ts = metadata.get("thread_ts", "") + key = self._assistant_thread_key(channel_id, thread_ts) + if not key: + return + + existing = self._assistant_threads.get(key, {}) + merged = dict(existing) + merged.update({k: v for k, v in metadata.items() if v}) + self._assistant_threads[key] = merged + + team_id = merged.get("team_id", "") + if team_id and channel_id: + self._channel_team[channel_id] = team_id + + def _lookup_assistant_thread_metadata( + self, + event: dict, + channel_id: str = "", + thread_ts: str = "", + ) -> Dict[str, str]: + """Load cached assistant-thread metadata that matches the current event.""" + metadata = self._extract_assistant_thread_metadata(event) + if channel_id and not metadata.get("channel_id"): + metadata["channel_id"] = channel_id + if thread_ts and not metadata.get("thread_ts"): + metadata["thread_ts"] = thread_ts + + key = self._assistant_thread_key( + metadata.get("channel_id", ""), + metadata.get("thread_ts", ""), + ) + cached = self._assistant_threads.get(key, {}) if key else {} + if cached: + merged = dict(cached) + merged.update({k: v for k, v in metadata.items() if v}) + return merged + return metadata + + def _seed_assistant_thread_session(self, metadata: Dict[str, str]) -> None: + """Prime the session store so assistant threads get stable user scoping.""" + session_store = getattr(self, "_session_store", None) + if not session_store: + return + + channel_id = metadata.get("channel_id", "") + thread_ts = metadata.get("thread_ts", "") + user_id = metadata.get("user_id", "") + if not channel_id or not thread_ts or not user_id: + return + + source = self.build_source( + chat_id=channel_id, + chat_name=channel_id, + chat_type="dm", + user_id=user_id, + thread_id=thread_ts, + chat_topic=metadata.get("context_channel_id") or None, + ) + + try: + session_store.get_or_create_session(source) + except Exception: + logger.debug( + "[Slack] Failed to seed assistant thread session for %s/%s", + channel_id, + thread_ts, + exc_info=True, + ) + + async def _handle_assistant_thread_lifecycle_event(self, event: dict) -> None: + """Handle Slack Assistant lifecycle events that carry user/thread identity.""" + metadata = self._extract_assistant_thread_metadata(event) + self._cache_assistant_thread_metadata(metadata) + self._seed_assistant_thread_session(metadata) + async def _handle_slack_message(self, event: dict) -> None: """Handle an incoming Slack message event.""" # Dedup: Slack Socket Mode can redeliver events after reconnects (#4777) @@ -781,10 +917,21 @@ class SlackAdapter(BasePlatformAdapter): return text = event.get("text", "") - user_id = event.get("user", "") channel_id = event.get("channel", "") ts = event.get("ts", "") - team_id = event.get("team", "") + assistant_meta = self._lookup_assistant_thread_metadata( + event, + channel_id=channel_id, + thread_ts=event.get("thread_ts", ""), + ) + user_id = event.get("user") or assistant_meta.get("user_id", "") + if not channel_id: + channel_id = assistant_meta.get("channel_id", "") + team_id = ( + event.get("team") + or event.get("team_id") + or assistant_meta.get("team_id", "") + ) # Track which workspace owns this channel if team_id and channel_id: @@ -792,6 +939,8 @@ class SlackAdapter(BasePlatformAdapter): # Determine if this is a DM or channel message channel_type = event.get("channel_type", "") + if not channel_type and channel_id.startswith("D"): + channel_type = "im" is_dm = channel_type == "im" # Build thread_ts for session keying. @@ -800,7 +949,7 @@ class SlackAdapter(BasePlatformAdapter): # In DMs: only use the real thread_ts — top-level DMs should share # one continuous session, threaded DMs get their own session. if is_dm: - thread_ts = event.get("thread_ts") # None for top-level DMs + thread_ts = event.get("thread_ts") or assistant_meta.get("thread_ts") # None for top-level DMs else: thread_ts = event.get("thread_ts") or ts # ts fallback for channels diff --git a/tests/gateway/test_slack.py b/tests/gateway/test_slack.py index 89b44718344..0bad0abe567 100644 --- a/tests/gateway/test_slack.py +++ b/tests/gateway/test_slack.py @@ -96,7 +96,7 @@ class TestAppMentionHandler: """Verify that the app_mention event handler is registered.""" def test_app_mention_registered_on_connect(self): - """connect() should register both 'message' and 'app_mention' handlers.""" + """connect() should register message + assistant lifecycle handlers.""" config = PlatformConfig(enabled=True, token="xoxb-fake") adapter = SlackAdapter(config) @@ -145,6 +145,8 @@ class TestAppMentionHandler: assert "message" in registered_events assert "app_mention" in registered_events + assert "assistant_thread_started" in registered_events + assert "assistant_thread_context_changed" in registered_events assert "/hermes" in registered_commands @@ -840,6 +842,92 @@ class TestThreadReplyHandling: adapter.handle_message.assert_not_called() +# --------------------------------------------------------------------------- +# TestAssistantThreadLifecycle +# --------------------------------------------------------------------------- + + +class TestAssistantThreadLifecycle: + """Slack Assistant lifecycle events should seed session/user context.""" + + @pytest.fixture() + def mock_session_store(self): + store = MagicMock() + store._entries = {} + store._ensure_loaded = MagicMock() + store.config = MagicMock() + store.config.group_sessions_per_user = True + store.get_or_create_session = MagicMock() + return store + + @pytest.fixture() + def assistant_adapter(self, mock_session_store): + config = PlatformConfig(enabled=True, token="***") + a = SlackAdapter(config) + a._app = MagicMock() + a._app.client = AsyncMock() + a._bot_user_id = "U_BOT" + a._team_bot_user_ids = {"T_TEAM": "U_BOT"} + a._running = True + a.handle_message = AsyncMock() + a.set_session_store(mock_session_store) + return a + + @pytest.mark.asyncio + async def test_lifecycle_event_seeds_session_store(self, assistant_adapter, mock_session_store): + event = { + "type": "assistant_thread_started", + "team_id": "T_TEAM", + "assistant_thread": { + "channel_id": "D123", + "thread_ts": "171.000", + "user_id": "U_USER", + "context": {"channel_id": "C_ORIGIN"}, + }, + } + + await assistant_adapter._handle_assistant_thread_lifecycle_event(event) + + assert assistant_adapter._assistant_threads[("D123", "171.000")]["user_id"] == "U_USER" + mock_session_store.get_or_create_session.assert_called_once() + source = mock_session_store.get_or_create_session.call_args[0][0] + assert source.chat_id == "D123" + assert source.chat_type == "dm" + assert source.user_id == "U_USER" + assert source.thread_id == "171.000" + assert source.chat_topic == "C_ORIGIN" + + @pytest.mark.asyncio + async def test_message_uses_cached_assistant_thread_identity(self, assistant_adapter): + assistant_adapter._assistant_threads[("D123", "171.000")] = { + "channel_id": "D123", + "thread_ts": "171.000", + "user_id": "U_USER", + "team_id": "T_TEAM", + } + assistant_adapter._app.client.users_info = AsyncMock(return_value={ + "user": {"profile": {"display_name": "Tyler"}} + }) + assistant_adapter._app.client.reactions_add = AsyncMock() + assistant_adapter._app.client.reactions_remove = AsyncMock() + + event = { + "text": "hello from assistant dm", + "channel": "D123", + "channel_type": "im", + "thread_ts": "171.000", + "ts": "171.111", + "team": "T_TEAM", + } + + await assistant_adapter._handle_slack_message(event) + + msg_event = assistant_adapter.handle_message.call_args[0][0] + assert msg_event.source.user_id == "U_USER" + assert msg_event.source.thread_id == "171.000" + assert msg_event.source.user_name == "Tyler" + + # --------------------------------------------------------------------------- # TestUserNameResolution # ---------------------------------------------------------------------------