mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-29 15:31:38 +08:00
Compare commits
2 Commits
fix/plugin
...
hermes/her
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5debb231fe | ||
|
|
9ceacc4e65 |
@@ -14,7 +14,7 @@ import logging
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from typing import Dict, Optional, Any
|
||||
from typing import Dict, Optional, Any, Tuple
|
||||
|
||||
try:
|
||||
from slack_bolt.async_app import AsyncApp
|
||||
@@ -95,6 +95,12 @@ class SlackAdapter(BasePlatformAdapter):
|
||||
# respond to ALL subsequent messages in that thread automatically.
|
||||
self._mentioned_threads: set = set()
|
||||
self._MENTIONED_THREADS_MAX = 5000
|
||||
# Assistant thread metadata keyed by (channel_id, thread_ts). Slack's
|
||||
# AI Assistant lifecycle events can arrive before/alongside message
|
||||
# events, and they carry the user/thread identity needed for stable
|
||||
# session + memory scoping.
|
||||
self._assistant_threads: Dict[Tuple[str, str], Dict[str, str]] = {}
|
||||
self._ASSISTANT_THREADS_MAX = 5000
|
||||
|
||||
async def connect(self) -> bool:
|
||||
"""Connect to Slack via Socket Mode."""
|
||||
@@ -181,6 +187,14 @@ class SlackAdapter(BasePlatformAdapter):
|
||||
async def handle_app_mention(event, say):
|
||||
pass
|
||||
|
||||
@self._app.event("assistant_thread_started")
|
||||
async def handle_assistant_thread_started(event, say):
|
||||
await self._handle_assistant_thread_lifecycle_event(event)
|
||||
|
||||
@self._app.event("assistant_thread_context_changed")
|
||||
async def handle_assistant_thread_context_changed(event, say):
|
||||
await self._handle_assistant_thread_lifecycle_event(event)
|
||||
|
||||
# Register slash command handler
|
||||
@self._app.command("/hermes")
|
||||
async def handle_hermes_command(ack, command):
|
||||
@@ -755,6 +769,135 @@ class SlackAdapter(BasePlatformAdapter):
|
||||
|
||||
# ----- Internal handlers -----
|
||||
|
||||
def _assistant_thread_key(self, channel_id: str, thread_ts: str) -> Optional[Tuple[str, str]]:
|
||||
"""Return a stable cache key for Slack assistant thread metadata."""
|
||||
if not channel_id or not thread_ts:
|
||||
return None
|
||||
return (str(channel_id), str(thread_ts))
|
||||
|
||||
def _extract_assistant_thread_metadata(self, event: dict) -> Dict[str, str]:
|
||||
"""Extract Slack Assistant thread identity data from an event payload."""
|
||||
assistant_thread = event.get("assistant_thread") or {}
|
||||
context = assistant_thread.get("context") or event.get("context") or {}
|
||||
|
||||
channel_id = (
|
||||
assistant_thread.get("channel_id")
|
||||
or event.get("channel")
|
||||
or context.get("channel_id")
|
||||
or ""
|
||||
)
|
||||
thread_ts = (
|
||||
assistant_thread.get("thread_ts")
|
||||
or event.get("thread_ts")
|
||||
or event.get("message_ts")
|
||||
or ""
|
||||
)
|
||||
user_id = (
|
||||
assistant_thread.get("user_id")
|
||||
or event.get("user")
|
||||
or context.get("user_id")
|
||||
or ""
|
||||
)
|
||||
team_id = (
|
||||
event.get("team")
|
||||
or event.get("team_id")
|
||||
or assistant_thread.get("team_id")
|
||||
or ""
|
||||
)
|
||||
context_channel_id = context.get("channel_id") or ""
|
||||
|
||||
return {
|
||||
"channel_id": str(channel_id) if channel_id else "",
|
||||
"thread_ts": str(thread_ts) if thread_ts else "",
|
||||
"user_id": str(user_id) if user_id else "",
|
||||
"team_id": str(team_id) if team_id else "",
|
||||
"context_channel_id": str(context_channel_id) if context_channel_id else "",
|
||||
}
|
||||
|
||||
def _cache_assistant_thread_metadata(self, metadata: Dict[str, str]) -> None:
|
||||
"""Remember assistant thread identity data for later message events."""
|
||||
channel_id = metadata.get("channel_id", "")
|
||||
thread_ts = metadata.get("thread_ts", "")
|
||||
key = self._assistant_thread_key(channel_id, thread_ts)
|
||||
if not key:
|
||||
return
|
||||
|
||||
existing = self._assistant_threads.get(key, {})
|
||||
merged = dict(existing)
|
||||
merged.update({k: v for k, v in metadata.items() if v})
|
||||
self._assistant_threads[key] = merged
|
||||
|
||||
# Evict oldest entries when the cache exceeds the limit
|
||||
if len(self._assistant_threads) > self._ASSISTANT_THREADS_MAX:
|
||||
excess = len(self._assistant_threads) - self._ASSISTANT_THREADS_MAX // 2
|
||||
for old_key in list(self._assistant_threads)[:excess]:
|
||||
del self._assistant_threads[old_key]
|
||||
|
||||
team_id = merged.get("team_id", "")
|
||||
if team_id and channel_id:
|
||||
self._channel_team[channel_id] = team_id
|
||||
|
||||
def _lookup_assistant_thread_metadata(
|
||||
self,
|
||||
event: dict,
|
||||
channel_id: str = "",
|
||||
thread_ts: str = "",
|
||||
) -> Dict[str, str]:
|
||||
"""Load cached assistant-thread metadata that matches the current event."""
|
||||
metadata = self._extract_assistant_thread_metadata(event)
|
||||
if channel_id and not metadata.get("channel_id"):
|
||||
metadata["channel_id"] = channel_id
|
||||
if thread_ts and not metadata.get("thread_ts"):
|
||||
metadata["thread_ts"] = thread_ts
|
||||
|
||||
key = self._assistant_thread_key(
|
||||
metadata.get("channel_id", ""),
|
||||
metadata.get("thread_ts", ""),
|
||||
)
|
||||
cached = self._assistant_threads.get(key, {}) if key else {}
|
||||
if cached:
|
||||
merged = dict(cached)
|
||||
merged.update({k: v for k, v in metadata.items() if v})
|
||||
return merged
|
||||
return metadata
|
||||
|
||||
def _seed_assistant_thread_session(self, metadata: Dict[str, str]) -> None:
|
||||
"""Prime the session store so assistant threads get stable user scoping."""
|
||||
session_store = getattr(self, "_session_store", None)
|
||||
if not session_store:
|
||||
return
|
||||
|
||||
channel_id = metadata.get("channel_id", "")
|
||||
thread_ts = metadata.get("thread_ts", "")
|
||||
user_id = metadata.get("user_id", "")
|
||||
if not channel_id or not thread_ts or not user_id:
|
||||
return
|
||||
|
||||
source = self.build_source(
|
||||
chat_id=channel_id,
|
||||
chat_name=channel_id,
|
||||
chat_type="dm",
|
||||
user_id=user_id,
|
||||
thread_id=thread_ts,
|
||||
chat_topic=metadata.get("context_channel_id") or None,
|
||||
)
|
||||
|
||||
try:
|
||||
session_store.get_or_create_session(source)
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"[Slack] Failed to seed assistant thread session for %s/%s",
|
||||
channel_id,
|
||||
thread_ts,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
async def _handle_assistant_thread_lifecycle_event(self, event: dict) -> None:
|
||||
"""Handle Slack Assistant lifecycle events that carry user/thread identity."""
|
||||
metadata = self._extract_assistant_thread_metadata(event)
|
||||
self._cache_assistant_thread_metadata(metadata)
|
||||
self._seed_assistant_thread_session(metadata)
|
||||
|
||||
async def _handle_slack_message(self, event: dict) -> None:
|
||||
"""Handle an incoming Slack message event."""
|
||||
# Dedup: Slack Socket Mode can redeliver events after reconnects (#4777)
|
||||
@@ -781,10 +924,21 @@ class SlackAdapter(BasePlatformAdapter):
|
||||
return
|
||||
|
||||
text = event.get("text", "")
|
||||
user_id = event.get("user", "")
|
||||
channel_id = event.get("channel", "")
|
||||
ts = event.get("ts", "")
|
||||
team_id = event.get("team", "")
|
||||
assistant_meta = self._lookup_assistant_thread_metadata(
|
||||
event,
|
||||
channel_id=channel_id,
|
||||
thread_ts=event.get("thread_ts", ""),
|
||||
)
|
||||
user_id = event.get("user") or assistant_meta.get("user_id", "")
|
||||
if not channel_id:
|
||||
channel_id = assistant_meta.get("channel_id", "")
|
||||
team_id = (
|
||||
event.get("team")
|
||||
or event.get("team_id")
|
||||
or assistant_meta.get("team_id", "")
|
||||
)
|
||||
|
||||
# Track which workspace owns this channel
|
||||
if team_id and channel_id:
|
||||
@@ -792,6 +946,8 @@ class SlackAdapter(BasePlatformAdapter):
|
||||
|
||||
# Determine if this is a DM or channel message
|
||||
channel_type = event.get("channel_type", "")
|
||||
if not channel_type and channel_id.startswith("D"):
|
||||
channel_type = "im"
|
||||
is_dm = channel_type == "im"
|
||||
|
||||
# Build thread_ts for session keying.
|
||||
@@ -800,7 +956,7 @@ class SlackAdapter(BasePlatformAdapter):
|
||||
# In DMs: only use the real thread_ts — top-level DMs should share
|
||||
# one continuous session, threaded DMs get their own session.
|
||||
if is_dm:
|
||||
thread_ts = event.get("thread_ts") # None for top-level DMs
|
||||
thread_ts = event.get("thread_ts") or assistant_meta.get("thread_ts") # None for top-level DMs
|
||||
else:
|
||||
thread_ts = event.get("thread_ts") or ts # ts fallback for channels
|
||||
|
||||
|
||||
@@ -96,7 +96,7 @@ class TestAppMentionHandler:
|
||||
"""Verify that the app_mention event handler is registered."""
|
||||
|
||||
def test_app_mention_registered_on_connect(self):
|
||||
"""connect() should register both 'message' and 'app_mention' handlers."""
|
||||
"""connect() should register message + assistant lifecycle handlers."""
|
||||
config = PlatformConfig(enabled=True, token="xoxb-fake")
|
||||
adapter = SlackAdapter(config)
|
||||
|
||||
@@ -145,6 +145,8 @@ class TestAppMentionHandler:
|
||||
|
||||
assert "message" in registered_events
|
||||
assert "app_mention" in registered_events
|
||||
assert "assistant_thread_started" in registered_events
|
||||
assert "assistant_thread_context_changed" in registered_events
|
||||
assert "/hermes" in registered_commands
|
||||
|
||||
|
||||
@@ -840,6 +842,114 @@ class TestThreadReplyHandling:
|
||||
adapter.handle_message.assert_not_called()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestAssistantThreadLifecycle
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestAssistantThreadLifecycle:
|
||||
"""Slack Assistant lifecycle events should seed session/user context."""
|
||||
|
||||
@pytest.fixture()
|
||||
def mock_session_store(self):
|
||||
store = MagicMock()
|
||||
store._entries = {}
|
||||
store._ensure_loaded = MagicMock()
|
||||
store.config = MagicMock()
|
||||
store.config.group_sessions_per_user = True
|
||||
store.get_or_create_session = MagicMock()
|
||||
return store
|
||||
|
||||
@pytest.fixture()
|
||||
def assistant_adapter(self, mock_session_store):
|
||||
config = PlatformConfig(enabled=True, token="***")
|
||||
a = SlackAdapter(config)
|
||||
a._app = MagicMock()
|
||||
a._app.client = AsyncMock()
|
||||
a._bot_user_id = "U_BOT"
|
||||
a._team_bot_user_ids = {"T_TEAM": "U_BOT"}
|
||||
a._running = True
|
||||
a.handle_message = AsyncMock()
|
||||
a.set_session_store(mock_session_store)
|
||||
return a
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lifecycle_event_seeds_session_store(self, assistant_adapter, mock_session_store):
|
||||
event = {
|
||||
"type": "assistant_thread_started",
|
||||
"team_id": "T_TEAM",
|
||||
"assistant_thread": {
|
||||
"channel_id": "D123",
|
||||
"thread_ts": "171.000",
|
||||
"user_id": "U_USER",
|
||||
"context": {"channel_id": "C_ORIGIN"},
|
||||
},
|
||||
}
|
||||
|
||||
await assistant_adapter._handle_assistant_thread_lifecycle_event(event)
|
||||
|
||||
assert assistant_adapter._assistant_threads[("D123", "171.000")]["user_id"] == "U_USER"
|
||||
mock_session_store.get_or_create_session.assert_called_once()
|
||||
source = mock_session_store.get_or_create_session.call_args[0][0]
|
||||
assert source.chat_id == "D123"
|
||||
assert source.chat_type == "dm"
|
||||
assert source.user_id == "U_USER"
|
||||
assert source.thread_id == "171.000"
|
||||
assert source.chat_topic == "C_ORIGIN"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_message_uses_cached_assistant_thread_identity(self, assistant_adapter):
|
||||
assistant_adapter._assistant_threads[("D123", "171.000")] = {
|
||||
"channel_id": "D123",
|
||||
"thread_ts": "171.000",
|
||||
"user_id": "U_USER",
|
||||
"team_id": "T_TEAM",
|
||||
}
|
||||
assistant_adapter._app.client.users_info = AsyncMock(return_value={
|
||||
"user": {"profile": {"display_name": "Tyler"}}
|
||||
})
|
||||
assistant_adapter._app.client.reactions_add = AsyncMock()
|
||||
assistant_adapter._app.client.reactions_remove = AsyncMock()
|
||||
|
||||
event = {
|
||||
"text": "hello from assistant dm",
|
||||
"channel": "D123",
|
||||
"channel_type": "im",
|
||||
"thread_ts": "171.000",
|
||||
"ts": "171.111",
|
||||
"team": "T_TEAM",
|
||||
}
|
||||
|
||||
await assistant_adapter._handle_slack_message(event)
|
||||
|
||||
msg_event = assistant_adapter.handle_message.call_args[0][0]
|
||||
assert msg_event.source.user_id == "U_USER"
|
||||
assert msg_event.source.thread_id == "171.000"
|
||||
assert msg_event.source.user_name == "Tyler"
|
||||
|
||||
def test_assistant_threads_cache_eviction(self, assistant_adapter):
|
||||
"""Cache should evict oldest entries when exceeding the size limit."""
|
||||
assistant_adapter._ASSISTANT_THREADS_MAX = 10
|
||||
# Fill to the limit
|
||||
for i in range(10):
|
||||
assistant_adapter._cache_assistant_thread_metadata({
|
||||
"channel_id": f"D{i}",
|
||||
"thread_ts": f"{i}.000",
|
||||
"user_id": f"U{i}",
|
||||
})
|
||||
assert len(assistant_adapter._assistant_threads) == 10
|
||||
|
||||
# Adding one more should trigger eviction (down to max // 2 = 5)
|
||||
assistant_adapter._cache_assistant_thread_metadata({
|
||||
"channel_id": "D999",
|
||||
"thread_ts": "999.000",
|
||||
"user_id": "U999",
|
||||
})
|
||||
assert len(assistant_adapter._assistant_threads) <= 10
|
||||
# The newest entry must survive eviction
|
||||
assert ("D999", "999.000") in assistant_adapter._assistant_threads
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestUserNameResolution
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
Reference in New Issue
Block a user