diff --git a/gateway/platforms/helpers.py b/gateway/platforms/helpers.py new file mode 100644 index 0000000000..7239743ad4 --- /dev/null +++ b/gateway/platforms/helpers.py @@ -0,0 +1,261 @@ +"""Shared helper classes for gateway platform adapters. + +Extracts common patterns that were duplicated across 5-7 adapters: +message deduplication, text batch aggregation, markdown stripping, +and thread participation tracking. +""" + +import asyncio +import json +import logging +import re +import time +from pathlib import Path +from typing import TYPE_CHECKING, Dict, Optional + +if TYPE_CHECKING: + from gateway.platforms.base import BasePlatformAdapter, MessageEvent + +logger = logging.getLogger(__name__) + + +# ─── Message Deduplication ──────────────────────────────────────────────────── + + +class MessageDeduplicator: + """TTL-based message deduplication cache. + + Replaces the identical ``_seen_messages`` / ``_is_duplicate()`` pattern + previously duplicated in discord, slack, dingtalk, wecom, weixin, + mattermost, and feishu adapters. + + Usage:: + + self._dedup = MessageDeduplicator() + + # In message handler: + if self._dedup.is_duplicate(msg_id): + return + """ + + def __init__(self, max_size: int = 2000, ttl_seconds: float = 300): + self._seen: Dict[str, float] = {} + self._max_size = max_size + self._ttl = ttl_seconds + + def is_duplicate(self, msg_id: str) -> bool: + """Return True if *msg_id* was already seen within the TTL window.""" + if not msg_id: + return False + now = time.time() + if msg_id in self._seen: + return True + self._seen[msg_id] = now + if len(self._seen) > self._max_size: + cutoff = now - self._ttl + self._seen = {k: v for k, v in self._seen.items() if v > cutoff} + return False + + def clear(self): + """Clear all tracked messages.""" + self._seen.clear() + + +# ─── Text Batch Aggregation ────────────────────────────────────────────────── + + +class TextBatchAggregator: + """Aggregates rapid-fire text events into single messages. + + Replaces the ``_enqueue_text_event`` / ``_flush_text_batch`` pattern + previously duplicated in telegram, discord, matrix, wecom, and feishu. + + Usage:: + + self._text_batcher = TextBatchAggregator( + handler=self._message_handler, + batch_delay=0.6, + split_threshold=1900, + ) + + # In message dispatch: + if msg_type == MessageType.TEXT and self._text_batcher.is_enabled(): + self._text_batcher.enqueue(event, session_key) + return + """ + + def __init__( + self, + handler, + *, + batch_delay: float = 0.6, + split_delay: float = 2.0, + split_threshold: int = 4000, + ): + self._handler = handler + self._batch_delay = batch_delay + self._split_delay = split_delay + self._split_threshold = split_threshold + self._pending: Dict[str, "MessageEvent"] = {} + self._pending_tasks: Dict[str, asyncio.Task] = {} + + def is_enabled(self) -> bool: + """Return True if batching is active (delay > 0).""" + return self._batch_delay > 0 + + def enqueue(self, event: "MessageEvent", key: str) -> None: + """Add *event* to the pending batch for *key*.""" + chunk_len = len(event.text or "") + existing = self._pending.get(key) + if not existing: + event._last_chunk_len = chunk_len # type: ignore[attr-defined] + self._pending[key] = event + else: + existing.text = f"{existing.text}\n{event.text}" + existing._last_chunk_len = chunk_len # type: ignore[attr-defined] + + # Cancel prior flush timer, start a new one + prior = self._pending_tasks.get(key) + if prior and not prior.done(): + prior.cancel() + self._pending_tasks[key] = asyncio.create_task(self._flush(key)) + + async def _flush(self, key: str) -> None: + """Wait then dispatch the batched event for *key*.""" + current_task = self._pending_tasks.get(key) + pending = self._pending.get(key) + last_len = getattr(pending, "_last_chunk_len", 0) if pending else 0 + + # Use longer delay when the last chunk looks like a split message + delay = self._split_delay if last_len >= self._split_threshold else self._batch_delay + await asyncio.sleep(delay) + + event = self._pending.pop(key, None) + if event: + try: + await self._handler(event) + except Exception: + logger.exception("[TextBatchAggregator] Error dispatching batched event for %s", key) + + if self._pending_tasks.get(key) is current_task: + self._pending_tasks.pop(key, None) + + def cancel_all(self) -> None: + """Cancel all pending flush tasks.""" + for task in self._pending_tasks.values(): + if not task.done(): + task.cancel() + self._pending_tasks.clear() + self._pending.clear() + + +# ─── Markdown Stripping ────────────────────────────────────────────────────── + +# Pre-compiled regexes for performance +_RE_BOLD = re.compile(r"\*\*(.+?)\*\*", re.DOTALL) +_RE_ITALIC_STAR = re.compile(r"\*(.+?)\*", re.DOTALL) +_RE_BOLD_UNDER = re.compile(r"__(.+?)__", re.DOTALL) +_RE_ITALIC_UNDER = re.compile(r"_(.+?)_", re.DOTALL) +_RE_CODE_BLOCK = re.compile(r"```[a-zA-Z0-9_+-]*\n?") +_RE_INLINE_CODE = re.compile(r"`(.+?)`") +_RE_HEADING = re.compile(r"^#{1,6}\s+", re.MULTILINE) +_RE_LINK = re.compile(r"\[([^\]]+)\]\([^\)]+\)") +_RE_MULTI_NEWLINE = re.compile(r"\n{3,}") + + +def strip_markdown(text: str) -> str: + """Strip markdown formatting for plain-text platforms (SMS, iMessage, etc.). + + Replaces the identical ``_strip_markdown()`` functions previously + duplicated in sms.py, bluebubbles.py, and feishu.py. + """ + text = _RE_BOLD.sub(r"\1", text) + text = _RE_ITALIC_STAR.sub(r"\1", text) + text = _RE_BOLD_UNDER.sub(r"\1", text) + text = _RE_ITALIC_UNDER.sub(r"\1", text) + text = _RE_CODE_BLOCK.sub("", text) + text = _RE_INLINE_CODE.sub(r"\1", text) + text = _RE_HEADING.sub("", text) + text = _RE_LINK.sub(r"\1", text) + text = _RE_MULTI_NEWLINE.sub("\n\n", text) + return text.strip() + + +# ─── Thread Participation Tracking ─────────────────────────────────────────── + + +class ThreadParticipationTracker: + """Persistent tracking of threads the bot has participated in. + + Replaces the identical ``_load/_save_participated_threads`` + + ``_mark_thread_participated`` pattern previously duplicated in + discord.py and matrix.py. + + Usage:: + + self._threads = ThreadParticipationTracker("discord") + + # Check membership: + if thread_id in self._threads: + ... + + # Mark participation: + self._threads.mark(thread_id) + """ + + _MAX_TRACKED = 500 + + def __init__(self, platform_name: str, max_tracked: int = 500): + self._platform = platform_name + self._max_tracked = max_tracked + self._threads: set = self._load() + + def _state_path(self) -> Path: + from hermes_constants import get_hermes_home + return get_hermes_home() / "platforms" / f"{self._platform}_threads.json" + + def _load(self) -> set: + path = self._state_path() + if path.exists(): + try: + return set(json.loads(path.read_text(encoding="utf-8"))) + except Exception: + pass + return set() + + def _save(self) -> None: + path = self._state_path() + path.parent.mkdir(parents=True, exist_ok=True) + thread_list = list(self._threads) + if len(thread_list) > self._max_tracked: + thread_list = thread_list[-self._max_tracked:] + self._threads = set(thread_list) + path.write_text(json.dumps(thread_list), encoding="utf-8") + + def mark(self, thread_id: str) -> None: + """Mark *thread_id* as participated and persist.""" + if thread_id not in self._threads: + self._threads.add(thread_id) + self._save() + + def __contains__(self, thread_id: str) -> bool: + return thread_id in self._threads + + def clear(self) -> None: + self._threads.clear() + + +# ─── Phone Number Redaction ────────────────────────────────────────────────── + + +def redact_phone(phone: str) -> str: + """Redact a phone number for logging, preserving country code and last 4. + + Replaces the identical ``_redact_phone()`` functions in signal.py, + sms.py, and bluebubbles.py. + """ + if not phone: + return "" + if len(phone) <= 8: + return phone[:2] + "****" + phone[-2:] if len(phone) > 4 else "****" + return phone[:4] + "****" + phone[-4:] diff --git a/hermes_cli/cli_output.py b/hermes_cli/cli_output.py new file mode 100644 index 0000000000..7028f0b337 --- /dev/null +++ b/hermes_cli/cli_output.py @@ -0,0 +1,79 @@ +"""Shared CLI output helpers for Hermes CLI modules. + +Extracts the identical ``print_info/success/warning/error`` and ``prompt()`` +functions previously duplicated across setup.py, tools_config.py, +mcp_config.py, and memory_setup.py. +""" + +import getpass +import sys + +from hermes_cli.colors import Colors, color + + +# ─── Print Helpers ──────────────────────────────────────────────────────────── + + +def print_info(text: str) -> None: + """Print a dim informational message.""" + print(color(f" {text}", Colors.DIM)) + + +def print_success(text: str) -> None: + """Print a green success message with ✓ prefix.""" + print(color(f"✓ {text}", Colors.GREEN)) + + +def print_warning(text: str) -> None: + """Print a yellow warning message with ⚠ prefix.""" + print(color(f"⚠ {text}", Colors.YELLOW)) + + +def print_error(text: str) -> None: + """Print a red error message with ✗ prefix.""" + print(color(f"✗ {text}", Colors.RED)) + + +def print_header(text: str) -> None: + """Print a bold yellow header.""" + print(color(f"\n {text}", Colors.YELLOW)) + + +# ─── Input Prompts ──────────────────────────────────────────────────────────── + + +def prompt( + question: str, + default: str | None = None, + password: bool = False, +) -> str: + """Prompt the user for input with optional default and password masking. + + Replaces the four independent ``_prompt()`` / ``prompt()`` implementations + in setup.py, tools_config.py, mcp_config.py, and memory_setup.py. + + Returns the user's input (stripped), or *default* if the user presses Enter. + Returns empty string on Ctrl-C or EOF. + """ + suffix = f" [{default}]" if default else "" + display = color(f" {question}{suffix}: ", Colors.DIM) + + try: + if password: + value = getpass.getpass(display) + else: + value = input(display) + value = value.strip() + return value if value else (default or "") + except (KeyboardInterrupt, EOFError): + print() + return "" + + +def prompt_yes_no(question: str, default: bool = True) -> bool: + """Prompt for a yes/no answer. Returns bool.""" + hint = "Y/n" if default else "y/N" + answer = prompt(f"{question} ({hint})") + if not answer: + return default + return answer.lower().startswith("y") diff --git a/hermes_constants.py b/hermes_constants.py index 7d149f404e..85955d5482 100644 --- a/hermes_constants.py +++ b/hermes_constants.py @@ -189,6 +189,33 @@ def is_wsl() -> bool: return _wsl_detected +# ─── Well-Known Paths ───────────────────────────────────────────────────────── + + +def get_config_path() -> Path: + """Return the path to ``config.yaml`` under HERMES_HOME. + + Replaces the ``get_hermes_home() / "config.yaml"`` pattern repeated + in 7+ files (skill_utils.py, hermes_logging.py, hermes_time.py, etc.). + """ + return get_hermes_home() / "config.yaml" + + +def get_skills_dir() -> Path: + """Return the path to the skills directory under HERMES_HOME.""" + return get_hermes_home() / "skills" + + +def get_logs_dir() -> Path: + """Return the path to the logs directory under HERMES_HOME.""" + return get_hermes_home() / "logs" + + +def get_env_path() -> Path: + """Return the path to the ``.env`` file under HERMES_HOME.""" + return get_hermes_home() / ".env" + + OPENROUTER_BASE_URL = "https://openrouter.ai/api/v1" OPENROUTER_MODELS_URL = f"{OPENROUTER_BASE_URL}/models" diff --git a/tools/path_security.py b/tools/path_security.py new file mode 100644 index 0000000000..828011e5d7 --- /dev/null +++ b/tools/path_security.py @@ -0,0 +1,43 @@ +"""Shared path validation helpers for tool implementations. + +Extracts the ``resolve() + relative_to()`` and ``..`` traversal check +patterns previously duplicated across skill_manager_tool, skills_tool, +skills_hub, cronjob_tools, and credential_files. +""" + +import logging +from pathlib import Path +from typing import Optional + +logger = logging.getLogger(__name__) + + +def validate_within_dir(path: Path, root: Path) -> Optional[str]: + """Ensure *path* resolves to a location within *root*. + + Returns an error message string if validation fails, or ``None`` if the + path is safe. Uses ``Path.resolve()`` to follow symlinks and normalize + ``..`` components. + + Usage:: + + error = validate_within_dir(user_path, allowed_root) + if error: + return json.dumps({"error": error}) + """ + try: + resolved = path.resolve() + root_resolved = root.resolve() + resolved.relative_to(root_resolved) + except (ValueError, OSError) as exc: + return f"Path escapes allowed directory: {exc}" + return None + + +def has_traversal_component(path_str: str) -> bool: + """Return True if *path_str* contains ``..`` traversal components. + + Quick check for obvious traversal attempts before doing full resolution. + """ + parts = Path(path_str).parts + return ".." in parts diff --git a/utils.py b/utils.py index 9a2105d54f..bd2a6b70f5 100644 --- a/utils.py +++ b/utils.py @@ -1,13 +1,16 @@ """Shared utility functions for hermes-agent.""" import json +import logging import os import tempfile from pathlib import Path -from typing import Any, Union +from typing import Any, List, Optional, Union import yaml +logger = logging.getLogger(__name__) + TRUTHY_STRINGS = frozenset({"1", "true", "yes", "on"}) @@ -124,3 +127,88 @@ def atomic_yaml_write( except OSError: pass raise + + +# ─── JSON Helpers ───────────────────────────────────────────────────────────── + + +def safe_json_loads(text: str, default: Any = None) -> Any: + """Parse JSON, returning *default* on any parse error. + + Replaces the ``try: json.loads(x) except (JSONDecodeError, TypeError)`` + pattern duplicated across display.py, anthropic_adapter.py, + auxiliary_client.py, and others. + """ + try: + return json.loads(text) + except (json.JSONDecodeError, TypeError, ValueError): + return default + + +def read_json_file(path: Path, default: Any = None) -> Any: + """Read and parse a JSON file, returning *default* on any error. + + Replaces the repeated ``try: json.loads(path.read_text()) except ...`` + pattern in anthropic_adapter.py, auxiliary_client.py, credential_pool.py, + and skill_utils.py. + """ + try: + return json.loads(Path(path).read_text(encoding="utf-8")) + except (json.JSONDecodeError, OSError, IOError, ValueError) as exc: + logger.debug("Failed to read %s: %s", path, exc) + return default + + +def read_jsonl(path: Path) -> List[dict]: + """Read a JSONL file (one JSON object per line). + + Returns a list of parsed objects, skipping blank lines. + """ + entries = [] + with open(path, "r", encoding="utf-8") as f: + for line in f: + line = line.strip() + if line: + entries.append(json.loads(line)) + return entries + + +def append_jsonl(path: Path, entry: dict) -> None: + """Append a single JSON object as a new line to a JSONL file.""" + path = Path(path) + path.parent.mkdir(parents=True, exist_ok=True) + with open(path, "a", encoding="utf-8") as f: + f.write(json.dumps(entry, ensure_ascii=False) + "\n") + + +# ─── Environment Variable Helpers ───────────────────────────────────────────── + + +def env_str(key: str, default: str = "") -> str: + """Read an environment variable, stripped of whitespace. + + Replaces the ``os.getenv("X", "").strip()`` pattern repeated 50+ times + across runtime_provider.py, anthropic_adapter.py, models.py, etc. + """ + return os.getenv(key, default).strip() + + +def env_lower(key: str, default: str = "") -> str: + """Read an environment variable, stripped and lowercased.""" + return os.getenv(key, default).strip().lower() + + +def env_int(key: str, default: int = 0) -> int: + """Read an environment variable as an integer, with fallback.""" + raw = os.getenv(key, "").strip() + if not raw: + return default + try: + return int(raw) + except (ValueError, TypeError): + return default + + +def env_bool(key: str, default: bool = False) -> bool: + """Read an environment variable as a boolean.""" + return is_truthy_value(os.getenv(key, ""), default=default)