mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-28 06:51:16 +08:00
fix(gateway): avoid cross-user mirror writes in per-user group sessions
This commit is contained in:
@@ -28,6 +28,7 @@ def mirror_to_session(
|
||||
message_text: str,
|
||||
source_label: str = "cli",
|
||||
thread_id: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Append a delivery-mirror message to the target session's transcript.
|
||||
@@ -39,9 +40,20 @@ def mirror_to_session(
|
||||
All errors are caught -- this is never fatal.
|
||||
"""
|
||||
try:
|
||||
session_id = _find_session_id(platform, str(chat_id), thread_id=thread_id)
|
||||
session_id = _find_session_id(
|
||||
platform,
|
||||
str(chat_id),
|
||||
thread_id=thread_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
if not session_id:
|
||||
logger.debug("Mirror: no session found for %s:%s:%s", platform, chat_id, thread_id)
|
||||
logger.debug(
|
||||
"Mirror: no session found for %s:%s:%s:%s",
|
||||
platform,
|
||||
chat_id,
|
||||
thread_id,
|
||||
user_id,
|
||||
)
|
||||
return False
|
||||
|
||||
mirror_msg = {
|
||||
@@ -59,17 +71,33 @@ def mirror_to_session(
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.debug("Mirror failed for %s:%s:%s: %s", platform, chat_id, thread_id, e)
|
||||
logger.debug(
|
||||
"Mirror failed for %s:%s:%s:%s: %s",
|
||||
platform,
|
||||
chat_id,
|
||||
thread_id,
|
||||
user_id,
|
||||
e,
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
def _find_session_id(platform: str, chat_id: str, thread_id: Optional[str] = None) -> Optional[str]:
|
||||
def _find_session_id(
|
||||
platform: str,
|
||||
chat_id: str,
|
||||
thread_id: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Find the active session_id for a platform + chat_id pair.
|
||||
|
||||
Scans sessions.json entries and matches where origin.chat_id == chat_id
|
||||
on the right platform. DM session keys don't embed the chat_id
|
||||
(e.g. "agent:main:telegram:dm"), so we check the origin dict.
|
||||
|
||||
When *user_id* is provided, prefer exact sender matches. If multiple
|
||||
same-chat candidates exist and none matches the user, return None instead
|
||||
of guessing and contaminating another participant's session.
|
||||
"""
|
||||
if not _SESSIONS_INDEX.exists():
|
||||
return None
|
||||
@@ -81,8 +109,7 @@ def _find_session_id(platform: str, chat_id: str, thread_id: Optional[str] = Non
|
||||
return None
|
||||
|
||||
platform_lower = platform.lower()
|
||||
best_match = None
|
||||
best_updated = ""
|
||||
candidates = []
|
||||
|
||||
for _key, entry in data.items():
|
||||
origin = entry.get("origin") or {}
|
||||
@@ -96,12 +123,31 @@ def _find_session_id(platform: str, chat_id: str, thread_id: Optional[str] = Non
|
||||
origin_thread_id = origin.get("thread_id")
|
||||
if thread_id is not None and str(origin_thread_id or "") != str(thread_id):
|
||||
continue
|
||||
updated = entry.get("updated_at", "")
|
||||
if updated > best_updated:
|
||||
best_updated = updated
|
||||
best_match = entry.get("session_id")
|
||||
candidates.append(entry)
|
||||
|
||||
return best_match
|
||||
if not candidates:
|
||||
return None
|
||||
|
||||
if user_id:
|
||||
exact_user_matches = [
|
||||
entry for entry in candidates
|
||||
if str((entry.get("origin") or {}).get("user_id") or "") == str(user_id)
|
||||
]
|
||||
if exact_user_matches:
|
||||
candidates = exact_user_matches
|
||||
elif len(candidates) > 1:
|
||||
return None
|
||||
elif len(candidates) > 1:
|
||||
distinct_user_ids = {
|
||||
str((entry.get("origin") or {}).get("user_id") or "").strip()
|
||||
for entry in candidates
|
||||
if str((entry.get("origin") or {}).get("user_id") or "").strip()
|
||||
}
|
||||
if len(distinct_user_ids) > 1:
|
||||
return None
|
||||
|
||||
best_entry = max(candidates, key=lambda entry: entry.get("updated_at", ""))
|
||||
return best_entry.get("session_id")
|
||||
|
||||
|
||||
def _append_to_jsonl(session_id: str, message: dict) -> None:
|
||||
|
||||
@@ -77,6 +77,46 @@ class TestFindSessionId:
|
||||
|
||||
assert result == "sess_topic_a"
|
||||
|
||||
def test_user_id_disambiguates_same_group_chat(self, tmp_path):
|
||||
sessions_dir, index_file = _setup_sessions(tmp_path, {
|
||||
"alice": {
|
||||
"session_id": "sess_alice",
|
||||
"origin": {"platform": "telegram", "chat_id": "-1001", "user_id": "alice"},
|
||||
"updated_at": "2026-01-01T00:00:00",
|
||||
},
|
||||
"bob": {
|
||||
"session_id": "sess_bob",
|
||||
"origin": {"platform": "telegram", "chat_id": "-1001", "user_id": "bob"},
|
||||
"updated_at": "2026-02-01T00:00:00",
|
||||
},
|
||||
})
|
||||
|
||||
with patch.object(mirror_mod, "_SESSIONS_DIR", sessions_dir), \
|
||||
patch.object(mirror_mod, "_SESSIONS_INDEX", index_file):
|
||||
result = _find_session_id("telegram", "-1001", user_id="alice")
|
||||
|
||||
assert result == "sess_alice"
|
||||
|
||||
def test_ambiguous_same_group_chat_without_user_id_returns_none(self, tmp_path):
|
||||
sessions_dir, index_file = _setup_sessions(tmp_path, {
|
||||
"alice": {
|
||||
"session_id": "sess_alice",
|
||||
"origin": {"platform": "telegram", "chat_id": "-1001", "user_id": "alice"},
|
||||
"updated_at": "2026-01-01T00:00:00",
|
||||
},
|
||||
"bob": {
|
||||
"session_id": "sess_bob",
|
||||
"origin": {"platform": "telegram", "chat_id": "-1001", "user_id": "bob"},
|
||||
"updated_at": "2026-02-01T00:00:00",
|
||||
},
|
||||
})
|
||||
|
||||
with patch.object(mirror_mod, "_SESSIONS_DIR", sessions_dir), \
|
||||
patch.object(mirror_mod, "_SESSIONS_INDEX", index_file):
|
||||
result = _find_session_id("telegram", "-1001")
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_no_match_returns_none(self, tmp_path):
|
||||
sessions_dir, index_file = _setup_sessions(tmp_path, {
|
||||
"sess": {
|
||||
@@ -189,6 +229,35 @@ class TestMirrorToSession:
|
||||
assert (sessions_dir / "sess_topic_a.jsonl").exists()
|
||||
assert not (sessions_dir / "sess_topic_b.jsonl").exists()
|
||||
|
||||
def test_successful_mirror_uses_user_id_for_group_session(self, tmp_path):
|
||||
sessions_dir, index_file = _setup_sessions(tmp_path, {
|
||||
"alice": {
|
||||
"session_id": "sess_alice",
|
||||
"origin": {"platform": "telegram", "chat_id": "-1001", "user_id": "alice"},
|
||||
"updated_at": "2026-01-01T00:00:00",
|
||||
},
|
||||
"bob": {
|
||||
"session_id": "sess_bob",
|
||||
"origin": {"platform": "telegram", "chat_id": "-1001", "user_id": "bob"},
|
||||
"updated_at": "2026-02-01T00:00:00",
|
||||
},
|
||||
})
|
||||
|
||||
with patch.object(mirror_mod, "_SESSIONS_DIR", sessions_dir), \
|
||||
patch.object(mirror_mod, "_SESSIONS_INDEX", index_file), \
|
||||
patch("gateway.mirror._append_to_sqlite"):
|
||||
result = mirror_to_session(
|
||||
"telegram",
|
||||
"-1001",
|
||||
"Hello group!",
|
||||
source_label="cli",
|
||||
user_id="alice",
|
||||
)
|
||||
|
||||
assert result is True
|
||||
assert (sessions_dir / "sess_alice.jsonl").exists()
|
||||
assert not (sessions_dir / "sess_bob.jsonl").exists()
|
||||
|
||||
def test_no_matching_session(self, tmp_path):
|
||||
sessions_dir, index_file = _setup_sessions(tmp_path, {})
|
||||
|
||||
|
||||
@@ -167,6 +167,39 @@ class TestSendMessageTool:
|
||||
media_files=[],
|
||||
)
|
||||
|
||||
def test_mirror_receives_current_session_user_id(self):
|
||||
config, _telegram_cfg = _make_config()
|
||||
|
||||
with patch("gateway.config.load_gateway_config", return_value=config), \
|
||||
patch("tools.interrupt.is_interrupted", return_value=False), \
|
||||
patch("model_tools._run_async", side_effect=_run_async_immediately), \
|
||||
patch("tools.send_message_tool._send_to_platform", new=AsyncMock(return_value={"success": True})), \
|
||||
patch("gateway.session_context.get_session_env") as get_session_env_mock, \
|
||||
patch("gateway.mirror.mirror_to_session", return_value=True) as mirror_mock:
|
||||
get_session_env_mock.side_effect = lambda name, default="": {
|
||||
"HERMES_SESSION_PLATFORM": "telegram",
|
||||
"HERMES_SESSION_USER_ID": "user-123",
|
||||
}.get(name, default)
|
||||
result = json.loads(
|
||||
send_message_tool(
|
||||
{
|
||||
"action": "send",
|
||||
"target": "telegram:12345",
|
||||
"message": "hello",
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
assert result["success"] is True
|
||||
mirror_mock.assert_called_once_with(
|
||||
"telegram",
|
||||
"12345",
|
||||
"hello",
|
||||
source_label="telegram",
|
||||
thread_id=None,
|
||||
user_id="user-123",
|
||||
)
|
||||
|
||||
def test_top_level_send_failure_redacts_query_token(self):
|
||||
config, _telegram_cfg = _make_config()
|
||||
leaked = "very-secret-query-token-123456"
|
||||
|
||||
@@ -299,7 +299,15 @@ def _handle_send(args):
|
||||
from gateway.mirror import mirror_to_session
|
||||
from gateway.session_context import get_session_env
|
||||
source_label = get_session_env("HERMES_SESSION_PLATFORM", "cli")
|
||||
if mirror_to_session(platform_name, chat_id, mirror_text, source_label=source_label, thread_id=thread_id):
|
||||
user_id = get_session_env("HERMES_SESSION_USER_ID", "") or None
|
||||
if mirror_to_session(
|
||||
platform_name,
|
||||
chat_id,
|
||||
mirror_text,
|
||||
source_label=source_label,
|
||||
thread_id=thread_id,
|
||||
user_id=user_id,
|
||||
):
|
||||
result["mirrored"] = True
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user