mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-07-04 09:07:20 +08:00
Compare commits
6 Commits
fix/tui-qu
...
fix/gatewa
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
bf952f5705 | ||
|
|
71e13373de | ||
|
|
6ea0f72885 | ||
|
|
b8cd48db9c | ||
|
|
32a326585f | ||
|
|
b37d81e302 |
@@ -4335,7 +4335,8 @@ class BasePlatformAdapter(ABC):
|
||||
# Rewrite ``event.source.thread_id`` via the installed recovery hook
|
||||
# (Telegram DM topic mode) so the session key, guard checks, and
|
||||
# downstream delivery all agree on the same lane.
|
||||
self._apply_topic_recovery(event)
|
||||
# Offloaded: the sync hook must not block the loop.
|
||||
await asyncio.to_thread(self._apply_topic_recovery, event)
|
||||
|
||||
session_key = build_session_key(
|
||||
event.source,
|
||||
|
||||
104
gateway/run.py
104
gateway/run.py
@@ -2777,8 +2777,8 @@ class GatewayRunner(GatewayAuthorizationMixin, GatewayKanbanWatchersMixin, Gatew
|
||||
# Initialize session database for session_search tool support
|
||||
self._session_db = None
|
||||
try:
|
||||
from hermes_state import SessionDB
|
||||
self._session_db = SessionDB()
|
||||
from hermes_state import AsyncSessionDB, SessionDB
|
||||
self._session_db = AsyncSessionDB(SessionDB())
|
||||
except Exception as e:
|
||||
# WARNING (not DEBUG) so the failure appears in errors.log — matches
|
||||
# cli.py's handling of the same init path. Users hitting NFS-mounted
|
||||
@@ -2799,7 +2799,8 @@ class GatewayRunner(GatewayAuthorizationMixin, GatewayKanbanWatchersMixin, Gatew
|
||||
from hermes_cli.config import load_config as _load_full_config
|
||||
_sess_cfg = (_load_full_config().get("sessions") or {})
|
||||
if _sess_cfg.get("auto_prune", False):
|
||||
self._session_db.maybe_auto_prune_and_vacuum(
|
||||
# Construction-time, before the loop serves traffic; sync DB is fine.
|
||||
self._session_db._db.maybe_auto_prune_and_vacuum(
|
||||
retention_days=int(_sess_cfg.get("retention_days", 90)),
|
||||
min_interval_hours=int(_sess_cfg.get("min_interval_hours", 24)),
|
||||
vacuum=bool(_sess_cfg.get("vacuum_after_prune", True)),
|
||||
@@ -3254,6 +3255,8 @@ class GatewayRunner(GatewayAuthorizationMixin, GatewayKanbanWatchersMixin, Gatew
|
||||
session_db = getattr(self, "_session_db", None)
|
||||
if session_db is None:
|
||||
return False
|
||||
# Runs off-loop (always via asyncio.to_thread); use the sync handle.
|
||||
session_db = getattr(session_db, "_db", session_db)
|
||||
try:
|
||||
raw = session_db.is_telegram_topic_mode_enabled(
|
||||
chat_id=str(source.chat_id),
|
||||
@@ -3351,6 +3354,8 @@ class GatewayRunner(GatewayAuthorizationMixin, GatewayKanbanWatchersMixin, Gatew
|
||||
session_db = getattr(self, "_session_db", None)
|
||||
if session_db is None or not source.chat_id or not source.thread_id:
|
||||
return
|
||||
# Runs off-loop (always via asyncio.to_thread); use the sync handle.
|
||||
session_db = getattr(session_db, "_db", session_db)
|
||||
session_db.bind_telegram_topic(
|
||||
chat_id=str(source.chat_id),
|
||||
thread_id=str(source.thread_id),
|
||||
@@ -3419,6 +3424,8 @@ class GatewayRunner(GatewayAuthorizationMixin, GatewayKanbanWatchersMixin, Gatew
|
||||
session_db = getattr(self, "_session_db", None)
|
||||
if session_db is None:
|
||||
return None
|
||||
# Runs off-loop (always via asyncio.to_thread); use the sync handle.
|
||||
session_db = getattr(session_db, "_db", session_db)
|
||||
try:
|
||||
bindings = session_db.list_telegram_topic_bindings_for_chat(
|
||||
chat_id=str(source.chat_id),
|
||||
@@ -6552,23 +6559,23 @@ class GatewayRunner(GatewayAuthorizationMixin, GatewayKanbanWatchersMixin, Gatew
|
||||
if self._session_db is None:
|
||||
await asyncio.sleep(interval)
|
||||
continue
|
||||
pending = await asyncio.to_thread(self._session_db.list_pending_handoffs)
|
||||
pending = await self._session_db.list_pending_handoffs()
|
||||
for row in pending:
|
||||
session_id = row.get("id")
|
||||
if not session_id:
|
||||
continue
|
||||
if not await asyncio.to_thread(self._session_db.claim_handoff, session_id):
|
||||
if not await self._session_db.claim_handoff(session_id):
|
||||
# Another tick or another gateway already claimed it.
|
||||
continue
|
||||
try:
|
||||
await self._process_handoff(row)
|
||||
await asyncio.to_thread(self._session_db.complete_handoff, session_id)
|
||||
await self._session_db.complete_handoff(session_id)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Handoff for session %s failed: %s",
|
||||
session_id, exc, exc_info=True,
|
||||
)
|
||||
await asyncio.to_thread(self._session_db.fail_handoff, session_id, str(exc))
|
||||
await self._session_db.fail_handoff(session_id, str(exc))
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
@@ -7417,8 +7424,11 @@ class GatewayRunner(GatewayAuthorizationMixin, GatewayKanbanWatchersMixin, Gatew
|
||||
# old gateway's connection holding the WAL lock until Python
|
||||
# actually exits — causing 'database is locked' errors when
|
||||
# the new gateway tries to open the same file.
|
||||
for _db_holder in (self, getattr(self, "session_store", None)):
|
||||
_db = getattr(_db_holder, "_db", None) if _db_holder else None
|
||||
# ``self`` holds the DB at ``_session_db`` (an AsyncSessionDB facade);
|
||||
# unwrap to the sync handle. ``session_store`` holds it at ``_db``.
|
||||
_self_db = getattr(self, "_session_db", None)
|
||||
_self_db = getattr(_self_db, "_db", _self_db)
|
||||
for _db in (_self_db, getattr(getattr(self, "session_store", None), "_db", None)):
|
||||
if _db is None or not hasattr(_db, "close"):
|
||||
continue
|
||||
try:
|
||||
@@ -8656,7 +8666,7 @@ class GatewayRunner(GatewayAuthorizationMixin, GatewayKanbanWatchersMixin, Gatew
|
||||
break
|
||||
|
||||
if canonical == "new":
|
||||
if self._is_telegram_topic_root_lobby(source):
|
||||
if await asyncio.to_thread(self._is_telegram_topic_root_lobby, source):
|
||||
return self._telegram_topic_root_new_message()
|
||||
async def _do_reset():
|
||||
return await self._handle_reset_command(event)
|
||||
@@ -9117,7 +9127,7 @@ class GatewayRunner(GatewayAuthorizationMixin, GatewayKanbanWatchersMixin, Gatew
|
||||
# No bare text matching — "yes" in normal conversation must not trigger
|
||||
# execution of a dangerous command.
|
||||
|
||||
if self._is_telegram_topic_root_lobby(source):
|
||||
if await asyncio.to_thread(self._is_telegram_topic_root_lobby, source):
|
||||
# Debounce the lobby reminder so a user who forgets about
|
||||
# topic mode and fires ten prompts doesn't get ten copies.
|
||||
if self._should_send_telegram_lobby_reminder(source):
|
||||
@@ -9598,7 +9608,7 @@ class GatewayRunner(GatewayAuthorizationMixin, GatewayKanbanWatchersMixin, Gatew
|
||||
# Topic-mode DMs: rewrite a stale/foreign thread_id to the user's
|
||||
# last-active topic so a cross-topic Reply or stripped plain reply
|
||||
# doesn't fragment the conversation across sessions.
|
||||
recovered = self._recover_telegram_topic_thread_id(source)
|
||||
recovered = await asyncio.to_thread(self._recover_telegram_topic_thread_id, source)
|
||||
if recovered is not None:
|
||||
logger.info(
|
||||
"telegram topic recovery: chat=%s user=%s %r -> %s",
|
||||
@@ -9613,12 +9623,12 @@ class GatewayRunner(GatewayAuthorizationMixin, GatewayKanbanWatchersMixin, Gatew
|
||||
session_entry = self.session_store.get_or_create_session(source)
|
||||
session_key = session_entry.session_key
|
||||
self._cache_session_source(session_key, source)
|
||||
if self._is_telegram_topic_lane(source):
|
||||
if await asyncio.to_thread(self._is_telegram_topic_lane, source):
|
||||
try:
|
||||
binding = self._session_db.get_telegram_topic_binding(
|
||||
binding = (await self._session_db.get_telegram_topic_binding(
|
||||
chat_id=str(source.chat_id),
|
||||
thread_id=str(source.thread_id),
|
||||
) if self._session_db else None
|
||||
)) if self._session_db else None
|
||||
except Exception:
|
||||
logger.debug("Failed to read Telegram topic binding", exc_info=True)
|
||||
binding = None
|
||||
@@ -9632,7 +9642,7 @@ class GatewayRunner(GatewayAuthorizationMixin, GatewayKanbanWatchersMixin, Gatew
|
||||
# a compression parent, so this is cheap and safe.
|
||||
if bound_session_id and self._session_db is not None:
|
||||
try:
|
||||
canonical_session_id = self._session_db.get_compression_tip(
|
||||
canonical_session_id = await self._session_db.get_compression_tip(
|
||||
bound_session_id,
|
||||
)
|
||||
except Exception:
|
||||
@@ -9661,12 +9671,13 @@ class GatewayRunner(GatewayAuthorizationMixin, GatewayKanbanWatchersMixin, Gatew
|
||||
bound_session_id
|
||||
and bound_session_id != str(binding.get("session_id") or "")
|
||||
):
|
||||
self._sync_telegram_topic_binding(
|
||||
await asyncio.to_thread(
|
||||
self._sync_telegram_topic_binding,
|
||||
source, session_entry, reason="compression-tip-walk",
|
||||
)
|
||||
else:
|
||||
try:
|
||||
self._record_telegram_topic_binding(source, session_entry)
|
||||
await asyncio.to_thread(self._record_telegram_topic_binding, source, session_entry)
|
||||
except Exception:
|
||||
logger.debug("Failed to record Telegram topic binding", exc_info=True)
|
||||
# Capture and immediately consume was_auto_reset so it does not
|
||||
@@ -10037,7 +10048,7 @@ class GatewayRunner(GatewayAuthorizationMixin, GatewayKanbanWatchersMixin, Gatew
|
||||
skip_memory=True,
|
||||
enabled_toolsets=["memory"],
|
||||
session_id=session_entry.session_id,
|
||||
session_db=self._session_db,
|
||||
session_db=getattr(self._session_db, "_db", self._session_db),
|
||||
)
|
||||
try:
|
||||
# The hygiene agent rotates the session
|
||||
@@ -10070,7 +10081,8 @@ class GatewayRunner(GatewayAuthorizationMixin, GatewayKanbanWatchersMixin, Gatew
|
||||
if _hyg_rotated:
|
||||
session_entry.session_id = _hyg_new_sid
|
||||
self.session_store._save()
|
||||
self._sync_telegram_topic_binding(
|
||||
await asyncio.to_thread(
|
||||
self._sync_telegram_topic_binding,
|
||||
source, session_entry,
|
||||
reason="hygiene-compression",
|
||||
)
|
||||
@@ -10424,7 +10436,7 @@ class GatewayRunner(GatewayAuthorizationMixin, GatewayKanbanWatchersMixin, Gatew
|
||||
# prompt caching. Refreshing here makes the guard fire only on a
|
||||
# DIFFERENT process's writes. Uses the (possibly compaction-
|
||||
# updated) live session_id. Fail-safe inside the helper.
|
||||
self._refresh_agent_cache_message_count(
|
||||
await self._refresh_agent_cache_message_count(
|
||||
session_key, session_entry.session_id
|
||||
)
|
||||
|
||||
@@ -10461,7 +10473,8 @@ class GatewayRunner(GatewayAuthorizationMixin, GatewayKanbanWatchersMixin, Gatew
|
||||
if agent_result.get("session_id") and agent_result["session_id"] != session_entry.session_id:
|
||||
session_entry.session_id = agent_result["session_id"]
|
||||
self.session_store._save()
|
||||
self._sync_telegram_topic_binding(
|
||||
await asyncio.to_thread(
|
||||
self._sync_telegram_topic_binding,
|
||||
source, session_entry, reason="agent-result-compression",
|
||||
)
|
||||
|
||||
@@ -10664,7 +10677,8 @@ class GatewayRunner(GatewayAuthorizationMixin, GatewayKanbanWatchersMixin, Gatew
|
||||
# forever (#35809 — regression of the #9893/#10063 auto-reset).
|
||||
# No-op on non-topic lanes.
|
||||
session_entry = new_entry
|
||||
self._sync_telegram_topic_binding(
|
||||
await asyncio.to_thread(
|
||||
self._sync_telegram_topic_binding,
|
||||
source, session_entry, reason="compression-exhausted-reset",
|
||||
)
|
||||
response = (response or "") + (
|
||||
@@ -12055,7 +12069,7 @@ class GatewayRunner(GatewayAuthorizationMixin, GatewayKanbanWatchersMixin, Gatew
|
||||
chat_name=source.chat_name,
|
||||
chat_type=source.chat_type,
|
||||
thread_id=source.thread_id,
|
||||
session_db=self._session_db,
|
||||
session_db=getattr(self._session_db, "_db", self._session_db),
|
||||
fallback_model=self._fallback_model,
|
||||
)
|
||||
try:
|
||||
@@ -12274,7 +12288,7 @@ class GatewayRunner(GatewayAuthorizationMixin, GatewayKanbanWatchersMixin, Gatew
|
||||
title: str,
|
||||
) -> None:
|
||||
"""Best-effort rename of a Telegram DM topic when Hermes auto-titles a session."""
|
||||
if not self._is_telegram_topic_lane(source) or not source.chat_id or not source.thread_id:
|
||||
if not await asyncio.to_thread(self._is_telegram_topic_lane, source) or not source.chat_id or not source.thread_id:
|
||||
return
|
||||
|
||||
# Operator can fully disable per-topic auto-rename via
|
||||
@@ -12308,7 +12322,7 @@ class GatewayRunner(GatewayAuthorizationMixin, GatewayKanbanWatchersMixin, Gatew
|
||||
session_db = getattr(self, "_session_db", None)
|
||||
if session_db is not None:
|
||||
try:
|
||||
binding = session_db.get_telegram_topic_binding(
|
||||
binding = await session_db.get_telegram_topic_binding(
|
||||
chat_id=str(source.chat_id),
|
||||
thread_id=str(source.thread_id),
|
||||
)
|
||||
@@ -12455,7 +12469,7 @@ class GatewayRunner(GatewayAuthorizationMixin, GatewayKanbanWatchersMixin, Gatew
|
||||
"5. /topic <id> inside a topic restores an old session into it."
|
||||
)
|
||||
|
||||
def _disable_telegram_topic_mode_for_chat(self, source: SessionSource) -> str:
|
||||
async def _disable_telegram_topic_mode_for_chat(self, source: SessionSource) -> str:
|
||||
"""Cleanly disable topic mode for a chat via /topic off."""
|
||||
if not self._session_db:
|
||||
from hermes_state import format_session_db_unavailable
|
||||
@@ -12465,7 +12479,7 @@ class GatewayRunner(GatewayAuthorizationMixin, GatewayKanbanWatchersMixin, Gatew
|
||||
return "Could not determine chat ID."
|
||||
# No-op if never enabled.
|
||||
try:
|
||||
currently_enabled = self._session_db.is_telegram_topic_mode_enabled(
|
||||
currently_enabled = await self._session_db.is_telegram_topic_mode_enabled(
|
||||
chat_id=chat_id,
|
||||
user_id=str(source.user_id or ""),
|
||||
)
|
||||
@@ -12474,7 +12488,7 @@ class GatewayRunner(GatewayAuthorizationMixin, GatewayKanbanWatchersMixin, Gatew
|
||||
if not currently_enabled:
|
||||
return "Multi-session topic mode is not currently enabled for this chat."
|
||||
try:
|
||||
self._session_db.disable_telegram_topic_mode(chat_id=chat_id)
|
||||
await self._session_db.disable_telegram_topic_mode(chat_id=chat_id)
|
||||
except Exception as exc:
|
||||
logger.exception("Failed to disable Telegram topic mode")
|
||||
return f"Failed to disable topic mode: {exc}"
|
||||
@@ -12492,7 +12506,7 @@ class GatewayRunner(GatewayAuthorizationMixin, GatewayKanbanWatchersMixin, Gatew
|
||||
)
|
||||
|
||||
|
||||
def _telegram_topic_root_status_message(self, source: SessionSource) -> str:
|
||||
async def _telegram_topic_root_status_message(self, source: SessionSource) -> str:
|
||||
lines = [
|
||||
"Telegram multi-session topics are enabled.",
|
||||
"",
|
||||
@@ -12502,7 +12516,7 @@ class GatewayRunner(GatewayAuthorizationMixin, GatewayKanbanWatchersMixin, Gatew
|
||||
"",
|
||||
]
|
||||
try:
|
||||
sessions = self._session_db.list_unlinked_telegram_sessions_for_user(
|
||||
sessions = await self._session_db.list_unlinked_telegram_sessions_for_user(
|
||||
chat_id=str(source.chat_id),
|
||||
user_id=str(source.user_id),
|
||||
limit=10,
|
||||
@@ -12541,11 +12555,11 @@ class GatewayRunner(GatewayAuthorizationMixin, GatewayKanbanWatchersMixin, Gatew
|
||||
async def _restore_telegram_topic_session(self, event: MessageEvent, raw_session_id: str) -> str:
|
||||
"""Restore an existing Telegram-owned Hermes session into this topic."""
|
||||
source = event.source
|
||||
session_id = self._session_db.resolve_session_id(raw_session_id.strip())
|
||||
session_id = await self._session_db.resolve_session_id(raw_session_id.strip())
|
||||
if not session_id:
|
||||
return f"Session not found: {raw_session_id.strip()}"
|
||||
|
||||
session = self._session_db.get_session(session_id)
|
||||
session = await self._session_db.get_session(session_id)
|
||||
if not session:
|
||||
return f"Session not found: {raw_session_id.strip()}"
|
||||
if str(session.get("source") or "") != "telegram":
|
||||
@@ -12553,8 +12567,8 @@ class GatewayRunner(GatewayAuthorizationMixin, GatewayKanbanWatchersMixin, Gatew
|
||||
if str(session.get("user_id") or "") != str(source.user_id):
|
||||
return "That session does not belong to this Telegram user."
|
||||
|
||||
linked = self._session_db.is_telegram_session_linked_to_topic(session_id=session_id)
|
||||
current_binding = self._session_db.get_telegram_topic_binding(
|
||||
linked = await self._session_db.is_telegram_session_linked_to_topic(session_id=session_id)
|
||||
current_binding = await self._session_db.get_telegram_topic_binding(
|
||||
chat_id=str(source.chat_id),
|
||||
thread_id=str(source.thread_id),
|
||||
)
|
||||
@@ -12564,7 +12578,7 @@ class GatewayRunner(GatewayAuthorizationMixin, GatewayKanbanWatchersMixin, Gatew
|
||||
|
||||
session_key = self._session_key_for_source(source)
|
||||
try:
|
||||
self._session_db.bind_telegram_topic(
|
||||
await self._session_db.bind_telegram_topic(
|
||||
chat_id=str(source.chat_id),
|
||||
thread_id=str(source.thread_id),
|
||||
user_id=str(source.user_id),
|
||||
@@ -12577,10 +12591,10 @@ class GatewayRunner(GatewayAuthorizationMixin, GatewayKanbanWatchersMixin, Gatew
|
||||
return "That session is already linked to another Telegram topic."
|
||||
raise
|
||||
|
||||
title = self._session_db.get_session_title(session_id) or session_id
|
||||
title = await self._session_db.get_session_title(session_id) or session_id
|
||||
last_assistant = None
|
||||
try:
|
||||
for message in reversed(self._session_db.get_messages(session_id)):
|
||||
for message in reversed(await self._session_db.get_messages(session_id)):
|
||||
if message.get("role") == "assistant" and message.get("content"):
|
||||
last_assistant = str(message.get("content"))
|
||||
break
|
||||
@@ -14605,7 +14619,7 @@ class GatewayRunner(GatewayAuthorizationMixin, GatewayKanbanWatchersMixin, Gatew
|
||||
if release_running_state:
|
||||
self._release_running_agent_state(session_key)
|
||||
|
||||
def _refresh_agent_cache_message_count(
|
||||
async def _refresh_agent_cache_message_count(
|
||||
self, session_key: str, session_id: Optional[str]
|
||||
) -> None:
|
||||
"""Re-baseline a cached agent's stored message_count after THIS turn.
|
||||
@@ -14637,7 +14651,7 @@ class GatewayRunner(GatewayAuthorizationMixin, GatewayKanbanWatchersMixin, Gatew
|
||||
if not _cache_lock or _cache is None:
|
||||
return
|
||||
try:
|
||||
_sess_row = self._session_db.get_session(session_id)
|
||||
_sess_row = await self._session_db.get_session(session_id)
|
||||
_live = _sess_row.get("message_count", 0) if _sess_row else None
|
||||
except Exception:
|
||||
return
|
||||
@@ -16320,7 +16334,8 @@ class GatewayRunner(GatewayAuthorizationMixin, GatewayKanbanWatchersMixin, Gatew
|
||||
_current_msg_count = None
|
||||
if self._session_db is not None and session_id:
|
||||
try:
|
||||
_sess_row = self._session_db.get_session(session_id)
|
||||
# run_sync is off-loop (executor); sync DB is fine.
|
||||
_sess_row = self._session_db._db.get_session(session_id)
|
||||
if _sess_row:
|
||||
_current_msg_count = _sess_row.get("message_count", 0)
|
||||
except Exception:
|
||||
@@ -16431,7 +16446,7 @@ class GatewayRunner(GatewayAuthorizationMixin, GatewayKanbanWatchersMixin, Gatew
|
||||
chat_type=source.chat_type,
|
||||
thread_id=source.thread_id,
|
||||
gateway_session_key=session_key,
|
||||
session_db=self._session_db,
|
||||
session_db=getattr(self._session_db, "_db", self._session_db),
|
||||
fallback_model=self._fallback_model,
|
||||
)
|
||||
if _cache_lock and _cache is not None:
|
||||
@@ -17018,7 +17033,8 @@ class GatewayRunner(GatewayAuthorizationMixin, GatewayKanbanWatchersMixin, Gatew
|
||||
and self._session_db is not None
|
||||
):
|
||||
try:
|
||||
_binding = self._session_db.get_telegram_topic_binding_by_session(
|
||||
# run_sync is off-loop (executor); sync DB is fine.
|
||||
_binding = self._session_db._db.get_telegram_topic_binding_by_session(
|
||||
session_id=agent_session_id,
|
||||
)
|
||||
if _binding and _binding.get("thread_id"):
|
||||
@@ -17143,7 +17159,7 @@ class GatewayRunner(GatewayAuthorizationMixin, GatewayKanbanWatchersMixin, Gatew
|
||||
title,
|
||||
)
|
||||
maybe_auto_title(
|
||||
self._session_db,
|
||||
getattr(self._session_db, "_db", self._session_db),
|
||||
effective_session_id,
|
||||
message,
|
||||
final_response,
|
||||
|
||||
@@ -228,11 +228,11 @@ class GatewaySlashCommandsMixin:
|
||||
session_info = ""
|
||||
|
||||
if new_entry:
|
||||
header = self._telegram_topic_new_header(source) or t("gateway.reset.header_default")
|
||||
header = await asyncio.to_thread(self._telegram_topic_new_header, source) or t("gateway.reset.header_default")
|
||||
else:
|
||||
# No existing session, just create one
|
||||
new_entry = self.session_store.get_or_create_session(source, force_new=True)
|
||||
header = self._telegram_topic_new_header(source) or t("gateway.reset.header_new")
|
||||
header = await asyncio.to_thread(self._telegram_topic_new_header, source) or t("gateway.reset.header_new")
|
||||
|
||||
# Set session title if provided with /new <title>
|
||||
_title_arg = event.get_command_args().strip()
|
||||
@@ -246,7 +246,7 @@ class GatewaySlashCommandsMixin:
|
||||
_title_note = t("gateway.reset.title_rejected", error=str(e))
|
||||
if sanitized:
|
||||
try:
|
||||
self._session_db.set_session_title(new_entry.session_id, sanitized)
|
||||
await self._session_db.set_session_title(new_entry.session_id, sanitized)
|
||||
header = t("gateway.reset.header_titled", title=sanitized)
|
||||
except ValueError as e:
|
||||
_title_note = t("gateway.reset.title_error_untitled", error=str(e))
|
||||
@@ -262,9 +262,9 @@ class GatewaySlashCommandsMixin:
|
||||
# uses the freshly-created session. Without this, the binding
|
||||
# still points at the old session and the binding-lookup at the
|
||||
# top of _handle_message_with_agent would switch right back.
|
||||
if self._is_telegram_topic_lane(source) and new_entry is not None:
|
||||
if await asyncio.to_thread(self._is_telegram_topic_lane, source) and new_entry is not None:
|
||||
try:
|
||||
self._record_telegram_topic_binding(source, new_entry)
|
||||
await asyncio.to_thread(self._record_telegram_topic_binding, source, new_entry)
|
||||
except Exception:
|
||||
logger.debug("Failed to rebind Telegram topic after /new", exc_info=True)
|
||||
|
||||
@@ -498,11 +498,11 @@ class GatewaySlashCommandsMixin:
|
||||
db_total_tokens = 0
|
||||
if self._session_db:
|
||||
try:
|
||||
title = self._session_db.get_session_title(session_entry.session_id)
|
||||
title = await self._session_db.get_session_title(session_entry.session_id)
|
||||
except Exception:
|
||||
title = None
|
||||
try:
|
||||
row = self._session_db.get_session(session_entry.session_id)
|
||||
row = await self._session_db.get_session(session_entry.session_id)
|
||||
if isinstance(row, dict):
|
||||
session_row = row
|
||||
db_total_tokens = (
|
||||
@@ -1175,7 +1175,7 @@ class GatewaySlashCommandsMixin:
|
||||
# (Telegram DM topic recovery) before deriving the override key, so
|
||||
# the override is stored under the key the next message turn reads
|
||||
# (#30479).
|
||||
source = self._normalize_source_for_session_key(source)
|
||||
source = await asyncio.to_thread(self._normalize_source_for_session_key, source)
|
||||
session_key = self._session_key_for_source(source)
|
||||
override = self._session_model_overrides.get(session_key, {})
|
||||
if override:
|
||||
@@ -1308,7 +1308,7 @@ class GatewaySlashCommandsMixin:
|
||||
_sess_entry = _self.session_store.get_or_create_session(
|
||||
event.source
|
||||
)
|
||||
_sess_db.update_session_model(
|
||||
await _sess_db.update_session_model(
|
||||
_sess_entry.session_id, result.new_model
|
||||
)
|
||||
except Exception as exc:
|
||||
@@ -1539,7 +1539,7 @@ class GatewaySlashCommandsMixin:
|
||||
# override just stored below (Closes #48031).
|
||||
if getattr(_sess_entry, "was_auto_reset", False):
|
||||
_sess_entry.was_auto_reset = False
|
||||
_sess_db.update_session_model(
|
||||
await _sess_db.update_session_model(
|
||||
_sess_entry.session_id, result.new_model
|
||||
)
|
||||
except Exception as exc:
|
||||
@@ -2331,7 +2331,7 @@ class GatewaySlashCommandsMixin:
|
||||
# Normalize the source (Telegram DM topic recovery) before deriving
|
||||
# the override key so storage matches the key the next message turn
|
||||
# reads — same fix as /model (#30479).
|
||||
_reasoning_source = self._normalize_source_for_session_key(event.source)
|
||||
_reasoning_source = await asyncio.to_thread(self._normalize_source_for_session_key, event.source)
|
||||
session_key = self._session_key_for_source(_reasoning_source)
|
||||
self._show_reasoning = self._load_show_reasoning()
|
||||
self._reasoning_config = self._resolve_session_reasoning_config(
|
||||
@@ -2825,7 +2825,7 @@ class GatewaySlashCommandsMixin:
|
||||
skip_memory=True,
|
||||
enabled_toolsets=["memory"],
|
||||
session_id=session_entry.session_id,
|
||||
session_db=self._session_db,
|
||||
session_db=getattr(self._session_db, "_db", self._session_db),
|
||||
)
|
||||
try:
|
||||
tmp_agent._print_fn = lambda *a, **kw: None
|
||||
@@ -2870,7 +2870,8 @@ class GatewaySlashCommandsMixin:
|
||||
if rotated:
|
||||
session_entry.session_id = new_session_id
|
||||
self.session_store._save()
|
||||
self._sync_telegram_topic_binding(
|
||||
await asyncio.to_thread(
|
||||
self._sync_telegram_topic_binding,
|
||||
source, session_entry, reason="compress-command",
|
||||
)
|
||||
|
||||
@@ -2983,7 +2984,7 @@ class GatewaySlashCommandsMixin:
|
||||
|
||||
# /topic off — clean disable path so users don't have to edit the DB.
|
||||
if args.lower() in {"off", "disable", "stop"}:
|
||||
return self._disable_telegram_topic_mode_for_chat(source)
|
||||
return await self._disable_telegram_topic_mode_for_chat(source)
|
||||
|
||||
if args:
|
||||
if not source.thread_id:
|
||||
@@ -3004,7 +3005,7 @@ class GatewaySlashCommandsMixin:
|
||||
return t("gateway.topic.topics_user_disallowed")
|
||||
|
||||
try:
|
||||
self._session_db.enable_telegram_topic_mode(
|
||||
await self._session_db.enable_telegram_topic_mode(
|
||||
chat_id=str(source.chat_id),
|
||||
user_id=str(source.user_id),
|
||||
has_topics_enabled=capabilities.get("has_topics_enabled"),
|
||||
@@ -3019,7 +3020,7 @@ class GatewaySlashCommandsMixin:
|
||||
|
||||
if source.thread_id:
|
||||
try:
|
||||
binding = self._session_db.get_telegram_topic_binding(
|
||||
binding = await self._session_db.get_telegram_topic_binding(
|
||||
chat_id=str(source.chat_id),
|
||||
thread_id=str(source.thread_id),
|
||||
)
|
||||
@@ -3030,7 +3031,7 @@ class GatewaySlashCommandsMixin:
|
||||
session_id = str(binding.get("session_id") or "")
|
||||
title = None
|
||||
try:
|
||||
title = self._session_db.get_session_title(session_id)
|
||||
title = await self._session_db.get_session_title(session_id)
|
||||
except Exception:
|
||||
title = None
|
||||
session_label = title or t("gateway.topic.untitled_session")
|
||||
@@ -3041,7 +3042,7 @@ class GatewaySlashCommandsMixin:
|
||||
)
|
||||
return t("gateway.topic.thread_ready")
|
||||
|
||||
return self._telegram_topic_root_status_message(source)
|
||||
return await self._telegram_topic_root_status_message(source)
|
||||
|
||||
async def _handle_title_command(self, event: MessageEvent) -> str:
|
||||
"""Handle /title command — set or show the current session's title."""
|
||||
@@ -3055,11 +3056,11 @@ class GatewaySlashCommandsMixin:
|
||||
|
||||
# Ensure session exists in SQLite DB (it may only exist in session_store
|
||||
# if this is the first command in a new session)
|
||||
existing_title = self._session_db.get_session_title(session_id)
|
||||
existing_title = await self._session_db.get_session_title(session_id)
|
||||
if existing_title is None:
|
||||
# Session doesn't exist in DB yet — create it
|
||||
try:
|
||||
self._session_db.create_session(
|
||||
await self._session_db.create_session(
|
||||
session_id=session_id,
|
||||
source=source.platform.value if source.platform else "unknown",
|
||||
user_id=source.user_id,
|
||||
@@ -3071,14 +3072,15 @@ class GatewaySlashCommandsMixin:
|
||||
if title_arg:
|
||||
# Sanitize the title before setting
|
||||
try:
|
||||
sanitized = self._session_db.sanitize_title(title_arg)
|
||||
from hermes_state import SessionDB
|
||||
sanitized = SessionDB.sanitize_title(title_arg)
|
||||
except ValueError as e:
|
||||
return t("gateway.shared.warn_passthrough", error=e)
|
||||
if not sanitized:
|
||||
return t("gateway.title.empty_after_clean")
|
||||
# Set the title
|
||||
try:
|
||||
if self._session_db.set_session_title(session_id, sanitized):
|
||||
if await self._session_db.set_session_title(session_id, sanitized):
|
||||
# Propagate the user-chosen title to the visible Telegram
|
||||
# forum topic name too. Auto-generated titles already rename
|
||||
# the topic; without this, /title only updated the DB title
|
||||
@@ -3089,7 +3091,7 @@ class GatewaySlashCommandsMixin:
|
||||
)
|
||||
if callable(schedule_rename):
|
||||
try:
|
||||
schedule_rename(source, session_id, sanitized)
|
||||
await asyncio.to_thread(schedule_rename, source, session_id, sanitized)
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"Failed to rename Telegram topic from /title",
|
||||
@@ -3102,7 +3104,7 @@ class GatewaySlashCommandsMixin:
|
||||
return t("gateway.shared.warn_passthrough", error=e)
|
||||
else:
|
||||
# Show the current title and session ID
|
||||
title = self._session_db.get_session_title(session_id)
|
||||
title = await self._session_db.get_session_title(session_id)
|
||||
if title:
|
||||
return t("gateway.title.current_with_title", session_id=session_id, title=title)
|
||||
else:
|
||||
@@ -3135,15 +3137,15 @@ class GatewaySlashCommandsMixin:
|
||||
):
|
||||
name = name[1:-1].strip()
|
||||
|
||||
def _list_titled_sessions() -> list[dict]:
|
||||
async def _list_titled_sessions() -> list[dict]:
|
||||
user_source = source.platform.value if source.platform else None
|
||||
sessions = self._session_db.list_sessions_rich(source=user_source, limit=10)
|
||||
sessions = await self._session_db.list_sessions_rich(source=user_source, limit=10)
|
||||
return [s for s in sessions if s.get("title")][:10]
|
||||
|
||||
if not name:
|
||||
# List recent titled sessions for this user/platform
|
||||
try:
|
||||
titled = _list_titled_sessions()
|
||||
titled = await _list_titled_sessions()
|
||||
if source.platform == Platform.MATRIX and not allow_all:
|
||||
scoped = []
|
||||
for s in titled:
|
||||
@@ -3174,7 +3176,7 @@ class GatewaySlashCommandsMixin:
|
||||
# Resolve a numbered choice or a title to a session ID.
|
||||
if name.isdigit():
|
||||
try:
|
||||
titled = _list_titled_sessions()
|
||||
titled = await _list_titled_sessions()
|
||||
if source.platform == Platform.MATRIX and not allow_all:
|
||||
scoped = []
|
||||
for s in titled:
|
||||
@@ -3194,17 +3196,17 @@ class GatewaySlashCommandsMixin:
|
||||
else:
|
||||
# Try direct session ID lookup first (so `/resume <session_id>`
|
||||
# works in the gateway, not just `/resume <title>`).
|
||||
session = self._session_db.get_session(name)
|
||||
session = await self._session_db.get_session(name)
|
||||
if session:
|
||||
target_id = session["id"]
|
||||
else:
|
||||
target_id = self._session_db.resolve_session_by_title(name)
|
||||
target_id = await self._session_db.resolve_session_by_title(name)
|
||||
if not target_id:
|
||||
return t("gateway.resume.not_found", name=name)
|
||||
# Compression creates child continuations that hold the live transcript.
|
||||
# Follow that chain so gateway /resume matches CLI behavior (#15000).
|
||||
try:
|
||||
target_id = self._session_db.resolve_resume_session_id(target_id)
|
||||
target_id = await self._session_db.resolve_resume_session_id(target_id)
|
||||
except Exception as e:
|
||||
logger.debug("Failed to resolve resume continuation for %s: %s", target_id, e)
|
||||
|
||||
@@ -3255,7 +3257,7 @@ class GatewaySlashCommandsMixin:
|
||||
self._evict_cached_agent(session_key)
|
||||
|
||||
# Get the title for confirmation
|
||||
title = self._session_db.get_session_title(target_id) or name
|
||||
title = await self._session_db.get_session_title(target_id) or name
|
||||
|
||||
# Count messages for context
|
||||
history = self.session_store.load_transcript(target_id)
|
||||
@@ -3299,8 +3301,9 @@ class GatewaySlashCommandsMixin:
|
||||
return await self._handle_resume_command(resume_event)
|
||||
|
||||
current_entry = self.session_store.get_or_create_session(source)
|
||||
rows = query_session_listing(
|
||||
self._session_db,
|
||||
rows = await asyncio.to_thread(
|
||||
query_session_listing,
|
||||
getattr(self._session_db, "_db", self._session_db),
|
||||
source=source.platform.value if source.platform else None,
|
||||
current_session_id=current_entry.session_id,
|
||||
include_all_sources=include_all,
|
||||
@@ -3356,9 +3359,9 @@ class GatewaySlashCommandsMixin:
|
||||
if branch_name:
|
||||
branch_title = branch_name
|
||||
else:
|
||||
current_title = self._session_db.get_session_title(current_entry.session_id)
|
||||
current_title = await self._session_db.get_session_title(current_entry.session_id)
|
||||
base = current_title or "branch"
|
||||
branch_title = self._session_db.get_next_title_in_lineage(base)
|
||||
branch_title = await self._session_db.get_next_title_in_lineage(base)
|
||||
|
||||
parent_session_id = current_entry.session_id
|
||||
|
||||
@@ -3368,7 +3371,7 @@ class GatewaySlashCommandsMixin:
|
||||
# /sessions even after the parent is reopened and re-ended with a
|
||||
# different end_reason (e.g. tui_shutdown overwriting 'branched').
|
||||
try:
|
||||
self._session_db.create_session(
|
||||
await self._session_db.create_session(
|
||||
session_id=new_session_id,
|
||||
source=source.platform.value if source.platform else "gateway",
|
||||
model=(self.config.get("model", {}) or {}).get("default") if isinstance(self.config, dict) else None,
|
||||
@@ -3382,7 +3385,7 @@ class GatewaySlashCommandsMixin:
|
||||
# Copy conversation history to the new session
|
||||
for msg in history:
|
||||
try:
|
||||
self._session_db.append_message(
|
||||
await self._session_db.append_message(
|
||||
session_id=new_session_id,
|
||||
role=msg.get("role", "user"),
|
||||
content=msg.get("content"),
|
||||
@@ -3401,7 +3404,7 @@ class GatewaySlashCommandsMixin:
|
||||
|
||||
# Set title
|
||||
try:
|
||||
self._session_db.set_session_title(new_session_id, branch_title)
|
||||
await self._session_db.set_session_title(new_session_id, branch_title)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@@ -3484,7 +3487,7 @@ class GatewaySlashCommandsMixin:
|
||||
if not provider and getattr(self, "_session_db", None) is not None:
|
||||
try:
|
||||
_entry_for_billing = self.session_store.get_or_create_session(source)
|
||||
persisted = self._session_db.get_session(_entry_for_billing.session_id) or {}
|
||||
persisted = await self._session_db.get_session(_entry_for_billing.session_id) or {}
|
||||
except Exception:
|
||||
persisted = {}
|
||||
provider = provider or persisted.get("billing_provider")
|
||||
|
||||
@@ -14,6 +14,7 @@ Key design decisions:
|
||||
- Session source tagging ('cli', 'telegram', 'discord', etc.) for filtering
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import random
|
||||
@@ -5525,3 +5526,20 @@ class SessionDB:
|
||||
(error[:500], session_id),
|
||||
)
|
||||
self._execute_write(_do)
|
||||
|
||||
|
||||
class AsyncSessionDB:
|
||||
"""Async door onto SessionDB: offloads each call via asyncio.to_thread so a blocking SQLite call never freezes the event loop. Generic forwarder — the audit confirms no method returns a live cursor/generator."""
|
||||
|
||||
def __init__(self, db: "SessionDB") -> None:
|
||||
self._db = db
|
||||
|
||||
def __getattr__(self, name: str):
|
||||
attr = getattr(self._db, name)
|
||||
if not callable(attr):
|
||||
return attr
|
||||
|
||||
async def _offloaded(*args, **kwargs):
|
||||
return await asyncio.to_thread(attr, *args, **kwargs)
|
||||
|
||||
return _offloaded
|
||||
|
||||
@@ -264,11 +264,18 @@ def make_adapter(platform: Platform, runner=None):
|
||||
|
||||
|
||||
async def send_and_capture(adapter, text: str, platform: Platform, **event_kwargs) -> AsyncMock:
|
||||
"""Send a message through the full e2e flow and return the send mock."""
|
||||
"""Send a message through the full e2e flow and return the send mock.
|
||||
|
||||
Polls for the send rather than waiting a fixed delay: handler DB work now
|
||||
hops to worker threads (AsyncSessionDB), so completion latency varies.
|
||||
"""
|
||||
event = make_event(platform, text, **event_kwargs)
|
||||
adapter.send.reset_mock()
|
||||
await adapter.handle_message(event)
|
||||
await asyncio.sleep(0.3)
|
||||
for _ in range(40): # up to ~2s; returns as soon as the send lands
|
||||
if adapter.send.called:
|
||||
break
|
||||
await asyncio.sleep(0.05)
|
||||
return adapter.send
|
||||
|
||||
|
||||
|
||||
@@ -39,6 +39,15 @@ from unittest.mock import MagicMock
|
||||
import pytest
|
||||
|
||||
|
||||
def make_async_session_db(sync_mock=None):
|
||||
"""Wrap a sync mock SessionDB in AsyncSessionDB so gateway code that awaits
|
||||
the facade works in tests. Returns (facade, sync_mock); configure return
|
||||
values and assert calls on sync_mock."""
|
||||
from hermes_state import AsyncSessionDB
|
||||
sync_mock = sync_mock if sync_mock is not None else MagicMock()
|
||||
return AsyncSessionDB(sync_mock), sync_mock
|
||||
|
||||
|
||||
def _ensure_telegram_mock() -> None:
|
||||
"""Install a comprehensive telegram mock in sys.modules.
|
||||
|
||||
|
||||
@@ -102,13 +102,25 @@ class TestAutoResetBlockReSyncsBinding:
|
||||
"""The block must re-sync the topic binding so the next inbound message
|
||||
cannot ``switch_session`` back onto the bloated compressed child."""
|
||||
block = _find_compression_exhausted_reset_block()
|
||||
sync_calls = [
|
||||
sub
|
||||
for sub in ast.walk(block)
|
||||
if isinstance(sub, ast.Call)
|
||||
and isinstance(sub.func, ast.Attribute)
|
||||
and sub.func.attr == "_sync_telegram_topic_binding"
|
||||
]
|
||||
|
||||
def _references_helper(node):
|
||||
# Direct call: self._sync_telegram_topic_binding(...)
|
||||
if (
|
||||
isinstance(node, ast.Call)
|
||||
and isinstance(node.func, ast.Attribute)
|
||||
and node.func.attr == "_sync_telegram_topic_binding"
|
||||
):
|
||||
return True
|
||||
# Offloaded: await asyncio.to_thread(self._sync_telegram_topic_binding, ...)
|
||||
# — the helper is passed as an argument, not the call's func.
|
||||
if (
|
||||
isinstance(node, ast.Attribute)
|
||||
and node.attr == "_sync_telegram_topic_binding"
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
sync_calls = [sub for sub in ast.walk(block) if _references_helper(sub)]
|
||||
assert sync_calls, (
|
||||
"gateway/run.py auto-reset block does not call "
|
||||
"_sync_telegram_topic_binding after reset_session. Without it the "
|
||||
|
||||
@@ -12,6 +12,8 @@ Verifies that the agent cache correctly:
|
||||
import threading
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
|
||||
def _make_runner():
|
||||
@@ -1565,8 +1567,11 @@ class TestAgentCacheMessageCountRebaseline:
|
||||
"""
|
||||
|
||||
def _runner_with_db(self, db):
|
||||
from hermes_state import AsyncSessionDB
|
||||
|
||||
runner = _make_runner()
|
||||
runner._session_db = db
|
||||
# The gateway holds the async facade; the production refresh awaits it.
|
||||
runner._session_db = AsyncSessionDB(db)
|
||||
return runner
|
||||
|
||||
@staticmethod
|
||||
@@ -1577,7 +1582,7 @@ class TestAgentCacheMessageCountRebaseline:
|
||||
the cached agent (or either side is None / it's a legacy 2-tuple).
|
||||
"""
|
||||
try:
|
||||
row = runner._session_db.get_session(session_id)
|
||||
row = runner._session_db._db.get_session(session_id)
|
||||
live = row.get("message_count", 0) if row else None
|
||||
except Exception:
|
||||
live = None
|
||||
@@ -1591,7 +1596,8 @@ class TestAgentCacheMessageCountRebaseline:
|
||||
)
|
||||
return not invalidate
|
||||
|
||||
def test_same_process_turns_preserve_cached_agent(self, tmp_path):
|
||||
@pytest.mark.asyncio
|
||||
async def test_same_process_turns_preserve_cached_agent(self, tmp_path):
|
||||
"""The regression guard: consecutive same-process turns must REUSE
|
||||
the cached agent (prompt cache preserved), not rebuild every turn.
|
||||
|
||||
@@ -1619,7 +1625,7 @@ class TestAgentCacheMessageCountRebaseline:
|
||||
db.append_message("s1", role="user", content="u")
|
||||
db.append_message("s1", role="assistant", content="a")
|
||||
# Post-turn re-baseline (the fix).
|
||||
runner._refresh_agent_cache_message_count("telegram:s1", "s1")
|
||||
await runner._refresh_agent_cache_message_count("telegram:s1", "s1")
|
||||
# Next turn's guard decision.
|
||||
if self._guard_would_reuse(runner, "telegram:s1", "s1"):
|
||||
reuses += 1
|
||||
@@ -1630,7 +1636,8 @@ class TestAgentCacheMessageCountRebaseline:
|
||||
with runner._agent_cache_lock:
|
||||
assert runner._agent_cache["telegram:s1"][0] is agent
|
||||
|
||||
def test_cross_process_write_still_invalidates(self, tmp_path):
|
||||
@pytest.mark.asyncio
|
||||
async def test_cross_process_write_still_invalidates(self, tmp_path):
|
||||
"""After the re-baseline, a DIFFERENT process appending to the same
|
||||
session must still flip the guard to rebuild (the #45966 fix holds).
|
||||
"""
|
||||
@@ -1650,7 +1657,7 @@ class TestAgentCacheMessageCountRebaseline:
|
||||
# Our own turn + re-baseline -> reuse next turn.
|
||||
db.append_message("s1", role="user", content="u")
|
||||
db.append_message("s1", role="assistant", content="a")
|
||||
runner._refresh_agent_cache_message_count("telegram:s1", "s1")
|
||||
await runner._refresh_agent_cache_message_count("telegram:s1", "s1")
|
||||
assert self._guard_would_reuse(runner, "telegram:s1", "s1") is True
|
||||
|
||||
# ANOTHER process (e.g. the desktop dashboard backend) appends a turn
|
||||
@@ -1660,10 +1667,11 @@ class TestAgentCacheMessageCountRebaseline:
|
||||
# Guard must now reject reuse so the agent rebuilds from fresh disk.
|
||||
assert self._guard_would_reuse(runner, "telegram:s1", "s1") is False
|
||||
|
||||
def test_rebaseline_is_fail_safe_and_skips_legacy_and_pending(self, tmp_path):
|
||||
@pytest.mark.asyncio
|
||||
async def test_rebaseline_is_fail_safe_and_skips_legacy_and_pending(self, tmp_path):
|
||||
"""Re-baseline must never crash and must leave legacy 2-tuples and
|
||||
pending-sentinel entries untouched."""
|
||||
from hermes_state import SessionDB
|
||||
from hermes_state import AsyncSessionDB, SessionDB
|
||||
from gateway.run import _AGENT_PENDING_SENTINEL
|
||||
|
||||
db = SessionDB(db_path=tmp_path / "sessions.db")
|
||||
@@ -1673,24 +1681,24 @@ class TestAgentCacheMessageCountRebaseline:
|
||||
|
||||
# No session_db -> no-op, no crash.
|
||||
runner._session_db = None
|
||||
runner._refresh_agent_cache_message_count("telegram:s1", "s1")
|
||||
runner._session_db = db
|
||||
await runner._refresh_agent_cache_message_count("telegram:s1", "s1")
|
||||
runner._session_db = AsyncSessionDB(db)
|
||||
|
||||
# Falsy session_id -> no-op.
|
||||
runner._refresh_agent_cache_message_count("telegram:s1", "")
|
||||
runner._refresh_agent_cache_message_count("telegram:s1", None)
|
||||
await runner._refresh_agent_cache_message_count("telegram:s1", "")
|
||||
await runner._refresh_agent_cache_message_count("telegram:s1", None)
|
||||
|
||||
# Legacy 2-tuple is left untouched (it opts out of the guard).
|
||||
with runner._agent_cache_lock:
|
||||
runner._agent_cache["telegram:s1"] = (object(), "sig")
|
||||
runner._refresh_agent_cache_message_count("telegram:s1", "s1")
|
||||
await runner._refresh_agent_cache_message_count("telegram:s1", "s1")
|
||||
with runner._agent_cache_lock:
|
||||
assert len(runner._agent_cache["telegram:s1"]) == 2
|
||||
|
||||
# Pending sentinel entry is left untouched.
|
||||
with runner._agent_cache_lock:
|
||||
runner._agent_cache["telegram:s1"] = (_AGENT_PENDING_SENTINEL, "sig", 0)
|
||||
runner._refresh_agent_cache_message_count("telegram:s1", "s1")
|
||||
await runner._refresh_agent_cache_message_count("telegram:s1", "s1")
|
||||
with runner._agent_cache_lock:
|
||||
assert runner._agent_cache["telegram:s1"][0] is _AGENT_PENDING_SENTINEL
|
||||
assert runner._agent_cache["telegram:s1"][2] == 0
|
||||
@@ -1700,10 +1708,10 @@ class TestAgentCacheMessageCountRebaseline:
|
||||
def get_session(self, _sid):
|
||||
raise RuntimeError("db locked")
|
||||
|
||||
runner._session_db = _BoomDB() # type: ignore[assignment]
|
||||
runner._session_db = AsyncSessionDB(_BoomDB()) # type: ignore[assignment]
|
||||
with runner._agent_cache_lock:
|
||||
runner._agent_cache["telegram:s1"] = (object(), "sig", 5)
|
||||
runner._refresh_agent_cache_message_count("telegram:s1", "s1")
|
||||
await runner._refresh_agent_cache_message_count("telegram:s1", "s1")
|
||||
with runner._agent_cache_lock:
|
||||
assert runner._agent_cache["telegram:s1"][2] == 5
|
||||
|
||||
|
||||
402
tests/gateway/test_async_session_db.py
Normal file
402
tests/gateway/test_async_session_db.py
Normal file
@@ -0,0 +1,402 @@
|
||||
"""AsyncSessionDB offload facade + gateway raw-call guard.
|
||||
|
||||
The gateway runs one asyncio loop for every session; SessionDB is synchronous,
|
||||
so a raw call on the loop freezes every conversation until it returns.
|
||||
AsyncSessionDB offloads each call via asyncio.to_thread. These tests pin the
|
||||
facade's contract and lock the gateway boundary so a 39th raw call can't regress.
|
||||
"""
|
||||
|
||||
import ast
|
||||
import asyncio
|
||||
import threading
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
import hermes_state
|
||||
from hermes_state import AsyncSessionDB
|
||||
|
||||
|
||||
class _SpyDB:
|
||||
"""SessionDB stand-in recording the thread each call ran on."""
|
||||
|
||||
def __init__(self):
|
||||
self.calls = []
|
||||
self.attr = "plain-value"
|
||||
|
||||
def _ran_on(self, name):
|
||||
self.calls.append((name, threading.get_ident()))
|
||||
|
||||
def returns_none(self):
|
||||
self._ran_on("returns_none")
|
||||
return None
|
||||
|
||||
def returns_bool(self):
|
||||
self._ran_on("returns_bool")
|
||||
return True
|
||||
|
||||
def returns_str(self):
|
||||
self._ran_on("returns_str")
|
||||
return "title"
|
||||
|
||||
def returns_dict(self):
|
||||
self._ran_on("returns_dict")
|
||||
return {"id": "s1"}
|
||||
|
||||
def returns_list(self):
|
||||
self._ran_on("returns_list")
|
||||
return [{"id": "s1"}, {"id": "s2"}]
|
||||
|
||||
def raises(self):
|
||||
self._ran_on("raises")
|
||||
raise ValueError("boom")
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------
|
||||
# Facade behaviour
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_offloads_off_calling_thread():
|
||||
"""A call must execute on a worker thread, not the caller's loop thread."""
|
||||
db = _SpyDB()
|
||||
facade = AsyncSessionDB(db)
|
||||
caller_ident = threading.get_ident()
|
||||
|
||||
await facade.returns_none()
|
||||
|
||||
ran_idents = [ident for _name, ident in db.calls]
|
||||
assert ran_idents and all(i != caller_ident for i in ran_idents)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_offload_goes_through_to_thread(monkeypatch):
|
||||
"""The offload must route through asyncio.to_thread (where the facade lives)."""
|
||||
db = _SpyDB()
|
||||
facade = AsyncSessionDB(db)
|
||||
|
||||
seen = []
|
||||
real = asyncio.to_thread
|
||||
|
||||
async def _spy(func, *args, **kwargs):
|
||||
seen.append(getattr(func, "__name__", repr(func)))
|
||||
return await real(func, *args, **kwargs)
|
||||
|
||||
monkeypatch.setattr(hermes_state.asyncio, "to_thread", _spy)
|
||||
await facade.returns_str()
|
||||
assert "returns_str" in seen
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"method,expected",
|
||||
[
|
||||
("returns_none", None),
|
||||
("returns_bool", True),
|
||||
("returns_str", "title"),
|
||||
("returns_dict", {"id": "s1"}),
|
||||
("returns_list", [{"id": "s1"}, {"id": "s2"}]),
|
||||
],
|
||||
)
|
||||
async def test_returns_underlying_value_unchanged(method, expected):
|
||||
facade = AsyncSessionDB(_SpyDB())
|
||||
assert await getattr(facade, method)() == expected
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_propagates_exception():
|
||||
facade = AsyncSessionDB(_SpyDB())
|
||||
with pytest.raises(ValueError, match="boom"):
|
||||
await facade.raises()
|
||||
|
||||
|
||||
def test_non_callable_attribute_passes_through():
|
||||
facade = AsyncSessionDB(_SpyDB())
|
||||
assert facade.attr == "plain-value"
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------
|
||||
# Guard: no raw self._session_db.<method>( on the gateway loop
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
_GATEWAY_FILES = ("gateway/run.py", "gateway/slash_commands.py")
|
||||
# The only legitimate non-loop paths:
|
||||
# - SessionDB.sanitize_title: pure @staticmethod string cleaning, no DB.
|
||||
# - self._session_db._db.<x>: the sync escape, allowed ONLY where the call is
|
||||
# provably off the event loop — construction (__init__, before the loop
|
||||
# serves) and the run_sync closure (executed in a thread-pool executor).
|
||||
# Three such sites today; a fourth must be justified and this count bumped.
|
||||
_ALLOWED_SYNC_DB_ESCAPES = 3
|
||||
|
||||
# Sync helpers that touch SessionDB but are NEVER invoked bare on the loop:
|
||||
# every loop-side call wraps them in ``asyncio.to_thread(...)`` and the only
|
||||
# bare calls live in the run_sync thread-pool closure. Their DB calls therefore
|
||||
# run off-loop. The guard exempts their bodies AND enforces the contract — see
|
||||
# test_offloaded_helpers_never_called_bare_on_loop. Adding a helper here without
|
||||
# wrapping its loop call sites makes that test fail.
|
||||
_OFFLOADED_SYNC_HELPERS = frozenset({
|
||||
"_telegram_topic_mode_enabled",
|
||||
"_is_telegram_topic_lane",
|
||||
"_is_telegram_topic_root_lobby",
|
||||
"_recover_telegram_topic_thread_id",
|
||||
"_normalize_source_for_session_key",
|
||||
"_record_telegram_topic_binding",
|
||||
"_sync_telegram_topic_binding",
|
||||
"_telegram_topic_new_header",
|
||||
"_schedule_telegram_topic_title_rename",
|
||||
"_apply_topic_recovery",
|
||||
})
|
||||
|
||||
|
||||
def _repo_root() -> Path:
|
||||
return Path(__file__).resolve().parents[2]
|
||||
|
||||
|
||||
class _RawCallVisitor:
|
||||
"""Collect non-awaited SessionDB calls reachable on the gateway loop.
|
||||
|
||||
Catches both shapes:
|
||||
* direct: self._session_db.<method>(...)
|
||||
* aliased: db = getattr(self, "_session_db", None) / db = self._session_db
|
||||
then db.<method>(...)
|
||||
An ``await x.y()`` is Await(value=Call(...)); those Calls are exempt (the
|
||||
migrated path). The self._session_db._db.<x> sync escape is counted
|
||||
separately. SessionDB.sanitize_title is a staticmethod called on the class,
|
||||
so it never matches either shape.
|
||||
|
||||
Alias detection scans, per function scope, for locals bound to the gateway's
|
||||
_session_db (incl. closures that bind it off a captured ``self``-like param),
|
||||
then flags non-awaited calls on those names. The literal-grep blind spot that
|
||||
let six loop-reachable calls hide behind ``getattr(self, "_session_db")`` is
|
||||
exactly what this closes.
|
||||
"""
|
||||
|
||||
def __init__(self, tree: ast.AST):
|
||||
self.raw_calls = [] # (method, lineno) — direct, non-awaited, on-loop
|
||||
self.alias_calls = [] # (method, lineno) — via a _session_db-bound local, on-loop
|
||||
self.db_escapes = [] # self._session_db._db.<x> sites (lineno)
|
||||
# BARE self.<helper>(...) call sites of offloaded helpers — i.e. the
|
||||
# helper is actually *called*, not passed to asyncio.to_thread (which
|
||||
# references it as an attribute, producing no Call node here). Each is
|
||||
# (helper, lineno, enclosing_fn) for the contract test.
|
||||
self.bare_helper_calls = []
|
||||
|
||||
awaited = {id(n.value) for n in ast.walk(tree)
|
||||
if isinstance(n, ast.Await) and isinstance(n.value, ast.Call)}
|
||||
alias_names = self._collect_alias_names(tree)
|
||||
# Map each node to the name of the function whose body lexically encloses
|
||||
# it, so DB calls inside an offloaded helper (which runs off-loop) are
|
||||
# exempt while bare on-loop calls are not.
|
||||
enclosing = self._enclosing_fn_map(tree)
|
||||
ancestry = self._ancestor_fns(tree) # id(node) -> frozenset of enclosing fn names
|
||||
|
||||
for node in ast.walk(tree):
|
||||
if not isinstance(node, ast.Call):
|
||||
continue
|
||||
func = node.func
|
||||
if not isinstance(func, ast.Attribute):
|
||||
continue
|
||||
encl_fn = enclosing.get(id(node))
|
||||
in_offloaded_helper = encl_fn in _OFFLOADED_SYNC_HELPERS
|
||||
# Bare call of an offloaded helper (self._helper(...)). A to_thread
|
||||
# offload passes the helper as an attribute arg, not a Call, so it
|
||||
# never lands here — exactly the distinction the contract test needs.
|
||||
if (
|
||||
isinstance(func.value, ast.Name) and func.value.id == "self"
|
||||
and func.attr in _OFFLOADED_SYNC_HELPERS
|
||||
):
|
||||
self.bare_helper_calls.append(
|
||||
(func.attr, node.lineno, ancestry.get(id(node), frozenset()))
|
||||
)
|
||||
# alias.<method>(...) -> aliased loop call (var bound to _session_db)
|
||||
if (
|
||||
isinstance(func.value, ast.Name)
|
||||
and func.value.id in alias_names
|
||||
and func.attr not in ("_db",)
|
||||
and id(node) not in awaited
|
||||
and not in_offloaded_helper
|
||||
):
|
||||
self.alias_calls.append((func.attr, node.lineno))
|
||||
continue
|
||||
if not isinstance(func.value, ast.Attribute):
|
||||
continue
|
||||
inner = func.value
|
||||
# self._session_db._db.<method>(...) -> sync escape
|
||||
if (
|
||||
inner.attr == "_db"
|
||||
and isinstance(inner.value, ast.Attribute)
|
||||
and inner.value.attr == "_session_db"
|
||||
and isinstance(inner.value.value, ast.Name)
|
||||
and inner.value.value.id == "self"
|
||||
):
|
||||
self.db_escapes.append(inner.lineno)
|
||||
# self._session_db.<method>(...) not wrapped in await -> raw loop call
|
||||
elif (
|
||||
inner.attr == "_session_db"
|
||||
and isinstance(inner.value, ast.Name)
|
||||
and inner.value.id == "self"
|
||||
and id(node) not in awaited
|
||||
and not in_offloaded_helper
|
||||
):
|
||||
self.raw_calls.append((func.attr, node.lineno))
|
||||
|
||||
@staticmethod
|
||||
def _enclosing_fn_map(tree: ast.AST) -> dict:
|
||||
"""Map id(node) -> name of the nearest lexically-enclosing function."""
|
||||
out = {}
|
||||
|
||||
def walk(node, fn_name):
|
||||
this_fn = fn_name
|
||||
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
|
||||
this_fn = node.name
|
||||
for child in ast.iter_child_nodes(node):
|
||||
out[id(child)] = this_fn
|
||||
walk(child, this_fn)
|
||||
|
||||
walk(tree, None)
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def _ancestor_fns(tree: ast.AST) -> dict:
|
||||
"""Map id(node) -> frozenset of ALL enclosing function names (any depth)."""
|
||||
out = {}
|
||||
|
||||
def walk(node, stack):
|
||||
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
|
||||
stack = stack + (node.name,)
|
||||
for child in ast.iter_child_nodes(node):
|
||||
out[id(child)] = frozenset(stack)
|
||||
walk(child, stack)
|
||||
|
||||
walk(tree, ())
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def _is_session_db_source(value: ast.AST) -> bool:
|
||||
"""True if an assignment RHS resolves to <obj>._session_db.
|
||||
|
||||
Matches both ``<obj>._session_db`` and ``getattr(<obj>, "_session_db", ...)``
|
||||
where <obj> is any Name (covers ``self`` and captured closure params like
|
||||
``_self``). Excludes the ``._db`` sync handle.
|
||||
"""
|
||||
if isinstance(value, ast.Attribute):
|
||||
return value.attr == "_session_db" and isinstance(value.value, ast.Name)
|
||||
if (
|
||||
isinstance(value, ast.Call)
|
||||
and isinstance(value.func, ast.Name)
|
||||
and value.func.id == "getattr"
|
||||
and len(value.args) >= 2
|
||||
and isinstance(value.args[1], ast.Constant)
|
||||
and value.args[1].value == "_session_db"
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def _collect_alias_names(cls, tree: ast.AST) -> set:
|
||||
names = set()
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, ast.Assign) and cls._is_session_db_source(node.value):
|
||||
for tgt in node.targets:
|
||||
if isinstance(tgt, ast.Name):
|
||||
names.add(tgt.id)
|
||||
elif isinstance(node, ast.AnnAssign) and node.value is not None \
|
||||
and cls._is_session_db_source(node.value) \
|
||||
and isinstance(node.target, ast.Name):
|
||||
names.add(node.target.id)
|
||||
return names
|
||||
|
||||
|
||||
def _scan(rel_path: str) -> _RawCallVisitor:
|
||||
source = (_repo_root() / rel_path).read_text(encoding="utf-8")
|
||||
return _RawCallVisitor(ast.parse(source))
|
||||
|
||||
|
||||
def test_no_raw_session_db_calls_on_gateway_loop():
|
||||
"""Fail if any non-awaited SessionDB call appears in gateway files.
|
||||
|
||||
Every loop-reachable DB call must go through AsyncSessionDB (await), whether
|
||||
spelled directly (self._session_db.<method>(...)) or via a local alias
|
||||
(db = getattr(self, "_session_db", None); db.<method>(...)). The
|
||||
sanitize_title staticmethod is called on the class, not self/an alias, so it
|
||||
is not matched; the _db. sync escape is checked separately below.
|
||||
"""
|
||||
violations = []
|
||||
for rel in _GATEWAY_FILES:
|
||||
v = _scan(rel)
|
||||
violations.extend(f"{rel}:{ln} self._session_db.{m}(" for m, ln in v.raw_calls)
|
||||
violations.extend(f"{rel}:{ln} <alias>.{m}( (binds _session_db)" for m, ln in v.alias_calls)
|
||||
assert not violations, (
|
||||
"Non-awaited SessionDB calls on the gateway loop — route through "
|
||||
"AsyncSessionDB (await ...):\n " + "\n ".join(violations)
|
||||
)
|
||||
|
||||
|
||||
def test_sync_db_escape_confined_to_off_loop_sites():
|
||||
"""The self._session_db._db. sync escape must stay confined to known sites.
|
||||
|
||||
It is legitimate only where the call is provably off the loop: construction
|
||||
(before the loop serves) and the run_sync executor closure. More occurrences
|
||||
than the reviewed count means a blocking call may have leaked back onto the
|
||||
loop through the escape hatch.
|
||||
"""
|
||||
total = sum(len(_scan(rel).db_escapes) for rel in _GATEWAY_FILES)
|
||||
assert total <= _ALLOWED_SYNC_DB_ESCAPES, (
|
||||
f"self._session_db._db. sync escape used {total} times; "
|
||||
f"at most {_ALLOWED_SYNC_DB_ESCAPES} (construction + run_sync) is allowed."
|
||||
)
|
||||
|
||||
|
||||
def test_offloaded_helpers_never_called_bare_on_loop():
|
||||
"""The offloaded sync helpers must never be called bare on the event loop.
|
||||
|
||||
They touch SessionDB synchronously, so a bare ``self._helper(...)`` on the
|
||||
loop would freeze it. The contract: loop-side callers wrap them in
|
||||
``await asyncio.to_thread(self._helper, ...)`` (which references the helper
|
||||
as an attribute — no Call node — so it never appears here). A bare call is
|
||||
only legitimate when it runs off-loop: inside the ``run_sync`` thread-pool
|
||||
closure, or inside another offloaded helper (sync->sync, same thread). Any
|
||||
other bare call means a helper whose body the guard exempts is being invoked
|
||||
on the loop anyway — re-freezing the loop through the exemption.
|
||||
"""
|
||||
off_loop_ok = _OFFLOADED_SYNC_HELPERS | {"run_sync"}
|
||||
violations = []
|
||||
for rel in _GATEWAY_FILES:
|
||||
v = _scan(rel)
|
||||
for helper, ln, ancestors in v.bare_helper_calls:
|
||||
if not (ancestors & off_loop_ok):
|
||||
violations.append(f"{rel}:{ln} bare self.{helper}( on the loop")
|
||||
assert not violations, (
|
||||
"Offloaded sync helper called bare on the gateway loop — wrap in "
|
||||
"await asyncio.to_thread(self.<helper>, ...):\n " + "\n ".join(violations)
|
||||
)
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------
|
||||
# Interleaving safety: offloading opens await points where coroutines can
|
||||
# interleave against the same session rows. The gateway relies on SessionDB's
|
||||
# atomic operations (compare-and-set, INSERT OR IGNORE) to stay single-winner.
|
||||
# These pin that the defenses hold when driven concurrently through the facade.
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_claim_handoff_single_winner(tmp_path):
|
||||
db = AsyncSessionDB(hermes_state.SessionDB(db_path=tmp_path / "state.db"))
|
||||
sid = "s-handoff"
|
||||
await db.create_session(sid, "test")
|
||||
await db.request_handoff(sid, "telegram")
|
||||
|
||||
results = await asyncio.gather(*(db.claim_handoff(sid) for _ in range(20)))
|
||||
|
||||
assert sum(results) == 1, f"exactly one claim must win, got {sum(results)}"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_create_session_idempotent(tmp_path):
|
||||
db = AsyncSessionDB(hermes_state.SessionDB(db_path=tmp_path / "state.db"))
|
||||
sid = "s-create"
|
||||
|
||||
await asyncio.gather(*(db.create_session(sid, "test") for _ in range(20)))
|
||||
|
||||
rows = await db.list_sessions_rich(limit=100)
|
||||
assert sum(1 for r in rows if r["id"] == sid) == 1
|
||||
@@ -5,13 +5,14 @@ The Discord gateway heartbeat was stalling because the handoff watcher
|
||||
SQLite-backed ``SessionDB`` directly on the asyncio event loop every 2s
|
||||
('Shard ID None heartbeat blocked for more than N seconds').
|
||||
|
||||
The fix (mirroring PR #40782) wraps every blocking ``SessionDB`` call inside
|
||||
the watcher loop in ``asyncio.to_thread(...)`` so the SQLite I/O runs on a
|
||||
worker thread and never blocks the event loop / Discord heartbeat.
|
||||
The fix routes every blocking ``SessionDB`` call in the watcher through the
|
||||
``AsyncSessionDB`` facade, which offloads each call via ``asyncio.to_thread`` so
|
||||
the SQLite I/O runs on a worker thread and never blocks the event loop / Discord
|
||||
heartbeat.
|
||||
|
||||
These tests assert that behaviour contract. They are mutation-survivable:
|
||||
reverting any ``asyncio.to_thread(self._session_db.<call>)`` wrap back to a
|
||||
direct synchronous call on the loop makes the relevant assertion fail.
|
||||
reverting any ``await self._session_db.<call>(...)`` back to a direct synchronous
|
||||
call on the loop makes the relevant assertion fail.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
@@ -62,9 +63,15 @@ class _RecordingSessionDB:
|
||||
|
||||
|
||||
def _make_fake_runner(session_db, *, fail_process=False):
|
||||
"""Build a minimal object that exposes exactly what the loop body touches."""
|
||||
"""Build a minimal object that exposes exactly what the loop body touches.
|
||||
|
||||
The watcher now talks to the SessionDB through the AsyncSessionDB facade,
|
||||
so wrap the recording stand-in the same way the gateway does.
|
||||
"""
|
||||
from hermes_state import AsyncSessionDB
|
||||
|
||||
fake = types.SimpleNamespace()
|
||||
fake._session_db = session_db
|
||||
fake._session_db = AsyncSessionDB(session_db)
|
||||
# _running yields True for the first loop check, then False so the loop
|
||||
# exits after a single tick.
|
||||
states = iter([True, False])
|
||||
@@ -141,21 +148,23 @@ async def test_watcher_offloads_fail_handoff_to_thread(monkeypatch):
|
||||
async def test_watcher_wraps_calls_via_asyncio_to_thread(monkeypatch):
|
||||
"""Explicitly assert the offload goes through asyncio.to_thread.
|
||||
|
||||
Patches ``run.asyncio.to_thread`` and records which SessionDB callables
|
||||
were handed to it. Mutation-survivable: dropping any wrap removes its
|
||||
callable from the recorded set.
|
||||
Patches the AsyncSessionDB facade's ``asyncio.to_thread`` (it lives in
|
||||
hermes_state) and records which SessionDB callables were handed to it.
|
||||
Mutation-survivable: dropping any await removes its callable from the set.
|
||||
"""
|
||||
import hermes_state
|
||||
|
||||
db = _RecordingSessionDB(loop_thread_ident=-1)
|
||||
fake = _make_fake_runner(db, fail_process=False)
|
||||
|
||||
wrapped = []
|
||||
real_to_thread = run.asyncio.to_thread
|
||||
real_to_thread = hermes_state.asyncio.to_thread
|
||||
|
||||
async def _spy_to_thread(func, *args, **kwargs):
|
||||
wrapped.append(getattr(func, "__name__", repr(func)))
|
||||
return await real_to_thread(func, *args, **kwargs)
|
||||
|
||||
monkeypatch.setattr(run.asyncio, "to_thread", _spy_to_thread)
|
||||
monkeypatch.setattr(hermes_state.asyncio, "to_thread", _spy_to_thread)
|
||||
|
||||
await _run_one_tick(fake, monkeypatch)
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@ import pytest
|
||||
|
||||
from gateway.config import GatewayConfig, Platform, PlatformConfig
|
||||
from gateway.platforms.base import MessageEvent
|
||||
from hermes_state import AsyncSessionDB
|
||||
from gateway.session import (
|
||||
SessionContext,
|
||||
SessionEntry,
|
||||
@@ -343,16 +344,16 @@ def _make_runner(current_source: SessionSource, entries: list[SessionEntry]):
|
||||
runner._clear_session_boundary_security_state = MagicMock()
|
||||
runner._evict_cached_agent = MagicMock()
|
||||
runner._queue_depth = MagicMock(return_value=0)
|
||||
runner._session_db = MagicMock()
|
||||
runner._session_db.list_sessions_rich.return_value = [
|
||||
runner._session_db = AsyncSessionDB(MagicMock())
|
||||
runner._session_db._db.list_sessions_rich.return_value = [
|
||||
{"id": entry.session_id, "title": entry.display_name, "preview": ""}
|
||||
for entry in entries
|
||||
]
|
||||
runner._session_db.resolve_resume_session_id.side_effect = lambda sid: sid
|
||||
runner._session_db.get_session_title.side_effect = lambda sid: {
|
||||
runner._session_db._db.resolve_resume_session_id.side_effect = lambda sid: sid
|
||||
runner._session_db._db.get_session_title.side_effect = lambda sid: {
|
||||
entry.session_id: entry.display_name for entry in entries
|
||||
}.get(sid)
|
||||
runner._session_db.get_session.return_value = None
|
||||
runner._session_db._db.get_session.return_value = None
|
||||
return runner
|
||||
|
||||
|
||||
@@ -388,7 +389,7 @@ async def test_matrix_resume_does_not_cross_rooms_by_default():
|
||||
entry_a = _entry(source_a, "session-a", "Project A Plan")
|
||||
entry_b = _entry(source_b, "session-b", "Project B Plan")
|
||||
runner = _make_runner(source_b, [entry_a, entry_b])
|
||||
runner._session_db.resolve_session_by_title.return_value = "session-a"
|
||||
runner._session_db._db.resolve_session_by_title.return_value = "session-a"
|
||||
|
||||
result = await runner._handle_resume_command(_event("/resume Project A Plan", source_b))
|
||||
|
||||
@@ -406,7 +407,7 @@ async def test_matrix_resume_allows_same_room_session():
|
||||
source_b, "session-b-current", "Current Project B"
|
||||
)
|
||||
runner.session_store.switch_session.return_value = entry_b
|
||||
runner._session_db.resolve_session_by_title.return_value = "session-b-old"
|
||||
runner._session_db._db.resolve_session_by_title.return_value = "session-b-old"
|
||||
|
||||
result = await runner._handle_resume_command(_event("/resume Project B Plan", source_b))
|
||||
|
||||
@@ -423,14 +424,14 @@ async def test_matrix_resume_quoted_title_same_room():
|
||||
source_b, "session-b-current", "Current Project B"
|
||||
)
|
||||
runner.session_store.switch_session.return_value = entry_b
|
||||
runner._session_db.resolve_session_by_title.return_value = "session-b-old"
|
||||
runner._session_db._db.resolve_session_by_title.return_value = "session-b-old"
|
||||
|
||||
result = await runner._handle_resume_command(
|
||||
_event('/resume "Project B Plan"', source_b)
|
||||
)
|
||||
|
||||
assert "Resumed session" in result
|
||||
runner._session_db.resolve_session_by_title.assert_called_once_with("Project B Plan")
|
||||
runner._session_db._db.resolve_session_by_title.assert_called_once_with("Project B Plan")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -440,7 +441,7 @@ async def test_matrix_resume_quoted_title_cross_room_blocked():
|
||||
entry_a = _entry(source_a, "session-a", "Project A Plan")
|
||||
entry_b = _entry(source_b, "session-b", "Project B Plan")
|
||||
runner = _make_runner(source_b, [entry_a, entry_b])
|
||||
runner._session_db.resolve_session_by_title.return_value = "session-a"
|
||||
runner._session_db._db.resolve_session_by_title.return_value = "session-a"
|
||||
|
||||
result = await runner._handle_resume_command(
|
||||
_event('/resume "Project A Plan"', source_b)
|
||||
@@ -471,7 +472,7 @@ async def test_matrix_resume_cross_room_requires_explicit_flag_and_warns():
|
||||
entry_b = _entry(source_b, "session-b", "Project B Plan")
|
||||
runner = _make_runner(source_b, [entry_a, entry_b])
|
||||
runner.session_store.switch_session.return_value = entry_a
|
||||
runner._session_db.resolve_session_by_title.return_value = "session-a"
|
||||
runner._session_db._db.resolve_session_by_title.return_value = "session-a"
|
||||
|
||||
result = await runner._handle_resume_command(
|
||||
_event("/resume --cross-room Project A Plan", source_b)
|
||||
|
||||
@@ -39,6 +39,10 @@ def _make_runner(session_db=None, current_session_id="current_session_001",
|
||||
runner.adapters = {}
|
||||
runner.config = SimpleNamespace(platforms={})
|
||||
runner._voice_mode = {}
|
||||
# Gateway holds the async facade; the slash handlers await it.
|
||||
if session_db is not None:
|
||||
from hermes_state import AsyncSessionDB
|
||||
session_db = AsyncSessionDB(session_db)
|
||||
runner._session_db = session_db
|
||||
runner._running_agents = {}
|
||||
runner._is_user_authorized = lambda _source: True
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from hermes_state import AsyncSessionDB
|
||||
"""Regression tests for approval-state cleanup on session boundaries."""
|
||||
|
||||
from datetime import datetime
|
||||
@@ -86,9 +87,9 @@ def _make_resume_runner():
|
||||
runner.session_store.get_or_create_session.return_value = current_entry
|
||||
runner.session_store.switch_session.return_value = resumed_entry
|
||||
runner.session_store.load_transcript.return_value = []
|
||||
runner._session_db = MagicMock()
|
||||
runner._session_db.resolve_session_by_title.return_value = "resumed-session"
|
||||
runner._session_db.get_session_title.return_value = "Resumed Work"
|
||||
runner._session_db = AsyncSessionDB(MagicMock())
|
||||
runner._session_db._db.resolve_session_by_title.return_value = "resumed-session"
|
||||
runner._session_db._db.get_session_title.return_value = "Resumed Work"
|
||||
return runner, session_key
|
||||
|
||||
|
||||
@@ -116,9 +117,9 @@ def _make_branch_runner():
|
||||
{"role": "assistant", "content": "world"},
|
||||
]
|
||||
runner.session_store.switch_session.return_value = branched_entry
|
||||
runner._session_db = MagicMock()
|
||||
runner._session_db.get_session_title.return_value = "Current Work"
|
||||
runner._session_db.get_next_title_in_lineage.return_value = "Current Work #2"
|
||||
runner._session_db = AsyncSessionDB(MagicMock())
|
||||
runner._session_db._db.get_session_title.return_value = "Current Work"
|
||||
runner._session_db._db.get_next_title_in_lineage.return_value = "Current Work #2"
|
||||
return runner, session_key
|
||||
|
||||
|
||||
@@ -208,7 +209,7 @@ async def test_branch_preserves_persisted_assistant_metadata():
|
||||
result = await runner._handle_branch_command(_make_event("/branch"))
|
||||
|
||||
assert "Branched to" in result
|
||||
append_calls = runner._session_db.append_message.call_args_list
|
||||
append_calls = runner._session_db._db.append_message.call_args_list
|
||||
assert len(append_calls) == 2
|
||||
assistant_kwargs = append_calls[1].kwargs
|
||||
assert assistant_kwargs["role"] == "assistant"
|
||||
|
||||
@@ -171,8 +171,12 @@ async def test_second_message_during_sentinel_queued_not_duplicate():
|
||||
with patch.object(GatewayRunner, "_handle_message_with_agent", slow_inner):
|
||||
# Start first message (will block at barrier)
|
||||
task1 = asyncio.create_task(runner._handle_message(event1))
|
||||
# Yield so task1 enters slow_inner and sentinel is set
|
||||
await asyncio.sleep(0)
|
||||
# Yield until task1 has claimed the sentinel (it crosses a few awaits
|
||||
# before the claim; don't assume a fixed number of scheduler slices).
|
||||
for _ in range(50):
|
||||
await asyncio.sleep(0)
|
||||
if runner._running_agents.get(session_key) is _AGENT_PENDING_SENTINEL:
|
||||
break
|
||||
|
||||
# Verify sentinel is set
|
||||
assert runner._running_agents.get(session_key) is _AGENT_PENDING_SENTINEL
|
||||
@@ -417,7 +421,10 @@ async def test_stop_during_sentinel_force_cleans_session():
|
||||
|
||||
with patch.object(GatewayRunner, "_handle_message_with_agent", slow_inner):
|
||||
task1 = asyncio.create_task(runner._handle_message(event1))
|
||||
await asyncio.sleep(0)
|
||||
for _ in range(50):
|
||||
await asyncio.sleep(0)
|
||||
if runner._running_agents.get(session_key) is _AGENT_PENDING_SENTINEL:
|
||||
break
|
||||
|
||||
# Sentinel should be set
|
||||
assert runner._running_agents.get(session_key) is _AGENT_PENDING_SENTINEL
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from hermes_state import AsyncSessionDB
|
||||
"""Tests for gateway /status behavior and token persistence."""
|
||||
|
||||
from datetime import datetime
|
||||
@@ -53,11 +54,11 @@ def _make_runner(session_entry: SessionEntry, *, platform: Platform = Platform.T
|
||||
runner._session_run_generation = {}
|
||||
runner._pending_messages = {}
|
||||
runner._pending_approvals = {}
|
||||
runner._session_db = MagicMock()
|
||||
runner._session_db.get_session_title.return_value = None
|
||||
runner._session_db = AsyncSessionDB(MagicMock())
|
||||
runner._session_db._db.get_session_title.return_value = None
|
||||
# Default: no DB row → /status reports 0 tokens. Tests that exercise
|
||||
# the populated path override this.
|
||||
runner._session_db.get_session.return_value = None
|
||||
runner._session_db._db.get_session.return_value = None
|
||||
runner._reasoning_config = None
|
||||
runner._provider_routing = {}
|
||||
runner._fallback_model = None
|
||||
@@ -86,7 +87,7 @@ async def test_status_command_reports_running_agent_without_interrupt(monkeypatc
|
||||
)
|
||||
runner = _make_runner(session_entry)
|
||||
# Token total comes from the SQLite SessionDB, not SessionEntry.
|
||||
runner._session_db.get_session.return_value = {
|
||||
runner._session_db._db.get_session.return_value = {
|
||||
"input_tokens": 200,
|
||||
"output_tokens": 121,
|
||||
"cache_read_tokens": 0,
|
||||
@@ -118,7 +119,7 @@ async def test_status_command_includes_session_title_when_present():
|
||||
total_tokens=321,
|
||||
)
|
||||
runner = _make_runner(session_entry)
|
||||
runner._session_db.get_session_title.return_value = "My titled session"
|
||||
runner._session_db._db.get_session_title.return_value = "My titled session"
|
||||
|
||||
result = await runner._handle_message(_make_event("/status"))
|
||||
|
||||
@@ -141,7 +142,7 @@ async def test_status_command_reads_token_totals_from_session_db():
|
||||
total_tokens=0, # SessionEntry never gets written to — always 0.
|
||||
)
|
||||
runner = _make_runner(session_entry)
|
||||
runner._session_db.get_session.return_value = {
|
||||
runner._session_db._db.get_session.return_value = {
|
||||
"input_tokens": 1000,
|
||||
"output_tokens": 250,
|
||||
"cache_read_tokens": 500,
|
||||
@@ -169,7 +170,7 @@ async def test_status_command_tokens_zero_when_session_db_row_missing():
|
||||
total_tokens=999, # This should be ignored.
|
||||
)
|
||||
runner = _make_runner(session_entry)
|
||||
runner._session_db.get_session.return_value = None
|
||||
runner._session_db._db.get_session.return_value = None
|
||||
|
||||
result = await runner._handle_message(_make_event("/status"))
|
||||
|
||||
@@ -188,7 +189,7 @@ async def test_status_command_includes_live_agent_model_and_context():
|
||||
total_tokens=0,
|
||||
)
|
||||
runner = _make_runner(session_entry)
|
||||
runner._session_db.get_session.return_value = {
|
||||
runner._session_db._db.get_session.return_value = {
|
||||
"input_tokens": 1000,
|
||||
"output_tokens": 250,
|
||||
"cache_read_tokens": 0,
|
||||
@@ -228,7 +229,7 @@ async def test_status_command_includes_persisted_model_and_context_when_agent_no
|
||||
last_prompt_tokens=24_000,
|
||||
)
|
||||
runner = _make_runner(session_entry)
|
||||
runner._session_db.get_session.return_value = {
|
||||
runner._session_db._db.get_session.return_value = {
|
||||
"input_tokens": 2000,
|
||||
"output_tokens": 500,
|
||||
"cache_read_tokens": 0,
|
||||
|
||||
@@ -123,6 +123,10 @@ def _make_runner(session_db=None):
|
||||
runner._busy_ack_ts = {}
|
||||
runner._session_model_overrides = {}
|
||||
runner._pending_model_notes = {}
|
||||
# Gateway holds the async facade; the slash handlers await it.
|
||||
if session_db is not None:
|
||||
from hermes_state import AsyncSessionDB
|
||||
session_db = AsyncSessionDB(session_db)
|
||||
runner._session_db = session_db
|
||||
runner._reasoning_config = None
|
||||
runner._provider_routing = {}
|
||||
@@ -1399,7 +1403,8 @@ def test_session_split_restores_source_thread_id_from_binding(tmp_path):
|
||||
)
|
||||
|
||||
runner = object.__new__(GatewayRunner)
|
||||
runner._session_db = db
|
||||
from hermes_state import AsyncSessionDB
|
||||
runner._session_db = AsyncSessionDB(db)
|
||||
|
||||
# Build a source that looks like it came from a synthetic/recovered event:
|
||||
# platform and chat_type match a Telegram DM, but thread_id is None.
|
||||
@@ -1416,7 +1421,9 @@ def test_session_split_restores_source_thread_id_from_binding(tmp_path):
|
||||
and runner._session_db is not None
|
||||
):
|
||||
try:
|
||||
_binding = runner._session_db.get_telegram_topic_binding_by_session(
|
||||
# Mirror production: this block runs in the run_sync executor, so it
|
||||
# uses the sync handle (self._session_db._db), not the async facade.
|
||||
_binding = runner._session_db._db.get_telegram_topic_binding_by_session(
|
||||
session_id="sess-split-new",
|
||||
)
|
||||
if _binding and _binding.get("thread_id"):
|
||||
|
||||
@@ -32,6 +32,10 @@ def _make_runner(session_db=None):
|
||||
runner = object.__new__(GatewayRunner)
|
||||
runner.adapters = {}
|
||||
runner._voice_mode = {}
|
||||
# Gateway holds the async facade; the slash handlers await it.
|
||||
if session_db is not None:
|
||||
from hermes_state import AsyncSessionDB
|
||||
session_db = AsyncSessionDB(session_db)
|
||||
runner._session_db = session_db
|
||||
|
||||
# Mock session_store that returns a session entry with a known session_id
|
||||
@@ -296,7 +300,7 @@ class TestResetCommandWithTitle:
|
||||
runner._running_agents = {}
|
||||
runner._pending_messages = {}
|
||||
runner._pending_approvals = {}
|
||||
runner._session_db = MagicMock()
|
||||
runner._session_db = AsyncMock()
|
||||
runner._agent_cache = {}
|
||||
runner._agent_cache_lock = None
|
||||
runner._is_user_authorized = lambda _source: True
|
||||
@@ -356,7 +360,7 @@ class TestResetCommandWithTitle:
|
||||
runner._running_agents = {}
|
||||
runner._pending_messages = {}
|
||||
runner._pending_approvals = {}
|
||||
runner._session_db = MagicMock()
|
||||
runner._session_db = AsyncMock()
|
||||
runner._session_db.set_session_title.side_effect = ValueError(
|
||||
"Title 'Dup' is already in use by session abc-123"
|
||||
)
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from hermes_state import AsyncSessionDB
|
||||
"""Tests for gateway /usage command — agent cache lookup and output fields."""
|
||||
|
||||
import threading
|
||||
@@ -197,8 +198,8 @@ class TestUsageAccountSection:
|
||||
@pytest.mark.asyncio
|
||||
async def test_usage_command_uses_persisted_provider_when_agent_not_running(self, monkeypatch):
|
||||
runner = _make_runner(SK)
|
||||
runner._session_db = MagicMock()
|
||||
runner._session_db.get_session.return_value = {
|
||||
runner._session_db = AsyncSessionDB(MagicMock())
|
||||
runner._session_db._db.get_session.return_value = {
|
||||
"billing_provider": "openai-codex",
|
||||
"billing_base_url": "https://chatgpt.com/backend-api/codex",
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user