Compare commits

...

5 Commits

Author SHA1 Message Date
Teknium
c4c3a57a3a fix: restore agent.close() cleanup and correct /restart category
- Add agent.close() call to _finalize_shutdown_agents() to prevent
  zombie processes (terminal sandboxes, browser daemons, httpx clients)
- Global cleanup (process_registry, environments, browsers) preserved
  in _stop_impl() during conflict resolution
- Move /restart CommandDef from 'Info' to 'Session' category to match
  /stop and /status
2026-04-10 18:55:28 -07:00
Kenny Xie
eb3f021e2a fix(gateway): address restart review feedback 2026-04-10 18:54:48 -07:00
aquaright1
b4cb803954 fix(gateway): self-request service restarts when invoked in-process 2026-04-10 18:54:48 -07:00
Kenny Xie
b5928530d3 fix(gateway): tolerate partial runner construction 2026-04-10 18:54:48 -07:00
Kenny Xie
ef120b5422 fix(gateway): drain in-flight work before restart 2026-04-10 18:54:48 -07:00
14 changed files with 1089 additions and 161 deletions

View File

@@ -480,6 +480,12 @@ agent:
# Fires once per run when inactivity reaches this threshold (seconds).
# Set to 0 to disable the warning.
# gateway_timeout_warning: 900
# Graceful drain timeout for gateway stop/restart (seconds).
# The gateway stops accepting new work, waits for in-flight agents to
# finish, then interrupts anything still running after this timeout.
# 0 = no drain, interrupt immediately.
# restart_drain_timeout: 60
# Enable verbose logging
verbose: false

View File

