diff --git a/agent/display.py b/agent/display.py index 604b7a298c..1820645768 100644 --- a/agent/display.py +++ b/agent/display.py @@ -4,7 +4,6 @@ Pure display functions and classes with no AIAgent dependency. Used by AIAgent._execute_tool_calls for CLI feedback. """ -import json import logging import os import sys @@ -14,6 +13,8 @@ from dataclasses import dataclass, field from difflib import unified_diff from pathlib import Path +from utils import safe_json_loads + # ANSI escape codes for coloring tool failure indicators _RED = "\033[31m" _RESET = "\033[0m" @@ -372,9 +373,8 @@ def _result_succeeded(result: str | None) -> bool: """Conservatively detect whether a tool result represents success.""" if not result: return False - try: - data = json.loads(result) - except (json.JSONDecodeError, TypeError): + data = safe_json_loads(result) + if data is None: return False if not isinstance(data, dict): return False @@ -423,10 +423,7 @@ def extract_edit_diff( ) -> str | None: """Extract a unified diff from a file-edit tool result.""" if tool_name == "patch" and result: - try: - data = json.loads(result) - except (json.JSONDecodeError, TypeError): - data = None + data = safe_json_loads(result) if isinstance(data, dict): diff = data.get("diff") if isinstance(diff, str) and diff.strip(): @@ -780,23 +777,19 @@ def _detect_tool_failure(tool_name: str, result: str | None) -> tuple[bool, str] return False, "" if tool_name == "terminal": - try: - data = json.loads(result) + data = safe_json_loads(result) + if isinstance(data, dict): exit_code = data.get("exit_code") if exit_code is not None and exit_code != 0: return True, f" [exit {exit_code}]" - except (json.JSONDecodeError, TypeError, AttributeError): - logger.debug("Could not parse terminal result as JSON for exit code check") return False, "" # Memory-specific: distinguish "full" from real errors if tool_name == "memory": - try: - data = json.loads(result) + data = safe_json_loads(result) + if isinstance(data, dict): if data.get("success") is False and "exceed the limit" in data.get("error", ""): return True, " [full]" - except (json.JSONDecodeError, TypeError, AttributeError): - logger.debug("Could not parse memory result as JSON for capacity check") # Generic heuristic for non-terminal tools lower = result[:500].lower() diff --git a/agent/model_metadata.py b/agent/model_metadata.py index 2ef6830e58..f12801777d 100644 --- a/agent/model_metadata.py +++ b/agent/model_metadata.py @@ -179,6 +179,12 @@ _MAX_COMPLETION_KEYS = ( # Local server hostnames / address patterns _LOCAL_HOSTS = ("localhost", "127.0.0.1", "::1", "0.0.0.0") +# Docker / Podman / Lima DNS names that resolve to the host machine +_CONTAINER_LOCAL_SUFFIXES = ( + ".docker.internal", + ".containers.internal", + ".lima.internal", +) def _normalize_base_url(base_url: str) -> str: @@ -254,6 +260,9 @@ def is_local_endpoint(base_url: str) -> bool: return False if host in _LOCAL_HOSTS: return True + # Docker / Podman / Lima internal DNS names (e.g. host.docker.internal) + if any(host.endswith(suffix) for suffix in _CONTAINER_LOCAL_SUFFIXES): + return True # RFC-1918 private ranges and link-local import ipaddress try: diff --git a/agent/prompt_builder.py b/agent/prompt_builder.py index 08b8fe0a6a..26d913a029 100644 --- a/agent/prompt_builder.py +++ b/agent/prompt_builder.py @@ -12,7 +12,7 @@ import threading from collections import OrderedDict from pathlib import Path -from hermes_constants import get_hermes_home +from hermes_constants import get_hermes_home, get_skills_dir from typing import Optional from agent.skill_utils import ( @@ -548,8 +548,7 @@ def build_skills_system_prompt( are read-only — they appear in the index but new skills are always created in the local dir. Local skills take precedence when names collide. """ - hermes_home = get_hermes_home() - skills_dir = hermes_home / "skills" + skills_dir = get_skills_dir() external_dirs = get_all_skills_dirs()[1:] # skip local (index 0) if not skills_dir.exists() and not external_dirs: diff --git a/agent/skill_utils.py b/agent/skill_utils.py index ba606b358d..97ba92b735 100644 --- a/agent/skill_utils.py +++ b/agent/skill_utils.py @@ -12,7 +12,7 @@ import sys from pathlib import Path from typing import Any, Dict, List, Set, Tuple -from hermes_constants import get_hermes_home +from hermes_constants import get_config_path, get_skills_dir logger = logging.getLogger(__name__) @@ -130,7 +130,7 @@ def get_disabled_skill_names(platform: str | None = None) -> Set[str]: Reads the config file directly (no CLI config imports) to stay lightweight. """ - config_path = get_hermes_home() / "config.yaml" + config_path = get_config_path() if not config_path.exists(): return set() try: @@ -178,7 +178,7 @@ def get_external_skills_dirs() -> List[Path]: path. Only directories that actually exist are returned. Duplicates and paths that resolve to the local ``~/.hermes/skills/`` are silently skipped. """ - config_path = get_hermes_home() / "config.yaml" + config_path = get_config_path() if not config_path.exists(): return [] try: @@ -200,7 +200,7 @@ def get_external_skills_dirs() -> List[Path]: if not isinstance(raw_dirs, list): return [] - local_skills = (get_hermes_home() / "skills").resolve() + local_skills = get_skills_dir().resolve() seen: Set[Path] = set() result: List[Path] = [] @@ -230,7 +230,7 @@ def get_all_skills_dirs() -> List[Path]: The local dir is always first (and always included even if it doesn't exist yet — callers handle that). External dirs follow in config order. """ - dirs = [get_hermes_home() / "skills"] + dirs = [get_skills_dir()] dirs.extend(get_external_skills_dirs()) return dirs @@ -384,7 +384,7 @@ def resolve_skill_config_values( current values (or the declared default if the key isn't set). Path values are expanded via ``os.path.expanduser``. """ - config_path = get_hermes_home() / "config.yaml" + config_path = get_config_path() config: Dict[str, Any] = {} if config_path.exists(): try: diff --git a/cli.py b/cli.py index 18f6df6711..1a57dd3eb2 100644 --- a/cli.py +++ b/cli.py @@ -2748,6 +2748,15 @@ class HermesCLI: self.api_key = api_key self.base_url = base_url + # When a custom_provider entry carries an explicit `model` field, + # use it as the effective model name. Without this, running + # `hermes chat --model ` sends the provider name + # (e.g. "my-provider") as the model string to the API instead of + # the configured model (e.g. "qwen3.6-plus"), causing 400 errors. + runtime_model = runtime.get("model") + if runtime_model and isinstance(runtime_model, str): + self.model = runtime_model + # Normalize model for the resolved provider (e.g. swap non-Codex # models when provider is openai-codex). Fixes #651. model_changed = self._normalize_model_for_provider(resolved_provider) diff --git a/cron/scheduler.py b/cron/scheduler.py index 0e04fb047b..870ebe1418 100644 --- a/cron/scheduler.py +++ b/cron/scheduler.py @@ -722,6 +722,7 @@ def run_job(job: dict) -> tuple[bool, str, str, Optional[str]]: provider_sort=pr.get("sort"), disabled_toolsets=["cronjob", "messaging", "clarify"], quiet_mode=True, + skip_context_files=True, # Don't inject SOUL.md/AGENTS.md from scheduler cwd skip_memory=True, # Cron system prompts would corrupt user representations platform="cron", session_id=_cron_session_id, diff --git a/gateway/platforms/api_server.py b/gateway/platforms/api_server.py index baada7e058..1954a2b9e5 100644 --- a/gateway/platforms/api_server.py +++ b/gateway/platforms/api_server.py @@ -53,6 +53,7 @@ DEFAULT_HOST = "127.0.0.1" DEFAULT_PORT = 8642 MAX_STORED_RESPONSES = 100 MAX_REQUEST_BYTES = 1_000_000 # 1 MB default limit for POST bodies +CHAT_COMPLETIONS_SSE_KEEPALIVE_SECONDS = 30.0 def check_api_server_requirements() -> bool: @@ -762,7 +763,11 @@ class APIServerAdapter(BasePlatformAdapter): """ import queue as _q - sse_headers = {"Content-Type": "text/event-stream", "Cache-Control": "no-cache"} + sse_headers = { + "Content-Type": "text/event-stream", + "Cache-Control": "no-cache", + "X-Accel-Buffering": "no", + } # CORS middleware can't inject headers into StreamResponse after # prepare() flushes them, so resolve CORS headers up front. origin = request.headers.get("Origin", "") @@ -775,6 +780,8 @@ class APIServerAdapter(BasePlatformAdapter): await response.prepare(request) try: + last_activity = time.monotonic() + # Role chunk role_chunk = { "id": completion_id, "object": "chat.completion.chunk", @@ -782,6 +789,7 @@ class APIServerAdapter(BasePlatformAdapter): "choices": [{"index": 0, "delta": {"role": "assistant"}, "finish_reason": None}], } await response.write(f"data: {json.dumps(role_chunk)}\n\n".encode()) + last_activity = time.monotonic() # Helper — route a queue item to the correct SSE event. async def _emit(item): @@ -805,6 +813,7 @@ class APIServerAdapter(BasePlatformAdapter): "choices": [{"index": 0, "delta": {"content": item}, "finish_reason": None}], } await response.write(f"data: {json.dumps(content_chunk)}\n\n".encode()) + return time.monotonic() # Stream content chunks as they arrive from the agent loop = asyncio.get_event_loop() @@ -819,16 +828,19 @@ class APIServerAdapter(BasePlatformAdapter): delta = stream_q.get_nowait() if delta is None: break - await _emit(delta) + last_activity = await _emit(delta) except _q.Empty: break break + if time.monotonic() - last_activity >= CHAT_COMPLETIONS_SSE_KEEPALIVE_SECONDS: + await response.write(b": keepalive\n\n") + last_activity = time.monotonic() continue if delta is None: # End of stream sentinel break - await _emit(delta) + last_activity = await _emit(delta) # Get usage from completed agent usage = {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0} diff --git a/gateway/platforms/base.py b/gateway/platforms/base.py index b4c84f3119..45cb3694a7 100644 --- a/gateway/platforms/base.py +++ b/gateway/platforms/base.py @@ -823,7 +823,36 @@ class BasePlatformAdapter(ABC): result = handler(self) if asyncio.iscoroutine(result): await result - + + def _acquire_platform_lock(self, scope: str, identity: str, resource_desc: str) -> bool: + """Acquire a scoped lock for this adapter. Returns True on success.""" + from gateway.status import acquire_scoped_lock + self._platform_lock_scope = scope + self._platform_lock_identity = identity + acquired, existing = acquire_scoped_lock( + scope, identity, metadata={'platform': self.platform.value} + ) + if acquired: + return True + owner_pid = existing.get('pid') if isinstance(existing, dict) else None + message = ( + f'{resource_desc} already in use' + + (f' (PID {owner_pid})' if owner_pid else '') + + '. Stop the other gateway first.' + ) + logger.error('[%s] %s', self.name, message) + self._set_fatal_error(f'{scope}_lock', message, retryable=False) + return False + + def _release_platform_lock(self) -> None: + """Release the scoped lock acquired by _acquire_platform_lock.""" + identity = getattr(self, '_platform_lock_identity', None) + if not identity: + return + from gateway.status import release_scoped_lock + release_scoped_lock(self._platform_lock_scope, identity) + self._platform_lock_identity = None + @property def name(self) -> str: """Human-readable name for this adapter.""" diff --git a/gateway/platforms/bluebubbles.py b/gateway/platforms/bluebubbles.py index f50cd9503c..1150009965 100644 --- a/gateway/platforms/bluebubbles.py +++ b/gateway/platforms/bluebubbles.py @@ -30,6 +30,7 @@ from gateway.platforms.base import ( cache_audio_from_bytes, cache_document_from_bytes, ) +from gateway.platforms.helpers import strip_markdown logger = logging.getLogger(__name__) @@ -89,18 +90,7 @@ def _normalize_server_url(raw: str) -> str: return value.rstrip("/") -def _strip_markdown(text: str) -> str: - """Strip common markdown formatting for iMessage plain-text delivery.""" - text = re.sub(r"\*\*(.+?)\*\*", r"\1", text, flags=re.DOTALL) - text = re.sub(r"\*(.+?)\*", r"\1", text, flags=re.DOTALL) - text = re.sub(r"__(.+?)__", r"\1", text, flags=re.DOTALL) - text = re.sub(r"_(.+?)_", r"\1", text, flags=re.DOTALL) - text = re.sub(r"```[a-zA-Z0-9_+-]*\n?", "", text) - text = re.sub(r"`(.+?)`", r"\1", text) - text = re.sub(r"^#{1,6}\s+", "", text, flags=re.MULTILINE) - text = re.sub(r"\[([^\]]+)\]\(([^\)]+)\)", r"\1", text) - text = re.sub(r"\n{3,}", "\n\n", text) - return text.strip() + # --------------------------------------------------------------------------- @@ -393,7 +383,7 @@ class BlueBubblesAdapter(BasePlatformAdapter): reply_to: Optional[str] = None, metadata: Optional[Dict[str, Any]] = None, ) -> SendResult: - text = _strip_markdown(content or "") + text = strip_markdown(content or "") if not text: return SendResult(success=False, error="BlueBubbles send requires text") chunks = self.truncate_message(text, max_length=self.MAX_MESSAGE_LENGTH) @@ -679,7 +669,7 @@ class BlueBubblesAdapter(BasePlatformAdapter): return info def format_message(self, content: str) -> str: - return _strip_markdown(content) + return strip_markdown(content) # ------------------------------------------------------------------ # Inbound attachment downloading (from #4588) diff --git a/gateway/platforms/dingtalk.py b/gateway/platforms/dingtalk.py index e83b902dfb..5d50deca58 100644 --- a/gateway/platforms/dingtalk.py +++ b/gateway/platforms/dingtalk.py @@ -42,6 +42,7 @@ except ImportError: httpx = None # type: ignore[assignment] from gateway.config import Platform, PlatformConfig +from gateway.platforms.helpers import MessageDeduplicator from gateway.platforms.base import ( BasePlatformAdapter, MessageEvent, @@ -52,8 +53,6 @@ from gateway.platforms.base import ( logger = logging.getLogger(__name__) MAX_MESSAGE_LENGTH = 20000 -DEDUP_WINDOW_SECONDS = 300 -DEDUP_MAX_SIZE = 1000 RECONNECT_BACKOFF = [2, 5, 10, 30, 60] _SESSION_WEBHOOKS_MAX = 500 _DINGTALK_WEBHOOK_RE = re.compile(r'^https://api\.dingtalk\.com/') @@ -89,8 +88,8 @@ class DingTalkAdapter(BasePlatformAdapter): self._stream_task: Optional[asyncio.Task] = None self._http_client: Optional["httpx.AsyncClient"] = None - # Message deduplication: msg_id -> timestamp - self._seen_messages: Dict[str, float] = {} + # Message deduplication + self._dedup = MessageDeduplicator(max_size=1000) # Map chat_id -> session_webhook for reply routing self._session_webhooks: Dict[str, str] = {} @@ -170,7 +169,7 @@ class DingTalkAdapter(BasePlatformAdapter): self._stream_client = None self._session_webhooks.clear() - self._seen_messages.clear() + self._dedup.clear() logger.info("[%s] Disconnected", self.name) # -- Inbound message processing ----------------------------------------- @@ -178,7 +177,7 @@ class DingTalkAdapter(BasePlatformAdapter): async def _on_message(self, message: "ChatbotMessage") -> None: """Process an incoming DingTalk chatbot message.""" msg_id = getattr(message, "message_id", None) or uuid.uuid4().hex - if self._is_duplicate(msg_id): + if self._dedup.is_duplicate(msg_id): logger.debug("[%s] Duplicate message %s, skipping", self.name, msg_id) return @@ -256,20 +255,6 @@ class DingTalkAdapter(BasePlatformAdapter): content = " ".join(parts).strip() return content - # -- Deduplication ------------------------------------------------------ - - def _is_duplicate(self, msg_id: str) -> bool: - """Check and record a message ID. Returns True if already seen.""" - now = time.time() - if len(self._seen_messages) > DEDUP_MAX_SIZE: - cutoff = now - DEDUP_WINDOW_SECONDS - self._seen_messages = {k: v for k, v in self._seen_messages.items() if v > cutoff} - - if msg_id in self._seen_messages: - return True - self._seen_messages[msg_id] = now - return False - # -- Outbound messaging ------------------------------------------------- async def send( diff --git a/gateway/platforms/discord.py b/gateway/platforms/discord.py index dcf05a1625..b1d07e5d65 100644 --- a/gateway/platforms/discord.py +++ b/gateway/platforms/discord.py @@ -45,6 +45,7 @@ sys.path.insert(0, str(_Path(__file__).resolve().parents[2])) from gateway.config import Platform, PlatformConfig import re +from gateway.platforms.helpers import MessageDeduplicator, ThreadParticipationTracker from gateway.platforms.base import ( BasePlatformAdapter, MessageEvent, @@ -450,18 +451,14 @@ class DiscordAdapter(BasePlatformAdapter): # Track threads where the bot has participated so follow-up messages # in those threads don't require @mention. Persisted to disk so the # set survives gateway restarts. - self._bot_participated_threads: set = self._load_participated_threads() + self._threads = ThreadParticipationTracker("discord") # Persistent typing indicator loops per channel (DMs don't reliably # show the standard typing gateway event for bots) self._typing_tasks: Dict[str, asyncio.Task] = {} self._bot_task: Optional[asyncio.Task] = None - # Cap to prevent unbounded growth (Discord threads get archived). - self._MAX_TRACKED_THREADS = 500 - # Dedup cache: message_id → timestamp. Prevents duplicate bot - # responses when Discord RESUME replays events after reconnects. - self._seen_messages: Dict[str, float] = {} - self._SEEN_TTL = 300 # 5 minutes - self._SEEN_MAX = 2000 # prune threshold + # Dedup cache: prevents duplicate bot responses when Discord + # RESUME replays events after reconnects. + self._dedup = MessageDeduplicator() # Reply threading mode: "off" (no replies), "first" (reply on first # chunk only, default), "all" (reply-reference on every chunk). self._reply_to_mode: str = getattr(config, 'reply_to_mode', 'first') or 'first' @@ -502,18 +499,9 @@ class DiscordAdapter(BasePlatformAdapter): return False try: - # Acquire scoped lock to prevent duplicate bot token usage - from gateway.status import acquire_scoped_lock - self._token_lock_identity = self.config.token - acquired, existing = acquire_scoped_lock('discord-bot-token', self._token_lock_identity, metadata={'platform': 'discord'}) - if not acquired: - owner_pid = existing.get('pid') if isinstance(existing, dict) else None - message = f'Discord bot token already in use' + (f' (PID {owner_pid})' if owner_pid else '') + '. Stop the other gateway first.' - logger.error('[%s] %s', self.name, message) - self._set_fatal_error('discord_token_lock', message, retryable=False) + if not self._acquire_platform_lock('discord-bot-token', self.config.token, 'Discord bot token'): return False - # Parse allowed user entries (may contain usernames or IDs) allowed_env = os.getenv("DISCORD_ALLOWED_USERS", "") if allowed_env: @@ -569,17 +557,8 @@ class DiscordAdapter(BasePlatformAdapter): @self._client.event async def on_message(message: DiscordMessage): # Dedup: Discord RESUME replays events after reconnects (#4777) - msg_id = str(message.id) - now = time.time() - if msg_id in adapter_self._seen_messages: + if adapter_self._dedup.is_duplicate(str(message.id)): return - adapter_self._seen_messages[msg_id] = now - if len(adapter_self._seen_messages) > adapter_self._SEEN_MAX: - cutoff = now - adapter_self._SEEN_TTL - adapter_self._seen_messages = { - k: v for k, v in adapter_self._seen_messages.items() - if v > cutoff - } # Always ignore our own messages if message.author == self._client.user: @@ -685,23 +664,11 @@ class DiscordAdapter(BasePlatformAdapter): except asyncio.TimeoutError: logger.error("[%s] Timeout waiting for connection to Discord", self.name, exc_info=True) - try: - from gateway.status import release_scoped_lock - if getattr(self, '_token_lock_identity', None): - release_scoped_lock('discord-bot-token', self._token_lock_identity) - self._token_lock_identity = None - except Exception: - pass + self._release_platform_lock() return False except Exception as e: # pragma: no cover - defensive logging logger.error("[%s] Failed to connect to Discord: %s", self.name, e, exc_info=True) - try: - from gateway.status import release_scoped_lock - if getattr(self, '_token_lock_identity', None): - release_scoped_lock('discord-bot-token', self._token_lock_identity) - self._token_lock_identity = None - except Exception: - pass + self._release_platform_lock() return False async def disconnect(self) -> None: @@ -723,14 +690,7 @@ class DiscordAdapter(BasePlatformAdapter): self._client = None self._ready_event.clear() - # Release the token lock - try: - from gateway.status import release_scoped_lock - if getattr(self, '_token_lock_identity', None): - release_scoped_lock('discord-bot-token', self._token_lock_identity) - self._token_lock_identity = None - except Exception: - pass + self._release_platform_lock() logger.info("[%s] Disconnected", self.name) @@ -1870,7 +1830,7 @@ class DiscordAdapter(BasePlatformAdapter): # Track thread participation so follow-ups don't require @mention if thread_id: - self._track_thread(thread_id) + self._threads.mark(thread_id) # If a message was provided, kick off a new Hermes session in the thread starter = (message or "").strip() @@ -2241,49 +2201,6 @@ class DiscordAdapter(BasePlatformAdapter): return f"{parent_name} / {thread_name}" return thread_name - # ------------------------------------------------------------------ - # Thread participation persistence - # ------------------------------------------------------------------ - - @staticmethod - def _thread_state_path() -> Path: - """Path to the persisted thread participation set.""" - from hermes_cli.config import get_hermes_home - return get_hermes_home() / "discord_threads.json" - - @classmethod - def _load_participated_threads(cls) -> set: - """Load persisted thread IDs from disk.""" - path = cls._thread_state_path() - try: - if path.exists(): - data = json.loads(path.read_text(encoding="utf-8")) - if isinstance(data, list): - return set(data) - except Exception as e: - logger.debug("Could not load discord thread state: %s", e) - return set() - - def _save_participated_threads(self) -> None: - """Persist the current thread set to disk (best-effort).""" - path = self._thread_state_path() - try: - # Trim to most recent entries if over cap - thread_list = list(self._bot_participated_threads) - if len(thread_list) > self._MAX_TRACKED_THREADS: - thread_list = thread_list[-self._MAX_TRACKED_THREADS:] - self._bot_participated_threads = set(thread_list) - path.parent.mkdir(parents=True, exist_ok=True) - path.write_text(json.dumps(thread_list), encoding="utf-8") - except Exception as e: - logger.debug("Could not save discord thread state: %s", e) - - def _track_thread(self, thread_id: str) -> None: - """Add a thread to the participation set and persist.""" - if thread_id not in self._bot_participated_threads: - self._bot_participated_threads.add(thread_id) - self._save_participated_threads() - async def _handle_message(self, message: DiscordMessage) -> None: """Handle incoming Discord messages.""" # In server channels (not DMs), require the bot to be @mentioned @@ -2335,7 +2252,7 @@ class DiscordAdapter(BasePlatformAdapter): # Skip the mention check if the message is in a thread where # the bot has previously participated (auto-created or replied in). - in_bot_thread = is_thread and thread_id in self._bot_participated_threads + in_bot_thread = is_thread and thread_id in self._threads if require_mention and not is_free_channel and not in_bot_thread: if self._client.user not in message.mentions: @@ -2361,7 +2278,7 @@ class DiscordAdapter(BasePlatformAdapter): is_thread = True thread_id = str(thread.id) auto_threaded_channel = thread - self._track_thread(thread_id) + self._threads.mark(thread_id) # Determine message type msg_type = MessageType.TEXT @@ -2545,7 +2462,7 @@ class DiscordAdapter(BasePlatformAdapter): # Track thread participation so the bot won't require @mention for # follow-up messages in threads it has already engaged in. if thread_id: - self._track_thread(thread_id) + self._threads.mark(thread_id) # Only batch plain text messages — commands, media, etc. dispatch # immediately since they won't be split by the Discord client. diff --git a/gateway/platforms/feishu.py b/gateway/platforms/feishu.py index a88c7e52b9..16f5467b22 100644 --- a/gateway/platforms/feishu.py +++ b/gateway/platforms/feishu.py @@ -360,19 +360,21 @@ def _render_code_block_element(element: Dict[str, Any]) -> str: def _strip_markdown_to_plain_text(text: str) -> str: + """Strip markdown formatting to plain text for Feishu text fallbacks. + + Delegates common markdown stripping to the shared helper and adds + Feishu-specific patterns (blockquotes, strikethrough, underline tags, + horizontal rules, \\r\\n normalisation). + """ + from gateway.platforms.helpers import strip_markdown plain = text.replace("\r\n", "\n") plain = _MARKDOWN_LINK_RE.sub(lambda m: f"{m.group(1)} ({m.group(2).strip()})", plain) - plain = re.sub(r"^#{1,6}\s+", "", plain, flags=re.MULTILINE) plain = re.sub(r"^>\s?", "", plain, flags=re.MULTILINE) plain = re.sub(r"^\s*---+\s*$", "---", plain, flags=re.MULTILINE) - plain = re.sub(r"```(?:[^\n]*\n)?([\s\S]*?)```", lambda m: m.group(1).strip("\n"), plain) - plain = re.sub(r"`([^`\n]+)`", r"\1", plain) - plain = re.sub(r"\*\*([^*\n]+)\*\*", r"\1", plain) - plain = re.sub(r"\*([^*\n]+)\*", r"\1", plain) plain = re.sub(r"~~([^~\n]+)~~", r"\1", plain) plain = re.sub(r"([\s\S]*?)", r"\1", plain) - plain = re.sub(r"\n{3,}", "\n\n", plain) - return plain.strip() + plain = strip_markdown(plain) + return plain def _coerce_int(value: Any, default: Optional[int] = None, min_value: int = 0) -> Optional[int]: diff --git a/gateway/platforms/helpers.py b/gateway/platforms/helpers.py new file mode 100644 index 0000000000..c834dd89ca --- /dev/null +++ b/gateway/platforms/helpers.py @@ -0,0 +1,261 @@ +"""Shared helper classes for gateway platform adapters. + +Extracts common patterns that were duplicated across 5-7 adapters: +message deduplication, text batch aggregation, markdown stripping, +and thread participation tracking. +""" + +import asyncio +import json +import logging +import re +import time +from pathlib import Path +from typing import TYPE_CHECKING, Dict, Optional + +if TYPE_CHECKING: + from gateway.platforms.base import BasePlatformAdapter, MessageEvent + +logger = logging.getLogger(__name__) + + +# ─── Message Deduplication ──────────────────────────────────────────────────── + + +class MessageDeduplicator: + """TTL-based message deduplication cache. + + Replaces the identical ``_seen_messages`` / ``_is_duplicate()`` pattern + previously duplicated in discord, slack, dingtalk, wecom, weixin, + mattermost, and feishu adapters. + + Usage:: + + self._dedup = MessageDeduplicator() + + # In message handler: + if self._dedup.is_duplicate(msg_id): + return + """ + + def __init__(self, max_size: int = 2000, ttl_seconds: float = 300): + self._seen: Dict[str, float] = {} + self._max_size = max_size + self._ttl = ttl_seconds + + def is_duplicate(self, msg_id: str) -> bool: + """Return True if *msg_id* was already seen within the TTL window.""" + if not msg_id: + return False + now = time.time() + if msg_id in self._seen: + return True + self._seen[msg_id] = now + if len(self._seen) > self._max_size: + cutoff = now - self._ttl + self._seen = {k: v for k, v in self._seen.items() if v > cutoff} + return False + + def clear(self): + """Clear all tracked messages.""" + self._seen.clear() + + +# ─── Text Batch Aggregation ────────────────────────────────────────────────── + + +class TextBatchAggregator: + """Aggregates rapid-fire text events into single messages. + + Replaces the ``_enqueue_text_event`` / ``_flush_text_batch`` pattern + previously duplicated in telegram, discord, matrix, wecom, and feishu. + + Usage:: + + self._text_batcher = TextBatchAggregator( + handler=self._message_handler, + batch_delay=0.6, + split_threshold=1900, + ) + + # In message dispatch: + if msg_type == MessageType.TEXT and self._text_batcher.is_enabled(): + self._text_batcher.enqueue(event, session_key) + return + """ + + def __init__( + self, + handler, + *, + batch_delay: float = 0.6, + split_delay: float = 2.0, + split_threshold: int = 4000, + ): + self._handler = handler + self._batch_delay = batch_delay + self._split_delay = split_delay + self._split_threshold = split_threshold + self._pending: Dict[str, "MessageEvent"] = {} + self._pending_tasks: Dict[str, asyncio.Task] = {} + + def is_enabled(self) -> bool: + """Return True if batching is active (delay > 0).""" + return self._batch_delay > 0 + + def enqueue(self, event: "MessageEvent", key: str) -> None: + """Add *event* to the pending batch for *key*.""" + chunk_len = len(event.text or "") + existing = self._pending.get(key) + if not existing: + event._last_chunk_len = chunk_len # type: ignore[attr-defined] + self._pending[key] = event + else: + existing.text = f"{existing.text}\n{event.text}" + existing._last_chunk_len = chunk_len # type: ignore[attr-defined] + + # Cancel prior flush timer, start a new one + prior = self._pending_tasks.get(key) + if prior and not prior.done(): + prior.cancel() + self._pending_tasks[key] = asyncio.create_task(self._flush(key)) + + async def _flush(self, key: str) -> None: + """Wait then dispatch the batched event for *key*.""" + current_task = self._pending_tasks.get(key) + pending = self._pending.get(key) + last_len = getattr(pending, "_last_chunk_len", 0) if pending else 0 + + # Use longer delay when the last chunk looks like a split message + delay = self._split_delay if last_len >= self._split_threshold else self._batch_delay + await asyncio.sleep(delay) + + event = self._pending.pop(key, None) + if event: + try: + await self._handler(event) + except Exception: + logger.exception("[TextBatchAggregator] Error dispatching batched event for %s", key) + + if self._pending_tasks.get(key) is current_task: + self._pending_tasks.pop(key, None) + + def cancel_all(self) -> None: + """Cancel all pending flush tasks.""" + for task in self._pending_tasks.values(): + if not task.done(): + task.cancel() + self._pending_tasks.clear() + self._pending.clear() + + +# ─── Markdown Stripping ────────────────────────────────────────────────────── + +# Pre-compiled regexes for performance +_RE_BOLD = re.compile(r"\*\*(.+?)\*\*", re.DOTALL) +_RE_ITALIC_STAR = re.compile(r"\*(.+?)\*", re.DOTALL) +_RE_BOLD_UNDER = re.compile(r"__(.+?)__", re.DOTALL) +_RE_ITALIC_UNDER = re.compile(r"_(.+?)_", re.DOTALL) +_RE_CODE_BLOCK = re.compile(r"```[a-zA-Z0-9_+-]*\n?") +_RE_INLINE_CODE = re.compile(r"`(.+?)`") +_RE_HEADING = re.compile(r"^#{1,6}\s+", re.MULTILINE) +_RE_LINK = re.compile(r"\[([^\]]+)\]\([^\)]+\)") +_RE_MULTI_NEWLINE = re.compile(r"\n{3,}") + + +def strip_markdown(text: str) -> str: + """Strip markdown formatting for plain-text platforms (SMS, iMessage, etc.). + + Replaces the identical ``_strip_markdown()`` functions previously + duplicated in sms.py, bluebubbles.py, and feishu.py. + """ + text = _RE_BOLD.sub(r"\1", text) + text = _RE_ITALIC_STAR.sub(r"\1", text) + text = _RE_BOLD_UNDER.sub(r"\1", text) + text = _RE_ITALIC_UNDER.sub(r"\1", text) + text = _RE_CODE_BLOCK.sub("", text) + text = _RE_INLINE_CODE.sub(r"\1", text) + text = _RE_HEADING.sub("", text) + text = _RE_LINK.sub(r"\1", text) + text = _RE_MULTI_NEWLINE.sub("\n\n", text) + return text.strip() + + +# ─── Thread Participation Tracking ─────────────────────────────────────────── + + +class ThreadParticipationTracker: + """Persistent tracking of threads the bot has participated in. + + Replaces the identical ``_load/_save_participated_threads`` + + ``_mark_thread_participated`` pattern previously duplicated in + discord.py and matrix.py. + + Usage:: + + self._threads = ThreadParticipationTracker("discord") + + # Check membership: + if thread_id in self._threads: + ... + + # Mark participation: + self._threads.mark(thread_id) + """ + + _MAX_TRACKED = 500 + + def __init__(self, platform_name: str, max_tracked: int = 500): + self._platform = platform_name + self._max_tracked = max_tracked + self._threads: set = self._load() + + def _state_path(self) -> Path: + from hermes_constants import get_hermes_home + return get_hermes_home() / f"{self._platform}_threads.json" + + def _load(self) -> set: + path = self._state_path() + if path.exists(): + try: + return set(json.loads(path.read_text(encoding="utf-8"))) + except Exception: + pass + return set() + + def _save(self) -> None: + path = self._state_path() + path.parent.mkdir(parents=True, exist_ok=True) + thread_list = list(self._threads) + if len(thread_list) > self._max_tracked: + thread_list = thread_list[-self._max_tracked:] + self._threads = set(thread_list) + path.write_text(json.dumps(thread_list), encoding="utf-8") + + def mark(self, thread_id: str) -> None: + """Mark *thread_id* as participated and persist.""" + if thread_id not in self._threads: + self._threads.add(thread_id) + self._save() + + def __contains__(self, thread_id: str) -> bool: + return thread_id in self._threads + + def clear(self) -> None: + self._threads.clear() + + +# ─── Phone Number Redaction ────────────────────────────────────────────────── + + +def redact_phone(phone: str) -> str: + """Redact a phone number for logging, preserving country code and last 4. + + Replaces the identical ``_redact_phone()`` functions in signal.py, + sms.py, and bluebubbles.py. + """ + if not phone: + return "" + if len(phone) <= 8: + return phone[:2] + "****" + phone[-2:] if len(phone) > 4 else "****" + return phone[:4] + "****" + phone[-4:] diff --git a/gateway/platforms/matrix.py b/gateway/platforms/matrix.py index 7daf2e70e1..349f962d2e 100644 --- a/gateway/platforms/matrix.py +++ b/gateway/platforms/matrix.py @@ -92,6 +92,7 @@ from gateway.platforms.base import ( ProcessingOutcome, SendResult, ) +from gateway.platforms.helpers import ThreadParticipationTracker logger = logging.getLogger(__name__) @@ -216,8 +217,7 @@ class MatrixAdapter(BasePlatformAdapter): self._pending_megolm: list = [] # Thread participation tracking (for require_mention bypass) - self._bot_participated_threads: set = self._load_participated_threads() - self._MAX_TRACKED_THREADS = 500 + self._threads = ThreadParticipationTracker("matrix") # Mention/thread gating — parsed once from env vars. self._require_mention: bool = os.getenv("MATRIX_REQUIRE_MENTION", "true").lower() not in ("false", "0", "no") @@ -1019,7 +1019,7 @@ class MatrixAdapter(BasePlatformAdapter): # Require-mention gating. if not is_dm: is_free_room = room_id in self._free_rooms - in_bot_thread = bool(thread_id and thread_id in self._bot_participated_threads) + in_bot_thread = bool(thread_id and thread_id in self._threads) if self._require_mention and not is_free_room and not in_bot_thread: if not is_mentioned: return None @@ -1027,7 +1027,7 @@ class MatrixAdapter(BasePlatformAdapter): # DM mention-thread. if is_dm and not thread_id and self._dm_mention_threads and is_mentioned: thread_id = event_id - self._track_thread(thread_id) + self._threads.mark(thread_id) # Strip mention from body. if is_mentioned: @@ -1036,7 +1036,7 @@ class MatrixAdapter(BasePlatformAdapter): # Auto-thread. if not is_dm and not thread_id and self._auto_thread: thread_id = event_id - self._track_thread(thread_id) + self._threads.mark(thread_id) display_name = await self._get_display_name(room_id, sender) source = self.build_source( @@ -1048,7 +1048,7 @@ class MatrixAdapter(BasePlatformAdapter): ) if thread_id: - self._track_thread(thread_id) + self._threads.mark(thread_id) self._background_read_receipt(room_id, event_id) @@ -1697,48 +1697,6 @@ class MatrixAdapter(BasePlatformAdapter): for rid in self._joined_rooms } - # ------------------------------------------------------------------ - # Thread participation tracking - # ------------------------------------------------------------------ - - @staticmethod - def _thread_state_path() -> Path: - """Path to the persisted thread participation set.""" - from hermes_cli.config import get_hermes_home - return get_hermes_home() / "matrix_threads.json" - - @classmethod - def _load_participated_threads(cls) -> set: - """Load persisted thread IDs from disk.""" - path = cls._thread_state_path() - try: - if path.exists(): - data = json.loads(path.read_text(encoding="utf-8")) - if isinstance(data, list): - return set(data) - except Exception as e: - logger.debug("Could not load matrix thread state: %s", e) - return set() - - def _save_participated_threads(self) -> None: - """Persist the current thread set to disk (best-effort).""" - path = self._thread_state_path() - try: - thread_list = list(self._bot_participated_threads) - if len(thread_list) > self._MAX_TRACKED_THREADS: - thread_list = thread_list[-self._MAX_TRACKED_THREADS:] - self._bot_participated_threads = set(thread_list) - path.parent.mkdir(parents=True, exist_ok=True) - path.write_text(json.dumps(thread_list), encoding="utf-8") - except Exception as e: - logger.debug("Could not save matrix thread state: %s", e) - - def _track_thread(self, thread_id: str) -> None: - """Add a thread to the participation set and persist.""" - if thread_id not in self._bot_participated_threads: - self._bot_participated_threads.add(thread_id) - self._save_participated_threads() - # ------------------------------------------------------------------ # Mention detection helpers # ------------------------------------------------------------------ diff --git a/gateway/platforms/mattermost.py b/gateway/platforms/mattermost.py index 56f29e8760..23a86f02b1 100644 --- a/gateway/platforms/mattermost.py +++ b/gateway/platforms/mattermost.py @@ -18,11 +18,11 @@ import json import logging import os import re -import time from pathlib import Path from typing import Any, Dict, List, Optional from gateway.config import Platform, PlatformConfig +from gateway.platforms.helpers import MessageDeduplicator from gateway.platforms.base import ( BasePlatformAdapter, MessageEvent, @@ -96,10 +96,8 @@ class MattermostAdapter(BasePlatformAdapter): or os.getenv("MATTERMOST_REPLY_MODE", "off") ).lower() - # Dedup cache: post_id → timestamp (prevent reprocessing) - self._seen_posts: Dict[str, float] = {} - self._SEEN_MAX = 2000 - self._SEEN_TTL = 300 # 5 minutes + # Dedup cache (prevent reprocessing) + self._dedup = MessageDeduplicator() # ------------------------------------------------------------------ # HTTP helpers @@ -604,10 +602,8 @@ class MattermostAdapter(BasePlatformAdapter): post_id = post.get("id", "") # Dedup. - self._prune_seen() - if post_id in self._seen_posts: + if self._dedup.is_duplicate(post_id): return - self._seen_posts[post_id] = time.time() # Build message event. channel_id = post.get("channel_id", "") @@ -734,13 +730,4 @@ class MattermostAdapter(BasePlatformAdapter): await self.handle_message(msg_event) - def _prune_seen(self) -> None: - """Remove expired entries from the dedup cache.""" - if len(self._seen_posts) < self._SEEN_MAX: - return - now = time.time() - self._seen_posts = { - pid: ts - for pid, ts in self._seen_posts.items() - if now - ts < self._SEEN_TTL - } + diff --git a/gateway/platforms/signal.py b/gateway/platforms/signal.py index 08b62f2a6d..8ef7bd0d60 100644 --- a/gateway/platforms/signal.py +++ b/gateway/platforms/signal.py @@ -37,6 +37,7 @@ from gateway.platforms.base import ( cache_document_from_bytes, cache_image_from_url, ) +from gateway.platforms.helpers import redact_phone logger = logging.getLogger(__name__) @@ -51,22 +52,10 @@ SSE_RETRY_DELAY_MAX = 60.0 HEALTH_CHECK_INTERVAL = 30.0 # seconds between health checks HEALTH_CHECK_STALE_THRESHOLD = 120.0 # seconds without SSE activity before concern -# E.164 phone number pattern for redaction -_PHONE_RE = re.compile(r"\+[1-9]\d{6,14}") - - # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- -def _redact_phone(phone: str) -> str: - """Redact a phone number for logging: +15551234567 -> +155****4567.""" - if not phone: - return "" - if len(phone) <= 8: - return phone[:2] + "****" + phone[-2:] if len(phone) > 4 else "****" - return phone[:4] + "****" + phone[-4:] - def _parse_comma_list(value: str) -> List[str]: """Split a comma-separated string into a list, stripping whitespace.""" @@ -184,10 +173,8 @@ class SignalAdapter(BasePlatformAdapter): self._recent_sent_timestamps: set = set() self._max_recent_timestamps = 50 - self._phone_lock_identity: Optional[str] = None - logger.info("Signal adapter initialized: url=%s account=%s groups=%s", - self.http_url, _redact_phone(self.account), + self.http_url, redact_phone(self.account), "enabled" if self.group_allow_from else "disabled") # ------------------------------------------------------------------ @@ -202,23 +189,7 @@ class SignalAdapter(BasePlatformAdapter): # Acquire scoped lock to prevent duplicate Signal listeners for the same phone try: - from gateway.status import acquire_scoped_lock - - self._phone_lock_identity = self.account - acquired, existing = acquire_scoped_lock( - "signal-phone", - self._phone_lock_identity, - metadata={"platform": self.platform.value}, - ) - if not acquired: - owner_pid = existing.get("pid") if isinstance(existing, dict) else None - message = ( - "Another local Hermes gateway is already using this Signal account" - + (f" (PID {owner_pid})." if owner_pid else ".") - + " Stop the other gateway before starting a second Signal listener." - ) - logger.error("Signal: %s", message) - self._set_fatal_error("signal_phone_lock", message, retryable=False) + if not self._acquire_platform_lock('signal-phone', self.account, 'Signal account'): return False except Exception as e: logger.warning("Signal: Could not acquire phone lock (non-fatal): %s", e) @@ -270,13 +241,7 @@ class SignalAdapter(BasePlatformAdapter): await self.client.aclose() self.client = None - if self._phone_lock_identity: - try: - from gateway.status import release_scoped_lock - release_scoped_lock("signal-phone", self._phone_lock_identity) - except Exception as e: - logger.warning("Signal: Error releasing phone lock: %s", e, exc_info=True) - self._phone_lock_identity = None + self._release_platform_lock() logger.info("Signal: disconnected") @@ -542,7 +507,7 @@ class SignalAdapter(BasePlatformAdapter): ) logger.debug("Signal: message from %s in %s: %s", - _redact_phone(sender), chat_id[:20], (text or "")[:50]) + redact_phone(sender), chat_id[:20], (text or "")[:50]) await self.handle_message(event) diff --git a/gateway/platforms/slack.py b/gateway/platforms/slack.py index 361f74882e..8f9934cf7a 100644 --- a/gateway/platforms/slack.py +++ b/gateway/platforms/slack.py @@ -33,6 +33,7 @@ from pathlib import Path as _Path sys.path.insert(0, str(_Path(__file__).resolve().parents[2])) from gateway.config import Platform, PlatformConfig +from gateway.platforms.helpers import MessageDeduplicator from gateway.platforms.base import ( BasePlatformAdapter, MessageEvent, @@ -89,11 +90,9 @@ class SlackAdapter(BasePlatformAdapter): self._team_clients: Dict[str, AsyncWebClient] = {} # team_id → WebClient self._team_bot_user_ids: Dict[str, str] = {} # team_id → bot_user_id self._channel_team: Dict[str, str] = {} # channel_id → team_id - # Dedup cache: event_ts → timestamp. Prevents duplicate bot - # responses when Socket Mode reconnects redeliver events. - self._seen_messages: Dict[str, float] = {} - self._SEEN_TTL = 300 # 5 minutes - self._SEEN_MAX = 2000 # prune threshold + # Dedup cache: prevents duplicate bot responses when Socket Mode + # reconnects redeliver events. + self._dedup = MessageDeduplicator() # Track pending approval message_ts → resolved flag to prevent # double-clicks on approval buttons. self._approval_resolved: Dict[str, bool] = {} @@ -152,15 +151,7 @@ class SlackAdapter(BasePlatformAdapter): logger.warning("[Slack] Failed to read %s: %s", tokens_file, e) try: - # Acquire scoped lock to prevent duplicate app token usage - from gateway.status import acquire_scoped_lock - self._token_lock_identity = app_token - acquired, existing = acquire_scoped_lock('slack-app-token', app_token, metadata={'platform': 'slack'}) - if not acquired: - owner_pid = existing.get('pid') if isinstance(existing, dict) else None - message = f'Slack app token already in use' + (f' (PID {owner_pid})' if owner_pid else '') + '. Stop the other gateway first.' - logger.error('[%s] %s', self.name, message) - self._set_fatal_error('slack_token_lock', message, retryable=False) + if not self._acquire_platform_lock('slack-app-token', app_token, 'Slack app token'): return False # First token is the primary — used for AsyncApp / Socket Mode @@ -247,14 +238,7 @@ class SlackAdapter(BasePlatformAdapter): logger.warning("[Slack] Error while closing Socket Mode handler: %s", e, exc_info=True) self._running = False - # Release the token lock (use stored identity, not re-read env) - try: - from gateway.status import release_scoped_lock - if getattr(self, '_token_lock_identity', None): - release_scoped_lock('slack-app-token', self._token_lock_identity) - self._token_lock_identity = None - except Exception: - pass + self._release_platform_lock() logger.info("[Slack] Disconnected") @@ -953,17 +937,8 @@ class SlackAdapter(BasePlatformAdapter): """Handle an incoming Slack message event.""" # Dedup: Slack Socket Mode can redeliver events after reconnects (#4777) event_ts = event.get("ts", "") - if event_ts: - now = time.time() - if event_ts in self._seen_messages: - return - self._seen_messages[event_ts] = now - if len(self._seen_messages) > self._SEEN_MAX: - cutoff = now - self._SEEN_TTL - self._seen_messages = { - k: v for k, v in self._seen_messages.items() - if v > cutoff - } + if event_ts and self._dedup.is_duplicate(event_ts): + return # Bot message filtering (SLACK_ALLOW_BOTS / config allow_bots): # "none" — ignore all bot messages (default, backward-compatible) diff --git a/gateway/platforms/sms.py b/gateway/platforms/sms.py index a0760199ba..161949dab3 100644 --- a/gateway/platforms/sms.py +++ b/gateway/platforms/sms.py @@ -10,6 +10,9 @@ Shares credentials with the optional telephony skill — same env vars: Gateway-specific env vars: - SMS_WEBHOOK_PORT (default 8080) + - SMS_WEBHOOK_HOST (default 0.0.0.0) + - SMS_WEBHOOK_URL (public URL for Twilio signature validation — required) + - SMS_INSECURE_NO_SIGNATURE (true to disable signature validation — dev only) - SMS_ALLOWED_USERS (comma-separated E.164 phone numbers) - SMS_ALLOW_ALL_USERS (true/false) - SMS_HOME_CHANNEL (phone number for cron delivery) @@ -17,9 +20,10 @@ Gateway-specific env vars: import asyncio import base64 +import hashlib +import hmac import logging import os -import re import urllib.parse from typing import Any, Dict, Optional @@ -30,24 +34,14 @@ from gateway.platforms.base import ( MessageType, SendResult, ) +from gateway.platforms.helpers import redact_phone, strip_markdown logger = logging.getLogger(__name__) TWILIO_API_BASE = "https://api.twilio.com/2010-04-01/Accounts" MAX_SMS_LENGTH = 1600 # ~10 SMS segments DEFAULT_WEBHOOK_PORT = 8080 - -# E.164 phone number pattern for redaction -_PHONE_RE = re.compile(r"\+[1-9]\d{6,14}") - - -def _redact_phone(phone: str) -> str: - """Redact a phone number for logging: +15551234567 -> +1555***4567.""" - if not phone: - return "" - if len(phone) <= 8: - return phone[:2] + "***" + phone[-2:] if len(phone) > 4 else "****" - return phone[:5] + "***" + phone[-4:] +DEFAULT_WEBHOOK_HOST = "0.0.0.0" def check_sms_requirements() -> bool: @@ -77,6 +71,8 @@ class SmsAdapter(BasePlatformAdapter): self._webhook_port: int = int( os.getenv("SMS_WEBHOOK_PORT", str(DEFAULT_WEBHOOK_PORT)) ) + self._webhook_host: str = os.getenv("SMS_WEBHOOK_HOST", DEFAULT_WEBHOOK_HOST) + self._webhook_url: str = os.getenv("SMS_WEBHOOK_URL", "").strip() self._runner = None self._http_session: Optional["aiohttp.ClientSession"] = None @@ -98,13 +94,33 @@ class SmsAdapter(BasePlatformAdapter): logger.error("[sms] TWILIO_PHONE_NUMBER not set — cannot send replies") return False + insecure_no_sig = os.getenv("SMS_INSECURE_NO_SIGNATURE", "").lower() == "true" + + if not self._webhook_url and not insecure_no_sig: + logger.error( + "[sms] Refusing to start: SMS_WEBHOOK_URL is required for Twilio " + "signature validation. Set it to the public URL configured in your " + "Twilio console (e.g. https://example.com/webhooks/twilio). " + "For local development without validation, set " + "SMS_INSECURE_NO_SIGNATURE=true (NOT recommended for production).", + ) + return False + + if insecure_no_sig and not self._webhook_url: + logger.warning( + "[sms] SMS_INSECURE_NO_SIGNATURE=true — Twilio signature validation " + "is DISABLED. Any client that can reach port %d can inject messages. " + "Do NOT use this in production.", + self._webhook_port, + ) + app = web.Application() app.router.add_post("/webhooks/twilio", self._handle_webhook) app.router.add_get("/health", lambda _: web.Response(text="ok")) self._runner = web.AppRunner(app) await self._runner.setup() - site = web.TCPSite(self._runner, "0.0.0.0", self._webhook_port) + site = web.TCPSite(self._runner, self._webhook_host, self._webhook_port) await site.start() self._http_session = aiohttp.ClientSession( timeout=aiohttp.ClientTimeout(total=30), @@ -112,9 +128,10 @@ class SmsAdapter(BasePlatformAdapter): self._running = True logger.info( - "[sms] Twilio webhook server listening on port %d, from: %s", + "[sms] Twilio webhook server listening on %s:%d, from: %s", + self._webhook_host, self._webhook_port, - _redact_phone(self._from_number), + redact_phone(self._from_number), ) return True @@ -163,7 +180,7 @@ class SmsAdapter(BasePlatformAdapter): error_msg = body.get("message", str(body)) logger.error( "[sms] send failed to %s: %s %s", - _redact_phone(chat_id), + redact_phone(chat_id), resp.status, error_msg, ) @@ -174,7 +191,7 @@ class SmsAdapter(BasePlatformAdapter): msg_sid = body.get("sid", "") last_result = SendResult(success=True, message_id=msg_sid) except Exception as e: - logger.error("[sms] send error to %s: %s", _redact_phone(chat_id), e) + logger.error("[sms] send error to %s: %s", redact_phone(chat_id), e) return SendResult(success=False, error=str(e)) finally: # Close session only if we created a fallback (no persistent session) @@ -192,16 +209,75 @@ class SmsAdapter(BasePlatformAdapter): def format_message(self, content: str) -> str: """Strip markdown — SMS renders it as literal characters.""" - content = re.sub(r"\*\*(.+?)\*\*", r"\1", content, flags=re.DOTALL) - content = re.sub(r"\*(.+?)\*", r"\1", content, flags=re.DOTALL) - content = re.sub(r"__(.+?)__", r"\1", content, flags=re.DOTALL) - content = re.sub(r"_(.+?)_", r"\1", content, flags=re.DOTALL) - content = re.sub(r"```[a-z]*\n?", "", content) - content = re.sub(r"`(.+?)`", r"\1", content) - content = re.sub(r"^#{1,6}\s+", "", content, flags=re.MULTILINE) - content = re.sub(r"\[([^\]]+)\]\([^\)]+\)", r"\1", content) - content = re.sub(r"\n{3,}", "\n\n", content) - return content.strip() + return strip_markdown(content) + + # ------------------------------------------------------------------ + # Twilio signature validation + # ------------------------------------------------------------------ + + def _validate_twilio_signature( + self, url: str, post_params: dict, signature: str, + ) -> bool: + """Validate ``X-Twilio-Signature`` header (HMAC-SHA1, base64). + + Tries both with and without the default port for the URL scheme, + since Twilio may sign with either variant. + + Algorithm: https://www.twilio.com/docs/usage/security#validating-requests + """ + if self._check_signature(url, post_params, signature): + return True + + variant = self._port_variant_url(url) + if variant and self._check_signature(variant, post_params, signature): + return True + + return False + + def _check_signature( + self, url: str, post_params: dict, signature: str, + ) -> bool: + """Compute and compare a single Twilio signature.""" + data_to_sign = url + for key in sorted(post_params.keys()): + data_to_sign += key + post_params[key] + mac = hmac.new( + self._auth_token.encode("utf-8"), + data_to_sign.encode("utf-8"), + hashlib.sha1, + ) + computed = base64.b64encode(mac.digest()).decode("utf-8") + return hmac.compare_digest(computed, signature) + + @staticmethod + def _port_variant_url(url: str) -> str | None: + """Return the URL with the default port toggled, or None. + + Only toggles default ports (443 for https, 80 for http). + Non-standard ports are never modified. + """ + parsed = urllib.parse.urlparse(url) + default_ports = {"https": 443, "http": 80} + default_port = default_ports.get(parsed.scheme) + if default_port is None: + return None + + if parsed.port == default_port: + # Has explicit default port → strip it + return urllib.parse.urlunparse( + (parsed.scheme, parsed.hostname, parsed.path, + parsed.params, parsed.query, parsed.fragment) + ) + elif parsed.port is None: + # No port → add default + netloc = f"{parsed.hostname}:{default_port}" + return urllib.parse.urlunparse( + (parsed.scheme, netloc, parsed.path, + parsed.params, parsed.query, parsed.fragment) + ) + + # Non-standard port — no variant + return None # ------------------------------------------------------------------ # Twilio webhook handler @@ -213,7 +289,7 @@ class SmsAdapter(BasePlatformAdapter): try: raw = await request.read() # Twilio sends form-encoded data, not JSON - form = urllib.parse.parse_qs(raw.decode("utf-8")) + form = urllib.parse.parse_qs(raw.decode("utf-8"), keep_blank_values=True) except Exception as e: logger.error("[sms] webhook parse error: %s", e) return web.Response( @@ -222,6 +298,27 @@ class SmsAdapter(BasePlatformAdapter): status=400, ) + # Validate Twilio request signature when SMS_WEBHOOK_URL is configured + if self._webhook_url: + twilio_sig = request.headers.get("X-Twilio-Signature", "") + if not twilio_sig: + logger.warning("[sms] Rejected: missing X-Twilio-Signature header") + return web.Response( + text='', + content_type="application/xml", + status=403, + ) + flat_params = {k: v[0] for k, v in form.items() if v} + if not self._validate_twilio_signature( + self._webhook_url, flat_params, twilio_sig + ): + logger.warning("[sms] Rejected: invalid Twilio signature") + return web.Response( + text='', + content_type="application/xml", + status=403, + ) + # Extract fields (parse_qs returns lists) from_number = (form.get("From", [""]))[0].strip() to_number = (form.get("To", [""]))[0].strip() @@ -236,7 +333,7 @@ class SmsAdapter(BasePlatformAdapter): # Ignore messages from our own number (echo prevention) if from_number == self._from_number: - logger.debug("[sms] ignoring echo from own number %s", _redact_phone(from_number)) + logger.debug("[sms] ignoring echo from own number %s", redact_phone(from_number)) return web.Response( text='', content_type="application/xml", @@ -244,8 +341,8 @@ class SmsAdapter(BasePlatformAdapter): logger.info( "[sms] inbound from %s -> %s: %s", - _redact_phone(from_number), - _redact_phone(to_number), + redact_phone(from_number), + redact_phone(to_number), text[:80], ) diff --git a/gateway/platforms/telegram.py b/gateway/platforms/telegram.py index 8b4e43514b..2653296026 100644 --- a/gateway/platforms/telegram.py +++ b/gateway/platforms/telegram.py @@ -147,7 +147,6 @@ class TelegramAdapter(BasePlatformAdapter): self._text_batch_split_delay_seconds = float(os.getenv("HERMES_TELEGRAM_TEXT_BATCH_SPLIT_DELAY_SECONDS", "2.0")) self._pending_text_batches: Dict[str, MessageEvent] = {} self._pending_text_batch_tasks: Dict[str, asyncio.Task] = {} - self._token_lock_identity: Optional[str] = None self._polling_error_task: Optional[asyncio.Task] = None self._polling_conflict_count: int = 0 self._polling_network_error_count: int = 0 @@ -300,9 +299,11 @@ class TelegramAdapter(BasePlatformAdapter): # Exhausted retries — fatal message = ( - "Another Telegram bot poller is already using this token. " + "Another process is already polling this Telegram bot token " + "(possibly OpenClaw or another Hermes instance). " "Hermes stopped Telegram polling after %d retries. " - "Make sure only one gateway instance is running for this bot token." + "Only one poller can run per token — stop the other process " + "and restart with 'hermes start'." % MAX_CONFLICT_RETRIES ) logger.error("[%s] %s Original error: %s", self.name, message, error) @@ -497,23 +498,7 @@ class TelegramAdapter(BasePlatformAdapter): return False try: - from gateway.status import acquire_scoped_lock - - self._token_lock_identity = self.config.token - acquired, existing = acquire_scoped_lock( - "telegram-bot-token", - self._token_lock_identity, - metadata={"platform": self.platform.value}, - ) - if not acquired: - owner_pid = existing.get("pid") if isinstance(existing, dict) else None - message = ( - "Another local Hermes gateway is already using this Telegram bot token" - + (f" (PID {owner_pid})." if owner_pid else ".") - + " Stop the other gateway before starting a second Telegram poller." - ) - logger.error("[%s] %s", self.name, message) - self._set_fatal_error("telegram_token_lock", message, retryable=False) + if not self._acquire_platform_lock('telegram-bot-token', self.config.token, 'Telegram bot token'): return False # Build the application @@ -737,12 +722,7 @@ class TelegramAdapter(BasePlatformAdapter): return True except Exception as e: - if self._token_lock_identity: - try: - from gateway.status import release_scoped_lock - release_scoped_lock("telegram-bot-token", self._token_lock_identity) - except Exception: - pass + self._release_platform_lock() message = f"Telegram startup failed: {e}" self._set_fatal_error("telegram_connect_error", message, retryable=True) logger.error("[%s] Failed to connect to Telegram: %s", self.name, e, exc_info=True) @@ -768,12 +748,7 @@ class TelegramAdapter(BasePlatformAdapter): await self._app.shutdown() except Exception as e: logger.warning("[%s] Error during Telegram disconnect: %s", self.name, e, exc_info=True) - if self._token_lock_identity: - try: - from gateway.status import release_scoped_lock - release_scoped_lock("telegram-bot-token", self._token_lock_identity) - except Exception as e: - logger.warning("[%s] Error releasing Telegram token lock: %s", self.name, e, exc_info=True) + self._release_platform_lock() for task in self._pending_photo_batch_tasks.values(): if task and not task.done(): @@ -784,7 +759,6 @@ class TelegramAdapter(BasePlatformAdapter): self._mark_disconnected() self._app = None self._bot = None - self._token_lock_identity = None logger.info("[%s] Disconnected from Telegram", self.name) def _should_thread_reply(self, reply_to: Optional[str], chunk_index: int) -> bool: diff --git a/gateway/platforms/wecom.py b/gateway/platforms/wecom.py index 6fde73927b..a0e71e01b6 100644 --- a/gateway/platforms/wecom.py +++ b/gateway/platforms/wecom.py @@ -59,6 +59,7 @@ except ImportError: httpx = None # type: ignore[assignment] from gateway.config import Platform, PlatformConfig +from gateway.platforms.helpers import MessageDeduplicator from gateway.platforms.base import ( BasePlatformAdapter, MessageEvent, @@ -92,7 +93,6 @@ REQUEST_TIMEOUT_SECONDS = 15.0 HEARTBEAT_INTERVAL_SECONDS = 30.0 RECONNECT_BACKOFF = [2, 5, 10, 30, 60] -DEDUP_WINDOW_SECONDS = 300 DEDUP_MAX_SIZE = 1000 IMAGE_MAX_BYTES = 10 * 1024 * 1024 @@ -172,7 +172,7 @@ class WeComAdapter(BasePlatformAdapter): self._listen_task: Optional[asyncio.Task] = None self._heartbeat_task: Optional[asyncio.Task] = None self._pending_responses: Dict[str, asyncio.Future] = {} - self._seen_messages: Dict[str, float] = {} + self._dedup = MessageDeduplicator(max_size=DEDUP_MAX_SIZE) self._reply_req_ids: Dict[str, str] = {} # Text batching: merge rapid successive messages (Telegram-style). @@ -250,7 +250,7 @@ class WeComAdapter(BasePlatformAdapter): await self._http_client.aclose() self._http_client = None - self._seen_messages.clear() + self._dedup.clear() logger.info("[%s] Disconnected", self.name) async def _cleanup_ws(self) -> None: @@ -476,7 +476,7 @@ class WeComAdapter(BasePlatformAdapter): return msg_id = str(body.get("msgid") or self._payload_req_id(payload) or uuid.uuid4().hex) - if self._is_duplicate(msg_id): + if self._dedup.is_duplicate(msg_id): logger.debug("[%s] Duplicate message %s ignored", self.name, msg_id) return self._remember_reply_req_id(msg_id, self._payload_req_id(payload)) @@ -636,6 +636,13 @@ class WeComAdapter(BasePlatformAdapter): if voice_text: text_parts.append(voice_text) + # Extract appmsg title (filename) for WeCom AI Bot attachments + if msgtype == "appmsg": + appmsg = body.get("appmsg") if isinstance(body.get("appmsg"), dict) else {} + title = str(appmsg.get("title") or "").strip() + if title: + text_parts.append(title) + quote = body.get("quote") if isinstance(body.get("quote"), dict) else {} quote_type = str(quote.get("msgtype") or "").lower() if quote_type == "text": @@ -668,6 +675,13 @@ class WeComAdapter(BasePlatformAdapter): refs.append(("image", body["image"])) if msgtype == "file" and isinstance(body.get("file"), dict): refs.append(("file", body["file"])) + # Handle appmsg (WeCom AI Bot attachments with PDF/Word/Excel) + if msgtype == "appmsg" and isinstance(body.get("appmsg"), dict): + appmsg = body["appmsg"] + if isinstance(appmsg.get("file"), dict): + refs.append(("file", appmsg["file"])) + elif isinstance(appmsg.get("image"), dict): + refs.append(("image", appmsg["image"])) quote = body.get("quote") if isinstance(body.get("quote"), dict) else {} quote_type = str(quote.get("msgtype") or "").lower() @@ -825,24 +839,6 @@ class WeComAdapter(BasePlatformAdapter): wildcard = self._groups.get("*") return wildcard if isinstance(wildcard, dict) else {} - def _is_duplicate(self, msg_id: str) -> bool: - now = time.time() - if len(self._seen_messages) > DEDUP_MAX_SIZE: - cutoff = now - DEDUP_WINDOW_SECONDS - self._seen_messages = { - key: ts for key, ts in self._seen_messages.items() if ts > cutoff - } - if self._reply_req_ids: - self._reply_req_ids = { - key: value for key, value in self._reply_req_ids.items() if key in self._seen_messages - } - - if msg_id in self._seen_messages: - return True - - self._seen_messages[msg_id] = now - return False - def _remember_reply_req_id(self, message_id: str, req_id: str) -> None: normalized_message_id = str(message_id or "").strip() normalized_req_id = str(req_id or "").strip() diff --git a/gateway/platforms/weixin.py b/gateway/platforms/weixin.py index e25bb350f8..5821d922f8 100644 --- a/gateway/platforms/weixin.py +++ b/gateway/platforms/weixin.py @@ -53,6 +53,7 @@ except ImportError: # pragma: no cover - dependency gate CRYPTO_AVAILABLE = False from gateway.config import Platform, PlatformConfig +from gateway.platforms.helpers import MessageDeduplicator from gateway.platforms.base import ( BasePlatformAdapter, MessageEvent, @@ -63,6 +64,7 @@ from gateway.platforms.base import ( cache_image_from_bytes, ) from hermes_constants import get_hermes_home +from utils import atomic_json_write ILINK_BASE_URL = "https://ilinkai.weixin.qq.com" WEIXIN_CDN_BASE_URL = "https://novac2c.cdn.weixin.qq.com/c2c" @@ -206,7 +208,7 @@ def save_weixin_account( "saved_at": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()), } path = _account_file(hermes_home, account_id) - path.write_text(json.dumps(payload, indent=2), encoding="utf-8") + atomic_json_write(path, payload) try: path.chmod(0o600) except OSError: @@ -269,7 +271,7 @@ class ContextTokenStore: if key.startswith(prefix) } try: - self._path(account_id).write_text(json.dumps(payload), encoding="utf-8") + atomic_json_write(self._path(account_id), payload) except Exception as exc: logger.warning("weixin: failed to persist context tokens for %s: %s", _safe_id(account_id), exc) @@ -868,7 +870,7 @@ def _load_sync_buf(hermes_home: str, account_id: str) -> str: def _save_sync_buf(hermes_home: str, account_id: str, sync_buf: str) -> None: path = _sync_buf_path(hermes_home, account_id) - path.write_text(json.dumps({"get_updates_buf": sync_buf}), encoding="utf-8") + atomic_json_write(path, {"get_updates_buf": sync_buf}) async def qr_login( @@ -1007,8 +1009,7 @@ class WeixinAdapter(BasePlatformAdapter): self._typing_cache = TypingTicketCache() self._session: Optional[aiohttp.ClientSession] = None self._poll_task: Optional[asyncio.Task] = None - self._seen_messages: Dict[str, float] = {} - self._token_lock_identity: Optional[str] = None + self._dedup = MessageDeduplicator(ttl_seconds=MESSAGE_DEDUP_TTL_SECONDS) self._account_id = str(extra.get("account_id") or os.getenv("WEIXIN_ACCOUNT_ID", "")).strip() self._token = str(config.token or extra.get("token") or os.getenv("WEIXIN_TOKEN", "")).strip() @@ -1016,6 +1017,16 @@ class WeixinAdapter(BasePlatformAdapter): self._cdn_base_url = str( extra.get("cdn_base_url") or os.getenv("WEIXIN_CDN_BASE_URL", WEIXIN_CDN_BASE_URL) ).strip().rstrip("/") + self._send_chunk_delay_seconds = float( + extra.get("send_chunk_delay_seconds") or os.getenv("WEIXIN_SEND_CHUNK_DELAY_SECONDS", "0.35") + ) + self._send_chunk_retries = int( + extra.get("send_chunk_retries") or os.getenv("WEIXIN_SEND_CHUNK_RETRIES", "2") + ) + self._send_chunk_retry_delay_seconds = float( + extra.get("send_chunk_retry_delay_seconds") + or os.getenv("WEIXIN_SEND_CHUNK_RETRY_DELAY_SECONDS", "1.0") + ) self._dm_policy = str(extra.get("dm_policy") or os.getenv("WEIXIN_DM_POLICY", "open")).strip().lower() self._group_policy = str(extra.get("group_policy") or os.getenv("WEIXIN_GROUP_POLICY", "disabled")).strip().lower() allow_from = extra.get("allow_from") @@ -1066,23 +1077,7 @@ class WeixinAdapter(BasePlatformAdapter): return False try: - from gateway.status import acquire_scoped_lock - - self._token_lock_identity = self._token - acquired, existing = acquire_scoped_lock( - "weixin-bot-token", - self._token_lock_identity, - metadata={"platform": self.platform.value}, - ) - if not acquired: - owner_pid = existing.get("pid") if isinstance(existing, dict) else None - message = ( - "Another local Hermes gateway is already using this Weixin token" - + (f" (PID {owner_pid})." if owner_pid else ".") - + " Stop the other gateway before starting a second Weixin poller." - ) - logger.error("[%s] %s", self.name, message) - self._set_fatal_error("weixin_token_lock", message, retryable=False) + if not self._acquire_platform_lock('weixin-bot-token', self._token, 'Weixin bot token'): return False except Exception as exc: logger.debug("[%s] Token lock unavailable (non-fatal): %s", self.name, exc) @@ -1106,12 +1101,7 @@ class WeixinAdapter(BasePlatformAdapter): if self._session and not self._session.closed: await self._session.close() self._session = None - if self._token_lock_identity: - try: - from gateway.status import release_scoped_lock - release_scoped_lock("weixin-bot-token", self._token_lock_identity) - except Exception as exc: - logger.warning("[%s] Error releasing Weixin token lock: %s", self.name, exc, exc_info=True) + self._release_platform_lock() self._mark_disconnected() logger.info("[%s] Disconnected", self.name) @@ -1189,16 +1179,8 @@ class WeixinAdapter(BasePlatformAdapter): return message_id = str(message.get("message_id") or "").strip() - if message_id: - now = time.time() - self._seen_messages = { - key: value - for key, value in self._seen_messages.items() - if now - value < MESSAGE_DEDUP_TTL_SECONDS - } - if message_id in self._seen_messages: - return - self._seen_messages[message_id] = now + if message_id and self._dedup.is_duplicate(message_id): + return chat_type, effective_chat_id = _guess_chat_type(message, self._account_id) if chat_type == "group": @@ -1374,6 +1356,47 @@ class WeixinAdapter(BasePlatformAdapter): content, self.MAX_MESSAGE_LENGTH, self._split_multiline_messages, ) + async def _send_text_chunk( + self, + *, + chat_id: str, + chunk: str, + context_token: Optional[str], + client_id: str, + ) -> None: + """Send a single text chunk with per-chunk retry and backoff.""" + last_error: Optional[Exception] = None + for attempt in range(self._send_chunk_retries + 1): + try: + await _send_message( + self._session, + base_url=self._base_url, + token=self._token, + to=chat_id, + text=chunk, + context_token=context_token, + client_id=client_id, + ) + return + except Exception as exc: + last_error = exc + if attempt >= self._send_chunk_retries: + break + wait = self._send_chunk_retry_delay_seconds * (attempt + 1) + logger.warning( + "[%s] send chunk failed to=%s attempt=%d/%d, retrying in %.2fs: %s", + self.name, + _safe_id(chat_id), + attempt + 1, + self._send_chunk_retries + 1, + wait, + exc, + ) + if wait > 0: + await asyncio.sleep(wait) + assert last_error is not None + raise last_error + async def send( self, chat_id: str, @@ -1388,19 +1411,16 @@ class WeixinAdapter(BasePlatformAdapter): try: chunks = self._split_text(self.format_message(content)) for idx, chunk in enumerate(chunks): - if idx > 0: - await asyncio.sleep(0.3) client_id = f"hermes-weixin-{uuid.uuid4().hex}" - await _send_message( - self._session, - base_url=self._base_url, - token=self._token, - to=chat_id, - text=chunk, + await self._send_text_chunk( + chat_id=chat_id, + chunk=chunk, context_token=context_token, client_id=client_id, ) last_message_id = client_id + if idx < len(chunks) - 1 and self._send_chunk_delay_seconds > 0: + await asyncio.sleep(self._send_chunk_delay_seconds) return SendResult(success=True, message_id=last_message_id) except Exception as exc: logger.error("[%s] send failed to=%s: %s", self.name, _safe_id(chat_id), exc) diff --git a/gateway/platforms/whatsapp.py b/gateway/platforms/whatsapp.py index a6475dcb80..c616f72448 100644 --- a/gateway/platforms/whatsapp.py +++ b/gateway/platforms/whatsapp.py @@ -145,7 +145,6 @@ class WhatsAppAdapter(BasePlatformAdapter): self._bridge_log: Optional[Path] = None self._poll_task: Optional[asyncio.Task] = None self._http_session: Optional["aiohttp.ClientSession"] = None - self._session_lock_identity: Optional[str] = None def _whatsapp_require_mention(self) -> bool: configured = self.config.extra.get("require_mention") @@ -290,23 +289,7 @@ class WhatsAppAdapter(BasePlatformAdapter): # Acquire scoped lock to prevent duplicate sessions try: - from gateway.status import acquire_scoped_lock - - self._session_lock_identity = str(self._session_path) - acquired, existing = acquire_scoped_lock( - "whatsapp-session", - self._session_lock_identity, - metadata={"platform": self.platform.value}, - ) - if not acquired: - owner_pid = existing.get("pid") if isinstance(existing, dict) else None - message = ( - "Another local Hermes gateway is already using this WhatsApp session" - + (f" (PID {owner_pid})." if owner_pid else ".") - + " Stop the other gateway before starting a second WhatsApp bridge." - ) - logger.error("[%s] %s", self.name, message) - self._set_fatal_error("whatsapp_session_lock", message, retryable=False) + if not self._acquire_platform_lock('whatsapp-session', str(self._session_path), 'WhatsApp session'): return False except Exception as e: logger.warning("[%s] Could not acquire session lock (non-fatal): %s", self.name, e) @@ -468,12 +451,7 @@ class WhatsAppAdapter(BasePlatformAdapter): return True except Exception as e: - if self._session_lock_identity: - try: - from gateway.status import release_scoped_lock - release_scoped_lock("whatsapp-session", self._session_lock_identity) - except Exception: - pass + self._release_platform_lock() logger.error("[%s] Failed to start bridge: %s", self.name, e, exc_info=True) self._close_bridge_log() return False @@ -546,17 +524,11 @@ class WhatsAppAdapter(BasePlatformAdapter): await self._http_session.close() self._http_session = None - if self._session_lock_identity: - try: - from gateway.status import release_scoped_lock - release_scoped_lock("whatsapp-session", self._session_lock_identity) - except Exception as e: - logger.warning("[%s] Error releasing WhatsApp session lock: %s", self.name, e, exc_info=True) + self._release_platform_lock() self._mark_disconnected() self._bridge_process = None self._close_bridge_log() - self._session_lock_identity = None print(f"[{self.name}] Disconnected") async def send( diff --git a/gateway/run.py b/gateway/run.py index 362b8650b6..d577fd34e1 100644 --- a/gateway/run.py +++ b/gateway/run.py @@ -352,19 +352,14 @@ def _build_media_placeholder(event) -> str: return "\n".join(parts) -def _dequeue_pending_text(adapter, session_key: str) -> str | None: - """Consume and return the text of a pending queued message. +def _dequeue_pending_event(adapter, session_key: str) -> MessageEvent | None: + """Consume and return the full pending event for a session. - Preserves media context for captionless photo/document events by - building a placeholder so the message isn't silently dropped. + Queued follow-ups must preserve their media metadata so they can re-enter + the normal image/STT/document preprocessing path instead of being reduced + to a placeholder string. """ - event = adapter.get_pending_message(session_key) - if not event: - return None - text = event.text - if not text and getattr(event, "media_urls", None): - text = _build_media_placeholder(event) - return text + return adapter.get_pending_message(session_key) def _check_unavailable_skill(command_name: str) -> str | None: @@ -1465,7 +1460,18 @@ class GatewayRunner: logger.info("Recovered %s background process(es) from previous run", recovered) except Exception as e: logger.warning("Process checkpoint recovery: %s", e) - + + # Suspend sessions that were active when the gateway last exited. + # This prevents stuck sessions from being blindly resumed on restart, + # which can create an unrecoverable loop (#7536). Suspended sessions + # auto-reset on the next incoming message, giving the user a clean start. + try: + suspended = self.session_store.suspend_recently_active() + if suspended: + logger.info("Suspended %d in-flight session(s) from previous run", suspended) + except Exception as e: + logger.warning("Session suspension on startup failed: %s", e) + connected_count = 0 enabled_platform_count = 0 startup_nonretryable_errors: list[str] = [] @@ -2221,6 +2227,13 @@ class GatewayRunner: # are system-generated and must skip user authorization. if getattr(event, "internal", False): pass + elif source.user_id is None: + # Messages with no user identity (Telegram service messages, + # channel forwards, anonymous admin actions) cannot be + # authorized — drop silently instead of triggering the pairing + # flow with a None user_id. + logger.debug("Ignoring message with no user_id from %s", source.platform.value) + return None elif not self._is_user_authorized(source): logger.warning("Unauthorized user: %s (%s) on %s", source.user_id, source.user_name, source.platform.value) # In DMs: offer pairing code. In groups: silently ignore. @@ -2370,8 +2383,11 @@ class GatewayRunner: self._pending_messages.pop(_quick_key, None) if _quick_key in self._running_agents: del self._running_agents[_quick_key] - logger.info("HARD STOP for session %s — session lock released", _quick_key[:20]) - return "⚡ Force-stopped. The session is unlocked — you can send a new message." + # Mark session suspended so the next message starts fresh + # instead of resuming the stuck context (#7536). + self.session_store.suspend_session(_quick_key) + logger.info("HARD STOP for session %s — suspended, session lock released", _quick_key[:20]) + return "⚡ Force-stopped. The session is suspended — your next message will start fresh." # /reset and /new must bypass the running-agent guard so they # actually dispatch as commands instead of being queued as user @@ -2761,6 +2777,162 @@ class GatewayRunner: del self._running_agents[_quick_key] self._running_agents_ts.pop(_quick_key, None) + async def _prepare_inbound_message_text( + self, + *, + event: MessageEvent, + source: SessionSource, + history: List[Dict[str, Any]], + ) -> Optional[str]: + """Prepare inbound event text for the agent. + + Keep the normal inbound path and the queued follow-up path on the same + preprocessing pipeline so sender attribution, image enrichment, STT, + document notes, reply context, and @ references all behave the same. + """ + history = history or [] + message_text = event.text or "" + + _is_shared_thread = ( + source.chat_type != "dm" + and source.thread_id + and not getattr(self.config, "thread_sessions_per_user", False) + ) + if _is_shared_thread and source.user_name: + message_text = f"[{source.user_name}] {message_text}" + + if event.media_urls: + image_paths = [] + audio_paths = [] + for i, path in enumerate(event.media_urls): + mtype = event.media_types[i] if i < len(event.media_types) else "" + if mtype.startswith("image/") or event.message_type == MessageType.PHOTO: + image_paths.append(path) + if mtype.startswith("audio/") or event.message_type in (MessageType.VOICE, MessageType.AUDIO): + audio_paths.append(path) + + if image_paths: + message_text = await self._enrich_message_with_vision( + message_text, + image_paths, + ) + + if audio_paths: + message_text = await self._enrich_message_with_transcription( + message_text, + audio_paths, + ) + _stt_fail_markers = ( + "No STT provider", + "STT is disabled", + "can't listen", + "VOICE_TOOLS_OPENAI_KEY", + ) + if any(marker in message_text for marker in _stt_fail_markers): + _stt_adapter = self.adapters.get(source.platform) + _stt_meta = {"thread_id": source.thread_id} if source.thread_id else None + if _stt_adapter: + try: + _stt_msg = ( + "🎤 I received your voice message but can't transcribe it — " + "no speech-to-text provider is configured.\n\n" + "To enable voice: install faster-whisper " + "(`pip install faster-whisper` in the Hermes venv) " + "and set `stt.enabled: true` in config.yaml, " + "then /restart the gateway." + ) + if self._has_setup_skill(): + _stt_msg += "\n\nFor full setup instructions, type: `/skill hermes-agent-setup`" + await _stt_adapter.send( + source.chat_id, + _stt_msg, + metadata=_stt_meta, + ) + except Exception: + pass + + if event.media_urls and event.message_type == MessageType.DOCUMENT: + import mimetypes as _mimetypes + + _TEXT_EXTENSIONS = {".txt", ".md", ".csv", ".log", ".json", ".xml", ".yaml", ".yml", ".toml", ".ini", ".cfg"} + for i, path in enumerate(event.media_urls): + mtype = event.media_types[i] if i < len(event.media_types) else "" + if mtype in ("", "application/octet-stream"): + import os as _os2 + + _ext = _os2.path.splitext(path)[1].lower() + if _ext in _TEXT_EXTENSIONS: + mtype = "text/plain" + else: + guessed, _ = _mimetypes.guess_type(path) + if guessed: + mtype = guessed + if not mtype.startswith(("application/", "text/")): + continue + + import os as _os + import re as _re + + basename = _os.path.basename(path) + parts = basename.split("_", 2) + display_name = parts[2] if len(parts) >= 3 else basename + display_name = _re.sub(r'[^\w.\- ]', '_', display_name) + + if mtype.startswith("text/"): + context_note = ( + f"[The user sent a text document: '{display_name}'. " + f"Its content has been included below. " + f"The file is also saved at: {path}]" + ) + else: + context_note = ( + f"[The user sent a document: '{display_name}'. " + f"The file is saved at: {path}. " + f"Ask the user what they'd like you to do with it.]" + ) + message_text = f"{context_note}\n\n{message_text}" + + if getattr(event, "reply_to_text", None) and event.reply_to_message_id: + reply_snippet = event.reply_to_text[:500] + found_in_history = any( + reply_snippet[:200] in (msg.get("content") or "") + for msg in history + if msg.get("role") in ("assistant", "user", "tool") + ) + if not found_in_history: + message_text = f'[Replying to: "{reply_snippet}"]\n\n{message_text}' + + if "@" in message_text: + try: + from agent.context_references import preprocess_context_references_async + from agent.model_metadata import get_model_context_length + + _msg_cwd = os.environ.get("MESSAGING_CWD", os.path.expanduser("~")) + _msg_ctx_len = get_model_context_length( + self._model, + base_url=self._base_url or "", + ) + _ctx_result = await preprocess_context_references_async( + message_text, + cwd=_msg_cwd, + context_length=_msg_ctx_len, + allowed_root=_msg_cwd, + ) + if _ctx_result.blocked: + _adapter = self.adapters.get(source.platform) + if _adapter: + await _adapter.send( + source.chat_id, + "\n".join(_ctx_result.warnings) or "Context injection refused.", + ) + return None + if _ctx_result.expanded: + message_text = _ctx_result.message + except Exception as exc: + logger.debug("@ context reference expansion failed: %s", exc) + + return message_text + async def _handle_message_with_agent(self, event, source, _quick_key: str): """Inner handler that runs under the _running_agents sentinel guard.""" _msg_start_time = time.time() @@ -2812,7 +2984,9 @@ class GatewayRunner: # so the agent knows this is a fresh conversation (not an intentional /reset). if getattr(session_entry, 'was_auto_reset', False): reset_reason = getattr(session_entry, 'auto_reset_reason', None) or 'idle' - if reset_reason == "daily": + if reset_reason == "suspended": + context_note = "[System note: The user's previous session was stopped and suspended. This is a fresh conversation with no prior context.]" + elif reset_reason == "daily": context_note = "[System note: The user's session was automatically reset by the daily schedule. This is a fresh conversation with no prior context.]" else: context_note = "[System note: The user's previous session expired due to inactivity. This is a fresh conversation with no prior context.]" @@ -2829,7 +3003,9 @@ class GatewayRunner: ) platform_name = source.platform.value if source.platform else "" had_activity = getattr(session_entry, 'reset_had_activity', False) - should_notify = ( + # Suspended sessions always notify (they were explicitly stopped + # or crashed mid-operation) — skip the policy check. + should_notify = reset_reason == "suspended" or ( policy.notify and had_activity and platform_name not in policy.notify_exclude_platforms @@ -2837,7 +3013,9 @@ class GatewayRunner: if should_notify: adapter = self.adapters.get(source.platform) if adapter: - if reset_reason == "daily": + if reset_reason == "suspended": + reason_text = "previous session was stopped or interrupted" + elif reset_reason == "daily": reason_text = f"daily schedule at {policy.at_hour}:00" else: hours = policy.idle_minutes // 60 @@ -3195,149 +3373,13 @@ class GatewayRunner: # attachments (documents, audio, etc.) are not sent to the vision # tool even when they appear in the same message. # ----------------------------------------------------------------- - message_text = event.text or "" - - # ----------------------------------------------------------------- - # Sender attribution for shared thread sessions. - # - # When multiple users share a single thread session (the default for - # threads), prefix each message with [sender name] so the agent can - # tell participants apart. Skip for DMs (single-user by nature) and - # when per-user thread isolation is explicitly enabled. - # ----------------------------------------------------------------- - _is_shared_thread = ( - source.chat_type != "dm" - and source.thread_id - and not getattr(self.config, "thread_sessions_per_user", False) + message_text = await self._prepare_inbound_message_text( + event=event, + source=source, + history=history, ) - if _is_shared_thread and source.user_name: - message_text = f"[{source.user_name}] {message_text}" - - if event.media_urls: - image_paths = [] - for i, path in enumerate(event.media_urls): - # Check media_types if available; otherwise infer from message type - mtype = event.media_types[i] if i < len(event.media_types) else "" - is_image = ( - mtype.startswith("image/") - or event.message_type == MessageType.PHOTO - ) - if is_image: - image_paths.append(path) - if image_paths: - message_text = await self._enrich_message_with_vision( - message_text, image_paths - ) - - # ----------------------------------------------------------------- - # Auto-transcribe voice/audio messages sent by the user - # ----------------------------------------------------------------- - if event.media_urls: - audio_paths = [] - for i, path in enumerate(event.media_urls): - mtype = event.media_types[i] if i < len(event.media_types) else "" - is_audio = ( - mtype.startswith("audio/") - or event.message_type in (MessageType.VOICE, MessageType.AUDIO) - ) - if is_audio: - audio_paths.append(path) - if audio_paths: - message_text = await self._enrich_message_with_transcription( - message_text, audio_paths - ) - # If STT failed, send a direct message to the user so they - # know voice isn't configured — don't rely on the agent to - # relay the error clearly. - _stt_fail_markers = ( - "No STT provider", - "STT is disabled", - "can't listen", - "VOICE_TOOLS_OPENAI_KEY", - ) - if any(m in message_text for m in _stt_fail_markers): - _stt_adapter = self.adapters.get(source.platform) - _stt_meta = {"thread_id": source.thread_id} if source.thread_id else None - if _stt_adapter: - try: - _stt_msg = ( - "🎤 I received your voice message but can't transcribe it — " - "no speech-to-text provider is configured.\n\n" - "To enable voice: install faster-whisper " - "(`pip install faster-whisper` in the Hermes venv) " - "and set `stt.enabled: true` in config.yaml, " - "then /restart the gateway." - ) - # Point to setup skill if it's installed - if self._has_setup_skill(): - _stt_msg += "\n\nFor full setup instructions, type: `/skill hermes-agent-setup`" - await _stt_adapter.send( - source.chat_id, _stt_msg, - metadata=_stt_meta, - ) - except Exception: - pass - - # ----------------------------------------------------------------- - # Enrich document messages with context notes for the agent - # ----------------------------------------------------------------- - if event.media_urls and event.message_type == MessageType.DOCUMENT: - import mimetypes as _mimetypes - _TEXT_EXTENSIONS = {".txt", ".md", ".csv", ".log", ".json", ".xml", ".yaml", ".yml", ".toml", ".ini", ".cfg"} - for i, path in enumerate(event.media_urls): - mtype = event.media_types[i] if i < len(event.media_types) else "" - # Fall back to extension-based detection when MIME type is unreliable. - if mtype in ("", "application/octet-stream"): - import os as _os2 - _ext = _os2.path.splitext(path)[1].lower() - if _ext in _TEXT_EXTENSIONS: - mtype = "text/plain" - else: - guessed, _ = _mimetypes.guess_type(path) - if guessed: - mtype = guessed - if not mtype.startswith(("application/", "text/")): - continue - # Extract display filename by stripping the doc_{uuid12}_ prefix - import os as _os - basename = _os.path.basename(path) - # Format: doc_<12hex>_ - parts = basename.split("_", 2) - display_name = parts[2] if len(parts) >= 3 else basename - # Sanitize to prevent prompt injection via filenames - import re as _re - display_name = _re.sub(r'[^\w.\- ]', '_', display_name) - - if mtype.startswith("text/"): - context_note = ( - f"[The user sent a text document: '{display_name}'. " - f"Its content has been included below. " - f"The file is also saved at: {path}]" - ) - else: - context_note = ( - f"[The user sent a document: '{display_name}'. " - f"The file is saved at: {path}. " - f"Ask the user what they'd like you to do with it.]" - ) - message_text = f"{context_note}\n\n{message_text}" - - # ----------------------------------------------------------------- - # Inject reply context when user replies to a message not in history. - # Telegram (and other platforms) let users reply to specific messages, - # but if the quoted message is from a previous session, cron delivery, - # or background task, the agent has no context about what's being - # referenced. Prepend the quoted text so the agent understands. (#1594) - # ----------------------------------------------------------------- - if getattr(event, 'reply_to_text', None) and event.reply_to_message_id: - reply_snippet = event.reply_to_text[:500] - found_in_history = any( - reply_snippet[:200] in (msg.get("content") or "") - for msg in history - if msg.get("role") in ("assistant", "user", "tool") - ) - if not found_in_history: - message_text = f'[Replying to: "{reply_snippet}"]\n\n{message_text}' + if message_text is None: + return try: # Emit agent:start hook @@ -3349,30 +3391,6 @@ class GatewayRunner: } await self.hooks.emit("agent:start", hook_ctx) - # Expand @ context references (@file:, @folder:, @diff, etc.) - if "@" in message_text: - try: - from agent.context_references import preprocess_context_references_async - from agent.model_metadata import get_model_context_length - _msg_cwd = os.environ.get("MESSAGING_CWD", os.path.expanduser("~")) - _msg_ctx_len = get_model_context_length( - self._model, base_url=self._base_url or "") - _ctx_result = await preprocess_context_references_async( - message_text, cwd=_msg_cwd, - context_length=_msg_ctx_len, allowed_root=_msg_cwd) - if _ctx_result.blocked: - _adapter = self.adapters.get(source.platform) - if _adapter: - await _adapter.send( - source.chat_id, - "\n".join(_ctx_result.warnings) or "Context injection refused.", - ) - return - if _ctx_result.expanded: - message_text = _ctx_result.message - except Exception as exc: - logger.debug("@ context reference expansion failed: %s", exc) - # Run the agent agent_result = await self._run_agent( message=message_text, @@ -4010,25 +4028,31 @@ class GatewayRunner: handles /stop before this method is reached. This handler fires only through normal command dispatch (no running agent) or as a fallback. Force-clean the session lock in all cases for safety. + + When there IS a running/pending agent, the session is also marked + as *suspended* so the next message starts a fresh session instead + of resuming the stuck context (#7536). """ source = event.source session_entry = self.session_store.get_or_create_session(source) session_key = session_entry.session_key - + agent = self._running_agents.get(session_key) if agent is _AGENT_PENDING_SENTINEL: # Force-clean the sentinel so the session is unlocked. if session_key in self._running_agents: del self._running_agents[session_key] - logger.info("HARD STOP (pending) for session %s — sentinel cleared", session_key[:20]) - return "⚡ Force-stopped. The agent was still starting — session unlocked." + self.session_store.suspend_session(session_key) + logger.info("HARD STOP (pending) for session %s — suspended, sentinel cleared", session_key[:20]) + return "⚡ Force-stopped. The agent was still starting — your next message will start fresh." if agent: agent.interrupt("Stop requested") # Force-clean the session lock so a truly hung agent doesn't # keep it locked forever. if session_key in self._running_agents: del self._running_agents[session_key] - return "⚡ Force-stopped. The session is unlocked — you can send a new message." + self.session_store.suspend_session(session_key) + return "⚡ Force-stopped. Your next message will start a fresh session." else: return "No active task to stop." @@ -6694,6 +6718,8 @@ class GatewayRunner: chat_id=context.source.chat_id, chat_name=context.source.chat_name or "", thread_id=str(context.source.thread_id) if context.source.thread_id else "", + user_id=str(context.source.user_id) if context.source.user_id else "", + user_name=str(context.source.user_name) if context.source.user_name else "", ) def _clear_session_env(self, tokens: list) -> None: @@ -6906,6 +6932,8 @@ class GatewayRunner: platform_name = watcher.get("platform", "") chat_id = watcher.get("chat_id", "") thread_id = watcher.get("thread_id", "") + user_id = watcher.get("user_id", "") + user_name = watcher.get("user_name", "") agent_notify = watcher.get("notify_on_complete", False) notify_mode = self._load_background_notifications_mode() @@ -6961,6 +6989,8 @@ class GatewayRunner: platform=_platform_enum, chat_id=chat_id, thread_id=thread_id or None, + user_id=user_id or None, + user_name=user_name or None, ) synth_event = MessageEvent( text=synth_text, @@ -8115,17 +8145,16 @@ class GatewayRunner: # Get pending message from adapter. # Use session_key (not source.chat_id) to match adapter's storage keys. + pending_event = None pending = None if result and adapter and session_key: - if result.get("interrupted"): - pending = _dequeue_pending_text(adapter, session_key) - if not pending and result.get("interrupt_message"): - pending = result.get("interrupt_message") - else: - pending = _dequeue_pending_text(adapter, session_key) - if pending: - logger.debug("Processing queued message after agent completion: '%s...'", pending[:40]) - + pending_event = _dequeue_pending_event(adapter, session_key) + if result.get("interrupted") and not pending_event and result.get("interrupt_message"): + pending = result.get("interrupt_message") + elif pending_event: + pending = pending_event.text or _build_media_placeholder(pending_event) + logger.debug("Processing queued message after agent completion: '%s...'", pending[:40]) + # Safety net: if the pending text is a slash command (e.g. "/stop", # "/new"), discard it — commands should never be passed to the agent # as user input. The primary fix is in base.py (commands bypass the @@ -8143,27 +8172,29 @@ class GatewayRunner: "commands must not be passed as agent input", _pending_cmd_word, ) + pending_event = None pending = None except Exception: pass - if self._draining and pending: + if self._draining and (pending_event or pending): logger.info( "Discarding pending follow-up for session %s during gateway %s", session_key[:20] if session_key else "?", self._status_action_label(), ) + pending_event = None pending = None - if pending: + if pending_event or pending: logger.debug("Processing pending message: '%s...'", pending[:40]) - + # Clear the adapter's interrupt event so the next _run_agent call # doesn't immediately re-trigger the interrupt before the new agent # even makes its first API call (this was causing an infinite loop). if adapter and hasattr(adapter, '_active_sessions') and session_key and session_key in adapter._active_sessions: adapter._active_sessions[session_key].clear() - + # Cap recursion depth to prevent resource exhaustion when the # user sends multiple messages while the agent keeps failing. (#816) if _interrupt_depth >= self._MAX_INTERRUPT_DEPTH: @@ -8172,9 +8203,10 @@ class GatewayRunner: "queueing message instead of recursing.", _interrupt_depth, session_key, ) - # Queue the pending message for normal processing on next turn adapter = self.adapters.get(source.platform) - if adapter and hasattr(adapter, 'queue_message'): + if adapter and pending_event: + merge_pending_message_event(adapter._pending_messages, session_key, pending_event) + elif adapter and hasattr(adapter, 'queue_message'): adapter.queue_message(session_key, pending) return result_holder[0] or {"final_response": response, "messages": history} @@ -8189,23 +8221,37 @@ class GatewayRunner: if first_response and not _already_streamed: try: await adapter.send(source.chat_id, first_response, - metadata=getattr(event, "metadata", None)) + metadata={"thread_id": source.thread_id} if source.thread_id else None) except Exception as e: logger.warning("Failed to send first response before queued message: %s", e) # else: interrupted — discard the interrupted response ("Operation # interrupted." is just noise; the user already knows they sent a # new message). - # Process the pending message with updated history updated_history = result.get("messages", history) + next_source = source + next_message = pending + next_message_id = None + if pending_event is not None: + next_source = getattr(pending_event, "source", None) or source + next_message = await self._prepare_inbound_message_text( + event=pending_event, + source=next_source, + history=updated_history, + ) + if next_message is None: + return result + next_message_id = getattr(pending_event, "message_id", None) + return await self._run_agent( - message=pending, + message=next_message, context_prompt=context_prompt, history=updated_history, - source=source, + source=next_source, session_id=session_id, session_key=session_key, _interrupt_depth=_interrupt_depth + 1, + event_message_id=next_message_id, ) finally: # Stop progress sender, interrupt monitor, and notification task diff --git a/gateway/session.py b/gateway/session.py index 2b32c18895..96013df513 100644 --- a/gateway/session.py +++ b/gateway/session.py @@ -368,6 +368,11 @@ class SessionEntry: # survives gateway restarts (the old in-memory _pre_flushed_sessions # set was lost on restart, causing redundant re-flushes). memory_flushed: bool = False + + # When True the next call to get_or_create_session() will auto-reset + # this session (create a new session_id) so the user starts fresh. + # Set by /stop to break stuck-resume loops (#7536). + suspended: bool = False def to_dict(self) -> Dict[str, Any]: result = { @@ -387,6 +392,7 @@ class SessionEntry: "estimated_cost_usd": self.estimated_cost_usd, "cost_status": self.cost_status, "memory_flushed": self.memory_flushed, + "suspended": self.suspended, } if self.origin: result["origin"] = self.origin.to_dict() @@ -423,6 +429,7 @@ class SessionEntry: estimated_cost_usd=data.get("estimated_cost_usd", 0.0), cost_status=data.get("cost_status", "unknown"), memory_flushed=data.get("memory_flushed", False), + suspended=data.get("suspended", False), ) @@ -698,7 +705,12 @@ class SessionStore: if session_key in self._entries and not force_new: entry = self._entries[session_key] - reset_reason = self._should_reset(entry, source) + # Auto-reset sessions marked as suspended (e.g. after /stop + # broke a stuck loop — #7536). + if entry.suspended: + reset_reason = "suspended" + else: + reset_reason = self._should_reset(entry, source) if not reset_reason: entry.updated_at = now self._save() @@ -771,6 +783,44 @@ class SessionStore: entry.last_prompt_tokens = last_prompt_tokens self._save() + def suspend_session(self, session_key: str) -> bool: + """Mark a session as suspended so it auto-resets on next access. + + Used by ``/stop`` to prevent stuck sessions from being resumed + after a gateway restart (#7536). Returns True if the session + existed and was marked. + """ + with self._lock: + self._ensure_loaded_locked() + if session_key in self._entries: + self._entries[session_key].suspended = True + self._save() + return True + return False + + def suspend_recently_active(self, max_age_seconds: int = 120) -> int: + """Mark recently-active sessions as suspended. + + Called on gateway startup to prevent sessions that were likely + in-flight when the gateway last exited from being blindly resumed + (#7536). Only suspends sessions updated within *max_age_seconds* + to avoid resetting long-idle sessions that are harmless to resume. + Returns the number of sessions that were suspended. + """ + import time as _time + + cutoff = _time.time() - max_age_seconds + count = 0 + with self._lock: + self._ensure_loaded_locked() + for entry in self._entries.values(): + if not entry.suspended and entry.updated_at >= cutoff: + entry.suspended = True + count += 1 + if count: + self._save() + return count + def reset_session(self, session_key: str) -> Optional[SessionEntry]: """Force reset a session, creating a new session ID.""" db_end_session_id = None diff --git a/gateway/session_context.py b/gateway/session_context.py index 775cd8698b..6d676dc1ec 100644 --- a/gateway/session_context.py +++ b/gateway/session_context.py @@ -46,12 +46,16 @@ _SESSION_PLATFORM: ContextVar[str] = ContextVar("HERMES_SESSION_PLATFORM", defau _SESSION_CHAT_ID: ContextVar[str] = ContextVar("HERMES_SESSION_CHAT_ID", default="") _SESSION_CHAT_NAME: ContextVar[str] = ContextVar("HERMES_SESSION_CHAT_NAME", default="") _SESSION_THREAD_ID: ContextVar[str] = ContextVar("HERMES_SESSION_THREAD_ID", default="") +_SESSION_USER_ID: ContextVar[str] = ContextVar("HERMES_SESSION_USER_ID", default="") +_SESSION_USER_NAME: ContextVar[str] = ContextVar("HERMES_SESSION_USER_NAME", default="") _VAR_MAP = { "HERMES_SESSION_PLATFORM": _SESSION_PLATFORM, "HERMES_SESSION_CHAT_ID": _SESSION_CHAT_ID, "HERMES_SESSION_CHAT_NAME": _SESSION_CHAT_NAME, "HERMES_SESSION_THREAD_ID": _SESSION_THREAD_ID, + "HERMES_SESSION_USER_ID": _SESSION_USER_ID, + "HERMES_SESSION_USER_NAME": _SESSION_USER_NAME, } @@ -60,6 +64,8 @@ def set_session_vars( chat_id: str = "", chat_name: str = "", thread_id: str = "", + user_id: str = "", + user_name: str = "", ) -> list: """Set all session context variables and return reset tokens. @@ -74,6 +80,8 @@ def set_session_vars( _SESSION_CHAT_ID.set(chat_id), _SESSION_CHAT_NAME.set(chat_name), _SESSION_THREAD_ID.set(thread_id), + _SESSION_USER_ID.set(user_id), + _SESSION_USER_NAME.set(user_name), ] return tokens @@ -87,6 +95,8 @@ def clear_session_vars(tokens: list) -> None: _SESSION_CHAT_ID, _SESSION_CHAT_NAME, _SESSION_THREAD_ID, + _SESSION_USER_ID, + _SESSION_USER_NAME, ] for var, token in zip(vars_in_order, tokens): var.reset(token) diff --git a/hermes_cli/auth.py b/hermes_cli/auth.py index fcb7c2dc5a..56b9fb63c2 100644 --- a/hermes_cli/auth.py +++ b/hermes_cli/auth.py @@ -261,6 +261,28 @@ PROVIDER_REGISTRY: Dict[str, ProviderConfig] = { } +# ============================================================================= +# Anthropic Key Helper +# ============================================================================= + +def get_anthropic_key() -> str: + """Return the first usable Anthropic credential, or ``""``. + + Checks both the ``.env`` file (via ``get_env_value``) and the process + environment (``os.getenv``). The fallback order mirrors the + ``PROVIDER_REGISTRY["anthropic"].api_key_env_vars`` tuple: + + ANTHROPIC_API_KEY -> ANTHROPIC_TOKEN -> CLAUDE_CODE_OAUTH_TOKEN + """ + from hermes_cli.config import get_env_value + + for var in PROVIDER_REGISTRY["anthropic"].api_key_env_vars: + value = get_env_value(var) or os.getenv(var, "") + if value: + return value + return "" + + # ============================================================================= # Kimi Code Endpoint Detection # ============================================================================= diff --git a/hermes_cli/claw.py b/hermes_cli/claw.py index 3ab6bf9a8d..d0bfd73d23 100644 --- a/hermes_cli/claw.py +++ b/hermes_cli/claw.py @@ -52,6 +52,41 @@ _OPENCLAW_SCRIPT_INSTALLED = ( # Known OpenClaw directory names (current + legacy) _OPENCLAW_DIR_NAMES = (".openclaw", ".clawdbot", ".moldbot") +def _warn_if_gateway_running(auto_yes: bool) -> None: + """Check if a Hermes gateway is running with connected platforms. + + Migrating bot tokens while the gateway is polling will cause conflicts + (e.g. Telegram 409 "terminated by other getUpdates request"). Warn the + user and let them decide whether to continue. + """ + from gateway.status import get_running_pid, read_runtime_status + + if not get_running_pid(): + return + + data = read_runtime_status() or {} + platforms = data.get("platforms") or {} + connected = [name for name, info in platforms.items() + if isinstance(info, dict) and info.get("state") == "connected"] + if not connected: + return + + print() + print_error( + "Hermes gateway is running with active connections: " + + ", ".join(connected) + ) + print_info( + "Migrating bot tokens while the gateway is active will cause " + "conflicts (Telegram, Discord, and Slack only allow one active " + "session per token)." + ) + print_info("Recommendation: stop the gateway first with 'hermes stop'.") + print() + if not auto_yes and not prompt_yes_no("Continue anyway?", default=False): + print_info("Migration cancelled. Stop the gateway and try again.") + sys.exit(0) + # State files commonly found in OpenClaw workspace directories that cause # confusion after migration (the agent discovers them and writes to them) _WORKSPACE_STATE_GLOBS = ( @@ -252,6 +287,10 @@ def _cmd_migrate(args): print_info(f"Workspace: {workspace_target}") print() + # Check if a gateway is running with connected platforms — migrating tokens + # while the gateway is active will cause conflicts (e.g. Telegram 409). + _warn_if_gateway_running(auto_yes) + # Ensure config.yaml exists before migration tries to read it config_path = get_config_path() if not config_path.exists(): diff --git a/hermes_cli/cli_output.py b/hermes_cli/cli_output.py new file mode 100644 index 0000000000..3d454eb308 --- /dev/null +++ b/hermes_cli/cli_output.py @@ -0,0 +1,79 @@ +"""Shared CLI output helpers for Hermes CLI modules. + +Extracts the identical ``print_info/success/warning/error`` and ``prompt()`` +functions previously duplicated across setup.py, tools_config.py, +mcp_config.py, and memory_setup.py. +""" + +import getpass +import sys + +from hermes_cli.colors import Colors, color + + +# ─── Print Helpers ──────────────────────────────────────────────────────────── + + +def print_info(text: str) -> None: + """Print a dim informational message.""" + print(color(f" {text}", Colors.DIM)) + + +def print_success(text: str) -> None: + """Print a green success message with ✓ prefix.""" + print(color(f"✓ {text}", Colors.GREEN)) + + +def print_warning(text: str) -> None: + """Print a yellow warning message with ⚠ prefix.""" + print(color(f"⚠ {text}", Colors.YELLOW)) + + +def print_error(text: str) -> None: + """Print a red error message with ✗ prefix.""" + print(color(f"✗ {text}", Colors.RED)) + + +def print_header(text: str) -> None: + """Print a bold yellow header.""" + print(color(f"\n {text}", Colors.YELLOW)) + + +# ─── Input Prompts ──────────────────────────────────────────────────────────── + + +def prompt( + question: str, + default: str | None = None, + password: bool = False, +) -> str: + """Prompt the user for input with optional default and password masking. + + Replaces the four independent ``_prompt()`` / ``prompt()`` implementations + in setup.py, tools_config.py, mcp_config.py, and memory_setup.py. + + Returns the user's input (stripped), or *default* if the user presses Enter. + Returns empty string on Ctrl-C or EOF. + """ + suffix = f" [{default}]" if default else "" + display = color(f" {question}{suffix}: ", Colors.YELLOW) + + try: + if password: + value = getpass.getpass(display) + else: + value = input(display) + value = value.strip() + return value if value else (default or "") + except (KeyboardInterrupt, EOFError): + print() + return "" + + +def prompt_yes_no(question: str, default: bool = True) -> bool: + """Prompt for a yes/no answer. Returns bool.""" + hint = "Y/n" if default else "y/N" + answer = prompt(f"{question} ({hint})") + if not answer: + return default + return answer.lower().startswith("y") diff --git a/hermes_cli/config.py b/hermes_cli/config.py index 4661455d12..1545d15aad 100644 --- a/hermes_cli/config.py +++ b/hermes_cli/config.py @@ -1497,7 +1497,7 @@ _KNOWN_ROOT_KEYS = { # Valid fields inside a custom_providers list entry _VALID_CUSTOM_PROVIDER_FIELDS = { - "name", "base_url", "api_key", "api_mode", "models", + "name", "base_url", "api_key", "api_mode", "model", "models", "context_length", "rate_limit_delay", } @@ -2582,7 +2582,8 @@ def show_config(): for env_key, name in keys: value = get_env_value(env_key) print(f" {name:<14} {redact_key(value)}") - anthropic_value = get_env_value("ANTHROPIC_TOKEN") or get_env_value("ANTHROPIC_API_KEY") + from hermes_cli.auth import get_anthropic_key + anthropic_value = get_anthropic_key() print(f" {'Anthropic':<14} {redact_key(anthropic_value)}") # Model settings @@ -2798,8 +2799,8 @@ def set_config_value(key: str, value: str): # Write only user config back (not the full merged defaults) ensure_hermes_home() - with open(config_path, 'w', encoding="utf-8") as f: - yaml.dump(user_config, f, default_flow_style=False, sort_keys=False) + from utils import atomic_yaml_write + atomic_yaml_write(config_path, user_config, sort_keys=False) # Keep .env in sync for keys that terminal_tool reads directly from env vars. # config.yaml is authoritative, but terminal_tool only reads TERMINAL_ENV etc. diff --git a/hermes_cli/doctor.py b/hermes_cli/doctor.py index f5f8a228a7..13c904692c 100644 --- a/hermes_cli/doctor.py +++ b/hermes_cli/doctor.py @@ -336,8 +336,8 @@ def run_doctor(args): model_section[k] = raw_config.pop(k) else: raw_config.pop(k) - with open(config_path, "w") as f: - yaml.dump(raw_config, f, default_flow_style=False) + from utils import atomic_yaml_write + atomic_yaml_write(config_path, raw_config) check_ok("Migrated stale root-level keys into model section") fixed_count += 1 else: @@ -686,7 +686,8 @@ def run_doctor(args): else: check_warn("OpenRouter API", "(not configured)") - anthropic_key = os.getenv("ANTHROPIC_TOKEN") or os.getenv("ANTHROPIC_API_KEY") + from hermes_cli.auth import get_anthropic_key + anthropic_key = get_anthropic_key() if anthropic_key: print(" Checking Anthropic API...", end="", flush=True) try: diff --git a/hermes_cli/gateway.py b/hermes_cli/gateway.py index b29511dd59..505bad0b51 100644 --- a/hermes_cli/gateway.py +++ b/hermes_cli/gateway.py @@ -157,30 +157,54 @@ def _request_gateway_self_restart(pid: int) -> bool: return True -def find_gateway_pids(exclude_pids: set | None = None) -> list: +def find_gateway_pids(exclude_pids: set | None = None, all_profiles: bool = False) -> list: """Find PIDs of running gateway processes. Args: exclude_pids: PIDs to exclude from the result (e.g. service-managed PIDs that should not be killed during a stale-process sweep). + all_profiles: When ``True``, return gateway PIDs across **all** + profiles (the pre-7923 global behaviour). ``hermes update`` + needs this because a code update affects every profile. + When ``False`` (default), only PIDs belonging to the current + Hermes profile are returned. """ - pids = [] _exclude = exclude_pids or set() + pids = [pid for pid in _get_service_pids() if pid not in _exclude] patterns = [ "hermes_cli.main gateway", + "hermes_cli.main --profile", + "hermes_cli.main -p", "hermes_cli/main.py gateway", + "hermes_cli/main.py --profile", + "hermes_cli/main.py -p", "hermes gateway", "gateway/run.py", ] + current_home = str(get_hermes_home().resolve()) + current_profile_arg = _profile_arg(current_home) + current_profile_name = current_profile_arg.split()[-1] if current_profile_arg else "" + + def _matches_current_profile(command: str) -> bool: + if current_profile_name: + return ( + f"--profile {current_profile_name}" in command + or f"-p {current_profile_name}" in command + or f"HERMES_HOME={current_home}" in command + ) + + if "--profile " in command or " -p " in command: + return False + if "HERMES_HOME=" in command and f"HERMES_HOME={current_home}" not in command: + return False + return True try: if is_windows(): - # Windows: use wmic to search command lines result = subprocess.run( ["wmic", "process", "get", "ProcessId,CommandLine", "/FORMAT:LIST"], capture_output=True, text=True, timeout=10 ) - # Parse WMIC LIST output: blocks of "CommandLine=...\nProcessId=...\n" current_cmd = "" for line in result.stdout.split('\n'): line = line.strip() @@ -188,7 +212,7 @@ def find_gateway_pids(exclude_pids: set | None = None) -> list: current_cmd = line[len("CommandLine="):] elif line.startswith("ProcessId="): pid_str = line[len("ProcessId="):] - if any(p in current_cmd for p in patterns): + if any(p in current_cmd for p in patterns) and (all_profiles or _matches_current_profile(current_cmd)): try: pid = int(pid_str) if pid != os.getpid() and pid not in pids and pid not in _exclude: @@ -198,41 +222,57 @@ def find_gateway_pids(exclude_pids: set | None = None) -> list: current_cmd = "" else: result = subprocess.run( - ["ps", "aux"], + ["ps", "eww", "-ax", "-o", "pid=,command="], capture_output=True, text=True, timeout=10, ) for line in result.stdout.split('\n'): - # Skip grep and current process - if 'grep' in line or str(os.getpid()) in line: + stripped = line.strip() + if not stripped or 'grep' in stripped: continue - for pattern in patterns: - if pattern in line: - parts = line.split() - if len(parts) > 1: - try: - pid = int(parts[1]) - if pid not in pids and pid not in _exclude: - pids.append(pid) - except ValueError: - continue - break - except Exception: + + pid = None + command = "" + + parts = stripped.split(None, 1) + if len(parts) == 2: + try: + pid = int(parts[0]) + command = parts[1] + except ValueError: + pid = None + + if pid is None: + aux_parts = stripped.split() + if len(aux_parts) > 10 and aux_parts[1].isdigit(): + pid = int(aux_parts[1]) + command = " ".join(aux_parts[10:]) + + if pid is None: + continue + if pid == os.getpid() or pid in pids or pid in _exclude: + continue + if any(pattern in command for pattern in patterns) and (all_profiles or _matches_current_profile(command)): + pids.append(pid) + except (OSError, subprocess.TimeoutExpired): pass return pids -def kill_gateway_processes(force: bool = False, exclude_pids: set | None = None) -> int: +def kill_gateway_processes(force: bool = False, exclude_pids: set | None = None, + all_profiles: bool = False) -> int: """Kill any running gateway processes. Returns count killed. Args: force: Use the platform's force-kill mechanism instead of graceful terminate. exclude_pids: PIDs to skip (e.g. service-managed PIDs that were just restarted and should not be killed). + all_profiles: When ``True``, kill across all profiles. Passed + through to :func:`find_gateway_pids`. """ - pids = find_gateway_pids(exclude_pids=exclude_pids) + pids = find_gateway_pids(exclude_pids=exclude_pids, all_profiles=all_profiles) killed = 0 for pid in pids: @@ -633,6 +673,17 @@ def print_systemd_linger_guidance() -> None: print(" If you want the gateway user service to survive logout, run:") print(" sudo loginctl enable-linger $USER") +def _launchd_user_home() -> Path: + """Return the real macOS user home for launchd artifacts. + + Profile-mode Hermes often sets ``HOME`` to a profile-scoped directory, but + launchd user agents still live under the actual account home. + """ + import pwd + + return Path(pwd.getpwuid(os.getuid()).pw_dir) + + def get_launchd_plist_path() -> Path: """Return the launchd plist path, scoped per profile. @@ -641,7 +692,7 @@ def get_launchd_plist_path() -> Path: """ suffix = _profile_suffix() name = f"ai.hermes.gateway-{suffix}" if suffix else "ai.hermes.gateway" - return Path.home() / "Library" / "LaunchAgents" / f"{name}.plist" + return _launchd_user_home() / "Library" / "LaunchAgents" / f"{name}.plist" def _detect_venv_dir() -> Path | None: """Detect the active virtualenv directory. @@ -839,6 +890,25 @@ def _normalize_service_definition(text: str) -> str: return "\n".join(line.rstrip() for line in text.strip().splitlines()) +def _normalize_launchd_plist_for_comparison(text: str) -> str: + """Normalize launchd plist text for staleness checks. + + The generated plist intentionally captures a broad PATH assembled from the + invoking shell so user-installed tools remain reachable under launchd. + That makes raw text comparison unstable across shells, so ignore the PATH + payload when deciding whether the installed plist is stale. + """ + import re + + normalized = _normalize_service_definition(text) + return re.sub( + r'(PATH\s*)(.*?)()', + r'\1__HERMES_PATH__\3', + normalized, + flags=re.S, + ) + + def systemd_unit_is_current(system: bool = False) -> bool: unit_path = get_systemd_unit_path(system=system) if not unit_path.exists(): @@ -1220,7 +1290,7 @@ def launchd_plist_is_current() -> bool: installed = plist_path.read_text(encoding="utf-8") expected = generate_launchd_plist() - return _normalize_service_definition(installed) == _normalize_service_definition(expected) + return _normalize_launchd_plist_for_comparison(installed) == _normalize_launchd_plist_for_comparison(expected) def refresh_launchd_plist_if_needed() -> bool: @@ -1981,6 +2051,36 @@ def _setup_whatsapp(): cmd_whatsapp(argparse.Namespace()) +def _setup_email(): + """Configure Email via the standard platform setup.""" + email_platform = next(p for p in _PLATFORMS if p["key"] == "email") + _setup_standard_platform(email_platform) + + +def _setup_sms(): + """Configure SMS (Twilio) via the standard platform setup.""" + sms_platform = next(p for p in _PLATFORMS if p["key"] == "sms") + _setup_standard_platform(sms_platform) + + +def _setup_dingtalk(): + """Configure DingTalk via the standard platform setup.""" + dingtalk_platform = next(p for p in _PLATFORMS if p["key"] == "dingtalk") + _setup_standard_platform(dingtalk_platform) + + +def _setup_feishu(): + """Configure Feishu / Lark via the standard platform setup.""" + feishu_platform = next(p for p in _PLATFORMS if p["key"] == "feishu") + _setup_standard_platform(feishu_platform) + + +def _setup_wecom(): + """Configure WeCom (Enterprise WeChat) via the standard platform setup.""" + wecom_platform = next(p for p in _PLATFORMS if p["key"] == "wecom") + _setup_standard_platform(wecom_platform) + + def _is_service_installed() -> bool: """Check if the gateway is installed as a system service.""" if supports_systemd_services(): @@ -2540,7 +2640,7 @@ def gateway_command(args): service_available = True except subprocess.CalledProcessError: pass - killed = kill_gateway_processes() + killed = kill_gateway_processes(all_profiles=True) total = killed + (1 if service_available else 0) if total: print(f"✓ Stopped {total} gateway process(es) across all profiles") diff --git a/hermes_cli/main.py b/hermes_cli/main.py index 1e4159c9af..8d1a10000b 100644 --- a/hermes_cli/main.py +++ b/hermes_cli/main.py @@ -2758,13 +2758,8 @@ def _model_flow_anthropic(config, current_model=""): from hermes_cli.models import _PROVIDER_MODELS # Check ALL credential sources - existing_key = ( - get_env_value("ANTHROPIC_TOKEN") - or os.getenv("ANTHROPIC_TOKEN", "") - or get_env_value("ANTHROPIC_API_KEY") - or os.getenv("ANTHROPIC_API_KEY", "") - or os.getenv("CLAUDE_CODE_OAUTH_TOKEN", "") - ) + from hermes_cli.auth import get_anthropic_key + existing_key = get_anthropic_key() cc_available = False try: from agent.anthropic_adapter import read_claude_code_credentials, is_claude_code_token_valid @@ -4090,7 +4085,7 @@ def cmd_update(args): # Exclude PIDs that belong to just-restarted services so we don't # immediately kill the process that systemd/launchd just spawned. service_pids = _get_service_pids() - manual_pids = find_gateway_pids(exclude_pids=service_pids) + manual_pids = find_gateway_pids(exclude_pids=service_pids, all_profiles=True) for pid in manual_pids: try: os.kill(pid, _signal.SIGTERM) diff --git a/hermes_cli/mcp_config.py b/hermes_cli/mcp_config.py index 9154ed50a3..cf2dde0892 100644 --- a/hermes_cli/mcp_config.py +++ b/hermes_cli/mcp_config.py @@ -57,19 +57,8 @@ def _confirm(question: str, default: bool = True) -> bool: def _prompt(question: str, *, password: bool = False, default: str = "") -> str: - display = f" {question}" - if default: - display += f" [{default}]" - display += ": " - try: - if password: - value = getpass.getpass(color(display, Colors.YELLOW)) - else: - value = input(color(display, Colors.YELLOW)) - return value.strip() or default - except (KeyboardInterrupt, EOFError): - print() - return default + from hermes_cli.cli_output import prompt as _shared_prompt + return _shared_prompt(question, default=default, password=password) # ─── Config Helpers ─────────────────────────────────────────────────────────── diff --git a/hermes_cli/memory_setup.py b/hermes_cli/memory_setup.py index 2843f4f444..1aa4313676 100644 --- a/hermes_cli/memory_setup.py +++ b/hermes_cli/memory_setup.py @@ -25,85 +25,13 @@ def _curses_select(title: str, items: list[tuple[str, str]], default: int = 0) - items: list of (label, description) tuples. Returns selected index, or default on escape/quit. """ - try: - import curses - result = [default] - - def _menu(stdscr): - curses.curs_set(0) - if curses.has_colors(): - curses.start_color() - curses.use_default_colors() - curses.init_pair(1, curses.COLOR_GREEN, -1) - curses.init_pair(2, curses.COLOR_YELLOW, -1) - curses.init_pair(3, curses.COLOR_CYAN, -1) - cursor = default - - while True: - stdscr.clear() - max_y, max_x = stdscr.getmaxyx() - - # Title - try: - stdscr.addnstr(0, 0, title, max_x - 1, - curses.A_BOLD | (curses.color_pair(2) if curses.has_colors() else 0)) - stdscr.addnstr(1, 0, " ↑↓ navigate ⏎ select q quit", max_x - 1, - curses.color_pair(3) if curses.has_colors() else curses.A_DIM) - except curses.error: - pass - - for i, (label, desc) in enumerate(items): - y = i + 3 - if y >= max_y - 1: - break - arrow = "→" if i == cursor else " " - line = f" {arrow} {label}" - if desc: - line += f" {desc}" - - attr = curses.A_NORMAL - if i == cursor: - attr = curses.A_BOLD - if curses.has_colors(): - attr |= curses.color_pair(1) - try: - stdscr.addnstr(y, 0, line[:max_x - 1], max_x - 1, attr) - except curses.error: - pass - - stdscr.refresh() - key = stdscr.getch() - - if key in (curses.KEY_UP, ord('k')): - cursor = (cursor - 1) % len(items) - elif key in (curses.KEY_DOWN, ord('j')): - cursor = (cursor + 1) % len(items) - elif key in (curses.KEY_ENTER, 10, 13): - result[0] = cursor - return - elif key in (27, ord('q')): - return - - curses.wrapper(_menu) - return result[0] - - except Exception: - # Fallback: numbered input - print(f"\n {title}\n") - for i, (label, desc) in enumerate(items): - marker = "→" if i == default else " " - d = f" {desc}" if desc else "" - print(f" {marker} {i + 1}. {label}{d}") - while True: - try: - val = input(f"\n Select [1-{len(items)}] ({default + 1}): ") - if not val: - return default - idx = int(val) - 1 - if 0 <= idx < len(items): - return idx - except (ValueError, EOFError): - return default + from hermes_cli.curses_ui import curses_radiolist + # Format (label, desc) tuples into display strings + display_items = [ + f"{label} {desc}" if desc else label + for label, desc in items + ] + return curses_radiolist(title, display_items, selected=default, cancel_returns=default) def _prompt(label: str, default: str | None = None, secret: bool = False) -> str: diff --git a/hermes_cli/platforms.py b/hermes_cli/platforms.py new file mode 100644 index 0000000000..18307912b1 --- /dev/null +++ b/hermes_cli/platforms.py @@ -0,0 +1,45 @@ +""" +Shared platform registry for Hermes Agent. + +Single source of truth for platform metadata consumed by both +skills_config (label display) and tools_config (default toolset +resolution). Import ``PLATFORMS`` from here instead of maintaining +duplicate dicts in each module. +""" + +from collections import OrderedDict +from typing import NamedTuple + + +class PlatformInfo(NamedTuple): + """Metadata for a single platform entry.""" + label: str + default_toolset: str + + +# Ordered so that TUI menus are deterministic. +PLATFORMS: OrderedDict[str, PlatformInfo] = OrderedDict([ + ("cli", PlatformInfo(label="🖥️ CLI", default_toolset="hermes-cli")), + ("telegram", PlatformInfo(label="📱 Telegram", default_toolset="hermes-telegram")), + ("discord", PlatformInfo(label="💬 Discord", default_toolset="hermes-discord")), + ("slack", PlatformInfo(label="💼 Slack", default_toolset="hermes-slack")), + ("whatsapp", PlatformInfo(label="📱 WhatsApp", default_toolset="hermes-whatsapp")), + ("signal", PlatformInfo(label="📡 Signal", default_toolset="hermes-signal")), + ("bluebubbles", PlatformInfo(label="💙 BlueBubbles", default_toolset="hermes-bluebubbles")), + ("email", PlatformInfo(label="📧 Email", default_toolset="hermes-email")), + ("homeassistant", PlatformInfo(label="🏠 Home Assistant", default_toolset="hermes-homeassistant")), + ("mattermost", PlatformInfo(label="💬 Mattermost", default_toolset="hermes-mattermost")), + ("matrix", PlatformInfo(label="💬 Matrix", default_toolset="hermes-matrix")), + ("dingtalk", PlatformInfo(label="💬 DingTalk", default_toolset="hermes-dingtalk")), + ("feishu", PlatformInfo(label="🪽 Feishu", default_toolset="hermes-feishu")), + ("wecom", PlatformInfo(label="💬 WeCom", default_toolset="hermes-wecom")), + ("weixin", PlatformInfo(label="💬 Weixin", default_toolset="hermes-weixin")), + ("webhook", PlatformInfo(label="🔗 Webhook", default_toolset="hermes-webhook")), + ("api_server", PlatformInfo(label="🌐 API Server", default_toolset="hermes-api-server")), +]) + + +def platform_label(key: str, default: str = "") -> str: + """Return the display label for a platform key, or *default*.""" + info = PLATFORMS.get(key) + return info.label if info is not None else default diff --git a/hermes_cli/runtime_provider.py b/hermes_cli/runtime_provider.py index 3d1333c26f..cd0b667225 100644 --- a/hermes_cli/runtime_provider.py +++ b/hermes_cli/runtime_provider.py @@ -304,6 +304,9 @@ def _get_named_custom_provider(requested_provider: str) -> Optional[Dict[str, An api_mode = _parse_api_mode(entry.get("api_mode")) if api_mode: result["api_mode"] = api_mode + model_name = str(entry.get("model", "") or "").strip() + if model_name: + result["model"] = model_name return result return None @@ -329,6 +332,11 @@ def _resolve_named_custom_runtime( # Check if a credential pool exists for this custom endpoint pool_result = _try_resolve_from_custom_pool(base_url, "custom", custom_provider.get("api_mode")) if pool_result: + # Propagate the model name even when using pooled credentials — + # the pool doesn't know about the custom_providers model field. + model_name = custom_provider.get("model") + if model_name: + pool_result["model"] = model_name return pool_result api_key_candidates = [ @@ -339,7 +347,7 @@ def _resolve_named_custom_runtime( ] api_key = next((candidate for candidate in api_key_candidates if has_usable_secret(candidate)), "") - return { + result = { "provider": "custom", "api_mode": custom_provider.get("api_mode") or _detect_api_mode_for_url(base_url) @@ -348,6 +356,11 @@ def _resolve_named_custom_runtime( "api_key": api_key or "no-key-required", "source": f"custom_provider:{custom_provider.get('name', requested_provider)}", } + # Propagate the model name so callers can override self.model when the + # provider name differs from the actual model string the API expects. + if custom_provider.get("model"): + result["model"] = custom_provider["model"] + return result def _resolve_openrouter_runtime( diff --git a/hermes_cli/setup.py b/hermes_cli/setup.py index ca877606fd..a25ce84914 100644 --- a/hermes_cli/setup.py +++ b/hermes_cli/setup.py @@ -197,24 +197,12 @@ def print_header(title: str): print(color(f"◆ {title}", Colors.CYAN, Colors.BOLD)) -def print_info(text: str): - """Print info text.""" - print(color(f" {text}", Colors.DIM)) - - -def print_success(text: str): - """Print success message.""" - print(color(f"✓ {text}", Colors.GREEN)) - - -def print_warning(text: str): - """Print warning message.""" - print(color(f"⚠ {text}", Colors.YELLOW)) - - -def print_error(text: str): - """Print error message.""" - print(color(f"✗ {text}", Colors.RED)) +from hermes_cli.cli_output import ( # noqa: E402 + print_error, + print_info, + print_success, + print_warning, +) def is_interactive_stdin() -> bool: @@ -269,80 +257,9 @@ def prompt(question: str, default: str = None, password: bool = False) -> str: def _curses_prompt_choice(question: str, choices: list, default: int = 0) -> int: - """Single-select menu using curses to avoid simple_term_menu rendering bugs.""" - try: - import curses - result_holder = [default] - - def _curses_menu(stdscr): - curses.curs_set(0) - if curses.has_colors(): - curses.start_color() - curses.use_default_colors() - curses.init_pair(1, curses.COLOR_GREEN, -1) - curses.init_pair(2, curses.COLOR_YELLOW, -1) - cursor = default - scroll_offset = 0 - - while True: - stdscr.clear() - max_y, max_x = stdscr.getmaxyx() - - # Rows available for list items: rows 2..(max_y-2) inclusive. - visible = max(1, max_y - 3) - - # Scroll the viewport so the cursor is always visible. - if cursor < scroll_offset: - scroll_offset = cursor - elif cursor >= scroll_offset + visible: - scroll_offset = cursor - visible + 1 - scroll_offset = max(0, min(scroll_offset, max(0, len(choices) - visible))) - - try: - stdscr.addnstr( - 0, - 0, - question, - max_x - 1, - curses.A_BOLD | (curses.color_pair(2) if curses.has_colors() else 0), - ) - except curses.error: - pass - - for row, i in enumerate(range(scroll_offset, min(scroll_offset + visible, len(choices)))): - y = row + 2 - if y >= max_y - 1: - break - arrow = "→" if i == cursor else " " - line = f" {arrow} {choices[i]}" - attr = curses.A_NORMAL - if i == cursor: - attr = curses.A_BOLD - if curses.has_colors(): - attr |= curses.color_pair(1) - try: - stdscr.addnstr(y, 0, line, max_x - 1, attr) - except curses.error: - pass - - stdscr.refresh() - key = stdscr.getch() - if key in (curses.KEY_UP, ord("k")): - cursor = (cursor - 1) % len(choices) - elif key in (curses.KEY_DOWN, ord("j")): - cursor = (cursor + 1) % len(choices) - elif key in (curses.KEY_ENTER, 10, 13): - result_holder[0] = cursor - return - elif key in (27, ord("q")): - return - - curses.wrapper(_curses_menu) - from hermes_cli.curses_ui import flush_stdin - flush_stdin() - return result_holder[0] - except Exception: - return -1 + """Single-select menu using curses. Delegates to curses_radiolist.""" + from hermes_cli.curses_ui import curses_radiolist + return curses_radiolist(question, choices, selected=default, cancel_returns=-1) @@ -2052,6 +1969,42 @@ def _setup_weixin(): _gateway_setup_weixin() +def _setup_signal(): + """Configure Signal via gateway setup.""" + from hermes_cli.gateway import _setup_signal as _gateway_setup_signal + _gateway_setup_signal() + + +def _setup_email(): + """Configure Email via gateway setup.""" + from hermes_cli.gateway import _setup_email as _gateway_setup_email + _gateway_setup_email() + + +def _setup_sms(): + """Configure SMS (Twilio) via gateway setup.""" + from hermes_cli.gateway import _setup_sms as _gateway_setup_sms + _gateway_setup_sms() + + +def _setup_dingtalk(): + """Configure DingTalk via gateway setup.""" + from hermes_cli.gateway import _setup_dingtalk as _gateway_setup_dingtalk + _gateway_setup_dingtalk() + + +def _setup_feishu(): + """Configure Feishu / Lark via gateway setup.""" + from hermes_cli.gateway import _setup_feishu as _gateway_setup_feishu + _gateway_setup_feishu() + + +def _setup_wecom(): + """Configure WeCom (Enterprise WeChat) via gateway setup.""" + from hermes_cli.gateway import _setup_wecom as _gateway_setup_wecom + _gateway_setup_wecom() + + def _setup_bluebubbles(): """Configure BlueBubbles iMessage gateway.""" print_header("BlueBubbles (iMessage)") @@ -2168,9 +2121,15 @@ _GATEWAY_PLATFORMS = [ ("Telegram", "TELEGRAM_BOT_TOKEN", _setup_telegram), ("Discord", "DISCORD_BOT_TOKEN", _setup_discord), ("Slack", "SLACK_BOT_TOKEN", _setup_slack), + ("Signal", "SIGNAL_HTTP_URL", _setup_signal), + ("Email", "EMAIL_ADDRESS", _setup_email), + ("SMS (Twilio)", "TWILIO_ACCOUNT_SID", _setup_sms), ("Matrix", "MATRIX_ACCESS_TOKEN", _setup_matrix), ("Mattermost", "MATTERMOST_TOKEN", _setup_mattermost), ("WhatsApp", "WHATSAPP_ENABLED", _setup_whatsapp), + ("DingTalk", "DINGTALK_CLIENT_ID", _setup_dingtalk), + ("Feishu / Lark", "FEISHU_APP_ID", _setup_feishu), + ("WeCom (Enterprise WeChat)", "WECOM_BOT_ID", _setup_wecom), ("Weixin (WeChat)", "WEIXIN_ACCOUNT_ID", _setup_weixin), ("BlueBubbles (iMessage)", "BLUEBUBBLES_SERVER_URL", _setup_bluebubbles), ("Webhooks (GitHub, GitLab, etc.)", "WEBHOOK_ENABLED", _setup_webhooks), @@ -2212,10 +2171,17 @@ def setup_gateway(config: dict): get_env_value("TELEGRAM_BOT_TOKEN") or get_env_value("DISCORD_BOT_TOKEN") or get_env_value("SLACK_BOT_TOKEN") + or get_env_value("SIGNAL_HTTP_URL") + or get_env_value("EMAIL_ADDRESS") + or get_env_value("TWILIO_ACCOUNT_SID") or get_env_value("MATTERMOST_TOKEN") or get_env_value("MATRIX_ACCESS_TOKEN") or get_env_value("MATRIX_PASSWORD") or get_env_value("WHATSAPP_ENABLED") + or get_env_value("DINGTALK_CLIENT_ID") + or get_env_value("FEISHU_APP_ID") + or get_env_value("WECOM_BOT_ID") + or get_env_value("WEIXIN_ACCOUNT_ID") or get_env_value("BLUEBUBBLES_SERVER_URL") or get_env_value("WEBHOOK_ENABLED") ) @@ -2404,12 +2370,30 @@ def _get_section_config_summary(config: dict, section_key: str) -> Optional[str] platforms.append("Discord") if get_env_value("SLACK_BOT_TOKEN"): platforms.append("Slack") - if get_env_value("WHATSAPP_PHONE_NUMBER_ID"): - platforms.append("WhatsApp") if get_env_value("SIGNAL_ACCOUNT"): platforms.append("Signal") + if get_env_value("EMAIL_ADDRESS"): + platforms.append("Email") + if get_env_value("TWILIO_ACCOUNT_SID"): + platforms.append("SMS") + if get_env_value("MATRIX_ACCESS_TOKEN") or get_env_value("MATRIX_PASSWORD"): + platforms.append("Matrix") + if get_env_value("MATTERMOST_TOKEN"): + platforms.append("Mattermost") + if get_env_value("WHATSAPP_PHONE_NUMBER_ID"): + platforms.append("WhatsApp") + if get_env_value("DINGTALK_CLIENT_ID"): + platforms.append("DingTalk") + if get_env_value("FEISHU_APP_ID"): + platforms.append("Feishu") + if get_env_value("WECOM_BOT_ID"): + platforms.append("WeCom") + if get_env_value("WEIXIN_ACCOUNT_ID"): + platforms.append("Weixin") if get_env_value("BLUEBUBBLES_SERVER_URL"): platforms.append("BlueBubbles") + if get_env_value("WEBHOOK_ENABLED"): + platforms.append("Webhooks") if platforms: return ", ".join(platforms) return None # No platforms configured — section must run diff --git a/hermes_cli/skills_config.py b/hermes_cli/skills_config.py index b017361fee..92424a0ca3 100644 --- a/hermes_cli/skills_config.py +++ b/hermes_cli/skills_config.py @@ -15,25 +15,12 @@ from typing import List, Optional, Set from hermes_cli.config import load_config, save_config from hermes_cli.colors import Colors, color +from hermes_cli.platforms import PLATFORMS as _PLATFORMS, platform_label -PLATFORMS = { - "cli": "🖥️ CLI", - "telegram": "📱 Telegram", - "discord": "💬 Discord", - "slack": "💼 Slack", - "whatsapp": "📱 WhatsApp", - "signal": "📡 Signal", - "bluebubbles": "💬 BlueBubbles", - "email": "📧 Email", - "homeassistant": "🏠 Home Assistant", - "mattermost": "💬 Mattermost", - "matrix": "💬 Matrix", - "dingtalk": "💬 DingTalk", - "feishu": "🪽 Feishu", - "wecom": "💬 WeCom", - "weixin": "💬 Weixin", - "webhook": "🔗 Webhook", -} +# Backward-compatible view: {key: label_string} so existing code that +# iterates ``PLATFORMS.items()`` or calls ``PLATFORMS.get(key)`` keeps +# working without changes to every call site. +PLATFORMS = {k: info.label for k, info in _PLATFORMS.items() if k != "api_server"} # ─── Config Helpers ─────────────────────────────────────────────────────────── diff --git a/hermes_cli/status.py b/hermes_cli/status.py index baba4f359d..7a7a9c645d 100644 --- a/hermes_cli/status.py +++ b/hermes_cli/status.py @@ -141,11 +141,8 @@ def show_status(args): display = redact_key(value) if not show_all else value print(f" {name:<12} {check_mark(has_key)} {display}") - anthropic_value = ( - get_env_value("ANTHROPIC_TOKEN") - or get_env_value("ANTHROPIC_API_KEY") - or "" - ) + from hermes_cli.auth import get_anthropic_key + anthropic_value = get_anthropic_key() anthropic_display = redact_key(anthropic_value) if not show_all else anthropic_value print(f" {'Anthropic':<12} {check_mark(bool(anthropic_value))} {anthropic_display}") diff --git a/hermes_cli/tools_config.py b/hermes_cli/tools_config.py index 91c41dce5d..343007cabc 100644 --- a/hermes_cli/tools_config.py +++ b/hermes_cli/tools_config.py @@ -33,33 +33,13 @@ PROJECT_ROOT = Path(__file__).parent.parent.resolve() # ─── UI Helpers (shared with setup.py) ──────────────────────────────────────── -def _print_info(text: str): - print(color(f" {text}", Colors.DIM)) - -def _print_success(text: str): - print(color(f"✓ {text}", Colors.GREEN)) - -def _print_warning(text: str): - print(color(f"⚠ {text}", Colors.YELLOW)) - -def _print_error(text: str): - print(color(f"✗ {text}", Colors.RED)) - -def _prompt(question: str, default: str = None, password: bool = False) -> str: - if default: - display = f"{question} [{default}]: " - else: - display = f"{question}: " - try: - if password: - import getpass - value = getpass.getpass(color(display, Colors.YELLOW)) - else: - value = input(color(display, Colors.YELLOW)) - return value.strip() or default or "" - except (KeyboardInterrupt, EOFError): - print() - return default or "" +from hermes_cli.cli_output import ( # noqa: E402 — late import block + print_error as _print_error, + print_info as _print_info, + print_success as _print_success, + print_warning as _print_warning, + prompt as _prompt, +) # ─── Toolset Registry ───────────────────────────────────────────────────────── @@ -118,25 +98,14 @@ def _get_plugin_toolset_keys() -> set: except Exception: return set() -# Platform display config +# Platform display config — derived from the canonical registry so every +# module shares the same data. Kept as dict-of-dicts for backward +# compatibility with existing ``PLATFORMS[key]["label"]`` access patterns. +from hermes_cli.platforms import PLATFORMS as _PLATFORMS_REGISTRY + PLATFORMS = { - "cli": {"label": "🖥️ CLI", "default_toolset": "hermes-cli"}, - "telegram": {"label": "📱 Telegram", "default_toolset": "hermes-telegram"}, - "discord": {"label": "💬 Discord", "default_toolset": "hermes-discord"}, - "slack": {"label": "💼 Slack", "default_toolset": "hermes-slack"}, - "whatsapp": {"label": "📱 WhatsApp", "default_toolset": "hermes-whatsapp"}, - "signal": {"label": "📡 Signal", "default_toolset": "hermes-signal"}, - "bluebubbles": {"label": "💙 BlueBubbles", "default_toolset": "hermes-bluebubbles"}, - "homeassistant": {"label": "🏠 Home Assistant", "default_toolset": "hermes-homeassistant"}, - "email": {"label": "📧 Email", "default_toolset": "hermes-email"}, - "matrix": {"label": "💬 Matrix", "default_toolset": "hermes-matrix"}, - "dingtalk": {"label": "💬 DingTalk", "default_toolset": "hermes-dingtalk"}, - "feishu": {"label": "🪽 Feishu", "default_toolset": "hermes-feishu"}, - "wecom": {"label": "💬 WeCom", "default_toolset": "hermes-wecom"}, - "weixin": {"label": "💬 Weixin", "default_toolset": "hermes-weixin"}, - "api_server": {"label": "🌐 API Server", "default_toolset": "hermes-api-server"}, - "mattermost": {"label": "💬 Mattermost", "default_toolset": "hermes-mattermost"}, - "webhook": {"label": "🔗 Webhook", "default_toolset": "hermes-webhook"}, + k: {"label": info.label, "default_toolset": info.default_toolset} + for k, info in _PLATFORMS_REGISTRY.items() } @@ -677,86 +646,9 @@ def _toolset_has_keys(ts_key: str, config: dict = None) -> bool: # ─── Menu Helpers ───────────────────────────────────────────────────────────── def _prompt_choice(question: str, choices: list, default: int = 0) -> int: - """Single-select menu (arrow keys). Uses curses to avoid simple_term_menu - rendering bugs in tmux, iTerm, and other non-standard terminals.""" - - # Curses-based single-select — works in tmux, iTerm, and standard terminals - try: - import curses - result_holder = [default] - - def _curses_menu(stdscr): - curses.curs_set(0) - if curses.has_colors(): - curses.start_color() - curses.use_default_colors() - curses.init_pair(1, curses.COLOR_GREEN, -1) - curses.init_pair(2, curses.COLOR_YELLOW, -1) - cursor = default - - while True: - stdscr.clear() - max_y, max_x = stdscr.getmaxyx() - try: - stdscr.addnstr(0, 0, question, max_x - 1, - curses.A_BOLD | (curses.color_pair(2) if curses.has_colors() else 0)) - except curses.error: - pass - - for i, c in enumerate(choices): - y = i + 2 - if y >= max_y - 1: - break - arrow = "→" if i == cursor else " " - line = f" {arrow} {c}" - attr = curses.A_NORMAL - if i == cursor: - attr = curses.A_BOLD - if curses.has_colors(): - attr |= curses.color_pair(1) - try: - stdscr.addnstr(y, 0, line, max_x - 1, attr) - except curses.error: - pass - - stdscr.refresh() - key = stdscr.getch() - - if key in (curses.KEY_UP, ord('k')): - cursor = (cursor - 1) % len(choices) - elif key in (curses.KEY_DOWN, ord('j')): - cursor = (cursor + 1) % len(choices) - elif key in (curses.KEY_ENTER, 10, 13): - result_holder[0] = cursor - return - elif key in (27, ord('q')): - return - - curses.wrapper(_curses_menu) - from hermes_cli.curses_ui import flush_stdin - flush_stdin() - return result_holder[0] - - except Exception: - pass - - # Fallback: numbered input (Windows without curses, etc.) - print(color(question, Colors.YELLOW)) - for i, c in enumerate(choices): - marker = "●" if i == default else "○" - style = Colors.GREEN if i == default else "" - print(color(f" {marker} {i+1}. {c}", style) if style else f" {marker} {i+1}. {c}") - while True: - try: - val = input(color(f" Select [1-{len(choices)}] ({default + 1}): ", Colors.DIM)) - if not val: - return default - idx = int(val) - 1 - if 0 <= idx < len(choices): - return idx - except (ValueError, KeyboardInterrupt, EOFError): - print() - return default + """Single-select menu (arrow keys). Delegates to curses_radiolist.""" + from hermes_cli.curses_ui import curses_radiolist + return curses_radiolist(question, choices, selected=default, cancel_returns=default) # ─── Token Estimation ──────────────────────────────────────────────────────── diff --git a/hermes_constants.py b/hermes_constants.py index 7d149f404e..85955d5482 100644 --- a/hermes_constants.py +++ b/hermes_constants.py @@ -189,6 +189,33 @@ def is_wsl() -> bool: return _wsl_detected +# ─── Well-Known Paths ───────────────────────────────────────────────────────── + + +def get_config_path() -> Path: + """Return the path to ``config.yaml`` under HERMES_HOME. + + Replaces the ``get_hermes_home() / "config.yaml"`` pattern repeated + in 7+ files (skill_utils.py, hermes_logging.py, hermes_time.py, etc.). + """ + return get_hermes_home() / "config.yaml" + + +def get_skills_dir() -> Path: + """Return the path to the skills directory under HERMES_HOME.""" + return get_hermes_home() / "skills" + + +def get_logs_dir() -> Path: + """Return the path to the logs directory under HERMES_HOME.""" + return get_hermes_home() / "logs" + + +def get_env_path() -> Path: + """Return the path to the ``.env`` file under HERMES_HOME.""" + return get_hermes_home() / ".env" + + OPENROUTER_BASE_URL = "https://openrouter.ai/api/v1" OPENROUTER_MODELS_URL = f"{OPENROUTER_BASE_URL}/models" diff --git a/hermes_logging.py b/hermes_logging.py index 5d71590c3f..b765e94640 100644 --- a/hermes_logging.py +++ b/hermes_logging.py @@ -18,7 +18,7 @@ from logging.handlers import RotatingFileHandler from pathlib import Path from typing import Optional -from hermes_constants import get_hermes_home +from hermes_constants import get_config_path, get_hermes_home # Sentinel to track whether setup_logging() has already run. The function # is idempotent — calling it twice is safe but the second call is a no-op @@ -246,7 +246,7 @@ def _read_logging_config(): """ try: import yaml - config_path = get_hermes_home() / "config.yaml" + config_path = get_config_path() if config_path.exists(): with open(config_path, "r", encoding="utf-8") as f: cfg = yaml.safe_load(f) or {} diff --git a/hermes_time.py b/hermes_time.py index f7d085544b..9f172d28ff 100644 --- a/hermes_time.py +++ b/hermes_time.py @@ -16,7 +16,7 @@ crashes due to a bad timezone string. import logging import os from datetime import datetime -from hermes_constants import get_hermes_home +from hermes_constants import get_config_path from typing import Optional logger = logging.getLogger(__name__) @@ -48,8 +48,7 @@ def _resolve_timezone_name() -> str: # 2. config.yaml ``timezone`` key try: import yaml - hermes_home = get_hermes_home() - config_path = hermes_home / "config.yaml" + config_path = get_config_path() if config_path.exists(): with open(config_path) as f: cfg = yaml.safe_load(f) or {} diff --git a/run_agent.py b/run_agent.py index 60d82f541f..d9066fa6fb 100644 --- a/run_agent.py +++ b/run_agent.py @@ -739,6 +739,7 @@ class AIAgent: # Interrupt mechanism for breaking out of tool loops self._interrupt_requested = False self._interrupt_message = None # Optional message that triggered interrupt + self._execution_thread_id: int | None = None # Set at run_conversation() start self._client_lock = threading.RLock() # Subagent delegation state @@ -2832,8 +2833,10 @@ class AIAgent: """ self._interrupt_requested = True self._interrupt_message = message - # Signal all tools to abort any in-flight operations immediately - _set_interrupt(True) + # Signal all tools to abort any in-flight operations immediately. + # Scope the interrupt to this agent's execution thread so other + # agents running in the same process (gateway) are not affected. + _set_interrupt(True, self._execution_thread_id) # Propagate interrupt to any running child agents (subagent delegation) with self._active_children_lock: children_copy = list(self._active_children) @@ -2846,10 +2849,10 @@ class AIAgent: print("\n⚡ Interrupt requested" + (f": '{message[:40]}...'" if message and len(message) > 40 else f": '{message}'" if message else "")) def clear_interrupt(self) -> None: - """Clear any pending interrupt request and the global tool interrupt signal.""" + """Clear any pending interrupt request and the per-thread tool interrupt signal.""" self._interrupt_requested = False self._interrupt_message = None - _set_interrupt(False) + _set_interrupt(False, self._execution_thread_id) def _touch_activity(self, desc: str) -> None: """Update the last-activity timestamp and description (thread-safe).""" @@ -3443,6 +3446,7 @@ class AIAgent: def _chat_messages_to_responses_input(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]: """Convert internal chat-style messages to Responses input items.""" items: List[Dict[str, Any]] = [] + seen_item_ids: set = set() for msg in messages: if not isinstance(msg, dict): @@ -3463,7 +3467,12 @@ class AIAgent: if isinstance(codex_reasoning, list): for ri in codex_reasoning: if isinstance(ri, dict) and ri.get("encrypted_content"): + item_id = ri.get("id") + if item_id and item_id in seen_item_ids: + continue items.append(ri) + if item_id: + seen_item_ids.add(item_id) has_codex_reasoning = True if content_text.strip(): @@ -3543,6 +3552,7 @@ class AIAgent: raise ValueError("Codex Responses input must be a list of input items.") normalized: List[Dict[str, Any]] = [] + seen_ids: set = set() for idx, item in enumerate(raw_items): if not isinstance(item, dict): raise ValueError(f"Codex Responses input[{idx}] must be an object.") @@ -3595,8 +3605,12 @@ class AIAgent: if item_type == "reasoning": encrypted = item.get("encrypted_content") if isinstance(encrypted, str) and encrypted: - reasoning_item = {"type": "reasoning", "encrypted_content": encrypted} item_id = item.get("id") + if isinstance(item_id, str) and item_id: + if item_id in seen_ids: + continue + seen_ids.add(item_id) + reasoning_item = {"type": "reasoning", "encrypted_content": encrypted} if isinstance(item_id, str) and item_id: reasoning_item["id"] = item_id summary = item.get("summary") @@ -7800,6 +7814,11 @@ class AIAgent: compression_attempts = 0 _turn_exit_reason = "unknown" # Diagnostic: why the loop ended + # Record the execution thread so interrupt()/clear_interrupt() can + # scope the tool-level interrupt signal to THIS agent's thread only. + # Must be set before clear_interrupt() which uses it. + self._execution_thread_id = threading.current_thread().ident + # Clear any stale interrupt state at start self.clear_interrupt() @@ -8278,8 +8297,24 @@ class AIAgent: _text_parts.append(getattr(_blk, "text", "")) _trunc_content = "\n".join(_text_parts) if _text_parts else None + # A response is "thinking exhausted" only when the model + # actually produced reasoning blocks but no visible text after + # them. Models that do not use tags (e.g. GLM-4.7 on + # NVIDIA Build, minimax) may return content=None or an empty + # string for unrelated reasons — treat those as normal + # truncations that deserve continuation retries, not as + # thinking-budget exhaustion. + _has_think_tags = bool( + _trunc_content and re.search( + r'<(?:think|thinking|reasoning|REASONING_SCRATCHPAD)[^>]*>', + _trunc_content, + re.IGNORECASE, + ) + ) _thinking_exhausted = ( - not _trunc_has_tool_calls and ( + not _trunc_has_tool_calls + and _has_think_tags + and ( (_trunc_content is not None and not self._has_content_after_think_block(_trunc_content)) or _trunc_content is None ) @@ -9507,12 +9542,41 @@ class AIAgent: invalid_json_args.append((tc.function.name, str(e))) if invalid_json_args: + # Check if the invalid JSON is due to truncation rather + # than a model formatting mistake. Routers sometimes + # rewrite finish_reason from "length" to "tool_calls", + # hiding the truncation from the length handler above. + # Detect truncation: args that don't end with } or ] + # (after stripping whitespace) are cut off mid-stream. + _truncated = any( + not (tc.function.arguments or "").rstrip().endswith(("}", "]")) + for tc in assistant_message.tool_calls + if tc.function.name in {n for n, _ in invalid_json_args} + ) + if _truncated: + self._vprint( + f"{self.log_prefix}⚠️ Truncated tool call arguments detected " + f"(finish_reason={finish_reason!r}) — refusing to execute.", + force=True, + ) + self._invalid_json_retries = 0 + self._cleanup_task_resources(effective_task_id) + self._persist_session(messages, conversation_history) + return { + "final_response": None, + "messages": messages, + "api_calls": api_call_count, + "completed": False, + "partial": True, + "error": "Response truncated due to output length limit", + } + # Track retries for invalid JSON arguments self._invalid_json_retries += 1 - + tool_name, error_msg = invalid_json_args[0] self._vprint(f"{self.log_prefix}⚠️ Invalid JSON in tool call arguments for '{tool_name}': {error_msg}") - + if self._invalid_json_retries < 3: self._vprint(f"{self.log_prefix}🔄 Retrying API call ({self._invalid_json_retries}/3)...") # Don't add anything to messages, just retry the API call diff --git a/scripts/whatsapp-bridge/package-lock.json b/scripts/whatsapp-bridge/package-lock.json index 01af1c15a0..23ea30a092 100644 --- a/scripts/whatsapp-bridge/package-lock.json +++ b/scripts/whatsapp-bridge/package-lock.json @@ -8,7 +8,7 @@ "name": "hermes-whatsapp-bridge", "version": "1.0.0", "dependencies": { - "@whiskeysockets/baileys": "7.0.0-rc.9", + "@whiskeysockets/baileys": "WhiskeySockets/Baileys#fix/abprops-abt-fetch", "express": "^4.21.0", "pino": "^9.0.0", "qrcode-terminal": "^0.12.0" @@ -730,21 +730,22 @@ } }, "node_modules/@whiskeysockets/baileys": { + "name": "baileys", "version": "7.0.0-rc.9", - "resolved": "https://registry.npmjs.org/@whiskeysockets/baileys/-/baileys-7.0.0-rc.9.tgz", - "integrity": "sha512-YFm5gKXfDP9byCXCW3OPHKXLzrAKzolzgVUlRosHHgwbnf2YOO3XknkMm6J7+F0ns8OA0uuSBhgkRHTDtqkacw==", + "resolved": "git+ssh://git@github.com/WhiskeySockets/Baileys.git#01047debd81beb20da7b7779b08edcb06aa03770", "hasInstallScript": true, "license": "MIT", "dependencies": { "@cacheable/node-cache": "^1.4.0", "@hapi/boom": "^9.1.3", "async-mutex": "^0.5.0", - "libsignal": "git+https://github.com/whiskeysockets/libsignal-node.git", + "libsignal": "git+https://github.com/whiskeysockets/libsignal-node", "lru-cache": "^11.1.0", "music-metadata": "^11.7.0", "p-queue": "^9.0.0", "pino": "^9.6", "protobufjs": "^7.2.4", + "whatsapp-rust-bridge": "0.5.2", "ws": "^8.13.0" }, "engines": { @@ -2125,6 +2126,12 @@ "node": ">= 0.8" } }, + "node_modules/whatsapp-rust-bridge": { + "version": "0.5.2", + "resolved": "https://registry.npmjs.org/whatsapp-rust-bridge/-/whatsapp-rust-bridge-0.5.2.tgz", + "integrity": "sha512-6KBRNvxg6WMIwZ/euA8qVzj16qxMBzLllfmaJIP1JGAAfSvwn6nr8JDOMXeqpXPEOl71UfOG+79JwKEoT2b1Fw==", + "license": "MIT" + }, "node_modules/win-guid": { "version": "0.2.1", "resolved": "https://registry.npmjs.org/win-guid/-/win-guid-0.2.1.tgz", diff --git a/scripts/whatsapp-bridge/package.json b/scripts/whatsapp-bridge/package.json index 7db81f699e..2d32560f44 100644 --- a/scripts/whatsapp-bridge/package.json +++ b/scripts/whatsapp-bridge/package.json @@ -8,7 +8,7 @@ "start": "node bridge.js" }, "dependencies": { - "@whiskeysockets/baileys": "7.0.0-rc.9", + "@whiskeysockets/baileys": "WhiskeySockets/Baileys#fix/abprops-abt-fetch", "express": "^4.21.0", "qrcode-terminal": "^0.12.0", "pino": "^9.0.0" diff --git a/tests/agent/test_local_stream_timeout.py b/tests/agent/test_local_stream_timeout.py index 929f2e3c84..8184dd2d49 100644 --- a/tests/agent/test_local_stream_timeout.py +++ b/tests/agent/test_local_stream_timeout.py @@ -22,6 +22,9 @@ class TestLocalStreamReadTimeout: "http://0.0.0.0:5000", "http://192.168.1.100:8000", "http://10.0.0.5:1234", + "http://host.docker.internal:11434", + "http://host.containers.internal:11434", + "http://host.lima.internal:11434", ]) def test_local_endpoint_bumps_read_timeout(self, base_url): """Local endpoint + default timeout -> bumps to base_timeout.""" @@ -68,3 +71,38 @@ class TestLocalStreamReadTimeout: if _stream_read_timeout == 120.0 and base_url and is_local_endpoint(base_url): _stream_read_timeout = _base_timeout assert _stream_read_timeout == 120.0 + + +class TestIsLocalEndpoint: + """Direct unit tests for is_local_endpoint.""" + + @pytest.mark.parametrize("url", [ + "http://localhost:11434", + "http://127.0.0.1:8080", + "http://0.0.0.0:5000", + "http://[::1]:11434", + "http://192.168.1.100:8000", + "http://10.0.0.5:1234", + "http://172.17.0.1:11434", + ]) + def test_classic_local_addresses(self, url): + assert is_local_endpoint(url) is True + + @pytest.mark.parametrize("url", [ + "http://host.docker.internal:11434", + "http://host.docker.internal:8080/v1", + "http://gateway.docker.internal:11434", + "http://host.containers.internal:11434", + "http://host.lima.internal:11434", + ]) + def test_container_dns_names(self, url): + assert is_local_endpoint(url) is True + + @pytest.mark.parametrize("url", [ + "https://api.openai.com", + "https://openrouter.ai/api", + "https://api.anthropic.com", + "https://evil.docker.internal.example.com", + ]) + def test_remote_endpoints(self, url): + assert is_local_endpoint(url) is False diff --git a/tests/e2e/conftest.py b/tests/e2e/conftest.py index ef17af10bc..d9ca627c4f 100644 --- a/tests/e2e/conftest.py +++ b/tests/e2e/conftest.py @@ -211,7 +211,8 @@ def make_adapter(platform: Platform, runner=None): config = PlatformConfig(enabled=True, token="e2e-test-token") if platform == Platform.DISCORD: - with patch.object(DiscordAdapter, "_load_participated_threads", return_value=set()): + from gateway.platforms.helpers import ThreadParticipationTracker + with patch.object(ThreadParticipationTracker, "_load", return_value=set()): adapter = DiscordAdapter(config) platform_key = Platform.DISCORD elif platform == Platform.SLACK: diff --git a/tests/gateway/test_api_server.py b/tests/gateway/test_api_server.py index afc3ce9ce9..2be01fc2d1 100644 --- a/tests/gateway/test_api_server.py +++ b/tests/gateway/test_api_server.py @@ -409,11 +409,50 @@ class TestChatCompletionsEndpoint: ) assert resp.status == 200 assert "text/event-stream" in resp.headers.get("Content-Type", "") + assert resp.headers.get("X-Accel-Buffering") == "no" body = await resp.text() assert "data: " in body assert "[DONE]" in body assert "Hello!" in body + @pytest.mark.asyncio + async def test_stream_sends_keepalive_during_quiet_tool_gap(self, adapter): + """Idle SSE streams should send keepalive comments while tools run silently.""" + import asyncio + import gateway.platforms.api_server as api_server_mod + + app = _create_app(adapter) + async with TestClient(TestServer(app)) as cli: + async def _mock_run_agent(**kwargs): + cb = kwargs.get("stream_delta_callback") + if cb: + cb("Working") + await asyncio.sleep(0.65) + cb("...done") + return ( + {"final_response": "Working...done", "messages": [], "api_calls": 1}, + {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}, + ) + + with ( + patch.object(api_server_mod, "CHAT_COMPLETIONS_SSE_KEEPALIVE_SECONDS", 0.01), + patch.object(adapter, "_run_agent", side_effect=_mock_run_agent), + ): + resp = await cli.post( + "/v1/chat/completions", + json={ + "model": "test", + "messages": [{"role": "user", "content": "do the thing"}], + "stream": True, + }, + ) + assert resp.status == 200 + body = await resp.text() + assert ": keepalive" in body + assert "Working" in body + assert "...done" in body + assert "[DONE]" in body + @pytest.mark.asyncio async def test_stream_survives_tool_call_none_sentinel(self, adapter): """stream_delta_callback(None) mid-stream (tool calls) must NOT kill the SSE stream. diff --git a/tests/gateway/test_dingtalk.py b/tests/gateway/test_dingtalk.py index 5c73253fbf..5271136502 100644 --- a/tests/gateway/test_dingtalk.py +++ b/tests/gateway/test_dingtalk.py @@ -119,28 +119,29 @@ class TestDeduplication: def test_first_message_not_duplicate(self): from gateway.platforms.dingtalk import DingTalkAdapter adapter = DingTalkAdapter(PlatformConfig(enabled=True)) - assert adapter._is_duplicate("msg-1") is False + assert adapter._dedup.is_duplicate("msg-1") is False def test_second_same_message_is_duplicate(self): from gateway.platforms.dingtalk import DingTalkAdapter adapter = DingTalkAdapter(PlatformConfig(enabled=True)) - adapter._is_duplicate("msg-1") - assert adapter._is_duplicate("msg-1") is True + adapter._dedup.is_duplicate("msg-1") + assert adapter._dedup.is_duplicate("msg-1") is True def test_different_messages_not_duplicate(self): from gateway.platforms.dingtalk import DingTalkAdapter adapter = DingTalkAdapter(PlatformConfig(enabled=True)) - adapter._is_duplicate("msg-1") - assert adapter._is_duplicate("msg-2") is False + adapter._dedup.is_duplicate("msg-1") + assert adapter._dedup.is_duplicate("msg-2") is False def test_cache_cleanup_on_overflow(self): - from gateway.platforms.dingtalk import DingTalkAdapter, DEDUP_MAX_SIZE + from gateway.platforms.dingtalk import DingTalkAdapter adapter = DingTalkAdapter(PlatformConfig(enabled=True)) + max_size = adapter._dedup._max_size # Fill beyond max - for i in range(DEDUP_MAX_SIZE + 10): - adapter._is_duplicate(f"msg-{i}") + for i in range(max_size + 10): + adapter._dedup.is_duplicate(f"msg-{i}") # Cache should have been pruned - assert len(adapter._seen_messages) <= DEDUP_MAX_SIZE + 10 + assert len(adapter._dedup._seen) <= max_size + 10 # --------------------------------------------------------------------------- @@ -253,13 +254,13 @@ class TestConnect: from gateway.platforms.dingtalk import DingTalkAdapter adapter = DingTalkAdapter(PlatformConfig(enabled=True)) adapter._session_webhooks["a"] = "http://x" - adapter._seen_messages["b"] = 1.0 + adapter._dedup._seen["b"] = 1.0 adapter._http_client = AsyncMock() adapter._stream_task = None await adapter.disconnect() assert len(adapter._session_webhooks) == 0 - assert len(adapter._seen_messages) == 0 + assert len(adapter._dedup._seen) == 0 assert adapter._http_client is None diff --git a/tests/gateway/test_discord_connect.py b/tests/gateway/test_discord_connect.py index dd594cf7ed..9f094dd0dd 100644 --- a/tests/gateway/test_discord_connect.py +++ b/tests/gateway/test_discord_connect.py @@ -137,4 +137,4 @@ async def test_connect_releases_token_lock_on_timeout(monkeypatch): assert ok is False assert released == [("discord-bot-token", "test-token")] - assert adapter._token_lock_identity is None + assert adapter._platform_lock_identity is None diff --git a/tests/gateway/test_discord_free_response.py b/tests/gateway/test_discord_free_response.py index bc63c14f5a..29f65efc67 100644 --- a/tests/gateway/test_discord_free_response.py +++ b/tests/gateway/test_discord_free_response.py @@ -302,7 +302,7 @@ async def test_discord_bot_thread_skips_mention_requirement(adapter, monkeypatch monkeypatch.setenv("DISCORD_AUTO_THREAD", "false") # Simulate bot having previously participated in thread 456 - adapter._bot_participated_threads.add("456") + adapter._threads.mark("456") thread = FakeThread(channel_id=456, name="existing thread") message = make_message(channel=thread, content="follow-up without mention") @@ -344,7 +344,7 @@ async def test_discord_auto_thread_tracks_participation(adapter, monkeypatch): await adapter._handle_message(message) - assert "555" in adapter._bot_participated_threads + assert "555" in adapter._threads @pytest.mark.asyncio @@ -358,4 +358,4 @@ async def test_discord_thread_participation_tracked_on_dispatch(adapter, monkeyp await adapter._handle_message(message) - assert "777" in adapter._bot_participated_threads + assert "777" in adapter._threads diff --git a/tests/gateway/test_discord_thread_persistence.py b/tests/gateway/test_discord_thread_persistence.py index 0288b620d2..083f61ac7c 100644 --- a/tests/gateway/test_discord_thread_persistence.py +++ b/tests/gateway/test_discord_thread_persistence.py @@ -1,6 +1,6 @@ """Tests for Discord thread participation persistence. -Verifies that _bot_participated_threads survives adapter restarts by +Verifies that _threads (ThreadParticipationTracker) survives adapter restarts by being persisted to ~/.hermes/discord_threads.json. """ @@ -25,13 +25,13 @@ class TestDiscordThreadPersistence: def test_starts_empty_when_no_state_file(self, tmp_path): adapter = self._make_adapter(tmp_path) - assert adapter._bot_participated_threads == set() + assert "$nonexistent" not in adapter._threads def test_track_thread_persists_to_disk(self, tmp_path): adapter = self._make_adapter(tmp_path) with patch.dict(os.environ, {"HERMES_HOME": str(tmp_path)}): - adapter._track_thread("111") - adapter._track_thread("222") + adapter._threads.mark("111") + adapter._threads.mark("222") state_file = tmp_path / "discord_threads.json" assert state_file.exists() @@ -42,42 +42,43 @@ class TestDiscordThreadPersistence: """Threads tracked by one adapter instance are visible to the next.""" adapter1 = self._make_adapter(tmp_path) with patch.dict(os.environ, {"HERMES_HOME": str(tmp_path)}): - adapter1._track_thread("aaa") - adapter1._track_thread("bbb") + adapter1._threads.mark("aaa") + adapter1._threads.mark("bbb") adapter2 = self._make_adapter(tmp_path) - assert "aaa" in adapter2._bot_participated_threads - assert "bbb" in adapter2._bot_participated_threads + assert "aaa" in adapter2._threads + assert "bbb" in adapter2._threads def test_duplicate_track_does_not_double_save(self, tmp_path): adapter = self._make_adapter(tmp_path) with patch.dict(os.environ, {"HERMES_HOME": str(tmp_path)}): - adapter._track_thread("111") - adapter._track_thread("111") # no-op + adapter._threads.mark("111") + adapter._threads.mark("111") # no-op saved = json.loads((tmp_path / "discord_threads.json").read_text()) assert saved.count("111") == 1 def test_caps_at_max_tracked_threads(self, tmp_path): adapter = self._make_adapter(tmp_path) - adapter._MAX_TRACKED_THREADS = 5 + adapter._threads._max_tracked = 5 with patch.dict(os.environ, {"HERMES_HOME": str(tmp_path)}): for i in range(10): - adapter._track_thread(str(i)) + adapter._threads.mark(str(i)) - assert len(adapter._bot_participated_threads) == 5 + saved = json.loads((tmp_path / "discord_threads.json").read_text()) + assert len(saved) == 5 def test_corrupted_state_file_falls_back_to_empty(self, tmp_path): state_file = tmp_path / "discord_threads.json" state_file.write_text("not valid json{{{") adapter = self._make_adapter(tmp_path) - assert adapter._bot_participated_threads == set() + assert "$nonexistent" not in adapter._threads def test_missing_hermes_home_does_not_crash(self, tmp_path): """Load/save tolerate missing directories.""" fake_home = tmp_path / "nonexistent" / "deep" with patch.dict(os.environ, {"HERMES_HOME": str(fake_home)}): - from gateway.platforms.discord import DiscordAdapter - # _load should return empty set, not crash - threads = DiscordAdapter._load_participated_threads() - assert threads == set() + from gateway.platforms.helpers import ThreadParticipationTracker + # ThreadParticipationTracker should return empty set, not crash + tracker = ThreadParticipationTracker("discord") + assert "$test" not in tracker diff --git a/tests/gateway/test_internal_event_bypass_pairing.py b/tests/gateway/test_internal_event_bypass_pairing.py index 05b093b04a..46a96e5aa2 100644 --- a/tests/gateway/test_internal_event_bypass_pairing.py +++ b/tests/gateway/test_internal_event_bypass_pairing.py @@ -195,6 +195,105 @@ async def test_internal_event_does_not_trigger_pairing(monkeypatch, tmp_path): ) +@pytest.mark.asyncio +async def test_notify_on_complete_preserves_user_identity(monkeypatch, tmp_path): + """Synthetic completion event should carry user_id and user_name from the watcher.""" + import tools.process_registry as pr_module + + sessions = [ + SimpleNamespace( + output_buffer="done\n", exited=True, exit_code=0, command="echo test" + ), + ] + monkeypatch.setattr(pr_module, "process_registry", _FakeRegistry(sessions)) + + async def _instant_sleep(*_a, **_kw): + pass + monkeypatch.setattr(asyncio, "sleep", _instant_sleep) + + runner = _build_runner(monkeypatch, tmp_path) + adapter = runner.adapters[Platform.DISCORD] + + watcher = _watcher_dict_with_notify() + watcher["user_id"] = "user-42" + watcher["user_name"] = "alice" + + await runner._run_process_watcher(watcher) + + assert adapter.handle_message.await_count == 1 + event = adapter.handle_message.await_args.args[0] + assert event.source.user_id == "user-42" + assert event.source.user_name == "alice" + + +@pytest.mark.asyncio +async def test_none_user_id_skips_pairing(monkeypatch, tmp_path): + """A non-internal event with user_id=None should be silently dropped.""" + import gateway.run as gateway_run + + monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path) + (tmp_path / "config.yaml").write_text("", encoding="utf-8") + + runner = GatewayRunner(GatewayConfig()) + adapter = SimpleNamespace(send=AsyncMock()) + runner.adapters[Platform.TELEGRAM] = adapter + + source = SessionSource( + platform=Platform.TELEGRAM, + chat_id="123", + chat_type="dm", + user_id=None, + ) + event = MessageEvent( + text="service message", + source=source, + internal=False, + ) + + result = await runner._handle_message(event) + + # Should return None (dropped) and NOT send any pairing message + assert result is None + assert adapter.send.await_count == 0 + + +@pytest.mark.asyncio +async def test_none_user_id_does_not_generate_pairing_code(monkeypatch, tmp_path): + """A message with user_id=None must never call generate_code.""" + import gateway.run as gateway_run + + monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path) + (tmp_path / "config.yaml").write_text("", encoding="utf-8") + + runner = GatewayRunner(GatewayConfig()) + adapter = SimpleNamespace(send=AsyncMock()) + runner.adapters[Platform.DISCORD] = adapter + + generate_called = False + original_generate = runner.pairing_store.generate_code + + def tracking_generate(*args, **kwargs): + nonlocal generate_called + generate_called = True + return original_generate(*args, **kwargs) + + runner.pairing_store.generate_code = tracking_generate + + source = SessionSource( + platform=Platform.DISCORD, + chat_id="456", + chat_type="dm", + user_id=None, + ) + event = MessageEvent(text="anonymous", source=source, internal=False) + + await runner._handle_message(event) + + assert not generate_called, ( + "Pairing code should NOT be generated for messages with user_id=None" + ) + + @pytest.mark.asyncio async def test_non_internal_event_without_user_triggers_pairing(monkeypatch, tmp_path): """Verify the normal (non-internal) path still triggers pairing for unknown users.""" diff --git a/tests/gateway/test_matrix_mention.py b/tests/gateway/test_matrix_mention.py index d36c2b7657..873b873c23 100644 --- a/tests/gateway/test_matrix_mention.py +++ b/tests/gateway/test_matrix_mention.py @@ -247,7 +247,7 @@ async def test_require_mention_bot_participated_thread(monkeypatch): monkeypatch.setenv("MATRIX_AUTO_THREAD", "false") adapter = _make_adapter() - adapter._bot_participated_threads.add("$thread1") + adapter._threads.mark("$thread1") event = _make_event("hello without mention", thread_id="$thread1") @@ -298,7 +298,7 @@ async def test_auto_thread_preserves_existing_thread(monkeypatch): monkeypatch.delenv("MATRIX_AUTO_THREAD", raising=False) adapter = _make_adapter() - adapter._bot_participated_threads.add("$thread_root") + adapter._threads.mark("$thread_root") event = _make_event("reply in thread", thread_id="$thread_root") await adapter._on_room_message(event) @@ -340,17 +340,17 @@ async def test_auto_thread_disabled(monkeypatch): @pytest.mark.asyncio async def test_auto_thread_tracks_participation(monkeypatch): - """Auto-created threads are tracked in _bot_participated_threads.""" + """Auto-created threads are tracked in _threads.""" monkeypatch.setenv("MATRIX_REQUIRE_MENTION", "false") monkeypatch.delenv("MATRIX_AUTO_THREAD", raising=False) adapter = _make_adapter() event = _make_event("hello", event_id="$msg1") - with patch.object(adapter, "_save_participated_threads"): + with patch.object(adapter._threads, "_save"): await adapter._on_room_message(event) - assert "$msg1" in adapter._bot_participated_threads + assert "$msg1" in adapter._threads # --------------------------------------------------------------------------- @@ -361,56 +361,54 @@ async def test_auto_thread_tracks_participation(monkeypatch): class TestThreadPersistence: def test_empty_state_file(self, tmp_path, monkeypatch): """No state file → empty set.""" - from gateway.platforms.matrix import MatrixAdapter + from gateway.platforms.helpers import ThreadParticipationTracker monkeypatch.setattr( - MatrixAdapter, "_thread_state_path", - staticmethod(lambda: tmp_path / "matrix_threads.json"), + ThreadParticipationTracker, "_state_path", + lambda self: tmp_path / "matrix_threads.json", ) adapter = _make_adapter() - loaded = adapter._load_participated_threads() - assert loaded == set() + assert "$nonexistent" not in adapter._threads def test_track_thread_persists(self, tmp_path, monkeypatch): - """_track_thread writes to disk.""" - from gateway.platforms.matrix import MatrixAdapter + """mark() writes to disk.""" + from gateway.platforms.helpers import ThreadParticipationTracker state_path = tmp_path / "matrix_threads.json" monkeypatch.setattr( - MatrixAdapter, "_thread_state_path", - staticmethod(lambda: state_path), + ThreadParticipationTracker, "_state_path", + lambda self: state_path, ) adapter = _make_adapter() - adapter._track_thread("$thread_abc") + adapter._threads.mark("$thread_abc") data = json.loads(state_path.read_text()) assert "$thread_abc" in data def test_threads_survive_reload(self, tmp_path, monkeypatch): """Persisted threads are loaded by a new adapter instance.""" - from gateway.platforms.matrix import MatrixAdapter + from gateway.platforms.helpers import ThreadParticipationTracker state_path = tmp_path / "matrix_threads.json" state_path.write_text(json.dumps(["$t1", "$t2"])) monkeypatch.setattr( - MatrixAdapter, "_thread_state_path", - staticmethod(lambda: state_path), + ThreadParticipationTracker, "_state_path", + lambda self: state_path, ) adapter = _make_adapter() - assert "$t1" in adapter._bot_participated_threads - assert "$t2" in adapter._bot_participated_threads + assert "$t1" in adapter._threads + assert "$t2" in adapter._threads def test_cap_max_tracked_threads(self, tmp_path, monkeypatch): - """Thread set is trimmed to _MAX_TRACKED_THREADS.""" - from gateway.platforms.matrix import MatrixAdapter + """Thread set is trimmed to max_tracked.""" + from gateway.platforms.helpers import ThreadParticipationTracker state_path = tmp_path / "matrix_threads.json" monkeypatch.setattr( - MatrixAdapter, "_thread_state_path", - staticmethod(lambda: state_path), + ThreadParticipationTracker, "_state_path", + lambda self: state_path, ) adapter = _make_adapter() - adapter._MAX_TRACKED_THREADS = 5 + adapter._threads._max_tracked = 5 for i in range(10): - adapter._bot_participated_threads.add(f"$t{i}") - adapter._save_participated_threads() + adapter._threads.mark(f"$t{i}") data = json.loads(state_path.read_text()) assert len(data) == 5 @@ -447,7 +445,7 @@ async def test_dm_mention_thread_creates_thread(monkeypatch): _set_dm(adapter) event = _make_event("@hermes:example.org help me", event_id="$dm1") - with patch.object(adapter, "_save_participated_threads"): + with patch.object(adapter._threads, "_save"): await adapter._on_room_message(event) adapter.handle_message.assert_awaited_once() @@ -480,7 +478,7 @@ async def test_dm_mention_thread_preserves_existing_thread(monkeypatch): adapter = _make_adapter() _set_dm(adapter) - adapter._bot_participated_threads.add("$existing_thread") + adapter._threads.mark("$existing_thread") event = _make_event("@hermes:example.org help me", thread_id="$existing_thread") await adapter._on_room_message(event) @@ -491,7 +489,7 @@ async def test_dm_mention_thread_preserves_existing_thread(monkeypatch): @pytest.mark.asyncio async def test_dm_mention_thread_tracks_participation(monkeypatch): - """DM mention-thread tracks the thread in _bot_participated_threads.""" + """DM mention-thread tracks the thread in _threads.""" monkeypatch.setenv("MATRIX_DM_MENTION_THREADS", "true") monkeypatch.setenv("MATRIX_AUTO_THREAD", "false") @@ -499,10 +497,10 @@ async def test_dm_mention_thread_tracks_participation(monkeypatch): _set_dm(adapter) event = _make_event("@hermes:example.org help", event_id="$dm1") - with patch.object(adapter, "_save_participated_threads"): + with patch.object(adapter._threads, "_save"): await adapter._on_room_message(event) - assert "$dm1" in adapter._bot_participated_threads + assert "$dm1" in adapter._threads # --------------------------------------------------------------------------- diff --git a/tests/gateway/test_mattermost.py b/tests/gateway/test_mattermost.py index 7d47c0a3e1..56e46f6364 100644 --- a/tests/gateway/test_mattermost.py +++ b/tests/gateway/test_mattermost.py @@ -614,25 +614,27 @@ class TestMattermostDedup: assert self.adapter.handle_message.call_count == 2 def test_prune_seen_clears_expired(self): - """_prune_seen should remove entries older than _SEEN_TTL.""" + """Dedup cache should remove entries older than TTL on overflow.""" now = time.time() + dedup = self.adapter._dedup # Fill with enough expired entries to trigger pruning - for i in range(self.adapter._SEEN_MAX + 10): - self.adapter._seen_posts[f"old_{i}"] = now - 600 # 10 min ago + for i in range(dedup._max_size + 10): + dedup._seen[f"old_{i}"] = now - 600 # 10 min ago (older than default TTL) # Add a fresh one - self.adapter._seen_posts["fresh"] = now + dedup._seen["fresh"] = now - self.adapter._prune_seen() + # Trigger pruning by calling is_duplicate with a new entry (over max_size) + dedup.is_duplicate("trigger_prune") # Old entries should be pruned, fresh one kept - assert "fresh" in self.adapter._seen_posts - assert len(self.adapter._seen_posts) < self.adapter._SEEN_MAX + assert "fresh" in dedup._seen + assert len(dedup._seen) < dedup._max_size + 10 def test_seen_cache_tracks_post_ids(self): - """Posts are tracked in _seen_posts dict.""" - self.adapter._seen_posts["test_post"] = time.time() - assert "test_post" in self.adapter._seen_posts + """Posts are tracked in the dedup cache.""" + self.adapter._dedup._seen["test_post"] = time.time() + assert "test_post" in self.adapter._dedup._seen # --------------------------------------------------------------------------- diff --git a/tests/gateway/test_queue_consumption.py b/tests/gateway/test_queue_consumption.py index 2a4dd4ff02..50effc139d 100644 --- a/tests/gateway/test_queue_consumption.py +++ b/tests/gateway/test_queue_consumption.py @@ -10,6 +10,7 @@ from unittest.mock import AsyncMock, MagicMock, patch import pytest +from gateway.run import _dequeue_pending_event from gateway.platforms.base import ( BasePlatformAdapter, MessageEvent, @@ -79,6 +80,26 @@ class TestQueueMessageStorage: # Should be consumed (cleared) assert adapter.get_pending_message(session_key) is None + def test_dequeue_pending_event_preserves_voice_media_metadata(self): + adapter = _StubAdapter() + session_key = "telegram:user:voice" + event = MessageEvent( + text="", + message_type=MessageType.VOICE, + source=MagicMock(chat_id="123", platform=Platform.TELEGRAM), + message_id="voice-q1", + media_urls=["/tmp/voice.ogg"], + media_types=["audio/ogg"], + ) + adapter._pending_messages[session_key] = event + + retrieved = _dequeue_pending_event(adapter, session_key) + + assert retrieved is event + assert retrieved.media_urls == ["/tmp/voice.ogg"] + assert retrieved.media_types == ["audio/ogg"] + assert adapter.get_pending_message(session_key) is None + def test_queue_does_not_set_interrupt_event(self): """The whole point of /queue — no interrupt signal.""" adapter = _StubAdapter() diff --git a/tests/gateway/test_session_env.py b/tests/gateway/test_session_env.py index a7f1345b77..b75e267f11 100644 --- a/tests/gateway/test_session_env.py +++ b/tests/gateway/test_session_env.py @@ -18,6 +18,8 @@ def test_set_session_env_sets_contextvars(monkeypatch): chat_id="-1001", chat_name="Group", chat_type="group", + user_id="123456", + user_name="alice", thread_id="17585", ) context = SessionContext(source=source, connected_platforms=[], home_channels={}) @@ -25,6 +27,8 @@ def test_set_session_env_sets_contextvars(monkeypatch): monkeypatch.delenv("HERMES_SESSION_PLATFORM", raising=False) monkeypatch.delenv("HERMES_SESSION_CHAT_ID", raising=False) monkeypatch.delenv("HERMES_SESSION_CHAT_NAME", raising=False) + monkeypatch.delenv("HERMES_SESSION_USER_ID", raising=False) + monkeypatch.delenv("HERMES_SESSION_USER_NAME", raising=False) monkeypatch.delenv("HERMES_SESSION_THREAD_ID", raising=False) tokens = runner._set_session_env(context) @@ -33,6 +37,8 @@ def test_set_session_env_sets_contextvars(monkeypatch): assert get_session_env("HERMES_SESSION_PLATFORM") == "telegram" assert get_session_env("HERMES_SESSION_CHAT_ID") == "-1001" assert get_session_env("HERMES_SESSION_CHAT_NAME") == "Group" + assert get_session_env("HERMES_SESSION_USER_ID") == "123456" + assert get_session_env("HERMES_SESSION_USER_NAME") == "alice" assert get_session_env("HERMES_SESSION_THREAD_ID") == "17585" # os.environ should NOT be touched @@ -50,6 +56,8 @@ def test_clear_session_env_restores_previous_state(monkeypatch): monkeypatch.delenv("HERMES_SESSION_PLATFORM", raising=False) monkeypatch.delenv("HERMES_SESSION_CHAT_ID", raising=False) monkeypatch.delenv("HERMES_SESSION_CHAT_NAME", raising=False) + monkeypatch.delenv("HERMES_SESSION_USER_ID", raising=False) + monkeypatch.delenv("HERMES_SESSION_USER_NAME", raising=False) monkeypatch.delenv("HERMES_SESSION_THREAD_ID", raising=False) source = SessionSource( @@ -57,12 +65,15 @@ def test_clear_session_env_restores_previous_state(monkeypatch): chat_id="-1001", chat_name="Group", chat_type="group", + user_id="123456", + user_name="alice", thread_id="17585", ) context = SessionContext(source=source, connected_platforms=[], home_channels={}) tokens = runner._set_session_env(context) assert get_session_env("HERMES_SESSION_PLATFORM") == "telegram" + assert get_session_env("HERMES_SESSION_USER_ID") == "123456" runner._clear_session_env(tokens) @@ -70,6 +81,8 @@ def test_clear_session_env_restores_previous_state(monkeypatch): assert get_session_env("HERMES_SESSION_PLATFORM") == "" assert get_session_env("HERMES_SESSION_CHAT_ID") == "" assert get_session_env("HERMES_SESSION_CHAT_NAME") == "" + assert get_session_env("HERMES_SESSION_USER_ID") == "" + assert get_session_env("HERMES_SESSION_USER_NAME") == "" assert get_session_env("HERMES_SESSION_THREAD_ID") == "" diff --git a/tests/gateway/test_signal.py b/tests/gateway/test_signal.py index ae985300d1..265f9be783 100644 --- a/tests/gateway/test_signal.py +++ b/tests/gateway/test_signal.py @@ -114,16 +114,16 @@ class TestSignalAdapterInit: class TestSignalHelpers: def test_redact_phone_long(self): - from gateway.platforms.signal import _redact_phone - assert _redact_phone("+15551234567") == "+155****4567" + from gateway.platforms.helpers import redact_phone + assert redact_phone("+155****4567") == "+155****4567" def test_redact_phone_short(self): - from gateway.platforms.signal import _redact_phone - assert _redact_phone("+12345") == "+1****45" + from gateway.platforms.helpers import redact_phone + assert redact_phone("+12345") == "+1****45" def test_redact_phone_empty(self): - from gateway.platforms.signal import _redact_phone - assert _redact_phone("") == "" + from gateway.platforms.helpers import redact_phone + assert redact_phone("") == "" def test_parse_comma_list(self): from gateway.platforms.signal import _parse_comma_list diff --git a/tests/gateway/test_sms.py b/tests/gateway/test_sms.py index 54c1edf237..d8a1589bdf 100644 --- a/tests/gateway/test_sms.py +++ b/tests/gateway/test_sms.py @@ -1,11 +1,14 @@ """Tests for SMS (Twilio) platform integration. Covers config loading, format/truncate, echo prevention, -requirements check, and toolset verification. +requirements check, toolset verification, and Twilio signature validation. """ +import base64 +import hashlib +import hmac import os -from unittest.mock import patch +from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -213,3 +216,335 @@ class TestSmsToolset: from tools.cronjob_tools import CRONJOB_SCHEMA deliver_desc = CRONJOB_SCHEMA["parameters"]["properties"]["deliver"]["description"] assert "sms" in deliver_desc.lower() + + +# ── Webhook host configuration ───────────────────────────────────── + +class TestWebhookHostConfig: + """Verify SMS_WEBHOOK_HOST env var and default.""" + + def test_default_host_is_all_interfaces(self): + from gateway.platforms.sms import DEFAULT_WEBHOOK_HOST + assert DEFAULT_WEBHOOK_HOST == "0.0.0.0" + + def test_host_from_env(self): + from gateway.platforms.sms import SmsAdapter + + env = { + "TWILIO_ACCOUNT_SID": "ACtest", + "TWILIO_AUTH_TOKEN": "tok", + "TWILIO_PHONE_NUMBER": "+15550001111", + "SMS_WEBHOOK_HOST": "127.0.0.1", + } + with patch.dict(os.environ, env): + pc = PlatformConfig(enabled=True, api_key="tok") + adapter = SmsAdapter(pc) + assert adapter._webhook_host == "127.0.0.1" + + def test_webhook_url_from_env(self): + from gateway.platforms.sms import SmsAdapter + + env = { + "TWILIO_ACCOUNT_SID": "ACtest", + "TWILIO_AUTH_TOKEN": "tok", + "TWILIO_PHONE_NUMBER": "+15550001111", + "SMS_WEBHOOK_URL": "https://example.com/webhooks/twilio", + } + with patch.dict(os.environ, env): + pc = PlatformConfig(enabled=True, api_key="tok") + adapter = SmsAdapter(pc) + assert adapter._webhook_url == "https://example.com/webhooks/twilio" + + def test_webhook_url_stripped(self): + from gateway.platforms.sms import SmsAdapter + + env = { + "TWILIO_ACCOUNT_SID": "ACtest", + "TWILIO_AUTH_TOKEN": "tok", + "TWILIO_PHONE_NUMBER": "+15550001111", + "SMS_WEBHOOK_URL": " https://example.com/webhooks/twilio ", + } + with patch.dict(os.environ, env): + pc = PlatformConfig(enabled=True, api_key="tok") + adapter = SmsAdapter(pc) + assert adapter._webhook_url == "https://example.com/webhooks/twilio" + + +# ── Startup guard (fail-closed) ──────────────────────────────────── + +class TestStartupGuard: + """Adapter must refuse to start without SMS_WEBHOOK_URL.""" + + def _make_adapter(self, extra_env=None): + from gateway.platforms.sms import SmsAdapter + + env = { + "TWILIO_ACCOUNT_SID": "ACtest", + "TWILIO_AUTH_TOKEN": "tok", + "TWILIO_PHONE_NUMBER": "+15550001111", + } + if extra_env: + env.update(extra_env) + with patch.dict(os.environ, env, clear=False): + pc = PlatformConfig(enabled=True, api_key="tok") + adapter = SmsAdapter(pc) + return adapter + + @pytest.mark.asyncio + async def test_refuses_start_without_webhook_url(self): + adapter = self._make_adapter() + result = await adapter.connect() + assert result is False + + @pytest.mark.asyncio + async def test_insecure_flag_allows_start_without_url(self): + mock_session = AsyncMock() + with patch.dict(os.environ, {"SMS_INSECURE_NO_SIGNATURE": "true"}), \ + patch("aiohttp.web.AppRunner") as mock_runner_cls, \ + patch("aiohttp.web.TCPSite") as mock_site_cls, \ + patch("aiohttp.ClientSession", return_value=mock_session): + mock_runner_cls.return_value.setup = AsyncMock() + mock_runner_cls.return_value.cleanup = AsyncMock() + mock_site_cls.return_value.start = AsyncMock() + adapter = self._make_adapter() + result = await adapter.connect() + assert result is True + await adapter.disconnect() + + @pytest.mark.asyncio + async def test_webhook_url_allows_start(self): + mock_session = AsyncMock() + with patch("aiohttp.web.AppRunner") as mock_runner_cls, \ + patch("aiohttp.web.TCPSite") as mock_site_cls, \ + patch("aiohttp.ClientSession", return_value=mock_session): + mock_runner_cls.return_value.setup = AsyncMock() + mock_runner_cls.return_value.cleanup = AsyncMock() + mock_site_cls.return_value.start = AsyncMock() + adapter = self._make_adapter( + extra_env={"SMS_WEBHOOK_URL": "https://example.com/webhooks/twilio"} + ) + result = await adapter.connect() + assert result is True + await adapter.disconnect() + + +# ── Twilio signature validation ──────────────────────────────────── + +def _compute_twilio_signature(auth_token, url, params): + """Reference implementation of Twilio's signature algorithm.""" + data_to_sign = url + for key in sorted(params.keys()): + data_to_sign += key + params[key] + mac = hmac.new( + auth_token.encode("utf-8"), + data_to_sign.encode("utf-8"), + hashlib.sha1, + ) + return base64.b64encode(mac.digest()).decode("utf-8") + + +class TestTwilioSignatureValidation: + """Unit tests for SmsAdapter._validate_twilio_signature.""" + + def _make_adapter(self, auth_token="test_token_secret"): + from gateway.platforms.sms import SmsAdapter + + env = { + "TWILIO_ACCOUNT_SID": "ACtest", + "TWILIO_AUTH_TOKEN": auth_token, + "TWILIO_PHONE_NUMBER": "+15550001111", + } + with patch.dict(os.environ, env): + pc = PlatformConfig(enabled=True, api_key=auth_token) + adapter = SmsAdapter(pc) + return adapter + + def test_valid_signature_accepted(self): + adapter = self._make_adapter() + url = "https://example.com/webhooks/twilio" + params = {"From": "+15551234567", "Body": "hello", "To": "+15550001111"} + sig = _compute_twilio_signature("test_token_secret", url, params) + assert adapter._validate_twilio_signature(url, params, sig) is True + + def test_invalid_signature_rejected(self): + adapter = self._make_adapter() + url = "https://example.com/webhooks/twilio" + params = {"From": "+15551234567", "Body": "hello"} + assert adapter._validate_twilio_signature(url, params, "badsig") is False + + def test_wrong_token_rejected(self): + adapter = self._make_adapter(auth_token="correct_token") + url = "https://example.com/webhooks/twilio" + params = {"From": "+15551234567", "Body": "hello"} + sig = _compute_twilio_signature("wrong_token", url, params) + assert adapter._validate_twilio_signature(url, params, sig) is False + + def test_params_sorted_by_key(self): + """Signature must be computed with params sorted alphabetically.""" + adapter = self._make_adapter() + url = "https://example.com/webhooks/twilio" + params = {"Zebra": "last", "Alpha": "first", "Middle": "mid"} + sig = _compute_twilio_signature("test_token_secret", url, params) + assert adapter._validate_twilio_signature(url, params, sig) is True + + def test_empty_param_values_included(self): + """Blank values must be included in signature computation.""" + adapter = self._make_adapter() + url = "https://example.com/webhooks/twilio" + params = {"From": "+15551234567", "Body": "", "SmsStatus": "received"} + sig = _compute_twilio_signature("test_token_secret", url, params) + assert adapter._validate_twilio_signature(url, params, sig) is True + + def test_url_matters(self): + """Different URLs produce different signatures.""" + adapter = self._make_adapter() + params = {"Body": "hello"} + sig = _compute_twilio_signature( + "test_token_secret", "https://a.com/webhooks/twilio", params + ) + assert adapter._validate_twilio_signature( + "https://b.com/webhooks/twilio", params, sig + ) is False + + def test_port_variant_443_matches_without_port(self): + """Signature for https URL with :443 validates against URL without port.""" + adapter = self._make_adapter() + params = {"From": "+15551234567", "Body": "hello"} + sig = _compute_twilio_signature( + "test_token_secret", "https://example.com:443/webhooks/twilio", params + ) + assert adapter._validate_twilio_signature( + "https://example.com/webhooks/twilio", params, sig + ) is True + + def test_port_variant_without_port_matches_443(self): + """Signature for https URL without port validates against URL with :443.""" + adapter = self._make_adapter() + params = {"From": "+15551234567", "Body": "hello"} + sig = _compute_twilio_signature( + "test_token_secret", "https://example.com/webhooks/twilio", params + ) + assert adapter._validate_twilio_signature( + "https://example.com:443/webhooks/twilio", params, sig + ) is True + + def test_non_standard_port_no_variant(self): + """Non-standard port must NOT match URL without port.""" + adapter = self._make_adapter() + params = {"From": "+15551234567", "Body": "hello"} + sig = _compute_twilio_signature( + "test_token_secret", "https://example.com/webhooks/twilio", params + ) + assert adapter._validate_twilio_signature( + "https://example.com:8080/webhooks/twilio", params, sig + ) is False + + def test_port_variant_http_80(self): + """Port variant also works for http with port 80.""" + adapter = self._make_adapter() + params = {"From": "+15551234567", "Body": "hello"} + sig = _compute_twilio_signature( + "test_token_secret", "http://example.com:80/webhooks/twilio", params + ) + assert adapter._validate_twilio_signature( + "http://example.com/webhooks/twilio", params, sig + ) is True + + +# ── Webhook signature enforcement (handler-level) ────────────────── + +class TestWebhookSignatureEnforcement: + """Integration tests for signature validation in _handle_webhook.""" + + def _make_adapter(self, webhook_url=""): + from gateway.platforms.sms import SmsAdapter + + env = { + "TWILIO_ACCOUNT_SID": "ACtest", + "TWILIO_AUTH_TOKEN": "test_token_secret", + "TWILIO_PHONE_NUMBER": "+15550001111", + "SMS_WEBHOOK_URL": webhook_url, + } + with patch.dict(os.environ, env): + pc = PlatformConfig(enabled=True, api_key="test_token_secret") + adapter = SmsAdapter(pc) + adapter._message_handler = AsyncMock() + return adapter + + def _mock_request(self, body, headers=None): + request = MagicMock() + request.read = AsyncMock(return_value=body) + request.headers = headers or {} + return request + + @pytest.mark.asyncio + async def test_insecure_flag_skips_validation(self): + """With SMS_INSECURE_NO_SIGNATURE=true and no URL, requests are accepted.""" + env = {"SMS_INSECURE_NO_SIGNATURE": "true"} + with patch.dict(os.environ, env): + adapter = self._make_adapter(webhook_url="") + body = b"From=%2B15551234567&To=%2B15550001111&Body=hello&MessageSid=SM123" + request = self._mock_request(body) + resp = await adapter._handle_webhook(request) + assert resp.status == 200 + + @pytest.mark.asyncio + async def test_insecure_flag_with_url_still_validates(self): + """When both SMS_WEBHOOK_URL and SMS_INSECURE_NO_SIGNATURE are set, + validation stays active (URL takes precedence).""" + adapter = self._make_adapter(webhook_url="https://example.com/webhooks/twilio") + body = b"From=%2B15551234567&To=%2B15550001111&Body=hello&MessageSid=SM123" + request = self._mock_request(body, headers={}) + resp = await adapter._handle_webhook(request) + assert resp.status == 403 + + @pytest.mark.asyncio + async def test_missing_signature_returns_403(self): + adapter = self._make_adapter(webhook_url="https://example.com/webhooks/twilio") + body = b"From=%2B15551234567&To=%2B15550001111&Body=hello&MessageSid=SM123" + request = self._mock_request(body, headers={}) + resp = await adapter._handle_webhook(request) + assert resp.status == 403 + + @pytest.mark.asyncio + async def test_invalid_signature_returns_403(self): + adapter = self._make_adapter(webhook_url="https://example.com/webhooks/twilio") + body = b"From=%2B15551234567&To=%2B15550001111&Body=hello&MessageSid=SM123" + request = self._mock_request(body, headers={"X-Twilio-Signature": "invalid"}) + resp = await adapter._handle_webhook(request) + assert resp.status == 403 + + @pytest.mark.asyncio + async def test_valid_signature_returns_200(self): + webhook_url = "https://example.com/webhooks/twilio" + adapter = self._make_adapter(webhook_url=webhook_url) + params = { + "From": "+15551234567", + "To": "+15550001111", + "Body": "hello", + "MessageSid": "SM123", + } + sig = _compute_twilio_signature("test_token_secret", webhook_url, params) + body = b"From=%2B15551234567&To=%2B15550001111&Body=hello&MessageSid=SM123" + request = self._mock_request(body, headers={"X-Twilio-Signature": sig}) + resp = await adapter._handle_webhook(request) + assert resp.status == 200 + + @pytest.mark.asyncio + async def test_port_variant_signature_returns_200(self): + """Signature computed with :443 should pass when URL configured without port.""" + webhook_url = "https://example.com/webhooks/twilio" + adapter = self._make_adapter(webhook_url=webhook_url) + params = { + "From": "+15551234567", + "To": "+15550001111", + "Body": "hello", + "MessageSid": "SM123", + } + sig = _compute_twilio_signature( + "test_token_secret", "https://example.com:443/webhooks/twilio", params + ) + body = b"From=%2B15551234567&To=%2B15550001111&Body=hello&MessageSid=SM123" + request = self._mock_request(body, headers={"X-Twilio-Signature": sig}) + resp = await adapter._handle_webhook(request) + assert resp.status == 200 diff --git a/tests/gateway/test_stt_config.py b/tests/gateway/test_stt_config.py index a49e402151..23ba06af22 100644 --- a/tests/gateway/test_stt_config.py +++ b/tests/gateway/test_stt_config.py @@ -6,7 +6,9 @@ from unittest.mock import AsyncMock, patch import pytest import yaml -from gateway.config import GatewayConfig, load_gateway_config +from gateway.config import GatewayConfig, Platform, load_gateway_config +from gateway.platforms.base import MessageEvent, MessageType +from gateway.session import SessionSource def test_gateway_config_stt_disabled_from_dict_nested(): @@ -69,3 +71,46 @@ async def test_enrich_message_with_transcription_avoids_bogus_no_provider_messag assert "No STT provider is configured" not in result assert "trouble transcribing" in result assert "caption" in result + + +@pytest.mark.asyncio +async def test_prepare_inbound_message_text_transcribes_queued_voice_event(): + from gateway.run import GatewayRunner + + runner = GatewayRunner.__new__(GatewayRunner) + runner.config = GatewayConfig(stt_enabled=True) + runner.adapters = {} + runner._model = "test-model" + runner._base_url = "" + runner._has_setup_skill = lambda: False + + source = SessionSource( + platform=Platform.TELEGRAM, + chat_id="123", + chat_type="dm", + ) + event = MessageEvent( + text="", + message_type=MessageType.VOICE, + source=source, + media_urls=["/tmp/queued-voice.ogg"], + media_types=["audio/ogg"], + ) + + with patch( + "tools.transcription_tools.transcribe_audio", + return_value={ + "success": True, + "transcript": "queued voice transcript", + "provider": "local_command", + }, + ): + result = await runner._prepare_inbound_message_text( + event=event, + source=source, + history=[], + ) + + assert result is not None + assert "queued voice transcript" in result + assert "voice message" in result.lower() diff --git a/tests/gateway/test_telegram_conflict.py b/tests/gateway/test_telegram_conflict.py index 47a67f229b..dcf3116884 100644 --- a/tests/gateway/test_telegram_conflict.py +++ b/tests/gateway/test_telegram_conflict.py @@ -43,6 +43,8 @@ def _no_auto_discovery(monkeypatch): async def _noop(): return [] monkeypatch.setattr("gateway.platforms.telegram.discover_fallback_ips", _noop) + # Mock HTTPXRequest so the builder chain doesn't fail + monkeypatch.setattr("gateway.platforms.telegram.HTTPXRequest", lambda **kwargs: MagicMock()) @pytest.mark.asyncio @@ -57,9 +59,9 @@ async def test_connect_rejects_same_host_token_lock(monkeypatch): ok = await adapter.connect() assert ok is False - assert adapter.fatal_error_code == "telegram_token_lock" + assert adapter.fatal_error_code == "telegram-bot-token_lock" assert adapter.has_fatal_error is True - assert "already using this Telegram bot token" in adapter.fatal_error_message + assert "already in use" in adapter.fatal_error_message @pytest.mark.asyncio @@ -98,6 +100,8 @@ async def test_polling_conflict_retries_before_fatal(monkeypatch): ) builder = MagicMock() builder.token.return_value = builder + builder.request.return_value = builder + builder.get_updates_request.return_value = builder builder.build.return_value = app monkeypatch.setattr("gateway.platforms.telegram.Application", SimpleNamespace(builder=MagicMock(return_value=builder))) @@ -172,6 +176,8 @@ async def test_polling_conflict_becomes_fatal_after_retries(monkeypatch): ) builder = MagicMock() builder.token.return_value = builder + builder.request.return_value = builder + builder.get_updates_request.return_value = builder builder.build.return_value = app monkeypatch.setattr("gateway.platforms.telegram.Application", SimpleNamespace(builder=MagicMock(return_value=builder))) @@ -216,6 +222,8 @@ async def test_connect_marks_retryable_fatal_error_for_startup_network_failure(m builder = MagicMock() builder.token.return_value = builder + builder.request.return_value = builder + builder.get_updates_request.return_value = builder app = SimpleNamespace( bot=SimpleNamespace(delete_webhook=AsyncMock(), set_my_commands=AsyncMock()), updater=SimpleNamespace(), @@ -265,6 +273,8 @@ async def test_connect_clears_webhook_before_polling(monkeypatch): ) builder = MagicMock() builder.token.return_value = builder + builder.request.return_value = builder + builder.get_updates_request.return_value = builder builder.build.return_value = app monkeypatch.setattr( "gateway.platforms.telegram.Application", diff --git a/tests/gateway/test_weixin.py b/tests/gateway/test_weixin.py index caf4a7ebab..bb439fa9a6 100644 --- a/tests/gateway/test_weixin.py +++ b/tests/gateway/test_weixin.py @@ -1,12 +1,14 @@ """Tests for the Weixin platform adapter.""" import asyncio +import json import os from unittest.mock import AsyncMock, patch from gateway.config import PlatformConfig from gateway.config import GatewayConfig, HomeChannel, Platform, _apply_env_overrides -from gateway.platforms.weixin import WeixinAdapter +from gateway.platforms import weixin +from gateway.platforms.weixin import ContextTokenStore, WeixinAdapter from tools.send_message_tool import _parse_target_ref, _send_to_platform @@ -187,6 +189,70 @@ class TestWeixinConfig: assert config.get_connected_platforms() == [] +class TestWeixinStatePersistence: + def test_save_weixin_account_preserves_existing_file_on_replace_failure(self, tmp_path, monkeypatch): + account_path = tmp_path / "weixin" / "accounts" / "acct.json" + account_path.parent.mkdir(parents=True, exist_ok=True) + original = {"token": "old-token", "base_url": "https://old.example.com"} + account_path.write_text(json.dumps(original), encoding="utf-8") + + def _boom(_src, _dst): + raise OSError("disk full") + + monkeypatch.setattr("utils.os.replace", _boom) + + try: + weixin.save_weixin_account( + str(tmp_path), + account_id="acct", + token="new-token", + base_url="https://new.example.com", + user_id="wxid_new", + ) + except OSError: + pass + else: + raise AssertionError("expected save_weixin_account to propagate replace failure") + + assert json.loads(account_path.read_text(encoding="utf-8")) == original + + def test_context_token_persist_preserves_existing_file_on_replace_failure(self, tmp_path, monkeypatch): + token_path = tmp_path / "weixin" / "accounts" / "acct.context-tokens.json" + token_path.parent.mkdir(parents=True, exist_ok=True) + token_path.write_text(json.dumps({"user-a": "old-token"}), encoding="utf-8") + + def _boom(_src, _dst): + raise OSError("disk full") + + monkeypatch.setattr("utils.os.replace", _boom) + + store = ContextTokenStore(str(tmp_path)) + with patch.object(weixin.logger, "warning") as warning_mock: + store.set("acct", "user-b", "new-token") + + assert json.loads(token_path.read_text(encoding="utf-8")) == {"user-a": "old-token"} + warning_mock.assert_called_once() + + def test_save_sync_buf_preserves_existing_file_on_replace_failure(self, tmp_path, monkeypatch): + sync_path = tmp_path / "weixin" / "accounts" / "acct.sync.json" + sync_path.parent.mkdir(parents=True, exist_ok=True) + sync_path.write_text(json.dumps({"get_updates_buf": "old-sync"}), encoding="utf-8") + + def _boom(_src, _dst): + raise OSError("disk full") + + monkeypatch.setattr("utils.os.replace", _boom) + + try: + weixin._save_sync_buf(str(tmp_path), "acct", "new-sync") + except OSError: + pass + else: + raise AssertionError("expected _save_sync_buf to propagate replace failure") + + assert json.loads(sync_path.read_text(encoding="utf-8")) == {"get_updates_buf": "old-sync"} + + class TestWeixinSendMessageIntegration: def test_parse_target_ref_accepts_weixin_ids(self): assert _parse_target_ref("weixin", "wxid_test123") == ("wxid_test123", None, True) @@ -217,6 +283,55 @@ class TestWeixinSendMessageIntegration: ) +class TestWeixinChunkDelivery: + def _connected_adapter(self) -> WeixinAdapter: + adapter = _make_adapter() + adapter._session = object() + adapter._token = "test-token" + adapter._base_url = "https://weixin.example.com" + adapter._token_store.get = lambda account_id, chat_id: "ctx-token" + return adapter + + @patch("gateway.platforms.weixin.asyncio.sleep", new_callable=AsyncMock) + @patch("gateway.platforms.weixin._send_message", new_callable=AsyncMock) + def test_send_waits_between_multiple_chunks(self, send_message_mock, sleep_mock): + adapter = self._connected_adapter() + adapter.MAX_MESSAGE_LENGTH = 12 + + # Use double newlines so _pack_markdown_blocks splits into 3 blocks + result = asyncio.run(adapter.send("wxid_test123", "first\n\nsecond\n\nthird")) + + assert result.success is True + assert send_message_mock.await_count == 3 + assert sleep_mock.await_count == 2 + + @patch("gateway.platforms.weixin.asyncio.sleep", new_callable=AsyncMock) + @patch("gateway.platforms.weixin._send_message", new_callable=AsyncMock) + def test_send_retries_failed_chunk_before_continuing(self, send_message_mock, sleep_mock): + adapter = self._connected_adapter() + adapter.MAX_MESSAGE_LENGTH = 12 + calls = {"count": 0} + + async def flaky_send(*args, **kwargs): + calls["count"] += 1 + if calls["count"] == 2: + raise RuntimeError("temporary iLink failure") + + send_message_mock.side_effect = flaky_send + + # Use double newlines so _pack_markdown_blocks splits into 3 blocks + result = asyncio.run(adapter.send("wxid_test123", "first\n\nsecond\n\nthird")) + + assert result.success is True + # 3 chunks, but chunk 2 fails once and retries → 4 _send_message calls total + assert send_message_mock.await_count == 4 + # The retried chunk should reuse the same client_id for deduplication + first_try = send_message_mock.await_args_list[1].kwargs + retry = send_message_mock.await_args_list[2].kwargs + assert first_try["text"] == retry["text"] + assert first_try["client_id"] == retry["client_id"] + + class TestWeixinRemoteMediaSafety: def test_download_remote_media_blocks_unsafe_urls(self): adapter = _make_adapter() diff --git a/tests/hermes_cli/test_gateway.py b/tests/hermes_cli/test_gateway.py index 955449547c..fd88a26c6a 100644 --- a/tests/hermes_cli/test_gateway.py +++ b/tests/hermes_cli/test_gateway.py @@ -260,7 +260,7 @@ class TestWaitForGatewayExit: def test_kill_gateway_processes_force_uses_helper(self, monkeypatch): calls = [] - monkeypatch.setattr(gateway, "find_gateway_pids", lambda exclude_pids=None: [11, 22]) + monkeypatch.setattr(gateway, "find_gateway_pids", lambda exclude_pids=None, all_profiles=False: [11, 22]) monkeypatch.setattr(gateway, "terminate_pid", lambda pid, force=False: calls.append((pid, force))) killed = gateway.kill_gateway_processes(force=True) diff --git a/tests/hermes_cli/test_gateway_service.py b/tests/hermes_cli/test_gateway_service.py index c5d4cb4f5d..cba3a8192f 100644 --- a/tests/hermes_cli/test_gateway_service.py +++ b/tests/hermes_cli/test_gateway_service.py @@ -1,6 +1,7 @@ """Tests for gateway service management helpers.""" import os +import pwd from pathlib import Path from types import SimpleNamespace @@ -129,7 +130,7 @@ class TestGatewayStopCleanup: monkeypatch.setattr( gateway_cli, "kill_gateway_processes", - lambda force=False: kill_calls.append(force) or 2, + lambda force=False, all_profiles=False: kill_calls.append(force) or 2, ) gateway_cli.gateway_command(SimpleNamespace(gateway_command="stop")) @@ -155,7 +156,7 @@ class TestGatewayStopCleanup: monkeypatch.setattr( gateway_cli, "kill_gateway_processes", - lambda force=False: kill_calls.append(force) or 2, + lambda force=False, all_profiles=False: kill_calls.append(force) or 2, ) gateway_cli.gateway_command(SimpleNamespace(gateway_command="stop", **{"all": True})) @@ -924,6 +925,23 @@ class TestProfileArg: assert "--profile" in plist assert "mybot" in plist + def test_launchd_plist_path_uses_real_user_home_not_profile_home(self, tmp_path, monkeypatch): + profile_dir = tmp_path / ".hermes" / "profiles" / "orcha" + profile_dir.mkdir(parents=True) + machine_home = tmp_path / "machine-home" + machine_home.mkdir() + profile_home = profile_dir / "home" + profile_home.mkdir() + + monkeypatch.setattr(Path, "home", lambda: profile_home) + monkeypatch.setenv("HERMES_HOME", str(profile_dir)) + monkeypatch.setattr(gateway_cli, "get_hermes_home", lambda: profile_dir) + monkeypatch.setattr(pwd, "getpwuid", lambda uid: SimpleNamespace(pw_dir=str(machine_home))) + + plist_path = gateway_cli.get_launchd_plist_path() + + assert plist_path == machine_home / "Library" / "LaunchAgents" / "ai.hermes.gateway-orcha.plist" + class TestRemapPathForUser: """Unit tests for _remap_path_for_user().""" diff --git a/tests/hermes_cli/test_runtime_provider_resolution.py b/tests/hermes_cli/test_runtime_provider_resolution.py index f46b2dd133..20486a805b 100644 --- a/tests/hermes_cli/test_runtime_provider_resolution.py +++ b/tests/hermes_cli/test_runtime_provider_resolution.py @@ -1214,3 +1214,115 @@ def test_openrouter_provider_not_affected_by_custom_fix(monkeypatch): resolved = rp.resolve_runtime_provider(requested="openrouter") assert resolved["provider"] == "openrouter" + + +# ------------------------------------------------------------------ +# fix #7828 — custom_providers model field must propagate to runtime +# ------------------------------------------------------------------ + + +def test_get_named_custom_provider_includes_model(monkeypatch): + """_get_named_custom_provider should include the model field from config.""" + monkeypatch.setattr(rp, "load_config", lambda: { + "custom_providers": [{ + "name": "my-dashscope", + "base_url": "https://dashscope.aliyuncs.com/compatible-mode/v1", + "api_key": "test-key", + "api_mode": "chat_completions", + "model": "qwen3.6-plus", + }], + }) + + result = rp._get_named_custom_provider("my-dashscope") + assert result is not None + assert result["model"] == "qwen3.6-plus" + + +def test_get_named_custom_provider_excludes_empty_model(monkeypatch): + """Empty or whitespace-only model field should not appear in result.""" + for model_val in ["", " ", None]: + entry = { + "name": "test-ep", + "base_url": "https://example.com/v1", + "api_key": "key", + } + if model_val is not None: + entry["model"] = model_val + + monkeypatch.setattr(rp, "load_config", lambda e=entry: { + "custom_providers": [e], + }) + + result = rp._get_named_custom_provider("test-ep") + assert result is not None + assert "model" not in result, ( + f"model field {model_val!r} should not be included in result" + ) + + +def test_named_custom_runtime_propagates_model_direct_path(monkeypatch): + """Model should propagate through the direct (non-pool) resolution path.""" + monkeypatch.setattr(rp, "resolve_provider", lambda *a, **k: "my-server") + monkeypatch.setattr( + rp, "_get_named_custom_provider", + lambda p: { + "name": "my-server", + "base_url": "http://localhost:8000/v1", + "api_key": "test-key", + "model": "qwen3.6-plus", + }, + ) + # Ensure pool doesn't intercept + monkeypatch.setattr(rp, "_try_resolve_from_custom_pool", lambda *a, **k: None) + + resolved = rp.resolve_runtime_provider(requested="my-server") + assert resolved["model"] == "qwen3.6-plus" + assert resolved["provider"] == "custom" + + +def test_named_custom_runtime_propagates_model_pool_path(monkeypatch): + """Model should propagate even when credential pool handles credentials.""" + monkeypatch.setattr(rp, "resolve_provider", lambda *a, **k: "my-server") + monkeypatch.setattr( + rp, "_get_named_custom_provider", + lambda p: { + "name": "my-server", + "base_url": "http://localhost:8000/v1", + "api_key": "test-key", + "model": "qwen3.6-plus", + }, + ) + # Pool returns a result (intercepting the normal path) + monkeypatch.setattr( + rp, "_try_resolve_from_custom_pool", + lambda *a, **k: { + "provider": "custom", + "api_mode": "chat_completions", + "base_url": "http://localhost:8000/v1", + "api_key": "pool-key", + "source": "pool:custom:my-server", + }, + ) + + resolved = rp.resolve_runtime_provider(requested="my-server") + assert resolved["model"] == "qwen3.6-plus", ( + "model must be injected into pool result" + ) + assert resolved["api_key"] == "pool-key", "pool credentials should be used" + + +def test_named_custom_runtime_no_model_when_absent(monkeypatch): + """When custom_providers entry has no model field, runtime should not either.""" + monkeypatch.setattr(rp, "resolve_provider", lambda *a, **k: "my-server") + monkeypatch.setattr( + rp, "_get_named_custom_provider", + lambda p: { + "name": "my-server", + "base_url": "http://localhost:8000/v1", + "api_key": "test-key", + }, + ) + monkeypatch.setattr(rp, "_try_resolve_from_custom_pool", lambda *a, **k: None) + + resolved = rp.resolve_runtime_provider(requested="my-server") + assert "model" not in resolved diff --git a/tests/hermes_cli/test_update_gateway_restart.py b/tests/hermes_cli/test_update_gateway_restart.py index ceb05f65c9..822b22742d 100644 --- a/tests/hermes_cli/test_update_gateway_restart.py +++ b/tests/hermes_cli/test_update_gateway_restart.py @@ -191,6 +191,19 @@ class TestLaunchdPlistPath: raise AssertionError("PATH key not found in plist") +class TestLaunchdPlistCurrentness: + def test_launchd_plist_is_current_ignores_path_drift(self, tmp_path, monkeypatch): + plist_path = tmp_path / "ai.hermes.gateway.plist" + monkeypatch.setattr(gateway_cli, "get_launchd_plist_path", lambda: plist_path) + + monkeypatch.setenv("PATH", "/custom/bin:/usr/bin:/bin") + plist_path.write_text(gateway_cli.generate_launchd_plist(), encoding="utf-8") + + monkeypatch.setenv("PATH", "/opt/homebrew/bin:/usr/local/bin:/usr/bin:/bin") + + assert gateway_cli.launchd_plist_is_current() is True + + # --------------------------------------------------------------------------- # cmd_update — macOS launchd detection # --------------------------------------------------------------------------- @@ -536,7 +549,7 @@ class TestServicePidExclusion: gateway_cli, "_get_service_pids", return_value={SERVICE_PID} ), patch.object( gateway_cli, "find_gateway_pids", - side_effect=lambda exclude_pids=None: ( + side_effect=lambda exclude_pids=None, all_profiles=False: ( [SERVICE_PID] if not exclude_pids else [p for p in [SERVICE_PID] if p not in exclude_pids] ), @@ -579,7 +592,7 @@ class TestServicePidExclusion: gateway_cli, "_get_service_pids", return_value={SERVICE_PID} ), patch.object( gateway_cli, "find_gateway_pids", - side_effect=lambda exclude_pids=None: ( + side_effect=lambda exclude_pids=None, all_profiles=False: ( [SERVICE_PID] if not exclude_pids else [p for p in [SERVICE_PID] if p not in exclude_pids] ), @@ -618,7 +631,7 @@ class TestServicePidExclusion: launchctl_loaded=True, ) - def fake_find(exclude_pids=None): + def fake_find(exclude_pids=None, all_profiles=False): _exclude = exclude_pids or set() return [p for p in [SERVICE_PID, MANUAL_PID] if p not in _exclude] @@ -760,3 +773,28 @@ class TestFindGatewayPidsExclude: pids = gateway_cli.find_gateway_pids() assert 100 in pids assert 200 in pids + + def test_filters_to_current_profile(self, monkeypatch, tmp_path): + profile_dir = tmp_path / ".hermes" / "profiles" / "orcha" + profile_dir.mkdir(parents=True) + monkeypatch.setattr(gateway_cli, "is_windows", lambda: False) + monkeypatch.setattr(gateway_cli, "get_hermes_home", lambda: profile_dir) + + def fake_run(cmd, **kwargs): + return subprocess.CompletedProcess( + cmd, 0, + stdout=( + "100 /Users/dgrieco/.hermes/hermes-agent/venv/bin/python -m hermes_cli.main --profile orcha gateway run --replace\n" + "200 /Users/dgrieco/.hermes/hermes-agent/venv/bin/python -m hermes_cli.main --profile other gateway run --replace\n" + ), + stderr="", + ) + + monkeypatch.setattr(gateway_cli.subprocess, "run", fake_run) + monkeypatch.setattr("os.getpid", lambda: 999) + monkeypatch.setattr(gateway_cli, "_get_service_pids", lambda: set()) + monkeypatch.setattr(gateway_cli, "_profile_arg", lambda hermes_home=None: "--profile orcha") + + pids = gateway_cli.find_gateway_pids() + + assert pids == [100] diff --git a/tests/run_agent/test_interrupt_propagation.py b/tests/run_agent/test_interrupt_propagation.py index 7f8cb01c35..a746efdac1 100644 --- a/tests/run_agent/test_interrupt_propagation.py +++ b/tests/run_agent/test_interrupt_propagation.py @@ -22,23 +22,22 @@ class TestInterruptPropagationToChild(unittest.TestCase): def tearDown(self): set_interrupt(False) + def _make_bare_agent(self): + """Create a bare AIAgent via __new__ with all interrupt-related attrs.""" + from run_agent import AIAgent + agent = AIAgent.__new__(AIAgent) + agent._interrupt_requested = False + agent._interrupt_message = None + agent._execution_thread_id = None # defaults to current thread in set_interrupt + agent._active_children = [] + agent._active_children_lock = threading.Lock() + agent.quiet_mode = True + return agent + def test_parent_interrupt_sets_child_flag(self): """When parent.interrupt() is called, child._interrupt_requested should be set.""" - from run_agent import AIAgent - - parent = AIAgent.__new__(AIAgent) - parent._interrupt_requested = False - parent._interrupt_message = None - parent._active_children = [] - parent._active_children_lock = threading.Lock() - parent.quiet_mode = True - - child = AIAgent.__new__(AIAgent) - child._interrupt_requested = False - child._interrupt_message = None - child._active_children = [] - child._active_children_lock = threading.Lock() - child.quiet_mode = True + parent = self._make_bare_agent() + child = self._make_bare_agent() parent._active_children.append(child) @@ -49,40 +48,26 @@ class TestInterruptPropagationToChild(unittest.TestCase): assert child._interrupt_message == "new user message" assert is_interrupted() is True - def test_child_clear_interrupt_at_start_clears_global(self): - """child.clear_interrupt() at start of run_conversation clears the GLOBAL event. - - This is the intended behavior at startup, but verify it doesn't - accidentally clear an interrupt intended for a running child. + def test_child_clear_interrupt_at_start_clears_thread(self): + """child.clear_interrupt() at start of run_conversation clears the + per-thread interrupt flag for the current thread. """ - from run_agent import AIAgent - - child = AIAgent.__new__(AIAgent) + child = self._make_bare_agent() child._interrupt_requested = True child._interrupt_message = "msg" - child.quiet_mode = True - child._active_children = [] - child._active_children_lock = threading.Lock() - # Global is set + # Interrupt for current thread is set set_interrupt(True) assert is_interrupted() is True - # child.clear_interrupt() clears both + # child.clear_interrupt() clears both instance flag and thread flag child.clear_interrupt() assert child._interrupt_requested is False assert is_interrupted() is False def test_interrupt_during_child_api_call_detected(self): """Interrupt set during _interruptible_api_call is detected within 0.5s.""" - from run_agent import AIAgent - - child = AIAgent.__new__(AIAgent) - child._interrupt_requested = False - child._interrupt_message = None - child._active_children = [] - child._active_children_lock = threading.Lock() - child.quiet_mode = True + child = self._make_bare_agent() child.api_mode = "chat_completions" child.log_prefix = "" child._client_kwargs = {"api_key": "test", "base_url": "http://localhost:1234"} @@ -117,21 +102,8 @@ class TestInterruptPropagationToChild(unittest.TestCase): def test_concurrent_interrupt_propagation(self): """Simulates exact CLI flow: parent runs delegate in thread, main thread interrupts.""" - from run_agent import AIAgent - - parent = AIAgent.__new__(AIAgent) - parent._interrupt_requested = False - parent._interrupt_message = None - parent._active_children = [] - parent._active_children_lock = threading.Lock() - parent.quiet_mode = True - - child = AIAgent.__new__(AIAgent) - child._interrupt_requested = False - child._interrupt_message = None - child._active_children = [] - child._active_children_lock = threading.Lock() - child.quiet_mode = True + parent = self._make_bare_agent() + child = self._make_bare_agent() # Register child (simulating what _run_single_child does) parent._active_children.append(child) @@ -157,5 +129,79 @@ class TestInterruptPropagationToChild(unittest.TestCase): set_interrupt(False) +class TestPerThreadInterruptIsolation(unittest.TestCase): + """Verify that interrupting one agent does NOT affect another agent's thread. + + This is the core fix for the gateway cross-session interrupt leak: + multiple agents run in separate threads within the same process, and + interrupting agent A must not kill agent B's running tools. + """ + + def setUp(self): + set_interrupt(False) + + def tearDown(self): + set_interrupt(False) + + def test_interrupt_only_affects_target_thread(self): + """set_interrupt(True, tid) only makes is_interrupted() True on that thread.""" + results = {} + barrier = threading.Barrier(2) + + def thread_a(): + """Agent A's execution thread — will be interrupted.""" + tid = threading.current_thread().ident + results["a_tid"] = tid + barrier.wait(timeout=5) # sync with thread B + time.sleep(0.2) # let the interrupt arrive + results["a_interrupted"] = is_interrupted() + + def thread_b(): + """Agent B's execution thread — should NOT be affected.""" + tid = threading.current_thread().ident + results["b_tid"] = tid + barrier.wait(timeout=5) # sync with thread A + time.sleep(0.2) + results["b_interrupted"] = is_interrupted() + + ta = threading.Thread(target=thread_a) + tb = threading.Thread(target=thread_b) + ta.start() + tb.start() + + # Wait for both threads to register their TIDs + time.sleep(0.05) + while "a_tid" not in results or "b_tid" not in results: + time.sleep(0.01) + + # Interrupt ONLY thread A (simulates gateway interrupting agent A) + set_interrupt(True, results["a_tid"]) + + ta.join(timeout=3) + tb.join(timeout=3) + + assert results["a_interrupted"] is True, "Thread A should see the interrupt" + assert results["b_interrupted"] is False, "Thread B must NOT see thread A's interrupt" + + def test_clear_interrupt_only_clears_target_thread(self): + """Clearing one thread's interrupt doesn't clear another's.""" + tid_a = 99990001 + tid_b = 99990002 + set_interrupt(True, tid_a) + set_interrupt(True, tid_b) + + # Clear only A + set_interrupt(False, tid_a) + + # Simulate checking from thread B's perspective + from tools.interrupt import _interrupted_threads, _lock + with _lock: + assert tid_a not in _interrupted_threads + assert tid_b in _interrupted_threads + + # Cleanup + set_interrupt(False, tid_b) + + if __name__ == "__main__": unittest.main() diff --git a/tests/run_agent/test_run_agent.py b/tests/run_agent/test_run_agent.py index 0f2d1d4de9..61137fe90a 100644 --- a/tests/run_agent/test_run_agent.py +++ b/tests/run_agent/test_run_agent.py @@ -2087,8 +2087,9 @@ class TestRunConversation: assert "Thinking Budget Exhausted" in result["final_response"] assert "/thinkon" in result["final_response"] - def test_length_empty_content_detected_as_thinking_exhausted(self, agent): - """When finish_reason='length' and content is None/empty, detect exhaustion.""" + def test_length_empty_content_without_think_tags_retries_normally(self, agent): + """When finish_reason='length' and content is None but no think tags, + fall through to normal continuation retry (not thinking-exhaustion).""" self._setup_agent(agent) resp = _mock_response(content=None, finish_reason="length") agent.client.chat.completions.create.return_value = resp @@ -2100,12 +2101,10 @@ class TestRunConversation: ): result = agent.run_conversation("hello") + # Without think tags, the agent should attempt continuation retries + # (up to 3), not immediately fire thinking-exhaustion. + assert result["api_calls"] == 3 assert result["completed"] is False - assert result["api_calls"] == 1 - assert "reasoning" in result["error"].lower() - # User-friendly message is returned - assert result["final_response"] is not None - assert "Thinking Budget Exhausted" in result["final_response"] def test_length_with_tool_calls_returns_partial_without_executing_tools(self, agent): self._setup_agent(agent) @@ -2169,6 +2168,35 @@ class TestRunConversation: mock_hfc.assert_called_once() assert result["final_response"] == "Done!" + def test_truncated_tool_args_detected_when_finish_reason_not_length(self, agent): + """When a router rewrites finish_reason from 'length' to 'tool_calls', + truncated JSON arguments should still be detected and refused rather + than wasting 3 retry attempts.""" + self._setup_agent(agent) + agent.valid_tool_names.add("write_file") + bad_tc = _mock_tool_call( + name="write_file", + arguments='{"path":"report.md","content":"partial', + call_id="c1", + ) + resp = _mock_response( + content="", finish_reason="tool_calls", tool_calls=[bad_tc], + ) + agent.client.chat.completions.create.return_value = resp + + with ( + patch("run_agent.handle_function_call") as mock_handle_function_call, + patch.object(agent, "_persist_session"), + patch.object(agent, "_save_trajectory"), + patch.object(agent, "_cleanup_task_resources"), + ): + result = agent.run_conversation("write the report") + + assert result["completed"] is False + assert result["partial"] is True + assert "truncated due to output length limit" in result["error"] + mock_handle_function_call.assert_not_called() + class TestRetryExhaustion: """Regression: retry_count > max_retries was dead code (off-by-one). diff --git a/tests/run_agent/test_run_agent_codex_responses.py b/tests/run_agent/test_run_agent_codex_responses.py index 6756ed6fde..17a70624d8 100644 --- a/tests/run_agent/test_run_agent_codex_responses.py +++ b/tests/run_agent/test_run_agent_codex_responses.py @@ -1104,3 +1104,58 @@ def test_duplicate_detection_distinguishes_different_codex_reasoning(monkeypatch ] assert "enc_first" in encrypted_contents assert "enc_second" in encrypted_contents + + +def test_chat_messages_to_responses_input_deduplicates_reasoning_ids(monkeypatch): + """Duplicate reasoning item IDs across multi-turn incomplete responses + must be deduplicated so the Responses API doesn't reject with HTTP 400.""" + agent = _build_agent(monkeypatch) + messages = [ + {"role": "user", "content": "think hard"}, + { + "role": "assistant", + "content": "", + "codex_reasoning_items": [ + {"type": "reasoning", "id": "rs_aaa", "encrypted_content": "enc_1"}, + {"type": "reasoning", "id": "rs_bbb", "encrypted_content": "enc_2"}, + ], + }, + { + "role": "assistant", + "content": "partial answer", + "codex_reasoning_items": [ + # rs_aaa is duplicated from the previous turn + {"type": "reasoning", "id": "rs_aaa", "encrypted_content": "enc_1"}, + {"type": "reasoning", "id": "rs_ccc", "encrypted_content": "enc_3"}, + ], + }, + ] + items = agent._chat_messages_to_responses_input(messages) + + reasoning_ids = [it["id"] for it in items if it.get("type") == "reasoning"] + # rs_aaa should appear only once (first occurrence kept) + assert reasoning_ids.count("rs_aaa") == 1 + # rs_bbb and rs_ccc should each appear once + assert reasoning_ids.count("rs_bbb") == 1 + assert reasoning_ids.count("rs_ccc") == 1 + assert len(reasoning_ids) == 3 + + +def test_preflight_codex_input_deduplicates_reasoning_ids(monkeypatch): + """_preflight_codex_input_items should also deduplicate reasoning items by ID.""" + agent = _build_agent(monkeypatch) + raw_input = [ + {"role": "user", "content": [{"type": "input_text", "text": "hello"}]}, + {"type": "reasoning", "id": "rs_xyz", "encrypted_content": "enc_a"}, + {"role": "assistant", "content": "ok"}, + {"type": "reasoning", "id": "rs_xyz", "encrypted_content": "enc_a"}, + {"type": "reasoning", "id": "rs_zzz", "encrypted_content": "enc_b"}, + {"role": "assistant", "content": "done"}, + ] + normalized = agent._preflight_codex_input_items(raw_input) + + reasoning_items = [it for it in normalized if it.get("type") == "reasoning"] + reasoning_ids = [it["id"] for it in reasoning_items] + assert reasoning_ids.count("rs_xyz") == 1 + assert reasoning_ids.count("rs_zzz") == 1 + assert len(reasoning_items) == 2 diff --git a/tests/tools/test_browser_orphan_reaper.py b/tests/tools/test_browser_orphan_reaper.py new file mode 100644 index 0000000000..254dad7db7 --- /dev/null +++ b/tests/tools/test_browser_orphan_reaper.py @@ -0,0 +1,158 @@ +"""Tests for _reap_orphaned_browser_sessions() — kills orphaned agent-browser +daemons whose Python parent exited without cleaning up.""" + +import os +import signal +import textwrap +from pathlib import Path +from unittest.mock import patch, MagicMock + +import pytest + + +@pytest.fixture +def fake_tmpdir(tmp_path): + """Patch _socket_safe_tmpdir to return a temp dir we control.""" + with patch("tools.browser_tool._socket_safe_tmpdir", return_value=str(tmp_path)): + yield tmp_path + + +@pytest.fixture(autouse=True) +def _isolate_sessions(): + """Ensure _active_sessions is empty for each test.""" + import tools.browser_tool as bt + orig = bt._active_sessions.copy() + bt._active_sessions.clear() + yield + bt._active_sessions.clear() + bt._active_sessions.update(orig) + + +def _make_socket_dir(tmpdir, session_name, pid=None): + """Create a fake agent-browser socket directory with optional PID file.""" + d = tmpdir / f"agent-browser-{session_name}" + d.mkdir() + if pid is not None: + (d / f"{session_name}.pid").write_text(str(pid)) + return d + + +class TestReapOrphanedBrowserSessions: + """Tests for the orphan reaper function.""" + + def test_no_socket_dirs_is_noop(self, fake_tmpdir): + """No socket dirs => nothing happens, no errors.""" + from tools.browser_tool import _reap_orphaned_browser_sessions + _reap_orphaned_browser_sessions() # should not raise + + def test_stale_dir_without_pid_file_is_removed(self, fake_tmpdir): + """Socket dir with no PID file is cleaned up.""" + from tools.browser_tool import _reap_orphaned_browser_sessions + d = _make_socket_dir(fake_tmpdir, "h_abc1234567") + assert d.exists() + _reap_orphaned_browser_sessions() + assert not d.exists() + + def test_stale_dir_with_dead_pid_is_removed(self, fake_tmpdir): + """Socket dir whose daemon PID is dead gets cleaned up.""" + from tools.browser_tool import _reap_orphaned_browser_sessions + d = _make_socket_dir(fake_tmpdir, "h_dead123456", pid=999999999) + assert d.exists() + _reap_orphaned_browser_sessions() + assert not d.exists() + + def test_orphaned_alive_daemon_is_killed(self, fake_tmpdir): + """Alive daemon not tracked by _active_sessions gets SIGTERM.""" + from tools.browser_tool import _reap_orphaned_browser_sessions + + d = _make_socket_dir(fake_tmpdir, "h_orphan12345", pid=12345) + + kill_calls = [] + original_kill = os.kill + + def mock_kill(pid, sig): + kill_calls.append((pid, sig)) + if sig == 0: + return # pretend process exists + # Don't actually kill anything + + with patch("os.kill", side_effect=mock_kill): + _reap_orphaned_browser_sessions() + + # Should have checked existence (sig 0) then killed (SIGTERM) + assert (12345, 0) in kill_calls + assert (12345, signal.SIGTERM) in kill_calls + + def test_tracked_session_is_not_reaped(self, fake_tmpdir): + """Sessions tracked in _active_sessions are left alone.""" + import tools.browser_tool as bt + from tools.browser_tool import _reap_orphaned_browser_sessions + + session_name = "h_tracked1234" + d = _make_socket_dir(fake_tmpdir, session_name, pid=12345) + + # Register the session as actively tracked + bt._active_sessions["some_task"] = {"session_name": session_name} + + kill_calls = [] + + def mock_kill(pid, sig): + kill_calls.append((pid, sig)) + + with patch("os.kill", side_effect=mock_kill): + _reap_orphaned_browser_sessions() + + # Should NOT have tried to kill anything + assert len(kill_calls) == 0 + # Dir should still exist + assert d.exists() + + def test_permission_error_on_kill_check_skips(self, fake_tmpdir): + """If we can't check the PID (PermissionError), skip it.""" + from tools.browser_tool import _reap_orphaned_browser_sessions + + d = _make_socket_dir(fake_tmpdir, "h_perm1234567", pid=12345) + + def mock_kill(pid, sig): + if sig == 0: + raise PermissionError("not our process") + + with patch("os.kill", side_effect=mock_kill): + _reap_orphaned_browser_sessions() + + # Dir should still exist (we didn't touch someone else's process) + assert d.exists() + + def test_cdp_sessions_are_also_reaped(self, fake_tmpdir): + """CDP sessions (cdp_ prefix) are also scanned.""" + from tools.browser_tool import _reap_orphaned_browser_sessions + + d = _make_socket_dir(fake_tmpdir, "cdp_abc1234567") + assert d.exists() + _reap_orphaned_browser_sessions() + # No PID file → cleaned up + assert not d.exists() + + def test_non_hermes_dirs_are_ignored(self, fake_tmpdir): + """Socket dirs that don't match our naming pattern are left alone.""" + from tools.browser_tool import _reap_orphaned_browser_sessions + + # Create a dir that doesn't match h_* or cdp_* pattern + d = fake_tmpdir / "agent-browser-other_session" + d.mkdir() + (d / "other_session.pid").write_text("12345") + + _reap_orphaned_browser_sessions() + + # Should NOT be touched + assert d.exists() + + def test_corrupt_pid_file_is_cleaned(self, fake_tmpdir): + """PID file with non-integer content is cleaned up.""" + from tools.browser_tool import _reap_orphaned_browser_sessions + + d = _make_socket_dir(fake_tmpdir, "h_corrupt1234") + (d / "h_corrupt1234.pid").write_text("not-a-number") + + _reap_orphaned_browser_sessions() + assert not d.exists() diff --git a/tests/tools/test_checkpoint_manager.py b/tests/tools/test_checkpoint_manager.py index ef843465f1..ba9da6da1f 100644 --- a/tests/tools/test_checkpoint_manager.py +++ b/tests/tools/test_checkpoint_manager.py @@ -1,9 +1,6 @@ """Tests for tools/checkpoint_manager.py — CheckpointManager.""" import logging -import os -import json -import shutil import subprocess import pytest from pathlib import Path @@ -42,6 +39,19 @@ def checkpoint_base(tmp_path): return tmp_path / "checkpoints" +@pytest.fixture() +def fake_home(tmp_path, monkeypatch): + """Set a deterministic fake home for expanduser/path-home behavior.""" + home = tmp_path / "home" + home.mkdir() + monkeypatch.setenv("HOME", str(home)) + monkeypatch.setenv("USERPROFILE", str(home)) + monkeypatch.delenv("HOMEDRIVE", raising=False) + monkeypatch.delenv("HOMEPATH", raising=False) + monkeypatch.setattr(Path, "home", classmethod(lambda cls: home)) + return home + + @pytest.fixture() def mgr(work_dir, checkpoint_base, monkeypatch): """CheckpointManager with redirected checkpoint base.""" @@ -78,6 +88,16 @@ class TestShadowRepoPath: p = _shadow_repo_path(str(work_dir)) assert str(p).startswith(str(checkpoint_base)) + def test_tilde_and_expanded_home_share_shadow_repo(self, fake_home, checkpoint_base, monkeypatch): + monkeypatch.setattr("tools.checkpoint_manager.CHECKPOINT_BASE", checkpoint_base) + project = fake_home / "project" + project.mkdir() + + tilde_path = f"~/{project.name}" + expanded_path = str(project) + + assert _shadow_repo_path(tilde_path) == _shadow_repo_path(expanded_path) + # ========================================================================= # Shadow repo init @@ -221,6 +241,20 @@ class TestListCheckpoints: assert result[0]["reason"] == "third" assert result[2]["reason"] == "first" + def test_tilde_path_lists_same_checkpoints_as_expanded_path(self, checkpoint_base, fake_home, monkeypatch): + monkeypatch.setattr("tools.checkpoint_manager.CHECKPOINT_BASE", checkpoint_base) + mgr = CheckpointManager(enabled=True, max_snapshots=50) + project = fake_home / "project" + project.mkdir() + (project / "main.py").write_text("v1\n") + + tilde_path = f"~/{project.name}" + assert mgr.ensure_checkpoint(tilde_path, "initial") is True + + listed = mgr.list_checkpoints(str(project)) + assert len(listed) == 1 + assert listed[0]["reason"] == "initial" + # ========================================================================= # CheckpointManager — restoring @@ -271,6 +305,28 @@ class TestRestore: assert len(all_cps) >= 2 assert "pre-rollback" in all_cps[0]["reason"] + def test_tilde_path_supports_diff_and_restore_flow(self, checkpoint_base, fake_home, monkeypatch): + monkeypatch.setattr("tools.checkpoint_manager.CHECKPOINT_BASE", checkpoint_base) + mgr = CheckpointManager(enabled=True, max_snapshots=50) + project = fake_home / "project" + project.mkdir() + file_path = project / "main.py" + file_path.write_text("original\n") + + tilde_path = f"~/{project.name}" + assert mgr.ensure_checkpoint(tilde_path, "initial") is True + mgr.new_turn() + + file_path.write_text("changed\n") + checkpoints = mgr.list_checkpoints(str(project)) + diff_result = mgr.diff(tilde_path, checkpoints[0]["hash"]) + assert diff_result["success"] is True + assert "main.py" in diff_result["diff"] + + restore_result = mgr.restore(tilde_path, checkpoints[0]["hash"]) + assert restore_result["success"] is True + assert file_path.read_text() == "original\n" + # ========================================================================= # CheckpointManager — working dir resolution @@ -310,6 +366,19 @@ class TestWorkingDirResolution: result = mgr.get_working_dir_for_path(str(filepath)) assert result == str(filepath.parent) + def test_resolves_tilde_path_to_project_root(self, fake_home): + mgr = CheckpointManager(enabled=True) + project = fake_home / "myproject" + project.mkdir() + (project / "pyproject.toml").write_text("[project]\n") + subdir = project / "src" + subdir.mkdir() + filepath = subdir / "main.py" + filepath.write_text("x\n") + + result = mgr.get_working_dir_for_path(f"~/{project.name}/src/main.py") + assert result == str(project) + # ========================================================================= # Git env isolation @@ -333,6 +402,14 @@ class TestGitEnvIsolation: env = _git_env(shadow, str(tmp_path)) assert "GIT_INDEX_FILE" not in env + def test_expands_tilde_in_work_tree(self, fake_home, tmp_path): + shadow = tmp_path / "shadow" + work = fake_home / "work" + work.mkdir() + + env = _git_env(shadow, f"~/{work.name}") + assert env["GIT_WORK_TREE"] == str(work.resolve()) + # ========================================================================= # format_checkpoint_list @@ -384,6 +461,8 @@ class TestErrorResilience: assert result is False def test_run_git_allows_expected_nonzero_without_error_log(self, tmp_path, caplog): + work = tmp_path / "work" + work.mkdir() completed = subprocess.CompletedProcess( args=["git", "diff", "--cached", "--quiet"], returncode=1, @@ -395,7 +474,7 @@ class TestErrorResilience: ok, stdout, stderr = _run_git( ["diff", "--cached", "--quiet"], tmp_path / "shadow", - str(tmp_path / "work"), + str(work), allowed_returncodes={1}, ) assert ok is False @@ -403,6 +482,38 @@ class TestErrorResilience: assert stderr == "" assert not caplog.records + def test_run_git_invalid_working_dir_reports_path_error(self, tmp_path, caplog): + missing = tmp_path / "missing" + with caplog.at_level(logging.ERROR, logger="tools.checkpoint_manager"): + ok, stdout, stderr = _run_git( + ["status"], + tmp_path / "shadow", + str(missing), + ) + assert ok is False + assert stdout == "" + assert "working directory not found" in stderr + assert not any("Git executable not found" in r.getMessage() for r in caplog.records) + + def test_run_git_missing_git_reports_git_not_found(self, tmp_path, monkeypatch, caplog): + work = tmp_path / "work" + work.mkdir() + + def raise_missing_git(*args, **kwargs): + raise FileNotFoundError(2, "No such file or directory", "git") + + monkeypatch.setattr("tools.checkpoint_manager.subprocess.run", raise_missing_git) + with caplog.at_level(logging.ERROR, logger="tools.checkpoint_manager"): + ok, stdout, stderr = _run_git( + ["status"], + tmp_path / "shadow", + str(work), + ) + assert ok is False + assert stdout == "" + assert stderr == "git not found" + assert any("Git executable not found" in r.getMessage() for r in caplog.records) + def test_checkpoint_failure_does_not_raise(self, mgr, work_dir, monkeypatch): """Checkpoint failures should never raise — they're silently logged.""" def broken_run_git(*args, **kwargs): @@ -411,3 +522,68 @@ class TestErrorResilience: # Should not raise result = mgr.ensure_checkpoint(str(work_dir), "test") assert result is False + + +# ========================================================================= +# Security / Input validation +# ========================================================================= + +class TestSecurity: + def test_restore_rejects_argument_injection(self, mgr, work_dir): + mgr.ensure_checkpoint(str(work_dir), "initial") + # Try to pass a git flag as a commit hash + result = mgr.restore(str(work_dir), "--patch") + assert result["success"] is False + assert "Invalid commit hash" in result["error"] + assert "must not start with '-'" in result["error"] + + result = mgr.restore(str(work_dir), "-p") + assert result["success"] is False + assert "Invalid commit hash" in result["error"] + + def test_restore_rejects_invalid_hex_chars(self, mgr, work_dir): + mgr.ensure_checkpoint(str(work_dir), "initial") + # Git hashes should not contain characters like ;, &, | + result = mgr.restore(str(work_dir), "abc; rm -rf /") + assert result["success"] is False + assert "expected 4-64 hex characters" in result["error"] + + result = mgr.diff(str(work_dir), "abc&def") + assert result["success"] is False + assert "expected 4-64 hex characters" in result["error"] + + def test_restore_rejects_path_traversal(self, mgr, work_dir): + mgr.ensure_checkpoint(str(work_dir), "initial") + # Real commit hash but malicious path + checkpoints = mgr.list_checkpoints(str(work_dir)) + target_hash = checkpoints[0]["hash"] + + # Absolute path outside + result = mgr.restore(str(work_dir), target_hash, file_path="/etc/passwd") + assert result["success"] is False + assert "got absolute path" in result["error"] + + # Relative traversal outside path + result = mgr.restore(str(work_dir), target_hash, file_path="../outside_file.txt") + assert result["success"] is False + assert "escapes the working directory" in result["error"] + + def test_restore_accepts_valid_file_path(self, mgr, work_dir): + mgr.ensure_checkpoint(str(work_dir), "initial") + checkpoints = mgr.list_checkpoints(str(work_dir)) + target_hash = checkpoints[0]["hash"] + + # Valid path inside directory + result = mgr.restore(str(work_dir), target_hash, file_path="main.py") + assert result["success"] is True + + # Another valid path with subdirectories + (work_dir / "subdir").mkdir() + (work_dir / "subdir" / "test.txt").write_text("hello") + mgr.new_turn() + mgr.ensure_checkpoint(str(work_dir), "second") + checkpoints = mgr.list_checkpoints(str(work_dir)) + target_hash = checkpoints[0]["hash"] + + result = mgr.restore(str(work_dir), target_hash, file_path="subdir/test.txt") + assert result["success"] is True diff --git a/tests/tools/test_code_execution.py b/tests/tools/test_code_execution.py index 33653c3607..e015e5d42b 100644 --- a/tests/tools/test_code_execution.py +++ b/tests/tools/test_code_execution.py @@ -780,14 +780,18 @@ class TestLoadConfig(unittest.TestCase): @unittest.skipIf(sys.platform == "win32", "UDS not available on Windows") class TestInterruptHandling(unittest.TestCase): def test_interrupt_event_stops_execution(self): - """When _interrupt_event is set, execute_code should stop the script.""" + """When interrupt is set for the execution thread, execute_code should stop.""" code = "import time; time.sleep(60); print('should not reach')" + from tools.interrupt import set_interrupt + + # Capture the main thread ID so we can target the interrupt correctly. + # execute_code runs in the current thread; set_interrupt needs its ID. + main_tid = threading.current_thread().ident def set_interrupt_after_delay(): import time as _t _t.sleep(1) - from tools.terminal_tool import _interrupt_event - _interrupt_event.set() + set_interrupt(True, main_tid) t = threading.Thread(target=set_interrupt_after_delay, daemon=True) t.start() @@ -804,8 +808,7 @@ class TestInterruptHandling(unittest.TestCase): self.assertEqual(result["status"], "interrupted") self.assertIn("interrupted", result["output"]) finally: - from tools.terminal_tool import _interrupt_event - _interrupt_event.clear() + set_interrupt(False, main_tid) t.join(timeout=3) diff --git a/tests/tools/test_notify_on_complete.py b/tests/tools/test_notify_on_complete.py index ff6f14922f..411f95f7e0 100644 --- a/tests/tools/test_notify_on_complete.py +++ b/tests/tools/test_notify_on_complete.py @@ -227,6 +227,8 @@ class TestCheckpointNotify: "session_key": "sk1", "watcher_platform": "telegram", "watcher_chat_id": "123", + "watcher_user_id": "u123", + "watcher_user_name": "alice", "watcher_thread_id": "42", "watcher_interval": 5, "notify_on_complete": True, @@ -236,6 +238,8 @@ class TestCheckpointNotify: assert recovered == 1 assert len(registry.pending_watchers) == 1 assert registry.pending_watchers[0]["notify_on_complete"] is True + assert registry.pending_watchers[0]["user_id"] == "u123" + assert registry.pending_watchers[0]["user_name"] == "alice" def test_recover_defaults_false(self, registry, tmp_path): """Old checkpoint entries without the field default to False.""" diff --git a/tests/tools/test_process_registry.py b/tests/tools/test_process_registry.py index a61da9dd3e..d981878a31 100644 --- a/tests/tools/test_process_registry.py +++ b/tests/tools/test_process_registry.py @@ -438,6 +438,8 @@ class TestCheckpoint: s = _make_session() s.watcher_platform = "telegram" s.watcher_chat_id = "999" + s.watcher_user_id = "u123" + s.watcher_user_name = "alice" s.watcher_thread_id = "42" s.watcher_interval = 60 registry._running[s.id] = s @@ -447,6 +449,8 @@ class TestCheckpoint: assert len(data) == 1 assert data[0]["watcher_platform"] == "telegram" assert data[0]["watcher_chat_id"] == "999" + assert data[0]["watcher_user_id"] == "u123" + assert data[0]["watcher_user_name"] == "alice" assert data[0]["watcher_thread_id"] == "42" assert data[0]["watcher_interval"] == 60 @@ -460,6 +464,8 @@ class TestCheckpoint: "session_key": "sk1", "watcher_platform": "telegram", "watcher_chat_id": "123", + "watcher_user_id": "u123", + "watcher_user_name": "alice", "watcher_thread_id": "42", "watcher_interval": 60, }])) @@ -471,6 +477,8 @@ class TestCheckpoint: assert w["session_id"] == "proc_live" assert w["platform"] == "telegram" assert w["chat_id"] == "123" + assert w["user_id"] == "u123" + assert w["user_name"] == "alice" assert w["thread_id"] == "42" assert w["check_interval"] == 60 diff --git a/tests/tools/test_skill_manager_tool.py b/tests/tools/test_skill_manager_tool.py index 7b9e49d4f2..dd0ae17f8c 100644 --- a/tests/tools/test_skill_manager_tool.py +++ b/tests/tools/test_skill_manager_tool.py @@ -348,7 +348,7 @@ word word result = _patch_skill("my-skill", "old text", "new text", file_path="references/evil.md") assert result["success"] is False - assert "boundary" in result["error"].lower() + assert "escapes" in result["error"].lower() assert outside_file.read_text() == "old text here" @@ -412,7 +412,7 @@ class TestWriteFile: result = _write_file("my-skill", "references/escape/owned.md", "malicious") assert result["success"] is False - assert "boundary" in result["error"].lower() + assert "escapes" in result["error"].lower() assert not (outside_dir / "owned.md").exists() @@ -449,7 +449,7 @@ class TestRemoveFile: result = _remove_file("my-skill", "references/escape/keep.txt") assert result["success"] is False - assert "boundary" in result["error"].lower() + assert "escapes" in result["error"].lower() assert outside_file.exists() diff --git a/tests/tools/test_tool_result_storage.py b/tests/tools/test_tool_result_storage.py index f95b5dc08a..0bbb95bbd6 100644 --- a/tests/tools/test_tool_result_storage.py +++ b/tests/tools/test_tool_result_storage.py @@ -124,6 +124,34 @@ class TestWriteToSandbox: cmd = env.execute.call_args[0][0] assert "mkdir -p /data/data/com.termux/files/usr/tmp/hermes-results" in cmd + def test_path_with_spaces_is_quoted(self): + env = MagicMock() + env.execute.return_value = {"output": "", "returncode": 0} + remote_path = "/tmp/hermes results/abc file.txt" + _write_to_sandbox("content", remote_path, env) + cmd = env.execute.call_args[0][0] + assert "'/tmp/hermes results'" in cmd + assert "'/tmp/hermes results/abc file.txt'" in cmd + + def test_shell_metacharacters_neutralized(self): + """Paths with shell metacharacters must be quoted to prevent injection.""" + env = MagicMock() + env.execute.return_value = {"output": "", "returncode": 0} + malicious_path = "/tmp/hermes-results/$(whoami).txt" + _write_to_sandbox("content", malicious_path, env) + cmd = env.execute.call_args[0][0] + # The $() must not appear unquoted — shlex.quote wraps it + assert "'/tmp/hermes-results/$(whoami).txt'" in cmd + + def test_semicolon_injection_neutralized(self): + env = MagicMock() + env.execute.return_value = {"output": "", "returncode": 0} + malicious_path = "/tmp/x; rm -rf /; echo .txt" + _write_to_sandbox("content", malicious_path, env) + cmd = env.execute.call_args[0][0] + # The semicolons must be inside quotes, not acting as command separators + assert "'/tmp/x; rm -rf /; echo .txt'" in cmd + class TestResolveStorageDir: def test_defaults_to_storage_dir_without_env(self): diff --git a/tools/browser_tool.py b/tools/browser_tool.py index ed3cfbb9bb..bb24866066 100644 --- a/tools/browser_tool.py +++ b/tools/browser_tool.py @@ -473,13 +473,104 @@ def _cleanup_inactive_browser_sessions(): logger.warning("Error cleaning up inactive session %s: %s", task_id, e) +def _reap_orphaned_browser_sessions(): + """Scan for orphaned agent-browser daemon processes from previous runs. + + When the Python process that created a browser session exits uncleanly + (SIGKILL, crash, gateway restart), the in-memory ``_active_sessions`` + tracking is lost but the node + Chromium processes keep running. + + This function scans the tmp directory for ``agent-browser-*`` socket dirs + left behind by previous runs, reads the daemon PID files, and kills any + daemons that are still alive but not tracked by the current process. + + Called once on cleanup-thread startup — not every 30 seconds — to avoid + races with sessions being actively created. + """ + import glob + + tmpdir = _socket_safe_tmpdir() + pattern = os.path.join(tmpdir, "agent-browser-h_*") + socket_dirs = glob.glob(pattern) + # Also pick up CDP sessions + socket_dirs += glob.glob(os.path.join(tmpdir, "agent-browser-cdp_*")) + + if not socket_dirs: + return + + # Build set of session_names currently tracked by this process + with _cleanup_lock: + tracked_names = { + info.get("session_name") + for info in _active_sessions.values() + if info.get("session_name") + } + + reaped = 0 + for socket_dir in socket_dirs: + dir_name = os.path.basename(socket_dir) + # dir_name is "agent-browser-{session_name}" + session_name = dir_name.removeprefix("agent-browser-") + if not session_name: + continue + + # Skip sessions that we are actively tracking + if session_name in tracked_names: + continue + + pid_file = os.path.join(socket_dir, f"{session_name}.pid") + if not os.path.isfile(pid_file): + # No PID file — just a stale dir, remove it + shutil.rmtree(socket_dir, ignore_errors=True) + continue + + try: + daemon_pid = int(Path(pid_file).read_text().strip()) + except (ValueError, OSError): + shutil.rmtree(socket_dir, ignore_errors=True) + continue + + # Check if the daemon is still alive + try: + os.kill(daemon_pid, 0) # signal 0 = existence check + except ProcessLookupError: + # Already dead, just clean up the dir + shutil.rmtree(socket_dir, ignore_errors=True) + continue + except PermissionError: + # Alive but owned by someone else — leave it alone + continue + + # Daemon is alive and not tracked — orphan. Kill it. + try: + os.kill(daemon_pid, signal.SIGTERM) + logger.info("Reaped orphaned browser daemon PID %d (session %s)", + daemon_pid, session_name) + reaped += 1 + except (ProcessLookupError, PermissionError, OSError): + pass + + # Clean up the socket directory + shutil.rmtree(socket_dir, ignore_errors=True) + + if reaped: + logger.info("Reaped %d orphaned browser session(s) from previous run(s)", reaped) + + def _browser_cleanup_thread_worker(): """ Background thread that periodically cleans up inactive browser sessions. Runs every 30 seconds and checks for sessions that haven't been used within the BROWSER_SESSION_INACTIVITY_TIMEOUT period. + On first run, also reaps orphaned sessions from previous process lifetimes. """ + # One-time orphan reap on startup + try: + _reap_orphaned_browser_sessions() + except Exception as e: + logger.warning("Orphan reap error: %s", e) + while _cleanup_running: try: _cleanup_inactive_browser_sessions() diff --git a/tools/checkpoint_manager.py b/tools/checkpoint_manager.py index c298aa0bb6..42900a643d 100644 --- a/tools/checkpoint_manager.py +++ b/tools/checkpoint_manager.py @@ -21,6 +21,7 @@ into the user's project directory. import hashlib import logging import os +import re import shutil import subprocess from pathlib import Path @@ -64,23 +65,72 @@ _GIT_TIMEOUT: int = max(10, min(60, int(os.getenv("HERMES_CHECKPOINT_TIMEOUT", " # Max files to snapshot — skip huge directories to avoid slowdowns. _MAX_FILES = 50_000 +# Valid git commit hash pattern: 4–40 hex chars (short or full SHA-1/SHA-256). +_COMMIT_HASH_RE = re.compile(r'^[0-9a-fA-F]{4,64}$') + + +# --------------------------------------------------------------------------- +# Input validation helpers +# --------------------------------------------------------------------------- + +def _validate_commit_hash(commit_hash: str) -> Optional[str]: + """Validate a commit hash to prevent git argument injection. + + Returns an error string if invalid, None if valid. + Values starting with '-' would be interpreted as git flags + (e.g., '--patch', '-p') instead of revision specifiers. + """ + if not commit_hash or not commit_hash.strip(): + return "Empty commit hash" + if commit_hash.startswith("-"): + return f"Invalid commit hash (must not start with '-'): {commit_hash!r}" + if not _COMMIT_HASH_RE.match(commit_hash): + return f"Invalid commit hash (expected 4-64 hex characters): {commit_hash!r}" + return None + + +def _validate_file_path(file_path: str, working_dir: str) -> Optional[str]: + """Validate a file path to prevent path traversal outside the working directory. + + Returns an error string if invalid, None if valid. + """ + if not file_path or not file_path.strip(): + return "Empty file path" + # Reject absolute paths — restore targets must be relative to the workdir + if os.path.isabs(file_path): + return f"File path must be relative, got absolute path: {file_path!r}" + # Resolve and check containment within working_dir + abs_workdir = _normalize_path(working_dir) + resolved = (abs_workdir / file_path).resolve() + try: + resolved.relative_to(abs_workdir) + except ValueError: + return f"File path escapes the working directory via traversal: {file_path!r}" + return None + # --------------------------------------------------------------------------- # Shadow repo helpers # --------------------------------------------------------------------------- +def _normalize_path(path_value: str) -> Path: + """Return a canonical absolute path for checkpoint operations.""" + return Path(path_value).expanduser().resolve() + + def _shadow_repo_path(working_dir: str) -> Path: """Deterministic shadow repo path: sha256(abs_path)[:16].""" - abs_path = str(Path(working_dir).resolve()) + abs_path = str(_normalize_path(working_dir)) dir_hash = hashlib.sha256(abs_path.encode()).hexdigest()[:16] return CHECKPOINT_BASE / dir_hash def _git_env(shadow_repo: Path, working_dir: str) -> dict: """Build env dict that redirects git to the shadow repo.""" + normalized_working_dir = _normalize_path(working_dir) env = os.environ.copy() env["GIT_DIR"] = str(shadow_repo) - env["GIT_WORK_TREE"] = str(Path(working_dir).resolve()) + env["GIT_WORK_TREE"] = str(normalized_working_dir) env.pop("GIT_INDEX_FILE", None) env.pop("GIT_NAMESPACE", None) env.pop("GIT_ALTERNATE_OBJECT_DIRECTORIES", None) @@ -100,7 +150,17 @@ def _run_git( exits while preserving the normal ``ok = (returncode == 0)`` contract. Example: ``git diff --cached --quiet`` returns 1 when changes exist. """ - env = _git_env(shadow_repo, working_dir) + normalized_working_dir = _normalize_path(working_dir) + if not normalized_working_dir.exists(): + msg = f"working directory not found: {normalized_working_dir}" + logger.error("Git command skipped: %s (%s)", " ".join(["git"] + list(args)), msg) + return False, "", msg + if not normalized_working_dir.is_dir(): + msg = f"working directory is not a directory: {normalized_working_dir}" + logger.error("Git command skipped: %s (%s)", " ".join(["git"] + list(args)), msg) + return False, "", msg + + env = _git_env(shadow_repo, str(normalized_working_dir)) cmd = ["git"] + list(args) allowed_returncodes = allowed_returncodes or set() try: @@ -110,7 +170,7 @@ def _run_git( text=True, timeout=timeout, env=env, - cwd=str(Path(working_dir).resolve()), + cwd=str(normalized_working_dir), ) ok = result.returncode == 0 stdout = result.stdout.strip() @@ -125,9 +185,14 @@ def _run_git( msg = f"git timed out after {timeout}s: {' '.join(cmd)}" logger.error(msg, exc_info=True) return False, "", msg - except FileNotFoundError: - logger.error("Git executable not found: %s", " ".join(cmd), exc_info=True) - return False, "", "git not found" + except FileNotFoundError as exc: + missing_target = getattr(exc, "filename", None) + if missing_target == "git": + logger.error("Git executable not found: %s", " ".join(cmd), exc_info=True) + return False, "", "git not found" + msg = f"working directory not found: {normalized_working_dir}" + logger.error("Git command failed before execution: %s (%s)", " ".join(cmd), msg, exc_info=True) + return False, "", msg except Exception as exc: logger.error("Unexpected git error running %s: %s", " ".join(cmd), exc, exc_info=True) return False, "", str(exc) @@ -154,7 +219,7 @@ def _init_shadow_repo(shadow_repo: Path, working_dir: str) -> Optional[str]: ) (shadow_repo / "HERMES_WORKDIR").write_text( - str(Path(working_dir).resolve()) + "\n", encoding="utf-8" + str(_normalize_path(working_dir)) + "\n", encoding="utf-8" ) logger.debug("Initialised checkpoint repo at %s for %s", shadow_repo, working_dir) @@ -229,7 +294,7 @@ class CheckpointManager: if not self._git_available: return False - abs_dir = str(Path(working_dir).resolve()) + abs_dir = str(_normalize_path(working_dir)) # Skip root, home, and other overly broad directories if abs_dir in ("/", str(Path.home())): @@ -254,7 +319,7 @@ class CheckpointManager: Returns a list of dicts with keys: hash, short_hash, timestamp, reason, files_changed, insertions, deletions. Most recent first. """ - abs_dir = str(Path(working_dir).resolve()) + abs_dir = str(_normalize_path(working_dir)) shadow = _shadow_repo_path(abs_dir) if not (shadow / "HEAD").exists(): @@ -311,7 +376,12 @@ class CheckpointManager: Returns dict with success, diff text, and stat summary. """ - abs_dir = str(Path(working_dir).resolve()) + # Validate commit_hash to prevent git argument injection + hash_err = _validate_commit_hash(commit_hash) + if hash_err: + return {"success": False, "error": hash_err} + + abs_dir = str(_normalize_path(working_dir)) shadow = _shadow_repo_path(abs_dir) if not (shadow / "HEAD").exists(): @@ -364,7 +434,19 @@ class CheckpointManager: Returns dict with success/error info. """ - abs_dir = str(Path(working_dir).resolve()) + # Validate commit_hash to prevent git argument injection + hash_err = _validate_commit_hash(commit_hash) + if hash_err: + return {"success": False, "error": hash_err} + + abs_dir = str(_normalize_path(working_dir)) + + # Validate file_path to prevent path traversal outside the working dir + if file_path: + path_err = _validate_file_path(file_path, abs_dir) + if path_err: + return {"success": False, "error": path_err} + shadow = _shadow_repo_path(abs_dir) if not (shadow / "HEAD").exists(): @@ -413,7 +495,7 @@ class CheckpointManager: (directory containing .git, pyproject.toml, package.json, etc.). Falls back to the file's parent directory. """ - path = Path(file_path).resolve() + path = _normalize_path(file_path) if path.is_dir(): candidate = path else: diff --git a/tools/code_execution_tool.py b/tools/code_execution_tool.py index 7837d70d6c..d6c561e2c3 100644 --- a/tools/code_execution_tool.py +++ b/tools/code_execution_tool.py @@ -924,8 +924,8 @@ def execute_code( # --- Local execution path (UDS) --- below this line is unchanged --- - # Import interrupt event from terminal_tool (cooperative cancellation) - from tools.terminal_tool import _interrupt_event + # Import per-thread interrupt check (cooperative cancellation) + from tools.interrupt import is_interrupted as _is_interrupted # Resolve config _cfg = _load_config() @@ -1114,7 +1114,7 @@ def execute_code( status = "success" while proc.poll() is None: - if _interrupt_event.is_set(): + if _is_interrupted(): _kill_process_group(proc) status = "interrupted" break diff --git a/tools/credential_files.py b/tools/credential_files.py index 6ddcd07708..7998321e63 100644 --- a/tools/credential_files.py +++ b/tools/credential_files.py @@ -80,20 +80,18 @@ def register_credential_file( # Resolve symlinks and normalise ``..`` before the containment check so # that traversal like ``../. ssh/id_rsa`` cannot escape HERMES_HOME. - try: - resolved = host_path.resolve() - hermes_home_resolved = hermes_home.resolve() - resolved.relative_to(hermes_home_resolved) # raises ValueError if outside - except ValueError: + from tools.path_security import validate_within_dir + + containment_error = validate_within_dir(host_path, hermes_home) + if containment_error: logger.warning( - "credential_files: rejected path traversal %r " - "(resolves to %s, outside HERMES_HOME %s)", + "credential_files: rejected path traversal %r (%s)", relative_path, - resolved, - hermes_home_resolved, + containment_error, ) return False + resolved = host_path.resolve() if not resolved.is_file(): logger.debug("credential_files: skipping %s (not found)", resolved) return False @@ -142,7 +140,8 @@ def _load_config_files() -> List[Dict[str, str]]: cfg = read_raw_config() cred_files = cfg.get("terminal", {}).get("credential_files") if isinstance(cred_files, list): - hermes_home_resolved = hermes_home.resolve() + from tools.path_security import validate_within_dir + for item in cred_files: if isinstance(item, str) and item.strip(): rel = item.strip() @@ -151,20 +150,19 @@ def _load_config_files() -> List[Dict[str, str]]: "credential_files: rejected absolute config path %r", rel, ) continue - host_path = (hermes_home / rel).resolve() - try: - host_path.relative_to(hermes_home_resolved) - except ValueError: + host_path = hermes_home / rel + containment_error = validate_within_dir(host_path, hermes_home) + if containment_error: logger.warning( - "credential_files: rejected config path traversal %r " - "(resolves to %s, outside HERMES_HOME %s)", - rel, host_path, hermes_home_resolved, + "credential_files: rejected config path traversal %r (%s)", + rel, containment_error, ) continue - if host_path.is_file(): + resolved_path = host_path.resolve() + if resolved_path.is_file(): container_path = f"/root/.hermes/{rel}" result.append({ - "host_path": str(host_path), + "host_path": str(resolved_path), "container_path": container_path, }) except Exception as e: diff --git a/tools/cronjob_tools.py b/tools/cronjob_tools.py index 3018b8731f..e2db933813 100644 --- a/tools/cronjob_tools.py +++ b/tools/cronjob_tools.py @@ -165,12 +165,12 @@ def _validate_cron_script_path(script: Optional[str]) -> Optional[str]: ) # Validate containment after resolution + from tools.path_security import validate_within_dir + scripts_dir = get_hermes_home() / "scripts" scripts_dir.mkdir(parents=True, exist_ok=True) - resolved = (scripts_dir / raw).resolve() - try: - resolved.relative_to(scripts_dir.resolve()) - except ValueError: + containment_error = validate_within_dir(scripts_dir / raw, scripts_dir) + if containment_error: return ( f"Script path escapes the scripts directory via traversal: {raw!r}" ) diff --git a/tools/interrupt.py b/tools/interrupt.py index e5c9b1e27e..9bc8b83ae4 100644 --- a/tools/interrupt.py +++ b/tools/interrupt.py @@ -1,8 +1,12 @@ -"""Shared interrupt signaling for all tools. +"""Per-thread interrupt signaling for all tools. -Provides a global threading.Event that any tool can check to determine -if the user has requested an interrupt. The agent's interrupt() method -sets this event, and tools poll it during long-running operations. +Provides thread-scoped interrupt tracking so that interrupting one agent +session does not kill tools running in other sessions. This is critical +in the gateway where multiple agents run concurrently in the same process. + +The agent stores its execution thread ID at the start of run_conversation() +and passes it to set_interrupt()/clear_interrupt(). Tools call +is_interrupted() which checks the CURRENT thread — no argument needed. Usage in tools: from tools.interrupt import is_interrupted @@ -12,17 +16,61 @@ Usage in tools: import threading -_interrupt_event = threading.Event() +# Set of thread idents that have been interrupted. +_interrupted_threads: set[int] = set() +_lock = threading.Lock() -def set_interrupt(active: bool) -> None: - """Called by the agent to signal or clear the interrupt.""" - if active: - _interrupt_event.set() - else: - _interrupt_event.clear() +def set_interrupt(active: bool, thread_id: int | None = None) -> None: + """Set or clear interrupt for a specific thread. + + Args: + active: True to signal interrupt, False to clear it. + thread_id: Target thread ident. When None, targets the + current thread (backward compat for CLI/tests). + """ + tid = thread_id if thread_id is not None else threading.current_thread().ident + with _lock: + if active: + _interrupted_threads.add(tid) + else: + _interrupted_threads.discard(tid) def is_interrupted() -> bool: - """Check if an interrupt has been requested. Safe to call from any thread.""" - return _interrupt_event.is_set() + """Check if an interrupt has been requested for the current thread. + + Safe to call from any thread — each thread only sees its own + interrupt state. + """ + tid = threading.current_thread().ident + with _lock: + return tid in _interrupted_threads + + +# --------------------------------------------------------------------------- +# Backward-compatible _interrupt_event proxy +# --------------------------------------------------------------------------- +# Some legacy call sites (code_execution_tool, process_registry, tests) +# import _interrupt_event directly and call .is_set() / .set() / .clear(). +# This shim maps those calls to the per-thread functions above so existing +# code keeps working while the underlying mechanism is thread-scoped. + +class _ThreadAwareEventProxy: + """Drop-in proxy that maps threading.Event methods to per-thread state.""" + + def is_set(self) -> bool: + return is_interrupted() + + def set(self) -> None: # noqa: A003 + set_interrupt(True) + + def clear(self) -> None: + set_interrupt(False) + + def wait(self, timeout: float | None = None) -> bool: + """Not truly supported — returns current state immediately.""" + return self.is_set() + + +_interrupt_event = _ThreadAwareEventProxy() diff --git a/tools/path_security.py b/tools/path_security.py new file mode 100644 index 0000000000..828011e5d7 --- /dev/null +++ b/tools/path_security.py @@ -0,0 +1,43 @@ +"""Shared path validation helpers for tool implementations. + +Extracts the ``resolve() + relative_to()`` and ``..`` traversal check +patterns previously duplicated across skill_manager_tool, skills_tool, +skills_hub, cronjob_tools, and credential_files. +""" + +import logging +from pathlib import Path +from typing import Optional + +logger = logging.getLogger(__name__) + + +def validate_within_dir(path: Path, root: Path) -> Optional[str]: + """Ensure *path* resolves to a location within *root*. + + Returns an error message string if validation fails, or ``None`` if the + path is safe. Uses ``Path.resolve()`` to follow symlinks and normalize + ``..`` components. + + Usage:: + + error = validate_within_dir(user_path, allowed_root) + if error: + return json.dumps({"error": error}) + """ + try: + resolved = path.resolve() + root_resolved = root.resolve() + resolved.relative_to(root_resolved) + except (ValueError, OSError) as exc: + return f"Path escapes allowed directory: {exc}" + return None + + +def has_traversal_component(path_str: str) -> bool: + """Return True if *path_str* contains ``..`` traversal components. + + Quick check for obvious traversal attempts before doing full resolution. + """ + parts = Path(path_str).parts + return ".." in parts diff --git a/tools/process_registry.py b/tools/process_registry.py index 18d0b1de22..0e8e04b3b0 100644 --- a/tools/process_registry.py +++ b/tools/process_registry.py @@ -96,6 +96,8 @@ class ProcessSession: # Watcher/notification metadata (persisted for crash recovery) watcher_platform: str = "" watcher_chat_id: str = "" + watcher_user_id: str = "" + watcher_user_name: str = "" watcher_thread_id: str = "" watcher_interval: int = 0 # 0 = no watcher configured notify_on_complete: bool = False # Queue agent notification on exit @@ -695,7 +697,7 @@ class ProcessRegistry: and output snapshot. """ from tools.ansi_strip import strip_ansi - from tools.terminal_tool import _interrupt_event + from tools.interrupt import is_interrupted as _is_interrupted try: default_timeout = int(os.getenv("TERMINAL_TIMEOUT", "180")) @@ -732,7 +734,7 @@ class ProcessRegistry: result["timeout_note"] = timeout_note return result - if _interrupt_event.is_set(): + if _is_interrupted(): result = { "status": "interrupted", "output": strip_ansi(session.output_buffer[-1000:]), @@ -981,6 +983,8 @@ class ProcessRegistry: "session_key": s.session_key, "watcher_platform": s.watcher_platform, "watcher_chat_id": s.watcher_chat_id, + "watcher_user_id": s.watcher_user_id, + "watcher_user_name": s.watcher_user_name, "watcher_thread_id": s.watcher_thread_id, "watcher_interval": s.watcher_interval, "notify_on_complete": s.notify_on_complete, @@ -1042,6 +1046,8 @@ class ProcessRegistry: detached=True, # Can't read output, but can report status + kill watcher_platform=entry.get("watcher_platform", ""), watcher_chat_id=entry.get("watcher_chat_id", ""), + watcher_user_id=entry.get("watcher_user_id", ""), + watcher_user_name=entry.get("watcher_user_name", ""), watcher_thread_id=entry.get("watcher_thread_id", ""), watcher_interval=entry.get("watcher_interval", 0), notify_on_complete=entry.get("notify_on_complete", False), @@ -1060,6 +1066,8 @@ class ProcessRegistry: "session_key": session.session_key, "platform": session.watcher_platform, "chat_id": session.watcher_chat_id, + "user_id": session.watcher_user_id, + "user_name": session.watcher_user_name, "thread_id": session.watcher_thread_id, "notify_on_complete": session.notify_on_complete, }) diff --git a/tools/skill_manager_tool.py b/tools/skill_manager_tool.py index 2273d75fa6..2b2625fa0d 100644 --- a/tools/skill_manager_tool.py +++ b/tools/skill_manager_tool.py @@ -219,13 +219,15 @@ def _validate_file_path(file_path: str) -> Optional[str]: Validate a file path for write_file/remove_file. Must be under an allowed subdirectory and not escape the skill dir. """ + from tools.path_security import has_traversal_component + if not file_path: return "file_path is required." normalized = Path(file_path) # Prevent path traversal - if ".." in normalized.parts: + if has_traversal_component(file_path): return "Path traversal ('..') is not allowed." # Must be under an allowed subdirectory @@ -242,15 +244,12 @@ def _validate_file_path(file_path: str) -> Optional[str]: def _resolve_skill_target(skill_dir: Path, file_path: str) -> Tuple[Optional[Path], Optional[str]]: """Resolve a supporting-file path and ensure it stays within the skill directory.""" + from tools.path_security import validate_within_dir + target = skill_dir / file_path - try: - resolved = target.resolve(strict=False) - skill_dir_resolved = skill_dir.resolve() - resolved.relative_to(skill_dir_resolved) - except ValueError: - return None, "Path escapes skill directory boundary." - except OSError as e: - return None, f"Invalid file path '{file_path}': {e}" + error = validate_within_dir(target, skill_dir) + if error: + return None, error return target, None diff --git a/tools/skills_tool.py b/tools/skills_tool.py index 085ed00550..94b7c235b7 100644 --- a/tools/skills_tool.py +++ b/tools/skills_tool.py @@ -447,17 +447,8 @@ def _get_category_from_path(skill_path: Path) -> Optional[str]: return None -def _estimate_tokens(content: str) -> int: - """ - Rough token estimate (4 chars per token average). - - Args: - content: Text content - - Returns: - Estimated token count - """ - return len(content) // 4 +# Token estimation — use the shared implementation from model_metadata. +from agent.model_metadata import estimate_tokens_rough as _estimate_tokens def _parse_tags(tags_value) -> List[str]: @@ -947,9 +938,10 @@ def skill_view(name: str, file_path: str = None, task_id: str = None) -> str: # If a specific file path is requested, read that instead if file_path and skill_dir: + from tools.path_security import validate_within_dir, has_traversal_component + # Security: Prevent path traversal attacks - normalized_path = Path(file_path) - if ".." in normalized_path.parts: + if has_traversal_component(file_path): return json.dumps( { "success": False, @@ -962,24 +954,13 @@ def skill_view(name: str, file_path: str = None, task_id: str = None) -> str: target_file = skill_dir / file_path # Security: Verify resolved path is still within skill directory - try: - resolved = target_file.resolve() - skill_dir_resolved = skill_dir.resolve() - if not resolved.is_relative_to(skill_dir_resolved): - return json.dumps( - { - "success": False, - "error": "Path escapes skill directory boundary.", - "hint": "Use a relative path within the skill directory", - }, - ensure_ascii=False, - ) - except (OSError, ValueError): + traversal_error = validate_within_dir(target_file, skill_dir) + if traversal_error: return json.dumps( { "success": False, - "error": f"Invalid file path: '{file_path}'", - "hint": "Use a valid relative path within the skill directory", + "error": traversal_error, + "hint": "Use a relative path within the skill directory", }, ensure_ascii=False, ) diff --git a/tools/terminal_tool.py b/tools/terminal_tool.py index 859f0f1f36..f0cbff0f4c 100644 --- a/tools/terminal_tool.py +++ b/tools/terminal_tool.py @@ -1427,8 +1427,12 @@ def terminal_tool( if _gw_platform and not check_interval: _gw_chat_id = _gse("HERMES_SESSION_CHAT_ID", "") _gw_thread_id = _gse("HERMES_SESSION_THREAD_ID", "") + _gw_user_id = _gse("HERMES_SESSION_USER_ID", "") + _gw_user_name = _gse("HERMES_SESSION_USER_NAME", "") proc_session.watcher_platform = _gw_platform proc_session.watcher_chat_id = _gw_chat_id + proc_session.watcher_user_id = _gw_user_id + proc_session.watcher_user_name = _gw_user_name proc_session.watcher_thread_id = _gw_thread_id proc_session.watcher_interval = 5 process_registry.pending_watchers.append({ @@ -1437,6 +1441,8 @@ def terminal_tool( "session_key": session_key, "platform": _gw_platform, "chat_id": _gw_chat_id, + "user_id": _gw_user_id, + "user_name": _gw_user_name, "thread_id": _gw_thread_id, "notify_on_complete": True, }) @@ -1457,10 +1463,14 @@ def terminal_tool( watcher_platform = _gse2("HERMES_SESSION_PLATFORM", "") watcher_chat_id = _gse2("HERMES_SESSION_CHAT_ID", "") watcher_thread_id = _gse2("HERMES_SESSION_THREAD_ID", "") + watcher_user_id = _gse2("HERMES_SESSION_USER_ID", "") + watcher_user_name = _gse2("HERMES_SESSION_USER_NAME", "") # Store on session for checkpoint persistence proc_session.watcher_platform = watcher_platform proc_session.watcher_chat_id = watcher_chat_id + proc_session.watcher_user_id = watcher_user_id + proc_session.watcher_user_name = watcher_user_name proc_session.watcher_thread_id = watcher_thread_id proc_session.watcher_interval = effective_interval @@ -1470,6 +1480,8 @@ def terminal_tool( "session_key": session_key, "platform": watcher_platform, "chat_id": watcher_chat_id, + "user_id": watcher_user_id, + "user_name": watcher_user_name, "thread_id": watcher_thread_id, }) diff --git a/tools/tool_result_storage.py b/tools/tool_result_storage.py index a8ec5440bc..4342264482 100644 --- a/tools/tool_result_storage.py +++ b/tools/tool_result_storage.py @@ -24,6 +24,7 @@ Defense against context-window overflow operates at three levels: import logging import os +import shlex import uuid from tools.budget_config import ( @@ -79,7 +80,7 @@ def _write_to_sandbox(content: str, remote_path: str, env) -> bool: marker = _heredoc_marker(content) storage_dir = os.path.dirname(remote_path) cmd = ( - f"mkdir -p {storage_dir} && cat > {remote_path} << '{marker}'\n" + f"mkdir -p {shlex.quote(storage_dir)} && cat > {shlex.quote(remote_path)} << '{marker}'\n" f"{content}\n" f"{marker}" ) diff --git a/utils.py b/utils.py index 9a2105d54f..bd2a6b70f5 100644 --- a/utils.py +++ b/utils.py @@ -1,13 +1,16 @@ """Shared utility functions for hermes-agent.""" import json +import logging import os import tempfile from pathlib import Path -from typing import Any, Union +from typing import Any, List, Optional, Union import yaml +logger = logging.getLogger(__name__) + TRUTHY_STRINGS = frozenset({"1", "true", "yes", "on"}) @@ -124,3 +127,88 @@ def atomic_yaml_write( except OSError: pass raise + + +# ─── JSON Helpers ───────────────────────────────────────────────────────────── + + +def safe_json_loads(text: str, default: Any = None) -> Any: + """Parse JSON, returning *default* on any parse error. + + Replaces the ``try: json.loads(x) except (JSONDecodeError, TypeError)`` + pattern duplicated across display.py, anthropic_adapter.py, + auxiliary_client.py, and others. + """ + try: + return json.loads(text) + except (json.JSONDecodeError, TypeError, ValueError): + return default + + +def read_json_file(path: Path, default: Any = None) -> Any: + """Read and parse a JSON file, returning *default* on any error. + + Replaces the repeated ``try: json.loads(path.read_text()) except ...`` + pattern in anthropic_adapter.py, auxiliary_client.py, credential_pool.py, + and skill_utils.py. + """ + try: + return json.loads(Path(path).read_text(encoding="utf-8")) + except (json.JSONDecodeError, OSError, IOError, ValueError) as exc: + logger.debug("Failed to read %s: %s", path, exc) + return default + + +def read_jsonl(path: Path) -> List[dict]: + """Read a JSONL file (one JSON object per line). + + Returns a list of parsed objects, skipping blank lines. + """ + entries = [] + with open(path, "r", encoding="utf-8") as f: + for line in f: + line = line.strip() + if line: + entries.append(json.loads(line)) + return entries + + +def append_jsonl(path: Path, entry: dict) -> None: + """Append a single JSON object as a new line to a JSONL file.""" + path = Path(path) + path.parent.mkdir(parents=True, exist_ok=True) + with open(path, "a", encoding="utf-8") as f: + f.write(json.dumps(entry, ensure_ascii=False) + "\n") + + +# ─── Environment Variable Helpers ───────────────────────────────────────────── + + +def env_str(key: str, default: str = "") -> str: + """Read an environment variable, stripped of whitespace. + + Replaces the ``os.getenv("X", "").strip()`` pattern repeated 50+ times + across runtime_provider.py, anthropic_adapter.py, models.py, etc. + """ + return os.getenv(key, default).strip() + + +def env_lower(key: str, default: str = "") -> str: + """Read an environment variable, stripped and lowercased.""" + return os.getenv(key, default).strip().lower() + + +def env_int(key: str, default: int = 0) -> int: + """Read an environment variable as an integer, with fallback.""" + raw = os.getenv(key, "").strip() + if not raw: + return default + try: + return int(raw) + except (ValueError, TypeError): + return default + + +def env_bool(key: str, default: bool = False) -> bool: + """Read an environment variable as a boolean.""" + return is_truthy_value(os.getenv(key, ""), default=default) diff --git a/website/docs/reference/environment-variables.md b/website/docs/reference/environment-variables.md index 958faa61f9..a548a6ff6d 100644 --- a/website/docs/reference/environment-variables.md +++ b/website/docs/reference/environment-variables.md @@ -195,9 +195,12 @@ For cloud sandbox backends, persistence is filesystem-oriented. `TERMINAL_LIFETI | `SIGNAL_IGNORE_STORIES` | Ignore Signal stories/status updates | | `SIGNAL_ALLOW_ALL_USERS` | Allow all Signal users without an allowlist | | `TWILIO_ACCOUNT_SID` | Twilio Account SID (shared with telephony skill) | -| `TWILIO_AUTH_TOKEN` | Twilio Auth Token (shared with telephony skill) | +| `TWILIO_AUTH_TOKEN` | Twilio Auth Token (shared with telephony skill; also used for webhook signature validation) | | `TWILIO_PHONE_NUMBER` | Twilio phone number in E.164 format (shared with telephony skill) | +| `SMS_WEBHOOK_URL` | Public URL for Twilio signature validation — must match the webhook URL in Twilio Console (required) | | `SMS_WEBHOOK_PORT` | Webhook listener port for inbound SMS (default: `8080`) | +| `SMS_WEBHOOK_HOST` | Webhook bind address (default: `0.0.0.0`) | +| `SMS_INSECURE_NO_SIGNATURE` | Set to `true` to disable Twilio signature validation (local dev only — not for production) | | `SMS_ALLOWED_USERS` | Comma-separated E.164 phone numbers allowed to chat | | `SMS_ALLOW_ALL_USERS` | Allow all SMS senders without an allowlist | | `SMS_HOME_CHANNEL` | Phone number for cron job / notification delivery | diff --git a/website/docs/user-guide/messaging/index.md b/website/docs/user-guide/messaging/index.md index 335c6530bc..41b0314379 100644 --- a/website/docs/user-guide/messaging/index.md +++ b/website/docs/user-guide/messaging/index.md @@ -178,6 +178,8 @@ EMAIL_ALLOWED_USERS=trusted@example.com,colleague@work.com MATTERMOST_ALLOWED_USERS=3uo8dkh1p7g1mfk49ear5fzs5c MATRIX_ALLOWED_USERS=@alice:matrix.org DINGTALK_ALLOWED_USERS=user-id-1 +FEISHU_ALLOWED_USERS=ou_xxxxxxxx,ou_yyyyyyyy +WECOM_ALLOWED_USERS=user-id-1,user-id-2 # Or allow GATEWAY_ALLOWED_USERS=123456789,987654321 diff --git a/website/docs/user-guide/messaging/sms.md b/website/docs/user-guide/messaging/sms.md index 84a3b8fa2f..c5b28cd6fd 100644 --- a/website/docs/user-guide/messaging/sms.md +++ b/website/docs/user-guide/messaging/sms.md @@ -84,6 +84,13 @@ ngrok http 8080 Set the resulting public URL as your Twilio webhook. ::: +**Set `SMS_WEBHOOK_URL` to the same URL you configured in Twilio.** This is required for Twilio signature validation — the adapter will refuse to start without it: + +```bash +# Must match the webhook URL in your Twilio Console +SMS_WEBHOOK_URL=https://your-server:8080/webhooks/twilio +``` + The webhook port defaults to `8080`. Override with: ```bash @@ -101,9 +108,11 @@ hermes gateway You should see: ``` -[sms] Twilio webhook server listening on port 8080, from: +1555***4567 +[sms] Twilio webhook server listening on 0.0.0.0:8080, from: +1555***4567 ``` +If you see `Refusing to start: SMS_WEBHOOK_URL is required`, set `SMS_WEBHOOK_URL` to the public URL configured in your Twilio Console (see Step 3). + Text your Twilio number — Hermes will respond via SMS. --- @@ -113,9 +122,12 @@ Text your Twilio number — Hermes will respond via SMS. | Variable | Required | Description | |----------|----------|-------------| | `TWILIO_ACCOUNT_SID` | Yes | Twilio Account SID (starts with `AC`) | -| `TWILIO_AUTH_TOKEN` | Yes | Twilio Auth Token | +| `TWILIO_AUTH_TOKEN` | Yes | Twilio Auth Token (also used for webhook signature validation) | | `TWILIO_PHONE_NUMBER` | Yes | Your Twilio phone number (E.164 format) | +| `SMS_WEBHOOK_URL` | Yes | Public URL for Twilio signature validation — must match the webhook URL in your Twilio Console | | `SMS_WEBHOOK_PORT` | No | Webhook listener port (default: `8080`) | +| `SMS_WEBHOOK_HOST` | No | Webhook bind address (default: `0.0.0.0`) | +| `SMS_INSECURE_NO_SIGNATURE` | No | Set to `true` to disable signature validation (local dev only — **not for production**) | | `SMS_ALLOWED_USERS` | No | Comma-separated E.164 phone numbers allowed to chat | | `SMS_ALLOW_ALL_USERS` | No | Set to `true` to allow anyone (not recommended) | | `SMS_HOME_CHANNEL` | No | Phone number for cron job / notification delivery | @@ -134,6 +146,21 @@ Text your Twilio number — Hermes will respond via SMS. ## Security +### Webhook signature validation + +Hermes validates that inbound webhooks genuinely originate from Twilio by verifying the `X-Twilio-Signature` header (HMAC-SHA1). This prevents attackers from injecting forged messages. + +**`SMS_WEBHOOK_URL` is required.** Set it to the public URL configured in your Twilio Console. The adapter will refuse to start without it. + +For local development without a public URL, you can disable validation: + +```bash +# Local dev only — NOT for production +SMS_INSECURE_NO_SIGNATURE=true +``` + +### User allowlists + **The gateway denies all users by default.** Configure an allowlist: ```bash