Compare commits

...

2 Commits

Author SHA1 Message Date
kshitijk4poor
e47abf5c69 fix: follow-up improvements for watch notification routing (#9537)
- Populate watcher_* routing fields for watch-only processes (not just
  notify_on_complete), so watch-pattern events carry direct metadata
  instead of relying solely on session_key parsing fallback
- Extract _parse_session_key() helper to dedupe session key parsing
  at two call sites in gateway/run.py
- Add negative test proving cross-thread leakage doesn't happen
- Add edge-case tests for _build_process_event_source returning None
  (empty evt, invalid platform, short session_key)
- Add unit tests for _parse_session_key helper
2026-04-15 22:52:30 +05:30
etcircle
f871ec5a69 fix(gateway): route synthetic background events by session 2026-04-15 22:47:15 +05:30
6 changed files with 375 additions and 41 deletions

View File

@@ -482,6 +482,23 @@ def _resolve_hermes_bin() -> Optional[list[str]]:
return None return None
def _parse_session_key(session_key: str) -> "dict | None":
"""Parse a session key into its component parts.
Session keys follow the format ``agent:main:{platform}:{chat_type}:{chat_id}``.
Returns a dict with ``platform``, ``chat_type``, and ``chat_id`` keys,
or None if the key doesn't match the expected format.
"""
parts = session_key.split(":")
if len(parts) >= 5 and parts[0] == "agent" and parts[1] == "main":
return {
"platform": parts[2],
"chat_type": parts[3],
"chat_id": parts[4],
}
return None
def _format_gateway_process_notification(evt: dict) -> "str | None": def _format_gateway_process_notification(evt: dict) -> "str | None":
"""Format a watch pattern event from completion_queue into a [SYSTEM:] message.""" """Format a watch pattern event from completion_queue into a [SYSTEM:] message."""
evt_type = evt.get("type", "completion") evt_type = evt.get("type", "completion")
@@ -1489,12 +1506,11 @@ class GatewayRunner:
notified: set = set() notified: set = set()
for session_key in active: for session_key in active:
# Parse platform + chat_id from the session key. # Parse platform + chat_id from the session key.
# Format: agent:main:{platform}:{chat_type}:{chat_id}[:{extra}...] _parsed = _parse_session_key(session_key)
parts = session_key.split(":") if not _parsed:
if len(parts) < 5:
continue continue
platform_str = parts[2] platform_str = _parsed["platform"]
chat_id = parts[4] chat_id = _parsed["chat_id"]
# Deduplicate: one notification per chat, even if multiple # Deduplicate: one notification per chat, even if multiple
# sessions (different users/threads) share the same chat. # sessions (different users/threads) share the same chat.
@@ -3958,7 +3974,7 @@ class GatewayRunner:
synth_text = _format_gateway_process_notification(evt) synth_text = _format_gateway_process_notification(evt)
if synth_text: if synth_text:
try: try:
await self._inject_watch_notification(synth_text, event) await self._inject_watch_notification(synth_text, evt)
except Exception as e2: except Exception as e2:
logger.error("Watch notification injection error: %s", e2) logger.error("Watch notification injection error: %s", e2)
except Exception as e: except Exception as e:
@@ -7452,14 +7468,75 @@ class GatewayRunner:
return prefix return prefix
return user_text return user_text
async def _inject_watch_notification(self, synth_text: str, original_event) -> None: def _build_process_event_source(self, evt: dict):
"""Resolve the canonical source for a synthetic background-process event.
Prefer the persisted session-store origin for the event's session key.
Falling back to the currently active foreground event is what causes
cross-topic bleed, so don't do that.
"""
from gateway.session import SessionSource
session_key = str(evt.get("session_key") or "").strip()
derived_platform = ""
derived_chat_type = ""
derived_chat_id = ""
if session_key:
try:
self.session_store._ensure_loaded()
entry = self.session_store._entries.get(session_key)
if entry and getattr(entry, "origin", None):
return entry.origin
except Exception as exc:
logger.debug(
"Synthetic process-event session-store lookup failed for %s: %s",
session_key,
exc,
)
_parsed = _parse_session_key(session_key)
if _parsed:
derived_platform = _parsed["platform"]
derived_chat_type = _parsed["chat_type"]
derived_chat_id = _parsed["chat_id"]
platform_name = str(evt.get("platform") or derived_platform or "").strip().lower()
chat_type = str(evt.get("chat_type") or derived_chat_type or "").strip().lower()
chat_id = str(evt.get("chat_id") or derived_chat_id or "").strip()
if not platform_name or not chat_type or not chat_id:
return None
try:
platform = Platform(platform_name)
except Exception:
logger.warning(
"Synthetic process event has invalid platform metadata: %r",
platform_name,
)
return None
return SessionSource(
platform=platform,
chat_id=chat_id,
chat_type=chat_type,
thread_id=str(evt.get("thread_id") or "").strip() or None,
user_id=str(evt.get("user_id") or "").strip() or None,
user_name=str(evt.get("user_name") or "").strip() or None,
)
async def _inject_watch_notification(self, synth_text: str, evt: dict) -> None:
"""Inject a watch-pattern notification as a synthetic message event. """Inject a watch-pattern notification as a synthetic message event.
Uses the source from the original user event to route the notification Routing must come from the queued watch event itself, not from whatever
back to the correct chat/adapter. foreground message happened to be active when the queue was drained.
""" """
source = getattr(original_event, "source", None) source = self._build_process_event_source(evt)
if not source: if not source:
logger.warning(
"Dropping watch notification with no routing metadata for process %s",
evt.get("session_id", "unknown"),
)
return return
platform_name = source.platform.value if hasattr(source.platform, "value") else str(source.platform) platform_name = source.platform.value if hasattr(source.platform, "value") else str(source.platform)
adapter = None adapter = None
@@ -7477,7 +7554,12 @@ class GatewayRunner:
source=source, source=source,
internal=True, internal=True,
) )
logger.info("Watch pattern notification — injecting for %s", platform_name) logger.info(
"Watch pattern notification — injecting for %s chat=%s thread=%s",
platform_name,
source.chat_id,
source.thread_id,
)
await adapter.handle_message(synth_event) await adapter.handle_message(synth_event)
except Exception as e: except Exception as e:
logger.error("Watch notification injection error: %s", e) logger.error("Watch notification injection error: %s", e)
@@ -7547,33 +7629,42 @@ class GatewayRunner:
f"Command: {session.command}\n" f"Command: {session.command}\n"
f"Output:\n{_out}]" f"Output:\n{_out}]"
) )
source = self._build_process_event_source({
"session_id": session_id,
"session_key": session_key,
"platform": platform_name,
"chat_id": chat_id,
"thread_id": thread_id,
"user_id": user_id,
"user_name": user_name,
})
if not source:
logger.warning(
"Dropping completion notification with no routing metadata for process %s",
session_id,
)
break
adapter = None adapter = None
for p, a in self.adapters.items(): for p, a in self.adapters.items():
if p.value == platform_name: if p == source.platform:
adapter = a adapter = a
break break
if adapter and chat_id: if adapter and source.chat_id:
try: try:
from gateway.platforms.base import MessageEvent, MessageType from gateway.platforms.base import MessageEvent, MessageType
from gateway.session import SessionSource
from gateway.config import Platform
_platform_enum = Platform(platform_name)
_source = SessionSource(
platform=_platform_enum,
chat_id=chat_id,
thread_id=thread_id or None,
user_id=user_id or None,
user_name=user_name or None,
)
synth_event = MessageEvent( synth_event = MessageEvent(
text=synth_text, text=synth_text,
message_type=MessageType.TEXT, message_type=MessageType.TEXT,
source=_source, source=source,
internal=True, internal=True,
) )
logger.info( logger.info(
"Process %s finished — injecting agent notification for session %s", "Process %s finished — injecting agent notification for session %s chat=%s thread=%s",
session_id, session_key, session_id,
session_key,
source.chat_id,
source.thread_id,
) )
await adapter.handle_message(synth_event) await adapter.handle_message(synth_event)
except Exception as e: except Exception as e:

View File

@@ -14,7 +14,7 @@ from unittest.mock import AsyncMock, patch
import pytest import pytest
from gateway.config import GatewayConfig, Platform from gateway.config import GatewayConfig, Platform
from gateway.run import GatewayRunner from gateway.run import GatewayRunner, _parse_session_key
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -45,7 +45,7 @@ def _build_runner(monkeypatch, tmp_path, mode: str) -> GatewayRunner:
monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path) monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path)
runner = GatewayRunner(GatewayConfig()) runner = GatewayRunner(GatewayConfig())
adapter = SimpleNamespace(send=AsyncMock()) adapter = SimpleNamespace(send=AsyncMock(), handle_message=AsyncMock())
runner.adapters[Platform.TELEGRAM] = adapter runner.adapters[Platform.TELEGRAM] = adapter
return runner return runner
@@ -243,3 +243,156 @@ async def test_no_thread_id_sends_no_metadata(monkeypatch, tmp_path):
assert adapter.send.await_count == 1 assert adapter.send.await_count == 1
_, kwargs = adapter.send.call_args _, kwargs = adapter.send.call_args
assert kwargs["metadata"] is None assert kwargs["metadata"] is None
@pytest.mark.asyncio
async def test_inject_watch_notification_routes_from_session_store_origin(monkeypatch, tmp_path):
from gateway.session import SessionSource
runner = _build_runner(monkeypatch, tmp_path, "all")
adapter = runner.adapters[Platform.TELEGRAM]
runner.session_store._entries["agent:main:telegram:group:-100:42"] = SimpleNamespace(
origin=SessionSource(
platform=Platform.TELEGRAM,
chat_id="-100",
chat_type="group",
thread_id="42",
user_id="123",
user_name="Emiliyan",
)
)
evt = {
"session_id": "proc_watch",
"session_key": "agent:main:telegram:group:-100:42",
}
await runner._inject_watch_notification("[SYSTEM: Background process matched]", evt)
adapter.handle_message.assert_awaited_once()
synth_event = adapter.handle_message.await_args.args[0]
assert synth_event.internal is True
assert synth_event.source.platform == Platform.TELEGRAM
assert synth_event.source.chat_id == "-100"
assert synth_event.source.chat_type == "group"
assert synth_event.source.thread_id == "42"
assert synth_event.source.user_id == "123"
assert synth_event.source.user_name == "Emiliyan"
def test_build_process_event_source_falls_back_to_session_key_chat_type(monkeypatch, tmp_path):
runner = _build_runner(monkeypatch, tmp_path, "all")
evt = {
"session_id": "proc_watch",
"session_key": "agent:main:telegram:group:-100:42",
"platform": "telegram",
"chat_id": "-100",
"thread_id": "42",
"user_id": "123",
"user_name": "Emiliyan",
}
source = runner._build_process_event_source(evt)
assert source is not None
assert source.platform == Platform.TELEGRAM
assert source.chat_id == "-100"
assert source.chat_type == "group"
assert source.thread_id == "42"
assert source.user_id == "123"
assert source.user_name == "Emiliyan"
@pytest.mark.asyncio
async def test_inject_watch_notification_ignores_foreground_event_source(monkeypatch, tmp_path):
"""Negative test: watch notification must NOT route to the foreground thread."""
from gateway.session import SessionSource
runner = _build_runner(monkeypatch, tmp_path, "all")
adapter = runner.adapters[Platform.TELEGRAM]
# Session store has the process's original thread (thread 42)
runner.session_store._entries["agent:main:telegram:group:-100:42"] = SimpleNamespace(
origin=SessionSource(
platform=Platform.TELEGRAM,
chat_id="-100",
chat_type="group",
thread_id="42",
user_id="proc_owner",
user_name="alice",
)
)
# The evt dict carries the correct session_key — NOT a foreground event
evt = {
"session_id": "proc_cross_thread",
"session_key": "agent:main:telegram:group:-100:42",
}
await runner._inject_watch_notification("[SYSTEM: watch match]", evt)
adapter.handle_message.assert_awaited_once()
synth_event = adapter.handle_message.await_args.args[0]
# Must route to thread 42 (process origin), NOT some other thread
assert synth_event.source.thread_id == "42"
assert synth_event.source.user_id == "proc_owner"
def test_build_process_event_source_returns_none_for_empty_evt(monkeypatch, tmp_path):
"""Missing session_key and no platform metadata → None (drop notification)."""
runner = _build_runner(monkeypatch, tmp_path, "all")
source = runner._build_process_event_source({"session_id": "proc_orphan"})
assert source is None
def test_build_process_event_source_returns_none_for_invalid_platform(monkeypatch, tmp_path):
"""Invalid platform string → None."""
runner = _build_runner(monkeypatch, tmp_path, "all")
evt = {
"session_id": "proc_bad",
"platform": "not_a_real_platform",
"chat_type": "dm",
"chat_id": "123",
}
source = runner._build_process_event_source(evt)
assert source is None
def test_build_process_event_source_returns_none_for_short_session_key(monkeypatch, tmp_path):
"""Session key with <5 parts doesn't parse, falls through to empty metadata → None."""
runner = _build_runner(monkeypatch, tmp_path, "all")
evt = {
"session_id": "proc_short",
"session_key": "agent:main:telegram", # Too few parts
}
source = runner._build_process_event_source(evt)
assert source is None
# ---------------------------------------------------------------------------
# _parse_session_key helper
# ---------------------------------------------------------------------------
def test_parse_session_key_valid():
result = _parse_session_key("agent:main:telegram:group:-100")
assert result == {"platform": "telegram", "chat_type": "group", "chat_id": "-100"}
def test_parse_session_key_with_extra_parts():
"""Extra trailing parts (thread_id etc.) are ignored — only first 5 matter."""
result = _parse_session_key("agent:main:discord:group:chan123:thread456")
assert result == {"platform": "discord", "chat_type": "group", "chat_id": "chan123"}
def test_parse_session_key_too_short():
assert _parse_session_key("agent:main:telegram") is None
assert _parse_session_key("") is None
def test_parse_session_key_wrong_prefix():
assert _parse_session_key("cron:main:telegram:dm:123") is None
assert _parse_session_key("agent:cron:telegram:dm:123") is None