@@ -673,6 +673,32 @@ class SendResult:
retryable: bool = False # True for transient connection errors — base will retry automatically
def merge_pending_message_event(
pending_messages: Dict[str, MessageEvent],
session_key: str,
event: MessageEvent,
) -> None:
"""Store or merge a pending event for a session.
Photo bursts/albums often arrive as multiple near-simultaneous PHOTO
events. Merge those into the existing queued event so the next turn sees
the whole burst, while non-photo follow-ups still replace the pending
event normally.
"""
existing = pending_messages.get(session_key)
if (
existing
and getattr(existing, "message_type", None) == MessageType.PHOTO
and event.message_type == MessageType.PHOTO
):
existing.media_urls.extend(event.media_urls)
existing.media_types.extend(event.media_types)
if event.text:
existing.text = BasePlatformAdapter._merge_caption(existing.text, event.text)
return
pending_messages[session_key] = event
# Error substrings that indicate a transient *connection* failure worth retrying.
# "timeout" / "timed out" / "readtimeout" / "writetimeout" are intentionally
# excluded: a read/write timeout on a non-idempotent call (e.g. send_message)
@@ -727,6 +753,7 @@ class BasePlatformAdapter(ABC):
# working on a task after --replace or manual restarts.
self._background_tasks: set[asyncio.Task] = set()
self._expected_cancelled_tasks: set[asyncio.Task] = set()
self._busy_session_handler: Optional[Callable[[MessageEvent, str], Awaitable[bool]]] = None
# Chats where auto-TTS on voice input is disabled (set by /voice off)
self._auto_tts_disabled_chats: set = set()
# Chats where typing indicator is paused (e.g. during approval waits).
@@ -815,6 +842,10 @@ class BasePlatformAdapter(ABC):
an optional response string.
"""
self._message_handler = handler
def set_busy_session_handler(self, handler: Optional[Callable[[MessageEvent, str], Awaitable[bool]]]) -> None:
"""Set an optional handler for messages arriving during active sessions."""
self._busy_session_handler = handler
def set_session_store(self, session_store: Any) -> None:
"""
@@ -1396,7 +1427,7 @@ class BasePlatformAdapter(ABC):
# session lifecycle and its cleanup races with the running task
# (see PR #4926).
cmd = event.get_command()
if cmd in ("approve", "deny", "status", "stop", "new", "reset", "background"):
if cmd in ("approve", "deny", "status", "stop", "new", "reset", "background", "restart"):
logger.debug(
"[%s] Command '/%s' bypassing active-session guard for %s",
self.name, cmd, session_key,
@@ -1415,19 +1446,19 @@ class BasePlatformAdapter(ABC):
logger.error("[%s] Command '/%s' dispatch failed: %s", self.name, cmd, e, exc_info=True)
return
if self._busy_session_handler is not None:
try:
if await self._busy_session_handler(event, session_key):
return
except Exception as e:
logger.error("[%s] Busy-session handler failed: %s", self.name, e, exc_info=True)
# Special case: photo bursts/albums frequently arrive as multiple near-
# simultaneous messages. Queue them without interrupting the active run,
# then process them immediately after the current task finishes.
if event.message_type == MessageType.PHOTO:
logger.debug("[%s] Queuing photo follow-up for session %s without interrupt", self.name, session_key)
existing = self._pending_messages.get(session_key)
if existing and existing.message_type == MessageType.PHOTO:
existing.media_urls.extend(event.media_urls)
existing.media_types.extend(event.media_types)
if event.text:
existing.text = self._merge_caption(existing.text, event.text)
else:
self._pending_messages[session_key] = event
merge_pending_message_event(self._pending_messages, session_key, event)
return # Don't interrupt now - will run after current task completes
# Default behavior for non-photo follow-ups: interrupt the running agent

20
gateway/restart.py Normal file
View File

@@ -0,0 +1,20 @@
"""Shared gateway restart constants and parsing helpers."""
from hermes_cli.config import DEFAULT_CONFIG
# EX_TEMPFAIL from sysexits.h — used to ask the service manager to restart
# the gateway after a graceful drain/reload path completes.
GATEWAY_SERVICE_RESTART_EXIT_CODE = 75
DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT = float(
DEFAULT_CONFIG["agent"]["restart_drain_timeout"]
)
def parse_restart_drain_timeout(raw: object) -> float:
"""Parse a configured drain timeout, falling back to the shared default."""
try:
value = float(raw) if str(raw or "").strip() else DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT
except (TypeError, ValueError):
return DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT
return max(0.0, value)

View File

@@ -186,6 +186,12 @@ if _config_path.exists():
os.environ["HERMES_AGENT_TIMEOUT"] = str(_agent_cfg["gateway_timeout"])
if "gateway_timeout_warning" in _agent_cfg and "HERMES_AGENT_TIMEOUT_WARNING" not in os.environ:
os.environ["HERMES_AGENT_TIMEOUT_WARNING"] = str(_agent_cfg["gateway_timeout_warning"])
if "restart_drain_timeout" in _agent_cfg and "HERMES_RESTART_DRAIN_TIMEOUT" not in os.environ:
os.environ["HERMES_RESTART_DRAIN_TIMEOUT"] = str(_agent_cfg["restart_drain_timeout"])
_display_cfg = _cfg.get("display", {})
if _display_cfg and isinstance(_display_cfg, dict):
if "busy_input_mode" in _display_cfg and "HERMES_GATEWAY_BUSY_INPUT_MODE" not in os.environ:
os.environ["HERMES_GATEWAY_BUSY_INPUT_MODE"] = str(_display_cfg["busy_input_mode"])
# Timezone: bridge config.yaml → HERMES_TIMEZONE env var.
# HERMES_TIMEZONE from .env takes precedence (already in os.environ).
_tz_cfg = _cfg.get("timezone", "")
@@ -235,7 +241,17 @@ from gateway.session import (
build_session_key,
)
from gateway.delivery import DeliveryRouter
from gateway.platforms.base import BasePlatformAdapter, MessageEvent, MessageType
from gateway.platforms.base import (
BasePlatformAdapter,
MessageEvent,
MessageType,
merge_pending_message_event,
)
from gateway.restart import (
DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT,
GATEWAY_SERVICE_RESTART_EXIT_CODE,
parse_restart_drain_timeout,
)
def _normalize_whatsapp_identifier(value: str) -> str:
@@ -471,6 +487,16 @@ class GatewayRunner:
# Class-level defaults so partial construction in tests doesn't
# blow up on attribute access.
_running_agents_ts: Dict[str, float] = {}
_busy_input_mode: str = "interrupt"
_restart_drain_timeout: float = DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT
_exit_code: Optional[int] = None
_draining: bool = False
_restart_requested: bool = False
_restart_task_started: bool = False
_restart_detached: bool = False
_restart_via_service: bool = False
_stop_task: Optional[asyncio.Task] = None
_session_model_overrides: Dict[str, Dict[str, str]] = {}
def __init__(self, config: Optional[GatewayConfig] = None):
self.config = config or load_gateway_config()
@@ -483,6 +509,8 @@ class GatewayRunner:
self._reasoning_config = self._load_reasoning_config()
self._service_tier = self._load_service_tier()
self._show_reasoning = self._load_show_reasoning()
self._busy_input_mode = self._load_busy_input_mode()
self._restart_drain_timeout = self._load_restart_drain_timeout()
self._provider_routing = self._load_provider_routing()
self._fallback_model = self._load_fallback_model()
self._smart_model_routing = self._load_smart_model_routing()
@@ -499,6 +527,13 @@ class GatewayRunner:
self._exit_cleanly = False
self._exit_with_failure = False
self._exit_reason: Optional[str] = None
self._exit_code: Optional[int] = None
self._draining = False
self._restart_requested = False
self._restart_task_started = False
self._restart_detached = False
self._restart_via_service = False
self._stop_task: Optional[asyncio.Task] = None
# Track running agents per session for interrupt support
# Key: session_key, Value: AIAgent instance
@@ -759,6 +794,10 @@ class GatewayRunner:
def exit_reason(self) -> Optional[str]:
return self._exit_reason
@property
def exit_code(self) -> Optional[int]:
return self._exit_code
def _session_key_for_source(self, source: SessionSource) -> str:
"""Resolve the current session key for a source, honoring gateway config when available."""
if hasattr(self, "session_store") and self.session_store is not None:
@@ -868,6 +907,30 @@ class GatewayRunner:
self._exit_cleanly = True
self._exit_reason = reason
self._shutdown_event.set()
def _running_agent_count(self) -> int:
return len(self._running_agents)
def _status_action_label(self) -> str:
return "restart" if self._restart_requested else "shutdown"
def _status_action_gerund(self) -> str:
return "restarting" if self._restart_requested else "shutting down"
def _queue_during_drain_enabled(self) -> bool:
return self._restart_requested and self._busy_input_mode == "queue"
def _update_runtime_status(self, gateway_state: Optional[str] = None, exit_reason: Optional[str] = None) -> None:
try:
from gateway.status import write_runtime_status
write_runtime_status(
gateway_state=gateway_state,
exit_reason=exit_reason,
restart_requested=self._restart_requested,
active_agents=self._running_agent_count(),
)
except Exception:
pass
@staticmethod
def _load_prefill_messages() -> List[Dict[str, Any]]:
@@ -994,6 +1057,48 @@ class GatewayRunner:
pass
return False
@staticmethod
def _load_busy_input_mode() -> str:
"""Load gateway drain-time busy-input behavior from config/env."""
mode = os.getenv("HERMES_GATEWAY_BUSY_INPUT_MODE", "").strip().lower()
if not mode:
try:
import yaml as _y
cfg_path = _hermes_home / "config.yaml"
if cfg_path.exists():
with open(cfg_path, encoding="utf-8") as _f:
cfg = _y.safe_load(_f) or {}
mode = str(cfg.get("display", {}).get("busy_input_mode", "") or "").strip().lower()
except Exception:
pass
return "queue" if mode == "queue" else "interrupt"
@staticmethod
def _load_restart_drain_timeout() -> float:
"""Load graceful gateway restart/stop drain timeout in seconds."""
raw = os.getenv("HERMES_RESTART_DRAIN_TIMEOUT", "").strip()
if not raw:
try:
import yaml as _y
cfg_path = _hermes_home / "config.yaml"
if cfg_path.exists():
with open(cfg_path, encoding="utf-8") as _f:
cfg = _y.safe_load(_f) or {}
raw = str(cfg.get("agent", {}).get("restart_drain_timeout", "") or "").strip()
except Exception:
pass
value = parse_restart_drain_timeout(raw)
if raw and value == DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT:
try:
float(raw)
except (TypeError, ValueError):
logger.warning(
"Invalid restart_drain_timeout '%s', using default %.0fs",
raw,
DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT,
)
return value
@staticmethod
def _load_background_notifications_mode() -> str:
"""Load background process notification mode from config or env var.
@@ -1078,6 +1183,155 @@ class GatewayRunner:
pass
return {}
def _snapshot_running_agents(self) -> Dict[str, Any]:
return {
session_key: agent
for session_key, agent in self._running_agents.items()
if agent is not _AGENT_PENDING_SENTINEL
}
def _queue_or_replace_pending_event(self, session_key: str, event: MessageEvent) -> None:
adapter = self.adapters.get(event.source.platform)
if not adapter:
return
merge_pending_message_event(adapter._pending_messages, session_key, event)
async def _handle_active_session_busy_message(self, event: MessageEvent, session_key: str) -> bool:
if not self._draining:
return False
adapter = self.adapters.get(event.source.platform)
if not adapter:
return True
thread_meta = {"thread_id": event.source.thread_id} if event.source.thread_id else None
if self._queue_during_drain_enabled():
self._queue_or_replace_pending_event(session_key, event)
message = f"⏳ Gateway {self._status_action_gerund()} — queued for the next turn after it comes back."
else:
message = f"⏳ Gateway is {self._status_action_gerund()} and is not accepting another turn right now."
await adapter._send_with_retry(
chat_id=event.source.chat_id,
content=message,
reply_to=event.message_id,
metadata=thread_meta,
)
return True
async def _drain_active_agents(self, timeout: float) -> tuple[Dict[str, Any], bool]:
snapshot = self._snapshot_running_agents()
last_active_count = self._running_agent_count()
last_status_at = 0.0
def _maybe_update_status(force: bool = False) -> None:
nonlocal last_active_count, last_status_at
now = asyncio.get_running_loop().time()
active_count = self._running_agent_count()
if force or active_count != last_active_count or (now - last_status_at) >= 1.0:
self._update_runtime_status("draining")
last_active_count = active_count
last_status_at = now
if not self._running_agents:
_maybe_update_status(force=True)
return snapshot, False
_maybe_update_status(force=True)
if timeout <= 0:
return snapshot, True
deadline = asyncio.get_running_loop().time() + timeout
while self._running_agents and asyncio.get_running_loop().time() < deadline:
_maybe_update_status()
await asyncio.sleep(0.1)
timed_out = bool(self._running_agents)
_maybe_update_status(force=True)
return snapshot, timed_out
def _interrupt_running_agents(self, reason: str) -> None:
for session_key, agent in list(self._running_agents.items()):
if agent is _AGENT_PENDING_SENTINEL:
continue
try:
agent.interrupt(reason)
logger.debug("Interrupted running agent for session %s during shutdown", session_key[:20])
except Exception as e:
logger.debug("Failed interrupting agent during shutdown: %s", e)
def _finalize_shutdown_agents(self, active_agents: Dict[str, Any]) -> None:
for agent in active_agents.values():
try:
from hermes_cli.plugins import invoke_hook as _invoke_hook
_invoke_hook(
"on_session_finalize",
session_id=getattr(agent, "session_id", None),
platform="gateway",
)
except Exception:
pass
try:
if hasattr(agent, "shutdown_memory_provider"):
agent.shutdown_memory_provider()
except Exception:
pass
# Close tool resources (terminal sandboxes, browser daemons,
# background processes, httpx clients) to prevent zombie
# process accumulation.
try:
if hasattr(agent, 'close'):
agent.close()
except Exception:
pass
async def _launch_detached_restart_command(self) -> None:
import shutil
import subprocess
hermes_cmd = _resolve_hermes_bin()
if not hermes_cmd:
logger.error("Could not locate hermes binary for detached /restart")
return
current_pid = os.getpid()
cmd = " ".join(shlex.quote(part) for part in hermes_cmd)
shell_cmd = (
f"while kill -0 {current_pid} 2>/dev/null; do sleep 0.2; done; "
f"{cmd} gateway restart"
)
setsid_bin = shutil.which("setsid")
if setsid_bin:
subprocess.Popen(
[setsid_bin, "bash", "-lc", shell_cmd],
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
start_new_session=True,
)
else:
subprocess.Popen(
["bash", "-lc", shell_cmd],
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
start_new_session=True,
)
def request_restart(self, *, detached: bool = False, via_service: bool = False) -> bool:
if self._restart_task_started:
return False
self._restart_requested = True
self._restart_detached = detached
self._restart_via_service = via_service
self._restart_task_started = True
async def _run_restart() -> None:
await asyncio.sleep(0.05)
await self.stop(restart=True, detached_restart=detached, service_restart=via_service)
task = asyncio.create_task(_run_restart())
self._background_tasks.add(task)
task.add_done_callback(self._background_tasks.discard)
return True
async def start(self) -> bool:
"""
Start the gateway and all configured platform adapters.
@@ -1165,6 +1419,7 @@ class GatewayRunner:
adapter.set_message_handler(self._handle_message)
adapter.set_fatal_error_handler(self._handle_adapter_fatal_error)
adapter.set_session_store(self.session_store)
adapter.set_busy_session_handler(self._handle_active_session_busy_message)
# Try to connect
logger.info("Connecting to %s...", platform.value)
@@ -1240,11 +1495,7 @@ class GatewayRunner:
self.delivery_router.adapters = self.adapters
self._running = True
try:
from gateway.status import write_runtime_status
write_runtime_status(gateway_state="running", exit_reason=None)
except Exception:
pass
self._update_runtime_status("running")
# Emit gateway:startup hook
hook_count = len(self.hooks.loaded_hooks)
@@ -1479,6 +1730,7 @@ class GatewayRunner:
adapter.set_message_handler(self._handle_message)
adapter.set_fatal_error_handler(self._handle_adapter_fatal_error)
adapter.set_session_store(self.session_store)
adapter.set_busy_session_handler(self._handle_active_session_busy_message)
success = await adapter.connect()
if success:
@@ -1525,90 +1777,108 @@ class GatewayRunner:
return
await asyncio.sleep(1)
async def stop(self) -> None:
async def stop(
self,
*,
restart: bool = False,
detached_restart: bool = False,
service_restart: bool = False,
) -> None:
"""Stop the gateway and disconnect all adapters."""
logger.info("Stopping gateway...")
self._running = False
if restart:
self._restart_requested = True
self._restart_detached = detached_restart
self._restart_via_service = service_restart
if self._stop_task is not None:
await self._stop_task
return
for session_key, agent in list(self._running_agents.items()):
if agent is _AGENT_PENDING_SENTINEL:
continue
async def _stop_impl() -> None:
logger.info(
"Stopping gateway%s...",
" for restart" if self._restart_requested else "",
)
self._running = False
self._draining = True
timeout = self._restart_drain_timeout
active_agents, timed_out = await self._drain_active_agents(timeout)
if timed_out:
logger.warning(
"Gateway drain timed out after %.1fs with %d active agent(s); interrupting remaining work.",
timeout,
self._running_agent_count(),
)
self._interrupt_running_agents(
"Gateway restarting" if self._restart_requested else "Gateway shutting down"
)
interrupt_deadline = asyncio.get_running_loop().time() + 5.0
while self._running_agents and asyncio.get_running_loop().time() < interrupt_deadline:
self._update_runtime_status("draining")
await asyncio.sleep(0.1)
if self._restart_requested and self._restart_detached:
try:
await self._launch_detached_restart_command()
except Exception as e:
logger.error("Failed to launch detached gateway restart: %s", e)
self._finalize_shutdown_agents(active_agents)
for platform, adapter in list(self.adapters.items()):
try:
await adapter.cancel_background_tasks()
except Exception as e:
logger.debug("%s background-task cancel error: %s", platform.value, e)
try:
await adapter.disconnect()
logger.info("%s disconnected", platform.value)
except Exception as e:
logger.error("%s disconnect error: %s", platform.value, e)
for _task in list(self._background_tasks):
if _task is self._stop_task:
continue
_task.cancel()
self._background_tasks.clear()
self.adapters.clear()
self._running_agents.clear()
self._pending_messages.clear()
self._pending_approvals.clear()
self._shutdown_event.set()
# Global cleanup: kill any remaining tool subprocesses not tied
# to a specific agent (catch-all for zombie prevention).
try:
agent.interrupt("Gateway shutting down")
logger.debug("Interrupted running agent for session %s during shutdown", session_key[:20])
except Exception as e:
logger.debug("Failed interrupting agent during shutdown: %s", e)
# Fire plugin on_session_finalize hook before memory shutdown
try:
from hermes_cli.plugins import invoke_hook as _invoke_hook
_invoke_hook("on_session_finalize",
session_id=getattr(agent, 'session_id', None),
platform="gateway")
from tools.process_registry import process_registry
process_registry.kill_all()
except Exception:
pass
# Shut down memory provider at actual session boundary
try:
if hasattr(agent, 'shutdown_memory_provider'):
agent.shutdown_memory_provider()
from tools.terminal_tool import cleanup_all_environments
cleanup_all_environments()
except Exception:
pass
# Close tool resources (terminal sandboxes, browser daemons,
# background processes, httpx clients) to prevent zombie
# process accumulation.
try:
if hasattr(agent, 'close'):
agent.close()
from tools.browser_tool import cleanup_all_browsers
cleanup_all_browsers()
except Exception:
pass
for platform, adapter in list(self.adapters.items()):
try:
await adapter.cancel_background_tasks()
except Exception as e:
logger.debug("%s background-task cancel error: %s", platform.value, e)
try:
await adapter.disconnect()
logger.info("%s disconnected", platform.value)
except Exception as e:
logger.error("%s disconnect error: %s", platform.value, e)
from gateway.status import remove_pid_file
remove_pid_file()
# Cancel any pending background tasks
for _task in list(self._background_tasks):
_task.cancel()
self._background_tasks.clear()
if self._restart_requested and self._restart_via_service:
self._exit_code = GATEWAY_SERVICE_RESTART_EXIT_CODE
self._exit_reason = self._exit_reason or "Gateway restart requested"
self.adapters.clear()
self._running_agents.clear()
self._pending_messages.clear()
self._pending_approvals.clear()
self._shutdown_event.set()
self._draining = False
self._update_runtime_status("stopped", self._exit_reason)
logger.info("Gateway stopped")
# Global cleanup: kill any remaining tool subprocesses not tied
# to a specific agent (catch-all for zombie prevention).
try:
from tools.process_registry import process_registry
process_registry.kill_all()
except Exception:
pass
try:
from tools.terminal_tool import cleanup_all_environments
cleanup_all_environments()
except Exception:
pass
try:
from tools.browser_tool import cleanup_all_browsers
cleanup_all_browsers()
except Exception:
pass
from gateway.status import remove_pid_file, write_runtime_status
remove_pid_file()
try:
write_runtime_status(gateway_state="stopped", exit_reason=self._exit_reason)
except Exception:
pass
logger.info("Gateway stopped")
self._stop_task = asyncio.create_task(_stop_impl())
await self._stop_task
async def wait_for_shutdown(self) -> None:
"""Wait for shutdown signal."""
@@ -2014,6 +2284,9 @@ class GatewayRunner:
_evt_cmd = event.get_command()
_cmd_def_inner = _resolve_cmd_inner(_evt_cmd) if _evt_cmd else None
if _cmd_def_inner and _cmd_def_inner.name == "restart":
return await self._handle_restart_command(event)
# /stop must hard-kill the session when an agent is running.
# A soft interrupt (agent.interrupt()) doesn't help when the agent
# is truly hung — the executor thread is blocked and never checks
@@ -2094,18 +2367,7 @@ class GatewayRunner:
logger.debug("PRIORITY photo follow-up for session %s — queueing without interrupt", _quick_key[:20])
adapter = self.adapters.get(source.platform)
if adapter:
# Reuse adapter queue semantics so photo bursts merge cleanly.
if _quick_key in adapter._pending_messages:
existing = adapter._pending_messages[_quick_key]
if getattr(existing, "message_type", None) == MessageType.PHOTO:
existing.media_urls.extend(event.media_urls)
existing.media_types.extend(event.media_types)
if event.text:
existing.text = BasePlatformAdapter._merge_caption(existing.text, event.text)
else:
adapter._pending_messages[_quick_key] = event
else:
adapter._pending_messages[_quick_key] = event
merge_pending_message_event(adapter._pending_messages, _quick_key, event)
return None
running_agent = self._running_agents.get(_quick_key)
@@ -2123,6 +2385,14 @@ class GatewayRunner:
if adapter:
adapter._pending_messages[_quick_key] = event
return None
if self._draining:
if self._queue_during_drain_enabled():
self._queue_or_replace_pending_event(_quick_key, event)
return (
f"⏳ Gateway {self._status_action_gerund()} — queued for the next turn after it comes back."
if self._queue_during_drain_enabled()
else f"⏳ Gateway is {self._status_action_gerund()} and is not accepting another turn right now."
)
logger.debug("PRIORITY interrupt for session %s", _quick_key[:20])
running_agent.interrupt(event.text)
if _quick_key in self._pending_messages:
@@ -2164,6 +2434,9 @@ class GatewayRunner:
if canonical == "status":
return await self._handle_status_command(event)
if canonical == "restart":
return await self._handle_restart_command(event)
if canonical == "stop":
return await self._handle_stop_command(event)
@@ -2262,6 +2535,9 @@ class GatewayRunner:
if canonical == "voice":
return await self._handle_voice_command(event)
if self._draining:
return f"⏳ Gateway is {self._status_action_gerund()} and is not accepting new work right now."
# User-defined quick commands (bypass agent loop, no LLM call)
if command:
if isinstance(self.config, dict):
@@ -3556,7 +3832,21 @@ class GatewayRunner:
return "⚡ Force-stopped. The session is unlocked — you can send a new message."
else:
return "No active task to stop."
async def _handle_restart_command(self, event: MessageEvent) -> str:
"""Handle /restart command - drain active work, then restart the gateway."""
if self._restart_requested or self._draining:
count = self._running_agent_count()
if count:
return f"⏳ Draining {count} active agent(s) before restart..."
return "⏳ Gateway restart already in progress..."
active_agents = self._running_agent_count()
self.request_restart(detached=True, via_service=False)
if active_agents:
return f"⏳ Draining {active_agents} active agent(s) before restart..."
return "♻ Restarting gateway..."
async def _handle_help_command(self, event: MessageEvent) -> str:
"""Handle /help command - list available commands."""
from hermes_cli.commands import gateway_help_lines
@@ -3679,7 +3969,7 @@ class GatewayRunner:
# Check for session override
source = event.source
session_key = self._session_key_for_source(source)
override = getattr(self, "_session_model_overrides", {}).get(session_key, {})
override = self._session_model_overrides.get(session_key, {})
if override:
current_model = override.get("model", current_model)
current_provider = override.get("provider", current_provider)
@@ -3761,8 +4051,6 @@ class GatewayRunner:
f"via {result.provider_label or result.target_provider}. "
f"Adjust your self-identification accordingly.]"
)
if not hasattr(_self, "_session_model_overrides"):
_self._session_model_overrides = {}
_self._session_model_overrides[_session_key] = {
"model": result.new_model,
"provider": result.target_provider,
@@ -3876,8 +4164,6 @@ class GatewayRunner:
)
# Store session override so next agent creation uses the new model
if not hasattr(self, "_session_model_overrides"):
self._session_model_overrides = {}
self._session_model_overrides[session_key] = {
"model": result.new_model,
"provider": result.target_provider,
@@ -7363,6 +7649,8 @@ class GatewayRunner:
await asyncio.sleep(0.05)
if session_key:
self._running_agents[session_key] = agent_holder[0]
if self._draining:
self._update_runtime_status("draining")
tracking_task = asyncio.create_task(track_agent())
@@ -7608,6 +7896,14 @@ class GatewayRunner:
except Exception:
pass
if self._draining and pending:
logger.info(
"Discarding pending follow-up for session %s during gateway %s",
session_key[:20] if session_key else "?",
self._status_action_label(),
)
pending = None
if pending:
logger.debug("Processing pending message: '%s...'", pending[:40])
@@ -7684,6 +7980,8 @@ class GatewayRunner:
del self._running_agents[session_key]
if session_key:
self._running_agents_ts.pop(session_key, None)
if self._draining:
self._update_runtime_status("draining")
# Wait for cancelled tasks
for task in [progress_task, interrupt_monitor, tracking_task, _notify_task]:
@@ -7881,13 +8179,21 @@ async def start_gateway(config: Optional[GatewayConfig] = None, replace: bool =
runner = GatewayRunner(config)
# Set up signal handlers
def signal_handler():
def shutdown_signal_handler():
asyncio.create_task(runner.stop())
def restart_signal_handler():
runner.request_restart(detached=False, via_service=True)
loop = asyncio.get_event_loop()
for sig in (signal.SIGINT, signal.SIGTERM):
try:
loop.add_signal_handler(sig, signal_handler)
loop.add_signal_handler(sig, shutdown_signal_handler)
except NotImplementedError:
pass
if hasattr(signal, "SIGUSR1"):
try:
loop.add_signal_handler(signal.SIGUSR1, restart_signal_handler)
except NotImplementedError:
pass
@@ -7937,6 +8243,9 @@ async def start_gateway(config: Optional[GatewayConfig] = None, replace: bool =
except Exception:
pass
if runner.exit_code is not None:
raise SystemExit(runner.exit_code)
return True

View File

@@ -158,6 +158,8 @@ def _build_runtime_status_record() -> dict[str, Any]:
payload.update({
"gateway_state": "starting",
"exit_reason": None,
"restart_requested": False,
"active_agents": 0,
"platforms": {},
"updated_at": _utc_now_iso(),
})
@@ -218,6 +220,8 @@ def write_runtime_status(
*,
gateway_state: Optional[str] = None,
exit_reason: Optional[str] = None,
restart_requested: Optional[bool] = None,
active_agents: Optional[int] = None,
platform: Optional[str] = None,
platform_state: Optional[str] = None,
error_code: Optional[str] = None,
@@ -236,6 +240,10 @@ def write_runtime_status(
payload["gateway_state"] = gateway_state
if exit_reason is not None:
payload["exit_reason"] = exit_reason
if restart_requested is not None:
payload["restart_requested"] = bool(restart_requested)
if active_agents is not None:
payload["active_agents"] = max(0, int(active_agents))
if platform is not None:
platform_payload = payload["platforms"].get(platform, {})

View File

@@ -140,6 +140,8 @@ COMMAND_REGISTRY: list[CommandDef] = [
CommandDef("commands", "Browse all commands and skills (paginated)", "Info",
gateway_only=True, args_hint="[page]"),
CommandDef("help", "Show available commands", "Info"),
CommandDef("restart", "Gracefully restart the gateway after draining active runs", "Session",
gateway_only=True),
CommandDef("usage", "Show token usage and rate limits for the current session", "Info"),
CommandDef("insights", "Show usage insights and analytics", "Info",
args_hint="[days]"),

View File

@@ -269,6 +269,11 @@ DEFAULT_CONFIG = {
# tools or receiving API responses. Only fires when the agent has
# been completely idle for this duration. 0 = unlimited.
"gateway_timeout": 1800,
# Graceful drain timeout for gateway stop/restart (seconds).
# The gateway stops accepting new work, waits for running agents
# to finish, then interrupts any remaining runs after the timeout.
# 0 = no drain, interrupt immediately.
"restart_drain_timeout": 60,
"service_tier": "",
# Tool-use enforcement: injects system prompt guidance that tells the
# model to actually call tools instead of describing intended actions.

View File

@@ -15,7 +15,19 @@ from pathlib import Path
PROJECT_ROOT = Path(__file__).parent.parent.resolve()
from gateway.status import terminate_pid
from hermes_cli.config import get_env_value, get_hermes_home, save_env_value, is_managed, managed_error
from gateway.restart import (
DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT,
GATEWAY_SERVICE_RESTART_EXIT_CODE,
parse_restart_drain_timeout,
)
from hermes_cli.config import (
get_env_value,
get_hermes_home,
is_managed,
managed_error,
read_raw_config,
save_env_value,
)
# display_hermes_home is imported lazily at call sites to avoid ImportError
# when hermes_constants is cached from a pre-update version during `hermes update`.
from hermes_cli.setup import (
@@ -92,6 +104,59 @@ def _get_service_pids() -> set:
return pids
def _get_parent_pid(pid: int) -> int | None:
"""Return the parent PID for ``pid``, or ``None`` when unavailable."""
if pid <= 1:
return None
try:
result = subprocess.run(
["ps", "-o", "ppid=", "-p", str(pid)],
capture_output=True,
text=True,
timeout=5,
)
except (FileNotFoundError, subprocess.TimeoutExpired):
return None
if result.returncode != 0:
return None
raw = result.stdout.strip()
if not raw:
return None
try:
parent_pid = int(raw.splitlines()[-1].strip())
except ValueError:
return None
return parent_pid if parent_pid > 0 else None
def _is_pid_ancestor_of_current_process(target_pid: int) -> bool:
"""Return True when ``target_pid`` is this process or one of its ancestors."""
if target_pid <= 0:
return False
pid = os.getpid()
seen: set[int] = set()
while pid and pid not in seen:
if pid == target_pid:
return True
seen.add(pid)
pid = _get_parent_pid(pid) or 0
return False
def _request_gateway_self_restart(pid: int) -> bool:
"""Ask a running gateway ancestor to restart itself asynchronously."""
if not hasattr(signal, "SIGUSR1"):
return False
if not _is_pid_ancestor_of_current_process(pid):
return False
try:
os.kill(pid, signal.SIGUSR1)
except (ProcessLookupError, PermissionError, OSError):
return False
return True
def find_gateway_pids(exclude_pids: set | None = None) -> list:
"""Find PIDs of running gateway processes.
@@ -665,6 +730,7 @@ def generate_systemd_unit(system: bool = False, run_as_user: str | None = None)
path_entries.append(resolved_node_dir)
common_bin_paths = ["/usr/local/sbin", "/usr/local/bin", "/usr/sbin", "/usr/bin", "/sbin", "/bin"]
restart_timeout = max(60, int(_get_restart_drain_timeout() or 0))
if system:
username, group_name, home_dir = _system_service_identity(run_as_user)
@@ -703,9 +769,11 @@ Environment="VIRTUAL_ENV={venv_dir}"
Environment="HERMES_HOME={hermes_home}"
Restart=on-failure
RestartSec=30
RestartForceExitStatus={GATEWAY_SERVICE_RESTART_EXIT_CODE}
KillMode=mixed
KillSignal=SIGTERM
TimeoutStopSec=60
ExecReload=/bin/kill -USR1 $MAINPID
TimeoutStopSec={restart_timeout}
StandardOutput=journal
StandardError=journal
@@ -733,9 +801,11 @@ Environment="VIRTUAL_ENV={venv_dir}"
Environment="HERMES_HOME={hermes_home}"
Restart=on-failure
RestartSec=30
RestartForceExitStatus={GATEWAY_SERVICE_RESTART_EXIT_CODE}
KillMode=mixed
KillSignal=SIGTERM
TimeoutStopSec=60
ExecReload=/bin/kill -USR1 $MAINPID
TimeoutStopSec={restart_timeout}
StandardOutput=journal
StandardError=journal
@@ -838,6 +908,20 @@ def _select_systemd_scope(system: bool = False) -> bool:
return get_systemd_unit_path(system=True).exists() and not get_systemd_unit_path(system=False).exists()
def _get_restart_drain_timeout() -> float:
"""Return the configured gateway restart drain timeout in seconds."""
raw = os.getenv("HERMES_RESTART_DRAIN_TIMEOUT", "").strip()
if not raw:
cfg = read_raw_config()
agent_cfg = cfg.get("agent", {}) if isinstance(cfg, dict) else {}
raw = str(
agent_cfg.get(
"restart_drain_timeout", DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT
)
)
return parse_restart_drain_timeout(raw)
def systemd_install(force: bool = False, system: bool = False, run_as_user: str | None = None):
if system:
_require_root_for_system_service("install")
@@ -923,7 +1007,13 @@ def systemd_restart(system: bool = False):
if system:
_require_root_for_system_service("restart")
refresh_systemd_unit_if_needed(system=system)
subprocess.run(_systemctl_cmd(system) + ["restart", get_service_name()], check=True, timeout=90)
from gateway.status import get_running_pid
pid = get_running_pid()
if pid is not None and _request_gateway_self_restart(pid):
print(f"{_service_scope_label(system).capitalize()} service restart requested")
return
subprocess.run(_systemctl_cmd(system) + ["reload-or-restart", get_service_name()], check=True, timeout=90)
print(f"{_service_scope_label(system).capitalize()} service restarted")
@@ -1211,7 +1301,7 @@ def launchd_stop():
_wait_for_gateway_exit(timeout=10.0, force_after=5.0)
print("✓ Service stopped")
def _wait_for_gateway_exit(timeout: float = 10.0, force_after: float = 5.0):
def _wait_for_gateway_exit(timeout: float = 10.0, force_after: float | None = 5.0) -> bool:
"""Wait for the gateway process (by saved PID) to exit.
Uses the PID from the gateway.pid file — not launchd labels — so this
@@ -1226,21 +1316,21 @@ def _wait_for_gateway_exit(timeout: float = 10.0, force_after: float = 5.0):
from gateway.status import get_running_pid
deadline = time.monotonic() + timeout
force_deadline = time.monotonic() + force_after
force_deadline = (time.monotonic() + force_after) if force_after is not None else None
force_sent = False
while time.monotonic() < deadline:
pid = get_running_pid()
if pid is None:
return # Process exited cleanly.
return True # Process exited cleanly.
if not force_sent and time.monotonic() >= force_deadline:
if force_after is not None and not force_sent and time.monotonic() >= force_deadline:
# Grace period expired — force-kill the specific PID.
try:
terminate_pid(pid, force=True)
print(f"⚠ Gateway PID {pid} did not exit gracefully; sent SIGKILL")
except (ProcessLookupError, PermissionError, OSError):
return # Already gone or we can't touch it.
return True # Already gone or we can't touch it.
force_sent = True
time.sleep(0.3)
@@ -1249,15 +1339,30 @@ def _wait_for_gateway_exit(timeout: float = 10.0, force_after: float = 5.0):
remaining_pid = get_running_pid()
if remaining_pid is not None:
print(f"⚠ Gateway PID {remaining_pid} still running after {timeout}s — restart may fail")
return False
return True
def launchd_restart():
label = get_launchd_label()
target = f"{_launchd_domain()}/{label}"
# Use kickstart -k so launchd performs an atomic kill+restart.
# A two-step stop/start from inside the gateway's own process tree
# would kill the shell before the start command is reached.
drain_timeout = _get_restart_drain_timeout()
from gateway.status import get_running_pid
try:
pid = get_running_pid()
if pid is not None and _request_gateway_self_restart(pid):
print("✓ Service restart requested")
return
if pid is not None:
try:
terminate_pid(pid, force=False)
except (ProcessLookupError, PermissionError, OSError):
pid = None
if pid is not None:
exited = _wait_for_gateway_exit(timeout=drain_timeout, force_after=None)
if not exited:
print(f"⚠ Gateway drain timed out after {drain_timeout:.0f}s — forcing launchd restart")
subprocess.run(["launchctl", "kickstart", "-k", target], check=True, timeout=90)
print("✓ Service restarted")
except subprocess.CalledProcessError as e:
@@ -1728,6 +1833,8 @@ def _runtime_health_lines() -> list[str]:
lines: list[str] = []
gateway_state = state.get("gateway_state")
exit_reason = state.get("exit_reason")
active_agents = state.get("active_agents")
restart_requested = state.get("restart_requested")
platforms = state.get("platforms", {}) or {}
for platform, pdata in platforms.items():
@@ -1737,6 +1844,10 @@ def _runtime_health_lines() -> list[str]:
if gateway_state == "startup_failed" and exit_reason:
lines.append(f"⚠ Last startup issue: {exit_reason}")
elif gateway_state == "draining":
action = "restart" if restart_requested else "shutdown"
count = int(active_agents or 0)
lines.append(f"⏳ Gateway draining for {action} ({count} active agent(s))")
elif gateway_state == "stopped" and exit_reason:
lines.append(f"⚠ Last shutdown reason: {exit_reason}")

View File

@@ -0,0 +1,110 @@
import asyncio
from unittest.mock import AsyncMock, MagicMock
from gateway.config import GatewayConfig, Platform, PlatformConfig
from gateway.platforms.base import BasePlatformAdapter, MessageEvent, SendResult
from gateway.restart import DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT
from gateway.run import GatewayRunner
from gateway.session import SessionSource
class RestartTestAdapter(BasePlatformAdapter):
def __init__(self):
super().__init__(PlatformConfig(enabled=True, token="***"), Platform.TELEGRAM)
self.sent: list[str] = []
async def connect(self):
return True
async def disconnect(self):
return None
async def send(self, chat_id, content, reply_to=None, metadata=None):
self.sent.append(content)
return SendResult(success=True, message_id="1")
async def send_typing(self, chat_id, metadata=None):
return None
async def get_chat_info(self, chat_id):
return {"id": chat_id}
def make_restart_source(chat_id: str = "123456", chat_type: str = "dm") -> SessionSource:
return SessionSource(
platform=Platform.TELEGRAM,
chat_id=chat_id,
chat_type=chat_type,
)
def make_restart_runner(
adapter: BasePlatformAdapter | None = None,
) -> tuple[GatewayRunner, BasePlatformAdapter]:
runner = object.__new__(GatewayRunner)
runner.config = GatewayConfig(
platforms={Platform.TELEGRAM: PlatformConfig(enabled=True, token="***")}
)
runner._running = True
runner._shutdown_event = asyncio.Event()
runner._exit_reason = None
runner._exit_code = None
runner._running_agents = {}
runner._running_agents_ts = {}
runner._pending_messages = {}
runner._pending_approvals = {}
runner._pending_model_notes = {}
runner._background_tasks = set()
runner._draining = False
runner._restart_requested = False
runner._restart_task_started = False
runner._restart_detached = False
runner._restart_via_service = False
runner._restart_drain_timeout = DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT
runner._stop_task = None
runner._busy_input_mode = "interrupt"
runner._update_prompt_pending = {}
runner._voice_mode = {}
runner._session_model_overrides = {}
runner._shutdown_all_gateway_honcho = lambda: None
runner._update_runtime_status = MagicMock()
runner._queue_or_replace_pending_event = GatewayRunner._queue_or_replace_pending_event.__get__(
runner, GatewayRunner
)
runner._session_key_for_source = GatewayRunner._session_key_for_source.__get__(
runner, GatewayRunner
)
runner._handle_active_session_busy_message = (
GatewayRunner._handle_active_session_busy_message.__get__(runner, GatewayRunner)
)
runner._handle_restart_command = GatewayRunner._handle_restart_command.__get__(
runner, GatewayRunner
)
runner._status_action_label = GatewayRunner._status_action_label.__get__(
runner, GatewayRunner
)
runner._status_action_gerund = GatewayRunner._status_action_gerund.__get__(
runner, GatewayRunner
)
runner._queue_during_drain_enabled = GatewayRunner._queue_during_drain_enabled.__get__(
runner, GatewayRunner
)
runner._running_agent_count = GatewayRunner._running_agent_count.__get__(
runner, GatewayRunner
)
runner._launch_detached_restart_command = GatewayRunner._launch_detached_restart_command.__get__(
runner, GatewayRunner
)
runner.request_restart = GatewayRunner.request_restart.__get__(runner, GatewayRunner)
runner._is_user_authorized = lambda _source: True
runner.hooks = MagicMock()
runner.hooks.emit = AsyncMock()
runner.pairing_store = MagicMock()
runner.session_store = MagicMock()
runner.delivery_router = MagicMock()
platform_adapter = adapter or RestartTestAdapter()
platform_adapter.set_message_handler(AsyncMock(return_value=None))
platform_adapter.set_busy_session_handler(runner._handle_active_session_busy_message)
runner.adapters = {Platform.TELEGRAM: platform_adapter}
return runner, platform_adapter

View File

@@ -3,43 +3,15 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from gateway.config import GatewayConfig, Platform, PlatformConfig
from gateway.platforms.base import BasePlatformAdapter, MessageEvent, SendResult
from gateway.run import GatewayRunner
from gateway.session import SessionSource, build_session_key
class StubAdapter(BasePlatformAdapter):
def __init__(self):
super().__init__(PlatformConfig(enabled=True, token="***"), Platform.TELEGRAM)
async def connect(self):
return True
async def disconnect(self):
return None
async def send(self, chat_id, content, reply_to=None, metadata=None):
return SendResult(success=True, message_id="1")
async def send_typing(self, chat_id, metadata=None):
return None
async def get_chat_info(self, chat_id):
return {"id": chat_id}
def _source(chat_id="123456", chat_type="dm"):
return SessionSource(
platform=Platform.TELEGRAM,
chat_id=chat_id,
chat_type=chat_type,
)
from gateway.platforms.base import MessageEvent
from gateway.restart import GATEWAY_SERVICE_RESTART_EXIT_CODE
from gateway.session import build_session_key
from tests.gateway.restart_test_helpers import make_restart_runner, make_restart_source
@pytest.mark.asyncio
async def test_cancel_background_tasks_cancels_inflight_message_processing():
adapter = StubAdapter()
_runner, adapter = make_restart_runner()
release = asyncio.Event()
async def block_forever(_event):
@@ -47,7 +19,7 @@ async def test_cancel_background_tasks_cancels_inflight_message_processing():
return None
adapter.set_message_handler(block_forever)
event = MessageEvent(text="work", source=_source(), message_id="1")
event = MessageEvent(text="work", source=make_restart_source(), message_id="1")
await adapter.handle_message(event)
await asyncio.sleep(0)
@@ -65,17 +37,11 @@ async def test_cancel_background_tasks_cancels_inflight_message_processing():
@pytest.mark.asyncio
async def test_gateway_stop_interrupts_running_agents_and_cancels_adapter_tasks():
runner = object.__new__(GatewayRunner)
runner.config = GatewayConfig(platforms={Platform.TELEGRAM: PlatformConfig(enabled=True, token="***")})
runner._running = True
runner._shutdown_event = asyncio.Event()
runner._exit_reason = None
runner, adapter = make_restart_runner()
runner._pending_messages = {"session": "pending text"}
runner._pending_approvals = {"session": {"command": "rm -rf /tmp/x"}}
runner._background_tasks = set()
runner._shutdown_all_gateway_honcho = lambda: None
runner._restart_drain_timeout = 0.0
adapter = StubAdapter()
release = asyncio.Event()
async def block_forever(_event):
@@ -83,7 +49,7 @@ async def test_gateway_stop_interrupts_running_agents_and_cancels_adapter_tasks(
return None
adapter.set_message_handler(block_forever)
event = MessageEvent(text="work", source=_source(), message_id="1")
event = MessageEvent(text="work", source=make_restart_source(), message_id="1")
await adapter.handle_message(event)
await asyncio.sleep(0)
@@ -93,7 +59,6 @@ async def test_gateway_stop_interrupts_running_agents_and_cancels_adapter_tasks(
session_key = build_session_key(event.source)
running_agent = MagicMock()
runner._running_agents = {session_key: running_agent}
runner.adapters = {Platform.TELEGRAM: adapter}
with patch("gateway.status.remove_pid_file"), patch("gateway.status.write_runtime_status"):
await runner.stop()
@@ -105,3 +70,78 @@ async def test_gateway_stop_interrupts_running_agents_and_cancels_adapter_tasks(
assert runner._pending_messages == {}
assert runner._pending_approvals == {}
assert runner._shutdown_event.is_set() is True
@pytest.mark.asyncio
async def test_gateway_stop_drains_running_agents_before_disconnect():
runner, adapter = make_restart_runner()
disconnect_mock = AsyncMock()
adapter.disconnect = disconnect_mock
running_agent = MagicMock()
runner._running_agents = {"session": running_agent}
async def finish_agent():
await asyncio.sleep(0.05)
runner._running_agents.clear()
asyncio.create_task(finish_agent())
with patch("gateway.status.remove_pid_file"), patch("gateway.status.write_runtime_status"):
await runner.stop()
running_agent.interrupt.assert_not_called()
disconnect_mock.assert_awaited_once()
assert runner._shutdown_event.is_set() is True
@pytest.mark.asyncio
async def test_gateway_stop_interrupts_after_drain_timeout():
runner, adapter = make_restart_runner()
runner._restart_drain_timeout = 0.05
disconnect_mock = AsyncMock()
adapter.disconnect = disconnect_mock
running_agent = MagicMock()
runner._running_agents = {"session": running_agent}
with patch("gateway.status.remove_pid_file"), patch("gateway.status.write_runtime_status"):
await runner.stop()
running_agent.interrupt.assert_called_once_with("Gateway shutting down")
disconnect_mock.assert_awaited_once()
assert runner._shutdown_event.is_set() is True
@pytest.mark.asyncio
async def test_gateway_stop_service_restart_sets_named_exit_code():
runner, adapter = make_restart_runner()
adapter.disconnect = AsyncMock()
with patch("gateway.status.remove_pid_file"), patch("gateway.status.write_runtime_status"):
await runner.stop(restart=True, service_restart=True)
assert runner._exit_code == GATEWAY_SERVICE_RESTART_EXIT_CODE
@pytest.mark.asyncio
async def test_drain_active_agents_throttles_status_updates():
runner, _adapter = make_restart_runner()
runner._update_runtime_status = MagicMock()
runner._running_agents = {"a": MagicMock(), "b": MagicMock()}
async def finish_agents():
await asyncio.sleep(0.12)
runner._running_agents.pop("a")
await asyncio.sleep(0.12)
runner._running_agents.clear()
task = asyncio.create_task(finish_agents())
await runner._drain_active_agents(1.0)
await task
# Start, one count-change update, and final update. Allow one extra update
# if the loop observes the zero-agent state before exiting.
assert 3 <= runner._update_runtime_status.call_count <= 4

View File

@@ -0,0 +1,160 @@
import asyncio
import shutil
import subprocess
from unittest.mock import AsyncMock, MagicMock
import pytest
import gateway.run as gateway_run
from gateway.platforms.base import MessageEvent, MessageType
from gateway.restart import DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT
from gateway.session import build_session_key
from tests.gateway.restart_test_helpers import make_restart_runner, make_restart_source
@pytest.mark.asyncio
async def test_restart_command_while_busy_requests_drain_without_interrupt():
runner, _adapter = make_restart_runner()
runner.request_restart = MagicMock(return_value=True)
event = MessageEvent(
text="/restart",
message_type=MessageType.TEXT,
source=make_restart_source(),
message_id="m1",
)
session_key = build_session_key(event.source)
running_agent = MagicMock()
runner._running_agents[session_key] = running_agent
result = await runner._handle_message(event)
assert result == "⏳ Draining 1 active agent(s) before restart..."
running_agent.interrupt.assert_not_called()
runner.request_restart.assert_called_once_with(detached=True, via_service=False)
@pytest.mark.asyncio
async def test_drain_queue_mode_queues_follow_up_without_interrupt():
runner, adapter = make_restart_runner()
runner._draining = True
runner._restart_requested = True
runner._busy_input_mode = "queue"
event = MessageEvent(
text="follow up",
message_type=MessageType.TEXT,
source=make_restart_source(),
message_id="m2",
)
session_key = build_session_key(event.source)
adapter._active_sessions[session_key] = asyncio.Event()
await adapter.handle_message(event)
assert session_key in adapter._pending_messages
assert adapter._pending_messages[session_key].text == "follow up"
assert not adapter._active_sessions[session_key].is_set()
assert any("queued for the next turn" in message for message in adapter.sent)
@pytest.mark.asyncio
async def test_draining_rejects_new_session_messages():
runner, _adapter = make_restart_runner()
runner._draining = True
runner._restart_requested = True
event = MessageEvent(
text="hello",
message_type=MessageType.TEXT,
source=make_restart_source("fresh"),
message_id="m3",
)
result = await runner._handle_message(event)
assert result == "⏳ Gateway is restarting and is not accepting new work right now."
def test_load_busy_input_mode_prefers_env_then_config_then_default(tmp_path, monkeypatch):
monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path)
monkeypatch.delenv("HERMES_GATEWAY_BUSY_INPUT_MODE", raising=False)
assert gateway_run.GatewayRunner._load_busy_input_mode() == "interrupt"
(tmp_path / "config.yaml").write_text(
"display:\n busy_input_mode: queue\n", encoding="utf-8"
)
assert gateway_run.GatewayRunner._load_busy_input_mode() == "queue"
monkeypatch.setenv("HERMES_GATEWAY_BUSY_INPUT_MODE", "interrupt")
assert gateway_run.GatewayRunner._load_busy_input_mode() == "interrupt"
def test_load_restart_drain_timeout_prefers_env_then_config_then_default(
tmp_path, monkeypatch, caplog
):
monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path)
monkeypatch.delenv("HERMES_RESTART_DRAIN_TIMEOUT", raising=False)
assert (
gateway_run.GatewayRunner._load_restart_drain_timeout()
== DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT
)
(tmp_path / "config.yaml").write_text(
"agent:\n restart_drain_timeout: 12\n", encoding="utf-8"
)
assert gateway_run.GatewayRunner._load_restart_drain_timeout() == 12.0
monkeypatch.setenv("HERMES_RESTART_DRAIN_TIMEOUT", "7")
assert gateway_run.GatewayRunner._load_restart_drain_timeout() == 7.0
monkeypatch.setenv("HERMES_RESTART_DRAIN_TIMEOUT", "invalid")
assert (
gateway_run.GatewayRunner._load_restart_drain_timeout()
== DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT
)
assert "Invalid restart_drain_timeout" in caplog.text
@pytest.mark.asyncio
async def test_request_restart_is_idempotent():
runner, _adapter = make_restart_runner()
runner.stop = AsyncMock()
assert runner.request_restart(detached=True, via_service=False) is True
first_task = next(iter(runner._background_tasks))
assert runner.request_restart(detached=True, via_service=False) is False
await first_task
runner.stop.assert_awaited_once_with(
restart=True, detached_restart=True, service_restart=False
)
@pytest.mark.asyncio
async def test_launch_detached_restart_command_uses_setsid(monkeypatch):
runner, _adapter = make_restart_runner()
popen_calls = []
monkeypatch.setattr(gateway_run, "_resolve_hermes_bin", lambda: ["/usr/bin/hermes"])
monkeypatch.setattr(gateway_run.os, "getpid", lambda: 321)
monkeypatch.setattr(shutil, "which", lambda cmd: "/usr/bin/setsid" if cmd == "setsid" else None)
def fake_popen(cmd, **kwargs):
popen_calls.append((cmd, kwargs))
return MagicMock()
monkeypatch.setattr(subprocess, "Popen", fake_popen)
await runner._launch_detached_restart_command()
assert len(popen_calls) == 1
cmd, kwargs = popen_calls[0]
assert cmd[:2] == ["/usr/bin/setsid", "bash"]
assert "gateway restart" in cmd[-1]
assert "kill -0 321" in cmd[-1]
assert kwargs["start_new_session"] is True
assert kwargs["stdout"] is subprocess.DEVNULL
assert kwargs["stderr"] is subprocess.DEVNULL

View File

@@ -127,6 +127,16 @@ async def test_shutdown_fires_finalize_for_active_agents(mock_invoke_hook):
runner._shutdown_event = MagicMock()
runner.adapters = {}
runner._exit_reason = "test"
runner._exit_code = None
runner._draining = False
runner._restart_requested = False
runner._restart_task_started = False
runner._restart_detached = False
runner._restart_via_service = False
runner._restart_drain_timeout = 0.0
runner._stop_task = None
runner._running_agents_ts = {}
runner._update_runtime_status = MagicMock()
agent1 = MagicMock()
agent1.session_id = "sess-a"

View File

@@ -41,6 +41,15 @@ def _make_runner():
runner._pending_approvals = {}
runner._voice_mode = {}
runner._background_tasks = set()
runner._draining = False
runner._restart_requested = False
runner._restart_task_started = False
runner._restart_detached = False
runner._restart_via_service = False
runner._restart_drain_timeout = 0.0
runner._stop_task = None
runner._exit_code = None
runner._update_runtime_status = MagicMock()
runner._is_user_authorized = lambda _source: True
runner.hooks = MagicMock()
runner.hooks.emit = AsyncMock()

View File

@@ -5,6 +5,10 @@ from pathlib import Path
from types import SimpleNamespace
import hermes_cli.gateway as gateway_cli
from gateway.restart import (
DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT,
GATEWAY_SERVICE_RESTART_EXIT_CODE,
)
class TestSystemdServiceRefresh:
@@ -74,7 +78,7 @@ class TestSystemdServiceRefresh:
assert unit_path.read_text(encoding="utf-8") == "new unit\n"
assert calls[:2] == [
["systemctl", "--user", "daemon-reload"],
["systemctl", "--user", "restart", gateway_cli.get_service_name()],
["systemctl", "--user", "reload-or-restart", gateway_cli.get_service_name()],
]
@@ -84,6 +88,8 @@ class TestGeneratedSystemdUnits:
assert "ExecStart=" in unit
assert "ExecStop=" not in unit
assert "ExecReload=/bin/kill -USR1 $MAINPID" in unit
assert f"RestartForceExitStatus={GATEWAY_SERVICE_RESTART_EXIT_CODE}" in unit
assert "TimeoutStopSec=60" in unit
def test_user_unit_includes_resolved_node_directory_in_path(self, monkeypatch):
@@ -98,6 +104,8 @@ class TestGeneratedSystemdUnits:
assert "ExecStart=" in unit
assert "ExecStop=" not in unit
assert "ExecReload=/bin/kill -USR1 $MAINPID" in unit
assert f"RestartForceExitStatus={GATEWAY_SERVICE_RESTART_EXIT_CODE}" in unit
assert "TimeoutStopSec=60" in unit
assert "WantedBy=multi-user.target" in unit
@@ -157,6 +165,31 @@ class TestGatewayStopCleanup:
class TestLaunchdServiceRecovery:
def test_get_restart_drain_timeout_prefers_env_then_config_then_default(self, monkeypatch):
monkeypatch.delenv("HERMES_RESTART_DRAIN_TIMEOUT", raising=False)
monkeypatch.setattr(gateway_cli, "read_raw_config", lambda: {})
assert (
gateway_cli._get_restart_drain_timeout()
== DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT
)
monkeypatch.setattr(
gateway_cli,
"read_raw_config",
lambda: {"agent": {"restart_drain_timeout": 14}},
)
assert gateway_cli._get_restart_drain_timeout() == 14.0
monkeypatch.setenv("HERMES_RESTART_DRAIN_TIMEOUT", "9")
assert gateway_cli._get_restart_drain_timeout() == 9.0
monkeypatch.setenv("HERMES_RESTART_DRAIN_TIMEOUT", "invalid")
assert (
gateway_cli._get_restart_drain_timeout()
== DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT
)
def test_launchd_install_repairs_outdated_plist_without_force(self, tmp_path, monkeypatch):
plist_path = tmp_path / "ai.hermes.gateway.plist"
plist_path.write_text("<plist>old content</plist>", encoding="utf-8")
@@ -234,6 +267,55 @@ class TestLaunchdServiceRecovery:
["launchctl", "kickstart", target],
]
def test_launchd_restart_drains_running_gateway_before_kickstart(self, monkeypatch):
calls = []
target = f"{gateway_cli._launchd_domain()}/{gateway_cli.get_launchd_label()}"
monkeypatch.setattr(gateway_cli, "_get_restart_drain_timeout", lambda: 12.0)
monkeypatch.setattr(gateway_cli, "_request_gateway_self_restart", lambda pid: False)
monkeypatch.setattr(gateway_cli, "_wait_for_gateway_exit", lambda timeout, force_after=None: True)
monkeypatch.setattr(gateway_cli, "terminate_pid", lambda pid, force=False: calls.append(("term", pid, force)))
monkeypatch.setattr(
"gateway.status.get_running_pid",
lambda: 321,
)
def fake_run(cmd, check=False, **kwargs):
calls.append(cmd)
return SimpleNamespace(returncode=0, stdout="", stderr="")
monkeypatch.setattr(gateway_cli.subprocess, "run", fake_run)
gateway_cli.launchd_restart()
assert calls == [
("term", 321, False),
["launchctl", "kickstart", "-k", target],
]
def test_launchd_restart_self_requests_graceful_restart_without_kickstart(self, monkeypatch, capsys):
calls = []
monkeypatch.setattr(
"gateway.status.get_running_pid",
lambda: 321,
)
monkeypatch.setattr(
gateway_cli,
"_request_gateway_self_restart",
lambda pid: calls.append(("self", pid)) or True,
)
monkeypatch.setattr(
gateway_cli.subprocess,
"run",
lambda *args, **kwargs: (_ for _ in ()).throw(AssertionError("launchctl should not run")),
)
gateway_cli.launchd_restart()
assert calls == [("self", 321)]
assert "restart requested" in capsys.readouterr().out.lower()
def test_launchd_stop_uses_bootout_not_kill(self, monkeypatch):
"""launchd_stop must bootout the service so KeepAlive doesn't respawn it."""
label = gateway_cli.get_launchd_label()
@@ -337,6 +419,31 @@ class TestGatewayServiceDetection:
class TestGatewaySystemServiceRouting:
def test_systemd_restart_self_requests_graceful_restart_without_reload_or_restart(self, monkeypatch, capsys):
calls = []
monkeypatch.setattr(gateway_cli, "_select_systemd_scope", lambda system=False: False)
monkeypatch.setattr(gateway_cli, "refresh_systemd_unit_if_needed", lambda system=False: calls.append(("refresh", system)))
monkeypatch.setattr(
"gateway.status.get_running_pid",
lambda: 654,
)
monkeypatch.setattr(
gateway_cli,
"_request_gateway_self_restart",
lambda pid: calls.append(("self", pid)) or True,
)
monkeypatch.setattr(
gateway_cli.subprocess,
"run",
lambda *args, **kwargs: (_ for _ in ()).throw(AssertionError("systemctl should not run")),
)
gateway_cli.systemd_restart()
assert calls == [("refresh", False), ("self", 654)]
assert "restart requested" in capsys.readouterr().out.lower()
def test_gateway_install_passes_system_flags(self, monkeypatch):
monkeypatch.setattr(gateway_cli, "supports_systemd_services", lambda: True)
monkeypatch.setattr(gateway_cli, "is_termux", lambda: False)