From 6f1889b0fa228dc74af0efe7b3c804f3c3c725d2 Mon Sep 17 00:00:00 2001 From: teknium1 Date: Sat, 14 Mar 2026 00:17:04 -0700 Subject: [PATCH] fix: preserve current approval semantics for tirith guard Restore gateway/run.py to current main behavior while keeping tirith startup and pattern_keys replay, preserve yolo and non-interactive bypass semantics in the combined guard, and add regression tests for yolo and view-full flows. --- gateway/run.py | 1883 +++++++++++++++++++++++++++- tests/tools/test_approval.py | 12 + tests/tools/test_command_guards.py | 17 +- tests/tools/test_yolo_mode.py | 38 +- tools/approval.py | 22 +- 5 files changed, 1959 insertions(+), 13 deletions(-) diff --git a/gateway/run.py b/gateway/run.py index 11106584dd..1b7a2ed6ec 100644 --- a/gateway/run.py +++ b/gateway/run.py @@ -248,8 +248,8 @@ class GatewayRunner: self._pending_messages: Dict[str, str] = {} # Queued messages during interrupt # Track pending exec approvals per session - # Key: session_key, Value: {"command": str, "pattern_key": str} - self._pending_approvals: Dict[str, Dict[str, str]] = {} + # Key: session_key, Value: {"command": str, "pattern_key": str, ...} + self._pending_approvals: Dict[str, Dict[str, Any]] = {} # Persistent Honcho managers keyed by gateway session key. # This preserves write_frequency="session" semantics across short-lived @@ -1996,3 +1996,1882 @@ class GatewayRunner: """Handle /undo command - remove the last user/assistant exchange.""" source = event.source session_entry = self.session_store.get_or_create_session(source) + history = self.session_store.load_transcript(session_entry.session_id) + + # Find the last user message and remove everything from it onward + last_user_idx = None + for i in range(len(history) - 1, -1, -1): + if history[i].get("role") == "user": + last_user_idx = i + break + + if last_user_idx is None: + return "Nothing to undo." + + removed_msg = history[last_user_idx].get("content", "") + removed_count = len(history) - last_user_idx + self.session_store.rewrite_transcript(session_entry.session_id, history[:last_user_idx]) + # Reset stored token count — transcript was truncated + session_entry.last_prompt_tokens = 0 + + preview = removed_msg[:40] + "..." if len(removed_msg) > 40 else removed_msg + return f"↩️ Undid {removed_count} message(s).\nRemoved: \"{preview}\"" + + async def _handle_set_home_command(self, event: MessageEvent) -> str: + """Handle /sethome command -- set the current chat as the platform's home channel.""" + source = event.source + platform_name = source.platform.value if source.platform else "unknown" + chat_id = source.chat_id + chat_name = source.chat_name or chat_id + + env_key = f"{platform_name.upper()}_HOME_CHANNEL" + + # Save to config.yaml + try: + import yaml + config_path = _hermes_home / 'config.yaml' + user_config = {} + if config_path.exists(): + with open(config_path, encoding="utf-8") as f: + user_config = yaml.safe_load(f) or {} + user_config[env_key] = chat_id + with open(config_path, 'w', encoding="utf-8") as f: + yaml.dump(user_config, f, default_flow_style=False) + # Also set in the current environment so it takes effect immediately + os.environ[env_key] = str(chat_id) + except Exception as e: + return f"Failed to save home channel: {e}" + + return ( + f"✅ Home channel set to **{chat_name}** (ID: {chat_id}).\n" + f"Cron jobs and cross-platform messages will be delivered here." + ) + + async def _handle_rollback_command(self, event: MessageEvent) -> str: + """Handle /rollback command — list or restore filesystem checkpoints.""" + from tools.checkpoint_manager import CheckpointManager, format_checkpoint_list + + # Read checkpoint config from config.yaml + cp_cfg = {} + try: + import yaml as _y + _cfg_path = _hermes_home / "config.yaml" + if _cfg_path.exists(): + with open(_cfg_path, encoding="utf-8") as _f: + _data = _y.safe_load(_f) or {} + cp_cfg = _data.get("checkpoints", {}) + if isinstance(cp_cfg, bool): + cp_cfg = {"enabled": cp_cfg} + except Exception: + pass + + if not cp_cfg.get("enabled", False): + return ( + "Checkpoints are not enabled.\n" + "Enable in config.yaml:\n```\ncheckpoints:\n enabled: true\n```" + ) + + mgr = CheckpointManager( + enabled=True, + max_snapshots=cp_cfg.get("max_snapshots", 50), + ) + + cwd = os.getenv("MESSAGING_CWD", str(Path.home())) + arg = event.get_command_args().strip() + + if not arg: + checkpoints = mgr.list_checkpoints(cwd) + return format_checkpoint_list(checkpoints, cwd) + + # Restore by number or hash + checkpoints = mgr.list_checkpoints(cwd) + if not checkpoints: + return f"No checkpoints found for {cwd}" + + target_hash = None + try: + idx = int(arg) - 1 + if 0 <= idx < len(checkpoints): + target_hash = checkpoints[idx]["hash"] + else: + return f"Invalid checkpoint number. Use 1-{len(checkpoints)}." + except ValueError: + target_hash = arg + + result = mgr.restore(cwd, target_hash) + if result["success"]: + return ( + f"✅ Restored to checkpoint {result['restored_to']}: {result['reason']}\n" + f"A pre-rollback snapshot was saved automatically." + ) + return f"❌ {result['error']}" + + async def _handle_background_command(self, event: MessageEvent) -> str: + """Handle /background — run a prompt in a separate background session. + + Spawns a new AIAgent in a background thread with its own session. + When it completes, sends the result back to the same chat without + modifying the active session's conversation history. + """ + prompt = event.get_command_args().strip() + if not prompt: + return ( + "Usage: /background \n" + "Example: /background Summarize the top HN stories today\n\n" + "Runs the prompt in a separate session. " + "You can keep chatting — the result will appear here when done." + ) + + source = event.source + task_id = f"bg_{datetime.now().strftime('%H%M%S')}_{os.urandom(3).hex()}" + + # Fire-and-forget the background task + asyncio.create_task( + self._run_background_task(prompt, source, task_id) + ) + + preview = prompt[:60] + ("..." if len(prompt) > 60 else "") + return f'🔄 Background task started: "{preview}"\nTask ID: {task_id}\nYou can keep chatting — results will appear when done.' + + async def _run_background_task( + self, prompt: str, source: "SessionSource", task_id: str + ) -> None: + """Execute a background agent task and deliver the result to the chat.""" + from run_agent import AIAgent + + adapter = self.adapters.get(source.platform) + if not adapter: + logger.warning("No adapter for platform %s in background task %s", source.platform, task_id) + return + + _thread_metadata = {"thread_id": source.thread_id} if source.thread_id else None + + try: + runtime_kwargs = _resolve_runtime_agent_kwargs() + if not runtime_kwargs.get("api_key"): + await adapter.send( + source.chat_id, + f"❌ Background task {task_id} failed: no provider credentials configured.", + metadata=_thread_metadata, + ) + return + + # Read model from config via shared helper + model = _resolve_gateway_model() + + # Determine toolset (same logic as _run_agent) + default_toolset_map = { + Platform.LOCAL: "hermes-cli", + Platform.TELEGRAM: "hermes-telegram", + Platform.DISCORD: "hermes-discord", + Platform.WHATSAPP: "hermes-whatsapp", + Platform.SLACK: "hermes-slack", + Platform.SIGNAL: "hermes-signal", + Platform.HOMEASSISTANT: "hermes-homeassistant", + Platform.EMAIL: "hermes-email", + } + platform_toolsets_config = {} + try: + config_path = _hermes_home / 'config.yaml' + if config_path.exists(): + import yaml + with open(config_path, 'r', encoding="utf-8") as f: + user_config = yaml.safe_load(f) or {} + platform_toolsets_config = user_config.get("platform_toolsets", {}) + except Exception: + pass + + platform_config_key = { + Platform.LOCAL: "cli", + Platform.TELEGRAM: "telegram", + Platform.DISCORD: "discord", + Platform.WHATSAPP: "whatsapp", + Platform.SLACK: "slack", + Platform.SIGNAL: "signal", + Platform.HOMEASSISTANT: "homeassistant", + Platform.EMAIL: "email", + }.get(source.platform, "telegram") + + config_toolsets = platform_toolsets_config.get(platform_config_key) + if config_toolsets and isinstance(config_toolsets, list): + enabled_toolsets = config_toolsets + else: + default_toolset = default_toolset_map.get(source.platform, "hermes-telegram") + enabled_toolsets = [default_toolset] + + platform_key = "cli" if source.platform == Platform.LOCAL else source.platform.value + + pr = self._provider_routing + max_iterations = int(os.getenv("HERMES_MAX_ITERATIONS", "90")) + + def run_sync(): + agent = AIAgent( + model=model, + **runtime_kwargs, + max_iterations=max_iterations, + quiet_mode=True, + verbose_logging=False, + enabled_toolsets=enabled_toolsets, + reasoning_config=self._reasoning_config, + providers_allowed=pr.get("only"), + providers_ignored=pr.get("ignore"), + providers_order=pr.get("order"), + provider_sort=pr.get("sort"), + provider_require_parameters=pr.get("require_parameters", False), + provider_data_collection=pr.get("data_collection"), + session_id=task_id, + platform=platform_key, + session_db=self._session_db, + fallback_model=self._fallback_model, + ) + + return agent.run_conversation( + user_message=prompt, + task_id=task_id, + ) + + loop = asyncio.get_event_loop() + result = await loop.run_in_executor(None, run_sync) + + response = result.get("final_response", "") if result else "" + if not response and result and result.get("error"): + response = f"Error: {result['error']}" + + # Extract media files from the response + if response: + media_files, response = adapter.extract_media(response) + images, text_content = adapter.extract_images(response) + + preview = prompt[:60] + ("..." if len(prompt) > 60 else "") + header = f'✅ Background task complete\nPrompt: "{preview}"\n\n' + + if text_content: + await adapter.send( + chat_id=source.chat_id, + content=header + text_content, + metadata=_thread_metadata, + ) + elif not images and not media_files: + await adapter.send( + chat_id=source.chat_id, + content=header + "(No response generated)", + metadata=_thread_metadata, + ) + + # Send extracted images + for image_url, alt_text in (images or []): + try: + await adapter.send_image( + chat_id=source.chat_id, + image_url=image_url, + caption=alt_text, + ) + except Exception: + pass + + # Send media files + for media_path in (media_files or []): + try: + await adapter.send_file( + chat_id=source.chat_id, + file_path=media_path, + ) + except Exception: + pass + else: + preview = prompt[:60] + ("..." if len(prompt) > 60 else "") + await adapter.send( + chat_id=source.chat_id, + content=f'✅ Background task complete\nPrompt: "{preview}"\n\n(No response generated)', + metadata=_thread_metadata, + ) + + except Exception as e: + logger.exception("Background task %s failed", task_id) + try: + await adapter.send( + chat_id=source.chat_id, + content=f"❌ Background task {task_id} failed: {e}", + metadata=_thread_metadata, + ) + except Exception: + pass + + async def _handle_reasoning_command(self, event: MessageEvent) -> str: + """Handle /reasoning command — manage reasoning effort and display toggle. + + Usage: + /reasoning Show current effort level and display state + /reasoning Set reasoning effort (none, low, medium, high, xhigh) + /reasoning show|on Show model reasoning in responses + /reasoning hide|off Hide model reasoning from responses + """ + import yaml + + args = event.get_command_args().strip().lower() + config_path = _hermes_home / "config.yaml" + + def _save_config_key(key_path: str, value): + """Save a dot-separated key to config.yaml.""" + try: + user_config = {} + if config_path.exists(): + with open(config_path, encoding="utf-8") as f: + user_config = yaml.safe_load(f) or {} + keys = key_path.split(".") + current = user_config + for k in keys[:-1]: + if k not in current or not isinstance(current[k], dict): + current[k] = {} + current = current[k] + current[keys[-1]] = value + with open(config_path, "w", encoding="utf-8") as f: + yaml.dump(user_config, f, default_flow_style=False, sort_keys=False) + return True + except Exception as e: + logger.error("Failed to save config key %s: %s", key_path, e) + return False + + if not args: + # Show current state + rc = self._reasoning_config + if rc is None: + level = "medium (default)" + elif rc.get("enabled") is False: + level = "none (disabled)" + else: + level = rc.get("effort", "medium") + display_state = "on ✓" if self._show_reasoning else "off" + return ( + "🧠 **Reasoning Settings**\n\n" + f"**Effort:** `{level}`\n" + f"**Display:** {display_state}\n\n" + "_Usage:_ `/reasoning `" + ) + + # Display toggle + if args in ("show", "on"): + self._show_reasoning = True + _save_config_key("display.show_reasoning", True) + return "🧠 ✓ Reasoning display: **ON**\nModel thinking will be shown before each response." + + if args in ("hide", "off"): + self._show_reasoning = False + _save_config_key("display.show_reasoning", False) + return "🧠 ✓ Reasoning display: **OFF**" + + # Effort level change + effort = args.strip() + if effort == "none": + parsed = {"enabled": False} + elif effort in ("xhigh", "high", "medium", "low", "minimal"): + parsed = {"enabled": True, "effort": effort} + else: + return ( + f"⚠️ Unknown argument: `{effort}`\n\n" + "**Valid levels:** none, low, minimal, medium, high, xhigh\n" + "**Display:** show, hide" + ) + + self._reasoning_config = parsed + if _save_config_key("agent.reasoning_effort", effort): + return f"🧠 ✓ Reasoning effort set to `{effort}` (saved to config)\n_(takes effect on next message)_" + else: + return f"🧠 ✓ Reasoning effort set to `{effort}` (this session only)" + + async def _handle_compress_command(self, event: MessageEvent) -> str: + """Handle /compress command -- manually compress conversation context.""" + source = event.source + session_entry = self.session_store.get_or_create_session(source) + history = self.session_store.load_transcript(session_entry.session_id) + + if not history or len(history) < 4: + return "Not enough conversation to compress (need at least 4 messages)." + + try: + from run_agent import AIAgent + from agent.model_metadata import estimate_messages_tokens_rough + + runtime_kwargs = _resolve_runtime_agent_kwargs() + if not runtime_kwargs.get("api_key"): + return "No provider configured -- cannot compress." + + # Resolve model from config (same reason as memory flush above). + model = _resolve_gateway_model() + + msgs = [ + {"role": m.get("role"), "content": m.get("content")} + for m in history + if m.get("role") in ("user", "assistant") and m.get("content") + ] + original_count = len(msgs) + approx_tokens = estimate_messages_tokens_rough(msgs) + + tmp_agent = AIAgent( + **runtime_kwargs, + model=model, + max_iterations=4, + quiet_mode=True, + enabled_toolsets=["memory"], + session_id=session_entry.session_id, + ) + + loop = asyncio.get_event_loop() + compressed, _ = await loop.run_in_executor( + None, + lambda: tmp_agent._compress_context(msgs, "", approx_tokens=approx_tokens), + ) + + self.session_store.rewrite_transcript(session_entry.session_id, compressed) + # Reset stored token count — transcript changed, old value is stale + self.session_store.update_session( + session_entry.session_key, last_prompt_tokens=0, + ) + new_count = len(compressed) + new_tokens = estimate_messages_tokens_rough(compressed) + + return ( + f"🗜️ Compressed: {original_count} → {new_count} messages\n" + f"~{approx_tokens:,} → ~{new_tokens:,} tokens" + ) + except Exception as e: + logger.warning("Manual compress failed: %s", e) + return f"Compression failed: {e}" + + async def _handle_title_command(self, event: MessageEvent) -> str: + """Handle /title command — set or show the current session's title.""" + source = event.source + session_entry = self.session_store.get_or_create_session(source) + session_id = session_entry.session_id + + if not self._session_db: + return "Session database not available." + + title_arg = event.get_command_args().strip() + if title_arg: + # Sanitize the title before setting + try: + sanitized = self._session_db.sanitize_title(title_arg) + except ValueError as e: + return f"⚠️ {e}" + if not sanitized: + return "⚠️ Title is empty after cleanup. Please use printable characters." + # Set the title + try: + if self._session_db.set_session_title(session_id, sanitized): + return f"✏️ Session title set: **{sanitized}**" + else: + return "Session not found in database." + except ValueError as e: + return f"⚠️ {e}" + else: + # Show the current title + title = self._session_db.get_session_title(session_id) + if title: + return f"📌 Session title: **{title}**" + else: + return "No title set. Usage: `/title My Session Name`" + + async def _handle_resume_command(self, event: MessageEvent) -> str: + """Handle /resume command — switch to a previously-named session.""" + if not self._session_db: + return "Session database not available." + + source = event.source + session_key = build_session_key(source) + name = event.get_command_args().strip() + + if not name: + # List recent titled sessions for this user/platform + try: + user_source = source.platform.value if source.platform else None + sessions = self._session_db.list_sessions_rich( + source=user_source, limit=10 + ) + titled = [s for s in sessions if s.get("title")] + if not titled: + return ( + "No named sessions found.\n" + "Use `/title My Session` to name your current session, " + "then `/resume My Session` to return to it later." + ) + lines = ["📋 **Named Sessions**\n"] + for s in titled[:10]: + title = s["title"] + preview = s.get("preview", "")[:40] + preview_part = f" — _{preview}_" if preview else "" + lines.append(f"• **{title}**{preview_part}") + lines.append("\nUsage: `/resume `") + return "\n".join(lines) + except Exception as e: + logger.debug("Failed to list titled sessions: %s", e) + return f"Could not list sessions: {e}" + + # Resolve the name to a session ID + target_id = self._session_db.resolve_session_by_title(name) + if not target_id: + return ( + f"No session found matching '**{name}**'.\n" + "Use `/resume` with no arguments to see available sessions." + ) + + # Check if already on that session + current_entry = self.session_store.get_or_create_session(source) + if current_entry.session_id == target_id: + return f"📌 Already on session **{name}**." + + # Flush memories for current session before switching + try: + asyncio.create_task(self._async_flush_memories(current_entry.session_id)) + except Exception as e: + logger.debug("Memory flush on resume failed: %s", e) + + self._shutdown_gateway_honcho(session_key) + + # Clear any running agent for this session key + if session_key in self._running_agents: + del self._running_agents[session_key] + + # Switch the session entry to point at the old session + new_entry = self.session_store.switch_session(session_key, target_id) + if not new_entry: + return "Failed to switch session." + + # Get the title for confirmation + title = self._session_db.get_session_title(target_id) or name + + # Count messages for context + history = self.session_store.load_transcript(target_id) + msg_count = len([m for m in history if m.get("role") == "user"]) if history else 0 + msg_part = f" ({msg_count} message{'s' if msg_count != 1 else ''})" if msg_count else "" + + return f"↻ Resumed session **{title}**{msg_part}. Conversation restored." + + async def _handle_usage_command(self, event: MessageEvent) -> str: + """Handle /usage command -- show token usage for the session's last agent run.""" + source = event.source + session_key = build_session_key(source) + + agent = self._running_agents.get(session_key) + if agent and hasattr(agent, "session_total_tokens") and agent.session_api_calls > 0: + lines = [ + "📊 **Session Token Usage**", + f"Prompt (input): {agent.session_prompt_tokens:,}", + f"Completion (output): {agent.session_completion_tokens:,}", + f"Total: {agent.session_total_tokens:,}", + f"API calls: {agent.session_api_calls}", + ] + ctx = agent.context_compressor + if ctx.last_prompt_tokens: + pct = ctx.last_prompt_tokens / ctx.context_length * 100 if ctx.context_length else 0 + lines.append(f"Context: {ctx.last_prompt_tokens:,} / {ctx.context_length:,} ({pct:.0f}%)") + if ctx.compression_count: + lines.append(f"Compressions: {ctx.compression_count}") + return "\n".join(lines) + + # No running agent -- check session history for a rough count + session_entry = self.session_store.get_or_create_session(source) + history = self.session_store.load_transcript(session_entry.session_id) + if history: + from agent.model_metadata import estimate_messages_tokens_rough + msgs = [m for m in history if m.get("role") in ("user", "assistant") and m.get("content")] + approx = estimate_messages_tokens_rough(msgs) + return ( + f"📊 **Session Info**\n" + f"Messages: {len(msgs)}\n" + f"Estimated context: ~{approx:,} tokens\n" + f"_(Detailed usage available during active conversations)_" + ) + return "No usage data available for this session." + + async def _handle_insights_command(self, event: MessageEvent) -> str: + """Handle /insights command -- show usage insights and analytics.""" + import asyncio as _asyncio + + args = event.get_command_args().strip() + days = 30 + source = None + + # Parse simple args: /insights 7 or /insights --days 7 + if args: + parts = args.split() + i = 0 + while i < len(parts): + if parts[i] == "--days" and i + 1 < len(parts): + try: + days = int(parts[i + 1]) + except ValueError: + return f"Invalid --days value: {parts[i + 1]}" + i += 2 + elif parts[i] == "--source" and i + 1 < len(parts): + source = parts[i + 1] + i += 2 + elif parts[i].isdigit(): + days = int(parts[i]) + i += 1 + else: + i += 1 + + try: + from hermes_state import SessionDB + from agent.insights import InsightsEngine + + loop = _asyncio.get_event_loop() + + def _run_insights(): + db = SessionDB() + engine = InsightsEngine(db) + report = engine.generate(days=days, source=source) + result = engine.format_gateway(report) + db.close() + return result + + return await loop.run_in_executor(None, _run_insights) + except Exception as e: + logger.error("Insights command error: %s", e, exc_info=True) + return f"Error generating insights: {e}" + + async def _handle_reload_mcp_command(self, event: MessageEvent) -> str: + """Handle /reload-mcp command -- disconnect and reconnect all MCP servers.""" + loop = asyncio.get_event_loop() + try: + from tools.mcp_tool import shutdown_mcp_servers, discover_mcp_tools, _load_mcp_config, _servers, _lock + + # Capture old server names before shutdown + with _lock: + old_servers = set(_servers.keys()) + + # Read new config before shutting down, so we know what will be added/removed + new_config = _load_mcp_config() + new_server_names = set(new_config.keys()) + + # Shutdown existing connections + await loop.run_in_executor(None, shutdown_mcp_servers) + + # Reconnect by discovering tools (reads config.yaml fresh) + new_tools = await loop.run_in_executor(None, discover_mcp_tools) + + # Compute what changed + with _lock: + connected_servers = set(_servers.keys()) + + added = connected_servers - old_servers + removed = old_servers - connected_servers + reconnected = connected_servers & old_servers + + lines = ["🔄 **MCP Servers Reloaded**\n"] + if reconnected: + lines.append(f"♻️ Reconnected: {', '.join(sorted(reconnected))}") + if added: + lines.append(f"➕ Added: {', '.join(sorted(added))}") + if removed: + lines.append(f"➖ Removed: {', '.join(sorted(removed))}") + if not connected_servers: + lines.append("No MCP servers connected.") + else: + lines.append(f"\n🔧 {len(new_tools)} tool(s) available from {len(connected_servers)} server(s)") + + # Inject a message at the END of the session history so the + # model knows tools changed on its next turn. Appended after + # all existing messages to preserve prompt-cache for the prefix. + change_parts = [] + if added: + change_parts.append(f"Added servers: {', '.join(sorted(added))}") + if removed: + change_parts.append(f"Removed servers: {', '.join(sorted(removed))}") + if reconnected: + change_parts.append(f"Reconnected servers: {', '.join(sorted(reconnected))}") + tool_summary = f"{len(new_tools)} MCP tool(s) now available" if new_tools else "No MCP tools available" + change_detail = ". ".join(change_parts) + ". " if change_parts else "" + reload_msg = { + "role": "user", + "content": f"[SYSTEM: MCP servers have been reloaded. {change_detail}{tool_summary}. The tool list for this conversation has been updated accordingly.]", + } + try: + session_entry = self.session_store.get_or_create_session(event.source) + self.session_store.append_to_transcript( + session_entry.session_id, reload_msg + ) + except Exception: + pass # Best-effort; don't fail the reload over a transcript write + + return "\n".join(lines) + + except Exception as e: + logger.warning("MCP reload failed: %s", e) + return f"❌ MCP reload failed: {e}" + + async def _handle_update_command(self, event: MessageEvent) -> str: + """Handle /update command — update Hermes Agent to the latest version. + + Spawns ``hermes update`` in a separate systemd scope so it survives the + gateway restart that ``hermes update`` triggers at the end. A marker + file is written so the *new* gateway process can notify the user of the + result on startup. + """ + import json + import shutil + import subprocess + from datetime import datetime + + project_root = Path(__file__).parent.parent.resolve() + git_dir = project_root / '.git' + + if not git_dir.exists(): + return "✗ Not a git repository — cannot update." + + hermes_bin = shutil.which("hermes") + if not hermes_bin: + return "✗ `hermes` command not found on PATH." + + # Write marker so the restarted gateway can notify this chat + pending_path = _hermes_home / ".update_pending.json" + output_path = _hermes_home / ".update_output.txt" + pending = { + "platform": event.source.platform.value, + "chat_id": event.source.chat_id, + "user_id": event.source.user_id, + "timestamp": datetime.now().isoformat(), + } + pending_path.write_text(json.dumps(pending)) + + # Spawn `hermes update` in a separate cgroup so it survives gateway + # restart. systemd-run --user --scope creates a transient scope unit. + update_cmd = f"{hermes_bin} update > {output_path} 2>&1" + try: + systemd_run = shutil.which("systemd-run") + if systemd_run: + subprocess.Popen( + [systemd_run, "--user", "--scope", + "--unit=hermes-update", "--", + "bash", "-c", update_cmd], + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + start_new_session=True, + ) + else: + # Fallback: best-effort detach with start_new_session + subprocess.Popen( + ["bash", "-c", f"nohup {update_cmd} &"], + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + start_new_session=True, + ) + except Exception as e: + pending_path.unlink(missing_ok=True) + return f"✗ Failed to start update: {e}" + + return "⚕ Starting Hermes update… I'll notify you when it's done." + + async def _send_update_notification(self) -> None: + """If the gateway is starting after a ``/update``, notify the user.""" + import json + import re as _re + + pending_path = _hermes_home / ".update_pending.json" + output_path = _hermes_home / ".update_output.txt" + + if not pending_path.exists(): + return + + try: + pending = json.loads(pending_path.read_text()) + platform_str = pending.get("platform") + chat_id = pending.get("chat_id") + + # Read the captured update output + output = "" + if output_path.exists(): + output = output_path.read_text() + + # Resolve adapter + platform = Platform(platform_str) + adapter = self.adapters.get(platform) + + if adapter and chat_id: + # Strip ANSI escape codes for clean display + output = _re.sub(r'\x1b\[[0-9;]*m', '', output).strip() + if output: + # Truncate if too long for a single message + if len(output) > 3500: + output = "…" + output[-3500:] + msg = f"✅ Hermes update finished — gateway restarted.\n\n```\n{output}\n```" + else: + msg = "✅ Hermes update finished — gateway restarted successfully." + await adapter.send(chat_id, msg) + logger.info("Sent post-update notification to %s:%s", platform_str, chat_id) + except Exception as e: + logger.warning("Post-update notification failed: %s", e) + finally: + pending_path.unlink(missing_ok=True) + output_path.unlink(missing_ok=True) + + def _set_session_env(self, context: SessionContext) -> None: + """Set environment variables for the current session.""" + os.environ["HERMES_SESSION_PLATFORM"] = context.source.platform.value + os.environ["HERMES_SESSION_CHAT_ID"] = context.source.chat_id + if context.source.chat_name: + os.environ["HERMES_SESSION_CHAT_NAME"] = context.source.chat_name + + def _clear_session_env(self) -> None: + """Clear session environment variables.""" + for var in ["HERMES_SESSION_PLATFORM", "HERMES_SESSION_CHAT_ID", "HERMES_SESSION_CHAT_NAME"]: + if var in os.environ: + del os.environ[var] + + async def _enrich_message_with_vision( + self, + user_text: str, + image_paths: List[str], + ) -> str: + """ + Auto-analyze user-attached images with the vision tool and prepend + the descriptions to the message text. + + Each image is analyzed with a general-purpose prompt. The resulting + description *and* the local cache path are injected so the model can: + 1. Immediately understand what the user sent (no extra tool call). + 2. Re-examine the image with vision_analyze if it needs more detail. + + Args: + user_text: The user's original caption / message text. + image_paths: List of local file paths to cached images. + + Returns: + The enriched message string with vision descriptions prepended. + """ + from tools.vision_tools import vision_analyze_tool + import json as _json + + analysis_prompt = ( + "Describe everything visible in this image in thorough detail. " + "Include any text, code, data, objects, people, layout, colors, " + "and any other notable visual information." + ) + + enriched_parts = [] + for path in image_paths: + try: + logger.debug("Auto-analyzing user image: %s", path) + result_json = await vision_analyze_tool( + image_url=path, + user_prompt=analysis_prompt, + ) + result = _json.loads(result_json) + if result.get("success"): + description = result.get("analysis", "") + enriched_parts.append( + f"[The user sent an image~ Here's what I can see:\n{description}]\n" + f"[If you need a closer look, use vision_analyze with " + f"image_url: {path} ~]" + ) + else: + enriched_parts.append( + "[The user sent an image but I couldn't quite see it " + "this time (>_<) You can try looking at it yourself " + f"with vision_analyze using image_url: {path}]" + ) + except Exception as e: + logger.error("Vision auto-analysis error: %s", e) + enriched_parts.append( + f"[The user sent an image but something went wrong when I " + f"tried to look at it~ You can try examining it yourself " + f"with vision_analyze using image_url: {path}]" + ) + + # Combine: vision descriptions first, then the user's original text + if enriched_parts: + prefix = "\n\n".join(enriched_parts) + if user_text: + return f"{prefix}\n\n{user_text}" + return prefix + return user_text + + async def _enrich_message_with_transcription( + self, + user_text: str, + audio_paths: List[str], + ) -> str: + """ + Auto-transcribe user voice/audio messages using OpenAI Whisper API + and prepend the transcript to the message text. + + Args: + user_text: The user's original caption / message text. + audio_paths: List of local file paths to cached audio files. + + Returns: + The enriched message string with transcriptions prepended. + """ + from tools.transcription_tools import transcribe_audio + import asyncio + + enriched_parts = [] + for path in audio_paths: + try: + logger.debug("Transcribing user voice: %s", path) + result = await asyncio.to_thread(transcribe_audio, path) + if result["success"]: + transcript = result["transcript"] + enriched_parts.append( + f'[The user sent a voice message~ ' + f'Here\'s what they said: "{transcript}"]' + ) + else: + error = result.get("error", "unknown error") + if "OPENAI_API_KEY" in error or "VOICE_TOOLS_OPENAI_KEY" in error: + enriched_parts.append( + "[The user sent a voice message but I can't listen " + "to it right now~ VOICE_TOOLS_OPENAI_KEY isn't set up yet " + "(';w;') Let them know!]" + ) + else: + enriched_parts.append( + "[The user sent a voice message but I had trouble " + f"transcribing it~ ({error})]" + ) + except Exception as e: + logger.error("Transcription error: %s", e) + enriched_parts.append( + "[The user sent a voice message but something went wrong " + "when I tried to listen to it~ Let them know!]" + ) + + if enriched_parts: + prefix = "\n\n".join(enriched_parts) + if user_text: + return f"{prefix}\n\n{user_text}" + return prefix + return user_text + + async def _run_process_watcher(self, watcher: dict) -> None: + """ + Periodically check a background process and push updates to the user. + + Runs as an asyncio task. Stays silent when nothing changed. + Auto-removes when the process exits or is killed. + + Notification mode (from ``display.background_process_notifications``): + - ``all`` — running-output updates + final message + - ``result`` — final completion message only + - ``error`` — final message only when exit code != 0 + - ``off`` — no messages at all + """ + from tools.process_registry import process_registry + + session_id = watcher["session_id"] + interval = watcher["check_interval"] + session_key = watcher.get("session_key", "") + platform_name = watcher.get("platform", "") + chat_id = watcher.get("chat_id", "") + notify_mode = self._load_background_notifications_mode() + + logger.debug("Process watcher started: %s (every %ss, notify=%s)", + session_id, interval, notify_mode) + + if notify_mode == "off": + # Still wait for the process to exit so we can log it, but don't + # push any messages to the user. + while True: + await asyncio.sleep(interval) + session = process_registry.get(session_id) + if session is None or session.exited: + break + logger.debug("Process watcher ended (silent): %s", session_id) + return + + last_output_len = 0 + while True: + await asyncio.sleep(interval) + + session = process_registry.get(session_id) + if session is None: + break + + current_output_len = len(session.output_buffer) + has_new_output = current_output_len > last_output_len + last_output_len = current_output_len + + if session.exited: + # Decide whether to notify based on mode + should_notify = ( + notify_mode in ("all", "result") + or (notify_mode == "error" and session.exit_code not in (0, None)) + ) + if should_notify: + new_output = session.output_buffer[-1000:] if session.output_buffer else "" + message_text = ( + f"[Background process {session_id} finished with exit code {session.exit_code}~ " + f"Here's the final output:\n{new_output}]" + ) + adapter = None + for p, a in self.adapters.items(): + if p.value == platform_name: + adapter = a + break + if adapter and chat_id: + try: + await adapter.send(chat_id, message_text) + except Exception as e: + logger.error("Watcher delivery error: %s", e) + break + + elif has_new_output and notify_mode == "all": + # New output available -- deliver status update (only in "all" mode) + new_output = session.output_buffer[-500:] if session.output_buffer else "" + message_text = ( + f"[Background process {session_id} is still running~ " + f"New output:\n{new_output}]" + ) + adapter = None + for p, a in self.adapters.items(): + if p.value == platform_name: + adapter = a + break + if adapter and chat_id: + try: + await adapter.send(chat_id, message_text) + except Exception as e: + logger.error("Watcher delivery error: %s", e) + + logger.debug("Process watcher ended: %s", session_id) + + async def _run_agent( + self, + message: str, + context_prompt: str, + history: List[Dict[str, Any]], + source: SessionSource, + session_id: str, + session_key: str = None + ) -> Dict[str, Any]: + """ + Run the agent with the given message and context. + + Returns the full result dict from run_conversation, including: + - "final_response": str (the text to send back) + - "messages": list (full conversation including tool calls) + - "api_calls": int + - "completed": bool + + This is run in a thread pool to not block the event loop. + Supports interruption via new messages. + """ + from run_agent import AIAgent + import queue + + # Determine toolset based on platform. + # Check config.yaml for per-platform overrides, fallback to hardcoded defaults. + default_toolset_map = { + Platform.LOCAL: "hermes-cli", + Platform.TELEGRAM: "hermes-telegram", + Platform.DISCORD: "hermes-discord", + Platform.WHATSAPP: "hermes-whatsapp", + Platform.SLACK: "hermes-slack", + Platform.SIGNAL: "hermes-signal", + Platform.HOMEASSISTANT: "hermes-homeassistant", + Platform.EMAIL: "hermes-email", + } + + # Try to load platform_toolsets from config + platform_toolsets_config = {} + try: + config_path = _hermes_home / 'config.yaml' + if config_path.exists(): + import yaml + with open(config_path, 'r', encoding="utf-8") as f: + user_config = yaml.safe_load(f) or {} + platform_toolsets_config = user_config.get("platform_toolsets", {}) + except Exception as e: + logger.debug("Could not load platform_toolsets config: %s", e) + + # Map platform enum to config key + platform_config_key = { + Platform.LOCAL: "cli", + Platform.TELEGRAM: "telegram", + Platform.DISCORD: "discord", + Platform.WHATSAPP: "whatsapp", + Platform.SLACK: "slack", + Platform.SIGNAL: "signal", + Platform.HOMEASSISTANT: "homeassistant", + Platform.EMAIL: "email", + }.get(source.platform, "telegram") + + # Use config override if present (list of toolsets), otherwise hardcoded default + config_toolsets = platform_toolsets_config.get(platform_config_key) + if config_toolsets and isinstance(config_toolsets, list): + enabled_toolsets = config_toolsets + else: + default_toolset = default_toolset_map.get(source.platform, "hermes-telegram") + enabled_toolsets = [default_toolset] + + # Tool progress mode from config.yaml: "all", "new", "verbose", "off" + # Falls back to env vars for backward compatibility + _progress_cfg = {} + try: + _tp_cfg_path = _hermes_home / "config.yaml" + if _tp_cfg_path.exists(): + import yaml as _tp_yaml + with open(_tp_cfg_path, encoding="utf-8") as _tp_f: + _tp_data = _tp_yaml.safe_load(_tp_f) or {} + _progress_cfg = _tp_data.get("display", {}) + except Exception: + pass + progress_mode = ( + _progress_cfg.get("tool_progress") + or os.getenv("HERMES_TOOL_PROGRESS_MODE") + or "all" + ) + tool_progress_enabled = progress_mode != "off" + + # Queue for progress messages (thread-safe) + progress_queue = queue.Queue() if tool_progress_enabled else None + last_tool = [None] # Mutable container for tracking in closure + last_progress_msg = [None] # Track last message for dedup + repeat_count = [0] # How many times the same message repeated + + def progress_callback(tool_name: str, preview: str = None, args: dict = None): + """Callback invoked by agent when a tool is called.""" + if not progress_queue: + return + + # "new" mode: only report when tool changes + if progress_mode == "new" and tool_name == last_tool[0]: + return + last_tool[0] = tool_name + + # Build progress message with primary argument preview + tool_emojis = { + "terminal": "💻", + "process": "⚙️", + "web_search": "🔍", + "web_extract": "📄", + "read_file": "📖", + "write_file": "✍️", + "patch": "🔧", + "search": "🔎", + "search_files": "🔎", + "list_directory": "📂", + "image_generate": "🎨", + "text_to_speech": "🔊", + "browser_navigate": "🌐", + "browser_click": "👆", + "browser_type": "⌨️", + "browser_snapshot": "📸", + "browser_scroll": "📜", + "browser_back": "◀️", + "browser_press": "⌨️", + "browser_close": "🚪", + "browser_get_images": "🖼️", + "browser_vision": "👁️", + "moa_query": "🧠", + "mixture_of_agents": "🧠", + "vision_analyze": "👁️", + "skill_view": "📚", + "skills_list": "📋", + "todo": "📋", + "memory": "🧠", + "session_search": "🔍", + "send_message": "📨", + "schedule_cronjob": "⏰", + "list_cronjobs": "⏰", + "remove_cronjob": "⏰", + "execute_code": "🐍", + "delegate_task": "🔀", + "clarify": "❓", + "skill_manage": "📝", + } + emoji = tool_emojis.get(tool_name, "⚙️") + + # Verbose mode: show detailed arguments + if progress_mode == "verbose" and args: + import json as _json + args_str = _json.dumps(args, ensure_ascii=False, default=str) + if len(args_str) > 200: + args_str = args_str[:197] + "..." + msg = f"{emoji} {tool_name}({list(args.keys())})\n{args_str}" + progress_queue.put(msg) + return + + if preview: + # Truncate preview to keep messages clean + if len(preview) > 80: + preview = preview[:77] + "..." + msg = f"{emoji} {tool_name}: \"{preview}\"" + else: + msg = f"{emoji} {tool_name}..." + + # Dedup: collapse consecutive identical progress messages. + # Common with execute_code where models iterate with the same + # code (same boilerplate imports → identical previews). + if msg == last_progress_msg[0]: + repeat_count[0] += 1 + # Update the last line in progress_lines with a counter + # via a special "dedup" queue message. + progress_queue.put(("__dedup__", msg, repeat_count[0])) + return + last_progress_msg[0] = msg + repeat_count[0] = 0 + + progress_queue.put(msg) + + # Background task to send progress messages + # Accumulates tool lines into a single message that gets edited + _progress_metadata = {"thread_id": source.thread_id} if source.thread_id else None + + async def send_progress_messages(): + if not progress_queue: + return + + adapter = self.adapters.get(source.platform) + if not adapter: + return + + progress_lines = [] # Accumulated tool lines + progress_msg_id = None # ID of the progress message to edit + can_edit = True # False once an edit fails (platform doesn't support it) + + while True: + try: + raw = progress_queue.get_nowait() + + # Handle dedup messages: update last line with repeat counter + if isinstance(raw, tuple) and len(raw) == 3 and raw[0] == "__dedup__": + _, base_msg, count = raw + if progress_lines: + progress_lines[-1] = f"{base_msg} (×{count + 1})" + msg = progress_lines[-1] if progress_lines else base_msg + else: + msg = raw + progress_lines.append(msg) + + if can_edit and progress_msg_id is not None: + # Try to edit the existing progress message + full_text = "\n".join(progress_lines) + result = await adapter.edit_message( + chat_id=source.chat_id, + message_id=progress_msg_id, + content=full_text, + ) + if not result.success: + # Platform doesn't support editing — stop trying, + # send just this new line as a separate message + can_edit = False + await adapter.send(chat_id=source.chat_id, content=msg, metadata=_progress_metadata) + else: + if can_edit: + # First tool: send all accumulated text as new message + full_text = "\n".join(progress_lines) + result = await adapter.send(chat_id=source.chat_id, content=full_text, metadata=_progress_metadata) + else: + # Editing unsupported: send just this line + result = await adapter.send(chat_id=source.chat_id, content=msg, metadata=_progress_metadata) + if result.success and result.message_id: + progress_msg_id = result.message_id + + # Restore typing indicator + await asyncio.sleep(0.3) + await adapter.send_typing(source.chat_id, metadata=_progress_metadata) + + except queue.Empty: + await asyncio.sleep(0.3) + except asyncio.CancelledError: + # Drain remaining queued messages + while not progress_queue.empty(): + try: + raw = progress_queue.get_nowait() + if isinstance(raw, tuple) and len(raw) == 3 and raw[0] == "__dedup__": + _, base_msg, count = raw + if progress_lines: + progress_lines[-1] = f"{base_msg} (×{count + 1})" + else: + progress_lines.append(raw) + except Exception: + break + # Final edit with all remaining tools (only if editing works) + if can_edit and progress_lines and progress_msg_id: + full_text = "\n".join(progress_lines) + try: + await adapter.edit_message( + chat_id=source.chat_id, + message_id=progress_msg_id, + content=full_text, + ) + except Exception: + pass + return + except Exception as e: + logger.error("Progress message error: %s", e) + await asyncio.sleep(1) + + # We need to share the agent instance for interrupt support + agent_holder = [None] # Mutable container for the agent instance + result_holder = [None] # Mutable container for the result + tools_holder = [None] # Mutable container for the tool definitions + + # Bridge sync step_callback → async hooks.emit for agent:step events + _loop_for_step = asyncio.get_event_loop() + _hooks_ref = self.hooks + + def _step_callback_sync(iteration: int, tool_names: list) -> None: + try: + asyncio.run_coroutine_threadsafe( + _hooks_ref.emit("agent:step", { + "platform": source.platform.value if source.platform else "", + "user_id": source.user_id, + "session_id": session_id, + "iteration": iteration, + "tool_names": tool_names, + }), + _loop_for_step, + ) + except Exception as _e: + logger.debug("agent:step hook error: %s", _e) + + def run_sync(): + # Pass session_key to process registry via env var so background + # processes can be mapped back to this gateway session + os.environ["HERMES_SESSION_KEY"] = session_key or "" + + # Read from env var or use default (same as CLI) + max_iterations = int(os.getenv("HERMES_MAX_ITERATIONS", "90")) + + # Map platform enum to the platform hint key the agent understands. + # Platform.LOCAL ("local") maps to "cli"; others pass through as-is. + platform_key = "cli" if source.platform == Platform.LOCAL else source.platform.value + + # Combine platform context with user-configured ephemeral system prompt + combined_ephemeral = context_prompt or "" + if self._ephemeral_system_prompt: + combined_ephemeral = (combined_ephemeral + "\n\n" + self._ephemeral_system_prompt).strip() + + # Re-read .env and config for fresh credentials (gateway is long-lived, + # keys may change without restart). + try: + load_dotenv(_env_path, override=True, encoding="utf-8") + except UnicodeDecodeError: + load_dotenv(_env_path, override=True, encoding="latin-1") + except Exception: + pass + + model = _resolve_gateway_model() + + try: + runtime_kwargs = _resolve_runtime_agent_kwargs() + except Exception as exc: + return { + "final_response": f"⚠️ Provider authentication failed: {exc}", + "messages": [], + "api_calls": 0, + "tools": [], + } + + pr = self._provider_routing + honcho_manager, honcho_config = self._get_or_create_gateway_honcho(session_key) + agent = AIAgent( + model=model, + **runtime_kwargs, + max_iterations=max_iterations, + quiet_mode=True, + verbose_logging=False, + enabled_toolsets=enabled_toolsets, + ephemeral_system_prompt=combined_ephemeral or None, + prefill_messages=self._prefill_messages or None, + reasoning_config=self._reasoning_config, + providers_allowed=pr.get("only"), + providers_ignored=pr.get("ignore"), + providers_order=pr.get("order"), + provider_sort=pr.get("sort"), + provider_require_parameters=pr.get("require_parameters", False), + provider_data_collection=pr.get("data_collection"), + session_id=session_id, + tool_progress_callback=progress_callback if tool_progress_enabled else None, + step_callback=_step_callback_sync if _hooks_ref.loaded_hooks else None, + platform=platform_key, + honcho_session_key=session_key, + honcho_manager=honcho_manager, + honcho_config=honcho_config, + session_db=self._session_db, + fallback_model=self._fallback_model, + ) + + # Store agent reference for interrupt support + agent_holder[0] = agent + # Capture the full tool definitions for transcript logging + tools_holder[0] = agent.tools if hasattr(agent, 'tools') else None + + # Convert history to agent format. + # Two cases: + # 1. Normal path (from transcript): simple {role, content, timestamp} dicts + # - Strip timestamps, keep role+content + # 2. Interrupt path (from agent result["messages"]): full agent messages + # that may include tool_calls, tool_call_id, reasoning, etc. + # - These must be passed through intact so the API sees valid + # assistant→tool sequences (dropping tool_calls causes 500 errors) + agent_history = [] + for msg in history: + role = msg.get("role") + if not role: + continue + + # Skip metadata entries (tool definitions, session info) + # -- these are for transcript logging, not for the LLM + if role in ("session_meta",): + continue + + # Skip system messages -- the agent rebuilds its own system prompt + if role == "system": + continue + + # Rich agent messages (tool_calls, tool results) must be passed + # through intact so the API sees valid assistant→tool sequences + has_tool_calls = "tool_calls" in msg + has_tool_call_id = "tool_call_id" in msg + is_tool_message = role == "tool" + + if has_tool_calls or has_tool_call_id or is_tool_message: + clean_msg = {k: v for k, v in msg.items() if k != "timestamp"} + agent_history.append(clean_msg) + else: + # Simple text message - just need role and content + content = msg.get("content") + if content: + # Tag cross-platform mirror messages so the agent knows their origin + if msg.get("mirror"): + mirror_src = msg.get("mirror_source", "another session") + content = f"[Delivered from {mirror_src}] {content}" + agent_history.append({"role": role, "content": content}) + + # Collect MEDIA paths already in history so we can exclude them + # from the current turn's extraction. This is compression-safe: + # even if the message list shrinks, we know which paths are old. + _history_media_paths: set = set() + for _hm in agent_history: + if _hm.get("role") in ("tool", "function"): + _hc = _hm.get("content", "") + if "MEDIA:" in _hc: + for _match in re.finditer(r'MEDIA:(\S+)', _hc): + _p = _match.group(1).strip().rstrip('",}') + if _p: + _history_media_paths.add(_p) + + result = agent.run_conversation(message, conversation_history=agent_history, task_id=session_id) + result_holder[0] = result + + # Return final response, or a message if something went wrong + final_response = result.get("final_response") + + # Extract last actual prompt token count from the agent's compressor + _last_prompt_toks = 0 + _agent = agent_holder[0] + if _agent and hasattr(_agent, "context_compressor"): + _last_prompt_toks = getattr(_agent.context_compressor, "last_prompt_tokens", 0) + + if not final_response: + error_msg = f"⚠️ {result['error']}" if result.get("error") else "(No response generated)" + return { + "final_response": error_msg, + "messages": result.get("messages", []), + "api_calls": result.get("api_calls", 0), + "tools": tools_holder[0] or [], + "history_offset": len(agent_history), + "last_prompt_tokens": _last_prompt_toks, + } + + # Scan tool results for MEDIA: tags that need to be delivered + # as native audio/file attachments. The TTS tool embeds MEDIA: tags + # in its JSON response, but the model's final text reply usually + # doesn't include them. We collect unique tags from tool results and + # append any that aren't already present in the final response, so the + # adapter's extract_media() can find and deliver the files exactly once. + # + # Uses path-based deduplication against _history_media_paths (collected + # before run_conversation) instead of index slicing. This is safe even + # when context compression shrinks the message list. (Fixes #160) + if "MEDIA:" not in final_response: + media_tags = [] + has_voice_directive = False + for msg in result.get("messages", []): + if msg.get("role") in ("tool", "function"): + content = msg.get("content", "") + if "MEDIA:" in content: + for match in re.finditer(r'MEDIA:(\S+)', content): + path = match.group(1).strip().rstrip('",}') + if path and path not in _history_media_paths: + media_tags.append(f"MEDIA:{path}") + if "[[audio_as_voice]]" in content: + has_voice_directive = True + + if media_tags: + seen = set() + unique_tags = [] + for tag in media_tags: + if tag not in seen: + seen.add(tag) + unique_tags.append(tag) + if has_voice_directive: + unique_tags.insert(0, "[[audio_as_voice]]") + final_response = final_response + "\n" + "\n".join(unique_tags) + + # Sync session_id: the agent may have created a new session during + # mid-run context compression (_compress_context splits sessions). + # If so, update the session store entry so the NEXT message loads + # the compressed transcript, not the stale pre-compression one. + agent = agent_holder[0] + if agent and session_key and hasattr(agent, 'session_id') and agent.session_id != session_id: + logger.info( + "Session split detected: %s → %s (compression)", + session_id, agent.session_id, + ) + entry = self.session_store._entries.get(session_key) + if entry: + entry.session_id = agent.session_id + self.session_store._save() + + effective_session_id = getattr(agent, 'session_id', session_id) if agent else session_id + + return { + "final_response": final_response, + "last_reasoning": result.get("last_reasoning"), + "messages": result_holder[0].get("messages", []) if result_holder[0] else [], + "api_calls": result_holder[0].get("api_calls", 0) if result_holder[0] else 0, + "tools": tools_holder[0] or [], + "history_offset": len(agent_history), + "last_prompt_tokens": _last_prompt_toks, + "session_id": effective_session_id, + } + + # Start progress message sender if enabled + progress_task = None + if tool_progress_enabled: + progress_task = asyncio.create_task(send_progress_messages()) + + # Track this agent as running for this session (for interrupt support) + # We do this in a callback after the agent is created + async def track_agent(): + # Wait for agent to be created + while agent_holder[0] is None: + await asyncio.sleep(0.05) + if session_key: + self._running_agents[session_key] = agent_holder[0] + + tracking_task = asyncio.create_task(track_agent()) + + # Monitor for interrupts from the adapter (new messages arriving) + async def monitor_for_interrupt(): + adapter = self.adapters.get(source.platform) + if not adapter or not session_key: + return + + while True: + await asyncio.sleep(0.2) # Check every 200ms + # Check if adapter has a pending interrupt for this session. + # Must use session_key (build_session_key output) — NOT + # source.chat_id — because the adapter stores interrupt events + # under the full session key. + if hasattr(adapter, 'has_pending_interrupt') and adapter.has_pending_interrupt(session_key): + agent = agent_holder[0] + if agent: + pending_event = adapter.get_pending_message(session_key) + pending_text = pending_event.text if pending_event else None + logger.debug("Interrupt detected from adapter, signaling agent...") + agent.interrupt(pending_text) + break + + interrupt_monitor = asyncio.create_task(monitor_for_interrupt()) + + try: + # Run in thread pool to not block + loop = asyncio.get_event_loop() + response = await loop.run_in_executor(None, run_sync) + + # Check if we were interrupted and have a pending message + result = result_holder[0] + adapter = self.adapters.get(source.platform) + + # Get pending message from adapter if interrupted. + # Use session_key (not source.chat_id) to match adapter's storage keys. + pending = None + if result and result.get("interrupted") and adapter: + pending_event = adapter.get_pending_message(session_key) if session_key else None + if pending_event: + pending = pending_event.text + elif result.get("interrupt_message"): + pending = result.get("interrupt_message") + + if pending: + logger.debug("Processing interrupted message: '%s...'", pending[:40]) + + # Clear the adapter's interrupt event so the next _run_agent call + # doesn't immediately re-trigger the interrupt before the new agent + # even makes its first API call (this was causing an infinite loop). + if adapter and hasattr(adapter, '_active_sessions') and session_key and session_key in adapter._active_sessions: + adapter._active_sessions[session_key].clear() + + # Don't send the interrupted response to the user — it's just noise + # like "Operation interrupted." They already know they sent a new + # message, so go straight to processing it. + + # Now process the pending message with updated history + updated_history = result.get("messages", history) + return await self._run_agent( + message=pending, + context_prompt=context_prompt, + history=updated_history, + source=source, + session_id=session_id, + session_key=session_key + ) + finally: + # Stop progress sender and interrupt monitor + if progress_task: + progress_task.cancel() + interrupt_monitor.cancel() + + # Clean up tracking + tracking_task.cancel() + if session_key and session_key in self._running_agents: + del self._running_agents[session_key] + + # Wait for cancelled tasks + for task in [progress_task, interrupt_monitor, tracking_task]: + if task: + try: + await task + except asyncio.CancelledError: + pass + + return response + + +def _start_cron_ticker(stop_event: threading.Event, adapters=None, interval: int = 60): + """ + Background thread that ticks the cron scheduler at a regular interval. + + Runs inside the gateway process so cronjobs fire automatically without + needing a separate `hermes cron daemon` or system cron entry. + + Also refreshes the channel directory every 5 minutes and prunes the + image/audio/document cache once per hour. + """ + from cron.scheduler import tick as cron_tick + from gateway.platforms.base import cleanup_image_cache, cleanup_document_cache + + IMAGE_CACHE_EVERY = 60 # ticks — once per hour at default 60s interval + CHANNEL_DIR_EVERY = 5 # ticks — every 5 minutes + + logger.info("Cron ticker started (interval=%ds)", interval) + tick_count = 0 + while not stop_event.is_set(): + try: + cron_tick(verbose=False) + except Exception as e: + logger.debug("Cron tick error: %s", e) + + tick_count += 1 + + if tick_count % CHANNEL_DIR_EVERY == 0 and adapters: + try: + from gateway.channel_directory import build_channel_directory + build_channel_directory(adapters) + except Exception as e: + logger.debug("Channel directory refresh error: %s", e) + + if tick_count % IMAGE_CACHE_EVERY == 0: + try: + removed = cleanup_image_cache(max_age_hours=24) + if removed: + logger.info("Image cache cleanup: removed %d stale file(s)", removed) + except Exception as e: + logger.debug("Image cache cleanup error: %s", e) + try: + removed = cleanup_document_cache(max_age_hours=24) + if removed: + logger.info("Document cache cleanup: removed %d stale file(s)", removed) + except Exception as e: + logger.debug("Document cache cleanup error: %s", e) + + stop_event.wait(timeout=interval) + logger.info("Cron ticker stopped") + + +async def start_gateway(config: Optional[GatewayConfig] = None, replace: bool = False) -> bool: + """ + Start the gateway and run until interrupted. + + This is the main entry point for running the gateway. + Returns True if the gateway ran successfully, False if it failed to start. + A False return causes a non-zero exit code so systemd can auto-restart. + + Args: + config: Optional gateway configuration override. + replace: If True, kill any existing gateway instance before starting. + Useful for systemd services to avoid restart-loop deadlocks + when the previous process hasn't fully exited yet. + """ + # ── Duplicate-instance guard ────────────────────────────────────── + # Prevent two gateways from running under the same HERMES_HOME. + # The PID file is scoped to HERMES_HOME, so future multi-profile + # setups (each profile using a distinct HERMES_HOME) will naturally + # allow concurrent instances without tripping this guard. + import time as _time + from gateway.status import get_running_pid, remove_pid_file + existing_pid = get_running_pid() + if existing_pid is not None and existing_pid != os.getpid(): + if replace: + logger.info( + "Replacing existing gateway instance (PID %d) with --replace.", + existing_pid, + ) + try: + os.kill(existing_pid, signal.SIGTERM) + except ProcessLookupError: + pass # Already gone + except PermissionError: + logger.error( + "Permission denied killing PID %d. Cannot replace.", + existing_pid, + ) + return False + # Wait up to 10 seconds for the old process to exit + for _ in range(20): + try: + os.kill(existing_pid, 0) + _time.sleep(0.5) + except (ProcessLookupError, PermissionError): + break # Process is gone + else: + # Still alive after 10s — force kill + logger.warning( + "Old gateway (PID %d) did not exit after SIGTERM, sending SIGKILL.", + existing_pid, + ) + try: + os.kill(existing_pid, signal.SIGKILL) + _time.sleep(0.5) + except (ProcessLookupError, PermissionError): + pass + remove_pid_file() + else: + hermes_home = os.getenv("HERMES_HOME", "~/.hermes") + logger.error( + "Another gateway instance is already running (PID %d, HERMES_HOME=%s). " + "Use 'hermes gateway restart' to replace it, or 'hermes gateway stop' first.", + existing_pid, hermes_home, + ) + print( + f"\n❌ Gateway already running (PID {existing_pid}).\n" + f" Use 'hermes gateway restart' to replace it,\n" + f" or 'hermes gateway stop' to kill it first.\n" + f" Or use 'hermes gateway run --replace' to auto-replace.\n" + ) + return False + + # Sync bundled skills on gateway start (fast -- skips unchanged) + try: + from tools.skills_sync import sync_skills + sync_skills(quiet=True) + except Exception: + pass + + # Configure rotating file log so gateway output is persisted for debugging + log_dir = _hermes_home / 'logs' + log_dir.mkdir(parents=True, exist_ok=True) + file_handler = RotatingFileHandler( + log_dir / 'gateway.log', + maxBytes=5 * 1024 * 1024, + backupCount=3, + ) + from agent.redact import RedactingFormatter + file_handler.setFormatter(RedactingFormatter('%(asctime)s %(levelname)s %(name)s: %(message)s')) + logging.getLogger().addHandler(file_handler) + logging.getLogger().setLevel(logging.INFO) + + # Separate errors-only log for easy debugging + error_handler = RotatingFileHandler( + log_dir / 'errors.log', + maxBytes=2 * 1024 * 1024, + backupCount=2, + ) + error_handler.setLevel(logging.WARNING) + error_handler.setFormatter(RedactingFormatter('%(asctime)s %(levelname)s %(name)s: %(message)s')) + logging.getLogger().addHandler(error_handler) + + runner = GatewayRunner(config) + + # Set up signal handlers + def signal_handler(): + asyncio.create_task(runner.stop()) + + loop = asyncio.get_event_loop() + for sig in (signal.SIGINT, signal.SIGTERM): + try: + loop.add_signal_handler(sig, signal_handler) + except NotImplementedError: + pass + + # Start the gateway + success = await runner.start() + if not success: + return False + + # Write PID file so CLI can detect gateway is running + import atexit + from gateway.status import write_pid_file, remove_pid_file + write_pid_file() + atexit.register(remove_pid_file) + + # Start background cron ticker so scheduled jobs fire automatically + cron_stop = threading.Event() + cron_thread = threading.Thread( + target=_start_cron_ticker, + args=(cron_stop,), + kwargs={"adapters": runner.adapters}, + daemon=True, + name="cron-ticker", + ) + cron_thread.start() + + # Wait for shutdown + await runner.wait_for_shutdown() + + # Stop cron ticker cleanly + cron_stop.set() + cron_thread.join(timeout=5) + + # Close MCP server connections + try: + from tools.mcp_tool import shutdown_mcp_servers + shutdown_mcp_servers() + except Exception: + pass + + return True + + +def main(): + """CLI entry point for the gateway.""" + import argparse + + parser = argparse.ArgumentParser(description="Hermes Gateway - Multi-platform messaging") + parser.add_argument("--config", "-c", help="Path to gateway config file") + parser.add_argument("--verbose", "-v", action="store_true", help="Verbose output") + + args = parser.parse_args() + + config = None + if args.config: + import json + with open(args.config, encoding="utf-8") as f: + data = json.load(f) + config = GatewayConfig.from_dict(data) + + # Run the gateway - exit with code 1 if no platforms connected, + # so systemd Restart=on-failure will retry on transient errors (e.g. DNS) + success = asyncio.run(start_gateway(config)) + if not success: + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/tests/tools/test_approval.py b/tests/tools/test_approval.py index 311a0ba674..b95e865e52 100644 --- a/tests/tools/test_approval.py +++ b/tests/tools/test_approval.py @@ -377,6 +377,18 @@ class TestViewFullCommand: result = prompt_dangerous_approval(long_cmd, "recursive delete") assert result == "always" + def test_view_then_session_when_permanent_hidden(self): + """The view-full flow still works when allow_permanent=False.""" + long_cmd = "rm -rf " + "d" * 200 + inputs = iter(["v", "s"]) + with mock_patch("builtins.input", side_effect=inputs): + result = prompt_dangerous_approval( + long_cmd, + "recursive delete", + allow_permanent=False, + ) + assert result == "session" + def test_view_not_shown_for_short_command(self): """Short commands don't offer the view option; 'v' falls through to deny.""" short_cmd = "rm -rf /tmp" diff --git a/tests/tools/test_command_guards.py b/tests/tools/test_command_guards.py index b93f9dbbb5..c890a2c6f1 100644 --- a/tests/tools/test_command_guards.py +++ b/tests/tools/test_command_guards.py @@ -5,6 +5,7 @@ from unittest.mock import patch, MagicMock import pytest +import tools.approval as approval_module from tools.approval import ( approve_session, check_all_command_guards, @@ -35,15 +36,17 @@ def _clean_state(): """Clear approval state and relevant env vars between tests.""" key = os.getenv("HERMES_SESSION_KEY", "default") clear_session(key) + approval_module._permanent_approved.clear() saved = {} - for k in ("HERMES_INTERACTIVE", "HERMES_GATEWAY_SESSION", "HERMES_EXEC_ASK"): + for k in ("HERMES_INTERACTIVE", "HERMES_GATEWAY_SESSION", "HERMES_EXEC_ASK", "HERMES_YOLO_MODE"): if k in os.environ: saved[k] = os.environ.pop(k) yield clear_session(key) + approval_module._permanent_approved.clear() for k, v in saved.items(): os.environ[k] = v - for k in ("HERMES_INTERACTIVE", "HERMES_GATEWAY_SESSION", "HERMES_EXEC_ASK"): + for k in ("HERMES_INTERACTIVE", "HERMES_GATEWAY_SESSION", "HERMES_EXEC_ASK", "HERMES_YOLO_MODE"): os.environ.pop(k, None) @@ -76,9 +79,16 @@ class TestContainerSkip: class TestTirithAllowSafeCommand: @patch(_TIRITH_PATCH, return_value=_tirith_result("allow")) def test_both_allow(self, mock_tirith): + os.environ["HERMES_INTERACTIVE"] = "1" result = check_all_command_guards("echo hello", "local") assert result["approved"] is True + @patch(_TIRITH_PATCH, return_value=_tirith_result("allow")) + def test_noninteractive_skips_external_scan(self, mock_tirith): + result = check_all_command_guards("echo hello", "local") + assert result["approved"] is True + mock_tirith.assert_not_called() + # --------------------------------------------------------------------------- # tirith block @@ -88,6 +98,7 @@ class TestTirithBlock: @patch(_TIRITH_PATCH, return_value=_tirith_result("block", summary="homograph detected")) def test_tirith_block_safe_command(self, mock_tirith): + os.environ["HERMES_INTERACTIVE"] = "1" result = check_all_command_guards("curl http://gооgle.com", "local") assert result["approved"] is False assert "BLOCKED" in result["message"] @@ -97,6 +108,7 @@ class TestTirithBlock: return_value=_tirith_result("block", summary="terminal injection")) def test_tirith_block_plus_dangerous(self, mock_tirith): """tirith block takes precedence even if command is also dangerous.""" + os.environ["HERMES_INTERACTIVE"] = "1" result = check_all_command_guards("rm -rf / | curl http://evil", "local") assert result["approved"] is False assert "BLOCKED" in result["message"] @@ -308,5 +320,6 @@ class TestProgrammingErrorsPropagateFromWrapper: @patch(_TIRITH_PATCH, side_effect=AttributeError("bug in wrapper")) def test_attribute_error_propagates(self, mock_tirith): """Non-ImportError exceptions from tirith wrapper should propagate.""" + os.environ["HERMES_INTERACTIVE"] = "1" with pytest.raises(AttributeError, match="bug in wrapper"): check_all_command_guards("echo hello", "local") diff --git a/tests/tools/test_yolo_mode.py b/tests/tools/test_yolo_mode.py index 8802670100..91c751e7a9 100644 --- a/tests/tools/test_yolo_mode.py +++ b/tests/tools/test_yolo_mode.py @@ -3,7 +3,25 @@ import os import pytest -from tools.approval import check_dangerous_command, detect_dangerous_command +import tools.approval as approval_module +import tools.tirith_security + +from tools.approval import ( + check_all_command_guards, + check_dangerous_command, + detect_dangerous_command, +) + + +@pytest.fixture(autouse=True) +def _clear_approval_state(): + approval_module._permanent_approved.clear() + approval_module.clear_session("default") + approval_module.clear_session("test-session") + yield + approval_module._permanent_approved.clear() + approval_module.clear_session("default") + approval_module.clear_session("test-session") class TestYoloMode: @@ -54,6 +72,24 @@ class TestYoloMode: result = check_dangerous_command(cmd, "local") assert result["approved"], f"Command should be approved in yolo mode: {cmd}" + def test_combined_guard_bypasses_yolo_mode(self, monkeypatch): + """The new combined guard should preserve yolo bypass semantics.""" + monkeypatch.setenv("HERMES_YOLO_MODE", "1") + monkeypatch.setenv("HERMES_INTERACTIVE", "1") + + called = {"value": False} + + def fake_check(command): + called["value"] = True + return {"action": "block", "findings": [], "summary": "should never run"} + + monkeypatch.setattr(tools.tirith_security, "check_command_security", fake_check) + + result = check_all_command_guards("rm -rf /", "local") + assert result["approved"] + assert result["message"] is None + assert called["value"] is False + def test_yolo_mode_not_set_by_default(self): """HERMES_YOLO_MODE should not be set by default.""" # Clean env check — if it happens to be set in test env, that's fine, diff --git a/tools/approval.py b/tools/approval.py index 3ba8b17765..83980893d5 100644 --- a/tools/approval.py +++ b/tools/approval.py @@ -343,6 +343,19 @@ def check_all_command_guards(command: str, env_type: str, if env_type in ("docker", "singularity", "modal", "daytona"): return {"approved": True, "message": None} + # --yolo: bypass all approval prompts and pre-exec guard checks + if os.getenv("HERMES_YOLO_MODE"): + return {"approved": True, "message": None} + + is_cli = os.getenv("HERMES_INTERACTIVE") + is_gateway = os.getenv("HERMES_GATEWAY_SESSION") + is_ask = os.getenv("HERMES_EXEC_ASK") + + # Preserve the existing non-interactive behavior: outside CLI/gateway/ask + # flows, we do not block on approvals and we skip external guard work. + if not is_cli and not is_gateway and not is_ask: + return {"approved": True, "message": None} + # --- Phase 1: Gather findings from both checks --- # Tirith check — wrapper guarantees no raise for expected failures. @@ -390,13 +403,6 @@ def check_all_command_guards(command: str, env_type: str, # --- Phase 3: Approval --- - is_cli = os.getenv("HERMES_INTERACTIVE") - is_gateway = os.getenv("HERMES_GATEWAY_SESSION") - - # Non-interactive: auto-allow (matches existing behavior) - if not is_cli and not is_gateway: - return {"approved": True, "message": None} - # Combine descriptions for a single approval prompt combined_desc = "; ".join(desc for _, desc, _ in warnings) primary_key = warnings[0][0] @@ -405,7 +411,7 @@ def check_all_command_guards(command: str, env_type: str, # Gateway/async: single approval_required with combined description # Store all pattern keys so gateway replay approves all of them - if is_gateway or os.getenv("HERMES_EXEC_ASK"): + if is_gateway or is_ask: submit_pending(session_key, { "command": command, "pattern_key": primary_key, # backward compat