View File

@@ -230,6 +230,59 @@ async def test_notify_on_complete_preserves_user_identity(monkeypatch, tmp_path)
assert event.source.user_name == "alice" assert event.source.user_name == "alice"
@pytest.mark.asyncio
async def test_notify_on_complete_uses_session_store_origin_for_group_topic(monkeypatch, tmp_path):
import tools.process_registry as pr_module
from gateway.session import SessionSource
sessions = [
SimpleNamespace(
output_buffer="done\n", exited=True, exit_code=0, command="echo test"
),
]
monkeypatch.setattr(pr_module, "process_registry", _FakeRegistry(sessions))
async def _instant_sleep(*_a, **_kw):
pass
monkeypatch.setattr(asyncio, "sleep", _instant_sleep)
runner = GatewayRunner(GatewayConfig())
adapter = SimpleNamespace(send=AsyncMock(), handle_message=AsyncMock())
runner.adapters[Platform.TELEGRAM] = adapter
runner.session_store._entries["agent:main:telegram:group:-100:42"] = SimpleNamespace(
origin=SessionSource(
platform=Platform.TELEGRAM,
chat_id="-100",
chat_type="group",
thread_id="42",
user_id="user-42",
user_name="alice",
)
)
watcher = {
"session_id": "proc_test_internal",
"check_interval": 0,
"session_key": "agent:main:telegram:group:-100:42",
"platform": "telegram",
"chat_id": "-100",
"thread_id": "42",
"notify_on_complete": True,
}
await runner._run_process_watcher(watcher)
assert adapter.handle_message.await_count == 1
event = adapter.handle_message.await_args.args[0]
assert event.internal is True
assert event.source.platform == Platform.TELEGRAM
assert event.source.chat_id == "-100"
assert event.source.chat_type == "group"
assert event.source.thread_id == "42"
assert event.source.user_id == "user-42"
assert event.source.user_name == "alice"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_none_user_id_skips_pairing(monkeypatch, tmp_path): async def test_none_user_id_skips_pairing(monkeypatch, tmp_path):
"""A non-internal event with user_id=None should be silently dropped.""" """A non-internal event with user_id=None should be silently dropped."""

View File

@@ -92,6 +92,25 @@ class TestCheckWatchPatterns:
assert "disk full" in evt["output"] assert "disk full" in evt["output"]
assert evt["session_id"] == "proc_test_watch" assert evt["session_id"] == "proc_test_watch"
def test_match_carries_session_key_and_watcher_routing_metadata(self, registry):
session = _make_session(watch_patterns=["ERROR"])
session.session_key = "agent:main:telegram:group:-100:42"
session.watcher_platform = "telegram"
session.watcher_chat_id = "-100"
session.watcher_user_id = "u123"
session.watcher_user_name = "alice"
session.watcher_thread_id = "42"
registry._check_watch_patterns(session, "ERROR: disk full\n")
evt = registry.completion_queue.get_nowait()
assert evt["session_key"] == "agent:main:telegram:group:-100:42"
assert evt["platform"] == "telegram"
assert evt["chat_id"] == "-100"
assert evt["user_id"] == "u123"
assert evt["user_name"] == "alice"
assert evt["thread_id"] == "42"
def test_multiple_patterns(self, registry): def test_multiple_patterns(self, registry):
"""First matching pattern is reported.""" """First matching pattern is reported."""
session = _make_session(watch_patterns=["WARN", "ERROR"]) session = _make_session(watch_patterns=["WARN", "ERROR"])

View File

@@ -191,9 +191,15 @@ class ProcessRegistry:
session._watch_disabled = True session._watch_disabled = True
self.completion_queue.put({ self.completion_queue.put({
"session_id": session.id, "session_id": session.id,
"session_key": session.session_key,
"command": session.command, "command": session.command,
"type": "watch_disabled", "type": "watch_disabled",
"suppressed": session._watch_suppressed, "suppressed": session._watch_suppressed,
"platform": session.watcher_platform,
"chat_id": session.watcher_chat_id,
"user_id": session.watcher_user_id,
"user_name": session.watcher_user_name,
"thread_id": session.watcher_thread_id,
"message": ( "message": (
f"Watch patterns disabled for process {session.id}" f"Watch patterns disabled for process {session.id}"
f"too many matches ({session._watch_suppressed} suppressed). " f"too many matches ({session._watch_suppressed} suppressed). "
@@ -219,11 +225,17 @@ class ProcessRegistry:
self.completion_queue.put({ self.completion_queue.put({
"session_id": session.id, "session_id": session.id,
"session_key": session.session_key,
"command": session.command, "command": session.command,
"type": "watch_match", "type": "watch_match",
"pattern": matched_pattern, "pattern": matched_pattern,
"output": output, "output": output,
"suppressed": suppressed, "suppressed": suppressed,
"platform": session.watcher_platform,
"chat_id": session.watcher_chat_id,
"user_id": session.watcher_user_id,
"user_name": session.watcher_user_name,
"thread_id": session.watcher_thread_id,
}) })
@staticmethod @staticmethod

View File

@@ -1384,14 +1384,10 @@ def terminal_tool(
if pty_disabled_reason: if pty_disabled_reason:
result_data["pty_note"] = pty_disabled_reason result_data["pty_note"] = pty_disabled_reason
# Mark for agent notification on completion # Populate routing metadata on the session so that
if notify_on_complete and background: # watch-pattern and completion notifications can be
proc_session.notify_on_complete = True # routed back to the correct chat/thread.
result_data["notify_on_complete"] = True if background and (notify_on_complete or watch_patterns):
# In gateway mode, auto-register a fast watcher so the
# gateway can detect completion and trigger a new agent
# turn. CLI mode uses the completion_queue directly.
from gateway.session_context import get_session_env as _gse from gateway.session_context import get_session_env as _gse
_gw_platform = _gse("HERMES_SESSION_PLATFORM", "") _gw_platform = _gse("HERMES_SESSION_PLATFORM", "")
if _gw_platform: if _gw_platform:
@@ -1404,16 +1400,26 @@ def terminal_tool(
proc_session.watcher_user_id = _gw_user_id proc_session.watcher_user_id = _gw_user_id
proc_session.watcher_user_name = _gw_user_name proc_session.watcher_user_name = _gw_user_name
proc_session.watcher_thread_id = _gw_thread_id proc_session.watcher_thread_id = _gw_thread_id
# Mark for agent notification on completion
if notify_on_complete and background:
proc_session.notify_on_complete = True
result_data["notify_on_complete"] = True
# In gateway mode, auto-register a fast watcher so the
# gateway can detect completion and trigger a new agent
# turn. CLI mode uses the completion_queue directly.
if proc_session.watcher_platform:
proc_session.watcher_interval = 5 proc_session.watcher_interval = 5
process_registry.pending_watchers.append({ process_registry.pending_watchers.append({
"session_id": proc_session.id, "session_id": proc_session.id,
"check_interval": 5, "check_interval": 5,
"session_key": session_key, "session_key": session_key,
"platform": _gw_platform, "platform": proc_session.watcher_platform,
"chat_id": _gw_chat_id, "chat_id": proc_session.watcher_chat_id,
"user_id": _gw_user_id, "user_id": proc_session.watcher_user_id,
"user_name": _gw_user_name, "user_name": proc_session.watcher_user_name,
"thread_id": _gw_thread_id, "thread_id": proc_session.watcher_thread_id,
"notify_on_complete": True, "notify_on_complete": True,
}) })